diff options
Diffstat (limited to 'test/orm/inheritance/test_basic.py')
-rw-r--r-- | test/orm/inheritance/test_basic.py | 163 |
1 files changed, 145 insertions, 18 deletions
diff --git a/test/orm/inheritance/test_basic.py b/test/orm/inheritance/test_basic.py index 7b991a618..c9aa5fc9b 100644 --- a/test/orm/inheritance/test_basic.py +++ b/test/orm/inheritance/test_basic.py @@ -1,7 +1,7 @@ import warnings from test.lib.testing import eq_, assert_raises, assert_raises_message from sqlalchemy import * -from sqlalchemy import exc as sa_exc, util +from sqlalchemy import exc as sa_exc, util, event from sqlalchemy.orm import * from sqlalchemy.orm import exc as orm_exc, attributes from test.lib.assertsql import AllOf, CompiledSQL @@ -86,62 +86,189 @@ class PolymorphicOnNotLocalTest(fixtures.MappedTest): Column('y', String(10)), Column('xid', ForeignKey('t1.id'))) - def test_non_col_polymorphic_on(self): - class InterfaceBase(object): + @classmethod + def setup_classes(cls): + class Parent(cls.Comparable): + pass + class Child(Parent): pass + def test_non_col_polymorphic_on(self): + Parent = self.classes.Parent + t2 = self.tables.t2 assert_raises_message( sa_exc.ArgumentError, - "Column-based expression object expected " - "for argument 'polymorphic_on'; got: " - "'im not a column', type", + "Can't determine polymorphic_on " + "value 'im not a column' - no " + "attribute is mapped to this name.", mapper, - InterfaceBase, t2, polymorphic_on="im not a column" + Parent, t2, polymorphic_on="im not a column" ) - def test_bad_polymorphic_on(self): + def test_polymorphic_on_non_expr_prop(self): t2, t1 = self.tables.t2, self.tables.t1 + Parent = self.classes.Parent - class InterfaceBase(object): - pass + t1t2_join = select([t1.c.x], from_obj=[t1.join(t2)]).alias() + def go(): + interface_m = mapper(Parent, t2, + polymorphic_on=lambda:"hi", + polymorphic_identity=0) + assert_raises_message( + sa_exc.ArgumentError, + "Only direct column-mapped property or " + "SQL expression can be passed for polymorphic_on", + go + ) + + def test_polymorphic_on_not_present_col(self): + t2, t1 = self.tables.t2, self.tables.t1 + Parent = self.classes.Parent t1t2_join = select([t1.c.x], from_obj=[t1.join(t2)]).alias() def go(): - interface_m = mapper(InterfaceBase, t2, + t1t2_join_2 = select([t1.c.q], from_obj=[t1.join(t2)]).alias() + interface_m = mapper(Parent, t2, polymorphic_on=t1t2_join.c.x, + with_polymorphic=('*', t1t2_join_2), polymorphic_identity=0) - assert_raises_message( sa_exc.InvalidRequestError, "Could not map polymorphic_on column 'x' to the mapped table - " "polymorphic loads will not function properly", go ) - clear_mappers() + def test_polymorphic_on_only_in_with_poly(self): + t2, t1 = self.tables.t2, self.tables.t1 + Parent = self.classes.Parent + t1t2_join = select([t1.c.x], from_obj=[t1.join(t2)]).alias() # if its in the with_polymorphic, then its OK - interface_m = mapper(InterfaceBase, t2, + mapper(Parent, t2, polymorphic_on=t1t2_join.c.x, with_polymorphic=('*', t1t2_join), polymorphic_identity=0) - configure_mappers() - clear_mappers() + def test_polymorpic_on_not_in_with_poly(self): + t2, t1 = self.tables.t2, self.tables.t1 + Parent = self.classes.Parent + + t1t2_join = select([t1.c.x], from_obj=[t1.join(t2)]).alias() # if with_polymorphic, but its not present, not OK def go(): t1t2_join_2 = select([t1.c.q], from_obj=[t1.join(t2)]).alias() - interface_m = mapper(InterfaceBase, t2, + interface_m = mapper(Parent, t2, polymorphic_on=t1t2_join.c.x, with_polymorphic=('*', t1t2_join_2), polymorphic_identity=0) assert_raises_message( sa_exc.InvalidRequestError, - "Could not map polymorphic_on column 'x' to the mapped table - " + "Could not map polymorphic_on column 'x' " + "to the mapped table - " "polymorphic loads will not function properly", go ) + def test_polymorphic_on_expr_explicit_map(self): + t2, t1 = self.tables.t2, self.tables.t1 + Parent, Child = self.classes.Parent, self.classes.Child + expr = case([ + (t1.c.x=="p", "parent"), + (t1.c.x=="c", "child"), + ],else_ = t1.c.x) + mapper(Parent, t1, properties={ + "discriminator":column_property(expr) + }, polymorphic_identity="parent", + polymorphic_on=expr) + mapper(Child, t2, inherits=Parent, + polymorphic_identity="child") + + self._roundtrip() + + def test_polymorphic_on_expr_implicit_map(self): + t2, t1 = self.tables.t2, self.tables.t1 + Parent, Child = self.classes.Parent, self.classes.Child + expr = case([ + (t1.c.x=="p", "parent"), + (t1.c.x=="c", "child"), + ],else_ = t1.c.x).label("foo") + mapper(Parent, t1, polymorphic_identity="parent", + polymorphic_on=expr) + mapper(Child, t2, inherits=Parent, polymorphic_identity="child") + + self._roundtrip() + + def test_polymorphic_on_column_prop(self): + t2, t1 = self.tables.t2, self.tables.t1 + Parent, Child = self.classes.Parent, self.classes.Child + expr = case([ + (t1.c.x=="p", "parent"), + (t1.c.x=="c", "child"), + ],else_ = t1.c.x) + cprop = column_property(expr) + mapper(Parent, t1, properties={ + "discriminator":cprop + }, polymorphic_identity="parent", + polymorphic_on=cprop) + mapper(Child, t2, inherits=Parent, + polymorphic_identity="child") + + self._roundtrip() + + def test_polymorphic_on_column_str_prop(self): + t2, t1 = self.tables.t2, self.tables.t1 + Parent, Child = self.classes.Parent, self.classes.Child + expr = case([ + (t1.c.x=="p", "parent"), + (t1.c.x=="c", "child"), + ],else_ = t1.c.x) + cprop = column_property(expr) + mapper(Parent, t1, properties={ + "discriminator":cprop + }, polymorphic_identity="parent", + polymorphic_on="discriminator") + mapper(Child, t2, inherits=Parent, + polymorphic_identity="child") + + self._roundtrip() + + def test_polymorphic_on_synonym(self): + t2, t1 = self.tables.t2, self.tables.t1 + Parent, Child = self.classes.Parent, self.classes.Child + cprop = column_property(t1.c.x) + assert_raises_message( + sa_exc.ArgumentError, + "Only direct column-mapped property or " + "SQL expression can be passed for polymorphic_on", + mapper, Parent, t1, properties={ + "discriminator":cprop, + "discrim_syn":synonym(cprop) + }, polymorphic_identity="parent", + polymorphic_on="discrim_syn") + + def _roundtrip(self, set_event=True): + Parent, Child = self.classes.Parent, self.classes.Child + + if set_event: + @event.listens_for(Parent, "init", propagate=True) + def set_identity(instance, *arg, **kw): + instance.x = object_mapper(instance).polymorphic_identity + + s = Session(testing.db) + s.add_all([ + Parent(q="p1"), + Child(q="c1", y="c1"), + Parent(q="p2"), + ]) + s.commit() + s.close() + + eq_( + [type(t) for t in s.query(Parent).order_by(Parent.id)], + [Parent, Child, Parent] + ) + class FalseDiscriminatorTest(fixtures.MappedTest): @classmethod |