From cfe92fac6794515d3aa3b995e288b11d5c9437fa Mon Sep 17 00:00:00 2001 From: CaselIT Date: Thu, 21 Apr 2022 23:23:00 +0200 Subject: Various typing related updates Change-Id: I778b63b1c438f31964d841576f0dd54ae1a5fadc --- tools/write_pyi.py | 56 ++++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 48 insertions(+), 8 deletions(-) (limited to 'tools') diff --git a/tools/write_pyi.py b/tools/write_pyi.py index 60728a8..ec928cc 100644 --- a/tools/write_pyi.py +++ b/tools/write_pyi.py @@ -4,6 +4,7 @@ import re import sys from tempfile import NamedTemporaryFile import textwrap +import typing from mako.pygen import PythonPrinter @@ -15,6 +16,8 @@ if True: # avoid flake/zimports messing with the order from alembic.script.write_hooks import console_scripts from alembic.util.compat import inspect_formatargspec from alembic.util.compat import inspect_getfullargspec + from alembic.operations import ops + import sqlalchemy as sa IGNORE_ITEMS = { "op": {"context", "create_module_class_proxy"}, @@ -24,6 +27,17 @@ IGNORE_ITEMS = { "requires_connection", }, } +TRIM_MODULE = [ + "alembic.runtime.migration.", + "alembic.operations.ops.", + "sqlalchemy.engine.base.", + "sqlalchemy.sql.schema.", + "sqlalchemy.sql.selectable.", + "sqlalchemy.sql.elements.", + "sqlalchemy.sql.type_api.", + "sqlalchemy.sql.functions.", + "sqlalchemy.sql.dml.", +] def generate_pyi_for_proxy( @@ -66,14 +80,22 @@ def generate_pyi_for_proxy( printer.writeline("### end imports ###") buf.write("\n\n") + module = sys.modules[cls.__module__] + env = { + **sa.__dict__, + **sa.types.__dict__, + **ops.__dict__, + **module.__dict__, + } + for name in dir(cls): if name.startswith("_") or name in ignore_items: continue - meth = getattr(cls, name) + meth = getattr(cls, name, None) if callable(meth): - _generate_stub_for_meth(cls, name, printer) + _generate_stub_for_meth(cls, name, printer, env) else: - _generate_stub_for_attr(cls, name, printer) + _generate_stub_for_attr(cls, name, printer, env) printer.close() @@ -92,18 +114,29 @@ def generate_pyi_for_proxy( ) -def _generate_stub_for_attr(cls, name, printer): - type_ = cls.__annotations__.get(name, "Any") +def _generate_stub_for_attr(cls, name, printer, env): + try: + annotations = typing.get_type_hints(cls, env) + except NameError as e: + annotations = cls.__annotations__ + type_ = annotations.get(name, "Any") + if isinstance(type_, str) and type_[0] in "'\"": + type_ = type_[1:-1] printer.writeline(f"{name}: {type_}") -def _generate_stub_for_meth(cls, name, printer): +def _generate_stub_for_meth(cls, name, printer, env): fn = getattr(cls, name) while hasattr(fn, "__wrapped__"): fn = fn.__wrapped__ spec = inspect_getfullargspec(fn) + try: + annotations = typing.get_type_hints(fn, env) + spec.annotations.update(annotations) + except NameError as e: + pass name_args = spec[0] assert name_args[0:1] == ["self"] or name_args[0:1] == ["cls"] @@ -119,7 +152,10 @@ def _generate_stub_for_meth(cls, name, printer): else: retval = annotation.__module__ + "." + annotation.__qualname__ else: - retval = repr(annotation) + retval = annotation + + for trim in TRIM_MODULE: + retval = retval.replace(trim, "") retval = re.sub( r'ForwardRef\(([\'"].+?[\'"])\)', lambda m: m.group(1), retval @@ -127,7 +163,11 @@ def _generate_stub_for_meth(cls, name, printer): retval = re.sub("NoneType", "None", retval) return retval - argspec = inspect_formatargspec(*spec, formatannotation=_formatannotation) + argspec = inspect_formatargspec( + *spec, + formatannotation=_formatannotation, + formatreturns=lambda val: f"-> {_formatannotation(val)}", + ) func_text = textwrap.dedent( """\ -- cgit v1.2.1