summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2010-12-11 03:10:17 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2010-12-11 03:10:17 -0500
commitb88c54f95be3e3bc2e0923181d56862fa3fda9fa (patch)
treeeeb695bd9927895537d25a4402bad51b116c6964
parent4fa0ea584248726b164008cb8e4257d207b03af9 (diff)
parent9c0755640c5f1d45596ff7234d2d42f1c92d09e0 (diff)
downloadsqlalchemy-b88c54f95be3e3bc2e0923181d56862fa3fda9fa.tar.gz
do the mercurial dance (re-merge what I just merged...)
-rw-r--r--lib/sqlalchemy/engine/base.py49
-rw-r--r--lib/sqlalchemy/engine/default.py49
-rw-r--r--lib/sqlalchemy/ext/horizontal_shard.py4
-rw-r--r--lib/sqlalchemy/orm/mapper.py157
-rw-r--r--lib/sqlalchemy/orm/session.py3
-rw-r--r--lib/sqlalchemy/orm/unitofwork.py3
-rw-r--r--test/orm/test_unitofworkv2.py70
-rw-r--r--test/sql/test_query.py19
8 files changed, 226 insertions, 128 deletions
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
index a335b1d17..00aaca01e 100644
--- a/lib/sqlalchemy/engine/base.py
+++ b/lib/sqlalchemy/engine/base.py
@@ -629,24 +629,6 @@ class ExecutionContext(object):
raise NotImplementedError()
- def last_inserted_params(self):
- """Return a dictionary of the full parameter dictionary for the last
- compiled INSERT statement.
-
- Includes any ColumnDefaults or Sequences that were pre-executed.
- """
-
- raise NotImplementedError()
-
- def last_updated_params(self):
- """Return a dictionary of the full parameter dictionary for the last
- compiled UPDATE statement.
-
- Includes any ColumnDefaults that were pre-executed.
- """
-
- raise NotImplementedError()
-
def lastrow_has_defaults(self):
"""Return True if the last INSERT or UPDATE row contained
inlined or database-side defaults.
@@ -2467,6 +2449,7 @@ class ResultProxy(object):
did not explicitly specify returning().
"""
+
if not self.context.isinsert:
raise exc.InvalidRequestError(
"Statement is not an insert() expression construct.")
@@ -2475,31 +2458,33 @@ class ResultProxy(object):
"Can't call inserted_primary_key when returning() "
"is used.")
- return self.context._inserted_primary_key
+ return self.context.inserted_primary_key
@util.deprecated("0.6", "Use :attr:`.ResultProxy.inserted_primary_key`")
def last_inserted_ids(self):
"""Return the primary key for the row just inserted."""
return self.inserted_primary_key
-
+
def last_updated_params(self):
- """Return ``last_updated_params()`` from the underlying
- ExecutionContext.
-
- See ExecutionContext for details.
+ """Return the collection of updated parameters from this
+ execution.
+
"""
-
- return self.context.last_updated_params()
+ if self.context.executemany:
+ return self.context.compiled_parameters
+ else:
+ return self.context.compiled_parameters[0]
def last_inserted_params(self):
- """Return ``last_inserted_params()`` from the underlying
- ExecutionContext.
-
- See ExecutionContext for details.
+ """Return the collection of inserted parameters from this
+ execution.
+
"""
-
- return self.context.last_inserted_params()
+ if self.context.executemany:
+ return self.context.compiled_parameters
+ else:
+ return self.context.compiled_parameters[0]
def lastrow_has_defaults(self):
"""Return ``lastrow_has_defaults()`` from the underlying
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index 0717a8fef..21603b258 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -400,7 +400,9 @@ class DefaultExecutionContext(base.ExecutionContext):
self.cursor = self.create_cursor()
if self.isinsert or self.isupdate:
self.__process_defaults()
-
+ self.postfetch_cols = self.compiled.postfetch
+ self.prefetch_cols = self.compiled.prefetch
+
processors = dict(
(key, value) for key, value in
( (compiled.bind_names[bindparam],
@@ -541,7 +543,8 @@ class DefaultExecutionContext(base.ExecutionContext):
"""
conn = self._connection
- if isinstance(stmt, unicode) and not self.dialect.supports_unicode_statements:
+ if isinstance(stmt, unicode) and \
+ not self.dialect.supports_unicode_statements:
stmt = stmt.encode(self.dialect.encoding)
if self.dialect.positional:
@@ -614,13 +617,14 @@ class DefaultExecutionContext(base.ExecutionContext):
def post_insert(self):
if self.dialect.postfetch_lastrowid and \
- (not len(self._inserted_primary_key) or \
- None in self._inserted_primary_key):
+ (not len(self.inserted_primary_key) or \
+ None in self.inserted_primary_key):
table = self.compiled.statement.table
lastrowid = self.get_lastrowid()
- self._inserted_primary_key = [c is table._autoincrement_column and lastrowid or v
- for c, v in zip(table.primary_key, self._inserted_primary_key)
+ self.inserted_primary_key = [
+ c is table._autoincrement_column and lastrowid or v
+ for c, v in zip(table.primary_key, self.inserted_primary_key)
]
def _fetch_implicit_returning(self, resultproxy):
@@ -628,22 +632,17 @@ class DefaultExecutionContext(base.ExecutionContext):
row = resultproxy.fetchone()
ipk = []
- for c, v in zip(table.primary_key, self._inserted_primary_key):
+ for c, v in zip(table.primary_key, self.inserted_primary_key):
if v is not None:
ipk.append(v)
else:
ipk.append(row[c])
- self._inserted_primary_key = ipk
-
- def last_inserted_params(self):
- return self._last_inserted_params
-
- def last_updated_params(self):
- return self._last_updated_params
-
+ self.inserted_primary_key = ipk
+
def lastrow_has_defaults(self):
- return hasattr(self, 'postfetch_cols') and len(self.postfetch_cols)
+ return (self.isinsert or self.isupdate) and \
+ bool(self.postfetch_cols)
def set_input_sizes(self, translate=None, exclude_types=None):
"""Given a cursor and ClauseParameters, call the appropriate
@@ -714,7 +713,7 @@ class DefaultExecutionContext(base.ExecutionContext):
return None
else:
return self._exec_default(column.onupdate)
-
+
def __process_defaults(self):
"""Generate default values for compiled insert/update statements,
and generate inserted_primary_key collection.
@@ -745,7 +744,6 @@ class DefaultExecutionContext(base.ExecutionContext):
if val is not None:
param[c.key] = val
del self.current_parameters
-
else:
self.current_parameters = compiled_parameters = \
self.compiled_parameters[0]
@@ -759,18 +757,13 @@ class DefaultExecutionContext(base.ExecutionContext):
if val is not None:
compiled_parameters[c.key] = val
del self.current_parameters
-
+
if self.isinsert:
- self._inserted_primary_key = [
- compiled_parameters.get(c.key, None)
- for c in self.compiled.\
+ self.inserted_primary_key = [
+ self.compiled_parameters[0].get(c.key, None)
+ for c in self.compiled.\
statement.table.primary_key
- ]
- self._last_inserted_params = compiled_parameters
- else:
- self._last_updated_params = compiled_parameters
+ ]
- self.postfetch_cols = self.compiled.postfetch
- self.prefetch_cols = self.compiled.prefetch
DefaultDialect.execution_ctx_cls = DefaultExecutionContext
diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py
index 78e3f5953..e48cb9fcb 100644
--- a/lib/sqlalchemy/ext/horizontal_shard.py
+++ b/lib/sqlalchemy/ext/horizontal_shard.py
@@ -50,12 +50,12 @@ class ShardedSession(Session):
self.id_chooser = id_chooser
self.query_chooser = query_chooser
self.__binds = {}
- self._mapper_flush_opts = {'connection_callable':self.connection}
+ self.connection_callable = self.connection
self._query_cls = ShardedQuery
if shards is not None:
for k in shards:
self.bind_shard(k, shards[k])
-
+
def connection(self, mapper=None, instance=None, shard_id=None, **kwargs):
if shard_id is None:
shard_id = self.shard_chooser(mapper, instance)
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index a4662770e..20242c97c 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -1468,9 +1468,9 @@ class Mapper(object):
# if session has a connection callable,
# organize individual states with the connection
# to use for update
- if 'connection_callable' in uowtransaction.mapper_flush_opts:
+ if uowtransaction.session.connection_callable:
connection_callable = \
- uowtransaction.mapper_flush_opts['connection_callable']
+ uowtransaction.session.connection_callable
else:
connection = uowtransaction.transaction.connection(self)
connection_callable = None
@@ -1550,15 +1550,10 @@ class Mapper(object):
of objects.
This is called within the context of a UOWTransaction during a
- flush operation.
+ flush operation, given a list of states to be flushed. The
+ base mapper in an inheritance hierarchy handles the inserts/
+ updates for all descendant mappers.
- `_save_obj` issues SQL statements not just for instances mapped
- directly by this mapper, but for instances mapped by all
- inheriting mappers as well. This is to maintain proper insert
- ordering among a polymorphic chain of instances. Therefore
- _save_obj is typically called only on a *base mapper*, or a
- mapper which does not inherit from any other mapper.
-
"""
# if batch=false, call _save_obj separately for each object
@@ -1572,9 +1567,9 @@ class Mapper(object):
# if session has a connection callable,
# organize individual states with the connection
# to use for insert/update
- if 'connection_callable' in uowtransaction.mapper_flush_opts:
+ if uowtransaction.session.connection_callable:
connection_callable = \
- uowtransaction.mapper_flush_opts['connection_callable']
+ uowtransaction.session.connection_callable
else:
connection = uowtransaction.transaction.connection(self)
connection_callable = None
@@ -1592,6 +1587,7 @@ class Mapper(object):
instance_key = state.key or mapper._identity_key_from_state(state)
row_switch = None
+
# call before_XXX extensions
if not has_identity:
mapper.dispatch.on_before_insert(mapper, conn, state)
@@ -1652,9 +1648,9 @@ class Mapper(object):
params = {}
value_params = {}
- hasdata = False
-
+
if isinsert:
+ has_all_pks = True
for col in mapper._cols_by_table[table]:
if col is mapper.version_id_col:
params[col.key] = \
@@ -1668,19 +1664,21 @@ class Mapper(object):
value = prop.get_col_value(col, value)
if value is None:
- if col.default is None and \
- col.server_default is None and \
- col not in pks:
-
+ if col in pks:
+ has_all_pks = False
+ elif col.default is None and \
+ col.server_default is None:
params[col.key] = value
+
elif isinstance(value, sql.ClauseElement):
value_params[col] = value
else:
params[col.key] = value
insert.append((state, state_dict, params, mapper,
- connection, value_params))
+ connection, value_params, has_all_pks))
else:
+ hasdata = False
for col in mapper._cols_by_table[table]:
if col is mapper.version_id_col:
params[col._label] = \
@@ -1762,7 +1760,8 @@ class Mapper(object):
else:
hasdata = True
elif col in pks:
- value = state.manager[prop.key].impl.get(state, state_dict)
+ value = state.manager[prop.key].\
+ impl.get(state, state_dict)
if prop.get_col_value:
value = prop.get_col_value(col, value)
params[col._label] = value
@@ -1803,11 +1802,16 @@ class Mapper(object):
else:
c = cached_connections[connection].\
execute(statement, params)
-
- mapper._postfetch(uowtransaction, table,
- state, state_dict, c,
- c.last_updated_params(), value_params)
-
+
+ mapper._postfetch(
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ c.context.prefetch_cols,
+ c.context.postfetch_cols,
+ c.context.compiled_parameters[0],
+ value_params)
rows += c.rowcount
if connection.dialect.supports_sane_rowcount:
@@ -1826,37 +1830,71 @@ class Mapper(object):
if insert:
statement = self._memo(('insert', table), table.insert)
- for state, state_dict, params, mapper, \
- connection, value_params in insert:
+ for (connection, pkeys, hasvalue, has_all_pks), \
+ records in groupby(insert,
+ lambda rec: (rec[4],
+ rec[2].keys(),
+ bool(rec[5]),
+ rec[6])
+ ):
+ if has_all_pks and not hasvalue:
+ records = list(records)
+ multiparams = [rec[2] for rec in records]
+ c = cached_connections[connection].\
+ execute(statement, multiparams)
+
+ for (state, state_dict, params, mapper,
+ conn, value_params, has_all_pks), \
+ last_inserted_params in \
+ zip(records, c.context.compiled_parameters):
+ mapper._postfetch(
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ c.context.prefetch_cols,
+ c.context.postfetch_cols,
+ last_inserted_params,
+ value_params)
+
+ else:
+ for state, state_dict, params, mapper, \
+ connection, value_params, \
+ has_all_pks in records:
- if value_params:
- c = connection.execute(
+ if value_params:
+ result = connection.execute(
statement.values(value_params),
params)
- else:
- c = cached_connections[connection].\
- execute(statement, params)
+ else:
+ result = cached_connections[connection].\
+ execute(statement, params)
- primary_key = c.inserted_primary_key
-
- if primary_key is not None:
- # set primary key attributes
- for pk, col in zip(primary_key,
- mapper._pks_by_table[table]):
- # TODO: make sure this inlined code is OK
- # with composites
- prop = mapper._columntoproperty[col]
- if state_dict.get(prop.key) is None:
- # TODO: would rather say:
- #state_dict[prop.key] = pk
- mapper._set_state_attr_by_column(state,
- state_dict,
- col, pk)
-
- mapper._postfetch(uowtransaction, table,
- state, state_dict, c,
- c.last_inserted_params(),
- value_params)
+ primary_key = result.context.inserted_primary_key
+
+ if primary_key is not None:
+ # set primary key attributes
+ for pk, col in zip(primary_key,
+ mapper._pks_by_table[table]):
+ prop = mapper._columntoproperty[col]
+ if state_dict.get(prop.key) is None:
+ # TODO: would rather say:
+ #state_dict[prop.key] = pk
+ mapper._set_state_attr_by_column(
+ state,
+ state_dict,
+ col, pk)
+
+ mapper._postfetch(
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ result.context.prefetch_cols,
+ result.context.postfetch_cols,
+ result.context.compiled_parameters[0],
+ value_params)
+
for state, state_dict, mapper, connection, has_identity, \
instance_key, row_switch in tups:
@@ -1883,18 +1921,15 @@ class Mapper(object):
mapper.dispatch.on_after_update(mapper, connection, state)
def _postfetch(self, uowtransaction, table,
- state, dict_, resultproxy,
- params, value_params):
+ state, dict_, prefetch_cols, postfetch_cols,
+ params, value_params):
"""During a flush, expire attributes in need of newly
persisted database state."""
- postfetch_cols = resultproxy.postfetch_cols()
- generated_cols = list(resultproxy.prefetch_cols())
-
if self.version_id_col is not None:
- generated_cols.append(self.version_id_col)
+ prefetch_cols = list(prefetch_cols) + [self.version_id_col]
- for c in generated_cols:
+ for c in prefetch_cols:
if c.key in params and c in self._columntoproperty:
self._set_state_attr_by_column(state, dict_, c, params[c.key])
@@ -1937,9 +1972,9 @@ class Mapper(object):
flush operation.
"""
- if 'connection_callable' in uowtransaction.mapper_flush_opts:
+ if uowtransaction.session.connection_callable:
connection_callable = \
- uowtransaction.mapper_flush_opts['connection_callable']
+ uowtransaction.session.connection_callable
else:
connection = uowtransaction.transaction.connection(self)
connection_callable = None
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index 3517eab2b..30a84bf1a 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -511,7 +511,6 @@ class Session(object):
self._enable_transaction_accounting = _enable_transaction_accounting
self.twophase = twophase
self._query_cls = query_cls
- self._mapper_flush_opts = {}
if extension:
for ext in util.to_list(extension):
@@ -530,6 +529,8 @@ class Session(object):
dispatch = event.dispatcher(SessionEvents)
+ connection_callable = None
+
def begin(self, subtransactions=False, nested=False):
"""Begin a transaction on this Session.
diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py
index 875ce634b..d9d64fe39 100644
--- a/lib/sqlalchemy/orm/unitofwork.py
+++ b/lib/sqlalchemy/orm/unitofwork.py
@@ -76,7 +76,6 @@ class UOWEventHandler(interfaces.AttributeExtension):
class UOWTransaction(object):
def __init__(self, session):
self.session = session
- self.mapper_flush_opts = session._mapper_flush_opts
# dictionary used by external actors to
# store arbitrary state information.
@@ -316,7 +315,7 @@ class UOWTransaction(object):
postsort_actions):
rec.execute(self)
-
+
def finalize_flush_changes(self):
"""mark processed objects as clean / deleted after a successful flush().
diff --git a/test/orm/test_unitofworkv2.py b/test/orm/test_unitofworkv2.py
index 73a884e0c..766addc05 100644
--- a/test/orm/test_unitofworkv2.py
+++ b/test/orm/test_unitofworkv2.py
@@ -4,7 +4,8 @@ from test.lib.schema import Table, Column
from sqlalchemy import Integer, String, ForeignKey, func
from test.orm import _fixtures, _base
from sqlalchemy.orm import mapper, relationship, backref, \
- create_session, unitofwork, attributes
+ create_session, unitofwork, attributes,\
+ Session
from test.lib.assertsql import AllOf, CompiledSQL
from test.orm._fixtures import keywords, addresses, Base, Keyword, \
@@ -776,5 +777,72 @@ class RowswitchAccountingTest(_base.MappedTest):
sess.flush()
+class BatchInsertsTest(_base.MappedTest, testing.AssertsExecutionResults):
+ @classmethod
+ def define_tables(cls, metadata):
+ Table('t', metadata,
+ Column('id', Integer, primary_key=True,
+ test_needs_autoincrement=True),
+ Column('data', String(50)),
+ Column('def_', String(50), server_default='def1')
+ )
+ @testing.resolve_artifact_names
+ def test_batch_interaction(self):
+ """test batching groups same-structured, primary
+ key present statements together.
+
+ """
+ class T(Base):
+ pass
+ mapper(T, t)
+ sess = Session()
+ sess.add_all([
+ T(data='t1'),
+ T(data='t2'),
+ T(id=3, data='t3'),
+ T(id=4, data='t4'),
+ T(id=5, data='t5'),
+ T(id=6, data=func.lower('t6')),
+ T(id=7, data='t7'),
+ T(id=8, data='t8'),
+ T(id=9, data='t9', def_='def2'),
+ T(id=10, data='t10', def_='def3'),
+ T(id=11, data='t11'),
+ ])
+ self.assert_sql_execution(
+ testing.db,
+ sess.flush,
+ CompiledSQL(
+ "INSERT INTO t (data) VALUES (:data)",
+ {'data': 't1'}
+ ),
+ CompiledSQL(
+ "INSERT INTO t (data) VALUES (:data)",
+ {'data': 't2'}
+ ),
+ CompiledSQL(
+ "INSERT INTO t (id, data) VALUES (:id, :data)",
+ [{'data': 't3', 'id': 3},
+ {'data': 't4', 'id': 4},
+ {'data': 't5', 'id': 5}]
+ ),
+ CompiledSQL(
+ "INSERT INTO t (id, data) VALUES (:id, lower(:lower_1))",
+ {'lower_1': 't6', 'id': 6}
+ ),
+ CompiledSQL(
+ "INSERT INTO t (id, data) VALUES (:id, :data)",
+ [{'data': 't7', 'id': 7}, {'data': 't8', 'id': 8}]
+ ),
+ CompiledSQL(
+ "INSERT INTO t (id, data, def_) VALUES (:id, :data, :def_)",
+ [{'data': 't9', 'id': 9, 'def_':'def2'},
+ {'data': 't10', 'id': 10, 'def_':'def3'}]
+ ),
+ CompiledSQL(
+ "INSERT INTO t (id, data) VALUES (:id, :data)",
+ {'data': 't11', 'id': 11}
+ ),
+ )
diff --git a/test/sql/test_query.py b/test/sql/test_query.py
index f59b34076..e14f5301e 100644
--- a/test/sql/test_query.py
+++ b/test/sql/test_query.py
@@ -654,7 +654,24 @@ class QueryTest(TestBase):
getattr(result, meth),
)
trans.rollback()
-
+
+ def test_no_inserted_pk_on_non_insert(self):
+ result = testing.db.execute("select * from query_users")
+ assert_raises_message(
+ exc.InvalidRequestError,
+ r"Statement is not an insert\(\) expression construct.",
+ getattr, result, 'inserted_primary_key'
+ )
+
+ @testing.requires.returning
+ def test_no_inserted_pk_on_returning(self):
+ result = testing.db.execute(users.insert().returning(users.c.user_id, users.c.user_name))
+ assert_raises_message(
+ exc.InvalidRequestError,
+ r"Can't call inserted_primary_key when returning\(\) is used.",
+ getattr, result, 'inserted_primary_key'
+ )
+
def test_fetchone_til_end(self):
result = testing.db.execute("select * from query_users")
eq_(result.fetchone(), None)