diff options
Diffstat (limited to 'tools/generate_proxy_methods.py')
-rw-r--r-- | tools/generate_proxy_methods.py | 81 |
1 files changed, 32 insertions, 49 deletions
diff --git a/tools/generate_proxy_methods.py b/tools/generate_proxy_methods.py index c21db9d60..cc039d4d6 100644 --- a/tools/generate_proxy_methods.py +++ b/tools/generate_proxy_methods.py @@ -40,16 +40,16 @@ typed by hand. .. versionadded:: 2.0 """ +# mypy: ignore-errors + from __future__ import annotations -from argparse import ArgumentParser import collections import importlib import inspect import os from pathlib import Path import re -import shutil import sys from tempfile import NamedTemporaryFile import textwrap @@ -65,9 +65,9 @@ from typing import TypeVar from sqlalchemy import util from sqlalchemy.util import compat from sqlalchemy.util import langhelpers -from sqlalchemy.util.langhelpers import console_scripts from sqlalchemy.util.langhelpers import format_argspec_plus from sqlalchemy.util.langhelpers import inject_docstring_text +from sqlalchemy.util.tool_support import code_writer_cmd is_posix = os.name == "posix" @@ -340,7 +340,7 @@ def process_class( instrument(buf, prop, clslevel=True) -def process_module(modname: str, filename: str) -> str: +def process_module(modname: str, filename: str, cmd: code_writer_cmd) -> str: class_entries = classes[modname] @@ -348,7 +348,9 @@ def process_module(modname: str, filename: str) -> str: # current working directory, so that black / zimports use # local pyproject.toml with NamedTemporaryFile( - mode="w", delete=False, suffix=".py", dir=Path(filename).parent + mode="w", + delete=False, + suffix=".py", ) as buf, open(filename) as orig_py: in_block = False @@ -358,7 +360,7 @@ def process_module(modname: str, filename: str) -> str: if m: current_clsname = m.group(1) args = class_entries[current_clsname] - sys.stderr.write( + cmd.write_status( f"Generating attributes for class {current_clsname}\n" ) in_block = True @@ -379,39 +381,21 @@ def process_module(modname: str, filename: str) -> str: return buf.name -def run_module(modname, stdout): +def run_module(modname: str, cmd: code_writer_cmd) -> None: - sys.stderr.write(f"importing module {modname}\n") + cmd.write_status(f"importing module {modname}\n") mod = importlib.import_module(modname) - filename = destination_path = mod.__file__ - assert filename is not None - - tempfile = process_module(modname, filename) - - ignore_output = stdout - - console_scripts( - str(tempfile), - {"entrypoint": "zimports"}, - ignore_output=ignore_output, - ) + destination_path = mod.__file__ + assert destination_path is not None - console_scripts( - str(tempfile), - {"entrypoint": "black"}, - ignore_output=ignore_output, - ) + tempfile = process_module(modname, destination_path, cmd) - if stdout: - with open(tempfile) as tf: - print(tf.read()) - os.unlink(tempfile) - else: - sys.stderr.write(f"Writing {destination_path}...\n") - shutil.move(tempfile, destination_path) + cmd.run_zimports(tempfile) + cmd.run_black(tempfile) + cmd.write_output_file_from_tempfile(tempfile, destination_path) -def main(args): +def main(cmd: code_writer_cmd) -> None: from sqlalchemy import util from sqlalchemy.util import langhelpers @@ -420,8 +404,8 @@ def main(args): ) = create_proxy_methods for entry in entries: - if args.module in {"all", entry}: - run_module(entry, args.stdout) + if cmd.args.module in {"all", entry}: + run_module(entry, cmd) entries = [ @@ -432,17 +416,16 @@ entries = [ ] if __name__ == "__main__": - parser = ArgumentParser() - parser.add_argument( - "--module", - choices=entries + ["all"], - default="all", - help="Which file to generate. Default is to regenerate all files", - ) - parser.add_argument( - "--stdout", - action="store_true", - help="Write to stdout instead of saving to file", - ) - args = parser.parse_args() - main(args) + + cmd = code_writer_cmd(__file__) + + with cmd.add_arguments() as parser: + parser.add_argument( + "--module", + choices=entries + ["all"], + default="all", + help="Which file to generate. Default is to regenerate all files", + ) + + with cmd.run_program(): + main(cmd) |