diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2010-03-13 12:28:50 -0500 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2010-03-13 12:28:50 -0500 |
commit | 3290ac23df9eed8a61324eb68f062a7de29e549d (patch) | |
tree | 60804a7b5612beaf706dc424c6ceb43605b73fc0 /test/orm/inheritance/test_basic.py | |
parent | 10659a005cbfc8ded6b905ff822d490f3961f3b7 (diff) | |
download | sqlalchemy-3290ac23df9eed8a61324eb68f062a7de29e549d.tar.gz |
- query.get() now returns None if queried for an identifier
that is present in the identity map with a different class
than the one requested, i.e. when using polymorphic loading.
[ticket:1727]
Diffstat (limited to 'test/orm/inheritance/test_basic.py')
-rw-r--r-- | test/orm/inheritance/test_basic.py | 219 |
1 files changed, 117 insertions, 102 deletions
diff --git a/test/orm/inheritance/test_basic.py b/test/orm/inheritance/test_basic.py index aed7cf5ef..ce773a7bc 100644 --- a/test/orm/inheritance/test_basic.py +++ b/test/orm/inheritance/test_basic.py @@ -28,7 +28,7 @@ class O2MTest(_base.MappedTest): Column('foo_id', Integer, ForeignKey('foo.id'), nullable=False), Column('data', String(20))) - def testbasic(self): + def test_basic(self): class Foo(object): def __init__(self, data=None): self.data = data @@ -279,78 +279,88 @@ class GetTest(_base.MappedTest): Column('foo_id', Integer, ForeignKey('foo.id')), Column('bar_id', Integer, ForeignKey('bar.id')), Column('data', String(20))) + + @classmethod + def setup_classes(cls): + class Foo(_base.BasicEntity): + pass - def _create_test(polymorphic, name): - def test_get(self): - class Foo(object): - pass - - class Bar(Foo): - pass - - class Blub(Bar): - pass - - if polymorphic: - mapper(Foo, foo, polymorphic_on=foo.c.type, polymorphic_identity='foo') - mapper(Bar, bar, inherits=Foo, polymorphic_identity='bar') - mapper(Blub, blub, inherits=Bar, polymorphic_identity='blub') - else: - mapper(Foo, foo) - mapper(Bar, bar, inherits=Foo) - mapper(Blub, blub, inherits=Bar) - - sess = create_session() - f = Foo() - b = Bar() - bl = Blub() - sess.add(f) - sess.add(b) - sess.add(bl) - sess.flush() + class Bar(Foo): + pass - if polymorphic: - def go(): - assert sess.query(Foo).get(f.id) == f - assert sess.query(Foo).get(b.id) == b - assert sess.query(Foo).get(bl.id) == bl - assert sess.query(Bar).get(b.id) == b - assert sess.query(Bar).get(bl.id) == bl - assert sess.query(Blub).get(bl.id) == bl + class Blub(Bar): + pass - self.assert_sql_count(testing.db, go, 0) - else: - # this is testing the 'wrong' behavior of using get() - # polymorphically with mappers that are not configured to be - # polymorphic. the important part being that get() always - # returns an instance of the query's type. - def go(): - assert sess.query(Foo).get(f.id) == f + def test_get_polymorphic(self): + self._do_get_test(True) + + def test_get_nonpolymorphic(self): + self._do_get_test(False) - bb = sess.query(Foo).get(b.id) - assert isinstance(b, Foo) and bb.id==b.id + @testing.resolve_artifact_names + def _do_get_test(self, polymorphic): + if polymorphic: + mapper(Foo, foo, polymorphic_on=foo.c.type, polymorphic_identity='foo') + mapper(Bar, bar, inherits=Foo, polymorphic_identity='bar') + mapper(Blub, blub, inherits=Bar, polymorphic_identity='blub') + else: + mapper(Foo, foo) + mapper(Bar, bar, inherits=Foo) + mapper(Blub, blub, inherits=Bar) - bll = sess.query(Foo).get(bl.id) - assert isinstance(bll, Foo) and bll.id==bl.id + sess = create_session() + f = Foo() + b = Bar() + bl = Blub() + sess.add(f) + sess.add(b) + sess.add(bl) + sess.flush() + + if polymorphic: + def go(): + assert sess.query(Foo).get(f.id) is f + assert sess.query(Foo).get(b.id) is b + assert sess.query(Foo).get(bl.id) is bl + assert sess.query(Bar).get(b.id) is b + assert sess.query(Bar).get(bl.id) is bl + assert sess.query(Blub).get(bl.id) is bl + + # test class mismatches - item is present + # in the identity map but we requested a subclass + assert sess.query(Blub).get(f.id) is None + assert sess.query(Blub).get(b.id) is None + assert sess.query(Bar).get(f.id) is None + + self.assert_sql_count(testing.db, go, 0) + else: + # this is testing the 'wrong' behavior of using get() + # polymorphically with mappers that are not configured to be + # polymorphic. the important part being that get() always + # returns an instance of the query's type. + def go(): + assert sess.query(Foo).get(f.id) is f - assert sess.query(Bar).get(b.id) == b + bb = sess.query(Foo).get(b.id) + assert isinstance(b, Foo) and bb.id==b.id - bll = sess.query(Bar).get(bl.id) - assert isinstance(bll, Bar) and bll.id == bl.id + bll = sess.query(Foo).get(bl.id) + assert isinstance(bll, Foo) and bll.id==bl.id - assert sess.query(Blub).get(bl.id) == bl + assert sess.query(Bar).get(b.id) is b - self.assert_sql_count(testing.db, go, 3) + bll = sess.query(Bar).get(bl.id) + assert isinstance(bll, Bar) and bll.id == bl.id - test_get = function_named(test_get, name) - return test_get + assert sess.query(Blub).get(bl.id) is bl + + self.assert_sql_count(testing.db, go, 3) - test_get_polymorphic = _create_test(True, 'test_get_polymorphic') - test_get_nonpolymorphic = _create_test(False, 'test_get_nonpolymorphic') class EagerLazyTest(_base.MappedTest): """tests eager load/lazy load of child items off inheritance mappers, tests that LazyLoader constructs the right query condition.""" + @classmethod def define_tables(cls, metadata): global foo, bar, bar_foo @@ -367,7 +377,7 @@ class EagerLazyTest(_base.MappedTest): ) @testing.fails_on('maxdb', 'FIXME: unknown') - def testbasic(self): + def test_basic(self): class Foo(object): pass class Bar(Foo): pass @@ -394,7 +404,8 @@ class EagerLazyTest(_base.MappedTest): self.assert_(len(q.first().eager) == 1) class EagerTargetingTest(_base.MappedTest): - """test a scenario where joined table inheritance might be confused as an eagerly loaded joined table.""" + """test a scenario where joined table inheritance might be + confused as an eagerly loaded joined table.""" @classmethod def define_tables(cls, metadata): @@ -450,31 +461,32 @@ class EagerTargetingTest(_base.MappedTest): class FlushTest(_base.MappedTest): """test dependency sorting among inheriting mappers""" + @classmethod def define_tables(cls, metadata): - global users, roles, user_roles, admins - users = Table('users', metadata, + Table('users', metadata, Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('email', String(128)), Column('password', String(16)), ) - roles = Table('role', metadata, + Table('roles', metadata, Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('description', String(32)) ) - user_roles = Table('user_role', metadata, + Table('user_roles', metadata, Column('user_id', Integer, ForeignKey('users.id'), primary_key=True), - Column('role_id', Integer, ForeignKey('role.id'), primary_key=True) + Column('role_id', Integer, ForeignKey('roles.id'), primary_key=True) ) - admins = Table('admin', metadata, + Table('admins', metadata, Column('admin_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('user_id', Integer, ForeignKey('users.id')) ) - def testone(self): + @testing.resolve_artifact_names + def test_one(self): class User(object):pass class Role(object):pass class Admin(User):pass @@ -501,7 +513,8 @@ class FlushTest(_base.MappedTest): assert user_roles.count().scalar() == 1 - def testtwo(self): + @testing.resolve_artifact_names + def test_two(self): class User(object): def __init__(self, email=None, password=None): self.email = email @@ -541,34 +554,24 @@ class FlushTest(_base.MappedTest): class VersioningTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): - global base, subtable, stuff - base = Table('base', metadata, + Table('base', metadata, Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('version_id', Integer, nullable=False), Column('value', String(40)), Column('discriminator', Integer, nullable=False) ) - subtable = Table('subtable', metadata, + Table('subtable', metadata, Column('id', None, ForeignKey('base.id'), primary_key=True), Column('subdata', String(50)) ) - stuff = Table('stuff', metadata, + Table('stuff', metadata, Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('parent', Integer, ForeignKey('base.id')) ) - def setup(self): - super(VersioningTest, self).setup() - if not testing.db.dialect.supports_sane_rowcount: - self._warnings_filters = warnings.filters[:] - warnings.filterwarnings('ignore', category=sa_exc.SAWarning) - - def teardown(self): - super(VersioningTest, self).teardown() - if not testing.db.dialect.supports_sane_rowcount: - warnings.filters[:] = self._warnings_filters - + @testing.emits_warning(r".*updated rowcount") @engines.close_open_connections + @testing.resolve_artifact_names def test_save_update(self): class Base(_fixtures.Base): pass @@ -577,7 +580,10 @@ class VersioningTest(_base.MappedTest): class Stuff(Base): pass mapper(Stuff, stuff) - mapper(Base, base, polymorphic_on=base.c.discriminator, version_id_col=base.c.version_id, polymorphic_identity=1, properties={ + mapper(Base, base, + polymorphic_on=base.c.discriminator, + version_id_col=base.c.version_id, + polymorphic_identity=1, properties={ 'stuff':relation(Stuff) }) mapper(Sub, subtable, inherits=Base, polymorphic_identity=2) @@ -599,17 +605,14 @@ class VersioningTest(_base.MappedTest): sess.flush() - try: - sess2.query(Base).with_lockmode('read').get(s1.id) - assert False - except orm_exc.ConcurrentModificationError, e: - assert True + assert_raises(orm_exc.ConcurrentModificationError, + sess2.query(Base).with_lockmode('read').get, + s1.id) - try: + if not testing.db.dialect.supports_sane_rowcount: sess2.flush() - assert not testing.db.dialect.supports_sane_rowcount - except orm_exc.ConcurrentModificationError, e: - assert True + else: + assert_raises(orm_exc.ConcurrentModificationError, sess2.flush) sess2.refresh(s2) if testing.db.dialect.supports_sane_rowcount: @@ -617,13 +620,17 @@ class VersioningTest(_base.MappedTest): s2.subdata = 'sess2 subdata' sess2.flush() + @testing.emits_warning(r".*updated rowcount") + @testing.resolve_artifact_names def test_delete(self): class Base(_fixtures.Base): pass class Sub(Base): pass - mapper(Base, base, polymorphic_on=base.c.discriminator, version_id_col=base.c.version_id, polymorphic_identity=1) + mapper(Base, base, + polymorphic_on=base.c.discriminator, + version_id_col=base.c.version_id, polymorphic_identity=1) mapper(Sub, subtable, inherits=Base, polymorphic_identity=2) sess = create_session() @@ -697,17 +704,24 @@ class DistinctPKTest(_base.MappedTest): def test_explicit_props(self): person_mapper = mapper(Person, person_table) - mapper(Employee, employee_table, inherits=person_mapper, properties={'pid':person_table.c.id, 'eid':employee_table.c.id}) + mapper(Employee, employee_table, inherits=person_mapper, + properties={'pid':person_table.c.id, + 'eid':employee_table.c.id}) self._do_test(True) def test_explicit_composite_pk(self): person_mapper = mapper(Person, person_table) - try: - mapper(Employee, employee_table, inherits=person_mapper, primary_key=[person_table.c.id, employee_table.c.id]) - self._do_test(True) - assert False - except sa_exc.SAWarning, e: - assert str(e) == "On mapper Mapper|Employee|employees, primary key column 'employees.id' is being combined with distinct primary key column 'persons.id' in attribute 'id'. Use explicit properties to give each column its own mapped attribute name.", str(e) + mapper(Employee, employee_table, + inherits=person_mapper, + primary_key=[person_table.c.id, employee_table.c.id]) + assert_raises_message(sa_exc.SAWarning, + r"On mapper Mapper\|Employee\|employees, " + "primary key column 'employees.id' is being " + "combined with distinct primary key column 'persons.id' " + "in attribute 'id'. Use explicit properties to give " + "each column its own mapped attribute name.", + self._do_test, True + ) def test_explicit_pk(self): person_mapper = mapper(Person, person_table) @@ -1242,6 +1256,7 @@ class DeleteOrphanTest(_base.MappedTest): s1 = SubClass(data='s1') sess.add(s1) assert_raises_message(orm_exc.FlushError, - "is not attached to any parent 'Parent' instance via that classes' 'related' attribute", sess.flush) + r"is not attached to any parent 'Parent' instance via " + "that classes' 'related' attribute", sess.flush) |