summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/orm/_typing.py8
-rw-r--r--lib/sqlalchemy/orm/descriptor_props.py75
-rw-r--r--lib/sqlalchemy/orm/properties.py4
-rw-r--r--lib/sqlalchemy/orm/session.py12
-rw-r--r--lib/sqlalchemy/sql/sqltypes.py2
5 files changed, 74 insertions, 27 deletions
diff --git a/lib/sqlalchemy/orm/_typing.py b/lib/sqlalchemy/orm/_typing.py
index 0e624afe2..ed04c96c7 100644
--- a/lib/sqlalchemy/orm/_typing.py
+++ b/lib/sqlalchemy/orm/_typing.py
@@ -24,7 +24,6 @@ if TYPE_CHECKING:
from .attributes import QueryableAttribute
from .base import PassiveFlag
from .decl_api import registry as _registry_type
- from .descriptor_props import _CompositeClassProto
from .interfaces import InspectionAttr
from .interfaces import MapperProperty
from .interfaces import UserDefinedOption
@@ -103,8 +102,11 @@ def is_user_defined_option(
return not opt._is_core and opt._is_user_defined # type: ignore
-def is_composite_class(obj: Any) -> TypeGuard[_CompositeClassProto]:
- return hasattr(obj, "__composite_values__")
+def is_composite_class(obj: Any) -> bool:
+ # inlining is_dataclass(obj)
+ return hasattr(obj, "__composite_values__") or hasattr(
+ obj, "__dataclass_fields__"
+ )
if TYPE_CHECKING:
diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py
index a366a9534..d67319700 100644
--- a/lib/sqlalchemy/orm/descriptor_props.py
+++ b/lib/sqlalchemy/orm/descriptor_props.py
@@ -28,6 +28,7 @@ from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
+import weakref
from . import attributes
from . import util as orm_util
@@ -48,7 +49,6 @@ from .. import sql
from .. import util
from ..sql import expression
from ..sql.elements import BindParameter
-from ..util.typing import Protocol
if typing.TYPE_CHECKING:
from ._typing import _InstanceDict
@@ -78,14 +78,6 @@ _T = TypeVar("_T", bound=Any)
_PT = TypeVar("_PT", bound=Any)
-class _CompositeClassProto(Protocol):
- def __init__(self, *args: Any):
- ...
-
- def __composite_values__(self) -> Tuple[Any, ...]:
- ...
-
-
class DescriptorProperty(MapperProperty[_T]):
""":class:`.MapperProperty` which proxies access to a
user-defined descriptor."""
@@ -167,7 +159,12 @@ _CompositeAttrType = Union[
]
-_CC = TypeVar("_CC", bound=_CompositeClassProto)
+_CC = TypeVar("_CC", bound=Any)
+
+
+_composite_getters: weakref.WeakKeyDictionary[
+ Type[Any], Callable[[Any], Tuple[Any, ...]]
+] = weakref.WeakKeyDictionary()
class Composite(
@@ -236,6 +233,7 @@ class Composite(
util.set_creation_order(self)
self._create_descriptor()
+ self._init_accessor()
def instrument_class(self, mapper: Mapper[Any]) -> None:
super().instrument_class(mapper)
@@ -254,7 +252,7 @@ class Composite(
" method; can't get state"
) from ae
else:
- return accessor()
+ return accessor() # type: ignore
def do_init(self) -> None:
"""Initialization which occurs after the :class:`.Composite`
@@ -337,6 +335,7 @@ class Composite(
extracted_mapped_annotation: Optional[_AnnotationScanType],
is_dataclass_field: bool,
) -> None:
+ MappedColumn = util.preloaded.orm_properties.MappedColumn
if (
self.composite_class is None
and extracted_mapped_annotation is None
@@ -347,14 +346,57 @@ class Composite(
if isinstance(argument, str) or hasattr(
argument, "__forward_arg__"
):
+ str_arg = (
+ argument.__forward_arg__
+ if hasattr(argument, "__forward_arg__")
+ else str(argument)
+ )
raise sa_exc.ArgumentError(
f"Can't use forward ref {argument} for composite "
- f"class argument"
+ f"class argument; set up the type as Mapped[{str_arg}]"
)
self.composite_class = argument
if is_dataclass(self.composite_class):
self._setup_for_dataclass(registry, cls, key)
+ else:
+ for attr in self.attrs:
+ if (
+ isinstance(attr, (MappedColumn, schema.Column))
+ and attr.name is None
+ ):
+ raise sa_exc.ArgumentError(
+ "Composite class column arguments must be named "
+ "unless a dataclass is used"
+ )
+ self._init_accessor()
+
+ def _init_accessor(self) -> None:
+ if is_dataclass(self.composite_class) and not hasattr(
+ self.composite_class, "__composite_values__"
+ ):
+ insp = inspect.signature(self.composite_class)
+ getter = operator.attrgetter(
+ *[p.name for p in insp.parameters.values()]
+ )
+ if len(insp.parameters) == 1:
+ self._generated_composite_accessor = lambda obj: (getter(obj),)
+ else:
+ self._generated_composite_accessor = getter
+
+ if (
+ self.composite_class is not None
+ and isinstance(self.composite_class, type)
+ and self.composite_class not in _composite_getters
+ ):
+ if self._generated_composite_accessor is not None:
+ _composite_getters[
+ self.composite_class
+ ] = self._generated_composite_accessor
+ elif hasattr(self.composite_class, "__composite_values__"):
+ _composite_getters[
+ self.composite_class
+ ] = lambda obj: obj.__composite_values__() # type: ignore
@util.preload_module("sqlalchemy.orm.properties")
@util.preload_module("sqlalchemy.orm.decl_base")
@@ -388,15 +430,6 @@ class Composite(
elif isinstance(attr, schema.Column):
decl_base._undefer_column_name(param.name, attr)
- if not hasattr(self.composite_class, "__composite_values__"):
- getter = operator.attrgetter(
- *[p.name for p in insp.parameters.values()]
- )
- if len(insp.parameters) == 1:
- self._generated_composite_accessor = lambda obj: (getter(obj),)
- else:
- self._generated_composite_accessor = getter
-
@util.memoized_property
def _comparable_elements(self) -> Sequence[QueryableAttribute[Any]]:
return [getattr(self.parent.class_, prop.key) for prop in self.props]
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index d77d6e63c..064422293 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -537,6 +537,10 @@ class MappedColumn(
return new
@property
+ def name(self) -> str:
+ return self.column.name
+
+ @property
def mapper_property_to_assign(self) -> Optional["MapperProperty[_T]"]:
if self.deferred:
return ColumnProperty(
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index 788821b98..ec6f41b28 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -35,6 +35,7 @@ import weakref
from . import attributes
from . import context
+from . import descriptor_props
from . import exc
from . import identity
from . import loading
@@ -3193,8 +3194,15 @@ class Session(_SessionClassMethods, EventTarget):
) -> Optional[_O]:
# convert composite types to individual args
- if is_composite_class(primary_key_identity):
- primary_key_identity = primary_key_identity.__composite_values__()
+ if (
+ is_composite_class(primary_key_identity)
+ and type(primary_key_identity)
+ in descriptor_props._composite_getters
+ ):
+ getter = descriptor_props._composite_getters[
+ type(primary_key_identity)
+ ]
+ primary_key_identity = getter(primary_key_identity)
mapper: Optional[Mapper[_O]] = inspect(entity)
diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py
index faa0c794c..32f0813f5 100644
--- a/lib/sqlalchemy/sql/sqltypes.py
+++ b/lib/sqlalchemy/sql/sqltypes.py
@@ -3359,7 +3359,7 @@ class Uuid(TypeEngine[_UUID_RETURN]):
__visit_name__ = "uuid"
- collation = None
+ collation: Optional[str] = None
@overload
def __init__(