summaryrefslogtreecommitdiff
path: root/tools/sync_test_files.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2023-01-18 12:45:42 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2023-01-18 15:07:55 -0500
commitcd96ffe287e26651f8dce4f688bf87af1e423f06 (patch)
tree123197d983714caa9ae225d330fcf3743ba11b52 /tools/sync_test_files.py
parentf91a25cd8191f026dd43c0a2475cda8a56d65c19 (diff)
downloadsqlalchemy-cd96ffe287e26651f8dce4f688bf87af1e423f06.tar.gz
refactor code generation tools , include --check command
in particular it looks like CI was not picking up on the "git diff" oriented commands, which were failing to run due to pathing issues. As we were setting cwd for black/zimports relative to sqlalchemy library, and tox installs it in the venv, black/zimports would fail to run from tox, and since these are subprocess.run we didn't pick up the failure. This overall locks down how zimports/black are run so that we are definitely from the source root, by using the location of tools/ to determine the root. Fixes: #8892 Change-Id: I7c54b747edd5a80e0c699b8456febf66d8b62375
Diffstat (limited to 'tools/sync_test_files.py')
-rw-r--r--tools/sync_test_files.py38
1 files changed, 22 insertions, 16 deletions
diff --git a/tools/sync_test_files.py b/tools/sync_test_files.py
index 4ef15374a..4afa2dc8e 100644
--- a/tools/sync_test_files.py
+++ b/tools/sync_test_files.py
@@ -5,8 +5,11 @@
from __future__ import annotations
-from argparse import ArgumentParser
from pathlib import Path
+from typing import Any
+from typing import Iterable
+
+from sqlalchemy.util.tool_support import code_writer_cmd
header = '''\
"""This file is automatically generated from the file
@@ -22,27 +25,27 @@ from __future__ import annotations
home = Path(__file__).parent.parent
this_file = Path(__file__).relative_to(home).as_posix()
-remove_str = '# anno only: '
+remove_str = "# anno only: "
-def run_operation(name: str, source: str, dest: str):
- print("Running", name, "...", end="", flush=True)
- source_data = Path(source).read_text().replace(remove_str, '')
- dest_data = header.format(source=source, this_file=this_file) + source_data
+def run_operation(
+ name: str, source: str, dest: str, cmd: code_writer_cmd
+) -> None:
- Path(dest).write_text(dest_data)
+ source_data = Path(source).read_text().replace(remove_str, "")
+ dest_data = header.format(source=source, this_file=this_file) + source_data
- print(".. done")
+ cmd.write_output_file_from_text(dest_data, dest)
-def main(file: str):
+def main(file: str, cmd: code_writer_cmd) -> None:
if file == "all":
- operations = files.items()
+ operations: Iterable[Any] = files.items()
else:
operations = [(file, files[file])]
for name, info in operations:
- run_operation(name, info["source"], info["dest"])
+ run_operation(name, info["source"], info["dest"], cmd)
files = {
@@ -53,8 +56,11 @@ files = {
}
if __name__ == "__main__":
- parser = ArgumentParser()
- parser.add_argument("--file", choices=list(files) + ["all"], default="all")
-
- args = parser.parse_args()
- main(args.file)
+ cmd = code_writer_cmd(__file__)
+ with cmd.add_arguments() as parser:
+ parser.add_argument(
+ "--file", choices=list(files) + ["all"], default="all"
+ )
+
+ with cmd.run_program():
+ main(cmd.args.file, cmd)