diff options
-rw-r--r-- | CHANGES | 10 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/mysql.py | 9 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/interfaces.py | 80 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/mapper.py | 16 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/query.py | 22 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/strategies.py | 83 | ||||
-rw-r--r-- | test/orm/assorted_eager.py | 78 | ||||
-rw-r--r-- | test/orm/eager_relations.py | 46 |
8 files changed, 222 insertions, 122 deletions
@@ -28,10 +28,14 @@ CHANGES synonym/deferred - fixed clear_mappers() behavior to better clean up after itself - + +- behavior of query.options() is now fully based on paths, i.e. an option + such as eagerload_all('x.y.z.y.x') will apply eagerloading to only + those paths, i.e. and not 'x.y.x'; eagerload('children.children') applies + only to exactly two-levels deep, etc. [ticket:777] + - Made access dao dection more reliable [ticket:828] - - + 0.4.0 ----- diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index de417f3a5..72bbea07f 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -1423,15 +1423,6 @@ class MySQLDialect(default.DefaultDialect): def type_descriptor(self, typeobj): return sqltypes.adapt_type(typeobj, colspecs) - def compiler(self, statement, bindparams, **kwargs): - return MySQLCompiler(statement, bindparams, dialect=self, **kwargs) - - def schemagenerator(self, *args, **kwargs): - return MySQLSchemaGenerator(self, *args, **kwargs) - - def schemadropper(self, *args, **kwargs): - return MySQLSchemaDropper(self, *args, **kwargs) - def do_executemany(self, cursor, statement, parameters, context=None): rowcount = cursor.executemany(statement, parameters) if context is not None: diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index acbb11505..7cfabb61a 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -10,7 +10,7 @@ from sqlalchemy.sql import expression __all__ = ['EXT_CONTINUE', 'EXT_STOP', 'EXT_PASS', 'MapperExtension', 'MapperProperty', 'PropComparator', 'StrategizedProperty', - 'LoaderStack', 'build_path', 'MapperOption', + 'build_path', 'MapperOption', 'ExtensionOption', 'SynonymProperty', 'PropertyOption', 'AttributeExtension', 'StrategizedOption', 'LoaderStrategy' ] @@ -467,7 +467,8 @@ class StrategizedProperty(MapperProperty): """ def _get_context_strategy(self, context): - return self._get_strategy(context.attributes.get(("loaderstrategy", self), self.strategy.__class__)) + path = context.path + return self._get_strategy(context.attributes.get(("loaderstrategy", path), self.strategy.__class__)) def _get_strategy(self, cls): try: @@ -499,35 +500,6 @@ def build_path(mapper, key, prev=None): else: return (mapper.base_mapper, key) -class LoaderStack(object): - """a stack object used during load operations to track the - current position among a chain of mappers to eager loaders.""" - - def __init__(self): - self.__stack = [] - - def push_property(self, key): - self.__stack.append(key) - return tuple(self.__stack) - - def push_mapper(self, mapper): - self.__stack.append(mapper.base_mapper) - return tuple(self.__stack) - - def pop(self): - self.__stack.pop() - - def snapshot(self): - """return an 'snapshot' of this stack. - - this is a tuple form of the stack which can be used as a hash key. - """ - - return tuple(self.__stack) - - def __str__(self): - return "->".join([str(s) for s in self.__stack]) - class MapperOption(object): """Describe a modification to a Query.""" @@ -582,26 +554,35 @@ class PropertyOption(MapperOption): self.key = key def process_query(self, query): - self.process_query_property(query, self._get_properties(query)) + if self._should_log_debug: + self.logger.debug("applying option to Query, property key '%s'" % self.key) + paths = self._get_paths(query) + if paths: + self.process_query_property(query, paths) - def process_query_property(self, query, properties): + def process_query_property(self, query, paths): pass - def _get_properties(self, query): - try: - l = self.__prop - except AttributeError: - l = [] - mapper = query.mapper - for token in self.key.split('.'): - prop = mapper.get_property(token, resolve_synonyms=True) - l.append(prop) - mapper = getattr(prop, 'mapper', None) - self.__prop = l + def _get_paths(self, query): + path = None + l = [] + current_path = list(query._current_path) + + mapper = query.mapper + for token in self.key.split('.'): + if current_path and token == current_path[1]: + current_path = current_path[2:] + continue + prop = mapper.get_property(token, resolve_synonyms=True, raiseerr=False) + if prop is None: + return [] + path = build_path(mapper, prop.key, path) + l.append(path) + mapper = getattr(prop, 'mapper', None) return l PropertyOption.logger = logging.class_logger(PropertyOption) - +PropertyOption._should_log_debug = logging.is_debug_enabled(PropertyOption.logger) class AttributeExtension(object): """An abstract class which specifies `append`, `delete`, and `set` @@ -626,13 +607,12 @@ class StrategizedOption(PropertyOption): def is_chained(self): return False - def process_query_property(self, query, properties): - self.logger.debug("applying option to Query, property key '%s'" % self.key) + def process_query_property(self, query, paths): if self.is_chained(): - for prop in properties: - query._attributes[("loaderstrategy", prop)] = self.get_strategy_class() + for path in paths: + query._attributes[("loaderstrategy", path)] = self.get_strategy_class() else: - query._attributes[("loaderstrategy", properties[-1])] = self.get_strategy_class() + query._attributes[("loaderstrategy", paths[-1])] = self.get_strategy_class() def get_strategy_class(self): raise NotImplementedError() diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 7dcbb25e1..2f3c36515 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1477,7 +1477,7 @@ class Mapper(object): def populate_instance(self, selectcontext, instance, row, ispostselect=None, isnew=False, **flags): """populate an instance from a result row.""" - snapshot = selectcontext.stack.push_mapper(self) + snapshot = selectcontext.path + (self,) # retrieve a set of "row population" functions derived from the MapperProperties attached # to this Mapper. These are keyed in the select context based primarily off the # "snapshot" of the stack, which represents a path from the lead mapper in the query to this one, @@ -1492,14 +1492,14 @@ class Mapper(object): existing_populators = [] post_processors = [] for prop in self.__props.values(): - (newpop, existingpop, post_proc) = prop.create_row_processor(selectcontext, self, row) + (newpop, existingpop, post_proc) = selectcontext.exec_with_path(self, prop.key, prop.create_row_processor, selectcontext, self, row) if newpop is not None: - new_populators.append(newpop) + new_populators.append((prop.key, newpop)) if existingpop is not None: - existing_populators.append(existingpop) + existing_populators.append((prop.key, existingpop)) if post_proc is not None: post_processors.append(post_proc) - + poly_select_loader = self._get_poly_select_loader(selectcontext, row) if poly_select_loader is not None: post_processors.append(poly_select_loader) @@ -1512,10 +1512,8 @@ class Mapper(object): else: populators = existing_populators - for p in populators: - p(instance, row, ispostselect=ispostselect, isnew=isnew, **flags) - - selectcontext.stack.pop() + for (key, populator) in populators: + selectcontext.exec_with_path(self, key, populator, instance, row, ispostselect=ispostselect, isnew=isnew, **flags) if self.non_primary: selectcontext.attributes[('populating_mapper', instance)] = self diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index e5534e22c..bec05a43f 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -9,7 +9,6 @@ from sqlalchemy.sql import util as sql_util from sqlalchemy.sql import expression, visitors from sqlalchemy.orm import mapper, object_mapper from sqlalchemy.orm import util as mapperutil -from sqlalchemy.orm.interfaces import LoaderStack import operator __all__ = ['Query', 'QueryContext'] @@ -48,6 +47,7 @@ class Query(object): self._autoflush = True self._eager_loaders = util.Set([x for x in self.mapper._eager_loaders]) self._attributes = {} + self._current_path = () def _clone(self): q = Query.__new__(Query) @@ -64,6 +64,11 @@ class Query(object): primary_key_columns = property(lambda s:s.select_mapper.primary_key) session = property(_get_session) + def _with_current_path(self, path): + q = self._clone() + q._current_path = path + return q + def get(self, ident, **kwargs): """Return an instance of the object based on the given identifier, or None if not found. @@ -870,7 +875,7 @@ class Query(object): # TODO: doing this off the select_mapper. if its the polymorphic mapper, then # it has no relations() on it. should we compile those too into the query ? (i.e. eagerloads) for value in self.select_mapper.iterate_properties: - value.setup(context) + context.exec_with_path(self.select_mapper, value.key, value.setup, context) # additional entities/columns, add those to selection criterion for tup in self._entities: @@ -878,7 +883,7 @@ class Query(object): clauses = self._get_entity_clauses(tup) if isinstance(m, mapper.Mapper): for value in m.iterate_properties: - value.setup(context, parentclauses=clauses) + context.exec_with_path(self.select_mapper, value.key, value.setup, context, parentclauses=clauses) elif isinstance(m, sql.ColumnElement): if clauses is not None: m = clauses.adapt_clause(m) @@ -1158,9 +1163,16 @@ class QueryContext(object): self.populate_existing = query._populate_existing self.version_check = query._version_check self.identity_map = {} - self.stack = LoaderStack() + self.path = () self.options = query._with_options self.attributes = query._attributes.copy() - + + def exec_with_path(self, mapper, propkey, func, *args, **kwargs): + oldpath = self.path + self.path += (mapper.base_mapper, propkey) + try: + return func(*args, **kwargs) + finally: + self.path = oldpath diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index e5a757bf4..757021cd0 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -313,7 +313,7 @@ class LazyLoader(AbstractRelationLoader): bindparam.value = mapper.get_attr_by_column(instance, bind_to_col[bindparam.key]) return Visitor().traverse(criterion, clone=True) - def setup_loader(self, instance, options=None): + def setup_loader(self, instance, options=None, path=None): if not mapper.has_mapper(instance): return None else: @@ -342,6 +342,8 @@ class LazyLoader(AbstractRelationLoader): # if we have a simple straight-primary key load, use mapper.get() # to possibly save a DB round trip q = session.query(self.mapper).autoflush(False) + if path: + q = q._with_current_path(path) if self.use_get: params = {} for col, bind in self.lazybinds.iteritems(): @@ -387,7 +389,8 @@ class LazyLoader(AbstractRelationLoader): self.logger.debug("set instance-level lazy loader on %s" % mapperutil.attribute_str(instance, self.key)) # we are not the primary manager for this attribute on this class - set up a per-instance lazyloader, # which will override the class-level behavior - self._init_instance_attribute(instance, callable_=self.setup_loader(instance, selectcontext.options)) + + self._init_instance_attribute(instance, callable_=self.setup_loader(instance, selectcontext.options, selectcontext.query._current_path + selectcontext.path)) return (new_execute, None, None) else: def new_execute(instance, row, ispostselect, **flags): @@ -479,24 +482,19 @@ class EagerLoader(AbstractRelationLoader): def setup_query(self, context, parentclauses=None, parentmapper=None, **kwargs): """Add a left outer join to the statement thats being constructed.""" - # build a path as we setup the query. the format of this path - # matches that of interfaces.LoaderStack, and will be used in the - # row-loading phase to match up AliasedClause objects with the current - # LoaderStack position. - if parentclauses: - path = parentclauses.path + (self.parent.base_mapper, self.key) - else: - path = (self.parent.base_mapper, self.key) + path = context.path - if self.join_depth: - if len(path) / 2 > self.join_depth: - return - else: - if self.mapper.base_mapper in path: - return + # check for join_depth or basic recursion, + # if the current path was not explicitly stated as + # a desired "loaderstrategy" (i.e. via query.options()) + if ("loaderstrategy", path) not in context.attributes: + if self.join_depth: + if len(path) / 2 > self.join_depth: + return + else: + if self.mapper.base_mapper in path: + return - #print "CREATING EAGER PATH FOR", "->".join([str(s) for s in path]) - if parentmapper is None: localparent = context.mapper else: @@ -550,7 +548,7 @@ class EagerLoader(AbstractRelationLoader): statement.append_from(statement._outerjoin) for value in self.select_mapper.iterate_properties: - value.setup(context, parentclauses=clauses, parentmapper=self.select_mapper) + context.exec_with_path(self.select_mapper, value.key, value.setup, context, parentclauses=clauses, parentmapper=self.select_mapper) def _create_row_decorator(self, selectcontext, row, path): """Create a *row decorating* function that will apply eager @@ -562,19 +560,10 @@ class EagerLoader(AbstractRelationLoader): #print "creating row decorator for path ", "->".join([str(s) for s in path]) - # check for a user-defined decorator in the SelectContext (which was set up by the contains_eager() option) - if ("eager_row_processor", self.parent_property) in selectcontext.attributes: - # custom row decoration function, placed in the selectcontext by the - # contains_eager() mapper option - decorator = selectcontext.attributes[("eager_row_processor", self.parent_property)] - # key was present, but no decorator; therefore just use the row as is + if ("eager_row_processor", path) in selectcontext.attributes: + decorator = selectcontext.attributes[("eager_row_processor", path)] if decorator is None: decorator = lambda row: row - # check for an AliasedClauses row decorator that was set up by query._compile_context(). - # a further refactoring (described in [ticket:777]) will simplify this so that the - # contains_eager() option generates the same key as this one - elif ("eager_row_processor", path) in selectcontext.attributes: - decorator = selectcontext.attributes[("eager_row_processor", path)] else: if self._should_log_debug: self.logger.debug("Could not locate aliased clauses for key: " + str(path)) @@ -593,15 +582,12 @@ class EagerLoader(AbstractRelationLoader): return None def create_row_processor(self, selectcontext, mapper, row): - path = selectcontext.stack.push_property(self.key) - row_decorator = self._create_row_decorator(selectcontext, row, path) + row_decorator = self._create_row_decorator(selectcontext, row, selectcontext.path) if row_decorator is not None: def execute(instance, row, isnew, **flags): decorated_row = row_decorator(row) - selectcontext.stack.push_property(self.key) - if not self.uselist: if self._should_log_debug: self.logger.debug("eagerload scalar instance on %s" % mapperutil.attribute_str(instance, self.key)) @@ -633,9 +619,7 @@ class EagerLoader(AbstractRelationLoader): self.logger.debug("eagerload list instance on %s" % mapperutil.attribute_str(instance, self.key)) self.select_mapper._instance(selectcontext, decorated_row, result_list) - selectcontext.stack.pop() - selectcontext.stack.pop() if self._should_log_debug: self.logger.debug("Returning eager instance loader for %s" % str(self)) @@ -644,7 +628,6 @@ class EagerLoader(AbstractRelationLoader): else: if self._should_log_debug: self.logger.debug("eager loader %s degrading to lazy loader" % str(self)) - selectcontext.stack.pop() return self.parent_property._get_strategy(LazyLoader).create_row_processor(selectcontext, mapper, row) @@ -662,13 +645,19 @@ class EagerLazyOption(StrategizedOption): def is_chained(self): return not self.lazy and self.chained - def process_query_property(self, query, properties): + def process_query_property(self, query, paths): if self.lazy: - if properties[-1] in query._eager_loaders: - query._eager_loaders = query._eager_loaders.difference(util.Set([properties[-1]])) + if paths[-1] in query._eager_loaders: + query._eager_loaders = query._eager_loaders.difference(util.Set([paths[-1]])) else: - query._eager_loaders = query._eager_loaders.union(util.Set(properties)) - super(EagerLazyOption, self).process_query_property(query, properties) + if not self.chained: + paths = [paths[-1]] + res = util.Set() + for path in paths: + if len(path) - len(query._current_path) == 2: + res.add(path) + query._eager_loaders = query._eager_loaders.union(res) + super(EagerLazyOption, self).process_query_property(query, paths) def get_strategy_class(self): if self.lazy: @@ -702,17 +691,19 @@ class RowDecorateOption(PropertyOption): self.decorator = decorator self.alias = alias - def process_query_property(self, query, properties): + def process_query_property(self, query, paths): if self.alias is not None and self.decorator is None: if isinstance(self.alias, basestring): - self.alias = properties[-1].target.alias(self.alias) + (mapper, propname) = paths[-1] + prop = mapper.get_property(propname, resolve_synonyms=True) + self.alias = prop.target.alias(self.alias) def decorate(row): d = {} - for c in properties[-1].target.columns: + for c in prop.target.columns: d[c] = row[self.alias.corresponding_column(c)] return d self.decorator = decorate - query._attributes[("eager_row_processor", properties[-1])] = self.decorator + query._attributes[("eager_row_processor", paths[-1])] = self.decorator RowDecorateOption.logger = logging.class_logger(RowDecorateOption) diff --git a/test/orm/assorted_eager.py b/test/orm/assorted_eager.py index dab3d5de5..ad95a0a49 100644 --- a/test/orm/assorted_eager.py +++ b/test/orm/assorted_eager.py @@ -6,6 +6,7 @@ from sqlalchemy import * from sqlalchemy.orm import * from sqlalchemy.ext.sessioncontext import SessionContext from testlib import * +from testlib import fixtures class EagerTest(AssertMixin): def setUpAll(self): @@ -777,6 +778,83 @@ class EagerTest8(ORMTest): for t in session.query(cls.mapper).limit(10).offset(0).list(): print t.id, t.title, t.props_cnt + +class EagerTest9(ORMTest): + """test the usage of query options to eagerly load specific paths. + + this relies upon the 'path' construct used by PropertyOption to relate + LoaderStrategies to specific paths, as well as the path state maintained + throughout the query setup/mapper instances process. + """ + + def define_tables(self, metadata): + global accounts_table, transactions_table, entries_table + accounts_table = Table('accounts', metadata, + Column('account_id', Integer, primary_key=True), + Column('name', String(40)), + ) + transactions_table = Table('transactions', metadata, + Column('transaction_id', Integer, primary_key=True), + Column('name', String(40)), + ) + entries_table = Table('entries', metadata, + Column('entry_id', Integer, primary_key=True), + Column('name', String(40)), + Column('account_id', Integer, ForeignKey(accounts_table.c.account_id)), + Column('transaction_id', Integer, ForeignKey(transactions_table.c.transaction_id)), + ) + + def test_eagerload_on_path(self): + class Account(fixtures.Base): + pass + + class Transaction(fixtures.Base): + pass + + class Entry(fixtures.Base): + pass + + mapper(Account, accounts_table) + mapper(Transaction, transactions_table) + mapper(Entry, entries_table, properties = dict( + account = relation(Account, uselist=False, backref=backref('entries', lazy=True)), + transaction = relation(Transaction, uselist=False, backref=backref('entries', lazy=False)), + )) + + session = create_session() + + tx1 = Transaction(name='tx1') + tx2 = Transaction(name='tx2') + + acc1 = Account(name='acc1') + ent11 = Entry(name='ent11', account=acc1, transaction=tx1) + ent12 = Entry(name='ent12', account=acc1, transaction=tx2) + + acc2 = Account(name='acc2') + ent21 = Entry(name='ent21', account=acc2, transaction=tx1) + ent22 = Entry(name='ent22', account=acc2, transaction=tx2) + + session.save(acc1) + session.flush() + session.clear() + + def go(): + # load just the first Account. eager loading will actually load all objects saved thus far, + # but will not eagerly load the "accounts" off the immediate "entries"; only the + # "accounts" off the entries->transaction->entries + acc = session.query(Account).options(eagerload_all('entries.transaction.entries.account')).first() + + # no sql occurs + assert acc.name == 'acc1' + assert acc.entries[0].transaction.entries[0].account.name == 'acc1' + assert acc.entries[0].transaction.entries[1].account.name == 'acc2' + + # lazyload triggers but no sql occurs because many-to-one uses cached query.get() + for e in acc.entries: + assert e.account is acc + + self.assert_sql_count(testbase.db, go, 1) + if __name__ == "__main__": diff --git a/test/orm/eager_relations.py b/test/orm/eager_relations.py index b851a1856..4e18dda20 100644 --- a/test/orm/eager_relations.py +++ b/test/orm/eager_relations.py @@ -497,6 +497,52 @@ class SelfReferentialEagerTest(ORMTest): ]) == d self.assert_sql_count(testbase.db, go, 1) + def test_options(self): + class Node(Base): + def append(self, node): + self.children.append(node) + + mapper(Node, nodes, properties={ + 'children':relation(Node, lazy=True) + }) + sess = create_session() + n1 = Node(data='n1') + n1.append(Node(data='n11')) + n1.append(Node(data='n12')) + n1.append(Node(data='n13')) + n1.children[1].append(Node(data='n121')) + n1.children[1].append(Node(data='n122')) + n1.children[1].append(Node(data='n123')) + sess.save(n1) + sess.flush() + sess.clear() + def go(): + d = sess.query(Node).filter_by(data='n1').options(eagerload('children.children')).first() + assert Node(data='n1', children=[ + Node(data='n11'), + Node(data='n12', children=[ + Node(data='n121'), + Node(data='n122'), + Node(data='n123') + ]), + Node(data='n13') + ]) == d + self.assert_sql_count(testbase.db, go, 2) + + def go(): + d = sess.query(Node).filter_by(data='n1').options(eagerload('children.children')).first() + + # test that the query isn't wrapping the initial query for eager loading. + # testing only sqlite for now since the query text is slightly different on other + # dialects + if testing.against('sqlite'): + self.assert_sql(testbase.db, go, [ + ( + "SELECT nodes.id AS nodes_id, nodes.parent_id AS nodes_parent_id, nodes.data AS nodes_data FROM nodes WHERE nodes.data = :nodes_data ORDER BY nodes.oid LIMIT 1 OFFSET 0", + {'nodes_data': 'n1'} + ), + ]) + def test_no_depth(self): class Node(Base): def append(self, node): |