diff options
Diffstat (limited to 'lib/sqlalchemy/attributes.py')
-rw-r--r-- | lib/sqlalchemy/attributes.py | 55 |
1 files changed, 45 insertions, 10 deletions
diff --git a/lib/sqlalchemy/attributes.py b/lib/sqlalchemy/attributes.py index 84a1d58fb..7fd9686e3 100644 --- a/lib/sqlalchemy/attributes.py +++ b/lib/sqlalchemy/attributes.py @@ -13,13 +13,28 @@ class InstrumentedAttribute(object): PASSIVE_NORESULT = object() - def __init__(self, manager, key, uselist, callable_, typecallable, trackparent=False, extension=None, **kwargs): + def __init__(self, manager, key, uselist, callable_, typecallable, trackparent=False, extension=None, copy_function=None, compare_function=None, **kwargs): self.manager = manager self.key = key self.uselist = uselist self.callable_ = callable_ self.typecallable= typecallable self.trackparent = trackparent + if copy_function is None: + self._check_mutable_modified = False + if uselist: + self._copyfunc = lambda x: [y for y in x] + else: + # scalar values are assumed to be immutable unless a copy function + # is passed + self._copyfunc = lambda x: x + else: + self._check_mutable_modified = True + self._copyfunc = copy_function + if compare_function is None: + self._compare_function = lambda x,y: x == y + else: + self._compare_function = compare_function self.extensions = util.to_list(extension or []) def __set__(self, obj, value): @@ -31,6 +46,23 @@ class InstrumentedAttribute(object): return self return self.get(obj) + def is_equal(self, x, y): + return self._compare_function(x, y) + def copy(self, value): + return self._copyfunc(value) + + def check_mutable_modified(self, obj): + if self._check_mutable_modified: + h = self.get_history(obj, passive=True) + if h is not None and h.is_modified(): + obj._state['modified'] = True + return True + else: + return False + else: + return False + + def hasparent(self, item, optimistic=False): """return the boolean value of a "hasparent" flag attached to the given item. @@ -490,16 +522,14 @@ class CommittedState(object): if obj.__dict__.has_key(attr.key): value = obj.__dict__[attr.key] if value is not False: - if attr.uselist: - self.data[attr.key] = [x for x in value] - # not tracking parent on lazy-loaded instances at the moment. - # its not needed since they will be "optimistically" tested + self.data[attr.key] = attr.copy(value) + + # not tracking parent on lazy-loaded instances at the moment. + # its not needed since they will be "optimistically" tested + #if attr.uselist: #if attr.trackparent: # [attr.sethasparent(x, True) for x in self.data[attr.key] if x is not None] - else: - self.data[attr.key] = value - # not tracking parent on lazy-loaded instances at the moment. - # its not needed since they will be "optimistically" tested + #else: #if attr.trackparent and value is not None: # attr.sethasparent(value, True) @@ -550,7 +580,7 @@ class AttributeHistory(object): if a not in self._unchanged_items: self._deleted_items.append(a) else: - if current is original: + if attr.is_equal(current, original): self._unchanged_items = [current] self._added_items = [] self._deleted_items = [] @@ -564,6 +594,8 @@ class AttributeHistory(object): #print "key", attr.key, "orig", original, "current", current, "added", self._added_items, "unchanged", self._unchanged_items, "deleted", self._deleted_items def __iter__(self): return iter(self._current) + def is_modified(self): + return len(self._deleted_items) > 0 or len(self._added_items) > 0 def added_items(self): return self._added_items def unchanged_items(self): @@ -622,6 +654,9 @@ class AttributeManager(object): yield value def is_modified(self, object): + for attr in self.managed_attributes(object.__class__): + if attr.check_mutable_modified(object): + return True return object._state.get('modified', False) def init_attr(self, obj): |