diff options
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r-- | lib/sqlalchemy/sql/coercions.py | 41 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/elements.py | 49 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/lambdas.py | 980 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 32 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/traversals.py | 68 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/visitors.py | 34 |
6 files changed, 880 insertions, 324 deletions
diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index 588c485ae..fa0f9c435 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -7,10 +7,13 @@ import numbers import re +import types from . import operators from . import roles from . import visitors +from .base import Options +from .traversals import HasCacheKey from .visitors import Visitable from .. import exc from .. import inspection @@ -33,11 +36,36 @@ def _is_literal(element): of a SQL expression construct. """ + return not isinstance( - element, (Visitable, schema.SchemaEventTarget) + element, (Visitable, schema.SchemaEventTarget), ) and not hasattr(element, "__clause_element__") +def _deep_is_literal(element): + """Return whether or not the element is a "literal" in the context + of a SQL expression construct. + + does a deeper more esoteric check than _is_literal. is used + for lambda elements that have to distinguish values that would + be bound vs. not without any context. + + """ + + return ( + not isinstance( + element, + (Visitable, schema.SchemaEventTarget, HasCacheKey, Options,), + ) + and not hasattr(element, "__clause_element__") + and ( + not isinstance(element, type) + or not issubclass(element, HasCacheKey) + ) + and not isinstance(element, types.FunctionType) + ) + + def _document_text_coercion(paramname, meth_rst, param_rst): return util.add_parameter_text( paramname, @@ -711,9 +739,16 @@ class StatementImpl(_NoTextCoercion, RoleImpl): class CoerceTextStatementImpl(_CoerceLiterals, RoleImpl): __slots__ = () - def _literal_coercion(self, element, **kw): + def _dont_literal_coercion(self, element, **kw): if callable(element) and hasattr(element, "__code__"): - return lambdas.StatementLambdaElement(element, self._role_class) + return lambdas.StatementLambdaElement( + element, + self._role_class, + additional_cache_criteria=kw.get( + "additional_cache_criteria", () + ), + tracked=kw["tra"], + ) else: return super(CoerceTextStatementImpl, self)._literal_coercion( element, **kw diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index ca73a4392..8a506446d 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -32,8 +32,8 @@ from .base import NO_ARG from .base import PARSE_AUTOCOMMIT from .base import SingletonConstant from .coercions import _document_text_coercion -from .traversals import _copy_internals from .traversals import _get_children +from .traversals import HasCopyInternals from .traversals import MemoizedHasCacheKey from .traversals import NO_CACHE from .visitors import cloned_traverse @@ -182,6 +182,7 @@ class ClauseElement( roles.SQLRole, SupportsWrappingAnnotations, MemoizedHasCacheKey, + HasCopyInternals, Traversible, ): """Base class for elements of a programmatically constructed SQL @@ -372,35 +373,6 @@ class ClauseElement( """ return traversals.compare(self, other, **kw) - def _copy_internals(self, omit_attrs=(), **kw): - """Reassign internal elements to be clones of themselves. - - Called during a copy-and-traverse operation on newly - shallow-copied elements to create a deep copy. - - The given clone function should be used, which may be applying - additional transformations to the element (i.e. replacement - traversal, cloned traversal, annotations). - - """ - - try: - traverse_internals = self._traverse_internals - except AttributeError: - # user-defined classes may not have a _traverse_internals - return - - for attrname, obj, meth in _copy_internals.run_generated_dispatch( - self, traverse_internals, "_generated_copy_internals_traversal" - ): - if attrname in omit_attrs: - continue - - if obj is not None: - result = meth(self, attrname, obj, **kw) - if result is not None: - setattr(self, attrname, result) - def get_children(self, omit_attrs=(), **kw): r"""Return immediate child :class:`.visitors.Traversible` elements of this :class:`.visitors.Traversible`. @@ -535,8 +507,6 @@ class ClauseElement( else: elem_cache_key = None - cache_hit = False - if elem_cache_key: cache_key, extracted_params = elem_cache_key key = ( @@ -549,6 +519,7 @@ class ClauseElement( compiled_sql = compiled_cache.get(key) if compiled_sql is None: + cache_hit = dialect.CACHE_MISS compiled_sql = self._compiler( dialect, cache_key=elem_cache_key, @@ -559,7 +530,7 @@ class ClauseElement( ) compiled_cache[key] = compiled_sql else: - cache_hit = True + cache_hit = dialect.CACHE_HIT else: extracted_params = None compiled_sql = self._compiler( @@ -570,6 +541,11 @@ class ClauseElement( schema_translate_map=schema_translate_map, **kw ) + cache_hit = ( + dialect.CACHING_DISABLED + if compiled_cache is None + else dialect.NO_CACHE_KEY + ) return compiled_sql, extracted_params, cache_hit @@ -1343,10 +1319,7 @@ class BindParameter(roles.InElementRole, ColumnElement): if required is NO_ARG: required = value is NO_ARG and callable_ is None if value is NO_ARG: - self._value_required_for_cache = False value = None - else: - self._value_required_for_cache = True if quote is not None: key = quoted_name(key, quote) @@ -1412,6 +1385,7 @@ class BindParameter(roles.InElementRole, ColumnElement): """Return a copy of this :class:`.BindParameter` with the given value set. """ + cloned = self._clone(maintain_key=maintain_key) cloned.value = value cloned.callable = None @@ -1465,7 +1439,8 @@ class BindParameter(roles.InElementRole, ColumnElement): anon_map[idself] = id_ = str(anon_map.index) anon_map.index += 1 - bindparams.append(self) + if bindparams is not None: + bindparams.append(self) return ( id_, diff --git a/lib/sqlalchemy/sql/lambdas.py b/lib/sqlalchemy/sql/lambdas.py index 792411189..327003902 100644 --- a/lib/sqlalchemy/sql/lambdas.py +++ b/lib/sqlalchemy/sql/lambdas.py @@ -8,6 +8,7 @@ import itertools import operator import sys +import types import weakref from . import coercions @@ -17,29 +18,23 @@ from . import schema from . import traversals from . import type_api from . import visitors +from .base import _clone from .operators import ColumnOperators from .. import exc from .. import inspection from .. import util from ..util import collections_abc -_trackers = weakref.WeakKeyDictionary() +_closure_per_cache_key = util.LRUCache(1000) -_TRACKERS = 0 -_STALE_CHECK = 1 -_REAL_FN = 2 -_EXPR = 3 -_IS_SEQUENCE = 4 -_PROPAGATE_ATTRS = 5 - - -def lambda_stmt(lmb): +def lambda_stmt(lmb, **opts): """Produce a SQL statement that is cached as a lambda. - This SQL statement will only be constructed if element has not been - compiled yet. The approach is used to save on Python function overhead - when constructing statements that will be cached. + The Python code object within the lambda is scanned for both Python + literals that will become bound parameters as well as closure variables + that refer to Core or ORM constructs that may vary. The lambda itself + will be invoked only once per particular set of constructs detected. E.g.:: @@ -60,7 +55,8 @@ def lambda_stmt(lmb): """ - return coercions.expect(roles.CoerceTextStatementRole, lmb) + + return StatementLambdaElement(lmb, roles.CoerceTextStatementRole, **opts) class LambdaElement(elements.ClauseElement): @@ -87,64 +83,108 @@ class LambdaElement(elements.ClauseElement): _is_lambda_element = True - _resolved_bindparams = () - _traverse_internals = [ ("_resolved", visitors.InternalTraversal.dp_clauseelement) ] + _transforms = () + + parent_lambda = None + def __repr__(self): return "%s(%r)" % (self.__class__.__name__, self.fn.__code__) def __init__(self, fn, role, apply_propagate_attrs=None, **kw): self.fn = fn self.role = role - self.parent_lambda = None + self.tracker_key = (fn.__code__,) if apply_propagate_attrs is None and ( role is roles.CoerceTextStatementRole ): apply_propagate_attrs = self - if fn.__code__ not in _trackers: - rec = self._initialize_var_trackers( - role, apply_propagate_attrs, kw + rec = self._retrieve_tracker_rec(fn, apply_propagate_attrs, kw) + + if apply_propagate_attrs is not None: + propagate_attrs = rec.propagate_attrs + if propagate_attrs: + apply_propagate_attrs._propagate_attrs = propagate_attrs + + def _retrieve_tracker_rec(self, fn, apply_propagate_attrs, kw): + lambda_cache = kw.get("lambda_cache", _closure_per_cache_key) + + tracker_key = self.tracker_key + + fn = self.fn + closure = fn.__closure__ + + tracker = AnalyzedCode.get( + fn, + self, + kw, + track_bound_values=kw.get("track_bound_values", True), + enable_tracking=kw.get("enable_tracking", True), + track_on=kw.get("track_on", None), + ) + + self._resolved_bindparams = bindparams = [] + + anon_map = traversals.anon_map() + cache_key = tuple( + [ + getter(closure, kw, anon_map, bindparams) + for getter in tracker.closure_trackers + ] + ) + if self.parent_lambda is not None: + cache_key = self.parent_lambda.closure_cache_key + cache_key + + self.closure_cache_key = cache_key + + try: + rec = lambda_cache[tracker_key + cache_key] + except KeyError: + rec = None + + if rec is None: + rec = AnalyzedFunction( + tracker, self, apply_propagate_attrs, kw, fn ) + rec.closure_bindparams = bindparams + lambda_cache[tracker_key + cache_key] = rec else: - rec = _trackers[self.fn.__code__] - closure = fn.__closure__ + bindparams[:] = [ + orig_bind._with_value(new_bind.value, maintain_key=True) + for orig_bind, new_bind in zip( + rec.closure_bindparams, bindparams + ) + ] + + if self.parent_lambda is not None: + bindparams[:0] = self.parent_lambda._resolved_bindparams - # check if the objects fixed inside the lambda that we've cached - # have been changed. This can apply to things like mappers that - # were recreated in test suites. if so, re-initialize. - # - # this is a small performance hit on every use for a not very - # common situation, however it's very hard to debug if the - # condition does occur. - for idx, obj in rec[_STALE_CHECK]: - if closure[idx].cell_contents is not obj: - rec = self._initialize_var_trackers( - role, apply_propagate_attrs, kw - ) - break self._rec = rec - if apply_propagate_attrs is not None: - propagate_attrs = rec[_PROPAGATE_ATTRS] - if propagate_attrs: - apply_propagate_attrs._propagate_attrs = propagate_attrs + lambda_element = self + while lambda_element is not None: + rec = lambda_element._rec + if rec.bindparam_trackers: + tracker_instrumented_fn = rec.tracker_instrumented_fn + for tracker in rec.bindparam_trackers: + tracker( + lambda_element.fn, tracker_instrumented_fn, bindparams + ) + lambda_element = lambda_element.parent_lambda - if rec[_TRACKERS]: - self._resolved_bindparams = bindparams = [] - for tracker in rec[_TRACKERS]: - tracker(self.fn, bindparams) + return rec def __getattr__(self, key): - return getattr(self._rec[_EXPR], key) + return getattr(self._rec.expected_expr, key) @property def _is_sequence(self): - return self._rec[_IS_SEQUENCE] + return self._rec.is_sequence @property def _select_iterable(self): @@ -169,8 +209,7 @@ class LambdaElement(elements.ClauseElement): def _param_dict(self): return {b.key: b.value for b in self._resolved_bindparams} - @util.memoized_property - def _resolved(self): + def _setup_binds_for_tracked_expr(self, expr): bindparam_lookup = {b.key: b for b in self._resolved_bindparams} def replace(thing): @@ -179,17 +218,11 @@ class LambdaElement(elements.ClauseElement): and thing.key in bindparam_lookup ): bind = bindparam_lookup[thing.key] - # TODO: consider - # if we should clone the bindparam here, re-cache the new - # version, etc. also we make an assumption about "expanding" - # in this case. if thing.expanding: bind.expanding = True return bind - expr = self._rec[_EXPR] - - if self._rec[_IS_SEQUENCE]: + if self._rec.is_sequence: expr = [ visitors.replacement_traverse(sub_expr, {}, replace) for sub_expr in expr @@ -199,9 +232,39 @@ class LambdaElement(elements.ClauseElement): return expr + def _copy_internals( + self, clone=_clone, deferred_copy_internals=None, **kw + ): + # TODO: this needs A LOT of tests + self._resolved = clone( + self._resolved, + deferred_copy_internals=deferred_copy_internals, + **kw + ) + + @util.memoized_property + def _resolved(self): + expr = self._rec.expected_expr + + if self._resolved_bindparams: + expr = self._setup_binds_for_tracked_expr(expr) + + return expr + def _gen_cache_key(self, anon_map, bindparams): - cache_key = (self.fn.__code__, self.__class__) + cache_key = ( + self.fn.__code__, + self.__class__, + ) + self.closure_cache_key + + parent = self.parent_lambda + while parent is not None: + cache_key = ( + (parent.fn.__code__,) + parent.closure_cache_key + cache_key + ) + + parent = parent.parent_lambda if self._resolved_bindparams: bindparams.extend(self._resolved_bindparams) @@ -211,101 +274,51 @@ class LambdaElement(elements.ClauseElement): def _invoke_user_fn(self, fn, *arg): return fn() - def _initialize_var_trackers(self, role, apply_propagate_attrs, coerce_kw): - fn = self.fn - # track objects referenced inside of lambdas, create bindparams - # ahead of time for literal values. If bindparams are produced, - # then rewrite the function globals and closure as necessary so that - # it refers to the bindparams, then invoke the function - new_closure = {} - new_globals = fn.__globals__.copy() - tracker_collection = [] - check_closure_for_stale = [] +class DeferredLambdaElement(LambdaElement): + """A LambdaElement where the lambda accepts arguments and is + invoked within the compile phase with special context. - for name in fn.__code__.co_names: - if name not in new_globals: - continue + This lambda doesn't normally produce its real SQL expression outside of the + compile phase. It is passed a fixed set of initial arguments + so that it can generate a sample expression. - bound_value = _roll_down_to_literal(new_globals[name]) + """ - if coercions._is_literal(bound_value): - new_globals[name] = bind = PyWrapper(name, bound_value) - tracker_collection.append(_globals_tracker(name, bind)) + def __init__(self, fn, role, lambda_args=(), **kw): + self.lambda_args = lambda_args + self.coerce_kw = kw + super(DeferredLambdaElement, self).__init__(fn, role, **kw) - if fn.__closure__: - for closure_index, (fv, cell) in enumerate( - zip(fn.__code__.co_freevars, fn.__closure__) - ): + def _invoke_user_fn(self, fn, *arg): + return fn(*self.lambda_args) - bound_value = _roll_down_to_literal(cell.cell_contents) + def _resolve_with_args(self, *lambda_args): + tracker_fn = self._rec.tracker_instrumented_fn + expr = tracker_fn(*lambda_args) - if coercions._is_literal(bound_value): - new_closure[fv] = bind = PyWrapper(fv, bound_value) - tracker_collection.append( - _closure_tracker(fv, bind, closure_index) - ) - else: - new_closure[fv] = cell.cell_contents - # for normal cell contents, add them to a list that - # we can compare later when we get new lambdas. if - # any identities have changed, then we will recalculate - # the whole lambda and run it again. - check_closure_for_stale.append( - (closure_index, cell.cell_contents) - ) + expr = coercions.expect(self.role, expr, **self.coerce_kw) - if tracker_collection: - new_fn = _rewrite_code_obj( - fn, - [new_closure[name] for name in fn.__code__.co_freevars], - new_globals, - ) - expr = self._invoke_user_fn(new_fn) + if self._resolved_bindparams: + expr = self._setup_binds_for_tracked_expr(expr) - else: - new_fn = fn - expr = self._invoke_user_fn(new_fn) - tracker_collection = [] + # TODO: TEST TEST TEST, this is very out there + for deferred_copy_internals in self._transforms: + expr = deferred_copy_internals(expr) - if self.parent_lambda is None: - if isinstance(expr, collections_abc.Sequence): - expected_expr = [ - coercions.expect( - role, - sub_expr, - apply_propagate_attrs=apply_propagate_attrs, - **coerce_kw - ) - for sub_expr in expr - ] - is_sequence = True - else: - expected_expr = coercions.expect( - role, - expr, - apply_propagate_attrs=apply_propagate_attrs, - **coerce_kw - ) - is_sequence = False - else: - expected_expr = expr - is_sequence = False + return expr - if apply_propagate_attrs is not None: - propagate_attrs = apply_propagate_attrs._propagate_attrs - else: - propagate_attrs = util.immutabledict() - - rec = _trackers[self.fn.__code__] = ( - tracker_collection, - check_closure_for_stale, - new_fn, - expected_expr, - is_sequence, - propagate_attrs, + def _copy_internals( + self, clone=_clone, deferred_copy_internals=None, **kw + ): + super(DeferredLambdaElement, self)._copy_internals( + clone=clone, deferred_copy_internals=deferred_copy_internals, **kw ) - return rec + + # TODO: A LOT A LOT of tests. for _resolve_with_args, we don't know + # our expression yet. so hold onto the replacement + if deferred_copy_internals: + self._transforms += (deferred_copy_internals,) class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement): @@ -334,13 +347,38 @@ class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement): """ + def __init__(self, fn, parent_lambda, **kw): + self._default_kw = default_kw = {} + global_track_bound_values = kw.pop("global_track_bound_values", None) + if global_track_bound_values is not None: + default_kw["track_bound_values"] = global_track_bound_values + kw["track_bound_values"] = global_track_bound_values + + if "lambda_cache" in kw: + default_kw["lambda_cache"] = kw["lambda_cache"] + + super(StatementLambdaElement, self).__init__(fn, parent_lambda, **kw) + def __add__(self, other): - return LinkedLambdaElement(other, parent_lambda=self) + return LinkedLambdaElement( + other, parent_lambda=self, **self._default_kw + ) + + def add_criteria(self, other, **kw): + if self._default_kw: + if kw: + default_kw = self._default_kw.copy() + default_kw.update(kw) + kw = default_kw + else: + kw = self._default_kw + + return LinkedLambdaElement(other, parent_lambda=self, **kw) def _execute_on_connection( self, connection, multiparams, params, execution_options ): - if self._rec[_EXPR].supports_execution: + if self._rec.expected_expr.supports_execution: return connection._execute_clauseelement( self, multiparams, params, execution_options ) @@ -349,93 +387,579 @@ class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement): @property def _with_options(self): - return self._rec[_EXPR]._with_options + return self._rec.expected_expr._with_options @property def _effective_plugin_target(self): - return self._rec[_EXPR]._effective_plugin_target + return self._rec.expected_expr._effective_plugin_target @property def _is_future(self): - return self._rec[_EXPR]._is_future + return self._rec.expected_expr._is_future @property def _execution_options(self): - return self._rec[_EXPR]._execution_options + return self._rec.expected_expr._execution_options + + def spoil(self): + """Return a new :class:`.StatementLambdaElement` that will run + all lambdas unconditionally each time. + + """ + return NullLambdaStatement(self.fn()) + + +class NullLambdaStatement(roles.AllowsLambdaRole, elements.ClauseElement): + """Provides the :class:`.StatementLambdaElement` API but does not + cache or analyze lambdas. + + the lambdas are instead invoked immediately. + + The intended use is to isolate issues that may arise when using + lambda statements. + + """ + + __visit_name__ = "lambda_element" + + _is_lambda_element = True + + _traverse_internals = [ + ("_resolved", visitors.InternalTraversal.dp_clauseelement) + ] + + def __init__(self, statement): + self._resolved = statement + self._propagate_attrs = statement._propagate_attrs + + def __getattr__(self, key): + return getattr(self._resolved, key) + + def __add__(self, other): + statement = other(self._resolved) + + return NullLambdaStatement(statement) + + def add_criteria(self, other, **kw): + statement = other(self._resolved) + + return NullLambdaStatement(statement) + + def _execute_on_connection( + self, connection, multiparams, params, execution_options + ): + if self._resolved.supports_execution: + return connection._execute_clauseelement( + self, multiparams, params, execution_options + ) + else: + raise exc.ObjectNotExecutableError(self) class LinkedLambdaElement(StatementLambdaElement): + """Represent subsequent links of a :class:`.StatementLambdaElement`.""" + + role = None + def __init__(self, fn, parent_lambda, **kw): + self._default_kw = parent_lambda._default_kw + self.fn = fn self.parent_lambda = parent_lambda - role = None - apply_propagate_attrs = self + self.tracker_key = parent_lambda.tracker_key + (fn.__code__,) + self._retrieve_tracker_rec(fn, self, kw) + self._propagate_attrs = parent_lambda._propagate_attrs + + def _invoke_user_fn(self, fn, *arg): + return fn(self.parent_lambda._resolved) + + +class AnalyzedCode(object): + __slots__ = ( + "track_closure_variables", + "track_bound_values", + "bindparam_trackers", + "closure_trackers", + "build_py_wrappers", + ) + _fns = weakref.WeakKeyDictionary() + + @classmethod + def get(cls, fn, lambda_element, lambda_kw, **kw): + try: + # TODO: validate kw haven't changed? + return cls._fns[fn.__code__] + except KeyError: + pass + cls._fns[fn.__code__] = analyzed = AnalyzedCode( + fn, lambda_element, lambda_kw, **kw + ) + return analyzed + + def __init__( + self, + fn, + lambda_element, + lambda_kw, + track_bound_values=True, + enable_tracking=True, + track_on=None, + ): + closure = fn.__closure__ + + self.track_closure_variables = not track_on + + self.track_bound_values = track_bound_values + + # a list of callables generated from _bound_parameter_getter_* + # functions. Each of these uses a PyWrapper object to retrieve + # a parameter value + self.bindparam_trackers = [] + + # a list of callables generated from _cache_key_getter_* functions + # these callables work to generate a cache key for the lambda + # based on what's inside its closure variables. + self.closure_trackers = [] - if fn.__code__ not in _trackers: - rec = self._initialize_var_trackers( - role, apply_propagate_attrs, kw + self.build_py_wrappers = [] + + if enable_tracking: + if track_on: + self._init_track_on(track_on) + + self._init_globals(fn) + + if closure: + self._init_closure(fn) + + self._setup_additional_closure_trackers(fn, lambda_element, lambda_kw) + + def _init_track_on(self, track_on): + self.closure_trackers.extend( + self._cache_key_getter_track_on(idx, elem) + for idx, elem in enumerate(track_on) + ) + + def _init_globals(self, fn): + build_py_wrappers = self.build_py_wrappers + bindparam_trackers = self.bindparam_trackers + track_bound_values = self.track_bound_values + + for name in fn.__code__.co_names: + if name not in fn.__globals__: + continue + + _bound_value = self._roll_down_to_literal(fn.__globals__[name]) + + if coercions._deep_is_literal(_bound_value): + build_py_wrappers.append((name, None)) + if track_bound_values: + bindparam_trackers.append( + self._bound_parameter_getter_func_globals(name) + ) + + def _init_closure(self, fn): + build_py_wrappers = self.build_py_wrappers + closure = fn.__closure__ + + track_bound_values = self.track_bound_values + track_closure_variables = self.track_closure_variables + bindparam_trackers = self.bindparam_trackers + closure_trackers = self.closure_trackers + + for closure_index, (fv, cell) in enumerate( + zip(fn.__code__.co_freevars, closure) + ): + _bound_value = self._roll_down_to_literal(cell.cell_contents) + + if coercions._deep_is_literal(_bound_value): + build_py_wrappers.append((fv, closure_index)) + if track_bound_values: + bindparam_trackers.append( + self._bound_parameter_getter_func_closure( + fv, closure_index + ) + ) + else: + # for normal cell contents, add them to a list that + # we can compare later when we get new lambdas. if + # any identities have changed, then we will + # recalculate the whole lambda and run it again. + + if track_closure_variables: + closure_trackers.append( + self._cache_key_getter_closure_variable( + closure_index, cell.cell_contents + ) + ) + + def _setup_additional_closure_trackers( + self, fn, lambda_element, lambda_kw + ): + # an additional step is to actually run the function, then + # go through the PyWrapper objects that were set up to catch a bound + # parameter. then if they *didn't* make a param, oh they're another + # object in the closure we have to track for our cache key. so + # create trackers to catch those. + + analyzed_function = AnalyzedFunction( + self, lambda_element, None, lambda_kw, fn, + ) + + closure_trackers = self.closure_trackers + + for pywrapper in analyzed_function.closure_pywrappers: + if not pywrapper._sa__has_param: + closure_trackers.append( + self._cache_key_getter_tracked_literal(pywrapper) + ) + + @classmethod + def _roll_down_to_literal(cls, element): + is_clause_element = hasattr(element, "__clause_element__") + + if is_clause_element: + while not isinstance( + element, (elements.ClauseElement, schema.SchemaItem) + ): + try: + element = element.__clause_element__() + except AttributeError: + break + + if not is_clause_element: + insp = inspection.inspect(element, raiseerr=False) + if insp is not None: + try: + return insp.__clause_element__() + except AttributeError: + return insp + + # TODO: should we coerce consts None/True/False here? + return element + else: + return element + + def _bound_parameter_getter_func_globals(self, name): + """Return a getter that will extend a list of bound parameters + with new entries from the ``__globals__`` collection of a particular + lambda. + + """ + + def extract_parameter_value( + current_fn, tracker_instrumented_fn, result + ): + wrapper = tracker_instrumented_fn.__globals__[name] + object.__getattribute__(wrapper, "_extract_bound_parameters")( + current_fn.__globals__[name], result + ) + + return extract_parameter_value + + def _bound_parameter_getter_func_closure(self, name, closure_index): + """Return a getter that will extend a list of bound parameters + with new entries from the ``__closure__`` collection of a particular + lambda. + + """ + + def extract_parameter_value( + current_fn, tracker_instrumented_fn, result + ): + wrapper = tracker_instrumented_fn.__closure__[ + closure_index + ].cell_contents + object.__getattribute__(wrapper, "_extract_bound_parameters")( + current_fn.__closure__[closure_index].cell_contents, result + ) + + return extract_parameter_value + + def _cache_key_getter_track_on(self, idx, elem): + """Return a getter that will extend a cache key with new entries + from the "track_on" parameter passed to a :class:`.LambdaElement`. + + """ + if isinstance(elem, traversals.HasCacheKey): + + def get(closure, kw, anon_map, bindparams): + return kw["track_on"][idx]._gen_cache_key(anon_map, bindparams) + + else: + + def get(closure, kw, anon_map, bindparams): + return kw["track_on"][idx] + + return get + + def _cache_key_getter_closure_variable(self, idx, cell_contents): + """Return a getter that will extend a cache key with new entries + from the ``__closure__`` collection of a particular lambda. + + """ + + if isinstance(cell_contents, traversals.HasCacheKey): + + def get(closure, kw, anon_map, bindparams): + return closure[idx].cell_contents._gen_cache_key( + anon_map, bindparams + ) + + elif isinstance(cell_contents, types.FunctionType): + + def get(closure, kw, anon_map, bindparams): + return closure[idx].cell_contents.__code__ + + elif cell_contents.__hash__ is None: + # this covers dict, etc. + def get(closure, kw, anon_map, bindparams): + return () + + else: + + def get(closure, kw, anon_map, bindparams): + return closure[idx].cell_contents + + return get + + def _cache_key_getter_tracked_literal(self, pytracker): + """Return a getter that will extend a cache key with new entries + from the ``__closure__`` collection of a particular lambda. + + this getter differs from _cache_key_getter_closure_variable + in that these are detected after the function is run, and PyWrapper + objects have recorded that a particular literal value is in fact + not being interpreted as a bound parameter. + + """ + + elem = pytracker._sa__to_evaluate + closure_index = pytracker._sa__closure_index + + if isinstance(elem, set): + raise exc.ArgumentError( + "Can't create a cache key for lambda closure variable " + '"%s" because it\'s a set. try using a list' + % pytracker._sa__name ) + + elif isinstance(elem, list): + + def get(closure, kw, anon_map, bindparams): + return tuple( + elem._gen_cache_key(anon_map, bindparams) + for elem in closure[closure_index].cell_contents + ) + + elif elem.__hash__ is None: + # this covers dict, etc. + def get(closure, kw, anon_map, bindparams): + return () + else: - rec = _trackers[self.fn.__code__] + def get(closure, kw, anon_map, bindparams): + return closure[closure_index].cell_contents + + return get + + +class AnalyzedFunction(object): + __slots__ = ( + "analyzed_code", + "fn", + "closure_pywrappers", + "tracker_instrumented_fn", + "expr", + "bindparam_trackers", + "expected_expr", + "is_sequence", + "propagate_attrs", + "closure_bindparams", + ) + + def __init__( + self, analyzed_code, lambda_element, apply_propagate_attrs, kw, fn, + ): + self.analyzed_code = analyzed_code + self.fn = fn + + self.bindparam_trackers = analyzed_code.bindparam_trackers + + self._instrument_and_run_function(lambda_element) + + self._coerce_expression(lambda_element, apply_propagate_attrs, kw) + + def _instrument_and_run_function(self, lambda_element): + analyzed_code = self.analyzed_code + + fn = self.fn + self.closure_pywrappers = closure_pywrappers = [] + + build_py_wrappers = analyzed_code.build_py_wrappers + + if not build_py_wrappers: + self.tracker_instrumented_fn = tracker_instrumented_fn = fn + self.expr = lambda_element._invoke_user_fn(tracker_instrumented_fn) + else: + track_closure_variables = analyzed_code.track_closure_variables closure = fn.__closure__ - # check if objects referred to by the lambda have changed and - # re-scan the lambda if so. see comments for this same section in - # LambdaElement. - for idx, obj in rec[_STALE_CHECK]: - if closure[idx].cell_contents is not obj: - rec = self._initialize_var_trackers( - role, apply_propagate_attrs, kw + # will form the __closure__ of the function when we rebuild it + if closure: + new_closure = { + fv: cell.cell_contents + for fv, cell in zip(fn.__code__.co_freevars, closure) + } + else: + new_closure = {} + + # will form the __globals__ of the function when we rebuild it + new_globals = fn.__globals__.copy() + + for name, closure_index in build_py_wrappers: + if closure_index is not None: + value = closure[closure_index].cell_contents + new_closure[name] = bind = PyWrapper( + name, value, closure_index=closure_index ) - break + if track_closure_variables: + closure_pywrappers.append(bind) + else: + value = fn.__globals__[name] + new_globals[name] = bind = PyWrapper(name, value) + + # rewrite the original fn. things that look like they will + # become bound parameters are wrapped in a PyWrapper. + self.tracker_instrumented_fn = ( + tracker_instrumented_fn + ) = self._rewrite_code_obj( + fn, + [new_closure[name] for name in fn.__code__.co_freevars], + new_globals, + ) - self._rec = rec + # now invoke the function. This will give us a new SQL + # expression, but all the places that there would be a bound + # parameter, the PyWrapper in its place will give us a bind + # with a predictable name we can match up later. - self._propagate_attrs = parent_lambda._propagate_attrs + # additionally, each PyWrapper will log that it did in fact + # create a parameter, otherwise, it's some kind of Python + # object in the closure and we want to track that, to make + # sure it doesn't change to somehting else, or if it does, + # that we create a different tracked function with that + # variable. + self.expr = lambda_element._invoke_user_fn(tracker_instrumented_fn) - self._resolved_bindparams = bindparams = [] - rec = self._rec - while True: - if rec[_TRACKERS]: - for tracker in rec[_TRACKERS]: - tracker(self.fn, bindparams) - if self.parent_lambda is not None: - self = self.parent_lambda - rec = self._rec + def _coerce_expression(self, lambda_element, apply_propagate_attrs, kw): + """Run the tracker-generated expression through coercion rules. + + After the user-defined lambda has been invoked to produce a statement + for re-use, run it through coercion rules to both check that it's the + correct type of object and also to coerce it to its useful form. + + """ + + parent_lambda = lambda_element.parent_lambda + expr = self.expr + + if parent_lambda is None: + if isinstance(expr, collections_abc.Sequence): + self.expected_expr = [ + coercions.expect( + lambda_element.role, + sub_expr, + apply_propagate_attrs=apply_propagate_attrs, + **kw + ) + for sub_expr in expr + ] + self.is_sequence = True else: - break + self.expected_expr = coercions.expect( + lambda_element.role, + expr, + apply_propagate_attrs=apply_propagate_attrs, + **kw + ) + self.is_sequence = False + else: + self.expected_expr = expr + self.is_sequence = False - def _invoke_user_fn(self, fn, *arg): - return fn(self.parent_lambda._rec[_EXPR]) + if apply_propagate_attrs is not None: + self.propagate_attrs = apply_propagate_attrs._propagate_attrs + else: + self.propagate_attrs = util.EMPTY_DICT - def _gen_cache_key(self, anon_map, bindparams): - if self._resolved_bindparams: - bindparams.extend(self._resolved_bindparams) + def _rewrite_code_obj(self, f, cell_values, globals_): + """Return a copy of f, with a new closure and new globals - cache_key = (self.fn.__code__, self.__class__) + yes it works in pypy :P - parent = self.parent_lambda - while parent is not None: - cache_key = (parent.fn.__code__,) + cache_key - parent = parent.parent_lambda + """ - return cache_key + argrange = range(len(cell_values)) + + code = "def make_cells():\n" + if cell_values: + code += " (%s) = (%s)\n" % ( + ", ".join("i%d" % i for i in argrange), + ", ".join("o%d" % i for i in argrange), + ) + code += " def closure():\n" + code += " return %s\n" % ", ".join("i%d" % i for i in argrange) + code += " return closure.__closure__" + vars_ = {"o%d" % i: cell_values[i] for i in argrange} + exec(code, vars_, vars_) + closure = vars_["make_cells"]() + + func = type(f)( + f.__code__, globals_, f.__name__, f.__defaults__, closure + ) + if sys.version_info >= (3,): + func.__annotations__ = f.__annotations__ + func.__kwdefaults__ = f.__kwdefaults__ + func.__doc__ = f.__doc__ + func.__module__ = f.__module__ + + return func class PyWrapper(ColumnOperators): - def __init__(self, name, to_evaluate, getter=None): + """A wrapper object that is injected into the ``__globals__`` and + ``__closure__`` of a Python function. + + When the function is instrumented with :class:`.PyWrapper` objects, it is + then invoked just once in order to set up the wrappers. We look through + all the :class:`.PyWrapper` objects we made to find the ones that generated + a :class:`.BindParameter` object, e.g. the expression system interpreted + something as a literal. Those positions in the globals/closure are then + ones that we will look at, each time a new lambda comes in that refers to + the same ``__code__`` object. In this way, we keep a single version of + the SQL expression that this lambda produced, without calling upon the + Python function that created it more than once, unless its other closure + variables have changed. The expression is then transformed to have the + new bound values embedded into it. + + """ + + def __init__(self, name, to_evaluate, closure_index=None, getter=None): self._name = name self._to_evaluate = to_evaluate self._param = None + self._has_param = False self._bind_paths = {} self._getter = getter + self._closure_index = closure_index def __call__(self, *arg, **kw): elem = object.__getattribute__(self, "_to_evaluate") value = elem(*arg, **kw) - if coercions._is_literal(value) and not isinstance( + if coercions._deep_is_literal(value) and not isinstance( # TODO: coverage where an ORM option or similar is here value, traversals.HasCacheKey, @@ -481,8 +1005,8 @@ class PyWrapper(ColumnOperators): if param is None: name = object.__getattribute__(self, "_name") self._param = param = elements.BindParameter(name, unique=True) + self._has_param = True param.type = type_api._resolve_value_to_type(to_evaluate) - return param._with_value(to_evaluate, maintain_key=True) def __getattribute__(self, key): @@ -497,7 +1021,15 @@ class PyWrapper(ColumnOperators): else: return self._sa__add_getter(key, operator.attrgetter) + def __iter__(self): + elem = object.__getattribute__(self, "_to_evaluate") + return iter(elem) + def __getitem__(self, key): + elem = object.__getattribute__(self, "_to_evaluate") + if not hasattr(elem, "__getitem__"): + raise AttributeError("__getitem__") + if isinstance(key, PyWrapper): # TODO: coverage raise exc.InvalidRequestError( @@ -518,90 +1050,14 @@ class PyWrapper(ColumnOperators): elem = object.__getattribute__(self, "_to_evaluate") value = getter(elem) - if coercions._is_literal(value): - wrapper = PyWrapper(key, value, getter) + if coercions._deep_is_literal(value): + wrapper = PyWrapper(key, value, getter=getter) bind_paths[bind_path_key] = wrapper return wrapper else: return value -def _roll_down_to_literal(element): - is_clause_element = hasattr(element, "__clause_element__") - - if is_clause_element: - while not isinstance( - element, (elements.ClauseElement, schema.SchemaItem) - ): - try: - element = element.__clause_element__() - except AttributeError: - break - - if not is_clause_element: - insp = inspection.inspect(element, raiseerr=False) - if insp is not None: - try: - return insp.__clause_element__() - except AttributeError: - return insp - - # TODO: should we coerce consts None/True/False here? - return element - else: - return element - - -def _globals_tracker(name, wrapper): - def extract_parameter_value(current_fn, result): - object.__getattribute__(wrapper, "_extract_bound_parameters")( - current_fn.__globals__[name], result - ) - - return extract_parameter_value - - -def _closure_tracker(name, wrapper, closure_index): - def extract_parameter_value(current_fn, result): - object.__getattribute__(wrapper, "_extract_bound_parameters")( - current_fn.__closure__[closure_index].cell_contents, result - ) - - return extract_parameter_value - - -def _rewrite_code_obj(f, cell_values, globals_): - """Return a copy of f, with a new closure and new globals - - yes it works in pypy :P - - """ - - argrange = range(len(cell_values)) - - code = "def make_cells():\n" - if cell_values: - code += " (%s) = (%s)\n" % ( - ", ".join("i%d" % i for i in argrange), - ", ".join("o%d" % i for i in argrange), - ) - code += " def closure():\n" - code += " return %s\n" % ", ".join("i%d" % i for i in argrange) - code += " return closure.__closure__" - vars_ = {"o%d" % i: cell_values[i] for i in argrange} - exec(code, vars_, vars_) - closure = vars_["make_cells"]() - - func = type(f)(f.__code__, globals_, f.__name__, f.__defaults__, closure) - if sys.version_info >= (3,): - func.__annotations__ = f.__annotations__ - func.__kwdefaults__ = f.__kwdefaults__ - func.__doc__ = f.__doc__ - func.__module__ = f.__module__ - - return func - - @inspection._inspects(LambdaElement) def insp(lmb): return inspection.inspect(lmb._resolved) diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 1155c273b..d67b61743 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -2426,6 +2426,7 @@ class SelectBase( """ _is_select_statement = True + is_select = True def _generate_fromclause_column_proxies(self, fromclause): # type: (FromClause) -> None @@ -3867,19 +3868,19 @@ class Select( [ ("_raw_columns", InternalTraversal.dp_clauseelement_list), ("_from_obj", InternalTraversal.dp_clauseelement_list), - ("_where_criteria", InternalTraversal.dp_clauseelement_list), - ("_having_criteria", InternalTraversal.dp_clauseelement_list), - ("_order_by_clauses", InternalTraversal.dp_clauseelement_list,), - ("_group_by_clauses", InternalTraversal.dp_clauseelement_list,), + ("_where_criteria", InternalTraversal.dp_clauseelement_tuple), + ("_having_criteria", InternalTraversal.dp_clauseelement_tuple), + ("_order_by_clauses", InternalTraversal.dp_clauseelement_tuple,), + ("_group_by_clauses", InternalTraversal.dp_clauseelement_tuple,), ("_setup_joins", InternalTraversal.dp_setup_join_tuple,), ("_legacy_setup_joins", InternalTraversal.dp_setup_join_tuple,), - ("_correlate", InternalTraversal.dp_clauseelement_list), - ("_correlate_except", InternalTraversal.dp_clauseelement_list,), + ("_correlate", InternalTraversal.dp_clauseelement_tuple), + ("_correlate_except", InternalTraversal.dp_clauseelement_tuple,), ("_limit_clause", InternalTraversal.dp_clauseelement), ("_offset_clause", InternalTraversal.dp_clauseelement), ("_for_update_arg", InternalTraversal.dp_clauseelement), ("_distinct", InternalTraversal.dp_boolean), - ("_distinct_on", InternalTraversal.dp_clauseelement_list), + ("_distinct_on", InternalTraversal.dp_clauseelement_tuple), ("_label_style", InternalTraversal.dp_plain_obj), ("_is_future", InternalTraversal.dp_boolean), ] @@ -4345,7 +4346,7 @@ class Select( @_generative def join(self, target, onclause=None, isouter=False, full=False): - r"""Create a SQL JOIN against this :class:`_expresson.Select` + r"""Create a SQL JOIN against this :class:`_expression.Select` object's criterion and apply generatively, returning the newly resulting :class:`_expression.Select`. @@ -4474,7 +4475,7 @@ class Select( # they've become. This allows us to ensure the same cloned from # is used when other items such as columns are "cloned" - all_the_froms = list( + all_the_froms = set( itertools.chain( _from_objects(*self._raw_columns), _from_objects(*self._where_criteria), @@ -4490,10 +4491,15 @@ class Select( new_froms = {f: clone(f, **kw) for f in all_the_froms} # 2. copy FROM collections, adding in joins that we've created. - self._from_obj = tuple(clone(f, **kw) for f in self._from_obj) + tuple( - f for f in new_froms.values() if isinstance(f, Join) + existing_from_obj = [clone(f, **kw) for f in self._from_obj] + add_froms = ( + set(f for f in new_froms.values() if isinstance(f, Join)) + .difference(all_the_froms) + .difference(existing_from_obj) ) + self._from_obj = tuple(existing_from_obj) + tuple(add_froms) + # 3. clone everything else, making sure we use columns # corresponding to the froms we just made. def replace(obj, **kw): @@ -4687,6 +4693,7 @@ class Select( """ + assert isinstance(self._where_criteria, tuple) self._where_criteria += ( coercions.expect(roles.WhereHavingRole, whereclause), ) @@ -5371,6 +5378,9 @@ class TextualSelect(SelectBase): _is_textual = True + is_text = True + is_select = True + def __init__(self, text, columns, positional=False): self.element = text # convert for ORM attributes->columns, etc diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index f41480a94..cb38df6af 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -190,7 +190,10 @@ class HasCacheKey(object): # statements, not so much, but they usually won't have # annotations. result += self._annotations_cache_key - elif meth is InternalTraversal.dp_clauseelement_list: + elif ( + meth is InternalTraversal.dp_clauseelement_list + or meth is InternalTraversal.dp_clauseelement_tuple + ): result += ( attrname, tuple( @@ -390,6 +393,7 @@ class _CacheKey(ExtendedInternalTraversal): visit_has_cache_key = visit_clauseelement = CALL_GEN_CACHE_KEY visit_clauseelement_list = InternalTraversal.dp_clauseelement_list visit_annotations_key = InternalTraversal.dp_annotations_key + visit_clauseelement_tuple = InternalTraversal.dp_clauseelement_tuple visit_string = ( visit_boolean @@ -451,6 +455,8 @@ class _CacheKey(ExtendedInternalTraversal): tuple(elem._gen_cache_key(anon_map, bindparams) for elem in obj), ) + visit_executable_options = visit_has_cache_key_list + def visit_inspectable_list( self, attrname, obj, parent, anon_map, bindparams ): @@ -682,6 +688,41 @@ class _CacheKey(ExtendedInternalTraversal): _cache_key_traversal_visitor = _CacheKey() +class HasCopyInternals(object): + def _clone(self, **kw): + raise NotImplementedError() + + def _copy_internals(self, omit_attrs=(), **kw): + """Reassign internal elements to be clones of themselves. + + Called during a copy-and-traverse operation on newly + shallow-copied elements to create a deep copy. + + The given clone function should be used, which may be applying + additional transformations to the element (i.e. replacement + traversal, cloned traversal, annotations). + + """ + + try: + traverse_internals = self._traverse_internals + except AttributeError: + # user-defined classes may not have a _traverse_internals + return + + for attrname, obj, meth in _copy_internals.run_generated_dispatch( + self, traverse_internals, "_generated_copy_internals_traversal" + ): + if attrname in omit_attrs: + continue + + if obj is not None: + + result = meth(attrname, self, obj, **kw) + if result is not None: + setattr(self, attrname, result) + + class _CopyInternals(InternalTraversal): """Generate a _copy_internals internal traversal dispatch for classes with a _traverse_internals collection.""" @@ -696,6 +737,16 @@ class _CopyInternals(InternalTraversal): ): return [clone(clause, **kw) for clause in element] + def visit_clauseelement_tuple( + self, attrname, parent, element, clone=_clone, **kw + ): + return tuple([clone(clause, **kw) for clause in element]) + + def visit_executable_options( + self, attrname, parent, element, clone=_clone, **kw + ): + return tuple([clone(clause, **kw) for clause in element]) + def visit_clauseelement_unordered_set( self, attrname, parent, element, clone=_clone, **kw ): @@ -817,6 +868,9 @@ class _GetChildren(InternalTraversal): def visit_clauseelement_list(self, element, **kw): return element + def visit_clauseelement_tuple(self, element, **kw): + return element + def visit_clauseelement_tuples(self, element, **kw): return itertools.chain.from_iterable(element) @@ -840,8 +894,8 @@ class _GetChildren(InternalTraversal): if not isinstance(target, str): yield _flatten_clauseelement(target) - # if onclause is not None and not isinstance(onclause, str): - # yield _flatten_clauseelement(onclause) + if onclause is not None and not isinstance(onclause, str): + yield _flatten_clauseelement(onclause) def visit_dml_ordered_values(self, element, **kw): for k, v in element: @@ -1015,6 +1069,8 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): ): return COMPARE_FAILED + visit_executable_options = visit_has_cache_key_list + def visit_clauseelement( self, attrname, left_parent, left, right_parent, right, **kw ): @@ -1057,6 +1113,12 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): for l, r in util.zip_longest(left, right, fillvalue=None): self.stack.append((l, r)) + def visit_clauseelement_tuple( + self, attrname, left_parent, left, right_parent, right, **kw + ): + for l, r in util.zip_longest(left, right, fillvalue=None): + self.stack.append((l, r)) + def _compare_unordered_sequences(self, seq1, seq2, **kw): if seq1 is None: return seq2 is None diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 56d3c93b3..5cb3cba70 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -257,7 +257,7 @@ class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)): """ - dp_clauseelement_tuples = symbol("CT") + dp_clauseelement_tuples = symbol("CTS") """Visit a list of tuples which contain :class:`_expression.ClauseElement` objects. @@ -268,6 +268,13 @@ class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)): """ + dp_clauseelement_tuple = symbol("CT") + """Visit a tuple of :class:`_expression.ClauseElement` objects. + + """ + + dp_executable_options = symbol("EO") + dp_fromclause_ordered_set = symbol("CO") """Visit an ordered set of :class:`_expression.FromClause` objects. """ @@ -712,6 +719,9 @@ def cloned_traverse(obj, opts, visitors): cloned = {} stop_on = set(opts.get("stop_on", [])) + def deferred_copy_internals(obj): + return cloned_traverse(obj, opts, visitors) + def clone(elem, **kw): if elem in stop_on: return elem @@ -732,7 +742,7 @@ def cloned_traverse(obj, opts, visitors): return cloned[id(elem)] if obj is not None: - obj = clone(obj) + obj = clone(obj, deferred_copy_internals=deferred_copy_internals) clone = None # remove gc cycles return obj @@ -764,6 +774,9 @@ def replacement_traverse(obj, opts, replace): cloned = {} stop_on = {id(x) for x in opts.get("stop_on", [])} + def deferred_copy_internals(obj): + return replacement_traverse(obj, opts, replace) + def clone(elem, **kw): if ( id(elem) in stop_on @@ -776,19 +789,24 @@ def replacement_traverse(obj, opts, replace): stop_on.add(id(newelem)) return newelem else: - - if elem not in cloned: + # base "already seen" on id(), not hash, so that we don't + # replace an Annotated element with its non-annotated one, and + # vice versa + id_elem = id(elem) + if id_elem not in cloned: if "replace" in kw: newelem = kw["replace"](elem) if newelem is not None: - cloned[elem] = newelem + cloned[id_elem] = newelem return newelem - cloned[elem] = newelem = elem._clone() + cloned[id_elem] = newelem = elem._clone() newelem._copy_internals(clone=clone, **kw) - return cloned[elem] + return cloned[id_elem] if obj is not None: - obj = clone(obj, **opts) + obj = clone( + obj, deferred_copy_internals=deferred_copy_internals, **opts + ) clone = None # remove gc cycles return obj |