summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/context.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/orm/context.py')
-rw-r--r--lib/sqlalchemy/orm/context.py100
1 files changed, 76 insertions, 24 deletions
diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py
index 28fea2f9b..58556bb58 100644
--- a/lib/sqlalchemy/orm/context.py
+++ b/lib/sqlalchemy/orm/context.py
@@ -12,6 +12,7 @@ import itertools
from typing import Any
from typing import cast
from typing import Dict
+from typing import Iterable
from typing import List
from typing import Optional
from typing import Set
@@ -43,6 +44,7 @@ from ..sql import expression
from ..sql import roles
from ..sql import util as sql_util
from ..sql import visitors
+from ..sql._typing import _TP
from ..sql._typing import is_dml
from ..sql._typing import is_insert_update
from ..sql._typing import is_select_base
@@ -55,22 +57,32 @@ from ..sql.base import Options
from ..sql.dml import UpdateBase
from ..sql.elements import GroupedElement
from ..sql.elements import TextClause
-from ..sql.selectable import ExecutableReturnsRows
from ..sql.selectable import LABEL_STYLE_DISAMBIGUATE_ONLY
from ..sql.selectable import LABEL_STYLE_NONE
from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
from ..sql.selectable import Select
from ..sql.selectable import SelectLabelStyle
from ..sql.selectable import SelectState
+from ..sql.selectable import TypedReturnsRows
from ..sql.visitors import InternalTraversal
if TYPE_CHECKING:
from ._typing import _InternalEntityType
+ from .loading import PostLoad
from .mapper import Mapper
from .query import Query
+ from .session import _BindArguments
+ from .session import Session
+ from ..engine.interfaces import _CoreSingleExecuteParams
+ from ..engine.interfaces import _ExecuteOptionsParameter
+ from ..sql._typing import _ColumnsClauseArgument
+ from ..sql.compiler import SQLCompiler
from ..sql.dml import _DMLTableElement
from ..sql.elements import ColumnElement
+ from ..sql.selectable import _JoinTargetElement
from ..sql.selectable import _LabelConventionCallable
+ from ..sql.selectable import _SetupJoinsElement
+ from ..sql.selectable import ExecutableReturnsRows
from ..sql.selectable import SelectBase
from ..sql.type_api import TypeEngine
@@ -80,7 +92,7 @@ _path_registry = PathRegistry.root
_EMPTY_DICT = util.immutabledict()
-LABEL_STYLE_LEGACY_ORM = util.symbol("LABEL_STYLE_LEGACY_ORM")
+LABEL_STYLE_LEGACY_ORM = SelectLabelStyle.LABEL_STYLE_LEGACY_ORM
class QueryContext:
@@ -109,6 +121,10 @@ class QueryContext:
"loaders_require_uniquing",
)
+ runid: int
+ post_load_paths: Dict[PathRegistry, PostLoad]
+ compile_state: ORMCompileState
+
class default_load_options(Options):
_only_return_tuples = False
_populate_existing = False
@@ -123,13 +139,16 @@ class QueryContext:
def __init__(
self,
- compile_state,
- statement,
- params,
- session,
- load_options,
- execution_options=None,
- bind_arguments=None,
+ compile_state: CompileState,
+ statement: Union[Select[Any], FromStatement[Any]],
+ params: _CoreSingleExecuteParams,
+ session: Session,
+ load_options: Union[
+ Type[QueryContext.default_load_options],
+ QueryContext.default_load_options,
+ ],
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ bind_arguments: Optional[_BindArguments] = None,
):
self.load_options = load_options
self.execution_options = execution_options or _EMPTY_DICT
@@ -220,8 +239,8 @@ class ORMCompileState(CompileState):
attributes: Dict[Any, Any]
global_attributes: Dict[Any, Any]
- statement: Union[Select, FromStatement]
- select_statement: Union[Select, FromStatement]
+ statement: Union[Select[Any], FromStatement[Any]]
+ select_statement: Union[Select[Any], FromStatement[Any]]
_entities: List[_QueryEntity]
_polymorphic_adapters: Dict[_InternalEntityType, ORMAdapter]
compile_options: Union[
@@ -238,6 +257,7 @@ class ORMCompileState(CompileState):
Tuple[Any, ...]
]
current_path: PathRegistry = _path_registry
+ _has_mapper_entities = False
def __init__(self, *arg, **kw):
raise NotImplementedError()
@@ -266,7 +286,12 @@ class ORMCompileState(CompileState):
return SelectState._column_naming_convention(label_style)
@classmethod
- def create_for_statement(cls, statement_container, compiler, **kw):
+ def create_for_statement(
+ cls,
+ statement: Union[Select, FromStatement],
+ compiler: Optional[SQLCompiler],
+ **kw: Any,
+ ) -> ORMCompileState:
"""Create a context for a statement given a :class:`.Compiler`.
This method is always invoked in the context of SQLCompiler.process().
@@ -443,7 +468,12 @@ class ORMFromStatementCompileState(ORMCompileState):
eager_joins = _EMPTY_DICT
@classmethod
- def create_for_statement(cls, statement_container, compiler, **kw):
+ def create_for_statement(
+ cls,
+ statement_container: Union[Select, FromStatement],
+ compiler: Optional[SQLCompiler],
+ **kw: Any,
+ ) -> ORMCompileState:
if compiler is not None:
toplevel = not compiler.stack
@@ -577,7 +607,7 @@ class ORMFromStatementCompileState(ORMCompileState):
return None
-class FromStatement(GroupedElement, Generative, ExecutableReturnsRows):
+class FromStatement(GroupedElement, Generative, TypedReturnsRows[_TP]):
"""Core construct that represents a load of ORM objects from various
:class:`.ReturnsRows` and other classes including:
@@ -595,7 +625,7 @@ class FromStatement(GroupedElement, Generative, ExecutableReturnsRows):
_for_update_arg = None
- element: Union[SelectBase, TextClause, UpdateBase]
+ element: Union[ExecutableReturnsRows, TextClause]
_traverse_internals = [
("_raw_columns", InternalTraversal.dp_clauseelement_list),
@@ -606,7 +636,11 @@ class FromStatement(GroupedElement, Generative, ExecutableReturnsRows):
("_compile_options", InternalTraversal.dp_has_cache_key)
]
- def __init__(self, entities, element):
+ def __init__(
+ self,
+ entities: Iterable[_ColumnsClauseArgument[Any]],
+ element: Union[ExecutableReturnsRows, TextClause],
+ ):
self._raw_columns = [
coercions.expect(
roles.ColumnsClauseRole,
@@ -701,7 +735,12 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
_having_criteria = ()
@classmethod
- def create_for_statement(cls, statement, compiler, **kw):
+ def create_for_statement(
+ cls,
+ statement: Union[Select, FromStatement],
+ compiler: Optional[SQLCompiler],
+ **kw: Any,
+ ) -> ORMCompileState:
"""compiler hook, we arrive here from compiler.visit_select() only."""
self = cls.__new__(cls)
@@ -1073,9 +1112,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
)
@classmethod
- @util.preload_module("sqlalchemy.orm.query")
def from_statement(cls, statement, from_statement):
- query = util.preloaded.orm_query
from_statement = coercions.expect(
roles.ReturnsRowsRole,
@@ -1083,7 +1120,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
apply_propagate_attrs=statement,
)
- stmt = query.FromStatement(statement._raw_columns, from_statement)
+ stmt = FromStatement(statement._raw_columns, from_statement)
stmt.__dict__.update(
_with_options=statement._with_options,
@@ -2114,7 +2151,9 @@ def _column_descriptions(
return d
-def _legacy_filter_by_entity_zero(query_or_augmented_select):
+def _legacy_filter_by_entity_zero(
+ query_or_augmented_select: Union[Query[Any], Select[Any]]
+) -> Optional[_InternalEntityType[Any]]:
self = query_or_augmented_select
if self._setup_joins:
_last_joined_entity = self._last_joined_entity
@@ -2127,7 +2166,9 @@ def _legacy_filter_by_entity_zero(query_or_augmented_select):
return _entity_from_pre_ent_zero(self)
-def _entity_from_pre_ent_zero(query_or_augmented_select):
+def _entity_from_pre_ent_zero(
+ query_or_augmented_select: Union[Query[Any], Select[Any]]
+) -> Optional[_InternalEntityType[Any]]:
self = query_or_augmented_select
if not self._raw_columns:
return None
@@ -2144,13 +2185,19 @@ def _entity_from_pre_ent_zero(query_or_augmented_select):
return ent
-def _determine_last_joined_entity(setup_joins, entity_zero=None):
+def _determine_last_joined_entity(
+ setup_joins: Tuple[_SetupJoinsElement, ...],
+ entity_zero: Optional[_InternalEntityType[Any]] = None,
+) -> Optional[Union[_InternalEntityType[Any], _JoinTargetElement]]:
if not setup_joins:
return None
(target, onclause, from_, flags) = setup_joins[-1]
- if isinstance(target, interfaces.PropComparator):
+ if isinstance(
+ target,
+ attributes.QueryableAttribute,
+ ):
return target.entity
else:
return target
@@ -2161,6 +2208,8 @@ class _QueryEntity:
__slots__ = ()
+ supports_single_entity: bool
+
_non_hashable_value = False
_null_column_type = False
use_id_for_hash = False
@@ -2173,6 +2222,9 @@ class _QueryEntity:
def setup_compile_state(self, compile_state: ORMCompileState) -> None:
raise NotImplementedError()
+ def row_processor(self, context, result):
+ raise NotImplementedError()
+
@classmethod
def to_compile_state(
cls, compile_state, entities, entities_collection, is_current_entities