summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/engine/cursor.py116
-rw-r--r--lib/sqlalchemy/engine/default.py7
-rw-r--r--lib/sqlalchemy/engine/interfaces.py25
-rw-r--r--lib/sqlalchemy/ext/associationproxy.py35
-rw-r--r--lib/sqlalchemy/ext/declarative/extensions.py36
-rw-r--r--lib/sqlalchemy/ext/hybrid.py6
-rw-r--r--lib/sqlalchemy/orm/_orm_constructors.py86
-rw-r--r--lib/sqlalchemy/orm/_typing.py52
-rw-r--r--lib/sqlalchemy/orm/attributes.py74
-rw-r--r--lib/sqlalchemy/orm/base.py81
-rw-r--r--lib/sqlalchemy/orm/clsregistry.py177
-rw-r--r--lib/sqlalchemy/orm/collections.py30
-rw-r--r--lib/sqlalchemy/orm/context.py100
-rw-r--r--lib/sqlalchemy/orm/decl_api.py229
-rw-r--r--lib/sqlalchemy/orm/decl_base.py414
-rw-r--r--lib/sqlalchemy/orm/descriptor_props.py335
-rw-r--r--lib/sqlalchemy/orm/dynamic.py80
-rw-r--r--lib/sqlalchemy/orm/events.py1
-rw-r--r--lib/sqlalchemy/orm/exc.py34
-rw-r--r--lib/sqlalchemy/orm/identity.py40
-rw-r--r--lib/sqlalchemy/orm/instrumentation.py30
-rw-r--r--lib/sqlalchemy/orm/interfaces.py115
-rw-r--r--lib/sqlalchemy/orm/loading.py18
-rw-r--r--lib/sqlalchemy/orm/mapped_collection.py1
-rw-r--r--lib/sqlalchemy/orm/mapper.py80
-rw-r--r--lib/sqlalchemy/orm/path_registry.py126
-rw-r--r--lib/sqlalchemy/orm/properties.py197
-rw-r--r--lib/sqlalchemy/orm/query.py643
-rw-r--r--lib/sqlalchemy/orm/relationships.py1004
-rw-r--r--lib/sqlalchemy/orm/session.py7
-rw-r--r--lib/sqlalchemy/orm/state.py18
-rw-r--r--lib/sqlalchemy/orm/state_changes.py19
-rw-r--r--lib/sqlalchemy/orm/strategies.py89
-rw-r--r--lib/sqlalchemy/orm/strategy_options.py514
-rw-r--r--lib/sqlalchemy/orm/sync.py16
-rw-r--r--lib/sqlalchemy/orm/util.py70
-rw-r--r--lib/sqlalchemy/sql/_typing.py10
-rw-r--r--lib/sqlalchemy/sql/annotation.py18
-rw-r--r--lib/sqlalchemy/sql/base.py22
-rw-r--r--lib/sqlalchemy/sql/coercions.py11
-rw-r--r--lib/sqlalchemy/sql/elements.py18
-rw-r--r--lib/sqlalchemy/sql/selectable.py66
-rw-r--r--lib/sqlalchemy/sql/traversals.py37
-rw-r--r--lib/sqlalchemy/sql/util.py44
-rw-r--r--lib/sqlalchemy/sql/visitors.py46
-rw-r--r--lib/sqlalchemy/util/_collections.py4
-rw-r--r--lib/sqlalchemy/util/compat.py9
-rw-r--r--lib/sqlalchemy/util/langhelpers.py32
-rw-r--r--lib/sqlalchemy/util/preloaded.py4
-rw-r--r--lib/sqlalchemy/util/topological.py35
-rw-r--r--lib/sqlalchemy/util/typing.py77
51 files changed, 3643 insertions, 1695 deletions
diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py
index f4e22df2d..d5f0d8126 100644
--- a/lib/sqlalchemy/engine/cursor.py
+++ b/lib/sqlalchemy/engine/cursor.py
@@ -21,6 +21,7 @@ from typing import ClassVar
from typing import Dict
from typing import Iterator
from typing import List
+from typing import NoReturn
from typing import Optional
from typing import Sequence
from typing import Tuple
@@ -53,7 +54,11 @@ _UNPICKLED = util.symbol("unpickled")
if typing.TYPE_CHECKING:
+ from .base import Connection
+ from .default import DefaultExecutionContext
from .interfaces import _DBAPICursorDescription
+ from .interfaces import DBAPICursor
+ from .interfaces import Dialect
from .interfaces import ExecutionContext
from .result import _KeyIndexType
from .result import _KeyMapRecType
@@ -61,6 +66,7 @@ if typing.TYPE_CHECKING:
from .result import _ProcessorsType
from ..sql.type_api import _ResultProcessorType
+
_T = TypeVar("_T", bound=Any)
# metadata entry tuple indexes.
@@ -235,7 +241,7 @@ class CursorResultMetaData(ResultMetaData):
) = context.result_column_struct
num_ctx_cols = len(result_columns)
else:
- result_columns = (
+ result_columns = ( # type: ignore
cols_are_ordered
) = (
num_ctx_cols
@@ -776,25 +782,53 @@ class ResultFetchStrategy:
alternate_cursor_description: Optional[_DBAPICursorDescription] = None
- def soft_close(self, result, dbapi_cursor):
+ def soft_close(
+ self, result: CursorResult[Any], dbapi_cursor: Optional[DBAPICursor]
+ ) -> None:
raise NotImplementedError()
- def hard_close(self, result, dbapi_cursor):
+ def hard_close(
+ self, result: CursorResult[Any], dbapi_cursor: Optional[DBAPICursor]
+ ) -> None:
raise NotImplementedError()
- def yield_per(self, result, dbapi_cursor, num):
+ def yield_per(
+ self,
+ result: CursorResult[Any],
+ dbapi_cursor: Optional[DBAPICursor],
+ num: int,
+ ) -> None:
return
- def fetchone(self, result, dbapi_cursor, hard_close=False):
+ def fetchone(
+ self,
+ result: CursorResult[Any],
+ dbapi_cursor: DBAPICursor,
+ hard_close: bool = False,
+ ) -> Any:
raise NotImplementedError()
- def fetchmany(self, result, dbapi_cursor, size=None):
+ def fetchmany(
+ self,
+ result: CursorResult[Any],
+ dbapi_cursor: DBAPICursor,
+ size: Optional[int] = None,
+ ) -> Any:
raise NotImplementedError()
- def fetchall(self, result):
+ def fetchall(
+ self,
+ result: CursorResult[Any],
+ dbapi_cursor: DBAPICursor,
+ ) -> Any:
raise NotImplementedError()
- def handle_exception(self, result, dbapi_cursor, err):
+ def handle_exception(
+ self,
+ result: CursorResult[Any],
+ dbapi_cursor: Optional[DBAPICursor],
+ err: BaseException,
+ ) -> NoReturn:
raise err
@@ -882,18 +916,32 @@ class CursorFetchStrategy(ResultFetchStrategy):
__slots__ = ()
- def soft_close(self, result, dbapi_cursor):
+ def soft_close(
+ self, result: CursorResult[Any], dbapi_cursor: Optional[DBAPICursor]
+ ) -> None:
result.cursor_strategy = _NO_CURSOR_DQL
- def hard_close(self, result, dbapi_cursor):
+ def hard_close(
+ self, result: CursorResult[Any], dbapi_cursor: Optional[DBAPICursor]
+ ) -> None:
result.cursor_strategy = _NO_CURSOR_DQL
- def handle_exception(self, result, dbapi_cursor, err):
+ def handle_exception(
+ self,
+ result: CursorResult[Any],
+ dbapi_cursor: Optional[DBAPICursor],
+ err: BaseException,
+ ) -> NoReturn:
result.connection._handle_dbapi_exception(
err, None, None, dbapi_cursor, result.context
)
- def yield_per(self, result, dbapi_cursor, num):
+ def yield_per(
+ self,
+ result: CursorResult[Any],
+ dbapi_cursor: Optional[DBAPICursor],
+ num: int,
+ ) -> None:
result.cursor_strategy = BufferedRowCursorFetchStrategy(
dbapi_cursor,
{"max_row_buffer": num},
@@ -901,7 +949,12 @@ class CursorFetchStrategy(ResultFetchStrategy):
growth_factor=0,
)
- def fetchone(self, result, dbapi_cursor, hard_close=False):
+ def fetchone(
+ self,
+ result: CursorResult[Any],
+ dbapi_cursor: DBAPICursor,
+ hard_close: bool = False,
+ ) -> Any:
try:
row = dbapi_cursor.fetchone()
if row is None:
@@ -910,7 +963,12 @@ class CursorFetchStrategy(ResultFetchStrategy):
except BaseException as e:
self.handle_exception(result, dbapi_cursor, e)
- def fetchmany(self, result, dbapi_cursor, size=None):
+ def fetchmany(
+ self,
+ result: CursorResult[Any],
+ dbapi_cursor: DBAPICursor,
+ size: Optional[int] = None,
+ ) -> Any:
try:
if size is None:
l = dbapi_cursor.fetchmany()
@@ -923,7 +981,11 @@ class CursorFetchStrategy(ResultFetchStrategy):
except BaseException as e:
self.handle_exception(result, dbapi_cursor, e)
- def fetchall(self, result, dbapi_cursor):
+ def fetchall(
+ self,
+ result: CursorResult[Any],
+ dbapi_cursor: DBAPICursor,
+ ) -> Any:
try:
rows = dbapi_cursor.fetchall()
result._soft_close()
@@ -1163,6 +1225,9 @@ class _NoResultMetaData(ResultMetaData):
_NO_RESULT_METADATA = _NoResultMetaData()
+SelfCursorResult = TypeVar("SelfCursorResult", bound="CursorResult[Any]")
+
+
class CursorResult(Result[_T]):
"""A Result that is representing state from a DBAPI cursor.
@@ -1199,7 +1264,17 @@ class CursorResult(Result[_T]):
closed: bool = False
_is_cursor = True
- def __init__(self, context, cursor_strategy, cursor_description):
+ context: DefaultExecutionContext
+ dialect: Dialect
+ cursor_strategy: ResultFetchStrategy
+ connection: Connection
+
+ def __init__(
+ self,
+ context: DefaultExecutionContext,
+ cursor_strategy: ResultFetchStrategy,
+ cursor_description: Optional[_DBAPICursorDescription],
+ ):
self.context = context
self.dialect = context.dialect
self.cursor = context.cursor
@@ -1333,7 +1408,7 @@ class CursorResult(Result[_T]):
if not self._soft_closed:
cursor = self.cursor
- self.cursor = None
+ self.cursor = None # type: ignore
self.connection._safe_close_cursor(cursor)
self._soft_closed = True
@@ -1605,7 +1680,7 @@ class CursorResult(Result[_T]):
return self.dialect.supports_sane_multi_rowcount
@util.memoized_property
- def rowcount(self):
+ def rowcount(self) -> int:
"""Return the 'rowcount' for this result.
The 'rowcount' reports the number of rows *matched*
@@ -1655,6 +1730,7 @@ class CursorResult(Result[_T]):
return self.context.rowcount
except BaseException as e:
self.cursor_strategy.handle_exception(self, self.cursor, e)
+ raise # not called
@property
def lastrowid(self):
@@ -1749,7 +1825,7 @@ class CursorResult(Result[_T]):
)
return merged_result
- def close(self):
+ def close(self) -> Any:
"""Close this :class:`_engine.CursorResult`.
This closes out the underlying DBAPI cursor corresponding to the
@@ -1772,7 +1848,7 @@ class CursorResult(Result[_T]):
self._soft_close(hard=True)
@_generative
- def yield_per(self, num):
+ def yield_per(self: SelfCursorResult, num: int) -> SelfCursorResult:
self._yield_per = num
self.cursor_strategy.yield_per(self, self.cursor, num)
return self
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index 6094ad0fb..fc114efa3 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -64,6 +64,7 @@ if typing.TYPE_CHECKING:
from .base import Engine
from .interfaces import _CoreMultiExecuteParams
from .interfaces import _CoreSingleExecuteParams
+ from .interfaces import _DBAPICursorDescription
from .interfaces import _DBAPIMultiExecuteParams
from .interfaces import _ExecuteOptions
from .interfaces import _IsolationLevel
@@ -1285,8 +1286,8 @@ class DefaultExecutionContext(ExecutionContext):
def handle_dbapi_exception(self, e):
pass
- @property
- def rowcount(self):
+ @util.non_memoized_property
+ def rowcount(self) -> int:
return self.cursor.rowcount
def supports_sane_rowcount(self):
@@ -1304,7 +1305,7 @@ class DefaultExecutionContext(ExecutionContext):
strategy = _cursor.BufferedRowCursorFetchStrategy(
self.cursor, self.execution_options
)
- cursor_description = (
+ cursor_description: _DBAPICursorDescription = (
strategy.alternate_cursor_description
or self.cursor.description
)
diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py
index 641024603..e5414b70f 100644
--- a/lib/sqlalchemy/engine/interfaces.py
+++ b/lib/sqlalchemy/engine/interfaces.py
@@ -133,17 +133,7 @@ class DBAPICursor(Protocol):
@property
def description(
self,
- ) -> Sequence[
- Tuple[
- str,
- "DBAPIType",
- Optional[int],
- Optional[int],
- Optional[int],
- Optional[int],
- Optional[bool],
- ]
- ]:
+ ) -> _DBAPICursorDescription:
"""The description attribute of the Cursor.
.. seealso::
@@ -217,7 +207,15 @@ _DBAPIMultiExecuteParams = Union[
_DBAPIAnyExecuteParams = Union[
_DBAPIMultiExecuteParams, _DBAPISingleExecuteParams
]
-_DBAPICursorDescription = Tuple[str, Any, Any, Any, Any, Any, Any]
+_DBAPICursorDescription = Tuple[
+ str,
+ "DBAPIType",
+ Optional[int],
+ Optional[int],
+ Optional[int],
+ Optional[int],
+ Optional[bool],
+]
_AnySingleExecuteParams = _DBAPISingleExecuteParams
_AnyMultiExecuteParams = _DBAPIMultiExecuteParams
@@ -2297,6 +2295,9 @@ class ExecutionContext:
"""
+ engine: Engine
+ """engine which the Connection is associated with"""
+
connection: Connection
"""Connection object which can be freely used by default value
generators to execute SQL. This Connection should reference the
diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py
index 420ba5c8c..7db95eac9 100644
--- a/lib/sqlalchemy/ext/associationproxy.py
+++ b/lib/sqlalchemy/ext/associationproxy.py
@@ -53,7 +53,6 @@ from ..orm import ORMDescriptor
from ..orm.base import SQLORMOperations
from ..sql import operators
from ..sql import or_
-from ..sql.elements import SQLCoreOperations
from ..util.typing import Literal
from ..util.typing import Protocol
from ..util.typing import Self
@@ -64,8 +63,10 @@ if typing.TYPE_CHECKING:
from ..orm.interfaces import MapperProperty
from ..orm.interfaces import PropComparator
from ..orm.mapper import Mapper
+ from ..sql._typing import _ColumnExpressionArgument
from ..sql._typing import _InfoType
+
_T = TypeVar("_T", bound=Any)
_T_co = TypeVar("_T_co", bound=Any, covariant=True)
_T_con = TypeVar("_T_con", bound=Any, contravariant=True)
@@ -631,7 +632,9 @@ class AssociationProxyInstance(SQLORMOperations[_T]):
@property
def _comparator(self) -> PropComparator[Any]:
- return self._get_property().comparator
+ return getattr( # type: ignore
+ self.owning_class, self.target_collection
+ ).comparator
def __clause_element__(self) -> NoReturn:
raise NotImplementedError(
@@ -957,7 +960,9 @@ class AssociationProxyInstance(SQLORMOperations[_T]):
proxy.setter = setter
def _criterion_exists(
- self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs: Any
+ self,
+ criterion: Optional[_ColumnExpressionArgument[bool]] = None,
+ **kwargs: Any,
) -> ColumnElement[bool]:
is_has = kwargs.pop("is_has", None)
@@ -969,8 +974,8 @@ class AssociationProxyInstance(SQLORMOperations[_T]):
return self._comparator._criterion_exists(inner)
if self._target_is_object:
- prop = getattr(self.target_class, self.value_attr)
- value_expr = prop._criterion_exists(criterion, **kwargs)
+ attr = getattr(self.target_class, self.value_attr)
+ value_expr = attr.comparator._criterion_exists(criterion, **kwargs)
else:
if kwargs:
raise exc.ArgumentError(
@@ -988,8 +993,10 @@ class AssociationProxyInstance(SQLORMOperations[_T]):
return self._comparator._criterion_exists(value_expr)
def any(
- self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs: Any
- ) -> SQLCoreOperations[Any]:
+ self,
+ criterion: Optional[_ColumnExpressionArgument[bool]] = None,
+ **kwargs: Any,
+ ) -> ColumnElement[bool]:
"""Produce a proxied 'any' expression using EXISTS.
This expression will be a composed product
@@ -1010,8 +1017,10 @@ class AssociationProxyInstance(SQLORMOperations[_T]):
)
def has(
- self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs: Any
- ) -> SQLCoreOperations[Any]:
+ self,
+ criterion: Optional[_ColumnExpressionArgument[bool]] = None,
+ **kwargs: Any,
+ ) -> ColumnElement[bool]:
"""Produce a proxied 'has' expression using EXISTS.
This expression will be a composed product
@@ -1069,12 +1078,16 @@ class AmbiguousAssociationProxyInstance(AssociationProxyInstance[_T]):
self._ambiguous()
def any(
- self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs: Any
+ self,
+ criterion: Optional[_ColumnExpressionArgument[bool]] = None,
+ **kwargs: Any,
) -> NoReturn:
self._ambiguous()
def has(
- self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs: Any
+ self,
+ criterion: Optional[_ColumnExpressionArgument[bool]] = None,
+ **kwargs: Any,
) -> NoReturn:
self._ambiguous()
diff --git a/lib/sqlalchemy/ext/declarative/extensions.py b/lib/sqlalchemy/ext/declarative/extensions.py
index 9faf2ed51..22fa83c58 100644
--- a/lib/sqlalchemy/ext/declarative/extensions.py
+++ b/lib/sqlalchemy/ext/declarative/extensions.py
@@ -8,7 +8,10 @@
"""Public API functions and helpers for declarative."""
+from __future__ import annotations
+from typing import Callable
+from typing import TYPE_CHECKING
from ... import inspection
from ...orm import exc as orm_exc
@@ -20,6 +23,10 @@ from ...orm.util import polymorphic_union
from ...schema import Table
from ...util import OrderedDict
+if TYPE_CHECKING:
+ from ...engine.reflection import Inspector
+ from ...sql.schema import MetaData
+
class ConcreteBase:
"""A helper class for 'concrete' declarative mappings.
@@ -380,31 +387,36 @@ class DeferredReflection:
mapper = thingy.cls.__mapper__
metadata = mapper.class_.metadata
for rel in mapper._props.values():
+
if (
isinstance(rel, relationships.Relationship)
- and rel.secondary is not None
+ and rel._init_args.secondary._is_populated()
):
- if isinstance(rel.secondary, Table):
- cls._reflect_table(rel.secondary, insp)
- elif isinstance(rel.secondary, str):
+
+ secondary_arg = rel._init_args.secondary
+
+ if isinstance(secondary_arg.argument, Table):
+ cls._reflect_table(secondary_arg.argument, insp)
+ elif isinstance(secondary_arg.argument, str):
_, resolve_arg = _resolver(rel.parent.class_, rel)
- rel.secondary = resolve_arg(rel.secondary)
- rel.secondary._resolvers += (
+ resolver = resolve_arg(
+ secondary_arg.argument, True
+ )
+ resolver._resolvers += (
cls._sa_deferred_table_resolver(
insp, metadata
),
)
- # controversy! do we resolve it here? or leave
- # it deferred? I think doing it here is necessary
- # so the connection does not leak.
- rel.secondary = rel.secondary()
+ secondary_arg.argument = resolver()
@classmethod
- def _sa_deferred_table_resolver(cls, inspector, metadata):
- def _resolve(key):
+ def _sa_deferred_table_resolver(
+ cls, inspector: Inspector, metadata: MetaData
+ ) -> Callable[[str], Table]:
+ def _resolve(key: str) -> Table:
t1 = Table(key, metadata)
cls._reflect_table(t1, inspector)
return t1
diff --git a/lib/sqlalchemy/ext/hybrid.py b/lib/sqlalchemy/ext/hybrid.py
index ea558495b..accfa8949 100644
--- a/lib/sqlalchemy/ext/hybrid.py
+++ b/lib/sqlalchemy/ext/hybrid.py
@@ -1305,8 +1305,8 @@ class Comparator(interfaces.PropComparator[_T]):
return ret_expr
@util.non_memoized_property
- def property(self) -> Optional[interfaces.MapperProperty[_T]]:
- return None
+ def property(self) -> interfaces.MapperProperty[_T]:
+ raise NotImplementedError()
def adapt_to_entity(
self, adapt_to_entity: AliasedInsp[Any]
@@ -1344,7 +1344,7 @@ class ExprComparator(Comparator[_T]):
return [(self.expression, value)]
@util.non_memoized_property
- def property(self) -> Optional[MapperProperty[_T]]:
+ def property(self) -> MapperProperty[_T]:
# this accessor is not normally used, however is accessed by things
# like ORM synonyms if the hybrid is used in this context; the
# .property attribute is not necessarily accessible
diff --git a/lib/sqlalchemy/orm/_orm_constructors.py b/lib/sqlalchemy/orm/_orm_constructors.py
index 560db9817..18a18bd80 100644
--- a/lib/sqlalchemy/orm/_orm_constructors.py
+++ b/lib/sqlalchemy/orm/_orm_constructors.py
@@ -4,7 +4,6 @@
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: allow-untyped-defs, allow-untyped-calls
from __future__ import annotations
@@ -12,6 +11,8 @@ import typing
from typing import Any
from typing import Callable
from typing import Collection
+from typing import Iterable
+from typing import NoReturn
from typing import Optional
from typing import overload
from typing import Type
@@ -45,6 +46,7 @@ from ..util.typing import Literal
if TYPE_CHECKING:
from ._typing import _EntityType
from ._typing import _ORMColumnExprArgument
+ from .descriptor_props import _CC
from .descriptor_props import _CompositeAttrType
from .interfaces import PropComparator
from .mapper import Mapper
@@ -54,14 +56,19 @@ if TYPE_CHECKING:
from .relationships import _ORMColCollectionArgument
from .relationships import _ORMOrderByArgument
from .relationships import _RelationshipJoinConditionArgument
+ from .session import _SessionBind
from ..sql._typing import _ColumnExpressionArgument
+ from ..sql._typing import _FromClauseArgument
from ..sql._typing import _InfoType
+ from ..sql._typing import _OnClauseArgument
from ..sql._typing import _TypeEngineArgument
+ from ..sql.elements import ColumnElement
from ..sql.schema import _ServerDefaultType
from ..sql.schema import FetchedValue
from ..sql.selectable import Alias
from ..sql.selectable import Subquery
+
_T = typing.TypeVar("_T")
@@ -424,10 +431,10 @@ def column_property(
@overload
def composite(
- class_: Type[_T],
+ class_: Type[_CC],
*attrs: _CompositeAttrType[Any],
**kwargs: Any,
-) -> Composite[_T]:
+) -> Composite[_CC]:
...
@@ -680,7 +687,7 @@ def with_loader_criteria(
def relationship(
argument: Optional[_RelationshipArgumentType[Any]] = None,
- secondary: Optional[FromClause] = None,
+ secondary: Optional[Union[FromClause, str]] = None,
*,
uselist: Optional[bool] = None,
collection_class: Optional[
@@ -696,14 +703,14 @@ def relationship(
cascade: str = "save-update, merge",
viewonly: bool = False,
lazy: _LazyLoadArgumentType = "select",
- passive_deletes: bool = False,
+ passive_deletes: Union[Literal["all"], bool] = False,
passive_updates: bool = True,
active_history: bool = False,
enable_typechecks: bool = True,
foreign_keys: Optional[_ORMColCollectionArgument] = None,
remote_side: Optional[_ORMColCollectionArgument] = None,
join_depth: Optional[int] = None,
- comparator_factory: Optional[Type[PropComparator[Any]]] = None,
+ comparator_factory: Optional[Type[Relationship.Comparator[Any]]] = None,
single_parent: bool = False,
innerjoin: bool = False,
distinct_target_key: Optional[bool] = None,
@@ -1660,10 +1667,19 @@ def synonym(
than can be achieved with synonyms.
"""
- return Synonym(name, map_column, descriptor, comparator_factory, doc, info)
+ return Synonym(
+ name,
+ map_column=map_column,
+ descriptor=descriptor,
+ comparator_factory=comparator_factory,
+ doc=doc,
+ info=info,
+ )
-def create_session(bind=None, **kwargs):
+def create_session(
+ bind: Optional[_SessionBind] = None, **kwargs: Any
+) -> Session:
r"""Create a new :class:`.Session`
with no automation enabled by default.
@@ -1699,7 +1715,7 @@ def create_session(bind=None, **kwargs):
return Session(bind=bind, **kwargs)
-def _mapper_fn(*arg, **kw):
+def _mapper_fn(*arg: Any, **kw: Any) -> NoReturn:
"""Placeholder for the now-removed ``mapper()`` function.
Classical mappings should be performed using the
@@ -1726,7 +1742,9 @@ def _mapper_fn(*arg, **kw):
)
-def dynamic_loader(argument, **kw):
+def dynamic_loader(
+ argument: Optional[_RelationshipArgumentType[Any]] = None, **kw: Any
+) -> Relationship[Any]:
"""Construct a dynamically-loading mapper property.
This is essentially the same as
@@ -1746,7 +1764,7 @@ def dynamic_loader(argument, **kw):
return relationship(argument, **kw)
-def backref(name, **kwargs):
+def backref(name: str, **kwargs: Any) -> _ORMBackrefArgument:
"""Create a back reference with explicit keyword arguments, which are the
same arguments one can send to :func:`relationship`.
@@ -1765,7 +1783,11 @@ def backref(name, **kwargs):
return (name, kwargs)
-def deferred(*columns, **kw):
+def deferred(
+ column: _ORMColumnExprArgument[_T],
+ *additional_columns: _ORMColumnExprArgument[Any],
+ **kw: Any,
+) -> ColumnProperty[_T]:
r"""Indicate a column-based mapped attribute that by default will
not load unless accessed.
@@ -1791,7 +1813,8 @@ def deferred(*columns, **kw):
:ref:`deferred`
"""
- return ColumnProperty(deferred=True, *columns, **kw)
+ kw["deferred"] = True
+ return ColumnProperty(column, *additional_columns, **kw)
def query_expression(
@@ -1824,7 +1847,7 @@ def query_expression(
return prop
-def clear_mappers():
+def clear_mappers() -> None:
"""Remove all mappers from all classes.
.. versionchanged:: 1.4 This function now locates all
@@ -2003,16 +2026,16 @@ def aliased(
def with_polymorphic(
- base,
- classes,
- selectable=False,
- flat=False,
- polymorphic_on=None,
- aliased=False,
- adapt_on_names=False,
- innerjoin=False,
- _use_mapper_path=False,
-):
+ base: Union[_O, Mapper[_O]],
+ classes: Iterable[Type[Any]],
+ selectable: Union[Literal[False, None], FromClause] = False,
+ flat: bool = False,
+ polymorphic_on: Optional[ColumnElement[Any]] = None,
+ aliased: bool = False,
+ innerjoin: bool = False,
+ adapt_on_names: bool = False,
+ _use_mapper_path: bool = False,
+) -> AliasedClass[_O]:
"""Produce an :class:`.AliasedClass` construct which specifies
columns for descendant mappers of the given base.
@@ -2096,7 +2119,13 @@ def with_polymorphic(
)
-def join(left, right, onclause=None, isouter=False, full=False):
+def join(
+ left: _FromClauseArgument,
+ right: _FromClauseArgument,
+ onclause: Optional[_OnClauseArgument] = None,
+ isouter: bool = False,
+ full: bool = False,
+) -> _ORMJoin:
r"""Produce an inner join between left and right clauses.
:func:`_orm.join` is an extension to the core join interface
@@ -2135,7 +2164,12 @@ def join(left, right, onclause=None, isouter=False, full=False):
return _ORMJoin(left, right, onclause, isouter, full)
-def outerjoin(left, right, onclause=None, full=False):
+def outerjoin(
+ left: _FromClauseArgument,
+ right: _FromClauseArgument,
+ onclause: Optional[_OnClauseArgument] = None,
+ full: bool = False,
+) -> _ORMJoin:
"""Produce a left outer join between left and right clauses.
This is the "outer join" version of the :func:`_orm.join` function,
diff --git a/lib/sqlalchemy/orm/_typing.py b/lib/sqlalchemy/orm/_typing.py
index 29d82340a..0e624afe2 100644
--- a/lib/sqlalchemy/orm/_typing.py
+++ b/lib/sqlalchemy/orm/_typing.py
@@ -2,8 +2,8 @@ from __future__ import annotations
import operator
from typing import Any
-from typing import Callable
from typing import Dict
+from typing import Mapping
from typing import Optional
from typing import Tuple
from typing import Type
@@ -20,9 +20,12 @@ from ..util.typing import TypeGuard
if TYPE_CHECKING:
from .attributes import AttributeImpl
from .attributes import CollectionAttributeImpl
+ from .attributes import HasCollectionAdapter
+ from .attributes import QueryableAttribute
from .base import PassiveFlag
from .decl_api import registry as _registry_type
from .descriptor_props import _CompositeClassProto
+ from .interfaces import InspectionAttr
from .interfaces import MapperProperty
from .interfaces import UserDefinedOption
from .mapper import Mapper
@@ -30,11 +33,14 @@ if TYPE_CHECKING:
from .state import InstanceState
from .util import AliasedClass
from .util import AliasedInsp
+ from ..sql._typing import _CE
from ..sql.base import ExecutableOption
_T = TypeVar("_T", bound=Any)
+_T_co = TypeVar("_T_co", bound=Any, covariant=True)
+
# I would have preferred this were bound=object however it seems
# to not travel in all situations when defined in that way.
_O = TypeVar("_O", bound=Any)
@@ -42,6 +48,12 @@ _O = TypeVar("_O", bound=Any)
"""
+_OO = TypeVar("_OO", bound=object)
+"""The 'ORM mapped object, that's definitely object' type.
+
+"""
+
+
if TYPE_CHECKING:
_RegistryType = _registry_type
@@ -54,6 +66,7 @@ _EntityType = Union[
]
+_ClassDict = Mapping[str, Any]
_InstanceDict = Dict[str, Any]
_IdentityKeyType = Tuple[Type[_T], Tuple[Any, ...], Optional[Any]]
@@ -64,10 +77,19 @@ _ORMColumnExprArgument = Union[
roles.ExpressionElementRole[_T],
]
-# somehow Protocol didn't want to work for this one
-_ORMAdapterProto = Callable[
- [_ORMColumnExprArgument[_T], Optional[str]], _ORMColumnExprArgument[_T]
-]
+
+_ORMCOLEXPR = TypeVar("_ORMCOLEXPR", bound=ColumnElement[Any])
+
+
+class _ORMAdapterProto(Protocol):
+ """protocol for the :class:`.AliasedInsp._orm_adapt_element` method
+ which is a synonym for :class:`.AliasedInsp._adapt_element`.
+
+
+ """
+
+ def __call__(self, obj: _CE, key: Optional[str] = None) -> _CE:
+ ...
class _LoaderCallable(Protocol):
@@ -96,6 +118,16 @@ if TYPE_CHECKING:
def insp_is_aliased_class(obj: Any) -> TypeGuard[AliasedInsp[Any]]:
...
+ def insp_is_attribute(
+ obj: InspectionAttr,
+ ) -> TypeGuard[QueryableAttribute[Any]]:
+ ...
+
+ def attr_is_internal_proxy(
+ obj: InspectionAttr,
+ ) -> TypeGuard[QueryableAttribute[Any]]:
+ ...
+
def prop_is_relationship(
prop: MapperProperty[Any],
) -> TypeGuard[Relationship[Any]]:
@@ -106,9 +138,19 @@ if TYPE_CHECKING:
) -> TypeGuard[CollectionAttributeImpl]:
...
+ def is_has_collection_adapter(
+ impl: AttributeImpl,
+ ) -> TypeGuard[HasCollectionAdapter]:
+ ...
+
else:
insp_is_mapper_property = operator.attrgetter("is_property")
insp_is_mapper = operator.attrgetter("is_mapper")
insp_is_aliased_class = operator.attrgetter("is_aliased_class")
+ insp_is_attribute = operator.attrgetter("is_attribute")
+ attr_is_internal_proxy = operator.attrgetter("_is_internal_proxy")
is_collection_impl = operator.attrgetter("collection")
prop_is_relationship = operator.attrgetter("_is_relationship")
+ is_has_collection_adapter = operator.attrgetter(
+ "_is_has_collection_adapter"
+ )
diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py
index 9aeaeaa27..b5faa7cbf 100644
--- a/lib/sqlalchemy/orm/attributes.py
+++ b/lib/sqlalchemy/orm/attributes.py
@@ -117,7 +117,9 @@ class NoKey(str):
pass
-_AllPendingType = List[Tuple[Optional["InstanceState[Any]"], Optional[object]]]
+_AllPendingType = Sequence[
+ Tuple[Optional["InstanceState[Any]"], Optional[object]]
+]
NO_KEY = NoKey("no name")
@@ -798,6 +800,8 @@ class AttributeImpl:
supports_population: bool
dynamic: bool
+ _is_has_collection_adapter = False
+
_replace_token: AttributeEventToken
_remove_token: AttributeEventToken
_append_token: AttributeEventToken
@@ -1140,7 +1144,7 @@ class AttributeImpl:
state: InstanceState[Any],
dict_: _InstanceDict,
value: Any,
- initiator: Optional[AttributeEventToken],
+ initiator: Optional[AttributeEventToken] = None,
passive: PassiveFlag = PASSIVE_OFF,
check_old: Any = None,
pop: bool = False,
@@ -1236,7 +1240,7 @@ class ScalarAttributeImpl(AttributeImpl):
state: InstanceState[Any],
dict_: Dict[str, Any],
value: Any,
- initiator: Optional[AttributeEventToken],
+ initiator: Optional[AttributeEventToken] = None,
passive: PassiveFlag = PASSIVE_OFF,
check_old: Optional[object] = None,
pop: bool = False,
@@ -1402,7 +1406,7 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
state: InstanceState[Any],
dict_: _InstanceDict,
value: Any,
- initiator: Optional[AttributeEventToken],
+ initiator: Optional[AttributeEventToken] = None,
passive: PassiveFlag = PASSIVE_OFF,
check_old: Any = None,
pop: bool = False,
@@ -1494,6 +1498,9 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
class HasCollectionAdapter:
__slots__ = ()
+ collection: bool
+ _is_has_collection_adapter = True
+
def _dispose_previous_collection(
self,
state: InstanceState[Any],
@@ -1508,7 +1515,7 @@ class HasCollectionAdapter:
self,
state: InstanceState[Any],
dict_: _InstanceDict,
- user_data: Optional[_AdaptedCollectionProtocol] = None,
+ user_data: Literal[None] = ...,
passive: Literal[PassiveFlag.PASSIVE_OFF] = ...,
) -> CollectionAdapter:
...
@@ -1518,8 +1525,18 @@ class HasCollectionAdapter:
self,
state: InstanceState[Any],
dict_: _InstanceDict,
- user_data: Optional[_AdaptedCollectionProtocol] = None,
- passive: PassiveFlag = PASSIVE_OFF,
+ user_data: _AdaptedCollectionProtocol = ...,
+ passive: PassiveFlag = ...,
+ ) -> CollectionAdapter:
+ ...
+
+ @overload
+ def get_collection(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ user_data: Optional[_AdaptedCollectionProtocol] = ...,
+ passive: PassiveFlag = ...,
) -> Union[
Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter
]:
@@ -1530,12 +1547,25 @@ class HasCollectionAdapter:
state: InstanceState[Any],
dict_: _InstanceDict,
user_data: Optional[_AdaptedCollectionProtocol] = None,
- passive: PassiveFlag = PASSIVE_OFF,
+ passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
) -> Union[
Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter
]:
raise NotImplementedError()
+ def set(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ value: Any,
+ initiator: Optional[AttributeEventToken] = None,
+ passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
+ check_old: Any = None,
+ pop: bool = False,
+ _adapt: bool = True,
+ ) -> None:
+ raise NotImplementedError()
+
if TYPE_CHECKING:
@@ -1790,7 +1820,9 @@ class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl):
initiator: Optional[AttributeEventToken],
passive: PassiveFlag = PASSIVE_OFF,
) -> None:
- collection = self.get_collection(state, dict_, passive=passive)
+ collection = self.get_collection(
+ state, dict_, user_data=None, passive=passive
+ )
if collection is PASSIVE_NO_RESULT:
value = self.fire_append_event(state, dict_, value, initiator)
assert (
@@ -1810,7 +1842,9 @@ class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl):
initiator: Optional[AttributeEventToken],
passive: PassiveFlag = PASSIVE_OFF,
) -> None:
- collection = self.get_collection(state, state.dict, passive=passive)
+ collection = self.get_collection(
+ state, state.dict, user_data=None, passive=passive
+ )
if collection is PASSIVE_NO_RESULT:
self.fire_remove_event(state, dict_, value, initiator)
assert (
@@ -1844,7 +1878,7 @@ class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl):
dict_: _InstanceDict,
value: Any,
initiator: Optional[AttributeEventToken] = None,
- passive: PassiveFlag = PASSIVE_OFF,
+ passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
check_old: Any = None,
pop: bool = False,
_adapt: bool = True,
@@ -1963,7 +1997,7 @@ class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl):
self,
state: InstanceState[Any],
dict_: _InstanceDict,
- user_data: Optional[_AdaptedCollectionProtocol] = None,
+ user_data: Literal[None] = ...,
passive: Literal[PassiveFlag.PASSIVE_OFF] = ...,
) -> CollectionAdapter:
...
@@ -1973,7 +2007,17 @@ class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl):
self,
state: InstanceState[Any],
dict_: _InstanceDict,
- user_data: Optional[_AdaptedCollectionProtocol] = None,
+ user_data: _AdaptedCollectionProtocol = ...,
+ passive: PassiveFlag = ...,
+ ) -> CollectionAdapter:
+ ...
+
+ @overload
+ def get_collection(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ user_data: Optional[_AdaptedCollectionProtocol] = ...,
passive: PassiveFlag = PASSIVE_OFF,
) -> Union[
Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter
@@ -2490,7 +2534,7 @@ def register_attribute_impl(
impl_class: Optional[Type[AttributeImpl]] = None,
backref: Optional[str] = None,
**kw: Any,
-) -> InstrumentedAttribute[Any]:
+) -> QueryableAttribute[Any]:
manager = manager_of_class(class_)
if uselist:
@@ -2599,7 +2643,7 @@ def init_state_collection(
attr._dispose_previous_collection(state, old, old_collection, False)
user_data = attr._default_value(state, dict_)
- adapter = attr.get_collection(state, dict_, user_data)
+ adapter: CollectionAdapter = attr.get_collection(state, dict_, user_data)
adapter._reset_empty()
return adapter
diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py
index 0ace9b1cb..63f873fd0 100644
--- a/lib/sqlalchemy/orm/base.py
+++ b/lib/sqlalchemy/orm/base.py
@@ -18,6 +18,7 @@ from typing import Any
from typing import Callable
from typing import Dict
from typing import Generic
+from typing import no_type_check
from typing import Optional
from typing import overload
from typing import Type
@@ -35,17 +36,20 @@ from ..sql.elements import SQLCoreOperations
from ..util import FastIntFlag
from ..util.langhelpers import TypingOnly
from ..util.typing import Literal
-from ..util.typing import Self
if typing.TYPE_CHECKING:
+ from ._typing import _EntityType
from ._typing import _ExternalEntityType
from ._typing import _InternalEntityType
from .attributes import InstrumentedAttribute
from .instrumentation import ClassManager
+ from .interfaces import PropComparator
from .mapper import Mapper
from .state import InstanceState
from .util import AliasedClass
+ from ..sql._typing import _ColumnExpressionArgument
from ..sql._typing import _InfoType
+ from ..sql.elements import ColumnElement
_T = TypeVar("_T", bound=Any)
@@ -191,35 +195,34 @@ EXT_CONTINUE = util.symbol("EXT_CONTINUE")
EXT_STOP = util.symbol("EXT_STOP")
EXT_SKIP = util.symbol("EXT_SKIP")
-ONETOMANY = util.symbol(
- "ONETOMANY",
+
+class RelationshipDirection(Enum):
+ ONETOMANY = 1
"""Indicates the one-to-many direction for a :func:`_orm.relationship`.
This symbol is typically used by the internals but may be exposed within
certain API features.
- """,
-)
+ """
-MANYTOONE = util.symbol(
- "MANYTOONE",
+ MANYTOONE = 2
"""Indicates the many-to-one direction for a :func:`_orm.relationship`.
This symbol is typically used by the internals but may be exposed within
certain API features.
- """,
-)
+ """
-MANYTOMANY = util.symbol(
- "MANYTOMANY",
+ MANYTOMANY = 3
"""Indicates the many-to-many direction for a :func:`_orm.relationship`.
This symbol is typically used by the internals but may be exposed within
certain API features.
- """,
-)
+ """
+
+
+ONETOMANY, MANYTOONE, MANYTOMANY = tuple(RelationshipDirection)
class InspectionAttrExtensionType(Enum):
@@ -249,7 +252,7 @@ _DEFER_FOR_STATE = util.symbol("DEFER_FOR_STATE")
_RAISE_FOR_STATE = util.symbol("RAISE_FOR_STATE")
-_F = TypeVar("_F", bound=Callable)
+_F = TypeVar("_F", bound=Callable[..., Any])
_Self = TypeVar("_Self")
@@ -397,29 +400,34 @@ def _inspect_mapped_object(instance: _T) -> Optional[InstanceState[_T]]:
return None
-def _class_to_mapper(class_or_mapper: Union[Mapper[_T], _T]) -> Mapper[_T]:
+def _class_to_mapper(
+ class_or_mapper: Union[Mapper[_T], Type[_T]]
+) -> Mapper[_T]:
+ # can't get mypy to see an overload for this
insp = inspection.inspect(class_or_mapper, False)
if insp is not None:
- return insp.mapper
+ return insp.mapper # type: ignore
else:
+ assert isinstance(class_or_mapper, type)
raise exc.UnmappedClassError(class_or_mapper)
def _mapper_or_none(
- entity: Union[_T, _InternalEntityType[_T]]
+ entity: Union[Type[_T], _InternalEntityType[_T]]
) -> Optional[Mapper[_T]]:
"""Return the :class:`_orm.Mapper` for the given class or None if the
class is not mapped.
"""
+ # can't get mypy to see an overload for this
insp = inspection.inspect(entity, False)
if insp is not None:
- return insp.mapper
+ return insp.mapper # type: ignore
else:
return None
-def _is_mapped_class(entity):
+def _is_mapped_class(entity: Any) -> bool:
"""Return True if the given object is a mapped class,
:class:`_orm.Mapper`, or :class:`.AliasedClass`.
"""
@@ -432,20 +440,13 @@ def _is_mapped_class(entity):
)
-def _orm_columns(entity):
- insp = inspection.inspect(entity, False)
- if hasattr(insp, "selectable") and hasattr(insp.selectable, "c"):
- return [c for c in insp.selectable.c]
- else:
- return [entity]
-
-
-def _is_aliased_class(entity):
+def _is_aliased_class(entity: Any) -> bool:
insp = inspection.inspect(entity, False)
return insp is not None and getattr(insp, "is_aliased_class", False)
-def _entity_descriptor(entity, key):
+@no_type_check
+def _entity_descriptor(entity: _EntityType[Any], key: str) -> Any:
"""Return a class attribute given an entity and string name.
May return :class:`.InstrumentedAttribute` or user-defined
@@ -651,16 +652,26 @@ class SQLORMOperations(SQLCoreOperations[_T], TypingOnly):
if typing.TYPE_CHECKING:
- def of_type(self, class_):
+ def of_type(self, class_: _EntityType[Any]) -> PropComparator[_T]:
...
- def and_(self, *criteria):
+ def and_(
+ self, *criteria: _ColumnExpressionArgument[bool]
+ ) -> PropComparator[bool]:
...
- def any(self, criterion=None, **kwargs): # noqa: A001
+ def any( # noqa: A001
+ self,
+ criterion: Optional[_ColumnExpressionArgument[bool]] = None,
+ **kwargs: Any,
+ ) -> ColumnElement[bool]:
...
- def has(self, criterion=None, **kwargs):
+ def has(
+ self,
+ criterion: Optional[_ColumnExpressionArgument[bool]] = None,
+ **kwargs: Any,
+ ) -> ColumnElement[bool]:
...
@@ -673,7 +684,9 @@ class ORMDescriptor(Generic[_T], TypingOnly):
if typing.TYPE_CHECKING:
@overload
- def __get__(self: Self, instance: Any, owner: Literal[None]) -> Self:
+ def __get__(
+ self, instance: Any, owner: Literal[None]
+ ) -> ORMDescriptor[_T]:
...
@overload
diff --git a/lib/sqlalchemy/orm/clsregistry.py b/lib/sqlalchemy/orm/clsregistry.py
index 473468c6c..b3fcd29ea 100644
--- a/lib/sqlalchemy/orm/clsregistry.py
+++ b/lib/sqlalchemy/orm/clsregistry.py
@@ -4,7 +4,6 @@
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
"""Routines to handle the string class registry used by declarative.
@@ -16,7 +15,22 @@ This system allows specification of classes and expressions used in
from __future__ import annotations
import re
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Dict
+from typing import Generator
+from typing import Iterable
+from typing import List
+from typing import Mapping
from typing import MutableMapping
+from typing import NoReturn
+from typing import Optional
+from typing import Set
+from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
+from typing import TypeVar
from typing import Union
import weakref
@@ -29,6 +43,14 @@ from .. import exc
from .. import inspection
from .. import util
from ..sql.schema import _get_table_key
+from ..util.typing import CallableReference
+
+if TYPE_CHECKING:
+ from .relationships import Relationship
+ from ..sql.schema import MetaData
+ from ..sql.schema import Table
+
+_T = TypeVar("_T", bound=Any)
_ClsRegistryType = MutableMapping[str, Union[type, "ClsRegistryToken"]]
@@ -36,10 +58,12 @@ _ClsRegistryType = MutableMapping[str, Union[type, "ClsRegistryToken"]]
# the _decl_class_registry, which is usually weak referencing.
# the internal registries here link to classes with weakrefs and remove
# themselves when all references to contained classes are removed.
-_registries = set()
+_registries: Set[ClsRegistryToken] = set()
-def add_class(classname, cls, decl_class_registry):
+def add_class(
+ classname: str, cls: Type[_T], decl_class_registry: _ClsRegistryType
+) -> None:
"""Add a class to the _decl_class_registry associated with the
given declarative class.
@@ -49,13 +73,15 @@ def add_class(classname, cls, decl_class_registry):
existing = decl_class_registry[classname]
if not isinstance(existing, _MultipleClassMarker):
existing = decl_class_registry[classname] = _MultipleClassMarker(
- [cls, existing]
+ [cls, cast("Type[Any]", existing)]
)
else:
decl_class_registry[classname] = cls
try:
- root_module = decl_class_registry["_sa_module_registry"]
+ root_module = cast(
+ _ModuleMarker, decl_class_registry["_sa_module_registry"]
+ )
except KeyError:
decl_class_registry[
"_sa_module_registry"
@@ -79,7 +105,9 @@ def add_class(classname, cls, decl_class_registry):
module.add_class(classname, cls)
-def remove_class(classname, cls, decl_class_registry):
+def remove_class(
+ classname: str, cls: Type[Any], decl_class_registry: _ClsRegistryType
+) -> None:
if classname in decl_class_registry:
existing = decl_class_registry[classname]
if isinstance(existing, _MultipleClassMarker):
@@ -88,7 +116,9 @@ def remove_class(classname, cls, decl_class_registry):
del decl_class_registry[classname]
try:
- root_module = decl_class_registry["_sa_module_registry"]
+ root_module = cast(
+ _ModuleMarker, decl_class_registry["_sa_module_registry"]
+ )
except KeyError:
return
@@ -102,7 +132,11 @@ def remove_class(classname, cls, decl_class_registry):
module.remove_class(classname, cls)
-def _key_is_empty(key, decl_class_registry, test):
+def _key_is_empty(
+ key: str,
+ decl_class_registry: _ClsRegistryType,
+ test: Callable[[Any], bool],
+) -> bool:
"""test if a key is empty of a certain object.
used for unit tests against the registry to see if garbage collection
@@ -124,6 +158,8 @@ def _key_is_empty(key, decl_class_registry, test):
for sub_thing in thing.contents:
if test(sub_thing):
return False
+ else:
+ raise NotImplementedError("unknown codepath")
else:
return not test(thing)
@@ -142,20 +178,27 @@ class _MultipleClassMarker(ClsRegistryToken):
__slots__ = "on_remove", "contents", "__weakref__"
- def __init__(self, classes, on_remove=None):
+ contents: Set[weakref.ref[Type[Any]]]
+ on_remove: CallableReference[Optional[Callable[[], None]]]
+
+ def __init__(
+ self,
+ classes: Iterable[Type[Any]],
+ on_remove: Optional[Callable[[], None]] = None,
+ ):
self.on_remove = on_remove
self.contents = set(
[weakref.ref(item, self._remove_item) for item in classes]
)
_registries.add(self)
- def remove_item(self, cls):
+ def remove_item(self, cls: Type[Any]) -> None:
self._remove_item(weakref.ref(cls))
- def __iter__(self):
+ def __iter__(self) -> Generator[Optional[Type[Any]], None, None]:
return (ref() for ref in self.contents)
- def attempt_get(self, path, key):
+ def attempt_get(self, path: List[str], key: str) -> Type[Any]:
if len(self.contents) > 1:
raise exc.InvalidRequestError(
'Multiple classes found for path "%s" '
@@ -170,14 +213,14 @@ class _MultipleClassMarker(ClsRegistryToken):
raise NameError(key)
return cls
- def _remove_item(self, ref):
+ def _remove_item(self, ref: weakref.ref[Type[Any]]) -> None:
self.contents.discard(ref)
if not self.contents:
_registries.discard(self)
if self.on_remove:
self.on_remove()
- def add_item(self, item):
+ def add_item(self, item: Type[Any]) -> None:
# protect against class registration race condition against
# asynchronous garbage collection calling _remove_item,
# [ticket:3208]
@@ -206,7 +249,12 @@ class _ModuleMarker(ClsRegistryToken):
__slots__ = "parent", "name", "contents", "mod_ns", "path", "__weakref__"
- def __init__(self, name, parent):
+ parent: Optional[_ModuleMarker]
+ contents: Dict[str, Union[_ModuleMarker, _MultipleClassMarker]]
+ mod_ns: _ModNS
+ path: List[str]
+
+ def __init__(self, name: str, parent: Optional[_ModuleMarker]):
self.parent = parent
self.name = name
self.contents = {}
@@ -217,51 +265,53 @@ class _ModuleMarker(ClsRegistryToken):
self.path = []
_registries.add(self)
- def __contains__(self, name):
+ def __contains__(self, name: str) -> bool:
return name in self.contents
- def __getitem__(self, name):
+ def __getitem__(self, name: str) -> ClsRegistryToken:
return self.contents[name]
- def _remove_item(self, name):
+ def _remove_item(self, name: str) -> None:
self.contents.pop(name, None)
if not self.contents and self.parent is not None:
self.parent._remove_item(self.name)
_registries.discard(self)
- def resolve_attr(self, key):
- return getattr(self.mod_ns, key)
+ def resolve_attr(self, key: str) -> Union[_ModNS, Type[Any]]:
+ return self.mod_ns.__getattr__(key)
- def get_module(self, name):
+ def get_module(self, name: str) -> _ModuleMarker:
if name not in self.contents:
marker = _ModuleMarker(name, self)
self.contents[name] = marker
else:
- marker = self.contents[name]
+ marker = cast(_ModuleMarker, self.contents[name])
return marker
- def add_class(self, name, cls):
+ def add_class(self, name: str, cls: Type[Any]) -> None:
if name in self.contents:
- existing = self.contents[name]
+ existing = cast(_MultipleClassMarker, self.contents[name])
existing.add_item(cls)
else:
existing = self.contents[name] = _MultipleClassMarker(
[cls], on_remove=lambda: self._remove_item(name)
)
- def remove_class(self, name, cls):
+ def remove_class(self, name: str, cls: Type[Any]) -> None:
if name in self.contents:
- existing = self.contents[name]
+ existing = cast(_MultipleClassMarker, self.contents[name])
existing.remove_item(cls)
class _ModNS:
__slots__ = ("__parent",)
- def __init__(self, parent):
+ __parent: _ModuleMarker
+
+ def __init__(self, parent: _ModuleMarker):
self.__parent = parent
- def __getattr__(self, key):
+ def __getattr__(self, key: str) -> Union[_ModNS, Type[Any]]:
try:
value = self.__parent.contents[key]
except KeyError:
@@ -282,10 +332,12 @@ class _ModNS:
class _GetColumns:
__slots__ = ("cls",)
- def __init__(self, cls):
+ cls: Type[Any]
+
+ def __init__(self, cls: Type[Any]):
self.cls = cls
- def __getattr__(self, key):
+ def __getattr__(self, key: str) -> Any:
mp = class_mapper(self.cls, configure=False)
if mp:
if key not in mp.all_orm_descriptors:
@@ -296,6 +348,7 @@ class _GetColumns:
desc = mp.all_orm_descriptors[key]
if desc.extension_type is interfaces.NotExtension.NOT_EXTENSION:
+ assert isinstance(desc, attributes.QueryableAttribute)
prop = desc.property
if isinstance(prop, Synonym):
key = prop.name
@@ -316,15 +369,18 @@ inspection._inspects(_GetColumns)(
class _GetTable:
__slots__ = "key", "metadata"
- def __init__(self, key, metadata):
+ key: str
+ metadata: MetaData
+
+ def __init__(self, key: str, metadata: MetaData):
self.key = key
self.metadata = metadata
- def __getattr__(self, key):
+ def __getattr__(self, key: str) -> Table:
return self.metadata.tables[_get_table_key(key, self.key)]
-def _determine_container(key, value):
+def _determine_container(key: str, value: Any) -> _GetColumns:
if isinstance(value, _MultipleClassMarker):
value = value.attempt_get([], key)
return _GetColumns(value)
@@ -341,7 +397,21 @@ class _class_resolver:
"favor_tables",
)
- def __init__(self, cls, prop, fallback, arg, favor_tables=False):
+ cls: Type[Any]
+ prop: Relationship[Any]
+ fallback: Mapping[str, Any]
+ arg: str
+ favor_tables: bool
+ _resolvers: Tuple[Callable[[str], Any], ...]
+
+ def __init__(
+ self,
+ cls: Type[Any],
+ prop: Relationship[Any],
+ fallback: Mapping[str, Any],
+ arg: str,
+ favor_tables: bool = False,
+ ):
self.cls = cls
self.prop = prop
self.arg = arg
@@ -350,11 +420,12 @@ class _class_resolver:
self._resolvers = ()
self.favor_tables = favor_tables
- def _access_cls(self, key):
+ def _access_cls(self, key: str) -> Any:
cls = self.cls
manager = attributes.manager_of_class(cls)
decl_base = manager.registry
+ assert decl_base is not None
decl_class_registry = decl_base._class_registry
metadata = decl_base.metadata
@@ -362,7 +433,7 @@ class _class_resolver:
if key in metadata.tables:
return metadata.tables[key]
elif key in metadata._schemas:
- return _GetTable(key, cls.metadata)
+ return _GetTable(key, getattr(cls, "metadata", metadata))
if key in decl_class_registry:
return _determine_container(key, decl_class_registry[key])
@@ -371,13 +442,14 @@ class _class_resolver:
if key in metadata.tables:
return metadata.tables[key]
elif key in metadata._schemas:
- return _GetTable(key, cls.metadata)
+ return _GetTable(key, getattr(cls, "metadata", metadata))
- if (
- "_sa_module_registry" in decl_class_registry
- and key in decl_class_registry["_sa_module_registry"]
+ if "_sa_module_registry" in decl_class_registry and key in cast(
+ _ModuleMarker, decl_class_registry["_sa_module_registry"]
):
- registry = decl_class_registry["_sa_module_registry"]
+ registry = cast(
+ _ModuleMarker, decl_class_registry["_sa_module_registry"]
+ )
return registry.resolve_attr(key)
elif self._resolvers:
for resolv in self._resolvers:
@@ -387,7 +459,7 @@ class _class_resolver:
return self.fallback[key]
- def _raise_for_name(self, name, err):
+ def _raise_for_name(self, name: str, err: Exception) -> NoReturn:
generic_match = re.match(r"(.+)\[(.+)\]", name)
if generic_match:
@@ -409,7 +481,7 @@ class _class_resolver:
% (self.prop.parent, self.arg, name, self.cls)
) from err
- def _resolve_name(self):
+ def _resolve_name(self) -> Union[Table, Type[Any], _ModNS]:
name = self.arg
d = self._dict
rval = None
@@ -427,9 +499,11 @@ class _class_resolver:
if isinstance(rval, _GetColumns):
return rval.cls
else:
+ if TYPE_CHECKING:
+ assert isinstance(rval, (type, Table, _ModNS))
return rval
- def __call__(self):
+ def __call__(self) -> Any:
try:
x = eval(self.arg, globals(), self._dict)
@@ -441,10 +515,15 @@ class _class_resolver:
self._raise_for_name(n.args[0], n)
-_fallback_dict = None
+_fallback_dict: Mapping[str, Any] = None # type: ignore
-def _resolver(cls, prop):
+def _resolver(
+ cls: Type[Any], prop: Relationship[Any]
+) -> Tuple[
+ Callable[[str], Callable[[], Union[Type[Any], Table, _ModNS]]],
+ Callable[[str, bool], _class_resolver],
+]:
global _fallback_dict
@@ -456,12 +535,14 @@ def _resolver(cls, prop):
{"foreign": foreign, "remote": remote}
)
- def resolve_arg(arg, favor_tables=False):
+ def resolve_arg(arg: str, favor_tables: bool = False) -> _class_resolver:
return _class_resolver(
cls, prop, _fallback_dict, arg, favor_tables=favor_tables
)
- def resolve_name(arg):
+ def resolve_name(
+ arg: str,
+ ) -> Callable[[], Union[Type[Any], Table, _ModNS]]:
return _class_resolver(cls, prop, _fallback_dict, arg)._resolve_name
return resolve_name, resolve_arg
diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py
index da0da0fcf..78fe89d05 100644
--- a/lib/sqlalchemy/orm/collections.py
+++ b/lib/sqlalchemy/orm/collections.py
@@ -115,6 +115,7 @@ from typing import Collection
from typing import Dict
from typing import Iterable
from typing import List
+from typing import NoReturn
from typing import Optional
from typing import Set
from typing import Tuple
@@ -130,6 +131,7 @@ from ..util.compat import inspect_getfullargspec
from ..util.typing import Protocol
if typing.TYPE_CHECKING:
+ from .attributes import AttributeEventToken
from .attributes import CollectionAttributeImpl
from .mapped_collection import attribute_mapped_collection
from .mapped_collection import column_mapped_collection
@@ -500,7 +502,7 @@ class CollectionAdapter:
self.invalidated = False
self.empty = False
- def _warn_invalidated(self):
+ def _warn_invalidated(self) -> None:
util.warn("This collection has been invalidated.")
@property
@@ -509,7 +511,7 @@ class CollectionAdapter:
return self._data()
@property
- def _referenced_by_owner(self):
+ def _referenced_by_owner(self) -> bool:
"""return True if the owner state still refers to this collection.
This will return False within a bulk replace operation,
@@ -521,7 +523,9 @@ class CollectionAdapter:
def bulk_appender(self):
return self._data()._sa_appender
- def append_with_event(self, item, initiator=None):
+ def append_with_event(
+ self, item: Any, initiator: Optional[AttributeEventToken] = None
+ ) -> None:
"""Add an entity to the collection, firing mutation events."""
self._data()._sa_appender(item, _sa_initiator=initiator)
@@ -533,7 +537,7 @@ class CollectionAdapter:
self.empty = True
self.owner_state._empty_collections[self._key] = user_data
- def _reset_empty(self):
+ def _reset_empty(self) -> None:
assert (
self.empty
), "This collection adapter is not in the 'empty' state"
@@ -542,20 +546,20 @@ class CollectionAdapter:
self._key
] = self.owner_state._empty_collections.pop(self._key)
- def _refuse_empty(self):
+ def _refuse_empty(self) -> NoReturn:
raise sa_exc.InvalidRequestError(
"This is a special 'empty' collection which cannot accommodate "
"internal mutation operations"
)
- def append_without_event(self, item):
+ def append_without_event(self, item: Any) -> None:
"""Add or restore an entity to the collection, firing no events."""
if self.empty:
self._refuse_empty()
self._data()._sa_appender(item, _sa_initiator=False)
- def append_multiple_without_event(self, items):
+ def append_multiple_without_event(self, items: Iterable[Any]) -> None:
"""Add or restore an entity to the collection, firing no events."""
if self.empty:
self._refuse_empty()
@@ -566,17 +570,21 @@ class CollectionAdapter:
def bulk_remover(self):
return self._data()._sa_remover
- def remove_with_event(self, item, initiator=None):
+ def remove_with_event(
+ self, item: Any, initiator: Optional[AttributeEventToken] = None
+ ) -> None:
"""Remove an entity from the collection, firing mutation events."""
self._data()._sa_remover(item, _sa_initiator=initiator)
- def remove_without_event(self, item):
+ def remove_without_event(self, item: Any) -> None:
"""Remove an entity from the collection, firing no events."""
if self.empty:
self._refuse_empty()
self._data()._sa_remover(item, _sa_initiator=False)
- def clear_with_event(self, initiator=None):
+ def clear_with_event(
+ self, initiator: Optional[AttributeEventToken] = None
+ ) -> None:
"""Empty the collection, firing a mutation event for each entity."""
if self.empty:
@@ -585,7 +593,7 @@ class CollectionAdapter:
for item in list(self):
remover(item, _sa_initiator=initiator)
- def clear_without_event(self):
+ def clear_without_event(self) -> None:
"""Empty the collection, firing no events."""
if self.empty:
diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py
index 28fea2f9b..58556bb58 100644
--- a/lib/sqlalchemy/orm/context.py
+++ b/lib/sqlalchemy/orm/context.py
@@ -12,6 +12,7 @@ import itertools
from typing import Any
from typing import cast
from typing import Dict
+from typing import Iterable
from typing import List
from typing import Optional
from typing import Set
@@ -43,6 +44,7 @@ from ..sql import expression
from ..sql import roles
from ..sql import util as sql_util
from ..sql import visitors
+from ..sql._typing import _TP
from ..sql._typing import is_dml
from ..sql._typing import is_insert_update
from ..sql._typing import is_select_base
@@ -55,22 +57,32 @@ from ..sql.base import Options
from ..sql.dml import UpdateBase
from ..sql.elements import GroupedElement
from ..sql.elements import TextClause
-from ..sql.selectable import ExecutableReturnsRows
from ..sql.selectable import LABEL_STYLE_DISAMBIGUATE_ONLY
from ..sql.selectable import LABEL_STYLE_NONE
from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
from ..sql.selectable import Select
from ..sql.selectable import SelectLabelStyle
from ..sql.selectable import SelectState
+from ..sql.selectable import TypedReturnsRows
from ..sql.visitors import InternalTraversal
if TYPE_CHECKING:
from ._typing import _InternalEntityType
+ from .loading import PostLoad
from .mapper import Mapper
from .query import Query
+ from .session import _BindArguments
+ from .session import Session
+ from ..engine.interfaces import _CoreSingleExecuteParams
+ from ..engine.interfaces import _ExecuteOptionsParameter
+ from ..sql._typing import _ColumnsClauseArgument
+ from ..sql.compiler import SQLCompiler
from ..sql.dml import _DMLTableElement
from ..sql.elements import ColumnElement
+ from ..sql.selectable import _JoinTargetElement
from ..sql.selectable import _LabelConventionCallable
+ from ..sql.selectable import _SetupJoinsElement
+ from ..sql.selectable import ExecutableReturnsRows
from ..sql.selectable import SelectBase
from ..sql.type_api import TypeEngine
@@ -80,7 +92,7 @@ _path_registry = PathRegistry.root
_EMPTY_DICT = util.immutabledict()
-LABEL_STYLE_LEGACY_ORM = util.symbol("LABEL_STYLE_LEGACY_ORM")
+LABEL_STYLE_LEGACY_ORM = SelectLabelStyle.LABEL_STYLE_LEGACY_ORM
class QueryContext:
@@ -109,6 +121,10 @@ class QueryContext:
"loaders_require_uniquing",
)
+ runid: int
+ post_load_paths: Dict[PathRegistry, PostLoad]
+ compile_state: ORMCompileState
+
class default_load_options(Options):
_only_return_tuples = False
_populate_existing = False
@@ -123,13 +139,16 @@ class QueryContext:
def __init__(
self,
- compile_state,
- statement,
- params,
- session,
- load_options,
- execution_options=None,
- bind_arguments=None,
+ compile_state: CompileState,
+ statement: Union[Select[Any], FromStatement[Any]],
+ params: _CoreSingleExecuteParams,
+ session: Session,
+ load_options: Union[
+ Type[QueryContext.default_load_options],
+ QueryContext.default_load_options,
+ ],
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ bind_arguments: Optional[_BindArguments] = None,
):
self.load_options = load_options
self.execution_options = execution_options or _EMPTY_DICT
@@ -220,8 +239,8 @@ class ORMCompileState(CompileState):
attributes: Dict[Any, Any]
global_attributes: Dict[Any, Any]
- statement: Union[Select, FromStatement]
- select_statement: Union[Select, FromStatement]
+ statement: Union[Select[Any], FromStatement[Any]]
+ select_statement: Union[Select[Any], FromStatement[Any]]
_entities: List[_QueryEntity]
_polymorphic_adapters: Dict[_InternalEntityType, ORMAdapter]
compile_options: Union[
@@ -238,6 +257,7 @@ class ORMCompileState(CompileState):
Tuple[Any, ...]
]
current_path: PathRegistry = _path_registry
+ _has_mapper_entities = False
def __init__(self, *arg, **kw):
raise NotImplementedError()
@@ -266,7 +286,12 @@ class ORMCompileState(CompileState):
return SelectState._column_naming_convention(label_style)
@classmethod
- def create_for_statement(cls, statement_container, compiler, **kw):
+ def create_for_statement(
+ cls,
+ statement: Union[Select, FromStatement],
+ compiler: Optional[SQLCompiler],
+ **kw: Any,
+ ) -> ORMCompileState:
"""Create a context for a statement given a :class:`.Compiler`.
This method is always invoked in the context of SQLCompiler.process().
@@ -443,7 +468,12 @@ class ORMFromStatementCompileState(ORMCompileState):
eager_joins = _EMPTY_DICT
@classmethod
- def create_for_statement(cls, statement_container, compiler, **kw):
+ def create_for_statement(
+ cls,
+ statement_container: Union[Select, FromStatement],
+ compiler: Optional[SQLCompiler],
+ **kw: Any,
+ ) -> ORMCompileState:
if compiler is not None:
toplevel = not compiler.stack
@@ -577,7 +607,7 @@ class ORMFromStatementCompileState(ORMCompileState):
return None
-class FromStatement(GroupedElement, Generative, ExecutableReturnsRows):
+class FromStatement(GroupedElement, Generative, TypedReturnsRows[_TP]):
"""Core construct that represents a load of ORM objects from various
:class:`.ReturnsRows` and other classes including:
@@ -595,7 +625,7 @@ class FromStatement(GroupedElement, Generative, ExecutableReturnsRows):
_for_update_arg = None
- element: Union[SelectBase, TextClause, UpdateBase]
+ element: Union[ExecutableReturnsRows, TextClause]
_traverse_internals = [
("_raw_columns", InternalTraversal.dp_clauseelement_list),
@@ -606,7 +636,11 @@ class FromStatement(GroupedElement, Generative, ExecutableReturnsRows):
("_compile_options", InternalTraversal.dp_has_cache_key)
]
- def __init__(self, entities, element):
+ def __init__(
+ self,
+ entities: Iterable[_ColumnsClauseArgument[Any]],
+ element: Union[ExecutableReturnsRows, TextClause],
+ ):
self._raw_columns = [
coercions.expect(
roles.ColumnsClauseRole,
@@ -701,7 +735,12 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
_having_criteria = ()
@classmethod
- def create_for_statement(cls, statement, compiler, **kw):
+ def create_for_statement(
+ cls,
+ statement: Union[Select, FromStatement],
+ compiler: Optional[SQLCompiler],
+ **kw: Any,
+ ) -> ORMCompileState:
"""compiler hook, we arrive here from compiler.visit_select() only."""
self = cls.__new__(cls)
@@ -1073,9 +1112,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
)
@classmethod
- @util.preload_module("sqlalchemy.orm.query")
def from_statement(cls, statement, from_statement):
- query = util.preloaded.orm_query
from_statement = coercions.expect(
roles.ReturnsRowsRole,
@@ -1083,7 +1120,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
apply_propagate_attrs=statement,
)
- stmt = query.FromStatement(statement._raw_columns, from_statement)
+ stmt = FromStatement(statement._raw_columns, from_statement)
stmt.__dict__.update(
_with_options=statement._with_options,
@@ -2114,7 +2151,9 @@ def _column_descriptions(
return d
-def _legacy_filter_by_entity_zero(query_or_augmented_select):
+def _legacy_filter_by_entity_zero(
+ query_or_augmented_select: Union[Query[Any], Select[Any]]
+) -> Optional[_InternalEntityType[Any]]:
self = query_or_augmented_select
if self._setup_joins:
_last_joined_entity = self._last_joined_entity
@@ -2127,7 +2166,9 @@ def _legacy_filter_by_entity_zero(query_or_augmented_select):
return _entity_from_pre_ent_zero(self)
-def _entity_from_pre_ent_zero(query_or_augmented_select):
+def _entity_from_pre_ent_zero(
+ query_or_augmented_select: Union[Query[Any], Select[Any]]
+) -> Optional[_InternalEntityType[Any]]:
self = query_or_augmented_select
if not self._raw_columns:
return None
@@ -2144,13 +2185,19 @@ def _entity_from_pre_ent_zero(query_or_augmented_select):
return ent
-def _determine_last_joined_entity(setup_joins, entity_zero=None):
+def _determine_last_joined_entity(
+ setup_joins: Tuple[_SetupJoinsElement, ...],
+ entity_zero: Optional[_InternalEntityType[Any]] = None,
+) -> Optional[Union[_InternalEntityType[Any], _JoinTargetElement]]:
if not setup_joins:
return None
(target, onclause, from_, flags) = setup_joins[-1]
- if isinstance(target, interfaces.PropComparator):
+ if isinstance(
+ target,
+ attributes.QueryableAttribute,
+ ):
return target.entity
else:
return target
@@ -2161,6 +2208,8 @@ class _QueryEntity:
__slots__ = ()
+ supports_single_entity: bool
+
_non_hashable_value = False
_null_column_type = False
use_id_for_hash = False
@@ -2173,6 +2222,9 @@ class _QueryEntity:
def setup_compile_state(self, compile_state: ORMCompileState) -> None:
raise NotImplementedError()
+ def row_processor(self, context, result):
+ raise NotImplementedError()
+
@classmethod
def to_compile_state(
cls, compile_state, entities, entities_collection, is_current_entities
diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py
index fbe35f92a..1c343b04c 100644
--- a/lib/sqlalchemy/orm/decl_api.py
+++ b/lib/sqlalchemy/orm/decl_api.py
@@ -4,6 +4,7 @@
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
"""Public API functions and helpers for declarative."""
from __future__ import annotations
@@ -14,16 +15,21 @@ import typing
from typing import Any
from typing import Callable
from typing import ClassVar
+from typing import Dict
+from typing import FrozenSet
+from typing import Iterator
from typing import Mapping
from typing import Optional
+from typing import overload
+from typing import Set
from typing import Type
+from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
import weakref
from . import attributes
from . import clsregistry
-from . import exc as orm_exc
from . import instrumentation
from . import interfaces
from . import mapperlib
@@ -38,24 +44,40 @@ from .decl_base import _del_attribute
from .decl_base import _mapper
from .descriptor_props import Synonym as _orm_synonym
from .mapper import Mapper
+from .state import InstanceState
from .. import exc
from .. import inspection
from .. import util
from ..sql.elements import SQLCoreOperations
from ..sql.schema import MetaData
from ..sql.selectable import FromClause
-from ..sql.type_api import TypeEngine
from ..util import hybridmethod
from ..util import hybridproperty
from ..util import typing as compat_typing
+from ..util.typing import CallableReference
+from ..util.typing import Literal
+if TYPE_CHECKING:
+ from ._typing import _O
+ from ._typing import _RegistryType
+ from .descriptor_props import Synonym
+ from .instrumentation import ClassManager
+ from .interfaces import MapperProperty
+ from ..sql._typing import _TypeEngineArgument
_T = TypeVar("_T", bound=Any)
-_TypeAnnotationMapType = Mapping[Type, Union[Type[TypeEngine], TypeEngine]]
+# it's not clear how to have Annotated, Union objects etc. as keys here
+# from a typing perspective so just leave it open ended for now
+_TypeAnnotationMapType = Mapping[Any, "_TypeEngineArgument[Any]"]
+_MutableTypeAnnotationMapType = Dict[Any, "_TypeEngineArgument[Any]"]
+
+_DeclaredAttrDecorated = Callable[
+ ..., Union[Mapped[_T], SQLCoreOperations[_T]]
+]
-def has_inherited_table(cls):
+def has_inherited_table(cls: Type[_O]) -> bool:
"""Given a class, return True if any of the classes it inherits from has a
mapped table, otherwise return False.
@@ -75,13 +97,13 @@ def has_inherited_table(cls):
class _DynamicAttributesType(type):
- def __setattr__(cls, key, value):
+ def __setattr__(cls, key: str, value: Any) -> None:
if "__mapper__" in cls.__dict__:
_add_attribute(cls, key, value)
else:
type.__setattr__(cls, key, value)
- def __delattr__(cls, key):
+ def __delattr__(cls, key: str) -> None:
if "__mapper__" in cls.__dict__:
_del_attribute(cls, key)
else:
@@ -89,7 +111,7 @@ class _DynamicAttributesType(type):
class DeclarativeAttributeIntercept(
- _DynamicAttributesType, inspection.Inspectable["Mapper[Any]"]
+ _DynamicAttributesType, inspection.Inspectable[Mapper[Any]]
):
"""Metaclass that may be used in conjunction with the
:class:`_orm.DeclarativeBase` class to support addition of class
@@ -99,10 +121,10 @@ class DeclarativeAttributeIntercept(
class DeclarativeMeta(
- _DynamicAttributesType, inspection.Inspectable["Mapper[Any]"]
+ _DynamicAttributesType, inspection.Inspectable[Mapper[Any]]
):
metadata: MetaData
- registry: "RegistryType"
+ registry: RegistryType
def __init__(
cls, classname: Any, bases: Any, dict_: Any, **kw: Any
@@ -130,7 +152,9 @@ class DeclarativeMeta(
type.__init__(cls, classname, bases, dict_)
-def synonym_for(name, map_column=False):
+def synonym_for(
+ name: str, map_column: bool = False
+) -> Callable[[Callable[..., Any]], Synonym[Any]]:
"""Decorator that produces an :func:`_orm.synonym`
attribute in conjunction with a Python descriptor.
@@ -164,7 +188,7 @@ def synonym_for(name, map_column=False):
"""
- def decorate(fn):
+ def decorate(fn: Callable[..., Any]) -> Synonym[Any]:
return _orm_synonym(name, map_column=map_column, descriptor=fn)
return decorate
@@ -255,16 +279,16 @@ class declared_attr(interfaces._MappedAttribute[_T]):
if typing.TYPE_CHECKING:
- def __set__(self, instance, value):
+ def __set__(self, instance: Any, value: Any) -> None:
...
- def __delete__(self, instance: Any):
+ def __delete__(self, instance: Any) -> None:
...
def __init__(
self,
- fn: Callable[..., Union[Mapped[_T], SQLCoreOperations[_T]]],
- cascading=False,
+ fn: _DeclaredAttrDecorated[_T],
+ cascading: bool = False,
):
self.fget = fn
self._cascading = cascading
@@ -273,10 +297,28 @@ class declared_attr(interfaces._MappedAttribute[_T]):
def _collect_return_annotation(self) -> Optional[Type[Any]]:
return util.get_annotations(self.fget).get("return")
- def __get__(self, instance, owner) -> InstrumentedAttribute[_T]:
+ # this is the Mapped[] API where at class descriptor get time we want
+ # the type checker to see InstrumentedAttribute[_T]. However the
+ # callable function prior to mapping in fact calls the given
+ # declarative function that does not return InstrumentedAttribute
+ @overload
+ def __get__(self, instance: None, owner: Any) -> InstrumentedAttribute[_T]:
+ ...
+
+ @overload
+ def __get__(self, instance: object, owner: Any) -> _T:
+ ...
+
+ def __get__(
+ self, instance: Optional[object], owner: Any
+ ) -> Union[InstrumentedAttribute[_T], _T]:
# the declared_attr needs to make use of a cache that exists
# for the span of the declarative scan_attributes() phase.
# to achieve this we look at the class manager that's configured.
+
+ # note this method should not be called outside of the declarative
+ # setup phase
+
cls = owner
manager = attributes.opt_manager_of_class(cls)
if manager is None:
@@ -287,30 +329,33 @@ class declared_attr(interfaces._MappedAttribute[_T]):
"Unmanaged access of declarative attribute %s from "
"non-mapped class %s" % (self.fget.__name__, cls.__name__)
)
- return self.fget(cls)
+ return self.fget(cls) # type: ignore
elif manager.is_mapped:
# the class is mapped, which means we're outside of the declarative
# scan setup, just run the function.
- return self.fget(cls)
+ return self.fget(cls) # type: ignore
# here, we are inside of the declarative scan. use the registry
# that is tracking the values of these attributes.
declarative_scan = manager.declarative_scan()
+
+ # assert that we are in fact in the declarative scan
assert declarative_scan is not None
+
reg = declarative_scan.declared_attr_reg
if self in reg:
- return reg[self]
+ return reg[self] # type: ignore
else:
reg[self] = obj = self.fget(cls)
- return obj
+ return obj # type: ignore
@hybridmethod
- def _stateful(cls, **kw):
+ def _stateful(cls, **kw: Any) -> _stateful_declared_attr[_T]:
return _stateful_declared_attr(**kw)
@hybridproperty
- def cascading(cls):
+ def cascading(cls) -> _stateful_declared_attr[_T]:
"""Mark a :class:`.declared_attr` as cascading.
This is a special-use modifier which indicates that a column
@@ -372,20 +417,23 @@ class declared_attr(interfaces._MappedAttribute[_T]):
return cls._stateful(cascading=True)
-class _stateful_declared_attr(declared_attr):
- def __init__(self, **kw):
+class _stateful_declared_attr(declared_attr[_T]):
+ kw: Dict[str, Any]
+
+ def __init__(self, **kw: Any):
self.kw = kw
- def _stateful(self, **kw):
+ @hybridmethod
+ def _stateful(self, **kw: Any) -> _stateful_declared_attr[_T]:
new_kw = self.kw.copy()
new_kw.update(kw)
return _stateful_declared_attr(**new_kw)
- def __call__(self, fn):
+ def __call__(self, fn: _DeclaredAttrDecorated[_T]) -> declared_attr[_T]:
return declared_attr(fn, **self.kw)
-def declarative_mixin(cls):
+def declarative_mixin(cls: Type[_T]) -> Type[_T]:
"""Mark a class as providing the feature of "declarative mixin".
E.g.::
@@ -427,9 +475,9 @@ def declarative_mixin(cls):
return cls
-def _setup_declarative_base(cls):
+def _setup_declarative_base(cls: Type[Any]) -> None:
if "metadata" in cls.__dict__:
- metadata = cls.metadata
+ metadata = cls.metadata # type: ignore
else:
metadata = None
@@ -457,15 +505,15 @@ def _setup_declarative_base(cls):
reg = registry(
metadata=metadata, type_annotation_map=type_annotation_map
)
- cls.registry = reg
+ cls.registry = reg # type: ignore
- cls._sa_registry = reg
+ cls._sa_registry = reg # type: ignore
if "metadata" not in cls.__dict__:
- cls.metadata = cls.registry.metadata
+ cls.metadata = cls.registry.metadata # type: ignore
-class DeclarativeBaseNoMeta(inspection.Inspectable["Mapper"]):
+class DeclarativeBaseNoMeta(inspection.Inspectable[Mapper[Any]]):
"""Same as :class:`_orm.DeclarativeBase`, but does not use a metaclass
to intercept new attributes.
@@ -477,10 +525,10 @@ class DeclarativeBaseNoMeta(inspection.Inspectable["Mapper"]):
"""
- registry: ClassVar["registry"]
- _sa_registry: ClassVar["registry"]
+ registry: ClassVar[_RegistryType]
+ _sa_registry: ClassVar[_RegistryType]
metadata: ClassVar[MetaData]
- __mapper__: ClassVar[Mapper]
+ __mapper__: ClassVar[Mapper[Any]]
__table__: Optional[FromClause]
if typing.TYPE_CHECKING:
@@ -496,7 +544,7 @@ class DeclarativeBaseNoMeta(inspection.Inspectable["Mapper"]):
class DeclarativeBase(
- inspection.Inspectable["InstanceState"],
+ inspection.Inspectable[InstanceState[Any]],
metaclass=DeclarativeAttributeIntercept,
):
"""Base class used for declarative class definitions.
@@ -557,10 +605,10 @@ class DeclarativeBase(
"""
- registry: ClassVar["registry"]
- _sa_registry: ClassVar["registry"]
+ registry: ClassVar[_RegistryType]
+ _sa_registry: ClassVar[_RegistryType]
metadata: ClassVar[MetaData]
- __mapper__: ClassVar[Mapper]
+ __mapper__: ClassVar[Mapper[Any]]
__table__: Optional[FromClause]
if typing.TYPE_CHECKING:
@@ -572,10 +620,12 @@ class DeclarativeBase(
if DeclarativeBase in cls.__bases__:
_setup_declarative_base(cls)
else:
- cls._sa_registry.map_declaratively(cls)
+ _as_declarative(cls._sa_registry, cls, cls.__dict__)
-def add_mapped_attribute(target, key, attr):
+def add_mapped_attribute(
+ target: Type[_O], key: str, attr: MapperProperty[Any]
+) -> None:
"""Add a new mapped attribute to an ORM mapped class.
E.g.::
@@ -593,14 +643,15 @@ def add_mapped_attribute(target, key, attr):
def declarative_base(
+ *,
metadata: Optional[MetaData] = None,
- mapper=None,
- cls=object,
- name="Base",
+ mapper: Optional[Callable[..., Mapper[Any]]] = None,
+ cls: Type[Any] = object,
+ name: str = "Base",
class_registry: Optional[clsregistry._ClsRegistryType] = None,
type_annotation_map: Optional[_TypeAnnotationMapType] = None,
constructor: Callable[..., None] = _declarative_constructor,
- metaclass=DeclarativeMeta,
+ metaclass: Type[Any] = DeclarativeMeta,
) -> Any:
r"""Construct a base class for declarative class definitions.
@@ -736,8 +787,19 @@ class registry:
"""
+ _class_registry: clsregistry._ClsRegistryType
+ _managers: weakref.WeakKeyDictionary[ClassManager[Any], Literal[True]]
+ _non_primary_mappers: weakref.WeakKeyDictionary[Mapper[Any], Literal[True]]
+ metadata: MetaData
+ constructor: CallableReference[Callable[..., None]]
+ type_annotation_map: _MutableTypeAnnotationMapType
+ _dependents: Set[_RegistryType]
+ _dependencies: Set[_RegistryType]
+ _new_mappers: bool
+
def __init__(
self,
+ *,
metadata: Optional[MetaData] = None,
class_registry: Optional[clsregistry._ClsRegistryType] = None,
type_annotation_map: Optional[_TypeAnnotationMapType] = None,
@@ -799,9 +861,7 @@ class registry:
def update_type_annotation_map(
self,
- type_annotation_map: Mapping[
- Type, Union[Type[TypeEngine], TypeEngine]
- ],
+ type_annotation_map: _TypeAnnotationMapType,
) -> None:
"""update the :paramref:`_orm.registry.type_annotation_map` with new
values."""
@@ -817,20 +877,20 @@ class registry:
)
@property
- def mappers(self):
+ def mappers(self) -> FrozenSet[Mapper[Any]]:
"""read only collection of all :class:`_orm.Mapper` objects."""
return frozenset(manager.mapper for manager in self._managers).union(
self._non_primary_mappers
)
- def _set_depends_on(self, registry):
+ def _set_depends_on(self, registry: RegistryType) -> None:
if registry is self:
return
registry._dependents.add(self)
self._dependencies.add(registry)
- def _flag_new_mapper(self, mapper):
+ def _flag_new_mapper(self, mapper: Mapper[Any]) -> None:
mapper._ready_for_configure = True
if self._new_mappers:
return
@@ -839,7 +899,9 @@ class registry:
reg._new_mappers = True
@classmethod
- def _recurse_with_dependents(cls, registries):
+ def _recurse_with_dependents(
+ cls, registries: Set[RegistryType]
+ ) -> Iterator[RegistryType]:
todo = registries
done = set()
while todo:
@@ -856,7 +918,9 @@ class registry:
todo.update(reg._dependents.difference(done))
@classmethod
- def _recurse_with_dependencies(cls, registries):
+ def _recurse_with_dependencies(
+ cls, registries: Set[RegistryType]
+ ) -> Iterator[RegistryType]:
todo = registries
done = set()
while todo:
@@ -873,7 +937,7 @@ class registry:
# them before
todo.update(reg._dependencies.difference(done))
- def _mappers_to_configure(self):
+ def _mappers_to_configure(self) -> Iterator[Mapper[Any]]:
return itertools.chain(
(
manager.mapper
@@ -889,13 +953,13 @@ class registry:
),
)
- def _add_non_primary_mapper(self, np_mapper):
+ def _add_non_primary_mapper(self, np_mapper: Mapper[Any]) -> None:
self._non_primary_mappers[np_mapper] = True
- def _dispose_cls(self, cls):
+ def _dispose_cls(self, cls: Type[_O]) -> None:
clsregistry.remove_class(cls.__name__, cls, self._class_registry)
- def _add_manager(self, manager):
+ def _add_manager(self, manager: ClassManager[Any]) -> None:
self._managers[manager] = True
if manager.is_mapped:
raise exc.ArgumentError(
@@ -905,7 +969,7 @@ class registry:
assert manager.registry is None
manager.registry = self
- def configure(self, cascade=False):
+ def configure(self, cascade: bool = False) -> None:
"""Configure all as-yet unconfigured mappers in this
:class:`_orm.registry`.
@@ -946,7 +1010,7 @@ class registry:
"""
mapperlib._configure_registries({self}, cascade=cascade)
- def dispose(self, cascade=False):
+ def dispose(self, cascade: bool = False) -> None:
"""Dispose of all mappers in this :class:`_orm.registry`.
After invocation, all the classes that were mapped within this registry
@@ -972,7 +1036,7 @@ class registry:
mapperlib._dispose_registries({self}, cascade=cascade)
- def _dispose_manager_and_mapper(self, manager):
+ def _dispose_manager_and_mapper(self, manager: ClassManager[Any]) -> None:
if "mapper" in manager.__dict__:
mapper = manager.mapper
@@ -984,11 +1048,11 @@ class registry:
def generate_base(
self,
- mapper=None,
- cls=object,
- name="Base",
- metaclass=DeclarativeMeta,
- ):
+ mapper: Optional[Callable[..., Mapper[Any]]] = None,
+ cls: Type[Any] = object,
+ name: str = "Base",
+ metaclass: Type[Any] = DeclarativeMeta,
+ ) -> Any:
"""Generate a declarative base class.
Classes that inherit from the returned class object will be
@@ -1070,7 +1134,7 @@ class registry:
if hasattr(cls, "__class_getitem__"):
- def __class_getitem__(cls, key):
+ def __class_getitem__(cls: Type[_T], key: str) -> Type[_T]:
# allow generic classes in py3.9+
return cls
@@ -1078,7 +1142,7 @@ class registry:
return metaclass(name, bases, class_dict)
- def mapped(self, cls):
+ def mapped(self, cls: Type[_O]) -> Type[_O]:
"""Class decorator that will apply the Declarative mapping process
to a given class.
@@ -1114,7 +1178,7 @@ class registry:
_as_declarative(self, cls, cls.__dict__)
return cls
- def as_declarative_base(self, **kw):
+ def as_declarative_base(self, **kw: Any) -> Callable[[Type[_T]], Type[_T]]:
"""
Class decorator which will invoke
:meth:`_orm.registry.generate_base`
@@ -1142,14 +1206,14 @@ class registry:
"""
- def decorate(cls):
+ def decorate(cls: Type[_T]) -> Type[_T]:
kw["cls"] = cls
kw["name"] = cls.__name__
- return self.generate_base(**kw)
+ return self.generate_base(**kw) # type: ignore
return decorate
- def map_declaratively(self, cls):
+ def map_declaratively(self, cls: Type[_O]) -> Mapper[_O]:
"""Map a class declaratively.
In this form of mapping, the class is scanned for mapping information,
@@ -1194,9 +1258,15 @@ class registry:
:meth:`_orm.registry.map_imperatively`
"""
- return _as_declarative(self, cls, cls.__dict__)
+ _as_declarative(self, cls, cls.__dict__)
+ return cls.__mapper__ # type: ignore
- def map_imperatively(self, class_, local_table=None, **kw):
+ def map_imperatively(
+ self,
+ class_: Type[_O],
+ local_table: Optional[FromClause] = None,
+ **kw: Any,
+ ) -> Mapper[_O]:
r"""Map a class imperatively.
In this form of mapping, the class is not scanned for any mapping
@@ -1251,7 +1321,7 @@ class registry:
RegistryType = registry
-def as_declarative(**kw):
+def as_declarative(**kw: Any) -> Callable[[Type[_T]], Type[_T]]:
"""
Class decorator which will adapt a given class into a
:func:`_orm.declarative_base`.
@@ -1292,14 +1362,9 @@ def as_declarative(**kw):
@inspection._inspects(
DeclarativeMeta, DeclarativeBase, DeclarativeAttributeIntercept
)
-def _inspect_decl_meta(cls: Type[Any]) -> Mapper[Any]:
- mp: Mapper[Any] = _inspect_mapped_class(cls)
+def _inspect_decl_meta(cls: Type[Any]) -> Optional[Mapper[Any]]:
+ mp: Optional[Mapper[Any]] = _inspect_mapped_class(cls)
if mp is None:
if _DeferredMapperConfig.has_cls(cls):
_DeferredMapperConfig.raise_unmapped_for_cls(cls)
- raise orm_exc.UnmappedClassError(
- cls,
- msg="Class %s has a deferred mapping on it. It is not yet "
- "usable as a mapped class." % orm_exc._safe_cls_name(cls),
- )
return mp
diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py
index b1f81cb6b..c3faac36c 100644
--- a/lib/sqlalchemy/orm/decl_base.py
+++ b/lib/sqlalchemy/orm/decl_base.py
@@ -4,16 +4,26 @@
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
"""Internal implementation for declarative."""
from __future__ import annotations
import collections
from typing import Any
+from typing import Callable
+from typing import cast
from typing import Dict
+from typing import Iterable
+from typing import List
+from typing import Mapping
+from typing import NoReturn
+from typing import Optional
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
+from typing import TypeVar
+from typing import Union
import weakref
from . import attributes
@@ -21,6 +31,8 @@ from . import clsregistry
from . import exc as orm_exc
from . import instrumentation
from . import mapperlib
+from ._typing import _O
+from ._typing import attr_is_internal_proxy
from .attributes import InstrumentedAttribute
from .attributes import QueryableAttribute
from .base import _is_mapped_class
@@ -32,6 +44,7 @@ from .interfaces import _MappedAttribute
from .interfaces import _MapsColumns
from .interfaces import MapperProperty
from .mapper import Mapper as mapper
+from .mapper import Mapper
from .properties import ColumnProperty
from .properties import MappedColumn
from .util import _is_mapped_annotation
@@ -43,12 +56,41 @@ from ..sql import expression
from ..sql.schema import Column
from ..sql.schema import Table
from ..util import topological
+from ..util.typing import Protocol
if TYPE_CHECKING:
+ from ._typing import _ClassDict
from ._typing import _RegistryType
+ from .decl_api import declared_attr
+ from .instrumentation import ClassManager
+ from ..sql.schema import MetaData
+ from ..sql.selectable import FromClause
+
+_T = TypeVar("_T", bound=Any)
+
+_MapperKwArgs = Mapping[str, Any]
+
+_TableArgsType = Union[Tuple[Any, ...], Dict[str, Any]]
-def _declared_mapping_info(cls):
+class _DeclMappedClassProtocol(Protocol[_O]):
+ metadata: MetaData
+ __mapper__: Mapper[_O]
+ __table__: Table
+ __tablename__: str
+ __mapper_args__: Mapping[str, Any]
+ __table_args__: Optional[_TableArgsType]
+
+ def __declare_first__(self) -> None:
+ pass
+
+ def __declare_last__(self) -> None:
+ pass
+
+
+def _declared_mapping_info(
+ cls: Type[Any],
+) -> Optional[Union[_DeferredMapperConfig, Mapper[Any]]]:
# deferred mapping
if _DeferredMapperConfig.has_cls(cls):
return _DeferredMapperConfig.config_for_cls(cls)
@@ -59,13 +101,15 @@ def _declared_mapping_info(cls):
return None
-def _resolve_for_abstract_or_classical(cls):
+def _resolve_for_abstract_or_classical(cls: Type[Any]) -> Optional[Type[Any]]:
if cls is object:
return None
+ sup: Optional[Type[Any]]
+
if cls.__dict__.get("__abstract__", False):
- for sup in cls.__bases__:
- sup = _resolve_for_abstract_or_classical(sup)
+ for base_ in cls.__bases__:
+ sup = _resolve_for_abstract_or_classical(base_)
if sup is not None:
return sup
else:
@@ -79,7 +123,9 @@ def _resolve_for_abstract_or_classical(cls):
return cls
-def _get_immediate_cls_attr(cls, attrname, strict=False):
+def _get_immediate_cls_attr(
+ cls: Type[Any], attrname: str, strict: bool = False
+) -> Optional[Any]:
"""return an attribute of the class that is either present directly
on the class, e.g. not on a superclass, or is from a superclass but
this superclass is a non-mapped mixin, that is, not a descendant of
@@ -102,7 +148,7 @@ def _get_immediate_cls_attr(cls, attrname, strict=False):
return getattr(cls, attrname)
for base in cls.__mro__[1:]:
- _is_classicial_inherits = _dive_for_cls_manager(base)
+ _is_classicial_inherits = _dive_for_cls_manager(base) is not None
if attrname in base.__dict__ and (
base is cls
@@ -116,33 +162,37 @@ def _get_immediate_cls_attr(cls, attrname, strict=False):
return None
-def _dive_for_cls_manager(cls):
+def _dive_for_cls_manager(cls: Type[_O]) -> Optional[ClassManager[_O]]:
# because the class manager registration is pluggable,
# we need to do the search for every class in the hierarchy,
# rather than just a simple "cls._sa_class_manager"
- # python 2 old style class
- if not hasattr(cls, "__mro__"):
- return None
-
for base in cls.__mro__:
- manager = attributes.opt_manager_of_class(base)
+ manager: Optional[ClassManager[_O]] = attributes.opt_manager_of_class(
+ base
+ )
if manager:
return manager
return None
-def _as_declarative(registry, cls, dict_):
+def _as_declarative(
+ registry: _RegistryType, cls: Type[Any], dict_: _ClassDict
+) -> Optional[_MapperConfig]:
# declarative scans the class for attributes. no table or mapper
# args passed separately.
-
return _MapperConfig.setup_mapping(registry, cls, dict_, None, {})
-def _mapper(registry, cls, table, mapper_kw):
+def _mapper(
+ registry: _RegistryType,
+ cls: Type[_O],
+ table: Optional[FromClause],
+ mapper_kw: _MapperKwArgs,
+) -> Mapper[_O]:
_ImperativeMapperConfig(registry, cls, table, mapper_kw)
- return cls.__mapper__
+ return cast("_DeclMappedClassProtocol[_O]", cls).__mapper__
@util.preload_module("sqlalchemy.orm.decl_api")
@@ -152,7 +202,9 @@ def _is_declarative_props(obj: Any) -> bool:
return isinstance(obj, (declared_attr, util.classproperty))
-def _check_declared_props_nocascade(obj, name, cls):
+def _check_declared_props_nocascade(
+ obj: Any, name: str, cls: Type[_O]
+) -> bool:
if _is_declarative_props(obj):
if getattr(obj, "_cascading", False):
util.warn(
@@ -174,8 +226,20 @@ class _MapperConfig:
"__weakref__",
)
+ cls: Type[Any]
+ classname: str
+ properties: util.OrderedDict[str, MapperProperty[Any]]
+ declared_attr_reg: Dict[declared_attr[Any], Any]
+
@classmethod
- def setup_mapping(cls, registry, cls_, dict_, table, mapper_kw):
+ def setup_mapping(
+ cls,
+ registry: _RegistryType,
+ cls_: Type[_O],
+ dict_: _ClassDict,
+ table: Optional[FromClause],
+ mapper_kw: _MapperKwArgs,
+ ) -> Optional[_MapperConfig]:
manager = attributes.opt_manager_of_class(cls)
if manager and manager.class_ is cls_:
raise exc.InvalidRequestError(
@@ -183,24 +247,26 @@ class _MapperConfig:
)
if cls_.__dict__.get("__abstract__", False):
- return
+ return None
defer_map = _get_immediate_cls_attr(
cls_, "_sa_decl_prepare_nocascade", strict=True
) or hasattr(cls_, "_sa_decl_prepare")
if defer_map:
- cfg_cls = _DeferredMapperConfig
+ return _DeferredMapperConfig(
+ registry, cls_, dict_, table, mapper_kw
+ )
else:
- cfg_cls = _ClassScanMapperConfig
-
- return cfg_cls(registry, cls_, dict_, table, mapper_kw)
+ return _ClassScanMapperConfig(
+ registry, cls_, dict_, table, mapper_kw
+ )
def __init__(
self,
registry: _RegistryType,
cls_: Type[Any],
- mapper_kw: Dict[str, Any],
+ mapper_kw: _MapperKwArgs,
):
self.cls = util.assert_arg_type(cls_, type, "cls_")
self.classname = cls_.__name__
@@ -224,13 +290,16 @@ class _MapperConfig:
"Mapper." % self.cls
)
- def set_cls_attribute(self, attrname, value):
+ def set_cls_attribute(self, attrname: str, value: _T) -> _T:
manager = instrumentation.manager_of_class(self.cls)
manager.install_member(attrname, value)
return value
- def _early_mapping(self, mapper_kw):
+ def map(self, mapper_kw: _MapperKwArgs = ...) -> Mapper[Any]:
+ raise NotImplementedError()
+
+ def _early_mapping(self, mapper_kw: _MapperKwArgs) -> None:
self.map(mapper_kw)
@@ -239,10 +308,10 @@ class _ImperativeMapperConfig(_MapperConfig):
def __init__(
self,
- registry,
- cls_,
- table,
- mapper_kw,
+ registry: _RegistryType,
+ cls_: Type[_O],
+ table: Optional[FromClause],
+ mapper_kw: _MapperKwArgs,
):
super(_ImperativeMapperConfig, self).__init__(
registry, cls_, mapper_kw
@@ -260,7 +329,7 @@ class _ImperativeMapperConfig(_MapperConfig):
self._early_mapping(mapper_kw)
- def map(self, mapper_kw=util.EMPTY_DICT):
+ def map(self, mapper_kw: _MapperKwArgs = util.EMPTY_DICT) -> Mapper[Any]:
mapper_cls = mapper
return self.set_cls_attribute(
@@ -268,7 +337,7 @@ class _ImperativeMapperConfig(_MapperConfig):
mapper_cls(self.cls, self.local_table, **mapper_kw),
)
- def _setup_inheritance(self, mapper_kw):
+ def _setup_inheritance(self, mapper_kw: _MapperKwArgs) -> None:
cls = self.cls
inherits = mapper_kw.get("inherits", None)
@@ -277,8 +346,8 @@ class _ImperativeMapperConfig(_MapperConfig):
# since we search for classical mappings now, search for
# multiple mapped bases as well and raise an error.
inherits_search = []
- for c in cls.__bases__:
- c = _resolve_for_abstract_or_classical(c)
+ for base_ in cls.__bases__:
+ c = _resolve_for_abstract_or_classical(base_)
if c is None:
continue
if _declared_mapping_info(
@@ -318,13 +387,30 @@ class _ClassScanMapperConfig(_MapperConfig):
"inherits",
)
+ registry: _RegistryType
+ clsdict_view: _ClassDict
+ collected_annotations: Dict[str, Tuple[Any, bool]]
+ collected_attributes: Dict[str, Any]
+ local_table: Optional[FromClause]
+ persist_selectable: Optional[FromClause]
+ declared_columns: util.OrderedSet[Column[Any]]
+ column_copies: Dict[
+ Union[MappedColumn[Any], Column[Any]],
+ Union[MappedColumn[Any], Column[Any]],
+ ]
+ tablename: Optional[str]
+ mapper_args: Mapping[str, Any]
+ table_args: Optional[_TableArgsType]
+ mapper_args_fn: Optional[Callable[[], Dict[str, Any]]]
+ inherits: Optional[Type[Any]]
+
def __init__(
self,
- registry,
- cls_,
- dict_,
- table,
- mapper_kw,
+ registry: _RegistryType,
+ cls_: Type[_O],
+ dict_: _ClassDict,
+ table: Optional[FromClause],
+ mapper_kw: _MapperKwArgs,
):
# grab class dict before the instrumentation manager has been added.
@@ -337,7 +423,7 @@ class _ClassScanMapperConfig(_MapperConfig):
self.persist_selectable = None
self.collected_attributes = {}
- self.collected_annotations: Dict[str, Tuple[Any, bool]] = {}
+ self.collected_annotations = {}
self.declared_columns = util.OrderedSet()
self.column_copies = {}
@@ -360,31 +446,37 @@ class _ClassScanMapperConfig(_MapperConfig):
self._early_mapping(mapper_kw)
- def _setup_declared_events(self):
+ def _setup_declared_events(self) -> None:
if _get_immediate_cls_attr(self.cls, "__declare_last__"):
@event.listens_for(mapper, "after_configured")
- def after_configured():
- self.cls.__declare_last__()
+ def after_configured() -> None:
+ cast(
+ "_DeclMappedClassProtocol[Any]", self.cls
+ ).__declare_last__()
if _get_immediate_cls_attr(self.cls, "__declare_first__"):
@event.listens_for(mapper, "before_configured")
- def before_configured():
- self.cls.__declare_first__()
-
- def _cls_attr_override_checker(self, cls):
+ def before_configured() -> None:
+ cast(
+ "_DeclMappedClassProtocol[Any]", self.cls
+ ).__declare_first__()
+
+ def _cls_attr_override_checker(
+ self, cls: Type[_O]
+ ) -> Callable[[str, Any], bool]:
"""Produce a function that checks if a class has overridden an
attribute, taking SQLAlchemy-enabled dataclass fields into account.
"""
sa_dataclass_metadata_key = _get_immediate_cls_attr(
- cls, "__sa_dataclass_metadata_key__", None
+ cls, "__sa_dataclass_metadata_key__"
)
if sa_dataclass_metadata_key is None:
- def attribute_is_overridden(key, obj):
+ def attribute_is_overridden(key: str, obj: Any) -> bool:
return getattr(cls, key) is not obj
else:
@@ -402,7 +494,7 @@ class _ClassScanMapperConfig(_MapperConfig):
absent = object()
- def attribute_is_overridden(key, obj):
+ def attribute_is_overridden(key: str, obj: Any) -> bool:
if _is_declarative_props(obj):
obj = obj.fget
@@ -457,13 +549,15 @@ class _ClassScanMapperConfig(_MapperConfig):
]
)
- def _cls_attr_resolver(self, cls):
+ def _cls_attr_resolver(
+ self, cls: Type[Any]
+ ) -> Callable[[], Iterable[Tuple[str, Any, Any, bool]]]:
"""produce a function to iterate the "attributes" of a class,
adjusting for SQLAlchemy fields embedded in dataclass fields.
"""
- sa_dataclass_metadata_key = _get_immediate_cls_attr(
- cls, "__sa_dataclass_metadata_key__", None
+ sa_dataclass_metadata_key: Optional[str] = _get_immediate_cls_attr(
+ cls, "__sa_dataclass_metadata_key__"
)
cls_annotations = util.get_annotations(cls)
@@ -477,7 +571,9 @@ class _ClassScanMapperConfig(_MapperConfig):
)
if sa_dataclass_metadata_key is None:
- def local_attributes_for_class():
+ def local_attributes_for_class() -> Iterable[
+ Tuple[str, Any, Any, bool]
+ ]:
return (
(
name,
@@ -493,12 +589,16 @@ class _ClassScanMapperConfig(_MapperConfig):
field.name: field for field in util.local_dataclass_fields(cls)
}
- def local_attributes_for_class():
+ fixed_sa_dataclass_metadata_key = sa_dataclass_metadata_key
+
+ def local_attributes_for_class() -> Iterable[
+ Tuple[str, Any, Any, bool]
+ ]:
for name in names:
field = dataclass_fields.get(name, None)
if field and sa_dataclass_metadata_key in field.metadata:
yield field.name, _as_dc_declaredattr(
- field.metadata, sa_dataclass_metadata_key
+ field.metadata, fixed_sa_dataclass_metadata_key
), cls_annotations.get(field.name), True
else:
yield name, cls_vars.get(name), cls_annotations.get(
@@ -507,14 +607,17 @@ class _ClassScanMapperConfig(_MapperConfig):
return local_attributes_for_class
- def _scan_attributes(self):
+ def _scan_attributes(self) -> None:
cls = self.cls
+ cls_as_Decl = cast("_DeclMappedClassProtocol[Any]", cls)
+
clsdict_view = self.clsdict_view
collected_attributes = self.collected_attributes
column_copies = self.column_copies
mapper_args_fn = None
table_args = inherited_table_args = None
+
tablename = None
fixed_table = "__table__" in clsdict_view
@@ -555,21 +658,23 @@ class _ClassScanMapperConfig(_MapperConfig):
# make a copy of it so a class-level dictionary
# is not overwritten when we update column-based
# arguments.
- def mapper_args_fn():
- return dict(cls.__mapper_args__)
+ def _mapper_args_fn() -> Dict[str, Any]:
+ return dict(cls_as_Decl.__mapper_args__)
+
+ mapper_args_fn = _mapper_args_fn
elif name == "__tablename__":
check_decl = _check_declared_props_nocascade(
obj, name, cls
)
if not tablename and (not class_mapped or check_decl):
- tablename = cls.__tablename__
+ tablename = cls_as_Decl.__tablename__
elif name == "__table_args__":
check_decl = _check_declared_props_nocascade(
obj, name, cls
)
if not table_args and (not class_mapped or check_decl):
- table_args = cls.__table_args__
+ table_args = cls_as_Decl.__table_args__
if not isinstance(
table_args, (tuple, dict, type(None))
):
@@ -657,9 +762,10 @@ class _ClassScanMapperConfig(_MapperConfig):
# or similar. note there is no known case that
# produces nested proxies, so we are only
# looking one level deep right now.
+
if (
isinstance(ret, InspectionAttr)
- and ret._is_internal_proxy
+ and attr_is_internal_proxy(ret)
and not isinstance(
ret.original_property, MapperProperty
)
@@ -669,6 +775,7 @@ class _ClassScanMapperConfig(_MapperConfig):
collected_attributes[name] = column_copies[
obj
] = ret
+
if (
isinstance(ret, (Column, MapperProperty))
and ret.doc is None
@@ -737,7 +844,9 @@ class _ClassScanMapperConfig(_MapperConfig):
self.tablename = tablename
self.mapper_args_fn = mapper_args_fn
- def _warn_for_decl_attributes(self, cls, key, c):
+ def _warn_for_decl_attributes(
+ self, cls: Type[Any], key: str, c: Any
+ ) -> None:
if isinstance(c, expression.ColumnClause):
util.warn(
f"Attribute '{key}' on class {cls} appears to "
@@ -746,8 +855,12 @@ class _ClassScanMapperConfig(_MapperConfig):
)
def _produce_column_copies(
- self, attributes_for_class, attribute_is_overridden
- ):
+ self,
+ attributes_for_class: Callable[
+ [], Iterable[Tuple[str, Any, Any, bool]]
+ ],
+ attribute_is_overridden: Callable[[str, Any], bool],
+ ) -> None:
cls = self.cls
dict_ = self.clsdict_view
collected_attributes = self.collected_attributes
@@ -763,7 +876,8 @@ class _ClassScanMapperConfig(_MapperConfig):
continue
elif name not in dict_ and not (
"__table__" in dict_
- and (obj.name or name) in dict_["__table__"].c
+ and (getattr(obj, "name", None) or name)
+ in dict_["__table__"].c
):
if obj.foreign_keys:
for fk in obj.foreign_keys:
@@ -786,7 +900,7 @@ class _ClassScanMapperConfig(_MapperConfig):
setattr(cls, name, copy_)
- def _extract_mappable_attributes(self):
+ def _extract_mappable_attributes(self) -> None:
cls = self.cls
collected_attributes = self.collected_attributes
@@ -858,17 +972,19 @@ class _ClassScanMapperConfig(_MapperConfig):
"declarative base class."
)
elif isinstance(value, Column):
- _undefer_column_name(k, self.column_copies.get(value, value))
+ _undefer_column_name(
+ k, self.column_copies.get(value, value) # type: ignore
+ )
elif isinstance(value, _IntrospectsAnnotations):
annotation, is_dataclass = self.collected_annotations.get(
- k, (None, None)
+ k, (None, False)
)
value.declarative_scan(
self.registry, cls, k, annotation, is_dataclass
)
our_stuff[k] = value
- def _extract_declared_columns(self):
+ def _extract_declared_columns(self) -> None:
our_stuff = self.properties
# extract columns from the class dict
@@ -914,8 +1030,10 @@ class _ClassScanMapperConfig(_MapperConfig):
% (self.classname, name, (", ".join(sorted(keys))))
)
- def _setup_table(self, table=None):
+ def _setup_table(self, table: Optional[FromClause] = None) -> None:
cls = self.cls
+ cls_as_Decl = cast("_DeclMappedClassProtocol[Any]", cls)
+
tablename = self.tablename
table_args = self.table_args
clsdict_view = self.clsdict_view
@@ -925,13 +1043,18 @@ class _ClassScanMapperConfig(_MapperConfig):
if "__table__" not in clsdict_view and table is None:
if hasattr(cls, "__table_cls__"):
- table_cls = util.unbound_method_to_callable(cls.__table_cls__)
+ table_cls = cast(
+ Type[Table],
+ util.unbound_method_to_callable(cls.__table_cls__), # type: ignore # noqa: E501
+ )
else:
table_cls = Table
if tablename is not None:
- args, table_kw = (), {}
+ args: Tuple[Any, ...] = ()
+ table_kw: Dict[str, Any] = {}
+
if table_args:
if isinstance(table_args, dict):
table_kw = table_args
@@ -960,7 +1083,7 @@ class _ClassScanMapperConfig(_MapperConfig):
)
else:
if table is None:
- table = cls.__table__
+ table = cls_as_Decl.__table__
if declared_columns:
for c in declared_columns:
if not table.c.contains_column(c):
@@ -968,15 +1091,16 @@ class _ClassScanMapperConfig(_MapperConfig):
"Can't add additional column %r when "
"specifying __table__" % c.key
)
+
self.local_table = table
- def _metadata_for_cls(self, manager):
+ def _metadata_for_cls(self, manager: ClassManager[Any]) -> MetaData:
if hasattr(self.cls, "metadata"):
- return self.cls.metadata
+ return cast("_DeclMappedClassProtocol[Any]", self.cls).metadata
else:
return manager.registry.metadata
- def _setup_inheritance(self, mapper_kw):
+ def _setup_inheritance(self, mapper_kw: _MapperKwArgs) -> None:
table = self.local_table
cls = self.cls
table_args = self.table_args
@@ -988,8 +1112,8 @@ class _ClassScanMapperConfig(_MapperConfig):
# since we search for classical mappings now, search for
# multiple mapped bases as well and raise an error.
inherits_search = []
- for c in cls.__bases__:
- c = _resolve_for_abstract_or_classical(c)
+ for base_ in cls.__bases__:
+ c = _resolve_for_abstract_or_classical(base_)
if c is None:
continue
if _declared_mapping_info(
@@ -1024,9 +1148,12 @@ class _ClassScanMapperConfig(_MapperConfig):
"table-mapped class." % cls
)
elif self.inherits:
- inherited_mapper = _declared_mapping_info(self.inherits)
- inherited_table = inherited_mapper.local_table
- inherited_persist_selectable = inherited_mapper.persist_selectable
+ inherited_mapper_or_config = _declared_mapping_info(self.inherits)
+ assert inherited_mapper_or_config is not None
+ inherited_table = inherited_mapper_or_config.local_table
+ inherited_persist_selectable = (
+ inherited_mapper_or_config.persist_selectable
+ )
if table is None:
# single table inheritance.
@@ -1036,29 +1163,44 @@ class _ClassScanMapperConfig(_MapperConfig):
"Can't place __table_args__ on an inherited class "
"with no table."
)
+
# add any columns declared here to the inherited table.
- for c in declared_columns:
- if c.name in inherited_table.c:
- if inherited_table.c[c.name] is c:
+ if declared_columns and not isinstance(inherited_table, Table):
+ raise exc.ArgumentError(
+ f"Can't declare columns on single-table-inherited "
+ f"subclass {self.cls}; superclass {self.inherits} "
+ "is not mapped to a Table"
+ )
+
+ for col in declared_columns:
+ assert inherited_table is not None
+ if col.name in inherited_table.c:
+ if inherited_table.c[col.name] is col:
continue
raise exc.ArgumentError(
"Column '%s' on class %s conflicts with "
"existing column '%s'"
- % (c, cls, inherited_table.c[c.name])
+ % (col, cls, inherited_table.c[col.name])
)
- if c.primary_key:
+ if col.primary_key:
raise exc.ArgumentError(
"Can't place primary key columns on an inherited "
"class with no table."
)
- inherited_table.append_column(c)
+
+ if TYPE_CHECKING:
+ assert isinstance(inherited_table, Table)
+
+ inherited_table.append_column(col)
if (
inherited_persist_selectable is not None
and inherited_persist_selectable is not inherited_table
):
- inherited_persist_selectable._refresh_for_new_column(c)
+ inherited_persist_selectable._refresh_for_new_column(
+ col
+ )
- def _prepare_mapper_arguments(self, mapper_kw):
+ def _prepare_mapper_arguments(self, mapper_kw: _MapperKwArgs) -> None:
properties = self.properties
if self.mapper_args_fn:
@@ -1100,6 +1242,7 @@ class _ClassScanMapperConfig(_MapperConfig):
# not mapped on the parent class, to avoid
# mapping columns specific to sibling/nephew classes
inherited_mapper = _declared_mapping_info(self.inherits)
+ assert isinstance(inherited_mapper, Mapper)
inherited_table = inherited_mapper.local_table
if "exclude_properties" not in mapper_args:
@@ -1133,11 +1276,14 @@ class _ClassScanMapperConfig(_MapperConfig):
result_mapper_args["properties"] = properties
self.mapper_args = result_mapper_args
- def map(self, mapper_kw=util.EMPTY_DICT):
+ def map(self, mapper_kw: _MapperKwArgs = util.EMPTY_DICT) -> Mapper[Any]:
self._prepare_mapper_arguments(mapper_kw)
if hasattr(self.cls, "__mapper_cls__"):
- mapper_cls = util.unbound_method_to_callable(
- self.cls.__mapper_cls__
+ mapper_cls = cast(
+ "Type[Mapper[Any]]",
+ util.unbound_method_to_callable(
+ self.cls.__mapper_cls__ # type: ignore
+ ),
)
else:
mapper_cls = mapper
@@ -1149,7 +1295,9 @@ class _ClassScanMapperConfig(_MapperConfig):
@util.preload_module("sqlalchemy.orm.decl_api")
-def _as_dc_declaredattr(field_metadata, sa_dataclass_metadata_key):
+def _as_dc_declaredattr(
+ field_metadata: Mapping[str, Any], sa_dataclass_metadata_key: str
+) -> Any:
# wrap lambdas inside dataclass fields inside an ad-hoc declared_attr.
# we can't write it because field.metadata is immutable :( so we have
# to go through extra trouble to compare these
@@ -1162,46 +1310,55 @@ def _as_dc_declaredattr(field_metadata, sa_dataclass_metadata_key):
class _DeferredMapperConfig(_ClassScanMapperConfig):
- _configs = util.OrderedDict()
+ _cls: weakref.ref[Type[Any]]
+
+ _configs: util.OrderedDict[
+ weakref.ref[Type[Any]], _DeferredMapperConfig
+ ] = util.OrderedDict()
- def _early_mapping(self, mapper_kw):
+ def _early_mapping(self, mapper_kw: _MapperKwArgs) -> None:
pass
- @property
- def cls(self):
- return self._cls()
+ # mypy disallows plain property override of variable
+ @property # type: ignore
+ def cls(self) -> Type[Any]: # type: ignore
+ return self._cls() # type: ignore
@cls.setter
- def cls(self, class_):
+ def cls(self, class_: Type[Any]) -> None:
self._cls = weakref.ref(class_, self._remove_config_cls)
self._configs[self._cls] = self
@classmethod
- def _remove_config_cls(cls, ref):
+ def _remove_config_cls(cls, ref: weakref.ref[Type[Any]]) -> None:
cls._configs.pop(ref, None)
@classmethod
- def has_cls(cls, class_):
+ def has_cls(cls, class_: Type[Any]) -> bool:
# 2.6 fails on weakref if class_ is an old style class
return isinstance(class_, type) and weakref.ref(class_) in cls._configs
@classmethod
- def raise_unmapped_for_cls(cls, class_):
+ def raise_unmapped_for_cls(cls, class_: Type[Any]) -> NoReturn:
if hasattr(class_, "_sa_raise_deferred_config"):
- class_._sa_raise_deferred_config()
+ class_._sa_raise_deferred_config() # type: ignore
raise orm_exc.UnmappedClassError(
class_,
- msg="Class %s has a deferred mapping on it. It is not yet "
- "usable as a mapped class." % orm_exc._safe_cls_name(class_),
+ msg=(
+ f"Class {orm_exc._safe_cls_name(class_)} has a deferred "
+ "mapping on it. It is not yet usable as a mapped class."
+ ),
)
@classmethod
- def config_for_cls(cls, class_):
+ def config_for_cls(cls, class_: Type[Any]) -> _DeferredMapperConfig:
return cls._configs[weakref.ref(class_)]
@classmethod
- def classes_for_base(cls, base_cls, sort=True):
+ def classes_for_base(
+ cls, base_cls: Type[Any], sort: bool = True
+ ) -> List[_DeferredMapperConfig]:
classes_for_base = [
m
for m, cls_ in [(m, m.cls) for m in cls._configs.values()]
@@ -1213,7 +1370,7 @@ class _DeferredMapperConfig(_ClassScanMapperConfig):
all_m_by_cls = dict((m.cls, m) for m in classes_for_base)
- tuples = []
+ tuples: List[Tuple[_DeferredMapperConfig, _DeferredMapperConfig]] = []
for m_cls in all_m_by_cls:
tuples.extend(
(all_m_by_cls[base_cls], all_m_by_cls[m_cls])
@@ -1222,12 +1379,14 @@ class _DeferredMapperConfig(_ClassScanMapperConfig):
)
return list(topological.sort(tuples, classes_for_base))
- def map(self, mapper_kw=util.EMPTY_DICT):
+ def map(self, mapper_kw: _MapperKwArgs = util.EMPTY_DICT) -> Mapper[Any]:
self._configs.pop(self._cls, None)
return super(_DeferredMapperConfig, self).map(mapper_kw)
-def _add_attribute(cls, key, value):
+def _add_attribute(
+ cls: Type[Any], key: str, value: MapperProperty[Any]
+) -> None:
"""add an attribute to an existing declarative class.
This runs through the logic to determine MapperProperty,
@@ -1236,39 +1395,44 @@ def _add_attribute(cls, key, value):
"""
if "__mapper__" in cls.__dict__:
+ mapped_cls = cast("_DeclMappedClassProtocol[Any]", cls)
if isinstance(value, Column):
_undefer_column_name(key, value)
- cls.__table__.append_column(value, replace_existing=True)
- cls.__mapper__.add_property(key, value)
+ # TODO: raise for this is not a Table
+ mapped_cls.__table__.append_column(value, replace_existing=True)
+ mapped_cls.__mapper__.add_property(key, value)
elif isinstance(value, _MapsColumns):
mp = value.mapper_property_to_assign
for col in value.columns_to_assign:
_undefer_column_name(key, col)
- cls.__table__.append_column(col, replace_existing=True)
+ # TODO: raise for this is not a Table
+ mapped_cls.__table__.append_column(col, replace_existing=True)
if not mp:
- cls.__mapper__.add_property(key, col)
+ mapped_cls.__mapper__.add_property(key, col)
if mp:
- cls.__mapper__.add_property(key, mp)
+ mapped_cls.__mapper__.add_property(key, mp)
elif isinstance(value, MapperProperty):
- cls.__mapper__.add_property(key, value)
+ mapped_cls.__mapper__.add_property(key, value)
elif isinstance(value, QueryableAttribute) and value.key != key:
# detect a QueryableAttribute that's already mapped being
# assigned elsewhere in userland, turn into a synonym()
value = Synonym(value.key)
- cls.__mapper__.add_property(key, value)
+ mapped_cls.__mapper__.add_property(key, value)
else:
type.__setattr__(cls, key, value)
- cls.__mapper__._expire_memoizations()
+ mapped_cls.__mapper__._expire_memoizations()
else:
type.__setattr__(cls, key, value)
-def _del_attribute(cls, key):
+def _del_attribute(cls: Type[Any], key: str) -> None:
if (
"__mapper__" in cls.__dict__
and key in cls.__dict__
- and not cls.__mapper__._dispose_called
+ and not cast(
+ "_DeclMappedClassProtocol[Any]", cls
+ ).__mapper__._dispose_called
):
value = cls.__dict__[key]
if isinstance(
@@ -1279,12 +1443,14 @@ def _del_attribute(cls, key):
)
else:
type.__delattr__(cls, key)
- cls.__mapper__._expire_memoizations()
+ cast(
+ "_DeclMappedClassProtocol[Any]", cls
+ ).__mapper__._expire_memoizations()
else:
type.__delattr__(cls, key)
-def _declarative_constructor(self, **kwargs):
+def _declarative_constructor(self: Any, **kwargs: Any) -> None:
"""A simple constructor that allows initialization from kwargs.
Sets attributes on the constructed instance using the names and
@@ -1306,7 +1472,7 @@ def _declarative_constructor(self, **kwargs):
_declarative_constructor.__name__ = "__init__"
-def _undefer_column_name(key, column):
+def _undefer_column_name(key: str, column: Column[Any]) -> None:
if column.key is None:
column.key = key
if column.name is None:
diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py
index 5975c30db..8c89f96aa 100644
--- a/lib/sqlalchemy/orm/descriptor_props.py
+++ b/lib/sqlalchemy/orm/descriptor_props.py
@@ -20,15 +20,21 @@ import typing
from typing import Any
from typing import Callable
from typing import List
+from typing import NoReturn
from typing import Optional
+from typing import Sequence
from typing import Tuple
from typing import Type
+from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
from . import attributes
from . import util as orm_util
+from .base import LoaderCallableStatus
from .base import Mapped
+from .base import PassiveFlag
+from .base import SQLORMOperations
from .interfaces import _IntrospectsAnnotations
from .interfaces import _MapsColumns
from .interfaces import MapperProperty
@@ -41,20 +47,41 @@ from .. import schema
from .. import sql
from .. import util
from ..sql import expression
-from ..sql import operators
+from ..sql.elements import BindParameter
from ..util.typing import Protocol
if typing.TYPE_CHECKING:
+ from ._typing import _InstanceDict
+ from ._typing import _RegistryType
+ from .attributes import History
from .attributes import InstrumentedAttribute
+ from .attributes import QueryableAttribute
+ from .context import ORMCompileState
+ from .mapper import Mapper
+ from .properties import ColumnProperty
from .properties import MappedColumn
+ from .state import InstanceState
+ from ..engine.base import Connection
+ from ..engine.row import Row
+ from ..sql._typing import _DMLColumnArgument
from ..sql._typing import _InfoType
+ from ..sql.elements import ClauseList
+ from ..sql.elements import ColumnElement
from ..sql.schema import Column
+ from ..sql.selectable import Select
+ from ..util.typing import _AnnotationScanType
+ from ..util.typing import CallableReference
+ from ..util.typing import DescriptorReference
+ from ..util.typing import RODescriptorReference
_T = TypeVar("_T", bound=Any)
_PT = TypeVar("_PT", bound=Any)
class _CompositeClassProto(Protocol):
+ def __init__(self, *args: Any):
+ ...
+
def __composite_values__(self) -> Tuple[Any, ...]:
...
@@ -63,32 +90,43 @@ class DescriptorProperty(MapperProperty[_T]):
""":class:`.MapperProperty` which proxies access to a
user-defined descriptor."""
- doc = None
+ doc: Optional[str] = None
uses_objects = False
_links_to_entity = False
- def instrument_class(self, mapper):
+ descriptor: DescriptorReference[Any]
+
+ def get_history(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
+ ) -> History:
+ raise NotImplementedError()
+
+ def instrument_class(self, mapper: Mapper[Any]) -> None:
prop = self
- class _ProxyImpl:
+ class _ProxyImpl(attributes.AttributeImpl):
accepts_scalar_loader = False
load_on_unexpire = True
collection = False
@property
- def uses_objects(self):
+ def uses_objects(self) -> bool: # type: ignore
return prop.uses_objects
- def __init__(self, key):
+ def __init__(self, key: str):
self.key = key
- if hasattr(prop, "get_history"):
-
- def get_history(
- self, state, dict_, passive=attributes.PASSIVE_OFF
- ):
- return prop.get_history(state, dict_, passive)
+ def get_history(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
+ ) -> History:
+ return prop.get_history(state, dict_, passive)
if self.descriptor is None:
desc = getattr(mapper.class_, self.key, None)
@@ -97,13 +135,13 @@ class DescriptorProperty(MapperProperty[_T]):
if self.descriptor is None:
- def fset(obj, value):
+ def fset(obj: Any, value: Any) -> None:
setattr(obj, self.name, value)
- def fdel(obj):
+ def fdel(obj: Any) -> None:
delattr(obj, self.name)
- def fget(obj):
+ def fget(obj: Any) -> Any:
return getattr(obj, self.name)
self.descriptor = property(fget=fget, fset=fset, fdel=fdel)
@@ -129,8 +167,11 @@ _CompositeAttrType = Union[
]
+_CC = TypeVar("_CC", bound=_CompositeClassProto)
+
+
class Composite(
- _MapsColumns[_T], _IntrospectsAnnotations, DescriptorProperty[_T]
+ _MapsColumns[_CC], _IntrospectsAnnotations, DescriptorProperty[_CC]
):
"""Defines a "composite" mapped attribute, representing a collection
of columns as one attribute.
@@ -148,19 +189,25 @@ class Composite(
"""
- composite_class: Union[
- Type[_CompositeClassProto], Callable[..., Type[_CompositeClassProto]]
+ composite_class: Union[Type[_CC], Callable[..., _CC]]
+ attrs: Tuple[_CompositeAttrType[Any], ...]
+
+ _generated_composite_accessor: CallableReference[
+ Optional[Callable[[_CC], Tuple[Any, ...]]]
]
- attrs: Tuple[_CompositeAttrType, ...]
+
+ comparator_factory: Type[Comparator[_CC]]
def __init__(
self,
- class_: Union[None, _CompositeClassProto, _CompositeAttrType] = None,
- *attrs: _CompositeAttrType,
+ class_: Union[
+ None, Type[_CC], Callable[..., _CC], _CompositeAttrType[Any]
+ ] = None,
+ *attrs: _CompositeAttrType[Any],
active_history: bool = False,
deferred: bool = False,
group: Optional[str] = None,
- comparator_factory: Optional[Type[Comparator]] = None,
+ comparator_factory: Optional[Type[Comparator[_CC]]] = None,
info: Optional[_InfoType] = None,
):
super().__init__()
@@ -170,7 +217,7 @@ class Composite(
# will initialize within declarative_scan
self.composite_class = None # type: ignore
else:
- self.composite_class = class_
+ self.composite_class = class_ # type: ignore
self.attrs = attrs
self.active_history = active_history
@@ -183,18 +230,16 @@ class Composite(
)
self._generated_composite_accessor = None
if info is not None:
- self.info = info
+ self.info.update(info)
util.set_creation_order(self)
self._create_descriptor()
- def instrument_class(self, mapper):
+ def instrument_class(self, mapper: Mapper[Any]) -> None:
super().instrument_class(mapper)
self._setup_event_handlers()
- def _composite_values_from_instance(
- self, value: _CompositeClassProto
- ) -> Tuple[Any, ...]:
+ def _composite_values_from_instance(self, value: _CC) -> Tuple[Any, ...]:
if self._generated_composite_accessor:
return self._generated_composite_accessor(value)
else:
@@ -209,7 +254,7 @@ class Composite(
else:
return accessor()
- def do_init(self):
+ def do_init(self) -> None:
"""Initialization which occurs after the :class:`.Composite`
has been associated with its parent mapper.
@@ -218,13 +263,13 @@ class Composite(
_COMPOSITE_FGET = object()
- def _create_descriptor(self):
+ def _create_descriptor(self) -> None:
"""Create the Python descriptor that will serve as
the access point on instances of the mapped class.
"""
- def fget(instance):
+ def fget(instance: Any) -> Any:
dict_ = attributes.instance_dict(instance)
state = attributes.instance_state(instance)
@@ -251,11 +296,11 @@ class Composite(
return dict_.get(self.key, None)
- def fset(instance, value):
+ def fset(instance: Any, value: Any) -> None:
dict_ = attributes.instance_dict(instance)
state = attributes.instance_state(instance)
attr = state.manager[self.key]
- previous = dict_.get(self.key, attributes.NO_VALUE)
+ previous = dict_.get(self.key, LoaderCallableStatus.NO_VALUE)
for fn in attr.dispatch.set:
value = fn(state, value, previous, attr.impl)
dict_[self.key] = value
@@ -269,10 +314,10 @@ class Composite(
):
setattr(instance, key, value)
- def fdel(instance):
+ def fdel(instance: Any) -> None:
state = attributes.instance_state(instance)
dict_ = attributes.instance_dict(instance)
- previous = dict_.pop(self.key, attributes.NO_VALUE)
+ previous = dict_.pop(self.key, LoaderCallableStatus.NO_VALUE)
attr = state.manager[self.key]
attr.dispatch.remove(state, previous, attr.impl)
for key in self._attribute_keys:
@@ -282,8 +327,13 @@ class Composite(
@util.preload_module("sqlalchemy.orm.properties")
def declarative_scan(
- self, registry, cls, key, annotation, is_dataclass_field
- ):
+ self,
+ registry: _RegistryType,
+ cls: Type[Any],
+ key: str,
+ annotation: Optional[_AnnotationScanType],
+ is_dataclass_field: bool,
+ ) -> None:
MappedColumn = util.preloaded.orm_properties.MappedColumn
argument = _extract_mapped_subtype(
@@ -310,7 +360,9 @@ class Composite(
@util.preload_module("sqlalchemy.orm.properties")
@util.preload_module("sqlalchemy.orm.decl_base")
- def _setup_for_dataclass(self, registry, cls, key):
+ def _setup_for_dataclass(
+ self, registry: _RegistryType, cls: Type[Any], key: str
+ ) -> None:
MappedColumn = util.preloaded.orm_properties.MappedColumn
decl_base = util.preloaded.orm_decl_base
@@ -341,12 +393,12 @@ class Composite(
self._generated_composite_accessor = getter
@util.memoized_property
- def _comparable_elements(self):
+ def _comparable_elements(self) -> Sequence[QueryableAttribute[Any]]:
return [getattr(self.parent.class_, prop.key) for prop in self.props]
@util.memoized_property
@util.preload_module("orm.properties")
- def props(self):
+ def props(self) -> Sequence[MapperProperty[Any]]:
props = []
MappedColumn = util.preloaded.orm_properties.MappedColumn
@@ -360,17 +412,20 @@ class Composite(
elif isinstance(attr, attributes.InstrumentedAttribute):
prop = attr.property
else:
+ prop = None
+
+ if not isinstance(prop, MapperProperty):
raise sa_exc.ArgumentError(
"Composite expects Column objects or mapped "
- "attributes/attribute names as arguments, got: %r"
- % (attr,)
+ f"attributes/attribute names as arguments, got: {attr!r}"
)
+
props.append(prop)
return props
- @property
+ @util.non_memoized_property
@util.preload_module("orm.properties")
- def columns(self):
+ def columns(self) -> Sequence[Column[Any]]:
MappedColumn = util.preloaded.orm_properties.MappedColumn
return [
a.column if isinstance(a, MappedColumn) else a
@@ -379,32 +434,46 @@ class Composite(
]
@property
- def mapper_property_to_assign(self) -> Optional["MapperProperty[_T]"]:
+ def mapper_property_to_assign(self) -> Optional[MapperProperty[_CC]]:
return self
@property
- def columns_to_assign(self) -> List[schema.Column]:
+ def columns_to_assign(self) -> List[schema.Column[Any]]:
return [c for c in self.columns if c.table is None]
- def _setup_arguments_on_columns(self):
+ @util.preload_module("orm.properties")
+ def _setup_arguments_on_columns(self) -> None:
"""Propagate configuration arguments made on this composite
to the target columns, for those that apply.
"""
+ ColumnProperty = util.preloaded.orm_properties.ColumnProperty
+
for prop in self.props:
- prop.active_history = self.active_history
+ if not isinstance(prop, ColumnProperty):
+ continue
+ else:
+ cprop = prop
+
+ cprop.active_history = self.active_history
if self.deferred:
- prop.deferred = self.deferred
- prop.strategy_key = (("deferred", True), ("instrument", True))
- prop.group = self.group
+ cprop.deferred = self.deferred
+ cprop.strategy_key = (("deferred", True), ("instrument", True))
+ cprop.group = self.group
- def _setup_event_handlers(self):
+ def _setup_event_handlers(self) -> None:
"""Establish events that populate/expire the composite attribute."""
- def load_handler(state, context):
+ def load_handler(
+ state: InstanceState[Any], context: ORMCompileState
+ ) -> None:
_load_refresh_handler(state, context, None, is_refresh=False)
- def refresh_handler(state, context, to_load):
+ def refresh_handler(
+ state: InstanceState[Any],
+ context: ORMCompileState,
+ to_load: Optional[Sequence[str]],
+ ) -> None:
# note this corresponds to sqlalchemy.ext.mutable load_attrs()
if not to_load or (
@@ -412,7 +481,12 @@ class Composite(
).intersection(to_load):
_load_refresh_handler(state, context, to_load, is_refresh=True)
- def _load_refresh_handler(state, context, to_load, is_refresh):
+ def _load_refresh_handler(
+ state: InstanceState[Any],
+ context: ORMCompileState,
+ to_load: Optional[Sequence[str]],
+ is_refresh: bool,
+ ) -> None:
dict_ = state.dict
# if context indicates we are coming from the
@@ -440,11 +514,17 @@ class Composite(
*[state.dict[key] for key in self._attribute_keys]
)
- def expire_handler(state, keys):
+ def expire_handler(
+ state: InstanceState[Any], keys: Optional[Sequence[str]]
+ ) -> None:
if keys is None or set(self._attribute_keys).intersection(keys):
state.dict.pop(self.key, None)
- def insert_update_handler(mapper, connection, state):
+ def insert_update_handler(
+ mapper: Mapper[Any],
+ connection: Connection,
+ state: InstanceState[Any],
+ ) -> None:
"""After an insert or update, some columns may be expired due
to server side defaults, or re-populated due to client side
defaults. Pop out the composite value here so that it
@@ -473,14 +553,19 @@ class Composite(
# TODO: need a deserialize hook here
@util.memoized_property
- def _attribute_keys(self):
+ def _attribute_keys(self) -> Sequence[str]:
return [prop.key for prop in self.props]
- def get_history(self, state, dict_, passive=attributes.PASSIVE_OFF):
+ def get_history(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
+ ) -> History:
"""Provided for userland code that uses attributes.get_history()."""
- added = []
- deleted = []
+ added: List[Any] = []
+ deleted: List[Any] = []
has_history = False
for prop in self.props:
@@ -508,16 +593,27 @@ class Composite(
else:
return attributes.History((), [self.composite_class(*added)], ())
- def _comparator_factory(self, mapper):
+ def _comparator_factory(
+ self, mapper: Mapper[Any]
+ ) -> Composite.Comparator[_CC]:
return self.comparator_factory(self, mapper)
- class CompositeBundle(orm_util.Bundle):
- def __init__(self, property_, expr):
+ class CompositeBundle(orm_util.Bundle[_T]):
+ def __init__(
+ self,
+ property_: Composite[_T],
+ expr: ClauseList,
+ ):
self.property = property_
super().__init__(property_.key, *expr)
- def create_row_processor(self, query, procs, labels):
- def proc(row):
+ def create_row_processor(
+ self,
+ query: Select[Any],
+ procs: Sequence[Callable[[Row[Any]], Any]],
+ labels: Sequence[str],
+ ) -> Callable[[Row[Any]], Any]:
+ def proc(row: Row[Any]) -> Any:
return self.property.composite_class(
*[proc(row) for proc in procs]
)
@@ -546,17 +642,19 @@ class Composite(
# https://github.com/python/mypy/issues/4266
__hash__ = None # type: ignore
+ prop: RODescriptorReference[Composite[_PT]]
+
@util.memoized_property
- def clauses(self):
+ def clauses(self) -> ClauseList:
return expression.ClauseList(
group=False, *self._comparable_elements
)
- def __clause_element__(self):
+ def __clause_element__(self) -> Composite.CompositeBundle[_PT]:
return self.expression
@util.memoized_property
- def expression(self):
+ def expression(self) -> Composite.CompositeBundle[_PT]:
clauses = self.clauses._annotate(
{
"parententity": self._parententity,
@@ -566,13 +664,19 @@ class Composite(
)
return Composite.CompositeBundle(self.prop, clauses)
- def _bulk_update_tuples(self, value):
- if isinstance(value, sql.elements.BindParameter):
+ def _bulk_update_tuples(
+ self, value: Any
+ ) -> Sequence[Tuple[_DMLColumnArgument, Any]]:
+ if isinstance(value, BindParameter):
value = value.value
+ values: Sequence[Any]
+
if value is None:
values = [None for key in self.prop._attribute_keys]
- elif isinstance(value, self.prop.composite_class):
+ elif isinstance(self.prop.composite_class, type) and isinstance(
+ value, self.prop.composite_class
+ ):
values = self.prop._composite_values_from_instance(value)
else:
raise sa_exc.ArgumentError(
@@ -580,10 +684,10 @@ class Composite(
% (self.prop, value)
)
- return zip(self._comparable_elements, values)
+ return list(zip(self._comparable_elements, values))
@util.memoized_property
- def _comparable_elements(self):
+ def _comparable_elements(self) -> Sequence[QueryableAttribute[Any]]:
if self._adapt_to_entity:
return [
getattr(self._adapt_to_entity.entity, prop.key)
@@ -592,7 +696,8 @@ class Composite(
else:
return self.prop._comparable_elements
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501
+ values: Sequence[Any]
if other is None:
values = [None] * len(self.prop._comparable_elements)
else:
@@ -601,13 +706,14 @@ class Composite(
a == b for a, b in zip(self.prop._comparable_elements, values)
]
if self._adapt_to_entity:
+ assert self.adapter is not None
comparisons = [self.adapter(x) for x in comparisons]
return sql.and_(*comparisons)
- def __ne__(self, other):
+ def __ne__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501
return sql.not_(self.__eq__(other))
- def __str__(self):
+ def __str__(self) -> str:
return str(self.parent.class_.__name__) + "." + self.key
@@ -628,20 +734,24 @@ class ConcreteInheritedProperty(DescriptorProperty[_T]):
"""
- def _comparator_factory(self, mapper):
+ def _comparator_factory(
+ self, mapper: Mapper[Any]
+ ) -> Type[PropComparator[_T]]:
+
comparator_callable = None
for m in self.parent.iterate_to_root():
p = m._props[self.key]
- if not isinstance(p, ConcreteInheritedProperty):
+ if getattr(p, "comparator_factory", None) is not None:
comparator_callable = p.comparator_factory
break
- return comparator_callable
+ assert comparator_callable is not None
+ return comparator_callable(p, mapper) # type: ignore
- def __init__(self):
+ def __init__(self) -> None:
super().__init__()
- def warn():
+ def warn() -> NoReturn:
raise AttributeError(
"Concrete %s does not implement "
"attribute %r at the instance level. Add "
@@ -650,13 +760,13 @@ class ConcreteInheritedProperty(DescriptorProperty[_T]):
)
class NoninheritedConcreteProp:
- def __set__(s, obj, value):
+ def __set__(s: Any, obj: Any, value: Any) -> NoReturn:
warn()
- def __delete__(s, obj):
+ def __delete__(s: Any, obj: Any) -> NoReturn:
warn()
- def __get__(s, obj, owner):
+ def __get__(s: Any, obj: Any, owner: Any) -> Any:
if obj is None:
return self.descriptor
warn()
@@ -682,14 +792,16 @@ class Synonym(DescriptorProperty[_T]):
"""
+ comparator_factory: Optional[Type[PropComparator[_T]]]
+
def __init__(
self,
- name,
- map_column=None,
- descriptor=None,
- comparator_factory=None,
- doc=None,
- info=None,
+ name: str,
+ map_column: Optional[bool] = None,
+ descriptor: Optional[Any] = None,
+ comparator_factory: Optional[Type[PropComparator[_T]]] = None,
+ info: Optional[_InfoType] = None,
+ doc: Optional[str] = None,
):
super().__init__()
@@ -697,21 +809,30 @@ class Synonym(DescriptorProperty[_T]):
self.map_column = map_column
self.descriptor = descriptor
self.comparator_factory = comparator_factory
- self.doc = doc or (descriptor and descriptor.__doc__) or None
+ if doc:
+ self.doc = doc
+ elif descriptor and descriptor.__doc__:
+ self.doc = descriptor.__doc__
+ else:
+ self.doc = None
if info:
- self.info = info
+ self.info.update(info)
util.set_creation_order(self)
- @property
- def uses_objects(self):
- return getattr(self.parent.class_, self.name).impl.uses_objects
+ if not TYPE_CHECKING:
+
+ @property
+ def uses_objects(self) -> bool:
+ return getattr(self.parent.class_, self.name).impl.uses_objects
# TODO: when initialized, check _proxied_object,
# emit a warning if its not a column-based property
@util.memoized_property
- def _proxied_object(self):
+ def _proxied_object(
+ self,
+ ) -> Union[MapperProperty[_T], SQLORMOperations[_T]]:
attr = getattr(self.parent.class_, self.name)
if not hasattr(attr, "property") or not isinstance(
attr.property, MapperProperty
@@ -720,7 +841,8 @@ class Synonym(DescriptorProperty[_T]):
# hybrid or association proxy
if isinstance(attr, attributes.QueryableAttribute):
return attr.comparator
- elif isinstance(attr, operators.ColumnOperators):
+ elif isinstance(attr, SQLORMOperations):
+ # assocaition proxy comes here
return attr
raise sa_exc.InvalidRequestError(
@@ -730,7 +852,7 @@ class Synonym(DescriptorProperty[_T]):
)
return attr.property
- def _comparator_factory(self, mapper):
+ def _comparator_factory(self, mapper: Mapper[Any]) -> SQLORMOperations[_T]:
prop = self._proxied_object
if isinstance(prop, MapperProperty):
@@ -742,12 +864,17 @@ class Synonym(DescriptorProperty[_T]):
else:
return prop
- def get_history(self, *arg, **kw):
- attr = getattr(self.parent.class_, self.name)
- return attr.impl.get_history(*arg, **kw)
+ def get_history(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
+ ) -> History:
+ attr: QueryableAttribute[Any] = getattr(self.parent.class_, self.name)
+ return attr.impl.get_history(state, dict_, passive=passive)
@util.preload_module("sqlalchemy.orm.properties")
- def set_parent(self, parent, init):
+ def set_parent(self, parent: Mapper[Any], init: bool) -> None:
properties = util.preloaded.orm_properties
if self.map_column:
@@ -776,7 +903,7 @@ class Synonym(DescriptorProperty[_T]):
"%r for column %r"
% (self.key, self.name, self.name, self.key)
)
- p = properties.ColumnProperty(
+ p: ColumnProperty[Any] = properties.ColumnProperty(
parent.persist_selectable.c[self.key]
)
parent._configure_property(self.name, p, init=init, setparent=True)
diff --git a/lib/sqlalchemy/orm/dynamic.py b/lib/sqlalchemy/orm/dynamic.py
index 1b4f573b5..084ba969f 100644
--- a/lib/sqlalchemy/orm/dynamic.py
+++ b/lib/sqlalchemy/orm/dynamic.py
@@ -16,6 +16,12 @@ basic add/delete mutation.
from __future__ import annotations
+from typing import Any
+from typing import Optional
+from typing import overload
+from typing import TYPE_CHECKING
+from typing import Union
+
from . import attributes
from . import exc as orm_exc
from . import interfaces
@@ -23,17 +29,27 @@ from . import relationships
from . import strategies
from . import util as orm_util
from .base import object_mapper
+from .base import PassiveFlag
from .query import Query
from .session import object_session
from .. import exc
from .. import log
from .. import util
from ..engine import result
+from ..util.typing import Literal
+
+if TYPE_CHECKING:
+ from ._typing import _InstanceDict
+ from .attributes import _AdaptedCollectionProtocol
+ from .attributes import AttributeEventToken
+ from .attributes import CollectionAdapter
+ from .base import LoaderCallableStatus
+ from .state import InstanceState
@log.class_logger
@relationships.Relationship.strategy_for(lazy="dynamic")
-class DynaLoader(strategies.AbstractRelationshipLoader):
+class DynaLoader(strategies.AbstractRelationshipLoader, log.Identified):
def init_class_attribute(self, mapper):
self.is_class_level = True
if not self.uselist:
@@ -106,13 +122,47 @@ class DynamicAttributeImpl(
else:
return self.query_class(self, state)
+ @overload
def get_collection(
self,
- state,
- dict_,
- user_data=None,
- passive=attributes.PASSIVE_NO_INITIALIZE,
- ):
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ user_data: Literal[None] = ...,
+ passive: Literal[PassiveFlag.PASSIVE_OFF] = ...,
+ ) -> CollectionAdapter:
+ ...
+
+ @overload
+ def get_collection(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ user_data: _AdaptedCollectionProtocol = ...,
+ passive: PassiveFlag = ...,
+ ) -> CollectionAdapter:
+ ...
+
+ @overload
+ def get_collection(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ user_data: Optional[_AdaptedCollectionProtocol] = ...,
+ passive: PassiveFlag = ...,
+ ) -> Union[
+ Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter
+ ]:
+ ...
+
+ def get_collection(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ user_data: Optional[_AdaptedCollectionProtocol] = None,
+ passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
+ ) -> Union[
+ Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter
+ ]:
if not passive & attributes.SQL_OK:
data = self._get_collection_history(state, passive).added_items
else:
@@ -170,15 +220,15 @@ class DynamicAttributeImpl(
def set(
self,
- state,
- dict_,
- value,
- initiator=None,
- passive=attributes.PASSIVE_OFF,
- check_old=None,
- pop=False,
- _adapt=True,
- ):
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ value: Any,
+ initiator: Optional[AttributeEventToken] = None,
+ passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
+ check_old: Any = None,
+ pop: bool = False,
+ _adapt: bool = True,
+ ) -> None:
if initiator and initiator.parent_token is self.parent_token:
return
diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py
index 331c224ee..726ea79b5 100644
--- a/lib/sqlalchemy/orm/events.py
+++ b/lib/sqlalchemy/orm/events.py
@@ -4,6 +4,7 @@
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
+# mypy: ignore-errors
"""ORM event interfaces.
diff --git a/lib/sqlalchemy/orm/exc.py b/lib/sqlalchemy/orm/exc.py
index f157919ab..57e5fe8c6 100644
--- a/lib/sqlalchemy/orm/exc.py
+++ b/lib/sqlalchemy/orm/exc.py
@@ -4,7 +4,6 @@
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
"""SQLAlchemy ORM exceptions."""
@@ -12,13 +11,22 @@ from __future__ import annotations
from typing import Any
from typing import Optional
+from typing import Tuple
from typing import Type
+from typing import TYPE_CHECKING
+from typing import TypeVar
from .. import exc as sa_exc
from .. import util
from ..exc import MultipleResultsFound # noqa
from ..exc import NoResultFound # noqa
+if TYPE_CHECKING:
+ from .interfaces import LoaderStrategy
+ from .interfaces import MapperProperty
+ from .state import InstanceState
+
+_T = TypeVar("_T", bound=Any)
NO_STATE = (AttributeError, KeyError)
"""Exception types that may be raised by instrumentation implementations."""
@@ -100,14 +108,14 @@ class UnmappedInstanceError(UnmappedError):
)
UnmappedError.__init__(self, msg)
- def __reduce__(self):
+ def __reduce__(self) -> Any:
return self.__class__, (None, self.args[0])
class UnmappedClassError(UnmappedError):
"""An mapping operation was requested for an unknown class."""
- def __init__(self, cls: Type[object], msg: Optional[str] = None):
+ def __init__(self, cls: Type[_T], msg: Optional[str] = None):
if not msg:
msg = _default_unmapped(cls)
UnmappedError.__init__(self, msg)
@@ -137,7 +145,7 @@ class ObjectDeletedError(sa_exc.InvalidRequestError):
"""
@util.preload_module("sqlalchemy.orm.base")
- def __init__(self, state, msg=None):
+ def __init__(self, state: InstanceState[Any], msg: Optional[str] = None):
base = util.preloaded.orm_base
if not msg:
@@ -148,7 +156,7 @@ class ObjectDeletedError(sa_exc.InvalidRequestError):
sa_exc.InvalidRequestError.__init__(self, msg)
- def __reduce__(self):
+ def __reduce__(self) -> Any:
return self.__class__, (None, self.args[0])
@@ -161,11 +169,11 @@ class LoaderStrategyException(sa_exc.InvalidRequestError):
def __init__(
self,
- applied_to_property_type,
- requesting_property,
- applies_to,
- actual_strategy_type,
- strategy_key,
+ applied_to_property_type: Type[Any],
+ requesting_property: MapperProperty[Any],
+ applies_to: Optional[Type[MapperProperty[Any]]],
+ actual_strategy_type: Optional[Type[LoaderStrategy]],
+ strategy_key: Tuple[Any, ...],
):
if actual_strategy_type is None:
sa_exc.InvalidRequestError.__init__(
@@ -174,6 +182,7 @@ class LoaderStrategyException(sa_exc.InvalidRequestError):
% (strategy_key, requesting_property),
)
else:
+ assert applies_to is not None
sa_exc.InvalidRequestError.__init__(
self,
'Can\'t apply "%s" strategy to property "%s", '
@@ -188,7 +197,8 @@ class LoaderStrategyException(sa_exc.InvalidRequestError):
)
-def _safe_cls_name(cls):
+def _safe_cls_name(cls: Type[Any]) -> str:
+ cls_name: Optional[str]
try:
cls_name = ".".join((cls.__module__, cls.__name__))
except AttributeError:
@@ -199,7 +209,7 @@ def _safe_cls_name(cls):
@util.preload_module("sqlalchemy.orm.base")
-def _default_unmapped(cls) -> Optional[str]:
+def _default_unmapped(cls: Type[Any]) -> Optional[str]:
base = util.preloaded.orm_base
try:
diff --git a/lib/sqlalchemy/orm/identity.py b/lib/sqlalchemy/orm/identity.py
index d13265c56..63b131a78 100644
--- a/lib/sqlalchemy/orm/identity.py
+++ b/lib/sqlalchemy/orm/identity.py
@@ -8,6 +8,7 @@
from __future__ import annotations
from typing import Any
+from typing import cast
from typing import Dict
from typing import Iterable
from typing import Iterator
@@ -15,6 +16,7 @@ from typing import List
from typing import NoReturn
from typing import Optional
from typing import Set
+from typing import Tuple
from typing import TYPE_CHECKING
from typing import TypeVar
import weakref
@@ -66,7 +68,7 @@ class IdentityMap:
) -> Optional[_O]:
raise NotImplementedError()
- def keys(self):
+ def keys(self) -> Iterable[_IdentityKeyType[Any]]:
return self._dict.keys()
def values(self) -> Iterable[object]:
@@ -117,10 +119,10 @@ class IdentityMap:
class WeakInstanceDict(IdentityMap):
- _dict: Dict[Optional[_IdentityKeyType[Any]], InstanceState[Any]]
+ _dict: Dict[_IdentityKeyType[Any], InstanceState[Any]]
def __getitem__(self, key: _IdentityKeyType[_O]) -> _O:
- state = self._dict[key]
+ state = cast("InstanceState[_O]", self._dict[key])
o = state.obj()
if o is None:
raise KeyError(key)
@@ -140,6 +142,8 @@ class WeakInstanceDict(IdentityMap):
def contains_state(self, state: InstanceState[Any]) -> bool:
if state.key in self._dict:
+ if TYPE_CHECKING:
+ assert state.key is not None
try:
return self._dict[state.key] is state
except KeyError:
@@ -150,15 +154,16 @@ class WeakInstanceDict(IdentityMap):
def replace(
self, state: InstanceState[Any]
) -> Optional[InstanceState[Any]]:
+ assert state.key is not None
if state.key in self._dict:
try:
- existing = self._dict[state.key]
+ existing = existing_non_none = self._dict[state.key]
except KeyError:
# catch gc removed the key after we just checked for it
existing = None
else:
- if existing is not state:
- self._manage_removed_state(existing)
+ if existing_non_none is not state:
+ self._manage_removed_state(existing_non_none)
else:
return None
else:
@@ -170,6 +175,7 @@ class WeakInstanceDict(IdentityMap):
def add(self, state: InstanceState[Any]) -> bool:
key = state.key
+ assert key is not None
# inline of self.__contains__
if key in self._dict:
try:
@@ -206,7 +212,7 @@ class WeakInstanceDict(IdentityMap):
if key not in self._dict:
return default
try:
- state = self._dict[key]
+ state = cast("InstanceState[_O]", self._dict[key])
except KeyError:
# catch gc removed the key after we just checked for it
return default
@@ -216,13 +222,15 @@ class WeakInstanceDict(IdentityMap):
return default
return o
- def items(self) -> List[InstanceState[Any]]:
+ def items(self) -> List[Tuple[_IdentityKeyType[Any], InstanceState[Any]]]:
values = self.all_states()
result = []
for state in values:
value = state.obj()
+ key = state.key
+ assert key is not None
if value is not None:
- result.append((state.key, value))
+ result.append((key, value))
return result
def values(self) -> List[object]:
@@ -244,28 +252,32 @@ class WeakInstanceDict(IdentityMap):
def _fast_discard(self, state: InstanceState[Any]) -> None:
# used by InstanceState for state being
# GC'ed, inlines _managed_removed_state
+ key = state.key
+ assert key is not None
try:
- st = self._dict[state.key]
+ st = self._dict[key]
except KeyError:
# catch gc removed the key after we just checked for it
pass
else:
if st is state:
- self._dict.pop(state.key, None)
+ self._dict.pop(key, None)
def discard(self, state: InstanceState[Any]) -> None:
self.safe_discard(state)
def safe_discard(self, state: InstanceState[Any]) -> None:
- if state.key in self._dict:
+ key = state.key
+ if key in self._dict:
+ assert key is not None
try:
- st = self._dict[state.key]
+ st = self._dict[key]
except KeyError:
# catch gc removed the key after we just checked for it
pass
else:
if st is state:
- self._dict.pop(state.key, None)
+ self._dict.pop(key, None)
self._manage_removed_state(state)
diff --git a/lib/sqlalchemy/orm/instrumentation.py b/lib/sqlalchemy/orm/instrumentation.py
index 85b85215e..4fa61b7ce 100644
--- a/lib/sqlalchemy/orm/instrumentation.py
+++ b/lib/sqlalchemy/orm/instrumentation.py
@@ -66,7 +66,7 @@ from ..util.typing import Protocol
if TYPE_CHECKING:
from ._typing import _RegistryType
from .attributes import AttributeImpl
- from .attributes import InstrumentedAttribute
+ from .attributes import QueryableAttribute
from .collections import _AdaptedCollectionProtocol
from .collections import _CollectionFactoryType
from .decl_base import _MapperConfig
@@ -96,7 +96,7 @@ class _ManagerFactory(Protocol):
class ClassManager(
HasMemoized,
- Dict[str, "InstrumentedAttribute[Any]"],
+ Dict[str, "QueryableAttribute[Any]"],
Generic[_O],
EventTarget,
):
@@ -117,7 +117,14 @@ class ClassManager(
factory: Optional[_ManagerFactory]
declarative_scan: Optional[weakref.ref[_MapperConfig]] = None
- registry: Optional[_RegistryType] = None
+
+ registry: _RegistryType
+
+ if not TYPE_CHECKING:
+ # starts as None during setup
+ registry = None
+
+ class_: Type[_O]
_bases: List[ClassManager[Any]]
@@ -312,7 +319,7 @@ class ClassManager(
else:
return default
- def _attr_has_impl(self, key):
+ def _attr_has_impl(self, key: str) -> bool:
"""Return True if the given attribute is fully initialized.
i.e. has an impl.
@@ -366,7 +373,12 @@ class ClassManager(
def dict_getter(self):
return _default_dict_getter
- def instrument_attribute(self, key, inst, propagated=False):
+ def instrument_attribute(
+ self,
+ key: str,
+ inst: QueryableAttribute[Any],
+ propagated: bool = False,
+ ) -> None:
if propagated:
if key in self.local_attrs:
return # don't override local attr with inherited attr
@@ -429,7 +441,7 @@ class ClassManager(
delattr(self.class_, self.MANAGER_ATTR)
def install_descriptor(
- self, key: str, inst: InstrumentedAttribute[Any]
+ self, key: str, inst: QueryableAttribute[Any]
) -> None:
if key in (self.STATE_ATTR, self.MANAGER_ATTR):
raise KeyError(
@@ -490,7 +502,11 @@ class ClassManager(
# InstanceState management
def new_instance(self, state: Optional[InstanceState[_O]] = None) -> _O:
- instance = self.class_.__new__(self.class_)
+ # here, we would prefer _O to be bound to "object"
+ # so that mypy sees that __new__ is present. currently
+ # it's bound to Any as there were other problems not having
+ # it that way but these can be revisited
+ instance = self.class_.__new__(self.class_) # type: ignore
if state is None:
state = self._state_constructor(instance, self)
self._state_setter(instance, state)
diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py
index c9c54c1b0..b5569ce06 100644
--- a/lib/sqlalchemy/orm/interfaces.py
+++ b/lib/sqlalchemy/orm/interfaces.py
@@ -4,7 +4,6 @@
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: allow-untyped-defs, allow-untyped-calls
"""
@@ -33,6 +32,7 @@ from typing import Sequence
from typing import Set
from typing import Tuple
from typing import Type
+from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
@@ -48,6 +48,7 @@ from .base import MANYTOMANY as MANYTOMANY # noqa: F401
from .base import MANYTOONE as MANYTOONE # noqa: F401
from .base import NotExtension as NotExtension # noqa: F401
from .base import ONETOMANY as ONETOMANY # noqa: F401
+from .base import RelationshipDirection as RelationshipDirection # noqa: F401
from .base import SQLORMOperations
from .. import ColumnElement
from .. import inspection
@@ -59,7 +60,7 @@ from ..sql.base import ExecutableOption
from ..sql.cache_key import HasCacheKey
from ..sql.schema import Column
from ..sql.type_api import TypeEngine
-from ..util.typing import DescriptorReference
+from ..util.typing import RODescriptorReference
from ..util.typing import TypedDict
if typing.TYPE_CHECKING:
@@ -75,13 +76,11 @@ if typing.TYPE_CHECKING:
from .loading import _PopulatorDict
from .mapper import Mapper
from .path_registry import AbstractEntityRegistry
- from .path_registry import PathRegistry
from .query import Query
from .session import Session
from .state import InstanceState
from .strategy_options import _LoadElement
from .util import AliasedInsp
- from .util import CascadeOptions
from .util import ORMAdapter
from ..engine.result import Result
from ..sql._typing import _ColumnExpressionArgument
@@ -89,8 +88,10 @@ if typing.TYPE_CHECKING:
from ..sql._typing import _DMLColumnArgument
from ..sql._typing import _InfoType
from ..sql.operators import OperatorType
- from ..sql.util import ColumnAdapter
from ..sql.visitors import _TraverseInternalsType
+ from ..util.typing import _AnnotationScanType
+
+_StrategyKey = Tuple[Any, ...]
_T = TypeVar("_T", bound=Any)
@@ -104,7 +105,9 @@ class ORMStatementRole(roles.StatementRole):
)
-class ORMColumnsClauseRole(roles.TypedColumnsClauseRole[_T]):
+class ORMColumnsClauseRole(
+ roles.ColumnsClauseRole, roles.TypedColumnsClauseRole[_T]
+):
__slots__ = ()
_role_name = "ORM mapped entity, aliased entity, or Column expression"
@@ -137,8 +140,8 @@ class _IntrospectsAnnotations:
registry: RegistryType,
cls: Type[Any],
key: str,
- annotation: Optional[Type[Any]],
- is_dataclass_field: Optional[bool],
+ annotation: Optional[_AnnotationScanType],
+ is_dataclass_field: bool,
) -> None:
"""Perform class-specific initializaton at early declarative scanning
time.
@@ -199,6 +202,7 @@ class MapperProperty(
"parent",
"key",
"info",
+ "doc",
)
_cache_key_traversal: _TraverseInternalsType = [
@@ -206,14 +210,8 @@ class MapperProperty(
("key", visitors.ExtendedInternalTraversal.dp_string),
]
- cascade: Optional[CascadeOptions] = None
- """The set of 'cascade' attribute names.
-
- This collection is checked before the 'cascade_iterator' method is called.
-
- The collection typically only applies to a Relationship.
-
- """
+ if not TYPE_CHECKING:
+ cascade = None
is_property = True
"""Part of the InspectionAttr interface; states this object is a
@@ -240,6 +238,9 @@ class MapperProperty(
"""
+ doc: Optional[str]
+ """optional documentation string"""
+
def _memoized_attr_info(self) -> _InfoType:
"""Info dictionary associated with the object, allowing user-defined
data to be associated with this :class:`.InspectionAttr`.
@@ -268,8 +269,8 @@ class MapperProperty(
self,
context: ORMCompileState,
query_entity: _MapperEntity,
- path: PathRegistry,
- adapter: Optional[ColumnAdapter],
+ path: AbstractEntityRegistry,
+ adapter: Optional[ORMAdapter],
**kwargs: Any,
) -> None:
"""Called by Query for the purposes of constructing a SQL statement.
@@ -284,10 +285,10 @@ class MapperProperty(
self,
context: ORMCompileState,
query_entity: _MapperEntity,
- path: PathRegistry,
+ path: AbstractEntityRegistry,
mapper: Mapper[Any],
result: Result[Any],
- adapter: Optional[ColumnAdapter],
+ adapter: Optional[ORMAdapter],
populators: _PopulatorDict,
) -> None:
"""Produce row processing functions and append to the given
@@ -421,7 +422,7 @@ class MapperProperty(
dest_state: InstanceState[Any],
dest_dict: _InstanceDict,
load: bool,
- _recursive: Set[InstanceState[Any]],
+ _recursive: Dict[Any, object],
_resolve_conflict_map: Dict[_IdentityKeyType[Any], object],
) -> None:
"""Merge the attribute represented by this ``MapperProperty``
@@ -526,7 +527,7 @@ class PropComparator(SQLORMOperations[_T]):
_parententity: _InternalEntityType[Any]
_adapt_to_entity: Optional[AliasedInsp[Any]]
- prop: DescriptorReference[MapperProperty[_T]]
+ prop: RODescriptorReference[MapperProperty[_T]]
def __init__(
self,
@@ -539,7 +540,7 @@ class PropComparator(SQLORMOperations[_T]):
self._adapt_to_entity = adapt_to_entity
@util.non_memoized_property
- def property(self) -> Optional[MapperProperty[_T]]:
+ def property(self) -> MapperProperty[_T]:
"""Return the :class:`.MapperProperty` associated with this
:class:`.PropComparator`.
@@ -589,7 +590,7 @@ class PropComparator(SQLORMOperations[_T]):
return self.prop.comparator._criterion_exists(criterion, **kwargs)
@util.ro_non_memoized_property
- def adapter(self) -> Optional[_ORMAdapterProto[_T]]:
+ def adapter(self) -> Optional[_ORMAdapterProto]:
"""Produce a callable that adapts column expressions
to suit an aliased version of this comparator.
@@ -597,7 +598,7 @@ class PropComparator(SQLORMOperations[_T]):
if self._adapt_to_entity is None:
return None
else:
- return self._adapt_to_entity._adapt_element
+ return self._adapt_to_entity._orm_adapt_element
@util.ro_non_memoized_property
def info(self) -> _InfoType:
@@ -631,7 +632,7 @@ class PropComparator(SQLORMOperations[_T]):
) -> ColumnElement[Any]:
...
- def of_type(self, class_: _EntityType[Any]) -> PropComparator[_T]:
+ def of_type(self, class_: _EntityType[_T]) -> PropComparator[_T]:
r"""Redefine this object in terms of a polymorphic subclass,
:func:`_orm.with_polymorphic` construct, or :func:`_orm.aliased`
construct.
@@ -763,9 +764,9 @@ class StrategizedProperty(MapperProperty[_T]):
inherit_cache = True
strategy_wildcard_key: ClassVar[str]
- strategy_key: Tuple[Any, ...]
+ strategy_key: _StrategyKey
- _strategies: Dict[Tuple[Any, ...], LoaderStrategy]
+ _strategies: Dict[_StrategyKey, LoaderStrategy]
def _memoized_attr__wildcard_token(self) -> Tuple[str]:
return (
@@ -808,7 +809,7 @@ class StrategizedProperty(MapperProperty[_T]):
return load
- def _get_strategy(self, key: Tuple[Any, ...]) -> LoaderStrategy:
+ def _get_strategy(self, key: _StrategyKey) -> LoaderStrategy:
try:
return self._strategies[key]
except KeyError:
@@ -822,7 +823,14 @@ class StrategizedProperty(MapperProperty[_T]):
self._strategies[key] = strategy = cls(self, key)
return strategy
- def setup(self, context, query_entity, path, adapter, **kwargs):
+ def setup(
+ self,
+ context: ORMCompileState,
+ query_entity: _MapperEntity,
+ path: AbstractEntityRegistry,
+ adapter: Optional[ORMAdapter],
+ **kwargs: Any,
+ ) -> None:
loader = self._get_context_loader(context, path)
if loader and loader.strategy:
strat = self._get_strategy(loader.strategy)
@@ -833,8 +841,15 @@ class StrategizedProperty(MapperProperty[_T]):
)
def create_row_processor(
- self, context, query_entity, path, mapper, result, adapter, populators
- ):
+ self,
+ context: ORMCompileState,
+ query_entity: _MapperEntity,
+ path: AbstractEntityRegistry,
+ mapper: Mapper[Any],
+ result: Result[Any],
+ adapter: Optional[ORMAdapter],
+ populators: _PopulatorDict,
+ ) -> None:
loader = self._get_context_loader(context, path)
if loader and loader.strategy:
strat = self._get_strategy(loader.strategy)
@@ -851,11 +866,11 @@ class StrategizedProperty(MapperProperty[_T]):
populators,
)
- def do_init(self):
+ def do_init(self) -> None:
self._strategies = {}
self.strategy = self._get_strategy(self.strategy_key)
- def post_instrument_class(self, mapper):
+ def post_instrument_class(self, mapper: Mapper[Any]) -> None:
if (
not self.parent.non_primary
and not mapper.class_manager._attr_has_impl(self.key)
@@ -863,7 +878,7 @@ class StrategizedProperty(MapperProperty[_T]):
self.strategy.init_class_attribute(mapper)
_all_strategies: collections.defaultdict[
- Type[Any], Dict[Tuple[Any, ...], Type[LoaderStrategy]]
+ Type[MapperProperty[Any]], Dict[_StrategyKey, Type[LoaderStrategy]]
] = collections.defaultdict(dict)
@classmethod
@@ -888,6 +903,8 @@ class StrategizedProperty(MapperProperty[_T]):
for prop_cls in cls.__mro__:
if prop_cls in cls._all_strategies:
+ if TYPE_CHECKING:
+ assert issubclass(prop_cls, MapperProperty)
strategies = cls._all_strategies[prop_cls]
try:
return strategies[key]
@@ -976,8 +993,8 @@ class CompileStateOption(HasCacheKey, ORMOption):
_is_compile_state = True
- def process_compile_state(self, compile_state):
- """Apply a modification to a given :class:`.CompileState`.
+ def process_compile_state(self, compile_state: ORMCompileState) -> None:
+ """Apply a modification to a given :class:`.ORMCompileState`.
This method is part of the implementation of a particular
:class:`.CompileStateOption` and is only invoked internally
@@ -986,9 +1003,11 @@ class CompileStateOption(HasCacheKey, ORMOption):
"""
def process_compile_state_replaced_entities(
- self, compile_state, mapper_entities
- ):
- """Apply a modification to a given :class:`.CompileState`,
+ self,
+ compile_state: ORMCompileState,
+ mapper_entities: Sequence[_MapperEntity],
+ ) -> None:
+ """Apply a modification to a given :class:`.ORMCompileState`,
given entities that were replaced by with_only_columns() or
with_entities().
@@ -1011,8 +1030,10 @@ class LoaderOption(CompileStateOption):
__slots__ = ()
def process_compile_state_replaced_entities(
- self, compile_state, mapper_entities
- ):
+ self,
+ compile_state: ORMCompileState,
+ mapper_entities: Sequence[_MapperEntity],
+ ) -> None:
self.process_compile_state(compile_state)
@@ -1028,7 +1049,7 @@ class CriteriaOption(CompileStateOption):
_is_criteria_option = True
- def get_global_criteria(self, attributes):
+ def get_global_criteria(self, attributes: Dict[str, Any]) -> None:
"""update additional entity criteria options in the given
attributes dictionary.
@@ -1054,7 +1075,7 @@ class UserDefinedOption(ORMOption):
"""
- def __init__(self, payload=None):
+ def __init__(self, payload: Optional[Any] = None):
self.payload = payload
@@ -1132,10 +1153,10 @@ class LoaderStrategy:
"strategy_opts",
)
- _strategy_keys: ClassVar[List[Tuple[Any, ...]]]
+ _strategy_keys: ClassVar[List[_StrategyKey]]
def __init__(
- self, parent: MapperProperty[Any], strategy_key: Tuple[Any, ...]
+ self, parent: MapperProperty[Any], strategy_key: _StrategyKey
):
self.parent_property = parent
self.is_class_level = False
@@ -1186,5 +1207,5 @@ class LoaderStrategy:
"""
- def __str__(self):
+ def __str__(self) -> str:
return str(self.parent_property)
diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py
index 75887367e..1a5ea5fe6 100644
--- a/lib/sqlalchemy/orm/loading.py
+++ b/lib/sqlalchemy/orm/loading.py
@@ -54,11 +54,15 @@ from ..sql.selectable import SelectState
if TYPE_CHECKING:
from ._typing import _IdentityKeyType
from .base import LoaderCallableStatus
+ from .context import QueryContext
from .interfaces import ORMOption
from .mapper import Mapper
+ from .query import Query
from .session import Session
from .state import InstanceState
+ from ..engine.cursor import CursorResult
from ..engine.interfaces import _ExecuteOptions
+ from ..engine.result import Result
from ..sql import Select
_T = TypeVar("_T", bound=Any)
@@ -69,7 +73,7 @@ _new_runid = util.counter()
_PopulatorDict = Dict[str, List[Tuple[str, Any]]]
-def instances(cursor, context):
+def instances(cursor: CursorResult[Any], context: QueryContext) -> Result[Any]:
"""Return a :class:`.Result` given an ORM query context.
:param cursor: a :class:`.CursorResult`, generated by a statement
@@ -152,7 +156,7 @@ def instances(cursor, context):
unique_filters = [
_no_unique
if context.yield_per
- else _not_hashable(ent.column.type)
+ else _not_hashable(ent.column.type) # type: ignore
if (not ent.use_id_for_hash and ent._non_hashable_value)
else id
if ent.use_id_for_hash
@@ -164,7 +168,7 @@ def instances(cursor, context):
labels, extra, _unique_filters=unique_filters
)
- def chunks(size):
+ def chunks(size): # type: ignore
while True:
yield_per = size
@@ -302,7 +306,11 @@ def merge_frozen_result(session, statement, frozen_result, load=True):
"is superseded by the :func:`_orm.merge_frozen_result` function.",
)
@util.preload_module("sqlalchemy.orm.context")
-def merge_result(query, iterator, load=True):
+def merge_result(
+ query: Query[Any],
+ iterator: Union[FrozenResult, Iterable[Sequence[Any]], Iterable[object]],
+ load: bool = True,
+) -> Union[FrozenResult, Iterable[Any]]:
"""Merge a result into the given :class:`.Query` object's Session.
See :meth:`_orm.Query.merge_result` for top-level documentation on this
@@ -375,7 +383,7 @@ def merge_result(query, iterator, load=True):
result.append(keyed_tuple(newrow))
if frozen_result:
- return frozen_result.with_data(result)
+ return frozen_result.with_new_rows(result)
else:
return iter(result)
finally:
diff --git a/lib/sqlalchemy/orm/mapped_collection.py b/lib/sqlalchemy/orm/mapped_collection.py
index 4324a000d..d1057ca5f 100644
--- a/lib/sqlalchemy/orm/mapped_collection.py
+++ b/lib/sqlalchemy/orm/mapped_collection.py
@@ -4,6 +4,7 @@
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
+# mypy: ignore-errors
from __future__ import annotations
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index 337a7178b..2d3bceb92 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -80,6 +80,7 @@ from ..sql import roles
from ..sql import util as sql_util
from ..sql import visitors
from ..sql.cache_key import MemoizedHasCacheKey
+from ..sql.elements import KeyedColumnElement
from ..sql.schema import Table
from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
from ..util import HasMemoized
@@ -108,7 +109,6 @@ if TYPE_CHECKING:
from ..sql.base import ReadOnlyColumnCollection
from ..sql.elements import ColumnClause
from ..sql.elements import ColumnElement
- from ..sql.elements import KeyedColumnElement
from ..sql.schema import Column
from ..sql.selectable import FromClause
from ..sql.util import ColumnAdapter
@@ -182,6 +182,7 @@ class Mapper(
dispatch: dispatcher[Mapper[_O]]
_dispose_called = False
+ _configure_failed: Any = False
_ready_for_configure = False
@util.deprecated_params(
@@ -710,8 +711,11 @@ class Mapper(
self.batch = batch
self.eager_defaults = eager_defaults
self.column_prefix = column_prefix
- self.polymorphic_on = (
- coercions.expect(
+
+ # interim - polymorphic_on is further refined in
+ # _configure_polymorphic_setter
+ self.polymorphic_on = ( # type: ignore
+ coercions.expect( # type: ignore
roles.ColumnArgumentOrKeyRole,
polymorphic_on,
argname="polymorphic_on",
@@ -1832,12 +1836,22 @@ class Mapper(
)
@util.preload_module("sqlalchemy.orm.descriptor_props")
- def _configure_property(self, key, prop, init=True, setparent=True):
+ def _configure_property(
+ self,
+ key: str,
+ prop_arg: Union[KeyedColumnElement[Any], MapperProperty[Any]],
+ init: bool = True,
+ setparent: bool = True,
+ ) -> MapperProperty[Any]:
descriptor_props = util.preloaded.orm_descriptor_props
- self._log("_configure_property(%s, %s)", key, prop.__class__.__name__)
+ self._log(
+ "_configure_property(%s, %s)", key, prop_arg.__class__.__name__
+ )
- if not isinstance(prop, MapperProperty):
- prop = self._property_from_column(key, prop)
+ if not isinstance(prop_arg, MapperProperty):
+ prop = self._property_from_column(key, prop_arg)
+ else:
+ prop = prop_arg
if isinstance(prop, properties.ColumnProperty):
col = self.persist_selectable.corresponding_column(prop.columns[0])
@@ -1950,18 +1964,23 @@ class Mapper(
if self.configured:
self._expire_memoizations()
+ return prop
+
@util.preload_module("sqlalchemy.orm.descriptor_props")
- def _property_from_column(self, key, prop):
+ def _property_from_column(
+ self,
+ key: str,
+ prop_arg: Union[KeyedColumnElement[Any], MapperProperty[Any]],
+ ) -> MapperProperty[Any]:
"""generate/update a :class:`.ColumnProperty` given a
:class:`_schema.Column` object."""
descriptor_props = util.preloaded.orm_descriptor_props
# we were passed a Column or a list of Columns;
# generate a properties.ColumnProperty
- columns = util.to_list(prop)
+ columns = util.to_list(prop_arg)
column = columns[0]
- assert isinstance(column, expression.ColumnElement)
- prop = self._props.get(key, None)
+ prop = self._props.get(key)
if isinstance(prop, properties.ColumnProperty):
if (
@@ -2033,11 +2052,11 @@ class Mapper(
"columns get mapped." % (key, self, column.key, prop)
)
- def _check_configure(self):
+ def _check_configure(self) -> None:
if self.registry._new_mappers:
_configure_registries({self.registry}, cascade=True)
- def _post_configure_properties(self):
+ def _post_configure_properties(self) -> None:
"""Call the ``init()`` method on all ``MapperProperties``
attached to this mapper.
@@ -2068,7 +2087,9 @@ class Mapper(
for key, value in dict_of_properties.items():
self.add_property(key, value)
- def add_property(self, key, prop):
+ def add_property(
+ self, key: str, prop: Union[Column[Any], MapperProperty[Any]]
+ ) -> None:
"""Add an individual MapperProperty to this mapper.
If the mapper has not been configured yet, just adds the
@@ -2077,15 +2098,16 @@ class Mapper(
the given MapperProperty is configured immediately.
"""
+ prop = self._configure_property(key, prop, init=self.configured)
+ assert isinstance(prop, MapperProperty)
self._init_properties[key] = prop
- self._configure_property(key, prop, init=self.configured)
- def _expire_memoizations(self):
+ def _expire_memoizations(self) -> None:
for mapper in self.iterate_to_root():
mapper._reset_memoizations()
@property
- def _log_desc(self):
+ def _log_desc(self) -> str:
return (
"("
+ self.class_.__name__
@@ -2099,16 +2121,16 @@ class Mapper(
+ ")"
)
- def _log(self, msg, *args):
+ def _log(self, msg: str, *args: Any) -> None:
self.logger.info("%s " + msg, *((self._log_desc,) + args))
- def _log_debug(self, msg, *args):
+ def _log_debug(self, msg: str, *args: Any) -> None:
self.logger.debug("%s " + msg, *((self._log_desc,) + args))
- def __repr__(self):
+ def __repr__(self) -> str:
return "<Mapper at 0x%x; %s>" % (id(self), self.class_.__name__)
- def __str__(self):
+ def __str__(self) -> str:
return "Mapper[%s%s(%s)]" % (
self.class_.__name__,
self.non_primary and " (non-primary)" or "",
@@ -2155,7 +2177,9 @@ class Mapper(
"Mapper '%s' has no property '%s'" % (self, key)
) from err
- def get_property_by_column(self, column):
+ def get_property_by_column(
+ self, column: ColumnElement[_T]
+ ) -> MapperProperty[_T]:
"""Given a :class:`_schema.Column` object, return the
:class:`.MapperProperty` which maps this column."""
@@ -2795,7 +2819,7 @@ class Mapper(
return result
- def _is_userland_descriptor(self, assigned_name, obj):
+ def _is_userland_descriptor(self, assigned_name: str, obj: Any) -> bool:
if isinstance(
obj,
(
@@ -3603,7 +3627,9 @@ def configure_mappers():
_configure_registries(_all_registries(), cascade=True)
-def _configure_registries(registries, cascade):
+def _configure_registries(
+ registries: Set[_RegistryType], cascade: bool
+) -> None:
for reg in registries:
if reg._new_mappers:
break
@@ -3637,7 +3663,9 @@ def _configure_registries(registries, cascade):
@util.preload_module("sqlalchemy.orm.decl_api")
-def _do_configure_registries(registries, cascade):
+def _do_configure_registries(
+ registries: Set[_RegistryType], cascade: bool
+) -> None:
registry = util.preloaded.orm_decl_api.registry
@@ -3688,7 +3716,7 @@ def _do_configure_registries(registries, cascade):
@util.preload_module("sqlalchemy.orm.decl_api")
-def _dispose_registries(registries, cascade):
+def _dispose_registries(registries: Set[_RegistryType], cascade: bool) -> None:
registry = util.preloaded.orm_decl_api.registry
diff --git a/lib/sqlalchemy/orm/path_registry.py b/lib/sqlalchemy/orm/path_registry.py
index 361cea975..36c14a672 100644
--- a/lib/sqlalchemy/orm/path_registry.py
+++ b/lib/sqlalchemy/orm/path_registry.py
@@ -42,6 +42,7 @@ if TYPE_CHECKING:
from ..sql.cache_key import _CacheKeyTraversalType
from ..sql.elements import BindParameter
from ..sql.visitors import anon_map
+ from ..util.typing import _LiteralStar
from ..util.typing import TypeGuard
def is_root(path: PathRegistry) -> TypeGuard[RootRegistry]:
@@ -80,7 +81,7 @@ def _unreduce_path(path: _SerializedPath) -> PathRegistry:
return PathRegistry.deserialize(path)
-_WILDCARD_TOKEN = "*"
+_WILDCARD_TOKEN: _LiteralStar = "*"
_DEFAULT_TOKEN = "_sa_default"
@@ -115,6 +116,7 @@ class PathRegistry(HasCacheKey):
is_token = False
is_root = False
has_entity = False
+ is_property = False
is_entity = False
path: _PathRepresentation
@@ -175,7 +177,40 @@ class PathRegistry(HasCacheKey):
def __hash__(self) -> int:
return id(self)
- def __getitem__(self, key: Any) -> PathRegistry:
+ @overload
+ def __getitem__(self, entity: str) -> TokenRegistry:
+ ...
+
+ @overload
+ def __getitem__(self, entity: int) -> _PathElementType:
+ ...
+
+ @overload
+ def __getitem__(self, entity: slice) -> _PathRepresentation:
+ ...
+
+ @overload
+ def __getitem__(
+ self, entity: _InternalEntityType[Any]
+ ) -> AbstractEntityRegistry:
+ ...
+
+ @overload
+ def __getitem__(self, entity: MapperProperty[Any]) -> PropRegistry:
+ ...
+
+ def __getitem__(
+ self,
+ entity: Union[
+ str, int, slice, _InternalEntityType[Any], MapperProperty[Any]
+ ],
+ ) -> Union[
+ TokenRegistry,
+ _PathElementType,
+ _PathRepresentation,
+ PropRegistry,
+ AbstractEntityRegistry,
+ ]:
raise NotImplementedError()
# TODO: what are we using this for?
@@ -343,18 +378,8 @@ class RootRegistry(CreatesToken):
is_root = True
is_unnatural = False
- @overload
- def __getitem__(self, entity: str) -> TokenRegistry:
- ...
-
- @overload
- def __getitem__(
- self, entity: _InternalEntityType[Any]
- ) -> AbstractEntityRegistry:
- ...
-
- def __getitem__(
- self, entity: Union[str, _InternalEntityType[Any]]
+ def _getitem(
+ self, entity: Any
) -> Union[TokenRegistry, AbstractEntityRegistry]:
if entity in PathToken._intern:
if TYPE_CHECKING:
@@ -368,6 +393,9 @@ class RootRegistry(CreatesToken):
f"invalid argument for RootRegistry.__getitem__: {entity}"
)
+ if not TYPE_CHECKING:
+ __getitem__ = _getitem
+
PathRegistry.root = RootRegistry()
@@ -441,12 +469,15 @@ class TokenRegistry(PathRegistry):
else:
yield self
- def __getitem__(self, entity: Any) -> Any:
+ def _getitem(self, entity: Any) -> Any:
try:
return self.path[entity]
except TypeError as err:
raise IndexError(f"{entity}") from err
+ if not TYPE_CHECKING:
+ __getitem__ = _getitem
+
class PropRegistry(PathRegistry):
__slots__ = (
@@ -463,6 +494,7 @@ class PropRegistry(PathRegistry):
"is_unnatural",
)
inherit_cache = True
+ is_property = True
prop: MapperProperty[Any]
mapper: Optional[Mapper[Any]]
@@ -557,21 +589,7 @@ class PropRegistry(PathRegistry):
assert self.entity is not None
return self[self.entity]
- @overload
- def __getitem__(self, entity: slice) -> _PathRepresentation:
- ...
-
- @overload
- def __getitem__(self, entity: int) -> _PathElementType:
- ...
-
- @overload
- def __getitem__(
- self, entity: _InternalEntityType[Any]
- ) -> AbstractEntityRegistry:
- ...
-
- def __getitem__(
+ def _getitem(
self, entity: Union[int, slice, _InternalEntityType[Any]]
) -> Union[AbstractEntityRegistry, _PathElementType, _PathRepresentation]:
if isinstance(entity, (int, slice)):
@@ -579,6 +597,9 @@ class PropRegistry(PathRegistry):
else:
return SlotsEntityRegistry(self, entity)
+ if not TYPE_CHECKING:
+ __getitem__ = _getitem
+
class AbstractEntityRegistry(CreatesToken):
__slots__ = (
@@ -643,6 +664,10 @@ class AbstractEntityRegistry(CreatesToken):
self.natural_path = self.path
@property
+ def root_entity(self) -> _InternalEntityType[Any]:
+ return cast("_InternalEntityType[Any]", self.path[0])
+
+ @property
def entity_path(self) -> PathRegistry:
return self
@@ -653,23 +678,7 @@ class AbstractEntityRegistry(CreatesToken):
def __bool__(self) -> bool:
return True
- @overload
- def __getitem__(self, entity: MapperProperty[Any]) -> PropRegistry:
- ...
-
- @overload
- def __getitem__(self, entity: str) -> TokenRegistry:
- ...
-
- @overload
- def __getitem__(self, entity: int) -> _PathElementType:
- ...
-
- @overload
- def __getitem__(self, entity: slice) -> _PathRepresentation:
- ...
-
- def __getitem__(
+ def _getitem(
self, entity: Any
) -> Union[_PathElementType, _PathRepresentation, PathRegistry]:
if isinstance(entity, (int, slice)):
@@ -679,6 +688,9 @@ class AbstractEntityRegistry(CreatesToken):
else:
return PropRegistry(self, entity)
+ if not TYPE_CHECKING:
+ __getitem__ = _getitem
+
class SlotsEntityRegistry(AbstractEntityRegistry):
# for aliased class, return lightweight, no-cycles created
@@ -715,10 +727,28 @@ class CachingEntityRegistry(AbstractEntityRegistry):
def pop(self, key: Any, default: Any) -> Any:
return self._cache.pop(key, default)
- def __getitem__(self, entity: Any) -> Any:
+ def _getitem(self, entity: Any) -> Any:
if isinstance(entity, (int, slice)):
return self.path[entity]
elif isinstance(entity, PathToken):
return TokenRegistry(self, entity)
else:
return self._cache[entity]
+
+ if not TYPE_CHECKING:
+ __getitem__ = _getitem
+
+
+if TYPE_CHECKING:
+
+ def path_is_entity(
+ path: PathRegistry,
+ ) -> TypeGuard[AbstractEntityRegistry]:
+ ...
+
+ def path_is_property(path: PathRegistry) -> TypeGuard[PropRegistry]:
+ ...
+
+else:
+ path_is_entity = operator.attrgetter("is_entity")
+ path_is_property = operator.attrgetter("is_property")
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index 0ca0559b4..911617d6d 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -16,8 +16,10 @@ from __future__ import annotations
from typing import Any
from typing import cast
+from typing import Dict
from typing import List
from typing import Optional
+from typing import Sequence
from typing import Set
from typing import Type
from typing import TYPE_CHECKING
@@ -25,7 +27,6 @@ from typing import TypeVar
from . import attributes
from . import strategy_options
-from .base import SQLCoreOperations
from .descriptor_props import Composite
from .descriptor_props import ConcreteInheritedProperty
from .descriptor_props import Synonym
@@ -44,20 +45,34 @@ from .. import util
from ..sql import coercions
from ..sql import roles
from ..sql import sqltypes
+from ..sql.elements import SQLCoreOperations
from ..sql.schema import Column
from ..sql.schema import SchemaConst
from ..util.typing import de_optionalize_union_types
from ..util.typing import de_stringify_annotation
from ..util.typing import is_fwd_ref
from ..util.typing import NoneType
+from ..util.typing import Self
if TYPE_CHECKING:
+ from ._typing import _IdentityKeyType
+ from ._typing import _InstanceDict
from ._typing import _ORMColumnExprArgument
+ from ._typing import _RegistryType
+ from .mapper import Mapper
+ from .session import Session
+ from .state import _InstallLoaderCallableProto
+ from .state import InstanceState
from ..sql._typing import _InfoType
- from ..sql.elements import KeyedColumnElement
+ from ..sql.elements import ColumnElement
+ from ..sql.elements import NamedColumn
+ from ..sql.operators import OperatorType
+ from ..util.typing import _AnnotationScanType
+ from ..util.typing import RODescriptorReference
_T = TypeVar("_T", bound=Any)
_PT = TypeVar("_PT", bound=Any)
+_NC = TypeVar("_NC", bound="NamedColumn[Any]")
__all__ = [
"ColumnProperty",
@@ -85,11 +100,15 @@ class ColumnProperty(
inherit_cache = True
_links_to_entity = False
- columns: List[KeyedColumnElement[Any]]
- _orig_columns: List[KeyedColumnElement[Any]]
+ columns: List[NamedColumn[Any]]
+ _orig_columns: List[NamedColumn[Any]]
_is_polymorphic_discriminator: bool
+ _mapped_by_synonym: Optional[str]
+
+ comparator_factory: Type[PropComparator[_T]]
+
__slots__ = (
"_orig_columns",
"columns",
@@ -100,7 +119,6 @@ class ColumnProperty(
"descriptor",
"active_history",
"expire_on_flush",
- "doc",
"_creation_order",
"_is_polymorphic_discriminator",
"_mapped_by_synonym",
@@ -117,7 +135,7 @@ class ColumnProperty(
group: Optional[str] = None,
deferred: bool = False,
raiseload: bool = False,
- comparator_factory: Optional[Type[PropComparator]] = None,
+ comparator_factory: Optional[Type[PropComparator[_T]]] = None,
descriptor: Optional[Any] = None,
active_history: bool = False,
expire_on_flush: bool = True,
@@ -150,7 +168,7 @@ class ColumnProperty(
self.expire_on_flush = expire_on_flush
if info is not None:
- self.info = info
+ self.info.update(info)
if doc is not None:
self.doc = doc
@@ -173,8 +191,13 @@ class ColumnProperty(
self.strategy_key += (("raiseload", True),)
def declarative_scan(
- self, registry, cls, key, annotation, is_dataclass_field
- ):
+ self,
+ registry: _RegistryType,
+ cls: Type[Any],
+ key: str,
+ annotation: Optional[_AnnotationScanType],
+ is_dataclass_field: bool,
+ ) -> None:
column = self.columns[0]
if column.key is None:
column.key = key
@@ -186,20 +209,23 @@ class ColumnProperty(
return self
@property
- def columns_to_assign(self) -> List[Column]:
+ def columns_to_assign(self) -> List[Column[Any]]:
+ # mypy doesn't care about the isinstance here
return [
- c
+ c # type: ignore
for c in self.columns
if isinstance(c, Column) and c.table is None
]
- def _memoized_attr__renders_in_subqueries(self):
+ def _memoized_attr__renders_in_subqueries(self) -> bool:
return ("deferred", True) not in self.strategy_key or (
- self not in self.parent._readonly_props
+ self not in self.parent._readonly_props # type: ignore
)
@util.preload_module("sqlalchemy.orm.state", "sqlalchemy.orm.strategies")
- def _memoized_attr__deferred_column_loader(self):
+ def _memoized_attr__deferred_column_loader(
+ self,
+ ) -> _InstallLoaderCallableProto[Any]:
state = util.preloaded.orm_state
strategies = util.preloaded.orm_strategies
return state.InstanceState._instance_level_callable_processor(
@@ -209,7 +235,9 @@ class ColumnProperty(
)
@util.preload_module("sqlalchemy.orm.state", "sqlalchemy.orm.strategies")
- def _memoized_attr__raise_column_loader(self):
+ def _memoized_attr__raise_column_loader(
+ self,
+ ) -> _InstallLoaderCallableProto[Any]:
state = util.preloaded.orm_state
strategies = util.preloaded.orm_strategies
return state.InstanceState._instance_level_callable_processor(
@@ -218,7 +246,7 @@ class ColumnProperty(
self.key,
)
- def __clause_element__(self):
+ def __clause_element__(self) -> roles.ColumnsClauseRole:
"""Allow the ColumnProperty to work in expression before it is turned
into an instrumented attribute.
"""
@@ -226,7 +254,7 @@ class ColumnProperty(
return self.expression
@property
- def expression(self):
+ def expression(self) -> roles.ColumnsClauseRole:
"""Return the primary column or expression for this ColumnProperty.
E.g.::
@@ -247,7 +275,7 @@ class ColumnProperty(
"""
return self.columns[0]
- def instrument_class(self, mapper):
+ def instrument_class(self, mapper: Mapper[Any]) -> None:
if not self.instrument:
return
@@ -259,7 +287,7 @@ class ColumnProperty(
doc=self.doc,
)
- def do_init(self):
+ def do_init(self) -> None:
super().do_init()
if len(self.columns) > 1 and set(self.parent.primary_key).issuperset(
@@ -275,32 +303,25 @@ class ColumnProperty(
% (self.parent, self.columns[1], self.columns[0], self.key)
)
- def copy(self):
+ def copy(self) -> ColumnProperty[_T]:
return ColumnProperty(
+ *self.columns,
deferred=self.deferred,
group=self.group,
active_history=self.active_history,
- *self.columns,
- )
-
- def _getcommitted(
- self, state, dict_, column, passive=attributes.PASSIVE_OFF
- ):
- return state.get_impl(self.key).get_committed_value(
- state, dict_, passive=passive
)
def merge(
self,
- session,
- source_state,
- source_dict,
- dest_state,
- dest_dict,
- load,
- _recursive,
- _resolve_conflict_map,
- ):
+ session: Session,
+ source_state: InstanceState[Any],
+ source_dict: _InstanceDict,
+ dest_state: InstanceState[Any],
+ dest_dict: _InstanceDict,
+ load: bool,
+ _recursive: Dict[Any, object],
+ _resolve_conflict_map: Dict[_IdentityKeyType[Any], object],
+ ) -> None:
if not self.instrument:
return
elif self.key in source_dict:
@@ -335,9 +356,13 @@ class ColumnProperty(
"""
- __slots__ = "__clause_element__", "info", "expressions"
+ if not TYPE_CHECKING:
+ # prevent pylance from being clever about slots
+ __slots__ = "__clause_element__", "info", "expressions"
+
+ prop: RODescriptorReference[ColumnProperty[_PT]]
- def _orm_annotate_column(self, column):
+ def _orm_annotate_column(self, column: _NC) -> _NC:
"""annotate and possibly adapt a column to be returned
as the mapped-attribute exposed version of the column.
@@ -351,7 +376,7 @@ class ColumnProperty(
"""
pe = self._parententity
- annotations = {
+ annotations: Dict[str, Any] = {
"entity_namespace": pe,
"parententity": pe,
"parentmapper": pe,
@@ -377,22 +402,29 @@ class ColumnProperty(
{"compile_state_plugin": "orm", "plugin_subject": pe}
)
- def _memoized_method___clause_element__(self):
+ if TYPE_CHECKING:
+
+ def __clause_element__(self) -> NamedColumn[_PT]:
+ ...
+
+ def _memoized_method___clause_element__(
+ self,
+ ) -> NamedColumn[_PT]:
if self.adapter:
return self.adapter(self.prop.columns[0], self.prop.key)
else:
return self._orm_annotate_column(self.prop.columns[0])
- def _memoized_attr_info(self):
+ def _memoized_attr_info(self) -> _InfoType:
"""The .info dictionary for this attribute."""
ce = self.__clause_element__()
try:
- return ce.info
+ return ce.info # type: ignore
except AttributeError:
return self.prop.info
- def _memoized_attr_expressions(self):
+ def _memoized_attr_expressions(self) -> Sequence[NamedColumn[Any]]:
"""The full sequence of columns referenced by this
attribute, adjusted for any aliasing in progress.
@@ -409,21 +441,25 @@ class ColumnProperty(
self._orm_annotate_column(col) for col in self.prop.columns
]
- def _fallback_getattr(self, key):
+ def _fallback_getattr(self, key: str) -> Any:
"""proxy attribute access down to the mapped column.
this allows user-defined comparison methods to be accessed.
"""
return getattr(self.__clause_element__(), key)
- def operate(self, op, *other, **kwargs):
- return op(self.__clause_element__(), *other, **kwargs)
+ def operate(
+ self, op: OperatorType, *other: Any, **kwargs: Any
+ ) -> ColumnElement[Any]:
+ return op(self.__clause_element__(), *other, **kwargs) # type: ignore[return-value] # noqa: E501
- def reverse_operate(self, op, other, **kwargs):
+ def reverse_operate(
+ self, op: OperatorType, other: Any, **kwargs: Any
+ ) -> ColumnElement[Any]:
col = self.__clause_element__()
- return op(col._bind_param(op, other), col, **kwargs)
+ return op(col._bind_param(op, other), col, **kwargs) # type: ignore[return-value] # noqa: E501
- def __str__(self):
+ def __str__(self) -> str:
if not self.parent or not self.key:
return object.__repr__(self)
return str(self.parent.class_.__name__) + "." + self.key
@@ -460,7 +496,7 @@ class MappedColumn(
column: Column[_T]
foreign_keys: Optional[Set[ForeignKey]]
- def __init__(self, *arg, **kw):
+ def __init__(self, *arg: Any, **kw: Any):
self.deferred = kw.pop("deferred", False)
self.column = cast("Column[_T]", Column(*arg, **kw))
self.foreign_keys = self.column.foreign_keys
@@ -470,8 +506,8 @@ class MappedColumn(
)
util.set_creation_order(self)
- def _copy(self, **kw):
- new = self.__class__.__new__(self.__class__)
+ def _copy(self: Self, **kw: Any) -> Self:
+ new = cast(Self, self.__class__.__new__(self.__class__))
new.column = self.column._copy(**kw)
new.deferred = self.deferred
new.foreign_keys = new.column.foreign_keys
@@ -487,22 +523,31 @@ class MappedColumn(
return None
@property
- def columns_to_assign(self) -> List[Column]:
+ def columns_to_assign(self) -> List[Column[Any]]:
return [self.column]
- def __clause_element__(self):
+ def __clause_element__(self) -> Column[_T]:
return self.column
- def operate(self, op, *other, **kwargs):
- return op(self.__clause_element__(), *other, **kwargs)
+ def operate(
+ self, op: OperatorType, *other: Any, **kwargs: Any
+ ) -> ColumnElement[Any]:
+ return op(self.__clause_element__(), *other, **kwargs) # type: ignore[return-value] # noqa: E501
- def reverse_operate(self, op, other, **kwargs):
+ def reverse_operate(
+ self, op: OperatorType, other: Any, **kwargs: Any
+ ) -> ColumnElement[Any]:
col = self.__clause_element__()
- return op(col._bind_param(op, other), col, **kwargs)
+ return op(col._bind_param(op, other), col, **kwargs) # type: ignore[return-value] # noqa: E501
def declarative_scan(
- self, registry, cls, key, annotation, is_dataclass_field
- ):
+ self,
+ registry: _RegistryType,
+ cls: Type[Any],
+ key: str,
+ annotation: Optional[_AnnotationScanType],
+ is_dataclass_field: bool,
+ ) -> None:
column = self.column
if column.key is None:
column.key = key
@@ -526,38 +571,48 @@ class MappedColumn(
@util.preload_module("sqlalchemy.orm.decl_base")
def declarative_scan_for_composite(
- self, registry, cls, key, param_name, param_annotation
- ):
+ self,
+ registry: _RegistryType,
+ cls: Type[Any],
+ key: str,
+ param_name: str,
+ param_annotation: _AnnotationScanType,
+ ) -> None:
decl_base = util.preloaded.orm_decl_base
decl_base._undefer_column_name(param_name, self.column)
self._init_column_for_annotation(cls, registry, param_annotation)
- def _init_column_for_annotation(self, cls, registry, argument):
+ def _init_column_for_annotation(
+ self,
+ cls: Type[Any],
+ registry: _RegistryType,
+ argument: _AnnotationScanType,
+ ) -> None:
sqltype = self.column.type
nullable = False
if hasattr(argument, "__origin__"):
- nullable = NoneType in argument.__args__
+ nullable = NoneType in argument.__args__ # type: ignore
if not self._has_nullable:
self.column.nullable = nullable
if sqltype._isnull and not self.column.foreign_keys:
- sqltype = None
+ new_sqltype = None
our_type = de_optionalize_union_types(argument)
if is_fwd_ref(our_type):
our_type = de_stringify_annotation(cls, our_type)
if registry.type_annotation_map:
- sqltype = registry.type_annotation_map.get(our_type)
- if sqltype is None:
- sqltype = sqltypes._type_map_get(our_type)
+ new_sqltype = registry.type_annotation_map.get(our_type)
+ if new_sqltype is None:
+ new_sqltype = sqltypes._type_map_get(our_type) # type: ignore
- if sqltype is None:
+ if new_sqltype is None:
raise sa_exc.ArgumentError(
f"Could not locate SQLAlchemy Core "
f"type for Python type: {our_type}"
)
- self.column.type = sqltype
+ self.column.type = new_sqltype # type: ignore
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index a60a167ac..419891708 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -23,18 +23,23 @@ from __future__ import annotations
import collections.abc as collections_abc
import operator
from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Dict
from typing import Generic
from typing import Iterable
from typing import List
+from typing import Mapping
from typing import Optional
from typing import overload
from typing import Sequence
from typing import Tuple
+from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
-from . import exc as orm_exc
+from . import attributes
from . import interfaces
from . import loading
from . import util as orm_util
@@ -44,7 +49,6 @@ from .context import _column_descriptions
from .context import _determine_last_joined_entity
from .context import _legacy_filter_by_entity_zero
from .context import FromStatement
-from .context import LABEL_STYLE_LEGACY_ORM
from .context import ORMCompileState
from .context import QueryContext
from .interfaces import ORMColumnDescription
@@ -60,6 +64,8 @@ from .. import sql
from .. import util
from ..engine import Result
from ..engine import Row
+from ..event import dispatcher
+from ..event import EventTarget
from ..sql import coercions
from ..sql import expression
from ..sql import roles
@@ -71,8 +77,10 @@ from ..sql._typing import _TP
from ..sql.annotation import SupportsCloneAnnotations
from ..sql.base import _entity_namespace_key
from ..sql.base import _generative
+from ..sql.base import _NoArg
from ..sql.base import Executable
from ..sql.base import Generative
+from ..sql.elements import BooleanClauseList
from ..sql.expression import Exists
from ..sql.selectable import _MemoizedSelectEntities
from ..sql.selectable import _SelectFromElements
@@ -81,17 +89,31 @@ from ..sql.selectable import HasHints
from ..sql.selectable import HasPrefixes
from ..sql.selectable import HasSuffixes
from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
+from ..sql.selectable import SelectLabelStyle
from ..util.typing import Literal
+from ..util.typing import Self
if TYPE_CHECKING:
from ._typing import _EntityType
+ from ._typing import _ExternalEntityType
+ from ._typing import _InternalEntityType
+ from .mapper import Mapper
+ from .path_registry import PathRegistry
+ from .session import _PKIdentityArgument
from .session import Session
+ from .state import InstanceState
+ from ..engine.cursor import CursorResult
+ from ..engine.interfaces import _ImmutableExecuteOptions
+ from ..engine.result import FrozenResult
from ..engine.result import ScalarResult
from ..sql._typing import _ColumnExpressionArgument
from ..sql._typing import _ColumnsClauseArgument
+ from ..sql._typing import _DMLColumnArgument
+ from ..sql._typing import _JoinTargetArgument
from ..sql._typing import _MAYBE_ENTITY
from ..sql._typing import _no_kw
from ..sql._typing import _NOT_ENTITY
+ from ..sql._typing import _OnClauseArgument
from ..sql._typing import _PropagateAttrsType
from ..sql._typing import _T0
from ..sql._typing import _T1
@@ -102,18 +124,25 @@ if TYPE_CHECKING:
from ..sql._typing import _T6
from ..sql._typing import _T7
from ..sql._typing import _TypedColumnClauseArgument as _TCCA
- from ..sql.roles import TypedColumnsClauseRole
+ from ..sql.base import CacheableOptions
+ from ..sql.base import ExecutableOption
+ from ..sql.elements import ColumnElement
+ from ..sql.elements import Label
+ from ..sql.selectable import _JoinTargetElement
from ..sql.selectable import _SetupJoinsElement
from ..sql.selectable import Alias
+ from ..sql.selectable import CTE
from ..sql.selectable import ExecutableReturnsRows
+ from ..sql.selectable import FromClause
from ..sql.selectable import ScalarSelect
from ..sql.selectable import Subquery
+
__all__ = ["Query", "QueryContext"]
_T = TypeVar("_T", bound=Any)
-SelfQuery = TypeVar("SelfQuery", bound="Query")
+SelfQuery = TypeVar("SelfQuery", bound="Query[Any]")
@inspection._self_inspects
@@ -124,6 +153,7 @@ class Query(
HasPrefixes,
HasSuffixes,
HasHints,
+ EventTarget,
log.Identified,
Generative,
Executable,
@@ -150,40 +180,47 @@ class Query(
"""
# elements that are in Core and can be cached in the same way
- _where_criteria = ()
- _having_criteria = ()
+ _where_criteria: Tuple[ColumnElement[Any], ...] = ()
+ _having_criteria: Tuple[ColumnElement[Any], ...] = ()
- _order_by_clauses = ()
- _group_by_clauses = ()
- _limit_clause = None
- _offset_clause = None
+ _order_by_clauses: Tuple[ColumnElement[Any], ...] = ()
+ _group_by_clauses: Tuple[ColumnElement[Any], ...] = ()
+ _limit_clause: Optional[ColumnElement[Any]] = None
+ _offset_clause: Optional[ColumnElement[Any]] = None
- _distinct = False
- _distinct_on = ()
+ _distinct: bool = False
+ _distinct_on: Tuple[ColumnElement[Any], ...] = ()
- _for_update_arg = None
- _correlate = ()
- _auto_correlate = True
- _from_obj = ()
+ _for_update_arg: Optional[ForUpdateArg] = None
+ _correlate: Tuple[FromClause, ...] = ()
+ _auto_correlate: bool = True
+ _from_obj: Tuple[FromClause, ...] = ()
_setup_joins: Tuple[_SetupJoinsElement, ...] = ()
- _label_style = LABEL_STYLE_LEGACY_ORM
+ _label_style: SelectLabelStyle = SelectLabelStyle.LABEL_STYLE_LEGACY_ORM
_memoized_select_entities = ()
- _compile_options = ORMCompileState.default_compile_options
+ _compile_options: Union[
+ Type[CacheableOptions], CacheableOptions
+ ] = ORMCompileState.default_compile_options
+ _with_options: Tuple[ExecutableOption, ...]
load_options = QueryContext.default_load_options + {
"_legacy_uniquing": True
}
- _params = util.EMPTY_DICT
+ _params: util.immutabledict[str, Any] = util.EMPTY_DICT
# local Query builder state, not needed for
# compilation or execution
_enable_assertions = True
- _statement = None
+ _statement: Optional[ExecutableReturnsRows] = None
+
+ session: Session
+
+ dispatch: dispatcher[Query[_T]]
# mirrors that of ClauseElement, used to propagate the "orm"
# plugin as well as the "subject" of the plugin, e.g. the mapper
@@ -224,14 +261,23 @@ class Query(
"""
- self.session = session
+ # session is usually present. There's one case in subqueryloader
+ # where it stores a Query without a Session and also there are tests
+ # for the query(Entity).with_session(session) API which is likely in
+ # some old recipes, however these are legacy as select() can now be
+ # used.
+ self.session = session # type: ignore
self._set_entities(entities)
- def _set_propagate_attrs(self, values):
- self._propagate_attrs = util.immutabledict(values)
+ def _set_propagate_attrs(
+ self: SelfQuery, values: Mapping[str, Any]
+ ) -> SelfQuery:
+ self._propagate_attrs = util.immutabledict(values) # type: ignore
return self
- def _set_entities(self, entities):
+ def _set_entities(
+ self, entities: Iterable[_ColumnsClauseArgument[Any]]
+ ) -> None:
self._raw_columns = [
coercions.expect(
roles.ColumnsClauseRole,
@@ -242,15 +288,7 @@ class Query(
for ent in util.to_list(entities)
]
- @overload
- def tuples(self: Query[Row[_TP]]) -> Query[_TP]:
- ...
-
- @overload
def tuples(self: Query[_O]) -> Query[Tuple[_O]]:
- ...
-
- def tuples(self) -> Query[Any]:
"""return a tuple-typed form of this :class:`.Query`.
This method invokes the :meth:`.Query.only_return_tuples`
@@ -270,29 +308,27 @@ class Query(
.. versionadded:: 2.0
"""
- return self.only_return_tuples(True)
+ return self.only_return_tuples(True) # type: ignore
- def _entity_from_pre_ent_zero(self):
+ def _entity_from_pre_ent_zero(self) -> Optional[_InternalEntityType[Any]]:
if not self._raw_columns:
return None
ent = self._raw_columns[0]
if "parententity" in ent._annotations:
- return ent._annotations["parententity"]
- elif isinstance(ent, ORMColumnsClauseRole):
- return ent.entity
+ return ent._annotations["parententity"] # type: ignore
elif "bundle" in ent._annotations:
- return ent._annotations["bundle"]
+ return ent._annotations["bundle"] # type: ignore
else:
# label, other SQL expression
for element in visitors.iterate(ent):
if "parententity" in element._annotations:
- return element._annotations["parententity"]
+ return element._annotations["parententity"] # type: ignore # noqa: E501
else:
return None
- def _only_full_mapper_zero(self, methname):
+ def _only_full_mapper_zero(self, methname: str) -> Mapper[Any]:
if (
len(self._raw_columns) != 1
or "parententity" not in self._raw_columns[0]._annotations
@@ -303,9 +339,11 @@ class Query(
"a single mapped class." % methname
)
- return self._raw_columns[0]._annotations["parententity"]
+ return self._raw_columns[0]._annotations["parententity"] # type: ignore # noqa: E501
- def _set_select_from(self, obj, set_base_alias):
+ def _set_select_from(
+ self, obj: Iterable[_FromClauseArgument], set_base_alias: bool
+ ) -> None:
fa = [
coercions.expect(
roles.StrictFromClauseRole,
@@ -320,19 +358,22 @@ class Query(
self._from_obj = tuple(fa)
@_generative
- def _set_lazyload_from(self: SelfQuery, state) -> SelfQuery:
+ def _set_lazyload_from(
+ self: SelfQuery, state: InstanceState[Any]
+ ) -> SelfQuery:
self.load_options += {"_lazy_loaded_from": state}
return self
- def _get_condition(self):
- return self._no_criterion_condition(
- "get", order_by=False, distinct=False
- )
+ def _get_condition(self) -> None:
+ """used by legacy BakedQuery"""
+ self._no_criterion_condition("get", order_by=False, distinct=False)
- def _get_existing_condition(self):
+ def _get_existing_condition(self) -> None:
self._no_criterion_assertion("get", order_by=False, distinct=False)
- def _no_criterion_assertion(self, meth, order_by=True, distinct=True):
+ def _no_criterion_assertion(
+ self, meth: str, order_by: bool = True, distinct: bool = True
+ ) -> None:
if not self._enable_assertions:
return
if (
@@ -351,7 +392,9 @@ class Query(
"Query with existing criterion. " % meth
)
- def _no_criterion_condition(self, meth, order_by=True, distinct=True):
+ def _no_criterion_condition(
+ self, meth: str, order_by: bool = True, distinct: bool = True
+ ) -> None:
self._no_criterion_assertion(meth, order_by, distinct)
self._from_obj = self._setup_joins = ()
@@ -362,7 +405,7 @@ class Query(
self._order_by_clauses = self._group_by_clauses = ()
- def _no_clauseelement_condition(self, meth):
+ def _no_clauseelement_condition(self, meth: str) -> None:
if not self._enable_assertions:
return
if self._order_by_clauses:
@@ -372,7 +415,7 @@ class Query(
)
self._no_criterion_condition(meth)
- def _no_statement_condition(self, meth):
+ def _no_statement_condition(self, meth: str) -> None:
if not self._enable_assertions:
return
if self._statement is not None:
@@ -384,7 +427,7 @@ class Query(
% meth
)
- def _no_limit_offset(self, meth):
+ def _no_limit_offset(self, meth: str) -> None:
if not self._enable_assertions:
return
if self._limit_clause is not None or self._offset_clause is not None:
@@ -395,21 +438,21 @@ class Query(
)
@property
- def _has_row_limiting_clause(self):
+ def _has_row_limiting_clause(self) -> bool:
return (
self._limit_clause is not None or self._offset_clause is not None
)
def _get_options(
- self,
- populate_existing=None,
- version_check=None,
- only_load_props=None,
- refresh_state=None,
- identity_token=None,
- ):
- load_options = {}
- compile_options = {}
+ self: SelfQuery,
+ populate_existing: Optional[bool] = None,
+ version_check: Optional[bool] = None,
+ only_load_props: Optional[Sequence[str]] = None,
+ refresh_state: Optional[InstanceState[Any]] = None,
+ identity_token: Optional[Any] = None,
+ ) -> SelfQuery:
+ load_options: Dict[str, Any] = {}
+ compile_options: Dict[str, Any] = {}
if version_check:
load_options["_version_check"] = version_check
@@ -430,11 +473,18 @@ class Query(
return self
- def _clone(self):
- return self._generate()
+ def _clone(self: Self, **kw: Any) -> Self:
+ return self._generate() # type: ignore
+
+ def _get_select_statement_only(self) -> Select[_T]:
+ if self._statement is not None:
+ raise sa_exc.InvalidRequestError(
+ "Can't call this method on a Query that uses from_statement()"
+ )
+ return cast("Select[_T]", self.statement)
@property
- def statement(self):
+ def statement(self) -> Union[Select[_T], FromStatement[_T]]:
"""The full SELECT statement represented by this Query.
The statement by default will not have disambiguating labels
@@ -474,14 +524,15 @@ class Query(
return stmt
- def _final_statement(self, legacy_query_style=True):
+ def _final_statement(self, legacy_query_style: bool = True) -> Select[Any]:
"""Return the 'final' SELECT statement for this :class:`.Query`.
+ This is used by the testing suite only and is fairly inefficient.
+
This is the Core-only select() that will be rendered by a complete
compilation of this query, and is what .statement used to return
in 1.3.
- This method creates a complete compile state so is fairly expensive.
"""
@@ -489,9 +540,11 @@ class Query(
return q._compile_state(
use_legacy_query_style=legacy_query_style
- ).statement
+ ).statement # type: ignore
- def _statement_20(self, for_statement=False, use_legacy_query_style=True):
+ def _statement_20(
+ self, for_statement: bool = False, use_legacy_query_style: bool = True
+ ) -> Union[Select[_T], FromStatement[_T]]:
# TODO: this event needs to be deprecated, as it currently applies
# only to ORM query and occurs at this spot that is now more
# or less an artificial spot
@@ -500,7 +553,7 @@ class Query(
new_query = fn(self)
if new_query is not None and new_query is not self:
self = new_query
- if not fn._bake_ok:
+ if not fn._bake_ok: # type: ignore
self._compile_options += {"_bake_ok": False}
compile_options = self._compile_options
@@ -509,6 +562,8 @@ class Query(
"_use_legacy_query_style": use_legacy_query_style,
}
+ stmt: Union[Select[_T], FromStatement[_T]]
+
if self._statement is not None:
stmt = FromStatement(self._raw_columns, self._statement)
stmt.__dict__.update(
@@ -541,10 +596,10 @@ class Query(
def subquery(
self,
- name=None,
- with_labels=False,
- reduce_columns=False,
- ):
+ name: Optional[str] = None,
+ with_labels: bool = False,
+ reduce_columns: bool = False,
+ ) -> Subquery:
"""Return the full SELECT statement represented by
this :class:`_query.Query`, embedded within an
:class:`_expression.Alias`.
@@ -571,13 +626,21 @@ class Query(
if with_labels:
q = q.set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL)
- q = q.statement
+ stmt = q._get_select_statement_only()
+
+ if TYPE_CHECKING:
+ assert isinstance(stmt, Select)
if reduce_columns:
- q = q.reduce_columns()
- return q.alias(name=name)
+ stmt = stmt.reduce_columns()
+ return stmt.subquery(name=name)
- def cte(self, name=None, recursive=False, nesting=False):
+ def cte(
+ self,
+ name: Optional[str] = None,
+ recursive: bool = False,
+ nesting: bool = False,
+ ) -> CTE:
r"""Return the full SELECT statement represented by this
:class:`_query.Query` represented as a common table expression (CTE).
@@ -632,11 +695,13 @@ class Query(
:meth:`_expression.HasCTE.cte`
"""
- return self.enable_eagerloads(False).statement.cte(
- name=name, recursive=recursive, nesting=nesting
+ return (
+ self.enable_eagerloads(False)
+ ._get_select_statement_only()
+ .cte(name=name, recursive=recursive, nesting=nesting)
)
- def label(self, name):
+ def label(self, name: Optional[str]) -> Label[Any]:
"""Return the full SELECT statement represented by this
:class:`_query.Query`, converted
to a scalar subquery with a label of the given name.
@@ -645,7 +710,11 @@ class Query(
"""
- return self.enable_eagerloads(False).statement.label(name)
+ return (
+ self.enable_eagerloads(False)
+ ._get_select_statement_only()
+ .label(name)
+ )
@overload
def as_scalar(
@@ -704,10 +773,14 @@ class Query(
"""
- return self.enable_eagerloads(False).statement.scalar_subquery()
+ return (
+ self.enable_eagerloads(False)
+ ._get_select_statement_only()
+ .scalar_subquery()
+ )
@property
- def selectable(self):
+ def selectable(self) -> Union[Select[_T], FromStatement[_T]]:
"""Return the :class:`_expression.Select` object emitted by this
:class:`_query.Query`.
@@ -718,7 +791,7 @@ class Query(
"""
return self.__clause_element__()
- def __clause_element__(self):
+ def __clause_element__(self) -> Union[Select[_T], FromStatement[_T]]:
return (
self._with_compile_options(
_enable_eagerloads=False, _render_for_subquery=True
@@ -759,7 +832,7 @@ class Query(
return self
@property
- def is_single_entity(self):
+ def is_single_entity(self) -> bool:
"""Indicates if this :class:`_query.Query`
returns tuples or single entities.
@@ -785,7 +858,7 @@ class Query(
)
@_generative
- def enable_eagerloads(self: SelfQuery, value) -> SelfQuery:
+ def enable_eagerloads(self: SelfQuery, value: bool) -> SelfQuery:
"""Control whether or not eager joins and subqueries are
rendered.
@@ -804,7 +877,7 @@ class Query(
return self
@_generative
- def _with_compile_options(self: SelfQuery, **opt) -> SelfQuery:
+ def _with_compile_options(self: SelfQuery, **opt: Any) -> SelfQuery:
self._compile_options += opt
return self
@@ -813,13 +886,15 @@ class Query(
alternative="Use set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL) "
"instead.",
)
- def with_labels(self):
- return self.set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL)
+ def with_labels(self: SelfQuery) -> SelfQuery:
+ return self.set_label_style(
+ SelectLabelStyle.LABEL_STYLE_TABLENAME_PLUS_COL
+ )
apply_labels = with_labels
@property
- def get_label_style(self):
+ def get_label_style(self) -> SelectLabelStyle:
"""
Retrieve the current label style.
@@ -828,7 +903,7 @@ class Query(
"""
return self._label_style
- def set_label_style(self, style):
+ def set_label_style(self: SelfQuery, style: SelectLabelStyle) -> SelfQuery:
"""Apply column labels to the return value of Query.statement.
Indicates that this Query's `statement` accessor should return
@@ -864,7 +939,7 @@ class Query(
return self
@_generative
- def enable_assertions(self: SelfQuery, value) -> SelfQuery:
+ def enable_assertions(self: SelfQuery, value: bool) -> SelfQuery:
"""Control whether assertions are generated.
When set to False, the returned Query will
@@ -887,7 +962,7 @@ class Query(
return self
@property
- def whereclause(self):
+ def whereclause(self) -> Optional[ColumnElement[bool]]:
"""A readonly attribute which returns the current WHERE criterion for
this Query.
@@ -895,12 +970,12 @@ class Query(
criterion has been established.
"""
- return sql.elements.BooleanClauseList._construct_for_whereclause(
+ return BooleanClauseList._construct_for_whereclause(
self._where_criteria
)
@_generative
- def _with_current_path(self: SelfQuery, path) -> SelfQuery:
+ def _with_current_path(self: SelfQuery, path: PathRegistry) -> SelfQuery:
"""indicate that this query applies to objects loaded
within a certain path.
@@ -913,7 +988,7 @@ class Query(
return self
@_generative
- def yield_per(self: SelfQuery, count) -> SelfQuery:
+ def yield_per(self: SelfQuery, count: int) -> SelfQuery:
r"""Yield only ``count`` rows at a time.
The purpose of this method is when fetching very large result sets
@@ -938,7 +1013,7 @@ class Query(
":meth:`_orm.Query.get`",
alternative="The method is now available as :meth:`_orm.Session.get`",
)
- def get(self, ident):
+ def get(self, ident: _PKIdentityArgument) -> Optional[Any]:
"""Return an instance based on the given primary key identifier,
or ``None`` if not found.
@@ -1022,7 +1097,12 @@ class Query(
# it
return self._get_impl(ident, loading.load_on_pk_identity)
- def _get_impl(self, primary_key_identity, db_load_fn, identity_token=None):
+ def _get_impl(
+ self,
+ primary_key_identity: _PKIdentityArgument,
+ db_load_fn: Callable[..., Any],
+ identity_token: Optional[Any] = None,
+ ) -> Optional[Any]:
mapper = self._only_full_mapper_zero("get")
return self.session._get_impl(
mapper,
@@ -1036,7 +1116,7 @@ class Query(
)
@property
- def lazy_loaded_from(self):
+ def lazy_loaded_from(self) -> Optional[InstanceState[Any]]:
"""An :class:`.InstanceState` that is using this :class:`_query.Query`
for a lazy load operation.
@@ -1050,14 +1130,17 @@ class Query(
:attr:`.ORMExecuteState.lazy_loaded_from`
"""
- return self.load_options._lazy_loaded_from
+ return self.load_options._lazy_loaded_from # type: ignore
@property
- def _current_path(self):
- return self._compile_options._current_path
+ def _current_path(self) -> PathRegistry:
+ return self._compile_options._current_path # type: ignore
@_generative
- def correlate(self: SelfQuery, *fromclauses) -> SelfQuery:
+ def correlate(
+ self: SelfQuery,
+ *fromclauses: Union[Literal[None, False], _FromClauseArgument],
+ ) -> SelfQuery:
"""Return a :class:`.Query` construct which will correlate the given
FROM clauses to that of an enclosing :class:`.Query` or
:func:`~.expression.select`.
@@ -1082,13 +1165,13 @@ class Query(
if fromclauses and fromclauses[0] in {None, False}:
self._correlate = ()
else:
- self._correlate = set(self._correlate).union(
+ self._correlate = self._correlate + tuple(
coercions.expect(roles.FromClauseRole, f) for f in fromclauses
)
return self
@_generative
- def autoflush(self: SelfQuery, setting) -> SelfQuery:
+ def autoflush(self: SelfQuery, setting: bool) -> SelfQuery:
"""Return a Query with a specific 'autoflush' setting.
As of SQLAlchemy 1.4, the :meth:`_orm.Query.autoflush` method
@@ -1116,7 +1199,7 @@ class Query(
return self
@_generative
- def _with_invoke_all_eagers(self: SelfQuery, value) -> SelfQuery:
+ def _with_invoke_all_eagers(self: SelfQuery, value: bool) -> SelfQuery:
"""Set the 'invoke all eagers' flag which causes joined- and
subquery loaders to traverse into already-loaded related objects
and collections.
@@ -1132,7 +1215,14 @@ class Query(
alternative="Use the :func:`_orm.with_parent` standalone construct.",
)
@util.preload_module("sqlalchemy.orm.relationships")
- def with_parent(self, instance, property=None, from_entity=None): # noqa
+ def with_parent(
+ self: SelfQuery,
+ instance: object,
+ property: Optional[ # noqa: A002
+ attributes.QueryableAttribute[Any]
+ ] = None,
+ from_entity: Optional[_ExternalEntityType[Any]] = None,
+ ) -> SelfQuery:
"""Add filtering criterion that relates the given instance
to a child object or collection, using its attribute state
as well as an established :func:`_orm.relationship()`
@@ -1150,7 +1240,7 @@ class Query(
An instance which has some :func:`_orm.relationship`.
:param property:
- String property name, or class-bound attribute, which indicates
+ Class bound attribute which indicates
what relationship from the instance should be used to reconcile the
parent/child relationship.
@@ -1172,21 +1262,27 @@ class Query(
for prop in mapper.iterate_properties:
if (
isinstance(prop, relationships.Relationship)
- and prop.mapper is entity_zero.mapper
+ and prop.mapper is entity_zero.mapper # type: ignore
):
- property = prop # noqa
+ property = prop # type: ignore # noqa: A001
break
else:
raise sa_exc.InvalidRequestError(
"Could not locate a property which relates instances "
"of class '%s' to instances of class '%s'"
% (
- entity_zero.mapper.class_.__name__,
+ entity_zero.mapper.class_.__name__, # type: ignore
instance.__class__.__name__,
)
)
- return self.filter(with_parent(instance, property, entity_zero.entity))
+ return self.filter(
+ with_parent(
+ instance,
+ property, # type: ignore
+ entity_zero.entity, # type: ignore
+ )
+ )
@_generative
def add_entity(
@@ -1211,7 +1307,7 @@ class Query(
return self
@_generative
- def with_session(self: SelfQuery, session) -> SelfQuery:
+ def with_session(self: SelfQuery, session: Session) -> SelfQuery:
"""Return a :class:`_query.Query` that will use the given
:class:`.Session`.
@@ -1237,7 +1333,9 @@ class Query(
self.session = session
return self
- def _legacy_from_self(self, *entities):
+ def _legacy_from_self(
+ self: SelfQuery, *entities: _ColumnsClauseArgument[Any]
+ ) -> SelfQuery:
# used for query.count() as well as for the same
# function in BakedQuery, as well as some old tests in test_baked.py.
@@ -1255,13 +1353,13 @@ class Query(
return q
@_generative
- def _set_enable_single_crit(self: SelfQuery, val) -> SelfQuery:
+ def _set_enable_single_crit(self: SelfQuery, val: bool) -> SelfQuery:
self._compile_options += {"_enable_single_crit": val}
return self
@_generative
def _from_selectable(
- self: SelfQuery, fromclause, set_entity_from=True
+ self: SelfQuery, fromclause: FromClause, set_entity_from: bool = True
) -> SelfQuery:
for attr in (
"_where_criteria",
@@ -1292,7 +1390,7 @@ class Query(
"is deprecated and will be removed in a "
"future release. Please use :meth:`_query.Query.with_entities`",
)
- def values(self, *columns):
+ def values(self, *columns: _ColumnsClauseArgument[Any]) -> Iterable[Any]:
"""Return an iterator yielding result tuples corresponding
to the given list of columns
@@ -1304,7 +1402,7 @@ class Query(
q._set_entities(columns)
if not q.load_options._yield_per:
q.load_options += {"_yield_per": 10}
- return iter(q)
+ return iter(q) # type: ignore
_values = values
@@ -1315,25 +1413,24 @@ class Query(
"future release. Please use :meth:`_query.Query.with_entities` "
"in combination with :meth:`_query.Query.scalar`",
)
- def value(self, column):
+ def value(self, column: _ColumnExpressionArgument[Any]) -> Any:
"""Return a scalar result corresponding to the given
column expression.
"""
try:
- return next(self.values(column))[0]
+ return next(self.values(column))[0] # type: ignore
except StopIteration:
return None
@overload
- def with_entities(
- self, _entity: _EntityType[_O], **kwargs: Any
- ) -> Query[_O]:
+ def with_entities(self, _entity: _EntityType[_O]) -> Query[_O]:
...
@overload
def with_entities(
- self, _colexpr: TypedColumnsClauseRole[_T]
+ self,
+ _colexpr: roles.TypedColumnsClauseRole[_T],
) -> RowReturningQuery[Tuple[_T]]:
...
@@ -1418,14 +1515,14 @@ class Query(
@overload
def with_entities(
- self: SelfQuery, *entities: _ColumnsClauseArgument[Any]
- ) -> SelfQuery:
+ self, *entities: _ColumnsClauseArgument[Any]
+ ) -> Query[Any]:
...
@_generative
def with_entities(
- self: SelfQuery, *entities: _ColumnsClauseArgument[Any], **__kw: Any
- ) -> SelfQuery:
+ self, *entities: _ColumnsClauseArgument[Any], **__kw: Any
+ ) -> Query[Any]:
r"""Return a new :class:`_query.Query`
replacing the SELECT list with the
given entities.
@@ -1451,12 +1548,18 @@ class Query(
"""
if __kw:
raise _no_kw()
- _MemoizedSelectEntities._generate_for_statement(self)
+
+ # Query has all the same fields as Select for this operation
+ # this could in theory be based on a protocol but not sure if it's
+ # worth it
+ _MemoizedSelectEntities._generate_for_statement(self) # type: ignore
self._set_entities(entities)
return self
@_generative
- def add_columns(self, *column: _ColumnExpressionArgument) -> Query[Any]:
+ def add_columns(
+ self, *column: _ColumnExpressionArgument[Any]
+ ) -> Query[Any]:
"""Add one or more column expressions to the list
of result columns to be returned."""
@@ -1479,7 +1582,7 @@ class Query(
"is deprecated and will be removed in a "
"future release. Please use :meth:`_query.Query.add_columns`",
)
- def add_column(self, column) -> Query[Any]:
+ def add_column(self, column: _ColumnExpressionArgument[Any]) -> Query[Any]:
"""Add a column expression to the list of result columns to be
returned.
@@ -1487,7 +1590,7 @@ class Query(
return self.add_columns(column)
@_generative
- def options(self: SelfQuery, *args) -> SelfQuery:
+ def options(self: SelfQuery, *args: ExecutableOption) -> SelfQuery:
"""Return a new :class:`_query.Query` object,
applying the given list of
mapper options.
@@ -1505,18 +1608,21 @@ class Query(
opts = tuple(util.flatten_iterator(args))
if self._compile_options._current_path:
+ # opting for lower method overhead for the checks
for opt in opts:
- if opt._is_legacy_option:
- opt.process_query_conditionally(self)
+ if not opt._is_core and opt._is_legacy_option: # type: ignore
+ opt.process_query_conditionally(self) # type: ignore
else:
for opt in opts:
- if opt._is_legacy_option:
- opt.process_query(self)
+ if not opt._is_core and opt._is_legacy_option: # type: ignore
+ opt.process_query(self) # type: ignore
self._with_options += opts
return self
- def with_transformation(self, fn):
+ def with_transformation(
+ self, fn: Callable[[Query[Any]], Query[Any]]
+ ) -> Query[Any]:
"""Return a new :class:`_query.Query` object transformed by
the given function.
@@ -1535,7 +1641,7 @@ class Query(
"""
return fn(self)
- def get_execution_options(self):
+ def get_execution_options(self) -> _ImmutableExecuteOptions:
"""Get the non-SQL options which will take effect during execution.
.. versionadded:: 1.3
@@ -1547,7 +1653,7 @@ class Query(
return self._execution_options
@_generative
- def execution_options(self: SelfQuery, **kwargs) -> SelfQuery:
+ def execution_options(self: SelfQuery, **kwargs: Any) -> SelfQuery:
"""Set non-SQL options which take effect during execution.
Options allowed here include all of those accepted by
@@ -1596,11 +1702,17 @@ class Query(
@_generative
def with_for_update(
self: SelfQuery,
- read=False,
- nowait=False,
- of=None,
- skip_locked=False,
- key_share=False,
+ *,
+ nowait: bool = False,
+ read: bool = False,
+ of: Optional[
+ Union[
+ _ColumnExpressionArgument[Any],
+ Sequence[_ColumnExpressionArgument[Any]],
+ ]
+ ] = None,
+ skip_locked: bool = False,
+ key_share: bool = False,
) -> SelfQuery:
"""return a new :class:`_query.Query`
with the specified options for the
@@ -1659,7 +1771,9 @@ class Query(
return self
@_generative
- def params(self: SelfQuery, *args, **kwargs) -> SelfQuery:
+ def params(
+ self: SelfQuery, __params: Optional[Dict[str, Any]] = None, **kw: Any
+ ) -> SelfQuery:
r"""Add values for bind parameters which may have been
specified in filter().
@@ -1669,17 +1783,14 @@ class Query(
contain unicode keys in which case \**kwargs cannot be used.
"""
- if len(args) == 1:
- kwargs.update(args[0])
- elif len(args) > 0:
- raise sa_exc.ArgumentError(
- "params() takes zero or one positional argument, "
- "which is a dictionary."
- )
- self._params = self._params.union(kwargs)
+ if __params:
+ kw.update(__params)
+ self._params = self._params.union(kw)
return self
- def where(self: SelfQuery, *criterion) -> SelfQuery:
+ def where(
+ self: SelfQuery, *criterion: _ColumnExpressionArgument[bool]
+ ) -> SelfQuery:
"""A synonym for :meth:`.Query.filter`.
.. versionadded:: 1.4
@@ -1716,16 +1827,18 @@ class Query(
:meth:`_query.Query.filter_by` - filter on keyword expressions.
"""
- for criterion in list(criterion):
- criterion = coercions.expect(
- roles.WhereHavingRole, criterion, apply_propagate_attrs=self
+ for crit in list(criterion):
+ crit = coercions.expect(
+ roles.WhereHavingRole, crit, apply_propagate_attrs=self
)
- self._where_criteria += (criterion,)
+ self._where_criteria += (crit,)
return self
@util.memoized_property
- def _last_joined_entity(self):
+ def _last_joined_entity(
+ self,
+ ) -> Optional[Union[_InternalEntityType[Any], _JoinTargetElement]]:
if self._setup_joins:
return _determine_last_joined_entity(
self._setup_joins,
@@ -1733,7 +1846,7 @@ class Query(
else:
return None
- def _filter_by_zero(self):
+ def _filter_by_zero(self) -> Any:
"""for the filter_by() method, return the target entity for which
we will attempt to derive an expression from based on string name.
@@ -1800,13 +1913,6 @@ class Query(
"""
from_entity = self._filter_by_zero()
- if from_entity is None:
- raise sa_exc.InvalidRequestError(
- "Can't use filter_by when the first entity '%s' of a query "
- "is not a mapped class. Please use the filter method instead, "
- "or change the order of the entities in the query"
- % self._query_entity_zero()
- )
clauses = [
_entity_namespace_key(from_entity, key) == value
@@ -1815,9 +1921,12 @@ class Query(
return self.filter(*clauses)
@_generative
- @_assertions(_no_statement_condition, _no_limit_offset)
def order_by(
- self: SelfQuery, *clauses: _ColumnExpressionArgument[Any]
+ self: SelfQuery,
+ __first: Union[
+ Literal[None, False, _NoArg.NO_ARG], _ColumnExpressionArgument[Any]
+ ] = _NoArg.NO_ARG,
+ *clauses: _ColumnExpressionArgument[Any],
) -> SelfQuery:
"""Apply one or more ORDER BY criteria to the query and return
the newly resulting :class:`_query.Query`.
@@ -1844,20 +1953,27 @@ class Query(
"""
- if len(clauses) == 1 and (clauses[0] is None or clauses[0] is False):
+ for assertion in (self._no_statement_condition, self._no_limit_offset):
+ assertion("order_by")
+
+ if not clauses and (__first is None or __first is False):
self._order_by_clauses = ()
- else:
+ elif __first is not _NoArg.NO_ARG:
criterion = tuple(
coercions.expect(roles.OrderByRole, clause)
- for clause in clauses
+ for clause in (__first,) + clauses
)
self._order_by_clauses += criterion
+
return self
@_generative
- @_assertions(_no_statement_condition, _no_limit_offset)
def group_by(
- self: SelfQuery, *clauses: _ColumnExpressionArgument[Any]
+ self: SelfQuery,
+ __first: Union[
+ Literal[None, False, _NoArg.NO_ARG], _ColumnExpressionArgument[Any]
+ ] = _NoArg.NO_ARG,
+ *clauses: _ColumnExpressionArgument[Any],
) -> SelfQuery:
"""Apply one or more GROUP BY criterion to the query and return
the newly resulting :class:`_query.Query`.
@@ -1878,12 +1994,15 @@ class Query(
"""
- if len(clauses) == 1 and (clauses[0] is None or clauses[0] is False):
+ for assertion in (self._no_statement_condition, self._no_limit_offset):
+ assertion("group_by")
+
+ if not clauses and (__first is None or __first is False):
self._group_by_clauses = ()
- else:
+ elif __first is not _NoArg.NO_ARG:
criterion = tuple(
coercions.expect(roles.GroupByRole, clause)
- for clause in clauses
+ for clause in (__first,) + clauses
)
self._group_by_clauses += criterion
return self
@@ -1916,8 +2035,9 @@ class Query(
self._having_criteria += (having_criteria,)
return self
- def _set_op(self, expr_fn, *q):
- return self._from_selectable(expr_fn(*([self] + list(q))).subquery())
+ def _set_op(self: SelfQuery, expr_fn: Any, *q: Query[Any]) -> SelfQuery:
+ list_of_queries = (self,) + q
+ return self._from_selectable(expr_fn(*(list_of_queries)).subquery())
def union(self: SelfQuery, *q: Query[Any]) -> SelfQuery:
"""Produce a UNION of this Query against one or more queries.
@@ -2006,7 +2126,12 @@ class Query(
@_generative
@_assertions(_no_statement_condition, _no_limit_offset)
def join(
- self: SelfQuery, target, onclause=None, *, isouter=False, full=False
+ self: SelfQuery,
+ target: _JoinTargetArgument,
+ onclause: Optional[_OnClauseArgument] = None,
+ *,
+ isouter: bool = False,
+ full: bool = False,
) -> SelfQuery:
r"""Create a SQL JOIN against this :class:`_query.Query`
object's criterion
@@ -2193,20 +2318,23 @@ class Query(
"""
- target = coercions.expect(
+ join_target = coercions.expect(
roles.JoinTargetRole,
target,
apply_propagate_attrs=self,
legacy=True,
)
if onclause is not None:
- onclause = coercions.expect(
+ onclause_element = coercions.expect(
roles.OnClauseRole, onclause, legacy=True
)
+ else:
+ onclause_element = None
+
self._setup_joins += (
(
- target,
- onclause,
+ join_target,
+ onclause_element,
None,
{
"isouter": isouter,
@@ -2218,7 +2346,13 @@ class Query(
self.__dict__.pop("_last_joined_entity", None)
return self
- def outerjoin(self, target, onclause=None, *, full=False):
+ def outerjoin(
+ self: SelfQuery,
+ target: _JoinTargetArgument,
+ onclause: Optional[_OnClauseArgument] = None,
+ *,
+ full: bool = False,
+ ) -> SelfQuery:
"""Create a left outer join against this ``Query`` object's criterion
and apply generatively, returning the newly resulting ``Query``.
@@ -2295,7 +2429,7 @@ class Query(
self._set_select_from(from_obj, False)
return self
- def __getitem__(self, item):
+ def __getitem__(self, item: Any) -> Any:
return orm_util._getitem(
self,
item,
@@ -2303,7 +2437,11 @@ class Query(
@_generative
@_assertions(_no_statement_condition)
- def slice(self: SelfQuery, start, stop) -> SelfQuery:
+ def slice(
+ self: SelfQuery,
+ start: int,
+ stop: int,
+ ) -> SelfQuery:
"""Computes the "slice" of the :class:`_query.Query` represented by
the given indices and returns the resulting :class:`_query.Query`.
@@ -2341,7 +2479,9 @@ class Query(
@_generative
@_assertions(_no_statement_condition)
- def limit(self: SelfQuery, limit) -> SelfQuery:
+ def limit(
+ self: SelfQuery, limit: Union[int, _ColumnExpressionArgument[int]]
+ ) -> SelfQuery:
"""Apply a ``LIMIT`` to the query and return the newly resulting
``Query``.
@@ -2351,7 +2491,9 @@ class Query(
@_generative
@_assertions(_no_statement_condition)
- def offset(self: SelfQuery, offset) -> SelfQuery:
+ def offset(
+ self: SelfQuery, offset: Union[int, _ColumnExpressionArgument[int]]
+ ) -> SelfQuery:
"""Apply an ``OFFSET`` to the query and return the newly resulting
``Query``.
@@ -2361,7 +2503,9 @@ class Query(
@_generative
@_assertions(_no_statement_condition)
- def distinct(self: SelfQuery, *expr) -> SelfQuery:
+ def distinct(
+ self: SelfQuery, *expr: _ColumnExpressionArgument[Any]
+ ) -> SelfQuery:
r"""Apply a ``DISTINCT`` to the query and return the newly resulting
``Query``.
@@ -2415,7 +2559,7 @@ class Query(
:ref:`faq_query_deduplicating`
"""
- return self._iter().all()
+ return self._iter().all() # type: ignore
@_generative
@_assertions(_no_clauseelement_condition)
@@ -2462,9 +2606,9 @@ class Query(
"""
# replicates limit(1) behavior
if self._statement is not None:
- return self._iter().first()
+ return self._iter().first() # type: ignore
else:
- return self.limit(1)._iter().first()
+ return self.limit(1)._iter().first() # type: ignore
def one_or_none(self) -> Optional[_T]:
"""Return at most one result or raise an exception.
@@ -2490,7 +2634,7 @@ class Query(
:meth:`_query.Query.one`
"""
- return self._iter().one_or_none()
+ return self._iter().one_or_none() # type: ignore
def one(self) -> _T:
"""Return exactly one result or raise an exception.
@@ -2537,18 +2681,18 @@ class Query(
if not isinstance(ret, collections_abc.Sequence):
return ret
return ret[0]
- except orm_exc.NoResultFound:
+ except sa_exc.NoResultFound:
return None
def __iter__(self) -> Iterable[_T]:
- return self._iter().__iter__()
+ return self._iter().__iter__() # type: ignore
def _iter(self) -> Union[ScalarResult[_T], Result[_T]]:
# new style execution.
params = self._params
statement = self._statement_20()
- result = self.session.execute(
+ result: Union[ScalarResult[_T], Result[_T]] = self.session.execute(
statement,
params,
execution_options={"_sa_orm_load_options": self.load_options},
@@ -2556,7 +2700,7 @@ class Query(
# legacy: automatically set scalars, unique
if result._attributes.get("is_single_entity", False):
- result = result.scalars()
+ result = cast("Result[_T]", result).scalars()
if (
result._attributes.get("filtered", False)
@@ -2580,7 +2724,7 @@ class Query(
return str(statement.compile(bind))
- def _get_bind_args(self, statement, fn, **kw):
+ def _get_bind_args(self, statement: Any, fn: Any, **kw: Any) -> Any:
return fn(clause=statement, **kw)
@property
@@ -2634,7 +2778,11 @@ class Query(
return _column_descriptions(self, legacy=True)
- def instances(self, result_proxy: Result, context=None) -> Any:
+ def instances(
+ self,
+ result_proxy: CursorResult[Any],
+ context: Optional[QueryContext] = None,
+ ) -> Any:
"""Return an ORM result given a :class:`_engine.CursorResult` and
:class:`.QueryContext`.
@@ -2661,7 +2809,7 @@ class Query(
# legacy: automatically set scalars, unique
if result._attributes.get("is_single_entity", False):
- result = result.scalars()
+ result = result.scalars() # type: ignore
if result._attributes.get("filtered", False):
result = result.unique()
@@ -2675,7 +2823,13 @@ class Query(
":func:`_orm.merge_frozen_result` function.",
enable_warnings=False, # warnings occur via loading.merge_result
)
- def merge_result(self, iterator, load=True):
+ def merge_result(
+ self,
+ iterator: Union[
+ FrozenResult[Any], Iterable[Sequence[Any]], Iterable[object]
+ ],
+ load: bool = True,
+ ) -> Union[FrozenResult[Any], Iterable[Any]]:
"""Merge a result into this :class:`_query.Query` object's Session.
Given an iterator returned by a :class:`_query.Query`
@@ -2743,7 +2897,8 @@ class Query(
self.enable_eagerloads(False)
.add_columns(sql.literal_column("1"))
.set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL)
- .statement.with_only_columns(1)
+ ._get_select_statement_only()
+ .with_only_columns(1)
)
ezero = self._entity_from_pre_ent_zero()
@@ -2752,7 +2907,7 @@ class Query(
return sql.exists(inner)
- def count(self):
+ def count(self) -> int:
r"""Return a count of rows this the SQL formed by this :class:`Query`
would return.
@@ -2806,9 +2961,11 @@ class Query(
"""
col = sql.func.count(sql.literal_column("*"))
- return self._legacy_from_self(col).enable_eagerloads(False).scalar()
+ return ( # type: ignore
+ self._legacy_from_self(col).enable_eagerloads(False).scalar()
+ )
- def delete(self, synchronize_session="evaluate"):
+ def delete(self, synchronize_session: str = "evaluate") -> int:
r"""Perform a DELETE with an arbitrary WHERE clause.
Deletes rows matched by this query from the database.
@@ -2850,20 +3007,28 @@ class Query(
self = bulk_del.query
- delete_ = sql.delete(*self._raw_columns)
+ delete_ = sql.delete(*self._raw_columns) # type: ignore
delete_._where_criteria = self._where_criteria
- result = self.session.execute(
- delete_,
- self._params,
- execution_options={"synchronize_session": synchronize_session},
+ result: CursorResult[Any] = cast(
+ "CursorResult[Any]",
+ self.session.execute(
+ delete_,
+ self._params,
+ execution_options={"synchronize_session": synchronize_session},
+ ),
)
- bulk_del.result = result
+ bulk_del.result = result # type: ignore
self.session.dispatch.after_bulk_delete(bulk_del)
result.close()
return result.rowcount
- def update(self, values, synchronize_session="evaluate", update_args=None):
+ def update(
+ self,
+ values: Dict[_DMLColumnArgument, Any],
+ synchronize_session: str = "evaluate",
+ update_args: Optional[Dict[Any, Any]] = None,
+ ) -> int:
r"""Perform an UPDATE with an arbitrary WHERE clause.
Updates rows matched by this query in the database.
@@ -2926,28 +3091,33 @@ class Query(
bulk_ud.query = new_query
self = bulk_ud.query
- upd = sql.update(*self._raw_columns)
+ upd = sql.update(*self._raw_columns) # type: ignore
ppo = update_args.pop("preserve_parameter_order", False)
if ppo:
- upd = upd.ordered_values(*values)
+ upd = upd.ordered_values(*values) # type: ignore
else:
upd = upd.values(values)
if update_args:
upd = upd.with_dialect_options(**update_args)
upd._where_criteria = self._where_criteria
- result = self.session.execute(
- upd,
- self._params,
- execution_options={"synchronize_session": synchronize_session},
+ result: CursorResult[Any] = cast(
+ "CursorResult[Any]",
+ self.session.execute(
+ upd,
+ self._params,
+ execution_options={"synchronize_session": synchronize_session},
+ ),
)
- bulk_ud.result = result
+ bulk_ud.result = result # type: ignore
self.session.dispatch.after_bulk_update(bulk_ud)
result.close()
return result.rowcount
- def _compile_state(self, for_statement=False, **kw):
+ def _compile_state(
+ self, for_statement: bool = False, **kw: Any
+ ) -> ORMCompileState:
"""Create an out-of-compiler ORMCompileState object.
The ORMCompileState object is normally created directly as a result
@@ -2971,13 +3141,14 @@ class Query(
# ORMSelectCompileState. We could also base this on
# query._statement is not None as we have the ORM Query here
# however this is the more general path.
- compile_state_cls = ORMCompileState._get_plugin_class_for_plugin(
- stmt, "orm"
+ compile_state_cls = cast(
+ ORMCompileState,
+ ORMCompileState._get_plugin_class_for_plugin(stmt, "orm"),
)
return compile_state_cls.create_for_statement(stmt, None)
- def _compile_context(self, for_statement=False):
+ def _compile_context(self, for_statement: bool = False) -> QueryContext:
compile_state = self._compile_state(for_statement=for_statement)
context = QueryContext(
compile_state,
@@ -3006,7 +3177,7 @@ class AliasOption(interfaces.LoaderOption):
"""
- def process_compile_state(self, compile_state: ORMCompileState):
+ def process_compile_state(self, compile_state: ORMCompileState) -> None:
pass
@@ -3017,12 +3188,12 @@ class BulkUD:
"""
- def __init__(self, query):
+ def __init__(self, query: Query[Any]):
self.query = query.enable_eagerloads(False)
self._validate_query_state()
self.mapper = self.query._entity_from_pre_ent_zero()
- def _validate_query_state(self):
+ def _validate_query_state(self) -> None:
for attr, methname, notset, op in (
("_limit_clause", "limit()", None, operator.is_),
("_offset_clause", "offset()", None, operator.is_),
@@ -3049,14 +3220,19 @@ class BulkUD:
)
@property
- def session(self):
+ def session(self) -> Session:
return self.query.session
class BulkUpdate(BulkUD):
"""BulkUD which handles UPDATEs."""
- def __init__(self, query, values, update_kwargs):
+ def __init__(
+ self,
+ query: Query[Any],
+ values: Dict[_DMLColumnArgument, Any],
+ update_kwargs: Optional[Dict[Any, Any]],
+ ):
super(BulkUpdate, self).__init__(query)
self.values = values
self.update_kwargs = update_kwargs
@@ -3067,4 +3243,7 @@ class BulkDelete(BulkUD):
class RowReturningQuery(Query[Row[_TP]]):
- pass
+ if TYPE_CHECKING:
+
+ def tuples(self) -> Query[_TP]: # type: ignore
+ ...
diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py
index 8273775ae..1186f0f54 100644
--- a/lib/sqlalchemy/orm/relationships.py
+++ b/lib/sqlalchemy/orm/relationships.py
@@ -17,13 +17,23 @@ from __future__ import annotations
import collections
from collections import abc
+import dataclasses
import re
import typing
from typing import Any
from typing import Callable
+from typing import cast
+from typing import Collection
from typing import Dict
+from typing import Generic
+from typing import Iterable
+from typing import Iterator
+from typing import List
+from typing import NamedTuple
+from typing import NoReturn
from typing import Optional
from typing import Sequence
+from typing import Set
from typing import Tuple
from typing import Type
from typing import TypeVar
@@ -32,14 +42,19 @@ import weakref
from . import attributes
from . import strategy_options
+from ._typing import insp_is_aliased_class
+from ._typing import is_has_collection_adapter
from .base import _is_mapped_class
from .base import class_mapper
+from .base import LoaderCallableStatus
+from .base import PassiveFlag
from .base import state_str
from .interfaces import _IntrospectsAnnotations
from .interfaces import MANYTOMANY
from .interfaces import MANYTOONE
from .interfaces import ONETOMANY
from .interfaces import PropComparator
+from .interfaces import RelationshipDirection
from .interfaces import StrategizedProperty
from .util import _extract_mapped_subtype
from .util import _orm_annotate
@@ -60,6 +75,7 @@ from ..sql import visitors
from ..sql._typing import _ColumnExpressionArgument
from ..sql._typing import _HasClauseElement
from ..sql.elements import ColumnClause
+from ..sql.elements import ColumnElement
from ..sql.util import _deep_deannotate
from ..sql.util import _shallow_annotate
from ..sql.util import adapt_criterion_to_null
@@ -71,15 +87,42 @@ from ..util.typing import Literal
if typing.TYPE_CHECKING:
from ._typing import _EntityType
+ from ._typing import _ExternalEntityType
+ from ._typing import _IdentityKeyType
+ from ._typing import _InstanceDict
from ._typing import _InternalEntityType
+ from ._typing import _O
+ from ._typing import _RegistryType
+ from .clsregistry import _class_resolver
+ from .clsregistry import _ModNS
+ from .dependency import DependencyProcessor
from .mapper import Mapper
+ from .query import Query
+ from .session import Session
+ from .state import InstanceState
+ from .strategies import LazyLoader
from .util import AliasedClass
from .util import AliasedInsp
- from ..sql.elements import ColumnElement
+ from ..sql._typing import _CoreAdapterProto
+ from ..sql._typing import _EquivalentColumnMap
+ from ..sql._typing import _InfoType
+ from ..sql.annotation import _AnnotationDict
+ from ..sql.elements import BinaryExpression
+ from ..sql.elements import BindParameter
+ from ..sql.elements import ClauseElement
+ from ..sql.schema import Table
+ from ..sql.selectable import FromClause
+ from ..util.typing import _AnnotationScanType
+ from ..util.typing import RODescriptorReference
_T = TypeVar("_T", bound=Any)
+_T1 = TypeVar("_T1", bound=Any)
+_T2 = TypeVar("_T2", bound=Any)
+
_PT = TypeVar("_PT", bound=Any)
+_PT2 = TypeVar("_PT2", bound=Any)
+
_RelationshipArgumentType = Union[
str,
@@ -111,7 +154,10 @@ _RelationshipJoinConditionArgument = Union[
str, _ColumnExpressionArgument[bool]
]
_ORMOrderByArgument = Union[
- Literal[False], str, _ColumnExpressionArgument[Any]
+ Literal[False],
+ str,
+ _ColumnExpressionArgument[Any],
+ Iterable[Union[str, _ColumnExpressionArgument[Any]]],
]
_ORMBackrefArgument = Union[str, Tuple[str, Dict[str, Any]]]
_ORMColCollectionArgument = Union[
@@ -120,7 +166,19 @@ _ORMColCollectionArgument = Union[
]
-def remote(expr):
+_CEA = TypeVar("_CEA", bound=_ColumnExpressionArgument[Any])
+
+_CE = TypeVar("_CE", bound="ColumnElement[Any]")
+
+
+_ColumnPairIterable = Iterable[Tuple[ColumnElement[Any], ColumnElement[Any]]]
+
+_ColumnPairs = Sequence[Tuple[ColumnElement[Any], ColumnElement[Any]]]
+
+_MutableColumnPairs = List[Tuple[ColumnElement[Any], ColumnElement[Any]]]
+
+
+def remote(expr: _CEA) -> _CEA:
"""Annotate a portion of a primaryjoin expression
with a 'remote' annotation.
@@ -134,12 +192,12 @@ def remote(expr):
:func:`.foreign`
"""
- return _annotate_columns(
+ return _annotate_columns( # type: ignore
coercions.expect(roles.ColumnArgumentRole, expr), {"remote": True}
)
-def foreign(expr):
+def foreign(expr: _CEA) -> _CEA:
"""Annotate a portion of a primaryjoin expression
with a 'foreign' annotation.
@@ -154,11 +212,71 @@ def foreign(expr):
"""
- return _annotate_columns(
+ return _annotate_columns( # type: ignore
coercions.expect(roles.ColumnArgumentRole, expr), {"foreign": True}
)
+@dataclasses.dataclass
+class _RelationshipArg(Generic[_T1, _T2]):
+ """stores a user-defined parameter value that must be resolved and
+ parsed later at mapper configuration time.
+
+ """
+
+ __slots__ = "name", "argument", "resolved"
+ name: str
+ argument: _T1
+ resolved: Optional[_T2]
+
+ def _is_populated(self) -> bool:
+ return self.argument is not None
+
+ def _resolve_against_registry(
+ self, clsregistry_resolver: Callable[[str, bool], _class_resolver]
+ ) -> None:
+ attr_value = self.argument
+
+ if isinstance(attr_value, str):
+ self.resolved = clsregistry_resolver(
+ attr_value, self.name == "secondary"
+ )()
+ elif callable(attr_value) and not _is_mapped_class(attr_value):
+ self.resolved = attr_value()
+ else:
+ self.resolved = attr_value
+
+
+class _RelationshipArgs(NamedTuple):
+ """stores user-passed parameters that are resolved at mapper configuration
+ time.
+
+ """
+
+ secondary: _RelationshipArg[
+ Optional[Union[FromClause, str]],
+ Optional[FromClause],
+ ]
+ primaryjoin: _RelationshipArg[
+ Optional[_RelationshipJoinConditionArgument],
+ Optional[ColumnElement[Any]],
+ ]
+ secondaryjoin: _RelationshipArg[
+ Optional[_RelationshipJoinConditionArgument],
+ Optional[ColumnElement[Any]],
+ ]
+ order_by: _RelationshipArg[
+ _ORMOrderByArgument,
+ Union[Literal[None, False], Tuple[ColumnElement[Any], ...]],
+ ]
+ foreign_keys: _RelationshipArg[
+ Optional[_ORMColCollectionArgument], Set[ColumnElement[Any]]
+ ]
+ remote_side: _RelationshipArg[
+ Optional[_ORMColCollectionArgument], Set[ColumnElement[Any]]
+ ]
+
+
@log.class_logger
class Relationship(
_IntrospectsAnnotations, StrategizedProperty[_T], log.Identified
@@ -184,6 +302,10 @@ class Relationship(
_links_to_entity = True
_is_relationship = True
+ _overlaps: Sequence[str]
+
+ _lazy_strategy: LazyLoader
+
_persistence_only = dict(
passive_deletes=False,
passive_updates=True,
@@ -192,56 +314,87 @@ class Relationship(
cascade_backrefs=False,
)
- _dependency_processor = None
+ _dependency_processor: Optional[DependencyProcessor] = None
+
+ primaryjoin: ColumnElement[bool]
+ secondaryjoin: Optional[ColumnElement[bool]]
+ secondary: Optional[FromClause]
+ _join_condition: JoinCondition
+ order_by: Union[Literal[False], Tuple[ColumnElement[Any], ...]]
+
+ _user_defined_foreign_keys: Set[ColumnElement[Any]]
+ _calculated_foreign_keys: Set[ColumnElement[Any]]
+
+ remote_side: Set[ColumnElement[Any]]
+ local_columns: Set[ColumnElement[Any]]
+
+ synchronize_pairs: _ColumnPairs
+ secondary_synchronize_pairs: Optional[_ColumnPairs]
+
+ local_remote_pairs: Optional[_ColumnPairs]
+
+ direction: RelationshipDirection
+
+ _init_args: _RelationshipArgs
def __init__(
self,
argument: Optional[_RelationshipArgumentType[_T]] = None,
- secondary=None,
+ secondary: Optional[Union[FromClause, str]] = None,
*,
- uselist=None,
- collection_class=None,
- primaryjoin=None,
- secondaryjoin=None,
- back_populates=None,
- order_by=False,
- backref=None,
- cascade_backrefs=False,
- overlaps=None,
- post_update=False,
- cascade="save-update, merge",
- viewonly=False,
+ uselist: Optional[bool] = None,
+ collection_class: Optional[
+ Union[Type[Collection[Any]], Callable[[], Collection[Any]]]
+ ] = None,
+ primaryjoin: Optional[_RelationshipJoinConditionArgument] = None,
+ secondaryjoin: Optional[_RelationshipJoinConditionArgument] = None,
+ back_populates: Optional[str] = None,
+ order_by: _ORMOrderByArgument = False,
+ backref: Optional[_ORMBackrefArgument] = None,
+ overlaps: Optional[str] = None,
+ post_update: bool = False,
+ cascade: str = "save-update, merge",
+ viewonly: bool = False,
lazy: _LazyLoadArgumentType = "select",
- passive_deletes=False,
- passive_updates=True,
- active_history=False,
- enable_typechecks=True,
- foreign_keys=None,
- remote_side=None,
- join_depth=None,
- comparator_factory=None,
- single_parent=False,
- innerjoin=False,
- distinct_target_key=None,
- load_on_pending=False,
- query_class=None,
- info=None,
- omit_join=None,
- sync_backref=None,
- doc=None,
- bake_queries=True,
- _local_remote_pairs=None,
- _legacy_inactive_history_style=False,
+ passive_deletes: Union[Literal["all"], bool] = False,
+ passive_updates: bool = True,
+ active_history: bool = False,
+ enable_typechecks: bool = True,
+ foreign_keys: Optional[_ORMColCollectionArgument] = None,
+ remote_side: Optional[_ORMColCollectionArgument] = None,
+ join_depth: Optional[int] = None,
+ comparator_factory: Optional[
+ Type[Relationship.Comparator[Any]]
+ ] = None,
+ single_parent: bool = False,
+ innerjoin: bool = False,
+ distinct_target_key: Optional[bool] = None,
+ load_on_pending: bool = False,
+ query_class: Optional[Type[Query[Any]]] = None,
+ info: Optional[_InfoType] = None,
+ omit_join: Literal[None, False] = None,
+ sync_backref: Optional[bool] = None,
+ doc: Optional[str] = None,
+ bake_queries: Literal[True] = True,
+ cascade_backrefs: Literal[False] = False,
+ _local_remote_pairs: Optional[_ColumnPairs] = None,
+ _legacy_inactive_history_style: bool = False,
):
super(Relationship, self).__init__()
self.uselist = uselist
self.argument = argument
- self.secondary = secondary
- self.primaryjoin = primaryjoin
- self.secondaryjoin = secondaryjoin
+
+ self._init_args = _RelationshipArgs(
+ _RelationshipArg("secondary", secondary, None),
+ _RelationshipArg("primaryjoin", primaryjoin, None),
+ _RelationshipArg("secondaryjoin", secondaryjoin, None),
+ _RelationshipArg("order_by", order_by, None),
+ _RelationshipArg("foreign_keys", foreign_keys, None),
+ _RelationshipArg("remote_side", remote_side, None),
+ )
+
self.post_update = post_update
- self.direction = None
self.viewonly = viewonly
if viewonly:
self._warn_for_persistence_only_flags(
@@ -258,7 +411,6 @@ class Relationship(
self.sync_backref = sync_backref
self.lazy = lazy
self.single_parent = single_parent
- self._user_defined_foreign_keys = foreign_keys
self.collection_class = collection_class
self.passive_deletes = passive_deletes
@@ -269,7 +421,6 @@ class Relationship(
)
self.passive_updates = passive_updates
- self.remote_side = remote_side
self.enable_typechecks = enable_typechecks
self.query_class = query_class
self.innerjoin = innerjoin
@@ -292,23 +443,22 @@ class Relationship(
self.local_remote_pairs = _local_remote_pairs
self.load_on_pending = load_on_pending
self.comparator_factory = comparator_factory or Relationship.Comparator
- self.comparator = self.comparator_factory(self, None)
util.set_creation_order(self)
if info is not None:
- self.info = info
+ self.info.update(info)
self.strategy_key = (("lazy", self.lazy),)
- self._reverse_property = set()
+ self._reverse_property: Set[Relationship[Any]] = set()
+
if overlaps:
- self._overlaps = set(re.split(r"\s*,\s*", overlaps))
+ self._overlaps = set(re.split(r"\s*,\s*", overlaps)) # type: ignore # noqa: E501
else:
self._overlaps = ()
- self.cascade = cascade
-
- self.order_by = order_by
+ # mypy ignoring the @property setter
+ self.cascade = cascade # type: ignore
self.back_populates = back_populates
@@ -322,7 +472,7 @@ class Relationship(
else:
self.backref = backref
- def _warn_for_persistence_only_flags(self, **kw):
+ def _warn_for_persistence_only_flags(self, **kw: Any) -> None:
for k, v in kw.items():
if v != self._persistence_only[k]:
# we are warning here rather than warn deprecated as this is a
@@ -340,7 +490,7 @@ class Relationship(
"in a future release." % (k,)
)
- def instrument_class(self, mapper):
+ def instrument_class(self, mapper: Mapper[Any]) -> None:
attributes.register_descriptor(
mapper.class_,
self.key,
@@ -378,13 +528,16 @@ class Relationship(
"_extra_criteria",
)
+ prop: RODescriptorReference[Relationship[_PT]]
+ _of_type: Optional[_EntityType[_PT]]
+
def __init__(
self,
- prop,
- parentmapper,
- adapt_to_entity=None,
- of_type=None,
- extra_criteria=(),
+ prop: Relationship[_PT],
+ parentmapper: _InternalEntityType[Any],
+ adapt_to_entity: Optional[AliasedInsp[Any]] = None,
+ of_type: Optional[_EntityType[_PT]] = None,
+ extra_criteria: Tuple[ColumnElement[bool], ...] = (),
):
"""Construction of :class:`.Relationship.Comparator`
is internal to the ORM's attribute mechanics.
@@ -399,15 +552,17 @@ class Relationship(
self._of_type = None
self._extra_criteria = extra_criteria
- def adapt_to_entity(self, adapt_to_entity):
+ def adapt_to_entity(
+ self, adapt_to_entity: AliasedInsp[Any]
+ ) -> Relationship.Comparator[Any]:
return self.__class__(
- self.property,
+ self.prop,
self._parententity,
adapt_to_entity=adapt_to_entity,
of_type=self._of_type,
)
- entity: _InternalEntityType
+ entity: _InternalEntityType[_PT]
"""The target entity referred to by this
:class:`.Relationship.Comparator`.
@@ -419,7 +574,7 @@ class Relationship(
"""
- mapper: Mapper[Any]
+ mapper: Mapper[_PT]
"""The target :class:`_orm.Mapper` referred to by this
:class:`.Relationship.Comparator`.
@@ -428,22 +583,22 @@ class Relationship(
"""
- def _memoized_attr_entity(self) -> _InternalEntityType:
+ def _memoized_attr_entity(self) -> _InternalEntityType[_PT]:
if self._of_type:
- return inspect(self._of_type)
+ return inspect(self._of_type) # type: ignore
else:
return self.prop.entity
- def _memoized_attr_mapper(self) -> Mapper[Any]:
+ def _memoized_attr_mapper(self) -> Mapper[_PT]:
return self.entity.mapper
- def _source_selectable(self):
+ def _source_selectable(self) -> FromClause:
if self._adapt_to_entity:
return self._adapt_to_entity.selectable
else:
return self.property.parent._with_polymorphic_selectable
- def __clause_element__(self):
+ def __clause_element__(self) -> ColumnElement[bool]:
adapt_from = self._source_selectable()
if self._of_type:
of_type_entity = inspect(self._of_type)
@@ -457,7 +612,7 @@ class Relationship(
dest,
secondary,
target_adapter,
- ) = self.property._create_joins(
+ ) = self.prop._create_joins(
source_selectable=adapt_from,
source_polymorphic=True,
of_type_entity=of_type_entity,
@@ -469,7 +624,7 @@ class Relationship(
else:
return pj
- def of_type(self, cls):
+ def of_type(self, class_: _EntityType[_PT]) -> PropComparator[_PT]:
r"""Redefine this object in terms of a polymorphic subclass.
See :meth:`.PropComparator.of_type` for an example.
@@ -477,16 +632,16 @@ class Relationship(
"""
return Relationship.Comparator(
- self.property,
+ self.prop,
self._parententity,
adapt_to_entity=self._adapt_to_entity,
- of_type=cls,
+ of_type=class_,
extra_criteria=self._extra_criteria,
)
def and_(
self, *criteria: _ColumnExpressionArgument[bool]
- ) -> PropComparator[bool]:
+ ) -> PropComparator[Any]:
"""Add AND criteria.
See :meth:`.PropComparator.and_` for an example.
@@ -500,14 +655,14 @@ class Relationship(
)
return Relationship.Comparator(
- self.property,
+ self.prop,
self._parententity,
adapt_to_entity=self._adapt_to_entity,
of_type=self._of_type,
extra_criteria=self._extra_criteria + exprs,
)
- def in_(self, other):
+ def in_(self, other: Any) -> NoReturn:
"""Produce an IN clause - this is not implemented
for :func:`_orm.relationship`-based attributes at this time.
@@ -522,7 +677,7 @@ class Relationship(
# https://github.com/python/mypy/issues/4266
__hash__ = None # type: ignore
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501
"""Implement the ``==`` operator.
In a many-to-one context, such as::
@@ -559,7 +714,7 @@ class Relationship(
or many-to-many context produce a NOT EXISTS clause.
"""
- if isinstance(other, (util.NoneType, expression.Null)):
+ if other is None or isinstance(other, expression.Null):
if self.property.direction in [ONETOMANY, MANYTOMANY]:
return ~self._criterion_exists()
else:
@@ -585,8 +740,18 @@ class Relationship(
criterion: Optional[_ColumnExpressionArgument[bool]] = None,
**kwargs: Any,
) -> Exists:
+
+ where_criteria = (
+ coercions.expect(roles.WhereHavingRole, criterion)
+ if criterion is not None
+ else None
+ )
+
if getattr(self, "_of_type", None):
- info = inspect(self._of_type)
+ info: Optional[_InternalEntityType[Any]] = inspect(
+ self._of_type
+ )
+ assert info is not None
target_mapper, to_selectable, is_aliased_class = (
info.mapper,
info.selectable,
@@ -597,10 +762,10 @@ class Relationship(
single_crit = target_mapper._single_table_criterion
if single_crit is not None:
- if criterion is not None:
- criterion = single_crit & criterion
+ if where_criteria is not None:
+ where_criteria = single_crit & where_criteria
else:
- criterion = single_crit
+ where_criteria = single_crit
else:
is_aliased_class = False
to_selectable = None
@@ -624,10 +789,10 @@ class Relationship(
for k in kwargs:
crit = getattr(self.property.mapper.class_, k) == kwargs[k]
- if criterion is None:
- criterion = crit
+ if where_criteria is None:
+ where_criteria = crit
else:
- criterion = criterion & crit
+ where_criteria = where_criteria & crit
# annotate the *local* side of the join condition, in the case
# of pj + sj this is the full primaryjoin, in the case of just
@@ -638,24 +803,24 @@ class Relationship(
j = _orm_annotate(pj, exclude=self.property.remote_side)
if (
- criterion is not None
+ where_criteria is not None
and target_adapter
and not is_aliased_class
):
# limit this adapter to annotated only?
- criterion = target_adapter.traverse(criterion)
+ where_criteria = target_adapter.traverse(where_criteria)
# only have the "joined left side" of what we
# return be subject to Query adaption. The right
# side of it is used for an exists() subquery and
# should not correlate or otherwise reach out
# to anything in the enclosing query.
- if criterion is not None:
- criterion = criterion._annotate(
+ if where_criteria is not None:
+ where_criteria = where_criteria._annotate(
{"no_replacement_traverse": True}
)
- crit = j & sql.True_._ifnone(criterion)
+ crit = j & sql.True_._ifnone(where_criteria)
if secondary is not None:
ex = (
@@ -673,7 +838,11 @@ class Relationship(
)
return ex
- def any(self, criterion=None, **kwargs):
+ def any(
+ self,
+ criterion: Optional[_ColumnExpressionArgument[bool]] = None,
+ **kwargs: Any,
+ ) -> ColumnElement[bool]:
"""Produce an expression that tests a collection against
particular criterion, using EXISTS.
@@ -722,7 +891,11 @@ class Relationship(
return self._criterion_exists(criterion, **kwargs)
- def has(self, criterion=None, **kwargs):
+ def has(
+ self,
+ criterion: Optional[_ColumnExpressionArgument[bool]] = None,
+ **kwargs: Any,
+ ) -> ColumnElement[bool]:
"""Produce an expression that tests a scalar reference against
particular criterion, using EXISTS.
@@ -756,7 +929,9 @@ class Relationship(
)
return self._criterion_exists(criterion, **kwargs)
- def contains(self, other, **kwargs):
+ def contains(
+ self, other: _ColumnExpressionArgument[Any], **kwargs: Any
+ ) -> ColumnElement[bool]:
"""Return a simple expression that tests a collection for
containment of a particular item.
@@ -815,38 +990,45 @@ class Relationship(
kwargs may be ignored by this operator but are required for API
conformance.
"""
- if not self.property.uselist:
+ if not self.prop.uselist:
raise sa_exc.InvalidRequestError(
"'contains' not implemented for scalar "
"attributes. Use =="
)
- clause = self.property._optimized_compare(
+
+ clause = self.prop._optimized_compare(
other, adapt_source=self.adapter
)
- if self.property.secondaryjoin is not None:
+ if self.prop.secondaryjoin is not None:
clause.negation_clause = self.__negated_contains_or_equals(
other
)
return clause
- def __negated_contains_or_equals(self, other):
- if self.property.direction == MANYTOONE:
+ def __negated_contains_or_equals(
+ self, other: Any
+ ) -> ColumnElement[bool]:
+ if self.prop.direction == MANYTOONE:
state = attributes.instance_state(other)
- def state_bindparam(local_col, state, remote_col):
+ def state_bindparam(
+ local_col: ColumnElement[Any],
+ state: InstanceState[Any],
+ remote_col: ColumnElement[Any],
+ ) -> BindParameter[Any]:
dict_ = state.dict
return sql.bindparam(
local_col.key,
type_=local_col.type,
unique=True,
- callable_=self.property._get_attr_w_warn_on_none(
- self.property.mapper, state, dict_, remote_col
+ callable_=self.prop._get_attr_w_warn_on_none(
+ self.prop.mapper, state, dict_, remote_col
),
)
- def adapt(col):
+ def adapt(col: _CE) -> _CE:
if self.adapter:
return self.adapter(col)
else:
@@ -876,7 +1058,7 @@ class Relationship(
return ~self._criterion_exists(criterion)
- def __ne__(self, other):
+ def __ne__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501
"""Implement the ``!=`` operator.
In a many-to-one context, such as::
@@ -915,7 +1097,7 @@ class Relationship(
or many-to-many context produce an EXISTS clause.
"""
- if isinstance(other, (util.NoneType, expression.Null)):
+ if other is None or isinstance(other, expression.Null):
if self.property.direction == MANYTOONE:
return _orm_annotate(
~self.property._optimized_compare(
@@ -934,12 +1116,10 @@ class Relationship(
else:
return _orm_annotate(self.__negated_contains_or_equals(other))
- def _memoized_attr_property(self):
+ def _memoized_attr_property(self) -> Relationship[_PT]:
self.prop.parent._check_configure()
return self.prop
- comparator: Comparator[_T]
-
def _with_parent(
self,
instance: object,
@@ -947,10 +1127,11 @@ class Relationship(
from_entity: Optional[_EntityType[Any]] = None,
) -> ColumnElement[bool]:
assert instance is not None
- adapt_source = None
+ adapt_source: Optional[_CoreAdapterProto] = None
if from_entity is not None:
- insp = inspect(from_entity)
- if insp.is_aliased_class:
+ insp: Optional[_InternalEntityType[Any]] = inspect(from_entity)
+ assert insp is not None
+ if insp_is_aliased_class(insp):
adapt_source = insp._adapter.adapt_clause
return self._optimized_compare(
instance,
@@ -961,11 +1142,11 @@ class Relationship(
def _optimized_compare(
self,
- state,
- value_is_parent=False,
- adapt_source=None,
- alias_secondary=True,
- ):
+ state: Any,
+ value_is_parent: bool = False,
+ adapt_source: Optional[_CoreAdapterProto] = None,
+ alias_secondary: bool = True,
+ ) -> ColumnElement[bool]:
if state is not None:
try:
state = inspect(state)
@@ -1005,7 +1186,7 @@ class Relationship(
dict_ = attributes.instance_dict(state.obj())
- def visit_bindparam(bindparam):
+ def visit_bindparam(bindparam: BindParameter[Any]) -> None:
if bindparam._identifying_key in bind_to_col:
bindparam.callable = self._get_attr_w_warn_on_none(
mapper,
@@ -1027,7 +1208,13 @@ class Relationship(
criterion = adapt_source(criterion)
return criterion
- def _get_attr_w_warn_on_none(self, mapper, state, dict_, column):
+ def _get_attr_w_warn_on_none(
+ self,
+ mapper: Mapper[Any],
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ column: ColumnElement[Any],
+ ) -> Callable[[], Any]:
"""Create the callable that is used in a many-to-one expression.
E.g.::
@@ -1077,9 +1264,14 @@ class Relationship(
# this feature was added explicitly for use in this method.
state._track_last_known_value(prop.key)
- def _go():
- last_known = to_return = state._last_known_values[prop.key]
- existing_is_available = last_known is not attributes.NO_VALUE
+ lkv_fixed = state._last_known_values
+
+ def _go() -> Any:
+ assert lkv_fixed is not None
+ last_known = to_return = lkv_fixed[prop.key]
+ existing_is_available = (
+ last_known is not LoaderCallableStatus.NO_VALUE
+ )
# we support that the value may have changed. so here we
# try to get the most recent value including re-fetching.
@@ -1089,19 +1281,19 @@ class Relationship(
state,
dict_,
column,
- passive=attributes.PASSIVE_OFF
+ passive=PassiveFlag.PASSIVE_OFF
if state.persistent
- else attributes.PASSIVE_NO_FETCH ^ attributes.INIT_OK,
+ else PassiveFlag.PASSIVE_NO_FETCH ^ PassiveFlag.INIT_OK,
)
- if current_value is attributes.NEVER_SET:
+ if current_value is LoaderCallableStatus.NEVER_SET:
if not existing_is_available:
raise sa_exc.InvalidRequestError(
"Can't resolve value for column %s on object "
"%s; no value has been set for this column"
% (column, state_str(state))
)
- elif current_value is attributes.PASSIVE_NO_RESULT:
+ elif current_value is LoaderCallableStatus.PASSIVE_NO_RESULT:
if not existing_is_available:
raise sa_exc.InvalidRequestError(
"Can't resolve value for column %s on object "
@@ -1121,7 +1313,11 @@ class Relationship(
return _go
- def _lazy_none_clause(self, reverse_direction=False, adapt_source=None):
+ def _lazy_none_clause(
+ self,
+ reverse_direction: bool = False,
+ adapt_source: Optional[_CoreAdapterProto] = None,
+ ) -> ColumnElement[bool]:
if not reverse_direction:
criterion, bind_to_col = (
self._lazy_strategy._lazywhere,
@@ -1139,20 +1335,20 @@ class Relationship(
criterion = adapt_source(criterion)
return criterion
- def __str__(self):
+ def __str__(self) -> str:
return str(self.parent.class_.__name__) + "." + self.key
def merge(
self,
- session,
- source_state,
- source_dict,
- dest_state,
- dest_dict,
- load,
- _recursive,
- _resolve_conflict_map,
- ):
+ session: Session,
+ source_state: InstanceState[Any],
+ source_dict: _InstanceDict,
+ dest_state: InstanceState[Any],
+ dest_dict: _InstanceDict,
+ load: bool,
+ _recursive: Dict[Any, object],
+ _resolve_conflict_map: Dict[_IdentityKeyType[Any], object],
+ ) -> None:
if load:
for r in self._reverse_property:
@@ -1167,6 +1363,8 @@ class Relationship(
if self.uselist:
impl = source_state.get_impl(self.key)
+
+ assert is_has_collection_adapter(impl)
instances_iterable = impl.get_collection(source_state, source_dict)
# if this is a CollectionAttributeImpl, then empty should
@@ -1204,9 +1402,9 @@ class Relationship(
for c in dest_list:
coll.append_without_event(c)
else:
- dest_state.get_impl(self.key).set(
- dest_state, dest_dict, dest_list, _adapt=False
- )
+ dest_impl = dest_state.get_impl(self.key)
+ assert is_has_collection_adapter(dest_impl)
+ dest_impl.set(dest_state, dest_dict, dest_list, _adapt=False)
else:
current = source_dict[self.key]
if current is not None:
@@ -1231,8 +1429,12 @@ class Relationship(
)
def _value_as_iterable(
- self, state, dict_, key, passive=attributes.PASSIVE_OFF
- ):
+ self,
+ state: InstanceState[_O],
+ dict_: _InstanceDict,
+ key: str,
+ passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
+ ) -> Sequence[Tuple[InstanceState[_O], _O]]:
"""Return a list of tuples (state, obj) for the given
key.
@@ -1241,9 +1443,9 @@ class Relationship(
impl = state.manager[key].impl
x = impl.get(state, dict_, passive=passive)
- if x is attributes.PASSIVE_NO_RESULT or x is None:
+ if x is LoaderCallableStatus.PASSIVE_NO_RESULT or x is None:
return []
- elif hasattr(impl, "get_collection"):
+ elif is_has_collection_adapter(impl):
return [
(attributes.instance_state(o), o)
for o in impl.get_collection(state, dict_, x, passive=passive)
@@ -1252,19 +1454,23 @@ class Relationship(
return [(attributes.instance_state(x), x)]
def cascade_iterator(
- self, type_, state, dict_, visited_states, halt_on=None
- ):
+ self,
+ type_: str,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ visited_states: Set[InstanceState[Any]],
+ halt_on: Optional[Callable[[InstanceState[Any]], bool]] = None,
+ ) -> Iterator[Tuple[Any, Mapper[Any], InstanceState[Any], _InstanceDict]]:
# assert type_ in self._cascade
# only actively lazy load on the 'delete' cascade
if type_ != "delete" or self.passive_deletes:
- passive = attributes.PASSIVE_NO_INITIALIZE
+ passive = PassiveFlag.PASSIVE_NO_INITIALIZE
else:
- passive = attributes.PASSIVE_OFF
+ passive = PassiveFlag.PASSIVE_OFF
if type_ == "save-update":
tuples = state.manager[self.key].impl.get_all_pending(state, dict_)
-
else:
tuples = self._value_as_iterable(
state, dict_, self.key, passive=passive
@@ -1285,6 +1491,7 @@ class Relationship(
# see [ticket:2229]
continue
+ assert instance_state is not None
instance_dict = attributes.instance_dict(c)
if halt_on and halt_on(instance_state):
@@ -1308,14 +1515,16 @@ class Relationship(
yield c, instance_mapper, instance_state, instance_dict
@property
- def _effective_sync_backref(self):
+ def _effective_sync_backref(self) -> bool:
if self.viewonly:
return False
else:
return self.sync_backref is not False
@staticmethod
- def _check_sync_backref(rel_a, rel_b):
+ def _check_sync_backref(
+ rel_a: Relationship[Any], rel_b: Relationship[Any]
+ ) -> None:
if rel_a.viewonly and rel_b.sync_backref:
raise sa_exc.InvalidRequestError(
"Relationship %s cannot specify sync_backref=True since %s "
@@ -1328,7 +1537,7 @@ class Relationship(
):
rel_b.sync_backref = False
- def _add_reverse_property(self, key):
+ def _add_reverse_property(self, key: str) -> None:
other = self.mapper.get_property(key, _configure_mappers=False)
if not isinstance(other, Relationship):
raise sa_exc.InvalidRequestError(
@@ -1361,7 +1570,8 @@ class Relationship(
)
if (
- self.direction in (ONETOMANY, MANYTOONE)
+ other._configure_started
+ and self.direction in (ONETOMANY, MANYTOONE)
and self.direction == other.direction
):
raise sa_exc.ArgumentError(
@@ -1372,7 +1582,7 @@ class Relationship(
)
@util.memoized_property
- def entity(self) -> Union["Mapper", "AliasedInsp"]:
+ def entity(self) -> _InternalEntityType[_T]:
"""Return the target mapped entity, which is an inspect() of the
class or aliased class that is referred towards.
@@ -1388,7 +1598,7 @@ class Relationship(
"""
return self.entity.mapper
- def do_init(self):
+ def do_init(self) -> None:
self._check_conflicts()
self._process_dependent_arguments()
self._setup_entity()
@@ -1399,14 +1609,16 @@ class Relationship(
self._generate_backref()
self._join_condition._warn_for_conflicting_sync_targets()
super(Relationship, self).do_init()
- self._lazy_strategy = self._get_strategy((("lazy", "select"),))
+ self._lazy_strategy = cast(
+ "LazyLoader", self._get_strategy((("lazy", "select"),))
+ )
- def _setup_registry_dependencies(self):
+ def _setup_registry_dependencies(self) -> None:
self.parent.mapper.registry._set_depends_on(
self.entity.mapper.registry
)
- def _process_dependent_arguments(self):
+ def _process_dependent_arguments(self) -> None:
"""Convert incoming configuration arguments to their
proper form.
@@ -1417,78 +1629,80 @@ class Relationship(
# accept callables for other attributes which may require
# deferred initialization. This technique is used
# by declarative "string configs" and some recipes.
+ init_args = self._init_args
+
for attr in (
"order_by",
"primaryjoin",
"secondaryjoin",
"secondary",
- "_user_defined_foreign_keys",
+ "foreign_keys",
"remote_side",
):
- attr_value = getattr(self, attr)
-
- if isinstance(attr_value, str):
- setattr(
- self,
- attr,
- self._clsregistry_resolve_arg(
- attr_value, favor_tables=attr == "secondary"
- )(),
- )
- elif callable(attr_value) and not _is_mapped_class(attr_value):
- setattr(self, attr, attr_value())
+
+ rel_arg = getattr(init_args, attr)
+
+ rel_arg._resolve_against_registry(self._clsregistry_resolvers[1])
# remove "annotations" which are present if mapped class
# descriptors are used to create the join expression.
for attr in "primaryjoin", "secondaryjoin":
- val = getattr(self, attr)
+ rel_arg = getattr(init_args, attr)
+ val = rel_arg.resolved
if val is not None:
- setattr(
- self,
- attr,
- _orm_deannotate(
- coercions.expect(
- roles.ColumnArgumentRole, val, argname=attr
- )
- ),
+ rel_arg.resolved = _orm_deannotate(
+ coercions.expect(
+ roles.ColumnArgumentRole, val, argname=attr
+ )
)
- if self.secondary is not None and _is_mapped_class(self.secondary):
+ secondary = init_args.secondary.resolved
+ if secondary is not None and _is_mapped_class(secondary):
raise sa_exc.ArgumentError(
"secondary argument %s passed to to relationship() %s must "
"be a Table object or other FROM clause; can't send a mapped "
"class directly as rows in 'secondary' are persisted "
"independently of a class that is mapped "
- "to that same table." % (self.secondary, self)
+ "to that same table." % (secondary, self)
)
# ensure expressions in self.order_by, foreign_keys,
# remote_side are all columns, not strings.
- if self.order_by is not False and self.order_by is not None:
+ if (
+ init_args.order_by.resolved is not False
+ and init_args.order_by.resolved is not None
+ ):
self.order_by = tuple(
coercions.expect(
roles.ColumnArgumentRole, x, argname="order_by"
)
- for x in util.to_list(self.order_by)
+ for x in util.to_list(init_args.order_by.resolved)
)
+ else:
+ self.order_by = False
self._user_defined_foreign_keys = util.column_set(
coercions.expect(
roles.ColumnArgumentRole, x, argname="foreign_keys"
)
- for x in util.to_column_set(self._user_defined_foreign_keys)
+ for x in util.to_column_set(init_args.foreign_keys.resolved)
)
self.remote_side = util.column_set(
coercions.expect(
roles.ColumnArgumentRole, x, argname="remote_side"
)
- for x in util.to_column_set(self.remote_side)
+ for x in util.to_column_set(init_args.remote_side.resolved)
)
def declarative_scan(
- self, registry, cls, key, annotation, is_dataclass_field
- ):
+ self,
+ registry: _RegistryType,
+ cls: Type[Any],
+ key: str,
+ annotation: Optional[_AnnotationScanType],
+ is_dataclass_field: bool,
+ ) -> None:
argument = _extract_mapped_subtype(
annotation,
cls,
@@ -1502,17 +1716,19 @@ class Relationship(
if hasattr(argument, "__origin__"):
- collection_class = argument.__origin__
+ collection_class = argument.__origin__ # type: ignore
if issubclass(collection_class, abc.Collection):
if self.collection_class is None:
self.collection_class = collection_class
else:
self.uselist = False
- if argument.__args__:
- if issubclass(argument.__origin__, typing.Mapping):
- type_arg = argument.__args__[1]
+ if argument.__args__: # type: ignore
+ if issubclass(
+ argument.__origin__, typing.Mapping # type: ignore
+ ):
+ type_arg = argument.__args__[1] # type: ignore
else:
- type_arg = argument.__args__[0]
+ type_arg = argument.__args__[0] # type: ignore
if hasattr(type_arg, "__forward_arg__"):
str_argument = type_arg.__forward_arg__
argument = str_argument
@@ -1523,12 +1739,12 @@ class Relationship(
f"Generic alias {argument} requires an argument"
)
elif hasattr(argument, "__forward_arg__"):
- argument = argument.__forward_arg__
+ argument = argument.__forward_arg__ # type: ignore
self.argument = argument
@util.preload_module("sqlalchemy.orm.mapper")
- def _setup_entity(self, __argument=None):
+ def _setup_entity(self, __argument: Any = None) -> None:
if "entity" in self.__dict__:
return
@@ -1539,42 +1755,51 @@ class Relationship(
else:
argument = self.argument
+ resolved_argument: _ExternalEntityType[Any]
+
if isinstance(argument, str):
- argument = self._clsregistry_resolve_name(argument)()
+ # we might want to cleanup clsregistry API to make this
+ # more straightforward
+ resolved_argument = cast(
+ "_ExternalEntityType[Any]",
+ self._clsregistry_resolve_name(argument)(),
+ )
elif callable(argument) and not isinstance(
argument, (type, mapperlib.Mapper)
):
- argument = argument()
+ resolved_argument = argument()
else:
- argument = argument
+ resolved_argument = argument
- if isinstance(argument, type):
- entity = class_mapper(argument, configure=False)
+ entity: _InternalEntityType[Any]
+
+ if isinstance(resolved_argument, type):
+ entity = class_mapper(resolved_argument, configure=False)
else:
try:
- entity = inspect(argument)
+ entity = inspect(resolved_argument)
except sa_exc.NoInspectionAvailable:
- entity = None
+ entity = None # type: ignore
if not hasattr(entity, "mapper"):
raise sa_exc.ArgumentError(
"relationship '%s' expects "
"a class or a mapper argument (received: %s)"
- % (self.key, type(argument))
+ % (self.key, type(resolved_argument))
)
self.entity = entity # type: ignore
self.target = self.entity.persist_selectable
- def _setup_join_conditions(self):
+ def _setup_join_conditions(self) -> None:
self._join_condition = jc = JoinCondition(
parent_persist_selectable=self.parent.persist_selectable,
child_persist_selectable=self.entity.persist_selectable,
parent_local_selectable=self.parent.local_table,
child_local_selectable=self.entity.local_table,
- primaryjoin=self.primaryjoin,
- secondary=self.secondary,
- secondaryjoin=self.secondaryjoin,
+ primaryjoin=self._init_args.primaryjoin.resolved,
+ secondary=self._init_args.secondary.resolved,
+ secondaryjoin=self._init_args.secondaryjoin.resolved,
parent_equivalents=self.parent._equivalent_columns,
child_equivalents=self.mapper._equivalent_columns,
consider_as_foreign_keys=self._user_defined_foreign_keys,
@@ -1587,6 +1812,7 @@ class Relationship(
)
self.primaryjoin = jc.primaryjoin
self.secondaryjoin = jc.secondaryjoin
+ self.secondary = jc.secondary
self.direction = jc.direction
self.local_remote_pairs = jc.local_remote_pairs
self.remote_side = jc.remote_columns
@@ -1596,21 +1822,30 @@ class Relationship(
self.secondary_synchronize_pairs = jc.secondary_synchronize_pairs
@property
- def _clsregistry_resolve_arg(self):
+ def _clsregistry_resolve_arg(
+ self,
+ ) -> Callable[[str, bool], _class_resolver]:
return self._clsregistry_resolvers[1]
@property
- def _clsregistry_resolve_name(self):
+ def _clsregistry_resolve_name(
+ self,
+ ) -> Callable[[str], Callable[[], Union[Type[Any], Table, _ModNS]]]:
return self._clsregistry_resolvers[0]
@util.memoized_property
@util.preload_module("sqlalchemy.orm.clsregistry")
- def _clsregistry_resolvers(self):
+ def _clsregistry_resolvers(
+ self,
+ ) -> Tuple[
+ Callable[[str], Callable[[], Union[Type[Any], Table, _ModNS]]],
+ Callable[[str, bool], _class_resolver],
+ ]:
_resolver = util.preloaded.orm_clsregistry._resolver
return _resolver(self.parent.class_, self)
- def _check_conflicts(self):
+ def _check_conflicts(self) -> None:
"""Test that this relationship is legal, warn about
inheritance conflicts."""
if self.parent.non_primary and not class_mapper(
@@ -1637,10 +1872,10 @@ class Relationship(
return self._cascade
@cascade.setter
- def cascade(self, cascade: Union[str, CascadeOptions]):
+ def cascade(self, cascade: Union[str, CascadeOptions]) -> None:
self._set_cascade(cascade)
- def _set_cascade(self, cascade_arg: Union[str, CascadeOptions]):
+ def _set_cascade(self, cascade_arg: Union[str, CascadeOptions]) -> None:
cascade = CascadeOptions(cascade_arg)
if self.viewonly:
@@ -1655,7 +1890,7 @@ class Relationship(
if self._dependency_processor:
self._dependency_processor.cascade = cascade
- def _check_cascade_settings(self, cascade):
+ def _check_cascade_settings(self, cascade: CascadeOptions) -> None:
if (
cascade.delete_orphan
and not self.single_parent
@@ -1699,7 +1934,7 @@ class Relationship(
(self.key, self.parent.class_)
)
- def _persists_for(self, mapper):
+ def _persists_for(self, mapper: Mapper[Any]) -> bool:
"""Return True if this property will persist values on behalf
of the given mapper.
@@ -1710,16 +1945,15 @@ class Relationship(
and mapper.relationships[self.key] is self
)
- def _columns_are_mapped(self, *cols):
+ def _columns_are_mapped(self, *cols: ColumnElement[Any]) -> bool:
"""Return True if all columns in the given collection are
mapped by the tables referenced by this :class:`.Relationship`.
"""
+
+ secondary = self._init_args.secondary.resolved
for c in cols:
- if (
- self.secondary is not None
- and self.secondary.c.contains_column(c)
- ):
+ if secondary is not None and secondary.c.contains_column(c):
continue
if not self.parent.persist_selectable.c.contains_column(
c
@@ -1727,13 +1961,14 @@ class Relationship(
return False
return True
- def _generate_backref(self):
+ def _generate_backref(self) -> None:
"""Interpret the 'backref' instruction to create a
:func:`_orm.relationship` complementary to this one."""
if self.parent.non_primary:
return
if self.backref is not None and not self.back_populates:
+ kwargs: Dict[str, Any]
if isinstance(self.backref, str):
backref_key, kwargs = self.backref, {}
else:
@@ -1805,7 +2040,7 @@ class Relationship(
self._add_reverse_property(self.back_populates)
@util.preload_module("sqlalchemy.orm.dependency")
- def _post_init(self):
+ def _post_init(self) -> None:
dependency = util.preloaded.orm_dependency
if self.uselist is None:
@@ -1816,7 +2051,7 @@ class Relationship(
)(self)
@util.memoized_property
- def _use_get(self):
+ def _use_get(self) -> bool:
"""memoize the 'use_get' attribute of this RelationshipLoader's
lazyloader."""
@@ -1824,18 +2059,25 @@ class Relationship(
return strategy.use_get
@util.memoized_property
- def _is_self_referential(self):
+ def _is_self_referential(self) -> bool:
return self.mapper.common_parent(self.parent)
def _create_joins(
self,
- source_polymorphic=False,
- source_selectable=None,
- dest_selectable=None,
- of_type_entity=None,
- alias_secondary=False,
- extra_criteria=(),
- ):
+ source_polymorphic: bool = False,
+ source_selectable: Optional[FromClause] = None,
+ dest_selectable: Optional[FromClause] = None,
+ of_type_entity: Optional[_InternalEntityType[Any]] = None,
+ alias_secondary: bool = False,
+ extra_criteria: Tuple[ColumnElement[bool], ...] = (),
+ ) -> Tuple[
+ ColumnElement[bool],
+ Optional[ColumnElement[bool]],
+ FromClause,
+ FromClause,
+ Optional[FromClause],
+ Optional[ClauseAdapter],
+ ]:
aliased = False
@@ -1905,38 +2147,56 @@ class Relationship(
)
-def _annotate_columns(element, annotations):
- def clone(elem):
+def _annotate_columns(element: _CE, annotations: _AnnotationDict) -> _CE:
+ def clone(elem: _CE) -> _CE:
if isinstance(elem, expression.ColumnClause):
- elem = elem._annotate(annotations.copy())
+ elem = elem._annotate(annotations.copy()) # type: ignore
elem._copy_internals(clone=clone)
return elem
if element is not None:
element = clone(element)
- clone = None # remove gc cycles
+ clone = None # type: ignore # remove gc cycles
return element
class JoinCondition:
+
+ primaryjoin_initial: Optional[ColumnElement[bool]]
+ primaryjoin: ColumnElement[bool]
+ secondaryjoin: Optional[ColumnElement[bool]]
+ secondary: Optional[FromClause]
+ prop: Relationship[Any]
+
+ synchronize_pairs: _ColumnPairs
+ secondary_synchronize_pairs: _ColumnPairs
+ direction: RelationshipDirection
+
+ parent_persist_selectable: FromClause
+ child_persist_selectable: FromClause
+ parent_local_selectable: FromClause
+ child_local_selectable: FromClause
+
+ _local_remote_pairs: Optional[_ColumnPairs]
+
def __init__(
self,
- parent_persist_selectable,
- child_persist_selectable,
- parent_local_selectable,
- child_local_selectable,
- primaryjoin=None,
- secondary=None,
- secondaryjoin=None,
- parent_equivalents=None,
- child_equivalents=None,
- consider_as_foreign_keys=None,
- local_remote_pairs=None,
- remote_side=None,
- self_referential=False,
- prop=None,
- support_sync=True,
- can_be_synced_fn=lambda *c: True,
+ parent_persist_selectable: FromClause,
+ child_persist_selectable: FromClause,
+ parent_local_selectable: FromClause,
+ child_local_selectable: FromClause,
+ primaryjoin: Optional[ColumnElement[bool]] = None,
+ secondary: Optional[FromClause] = None,
+ secondaryjoin: Optional[ColumnElement[bool]] = None,
+ parent_equivalents: Optional[_EquivalentColumnMap] = None,
+ child_equivalents: Optional[_EquivalentColumnMap] = None,
+ consider_as_foreign_keys: Any = None,
+ local_remote_pairs: Optional[_ColumnPairs] = None,
+ remote_side: Any = None,
+ self_referential: Any = False,
+ prop: Optional[Relationship[Any]] = None,
+ support_sync: bool = True,
+ can_be_synced_fn: Callable[..., bool] = lambda *c: True,
):
self.parent_persist_selectable = parent_persist_selectable
self.parent_local_selectable = parent_local_selectable
@@ -1944,7 +2204,7 @@ class JoinCondition:
self.child_local_selectable = child_local_selectable
self.parent_equivalents = parent_equivalents
self.child_equivalents = child_equivalents
- self.primaryjoin = primaryjoin
+ self.primaryjoin_initial = primaryjoin
self.secondaryjoin = secondaryjoin
self.secondary = secondary
self.consider_as_foreign_keys = consider_as_foreign_keys
@@ -1954,7 +2214,10 @@ class JoinCondition:
self.self_referential = self_referential
self.support_sync = support_sync
self.can_be_synced_fn = can_be_synced_fn
+
self._determine_joins()
+ assert self.primaryjoin is not None
+
self._sanitize_joins()
self._annotate_fks()
self._annotate_remote()
@@ -1968,7 +2231,7 @@ class JoinCondition:
self._check_remote_side()
self._log_joins()
- def _log_joins(self):
+ def _log_joins(self) -> None:
if self.prop is None:
return
log = self.prop.logger
@@ -2008,7 +2271,7 @@ class JoinCondition:
)
log.info("%s relationship direction %s", self.prop, self.direction)
- def _sanitize_joins(self):
+ def _sanitize_joins(self) -> None:
"""remove the parententity annotation from our join conditions which
can leak in here based on some declarative patterns and maybe others.
@@ -2026,7 +2289,7 @@ class JoinCondition:
self.secondaryjoin, values=("parententity", "proxy_key")
)
- def _determine_joins(self):
+ def _determine_joins(self) -> None:
"""Determine the 'primaryjoin' and 'secondaryjoin' attributes,
if not passed to the constructor already.
@@ -2056,21 +2319,25 @@ class JoinCondition:
a_subset=self.child_local_selectable,
consider_as_foreign_keys=consider_as_foreign_keys,
)
- if self.primaryjoin is None:
+ if self.primaryjoin_initial is None:
self.primaryjoin = join_condition(
self.parent_persist_selectable,
self.secondary,
a_subset=self.parent_local_selectable,
consider_as_foreign_keys=consider_as_foreign_keys,
)
+ else:
+ self.primaryjoin = self.primaryjoin_initial
else:
- if self.primaryjoin is None:
+ if self.primaryjoin_initial is None:
self.primaryjoin = join_condition(
self.parent_persist_selectable,
self.child_persist_selectable,
a_subset=self.parent_local_selectable,
consider_as_foreign_keys=consider_as_foreign_keys,
)
+ else:
+ self.primaryjoin = self.primaryjoin_initial
except sa_exc.NoForeignKeysError as nfe:
if self.secondary is not None:
raise sa_exc.NoForeignKeysError(
@@ -2118,15 +2385,16 @@ class JoinCondition:
) from afe
@property
- def primaryjoin_minus_local(self):
+ def primaryjoin_minus_local(self) -> ColumnElement[bool]:
return _deep_deannotate(self.primaryjoin, values=("local", "remote"))
@property
- def secondaryjoin_minus_local(self):
+ def secondaryjoin_minus_local(self) -> ColumnElement[bool]:
+ assert self.secondaryjoin is not None
return _deep_deannotate(self.secondaryjoin, values=("local", "remote"))
@util.memoized_property
- def primaryjoin_reverse_remote(self):
+ def primaryjoin_reverse_remote(self) -> ColumnElement[bool]:
"""Return the primaryjoin condition suitable for the
"reverse" direction.
@@ -2138,7 +2406,7 @@ class JoinCondition:
"""
if self._has_remote_annotations:
- def replace(element):
+ def replace(element: _CE, **kw: Any) -> Optional[_CE]:
if "remote" in element._annotations:
v = dict(element._annotations)
del v["remote"]
@@ -2150,6 +2418,8 @@ class JoinCondition:
v["remote"] = True
return element._with_annotations(v)
+ return None
+
return visitors.replacement_traverse(self.primaryjoin, {}, replace)
else:
if self._has_foreign_annotations:
@@ -2160,7 +2430,7 @@ class JoinCondition:
else:
return _deep_deannotate(self.primaryjoin)
- def _has_annotation(self, clause, annotation):
+ def _has_annotation(self, clause: ClauseElement, annotation: str) -> bool:
for col in visitors.iterate(clause, {}):
if annotation in col._annotations:
return True
@@ -2168,14 +2438,14 @@ class JoinCondition:
return False
@util.memoized_property
- def _has_foreign_annotations(self):
+ def _has_foreign_annotations(self) -> bool:
return self._has_annotation(self.primaryjoin, "foreign")
@util.memoized_property
- def _has_remote_annotations(self):
+ def _has_remote_annotations(self) -> bool:
return self._has_annotation(self.primaryjoin, "remote")
- def _annotate_fks(self):
+ def _annotate_fks(self) -> None:
"""Annotate the primaryjoin and secondaryjoin
structures with 'foreign' annotations marking columns
considered as foreign.
@@ -2189,10 +2459,11 @@ class JoinCondition:
else:
self._annotate_present_fks()
- def _annotate_from_fk_list(self):
- def check_fk(col):
- if col in self.consider_as_foreign_keys:
- return col._annotate({"foreign": True})
+ def _annotate_from_fk_list(self) -> None:
+ def check_fk(element: _CE, **kw: Any) -> Optional[_CE]:
+ if element in self.consider_as_foreign_keys:
+ return element._annotate({"foreign": True})
+ return None
self.primaryjoin = visitors.replacement_traverse(
self.primaryjoin, {}, check_fk
@@ -2202,13 +2473,15 @@ class JoinCondition:
self.secondaryjoin, {}, check_fk
)
- def _annotate_present_fks(self):
+ def _annotate_present_fks(self) -> None:
if self.secondary is not None:
secondarycols = util.column_set(self.secondary.c)
else:
secondarycols = set()
- def is_foreign(a, b):
+ def is_foreign(
+ a: ColumnElement[Any], b: ColumnElement[Any]
+ ) -> Optional[ColumnElement[Any]]:
if isinstance(a, schema.Column) and isinstance(b, schema.Column):
if a.references(b):
return a
@@ -2221,7 +2494,9 @@ class JoinCondition:
elif b in secondarycols and a not in secondarycols:
return b
- def visit_binary(binary):
+ return None
+
+ def visit_binary(binary: BinaryExpression[Any]) -> None:
if not isinstance(
binary.left, sql.ColumnElement
) or not isinstance(binary.right, sql.ColumnElement):
@@ -2248,16 +2523,17 @@ class JoinCondition:
self.secondaryjoin, {}, {"binary": visit_binary}
)
- def _refers_to_parent_table(self):
+ def _refers_to_parent_table(self) -> bool:
"""Return True if the join condition contains column
comparisons where both columns are in both tables.
"""
pt = self.parent_persist_selectable
mt = self.child_persist_selectable
- result = [False]
+ result = False
- def visit_binary(binary):
+ def visit_binary(binary: BinaryExpression[Any]) -> None:
+ nonlocal result
c, f = binary.left, binary.right
if (
isinstance(c, expression.ColumnClause)
@@ -2267,19 +2543,19 @@ class JoinCondition:
and mt.is_derived_from(c.table)
and mt.is_derived_from(f.table)
):
- result[0] = True
+ result = True
visitors.traverse(self.primaryjoin, {}, {"binary": visit_binary})
- return result[0]
+ return result
- def _tables_overlap(self):
+ def _tables_overlap(self) -> bool:
"""Return True if parent/child tables have some overlap."""
return selectables_overlap(
self.parent_persist_selectable, self.child_persist_selectable
)
- def _annotate_remote(self):
+ def _annotate_remote(self) -> None:
"""Annotate the primaryjoin and secondaryjoin
structures with 'remote' annotations marking columns
considered as part of the 'remote' side.
@@ -2301,30 +2577,38 @@ class JoinCondition:
else:
self._annotate_remote_distinct_selectables()
- def _annotate_remote_secondary(self):
+ def _annotate_remote_secondary(self) -> None:
"""annotate 'remote' in primaryjoin, secondaryjoin
when 'secondary' is present.
"""
- def repl(element):
- if self.secondary.c.contains_column(element):
+ assert self.secondary is not None
+ fixed_secondary = self.secondary
+
+ def repl(element: _CE, **kw: Any) -> Optional[_CE]:
+ if fixed_secondary.c.contains_column(element):
return element._annotate({"remote": True})
+ return None
self.primaryjoin = visitors.replacement_traverse(
self.primaryjoin, {}, repl
)
+
+ assert self.secondaryjoin is not None
self.secondaryjoin = visitors.replacement_traverse(
self.secondaryjoin, {}, repl
)
- def _annotate_selfref(self, fn, remote_side_given):
+ def _annotate_selfref(
+ self, fn: Callable[[ColumnElement[Any]], bool], remote_side_given: bool
+ ) -> None:
"""annotate 'remote' in primaryjoin, secondaryjoin
when the relationship is detected as self-referential.
"""
- def visit_binary(binary):
+ def visit_binary(binary: BinaryExpression[Any]) -> None:
equated = binary.left.compare(binary.right)
if isinstance(binary.left, expression.ColumnClause) and isinstance(
binary.right, expression.ColumnClause
@@ -2341,7 +2625,7 @@ class JoinCondition:
self.primaryjoin, {}, {"binary": visit_binary}
)
- def _annotate_remote_from_args(self):
+ def _annotate_remote_from_args(self) -> None:
"""annotate 'remote' in primaryjoin, secondaryjoin
when the 'remote_side' or '_local_remote_pairs'
arguments are used.
@@ -2363,17 +2647,18 @@ class JoinCondition:
self._annotate_selfref(lambda col: col in remote_side, True)
else:
- def repl(element):
+ def repl(element: _CE, **kw: Any) -> Optional[_CE]:
# use set() to avoid generating ``__eq__()`` expressions
# against each element
if element in set(remote_side):
return element._annotate({"remote": True})
+ return None
self.primaryjoin = visitors.replacement_traverse(
self.primaryjoin, {}, repl
)
- def _annotate_remote_with_overlap(self):
+ def _annotate_remote_with_overlap(self) -> None:
"""annotate 'remote' in primaryjoin, secondaryjoin
when the parent/child tables have some set of
tables in common, though is not a fully self-referential
@@ -2381,7 +2666,7 @@ class JoinCondition:
"""
- def visit_binary(binary):
+ def visit_binary(binary: BinaryExpression[Any]) -> None:
binary.left, binary.right = proc_left_right(
binary.left, binary.right
)
@@ -2393,7 +2678,9 @@ class JoinCondition:
self.prop is not None and self.prop.mapper is not self.prop.parent
)
- def proc_left_right(left, right):
+ def proc_left_right(
+ left: ColumnElement[Any], right: ColumnElement[Any]
+ ) -> Tuple[ColumnElement[Any], ColumnElement[Any]]:
if isinstance(left, expression.ColumnClause) and isinstance(
right, expression.ColumnClause
):
@@ -2420,32 +2707,33 @@ class JoinCondition:
self.primaryjoin, {}, {"binary": visit_binary}
)
- def _annotate_remote_distinct_selectables(self):
+ def _annotate_remote_distinct_selectables(self) -> None:
"""annotate 'remote' in primaryjoin, secondaryjoin
when the parent/child tables are entirely
separate.
"""
- def repl(element):
+ def repl(element: _CE, **kw: Any) -> Optional[_CE]:
if self.child_persist_selectable.c.contains_column(element) and (
not self.parent_local_selectable.c.contains_column(element)
or self.child_local_selectable.c.contains_column(element)
):
return element._annotate({"remote": True})
+ return None
self.primaryjoin = visitors.replacement_traverse(
self.primaryjoin, {}, repl
)
- def _warn_non_column_elements(self):
+ def _warn_non_column_elements(self) -> None:
util.warn(
"Non-simple column elements in primary "
"join condition for property %s - consider using "
"remote() annotations to mark the remote side." % self.prop
)
- def _annotate_local(self):
+ def _annotate_local(self) -> None:
"""Annotate the primaryjoin and secondaryjoin
structures with 'local' annotations.
@@ -2466,29 +2754,31 @@ class JoinCondition:
else:
local_side = util.column_set(self.parent_persist_selectable.c)
- def locals_(elem):
- if "remote" not in elem._annotations and elem in local_side:
- return elem._annotate({"local": True})
+ def locals_(element: _CE, **kw: Any) -> Optional[_CE]:
+ if "remote" not in element._annotations and element in local_side:
+ return element._annotate({"local": True})
+ return None
self.primaryjoin = visitors.replacement_traverse(
self.primaryjoin, {}, locals_
)
- def _annotate_parentmapper(self):
+ def _annotate_parentmapper(self) -> None:
if self.prop is None:
return
- def parentmappers_(elem):
- if "remote" in elem._annotations:
- return elem._annotate({"parentmapper": self.prop.mapper})
- elif "local" in elem._annotations:
- return elem._annotate({"parentmapper": self.prop.parent})
+ def parentmappers_(element: _CE, **kw: Any) -> Optional[_CE]:
+ if "remote" in element._annotations:
+ return element._annotate({"parentmapper": self.prop.mapper})
+ elif "local" in element._annotations:
+ return element._annotate({"parentmapper": self.prop.parent})
+ return None
self.primaryjoin = visitors.replacement_traverse(
self.primaryjoin, {}, parentmappers_
)
- def _check_remote_side(self):
+ def _check_remote_side(self) -> None:
if not self.local_remote_pairs:
raise sa_exc.ArgumentError(
"Relationship %s could "
@@ -2501,7 +2791,9 @@ class JoinCondition:
"the relationship." % (self.prop,)
)
- def _check_foreign_cols(self, join_condition, primary):
+ def _check_foreign_cols(
+ self, join_condition: ColumnElement[bool], primary: bool
+ ) -> None:
"""Check the foreign key columns collected and emit error
messages."""
@@ -2567,7 +2859,7 @@ class JoinCondition:
)
raise sa_exc.ArgumentError(err)
- def _determine_direction(self):
+ def _determine_direction(self) -> None:
"""Determine if this relationship is one to many, many to one,
many to many.
@@ -2651,7 +2943,9 @@ class JoinCondition:
"nor the child's mapped tables" % self.prop
)
- def _deannotate_pairs(self, collection):
+ def _deannotate_pairs(
+ self, collection: _ColumnPairIterable
+ ) -> _MutableColumnPairs:
"""provide deannotation for the various lists of
pairs, so that using them in hashes doesn't incur
high-overhead __eq__() comparisons against
@@ -2660,13 +2954,22 @@ class JoinCondition:
"""
return [(x._deannotate(), y._deannotate()) for x, y in collection]
- def _setup_pairs(self):
- sync_pairs = []
- lrp = util.OrderedSet([])
- secondary_sync_pairs = []
-
- def go(joincond, collection):
- def visit_binary(binary, left, right):
+ def _setup_pairs(self) -> None:
+ sync_pairs: _MutableColumnPairs = []
+ lrp: util.OrderedSet[
+ Tuple[ColumnElement[Any], ColumnElement[Any]]
+ ] = util.OrderedSet([])
+ secondary_sync_pairs: _MutableColumnPairs = []
+
+ def go(
+ joincond: ColumnElement[bool],
+ collection: _MutableColumnPairs,
+ ) -> None:
+ def visit_binary(
+ binary: BinaryExpression[Any],
+ left: ColumnElement[Any],
+ right: ColumnElement[Any],
+ ) -> None:
if (
"remote" in right._annotations
and "remote" not in left._annotations
@@ -2703,9 +3006,12 @@ class JoinCondition:
secondary_sync_pairs
)
- _track_overlapping_sync_targets = weakref.WeakKeyDictionary()
+ _track_overlapping_sync_targets: weakref.WeakKeyDictionary[
+ ColumnElement[Any],
+ weakref.WeakKeyDictionary[Relationship[Any], ColumnElement[Any]],
+ ] = weakref.WeakKeyDictionary()
- def _warn_for_conflicting_sync_targets(self):
+ def _warn_for_conflicting_sync_targets(self) -> None:
if not self.support_sync:
return
@@ -2793,18 +3099,20 @@ class JoinCondition:
self._track_overlapping_sync_targets[to_][self.prop] = from_
@util.memoized_property
- def remote_columns(self):
+ def remote_columns(self) -> Set[ColumnElement[Any]]:
return self._gather_join_annotations("remote")
@util.memoized_property
- def local_columns(self):
+ def local_columns(self) -> Set[ColumnElement[Any]]:
return self._gather_join_annotations("local")
@util.memoized_property
- def foreign_key_columns(self):
+ def foreign_key_columns(self) -> Set[ColumnElement[Any]]:
return self._gather_join_annotations("foreign")
- def _gather_join_annotations(self, annotation):
+ def _gather_join_annotations(
+ self, annotation: str
+ ) -> Set[ColumnElement[Any]]:
s = set(
self._gather_columns_with_annotation(self.primaryjoin, annotation)
)
@@ -2816,24 +3124,32 @@ class JoinCondition:
)
return {x._deannotate() for x in s}
- def _gather_columns_with_annotation(self, clause, *annotation):
- annotation = set(annotation)
+ def _gather_columns_with_annotation(
+ self, clause: ColumnElement[Any], *annotation: Iterable[str]
+ ) -> Set[ColumnElement[Any]]:
+ annotation_set = set(annotation)
return set(
[
- col
+ cast(ColumnElement[Any], col)
for col in visitors.iterate(clause, {})
- if annotation.issubset(col._annotations)
+ if annotation_set.issubset(col._annotations)
]
)
def join_targets(
self,
- source_selectable,
- dest_selectable,
- aliased,
- single_crit=None,
- extra_criteria=(),
- ):
+ source_selectable: Optional[FromClause],
+ dest_selectable: FromClause,
+ aliased: bool,
+ single_crit: Optional[ColumnElement[bool]] = None,
+ extra_criteria: Tuple[ColumnElement[bool], ...] = (),
+ ) -> Tuple[
+ ColumnElement[bool],
+ Optional[ColumnElement[bool]],
+ Optional[FromClause],
+ Optional[ClauseAdapter],
+ FromClause,
+ ]:
"""Given a source and destination selectable, create a
join between them.
@@ -2923,9 +3239,15 @@ class JoinCondition:
dest_selectable,
)
- def create_lazy_clause(self, reverse_direction=False):
- binds = util.column_dict()
- equated_columns = util.column_dict()
+ def create_lazy_clause(
+ self, reverse_direction: bool = False
+ ) -> Tuple[
+ ColumnElement[bool],
+ Dict[str, ColumnElement[Any]],
+ Dict[ColumnElement[Any], ColumnElement[Any]],
+ ]:
+ binds: Dict[ColumnElement[Any], BindParameter[Any]] = {}
+ equated_columns: Dict[ColumnElement[Any], ColumnElement[Any]] = {}
has_secondary = self.secondaryjoin is not None
@@ -2941,21 +3263,23 @@ class JoinCondition:
for l, r in self.local_remote_pairs:
equated_columns[l] = r
- def col_to_bind(col):
+ def col_to_bind(
+ element: ColumnElement[Any], **kw: Any
+ ) -> Optional[BindParameter[Any]]:
if (
- (not reverse_direction and "local" in col._annotations)
+ (not reverse_direction and "local" in element._annotations)
or reverse_direction
and (
- (has_secondary and col in lookup)
- or (not has_secondary and "remote" in col._annotations)
+ (has_secondary and element in lookup)
+ or (not has_secondary and "remote" in element._annotations)
)
):
- if col not in binds:
- binds[col] = sql.bindparam(
- None, None, type_=col.type, unique=True
+ if element not in binds:
+ binds[element] = sql.bindparam(
+ None, None, type_=element.type, unique=True
)
- return binds[col]
+ return binds[element]
return None
lazywhere = self.primaryjoin
@@ -2982,8 +3306,8 @@ class _ColInAnnotations:
__slots__ = ("name",)
- def __init__(self, name):
+ def __init__(self, name: str):
self.name = name
- def __call__(self, c):
+ def __call__(self, c: ClauseElement) -> bool:
return self.name in c._annotations
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index b5491248b..d72e78c9e 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -118,6 +118,7 @@ if typing.TYPE_CHECKING:
from ..sql._typing import _T7
from ..sql._typing import _TypedColumnClauseArgument as _TCCA
from ..sql.base import Executable
+ from ..sql.base import ExecutableOption
from ..sql.elements import ClauseElement
from ..sql.roles import TypedColumnsClauseRole
from ..sql.selectable import TypedReturnsRows
@@ -765,7 +766,7 @@ class SessionTransaction(_StateChange, TransactionalContext):
self.session.dispatch.after_transaction_create(self.session, self)
def _raise_for_prerequisite_state(
- self, operation_name: str, state: SessionTransactionState
+ self, operation_name: str, state: _StateChangeState
) -> NoReturn:
if state is SessionTransactionState.DEACTIVE:
if self._rollback_exception:
@@ -3183,7 +3184,7 @@ class Session(_SessionClassMethods, EventTarget):
primary_key_identity: _PKIdentityArgument,
db_load_fn: Callable[..., _O],
*,
- options: Optional[Sequence[ORMOption]] = None,
+ options: Optional[Sequence[ExecutableOption]] = None,
populate_existing: bool = False,
with_for_update: Optional[ForUpdateArg] = None,
identity_token: Optional[Any] = None,
@@ -3377,7 +3378,7 @@ class Session(_SessionClassMethods, EventTarget):
*,
options: Optional[Sequence[ORMOption]] = None,
load: bool,
- _recursive: Dict[InstanceState[Any], object],
+ _recursive: Dict[Any, object],
_resolve_conflict_map: Dict[_IdentityKeyType[Any], object],
) -> _O:
mapper: Mapper[_O] = _state_mapper(state)
diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py
index cb8b1f4aa..af9f48706 100644
--- a/lib/sqlalchemy/orm/state.py
+++ b/lib/sqlalchemy/orm/state.py
@@ -82,6 +82,22 @@ class _InstanceDictProto(Protocol):
...
+class _InstallLoaderCallableProto(Protocol[_O]):
+ """used at result loading time to install a _LoaderCallable callable
+ upon a specific InstanceState, which will be used to populate an
+ attribute when that attribute is accessed.
+
+ Concrete examples are per-instance deferred column loaders and
+ relationship lazy loaders.
+
+ """
+
+ def __call__(
+ self, state: InstanceState[_O], dict_: _InstanceDict, row: Row[Any]
+ ) -> None:
+ ...
+
+
@inspection._self_inspects
class InstanceState(interfaces.InspectionAttrInfo, Generic[_O]):
"""tracks state information at the instance level.
@@ -658,7 +674,7 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_O]):
@classmethod
def _instance_level_callable_processor(
cls, manager: ClassManager[_O], fn: _LoaderCallable, key: Any
- ) -> Callable[[InstanceState[_O], _InstanceDict, Row[Any]], None]:
+ ) -> _InstallLoaderCallableProto[_O]:
impl = manager[key].impl
if is_collection_impl(impl):
fixed_impl = impl
diff --git a/lib/sqlalchemy/orm/state_changes.py b/lib/sqlalchemy/orm/state_changes.py
index b7bf96558..764b5dfa6 100644
--- a/lib/sqlalchemy/orm/state_changes.py
+++ b/lib/sqlalchemy/orm/state_changes.py
@@ -4,6 +4,7 @@
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
"""State tracking utilities used by :class:`_orm.Session`.
"""
@@ -14,6 +15,9 @@ import contextlib
from enum import Enum
from typing import Any
from typing import Callable
+from typing import cast
+from typing import Iterator
+from typing import NoReturn
from typing import Optional
from typing import Tuple
from typing import TypeVar
@@ -48,9 +52,11 @@ class _StateChange:
_next_state: _StateChangeState = _StateChangeStates.ANY
_state: _StateChangeState = _StateChangeStates.NO_CHANGE
- _current_fn: Optional[Callable] = None
+ _current_fn: Optional[Callable[..., Any]] = None
- def _raise_for_prerequisite_state(self, operation_name, state):
+ def _raise_for_prerequisite_state(
+ self, operation_name: str, state: _StateChangeState
+ ) -> NoReturn:
raise sa_exc.IllegalStateChangeError(
f"Can't run operation '{operation_name}()' when Session "
f"is in state {state!r}"
@@ -80,16 +86,19 @@ class _StateChange:
prerequisite_states is not _StateChangeStates.ANY
)
+ prerequisite_state_collection = cast(
+ "Tuple[_StateChangeState, ...]", prerequisite_states
+ )
expect_state_change = moves_to is not _StateChangeStates.NO_CHANGE
@util.decorator
- def _go(fn, self, *arg, **kw):
+ def _go(fn: _F, self: Any, *arg: Any, **kw: Any) -> Any:
current_state = self._state
if (
has_prerequisite_states
- and current_state not in prerequisite_states
+ and current_state not in prerequisite_state_collection
):
self._raise_for_prerequisite_state(fn.__name__, current_state)
@@ -159,7 +168,7 @@ class _StateChange:
return _go
@contextlib.contextmanager
- def _expect_state(self, expected: _StateChangeState):
+ def _expect_state(self, expected: _StateChangeState) -> Iterator[Any]:
"""called within a method that changes states.
method must also use the ``@declare_states()`` decorator.
diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py
index 0ba22e7a7..5dc80e4f2 100644
--- a/lib/sqlalchemy/orm/strategies.py
+++ b/lib/sqlalchemy/orm/strategies.py
@@ -14,6 +14,10 @@ from __future__ import annotations
import collections
import itertools
+from typing import Any
+from typing import Dict
+from typing import Tuple
+from typing import TYPE_CHECKING
from . import attributes
from . import exc as orm_exc
@@ -28,7 +32,9 @@ from . import util as orm_util
from .base import _DEFER_FOR_STATE
from .base import _RAISE_FOR_STATE
from .base import _SET_DEFERRED_EXPIRED
+from .base import LoaderCallableStatus
from .base import PASSIVE_OFF
+from .base import PassiveFlag
from .context import _column_descriptions
from .context import ORMCompileState
from .context import ORMSelectCompileState
@@ -50,6 +56,10 @@ from ..sql import visitors
from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
from ..sql.selectable import Select
+if TYPE_CHECKING:
+ from .relationships import Relationship
+ from ..sql.elements import ColumnElement
+
def _register_attribute(
prop,
@@ -486,10 +496,10 @@ class DeferredColumnLoader(LoaderStrategy):
def _load_for_state(self, state, passive):
if not state.key:
- return attributes.ATTR_EMPTY
+ return LoaderCallableStatus.ATTR_EMPTY
- if not passive & attributes.SQL_OK:
- return attributes.PASSIVE_NO_RESULT
+ if not passive & PassiveFlag.SQL_OK:
+ return LoaderCallableStatus.PASSIVE_NO_RESULT
localparent = state.manager.mapper
@@ -522,7 +532,7 @@ class DeferredColumnLoader(LoaderStrategy):
state.mapper, state, set(group), PASSIVE_OFF
)
- return attributes.ATTR_WAS_SET
+ return LoaderCallableStatus.ATTR_WAS_SET
def _invoke_raise_load(self, state, passive, lazy):
raise sa_exc.InvalidRequestError(
@@ -626,7 +636,9 @@ class NoLoader(AbstractRelationshipLoader):
@relationships.Relationship.strategy_for(lazy="raise")
@relationships.Relationship.strategy_for(lazy="raise_on_sql")
@relationships.Relationship.strategy_for(lazy="baked_select")
-class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
+class LazyLoader(
+ AbstractRelationshipLoader, util.MemoizedSlots, log.Identified
+):
"""Provide loading behavior for a :class:`.Relationship`
with "lazy=True", that is loads when first accessed.
@@ -648,7 +660,16 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
"_raise_on_sql",
)
- def __init__(self, parent, strategy_key):
+ _lazywhere: ColumnElement[bool]
+ _bind_to_col: Dict[str, ColumnElement[Any]]
+ _rev_lazywhere: ColumnElement[bool]
+ _rev_bind_to_col: Dict[str, ColumnElement[Any]]
+
+ parent_property: Relationship[Any]
+
+ def __init__(
+ self, parent: Relationship[Any], strategy_key: Tuple[Any, ...]
+ ):
super(LazyLoader, self).__init__(parent, strategy_key)
self._raise_always = self.strategy_opts["lazy"] == "raise"
self._raise_on_sql = self.strategy_opts["lazy"] == "raise_on_sql"
@@ -786,13 +807,13 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
o = state.obj() # strong ref
dict_ = attributes.instance_dict(o)
- if passive & attributes.INIT_OK:
- passive ^= attributes.INIT_OK
+ if passive & PassiveFlag.INIT_OK:
+ passive ^= PassiveFlag.INIT_OK
params = {}
for key, ident, value in param_keys:
if ident is not None:
- if passive and passive & attributes.LOAD_AGAINST_COMMITTED:
+ if passive and passive & PassiveFlag.LOAD_AGAINST_COMMITTED:
value = mapper._get_committed_state_attr_by_column(
state, dict_, ident, passive
)
@@ -818,23 +839,23 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
)
or not state.session_id
):
- return attributes.ATTR_EMPTY
+ return LoaderCallableStatus.ATTR_EMPTY
pending = not state.key
primary_key_identity = None
use_get = self.use_get and (not loadopt or not loadopt._extra_criteria)
- if (not passive & attributes.SQL_OK and not use_get) or (
+ if (not passive & PassiveFlag.SQL_OK and not use_get) or (
not passive & attributes.NON_PERSISTENT_OK and pending
):
- return attributes.PASSIVE_NO_RESULT
+ return LoaderCallableStatus.PASSIVE_NO_RESULT
if (
# we were given lazy="raise"
self._raise_always
# the no_raise history-related flag was not passed
- and not passive & attributes.NO_RAISE
+ and not passive & PassiveFlag.NO_RAISE
and (
# if we are use_get and related_object_ok is disabled,
# which means we are at most looking in the identity map
@@ -842,7 +863,7 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
# PASSIVE_NO_RESULT, don't raise. This is also a
# history-related flag
not use_get
- or passive & attributes.RELATED_OBJECT_OK
+ or passive & PassiveFlag.RELATED_OBJECT_OK
)
):
@@ -850,8 +871,8 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
session = _state_session(state)
if not session:
- if passive & attributes.NO_RAISE:
- return attributes.PASSIVE_NO_RESULT
+ if passive & PassiveFlag.NO_RAISE:
+ return LoaderCallableStatus.PASSIVE_NO_RESULT
raise orm_exc.DetachedInstanceError(
"Parent instance %s is not bound to a Session; "
@@ -865,19 +886,19 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
primary_key_identity = self._get_ident_for_use_get(
session, state, passive
)
- if attributes.PASSIVE_NO_RESULT in primary_key_identity:
- return attributes.PASSIVE_NO_RESULT
- elif attributes.NEVER_SET in primary_key_identity:
- return attributes.NEVER_SET
+ if LoaderCallableStatus.PASSIVE_NO_RESULT in primary_key_identity:
+ return LoaderCallableStatus.PASSIVE_NO_RESULT
+ elif LoaderCallableStatus.NEVER_SET in primary_key_identity:
+ return LoaderCallableStatus.NEVER_SET
if _none_set.issuperset(primary_key_identity):
return None
if (
self.key in state.dict
- and not passive & attributes.DEFERRED_HISTORY_LOAD
+ and not passive & PassiveFlag.DEFERRED_HISTORY_LOAD
):
- return attributes.ATTR_WAS_SET
+ return LoaderCallableStatus.ATTR_WAS_SET
# look for this identity in the identity map. Delegate to the
# Query class in use, as it may have special rules for how it
@@ -892,15 +913,15 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
)
if instance is not None:
- if instance is attributes.PASSIVE_CLASS_MISMATCH:
+ if instance is LoaderCallableStatus.PASSIVE_CLASS_MISMATCH:
return None
else:
return instance
elif (
- not passive & attributes.SQL_OK
- or not passive & attributes.RELATED_OBJECT_OK
+ not passive & PassiveFlag.SQL_OK
+ or not passive & PassiveFlag.RELATED_OBJECT_OK
):
- return attributes.PASSIVE_NO_RESULT
+ return LoaderCallableStatus.PASSIVE_NO_RESULT
return self._emit_lazyload(
session,
@@ -914,7 +935,7 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
def _get_ident_for_use_get(self, session, state, passive):
instance_mapper = state.manager.mapper
- if passive & attributes.LOAD_AGAINST_COMMITTED:
+ if passive & PassiveFlag.LOAD_AGAINST_COMMITTED:
get_attr = instance_mapper._get_committed_state_attr_by_column
else:
get_attr = instance_mapper._get_state_attr_by_column
@@ -985,7 +1006,7 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
stmt._compile_options += {"_current_path": effective_path}
if use_get:
- if self._raise_on_sql and not passive & attributes.NO_RAISE:
+ if self._raise_on_sql and not passive & PassiveFlag.NO_RAISE:
self._invoke_raise_load(state, passive, "raise_on_sql")
return loading.load_on_pk_identity(
@@ -1022,9 +1043,9 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
if (
self.key in state.dict
- and not passive & attributes.DEFERRED_HISTORY_LOAD
+ and not passive & PassiveFlag.DEFERRED_HISTORY_LOAD
):
- return attributes.ATTR_WAS_SET
+ return LoaderCallableStatus.ATTR_WAS_SET
if pending:
if util.has_intersection(orm_util._none_set, params.values()):
@@ -1033,7 +1054,7 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
elif util.has_intersection(orm_util._never_set, params.values()):
return None
- if self._raise_on_sql and not passive & attributes.NO_RAISE:
+ if self._raise_on_sql and not passive & PassiveFlag.NO_RAISE:
self._invoke_raise_load(state, passive, "raise_on_sql")
stmt._where_criteria = (lazy_clause,)
@@ -1246,9 +1267,9 @@ class ImmediateLoader(PostLoader):
# "use get" load. the "_RELATED" part means it may return
# instance even if its expired, since this is a mutually-recursive
# load operation.
- flags = attributes.PASSIVE_NO_FETCH_RELATED | attributes.NO_RAISE
+ flags = attributes.PASSIVE_NO_FETCH_RELATED | PassiveFlag.NO_RAISE
else:
- flags = attributes.PASSIVE_OFF | attributes.NO_RAISE
+ flags = attributes.PASSIVE_OFF | PassiveFlag.NO_RAISE
populators["delayed"].append((self.key, load_immediate))
@@ -2840,7 +2861,7 @@ class SelectInLoader(PostLoader, util.MemoizedSlots):
# if the loaded parent objects do not have the foreign key
# to the related item loaded, then degrade into the joined
# version of selectinload
- if attributes.PASSIVE_NO_RESULT in related_ident:
+ if LoaderCallableStatus.PASSIVE_NO_RESULT in related_ident:
query_info = self._fallback_query_info
break
diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py
index 63679dd27..7aed6dd7b 100644
--- a/lib/sqlalchemy/orm/strategy_options.py
+++ b/lib/sqlalchemy/orm/strategy_options.py
@@ -3,6 +3,7 @@
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
+# mypy: allow-untyped-defs, allow-untyped-calls
"""
@@ -12,18 +13,30 @@ from __future__ import annotations
import typing
from typing import Any
+from typing import Callable
from typing import cast
-from typing import Mapping
-from typing import NoReturn
+from typing import Dict
+from typing import Iterable
from typing import Optional
+from typing import overload
+from typing import Sequence
from typing import Tuple
+from typing import Type
+from typing import TypeVar
from typing import Union
from . import util as orm_util
+from ._typing import insp_is_aliased_class
+from ._typing import insp_is_attribute
+from ._typing import insp_is_mapper
+from ._typing import insp_is_mapper_property
+from .attributes import QueryableAttribute
from .base import InspectionAttr
from .interfaces import LoaderOption
from .path_registry import _DEFAULT_TOKEN
from .path_registry import _WILDCARD_TOKEN
+from .path_registry import AbstractEntityRegistry
+from .path_registry import path_is_property
from .path_registry import PathRegistry
from .path_registry import TokenRegistry
from .util import _orm_full_deannotate
@@ -38,14 +51,37 @@ from ..sql import roles
from ..sql import traversals
from ..sql import visitors
from ..sql.base import _generative
+from ..util.typing import Final
+from ..util.typing import Literal
-_RELATIONSHIP_TOKEN = "relationship"
-_COLUMN_TOKEN = "column"
+_RELATIONSHIP_TOKEN: Final[Literal["relationship"]] = "relationship"
+_COLUMN_TOKEN: Final[Literal["column"]] = "column"
+
+_FN = TypeVar("_FN", bound="Callable[..., Any]")
if typing.TYPE_CHECKING:
+ from ._typing import _EntityType
+ from ._typing import _InternalEntityType
+ from .context import _MapperEntity
+ from .context import ORMCompileState
+ from .context import QueryContext
+ from .interfaces import _StrategyKey
+ from .interfaces import MapperProperty
from .mapper import Mapper
+ from .path_registry import _PathRepresentation
+ from ..sql._typing import _ColumnExpressionArgument
+ from ..sql._typing import _FromClauseArgument
+ from ..sql.cache_key import _CacheKeyTraversalType
+ from ..sql.cache_key import CacheKey
+
+Self_AbstractLoad = TypeVar("Self_AbstractLoad", bound="_AbstractLoad")
+
+_AttrType = Union[str, "QueryableAttribute[Any]"]
-Self_AbstractLoad = typing.TypeVar("Self_AbstractLoad", bound="_AbstractLoad")
+_WildcardKeyType = Literal["relationship", "column"]
+_StrategySpec = Dict[str, Any]
+_OptsType = Dict[str, Any]
+_AttrGroupType = Tuple[_AttrType, ...]
class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
@@ -54,7 +90,12 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
_is_strategy_option = True
propagate_to_loaders: bool
- def contains_eager(self, attr, alias=None, _is_chain=False):
+ def contains_eager(
+ self: Self_AbstractLoad,
+ attr: _AttrType,
+ alias: Optional[_FromClauseArgument] = None,
+ _is_chain: bool = False,
+ ) -> Self_AbstractLoad:
r"""Indicate that the given attribute should be eagerly loaded from
columns stated manually in the query.
@@ -94,9 +135,7 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
"""
if alias is not None:
if not isinstance(alias, str):
- info = inspect(alias)
- alias = info.selectable
-
+ coerced_alias = coercions.expect(roles.FromClauseRole, alias)
else:
util.warn_deprecated(
"Passing a string name for the 'alias' argument to "
@@ -105,21 +144,28 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
"sqlalchemy.orm.aliased() construct.",
version="1.4",
)
+ coerced_alias = alias
elif getattr(attr, "_of_type", None):
- ot = inspect(attr._of_type)
- alias = ot.selectable
+ assert isinstance(attr, QueryableAttribute)
+ ot: Optional[_InternalEntityType[Any]] = inspect(attr._of_type)
+ assert ot is not None
+ coerced_alias = ot.selectable
+ else:
+ coerced_alias = None
cloned = self._set_relationship_strategy(
attr,
{"lazy": "joined"},
propagate_to_loaders=False,
- opts={"eager_from_alias": alias},
+ opts={"eager_from_alias": coerced_alias},
_reconcile_to_other=True if _is_chain else None,
)
return cloned
- def load_only(self, *attrs):
+ def load_only(
+ self: Self_AbstractLoad, *attrs: _AttrType
+ ) -> Self_AbstractLoad:
"""Indicate that for a particular entity, only the given list
of column-based attribute names should be loaded; all others will be
deferred.
@@ -159,11 +205,17 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
{"deferred": False, "instrument": True},
)
cloned = cloned._set_column_strategy(
- "*", {"deferred": True, "instrument": True}, {"undefer_pks": True}
+ ("*",),
+ {"deferred": True, "instrument": True},
+ {"undefer_pks": True},
)
return cloned
- def joinedload(self, attr, innerjoin=None):
+ def joinedload(
+ self: Self_AbstractLoad,
+ attr: _AttrType,
+ innerjoin: Optional[bool] = None,
+ ) -> Self_AbstractLoad:
"""Indicate that the given attribute should be loaded using joined
eager loading.
@@ -258,7 +310,9 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
)
return loader
- def subqueryload(self, attr):
+ def subqueryload(
+ self: Self_AbstractLoad, attr: _AttrType
+ ) -> Self_AbstractLoad:
"""Indicate that the given attribute should be loaded using
subquery eager loading.
@@ -289,7 +343,9 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
"""
return self._set_relationship_strategy(attr, {"lazy": "subquery"})
- def selectinload(self, attr):
+ def selectinload(
+ self: Self_AbstractLoad, attr: _AttrType
+ ) -> Self_AbstractLoad:
"""Indicate that the given attribute should be loaded using
SELECT IN eager loading.
@@ -321,7 +377,9 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
"""
return self._set_relationship_strategy(attr, {"lazy": "selectin"})
- def lazyload(self, attr):
+ def lazyload(
+ self: Self_AbstractLoad, attr: _AttrType
+ ) -> Self_AbstractLoad:
"""Indicate that the given attribute should be loaded using "lazy"
loading.
@@ -337,7 +395,9 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
"""
return self._set_relationship_strategy(attr, {"lazy": "select"})
- def immediateload(self, attr):
+ def immediateload(
+ self: Self_AbstractLoad, attr: _AttrType
+ ) -> Self_AbstractLoad:
"""Indicate that the given attribute should be loaded using
an immediate load with a per-attribute SELECT statement.
@@ -361,7 +421,7 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
loader = self._set_relationship_strategy(attr, {"lazy": "immediate"})
return loader
- def noload(self, attr):
+ def noload(self: Self_AbstractLoad, attr: _AttrType) -> Self_AbstractLoad:
"""Indicate that the given relationship attribute should remain
unloaded.
@@ -387,7 +447,9 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
return self._set_relationship_strategy(attr, {"lazy": "noload"})
- def raiseload(self, attr, sql_only=False):
+ def raiseload(
+ self: Self_AbstractLoad, attr: _AttrType, sql_only: bool = False
+ ) -> Self_AbstractLoad:
"""Indicate that the given attribute should raise an error if accessed.
A relationship attribute configured with :func:`_orm.raiseload` will
@@ -428,7 +490,9 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
attr, {"lazy": "raise_on_sql" if sql_only else "raise"}
)
- def defaultload(self, attr):
+ def defaultload(
+ self: Self_AbstractLoad, attr: _AttrType
+ ) -> Self_AbstractLoad:
"""Indicate an attribute should load using its default loader style.
This method is used to link to other loader options further into
@@ -463,7 +527,9 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
"""
return self._set_relationship_strategy(attr, None)
- def defer(self, key, raiseload=False):
+ def defer(
+ self: Self_AbstractLoad, key: _AttrType, raiseload: bool = False
+ ) -> Self_AbstractLoad:
r"""Indicate that the given column-oriented attribute should be
deferred, e.g. not loaded until accessed.
@@ -524,7 +590,7 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
strategy["raiseload"] = True
return self._set_column_strategy((key,), strategy)
- def undefer(self, key):
+ def undefer(self: Self_AbstractLoad, key: _AttrType) -> Self_AbstractLoad:
r"""Indicate that the given column-oriented attribute should be
undeferred, e.g. specified within the SELECT statement of the entity
as a whole.
@@ -538,7 +604,9 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
Examples::
# undefer two columns
- session.query(MyClass).options(undefer("col1"), undefer("col2"))
+ session.query(MyClass).options(
+ undefer(MyClass.col1), undefer(MyClass.col2)
+ )
# undefer all columns specific to a single class using Load + *
session.query(MyClass, MyOtherClass).options(
@@ -546,7 +614,7 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
# undefer a column on a related object
session.query(MyClass).options(
- defaultload(MyClass.items).undefer('text'))
+ defaultload(MyClass.items).undefer(MyClass.text))
:param key: Attribute to be undeferred.
@@ -563,7 +631,7 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
(key,), {"deferred": False, "instrument": True}
)
- def undefer_group(self, name):
+ def undefer_group(self: Self_AbstractLoad, name: str) -> Self_AbstractLoad:
"""Indicate that columns within the given deferred group name should be
undeferred.
@@ -591,10 +659,14 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
"""
return self._set_column_strategy(
- _WILDCARD_TOKEN, None, {f"undefer_group_{name}": True}
+ (_WILDCARD_TOKEN,), None, {f"undefer_group_{name}": True}
)
- def with_expression(self, key, expression):
+ def with_expression(
+ self: Self_AbstractLoad,
+ key: _AttrType,
+ expression: _ColumnExpressionArgument[Any],
+ ) -> Self_AbstractLoad:
r"""Apply an ad-hoc SQL expression to a "deferred expression"
attribute.
@@ -626,15 +698,17 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
"""
- expression = coercions.expect(
- roles.LabeledColumnExprRole, _orm_full_deannotate(expression)
+ expression = _orm_full_deannotate(
+ coercions.expect(roles.LabeledColumnExprRole, expression)
)
return self._set_column_strategy(
(key,), {"query_expression": True}, opts={"expression": expression}
)
- def selectin_polymorphic(self, classes):
+ def selectin_polymorphic(
+ self: Self_AbstractLoad, classes: Iterable[Type[Any]]
+ ) -> Self_AbstractLoad:
"""Indicate an eager load should take place for all attributes
specific to a subclass.
@@ -659,25 +733,37 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
)
return self
- def _coerce_strat(self, strategy):
+ @overload
+ def _coerce_strat(self, strategy: _StrategySpec) -> _StrategyKey:
+ ...
+
+ @overload
+ def _coerce_strat(self, strategy: Literal[None]) -> None:
+ ...
+
+ def _coerce_strat(
+ self, strategy: Optional[_StrategySpec]
+ ) -> Optional[_StrategyKey]:
if strategy is not None:
- strategy = tuple(sorted(strategy.items()))
- return strategy
+ strategy_key = tuple(sorted(strategy.items()))
+ else:
+ strategy_key = None
+ return strategy_key
@_generative
def _set_relationship_strategy(
self: Self_AbstractLoad,
- attr,
- strategy,
- propagate_to_loaders=True,
- opts=None,
- _reconcile_to_other=None,
+ attr: _AttrType,
+ strategy: Optional[_StrategySpec],
+ propagate_to_loaders: bool = True,
+ opts: Optional[_OptsType] = None,
+ _reconcile_to_other: Optional[bool] = None,
) -> Self_AbstractLoad:
- strategy = self._coerce_strat(strategy)
+ strategy_key = self._coerce_strat(strategy)
self._clone_for_bind_strategy(
(attr,),
- strategy,
+ strategy_key,
_RELATIONSHIP_TOKEN,
opts=opts,
propagate_to_loaders=propagate_to_loaders,
@@ -687,13 +773,16 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
@_generative
def _set_column_strategy(
- self: Self_AbstractLoad, attrs, strategy, opts=None
+ self: Self_AbstractLoad,
+ attrs: Tuple[_AttrType, ...],
+ strategy: Optional[_StrategySpec],
+ opts: Optional[_OptsType] = None,
) -> Self_AbstractLoad:
- strategy = self._coerce_strat(strategy)
+ strategy_key = self._coerce_strat(strategy)
self._clone_for_bind_strategy(
attrs,
- strategy,
+ strategy_key,
_COLUMN_TOKEN,
opts=opts,
attr_group=attrs,
@@ -702,12 +791,15 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
@_generative
def _set_generic_strategy(
- self: Self_AbstractLoad, attrs, strategy, _reconcile_to_other=None
+ self: Self_AbstractLoad,
+ attrs: Tuple[_AttrType, ...],
+ strategy: _StrategySpec,
+ _reconcile_to_other: Optional[bool] = None,
) -> Self_AbstractLoad:
- strategy = self._coerce_strat(strategy)
+ strategy_key = self._coerce_strat(strategy)
self._clone_for_bind_strategy(
attrs,
- strategy,
+ strategy_key,
None,
propagate_to_loaders=True,
reconcile_to_other=_reconcile_to_other,
@@ -716,14 +808,14 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
@_generative
def _set_class_strategy(
- self: Self_AbstractLoad, strategy, opts
+ self: Self_AbstractLoad, strategy: _StrategySpec, opts: _OptsType
) -> Self_AbstractLoad:
- strategy = self._coerce_strat(strategy)
+ strategy_key = self._coerce_strat(strategy)
- self._clone_for_bind_strategy(None, strategy, None, opts=opts)
+ self._clone_for_bind_strategy(None, strategy_key, None, opts=opts)
return self
- def _apply_to_parent(self, parent):
+ def _apply_to_parent(self, parent: Load) -> None:
"""apply this :class:`_orm._AbstractLoad` object as a sub-option o
a :class:`_orm.Load` object.
@@ -732,7 +824,9 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
"""
raise NotImplementedError()
- def options(self: Self_AbstractLoad, *opts) -> NoReturn:
+ def options(
+ self: Self_AbstractLoad, *opts: _AbstractLoad
+ ) -> Self_AbstractLoad:
r"""Apply a series of options as sub-options to this
:class:`_orm._AbstractLoad` object.
@@ -742,20 +836,22 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
raise NotImplementedError()
def _clone_for_bind_strategy(
- self,
- attrs,
- strategy,
- wildcard_key,
- opts=None,
- attr_group=None,
- propagate_to_loaders=True,
- reconcile_to_other=None,
- ):
+ self: Self_AbstractLoad,
+ attrs: Optional[Tuple[_AttrType, ...]],
+ strategy: Optional[_StrategyKey],
+ wildcard_key: Optional[_WildcardKeyType],
+ opts: Optional[_OptsType] = None,
+ attr_group: Optional[_AttrGroupType] = None,
+ propagate_to_loaders: bool = True,
+ reconcile_to_other: Optional[bool] = None,
+ ) -> Self_AbstractLoad:
raise NotImplementedError()
def process_compile_state_replaced_entities(
- self, compile_state, mapper_entities
- ):
+ self,
+ compile_state: ORMCompileState,
+ mapper_entities: Sequence[_MapperEntity],
+ ) -> None:
if not compile_state.compile_options._enable_eagerloads:
return
@@ -768,7 +864,7 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
not bool(compile_state.current_path),
)
- def process_compile_state(self, compile_state):
+ def process_compile_state(self, compile_state: ORMCompileState) -> None:
if not compile_state.compile_options._enable_eagerloads:
return
@@ -779,12 +875,22 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
and not compile_state.compile_options._for_refresh_state,
)
- def _process(self, compile_state, mapper_entities, raiseerr):
+ def _process(
+ self,
+ compile_state: ORMCompileState,
+ mapper_entities: Sequence[_MapperEntity],
+ raiseerr: bool,
+ ) -> None:
"""implemented by subclasses"""
raise NotImplementedError()
@classmethod
- def _chop_path(cls, to_chop, path, debug=False):
+ def _chop_path(
+ cls,
+ to_chop: _PathRepresentation,
+ path: PathRegistry,
+ debug: bool = False,
+ ) -> Optional[_PathRepresentation]:
i = -1
for i, (c_token, p_token) in enumerate(zip(to_chop, path.path)):
@@ -793,7 +899,7 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
return to_chop
elif (
c_token != f"{_RELATIONSHIP_TOKEN}:{_WILDCARD_TOKEN}"
- and c_token != p_token.key
+ and c_token != p_token.key # type: ignore
):
return None
@@ -801,9 +907,9 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
continue
elif (
isinstance(c_token, InspectionAttr)
- and c_token.is_mapper
+ and insp_is_mapper(c_token)
and (
- (p_token.is_mapper and c_token.isa(p_token))
+ (insp_is_mapper(p_token) and c_token.isa(p_token))
or (
# a too-liberal check here to allow a path like
# A->A.bs->B->B.cs->C->C.ds, natural path, to chop
@@ -827,10 +933,9 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
# test_of_type.py->test_all_subq_query
#
i >= 2
- and p_token.is_aliased_class
+ and insp_is_aliased_class(p_token)
and p_token._is_with_polymorphic
and c_token in p_token.with_polymorphic_mappers
- # and (breakpoint() or True)
)
)
):
@@ -841,7 +946,7 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
return to_chop[i + 1 :]
-SelfLoad = typing.TypeVar("SelfLoad", bound="Load")
+SelfLoad = TypeVar("SelfLoad", bound="Load")
class Load(_AbstractLoad):
@@ -903,28 +1008,28 @@ class Load(_AbstractLoad):
_cache_key_traversal = None
path: PathRegistry
- context: Tuple["_LoadElement", ...]
+ context: Tuple[_LoadElement, ...]
- def __init__(self, entity):
- insp = cast(Union["Mapper", AliasedInsp], inspect(entity))
+ def __init__(self, entity: _EntityType[Any]):
+ insp = cast("Union[Mapper[Any], AliasedInsp[Any]]", inspect(entity))
insp._post_inspect
self.path = insp._path_registry
self.context = ()
self.propagate_to_loaders = False
- def __str__(self):
+ def __str__(self) -> str:
return f"Load({self.path[0]})"
@classmethod
- def _construct_for_existing_path(cls, path):
+ def _construct_for_existing_path(cls, path: PathRegistry) -> Load:
load = cls.__new__(cls)
load.path = path
load.context = ()
load.propagate_to_loaders = False
return load
- def _adjust_for_extra_criteria(self, context):
+ def _adjust_for_extra_criteria(self, context: QueryContext) -> Load:
"""Apply the current bound parameters in a QueryContext to all
occurrences "extra_criteria" stored within this ``Load`` object,
returning a new instance of this ``Load`` object.
@@ -932,10 +1037,10 @@ class Load(_AbstractLoad):
"""
orig_query = context.compile_state.select_statement
- orig_cache_key = None
- replacement_cache_key = None
+ orig_cache_key: Optional[CacheKey] = None
+ replacement_cache_key: Optional[CacheKey] = None
- def process(opt):
+ def process(opt: _LoadElement) -> _LoadElement:
if not opt._extra_criteria:
return opt
@@ -948,6 +1053,9 @@ class Load(_AbstractLoad):
orig_cache_key = orig_query._generate_cache_key()
replacement_cache_key = context.query._generate_cache_key()
+ assert orig_cache_key is not None
+ assert replacement_cache_key is not None
+
opt._extra_criteria = tuple(
replacement_cache_key._apply_params_to_element(
orig_cache_key, crit
@@ -975,12 +1083,22 @@ class Load(_AbstractLoad):
ezero = None
for ent in mapper_entities:
ezero = ent.entity_zero
- if ezero and orm_util._entity_corresponds_to(ezero, path[0]):
+ if ezero and orm_util._entity_corresponds_to(
+ # technically this can be a token also, but this is
+ # safe to pass to _entity_corresponds_to()
+ ezero,
+ cast("_InternalEntityType[Any]", path[0]),
+ ):
return ezero
return None
- def _process(self, compile_state, mapper_entities, raiseerr):
+ def _process(
+ self,
+ compile_state: ORMCompileState,
+ mapper_entities: Sequence[_MapperEntity],
+ raiseerr: bool,
+ ) -> None:
reconciled_lead_entity = self._reconcile_query_entities_with_us(
mapper_entities, raiseerr
@@ -995,7 +1113,7 @@ class Load(_AbstractLoad):
raiseerr,
)
- def _apply_to_parent(self, parent):
+ def _apply_to_parent(self, parent: Load) -> None:
"""apply this :class:`_orm.Load` object as a sub-option of another
:class:`_orm.Load` object.
@@ -1007,7 +1125,8 @@ class Load(_AbstractLoad):
assert cloned.propagate_to_loaders == self.propagate_to_loaders
if not orm_util._entity_corresponds_to_use_path_impl(
- parent.path[-1], cloned.path[0]
+ cast("_InternalEntityType[Any]", parent.path[-1]),
+ cast("_InternalEntityType[Any]", cloned.path[0]),
):
raise sa_exc.ArgumentError(
f'Attribute "{cloned.path[1]}" does not link '
@@ -1025,7 +1144,7 @@ class Load(_AbstractLoad):
parent.context += cloned.context
@_generative
- def options(self: SelfLoad, *opts) -> SelfLoad:
+ def options(self: SelfLoad, *opts: _AbstractLoad) -> SelfLoad:
r"""Apply a series of options as sub-options to this
:class:`_orm.Load`
object.
@@ -1062,38 +1181,36 @@ class Load(_AbstractLoad):
return self
def _clone_for_bind_strategy(
- self,
- attrs,
- strategy,
- wildcard_key,
- opts=None,
- attr_group=None,
- propagate_to_loaders=True,
- reconcile_to_other=None,
- ) -> None:
+ self: SelfLoad,
+ attrs: Optional[Tuple[_AttrType, ...]],
+ strategy: Optional[_StrategyKey],
+ wildcard_key: Optional[_WildcardKeyType],
+ opts: Optional[_OptsType] = None,
+ attr_group: Optional[_AttrGroupType] = None,
+ propagate_to_loaders: bool = True,
+ reconcile_to_other: Optional[bool] = None,
+ ) -> SelfLoad:
# for individual strategy that needs to propagate, set the whole
# Load container to also propagate, so that it shows up in
# InstanceState.load_options
if propagate_to_loaders:
self.propagate_to_loaders = True
- if not self.path.has_entity:
- if self.path.is_token:
+ if self.path.is_token:
+ raise sa_exc.ArgumentError(
+ "Wildcard token cannot be followed by another entity"
+ )
+
+ elif path_is_property(self.path):
+ # re-use the lookup which will raise a nicely formatted
+ # LoaderStrategyException
+ if strategy:
+ self.path.prop._strategy_lookup(self.path.prop, strategy[0])
+ else:
raise sa_exc.ArgumentError(
- "Wildcard token cannot be followed by another entity"
+ f"Mapped attribute '{self.path.prop}' does not "
+ "refer to a mapped entity"
)
- else:
- # re-use the lookup which will raise a nicely formatted
- # LoaderStrategyException
- if strategy:
- self.path.prop._strategy_lookup(
- self.path.prop, strategy[0]
- )
- else:
- raise sa_exc.ArgumentError(
- f"Mapped attribute '{self.path.prop}' does not "
- "refer to a mapped entity"
- )
if attrs is None:
load_element = _ClassStrategyLoad.create(
@@ -1140,6 +1257,7 @@ class Load(_AbstractLoad):
if wildcard_key is _RELATIONSHIP_TOKEN:
self.path = load_element.path
self.context += (load_element,)
+ return self
def __getstate__(self):
d = self._shallow_to_dict()
@@ -1151,7 +1269,7 @@ class Load(_AbstractLoad):
self._shallow_from_dict(state)
-SelfWildcardLoad = typing.TypeVar("SelfWildcardLoad", bound="_WildcardLoad")
+SelfWildcardLoad = TypeVar("SelfWildcardLoad", bound="_WildcardLoad")
class _WildcardLoad(_AbstractLoad):
@@ -1167,14 +1285,14 @@ class _WildcardLoad(_AbstractLoad):
visitors.ExtendedInternalTraversal.dp_string_multi_dict,
),
]
- cache_key_traversal = None
+ cache_key_traversal: _CacheKeyTraversalType = None
strategy: Optional[Tuple[Any, ...]]
- local_opts: Mapping[str, Any]
+ local_opts: _OptsType
path: Tuple[str, ...]
propagate_to_loaders = False
- def __init__(self):
+ def __init__(self) -> None:
self.path = ()
self.strategy = None
self.local_opts = util.EMPTY_DICT
@@ -1189,6 +1307,7 @@ class _WildcardLoad(_AbstractLoad):
propagate_to_loaders=True,
reconcile_to_other=None,
):
+ assert attrs is not None
attr = attrs[0]
assert (
wildcard_key
@@ -1203,10 +1322,12 @@ class _WildcardLoad(_AbstractLoad):
if opts:
self.local_opts = util.immutabledict(opts)
- def options(self: SelfWildcardLoad, *opts) -> SelfWildcardLoad:
+ def options(
+ self: SelfWildcardLoad, *opts: _AbstractLoad
+ ) -> SelfWildcardLoad:
raise NotImplementedError("Star option does not support sub-options")
- def _apply_to_parent(self, parent):
+ def _apply_to_parent(self, parent: Load) -> None:
"""apply this :class:`_orm._WildcardLoad` object as a sub-option of
a :class:`_orm.Load` object.
@@ -1215,12 +1336,11 @@ class _WildcardLoad(_AbstractLoad):
it may be used as the sub-option of a :class:`_orm.Load` object.
"""
-
attr = self.path[0]
if attr.endswith(_DEFAULT_TOKEN):
attr = f"{attr.split(':')[0]}:{_WILDCARD_TOKEN}"
- effective_path = parent.path.token(attr)
+ effective_path = cast(AbstractEntityRegistry, parent.path).token(attr)
assert effective_path.is_token
@@ -1244,20 +1364,21 @@ class _WildcardLoad(_AbstractLoad):
entities = [ent.entity_zero for ent in mapper_entities]
current_path = compile_state.current_path
- start_path = self.path
+ start_path: _PathRepresentation = self.path
# TODO: chop_path already occurs in loader.process_compile_state()
# so we will seek to simplify this
if current_path:
- start_path = self._chop_path(start_path, current_path)
- if not start_path:
+ new_path = self._chop_path(start_path, current_path)
+ if not new_path:
return
+ start_path = new_path
# start_path is a single-token tuple
assert start_path and len(start_path) == 1
token = start_path[0]
-
+ assert isinstance(token, str)
entity = self._find_entity_basestring(entities, token, raiseerr)
if not entity:
@@ -1270,6 +1391,7 @@ class _WildcardLoad(_AbstractLoad):
# we just located, then go through the rest of our path
# tokens and populate into the Load().
+ assert isinstance(token, str)
loader = _TokenStrategyLoad.create(
path_element._path_registry,
token,
@@ -1291,7 +1413,12 @@ class _WildcardLoad(_AbstractLoad):
return loader
- def _find_entity_basestring(self, entities, token, raiseerr):
+ def _find_entity_basestring(
+ self,
+ entities: Iterable[_InternalEntityType[Any]],
+ token: str,
+ raiseerr: bool,
+ ) -> Optional[_InternalEntityType[Any]]:
if token.endswith(f":{_WILDCARD_TOKEN}"):
if len(list(entities)) != 1:
if raiseerr:
@@ -1324,11 +1451,11 @@ class _WildcardLoad(_AbstractLoad):
else:
return None
- def __getstate__(self):
+ def __getstate__(self) -> Dict[str, Any]:
d = self._shallow_to_dict()
return d
- def __setstate__(self, state):
+ def __setstate__(self, state: Dict[str, Any]) -> None:
self._shallow_from_dict(state)
@@ -1372,38 +1499,38 @@ class _LoadElement(
_extra_criteria: Tuple[Any, ...]
_reconcile_to_other: Optional[bool]
- strategy: Tuple[Any, ...]
+ strategy: Optional[_StrategyKey]
path: PathRegistry
propagate_to_loaders: bool
- local_opts: Mapping[str, Any]
+ local_opts: util.immutabledict[str, Any]
is_token_strategy: bool
is_class_strategy: bool
- def __hash__(self):
+ def __hash__(self) -> int:
return id(self)
def __eq__(self, other):
return traversals.compare(self, other)
@property
- def is_opts_only(self):
+ def is_opts_only(self) -> bool:
return bool(self.local_opts and self.strategy is None)
- def _clone(self):
+ def _clone(self, **kw: Any) -> _LoadElement:
cls = self.__class__
s = cls.__new__(cls)
self._shallow_copy_to(s)
return s
- def __getstate__(self):
+ def __getstate__(self) -> Dict[str, Any]:
d = self._shallow_to_dict()
d["path"] = self.path.serialize()
return d
- def __setstate__(self, state):
+ def __setstate__(self, state: Dict[str, Any]) -> None:
state["path"] = PathRegistry.deserialize(state["path"])
self._shallow_from_dict(state)
@@ -1437,8 +1564,8 @@ class _LoadElement(
)
def _adjust_effective_path_for_current_path(
- self, effective_path, current_path
- ):
+ self, effective_path: PathRegistry, current_path: PathRegistry
+ ) -> Optional[PathRegistry]:
"""receives the 'current_path' entry from an :class:`.ORMCompileState`
instance, which is set during lazy loads and secondary loader strategy
loads, and adjusts the given path to be relative to the
@@ -1456,7 +1583,7 @@ class _LoadElement(
"""
- chopped_start_path = Load._chop_path(effective_path, current_path)
+ chopped_start_path = Load._chop_path(effective_path.path, current_path)
if not chopped_start_path:
return None
@@ -1523,16 +1650,16 @@ class _LoadElement(
@classmethod
def create(
cls,
- path,
- attr,
- strategy,
- wildcard_key,
- local_opts,
- propagate_to_loaders,
- raiseerr=True,
- attr_group=None,
- reconcile_to_other=None,
- ):
+ path: PathRegistry,
+ attr: Optional[_AttrType],
+ strategy: Optional[_StrategyKey],
+ wildcard_key: Optional[_WildcardKeyType],
+ local_opts: Optional[_OptsType],
+ propagate_to_loaders: bool,
+ raiseerr: bool = True,
+ attr_group: Optional[_AttrGroupType] = None,
+ reconcile_to_other: Optional[bool] = None,
+ ) -> _LoadElement:
"""Create a new :class:`._LoadElement` object."""
opt = cls.__new__(cls)
@@ -1554,14 +1681,14 @@ class _LoadElement(
path = opt._init_path(path, attr, wildcard_key, attr_group, raiseerr)
if not path:
- return None
+ return None # type: ignore
assert opt.is_token_strategy == path.is_token
opt.path = path
return opt
- def __init__(self, path, strategy, local_opts, propagate_to_loaders):
+ def __init__(self) -> None:
raise NotImplementedError()
def _prepend_path_from(self, parent):
@@ -1580,7 +1707,8 @@ class _LoadElement(
assert cloned.is_class_strategy == self.is_class_strategy
if not orm_util._entity_corresponds_to_use_path_impl(
- parent.path[-1], cloned.path[0]
+ cast("_InternalEntityType[Any]", parent.path[-1]),
+ cast("_InternalEntityType[Any]", cloned.path[0]),
):
raise sa_exc.ArgumentError(
f'Attribute "{cloned.path[1]}" does not link '
@@ -1592,7 +1720,9 @@ class _LoadElement(
return cloned
@staticmethod
- def _reconcile(replacement, existing):
+ def _reconcile(
+ replacement: _LoadElement, existing: _LoadElement
+ ) -> _LoadElement:
"""define behavior for when two Load objects are to be put into
the context.attributes under the same key.
@@ -1670,7 +1800,7 @@ class _AttributeStrategyLoad(_LoadElement):
),
]
- _of_type: Union["Mapper", AliasedInsp, None]
+ _of_type: Union["Mapper[Any]", "AliasedInsp[Any]", None]
_path_with_polymorphic_path: Optional[PathRegistry]
is_class_strategy = False
@@ -1812,7 +1942,7 @@ class _AttributeStrategyLoad(_LoadElement):
pwpi = inspect(
orm_util.AliasedInsp._with_polymorphic_factory(
pwpi.mapper.base_mapper,
- pwpi.mapper,
+ (pwpi.mapper,),
aliased=True,
_use_mapper_path=True,
)
@@ -1820,11 +1950,12 @@ class _AttributeStrategyLoad(_LoadElement):
start_path = self._path_with_polymorphic_path
if current_path:
- start_path = self._adjust_effective_path_for_current_path(
+ new_path = self._adjust_effective_path_for_current_path(
start_path, current_path
)
- if start_path is None:
+ if new_path is None:
return
+ start_path = new_path
key = ("path_with_polymorphic", start_path.natural_path)
if key in context:
@@ -1872,6 +2003,7 @@ class _AttributeStrategyLoad(_LoadElement):
effective_path = self.path
if current_path:
+ assert effective_path is not None
effective_path = self._adjust_effective_path_for_current_path(
effective_path, current_path
)
@@ -1985,11 +2117,12 @@ class _TokenStrategyLoad(_LoadElement):
)
if current_path:
- effective_path = self._adjust_effective_path_for_current_path(
+ new_effective_path = self._adjust_effective_path_for_current_path(
effective_path, current_path
)
- if effective_path is None:
+ if new_effective_path is None:
return []
+ effective_path = new_effective_path
# for a wildcard token, expand out the path we set
# to encompass everything from the query entity on
@@ -2048,19 +2181,25 @@ class _ClassStrategyLoad(_LoadElement):
effective_path = self.path
if current_path:
- effective_path = self._adjust_effective_path_for_current_path(
+ new_effective_path = self._adjust_effective_path_for_current_path(
effective_path, current_path
)
- if effective_path is None:
+ if new_effective_path is None:
return []
+ effective_path = new_effective_path
- return [("loader", cast(PathRegistry, effective_path).natural_path)]
+ return [("loader", effective_path.natural_path)]
-def _generate_from_keys(meth, keys, chained, kw) -> _AbstractLoad:
-
- lead_element = None
+def _generate_from_keys(
+ meth: Callable[..., _AbstractLoad],
+ keys: Tuple[_AttrType, ...],
+ chained: bool,
+ kw: Any,
+) -> _AbstractLoad:
+ lead_element: Optional[_AbstractLoad] = None
+ attr: Any
for is_default, _keys in (True, keys[0:-1]), (False, keys[-1:]):
for attr in _keys:
if isinstance(attr, str):
@@ -2116,7 +2255,9 @@ def _generate_from_keys(meth, keys, chained, kw) -> _AbstractLoad:
return lead_element
-def _parse_attr_argument(attr):
+def _parse_attr_argument(
+ attr: _AttrType,
+) -> Tuple[InspectionAttr, _InternalEntityType[Any], MapperProperty[Any]]:
"""parse an attribute or wildcard argument to produce an
:class:`._AbstractLoad` instance.
@@ -2126,16 +2267,21 @@ def _parse_attr_argument(attr):
"""
try:
- insp = inspect(attr)
+ # TODO: need to figure out this None thing being returned by
+ # inspect(), it should not have None as an option in most cases
+ # if at all
+ insp: InspectionAttr = inspect(attr) # type: ignore
except sa_exc.NoInspectionAvailable as err:
raise sa_exc.ArgumentError(
"expected ORM mapped attribute for loader strategy argument"
) from err
- if insp.is_property:
+ lead_entity: _InternalEntityType[Any]
+
+ if insp_is_mapper_property(insp):
lead_entity = insp.parent
prop = insp
- elif insp.is_attribute:
+ elif insp_is_attribute(insp):
lead_entity = insp.parent
prop = insp.prop
else:
@@ -2146,7 +2292,7 @@ def _parse_attr_argument(attr):
return insp, lead_entity, prop
-def loader_unbound_fn(fn):
+def loader_unbound_fn(fn: _FN) -> _FN:
"""decorator that applies docstrings between standalone loader functions
and the loader methods on :class:`._AbstractLoad`.
@@ -2169,12 +2315,12 @@ See :func:`_orm.{fn.__name__}` for usage examples.
@loader_unbound_fn
-def contains_eager(*keys, **kw) -> _AbstractLoad:
+def contains_eager(*keys: _AttrType, **kw: Any) -> _AbstractLoad:
return _generate_from_keys(Load.contains_eager, keys, True, kw)
@loader_unbound_fn
-def load_only(*attrs) -> _AbstractLoad:
+def load_only(*attrs: _AttrType) -> _AbstractLoad:
# TODO: attrs against different classes. we likely have to
# add some extra state to Load of some kind
_, lead_element, _ = _parse_attr_argument(attrs[0])
@@ -2182,47 +2328,47 @@ def load_only(*attrs) -> _AbstractLoad:
@loader_unbound_fn
-def joinedload(*keys, **kw) -> _AbstractLoad:
+def joinedload(*keys: _AttrType, **kw: Any) -> _AbstractLoad:
return _generate_from_keys(Load.joinedload, keys, False, kw)
@loader_unbound_fn
-def subqueryload(*keys) -> _AbstractLoad:
+def subqueryload(*keys: _AttrType) -> _AbstractLoad:
return _generate_from_keys(Load.subqueryload, keys, False, {})
@loader_unbound_fn
-def selectinload(*keys) -> _AbstractLoad:
+def selectinload(*keys: _AttrType) -> _AbstractLoad:
return _generate_from_keys(Load.selectinload, keys, False, {})
@loader_unbound_fn
-def lazyload(*keys) -> _AbstractLoad:
+def lazyload(*keys: _AttrType) -> _AbstractLoad:
return _generate_from_keys(Load.lazyload, keys, False, {})
@loader_unbound_fn
-def immediateload(*keys) -> _AbstractLoad:
+def immediateload(*keys: _AttrType) -> _AbstractLoad:
return _generate_from_keys(Load.immediateload, keys, False, {})
@loader_unbound_fn
-def noload(*keys) -> _AbstractLoad:
+def noload(*keys: _AttrType) -> _AbstractLoad:
return _generate_from_keys(Load.noload, keys, False, {})
@loader_unbound_fn
-def raiseload(*keys, **kw) -> _AbstractLoad:
+def raiseload(*keys: _AttrType, **kw: Any) -> _AbstractLoad:
return _generate_from_keys(Load.raiseload, keys, False, kw)
@loader_unbound_fn
-def defaultload(*keys) -> _AbstractLoad:
+def defaultload(*keys: _AttrType) -> _AbstractLoad:
return _generate_from_keys(Load.defaultload, keys, False, {})
@loader_unbound_fn
-def defer(key, *addl_attrs, **kw) -> _AbstractLoad:
+def defer(key: _AttrType, *addl_attrs: _AttrType, **kw: Any) -> _AbstractLoad:
if addl_attrs:
util.warn_deprecated(
"The *addl_attrs on orm.defer is deprecated. Please use "
@@ -2234,7 +2380,7 @@ def defer(key, *addl_attrs, **kw) -> _AbstractLoad:
@loader_unbound_fn
-def undefer(key, *addl_attrs) -> _AbstractLoad:
+def undefer(key: _AttrType, *addl_attrs: _AttrType) -> _AbstractLoad:
if addl_attrs:
util.warn_deprecated(
"The *addl_attrs on orm.undefer is deprecated. Please use "
@@ -2246,19 +2392,23 @@ def undefer(key, *addl_attrs) -> _AbstractLoad:
@loader_unbound_fn
-def undefer_group(name) -> _AbstractLoad:
+def undefer_group(name: str) -> _AbstractLoad:
element = _WildcardLoad()
return element.undefer_group(name)
@loader_unbound_fn
-def with_expression(key, expression) -> _AbstractLoad:
+def with_expression(
+ key: _AttrType, expression: _ColumnExpressionArgument[Any]
+) -> _AbstractLoad:
return _generate_from_keys(
Load.with_expression, (key,), False, {"expression": expression}
)
@loader_unbound_fn
-def selectin_polymorphic(base_cls, classes) -> _AbstractLoad:
+def selectin_polymorphic(
+ base_cls: _EntityType[Any], classes: Iterable[Type[Any]]
+) -> _AbstractLoad:
ul = Load(base_cls)
return ul.selectin_polymorphic(classes)
diff --git a/lib/sqlalchemy/orm/sync.py b/lib/sqlalchemy/orm/sync.py
index 4f63e241b..4f1eeb39b 100644
--- a/lib/sqlalchemy/orm/sync.py
+++ b/lib/sqlalchemy/orm/sync.py
@@ -4,7 +4,7 @@
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
+# mypy: allow-untyped-defs, allow-untyped-calls
"""private module containing functions used for copying data
@@ -14,9 +14,9 @@ between instances based on join conditions.
from __future__ import annotations
-from . import attributes
from . import exc
from . import util as orm_util
+from .base import PassiveFlag
def populate(
@@ -36,7 +36,7 @@ def populate(
# inline of source_mapper._get_state_attr_by_column
prop = source_mapper._columntoproperty[l]
value = source.manager[prop.key].impl.get(
- source, source_dict, attributes.PASSIVE_OFF
+ source, source_dict, PassiveFlag.PASSIVE_OFF
)
except exc.UnmappedColumnError as err:
_raise_col_to_prop(False, source_mapper, l, dest_mapper, r, err)
@@ -74,8 +74,8 @@ def bulk_populate_inherit_keys(source_dict, source_mapper, synchronize_pairs):
try:
prop = source_mapper._columntoproperty[r]
source_dict[prop.key] = value
- except exc.UnmappedColumnError:
- _raise_col_to_prop(True, source_mapper, l, source_mapper, r)
+ except exc.UnmappedColumnError as err:
+ _raise_col_to_prop(True, source_mapper, l, source_mapper, r, err)
def clear(dest, dest_mapper, synchronize_pairs):
@@ -103,7 +103,7 @@ def update(source, source_mapper, dest, old_prefix, synchronize_pairs):
source.obj(), l
)
value = source_mapper._get_state_attr_by_column(
- source, source.dict, l, passive=attributes.PASSIVE_OFF
+ source, source.dict, l, passive=PassiveFlag.PASSIVE_OFF
)
except exc.UnmappedColumnError as err:
_raise_col_to_prop(False, source_mapper, l, None, r, err)
@@ -115,7 +115,7 @@ def populate_dict(source, source_mapper, dict_, synchronize_pairs):
for l, r in synchronize_pairs:
try:
value = source_mapper._get_state_attr_by_column(
- source, source.dict, l, passive=attributes.PASSIVE_OFF
+ source, source.dict, l, passive=PassiveFlag.PASSIVE_OFF
)
except exc.UnmappedColumnError as err:
_raise_col_to_prop(False, source_mapper, l, None, r, err)
@@ -134,7 +134,7 @@ def source_modified(uowcommit, source, source_mapper, synchronize_pairs):
except exc.UnmappedColumnError as err:
_raise_col_to_prop(False, source_mapper, l, None, r, err)
history = uowcommit.get_attribute_history(
- source, prop.key, attributes.PASSIVE_NO_INITIALIZE
+ source, prop.key, PassiveFlag.PASSIVE_NO_INITIALIZE
)
if bool(history.deleted):
return True
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index 4da0b7773..c50cc5bac 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -12,6 +12,7 @@ import re
import types
import typing
from typing import Any
+from typing import Callable
from typing import cast
from typing import Dict
from typing import FrozenSet
@@ -82,24 +83,29 @@ if typing.TYPE_CHECKING:
from ._typing import _EntityType
from ._typing import _IdentityKeyType
from ._typing import _InternalEntityType
- from ._typing import _ORMColumnExprArgument
+ from ._typing import _ORMCOLEXPR
from .context import _MapperEntity
from .context import ORMCompileState
from .mapper import Mapper
+ from .query import Query
from .relationships import Relationship
from ..engine import Row
from ..engine import RowMapping
+ from ..sql._typing import _CE
from ..sql._typing import _ColumnExpressionArgument
from ..sql._typing import _EquivalentColumnMap
from ..sql._typing import _FromClauseArgument
from ..sql._typing import _OnClauseArgument
from ..sql._typing import _PropagateAttrsType
+ from ..sql.annotation import _SA
from ..sql.base import ReadOnlyColumnCollection
from ..sql.elements import BindParameter
from ..sql.selectable import _ColumnsClauseElement
from ..sql.selectable import Alias
+ from ..sql.selectable import Select
from ..sql.selectable import Subquery
from ..sql.visitors import anon_map
+ from ..util.typing import _AnnotationScanType
_T = TypeVar("_T", bound=Any)
@@ -144,9 +150,11 @@ class CascadeOptions(FrozenSet[str]):
expunge: bool
delete_orphan: bool
- def __new__(cls, value_list):
+ def __new__(
+ cls, value_list: Optional[Union[Iterable[str], str]]
+ ) -> CascadeOptions:
if isinstance(value_list, str) or value_list is None:
- return cls.from_string(value_list)
+ return cls.from_string(value_list) # type: ignore
values = set(value_list)
if values.difference(cls._allowed_cascades):
raise sa_exc.ArgumentError(
@@ -864,7 +872,7 @@ class AliasedInsp(
def _with_polymorphic_factory(
cls,
base: Union[_O, Mapper[_O]],
- classes: Iterable[Type[Any]],
+ classes: Iterable[_EntityType[Any]],
selectable: Union[Literal[False, None], FromClause] = False,
flat: bool = False,
polymorphic_on: Optional[ColumnElement[Any]] = None,
@@ -1011,23 +1019,40 @@ class AliasedInsp(
)._aliased_insp
def _adapt_element(
- self, elem: _ORMColumnExprArgument[_T], key: Optional[str] = None
- ) -> _ORMColumnExprArgument[_T]:
- assert isinstance(elem, ColumnElement)
+ self, expr: _ORMCOLEXPR, key: Optional[str] = None
+ ) -> _ORMCOLEXPR:
+ assert isinstance(expr, ColumnElement)
d: Dict[str, Any] = {
"parententity": self,
"parentmapper": self.mapper,
}
if key:
d["proxy_key"] = key
+
+ # IMO mypy should see this one also as returning the same type
+ # we put into it, but it's not
return (
- self._adapter.traverse(elem)
+ self._adapter.traverse(expr) # type: ignore
._annotate(d)
._set_propagate_attrs(
{"compile_state_plugin": "orm", "plugin_subject": self}
)
)
+ if TYPE_CHECKING:
+ # establish compatibility with the _ORMAdapterProto protocol,
+ # which in turn is compatible with _CoreAdapterProto.
+
+ def _orm_adapt_element(
+ self,
+ obj: _CE,
+ key: Optional[str] = None,
+ ) -> _CE:
+ ...
+
+ else:
+ _orm_adapt_element = _adapt_element
+
def _entity_for_mapper(self, mapper):
self_poly = self.with_polymorphic_mappers
if mapper in self_poly:
@@ -1469,7 +1494,12 @@ class Bundle(
cloned.name = name
return cloned
- def create_row_processor(self, query, procs, labels):
+ def create_row_processor(
+ self,
+ query: Select[Any],
+ procs: Sequence[Callable[[Row[Any]], Any]],
+ labels: Sequence[str],
+ ) -> Callable[[Row[Any]], Any]:
"""Produce the "row processing" function for this :class:`.Bundle`.
May be overridden by subclasses.
@@ -1481,13 +1511,13 @@ class Bundle(
"""
keyed_tuple = result_tuple(labels, [() for l in labels])
- def proc(row):
+ def proc(row: Row[Any]) -> Any:
return keyed_tuple([proc(row) for proc in procs])
return proc
-def _orm_annotate(element, exclude=None):
+def _orm_annotate(element: _SA, exclude: Optional[Any] = None) -> _SA:
"""Deep copy the given ClauseElement, annotating each element with the
"_orm_adapt" flag.
@@ -1497,7 +1527,7 @@ def _orm_annotate(element, exclude=None):
return sql_util._deep_annotate(element, {"_orm_adapt": True}, exclude)
-def _orm_deannotate(element):
+def _orm_deannotate(element: _SA) -> _SA:
"""Remove annotations that link a column to a particular mapping.
Note this doesn't affect "remote" and "foreign" annotations
@@ -1511,7 +1541,7 @@ def _orm_deannotate(element):
)
-def _orm_full_deannotate(element):
+def _orm_full_deannotate(element: _SA) -> _SA:
return sql_util._deep_deannotate(element)
@@ -1560,13 +1590,15 @@ class _ORMJoin(expression.Join):
on_selectable = prop.parent.selectable
else:
prop = None
+ on_selectable = None
if prop:
left_selectable = left_info.selectable
-
+ adapt_from: Optional[FromClause]
if sql_util.clause_is_present(on_selectable, left_selectable):
adapt_from = on_selectable
else:
+ assert isinstance(left_selectable, FromClause)
adapt_from = left_selectable
(
@@ -1855,7 +1887,7 @@ def _entity_isa(given: _InternalEntityType[Any], mapper: Mapper[Any]) -> bool:
return given.isa(mapper)
-def _getitem(iterable_query, item):
+def _getitem(iterable_query: Query[Any], item: Any) -> Any:
"""calculate __getitem__ in terms of an iterable query object
that also has a slice() method.
@@ -1881,17 +1913,15 @@ def _getitem(iterable_query, item):
isinstance(stop, int) and stop < 0
):
_no_negative_indexes()
- return list(iterable_query)[item]
res = iterable_query.slice(start, stop)
if step is not None:
- return list(res)[None : None : item.step]
+ return list(res)[None : None : item.step] # type: ignore
else:
- return list(res)
+ return list(res) # type: ignore
else:
if item == -1:
_no_negative_indexes()
- return list(iterable_query)[-1]
else:
return list(iterable_query[item : item + 1])[0]
@@ -1933,7 +1963,7 @@ def _cleanup_mapped_str_annotation(annotation: str) -> str:
def _extract_mapped_subtype(
- raw_annotation: Union[type, str],
+ raw_annotation: Optional[_AnnotationScanType],
cls: type,
key: str,
attr_cls: Type[Any],
diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py
index f49a6d3ec..ed1bd2832 100644
--- a/lib/sqlalchemy/sql/_typing.py
+++ b/lib/sqlalchemy/sql/_typing.py
@@ -61,6 +61,9 @@ if TYPE_CHECKING:
_T = TypeVar("_T", bound=Any)
+_CE = TypeVar("_CE", bound="ColumnElement[Any]")
+
+
class _HasClauseElement(Protocol):
"""indicates a class that has a __clause_element__() method"""
@@ -68,6 +71,13 @@ class _HasClauseElement(Protocol):
...
+class _CoreAdapterProto(Protocol):
+ """protocol for the ClauseAdapter/ColumnAdapter.traverse() method."""
+
+ def __call__(self, obj: _CE) -> _CE:
+ ...
+
+
# match column types that are not ORM entities
_NOT_ENTITY = TypeVar(
"_NOT_ENTITY",
diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py
index fa36c09fc..56d88bc2f 100644
--- a/lib/sqlalchemy/sql/annotation.py
+++ b/lib/sqlalchemy/sql/annotation.py
@@ -454,9 +454,23 @@ def _deep_annotate(
return element
+@overload
+def _deep_deannotate(
+ element: Literal[None], values: Optional[Sequence[str]] = None
+) -> Literal[None]:
+ ...
+
+
+@overload
def _deep_deannotate(
element: _SA, values: Optional[Sequence[str]] = None
) -> _SA:
+ ...
+
+
+def _deep_deannotate(
+ element: Optional[_SA], values: Optional[Sequence[str]] = None
+) -> Optional[_SA]:
"""Deep copy the given element, removing annotations."""
cloned: Dict[Any, SupportsAnnotations] = {}
@@ -482,9 +496,7 @@ def _deep_deannotate(
return element
-def _shallow_annotate(
- element: SupportsAnnotations, annotations: _AnnotationDict
-) -> SupportsAnnotations:
+def _shallow_annotate(element: _SA, annotations: _AnnotationDict) -> _SA:
"""Annotate the given ClauseElement and copy its internals so that
internal objects refer to the new annotated object.
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index 248b48a25..f5a9c10c0 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -750,6 +750,17 @@ class _MetaOptions(type):
o1.__dict__.update(other)
return o1
+ if TYPE_CHECKING:
+
+ def __getattr__(self, key: str) -> Any:
+ ...
+
+ def __setattr__(self, key: str, value: Any) -> None:
+ ...
+
+ def __delattr__(self, key: str) -> None:
+ ...
+
class Options(metaclass=_MetaOptions):
"""A cacheable option dictionary with defaults."""
@@ -904,6 +915,17 @@ class Options(metaclass=_MetaOptions):
else:
return existing_options, exec_options
+ if TYPE_CHECKING:
+
+ def __getattr__(self, key: str) -> Any:
+ ...
+
+ def __setattr__(self, key: str, value: Any) -> None:
+ ...
+
+ def __delattr__(self, key: str) -> None:
+ ...
+
class CacheableOptions(Options, HasCacheKey):
__slots__ = ()
diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py
index eef5cf211..501188b12 100644
--- a/lib/sqlalchemy/sql/coercions.py
+++ b/lib/sqlalchemy/sql/coercions.py
@@ -56,6 +56,7 @@ if typing.TYPE_CHECKING:
from .elements import ColumnClause
from .elements import ColumnElement
from .elements import DQLDMLClauseElement
+ from .elements import NamedColumn
from .elements import SQLCoreOperations
from .schema import Column
from .selectable import _ColumnsClauseElement
@@ -199,6 +200,15 @@ def expect(
@overload
def expect(
+ role: Type[roles.LabeledColumnExprRole[Any]],
+ element: _ColumnExpressionArgument[_T],
+ **kw: Any,
+) -> NamedColumn[_T]:
+ ...
+
+
+@overload
+def expect(
role: Union[
Type[roles.ExpressionElementRole[Any]],
Type[roles.LimitOffsetRole],
@@ -217,6 +227,7 @@ def expect(
Type[roles.LimitOffsetRole],
Type[roles.WhereHavingRole],
Type[roles.OnClauseRole],
+ Type[roles.ColumnArgumentRole],
],
element: Any,
**kw: Any,
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index 41b7f6392..61c5379d8 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -503,7 +503,7 @@ class ClauseElement(
def params(
self: SelfClauseElement,
- __optionaldict: Optional[Dict[str, Any]] = None,
+ __optionaldict: Optional[Mapping[str, Any]] = None,
**kwargs: Any,
) -> SelfClauseElement:
"""Return a copy with :func:`_expression.bindparam` elements
@@ -525,7 +525,7 @@ class ClauseElement(
def _replace_params(
self: SelfClauseElement,
unique: bool,
- optionaldict: Optional[Dict[str, Any]],
+ optionaldict: Optional[Mapping[str, Any]],
kwargs: Dict[str, Any],
) -> SelfClauseElement:
@@ -545,7 +545,7 @@ class ClauseElement(
{"bindparam": visit_bindparam},
)
- def compare(self, other, **kw):
+ def compare(self, other: ClauseElement, **kw: Any) -> bool:
r"""Compare this :class:`_expression.ClauseElement` to
the given :class:`_expression.ClauseElement`.
@@ -2516,7 +2516,9 @@ class True_(SingletonConstant, roles.ConstExprRole[bool], ColumnElement[bool]):
return False_._singleton
@classmethod
- def _ifnone(cls, other):
+ def _ifnone(
+ cls, other: Optional[ColumnElement[Any]]
+ ) -> ColumnElement[Any]:
if other is None:
return cls._instance()
else:
@@ -4226,7 +4228,13 @@ class NamedColumn(KeyedColumnElement[_T]):
) -> Optional[str]:
return name
- def _bind_param(self, operator, obj, type_=None, expanding=False):
+ def _bind_param(
+ self,
+ operator: OperatorType,
+ obj: Any,
+ type_: Optional[TypeEngine[_T]] = None,
+ expanding: bool = False,
+ ) -> BindParameter[_T]:
return BindParameter(
self.key,
obj,
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index d0b0f1476..fd98f17e3 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -64,6 +64,7 @@ from .base import _EntityNamespace
from .base import _expand_cloned
from .base import _from_objects
from .base import _generative
+from .base import _NoArg
from .base import _select_iterables
from .base import CacheableOptions
from .base import ColumnCollection
@@ -131,6 +132,7 @@ if TYPE_CHECKING:
from .dml import Insert
from .dml import Update
from .elements import KeyedColumnElement
+ from .elements import Label
from .elements import NamedColumn
from .elements import TextClause
from .functions import Function
@@ -212,7 +214,7 @@ class ReturnsRows(roles.ReturnsRowsRole, DQLDMLClauseElement):
"""
raise NotImplementedError()
- def is_derived_from(self, fromclause: FromClause) -> bool:
+ def is_derived_from(self, fromclause: Optional[FromClause]) -> bool:
"""Return ``True`` if this :class:`.ReturnsRows` is
'derived' from the given :class:`.FromClause`.
@@ -778,7 +780,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
"""
return TableSample._construct(self, sampling, name, seed)
- def is_derived_from(self, fromclause: FromClause) -> bool:
+ def is_derived_from(self, fromclause: Optional[FromClause]) -> bool:
"""Return ``True`` if this :class:`_expression.FromClause` is
'derived' from the given ``FromClause``.
@@ -1128,11 +1130,14 @@ class SelectLabelStyle(Enum):
"""
+ LABEL_STYLE_LEGACY_ORM = 3
+
(
LABEL_STYLE_NONE,
LABEL_STYLE_TABLENAME_PLUS_COL,
LABEL_STYLE_DISAMBIGUATE_ONLY,
+ _,
) = list(SelectLabelStyle)
LABEL_STYLE_DEFAULT = LABEL_STYLE_DISAMBIGUATE_ONLY
@@ -1231,7 +1236,7 @@ class Join(roles.DMLTableRole, FromClause):
id(self.right),
)
- def is_derived_from(self, fromclause: FromClause) -> bool:
+ def is_derived_from(self, fromclause: Optional[FromClause]) -> bool:
return (
# use hash() to ensure direct comparison to annotated works
# as well
@@ -1635,7 +1640,7 @@ class AliasedReturnsRows(NoInit, NamedFromClause):
"""Legacy for dialects that are referring to Alias.original."""
return self.element
- def is_derived_from(self, fromclause: FromClause) -> bool:
+ def is_derived_from(self, fromclause: Optional[FromClause]) -> bool:
if fromclause in self._cloned_set:
return True
return self.element.is_derived_from(fromclause)
@@ -2840,7 +2845,7 @@ class FromGrouping(GroupedElement, FromClause):
def foreign_keys(self):
return self.element.foreign_keys
- def is_derived_from(self, fromclause: FromClause) -> bool:
+ def is_derived_from(self, fromclause: Optional[FromClause]) -> bool:
return self.element.is_derived_from(fromclause)
def alias(
@@ -3080,11 +3085,17 @@ class ForUpdateArg(ClauseElement):
def __init__(
self,
- nowait=False,
- read=False,
- of=None,
- skip_locked=False,
- key_share=False,
+ *,
+ nowait: bool = False,
+ read: bool = False,
+ of: Optional[
+ Union[
+ _ColumnExpressionArgument[Any],
+ Sequence[_ColumnExpressionArgument[Any]],
+ ]
+ ] = None,
+ skip_locked: bool = False,
+ key_share: bool = False,
):
"""Represents arguments specified to
:meth:`_expression.Select.for_update`.
@@ -3455,7 +3466,7 @@ class SelectBase(
return ScalarSelect(self)
- def label(self, name):
+ def label(self, name: Optional[str]) -> Label[Any]:
"""Return a 'scalar' representation of this selectable, embedded as a
subquery with a label.
@@ -3667,6 +3678,7 @@ class GenerativeSelect(SelectBase, Generative):
@_generative
def with_for_update(
self: SelfGenerativeSelect,
+ *,
nowait: bool = False,
read: bool = False,
of: Optional[
@@ -4064,7 +4076,11 @@ class GenerativeSelect(SelectBase, Generative):
@_generative
def order_by(
- self: SelfGenerativeSelect, *clauses: _ColumnExpressionArgument[Any]
+ self: SelfGenerativeSelect,
+ __first: Union[
+ Literal[None, _NoArg.NO_ARG], _ColumnExpressionArgument[Any]
+ ] = _NoArg.NO_ARG,
+ *clauses: _ColumnExpressionArgument[Any],
) -> SelfGenerativeSelect:
r"""Return a new selectable with the given list of ORDER BY
criteria applied.
@@ -4092,18 +4108,22 @@ class GenerativeSelect(SelectBase, Generative):
"""
- if len(clauses) == 1 and clauses[0] is None:
+ if not clauses and __first is None:
self._order_by_clauses = ()
- else:
+ elif __first is not _NoArg.NO_ARG:
self._order_by_clauses += tuple(
coercions.expect(roles.OrderByRole, clause)
- for clause in clauses
+ for clause in (__first,) + clauses
)
return self
@_generative
def group_by(
- self: SelfGenerativeSelect, *clauses: _ColumnExpressionArgument[Any]
+ self: SelfGenerativeSelect,
+ __first: Union[
+ Literal[None, _NoArg.NO_ARG], _ColumnExpressionArgument[Any]
+ ] = _NoArg.NO_ARG,
+ *clauses: _ColumnExpressionArgument[Any],
) -> SelfGenerativeSelect:
r"""Return a new selectable with the given list of GROUP BY
criterion applied.
@@ -4128,12 +4148,12 @@ class GenerativeSelect(SelectBase, Generative):
"""
- if len(clauses) == 1 and clauses[0] is None:
+ if not clauses and __first is None:
self._group_by_clauses = ()
- else:
+ elif __first is not _NoArg.NO_ARG:
self._group_by_clauses += tuple(
coercions.expect(roles.GroupByRole, clause)
- for clause in clauses
+ for clause in (__first,) + clauses
)
return self
@@ -4257,7 +4277,7 @@ class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows):
) -> GroupedElement:
return SelectStatementGrouping(self)
- def is_derived_from(self, fromclause: FromClause) -> bool:
+ def is_derived_from(self, fromclause: Optional[FromClause]) -> bool:
for s in self.selects:
if s.is_derived_from(fromclause):
return True
@@ -4959,7 +4979,7 @@ class Select(
_raw_columns: List[_ColumnsClauseElement]
- _distinct = False
+ _distinct: bool = False
_distinct_on: Tuple[ColumnElement[Any], ...] = ()
_correlate: Tuple[FromClause, ...] = ()
_correlate_except: Optional[Tuple[FromClause, ...]] = None
@@ -5478,8 +5498,8 @@ class Select(
return iter(self._all_selected_columns)
- def is_derived_from(self, fromclause: FromClause) -> bool:
- if self in fromclause._cloned_set:
+ def is_derived_from(self, fromclause: Optional[FromClause]) -> bool:
+ if fromclause is not None and self in fromclause._cloned_set:
return True
for f in self._iterate_from_elements():
diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py
index aceed99a5..94e635740 100644
--- a/lib/sqlalchemy/sql/traversals.py
+++ b/lib/sqlalchemy/sql/traversals.py
@@ -19,6 +19,7 @@ from typing import Callable
from typing import Deque
from typing import Dict
from typing import Iterable
+from typing import Optional
from typing import Set
from typing import Tuple
from typing import Type
@@ -39,7 +40,7 @@ COMPARE_FAILED = False
COMPARE_SUCCEEDED = True
-def compare(obj1, obj2, **kw):
+def compare(obj1: Any, obj2: Any, **kw: Any) -> bool:
strategy: TraversalComparatorStrategy
if kw.get("use_proxies", False):
strategy = ColIdentityComparatorStrategy()
@@ -49,7 +50,7 @@ def compare(obj1, obj2, **kw):
return strategy.compare(obj1, obj2, **kw)
-def _preconfigure_traversals(target_hierarchy):
+def _preconfigure_traversals(target_hierarchy: Type[Any]) -> None:
for cls in util.walk_subclasses(target_hierarchy):
if hasattr(cls, "_generate_cache_attrs") and hasattr(
cls, "_traverse_internals"
@@ -482,14 +483,22 @@ class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots):
def __init__(self):
self.stack: Deque[
- Tuple[ExternallyTraversible, ExternallyTraversible]
+ Tuple[
+ Optional[ExternallyTraversible],
+ Optional[ExternallyTraversible],
+ ]
] = deque()
self.cache = set()
def _memoized_attr_anon_map(self):
return (anon_map(), anon_map())
- def compare(self, obj1, obj2, **kw):
+ def compare(
+ self,
+ obj1: ExternallyTraversible,
+ obj2: ExternallyTraversible,
+ **kw: Any,
+ ) -> bool:
stack = self.stack
cache = self.cache
@@ -551,6 +560,10 @@ class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots):
elif left_attrname in attributes_compared:
continue
+ assert left_visit_sym is not None
+ assert left_attrname is not None
+ assert right_attrname is not None
+
dispatch = self.dispatch(left_visit_sym)
assert dispatch, (
f"{self.__class__} has no dispatch for "
@@ -595,6 +608,14 @@ class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots):
self, attrname, left_parent, left, right_parent, right, **kw
):
for l, r in zip_longest(left, right, fillvalue=None):
+ if l is None:
+ if r is not None:
+ return COMPARE_FAILED
+ else:
+ continue
+ elif r is None:
+ return COMPARE_FAILED
+
if l._gen_cache_key(self.anon_map[0], []) != r._gen_cache_key(
self.anon_map[1], []
):
@@ -604,6 +625,14 @@ class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots):
self, attrname, left_parent, left, right_parent, right, **kw
):
for l, r in zip_longest(left, right, fillvalue=None):
+ if l is None:
+ if r is not None:
+ return COMPARE_FAILED
+ else:
+ continue
+ elif r is None:
+ return COMPARE_FAILED
+
if (
l._gen_cache_key(self.anon_map[0], [])
if l._is_has_cache_key
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index 262689128..390e23952 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -73,6 +73,7 @@ if typing.TYPE_CHECKING:
from ._typing import _ColumnExpressionArgument
from ._typing import _EquivalentColumnMap
from ._typing import _TypeEngineArgument
+ from .elements import BinaryExpression
from .elements import TextClause
from .selectable import _JoinTargetElement
from .selectable import _SelectIterable
@@ -86,8 +87,15 @@ if typing.TYPE_CHECKING:
from ..engine.interfaces import _CoreSingleExecuteParams
from ..engine.row import Row
+_CE = TypeVar("_CE", bound="ColumnElement[Any]")
-def join_condition(a, b, a_subset=None, consider_as_foreign_keys=None):
+
+def join_condition(
+ a: FromClause,
+ b: FromClause,
+ a_subset: Optional[FromClause] = None,
+ consider_as_foreign_keys: Optional[AbstractSet[ColumnClause[Any]]] = None,
+) -> ColumnElement[bool]:
"""Create a join condition between two tables or selectables.
e.g.::
@@ -118,7 +126,9 @@ def join_condition(a, b, a_subset=None, consider_as_foreign_keys=None):
)
-def find_join_source(clauses, join_to):
+def find_join_source(
+ clauses: List[FromClause], join_to: FromClause
+) -> List[int]:
"""Given a list of FROM clauses and a selectable,
return the first index and element from the list of
clauses which can be joined against the selectable. returns
@@ -144,7 +154,9 @@ def find_join_source(clauses, join_to):
return idx
-def find_left_clause_that_matches_given(clauses, join_from):
+def find_left_clause_that_matches_given(
+ clauses: Sequence[FromClause], join_from: FromClause
+) -> List[int]:
"""Given a list of FROM clauses and a selectable,
return the indexes from the list of
clauses which is derived from the selectable.
@@ -243,7 +255,12 @@ def find_left_clause_to_join_from(
return idx
-def visit_binary_product(fn, expr):
+def visit_binary_product(
+ fn: Callable[
+ [BinaryExpression[Any], ColumnElement[Any], ColumnElement[Any]], None
+ ],
+ expr: ColumnElement[Any],
+) -> None:
"""Produce a traversal of the given expression, delivering
column comparisons to the given function.
@@ -278,19 +295,19 @@ def visit_binary_product(fn, expr):
a binary comparison is passed as pairs.
"""
- stack: List[ClauseElement] = []
+ stack: List[BinaryExpression[Any]] = []
- def visit(element):
+ def visit(element: ClauseElement) -> Iterator[ColumnElement[Any]]:
if isinstance(element, ScalarSelect):
# we don't want to dig into correlated subqueries,
# those are just column elements by themselves
yield element
elif element.__visit_name__ == "binary" and operators.is_comparison(
- element.operator
+ element.operator # type: ignore
):
- stack.insert(0, element)
- for l in visit(element.left):
- for r in visit(element.right):
+ stack.insert(0, element) # type: ignore
+ for l in visit(element.left): # type: ignore
+ for r in visit(element.right): # type: ignore
fn(stack[0], l, r)
stack.pop(0)
for elem in element.get_children():
@@ -502,7 +519,7 @@ def extract_first_column_annotation(column, annotation_name):
return None
-def selectables_overlap(left, right):
+def selectables_overlap(left: FromClause, right: FromClause) -> bool:
"""Return True if left/right have some overlapping selectable"""
return bool(
@@ -701,7 +718,7 @@ class _repr_params(_repr_base):
return "[%s]" % (", ".join(trunc(value) for value in params))
-def adapt_criterion_to_null(crit, nulls):
+def adapt_criterion_to_null(crit: _CE, nulls: Collection[Any]) -> _CE:
"""given criterion containing bind params, convert selected elements
to IS NULL.
@@ -922,9 +939,6 @@ def criterion_as_pairs(
return pairs
-_CE = TypeVar("_CE", bound="ClauseElement")
-
-
class ClauseAdapter(visitors.ReplacingExternalTraversal):
"""Clones and modifies clauses based on column correspondence.
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
index 217e2d2ab..b550f8f28 100644
--- a/lib/sqlalchemy/sql/visitors.py
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -21,7 +21,6 @@ from typing import Any
from typing import Callable
from typing import cast
from typing import ClassVar
-from typing import Collection
from typing import Dict
from typing import Iterable
from typing import Iterator
@@ -31,6 +30,7 @@ from typing import Optional
from typing import overload
from typing import Tuple
from typing import Type
+from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
@@ -42,6 +42,10 @@ from ..util.typing import Literal
from ..util.typing import Protocol
from ..util.typing import Self
+if TYPE_CHECKING:
+ from .annotation import _AnnotationDict
+ from .elements import ColumnElement
+
if typing.TYPE_CHECKING or not HAS_CYEXTENSION:
from ._py_util import prefix_anon_map as prefix_anon_map
from ._py_util import cache_anon_map as anon_map
@@ -590,13 +594,23 @@ _dispatch_lookup = HasTraversalDispatch._dispatch_lookup
_generate_traversal_dispatch()
+SelfExternallyTraversible = TypeVar(
+ "SelfExternallyTraversible", bound="ExternallyTraversible"
+)
+
+
class ExternallyTraversible(HasTraverseInternals, Visitable):
__slots__ = ()
- _annotations: Collection[Any] = ()
+ _annotations: Mapping[Any, Any] = util.EMPTY_DICT
if typing.TYPE_CHECKING:
+ def _annotate(
+ self: SelfExternallyTraversible, values: _AnnotationDict
+ ) -> SelfExternallyTraversible:
+ ...
+
def get_children(
self, *, omit_attrs: Tuple[str, ...] = (), **kw: Any
) -> Iterable[ExternallyTraversible]:
@@ -624,6 +638,7 @@ class ExternallyTraversible(HasTraverseInternals, Visitable):
_ET = TypeVar("_ET", bound=ExternallyTraversible)
+_CE = TypeVar("_CE", bound="ColumnElement[Any]")
_TraverseCallableType = Callable[[_ET], None]
@@ -633,10 +648,8 @@ class _CloneCallableType(Protocol):
...
-class _TraverseTransformCallableType(Protocol):
- def __call__(
- self, element: ExternallyTraversible, **kw: Any
- ) -> Optional[ExternallyTraversible]:
+class _TraverseTransformCallableType(Protocol[_ET]):
+ def __call__(self, element: _ET, **kw: Any) -> Optional[_ET]:
...
@@ -1074,16 +1087,25 @@ def cloned_traverse(
def replacement_traverse(
obj: Literal[None],
opts: Mapping[str, Any],
- replace: _TraverseTransformCallableType,
+ replace: _TraverseTransformCallableType[Any],
) -> None:
...
@overload
def replacement_traverse(
+ obj: _CE,
+ opts: Mapping[str, Any],
+ replace: _TraverseTransformCallableType[Any],
+) -> _CE:
+ ...
+
+
+@overload
+def replacement_traverse(
obj: ExternallyTraversible,
opts: Mapping[str, Any],
- replace: _TraverseTransformCallableType,
+ replace: _TraverseTransformCallableType[Any],
) -> ExternallyTraversible:
...
@@ -1091,7 +1113,7 @@ def replacement_traverse(
def replacement_traverse(
obj: Optional[ExternallyTraversible],
opts: Mapping[str, Any],
- replace: _TraverseTransformCallableType,
+ replace: _TraverseTransformCallableType[Any],
) -> Optional[ExternallyTraversible]:
"""Clone the given expression structure, allowing element
replacement by a given replacement function.
@@ -1134,7 +1156,7 @@ def replacement_traverse(
newelem = replace(elem)
if newelem is not None:
stop_on.add(id(newelem))
- return newelem
+ return newelem # type: ignore
else:
# base "already seen" on id(), not hash, so that we don't
# replace an Annotated element with its non-annotated one, and
@@ -1145,11 +1167,11 @@ def replacement_traverse(
newelem = kw["replace"](elem)
if newelem is not None:
cloned[id_elem] = newelem
- return newelem
+ return newelem # type: ignore
cloned[id_elem] = newelem = elem._clone(**kw)
newelem._copy_internals(clone=clone, **kw)
- return cloned[id_elem]
+ return cloned[id_elem] # type: ignore
if obj is not None:
obj = clone(
diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py
index 7150dedcf..54be2e4e5 100644
--- a/lib/sqlalchemy/util/_collections.py
+++ b/lib/sqlalchemy/util/_collections.py
@@ -71,7 +71,7 @@ _T_co = TypeVar("_T_co", covariant=True)
EMPTY_SET: FrozenSet[Any] = frozenset()
-def merge_lists_w_ordering(a, b):
+def merge_lists_w_ordering(a: List[Any], b: List[Any]) -> List[Any]:
"""merge two lists, maintaining ordering as much as possible.
this is to reconcile vars(cls) with cls.__annotations__.
@@ -450,7 +450,7 @@ def to_set(x):
return x
-def to_column_set(x):
+def to_column_set(x: Any) -> Set[Any]:
if x is None:
return column_set()
if not isinstance(x, column_set):
diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py
index 24fa0f3e3..adbbf143f 100644
--- a/lib/sqlalchemy/util/compat.py
+++ b/lib/sqlalchemy/util/compat.py
@@ -20,11 +20,14 @@ import typing
from typing import Any
from typing import Callable
from typing import Dict
+from typing import Iterable
from typing import List
from typing import Mapping
from typing import Optional
from typing import Sequence
+from typing import Set
from typing import Tuple
+from typing import Type
py311 = sys.version_info >= (3, 11)
@@ -225,7 +228,7 @@ def inspect_formatargspec(
return result
-def dataclass_fields(cls):
+def dataclass_fields(cls: Type[Any]) -> Iterable[dataclasses.Field[Any]]:
"""Return a sequence of all dataclasses.Field objects associated
with a class."""
@@ -235,12 +238,12 @@ def dataclass_fields(cls):
return []
-def local_dataclass_fields(cls):
+def local_dataclass_fields(cls: Type[Any]) -> Iterable[dataclasses.Field[Any]]:
"""Return a sequence of all dataclasses.Field objects associated with
a class, excluding those that originate from a superclass."""
if dataclasses.is_dataclass(cls):
- super_fields = set()
+ super_fields: Set[dataclasses.Field[Any]] = set()
for sup in cls.__bases__:
super_fields.update(dataclass_fields(sup))
return [f for f in dataclasses.fields(cls) if f not in super_fields]
diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py
index 24c66bfa4..e54f33475 100644
--- a/lib/sqlalchemy/util/langhelpers.py
+++ b/lib/sqlalchemy/util/langhelpers.py
@@ -266,13 +266,31 @@ def decorator(target: Callable[..., Any]) -> Callable[[_Fn], _Fn]:
metadata: Dict[str, Optional[str]] = dict(target=targ_name, fn=fn_name)
metadata.update(format_argspec_plus(spec, grouped=False))
metadata["name"] = fn.__name__
- code = (
- """\
+
+ # look for __ positional arguments. This is a convention in
+ # SQLAlchemy that arguments should be passed positionally
+ # rather than as keyword
+ # arguments. note that apply_pos doesn't currently work in all cases
+ # such as when a kw-only indicator "*" is present, which is why
+ # we limit the use of this to just that case we can detect. As we add
+ # more kinds of methods that use @decorator, things may have to
+ # be further improved in this area
+ if "__" in repr(spec[0]):
+ code = (
+ """\
+def %(name)s%(grouped_args)s:
+ return %(target)s(%(fn)s, %(apply_pos)s)
+"""
+ % metadata
+ )
+ else:
+ code = (
+ """\
def %(name)s%(grouped_args)s:
return %(target)s(%(fn)s, %(apply_kw)s)
"""
- % metadata
- )
+ % metadata
+ )
env.update({targ_name: target, fn_name: fn, "__name__": fn.__module__})
decorated = cast(
@@ -1235,10 +1253,10 @@ class HasMemoized:
return result
@classmethod
- def memoized_instancemethod(cls, fn: Any) -> Any:
+ def memoized_instancemethod(cls, fn: _F) -> _F:
"""Decorate a method memoize its return value."""
- def oneshot(self, *args, **kw):
+ def oneshot(self: Any, *args: Any, **kw: Any) -> Any:
result = fn(self, *args, **kw)
def memo(*a, **kw):
@@ -1250,7 +1268,7 @@ class HasMemoized:
self._memoized_keys |= {fn.__name__}
return result
- return update_wrapper(oneshot, fn)
+ return update_wrapper(oneshot, fn) # type: ignore
if TYPE_CHECKING:
diff --git a/lib/sqlalchemy/util/preloaded.py b/lib/sqlalchemy/util/preloaded.py
index fce3cd3b0..67394c9a3 100644
--- a/lib/sqlalchemy/util/preloaded.py
+++ b/lib/sqlalchemy/util/preloaded.py
@@ -25,8 +25,12 @@ _FN = TypeVar("_FN", bound=Callable[..., Any])
if TYPE_CHECKING:
from sqlalchemy.engine import default as engine_default # noqa
+ from sqlalchemy.orm import clsregistry as orm_clsregistry # noqa
+ from sqlalchemy.orm import decl_api as orm_decl_api # noqa
+ from sqlalchemy.orm import properties as orm_properties # noqa
from sqlalchemy.orm import relationships as orm_relationships # noqa
from sqlalchemy.orm import session as orm_session # noqa
+ from sqlalchemy.orm import state as orm_state # noqa
from sqlalchemy.orm import util as orm_util # noqa
from sqlalchemy.sql import dml as sql_dml # noqa
from sqlalchemy.sql import functions as sql_functions # noqa
diff --git a/lib/sqlalchemy/util/topological.py b/lib/sqlalchemy/util/topological.py
index 37297103e..24e478b57 100644
--- a/lib/sqlalchemy/util/topological.py
+++ b/lib/sqlalchemy/util/topological.py
@@ -4,21 +4,33 @@
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: allow-untyped-defs, allow-untyped-calls
"""Topological sorting algorithms."""
from __future__ import annotations
+from typing import Any
+from typing import DefaultDict
+from typing import Iterable
+from typing import Iterator
+from typing import Sequence
+from typing import Set
+from typing import Tuple
+from typing import TypeVar
+
from .. import util
from ..exc import CircularDependencyError
+_T = TypeVar("_T", bound=Any)
+
__all__ = ["sort", "sort_as_subsets", "find_cycles"]
-def sort_as_subsets(tuples, allitems):
+def sort_as_subsets(
+ tuples: Iterable[Tuple[_T, _T]], allitems: Iterable[_T]
+) -> Iterator[Sequence[_T]]:
- edges = util.defaultdict(set)
+ edges: DefaultDict[_T, Set[_T]] = util.defaultdict(set)
for parent, child in tuples:
edges[child].add(parent)
@@ -43,7 +55,11 @@ def sort_as_subsets(tuples, allitems):
yield output
-def sort(tuples, allitems, deterministic_order=True):
+def sort(
+ tuples: Iterable[Tuple[_T, _T]],
+ allitems: Iterable[_T],
+ deterministic_order: bool = True,
+) -> Iterator[_T]:
"""sort the given list of items by dependency.
'tuples' is a list of tuples representing a partial ordering.
@@ -59,11 +75,14 @@ def sort(tuples, allitems, deterministic_order=True):
yield s
-def find_cycles(tuples, allitems):
+def find_cycles(
+ tuples: Iterable[Tuple[_T, _T]],
+ allitems: Iterable[_T],
+) -> Set[_T]:
# adapted from:
# https://neopythonic.blogspot.com/2009/01/detecting-cycles-in-directed-graph.html
- edges = util.defaultdict(set)
+ edges: DefaultDict[_T, Set[_T]] = util.defaultdict(set)
for parent, child in tuples:
edges[parent].add(child)
nodes_to_test = set(edges)
@@ -99,5 +118,5 @@ def find_cycles(tuples, allitems):
return output
-def _gen_edges(edges):
- return set([(right, left) for left in edges for right in edges[left]])
+def _gen_edges(edges: DefaultDict[_T, Set[_T]]) -> Set[Tuple[_T, _T]]:
+ return {(right, left) for left in edges for right in edges[left]}
diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py
index ebcae28a7..44e26f609 100644
--- a/lib/sqlalchemy/util/typing.py
+++ b/lib/sqlalchemy/util/typing.py
@@ -11,7 +11,9 @@ from typing import Dict
from typing import ForwardRef
from typing import Generic
from typing import Iterable
+from typing import NoReturn
from typing import Optional
+from typing import overload
from typing import Tuple
from typing import Type
from typing import TypeVar
@@ -33,7 +35,7 @@ Self = TypeVar("Self", bound=Any)
if compat.py310:
# why they took until py310 to put this in stdlib is beyond me,
# I've been wanting it since py27
- from types import NoneType
+ from types import NoneType as NoneType
else:
NoneType = type(None) # type: ignore
@@ -68,6 +70,8 @@ else:
# copied from TypeShed, required in order to implement
# MutableMapping.update()
+_AnnotationScanType = Union[Type[Any], str]
+
class SupportsKeysAndGetItem(Protocol[_KT, _VT_co]):
def keys(self) -> Iterable[_KT]:
@@ -90,9 +94,9 @@ else:
def de_stringify_annotation(
cls: Type[Any],
- annotation: Union[str, Type[Any]],
+ annotation: _AnnotationScanType,
str_cleanup_fn: Optional[Callable[[str], str]] = None,
-) -> Union[str, Type[Any]]:
+) -> Type[Any]:
"""Resolve annotations that may be string based into real objects.
This is particularly important if a module defines "from __future__ import
@@ -125,20 +129,32 @@ def de_stringify_annotation(
annotation = eval(annotation, base_globals, None)
except NameError:
pass
- return annotation
+ return annotation # type: ignore
-def is_fwd_ref(type_):
+def is_fwd_ref(type_: _AnnotationScanType) -> bool:
return isinstance(type_, ForwardRef)
-def de_optionalize_union_types(type_):
+@overload
+def de_optionalize_union_types(type_: str) -> str:
+ ...
+
+
+@overload
+def de_optionalize_union_types(type_: Type[Any]) -> Type[Any]:
+ ...
+
+
+def de_optionalize_union_types(
+ type_: _AnnotationScanType,
+) -> _AnnotationScanType:
"""Given a type, filter out ``Union`` types that include ``NoneType``
to not include the ``NoneType``.
"""
if is_optional(type_):
- typ = set(type_.__args__)
+ typ = set(type_.__args__) # type: ignore
typ.discard(NoneType)
@@ -148,14 +164,14 @@ def de_optionalize_union_types(type_):
return type_
-def make_union_type(*types):
+def make_union_type(*types: _AnnotationScanType) -> Type[Any]:
"""Make a Union type.
This is needed by :func:`.de_optionalize_union_types` which removes
``NoneType`` from a ``Union``.
"""
- return cast(Any, Union).__getitem__(types)
+ return cast(Any, Union).__getitem__(types) # type: ignore
def expand_unions(
@@ -251,4 +267,47 @@ class DescriptorReference(Generic[_DESC]):
...
+_DESC_co = TypeVar("_DESC_co", bound=DescriptorProto, covariant=True)
+
+
+class RODescriptorReference(Generic[_DESC_co]):
+ """a descriptor that refers to a descriptor.
+
+ same as :class:`.DescriptorReference` but is read-only, so that subclasses
+ can define a subtype as the generically contained element
+
+ """
+
+ def __get__(self, instance: object, owner: Any) -> _DESC_co:
+ ...
+
+ def __set__(self, instance: Any, value: Any) -> NoReturn:
+ ...
+
+ def __delete__(self, instance: Any) -> NoReturn:
+ ...
+
+
+_FN = TypeVar("_FN", bound=Optional[Callable[..., Any]])
+
+
+class CallableReference(Generic[_FN]):
+ """a descriptor that refers to a callable.
+
+ works around mypy's limitation of not allowing callables assigned
+ as instance variables
+
+
+ """
+
+ def __get__(self, instance: object, owner: Any) -> _FN:
+ ...
+
+ def __set__(self, instance: Any, value: _FN) -> None:
+ ...
+
+ def __delete__(self, instance: Any) -> None:
+ ...
+
+
# $def ro_descriptor_reference(fn: Callable[])