diff options
Diffstat (limited to 'lib/sqlalchemy/orm/descriptor_props.py')
-rw-r--r-- | lib/sqlalchemy/orm/descriptor_props.py | 75 |
1 files changed, 54 insertions, 21 deletions
diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index a366a9534..d67319700 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -28,6 +28,7 @@ from typing import Type from typing import TYPE_CHECKING from typing import TypeVar from typing import Union +import weakref from . import attributes from . import util as orm_util @@ -48,7 +49,6 @@ from .. import sql from .. import util from ..sql import expression from ..sql.elements import BindParameter -from ..util.typing import Protocol if typing.TYPE_CHECKING: from ._typing import _InstanceDict @@ -78,14 +78,6 @@ _T = TypeVar("_T", bound=Any) _PT = TypeVar("_PT", bound=Any) -class _CompositeClassProto(Protocol): - def __init__(self, *args: Any): - ... - - def __composite_values__(self) -> Tuple[Any, ...]: - ... - - class DescriptorProperty(MapperProperty[_T]): """:class:`.MapperProperty` which proxies access to a user-defined descriptor.""" @@ -167,7 +159,12 @@ _CompositeAttrType = Union[ ] -_CC = TypeVar("_CC", bound=_CompositeClassProto) +_CC = TypeVar("_CC", bound=Any) + + +_composite_getters: weakref.WeakKeyDictionary[ + Type[Any], Callable[[Any], Tuple[Any, ...]] +] = weakref.WeakKeyDictionary() class Composite( @@ -236,6 +233,7 @@ class Composite( util.set_creation_order(self) self._create_descriptor() + self._init_accessor() def instrument_class(self, mapper: Mapper[Any]) -> None: super().instrument_class(mapper) @@ -254,7 +252,7 @@ class Composite( " method; can't get state" ) from ae else: - return accessor() + return accessor() # type: ignore def do_init(self) -> None: """Initialization which occurs after the :class:`.Composite` @@ -337,6 +335,7 @@ class Composite( extracted_mapped_annotation: Optional[_AnnotationScanType], is_dataclass_field: bool, ) -> None: + MappedColumn = util.preloaded.orm_properties.MappedColumn if ( self.composite_class is None and extracted_mapped_annotation is None @@ -347,14 +346,57 @@ class Composite( if isinstance(argument, str) or hasattr( argument, "__forward_arg__" ): + str_arg = ( + argument.__forward_arg__ + if hasattr(argument, "__forward_arg__") + else str(argument) + ) raise sa_exc.ArgumentError( f"Can't use forward ref {argument} for composite " - f"class argument" + f"class argument; set up the type as Mapped[{str_arg}]" ) self.composite_class = argument if is_dataclass(self.composite_class): self._setup_for_dataclass(registry, cls, key) + else: + for attr in self.attrs: + if ( + isinstance(attr, (MappedColumn, schema.Column)) + and attr.name is None + ): + raise sa_exc.ArgumentError( + "Composite class column arguments must be named " + "unless a dataclass is used" + ) + self._init_accessor() + + def _init_accessor(self) -> None: + if is_dataclass(self.composite_class) and not hasattr( + self.composite_class, "__composite_values__" + ): + insp = inspect.signature(self.composite_class) + getter = operator.attrgetter( + *[p.name for p in insp.parameters.values()] + ) + if len(insp.parameters) == 1: + self._generated_composite_accessor = lambda obj: (getter(obj),) + else: + self._generated_composite_accessor = getter + + if ( + self.composite_class is not None + and isinstance(self.composite_class, type) + and self.composite_class not in _composite_getters + ): + if self._generated_composite_accessor is not None: + _composite_getters[ + self.composite_class + ] = self._generated_composite_accessor + elif hasattr(self.composite_class, "__composite_values__"): + _composite_getters[ + self.composite_class + ] = lambda obj: obj.__composite_values__() # type: ignore @util.preload_module("sqlalchemy.orm.properties") @util.preload_module("sqlalchemy.orm.decl_base") @@ -388,15 +430,6 @@ class Composite( elif isinstance(attr, schema.Column): decl_base._undefer_column_name(param.name, attr) - if not hasattr(self.composite_class, "__composite_values__"): - getter = operator.attrgetter( - *[p.name for p in insp.parameters.values()] - ) - if len(insp.parameters) == 1: - self._generated_composite_accessor = lambda obj: (getter(obj),) - else: - self._generated_composite_accessor = getter - @util.memoized_property def _comparable_elements(self) -> Sequence[QueryableAttribute[Any]]: return [getattr(self.parent.class_, prop.key) for prop in self.props] |