summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/cache_key.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/cache_key.py')
-rw-r--r--lib/sqlalchemy/sql/cache_key.py354
1 files changed, 271 insertions, 83 deletions
diff --git a/lib/sqlalchemy/sql/cache_key.py b/lib/sqlalchemy/sql/cache_key.py
index ff659b77d..fca58f98e 100644
--- a/lib/sqlalchemy/sql/cache_key.py
+++ b/lib/sqlalchemy/sql/cache_key.py
@@ -11,21 +11,41 @@ import enum
from itertools import zip_longest
import typing
from typing import Any
-from typing import Callable
+from typing import cast
+from typing import Dict
+from typing import Iterator
+from typing import List
from typing import NamedTuple
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Type
from typing import Union
from .visitors import anon_map
-from .visitors import ExtendedInternalTraversal
+from .visitors import HasTraversalDispatch
+from .visitors import HasTraverseInternals
from .visitors import InternalTraversal
+from .visitors import prefix_anon_map
from .. import util
from ..inspection import inspect
from ..util import HasMemoized
from ..util.typing import Literal
-
+from ..util.typing import Protocol
if typing.TYPE_CHECKING:
from .elements import BindParameter
+ from .elements import ClauseElement
+ from .visitors import _TraverseInternalsType
+ from ..engine.base import _CompiledCacheType
+ from ..engine.interfaces import _CoreSingleExecuteParams
+
+
+class _CacheKeyTraversalDispatchType(Protocol):
+ def __call__(
+ s, self: HasCacheKey, visitor: _CacheKeyTraversal
+ ) -> CacheKey:
+ ...
class CacheConst(enum.Enum):
@@ -70,7 +90,9 @@ class HasCacheKey:
__slots__ = ()
- _cache_key_traversal = NO_CACHE
+ _cache_key_traversal: Union[
+ _TraverseInternalsType, Literal[CacheConst.NO_CACHE]
+ ] = NO_CACHE
_is_has_cache_key = True
@@ -83,7 +105,7 @@ class HasCacheKey:
"""
- inherit_cache = None
+ inherit_cache: Optional[bool] = None
"""Indicate if this :class:`.HasCacheKey` instance should make use of the
cache key generation scheme used by its immediate superclass.
@@ -106,8 +128,12 @@ class HasCacheKey:
__slots__ = ()
+ _generated_cache_key_traversal: Any
+
@classmethod
- def _generate_cache_attrs(cls):
+ def _generate_cache_attrs(
+ cls,
+ ) -> Union[_CacheKeyTraversalDispatchType, Literal[CacheConst.NO_CACHE]]:
"""generate cache key dispatcher for a new class.
This sets the _generated_cache_key_traversal attribute once called
@@ -121,8 +147,11 @@ class HasCacheKey:
_cache_key_traversal = getattr(cls, "_cache_key_traversal", None)
if _cache_key_traversal is None:
try:
- # this would be HasTraverseInternals
- _cache_key_traversal = cls._traverse_internals
+ # check for _traverse_internals, which is part of
+ # HasTraverseInternals
+ _cache_key_traversal = cast(
+ "Type[HasTraverseInternals]", cls
+ )._traverse_internals
except AttributeError:
cls._generated_cache_key_traversal = NO_CACHE
return NO_CACHE
@@ -138,7 +167,9 @@ class HasCacheKey:
# more complicated, so for the moment this is a little less
# efficient on startup but simpler.
return _cache_key_traversal_visitor.generate_dispatch(
- cls, _cache_key_traversal, "_generated_cache_key_traversal"
+ cls,
+ _cache_key_traversal,
+ "_generated_cache_key_traversal",
)
else:
_cache_key_traversal = cls.__dict__.get(
@@ -170,11 +201,15 @@ class HasCacheKey:
return NO_CACHE
return _cache_key_traversal_visitor.generate_dispatch(
- cls, _cache_key_traversal, "_generated_cache_key_traversal"
+ cls,
+ _cache_key_traversal,
+ "_generated_cache_key_traversal",
)
@util.preload_module("sqlalchemy.sql.elements")
- def _gen_cache_key(self, anon_map, bindparams):
+ def _gen_cache_key(
+ self, anon_map: anon_map, bindparams: List[BindParameter[Any]]
+ ) -> Optional[Tuple[Any, ...]]:
"""return an optional cache key.
The cache key is a tuple which can contain any series of
@@ -202,15 +237,15 @@ class HasCacheKey:
dispatcher: Union[
Literal[CacheConst.NO_CACHE],
- Callable[[HasCacheKey, "_CacheKeyTraversal"], "CacheKey"],
+ _CacheKeyTraversalDispatchType,
]
try:
dispatcher = cls.__dict__["_generated_cache_key_traversal"]
except KeyError:
- # most of the dispatchers are generated up front
- # in sqlalchemy/sql/__init__.py ->
- # traversals.py-> _preconfigure_traversals().
+ # traversals.py -> _preconfigure_traversals()
+ # may be used to run these ahead of time, but
+ # is not enabled right now.
# this block will generate any remaining dispatchers.
dispatcher = cls._generate_cache_attrs()
@@ -218,7 +253,7 @@ class HasCacheKey:
anon_map[NO_CACHE] = True
return None
- result = (id_, cls)
+ result: Tuple[Any, ...] = (id_, cls)
# inline of _cache_key_traversal_visitor.run_generated_dispatch()
@@ -268,7 +303,7 @@ class HasCacheKey:
# Columns, this should be long lived. For select()
# statements, not so much, but they usually won't have
# annotations.
- result += self._annotations_cache_key
+ result += self._annotations_cache_key # type: ignore
elif (
meth is InternalTraversal.dp_clauseelement_list
or meth is InternalTraversal.dp_clauseelement_tuple
@@ -290,7 +325,7 @@ class HasCacheKey:
)
return result
- def _generate_cache_key(self):
+ def _generate_cache_key(self) -> Optional[CacheKey]:
"""return a cache key.
The cache key is a tuple which can contain any series of
@@ -322,32 +357,40 @@ class HasCacheKey:
"""
- bindparams = []
+ bindparams: List[BindParameter[Any]] = []
_anon_map = anon_map()
key = self._gen_cache_key(_anon_map, bindparams)
if NO_CACHE in _anon_map:
return None
else:
+ assert key is not None
return CacheKey(key, bindparams)
@classmethod
- def _generate_cache_key_for_object(cls, obj):
- bindparams = []
+ def _generate_cache_key_for_object(
+ cls, obj: HasCacheKey
+ ) -> Optional[CacheKey]:
+ bindparams: List[BindParameter[Any]] = []
_anon_map = anon_map()
key = obj._gen_cache_key(_anon_map, bindparams)
if NO_CACHE in _anon_map:
return None
else:
+ assert key is not None
return CacheKey(key, bindparams)
+class HasCacheKeyTraverse(HasTraverseInternals, HasCacheKey):
+ pass
+
+
class MemoizedHasCacheKey(HasCacheKey, HasMemoized):
__slots__ = ()
@HasMemoized.memoized_instancemethod
- def _generate_cache_key(self):
+ def _generate_cache_key(self) -> Optional[CacheKey]:
return HasCacheKey._generate_cache_key(self)
@@ -362,14 +405,22 @@ class CacheKey(NamedTuple):
"""
key: Tuple[Any, ...]
- bindparams: Sequence[BindParameter]
+ bindparams: Sequence[BindParameter[Any]]
- def __hash__(self):
+ # can't set __hash__ attribute because it interferes
+ # with namedtuple
+ # can't use "if not TYPE_CHECKING" because mypy rejects it
+ # inside of a NamedTuple
+ def __hash__(self) -> Optional[int]: # type: ignore
"""CacheKey itself is not hashable - hash the .key portion"""
-
return None
- def to_offline_string(self, statement_cache, statement, parameters):
+ def to_offline_string(
+ self,
+ statement_cache: _CompiledCacheType,
+ statement: ClauseElement,
+ parameters: _CoreSingleExecuteParams,
+ ) -> str:
"""Generate an "offline string" form of this :class:`.CacheKey`
The "offline string" is basically the string SQL for the
@@ -400,21 +451,21 @@ class CacheKey(NamedTuple):
return repr((sql_str, param_tuple))
- def __eq__(self, other):
- return self.key == other.key
+ def __eq__(self, other: Any) -> bool:
+ return bool(self.key == other.key)
@classmethod
- def _diff_tuples(cls, left, right):
+ def _diff_tuples(cls, left: CacheKey, right: CacheKey) -> str:
ck1 = CacheKey(left, [])
ck2 = CacheKey(right, [])
return ck1._diff(ck2)
- def _whats_different(self, other):
+ def _whats_different(self, other: CacheKey) -> Iterator[str]:
k1 = self.key
k2 = other.key
- stack = []
+ stack: List[int] = []
pickup_index = 0
while True:
s1, s2 = k1, k2
@@ -440,11 +491,11 @@ class CacheKey(NamedTuple):
pickup_index = stack.pop(-1)
break
- def _diff(self, other):
+ def _diff(self, other: CacheKey) -> str:
return ", ".join(self._whats_different(other))
- def __str__(self):
- stack = [self.key]
+ def __str__(self) -> str:
+ stack: List[Union[Tuple[Any, ...], HasCacheKey]] = [self.key]
output = []
sentinel = object()
@@ -473,15 +524,15 @@ class CacheKey(NamedTuple):
return "CacheKey(key=%s)" % ("\n".join(output),)
- def _generate_param_dict(self):
+ def _generate_param_dict(self) -> Dict[str, Any]:
"""used for testing"""
- from .compiler import prefix_anon_map
-
_anon_map = prefix_anon_map()
return {b.key % _anon_map: b.effective_value for b in self.bindparams}
- def _apply_params_to_element(self, original_cache_key, target_element):
+ def _apply_params_to_element(
+ self, original_cache_key: CacheKey, target_element: ClauseElement
+ ) -> ClauseElement:
translate = {
k.key: v.value
for k, v in zip(original_cache_key.bindparams, self.bindparams)
@@ -490,7 +541,7 @@ class CacheKey(NamedTuple):
return target_element.params(translate)
-class _CacheKeyTraversal(ExtendedInternalTraversal):
+class _CacheKeyTraversal(HasTraversalDispatch):
# very common elements are inlined into the main _get_cache_key() method
# to produce a dramatic savings in Python function call overhead
@@ -512,17 +563,43 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
visit_propagate_attrs = PROPAGATE_ATTRS
def visit_with_context_options(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return tuple((fn.__code__, c_key) for fn, c_key in obj)
- def visit_inspectable(self, attrname, obj, parent, anon_map, bindparams):
+ def visit_inspectable(
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return (attrname, inspect(obj)._gen_cache_key(anon_map, bindparams))
- def visit_string_list(self, attrname, obj, parent, anon_map, bindparams):
+ def visit_string_list(
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return tuple(obj)
- def visit_multi(self, attrname, obj, parent, anon_map, bindparams):
+ def visit_multi(
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return (
attrname,
obj._gen_cache_key(anon_map, bindparams)
@@ -530,7 +607,14 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
else obj,
)
- def visit_multi_list(self, attrname, obj, parent, anon_map, bindparams):
+ def visit_multi_list(
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return (
attrname,
tuple(
@@ -542,8 +626,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
)
def visit_has_cache_key_tuples(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
if not obj:
return ()
return (
@@ -558,8 +647,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
)
def visit_has_cache_key_list(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
if not obj:
return ()
return (
@@ -568,8 +662,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
)
def visit_executable_options(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
if not obj:
return ()
return (
@@ -582,22 +681,37 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
)
def visit_inspectable_list(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return self.visit_has_cache_key_list(
attrname, [inspect(o) for o in obj], parent, anon_map, bindparams
)
def visit_clauseelement_tuples(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return self.visit_has_cache_key_tuples(
attrname, obj, parent, anon_map, bindparams
)
def visit_fromclause_ordered_set(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
if not obj:
return ()
return (
@@ -606,8 +720,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
)
def visit_clauseelement_unordered_set(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
if not obj:
return ()
cache_keys = [
@@ -621,13 +740,23 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
)
def visit_named_ddl_element(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return (attrname, obj.name)
def visit_prefix_sequence(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
if not obj:
return ()
@@ -642,8 +771,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
)
def visit_setup_join_tuple(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return tuple(
(
target._gen_cache_key(anon_map, bindparams),
@@ -659,8 +793,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
)
def visit_table_hint_list(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
if not obj:
return ()
@@ -678,12 +817,24 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
),
)
- def visit_plain_dict(self, attrname, obj, parent, anon_map, bindparams):
+ def visit_plain_dict(
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return (attrname, tuple([(key, obj[key]) for key in sorted(obj)]))
def visit_dialect_options(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return (
attrname,
tuple(
@@ -701,8 +852,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
)
def visit_string_clauseelement_dict(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return (
attrname,
tuple(
@@ -712,8 +868,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
)
def visit_string_multi_dict(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return (
attrname,
tuple(
@@ -728,8 +889,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
)
def visit_fromclause_canonical_column_collection(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
# inlining into the internals of ColumnCollection
return (
attrname,
@@ -740,14 +906,24 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
)
def visit_unknown_structure(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
anon_map[NO_CACHE] = True
return ()
def visit_dml_ordered_values(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return (
attrname,
tuple(
@@ -761,7 +937,14 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
),
)
- def visit_dml_values(self, attrname, obj, parent, anon_map, bindparams):
+ def visit_dml_values(
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
# in py37 we can assume two dictionaries created in the same
# insert ordering will retain that sorting
return (
@@ -778,8 +961,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
)
def visit_dml_multi_values(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
# multivalues are simply not cacheable right now
anon_map[NO_CACHE] = True
return ()