diff options
Diffstat (limited to 'tools')
-rw-r--r-- | tools/write_pyi.py | 38 |
1 files changed, 18 insertions, 20 deletions
diff --git a/tools/write_pyi.py b/tools/write_pyi.py index ec928cc..cf42d1b 100644 --- a/tools/write_pyi.py +++ b/tools/write_pyi.py @@ -38,6 +38,7 @@ TRIM_MODULE = [ "sqlalchemy.sql.functions.", "sqlalchemy.sql.dml.", ] +CONTEXT_MANAGERS = {"op": ["batch_alter_table"]} def generate_pyi_for_proxy( @@ -46,8 +47,10 @@ def generate_pyi_for_proxy( source_path: Path, destination_path: Path, ignore_output: bool, - ignore_items: set, + file_key: str, ): + ignore_items = IGNORE_ITEMS.get(file_key, set()) + context_managers = CONTEXT_MANAGERS.get(file_key, []) if sys.version_info < (3, 9): raise RuntimeError("This script must be run with Python 3.9 or higher") @@ -93,7 +96,9 @@ def generate_pyi_for_proxy( continue meth = getattr(cls, name, None) if callable(meth): - _generate_stub_for_meth(cls, name, printer, env) + _generate_stub_for_meth( + cls, name, printer, env, name in context_managers + ) else: _generate_stub_for_attr(cls, name, printer, env) @@ -125,7 +130,7 @@ def _generate_stub_for_attr(cls, name, printer, env): printer.writeline(f"{name}: {type_}") -def _generate_stub_for_meth(cls, name, printer, env): +def _generate_stub_for_meth(cls, name, printer, env, is_context_manager): fn = getattr(cls, name) while hasattr(fn, "__wrapped__"): @@ -168,24 +173,19 @@ def _generate_stub_for_meth(cls, name, printer, env): formatannotation=_formatannotation, formatreturns=lambda val: f"-> {_formatannotation(val)}", ) - + contextmanager = "@contextmanager" if is_context_manager else "" func_text = textwrap.dedent( - """\ - def %(name)s%(argspec)s: - '''%(doc)s''' + f""" + {contextmanager} + def {name}{argspec}: + '''{fn.__doc__}''' """ - % { - "name": name, - "argspec": argspec, - "doc": fn.__doc__, - } ) - printer.write_indented_block(func_text) def run_file( - source_path: Path, cls_to_generate: type, stdout: bool, ignore_items: set + source_path: Path, cls_to_generate: type, stdout: bool, file_key: str ): progname = Path(sys.argv[0]).as_posix() if not stdout: @@ -195,7 +195,7 @@ def run_file( source_path=source_path, destination_path=source_path, ignore_output=False, - ignore_items=ignore_items, + file_key=file_key, ) else: with NamedTemporaryFile(delete=False, suffix=".pyi") as f: @@ -207,7 +207,7 @@ def run_file( source_path=source_path, destination_path=f_path, ignore_output=True, - ignore_items=ignore_items, + file_key=file_key, ) sys.stdout.write(f_path.read_text()) f_path.unlink() @@ -216,15 +216,13 @@ def run_file( def main(args): location = Path(__file__).parent.parent / "alembic" if args.file in {"all", "op"}: - run_file( - location / "op.pyi", Operations, args.stdout, IGNORE_ITEMS["op"] - ) + run_file(location / "op.pyi", Operations, args.stdout, "op") if args.file in {"all", "context"}: run_file( location / "context.pyi", EnvironmentContext, args.stdout, - IGNORE_ITEMS["context"], + "context", ) |