diff options
-rw-r--r-- | doc/build/changelog/unreleased_20/9240.rst | 22 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/decl_base.py | 6 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/mapper.py | 57 | ||||
-rw-r--r-- | test/orm/declarative/test_basic.py | 246 |
4 files changed, 325 insertions, 6 deletions
diff --git a/doc/build/changelog/unreleased_20/9240.rst b/doc/build/changelog/unreleased_20/9240.rst new file mode 100644 index 000000000..23e807f62 --- /dev/null +++ b/doc/build/changelog/unreleased_20/9240.rst @@ -0,0 +1,22 @@ +.. change:: + :tags: bug, orm + :tickets: 9240 + + Repaired ORM Declarative mappings to allow for the + :paramref:`_orm.Mapper.primary_key` parameter to be specified within + ``__mapper_args__`` when using :func:`_orm.mapped_column`. Despite this + usage being directly in the 2.0 documentation, the :class:`_orm.Mapper` was + not accepting the :func:`_orm.mapped_column` construct in this context. Ths + feature was already working for the :paramref:`_orm.Mapper.version_id_col` + and :paramref:`_orm.Mapper.polymorphic_on` parameters. + + As part of this change, the ``__mapper_args__`` attribute may be specified + without using :func:`_orm.declared_attr` on a non-mapped mixin class, + including a ``"primary_key"`` entry that refers to :class:`_schema.Column` + or :func:`_orm.mapped_column` objects locally present on the mixin; + Declarative will also translate these columns into the correct ones for a + particular mapped class. This again was working already for the + :paramref:`_orm.Mapper.version_id_col` and + :paramref:`_orm.Mapper.polymorphic_on` parameters. Additionally, + elements within ``"primary_key"`` may be indicated as string names of + existing mapped properties. diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index a858f12cb..37fa964b8 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -1721,6 +1721,12 @@ class _ClassScanMapperConfig(_MapperConfig): v = mapper_args[k] mapper_args[k] = self.column_copies.get(v, v) + if "primary_key" in mapper_args: + mapper_args["primary_key"] = [ + self.column_copies.get(v, v) + for v in util.to_list(mapper_args["primary_key"]) + ] + if "inherits" in mapper_args: inherits_arg = mapper_args["inherits"] if isinstance(inherits_arg, Mapper): diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index a3b209e4a..660c61691 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -83,6 +83,7 @@ from ..sql import util as sql_util from ..sql import visitors from ..sql.cache_key import MemoizedHasCacheKey from ..sql.elements import KeyedColumnElement +from ..sql.schema import Column from ..sql.schema import Table from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL from ..util import HasMemoized @@ -112,7 +113,6 @@ if TYPE_CHECKING: from ..sql.base import ReadOnlyColumnCollection from ..sql.elements import ColumnClause from ..sql.elements import ColumnElement - from ..sql.schema import Column from ..sql.selectable import FromClause from ..util import OrderedSet @@ -650,11 +650,15 @@ class Mapper( :ref:`orm_mapping_classes_toplevel` :param primary_key: A list of :class:`_schema.Column` - objects which define + objects, or alternatively string names of attribute names which + refer to :class:`_schema.Column`, which define the primary key to be used against this mapper's selectable unit. This is normally simply the primary key of the ``local_table``, but can be overridden here. + .. versionchanged:: 2.0.2 :paramref:`_orm.Mapper.primary_key` + arguments may be indicated as string attribute names as well. + .. seealso:: :ref:`mapper_primary_key` - background and example use @@ -1557,6 +1561,29 @@ class Mapper( self.__dict__.pop("_configure_failed", None) + def _str_arg_to_mapped_col(self, argname: str, key: str) -> Column[Any]: + try: + prop = self._props[key] + except KeyError as err: + raise sa_exc.ArgumentError( + f"Can't determine {argname} column '{key}' - " + "no attribute is mapped to this name." + ) from err + try: + expr = prop.expression + except AttributeError as ae: + raise sa_exc.ArgumentError( + f"Can't determine {argname} column '{key}'; " + "property does not refer to a single mapped Column" + ) from ae + if not isinstance(expr, Column): + raise sa_exc.ArgumentError( + f"Can't determine {argname} column '{key}'; " + "property does not refer to a single " + "mapped Column" + ) + return expr + def _configure_pks(self) -> None: self.tables = sql_util.find_tables(self.persist_selectable) @@ -1585,10 +1612,28 @@ class Mapper( all_cols ) + if self._primary_key_argument: + + coerced_pk_arg = [ + self._str_arg_to_mapped_col("primary_key", c) + if isinstance(c, str) + else c + for c in ( + coercions.expect( # type: ignore + roles.DDLConstraintColumnRole, + coerce_pk, + argname="primary_key", + ) + for coerce_pk in self._primary_key_argument + ) + ] + else: + coerced_pk_arg = None + # if explicit PK argument sent, add those columns to the # primary key mappings - if self._primary_key_argument: - for k in self._primary_key_argument: + if coerced_pk_arg: + for k in coerced_pk_arg: if k.table not in self._pks_by_table: self._pks_by_table[k.table] = util.OrderedSet() self._pks_by_table[k.table].add(k) @@ -1625,12 +1670,12 @@ class Mapper( # determine primary key from argument or persist_selectable pks primary_key: Collection[ColumnElement[Any]] - if self._primary_key_argument: + if coerced_pk_arg: primary_key = [ cc if cc is not None else c for cc, c in ( (self.persist_selectable.corresponding_column(c), c) - for c in self._primary_key_argument + for c in coerced_pk_arg ) ] else: diff --git a/test/orm/declarative/test_basic.py b/test/orm/declarative/test_basic.py index e2108f888..45f0d4200 100644 --- a/test/orm/declarative/test_basic.py +++ b/test/orm/declarative/test_basic.py @@ -30,6 +30,7 @@ from sqlalchemy.orm import deferred from sqlalchemy.orm import descriptor_props from sqlalchemy.orm import exc as orm_exc from sqlalchemy.orm import joinedload +from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column from sqlalchemy.orm import MappedColumn from sqlalchemy.orm import Mapper @@ -188,6 +189,251 @@ class DeclarativeBaseSetupsTest(fixtures.TestBase): ): Base.__init__(fs, x=5) + @testing.variation("argument", ["version_id_col", "polymorphic_on"]) + @testing.variation("column_type", ["anno", "non_anno", "plain_column"]) + def test_mapped_column_version_poly_arg( + self, decl_base, column_type, argument + ): + """test #9240""" + + if column_type.anno: + + class A(decl_base): + __tablename__ = "a" + + a: Mapped[int] = mapped_column(primary_key=True) + b: Mapped[int] = mapped_column() + c: Mapped[str] = mapped_column() + + if argument.version_id_col: + __mapper_args__ = {"version_id_col": b} + elif argument.polymorphic_on: + __mapper_args__ = {"polymorphic_on": c} + else: + argument.fail() + + elif column_type.non_anno: + + class A(decl_base): + __tablename__ = "a" + + a = mapped_column(Integer, primary_key=True) + b = mapped_column(Integer) + c = mapped_column(String) + + if argument.version_id_col: + __mapper_args__ = {"version_id_col": b} + elif argument.polymorphic_on: + __mapper_args__ = {"polymorphic_on": c} + else: + argument.fail() + + elif column_type.plain_column: + + class A(decl_base): + __tablename__ = "a" + + a = Column(Integer, primary_key=True) + b = Column(Integer) + c = Column(String) + + if argument.version_id_col: + __mapper_args__ = {"version_id_col": b} + elif argument.polymorphic_on: + __mapper_args__ = {"polymorphic_on": c} + else: + argument.fail() + + else: + column_type.fail() + + if argument.version_id_col: + assert A.__mapper__.version_id_col is A.__table__.c.b + elif argument.polymorphic_on: + assert A.__mapper__.polymorphic_on is A.__table__.c.c + else: + argument.fail() + + @testing.variation( + "pk_type", ["single", "tuple", "list", "single_str", "list_str"] + ) + @testing.variation("column_type", ["anno", "non_anno", "plain_column"]) + def test_mapped_column_pk_arg(self, decl_base, column_type, pk_type): + """test #9240""" + + if column_type.anno: + + class A(decl_base): + __tablename__ = "a" + + a: Mapped[int] = mapped_column() + b: Mapped[int] = mapped_column() + + if pk_type.single: + __mapper_args__ = {"primary_key": a} + elif pk_type.tuple: + __mapper_args__ = {"primary_key": (a, b)} + elif pk_type.list: + __mapper_args__ = {"primary_key": [a, b]} + elif pk_type.single_str: + __mapper_args__ = {"primary_key": "a"} + elif pk_type.list_str: + __mapper_args__ = {"primary_key": ["a", "b"]} + else: + pk_type.fail() + + elif column_type.non_anno: + + class A(decl_base): + __tablename__ = "a" + + a = mapped_column(Integer) + b = mapped_column(Integer) + + if pk_type.single: + __mapper_args__ = {"primary_key": a} + elif pk_type.tuple: + __mapper_args__ = {"primary_key": (a, b)} + elif pk_type.list: + __mapper_args__ = {"primary_key": [a, b]} + elif pk_type.single_str: + __mapper_args__ = {"primary_key": "a"} + elif pk_type.list_str: + __mapper_args__ = {"primary_key": ["a", "b"]} + else: + pk_type.fail() + + elif column_type.plain_column: + + class A(decl_base): + __tablename__ = "a" + + a = Column(Integer) + b = Column(Integer) + + if pk_type.single: + __mapper_args__ = {"primary_key": a} + elif pk_type.tuple: + __mapper_args__ = {"primary_key": (a, b)} + elif pk_type.list: + __mapper_args__ = {"primary_key": [a, b]} + elif pk_type.single_str: + __mapper_args__ = {"primary_key": "a"} + elif pk_type.list_str: + __mapper_args__ = {"primary_key": ["a", "b"]} + else: + pk_type.fail() + + else: + column_type.fail() + + if pk_type.single or pk_type.single_str: + assert A.__mapper__.primary_key[0] is A.__table__.c.a + else: + assert A.__mapper__.primary_key[0] is A.__table__.c.a + assert A.__mapper__.primary_key[1] is A.__table__.c.b + + def test_mapper_pk_arg_degradation_no_col(self, decl_base): + + with expect_raises_message( + exc.ArgumentError, + "Can't determine primary_key column 'q' - no attribute is " + "mapped to this name.", + ): + + class A(decl_base): + __tablename__ = "a" + + a: Mapped[int] = mapped_column() + b: Mapped[int] = mapped_column() + + __mapper_args__ = {"primary_key": "q"} + + @testing.variation("proptype", ["relationship", "colprop"]) + def test_mapper_pk_arg_degradation_is_not_a_col(self, decl_base, proptype): + + with expect_raises_message( + exc.ArgumentError, + "Can't determine primary_key column 'b'; property does " + "not refer to a single mapped Column", + ): + + class A(decl_base): + __tablename__ = "a" + + a: Mapped[int] = mapped_column(Integer) + + if proptype.colprop: + b: Mapped[int] = column_property(a + 5) + elif proptype.relationship: + b = relationship("B") + else: + proptype.fail() + + __mapper_args__ = {"primary_key": "b"} + + @testing.variation( + "argument", ["version_id_col", "polymorphic_on", "primary_key"] + ) + @testing.variation("argtype", ["callable", "fixed"]) + @testing.variation("column_type", ["mapped_column", "plain_column"]) + def test_mapped_column_pk_arg_via_mixin( + self, decl_base, argtype, column_type, argument + ): + """test #9240""" + + class Mixin: + if column_type.mapped_column: + a: Mapped[int] = mapped_column() + b: Mapped[int] = mapped_column() + c: Mapped[str] = mapped_column() + elif column_type.plain_column: + a = Column(Integer) + b = Column(Integer) + c = Column(String) + else: + column_type.fail() + + if argtype.callable: + + @declared_attr.directive + @classmethod + def __mapper_args__(cls): + if argument.primary_key: + return {"primary_key": [cls.a, cls.b]} + elif argument.version_id_col: + return {"version_id_col": cls.b, "primary_key": cls.a} + elif argument.polymorphic_on: + return {"polymorphic_on": cls.c, "primary_key": cls.a} + else: + argument.fail() + + elif argtype.fixed: + if argument.primary_key: + __mapper_args__ = {"primary_key": [a, b]} + elif argument.version_id_col: + __mapper_args__ = {"primary_key": a, "version_id_col": b} + elif argument.polymorphic_on: + __mapper_args__ = {"primary_key": a, "polymorphic_on": c} + else: + argument.fail() + + else: + argtype.fail() + + class A(Mixin, decl_base): + __tablename__ = "a" + + if argument.primary_key: + assert A.__mapper__.primary_key[0] is A.__table__.c.a + assert A.__mapper__.primary_key[1] is A.__table__.c.b + elif argument.version_id_col: + assert A.__mapper__.version_id_col is A.__table__.c.b + elif argument.polymorphic_on: + assert A.__mapper__.polymorphic_on is A.__table__.c.c + else: + argtype.fail() + def test_dispose_attrs(self): reg = registry() |