summaryrefslogtreecommitdiff
path: root/tools/generate_proxy_methods.py
diff options
context:
space:
mode:
Diffstat (limited to 'tools/generate_proxy_methods.py')
-rw-r--r--tools/generate_proxy_methods.py81
1 files changed, 32 insertions, 49 deletions
diff --git a/tools/generate_proxy_methods.py b/tools/generate_proxy_methods.py
index c21db9d60..cc039d4d6 100644
--- a/tools/generate_proxy_methods.py
+++ b/tools/generate_proxy_methods.py
@@ -40,16 +40,16 @@ typed by hand.
.. versionadded:: 2.0
"""
+# mypy: ignore-errors
+
from __future__ import annotations
-from argparse import ArgumentParser
import collections
import importlib
import inspect
import os
from pathlib import Path
import re
-import shutil
import sys
from tempfile import NamedTemporaryFile
import textwrap
@@ -65,9 +65,9 @@ from typing import TypeVar
from sqlalchemy import util
from sqlalchemy.util import compat
from sqlalchemy.util import langhelpers
-from sqlalchemy.util.langhelpers import console_scripts
from sqlalchemy.util.langhelpers import format_argspec_plus
from sqlalchemy.util.langhelpers import inject_docstring_text
+from sqlalchemy.util.tool_support import code_writer_cmd
is_posix = os.name == "posix"
@@ -340,7 +340,7 @@ def process_class(
instrument(buf, prop, clslevel=True)
-def process_module(modname: str, filename: str) -> str:
+def process_module(modname: str, filename: str, cmd: code_writer_cmd) -> str:
class_entries = classes[modname]
@@ -348,7 +348,9 @@ def process_module(modname: str, filename: str) -> str:
# current working directory, so that black / zimports use
# local pyproject.toml
with NamedTemporaryFile(
- mode="w", delete=False, suffix=".py", dir=Path(filename).parent
+ mode="w",
+ delete=False,
+ suffix=".py",
) as buf, open(filename) as orig_py:
in_block = False
@@ -358,7 +360,7 @@ def process_module(modname: str, filename: str) -> str:
if m:
current_clsname = m.group(1)
args = class_entries[current_clsname]
- sys.stderr.write(
+ cmd.write_status(
f"Generating attributes for class {current_clsname}\n"
)
in_block = True
@@ -379,39 +381,21 @@ def process_module(modname: str, filename: str) -> str:
return buf.name
-def run_module(modname, stdout):
+def run_module(modname: str, cmd: code_writer_cmd) -> None:
- sys.stderr.write(f"importing module {modname}\n")
+ cmd.write_status(f"importing module {modname}\n")
mod = importlib.import_module(modname)
- filename = destination_path = mod.__file__
- assert filename is not None
-
- tempfile = process_module(modname, filename)
-
- ignore_output = stdout
-
- console_scripts(
- str(tempfile),
- {"entrypoint": "zimports"},
- ignore_output=ignore_output,
- )
+ destination_path = mod.__file__
+ assert destination_path is not None
- console_scripts(
- str(tempfile),
- {"entrypoint": "black"},
- ignore_output=ignore_output,
- )
+ tempfile = process_module(modname, destination_path, cmd)
- if stdout:
- with open(tempfile) as tf:
- print(tf.read())
- os.unlink(tempfile)
- else:
- sys.stderr.write(f"Writing {destination_path}...\n")
- shutil.move(tempfile, destination_path)
+ cmd.run_zimports(tempfile)
+ cmd.run_black(tempfile)
+ cmd.write_output_file_from_tempfile(tempfile, destination_path)
-def main(args):
+def main(cmd: code_writer_cmd) -> None:
from sqlalchemy import util
from sqlalchemy.util import langhelpers
@@ -420,8 +404,8 @@ def main(args):
) = create_proxy_methods
for entry in entries:
- if args.module in {"all", entry}:
- run_module(entry, args.stdout)
+ if cmd.args.module in {"all", entry}:
+ run_module(entry, cmd)
entries = [
@@ -432,17 +416,16 @@ entries = [
]
if __name__ == "__main__":
- parser = ArgumentParser()
- parser.add_argument(
- "--module",
- choices=entries + ["all"],
- default="all",
- help="Which file to generate. Default is to regenerate all files",
- )
- parser.add_argument(
- "--stdout",
- action="store_true",
- help="Write to stdout instead of saving to file",
- )
- args = parser.parse_args()
- main(args)
+
+ cmd = code_writer_cmd(__file__)
+
+ with cmd.add_arguments() as parser:
+ parser.add_argument(
+ "--module",
+ choices=entries + ["all"],
+ default="all",
+ help="Which file to generate. Default is to regenerate all files",
+ )
+
+ with cmd.run_program():
+ main(cmd)