summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CHANGES10
-rw-r--r--lib/sqlalchemy/databases/mysql.py9
-rw-r--r--lib/sqlalchemy/orm/interfaces.py80
-rw-r--r--lib/sqlalchemy/orm/mapper.py16
-rw-r--r--lib/sqlalchemy/orm/query.py22
-rw-r--r--lib/sqlalchemy/orm/strategies.py83
-rw-r--r--test/orm/assorted_eager.py78
-rw-r--r--test/orm/eager_relations.py46
8 files changed, 222 insertions, 122 deletions
diff --git a/CHANGES b/CHANGES
index 5e4c52b02..66d6385b9 100644
--- a/CHANGES
+++ b/CHANGES
@@ -28,10 +28,14 @@ CHANGES
synonym/deferred
- fixed clear_mappers() behavior to better clean up after itself
-
+
+- behavior of query.options() is now fully based on paths, i.e. an option
+ such as eagerload_all('x.y.z.y.x') will apply eagerloading to only
+ those paths, i.e. and not 'x.y.x'; eagerload('children.children') applies
+ only to exactly two-levels deep, etc. [ticket:777]
+
- Made access dao dection more reliable [ticket:828]
-
-
+
0.4.0
-----
diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py
index de417f3a5..72bbea07f 100644
--- a/lib/sqlalchemy/databases/mysql.py
+++ b/lib/sqlalchemy/databases/mysql.py
@@ -1423,15 +1423,6 @@ class MySQLDialect(default.DefaultDialect):
def type_descriptor(self, typeobj):
return sqltypes.adapt_type(typeobj, colspecs)
- def compiler(self, statement, bindparams, **kwargs):
- return MySQLCompiler(statement, bindparams, dialect=self, **kwargs)
-
- def schemagenerator(self, *args, **kwargs):
- return MySQLSchemaGenerator(self, *args, **kwargs)
-
- def schemadropper(self, *args, **kwargs):
- return MySQLSchemaDropper(self, *args, **kwargs)
-
def do_executemany(self, cursor, statement, parameters, context=None):
rowcount = cursor.executemany(statement, parameters)
if context is not None:
diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py
index acbb11505..7cfabb61a 100644
--- a/lib/sqlalchemy/orm/interfaces.py
+++ b/lib/sqlalchemy/orm/interfaces.py
@@ -10,7 +10,7 @@ from sqlalchemy.sql import expression
__all__ = ['EXT_CONTINUE', 'EXT_STOP', 'EXT_PASS', 'MapperExtension',
'MapperProperty', 'PropComparator', 'StrategizedProperty',
- 'LoaderStack', 'build_path', 'MapperOption',
+ 'build_path', 'MapperOption',
'ExtensionOption', 'SynonymProperty', 'PropertyOption',
'AttributeExtension', 'StrategizedOption', 'LoaderStrategy' ]
@@ -467,7 +467,8 @@ class StrategizedProperty(MapperProperty):
"""
def _get_context_strategy(self, context):
- return self._get_strategy(context.attributes.get(("loaderstrategy", self), self.strategy.__class__))
+ path = context.path
+ return self._get_strategy(context.attributes.get(("loaderstrategy", path), self.strategy.__class__))
def _get_strategy(self, cls):
try:
@@ -499,35 +500,6 @@ def build_path(mapper, key, prev=None):
else:
return (mapper.base_mapper, key)
-class LoaderStack(object):
- """a stack object used during load operations to track the
- current position among a chain of mappers to eager loaders."""
-
- def __init__(self):
- self.__stack = []
-
- def push_property(self, key):
- self.__stack.append(key)
- return tuple(self.__stack)
-
- def push_mapper(self, mapper):
- self.__stack.append(mapper.base_mapper)
- return tuple(self.__stack)
-
- def pop(self):
- self.__stack.pop()
-
- def snapshot(self):
- """return an 'snapshot' of this stack.
-
- this is a tuple form of the stack which can be used as a hash key.
- """
-
- return tuple(self.__stack)
-
- def __str__(self):
- return "->".join([str(s) for s in self.__stack])
-
class MapperOption(object):
"""Describe a modification to a Query."""
@@ -582,26 +554,35 @@ class PropertyOption(MapperOption):
self.key = key
def process_query(self, query):
- self.process_query_property(query, self._get_properties(query))
+ if self._should_log_debug:
+ self.logger.debug("applying option to Query, property key '%s'" % self.key)
+ paths = self._get_paths(query)
+ if paths:
+ self.process_query_property(query, paths)
- def process_query_property(self, query, properties):
+ def process_query_property(self, query, paths):
pass
- def _get_properties(self, query):
- try:
- l = self.__prop
- except AttributeError:
- l = []
- mapper = query.mapper
- for token in self.key.split('.'):
- prop = mapper.get_property(token, resolve_synonyms=True)
- l.append(prop)
- mapper = getattr(prop, 'mapper', None)
- self.__prop = l
+ def _get_paths(self, query):
+ path = None
+ l = []
+ current_path = list(query._current_path)
+
+ mapper = query.mapper
+ for token in self.key.split('.'):
+ if current_path and token == current_path[1]:
+ current_path = current_path[2:]
+ continue
+ prop = mapper.get_property(token, resolve_synonyms=True, raiseerr=False)
+ if prop is None:
+ return []
+ path = build_path(mapper, prop.key, path)
+ l.append(path)
+ mapper = getattr(prop, 'mapper', None)
return l
PropertyOption.logger = logging.class_logger(PropertyOption)
-
+PropertyOption._should_log_debug = logging.is_debug_enabled(PropertyOption.logger)
class AttributeExtension(object):
"""An abstract class which specifies `append`, `delete`, and `set`
@@ -626,13 +607,12 @@ class StrategizedOption(PropertyOption):
def is_chained(self):
return False
- def process_query_property(self, query, properties):
- self.logger.debug("applying option to Query, property key '%s'" % self.key)
+ def process_query_property(self, query, paths):
if self.is_chained():
- for prop in properties:
- query._attributes[("loaderstrategy", prop)] = self.get_strategy_class()
+ for path in paths:
+ query._attributes[("loaderstrategy", path)] = self.get_strategy_class()
else:
- query._attributes[("loaderstrategy", properties[-1])] = self.get_strategy_class()
+ query._attributes[("loaderstrategy", paths[-1])] = self.get_strategy_class()
def get_strategy_class(self):
raise NotImplementedError()
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index 7dcbb25e1..2f3c36515 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -1477,7 +1477,7 @@ class Mapper(object):
def populate_instance(self, selectcontext, instance, row, ispostselect=None, isnew=False, **flags):
"""populate an instance from a result row."""
- snapshot = selectcontext.stack.push_mapper(self)
+ snapshot = selectcontext.path + (self,)
# retrieve a set of "row population" functions derived from the MapperProperties attached
# to this Mapper. These are keyed in the select context based primarily off the
# "snapshot" of the stack, which represents a path from the lead mapper in the query to this one,
@@ -1492,14 +1492,14 @@ class Mapper(object):
existing_populators = []
post_processors = []
for prop in self.__props.values():
- (newpop, existingpop, post_proc) = prop.create_row_processor(selectcontext, self, row)
+ (newpop, existingpop, post_proc) = selectcontext.exec_with_path(self, prop.key, prop.create_row_processor, selectcontext, self, row)
if newpop is not None:
- new_populators.append(newpop)
+ new_populators.append((prop.key, newpop))
if existingpop is not None:
- existing_populators.append(existingpop)
+ existing_populators.append((prop.key, existingpop))
if post_proc is not None:
post_processors.append(post_proc)
-
+
poly_select_loader = self._get_poly_select_loader(selectcontext, row)
if poly_select_loader is not None:
post_processors.append(poly_select_loader)
@@ -1512,10 +1512,8 @@ class Mapper(object):
else:
populators = existing_populators
- for p in populators:
- p(instance, row, ispostselect=ispostselect, isnew=isnew, **flags)
-
- selectcontext.stack.pop()
+ for (key, populator) in populators:
+ selectcontext.exec_with_path(self, key, populator, instance, row, ispostselect=ispostselect, isnew=isnew, **flags)
if self.non_primary:
selectcontext.attributes[('populating_mapper', instance)] = self
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index e5534e22c..bec05a43f 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -9,7 +9,6 @@ from sqlalchemy.sql import util as sql_util
from sqlalchemy.sql import expression, visitors
from sqlalchemy.orm import mapper, object_mapper
from sqlalchemy.orm import util as mapperutil
-from sqlalchemy.orm.interfaces import LoaderStack
import operator
__all__ = ['Query', 'QueryContext']
@@ -48,6 +47,7 @@ class Query(object):
self._autoflush = True
self._eager_loaders = util.Set([x for x in self.mapper._eager_loaders])
self._attributes = {}
+ self._current_path = ()
def _clone(self):
q = Query.__new__(Query)
@@ -64,6 +64,11 @@ class Query(object):
primary_key_columns = property(lambda s:s.select_mapper.primary_key)
session = property(_get_session)
+ def _with_current_path(self, path):
+ q = self._clone()
+ q._current_path = path
+ return q
+
def get(self, ident, **kwargs):
"""Return an instance of the object based on the given
identifier, or None if not found.
@@ -870,7 +875,7 @@ class Query(object):
# TODO: doing this off the select_mapper. if its the polymorphic mapper, then
# it has no relations() on it. should we compile those too into the query ? (i.e. eagerloads)
for value in self.select_mapper.iterate_properties:
- value.setup(context)
+ context.exec_with_path(self.select_mapper, value.key, value.setup, context)
# additional entities/columns, add those to selection criterion
for tup in self._entities:
@@ -878,7 +883,7 @@ class Query(object):
clauses = self._get_entity_clauses(tup)
if isinstance(m, mapper.Mapper):
for value in m.iterate_properties:
- value.setup(context, parentclauses=clauses)
+ context.exec_with_path(self.select_mapper, value.key, value.setup, context, parentclauses=clauses)
elif isinstance(m, sql.ColumnElement):
if clauses is not None:
m = clauses.adapt_clause(m)
@@ -1158,9 +1163,16 @@ class QueryContext(object):
self.populate_existing = query._populate_existing
self.version_check = query._version_check
self.identity_map = {}
- self.stack = LoaderStack()
+ self.path = ()
self.options = query._with_options
self.attributes = query._attributes.copy()
-
+
+ def exec_with_path(self, mapper, propkey, func, *args, **kwargs):
+ oldpath = self.path
+ self.path += (mapper.base_mapper, propkey)
+ try:
+ return func(*args, **kwargs)
+ finally:
+ self.path = oldpath
diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py
index e5a757bf4..757021cd0 100644
--- a/lib/sqlalchemy/orm/strategies.py
+++ b/lib/sqlalchemy/orm/strategies.py
@@ -313,7 +313,7 @@ class LazyLoader(AbstractRelationLoader):
bindparam.value = mapper.get_attr_by_column(instance, bind_to_col[bindparam.key])
return Visitor().traverse(criterion, clone=True)
- def setup_loader(self, instance, options=None):
+ def setup_loader(self, instance, options=None, path=None):
if not mapper.has_mapper(instance):
return None
else:
@@ -342,6 +342,8 @@ class LazyLoader(AbstractRelationLoader):
# if we have a simple straight-primary key load, use mapper.get()
# to possibly save a DB round trip
q = session.query(self.mapper).autoflush(False)
+ if path:
+ q = q._with_current_path(path)
if self.use_get:
params = {}
for col, bind in self.lazybinds.iteritems():
@@ -387,7 +389,8 @@ class LazyLoader(AbstractRelationLoader):
self.logger.debug("set instance-level lazy loader on %s" % mapperutil.attribute_str(instance, self.key))
# we are not the primary manager for this attribute on this class - set up a per-instance lazyloader,
# which will override the class-level behavior
- self._init_instance_attribute(instance, callable_=self.setup_loader(instance, selectcontext.options))
+
+ self._init_instance_attribute(instance, callable_=self.setup_loader(instance, selectcontext.options, selectcontext.query._current_path + selectcontext.path))
return (new_execute, None, None)
else:
def new_execute(instance, row, ispostselect, **flags):
@@ -479,24 +482,19 @@ class EagerLoader(AbstractRelationLoader):
def setup_query(self, context, parentclauses=None, parentmapper=None, **kwargs):
"""Add a left outer join to the statement thats being constructed."""
- # build a path as we setup the query. the format of this path
- # matches that of interfaces.LoaderStack, and will be used in the
- # row-loading phase to match up AliasedClause objects with the current
- # LoaderStack position.
- if parentclauses:
- path = parentclauses.path + (self.parent.base_mapper, self.key)
- else:
- path = (self.parent.base_mapper, self.key)
+ path = context.path
- if self.join_depth:
- if len(path) / 2 > self.join_depth:
- return
- else:
- if self.mapper.base_mapper in path:
- return
+ # check for join_depth or basic recursion,
+ # if the current path was not explicitly stated as
+ # a desired "loaderstrategy" (i.e. via query.options())
+ if ("loaderstrategy", path) not in context.attributes:
+ if self.join_depth:
+ if len(path) / 2 > self.join_depth:
+ return
+ else:
+ if self.mapper.base_mapper in path:
+ return
- #print "CREATING EAGER PATH FOR", "->".join([str(s) for s in path])
-
if parentmapper is None:
localparent = context.mapper
else:
@@ -550,7 +548,7 @@ class EagerLoader(AbstractRelationLoader):
statement.append_from(statement._outerjoin)
for value in self.select_mapper.iterate_properties:
- value.setup(context, parentclauses=clauses, parentmapper=self.select_mapper)
+ context.exec_with_path(self.select_mapper, value.key, value.setup, context, parentclauses=clauses, parentmapper=self.select_mapper)
def _create_row_decorator(self, selectcontext, row, path):
"""Create a *row decorating* function that will apply eager
@@ -562,19 +560,10 @@ class EagerLoader(AbstractRelationLoader):
#print "creating row decorator for path ", "->".join([str(s) for s in path])
- # check for a user-defined decorator in the SelectContext (which was set up by the contains_eager() option)
- if ("eager_row_processor", self.parent_property) in selectcontext.attributes:
- # custom row decoration function, placed in the selectcontext by the
- # contains_eager() mapper option
- decorator = selectcontext.attributes[("eager_row_processor", self.parent_property)]
- # key was present, but no decorator; therefore just use the row as is
+ if ("eager_row_processor", path) in selectcontext.attributes:
+ decorator = selectcontext.attributes[("eager_row_processor", path)]
if decorator is None:
decorator = lambda row: row
- # check for an AliasedClauses row decorator that was set up by query._compile_context().
- # a further refactoring (described in [ticket:777]) will simplify this so that the
- # contains_eager() option generates the same key as this one
- elif ("eager_row_processor", path) in selectcontext.attributes:
- decorator = selectcontext.attributes[("eager_row_processor", path)]
else:
if self._should_log_debug:
self.logger.debug("Could not locate aliased clauses for key: " + str(path))
@@ -593,15 +582,12 @@ class EagerLoader(AbstractRelationLoader):
return None
def create_row_processor(self, selectcontext, mapper, row):
- path = selectcontext.stack.push_property(self.key)
- row_decorator = self._create_row_decorator(selectcontext, row, path)
+ row_decorator = self._create_row_decorator(selectcontext, row, selectcontext.path)
if row_decorator is not None:
def execute(instance, row, isnew, **flags):
decorated_row = row_decorator(row)
- selectcontext.stack.push_property(self.key)
-
if not self.uselist:
if self._should_log_debug:
self.logger.debug("eagerload scalar instance on %s" % mapperutil.attribute_str(instance, self.key))
@@ -633,9 +619,7 @@ class EagerLoader(AbstractRelationLoader):
self.logger.debug("eagerload list instance on %s" % mapperutil.attribute_str(instance, self.key))
self.select_mapper._instance(selectcontext, decorated_row, result_list)
- selectcontext.stack.pop()
- selectcontext.stack.pop()
if self._should_log_debug:
self.logger.debug("Returning eager instance loader for %s" % str(self))
@@ -644,7 +628,6 @@ class EagerLoader(AbstractRelationLoader):
else:
if self._should_log_debug:
self.logger.debug("eager loader %s degrading to lazy loader" % str(self))
- selectcontext.stack.pop()
return self.parent_property._get_strategy(LazyLoader).create_row_processor(selectcontext, mapper, row)
@@ -662,13 +645,19 @@ class EagerLazyOption(StrategizedOption):
def is_chained(self):
return not self.lazy and self.chained
- def process_query_property(self, query, properties):
+ def process_query_property(self, query, paths):
if self.lazy:
- if properties[-1] in query._eager_loaders:
- query._eager_loaders = query._eager_loaders.difference(util.Set([properties[-1]]))
+ if paths[-1] in query._eager_loaders:
+ query._eager_loaders = query._eager_loaders.difference(util.Set([paths[-1]]))
else:
- query._eager_loaders = query._eager_loaders.union(util.Set(properties))
- super(EagerLazyOption, self).process_query_property(query, properties)
+ if not self.chained:
+ paths = [paths[-1]]
+ res = util.Set()
+ for path in paths:
+ if len(path) - len(query._current_path) == 2:
+ res.add(path)
+ query._eager_loaders = query._eager_loaders.union(res)
+ super(EagerLazyOption, self).process_query_property(query, paths)
def get_strategy_class(self):
if self.lazy:
@@ -702,17 +691,19 @@ class RowDecorateOption(PropertyOption):
self.decorator = decorator
self.alias = alias
- def process_query_property(self, query, properties):
+ def process_query_property(self, query, paths):
if self.alias is not None and self.decorator is None:
if isinstance(self.alias, basestring):
- self.alias = properties[-1].target.alias(self.alias)
+ (mapper, propname) = paths[-1]
+ prop = mapper.get_property(propname, resolve_synonyms=True)
+ self.alias = prop.target.alias(self.alias)
def decorate(row):
d = {}
- for c in properties[-1].target.columns:
+ for c in prop.target.columns:
d[c] = row[self.alias.corresponding_column(c)]
return d
self.decorator = decorate
- query._attributes[("eager_row_processor", properties[-1])] = self.decorator
+ query._attributes[("eager_row_processor", paths[-1])] = self.decorator
RowDecorateOption.logger = logging.class_logger(RowDecorateOption)
diff --git a/test/orm/assorted_eager.py b/test/orm/assorted_eager.py
index dab3d5de5..ad95a0a49 100644
--- a/test/orm/assorted_eager.py
+++ b/test/orm/assorted_eager.py
@@ -6,6 +6,7 @@ from sqlalchemy import *
from sqlalchemy.orm import *
from sqlalchemy.ext.sessioncontext import SessionContext
from testlib import *
+from testlib import fixtures
class EagerTest(AssertMixin):
def setUpAll(self):
@@ -777,6 +778,83 @@ class EagerTest8(ORMTest):
for t in session.query(cls.mapper).limit(10).offset(0).list():
print t.id, t.title, t.props_cnt
+
+class EagerTest9(ORMTest):
+ """test the usage of query options to eagerly load specific paths.
+
+ this relies upon the 'path' construct used by PropertyOption to relate
+ LoaderStrategies to specific paths, as well as the path state maintained
+ throughout the query setup/mapper instances process.
+ """
+
+ def define_tables(self, metadata):
+ global accounts_table, transactions_table, entries_table
+ accounts_table = Table('accounts', metadata,
+ Column('account_id', Integer, primary_key=True),
+ Column('name', String(40)),
+ )
+ transactions_table = Table('transactions', metadata,
+ Column('transaction_id', Integer, primary_key=True),
+ Column('name', String(40)),
+ )
+ entries_table = Table('entries', metadata,
+ Column('entry_id', Integer, primary_key=True),
+ Column('name', String(40)),
+ Column('account_id', Integer, ForeignKey(accounts_table.c.account_id)),
+ Column('transaction_id', Integer, ForeignKey(transactions_table.c.transaction_id)),
+ )
+
+ def test_eagerload_on_path(self):
+ class Account(fixtures.Base):
+ pass
+
+ class Transaction(fixtures.Base):
+ pass
+
+ class Entry(fixtures.Base):
+ pass
+
+ mapper(Account, accounts_table)
+ mapper(Transaction, transactions_table)
+ mapper(Entry, entries_table, properties = dict(
+ account = relation(Account, uselist=False, backref=backref('entries', lazy=True)),
+ transaction = relation(Transaction, uselist=False, backref=backref('entries', lazy=False)),
+ ))
+
+ session = create_session()
+
+ tx1 = Transaction(name='tx1')
+ tx2 = Transaction(name='tx2')
+
+ acc1 = Account(name='acc1')
+ ent11 = Entry(name='ent11', account=acc1, transaction=tx1)
+ ent12 = Entry(name='ent12', account=acc1, transaction=tx2)
+
+ acc2 = Account(name='acc2')
+ ent21 = Entry(name='ent21', account=acc2, transaction=tx1)
+ ent22 = Entry(name='ent22', account=acc2, transaction=tx2)
+
+ session.save(acc1)
+ session.flush()
+ session.clear()
+
+ def go():
+ # load just the first Account. eager loading will actually load all objects saved thus far,
+ # but will not eagerly load the "accounts" off the immediate "entries"; only the
+ # "accounts" off the entries->transaction->entries
+ acc = session.query(Account).options(eagerload_all('entries.transaction.entries.account')).first()
+
+ # no sql occurs
+ assert acc.name == 'acc1'
+ assert acc.entries[0].transaction.entries[0].account.name == 'acc1'
+ assert acc.entries[0].transaction.entries[1].account.name == 'acc2'
+
+ # lazyload triggers but no sql occurs because many-to-one uses cached query.get()
+ for e in acc.entries:
+ assert e.account is acc
+
+ self.assert_sql_count(testbase.db, go, 1)
+
if __name__ == "__main__":
diff --git a/test/orm/eager_relations.py b/test/orm/eager_relations.py
index b851a1856..4e18dda20 100644
--- a/test/orm/eager_relations.py
+++ b/test/orm/eager_relations.py
@@ -497,6 +497,52 @@ class SelfReferentialEagerTest(ORMTest):
]) == d
self.assert_sql_count(testbase.db, go, 1)
+ def test_options(self):
+ class Node(Base):
+ def append(self, node):
+ self.children.append(node)
+
+ mapper(Node, nodes, properties={
+ 'children':relation(Node, lazy=True)
+ })
+ sess = create_session()
+ n1 = Node(data='n1')
+ n1.append(Node(data='n11'))
+ n1.append(Node(data='n12'))
+ n1.append(Node(data='n13'))
+ n1.children[1].append(Node(data='n121'))
+ n1.children[1].append(Node(data='n122'))
+ n1.children[1].append(Node(data='n123'))
+ sess.save(n1)
+ sess.flush()
+ sess.clear()
+ def go():
+ d = sess.query(Node).filter_by(data='n1').options(eagerload('children.children')).first()
+ assert Node(data='n1', children=[
+ Node(data='n11'),
+ Node(data='n12', children=[
+ Node(data='n121'),
+ Node(data='n122'),
+ Node(data='n123')
+ ]),
+ Node(data='n13')
+ ]) == d
+ self.assert_sql_count(testbase.db, go, 2)
+
+ def go():
+ d = sess.query(Node).filter_by(data='n1').options(eagerload('children.children')).first()
+
+ # test that the query isn't wrapping the initial query for eager loading.
+ # testing only sqlite for now since the query text is slightly different on other
+ # dialects
+ if testing.against('sqlite'):
+ self.assert_sql(testbase.db, go, [
+ (
+ "SELECT nodes.id AS nodes_id, nodes.parent_id AS nodes_parent_id, nodes.data AS nodes_data FROM nodes WHERE nodes.data = :nodes_data ORDER BY nodes.oid LIMIT 1 OFFSET 0",
+ {'nodes_data': 'n1'}
+ ),
+ ])
+
def test_no_depth(self):
class Node(Base):
def append(self, node):