diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-03-08 17:14:41 -0500 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-03-13 15:29:20 -0400 |
commit | 769fa67d842035dd852ab8b6a26ea3f110a51131 (patch) | |
tree | 5c121caca336071091c6f5ea4c54743c92d6458a /lib/sqlalchemy/sql/cache_key.py | |
parent | 77fc8216a74e6b2d0efc6591c6c735687bd10002 (diff) | |
download | sqlalchemy-769fa67d842035dd852ab8b6a26ea3f110a51131.tar.gz |
pep-484: sqlalchemy.sql pass one
sqlalchemy.sql will require many passes to get all
modules even gradually typed. Will have to pick and
choose what modules can be strictly typed vs. which
can be gradual.
in this patch, emphasis is on visitors.py, cache_key.py,
annotations.py for strict typing, compiler.py is on gradual
typing but has much more structure, in particular where it
connects with the outside world.
The work within compiler.py also reached back out to
engine/cursor.py , default.py quite a bit.
References: #6810
Change-Id: I6e8a29f6013fd216e43d45091bc193f8be0368fd
Diffstat (limited to 'lib/sqlalchemy/sql/cache_key.py')
-rw-r--r-- | lib/sqlalchemy/sql/cache_key.py | 354 |
1 files changed, 271 insertions, 83 deletions
diff --git a/lib/sqlalchemy/sql/cache_key.py b/lib/sqlalchemy/sql/cache_key.py index ff659b77d..fca58f98e 100644 --- a/lib/sqlalchemy/sql/cache_key.py +++ b/lib/sqlalchemy/sql/cache_key.py @@ -11,21 +11,41 @@ import enum from itertools import zip_longest import typing from typing import Any -from typing import Callable +from typing import cast +from typing import Dict +from typing import Iterator +from typing import List from typing import NamedTuple +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Type from typing import Union from .visitors import anon_map -from .visitors import ExtendedInternalTraversal +from .visitors import HasTraversalDispatch +from .visitors import HasTraverseInternals from .visitors import InternalTraversal +from .visitors import prefix_anon_map from .. import util from ..inspection import inspect from ..util import HasMemoized from ..util.typing import Literal - +from ..util.typing import Protocol if typing.TYPE_CHECKING: from .elements import BindParameter + from .elements import ClauseElement + from .visitors import _TraverseInternalsType + from ..engine.base import _CompiledCacheType + from ..engine.interfaces import _CoreSingleExecuteParams + + +class _CacheKeyTraversalDispatchType(Protocol): + def __call__( + s, self: HasCacheKey, visitor: _CacheKeyTraversal + ) -> CacheKey: + ... class CacheConst(enum.Enum): @@ -70,7 +90,9 @@ class HasCacheKey: __slots__ = () - _cache_key_traversal = NO_CACHE + _cache_key_traversal: Union[ + _TraverseInternalsType, Literal[CacheConst.NO_CACHE] + ] = NO_CACHE _is_has_cache_key = True @@ -83,7 +105,7 @@ class HasCacheKey: """ - inherit_cache = None + inherit_cache: Optional[bool] = None """Indicate if this :class:`.HasCacheKey` instance should make use of the cache key generation scheme used by its immediate superclass. @@ -106,8 +128,12 @@ class HasCacheKey: __slots__ = () + _generated_cache_key_traversal: Any + @classmethod - def _generate_cache_attrs(cls): + def _generate_cache_attrs( + cls, + ) -> Union[_CacheKeyTraversalDispatchType, Literal[CacheConst.NO_CACHE]]: """generate cache key dispatcher for a new class. This sets the _generated_cache_key_traversal attribute once called @@ -121,8 +147,11 @@ class HasCacheKey: _cache_key_traversal = getattr(cls, "_cache_key_traversal", None) if _cache_key_traversal is None: try: - # this would be HasTraverseInternals - _cache_key_traversal = cls._traverse_internals + # check for _traverse_internals, which is part of + # HasTraverseInternals + _cache_key_traversal = cast( + "Type[HasTraverseInternals]", cls + )._traverse_internals except AttributeError: cls._generated_cache_key_traversal = NO_CACHE return NO_CACHE @@ -138,7 +167,9 @@ class HasCacheKey: # more complicated, so for the moment this is a little less # efficient on startup but simpler. return _cache_key_traversal_visitor.generate_dispatch( - cls, _cache_key_traversal, "_generated_cache_key_traversal" + cls, + _cache_key_traversal, + "_generated_cache_key_traversal", ) else: _cache_key_traversal = cls.__dict__.get( @@ -170,11 +201,15 @@ class HasCacheKey: return NO_CACHE return _cache_key_traversal_visitor.generate_dispatch( - cls, _cache_key_traversal, "_generated_cache_key_traversal" + cls, + _cache_key_traversal, + "_generated_cache_key_traversal", ) @util.preload_module("sqlalchemy.sql.elements") - def _gen_cache_key(self, anon_map, bindparams): + def _gen_cache_key( + self, anon_map: anon_map, bindparams: List[BindParameter[Any]] + ) -> Optional[Tuple[Any, ...]]: """return an optional cache key. The cache key is a tuple which can contain any series of @@ -202,15 +237,15 @@ class HasCacheKey: dispatcher: Union[ Literal[CacheConst.NO_CACHE], - Callable[[HasCacheKey, "_CacheKeyTraversal"], "CacheKey"], + _CacheKeyTraversalDispatchType, ] try: dispatcher = cls.__dict__["_generated_cache_key_traversal"] except KeyError: - # most of the dispatchers are generated up front - # in sqlalchemy/sql/__init__.py -> - # traversals.py-> _preconfigure_traversals(). + # traversals.py -> _preconfigure_traversals() + # may be used to run these ahead of time, but + # is not enabled right now. # this block will generate any remaining dispatchers. dispatcher = cls._generate_cache_attrs() @@ -218,7 +253,7 @@ class HasCacheKey: anon_map[NO_CACHE] = True return None - result = (id_, cls) + result: Tuple[Any, ...] = (id_, cls) # inline of _cache_key_traversal_visitor.run_generated_dispatch() @@ -268,7 +303,7 @@ class HasCacheKey: # Columns, this should be long lived. For select() # statements, not so much, but they usually won't have # annotations. - result += self._annotations_cache_key + result += self._annotations_cache_key # type: ignore elif ( meth is InternalTraversal.dp_clauseelement_list or meth is InternalTraversal.dp_clauseelement_tuple @@ -290,7 +325,7 @@ class HasCacheKey: ) return result - def _generate_cache_key(self): + def _generate_cache_key(self) -> Optional[CacheKey]: """return a cache key. The cache key is a tuple which can contain any series of @@ -322,32 +357,40 @@ class HasCacheKey: """ - bindparams = [] + bindparams: List[BindParameter[Any]] = [] _anon_map = anon_map() key = self._gen_cache_key(_anon_map, bindparams) if NO_CACHE in _anon_map: return None else: + assert key is not None return CacheKey(key, bindparams) @classmethod - def _generate_cache_key_for_object(cls, obj): - bindparams = [] + def _generate_cache_key_for_object( + cls, obj: HasCacheKey + ) -> Optional[CacheKey]: + bindparams: List[BindParameter[Any]] = [] _anon_map = anon_map() key = obj._gen_cache_key(_anon_map, bindparams) if NO_CACHE in _anon_map: return None else: + assert key is not None return CacheKey(key, bindparams) +class HasCacheKeyTraverse(HasTraverseInternals, HasCacheKey): + pass + + class MemoizedHasCacheKey(HasCacheKey, HasMemoized): __slots__ = () @HasMemoized.memoized_instancemethod - def _generate_cache_key(self): + def _generate_cache_key(self) -> Optional[CacheKey]: return HasCacheKey._generate_cache_key(self) @@ -362,14 +405,22 @@ class CacheKey(NamedTuple): """ key: Tuple[Any, ...] - bindparams: Sequence[BindParameter] + bindparams: Sequence[BindParameter[Any]] - def __hash__(self): + # can't set __hash__ attribute because it interferes + # with namedtuple + # can't use "if not TYPE_CHECKING" because mypy rejects it + # inside of a NamedTuple + def __hash__(self) -> Optional[int]: # type: ignore """CacheKey itself is not hashable - hash the .key portion""" - return None - def to_offline_string(self, statement_cache, statement, parameters): + def to_offline_string( + self, + statement_cache: _CompiledCacheType, + statement: ClauseElement, + parameters: _CoreSingleExecuteParams, + ) -> str: """Generate an "offline string" form of this :class:`.CacheKey` The "offline string" is basically the string SQL for the @@ -400,21 +451,21 @@ class CacheKey(NamedTuple): return repr((sql_str, param_tuple)) - def __eq__(self, other): - return self.key == other.key + def __eq__(self, other: Any) -> bool: + return bool(self.key == other.key) @classmethod - def _diff_tuples(cls, left, right): + def _diff_tuples(cls, left: CacheKey, right: CacheKey) -> str: ck1 = CacheKey(left, []) ck2 = CacheKey(right, []) return ck1._diff(ck2) - def _whats_different(self, other): + def _whats_different(self, other: CacheKey) -> Iterator[str]: k1 = self.key k2 = other.key - stack = [] + stack: List[int] = [] pickup_index = 0 while True: s1, s2 = k1, k2 @@ -440,11 +491,11 @@ class CacheKey(NamedTuple): pickup_index = stack.pop(-1) break - def _diff(self, other): + def _diff(self, other: CacheKey) -> str: return ", ".join(self._whats_different(other)) - def __str__(self): - stack = [self.key] + def __str__(self) -> str: + stack: List[Union[Tuple[Any, ...], HasCacheKey]] = [self.key] output = [] sentinel = object() @@ -473,15 +524,15 @@ class CacheKey(NamedTuple): return "CacheKey(key=%s)" % ("\n".join(output),) - def _generate_param_dict(self): + def _generate_param_dict(self) -> Dict[str, Any]: """used for testing""" - from .compiler import prefix_anon_map - _anon_map = prefix_anon_map() return {b.key % _anon_map: b.effective_value for b in self.bindparams} - def _apply_params_to_element(self, original_cache_key, target_element): + def _apply_params_to_element( + self, original_cache_key: CacheKey, target_element: ClauseElement + ) -> ClauseElement: translate = { k.key: v.value for k, v in zip(original_cache_key.bindparams, self.bindparams) @@ -490,7 +541,7 @@ class CacheKey(NamedTuple): return target_element.params(translate) -class _CacheKeyTraversal(ExtendedInternalTraversal): +class _CacheKeyTraversal(HasTraversalDispatch): # very common elements are inlined into the main _get_cache_key() method # to produce a dramatic savings in Python function call overhead @@ -512,17 +563,43 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): visit_propagate_attrs = PROPAGATE_ATTRS def visit_with_context_options( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return tuple((fn.__code__, c_key) for fn, c_key in obj) - def visit_inspectable(self, attrname, obj, parent, anon_map, bindparams): + def visit_inspectable( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return (attrname, inspect(obj)._gen_cache_key(anon_map, bindparams)) - def visit_string_list(self, attrname, obj, parent, anon_map, bindparams): + def visit_string_list( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return tuple(obj) - def visit_multi(self, attrname, obj, parent, anon_map, bindparams): + def visit_multi( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return ( attrname, obj._gen_cache_key(anon_map, bindparams) @@ -530,7 +607,14 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): else obj, ) - def visit_multi_list(self, attrname, obj, parent, anon_map, bindparams): + def visit_multi_list( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return ( attrname, tuple( @@ -542,8 +626,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_has_cache_key_tuples( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: if not obj: return () return ( @@ -558,8 +647,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_has_cache_key_list( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: if not obj: return () return ( @@ -568,8 +662,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_executable_options( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: if not obj: return () return ( @@ -582,22 +681,37 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_inspectable_list( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return self.visit_has_cache_key_list( attrname, [inspect(o) for o in obj], parent, anon_map, bindparams ) def visit_clauseelement_tuples( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return self.visit_has_cache_key_tuples( attrname, obj, parent, anon_map, bindparams ) def visit_fromclause_ordered_set( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: if not obj: return () return ( @@ -606,8 +720,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_clauseelement_unordered_set( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: if not obj: return () cache_keys = [ @@ -621,13 +740,23 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_named_ddl_element( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return (attrname, obj.name) def visit_prefix_sequence( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: if not obj: return () @@ -642,8 +771,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_setup_join_tuple( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return tuple( ( target._gen_cache_key(anon_map, bindparams), @@ -659,8 +793,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_table_hint_list( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: if not obj: return () @@ -678,12 +817,24 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ), ) - def visit_plain_dict(self, attrname, obj, parent, anon_map, bindparams): + def visit_plain_dict( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return (attrname, tuple([(key, obj[key]) for key in sorted(obj)])) def visit_dialect_options( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return ( attrname, tuple( @@ -701,8 +852,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_string_clauseelement_dict( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return ( attrname, tuple( @@ -712,8 +868,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_string_multi_dict( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return ( attrname, tuple( @@ -728,8 +889,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_fromclause_canonical_column_collection( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: # inlining into the internals of ColumnCollection return ( attrname, @@ -740,14 +906,24 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_unknown_structure( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: anon_map[NO_CACHE] = True return () def visit_dml_ordered_values( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return ( attrname, tuple( @@ -761,7 +937,14 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ), ) - def visit_dml_values(self, attrname, obj, parent, anon_map, bindparams): + def visit_dml_values( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: # in py37 we can assume two dictionaries created in the same # insert ordering will retain that sorting return ( @@ -778,8 +961,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_dml_multi_values( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: # multivalues are simply not cacheable right now anon_map[NO_CACHE] = True return () |