diff options
author | Stefan Behnel <stefan_ml@behnel.de> | 2021-04-18 16:49:00 +0200 |
---|---|---|
committer | Stefan Behnel <stefan_ml@behnel.de> | 2021-04-18 16:49:00 +0200 |
commit | f3f7b612dc005abdac2e0a0a48dcf9be7b4b0122 (patch) | |
tree | 82ccd43a826e9bee220bc362468402cd6fe1f16c | |
parent | f79916155905490e59b3aeac635aefdacbbbc994 (diff) | |
download | cython-f3f7b612dc005abdac2e0a0a48dcf9be7b4b0122.tar.gz |
Allow searching for include/import files without passing a source pos tuple and clean up the call chains a little.
-rw-r--r-- | Cython/Compiler/Main.py | 47 | ||||
-rw-r--r-- | Cython/Compiler/ModuleNode.py | 2 |
2 files changed, 26 insertions, 23 deletions
diff --git a/Cython/Compiler/Main.py b/Cython/Compiler/Main.py index 4ff5a6420..31c0e1747 100644 --- a/Cython/Compiler/Main.py +++ b/Cython/Compiler/Main.py @@ -209,7 +209,8 @@ class Context(object): # Set pxd_file_loaded such that we don't need to # look for the non-existing pxd file next time. scope.pxd_file_loaded = True - package_pathname = self.search_include_directories(qualified_name, ".py", pos) + package_pathname = self.search_include_directories( + qualified_name, suffix=".py", source_pos=pos) if package_pathname and package_pathname.endswith(Utils.PACKAGE_FILES): pass else: @@ -232,7 +233,7 @@ class Context(object): pass return scope - def find_pxd_file(self, qualified_name, pos, sys_path=True): + def find_pxd_file(self, qualified_name, pos=None, sys_path=True, source_file_path=None): # Search include path (and sys.path if sys_path is True) for # the .pxd file corresponding to the given fully-qualified # module name. @@ -241,34 +242,36 @@ class Context(object): # the directory containing the source file is searched first # for a dotted filename, and its containing package root # directory is searched first for a non-dotted filename. - pxd = self.search_include_directories(qualified_name, ".pxd", pos, sys_path=sys_path) + pxd = self.search_include_directories( + qualified_name, suffix=".pxd", source_pos=pos, sys_path=sys_path, source_file_path=source_file_path) if pxd is None and Options.cimport_from_pyx: return self.find_pyx_file(qualified_name, pos) return pxd - def find_pyx_file(self, qualified_name, pos): + def find_pyx_file(self, qualified_name, pos=None, source_file_path=None): # Search include path for the .pyx file corresponding to the # given fully-qualified module name, as for find_pxd_file(). - return self.search_include_directories(qualified_name, ".pyx", pos) + return self.search_include_directories( + qualified_name, suffix=".pyx", source_pos=pos, source_file_path=source_file_path) - def find_include_file(self, filename, pos): + def find_include_file(self, filename, pos=None, source_file_path=None): # Search list of include directories for filename. # Reports an error and returns None if not found. - path = self.search_include_directories(filename, "", pos, - include=True) + path = self.search_include_directories( + filename, source_pos=pos, include=True, source_file_path=source_file_path) if not path: error(pos, "'%s' not found" % filename) return path - def search_include_directories(self, qualified_name, suffix, pos, - include=False, sys_path=False): + def search_include_directories(self, qualified_name, + suffix=None, source_pos=None, include=False, sys_path=False, source_file_path=None): include_dirs = self.include_directories if sys_path: include_dirs = include_dirs + sys.path # include_dirs must be hashable for caching in @cached_function include_dirs = tuple(include_dirs + [standard_include_path]) - return search_include_directories(include_dirs, qualified_name, - suffix, pos, include) + return search_include_directories( + include_dirs, qualified_name, suffix or "", source_pos, include, source_file_path) def find_root_package_dir(self, file_path): return Utils.find_root_package_dir(file_path) @@ -282,15 +285,14 @@ class Context(object): c_time = Utils.modification_time(output_path) if Utils.file_newer_than(source_path, c_time): return 1 - pos = [source_path] pxd_path = Utils.replace_suffix(source_path, ".pxd") if os.path.exists(pxd_path) and Utils.file_newer_than(pxd_path, c_time): return 1 for kind, name in self.read_dependency_file(source_path): if kind == "cimport": - dep_path = self.find_pxd_file(name, pos) + dep_path = self.find_pxd_file(name, source_file_path=source_path) elif kind == "include": - dep_path = self.search_include_directories(name, "", pos) + dep_path = self.search_include_directories(name, source_file_path=source_path) else: continue if dep_path and Utils.file_newer_than(dep_path, c_time): @@ -629,24 +631,25 @@ def compile(source, options = None, full_module_name = None, **kwds): @Utils.cached_function -def search_include_directories(dirs, qualified_name, suffix, pos, include=False): +def search_include_directories(dirs, qualified_name, suffix="", pos=None, include=False, source_file_path=None): """ Search the list of include directories for the given file name. - If a source file position is given, first searches the directory - containing that file. Returns None if not found, but does not - report an error. + If a source file path or position is given, first searches the directory + containing that file. Returns None if not found, but does not report an error. The 'include' option will disable package dereferencing. """ - if pos: + if pos and not source_file_path: file_desc = pos[0] if not isinstance(file_desc, FileSourceDescriptor): raise RuntimeError("Only file sources for code supported") + source_file_path = file_desc.filename + if source_file_path: if include: - dirs = (os.path.dirname(file_desc.filename),) + dirs + dirs = (os.path.dirname(source_file_path),) + dirs else: - dirs = (Utils.find_root_package_dir(file_desc.filename),) + dirs + dirs = (Utils.find_root_package_dir(source_file_path),) + dirs # search for dotted filename e.g. <dir>/foo.bar.pxd dotted_filename = qualified_name diff --git a/Cython/Compiler/ModuleNode.py b/Cython/Compiler/ModuleNode.py index 77a92368c..e301c715b 100644 --- a/Cython/Compiler/ModuleNode.py +++ b/Cython/Compiler/ModuleNode.py @@ -520,7 +520,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): if not target_file_dir.startswith(target_dir): # any other directories may not be writable => avoid trying continue - source_file = search_include_file(included_file, "", self.pos, include=True) + source_file = search_include_file(included_file, source_pos=self.pos, include=True) if not source_file: continue if target_file_dir != target_dir and not os.path.exists(target_file_dir): |