diff options
-rw-r--r-- | alembic/context.pyi | 9 | ||||
-rw-r--r-- | alembic/runtime/environment.py | 14 | ||||
-rw-r--r-- | alembic/util/compat.py | 1 | ||||
-rw-r--r-- | docs/build/unreleased/1147.rst | 8 | ||||
-rw-r--r-- | tests/requirements.py | 2 | ||||
-rw-r--r-- | tools/write_pyi.py | 63 |
6 files changed, 79 insertions, 18 deletions
diff --git a/alembic/context.pyi b/alembic/context.pyi index 9871fad..86345c4 100644 --- a/alembic/context.pyi +++ b/alembic/context.pyi @@ -7,7 +7,9 @@ from typing import Callable from typing import ContextManager from typing import Dict from typing import List +from typing import Literal from typing import Optional +from typing import overload from typing import TextIO from typing import Tuple from typing import TYPE_CHECKING @@ -644,8 +646,13 @@ def get_tag_argument() -> Optional[str]: """ +@overload +def get_x_argument(as_dictionary: Literal[False]) -> List[str]: ... +@overload +def get_x_argument(as_dictionary: Literal[True]) -> Dict[str, str]: ... +@overload def get_x_argument( - as_dictionary: bool = False, + as_dictionary: bool = ..., ) -> Union[List[str], Dict[str, str]]: """Return the value(s) passed for the ``-x`` argument, if any. diff --git a/alembic/runtime/environment.py b/alembic/runtime/environment.py index 44dcd72..a441d1f 100644 --- a/alembic/runtime/environment.py +++ b/alembic/runtime/environment.py @@ -269,15 +269,17 @@ class EnvironmentContext(util.ModuleClsProxy): return self.context_opts.get("tag", None) @overload - def get_x_argument( # type:ignore[misc] - self, as_dictionary: Literal[False] = ... - ) -> List[str]: + def get_x_argument(self, as_dictionary: Literal[False]) -> List[str]: ... @overload - def get_x_argument( # type:ignore[misc] - self, as_dictionary: Literal[True] = ... - ) -> Dict[str, str]: + def get_x_argument(self, as_dictionary: Literal[True]) -> Dict[str, str]: + ... + + @overload + def get_x_argument( + self, as_dictionary: bool = ... + ) -> Union[List[str], Dict[str, str]]: ... def get_x_argument( diff --git a/alembic/util/compat.py b/alembic/util/compat.py index 289aaa2..2fe4957 100644 --- a/alembic/util/compat.py +++ b/alembic/util/compat.py @@ -10,6 +10,7 @@ from sqlalchemy.util.compat import inspect_formatargspec # noqa is_posix = os.name == "posix" +py311 = sys.version_info >= (3, 11) py39 = sys.version_info >= (3, 9) py38 = sys.version_info >= (3, 8) diff --git a/docs/build/unreleased/1147.rst b/docs/build/unreleased/1147.rst new file mode 100644 index 0000000..5f6f7de --- /dev/null +++ b/docs/build/unreleased/1147.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: bug, typing + :tickets: 1146, 1147 + + Fixed typing definitions for :meth:`.EnvironmentContext.get_x_argument`. + + Typing stubs are now generated for overloaded proxied methods such as + :meth:`.EnvironmentContext.get_x_argument`.
\ No newline at end of file diff --git a/tests/requirements.py b/tests/requirements.py index aa88f66..c774e67 100644 --- a/tests/requirements.py +++ b/tests/requirements.py @@ -402,7 +402,7 @@ class DefaultRequirements(SuiteRequirements): requirements, "black and zimports are required for this test" ) version = exclusions.only_if( - lambda _: compat.py39, "python 3.9 is required" + lambda _: compat.py311, "python 3.11 is required" ) sqlalchemy = exclusions.only_if( diff --git a/tools/write_pyi.py b/tools/write_pyi.py index e5112fd..e3feb36 100644 --- a/tools/write_pyi.py +++ b/tools/write_pyi.py @@ -52,8 +52,10 @@ def generate_pyi_for_proxy( ): ignore_items = IGNORE_ITEMS.get(file_key, set()) context_managers = CONTEXT_MANAGERS.get(file_key, []) - if sys.version_info < (3, 9): - raise RuntimeError("This script must be run with Python 3.9 or higher") + if sys.version_info < (3, 11): + raise RuntimeError( + "This script must be run with Python 3.11 or higher" + ) # When using an absolute path on windows, this will generate the correct # relative path that shall be written to the top comment of the pyi file. @@ -99,9 +101,30 @@ def generate_pyi_for_proxy( continue meth = getattr(cls, name, None) if callable(meth): - _generate_stub_for_meth( - cls, name, printer, env, name in context_managers - ) + # If there are overloads, generate only those + # Do not generate the base implementation to avoid mypy errors + overloads = typing.get_overloads(meth) + if overloads: + # use enumerate so we can generate docs on the last overload + for i, ovl in enumerate(overloads, 1): + _generate_stub_for_meth( + ovl, + cls, + printer, + env, + is_context_manager=name in context_managers, + is_overload=True, + base_method=meth, + gen_docs=(i == len(overloads)), + ) + else: + _generate_stub_for_meth( + meth, + cls, + printer, + env, + is_context_manager=name in context_managers, + ) else: _generate_stub_for_attr(cls, name, printer, env) @@ -133,12 +156,20 @@ def _generate_stub_for_attr(cls, name, printer, env): printer.writeline(f"{name}: {type_}") -def _generate_stub_for_meth(cls, name, printer, env, is_context_manager): - - fn = getattr(cls, name) +def _generate_stub_for_meth( + fn, + cls, + printer, + env, + is_context_manager, + is_overload=False, + base_method=None, + gen_docs=True, +): while hasattr(fn, "__wrapped__"): fn = fn.__wrapped__ + name = fn.__name__ spec = inspect_getfullargspec(fn) try: annotations = typing.get_type_hints(fn, env) @@ -168,17 +199,29 @@ def _generate_stub_for_meth(cls, name, printer, env, is_context_manager): retval = re.sub("NoneType", "None", retval) return retval + def _formatvalue(value): + return "=" + ("..." if value is Ellipsis else repr(value)) + argspec = inspect_formatargspec( *spec, formatannotation=_formatannotation, + formatvalue=_formatvalue, formatreturns=lambda val: f"-> {_formatannotation(val)}", ) + + overload = "@overload" if is_overload else "" contextmanager = "@contextmanager" if is_context_manager else "" + + fn_doc = base_method.__doc__ if base_method else fn.__doc__ + has_docs = gen_docs and fn_doc is not None + docs = '"""' + f"{fn_doc}" + '"""' if has_docs else "" + func_text = textwrap.dedent( f""" + {overload} {contextmanager} - def {name}{argspec}: - '''{fn.__doc__}''' + def {name}{argspec}: {"..." if not docs else ""} + {docs} """ ) |