diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2023-01-19 12:09:29 -0500 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2023-01-19 17:04:59 -0500 |
commit | be0831fea83247451628bc6643d5b130c63f6011 (patch) | |
tree | 97036e708f073e91d7de6fc490cba06d14c1b142 /lib/sqlalchemy/sql/lambdas.py | |
parent | e82a5f19e1606500ad4bf6a456c2558d74df24bf (diff) | |
download | sqlalchemy-be0831fea83247451628bc6643d5b130c63f6011.tar.gz |
implement basic typing for lambda elements
These weren't working at all, so fixed things up and
added a test suite. Keeping things very basic with Any
returns etc. as having more specific return types
starts making it too cumbersome to write end-user code.
Corrected the type passed for "lambda statements" so that a plain lambda is
accepted by mypy, pyright, others without any errors about argument types.
Additionally implemented typing for more of the public API for lambda
statements and ensured :class:`.StatementLambdaElement` is part of the
:class:`.Executable` hierarchy so it's typed as accepted by
:meth:`_engine.Connection.execute`.
Fixes: #9120
Change-Id: Ia7fa34e5b6e43fba02c8f94ccc256f3a68a1f445
Diffstat (limited to 'lib/sqlalchemy/sql/lambdas.py')
-rw-r--r-- | lib/sqlalchemy/sql/lambdas.py | 104 |
1 files changed, 73 insertions, 31 deletions
diff --git a/lib/sqlalchemy/sql/lambdas.py b/lib/sqlalchemy/sql/lambdas.py index b153ba999..d737b1bcb 100644 --- a/lib/sqlalchemy/sql/lambdas.py +++ b/lib/sqlalchemy/sql/lambdas.py @@ -18,13 +18,13 @@ from types import CodeType from typing import Any from typing import Callable from typing import cast -from typing import Iterable from typing import List from typing import MutableMapping from typing import Optional from typing import Tuple from typing import Type from typing import TYPE_CHECKING +from typing import TypeVar from typing import Union import weakref @@ -43,7 +43,6 @@ from .. import exc from .. import inspection from .. import util from ..util.typing import Literal -from ..util.typing import Protocol from ..util.typing import Self if TYPE_CHECKING: @@ -60,12 +59,14 @@ _BoundParameterGetter = Callable[..., Any] _closure_per_cache_key: _LambdaCacheType = util.LRUCache(1000) -class _LambdaType(Protocol): - __code__: CodeType - __closure__: Iterable[Tuple[Any, Any]] +_LambdaType = Callable[[], Any] - def __call__(self, *arg: Any, **kw: Any) -> ClauseElement: - ... +_AnyLambdaType = Callable[..., Any] + +_StmtLambdaType = Callable[[], Any] + +_E = TypeVar("_E", bound=Executable) +_StmtLambdaElementType = Callable[[_E], Any] class LambdaOptions(Options): @@ -78,7 +79,7 @@ class LambdaOptions(Options): def lambda_stmt( - lmb: _LambdaType, + lmb: _StmtLambdaType, enable_tracking: bool = True, track_closure_variables: bool = True, track_on: Optional[object] = None, @@ -185,7 +186,7 @@ class LambdaElement(elements.ClauseElement): closure_cache_key: Union[Tuple[Any, ...], Literal[CacheConst.NO_CACHE]] role: Type[SQLRole] _rec: Union[AnalyzedFunction, NonAnalyzedFunction] - fn: _LambdaType + fn: _AnyLambdaType tracker_key: Tuple[CodeType, ...] def __repr__(self): @@ -416,8 +417,8 @@ class LambdaElement(elements.ClauseElement): bindparams.extend(self._resolved_bindparams) return cache_key - def _invoke_user_fn(self, fn: _LambdaType, *arg: Any) -> ClauseElement: - return fn() + def _invoke_user_fn(self, fn: _AnyLambdaType, *arg: Any) -> ClauseElement: + return fn() # type: ignore[no-any-return] class DeferredLambdaElement(LambdaElement): @@ -494,7 +495,9 @@ class DeferredLambdaElement(LambdaElement): self._transforms += (deferred_copy_internals,) -class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement): +class StatementLambdaElement( + roles.AllowsLambdaRole, LambdaElement, Executable +): """Represent a composable SQL statement as a :class:`_sql.LambdaElement`. The :class:`_sql.StatementLambdaElement` is constructed using the @@ -520,17 +523,30 @@ class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement): """ - def __add__(self, other): + if TYPE_CHECKING: + + def __init__( + self, + fn: _StmtLambdaType, + role: Type[SQLRole], + opts: Union[Type[LambdaOptions], LambdaOptions] = LambdaOptions, + apply_propagate_attrs: Optional[ClauseElement] = None, + ): + ... + + def __add__( + self, other: _StmtLambdaElementType[Any] + ) -> StatementLambdaElement: return self.add_criteria(other) def add_criteria( self, - other, - enable_tracking=True, - track_on=None, - track_closure_variables=True, - track_bound_values=True, - ): + other: _StmtLambdaElementType[Any], + enable_tracking: bool = True, + track_on: Optional[Any] = None, + track_closure_variables: bool = True, + track_bound_values: bool = True, + ) -> StatementLambdaElement: """Add new criteria to this :class:`_sql.StatementLambdaElement`. E.g.:: @@ -588,24 +604,50 @@ class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement): raise exc.ObjectNotExecutableError(self) @property + def _proxied(self) -> Any: + return self._rec_expected_expr + + @property def _with_options(self): - if TYPE_CHECKING: - assert isinstance(self._rec.expected_expr, Executable) - return self._rec.expected_expr._with_options + return self._proxied._with_options @property def _effective_plugin_target(self): - if TYPE_CHECKING: - assert isinstance(self._rec.expected_expr, Executable) - return self._rec.expected_expr._effective_plugin_target + return self._proxied._effective_plugin_target @property def _execution_options(self): - if TYPE_CHECKING: - assert isinstance(self._rec.expected_expr, Executable) - return self._rec.expected_expr._execution_options + return self._proxied._execution_options + + @property + def _all_selected_columns(self): + return self._proxied._all_selected_columns + + @property + def is_select(self): + return self._proxied.is_select + + @property + def is_update(self): + return self._proxied.is_update + + @property + def is_insert(self): + return self._proxied.is_insert - def spoil(self): + @property + def is_text(self): + return self._proxied.is_text + + @property + def is_delete(self): + return self._proxied.is_delete + + @property + def is_dml(self): + return self._proxied.is_dml + + def spoil(self) -> NullLambdaStatement: """Return a new :class:`.StatementLambdaElement` that will run all lambdas unconditionally each time. @@ -667,12 +709,12 @@ class LinkedLambdaElement(StatementLambdaElement): def __init__( self, - fn: _LambdaType, + fn: _StmtLambdaElementType[Any], parent_lambda: StatementLambdaElement, opts: Union[Type[LambdaOptions], LambdaOptions], ): self.opts = opts - self.fn = fn + self.fn = fn # type: ignore[assignment] self.parent_lambda = parent_lambda self.tracker_key = parent_lambda.tracker_key + (fn.__code__,) |