summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2023-03-27 23:33:02 +0000
committerGerrit Code Review <gerrit@bbpush.zzzcomputing.com>2023-03-27 23:33:02 +0000
commite3a225aedbbb058122ecca466935b87405cf0132 (patch)
treea8f3d0f4cc8e4afb0b618bdde1b6d42cde7dd7d1
parent348d76072c108e996baf59900fc45468f48c4cce (diff)
parent24dd3d8c90876a05377d04910819dcd5d25aed4e (diff)
downloadsqlalchemy-e3a225aedbbb058122ecca466935b87405cf0132.tar.gz
Merge "support DeclarativeBase for versioned history example" into main
-rw-r--r--doc/build/changelog/unreleased_20/vers_fixes.rst7
-rw-r--r--examples/versioned_history/__init__.py3
-rw-r--r--examples/versioned_history/history_meta.py39
-rw-r--r--examples/versioned_history/test_versioning.py55
-rw-r--r--test/base/test_examples.py10
5 files changed, 95 insertions, 19 deletions
diff --git a/doc/build/changelog/unreleased_20/vers_fixes.rst b/doc/build/changelog/unreleased_20/vers_fixes.rst
new file mode 100644
index 000000000..d4f641151
--- /dev/null
+++ b/doc/build/changelog/unreleased_20/vers_fixes.rst
@@ -0,0 +1,7 @@
+.. change::
+ :tags: bug, examples
+
+ Fixed issue in "versioned history" example where using a declarative base
+ that is derived from :class:`_orm.DeclarativeBase` would fail to be mapped.
+ Additionally, repaired the given test suite so that the documented
+ instructions for running the example using Python unittest now work again.
diff --git a/examples/versioned_history/__init__.py b/examples/versioned_history/__init__.py
index 14dbea5c0..0593881e2 100644
--- a/examples/versioned_history/__init__.py
+++ b/examples/versioned_history/__init__.py
@@ -16,7 +16,8 @@ A fragment of example usage, using declarative::
from history_meta import Versioned, versioned_session
- Base = declarative_base()
+ class Base(DeclarativeBase):
+ pass
class SomeClass(Versioned, Base):
__tablename__ = 'sometable'
diff --git a/examples/versioned_history/history_meta.py b/examples/versioned_history/history_meta.py
index 1176a5dff..cc3ef2b0a 100644
--- a/examples/versioned_history/history_meta.py
+++ b/examples/versioned_history/history_meta.py
@@ -6,6 +6,7 @@ from sqlalchemy import Column
from sqlalchemy import DateTime
from sqlalchemy import event
from sqlalchemy import ForeignKeyConstraint
+from sqlalchemy import inspect
from sqlalchemy import Integer
from sqlalchemy import PrimaryKeyConstraint
from sqlalchemy import util
@@ -174,22 +175,22 @@ def _history_mapper(local_mapper):
else:
bases = local_mapper.base_mapper.class_.__bases__
- versioned_cls = type.__new__(
- type,
+ versioned_cls = type(
"%sHistory" % cls.__name__,
bases,
- {"_history_mapper_configured": True},
+ {
+ "_history_mapper_configured": True,
+ "__table__": history_table,
+ "__mapper_args__": dict(
+ inherits=super_history_mapper,
+ polymorphic_identity=local_mapper.polymorphic_identity,
+ polymorphic_on=polymorphic_on,
+ properties=properties,
+ ),
+ },
)
- m = local_mapper.registry.map_imperatively(
- versioned_cls,
- history_table,
- inherits=super_history_mapper,
- polymorphic_identity=local_mapper.polymorphic_identity,
- polymorphic_on=polymorphic_on,
- properties=properties,
- )
- cls.__history_mapper__ = m
+ cls.__history_mapper__ = versioned_cls.__mapper__
class Versioned:
@@ -201,9 +202,17 @@ class Versioned:
are used for new rows even for rows that have been deleted."""
def __init_subclass__(cls) -> None:
- @event.listens_for(cls, "after_mapper_constructed")
- def _mapper_constructed(mapper, class_):
- _history_mapper(mapper)
+ insp = inspect(cls, raiseerr=False)
+
+ if insp is not None:
+ _history_mapper(insp)
+ else:
+
+ @event.listens_for(cls, "after_mapper_constructed")
+ def _mapper_constructed(mapper, class_):
+ _history_mapper(mapper)
+
+ super().__init_subclass__()
def versioned_objects(iter_):
diff --git a/examples/versioned_history/test_versioning.py b/examples/versioned_history/test_versioning.py
index 8f9635592..9caadc043 100644
--- a/examples/versioned_history/test_versioning.py
+++ b/examples/versioned_history/test_versioning.py
@@ -8,12 +8,15 @@ from sqlalchemy import Boolean
from sqlalchemy import Column
from sqlalchemy import create_engine
from sqlalchemy import ForeignKey
+from sqlalchemy import inspect
from sqlalchemy import Integer
+from sqlalchemy import join
from sqlalchemy import select
from sqlalchemy import String
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import clear_mappers
from sqlalchemy.orm import column_property
+from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import deferred
from sqlalchemy.orm import exc as orm_exc
from sqlalchemy.orm import relationship
@@ -21,6 +24,8 @@ from sqlalchemy.orm import Session
from sqlalchemy.testing import assert_raises
from sqlalchemy.testing import AssertsCompiledSQL
from sqlalchemy.testing import eq_
+from sqlalchemy.testing import eq_ignore_whitespace
+from sqlalchemy.testing import is_
from sqlalchemy.testing import ne_
from sqlalchemy.testing.entities import ComparableEntity
from .history_meta import Versioned
@@ -37,7 +42,7 @@ class TestVersioning(AssertsCompiledSQL):
self.engine = engine = create_engine("sqlite://")
self.session = Session(engine)
- self.Base = declarative_base()
+ self.make_base()
versioned_session(self.session)
def tearDown(self):
@@ -45,6 +50,9 @@ class TestVersioning(AssertsCompiledSQL):
clear_mappers()
self.Base.metadata.drop_all(self.engine)
+ def make_base(self):
+ self.Base = declarative_base()
+
def create_tables(self):
self.Base.metadata.create_all(self.engine)
@@ -120,6 +128,37 @@ class TestVersioning(AssertsCompiledSQL):
],
)
+ def test_discussion_9546(self):
+ class ThingExternal(Versioned, self.Base):
+ __tablename__ = "things_external"
+ id = Column(Integer, primary_key=True)
+ external_attribute = Column(String)
+
+ class ThingLocal(Versioned, self.Base):
+ __tablename__ = "things_local"
+ id = Column(
+ Integer, ForeignKey(ThingExternal.id), primary_key=True
+ )
+ internal_attribute = Column(String)
+
+ is_(ThingExternal.__table__, inspect(ThingExternal).local_table)
+
+ class Thing(self.Base):
+ __table__ = join(ThingExternal, ThingLocal)
+ id = column_property(ThingExternal.id, ThingLocal.id)
+ version = column_property(
+ ThingExternal.version, ThingLocal.version
+ )
+
+ eq_ignore_whitespace(
+ str(select(Thing)),
+ "SELECT things_external.id, things_local.id AS id_1, "
+ "things_external.external_attribute, things_external.version, "
+ "things_local.version AS version_1, "
+ "things_local.internal_attribute FROM things_external "
+ "JOIN things_local ON things_external.id = things_local.id",
+ )
+
def test_w_mapper_versioning(self):
class SomeClass(Versioned, self.Base, ComparableEntity):
__tablename__ = "sometable"
@@ -750,7 +789,19 @@ class TestVersioning(AssertsCompiledSQL):
sess.commit()
-class TestVersioningUnittest(unittest.TestCase, TestVersioning):
+class TestVersioningNewBase(TestVersioning):
+ def make_base(self):
+ class Base(DeclarativeBase):
+ pass
+
+ self.Base = Base
+
+
+class TestVersioningUnittest(TestVersioning, unittest.TestCase):
+ pass
+
+
+class TestVersioningNewBaseUnittest(TestVersioningNewBase, unittest.TestCase):
pass
diff --git a/test/base/test_examples.py b/test/base/test_examples.py
index 50f0c01f2..4baddfb10 100644
--- a/test/base/test_examples.py
+++ b/test/base/test_examples.py
@@ -15,9 +15,17 @@ test_versioning = __import__(
).versioned_history.test_versioning
-class VersionedRowsTest(
+class VersionedRowsTestLegacyBase(
test_versioning.TestVersioning,
fixtures.RemoveORMEventsGlobally,
fixtures.TestBase,
):
pass
+
+
+class VersionedRowsTestNewBase(
+ test_versioning.TestVersioningNewBase,
+ fixtures.RemoveORMEventsGlobally,
+ fixtures.TestBase,
+):
+ pass