summaryrefslogtreecommitdiff
path: root/astroid
diff options
context:
space:
mode:
authorMark Byrne <31762852+mbyrnepr2@users.noreply.github.com>2023-05-16 10:30:12 +0200
committerGitHub <noreply@github.com>2023-05-16 10:30:12 +0200
commit12c3d1556be3bacaf4816898616b6124e0d44da9 (patch)
tree0155475c0086e78480d38709a06796be7bd93151 /astroid
parent1177accd39acffe1288d0f8fa08dbea0454035ba (diff)
downloadastroid-git-12c3d1556be3bacaf4816898616b6124e0d44da9.tar.gz
Recognize stub ``pyi`` Python files. (#2182)HEADmain
Recognize stub ``pyi`` Python files. Refs pylint-dev/pylint#4987 Co-authored-by: Jacob Walls <jacobtylerwalls@gmail.com>
Diffstat (limited to 'astroid')
-rw-r--r--astroid/interpreter/_import/spec.py2
-rw-r--r--astroid/modutils.py13
2 files changed, 7 insertions, 8 deletions
diff --git a/astroid/interpreter/_import/spec.py b/astroid/interpreter/_import/spec.py
index b1f8e8db..3c21fd73 100644
--- a/astroid/interpreter/_import/spec.py
+++ b/astroid/interpreter/_import/spec.py
@@ -163,7 +163,7 @@ class ImportlibFinder(Finder):
for entry in submodule_path:
package_directory = os.path.join(entry, modname)
- for suffix in (".py", importlib.machinery.BYTECODE_SUFFIXES[0]):
+ for suffix in (".py", ".pyi", importlib.machinery.BYTECODE_SUFFIXES[0]):
package_file_name = "__init__" + suffix
file_path = os.path.join(package_directory, package_file_name)
if os.path.isfile(file_path):
diff --git a/astroid/modutils.py b/astroid/modutils.py
index b4f3b6e3..33fd3eeb 100644
--- a/astroid/modutils.py
+++ b/astroid/modutils.py
@@ -44,10 +44,10 @@ logger = logging.getLogger(__name__)
if sys.platform.startswith("win"):
- PY_SOURCE_EXTS = ("py", "pyw")
+ PY_SOURCE_EXTS = ("py", "pyw", "pyi")
PY_COMPILED_EXTS = ("dll", "pyd")
else:
- PY_SOURCE_EXTS = ("py",)
+ PY_SOURCE_EXTS = ("py", "pyi")
PY_COMPILED_EXTS = ("so",)
@@ -274,9 +274,6 @@ def _get_relative_base_path(filename: str, path_to_check: str) -> list[str] | No
if os.path.normcase(real_filename).startswith(path_to_check):
importable_path = real_filename
- # if "var" in path_to_check:
- # breakpoint()
-
if importable_path:
base_path = os.path.splitext(importable_path)[0]
relative_base_path = base_path[len(path_to_check) :]
@@ -476,7 +473,7 @@ def get_module_files(
continue
_handle_blacklist(blacklist, dirnames, filenames)
# check for __init__.py
- if not list_all and "__init__.py" not in filenames:
+ if not list_all and {"__init__.py", "__init__.pyi"}.isdisjoint(filenames):
dirnames[:] = ()
continue
for filename in filenames:
@@ -499,6 +496,8 @@ def get_source_file(filename: str, include_no_ext: bool = False) -> str:
"""
filename = os.path.abspath(_path_from_filename(filename))
base, orig_ext = os.path.splitext(filename)
+ if orig_ext == ".pyi" and os.path.exists(f"{base}{orig_ext}"):
+ return f"{base}{orig_ext}"
for ext in PY_SOURCE_EXTS:
source_path = f"{base}.{ext}"
if os.path.exists(source_path):
@@ -663,7 +662,7 @@ def _is_python_file(filename: str) -> bool:
.pyc and .pyo are ignored
"""
- return filename.endswith((".py", ".so", ".pyd", ".pyw"))
+ return filename.endswith((".py", ".pyi", ".so", ".pyd", ".pyw"))
def _has_init(directory: str) -> str | None: