diff options
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r-- | lib/sqlalchemy/orm/_typing.py | 8 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/descriptor_props.py | 75 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/properties.py | 4 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/session.py | 12 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/sqltypes.py | 2 |
5 files changed, 74 insertions, 27 deletions
diff --git a/lib/sqlalchemy/orm/_typing.py b/lib/sqlalchemy/orm/_typing.py index 0e624afe2..ed04c96c7 100644 --- a/lib/sqlalchemy/orm/_typing.py +++ b/lib/sqlalchemy/orm/_typing.py @@ -24,7 +24,6 @@ if TYPE_CHECKING: from .attributes import QueryableAttribute from .base import PassiveFlag from .decl_api import registry as _registry_type - from .descriptor_props import _CompositeClassProto from .interfaces import InspectionAttr from .interfaces import MapperProperty from .interfaces import UserDefinedOption @@ -103,8 +102,11 @@ def is_user_defined_option( return not opt._is_core and opt._is_user_defined # type: ignore -def is_composite_class(obj: Any) -> TypeGuard[_CompositeClassProto]: - return hasattr(obj, "__composite_values__") +def is_composite_class(obj: Any) -> bool: + # inlining is_dataclass(obj) + return hasattr(obj, "__composite_values__") or hasattr( + obj, "__dataclass_fields__" + ) if TYPE_CHECKING: diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index a366a9534..d67319700 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -28,6 +28,7 @@ from typing import Type from typing import TYPE_CHECKING from typing import TypeVar from typing import Union +import weakref from . import attributes from . import util as orm_util @@ -48,7 +49,6 @@ from .. import sql from .. import util from ..sql import expression from ..sql.elements import BindParameter -from ..util.typing import Protocol if typing.TYPE_CHECKING: from ._typing import _InstanceDict @@ -78,14 +78,6 @@ _T = TypeVar("_T", bound=Any) _PT = TypeVar("_PT", bound=Any) -class _CompositeClassProto(Protocol): - def __init__(self, *args: Any): - ... - - def __composite_values__(self) -> Tuple[Any, ...]: - ... - - class DescriptorProperty(MapperProperty[_T]): """:class:`.MapperProperty` which proxies access to a user-defined descriptor.""" @@ -167,7 +159,12 @@ _CompositeAttrType = Union[ ] -_CC = TypeVar("_CC", bound=_CompositeClassProto) +_CC = TypeVar("_CC", bound=Any) + + +_composite_getters: weakref.WeakKeyDictionary[ + Type[Any], Callable[[Any], Tuple[Any, ...]] +] = weakref.WeakKeyDictionary() class Composite( @@ -236,6 +233,7 @@ class Composite( util.set_creation_order(self) self._create_descriptor() + self._init_accessor() def instrument_class(self, mapper: Mapper[Any]) -> None: super().instrument_class(mapper) @@ -254,7 +252,7 @@ class Composite( " method; can't get state" ) from ae else: - return accessor() + return accessor() # type: ignore def do_init(self) -> None: """Initialization which occurs after the :class:`.Composite` @@ -337,6 +335,7 @@ class Composite( extracted_mapped_annotation: Optional[_AnnotationScanType], is_dataclass_field: bool, ) -> None: + MappedColumn = util.preloaded.orm_properties.MappedColumn if ( self.composite_class is None and extracted_mapped_annotation is None @@ -347,14 +346,57 @@ class Composite( if isinstance(argument, str) or hasattr( argument, "__forward_arg__" ): + str_arg = ( + argument.__forward_arg__ + if hasattr(argument, "__forward_arg__") + else str(argument) + ) raise sa_exc.ArgumentError( f"Can't use forward ref {argument} for composite " - f"class argument" + f"class argument; set up the type as Mapped[{str_arg}]" ) self.composite_class = argument if is_dataclass(self.composite_class): self._setup_for_dataclass(registry, cls, key) + else: + for attr in self.attrs: + if ( + isinstance(attr, (MappedColumn, schema.Column)) + and attr.name is None + ): + raise sa_exc.ArgumentError( + "Composite class column arguments must be named " + "unless a dataclass is used" + ) + self._init_accessor() + + def _init_accessor(self) -> None: + if is_dataclass(self.composite_class) and not hasattr( + self.composite_class, "__composite_values__" + ): + insp = inspect.signature(self.composite_class) + getter = operator.attrgetter( + *[p.name for p in insp.parameters.values()] + ) + if len(insp.parameters) == 1: + self._generated_composite_accessor = lambda obj: (getter(obj),) + else: + self._generated_composite_accessor = getter + + if ( + self.composite_class is not None + and isinstance(self.composite_class, type) + and self.composite_class not in _composite_getters + ): + if self._generated_composite_accessor is not None: + _composite_getters[ + self.composite_class + ] = self._generated_composite_accessor + elif hasattr(self.composite_class, "__composite_values__"): + _composite_getters[ + self.composite_class + ] = lambda obj: obj.__composite_values__() # type: ignore @util.preload_module("sqlalchemy.orm.properties") @util.preload_module("sqlalchemy.orm.decl_base") @@ -388,15 +430,6 @@ class Composite( elif isinstance(attr, schema.Column): decl_base._undefer_column_name(param.name, attr) - if not hasattr(self.composite_class, "__composite_values__"): - getter = operator.attrgetter( - *[p.name for p in insp.parameters.values()] - ) - if len(insp.parameters) == 1: - self._generated_composite_accessor = lambda obj: (getter(obj),) - else: - self._generated_composite_accessor = getter - @util.memoized_property def _comparable_elements(self) -> Sequence[QueryableAttribute[Any]]: return [getattr(self.parent.class_, prop.key) for prop in self.props] diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index d77d6e63c..064422293 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -537,6 +537,10 @@ class MappedColumn( return new @property + def name(self) -> str: + return self.column.name + + @property def mapper_property_to_assign(self) -> Optional["MapperProperty[_T]"]: if self.deferred: return ColumnProperty( diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 788821b98..ec6f41b28 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -35,6 +35,7 @@ import weakref from . import attributes from . import context +from . import descriptor_props from . import exc from . import identity from . import loading @@ -3193,8 +3194,15 @@ class Session(_SessionClassMethods, EventTarget): ) -> Optional[_O]: # convert composite types to individual args - if is_composite_class(primary_key_identity): - primary_key_identity = primary_key_identity.__composite_values__() + if ( + is_composite_class(primary_key_identity) + and type(primary_key_identity) + in descriptor_props._composite_getters + ): + getter = descriptor_props._composite_getters[ + type(primary_key_identity) + ] + primary_key_identity = getter(primary_key_identity) mapper: Optional[Mapper[_O]] = inspect(entity) diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index faa0c794c..32f0813f5 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -3359,7 +3359,7 @@ class Uuid(TypeEngine[_UUID_RETURN]): __visit_name__ = "uuid" - collation = None + collation: Optional[str] = None @overload def __init__( |