diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-04-05 19:00:19 -0400 |
---|---|---|
committer | mike bayer <mike_mp@zzzcomputing.com> | 2022-04-12 02:09:42 +0000 |
commit | 98eae4e181cb2d1bbc67ec834bfad29dcba7f461 (patch) | |
tree | fc000c3113a4a198b4ddd6bc81fe291dc9ef1ffb /tools/generate_proxy_methods.py | |
parent | 15ef11e0ede82e44fb07f31b63d3db0712d8bf48 (diff) | |
download | sqlalchemy-98eae4e181cb2d1bbc67ec834bfad29dcba7f461.tar.gz |
use code generation for scoped_session
our decorator thing generates code in any case,
so point it at the file itself to generate real code
for the blocks rather than doing things dynamically.
this will allow typing tools to have no problem
whatsoever and we also reduce import time overhead.
file size will be a lot bigger though, shrugs.
syntax / dupe method / etc. checking will be accomplished
by our existing linting / typing / formatting tools.
As we are also using "from __future__ import annotations",
we also no longer have to apply quotes to generated
annotations.
Change-Id: I20962cb65bda63ff0fb67357ab346e9b1ef4f108
Diffstat (limited to 'tools/generate_proxy_methods.py')
-rw-r--r-- | tools/generate_proxy_methods.py | 413 |
1 files changed, 413 insertions, 0 deletions
diff --git a/tools/generate_proxy_methods.py b/tools/generate_proxy_methods.py new file mode 100644 index 000000000..eec4d878a --- /dev/null +++ b/tools/generate_proxy_methods.py @@ -0,0 +1,413 @@ +"""Generate static proxy code for SQLAlchemy classes that proxy other +objects. + +This tool is run at source code authoring / commit time whenever we add new +methods to engines/connections/sessions that need to be generically proxied by +scoped_session or asyncio. The generated code is part of what's committed +to source just as though we typed it all by hand. + +The original "proxy" class was scoped_session. Then with asyncio, all the +asyncio objects are essentially "proxy" objects as well; while all the methods +that are "async" needed to be written by hand, there's lots of other attributes +and methods that are proxied exactly. + +To eliminate redundancy, all of these classes made use of the +@langhelpers.create_proxy_methods() decorator which at runtime would read a +selected list of methods and attributes from the proxied class and generate new +methods and properties descriptors on the proxying class; this way the proxy +would have all the same methods signatures / attributes / docstrings consumed +by Sphinx and look just like the proxied class. + +Then mypy and typing came along, which don't care about runtime generated code +and never will. So this script takes that same +@langhelpers.create_proxy_methods() decorator, keeps its public interface just +as is, and uses it to generate all the code and docs in those proxy classes +statically, as though we sat there and spent seven hours typing it all by hand. +The runtime code generation part is removed from ``create_proxy_methods()``. +Now we have static code that is perfectly consumable by all the typing tools +and we also reduce import time a bit. + +A similar approach is used in Alembic where a dynamic approach towards creating +alembic "ops" was enhanced to generate a .pyi stubs file statically for +consumption by typing tools. + +.. versionadded:: 2.0 + +""" +from __future__ import annotations + +from argparse import ArgumentParser +import collections +import importlib +import inspect +import os +from pathlib import Path +import re +import shlex +import shutil +import subprocess +import sys +from tempfile import NamedTemporaryFile +import textwrap +from typing import Any +from typing import Callable +from typing import Dict +from typing import Iterable +from typing import TextIO +from typing import Tuple +from typing import Type +from typing import TypeVar + +from sqlalchemy import util +from sqlalchemy.util import compat +from sqlalchemy.util import langhelpers +from sqlalchemy.util.langhelpers import format_argspec_plus +from sqlalchemy.util.langhelpers import inject_docstring_text + +is_posix = os.name == "posix" + + +sys.path.append(str(Path(__file__).parent.parent)) + + +class _repr_sym: + __slots__ = ("sym",) + + def __init__(self, sym: str): + self.sym = sym + + def __repr__(self) -> str: + return self.sym + + +classes: collections.defaultdict[ + str, Dict[str, Tuple[Any, ...]] +] = collections.defaultdict(dict) + +_T = TypeVar("_T", bound="Any") + + +def create_proxy_methods( + target_cls: Type[Any], + target_cls_sphinx_name: str, + proxy_cls_sphinx_name: str, + classmethods: Iterable[str] = (), + methods: Iterable[str] = (), + attributes: Iterable[str] = (), +) -> Callable[[Type[_T]], Type[_T]]: + """A class decorator that will copy attributes to a proxy class. + + The class to be instrumented must define a single accessor "_proxied". + + """ + + def decorate(cls: Type[_T]) -> Type[_T]: + # collect the class as a separate step. since the decorator + # is called as a result of imports, the order in which classes + # are collected (like in asyncio) can't be well controlled. however, + # the proxies (specifically asyncio session and asyncio scoped_session) + # have to be generated in dependency order, so run them in order in a + # second step. + classes[cls.__module__][cls.__name__] = ( + target_cls, + target_cls_sphinx_name, + proxy_cls_sphinx_name, + classmethods, + methods, + attributes, + cls, + ) + return cls + + return decorate + + +def process_class( + buf: TextIO, + target_cls: Type[Any], + target_cls_sphinx_name: str, + proxy_cls_sphinx_name: str, + classmethods: Iterable[str], + methods: Iterable[str], + attributes: Iterable[str], + cls: Type[Any], +): + + sphinx_symbol_match = re.match(r":class:`(.+)`", target_cls_sphinx_name) + if not sphinx_symbol_match: + raise Exception( + f"Couldn't match sphinx class identifier from " + f"target_cls_sphinx_name f{target_cls_sphinx_name!r}. Currently " + 'this program expects the form ":class:`_<prefix>.<clsname>`"' + ) + + sphinx_symbol = sphinx_symbol_match.group(1) + + def instrument(buf: TextIO, name: str, clslevel: bool = False) -> None: + fn = getattr(target_cls, name) + spec = compat.inspect_getfullargspec(fn) + + iscoroutine = inspect.iscoroutinefunction(fn) + + if spec.defaults: + new_defaults = tuple( + _repr_sym("util.EMPTY_DICT") if df is util.EMPTY_DICT else df + for df in spec.defaults + ) + elem = list(spec) + elem[3] = tuple(new_defaults) + spec = compat.FullArgSpec(*elem) + + caller_argspec = format_argspec_plus(spec, grouped=False) + + metadata = { + "name": fn.__name__, + "async": "async " if iscoroutine else "", + "await": "await " if iscoroutine else "", + "apply_pos_proxied": caller_argspec["apply_pos_proxied"], + "target_cls_name": target_cls.__name__, + "apply_kw_proxied": caller_argspec["apply_kw_proxied"], + "grouped_args": caller_argspec["grouped_args"], + "self_arg": caller_argspec["self_arg"], + "doc": textwrap.indent( + inject_docstring_text( + fn.__doc__, + textwrap.indent( + ".. container:: class_bases\n\n" + f" Proxied for the {target_cls_sphinx_name} " + "class on \n" + f" behalf of the {proxy_cls_sphinx_name} " + "class.", + " ", + ), + 1, + ), + " ", + ).lstrip(), + } + + if clslevel: + code = ( + "@classmethod\n" + "%(async)sdef %(name)s%(grouped_args)s:\n" + ' r"""%(doc)s\n """ # noqa: E501\n\n' + " return %(await)s%(target_cls_name)s.%(name)s(%(apply_kw_proxied)s)\n\n" # noqa: E501 + % metadata + ) + else: + code = ( + "%(async)sdef %(name)s%(grouped_args)s:\n" + ' r"""%(doc)s\n """ # noqa: E501\n\n' + " return %(await)s%(self_arg)s._proxied.%(name)s(%(apply_kw_proxied)s)\n\n" # noqa: E501 + % metadata + ) + + buf.write(textwrap.indent(code, " ")) + + def makeprop(buf: TextIO, name: str) -> None: + attr = target_cls.__dict__.get(name, None) + + return_type = target_cls.__annotations__.get(name, "Any") + assert isinstance(return_type, str), ( + "expected string annotations, is from __future__ " + "import annotations set up?" + ) + + if attr is not None: + if isinstance(attr, property): + readonly = attr.fset is None + elif isinstance(attr, langhelpers.generic_fn_descriptor): + readonly = True + else: + readonly = not hasattr(attr, "__set__") + doc = textwrap.indent( + inject_docstring_text( + attr.__doc__, + textwrap.indent( + ".. container:: class_bases\n\n" + f" Proxied for the {target_cls_sphinx_name} " + "class \n" + f" on behalf of the {proxy_cls_sphinx_name} " + "class.", + " ", + ), + 1, + ), + " ", + ).lstrip() + else: + readonly = False + doc = ( + f"Proxy for the :attr:`{sphinx_symbol}.{name}` " + "attribute \n" + f" on behalf of the {proxy_cls_sphinx_name} " + "class.\n" + ) + + code = ( + "@property\n" + "def %(name)s(self) -> %(return_type)s:\n" + ' r"""%(doc)s\n """ # noqa: E501\n\n' + " return self._proxied.%(name)s\n\n" + ) % {"name": name, "doc": doc, "return_type": return_type} + + if not readonly: + code += ( + "@%(name)s.setter\n" + "def %(name)s(self, attr: %(return_type)s) -> None:\n" + " self._proxied.%(name)s = attr\n\n" + ) % {"name": name, "doc": doc, "return_type": return_type} + + buf.write(textwrap.indent(code, " ")) + + for meth in methods: + instrument(buf, meth) + + for prop in attributes: + makeprop(buf, prop) + + for prop in classmethods: + instrument(buf, prop, clslevel=True) + + +def process_module(modname: str, filename: str) -> str: + + class_entries = classes[modname] + + # use tempfile in same path as the module, or at least in the + # current working directory, so that black / zimports use + # local pyproject.toml + with NamedTemporaryFile( + mode="w", delete=False, suffix=".py", dir=Path(filename).parent + ) as buf, open(filename) as orig_py: + + in_block = False + current_clsname = None + for line in orig_py: + m = re.match(r" # START PROXY METHODS (.+)$", line) + if m: + current_clsname = m.group(1) + args = class_entries[current_clsname] + sys.stderr.write( + f"Generating attributes for class {current_clsname}\n" + ) + in_block = True + buf.write(line) + buf.write( + "\n # code within this block is " + "**programmatically, \n" + " # statically generated** by" + " tools/generate_proxy_methods.py\n\n" + ) + + process_class(buf, *args) + if line.startswith(f" # END PROXY METHODS {current_clsname}"): + in_block = False + + if not in_block: + buf.write(line) + return buf.name + + +def console_scripts( + path: str, options: dict, ignore_output: bool = False +) -> None: + + entrypoint_name = options["entrypoint"] + + for entry in compat.importlib_metadata_get("console_scripts"): + if entry.name == entrypoint_name: + impl = entry + break + else: + raise Exception( + f"Could not find entrypoint console_scripts.{entrypoint_name}" + ) + cmdline_options_str = options.get("options", "") + cmdline_options_list = shlex.split(cmdline_options_str, posix=is_posix) + [ + path + ] + + kw = {} + if ignore_output: + kw["stdout"] = kw["stderr"] = subprocess.DEVNULL + + subprocess.run( + [ + sys.executable, + "-c", + "import %s; %s.%s()" % (impl.module, impl.module, impl.attr), + ] + + cmdline_options_list, + cwd=Path(__file__).parent.parent, + **kw, + ) + + +def run_module(modname, stdout): + + sys.stderr.write(f"importing module {modname}\n") + mod = importlib.import_module(modname) + filename = destination_path = mod.__file__ + assert filename is not None + + tempfile = process_module(modname, filename) + + ignore_output = stdout + + console_scripts( + str(tempfile), + {"entrypoint": "zimports"}, + ignore_output=ignore_output, + ) + + console_scripts( + str(tempfile), + {"entrypoint": "black"}, + ignore_output=ignore_output, + ) + + if stdout: + with open(tempfile) as tf: + print(tf.read()) + os.unlink(tempfile) + else: + sys.stderr.write(f"Writing {destination_path}...\n") + shutil.move(tempfile, destination_path) + + +def main(args): + from sqlalchemy import util + from sqlalchemy.util import langhelpers + + util.create_proxy_methods = ( + langhelpers.create_proxy_methods + ) = create_proxy_methods + + for entry in entries: + if args.module in {"all", entry}: + run_module(entry, args.stdout) + + +entries = [ + "sqlalchemy.orm.scoping", + "sqlalchemy.ext.asyncio.engine", + "sqlalchemy.ext.asyncio.session", + "sqlalchemy.ext.asyncio.scoping", +] + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument( + "--module", + choices=entries + ["all"], + default="all", + help="Which file to generate. Default is to regenerate all files", + ) + parser.add_argument( + "--stdout", + action="store_true", + help="Write to stdout instead of saving to file", + ) + args = parser.parse_args() + main(args) |