diff options
Diffstat (limited to 'lib/sqlalchemy/sql/base.py')
-rw-r--r-- | lib/sqlalchemy/sql/base.py | 49 |
1 files changed, 39 insertions, 10 deletions
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 187651435..30d589258 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -16,6 +16,7 @@ import itertools from itertools import zip_longest import operator import re +import typing from . import roles from . import visitors @@ -29,6 +30,7 @@ from .. import exc from .. import util from ..util import HasMemoized from ..util import hybridmethod +from ..util import typing as compat_typing try: from sqlalchemy.cyextension.util import prefix_anon_map # noqa @@ -42,6 +44,10 @@ type_api = None NO_ARG = util.symbol("NO_ARG") +# if I use sqlalchemy.util.typing, which has the exact same +# symbols, mypy reports: "error: _Fn? not callable" +_Fn = typing.TypeVar("_Fn", bound=typing.Callable) + class Immutable: """mark a ClauseElement as 'immutable' when expressions are cloned.""" @@ -101,7 +107,16 @@ def _select_iterables(elements): ) -def _generative(fn): +_Self = typing.TypeVar("_Self", bound="_GenerativeType") +_Args = compat_typing.ParamSpec("_Args") + + +class _GenerativeType(compat_typing.Protocol): + def _generate(self: "_Self") -> "_Self": + ... + + +def _generative(fn: _Fn) -> _Fn: """non-caching _generative() decorator. This is basically the legacy decorator that copies the object and @@ -110,14 +125,14 @@ def _generative(fn): """ @util.decorator - def _generative(fn, self, *args, **kw): + def _generative( + fn: _Fn, self: _Self, *args: _Args.args, **kw: _Args.kwargs + ) -> _Self: """Mark a method as generative.""" self = self._generate() x = fn(self, *args, **kw) - assert ( - x is None or x is self - ), "generative methods must return None or self" + assert x is self, "generative methods must return self" return self decorated = _generative(fn) @@ -788,6 +803,9 @@ class ExecutableOption(HasCopyInternals): return c +SelfExecutable = typing.TypeVar("SelfExecutable", bound="Executable") + + class Executable(roles.StatementRole, Generative): """Mark a :class:`_expression.ClauseElement` as supporting execution. @@ -824,7 +842,7 @@ class Executable(roles.StatementRole, Generative): return self.__visit_name__ @_generative - def options(self, *options): + def options(self: SelfExecutable, *options) -> SelfExecutable: """Apply options to this statement. In the general sense, options are any kind of Python object @@ -857,9 +875,12 @@ class Executable(roles.StatementRole, Generative): coercions.expect(roles.ExecutableOptionRole, opt) for opt in options ) + return self @_generative - def _set_compile_options(self, compile_options): + def _set_compile_options( + self: SelfExecutable, compile_options + ) -> SelfExecutable: """Assign the compile options to a new value. :param compile_options: appropriate CacheableOptions structure @@ -867,15 +888,21 @@ class Executable(roles.StatementRole, Generative): """ self._compile_options = compile_options + return self @_generative - def _update_compile_options(self, options): + def _update_compile_options( + self: SelfExecutable, options + ) -> SelfExecutable: """update the _compile_options with new keys.""" self._compile_options += options + return self @_generative - def _add_context_option(self, callable_, cache_args): + def _add_context_option( + self: SelfExecutable, callable_, cache_args + ) -> SelfExecutable: """Add a context option to this statement. These are callable functions that will @@ -887,9 +914,10 @@ class Executable(roles.StatementRole, Generative): """ self._with_context_options += ((callable_, cache_args),) + return self @_generative - def execution_options(self, **kw): + def execution_options(self: SelfExecutable, **kw) -> SelfExecutable: """Set non-SQL options for the statement which take effect during execution. @@ -1004,6 +1032,7 @@ class Executable(roles.StatementRole, Generative): "on Connection.execution_options(), not per statement." ) self._execution_options = self._execution_options.union(kw) + return self def get_execution_options(self): """Get the non-SQL options which will take effect during execution. |