summaryrefslogtreecommitdiff
path: root/tools/generate_sql_functions.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2023-01-20 15:17:44 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2023-01-23 11:17:38 -0500
commit6499098e36497d15d5972696983ce0ae4cc99409 (patch)
tree776264381311877c540f26656221ff637a0c5f8e /tools/generate_sql_functions.py
parent2f91dd79310657814ad28b6ef64f91fff7a007c9 (diff)
downloadsqlalchemy-6499098e36497d15d5972696983ce0ae4cc99409.tar.gz
generate stubs for func known functions
Added typing for the built-in generic functions that are available from the :data:`_sql.func` namespace, which accept a particular set of arguments and return a particular type, such as for :class:`_sql.count`, :class:`_sql.current_timestamp`, etc. Fixes: #9129 Change-Id: I1a2e0dcca3048c77e84dc786843a7df05c457dfa
Diffstat (limited to 'tools/generate_sql_functions.py')
-rw-r--r--tools/generate_sql_functions.py160
1 files changed, 160 insertions, 0 deletions
diff --git a/tools/generate_sql_functions.py b/tools/generate_sql_functions.py
new file mode 100644
index 000000000..d207c62bc
--- /dev/null
+++ b/tools/generate_sql_functions.py
@@ -0,0 +1,160 @@
+"""Generate inline stubs for generic functions on func
+
+"""
+# mypy: ignore-errors
+
+from __future__ import annotations
+
+from decimal import Decimal
+import inspect
+import re
+from tempfile import NamedTemporaryFile
+import textwrap
+from typing import Any
+
+from sqlalchemy.sql.functions import _registry
+from sqlalchemy.types import TypeEngine
+from sqlalchemy.util.tool_support import code_writer_cmd
+
+
+def _fns_in_deterministic_order():
+ reg = _registry["_default"]
+ for key in sorted(reg):
+ yield key, reg[key]
+
+
+def process_functions(filename: str, cmd: code_writer_cmd) -> str:
+
+ with NamedTemporaryFile(
+ mode="w",
+ delete=False,
+ suffix=".py",
+ ) as buf, open(filename) as orig_py:
+ indent = ""
+ in_block = False
+
+ for line in orig_py:
+ m = re.match(
+ r"^( *)# START GENERATED FUNCTION ACCESSORS",
+ line,
+ )
+ if m:
+ in_block = True
+ buf.write(line)
+ indent = m.group(1)
+ buf.write(
+ textwrap.indent(
+ """
+# code within this block is **programmatically,
+# statically generated** by tools/generate_sql_functions.py
+""",
+ indent,
+ )
+ )
+
+ builtins = set(dir(__builtins__))
+ for key, fn_class in _fns_in_deterministic_order():
+ is_reserved_word = key in builtins
+
+ guess_its_generic = bool(fn_class.__parameters__)
+
+ buf.write(
+ textwrap.indent(
+ f"""
+@property
+def {key}(self) -> Type[{fn_class.__name__}{
+ '[Any]' if guess_its_generic else ''
+}]:{
+ ' # noqa: A001' if is_reserved_word else ''
+}
+ ...
+
+""",
+ indent,
+ )
+ )
+
+ m = re.match(
+ r"^( *)# START GENERATED FUNCTION TYPING TESTS",
+ line,
+ )
+ if m:
+ in_block = True
+ buf.write(line)
+ indent = m.group(1)
+
+ buf.write(
+ textwrap.indent(
+ """
+# code within this block is **programmatically,
+# statically generated** by tools/generate_sql_functions.py
+""",
+ indent,
+ )
+ )
+
+ count = 0
+ for key, fn_class in _fns_in_deterministic_order():
+ if hasattr(fn_class, "type") and isinstance(
+ fn_class.type, TypeEngine
+ ):
+ python_type = fn_class.type.python_type
+
+ # TODO: numeric types don't seem to be coming out
+ # at the moment, because Numeric is typed generically
+ # in that it can return Decimal or float. We would need
+ # to further break out Numeric / Float into types
+ # that type out as returning an exact Decimal or float
+ if python_type is Decimal:
+ python_type = Any
+ python_expr = f"{python_type.__name__}"
+ else:
+ python_expr = rf"Tuple\[.*{python_type.__name__}\]"
+ argspec = inspect.getfullargspec(fn_class)
+ args = ", ".join(
+ 'column("x")' for elem in argspec.args[1:]
+ )
+ count += 1
+
+ buf.write(
+ textwrap.indent(
+ rf"""
+stmt{count} = select(func.{key}({args}))
+
+# EXPECTED_RE_TYPE: .*Select\[{python_expr}\]
+reveal_type(stmt{count})
+
+""",
+ indent,
+ )
+ )
+
+ if in_block and line.startswith(
+ f"{indent}# END GENERATED FUNCTION"
+ ):
+ in_block = False
+
+ if not in_block:
+ buf.write(line)
+ return buf.name
+
+
+def main(cmd: code_writer_cmd) -> None:
+ for path in [functions_py, test_functions_py]:
+ destination_path = path
+ tempfile = process_functions(destination_path, cmd)
+ cmd.run_zimports(tempfile)
+ cmd.run_black(tempfile)
+ cmd.write_output_file_from_tempfile(tempfile, destination_path)
+
+
+functions_py = "lib/sqlalchemy/sql/functions.py"
+test_functions_py = "test/ext/mypy/plain_files/functions.py"
+
+
+if __name__ == "__main__":
+
+ cmd = code_writer_cmd(__file__)
+
+ with cmd.run_program():
+ main(cmd)