summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--doc/build/changelog/unreleased_20/9240.rst22
-rw-r--r--lib/sqlalchemy/orm/decl_base.py6
-rw-r--r--lib/sqlalchemy/orm/mapper.py57
-rw-r--r--test/orm/declarative/test_basic.py246
4 files changed, 325 insertions, 6 deletions
diff --git a/doc/build/changelog/unreleased_20/9240.rst b/doc/build/changelog/unreleased_20/9240.rst
new file mode 100644
index 000000000..23e807f62
--- /dev/null
+++ b/doc/build/changelog/unreleased_20/9240.rst
@@ -0,0 +1,22 @@
+.. change::
+ :tags: bug, orm
+ :tickets: 9240
+
+ Repaired ORM Declarative mappings to allow for the
+ :paramref:`_orm.Mapper.primary_key` parameter to be specified within
+ ``__mapper_args__`` when using :func:`_orm.mapped_column`. Despite this
+ usage being directly in the 2.0 documentation, the :class:`_orm.Mapper` was
+ not accepting the :func:`_orm.mapped_column` construct in this context. Ths
+ feature was already working for the :paramref:`_orm.Mapper.version_id_col`
+ and :paramref:`_orm.Mapper.polymorphic_on` parameters.
+
+ As part of this change, the ``__mapper_args__`` attribute may be specified
+ without using :func:`_orm.declared_attr` on a non-mapped mixin class,
+ including a ``"primary_key"`` entry that refers to :class:`_schema.Column`
+ or :func:`_orm.mapped_column` objects locally present on the mixin;
+ Declarative will also translate these columns into the correct ones for a
+ particular mapped class. This again was working already for the
+ :paramref:`_orm.Mapper.version_id_col` and
+ :paramref:`_orm.Mapper.polymorphic_on` parameters. Additionally,
+ elements within ``"primary_key"`` may be indicated as string names of
+ existing mapped properties.
diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py
index a858f12cb..37fa964b8 100644
--- a/lib/sqlalchemy/orm/decl_base.py
+++ b/lib/sqlalchemy/orm/decl_base.py
@@ -1721,6 +1721,12 @@ class _ClassScanMapperConfig(_MapperConfig):
v = mapper_args[k]
mapper_args[k] = self.column_copies.get(v, v)
+ if "primary_key" in mapper_args:
+ mapper_args["primary_key"] = [
+ self.column_copies.get(v, v)
+ for v in util.to_list(mapper_args["primary_key"])
+ ]
+
if "inherits" in mapper_args:
inherits_arg = mapper_args["inherits"]
if isinstance(inherits_arg, Mapper):
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index a3b209e4a..660c61691 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -83,6 +83,7 @@ from ..sql import util as sql_util
from ..sql import visitors
from ..sql.cache_key import MemoizedHasCacheKey
from ..sql.elements import KeyedColumnElement
+from ..sql.schema import Column
from ..sql.schema import Table
from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
from ..util import HasMemoized
@@ -112,7 +113,6 @@ if TYPE_CHECKING:
from ..sql.base import ReadOnlyColumnCollection
from ..sql.elements import ColumnClause
from ..sql.elements import ColumnElement
- from ..sql.schema import Column
from ..sql.selectable import FromClause
from ..util import OrderedSet
@@ -650,11 +650,15 @@ class Mapper(
:ref:`orm_mapping_classes_toplevel`
:param primary_key: A list of :class:`_schema.Column`
- objects which define
+ objects, or alternatively string names of attribute names which
+ refer to :class:`_schema.Column`, which define
the primary key to be used against this mapper's selectable unit.
This is normally simply the primary key of the ``local_table``, but
can be overridden here.
+ .. versionchanged:: 2.0.2 :paramref:`_orm.Mapper.primary_key`
+ arguments may be indicated as string attribute names as well.
+
.. seealso::
:ref:`mapper_primary_key` - background and example use
@@ -1557,6 +1561,29 @@ class Mapper(
self.__dict__.pop("_configure_failed", None)
+ def _str_arg_to_mapped_col(self, argname: str, key: str) -> Column[Any]:
+ try:
+ prop = self._props[key]
+ except KeyError as err:
+ raise sa_exc.ArgumentError(
+ f"Can't determine {argname} column '{key}' - "
+ "no attribute is mapped to this name."
+ ) from err
+ try:
+ expr = prop.expression
+ except AttributeError as ae:
+ raise sa_exc.ArgumentError(
+ f"Can't determine {argname} column '{key}'; "
+ "property does not refer to a single mapped Column"
+ ) from ae
+ if not isinstance(expr, Column):
+ raise sa_exc.ArgumentError(
+ f"Can't determine {argname} column '{key}'; "
+ "property does not refer to a single "
+ "mapped Column"
+ )
+ return expr
+
def _configure_pks(self) -> None:
self.tables = sql_util.find_tables(self.persist_selectable)
@@ -1585,10 +1612,28 @@ class Mapper(
all_cols
)
+ if self._primary_key_argument:
+
+ coerced_pk_arg = [
+ self._str_arg_to_mapped_col("primary_key", c)
+ if isinstance(c, str)
+ else c
+ for c in (
+ coercions.expect( # type: ignore
+ roles.DDLConstraintColumnRole,
+ coerce_pk,
+ argname="primary_key",
+ )
+ for coerce_pk in self._primary_key_argument
+ )
+ ]
+ else:
+ coerced_pk_arg = None
+
# if explicit PK argument sent, add those columns to the
# primary key mappings
- if self._primary_key_argument:
- for k in self._primary_key_argument:
+ if coerced_pk_arg:
+ for k in coerced_pk_arg:
if k.table not in self._pks_by_table:
self._pks_by_table[k.table] = util.OrderedSet()
self._pks_by_table[k.table].add(k)
@@ -1625,12 +1670,12 @@ class Mapper(
# determine primary key from argument or persist_selectable pks
primary_key: Collection[ColumnElement[Any]]
- if self._primary_key_argument:
+ if coerced_pk_arg:
primary_key = [
cc if cc is not None else c
for cc, c in (
(self.persist_selectable.corresponding_column(c), c)
- for c in self._primary_key_argument
+ for c in coerced_pk_arg
)
]
else:
diff --git a/test/orm/declarative/test_basic.py b/test/orm/declarative/test_basic.py
index e2108f888..45f0d4200 100644
--- a/test/orm/declarative/test_basic.py
+++ b/test/orm/declarative/test_basic.py
@@ -30,6 +30,7 @@ from sqlalchemy.orm import deferred
from sqlalchemy.orm import descriptor_props
from sqlalchemy.orm import exc as orm_exc
from sqlalchemy.orm import joinedload
+from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import MappedColumn
from sqlalchemy.orm import Mapper
@@ -188,6 +189,251 @@ class DeclarativeBaseSetupsTest(fixtures.TestBase):
):
Base.__init__(fs, x=5)
+ @testing.variation("argument", ["version_id_col", "polymorphic_on"])
+ @testing.variation("column_type", ["anno", "non_anno", "plain_column"])
+ def test_mapped_column_version_poly_arg(
+ self, decl_base, column_type, argument
+ ):
+ """test #9240"""
+
+ if column_type.anno:
+
+ class A(decl_base):
+ __tablename__ = "a"
+
+ a: Mapped[int] = mapped_column(primary_key=True)
+ b: Mapped[int] = mapped_column()
+ c: Mapped[str] = mapped_column()
+
+ if argument.version_id_col:
+ __mapper_args__ = {"version_id_col": b}
+ elif argument.polymorphic_on:
+ __mapper_args__ = {"polymorphic_on": c}
+ else:
+ argument.fail()
+
+ elif column_type.non_anno:
+
+ class A(decl_base):
+ __tablename__ = "a"
+
+ a = mapped_column(Integer, primary_key=True)
+ b = mapped_column(Integer)
+ c = mapped_column(String)
+
+ if argument.version_id_col:
+ __mapper_args__ = {"version_id_col": b}
+ elif argument.polymorphic_on:
+ __mapper_args__ = {"polymorphic_on": c}
+ else:
+ argument.fail()
+
+ elif column_type.plain_column:
+
+ class A(decl_base):
+ __tablename__ = "a"
+
+ a = Column(Integer, primary_key=True)
+ b = Column(Integer)
+ c = Column(String)
+
+ if argument.version_id_col:
+ __mapper_args__ = {"version_id_col": b}
+ elif argument.polymorphic_on:
+ __mapper_args__ = {"polymorphic_on": c}
+ else:
+ argument.fail()
+
+ else:
+ column_type.fail()
+
+ if argument.version_id_col:
+ assert A.__mapper__.version_id_col is A.__table__.c.b
+ elif argument.polymorphic_on:
+ assert A.__mapper__.polymorphic_on is A.__table__.c.c
+ else:
+ argument.fail()
+
+ @testing.variation(
+ "pk_type", ["single", "tuple", "list", "single_str", "list_str"]
+ )
+ @testing.variation("column_type", ["anno", "non_anno", "plain_column"])
+ def test_mapped_column_pk_arg(self, decl_base, column_type, pk_type):
+ """test #9240"""
+
+ if column_type.anno:
+
+ class A(decl_base):
+ __tablename__ = "a"
+
+ a: Mapped[int] = mapped_column()
+ b: Mapped[int] = mapped_column()
+
+ if pk_type.single:
+ __mapper_args__ = {"primary_key": a}
+ elif pk_type.tuple:
+ __mapper_args__ = {"primary_key": (a, b)}
+ elif pk_type.list:
+ __mapper_args__ = {"primary_key": [a, b]}
+ elif pk_type.single_str:
+ __mapper_args__ = {"primary_key": "a"}
+ elif pk_type.list_str:
+ __mapper_args__ = {"primary_key": ["a", "b"]}
+ else:
+ pk_type.fail()
+
+ elif column_type.non_anno:
+
+ class A(decl_base):
+ __tablename__ = "a"
+
+ a = mapped_column(Integer)
+ b = mapped_column(Integer)
+
+ if pk_type.single:
+ __mapper_args__ = {"primary_key": a}
+ elif pk_type.tuple:
+ __mapper_args__ = {"primary_key": (a, b)}
+ elif pk_type.list:
+ __mapper_args__ = {"primary_key": [a, b]}
+ elif pk_type.single_str:
+ __mapper_args__ = {"primary_key": "a"}
+ elif pk_type.list_str:
+ __mapper_args__ = {"primary_key": ["a", "b"]}
+ else:
+ pk_type.fail()
+
+ elif column_type.plain_column:
+
+ class A(decl_base):
+ __tablename__ = "a"
+
+ a = Column(Integer)
+ b = Column(Integer)
+
+ if pk_type.single:
+ __mapper_args__ = {"primary_key": a}
+ elif pk_type.tuple:
+ __mapper_args__ = {"primary_key": (a, b)}
+ elif pk_type.list:
+ __mapper_args__ = {"primary_key": [a, b]}
+ elif pk_type.single_str:
+ __mapper_args__ = {"primary_key": "a"}
+ elif pk_type.list_str:
+ __mapper_args__ = {"primary_key": ["a", "b"]}
+ else:
+ pk_type.fail()
+
+ else:
+ column_type.fail()
+
+ if pk_type.single or pk_type.single_str:
+ assert A.__mapper__.primary_key[0] is A.__table__.c.a
+ else:
+ assert A.__mapper__.primary_key[0] is A.__table__.c.a
+ assert A.__mapper__.primary_key[1] is A.__table__.c.b
+
+ def test_mapper_pk_arg_degradation_no_col(self, decl_base):
+
+ with expect_raises_message(
+ exc.ArgumentError,
+ "Can't determine primary_key column 'q' - no attribute is "
+ "mapped to this name.",
+ ):
+
+ class A(decl_base):
+ __tablename__ = "a"
+
+ a: Mapped[int] = mapped_column()
+ b: Mapped[int] = mapped_column()
+
+ __mapper_args__ = {"primary_key": "q"}
+
+ @testing.variation("proptype", ["relationship", "colprop"])
+ def test_mapper_pk_arg_degradation_is_not_a_col(self, decl_base, proptype):
+
+ with expect_raises_message(
+ exc.ArgumentError,
+ "Can't determine primary_key column 'b'; property does "
+ "not refer to a single mapped Column",
+ ):
+
+ class A(decl_base):
+ __tablename__ = "a"
+
+ a: Mapped[int] = mapped_column(Integer)
+
+ if proptype.colprop:
+ b: Mapped[int] = column_property(a + 5)
+ elif proptype.relationship:
+ b = relationship("B")
+ else:
+ proptype.fail()
+
+ __mapper_args__ = {"primary_key": "b"}
+
+ @testing.variation(
+ "argument", ["version_id_col", "polymorphic_on", "primary_key"]
+ )
+ @testing.variation("argtype", ["callable", "fixed"])
+ @testing.variation("column_type", ["mapped_column", "plain_column"])
+ def test_mapped_column_pk_arg_via_mixin(
+ self, decl_base, argtype, column_type, argument
+ ):
+ """test #9240"""
+
+ class Mixin:
+ if column_type.mapped_column:
+ a: Mapped[int] = mapped_column()
+ b: Mapped[int] = mapped_column()
+ c: Mapped[str] = mapped_column()
+ elif column_type.plain_column:
+ a = Column(Integer)
+ b = Column(Integer)
+ c = Column(String)
+ else:
+ column_type.fail()
+
+ if argtype.callable:
+
+ @declared_attr.directive
+ @classmethod
+ def __mapper_args__(cls):
+ if argument.primary_key:
+ return {"primary_key": [cls.a, cls.b]}
+ elif argument.version_id_col:
+ return {"version_id_col": cls.b, "primary_key": cls.a}
+ elif argument.polymorphic_on:
+ return {"polymorphic_on": cls.c, "primary_key": cls.a}
+ else:
+ argument.fail()
+
+ elif argtype.fixed:
+ if argument.primary_key:
+ __mapper_args__ = {"primary_key": [a, b]}
+ elif argument.version_id_col:
+ __mapper_args__ = {"primary_key": a, "version_id_col": b}
+ elif argument.polymorphic_on:
+ __mapper_args__ = {"primary_key": a, "polymorphic_on": c}
+ else:
+ argument.fail()
+
+ else:
+ argtype.fail()
+
+ class A(Mixin, decl_base):
+ __tablename__ = "a"
+
+ if argument.primary_key:
+ assert A.__mapper__.primary_key[0] is A.__table__.c.a
+ assert A.__mapper__.primary_key[1] is A.__table__.c.b
+ elif argument.version_id_col:
+ assert A.__mapper__.version_id_col is A.__table__.c.b
+ elif argument.polymorphic_on:
+ assert A.__mapper__.polymorphic_on is A.__table__.c.c
+ else:
+ argtype.fail()
+
def test_dispose_attrs(self):
reg = registry()