summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/attributes.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/attributes.py')
-rw-r--r--lib/sqlalchemy/attributes.py55
1 files changed, 45 insertions, 10 deletions
diff --git a/lib/sqlalchemy/attributes.py b/lib/sqlalchemy/attributes.py
index 84a1d58fb..7fd9686e3 100644
--- a/lib/sqlalchemy/attributes.py
+++ b/lib/sqlalchemy/attributes.py
@@ -13,13 +13,28 @@ class InstrumentedAttribute(object):
PASSIVE_NORESULT = object()
- def __init__(self, manager, key, uselist, callable_, typecallable, trackparent=False, extension=None, **kwargs):
+ def __init__(self, manager, key, uselist, callable_, typecallable, trackparent=False, extension=None, copy_function=None, compare_function=None, **kwargs):
self.manager = manager
self.key = key
self.uselist = uselist
self.callable_ = callable_
self.typecallable= typecallable
self.trackparent = trackparent
+ if copy_function is None:
+ self._check_mutable_modified = False
+ if uselist:
+ self._copyfunc = lambda x: [y for y in x]
+ else:
+ # scalar values are assumed to be immutable unless a copy function
+ # is passed
+ self._copyfunc = lambda x: x
+ else:
+ self._check_mutable_modified = True
+ self._copyfunc = copy_function
+ if compare_function is None:
+ self._compare_function = lambda x,y: x == y
+ else:
+ self._compare_function = compare_function
self.extensions = util.to_list(extension or [])
def __set__(self, obj, value):
@@ -31,6 +46,23 @@ class InstrumentedAttribute(object):
return self
return self.get(obj)
+ def is_equal(self, x, y):
+ return self._compare_function(x, y)
+ def copy(self, value):
+ return self._copyfunc(value)
+
+ def check_mutable_modified(self, obj):
+ if self._check_mutable_modified:
+ h = self.get_history(obj, passive=True)
+ if h is not None and h.is_modified():
+ obj._state['modified'] = True
+ return True
+ else:
+ return False
+ else:
+ return False
+
+
def hasparent(self, item, optimistic=False):
"""return the boolean value of a "hasparent" flag attached to the given item.
@@ -490,16 +522,14 @@ class CommittedState(object):
if obj.__dict__.has_key(attr.key):
value = obj.__dict__[attr.key]
if value is not False:
- if attr.uselist:
- self.data[attr.key] = [x for x in value]
- # not tracking parent on lazy-loaded instances at the moment.
- # its not needed since they will be "optimistically" tested
+ self.data[attr.key] = attr.copy(value)
+
+ # not tracking parent on lazy-loaded instances at the moment.
+ # its not needed since they will be "optimistically" tested
+ #if attr.uselist:
#if attr.trackparent:
# [attr.sethasparent(x, True) for x in self.data[attr.key] if x is not None]
- else:
- self.data[attr.key] = value
- # not tracking parent on lazy-loaded instances at the moment.
- # its not needed since they will be "optimistically" tested
+ #else:
#if attr.trackparent and value is not None:
# attr.sethasparent(value, True)
@@ -550,7 +580,7 @@ class AttributeHistory(object):
if a not in self._unchanged_items:
self._deleted_items.append(a)
else:
- if current is original:
+ if attr.is_equal(current, original):
self._unchanged_items = [current]
self._added_items = []
self._deleted_items = []
@@ -564,6 +594,8 @@ class AttributeHistory(object):
#print "key", attr.key, "orig", original, "current", current, "added", self._added_items, "unchanged", self._unchanged_items, "deleted", self._deleted_items
def __iter__(self):
return iter(self._current)
+ def is_modified(self):
+ return len(self._deleted_items) > 0 or len(self._added_items) > 0
def added_items(self):
return self._added_items
def unchanged_items(self):
@@ -622,6 +654,9 @@ class AttributeManager(object):
yield value
def is_modified(self, object):
+ for attr in self.managed_attributes(object.__class__):
+ if attr.check_mutable_modified(object):
+ return True
return object._state.get('modified', False)
def init_attr(self, obj):