diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2010-04-06 01:23:54 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2010-04-06 01:23:54 -0400 |
commit | 9df8afc600cd69a87ece009beefa0108bb49b256 (patch) | |
tree | 61f07be937743be302b5307119e441fc26d10a98 /lib/sqlalchemy/orm/unitofwork.py | |
parent | 4071156acdd5929c8c8a2c9556fc466ba7581eca (diff) | |
download | sqlalchemy-9df8afc600cd69a87ece009beefa0108bb49b256.tar.gz |
- cleanup, factoring, had some heisenbugs. more test coverage
will be needed overall as missing dependency rules lead
to subtle bugs pretty easily
Diffstat (limited to 'lib/sqlalchemy/orm/unitofwork.py')
-rw-r--r-- | lib/sqlalchemy/orm/unitofwork.py | 92 |
1 files changed, 60 insertions, 32 deletions
diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index 0028202c5..088c71b6e 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -60,15 +60,19 @@ class UOWEventHandler(interfaces.AttributeExtension): sess.expunge(item) def set(self, state, newvalue, oldvalue, initiator): - # process "save_update" cascade rules for when an instance is attached to another instance + # process "save_update" cascade rules for when an instance + # is attached to another instance if oldvalue is newvalue: return newvalue sess = _state_session(state) if sess: prop = _state_mapper(state).get_property(self.key) - if newvalue is not None and prop.cascade.save_update and newvalue not in sess: + if newvalue is not None and \ + prop.cascade.save_update and \ + newvalue not in sess: sess.add(newvalue) - if prop.cascade.delete_orphan and oldvalue in sess.new and \ + if prop.cascade.delete_orphan and \ + oldvalue in sess.new and \ prop.mapper._is_orphan(attributes.instance_state(oldvalue)): sess.expunge(oldvalue) return newvalue @@ -99,7 +103,8 @@ class UOWTransaction(object): return bool(self.states) def is_deleted(self, state): - """return true if the given state is marked as deleted within this UOWTransaction.""" + """return true if the given state is marked as deleted + within this UOWTransaction.""" return state in self.states and self.states[state][0] def remove_state_actions(self, state): @@ -130,8 +135,6 @@ class UOWTransaction(object): return history.as_state() def register_object(self, state, isdelete=False, listonly=False): - - # if object is not in the overall session, do nothing if not self.session._contains_state(state): return @@ -139,7 +142,7 @@ class UOWTransaction(object): mapper = _state_mapper(state) if mapper not in self.mappers: - mapper.per_mapper_flush_actions(self) + mapper._per_mapper_flush_actions(self) self.mappers[mapper].add(state) self.states[state] = (isdelete, listonly) @@ -199,11 +202,9 @@ class UOWTransaction(object): # the per-state actions for those per-mapper actions # that were broken up. for edge in list(self.dependencies): - if None in edge: - self.dependencies.remove(edge) - elif cycles.issuperset(edge): - self.dependencies.remove(edge) - elif edge[0].disabled or edge[1].disabled: + if None in edge or\ + cycles.issuperset(edge) or \ + edge[0].disabled or edge[1].disabled: self.dependencies.remove(edge) elif edge[0] in cycles: self.dependencies.remove(edge) @@ -220,14 +221,24 @@ class UOWTransaction(object): ] ).difference(cycles) - # execute actions + # execute if cycles: - for set_ in topological.sort_as_subsets(self.dependencies, postsort_actions): + for set_ in topological.sort_as_subsets( + self.dependencies, + postsort_actions): while set_: n = set_.pop() n.execute_aggregate(self, set_) else: - for rec in topological.sort(self.dependencies, postsort_actions): + r = list(topological.sort( + self.dependencies, + postsort_actions)) + print "-----------" + print self.dependencies + print r + for rec in topological.sort( + self.dependencies, + postsort_actions): rec.execute(self) @@ -254,7 +265,9 @@ class PreSortRec(object): if key in uow.presort_actions: return uow.presort_actions[key] else: - uow.presort_actions[key] = ret = object.__new__(cls) + uow.presort_actions[key] = \ + ret = \ + object.__new__(cls) return ret class PostSortRec(object): @@ -265,7 +278,9 @@ class PostSortRec(object): if key in uow.postsort_actions: return uow.postsort_actions[key] else: - uow.postsort_actions[key] = ret = object.__new__(cls) + uow.postsort_actions[key] = \ + ret = \ + object.__new__(cls) return ret def execute_aggregate(self, uow, recs): @@ -351,7 +366,7 @@ class SaveUpdateAll(PostSortRec): ) def per_state_flush_actions(self, uow): - for rec in self.mapper.per_state_flush_actions( + for rec in self.mapper._per_state_flush_actions( uow, uow.states_for_mapper_hierarchy(self.mapper, False, False), False): @@ -369,7 +384,7 @@ class DeleteAll(PostSortRec): ) def per_state_flush_actions(self, uow): - for rec in self.mapper.per_state_flush_actions( + for rec in self.mapper._per_state_flush_actions( uow, uow.states_for_mapper_hierarchy(self.mapper, True, False), True): @@ -396,26 +411,27 @@ class ProcessState(PostSortRec): ) class SaveUpdateState(PostSortRec): - def __init__(self, uow, state): + def __init__(self, uow, state, mapper): self.state = state - + self.mapper = mapper + def execute(self, uow): - mapper = self.state.manager.mapper.base_mapper - mapper._save_obj( + self.mapper._save_obj( [self.state], uow ) def execute_aggregate(self, uow, recs): cls_ = self.__class__ - # TODO: have 'mapper' be present on SaveUpdateState already - mapper = self.state.manager.mapper.base_mapper - + mapper = self.mapper our_recs = [r for r in recs if r.__class__ is cls_ and - r.state.manager.mapper.base_mapper is mapper] + r.mapper is mapper] recs.difference_update(our_recs) - mapper._save_obj([self.state] + [r.state for r in our_recs], uow) + mapper._save_obj( + [self.state] + + [r.state for r in our_recs], + uow) def __repr__(self): return "%s(%s)" % ( @@ -424,17 +440,29 @@ class SaveUpdateState(PostSortRec): ) class DeleteState(PostSortRec): - def __init__(self, uow, state): + def __init__(self, uow, state, mapper): self.state = state - + self.mapper = mapper + def execute(self, uow): - mapper = self.state.manager.mapper.base_mapper if uow.states[self.state][0]: - mapper._delete_obj( + self.mapper._delete_obj( [self.state], uow ) + def execute_aggregate(self, uow, recs): + cls_ = self.__class__ + mapper = self.mapper + our_recs = [r for r in recs + if r.__class__ is cls_ and + r.mapper is mapper] + recs.difference_update(our_recs) + states = [self.state] + [r.state for r in our_recs] + mapper._delete_obj( + [s for s in states if uow.states[s][0]], + uow) + def __repr__(self): return "%s(%s)" % ( self.__class__.__name__, |