diff options
Diffstat (limited to 'lib/sqlalchemy/orm/descriptor_props.py')
-rw-r--r-- | lib/sqlalchemy/orm/descriptor_props.py | 335 |
1 files changed, 231 insertions, 104 deletions
diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index 5975c30db..8c89f96aa 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -20,15 +20,21 @@ import typing from typing import Any from typing import Callable from typing import List +from typing import NoReturn from typing import Optional +from typing import Sequence from typing import Tuple from typing import Type +from typing import TYPE_CHECKING from typing import TypeVar from typing import Union from . import attributes from . import util as orm_util +from .base import LoaderCallableStatus from .base import Mapped +from .base import PassiveFlag +from .base import SQLORMOperations from .interfaces import _IntrospectsAnnotations from .interfaces import _MapsColumns from .interfaces import MapperProperty @@ -41,20 +47,41 @@ from .. import schema from .. import sql from .. import util from ..sql import expression -from ..sql import operators +from ..sql.elements import BindParameter from ..util.typing import Protocol if typing.TYPE_CHECKING: + from ._typing import _InstanceDict + from ._typing import _RegistryType + from .attributes import History from .attributes import InstrumentedAttribute + from .attributes import QueryableAttribute + from .context import ORMCompileState + from .mapper import Mapper + from .properties import ColumnProperty from .properties import MappedColumn + from .state import InstanceState + from ..engine.base import Connection + from ..engine.row import Row + from ..sql._typing import _DMLColumnArgument from ..sql._typing import _InfoType + from ..sql.elements import ClauseList + from ..sql.elements import ColumnElement from ..sql.schema import Column + from ..sql.selectable import Select + from ..util.typing import _AnnotationScanType + from ..util.typing import CallableReference + from ..util.typing import DescriptorReference + from ..util.typing import RODescriptorReference _T = TypeVar("_T", bound=Any) _PT = TypeVar("_PT", bound=Any) class _CompositeClassProto(Protocol): + def __init__(self, *args: Any): + ... + def __composite_values__(self) -> Tuple[Any, ...]: ... @@ -63,32 +90,43 @@ class DescriptorProperty(MapperProperty[_T]): """:class:`.MapperProperty` which proxies access to a user-defined descriptor.""" - doc = None + doc: Optional[str] = None uses_objects = False _links_to_entity = False - def instrument_class(self, mapper): + descriptor: DescriptorReference[Any] + + def get_history( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + passive: PassiveFlag = PassiveFlag.PASSIVE_OFF, + ) -> History: + raise NotImplementedError() + + def instrument_class(self, mapper: Mapper[Any]) -> None: prop = self - class _ProxyImpl: + class _ProxyImpl(attributes.AttributeImpl): accepts_scalar_loader = False load_on_unexpire = True collection = False @property - def uses_objects(self): + def uses_objects(self) -> bool: # type: ignore return prop.uses_objects - def __init__(self, key): + def __init__(self, key: str): self.key = key - if hasattr(prop, "get_history"): - - def get_history( - self, state, dict_, passive=attributes.PASSIVE_OFF - ): - return prop.get_history(state, dict_, passive) + def get_history( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + passive: PassiveFlag = PassiveFlag.PASSIVE_OFF, + ) -> History: + return prop.get_history(state, dict_, passive) if self.descriptor is None: desc = getattr(mapper.class_, self.key, None) @@ -97,13 +135,13 @@ class DescriptorProperty(MapperProperty[_T]): if self.descriptor is None: - def fset(obj, value): + def fset(obj: Any, value: Any) -> None: setattr(obj, self.name, value) - def fdel(obj): + def fdel(obj: Any) -> None: delattr(obj, self.name) - def fget(obj): + def fget(obj: Any) -> Any: return getattr(obj, self.name) self.descriptor = property(fget=fget, fset=fset, fdel=fdel) @@ -129,8 +167,11 @@ _CompositeAttrType = Union[ ] +_CC = TypeVar("_CC", bound=_CompositeClassProto) + + class Composite( - _MapsColumns[_T], _IntrospectsAnnotations, DescriptorProperty[_T] + _MapsColumns[_CC], _IntrospectsAnnotations, DescriptorProperty[_CC] ): """Defines a "composite" mapped attribute, representing a collection of columns as one attribute. @@ -148,19 +189,25 @@ class Composite( """ - composite_class: Union[ - Type[_CompositeClassProto], Callable[..., Type[_CompositeClassProto]] + composite_class: Union[Type[_CC], Callable[..., _CC]] + attrs: Tuple[_CompositeAttrType[Any], ...] + + _generated_composite_accessor: CallableReference[ + Optional[Callable[[_CC], Tuple[Any, ...]]] ] - attrs: Tuple[_CompositeAttrType, ...] + + comparator_factory: Type[Comparator[_CC]] def __init__( self, - class_: Union[None, _CompositeClassProto, _CompositeAttrType] = None, - *attrs: _CompositeAttrType, + class_: Union[ + None, Type[_CC], Callable[..., _CC], _CompositeAttrType[Any] + ] = None, + *attrs: _CompositeAttrType[Any], active_history: bool = False, deferred: bool = False, group: Optional[str] = None, - comparator_factory: Optional[Type[Comparator]] = None, + comparator_factory: Optional[Type[Comparator[_CC]]] = None, info: Optional[_InfoType] = None, ): super().__init__() @@ -170,7 +217,7 @@ class Composite( # will initialize within declarative_scan self.composite_class = None # type: ignore else: - self.composite_class = class_ + self.composite_class = class_ # type: ignore self.attrs = attrs self.active_history = active_history @@ -183,18 +230,16 @@ class Composite( ) self._generated_composite_accessor = None if info is not None: - self.info = info + self.info.update(info) util.set_creation_order(self) self._create_descriptor() - def instrument_class(self, mapper): + def instrument_class(self, mapper: Mapper[Any]) -> None: super().instrument_class(mapper) self._setup_event_handlers() - def _composite_values_from_instance( - self, value: _CompositeClassProto - ) -> Tuple[Any, ...]: + def _composite_values_from_instance(self, value: _CC) -> Tuple[Any, ...]: if self._generated_composite_accessor: return self._generated_composite_accessor(value) else: @@ -209,7 +254,7 @@ class Composite( else: return accessor() - def do_init(self): + def do_init(self) -> None: """Initialization which occurs after the :class:`.Composite` has been associated with its parent mapper. @@ -218,13 +263,13 @@ class Composite( _COMPOSITE_FGET = object() - def _create_descriptor(self): + def _create_descriptor(self) -> None: """Create the Python descriptor that will serve as the access point on instances of the mapped class. """ - def fget(instance): + def fget(instance: Any) -> Any: dict_ = attributes.instance_dict(instance) state = attributes.instance_state(instance) @@ -251,11 +296,11 @@ class Composite( return dict_.get(self.key, None) - def fset(instance, value): + def fset(instance: Any, value: Any) -> None: dict_ = attributes.instance_dict(instance) state = attributes.instance_state(instance) attr = state.manager[self.key] - previous = dict_.get(self.key, attributes.NO_VALUE) + previous = dict_.get(self.key, LoaderCallableStatus.NO_VALUE) for fn in attr.dispatch.set: value = fn(state, value, previous, attr.impl) dict_[self.key] = value @@ -269,10 +314,10 @@ class Composite( ): setattr(instance, key, value) - def fdel(instance): + def fdel(instance: Any) -> None: state = attributes.instance_state(instance) dict_ = attributes.instance_dict(instance) - previous = dict_.pop(self.key, attributes.NO_VALUE) + previous = dict_.pop(self.key, LoaderCallableStatus.NO_VALUE) attr = state.manager[self.key] attr.dispatch.remove(state, previous, attr.impl) for key in self._attribute_keys: @@ -282,8 +327,13 @@ class Composite( @util.preload_module("sqlalchemy.orm.properties") def declarative_scan( - self, registry, cls, key, annotation, is_dataclass_field - ): + self, + registry: _RegistryType, + cls: Type[Any], + key: str, + annotation: Optional[_AnnotationScanType], + is_dataclass_field: bool, + ) -> None: MappedColumn = util.preloaded.orm_properties.MappedColumn argument = _extract_mapped_subtype( @@ -310,7 +360,9 @@ class Composite( @util.preload_module("sqlalchemy.orm.properties") @util.preload_module("sqlalchemy.orm.decl_base") - def _setup_for_dataclass(self, registry, cls, key): + def _setup_for_dataclass( + self, registry: _RegistryType, cls: Type[Any], key: str + ) -> None: MappedColumn = util.preloaded.orm_properties.MappedColumn decl_base = util.preloaded.orm_decl_base @@ -341,12 +393,12 @@ class Composite( self._generated_composite_accessor = getter @util.memoized_property - def _comparable_elements(self): + def _comparable_elements(self) -> Sequence[QueryableAttribute[Any]]: return [getattr(self.parent.class_, prop.key) for prop in self.props] @util.memoized_property @util.preload_module("orm.properties") - def props(self): + def props(self) -> Sequence[MapperProperty[Any]]: props = [] MappedColumn = util.preloaded.orm_properties.MappedColumn @@ -360,17 +412,20 @@ class Composite( elif isinstance(attr, attributes.InstrumentedAttribute): prop = attr.property else: + prop = None + + if not isinstance(prop, MapperProperty): raise sa_exc.ArgumentError( "Composite expects Column objects or mapped " - "attributes/attribute names as arguments, got: %r" - % (attr,) + f"attributes/attribute names as arguments, got: {attr!r}" ) + props.append(prop) return props - @property + @util.non_memoized_property @util.preload_module("orm.properties") - def columns(self): + def columns(self) -> Sequence[Column[Any]]: MappedColumn = util.preloaded.orm_properties.MappedColumn return [ a.column if isinstance(a, MappedColumn) else a @@ -379,32 +434,46 @@ class Composite( ] @property - def mapper_property_to_assign(self) -> Optional["MapperProperty[_T]"]: + def mapper_property_to_assign(self) -> Optional[MapperProperty[_CC]]: return self @property - def columns_to_assign(self) -> List[schema.Column]: + def columns_to_assign(self) -> List[schema.Column[Any]]: return [c for c in self.columns if c.table is None] - def _setup_arguments_on_columns(self): + @util.preload_module("orm.properties") + def _setup_arguments_on_columns(self) -> None: """Propagate configuration arguments made on this composite to the target columns, for those that apply. """ + ColumnProperty = util.preloaded.orm_properties.ColumnProperty + for prop in self.props: - prop.active_history = self.active_history + if not isinstance(prop, ColumnProperty): + continue + else: + cprop = prop + + cprop.active_history = self.active_history if self.deferred: - prop.deferred = self.deferred - prop.strategy_key = (("deferred", True), ("instrument", True)) - prop.group = self.group + cprop.deferred = self.deferred + cprop.strategy_key = (("deferred", True), ("instrument", True)) + cprop.group = self.group - def _setup_event_handlers(self): + def _setup_event_handlers(self) -> None: """Establish events that populate/expire the composite attribute.""" - def load_handler(state, context): + def load_handler( + state: InstanceState[Any], context: ORMCompileState + ) -> None: _load_refresh_handler(state, context, None, is_refresh=False) - def refresh_handler(state, context, to_load): + def refresh_handler( + state: InstanceState[Any], + context: ORMCompileState, + to_load: Optional[Sequence[str]], + ) -> None: # note this corresponds to sqlalchemy.ext.mutable load_attrs() if not to_load or ( @@ -412,7 +481,12 @@ class Composite( ).intersection(to_load): _load_refresh_handler(state, context, to_load, is_refresh=True) - def _load_refresh_handler(state, context, to_load, is_refresh): + def _load_refresh_handler( + state: InstanceState[Any], + context: ORMCompileState, + to_load: Optional[Sequence[str]], + is_refresh: bool, + ) -> None: dict_ = state.dict # if context indicates we are coming from the @@ -440,11 +514,17 @@ class Composite( *[state.dict[key] for key in self._attribute_keys] ) - def expire_handler(state, keys): + def expire_handler( + state: InstanceState[Any], keys: Optional[Sequence[str]] + ) -> None: if keys is None or set(self._attribute_keys).intersection(keys): state.dict.pop(self.key, None) - def insert_update_handler(mapper, connection, state): + def insert_update_handler( + mapper: Mapper[Any], + connection: Connection, + state: InstanceState[Any], + ) -> None: """After an insert or update, some columns may be expired due to server side defaults, or re-populated due to client side defaults. Pop out the composite value here so that it @@ -473,14 +553,19 @@ class Composite( # TODO: need a deserialize hook here @util.memoized_property - def _attribute_keys(self): + def _attribute_keys(self) -> Sequence[str]: return [prop.key for prop in self.props] - def get_history(self, state, dict_, passive=attributes.PASSIVE_OFF): + def get_history( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + passive: PassiveFlag = PassiveFlag.PASSIVE_OFF, + ) -> History: """Provided for userland code that uses attributes.get_history().""" - added = [] - deleted = [] + added: List[Any] = [] + deleted: List[Any] = [] has_history = False for prop in self.props: @@ -508,16 +593,27 @@ class Composite( else: return attributes.History((), [self.composite_class(*added)], ()) - def _comparator_factory(self, mapper): + def _comparator_factory( + self, mapper: Mapper[Any] + ) -> Composite.Comparator[_CC]: return self.comparator_factory(self, mapper) - class CompositeBundle(orm_util.Bundle): - def __init__(self, property_, expr): + class CompositeBundle(orm_util.Bundle[_T]): + def __init__( + self, + property_: Composite[_T], + expr: ClauseList, + ): self.property = property_ super().__init__(property_.key, *expr) - def create_row_processor(self, query, procs, labels): - def proc(row): + def create_row_processor( + self, + query: Select[Any], + procs: Sequence[Callable[[Row[Any]], Any]], + labels: Sequence[str], + ) -> Callable[[Row[Any]], Any]: + def proc(row: Row[Any]) -> Any: return self.property.composite_class( *[proc(row) for proc in procs] ) @@ -546,17 +642,19 @@ class Composite( # https://github.com/python/mypy/issues/4266 __hash__ = None # type: ignore + prop: RODescriptorReference[Composite[_PT]] + @util.memoized_property - def clauses(self): + def clauses(self) -> ClauseList: return expression.ClauseList( group=False, *self._comparable_elements ) - def __clause_element__(self): + def __clause_element__(self) -> Composite.CompositeBundle[_PT]: return self.expression @util.memoized_property - def expression(self): + def expression(self) -> Composite.CompositeBundle[_PT]: clauses = self.clauses._annotate( { "parententity": self._parententity, @@ -566,13 +664,19 @@ class Composite( ) return Composite.CompositeBundle(self.prop, clauses) - def _bulk_update_tuples(self, value): - if isinstance(value, sql.elements.BindParameter): + def _bulk_update_tuples( + self, value: Any + ) -> Sequence[Tuple[_DMLColumnArgument, Any]]: + if isinstance(value, BindParameter): value = value.value + values: Sequence[Any] + if value is None: values = [None for key in self.prop._attribute_keys] - elif isinstance(value, self.prop.composite_class): + elif isinstance(self.prop.composite_class, type) and isinstance( + value, self.prop.composite_class + ): values = self.prop._composite_values_from_instance(value) else: raise sa_exc.ArgumentError( @@ -580,10 +684,10 @@ class Composite( % (self.prop, value) ) - return zip(self._comparable_elements, values) + return list(zip(self._comparable_elements, values)) @util.memoized_property - def _comparable_elements(self): + def _comparable_elements(self) -> Sequence[QueryableAttribute[Any]]: if self._adapt_to_entity: return [ getattr(self._adapt_to_entity.entity, prop.key) @@ -592,7 +696,8 @@ class Composite( else: return self.prop._comparable_elements - def __eq__(self, other): + def __eq__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 + values: Sequence[Any] if other is None: values = [None] * len(self.prop._comparable_elements) else: @@ -601,13 +706,14 @@ class Composite( a == b for a, b in zip(self.prop._comparable_elements, values) ] if self._adapt_to_entity: + assert self.adapter is not None comparisons = [self.adapter(x) for x in comparisons] return sql.and_(*comparisons) - def __ne__(self, other): + def __ne__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 return sql.not_(self.__eq__(other)) - def __str__(self): + def __str__(self) -> str: return str(self.parent.class_.__name__) + "." + self.key @@ -628,20 +734,24 @@ class ConcreteInheritedProperty(DescriptorProperty[_T]): """ - def _comparator_factory(self, mapper): + def _comparator_factory( + self, mapper: Mapper[Any] + ) -> Type[PropComparator[_T]]: + comparator_callable = None for m in self.parent.iterate_to_root(): p = m._props[self.key] - if not isinstance(p, ConcreteInheritedProperty): + if getattr(p, "comparator_factory", None) is not None: comparator_callable = p.comparator_factory break - return comparator_callable + assert comparator_callable is not None + return comparator_callable(p, mapper) # type: ignore - def __init__(self): + def __init__(self) -> None: super().__init__() - def warn(): + def warn() -> NoReturn: raise AttributeError( "Concrete %s does not implement " "attribute %r at the instance level. Add " @@ -650,13 +760,13 @@ class ConcreteInheritedProperty(DescriptorProperty[_T]): ) class NoninheritedConcreteProp: - def __set__(s, obj, value): + def __set__(s: Any, obj: Any, value: Any) -> NoReturn: warn() - def __delete__(s, obj): + def __delete__(s: Any, obj: Any) -> NoReturn: warn() - def __get__(s, obj, owner): + def __get__(s: Any, obj: Any, owner: Any) -> Any: if obj is None: return self.descriptor warn() @@ -682,14 +792,16 @@ class Synonym(DescriptorProperty[_T]): """ + comparator_factory: Optional[Type[PropComparator[_T]]] + def __init__( self, - name, - map_column=None, - descriptor=None, - comparator_factory=None, - doc=None, - info=None, + name: str, + map_column: Optional[bool] = None, + descriptor: Optional[Any] = None, + comparator_factory: Optional[Type[PropComparator[_T]]] = None, + info: Optional[_InfoType] = None, + doc: Optional[str] = None, ): super().__init__() @@ -697,21 +809,30 @@ class Synonym(DescriptorProperty[_T]): self.map_column = map_column self.descriptor = descriptor self.comparator_factory = comparator_factory - self.doc = doc or (descriptor and descriptor.__doc__) or None + if doc: + self.doc = doc + elif descriptor and descriptor.__doc__: + self.doc = descriptor.__doc__ + else: + self.doc = None if info: - self.info = info + self.info.update(info) util.set_creation_order(self) - @property - def uses_objects(self): - return getattr(self.parent.class_, self.name).impl.uses_objects + if not TYPE_CHECKING: + + @property + def uses_objects(self) -> bool: + return getattr(self.parent.class_, self.name).impl.uses_objects # TODO: when initialized, check _proxied_object, # emit a warning if its not a column-based property @util.memoized_property - def _proxied_object(self): + def _proxied_object( + self, + ) -> Union[MapperProperty[_T], SQLORMOperations[_T]]: attr = getattr(self.parent.class_, self.name) if not hasattr(attr, "property") or not isinstance( attr.property, MapperProperty @@ -720,7 +841,8 @@ class Synonym(DescriptorProperty[_T]): # hybrid or association proxy if isinstance(attr, attributes.QueryableAttribute): return attr.comparator - elif isinstance(attr, operators.ColumnOperators): + elif isinstance(attr, SQLORMOperations): + # assocaition proxy comes here return attr raise sa_exc.InvalidRequestError( @@ -730,7 +852,7 @@ class Synonym(DescriptorProperty[_T]): ) return attr.property - def _comparator_factory(self, mapper): + def _comparator_factory(self, mapper: Mapper[Any]) -> SQLORMOperations[_T]: prop = self._proxied_object if isinstance(prop, MapperProperty): @@ -742,12 +864,17 @@ class Synonym(DescriptorProperty[_T]): else: return prop - def get_history(self, *arg, **kw): - attr = getattr(self.parent.class_, self.name) - return attr.impl.get_history(*arg, **kw) + def get_history( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + passive: PassiveFlag = PassiveFlag.PASSIVE_OFF, + ) -> History: + attr: QueryableAttribute[Any] = getattr(self.parent.class_, self.name) + return attr.impl.get_history(state, dict_, passive=passive) @util.preload_module("sqlalchemy.orm.properties") - def set_parent(self, parent, init): + def set_parent(self, parent: Mapper[Any], init: bool) -> None: properties = util.preloaded.orm_properties if self.map_column: @@ -776,7 +903,7 @@ class Synonym(DescriptorProperty[_T]): "%r for column %r" % (self.key, self.name, self.name, self.key) ) - p = properties.ColumnProperty( + p: ColumnProperty[Any] = properties.ColumnProperty( parent.persist_selectable.c[self.key] ) parent._configure_property(self.name, p, init=init, setparent=True) |