summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/lambdas.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2023-01-19 12:09:29 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2023-01-19 17:04:59 -0500
commitbe0831fea83247451628bc6643d5b130c63f6011 (patch)
tree97036e708f073e91d7de6fc490cba06d14c1b142 /lib/sqlalchemy/sql/lambdas.py
parente82a5f19e1606500ad4bf6a456c2558d74df24bf (diff)
downloadsqlalchemy-be0831fea83247451628bc6643d5b130c63f6011.tar.gz
implement basic typing for lambda elements
These weren't working at all, so fixed things up and added a test suite. Keeping things very basic with Any returns etc. as having more specific return types starts making it too cumbersome to write end-user code. Corrected the type passed for "lambda statements" so that a plain lambda is accepted by mypy, pyright, others without any errors about argument types. Additionally implemented typing for more of the public API for lambda statements and ensured :class:`.StatementLambdaElement` is part of the :class:`.Executable` hierarchy so it's typed as accepted by :meth:`_engine.Connection.execute`. Fixes: #9120 Change-Id: Ia7fa34e5b6e43fba02c8f94ccc256f3a68a1f445
Diffstat (limited to 'lib/sqlalchemy/sql/lambdas.py')
-rw-r--r--lib/sqlalchemy/sql/lambdas.py104
1 files changed, 73 insertions, 31 deletions
diff --git a/lib/sqlalchemy/sql/lambdas.py b/lib/sqlalchemy/sql/lambdas.py
index b153ba999..d737b1bcb 100644
--- a/lib/sqlalchemy/sql/lambdas.py
+++ b/lib/sqlalchemy/sql/lambdas.py
@@ -18,13 +18,13 @@ from types import CodeType
from typing import Any
from typing import Callable
from typing import cast
-from typing import Iterable
from typing import List
from typing import MutableMapping
from typing import Optional
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
+from typing import TypeVar
from typing import Union
import weakref
@@ -43,7 +43,6 @@ from .. import exc
from .. import inspection
from .. import util
from ..util.typing import Literal
-from ..util.typing import Protocol
from ..util.typing import Self
if TYPE_CHECKING:
@@ -60,12 +59,14 @@ _BoundParameterGetter = Callable[..., Any]
_closure_per_cache_key: _LambdaCacheType = util.LRUCache(1000)
-class _LambdaType(Protocol):
- __code__: CodeType
- __closure__: Iterable[Tuple[Any, Any]]
+_LambdaType = Callable[[], Any]
- def __call__(self, *arg: Any, **kw: Any) -> ClauseElement:
- ...
+_AnyLambdaType = Callable[..., Any]
+
+_StmtLambdaType = Callable[[], Any]
+
+_E = TypeVar("_E", bound=Executable)
+_StmtLambdaElementType = Callable[[_E], Any]
class LambdaOptions(Options):
@@ -78,7 +79,7 @@ class LambdaOptions(Options):
def lambda_stmt(
- lmb: _LambdaType,
+ lmb: _StmtLambdaType,
enable_tracking: bool = True,
track_closure_variables: bool = True,
track_on: Optional[object] = None,
@@ -185,7 +186,7 @@ class LambdaElement(elements.ClauseElement):
closure_cache_key: Union[Tuple[Any, ...], Literal[CacheConst.NO_CACHE]]
role: Type[SQLRole]
_rec: Union[AnalyzedFunction, NonAnalyzedFunction]
- fn: _LambdaType
+ fn: _AnyLambdaType
tracker_key: Tuple[CodeType, ...]
def __repr__(self):
@@ -416,8 +417,8 @@ class LambdaElement(elements.ClauseElement):
bindparams.extend(self._resolved_bindparams)
return cache_key
- def _invoke_user_fn(self, fn: _LambdaType, *arg: Any) -> ClauseElement:
- return fn()
+ def _invoke_user_fn(self, fn: _AnyLambdaType, *arg: Any) -> ClauseElement:
+ return fn() # type: ignore[no-any-return]
class DeferredLambdaElement(LambdaElement):
@@ -494,7 +495,9 @@ class DeferredLambdaElement(LambdaElement):
self._transforms += (deferred_copy_internals,)
-class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement):
+class StatementLambdaElement(
+ roles.AllowsLambdaRole, LambdaElement, Executable
+):
"""Represent a composable SQL statement as a :class:`_sql.LambdaElement`.
The :class:`_sql.StatementLambdaElement` is constructed using the
@@ -520,17 +523,30 @@ class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement):
"""
- def __add__(self, other):
+ if TYPE_CHECKING:
+
+ def __init__(
+ self,
+ fn: _StmtLambdaType,
+ role: Type[SQLRole],
+ opts: Union[Type[LambdaOptions], LambdaOptions] = LambdaOptions,
+ apply_propagate_attrs: Optional[ClauseElement] = None,
+ ):
+ ...
+
+ def __add__(
+ self, other: _StmtLambdaElementType[Any]
+ ) -> StatementLambdaElement:
return self.add_criteria(other)
def add_criteria(
self,
- other,
- enable_tracking=True,
- track_on=None,
- track_closure_variables=True,
- track_bound_values=True,
- ):
+ other: _StmtLambdaElementType[Any],
+ enable_tracking: bool = True,
+ track_on: Optional[Any] = None,
+ track_closure_variables: bool = True,
+ track_bound_values: bool = True,
+ ) -> StatementLambdaElement:
"""Add new criteria to this :class:`_sql.StatementLambdaElement`.
E.g.::
@@ -588,24 +604,50 @@ class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement):
raise exc.ObjectNotExecutableError(self)
@property
+ def _proxied(self) -> Any:
+ return self._rec_expected_expr
+
+ @property
def _with_options(self):
- if TYPE_CHECKING:
- assert isinstance(self._rec.expected_expr, Executable)
- return self._rec.expected_expr._with_options
+ return self._proxied._with_options
@property
def _effective_plugin_target(self):
- if TYPE_CHECKING:
- assert isinstance(self._rec.expected_expr, Executable)
- return self._rec.expected_expr._effective_plugin_target
+ return self._proxied._effective_plugin_target
@property
def _execution_options(self):
- if TYPE_CHECKING:
- assert isinstance(self._rec.expected_expr, Executable)
- return self._rec.expected_expr._execution_options
+ return self._proxied._execution_options
+
+ @property
+ def _all_selected_columns(self):
+ return self._proxied._all_selected_columns
+
+ @property
+ def is_select(self):
+ return self._proxied.is_select
+
+ @property
+ def is_update(self):
+ return self._proxied.is_update
+
+ @property
+ def is_insert(self):
+ return self._proxied.is_insert
- def spoil(self):
+ @property
+ def is_text(self):
+ return self._proxied.is_text
+
+ @property
+ def is_delete(self):
+ return self._proxied.is_delete
+
+ @property
+ def is_dml(self):
+ return self._proxied.is_dml
+
+ def spoil(self) -> NullLambdaStatement:
"""Return a new :class:`.StatementLambdaElement` that will run
all lambdas unconditionally each time.
@@ -667,12 +709,12 @@ class LinkedLambdaElement(StatementLambdaElement):
def __init__(
self,
- fn: _LambdaType,
+ fn: _StmtLambdaElementType[Any],
parent_lambda: StatementLambdaElement,
opts: Union[Type[LambdaOptions], LambdaOptions],
):
self.opts = opts
- self.fn = fn
+ self.fn = fn # type: ignore[assignment]
self.parent_lambda = parent_lambda
self.tracker_key = parent_lambda.tracker_key + (fn.__code__,)