summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/context.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/orm/context.py')
-rw-r--r--lib/sqlalchemy/orm/context.py173
1 files changed, 141 insertions, 32 deletions
diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py
index dc96f8c3c..f8c7ba714 100644
--- a/lib/sqlalchemy/orm/context.py
+++ b/lib/sqlalchemy/orm/context.py
@@ -73,6 +73,7 @@ if TYPE_CHECKING:
from .query import Query
from .session import _BindArguments
from .session import Session
+ from ..engine import Result
from ..engine.interfaces import _CoreSingleExecuteParams
from ..engine.interfaces import _ExecuteOptionsParameter
from ..sql._typing import _ColumnsClauseArgument
@@ -203,15 +204,19 @@ _orm_load_exec_options = util.immutabledict(
class AbstractORMCompileState(CompileState):
+ is_dml_returning = False
+
@classmethod
def create_for_statement(
cls,
statement: Union[Select, FromStatement],
compiler: Optional[SQLCompiler],
**kw: Any,
- ) -> ORMCompileState:
+ ) -> AbstractORMCompileState:
"""Create a context for a statement given a :class:`.Compiler`.
+
This method is always invoked in the context of SQLCompiler.process().
+
For a Select object, this would be invoked from
SQLCompiler.visit_select(). For the special FromStatement object used
by Query to indicate "Query.from_statement()", this is called by
@@ -233,6 +238,28 @@ class AbstractORMCompileState(CompileState):
raise NotImplementedError()
@classmethod
+ def orm_execute_statement(
+ cls,
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ conn,
+ ) -> Result:
+ result = conn.execute(
+ statement, params or {}, execution_options=execution_options
+ )
+ return cls.orm_setup_cursor_result(
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ result,
+ )
+
+ @classmethod
def orm_setup_cursor_result(
cls,
session,
@@ -309,6 +336,17 @@ class ORMCompileState(AbstractORMCompileState):
def __init__(self, *arg, **kw):
raise NotImplementedError()
+ if TYPE_CHECKING:
+
+ @classmethod
+ def create_for_statement(
+ cls,
+ statement: Union[Select, FromStatement],
+ compiler: Optional[SQLCompiler],
+ **kw: Any,
+ ) -> ORMCompileState:
+ ...
+
def _append_dedupe_col_collection(self, obj, col_collection):
dedupe = self.dedupe_columns
if obj not in dedupe:
@@ -333,26 +371,6 @@ class ORMCompileState(AbstractORMCompileState):
return SelectState._column_naming_convention(label_style)
@classmethod
- def create_for_statement(
- cls,
- statement: Union[Select, FromStatement],
- compiler: Optional[SQLCompiler],
- **kw: Any,
- ) -> ORMCompileState:
- """Create a context for a statement given a :class:`.Compiler`.
-
- This method is always invoked in the context of SQLCompiler.process().
-
- For a Select object, this would be invoked from
- SQLCompiler.visit_select(). For the special FromStatement object used
- by Query to indicate "Query.from_statement()", this is called by
- FromStatement._compiler_dispatch() that would be called by
- SQLCompiler.process().
-
- """
- raise NotImplementedError()
-
- @classmethod
def get_column_descriptions(cls, statement):
return _column_descriptions(statement)
@@ -518,6 +536,49 @@ class ORMCompileState(AbstractORMCompileState):
)
+class DMLReturningColFilter:
+ """an adapter used for the DML RETURNING case.
+
+ Has a subset of the interface used by
+ :class:`.ORMAdapter` and is used for :class:`._QueryEntity`
+ instances to set up their columns as used in RETURNING for a
+ DML statement.
+
+ """
+
+ __slots__ = ("mapper", "columns", "__weakref__")
+
+ def __init__(self, target_mapper, immediate_dml_mapper):
+ if (
+ immediate_dml_mapper is not None
+ and target_mapper.local_table
+ is not immediate_dml_mapper.local_table
+ ):
+ # joined inh, or in theory other kinds of multi-table mappings
+ self.mapper = immediate_dml_mapper
+ else:
+ # single inh, normal mappings, etc.
+ self.mapper = target_mapper
+ self.columns = self.columns = util.WeakPopulateDict(
+ self.adapt_check_present # type: ignore
+ )
+
+ def __call__(self, col, as_filter):
+ for cc in sql_util._find_columns(col):
+ c2 = self.adapt_check_present(cc)
+ if c2 is not None:
+ return col
+ else:
+ return None
+
+ def adapt_check_present(self, col):
+ mapper = self.mapper
+ prop = mapper._columntoproperty.get(col, None)
+ if prop is None:
+ return None
+ return mapper.local_table.c.corresponding_column(col)
+
+
@sql.base.CompileState.plugin_for("orm", "orm_from_statement")
class ORMFromStatementCompileState(ORMCompileState):
_from_obj_alias = None
@@ -525,7 +586,7 @@ class ORMFromStatementCompileState(ORMCompileState):
statement_container: FromStatement
requested_statement: Union[SelectBase, TextClause, UpdateBase]
- dml_table: _DMLTableElement
+ dml_table: Optional[_DMLTableElement] = None
_has_orm_entities = False
multi_row_eager_loaders = False
@@ -541,7 +602,7 @@ class ORMFromStatementCompileState(ORMCompileState):
statement_container: Union[Select, FromStatement],
compiler: Optional[SQLCompiler],
**kw: Any,
- ) -> ORMCompileState:
+ ) -> ORMFromStatementCompileState:
if compiler is not None:
toplevel = not compiler.stack
@@ -565,6 +626,7 @@ class ORMFromStatementCompileState(ORMCompileState):
if statement.is_dml:
self.dml_table = statement.table
+ self.is_dml_returning = True
self._entities = []
self._polymorphic_adapters = {}
@@ -674,6 +736,18 @@ class ORMFromStatementCompileState(ORMCompileState):
def _get_current_adapter(self):
return None
+ def setup_dml_returning_compile_state(self, dml_mapper):
+ """used by BulkORMInsert (and Update / Delete?) to set up a handler
+ for RETURNING to return ORM objects and expressions
+
+ """
+ target_mapper = self.statement._propagate_attrs.get(
+ "plugin_subject", None
+ )
+ adapter = DMLReturningColFilter(target_mapper, dml_mapper)
+ for entity in self._entities:
+ entity.setup_dml_returning_compile_state(self, adapter)
+
class FromStatement(GroupedElement, Generative, TypedReturnsRows[_TP]):
"""Core construct that represents a load of ORM objects from various
@@ -813,7 +887,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
statement: Union[Select, FromStatement],
compiler: Optional[SQLCompiler],
**kw: Any,
- ) -> ORMCompileState:
+ ) -> ORMSelectCompileState:
"""compiler hook, we arrive here from compiler.visit_select() only."""
self = cls.__new__(cls)
@@ -2312,6 +2386,13 @@ class _QueryEntity:
def setup_compile_state(self, compile_state: ORMCompileState) -> None:
raise NotImplementedError()
+ def setup_dml_returning_compile_state(
+ self,
+ compile_state: ORMCompileState,
+ adapter: DMLReturningColFilter,
+ ) -> None:
+ raise NotImplementedError()
+
def row_processor(self, context, result):
raise NotImplementedError()
@@ -2509,8 +2590,24 @@ class _MapperEntity(_QueryEntity):
return _instance, self._label_name, self._extra_entities
- def setup_compile_state(self, compile_state):
+ def setup_dml_returning_compile_state(
+ self,
+ compile_state: ORMCompileState,
+ adapter: DMLReturningColFilter,
+ ) -> None:
+ loading._setup_entity_query(
+ compile_state,
+ self.mapper,
+ self,
+ self.path,
+ adapter,
+ compile_state.primary_columns,
+ with_polymorphic=self._with_polymorphic_mappers,
+ only_load_props=compile_state.compile_options._only_load_props,
+ polymorphic_discriminator=self._polymorphic_discriminator,
+ )
+ def setup_compile_state(self, compile_state):
adapter = self._get_entity_clauses(compile_state)
single_table_crit = self.mapper._single_table_criterion
@@ -2536,7 +2633,6 @@ class _MapperEntity(_QueryEntity):
only_load_props=compile_state.compile_options._only_load_props,
polymorphic_discriminator=self._polymorphic_discriminator,
)
-
compile_state._fallback_from_clauses.append(self.selectable)
@@ -2743,9 +2839,7 @@ class _ColumnEntity(_QueryEntity):
getter, label_name, extra_entities = self._row_processor
if self.translate_raw_column:
extra_entities += (
- result.context.invoked_statement._raw_columns[
- self.raw_column_index
- ],
+ context.query._raw_columns[self.raw_column_index],
)
return getter, label_name, extra_entities
@@ -2781,9 +2875,7 @@ class _ColumnEntity(_QueryEntity):
if self.translate_raw_column:
extra_entities = self._extra_entities + (
- result.context.invoked_statement._raw_columns[
- self.raw_column_index
- ],
+ context.query._raw_columns[self.raw_column_index],
)
return getter, self._label_name, extra_entities
else:
@@ -2843,6 +2935,8 @@ class _RawColumnEntity(_ColumnEntity):
current_adapter = compile_state._get_current_adapter()
if current_adapter:
column = current_adapter(self.column, False)
+ if column is None:
+ return
else:
column = self.column
@@ -2944,10 +3038,25 @@ class _ORMColumnEntity(_ColumnEntity):
self.entity_zero
) and entity.common_parent(self.entity_zero)
+ def setup_dml_returning_compile_state(
+ self,
+ compile_state: ORMCompileState,
+ adapter: DMLReturningColFilter,
+ ) -> None:
+ self._fetch_column = self.column
+ column = adapter(self.column, False)
+ if column is not None:
+ compile_state.dedupe_columns.add(column)
+ compile_state.primary_columns.append(column)
+
def setup_compile_state(self, compile_state):
current_adapter = compile_state._get_current_adapter()
if current_adapter:
column = current_adapter(self.column, False)
+ if column is None:
+ assert compile_state.is_dml_returning
+ self._fetch_column = self.column
+ return
else:
column = self.column