diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2012-02-09 21:16:53 -0500 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2012-02-09 21:16:53 -0500 |
commit | bc45fa350a02da5f24d866078abed471cd98f15b (patch) | |
tree | 2607af2197e003fdc735c020207d4f234d718fee /lib/sqlalchemy | |
parent | 91f4109dc3ec49686ba2393eb6b7bd9bb5b95fb3 (diff) | |
download | sqlalchemy-bc45fa350a02da5f24d866078abed471cd98f15b.tar.gz |
- got m2m, local_remote_pairs, etc. working
- using new traversal that returns the product of both sides
of a binary, starting to work with (a+b) == (c+d) types of joins.
primaryjoins on functions working
- annotations working, including reversing local/remote when
doing backref
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r-- | lib/sqlalchemy/orm/__init__.py | 8 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/properties.py | 69 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/relationships.py | 447 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/strategies.py | 4 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/util.py | 13 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/expression.py | 24 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/util.py | 76 |
7 files changed, 294 insertions, 347 deletions
diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py index 9fd969e3b..13bd18f08 100644 --- a/lib/sqlalchemy/orm/__init__.py +++ b/lib/sqlalchemy/orm/__init__.py @@ -44,6 +44,11 @@ from sqlalchemy.orm.properties import ( PropertyLoader, SynonymProperty, ) +from sqlalchemy.orm.relationships import ( + foreign, + remote, + remote_foreign +) from sqlalchemy.orm import mapper as mapperlib from sqlalchemy.orm.mapper import reconstructor, validates from sqlalchemy.orm import strategies @@ -81,6 +86,7 @@ __all__ = ( 'dynamic_loader', 'eagerload', 'eagerload_all', + 'foreign', 'immediateload', 'join', 'joinedload', @@ -96,6 +102,8 @@ __all__ = ( 'reconstructor', 'relationship', 'relation', + 'remote', + 'remote_foreign', 'scoped_session', 'sessionmaker', 'subqueryload', diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 77da33b5f..38237b2d4 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -16,7 +16,7 @@ from sqlalchemy.sql.util import ClauseAdapter, criterion_as_pairs, \ join_condition, _shallow_annotate from sqlalchemy.sql import operators, expression, visitors from sqlalchemy.orm import attributes, dependency, mapper, \ - object_mapper, strategies, configure_mappers + object_mapper, strategies, configure_mappers, relationships from sqlalchemy.orm.util import CascadeOptions, _class_to_mapper, \ _orm_annotate, _orm_deannotate @@ -915,13 +915,12 @@ class RelationshipProperty(StrategizedProperty): self._check_conflicts() self._process_dependent_arguments() self._setup_join_conditions() - self._extra_determine_direction() + self._check_cascade_settings() self._post_init() self._generate_backref() super(RelationshipProperty, self).do_init() def _setup_join_conditions(self): - import relationships self._join_condition = jc = relationships.JoinCondition( parent_selectable=self.parent.mapped_table, child_selectable=self.mapper.mapped_table, @@ -946,8 +945,8 @@ class RelationshipProperty(StrategizedProperty): self.local_remote_pairs = jc.local_remote_pairs self.remote_side = jc.remote_columns self.synchronize_pairs = jc.synchronize_pairs - self.secondary_synchronize_pairs = jc.secondary_synchronize_pairs self._calculated_foreign_keys = jc.foreign_key_columns + self.secondary_synchronize_pairs = jc.secondary_synchronize_pairs def _check_conflicts(self): """Test that this relationship is legal, warn about @@ -1035,7 +1034,7 @@ class RelationshipProperty(StrategizedProperty): (self.key, self.parent.class_) ) - def _extra_determine_direction(self): + def _check_cascade_settings(self): if self.cascade.delete_orphan and not self.single_parent \ and (self.direction is MANYTOMANY or self.direction is MANYTOONE): @@ -1064,7 +1063,6 @@ class RelationshipProperty(StrategizedProperty): return False return True - def _generate_backref(self): if not self.is_primary(): return @@ -1083,13 +1081,15 @@ class RelationshipProperty(StrategizedProperty): pj = kwargs.pop('primaryjoin', self.secondaryjoin) sj = kwargs.pop('secondaryjoin', self.primaryjoin) else: - pj = kwargs.pop('primaryjoin', self.primaryjoin) + pj = kwargs.pop('primaryjoin', + self._join_condition.primaryjoin_reverse_remote) sj = kwargs.pop('secondaryjoin', None) if sj: raise sa_exc.InvalidRequestError( "Can't assign 'secondaryjoin' on a backref against " "a non-secondary relationship." ) + foreign_keys = kwargs.pop('foreign_keys', self._user_defined_foreign_keys) parent = self.parent.primary_mapper() @@ -1112,21 +1112,6 @@ class RelationshipProperty(StrategizedProperty): self._add_reverse_property(self.back_populates) def _post_init(self): - self.logger.info('%s setup primary join %s', self, - self.primaryjoin) - self.logger.info('%s setup secondary join %s', self, - self.secondaryjoin) - self.logger.info('%s synchronize pairs [%s]', self, - ','.join('(%s => %s)' % (l, r) for (l, r) in - self.synchronize_pairs)) - self.logger.info('%s secondary synchronize pairs [%s]', self, - ','.join('(%s => %s)' % (l, r) for (l, r) in - self.secondary_synchronize_pairs or [])) - self.logger.info('%s local/remote pairs [%s]', self, - ','.join('(%s / %s)' % (l, r) for (l, r) in - self.local_remote_pairs)) - self.logger.info('%s relationship direction %s', self, - self.direction) if self.uselist is None: self.uselist = self.direction is not MANYTOONE if not self.viewonly: @@ -1141,46 +1126,6 @@ class RelationshipProperty(StrategizedProperty): strategy = self._get_strategy(strategies.LazyLoader) return strategy.use_get - def _refers_to_parent_table(self): - alt = self._alt_refers_to_parent_table() - pt = self.parent.mapped_table - mt = self.mapper.mapped_table - for c, f in self.synchronize_pairs: - if ( - pt.is_derived_from(c.table) and \ - pt.is_derived_from(f.table) and \ - mt.is_derived_from(c.table) and \ - mt.is_derived_from(f.table) - ): - assert alt - return True - else: - assert not alt - return False - - def _alt_refers_to_parent_table(self): - pt = self.parent.mapped_table - mt = self.mapper.mapped_table - result = [False] - def visit_binary(binary): - c, f = binary.left, binary.right - if ( - isinstance(c, expression.ColumnClause) and \ - isinstance(f, expression.ColumnClause) and \ - pt.is_derived_from(c.table) and \ - pt.is_derived_from(f.table) and \ - mt.is_derived_from(c.table) and \ - mt.is_derived_from(f.table) - ): - result[0] = True - - visitors.traverse( - self.primaryjoin, - {}, - {"binary":visit_binary} - ) - return result[0] - @util.memoized_property def _is_self_referential(self): return self.mapper.common_parent(self.parent) diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index 723d45295..d8c2659b6 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -15,7 +15,7 @@ and `secondaryjoin` aspects of :func:`.relationship`. from sqlalchemy import sql, util, log, exc as sa_exc, schema from sqlalchemy.sql.util import ClauseAdapter, criterion_as_pairs, \ - join_condition, _shallow_annotate + join_condition, _shallow_annotate, visit_binary_product from sqlalchemy.sql import operators, expression, visitors from sqlalchemy.orm.interfaces import MANYTOMANY, MANYTOONE, ONETOMANY @@ -78,7 +78,34 @@ class JoinCondition(object): self._determine_joins() self._annotate_fks() self._annotate_remote() + self._annotate_local() self._determine_direction() + self._setup_pairs() + self._check_foreign_cols(self.primaryjoin, True) + if self.secondaryjoin is not None: + self._check_foreign_cols(self.secondaryjoin, False) + self._check_remote_side() + self._log_joins() + + def _log_joins(self): + if self.prop is None: + return + log = self.prop.logger + log.info('%s setup primary join %s', self, + self.primaryjoin) + log.info('%s setup secondary join %s', self, + self.secondaryjoin) + log.info('%s synchronize pairs [%s]', self, + ','.join('(%s => %s)' % (l, r) for (l, r) in + self.synchronize_pairs)) + log.info('%s secondary synchronize pairs [%s]', self, + ','.join('(%s => %s)' % (l, r) for (l, r) in + self.secondary_synchronize_pairs or [])) + log.info('%s local/remote pairs [%s]', self, + ','.join('(%s / %s)' % (l, r) for (l, r) in + self.local_remote_pairs)) + log.info('%s relationship direction %s', self, + self.direction) def _determine_joins(self): """Determine the 'primaryjoin' and 'secondaryjoin' attributes, @@ -128,28 +155,60 @@ class JoinCondition(object): "'secondaryjoin' is needed as well." % self.prop) + @util.memoized_property + def primaryjoin_reverse_remote(self): + def replace(element): + if "remote" in element._annotations: + v = element._annotations.copy() + del v['remote'] + v['local'] = True + return element._with_annotations(v) + elif "local" in element._annotations: + v = element._annotations.copy() + del v['local'] + v['remote'] = True + return element._with_annotations(v) + return visitors.replacement_traverse(self.primaryjoin, {}, replace) + + def _has_annotation(self, clause, annotation): + for col in visitors.iterate(clause, {}): + if annotation in col._annotations: + return True + else: + return False + def _annotate_fks(self): + if self._has_annotation(self.primaryjoin, "foreign"): + return + + if self.consider_as_foreign_keys: + self._annotate_from_fk_list() + else: + self._annotate_present_fks() + + def _annotate_from_fk_list(self): + def check_fk(col): + if col in self.consider_as_foreign_keys: + return col._annotate({"foreign":True}) + self.primaryjoin = visitors.replacement_traverse( + self.primaryjoin, + {}, + check_fk + ) + if self.secondaryjoin is not None: + self.secondaryjoin = visitors.replacement_traverse( + self.secondaryjoin, + {}, + check_fk + ) + + def _annotate_present_fks(self): if self.secondary is not None: secondarycols = util.column_set(self.secondary.c) else: secondarycols = set() - def col_is(a, b): - return a.compare(b) - def is_foreign(a, b): - if self.consider_as_foreign_keys: - if a in self.consider_as_foreign_keys and ( - col_is(a, b) or - b not in self.consider_as_foreign_keys - ): - return a - elif b in self.consider_as_foreign_keys and ( - col_is(a, b) or - a not in self.consider_as_foreign_keys - ): - return b - if isinstance(a, schema.Column) and \ isinstance(b, schema.Column): if a.references(b): @@ -163,19 +222,6 @@ class JoinCondition(object): elif b in secondarycols and a not in secondarycols: return b - def _annotate_fk(binary, left, right): - can_be_synced = self.can_be_synced_fn(left) - left = left._annotate({ - #"equated":binary.operator is operators.eq, - "can_be_synced":can_be_synced and \ - binary.operator is operators.eq - }) - right = right._annotate({ - #"equated":binary.operator is operators.eq, - "referent":True - }) - return left, right - def visit_binary(binary): if not isinstance(binary.left, sql.ColumnElement) or \ not isinstance(binary.right, sql.ColumnElement): @@ -185,20 +231,12 @@ class JoinCondition(object): "foreign" not in binary.right._annotations: col = is_foreign(binary.left, binary.right) if col is not None: - if col is binary.left: + if col.compare(binary.left): binary.left = binary.left._annotate( {"foreign":True}) - elif col is binary.right: + elif col.compare(binary.right): binary.right = binary.right._annotate( {"foreign":True}) - # TODO: when the two cols are the same. - - if "foreign" in binary.left._annotations: - binary.left, binary.right = _annotate_fk( - binary, binary.left, binary.right) - if "foreign" in binary.right._annotations: - binary.right, binary.left = _annotate_fk( - binary, binary.right, binary.left) self.primaryjoin = visitors.cloned_traverse( self.primaryjoin, @@ -211,11 +249,6 @@ class JoinCondition(object): {}, {"binary":visit_binary} ) - self._check_foreign_cols( - self.primaryjoin, True) - if self.secondaryjoin is not None: - self._check_foreign_cols( - self.secondaryjoin, False) def _refers_to_parent_table(self): pt = self.parent_selectable @@ -241,18 +274,14 @@ class JoinCondition(object): return result[0] def _annotate_remote(self): - parentcols = util.column_set(self.parent_selectable.c) + if self._has_annotation(self.primaryjoin, "remote"): + return - for col in visitors.iterate(self.primaryjoin, {}): - if "remote" in col._annotations: - has_remote_annotations = True - break - else: - has_remote_annotations = False + parentcols = util.column_set(self.parent_selectable.c) def _annotate_selfref(fn): def visit_binary(binary): - equated = binary.left is binary.right + equated = binary.left.compare(binary.right) if isinstance(binary.left, sql.ColumnElement) and \ isinstance(binary.right, sql.ColumnElement): # assume one to many - FKs are "remote" @@ -267,44 +296,72 @@ class JoinCondition(object): self.primaryjoin, {}, {"binary":visit_binary}) - if not has_remote_annotations: + if self.secondary is not None: + def repl(element): + if self.secondary.c.contains_column(element): + return element._annotate({"remote":True}) + self.primaryjoin = visitors.replacement_traverse( + self.primaryjoin, {}, repl) + self.secondaryjoin = visitors.replacement_traverse( + self.secondaryjoin, {}, repl) + elif self._local_remote_pairs or self._remote_side: + if self._local_remote_pairs: - raise NotImplementedError() - elif self._remote_side: - if self._refers_to_parent_table(): - _annotate_selfref(lambda col:col in self._remote_side) - else: - def repl(element): - if element in self._remote_side: - return element._annotate({"remote":True}) - self.primaryjoin = visitors.replacement_traverse( - self.primaryjoin, {}, repl) - elif self.secondary is not None: - def repl(element): - if self.secondary.c.contains_column(element): - return element._annotate({"remote":True}) - self.primaryjoin = visitors.replacement_traverse( - self.primaryjoin, {}, repl) - self.secondaryjoin = visitors.replacement_traverse( - self.secondaryjoin, {}, repl) - elif self._refers_to_parent_table(): - _annotate_selfref(lambda col:"foreign" in col._annotations) + if self._remote_side: + raise sa_exc.ArgumentError( + "remote_side argument is redundant " + "against more detailed _local_remote_side " + "argument.") + + remote_side = [r for (l, r) in self._local_remote_pairs] + else: + remote_side = self._remote_side + + if self._refers_to_parent_table(): + _annotate_selfref(lambda col:col in remote_side) else: def repl(element): - if self.child_selectable.c.contains_column(element): + if element in remote_side: return element._annotate({"remote":True}) - self.primaryjoin = visitors.replacement_traverse( self.primaryjoin, {}, repl) + elif self._refers_to_parent_table(): + _annotate_selfref(lambda col:"foreign" in col._annotations) + else: + def repl(element): + if self.child_selectable.c.contains_column(element): + return element._annotate({"remote":True}) + + self.primaryjoin = visitors.replacement_traverse( + self.primaryjoin, {}, repl) + + def _annotate_local(self): + if self._has_annotation(self.primaryjoin, "local"): + return + + parentcols = util.column_set(self.parent_selectable.c) + + if self._local_remote_pairs: + local_side = util.column_set([l for (l, r) + in self._local_remote_pairs]) + else: + local_side = util.column_set(self.parent_selectable.c) def locals_(elem): if "remote" not in elem._annotations and \ - elem in parentcols: + elem in local_side: return elem._annotate({"local":True}) self.primaryjoin = visitors.replacement_traverse( self.primaryjoin, {}, locals_ ) + def _check_remote_side(self): + if not self.local_remote_pairs: + raise sa_exc.ArgumentError('Relationship %s could ' + 'not determine any local/remote column ' + 'pairs from remote side argument %r' + % (self.prop, self._remote_side)) + def _check_foreign_cols(self, join_condition, primary): """Check the foreign key columns collected and emit error messages.""" @@ -315,11 +372,10 @@ class JoinCondition(object): has_foreign = bool(foreign_cols) - if self.support_sync: - for col in foreign_cols: - if col._annotations.get("can_be_synced"): - can_sync = True - break + if primary: + can_sync = bool(self.synchronize_pairs) + else: + can_sync = bool(self.secondary_synchronize_pairs) if self.support_sync and can_sync or \ (not self.support_sync and has_foreign): @@ -407,6 +463,44 @@ class JoinCondition(object): "key columns are present in neither the parent " "nor the child's mapped tables" % self.prop) + def _setup_pairs(self): + sync_pairs = [] + lrp = util.OrderedSet([]) + secondary_sync_pairs = [] + + def go(joincond, collection): + def visit_binary(binary, left, right): + if "remote" in right._annotations and \ + "remote" not in left._annotations and \ + self.can_be_synced_fn(left): + lrp.add((left, right)) + elif "remote" in left._annotations and \ + "remote" not in right._annotations and \ + self.can_be_synced_fn(right): + lrp.add((right, left)) + if binary.operator is operators.eq: + # and \ + #binary.left.compare(left) and \ + #binary.right.compare(right): + if "foreign" in right._annotations: + collection.append((left, right)) + elif "foreign" in left._annotations: + collection.append((right, left)) + visit_binary_product(visit_binary, joincond) + + for joincond, collection in [ + (self.primaryjoin, sync_pairs), + (self.secondaryjoin, secondary_sync_pairs) + ]: + if joincond is None: + continue + go(joincond, collection) + + self.local_remote_pairs = list(lrp) + self.synchronize_pairs = sync_pairs + self.secondary_synchronize_pairs = secondary_sync_pairs + + @util.memoized_property def remote_columns(self): return self._gather_join_annotations("remote") @@ -416,38 +510,6 @@ class JoinCondition(object): return self._gather_join_annotations("local") @util.memoized_property - def synchronize_pairs(self): - parentcols = util.column_set(self.parent_selectable.c) - targetcols = util.column_set(self.child_selectable.c) - result = [] - for l, r in self.local_remote_pairs: - if self.secondary is not None: - if "foreign" in r._annotations and \ - l in parentcols: - result.append((l, r)) - elif "foreign" in r._annotations and \ - "can_be_synced" in r._annotations: - result.append((l, r)) - elif "foreign" in l._annotations and \ - "can_be_synced" in l._annotations: - result.append((r, l)) - return result - - @util.memoized_property - def secondary_synchronize_pairs(self): - parentcols = util.column_set(self.parent_selectable.c) - targetcols = util.column_set(self.child_selectable.c) - result = [] - if self.secondary is None: - return result - - for l, r in self.local_remote_pairs: - if "foreign" in l._annotations and \ - r in targetcols: - result.append((l, r)) - return result - - @util.memoized_property def foreign_key_columns(self): return self._gather_join_annotations("foreign") @@ -470,24 +532,6 @@ class JoinCondition(object): if annotation.issubset(col._annotations) ]) - @util.memoized_property - def local_remote_pairs(self): - lrp = util.OrderedSet() - def visit_binary(binary): - if "remote" in binary.right._annotations and \ - "remote" not in binary.left._annotations and \ - isinstance(binary.left, expression.ColumnClause) and \ - self.can_be_synced_fn(binary.left): - lrp.add((binary.left, binary.right)) - elif "remote" in binary.left._annotations and \ - "remote" not in binary.right._annotations and \ - isinstance(binary.right, expression.ColumnClause) and \ - self.can_be_synced_fn(binary.right): - lrp.add((binary.right, binary.left)) - visitors.traverse(self.primaryjoin, {}, {"binary":visit_binary}) - if self.secondaryjoin is not None: - visitors.traverse(self.secondaryjoin, {}, {"binary":visit_binary}) - return list(lrp) def join_targets(self, source_selectable, dest_selectable, @@ -604,147 +648,6 @@ def _create_lazy_clause(cls, prop, reverse_direction=False): return lazywhere, bind_to_col, equated_columns -def _determine_synchronize_pairs(self): - """Resolve 'primary'/foreign' column pairs from the primaryjoin - and secondaryjoin arguments. - - """ - if self.local_remote_pairs: - if not self._user_defined_foreign_keys: - raise sa_exc.ArgumentError( - "foreign_keys argument is " - "required with _local_remote_pairs argument") - self.synchronize_pairs = [] - for l, r in self.local_remote_pairs: - if r in self._user_defined_foreign_keys: - self.synchronize_pairs.append((l, r)) - elif l in self._user_defined_foreign_keys: - self.synchronize_pairs.append((r, l)) - else: - self.synchronize_pairs = self._sync_pairs_from_join( - self.primaryjoin, - True) - - self._calculated_foreign_keys = util.column_set( - r for (l, r) in - self.synchronize_pairs) - - if self.secondaryjoin is not None: - self.secondary_synchronize_pairs = self._sync_pairs_from_join( - self.secondaryjoin, - False) - self._calculated_foreign_keys.update( - r for (l, r) in - self.secondary_synchronize_pairs) - else: - self.secondary_synchronize_pairs = None - - -def _determine_local_remote_pairs(self): - """Determine pairs of columns representing "local" to - "remote", where "local" columns are on the parent mapper, - "remote" are on the target mapper. - - These pairs are used on the load side only to generate - lazy loading clauses. - - """ - if not self.local_remote_pairs and not self.remote_side: - # the most common, trivial case. Derive - # local/remote pairs from the synchronize pairs. - eq_pairs = util.unique_list( - self.synchronize_pairs + - (self.secondary_synchronize_pairs or [])) - if self.direction is MANYTOONE: - self.local_remote_pairs = [(r, l) for l, r in eq_pairs] - else: - self.local_remote_pairs = eq_pairs - - # "remote_side" specified, derive from the primaryjoin - # plus remote_side, similarly to how synchronize_pairs - # were determined. - elif self.remote_side: - if self.local_remote_pairs: - raise sa_exc.ArgumentError('remote_side argument is ' - 'redundant against more detailed ' - '_local_remote_side argument.') - if self.direction is MANYTOONE: - self.local_remote_pairs = [(r, l) for (l, r) in - criterion_as_pairs(self.primaryjoin, - consider_as_referenced_keys=self.remote_side, - any_operator=True)] - - else: - self.local_remote_pairs = \ - criterion_as_pairs(self.primaryjoin, - consider_as_foreign_keys=self.remote_side, - any_operator=True) - if not self.local_remote_pairs: - raise sa_exc.ArgumentError('Relationship %s could ' - 'not determine any local/remote column ' - 'pairs from remote side argument %r' - % (self, self.remote_side)) - # else local_remote_pairs were sent explcitly via - # ._local_remote_pairs. - - # create local_side/remote_side accessors - self.local_side = util.ordered_column_set( - l for l, r in self.local_remote_pairs) - self.remote_side = util.ordered_column_set( - r for l, r in self.local_remote_pairs) - - # check that the non-foreign key column in the local/remote - # collection is mapped. The foreign key - # which the individual mapped column references directly may - # itself be in a non-mapped table; see - # test.orm.test_relationships.ViewOnlyComplexJoin.test_basic - # for an example of this. - if self.direction is ONETOMANY: - for col in self.local_side: - if not self._columns_are_mapped(col): - raise sa_exc.ArgumentError( - "Local column '%s' is not " - "part of mapping %s. Specify remote_side " - "argument to indicate which column lazy join " - "condition should compare against." % (col, - self.parent)) - elif self.direction is MANYTOONE: - for col in self.remote_side: - if not self._columns_are_mapped(col): - raise sa_exc.ArgumentError( - "Remote column '%s' is not " - "part of mapping %s. Specify remote_side " - "argument to indicate which column lazy join " - "condition should bind." % (col, self.mapper)) - - count = [0] - def clone(elem): - if set(['local', 'remote']).intersection(elem._annotations): - return None - elif elem in self.local_side and elem in self.remote_side: - # TODO: OK this still sucks. this is basically, - # refuse, refuse, refuse the temptation to guess! - # but crap we really have to guess don't we. we - # might want to traverse here with cloned_traverse - # so we can see the binary exprs and do it at that - # level.... - if count[0] % 2 == 0: - elem = elem._annotate({'local':True}) - else: - elem = elem._annotate({'remote':True}) - count[0] += 1 - elif elem in self.local_side: - elem = elem._annotate({'local':True}) - elif elem in self.remote_side: - elem = elem._annotate({'remote':True}) - else: - elem = None - return elem - - self.primaryjoin = visitors.replacement_traverse( - self.primaryjoin, {}, clone - ) - def _criterion_exists(self, criterion=None, **kwargs): if getattr(self, '_of_type', None): diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 5f4b182d0..320234281 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -785,6 +785,7 @@ class SubqueryLoader(AbstractRelationshipLoader): leftmost_mapper, leftmost_prop = \ subq_mapper, \ subq_mapper._props[subq_path[1]] + # TODO: local cols might not be unique here leftmost_cols, remote_cols = self._local_remote_columns(leftmost_prop) leftmost_attr = [ @@ -846,6 +847,7 @@ class SubqueryLoader(AbstractRelationshipLoader): # self.parent is more specific than subq_path[-2] parent_alias = mapperutil.AliasedClass(self.parent) + # TODO: local cols might not be unique here local_cols, remote_cols = \ self._local_remote_columns(self.parent_property) @@ -885,6 +887,7 @@ class SubqueryLoader(AbstractRelationshipLoader): if prop.secondary is None: return zip(*prop.local_remote_pairs) else: + # TODO: this isn't going to work for readonly.... return \ [p[0] for p in prop.synchronize_pairs],\ [ @@ -930,6 +933,7 @@ class SubqueryLoader(AbstractRelationshipLoader): if ('subquery', reduced_path) not in context.attributes: return None, None, None + # TODO: local_cols might not be unique here local_cols, remote_cols = self._local_remote_columns(self.parent_property) q = context.attributes[('subquery', reduced_path)] diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 0cd5b0594..f17f675f4 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -366,7 +366,18 @@ def _orm_annotate(element, exclude=None): """ return sql_util._deep_annotate(element, {'_orm_adapt':True}, exclude) -_orm_deannotate = sql_util._deep_deannotate +def _orm_deannotate(element): + """Remove annotations that link a column to a particular mapping. + + Note this doesn't affect "remote" and "foreign" annotations + passed by the :func:`.orm.foreign` and :func:`.orm.remote` + annotators. + + """ + + return sql_util._deep_deannotate(element, + values=("_orm_adapt", "parententity") + ) class _ORMJoin(expression.Join): """Extend Join to support ORM constructs as input.""" diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 72099a5f5..ebf4de9a2 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -1576,18 +1576,30 @@ class ClauseElement(Visitable): return id(self) def _annotate(self, values): - """return a copy of this ClauseElement with the given annotations - dictionary. + """return a copy of this ClauseElement with annotations + updated by the given dictionary. """ return sqlutil.Annotated(self, values) - def _deannotate(self): - """return a copy of this ClauseElement with an empty annotations - dictionary. + def _with_annotations(self, values): + """return a copy of this ClauseElement with annotations + replaced by the given dictionary. """ - return self._clone() + return sqlutil.Annotated(self, values) + + def _deannotate(self, values=None): + """return a copy of this :class:`.ClauseElement` with annotations + removed. + + :param values: optional tuple of individual values + to remove. + + """ + # since we have no annotations we return + # self + return self def unique_params(self, *optionaldict, **kwargs): """Return a copy with :func:`bindparam()` elments replaced. diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index f0509c16f..9a45a5777 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -62,6 +62,61 @@ def find_join_source(clauses, join_to): else: return None, None + +def visit_binary_product(fn, expr): + """Produce a traversal of the given expression, delivering + column comparisons to the given function. + + The function is of the form:: + + def my_fn(binary, left, right) + + For each binary expression located which has a + comparison operator, the product of "left" and + "right" will be delivered to that function, + in terms of that binary. + + Hence an expression like:: + + and_( + (a + b) == q + func.sum(e + f), + j == r + ) + + would have the traversal:: + + a <eq> q + a <eq> e + a <eq> f + b <eq> q + b <eq> e + b <eq> f + j <eq> r + + That is, every combination of "left" and + "right" that doesn't further contain + a binary comparison is passed as pairs. + + """ + stack = [] + def visit(element): + if element.__visit_name__ == 'binary' and \ + operators.is_comparison(element.operator): + stack.insert(0, element) + for l in visit(element.left): + for r in visit(element.right): + fn(stack[0], l, r) + stack.pop(0) + for elem in element.get_children(): + visit(elem) + else: + if isinstance(element, expression.ColumnClause): + yield element + for elem in element.get_children(): + for e in visit(elem): + yield e + list(visit(expr)) + def find_tables(clause, check_columns=False, include_aliases=False, include_joins=False, include_selects=False, include_crud=False): @@ -357,13 +412,22 @@ class Annotated(object): def _annotate(self, values): _values = self._annotations.copy() _values.update(values) + return self._with_annotations(_values) + + def _with_annotations(self, values): clone = self.__class__.__new__(self.__class__) clone.__dict__ = self.__dict__.copy() - clone._annotations = _values + clone._annotations = values return clone - def _deannotate(self): - return self.__element + def _deannotate(self, values=None): + if values is None: + return self.__element + else: + _values = self._annotations.copy() + for v in values: + _values.pop(v, None) + return self._with_annotations(_values) def _compiler_dispatch(self, visitor, **kw): return self.__element.__class__._compiler_dispatch(self, visitor, **kw) @@ -426,11 +490,11 @@ def _deep_annotate(element, annotations, exclude=None): element = clone(element) return element -def _deep_deannotate(element): - """Deep copy the given element, removing all annotations.""" +def _deep_deannotate(element, values=None): + """Deep copy the given element, removing annotations.""" def clone(elem): - elem = elem._deannotate() + elem = elem._deannotate(values=values) elem._copy_internals(clone=clone) return elem |