summaryrefslogtreecommitdiff
path: root/astroid/modutils.py
diff options
context:
space:
mode:
Diffstat (limited to 'astroid/modutils.py')
-rw-r--r--astroid/modutils.py61
1 files changed, 49 insertions, 12 deletions
diff --git a/astroid/modutils.py b/astroid/modutils.py
index 83d0ba74..4d8dfec0 100644
--- a/astroid/modutils.py
+++ b/astroid/modutils.py
@@ -285,26 +285,63 @@ def check_modpath_has_init(path, mod_path):
return True
+def _get_relative_base_path(filename, path_to_check):
+ """Extracts the relative mod path of the file to import from
+
+ Check if a file is within the passed in path and if so, returns the
+ relative mod path from the one passed in.
+
+ If the filename is no in path_to_check, returns None
+
+ Note this function will look for both abs and realpath of the file,
+ this allows to find the relative base path even if the file is a
+ symlink of a file in the passed in path
+
+ Examples:
+ _get_relative_base_path("/a/b/c/d.py", "/a/b") -> ["c","d"]
+ _get_relative_base_path("/a/b/c/d.py", "/dev") -> None
+ """
+ importable_path = None
+ path_to_check = os.path.normcase(path_to_check)
+ abs_filename = os.path.abspath(filename)
+ if os.path.normcase(abs_filename).startswith(path_to_check):
+ importable_path = abs_filename
+
+ real_filename = os.path.realpath(filename)
+ if os.path.normcase(real_filename).startswith(path_to_check):
+ importable_path = real_filename
+
+ if importable_path:
+ base_path = os.path.splitext(importable_path)[0]
+ relative_base_path = base_path[len(path_to_check):]
+ return [pkg for pkg in relative_base_path.split(os.sep) if pkg]
+
+ return None
+
+
def modpath_from_file_with_callback(filename, extrapath=None, is_package_cb=None):
- filename = _path_from_filename(filename)
- filename = os.path.realpath(os.path.expanduser(filename))
- base = os.path.splitext(filename)[0]
+ filename = os.path.expanduser(_path_from_filename(filename))
if extrapath is not None:
for path_ in six.moves.map(_canonicalize_path, extrapath):
path = os.path.abspath(path_)
- if path and os.path.normcase(base[:len(path)]) == os.path.normcase(path):
- submodpath = [pkg for pkg in base[len(path):].split(os.sep)
- if pkg]
- if is_package_cb(path, submodpath[:-1]):
- return extrapath[path_].split('.') + submodpath
+ if not path:
+ continue
+ submodpath = _get_relative_base_path(filename, path)
+ if not submodpath:
+ continue
+ if is_package_cb(path, submodpath[:-1]):
+ return extrapath[path_].split('.') + submodpath
for path in six.moves.map(_canonicalize_path, sys.path):
path = _cache_normalize_path(path)
- if path and os.path.normcase(base).startswith(path):
- modpath = [pkg for pkg in base[len(path):].split(os.sep) if pkg]
- if is_package_cb(path, modpath[:-1]):
- return modpath
+ if not path:
+ continue
+ modpath = _get_relative_base_path(filename, path)
+ if not modpath:
+ continue
+ if is_package_cb(path, modpath[:-1]):
+ return modpath
raise ImportError('Unable to find module for %s in %s' % (
filename, ', \n'.join(sys.path)))