diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2015-12-14 12:00:21 -0500 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2015-12-14 12:00:21 -0500 |
commit | c59ea573e12e94760e2e8770ebf5e9690e9c0105 (patch) | |
tree | 3ebf15131d17d122f598e79a4d7320cf7c9e3b8a | |
parent | 7d96ad4d535dc02a8ab1384df1db94dea2a045b5 (diff) | |
download | sqlalchemy-c59ea573e12e94760e2e8770ebf5e9690e9c0105.tar.gz |
ref #3609 wip
-rw-r--r-- | lib/sqlalchemy/orm/mapper.py | 14 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/persistence.py | 27 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/crud.py | 1 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/assertsql.py | 16 | ||||
-rw-r--r-- | test/orm/test_unitofworkv2.py | 165 |
5 files changed, 209 insertions, 14 deletions
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 5ade4b966..95aa14a26 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1970,12 +1970,24 @@ class Mapper(InspectionAttr): ( table, frozenset([ - col for col in columns + col.key for col in columns if col.server_default is not None]) ) for table, columns in self._cols_by_table.items() ) + @_memoized_configured_property + def _server_onupdate_default_cols(self): + return dict( + ( + table, + frozenset([ + col.key for col in columns + if col.server_onupdate is not None]) + ) + for table, columns in self._cols_by_table.items() + ) + @property def selectable(self): """The :func:`.select` construct this :class:`.Mapper` selects from diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 768c1146a..9b631d2b0 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -397,7 +397,7 @@ def _collect_insert_commands( if mapper.base_mapper.eager_defaults: has_all_defaults = mapper._server_default_cols[table].\ - issubset(params) + issubset(set(params).union(value_params)) else: has_all_defaults = True else: @@ -448,6 +448,7 @@ def _collect_update_commands( set(propkey_to_col).intersection(state_dict).difference( mapper._pk_keys_by_table[table]) ) + has_all_defaults = True else: params = {} for propkey in set(propkey_to_col).intersection( @@ -463,6 +464,12 @@ def _collect_update_commands( value, state.committed_state[propkey]) is not True: params[col.key] = value + if mapper.base_mapper.eager_defaults: + has_all_defaults = mapper._server_onupdate_default_cols[table].\ + issubset(set(params).union(value_params)) + else: + has_all_defaults = True + if update_version_id is not None and \ mapper.version_id_col in mapper._cols_by_table[table]: @@ -529,7 +536,7 @@ def _collect_update_commands( params.update(pk_params) yield ( state, state_dict, params, mapper, - connection, value_params) + connection, value_params, has_all_defaults) def _collect_post_update_commands(base_mapper, uowtransaction, table, @@ -628,14 +635,16 @@ def _emit_update_statements(base_mapper, uowtransaction, statement = base_mapper._memo(('update', table), update_stmt) - for (connection, paramkeys, hasvalue), \ + for (connection, paramkeys, hasvalue, has_all_defaults), \ records in groupby( update, lambda rec: ( rec[4], # connection set(rec[2]), # set of parameter keys - bool(rec[5]))): # whether or not we have "value" parameters - + bool(rec[5]), # whether or not we have "value" parameters + rec[6] # has_all_defaults + ) + ): rows = 0 records = list(records) @@ -645,11 +654,11 @@ def _emit_update_statements(base_mapper, uowtransaction, assert_singlerow = connection.dialect.supports_sane_rowcount assert_multirow = assert_singlerow and \ connection.dialect.supports_sane_multi_rowcount - allow_multirow = not needs_version_id + allow_multirow = has_all_defaults and not needs_version_id if hasvalue: for state, state_dict, params, mapper, \ - connection, value_params in records: + connection, value_params, has_all_defaults in records: c = connection.execute( statement.values(value_params), params) @@ -669,7 +678,7 @@ def _emit_update_statements(base_mapper, uowtransaction, if not allow_multirow: check_rowcount = assert_singlerow for state, state_dict, params, mapper, \ - connection, value_params in records: + connection, value_params, has_all_defaults in records: c = cached_connections[connection].\ execute(statement, params) @@ -699,7 +708,7 @@ def _emit_update_statements(base_mapper, uowtransaction, rows += c.rowcount for state, state_dict, params, mapper, \ - connection, value_params in records: + connection, value_params, has_all_defaults in records: if bookkeeping: _postfetch( mapper, diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 18b96018d..c5495ccde 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -493,6 +493,7 @@ def _append_param_update( else: compiler.postfetch.append(c) elif implicit_return_defaults and \ + stmt._return_defaults is not True and \ c in implicit_return_defaults: compiler.returning.append(c) diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py index 243493607..ce01e9b81 100644 --- a/lib/sqlalchemy/testing/assertsql.py +++ b/lib/sqlalchemy/testing/assertsql.py @@ -13,6 +13,7 @@ import contextlib from .. import event from sqlalchemy.schema import _DDLCompiles from sqlalchemy.engine.util import _distill_params +from sqlalchemy.engine import url class AssertRule(object): @@ -58,16 +59,25 @@ class CursorSQL(SQLMatchRule): class CompiledSQL(SQLMatchRule): - def __init__(self, statement, params=None): + def __init__(self, statement, params=None, dialect='default'): self.statement = statement self.params = params + self.dialect = dialect def _compare_sql(self, execute_observed, received_statement): stmt = re.sub(r'[\n\t]', '', self.statement) return received_statement == stmt def _compile_dialect(self, execute_observed): - return DefaultDialect() + if self.dialect == 'default': + return DefaultDialect(implicit_returning=True) + else: + # ugh + if self.dialect == 'postgresql': + params = {'implicit_returning': True} + else: + params = {} + return url.URL(self.dialect).get_dialect()(**params) def _received_statement(self, execute_observed): """reconstruct the statement and params in terms @@ -159,7 +169,7 @@ class CompiledSQL(SQLMatchRule): 'Testing for compiled statement %r partial params %r, ' 'received %%(received_statement)r with params ' '%%(received_parameters)r' % ( - self.statement, expected_params + self.statement.replace('%', '%%'), expected_params ) ) diff --git a/test/orm/test_unitofworkv2.py b/test/orm/test_unitofworkv2.py index 09240dfdb..126b95082 100644 --- a/test/orm/test_unitofworkv2.py +++ b/test/orm/test_unitofworkv2.py @@ -5,7 +5,7 @@ from sqlalchemy.testing.schema import Table, Column from test.orm import _fixtures from sqlalchemy import exc, util from sqlalchemy.testing import fixtures, config -from sqlalchemy import Integer, String, ForeignKey, func, literal +from sqlalchemy import Integer, String, ForeignKey, func, literal, FetchedValue from sqlalchemy.orm import mapper, relationship, backref, \ create_session, unitofwork, attributes,\ Session, exc as orm_exc @@ -1848,6 +1848,169 @@ class NoAttrEventInFlushTest(fixtures.MappedTest): eq_(t1.returning_val, 5) +class EagerDefaultsTest(fixtures.MappedTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + 'test', metadata, + Column('id', Integer, primary_key=True), + Column('foo', Integer, server_default="3") + ) + + Table( + 'test2', metadata, + Column('id', Integer, primary_key=True), + Column('foo', Integer), + Column('bar', Integer, server_onupdate=FetchedValue()) + ) + + @classmethod + def setup_classes(cls): + class Thing(cls.Basic): + pass + + class Thing2(cls.Basic): + pass + + @classmethod + def setup_mappers(cls): + Thing = cls.classes.Thing + + mapper(Thing, cls.tables.test, eager_defaults=True) + + Thing2 = cls.classes.Thing2 + + mapper(Thing2, cls.tables.test2, eager_defaults=True) + + def test_insert_defaults_present(self): + Thing = self.classes.Thing + s = Session() + + t1, t2 = ( + Thing(id=1, foo=5), + Thing(id=2, foo=10) + ) + + s.add_all([t1, t2]) + + self.assert_sql_execution( + testing.db, + s.flush, + CompiledSQL( + "INSERT INTO test (id, foo) VALUES (:id, :foo)", + [{'foo': 5, 'id': 1}, {'foo': 10, 'id': 2}] + ), + ) + + def go(): + eq_(t1.foo, 5) + eq_(t2.foo, 10) + + self.assert_sql_count(testing.db, go, 0) + + def test_insert_defaults_non_present(self): + Thing = self.classes.Thing + s = Session() + + t1, t2 = ( + Thing(id=1), + Thing(id=2) + ) + + s.add_all([t1, t2]) + + self.assert_sql_execution( + testing.db, + s.commit, + CompiledSQL( + "INSERT INTO test (id) VALUES (%(id)s) RETURNING test.foo", + [{'id': 1}], + dialect='postgresql' + ), + CompiledSQL( + "INSERT INTO test (id) VALUES (%(id)s) RETURNING test.foo", + [{'id': 2}], + dialect='postgresql' + ), + ) + + def test_update_defaults_nonpresent(self): + Thing2 = self.classes.Thing2 + s = Session() + + t1, t2, t3, t4 = ( + Thing2(id=1, foo=1, bar=2), + Thing2(id=2, foo=2, bar=3), + Thing2(id=3, foo=3, bar=4), + Thing2(id=4, foo=4, bar=5) + ) + + s.add_all([t1, t2, t3, t4]) + s.flush() + + t1.foo = 5 + t2.foo = 6 + t2.bar = 10 + t3.foo = 7 + t4.foo = 8 + t4.bar = 12 + + self.assert_sql_execution( + testing.db, + s.commit, + CompiledSQL( + "UPDATE test2 SET foo=%(foo)s WHERE test2.id = %(test2_id)s " + "RETURNING test2.bar", + [{'foo': 5, 'test2_id': 1}], + dialect='postgresql' + ), + CompiledSQL( + "UPDATE test2 SET foo=%(foo)s, bar=%(bar)s " + "WHERE test2.id = %(test2_id)s", + [{'foo': 6, 'bar': 10, 'test2_id': 2}], + dialect='postgresql' + ), + CompiledSQL( + "UPDATE test2 SET foo=%(foo)s WHERE test2.id = %(test2_id)s " + "RETURNING test2.bar", + [{'foo': 7, 'test2_id': 3}], + dialect='postgresql' + ), + CompiledSQL( + "UPDATE test2 SET foo=%(foo)s, bar=%(bar)s " + "WHERE test2.id = %(test2_id)s", + [{'foo': 8, 'bar': 12, 'test2_id': 4}], + dialect='postgresql' + ), + ) + + def test_update_defaults_present(self): + Thing2 = self.classes.Thing2 + s = Session() + + t1, t2 = ( + Thing2(id=1, foo=1, bar=2), + Thing2(id=2, foo=2, bar=3) + ) + + s.add_all([t1, t2]) + s.flush() + + t1.bar = 5 + t2.bar = 10 + + self.assert_sql_execution( + testing.db, + s.commit, + CompiledSQL( + "UPDATE test2 SET bar=%(bar)s WHERE test2.id = %(test2_id)s", + [{'bar': 5, 'test2_id': 1}, {'bar': 10, 'test2_id': 2}], + dialect='postgresql' + ) + ) + class TypeWoBoolTest(fixtures.MappedTest, testing.AssertsExecutionResults): """test support for custom datatypes that return a non-__bool__ value when compared via __eq__(), eg. ticket 3469""" |