summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/annotation.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/annotation.py')
-rw-r--r--lib/sqlalchemy/sql/annotation.py69
1 files changed, 54 insertions, 15 deletions
diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py
index a0264845e..0d995ec8a 100644
--- a/lib/sqlalchemy/sql/annotation.py
+++ b/lib/sqlalchemy/sql/annotation.py
@@ -12,12 +12,32 @@ associations.
"""
from . import operators
+from .base import HasCacheKey
+from .visitors import InternalTraversal
from .. import util
-class SupportsCloneAnnotations(object):
+class SupportsAnnotations(object):
+ @util.memoized_property
+ def _annotation_traversals(self):
+ return [
+ (
+ key,
+ InternalTraversal.dp_has_cache_key
+ if isinstance(value, HasCacheKey)
+ else InternalTraversal.dp_plain_obj,
+ )
+ for key, value in self._annotations.items()
+ ]
+
+
+class SupportsCloneAnnotations(SupportsAnnotations):
_annotations = util.immutabledict()
+ _traverse_internals = [
+ ("_annotations", InternalTraversal.dp_annotations_state)
+ ]
+
def _annotate(self, values):
"""return a copy of this ClauseElement with annotations
updated by the given dictionary.
@@ -25,6 +45,7 @@ class SupportsCloneAnnotations(object):
"""
new = self._clone()
new._annotations = new._annotations.union(values)
+ new.__dict__.pop("_annotation_traversals", None)
return new
def _with_annotations(self, values):
@@ -34,6 +55,7 @@ class SupportsCloneAnnotations(object):
"""
new = self._clone()
new._annotations = util.immutabledict(values)
+ new.__dict__.pop("_annotation_traversals", None)
return new
def _deannotate(self, values=None, clone=False):
@@ -49,12 +71,13 @@ class SupportsCloneAnnotations(object):
# the expression for a deep deannotation
new = self._clone()
new._annotations = {}
+ new.__dict__.pop("_annotation_traversals", None)
return new
else:
return self
-class SupportsWrappingAnnotations(object):
+class SupportsWrappingAnnotations(SupportsAnnotations):
def _annotate(self, values):
"""return a copy of this ClauseElement with annotations
updated by the given dictionary.
@@ -123,6 +146,7 @@ class Annotated(object):
def __init__(self, element, values):
self.__dict__ = element.__dict__.copy()
+ self.__dict__.pop("_annotation_traversals", None)
self.__element = element
self._annotations = values
self._hash = hash(element)
@@ -135,6 +159,7 @@ class Annotated(object):
def _with_annotations(self, values):
clone = self.__class__.__new__(self.__class__)
clone.__dict__ = self.__dict__.copy()
+ clone.__dict__.pop("_annotation_traversals", None)
clone._annotations = values
return clone
@@ -192,7 +217,17 @@ def _deep_annotate(element, annotations, exclude=None):
"""
- def clone(elem):
+ # annotated objects hack the __hash__() method so if we want to
+ # uniquely process them we have to use id()
+
+ cloned_ids = {}
+
+ def clone(elem, **kw):
+ id_ = id(elem)
+
+ if id_ in cloned_ids:
+ return cloned_ids[id_]
+
if (
exclude
and hasattr(elem, "proxy_set")
@@ -204,6 +239,7 @@ def _deep_annotate(element, annotations, exclude=None):
else:
newelem = elem
newelem._copy_internals(clone=clone)
+ cloned_ids[id_] = newelem
return newelem
if element is not None:
@@ -214,23 +250,21 @@ def _deep_annotate(element, annotations, exclude=None):
def _deep_deannotate(element, values=None):
"""Deep copy the given element, removing annotations."""
- cloned = util.column_dict()
+ cloned = {}
- def clone(elem):
- # if a values dict is given,
- # the elem must be cloned each time it appears,
- # as there may be different annotations in source
- # elements that are remaining. if totally
- # removing all annotations, can assume the same
- # slate...
- if values or elem not in cloned:
+ def clone(elem, **kw):
+ if values:
+ key = id(elem)
+ else:
+ key = elem
+
+ if key not in cloned:
newelem = elem._deannotate(values=values, clone=True)
newelem._copy_internals(clone=clone)
- if not values:
- cloned[elem] = newelem
+ cloned[key] = newelem
return newelem
else:
- return cloned[elem]
+ return cloned[key]
if element is not None:
element = clone(element)
@@ -268,6 +302,11 @@ def _new_annotation_type(cls, base_cls):
"Annotated%s" % cls.__name__, (base_cls, cls), {}
)
globals()["Annotated%s" % cls.__name__] = anno_cls
+
+ if "_traverse_internals" in cls.__dict__:
+ anno_cls._traverse_internals = list(cls._traverse_internals) + [
+ ("_annotations", InternalTraversal.dp_annotations_state)
+ ]
return anno_cls