diff options
Diffstat (limited to 'lib/sqlalchemy/sql/dml.py')
-rw-r--r-- | lib/sqlalchemy/sql/dml.py | 32 |
1 files changed, 26 insertions, 6 deletions
diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 955eb4109..2ed3be9cb 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -119,6 +119,8 @@ class DMLState(CompileState): _ordered_values: Optional[List[Tuple[_DMLColumnElement, Any]]] = None _parameter_ordering: Optional[List[_DMLColumnElement]] = None _has_multi_parameters = False + _primary_table: FromClause + _supports_implicit_returning = True isupdate = False isdelete = False @@ -182,11 +184,14 @@ class DMLState(CompileState): for k, v in kv_iterator ] - def _make_extra_froms(self, statement: DMLWhereBase) -> List[FromClause]: + def _make_extra_froms( + self, statement: DMLWhereBase + ) -> Tuple[FromClause, List[FromClause]]: froms: List[FromClause] = [] all_tables = list(sql_util.tables_from_leftmost(statement.table)) - seen = {all_tables[0]} + primary_table = all_tables[0] + seen = {primary_table} for crit in statement._where_criteria: for item in _from_objects(crit): @@ -195,7 +200,7 @@ class DMLState(CompileState): seen.update(item._cloned_set) froms.extend(all_tables[1:]) - return froms + return primary_table, froms def _process_multi_values(self, statement: ValuesBase) -> None: if not statement._supports_multi_parameters: @@ -286,8 +291,18 @@ class InsertDMLState(DMLState): include_table_with_column_exprs = False - def __init__(self, statement: Insert, compiler: SQLCompiler, **kw: Any): + def __init__( + self, + statement: Insert, + compiler: SQLCompiler, + disable_implicit_returning: bool = False, + **kw: Any, + ): self.statement = statement + self._primary_table = statement.table + + if disable_implicit_returning: + self._supports_implicit_returning = False self.isinsert = True if statement._select_names: @@ -306,6 +321,7 @@ class UpdateDMLState(DMLState): def __init__(self, statement: Update, compiler: SQLCompiler, **kw: Any): self.statement = statement + self.isupdate = True if statement._ordered_values is not None: self._process_ordered_values(statement) @@ -313,7 +329,9 @@ class UpdateDMLState(DMLState): self._process_values(statement) elif statement._multi_values: self._process_multi_values(statement) - self._extra_froms = ef = self._make_extra_froms(statement) + t, ef = self._make_extra_froms(statement) + self._primary_table = t + self._extra_froms = ef self.is_multitable = mt = ef @@ -330,7 +348,9 @@ class DeleteDMLState(DMLState): self.statement = statement self.isdelete = True - self._extra_froms = self._make_extra_froms(statement) + t, ef = self._make_extra_froms(statement) + self._primary_table = t + self._extra_froms = ef SelfUpdateBase = typing.TypeVar("SelfUpdateBase", bound="UpdateBase") |