diff options
Diffstat (limited to 'lib/sqlalchemy/orm/persistence.py')
-rw-r--r-- | lib/sqlalchemy/orm/persistence.py | 189 |
1 files changed, 170 insertions, 19 deletions
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 7cd66513b..59a0a3d81 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -1805,6 +1805,8 @@ _EMPTY_DICT = util.immutabledict() class BulkUDCompileState(CompileState): class default_update_options(Options): _synchronize_session = "evaluate" + _is_delete_using = False + _is_update_from = False _autoflush = True _subject_mapper = None _resolved_values = _EMPTY_DICT @@ -1815,7 +1817,15 @@ class BulkUDCompileState(CompileState): _refresh_identity_token = None @classmethod - def can_use_returning(cls, dialect: Dialect, mapper: Mapper[Any]) -> bool: + def can_use_returning( + cls, + dialect: Dialect, + mapper: Mapper[Any], + *, + is_multitable: bool = False, + is_update_from: bool = False, + is_delete_using: bool = False, + ) -> bool: raise NotImplementedError() @classmethod @@ -1836,7 +1846,7 @@ class BulkUDCompileState(CompileState): execution_options, ) = BulkUDCompileState.default_update_options.from_execution_options( "_sa_orm_update_options", - {"synchronize_session"}, + {"synchronize_session", "is_delete_using", "is_update_from"}, execution_options, statement._execution_options, ) @@ -1863,7 +1873,11 @@ class BulkUDCompileState(CompileState): session._autoflush() statement = statement._annotate( - {"synchronize_session": update_options._synchronize_session} + { + "synchronize_session": update_options._synchronize_session, + "is_delete_using": update_options._is_delete_using, + "is_update_from": update_options._is_update_from, + } ) # this stage of the execution is called before the do_orm_execute event @@ -1964,6 +1978,56 @@ class BulkUDCompileState(CompileState): return return_crit @classmethod + def _interpret_returning_rows(cls, mapper, rows): + """translate from local inherited table columns to base mapper + primary key columns. + + Joined inheritance mappers always establish the primary key in terms of + the base table. When we UPDATE a sub-table, we can only get + RETURNING for the sub-table's columns. + + Here, we create a lookup from the local sub table's primary key + columns to the base table PK columns so that we can get identity + key values from RETURNING that's against the joined inheritance + sub-table. + + the complexity here is to support more than one level deep of + inheritance, where we have to link columns to each other across + the inheritance hierarchy. + + """ + + if mapper.local_table is not mapper.base_mapper.local_table: + return rows + + # this starts as a mapping of + # local_pk_col: local_pk_col. + # we will then iteratively rewrite the "value" of the dict with + # each successive superclass column + local_pk_to_base_pk = {pk: pk for pk in mapper.local_table.primary_key} + + for mp in mapper.iterate_to_root(): + if mp.inherits is None: + break + elif mp.local_table is mp.inherits.local_table: + continue + + t_to_e = dict(mp._table_to_equated[mp.inherits.local_table]) + col_to_col = {sub_pk: super_pk for super_pk, sub_pk in t_to_e[mp]} + for pk, super_ in local_pk_to_base_pk.items(): + local_pk_to_base_pk[pk] = col_to_col[super_] + + lookup = { + local_pk_to_base_pk[lpk]: idx + for idx, lpk in enumerate(mapper.local_table.primary_key) + } + primary_key_convert = [ + lookup[bpk] for bpk in mapper.base_mapper.primary_key + ] + + return [tuple(row[idx] for idx in primary_key_convert) for row in rows] + + @classmethod def _do_pre_synchronize_evaluate( cls, session, @@ -2111,8 +2175,12 @@ class BulkUDCompileState(CompileState): def skip_for_returning(orm_context: ORMExecuteState) -> Any: bind = orm_context.session.get_bind(**orm_context.bind_arguments) - - if cls.can_use_returning(bind.dialect, mapper): + if cls.can_use_returning( + bind.dialect, + mapper, + is_update_from=update_options._is_update_from, + is_delete_using=update_options._is_delete_using, + ): return _result.null_result() else: return None @@ -2300,25 +2368,60 @@ class BulkORMUpdate(ORMDMLState, UpdateDMLState, BulkUDCompileState): # if we are against a lambda statement we might not be the # topmost object that received per-execute annotations + # do this first as we need to determine if there is + # UPDATE..FROM + + UpdateDMLState.__init__(self, new_stmt, compiler, **kw) + if compiler._annotations.get( "synchronize_session", None - ) == "fetch" and self.can_use_returning(compiler.dialect, mapper): + ) == "fetch" and self.can_use_returning( + compiler.dialect, mapper, is_multitable=self.is_multitable + ): if new_stmt._returning: raise sa_exc.InvalidRequestError( "Can't use synchronize_session='fetch' " "with explicit returning()" ) - new_stmt = new_stmt.returning(*mapper.primary_key) - - UpdateDMLState.__init__(self, new_stmt, compiler, **kw) + self.statement = self.statement.returning( + *mapper.local_table.primary_key + ) return self @classmethod - def can_use_returning(cls, dialect: Dialect, mapper: Mapper[Any]) -> bool: - return ( + def can_use_returning( + cls, + dialect: Dialect, + mapper: Mapper[Any], + *, + is_multitable: bool = False, + is_update_from: bool = False, + is_delete_using: bool = False, + ) -> bool: + + # normal answer for "should we use RETURNING" at all. + normal_answer = ( dialect.update_returning and mapper.local_table.implicit_returning ) + if not normal_answer: + return False + + # these workarounds are currently hypothetical for UPDATE, + # unlike DELETE where they impact MariaDB + if is_update_from: + return dialect.update_returning_multifrom + + elif is_multitable and not dialect.update_returning_multifrom: + raise sa_exc.CompileError( + f'Dialect "{dialect.name}" does not support RETURNING ' + "with UPDATE..FROM; for synchronize_session='fetch', " + "please add the additional execution option " + "'is_update_from=True' to the statement to indicate that " + "a separate SELECT should be used for this backend." + ) + + return True @classmethod def _get_crud_kv_pairs(cls, statement, kv_iterator): @@ -2429,9 +2532,11 @@ class BulkORMUpdate(ORMDMLState, UpdateDMLState, BulkUDCompileState): evaluated_keys = list(update_options._value_evaluators.keys()) if result.returns_rows: + rows = cls._interpret_returning_rows(target_mapper, result.all()) + matched_rows = [ tuple(row) + (update_options._refresh_identity_token,) - for row in result.all() + for row in rows ] else: matched_rows = update_options._matched_rows @@ -2500,20 +2605,64 @@ class BulkORMDelete(ORMDMLState, DeleteDMLState, BulkUDCompileState): if new_crit: statement = statement.where(*new_crit) + # do this first as we need to determine if there is + # DELETE..FROM + DeleteDMLState.__init__(self, statement, compiler, **kw) + if compiler._annotations.get( "synchronize_session", None - ) == "fetch" and self.can_use_returning(compiler.dialect, mapper): - statement = statement.returning(*mapper.primary_key) - - DeleteDMLState.__init__(self, statement, compiler, **kw) + ) == "fetch" and self.can_use_returning( + compiler.dialect, + mapper, + is_multitable=self.is_multitable, + is_delete_using=compiler._annotations.get( + "is_delete_using", False + ), + ): + self.statement = statement.returning(*statement.table.primary_key) return self @classmethod - def can_use_returning(cls, dialect: Dialect, mapper: Mapper[Any]) -> bool: - return ( + def can_use_returning( + cls, + dialect: Dialect, + mapper: Mapper[Any], + *, + is_multitable: bool = False, + is_update_from: bool = False, + is_delete_using: bool = False, + ) -> bool: + + # normal answer for "should we use RETURNING" at all. + normal_answer = ( dialect.delete_returning and mapper.local_table.implicit_returning ) + if not normal_answer: + return False + + # now get into special workarounds because MariaDB supports + # DELETE...RETURNING but not DELETE...USING...RETURNING. + if is_delete_using: + # is_delete_using hint was passed. use + # additional dialect feature (True for PG, False for MariaDB) + return dialect.delete_returning_multifrom + + elif is_multitable and not dialect.delete_returning_multifrom: + # is_delete_using hint was not passed, but we determined + # at compile time that this is in fact a DELETE..USING. + # it's too late to continue since we did not pre-SELECT. + # raise that we need that hint up front. + + raise sa_exc.CompileError( + f'Dialect "{dialect.name}" does not support RETURNING ' + "with DELETE..USING; for synchronize_session='fetch', " + "please add the additional execution option " + "'is_delete_using=True' to the statement to indicate that " + "a separate SELECT should be used for this backend." + ) + + return True @classmethod def _do_post_synchronize_evaluate(cls, session, result, update_options): @@ -2530,9 +2679,11 @@ class BulkORMDelete(ORMDMLState, DeleteDMLState, BulkUDCompileState): target_mapper = update_options._subject_mapper if result.returns_rows: + rows = cls._interpret_returning_rows(target_mapper, result.all()) + matched_rows = [ tuple(row) + (update_options._refresh_identity_token,) - for row in result.all() + for row in rows ] else: matched_rows = update_options._matched_rows |