summaryrefslogtreecommitdiff
path: root/test/orm/inheritance/basic.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/orm/inheritance/basic.py')
-rw-r--r--test/orm/inheritance/basic.py409
1 files changed, 409 insertions, 0 deletions
diff --git a/test/orm/inheritance/basic.py b/test/orm/inheritance/basic.py
new file mode 100644
index 000000000..be623e1b8
--- /dev/null
+++ b/test/orm/inheritance/basic.py
@@ -0,0 +1,409 @@
+import testbase
+from sqlalchemy import *
+from sqlalchemy.orm import *
+from testlib import *
+
+
+class O2MTest(ORMTest):
+ """deals with inheritance and one-to-many relationships"""
+ def define_tables(self, metadata):
+ global foo, bar, blub
+ # the 'data' columns are to appease SQLite which cant handle a blank INSERT
+ foo = Table('foo', metadata,
+ Column('id', Integer, Sequence('foo_seq'), primary_key=True),
+ Column('data', String(20)))
+
+ bar = Table('bar', metadata,
+ Column('id', Integer, ForeignKey('foo.id'), primary_key=True),
+ Column('data', String(20)))
+
+ blub = Table('blub', metadata,
+ Column('id', Integer, ForeignKey('bar.id'), primary_key=True),
+ Column('foo_id', Integer, ForeignKey('foo.id'), nullable=False),
+ Column('data', String(20)))
+
+ def testbasic(self):
+ class Foo(object):
+ def __init__(self, data=None):
+ self.data = data
+ def __repr__(self):
+ return "Foo id %d, data %s" % (self.id, self.data)
+ mapper(Foo, foo)
+
+ class Bar(Foo):
+ def __repr__(self):
+ return "Bar id %d, data %s" % (self.id, self.data)
+
+ mapper(Bar, bar, inherits=Foo)
+
+ class Blub(Bar):
+ def __repr__(self):
+ return "Blub id %d, data %s" % (self.id, self.data)
+
+ mapper(Blub, blub, inherits=Bar, properties={
+ 'parent_foo':relation(Foo)
+ })
+
+ sess = create_session()
+ b1 = Blub("blub #1")
+ b2 = Blub("blub #2")
+ f = Foo("foo #1")
+ sess.save(b1)
+ sess.save(b2)
+ sess.save(f)
+ b1.parent_foo = f
+ b2.parent_foo = f
+ sess.flush()
+ compare = repr(b1) + repr(b2) + repr(b1.parent_foo) + repr(b2.parent_foo)
+ sess.clear()
+ l = sess.query(Blub).select()
+ result = repr(l[0]) + repr(l[1]) + repr(l[0].parent_foo) + repr(l[1].parent_foo)
+ print result
+ self.assert_(compare == result)
+ self.assert_(l[0].parent_foo.data == 'foo #1' and l[1].parent_foo.data == 'foo #1')
+
+class GetTest(ORMTest):
+ def define_tables(self, metadata):
+ global foo, bar, blub
+ foo = Table('foo', metadata,
+ Column('id', Integer, Sequence('foo_seq'), primary_key=True),
+ Column('type', String(30)),
+ Column('data', String(20)))
+
+ bar = Table('bar', metadata,
+ Column('id', Integer, ForeignKey('foo.id'), primary_key=True),
+ Column('data', String(20)))
+
+ blub = Table('blub', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('foo_id', Integer, ForeignKey('foo.id')),
+ Column('bar_id', Integer, ForeignKey('bar.id')),
+ Column('data', String(20)))
+
+ def create_test(polymorphic):
+ def test_get(self):
+ class Foo(object):
+ pass
+
+ class Bar(Foo):
+ pass
+
+ class Blub(Bar):
+ pass
+
+ if polymorphic:
+ mapper(Foo, foo, polymorphic_on=foo.c.type, polymorphic_identity='foo')
+ mapper(Bar, bar, inherits=Foo, polymorphic_identity='bar')
+ mapper(Blub, blub, inherits=Bar, polymorphic_identity='blub')
+ else:
+ mapper(Foo, foo)
+ mapper(Bar, bar, inherits=Foo)
+ mapper(Blub, blub, inherits=Bar)
+
+ sess = create_session()
+ f = Foo()
+ b = Bar()
+ bl = Blub()
+ sess.save(f)
+ sess.save(b)
+ sess.save(bl)
+ sess.flush()
+
+ if polymorphic:
+ def go():
+ assert sess.query(Foo).get(f.id) == f
+ assert sess.query(Foo).get(b.id) == b
+ assert sess.query(Foo).get(bl.id) == bl
+ assert sess.query(Bar).get(b.id) == b
+ assert sess.query(Bar).get(bl.id) == bl
+ assert sess.query(Blub).get(bl.id) == bl
+
+ self.assert_sql_count(testbase.db, go, 0)
+ else:
+ # this is testing the 'wrong' behavior of using get()
+ # polymorphically with mappers that are not configured to be
+ # polymorphic. the important part being that get() always
+ # returns an instance of the query's type.
+ def go():
+ assert sess.query(Foo).get(f.id) == f
+
+ bb = sess.query(Foo).get(b.id)
+ assert isinstance(b, Foo) and bb.id==b.id
+
+ bll = sess.query(Foo).get(bl.id)
+ assert isinstance(bll, Foo) and bll.id==bl.id
+
+ assert sess.query(Bar).get(b.id) == b
+
+ bll = sess.query(Bar).get(bl.id)
+ assert isinstance(bll, Bar) and bll.id == bl.id
+
+ assert sess.query(Blub).get(bl.id) == bl
+
+ self.assert_sql_count(testbase.db, go, 3)
+
+ return test_get
+
+ test_get_polymorphic = create_test(True)
+ test_get_nonpolymorphic = create_test(False)
+
+
+class ConstructionTest(ORMTest):
+ def define_tables(self, metadata):
+ global content_type, content, product
+ content_type = Table('content_type', metadata,
+ Column('id', Integer, primary_key=True)
+ )
+ content = Table('content', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('content_type_id', Integer, ForeignKey('content_type.id')),
+ Column('type', String(30))
+ )
+ product = Table('product', metadata,
+ Column('id', Integer, ForeignKey('content.id'), primary_key=True)
+ )
+
+ def testbasic(self):
+ class ContentType(object): pass
+ class Content(object): pass
+ class Product(Content): pass
+
+ content_types = mapper(ContentType, content_type)
+ contents = mapper(Content, content, properties={
+ 'content_type':relation(content_types)
+ }, polymorphic_identity='contents')
+
+ products = mapper(Product, product, inherits=contents, polymorphic_identity='products')
+
+ try:
+ compile_mappers()
+ assert False
+ except exceptions.ArgumentError, e:
+ assert str(e) == "Mapper 'Mapper|Content|content' specifies a polymorphic_identity of 'contents', but no mapper in it's hierarchy specifies the 'polymorphic_on' column argument"
+
+ def testbackref(self):
+ """tests adding a property to the superclass mapper"""
+ class ContentType(object): pass
+ class Content(object): pass
+ class Product(Content): pass
+
+ contents = mapper(Content, content, polymorphic_on=content.c.type, polymorphic_identity='content')
+ products = mapper(Product, product, inherits=contents, polymorphic_identity='product')
+ content_types = mapper(ContentType, content_type, properties={
+ 'content':relation(contents, backref='contenttype')
+ })
+ p = Product()
+ p.contenttype = ContentType()
+ # TODO: assertion ??
+
+class EagerLazyTest(ORMTest):
+ """tests eager load/lazy load of child items off inheritance mappers, tests that
+ LazyLoader constructs the right query condition."""
+ def define_tables(self, metadata):
+ global foo, bar, bar_foo
+ foo = Table('foo', metadata, Column('id', Integer, Sequence('foo_seq'), primary_key=True),
+ Column('data', String(30)))
+ bar = Table('bar', metadata, Column('id', Integer, ForeignKey('foo.id'), primary_key=True),
+ Column('data', String(30)))
+
+ bar_foo = Table('bar_foo', metadata,
+ Column('bar_id', Integer, ForeignKey('bar.id')),
+ Column('foo_id', Integer, ForeignKey('foo.id'))
+ )
+
+ def testbasic(self):
+ class Foo(object): pass
+ class Bar(Foo): pass
+
+ foos = mapper(Foo, foo)
+ bars = mapper(Bar, bar, inherits=foos)
+ bars.add_property('lazy', relation(foos, bar_foo, lazy=True))
+ bars.add_property('eager', relation(foos, bar_foo, lazy=False))
+
+ foo.insert().execute(data='foo1')
+ bar.insert().execute(id=1, data='bar1')
+
+ foo.insert().execute(data='foo2')
+ bar.insert().execute(id=2, data='bar2')
+
+ foo.insert().execute(data='foo3') #3
+ foo.insert().execute(data='foo4') #4
+
+ bar_foo.insert().execute(bar_id=1, foo_id=3)
+ bar_foo.insert().execute(bar_id=2, foo_id=4)
+
+ sess = create_session()
+ q = sess.query(Bar)
+ self.assert_(len(q.selectfirst().lazy) == 1)
+ self.assert_(len(q.selectfirst().eager) == 1)
+
+
+class FlushTest(ORMTest):
+ """test dependency sorting among inheriting mappers"""
+ def define_tables(self, metadata):
+ global users, roles, user_roles, admins
+ users = Table('users', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('email', String(128)),
+ Column('password', String(16)),
+ )
+
+ roles = Table('role', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('description', String(32))
+ )
+
+ user_roles = Table('user_role', metadata,
+ Column('user_id', Integer, ForeignKey('users.id'), primary_key=True),
+ Column('role_id', Integer, ForeignKey('role.id'), primary_key=True)
+ )
+
+ admins = Table('admin', metadata,
+ Column('admin_id', Integer, primary_key=True),
+ Column('user_id', Integer, ForeignKey('users.id'))
+ )
+
+ def testone(self):
+ class User(object):pass
+ class Role(object):pass
+ class Admin(User):pass
+ role_mapper = mapper(Role, roles)
+ user_mapper = mapper(User, users, properties = {
+ 'roles' : relation(Role, secondary=user_roles, lazy=False, private=False)
+ }
+ )
+ admin_mapper = mapper(Admin, admins, inherits=user_mapper)
+ sess = create_session()
+ adminrole = Role('admin')
+ sess.save(adminrole)
+ sess.flush()
+
+ # create an Admin, and append a Role. the dependency processors
+ # corresponding to the "roles" attribute for the Admin mapper and the User mapper
+ # have to ensure that two dependency processors dont fire off and insert the
+ # many to many row twice.
+ a = Admin()
+ a.roles.append(adminrole)
+ a.password = 'admin'
+ sess.save(a)
+ sess.flush()
+
+ assert user_roles.count().scalar() == 1
+
+ def testtwo(self):
+ class User(object):
+ def __init__(self, email=None, password=None):
+ self.email = email
+ self.password = password
+
+ class Role(object):
+ def __init__(self, description=None):
+ self.description = description
+
+ class Admin(User):pass
+
+ role_mapper = mapper(Role, roles)
+ user_mapper = mapper(User, users, properties = {
+ 'roles' : relation(Role, secondary=user_roles, lazy=False, private=False)
+ }
+ )
+
+ admin_mapper = mapper(Admin, admins, inherits=user_mapper)
+
+ # create roles
+ adminrole = Role('admin')
+
+ sess = create_session()
+ sess.save(adminrole)
+ sess.flush()
+
+ # create admin user
+ a = Admin(email='tim', password='admin')
+ a.roles.append(adminrole)
+ sess.save(a)
+ sess.flush()
+
+ a.password = 'sadmin'
+ sess.flush()
+ assert user_roles.count().scalar() == 1
+
+class DistinctPKTest(ORMTest):
+ """test the construction of mapper.primary_key when an inheriting relationship
+ joins on a column other than primary key column."""
+ keep_data = True
+
+ def define_tables(self, metadata):
+ global person_table, employee_table, Person, Employee
+
+ person_table = Table("persons", metadata,
+ Column("id", Integer, primary_key=True),
+ Column("name", String(80)),
+ )
+
+ employee_table = Table("employees", metadata,
+ Column("id", Integer, primary_key=True),
+ Column("salary", Integer),
+ Column("person_id", Integer, ForeignKey("persons.id")),
+ )
+
+ class Person(object):
+ def __init__(self, name):
+ self.name = name
+
+ class Employee(Person): pass
+
+ import warnings
+ warnings.filterwarnings("error", r".*On mapper.*distinct primary key")
+
+ def insert_data(self):
+ person_insert = person_table.insert()
+ person_insert.execute(id=1, name='alice')
+ person_insert.execute(id=2, name='bob')
+
+ employee_insert = employee_table.insert()
+ employee_insert.execute(id=2, salary=250, person_id=1) # alice
+ employee_insert.execute(id=3, salary=200, person_id=2) # bob
+
+ def test_implicit(self):
+ person_mapper = mapper(Person, person_table)
+ mapper(Employee, employee_table, inherits=person_mapper)
+ assert list(class_mapper(Employee).primary_key) == [person_table.c.id]
+
+ def test_explicit_props(self):
+ person_mapper = mapper(Person, person_table)
+ mapper(Employee, employee_table, inherits=person_mapper, properties={'pid':person_table.c.id, 'eid':employee_table.c.id})
+ self._do_test(True)
+
+ def test_explicit_composite_pk(self):
+ person_mapper = mapper(Person, person_table)
+ mapper(Employee, employee_table, inherits=person_mapper, primary_key=[person_table.c.id, employee_table.c.id])
+ try:
+ self._do_test(True)
+ assert False
+ except RuntimeWarning, e:
+ assert str(e) == "On mapper Mapper|Employee|employees, primary key column 'employees.id' is being combined with distinct primary key column 'persons.id' in attribute 'id'. Use explicit properties to give each column its own mapped attribute name."
+
+ def test_explicit_pk(self):
+ person_mapper = mapper(Person, person_table)
+ mapper(Employee, employee_table, inherits=person_mapper, primary_key=[person_table.c.id])
+ self._do_test(False)
+
+ def _do_test(self, composite):
+ session = create_session()
+ query = session.query(Employee)
+
+ if composite:
+ alice1 = query.get([1,2])
+ bob = query.get([2,3])
+ alice2 = query.get([1,2])
+ else:
+ alice1 = query.get(1)
+ bob = query.get(2)
+ alice2 = query.get(1)
+
+ assert alice1.name == alice2.name == 'alice'
+ assert bob.name == 'bob'
+
+
+if __name__ == "__main__":
+ testbase.main()