summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2019-08-29 14:45:23 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2019-11-04 13:22:43 -0500
commit29330ec1596f12462c501a65404ff52005b16b6c (patch)
treebe20b85ae3939cdbc4f790fadd4f4372421891d4 /lib/sqlalchemy/sql
parentdb47859dca999b9d1679b513fe855e408d7d07c4 (diff)
downloadsqlalchemy-29330ec1596f12462c501a65404ff52005b16b6c.tar.gz
Add anonymizing context to cache keys, comparison; convert traversal
Created new visitor system called "internal traversal" that applies a data driven approach to the concept of a class that defines its own traversal steps, in contrast to the existing style of traversal now known as "external traversal" where the visitor class defines the traversal, i.e. the SQLCompiler. The internal traversal system now implements get_children(), _copy_internals(), compare() and _cache_key() for most Core elements. Core elements with special needs like Select still implement some of these methods directly however most of these methods are no longer explicitly implemented. The data-driven system is also applied to ORM elements that take part in SQL expressions so that these objects, like mappers, aliasedclass, query options, etc. can all participate in the cache key process. Still not considered is that this approach to defining traversibility will be used to create some kind of generic introspection system that works across Core / ORM. It's also not clear if real statement caching using the _cache_key() method is feasible, if it is shown that running _cache_key() is nearly as expensive as compiling in any case. Because it is data driven, it is more straightforward to optimize using inlined code, as is the case now, as well as potentially using C code to speed it up. In addition, the caching sytem now accommodates for anonymous name labels, which is essential so that constructs which have anonymous labels can be cacheable, that is, their position within a statement in relation to other anonymous names causes them to generate an integer counter relative to that construct which will be the same every time. Gathering of bound parameters from any cache key generation is also now required as there is no use case for a cache key that does not extract bound parameter values. Applies-to: #4639 Change-Id: I0660584def8627cad566719ee98d3be045db4b8d
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r--lib/sqlalchemy/sql/annotation.py69
-rw-r--r--lib/sqlalchemy/sql/base.py36
-rw-r--r--lib/sqlalchemy/sql/clause_compare.py334
-rw-r--r--lib/sqlalchemy/sql/compiler.py29
-rw-r--r--lib/sqlalchemy/sql/default_comparator.py3
-rw-r--r--lib/sqlalchemy/sql/elements.py515
-rw-r--r--lib/sqlalchemy/sql/expression.py2
-rw-r--r--lib/sqlalchemy/sql/functions.py70
-rw-r--r--lib/sqlalchemy/sql/schema.py18
-rw-r--r--lib/sqlalchemy/sql/selectable.py396
-rw-r--r--lib/sqlalchemy/sql/traversals.py768
-rw-r--r--lib/sqlalchemy/sql/type_api.py17
-rw-r--r--lib/sqlalchemy/sql/util.py2
-rw-r--r--lib/sqlalchemy/sql/visitors.py447
14 files changed, 1646 insertions, 1060 deletions
diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py
index a0264845e..0d995ec8a 100644
--- a/lib/sqlalchemy/sql/annotation.py
+++ b/lib/sqlalchemy/sql/annotation.py
@@ -12,12 +12,32 @@ associations.
"""
from . import operators
+from .base import HasCacheKey
+from .visitors import InternalTraversal
from .. import util
-class SupportsCloneAnnotations(object):
+class SupportsAnnotations(object):
+ @util.memoized_property
+ def _annotation_traversals(self):
+ return [
+ (
+ key,
+ InternalTraversal.dp_has_cache_key
+ if isinstance(value, HasCacheKey)
+ else InternalTraversal.dp_plain_obj,
+ )
+ for key, value in self._annotations.items()
+ ]
+
+
+class SupportsCloneAnnotations(SupportsAnnotations):
_annotations = util.immutabledict()
+ _traverse_internals = [
+ ("_annotations", InternalTraversal.dp_annotations_state)
+ ]
+
def _annotate(self, values):
"""return a copy of this ClauseElement with annotations
updated by the given dictionary.
@@ -25,6 +45,7 @@ class SupportsCloneAnnotations(object):
"""
new = self._clone()
new._annotations = new._annotations.union(values)
+ new.__dict__.pop("_annotation_traversals", None)
return new
def _with_annotations(self, values):
@@ -34,6 +55,7 @@ class SupportsCloneAnnotations(object):
"""
new = self._clone()
new._annotations = util.immutabledict(values)
+ new.__dict__.pop("_annotation_traversals", None)
return new
def _deannotate(self, values=None, clone=False):
@@ -49,12 +71,13 @@ class SupportsCloneAnnotations(object):
# the expression for a deep deannotation
new = self._clone()
new._annotations = {}
+ new.__dict__.pop("_annotation_traversals", None)
return new
else:
return self
-class SupportsWrappingAnnotations(object):
+class SupportsWrappingAnnotations(SupportsAnnotations):
def _annotate(self, values):
"""return a copy of this ClauseElement with annotations
updated by the given dictionary.
@@ -123,6 +146,7 @@ class Annotated(object):
def __init__(self, element, values):
self.__dict__ = element.__dict__.copy()
+ self.__dict__.pop("_annotation_traversals", None)
self.__element = element
self._annotations = values
self._hash = hash(element)
@@ -135,6 +159,7 @@ class Annotated(object):
def _with_annotations(self, values):
clone = self.__class__.__new__(self.__class__)
clone.__dict__ = self.__dict__.copy()
+ clone.__dict__.pop("_annotation_traversals", None)
clone._annotations = values
return clone
@@ -192,7 +217,17 @@ def _deep_annotate(element, annotations, exclude=None):
"""
- def clone(elem):
+ # annotated objects hack the __hash__() method so if we want to
+ # uniquely process them we have to use id()
+
+ cloned_ids = {}
+
+ def clone(elem, **kw):
+ id_ = id(elem)
+
+ if id_ in cloned_ids:
+ return cloned_ids[id_]
+
if (
exclude
and hasattr(elem, "proxy_set")
@@ -204,6 +239,7 @@ def _deep_annotate(element, annotations, exclude=None):
else:
newelem = elem
newelem._copy_internals(clone=clone)
+ cloned_ids[id_] = newelem
return newelem
if element is not None:
@@ -214,23 +250,21 @@ def _deep_annotate(element, annotations, exclude=None):
def _deep_deannotate(element, values=None):
"""Deep copy the given element, removing annotations."""
- cloned = util.column_dict()
+ cloned = {}
- def clone(elem):
- # if a values dict is given,
- # the elem must be cloned each time it appears,
- # as there may be different annotations in source
- # elements that are remaining. if totally
- # removing all annotations, can assume the same
- # slate...
- if values or elem not in cloned:
+ def clone(elem, **kw):
+ if values:
+ key = id(elem)
+ else:
+ key = elem
+
+ if key not in cloned:
newelem = elem._deannotate(values=values, clone=True)
newelem._copy_internals(clone=clone)
- if not values:
- cloned[elem] = newelem
+ cloned[key] = newelem
return newelem
else:
- return cloned[elem]
+ return cloned[key]
if element is not None:
element = clone(element)
@@ -268,6 +302,11 @@ def _new_annotation_type(cls, base_cls):
"Annotated%s" % cls.__name__, (base_cls, cls), {}
)
globals()["Annotated%s" % cls.__name__] = anno_cls
+
+ if "_traverse_internals" in cls.__dict__:
+ anno_cls._traverse_internals = list(cls._traverse_internals) + [
+ ("_annotations", InternalTraversal.dp_annotations_state)
+ ]
return anno_cls
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index 7e9199bfa..d11a3a313 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -14,6 +14,7 @@ import itertools
import operator
import re
+from .traversals import HasCacheKey # noqa
from .visitors import ClauseVisitor
from .. import exc
from .. import util
@@ -38,18 +39,41 @@ class Immutable(object):
def _clone(self):
return self
+ def _copy_internals(self, **kw):
+ pass
+
+
+class HasMemoized(object):
+ def _reset_memoizations(self):
+ self._memoized_property.expire_instance(self)
+
+ def _reset_exported(self):
+ self._memoized_property.expire_instance(self)
+
+ def _copy_internals(self, **kw):
+ super(HasMemoized, self)._copy_internals(**kw)
+ self._reset_memoizations()
+
def _from_objects(*elements):
return itertools.chain(*[element._from_objects for element in elements])
def _generative(fn):
+ """non-caching _generative() decorator.
+
+ This is basically the legacy decorator that copies the object and
+ runs a method on the new copy.
+
+ """
+
@util.decorator
- def _generative(fn, *args, **kw):
+ def _generative(fn, self, *args, **kw):
"""Mark a method as generative."""
- self = args[0]._generate()
- fn(self, *args[1:], **kw)
+ self = self._generate()
+ x = fn(self, *args, **kw)
+ assert x is None, "generative methods must have no return value"
return self
decorated = _generative(fn)
@@ -357,10 +381,8 @@ class DialectKWArgs(object):
class Generative(object):
- """Allow a ClauseElement to generate itself via the
- @_generative decorator.
-
- """
+ """Provide a method-chaining pattern in conjunction with the
+ @_generative decorator."""
def _generate(self):
s = self.__class__.__new__(self.__class__)
diff --git a/lib/sqlalchemy/sql/clause_compare.py b/lib/sqlalchemy/sql/clause_compare.py
deleted file mode 100644
index 30a90348c..000000000
--- a/lib/sqlalchemy/sql/clause_compare.py
+++ /dev/null
@@ -1,334 +0,0 @@
-from collections import deque
-
-from . import operators
-from .. import util
-
-
-SKIP_TRAVERSE = util.symbol("skip_traverse")
-
-
-def compare(obj1, obj2, **kw):
- if kw.get("use_proxies", False):
- strategy = ColIdentityComparatorStrategy()
- else:
- strategy = StructureComparatorStrategy()
-
- return strategy.compare(obj1, obj2, **kw)
-
-
-class StructureComparatorStrategy(object):
- __slots__ = "compare_stack", "cache"
-
- def __init__(self):
- self.compare_stack = deque()
- self.cache = set()
-
- def compare(self, obj1, obj2, **kw):
- stack = self.compare_stack
- cache = self.cache
-
- stack.append((obj1, obj2))
-
- while stack:
- left, right = stack.popleft()
-
- if left is right:
- continue
- elif left is None or right is None:
- # we know they are different so no match
- return False
- elif (left, right) in cache:
- continue
- cache.add((left, right))
-
- visit_name = left.__visit_name__
-
- # we're not exactly looking for identical types, because
- # there are things like Column and AnnotatedColumn. So the
- # visit_name has to at least match up
- if visit_name != right.__visit_name__:
- return False
-
- meth = getattr(self, "compare_%s" % visit_name, None)
-
- if meth:
- comparison = meth(left, right, **kw)
- if comparison is False:
- return False
- elif comparison is SKIP_TRAVERSE:
- continue
-
- for c1, c2 in util.zip_longest(
- left.get_children(column_collections=False),
- right.get_children(column_collections=False),
- fillvalue=None,
- ):
- if c1 is None or c2 is None:
- # collections are different sizes, comparison fails
- return False
- stack.append((c1, c2))
-
- return True
-
- def compare_inner(self, obj1, obj2, **kw):
- stack = self.compare_stack
- try:
- self.compare_stack = deque()
- return self.compare(obj1, obj2, **kw)
- finally:
- self.compare_stack = stack
-
- def _compare_unordered_sequences(self, seq1, seq2, **kw):
- if seq1 is None:
- return seq2 is None
-
- completed = set()
- for clause in seq1:
- for other_clause in set(seq2).difference(completed):
- if self.compare_inner(clause, other_clause, **kw):
- completed.add(other_clause)
- break
- return len(completed) == len(seq1) == len(seq2)
-
- def compare_bindparam(self, left, right, **kw):
- # note the ".key" is often generated from id(self) so can't
- # be compared, as far as determining structure.
- return (
- left.type._compare_type_affinity(right.type)
- and left.value == right.value
- and left.callable == right.callable
- and left._orig_key == right._orig_key
- )
-
- def compare_clauselist(self, left, right, **kw):
- if left.operator is right.operator:
- if operators.is_associative(left.operator):
- if self._compare_unordered_sequences(
- left.clauses, right.clauses
- ):
- return SKIP_TRAVERSE
- else:
- return False
- else:
- # normal ordered traversal
- return True
- else:
- return False
-
- def compare_unary(self, left, right, **kw):
- if left.operator:
- disp = self._get_operator_dispatch(
- left.operator, "unary", "operator"
- )
- if disp is not None:
- result = disp(left, right, left.operator, **kw)
- if result is not True:
- return result
- elif left.modifier:
- disp = self._get_operator_dispatch(
- left.modifier, "unary", "modifier"
- )
- if disp is not None:
- result = disp(left, right, left.operator, **kw)
- if result is not True:
- return result
- return (
- left.operator == right.operator and left.modifier == right.modifier
- )
-
- def compare_binary(self, left, right, **kw):
- disp = self._get_operator_dispatch(left.operator, "binary", None)
- if disp:
- result = disp(left, right, left.operator, **kw)
- if result is not True:
- return result
-
- if left.operator == right.operator:
- if operators.is_commutative(left.operator):
- if (
- compare(left.left, right.left, **kw)
- and compare(left.right, right.right, **kw)
- ) or (
- compare(left.left, right.right, **kw)
- and compare(left.right, right.left, **kw)
- ):
- return SKIP_TRAVERSE
- else:
- return False
- else:
- return True
- else:
- return False
-
- def _get_operator_dispatch(self, operator_, qualifier1, qualifier2):
- # used by compare_binary, compare_unary
- attrname = "visit_%s_%s%s" % (
- operator_.__name__,
- qualifier1,
- "_" + qualifier2 if qualifier2 else "",
- )
- return getattr(self, attrname, None)
-
- def visit_function_as_comparison_op_binary(
- self, left, right, operator, **kw
- ):
- return (
- left.left_index == right.left_index
- and left.right_index == right.right_index
- )
-
- def compare_function(self, left, right, **kw):
- return left.name == right.name
-
- def compare_column(self, left, right, **kw):
- if left.table is not None:
- self.compare_stack.appendleft((left.table, right.table))
- return (
- left.key == right.key
- and left.name == right.name
- and (
- left.type._compare_type_affinity(right.type)
- if left.type is not None
- else right.type is None
- )
- and left.is_literal == right.is_literal
- )
-
- def compare_collation(self, left, right, **kw):
- return left.collation == right.collation
-
- def compare_type_coerce(self, left, right, **kw):
- return left.type._compare_type_affinity(right.type)
-
- @util.dependencies("sqlalchemy.sql.elements")
- def compare_alias(self, elements, left, right, **kw):
- return (
- left.name == right.name
- if not isinstance(left.name, elements._anonymous_label)
- else isinstance(right.name, elements._anonymous_label)
- )
-
- def compare_cte(self, elements, left, right, **kw):
- raise NotImplementedError("TODO")
-
- def compare_extract(self, left, right, **kw):
- return left.field == right.field
-
- def compare_textual_label_reference(self, left, right, **kw):
- return left.element == right.element
-
- def compare_slice(self, left, right, **kw):
- return (
- left.start == right.start
- and left.stop == right.stop
- and left.step == right.step
- )
-
- def compare_over(self, left, right, **kw):
- return left.range_ == right.range_ and left.rows == right.rows
-
- @util.dependencies("sqlalchemy.sql.elements")
- def compare_label(self, elements, left, right, **kw):
- return left._type._compare_type_affinity(right._type) and (
- left.name == right.name
- if not isinstance(left.name, elements._anonymous_label)
- else isinstance(right.name, elements._anonymous_label)
- )
-
- def compare_typeclause(self, left, right, **kw):
- return left.type._compare_type_affinity(right.type)
-
- def compare_join(self, left, right, **kw):
- return left.isouter == right.isouter and left.full == right.full
-
- def compare_table(self, left, right, **kw):
- if left.name != right.name:
- return False
-
- self.compare_stack.extendleft(
- util.zip_longest(left.columns, right.columns)
- )
-
- def compare_compound_select(self, left, right, **kw):
-
- if not self._compare_unordered_sequences(
- left.selects, right.selects, **kw
- ):
- return False
-
- if left.keyword != right.keyword:
- return False
-
- if left._for_update_arg != right._for_update_arg:
- return False
-
- if not self.compare_inner(
- left._order_by_clause, right._order_by_clause, **kw
- ):
- return False
-
- if not self.compare_inner(
- left._group_by_clause, right._group_by_clause, **kw
- ):
- return False
-
- return SKIP_TRAVERSE
-
- def compare_select(self, left, right, **kw):
- if not self._compare_unordered_sequences(
- left._correlate, right._correlate
- ):
- return False
- if not self._compare_unordered_sequences(
- left._correlate_except, right._correlate_except
- ):
- return False
-
- if not self._compare_unordered_sequences(
- left._from_obj, right._from_obj
- ):
- return False
-
- if left._for_update_arg != right._for_update_arg:
- return False
-
- return True
-
- def compare_textual_select(self, left, right, **kw):
- self.compare_stack.extendleft(
- util.zip_longest(left.column_args, right.column_args)
- )
- return left.positional == right.positional
-
-
-class ColIdentityComparatorStrategy(StructureComparatorStrategy):
- def compare_column_element(
- self, left, right, use_proxies=True, equivalents=(), **kw
- ):
- """Compare ColumnElements using proxies and equivalent collections.
-
- This is a comparison strategy specific to the ORM.
- """
-
- to_compare = (right,)
- if equivalents and right in equivalents:
- to_compare = equivalents[right].union(to_compare)
-
- for oth in to_compare:
- if use_proxies and left.shares_lineage(oth):
- return True
- elif hash(left) == hash(right):
- return True
- else:
- return False
-
- def compare_column(self, left, right, **kw):
- return self.compare_column_element(left, right, **kw)
-
- def compare_label(self, left, right, **kw):
- return self.compare_column_element(left, right, **kw)
-
- def compare_table(self, left, right, **kw):
- # tables compare on identity, since it's not really feasible to
- # compare them column by column with the above rules
- return left is right
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 5ecec7d6c..546fffc6c 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -434,6 +434,27 @@ class _CompileLabel(elements.ColumnElement):
return self
+class prefix_anon_map(dict):
+ """A map that creates new keys for missing key access.
+
+ Considers keys of the form "<ident> <name>" to produce
+ new symbols "<name>_<index>", where "index" is an incrementing integer
+ corresponding to <name>.
+
+ Inlines the approach taken by :class:`sqlalchemy.util.PopulateDict` which
+ is otherwise usually used for this type of operation.
+
+ """
+
+ def __missing__(self, key):
+ (ident, derived) = key.split(" ", 1)
+ anonymous_counter = self.get(derived, 1)
+ self[derived] = anonymous_counter + 1
+ value = derived + "_" + str(anonymous_counter)
+ self[key] = value
+ return value
+
+
class SQLCompiler(Compiled):
"""Default implementation of :class:`.Compiled`.
@@ -574,7 +595,7 @@ class SQLCompiler(Compiled):
# a map which tracks "anonymous" identifiers that are created on
# the fly here
- self.anon_map = util.PopulateDict(self._process_anon)
+ self.anon_map = prefix_anon_map()
# a map which tracks "truncated" names based on
# dialect.label_length or dialect.max_identifier_length
@@ -1712,12 +1733,6 @@ class SQLCompiler(Compiled):
def _anonymize(self, name):
return name % self.anon_map
- def _process_anon(self, key):
- (ident, derived) = key.split(" ", 1)
- anonymous_counter = self.anon_map.get(derived, 1)
- self.anon_map[derived] = anonymous_counter + 1
- return derived + "_" + str(anonymous_counter)
-
def bindparam_string(
self,
name,
diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py
index 918f7524e..c0baa8555 100644
--- a/lib/sqlalchemy/sql/default_comparator.py
+++ b/lib/sqlalchemy/sql/default_comparator.py
@@ -178,6 +178,9 @@ def _unsupported_impl(expr, op, *arg, **kw):
def _inv_impl(expr, op, **kw):
"""See :meth:`.ColumnOperators.__inv__`."""
+
+ # undocumented element currently used by the ORM for
+ # relationship.contains()
if hasattr(expr, "negation_clause"):
return expr.negation_clause
else:
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index e6f57b8d1..ba615bc3f 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -16,23 +16,29 @@ import itertools
import operator
import re
-from . import clause_compare
from . import coercions
from . import operators
from . import roles
+from . import traversals
from . import type_api
from .annotation import Annotated
from .annotation import SupportsWrappingAnnotations
from .base import _clone
from .base import _generative
from .base import Executable
+from .base import HasCacheKey
+from .base import HasMemoized
from .base import Immutable
from .base import NO_ARG
from .base import PARSE_AUTOCOMMIT
from .coercions import _document_text_coercion
+from .traversals import _copy_internals
+from .traversals import _get_children
+from .traversals import NO_CACHE
from .visitors import cloned_traverse
+from .visitors import InternalTraversal
from .visitors import traverse
-from .visitors import Visitable
+from .visitors import Traversible
from .. import exc
from .. import inspection
from .. import util
@@ -162,7 +168,9 @@ def not_(clause):
@inspection._self_inspects
-class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable):
+class ClauseElement(
+ roles.SQLRole, SupportsWrappingAnnotations, HasCacheKey, Traversible
+):
"""Base class for elements of a programmatically constructed SQL
expression.
@@ -190,6 +198,13 @@ class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable):
_order_by_label_element = None
+ @property
+ def _cache_key_traversal(self):
+ try:
+ return self._traverse_internals
+ except AttributeError:
+ return NO_CACHE
+
def _clone(self):
"""Create a shallow copy of this ClauseElement.
@@ -221,28 +236,6 @@ class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable):
"""
return self
- def _cache_key(self, **kw):
- """return an optional cache key.
-
- The cache key is a tuple which can contain any series of
- objects that are hashable and also identifies
- this object uniquely within the presence of a larger SQL expression
- or statement, for the purposes of caching the resulting query.
-
- The cache key should be based on the SQL compiled structure that would
- ultimately be produced. That is, two structures that are composed in
- exactly the same way should produce the same cache key; any difference
- in the strucures that would affect the SQL string or the type handlers
- should result in a different cache key.
-
- If a structure cannot produce a useful cache key, it should raise
- NotImplementedError, which will result in the entire structure
- for which it's part of not being useful as a cache key.
-
-
- """
- raise NotImplementedError()
-
@property
def _constructor(self):
"""return the 'constructor' for this ClauseElement.
@@ -336,9 +329,9 @@ class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable):
(see :class:`.ColumnElement`)
"""
- return clause_compare.compare(self, other, **kw)
+ return traversals.compare(self, other, **kw)
- def _copy_internals(self, clone=_clone, **kw):
+ def _copy_internals(self, **kw):
"""Reassign internal elements to be clones of themselves.
Called during a copy-and-traverse operation on newly
@@ -349,21 +342,46 @@ class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable):
traversal, cloned traversal, annotations).
"""
- pass
- def get_children(self, **kwargs):
- r"""Return immediate child elements of this :class:`.ClauseElement`.
+ try:
+ traverse_internals = self._traverse_internals
+ except AttributeError:
+ return
+
+ for attrname, obj, meth in _copy_internals.run_generated_dispatch(
+ self, traverse_internals, "_generated_copy_internals_traversal"
+ ):
+ if obj is not None:
+ result = meth(self, obj, **kw)
+ if result is not None:
+ setattr(self, attrname, result)
+
+ def get_children(self, omit_attrs=None, **kw):
+ r"""Return immediate child :class:`.Traversible` elements of this
+ :class:`.Traversible`.
This is used for visit traversal.
- \**kwargs may contain flags that change the collection that is
+ \**kw may contain flags that change the collection that is
returned, for example to return a subset of items in order to
cut down on larger traversals, or to return child items from a
different context (such as schema-level collections instead of
clause-level).
"""
- return []
+ result = []
+ try:
+ traverse_internals = self._traverse_internals
+ except AttributeError:
+ return result
+
+ for attrname, obj, meth in _get_children.run_generated_dispatch(
+ self, traverse_internals, "_generated_get_children_traversal"
+ ):
+ if obj is None or omit_attrs and attrname in omit_attrs:
+ continue
+ result.extend(meth(obj, **kw))
+ return result
def self_group(self, against=None):
# type: (Optional[Any]) -> ClauseElement
@@ -501,6 +519,8 @@ class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable):
return or_(self, other)
def __invert__(self):
+ # undocumented element currently used by the ORM for
+ # relationship.contains()
if hasattr(self, "negation_clause"):
return self.negation_clause
else:
@@ -508,9 +528,7 @@ class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable):
def _negate(self):
return UnaryExpression(
- self.self_group(against=operators.inv),
- operator=operators.inv,
- negate=None,
+ self.self_group(against=operators.inv), operator=operators.inv
)
def __bool__(self):
@@ -731,9 +749,6 @@ class ColumnElement(
else:
return comparator_factory(self)
- def _cache_key(self, **kw):
- raise NotImplementedError(self.__class__)
-
def __getattr__(self, key):
try:
return getattr(self.comparator, key)
@@ -969,6 +984,13 @@ class BindParameter(roles.InElementRole, ColumnElement):
__visit_name__ = "bindparam"
+ _traverse_internals = [
+ ("key", InternalTraversal.dp_anon_name),
+ ("type", InternalTraversal.dp_type),
+ ("callable", InternalTraversal.dp_plain_dict),
+ ("value", InternalTraversal.dp_plain_obj),
+ ]
+
_is_crud = False
_expanding_in_types = ()
@@ -1321,26 +1343,19 @@ class BindParameter(roles.InElementRole, ColumnElement):
)
return c
- def _cache_key(self, bindparams=None, **kw):
- if bindparams is None:
- # even though _cache_key is a private method, we would like to
- # be super paranoid about this point. You can't include the
- # "value" or "callable" in the cache key, because the value is
- # not part of the structure of a statement and is likely to
- # change every time. However you cannot *throw it away* either,
- # because you can't invoke the statement without the parameter
- # values that were explicitly placed. So require that they
- # are collected here to make sure this happens.
- if self._value_required_for_cache:
- raise NotImplementedError(
- "bindparams collection argument required for _cache_key "
- "implementation. Bound parameter cache keys are not safe "
- "to use without accommodating for the value or callable "
- "within the parameter itself."
- )
- else:
- bindparams.append(self)
- return (BindParameter, self.type._cache_key, self._orig_key)
+ def _gen_cache_key(self, anon_map, bindparams):
+ if self in anon_map:
+ return (anon_map[self], self.__class__)
+
+ id_ = anon_map[self]
+ bindparams.append(self)
+
+ return (
+ id_,
+ self.__class__,
+ self.type._gen_cache_key,
+ traversals._resolve_name_for_compare(self, self.key, anon_map),
+ )
def _convert_to_unique(self):
if not self.unique:
@@ -1377,12 +1392,11 @@ class TypeClause(ClauseElement):
__visit_name__ = "typeclause"
+ _traverse_internals = [("type", InternalTraversal.dp_type)]
+
def __init__(self, type_):
self.type = type_
- def _cache_key(self, **kw):
- return (TypeClause, self.type._cache_key)
-
class TextClause(
roles.DDLConstraintColumnRole,
@@ -1419,6 +1433,11 @@ class TextClause(
__visit_name__ = "textclause"
+ _traverse_internals = [
+ ("_bindparams", InternalTraversal.dp_string_clauseelement_dict),
+ ("text", InternalTraversal.dp_string),
+ ]
+
_is_text_clause = True
_is_textual = True
@@ -1861,19 +1880,6 @@ class TextClause(
else:
return self
- def _copy_internals(self, clone=_clone, **kw):
- self._bindparams = dict(
- (b.key, clone(b, **kw)) for b in self._bindparams.values()
- )
-
- def get_children(self, **kwargs):
- return list(self._bindparams.values())
-
- def _cache_key(self, **kw):
- return (self.text,) + tuple(
- bind._cache_key for bind in self._bindparams.values()
- )
-
class Null(roles.ConstExprRole, ColumnElement):
"""Represent the NULL keyword in a SQL statement.
@@ -1885,6 +1891,8 @@ class Null(roles.ConstExprRole, ColumnElement):
__visit_name__ = "null"
+ _traverse_internals = []
+
@util.memoized_property
def type(self):
return type_api.NULLTYPE
@@ -1895,9 +1903,6 @@ class Null(roles.ConstExprRole, ColumnElement):
return Null()
- def _cache_key(self, **kw):
- return (Null,)
-
class False_(roles.ConstExprRole, ColumnElement):
"""Represent the ``false`` keyword, or equivalent, in a SQL statement.
@@ -1908,6 +1913,7 @@ class False_(roles.ConstExprRole, ColumnElement):
"""
__visit_name__ = "false"
+ _traverse_internals = []
@util.memoized_property
def type(self):
@@ -1954,9 +1960,6 @@ class False_(roles.ConstExprRole, ColumnElement):
return False_()
- def _cache_key(self, **kw):
- return (False_,)
-
class True_(roles.ConstExprRole, ColumnElement):
"""Represent the ``true`` keyword, or equivalent, in a SQL statement.
@@ -1968,6 +1971,8 @@ class True_(roles.ConstExprRole, ColumnElement):
__visit_name__ = "true"
+ _traverse_internals = []
+
@util.memoized_property
def type(self):
return type_api.BOOLEANTYPE
@@ -2020,9 +2025,6 @@ class True_(roles.ConstExprRole, ColumnElement):
return True_()
- def _cache_key(self, **kw):
- return (True_,)
-
class ClauseList(
roles.InElementRole,
@@ -2038,6 +2040,11 @@ class ClauseList(
__visit_name__ = "clauselist"
+ _traverse_internals = [
+ ("clauses", InternalTraversal.dp_clauseelement_list),
+ ("operator", InternalTraversal.dp_operator),
+ ]
+
def __init__(self, *clauses, **kwargs):
self.operator = kwargs.pop("operator", operators.comma_op)
self.group = kwargs.pop("group", True)
@@ -2082,17 +2089,6 @@ class ClauseList(
coercions.expect(self._text_converter_role, clause)
)
- def _copy_internals(self, clone=_clone, **kw):
- self.clauses = [clone(clause, **kw) for clause in self.clauses]
-
- def get_children(self, **kwargs):
- return self.clauses
-
- def _cache_key(self, **kw):
- return (ClauseList, self.operator) + tuple(
- clause._cache_key(**kw) for clause in self.clauses
- )
-
@property
def _from_objects(self):
return list(itertools.chain(*[c._from_objects for c in self.clauses]))
@@ -2115,11 +2111,6 @@ class BooleanClauseList(ClauseList, ColumnElement):
"BooleanClauseList has a private constructor"
)
- def _cache_key(self, **kw):
- return (BooleanClauseList, self.operator) + tuple(
- clause._cache_key(**kw) for clause in self.clauses
- )
-
@classmethod
def _construct(cls, operator, continue_on, skip_on, *clauses, **kw):
convert_clauses = []
@@ -2250,6 +2241,8 @@ or_ = BooleanClauseList.or_
class Tuple(ClauseList, ColumnElement):
"""Represent a SQL tuple."""
+ _traverse_internals = ClauseList._traverse_internals + []
+
def __init__(self, *clauses, **kw):
"""Return a :class:`.Tuple`.
@@ -2289,11 +2282,6 @@ class Tuple(ClauseList, ColumnElement):
def _select_iterable(self):
return (self,)
- def _cache_key(self, **kw):
- return (Tuple,) + tuple(
- clause._cache_key(**kw) for clause in self.clauses
- )
-
def _bind_param(self, operator, obj, type_=None):
return Tuple(
*[
@@ -2339,6 +2327,12 @@ class Case(ColumnElement):
__visit_name__ = "case"
+ _traverse_internals = [
+ ("value", InternalTraversal.dp_clauseelement),
+ ("whens", InternalTraversal.dp_clauseelement_tuples),
+ ("else_", InternalTraversal.dp_clauseelement),
+ ]
+
def __init__(self, whens, value=None, else_=None):
r"""Produce a ``CASE`` expression.
@@ -2501,40 +2495,6 @@ class Case(ColumnElement):
else:
self.else_ = None
- def _copy_internals(self, clone=_clone, **kw):
- if self.value is not None:
- self.value = clone(self.value, **kw)
- self.whens = [(clone(x, **kw), clone(y, **kw)) for x, y in self.whens]
- if self.else_ is not None:
- self.else_ = clone(self.else_, **kw)
-
- def get_children(self, **kwargs):
- if self.value is not None:
- yield self.value
- for x, y in self.whens:
- yield x
- yield y
- if self.else_ is not None:
- yield self.else_
-
- def _cache_key(self, **kw):
- return (
- (
- Case,
- self.value._cache_key(**kw)
- if self.value is not None
- else None,
- )
- + tuple(
- (x._cache_key(**kw), y._cache_key(**kw)) for x, y in self.whens
- )
- + (
- self.else_._cache_key(**kw)
- if self.else_ is not None
- else None,
- )
- )
-
@property
def _from_objects(self):
return list(
@@ -2603,6 +2563,11 @@ class Cast(WrapsColumnExpression, ColumnElement):
__visit_name__ = "cast"
+ _traverse_internals = [
+ ("clause", InternalTraversal.dp_clauseelement),
+ ("typeclause", InternalTraversal.dp_clauseelement),
+ ]
+
def __init__(self, expression, type_):
r"""Produce a ``CAST`` expression.
@@ -2662,20 +2627,6 @@ class Cast(WrapsColumnExpression, ColumnElement):
)
self.typeclause = TypeClause(self.type)
- def _copy_internals(self, clone=_clone, **kw):
- self.clause = clone(self.clause, **kw)
- self.typeclause = clone(self.typeclause, **kw)
-
- def get_children(self, **kwargs):
- return self.clause, self.typeclause
-
- def _cache_key(self, **kw):
- return (
- Cast,
- self.clause._cache_key(**kw),
- self.typeclause._cache_key(**kw),
- )
-
@property
def _from_objects(self):
return self.clause._from_objects
@@ -2685,7 +2636,7 @@ class Cast(WrapsColumnExpression, ColumnElement):
return self.clause
-class TypeCoerce(WrapsColumnExpression, ColumnElement):
+class TypeCoerce(HasMemoized, WrapsColumnExpression, ColumnElement):
"""Represent a Python-side type-coercion wrapper.
:class:`.TypeCoerce` supplies the :func:`.expression.type_coerce`
@@ -2705,6 +2656,13 @@ class TypeCoerce(WrapsColumnExpression, ColumnElement):
__visit_name__ = "type_coerce"
+ _traverse_internals = [
+ ("clause", InternalTraversal.dp_clauseelement),
+ ("type", InternalTraversal.dp_type),
+ ]
+
+ _memoized_property = util.group_expirable_memoized_property()
+
def __init__(self, expression, type_):
r"""Associate a SQL expression with a particular type, without rendering
``CAST``.
@@ -2773,21 +2731,11 @@ class TypeCoerce(WrapsColumnExpression, ColumnElement):
roles.ExpressionElementRole, expression, type_=self.type
)
- def _copy_internals(self, clone=_clone, **kw):
- self.clause = clone(self.clause, **kw)
- self.__dict__.pop("typed_expression", None)
-
- def get_children(self, **kwargs):
- return (self.clause,)
-
- def _cache_key(self, **kw):
- return (TypeCoerce, self.type._cache_key, self.clause._cache_key(**kw))
-
@property
def _from_objects(self):
return self.clause._from_objects
- @util.memoized_property
+ @_memoized_property
def typed_expression(self):
if isinstance(self.clause, BindParameter):
bp = self.clause._clone()
@@ -2806,6 +2754,11 @@ class Extract(ColumnElement):
__visit_name__ = "extract"
+ _traverse_internals = [
+ ("expr", InternalTraversal.dp_clauseelement),
+ ("field", InternalTraversal.dp_string),
+ ]
+
def __init__(self, field, expr, **kwargs):
"""Return a :class:`.Extract` construct.
@@ -2818,15 +2771,6 @@ class Extract(ColumnElement):
self.field = field
self.expr = coercions.expect(roles.ExpressionElementRole, expr)
- def _copy_internals(self, clone=_clone, **kw):
- self.expr = clone(self.expr, **kw)
-
- def get_children(self, **kwargs):
- return (self.expr,)
-
- def _cache_key(self, **kw):
- return (Extract, self.field, self.expr._cache_key(**kw))
-
@property
def _from_objects(self):
return self.expr._from_objects
@@ -2847,18 +2791,11 @@ class _label_reference(ColumnElement):
__visit_name__ = "label_reference"
+ _traverse_internals = [("element", InternalTraversal.dp_clauseelement)]
+
def __init__(self, element):
self.element = element
- def _copy_internals(self, clone=_clone, **kw):
- self.element = clone(self.element, **kw)
-
- def _cache_key(self, **kw):
- return (_label_reference, self.element._cache_key(**kw))
-
- def get_children(self, **kwargs):
- return [self.element]
-
@property
def _from_objects(self):
return ()
@@ -2867,6 +2804,8 @@ class _label_reference(ColumnElement):
class _textual_label_reference(ColumnElement):
__visit_name__ = "textual_label_reference"
+ _traverse_internals = [("element", InternalTraversal.dp_string)]
+
def __init__(self, element):
self.element = element
@@ -2874,9 +2813,6 @@ class _textual_label_reference(ColumnElement):
def _text_clause(self):
return TextClause._create_text(self.element)
- def _cache_key(self, **kw):
- return (_textual_label_reference, self.element)
-
class UnaryExpression(ColumnElement):
"""Define a 'unary' expression.
@@ -2894,13 +2830,18 @@ class UnaryExpression(ColumnElement):
__visit_name__ = "unary"
+ _traverse_internals = [
+ ("element", InternalTraversal.dp_clauseelement),
+ ("operator", InternalTraversal.dp_operator),
+ ("modifier", InternalTraversal.dp_operator),
+ ]
+
def __init__(
self,
element,
operator=None,
modifier=None,
type_=None,
- negate=None,
wraps_column_expression=False,
):
self.operator = operator
@@ -2909,7 +2850,6 @@ class UnaryExpression(ColumnElement):
against=self.operator or self.modifier
)
self.type = type_api.to_instance(type_)
- self.negate = negate
self.wraps_column_expression = wraps_column_expression
@classmethod
@@ -3135,37 +3075,13 @@ class UnaryExpression(ColumnElement):
def _from_objects(self):
return self.element._from_objects
- def _copy_internals(self, clone=_clone, **kw):
- self.element = clone(self.element, **kw)
-
- def _cache_key(self, **kw):
- return (
- UnaryExpression,
- self.element._cache_key(**kw),
- self.operator,
- self.modifier,
- )
-
- def get_children(self, **kwargs):
- return (self.element,)
-
def _negate(self):
- if self.negate is not None:
- return UnaryExpression(
- self.element,
- operator=self.negate,
- negate=self.operator,
- modifier=self.modifier,
- type_=self.type,
- wraps_column_expression=self.wraps_column_expression,
- )
- elif self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity:
+ if self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity:
return UnaryExpression(
self.self_group(against=operators.inv),
operator=operators.inv,
type_=type_api.BOOLEANTYPE,
wraps_column_expression=self.wraps_column_expression,
- negate=None,
)
else:
return ClauseElement._negate(self)
@@ -3286,15 +3202,6 @@ class AsBoolean(WrapsColumnExpression, UnaryExpression):
# type: (Optional[Any]) -> ClauseElement
return self
- def _cache_key(self, **kw):
- return (
- self.element._cache_key(**kw),
- self.type._cache_key,
- self.operator,
- self.negate,
- self.modifier,
- )
-
def _negate(self):
if isinstance(self.element, (True_, False_)):
return self.element._negate()
@@ -3318,6 +3225,14 @@ class BinaryExpression(ColumnElement):
__visit_name__ = "binary"
+ _traverse_internals = [
+ ("left", InternalTraversal.dp_clauseelement),
+ ("right", InternalTraversal.dp_clauseelement),
+ ("operator", InternalTraversal.dp_operator),
+ ("negate", InternalTraversal.dp_operator),
+ ("modifiers", InternalTraversal.dp_plain_dict),
+ ]
+
_is_implicitly_boolean = True
"""Indicates that any database will know this is a boolean expression
even if the database does not have an explicit boolean datatype.
@@ -3360,20 +3275,6 @@ class BinaryExpression(ColumnElement):
def _from_objects(self):
return self.left._from_objects + self.right._from_objects
- def _copy_internals(self, clone=_clone, **kw):
- self.left = clone(self.left, **kw)
- self.right = clone(self.right, **kw)
-
- def get_children(self, **kwargs):
- return self.left, self.right
-
- def _cache_key(self, **kw):
- return (
- BinaryExpression,
- self.left._cache_key(**kw),
- self.right._cache_key(**kw),
- )
-
def self_group(self, against=None):
# type: (Optional[Any]) -> ClauseElement
@@ -3406,6 +3307,12 @@ class Slice(ColumnElement):
__visit_name__ = "slice"
+ _traverse_internals = [
+ ("start", InternalTraversal.dp_plain_obj),
+ ("stop", InternalTraversal.dp_plain_obj),
+ ("step", InternalTraversal.dp_plain_obj),
+ ]
+
def __init__(self, start, stop, step):
self.start = start
self.stop = stop
@@ -3417,9 +3324,6 @@ class Slice(ColumnElement):
assert against is operator.getitem
return self
- def _cache_key(self, **kw):
- return (Slice, self.start, self.stop, self.step)
-
class IndexExpression(BinaryExpression):
"""Represent the class of expressions that are like an "index" operation.
@@ -3444,6 +3348,11 @@ class GroupedElement(ClauseElement):
class Grouping(GroupedElement, ColumnElement):
"""Represent a grouping within a column expression"""
+ _traverse_internals = [
+ ("element", InternalTraversal.dp_clauseelement),
+ ("type", InternalTraversal.dp_type),
+ ]
+
def __init__(self, element):
self.element = element
self.type = getattr(element, "type", type_api.NULLTYPE)
@@ -3460,15 +3369,6 @@ class Grouping(GroupedElement, ColumnElement):
def _label(self):
return getattr(self.element, "_label", None) or self.anon_label
- def _copy_internals(self, clone=_clone, **kw):
- self.element = clone(self.element, **kw)
-
- def get_children(self, **kwargs):
- return (self.element,)
-
- def _cache_key(self, **kw):
- return (Grouping, self.element._cache_key(**kw))
-
@property
def _from_objects(self):
return self.element._from_objects
@@ -3501,6 +3401,14 @@ class Over(ColumnElement):
__visit_name__ = "over"
+ _traverse_internals = [
+ ("element", InternalTraversal.dp_clauseelement),
+ ("order_by", InternalTraversal.dp_clauseelement),
+ ("partition_by", InternalTraversal.dp_clauseelement),
+ ("range_", InternalTraversal.dp_plain_obj),
+ ("rows", InternalTraversal.dp_plain_obj),
+ ]
+
order_by = None
partition_by = None
@@ -3667,30 +3575,6 @@ class Over(ColumnElement):
def type(self):
return self.element.type
- def get_children(self, **kwargs):
- return [
- c
- for c in (self.element, self.partition_by, self.order_by)
- if c is not None
- ]
-
- def _cache_key(self, **kw):
- return (
- (Over,)
- + tuple(
- e._cache_key(**kw) if e is not None else None
- for e in (self.element, self.partition_by, self.order_by)
- )
- + (self.range_, self.rows)
- )
-
- def _copy_internals(self, clone=_clone, **kw):
- self.element = clone(self.element, **kw)
- if self.partition_by is not None:
- self.partition_by = clone(self.partition_by, **kw)
- if self.order_by is not None:
- self.order_by = clone(self.order_by, **kw)
-
@property
def _from_objects(self):
return list(
@@ -3723,6 +3607,11 @@ class WithinGroup(ColumnElement):
__visit_name__ = "withingroup"
+ _traverse_internals = [
+ ("element", InternalTraversal.dp_clauseelement),
+ ("order_by", InternalTraversal.dp_clauseelement),
+ ]
+
order_by = None
def __init__(self, element, *order_by):
@@ -3791,25 +3680,6 @@ class WithinGroup(ColumnElement):
else:
return self.element.type
- def get_children(self, **kwargs):
- return [c for c in (self.element, self.order_by) if c is not None]
-
- def _cache_key(self, **kw):
- return (
- WithinGroup,
- self.element._cache_key(**kw)
- if self.element is not None
- else None,
- self.order_by._cache_key(**kw)
- if self.order_by is not None
- else None,
- )
-
- def _copy_internals(self, clone=_clone, **kw):
- self.element = clone(self.element, **kw)
- if self.order_by is not None:
- self.order_by = clone(self.order_by, **kw)
-
@property
def _from_objects(self):
return list(
@@ -3845,6 +3715,11 @@ class FunctionFilter(ColumnElement):
__visit_name__ = "funcfilter"
+ _traverse_internals = [
+ ("func", InternalTraversal.dp_clauseelement),
+ ("criterion", InternalTraversal.dp_clauseelement),
+ ]
+
criterion = None
def __init__(self, func, *criterion):
@@ -3932,23 +3807,6 @@ class FunctionFilter(ColumnElement):
def type(self):
return self.func.type
- def get_children(self, **kwargs):
- return [c for c in (self.func, self.criterion) if c is not None]
-
- def _copy_internals(self, clone=_clone, **kw):
- self.func = clone(self.func, **kw)
- if self.criterion is not None:
- self.criterion = clone(self.criterion, **kw)
-
- def _cache_key(self, **kw):
- return (
- FunctionFilter,
- self.func._cache_key(**kw),
- self.criterion._cache_key(**kw)
- if self.criterion is not None
- else None,
- )
-
@property
def _from_objects(self):
return list(
@@ -3962,7 +3820,7 @@ class FunctionFilter(ColumnElement):
)
-class Label(roles.LabeledColumnExprRole, ColumnElement):
+class Label(HasMemoized, roles.LabeledColumnExprRole, ColumnElement):
"""Represents a column label (AS).
Represent a label, as typically applied to any column-level
@@ -3972,6 +3830,14 @@ class Label(roles.LabeledColumnExprRole, ColumnElement):
__visit_name__ = "label"
+ _traverse_internals = [
+ ("name", InternalTraversal.dp_anon_name),
+ ("_type", InternalTraversal.dp_type),
+ ("_element", InternalTraversal.dp_clauseelement),
+ ]
+
+ _memoized_property = util.group_expirable_memoized_property()
+
def __init__(self, name, element, type_=None):
"""Return a :class:`Label` object for the
given :class:`.ColumnElement`.
@@ -4010,14 +3876,11 @@ class Label(roles.LabeledColumnExprRole, ColumnElement):
def __reduce__(self):
return self.__class__, (self.name, self._element, self._type)
- def _cache_key(self, **kw):
- return (Label, self.element._cache_key(**kw), self._resolve_label)
-
@util.memoized_property
def _is_implicitly_boolean(self):
return self.element._is_implicitly_boolean
- @util.memoized_property
+ @_memoized_property
def _allow_label_resolve(self):
return self.element._allow_label_resolve
@@ -4031,7 +3894,7 @@ class Label(roles.LabeledColumnExprRole, ColumnElement):
self._type or getattr(self._element, "type", None)
)
- @util.memoized_property
+ @_memoized_property
def element(self):
return self._element.self_group(against=operators.as_)
@@ -4057,13 +3920,9 @@ class Label(roles.LabeledColumnExprRole, ColumnElement):
def foreign_keys(self):
return self.element.foreign_keys
- def get_children(self, **kwargs):
- return (self.element,)
-
def _copy_internals(self, clone=_clone, anonymize_labels=False, **kw):
+ self._reset_memoizations()
self._element = clone(self._element, **kw)
- self.__dict__.pop("element", None)
- self.__dict__.pop("_allow_label_resolve", None)
if anonymize_labels:
self.name = self._resolve_label = _anonymous_label(
"%%(%d %s)s"
@@ -4124,6 +3983,13 @@ class ColumnClause(roles.LabeledColumnExprRole, Immutable, ColumnElement):
__visit_name__ = "column"
+ _traverse_internals = [
+ ("name", InternalTraversal.dp_string),
+ ("type", InternalTraversal.dp_type),
+ ("table", InternalTraversal.dp_clauseelement),
+ ("is_literal", InternalTraversal.dp_boolean),
+ ]
+
onupdate = default = server_default = server_onupdate = None
_is_multiparam_column = False
@@ -4254,14 +4120,6 @@ class ColumnClause(roles.LabeledColumnExprRole, Immutable, ColumnElement):
table = property(_get_table, _set_table)
- def _cache_key(self, **kw):
- return (
- self.name,
- self.table.name if self.table is not None else None,
- self.is_literal,
- self.type._cache_key,
- )
-
@_memoized_property
def _from_objects(self):
t = self.table
@@ -4395,12 +4253,11 @@ class ColumnClause(roles.LabeledColumnExprRole, Immutable, ColumnElement):
class CollationClause(ColumnElement):
__visit_name__ = "collation"
+ _traverse_internals = [("collation", InternalTraversal.dp_string)]
+
def __init__(self, collation):
self.collation = collation
- def _cache_key(self, **kw):
- return (CollationClause, self.collation)
-
class _IdentifiedClause(Executable, ClauseElement):
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index 7ce822669..08e69f075 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -86,7 +86,6 @@ __all__ = [
from .base import _from_objects # noqa
from .base import ColumnCollection # noqa
from .base import Executable # noqa
-from .base import Generative # noqa
from .base import PARSE_AUTOCOMMIT # noqa
from .dml import Delete # noqa
from .dml import Insert # noqa
@@ -242,7 +241,6 @@ _UnaryExpression = UnaryExpression
_Case = Case
_Tuple = Tuple
_Over = Over
-_Generative = Generative
_TypeClause = TypeClause
_Extract = Extract
_Exists = Exists
diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py
index cbc8e539f..96e64dc28 100644
--- a/lib/sqlalchemy/sql/functions.py
+++ b/lib/sqlalchemy/sql/functions.py
@@ -17,7 +17,6 @@ from . import sqltypes
from . import util as sqlutil
from .base import ColumnCollection
from .base import Executable
-from .elements import _clone
from .elements import _type_from_args
from .elements import BinaryExpression
from .elements import BindParameter
@@ -33,7 +32,8 @@ from .elements import WithinGroup
from .selectable import Alias
from .selectable import FromClause
from .selectable import Select
-from .visitors import VisitableType
+from .visitors import InternalTraversal
+from .visitors import TraversibleType
from .. import util
@@ -78,10 +78,14 @@ class FunctionElement(Executable, ColumnElement, FromClause):
"""
+ _traverse_internals = [("clause_expr", InternalTraversal.dp_clauseelement)]
+
packagenames = ()
_has_args = False
+ _memoized_property = FromClause._memoized_property
+
def __init__(self, *clauses, **kwargs):
r"""Construct a :class:`.FunctionElement`.
@@ -136,7 +140,7 @@ class FunctionElement(Executable, ColumnElement, FromClause):
col = self.label(None)
return ColumnCollection(columns=[(col.key, col)])
- @util.memoized_property
+ @_memoized_property
def clauses(self):
"""Return the underlying :class:`.ClauseList` which contains
the arguments for this :class:`.FunctionElement`.
@@ -283,17 +287,6 @@ class FunctionElement(Executable, ColumnElement, FromClause):
def _from_objects(self):
return self.clauses._from_objects
- def get_children(self, **kwargs):
- return (self.clause_expr,)
-
- def _cache_key(self, **kw):
- return (FunctionElement, self.clause_expr._cache_key(**kw))
-
- def _copy_internals(self, clone=_clone, **kw):
- self.clause_expr = clone(self.clause_expr, **kw)
- self._reset_exported()
- FunctionElement.clauses._reset(self)
-
def within_group_type(self, within_group):
"""For types that define their return type as based on the criteria
within a WITHIN GROUP (ORDER BY) expression, called by the
@@ -404,6 +397,13 @@ class FunctionElement(Executable, ColumnElement, FromClause):
class FunctionAsBinary(BinaryExpression):
+ _traverse_internals = [
+ ("sql_function", InternalTraversal.dp_clauseelement),
+ ("left_index", InternalTraversal.dp_plain_obj),
+ ("right_index", InternalTraversal.dp_plain_obj),
+ ("modifiers", InternalTraversal.dp_plain_dict),
+ ]
+
def __init__(self, fn, left_index, right_index):
self.sql_function = fn
self.left_index = left_index
@@ -431,20 +431,6 @@ class FunctionAsBinary(BinaryExpression):
def right(self, value):
self.sql_function.clauses.clauses[self.right_index - 1] = value
- def _copy_internals(self, clone=_clone, **kw):
- self.sql_function = clone(self.sql_function, **kw)
-
- def get_children(self, **kw):
- yield self.sql_function
-
- def _cache_key(self, **kw):
- return (
- FunctionAsBinary,
- self.sql_function._cache_key(**kw),
- self.left_index,
- self.right_index,
- )
-
class _FunctionGenerator(object):
"""Generate SQL function expressions.
@@ -606,6 +592,12 @@ class Function(FunctionElement):
__visit_name__ = "function"
+ _traverse_internals = FunctionElement._traverse_internals + [
+ ("packagenames", InternalTraversal.dp_plain_obj),
+ ("name", InternalTraversal.dp_string),
+ ("type", InternalTraversal.dp_type),
+ ]
+
def __init__(self, name, *clauses, **kw):
"""Construct a :class:`.Function`.
@@ -630,15 +622,8 @@ class Function(FunctionElement):
unique=True,
)
- def _cache_key(self, **kw):
- return (
- (Function,) + tuple(self.packagenames)
- if self.packagenames
- else () + (self.name, self.clause_expr._cache_key(**kw))
- )
-
-class _GenericMeta(VisitableType):
+class _GenericMeta(TraversibleType):
def __init__(cls, clsname, bases, clsdict):
if annotation.Annotated not in cls.__mro__:
cls.name = name = clsdict.get("name", clsname)
@@ -764,6 +749,10 @@ class next_value(GenericFunction):
type = sqltypes.Integer()
name = "next_value"
+ _traverse_internals = [
+ ("sequence", InternalTraversal.dp_named_ddl_element)
+ ]
+
def __init__(self, seq, **kw):
assert isinstance(
seq, schema.Sequence
@@ -771,21 +760,12 @@ class next_value(GenericFunction):
self._bind = kw.get("bind", None)
self.sequence = seq
- def _cache_key(self, **kw):
- return (next_value, self.sequence.name)
-
def compare(self, other, **kw):
return (
isinstance(other, next_value)
and self.sequence.name == other.sequence.name
)
- def get_children(self, **kwargs):
- return []
-
- def _copy_internals(self, **kw):
- pass
-
@property
def _from_objects(self):
return []
diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py
index 4e8f4a397..ee7dc61ce 100644
--- a/lib/sqlalchemy/sql/schema.py
+++ b/lib/sqlalchemy/sql/schema.py
@@ -50,6 +50,7 @@ from .elements import ColumnElement
from .elements import quoted_name
from .elements import TextClause
from .selectable import TableClause
+from .visitors import InternalTraversal
from .. import event
from .. import exc
from .. import inspection
@@ -425,6 +426,21 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
__visit_name__ = "table"
+ _traverse_internals = TableClause._traverse_internals + [
+ ("schema", InternalTraversal.dp_string)
+ ]
+
+ def _gen_cache_key(self, anon_map, bindparams):
+ return (self,)
+
+ @util.deprecated_params(
+ useexisting=(
+ "0.7",
+ "The :paramref:`.Table.useexisting` parameter is deprecated and "
+ "will be removed in a future release. Please use "
+ ":paramref:`.Table.extend_existing`.",
+ )
+ )
def __new__(cls, *args, **kw):
if not args:
# python3k pickle seems to call this
@@ -763,6 +779,8 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
def get_children(
self, column_collections=True, schema_visitor=False, **kw
):
+ # TODO: consider that we probably don't need column_collections=True
+ # at all, it does not seem to impact anything
if not schema_visitor:
return TableClause.get_children(
self, column_collections=column_collections, **kw
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index 6a7413fc0..4b3844eec 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -31,6 +31,7 @@ from .base import ColumnSet
from .base import DedupeColumnCollection
from .base import Executable
from .base import Generative
+from .base import HasMemoized
from .base import Immutable
from .coercions import _document_text_coercion
from .elements import _anonymous_label
@@ -39,11 +40,13 @@ from .elements import and_
from .elements import BindParameter
from .elements import ClauseElement
from .elements import ClauseList
+from .elements import ColumnClause
from .elements import GroupedElement
from .elements import Grouping
from .elements import literal_column
from .elements import True_
from .elements import UnaryExpression
+from .visitors import InternalTraversal
from .. import exc
from .. import util
@@ -201,6 +204,8 @@ class Selectable(ReturnsRows):
class HasPrefixes(object):
_prefixes = ()
+ _traverse_internals = [("_prefixes", InternalTraversal.dp_prefix_sequence)]
+
@_generative
@_document_text_coercion(
"expr",
@@ -252,6 +257,8 @@ class HasPrefixes(object):
class HasSuffixes(object):
_suffixes = ()
+ _traverse_internals = [("_suffixes", InternalTraversal.dp_prefix_sequence)]
+
@_generative
@_document_text_coercion(
"expr",
@@ -295,7 +302,7 @@ class HasSuffixes(object):
)
-class FromClause(roles.AnonymizedFromClauseRole, Selectable):
+class FromClause(HasMemoized, roles.AnonymizedFromClauseRole, Selectable):
"""Represent an element that can be used within the ``FROM``
clause of a ``SELECT`` statement.
@@ -529,11 +536,6 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
"""
return getattr(self, "name", self.__class__.__name__ + " object")
- def _reset_exported(self):
- """delete memoized collections when a FromClause is cloned."""
-
- self._memoized_property.expire_instance(self)
-
def _generate_fromclause_column_proxies(self, fromclause):
fromclause._columns._populate_separate_keys(
col._make_proxy(fromclause) for col in self.c
@@ -668,6 +670,14 @@ class Join(FromClause):
__visit_name__ = "join"
+ _traverse_internals = [
+ ("left", InternalTraversal.dp_clauseelement),
+ ("right", InternalTraversal.dp_clauseelement),
+ ("onclause", InternalTraversal.dp_clauseelement),
+ ("isouter", InternalTraversal.dp_boolean),
+ ("full", InternalTraversal.dp_boolean),
+ ]
+
_is_join = True
def __init__(self, left, right, onclause=None, isouter=False, full=False):
@@ -805,25 +815,6 @@ class Join(FromClause):
self.left._refresh_for_new_column(column)
self.right._refresh_for_new_column(column)
- def _copy_internals(self, clone=_clone, **kw):
- self._reset_exported()
- self.left = clone(self.left, **kw)
- self.right = clone(self.right, **kw)
- self.onclause = clone(self.onclause, **kw)
-
- def get_children(self, **kwargs):
- return self.left, self.right, self.onclause
-
- def _cache_key(self, **kw):
- return (
- Join,
- self.isouter,
- self.full,
- self.left._cache_key(**kw),
- self.right._cache_key(**kw),
- self.onclause._cache_key(**kw),
- )
-
def _match_primaries(self, left, right):
if isinstance(left, Join):
left_right = left.right
@@ -1175,6 +1166,11 @@ class AliasedReturnsRows(FromClause):
_is_from_container = True
named_with_column = True
+ _traverse_internals = [
+ ("element", InternalTraversal.dp_clauseelement),
+ ("name", InternalTraversal.dp_anon_name),
+ ]
+
def __init__(self, *arg, **kw):
raise NotImplementedError(
"The %s class is not intended to be constructed "
@@ -1243,18 +1239,13 @@ class AliasedReturnsRows(FromClause):
def _copy_internals(self, clone=_clone, **kw):
element = clone(self.element, **kw)
+
+ # the element clone is usually against a Table that returns the
+ # same object. don't reset exported .c. collections and other
+ # memoized details if nothing changed
if element is not self.element:
self._reset_exported()
- self.element = element
-
- def get_children(self, column_collections=True, **kw):
- if column_collections:
- for c in self.c:
- yield c
- yield self.element
-
- def _cache_key(self, **kw):
- return (self.__class__, self.element._cache_key(**kw), self._orig_name)
+ self.element = element
@property
def _from_objects(self):
@@ -1396,6 +1387,11 @@ class TableSample(AliasedReturnsRows):
__visit_name__ = "tablesample"
+ _traverse_internals = AliasedReturnsRows._traverse_internals + [
+ ("sampling", InternalTraversal.dp_clauseelement),
+ ("seed", InternalTraversal.dp_clauseelement),
+ ]
+
@classmethod
def _factory(cls, selectable, sampling, name=None, seed=None):
"""Return a :class:`.TableSample` object.
@@ -1466,6 +1462,16 @@ class CTE(Generative, HasSuffixes, AliasedReturnsRows):
__visit_name__ = "cte"
+ _traverse_internals = (
+ AliasedReturnsRows._traverse_internals
+ + [
+ ("_cte_alias", InternalTraversal.dp_clauseelement),
+ ("_restates", InternalTraversal.dp_clauseelement_unordered_set),
+ ("recursive", InternalTraversal.dp_boolean),
+ ]
+ + HasSuffixes._traverse_internals
+ )
+
@classmethod
def _factory(cls, selectable, name=None, recursive=False):
r"""Return a new :class:`.CTE`, or Common Table Expression instance.
@@ -1495,15 +1501,13 @@ class CTE(Generative, HasSuffixes, AliasedReturnsRows):
def _copy_internals(self, clone=_clone, **kw):
super(CTE, self)._copy_internals(clone, **kw)
+ # TODO: I don't like that we can't use the traversal data here
if self._cte_alias is not None:
self._cte_alias = clone(self._cte_alias, **kw)
self._restates = frozenset(
[clone(elem, **kw) for elem in self._restates]
)
- def _cache_key(self, *arg, **kw):
- raise NotImplementedError("TODO")
-
def alias(self, name=None, flat=False):
"""Return an :class:`.Alias` of this :class:`.CTE`.
@@ -1764,6 +1768,8 @@ class Subquery(AliasedReturnsRows):
class FromGrouping(GroupedElement, FromClause):
"""Represent a grouping of a FROM clause"""
+ _traverse_internals = [("element", InternalTraversal.dp_clauseelement)]
+
def __init__(self, element):
self.element = coercions.expect(roles.FromClauseRole, element)
@@ -1792,15 +1798,6 @@ class FromGrouping(GroupedElement, FromClause):
def _hide_froms(self):
return self.element._hide_froms
- def get_children(self, **kwargs):
- return (self.element,)
-
- def _copy_internals(self, clone=_clone, **kw):
- self.element = clone(self.element, **kw)
-
- def _cache_key(self, **kw):
- return (FromGrouping, self.element._cache_key(**kw))
-
@property
def _from_objects(self):
return self.element._from_objects
@@ -1843,6 +1840,14 @@ class TableClause(Immutable, FromClause):
__visit_name__ = "table"
+ _traverse_internals = [
+ (
+ "columns",
+ InternalTraversal.dp_fromclause_canonical_column_collection,
+ ),
+ ("name", InternalTraversal.dp_string),
+ ]
+
named_with_column = True
implicit_returning = False
@@ -1895,17 +1900,6 @@ class TableClause(Immutable, FromClause):
self._columns.add(c)
c.table = self
- def get_children(self, column_collections=True, **kwargs):
- if column_collections:
- return [c for c in self.c]
- else:
- return []
-
- def _cache_key(self, **kw):
- return (TableClause, self.name) + tuple(
- col._cache_key(**kw) for col in self._columns
- )
-
@util.dependencies("sqlalchemy.sql.dml")
def insert(self, dml, values=None, inline=False, **kwargs):
"""Generate an :func:`.insert` construct against this
@@ -1965,6 +1959,13 @@ class TableClause(Immutable, FromClause):
class ForUpdateArg(ClauseElement):
+ _traverse_internals = [
+ ("of", InternalTraversal.dp_clauseelement_list),
+ ("nowait", InternalTraversal.dp_boolean),
+ ("read", InternalTraversal.dp_boolean),
+ ("skip_locked", InternalTraversal.dp_boolean),
+ ]
+
@classmethod
def parse_legacy_select(self, arg):
"""Parse the for_update argument of :func:`.select`.
@@ -2029,19 +2030,6 @@ class ForUpdateArg(ClauseElement):
def __hash__(self):
return id(self)
- def _copy_internals(self, clone=_clone, **kw):
- if self.of is not None:
- self.of = [clone(col, **kw) for col in self.of]
-
- def _cache_key(self, **kw):
- return (
- ForUpdateArg,
- self.nowait,
- self.read,
- self.skip_locked,
- self.of._cache_key(**kw) if self.of is not None else None,
- )
-
def __init__(
self,
nowait=False,
@@ -2074,6 +2062,7 @@ class SelectBase(
roles.DMLSelectRole,
roles.CompoundElementRole,
roles.InElementRole,
+ HasMemoized,
HasCTE,
Executable,
SupportsCloneAnnotations,
@@ -2092,9 +2081,6 @@ class SelectBase(
_memoized_property = util.group_expirable_memoized_property()
- def _reset_memoizations(self):
- self._memoized_property.expire_instance(self)
-
def _generate_fromclause_column_proxies(self, fromclause):
# type: (FromClause)
raise NotImplementedError()
@@ -2339,6 +2325,7 @@ class SelectStatementGrouping(GroupedElement, SelectBase):
"""
__visit_name__ = "grouping"
+ _traverse_internals = [("element", InternalTraversal.dp_clauseelement)]
_is_select_container = True
@@ -2350,9 +2337,6 @@ class SelectStatementGrouping(GroupedElement, SelectBase):
def select_statement(self):
return self.element
- def get_children(self, **kwargs):
- return (self.element,)
-
def self_group(self, against=None):
# type: (Optional[Any]) -> FromClause
return self
@@ -2377,12 +2361,6 @@ class SelectStatementGrouping(GroupedElement, SelectBase):
"""
return self.element.selected_columns
- def _copy_internals(self, clone=_clone, **kw):
- self.element = clone(self.element, **kw)
-
- def _cache_key(self, **kw):
- return (SelectStatementGrouping, self.element._cache_key(**kw))
-
@property
def _from_objects(self):
return self.element._from_objects
@@ -2758,9 +2736,6 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase):
def _label_resolve_dict(self):
raise NotImplementedError()
- def _copy_internals(self, clone=_clone, **kw):
- raise NotImplementedError()
-
class CompoundSelect(GenerativeSelect):
"""Forms the basis of ``UNION``, ``UNION ALL``, and other
@@ -2785,6 +2760,16 @@ class CompoundSelect(GenerativeSelect):
__visit_name__ = "compound_select"
+ _traverse_internals = [
+ ("selects", InternalTraversal.dp_clauseelement_list),
+ ("_limit_clause", InternalTraversal.dp_clauseelement),
+ ("_offset_clause", InternalTraversal.dp_clauseelement),
+ ("_order_by_clause", InternalTraversal.dp_clauseelement),
+ ("_group_by_clause", InternalTraversal.dp_clauseelement),
+ ("_for_update_arg", InternalTraversal.dp_clauseelement),
+ ("keyword", InternalTraversal.dp_string),
+ ] + SupportsCloneAnnotations._traverse_internals
+
UNION = util.symbol("UNION")
UNION_ALL = util.symbol("UNION ALL")
EXCEPT = util.symbol("EXCEPT")
@@ -3004,47 +2989,6 @@ class CompoundSelect(GenerativeSelect):
"""
return self.selects[0].selected_columns
- def _copy_internals(self, clone=_clone, **kw):
- self._reset_memoizations()
- self.selects = [clone(s, **kw) for s in self.selects]
- if hasattr(self, "_col_map"):
- del self._col_map
- for attr in (
- "_limit_clause",
- "_offset_clause",
- "_order_by_clause",
- "_group_by_clause",
- "_for_update_arg",
- ):
- if getattr(self, attr) is not None:
- setattr(self, attr, clone(getattr(self, attr), **kw))
-
- def get_children(self, **kwargs):
- return [self._order_by_clause, self._group_by_clause] + list(
- self.selects
- )
-
- def _cache_key(self, **kw):
- return (
- (CompoundSelect, self.keyword)
- + tuple(stmt._cache_key(**kw) for stmt in self.selects)
- + (
- self._order_by_clause._cache_key(**kw)
- if self._order_by_clause is not None
- else None,
- )
- + (
- self._group_by_clause._cache_key(**kw)
- if self._group_by_clause is not None
- else None,
- )
- + (
- self._for_update_arg._cache_key(**kw)
- if self._for_update_arg is not None
- else None,
- )
- )
-
def bind(self):
if self._bind:
return self._bind
@@ -3193,11 +3137,35 @@ class Select(
_hints = util.immutabledict()
_statement_hints = ()
_distinct = False
- _from_cloned = None
+ _distinct_on = ()
_correlate = ()
_correlate_except = None
_memoized_property = SelectBase._memoized_property
+ _traverse_internals = (
+ [
+ ("_from_obj", InternalTraversal.dp_fromclause_ordered_set),
+ ("_raw_columns", InternalTraversal.dp_clauseelement_list),
+ ("_whereclause", InternalTraversal.dp_clauseelement),
+ ("_having", InternalTraversal.dp_clauseelement),
+ ("_order_by_clause", InternalTraversal.dp_clauseelement_list),
+ ("_group_by_clause", InternalTraversal.dp_clauseelement_list),
+ ("_correlate", InternalTraversal.dp_clauseelement_unordered_set),
+ (
+ "_correlate_except",
+ InternalTraversal.dp_clauseelement_unordered_set,
+ ),
+ ("_for_update_arg", InternalTraversal.dp_clauseelement),
+ ("_statement_hints", InternalTraversal.dp_statement_hint_list),
+ ("_hints", InternalTraversal.dp_table_hint_list),
+ ("_distinct", InternalTraversal.dp_boolean),
+ ("_distinct_on", InternalTraversal.dp_clauseelement_list),
+ ]
+ + HasPrefixes._traverse_internals
+ + HasSuffixes._traverse_internals
+ + SupportsCloneAnnotations._traverse_internals
+ )
+
@util.deprecated_params(
autocommit=(
"0.6",
@@ -3416,13 +3384,14 @@ class Select(
"""
self._auto_correlate = correlate
if distinct is not False:
- if distinct is True:
- self._distinct = True
- else:
- self._distinct = [
- coercions.expect(roles.WhereHavingRole, e)
- for e in util.to_list(distinct)
- ]
+ self._distinct = True
+ if not isinstance(distinct, bool):
+ self._distinct_on = tuple(
+ [
+ coercions.expect(roles.WhereHavingRole, e)
+ for e in util.to_list(distinct)
+ ]
+ )
if from_obj is not None:
self._from_obj = util.OrderedSet(
@@ -3472,15 +3441,17 @@ class Select(
GenerativeSelect.__init__(self, **kwargs)
+ # @_memoized_property
@property
def _froms(self):
- # would love to cache this,
- # but there's just enough edge cases, particularly now that
- # declarative encourages construction of SQL expressions
- # without tables present, to just regen this each time.
+ # current roadblock to caching is two tests that test that the
+ # SELECT can be compiled to a string, then a Table is created against
+ # columns, then it can be compiled again and works. this is somewhat
+ # valid as people make select() against declarative class where
+ # columns don't have their Table yet and perhaps some operations
+ # call upon _froms and cache it too soon.
froms = []
seen = set()
- translate = self._from_cloned
for item in itertools.chain(
_from_objects(*self._raw_columns),
@@ -3493,8 +3464,6 @@ class Select(
raise exc.InvalidRequestError(
"select() construct refers to itself as a FROM"
)
- if translate and item in translate:
- item = translate[item]
if not seen.intersection(item._cloned_set):
froms.append(item)
seen.update(item._cloned_set)
@@ -3518,15 +3487,6 @@ class Select(
itertools.chain(*[_expand_cloned(f._hide_froms) for f in froms])
)
if toremove:
- # if we're maintaining clones of froms,
- # add the copies out to the toremove list. only include
- # clones that are lexical equivalents.
- if self._from_cloned:
- toremove.update(
- self._from_cloned[f]
- for f in toremove.intersection(self._from_cloned)
- if self._from_cloned[f]._is_lexical_equivalent(f)
- )
# filter out to FROM clauses not in the list,
# using a list to maintain ordering
froms = [f for f in froms if f not in toremove]
@@ -3707,7 +3667,6 @@ class Select(
return False
def _copy_internals(self, clone=_clone, **kw):
-
# Select() object has been cloned and probably adapted by the
# given clone function. Apply the cloning function to internal
# objects
@@ -3719,37 +3678,42 @@ class Select(
# as of 0.7.4 we also put the current version of _froms, which
# gets cleared on each generation. previously we were "baking"
# _froms into self._from_obj.
- self._from_cloned = from_cloned = dict(
- (f, clone(f, **kw)) for f in self._from_obj.union(self._froms)
- )
- # 3. update persistent _from_obj with the cloned versions.
- self._from_obj = util.OrderedSet(
- from_cloned[f] for f in self._from_obj
+ all_the_froms = list(
+ itertools.chain(
+ _from_objects(*self._raw_columns),
+ _from_objects(self._whereclause)
+ if self._whereclause is not None
+ else (),
+ )
)
+ new_froms = {f: clone(f, **kw) for f in all_the_froms}
+ # copy FROM collections
- # the _correlate collection is done separately, what can happen
- # here is the same item is _correlate as in _from_obj but the
- # _correlate version has an annotation on it - (specifically
- # RelationshipProperty.Comparator._criterion_exists() does
- # this). Also keep _correlate liberally open with its previous
- # contents, as this set is used for matching, not rendering.
- self._correlate = set(clone(f) for f in self._correlate).union(
- self._correlate
- )
+ self._from_obj = util.OrderedSet(
+ clone(f, **kw) for f in self._from_obj
+ ).union(f for f in new_froms.values() if isinstance(f, Join))
- # do something similar for _correlate_except - this is a more
- # unusual case but same idea applies
+ self._correlate = set(clone(f) for f in self._correlate)
if self._correlate_except:
self._correlate_except = set(
clone(f) for f in self._correlate_except
- ).union(self._correlate_except)
+ )
# 4. clone other things. The difficulty here is that Column
- # objects are not actually cloned, and refer to their original
- # .table, resulting in the wrong "from" parent after a clone
- # operation. Hence _from_cloned and _from_obj supersede what is
- # present here.
+ # objects are usually not altered by a straight clone because they
+ # are dependent on the FROM cloning we just did above in order to
+ # be targeted correctly, or a new FROM we have might be a JOIN
+ # object which doesn't have its own columns. so give the cloner a
+ # hint.
+ def replace(obj, **kw):
+ if isinstance(obj, ColumnClause) and obj.table in new_froms:
+ newelem = new_froms[obj.table].corresponding_column(obj)
+ return newelem
+
+ kw["replace"] = replace
+
+ # TODO: I'd still like to try to leverage the traversal data
self._raw_columns = [clone(c, **kw) for c in self._raw_columns]
for attr in (
"_limit_clause",
@@ -3763,67 +3727,12 @@ class Select(
if getattr(self, attr) is not None:
setattr(self, attr, clone(getattr(self, attr), **kw))
- # erase _froms collection,
- # etc.
self._reset_memoizations()
def get_children(self, **kwargs):
- """return child elements as per the ClauseElement specification."""
-
- return (
- self._raw_columns
- + list(self._froms)
- + [
- x
- for x in (
- self._whereclause,
- self._having,
- self._order_by_clause,
- self._group_by_clause,
- )
- if x is not None
- ]
- )
-
- def _cache_key(self, **kw):
- return (
- (Select,)
- + ("raw_columns",)
- + tuple(elem._cache_key(**kw) for elem in self._raw_columns)
- + ("elements",)
- + tuple(
- elem._cache_key(**kw) if elem is not None else None
- for elem in (
- self._whereclause,
- self._having,
- self._order_by_clause,
- self._group_by_clause,
- )
- )
- + ("from_obj",)
- + tuple(elem._cache_key(**kw) for elem in self._from_obj)
- + ("correlate",)
- + tuple(
- elem._cache_key(**kw)
- for elem in (
- self._correlate if self._correlate is not None else ()
- )
- )
- + ("correlate_except",)
- + tuple(
- elem._cache_key(**kw)
- for elem in (
- self._correlate_except
- if self._correlate_except is not None
- else ()
- )
- )
- + ("for_update",),
- (
- self._for_update_arg._cache_key(**kw)
- if self._for_update_arg is not None
- else None,
- ),
+ # TODO: define "get_children" traversal items separately?
+ return self._froms + super(Select, self).get_children(
+ omit_attrs=["_from_obj", "_correlate", "_correlate_except"]
)
@_generative
@@ -3987,10 +3896,8 @@ class Select(
"""
if expr:
expr = [coercions.expect(roles.ByOfRole, e) for e in expr]
- if isinstance(self._distinct, list):
- self._distinct = self._distinct + expr
- else:
- self._distinct = expr
+ self._distinct = True
+ self._distinct_on = self._distinct_on + tuple(expr)
else:
self._distinct = True
@@ -4489,6 +4396,11 @@ class TextualSelect(SelectBase):
__visit_name__ = "textual_select"
+ _traverse_internals = [
+ ("element", InternalTraversal.dp_clauseelement),
+ ("column_args", InternalTraversal.dp_clauseelement_list),
+ ] + SupportsCloneAnnotations._traverse_internals
+
_is_textual = True
def __init__(self, text, columns, positional=False):
@@ -4534,18 +4446,6 @@ class TextualSelect(SelectBase):
c._make_proxy(fromclause) for c in self.column_args
)
- def _copy_internals(self, clone=_clone, **kw):
- self._reset_memoizations()
- self.element = clone(self.element, **kw)
-
- def get_children(self, **kw):
- return [self.element]
-
- def _cache_key(self, **kw):
- return (TextualSelect, self.element._cache_key(**kw)) + tuple(
- col._cache_key(**kw) for col in self.column_args
- )
-
def _scalar_type(self):
return self.column_args[0].type
diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py
new file mode 100644
index 000000000..c0782ce48
--- /dev/null
+++ b/lib/sqlalchemy/sql/traversals.py
@@ -0,0 +1,768 @@
+from collections import deque
+from collections import namedtuple
+
+from . import operators
+from .visitors import ExtendedInternalTraversal
+from .visitors import InternalTraversal
+from .. import inspect
+from .. import util
+
+SKIP_TRAVERSE = util.symbol("skip_traverse")
+COMPARE_FAILED = False
+COMPARE_SUCCEEDED = True
+NO_CACHE = util.symbol("no_cache")
+
+
+def compare(obj1, obj2, **kw):
+ if kw.get("use_proxies", False):
+ strategy = ColIdentityComparatorStrategy()
+ else:
+ strategy = TraversalComparatorStrategy()
+
+ return strategy.compare(obj1, obj2, **kw)
+
+
+class HasCacheKey(object):
+ _cache_key_traversal = NO_CACHE
+
+ def _gen_cache_key(self, anon_map, bindparams):
+ """return an optional cache key.
+
+ The cache key is a tuple which can contain any series of
+ objects that are hashable and also identifies
+ this object uniquely within the presence of a larger SQL expression
+ or statement, for the purposes of caching the resulting query.
+
+ The cache key should be based on the SQL compiled structure that would
+ ultimately be produced. That is, two structures that are composed in
+ exactly the same way should produce the same cache key; any difference
+ in the strucures that would affect the SQL string or the type handlers
+ should result in a different cache key.
+
+ If a structure cannot produce a useful cache key, it should raise
+ NotImplementedError, which will result in the entire structure
+ for which it's part of not being useful as a cache key.
+
+
+ """
+
+ if self in anon_map:
+ return (anon_map[self], self.__class__)
+
+ id_ = anon_map[self]
+
+ if self._cache_key_traversal is NO_CACHE:
+ anon_map[NO_CACHE] = True
+ return None
+
+ result = (id_, self.__class__)
+
+ for attrname, obj, meth in _cache_key_traversal.run_generated_dispatch(
+ self, self._cache_key_traversal, "_generated_cache_key_traversal"
+ ):
+ if obj is not None:
+ result += meth(attrname, obj, self, anon_map, bindparams)
+ return result
+
+ def _generate_cache_key(self):
+ """return a cache key.
+
+ The cache key is a tuple which can contain any series of
+ objects that are hashable and also identifies
+ this object uniquely within the presence of a larger SQL expression
+ or statement, for the purposes of caching the resulting query.
+
+ The cache key should be based on the SQL compiled structure that would
+ ultimately be produced. That is, two structures that are composed in
+ exactly the same way should produce the same cache key; any difference
+ in the strucures that would affect the SQL string or the type handlers
+ should result in a different cache key.
+
+ The cache key returned by this method is an instance of
+ :class:`.CacheKey`, which consists of a tuple representing the
+ cache key, as well as a list of :class:`.BindParameter` objects
+ which are extracted from the expression. While two expressions
+ that produce identical cache key tuples will themselves generate
+ identical SQL strings, the list of :class:`.BindParameter` objects
+ indicates the bound values which may have different values in
+ each one; these bound parameters must be consulted in order to
+ execute the statement with the correct parameters.
+
+ a :class:`.ClauseElement` structure that does not implement
+ a :meth:`._gen_cache_key` method and does not implement a
+ :attr:`.traverse_internals` attribute will not be cacheable; when
+ such an element is embedded into a larger structure, this method
+ will return None, indicating no cache key is available.
+
+ """
+ bindparams = []
+
+ _anon_map = anon_map()
+ key = self._gen_cache_key(_anon_map, bindparams)
+ if NO_CACHE in _anon_map:
+ return None
+ else:
+ return CacheKey(key, bindparams)
+
+
+class CacheKey(namedtuple("CacheKey", ["key", "bindparams"])):
+ def __hash__(self):
+ return hash(self.key)
+
+ def __eq__(self, other):
+ return self.key == other.key
+
+
+def _clone(element, **kw):
+ return element._clone()
+
+
+class _CacheKey(ExtendedInternalTraversal):
+ def visit_has_cache_key(self, attrname, obj, parent, anon_map, bindparams):
+ return (attrname, obj._gen_cache_key(anon_map, bindparams))
+
+ def visit_inspectable(self, attrname, obj, parent, anon_map, bindparams):
+ return self.visit_has_cache_key(
+ attrname, inspect(obj), parent, anon_map, bindparams
+ )
+
+ def visit_clauseelement(self, attrname, obj, parent, anon_map, bindparams):
+ return (attrname, obj._gen_cache_key(anon_map, bindparams))
+
+ def visit_multi(self, attrname, obj, parent, anon_map, bindparams):
+ return (
+ attrname,
+ obj._gen_cache_key(anon_map, bindparams)
+ if isinstance(obj, HasCacheKey)
+ else obj,
+ )
+
+ def visit_multi_list(self, attrname, obj, parent, anon_map, bindparams):
+ return (
+ attrname,
+ tuple(
+ elem._gen_cache_key(anon_map, bindparams)
+ if isinstance(elem, HasCacheKey)
+ else elem
+ for elem in obj
+ ),
+ )
+
+ def visit_has_cache_key_tuples(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (
+ attrname,
+ tuple(
+ tuple(
+ elem._gen_cache_key(anon_map, bindparams)
+ for elem in tup_elem
+ )
+ for tup_elem in obj
+ ),
+ )
+
+ def visit_has_cache_key_list(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (
+ attrname,
+ tuple(elem._gen_cache_key(anon_map, bindparams) for elem in obj),
+ )
+
+ def visit_inspectable_list(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return self.visit_has_cache_key_list(
+ attrname, [inspect(o) for o in obj], parent, anon_map, bindparams
+ )
+
+ def visit_clauseelement_list(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (
+ attrname,
+ tuple(elem._gen_cache_key(anon_map, bindparams) for elem in obj),
+ )
+
+ def visit_clauseelement_tuples(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return self.visit_has_cache_key_tuples(
+ attrname, obj, parent, anon_map, bindparams
+ )
+
+ def visit_anon_name(self, attrname, obj, parent, anon_map, bindparams):
+ from . import elements
+
+ name = obj
+ if isinstance(name, elements._anonymous_label):
+ name = name.apply_map(anon_map)
+
+ return (attrname, name)
+
+ def visit_fromclause_ordered_set(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (
+ attrname,
+ tuple(elem._gen_cache_key(anon_map, bindparams) for elem in obj),
+ )
+
+ def visit_clauseelement_unordered_set(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ cache_keys = [
+ elem._gen_cache_key(anon_map, bindparams) for elem in obj
+ ]
+ return (
+ attrname,
+ tuple(
+ sorted(cache_keys)
+ ), # cache keys all start with (id_, class)
+ )
+
+ def visit_named_ddl_element(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (attrname, obj.name)
+
+ def visit_prefix_sequence(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (
+ attrname,
+ tuple(
+ (clause._gen_cache_key(anon_map, bindparams), strval)
+ for clause, strval in obj
+ ),
+ )
+
+ def visit_statement_hint_list(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (attrname, obj)
+
+ def visit_table_hint_list(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (
+ attrname,
+ tuple(
+ (
+ clause._gen_cache_key(anon_map, bindparams),
+ dialect_name,
+ text,
+ )
+ for (clause, dialect_name), text in obj.items()
+ ),
+ )
+
+ def visit_type(self, attrname, obj, parent, anon_map, bindparams):
+ return (attrname, obj._gen_cache_key)
+
+ def visit_plain_dict(self, attrname, obj, parent, anon_map, bindparams):
+ return (attrname, tuple((key, obj[key]) for key in sorted(obj)))
+
+ def visit_string_clauseelement_dict(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (
+ attrname,
+ tuple(
+ (key, obj[key]._gen_cache_key(anon_map, bindparams))
+ for key in sorted(obj)
+ ),
+ )
+
+ def visit_string_multi_dict(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (
+ attrname,
+ tuple(
+ (
+ key,
+ value._gen_cache_key(anon_map, bindparams)
+ if isinstance(value, HasCacheKey)
+ else value,
+ )
+ for key, value in [(key, obj[key]) for key in sorted(obj)]
+ ),
+ )
+
+ def visit_string(self, attrname, obj, parent, anon_map, bindparams):
+ return (attrname, obj)
+
+ def visit_boolean(self, attrname, obj, parent, anon_map, bindparams):
+ return (attrname, obj)
+
+ def visit_operator(self, attrname, obj, parent, anon_map, bindparams):
+ return (attrname, obj)
+
+ def visit_plain_obj(self, attrname, obj, parent, anon_map, bindparams):
+ return (attrname, obj)
+
+ def visit_fromclause_canonical_column_collection(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (
+ attrname,
+ tuple(col._gen_cache_key(anon_map, bindparams) for col in obj),
+ )
+
+ def visit_annotations_state(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (
+ attrname,
+ tuple(
+ (
+ key,
+ self.dispatch(sym)(
+ key, obj[key], obj, anon_map, bindparams
+ ),
+ )
+ for key, sym in parent._annotation_traversals
+ ),
+ )
+
+ def visit_unknown_structure(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ anon_map[NO_CACHE] = True
+ return ()
+
+
+_cache_key_traversal = _CacheKey()
+
+
+class _CopyInternals(InternalTraversal):
+ """Generate a _copy_internals internal traversal dispatch for classes
+ with a _traverse_internals collection."""
+
+ def visit_clauseelement(self, parent, element, clone=_clone, **kw):
+ return clone(element, **kw)
+
+ def visit_clauseelement_list(self, parent, element, clone=_clone, **kw):
+ return [clone(clause, **kw) for clause in element]
+
+ def visit_clauseelement_tuples(self, parent, element, clone=_clone, **kw):
+ return [
+ tuple(clone(tup_elem, **kw) for tup_elem in elem)
+ for elem in element
+ ]
+
+ def visit_string_clauseelement_dict(
+ self, parent, element, clone=_clone, **kw
+ ):
+ return dict(
+ (key, clone(value, **kw)) for key, value in element.items()
+ )
+
+
+_copy_internals = _CopyInternals()
+
+
+class _GetChildren(InternalTraversal):
+ """Generate a _children_traversal internal traversal dispatch for classes
+ with a _traverse_internals collection."""
+
+ def visit_has_cache_key(self, element, **kw):
+ return (element,)
+
+ def visit_clauseelement(self, element, **kw):
+ return (element,)
+
+ def visit_clauseelement_list(self, element, **kw):
+ return tuple(element)
+
+ def visit_clauseelement_tuples(self, element, **kw):
+ tup = ()
+ for elem in element:
+ tup += elem
+ return tup
+
+ def visit_fromclause_canonical_column_collection(self, element, **kw):
+ if kw.get("column_collections", False):
+ return tuple(element)
+ else:
+ return ()
+
+ def visit_string_clauseelement_dict(self, element, **kw):
+ return tuple(element.values())
+
+ def visit_fromclause_ordered_set(self, element, **kw):
+ return tuple(element)
+
+ def visit_clauseelement_unordered_set(self, element, **kw):
+ return tuple(element)
+
+
+_get_children = _GetChildren()
+
+
+@util.dependencies("sqlalchemy.sql.elements")
+def _resolve_name_for_compare(elements, element, name, anon_map, **kw):
+ if isinstance(name, elements._anonymous_label):
+ name = name.apply_map(anon_map)
+
+ return name
+
+
+class anon_map(dict):
+ """A map that creates new keys for missing key access.
+
+ Produces an incrementing sequence given a series of unique keys.
+
+ This is similar to the compiler prefix_anon_map class although simpler.
+
+ Inlines the approach taken by :class:`sqlalchemy.util.PopulateDict` which
+ is otherwise usually used for this type of operation.
+
+ """
+
+ def __init__(self):
+ self.index = 0
+
+ def __missing__(self, key):
+ self[key] = val = str(self.index)
+ self.index += 1
+ return val
+
+
+class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
+ __slots__ = "stack", "cache", "anon_map"
+
+ def __init__(self):
+ self.stack = deque()
+ self.cache = set()
+
+ def _memoized_attr_anon_map(self):
+ return (anon_map(), anon_map())
+
+ def compare(self, obj1, obj2, **kw):
+ stack = self.stack
+ cache = self.cache
+
+ compare_annotations = kw.get("compare_annotations", False)
+
+ stack.append((obj1, obj2))
+
+ while stack:
+ left, right = stack.popleft()
+
+ if left is right:
+ continue
+ elif left is None or right is None:
+ # we know they are different so no match
+ return False
+ elif (left, right) in cache:
+ continue
+ cache.add((left, right))
+
+ visit_name = left.__visit_name__
+ if visit_name != right.__visit_name__:
+ return False
+
+ meth = getattr(self, "compare_%s" % visit_name, None)
+
+ if meth:
+ attributes_compared = meth(left, right, **kw)
+ if attributes_compared is COMPARE_FAILED:
+ return False
+ elif attributes_compared is SKIP_TRAVERSE:
+ continue
+
+ # attributes_compared is returned as a list of attribute
+ # names that were "handled" by the comparison method above.
+ # remaining attribute names in the _traverse_internals
+ # will be compared.
+ else:
+ attributes_compared = ()
+
+ for (
+ (left_attrname, left_visit_sym),
+ (right_attrname, right_visit_sym),
+ ) in util.zip_longest(
+ left._traverse_internals,
+ right._traverse_internals,
+ fillvalue=(None, None),
+ ):
+ if (
+ left_attrname != right_attrname
+ or left_visit_sym is not right_visit_sym
+ ):
+ if not compare_annotations and (
+ (
+ left_visit_sym
+ is InternalTraversal.dp_annotations_state,
+ )
+ or (
+ right_visit_sym
+ is InternalTraversal.dp_annotations_state,
+ )
+ ):
+ continue
+
+ return False
+ elif left_attrname in attributes_compared:
+ continue
+
+ dispatch = self.dispatch(left_visit_sym)
+ left_child = getattr(left, left_attrname)
+ right_child = getattr(right, right_attrname)
+ if left_child is None:
+ if right_child is not None:
+ return False
+ else:
+ continue
+
+ comparison = dispatch(
+ left, left_child, right, right_child, **kw
+ )
+ if comparison is COMPARE_FAILED:
+ return False
+
+ return True
+
+ def compare_inner(self, obj1, obj2, **kw):
+ comparator = self.__class__()
+ return comparator.compare(obj1, obj2, **kw)
+
+ def visit_has_cache_key(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ if left._gen_cache_key(self.anon_map[0], []) != right._gen_cache_key(
+ self.anon_map[1], []
+ ):
+ return COMPARE_FAILED
+
+ def visit_clauseelement(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ self.stack.append((left, right))
+
+ def visit_fromclause_canonical_column_collection(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ for lcol, rcol in util.zip_longest(left, right, fillvalue=None):
+ self.stack.append((lcol, rcol))
+
+ def visit_fromclause_derived_column_collection(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ pass
+
+ def visit_string_clauseelement_dict(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ for lstr, rstr in util.zip_longest(
+ sorted(left), sorted(right), fillvalue=None
+ ):
+ if lstr != rstr:
+ return COMPARE_FAILED
+ self.stack.append((left[lstr], right[rstr]))
+
+ def visit_annotations_state(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ if not kw.get("compare_annotations", False):
+ return
+
+ for (lstr, lmeth), (rstr, rmeth) in util.zip_longest(
+ left_parent._annotation_traversals,
+ right_parent._annotation_traversals,
+ fillvalue=(None, None),
+ ):
+ if lstr != rstr or (lmeth is not rmeth):
+ return COMPARE_FAILED
+
+ dispatch = self.dispatch(lmeth)
+ left_child = left[lstr]
+ right_child = right[rstr]
+ if left_child is None:
+ if right_child is not None:
+ return False
+ else:
+ continue
+
+ comparison = dispatch(None, left_child, None, right_child, **kw)
+ if comparison is COMPARE_FAILED:
+ return comparison
+
+ def visit_clauseelement_tuples(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ for ltup, rtup in util.zip_longest(left, right, fillvalue=None):
+ if ltup is None or rtup is None:
+ return COMPARE_FAILED
+
+ for l, r in util.zip_longest(ltup, rtup, fillvalue=None):
+ self.stack.append((l, r))
+
+ def visit_clauseelement_list(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ for l, r in util.zip_longest(left, right, fillvalue=None):
+ self.stack.append((l, r))
+
+ def _compare_unordered_sequences(self, seq1, seq2, **kw):
+ if seq1 is None:
+ return seq2 is None
+
+ completed = set()
+ for clause in seq1:
+ for other_clause in set(seq2).difference(completed):
+ if self.compare_inner(clause, other_clause, **kw):
+ completed.add(other_clause)
+ break
+ return len(completed) == len(seq1) == len(seq2)
+
+ def visit_clauseelement_unordered_set(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ return self._compare_unordered_sequences(left, right, **kw)
+
+ def visit_fromclause_ordered_set(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ for l, r in util.zip_longest(left, right, fillvalue=None):
+ self.stack.append((l, r))
+
+ def visit_string(self, left_parent, left, right_parent, right, **kw):
+ return left == right
+
+ def visit_anon_name(self, left_parent, left, right_parent, right, **kw):
+ return _resolve_name_for_compare(
+ left_parent, left, self.anon_map[0], **kw
+ ) == _resolve_name_for_compare(
+ right_parent, right, self.anon_map[1], **kw
+ )
+
+ def visit_boolean(self, left_parent, left, right_parent, right, **kw):
+ return left == right
+
+ def visit_operator(self, left_parent, left, right_parent, right, **kw):
+ return left is right
+
+ def visit_type(self, left_parent, left, right_parent, right, **kw):
+ return left._compare_type_affinity(right)
+
+ def visit_plain_dict(self, left_parent, left, right_parent, right, **kw):
+ return left == right
+
+ def visit_plain_obj(self, left_parent, left, right_parent, right, **kw):
+ return left == right
+
+ def visit_named_ddl_element(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ if left is None:
+ if right is not None:
+ return COMPARE_FAILED
+
+ return left.name == right.name
+
+ def visit_prefix_sequence(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ for (l_clause, l_str), (r_clause, r_str) in util.zip_longest(
+ left, right, fillvalue=(None, None)
+ ):
+ if l_str != r_str:
+ return COMPARE_FAILED
+ else:
+ self.stack.append((l_clause, r_clause))
+
+ def visit_table_hint_list(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ left_keys = sorted(left, key=lambda elem: (elem[0].fullname, elem[1]))
+ right_keys = sorted(
+ right, key=lambda elem: (elem[0].fullname, elem[1])
+ )
+ for (ltable, ldialect), (rtable, rdialect) in util.zip_longest(
+ left_keys, right_keys, fillvalue=(None, None)
+ ):
+ if ldialect != rdialect:
+ return COMPARE_FAILED
+ elif left[(ltable, ldialect)] != right[(rtable, rdialect)]:
+ return COMPARE_FAILED
+ else:
+ self.stack.append((ltable, rtable))
+
+ def visit_statement_hint_list(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ return left == right
+
+ def visit_unknown_structure(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ raise NotImplementedError()
+
+ def compare_clauselist(self, left, right, **kw):
+ if left.operator is right.operator:
+ if operators.is_associative(left.operator):
+ if self._compare_unordered_sequences(
+ left.clauses, right.clauses, **kw
+ ):
+ return ["operator", "clauses"]
+ else:
+ return COMPARE_FAILED
+ else:
+ return ["operator"]
+ else:
+ return COMPARE_FAILED
+
+ def compare_binary(self, left, right, **kw):
+ if left.operator == right.operator:
+ if operators.is_commutative(left.operator):
+ if (
+ compare(left.left, right.left, **kw)
+ and compare(left.right, right.right, **kw)
+ ) or (
+ compare(left.left, right.right, **kw)
+ and compare(left.right, right.left, **kw)
+ ):
+ return ["operator", "negate", "left", "right"]
+ else:
+ return COMPARE_FAILED
+ else:
+ return ["operator", "negate"]
+ else:
+ return COMPARE_FAILED
+
+
+class ColIdentityComparatorStrategy(TraversalComparatorStrategy):
+ def compare_column_element(
+ self, left, right, use_proxies=True, equivalents=(), **kw
+ ):
+ """Compare ColumnElements using proxies and equivalent collections.
+
+ This is a comparison strategy specific to the ORM.
+ """
+
+ to_compare = (right,)
+ if equivalents and right in equivalents:
+ to_compare = equivalents[right].union(to_compare)
+
+ for oth in to_compare:
+ if use_proxies and left.shares_lineage(oth):
+ return SKIP_TRAVERSE
+ elif hash(left) == hash(right):
+ return SKIP_TRAVERSE
+ else:
+ return COMPARE_FAILED
+
+ def compare_column(self, left, right, **kw):
+ return self.compare_column_element(left, right, **kw)
+
+ def compare_label(self, left, right, **kw):
+ return self.compare_column_element(left, right, **kw)
+
+ def compare_table(self, left, right, **kw):
+ # tables compare on identity, since it's not really feasible to
+ # compare them column by column with the above rules
+ return SKIP_TRAVERSE if left is right else COMPARE_FAILED
diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py
index 9c5f5dd47..d09bb28bb 100644
--- a/lib/sqlalchemy/sql/type_api.py
+++ b/lib/sqlalchemy/sql/type_api.py
@@ -12,8 +12,8 @@
from . import operators
from .base import SchemaEventTarget
-from .visitors import Visitable
-from .visitors import VisitableType
+from .visitors import Traversible
+from .visitors import TraversibleType
from .. import exc
from .. import util
@@ -28,7 +28,7 @@ INDEXABLE = None
_resolve_value_to_type = None
-class TypeEngine(Visitable):
+class TypeEngine(Traversible):
"""The ultimate base class for all SQL datatypes.
Common subclasses of :class:`.TypeEngine` include
@@ -535,8 +535,13 @@ class TypeEngine(Visitable):
return dialect.type_descriptor(self)
@util.memoized_property
- def _cache_key(self):
- return util.constructor_key(self, self.__class__)
+ def _gen_cache_key(self):
+ names = util.get_cls_kwargs(self.__class__)
+ return (self.__class__,) + tuple(
+ (k, self.__dict__[k])
+ for k in names
+ if k in self.__dict__ and not k.startswith("_")
+ )
def adapt(self, cls, **kw):
"""Produce an "adapted" form of this type, given an "impl" class
@@ -617,7 +622,7 @@ class TypeEngine(Visitable):
return util.generic_repr(self)
-class VisitableCheckKWArg(util.EnsureKWArgType, VisitableType):
+class VisitableCheckKWArg(util.EnsureKWArgType, TraversibleType):
pass
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index e109852a2..8539f4845 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -734,7 +734,7 @@ def criterion_as_pairs(
return pairs
-class ClauseAdapter(visitors.ReplacingCloningVisitor):
+class ClauseAdapter(visitors.ReplacingExternalTraversal):
"""Clones and modifies clauses based on column correspondence.
E.g.::
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
index 7b2ac285a..8c06eb8af 100644
--- a/lib/sqlalchemy/sql/visitors.py
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -28,14 +28,10 @@ import operator
from .. import exc
from .. import util
-
+from ..util import langhelpers
+from ..util import symbol
__all__ = [
- "VisitableType",
- "Visitable",
- "ClauseVisitor",
- "CloningVisitor",
- "ReplacingCloningVisitor",
"iterate",
"iterate_depthfirst",
"traverse_using",
@@ -43,85 +39,382 @@ __all__ = [
"traverse_depthfirst",
"cloned_traverse",
"replacement_traverse",
+ "Traversible",
+ "TraversibleType",
+ "ExternalTraversal",
+ "InternalTraversal",
]
-class VisitableType(type):
- """Metaclass which assigns a ``_compiler_dispatch`` method to classes
- having a ``__visit_name__`` attribute.
+def _generate_compiler_dispatch(cls):
+ """Generate a _compiler_dispatch() external traversal on classes with a
+ __visit_name__ attribute.
+
+ """
+ visit_name = cls.__visit_name__
+
+ if isinstance(visit_name, util.compat.string_types):
+ # There is an optimization opportunity here because the
+ # the string name of the class's __visit_name__ is known at
+ # this early stage (import time) so it can be pre-constructed.
+ getter = operator.attrgetter("visit_%s" % visit_name)
+
+ def _compiler_dispatch(self, visitor, **kw):
+ try:
+ meth = getter(visitor)
+ except AttributeError:
+ raise exc.UnsupportedCompilationError(visitor, cls)
+ else:
+ return meth(self, **kw)
+
+ else:
+ # The optimization opportunity is lost for this case because the
+ # __visit_name__ is not yet a string. As a result, the visit
+ # string has to be recalculated with each compilation.
+ def _compiler_dispatch(self, visitor, **kw):
+ visit_attr = "visit_%s" % self.__visit_name__
+ try:
+ meth = getattr(visitor, visit_attr)
+ except AttributeError:
+ raise exc.UnsupportedCompilationError(visitor, cls)
+ else:
+ return meth(self, **kw)
+
+ _compiler_dispatch.__doc__ = """Look for an attribute named "visit_"
+ + self.__visit_name__ on the visitor, and call it with the same
+ kw params.
+ """
+ cls._compiler_dispatch = _compiler_dispatch
+
+
+class TraversibleType(type):
+ """Metaclass which assigns dispatch attributes to various kinds of
+ "visitable" classes.
- The ``_compiler_dispatch`` attribute becomes an instance method which
- looks approximately like the following::
+ Attributes include:
- def _compiler_dispatch (self, visitor, **kw):
- '''Look for an attribute named "visit_" + self.__visit_name__
- on the visitor, and call it with the same kw params.'''
- visit_attr = 'visit_%s' % self.__visit_name__
- return getattr(visitor, visit_attr)(self, **kw)
+ * The ``_compiler_dispatch`` method, corresponding to ``__visit_name__``.
+ This is called "external traversal" because the caller of each visit()
+ method is responsible for sub-traversing the inner elements of each
+ object. This is appropriate for string compilers and other traversals
+ that need to call upon the inner elements in a specific pattern.
- Classes having no ``__visit_name__`` attribute will remain unaffected.
+ * internal traversal collections ``_children_traversal``,
+ ``_cache_key_traversal``, ``_copy_internals_traversal``, generated from
+ an optional ``_traverse_internals`` collection of symbols which comes
+ from the :class:`.InternalTraversal` list of symbols. This is called
+ "internal traversal" MARKMARK
"""
def __init__(cls, clsname, bases, clsdict):
- if clsname != "Visitable" and hasattr(cls, "__visit_name__"):
- _generate_dispatch(cls)
+ if clsname != "Traversible":
+ if "__visit_name__" in clsdict:
+ _generate_compiler_dispatch(cls)
+
+ super(TraversibleType, cls).__init__(clsname, bases, clsdict)
- super(VisitableType, cls).__init__(clsname, bases, clsdict)
+class Traversible(util.with_metaclass(TraversibleType)):
+ """Base class for visitable objects, applies the
+ :class:`.visitors.TraversibleType` metaclass.
-def _generate_dispatch(cls):
- """Return an optimized visit dispatch function for the cls
- for use by the compiler.
"""
- if "__visit_name__" in cls.__dict__:
- visit_name = cls.__visit_name__
- if isinstance(visit_name, util.compat.string_types):
- # There is an optimization opportunity here because the
- # the string name of the class's __visit_name__ is known at
- # this early stage (import time) so it can be pre-constructed.
- getter = operator.attrgetter("visit_%s" % visit_name)
- def _compiler_dispatch(self, visitor, **kw):
- try:
- meth = getter(visitor)
- except AttributeError:
- raise exc.UnsupportedCompilationError(visitor, cls)
- else:
- return meth(self, **kw)
+class _InternalTraversalType(type):
+ def __init__(cls, clsname, bases, clsdict):
+ if cls.__name__ in ("InternalTraversal", "ExtendedInternalTraversal"):
+ lookup = {}
+ for key, sym in clsdict.items():
+ if key.startswith("dp_"):
+ visit_key = key.replace("dp_", "visit_")
+ sym_name = sym.name
+ assert sym_name not in lookup, sym_name
+ lookup[sym] = lookup[sym_name] = visit_key
+ if hasattr(cls, "_dispatch_lookup"):
+ lookup.update(cls._dispatch_lookup)
+ cls._dispatch_lookup = lookup
+
+ super(_InternalTraversalType, cls).__init__(clsname, bases, clsdict)
+
+
+def _generate_dispatcher(visitor, internal_dispatch, method_name):
+ names = []
+ for attrname, visit_sym in internal_dispatch:
+ meth = visitor.dispatch(visit_sym)
+ if meth:
+ visit_name = ExtendedInternalTraversal._dispatch_lookup[visit_sym]
+ names.append((attrname, visit_name))
+
+ code = (
+ (" return [\n")
+ + (
+ ", \n".join(
+ " (%r, self.%s, visitor.%s)"
+ % (attrname, attrname, visit_name)
+ for attrname, visit_name in names
+ )
+ )
+ + ("\n ]\n")
+ )
+ meth_text = ("def %s(self, visitor):\n" % method_name) + code + "\n"
+ # print(meth_text)
+ return langhelpers._exec_code_in_env(meth_text, {}, method_name)
- else:
- # The optimization opportunity is lost for this case because the
- # __visit_name__ is not yet a string. As a result, the visit
- # string has to be recalculated with each compilation.
- def _compiler_dispatch(self, visitor, **kw):
- visit_attr = "visit_%s" % self.__visit_name__
- try:
- meth = getattr(visitor, visit_attr)
- except AttributeError:
- raise exc.UnsupportedCompilationError(visitor, cls)
- else:
- return meth(self, **kw)
-
- _compiler_dispatch.__doc__ = """Look for an attribute named "visit_" + self.__visit_name__
- on the visitor, and call it with the same kw params.
- """
- cls._compiler_dispatch = _compiler_dispatch
-
-
-class Visitable(util.with_metaclass(VisitableType, object)):
- """Base class for visitable objects, applies the
- :class:`.visitors.VisitableType` metaclass.
- The :class:`.Visitable` class is essentially at the base of the
- :class:`.ClauseElement` hierarchy.
+class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)):
+ r"""Defines visitor symbols used for internal traversal.
+
+ The :class:`.InternalTraversal` class is used in two ways. One is that
+ it can serve as the superclass for an object that implements the
+ various visit methods of the class. The other is that the symbols
+ themselves of :class:`.InternalTraversal` are used within
+ the ``_traverse_internals`` collection. Such as, the :class:`.Case`
+ object defines ``_travserse_internals`` as ::
+
+ _traverse_internals = [
+ ("value", InternalTraversal.dp_clauseelement),
+ ("whens", InternalTraversal.dp_clauseelement_tuples),
+ ("else_", InternalTraversal.dp_clauseelement),
+ ]
+
+ Above, the :class:`.Case` class indicates its internal state as the
+ attribtues named ``value``, ``whens``, and ``else\_``. They each
+ link to an :class:`.InternalTraversal` method which indicates the type
+ of datastructure referred towards.
+
+ Using the ``_traverse_internals`` structure, objects of type
+ :class:`.InternalTraversible` will have the following methods automatically
+ implemented:
+
+ * :meth:`.Traversible.get_children`
+
+ * :meth:`.Traversible._copy_internals`
+
+ * :meth:`.Traversible._gen_cache_key`
+
+ Subclasses can also implement these methods directly, particularly for the
+ :meth:`.Traversible._copy_internals` method, when special steps
+ are needed.
+
+ .. versionadded:: 1.4
"""
+ def dispatch(self, visit_symbol):
+ """Given a method from :class:`.InternalTraversal`, return the
+ corresponding method on a subclass.
-class ClauseVisitor(object):
- """Base class for visitor objects which can traverse using
+ """
+ name = self._dispatch_lookup[visit_symbol]
+ return getattr(self, name, None)
+
+ def run_generated_dispatch(
+ self, target, internal_dispatch, generate_dispatcher_name
+ ):
+ try:
+ dispatcher = target.__class__.__dict__[generate_dispatcher_name]
+ except KeyError:
+ dispatcher = _generate_dispatcher(
+ self, internal_dispatch, generate_dispatcher_name
+ )
+ setattr(target.__class__, generate_dispatcher_name, dispatcher)
+ return dispatcher(target, self)
+
+ dp_has_cache_key = symbol("HC")
+ """Visit a :class:`.HasCacheKey` object."""
+
+ dp_clauseelement = symbol("CE")
+ """Visit a :class:`.ClauseElement` object."""
+
+ dp_fromclause_canonical_column_collection = symbol("FC")
+ """Visit a :class:`.FromClause` object in the context of the
+ ``columns`` attribute.
+
+ The column collection is "canonical", meaning it is the originally
+ defined location of the :class:`.ColumnClause` objects. Right now
+ this means that the object being visited is a :class:`.TableClause`
+ or :class:`.Table` object only.
+
+ """
+
+ dp_clauseelement_tuples = symbol("CT")
+ """Visit a list of tuples which contain :class:`.ClauseElement`
+ objects.
+
+ """
+
+ dp_clauseelement_list = symbol("CL")
+ """Visit a list of :class:`.ClauseElement` objects.
+
+ """
+
+ dp_clauseelement_unordered_set = symbol("CU")
+ """Visit an unordered set of :class:`.ClauseElement` objects. """
+
+ dp_fromclause_ordered_set = symbol("CO")
+ """Visit an ordered set of :class:`.FromClause` objects. """
+
+ dp_string = symbol("S")
+ """Visit a plain string value.
+
+ Examples include table and column names, bound parameter keys, special
+ keywords such as "UNION", "UNION ALL".
+
+ The string value is considered to be significant for cache key
+ generation.
+
+ """
+
+ dp_anon_name = symbol("AN")
+ """Visit a potentially "anonymized" string value.
+
+ The string value is considered to be significant for cache key
+ generation.
+
+ """
+
+ dp_boolean = symbol("B")
+ """Visit a boolean value.
+
+ The boolean value is considered to be significant for cache key
+ generation.
+
+ """
+
+ dp_operator = symbol("O")
+ """Visit an operator.
+
+ The operator is a function from the :mod:`sqlalchemy.sql.operators`
+ module.
+
+ The operator value is considered to be significant for cache key
+ generation.
+
+ """
+
+ dp_type = symbol("T")
+ """Visit a :class:`.TypeEngine` object
+
+ The type object is considered to be significant for cache key
+ generation.
+
+ """
+
+ dp_plain_dict = symbol("PD")
+ """Visit a dictionary with string keys.
+
+ The keys of the dictionary should be strings, the values should
+ be immutable and hashable. The dictionary is considered to be
+ significant for cache key generation.
+
+ """
+
+ dp_string_clauseelement_dict = symbol("CD")
+ """Visit a dictionary of string keys to :class:`.ClauseElement`
+ objects.
+
+ """
+
+ dp_string_multi_dict = symbol("MD")
+ """Visit a dictionary of string keys to values which may either be
+ plain immutable/hashable or :class:`.HasCacheKey` objects.
+
+ """
+
+ dp_plain_obj = symbol("PO")
+ """Visit a plain python object.
+
+ The value should be immutable and hashable, such as an integer.
+ The value is considered to be significant for cache key generation.
+
+ """
+
+ dp_annotations_state = symbol("A")
+ """Visit the state of the :class:`.Annotatated` version of an object.
+
+ """
+
+ dp_named_ddl_element = symbol("DD")
+ """Visit a simple named DDL element.
+
+ The current object used by this method is the :class:`.Sequence`.
+
+ The object is only considered to be important for cache key generation
+ as far as its name, but not any other aspects of it.
+
+ """
+
+ dp_prefix_sequence = symbol("PS")
+ """Visit the sequence represented by :class:`.HasPrefixes`
+ or :class:`.HasSuffixes`.
+
+ """
+
+ dp_table_hint_list = symbol("TH")
+ """Visit the ``_hints`` collection of a :class:`.Select` object.
+
+ """
+
+ dp_statement_hint_list = symbol("SH")
+ """Visit the ``_statement_hints`` collection of a :class:`.Select`
+ object.
+
+ """
+
+ dp_unknown_structure = symbol("UK")
+ """Visit an unknown structure.
+
+ """
+
+
+class ExtendedInternalTraversal(InternalTraversal):
+ """defines additional symbols that are useful in caching applications.
+
+ Traversals for :class:`.ClauseElement` objects only need to use
+ those symbols present in :class:`.InternalTraversal`. However, for
+ additional caching use cases within the ORM, symbols dealing with the
+ :class:`.HasCacheKey` class are added here.
+
+ """
+
+ dp_ignore = symbol("IG")
+ """Specify an object that should be ignored entirely.
+
+ This currently applies function call argument caching where some
+ arguments should not be considered to be part of a cache key.
+
+ """
+
+ dp_inspectable = symbol("IS")
+ """Visit an inspectable object where the return value is a HasCacheKey`
+ object."""
+
+ dp_multi = symbol("M")
+ """Visit an object that may be a :class:`.HasCacheKey` or may be a
+ plain hashable object."""
+
+ dp_multi_list = symbol("MT")
+ """Visit a tuple containing elements that may be :class:`.HasCacheKey` or
+ may be a plain hashable object."""
+
+ dp_has_cache_key_tuples = symbol("HT")
+ """Visit a list of tuples which contain :class:`.HasCacheKey`
+ objects.
+
+ """
+
+ dp_has_cache_key_list = symbol("HL")
+ """Visit a list of :class:`.HasCacheKey` objects."""
+
+ dp_inspectable_list = symbol("IL")
+ """Visit a list of inspectable objects which upon inspection are
+ HasCacheKey objects."""
+
+
+class ExternalTraversal(object):
+ """Base class for visitor objects which can traverse externally using
the :func:`.visitors.traverse` function.
Direct usage of the :func:`.visitors.traverse` function is usually
@@ -178,7 +471,7 @@ class ClauseVisitor(object):
return self
-class CloningVisitor(ClauseVisitor):
+class CloningExternalTraversal(ExternalTraversal):
"""Base class for visitor objects which can traverse using
the :func:`.visitors.cloned_traverse` function.
@@ -203,7 +496,7 @@ class CloningVisitor(ClauseVisitor):
)
-class ReplacingCloningVisitor(CloningVisitor):
+class ReplacingExternalTraversal(CloningExternalTraversal):
"""Base class for visitor objects which can traverse using
the :func:`.visitors.replacement_traverse` function.
@@ -233,6 +526,14 @@ class ReplacingCloningVisitor(CloningVisitor):
return replacement_traverse(obj, self.__traverse_options__, replace)
+# backwards compatibility
+Visitable = Traversible
+VisitableType = TraversibleType
+ClauseVisitor = ExternalTraversal
+CloningVisitor = CloningExternalTraversal
+ReplacingCloningVisitor = ReplacingExternalTraversal
+
+
def iterate(obj, opts):
r"""traverse the given expression structure, returning an iterator.
@@ -405,11 +706,18 @@ def cloned_traverse(obj, opts, visitors):
cloned = {}
stop_on = set(opts.get("stop_on", []))
- def clone(elem):
+ def clone(elem, **kw):
if elem in stop_on:
return elem
else:
if id(elem) not in cloned:
+
+ if "replace" in kw:
+ newelem = kw["replace"](elem)
+ if newelem is not None:
+ cloned[id(elem)] = newelem
+ return newelem
+
cloned[id(elem)] = newelem = elem._clone()
newelem._copy_internals(clone=clone)
meth = visitors.get(newelem.__visit_name__, None)
@@ -461,7 +769,14 @@ def replacement_traverse(obj, opts, replace):
stop_on.add(id(newelem))
return newelem
else:
+
if elem not in cloned:
+ if "replace" in kw:
+ newelem = kw["replace"](elem)
+ if newelem is not None:
+ cloned[elem] = newelem
+ return newelem
+
cloned[elem] = newelem = elem._clone()
newelem._copy_internals(clone=clone, **kw)
return cloned[elem]