diff options
Diffstat (limited to 'lib/sqlalchemy/orm/context.py')
-rw-r--r-- | lib/sqlalchemy/orm/context.py | 173 |
1 files changed, 141 insertions, 32 deletions
diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index dc96f8c3c..f8c7ba714 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -73,6 +73,7 @@ if TYPE_CHECKING: from .query import Query from .session import _BindArguments from .session import Session + from ..engine import Result from ..engine.interfaces import _CoreSingleExecuteParams from ..engine.interfaces import _ExecuteOptionsParameter from ..sql._typing import _ColumnsClauseArgument @@ -203,15 +204,19 @@ _orm_load_exec_options = util.immutabledict( class AbstractORMCompileState(CompileState): + is_dml_returning = False + @classmethod def create_for_statement( cls, statement: Union[Select, FromStatement], compiler: Optional[SQLCompiler], **kw: Any, - ) -> ORMCompileState: + ) -> AbstractORMCompileState: """Create a context for a statement given a :class:`.Compiler`. + This method is always invoked in the context of SQLCompiler.process(). + For a Select object, this would be invoked from SQLCompiler.visit_select(). For the special FromStatement object used by Query to indicate "Query.from_statement()", this is called by @@ -233,6 +238,28 @@ class AbstractORMCompileState(CompileState): raise NotImplementedError() @classmethod + def orm_execute_statement( + cls, + session, + statement, + params, + execution_options, + bind_arguments, + conn, + ) -> Result: + result = conn.execute( + statement, params or {}, execution_options=execution_options + ) + return cls.orm_setup_cursor_result( + session, + statement, + params, + execution_options, + bind_arguments, + result, + ) + + @classmethod def orm_setup_cursor_result( cls, session, @@ -309,6 +336,17 @@ class ORMCompileState(AbstractORMCompileState): def __init__(self, *arg, **kw): raise NotImplementedError() + if TYPE_CHECKING: + + @classmethod + def create_for_statement( + cls, + statement: Union[Select, FromStatement], + compiler: Optional[SQLCompiler], + **kw: Any, + ) -> ORMCompileState: + ... + def _append_dedupe_col_collection(self, obj, col_collection): dedupe = self.dedupe_columns if obj not in dedupe: @@ -333,26 +371,6 @@ class ORMCompileState(AbstractORMCompileState): return SelectState._column_naming_convention(label_style) @classmethod - def create_for_statement( - cls, - statement: Union[Select, FromStatement], - compiler: Optional[SQLCompiler], - **kw: Any, - ) -> ORMCompileState: - """Create a context for a statement given a :class:`.Compiler`. - - This method is always invoked in the context of SQLCompiler.process(). - - For a Select object, this would be invoked from - SQLCompiler.visit_select(). For the special FromStatement object used - by Query to indicate "Query.from_statement()", this is called by - FromStatement._compiler_dispatch() that would be called by - SQLCompiler.process(). - - """ - raise NotImplementedError() - - @classmethod def get_column_descriptions(cls, statement): return _column_descriptions(statement) @@ -518,6 +536,49 @@ class ORMCompileState(AbstractORMCompileState): ) +class DMLReturningColFilter: + """an adapter used for the DML RETURNING case. + + Has a subset of the interface used by + :class:`.ORMAdapter` and is used for :class:`._QueryEntity` + instances to set up their columns as used in RETURNING for a + DML statement. + + """ + + __slots__ = ("mapper", "columns", "__weakref__") + + def __init__(self, target_mapper, immediate_dml_mapper): + if ( + immediate_dml_mapper is not None + and target_mapper.local_table + is not immediate_dml_mapper.local_table + ): + # joined inh, or in theory other kinds of multi-table mappings + self.mapper = immediate_dml_mapper + else: + # single inh, normal mappings, etc. + self.mapper = target_mapper + self.columns = self.columns = util.WeakPopulateDict( + self.adapt_check_present # type: ignore + ) + + def __call__(self, col, as_filter): + for cc in sql_util._find_columns(col): + c2 = self.adapt_check_present(cc) + if c2 is not None: + return col + else: + return None + + def adapt_check_present(self, col): + mapper = self.mapper + prop = mapper._columntoproperty.get(col, None) + if prop is None: + return None + return mapper.local_table.c.corresponding_column(col) + + @sql.base.CompileState.plugin_for("orm", "orm_from_statement") class ORMFromStatementCompileState(ORMCompileState): _from_obj_alias = None @@ -525,7 +586,7 @@ class ORMFromStatementCompileState(ORMCompileState): statement_container: FromStatement requested_statement: Union[SelectBase, TextClause, UpdateBase] - dml_table: _DMLTableElement + dml_table: Optional[_DMLTableElement] = None _has_orm_entities = False multi_row_eager_loaders = False @@ -541,7 +602,7 @@ class ORMFromStatementCompileState(ORMCompileState): statement_container: Union[Select, FromStatement], compiler: Optional[SQLCompiler], **kw: Any, - ) -> ORMCompileState: + ) -> ORMFromStatementCompileState: if compiler is not None: toplevel = not compiler.stack @@ -565,6 +626,7 @@ class ORMFromStatementCompileState(ORMCompileState): if statement.is_dml: self.dml_table = statement.table + self.is_dml_returning = True self._entities = [] self._polymorphic_adapters = {} @@ -674,6 +736,18 @@ class ORMFromStatementCompileState(ORMCompileState): def _get_current_adapter(self): return None + def setup_dml_returning_compile_state(self, dml_mapper): + """used by BulkORMInsert (and Update / Delete?) to set up a handler + for RETURNING to return ORM objects and expressions + + """ + target_mapper = self.statement._propagate_attrs.get( + "plugin_subject", None + ) + adapter = DMLReturningColFilter(target_mapper, dml_mapper) + for entity in self._entities: + entity.setup_dml_returning_compile_state(self, adapter) + class FromStatement(GroupedElement, Generative, TypedReturnsRows[_TP]): """Core construct that represents a load of ORM objects from various @@ -813,7 +887,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState): statement: Union[Select, FromStatement], compiler: Optional[SQLCompiler], **kw: Any, - ) -> ORMCompileState: + ) -> ORMSelectCompileState: """compiler hook, we arrive here from compiler.visit_select() only.""" self = cls.__new__(cls) @@ -2312,6 +2386,13 @@ class _QueryEntity: def setup_compile_state(self, compile_state: ORMCompileState) -> None: raise NotImplementedError() + def setup_dml_returning_compile_state( + self, + compile_state: ORMCompileState, + adapter: DMLReturningColFilter, + ) -> None: + raise NotImplementedError() + def row_processor(self, context, result): raise NotImplementedError() @@ -2509,8 +2590,24 @@ class _MapperEntity(_QueryEntity): return _instance, self._label_name, self._extra_entities - def setup_compile_state(self, compile_state): + def setup_dml_returning_compile_state( + self, + compile_state: ORMCompileState, + adapter: DMLReturningColFilter, + ) -> None: + loading._setup_entity_query( + compile_state, + self.mapper, + self, + self.path, + adapter, + compile_state.primary_columns, + with_polymorphic=self._with_polymorphic_mappers, + only_load_props=compile_state.compile_options._only_load_props, + polymorphic_discriminator=self._polymorphic_discriminator, + ) + def setup_compile_state(self, compile_state): adapter = self._get_entity_clauses(compile_state) single_table_crit = self.mapper._single_table_criterion @@ -2536,7 +2633,6 @@ class _MapperEntity(_QueryEntity): only_load_props=compile_state.compile_options._only_load_props, polymorphic_discriminator=self._polymorphic_discriminator, ) - compile_state._fallback_from_clauses.append(self.selectable) @@ -2743,9 +2839,7 @@ class _ColumnEntity(_QueryEntity): getter, label_name, extra_entities = self._row_processor if self.translate_raw_column: extra_entities += ( - result.context.invoked_statement._raw_columns[ - self.raw_column_index - ], + context.query._raw_columns[self.raw_column_index], ) return getter, label_name, extra_entities @@ -2781,9 +2875,7 @@ class _ColumnEntity(_QueryEntity): if self.translate_raw_column: extra_entities = self._extra_entities + ( - result.context.invoked_statement._raw_columns[ - self.raw_column_index - ], + context.query._raw_columns[self.raw_column_index], ) return getter, self._label_name, extra_entities else: @@ -2843,6 +2935,8 @@ class _RawColumnEntity(_ColumnEntity): current_adapter = compile_state._get_current_adapter() if current_adapter: column = current_adapter(self.column, False) + if column is None: + return else: column = self.column @@ -2944,10 +3038,25 @@ class _ORMColumnEntity(_ColumnEntity): self.entity_zero ) and entity.common_parent(self.entity_zero) + def setup_dml_returning_compile_state( + self, + compile_state: ORMCompileState, + adapter: DMLReturningColFilter, + ) -> None: + self._fetch_column = self.column + column = adapter(self.column, False) + if column is not None: + compile_state.dedupe_columns.add(column) + compile_state.primary_columns.append(column) + def setup_compile_state(self, compile_state): current_adapter = compile_state._get_current_adapter() if current_adapter: column = current_adapter(self.column, False) + if column is None: + assert compile_state.is_dml_returning + self._fetch_column = self.column + return else: column = self.column |