diff options
author | mike bayer <mike_mp@zzzcomputing.com> | 2022-05-16 02:32:44 +0000 |
---|---|---|
committer | Gerrit Code Review <gerrit@ci3.zzzcomputing.com> | 2022-05-16 02:32:44 +0000 |
commit | 5d080d17464712d33c0215d12513e529d848ee8c (patch) | |
tree | eec56f3138a48f55f2585a64f01b4fd9c14451b7 /lib/sqlalchemy/orm/context.py | |
parent | c4dad3695f4ab9fef3a4cb05893492afbec811f7 (diff) | |
parent | 18a73fb1d1c267842ead5dacd05a49f4344d8b22 (diff) | |
download | sqlalchemy-5d080d17464712d33c0215d12513e529d848ee8c.tar.gz |
Merge "revenge of pep 484" into main
Diffstat (limited to 'lib/sqlalchemy/orm/context.py')
-rw-r--r-- | lib/sqlalchemy/orm/context.py | 100 |
1 files changed, 76 insertions, 24 deletions
diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index 28fea2f9b..58556bb58 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -12,6 +12,7 @@ import itertools from typing import Any from typing import cast from typing import Dict +from typing import Iterable from typing import List from typing import Optional from typing import Set @@ -43,6 +44,7 @@ from ..sql import expression from ..sql import roles from ..sql import util as sql_util from ..sql import visitors +from ..sql._typing import _TP from ..sql._typing import is_dml from ..sql._typing import is_insert_update from ..sql._typing import is_select_base @@ -55,22 +57,32 @@ from ..sql.base import Options from ..sql.dml import UpdateBase from ..sql.elements import GroupedElement from ..sql.elements import TextClause -from ..sql.selectable import ExecutableReturnsRows from ..sql.selectable import LABEL_STYLE_DISAMBIGUATE_ONLY from ..sql.selectable import LABEL_STYLE_NONE from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL from ..sql.selectable import Select from ..sql.selectable import SelectLabelStyle from ..sql.selectable import SelectState +from ..sql.selectable import TypedReturnsRows from ..sql.visitors import InternalTraversal if TYPE_CHECKING: from ._typing import _InternalEntityType + from .loading import PostLoad from .mapper import Mapper from .query import Query + from .session import _BindArguments + from .session import Session + from ..engine.interfaces import _CoreSingleExecuteParams + from ..engine.interfaces import _ExecuteOptionsParameter + from ..sql._typing import _ColumnsClauseArgument + from ..sql.compiler import SQLCompiler from ..sql.dml import _DMLTableElement from ..sql.elements import ColumnElement + from ..sql.selectable import _JoinTargetElement from ..sql.selectable import _LabelConventionCallable + from ..sql.selectable import _SetupJoinsElement + from ..sql.selectable import ExecutableReturnsRows from ..sql.selectable import SelectBase from ..sql.type_api import TypeEngine @@ -80,7 +92,7 @@ _path_registry = PathRegistry.root _EMPTY_DICT = util.immutabledict() -LABEL_STYLE_LEGACY_ORM = util.symbol("LABEL_STYLE_LEGACY_ORM") +LABEL_STYLE_LEGACY_ORM = SelectLabelStyle.LABEL_STYLE_LEGACY_ORM class QueryContext: @@ -109,6 +121,10 @@ class QueryContext: "loaders_require_uniquing", ) + runid: int + post_load_paths: Dict[PathRegistry, PostLoad] + compile_state: ORMCompileState + class default_load_options(Options): _only_return_tuples = False _populate_existing = False @@ -123,13 +139,16 @@ class QueryContext: def __init__( self, - compile_state, - statement, - params, - session, - load_options, - execution_options=None, - bind_arguments=None, + compile_state: CompileState, + statement: Union[Select[Any], FromStatement[Any]], + params: _CoreSingleExecuteParams, + session: Session, + load_options: Union[ + Type[QueryContext.default_load_options], + QueryContext.default_load_options, + ], + execution_options: Optional[_ExecuteOptionsParameter] = None, + bind_arguments: Optional[_BindArguments] = None, ): self.load_options = load_options self.execution_options = execution_options or _EMPTY_DICT @@ -220,8 +239,8 @@ class ORMCompileState(CompileState): attributes: Dict[Any, Any] global_attributes: Dict[Any, Any] - statement: Union[Select, FromStatement] - select_statement: Union[Select, FromStatement] + statement: Union[Select[Any], FromStatement[Any]] + select_statement: Union[Select[Any], FromStatement[Any]] _entities: List[_QueryEntity] _polymorphic_adapters: Dict[_InternalEntityType, ORMAdapter] compile_options: Union[ @@ -238,6 +257,7 @@ class ORMCompileState(CompileState): Tuple[Any, ...] ] current_path: PathRegistry = _path_registry + _has_mapper_entities = False def __init__(self, *arg, **kw): raise NotImplementedError() @@ -266,7 +286,12 @@ class ORMCompileState(CompileState): return SelectState._column_naming_convention(label_style) @classmethod - def create_for_statement(cls, statement_container, compiler, **kw): + def create_for_statement( + cls, + statement: Union[Select, FromStatement], + compiler: Optional[SQLCompiler], + **kw: Any, + ) -> ORMCompileState: """Create a context for a statement given a :class:`.Compiler`. This method is always invoked in the context of SQLCompiler.process(). @@ -443,7 +468,12 @@ class ORMFromStatementCompileState(ORMCompileState): eager_joins = _EMPTY_DICT @classmethod - def create_for_statement(cls, statement_container, compiler, **kw): + def create_for_statement( + cls, + statement_container: Union[Select, FromStatement], + compiler: Optional[SQLCompiler], + **kw: Any, + ) -> ORMCompileState: if compiler is not None: toplevel = not compiler.stack @@ -577,7 +607,7 @@ class ORMFromStatementCompileState(ORMCompileState): return None -class FromStatement(GroupedElement, Generative, ExecutableReturnsRows): +class FromStatement(GroupedElement, Generative, TypedReturnsRows[_TP]): """Core construct that represents a load of ORM objects from various :class:`.ReturnsRows` and other classes including: @@ -595,7 +625,7 @@ class FromStatement(GroupedElement, Generative, ExecutableReturnsRows): _for_update_arg = None - element: Union[SelectBase, TextClause, UpdateBase] + element: Union[ExecutableReturnsRows, TextClause] _traverse_internals = [ ("_raw_columns", InternalTraversal.dp_clauseelement_list), @@ -606,7 +636,11 @@ class FromStatement(GroupedElement, Generative, ExecutableReturnsRows): ("_compile_options", InternalTraversal.dp_has_cache_key) ] - def __init__(self, entities, element): + def __init__( + self, + entities: Iterable[_ColumnsClauseArgument[Any]], + element: Union[ExecutableReturnsRows, TextClause], + ): self._raw_columns = [ coercions.expect( roles.ColumnsClauseRole, @@ -701,7 +735,12 @@ class ORMSelectCompileState(ORMCompileState, SelectState): _having_criteria = () @classmethod - def create_for_statement(cls, statement, compiler, **kw): + def create_for_statement( + cls, + statement: Union[Select, FromStatement], + compiler: Optional[SQLCompiler], + **kw: Any, + ) -> ORMCompileState: """compiler hook, we arrive here from compiler.visit_select() only.""" self = cls.__new__(cls) @@ -1073,9 +1112,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState): ) @classmethod - @util.preload_module("sqlalchemy.orm.query") def from_statement(cls, statement, from_statement): - query = util.preloaded.orm_query from_statement = coercions.expect( roles.ReturnsRowsRole, @@ -1083,7 +1120,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState): apply_propagate_attrs=statement, ) - stmt = query.FromStatement(statement._raw_columns, from_statement) + stmt = FromStatement(statement._raw_columns, from_statement) stmt.__dict__.update( _with_options=statement._with_options, @@ -2114,7 +2151,9 @@ def _column_descriptions( return d -def _legacy_filter_by_entity_zero(query_or_augmented_select): +def _legacy_filter_by_entity_zero( + query_or_augmented_select: Union[Query[Any], Select[Any]] +) -> Optional[_InternalEntityType[Any]]: self = query_or_augmented_select if self._setup_joins: _last_joined_entity = self._last_joined_entity @@ -2127,7 +2166,9 @@ def _legacy_filter_by_entity_zero(query_or_augmented_select): return _entity_from_pre_ent_zero(self) -def _entity_from_pre_ent_zero(query_or_augmented_select): +def _entity_from_pre_ent_zero( + query_or_augmented_select: Union[Query[Any], Select[Any]] +) -> Optional[_InternalEntityType[Any]]: self = query_or_augmented_select if not self._raw_columns: return None @@ -2144,13 +2185,19 @@ def _entity_from_pre_ent_zero(query_or_augmented_select): return ent -def _determine_last_joined_entity(setup_joins, entity_zero=None): +def _determine_last_joined_entity( + setup_joins: Tuple[_SetupJoinsElement, ...], + entity_zero: Optional[_InternalEntityType[Any]] = None, +) -> Optional[Union[_InternalEntityType[Any], _JoinTargetElement]]: if not setup_joins: return None (target, onclause, from_, flags) = setup_joins[-1] - if isinstance(target, interfaces.PropComparator): + if isinstance( + target, + attributes.QueryableAttribute, + ): return target.entity else: return target @@ -2161,6 +2208,8 @@ class _QueryEntity: __slots__ = () + supports_single_entity: bool + _non_hashable_value = False _null_column_type = False use_id_for_hash = False @@ -2173,6 +2222,9 @@ class _QueryEntity: def setup_compile_state(self, compile_state: ORMCompileState) -> None: raise NotImplementedError() + def row_processor(self, context, result): + raise NotImplementedError() + @classmethod def to_compile_state( cls, compile_state, entities, entities_collection, is_current_entities |