diff options
Diffstat (limited to 'tools/generate_proxy_methods.py')
-rw-r--r-- | tools/generate_proxy_methods.py | 93 |
1 files changed, 55 insertions, 38 deletions
diff --git a/tools/generate_proxy_methods.py b/tools/generate_proxy_methods.py index ffc470972..91a891882 100644 --- a/tools/generate_proxy_methods.py +++ b/tools/generate_proxy_methods.py @@ -31,6 +31,12 @@ A similar approach is used in Alembic where a dynamic approach towards creating alembic "ops" was enhanced to generate a .pyi stubs file statically for consumption by typing tools. +Note that the usual OO approach of having a common interface class with +concrete subtypes doesn't really solve any problems here; the concrete subtypes +must still list out all methods, arguments, typing annotations, and docstrings, +all of which is copied by this script rather than requiring it all be +typed by hand. + .. versionadded:: 2.0 """ @@ -43,9 +49,7 @@ import inspect import os from pathlib import Path import re -import shlex import shutil -import subprocess import sys from tempfile import NamedTemporaryFile import textwrap @@ -61,6 +65,7 @@ from typing import TypeVar from sqlalchemy import util from sqlalchemy.util import compat from sqlalchemy.util import langhelpers +from sqlalchemy.util.langhelpers import console_scripts from sqlalchemy.util.langhelpers import format_argspec_plus from sqlalchemy.util.langhelpers import inject_docstring_text @@ -122,6 +127,47 @@ def create_proxy_methods( return decorate +def _grab_overloads(fn): + """grab @overload entries for a function, assuming black-formatted + code ;) so that we can do a simple regex + + """ + + # functions that use @util.deprecated and whatnot will have a string + # generated fn. we can look at __wrapped__ but these functions don't + # have any overloads in any case right now so skip + if fn.__code__.co_filename == "<string>": + return [] + + with open(fn.__code__.co_filename) as f: + lines = [l for i, l in zip(range(fn.__code__.co_firstlineno), f)] + + lines.reverse() + + output = [] + + current_ov = [] + for line in lines[1:]: + current_ov.append(line) + outside_block_match = re.match(r"^\w", line) + if outside_block_match: + current_ov[:] = [] + break + + fn_match = re.match(rf"^ (?:async )?def (.*)\($", line) + if fn_match and fn_match.group(1) != fn.__name__: + current_ov[:] = [] + break + + ov_match = re.match(r"^ @overload$", line) + if ov_match: + output.append("".join(reversed(current_ov))) + current_ov[:] = [] + + output.reverse() + return output + + def process_class( buf: TextIO, target_cls: Type[Any], @@ -145,6 +191,12 @@ def process_class( def instrument(buf: TextIO, name: str, clslevel: bool = False) -> None: fn = getattr(target_cls, name) + + overloads = _grab_overloads(fn) + + for overload in overloads: + buf.write(overload) + spec = compat.inspect_getfullargspec(fn) iscoroutine = inspect.iscoroutinefunction(fn) @@ -311,7 +363,7 @@ def process_module(modname: str, filename: str) -> str: "\n # code within this block is " "**programmatically, \n" " # statically generated** by" - " tools/generate_proxy_methods.py\n\n" + f" tools/{os.path.basename(__file__)}\n\n" ) process_class(buf, *args) @@ -323,41 +375,6 @@ def process_module(modname: str, filename: str) -> str: return buf.name -def console_scripts( - path: str, options: dict, ignore_output: bool = False -) -> None: - - entrypoint_name = options["entrypoint"] - - for entry in compat.importlib_metadata_get("console_scripts"): - if entry.name == entrypoint_name: - impl = entry - break - else: - raise Exception( - f"Could not find entrypoint console_scripts.{entrypoint_name}" - ) - cmdline_options_str = options.get("options", "") - cmdline_options_list = shlex.split(cmdline_options_str, posix=is_posix) + [ - path - ] - - kw = {} - if ignore_output: - kw["stdout"] = kw["stderr"] = subprocess.DEVNULL - - subprocess.run( - [ - sys.executable, - "-c", - "import %s; %s.%s()" % (impl.module, impl.module, impl.attr), - ] - + cmdline_options_list, - cwd=Path(__file__).parent.parent, - **kw, - ) - - def run_module(modname, stdout): sys.stderr.write(f"importing module {modname}\n") |