diff options
Diffstat (limited to 'test/orm/inheritance/abc_inheritance.py')
-rw-r--r-- | test/orm/inheritance/abc_inheritance.py | 166 |
1 files changed, 166 insertions, 0 deletions
diff --git a/test/orm/inheritance/abc_inheritance.py b/test/orm/inheritance/abc_inheritance.py new file mode 100644 index 000000000..3b35b3713 --- /dev/null +++ b/test/orm/inheritance/abc_inheritance.py @@ -0,0 +1,166 @@ +import testbase +from sqlalchemy import * +from sqlalchemy.orm import * +from sqlalchemy.orm.sync import ONETOMANY, MANYTOONE +from testlib import * + +def produce_test(parent, child, direction): + """produce a testcase for A->B->C inheritance with a self-referential + relationship between two of the classes, using either one-to-many or + many-to-one.""" + class ABCTest(ORMTest): + def define_tables(self, meta): + global ta, tb, tc + ta = ["a", meta] + ta.append(Column('id', Integer, primary_key=True)), + ta.append(Column('a_data', String(30))) + if "a"== parent and direction == MANYTOONE: + ta.append(Column('child_id', Integer, ForeignKey("%s.id" % child, use_alter=True, name="foo"))) + elif "a" == child and direction == ONETOMANY: + ta.append(Column('parent_id', Integer, ForeignKey("%s.id" % parent, use_alter=True, name="foo"))) + ta = Table(*ta) + + tb = ["b", meta] + tb.append(Column('id', Integer, ForeignKey("a.id"), primary_key=True, )) + + tb.append(Column('b_data', String(30))) + + if "b"== parent and direction == MANYTOONE: + tb.append(Column('child_id', Integer, ForeignKey("%s.id" % child, use_alter=True, name="foo"))) + elif "b" == child and direction == ONETOMANY: + tb.append(Column('parent_id', Integer, ForeignKey("%s.id" % parent, use_alter=True, name="foo"))) + tb = Table(*tb) + + tc = ["c", meta] + tc.append(Column('id', Integer, ForeignKey("b.id"), primary_key=True, )) + + tc.append(Column('c_data', String(30))) + + if "c"== parent and direction == MANYTOONE: + tc.append(Column('child_id', Integer, ForeignKey("%s.id" % child, use_alter=True, name="foo"))) + elif "c" == child and direction == ONETOMANY: + tc.append(Column('parent_id', Integer, ForeignKey("%s.id" % parent, use_alter=True, name="foo"))) + tc = Table(*tc) + + def tearDown(self): + if direction == MANYTOONE: + parent_table = {"a":ta, "b":tb, "c": tc}[parent] + parent_table.update(values={parent_table.c.child_id:None}).execute() + elif direction == ONETOMANY: + child_table = {"a":ta, "b":tb, "c": tc}[child] + child_table.update(values={child_table.c.parent_id:None}).execute() + super(ABCTest, self).tearDown() + + def test_roundtrip(self): + parent_table = {"a":ta, "b":tb, "c": tc}[parent] + child_table = {"a":ta, "b":tb, "c": tc}[child] + + remote_side = None + + if direction == MANYTOONE: + foreign_keys = [parent_table.c.child_id] + elif direction == ONETOMANY: + foreign_keys = [child_table.c.parent_id] + + atob = ta.c.id==tb.c.id + btoc = tc.c.id==tb.c.id + + if direction == ONETOMANY: + relationjoin = parent_table.c.id==child_table.c.parent_id + elif direction == MANYTOONE: + relationjoin = parent_table.c.child_id==child_table.c.id + if parent is child: + remote_side = [child_table.c.id] + + abcjoin = polymorphic_union( + {"a":ta.select(tb.c.id==None, from_obj=[ta.outerjoin(tb, onclause=atob)]), + "b":ta.join(tb, onclause=atob).outerjoin(tc, onclause=btoc).select(tc.c.id==None, fold_equivalents=True), + "c":tc.join(tb, onclause=btoc).join(ta, onclause=atob) + },"type", "abcjoin" + ) + + bcjoin = polymorphic_union( + { + "b":ta.join(tb, onclause=atob).outerjoin(tc, onclause=btoc).select(tc.c.id==None, fold_equivalents=True), + "c":tc.join(tb, onclause=btoc).join(ta, onclause=atob) + },"type", "bcjoin" + ) + class A(object): + def __init__(self, name): + self.a_data = name + class B(A):pass + class C(B):pass + + mapper(A, ta, polymorphic_on=abcjoin.c.type, select_table=abcjoin, polymorphic_identity="a") + mapper(B, tb, polymorphic_on=bcjoin.c.type, select_table=bcjoin, polymorphic_identity="b", inherits=A, inherit_condition=atob) + mapper(C, tc, polymorphic_identity="c", inherits=B, inherit_condition=btoc) + + parent_mapper = class_mapper({ta:A, tb:B, tc:C}[parent_table]) + child_mapper = class_mapper({ta:A, tb:B, tc:C}[child_table]) + + parent_class = parent_mapper.class_ + child_class = child_mapper.class_ + + parent_mapper.add_property("collection", relation(child_mapper, primaryjoin=relationjoin, foreign_keys=foreign_keys, remote_side=remote_side, uselist=True)) + + sess = create_session() + + parent_obj = parent_class('parent1') + child_obj = child_class('child1') + somea = A('somea') + someb = B('someb') + somec = C('somec') + print "APPENDING", parent.__class__.__name__ , "TO", child.__class__.__name__ + sess.save(parent_obj) + parent_obj.collection.append(child_obj) + if direction == ONETOMANY: + child2 = child_class('child2') + parent_obj.collection.append(child2) + sess.save(child2) + elif direction == MANYTOONE: + parent2 = parent_class('parent2') + parent2.collection.append(child_obj) + sess.save(parent2) + sess.save(somea) + sess.save(someb) + sess.save(somec) + sess.flush() + sess.clear() + + # assert result via direct get() of parent object + result = sess.query(parent_class).get(parent_obj.id) + assert result.id == parent_obj.id + assert result.collection[0].id == child_obj.id + if direction == ONETOMANY: + assert result.collection[1].id == child2.id + elif direction == MANYTOONE: + result2 = sess.query(parent_class).get(parent2.id) + assert result2.id == parent2.id + assert result2.collection[0].id == child_obj.id + + sess.clear() + + # assert result via polymorphic load of parent object + result = sess.query(A).get_by(id=parent_obj.id) + assert result.id == parent_obj.id + assert result.collection[0].id == child_obj.id + if direction == ONETOMANY: + assert result.collection[1].id == child2.id + elif direction == MANYTOONE: + result2 = sess.query(A).get_by(id=parent2.id) + assert result2.id == parent2.id + assert result2.collection[0].id == child_obj.id + + ABCTest.__name__ = "Test%sTo%s%s" % (parent, child, (direction is ONETOMANY and "O2M" or "M2O")) + return ABCTest + +# test all combinations of polymorphic a/b/c related to another of a/b/c +for parent in ["a", "b", "c"]: + for child in ["a", "b", "c"]: + for direction in [ONETOMANY, MANYTOONE]: + testclass = produce_test(parent, child, direction) + exec("%s = testclass" % testclass.__name__) + + +if __name__ == "__main__": + testbase.main() |