diff options
-rw-r--r-- | lib/sqlalchemy/dialects/oracle/base.py | 8 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/base.py | 8 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/query.py | 4 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 4 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 18 | ||||
-rw-r--r-- | test/dialect/mysql/test_compiler.py | 15 | ||||
-rw-r--r-- | test/dialect/postgresql/test_compiler.py | 55 | ||||
-rw-r--r-- | test/dialect/test_oracle.py | 43 | ||||
-rw-r--r-- | test/orm/test_lockmode.py | 8 | ||||
-rw-r--r-- | test/sql/test_compiler.py | 78 | ||||
-rw-r--r-- | test/sql/test_selectable.py | 54 |
11 files changed, 203 insertions, 92 deletions
diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index a3c31b7cc..ba69c3d1f 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -666,15 +666,15 @@ class OracleCompiler(compiler.SQLCompiler): tmp = ' FOR UPDATE' - if select._for_update_arg.nowait: - tmp += " NOWAIT" - if select._for_update_arg.of: tmp += ' OF ' + ', '.join( - self._process(elem) for elem in + self.process(elem) for elem in select._for_update_arg.of ) + if select._for_update_arg.nowait: + tmp += " NOWAIT" + return tmp diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 091fdeda2..69b0fb040 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1020,17 +1020,17 @@ class PGCompiler(compiler.SQLCompiler): else: tmp = " FOR UPDATE" - if select._for_update_arg.nowait: - tmp += " NOWAIT" - if select._for_update_arg.of: # TODO: assuming simplistic c.table here tables = set(c.table for c in select._for_update_arg.of) tmp += " OF " + ", ".join( - self.process(table, asfrom=True) + self.process(table, ashint=True) for table in tables ) + if select._for_update_arg.nowait: + tmp += " NOWAIT" + return tmp def returning_clause(self, stmt, returning_cols): diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index f0d9a47d6..173ad038e 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -1124,10 +1124,10 @@ class Query(object): self._execution_options = self._execution_options.union(kwargs) @_generative() - def with_lockmode(self, mode, of=None): + def with_lockmode(self, mode): """Return a new Query object with the specified locking mode. - .. deprecated:: 0.9.0b2 superseded by :meth:`.Query.for_update`. + .. deprecated:: 0.9.0b2 superseded by :meth:`.Query.with_for_update`. :param mode: a string representing the desired locking mode. A corresponding :meth:`~sqlalchemy.orm.query.LockmodeArgs` object diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 0fc99897e..3ba3957d6 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1513,9 +1513,11 @@ class SQLCompiler(Compiled): text += self.order_by_clause(select, order_by_select=order_by_select, **kwargs) + if select._limit is not None or select._offset is not None: text += self.limit_clause(select) - if select._for_update_arg: + + if select._for_update_arg is not None: text += self.for_update_clause(select) if self.ctes and \ diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index e49c10001..01c803f3b 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -1151,7 +1151,7 @@ class TableClause(Immutable, FromClause): return [self] -class ForUpdateArg(object): +class ForUpdateArg(ClauseElement): @classmethod def parse_legacy_select(self, arg): @@ -1185,6 +1185,8 @@ class ForUpdateArg(object): read = True elif arg == 'read_nowait': read = nowait = True + elif arg is not True: + raise exc.ArgumentError("Unknown for_update argument: %r" % arg) return ForUpdateArg(read=read, nowait=nowait) @@ -1195,9 +1197,13 @@ class ForUpdateArg(object): elif self.read and self.nowait: return "read_nowait" elif self.nowait: - return "update_nowait" + return "nowait" else: - return "update" + return True + + def _copy_internals(self, clone=_clone, **kw): + if self.of is not None: + self.of = [clone(col, **kw) for col in self.of] def __init__(self, nowait=False, read=False, of=None): """Represents arguments specified to :meth:`.Select.for_update`. @@ -1208,7 +1214,7 @@ class ForUpdateArg(object): self.nowait = nowait self.read = read if of is not None: - self.of = [_only_column_elements(of, "of") + self.of = [_only_column_elements(elem, "of") for elem in util.to_list(of)] else: self.of = None @@ -1770,7 +1776,7 @@ class CompoundSelect(SelectBase): self.selects = [clone(s, **kw) for s in self.selects] if hasattr(self, '_col_map'): del self._col_map - for attr in ('_order_by_clause', '_group_by_clause'): + for attr in ('_order_by_clause', '_group_by_clause', '_for_update_arg'): if getattr(self, attr) is not None: setattr(self, attr, clone(getattr(self, attr), **kw)) @@ -2255,7 +2261,7 @@ class Select(HasPrefixes, SelectBase): # present here. self._raw_columns = [clone(c, **kw) for c in self._raw_columns] for attr in '_whereclause', '_having', '_order_by_clause', \ - '_group_by_clause': + '_group_by_clause', '_for_update_arg': if getattr(self, attr) is not None: setattr(self, attr, clone(getattr(self, attr), **kw)) diff --git a/test/dialect/mysql/test_compiler.py b/test/dialect/mysql/test_compiler.py index a50c6a901..46e8bfb82 100644 --- a/test/dialect/mysql/test_compiler.py +++ b/test/dialect/mysql/test_compiler.py @@ -6,6 +6,7 @@ from sqlalchemy import sql, exc, schema, types as sqltypes from sqlalchemy.dialects.mysql import base as mysql from sqlalchemy.testing import fixtures, AssertsCompiledSQL from sqlalchemy import testing +from sqlalchemy.sql import table, column class CompileTest(fixtures.TestBase, AssertsCompiledSQL): @@ -131,6 +132,20 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): schema.CreateTable(t2).compile, dialect=mysql.dialect() ) + def test_for_update(self): + table1 = table('mytable', + column('myid'), column('name'), column('description')) + + self.assert_compile( + table1.select(table1.c.myid == 7).with_for_update(), + "SELECT mytable.myid, mytable.name, mytable.description " + "FROM mytable WHERE mytable.myid = %s FOR UPDATE") + + self.assert_compile( + table1.select(table1.c.myid == 7).with_for_update(read=True), + "SELECT mytable.myid, mytable.name, mytable.description " + "FROM mytable WHERE mytable.myid = %s LOCK IN SHARE MODE") + class SQLTest(fixtures.TestBase, AssertsCompiledSQL): """Tests MySQL-dialect specific compilation.""" diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py index 76fd9d907..05963e51c 100644 --- a/test/dialect/postgresql/test_compiler.py +++ b/test/dialect/postgresql/test_compiler.py @@ -249,6 +249,61 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): 'SUBSTRING(%(substring_1)s FROM %(substring_2)s)') + def test_for_update(self): + table1 = table('mytable', + column('myid'), column('name'), column('description')) + + self.assert_compile( + table1.select(table1.c.myid == 7).with_for_update(), + "SELECT mytable.myid, mytable.name, mytable.description " + "FROM mytable WHERE mytable.myid = %(myid_1)s FOR UPDATE") + + self.assert_compile( + table1.select(table1.c.myid == 7).with_for_update(nowait=True), + "SELECT mytable.myid, mytable.name, mytable.description " + "FROM mytable WHERE mytable.myid = %(myid_1)s FOR UPDATE NOWAIT") + + self.assert_compile( + table1.select(table1.c.myid == 7).with_for_update(read=True), + "SELECT mytable.myid, mytable.name, mytable.description " + "FROM mytable WHERE mytable.myid = %(myid_1)s FOR SHARE") + + self.assert_compile( + table1.select(table1.c.myid == 7). + with_for_update(read=True, nowait=True), + "SELECT mytable.myid, mytable.name, mytable.description " + "FROM mytable WHERE mytable.myid = %(myid_1)s FOR SHARE NOWAIT") + + self.assert_compile( + table1.select(table1.c.myid == 7). + with_for_update(of=table1.c.myid), + "SELECT mytable.myid, mytable.name, mytable.description " + "FROM mytable WHERE mytable.myid = %(myid_1)s " + "FOR UPDATE OF mytable") + + self.assert_compile( + table1.select(table1.c.myid == 7). + with_for_update(read=True, nowait=True, of=table1.c.myid), + "SELECT mytable.myid, mytable.name, mytable.description " + "FROM mytable WHERE mytable.myid = %(myid_1)s " + "FOR SHARE OF mytable NOWAIT") + + self.assert_compile( + table1.select(table1.c.myid == 7). + with_for_update(read=True, nowait=True, + of=[table1.c.myid, table1.c.name]), + "SELECT mytable.myid, mytable.name, mytable.description " + "FROM mytable WHERE mytable.myid = %(myid_1)s " + "FOR SHARE OF mytable NOWAIT") + + ta = table1.alias() + self.assert_compile( + ta.select(ta.c.myid == 7). + with_for_update(of=[ta.c.myid, ta.c.name]), + "SELECT mytable_1.myid, mytable_1.name, mytable_1.description " + "FROM mytable AS mytable_1 " + "WHERE mytable_1.myid = %(myid_1)s FOR UPDATE OF mytable_1" + ) def test_reserved_words(self): diff --git a/test/dialect/test_oracle.py b/test/dialect/test_oracle.py index 185bfb883..3af57c50b 100644 --- a/test/dialect/test_oracle.py +++ b/test/dialect/test_oracle.py @@ -217,6 +217,49 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): ':ROWNUM_1) WHERE ora_rn > :ora_rn_1 FOR ' 'UPDATE') + def test_for_update(self): + table1 = table('mytable', + column('myid'), column('name'), column('description')) + + self.assert_compile( + table1.select(table1.c.myid == 7).with_for_update(), + "SELECT mytable.myid, mytable.name, mytable.description " + "FROM mytable WHERE mytable.myid = :myid_1 FOR UPDATE") + + self.assert_compile( + table1.select(table1.c.myid == 7).with_for_update(of=table1.c.myid), + "SELECT mytable.myid, mytable.name, mytable.description " + "FROM mytable WHERE mytable.myid = :myid_1 FOR UPDATE OF mytable.myid") + + self.assert_compile( + table1.select(table1.c.myid == 7).with_for_update(nowait=True), + "SELECT mytable.myid, mytable.name, mytable.description " + "FROM mytable WHERE mytable.myid = :myid_1 FOR UPDATE NOWAIT") + + self.assert_compile( + table1.select(table1.c.myid == 7). + with_for_update(nowait=True, of=table1.c.myid), + "SELECT mytable.myid, mytable.name, mytable.description " + "FROM mytable WHERE mytable.myid = :myid_1 " + "FOR UPDATE OF mytable.myid NOWAIT") + + self.assert_compile( + table1.select(table1.c.myid == 7). + with_for_update(nowait=True, of=[table1.c.myid, table1.c.name]), + "SELECT mytable.myid, mytable.name, mytable.description " + "FROM mytable WHERE mytable.myid = :myid_1 FOR UPDATE OF " + "mytable.myid, mytable.name NOWAIT") + + ta = table1.alias() + self.assert_compile( + ta.select(ta.c.myid == 7). + with_for_update(of=[ta.c.myid, ta.c.name]), + "SELECT mytable_1.myid, mytable_1.name, mytable_1.description " + "FROM mytable mytable_1 " + "WHERE mytable_1.myid = :myid_1 FOR UPDATE OF " + "mytable_1.myid, mytable_1.name" + ) + def test_limit_preserves_typing_information(self): class MyType(TypeDecorator): impl = Integer diff --git a/test/orm/test_lockmode.py b/test/orm/test_lockmode.py index f9950c261..3a8379be9 100644 --- a/test/orm/test_lockmode.py +++ b/test/orm/test_lockmode.py @@ -76,7 +76,7 @@ class LockModeTest(_fixtures.FixtureTest, AssertsCompiledSQL): def test_postgres_update_of(self): User = self.classes.User sess = Session() - self.assert_compile(sess.query(User.id).with_lockmode('update', of=User.id), + self.assert_compile(sess.query(User.id).for_update(of=User.id), "SELECT users.id AS users_id FROM users FOR UPDATE OF users", dialect=postgresql.dialect() ) @@ -84,8 +84,8 @@ class LockModeTest(_fixtures.FixtureTest, AssertsCompiledSQL): def test_postgres_update_of_list(self): User = self.classes.User sess = Session() - self.assert_compile(sess.query(User.id).with_lockmode('update', of=[User.id, User.id, User.id]), - "SELECT users.id AS users_id FROM users FOR UPDATE OF users, users, users", + self.assert_compile(sess.query(User.id).for_update(of=[User.id, User.id, User.id]), + "SELECT users.id AS users_id FROM users FOR UPDATE OF users", dialect=postgresql.dialect() ) @@ -93,7 +93,7 @@ class LockModeTest(_fixtures.FixtureTest, AssertsCompiledSQL): def test_postgres_update_nowait(self): User = self.classes.User sess = Session() - self.assert_compile(sess.query(User.id).with_lockmode('update_nowait'), + self.assert_compile(sess.query(User.id).for_updatewith_lockmode('update_nowait'), "SELECT users.id AS users_id FROM users FOR UPDATE NOWAIT", dialect=postgresql.dialect() ) diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index 26cd30026..f1f852ddc 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -1045,86 +1045,22 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): def test_for_update(self): self.assert_compile( - table1.select(table1.c.myid == 7, for_update=True), + table1.select(table1.c.myid == 7).with_for_update(), "SELECT mytable.myid, mytable.name, mytable.description " "FROM mytable WHERE mytable.myid = :myid_1 FOR UPDATE") - self.assert_compile( - table1.select(table1.c.myid == 7, for_update=False), - "SELECT mytable.myid, mytable.name, mytable.description " - "FROM mytable WHERE mytable.myid = :myid_1") - # not supported by dialect, should just use update self.assert_compile( - table1.select(table1.c.myid == 7, for_update='nowait'), - "SELECT mytable.myid, mytable.name, mytable.description " - "FROM mytable WHERE mytable.myid = :myid_1 FOR UPDATE") - - # unknown lock mode - self.assert_compile( - table1.select(table1.c.myid == 7, for_update='unknown_mode'), + table1.select(table1.c.myid == 7).with_for_update(nowait=True), "SELECT mytable.myid, mytable.name, mytable.description " "FROM mytable WHERE mytable.myid = :myid_1 FOR UPDATE") - # ----- mysql - - self.assert_compile( - table1.select(table1.c.myid == 7, for_update=True), - "SELECT mytable.myid, mytable.name, mytable.description " - "FROM mytable WHERE mytable.myid = %s FOR UPDATE", - dialect=mysql.dialect()) - - self.assert_compile( - table1.select(table1.c.myid == 7, for_update="read"), - "SELECT mytable.myid, mytable.name, mytable.description " - "FROM mytable WHERE mytable.myid = %s LOCK IN SHARE MODE", - dialect=mysql.dialect()) - - # ----- oracle - - self.assert_compile( - table1.select(table1.c.myid == 7, for_update=True), - "SELECT mytable.myid, mytable.name, mytable.description " - "FROM mytable WHERE mytable.myid = :myid_1 FOR UPDATE", - dialect=oracle.dialect()) - - self.assert_compile( - table1.select(table1.c.myid == 7, for_update="nowait"), - "SELECT mytable.myid, mytable.name, mytable.description " - "FROM mytable WHERE mytable.myid = :myid_1 FOR UPDATE NOWAIT", - dialect=oracle.dialect()) - - # ----- postgresql - - self.assert_compile( - table1.select(table1.c.myid == 7, for_update=True), - "SELECT mytable.myid, mytable.name, mytable.description " - "FROM mytable WHERE mytable.myid = %(myid_1)s FOR UPDATE", - dialect=postgresql.dialect()) - - self.assert_compile( - table1.select(table1.c.myid == 7, for_update="nowait"), - "SELECT mytable.myid, mytable.name, mytable.description " - "FROM mytable WHERE mytable.myid = %(myid_1)s FOR UPDATE NOWAIT", - dialect=postgresql.dialect()) - - self.assert_compile( - table1.select(table1.c.myid == 7, for_update="read"), - "SELECT mytable.myid, mytable.name, mytable.description " - "FROM mytable WHERE mytable.myid = %(myid_1)s FOR SHARE", - dialect=postgresql.dialect()) - - self.assert_compile( - table1.select(table1.c.myid == 7, for_update="read_nowait"), - "SELECT mytable.myid, mytable.name, mytable.description " - "FROM mytable WHERE mytable.myid = %(myid_1)s FOR SHARE NOWAIT", - dialect=postgresql.dialect()) + assert_raises_message( + exc.ArgumentError, + "Unknown for_update argument: 'unknown_mode'", + table1.select, table1.c.myid == 7, for_update='unknown_mode' + ) - self.assert_compile( - table1.select(table1.c.myid == 7).with_for_update(of=table1.c.myid), - "SELECT mytable.myid, mytable.name, mytable.description " - "FROM mytable WHERE mytable.myid = %(myid_1)s FOR UPDATE OF mytable", - dialect=postgresql.dialect()) def test_alias(self): # test the alias for a table1. column names stay the same, diff --git a/test/sql/test_selectable.py b/test/sql/test_selectable.py index 0fc7a0ed0..66cdd87c2 100644 --- a/test/sql/test_selectable.py +++ b/test/sql/test_selectable.py @@ -1903,3 +1903,57 @@ class WithLabelsTest(fixtures.TestBase): ['t1_x', 't2_x'] ) self._assert_result_keys(sel, ['t1_a', 't2_b']) + +class ForUpdateTest(fixtures.TestBase, AssertsCompiledSQL): + __dialect__ = "default" + + def _assert_legacy(self, leg, read=False, nowait=False): + t = table('t', column('c')) + s1 = select([t], for_update=leg) + if leg is False: + assert s1._for_update_arg is None + assert s1.for_update is None + else: + eq_( + s1._for_update_arg.read, read + ) + eq_( + s1._for_update_arg.nowait, nowait + ) + eq_(s1.for_update, leg) + + def test_false_legacy(self): + self._assert_legacy(False) + + def test_plain_true_legacy(self): + self._assert_legacy(True) + + def test_read_legacy(self): + self._assert_legacy("read", read=True) + + def test_nowait_legacy(self): + self._assert_legacy("nowait", nowait=True) + + def test_read_nowait_legacy(self): + self._assert_legacy("read_nowait", read=True, nowait=True) + + def test_basic_clone(self): + t = table('t', column('c')) + s = select([t]).with_for_update(read=True, of=t.c.c) + s2 = visitors.ReplacingCloningVisitor().traverse(s) + assert s2._for_update_arg is not s._for_update_arg + eq_(s2._for_update_arg.read, True) + eq_(s2._for_update_arg.of, [t.c.c]) + self.assert_compile(s2, + "SELECT t.c FROM t FOR SHARE OF t", + dialect="postgresql") + + def test_adapt(self): + t = table('t', column('c')) + s = select([t]).with_for_update(read=True, of=t.c.c) + a = t.alias() + s2 = sql_util.ClauseAdapter(a).traverse(s) + eq_(s2._for_update_arg.of, [a.c.c]) + self.assert_compile(s2, + "SELECT t_1.c FROM t AS t_1 FOR SHARE OF t_1", + dialect="postgresql") |