summaryrefslogtreecommitdiff
path: root/tools
diff options
context:
space:
mode:
authorCaselIT <cfederico87@gmail.com>2022-04-21 23:23:00 +0200
committerCaselIT <cfederico87@gmail.com>2022-04-23 22:04:36 +0200
commitcfe92fac6794515d3aa3b995e288b11d5c9437fa (patch)
tree7e00ebf10db2bd5a95b5f4b3a49a31c24dffe8c5 /tools
parente539704aae92bee5d266b1e4e5cfe54b14d544f1 (diff)
downloadalembic-cfe92fac6794515d3aa3b995e288b11d5c9437fa.tar.gz
Various typing related updates
Change-Id: I778b63b1c438f31964d841576f0dd54ae1a5fadc
Diffstat (limited to 'tools')
-rw-r--r--tools/write_pyi.py56
1 files changed, 48 insertions, 8 deletions
diff --git a/tools/write_pyi.py b/tools/write_pyi.py
index 60728a8..ec928cc 100644
--- a/tools/write_pyi.py
+++ b/tools/write_pyi.py
@@ -4,6 +4,7 @@ import re
import sys
from tempfile import NamedTemporaryFile
import textwrap
+import typing
from mako.pygen import PythonPrinter
@@ -15,6 +16,8 @@ if True: # avoid flake/zimports messing with the order
from alembic.script.write_hooks import console_scripts
from alembic.util.compat import inspect_formatargspec
from alembic.util.compat import inspect_getfullargspec
+ from alembic.operations import ops
+ import sqlalchemy as sa
IGNORE_ITEMS = {
"op": {"context", "create_module_class_proxy"},
@@ -24,6 +27,17 @@ IGNORE_ITEMS = {
"requires_connection",
},
}
+TRIM_MODULE = [
+ "alembic.runtime.migration.",
+ "alembic.operations.ops.",
+ "sqlalchemy.engine.base.",
+ "sqlalchemy.sql.schema.",
+ "sqlalchemy.sql.selectable.",
+ "sqlalchemy.sql.elements.",
+ "sqlalchemy.sql.type_api.",
+ "sqlalchemy.sql.functions.",
+ "sqlalchemy.sql.dml.",
+]
def generate_pyi_for_proxy(
@@ -66,14 +80,22 @@ def generate_pyi_for_proxy(
printer.writeline("### end imports ###")
buf.write("\n\n")
+ module = sys.modules[cls.__module__]
+ env = {
+ **sa.__dict__,
+ **sa.types.__dict__,
+ **ops.__dict__,
+ **module.__dict__,
+ }
+
for name in dir(cls):
if name.startswith("_") or name in ignore_items:
continue
- meth = getattr(cls, name)
+ meth = getattr(cls, name, None)
if callable(meth):
- _generate_stub_for_meth(cls, name, printer)
+ _generate_stub_for_meth(cls, name, printer, env)
else:
- _generate_stub_for_attr(cls, name, printer)
+ _generate_stub_for_attr(cls, name, printer, env)
printer.close()
@@ -92,18 +114,29 @@ def generate_pyi_for_proxy(
)
-def _generate_stub_for_attr(cls, name, printer):
- type_ = cls.__annotations__.get(name, "Any")
+def _generate_stub_for_attr(cls, name, printer, env):
+ try:
+ annotations = typing.get_type_hints(cls, env)
+ except NameError as e:
+ annotations = cls.__annotations__
+ type_ = annotations.get(name, "Any")
+ if isinstance(type_, str) and type_[0] in "'\"":
+ type_ = type_[1:-1]
printer.writeline(f"{name}: {type_}")
-def _generate_stub_for_meth(cls, name, printer):
+def _generate_stub_for_meth(cls, name, printer, env):
fn = getattr(cls, name)
while hasattr(fn, "__wrapped__"):
fn = fn.__wrapped__
spec = inspect_getfullargspec(fn)
+ try:
+ annotations = typing.get_type_hints(fn, env)
+ spec.annotations.update(annotations)
+ except NameError as e:
+ pass
name_args = spec[0]
assert name_args[0:1] == ["self"] or name_args[0:1] == ["cls"]
@@ -119,7 +152,10 @@ def _generate_stub_for_meth(cls, name, printer):
else:
retval = annotation.__module__ + "." + annotation.__qualname__
else:
- retval = repr(annotation)
+ retval = annotation
+
+ for trim in TRIM_MODULE:
+ retval = retval.replace(trim, "")
retval = re.sub(
r'ForwardRef\(([\'"].+?[\'"])\)', lambda m: m.group(1), retval
@@ -127,7 +163,11 @@ def _generate_stub_for_meth(cls, name, printer):
retval = re.sub("NoneType", "None", retval)
return retval
- argspec = inspect_formatargspec(*spec, formatannotation=_formatannotation)
+ argspec = inspect_formatargspec(
+ *spec,
+ formatannotation=_formatannotation,
+ formatreturns=lambda val: f"-> {_formatannotation(val)}",
+ )
func_text = textwrap.dedent(
"""\