summaryrefslogtreecommitdiff
path: root/test/orm/entity.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2008-05-09 16:34:10 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2008-05-09 16:34:10 +0000
commit4a6afd469fad170868554bf28578849bf3dfd5dd (patch)
treeb396edc33d567ae19dd244e87137296450467725 /test/orm/entity.py
parent46b7c9dc57a38d5b9e44a4723dad2ad8ec57baca (diff)
downloadsqlalchemy-4a6afd469fad170868554bf28578849bf3dfd5dd.tar.gz
r4695 merged to trunk; trunk now becomes 0.5.
0.4 development continues at /sqlalchemy/branches/rel_0_4
Diffstat (limited to 'test/orm/entity.py')
-rw-r--r--test/orm/entity.py127
1 files changed, 90 insertions, 37 deletions
diff --git a/test/orm/entity.py b/test/orm/entity.py
index 760f8fce9..d9c9e4002 100644
--- a/test/orm/entity.py
+++ b/test/orm/entity.py
@@ -1,19 +1,18 @@
import testenv; testenv.configure_for_tests()
from sqlalchemy import *
from sqlalchemy.orm import *
-from sqlalchemy.ext.sessioncontext import SessionContext
from testlib import *
from testlib.tables import *
+from testlib import fixtures
class EntityTest(TestBase, AssertsExecutionResults):
"""tests mappers that are constructed based on "entity names", which allows the same class
to have multiple primary mappers """
- @testing.uses_deprecated('SessionContext')
def setUpAll(self):
global user1, user2, address1, address2, metadata, ctx
metadata = MetaData(testing.db)
- ctx = SessionContext(create_session)
+ ctx = scoped_session(create_session)
user1 = Table('user1', metadata,
Column('user_id', Integer, Sequence('user1_id_seq', optional=True),
@@ -45,28 +44,31 @@ class EntityTest(TestBase, AssertsExecutionResults):
def tearDownAll(self):
metadata.drop_all()
def tearDown(self):
- ctx.current.clear()
+ ctx.clear()
clear_mappers()
for t in metadata.table_iterator(reverse=True):
t.delete().execute()
- @testing.uses_deprecated('SessionContextExt')
def testbasic(self):
"""tests a pair of one-to-many mapper structures, establishing that both
parent and child objects honor the "entity_name" attribute attached to the object
instances."""
- class User(object):pass
- class Address(object):pass
+ class User(object):
+ def __init__(self, **kw):
+ pass
+ class Address(object):
+ def __init__(self, **kw):
+ pass
- a1mapper = mapper(Address, address1, entity_name='address1', extension=ctx.mapper_extension)
- a2mapper = mapper(Address, address2, entity_name='address2', extension=ctx.mapper_extension)
+ a1mapper = mapper(Address, address1, entity_name='address1', extension=ctx.extension)
+ a2mapper = mapper(Address, address2, entity_name='address2', extension=ctx.extension)
u1mapper = mapper(User, user1, entity_name='user1', properties ={
'addresses':relation(a1mapper)
- }, extension=ctx.mapper_extension)
+ }, extension=ctx.extension)
u2mapper =mapper(User, user2, entity_name='user2', properties={
'addresses':relation(a2mapper)
- }, extension=ctx.mapper_extension)
-
+ }, extension=ctx.extension)
+
u1 = User(_sa_entity_name='user1')
u1.name = 'this is user 1'
a1 = Address(_sa_entity_name='address1')
@@ -79,22 +81,22 @@ class EntityTest(TestBase, AssertsExecutionResults):
a2.email='a2@foo.com'
u2.addresses.append(a2)
- ctx.current.flush()
+ ctx.flush()
assert user1.select().execute().fetchall() == [(u1.user_id, u1.name)]
assert user2.select().execute().fetchall() == [(u2.user_id, u2.name)]
assert address1.select().execute().fetchall() == [(a1.address_id, u1.user_id, 'a1@foo.com')]
assert address2.select().execute().fetchall() == [(a1.address_id, u2.user_id, 'a2@foo.com')]
- ctx.current.clear()
- u1list = ctx.current.query(User, entity_name='user1').all()
- u2list = ctx.current.query(User, entity_name='user2').all()
+ ctx.clear()
+ u1list = ctx.query(User, entity_name='user1').all()
+ u2list = ctx.query(User, entity_name='user2').all()
assert len(u1list) == len(u2list) == 1
assert u1list[0] is not u2list[0]
assert len(u1list[0].addresses) == len(u2list[0].addresses) == 1
- u1 = ctx.current.query(User, entity_name='user1').first()
- ctx.current.refresh(u1)
- ctx.current.expire(u1)
+ u1 = ctx.query(User, entity_name='user1').first()
+ ctx.refresh(u1)
+ ctx.expire(u1)
def testcascade(self):
@@ -142,18 +144,24 @@ class EntityTest(TestBase, AssertsExecutionResults):
def testpolymorphic(self):
"""tests that entity_name can be used to have two kinds of relations on the same class."""
- class User(object):pass
- class Address1(object):pass
- class Address2(object):pass
+ class User(object):
+ def __init__(self, **kw):
+ pass
+ class Address1(object):
+ def __init__(self, **kw):
+ pass
+ class Address2(object):
+ def __init__(self, **kw):
+ pass
- a1mapper = mapper(Address1, address1, extension=ctx.mapper_extension)
- a2mapper = mapper(Address2, address2, extension=ctx.mapper_extension)
+ a1mapper = mapper(Address1, address1, extension=ctx.extension)
+ a2mapper = mapper(Address2, address2, extension=ctx.extension)
u1mapper = mapper(User, user1, entity_name='user1', properties ={
'addresses':relation(a1mapper)
- }, extension=ctx.mapper_extension)
+ }, extension=ctx.extension)
u2mapper =mapper(User, user2, entity_name='user2', properties={
'addresses':relation(a2mapper)
- }, extension=ctx.mapper_extension)
+ }, extension=ctx.extension)
u1 = User(_sa_entity_name='user1')
u1.name = 'this is user 1'
@@ -167,15 +175,15 @@ class EntityTest(TestBase, AssertsExecutionResults):
a2.email='a2@foo.com'
u2.addresses.append(a2)
- ctx.current.flush()
+ ctx.flush()
assert user1.select().execute().fetchall() == [(u1.user_id, u1.name)]
assert user2.select().execute().fetchall() == [(u2.user_id, u2.name)]
assert address1.select().execute().fetchall() == [(a1.address_id, u1.user_id, 'a1@foo.com')]
assert address2.select().execute().fetchall() == [(a1.address_id, u2.user_id, 'a2@foo.com')]
- ctx.current.clear()
- u1list = ctx.current.query(User, entity_name='user1').all()
- u2list = ctx.current.query(User, entity_name='user2').all()
+ ctx.clear()
+ u1list = ctx.query(User, entity_name='user1').all()
+ u2list = ctx.query(User, entity_name='user2').all()
assert len(u1list) == len(u2list) == 1
assert u1list[0] is not u2list[0]
assert len(u1list[0].addresses) == len(u2list[0].addresses) == 1
@@ -186,13 +194,15 @@ class EntityTest(TestBase, AssertsExecutionResults):
def testpolymorphic_deferred(self):
"""test that deferred columns load properly using entity names"""
- class User(object):pass
+ class User(object):
+ def __init__(self, **kwargs):
+ pass
u1mapper = mapper(User, user1, entity_name='user1', properties ={
'name':deferred(user1.c.name)
- }, extension=ctx.mapper_extension)
+ }, extension=ctx.extension)
u2mapper =mapper(User, user2, entity_name='user2', properties={
'name':deferred(user2.c.name)
- }, extension=ctx.mapper_extension)
+ }, extension=ctx.extension)
u1 = User(_sa_entity_name='user1')
u1.name = 'this is user 1'
@@ -200,13 +210,13 @@ class EntityTest(TestBase, AssertsExecutionResults):
u2 = User(_sa_entity_name='user2')
u2.name='this is user 2'
- ctx.current.flush()
+ ctx.flush()
assert user1.select().execute().fetchall() == [(u1.user_id, u1.name)]
assert user2.select().execute().fetchall() == [(u2.user_id, u2.name)]
- ctx.current.clear()
- u1list = ctx.current.query(User, entity_name='user1').all()
- u2list = ctx.current.query(User, entity_name='user2').all()
+ ctx.clear()
+ u1list = ctx.query(User, entity_name='user1').all()
+ u2list = ctx.query(User, entity_name='user2').all()
assert len(u1list) == len(u2list) == 1
assert u1list[0] is not u2list[0]
# the deferred column load requires that setup_loader() check that the correct DeferredColumnLoader
@@ -214,6 +224,49 @@ class EntityTest(TestBase, AssertsExecutionResults):
assert u1list[0].name == 'this is user 1'
assert u2list[0].name == 'this is user 2'
+class SelfReferentialTest(ORMTest):
+ def define_tables(self, metadata):
+ global nodes
+
+ nodes = Table('nodes', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('parent_id', Integer, ForeignKey('nodes.id')),
+ Column('data', String(50)),
+ Column('type', String(50)),
+ )
+
+ # fails inconsistently. entity name needs deterministic
+ # instrumentation.
+ def dont_test_relation(self):
+ class Node(fixtures.Base):
+ pass
+
+ foonodes = nodes.select().where(nodes.c.type=='foo').alias()
+ barnodes = nodes.select().where(nodes.c.type=='bar').alias()
+
+ # TODO: the order of instrumentation here is not deterministic;
+ # therefore the test fails sporadically since "Node.data" references
+ # different mappers at different times
+ m1 = mapper(Node, nodes)
+ m2 = mapper(Node, foonodes, entity_name='foo')
+ m3 = mapper(Node, barnodes, entity_name='bar')
+
+ m1.add_property('foonodes', relation(m2, primaryjoin=nodes.c.id==foonodes.c.parent_id,
+ backref=backref('foo_parent', remote_side=nodes.c.id, primaryjoin=nodes.c.id==foonodes.c.parent_id)))
+ m1.add_property('barnodes', relation(m3, primaryjoin=nodes.c.id==barnodes.c.parent_id,
+ backref=backref('bar_parent', remote_side=nodes.c.id, primaryjoin=nodes.c.id==barnodes.c.parent_id)))
+
+ sess = create_session()
+
+ n1 = Node(data='n1', type='bat')
+ n1.foonodes.append(Node(data='n2', type='foo'))
+ Node(data='n3', type='bar', bar_parent=n1)
+ sess.save(n1)
+ sess.flush()
+ sess.clear()
+
+ self.assertEquals(sess.query(Node, entity_name="bar").one(), Node(data='n3'))
+ self.assertEquals(sess.query(Node).filter(Node.data=='n1').one(), Node(data='n1', foonodes=[Node(data='n2')], barnodes=[Node(data='n3')]))
if __name__ == "__main__":
testenv.main()