diff options
Diffstat (limited to 'lib/sqlalchemy/sql/cache_key.py')
-rw-r--r-- | lib/sqlalchemy/sql/cache_key.py | 354 |
1 files changed, 271 insertions, 83 deletions
diff --git a/lib/sqlalchemy/sql/cache_key.py b/lib/sqlalchemy/sql/cache_key.py index ff659b77d..fca58f98e 100644 --- a/lib/sqlalchemy/sql/cache_key.py +++ b/lib/sqlalchemy/sql/cache_key.py @@ -11,21 +11,41 @@ import enum from itertools import zip_longest import typing from typing import Any -from typing import Callable +from typing import cast +from typing import Dict +from typing import Iterator +from typing import List from typing import NamedTuple +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Type from typing import Union from .visitors import anon_map -from .visitors import ExtendedInternalTraversal +from .visitors import HasTraversalDispatch +from .visitors import HasTraverseInternals from .visitors import InternalTraversal +from .visitors import prefix_anon_map from .. import util from ..inspection import inspect from ..util import HasMemoized from ..util.typing import Literal - +from ..util.typing import Protocol if typing.TYPE_CHECKING: from .elements import BindParameter + from .elements import ClauseElement + from .visitors import _TraverseInternalsType + from ..engine.base import _CompiledCacheType + from ..engine.interfaces import _CoreSingleExecuteParams + + +class _CacheKeyTraversalDispatchType(Protocol): + def __call__( + s, self: HasCacheKey, visitor: _CacheKeyTraversal + ) -> CacheKey: + ... class CacheConst(enum.Enum): @@ -70,7 +90,9 @@ class HasCacheKey: __slots__ = () - _cache_key_traversal = NO_CACHE + _cache_key_traversal: Union[ + _TraverseInternalsType, Literal[CacheConst.NO_CACHE] + ] = NO_CACHE _is_has_cache_key = True @@ -83,7 +105,7 @@ class HasCacheKey: """ - inherit_cache = None + inherit_cache: Optional[bool] = None """Indicate if this :class:`.HasCacheKey` instance should make use of the cache key generation scheme used by its immediate superclass. @@ -106,8 +128,12 @@ class HasCacheKey: __slots__ = () + _generated_cache_key_traversal: Any + @classmethod - def _generate_cache_attrs(cls): + def _generate_cache_attrs( + cls, + ) -> Union[_CacheKeyTraversalDispatchType, Literal[CacheConst.NO_CACHE]]: """generate cache key dispatcher for a new class. This sets the _generated_cache_key_traversal attribute once called @@ -121,8 +147,11 @@ class HasCacheKey: _cache_key_traversal = getattr(cls, "_cache_key_traversal", None) if _cache_key_traversal is None: try: - # this would be HasTraverseInternals - _cache_key_traversal = cls._traverse_internals + # check for _traverse_internals, which is part of + # HasTraverseInternals + _cache_key_traversal = cast( + "Type[HasTraverseInternals]", cls + )._traverse_internals except AttributeError: cls._generated_cache_key_traversal = NO_CACHE return NO_CACHE @@ -138,7 +167,9 @@ class HasCacheKey: # more complicated, so for the moment this is a little less # efficient on startup but simpler. return _cache_key_traversal_visitor.generate_dispatch( - cls, _cache_key_traversal, "_generated_cache_key_traversal" + cls, + _cache_key_traversal, + "_generated_cache_key_traversal", ) else: _cache_key_traversal = cls.__dict__.get( @@ -170,11 +201,15 @@ class HasCacheKey: return NO_CACHE return _cache_key_traversal_visitor.generate_dispatch( - cls, _cache_key_traversal, "_generated_cache_key_traversal" + cls, + _cache_key_traversal, + "_generated_cache_key_traversal", ) @util.preload_module("sqlalchemy.sql.elements") - def _gen_cache_key(self, anon_map, bindparams): + def _gen_cache_key( + self, anon_map: anon_map, bindparams: List[BindParameter[Any]] + ) -> Optional[Tuple[Any, ...]]: """return an optional cache key. The cache key is a tuple which can contain any series of @@ -202,15 +237,15 @@ class HasCacheKey: dispatcher: Union[ Literal[CacheConst.NO_CACHE], - Callable[[HasCacheKey, "_CacheKeyTraversal"], "CacheKey"], + _CacheKeyTraversalDispatchType, ] try: dispatcher = cls.__dict__["_generated_cache_key_traversal"] except KeyError: - # most of the dispatchers are generated up front - # in sqlalchemy/sql/__init__.py -> - # traversals.py-> _preconfigure_traversals(). + # traversals.py -> _preconfigure_traversals() + # may be used to run these ahead of time, but + # is not enabled right now. # this block will generate any remaining dispatchers. dispatcher = cls._generate_cache_attrs() @@ -218,7 +253,7 @@ class HasCacheKey: anon_map[NO_CACHE] = True return None - result = (id_, cls) + result: Tuple[Any, ...] = (id_, cls) # inline of _cache_key_traversal_visitor.run_generated_dispatch() @@ -268,7 +303,7 @@ class HasCacheKey: # Columns, this should be long lived. For select() # statements, not so much, but they usually won't have # annotations. - result += self._annotations_cache_key + result += self._annotations_cache_key # type: ignore elif ( meth is InternalTraversal.dp_clauseelement_list or meth is InternalTraversal.dp_clauseelement_tuple @@ -290,7 +325,7 @@ class HasCacheKey: ) return result - def _generate_cache_key(self): + def _generate_cache_key(self) -> Optional[CacheKey]: """return a cache key. The cache key is a tuple which can contain any series of @@ -322,32 +357,40 @@ class HasCacheKey: """ - bindparams = [] + bindparams: List[BindParameter[Any]] = [] _anon_map = anon_map() key = self._gen_cache_key(_anon_map, bindparams) if NO_CACHE in _anon_map: return None else: + assert key is not None return CacheKey(key, bindparams) @classmethod - def _generate_cache_key_for_object(cls, obj): - bindparams = [] + def _generate_cache_key_for_object( + cls, obj: HasCacheKey + ) -> Optional[CacheKey]: + bindparams: List[BindParameter[Any]] = [] _anon_map = anon_map() key = obj._gen_cache_key(_anon_map, bindparams) if NO_CACHE in _anon_map: return None else: + assert key is not None return CacheKey(key, bindparams) +class HasCacheKeyTraverse(HasTraverseInternals, HasCacheKey): + pass + + class MemoizedHasCacheKey(HasCacheKey, HasMemoized): __slots__ = () @HasMemoized.memoized_instancemethod - def _generate_cache_key(self): + def _generate_cache_key(self) -> Optional[CacheKey]: return HasCacheKey._generate_cache_key(self) @@ -362,14 +405,22 @@ class CacheKey(NamedTuple): """ key: Tuple[Any, ...] - bindparams: Sequence[BindParameter] + bindparams: Sequence[BindParameter[Any]] - def __hash__(self): + # can't set __hash__ attribute because it interferes + # with namedtuple + # can't use "if not TYPE_CHECKING" because mypy rejects it + # inside of a NamedTuple + def __hash__(self) -> Optional[int]: # type: ignore """CacheKey itself is not hashable - hash the .key portion""" - return None - def to_offline_string(self, statement_cache, statement, parameters): + def to_offline_string( + self, + statement_cache: _CompiledCacheType, + statement: ClauseElement, + parameters: _CoreSingleExecuteParams, + ) -> str: """Generate an "offline string" form of this :class:`.CacheKey` The "offline string" is basically the string SQL for the @@ -400,21 +451,21 @@ class CacheKey(NamedTuple): return repr((sql_str, param_tuple)) - def __eq__(self, other): - return self.key == other.key + def __eq__(self, other: Any) -> bool: + return bool(self.key == other.key) @classmethod - def _diff_tuples(cls, left, right): + def _diff_tuples(cls, left: CacheKey, right: CacheKey) -> str: ck1 = CacheKey(left, []) ck2 = CacheKey(right, []) return ck1._diff(ck2) - def _whats_different(self, other): + def _whats_different(self, other: CacheKey) -> Iterator[str]: k1 = self.key k2 = other.key - stack = [] + stack: List[int] = [] pickup_index = 0 while True: s1, s2 = k1, k2 @@ -440,11 +491,11 @@ class CacheKey(NamedTuple): pickup_index = stack.pop(-1) break - def _diff(self, other): + def _diff(self, other: CacheKey) -> str: return ", ".join(self._whats_different(other)) - def __str__(self): - stack = [self.key] + def __str__(self) -> str: + stack: List[Union[Tuple[Any, ...], HasCacheKey]] = [self.key] output = [] sentinel = object() @@ -473,15 +524,15 @@ class CacheKey(NamedTuple): return "CacheKey(key=%s)" % ("\n".join(output),) - def _generate_param_dict(self): + def _generate_param_dict(self) -> Dict[str, Any]: """used for testing""" - from .compiler import prefix_anon_map - _anon_map = prefix_anon_map() return {b.key % _anon_map: b.effective_value for b in self.bindparams} - def _apply_params_to_element(self, original_cache_key, target_element): + def _apply_params_to_element( + self, original_cache_key: CacheKey, target_element: ClauseElement + ) -> ClauseElement: translate = { k.key: v.value for k, v in zip(original_cache_key.bindparams, self.bindparams) @@ -490,7 +541,7 @@ class CacheKey(NamedTuple): return target_element.params(translate) -class _CacheKeyTraversal(ExtendedInternalTraversal): +class _CacheKeyTraversal(HasTraversalDispatch): # very common elements are inlined into the main _get_cache_key() method # to produce a dramatic savings in Python function call overhead @@ -512,17 +563,43 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): visit_propagate_attrs = PROPAGATE_ATTRS def visit_with_context_options( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return tuple((fn.__code__, c_key) for fn, c_key in obj) - def visit_inspectable(self, attrname, obj, parent, anon_map, bindparams): + def visit_inspectable( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return (attrname, inspect(obj)._gen_cache_key(anon_map, bindparams)) - def visit_string_list(self, attrname, obj, parent, anon_map, bindparams): + def visit_string_list( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return tuple(obj) - def visit_multi(self, attrname, obj, parent, anon_map, bindparams): + def visit_multi( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return ( attrname, obj._gen_cache_key(anon_map, bindparams) @@ -530,7 +607,14 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): else obj, ) - def visit_multi_list(self, attrname, obj, parent, anon_map, bindparams): + def visit_multi_list( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return ( attrname, tuple( @@ -542,8 +626,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_has_cache_key_tuples( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: if not obj: return () return ( @@ -558,8 +647,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_has_cache_key_list( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: if not obj: return () return ( @@ -568,8 +662,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_executable_options( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: if not obj: return () return ( @@ -582,22 +681,37 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_inspectable_list( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return self.visit_has_cache_key_list( attrname, [inspect(o) for o in obj], parent, anon_map, bindparams ) def visit_clauseelement_tuples( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return self.visit_has_cache_key_tuples( attrname, obj, parent, anon_map, bindparams ) def visit_fromclause_ordered_set( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: if not obj: return () return ( @@ -606,8 +720,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_clauseelement_unordered_set( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: if not obj: return () cache_keys = [ @@ -621,13 +740,23 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_named_ddl_element( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return (attrname, obj.name) def visit_prefix_sequence( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: if not obj: return () @@ -642,8 +771,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_setup_join_tuple( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return tuple( ( target._gen_cache_key(anon_map, bindparams), @@ -659,8 +793,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_table_hint_list( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: if not obj: return () @@ -678,12 +817,24 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ), ) - def visit_plain_dict(self, attrname, obj, parent, anon_map, bindparams): + def visit_plain_dict( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return (attrname, tuple([(key, obj[key]) for key in sorted(obj)])) def visit_dialect_options( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return ( attrname, tuple( @@ -701,8 +852,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_string_clauseelement_dict( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return ( attrname, tuple( @@ -712,8 +868,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_string_multi_dict( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return ( attrname, tuple( @@ -728,8 +889,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_fromclause_canonical_column_collection( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: # inlining into the internals of ColumnCollection return ( attrname, @@ -740,14 +906,24 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_unknown_structure( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: anon_map[NO_CACHE] = True return () def visit_dml_ordered_values( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return ( attrname, tuple( @@ -761,7 +937,14 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ), ) - def visit_dml_values(self, attrname, obj, parent, anon_map, bindparams): + def visit_dml_values( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: # in py37 we can assume two dictionaries created in the same # insert ordering will retain that sorting return ( @@ -778,8 +961,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_dml_multi_values( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: # multivalues are simply not cacheable right now anon_map[NO_CACHE] = True return () |