summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--doc/build/changelog/unreleased_20/9228.rst9
-rw-r--r--lib/sqlalchemy/orm/mapper.py20
-rw-r--r--lib/sqlalchemy/orm/persistence.py13
-rw-r--r--test/orm/inheritance/test_basic.py6
-rw-r--r--test/orm/test_versioning.py162
5 files changed, 81 insertions, 129 deletions
diff --git a/doc/build/changelog/unreleased_20/9228.rst b/doc/build/changelog/unreleased_20/9228.rst
new file mode 100644
index 000000000..7e96c2461
--- /dev/null
+++ b/doc/build/changelog/unreleased_20/9228.rst
@@ -0,0 +1,9 @@
+.. change::
+ :tags: orm, bug, regression
+ :tickets: 9228
+
+ Fixed regression where using the :paramref:`_orm.Mapper.version_id_col`
+ feature with a regular Python-side incrementing column would fail to work
+ for SQLite and other databases that don't support "rowcount" with
+ "RETURNING", as "RETURNING" would be assumed for such columns even though
+ that's not what actually takes place.
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index a3b209e4a..bb7e470ff 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -83,6 +83,7 @@ from ..sql import util as sql_util
from ..sql import visitors
from ..sql.cache_key import MemoizedHasCacheKey
from ..sql.elements import KeyedColumnElement
+from ..sql.schema import Column
from ..sql.schema import Table
from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
from ..util import HasMemoized
@@ -112,7 +113,6 @@ if TYPE_CHECKING:
from ..sql.base import ReadOnlyColumnCollection
from ..sql.elements import ColumnClause
from ..sql.elements import ColumnElement
- from ..sql.schema import Column
from ..sql.selectable import FromClause
from ..util import OrderedSet
@@ -2523,6 +2523,24 @@ class Mapper(
return from_obj
@HasMemoized.memoized_attribute
+ def _version_id_has_server_side_value(self) -> bool:
+ vid_col = self.version_id_col
+
+ if vid_col is None:
+ return False
+
+ elif not isinstance(vid_col, Column):
+ return True
+ else:
+ return vid_col.server_default is not None or (
+ vid_col.default is not None
+ and (
+ not vid_col.default.is_scalar
+ and not vid_col.default.is_callable
+ )
+ )
+
+ @HasMemoized.memoized_attribute
def _single_table_criterion(self):
if self.single and self.inherits and self.polymorphic_on is not None:
return self.polymorphic_on._annotate(
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py
index cc7e321b4..b8368001b 100644
--- a/lib/sqlalchemy/orm/persistence.py
+++ b/lib/sqlalchemy/orm/persistence.py
@@ -801,7 +801,7 @@ def _emit_update_statements(
)
return_defaults = True
- if mapper.version_id_col is not None:
+ if mapper._version_id_has_server_side_value:
statement = statement.return_defaults(mapper.version_id_col)
return_defaults = True
@@ -1268,13 +1268,16 @@ def _emit_post_update_statements(
stmt = table.update().where(clauses)
- if mapper.version_id_col is not None:
- stmt = stmt.return_defaults(mapper.version_id_col)
-
return stmt
statement = base_mapper._memo(("post_update", table), update_stmt)
+ if mapper._version_id_has_server_side_value:
+ statement = statement.return_defaults(mapper.version_id_col)
+ return_defaults = True
+ else:
+ return_defaults = False
+
# execute each UPDATE in the order according to the original
# list of states to guarantee row access order, but
# also group them into common (connection, cols) sets
@@ -1290,7 +1293,7 @@ def _emit_post_update_statements(
assert_singlerow = (
connection.dialect.supports_sane_rowcount
- if mapper.version_id_col is None
+ if not return_defaults
else connection.dialect.supports_sane_rowcount_returning
)
assert_multirow = (
diff --git a/test/orm/inheritance/test_basic.py b/test/orm/inheritance/test_basic.py
index 37368f3ad..02f352786 100644
--- a/test/orm/inheritance/test_basic.py
+++ b/test/orm/inheritance/test_basic.py
@@ -2101,8 +2101,7 @@ class VersioningTest(fixtures.MappedTest):
Column("parent", Integer, ForeignKey("base.id")),
)
- @testing.emits_warning(r".*updated rowcount")
- @testing.requires.sane_rowcount_w_returning
+ @testing.requires.sane_rowcount
def test_save_update(self):
subtable, base, stuff = (
self.tables.subtable,
@@ -2170,8 +2169,7 @@ class VersioningTest(fixtures.MappedTest):
s2.subdata = "sess2 subdata"
sess2.flush()
- @testing.emits_warning(r".*(update|delete)d rowcount")
- @testing.requires.sane_rowcount_w_returning
+ @testing.requires.sane_rowcount
def test_delete(self):
subtable, base = self.tables.subtable, self.tables.base
diff --git a/test/orm/test_versioning.py b/test/orm/test_versioning.py
index 1a1801311..f6b9f18fc 100644
--- a/test/orm/test_versioning.py
+++ b/test/orm/test_versioning.py
@@ -140,9 +140,7 @@ class NullVersionIdTest(fixtures.MappedTest):
f1.value = "f1rev2"
- with conditional_sane_rowcount_warnings(
- update=True, only_returning=True
- ):
+ with conditional_sane_rowcount_warnings(update=True):
f1.version_id = None
assert_raises_message(
sa.orm.exc.FlushError,
@@ -209,24 +207,20 @@ class VersioningTest(fixtures.MappedTest):
s1.commit()
f1.value = "f1rev2"
- with conditional_sane_rowcount_warnings(
- update=True, only_returning=True
- ):
+ with conditional_sane_rowcount_warnings(update=True):
s1.commit()
s2 = fixture_session()
f1_s = s2.get(Foo, f1.id)
f1_s.value = "f1rev3"
- with conditional_sane_rowcount_warnings(
- update=True, only_returning=True
- ):
+ with conditional_sane_rowcount_warnings(update=True):
s2.commit()
f1.value = "f1rev3mine"
# Only dialects with a sane rowcount can detect the
# StaleDataError
- if testing.db.dialect.supports_sane_rowcount_returning:
+ if testing.db.dialect.supports_sane_rowcount:
assert_raises_message(
sa.orm.exc.StaleDataError,
r"UPDATE statement on table 'version_table' expected "
@@ -235,9 +229,7 @@ class VersioningTest(fixtures.MappedTest):
),
s1.rollback()
else:
- with conditional_sane_rowcount_warnings(
- update=True, only_returning=True
- ):
+ with conditional_sane_rowcount_warnings(update=True):
s1.commit()
# new in 0.5 ! don't need to close the session
@@ -245,9 +237,7 @@ class VersioningTest(fixtures.MappedTest):
f2 = s1.get(Foo, f2.id)
f1_s.value = "f1rev4"
- with conditional_sane_rowcount_warnings(
- update=True, only_returning=True
- ):
+ with conditional_sane_rowcount_warnings(update=True):
s2.commit()
s1.delete(f1)
@@ -275,9 +265,7 @@ class VersioningTest(fixtures.MappedTest):
f1.value = "f1rev2"
f2.value = "f2rev2"
- with conditional_sane_rowcount_warnings(
- update=True, only_returning=True
- ):
+ with conditional_sane_rowcount_warnings(update=True):
s1.commit()
eq_(
@@ -306,9 +294,7 @@ class VersioningTest(fixtures.MappedTest):
s1.add_all((f1, f2))
s1.commit()
- with conditional_sane_rowcount_warnings(
- update=True, only_returning=True
- ):
+ with conditional_sane_rowcount_warnings(update=True):
s1.bulk_update_mappings(
Foo,
[
@@ -340,9 +326,7 @@ class VersioningTest(fixtures.MappedTest):
s1.commit()
eq_(f1.version_id, 1)
f1.version_id = 2
- with conditional_sane_rowcount_warnings(
- update=True, only_returning=True
- ):
+ with conditional_sane_rowcount_warnings(update=True):
s1.commit()
eq_(f1.version_id, 2)
@@ -350,9 +334,7 @@ class VersioningTest(fixtures.MappedTest):
# is honored
f1.version_id = 4
f1.value = "something new"
- with conditional_sane_rowcount_warnings(
- update=True, only_returning=True
- ):
+ with conditional_sane_rowcount_warnings(update=True):
s1.commit()
eq_(f1.version_id, 4)
@@ -377,9 +359,7 @@ class VersioningTest(fixtures.MappedTest):
s2 = fixture_session()
f1s2 = s2.get(Foo, f1s1.id)
f1s2.value = "f1 new value"
- with conditional_sane_rowcount_warnings(
- update=True, only_returning=True
- ):
+ with conditional_sane_rowcount_warnings(update=True):
s2.commit()
# load, version is wrong
@@ -465,9 +445,7 @@ class VersioningTest(fixtures.MappedTest):
s1.commit()
f1s1.value = "f2 value"
- with conditional_sane_rowcount_warnings(
- update=True, only_returning=True
- ):
+ with conditional_sane_rowcount_warnings(update=True):
s1.flush()
eq_(f1s1.version_id, 2)
@@ -532,17 +510,13 @@ class VersioningTest(fixtures.MappedTest):
s1.commit()
f1.value = "f2"
- with conditional_sane_rowcount_warnings(
- update=True, only_returning=True
- ):
+ with conditional_sane_rowcount_warnings(update=True):
s1.commit()
f2 = Foo(id=f1.id, value="f3")
f3 = s1.merge(f2)
assert f3 is f1
- with conditional_sane_rowcount_warnings(
- update=True, only_returning=True
- ):
+ with conditional_sane_rowcount_warnings(update=True):
s1.commit()
eq_(f3.version_id, 3)
@@ -555,17 +529,13 @@ class VersioningTest(fixtures.MappedTest):
s1.commit()
f1.value = "f2"
- with conditional_sane_rowcount_warnings(
- update=True, only_returning=True
- ):
+ with conditional_sane_rowcount_warnings(update=True):
s1.commit()
f2 = Foo(id=f1.id, value="f3", version_id=2)
f3 = s1.merge(f2)
assert f3 is f1
- with conditional_sane_rowcount_warnings(
- update=True, only_returning=True
- ):
+ with conditional_sane_rowcount_warnings(update=True):
s1.commit()
eq_(f3.version_id, 3)
@@ -578,9 +548,7 @@ class VersioningTest(fixtures.MappedTest):
s1.commit()
f1.value = "f2"
- with conditional_sane_rowcount_warnings(
- update=True, only_returning=True
- ):
+ with conditional_sane_rowcount_warnings(update=True):
s1.commit()
f2 = Foo(id=f1.id, value="f3", version_id=1)
@@ -603,9 +571,7 @@ class VersioningTest(fixtures.MappedTest):
s1.commit()
f1.value = "f2"
- with conditional_sane_rowcount_warnings(
- update=True, only_returning=True
- ):
+ with conditional_sane_rowcount_warnings(update=True):
s1.commit()
f2 = Foo(id=f1.id, value="f3", version_id=1)
@@ -670,9 +636,7 @@ class VersionOnPostUpdateTest(fixtures.MappedTest):
s, n1, n2 = self._fixture(o2m=True, post_update=False)
n1.related.append(n2)
- with conditional_sane_rowcount_warnings(
- update=True, only_returning=True
- ):
+ with conditional_sane_rowcount_warnings(update=True):
s.flush()
eq_(n1.version_id, 1)
@@ -682,9 +646,7 @@ class VersionOnPostUpdateTest(fixtures.MappedTest):
s, n1, n2 = self._fixture(o2m=False, post_update=False)
n1.related = n2
- with conditional_sane_rowcount_warnings(
- update=True, only_returning=True
- ):
+ with conditional_sane_rowcount_warnings(update=True):
s.flush()
eq_(n1.version_id, 2)
@@ -694,9 +656,7 @@ class VersionOnPostUpdateTest(fixtures.MappedTest):
s, n1, n2 = self._fixture(o2m=True, post_update=True)
n1.related.append(n2)
- with conditional_sane_rowcount_warnings(
- update=True, only_returning=True
- ):
+ with conditional_sane_rowcount_warnings(update=True):
s.flush()
eq_(n1.version_id, 1)
@@ -706,9 +666,7 @@ class VersionOnPostUpdateTest(fixtures.MappedTest):
s, n1, n2 = self._fixture(o2m=False, post_update=True)
n1.related = n2
- with conditional_sane_rowcount_warnings(
- update=True, only_returning=True
- ):
+ with conditional_sane_rowcount_warnings(update=True):
s.flush()
eq_(n1.version_id, 2)
@@ -719,9 +677,7 @@ class VersionOnPostUpdateTest(fixtures.MappedTest):
n1.related.append(n2)
s.add_all([n1, n2])
- with conditional_sane_rowcount_warnings(
- update=True, only_returning=True
- ):
+ with conditional_sane_rowcount_warnings(update=True):
s.flush()
eq_(n1.version_id, 1)
@@ -732,15 +688,13 @@ class VersionOnPostUpdateTest(fixtures.MappedTest):
n1.related = n2
s.add_all([n1, n2])
- with conditional_sane_rowcount_warnings(
- update=True, only_returning=True
- ):
+ with conditional_sane_rowcount_warnings(update=True):
s.flush()
eq_(n1.version_id, 1)
eq_(n2.version_id, 1)
- @testing.requires.sane_rowcount_w_returning
+ @testing.requires.sane_rowcount
def test_o2m_post_update_version_assert(self):
Node = self.classes.Node
s, n1, n2 = self._fixture(o2m=True, post_update=True)
@@ -782,7 +736,7 @@ class VersionOnPostUpdateTest(fixtures.MappedTest):
):
s.flush()
- @testing.requires.sane_rowcount_w_returning
+ @testing.requires.sane_rowcount
def test_m2o_post_update_version_assert(self):
Node = self.classes.Node
@@ -944,9 +898,7 @@ class ColumnTypeTest(fixtures.MappedTest):
s1.commit()
f1.value = "f1rev2"
- with conditional_sane_rowcount_warnings(
- update=True, only_returning=True
- ):
+ with conditional_sane_rowcount_warnings(update=True):
s1.commit()
@@ -1007,9 +959,7 @@ class RowSwitchTest(fixtures.MappedTest):
p = session.query(P).first()
session.delete(p)
session.add(P(id="P1", data="really a row-switch"))
- with conditional_sane_rowcount_warnings(
- update=True, only_returning=True
- ):
+ with conditional_sane_rowcount_warnings(update=True):
session.commit()
def test_child_row_switch(self):
@@ -1028,9 +978,7 @@ class RowSwitchTest(fixtures.MappedTest):
p = session.query(P).first()
p.c = C(data="child row-switch")
- with conditional_sane_rowcount_warnings(
- update=True, only_returning=True
- ):
+ with conditional_sane_rowcount_warnings(update=True):
session.commit()
@@ -1096,9 +1044,7 @@ class AlternateGeneratorTest(fixtures.MappedTest):
p = session.query(P).first()
session.delete(p)
session.add(P(id="P1", data="really a row-switch"))
- with conditional_sane_rowcount_warnings(
- update=True, only_returning=True
- ):
+ with conditional_sane_rowcount_warnings(update=True):
session.commit()
def test_child_row_switch_one(self):
@@ -1117,12 +1063,10 @@ class AlternateGeneratorTest(fixtures.MappedTest):
p = session.query(P).first()
p.c = C(data="child row-switch")
- with conditional_sane_rowcount_warnings(
- update=True, only_returning=True
- ):
+ with conditional_sane_rowcount_warnings(update=True):
session.commit()
- @testing.requires.sane_rowcount_w_returning
+ @testing.requires.sane_rowcount
def test_child_row_switch_two(self):
P = self.classes.P
@@ -1206,9 +1150,7 @@ class PlainInheritanceTest(fixtures.MappedTest):
s.commit()
s1.sub_data = "s2"
- with conditional_sane_rowcount_warnings(
- update=True, only_returning=True
- ):
+ with conditional_sane_rowcount_warnings(update=True):
s.commit()
eq_(s1.version_id, 2)
@@ -1799,14 +1741,12 @@ class ManualVersionTest(fixtures.MappedTest):
a1.vid = 2
a1.data = "d2"
- with conditional_sane_rowcount_warnings(
- update=True, only_returning=True
- ):
+ with conditional_sane_rowcount_warnings(update=True):
sess.commit()
eq_(a1.vid, 2)
- @testing.requires.sane_rowcount_w_returning
+ @testing.requires.sane_rowcount
def test_update_concurrent_check(self):
sess = fixture_session()
a1 = self.classes.A()
@@ -1833,18 +1773,14 @@ class ManualVersionTest(fixtures.MappedTest):
# change the data and UPDATE without
# incrementing version id
a1.data = "d2"
- with conditional_sane_rowcount_warnings(
- update=True, only_returning=True
- ):
+ with conditional_sane_rowcount_warnings(update=True):
sess.commit()
eq_(a1.vid, 1)
a1.data = "d3"
a1.vid = 2
- with conditional_sane_rowcount_warnings(
- update=True, only_returning=True
- ):
+ with conditional_sane_rowcount_warnings(update=True):
sess.commit()
eq_(a1.vid, 2)
@@ -1907,18 +1843,14 @@ class ManualInheritanceVersionTest(fixtures.MappedTest):
# change col on subtable only without
# incrementing version id
b1.b_data = "bd2"
- with conditional_sane_rowcount_warnings(
- update=True, only_returning=True
- ):
+ with conditional_sane_rowcount_warnings(update=True):
sess.commit()
eq_(b1.vid, 1)
b1.b_data = "d3"
b1.vid = 2
- with conditional_sane_rowcount_warnings(
- update=True, only_returning=True
- ):
+ with conditional_sane_rowcount_warnings(update=True):
sess.commit()
eq_(b1.vid, 2)
@@ -1990,9 +1922,7 @@ class VersioningMappedSelectTest(fixtures.MappedTest):
f1.value = "f1rev2"
f2.value = "f2rev2"
- with conditional_sane_rowcount_warnings(
- update=True, only_returning=True
- ):
+ with conditional_sane_rowcount_warnings(update=True):
s1.commit()
eq_(
@@ -2015,9 +1945,7 @@ class VersioningMappedSelectTest(fixtures.MappedTest):
f1.version_id = 2
f2.value = "f2rev2"
f2.version_id = 2
- with conditional_sane_rowcount_warnings(
- update=True, only_returning=True
- ):
+ with conditional_sane_rowcount_warnings(update=True):
s1.flush()
eq_(
@@ -2057,9 +1985,7 @@ class VersioningMappedSelectTest(fixtures.MappedTest):
s1.expire_all()
- with conditional_sane_rowcount_warnings(
- update=True, only_returning=True
- ):
+ with conditional_sane_rowcount_warnings(update=True):
f1.value = "f2"
f1.version_id = 2
s1.flush()
@@ -2109,9 +2035,7 @@ class QuotedBindVersioningTest(fixtures.MappedTest):
fixture_session.commit()
f1.value = "v2"
- with conditional_sane_rowcount_warnings(
- update=True, only_returning=True
- ):
+ with conditional_sane_rowcount_warnings(update=True):
fixture_session.commit()
eq_(f1.version, 2)