summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r--lib/sqlalchemy/sql/coercions.py41
-rw-r--r--lib/sqlalchemy/sql/elements.py49
-rw-r--r--lib/sqlalchemy/sql/lambdas.py980
-rw-r--r--lib/sqlalchemy/sql/selectable.py32
-rw-r--r--lib/sqlalchemy/sql/traversals.py68
-rw-r--r--lib/sqlalchemy/sql/visitors.py34
6 files changed, 880 insertions, 324 deletions
diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py
index 588c485ae..fa0f9c435 100644
--- a/lib/sqlalchemy/sql/coercions.py
+++ b/lib/sqlalchemy/sql/coercions.py
@@ -7,10 +7,13 @@
import numbers
import re
+import types
from . import operators
from . import roles
from . import visitors
+from .base import Options
+from .traversals import HasCacheKey
from .visitors import Visitable
from .. import exc
from .. import inspection
@@ -33,11 +36,36 @@ def _is_literal(element):
of a SQL expression construct.
"""
+
return not isinstance(
- element, (Visitable, schema.SchemaEventTarget)
+ element, (Visitable, schema.SchemaEventTarget),
) and not hasattr(element, "__clause_element__")
+def _deep_is_literal(element):
+ """Return whether or not the element is a "literal" in the context
+ of a SQL expression construct.
+
+ does a deeper more esoteric check than _is_literal. is used
+ for lambda elements that have to distinguish values that would
+ be bound vs. not without any context.
+
+ """
+
+ return (
+ not isinstance(
+ element,
+ (Visitable, schema.SchemaEventTarget, HasCacheKey, Options,),
+ )
+ and not hasattr(element, "__clause_element__")
+ and (
+ not isinstance(element, type)
+ or not issubclass(element, HasCacheKey)
+ )
+ and not isinstance(element, types.FunctionType)
+ )
+
+
def _document_text_coercion(paramname, meth_rst, param_rst):
return util.add_parameter_text(
paramname,
@@ -711,9 +739,16 @@ class StatementImpl(_NoTextCoercion, RoleImpl):
class CoerceTextStatementImpl(_CoerceLiterals, RoleImpl):
__slots__ = ()
- def _literal_coercion(self, element, **kw):
+ def _dont_literal_coercion(self, element, **kw):
if callable(element) and hasattr(element, "__code__"):
- return lambdas.StatementLambdaElement(element, self._role_class)
+ return lambdas.StatementLambdaElement(
+ element,
+ self._role_class,
+ additional_cache_criteria=kw.get(
+ "additional_cache_criteria", ()
+ ),
+ tracked=kw["tra"],
+ )
else:
return super(CoerceTextStatementImpl, self)._literal_coercion(
element, **kw
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index ca73a4392..8a506446d 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -32,8 +32,8 @@ from .base import NO_ARG
from .base import PARSE_AUTOCOMMIT
from .base import SingletonConstant
from .coercions import _document_text_coercion
-from .traversals import _copy_internals
from .traversals import _get_children
+from .traversals import HasCopyInternals
from .traversals import MemoizedHasCacheKey
from .traversals import NO_CACHE
from .visitors import cloned_traverse
@@ -182,6 +182,7 @@ class ClauseElement(
roles.SQLRole,
SupportsWrappingAnnotations,
MemoizedHasCacheKey,
+ HasCopyInternals,
Traversible,
):
"""Base class for elements of a programmatically constructed SQL
@@ -372,35 +373,6 @@ class ClauseElement(
"""
return traversals.compare(self, other, **kw)
- def _copy_internals(self, omit_attrs=(), **kw):
- """Reassign internal elements to be clones of themselves.
-
- Called during a copy-and-traverse operation on newly
- shallow-copied elements to create a deep copy.
-
- The given clone function should be used, which may be applying
- additional transformations to the element (i.e. replacement
- traversal, cloned traversal, annotations).
-
- """
-
- try:
- traverse_internals = self._traverse_internals
- except AttributeError:
- # user-defined classes may not have a _traverse_internals
- return
-
- for attrname, obj, meth in _copy_internals.run_generated_dispatch(
- self, traverse_internals, "_generated_copy_internals_traversal"
- ):
- if attrname in omit_attrs:
- continue
-
- if obj is not None:
- result = meth(self, attrname, obj, **kw)
- if result is not None:
- setattr(self, attrname, result)
-
def get_children(self, omit_attrs=(), **kw):
r"""Return immediate child :class:`.visitors.Traversible`
elements of this :class:`.visitors.Traversible`.
@@ -535,8 +507,6 @@ class ClauseElement(
else:
elem_cache_key = None
- cache_hit = False
-
if elem_cache_key:
cache_key, extracted_params = elem_cache_key
key = (
@@ -549,6 +519,7 @@ class ClauseElement(
compiled_sql = compiled_cache.get(key)
if compiled_sql is None:
+ cache_hit = dialect.CACHE_MISS
compiled_sql = self._compiler(
dialect,
cache_key=elem_cache_key,
@@ -559,7 +530,7 @@ class ClauseElement(
)
compiled_cache[key] = compiled_sql
else:
- cache_hit = True
+ cache_hit = dialect.CACHE_HIT
else:
extracted_params = None
compiled_sql = self._compiler(
@@ -570,6 +541,11 @@ class ClauseElement(
schema_translate_map=schema_translate_map,
**kw
)
+ cache_hit = (
+ dialect.CACHING_DISABLED
+ if compiled_cache is None
+ else dialect.NO_CACHE_KEY
+ )
return compiled_sql, extracted_params, cache_hit
@@ -1343,10 +1319,7 @@ class BindParameter(roles.InElementRole, ColumnElement):
if required is NO_ARG:
required = value is NO_ARG and callable_ is None
if value is NO_ARG:
- self._value_required_for_cache = False
value = None
- else:
- self._value_required_for_cache = True
if quote is not None:
key = quoted_name(key, quote)
@@ -1412,6 +1385,7 @@ class BindParameter(roles.InElementRole, ColumnElement):
"""Return a copy of this :class:`.BindParameter` with the given value
set.
"""
+
cloned = self._clone(maintain_key=maintain_key)
cloned.value = value
cloned.callable = None
@@ -1465,7 +1439,8 @@ class BindParameter(roles.InElementRole, ColumnElement):
anon_map[idself] = id_ = str(anon_map.index)
anon_map.index += 1
- bindparams.append(self)
+ if bindparams is not None:
+ bindparams.append(self)
return (
id_,
diff --git a/lib/sqlalchemy/sql/lambdas.py b/lib/sqlalchemy/sql/lambdas.py
index 792411189..327003902 100644
--- a/lib/sqlalchemy/sql/lambdas.py
+++ b/lib/sqlalchemy/sql/lambdas.py
@@ -8,6 +8,7 @@
import itertools
import operator
import sys
+import types
import weakref
from . import coercions
@@ -17,29 +18,23 @@ from . import schema
from . import traversals
from . import type_api
from . import visitors
+from .base import _clone
from .operators import ColumnOperators
from .. import exc
from .. import inspection
from .. import util
from ..util import collections_abc
-_trackers = weakref.WeakKeyDictionary()
+_closure_per_cache_key = util.LRUCache(1000)
-_TRACKERS = 0
-_STALE_CHECK = 1
-_REAL_FN = 2
-_EXPR = 3
-_IS_SEQUENCE = 4
-_PROPAGATE_ATTRS = 5
-
-
-def lambda_stmt(lmb):
+def lambda_stmt(lmb, **opts):
"""Produce a SQL statement that is cached as a lambda.
- This SQL statement will only be constructed if element has not been
- compiled yet. The approach is used to save on Python function overhead
- when constructing statements that will be cached.
+ The Python code object within the lambda is scanned for both Python
+ literals that will become bound parameters as well as closure variables
+ that refer to Core or ORM constructs that may vary. The lambda itself
+ will be invoked only once per particular set of constructs detected.
E.g.::
@@ -60,7 +55,8 @@ def lambda_stmt(lmb):
"""
- return coercions.expect(roles.CoerceTextStatementRole, lmb)
+
+ return StatementLambdaElement(lmb, roles.CoerceTextStatementRole, **opts)
class LambdaElement(elements.ClauseElement):
@@ -87,64 +83,108 @@ class LambdaElement(elements.ClauseElement):
_is_lambda_element = True
- _resolved_bindparams = ()
-
_traverse_internals = [
("_resolved", visitors.InternalTraversal.dp_clauseelement)
]
+ _transforms = ()
+
+ parent_lambda = None
+
def __repr__(self):
return "%s(%r)" % (self.__class__.__name__, self.fn.__code__)
def __init__(self, fn, role, apply_propagate_attrs=None, **kw):
self.fn = fn
self.role = role
- self.parent_lambda = None
+ self.tracker_key = (fn.__code__,)
if apply_propagate_attrs is None and (
role is roles.CoerceTextStatementRole
):
apply_propagate_attrs = self
- if fn.__code__ not in _trackers:
- rec = self._initialize_var_trackers(
- role, apply_propagate_attrs, kw
+ rec = self._retrieve_tracker_rec(fn, apply_propagate_attrs, kw)
+
+ if apply_propagate_attrs is not None:
+ propagate_attrs = rec.propagate_attrs
+ if propagate_attrs:
+ apply_propagate_attrs._propagate_attrs = propagate_attrs
+
+ def _retrieve_tracker_rec(self, fn, apply_propagate_attrs, kw):
+ lambda_cache = kw.get("lambda_cache", _closure_per_cache_key)
+
+ tracker_key = self.tracker_key
+
+ fn = self.fn
+ closure = fn.__closure__
+
+ tracker = AnalyzedCode.get(
+ fn,
+ self,
+ kw,
+ track_bound_values=kw.get("track_bound_values", True),
+ enable_tracking=kw.get("enable_tracking", True),
+ track_on=kw.get("track_on", None),
+ )
+
+ self._resolved_bindparams = bindparams = []
+
+ anon_map = traversals.anon_map()
+ cache_key = tuple(
+ [
+ getter(closure, kw, anon_map, bindparams)
+ for getter in tracker.closure_trackers
+ ]
+ )
+ if self.parent_lambda is not None:
+ cache_key = self.parent_lambda.closure_cache_key + cache_key
+
+ self.closure_cache_key = cache_key
+
+ try:
+ rec = lambda_cache[tracker_key + cache_key]
+ except KeyError:
+ rec = None
+
+ if rec is None:
+ rec = AnalyzedFunction(
+ tracker, self, apply_propagate_attrs, kw, fn
)
+ rec.closure_bindparams = bindparams
+ lambda_cache[tracker_key + cache_key] = rec
else:
- rec = _trackers[self.fn.__code__]
- closure = fn.__closure__
+ bindparams[:] = [
+ orig_bind._with_value(new_bind.value, maintain_key=True)
+ for orig_bind, new_bind in zip(
+ rec.closure_bindparams, bindparams
+ )
+ ]
+
+ if self.parent_lambda is not None:
+ bindparams[:0] = self.parent_lambda._resolved_bindparams
- # check if the objects fixed inside the lambda that we've cached
- # have been changed. This can apply to things like mappers that
- # were recreated in test suites. if so, re-initialize.
- #
- # this is a small performance hit on every use for a not very
- # common situation, however it's very hard to debug if the
- # condition does occur.
- for idx, obj in rec[_STALE_CHECK]:
- if closure[idx].cell_contents is not obj:
- rec = self._initialize_var_trackers(
- role, apply_propagate_attrs, kw
- )
- break
self._rec = rec
- if apply_propagate_attrs is not None:
- propagate_attrs = rec[_PROPAGATE_ATTRS]
- if propagate_attrs:
- apply_propagate_attrs._propagate_attrs = propagate_attrs
+ lambda_element = self
+ while lambda_element is not None:
+ rec = lambda_element._rec
+ if rec.bindparam_trackers:
+ tracker_instrumented_fn = rec.tracker_instrumented_fn
+ for tracker in rec.bindparam_trackers:
+ tracker(
+ lambda_element.fn, tracker_instrumented_fn, bindparams
+ )
+ lambda_element = lambda_element.parent_lambda
- if rec[_TRACKERS]:
- self._resolved_bindparams = bindparams = []
- for tracker in rec[_TRACKERS]:
- tracker(self.fn, bindparams)
+ return rec
def __getattr__(self, key):
- return getattr(self._rec[_EXPR], key)
+ return getattr(self._rec.expected_expr, key)
@property
def _is_sequence(self):
- return self._rec[_IS_SEQUENCE]
+ return self._rec.is_sequence
@property
def _select_iterable(self):
@@ -169,8 +209,7 @@ class LambdaElement(elements.ClauseElement):
def _param_dict(self):
return {b.key: b.value for b in self._resolved_bindparams}
- @util.memoized_property
- def _resolved(self):
+ def _setup_binds_for_tracked_expr(self, expr):
bindparam_lookup = {b.key: b for b in self._resolved_bindparams}
def replace(thing):
@@ -179,17 +218,11 @@ class LambdaElement(elements.ClauseElement):
and thing.key in bindparam_lookup
):
bind = bindparam_lookup[thing.key]
- # TODO: consider
- # if we should clone the bindparam here, re-cache the new
- # version, etc. also we make an assumption about "expanding"
- # in this case.
if thing.expanding:
bind.expanding = True
return bind
- expr = self._rec[_EXPR]
-
- if self._rec[_IS_SEQUENCE]:
+ if self._rec.is_sequence:
expr = [
visitors.replacement_traverse(sub_expr, {}, replace)
for sub_expr in expr
@@ -199,9 +232,39 @@ class LambdaElement(elements.ClauseElement):
return expr
+ def _copy_internals(
+ self, clone=_clone, deferred_copy_internals=None, **kw
+ ):
+ # TODO: this needs A LOT of tests
+ self._resolved = clone(
+ self._resolved,
+ deferred_copy_internals=deferred_copy_internals,
+ **kw
+ )
+
+ @util.memoized_property
+ def _resolved(self):
+ expr = self._rec.expected_expr
+
+ if self._resolved_bindparams:
+ expr = self._setup_binds_for_tracked_expr(expr)
+
+ return expr
+
def _gen_cache_key(self, anon_map, bindparams):
- cache_key = (self.fn.__code__, self.__class__)
+ cache_key = (
+ self.fn.__code__,
+ self.__class__,
+ ) + self.closure_cache_key
+
+ parent = self.parent_lambda
+ while parent is not None:
+ cache_key = (
+ (parent.fn.__code__,) + parent.closure_cache_key + cache_key
+ )
+
+ parent = parent.parent_lambda
if self._resolved_bindparams:
bindparams.extend(self._resolved_bindparams)
@@ -211,101 +274,51 @@ class LambdaElement(elements.ClauseElement):
def _invoke_user_fn(self, fn, *arg):
return fn()
- def _initialize_var_trackers(self, role, apply_propagate_attrs, coerce_kw):
- fn = self.fn
- # track objects referenced inside of lambdas, create bindparams
- # ahead of time for literal values. If bindparams are produced,
- # then rewrite the function globals and closure as necessary so that
- # it refers to the bindparams, then invoke the function
- new_closure = {}
- new_globals = fn.__globals__.copy()
- tracker_collection = []
- check_closure_for_stale = []
+class DeferredLambdaElement(LambdaElement):
+ """A LambdaElement where the lambda accepts arguments and is
+ invoked within the compile phase with special context.
- for name in fn.__code__.co_names:
- if name not in new_globals:
- continue
+ This lambda doesn't normally produce its real SQL expression outside of the
+ compile phase. It is passed a fixed set of initial arguments
+ so that it can generate a sample expression.
- bound_value = _roll_down_to_literal(new_globals[name])
+ """
- if coercions._is_literal(bound_value):
- new_globals[name] = bind = PyWrapper(name, bound_value)
- tracker_collection.append(_globals_tracker(name, bind))
+ def __init__(self, fn, role, lambda_args=(), **kw):
+ self.lambda_args = lambda_args
+ self.coerce_kw = kw
+ super(DeferredLambdaElement, self).__init__(fn, role, **kw)
- if fn.__closure__:
- for closure_index, (fv, cell) in enumerate(
- zip(fn.__code__.co_freevars, fn.__closure__)
- ):
+ def _invoke_user_fn(self, fn, *arg):
+ return fn(*self.lambda_args)
- bound_value = _roll_down_to_literal(cell.cell_contents)
+ def _resolve_with_args(self, *lambda_args):
+ tracker_fn = self._rec.tracker_instrumented_fn
+ expr = tracker_fn(*lambda_args)
- if coercions._is_literal(bound_value):
- new_closure[fv] = bind = PyWrapper(fv, bound_value)
- tracker_collection.append(
- _closure_tracker(fv, bind, closure_index)
- )
- else:
- new_closure[fv] = cell.cell_contents
- # for normal cell contents, add them to a list that
- # we can compare later when we get new lambdas. if
- # any identities have changed, then we will recalculate
- # the whole lambda and run it again.
- check_closure_for_stale.append(
- (closure_index, cell.cell_contents)
- )
+ expr = coercions.expect(self.role, expr, **self.coerce_kw)
- if tracker_collection:
- new_fn = _rewrite_code_obj(
- fn,
- [new_closure[name] for name in fn.__code__.co_freevars],
- new_globals,
- )
- expr = self._invoke_user_fn(new_fn)
+ if self._resolved_bindparams:
+ expr = self._setup_binds_for_tracked_expr(expr)
- else:
- new_fn = fn
- expr = self._invoke_user_fn(new_fn)
- tracker_collection = []
+ # TODO: TEST TEST TEST, this is very out there
+ for deferred_copy_internals in self._transforms:
+ expr = deferred_copy_internals(expr)
- if self.parent_lambda is None:
- if isinstance(expr, collections_abc.Sequence):
- expected_expr = [
- coercions.expect(
- role,
- sub_expr,
- apply_propagate_attrs=apply_propagate_attrs,
- **coerce_kw
- )
- for sub_expr in expr
- ]
- is_sequence = True
- else:
- expected_expr = coercions.expect(
- role,
- expr,
- apply_propagate_attrs=apply_propagate_attrs,
- **coerce_kw
- )
- is_sequence = False
- else:
- expected_expr = expr
- is_sequence = False
+ return expr
- if apply_propagate_attrs is not None:
- propagate_attrs = apply_propagate_attrs._propagate_attrs
- else:
- propagate_attrs = util.immutabledict()
-
- rec = _trackers[self.fn.__code__] = (
- tracker_collection,
- check_closure_for_stale,
- new_fn,
- expected_expr,
- is_sequence,
- propagate_attrs,
+ def _copy_internals(
+ self, clone=_clone, deferred_copy_internals=None, **kw
+ ):
+ super(DeferredLambdaElement, self)._copy_internals(
+ clone=clone, deferred_copy_internals=deferred_copy_internals, **kw
)
- return rec
+
+ # TODO: A LOT A LOT of tests. for _resolve_with_args, we don't know
+ # our expression yet. so hold onto the replacement
+ if deferred_copy_internals:
+ self._transforms += (deferred_copy_internals,)
class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement):
@@ -334,13 +347,38 @@ class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement):
"""
+ def __init__(self, fn, parent_lambda, **kw):
+ self._default_kw = default_kw = {}
+ global_track_bound_values = kw.pop("global_track_bound_values", None)
+ if global_track_bound_values is not None:
+ default_kw["track_bound_values"] = global_track_bound_values
+ kw["track_bound_values"] = global_track_bound_values
+
+ if "lambda_cache" in kw:
+ default_kw["lambda_cache"] = kw["lambda_cache"]
+
+ super(StatementLambdaElement, self).__init__(fn, parent_lambda, **kw)
+
def __add__(self, other):
- return LinkedLambdaElement(other, parent_lambda=self)
+ return LinkedLambdaElement(
+ other, parent_lambda=self, **self._default_kw
+ )
+
+ def add_criteria(self, other, **kw):
+ if self._default_kw:
+ if kw:
+ default_kw = self._default_kw.copy()
+ default_kw.update(kw)
+ kw = default_kw
+ else:
+ kw = self._default_kw
+
+ return LinkedLambdaElement(other, parent_lambda=self, **kw)
def _execute_on_connection(
self, connection, multiparams, params, execution_options
):
- if self._rec[_EXPR].supports_execution:
+ if self._rec.expected_expr.supports_execution:
return connection._execute_clauseelement(
self, multiparams, params, execution_options
)
@@ -349,93 +387,579 @@ class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement):
@property
def _with_options(self):
- return self._rec[_EXPR]._with_options
+ return self._rec.expected_expr._with_options
@property
def _effective_plugin_target(self):
- return self._rec[_EXPR]._effective_plugin_target
+ return self._rec.expected_expr._effective_plugin_target
@property
def _is_future(self):
- return self._rec[_EXPR]._is_future
+ return self._rec.expected_expr._is_future
@property
def _execution_options(self):
- return self._rec[_EXPR]._execution_options
+ return self._rec.expected_expr._execution_options
+
+ def spoil(self):
+ """Return a new :class:`.StatementLambdaElement` that will run
+ all lambdas unconditionally each time.
+
+ """
+ return NullLambdaStatement(self.fn())
+
+
+class NullLambdaStatement(roles.AllowsLambdaRole, elements.ClauseElement):
+ """Provides the :class:`.StatementLambdaElement` API but does not
+ cache or analyze lambdas.
+
+ the lambdas are instead invoked immediately.
+
+ The intended use is to isolate issues that may arise when using
+ lambda statements.
+
+ """
+
+ __visit_name__ = "lambda_element"
+
+ _is_lambda_element = True
+
+ _traverse_internals = [
+ ("_resolved", visitors.InternalTraversal.dp_clauseelement)
+ ]
+
+ def __init__(self, statement):
+ self._resolved = statement
+ self._propagate_attrs = statement._propagate_attrs
+
+ def __getattr__(self, key):
+ return getattr(self._resolved, key)
+
+ def __add__(self, other):
+ statement = other(self._resolved)
+
+ return NullLambdaStatement(statement)
+
+ def add_criteria(self, other, **kw):
+ statement = other(self._resolved)
+
+ return NullLambdaStatement(statement)
+
+ def _execute_on_connection(
+ self, connection, multiparams, params, execution_options
+ ):
+ if self._resolved.supports_execution:
+ return connection._execute_clauseelement(
+ self, multiparams, params, execution_options
+ )
+ else:
+ raise exc.ObjectNotExecutableError(self)
class LinkedLambdaElement(StatementLambdaElement):
+ """Represent subsequent links of a :class:`.StatementLambdaElement`."""
+
+ role = None
+
def __init__(self, fn, parent_lambda, **kw):
+ self._default_kw = parent_lambda._default_kw
+
self.fn = fn
self.parent_lambda = parent_lambda
- role = None
- apply_propagate_attrs = self
+ self.tracker_key = parent_lambda.tracker_key + (fn.__code__,)
+ self._retrieve_tracker_rec(fn, self, kw)
+ self._propagate_attrs = parent_lambda._propagate_attrs
+
+ def _invoke_user_fn(self, fn, *arg):
+ return fn(self.parent_lambda._resolved)
+
+
+class AnalyzedCode(object):
+ __slots__ = (
+ "track_closure_variables",
+ "track_bound_values",
+ "bindparam_trackers",
+ "closure_trackers",
+ "build_py_wrappers",
+ )
+ _fns = weakref.WeakKeyDictionary()
+
+ @classmethod
+ def get(cls, fn, lambda_element, lambda_kw, **kw):
+ try:
+ # TODO: validate kw haven't changed?
+ return cls._fns[fn.__code__]
+ except KeyError:
+ pass
+ cls._fns[fn.__code__] = analyzed = AnalyzedCode(
+ fn, lambda_element, lambda_kw, **kw
+ )
+ return analyzed
+
+ def __init__(
+ self,
+ fn,
+ lambda_element,
+ lambda_kw,
+ track_bound_values=True,
+ enable_tracking=True,
+ track_on=None,
+ ):
+ closure = fn.__closure__
+
+ self.track_closure_variables = not track_on
+
+ self.track_bound_values = track_bound_values
+
+ # a list of callables generated from _bound_parameter_getter_*
+ # functions. Each of these uses a PyWrapper object to retrieve
+ # a parameter value
+ self.bindparam_trackers = []
+
+ # a list of callables generated from _cache_key_getter_* functions
+ # these callables work to generate a cache key for the lambda
+ # based on what's inside its closure variables.
+ self.closure_trackers = []
- if fn.__code__ not in _trackers:
- rec = self._initialize_var_trackers(
- role, apply_propagate_attrs, kw
+ self.build_py_wrappers = []
+
+ if enable_tracking:
+ if track_on:
+ self._init_track_on(track_on)
+
+ self._init_globals(fn)
+
+ if closure:
+ self._init_closure(fn)
+
+ self._setup_additional_closure_trackers(fn, lambda_element, lambda_kw)
+
+ def _init_track_on(self, track_on):
+ self.closure_trackers.extend(
+ self._cache_key_getter_track_on(idx, elem)
+ for idx, elem in enumerate(track_on)
+ )
+
+ def _init_globals(self, fn):
+ build_py_wrappers = self.build_py_wrappers
+ bindparam_trackers = self.bindparam_trackers
+ track_bound_values = self.track_bound_values
+
+ for name in fn.__code__.co_names:
+ if name not in fn.__globals__:
+ continue
+
+ _bound_value = self._roll_down_to_literal(fn.__globals__[name])
+
+ if coercions._deep_is_literal(_bound_value):
+ build_py_wrappers.append((name, None))
+ if track_bound_values:
+ bindparam_trackers.append(
+ self._bound_parameter_getter_func_globals(name)
+ )
+
+ def _init_closure(self, fn):
+ build_py_wrappers = self.build_py_wrappers
+ closure = fn.__closure__
+
+ track_bound_values = self.track_bound_values
+ track_closure_variables = self.track_closure_variables
+ bindparam_trackers = self.bindparam_trackers
+ closure_trackers = self.closure_trackers
+
+ for closure_index, (fv, cell) in enumerate(
+ zip(fn.__code__.co_freevars, closure)
+ ):
+ _bound_value = self._roll_down_to_literal(cell.cell_contents)
+
+ if coercions._deep_is_literal(_bound_value):
+ build_py_wrappers.append((fv, closure_index))
+ if track_bound_values:
+ bindparam_trackers.append(
+ self._bound_parameter_getter_func_closure(
+ fv, closure_index
+ )
+ )
+ else:
+ # for normal cell contents, add them to a list that
+ # we can compare later when we get new lambdas. if
+ # any identities have changed, then we will
+ # recalculate the whole lambda and run it again.
+
+ if track_closure_variables:
+ closure_trackers.append(
+ self._cache_key_getter_closure_variable(
+ closure_index, cell.cell_contents
+ )
+ )
+
+ def _setup_additional_closure_trackers(
+ self, fn, lambda_element, lambda_kw
+ ):
+ # an additional step is to actually run the function, then
+ # go through the PyWrapper objects that were set up to catch a bound
+ # parameter. then if they *didn't* make a param, oh they're another
+ # object in the closure we have to track for our cache key. so
+ # create trackers to catch those.
+
+ analyzed_function = AnalyzedFunction(
+ self, lambda_element, None, lambda_kw, fn,
+ )
+
+ closure_trackers = self.closure_trackers
+
+ for pywrapper in analyzed_function.closure_pywrappers:
+ if not pywrapper._sa__has_param:
+ closure_trackers.append(
+ self._cache_key_getter_tracked_literal(pywrapper)
+ )
+
+ @classmethod
+ def _roll_down_to_literal(cls, element):
+ is_clause_element = hasattr(element, "__clause_element__")
+
+ if is_clause_element:
+ while not isinstance(
+ element, (elements.ClauseElement, schema.SchemaItem)
+ ):
+ try:
+ element = element.__clause_element__()
+ except AttributeError:
+ break
+
+ if not is_clause_element:
+ insp = inspection.inspect(element, raiseerr=False)
+ if insp is not None:
+ try:
+ return insp.__clause_element__()
+ except AttributeError:
+ return insp
+
+ # TODO: should we coerce consts None/True/False here?
+ return element
+ else:
+ return element
+
+ def _bound_parameter_getter_func_globals(self, name):
+ """Return a getter that will extend a list of bound parameters
+ with new entries from the ``__globals__`` collection of a particular
+ lambda.
+
+ """
+
+ def extract_parameter_value(
+ current_fn, tracker_instrumented_fn, result
+ ):
+ wrapper = tracker_instrumented_fn.__globals__[name]
+ object.__getattribute__(wrapper, "_extract_bound_parameters")(
+ current_fn.__globals__[name], result
+ )
+
+ return extract_parameter_value
+
+ def _bound_parameter_getter_func_closure(self, name, closure_index):
+ """Return a getter that will extend a list of bound parameters
+ with new entries from the ``__closure__`` collection of a particular
+ lambda.
+
+ """
+
+ def extract_parameter_value(
+ current_fn, tracker_instrumented_fn, result
+ ):
+ wrapper = tracker_instrumented_fn.__closure__[
+ closure_index
+ ].cell_contents
+ object.__getattribute__(wrapper, "_extract_bound_parameters")(
+ current_fn.__closure__[closure_index].cell_contents, result
+ )
+
+ return extract_parameter_value
+
+ def _cache_key_getter_track_on(self, idx, elem):
+ """Return a getter that will extend a cache key with new entries
+ from the "track_on" parameter passed to a :class:`.LambdaElement`.
+
+ """
+ if isinstance(elem, traversals.HasCacheKey):
+
+ def get(closure, kw, anon_map, bindparams):
+ return kw["track_on"][idx]._gen_cache_key(anon_map, bindparams)
+
+ else:
+
+ def get(closure, kw, anon_map, bindparams):
+ return kw["track_on"][idx]
+
+ return get
+
+ def _cache_key_getter_closure_variable(self, idx, cell_contents):
+ """Return a getter that will extend a cache key with new entries
+ from the ``__closure__`` collection of a particular lambda.
+
+ """
+
+ if isinstance(cell_contents, traversals.HasCacheKey):
+
+ def get(closure, kw, anon_map, bindparams):
+ return closure[idx].cell_contents._gen_cache_key(
+ anon_map, bindparams
+ )
+
+ elif isinstance(cell_contents, types.FunctionType):
+
+ def get(closure, kw, anon_map, bindparams):
+ return closure[idx].cell_contents.__code__
+
+ elif cell_contents.__hash__ is None:
+ # this covers dict, etc.
+ def get(closure, kw, anon_map, bindparams):
+ return ()
+
+ else:
+
+ def get(closure, kw, anon_map, bindparams):
+ return closure[idx].cell_contents
+
+ return get
+
+ def _cache_key_getter_tracked_literal(self, pytracker):
+ """Return a getter that will extend a cache key with new entries
+ from the ``__closure__`` collection of a particular lambda.
+
+ this getter differs from _cache_key_getter_closure_variable
+ in that these are detected after the function is run, and PyWrapper
+ objects have recorded that a particular literal value is in fact
+ not being interpreted as a bound parameter.
+
+ """
+
+ elem = pytracker._sa__to_evaluate
+ closure_index = pytracker._sa__closure_index
+
+ if isinstance(elem, set):
+ raise exc.ArgumentError(
+ "Can't create a cache key for lambda closure variable "
+ '"%s" because it\'s a set. try using a list'
+ % pytracker._sa__name
)
+
+ elif isinstance(elem, list):
+
+ def get(closure, kw, anon_map, bindparams):
+ return tuple(
+ elem._gen_cache_key(anon_map, bindparams)
+ for elem in closure[closure_index].cell_contents
+ )
+
+ elif elem.__hash__ is None:
+ # this covers dict, etc.
+ def get(closure, kw, anon_map, bindparams):
+ return ()
+
else:
- rec = _trackers[self.fn.__code__]
+ def get(closure, kw, anon_map, bindparams):
+ return closure[closure_index].cell_contents
+
+ return get
+
+
+class AnalyzedFunction(object):
+ __slots__ = (
+ "analyzed_code",
+ "fn",
+ "closure_pywrappers",
+ "tracker_instrumented_fn",
+ "expr",
+ "bindparam_trackers",
+ "expected_expr",
+ "is_sequence",
+ "propagate_attrs",
+ "closure_bindparams",
+ )
+
+ def __init__(
+ self, analyzed_code, lambda_element, apply_propagate_attrs, kw, fn,
+ ):
+ self.analyzed_code = analyzed_code
+ self.fn = fn
+
+ self.bindparam_trackers = analyzed_code.bindparam_trackers
+
+ self._instrument_and_run_function(lambda_element)
+
+ self._coerce_expression(lambda_element, apply_propagate_attrs, kw)
+
+ def _instrument_and_run_function(self, lambda_element):
+ analyzed_code = self.analyzed_code
+
+ fn = self.fn
+ self.closure_pywrappers = closure_pywrappers = []
+
+ build_py_wrappers = analyzed_code.build_py_wrappers
+
+ if not build_py_wrappers:
+ self.tracker_instrumented_fn = tracker_instrumented_fn = fn
+ self.expr = lambda_element._invoke_user_fn(tracker_instrumented_fn)
+ else:
+ track_closure_variables = analyzed_code.track_closure_variables
closure = fn.__closure__
- # check if objects referred to by the lambda have changed and
- # re-scan the lambda if so. see comments for this same section in
- # LambdaElement.
- for idx, obj in rec[_STALE_CHECK]:
- if closure[idx].cell_contents is not obj:
- rec = self._initialize_var_trackers(
- role, apply_propagate_attrs, kw
+ # will form the __closure__ of the function when we rebuild it
+ if closure:
+ new_closure = {
+ fv: cell.cell_contents
+ for fv, cell in zip(fn.__code__.co_freevars, closure)
+ }
+ else:
+ new_closure = {}
+
+ # will form the __globals__ of the function when we rebuild it
+ new_globals = fn.__globals__.copy()
+
+ for name, closure_index in build_py_wrappers:
+ if closure_index is not None:
+ value = closure[closure_index].cell_contents
+ new_closure[name] = bind = PyWrapper(
+ name, value, closure_index=closure_index
)
- break
+ if track_closure_variables:
+ closure_pywrappers.append(bind)
+ else:
+ value = fn.__globals__[name]
+ new_globals[name] = bind = PyWrapper(name, value)
+
+ # rewrite the original fn. things that look like they will
+ # become bound parameters are wrapped in a PyWrapper.
+ self.tracker_instrumented_fn = (
+ tracker_instrumented_fn
+ ) = self._rewrite_code_obj(
+ fn,
+ [new_closure[name] for name in fn.__code__.co_freevars],
+ new_globals,
+ )
- self._rec = rec
+ # now invoke the function. This will give us a new SQL
+ # expression, but all the places that there would be a bound
+ # parameter, the PyWrapper in its place will give us a bind
+ # with a predictable name we can match up later.
- self._propagate_attrs = parent_lambda._propagate_attrs
+ # additionally, each PyWrapper will log that it did in fact
+ # create a parameter, otherwise, it's some kind of Python
+ # object in the closure and we want to track that, to make
+ # sure it doesn't change to somehting else, or if it does,
+ # that we create a different tracked function with that
+ # variable.
+ self.expr = lambda_element._invoke_user_fn(tracker_instrumented_fn)
- self._resolved_bindparams = bindparams = []
- rec = self._rec
- while True:
- if rec[_TRACKERS]:
- for tracker in rec[_TRACKERS]:
- tracker(self.fn, bindparams)
- if self.parent_lambda is not None:
- self = self.parent_lambda
- rec = self._rec
+ def _coerce_expression(self, lambda_element, apply_propagate_attrs, kw):
+ """Run the tracker-generated expression through coercion rules.
+
+ After the user-defined lambda has been invoked to produce a statement
+ for re-use, run it through coercion rules to both check that it's the
+ correct type of object and also to coerce it to its useful form.
+
+ """
+
+ parent_lambda = lambda_element.parent_lambda
+ expr = self.expr
+
+ if parent_lambda is None:
+ if isinstance(expr, collections_abc.Sequence):
+ self.expected_expr = [
+ coercions.expect(
+ lambda_element.role,
+ sub_expr,
+ apply_propagate_attrs=apply_propagate_attrs,
+ **kw
+ )
+ for sub_expr in expr
+ ]
+ self.is_sequence = True
else:
- break
+ self.expected_expr = coercions.expect(
+ lambda_element.role,
+ expr,
+ apply_propagate_attrs=apply_propagate_attrs,
+ **kw
+ )
+ self.is_sequence = False
+ else:
+ self.expected_expr = expr
+ self.is_sequence = False
- def _invoke_user_fn(self, fn, *arg):
- return fn(self.parent_lambda._rec[_EXPR])
+ if apply_propagate_attrs is not None:
+ self.propagate_attrs = apply_propagate_attrs._propagate_attrs
+ else:
+ self.propagate_attrs = util.EMPTY_DICT
- def _gen_cache_key(self, anon_map, bindparams):
- if self._resolved_bindparams:
- bindparams.extend(self._resolved_bindparams)
+ def _rewrite_code_obj(self, f, cell_values, globals_):
+ """Return a copy of f, with a new closure and new globals
- cache_key = (self.fn.__code__, self.__class__)
+ yes it works in pypy :P
- parent = self.parent_lambda
- while parent is not None:
- cache_key = (parent.fn.__code__,) + cache_key
- parent = parent.parent_lambda
+ """
- return cache_key
+ argrange = range(len(cell_values))
+
+ code = "def make_cells():\n"
+ if cell_values:
+ code += " (%s) = (%s)\n" % (
+ ", ".join("i%d" % i for i in argrange),
+ ", ".join("o%d" % i for i in argrange),
+ )
+ code += " def closure():\n"
+ code += " return %s\n" % ", ".join("i%d" % i for i in argrange)
+ code += " return closure.__closure__"
+ vars_ = {"o%d" % i: cell_values[i] for i in argrange}
+ exec(code, vars_, vars_)
+ closure = vars_["make_cells"]()
+
+ func = type(f)(
+ f.__code__, globals_, f.__name__, f.__defaults__, closure
+ )
+ if sys.version_info >= (3,):
+ func.__annotations__ = f.__annotations__
+ func.__kwdefaults__ = f.__kwdefaults__
+ func.__doc__ = f.__doc__
+ func.__module__ = f.__module__
+
+ return func
class PyWrapper(ColumnOperators):
- def __init__(self, name, to_evaluate, getter=None):
+ """A wrapper object that is injected into the ``__globals__`` and
+ ``__closure__`` of a Python function.
+
+ When the function is instrumented with :class:`.PyWrapper` objects, it is
+ then invoked just once in order to set up the wrappers. We look through
+ all the :class:`.PyWrapper` objects we made to find the ones that generated
+ a :class:`.BindParameter` object, e.g. the expression system interpreted
+ something as a literal. Those positions in the globals/closure are then
+ ones that we will look at, each time a new lambda comes in that refers to
+ the same ``__code__`` object. In this way, we keep a single version of
+ the SQL expression that this lambda produced, without calling upon the
+ Python function that created it more than once, unless its other closure
+ variables have changed. The expression is then transformed to have the
+ new bound values embedded into it.
+
+ """
+
+ def __init__(self, name, to_evaluate, closure_index=None, getter=None):
self._name = name
self._to_evaluate = to_evaluate
self._param = None
+ self._has_param = False
self._bind_paths = {}
self._getter = getter
+ self._closure_index = closure_index
def __call__(self, *arg, **kw):
elem = object.__getattribute__(self, "_to_evaluate")
value = elem(*arg, **kw)
- if coercions._is_literal(value) and not isinstance(
+ if coercions._deep_is_literal(value) and not isinstance(
# TODO: coverage where an ORM option or similar is here
value,
traversals.HasCacheKey,
@@ -481,8 +1005,8 @@ class PyWrapper(ColumnOperators):
if param is None:
name = object.__getattribute__(self, "_name")
self._param = param = elements.BindParameter(name, unique=True)
+ self._has_param = True
param.type = type_api._resolve_value_to_type(to_evaluate)
-
return param._with_value(to_evaluate, maintain_key=True)
def __getattribute__(self, key):
@@ -497,7 +1021,15 @@ class PyWrapper(ColumnOperators):
else:
return self._sa__add_getter(key, operator.attrgetter)
+ def __iter__(self):
+ elem = object.__getattribute__(self, "_to_evaluate")
+ return iter(elem)
+
def __getitem__(self, key):
+ elem = object.__getattribute__(self, "_to_evaluate")
+ if not hasattr(elem, "__getitem__"):
+ raise AttributeError("__getitem__")
+
if isinstance(key, PyWrapper):
# TODO: coverage
raise exc.InvalidRequestError(
@@ -518,90 +1050,14 @@ class PyWrapper(ColumnOperators):
elem = object.__getattribute__(self, "_to_evaluate")
value = getter(elem)
- if coercions._is_literal(value):
- wrapper = PyWrapper(key, value, getter)
+ if coercions._deep_is_literal(value):
+ wrapper = PyWrapper(key, value, getter=getter)
bind_paths[bind_path_key] = wrapper
return wrapper
else:
return value
-def _roll_down_to_literal(element):
- is_clause_element = hasattr(element, "__clause_element__")
-
- if is_clause_element:
- while not isinstance(
- element, (elements.ClauseElement, schema.SchemaItem)
- ):
- try:
- element = element.__clause_element__()
- except AttributeError:
- break
-
- if not is_clause_element:
- insp = inspection.inspect(element, raiseerr=False)
- if insp is not None:
- try:
- return insp.__clause_element__()
- except AttributeError:
- return insp
-
- # TODO: should we coerce consts None/True/False here?
- return element
- else:
- return element
-
-
-def _globals_tracker(name, wrapper):
- def extract_parameter_value(current_fn, result):
- object.__getattribute__(wrapper, "_extract_bound_parameters")(
- current_fn.__globals__[name], result
- )
-
- return extract_parameter_value
-
-
-def _closure_tracker(name, wrapper, closure_index):
- def extract_parameter_value(current_fn, result):
- object.__getattribute__(wrapper, "_extract_bound_parameters")(
- current_fn.__closure__[closure_index].cell_contents, result
- )
-
- return extract_parameter_value
-
-
-def _rewrite_code_obj(f, cell_values, globals_):
- """Return a copy of f, with a new closure and new globals
-
- yes it works in pypy :P
-
- """
-
- argrange = range(len(cell_values))
-
- code = "def make_cells():\n"
- if cell_values:
- code += " (%s) = (%s)\n" % (
- ", ".join("i%d" % i for i in argrange),
- ", ".join("o%d" % i for i in argrange),
- )
- code += " def closure():\n"
- code += " return %s\n" % ", ".join("i%d" % i for i in argrange)
- code += " return closure.__closure__"
- vars_ = {"o%d" % i: cell_values[i] for i in argrange}
- exec(code, vars_, vars_)
- closure = vars_["make_cells"]()
-
- func = type(f)(f.__code__, globals_, f.__name__, f.__defaults__, closure)
- if sys.version_info >= (3,):
- func.__annotations__ = f.__annotations__
- func.__kwdefaults__ = f.__kwdefaults__
- func.__doc__ = f.__doc__
- func.__module__ = f.__module__
-
- return func
-
-
@inspection._inspects(LambdaElement)
def insp(lmb):
return inspection.inspect(lmb._resolved)
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index 1155c273b..d67b61743 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -2426,6 +2426,7 @@ class SelectBase(
"""
_is_select_statement = True
+ is_select = True
def _generate_fromclause_column_proxies(self, fromclause):
# type: (FromClause) -> None
@@ -3867,19 +3868,19 @@ class Select(
[
("_raw_columns", InternalTraversal.dp_clauseelement_list),
("_from_obj", InternalTraversal.dp_clauseelement_list),
- ("_where_criteria", InternalTraversal.dp_clauseelement_list),
- ("_having_criteria", InternalTraversal.dp_clauseelement_list),
- ("_order_by_clauses", InternalTraversal.dp_clauseelement_list,),
- ("_group_by_clauses", InternalTraversal.dp_clauseelement_list,),
+ ("_where_criteria", InternalTraversal.dp_clauseelement_tuple),
+ ("_having_criteria", InternalTraversal.dp_clauseelement_tuple),
+ ("_order_by_clauses", InternalTraversal.dp_clauseelement_tuple,),
+ ("_group_by_clauses", InternalTraversal.dp_clauseelement_tuple,),
("_setup_joins", InternalTraversal.dp_setup_join_tuple,),
("_legacy_setup_joins", InternalTraversal.dp_setup_join_tuple,),
- ("_correlate", InternalTraversal.dp_clauseelement_list),
- ("_correlate_except", InternalTraversal.dp_clauseelement_list,),
+ ("_correlate", InternalTraversal.dp_clauseelement_tuple),
+ ("_correlate_except", InternalTraversal.dp_clauseelement_tuple,),
("_limit_clause", InternalTraversal.dp_clauseelement),
("_offset_clause", InternalTraversal.dp_clauseelement),
("_for_update_arg", InternalTraversal.dp_clauseelement),
("_distinct", InternalTraversal.dp_boolean),
- ("_distinct_on", InternalTraversal.dp_clauseelement_list),
+ ("_distinct_on", InternalTraversal.dp_clauseelement_tuple),
("_label_style", InternalTraversal.dp_plain_obj),
("_is_future", InternalTraversal.dp_boolean),
]
@@ -4345,7 +4346,7 @@ class Select(
@_generative
def join(self, target, onclause=None, isouter=False, full=False):
- r"""Create a SQL JOIN against this :class:`_expresson.Select`
+ r"""Create a SQL JOIN against this :class:`_expression.Select`
object's criterion
and apply generatively, returning the newly resulting
:class:`_expression.Select`.
@@ -4474,7 +4475,7 @@ class Select(
# they've become. This allows us to ensure the same cloned from
# is used when other items such as columns are "cloned"
- all_the_froms = list(
+ all_the_froms = set(
itertools.chain(
_from_objects(*self._raw_columns),
_from_objects(*self._where_criteria),
@@ -4490,10 +4491,15 @@ class Select(
new_froms = {f: clone(f, **kw) for f in all_the_froms}
# 2. copy FROM collections, adding in joins that we've created.
- self._from_obj = tuple(clone(f, **kw) for f in self._from_obj) + tuple(
- f for f in new_froms.values() if isinstance(f, Join)
+ existing_from_obj = [clone(f, **kw) for f in self._from_obj]
+ add_froms = (
+ set(f for f in new_froms.values() if isinstance(f, Join))
+ .difference(all_the_froms)
+ .difference(existing_from_obj)
)
+ self._from_obj = tuple(existing_from_obj) + tuple(add_froms)
+
# 3. clone everything else, making sure we use columns
# corresponding to the froms we just made.
def replace(obj, **kw):
@@ -4687,6 +4693,7 @@ class Select(
"""
+ assert isinstance(self._where_criteria, tuple)
self._where_criteria += (
coercions.expect(roles.WhereHavingRole, whereclause),
)
@@ -5371,6 +5378,9 @@ class TextualSelect(SelectBase):
_is_textual = True
+ is_text = True
+ is_select = True
+
def __init__(self, text, columns, positional=False):
self.element = text
# convert for ORM attributes->columns, etc
diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py
index f41480a94..cb38df6af 100644
--- a/lib/sqlalchemy/sql/traversals.py
+++ b/lib/sqlalchemy/sql/traversals.py
@@ -190,7 +190,10 @@ class HasCacheKey(object):
# statements, not so much, but they usually won't have
# annotations.
result += self._annotations_cache_key
- elif meth is InternalTraversal.dp_clauseelement_list:
+ elif (
+ meth is InternalTraversal.dp_clauseelement_list
+ or meth is InternalTraversal.dp_clauseelement_tuple
+ ):
result += (
attrname,
tuple(
@@ -390,6 +393,7 @@ class _CacheKey(ExtendedInternalTraversal):
visit_has_cache_key = visit_clauseelement = CALL_GEN_CACHE_KEY
visit_clauseelement_list = InternalTraversal.dp_clauseelement_list
visit_annotations_key = InternalTraversal.dp_annotations_key
+ visit_clauseelement_tuple = InternalTraversal.dp_clauseelement_tuple
visit_string = (
visit_boolean
@@ -451,6 +455,8 @@ class _CacheKey(ExtendedInternalTraversal):
tuple(elem._gen_cache_key(anon_map, bindparams) for elem in obj),
)
+ visit_executable_options = visit_has_cache_key_list
+
def visit_inspectable_list(
self, attrname, obj, parent, anon_map, bindparams
):
@@ -682,6 +688,41 @@ class _CacheKey(ExtendedInternalTraversal):
_cache_key_traversal_visitor = _CacheKey()
+class HasCopyInternals(object):
+ def _clone(self, **kw):
+ raise NotImplementedError()
+
+ def _copy_internals(self, omit_attrs=(), **kw):
+ """Reassign internal elements to be clones of themselves.
+
+ Called during a copy-and-traverse operation on newly
+ shallow-copied elements to create a deep copy.
+
+ The given clone function should be used, which may be applying
+ additional transformations to the element (i.e. replacement
+ traversal, cloned traversal, annotations).
+
+ """
+
+ try:
+ traverse_internals = self._traverse_internals
+ except AttributeError:
+ # user-defined classes may not have a _traverse_internals
+ return
+
+ for attrname, obj, meth in _copy_internals.run_generated_dispatch(
+ self, traverse_internals, "_generated_copy_internals_traversal"
+ ):
+ if attrname in omit_attrs:
+ continue
+
+ if obj is not None:
+
+ result = meth(attrname, self, obj, **kw)
+ if result is not None:
+ setattr(self, attrname, result)
+
+
class _CopyInternals(InternalTraversal):
"""Generate a _copy_internals internal traversal dispatch for classes
with a _traverse_internals collection."""
@@ -696,6 +737,16 @@ class _CopyInternals(InternalTraversal):
):
return [clone(clause, **kw) for clause in element]
+ def visit_clauseelement_tuple(
+ self, attrname, parent, element, clone=_clone, **kw
+ ):
+ return tuple([clone(clause, **kw) for clause in element])
+
+ def visit_executable_options(
+ self, attrname, parent, element, clone=_clone, **kw
+ ):
+ return tuple([clone(clause, **kw) for clause in element])
+
def visit_clauseelement_unordered_set(
self, attrname, parent, element, clone=_clone, **kw
):
@@ -817,6 +868,9 @@ class _GetChildren(InternalTraversal):
def visit_clauseelement_list(self, element, **kw):
return element
+ def visit_clauseelement_tuple(self, element, **kw):
+ return element
+
def visit_clauseelement_tuples(self, element, **kw):
return itertools.chain.from_iterable(element)
@@ -840,8 +894,8 @@ class _GetChildren(InternalTraversal):
if not isinstance(target, str):
yield _flatten_clauseelement(target)
- # if onclause is not None and not isinstance(onclause, str):
- # yield _flatten_clauseelement(onclause)
+ if onclause is not None and not isinstance(onclause, str):
+ yield _flatten_clauseelement(onclause)
def visit_dml_ordered_values(self, element, **kw):
for k, v in element:
@@ -1015,6 +1069,8 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
):
return COMPARE_FAILED
+ visit_executable_options = visit_has_cache_key_list
+
def visit_clauseelement(
self, attrname, left_parent, left, right_parent, right, **kw
):
@@ -1057,6 +1113,12 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
for l, r in util.zip_longest(left, right, fillvalue=None):
self.stack.append((l, r))
+ def visit_clauseelement_tuple(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ for l, r in util.zip_longest(left, right, fillvalue=None):
+ self.stack.append((l, r))
+
def _compare_unordered_sequences(self, seq1, seq2, **kw):
if seq1 is None:
return seq2 is None
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
index 56d3c93b3..5cb3cba70 100644
--- a/lib/sqlalchemy/sql/visitors.py
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -257,7 +257,7 @@ class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)):
"""
- dp_clauseelement_tuples = symbol("CT")
+ dp_clauseelement_tuples = symbol("CTS")
"""Visit a list of tuples which contain :class:`_expression.ClauseElement`
objects.
@@ -268,6 +268,13 @@ class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)):
"""
+ dp_clauseelement_tuple = symbol("CT")
+ """Visit a tuple of :class:`_expression.ClauseElement` objects.
+
+ """
+
+ dp_executable_options = symbol("EO")
+
dp_fromclause_ordered_set = symbol("CO")
"""Visit an ordered set of :class:`_expression.FromClause` objects. """
@@ -712,6 +719,9 @@ def cloned_traverse(obj, opts, visitors):
cloned = {}
stop_on = set(opts.get("stop_on", []))
+ def deferred_copy_internals(obj):
+ return cloned_traverse(obj, opts, visitors)
+
def clone(elem, **kw):
if elem in stop_on:
return elem
@@ -732,7 +742,7 @@ def cloned_traverse(obj, opts, visitors):
return cloned[id(elem)]
if obj is not None:
- obj = clone(obj)
+ obj = clone(obj, deferred_copy_internals=deferred_copy_internals)
clone = None # remove gc cycles
return obj
@@ -764,6 +774,9 @@ def replacement_traverse(obj, opts, replace):
cloned = {}
stop_on = {id(x) for x in opts.get("stop_on", [])}
+ def deferred_copy_internals(obj):
+ return replacement_traverse(obj, opts, replace)
+
def clone(elem, **kw):
if (
id(elem) in stop_on
@@ -776,19 +789,24 @@ def replacement_traverse(obj, opts, replace):
stop_on.add(id(newelem))
return newelem
else:
-
- if elem not in cloned:
+ # base "already seen" on id(), not hash, so that we don't
+ # replace an Annotated element with its non-annotated one, and
+ # vice versa
+ id_elem = id(elem)
+ if id_elem not in cloned:
if "replace" in kw:
newelem = kw["replace"](elem)
if newelem is not None:
- cloned[elem] = newelem
+ cloned[id_elem] = newelem
return newelem
- cloned[elem] = newelem = elem._clone()
+ cloned[id_elem] = newelem = elem._clone()
newelem._copy_internals(clone=clone, **kw)
- return cloned[elem]
+ return cloned[id_elem]
if obj is not None:
- obj = clone(obj, **opts)
+ obj = clone(
+ obj, deferred_copy_internals=deferred_copy_internals, **opts
+ )
clone = None # remove gc cycles
return obj