summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--alembic/context.pyi9
-rw-r--r--alembic/runtime/environment.py14
-rw-r--r--alembic/util/compat.py1
-rw-r--r--docs/build/unreleased/1147.rst8
-rw-r--r--tests/requirements.py2
-rw-r--r--tools/write_pyi.py63
6 files changed, 79 insertions, 18 deletions
diff --git a/alembic/context.pyi b/alembic/context.pyi
index 9871fad..86345c4 100644
--- a/alembic/context.pyi
+++ b/alembic/context.pyi
@@ -7,7 +7,9 @@ from typing import Callable
from typing import ContextManager
from typing import Dict
from typing import List
+from typing import Literal
from typing import Optional
+from typing import overload
from typing import TextIO
from typing import Tuple
from typing import TYPE_CHECKING
@@ -644,8 +646,13 @@ def get_tag_argument() -> Optional[str]:
"""
+@overload
+def get_x_argument(as_dictionary: Literal[False]) -> List[str]: ...
+@overload
+def get_x_argument(as_dictionary: Literal[True]) -> Dict[str, str]: ...
+@overload
def get_x_argument(
- as_dictionary: bool = False,
+ as_dictionary: bool = ...,
) -> Union[List[str], Dict[str, str]]:
"""Return the value(s) passed for the ``-x`` argument, if any.
diff --git a/alembic/runtime/environment.py b/alembic/runtime/environment.py
index 44dcd72..a441d1f 100644
--- a/alembic/runtime/environment.py
+++ b/alembic/runtime/environment.py
@@ -269,15 +269,17 @@ class EnvironmentContext(util.ModuleClsProxy):
return self.context_opts.get("tag", None)
@overload
- def get_x_argument( # type:ignore[misc]
- self, as_dictionary: Literal[False] = ...
- ) -> List[str]:
+ def get_x_argument(self, as_dictionary: Literal[False]) -> List[str]:
...
@overload
- def get_x_argument( # type:ignore[misc]
- self, as_dictionary: Literal[True] = ...
- ) -> Dict[str, str]:
+ def get_x_argument(self, as_dictionary: Literal[True]) -> Dict[str, str]:
+ ...
+
+ @overload
+ def get_x_argument(
+ self, as_dictionary: bool = ...
+ ) -> Union[List[str], Dict[str, str]]:
...
def get_x_argument(
diff --git a/alembic/util/compat.py b/alembic/util/compat.py
index 289aaa2..2fe4957 100644
--- a/alembic/util/compat.py
+++ b/alembic/util/compat.py
@@ -10,6 +10,7 @@ from sqlalchemy.util.compat import inspect_formatargspec # noqa
is_posix = os.name == "posix"
+py311 = sys.version_info >= (3, 11)
py39 = sys.version_info >= (3, 9)
py38 = sys.version_info >= (3, 8)
diff --git a/docs/build/unreleased/1147.rst b/docs/build/unreleased/1147.rst
new file mode 100644
index 0000000..5f6f7de
--- /dev/null
+++ b/docs/build/unreleased/1147.rst
@@ -0,0 +1,8 @@
+.. change::
+ :tags: bug, typing
+ :tickets: 1146, 1147
+
+ Fixed typing definitions for :meth:`.EnvironmentContext.get_x_argument`.
+
+ Typing stubs are now generated for overloaded proxied methods such as
+ :meth:`.EnvironmentContext.get_x_argument`. \ No newline at end of file
diff --git a/tests/requirements.py b/tests/requirements.py
index aa88f66..c774e67 100644
--- a/tests/requirements.py
+++ b/tests/requirements.py
@@ -402,7 +402,7 @@ class DefaultRequirements(SuiteRequirements):
requirements, "black and zimports are required for this test"
)
version = exclusions.only_if(
- lambda _: compat.py39, "python 3.9 is required"
+ lambda _: compat.py311, "python 3.11 is required"
)
sqlalchemy = exclusions.only_if(
diff --git a/tools/write_pyi.py b/tools/write_pyi.py
index e5112fd..e3feb36 100644
--- a/tools/write_pyi.py
+++ b/tools/write_pyi.py
@@ -52,8 +52,10 @@ def generate_pyi_for_proxy(
):
ignore_items = IGNORE_ITEMS.get(file_key, set())
context_managers = CONTEXT_MANAGERS.get(file_key, [])
- if sys.version_info < (3, 9):
- raise RuntimeError("This script must be run with Python 3.9 or higher")
+ if sys.version_info < (3, 11):
+ raise RuntimeError(
+ "This script must be run with Python 3.11 or higher"
+ )
# When using an absolute path on windows, this will generate the correct
# relative path that shall be written to the top comment of the pyi file.
@@ -99,9 +101,30 @@ def generate_pyi_for_proxy(
continue
meth = getattr(cls, name, None)
if callable(meth):
- _generate_stub_for_meth(
- cls, name, printer, env, name in context_managers
- )
+ # If there are overloads, generate only those
+ # Do not generate the base implementation to avoid mypy errors
+ overloads = typing.get_overloads(meth)
+ if overloads:
+ # use enumerate so we can generate docs on the last overload
+ for i, ovl in enumerate(overloads, 1):
+ _generate_stub_for_meth(
+ ovl,
+ cls,
+ printer,
+ env,
+ is_context_manager=name in context_managers,
+ is_overload=True,
+ base_method=meth,
+ gen_docs=(i == len(overloads)),
+ )
+ else:
+ _generate_stub_for_meth(
+ meth,
+ cls,
+ printer,
+ env,
+ is_context_manager=name in context_managers,
+ )
else:
_generate_stub_for_attr(cls, name, printer, env)
@@ -133,12 +156,20 @@ def _generate_stub_for_attr(cls, name, printer, env):
printer.writeline(f"{name}: {type_}")
-def _generate_stub_for_meth(cls, name, printer, env, is_context_manager):
-
- fn = getattr(cls, name)
+def _generate_stub_for_meth(
+ fn,
+ cls,
+ printer,
+ env,
+ is_context_manager,
+ is_overload=False,
+ base_method=None,
+ gen_docs=True,
+):
while hasattr(fn, "__wrapped__"):
fn = fn.__wrapped__
+ name = fn.__name__
spec = inspect_getfullargspec(fn)
try:
annotations = typing.get_type_hints(fn, env)
@@ -168,17 +199,29 @@ def _generate_stub_for_meth(cls, name, printer, env, is_context_manager):
retval = re.sub("NoneType", "None", retval)
return retval
+ def _formatvalue(value):
+ return "=" + ("..." if value is Ellipsis else repr(value))
+
argspec = inspect_formatargspec(
*spec,
formatannotation=_formatannotation,
+ formatvalue=_formatvalue,
formatreturns=lambda val: f"-> {_formatannotation(val)}",
)
+
+ overload = "@overload" if is_overload else ""
contextmanager = "@contextmanager" if is_context_manager else ""
+
+ fn_doc = base_method.__doc__ if base_method else fn.__doc__
+ has_docs = gen_docs and fn_doc is not None
+ docs = '"""' + f"{fn_doc}" + '"""' if has_docs else ""
+
func_text = textwrap.dedent(
f"""
+ {overload}
{contextmanager}
- def {name}{argspec}:
- '''{fn.__doc__}'''
+ def {name}{argspec}: {"..." if not docs else ""}
+ {docs}
"""
)