summaryrefslogtreecommitdiff
path: root/test/orm/inheritance/test_basic.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2009-06-10 21:18:24 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2009-06-10 21:18:24 +0000
commit45cec095b4904ba71425d2fe18c143982dd08f43 (patch)
treeaf5e540fdcbf1cb2a3337157d69d4b40be010fa8 /test/orm/inheritance/test_basic.py
parent698a3c1ac665e7cd2ef8d5ad3ebf51b7fe6661f4 (diff)
downloadsqlalchemy-45cec095b4904ba71425d2fe18c143982dd08f43.tar.gz
- unit tests have been migrated from unittest to nose.
See README.unittests for information on how to run the tests. [ticket:970]
Diffstat (limited to 'test/orm/inheritance/test_basic.py')
-rw-r--r--test/orm/inheritance/test_basic.py1027
1 files changed, 1027 insertions, 0 deletions
diff --git a/test/orm/inheritance/test_basic.py b/test/orm/inheritance/test_basic.py
new file mode 100644
index 000000000..fc4aae17d
--- /dev/null
+++ b/test/orm/inheritance/test_basic.py
@@ -0,0 +1,1027 @@
+from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
+from sqlalchemy import *
+from sqlalchemy import exc as sa_exc, util
+from sqlalchemy.orm import *
+from sqlalchemy.orm import exc as orm_exc
+
+from sqlalchemy.test import testing, engines
+from sqlalchemy.util import function_named
+from test.orm import _base, _fixtures
+
+class O2MTest(_base.MappedTest):
+ """deals with inheritance and one-to-many relationships"""
+ @classmethod
+ def define_tables(cls, metadata):
+ global foo, bar, blub
+ foo = Table('foo', metadata,
+ Column('id', Integer, Sequence('foo_seq', optional=True),
+ primary_key=True),
+ Column('data', String(20)))
+
+ bar = Table('bar', metadata,
+ Column('id', Integer, ForeignKey('foo.id'), primary_key=True),
+ Column('data', String(20)))
+
+ blub = Table('blub', metadata,
+ Column('id', Integer, ForeignKey('bar.id'), primary_key=True),
+ Column('foo_id', Integer, ForeignKey('foo.id'), nullable=False),
+ Column('data', String(20)))
+
+ def testbasic(self):
+ class Foo(object):
+ def __init__(self, data=None):
+ self.data = data
+ def __repr__(self):
+ return "Foo id %d, data %s" % (self.id, self.data)
+ mapper(Foo, foo)
+
+ class Bar(Foo):
+ def __repr__(self):
+ return "Bar id %d, data %s" % (self.id, self.data)
+
+ mapper(Bar, bar, inherits=Foo)
+
+ class Blub(Bar):
+ def __repr__(self):
+ return "Blub id %d, data %s" % (self.id, self.data)
+
+ mapper(Blub, blub, inherits=Bar, properties={
+ 'parent_foo':relation(Foo)
+ })
+
+ sess = create_session()
+ b1 = Blub("blub #1")
+ b2 = Blub("blub #2")
+ f = Foo("foo #1")
+ sess.add(b1)
+ sess.add(b2)
+ sess.add(f)
+ b1.parent_foo = f
+ b2.parent_foo = f
+ sess.flush()
+ compare = ','.join([repr(b1), repr(b2), repr(b1.parent_foo), repr(b2.parent_foo)])
+ sess.expunge_all()
+ l = sess.query(Blub).all()
+ result = ','.join([repr(l[0]), repr(l[1]), repr(l[0].parent_foo), repr(l[1].parent_foo)])
+ print compare
+ print result
+ self.assert_(compare == result)
+ self.assert_(l[0].parent_foo.data == 'foo #1' and l[1].parent_foo.data == 'foo #1')
+
+class FalseDiscriminatorTest(_base.MappedTest):
+ @classmethod
+ def define_tables(cls, metadata):
+ global t1
+ t1 = Table('t1', metadata, Column('id', Integer, primary_key=True), Column('type', Integer, nullable=False))
+
+ def test_false_discriminator(self):
+ class Foo(object):pass
+ class Bar(Foo):pass
+ mapper(Foo, t1, polymorphic_on=t1.c.type, polymorphic_identity=1)
+ mapper(Bar, inherits=Foo, polymorphic_identity=0)
+ sess = create_session()
+ f1 = Bar()
+ sess.add(f1)
+ sess.flush()
+ assert f1.type == 0
+ sess.expunge_all()
+ assert isinstance(sess.query(Foo).one(), Bar)
+
+class PolymorphicSynonymTest(_base.MappedTest):
+ @classmethod
+ def define_tables(cls, metadata):
+ global t1, t2
+ t1 = Table('t1', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('type', String(10), nullable=False),
+ Column('info', String(255)))
+ t2 = Table('t2', metadata,
+ Column('id', Integer, ForeignKey('t1.id'), primary_key=True),
+ Column('data', String(10), nullable=False))
+
+ def test_polymorphic_synonym(self):
+ class T1(_fixtures.Base):
+ def info(self):
+ return "THE INFO IS:" + self._info
+ def _set_info(self, x):
+ self._info = x
+ info = property(info, _set_info)
+
+ class T2(T1):pass
+
+ mapper(T1, t1, polymorphic_on=t1.c.type, polymorphic_identity='t1', properties={
+ 'info':synonym('_info', map_column=True)
+ })
+ mapper(T2, t2, inherits=T1, polymorphic_identity='t2')
+ sess = create_session()
+ at1 = T1(info='at1')
+ at2 = T2(info='at2', data='t2 data')
+ sess.add(at1)
+ sess.add(at2)
+ sess.flush()
+ sess.expunge_all()
+ eq_(sess.query(T2).filter(T2.info=='at2').one(), at2)
+ eq_(at2.info, "THE INFO IS:at2")
+
+
+class CascadeTest(_base.MappedTest):
+ """that cascades on polymorphic relations continue
+ cascading along the path of the instance's mapper, not
+ the base mapper."""
+
+ @classmethod
+ def define_tables(cls, metadata):
+ global t1, t2, t3, t4
+ t1= Table('t1', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', String(30))
+ )
+
+ t2 = Table('t2', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('t1id', Integer, ForeignKey('t1.id')),
+ Column('type', String(30)),
+ Column('data', String(30))
+ )
+ t3 = Table('t3', metadata,
+ Column('id', Integer, ForeignKey('t2.id'), primary_key=True),
+ Column('moredata', String(30)))
+
+ t4 = Table('t4', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('t3id', Integer, ForeignKey('t3.id')),
+ Column('data', String(30)))
+
+ def test_cascade(self):
+ class T1(_fixtures.Base):
+ pass
+ class T2(_fixtures.Base):
+ pass
+ class T3(T2):
+ pass
+ class T4(_fixtures.Base):
+ pass
+
+ mapper(T1, t1, properties={
+ 't2s':relation(T2, cascade="all")
+ })
+ mapper(T2, t2, polymorphic_on=t2.c.type, polymorphic_identity='t2')
+ mapper(T3, t3, inherits=T2, polymorphic_identity='t3', properties={
+ 't4s':relation(T4, cascade="all")
+ })
+ mapper(T4, t4)
+
+ sess = create_session()
+ t1_1 = T1(data='t1')
+
+ t3_1 = T3(data ='t3', moredata='t3')
+ t2_1 = T2(data='t2')
+
+ t1_1.t2s.append(t2_1)
+ t1_1.t2s.append(t3_1)
+
+ t4_1 = T4(data='t4')
+ t3_1.t4s.append(t4_1)
+
+ sess.add(t1_1)
+
+
+ assert t4_1 in sess.new
+ sess.flush()
+
+ sess.delete(t1_1)
+ assert t4_1 in sess.deleted
+ sess.flush()
+
+class GetTest(_base.MappedTest):
+ @classmethod
+ def define_tables(cls, metadata):
+ global foo, bar, blub
+ foo = Table('foo', metadata,
+ Column('id', Integer, Sequence('foo_seq', optional=True),
+ primary_key=True),
+ Column('type', String(30)),
+ Column('data', String(20)))
+
+ bar = Table('bar', metadata,
+ Column('id', Integer, ForeignKey('foo.id'), primary_key=True),
+ Column('data', String(20)))
+
+ blub = Table('blub', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('foo_id', Integer, ForeignKey('foo.id')),
+ Column('bar_id', Integer, ForeignKey('bar.id')),
+ Column('data', String(20)))
+
+ def _create_test(polymorphic, name):
+ def test_get(self):
+ class Foo(object):
+ pass
+
+ class Bar(Foo):
+ pass
+
+ class Blub(Bar):
+ pass
+
+ if polymorphic:
+ mapper(Foo, foo, polymorphic_on=foo.c.type, polymorphic_identity='foo')
+ mapper(Bar, bar, inherits=Foo, polymorphic_identity='bar')
+ mapper(Blub, blub, inherits=Bar, polymorphic_identity='blub')
+ else:
+ mapper(Foo, foo)
+ mapper(Bar, bar, inherits=Foo)
+ mapper(Blub, blub, inherits=Bar)
+
+ sess = create_session()
+ f = Foo()
+ b = Bar()
+ bl = Blub()
+ sess.add(f)
+ sess.add(b)
+ sess.add(bl)
+ sess.flush()
+
+ if polymorphic:
+ def go():
+ assert sess.query(Foo).get(f.id) == f
+ assert sess.query(Foo).get(b.id) == b
+ assert sess.query(Foo).get(bl.id) == bl
+ assert sess.query(Bar).get(b.id) == b
+ assert sess.query(Bar).get(bl.id) == bl
+ assert sess.query(Blub).get(bl.id) == bl
+
+ self.assert_sql_count(testing.db, go, 0)
+ else:
+ # this is testing the 'wrong' behavior of using get()
+ # polymorphically with mappers that are not configured to be
+ # polymorphic. the important part being that get() always
+ # returns an instance of the query's type.
+ def go():
+ assert sess.query(Foo).get(f.id) == f
+
+ bb = sess.query(Foo).get(b.id)
+ assert isinstance(b, Foo) and bb.id==b.id
+
+ bll = sess.query(Foo).get(bl.id)
+ assert isinstance(bll, Foo) and bll.id==bl.id
+
+ assert sess.query(Bar).get(b.id) == b
+
+ bll = sess.query(Bar).get(bl.id)
+ assert isinstance(bll, Bar) and bll.id == bl.id
+
+ assert sess.query(Blub).get(bl.id) == bl
+
+ self.assert_sql_count(testing.db, go, 3)
+
+ test_get = function_named(test_get, name)
+ return test_get
+
+ test_get_polymorphic = _create_test(True, 'test_get_polymorphic')
+ test_get_nonpolymorphic = _create_test(False, 'test_get_nonpolymorphic')
+
+class EagerLazyTest(_base.MappedTest):
+ """tests eager load/lazy load of child items off inheritance mappers, tests that
+ LazyLoader constructs the right query condition."""
+ @classmethod
+ def define_tables(cls, metadata):
+ global foo, bar, bar_foo
+ foo = Table('foo', metadata,
+ Column('id', Integer, Sequence('foo_seq', optional=True),
+ primary_key=True),
+ Column('data', String(30)))
+ bar = Table('bar', metadata,
+ Column('id', Integer, ForeignKey('foo.id'), primary_key=True),
+ Column('data', String(30)))
+
+ bar_foo = Table('bar_foo', metadata,
+ Column('bar_id', Integer, ForeignKey('bar.id')),
+ Column('foo_id', Integer, ForeignKey('foo.id'))
+ )
+
+ @testing.fails_on('maxdb', 'FIXME: unknown')
+ def testbasic(self):
+ class Foo(object): pass
+ class Bar(Foo): pass
+
+ foos = mapper(Foo, foo)
+ bars = mapper(Bar, bar, inherits=foos)
+ bars.add_property('lazy', relation(foos, bar_foo, lazy=True))
+ bars.add_property('eager', relation(foos, bar_foo, lazy=False))
+
+ foo.insert().execute(data='foo1')
+ bar.insert().execute(id=1, data='bar1')
+
+ foo.insert().execute(data='foo2')
+ bar.insert().execute(id=2, data='bar2')
+
+ foo.insert().execute(data='foo3') #3
+ foo.insert().execute(data='foo4') #4
+
+ bar_foo.insert().execute(bar_id=1, foo_id=3)
+ bar_foo.insert().execute(bar_id=2, foo_id=4)
+
+ sess = create_session()
+ q = sess.query(Bar)
+ self.assert_(len(q.first().lazy) == 1)
+ self.assert_(len(q.first().eager) == 1)
+
+
+class FlushTest(_base.MappedTest):
+ """test dependency sorting among inheriting mappers"""
+ @classmethod
+ def define_tables(cls, metadata):
+ global users, roles, user_roles, admins
+ users = Table('users', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('email', String(128)),
+ Column('password', String(16)),
+ )
+
+ roles = Table('role', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('description', String(32))
+ )
+
+ user_roles = Table('user_role', metadata,
+ Column('user_id', Integer, ForeignKey('users.id'), primary_key=True),
+ Column('role_id', Integer, ForeignKey('role.id'), primary_key=True)
+ )
+
+ admins = Table('admin', metadata,
+ Column('admin_id', Integer, primary_key=True),
+ Column('user_id', Integer, ForeignKey('users.id'))
+ )
+
+ def testone(self):
+ class User(object):pass
+ class Role(object):pass
+ class Admin(User):pass
+ role_mapper = mapper(Role, roles)
+ user_mapper = mapper(User, users, properties = {
+ 'roles' : relation(Role, secondary=user_roles, lazy=False)
+ }
+ )
+ admin_mapper = mapper(Admin, admins, inherits=user_mapper)
+ sess = create_session()
+ adminrole = Role()
+ sess.add(adminrole)
+ sess.flush()
+
+ # create an Admin, and append a Role. the dependency processors
+ # corresponding to the "roles" attribute for the Admin mapper and the User mapper
+ # have to ensure that two dependency processors dont fire off and insert the
+ # many to many row twice.
+ a = Admin()
+ a.roles.append(adminrole)
+ a.password = 'admin'
+ sess.add(a)
+ sess.flush()
+
+ assert user_roles.count().scalar() == 1
+
+ def testtwo(self):
+ class User(object):
+ def __init__(self, email=None, password=None):
+ self.email = email
+ self.password = password
+
+ class Role(object):
+ def __init__(self, description=None):
+ self.description = description
+
+ class Admin(User):pass
+
+ role_mapper = mapper(Role, roles)
+ user_mapper = mapper(User, users, properties = {
+ 'roles' : relation(Role, secondary=user_roles, lazy=False)
+ }
+ )
+
+ admin_mapper = mapper(Admin, admins, inherits=user_mapper)
+
+ # create roles
+ adminrole = Role('admin')
+
+ sess = create_session()
+ sess.add(adminrole)
+ sess.flush()
+
+ # create admin user
+ a = Admin(email='tim', password='admin')
+ a.roles.append(adminrole)
+ sess.add(a)
+ sess.flush()
+
+ a.password = 'sadmin'
+ sess.flush()
+ assert user_roles.count().scalar() == 1
+
+class VersioningTest(_base.MappedTest):
+ @classmethod
+ def define_tables(cls, metadata):
+ global base, subtable, stuff
+ base = Table('base', metadata,
+ Column('id', Integer, Sequence('version_test_seq', optional=True), primary_key=True ),
+ Column('version_id', Integer, nullable=False),
+ Column('value', String(40)),
+ Column('discriminator', Integer, nullable=False)
+ )
+ subtable = Table('subtable', metadata,
+ Column('id', None, ForeignKey('base.id'), primary_key=True),
+ Column('subdata', String(50))
+ )
+ stuff = Table('stuff', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('parent', Integer, ForeignKey('base.id'))
+ )
+
+ @testing.fails_on('mssql', 'FIXME: the flush still happens with the concurrency issue.')
+ @engines.close_open_connections
+ def test_save_update(self):
+ class Base(_fixtures.Base):
+ pass
+ class Sub(Base):
+ pass
+ class Stuff(Base):
+ pass
+ mapper(Stuff, stuff)
+ mapper(Base, base, polymorphic_on=base.c.discriminator, version_id_col=base.c.version_id, polymorphic_identity=1, properties={
+ 'stuff':relation(Stuff)
+ })
+ mapper(Sub, subtable, inherits=Base, polymorphic_identity=2)
+
+ sess = create_session()
+
+ b1 = Base(value='b1')
+ s1 = Sub(value='sub1', subdata='some subdata')
+ sess.add(b1)
+ sess.add(s1)
+
+ sess.flush()
+
+ sess2 = create_session()
+ s2 = sess2.query(Base).get(s1.id)
+ s2.subdata = 'sess2 subdata'
+
+ s1.subdata = 'sess1 subdata'
+
+ sess.flush()
+
+ try:
+ sess2.query(Base).with_lockmode('read').get(s1.id)
+ assert False
+ except orm_exc.ConcurrentModificationError, e:
+ assert True
+
+ try:
+ sess2.flush()
+ assert False
+ except orm_exc.ConcurrentModificationError, e:
+ assert True
+
+ sess2.refresh(s2)
+ assert s2.subdata == 'sess1 subdata'
+ s2.subdata = 'sess2 subdata'
+ sess2.flush()
+
+ @testing.fails_on('mssql', 'FIXME: the flush still happens with the concurrency issue.')
+ def test_delete(self):
+ class Base(_fixtures.Base):
+ pass
+ class Sub(Base):
+ pass
+
+ mapper(Base, base, polymorphic_on=base.c.discriminator, version_id_col=base.c.version_id, polymorphic_identity=1)
+ mapper(Sub, subtable, inherits=Base, polymorphic_identity=2)
+
+ sess = create_session()
+
+ b1 = Base(value='b1')
+ s1 = Sub(value='sub1', subdata='some subdata')
+ s2 = Sub(value='sub2', subdata='some other subdata')
+ sess.add(b1)
+ sess.add(s1)
+ sess.add(s2)
+
+ sess.flush()
+
+ sess2 = create_session()
+ s3 = sess2.query(Base).get(s1.id)
+ sess2.delete(s3)
+ sess2.flush()
+
+ s2.subdata = 'some new subdata'
+ sess.flush()
+
+ try:
+ s1.subdata = 'some new subdata'
+ sess.flush()
+ assert False
+ except orm_exc.ConcurrentModificationError, e:
+ assert True
+
+class DistinctPKTest(_base.MappedTest):
+ """test the construction of mapper.primary_key when an inheriting relationship
+ joins on a column other than primary key column."""
+
+ run_inserts = 'once'
+ run_deletes = None
+
+ @classmethod
+ def define_tables(cls, metadata):
+ global person_table, employee_table, Person, Employee
+
+ person_table = Table("persons", metadata,
+ Column("id", Integer, primary_key=True),
+ Column("name", String(80)),
+ )
+
+ employee_table = Table("employees", metadata,
+ Column("id", Integer, primary_key=True),
+ Column("salary", Integer),
+ Column("person_id", Integer, ForeignKey("persons.id")),
+ )
+
+ class Person(object):
+ def __init__(self, name):
+ self.name = name
+
+ class Employee(Person): pass
+
+ @classmethod
+ def insert_data(cls):
+ person_insert = person_table.insert()
+ person_insert.execute(id=1, name='alice')
+ person_insert.execute(id=2, name='bob')
+
+ employee_insert = employee_table.insert()
+ employee_insert.execute(id=2, salary=250, person_id=1) # alice
+ employee_insert.execute(id=3, salary=200, person_id=2) # bob
+
+ def test_implicit(self):
+ person_mapper = mapper(Person, person_table)
+ mapper(Employee, employee_table, inherits=person_mapper)
+ assert list(class_mapper(Employee).primary_key) == [person_table.c.id]
+
+ def test_explicit_props(self):
+ person_mapper = mapper(Person, person_table)
+ mapper(Employee, employee_table, inherits=person_mapper, properties={'pid':person_table.c.id, 'eid':employee_table.c.id})
+ self._do_test(True)
+
+ def test_explicit_composite_pk(self):
+ person_mapper = mapper(Person, person_table)
+ try:
+ mapper(Employee, employee_table, inherits=person_mapper, primary_key=[person_table.c.id, employee_table.c.id])
+ self._do_test(True)
+ assert False
+ except sa_exc.SAWarning, e:
+ assert str(e) == "On mapper Mapper|Employee|employees, primary key column 'employees.id' is being combined with distinct primary key column 'persons.id' in attribute 'id'. Use explicit properties to give each column its own mapped attribute name.", str(e)
+
+ def test_explicit_pk(self):
+ person_mapper = mapper(Person, person_table)
+ mapper(Employee, employee_table, inherits=person_mapper, primary_key=[person_table.c.id])
+ self._do_test(False)
+
+ def _do_test(self, composite):
+ session = create_session()
+ query = session.query(Employee)
+
+ if composite:
+ alice1 = query.get([1,2])
+ bob = query.get([2,3])
+ alice2 = query.get([1,2])
+ else:
+ alice1 = query.get(1)
+ bob = query.get(2)
+ alice2 = query.get(1)
+
+ assert alice1.name == alice2.name == 'alice'
+ assert bob.name == 'bob'
+
+class SyncCompileTest(_base.MappedTest):
+ """test that syncrules compile properly on custom inherit conds"""
+ @classmethod
+ def define_tables(cls, metadata):
+ global _a_table, _b_table, _c_table
+
+ _a_table = Table('a', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data1', String(128))
+ )
+
+ _b_table = Table('b', metadata,
+ Column('a_id', Integer, ForeignKey('a.id'), primary_key=True),
+ Column('data2', String(128))
+ )
+
+ _c_table = Table('c', metadata,
+ # Column('a_id', Integer, ForeignKey('b.a_id'), primary_key=True), #works
+ Column('b_a_id', Integer, ForeignKey('b.a_id'), primary_key=True),
+ Column('data3', String(128))
+ )
+
+ def test_joins(self):
+ for j1 in (None, _b_table.c.a_id==_a_table.c.id, _a_table.c.id==_b_table.c.a_id):
+ for j2 in (None, _b_table.c.a_id==_c_table.c.b_a_id, _c_table.c.b_a_id==_b_table.c.a_id):
+ self._do_test(j1, j2)
+ for t in reversed(_a_table.metadata.sorted_tables):
+ t.delete().execute().close()
+
+ def _do_test(self, j1, j2):
+ class A(object):
+ def __init__(self, **kwargs):
+ for key, value in kwargs.items():
+ setattr(self, key, value)
+
+ class B(A):
+ pass
+
+ class C(B):
+ pass
+
+ mapper(A, _a_table)
+ mapper(B, _b_table, inherits=A,
+ inherit_condition=j1
+ )
+ mapper(C, _c_table, inherits=B,
+ inherit_condition=j2
+ )
+
+ session = create_session()
+
+ a = A(data1='a1')
+ session.add(a)
+
+ b = B(data1='b1', data2='b2')
+ session.add(b)
+
+ c = C(data1='c1', data2='c2', data3='c3')
+ session.add(c)
+
+ session.flush()
+ session.expunge_all()
+
+ assert len(session.query(A).all()) == 3
+ assert len(session.query(B).all()) == 2
+ assert len(session.query(C).all()) == 1
+
+class OverrideColKeyTest(_base.MappedTest):
+ """test overriding of column attributes."""
+
+ @classmethod
+ def define_tables(cls, metadata):
+ global base, subtable
+
+ base = Table('base', metadata,
+ Column('base_id', Integer, primary_key=True),
+ Column('data', String(255)),
+ Column('sqlite_fixer', String(10))
+ )
+
+ subtable = Table('subtable', metadata,
+ Column('base_id', Integer, ForeignKey('base.base_id'), primary_key=True),
+ Column('subdata', String(255))
+ )
+
+ def test_plain(self):
+ # control case
+ class Base(object):
+ pass
+ class Sub(Base):
+ pass
+
+ mapper(Base, base)
+ mapper(Sub, subtable, inherits=Base)
+
+ # Sub gets a "base_id" property using the "base_id"
+ # column of both tables.
+ eq_(
+ class_mapper(Sub).get_property('base_id').columns,
+ [base.c.base_id, subtable.c.base_id]
+ )
+
+ def test_override_explicit(self):
+ # this pattern is what you see when using declarative
+ # in particular, here we do a "manual" version of
+ # what we'd like the mapper to do.
+
+ class Base(object):
+ pass
+ class Sub(Base):
+ pass
+
+ mapper(Base, base, properties={
+ 'id':base.c.base_id
+ })
+ mapper(Sub, subtable, inherits=Base, properties={
+ # this is the manual way to do it, is not really
+ # possible in declarative
+ 'id':[base.c.base_id, subtable.c.base_id]
+ })
+
+ eq_(
+ class_mapper(Sub).get_property('id').columns,
+ [base.c.base_id, subtable.c.base_id]
+ )
+
+ s1 = Sub()
+ s1.id = 10
+ sess = create_session()
+ sess.add(s1)
+ sess.flush()
+ assert sess.query(Sub).get(10) is s1
+
+ def test_override_onlyinparent(self):
+ class Base(object):
+ pass
+ class Sub(Base):
+ pass
+
+ mapper(Base, base, properties={
+ 'id':base.c.base_id
+ })
+ mapper(Sub, subtable, inherits=Base)
+
+ eq_(
+ class_mapper(Sub).get_property('id').columns,
+ [base.c.base_id]
+ )
+
+ eq_(
+ class_mapper(Sub).get_property('base_id').columns,
+ [subtable.c.base_id]
+ )
+
+ s1 = Sub()
+ s1.id = 10
+
+ s2 = Sub()
+ s2.base_id = 15
+
+ sess = create_session()
+ sess.add_all([s1, s2])
+ sess.flush()
+
+ # s1 gets '10'
+ assert sess.query(Sub).get(10) is s1
+
+ # s2 gets a new id, base_id is overwritten by the ultimate
+ # PK col
+ assert s2.id == s2.base_id != 15
+
+ def test_override_implicit(self):
+ # this is how the pattern looks intuitively when
+ # using declarative.
+ # fixed as part of [ticket:1111]
+
+ class Base(object):
+ pass
+ class Sub(Base):
+ pass
+
+ mapper(Base, base, properties={
+ 'id':base.c.base_id
+ })
+ mapper(Sub, subtable, inherits=Base, properties={
+ 'id':subtable.c.base_id
+ })
+
+ # Sub mapper compilation needs to detect that "base.c.base_id"
+ # is renamed in the inherited mapper as "id", even though
+ # it has its own "id" property. Sub's "id" property
+ # gets joined normally with the extra column.
+
+ eq_(
+ class_mapper(Sub).get_property('id').columns,
+ [base.c.base_id, subtable.c.base_id]
+ )
+
+ s1 = Sub()
+ s1.id = 10
+ sess = create_session()
+ sess.add(s1)
+ sess.flush()
+ assert sess.query(Sub).get(10) is s1
+
+ def test_plain_descriptor(self):
+ """test that descriptors prevent inheritance from propigating properties to subclasses."""
+
+ class Base(object):
+ pass
+ class Sub(Base):
+ @property
+ def data(self):
+ return "im the data"
+
+ mapper(Base, base)
+ mapper(Sub, subtable, inherits=Base)
+
+ s1 = Sub()
+ sess = create_session()
+ sess.add(s1)
+ sess.flush()
+ assert sess.query(Sub).one().data == "im the data"
+
+ def test_custom_descriptor(self):
+ """test that descriptors prevent inheritance from propigating properties to subclasses."""
+
+ class MyDesc(object):
+ def __get__(self, instance, owner):
+ if instance is None:
+ return self
+ return "im the data"
+
+ class Base(object):
+ pass
+ class Sub(Base):
+ data = MyDesc()
+
+ mapper(Base, base)
+ mapper(Sub, subtable, inherits=Base)
+
+ s1 = Sub()
+ sess = create_session()
+ sess.add(s1)
+ sess.flush()
+ assert sess.query(Sub).one().data == "im the data"
+
+ def test_sub_columns_over_base_descriptors(self):
+ class Base(object):
+ @property
+ def subdata(self):
+ return "this is base"
+
+ class Sub(Base):
+ pass
+
+ mapper(Base, base)
+ mapper(Sub, subtable, inherits=Base)
+
+ sess = create_session()
+ b1 = Base()
+ assert b1.subdata == "this is base"
+ s1 = Sub()
+ s1.subdata = "this is sub"
+ assert s1.subdata == "this is sub"
+
+ sess.add_all([s1, b1])
+ sess.flush()
+ sess.expunge_all()
+
+ assert sess.query(Base).get(b1.base_id).subdata == "this is base"
+ assert sess.query(Sub).get(s1.base_id).subdata == "this is sub"
+
+ def test_base_descriptors_over_base_cols(self):
+ class Base(object):
+ @property
+ def data(self):
+ return "this is base"
+
+ class Sub(Base):
+ pass
+
+ mapper(Base, base)
+ mapper(Sub, subtable, inherits=Base)
+
+ sess = create_session()
+ b1 = Base()
+ assert b1.data == "this is base"
+ s1 = Sub()
+ assert s1.data == "this is base"
+
+ sess.add_all([s1, b1])
+ sess.flush()
+ sess.expunge_all()
+
+ assert sess.query(Base).get(b1.base_id).data == "this is base"
+ assert sess.query(Sub).get(s1.base_id).data == "this is base"
+
+class OptimizedLoadTest(_base.MappedTest):
+ """test that the 'optimized load' routine doesn't crash when
+ a column in the join condition is not available.
+
+ """
+ @classmethod
+ def define_tables(cls, metadata):
+ global base, sub
+ base = Table('base', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', String(50)),
+ Column('type', String(50))
+ )
+ sub = Table('sub', metadata,
+ Column('id', Integer, ForeignKey('base.id'), primary_key=True),
+ Column('sub', String(50))
+ )
+
+ def test_optimized_passes(self):
+ class Base(object):
+ pass
+ class Sub(Base):
+ pass
+
+ mapper(Base, base, polymorphic_on=base.c.type, polymorphic_identity='base')
+
+ # redefine Sub's "id" to favor the "id" col in the subtable.
+ # "id" is also part of the primary join condition
+ mapper(Sub, sub, inherits=Base, polymorphic_identity='sub', properties={'id':sub.c.id})
+ sess = create_session()
+ s1 = Sub()
+ s1.data = 's1data'
+ s1.sub = 's1sub'
+ sess.add(s1)
+ sess.flush()
+ sess.expunge_all()
+
+ # load s1 via Base. s1.id won't populate since it's relative to
+ # the "sub" table. The optimized load kicks in and tries to
+ # generate on the primary join, but cannot since "id" is itself unloaded.
+ # the optimized load needs to return "None" so regular full-row loading proceeds
+ s1 = sess.query(Base).get(s1.id)
+ assert s1.sub == 's1sub'
+
+class PKDiscriminatorTest(_base.MappedTest):
+ @classmethod
+ def define_tables(cls, metadata):
+ parents = Table('parents', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('name', String(60)))
+
+ children = Table('children', metadata,
+ Column('id', Integer, ForeignKey('parents.id'), primary_key=True),
+ Column('type', Integer,primary_key=True),
+ Column('name', String(60)))
+
+ @testing.resolve_artifact_names
+ def test_pk_as_discriminator(self):
+ class Parent(object):
+ def __init__(self, name=None):
+ self.name = name
+
+ class Child(object):
+ def __init__(self, name=None):
+ self.name = name
+
+ class A(Child):
+ pass
+
+ mapper(Parent, parents, properties={
+ 'children': relation(Child, backref='parent'),
+ })
+ mapper(Child, children, polymorphic_on=children.c.type,
+ polymorphic_identity=1)
+
+ mapper(A, inherits=Child, polymorphic_identity=2)
+
+ s = create_session()
+ p = Parent('p1')
+ a = A('a1')
+ p.children.append(a)
+ s.add(p)
+ s.flush()
+
+ assert a.id
+ assert a.type == 2
+
+
+class DeleteOrphanTest(_base.MappedTest):
+ @classmethod
+ def define_tables(cls, metadata):
+ global single, parent
+ single = Table('single', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('type', String(50), nullable=False),
+ Column('data', String(50)),
+ Column('parent_id', Integer, ForeignKey('parent.id'), nullable=False),
+ )
+
+ parent = Table('parent', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', String(50))
+ )
+
+ def test_orphan_message(self):
+ class Base(_fixtures.Base):
+ pass
+
+ class SubClass(Base):
+ pass
+
+ class Parent(_fixtures.Base):
+ pass
+
+ mapper(Base, single, polymorphic_on=single.c.type, polymorphic_identity='base')
+ mapper(SubClass, inherits=Base, polymorphic_identity='sub')
+ mapper(Parent, parent, properties={
+ 'related':relation(Base, cascade="all, delete-orphan")
+ })
+
+ sess = create_session()
+ s1 = SubClass(data='s1')
+ sess.add(s1)
+ assert_raises_message(orm_exc.FlushError,
+ "is not attached to any parent 'Parent' instance via that classes' 'related' attribute", sess.flush)
+
+