summaryrefslogtreecommitdiff
path: root/tools/generate_proxy_methods.py
diff options
context:
space:
mode:
Diffstat (limited to 'tools/generate_proxy_methods.py')
-rw-r--r--tools/generate_proxy_methods.py413
1 files changed, 413 insertions, 0 deletions
diff --git a/tools/generate_proxy_methods.py b/tools/generate_proxy_methods.py
new file mode 100644
index 000000000..eec4d878a
--- /dev/null
+++ b/tools/generate_proxy_methods.py
@@ -0,0 +1,413 @@
+"""Generate static proxy code for SQLAlchemy classes that proxy other
+objects.
+
+This tool is run at source code authoring / commit time whenever we add new
+methods to engines/connections/sessions that need to be generically proxied by
+scoped_session or asyncio. The generated code is part of what's committed
+to source just as though we typed it all by hand.
+
+The original "proxy" class was scoped_session. Then with asyncio, all the
+asyncio objects are essentially "proxy" objects as well; while all the methods
+that are "async" needed to be written by hand, there's lots of other attributes
+and methods that are proxied exactly.
+
+To eliminate redundancy, all of these classes made use of the
+@langhelpers.create_proxy_methods() decorator which at runtime would read a
+selected list of methods and attributes from the proxied class and generate new
+methods and properties descriptors on the proxying class; this way the proxy
+would have all the same methods signatures / attributes / docstrings consumed
+by Sphinx and look just like the proxied class.
+
+Then mypy and typing came along, which don't care about runtime generated code
+and never will. So this script takes that same
+@langhelpers.create_proxy_methods() decorator, keeps its public interface just
+as is, and uses it to generate all the code and docs in those proxy classes
+statically, as though we sat there and spent seven hours typing it all by hand.
+The runtime code generation part is removed from ``create_proxy_methods()``.
+Now we have static code that is perfectly consumable by all the typing tools
+and we also reduce import time a bit.
+
+A similar approach is used in Alembic where a dynamic approach towards creating
+alembic "ops" was enhanced to generate a .pyi stubs file statically for
+consumption by typing tools.
+
+.. versionadded:: 2.0
+
+"""
+from __future__ import annotations
+
+from argparse import ArgumentParser
+import collections
+import importlib
+import inspect
+import os
+from pathlib import Path
+import re
+import shlex
+import shutil
+import subprocess
+import sys
+from tempfile import NamedTemporaryFile
+import textwrap
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import Iterable
+from typing import TextIO
+from typing import Tuple
+from typing import Type
+from typing import TypeVar
+
+from sqlalchemy import util
+from sqlalchemy.util import compat
+from sqlalchemy.util import langhelpers
+from sqlalchemy.util.langhelpers import format_argspec_plus
+from sqlalchemy.util.langhelpers import inject_docstring_text
+
+is_posix = os.name == "posix"
+
+
+sys.path.append(str(Path(__file__).parent.parent))
+
+
+class _repr_sym:
+ __slots__ = ("sym",)
+
+ def __init__(self, sym: str):
+ self.sym = sym
+
+ def __repr__(self) -> str:
+ return self.sym
+
+
+classes: collections.defaultdict[
+ str, Dict[str, Tuple[Any, ...]]
+] = collections.defaultdict(dict)
+
+_T = TypeVar("_T", bound="Any")
+
+
+def create_proxy_methods(
+ target_cls: Type[Any],
+ target_cls_sphinx_name: str,
+ proxy_cls_sphinx_name: str,
+ classmethods: Iterable[str] = (),
+ methods: Iterable[str] = (),
+ attributes: Iterable[str] = (),
+) -> Callable[[Type[_T]], Type[_T]]:
+ """A class decorator that will copy attributes to a proxy class.
+
+ The class to be instrumented must define a single accessor "_proxied".
+
+ """
+
+ def decorate(cls: Type[_T]) -> Type[_T]:
+ # collect the class as a separate step. since the decorator
+ # is called as a result of imports, the order in which classes
+ # are collected (like in asyncio) can't be well controlled. however,
+ # the proxies (specifically asyncio session and asyncio scoped_session)
+ # have to be generated in dependency order, so run them in order in a
+ # second step.
+ classes[cls.__module__][cls.__name__] = (
+ target_cls,
+ target_cls_sphinx_name,
+ proxy_cls_sphinx_name,
+ classmethods,
+ methods,
+ attributes,
+ cls,
+ )
+ return cls
+
+ return decorate
+
+
+def process_class(
+ buf: TextIO,
+ target_cls: Type[Any],
+ target_cls_sphinx_name: str,
+ proxy_cls_sphinx_name: str,
+ classmethods: Iterable[str],
+ methods: Iterable[str],
+ attributes: Iterable[str],
+ cls: Type[Any],
+):
+
+ sphinx_symbol_match = re.match(r":class:`(.+)`", target_cls_sphinx_name)
+ if not sphinx_symbol_match:
+ raise Exception(
+ f"Couldn't match sphinx class identifier from "
+ f"target_cls_sphinx_name f{target_cls_sphinx_name!r}. Currently "
+ 'this program expects the form ":class:`_<prefix>.<clsname>`"'
+ )
+
+ sphinx_symbol = sphinx_symbol_match.group(1)
+
+ def instrument(buf: TextIO, name: str, clslevel: bool = False) -> None:
+ fn = getattr(target_cls, name)
+ spec = compat.inspect_getfullargspec(fn)
+
+ iscoroutine = inspect.iscoroutinefunction(fn)
+
+ if spec.defaults:
+ new_defaults = tuple(
+ _repr_sym("util.EMPTY_DICT") if df is util.EMPTY_DICT else df
+ for df in spec.defaults
+ )
+ elem = list(spec)
+ elem[3] = tuple(new_defaults)
+ spec = compat.FullArgSpec(*elem)
+
+ caller_argspec = format_argspec_plus(spec, grouped=False)
+
+ metadata = {
+ "name": fn.__name__,
+ "async": "async " if iscoroutine else "",
+ "await": "await " if iscoroutine else "",
+ "apply_pos_proxied": caller_argspec["apply_pos_proxied"],
+ "target_cls_name": target_cls.__name__,
+ "apply_kw_proxied": caller_argspec["apply_kw_proxied"],
+ "grouped_args": caller_argspec["grouped_args"],
+ "self_arg": caller_argspec["self_arg"],
+ "doc": textwrap.indent(
+ inject_docstring_text(
+ fn.__doc__,
+ textwrap.indent(
+ ".. container:: class_bases\n\n"
+ f" Proxied for the {target_cls_sphinx_name} "
+ "class on \n"
+ f" behalf of the {proxy_cls_sphinx_name} "
+ "class.",
+ " ",
+ ),
+ 1,
+ ),
+ " ",
+ ).lstrip(),
+ }
+
+ if clslevel:
+ code = (
+ "@classmethod\n"
+ "%(async)sdef %(name)s%(grouped_args)s:\n"
+ ' r"""%(doc)s\n """ # noqa: E501\n\n'
+ " return %(await)s%(target_cls_name)s.%(name)s(%(apply_kw_proxied)s)\n\n" # noqa: E501
+ % metadata
+ )
+ else:
+ code = (
+ "%(async)sdef %(name)s%(grouped_args)s:\n"
+ ' r"""%(doc)s\n """ # noqa: E501\n\n'
+ " return %(await)s%(self_arg)s._proxied.%(name)s(%(apply_kw_proxied)s)\n\n" # noqa: E501
+ % metadata
+ )
+
+ buf.write(textwrap.indent(code, " "))
+
+ def makeprop(buf: TextIO, name: str) -> None:
+ attr = target_cls.__dict__.get(name, None)
+
+ return_type = target_cls.__annotations__.get(name, "Any")
+ assert isinstance(return_type, str), (
+ "expected string annotations, is from __future__ "
+ "import annotations set up?"
+ )
+
+ if attr is not None:
+ if isinstance(attr, property):
+ readonly = attr.fset is None
+ elif isinstance(attr, langhelpers.generic_fn_descriptor):
+ readonly = True
+ else:
+ readonly = not hasattr(attr, "__set__")
+ doc = textwrap.indent(
+ inject_docstring_text(
+ attr.__doc__,
+ textwrap.indent(
+ ".. container:: class_bases\n\n"
+ f" Proxied for the {target_cls_sphinx_name} "
+ "class \n"
+ f" on behalf of the {proxy_cls_sphinx_name} "
+ "class.",
+ " ",
+ ),
+ 1,
+ ),
+ " ",
+ ).lstrip()
+ else:
+ readonly = False
+ doc = (
+ f"Proxy for the :attr:`{sphinx_symbol}.{name}` "
+ "attribute \n"
+ f" on behalf of the {proxy_cls_sphinx_name} "
+ "class.\n"
+ )
+
+ code = (
+ "@property\n"
+ "def %(name)s(self) -> %(return_type)s:\n"
+ ' r"""%(doc)s\n """ # noqa: E501\n\n'
+ " return self._proxied.%(name)s\n\n"
+ ) % {"name": name, "doc": doc, "return_type": return_type}
+
+ if not readonly:
+ code += (
+ "@%(name)s.setter\n"
+ "def %(name)s(self, attr: %(return_type)s) -> None:\n"
+ " self._proxied.%(name)s = attr\n\n"
+ ) % {"name": name, "doc": doc, "return_type": return_type}
+
+ buf.write(textwrap.indent(code, " "))
+
+ for meth in methods:
+ instrument(buf, meth)
+
+ for prop in attributes:
+ makeprop(buf, prop)
+
+ for prop in classmethods:
+ instrument(buf, prop, clslevel=True)
+
+
+def process_module(modname: str, filename: str) -> str:
+
+ class_entries = classes[modname]
+
+ # use tempfile in same path as the module, or at least in the
+ # current working directory, so that black / zimports use
+ # local pyproject.toml
+ with NamedTemporaryFile(
+ mode="w", delete=False, suffix=".py", dir=Path(filename).parent
+ ) as buf, open(filename) as orig_py:
+
+ in_block = False
+ current_clsname = None
+ for line in orig_py:
+ m = re.match(r" # START PROXY METHODS (.+)$", line)
+ if m:
+ current_clsname = m.group(1)
+ args = class_entries[current_clsname]
+ sys.stderr.write(
+ f"Generating attributes for class {current_clsname}\n"
+ )
+ in_block = True
+ buf.write(line)
+ buf.write(
+ "\n # code within this block is "
+ "**programmatically, \n"
+ " # statically generated** by"
+ " tools/generate_proxy_methods.py\n\n"
+ )
+
+ process_class(buf, *args)
+ if line.startswith(f" # END PROXY METHODS {current_clsname}"):
+ in_block = False
+
+ if not in_block:
+ buf.write(line)
+ return buf.name
+
+
+def console_scripts(
+ path: str, options: dict, ignore_output: bool = False
+) -> None:
+
+ entrypoint_name = options["entrypoint"]
+
+ for entry in compat.importlib_metadata_get("console_scripts"):
+ if entry.name == entrypoint_name:
+ impl = entry
+ break
+ else:
+ raise Exception(
+ f"Could not find entrypoint console_scripts.{entrypoint_name}"
+ )
+ cmdline_options_str = options.get("options", "")
+ cmdline_options_list = shlex.split(cmdline_options_str, posix=is_posix) + [
+ path
+ ]
+
+ kw = {}
+ if ignore_output:
+ kw["stdout"] = kw["stderr"] = subprocess.DEVNULL
+
+ subprocess.run(
+ [
+ sys.executable,
+ "-c",
+ "import %s; %s.%s()" % (impl.module, impl.module, impl.attr),
+ ]
+ + cmdline_options_list,
+ cwd=Path(__file__).parent.parent,
+ **kw,
+ )
+
+
+def run_module(modname, stdout):
+
+ sys.stderr.write(f"importing module {modname}\n")
+ mod = importlib.import_module(modname)
+ filename = destination_path = mod.__file__
+ assert filename is not None
+
+ tempfile = process_module(modname, filename)
+
+ ignore_output = stdout
+
+ console_scripts(
+ str(tempfile),
+ {"entrypoint": "zimports"},
+ ignore_output=ignore_output,
+ )
+
+ console_scripts(
+ str(tempfile),
+ {"entrypoint": "black"},
+ ignore_output=ignore_output,
+ )
+
+ if stdout:
+ with open(tempfile) as tf:
+ print(tf.read())
+ os.unlink(tempfile)
+ else:
+ sys.stderr.write(f"Writing {destination_path}...\n")
+ shutil.move(tempfile, destination_path)
+
+
+def main(args):
+ from sqlalchemy import util
+ from sqlalchemy.util import langhelpers
+
+ util.create_proxy_methods = (
+ langhelpers.create_proxy_methods
+ ) = create_proxy_methods
+
+ for entry in entries:
+ if args.module in {"all", entry}:
+ run_module(entry, args.stdout)
+
+
+entries = [
+ "sqlalchemy.orm.scoping",
+ "sqlalchemy.ext.asyncio.engine",
+ "sqlalchemy.ext.asyncio.session",
+ "sqlalchemy.ext.asyncio.scoping",
+]
+
+if __name__ == "__main__":
+ parser = ArgumentParser()
+ parser.add_argument(
+ "--module",
+ choices=entries + ["all"],
+ default="all",
+ help="Which file to generate. Default is to regenerate all files",
+ )
+ parser.add_argument(
+ "--stdout",
+ action="store_true",
+ help="Write to stdout instead of saving to file",
+ )
+ args = parser.parse_args()
+ main(args)