summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/descriptor_props.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/orm/descriptor_props.py')
-rw-r--r--lib/sqlalchemy/orm/descriptor_props.py335
1 files changed, 231 insertions, 104 deletions
diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py
index 5975c30db..8c89f96aa 100644
--- a/lib/sqlalchemy/orm/descriptor_props.py
+++ b/lib/sqlalchemy/orm/descriptor_props.py
@@ -20,15 +20,21 @@ import typing
from typing import Any
from typing import Callable
from typing import List
+from typing import NoReturn
from typing import Optional
+from typing import Sequence
from typing import Tuple
from typing import Type
+from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
from . import attributes
from . import util as orm_util
+from .base import LoaderCallableStatus
from .base import Mapped
+from .base import PassiveFlag
+from .base import SQLORMOperations
from .interfaces import _IntrospectsAnnotations
from .interfaces import _MapsColumns
from .interfaces import MapperProperty
@@ -41,20 +47,41 @@ from .. import schema
from .. import sql
from .. import util
from ..sql import expression
-from ..sql import operators
+from ..sql.elements import BindParameter
from ..util.typing import Protocol
if typing.TYPE_CHECKING:
+ from ._typing import _InstanceDict
+ from ._typing import _RegistryType
+ from .attributes import History
from .attributes import InstrumentedAttribute
+ from .attributes import QueryableAttribute
+ from .context import ORMCompileState
+ from .mapper import Mapper
+ from .properties import ColumnProperty
from .properties import MappedColumn
+ from .state import InstanceState
+ from ..engine.base import Connection
+ from ..engine.row import Row
+ from ..sql._typing import _DMLColumnArgument
from ..sql._typing import _InfoType
+ from ..sql.elements import ClauseList
+ from ..sql.elements import ColumnElement
from ..sql.schema import Column
+ from ..sql.selectable import Select
+ from ..util.typing import _AnnotationScanType
+ from ..util.typing import CallableReference
+ from ..util.typing import DescriptorReference
+ from ..util.typing import RODescriptorReference
_T = TypeVar("_T", bound=Any)
_PT = TypeVar("_PT", bound=Any)
class _CompositeClassProto(Protocol):
+ def __init__(self, *args: Any):
+ ...
+
def __composite_values__(self) -> Tuple[Any, ...]:
...
@@ -63,32 +90,43 @@ class DescriptorProperty(MapperProperty[_T]):
""":class:`.MapperProperty` which proxies access to a
user-defined descriptor."""
- doc = None
+ doc: Optional[str] = None
uses_objects = False
_links_to_entity = False
- def instrument_class(self, mapper):
+ descriptor: DescriptorReference[Any]
+
+ def get_history(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
+ ) -> History:
+ raise NotImplementedError()
+
+ def instrument_class(self, mapper: Mapper[Any]) -> None:
prop = self
- class _ProxyImpl:
+ class _ProxyImpl(attributes.AttributeImpl):
accepts_scalar_loader = False
load_on_unexpire = True
collection = False
@property
- def uses_objects(self):
+ def uses_objects(self) -> bool: # type: ignore
return prop.uses_objects
- def __init__(self, key):
+ def __init__(self, key: str):
self.key = key
- if hasattr(prop, "get_history"):
-
- def get_history(
- self, state, dict_, passive=attributes.PASSIVE_OFF
- ):
- return prop.get_history(state, dict_, passive)
+ def get_history(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
+ ) -> History:
+ return prop.get_history(state, dict_, passive)
if self.descriptor is None:
desc = getattr(mapper.class_, self.key, None)
@@ -97,13 +135,13 @@ class DescriptorProperty(MapperProperty[_T]):
if self.descriptor is None:
- def fset(obj, value):
+ def fset(obj: Any, value: Any) -> None:
setattr(obj, self.name, value)
- def fdel(obj):
+ def fdel(obj: Any) -> None:
delattr(obj, self.name)
- def fget(obj):
+ def fget(obj: Any) -> Any:
return getattr(obj, self.name)
self.descriptor = property(fget=fget, fset=fset, fdel=fdel)
@@ -129,8 +167,11 @@ _CompositeAttrType = Union[
]
+_CC = TypeVar("_CC", bound=_CompositeClassProto)
+
+
class Composite(
- _MapsColumns[_T], _IntrospectsAnnotations, DescriptorProperty[_T]
+ _MapsColumns[_CC], _IntrospectsAnnotations, DescriptorProperty[_CC]
):
"""Defines a "composite" mapped attribute, representing a collection
of columns as one attribute.
@@ -148,19 +189,25 @@ class Composite(
"""
- composite_class: Union[
- Type[_CompositeClassProto], Callable[..., Type[_CompositeClassProto]]
+ composite_class: Union[Type[_CC], Callable[..., _CC]]
+ attrs: Tuple[_CompositeAttrType[Any], ...]
+
+ _generated_composite_accessor: CallableReference[
+ Optional[Callable[[_CC], Tuple[Any, ...]]]
]
- attrs: Tuple[_CompositeAttrType, ...]
+
+ comparator_factory: Type[Comparator[_CC]]
def __init__(
self,
- class_: Union[None, _CompositeClassProto, _CompositeAttrType] = None,
- *attrs: _CompositeAttrType,
+ class_: Union[
+ None, Type[_CC], Callable[..., _CC], _CompositeAttrType[Any]
+ ] = None,
+ *attrs: _CompositeAttrType[Any],
active_history: bool = False,
deferred: bool = False,
group: Optional[str] = None,
- comparator_factory: Optional[Type[Comparator]] = None,
+ comparator_factory: Optional[Type[Comparator[_CC]]] = None,
info: Optional[_InfoType] = None,
):
super().__init__()
@@ -170,7 +217,7 @@ class Composite(
# will initialize within declarative_scan
self.composite_class = None # type: ignore
else:
- self.composite_class = class_
+ self.composite_class = class_ # type: ignore
self.attrs = attrs
self.active_history = active_history
@@ -183,18 +230,16 @@ class Composite(
)
self._generated_composite_accessor = None
if info is not None:
- self.info = info
+ self.info.update(info)
util.set_creation_order(self)
self._create_descriptor()
- def instrument_class(self, mapper):
+ def instrument_class(self, mapper: Mapper[Any]) -> None:
super().instrument_class(mapper)
self._setup_event_handlers()
- def _composite_values_from_instance(
- self, value: _CompositeClassProto
- ) -> Tuple[Any, ...]:
+ def _composite_values_from_instance(self, value: _CC) -> Tuple[Any, ...]:
if self._generated_composite_accessor:
return self._generated_composite_accessor(value)
else:
@@ -209,7 +254,7 @@ class Composite(
else:
return accessor()
- def do_init(self):
+ def do_init(self) -> None:
"""Initialization which occurs after the :class:`.Composite`
has been associated with its parent mapper.
@@ -218,13 +263,13 @@ class Composite(
_COMPOSITE_FGET = object()
- def _create_descriptor(self):
+ def _create_descriptor(self) -> None:
"""Create the Python descriptor that will serve as
the access point on instances of the mapped class.
"""
- def fget(instance):
+ def fget(instance: Any) -> Any:
dict_ = attributes.instance_dict(instance)
state = attributes.instance_state(instance)
@@ -251,11 +296,11 @@ class Composite(
return dict_.get(self.key, None)
- def fset(instance, value):
+ def fset(instance: Any, value: Any) -> None:
dict_ = attributes.instance_dict(instance)
state = attributes.instance_state(instance)
attr = state.manager[self.key]
- previous = dict_.get(self.key, attributes.NO_VALUE)
+ previous = dict_.get(self.key, LoaderCallableStatus.NO_VALUE)
for fn in attr.dispatch.set:
value = fn(state, value, previous, attr.impl)
dict_[self.key] = value
@@ -269,10 +314,10 @@ class Composite(
):
setattr(instance, key, value)
- def fdel(instance):
+ def fdel(instance: Any) -> None:
state = attributes.instance_state(instance)
dict_ = attributes.instance_dict(instance)
- previous = dict_.pop(self.key, attributes.NO_VALUE)
+ previous = dict_.pop(self.key, LoaderCallableStatus.NO_VALUE)
attr = state.manager[self.key]
attr.dispatch.remove(state, previous, attr.impl)
for key in self._attribute_keys:
@@ -282,8 +327,13 @@ class Composite(
@util.preload_module("sqlalchemy.orm.properties")
def declarative_scan(
- self, registry, cls, key, annotation, is_dataclass_field
- ):
+ self,
+ registry: _RegistryType,
+ cls: Type[Any],
+ key: str,
+ annotation: Optional[_AnnotationScanType],
+ is_dataclass_field: bool,
+ ) -> None:
MappedColumn = util.preloaded.orm_properties.MappedColumn
argument = _extract_mapped_subtype(
@@ -310,7 +360,9 @@ class Composite(
@util.preload_module("sqlalchemy.orm.properties")
@util.preload_module("sqlalchemy.orm.decl_base")
- def _setup_for_dataclass(self, registry, cls, key):
+ def _setup_for_dataclass(
+ self, registry: _RegistryType, cls: Type[Any], key: str
+ ) -> None:
MappedColumn = util.preloaded.orm_properties.MappedColumn
decl_base = util.preloaded.orm_decl_base
@@ -341,12 +393,12 @@ class Composite(
self._generated_composite_accessor = getter
@util.memoized_property
- def _comparable_elements(self):
+ def _comparable_elements(self) -> Sequence[QueryableAttribute[Any]]:
return [getattr(self.parent.class_, prop.key) for prop in self.props]
@util.memoized_property
@util.preload_module("orm.properties")
- def props(self):
+ def props(self) -> Sequence[MapperProperty[Any]]:
props = []
MappedColumn = util.preloaded.orm_properties.MappedColumn
@@ -360,17 +412,20 @@ class Composite(
elif isinstance(attr, attributes.InstrumentedAttribute):
prop = attr.property
else:
+ prop = None
+
+ if not isinstance(prop, MapperProperty):
raise sa_exc.ArgumentError(
"Composite expects Column objects or mapped "
- "attributes/attribute names as arguments, got: %r"
- % (attr,)
+ f"attributes/attribute names as arguments, got: {attr!r}"
)
+
props.append(prop)
return props
- @property
+ @util.non_memoized_property
@util.preload_module("orm.properties")
- def columns(self):
+ def columns(self) -> Sequence[Column[Any]]:
MappedColumn = util.preloaded.orm_properties.MappedColumn
return [
a.column if isinstance(a, MappedColumn) else a
@@ -379,32 +434,46 @@ class Composite(
]
@property
- def mapper_property_to_assign(self) -> Optional["MapperProperty[_T]"]:
+ def mapper_property_to_assign(self) -> Optional[MapperProperty[_CC]]:
return self
@property
- def columns_to_assign(self) -> List[schema.Column]:
+ def columns_to_assign(self) -> List[schema.Column[Any]]:
return [c for c in self.columns if c.table is None]
- def _setup_arguments_on_columns(self):
+ @util.preload_module("orm.properties")
+ def _setup_arguments_on_columns(self) -> None:
"""Propagate configuration arguments made on this composite
to the target columns, for those that apply.
"""
+ ColumnProperty = util.preloaded.orm_properties.ColumnProperty
+
for prop in self.props:
- prop.active_history = self.active_history
+ if not isinstance(prop, ColumnProperty):
+ continue
+ else:
+ cprop = prop
+
+ cprop.active_history = self.active_history
if self.deferred:
- prop.deferred = self.deferred
- prop.strategy_key = (("deferred", True), ("instrument", True))
- prop.group = self.group
+ cprop.deferred = self.deferred
+ cprop.strategy_key = (("deferred", True), ("instrument", True))
+ cprop.group = self.group
- def _setup_event_handlers(self):
+ def _setup_event_handlers(self) -> None:
"""Establish events that populate/expire the composite attribute."""
- def load_handler(state, context):
+ def load_handler(
+ state: InstanceState[Any], context: ORMCompileState
+ ) -> None:
_load_refresh_handler(state, context, None, is_refresh=False)
- def refresh_handler(state, context, to_load):
+ def refresh_handler(
+ state: InstanceState[Any],
+ context: ORMCompileState,
+ to_load: Optional[Sequence[str]],
+ ) -> None:
# note this corresponds to sqlalchemy.ext.mutable load_attrs()
if not to_load or (
@@ -412,7 +481,12 @@ class Composite(
).intersection(to_load):
_load_refresh_handler(state, context, to_load, is_refresh=True)
- def _load_refresh_handler(state, context, to_load, is_refresh):
+ def _load_refresh_handler(
+ state: InstanceState[Any],
+ context: ORMCompileState,
+ to_load: Optional[Sequence[str]],
+ is_refresh: bool,
+ ) -> None:
dict_ = state.dict
# if context indicates we are coming from the
@@ -440,11 +514,17 @@ class Composite(
*[state.dict[key] for key in self._attribute_keys]
)
- def expire_handler(state, keys):
+ def expire_handler(
+ state: InstanceState[Any], keys: Optional[Sequence[str]]
+ ) -> None:
if keys is None or set(self._attribute_keys).intersection(keys):
state.dict.pop(self.key, None)
- def insert_update_handler(mapper, connection, state):
+ def insert_update_handler(
+ mapper: Mapper[Any],
+ connection: Connection,
+ state: InstanceState[Any],
+ ) -> None:
"""After an insert or update, some columns may be expired due
to server side defaults, or re-populated due to client side
defaults. Pop out the composite value here so that it
@@ -473,14 +553,19 @@ class Composite(
# TODO: need a deserialize hook here
@util.memoized_property
- def _attribute_keys(self):
+ def _attribute_keys(self) -> Sequence[str]:
return [prop.key for prop in self.props]
- def get_history(self, state, dict_, passive=attributes.PASSIVE_OFF):
+ def get_history(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
+ ) -> History:
"""Provided for userland code that uses attributes.get_history()."""
- added = []
- deleted = []
+ added: List[Any] = []
+ deleted: List[Any] = []
has_history = False
for prop in self.props:
@@ -508,16 +593,27 @@ class Composite(
else:
return attributes.History((), [self.composite_class(*added)], ())
- def _comparator_factory(self, mapper):
+ def _comparator_factory(
+ self, mapper: Mapper[Any]
+ ) -> Composite.Comparator[_CC]:
return self.comparator_factory(self, mapper)
- class CompositeBundle(orm_util.Bundle):
- def __init__(self, property_, expr):
+ class CompositeBundle(orm_util.Bundle[_T]):
+ def __init__(
+ self,
+ property_: Composite[_T],
+ expr: ClauseList,
+ ):
self.property = property_
super().__init__(property_.key, *expr)
- def create_row_processor(self, query, procs, labels):
- def proc(row):
+ def create_row_processor(
+ self,
+ query: Select[Any],
+ procs: Sequence[Callable[[Row[Any]], Any]],
+ labels: Sequence[str],
+ ) -> Callable[[Row[Any]], Any]:
+ def proc(row: Row[Any]) -> Any:
return self.property.composite_class(
*[proc(row) for proc in procs]
)
@@ -546,17 +642,19 @@ class Composite(
# https://github.com/python/mypy/issues/4266
__hash__ = None # type: ignore
+ prop: RODescriptorReference[Composite[_PT]]
+
@util.memoized_property
- def clauses(self):
+ def clauses(self) -> ClauseList:
return expression.ClauseList(
group=False, *self._comparable_elements
)
- def __clause_element__(self):
+ def __clause_element__(self) -> Composite.CompositeBundle[_PT]:
return self.expression
@util.memoized_property
- def expression(self):
+ def expression(self) -> Composite.CompositeBundle[_PT]:
clauses = self.clauses._annotate(
{
"parententity": self._parententity,
@@ -566,13 +664,19 @@ class Composite(
)
return Composite.CompositeBundle(self.prop, clauses)
- def _bulk_update_tuples(self, value):
- if isinstance(value, sql.elements.BindParameter):
+ def _bulk_update_tuples(
+ self, value: Any
+ ) -> Sequence[Tuple[_DMLColumnArgument, Any]]:
+ if isinstance(value, BindParameter):
value = value.value
+ values: Sequence[Any]
+
if value is None:
values = [None for key in self.prop._attribute_keys]
- elif isinstance(value, self.prop.composite_class):
+ elif isinstance(self.prop.composite_class, type) and isinstance(
+ value, self.prop.composite_class
+ ):
values = self.prop._composite_values_from_instance(value)
else:
raise sa_exc.ArgumentError(
@@ -580,10 +684,10 @@ class Composite(
% (self.prop, value)
)
- return zip(self._comparable_elements, values)
+ return list(zip(self._comparable_elements, values))
@util.memoized_property
- def _comparable_elements(self):
+ def _comparable_elements(self) -> Sequence[QueryableAttribute[Any]]:
if self._adapt_to_entity:
return [
getattr(self._adapt_to_entity.entity, prop.key)
@@ -592,7 +696,8 @@ class Composite(
else:
return self.prop._comparable_elements
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501
+ values: Sequence[Any]
if other is None:
values = [None] * len(self.prop._comparable_elements)
else:
@@ -601,13 +706,14 @@ class Composite(
a == b for a, b in zip(self.prop._comparable_elements, values)
]
if self._adapt_to_entity:
+ assert self.adapter is not None
comparisons = [self.adapter(x) for x in comparisons]
return sql.and_(*comparisons)
- def __ne__(self, other):
+ def __ne__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501
return sql.not_(self.__eq__(other))
- def __str__(self):
+ def __str__(self) -> str:
return str(self.parent.class_.__name__) + "." + self.key
@@ -628,20 +734,24 @@ class ConcreteInheritedProperty(DescriptorProperty[_T]):
"""
- def _comparator_factory(self, mapper):
+ def _comparator_factory(
+ self, mapper: Mapper[Any]
+ ) -> Type[PropComparator[_T]]:
+
comparator_callable = None
for m in self.parent.iterate_to_root():
p = m._props[self.key]
- if not isinstance(p, ConcreteInheritedProperty):
+ if getattr(p, "comparator_factory", None) is not None:
comparator_callable = p.comparator_factory
break
- return comparator_callable
+ assert comparator_callable is not None
+ return comparator_callable(p, mapper) # type: ignore
- def __init__(self):
+ def __init__(self) -> None:
super().__init__()
- def warn():
+ def warn() -> NoReturn:
raise AttributeError(
"Concrete %s does not implement "
"attribute %r at the instance level. Add "
@@ -650,13 +760,13 @@ class ConcreteInheritedProperty(DescriptorProperty[_T]):
)
class NoninheritedConcreteProp:
- def __set__(s, obj, value):
+ def __set__(s: Any, obj: Any, value: Any) -> NoReturn:
warn()
- def __delete__(s, obj):
+ def __delete__(s: Any, obj: Any) -> NoReturn:
warn()
- def __get__(s, obj, owner):
+ def __get__(s: Any, obj: Any, owner: Any) -> Any:
if obj is None:
return self.descriptor
warn()
@@ -682,14 +792,16 @@ class Synonym(DescriptorProperty[_T]):
"""
+ comparator_factory: Optional[Type[PropComparator[_T]]]
+
def __init__(
self,
- name,
- map_column=None,
- descriptor=None,
- comparator_factory=None,
- doc=None,
- info=None,
+ name: str,
+ map_column: Optional[bool] = None,
+ descriptor: Optional[Any] = None,
+ comparator_factory: Optional[Type[PropComparator[_T]]] = None,
+ info: Optional[_InfoType] = None,
+ doc: Optional[str] = None,
):
super().__init__()
@@ -697,21 +809,30 @@ class Synonym(DescriptorProperty[_T]):
self.map_column = map_column
self.descriptor = descriptor
self.comparator_factory = comparator_factory
- self.doc = doc or (descriptor and descriptor.__doc__) or None
+ if doc:
+ self.doc = doc
+ elif descriptor and descriptor.__doc__:
+ self.doc = descriptor.__doc__
+ else:
+ self.doc = None
if info:
- self.info = info
+ self.info.update(info)
util.set_creation_order(self)
- @property
- def uses_objects(self):
- return getattr(self.parent.class_, self.name).impl.uses_objects
+ if not TYPE_CHECKING:
+
+ @property
+ def uses_objects(self) -> bool:
+ return getattr(self.parent.class_, self.name).impl.uses_objects
# TODO: when initialized, check _proxied_object,
# emit a warning if its not a column-based property
@util.memoized_property
- def _proxied_object(self):
+ def _proxied_object(
+ self,
+ ) -> Union[MapperProperty[_T], SQLORMOperations[_T]]:
attr = getattr(self.parent.class_, self.name)
if not hasattr(attr, "property") or not isinstance(
attr.property, MapperProperty
@@ -720,7 +841,8 @@ class Synonym(DescriptorProperty[_T]):
# hybrid or association proxy
if isinstance(attr, attributes.QueryableAttribute):
return attr.comparator
- elif isinstance(attr, operators.ColumnOperators):
+ elif isinstance(attr, SQLORMOperations):
+ # assocaition proxy comes here
return attr
raise sa_exc.InvalidRequestError(
@@ -730,7 +852,7 @@ class Synonym(DescriptorProperty[_T]):
)
return attr.property
- def _comparator_factory(self, mapper):
+ def _comparator_factory(self, mapper: Mapper[Any]) -> SQLORMOperations[_T]:
prop = self._proxied_object
if isinstance(prop, MapperProperty):
@@ -742,12 +864,17 @@ class Synonym(DescriptorProperty[_T]):
else:
return prop
- def get_history(self, *arg, **kw):
- attr = getattr(self.parent.class_, self.name)
- return attr.impl.get_history(*arg, **kw)
+ def get_history(
+ self,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
+ ) -> History:
+ attr: QueryableAttribute[Any] = getattr(self.parent.class_, self.name)
+ return attr.impl.get_history(state, dict_, passive=passive)
@util.preload_module("sqlalchemy.orm.properties")
- def set_parent(self, parent, init):
+ def set_parent(self, parent: Mapper[Any], init: bool) -> None:
properties = util.preloaded.orm_properties
if self.map_column:
@@ -776,7 +903,7 @@ class Synonym(DescriptorProperty[_T]):
"%r for column %r"
% (self.key, self.name, self.name, self.key)
)
- p = properties.ColumnProperty(
+ p: ColumnProperty[Any] = properties.ColumnProperty(
parent.persist_selectable.c[self.key]
)
parent._configure_property(self.name, p, init=init, setparent=True)