diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2023-01-20 15:17:44 -0500 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2023-01-23 11:17:38 -0500 |
commit | 6499098e36497d15d5972696983ce0ae4cc99409 (patch) | |
tree | 776264381311877c540f26656221ff637a0c5f8e /tools/generate_sql_functions.py | |
parent | 2f91dd79310657814ad28b6ef64f91fff7a007c9 (diff) | |
download | sqlalchemy-6499098e36497d15d5972696983ce0ae4cc99409.tar.gz |
generate stubs for func known functions
Added typing for the built-in generic functions that are available from the
:data:`_sql.func` namespace, which accept a particular set of arguments and
return a particular type, such as for :class:`_sql.count`,
:class:`_sql.current_timestamp`, etc.
Fixes: #9129
Change-Id: I1a2e0dcca3048c77e84dc786843a7df05c457dfa
Diffstat (limited to 'tools/generate_sql_functions.py')
-rw-r--r-- | tools/generate_sql_functions.py | 160 |
1 files changed, 160 insertions, 0 deletions
diff --git a/tools/generate_sql_functions.py b/tools/generate_sql_functions.py new file mode 100644 index 000000000..d207c62bc --- /dev/null +++ b/tools/generate_sql_functions.py @@ -0,0 +1,160 @@ +"""Generate inline stubs for generic functions on func + +""" +# mypy: ignore-errors + +from __future__ import annotations + +from decimal import Decimal +import inspect +import re +from tempfile import NamedTemporaryFile +import textwrap +from typing import Any + +from sqlalchemy.sql.functions import _registry +from sqlalchemy.types import TypeEngine +from sqlalchemy.util.tool_support import code_writer_cmd + + +def _fns_in_deterministic_order(): + reg = _registry["_default"] + for key in sorted(reg): + yield key, reg[key] + + +def process_functions(filename: str, cmd: code_writer_cmd) -> str: + + with NamedTemporaryFile( + mode="w", + delete=False, + suffix=".py", + ) as buf, open(filename) as orig_py: + indent = "" + in_block = False + + for line in orig_py: + m = re.match( + r"^( *)# START GENERATED FUNCTION ACCESSORS", + line, + ) + if m: + in_block = True + buf.write(line) + indent = m.group(1) + buf.write( + textwrap.indent( + """ +# code within this block is **programmatically, +# statically generated** by tools/generate_sql_functions.py +""", + indent, + ) + ) + + builtins = set(dir(__builtins__)) + for key, fn_class in _fns_in_deterministic_order(): + is_reserved_word = key in builtins + + guess_its_generic = bool(fn_class.__parameters__) + + buf.write( + textwrap.indent( + f""" +@property +def {key}(self) -> Type[{fn_class.__name__}{ + '[Any]' if guess_its_generic else '' +}]:{ + ' # noqa: A001' if is_reserved_word else '' +} + ... + +""", + indent, + ) + ) + + m = re.match( + r"^( *)# START GENERATED FUNCTION TYPING TESTS", + line, + ) + if m: + in_block = True + buf.write(line) + indent = m.group(1) + + buf.write( + textwrap.indent( + """ +# code within this block is **programmatically, +# statically generated** by tools/generate_sql_functions.py +""", + indent, + ) + ) + + count = 0 + for key, fn_class in _fns_in_deterministic_order(): + if hasattr(fn_class, "type") and isinstance( + fn_class.type, TypeEngine + ): + python_type = fn_class.type.python_type + + # TODO: numeric types don't seem to be coming out + # at the moment, because Numeric is typed generically + # in that it can return Decimal or float. We would need + # to further break out Numeric / Float into types + # that type out as returning an exact Decimal or float + if python_type is Decimal: + python_type = Any + python_expr = f"{python_type.__name__}" + else: + python_expr = rf"Tuple\[.*{python_type.__name__}\]" + argspec = inspect.getfullargspec(fn_class) + args = ", ".join( + 'column("x")' for elem in argspec.args[1:] + ) + count += 1 + + buf.write( + textwrap.indent( + rf""" +stmt{count} = select(func.{key}({args})) + +# EXPECTED_RE_TYPE: .*Select\[{python_expr}\] +reveal_type(stmt{count}) + +""", + indent, + ) + ) + + if in_block and line.startswith( + f"{indent}# END GENERATED FUNCTION" + ): + in_block = False + + if not in_block: + buf.write(line) + return buf.name + + +def main(cmd: code_writer_cmd) -> None: + for path in [functions_py, test_functions_py]: + destination_path = path + tempfile = process_functions(destination_path, cmd) + cmd.run_zimports(tempfile) + cmd.run_black(tempfile) + cmd.write_output_file_from_tempfile(tempfile, destination_path) + + +functions_py = "lib/sqlalchemy/sql/functions.py" +test_functions_py = "test/ext/mypy/plain_files/functions.py" + + +if __name__ == "__main__": + + cmd = code_writer_cmd(__file__) + + with cmd.run_program(): + main(cmd) |