summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2015-12-14 12:00:21 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2015-12-14 12:00:21 -0500
commitc59ea573e12e94760e2e8770ebf5e9690e9c0105 (patch)
tree3ebf15131d17d122f598e79a4d7320cf7c9e3b8a
parent7d96ad4d535dc02a8ab1384df1db94dea2a045b5 (diff)
downloadsqlalchemy-c59ea573e12e94760e2e8770ebf5e9690e9c0105.tar.gz
ref #3609 wip
-rw-r--r--lib/sqlalchemy/orm/mapper.py14
-rw-r--r--lib/sqlalchemy/orm/persistence.py27
-rw-r--r--lib/sqlalchemy/sql/crud.py1
-rw-r--r--lib/sqlalchemy/testing/assertsql.py16
-rw-r--r--test/orm/test_unitofworkv2.py165
5 files changed, 209 insertions, 14 deletions
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index 5ade4b966..95aa14a26 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -1970,12 +1970,24 @@ class Mapper(InspectionAttr):
(
table,
frozenset([
- col for col in columns
+ col.key for col in columns
if col.server_default is not None])
)
for table, columns in self._cols_by_table.items()
)
+ @_memoized_configured_property
+ def _server_onupdate_default_cols(self):
+ return dict(
+ (
+ table,
+ frozenset([
+ col.key for col in columns
+ if col.server_onupdate is not None])
+ )
+ for table, columns in self._cols_by_table.items()
+ )
+
@property
def selectable(self):
"""The :func:`.select` construct this :class:`.Mapper` selects from
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py
index 768c1146a..9b631d2b0 100644
--- a/lib/sqlalchemy/orm/persistence.py
+++ b/lib/sqlalchemy/orm/persistence.py
@@ -397,7 +397,7 @@ def _collect_insert_commands(
if mapper.base_mapper.eager_defaults:
has_all_defaults = mapper._server_default_cols[table].\
- issubset(params)
+ issubset(set(params).union(value_params))
else:
has_all_defaults = True
else:
@@ -448,6 +448,7 @@ def _collect_update_commands(
set(propkey_to_col).intersection(state_dict).difference(
mapper._pk_keys_by_table[table])
)
+ has_all_defaults = True
else:
params = {}
for propkey in set(propkey_to_col).intersection(
@@ -463,6 +464,12 @@ def _collect_update_commands(
value, state.committed_state[propkey]) is not True:
params[col.key] = value
+ if mapper.base_mapper.eager_defaults:
+ has_all_defaults = mapper._server_onupdate_default_cols[table].\
+ issubset(set(params).union(value_params))
+ else:
+ has_all_defaults = True
+
if update_version_id is not None and \
mapper.version_id_col in mapper._cols_by_table[table]:
@@ -529,7 +536,7 @@ def _collect_update_commands(
params.update(pk_params)
yield (
state, state_dict, params, mapper,
- connection, value_params)
+ connection, value_params, has_all_defaults)
def _collect_post_update_commands(base_mapper, uowtransaction, table,
@@ -628,14 +635,16 @@ def _emit_update_statements(base_mapper, uowtransaction,
statement = base_mapper._memo(('update', table), update_stmt)
- for (connection, paramkeys, hasvalue), \
+ for (connection, paramkeys, hasvalue, has_all_defaults), \
records in groupby(
update,
lambda rec: (
rec[4], # connection
set(rec[2]), # set of parameter keys
- bool(rec[5]))): # whether or not we have "value" parameters
-
+ bool(rec[5]), # whether or not we have "value" parameters
+ rec[6] # has_all_defaults
+ )
+ ):
rows = 0
records = list(records)
@@ -645,11 +654,11 @@ def _emit_update_statements(base_mapper, uowtransaction,
assert_singlerow = connection.dialect.supports_sane_rowcount
assert_multirow = assert_singlerow and \
connection.dialect.supports_sane_multi_rowcount
- allow_multirow = not needs_version_id
+ allow_multirow = has_all_defaults and not needs_version_id
if hasvalue:
for state, state_dict, params, mapper, \
- connection, value_params in records:
+ connection, value_params, has_all_defaults in records:
c = connection.execute(
statement.values(value_params),
params)
@@ -669,7 +678,7 @@ def _emit_update_statements(base_mapper, uowtransaction,
if not allow_multirow:
check_rowcount = assert_singlerow
for state, state_dict, params, mapper, \
- connection, value_params in records:
+ connection, value_params, has_all_defaults in records:
c = cached_connections[connection].\
execute(statement, params)
@@ -699,7 +708,7 @@ def _emit_update_statements(base_mapper, uowtransaction,
rows += c.rowcount
for state, state_dict, params, mapper, \
- connection, value_params in records:
+ connection, value_params, has_all_defaults in records:
if bookkeeping:
_postfetch(
mapper,
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py
index 18b96018d..c5495ccde 100644
--- a/lib/sqlalchemy/sql/crud.py
+++ b/lib/sqlalchemy/sql/crud.py
@@ -493,6 +493,7 @@ def _append_param_update(
else:
compiler.postfetch.append(c)
elif implicit_return_defaults and \
+ stmt._return_defaults is not True and \
c in implicit_return_defaults:
compiler.returning.append(c)
diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py
index 243493607..ce01e9b81 100644
--- a/lib/sqlalchemy/testing/assertsql.py
+++ b/lib/sqlalchemy/testing/assertsql.py
@@ -13,6 +13,7 @@ import contextlib
from .. import event
from sqlalchemy.schema import _DDLCompiles
from sqlalchemy.engine.util import _distill_params
+from sqlalchemy.engine import url
class AssertRule(object):
@@ -58,16 +59,25 @@ class CursorSQL(SQLMatchRule):
class CompiledSQL(SQLMatchRule):
- def __init__(self, statement, params=None):
+ def __init__(self, statement, params=None, dialect='default'):
self.statement = statement
self.params = params
+ self.dialect = dialect
def _compare_sql(self, execute_observed, received_statement):
stmt = re.sub(r'[\n\t]', '', self.statement)
return received_statement == stmt
def _compile_dialect(self, execute_observed):
- return DefaultDialect()
+ if self.dialect == 'default':
+ return DefaultDialect(implicit_returning=True)
+ else:
+ # ugh
+ if self.dialect == 'postgresql':
+ params = {'implicit_returning': True}
+ else:
+ params = {}
+ return url.URL(self.dialect).get_dialect()(**params)
def _received_statement(self, execute_observed):
"""reconstruct the statement and params in terms
@@ -159,7 +169,7 @@ class CompiledSQL(SQLMatchRule):
'Testing for compiled statement %r partial params %r, '
'received %%(received_statement)r with params '
'%%(received_parameters)r' % (
- self.statement, expected_params
+ self.statement.replace('%', '%%'), expected_params
)
)
diff --git a/test/orm/test_unitofworkv2.py b/test/orm/test_unitofworkv2.py
index 09240dfdb..126b95082 100644
--- a/test/orm/test_unitofworkv2.py
+++ b/test/orm/test_unitofworkv2.py
@@ -5,7 +5,7 @@ from sqlalchemy.testing.schema import Table, Column
from test.orm import _fixtures
from sqlalchemy import exc, util
from sqlalchemy.testing import fixtures, config
-from sqlalchemy import Integer, String, ForeignKey, func, literal
+from sqlalchemy import Integer, String, ForeignKey, func, literal, FetchedValue
from sqlalchemy.orm import mapper, relationship, backref, \
create_session, unitofwork, attributes,\
Session, exc as orm_exc
@@ -1848,6 +1848,169 @@ class NoAttrEventInFlushTest(fixtures.MappedTest):
eq_(t1.returning_val, 5)
+class EagerDefaultsTest(fixtures.MappedTest):
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ 'test', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('foo', Integer, server_default="3")
+ )
+
+ Table(
+ 'test2', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('foo', Integer),
+ Column('bar', Integer, server_onupdate=FetchedValue())
+ )
+
+ @classmethod
+ def setup_classes(cls):
+ class Thing(cls.Basic):
+ pass
+
+ class Thing2(cls.Basic):
+ pass
+
+ @classmethod
+ def setup_mappers(cls):
+ Thing = cls.classes.Thing
+
+ mapper(Thing, cls.tables.test, eager_defaults=True)
+
+ Thing2 = cls.classes.Thing2
+
+ mapper(Thing2, cls.tables.test2, eager_defaults=True)
+
+ def test_insert_defaults_present(self):
+ Thing = self.classes.Thing
+ s = Session()
+
+ t1, t2 = (
+ Thing(id=1, foo=5),
+ Thing(id=2, foo=10)
+ )
+
+ s.add_all([t1, t2])
+
+ self.assert_sql_execution(
+ testing.db,
+ s.flush,
+ CompiledSQL(
+ "INSERT INTO test (id, foo) VALUES (:id, :foo)",
+ [{'foo': 5, 'id': 1}, {'foo': 10, 'id': 2}]
+ ),
+ )
+
+ def go():
+ eq_(t1.foo, 5)
+ eq_(t2.foo, 10)
+
+ self.assert_sql_count(testing.db, go, 0)
+
+ def test_insert_defaults_non_present(self):
+ Thing = self.classes.Thing
+ s = Session()
+
+ t1, t2 = (
+ Thing(id=1),
+ Thing(id=2)
+ )
+
+ s.add_all([t1, t2])
+
+ self.assert_sql_execution(
+ testing.db,
+ s.commit,
+ CompiledSQL(
+ "INSERT INTO test (id) VALUES (%(id)s) RETURNING test.foo",
+ [{'id': 1}],
+ dialect='postgresql'
+ ),
+ CompiledSQL(
+ "INSERT INTO test (id) VALUES (%(id)s) RETURNING test.foo",
+ [{'id': 2}],
+ dialect='postgresql'
+ ),
+ )
+
+ def test_update_defaults_nonpresent(self):
+ Thing2 = self.classes.Thing2
+ s = Session()
+
+ t1, t2, t3, t4 = (
+ Thing2(id=1, foo=1, bar=2),
+ Thing2(id=2, foo=2, bar=3),
+ Thing2(id=3, foo=3, bar=4),
+ Thing2(id=4, foo=4, bar=5)
+ )
+
+ s.add_all([t1, t2, t3, t4])
+ s.flush()
+
+ t1.foo = 5
+ t2.foo = 6
+ t2.bar = 10
+ t3.foo = 7
+ t4.foo = 8
+ t4.bar = 12
+
+ self.assert_sql_execution(
+ testing.db,
+ s.commit,
+ CompiledSQL(
+ "UPDATE test2 SET foo=%(foo)s WHERE test2.id = %(test2_id)s "
+ "RETURNING test2.bar",
+ [{'foo': 5, 'test2_id': 1}],
+ dialect='postgresql'
+ ),
+ CompiledSQL(
+ "UPDATE test2 SET foo=%(foo)s, bar=%(bar)s "
+ "WHERE test2.id = %(test2_id)s",
+ [{'foo': 6, 'bar': 10, 'test2_id': 2}],
+ dialect='postgresql'
+ ),
+ CompiledSQL(
+ "UPDATE test2 SET foo=%(foo)s WHERE test2.id = %(test2_id)s "
+ "RETURNING test2.bar",
+ [{'foo': 7, 'test2_id': 3}],
+ dialect='postgresql'
+ ),
+ CompiledSQL(
+ "UPDATE test2 SET foo=%(foo)s, bar=%(bar)s "
+ "WHERE test2.id = %(test2_id)s",
+ [{'foo': 8, 'bar': 12, 'test2_id': 4}],
+ dialect='postgresql'
+ ),
+ )
+
+ def test_update_defaults_present(self):
+ Thing2 = self.classes.Thing2
+ s = Session()
+
+ t1, t2 = (
+ Thing2(id=1, foo=1, bar=2),
+ Thing2(id=2, foo=2, bar=3)
+ )
+
+ s.add_all([t1, t2])
+ s.flush()
+
+ t1.bar = 5
+ t2.bar = 10
+
+ self.assert_sql_execution(
+ testing.db,
+ s.commit,
+ CompiledSQL(
+ "UPDATE test2 SET bar=%(bar)s WHERE test2.id = %(test2_id)s",
+ [{'bar': 5, 'test2_id': 1}, {'bar': 10, 'test2_id': 2}],
+ dialect='postgresql'
+ )
+ )
+
class TypeWoBoolTest(fixtures.MappedTest, testing.AssertsExecutionResults):
"""test support for custom datatypes that return a non-__bool__ value
when compared via __eq__(), eg. ticket 3469"""