diff options
Diffstat (limited to 'lib/sqlalchemy/sql/annotation.py')
-rw-r--r-- | lib/sqlalchemy/sql/annotation.py | 69 |
1 files changed, 54 insertions, 15 deletions
diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py index a0264845e..0d995ec8a 100644 --- a/lib/sqlalchemy/sql/annotation.py +++ b/lib/sqlalchemy/sql/annotation.py @@ -12,12 +12,32 @@ associations. """ from . import operators +from .base import HasCacheKey +from .visitors import InternalTraversal from .. import util -class SupportsCloneAnnotations(object): +class SupportsAnnotations(object): + @util.memoized_property + def _annotation_traversals(self): + return [ + ( + key, + InternalTraversal.dp_has_cache_key + if isinstance(value, HasCacheKey) + else InternalTraversal.dp_plain_obj, + ) + for key, value in self._annotations.items() + ] + + +class SupportsCloneAnnotations(SupportsAnnotations): _annotations = util.immutabledict() + _traverse_internals = [ + ("_annotations", InternalTraversal.dp_annotations_state) + ] + def _annotate(self, values): """return a copy of this ClauseElement with annotations updated by the given dictionary. @@ -25,6 +45,7 @@ class SupportsCloneAnnotations(object): """ new = self._clone() new._annotations = new._annotations.union(values) + new.__dict__.pop("_annotation_traversals", None) return new def _with_annotations(self, values): @@ -34,6 +55,7 @@ class SupportsCloneAnnotations(object): """ new = self._clone() new._annotations = util.immutabledict(values) + new.__dict__.pop("_annotation_traversals", None) return new def _deannotate(self, values=None, clone=False): @@ -49,12 +71,13 @@ class SupportsCloneAnnotations(object): # the expression for a deep deannotation new = self._clone() new._annotations = {} + new.__dict__.pop("_annotation_traversals", None) return new else: return self -class SupportsWrappingAnnotations(object): +class SupportsWrappingAnnotations(SupportsAnnotations): def _annotate(self, values): """return a copy of this ClauseElement with annotations updated by the given dictionary. @@ -123,6 +146,7 @@ class Annotated(object): def __init__(self, element, values): self.__dict__ = element.__dict__.copy() + self.__dict__.pop("_annotation_traversals", None) self.__element = element self._annotations = values self._hash = hash(element) @@ -135,6 +159,7 @@ class Annotated(object): def _with_annotations(self, values): clone = self.__class__.__new__(self.__class__) clone.__dict__ = self.__dict__.copy() + clone.__dict__.pop("_annotation_traversals", None) clone._annotations = values return clone @@ -192,7 +217,17 @@ def _deep_annotate(element, annotations, exclude=None): """ - def clone(elem): + # annotated objects hack the __hash__() method so if we want to + # uniquely process them we have to use id() + + cloned_ids = {} + + def clone(elem, **kw): + id_ = id(elem) + + if id_ in cloned_ids: + return cloned_ids[id_] + if ( exclude and hasattr(elem, "proxy_set") @@ -204,6 +239,7 @@ def _deep_annotate(element, annotations, exclude=None): else: newelem = elem newelem._copy_internals(clone=clone) + cloned_ids[id_] = newelem return newelem if element is not None: @@ -214,23 +250,21 @@ def _deep_annotate(element, annotations, exclude=None): def _deep_deannotate(element, values=None): """Deep copy the given element, removing annotations.""" - cloned = util.column_dict() + cloned = {} - def clone(elem): - # if a values dict is given, - # the elem must be cloned each time it appears, - # as there may be different annotations in source - # elements that are remaining. if totally - # removing all annotations, can assume the same - # slate... - if values or elem not in cloned: + def clone(elem, **kw): + if values: + key = id(elem) + else: + key = elem + + if key not in cloned: newelem = elem._deannotate(values=values, clone=True) newelem._copy_internals(clone=clone) - if not values: - cloned[elem] = newelem + cloned[key] = newelem return newelem else: - return cloned[elem] + return cloned[key] if element is not None: element = clone(element) @@ -268,6 +302,11 @@ def _new_annotation_type(cls, base_cls): "Annotated%s" % cls.__name__, (base_cls, cls), {} ) globals()["Annotated%s" % cls.__name__] = anno_cls + + if "_traverse_internals" in cls.__dict__: + anno_cls._traverse_internals = list(cls._traverse_internals) + [ + ("_annotations", InternalTraversal.dp_annotations_state) + ] return anno_cls |