diff options
Diffstat (limited to 'test/orm/entity.py')
-rw-r--r-- | test/orm/entity.py | 170 |
1 files changed, 170 insertions, 0 deletions
diff --git a/test/orm/entity.py b/test/orm/entity.py new file mode 100644 index 000000000..f425c5c8e --- /dev/null +++ b/test/orm/entity.py @@ -0,0 +1,170 @@ +from testbase import PersistTest, AssertMixin +import unittest +from sqlalchemy import * +import testbase +from sqlalchemy.ext.sessioncontext import SessionContext + +from tables import * +import tables + +class EntityTest(AssertMixin): + """tests mappers that are constructed based on "entity names", which allows the same class + to have multiple primary mappers """ + def setUpAll(self): + global user1, user2, address1, address2, metadata, ctx + metadata = BoundMetaData(testbase.db) + ctx = SessionContext(create_session) + + user1 = Table('user1', metadata, + Column('user_id', Integer, Sequence('user1_id_seq'), primary_key=True), + Column('name', String(60), nullable=False) + ) + user2 = Table('user2', metadata, + Column('user_id', Integer, Sequence('user2_id_seq'), primary_key=True), + Column('name', String(60), nullable=False) + ) + address1 = Table('address1', metadata, + Column('address_id', Integer, Sequence('address1_id_seq'), primary_key=True), + Column('user_id', Integer, ForeignKey(user1.c.user_id), nullable=False), + Column('email', String(100), nullable=False) + ) + address2 = Table('address2', metadata, + Column('address_id', Integer, Sequence('address2_id_seq'), primary_key=True), + Column('user_id', Integer, ForeignKey(user2.c.user_id), nullable=False), + Column('email', String(100), nullable=False) + ) + metadata.create_all() + def tearDownAll(self): + metadata.drop_all() + def tearDown(self): + clear_mappers() + for t in metadata.table_iterator(reverse=True): + t.delete().execute() + + def testbasic(self): + """tests a pair of one-to-many mapper structures, establishing that both + parent and child objects honor the "entity_name" attribute attached to the object + instances.""" + class User(object):pass + class Address(object):pass + + a1mapper = mapper(Address, address1, entity_name='address1', extension=ctx.mapper_extension) + a2mapper = mapper(Address, address2, entity_name='address2', extension=ctx.mapper_extension) + u1mapper = mapper(User, user1, entity_name='user1', properties ={ + 'addresses':relation(a1mapper) + }, extension=ctx.mapper_extension) + u2mapper =mapper(User, user2, entity_name='user2', properties={ + 'addresses':relation(a2mapper) + }, extension=ctx.mapper_extension) + + u1 = User(_sa_entity_name='user1') + u1.name = 'this is user 1' + a1 = Address(_sa_entity_name='address1') + a1.email='a1@foo.com' + u1.addresses.append(a1) + + u2 = User(_sa_entity_name='user2') + u2.name='this is user 2' + a2 = Address(_sa_entity_name='address2') + a2.email='a2@foo.com' + u2.addresses.append(a2) + + ctx.current.flush() + assert user1.select().execute().fetchall() == [(u1.user_id, u1.name)] + assert user2.select().execute().fetchall() == [(u2.user_id, u2.name)] + assert address1.select().execute().fetchall() == [(u1.user_id, a1.user_id, 'a1@foo.com')] + assert address2.select().execute().fetchall() == [(u2.user_id, a2.user_id, 'a2@foo.com')] + + ctx.current.clear() + u1list = ctx.current.query(User, entity_name='user1').select() + u2list = ctx.current.query(User, entity_name='user2').select() + assert len(u1list) == len(u2list) == 1 + assert u1list[0] is not u2list[0] + assert len(u1list[0].addresses) == len(u2list[0].addresses) == 1 + + def testcascade(self): + """same as testbasic but relies on session cascading""" + class User(object):pass + class Address(object):pass + + a1mapper = mapper(Address, address1, entity_name='address1') + a2mapper = mapper(Address, address2, entity_name='address2') + u1mapper = mapper(User, user1, entity_name='user1', properties ={ + 'addresses':relation(a1mapper) + }) + u2mapper =mapper(User, user2, entity_name='user2', properties={ + 'addresses':relation(a2mapper) + }) + + sess = create_session() + u1 = User() + u1.name = 'this is user 1' + sess.save(u1, entity_name='user1') + a1 = Address() + a1.email='a1@foo.com' + u1.addresses.append(a1) + + u2 = User() + u2.name='this is user 2' + a2 = Address() + a2.email='a2@foo.com' + u2.addresses.append(a2) + sess.save(u2, entity_name='user2') + + sess.flush() + assert user1.select().execute().fetchall() == [(u1.user_id, u1.name)] + assert user2.select().execute().fetchall() == [(u2.user_id, u2.name)] + assert address1.select().execute().fetchall() == [(u1.user_id, a1.user_id, 'a1@foo.com')] + assert address2.select().execute().fetchall() == [(u2.user_id, a2.user_id, 'a2@foo.com')] + + sess.clear() + u1list = sess.query(User, entity_name='user1').select() + u2list = sess.query(User, entity_name='user2').select() + assert len(u1list) == len(u2list) == 1 + assert u1list[0] is not u2list[0] + assert len(u1list[0].addresses) == len(u2list[0].addresses) == 1 + + def testpolymorphic(self): + """tests that entity_name can be used to have two kinds of relations on the same class.""" + class User(object):pass + class Address1(object):pass + class Address2(object):pass + + a1mapper = mapper(Address1, address1, extension=ctx.mapper_extension) + a2mapper = mapper(Address2, address2, extension=ctx.mapper_extension) + u1mapper = mapper(User, user1, entity_name='user1', properties ={ + 'addresses':relation(a1mapper) + }, extension=ctx.mapper_extension) + u2mapper =mapper(User, user2, entity_name='user2', properties={ + 'addresses':relation(a2mapper) + }, extension=ctx.mapper_extension) + + u1 = User(_sa_entity_name='user1') + u1.name = 'this is user 1' + a1 = Address1() + a1.email='a1@foo.com' + u1.addresses.append(a1) + + u2 = User(_sa_entity_name='user2') + u2.name='this is user 2' + a2 = Address2() + a2.email='a2@foo.com' + u2.addresses.append(a2) + + ctx.current.flush() + assert user1.select().execute().fetchall() == [(u1.user_id, u1.name)] + assert user2.select().execute().fetchall() == [(u2.user_id, u2.name)] + assert address1.select().execute().fetchall() == [(u1.user_id, a1.user_id, 'a1@foo.com')] + assert address2.select().execute().fetchall() == [(u2.user_id, a2.user_id, 'a2@foo.com')] + + ctx.current.clear() + u1list = ctx.current.query(User, entity_name='user1').select() + u2list = ctx.current.query(User, entity_name='user2').select() + assert len(u1list) == len(u2list) == 1 + assert u1list[0] is not u2list[0] + assert len(u1list[0].addresses) == len(u2list[0].addresses) == 1 + assert isinstance(u1list[0].addresses[0], Address1) + assert isinstance(u2list[0].addresses[0], Address2) + +if __name__ == "__main__": + testbase.main() |