diff options
Diffstat (limited to 'tools/generate_proxy_methods.py')
-rw-r--r-- | tools/generate_proxy_methods.py | 413 |
1 files changed, 413 insertions, 0 deletions
diff --git a/tools/generate_proxy_methods.py b/tools/generate_proxy_methods.py new file mode 100644 index 000000000..eec4d878a --- /dev/null +++ b/tools/generate_proxy_methods.py @@ -0,0 +1,413 @@ +"""Generate static proxy code for SQLAlchemy classes that proxy other +objects. + +This tool is run at source code authoring / commit time whenever we add new +methods to engines/connections/sessions that need to be generically proxied by +scoped_session or asyncio. The generated code is part of what's committed +to source just as though we typed it all by hand. + +The original "proxy" class was scoped_session. Then with asyncio, all the +asyncio objects are essentially "proxy" objects as well; while all the methods +that are "async" needed to be written by hand, there's lots of other attributes +and methods that are proxied exactly. + +To eliminate redundancy, all of these classes made use of the +@langhelpers.create_proxy_methods() decorator which at runtime would read a +selected list of methods and attributes from the proxied class and generate new +methods and properties descriptors on the proxying class; this way the proxy +would have all the same methods signatures / attributes / docstrings consumed +by Sphinx and look just like the proxied class. + +Then mypy and typing came along, which don't care about runtime generated code +and never will. So this script takes that same +@langhelpers.create_proxy_methods() decorator, keeps its public interface just +as is, and uses it to generate all the code and docs in those proxy classes +statically, as though we sat there and spent seven hours typing it all by hand. +The runtime code generation part is removed from ``create_proxy_methods()``. +Now we have static code that is perfectly consumable by all the typing tools +and we also reduce import time a bit. + +A similar approach is used in Alembic where a dynamic approach towards creating +alembic "ops" was enhanced to generate a .pyi stubs file statically for +consumption by typing tools. + +.. versionadded:: 2.0 + +""" +from __future__ import annotations + +from argparse import ArgumentParser +import collections +import importlib +import inspect +import os +from pathlib import Path +import re +import shlex +import shutil +import subprocess +import sys +from tempfile import NamedTemporaryFile +import textwrap +from typing import Any +from typing import Callable +from typing import Dict +from typing import Iterable +from typing import TextIO +from typing import Tuple +from typing import Type +from typing import TypeVar + +from sqlalchemy import util +from sqlalchemy.util import compat +from sqlalchemy.util import langhelpers +from sqlalchemy.util.langhelpers import format_argspec_plus +from sqlalchemy.util.langhelpers import inject_docstring_text + +is_posix = os.name == "posix" + + +sys.path.append(str(Path(__file__).parent.parent)) + + +class _repr_sym: + __slots__ = ("sym",) + + def __init__(self, sym: str): + self.sym = sym + + def __repr__(self) -> str: + return self.sym + + +classes: collections.defaultdict[ + str, Dict[str, Tuple[Any, ...]] +] = collections.defaultdict(dict) + +_T = TypeVar("_T", bound="Any") + + +def create_proxy_methods( + target_cls: Type[Any], + target_cls_sphinx_name: str, + proxy_cls_sphinx_name: str, + classmethods: Iterable[str] = (), + methods: Iterable[str] = (), + attributes: Iterable[str] = (), +) -> Callable[[Type[_T]], Type[_T]]: + """A class decorator that will copy attributes to a proxy class. + + The class to be instrumented must define a single accessor "_proxied". + + """ + + def decorate(cls: Type[_T]) -> Type[_T]: + # collect the class as a separate step. since the decorator + # is called as a result of imports, the order in which classes + # are collected (like in asyncio) can't be well controlled. however, + # the proxies (specifically asyncio session and asyncio scoped_session) + # have to be generated in dependency order, so run them in order in a + # second step. + classes[cls.__module__][cls.__name__] = ( + target_cls, + target_cls_sphinx_name, + proxy_cls_sphinx_name, + classmethods, + methods, + attributes, + cls, + ) + return cls + + return decorate + + +def process_class( + buf: TextIO, + target_cls: Type[Any], + target_cls_sphinx_name: str, + proxy_cls_sphinx_name: str, + classmethods: Iterable[str], + methods: Iterable[str], + attributes: Iterable[str], + cls: Type[Any], +): + + sphinx_symbol_match = re.match(r":class:`(.+)`", target_cls_sphinx_name) + if not sphinx_symbol_match: + raise Exception( + f"Couldn't match sphinx class identifier from " + f"target_cls_sphinx_name f{target_cls_sphinx_name!r}. Currently " + 'this program expects the form ":class:`_<prefix>.<clsname>`"' + ) + + sphinx_symbol = sphinx_symbol_match.group(1) + + def instrument(buf: TextIO, name: str, clslevel: bool = False) -> None: + fn = getattr(target_cls, name) + spec = compat.inspect_getfullargspec(fn) + + iscoroutine = inspect.iscoroutinefunction(fn) + + if spec.defaults: + new_defaults = tuple( + _repr_sym("util.EMPTY_DICT") if df is util.EMPTY_DICT else df + for df in spec.defaults + ) + elem = list(spec) + elem[3] = tuple(new_defaults) + spec = compat.FullArgSpec(*elem) + + caller_argspec = format_argspec_plus(spec, grouped=False) + + metadata = { + "name": fn.__name__, + "async": "async " if iscoroutine else "", + "await": "await " if iscoroutine else "", + "apply_pos_proxied": caller_argspec["apply_pos_proxied"], + "target_cls_name": target_cls.__name__, + "apply_kw_proxied": caller_argspec["apply_kw_proxied"], + "grouped_args": caller_argspec["grouped_args"], + "self_arg": caller_argspec["self_arg"], + "doc": textwrap.indent( + inject_docstring_text( + fn.__doc__, + textwrap.indent( + ".. container:: class_bases\n\n" + f" Proxied for the {target_cls_sphinx_name} " + "class on \n" + f" behalf of the {proxy_cls_sphinx_name} " + "class.", + " ", + ), + 1, + ), + " ", + ).lstrip(), + } + + if clslevel: + code = ( + "@classmethod\n" + "%(async)sdef %(name)s%(grouped_args)s:\n" + ' r"""%(doc)s\n """ # noqa: E501\n\n' + " return %(await)s%(target_cls_name)s.%(name)s(%(apply_kw_proxied)s)\n\n" # noqa: E501 + % metadata + ) + else: + code = ( + "%(async)sdef %(name)s%(grouped_args)s:\n" + ' r"""%(doc)s\n """ # noqa: E501\n\n' + " return %(await)s%(self_arg)s._proxied.%(name)s(%(apply_kw_proxied)s)\n\n" # noqa: E501 + % metadata + ) + + buf.write(textwrap.indent(code, " ")) + + def makeprop(buf: TextIO, name: str) -> None: + attr = target_cls.__dict__.get(name, None) + + return_type = target_cls.__annotations__.get(name, "Any") + assert isinstance(return_type, str), ( + "expected string annotations, is from __future__ " + "import annotations set up?" + ) + + if attr is not None: + if isinstance(attr, property): + readonly = attr.fset is None + elif isinstance(attr, langhelpers.generic_fn_descriptor): + readonly = True + else: + readonly = not hasattr(attr, "__set__") + doc = textwrap.indent( + inject_docstring_text( + attr.__doc__, + textwrap.indent( + ".. container:: class_bases\n\n" + f" Proxied for the {target_cls_sphinx_name} " + "class \n" + f" on behalf of the {proxy_cls_sphinx_name} " + "class.", + " ", + ), + 1, + ), + " ", + ).lstrip() + else: + readonly = False + doc = ( + f"Proxy for the :attr:`{sphinx_symbol}.{name}` " + "attribute \n" + f" on behalf of the {proxy_cls_sphinx_name} " + "class.\n" + ) + + code = ( + "@property\n" + "def %(name)s(self) -> %(return_type)s:\n" + ' r"""%(doc)s\n """ # noqa: E501\n\n' + " return self._proxied.%(name)s\n\n" + ) % {"name": name, "doc": doc, "return_type": return_type} + + if not readonly: + code += ( + "@%(name)s.setter\n" + "def %(name)s(self, attr: %(return_type)s) -> None:\n" + " self._proxied.%(name)s = attr\n\n" + ) % {"name": name, "doc": doc, "return_type": return_type} + + buf.write(textwrap.indent(code, " ")) + + for meth in methods: + instrument(buf, meth) + + for prop in attributes: + makeprop(buf, prop) + + for prop in classmethods: + instrument(buf, prop, clslevel=True) + + +def process_module(modname: str, filename: str) -> str: + + class_entries = classes[modname] + + # use tempfile in same path as the module, or at least in the + # current working directory, so that black / zimports use + # local pyproject.toml + with NamedTemporaryFile( + mode="w", delete=False, suffix=".py", dir=Path(filename).parent + ) as buf, open(filename) as orig_py: + + in_block = False + current_clsname = None + for line in orig_py: + m = re.match(r" # START PROXY METHODS (.+)$", line) + if m: + current_clsname = m.group(1) + args = class_entries[current_clsname] + sys.stderr.write( + f"Generating attributes for class {current_clsname}\n" + ) + in_block = True + buf.write(line) + buf.write( + "\n # code within this block is " + "**programmatically, \n" + " # statically generated** by" + " tools/generate_proxy_methods.py\n\n" + ) + + process_class(buf, *args) + if line.startswith(f" # END PROXY METHODS {current_clsname}"): + in_block = False + + if not in_block: + buf.write(line) + return buf.name + + +def console_scripts( + path: str, options: dict, ignore_output: bool = False +) -> None: + + entrypoint_name = options["entrypoint"] + + for entry in compat.importlib_metadata_get("console_scripts"): + if entry.name == entrypoint_name: + impl = entry + break + else: + raise Exception( + f"Could not find entrypoint console_scripts.{entrypoint_name}" + ) + cmdline_options_str = options.get("options", "") + cmdline_options_list = shlex.split(cmdline_options_str, posix=is_posix) + [ + path + ] + + kw = {} + if ignore_output: + kw["stdout"] = kw["stderr"] = subprocess.DEVNULL + + subprocess.run( + [ + sys.executable, + "-c", + "import %s; %s.%s()" % (impl.module, impl.module, impl.attr), + ] + + cmdline_options_list, + cwd=Path(__file__).parent.parent, + **kw, + ) + + +def run_module(modname, stdout): + + sys.stderr.write(f"importing module {modname}\n") + mod = importlib.import_module(modname) + filename = destination_path = mod.__file__ + assert filename is not None + + tempfile = process_module(modname, filename) + + ignore_output = stdout + + console_scripts( + str(tempfile), + {"entrypoint": "zimports"}, + ignore_output=ignore_output, + ) + + console_scripts( + str(tempfile), + {"entrypoint": "black"}, + ignore_output=ignore_output, + ) + + if stdout: + with open(tempfile) as tf: + print(tf.read()) + os.unlink(tempfile) + else: + sys.stderr.write(f"Writing {destination_path}...\n") + shutil.move(tempfile, destination_path) + + +def main(args): + from sqlalchemy import util + from sqlalchemy.util import langhelpers + + util.create_proxy_methods = ( + langhelpers.create_proxy_methods + ) = create_proxy_methods + + for entry in entries: + if args.module in {"all", entry}: + run_module(entry, args.stdout) + + +entries = [ + "sqlalchemy.orm.scoping", + "sqlalchemy.ext.asyncio.engine", + "sqlalchemy.ext.asyncio.session", + "sqlalchemy.ext.asyncio.scoping", +] + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument( + "--module", + choices=entries + ["all"], + default="all", + help="Which file to generate. Default is to regenerate all files", + ) + parser.add_argument( + "--stdout", + action="store_true", + help="Write to stdout instead of saving to file", + ) + args = parser.parse_args() + main(args) |