summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CHANGES3
-rw-r--r--lib/sqlalchemy/orm/mapper.py12
-rw-r--r--lib/sqlalchemy/orm/query.py4
-rw-r--r--test/orm/objectstore.py42
4 files changed, 53 insertions, 8 deletions
diff --git a/CHANGES b/CHANGES
index 7cb9b2cb5..b0d75ea72 100644
--- a/CHANGES
+++ b/CHANGES
@@ -10,6 +10,9 @@ including "with_lockmode" function to get a Query copy that has
a default locking mode. Will translate "read"/"update"
arguments into a for_update argument on the select side.
[ticket:292]
+- implemented "version check" logic in Query/Mapper, used
+when version_id_col is in effect and query.with_lockmode()
+is used to get() an instance thats already loaded
0.2.8
- cleanup on connection methods + documentation. custom DBAPI
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index 4197401f7..837f17a33 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -610,6 +610,7 @@ class Mapper(object):
limit = kwargs.get('limit', None)
offset = kwargs.get('offset', None)
populate_existing = kwargs.get('populate_existing', False)
+ version_check = kwargs.get('version_check', False)
result = util.UniqueAppender([])
if mappers:
@@ -624,7 +625,7 @@ class Mapper(object):
row = cursor.fetchone()
if row is None:
break
- self._instance(session, row, imap, result, populate_existing=populate_existing)
+ self._instance(session, row, imap, result, populate_existing=populate_existing, version_check=version_check)
i = 0
for m in mappers:
m._instance(session, row, imap, otherresults[i])
@@ -838,7 +839,7 @@ class Mapper(object):
rows += c.cursor.rowcount
if c.supports_sane_rowcount() and rows != len(update):
- raise exceptions.FlushError("ConcurrencyError - updated rowcount %d does not match number of objects updated %d" % (rows, len(update)))
+ raise exceptions.ConcurrentModificationError("Updated rowcount %d does not match number of objects updated %d" % (rows, len(update)))
if len(insert):
statement = table.insert()
@@ -932,7 +933,7 @@ class Mapper(object):
statement = table.delete(clause)
c = connection.execute(statement, delete)
if c.supports_sane_rowcount() and c.rowcount != len(delete):
- raise exceptions.FlushError("ConcurrencyError - updated rowcount %d does not match number of objects updated %d" % (c.cursor.rowcount, len(delete)))
+ raise exceptions.ConcurrentModificationError("Updated rowcount %d does not match number of objects updated %d" % (c.cursor.rowcount, len(delete)))
[self.extension.after_delete(self, connection, obj) for obj in deleted_objects]
@@ -972,7 +973,7 @@ class Mapper(object):
def get_select_mapper(self):
return self.__surrogate_mapper or self
- def _instance(self, session, row, imap, result = None, populate_existing = False):
+ def _instance(self, session, row, imap, result = None, populate_existing = False, version_check=False):
"""pulls an object instance from the given row and appends it to the given result
list. if the instance already exists in the given identity map, its not added. in
either case, executes all the property loaders on the instance to also process extra
@@ -994,6 +995,9 @@ class Mapper(object):
if session.has_key(identitykey):
instance = session._get(identitykey)
isnew = False
+ if version_check and self.version_id_col is not None and self._getattrbycolumn(instance, self.version_id_col) != row[self.version_id_col]:
+ raise exceptions.ConcurrentModificationError("Instance '%s' version of %s does not match %s" % (instance, self._getattrbycolumn(instance, self.version_id_col), row[self.version_id_col]))
+
if populate_existing or session.is_expired(instance, unexpire=True):
if not imap.has_key(identitykey):
imap[identitykey] = instance
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index 052d048cb..13092b44d 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -278,7 +278,7 @@ class Query(object):
def _get(self, key, ident=None, reload=False, lockmode=None):
lockmode = lockmode or self.lockmode
- if not reload and not self.always_refresh and lockmode == None:
+ if not reload and not self.always_refresh and lockmode is None:
try:
return self.session._get(key)
except KeyError:
@@ -301,7 +301,7 @@ class Query(object):
i += 1
try:
statement = self.compile(self._get_clause, lockmode=lockmode)
- return self._select_statement(statement, params=params, populate_existing=reload)[0]
+ return self._select_statement(statement, params=params, populate_existing=reload, version_check=(lockmode is not None))[0]
except IndexError:
return None
diff --git a/test/orm/objectstore.py b/test/orm/objectstore.py
index 4c35fee65..ec47749bd 100644
--- a/test/orm/objectstore.py
+++ b/test/orm/objectstore.py
@@ -150,7 +150,7 @@ class VersioningTest(SessionTest):
# a concurrent session has modified this, should throw
# an exception
s.flush()
- except exceptions.SQLAlchemyError, e:
+ except exceptions.ConcurrentModificationError, e:
#print e
success = True
assert success
@@ -166,10 +166,48 @@ class VersioningTest(SessionTest):
success = False
try:
s.flush()
- except exceptions.SQLAlchemyError, e:
+ except exceptions.ConcurrentModificationError, e:
#print e
success = True
assert success
+ def testversioncheck(self):
+ """test that query.with_lockmode performs a 'version check' on an already loaded instance"""
+ s1 = create_session()
+ class Foo(object):pass
+ assign_mapper(Foo, version_table, version_id_col=version_table.c.version_id)
+ f1s1 =Foo(value='f1', _sa_session=s1)
+ s1.flush()
+ s2 = create_session()
+ f1s2 = s2.query(Foo).get(f1s1.id)
+ f1s2.value='f1 new value'
+ s2.flush()
+ try:
+ # load, version is wrong
+ s1.query(Foo).with_lockmode('read').get(f1s1.id)
+ assert False
+ except exceptions.ConcurrentModificationError, e:
+ assert True
+ # reload it
+ s1.query(Foo).load(f1s1.id)
+ # now assert version OK
+ s1.query(Foo).with_lockmode('read').get(f1s1.id)
+
+ # assert brand new load is OK too
+ s1.clear()
+ s1.query(Foo).with_lockmode('read').get(f1s1.id)
+
+ def testnoversioncheck(self):
+ """test that query.with_lockmode works OK when the mapper has no version id col"""
+ s1 = create_session()
+ class Foo(object):pass
+ assign_mapper(Foo, version_table)
+ f1s1 =Foo(value='f1', _sa_session=s1)
+ f1s1.version_id=0
+ s1.flush()
+ s2 = create_session()
+ f1s2 = s2.query(Foo).with_lockmode('read').get(f1s1.id)
+ assert f1s2.id == f1s1.id
+ assert f1s2.value == f1s1.value
class UnicodeTest(SessionTest):
def setUpAll(self):