summaryrefslogtreecommitdiff
path: root/tools
diff options
context:
space:
mode:
authorCaselIT <cfederico87@gmail.com>2023-03-17 00:50:53 +0100
committerFederico Caselli <cfederico87@gmail.com>2023-05-04 22:53:09 +0200
commit2aba0ada168d0047d54c7a08b0ffdde3102b716b (patch)
tree6334971e00debf16b208de8b12d882a79b75d902 /tools
parente17e59ee2be160fff35b38b08d68766a971b3069 (diff)
downloadalembic-2aba0ada168d0047d54c7a08b0ffdde3102b716b.tar.gz
Add Operations and BatchOperations stub methods
Updated stub generator script to also add stubs method definitions for the :class:`.Operations` class and the :class:`.BatchOperations` class obtained from :meth:`.Operations.batch_alter_table`. Repaired the return signatures for :class:`.Operations` that mostly return ``None``, and were erroneously referring to ``Optional[Table]`` in many cases. Fixes: #1093 Change-Id: I98d38dd5a1e719b4dbbc1003746ec28f26c27808
Diffstat (limited to 'tools')
-rw-r--r--tools/write_pyi.py303
1 files changed, 217 insertions, 86 deletions
diff --git a/tools/write_pyi.py b/tools/write_pyi.py
index fa79c49..7d24870 100644
--- a/tools/write_pyi.py
+++ b/tools/write_pyi.py
@@ -1,4 +1,9 @@
+from __future__ import annotations
+
+from abc import abstractmethod
from argparse import ArgumentParser
+from dataclasses import dataclass
+from dataclasses import field
from pathlib import Path
import re
import sys
@@ -6,11 +11,14 @@ from tempfile import NamedTemporaryFile
import textwrap
import typing
-from mako.pygen import PythonPrinter
+from alembic.autogenerate.api import AutogenContext
+from alembic.ddl.impl import DefaultImpl
+from alembic.runtime.migration import MigrationInfo
sys.path.append(str(Path(__file__).parent.parent))
if True: # avoid flake/zimports messing with the order
+ from alembic.operations.base import BatchOperations
from alembic.operations.base import Operations
from alembic.runtime.environment import EnvironmentContext
from alembic.runtime.migration import MigrationContext
@@ -20,18 +28,12 @@ if True: # avoid flake/zimports messing with the order
from alembic.operations import ops
import sqlalchemy as sa
-IGNORE_ITEMS = {
- "op": {"context", "create_module_class_proxy"},
- "context": {
- "create_module_class_proxy",
- "get_impl",
- "requires_connection",
- },
-}
+
TRIM_MODULE = [
"alembic.runtime.migration.",
"alembic.operations.base.",
"alembic.operations.ops.",
+ "alembic.autogenerate.api.",
"sqlalchemy.engine.base.",
"sqlalchemy.engine.url.",
"sqlalchemy.sql.schema.",
@@ -41,58 +43,38 @@ TRIM_MODULE = [
"sqlalchemy.sql.functions.",
"sqlalchemy.sql.dml.",
]
-CONTEXT_MANAGERS = {"op": ["batch_alter_table"]}
-ADDITIONAL_ENV = {"MigrationContext": MigrationContext}
+ADDITIONAL_ENV = {
+ "MigrationContext": MigrationContext,
+ "AutogenContext": AutogenContext,
+ "DefaultImpl": DefaultImpl,
+ "MigrationInfo": MigrationInfo,
+}
def generate_pyi_for_proxy(
- cls: type,
- progname: str,
- source_path: Path,
- destination_path: Path,
- ignore_output: bool,
- file_key: str,
+ file_info: FileInfo, destination_path: Path, ignore_output: bool
):
- ignore_items = IGNORE_ITEMS.get(file_key, set())
- context_managers = CONTEXT_MANAGERS.get(file_key, [])
if sys.version_info < (3, 11):
raise RuntimeError(
"This script must be run with Python 3.11 or higher"
)
+ progname = Path(sys.argv[0]).as_posix()
# When using an absolute path on windows, this will generate the correct
# relative path that shall be written to the top comment of the pyi file.
if Path(progname).is_absolute():
progname = Path(progname).relative_to(Path().cwd()).as_posix()
- imports = []
- read_imports = False
- with open(source_path) as read_file:
- for line in read_file:
- if line.startswith("# ### this file stubs are generated by"):
- read_imports = True
- elif line.startswith("### end imports ###"):
- read_imports = False
- break
- elif read_imports:
- imports.append(line.rstrip())
+ file_info.read_file()
+ cls = file_info.target
with open(destination_path, "w") as buf:
- printer = PythonPrinter(buf)
-
- printer.writeline(
- f"# ### this file stubs are generated by {progname} "
- "- do not edit ###"
- )
- for line in imports:
- buf.write(line + "\n")
- printer.writeline("### end imports ###")
- buf.write("\n\n")
+ file_info.write_before(buf, progname)
module = sys.modules[cls.__module__]
env = {
**typing.__dict__,
- **sa.sql.schema.__dict__,
+ **sa.schema.__dict__,
**sa.__dict__,
**sa.types.__dict__,
**ADDITIONAL_ENV,
@@ -101,39 +83,43 @@ def generate_pyi_for_proxy(
}
for name in dir(cls):
- if name.startswith("_") or name in ignore_items:
+ if name.startswith("_") or name in file_info.ignore_items:
continue
meth = getattr(cls, name, None)
if callable(meth):
# If there are overloads, generate only those
# Do not generate the base implementation to avoid mypy errors
overloads = typing.get_overloads(meth)
+ is_context_manager = name in file_info.context_managers
if overloads:
# use enumerate so we can generate docs on the
# last overload
for i, ovl in enumerate(overloads, 1):
- _generate_stub_for_meth(
+ text = _generate_stub_for_meth(
ovl,
cls,
- printer,
+ file_info,
env,
- is_context_manager=name in context_managers,
+ is_context_manager=is_context_manager,
is_overload=True,
base_method=meth,
gen_docs=(i == len(overloads)),
)
+ file_info.write(buf, text)
else:
- _generate_stub_for_meth(
+ text = _generate_stub_for_meth(
meth,
cls,
- printer,
+ file_info,
env,
- is_context_manager=name in context_managers,
+ is_context_manager=is_context_manager,
)
+ file_info.write(buf, text)
else:
- _generate_stub_for_attr(cls, name, printer, env)
+ text = _generate_stub_for_attr(cls, name, env)
+ file_info.write(buf, text)
- printer.close()
+ file_info.write_after(buf)
console_scripts(
str(destination_path),
@@ -150,7 +136,7 @@ def generate_pyi_for_proxy(
)
-def _generate_stub_for_attr(cls, name, printer, env):
+def _generate_stub_for_attr(cls, name, env):
try:
annotations = typing.get_type_hints(cls, env)
except NameError:
@@ -158,13 +144,13 @@ def _generate_stub_for_attr(cls, name, printer, env):
type_ = annotations.get(name, "Any")
if isinstance(type_, str) and type_[0] in "'\"":
type_ = type_[1:-1]
- printer.writeline(f"{name}: {type_}")
+ return f"{name}: {type_}"
def _generate_stub_for_meth(
fn,
cls,
- printer,
+ file_info,
env,
is_context_manager,
is_overload=False,
@@ -185,7 +171,8 @@ def _generate_stub_for_meth(
name_args = spec[0]
assert name_args[0:1] == ["self"] or name_args[0:1] == ["cls"]
- name_args[0:1] = []
+ if file_info.RemoveFirstArg:
+ name_args[0:1] = []
def _formatannotation(annotation, base_module=None):
if getattr(annotation, "__module__", None) == "typing":
@@ -219,8 +206,14 @@ def _generate_stub_for_meth(
fn_doc = base_method.__doc__ if base_method else fn.__doc__
has_docs = gen_docs and fn_doc is not None
- string_prefix = "r" if chr(92) in fn_doc else ""
- docs = f'{string_prefix}"""' + f"{fn_doc}" + '"""' if has_docs else ""
+ string_prefix = "r" if has_docs and chr(92) in fn_doc else ""
+ if has_docs:
+ noqua = " # noqa: E501" if file_info.docs_noqa_E501 else ""
+ docs = f'{string_prefix}"""{fn_doc}"""{noqua}'
+ else:
+ docs = ""
+
+ suffix = "..." if file_info.AddEllipsis and docs else ""
func_text = textwrap.dedent(
f"""
@@ -228,61 +221,199 @@ def _generate_stub_for_meth(
{contextmanager}
def {name}{argspec}: {"..." if not docs else ""}
{docs}
+ {suffix}
"""
)
- printer.write_indented_block(func_text)
+ return func_text
-def run_file(
- source_path: Path, cls_to_generate: type, stdout: bool, file_key: str
-):
- progname = Path(sys.argv[0]).as_posix()
+def run_file(finfo: FileInfo, stdout: bool):
if not stdout:
generate_pyi_for_proxy(
- cls_to_generate,
- progname,
- source_path=source_path,
- destination_path=source_path,
- ignore_output=False,
- file_key=file_key,
+ finfo, destination_path=finfo.path, ignore_output=False
)
else:
- with NamedTemporaryFile(delete=False, suffix=".pyi") as f:
+ with NamedTemporaryFile(delete=False, suffix=finfo.path.suffix) as f:
f.close()
f_path = Path(f.name)
generate_pyi_for_proxy(
- cls_to_generate,
- progname,
- source_path=source_path,
- destination_path=f_path,
- ignore_output=True,
- file_key=file_key,
+ finfo, destination_path=f_path, ignore_output=True
)
sys.stdout.write(f_path.read_text())
f_path.unlink()
def main(args):
- location = Path(__file__).parent.parent / "alembic"
- if args.file in {"all", "op"}:
- run_file(location / "op.pyi", Operations, args.stdout, "op")
- if args.file in {"all", "context"}:
- run_file(
- location / "context.pyi",
- EnvironmentContext,
- args.stdout,
- "context",
+ for case in cases:
+ if args.name == "all" or args.name == case.name:
+ run_file(case, args.stdout)
+
+
+@dataclass
+class FileInfo:
+ RemoveFirstArg: typing.ClassVar[bool]
+ AddEllipsis: typing.ClassVar[bool]
+
+ name: str
+ path: Path
+ target: type
+ ignore_items: set[str] = field(default_factory=set)
+ context_managers: set[str] = field(default_factory=set)
+ docs_noqa_E501: bool = field(default=False)
+
+ @abstractmethod
+ def read_file(self):
+ pass
+
+ @abstractmethod
+ def write_before(self, out: typing.IO[str], progname: str):
+ pass
+
+ @abstractmethod
+ def write(self, out: typing.IO[str], text: str):
+ pass
+
+ def write_after(self, out: typing.IO[str]):
+ pass
+
+
+@dataclass
+class StubFileInfo(FileInfo):
+ RemoveFirstArg = True
+ AddEllipsis = False
+ imports: list[str] = field(init=False)
+
+ def read_file(self):
+ imports = []
+ read_imports = False
+ with open(self.path) as read_file:
+ for line in read_file:
+ if line.startswith("# ### this file stubs are generated by"):
+ read_imports = True
+ elif line.startswith("### end imports ###"):
+ read_imports = False
+ break
+ elif read_imports:
+ imports.append(line.rstrip())
+ self.imports = imports
+
+ def write_before(self, out: typing.IO[str], progname: str):
+ self.write(
+ out,
+ f"# ### this file stubs are generated by {progname} "
+ "- do not edit ###",
+ )
+ for line in self.imports:
+ self.write(out, line)
+ self.write(out, "### end imports ###\n")
+
+ def write(self, out: typing.IO[str], text: str):
+ out.write(text)
+ out.write("\n")
+
+
+@dataclass
+class PyFileInfo(FileInfo):
+ RemoveFirstArg = False
+ AddEllipsis = True
+ indent: str = field(init=False)
+ before: list[str] = field(init=False)
+ after: list[str] = field(init=False)
+
+ def read_file(self):
+ self.before = []
+ self.after = []
+ state = "before"
+ start_text = rf"^(\s*)# START STUB FUNCTIONS: {self.name}"
+ end_text = rf"^\s*# END STUB FUNCTIONS: {self.name}"
+ with open(self.path) as read_file:
+ for line in read_file:
+ if m := re.match(start_text, line):
+ assert state == "before"
+ self.indent = m.group(1)
+ self.before.append(line)
+ state = "stubs"
+ elif m := re.match(end_text, line):
+ assert state == "stubs"
+ state = "after"
+ if state == "before":
+ self.before.append(line)
+ if state == "after":
+ self.after.append(line)
+ assert state == "after", state
+
+ def write_before(self, out: typing.IO[str], progname: str):
+ out.writelines(self.before)
+ self.write(
+ out, f"# ### the following stubs are generated by {progname} ###"
)
+ self.write(out, "# ### do not edit ###")
+
+ def write(self, out: typing.IO[str], text: str):
+ out.write(textwrap.indent(text, self.indent))
+ out.write("\n")
+
+ def write_after(self, out: typing.IO[str]):
+ out.writelines(self.after)
+location = Path(__file__).parent.parent / "alembic"
+
+cls_ignore = {
+ "batch_alter_table",
+ "context",
+ "create_module_class_proxy",
+ "f",
+ "get_bind",
+ "get_context",
+ "implementation_for",
+ "inline_literal",
+ "invoke",
+ "register_operation",
+}
+
+cases = [
+ StubFileInfo(
+ "op",
+ location / "op.pyi",
+ Operations,
+ ignore_items={"context", "create_module_class_proxy"},
+ context_managers={"batch_alter_table"},
+ ),
+ StubFileInfo(
+ "context",
+ location / "context.pyi",
+ EnvironmentContext,
+ ignore_items={
+ "create_module_class_proxy",
+ "get_impl",
+ "requires_connection",
+ },
+ ),
+ PyFileInfo(
+ "batch_op",
+ location / "operations/base.py",
+ BatchOperations,
+ ignore_items=cls_ignore,
+ docs_noqa_E501=True,
+ ),
+ PyFileInfo(
+ "op_cls",
+ location / "operations/base.py",
+ Operations,
+ ignore_items=cls_ignore,
+ docs_noqa_E501=True,
+ ),
+]
+
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument(
- "--file",
- choices={"op", "context", "all"},
+ "--name",
+ choices=[fi.name for fi in cases] + ["all"],
default="all",
- help="Which file to generate. Default is to regenerate all files",
+ help="Which name to generate. Default is to regenerate all names",
)
parser.add_argument(
"--stdout",