diff options
-rw-r--r-- | doc/build/changelog/migration_14.rst | 53 | ||||
-rw-r--r-- | doc/build/changelog/unreleased_14/1653.rst | 8 | ||||
-rw-r--r-- | doc/build/changelog/unreleased_14/orm_update_delete.rst | 18 | ||||
-rw-r--r-- | examples/dogpile_caching/caching_query.py | 5 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/mssql/base.py | 12 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/base.py | 11 | ||||
-rw-r--r-- | lib/sqlalchemy/engine/default.py | 1 | ||||
-rw-r--r-- | lib/sqlalchemy/engine/result.py | 4 | ||||
-rw-r--r-- | lib/sqlalchemy/ext/horizontal_shard.py | 1 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/context.py | 41 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/evaluator.py | 28 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/mapper.py | 19 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/persistence.py | 250 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/session.py | 72 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/requirements.py | 22 | ||||
-rw-r--r-- | test/aaa_profiling/test_orm.py | 3 | ||||
-rw-r--r-- | test/engine/test_reflection.py | 1 | ||||
-rw-r--r-- | test/ext/test_horizontal_shard.py | 62 | ||||
-rw-r--r-- | test/orm/test_evaluator.py | 50 | ||||
-rw-r--r-- | test/orm/test_update_delete.py | 128 | ||||
-rw-r--r-- | test/sql/test_returning.py | 33 |
21 files changed, 698 insertions, 124 deletions
diff --git a/doc/build/changelog/migration_14.rst b/doc/build/changelog/migration_14.rst index 6702f86f0..97b94087f 100644 --- a/doc/build/changelog/migration_14.rst +++ b/doc/build/changelog/migration_14.rst @@ -349,7 +349,7 @@ New Result object ----------------- The ``ResultProxy`` object has been replaced with the 2.0 -style -:class:`.Result` object discussed at :ref:`change_result_20_core`. This result object +:class:`_result.Result` object discussed at :ref:`change_result_20_core`. This result object is fully compatible with ``ResultProxy`` and includes many new features, that are now applied to both Core and ORM results equally, including methods such as: @@ -366,7 +366,7 @@ such as: When using Core, the object returned is an instance of :class:`.CursorResult`, which continues to feature the same API features as ``ResultProxy`` regarding -inserted primary keys, defaults, rowcounts, etc. For ORM, a :class:`.Result` +inserted primary keys, defaults, rowcounts, etc. For ORM, a :class:`_result.Result` subclass will be returned that performs translation of Core rows into ORM rows, and then allows all the same operations to take place. @@ -594,6 +594,55 @@ as was present previously. :ticket:`4826` +.. _change_orm_update_returning_14: + +ORM Bulk Update and Delete use RETURNING for "fetch" strategy when available +---------------------------------------------------------------------------- + +An ORM bulk update or delete that uses the "fetch" strategy:: + + sess.query(User).filter(User.age > 29).update( + {"age": User.age - 10}, synchronize_session="fetch" + ) + +Will now use RETURNING if the backend database supports it; this currently +includes PostgreSQL and SQL Server (the Oracle dialect does not support RETURNING +of multiple rows):: + + UPDATE users SET age_int=(users.age_int - %(age_int_1)s) WHERE users.age_int > %(age_int_2)s RETURNING users.id + [generated in 0.00060s] {'age_int_1': 10, 'age_int_2': 29} + Col ('id',) + Row (2,) + Row (4,) + +For backends that do not support RETURNING of multiple rows, the previous approach +of emitting SELECT for the primary keys beforehand is still used:: + + SELECT users.id FROM users WHERE users.age_int > %(age_int_1)s + [generated in 0.00043s] {'age_int_1': 29} + Col ('id',) + Row (2,) + Row (4,) + UPDATE users SET age_int=(users.age_int - %(age_int_1)s) WHERE users.age_int > %(age_int_2)s + [generated in 0.00102s] {'age_int_1': 10, 'age_int_2': 29} + +One of the intricate challenges of this change is to support cases such as the +horizontal sharding extension, where a single bulk update or delete may be +multiplexed among backends some of which support RETURNING and some don't. The +new 1.4 execution archiecture supports this case so that the "fetch" strategy +can be left intact with a graceful degrade to using a SELECT, rather than having +to add a new "returning" strategy that would not be backend-agnostic. + +As part of this change, the "fetch" strategy is also made much more efficient +in that it will no longer expire the objects located which match the rows, +for Python expressions used in the SET clause which can be evaluated in +Python; these are instead assigned +directly onto the object in the same way as the "evaluate" strategy. Only +for SQL expressions that can't be evaluated does it fall back to expiring +the attributes. The "evaluate" strategy has also been enhanced to fall back +to "expire" for a value that cannot be evaluated. + + Behavioral Changes - ORM ======================== diff --git a/doc/build/changelog/unreleased_14/1653.rst b/doc/build/changelog/unreleased_14/1653.rst new file mode 100644 index 000000000..e35121216 --- /dev/null +++ b/doc/build/changelog/unreleased_14/1653.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: usecase, orm + :tickets: 1653 + + The evaluator that takes place within the ORM bulk update and delete for + synchronize_session="evaluate" now supports the IN and NOT IN operators. + Tuple IN is also supported. + diff --git a/doc/build/changelog/unreleased_14/orm_update_delete.rst b/doc/build/changelog/unreleased_14/orm_update_delete.rst new file mode 100644 index 000000000..e16ab62ca --- /dev/null +++ b/doc/build/changelog/unreleased_14/orm_update_delete.rst @@ -0,0 +1,18 @@ +.. change:: + :tags: orm, performance + + The bulk update and delete methods :meth:`.Query.update` and + :meth:`.Query.delete`, as well as their 2.0-style counterparts, now make + use of RETURNING when the "fetch" strategy is used in order to fetch the + list of affected primary key identites, rather than emitting a separate + SELECT, when the backend in use supports RETURNING. Additionally, the + "fetch" strategy will in ordinary cases not expire the attributes that have + been updated, and will instead apply the updated values directly in the + same way that the "evaluate" strategy does, to avoid having to refresh the + object. The "evaluate" strategy will also fall back to expiring + attributes that were updated to a SQL expression that was unevaluable in + Python. + + .. seealso:: + + :ref:`change_orm_update_returning_14`
\ No newline at end of file diff --git a/examples/dogpile_caching/caching_query.py b/examples/dogpile_caching/caching_query.py index 54f712a11..f99447361 100644 --- a/examples/dogpile_caching/caching_query.py +++ b/examples/dogpile_caching/caching_query.py @@ -130,11 +130,14 @@ class FromCache(UserDefinedOption): self.expiration_time = expiration_time self.ignore_expiration = ignore_expiration + def _gen_cache_key(self, anon_map, bindparams): + return None + def _generate_cache_key(self, statement, parameters, orm_cache): statement_cache_key = statement._generate_cache_key() key = statement_cache_key.to_offline_string( - orm_cache._statement_cache, parameters + orm_cache._statement_cache, statement, parameters ) + repr(self.cache_key) # print("here's our key...%s" % key) diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 5aaecf23a..4b211bde7 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -2402,6 +2402,9 @@ class MSDialect(default.DefaultDialect): max_identifier_length = 128 schema_name = "dbo" + implicit_returning = True + full_returning = True + colspecs = { sqltypes.DateTime: _MSDateTime, sqltypes.Date: _MSDate, @@ -2567,11 +2570,10 @@ class MSDialect(default.DefaultDialect): "features may not function properly." % ".".join(str(x) for x in self.server_version_info) ) - if ( - self.server_version_info >= MS_2005_VERSION - and "implicit_returning" not in self.__dict__ - ): - self.implicit_returning = True + + if self.server_version_info < MS_2005_VERSION: + self.implicit_returning = self.full_returning = False + if self.server_version_info >= MS_2008_VERSION: self.supports_multivalues_insert = True if self.deprecate_large_types is None: diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index f3e775354..c2d9af4d2 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -2510,6 +2510,9 @@ class PGDialect(default.DefaultDialect): inspector = PGInspector isolation_level = None + implicit_returning = True + full_returning = True + construct_arguments = [ ( schema.Index, @@ -2555,10 +2558,10 @@ class PGDialect(default.DefaultDialect): def initialize(self, connection): super(PGDialect, self).initialize(connection) - self.implicit_returning = self.server_version_info > ( - 8, - 2, - ) and self.__dict__.get("implicit_returning", True) + + if self.server_version_info <= (8, 2): + self.full_returning = self.implicit_returning = False + self.supports_native_enum = self.server_version_info >= (8, 3) if not self.supports_native_enum: self.colspecs = self.colspecs.copy() diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 4d516e97c..1a8dbb4cd 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -67,6 +67,7 @@ class DefaultDialect(interfaces.Dialect): preexecute_autoincrement_sequences = False postfetch_lastrowid = True implicit_returning = False + full_returning = False cte_follows_insert = False diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index b29bc22d4..ead52a3f8 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -1259,6 +1259,10 @@ class IteratorResult(Result): return list(itertools.islice(self.iterator, 0, size)) +def null_result(): + return IteratorResult(SimpleResultMetaData([]), iter([])) + + class ChunkedIteratorResult(IteratorResult): """An :class:`.IteratorResult` that works from an iterator-producing callable. diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py index 0983807cb..9d7266d1a 100644 --- a/lib/sqlalchemy/ext/horizontal_shard.py +++ b/lib/sqlalchemy/ext/horizontal_shard.py @@ -220,7 +220,6 @@ def execute_and_instances(orm_context): update_options = active_options = orm_context.update_delete_options session = orm_context.session - # orm_query = orm_context.orm_query def iter_for_shard(shard_id, load_options, update_options): execution_options = dict(orm_context.local_execution_options) diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index f380229e1..77237f089 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -193,8 +193,17 @@ class ORMCompileState(CompileState): @classmethod def orm_pre_session_exec( - cls, session, statement, params, execution_options, bind_arguments + cls, + session, + statement, + params, + execution_options, + bind_arguments, + is_reentrant_invoke, ): + if is_reentrant_invoke: + return statement, execution_options + load_options = execution_options.get( "_sa_orm_load_options", QueryContext.default_load_options ) @@ -220,7 +229,7 @@ class ORMCompileState(CompileState): if load_options._autoflush: session._autoflush() - return execution_options + return statement, execution_options @classmethod def orm_setup_cursor_result( @@ -2259,9 +2268,20 @@ class _ColumnEntity(_QueryEntity): ) if _entity: - _ORMColumnEntity( - compile_state, column, _entity, parent_bundle=parent_bundle - ) + if "identity_token" in column._annotations: + _IdentityTokenEntity( + compile_state, + column, + _entity, + parent_bundle=parent_bundle, + ) + else: + _ORMColumnEntity( + compile_state, + column, + _entity, + parent_bundle=parent_bundle, + ) else: _RawColumnEntity( compile_state, column, parent_bundle=parent_bundle @@ -2462,3 +2482,14 @@ class _ORMColumnEntity(_ColumnEntity): compile_state.primary_columns.append(column) self._fetch_column = column + + +class _IdentityTokenEntity(_ORMColumnEntity): + def setup_compile_state(self, compile_state): + pass + + def row_processor(self, context, result): + def getter(row): + return context.load_options._refresh_identity_token + + return getter, self._label_name, self._extra_entities diff --git a/lib/sqlalchemy/orm/evaluator.py b/lib/sqlalchemy/orm/evaluator.py index 51bc8e426..caa9ffe10 100644 --- a/lib/sqlalchemy/orm/evaluator.py +++ b/lib/sqlalchemy/orm/evaluator.py @@ -35,6 +35,10 @@ _straight_ops = set( ) ) +_extended_ops = { + operators.in_op: (lambda a, b: a in b), + operators.notin_op: (lambda a, b: a not in b), +} _notimplemented_ops = set( getattr(operators, op) @@ -43,9 +47,8 @@ _notimplemented_ops = set( "notlike_op", "ilike_op", "notilike_op", + "startswith_op", "between_op", - "in_op", - "notin_op", "endswith_op", "concat_op", ) @@ -136,6 +139,17 @@ class EvaluatorCompiler(object): return False return True + elif clause.operator is operators.comma_op: + + def evaluate(obj): + values = [] + for sub_evaluate in evaluators: + value = sub_evaluate(obj) + if value is None: + return None + values.append(value) + return tuple(values) + else: raise UnevaluatableError( "Cannot evaluate clauselist with operator %s" % clause.operator @@ -158,6 +172,16 @@ class EvaluatorCompiler(object): def evaluate(obj): return eval_left(obj) != eval_right(obj) + elif operator in _extended_ops: + + def evaluate(obj): + left_val = eval_left(obj) + right_val = eval_right(obj) + if left_val is None or right_val is None: + return None + + return _extended_ops[operator](left_val, right_val) + elif operator in _straight_ops: def evaluate(obj): diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index bec6da74d..ef0e9a49b 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -2240,7 +2240,6 @@ class Mapper( "entity_namespace": self, "parententity": self, "parentmapper": self, - "compile_state_plugin": "orm", } if self.persist_selectable is not self.local_table: # joined table inheritance, with polymorphic selectable, @@ -2250,7 +2249,6 @@ class Mapper( "entity_namespace": self, "parententity": self, "parentmapper": self, - "compile_state_plugin": "orm", } )._set_propagate_attrs( {"compile_state_plugin": "orm", "plugin_subject": self} @@ -2260,6 +2258,23 @@ class Mapper( {"compile_state_plugin": "orm", "plugin_subject": self} ) + @util.memoized_property + def select_identity_token(self): + return ( + expression.null() + ._annotate( + { + "entity_namespace": self, + "parententity": self, + "parentmapper": self, + "identity_token": True, + } + ) + ._set_propagate_attrs( + {"compile_state_plugin": "orm", "plugin_subject": self} + ) + ) + @property def selectable(self): """The :func:`_expression.select` construct this diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 8393eaf74..bd8efe77f 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -28,6 +28,7 @@ from .. import exc as sa_exc from .. import future from .. import sql from .. import util +from ..engine import result as _result from ..future import select as future_select from ..sql import coercions from ..sql import expression @@ -1672,8 +1673,17 @@ class BulkUDCompileState(CompileState): @classmethod def orm_pre_session_exec( - cls, session, statement, params, execution_options, bind_arguments + cls, + session, + statement, + params, + execution_options, + bind_arguments, + is_reentrant_invoke, ): + if is_reentrant_invoke: + return statement, execution_options + sync = execution_options.get("synchronize_session", None) if sync is None: sync = statement._execution_options.get( @@ -1706,6 +1716,17 @@ class BulkUDCompileState(CompileState): if update_options._autoflush: session._autoflush() + statement = statement._annotate( + {"synchronize_session": update_options._synchronize_session} + ) + + # this stage of the execution is called before the do_orm_execute event + # hook. meaning for an extension like horizontal sharding, this step + # happens before the extension splits out into multiple backends and + # runs only once. if we do pre_sync_fetch, we execute a SELECT + # statement, which the horizontal sharding extension splits amongst the + # shards and combines the results together. + if update_options._synchronize_session == "evaluate": update_options = cls._do_pre_synchronize_evaluate( session, @@ -1725,19 +1746,31 @@ class BulkUDCompileState(CompileState): update_options, ) - return util.immutabledict(execution_options).union( - dict(_sa_orm_update_options=update_options) + return ( + statement, + util.immutabledict(execution_options).union( + dict(_sa_orm_update_options=update_options) + ), ) @classmethod def orm_setup_cursor_result( cls, session, statement, execution_options, bind_arguments, result ): + + # this stage of the execution is called after the + # do_orm_execute event hook. meaning for an extension like + # horizontal sharding, this step happens *within* the horizontal + # sharding event handler which calls session.execute() re-entrantly + # and will occur for each backend individually. + # the sharding extension then returns its own merged result from the + # individual ones we return here. + update_options = execution_options["_sa_orm_update_options"] if update_options._synchronize_session == "evaluate": - cls._do_post_synchronize_evaluate(session, update_options) + cls._do_post_synchronize_evaluate(session, result, update_options) elif update_options._synchronize_session == "fetch": - cls._do_post_synchronize_fetch(session, update_options) + cls._do_post_synchronize_fetch(session, result, update_options) return result @@ -1767,18 +1800,6 @@ class BulkUDCompileState(CompileState): def eval_condition(obj): return True - # TODO: something more robust for this conditional - if statement.__visit_name__ == "update": - resolved_values = cls._get_resolved_values(mapper, statement) - value_evaluators = {} - resolved_keys_as_propnames = cls._resolved_keys_as_propnames( - mapper, resolved_values - ) - for key, value in resolved_keys_as_propnames: - value_evaluators[key] = evaluator_compiler.process( - coercions.expect(roles.ExpressionElementRole, value) - ) - except evaluator.UnevaluatableError as err: util.raise_( sa_exc.InvalidRequestError( @@ -1789,13 +1810,35 @@ class BulkUDCompileState(CompileState): from_=err, ) - # TODO: detect when the where clause is a trivial primary key match + if statement.__visit_name__ == "update": + resolved_values = cls._get_resolved_values(mapper, statement) + value_evaluators = {} + resolved_keys_as_propnames = cls._resolved_keys_as_propnames( + mapper, resolved_values + ) + for key, value in resolved_keys_as_propnames: + try: + _evaluator = evaluator_compiler.process( + coercions.expect(roles.ExpressionElementRole, value) + ) + except evaluator.UnevaluatableError: + pass + else: + value_evaluators[key] = _evaluator + + # TODO: detect when the where clause is a trivial primary key match. matched_objects = [ obj for (cls, pk, identity_token,), obj in session.identity_map.items() if issubclass(cls, target_cls) and eval_condition(obj) - and identity_token == update_options._refresh_identity_token + and ( + update_options._refresh_identity_token is None + # TODO: coverage for the case where horiziontal sharding + # invokes an update() or delete() given an explicit identity + # token up front + or identity_token == update_options._refresh_identity_token + ) ] return update_options + { "_matched_objects": matched_objects, @@ -1868,29 +1911,56 @@ class BulkUDCompileState(CompileState): ): mapper = update_options._subject_mapper - if mapper: - primary_table = mapper.local_table - else: - primary_table = statement._raw_columns[0] - - # note this creates a Select() *without* the ORM plugin. - # we don't want that here. - select_stmt = future_select(*primary_table.primary_key) + select_stmt = future_select( + *(mapper.primary_key + (mapper.select_identity_token,)) + ) select_stmt._where_criteria = statement._where_criteria - matched_rows = session.execute( - select_stmt, params, execution_options, bind_arguments - ).fetchall() + def skip_for_full_returning(orm_context): + bind = orm_context.session.get_bind(**orm_context.bind_arguments) + if bind.dialect.full_returning: + return _result.null_result() + else: + return None + + result = session.execute( + select_stmt, + params, + execution_options, + bind_arguments, + _add_event=skip_for_full_returning, + ) + matched_rows = result.fetchall() + + value_evaluators = _EMPTY_DICT if statement.__visit_name__ == "update": + target_cls = mapper.class_ + evaluator_compiler = evaluator.EvaluatorCompiler(target_cls) resolved_values = cls._get_resolved_values(mapper, statement) resolved_keys_as_propnames = cls._resolved_keys_as_propnames( mapper, resolved_values ) + + resolved_keys_as_propnames = cls._resolved_keys_as_propnames( + mapper, resolved_values + ) + value_evaluators = {} + for key, value in resolved_keys_as_propnames: + try: + _evaluator = evaluator_compiler.process( + coercions.expect(roles.ExpressionElementRole, value) + ) + except evaluator.UnevaluatableError: + pass + else: + value_evaluators[key] = _evaluator + else: resolved_keys_as_propnames = _EMPTY_DICT return update_options + { + "_value_evaluators": value_evaluators, "_matched_rows": matched_rows, "_resolved_keys_as_propnames": resolved_keys_as_propnames, } @@ -1925,15 +1995,23 @@ class BulkORMUpdate(UpdateDMLState, BulkUDCompileState): elif statement._values: new_stmt._values = self._resolved_values + if ( + statement._annotations.get("synchronize_session", None) == "fetch" + and compiler.dialect.full_returning + ): + new_stmt = new_stmt.returning(*mapper.primary_key) + UpdateDMLState.__init__(self, new_stmt, compiler, **kw) return self @classmethod - def _do_post_synchronize_evaluate(cls, session, update_options): + def _do_post_synchronize_evaluate(cls, session, result, update_options): states = set() evaluated_keys = list(update_options._value_evaluators.keys()) + values = update_options._resolved_keys_as_propnames + attrib = set(k for k, v in values) for obj in update_options._matched_objects: state, dict_ = ( @@ -1941,9 +2019,15 @@ class BulkORMUpdate(UpdateDMLState, BulkUDCompileState): attributes.instance_dict(obj), ) - assert ( - state.identity_token == update_options._refresh_identity_token - ) + # the evaluated states were gathered across all identity tokens. + # however the post_sync events are called per identity token, + # so filter. + if ( + update_options._refresh_identity_token is not None + and state.identity_token + != update_options._refresh_identity_token + ): + continue # only evaluate unmodified attributes to_evaluate = state.unmodified.intersection(evaluated_keys) @@ -1954,38 +2038,64 @@ class BulkORMUpdate(UpdateDMLState, BulkUDCompileState): state._commit(dict_, list(to_evaluate)) - # expire attributes with pending changes - # (there was no autoflush, so they are overwritten) - state._expire_attributes( - dict_, set(evaluated_keys).difference(to_evaluate) - ) + to_expire = attrib.intersection(dict_).difference(to_evaluate) + if to_expire: + state._expire_attributes(dict_, to_expire) + states.add(state) session._register_altered(states) @classmethod - def _do_post_synchronize_fetch(cls, session, update_options): + def _do_post_synchronize_fetch(cls, session, result, update_options): target_mapper = update_options._subject_mapper - states = set( - [ - attributes.instance_state(session.identity_map[identity_key]) - for identity_key in [ - target_mapper.identity_key_from_primary_key( - list(primary_key), - identity_token=update_options._refresh_identity_token, - ) - for primary_key in update_options._matched_rows + states = set() + evaluated_keys = list(update_options._value_evaluators.keys()) + + if result.returns_rows: + matched_rows = [ + tuple(row) + (update_options._refresh_identity_token,) + for row in result.all() + ] + else: + matched_rows = update_options._matched_rows + + objs = [ + session.identity_map[identity_key] + for identity_key in [ + target_mapper.identity_key_from_primary_key( + list(primary_key), identity_token=identity_token, + ) + for primary_key, identity_token in [ + (row[0:-1], row[-1]) for row in matched_rows ] - if identity_key in session.identity_map + if update_options._refresh_identity_token is None + or identity_token == update_options._refresh_identity_token ] - ) + if identity_key in session.identity_map + ] values = update_options._resolved_keys_as_propnames attrib = set(k for k, v in values) - for state in states: - to_expire = attrib.intersection(state.dict) + + for obj in objs: + state, dict_ = ( + attributes.instance_state(obj), + attributes.instance_dict(obj), + ) + + to_evaluate = state.unmodified.intersection(evaluated_keys) + for key in to_evaluate: + dict_[key] = update_options._value_evaluators[key](obj) + state.manager.dispatch.refresh(state, None, to_evaluate) + + state._commit(dict_, list(to_evaluate)) + + to_expire = attrib.intersection(dict_).difference(to_evaluate) if to_expire: - session._expire_state(state, to_expire) + state._expire_attributes(dict_, to_expire) + + states.add(state) session._register_altered(states) @@ -1995,14 +2105,24 @@ class BulkORMDelete(DeleteDMLState, BulkUDCompileState): def create_for_statement(cls, statement, compiler, **kw): self = cls.__new__(cls) - self.mapper = statement.table._annotations.get("parentmapper", None) + self.mapper = mapper = statement.table._annotations.get( + "parentmapper", None + ) + + if ( + mapper + and statement._annotations.get("synchronize_session", None) + == "fetch" + and compiler.dialect.full_returning + ): + statement = statement.returning(*mapper.primary_key) DeleteDMLState.__init__(self, statement, compiler, **kw) return self @classmethod - def _do_post_synchronize_evaluate(cls, session, update_options): + def _do_post_synchronize_evaluate(cls, session, result, update_options): session._remove_newly_deleted( [ @@ -2012,15 +2132,25 @@ class BulkORMDelete(DeleteDMLState, BulkUDCompileState): ) @classmethod - def _do_post_synchronize_fetch(cls, session, update_options): + def _do_post_synchronize_fetch(cls, session, result, update_options): target_mapper = update_options._subject_mapper - for primary_key in update_options._matched_rows: + if result.returns_rows: + matched_rows = [ + tuple(row) + (update_options._refresh_identity_token,) + for row in result.all() + ] + else: + matched_rows = update_options._matched_rows + + for row in matched_rows: + primary_key = row[0:-1] + identity_token = row[-1] + # TODO: inline this and call remove_newly_deleted # once identity_key = target_mapper.identity_key_from_primary_key( - list(primary_key), - identity_token=update_options._refresh_identity_token, + list(primary_key), identity_token=identity_token, ) if identity_key in session.identity_map: session._remove_newly_deleted( diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 5ad8bcf2f..a398da793 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -116,6 +116,8 @@ class ORMExecuteState(util.MemoizedSlots): "_merged_execution_options", "bind_arguments", "_compile_state_cls", + "_starting_event_idx", + "_events_todo", ) def __init__( @@ -126,6 +128,7 @@ class ORMExecuteState(util.MemoizedSlots): execution_options, bind_arguments, compile_state_cls, + events_todo, ): self.session = session self.statement = statement @@ -133,6 +136,10 @@ class ORMExecuteState(util.MemoizedSlots): self._execution_options = execution_options self.bind_arguments = bind_arguments self._compile_state_cls = compile_state_cls + self._events_todo = list(events_todo) + + def _remaining_events(self): + return self._events_todo[self._starting_event_idx + 1 :] def invoke_statement( self, @@ -200,7 +207,11 @@ class ORMExecuteState(util.MemoizedSlots): _execution_options = self._execution_options return self.session.execute( - statement, _params, _execution_options, _bind_arguments + statement, + _params, + _execution_options, + _bind_arguments, + _parent_execute_state=self, ) @property @@ -1376,6 +1387,8 @@ class Session(_SessionClassMethods): params=None, execution_options=util.immutabledict(), bind_arguments=None, + _parent_execute_state=None, + _add_event=None, **kw ): r"""Execute a SQL expression construct or string statement within @@ -1521,8 +1534,16 @@ class Session(_SessionClassMethods): compile_state_cls = None if compile_state_cls is not None: - execution_options = compile_state_cls.orm_pre_session_exec( - self, statement, params, execution_options, bind_arguments + ( + statement, + execution_options, + ) = compile_state_cls.orm_pre_session_exec( + self, + statement, + params, + execution_options, + bind_arguments, + _parent_execute_state is not None, ) else: bind_arguments.setdefault("clause", statement) @@ -1531,22 +1552,28 @@ class Session(_SessionClassMethods): execution_options, {"future_result": True} ) - if self.dispatch.do_orm_execute: - # run this event whether or not we are in ORM mode - skip_events = bind_arguments.get("_sa_skip_events", False) - if not skip_events: - orm_exec_state = ORMExecuteState( - self, - statement, - params, - execution_options, - bind_arguments, - compile_state_cls, - ) - for fn in self.dispatch.do_orm_execute: - result = fn(orm_exec_state) - if result: - return result + if _parent_execute_state: + events_todo = _parent_execute_state._remaining_events() + else: + events_todo = self.dispatch.do_orm_execute + if _add_event: + events_todo = list(events_todo) + [_add_event] + + if events_todo: + orm_exec_state = ORMExecuteState( + self, + statement, + params, + execution_options, + bind_arguments, + compile_state_cls, + events_todo, + ) + for idx, fn in enumerate(events_todo): + orm_exec_state._starting_event_idx = idx + result = fn(orm_exec_state) + if result: + return result bind = self.get_bind(**bind_arguments) @@ -1729,7 +1756,12 @@ class Session(_SessionClassMethods): self._add_bind(table, bind) def get_bind( - self, mapper=None, clause=None, bind=None, _sa_skip_events=None + self, + mapper=None, + clause=None, + bind=None, + _sa_skip_events=None, + _sa_skip_for_implicit_returning=False, ): """Return a "bind" to which this :class:`.Session` is bound. diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index 2d51e7c9b..163276ca9 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -312,12 +312,30 @@ class SuiteRequirements(Requirements): return exclusions.open() @property + def full_returning(self): + """target platform supports RETURNING completely, including + multiple rows returned. + + """ + + return exclusions.only_if( + lambda config: config.db.dialect.full_returning, + "%(database)s %(does_support)s 'RETURNING of multiple rows'", + ) + + @property def returning(self): - """target platform supports RETURNING.""" + """target platform supports RETURNING for at least one row. + + .. seealso:: + + :attr:`.Requirements.full_returning` + + """ return exclusions.only_if( lambda config: config.db.dialect.implicit_returning, - "%(database)s %(does_support)s 'returning'", + "%(database)s %(does_support)s 'RETURNING of a single row'", ) @property diff --git a/test/aaa_profiling/test_orm.py b/test/aaa_profiling/test_orm.py index 8f06220e2..13e92f5c4 100644 --- a/test/aaa_profiling/test_orm.py +++ b/test/aaa_profiling/test_orm.py @@ -878,6 +878,7 @@ class JoinedEagerLoadTest(NoCache, fixtures.MappedTest): {}, exec_opts, bind_arguments, + is_reentrant_invoke=False, ) r = sess.connection().execute( @@ -888,7 +889,7 @@ class JoinedEagerLoadTest(NoCache, fixtures.MappedTest): r.context.compiled.compile_state = compile_state obj = ORMCompileState.orm_setup_cursor_result( - sess, compile_state.statement, exec_opts, {}, r + sess, compile_state.statement, exec_opts, {}, r, ) list(obj) sess.close() diff --git a/test/engine/test_reflection.py b/test/engine/test_reflection.py index 0fea029fe..f1b54cb8f 100644 --- a/test/engine/test_reflection.py +++ b/test/engine/test_reflection.py @@ -1197,7 +1197,6 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): @testing.provide_metadata def test_reflect_all(self): existing = inspect(testing.db).get_table_names() - names = ["rt_%s" % name for name in ("a", "b", "c", "d", "e")] nameset = set(names) for name in names: diff --git a/test/ext/test_horizontal_shard.py b/test/ext/test_horizontal_shard.py index c0029fbb6..9855cd5ab 100644 --- a/test/ext/test_horizontal_shard.py +++ b/test/ext/test_horizontal_shard.py @@ -35,8 +35,6 @@ from sqlalchemy.testing import provision from sqlalchemy.testing.engines import testing_engine from sqlalchemy.testing.engines import testing_reaper -# TODO: ShardTest can be turned into a base for further subclasses - class ShardTest(object): __skip_if__ = (lambda: util.win32,) @@ -47,9 +45,9 @@ class ShardTest(object): def setUp(self): global db1, db2, db3, db4, weather_locations, weather_reports - db1, db2, db3, db4 = self._init_dbs() + db1, db2, db3, db4 = self._dbs = self._init_dbs() - meta = MetaData() + meta = self.metadata = MetaData() ids = Table("ids", meta, Column("nextid", Integer, nullable=False)) def id_generator(ctx): @@ -578,9 +576,11 @@ class ShardTest(object): temps = sess.execute(future_select(Report)).scalars().all() eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0}) + # MARKMARK + # omitting the criteria so that the UPDATE affects three out of + # four shards sess.execute( update(Report) - .filter(Report.temperature >= 80) .values({"temperature": Report.temperature + 6},) .execution_options(synchronize_session="fetch") ) @@ -590,11 +590,11 @@ class ShardTest(object): row.temperature for row in sess.execute(future_select(Report.temperature)) ), - {86.0, 75.0, 91.0}, + {86.0, 81.0, 91.0}, ) # test synchronize session as well - eq_(set(t.temperature for t in temps), {86.0, 75.0, 91.0}) + eq_(set(t.temperature for t in temps), {86.0, 81.0, 91.0}) def test_bulk_delete_future_synchronize_evaluate(self): sess = self._fixture_data() @@ -711,9 +711,8 @@ class TableNameConventionShardTest(ShardTest, fixtures.TestBase): This used to be called "AttachedFileShardTest" but I didn't see any ATTACH going on. - The approach taken by this test is awkward and I wouldn't recommend using - this pattern in a real situation. I'm not sure of the history of this test - but it likely predates when we knew how to use real ATTACH in SQLite. + A more modern approach here would be to use the schema_translate_map + option. """ @@ -742,6 +741,49 @@ class TableNameConventionShardTest(ShardTest, fixtures.TestBase): return db1, db2, db3, db4 +class MultipleDialectShardTest(ShardTest, fixtures.TestBase): + __only_on__ = "postgresql" + + schema = "changeme" + + def _init_dbs(self): + e1 = testing_engine("sqlite://") + with e1.connect() as conn: + for i in [1, 3]: + conn.exec_driver_sql( + 'ATTACH DATABASE "shard%s_%s.db" AS shard%s' + % (i, provision.FOLLOWER_IDENT, i) + ) + + e2 = testing_engine() + with e2.connect() as conn: + for i in [2, 4]: + conn.exec_driver_sql( + "CREATE SCHEMA IF NOT EXISTS shard%s" % (i,) + ) + + db1 = e1.execution_options(schema_translate_map={"changeme": "shard1"}) + db2 = e2.execution_options(schema_translate_map={"changeme": "shard2"}) + db3 = e1.execution_options(schema_translate_map={"changeme": "shard3"}) + db4 = e2.execution_options(schema_translate_map={"changeme": "shard4"}) + + self.sqlite_engine = e1 + self.postgresql_engine = e2 + return db1, db2, db3, db4 + + def teardown(self): + clear_mappers() + + self.sqlite_engine.connect().invalidate() + for i in [1, 3]: + os.remove("shard%d_%s.db" % (i, provision.FOLLOWER_IDENT)) + + with self.postgresql_engine.connect() as conn: + self.metadata.drop_all(conn) + for i in [2, 4]: + conn.exec_driver_sql("DROP SCHEMA shard%s CASCADE" % (i,)) + + class SelectinloadRegressionTest(fixtures.DeclarativeMappedTest): """test #4175 """ diff --git a/test/orm/test_evaluator.py b/test/orm/test_evaluator.py index 5bc054486..20577d8e6 100644 --- a/test/orm/test_evaluator.py +++ b/test/orm/test_evaluator.py @@ -8,6 +8,7 @@ from sqlalchemy import Integer from sqlalchemy import not_ from sqlalchemy import or_ from sqlalchemy import String +from sqlalchemy import tuple_ from sqlalchemy.orm import evaluator from sqlalchemy.orm import mapper from sqlalchemy.orm import relationship @@ -19,7 +20,6 @@ from sqlalchemy.testing import is_ from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table - compiler = evaluator.EvaluatorCompiler() @@ -191,6 +191,54 @@ class EvaluateTest(fixtures.MappedTest): ], ) + def test_in(self): + User = self.classes.User + + eval_eq( + User.name.in_(["foo", "bar"]), + testcases=[ + (User(id=1, name="foo"), True), + (User(id=2, name="bat"), False), + (User(id=1, name="bar"), True), + (User(id=1, name=None), None), + ], + ) + + eval_eq( + User.name.notin_(["foo", "bar"]), + testcases=[ + (User(id=1, name="foo"), False), + (User(id=2, name="bat"), True), + (User(id=1, name="bar"), False), + (User(id=1, name=None), None), + ], + ) + + def test_in_tuples(self): + User = self.classes.User + + eval_eq( + tuple_(User.id, User.name).in_([(1, "foo"), (2, "bar")]), + testcases=[ + (User(id=1, name="foo"), True), + (User(id=2, name="bat"), False), + (User(id=1, name="bar"), False), + (User(id=2, name="bar"), True), + (User(id=1, name=None), None), + ], + ) + + eval_eq( + tuple_(User.id, User.name).notin_([(1, "foo"), (2, "bar")]), + testcases=[ + (User(id=1, name="foo"), False), + (User(id=2, name="bat"), True), + (User(id=1, name="bar"), True), + (User(id=2, name="bar"), False), + (User(id=1, name=None), None), + ], + ) + def test_null_propagation(self): User = self.classes.User diff --git a/test/orm/test_update_delete.py b/test/orm/test_update_delete.py index 12a8417ba..310b17047 100644 --- a/test/orm/test_update_delete.py +++ b/test/orm/test_update_delete.py @@ -23,6 +23,9 @@ from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures +from sqlalchemy.testing import in_ +from sqlalchemy.testing import not_in_ +from sqlalchemy.testing.assertsql import CompiledSQL from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -144,6 +147,50 @@ class UpdateDeleteTest(fixtures.MappedTest): q.delete, ) + def test_update_w_unevaluatable_value_evaluate(self): + """test that the "evaluate" strategy falls back to 'expire' for an + update SET that is not evaluable in Python.""" + + User = self.classes.User + + s = Session() + + jill = s.query(User).filter(User.name == "jill").one() + + s.execute( + update(User) + .filter(User.name == "jill") + .values({"name": User.name + User.name}), + execution_options={"synchronize_session": "evaluate"}, + ) + + eq_(jill.name, "jilljill") + + def test_update_w_unevaluatable_value_fetch(self): + """test that the "fetch" strategy falls back to 'expire' for an + update SET that is not evaluable in Python. + + Prior to 1.4 the "fetch" strategy used expire for everything + but now it tries to evaluate a SET clause to avoid a round + trip. + + """ + + User = self.classes.User + + s = Session() + + jill = s.query(User).filter(User.name == "jill").one() + + s.execute( + update(User) + .filter(User.name == "jill") + .values({"name": User.name + User.name}), + execution_options={"synchronize_session": "fetch"}, + ) + + eq_(jill.name, "jilljill") + def test_evaluate_clauseelement(self): User = self.classes.User @@ -479,6 +526,87 @@ class UpdateDeleteTest(fixtures.MappedTest): list(zip([25, 37, 29, 27])), ) + def test_update_fetch_returning(self): + User = self.classes.User + + sess = Session() + + john, jack, jill, jane = sess.query(User).order_by(User.id).all() + + with self.sql_execution_asserter() as asserter: + sess.query(User).filter(User.age > 29).update( + {"age": User.age - 10}, synchronize_session="fetch" + ) + + # these are simple values, these are now evaluated even with + # the "fetch" strategy, new in 1.4, so there is no expiry + eq_([john.age, jack.age, jill.age, jane.age], [25, 37, 29, 27]) + + if testing.db.dialect.full_returning: + asserter.assert_( + CompiledSQL( + "UPDATE users SET age_int=(users.age_int - %(age_int_1)s) " + "WHERE users.age_int > %(age_int_2)s RETURNING users.id", + [{"age_int_1": 10, "age_int_2": 29}], + dialect="postgresql", + ), + ) + else: + asserter.assert_( + CompiledSQL( + "SELECT users.id FROM users " + "WHERE users.age_int > :age_int_1", + [{"age_int_1": 29}], + ), + CompiledSQL( + "UPDATE users SET age_int=(users.age_int - :age_int_1) " + "WHERE users.age_int > :age_int_2", + [{"age_int_1": 10, "age_int_2": 29}], + ), + ) + + def test_delete_fetch_returning(self): + User = self.classes.User + + sess = Session() + + john, jack, jill, jane = sess.query(User).order_by(User.id).all() + + in_(john, sess) + in_(jack, sess) + + with self.sql_execution_asserter() as asserter: + sess.query(User).filter(User.age > 29).delete( + synchronize_session="fetch" + ) + + if testing.db.dialect.full_returning: + asserter.assert_( + CompiledSQL( + "DELETE FROM users WHERE users.age_int > %(age_int_1)s " + "RETURNING users.id", + [{"age_int_1": 29}], + dialect="postgresql", + ), + ) + else: + asserter.assert_( + CompiledSQL( + "SELECT users.id FROM users " + "WHERE users.age_int > :age_int_1", + [{"age_int_1": 29}], + ), + CompiledSQL( + "DELETE FROM users WHERE users.age_int > :age_int_1", + [{"age_int_1": 29}], + ), + ) + + in_(john, sess) + not_in_(jack, sess) + in_(jill, sess) + not_in_(jane, sess) + def test_update_without_load(self): User = self.classes.User diff --git a/test/sql/test_returning.py b/test/sql/test_returning.py index f856c15a4..90c21ed45 100644 --- a/test/sql/test_returning.py +++ b/test/sql/test_returning.py @@ -129,6 +129,32 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults): ) eq_(result2.fetchall(), [(1, True), (2, False)]) + @testing.requires.full_returning + def test_update_full_returning(self, connection): + connection.execute( + table.insert(), + [{"persons": 5, "full": False}, {"persons": 3, "full": False}], + ) + + result = connection.execute( + table.update(table.c.persons > 2) + .values(full=True) + .returning(table.c.id, table.c.full) + ) + eq_(result.fetchall(), [(1, True), (2, True)]) + + @testing.requires.full_returning + def test_delete_full_returning(self, connection): + connection.execute( + table.insert(), + [{"persons": 5, "full": False}, {"persons": 3, "full": False}], + ) + + result = connection.execute( + table.delete().returning(table.c.id, table.c.full) + ) + eq_(result.fetchall(), [(1, False), (2, False)]) + def test_insert_returning(self, connection): result = connection.execute( table.insert().returning(table.c.id), {"persons": 1, "full": False} @@ -474,13 +500,6 @@ class ImplicitReturningFlag(fixtures.TestBase): testing.requires.returning(go)() e = engines.testing_engine() - # starts as False. This is because all of Firebird, - # PostgreSQL, Oracle, SQL Server started supporting RETURNING - # as of a certain version, and the flag is not set until - # version detection occurs. If some DB comes along that has - # RETURNING in all cases, this test can be adjusted. - assert e.dialect.implicit_returning is False - # version detection on connect sets it c = e.connect() c.close() |