summaryrefslogtreecommitdiff
path: root/tools
diff options
context:
space:
mode:
authorCaselIT <cfederico87@gmail.com>2022-06-04 05:59:23 -0400
committerCaselIT <cfederico87@gmail.com>2022-06-04 12:05:46 +0200
commite28ee4ed42ac57f727a934a0916075168d87fcf3 (patch)
tree34a4bb825e7bca1e4d639e532e9d2f05a82030d8 /tools
parent0dd2cc8ddfb7a48b693c12f20d68964a78912e59 (diff)
downloadalembic-e28ee4ed42ac57f727a934a0916075168d87fcf3.tar.gz
Annotate batch_alter_table
Fixes: #975 Closes: #1032 Pull-request: https://github.com/sqlalchemy/alembic/pull/1032 Pull-request-sha: a111d4f446e861bd01a3cea7ebd1c18a2446601a Change-Id: Idb8e1c8b6577204a64cde195f094830cdbba68ce
Diffstat (limited to 'tools')
-rw-r--r--tools/write_pyi.py38
1 files changed, 18 insertions, 20 deletions
diff --git a/tools/write_pyi.py b/tools/write_pyi.py
index ec928cc..cf42d1b 100644
--- a/tools/write_pyi.py
+++ b/tools/write_pyi.py
@@ -38,6 +38,7 @@ TRIM_MODULE = [
"sqlalchemy.sql.functions.",
"sqlalchemy.sql.dml.",
]
+CONTEXT_MANAGERS = {"op": ["batch_alter_table"]}
def generate_pyi_for_proxy(
@@ -46,8 +47,10 @@ def generate_pyi_for_proxy(
source_path: Path,
destination_path: Path,
ignore_output: bool,
- ignore_items: set,
+ file_key: str,
):
+ ignore_items = IGNORE_ITEMS.get(file_key, set())
+ context_managers = CONTEXT_MANAGERS.get(file_key, [])
if sys.version_info < (3, 9):
raise RuntimeError("This script must be run with Python 3.9 or higher")
@@ -93,7 +96,9 @@ def generate_pyi_for_proxy(
continue
meth = getattr(cls, name, None)
if callable(meth):
- _generate_stub_for_meth(cls, name, printer, env)
+ _generate_stub_for_meth(
+ cls, name, printer, env, name in context_managers
+ )
else:
_generate_stub_for_attr(cls, name, printer, env)
@@ -125,7 +130,7 @@ def _generate_stub_for_attr(cls, name, printer, env):
printer.writeline(f"{name}: {type_}")
-def _generate_stub_for_meth(cls, name, printer, env):
+def _generate_stub_for_meth(cls, name, printer, env, is_context_manager):
fn = getattr(cls, name)
while hasattr(fn, "__wrapped__"):
@@ -168,24 +173,19 @@ def _generate_stub_for_meth(cls, name, printer, env):
formatannotation=_formatannotation,
formatreturns=lambda val: f"-> {_formatannotation(val)}",
)
-
+ contextmanager = "@contextmanager" if is_context_manager else ""
func_text = textwrap.dedent(
- """\
- def %(name)s%(argspec)s:
- '''%(doc)s'''
+ f"""
+ {contextmanager}
+ def {name}{argspec}:
+ '''{fn.__doc__}'''
"""
- % {
- "name": name,
- "argspec": argspec,
- "doc": fn.__doc__,
- }
)
-
printer.write_indented_block(func_text)
def run_file(
- source_path: Path, cls_to_generate: type, stdout: bool, ignore_items: set
+ source_path: Path, cls_to_generate: type, stdout: bool, file_key: str
):
progname = Path(sys.argv[0]).as_posix()
if not stdout:
@@ -195,7 +195,7 @@ def run_file(
source_path=source_path,
destination_path=source_path,
ignore_output=False,
- ignore_items=ignore_items,
+ file_key=file_key,
)
else:
with NamedTemporaryFile(delete=False, suffix=".pyi") as f:
@@ -207,7 +207,7 @@ def run_file(
source_path=source_path,
destination_path=f_path,
ignore_output=True,
- ignore_items=ignore_items,
+ file_key=file_key,
)
sys.stdout.write(f_path.read_text())
f_path.unlink()
@@ -216,15 +216,13 @@ def run_file(
def main(args):
location = Path(__file__).parent.parent / "alembic"
if args.file in {"all", "op"}:
- run_file(
- location / "op.pyi", Operations, args.stdout, IGNORE_ITEMS["op"]
- )
+ run_file(location / "op.pyi", Operations, args.stdout, "op")
if args.file in {"all", "context"}:
run_file(
location / "context.pyi",
EnvironmentContext,
args.stdout,
- IGNORE_ITEMS["context"],
+ "context",
)