diff options
Diffstat (limited to 'lib/sqlalchemy/orm/properties.py')
-rw-r--r-- | lib/sqlalchemy/orm/properties.py | 197 |
1 files changed, 126 insertions, 71 deletions
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 0ca0559b4..911617d6d 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -16,8 +16,10 @@ from __future__ import annotations from typing import Any from typing import cast +from typing import Dict from typing import List from typing import Optional +from typing import Sequence from typing import Set from typing import Type from typing import TYPE_CHECKING @@ -25,7 +27,6 @@ from typing import TypeVar from . import attributes from . import strategy_options -from .base import SQLCoreOperations from .descriptor_props import Composite from .descriptor_props import ConcreteInheritedProperty from .descriptor_props import Synonym @@ -44,20 +45,34 @@ from .. import util from ..sql import coercions from ..sql import roles from ..sql import sqltypes +from ..sql.elements import SQLCoreOperations from ..sql.schema import Column from ..sql.schema import SchemaConst from ..util.typing import de_optionalize_union_types from ..util.typing import de_stringify_annotation from ..util.typing import is_fwd_ref from ..util.typing import NoneType +from ..util.typing import Self if TYPE_CHECKING: + from ._typing import _IdentityKeyType + from ._typing import _InstanceDict from ._typing import _ORMColumnExprArgument + from ._typing import _RegistryType + from .mapper import Mapper + from .session import Session + from .state import _InstallLoaderCallableProto + from .state import InstanceState from ..sql._typing import _InfoType - from ..sql.elements import KeyedColumnElement + from ..sql.elements import ColumnElement + from ..sql.elements import NamedColumn + from ..sql.operators import OperatorType + from ..util.typing import _AnnotationScanType + from ..util.typing import RODescriptorReference _T = TypeVar("_T", bound=Any) _PT = TypeVar("_PT", bound=Any) +_NC = TypeVar("_NC", bound="NamedColumn[Any]") __all__ = [ "ColumnProperty", @@ -85,11 +100,15 @@ class ColumnProperty( inherit_cache = True _links_to_entity = False - columns: List[KeyedColumnElement[Any]] - _orig_columns: List[KeyedColumnElement[Any]] + columns: List[NamedColumn[Any]] + _orig_columns: List[NamedColumn[Any]] _is_polymorphic_discriminator: bool + _mapped_by_synonym: Optional[str] + + comparator_factory: Type[PropComparator[_T]] + __slots__ = ( "_orig_columns", "columns", @@ -100,7 +119,6 @@ class ColumnProperty( "descriptor", "active_history", "expire_on_flush", - "doc", "_creation_order", "_is_polymorphic_discriminator", "_mapped_by_synonym", @@ -117,7 +135,7 @@ class ColumnProperty( group: Optional[str] = None, deferred: bool = False, raiseload: bool = False, - comparator_factory: Optional[Type[PropComparator]] = None, + comparator_factory: Optional[Type[PropComparator[_T]]] = None, descriptor: Optional[Any] = None, active_history: bool = False, expire_on_flush: bool = True, @@ -150,7 +168,7 @@ class ColumnProperty( self.expire_on_flush = expire_on_flush if info is not None: - self.info = info + self.info.update(info) if doc is not None: self.doc = doc @@ -173,8 +191,13 @@ class ColumnProperty( self.strategy_key += (("raiseload", True),) def declarative_scan( - self, registry, cls, key, annotation, is_dataclass_field - ): + self, + registry: _RegistryType, + cls: Type[Any], + key: str, + annotation: Optional[_AnnotationScanType], + is_dataclass_field: bool, + ) -> None: column = self.columns[0] if column.key is None: column.key = key @@ -186,20 +209,23 @@ class ColumnProperty( return self @property - def columns_to_assign(self) -> List[Column]: + def columns_to_assign(self) -> List[Column[Any]]: + # mypy doesn't care about the isinstance here return [ - c + c # type: ignore for c in self.columns if isinstance(c, Column) and c.table is None ] - def _memoized_attr__renders_in_subqueries(self): + def _memoized_attr__renders_in_subqueries(self) -> bool: return ("deferred", True) not in self.strategy_key or ( - self not in self.parent._readonly_props + self not in self.parent._readonly_props # type: ignore ) @util.preload_module("sqlalchemy.orm.state", "sqlalchemy.orm.strategies") - def _memoized_attr__deferred_column_loader(self): + def _memoized_attr__deferred_column_loader( + self, + ) -> _InstallLoaderCallableProto[Any]: state = util.preloaded.orm_state strategies = util.preloaded.orm_strategies return state.InstanceState._instance_level_callable_processor( @@ -209,7 +235,9 @@ class ColumnProperty( ) @util.preload_module("sqlalchemy.orm.state", "sqlalchemy.orm.strategies") - def _memoized_attr__raise_column_loader(self): + def _memoized_attr__raise_column_loader( + self, + ) -> _InstallLoaderCallableProto[Any]: state = util.preloaded.orm_state strategies = util.preloaded.orm_strategies return state.InstanceState._instance_level_callable_processor( @@ -218,7 +246,7 @@ class ColumnProperty( self.key, ) - def __clause_element__(self): + def __clause_element__(self) -> roles.ColumnsClauseRole: """Allow the ColumnProperty to work in expression before it is turned into an instrumented attribute. """ @@ -226,7 +254,7 @@ class ColumnProperty( return self.expression @property - def expression(self): + def expression(self) -> roles.ColumnsClauseRole: """Return the primary column or expression for this ColumnProperty. E.g.:: @@ -247,7 +275,7 @@ class ColumnProperty( """ return self.columns[0] - def instrument_class(self, mapper): + def instrument_class(self, mapper: Mapper[Any]) -> None: if not self.instrument: return @@ -259,7 +287,7 @@ class ColumnProperty( doc=self.doc, ) - def do_init(self): + def do_init(self) -> None: super().do_init() if len(self.columns) > 1 and set(self.parent.primary_key).issuperset( @@ -275,32 +303,25 @@ class ColumnProperty( % (self.parent, self.columns[1], self.columns[0], self.key) ) - def copy(self): + def copy(self) -> ColumnProperty[_T]: return ColumnProperty( + *self.columns, deferred=self.deferred, group=self.group, active_history=self.active_history, - *self.columns, - ) - - def _getcommitted( - self, state, dict_, column, passive=attributes.PASSIVE_OFF - ): - return state.get_impl(self.key).get_committed_value( - state, dict_, passive=passive ) def merge( self, - session, - source_state, - source_dict, - dest_state, - dest_dict, - load, - _recursive, - _resolve_conflict_map, - ): + session: Session, + source_state: InstanceState[Any], + source_dict: _InstanceDict, + dest_state: InstanceState[Any], + dest_dict: _InstanceDict, + load: bool, + _recursive: Dict[Any, object], + _resolve_conflict_map: Dict[_IdentityKeyType[Any], object], + ) -> None: if not self.instrument: return elif self.key in source_dict: @@ -335,9 +356,13 @@ class ColumnProperty( """ - __slots__ = "__clause_element__", "info", "expressions" + if not TYPE_CHECKING: + # prevent pylance from being clever about slots + __slots__ = "__clause_element__", "info", "expressions" + + prop: RODescriptorReference[ColumnProperty[_PT]] - def _orm_annotate_column(self, column): + def _orm_annotate_column(self, column: _NC) -> _NC: """annotate and possibly adapt a column to be returned as the mapped-attribute exposed version of the column. @@ -351,7 +376,7 @@ class ColumnProperty( """ pe = self._parententity - annotations = { + annotations: Dict[str, Any] = { "entity_namespace": pe, "parententity": pe, "parentmapper": pe, @@ -377,22 +402,29 @@ class ColumnProperty( {"compile_state_plugin": "orm", "plugin_subject": pe} ) - def _memoized_method___clause_element__(self): + if TYPE_CHECKING: + + def __clause_element__(self) -> NamedColumn[_PT]: + ... + + def _memoized_method___clause_element__( + self, + ) -> NamedColumn[_PT]: if self.adapter: return self.adapter(self.prop.columns[0], self.prop.key) else: return self._orm_annotate_column(self.prop.columns[0]) - def _memoized_attr_info(self): + def _memoized_attr_info(self) -> _InfoType: """The .info dictionary for this attribute.""" ce = self.__clause_element__() try: - return ce.info + return ce.info # type: ignore except AttributeError: return self.prop.info - def _memoized_attr_expressions(self): + def _memoized_attr_expressions(self) -> Sequence[NamedColumn[Any]]: """The full sequence of columns referenced by this attribute, adjusted for any aliasing in progress. @@ -409,21 +441,25 @@ class ColumnProperty( self._orm_annotate_column(col) for col in self.prop.columns ] - def _fallback_getattr(self, key): + def _fallback_getattr(self, key: str) -> Any: """proxy attribute access down to the mapped column. this allows user-defined comparison methods to be accessed. """ return getattr(self.__clause_element__(), key) - def operate(self, op, *other, **kwargs): - return op(self.__clause_element__(), *other, **kwargs) + def operate( + self, op: OperatorType, *other: Any, **kwargs: Any + ) -> ColumnElement[Any]: + return op(self.__clause_element__(), *other, **kwargs) # type: ignore[return-value] # noqa: E501 - def reverse_operate(self, op, other, **kwargs): + def reverse_operate( + self, op: OperatorType, other: Any, **kwargs: Any + ) -> ColumnElement[Any]: col = self.__clause_element__() - return op(col._bind_param(op, other), col, **kwargs) + return op(col._bind_param(op, other), col, **kwargs) # type: ignore[return-value] # noqa: E501 - def __str__(self): + def __str__(self) -> str: if not self.parent or not self.key: return object.__repr__(self) return str(self.parent.class_.__name__) + "." + self.key @@ -460,7 +496,7 @@ class MappedColumn( column: Column[_T] foreign_keys: Optional[Set[ForeignKey]] - def __init__(self, *arg, **kw): + def __init__(self, *arg: Any, **kw: Any): self.deferred = kw.pop("deferred", False) self.column = cast("Column[_T]", Column(*arg, **kw)) self.foreign_keys = self.column.foreign_keys @@ -470,8 +506,8 @@ class MappedColumn( ) util.set_creation_order(self) - def _copy(self, **kw): - new = self.__class__.__new__(self.__class__) + def _copy(self: Self, **kw: Any) -> Self: + new = cast(Self, self.__class__.__new__(self.__class__)) new.column = self.column._copy(**kw) new.deferred = self.deferred new.foreign_keys = new.column.foreign_keys @@ -487,22 +523,31 @@ class MappedColumn( return None @property - def columns_to_assign(self) -> List[Column]: + def columns_to_assign(self) -> List[Column[Any]]: return [self.column] - def __clause_element__(self): + def __clause_element__(self) -> Column[_T]: return self.column - def operate(self, op, *other, **kwargs): - return op(self.__clause_element__(), *other, **kwargs) + def operate( + self, op: OperatorType, *other: Any, **kwargs: Any + ) -> ColumnElement[Any]: + return op(self.__clause_element__(), *other, **kwargs) # type: ignore[return-value] # noqa: E501 - def reverse_operate(self, op, other, **kwargs): + def reverse_operate( + self, op: OperatorType, other: Any, **kwargs: Any + ) -> ColumnElement[Any]: col = self.__clause_element__() - return op(col._bind_param(op, other), col, **kwargs) + return op(col._bind_param(op, other), col, **kwargs) # type: ignore[return-value] # noqa: E501 def declarative_scan( - self, registry, cls, key, annotation, is_dataclass_field - ): + self, + registry: _RegistryType, + cls: Type[Any], + key: str, + annotation: Optional[_AnnotationScanType], + is_dataclass_field: bool, + ) -> None: column = self.column if column.key is None: column.key = key @@ -526,38 +571,48 @@ class MappedColumn( @util.preload_module("sqlalchemy.orm.decl_base") def declarative_scan_for_composite( - self, registry, cls, key, param_name, param_annotation - ): + self, + registry: _RegistryType, + cls: Type[Any], + key: str, + param_name: str, + param_annotation: _AnnotationScanType, + ) -> None: decl_base = util.preloaded.orm_decl_base decl_base._undefer_column_name(param_name, self.column) self._init_column_for_annotation(cls, registry, param_annotation) - def _init_column_for_annotation(self, cls, registry, argument): + def _init_column_for_annotation( + self, + cls: Type[Any], + registry: _RegistryType, + argument: _AnnotationScanType, + ) -> None: sqltype = self.column.type nullable = False if hasattr(argument, "__origin__"): - nullable = NoneType in argument.__args__ + nullable = NoneType in argument.__args__ # type: ignore if not self._has_nullable: self.column.nullable = nullable if sqltype._isnull and not self.column.foreign_keys: - sqltype = None + new_sqltype = None our_type = de_optionalize_union_types(argument) if is_fwd_ref(our_type): our_type = de_stringify_annotation(cls, our_type) if registry.type_annotation_map: - sqltype = registry.type_annotation_map.get(our_type) - if sqltype is None: - sqltype = sqltypes._type_map_get(our_type) + new_sqltype = registry.type_annotation_map.get(our_type) + if new_sqltype is None: + new_sqltype = sqltypes._type_map_get(our_type) # type: ignore - if sqltype is None: + if new_sqltype is None: raise sa_exc.ArgumentError( f"Could not locate SQLAlchemy Core " f"type for Python type: {our_type}" ) - self.column.type = sqltype + self.column.type = new_sqltype # type: ignore |