diff options
author | CaselIT <cfederico87@gmail.com> | 2023-03-17 00:50:53 +0100 |
---|---|---|
committer | Federico Caselli <cfederico87@gmail.com> | 2023-05-04 22:53:09 +0200 |
commit | 2aba0ada168d0047d54c7a08b0ffdde3102b716b (patch) | |
tree | 6334971e00debf16b208de8b12d882a79b75d902 /tools | |
parent | e17e59ee2be160fff35b38b08d68766a971b3069 (diff) | |
download | alembic-2aba0ada168d0047d54c7a08b0ffdde3102b716b.tar.gz |
Add Operations and BatchOperations stub methods
Updated stub generator script to also add stubs method definitions
for the :class:`.Operations` class and the :class:`.BatchOperations`
class obtained from :meth:`.Operations.batch_alter_table`.
Repaired the return signatures for :class:`.Operations` that mostly
return ``None``, and were erroneously referring to ``Optional[Table]``
in many cases.
Fixes: #1093
Change-Id: I98d38dd5a1e719b4dbbc1003746ec28f26c27808
Diffstat (limited to 'tools')
-rw-r--r-- | tools/write_pyi.py | 303 |
1 files changed, 217 insertions, 86 deletions
diff --git a/tools/write_pyi.py b/tools/write_pyi.py index fa79c49..7d24870 100644 --- a/tools/write_pyi.py +++ b/tools/write_pyi.py @@ -1,4 +1,9 @@ +from __future__ import annotations + +from abc import abstractmethod from argparse import ArgumentParser +from dataclasses import dataclass +from dataclasses import field from pathlib import Path import re import sys @@ -6,11 +11,14 @@ from tempfile import NamedTemporaryFile import textwrap import typing -from mako.pygen import PythonPrinter +from alembic.autogenerate.api import AutogenContext +from alembic.ddl.impl import DefaultImpl +from alembic.runtime.migration import MigrationInfo sys.path.append(str(Path(__file__).parent.parent)) if True: # avoid flake/zimports messing with the order + from alembic.operations.base import BatchOperations from alembic.operations.base import Operations from alembic.runtime.environment import EnvironmentContext from alembic.runtime.migration import MigrationContext @@ -20,18 +28,12 @@ if True: # avoid flake/zimports messing with the order from alembic.operations import ops import sqlalchemy as sa -IGNORE_ITEMS = { - "op": {"context", "create_module_class_proxy"}, - "context": { - "create_module_class_proxy", - "get_impl", - "requires_connection", - }, -} + TRIM_MODULE = [ "alembic.runtime.migration.", "alembic.operations.base.", "alembic.operations.ops.", + "alembic.autogenerate.api.", "sqlalchemy.engine.base.", "sqlalchemy.engine.url.", "sqlalchemy.sql.schema.", @@ -41,58 +43,38 @@ TRIM_MODULE = [ "sqlalchemy.sql.functions.", "sqlalchemy.sql.dml.", ] -CONTEXT_MANAGERS = {"op": ["batch_alter_table"]} -ADDITIONAL_ENV = {"MigrationContext": MigrationContext} +ADDITIONAL_ENV = { + "MigrationContext": MigrationContext, + "AutogenContext": AutogenContext, + "DefaultImpl": DefaultImpl, + "MigrationInfo": MigrationInfo, +} def generate_pyi_for_proxy( - cls: type, - progname: str, - source_path: Path, - destination_path: Path, - ignore_output: bool, - file_key: str, + file_info: FileInfo, destination_path: Path, ignore_output: bool ): - ignore_items = IGNORE_ITEMS.get(file_key, set()) - context_managers = CONTEXT_MANAGERS.get(file_key, []) if sys.version_info < (3, 11): raise RuntimeError( "This script must be run with Python 3.11 or higher" ) + progname = Path(sys.argv[0]).as_posix() # When using an absolute path on windows, this will generate the correct # relative path that shall be written to the top comment of the pyi file. if Path(progname).is_absolute(): progname = Path(progname).relative_to(Path().cwd()).as_posix() - imports = [] - read_imports = False - with open(source_path) as read_file: - for line in read_file: - if line.startswith("# ### this file stubs are generated by"): - read_imports = True - elif line.startswith("### end imports ###"): - read_imports = False - break - elif read_imports: - imports.append(line.rstrip()) + file_info.read_file() + cls = file_info.target with open(destination_path, "w") as buf: - printer = PythonPrinter(buf) - - printer.writeline( - f"# ### this file stubs are generated by {progname} " - "- do not edit ###" - ) - for line in imports: - buf.write(line + "\n") - printer.writeline("### end imports ###") - buf.write("\n\n") + file_info.write_before(buf, progname) module = sys.modules[cls.__module__] env = { **typing.__dict__, - **sa.sql.schema.__dict__, + **sa.schema.__dict__, **sa.__dict__, **sa.types.__dict__, **ADDITIONAL_ENV, @@ -101,39 +83,43 @@ def generate_pyi_for_proxy( } for name in dir(cls): - if name.startswith("_") or name in ignore_items: + if name.startswith("_") or name in file_info.ignore_items: continue meth = getattr(cls, name, None) if callable(meth): # If there are overloads, generate only those # Do not generate the base implementation to avoid mypy errors overloads = typing.get_overloads(meth) + is_context_manager = name in file_info.context_managers if overloads: # use enumerate so we can generate docs on the # last overload for i, ovl in enumerate(overloads, 1): - _generate_stub_for_meth( + text = _generate_stub_for_meth( ovl, cls, - printer, + file_info, env, - is_context_manager=name in context_managers, + is_context_manager=is_context_manager, is_overload=True, base_method=meth, gen_docs=(i == len(overloads)), ) + file_info.write(buf, text) else: - _generate_stub_for_meth( + text = _generate_stub_for_meth( meth, cls, - printer, + file_info, env, - is_context_manager=name in context_managers, + is_context_manager=is_context_manager, ) + file_info.write(buf, text) else: - _generate_stub_for_attr(cls, name, printer, env) + text = _generate_stub_for_attr(cls, name, env) + file_info.write(buf, text) - printer.close() + file_info.write_after(buf) console_scripts( str(destination_path), @@ -150,7 +136,7 @@ def generate_pyi_for_proxy( ) -def _generate_stub_for_attr(cls, name, printer, env): +def _generate_stub_for_attr(cls, name, env): try: annotations = typing.get_type_hints(cls, env) except NameError: @@ -158,13 +144,13 @@ def _generate_stub_for_attr(cls, name, printer, env): type_ = annotations.get(name, "Any") if isinstance(type_, str) and type_[0] in "'\"": type_ = type_[1:-1] - printer.writeline(f"{name}: {type_}") + return f"{name}: {type_}" def _generate_stub_for_meth( fn, cls, - printer, + file_info, env, is_context_manager, is_overload=False, @@ -185,7 +171,8 @@ def _generate_stub_for_meth( name_args = spec[0] assert name_args[0:1] == ["self"] or name_args[0:1] == ["cls"] - name_args[0:1] = [] + if file_info.RemoveFirstArg: + name_args[0:1] = [] def _formatannotation(annotation, base_module=None): if getattr(annotation, "__module__", None) == "typing": @@ -219,8 +206,14 @@ def _generate_stub_for_meth( fn_doc = base_method.__doc__ if base_method else fn.__doc__ has_docs = gen_docs and fn_doc is not None - string_prefix = "r" if chr(92) in fn_doc else "" - docs = f'{string_prefix}"""' + f"{fn_doc}" + '"""' if has_docs else "" + string_prefix = "r" if has_docs and chr(92) in fn_doc else "" + if has_docs: + noqua = " # noqa: E501" if file_info.docs_noqa_E501 else "" + docs = f'{string_prefix}"""{fn_doc}"""{noqua}' + else: + docs = "" + + suffix = "..." if file_info.AddEllipsis and docs else "" func_text = textwrap.dedent( f""" @@ -228,61 +221,199 @@ def _generate_stub_for_meth( {contextmanager} def {name}{argspec}: {"..." if not docs else ""} {docs} + {suffix} """ ) - printer.write_indented_block(func_text) + return func_text -def run_file( - source_path: Path, cls_to_generate: type, stdout: bool, file_key: str -): - progname = Path(sys.argv[0]).as_posix() +def run_file(finfo: FileInfo, stdout: bool): if not stdout: generate_pyi_for_proxy( - cls_to_generate, - progname, - source_path=source_path, - destination_path=source_path, - ignore_output=False, - file_key=file_key, + finfo, destination_path=finfo.path, ignore_output=False ) else: - with NamedTemporaryFile(delete=False, suffix=".pyi") as f: + with NamedTemporaryFile(delete=False, suffix=finfo.path.suffix) as f: f.close() f_path = Path(f.name) generate_pyi_for_proxy( - cls_to_generate, - progname, - source_path=source_path, - destination_path=f_path, - ignore_output=True, - file_key=file_key, + finfo, destination_path=f_path, ignore_output=True ) sys.stdout.write(f_path.read_text()) f_path.unlink() def main(args): - location = Path(__file__).parent.parent / "alembic" - if args.file in {"all", "op"}: - run_file(location / "op.pyi", Operations, args.stdout, "op") - if args.file in {"all", "context"}: - run_file( - location / "context.pyi", - EnvironmentContext, - args.stdout, - "context", + for case in cases: + if args.name == "all" or args.name == case.name: + run_file(case, args.stdout) + + +@dataclass +class FileInfo: + RemoveFirstArg: typing.ClassVar[bool] + AddEllipsis: typing.ClassVar[bool] + + name: str + path: Path + target: type + ignore_items: set[str] = field(default_factory=set) + context_managers: set[str] = field(default_factory=set) + docs_noqa_E501: bool = field(default=False) + + @abstractmethod + def read_file(self): + pass + + @abstractmethod + def write_before(self, out: typing.IO[str], progname: str): + pass + + @abstractmethod + def write(self, out: typing.IO[str], text: str): + pass + + def write_after(self, out: typing.IO[str]): + pass + + +@dataclass +class StubFileInfo(FileInfo): + RemoveFirstArg = True + AddEllipsis = False + imports: list[str] = field(init=False) + + def read_file(self): + imports = [] + read_imports = False + with open(self.path) as read_file: + for line in read_file: + if line.startswith("# ### this file stubs are generated by"): + read_imports = True + elif line.startswith("### end imports ###"): + read_imports = False + break + elif read_imports: + imports.append(line.rstrip()) + self.imports = imports + + def write_before(self, out: typing.IO[str], progname: str): + self.write( + out, + f"# ### this file stubs are generated by {progname} " + "- do not edit ###", + ) + for line in self.imports: + self.write(out, line) + self.write(out, "### end imports ###\n") + + def write(self, out: typing.IO[str], text: str): + out.write(text) + out.write("\n") + + +@dataclass +class PyFileInfo(FileInfo): + RemoveFirstArg = False + AddEllipsis = True + indent: str = field(init=False) + before: list[str] = field(init=False) + after: list[str] = field(init=False) + + def read_file(self): + self.before = [] + self.after = [] + state = "before" + start_text = rf"^(\s*)# START STUB FUNCTIONS: {self.name}" + end_text = rf"^\s*# END STUB FUNCTIONS: {self.name}" + with open(self.path) as read_file: + for line in read_file: + if m := re.match(start_text, line): + assert state == "before" + self.indent = m.group(1) + self.before.append(line) + state = "stubs" + elif m := re.match(end_text, line): + assert state == "stubs" + state = "after" + if state == "before": + self.before.append(line) + if state == "after": + self.after.append(line) + assert state == "after", state + + def write_before(self, out: typing.IO[str], progname: str): + out.writelines(self.before) + self.write( + out, f"# ### the following stubs are generated by {progname} ###" ) + self.write(out, "# ### do not edit ###") + + def write(self, out: typing.IO[str], text: str): + out.write(textwrap.indent(text, self.indent)) + out.write("\n") + + def write_after(self, out: typing.IO[str]): + out.writelines(self.after) +location = Path(__file__).parent.parent / "alembic" + +cls_ignore = { + "batch_alter_table", + "context", + "create_module_class_proxy", + "f", + "get_bind", + "get_context", + "implementation_for", + "inline_literal", + "invoke", + "register_operation", +} + +cases = [ + StubFileInfo( + "op", + location / "op.pyi", + Operations, + ignore_items={"context", "create_module_class_proxy"}, + context_managers={"batch_alter_table"}, + ), + StubFileInfo( + "context", + location / "context.pyi", + EnvironmentContext, + ignore_items={ + "create_module_class_proxy", + "get_impl", + "requires_connection", + }, + ), + PyFileInfo( + "batch_op", + location / "operations/base.py", + BatchOperations, + ignore_items=cls_ignore, + docs_noqa_E501=True, + ), + PyFileInfo( + "op_cls", + location / "operations/base.py", + Operations, + ignore_items=cls_ignore, + docs_noqa_E501=True, + ), +] + if __name__ == "__main__": parser = ArgumentParser() parser.add_argument( - "--file", - choices={"op", "context", "all"}, + "--name", + choices=[fi.name for fi in cases] + ["all"], default="all", - help="Which file to generate. Default is to regenerate all files", + help="Which name to generate. Default is to regenerate all names", ) parser.add_argument( "--stdout", |