summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/sqlalchemy/dialects/oracle/base.py8
-rw-r--r--lib/sqlalchemy/dialects/postgresql/base.py8
-rw-r--r--lib/sqlalchemy/orm/query.py4
-rw-r--r--lib/sqlalchemy/sql/compiler.py4
-rw-r--r--lib/sqlalchemy/sql/selectable.py18
-rw-r--r--test/dialect/mysql/test_compiler.py15
-rw-r--r--test/dialect/postgresql/test_compiler.py55
-rw-r--r--test/dialect/test_oracle.py43
-rw-r--r--test/orm/test_lockmode.py8
-rw-r--r--test/sql/test_compiler.py78
-rw-r--r--test/sql/test_selectable.py54
11 files changed, 203 insertions, 92 deletions
diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py
index a3c31b7cc..ba69c3d1f 100644
--- a/lib/sqlalchemy/dialects/oracle/base.py
+++ b/lib/sqlalchemy/dialects/oracle/base.py
@@ -666,15 +666,15 @@ class OracleCompiler(compiler.SQLCompiler):
tmp = ' FOR UPDATE'
- if select._for_update_arg.nowait:
- tmp += " NOWAIT"
-
if select._for_update_arg.of:
tmp += ' OF ' + ', '.join(
- self._process(elem) for elem in
+ self.process(elem) for elem in
select._for_update_arg.of
)
+ if select._for_update_arg.nowait:
+ tmp += " NOWAIT"
+
return tmp
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py
index 091fdeda2..69b0fb040 100644
--- a/lib/sqlalchemy/dialects/postgresql/base.py
+++ b/lib/sqlalchemy/dialects/postgresql/base.py
@@ -1020,17 +1020,17 @@ class PGCompiler(compiler.SQLCompiler):
else:
tmp = " FOR UPDATE"
- if select._for_update_arg.nowait:
- tmp += " NOWAIT"
-
if select._for_update_arg.of:
# TODO: assuming simplistic c.table here
tables = set(c.table for c in select._for_update_arg.of)
tmp += " OF " + ", ".join(
- self.process(table, asfrom=True)
+ self.process(table, ashint=True)
for table in tables
)
+ if select._for_update_arg.nowait:
+ tmp += " NOWAIT"
+
return tmp
def returning_clause(self, stmt, returning_cols):
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index f0d9a47d6..173ad038e 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -1124,10 +1124,10 @@ class Query(object):
self._execution_options = self._execution_options.union(kwargs)
@_generative()
- def with_lockmode(self, mode, of=None):
+ def with_lockmode(self, mode):
"""Return a new Query object with the specified locking mode.
- .. deprecated:: 0.9.0b2 superseded by :meth:`.Query.for_update`.
+ .. deprecated:: 0.9.0b2 superseded by :meth:`.Query.with_for_update`.
:param mode: a string representing the desired locking mode. A
corresponding :meth:`~sqlalchemy.orm.query.LockmodeArgs` object
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 0fc99897e..3ba3957d6 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -1513,9 +1513,11 @@ class SQLCompiler(Compiled):
text += self.order_by_clause(select,
order_by_select=order_by_select, **kwargs)
+
if select._limit is not None or select._offset is not None:
text += self.limit_clause(select)
- if select._for_update_arg:
+
+ if select._for_update_arg is not None:
text += self.for_update_clause(select)
if self.ctes and \
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index e49c10001..01c803f3b 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -1151,7 +1151,7 @@ class TableClause(Immutable, FromClause):
return [self]
-class ForUpdateArg(object):
+class ForUpdateArg(ClauseElement):
@classmethod
def parse_legacy_select(self, arg):
@@ -1185,6 +1185,8 @@ class ForUpdateArg(object):
read = True
elif arg == 'read_nowait':
read = nowait = True
+ elif arg is not True:
+ raise exc.ArgumentError("Unknown for_update argument: %r" % arg)
return ForUpdateArg(read=read, nowait=nowait)
@@ -1195,9 +1197,13 @@ class ForUpdateArg(object):
elif self.read and self.nowait:
return "read_nowait"
elif self.nowait:
- return "update_nowait"
+ return "nowait"
else:
- return "update"
+ return True
+
+ def _copy_internals(self, clone=_clone, **kw):
+ if self.of is not None:
+ self.of = [clone(col, **kw) for col in self.of]
def __init__(self, nowait=False, read=False, of=None):
"""Represents arguments specified to :meth:`.Select.for_update`.
@@ -1208,7 +1214,7 @@ class ForUpdateArg(object):
self.nowait = nowait
self.read = read
if of is not None:
- self.of = [_only_column_elements(of, "of")
+ self.of = [_only_column_elements(elem, "of")
for elem in util.to_list(of)]
else:
self.of = None
@@ -1770,7 +1776,7 @@ class CompoundSelect(SelectBase):
self.selects = [clone(s, **kw) for s in self.selects]
if hasattr(self, '_col_map'):
del self._col_map
- for attr in ('_order_by_clause', '_group_by_clause'):
+ for attr in ('_order_by_clause', '_group_by_clause', '_for_update_arg'):
if getattr(self, attr) is not None:
setattr(self, attr, clone(getattr(self, attr), **kw))
@@ -2255,7 +2261,7 @@ class Select(HasPrefixes, SelectBase):
# present here.
self._raw_columns = [clone(c, **kw) for c in self._raw_columns]
for attr in '_whereclause', '_having', '_order_by_clause', \
- '_group_by_clause':
+ '_group_by_clause', '_for_update_arg':
if getattr(self, attr) is not None:
setattr(self, attr, clone(getattr(self, attr), **kw))
diff --git a/test/dialect/mysql/test_compiler.py b/test/dialect/mysql/test_compiler.py
index a50c6a901..46e8bfb82 100644
--- a/test/dialect/mysql/test_compiler.py
+++ b/test/dialect/mysql/test_compiler.py
@@ -6,6 +6,7 @@ from sqlalchemy import sql, exc, schema, types as sqltypes
from sqlalchemy.dialects.mysql import base as mysql
from sqlalchemy.testing import fixtures, AssertsCompiledSQL
from sqlalchemy import testing
+from sqlalchemy.sql import table, column
class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
@@ -131,6 +132,20 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
schema.CreateTable(t2).compile, dialect=mysql.dialect()
)
+ def test_for_update(self):
+ table1 = table('mytable',
+ column('myid'), column('name'), column('description'))
+
+ self.assert_compile(
+ table1.select(table1.c.myid == 7).with_for_update(),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %s FOR UPDATE")
+
+ self.assert_compile(
+ table1.select(table1.c.myid == 7).with_for_update(read=True),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %s LOCK IN SHARE MODE")
+
class SQLTest(fixtures.TestBase, AssertsCompiledSQL):
"""Tests MySQL-dialect specific compilation."""
diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py
index 76fd9d907..05963e51c 100644
--- a/test/dialect/postgresql/test_compiler.py
+++ b/test/dialect/postgresql/test_compiler.py
@@ -249,6 +249,61 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
'SUBSTRING(%(substring_1)s FROM %(substring_2)s)')
+ def test_for_update(self):
+ table1 = table('mytable',
+ column('myid'), column('name'), column('description'))
+
+ self.assert_compile(
+ table1.select(table1.c.myid == 7).with_for_update(),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %(myid_1)s FOR UPDATE")
+
+ self.assert_compile(
+ table1.select(table1.c.myid == 7).with_for_update(nowait=True),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %(myid_1)s FOR UPDATE NOWAIT")
+
+ self.assert_compile(
+ table1.select(table1.c.myid == 7).with_for_update(read=True),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %(myid_1)s FOR SHARE")
+
+ self.assert_compile(
+ table1.select(table1.c.myid == 7).
+ with_for_update(read=True, nowait=True),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %(myid_1)s FOR SHARE NOWAIT")
+
+ self.assert_compile(
+ table1.select(table1.c.myid == 7).
+ with_for_update(of=table1.c.myid),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %(myid_1)s "
+ "FOR UPDATE OF mytable")
+
+ self.assert_compile(
+ table1.select(table1.c.myid == 7).
+ with_for_update(read=True, nowait=True, of=table1.c.myid),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %(myid_1)s "
+ "FOR SHARE OF mytable NOWAIT")
+
+ self.assert_compile(
+ table1.select(table1.c.myid == 7).
+ with_for_update(read=True, nowait=True,
+ of=[table1.c.myid, table1.c.name]),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %(myid_1)s "
+ "FOR SHARE OF mytable NOWAIT")
+
+ ta = table1.alias()
+ self.assert_compile(
+ ta.select(ta.c.myid == 7).
+ with_for_update(of=[ta.c.myid, ta.c.name]),
+ "SELECT mytable_1.myid, mytable_1.name, mytable_1.description "
+ "FROM mytable AS mytable_1 "
+ "WHERE mytable_1.myid = %(myid_1)s FOR UPDATE OF mytable_1"
+ )
def test_reserved_words(self):
diff --git a/test/dialect/test_oracle.py b/test/dialect/test_oracle.py
index 185bfb883..3af57c50b 100644
--- a/test/dialect/test_oracle.py
+++ b/test/dialect/test_oracle.py
@@ -217,6 +217,49 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
':ROWNUM_1) WHERE ora_rn > :ora_rn_1 FOR '
'UPDATE')
+ def test_for_update(self):
+ table1 = table('mytable',
+ column('myid'), column('name'), column('description'))
+
+ self.assert_compile(
+ table1.select(table1.c.myid == 7).with_for_update(),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = :myid_1 FOR UPDATE")
+
+ self.assert_compile(
+ table1.select(table1.c.myid == 7).with_for_update(of=table1.c.myid),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = :myid_1 FOR UPDATE OF mytable.myid")
+
+ self.assert_compile(
+ table1.select(table1.c.myid == 7).with_for_update(nowait=True),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = :myid_1 FOR UPDATE NOWAIT")
+
+ self.assert_compile(
+ table1.select(table1.c.myid == 7).
+ with_for_update(nowait=True, of=table1.c.myid),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = :myid_1 "
+ "FOR UPDATE OF mytable.myid NOWAIT")
+
+ self.assert_compile(
+ table1.select(table1.c.myid == 7).
+ with_for_update(nowait=True, of=[table1.c.myid, table1.c.name]),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = :myid_1 FOR UPDATE OF "
+ "mytable.myid, mytable.name NOWAIT")
+
+ ta = table1.alias()
+ self.assert_compile(
+ ta.select(ta.c.myid == 7).
+ with_for_update(of=[ta.c.myid, ta.c.name]),
+ "SELECT mytable_1.myid, mytable_1.name, mytable_1.description "
+ "FROM mytable mytable_1 "
+ "WHERE mytable_1.myid = :myid_1 FOR UPDATE OF "
+ "mytable_1.myid, mytable_1.name"
+ )
+
def test_limit_preserves_typing_information(self):
class MyType(TypeDecorator):
impl = Integer
diff --git a/test/orm/test_lockmode.py b/test/orm/test_lockmode.py
index f9950c261..3a8379be9 100644
--- a/test/orm/test_lockmode.py
+++ b/test/orm/test_lockmode.py
@@ -76,7 +76,7 @@ class LockModeTest(_fixtures.FixtureTest, AssertsCompiledSQL):
def test_postgres_update_of(self):
User = self.classes.User
sess = Session()
- self.assert_compile(sess.query(User.id).with_lockmode('update', of=User.id),
+ self.assert_compile(sess.query(User.id).for_update(of=User.id),
"SELECT users.id AS users_id FROM users FOR UPDATE OF users",
dialect=postgresql.dialect()
)
@@ -84,8 +84,8 @@ class LockModeTest(_fixtures.FixtureTest, AssertsCompiledSQL):
def test_postgres_update_of_list(self):
User = self.classes.User
sess = Session()
- self.assert_compile(sess.query(User.id).with_lockmode('update', of=[User.id, User.id, User.id]),
- "SELECT users.id AS users_id FROM users FOR UPDATE OF users, users, users",
+ self.assert_compile(sess.query(User.id).for_update(of=[User.id, User.id, User.id]),
+ "SELECT users.id AS users_id FROM users FOR UPDATE OF users",
dialect=postgresql.dialect()
)
@@ -93,7 +93,7 @@ class LockModeTest(_fixtures.FixtureTest, AssertsCompiledSQL):
def test_postgres_update_nowait(self):
User = self.classes.User
sess = Session()
- self.assert_compile(sess.query(User.id).with_lockmode('update_nowait'),
+ self.assert_compile(sess.query(User.id).for_updatewith_lockmode('update_nowait'),
"SELECT users.id AS users_id FROM users FOR UPDATE NOWAIT",
dialect=postgresql.dialect()
)
diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py
index 26cd30026..f1f852ddc 100644
--- a/test/sql/test_compiler.py
+++ b/test/sql/test_compiler.py
@@ -1045,86 +1045,22 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
def test_for_update(self):
self.assert_compile(
- table1.select(table1.c.myid == 7, for_update=True),
+ table1.select(table1.c.myid == 7).with_for_update(),
"SELECT mytable.myid, mytable.name, mytable.description "
"FROM mytable WHERE mytable.myid = :myid_1 FOR UPDATE")
- self.assert_compile(
- table1.select(table1.c.myid == 7, for_update=False),
- "SELECT mytable.myid, mytable.name, mytable.description "
- "FROM mytable WHERE mytable.myid = :myid_1")
-
# not supported by dialect, should just use update
self.assert_compile(
- table1.select(table1.c.myid == 7, for_update='nowait'),
- "SELECT mytable.myid, mytable.name, mytable.description "
- "FROM mytable WHERE mytable.myid = :myid_1 FOR UPDATE")
-
- # unknown lock mode
- self.assert_compile(
- table1.select(table1.c.myid == 7, for_update='unknown_mode'),
+ table1.select(table1.c.myid == 7).with_for_update(nowait=True),
"SELECT mytable.myid, mytable.name, mytable.description "
"FROM mytable WHERE mytable.myid = :myid_1 FOR UPDATE")
- # ----- mysql
-
- self.assert_compile(
- table1.select(table1.c.myid == 7, for_update=True),
- "SELECT mytable.myid, mytable.name, mytable.description "
- "FROM mytable WHERE mytable.myid = %s FOR UPDATE",
- dialect=mysql.dialect())
-
- self.assert_compile(
- table1.select(table1.c.myid == 7, for_update="read"),
- "SELECT mytable.myid, mytable.name, mytable.description "
- "FROM mytable WHERE mytable.myid = %s LOCK IN SHARE MODE",
- dialect=mysql.dialect())
-
- # ----- oracle
-
- self.assert_compile(
- table1.select(table1.c.myid == 7, for_update=True),
- "SELECT mytable.myid, mytable.name, mytable.description "
- "FROM mytable WHERE mytable.myid = :myid_1 FOR UPDATE",
- dialect=oracle.dialect())
-
- self.assert_compile(
- table1.select(table1.c.myid == 7, for_update="nowait"),
- "SELECT mytable.myid, mytable.name, mytable.description "
- "FROM mytable WHERE mytable.myid = :myid_1 FOR UPDATE NOWAIT",
- dialect=oracle.dialect())
-
- # ----- postgresql
-
- self.assert_compile(
- table1.select(table1.c.myid == 7, for_update=True),
- "SELECT mytable.myid, mytable.name, mytable.description "
- "FROM mytable WHERE mytable.myid = %(myid_1)s FOR UPDATE",
- dialect=postgresql.dialect())
-
- self.assert_compile(
- table1.select(table1.c.myid == 7, for_update="nowait"),
- "SELECT mytable.myid, mytable.name, mytable.description "
- "FROM mytable WHERE mytable.myid = %(myid_1)s FOR UPDATE NOWAIT",
- dialect=postgresql.dialect())
-
- self.assert_compile(
- table1.select(table1.c.myid == 7, for_update="read"),
- "SELECT mytable.myid, mytable.name, mytable.description "
- "FROM mytable WHERE mytable.myid = %(myid_1)s FOR SHARE",
- dialect=postgresql.dialect())
-
- self.assert_compile(
- table1.select(table1.c.myid == 7, for_update="read_nowait"),
- "SELECT mytable.myid, mytable.name, mytable.description "
- "FROM mytable WHERE mytable.myid = %(myid_1)s FOR SHARE NOWAIT",
- dialect=postgresql.dialect())
+ assert_raises_message(
+ exc.ArgumentError,
+ "Unknown for_update argument: 'unknown_mode'",
+ table1.select, table1.c.myid == 7, for_update='unknown_mode'
+ )
- self.assert_compile(
- table1.select(table1.c.myid == 7).with_for_update(of=table1.c.myid),
- "SELECT mytable.myid, mytable.name, mytable.description "
- "FROM mytable WHERE mytable.myid = %(myid_1)s FOR UPDATE OF mytable",
- dialect=postgresql.dialect())
def test_alias(self):
# test the alias for a table1. column names stay the same,
diff --git a/test/sql/test_selectable.py b/test/sql/test_selectable.py
index 0fc7a0ed0..66cdd87c2 100644
--- a/test/sql/test_selectable.py
+++ b/test/sql/test_selectable.py
@@ -1903,3 +1903,57 @@ class WithLabelsTest(fixtures.TestBase):
['t1_x', 't2_x']
)
self._assert_result_keys(sel, ['t1_a', 't2_b'])
+
+class ForUpdateTest(fixtures.TestBase, AssertsCompiledSQL):
+ __dialect__ = "default"
+
+ def _assert_legacy(self, leg, read=False, nowait=False):
+ t = table('t', column('c'))
+ s1 = select([t], for_update=leg)
+ if leg is False:
+ assert s1._for_update_arg is None
+ assert s1.for_update is None
+ else:
+ eq_(
+ s1._for_update_arg.read, read
+ )
+ eq_(
+ s1._for_update_arg.nowait, nowait
+ )
+ eq_(s1.for_update, leg)
+
+ def test_false_legacy(self):
+ self._assert_legacy(False)
+
+ def test_plain_true_legacy(self):
+ self._assert_legacy(True)
+
+ def test_read_legacy(self):
+ self._assert_legacy("read", read=True)
+
+ def test_nowait_legacy(self):
+ self._assert_legacy("nowait", nowait=True)
+
+ def test_read_nowait_legacy(self):
+ self._assert_legacy("read_nowait", read=True, nowait=True)
+
+ def test_basic_clone(self):
+ t = table('t', column('c'))
+ s = select([t]).with_for_update(read=True, of=t.c.c)
+ s2 = visitors.ReplacingCloningVisitor().traverse(s)
+ assert s2._for_update_arg is not s._for_update_arg
+ eq_(s2._for_update_arg.read, True)
+ eq_(s2._for_update_arg.of, [t.c.c])
+ self.assert_compile(s2,
+ "SELECT t.c FROM t FOR SHARE OF t",
+ dialect="postgresql")
+
+ def test_adapt(self):
+ t = table('t', column('c'))
+ s = select([t]).with_for_update(read=True, of=t.c.c)
+ a = t.alias()
+ s2 = sql_util.ClauseAdapter(a).traverse(s)
+ eq_(s2._for_update_arg.of, [a.c.c])
+ self.assert_compile(s2,
+ "SELECT t_1.c FROM t AS t_1 FOR SHARE OF t_1",
+ dialect="postgresql")