summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--doc/build/changelog/changelog_09.rst31
-rw-r--r--doc/build/core/dml.rst8
-rw-r--r--doc/build/core/types.rst2
-rw-r--r--doc/build/glossary.rst38
-rw-r--r--doc/build/orm/exceptions.rst1
-rw-r--r--doc/build/orm/mapper_config.rst196
-rw-r--r--doc/build/static/docs.css15
-rw-r--r--lib/sqlalchemy/engine/base.py6
-rw-r--r--lib/sqlalchemy/engine/default.py7
-rw-r--r--lib/sqlalchemy/engine/result.py10
-rw-r--r--lib/sqlalchemy/orm/mapper.py105
-rw-r--r--lib/sqlalchemy/orm/persistence.py125
-rw-r--r--lib/sqlalchemy/sql/compiler.py59
-rw-r--r--lib/sqlalchemy/sql/dml.py202
-rw-r--r--lib/sqlalchemy/sql/expression.py3
-rw-r--r--test/orm/test_unitofwork.py6
-rw-r--r--test/orm/test_versioning.py195
-rw-r--r--test/sql/test_returning.py122
18 files changed, 896 insertions, 235 deletions
diff --git a/doc/build/changelog/changelog_09.rst b/doc/build/changelog/changelog_09.rst
index 5d248acec..be7a7fc7c 100644
--- a/doc/build/changelog/changelog_09.rst
+++ b/doc/build/changelog/changelog_09.rst
@@ -7,6 +7,37 @@
:version: 0.9.0
.. change::
+ :tags: feature, orm
+ :tickets: 2793
+
+ The ``version_id_generator`` parameter of ``Mapper`` can now be specified
+ to rely upon server generated version identifiers, using triggers
+ or other database-provided versioning features, by passing the value
+ ``False``. The ORM will use RETURNING when available to immediately
+ load the new version identifier, else it will emit a second SELECT.
+
+ .. change::
+ :tags: feature, orm
+ :tickets: 2793
+
+ The ``eager_defaults`` flag of :class:`.Mapper` will now allow the
+ newly generated default values to be fetched using an inline
+ RETURNING clause, rather than a second SELECT statement, for backends
+ that support RETURNING.
+
+ .. change::
+ :tags: feature, core
+ :tickets: 2793
+
+ Added a new variant to :meth:`.ValuesBase.returning` called
+ :meth:`.ValuesBase.return_defaults`; this allows arbitrary columns
+ to be added to the RETURNING clause of the statement without interfering
+ with the compilers usual "implicit returning" feature, which is used to
+ efficiently fetch newly generated primary key values. For supporting
+ backends, a dictionary of all fetched values is present at
+ :attr:`.ResultProxy.returned_defaults`.
+
+ .. change::
:tags: feature
Added a new flag ``system=True`` to :class:`.Column`, which marks
diff --git a/doc/build/core/dml.rst b/doc/build/core/dml.rst
index d2901c204..892d85921 100644
--- a/doc/build/core/dml.rst
+++ b/doc/build/core/dml.rst
@@ -26,13 +26,11 @@ constructs build on the intermediary :class:`.ValuesBase`.
:members:
:inherited-members:
-
-.. autoclass:: UpdateBase
+.. autoclass:: sqlalchemy.sql.expression.UpdateBase
:members:
-
-.. autoclass:: ValuesBase
- :members:
+.. autoclass:: sqlalchemy.sql.expression.ValuesBase
+ :members:
diff --git a/doc/build/core/types.rst b/doc/build/core/types.rst
index ccbba5d24..a40363135 100644
--- a/doc/build/core/types.rst
+++ b/doc/build/core/types.rst
@@ -343,6 +343,8 @@ many decimal places. Here's a recipe that rounds them down::
value = value.quantize(self.quantize)
return value
+.. _custom_guid_type:
+
Backend-agnostic GUID Type
^^^^^^^^^^^^^^^^^^^^^^^^^^
diff --git a/doc/build/glossary.rst b/doc/build/glossary.rst
index 7c497df76..7fe4c16bd 100644
--- a/doc/build/glossary.rst
+++ b/doc/build/glossary.rst
@@ -521,3 +521,41 @@ Glossary
http://en.wikipedia.org/wiki/Durability_(database_systems)
+ RETURNING
+ This is a non-SQL standard clause provided in various forms by
+ certain backends, which provides the service of returning a result
+ set upon execution of an INSERT, UPDATE or DELETE statement. Any set
+ of columns from the matched rows can be returned, as though they were
+ produced from a SELECT statement.
+
+ The RETURNING clause provides both a dramatic performance boost to
+ common update/select scenarios, including retrieval of inline- or
+ default- generated primary key values and defaults at the moment they
+ were created, as well as a way to get at server-generated
+ default values in an atomic way.
+
+ An example of RETURNING, idiomatic to Postgresql, looks like::
+
+ INSERT INTO user_account (name) VALUES ('new name') RETURNING id, timestamp
+
+ Above, the INSERT statement will provide upon execution a result set
+ which includes the values of the columns ``user_account.id`` and
+ ``user_account.timestamp``, which above should have been generated as default
+ values as they are not included otherwise (but note any series of columns
+ or SQL expressions can be placed into RETURNING, not just default-value columns).
+
+ The backends that currently support
+ RETURNING or a similar construct are Postgresql, SQL Server, Oracle,
+ and Firebird. The Postgresql and Firebird implementations are generally
+ full featured, whereas the implementations of SQL Server and Oracle
+ have caveats. On SQL Server, the clause is known as "OUTPUT INSERTED"
+ for INSERT and UPDATE statements and "OUTPUT DELETED" for DELETE statements;
+ the key caveat is that triggers are not supported in conjunction with this
+ keyword. On Oracle, it is known as "RETURNING...INTO", and requires that the
+ value be placed into an OUT paramter, meaning not only is the syntax awkward,
+ but it can also only be used for one row at a time.
+
+ SQLAlchemy's :meth:`.UpdateBase.returning` system provides a layer of abstraction
+ on top of the RETURNING systems of these backends to provide a consistent
+ interface for returning columns. The ORM also includes many optimizations
+ that make use of RETURNING when available.
diff --git a/doc/build/orm/exceptions.rst b/doc/build/orm/exceptions.rst
index 7a760e26f..f95b26eed 100644
--- a/doc/build/orm/exceptions.rst
+++ b/doc/build/orm/exceptions.rst
@@ -2,5 +2,4 @@ ORM Exceptions
==============
.. automodule:: sqlalchemy.orm.exc
-
:members: \ No newline at end of file
diff --git a/doc/build/orm/mapper_config.rst b/doc/build/orm/mapper_config.rst
index d3dc1cf6b..88b256ae2 100644
--- a/doc/build/orm/mapper_config.rst
+++ b/doc/build/orm/mapper_config.rst
@@ -1115,6 +1115,202 @@ of these events.
.. autofunction:: reconstructor
+
+.. _mapper_version_counter:
+
+Configuring a Version Counter
+=============================
+
+The :class:`.Mapper` supports management of a :term:`version id column`, which
+is a single table column that increments or otherwise updates its value
+each time an ``UPDATE`` to the mapped table occurs. This value is checked each
+time the ORM emits an ``UPDATE`` or ``DELETE`` against the row to ensure that
+the value held in memory matches the database value.
+
+The purpose of this feature is to detect when two concurrent transactions
+are modifying the same row at roughly the same time, or alternatively to provide
+a guard against the usage of a "stale" row in a system that might be re-using
+data from a previous transaction without refreshing (e.g. if one sets ``expire_on_commit=False``
+with a :class:`.Session`, it is possible to re-use the data from a previous
+transaction).
+
+.. topic:: Concurrent transaction updates
+
+ When detecting concurrent updates within transactions, it is typically the
+ case that the database's transaction isolation level is below the level of
+ :term:`repeatable read`; otherwise, the transaction will not be exposed
+ to a new row value created by a concurrent update which conflicts with
+ the locally updated value. In this case, the SQLAlchemy versioning
+ feature will typically not be useful for in-transaction conflict detection,
+ though it still can be used for cross-transaction staleness detection.
+
+ The database that enforces repeatable reads will typically either have locked the
+ target row against a concurrent update, or is employing some form
+ of multi version concurrency control such that it will emit an error
+ when the transaction is committed. SQLAlchemy's version_id_col is an alternative
+ which allows version tracking to occur for specific tables within a transaction
+ that otherwise might not have this isolation level set.
+
+ .. seealso::
+
+ `Repeatable Read Isolation Level <http://www.postgresql.org/docs/9.1/static/transaction-iso.html#XACT-REPEATABLE-READ>`_ - Postgresql's implementation of repeatable read, including a description of the error condition.
+
+Simple Version Counting
+-----------------------
+
+The most straightforward way to track versions is to add an integer column
+to the mapped table, then establish it as the ``version_id_col`` within the
+mapper options::
+
+ class User(Base):
+ __tablename__ = 'user'
+
+ id = Column(Integer, primary_key=True)
+ version_id = Column(Integer, nullable=False)
+ name = Column(String(50), nullable=False)
+
+ __mapper_args__ = {
+ "version_id_col": version_id
+ }
+
+Above, the ``User`` mapping tracks integer versions using the column
+``version_id``. When an object of type ``User`` is first flushed, the
+``version_id`` column will be given a value of "1". Then, an UPDATE
+of the table later on will always be emitted in a manner similar to the
+following::
+
+ UPDATE user SET version_id=:version_id, name=:name
+ WHERE user.id = :user_id AND user.version_id = :user_version_id
+ {"name": "new name", "version_id": 2, "user_id": 1, "user_version_id": 1}
+
+The above UPDATE statement is updating the row that not only matches
+``user.id = 1``, it also is requiring that ``user.version_id = 1``, where "1"
+is the last version identifier we've been known to use on this object.
+If a transaction elsewhere has modifed the row independently, this version id
+will no longer match, and the UPDATE statement will report that no rows matched;
+this is the condition that SQLAlchemy tests, that exactly one row matched our
+UPDATE (or DELETE) statement. If zero rows match, that indicates our version
+of the data is stale, and a :class:`.StaleDataError` is raised.
+
+.. _custom_version_counter:
+
+Custom Version Counters / Types
+-------------------------------
+
+Other kinds of values or counters can be used for versioning. Common types include
+dates and GUIDs. When using an alternate type or counter scheme, SQLAlchemy
+provides a hook for this scheme using the ``version_id_generator`` argument,
+which accepts a version generation callable. This callable is passed the value of the current
+known version, and is expected to return the subsequent version.
+
+For example, if we wanted to track the versioning of our ``User`` class
+using a randomly generated GUID, we could do this (note that some backends
+support a native GUID type, but we illustrate here using a simple string)::
+
+ import uuid
+
+ class User(Base):
+ __tablename__ = 'user'
+
+ id = Column(Integer, primary_key=True)
+ version_uuid = Column(String(32))
+ name = Column(String(50), nullable=False)
+
+ __mapper_args__ = {
+ 'version_id_col':version_uuid,
+ 'version_id_generator':lambda version: uuid.uuid4().hex
+ }
+
+The persistence engine will call upon ``uuid.uuid4()`` each time a
+``User`` object is subject to an INSERT or an UPDATE. In this case, our
+version generation function can disregard the incoming value of ``version``,
+as the ``uuid4()`` function
+generates identifiers without any prerequisite value. If we were using
+a sequential versioning scheme such as numeric or a special character system,
+we could make use of the given ``version`` in order to help determine the
+subsequent value.
+
+.. seealso::
+
+ :ref:`custom_guid_type`
+
+.. _server_side_version_counter:
+
+Server Side Version Counters
+-----------------------------
+
+The ``version_id_generator`` can also be configured to rely upon a value
+that is generated by the database. In this case, the database would need
+some means of generating new identifiers when a row is subject to an INSERT
+as well as with an UPDATE. For the UPDATE case, typically an update trigger
+is needed, unless the database in question supports some other native
+version identifier. The Postgresql database in particular supports a system
+column called `xmin <http://www.postgresql.org/docs/9.1/static/ddl-system-columns.html>`_
+which provides UPDATE versioning. We can make use
+of the Postgresql ``xmin`` column to version our ``User``
+class as follows::
+
+ class User(Base):
+ __tablename__ = 'user'
+
+ id = Column(Integer, primary_key=True)
+ name = Column(String(50), nullable=False)
+ xmin = Column("xmin", Integer, system=True)
+
+ __mapper_args__ = {
+ 'version_id_col': xmin,
+ 'version_id_generator': False
+ }
+
+With the above mapping, the ORM will rely upon the ``xmin`` column for
+automatically providing the new value of the version id counter.
+
+.. topic:: creating tables that refer to system columns
+
+ In the above scenario, as ``xmin`` is a system column provided by Postgresql,
+ we use the ``system=True`` argument to mark it as a system-provided
+ column, omitted from the ``CREATE TABLE`` statement.
+
+
+The ORM typically does not actively fetch the values of database-generated
+values when it emits an INSERT or UPDATE, instead leaving these columns as
+"expired" and to be fetched when they are next accessed. However, when a
+server side version column is used, the ORM needs to actively fetch the newly
+generated value. This is so that the version counter is set up *before*
+any concurrent transaction may update it again. This fetching is also
+best done simultaneously within the INSERT or UPDATE statement using :term:`RETURNING`,
+otherwise if emitting a SELECT statement afterwards, there is still a potential
+race condition where the version counter may change before it can be fetched.
+
+When the target database supports RETURNING, an INSERT statement for our ``User`` class will look
+like this::
+
+ INSERT INTO "user" (name) VALUES (%(name)s) RETURNING "user".id, "user".xmin
+ {'name': 'ed'}
+
+Where above, the ORM can acquire any newly generated primary key values along
+with server-generated version identifiers in one statement. When the backend
+does not support RETURNING, an additional SELECT must be emitted for **every**
+INSERT, which is much less efficient, and also introduces the possibility of
+missed version counters::
+
+ INSERT INTO "user" (name) VALUES (%(name)s) RETURNING "user".id, "user".version_id
+ {'name': 'ed'}
+
+ SELECT "user".version_id AS user_version_id FROM "user" where
+ "user".id = :param_1
+ {"param_1": 1}
+
+It is *strongly recommended* that server side version counters only be used
+when absolutely necessary and only on backends that support :term:`RETURNING`,
+e.g. Postgresql, Oracle, SQL Server (though SQL Server has
+`major caveats <http://blogs.msdn.com/b/sqlprogrammability/archive/2008/07/11/update-with-output-clause-triggers-and-sqlmoreresults.aspx>`_ when triggers are used), Firebird.
+
+.. versionadded:: 0.9.0
+
+ Support for server side version identifier tracking.
+
+
Class Mapping API
=================
diff --git a/doc/build/static/docs.css b/doc/build/static/docs.css
index 09269487b..191a2041c 100644
--- a/doc/build/static/docs.css
+++ b/doc/build/static/docs.css
@@ -343,6 +343,21 @@ div.admonition, div.topic, .deprecated, .versionadded, .versionchanged {
box-shadow: 2px 2px 3px #DFDFDF;
}
+
+div.sidebar {
+ background-color: #FFFFEE;
+ border: 1px solid #DDDDBB;
+ float: right;
+ margin: 0 0 0.5em 1em;
+ padding: 7px 7px 0;
+ width: 40%;
+ font-size:.9em;
+}
+
+p.sidebar-title {
+ font-weight: bold;
+}
+
/* grrr sphinx changing your document structures, removing classes.... */
.versionadded .versionmodified,
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
index 83fa34f2c..257eaa18a 100644
--- a/lib/sqlalchemy/engine/base.py
+++ b/lib/sqlalchemy/engine/base.py
@@ -898,6 +898,12 @@ class Connection(Connectable):
elif not context._is_explicit_returning:
result.close(_autoclose_connection=False)
result._metadata = None
+ elif context.isupdate:
+ if context._is_implicit_returning:
+ context._fetch_implicit_update_returning(result)
+ result.close(_autoclose_connection=False)
+ result._metadata = None
+
elif result._metadata is None:
# no results, get rowcount
# (which requires open cursor on some drivers
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index 017dfa902..90c7f5993 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -396,6 +396,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
statement = None
postfetch_cols = None
prefetch_cols = None
+ returning_cols = None
_is_implicit_returning = False
_is_explicit_returning = False
@@ -492,6 +493,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
if self.isinsert or self.isupdate:
self.postfetch_cols = self.compiled.postfetch
self.prefetch_cols = self.compiled.prefetch
+ self.returning_cols = self.compiled.returning
self.__process_defaults()
processors = compiled._bind_processors
@@ -750,6 +752,11 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
ipk.append(row[c])
self.inserted_primary_key = ipk
+ self.returned_defaults = row
+
+ def _fetch_implicit_update_returning(self, resultproxy):
+ row = resultproxy.fetchone()
+ self.returned_defaults = row
def lastrow_has_defaults(self):
return (self.isinsert or self.isupdate) and \
diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py
index 91f3c2275..0e2316573 100644
--- a/lib/sqlalchemy/engine/result.py
+++ b/lib/sqlalchemy/engine/result.py
@@ -621,6 +621,16 @@ class ResultProxy(object):
else:
return self.context.compiled_parameters[0]
+ @property
+ def returned_defaults(self):
+ """Return the values of default columns that were fetched using
+ the ``returned_defaults`` feature.
+
+ .. versionadded:: 0.9.0
+
+ """
+ return self.context.returned_defaults
+
def lastrow_has_defaults(self):
"""Return ``lastrow_has_defaults()`` from the underlying
:class:`.ExecutionContext`.
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index fc80c4404..30b5ffc79 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -208,6 +208,22 @@ class Mapper(_InspectionAttr):
See the section :ref:`concrete_inheritance` for an example.
+ :param eager_defaults: if True, the ORM will immediately fetch the
+ value of server-generated default values after an INSERT or UPDATE,
+ rather than leaving them as expired to be fetched on next access.
+ This can be used for event schemes where the server-generated values
+ are needed immediately before the flush completes. By default,
+ this scheme will emit an individual ``SELECT`` statement per row
+ inserted or updated, which note can add significant performance
+ overhead. However, if the
+ target database supports :term:`RETURNING`, the default values will be
+ returned inline with the INSERT or UPDATE statement, which can
+ greatly enhance performance for an application that needs frequent
+ access to just-generated server defaults.
+
+ .. versionchanged:: 0.9.0 The ``eager_defaults`` option can now
+ make use of :term:`RETURNING` for backends which support it.
+
:param exclude_properties: A list or set of string column names to
be excluded from mapping.
@@ -391,9 +407,9 @@ class Mapper(_InspectionAttr):
thus persisting the value to the ``discriminator`` column
in the database.
- See also:
+ .. seealso::
- :ref:`inheritance_toplevel`
+ :ref:`inheritance_toplevel`
:param polymorphic_identity: Specifies the value which
identifies this particular class as returned by the
@@ -419,34 +435,44 @@ class Mapper(_InspectionAttr):
can be overridden here.
:param version_id_col: A :class:`.Column`
- that will be used to keep a running version id of mapped entities
- in the database. This is used during save operations to ensure that
- no other thread or process has updated the instance during the
- lifetime of the entity, else a
+ that will be used to keep a running version id of rows
+ in the table. This is used to detect concurrent updates or
+ the presence of stale data in a flush. The methodology is to
+ detect if an UPDATE statement does not match the last known
+ version id, a
:class:`~sqlalchemy.orm.exc.StaleDataError` exception is
- thrown. By default the column must be of :class:`.Integer` type,
- unless ``version_id_generator`` specifies a new generation
- algorithm.
+ thrown.
+ By default, the column must be of :class:`.Integer` type,
+ unless ``version_id_generator`` specifies an alternative version
+ generator.
- :param version_id_generator: A callable which defines the algorithm
- used to generate new version ids. Defaults to an integer
- generator. Can be replaced with one that generates timestamps,
- uuids, etc. e.g.::
+ .. seealso::
- import uuid
+ :ref:`mapper_version_counter` - discussion of version counting
+ and rationale.
- class MyClass(Base):
- __tablename__ = 'mytable'
- id = Column(Integer, primary_key=True)
- version_uuid = Column(String(32))
+ :param version_id_generator: Define how new version ids should
+ be generated. Defaults to ``None``, which indicates that
+ a simple integer counting scheme be employed. To provide a custom
+ versioning scheme, provide a callable function of the form::
- __mapper_args__ = {
- 'version_id_col':version_uuid,
- 'version_id_generator':lambda version:uuid.uuid4().hex
- }
+ def generate_version(version):
+ return next_version
+
+ Alternatively, server-side versioning functions such as triggers
+ may be used as well, by specifying the value ``False``.
+ Please see :ref:`server_side_version_counter` for a discussion
+ of important points when using this option.
+
+ .. versionadded:: 0.9.0 ``version_id_generator`` supports server-side
+ version number generation.
+
+ .. seealso::
+
+ :ref:`custom_version_counter`
+
+ :ref:`server_side_version_counter`
- The callable receives the current version identifier as its
- single argument.
:param with_polymorphic: A tuple in the form ``(<classes>,
<selectable>)`` indicating the default style of "polymorphic"
@@ -458,13 +484,9 @@ class Mapper(_InspectionAttr):
indicates a selectable that will be used to query for multiple
classes.
- See also:
-
- :ref:`concrete_inheritance` - typically uses ``with_polymorphic``
- to specify a UNION statement to select from.
+ .. seealso::
- :ref:`with_polymorphic` - usage example of the related
- :meth:`.Query.with_polymorphic` method
+ :ref:`with_polymorphic` - discussion of polymorphic querying techniques.
"""
@@ -481,9 +503,19 @@ class Mapper(_InspectionAttr):
self.order_by = order_by
self.always_refresh = always_refresh
- self.version_id_col = version_id_col
- self.version_id_generator = version_id_generator or \
- (lambda x: (x or 0) + 1)
+
+ if isinstance(version_id_col, MapperProperty):
+ self.version_id_prop = version_id_col
+ self.version_id_col = None
+ else:
+ self.version_id_col = version_id_col
+ if version_id_generator is False:
+ self.version_id_generator = False
+ elif version_id_generator is None:
+ self.version_id_generator = lambda x: (x or 0) + 1
+ else:
+ self.version_id_generator = version_id_generator
+
self.concrete = concrete
self.single = False
self.inherits = inherits
@@ -1406,6 +1438,13 @@ class Mapper(_InspectionAttr):
_validate_polymorphic_identity = None
@_memoized_configured_property
+ def _version_id_prop(self):
+ if self.version_id_col is not None:
+ return self._columntoproperty[self.version_id_col]
+ else:
+ return None
+
+ @_memoized_configured_property
def _acceptable_polymorphic_identities(self):
identities = set()
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py
index 14970ef25..042186179 100644
--- a/lib/sqlalchemy/orm/persistence.py
+++ b/lib/sqlalchemy/orm/persistence.py
@@ -61,7 +61,7 @@ def save_obj(base_mapper, states, uowtransaction, single=False):
if insert:
_emit_insert_statements(base_mapper, uowtransaction,
cached_connections,
- table, insert)
+ mapper, table, insert)
_finalize_insert_update_commands(base_mapper, uowtransaction,
states_to_insert, states_to_update)
@@ -246,9 +246,12 @@ def _collect_insert_commands(base_mapper, uowtransaction, table,
value_params = {}
has_all_pks = True
+ has_all_defaults = True
for col in mapper._cols_by_table[table]:
if col is mapper.version_id_col:
- params[col.key] = mapper.version_id_generator(None)
+ if mapper.version_id_generator is not False:
+ val = mapper.version_id_generator(None)
+ params[col.key] = val
else:
# pull straight from the dict for
# pending objects
@@ -261,6 +264,9 @@ def _collect_insert_commands(base_mapper, uowtransaction, table,
elif col.default is None and \
col.server_default is None:
params[col.key] = value
+ elif col.server_default is not None and \
+ mapper.base_mapper.eager_defaults:
+ has_all_defaults = False
elif isinstance(value, sql.ClauseElement):
value_params[col] = value
@@ -268,7 +274,8 @@ def _collect_insert_commands(base_mapper, uowtransaction, table,
params[col.key] = value
insert.append((state, state_dict, params, mapper,
- connection, value_params, has_all_pks))
+ connection, value_params, has_all_pks,
+ has_all_defaults))
return insert
@@ -315,19 +322,20 @@ def _collect_update_commands(base_mapper, uowtransaction,
params[col.key] = history.added[0]
hasdata = True
else:
- params[col.key] = mapper.version_id_generator(
- params[col._label])
-
- # HACK: check for history, in case the
- # history is only
- # in a different table than the one
- # where the version_id_col is.
- for prop in mapper._columntoproperty.values():
- history = attributes.get_state_history(
- state, prop.key,
- attributes.PASSIVE_NO_INITIALIZE)
- if history.added:
- hasdata = True
+ if mapper.version_id_generator is not False:
+ val = mapper.version_id_generator(params[col._label])
+ params[col.key] = val
+
+ # HACK: check for history, in case the
+ # history is only
+ # in a different table than the one
+ # where the version_id_col is.
+ for prop in mapper._columntoproperty.values():
+ history = attributes.get_state_history(
+ state, prop.key,
+ attributes.PASSIVE_NO_INITIALIZE)
+ if history.added:
+ hasdata = True
else:
prop = mapper._columntoproperty[col]
history = attributes.get_state_history(
@@ -478,7 +486,13 @@ def _emit_update_statements(base_mapper, uowtransaction,
sql.bindparam(mapper.version_id_col._label,
type_=mapper.version_id_col.type))
- return table.update(clause)
+ stmt = table.update(clause)
+ if mapper.base_mapper.eager_defaults:
+ stmt = stmt.return_defaults()
+ elif mapper.version_id_col is not None:
+ stmt = stmt.return_defaults(mapper.version_id_col)
+
+ return stmt
statement = base_mapper._memo(('update', table), update_stmt)
@@ -500,8 +514,7 @@ def _emit_update_statements(base_mapper, uowtransaction,
table,
state,
state_dict,
- c.context.prefetch_cols,
- c.context.postfetch_cols,
+ c,
c.context.compiled_parameters[0],
value_params)
rows += c.rowcount
@@ -521,44 +534,55 @@ def _emit_update_statements(base_mapper, uowtransaction,
def _emit_insert_statements(base_mapper, uowtransaction,
- cached_connections, table, insert):
+ cached_connections, mapper, table, insert):
"""Emit INSERT statements corresponding to value lists collected
by _collect_insert_commands()."""
statement = base_mapper._memo(('insert', table), table.insert)
- for (connection, pkeys, hasvalue, has_all_pks), \
+ for (connection, pkeys, hasvalue, has_all_pks, has_all_defaults), \
records in groupby(insert,
lambda rec: (rec[4],
list(rec[2].keys()),
bool(rec[5]),
- rec[6])
+ rec[6], rec[7])
):
- if has_all_pks and not hasvalue:
+ if \
+ (
+ has_all_defaults
+ or not base_mapper.eager_defaults
+ or not connection.dialect.implicit_returning
+ ) and has_all_pks and not hasvalue:
+
records = list(records)
multiparams = [rec[2] for rec in records]
+
c = cached_connections[connection].\
execute(statement, multiparams)
- for (state, state_dict, params, mapper,
- conn, value_params, has_all_pks), \
+ for (state, state_dict, params, mapper_rec,
+ conn, value_params, has_all_pks, has_all_defaults), \
last_inserted_params in \
zip(records, c.context.compiled_parameters):
_postfetch(
- mapper,
+ mapper_rec,
uowtransaction,
table,
state,
state_dict,
- c.context.prefetch_cols,
- c.context.postfetch_cols,
+ c,
last_inserted_params,
value_params)
else:
- for state, state_dict, params, mapper, \
+ if not has_all_defaults and base_mapper.eager_defaults:
+ statement = statement.return_defaults()
+ elif mapper.version_id_col is not None:
+ statement = statement.return_defaults(mapper.version_id_col)
+
+ for state, state_dict, params, mapper_rec, \
connection, value_params, \
- has_all_pks in records:
+ has_all_pks, has_all_defaults in records:
if value_params:
result = connection.execute(
@@ -574,23 +598,22 @@ def _emit_insert_statements(base_mapper, uowtransaction,
# set primary key attributes
for pk, col in zip(primary_key,
mapper._pks_by_table[table]):
- prop = mapper._columntoproperty[col]
+ prop = mapper_rec._columntoproperty[col]
if state_dict.get(prop.key) is None:
# TODO: would rather say:
#state_dict[prop.key] = pk
- mapper._set_state_attr_by_column(
+ mapper_rec._set_state_attr_by_column(
state,
state_dict,
col, pk)
_postfetch(
- mapper,
+ mapper_rec,
uowtransaction,
table,
state,
state_dict,
- result.context.prefetch_cols,
- result.context.postfetch_cols,
+ result,
result.context.compiled_parameters[0],
value_params)
@@ -699,14 +722,25 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction,
if readonly:
state._expire_attributes(state.dict, readonly)
- # if eager_defaults option is enabled,
- # refresh whatever has been expired.
+ # if eager_defaults option is enabled, load
+ # all expired cols. Else if we have a version_id_col, make sure
+ # it isn't expired.
+ toload_now = []
+
if base_mapper.eager_defaults and state.unloaded:
+ toload_now.extend(state.unloaded)
+ elif mapper.version_id_col is not None and \
+ mapper.version_id_generator is False:
+ prop = mapper._columntoproperty[mapper.version_id_col]
+ if prop.key in state.unloaded:
+ toload_now.extend([prop.key])
+
+ if toload_now:
state.key = base_mapper._identity_key_from_state(state)
loading.load_on_ident(
uowtransaction.session.query(base_mapper),
state.key, refresh_state=state,
- only_load_props=state.unloaded)
+ only_load_props=toload_now)
# call after_XXX extensions
if not has_identity:
@@ -716,15 +750,26 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction,
def _postfetch(mapper, uowtransaction, table,
- state, dict_, prefetch_cols, postfetch_cols,
- params, value_params):
+ state, dict_, result, params, value_params):
"""Expire attributes in need of newly persisted database state,
after an INSERT or UPDATE statement has proceeded for that
state."""
+ prefetch_cols = result.context.prefetch_cols
+ postfetch_cols = result.context.postfetch_cols
+ returning_cols = result.context.returning_cols
+
if mapper.version_id_col is not None:
prefetch_cols = list(prefetch_cols) + [mapper.version_id_col]
+ if returning_cols:
+ row = result.context.returned_defaults
+ if row is not None:
+ for col in returning_cols:
+ if col.primary_key:
+ continue
+ mapper._set_state_attr_by_column(state, dict_, col, row[col])
+
for c in prefetch_cols:
if c.key in params and c in mapper._columntoproperty:
mapper._set_state_attr_by_column(state, dict_, c, params[c.key])
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 6370b1227..5d05cbc29 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -1761,11 +1761,12 @@ class SQLCompiler(Compiled):
'=' + c[1] for c in colparams
)
- if update_stmt._returning:
- self.returning = update_stmt._returning
+ if self.returning or update_stmt._returning:
+ if not self.returning:
+ self.returning = update_stmt._returning
if self.returning_precedes_values:
text += " " + self.returning_clause(
- update_stmt, update_stmt._returning)
+ update_stmt, self.returning)
if extra_froms:
extra_from_text = self.update_from_clause(
@@ -1785,7 +1786,7 @@ class SQLCompiler(Compiled):
if self.returning and not self.returning_precedes_values:
text += " " + self.returning_clause(
- update_stmt, update_stmt._returning)
+ update_stmt, self.returning)
self.stack.pop(-1)
@@ -1866,6 +1867,19 @@ class SQLCompiler(Compiled):
self.dialect.implicit_returning and \
stmt.table.implicit_returning
+ if self.isinsert:
+ implicit_return_defaults = implicit_returning and stmt._return_defaults
+ elif self.isupdate:
+ implicit_return_defaults = self.dialect.implicit_returning and \
+ stmt.table.implicit_returning and \
+ stmt._return_defaults
+
+ if implicit_return_defaults:
+ if stmt._return_defaults is True:
+ implicit_return_defaults = set(stmt.table.c)
+ else:
+ implicit_return_defaults = set(stmt._return_defaults)
+
postfetch_lastrowid = need_pks and self.dialect.postfetch_lastrowid
check_columns = {}
@@ -1928,6 +1942,10 @@ class SQLCompiler(Compiled):
elif c.primary_key and implicit_returning:
self.returning.append(c)
value = self.process(value.self_group())
+ elif implicit_return_defaults and \
+ c in implicit_return_defaults:
+ self.returning.append(c)
+ value = self.process(value.self_group())
else:
self.postfetch.append(c)
value = self.process(value.self_group())
@@ -1984,14 +2002,20 @@ class SQLCompiler(Compiled):
not self.dialect.sequences_optional):
proc = self.process(c.default)
values.append((c, proc))
- if not c.primary_key:
+ if implicit_return_defaults and \
+ c in implicit_return_defaults:
+ self.returning.append(c)
+ elif not c.primary_key:
self.postfetch.append(c)
elif c.default.is_clause_element:
values.append(
(c, self.process(c.default.arg.self_group()))
)
- if not c.primary_key:
+ if implicit_return_defaults and \
+ c in implicit_return_defaults:
+ self.returning.append(c)
+ elif not c.primary_key:
# dont add primary key column to postfetch
self.postfetch.append(c)
else:
@@ -2000,8 +2024,14 @@ class SQLCompiler(Compiled):
)
self.prefetch.append(c)
elif c.server_default is not None:
- if not c.primary_key:
+ if implicit_return_defaults and \
+ c in implicit_return_defaults:
+ self.returning.append(c)
+ elif not c.primary_key:
self.postfetch.append(c)
+ elif implicit_return_defaults and \
+ c in implicit_return_defaults:
+ self.returning.append(c)
elif self.isupdate:
if c.onupdate is not None and not c.onupdate.is_sequence:
@@ -2009,14 +2039,25 @@ class SQLCompiler(Compiled):
values.append(
(c, self.process(c.onupdate.arg.self_group()))
)
- self.postfetch.append(c)
+ if implicit_return_defaults and \
+ c in implicit_return_defaults:
+ self.returning.append(c)
+ else:
+ self.postfetch.append(c)
else:
values.append(
(c, self._create_crud_bind_param(c, None))
)
self.prefetch.append(c)
elif c.server_onupdate is not None:
- self.postfetch.append(c)
+ if implicit_return_defaults and \
+ c in implicit_return_defaults:
+ self.returning.append(c)
+ else:
+ self.postfetch.append(c)
+ elif implicit_return_defaults and \
+ c in implicit_return_defaults:
+ self.returning.append(c)
if parameters and stmt_parameters:
check = set(parameters).intersection(
diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py
index cbebf7d55..abbd05efe 100644
--- a/lib/sqlalchemy/sql/dml.py
+++ b/lib/sqlalchemy/sql/dml.py
@@ -104,9 +104,14 @@ class UpdateBase(HasPrefixes, Executable, ClauseElement):
read the documentation notes for the database in use in
order to determine the availability of RETURNING.
+ .. seealso::
+
+ :meth:`.ValuesBase.return_defaults`
+
"""
self._returning = cols
+
@_generative
def with_hint(self, text, selectable=None, dialect_name="*"):
"""Add a table hint for a single table to this
@@ -303,6 +308,58 @@ class ValuesBase(UpdateBase):
else:
self.parameters.update(kwargs)
+ @_generative
+ def return_defaults(self, *cols):
+ """If available, make use of a RETURNING clause for the purpose
+ of fetching server-side expressions and defaults.
+
+ When used against a backend that supports RETURNING, all column
+ values generated by SQL expression or server-side-default will be added
+ to any existing RETURNING clause, excluding one that is specified
+ by the :meth:`.UpdateBase.returning` method. The column values
+ will then be available on the result using the
+ :meth:`.ResultProxy.server_returned_defaults` method as a
+ dictionary, referring to values keyed to the :meth:`.Column` object
+ as well as its ``.key``.
+
+ This method differs from :meth:`.UpdateBase.returning` in these ways:
+
+ 1. It is compatible with any backend. Backends that don't support
+ RETURNING will skip the usage of the feature, rather than raising
+ an exception. The return value of :attr:`.ResultProxy.returned_defaults`
+ will be ``None``
+
+ 2. It is compatible with the existing logic to fetch auto-generated
+ primary key values, also known as "implicit returning". Backends that
+ support RETURNING will automatically make use of RETURNING in order
+ to fetch the value of newly generated primary keys; while the
+ :meth:`.UpdateBase.returning` method circumvents this behavior,
+ :meth:`.UpdateBase.return_defaults` leaves it intact.
+
+ 3. :meth:`.UpdateBase.returning` leaves the cursor's rows ready for
+ fetching using methods like :meth:`.ResultProxy.fetchone`, whereas
+ :meth:`.ValuesBase.return_defaults` fetches the row internally.
+ While all DBAPI backends observed so far seem to only support
+ RETURNING with single-row executions,
+ technically :meth:`.UpdateBase.returning` would support a backend
+ that can deliver multiple RETURNING rows as well. However
+ :meth:`.ValuesBase.return_defaults` is single-row by definition.
+
+ :param cols: optional list of column key names or :class:`.Column`
+ objects. If omitted, all column expressions evaulated on the server
+ are added to the returning list.
+
+ .. versionadded:: 0.9.0
+
+ .. seealso::
+
+ :meth:`.UpdateBase.returning`
+
+ :meth:`.ResultProxy.returned_defaults`
+
+ """
+ self._return_defaults = cols or True
+
class Insert(ValuesBase):
"""Represent an INSERT construct.
@@ -326,52 +383,15 @@ class Insert(ValuesBase):
bind=None,
prefixes=None,
returning=None,
+ return_defaults=False,
**kwargs):
- """Construct an :class:`.Insert` object.
-
- Similar functionality is available via the
- :meth:`~.TableClause.insert` method on
- :class:`~.schema.Table`.
-
- :param table: :class:`.TableClause` which is the subject of the insert.
-
- :param values: collection of values to be inserted; see
- :meth:`.Insert.values` for a description of allowed formats here.
- Can be omitted entirely; a :class:`.Insert` construct will also
- dynamically render the VALUES clause at execution time based on
- the parameters passed to :meth:`.Connection.execute`.
-
- :param inline: if True, SQL defaults will be compiled 'inline' into the
- statement and not pre-executed.
-
- If both `values` and compile-time bind parameters are present, the
- compile-time bind parameters override the information specified
- within `values` on a per-key basis.
-
- The keys within `values` can be either :class:`~sqlalchemy.schema.Column`
- objects or their string identifiers. Each key may reference one of:
-
- * a literal data value (i.e. string, number, etc.);
- * a Column object;
- * a SELECT statement.
-
- If a ``SELECT`` statement is specified which references this
- ``INSERT`` statement's table, the statement will be correlated
- against the ``INSERT`` statement.
-
- .. seealso::
-
- :ref:`coretutorial_insert_expressions` - SQL Expression Tutorial
-
- :ref:`inserts_and_updates` - SQL Expression Tutorial
-
- """
ValuesBase.__init__(self, table, values, prefixes)
self._bind = bind
self.select = None
self.inline = inline
self._returning = returning
self.kwargs = kwargs
+ self._return_defaults = return_defaults
def get_children(self, **kwargs):
if self.select is not None:
@@ -446,109 +466,8 @@ class Update(ValuesBase):
bind=None,
prefixes=None,
returning=None,
+ return_defaults=False,
**kwargs):
- """Construct an :class:`.Update` object.
-
- E.g.::
-
- from sqlalchemy import update
-
- stmt = update(users).where(users.c.id==5).\\
- values(name='user #5')
-
- Similar functionality is available via the
- :meth:`~.TableClause.update` method on
- :class:`.Table`::
-
- stmt = users.update().\\
- where(users.c.id==5).\\
- values(name='user #5')
-
- :param table: A :class:`.Table` object representing the database
- table to be updated.
-
- :param whereclause: Optional SQL expression describing the ``WHERE``
- condition of the ``UPDATE`` statement. Modern applications
- may prefer to use the generative :meth:`~Update.where()`
- method to specify the ``WHERE`` clause.
-
- The WHERE clause can refer to multiple tables.
- For databases which support this, an ``UPDATE FROM`` clause will
- be generated, or on MySQL, a multi-table update. The statement
- will fail on databases that don't have support for multi-table
- update statements. A SQL-standard method of referring to
- additional tables in the WHERE clause is to use a correlated
- subquery::
-
- users.update().values(name='ed').where(
- users.c.name==select([addresses.c.email_address]).\\
- where(addresses.c.user_id==users.c.id).\\
- as_scalar()
- )
-
- .. versionchanged:: 0.7.4
- The WHERE clause can refer to multiple tables.
-
- :param values:
- Optional dictionary which specifies the ``SET`` conditions of the
- ``UPDATE``. If left as ``None``, the ``SET``
- conditions are determined from those parameters passed to the
- statement during the execution and/or compilation of the
- statement. When compiled standalone without any parameters,
- the ``SET`` clause generates for all columns.
-
- Modern applications may prefer to use the generative
- :meth:`.Update.values` method to set the values of the
- UPDATE statement.
-
- :param inline:
- if True, SQL defaults present on :class:`.Column` objects via
- the ``default`` keyword will be compiled 'inline' into the statement
- and not pre-executed. This means that their values will not
- be available in the dictionary returned from
- :meth:`.ResultProxy.last_updated_params`.
-
- If both ``values`` and compile-time bind parameters are present, the
- compile-time bind parameters override the information specified
- within ``values`` on a per-key basis.
-
- The keys within ``values`` can be either :class:`.Column`
- objects or their string identifiers (specifically the "key" of the
- :class:`.Column`, normally but not necessarily equivalent to
- its "name"). Normally, the
- :class:`.Column` objects used here are expected to be
- part of the target :class:`.Table` that is the table
- to be updated. However when using MySQL, a multiple-table
- UPDATE statement can refer to columns from any of
- the tables referred to in the WHERE clause.
-
- The values referred to in ``values`` are typically:
-
- * a literal data value (i.e. string, number, etc.)
- * a SQL expression, such as a related :class:`.Column`,
- a scalar-returning :func:`.select` construct,
- etc.
-
- When combining :func:`.select` constructs within the values
- clause of an :func:`.update` construct,
- the subquery represented by the :func:`.select` should be
- *correlated* to the parent table, that is, providing criterion
- which links the table inside the subquery to the outer table
- being updated::
-
- users.update().values(
- name=select([addresses.c.email_address]).\\
- where(addresses.c.user_id==users.c.id).\\
- as_scalar()
- )
-
- .. seealso::
-
- :ref:`inserts_and_updates` - SQL Expression
- Language Tutorial
-
-
- """
ValuesBase.__init__(self, table, values, prefixes)
self._bind = bind
self._returning = returning
@@ -558,6 +477,7 @@ class Update(ValuesBase):
self._whereclause = None
self.inline = inline
self.kwargs = kwargs
+ self._return_defaults = return_defaults
def get_children(self, **kwargs):
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index bbbe0b235..01091bc0a 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -49,7 +49,7 @@ from .selectable import Alias, Join, Select, Selectable, TableClause, \
subquery, HasPrefixes, Exists, ScalarSelect
-from .dml import Insert, Update, Delete
+from .dml import Insert, Update, Delete, UpdateBase, ValuesBase
# factory functions - these pull class-bound constructors and classmethods
# from SQL elements and selectables into public functions. This allows
@@ -101,6 +101,7 @@ from .elements import _literal_as_text, _clause_element_as_expr,\
from .selectable import _interpret_as_from
+
# old names for compatibility
_Executable = Executable
_BindParamClause = BindParameter
diff --git a/test/orm/test_unitofwork.py b/test/orm/test_unitofwork.py
index 4b9799d47..cbd981120 100644
--- a/test/orm/test_unitofwork.py
+++ b/test/orm/test_unitofwork.py
@@ -860,7 +860,11 @@ class DefaultTest(fixtures.MappedTest):
session = create_session()
session.add(h1)
- session.flush()
+
+ if testing.db.dialect.implicit_returning:
+ self.sql_count_(1, session.flush)
+ else:
+ self.sql_count_(2, session.flush)
self.sql_count_(0, lambda: eq_(h1.hoho, hohoval))
diff --git a/test/orm/test_versioning.py b/test/orm/test_versioning.py
index abb08c536..d8d92830f 100644
--- a/test/orm/test_versioning.py
+++ b/test/orm/test_versioning.py
@@ -11,7 +11,7 @@ from sqlalchemy.orm import mapper, relationship, Session, \
from sqlalchemy.testing import eq_, ne_, assert_raises, assert_raises_message
from sqlalchemy.testing import fixtures
from test.orm import _fixtures
-from sqlalchemy.testing import fixtures
+from sqlalchemy.testing.assertsql import AllOf, CompiledSQL
_uuids = [
@@ -461,12 +461,12 @@ class AlternateGeneratorTest(fixtures.MappedTest):
cls.classes.P)
mapper(P, p, version_id_col=p.c.version_id,
- version_id_generator=lambda x:make_uuid(),
+ version_id_generator=lambda x: make_uuid(),
properties={
- 'c':relationship(C, uselist=False, cascade='all, delete-orphan')
+ 'c': relationship(C, uselist=False, cascade='all, delete-orphan')
})
mapper(C, c, version_id_col=c.c.version_id,
- version_id_generator=lambda x:make_uuid(),
+ version_id_generator=lambda x: make_uuid(),
)
@testing.emits_warning_on('+zxjdbc', r'.*does not support updated rowcount')
@@ -643,3 +643,190 @@ class InheritanceTwoVersionIdsTest(fixtures.MappedTest):
mapper,
Sub, sub, inherits=Base,
version_id_col=sub.c.version_id)
+
+
+class ServerVersioningTest(fixtures.MappedTest):
+ run_define_tables = 'each'
+
+ @classmethod
+ def define_tables(cls, metadata):
+ from sqlalchemy.sql import ColumnElement
+ from sqlalchemy.ext.compiler import compiles
+ import itertools
+
+ counter = itertools.count(1)
+
+ class IncDefault(ColumnElement):
+ pass
+
+ @compiles(IncDefault)
+ def compile(element, compiler, **kw):
+ # cache the counter value on the statement
+ # itself so the assertsql system gets the same
+ # value when it compiles the statement a second time
+ stmt = compiler.statement
+ if hasattr(stmt, "_counter"):
+ return stmt._counter
+ else:
+ stmt._counter = str(counter.next())
+ return stmt._counter
+
+ Table('version_table', metadata,
+ Column('id', Integer, primary_key=True,
+ test_needs_autoincrement=True),
+ Column('version_id', Integer, nullable=False,
+ default=IncDefault(), onupdate=IncDefault()),
+ Column('value', String(40), nullable=False))
+
+ @classmethod
+ def setup_classes(cls):
+ class Foo(cls.Basic):
+ pass
+ class Bar(cls.Basic):
+ pass
+
+ def _fixture(self, expire_on_commit=True):
+ Foo, version_table = self.classes.Foo, self.tables.version_table
+
+ mapper(Foo, version_table,
+ version_id_col=version_table.c.version_id,
+ version_id_generator=False,
+ )
+
+ s1 = Session(expire_on_commit=expire_on_commit)
+ return s1
+
+ def test_insert_col(self):
+ sess = self._fixture()
+
+ f1 = self.classes.Foo(value='f1')
+ sess.add(f1)
+
+ statements = [
+ # note that the assertsql tests the rule against
+ # "default" - on a "returning" backend, the statement
+ # includes "RETURNING"
+ CompiledSQL(
+ "INSERT INTO version_table (version_id, value) "
+ "VALUES (1, :value)",
+ lambda ctx: [{'value': 'f1'}]
+ )
+ ]
+ if not testing.db.dialect.implicit_returning:
+ # DBs without implicit returning, we must immediately
+ # SELECT for the new version id
+ statements.append(
+ CompiledSQL(
+ "SELECT version_table.version_id AS version_table_version_id "
+ "FROM version_table WHERE version_table.id = :param_1",
+ lambda ctx: [{"param_1": 1}]
+ )
+ )
+ self.assert_sql_execution(testing.db, sess.flush, *statements)
+
+ def test_update_col(self):
+ sess = self._fixture()
+
+ f1 = self.classes.Foo(value='f1')
+ sess.add(f1)
+ sess.flush()
+
+ f1.value = 'f2'
+
+ statements = [
+ # note that the assertsql tests the rule against
+ # "default" - on a "returning" backend, the statement
+ # includes "RETURNING"
+ CompiledSQL(
+ "UPDATE version_table SET version_id=2, value=:value "
+ "WHERE version_table.id = :version_table_id AND "
+ "version_table.version_id = :version_table_version_id",
+ lambda ctx: [{"version_table_id": 1,
+ "version_table_version_id": 1, "value": "f2"}]
+ )
+ ]
+ if not testing.db.dialect.implicit_returning:
+ # DBs without implicit returning, we must immediately
+ # SELECT for the new version id
+ statements.append(
+ CompiledSQL(
+ "SELECT version_table.version_id AS version_table_version_id "
+ "FROM version_table WHERE version_table.id = :param_1",
+ lambda ctx: [{"param_1": 1}]
+ )
+ )
+ self.assert_sql_execution(testing.db, sess.flush, *statements)
+
+
+ def test_delete_col(self):
+ sess = self._fixture()
+
+ f1 = self.classes.Foo(value='f1')
+ sess.add(f1)
+ sess.flush()
+
+ sess.delete(f1)
+
+ statements = [
+ # note that the assertsql tests the rule against
+ # "default" - on a "returning" backend, the statement
+ # includes "RETURNING"
+ CompiledSQL(
+ "DELETE FROM version_table "
+ "WHERE version_table.id = :id AND "
+ "version_table.version_id = :version_id",
+ lambda ctx: [{"id": 1, "version_id": 1}]
+ )
+ ]
+ self.assert_sql_execution(testing.db, sess.flush, *statements)
+
+ def test_concurrent_mod_err_expire_on_commit(self):
+ sess = self._fixture()
+
+ f1 = self.classes.Foo(value='f1')
+ sess.add(f1)
+ sess.commit()
+
+ f1.value
+
+ s2 = Session()
+ f2 = s2.query(self.classes.Foo).first()
+ f2.value = 'f2'
+ s2.commit()
+
+ f1.value = 'f3'
+
+ assert_raises_message(
+ orm.exc.StaleDataError,
+ r"UPDATE statement on table 'version_table' expected to "
+ r"update 1 row\(s\); 0 were matched.",
+ sess.commit
+ )
+
+ def test_concurrent_mod_err_noexpire_on_commit(self):
+ sess = self._fixture(expire_on_commit=False)
+
+ f1 = self.classes.Foo(value='f1')
+ sess.add(f1)
+ sess.commit()
+
+ # here, we're not expired overall, so no load occurs and we
+ # stay without a version id, unless we've emitted
+ # a SELECT for it within the flush.
+ f1.value
+
+ s2 = Session(expire_on_commit=False)
+ f2 = s2.query(self.classes.Foo).first()
+ f2.value = 'f2'
+ s2.commit()
+
+ f1.value = 'f3'
+
+ assert_raises_message(
+ orm.exc.StaleDataError,
+ r"UPDATE statement on table 'version_table' expected to "
+ r"update 1 row\(s\); 0 were matched.",
+ sess.commit
+ )
+
+
diff --git a/test/sql/test_returning.py b/test/sql/test_returning.py
index 6a42b0625..179d2d261 100644
--- a/test/sql/test_returning.py
+++ b/test/sql/test_returning.py
@@ -6,6 +6,7 @@ from sqlalchemy.types import TypeDecorator
from sqlalchemy.testing import fixtures, AssertsExecutionResults, engines, \
assert_raises_message
from sqlalchemy import exc as sa_exc
+import itertools
class ReturningTest(fixtures.TestBase, AssertsExecutionResults):
__requires__ = 'returning',
@@ -184,6 +185,127 @@ class KeyReturningTest(fixtures.TestBase, AssertsExecutionResults):
assert row[table.c.foo_id] == row['id'] == 1
+class ReturnDefaultsTest(fixtures.TablesTest):
+ __requires__ = ('returning', )
+ run_define_tables = 'each'
+
+ @classmethod
+ def define_tables(cls, metadata):
+ from sqlalchemy.sql import ColumnElement
+ from sqlalchemy.ext.compiler import compiles
+
+ counter = itertools.count()
+
+ class IncDefault(ColumnElement):
+ pass
+
+ @compiles(IncDefault)
+ def compile(element, compiler, **kw):
+ return str(counter.next())
+
+ Table("t1", metadata,
+ Column("id", Integer, primary_key=True, test_needs_autoincrement=True),
+ Column("data", String(50)),
+ Column("insdef", Integer, default=IncDefault()),
+ Column("upddef", Integer, onupdate=IncDefault())
+ )
+
+ def test_chained_insert_pk(self):
+ t1 = self.tables.t1
+ result = testing.db.execute(
+ t1.insert().values(upddef=1).return_defaults(t1.c.insdef)
+ )
+ eq_(
+ dict(result.returned_defaults),
+ {"id": 1, "insdef": 0}
+ )
+
+ def test_arg_insert_pk(self):
+ t1 = self.tables.t1
+ result = testing.db.execute(
+ t1.insert(return_defaults=[t1.c.insdef]).values(upddef=1)
+ )
+ eq_(
+ dict(result.returned_defaults),
+ {"id": 1, "insdef": 0}
+ )
+
+ def test_chained_update_pk(self):
+ t1 = self.tables.t1
+ testing.db.execute(
+ t1.insert().values(upddef=1)
+ )
+ result = testing.db.execute(t1.update().values(data='d1').
+ return_defaults(t1.c.upddef))
+ eq_(
+ dict(result.returned_defaults),
+ {"upddef": 1}
+ )
+
+ def test_arg_update_pk(self):
+ t1 = self.tables.t1
+ testing.db.execute(
+ t1.insert().values(upddef=1)
+ )
+ result = testing.db.execute(t1.update(return_defaults=[t1.c.upddef]).
+ values(data='d1'))
+ eq_(
+ dict(result.returned_defaults),
+ {"upddef": 1}
+ )
+
+ def test_insert_non_default(self):
+ """test that a column not marked at all as a
+ default works with this feature."""
+
+ t1 = self.tables.t1
+ result = testing.db.execute(
+ t1.insert().values(upddef=1).return_defaults(t1.c.data)
+ )
+ eq_(
+ dict(result.returned_defaults),
+ {"id": 1, "data": None}
+ )
+
+ def test_update_non_default(self):
+ """test that a column not marked at all as a
+ default works with this feature."""
+
+ t1 = self.tables.t1
+ testing.db.execute(
+ t1.insert().values(upddef=1)
+ )
+ result = testing.db.execute(t1.update().
+ values(upddef=2).return_defaults(t1.c.data))
+ eq_(
+ dict(result.returned_defaults),
+ {"data": None}
+ )
+
+ def test_insert_non_default_plus_default(self):
+ t1 = self.tables.t1
+ result = testing.db.execute(
+ t1.insert().values(upddef=1).return_defaults(
+ t1.c.data, t1.c.insdef)
+ )
+ eq_(
+ dict(result.returned_defaults),
+ {"id": 1, "data": None, "insdef": 0}
+ )
+
+ def test_update_non_default_plus_default(self):
+ t1 = self.tables.t1
+ testing.db.execute(
+ t1.insert().values(upddef=1)
+ )
+ result = testing.db.execute(t1.update().
+ values(insdef=2).return_defaults(
+ t1.c.data, t1.c.upddef))
+ eq_(
+ dict(result.returned_defaults),
+ {"data": None, 'upddef': 1}
+ )
+
class ImplicitReturningFlag(fixtures.TestBase):
def test_flag_turned_off(self):
e = engines.testing_engine(options={'implicit_returning':False})