summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/attributes.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/orm/attributes.py')
-rw-r--r--lib/sqlalchemy/orm/attributes.py23
1 files changed, 20 insertions, 3 deletions
diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py
index 66197ba0e..459a52539 100644
--- a/lib/sqlalchemy/orm/attributes.py
+++ b/lib/sqlalchemy/orm/attributes.py
@@ -863,6 +863,16 @@ class CollectionAttributeImpl(AttributeImpl):
self.copy = copy_function
self.collection_factory = typecallable
+ if hasattr(self.collection_factory, "_sa_linker"):
+
+ @event.listens_for(self, "init_collection")
+ def link(target, collection, collection_adapter):
+ collection._sa_linker(collection_adapter)
+
+ @event.listens_for(self, "dispose_collection")
+ def unlink(target, collection, collection_adapter):
+ collection._sa_linker(None)
+
def __copy(self, item):
return [y for y in collections.collection_adapter(item)]
@@ -955,9 +965,14 @@ class CollectionAttributeImpl(AttributeImpl):
return user_data
def _initialize_collection(self, state):
- return state.manager.initialize_collection(
+
+ adapter, collection = state.manager.initialize_collection(
self.key, state, self.collection_factory)
+ self.dispatch.init_collection(state, collection, adapter)
+
+ return adapter, collection
+
def append(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
collection = self.get_collection(state, dict_, passive=passive)
if collection is PASSIVE_NO_RESULT:
@@ -1026,12 +1041,14 @@ class CollectionAttributeImpl(AttributeImpl):
# place a copy of "old" in state.committed_state
state._modified_event(dict_, self, old, True)
- old_collection = getattr(old, '_sa_adapter')
+ old_collection = old._sa_adapter
dict_[self.key] = user_data
collections.bulk_replace(new_values, old_collection, new_collection)
- old_collection.unlink(old)
+
+ del old._sa_adapter
+ self.dispatch.dispose_collection(state, old, old_collection)
def _invalidate_collection(self, collection):
adapter = getattr(collection, '_sa_adapter')