summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/unitofwork.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2010-04-06 01:23:54 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2010-04-06 01:23:54 -0400
commit9df8afc600cd69a87ece009beefa0108bb49b256 (patch)
tree61f07be937743be302b5307119e441fc26d10a98 /lib/sqlalchemy/orm/unitofwork.py
parent4071156acdd5929c8c8a2c9556fc466ba7581eca (diff)
downloadsqlalchemy-9df8afc600cd69a87ece009beefa0108bb49b256.tar.gz
- cleanup, factoring, had some heisenbugs. more test coverage
will be needed overall as missing dependency rules lead to subtle bugs pretty easily
Diffstat (limited to 'lib/sqlalchemy/orm/unitofwork.py')
-rw-r--r--lib/sqlalchemy/orm/unitofwork.py92
1 files changed, 60 insertions, 32 deletions
diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py
index 0028202c5..088c71b6e 100644
--- a/lib/sqlalchemy/orm/unitofwork.py
+++ b/lib/sqlalchemy/orm/unitofwork.py
@@ -60,15 +60,19 @@ class UOWEventHandler(interfaces.AttributeExtension):
sess.expunge(item)
def set(self, state, newvalue, oldvalue, initiator):
- # process "save_update" cascade rules for when an instance is attached to another instance
+ # process "save_update" cascade rules for when an instance
+ # is attached to another instance
if oldvalue is newvalue:
return newvalue
sess = _state_session(state)
if sess:
prop = _state_mapper(state).get_property(self.key)
- if newvalue is not None and prop.cascade.save_update and newvalue not in sess:
+ if newvalue is not None and \
+ prop.cascade.save_update and \
+ newvalue not in sess:
sess.add(newvalue)
- if prop.cascade.delete_orphan and oldvalue in sess.new and \
+ if prop.cascade.delete_orphan and \
+ oldvalue in sess.new and \
prop.mapper._is_orphan(attributes.instance_state(oldvalue)):
sess.expunge(oldvalue)
return newvalue
@@ -99,7 +103,8 @@ class UOWTransaction(object):
return bool(self.states)
def is_deleted(self, state):
- """return true if the given state is marked as deleted within this UOWTransaction."""
+ """return true if the given state is marked as deleted
+ within this UOWTransaction."""
return state in self.states and self.states[state][0]
def remove_state_actions(self, state):
@@ -130,8 +135,6 @@ class UOWTransaction(object):
return history.as_state()
def register_object(self, state, isdelete=False, listonly=False):
-
- # if object is not in the overall session, do nothing
if not self.session._contains_state(state):
return
@@ -139,7 +142,7 @@ class UOWTransaction(object):
mapper = _state_mapper(state)
if mapper not in self.mappers:
- mapper.per_mapper_flush_actions(self)
+ mapper._per_mapper_flush_actions(self)
self.mappers[mapper].add(state)
self.states[state] = (isdelete, listonly)
@@ -199,11 +202,9 @@ class UOWTransaction(object):
# the per-state actions for those per-mapper actions
# that were broken up.
for edge in list(self.dependencies):
- if None in edge:
- self.dependencies.remove(edge)
- elif cycles.issuperset(edge):
- self.dependencies.remove(edge)
- elif edge[0].disabled or edge[1].disabled:
+ if None in edge or\
+ cycles.issuperset(edge) or \
+ edge[0].disabled or edge[1].disabled:
self.dependencies.remove(edge)
elif edge[0] in cycles:
self.dependencies.remove(edge)
@@ -220,14 +221,24 @@ class UOWTransaction(object):
]
).difference(cycles)
- # execute actions
+ # execute
if cycles:
- for set_ in topological.sort_as_subsets(self.dependencies, postsort_actions):
+ for set_ in topological.sort_as_subsets(
+ self.dependencies,
+ postsort_actions):
while set_:
n = set_.pop()
n.execute_aggregate(self, set_)
else:
- for rec in topological.sort(self.dependencies, postsort_actions):
+ r = list(topological.sort(
+ self.dependencies,
+ postsort_actions))
+ print "-----------"
+ print self.dependencies
+ print r
+ for rec in topological.sort(
+ self.dependencies,
+ postsort_actions):
rec.execute(self)
@@ -254,7 +265,9 @@ class PreSortRec(object):
if key in uow.presort_actions:
return uow.presort_actions[key]
else:
- uow.presort_actions[key] = ret = object.__new__(cls)
+ uow.presort_actions[key] = \
+ ret = \
+ object.__new__(cls)
return ret
class PostSortRec(object):
@@ -265,7 +278,9 @@ class PostSortRec(object):
if key in uow.postsort_actions:
return uow.postsort_actions[key]
else:
- uow.postsort_actions[key] = ret = object.__new__(cls)
+ uow.postsort_actions[key] = \
+ ret = \
+ object.__new__(cls)
return ret
def execute_aggregate(self, uow, recs):
@@ -351,7 +366,7 @@ class SaveUpdateAll(PostSortRec):
)
def per_state_flush_actions(self, uow):
- for rec in self.mapper.per_state_flush_actions(
+ for rec in self.mapper._per_state_flush_actions(
uow,
uow.states_for_mapper_hierarchy(self.mapper, False, False),
False):
@@ -369,7 +384,7 @@ class DeleteAll(PostSortRec):
)
def per_state_flush_actions(self, uow):
- for rec in self.mapper.per_state_flush_actions(
+ for rec in self.mapper._per_state_flush_actions(
uow,
uow.states_for_mapper_hierarchy(self.mapper, True, False),
True):
@@ -396,26 +411,27 @@ class ProcessState(PostSortRec):
)
class SaveUpdateState(PostSortRec):
- def __init__(self, uow, state):
+ def __init__(self, uow, state, mapper):
self.state = state
-
+ self.mapper = mapper
+
def execute(self, uow):
- mapper = self.state.manager.mapper.base_mapper
- mapper._save_obj(
+ self.mapper._save_obj(
[self.state],
uow
)
def execute_aggregate(self, uow, recs):
cls_ = self.__class__
- # TODO: have 'mapper' be present on SaveUpdateState already
- mapper = self.state.manager.mapper.base_mapper
-
+ mapper = self.mapper
our_recs = [r for r in recs
if r.__class__ is cls_ and
- r.state.manager.mapper.base_mapper is mapper]
+ r.mapper is mapper]
recs.difference_update(our_recs)
- mapper._save_obj([self.state] + [r.state for r in our_recs], uow)
+ mapper._save_obj(
+ [self.state] +
+ [r.state for r in our_recs],
+ uow)
def __repr__(self):
return "%s(%s)" % (
@@ -424,17 +440,29 @@ class SaveUpdateState(PostSortRec):
)
class DeleteState(PostSortRec):
- def __init__(self, uow, state):
+ def __init__(self, uow, state, mapper):
self.state = state
-
+ self.mapper = mapper
+
def execute(self, uow):
- mapper = self.state.manager.mapper.base_mapper
if uow.states[self.state][0]:
- mapper._delete_obj(
+ self.mapper._delete_obj(
[self.state],
uow
)
+ def execute_aggregate(self, uow, recs):
+ cls_ = self.__class__
+ mapper = self.mapper
+ our_recs = [r for r in recs
+ if r.__class__ is cls_ and
+ r.mapper is mapper]
+ recs.difference_update(our_recs)
+ states = [self.state] + [r.state for r in our_recs]
+ mapper._delete_obj(
+ [s for s in states if uow.states[s][0]],
+ uow)
+
def __repr__(self):
return "%s(%s)" % (
self.__class__.__name__,