summaryrefslogtreecommitdiff
path: root/tools/generate_proxy_methods.py
diff options
context:
space:
mode:
Diffstat (limited to 'tools/generate_proxy_methods.py')
-rw-r--r--tools/generate_proxy_methods.py93
1 files changed, 55 insertions, 38 deletions
diff --git a/tools/generate_proxy_methods.py b/tools/generate_proxy_methods.py
index ffc470972..91a891882 100644
--- a/tools/generate_proxy_methods.py
+++ b/tools/generate_proxy_methods.py
@@ -31,6 +31,12 @@ A similar approach is used in Alembic where a dynamic approach towards creating
alembic "ops" was enhanced to generate a .pyi stubs file statically for
consumption by typing tools.
+Note that the usual OO approach of having a common interface class with
+concrete subtypes doesn't really solve any problems here; the concrete subtypes
+must still list out all methods, arguments, typing annotations, and docstrings,
+all of which is copied by this script rather than requiring it all be
+typed by hand.
+
.. versionadded:: 2.0
"""
@@ -43,9 +49,7 @@ import inspect
import os
from pathlib import Path
import re
-import shlex
import shutil
-import subprocess
import sys
from tempfile import NamedTemporaryFile
import textwrap
@@ -61,6 +65,7 @@ from typing import TypeVar
from sqlalchemy import util
from sqlalchemy.util import compat
from sqlalchemy.util import langhelpers
+from sqlalchemy.util.langhelpers import console_scripts
from sqlalchemy.util.langhelpers import format_argspec_plus
from sqlalchemy.util.langhelpers import inject_docstring_text
@@ -122,6 +127,47 @@ def create_proxy_methods(
return decorate
+def _grab_overloads(fn):
+ """grab @overload entries for a function, assuming black-formatted
+ code ;) so that we can do a simple regex
+
+ """
+
+ # functions that use @util.deprecated and whatnot will have a string
+ # generated fn. we can look at __wrapped__ but these functions don't
+ # have any overloads in any case right now so skip
+ if fn.__code__.co_filename == "<string>":
+ return []
+
+ with open(fn.__code__.co_filename) as f:
+ lines = [l for i, l in zip(range(fn.__code__.co_firstlineno), f)]
+
+ lines.reverse()
+
+ output = []
+
+ current_ov = []
+ for line in lines[1:]:
+ current_ov.append(line)
+ outside_block_match = re.match(r"^\w", line)
+ if outside_block_match:
+ current_ov[:] = []
+ break
+
+ fn_match = re.match(rf"^ (?:async )?def (.*)\($", line)
+ if fn_match and fn_match.group(1) != fn.__name__:
+ current_ov[:] = []
+ break
+
+ ov_match = re.match(r"^ @overload$", line)
+ if ov_match:
+ output.append("".join(reversed(current_ov)))
+ current_ov[:] = []
+
+ output.reverse()
+ return output
+
+
def process_class(
buf: TextIO,
target_cls: Type[Any],
@@ -145,6 +191,12 @@ def process_class(
def instrument(buf: TextIO, name: str, clslevel: bool = False) -> None:
fn = getattr(target_cls, name)
+
+ overloads = _grab_overloads(fn)
+
+ for overload in overloads:
+ buf.write(overload)
+
spec = compat.inspect_getfullargspec(fn)
iscoroutine = inspect.iscoroutinefunction(fn)
@@ -311,7 +363,7 @@ def process_module(modname: str, filename: str) -> str:
"\n # code within this block is "
"**programmatically, \n"
" # statically generated** by"
- " tools/generate_proxy_methods.py\n\n"
+ f" tools/{os.path.basename(__file__)}\n\n"
)
process_class(buf, *args)
@@ -323,41 +375,6 @@ def process_module(modname: str, filename: str) -> str:
return buf.name
-def console_scripts(
- path: str, options: dict, ignore_output: bool = False
-) -> None:
-
- entrypoint_name = options["entrypoint"]
-
- for entry in compat.importlib_metadata_get("console_scripts"):
- if entry.name == entrypoint_name:
- impl = entry
- break
- else:
- raise Exception(
- f"Could not find entrypoint console_scripts.{entrypoint_name}"
- )
- cmdline_options_str = options.get("options", "")
- cmdline_options_list = shlex.split(cmdline_options_str, posix=is_posix) + [
- path
- ]
-
- kw = {}
- if ignore_output:
- kw["stdout"] = kw["stderr"] = subprocess.DEVNULL
-
- subprocess.run(
- [
- sys.executable,
- "-c",
- "import %s; %s.%s()" % (impl.module, impl.module, impl.attr),
- ]
- + cmdline_options_list,
- cwd=Path(__file__).parent.parent,
- **kw,
- )
-
-
def run_module(modname, stdout):
sys.stderr.write(f"importing module {modname}\n")