diff options
-rw-r--r-- | astroid/modutils.py | 61 | ||||
-rw-r--r-- | astroid/tests/unittest_modutils.py | 22 |
2 files changed, 71 insertions, 12 deletions
diff --git a/astroid/modutils.py b/astroid/modutils.py index 83d0ba74..4d8dfec0 100644 --- a/astroid/modutils.py +++ b/astroid/modutils.py @@ -285,26 +285,63 @@ def check_modpath_has_init(path, mod_path): return True +def _get_relative_base_path(filename, path_to_check): + """Extracts the relative mod path of the file to import from + + Check if a file is within the passed in path and if so, returns the + relative mod path from the one passed in. + + If the filename is no in path_to_check, returns None + + Note this function will look for both abs and realpath of the file, + this allows to find the relative base path even if the file is a + symlink of a file in the passed in path + + Examples: + _get_relative_base_path("/a/b/c/d.py", "/a/b") -> ["c","d"] + _get_relative_base_path("/a/b/c/d.py", "/dev") -> None + """ + importable_path = None + path_to_check = os.path.normcase(path_to_check) + abs_filename = os.path.abspath(filename) + if os.path.normcase(abs_filename).startswith(path_to_check): + importable_path = abs_filename + + real_filename = os.path.realpath(filename) + if os.path.normcase(real_filename).startswith(path_to_check): + importable_path = real_filename + + if importable_path: + base_path = os.path.splitext(importable_path)[0] + relative_base_path = base_path[len(path_to_check):] + return [pkg for pkg in relative_base_path.split(os.sep) if pkg] + + return None + + def modpath_from_file_with_callback(filename, extrapath=None, is_package_cb=None): - filename = _path_from_filename(filename) - filename = os.path.realpath(os.path.expanduser(filename)) - base = os.path.splitext(filename)[0] + filename = os.path.expanduser(_path_from_filename(filename)) if extrapath is not None: for path_ in six.moves.map(_canonicalize_path, extrapath): path = os.path.abspath(path_) - if path and os.path.normcase(base[:len(path)]) == os.path.normcase(path): - submodpath = [pkg for pkg in base[len(path):].split(os.sep) - if pkg] - if is_package_cb(path, submodpath[:-1]): - return extrapath[path_].split('.') + submodpath + if not path: + continue + submodpath = _get_relative_base_path(filename, path) + if not submodpath: + continue + if is_package_cb(path, submodpath[:-1]): + return extrapath[path_].split('.') + submodpath for path in six.moves.map(_canonicalize_path, sys.path): path = _cache_normalize_path(path) - if path and os.path.normcase(base).startswith(path): - modpath = [pkg for pkg in base[len(path):].split(os.sep) if pkg] - if is_package_cb(path, modpath[:-1]): - return modpath + if not path: + continue + modpath = _get_relative_base_path(filename, path) + if not modpath: + continue + if is_package_cb(path, modpath[:-1]): + return modpath raise ImportError('Unable to find module for %s in %s' % ( filename, ', \n'.join(sys.path))) diff --git a/astroid/tests/unittest_modutils.py b/astroid/tests/unittest_modutils.py index ef3c460c..01a4d1fb 100644 --- a/astroid/tests/unittest_modutils.py +++ b/astroid/tests/unittest_modutils.py @@ -15,6 +15,7 @@ import os import sys import unittest from xml import etree +import tempfile import astroid from astroid.interpreter._import import spec @@ -110,6 +111,27 @@ class ModPathFromFileTest(unittest.TestCase): def test_raise_modpath_from_file_Exception(self): self.assertRaises(Exception, modutils.modpath_from_file, '/turlututu') + def test_import_symlink_with_source_outside_of_path(self): + with tempfile.NamedTemporaryFile() as tmpfile: + linked_file_name = 'symlinked_file.py' + try: + os.symlink(tmpfile.name, linked_file_name) + self.assertEqual(modutils.modpath_from_file(linked_file_name), + ['symlinked_file']) + finally: + os.remove(linked_file_name) + + def test_import_symlink_both_outside_of_path(self): + with tempfile.NamedTemporaryFile() as tmpfile: + linked_file_name = os.path.join(tempfile.gettempdir(), + 'symlinked_file.py') + try: + os.symlink(tmpfile.name, linked_file_name) + self.assertRaises(ImportError, + modutils.modpath_from_file, linked_file_name) + finally: + os.remove(linked_file_name) + class LoadModuleFromPathTest(resources.SysPathSetup, unittest.TestCase): |