summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/base.py')
-rw-r--r--lib/sqlalchemy/sql/base.py49
1 files changed, 39 insertions, 10 deletions
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index 187651435..30d589258 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -16,6 +16,7 @@ import itertools
from itertools import zip_longest
import operator
import re
+import typing
from . import roles
from . import visitors
@@ -29,6 +30,7 @@ from .. import exc
from .. import util
from ..util import HasMemoized
from ..util import hybridmethod
+from ..util import typing as compat_typing
try:
from sqlalchemy.cyextension.util import prefix_anon_map # noqa
@@ -42,6 +44,10 @@ type_api = None
NO_ARG = util.symbol("NO_ARG")
+# if I use sqlalchemy.util.typing, which has the exact same
+# symbols, mypy reports: "error: _Fn? not callable"
+_Fn = typing.TypeVar("_Fn", bound=typing.Callable)
+
class Immutable:
"""mark a ClauseElement as 'immutable' when expressions are cloned."""
@@ -101,7 +107,16 @@ def _select_iterables(elements):
)
-def _generative(fn):
+_Self = typing.TypeVar("_Self", bound="_GenerativeType")
+_Args = compat_typing.ParamSpec("_Args")
+
+
+class _GenerativeType(compat_typing.Protocol):
+ def _generate(self: "_Self") -> "_Self":
+ ...
+
+
+def _generative(fn: _Fn) -> _Fn:
"""non-caching _generative() decorator.
This is basically the legacy decorator that copies the object and
@@ -110,14 +125,14 @@ def _generative(fn):
"""
@util.decorator
- def _generative(fn, self, *args, **kw):
+ def _generative(
+ fn: _Fn, self: _Self, *args: _Args.args, **kw: _Args.kwargs
+ ) -> _Self:
"""Mark a method as generative."""
self = self._generate()
x = fn(self, *args, **kw)
- assert (
- x is None or x is self
- ), "generative methods must return None or self"
+ assert x is self, "generative methods must return self"
return self
decorated = _generative(fn)
@@ -788,6 +803,9 @@ class ExecutableOption(HasCopyInternals):
return c
+SelfExecutable = typing.TypeVar("SelfExecutable", bound="Executable")
+
+
class Executable(roles.StatementRole, Generative):
"""Mark a :class:`_expression.ClauseElement` as supporting execution.
@@ -824,7 +842,7 @@ class Executable(roles.StatementRole, Generative):
return self.__visit_name__
@_generative
- def options(self, *options):
+ def options(self: SelfExecutable, *options) -> SelfExecutable:
"""Apply options to this statement.
In the general sense, options are any kind of Python object
@@ -857,9 +875,12 @@ class Executable(roles.StatementRole, Generative):
coercions.expect(roles.ExecutableOptionRole, opt)
for opt in options
)
+ return self
@_generative
- def _set_compile_options(self, compile_options):
+ def _set_compile_options(
+ self: SelfExecutable, compile_options
+ ) -> SelfExecutable:
"""Assign the compile options to a new value.
:param compile_options: appropriate CacheableOptions structure
@@ -867,15 +888,21 @@ class Executable(roles.StatementRole, Generative):
"""
self._compile_options = compile_options
+ return self
@_generative
- def _update_compile_options(self, options):
+ def _update_compile_options(
+ self: SelfExecutable, options
+ ) -> SelfExecutable:
"""update the _compile_options with new keys."""
self._compile_options += options
+ return self
@_generative
- def _add_context_option(self, callable_, cache_args):
+ def _add_context_option(
+ self: SelfExecutable, callable_, cache_args
+ ) -> SelfExecutable:
"""Add a context option to this statement.
These are callable functions that will
@@ -887,9 +914,10 @@ class Executable(roles.StatementRole, Generative):
"""
self._with_context_options += ((callable_, cache_args),)
+ return self
@_generative
- def execution_options(self, **kw):
+ def execution_options(self: SelfExecutable, **kw) -> SelfExecutable:
"""Set non-SQL options for the statement which take effect during
execution.
@@ -1004,6 +1032,7 @@ class Executable(roles.StatementRole, Generative):
"on Connection.execution_options(), not per statement."
)
self._execution_options = self._execution_options.union(kw)
+ return self
def get_execution_options(self):
"""Get the non-SQL options which will take effect during execution.