summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2010-12-03 18:45:42 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2010-12-03 18:45:42 -0500
commitaf75fdf60fd3498ab3c5757e81a5d6b5e52f590d (patch)
tree159b5d33f1349b837e330b260ec8df0c9736ab3b
parentd30678d18de7828f03f8179d7980cab2e66c18bc (diff)
downloadsqlalchemy-af75fdf60fd3498ab3c5757e81a5d6b5e52f590d.tar.gz
- added strictness to the optimized load, [ticket:1992]
-rw-r--r--lib/sqlalchemy/orm/mapper.py13
-rw-r--r--test/orm/inheritance/test_basic.py116
-rw-r--r--test/orm/test_expire.py43
3 files changed, 154 insertions, 18 deletions
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index 1a0f3ad2f..2be249594 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -1291,10 +1291,10 @@ class Mapper(object):
"""
props = self._props
- tables = set(chain(*
- (sqlutil.find_tables(props[key].columns[0],
- check_columns=True)
- for key in attribute_names)
+ tables = set(chain(
+ *[sqlutil.find_tables(c, check_columns=True)
+ for key in attribute_names
+ for c in props[key].columns]
))
if self.base_mapper.local_table in tables:
@@ -1313,7 +1313,7 @@ class Mapper(object):
leftval = self._get_committed_state_attr_by_column(
state, state.dict,
leftcol, passive=True)
- if leftval is attributes.PASSIVE_NO_RESULT:
+ if leftval is attributes.PASSIVE_NO_RESULT or leftval is None:
raise ColumnsNotAvailable()
binary.left = sql.bindparam(None, leftval,
type_=binary.right.type)
@@ -1321,7 +1321,7 @@ class Mapper(object):
rightval = self._get_committed_state_attr_by_column(
state, state.dict,
rightcol, passive=True)
- if rightval is attributes.PASSIVE_NO_RESULT:
+ if rightval is attributes.PASSIVE_NO_RESULT or rightval is None:
raise ColumnsNotAvailable()
binary.right = sql.bindparam(None, rightval,
type_=binary.right.type)
@@ -2451,6 +2451,7 @@ def _load_scalar_attributes(state, attribute_names):
has_key = state.has_identity
result = False
+
if mapper.inherits and not mapper.concrete:
statement = mapper._optimized_get_statement(state, attribute_names)
if statement is not None:
diff --git a/test/orm/inheritance/test_basic.py b/test/orm/inheritance/test_basic.py
index 5892b3c89..544aa1abe 100644
--- a/test/orm/inheritance/test_basic.py
+++ b/test/orm/inheritance/test_basic.py
@@ -3,7 +3,8 @@ from test.lib.testing import eq_, assert_raises, assert_raises_message
from sqlalchemy import *
from sqlalchemy import exc as sa_exc, util
from sqlalchemy.orm import *
-from sqlalchemy.orm import exc as orm_exc
+from sqlalchemy.orm import exc as orm_exc, attributes
+from test.lib.assertsql import AllOf, CompiledSQL
from test.lib import testing, engines
from sqlalchemy.util import function_named
@@ -1171,22 +1172,29 @@ class OptimizedLoadTest(_base.MappedTest):
@classmethod
def define_tables(cls, metadata):
- global base, sub, with_comp
- base = Table('base', metadata,
+ Table('base', metadata,
Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(50)),
- Column('type', String(50))
+ Column('type', String(50)),
+ Column('counter', Integer, server_default="1")
)
- sub = Table('sub', metadata,
+ Table('sub', metadata,
Column('id', Integer, ForeignKey('base.id'), primary_key=True),
- Column('sub', String(50))
+ Column('sub', String(50)),
+ Column('counter', Integer, server_default="1"),
+ Column('counter2', Integer, server_default="1")
+ )
+ Table('subsub', metadata,
+ Column('id', Integer, ForeignKey('sub.id'), primary_key=True),
+ Column('counter2', Integer, server_default="1")
)
- with_comp = Table('with_comp', metadata,
+ Table('with_comp', metadata,
Column('id', Integer, ForeignKey('base.id'), primary_key=True),
Column('a', String(10)),
Column('b', String(10))
)
+ @testing.resolve_artifact_names
def test_optimized_passes(self):
""""test that the 'optimized load' routine doesn't crash when
a column in the join condition is not available."""
@@ -1216,6 +1224,7 @@ class OptimizedLoadTest(_base.MappedTest):
s1 = sess.query(Base).first()
assert s1.sub == 's1sub'
+ @testing.resolve_artifact_names
def test_column_expression(self):
class Base(_base.BasicEntity):
pass
@@ -1233,6 +1242,7 @@ class OptimizedLoadTest(_base.MappedTest):
s1 = sess.query(Base).first()
assert s1.concat == 's1sub|s1sub'
+ @testing.resolve_artifact_names
def test_column_expression_joined(self):
class Base(_base.ComparableEntity):
pass
@@ -1262,6 +1272,7 @@ class OptimizedLoadTest(_base.MappedTest):
]
)
+ @testing.resolve_artifact_names
def test_composite_column_joined(self):
class Base(_base.ComparableEntity):
pass
@@ -1290,6 +1301,97 @@ class OptimizedLoadTest(_base.MappedTest):
assert s2test.comp
eq_(s1test.comp, Comp('ham', 'cheese'))
eq_(s2test.comp, Comp('bacon', 'eggs'))
+
+ @testing.resolve_artifact_names
+ def test_load_expired_on_pending(self):
+ class Base(_base.ComparableEntity):
+ pass
+ class Sub(Base):
+ pass
+ mapper(Base, base, polymorphic_on=base.c.type, polymorphic_identity='base')
+ mapper(Sub, sub, inherits=Base, polymorphic_identity='sub')
+ sess = Session()
+ s1 = Sub(data='s1')
+ sess.add(s1)
+ self.assert_sql_execution(
+ testing.db,
+ sess.flush,
+ CompiledSQL(
+ "INSERT INTO base (data, type) VALUES (:data, :type)",
+ [{'data':'s1','type':'sub'}]
+ ),
+ CompiledSQL(
+ "SELECT sub.counter AS sub_counter, base.counter AS "
+ "base_counter FROM base JOIN sub ON base.id = "
+ "sub.id WHERE base.id = :param_1",
+ lambda ctx:{'param_1':s1.id}
+ ),
+ CompiledSQL(
+ "INSERT INTO sub (id, sub) VALUES (:id, :sub)",
+ lambda ctx:{'id':s1.id, 'sub':None}
+ ),
+ )
+
+ @testing.resolve_artifact_names
+ def test_dont_generate_on_none(self):
+ class Base(_base.ComparableEntity):
+ pass
+ class Sub(Base):
+ pass
+ mapper(Base, base, polymorphic_on=base.c.type, polymorphic_identity='base')
+ m = mapper(Sub, sub, inherits=Base, polymorphic_identity='sub')
+
+ s1 = Sub()
+ assert m._optimized_get_statement(attributes.instance_state(s1), ['counter2']) is None
+
+ # loads s1.id as None
+ eq_(s1.id, None)
+
+ # this now will come up with a value of None for id - should reject
+ assert m._optimized_get_statement(attributes.instance_state(s1), ['counter2']) is None
+
+ s1.id = 1
+ attributes.instance_state(s1).commit_all(s1.__dict__, None)
+ assert m._optimized_get_statement(attributes.instance_state(s1), ['counter2']) is not None
+
+ @testing.resolve_artifact_names
+ def test_load_expired_on_pending_twolevel(self):
+ class Base(_base.ComparableEntity):
+ pass
+ class Sub(Base):
+ pass
+ class SubSub(Sub):
+ pass
+
+ mapper(Base, base, polymorphic_on=base.c.type, polymorphic_identity='base')
+ mapper(Sub, sub, inherits=Base, polymorphic_identity='sub')
+ mapper(SubSub, subsub, inherits=Sub, polymorphic_identity='subsub')
+ sess = Session()
+ s1 = SubSub(data='s1', counter=1)
+ sess.add(s1)
+ self.assert_sql_execution(
+ testing.db,
+ sess.flush,
+ CompiledSQL(
+ "INSERT INTO base (data, type, counter) VALUES (:data, :type, :counter)",
+ [{'data':'s1','type':'subsub','counter':1}]
+ ),
+ CompiledSQL(
+ "INSERT INTO sub (id, sub, counter) VALUES (:id, :sub, :counter)",
+ lambda ctx:[{'counter': 1, 'sub': None, 'id': s1.id}]
+ ),
+ CompiledSQL(
+ "SELECT subsub.counter2 AS subsub_counter2, sub.counter2 AS "
+ "sub_counter2 FROM base JOIN sub ON base.id = sub.id JOIN "
+ "subsub ON sub.id = subsub.id WHERE base.id = :param_1",
+ lambda ctx:{u'param_1': s1.id}
+ ),
+ CompiledSQL(
+ "INSERT INTO subsub (id) VALUES (:id)",
+ lambda ctx:{'id':s1.id}
+ ),
+ )
+
class PKDiscriminatorTest(_base.MappedTest):
diff --git a/test/orm/test_expire.py b/test/orm/test_expire.py
index 93a3e198b..bbfdaa172 100644
--- a/test/orm/test_expire.py
+++ b/test/orm/test_expire.py
@@ -9,7 +9,7 @@ from test.lib.schema import Table
from test.lib.schema import Column
from sqlalchemy.orm import mapper, relationship, create_session, \
attributes, deferred, exc as orm_exc, defer, undefer,\
- strategies, state, lazyload, backref
+ strategies, state, lazyload, backref, Session
from test.orm import _base, _fixtures
@@ -240,7 +240,7 @@ class ExpireTest(_fixtures.FixtureTest):
assert 'name' not in u.__dict__
sess.add(u)
assert_raises(sa_exc.InvalidRequestError, getattr, u, 'name')
-
+
@testing.resolve_artifact_names
def test_expire_preserves_changes(self):
@@ -886,7 +886,8 @@ class PolymorphicExpireTest(_base.MappedTest):
Column('person_id', Integer, primary_key=True,
test_needs_autoincrement=True),
Column('name', String(50)),
- Column('type', String(30)))
+ Column('type', String(30)),
+ )
engineers = Table('engineers', metadata,
Column('person_id', Integer, ForeignKey('people.person_id'),
@@ -913,11 +914,15 @@ class PolymorphicExpireTest(_base.MappedTest):
{'person_id':2, 'status':'new engineer'},
{'person_id':3, 'status':'old engineer'},
)
-
+
+ @classmethod
@testing.resolve_artifact_names
- def test_poly_deferred(self):
+ def setup_mappers(cls):
mapper(Person, people, polymorphic_on=people.c.type, polymorphic_identity='person')
mapper(Engineer, engineers, inherits=Person, polymorphic_identity='engineer')
+
+ @testing.resolve_artifact_names
+ def test_poly_deferred(self):
sess = create_session()
[p1, e1, e2] = sess.query(Person).order_by(people.c.person_id).all()
@@ -953,6 +958,34 @@ class PolymorphicExpireTest(_base.MappedTest):
self.assert_sql_count(testing.db, go, 2)
eq_(Engineer.name.get_history(e1), (['new engineer name'],(), ['engineer1']))
+ @testing.resolve_artifact_names
+ def test_no_instance_key(self):
+
+ sess = create_session()
+ e1 = sess.query(Engineer).get(2)
+
+ sess.expire(e1, attribute_names=['name'])
+ sess.expunge(e1)
+ attributes.instance_state(e1).key = None
+ assert 'name' not in e1.__dict__
+ sess.add(e1)
+ assert e1.name == 'engineer1'
+
+ @testing.resolve_artifact_names
+ def test_no_instance_key(self):
+ # same as test_no_instance_key, but the PK columns
+ # are absent. ensure an error is raised.
+ sess = create_session()
+ e1 = sess.query(Engineer).get(2)
+
+ sess.expire(e1, attribute_names=['name', 'person_id'])
+ sess.expunge(e1)
+ attributes.instance_state(e1).key = None
+ assert 'name' not in e1.__dict__
+ sess.add(e1)
+ assert_raises(sa_exc.InvalidRequestError, getattr, e1, 'name')
+
+
class ExpiredPendingTest(_fixtures.FixtureTest):
run_define_tables = 'once'
run_setup_classes = 'once'