diff options
-rw-r--r-- | lib/sqlalchemy/orm/__init__.py | 8 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/properties.py | 69 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/relationships.py | 447 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/strategies.py | 4 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/util.py | 13 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/expression.py | 24 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/util.py | 76 | ||||
-rw-r--r-- | test/orm/test_rel_fn.py | 246 | ||||
-rw-r--r-- | test/orm/test_relationships.py | 79 | ||||
-rw-r--r-- | test/sql/test_generative.py | 86 | ||||
-rw-r--r-- | test/sql/test_selectable.py | 4 |
11 files changed, 597 insertions, 459 deletions
diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py index 9fd969e3b..13bd18f08 100644 --- a/lib/sqlalchemy/orm/__init__.py +++ b/lib/sqlalchemy/orm/__init__.py @@ -44,6 +44,11 @@ from sqlalchemy.orm.properties import ( PropertyLoader, SynonymProperty, ) +from sqlalchemy.orm.relationships import ( + foreign, + remote, + remote_foreign +) from sqlalchemy.orm import mapper as mapperlib from sqlalchemy.orm.mapper import reconstructor, validates from sqlalchemy.orm import strategies @@ -81,6 +86,7 @@ __all__ = ( 'dynamic_loader', 'eagerload', 'eagerload_all', + 'foreign', 'immediateload', 'join', 'joinedload', @@ -96,6 +102,8 @@ __all__ = ( 'reconstructor', 'relationship', 'relation', + 'remote', + 'remote_foreign', 'scoped_session', 'sessionmaker', 'subqueryload', diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 77da33b5f..38237b2d4 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -16,7 +16,7 @@ from sqlalchemy.sql.util import ClauseAdapter, criterion_as_pairs, \ join_condition, _shallow_annotate from sqlalchemy.sql import operators, expression, visitors from sqlalchemy.orm import attributes, dependency, mapper, \ - object_mapper, strategies, configure_mappers + object_mapper, strategies, configure_mappers, relationships from sqlalchemy.orm.util import CascadeOptions, _class_to_mapper, \ _orm_annotate, _orm_deannotate @@ -915,13 +915,12 @@ class RelationshipProperty(StrategizedProperty): self._check_conflicts() self._process_dependent_arguments() self._setup_join_conditions() - self._extra_determine_direction() + self._check_cascade_settings() self._post_init() self._generate_backref() super(RelationshipProperty, self).do_init() def _setup_join_conditions(self): - import relationships self._join_condition = jc = relationships.JoinCondition( parent_selectable=self.parent.mapped_table, child_selectable=self.mapper.mapped_table, @@ -946,8 +945,8 @@ class RelationshipProperty(StrategizedProperty): self.local_remote_pairs = jc.local_remote_pairs self.remote_side = jc.remote_columns self.synchronize_pairs = jc.synchronize_pairs - self.secondary_synchronize_pairs = jc.secondary_synchronize_pairs self._calculated_foreign_keys = jc.foreign_key_columns + self.secondary_synchronize_pairs = jc.secondary_synchronize_pairs def _check_conflicts(self): """Test that this relationship is legal, warn about @@ -1035,7 +1034,7 @@ class RelationshipProperty(StrategizedProperty): (self.key, self.parent.class_) ) - def _extra_determine_direction(self): + def _check_cascade_settings(self): if self.cascade.delete_orphan and not self.single_parent \ and (self.direction is MANYTOMANY or self.direction is MANYTOONE): @@ -1064,7 +1063,6 @@ class RelationshipProperty(StrategizedProperty): return False return True - def _generate_backref(self): if not self.is_primary(): return @@ -1083,13 +1081,15 @@ class RelationshipProperty(StrategizedProperty): pj = kwargs.pop('primaryjoin', self.secondaryjoin) sj = kwargs.pop('secondaryjoin', self.primaryjoin) else: - pj = kwargs.pop('primaryjoin', self.primaryjoin) + pj = kwargs.pop('primaryjoin', + self._join_condition.primaryjoin_reverse_remote) sj = kwargs.pop('secondaryjoin', None) if sj: raise sa_exc.InvalidRequestError( "Can't assign 'secondaryjoin' on a backref against " "a non-secondary relationship." ) + foreign_keys = kwargs.pop('foreign_keys', self._user_defined_foreign_keys) parent = self.parent.primary_mapper() @@ -1112,21 +1112,6 @@ class RelationshipProperty(StrategizedProperty): self._add_reverse_property(self.back_populates) def _post_init(self): - self.logger.info('%s setup primary join %s', self, - self.primaryjoin) - self.logger.info('%s setup secondary join %s', self, - self.secondaryjoin) - self.logger.info('%s synchronize pairs [%s]', self, - ','.join('(%s => %s)' % (l, r) for (l, r) in - self.synchronize_pairs)) - self.logger.info('%s secondary synchronize pairs [%s]', self, - ','.join('(%s => %s)' % (l, r) for (l, r) in - self.secondary_synchronize_pairs or [])) - self.logger.info('%s local/remote pairs [%s]', self, - ','.join('(%s / %s)' % (l, r) for (l, r) in - self.local_remote_pairs)) - self.logger.info('%s relationship direction %s', self, - self.direction) if self.uselist is None: self.uselist = self.direction is not MANYTOONE if not self.viewonly: @@ -1141,46 +1126,6 @@ class RelationshipProperty(StrategizedProperty): strategy = self._get_strategy(strategies.LazyLoader) return strategy.use_get - def _refers_to_parent_table(self): - alt = self._alt_refers_to_parent_table() - pt = self.parent.mapped_table - mt = self.mapper.mapped_table - for c, f in self.synchronize_pairs: - if ( - pt.is_derived_from(c.table) and \ - pt.is_derived_from(f.table) and \ - mt.is_derived_from(c.table) and \ - mt.is_derived_from(f.table) - ): - assert alt - return True - else: - assert not alt - return False - - def _alt_refers_to_parent_table(self): - pt = self.parent.mapped_table - mt = self.mapper.mapped_table - result = [False] - def visit_binary(binary): - c, f = binary.left, binary.right - if ( - isinstance(c, expression.ColumnClause) and \ - isinstance(f, expression.ColumnClause) and \ - pt.is_derived_from(c.table) and \ - pt.is_derived_from(f.table) and \ - mt.is_derived_from(c.table) and \ - mt.is_derived_from(f.table) - ): - result[0] = True - - visitors.traverse( - self.primaryjoin, - {}, - {"binary":visit_binary} - ) - return result[0] - @util.memoized_property def _is_self_referential(self): return self.mapper.common_parent(self.parent) diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index 723d45295..d8c2659b6 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -15,7 +15,7 @@ and `secondaryjoin` aspects of :func:`.relationship`. from sqlalchemy import sql, util, log, exc as sa_exc, schema from sqlalchemy.sql.util import ClauseAdapter, criterion_as_pairs, \ - join_condition, _shallow_annotate + join_condition, _shallow_annotate, visit_binary_product from sqlalchemy.sql import operators, expression, visitors from sqlalchemy.orm.interfaces import MANYTOMANY, MANYTOONE, ONETOMANY @@ -78,7 +78,34 @@ class JoinCondition(object): self._determine_joins() self._annotate_fks() self._annotate_remote() + self._annotate_local() self._determine_direction() + self._setup_pairs() + self._check_foreign_cols(self.primaryjoin, True) + if self.secondaryjoin is not None: + self._check_foreign_cols(self.secondaryjoin, False) + self._check_remote_side() + self._log_joins() + + def _log_joins(self): + if self.prop is None: + return + log = self.prop.logger + log.info('%s setup primary join %s', self, + self.primaryjoin) + log.info('%s setup secondary join %s', self, + self.secondaryjoin) + log.info('%s synchronize pairs [%s]', self, + ','.join('(%s => %s)' % (l, r) for (l, r) in + self.synchronize_pairs)) + log.info('%s secondary synchronize pairs [%s]', self, + ','.join('(%s => %s)' % (l, r) for (l, r) in + self.secondary_synchronize_pairs or [])) + log.info('%s local/remote pairs [%s]', self, + ','.join('(%s / %s)' % (l, r) for (l, r) in + self.local_remote_pairs)) + log.info('%s relationship direction %s', self, + self.direction) def _determine_joins(self): """Determine the 'primaryjoin' and 'secondaryjoin' attributes, @@ -128,28 +155,60 @@ class JoinCondition(object): "'secondaryjoin' is needed as well." % self.prop) + @util.memoized_property + def primaryjoin_reverse_remote(self): + def replace(element): + if "remote" in element._annotations: + v = element._annotations.copy() + del v['remote'] + v['local'] = True + return element._with_annotations(v) + elif "local" in element._annotations: + v = element._annotations.copy() + del v['local'] + v['remote'] = True + return element._with_annotations(v) + return visitors.replacement_traverse(self.primaryjoin, {}, replace) + + def _has_annotation(self, clause, annotation): + for col in visitors.iterate(clause, {}): + if annotation in col._annotations: + return True + else: + return False + def _annotate_fks(self): + if self._has_annotation(self.primaryjoin, "foreign"): + return + + if self.consider_as_foreign_keys: + self._annotate_from_fk_list() + else: + self._annotate_present_fks() + + def _annotate_from_fk_list(self): + def check_fk(col): + if col in self.consider_as_foreign_keys: + return col._annotate({"foreign":True}) + self.primaryjoin = visitors.replacement_traverse( + self.primaryjoin, + {}, + check_fk + ) + if self.secondaryjoin is not None: + self.secondaryjoin = visitors.replacement_traverse( + self.secondaryjoin, + {}, + check_fk + ) + + def _annotate_present_fks(self): if self.secondary is not None: secondarycols = util.column_set(self.secondary.c) else: secondarycols = set() - def col_is(a, b): - return a.compare(b) - def is_foreign(a, b): - if self.consider_as_foreign_keys: - if a in self.consider_as_foreign_keys and ( - col_is(a, b) or - b not in self.consider_as_foreign_keys - ): - return a - elif b in self.consider_as_foreign_keys and ( - col_is(a, b) or - a not in self.consider_as_foreign_keys - ): - return b - if isinstance(a, schema.Column) and \ isinstance(b, schema.Column): if a.references(b): @@ -163,19 +222,6 @@ class JoinCondition(object): elif b in secondarycols and a not in secondarycols: return b - def _annotate_fk(binary, left, right): - can_be_synced = self.can_be_synced_fn(left) - left = left._annotate({ - #"equated":binary.operator is operators.eq, - "can_be_synced":can_be_synced and \ - binary.operator is operators.eq - }) - right = right._annotate({ - #"equated":binary.operator is operators.eq, - "referent":True - }) - return left, right - def visit_binary(binary): if not isinstance(binary.left, sql.ColumnElement) or \ not isinstance(binary.right, sql.ColumnElement): @@ -185,20 +231,12 @@ class JoinCondition(object): "foreign" not in binary.right._annotations: col = is_foreign(binary.left, binary.right) if col is not None: - if col is binary.left: + if col.compare(binary.left): binary.left = binary.left._annotate( {"foreign":True}) - elif col is binary.right: + elif col.compare(binary.right): binary.right = binary.right._annotate( {"foreign":True}) - # TODO: when the two cols are the same. - - if "foreign" in binary.left._annotations: - binary.left, binary.right = _annotate_fk( - binary, binary.left, binary.right) - if "foreign" in binary.right._annotations: - binary.right, binary.left = _annotate_fk( - binary, binary.right, binary.left) self.primaryjoin = visitors.cloned_traverse( self.primaryjoin, @@ -211,11 +249,6 @@ class JoinCondition(object): {}, {"binary":visit_binary} ) - self._check_foreign_cols( - self.primaryjoin, True) - if self.secondaryjoin is not None: - self._check_foreign_cols( - self.secondaryjoin, False) def _refers_to_parent_table(self): pt = self.parent_selectable @@ -241,18 +274,14 @@ class JoinCondition(object): return result[0] def _annotate_remote(self): - parentcols = util.column_set(self.parent_selectable.c) + if self._has_annotation(self.primaryjoin, "remote"): + return - for col in visitors.iterate(self.primaryjoin, {}): - if "remote" in col._annotations: - has_remote_annotations = True - break - else: - has_remote_annotations = False + parentcols = util.column_set(self.parent_selectable.c) def _annotate_selfref(fn): def visit_binary(binary): - equated = binary.left is binary.right + equated = binary.left.compare(binary.right) if isinstance(binary.left, sql.ColumnElement) and \ isinstance(binary.right, sql.ColumnElement): # assume one to many - FKs are "remote" @@ -267,44 +296,72 @@ class JoinCondition(object): self.primaryjoin, {}, {"binary":visit_binary}) - if not has_remote_annotations: + if self.secondary is not None: + def repl(element): + if self.secondary.c.contains_column(element): + return element._annotate({"remote":True}) + self.primaryjoin = visitors.replacement_traverse( + self.primaryjoin, {}, repl) + self.secondaryjoin = visitors.replacement_traverse( + self.secondaryjoin, {}, repl) + elif self._local_remote_pairs or self._remote_side: + if self._local_remote_pairs: - raise NotImplementedError() - elif self._remote_side: - if self._refers_to_parent_table(): - _annotate_selfref(lambda col:col in self._remote_side) - else: - def repl(element): - if element in self._remote_side: - return element._annotate({"remote":True}) - self.primaryjoin = visitors.replacement_traverse( - self.primaryjoin, {}, repl) - elif self.secondary is not None: - def repl(element): - if self.secondary.c.contains_column(element): - return element._annotate({"remote":True}) - self.primaryjoin = visitors.replacement_traverse( - self.primaryjoin, {}, repl) - self.secondaryjoin = visitors.replacement_traverse( - self.secondaryjoin, {}, repl) - elif self._refers_to_parent_table(): - _annotate_selfref(lambda col:"foreign" in col._annotations) + if self._remote_side: + raise sa_exc.ArgumentError( + "remote_side argument is redundant " + "against more detailed _local_remote_side " + "argument.") + + remote_side = [r for (l, r) in self._local_remote_pairs] + else: + remote_side = self._remote_side + + if self._refers_to_parent_table(): + _annotate_selfref(lambda col:col in remote_side) else: def repl(element): - if self.child_selectable.c.contains_column(element): + if element in remote_side: return element._annotate({"remote":True}) - self.primaryjoin = visitors.replacement_traverse( self.primaryjoin, {}, repl) + elif self._refers_to_parent_table(): + _annotate_selfref(lambda col:"foreign" in col._annotations) + else: + def repl(element): + if self.child_selectable.c.contains_column(element): + return element._annotate({"remote":True}) + + self.primaryjoin = visitors.replacement_traverse( + self.primaryjoin, {}, repl) + + def _annotate_local(self): + if self._has_annotation(self.primaryjoin, "local"): + return + + parentcols = util.column_set(self.parent_selectable.c) + + if self._local_remote_pairs: + local_side = util.column_set([l for (l, r) + in self._local_remote_pairs]) + else: + local_side = util.column_set(self.parent_selectable.c) def locals_(elem): if "remote" not in elem._annotations and \ - elem in parentcols: + elem in local_side: return elem._annotate({"local":True}) self.primaryjoin = visitors.replacement_traverse( self.primaryjoin, {}, locals_ ) + def _check_remote_side(self): + if not self.local_remote_pairs: + raise sa_exc.ArgumentError('Relationship %s could ' + 'not determine any local/remote column ' + 'pairs from remote side argument %r' + % (self.prop, self._remote_side)) + def _check_foreign_cols(self, join_condition, primary): """Check the foreign key columns collected and emit error messages.""" @@ -315,11 +372,10 @@ class JoinCondition(object): has_foreign = bool(foreign_cols) - if self.support_sync: - for col in foreign_cols: - if col._annotations.get("can_be_synced"): - can_sync = True - break + if primary: + can_sync = bool(self.synchronize_pairs) + else: + can_sync = bool(self.secondary_synchronize_pairs) if self.support_sync and can_sync or \ (not self.support_sync and has_foreign): @@ -407,6 +463,44 @@ class JoinCondition(object): "key columns are present in neither the parent " "nor the child's mapped tables" % self.prop) + def _setup_pairs(self): + sync_pairs = [] + lrp = util.OrderedSet([]) + secondary_sync_pairs = [] + + def go(joincond, collection): + def visit_binary(binary, left, right): + if "remote" in right._annotations and \ + "remote" not in left._annotations and \ + self.can_be_synced_fn(left): + lrp.add((left, right)) + elif "remote" in left._annotations and \ + "remote" not in right._annotations and \ + self.can_be_synced_fn(right): + lrp.add((right, left)) + if binary.operator is operators.eq: + # and \ + #binary.left.compare(left) and \ + #binary.right.compare(right): + if "foreign" in right._annotations: + collection.append((left, right)) + elif "foreign" in left._annotations: + collection.append((right, left)) + visit_binary_product(visit_binary, joincond) + + for joincond, collection in [ + (self.primaryjoin, sync_pairs), + (self.secondaryjoin, secondary_sync_pairs) + ]: + if joincond is None: + continue + go(joincond, collection) + + self.local_remote_pairs = list(lrp) + self.synchronize_pairs = sync_pairs + self.secondary_synchronize_pairs = secondary_sync_pairs + + @util.memoized_property def remote_columns(self): return self._gather_join_annotations("remote") @@ -416,38 +510,6 @@ class JoinCondition(object): return self._gather_join_annotations("local") @util.memoized_property - def synchronize_pairs(self): - parentcols = util.column_set(self.parent_selectable.c) - targetcols = util.column_set(self.child_selectable.c) - result = [] - for l, r in self.local_remote_pairs: - if self.secondary is not None: - if "foreign" in r._annotations and \ - l in parentcols: - result.append((l, r)) - elif "foreign" in r._annotations and \ - "can_be_synced" in r._annotations: - result.append((l, r)) - elif "foreign" in l._annotations and \ - "can_be_synced" in l._annotations: - result.append((r, l)) - return result - - @util.memoized_property - def secondary_synchronize_pairs(self): - parentcols = util.column_set(self.parent_selectable.c) - targetcols = util.column_set(self.child_selectable.c) - result = [] - if self.secondary is None: - return result - - for l, r in self.local_remote_pairs: - if "foreign" in l._annotations and \ - r in targetcols: - result.append((l, r)) - return result - - @util.memoized_property def foreign_key_columns(self): return self._gather_join_annotations("foreign") @@ -470,24 +532,6 @@ class JoinCondition(object): if annotation.issubset(col._annotations) ]) - @util.memoized_property - def local_remote_pairs(self): - lrp = util.OrderedSet() - def visit_binary(binary): - if "remote" in binary.right._annotations and \ - "remote" not in binary.left._annotations and \ - isinstance(binary.left, expression.ColumnClause) and \ - self.can_be_synced_fn(binary.left): - lrp.add((binary.left, binary.right)) - elif "remote" in binary.left._annotations and \ - "remote" not in binary.right._annotations and \ - isinstance(binary.right, expression.ColumnClause) and \ - self.can_be_synced_fn(binary.right): - lrp.add((binary.right, binary.left)) - visitors.traverse(self.primaryjoin, {}, {"binary":visit_binary}) - if self.secondaryjoin is not None: - visitors.traverse(self.secondaryjoin, {}, {"binary":visit_binary}) - return list(lrp) def join_targets(self, source_selectable, dest_selectable, @@ -604,147 +648,6 @@ def _create_lazy_clause(cls, prop, reverse_direction=False): return lazywhere, bind_to_col, equated_columns -def _determine_synchronize_pairs(self): - """Resolve 'primary'/foreign' column pairs from the primaryjoin - and secondaryjoin arguments. - - """ - if self.local_remote_pairs: - if not self._user_defined_foreign_keys: - raise sa_exc.ArgumentError( - "foreign_keys argument is " - "required with _local_remote_pairs argument") - self.synchronize_pairs = [] - for l, r in self.local_remote_pairs: - if r in self._user_defined_foreign_keys: - self.synchronize_pairs.append((l, r)) - elif l in self._user_defined_foreign_keys: - self.synchronize_pairs.append((r, l)) - else: - self.synchronize_pairs = self._sync_pairs_from_join( - self.primaryjoin, - True) - - self._calculated_foreign_keys = util.column_set( - r for (l, r) in - self.synchronize_pairs) - - if self.secondaryjoin is not None: - self.secondary_synchronize_pairs = self._sync_pairs_from_join( - self.secondaryjoin, - False) - self._calculated_foreign_keys.update( - r for (l, r) in - self.secondary_synchronize_pairs) - else: - self.secondary_synchronize_pairs = None - - -def _determine_local_remote_pairs(self): - """Determine pairs of columns representing "local" to - "remote", where "local" columns are on the parent mapper, - "remote" are on the target mapper. - - These pairs are used on the load side only to generate - lazy loading clauses. - - """ - if not self.local_remote_pairs and not self.remote_side: - # the most common, trivial case. Derive - # local/remote pairs from the synchronize pairs. - eq_pairs = util.unique_list( - self.synchronize_pairs + - (self.secondary_synchronize_pairs or [])) - if self.direction is MANYTOONE: - self.local_remote_pairs = [(r, l) for l, r in eq_pairs] - else: - self.local_remote_pairs = eq_pairs - - # "remote_side" specified, derive from the primaryjoin - # plus remote_side, similarly to how synchronize_pairs - # were determined. - elif self.remote_side: - if self.local_remote_pairs: - raise sa_exc.ArgumentError('remote_side argument is ' - 'redundant against more detailed ' - '_local_remote_side argument.') - if self.direction is MANYTOONE: - self.local_remote_pairs = [(r, l) for (l, r) in - criterion_as_pairs(self.primaryjoin, - consider_as_referenced_keys=self.remote_side, - any_operator=True)] - - else: - self.local_remote_pairs = \ - criterion_as_pairs(self.primaryjoin, - consider_as_foreign_keys=self.remote_side, - any_operator=True) - if not self.local_remote_pairs: - raise sa_exc.ArgumentError('Relationship %s could ' - 'not determine any local/remote column ' - 'pairs from remote side argument %r' - % (self, self.remote_side)) - # else local_remote_pairs were sent explcitly via - # ._local_remote_pairs. - - # create local_side/remote_side accessors - self.local_side = util.ordered_column_set( - l for l, r in self.local_remote_pairs) - self.remote_side = util.ordered_column_set( - r for l, r in self.local_remote_pairs) - - # check that the non-foreign key column in the local/remote - # collection is mapped. The foreign key - # which the individual mapped column references directly may - # itself be in a non-mapped table; see - # test.orm.test_relationships.ViewOnlyComplexJoin.test_basic - # for an example of this. - if self.direction is ONETOMANY: - for col in self.local_side: - if not self._columns_are_mapped(col): - raise sa_exc.ArgumentError( - "Local column '%s' is not " - "part of mapping %s. Specify remote_side " - "argument to indicate which column lazy join " - "condition should compare against." % (col, - self.parent)) - elif self.direction is MANYTOONE: - for col in self.remote_side: - if not self._columns_are_mapped(col): - raise sa_exc.ArgumentError( - "Remote column '%s' is not " - "part of mapping %s. Specify remote_side " - "argument to indicate which column lazy join " - "condition should bind." % (col, self.mapper)) - - count = [0] - def clone(elem): - if set(['local', 'remote']).intersection(elem._annotations): - return None - elif elem in self.local_side and elem in self.remote_side: - # TODO: OK this still sucks. this is basically, - # refuse, refuse, refuse the temptation to guess! - # but crap we really have to guess don't we. we - # might want to traverse here with cloned_traverse - # so we can see the binary exprs and do it at that - # level.... - if count[0] % 2 == 0: - elem = elem._annotate({'local':True}) - else: - elem = elem._annotate({'remote':True}) - count[0] += 1 - elif elem in self.local_side: - elem = elem._annotate({'local':True}) - elif elem in self.remote_side: - elem = elem._annotate({'remote':True}) - else: - elem = None - return elem - - self.primaryjoin = visitors.replacement_traverse( - self.primaryjoin, {}, clone - ) - def _criterion_exists(self, criterion=None, **kwargs): if getattr(self, '_of_type', None): diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 5f4b182d0..320234281 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -785,6 +785,7 @@ class SubqueryLoader(AbstractRelationshipLoader): leftmost_mapper, leftmost_prop = \ subq_mapper, \ subq_mapper._props[subq_path[1]] + # TODO: local cols might not be unique here leftmost_cols, remote_cols = self._local_remote_columns(leftmost_prop) leftmost_attr = [ @@ -846,6 +847,7 @@ class SubqueryLoader(AbstractRelationshipLoader): # self.parent is more specific than subq_path[-2] parent_alias = mapperutil.AliasedClass(self.parent) + # TODO: local cols might not be unique here local_cols, remote_cols = \ self._local_remote_columns(self.parent_property) @@ -885,6 +887,7 @@ class SubqueryLoader(AbstractRelationshipLoader): if prop.secondary is None: return zip(*prop.local_remote_pairs) else: + # TODO: this isn't going to work for readonly.... return \ [p[0] for p in prop.synchronize_pairs],\ [ @@ -930,6 +933,7 @@ class SubqueryLoader(AbstractRelationshipLoader): if ('subquery', reduced_path) not in context.attributes: return None, None, None + # TODO: local_cols might not be unique here local_cols, remote_cols = self._local_remote_columns(self.parent_property) q = context.attributes[('subquery', reduced_path)] diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 0cd5b0594..f17f675f4 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -366,7 +366,18 @@ def _orm_annotate(element, exclude=None): """ return sql_util._deep_annotate(element, {'_orm_adapt':True}, exclude) -_orm_deannotate = sql_util._deep_deannotate +def _orm_deannotate(element): + """Remove annotations that link a column to a particular mapping. + + Note this doesn't affect "remote" and "foreign" annotations + passed by the :func:`.orm.foreign` and :func:`.orm.remote` + annotators. + + """ + + return sql_util._deep_deannotate(element, + values=("_orm_adapt", "parententity") + ) class _ORMJoin(expression.Join): """Extend Join to support ORM constructs as input.""" diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 72099a5f5..ebf4de9a2 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -1576,18 +1576,30 @@ class ClauseElement(Visitable): return id(self) def _annotate(self, values): - """return a copy of this ClauseElement with the given annotations - dictionary. + """return a copy of this ClauseElement with annotations + updated by the given dictionary. """ return sqlutil.Annotated(self, values) - def _deannotate(self): - """return a copy of this ClauseElement with an empty annotations - dictionary. + def _with_annotations(self, values): + """return a copy of this ClauseElement with annotations + replaced by the given dictionary. """ - return self._clone() + return sqlutil.Annotated(self, values) + + def _deannotate(self, values=None): + """return a copy of this :class:`.ClauseElement` with annotations + removed. + + :param values: optional tuple of individual values + to remove. + + """ + # since we have no annotations we return + # self + return self def unique_params(self, *optionaldict, **kwargs): """Return a copy with :func:`bindparam()` elments replaced. diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index f0509c16f..9a45a5777 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -62,6 +62,61 @@ def find_join_source(clauses, join_to): else: return None, None + +def visit_binary_product(fn, expr): + """Produce a traversal of the given expression, delivering + column comparisons to the given function. + + The function is of the form:: + + def my_fn(binary, left, right) + + For each binary expression located which has a + comparison operator, the product of "left" and + "right" will be delivered to that function, + in terms of that binary. + + Hence an expression like:: + + and_( + (a + b) == q + func.sum(e + f), + j == r + ) + + would have the traversal:: + + a <eq> q + a <eq> e + a <eq> f + b <eq> q + b <eq> e + b <eq> f + j <eq> r + + That is, every combination of "left" and + "right" that doesn't further contain + a binary comparison is passed as pairs. + + """ + stack = [] + def visit(element): + if element.__visit_name__ == 'binary' and \ + operators.is_comparison(element.operator): + stack.insert(0, element) + for l in visit(element.left): + for r in visit(element.right): + fn(stack[0], l, r) + stack.pop(0) + for elem in element.get_children(): + visit(elem) + else: + if isinstance(element, expression.ColumnClause): + yield element + for elem in element.get_children(): + for e in visit(elem): + yield e + list(visit(expr)) + def find_tables(clause, check_columns=False, include_aliases=False, include_joins=False, include_selects=False, include_crud=False): @@ -357,13 +412,22 @@ class Annotated(object): def _annotate(self, values): _values = self._annotations.copy() _values.update(values) + return self._with_annotations(_values) + + def _with_annotations(self, values): clone = self.__class__.__new__(self.__class__) clone.__dict__ = self.__dict__.copy() - clone._annotations = _values + clone._annotations = values return clone - def _deannotate(self): - return self.__element + def _deannotate(self, values=None): + if values is None: + return self.__element + else: + _values = self._annotations.copy() + for v in values: + _values.pop(v, None) + return self._with_annotations(_values) def _compiler_dispatch(self, visitor, **kw): return self.__element.__class__._compiler_dispatch(self, visitor, **kw) @@ -426,11 +490,11 @@ def _deep_annotate(element, annotations, exclude=None): element = clone(element) return element -def _deep_deannotate(element): - """Deep copy the given element, removing all annotations.""" +def _deep_deannotate(element, values=None): + """Deep copy the given element, removing annotations.""" def clone(elem): - elem = elem._deannotate() + elem = elem._deannotate(values=values) elem._copy_internals(clone=clone) return elem diff --git a/test/orm/test_rel_fn.py b/test/orm/test_rel_fn.py index d3d346bba..346cb90c1 100644 --- a/test/orm/test_rel_fn.py +++ b/test/orm/test_rel_fn.py @@ -6,9 +6,8 @@ from sqlalchemy import MetaData, Table, Column, ForeignKey, Integer, \ select, ForeignKeyConstraint, exc from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE, MANYTOMANY -class JoinCondTest(fixtures.TestBase, AssertsCompiledSQL): - __dialect__ = 'default' +class _JoinFixtures(object): @classmethod def setup_class(cls): m = MetaData() @@ -36,6 +35,28 @@ class JoinCondTest(fixtures.TestBase, AssertsCompiledSQL): ['composite_selfref.id', 'composite_selfref.group_id'] ) ) + cls.m2mleft = Table('m2mlft', m, + Column('id', Integer, primary_key=True), + ) + cls.m2mright = Table('m2mrgt', m, + Column('id', Integer, primary_key=True), + ) + cls.m2msecondary = Table('m2msecondary', m, + Column('lid', Integer, ForeignKey('m2mlft.id'), primary_key=True), + Column('rid', Integer, ForeignKey('m2mrgt.id'), primary_key=True), + ) + + def _join_fixture_m2m_selfref(self, **kw): + return relationships.JoinCondition( + self.m2mleft, + self.m2mright, + self.m2mleft, + self.m2mright, + secondary=self.m2msecondary, + primaryjoin=self.m2mleft.c.id==self.m2msecondary.c.lid, + secondaryjoin=self.m2mright.c.id==self.m2msecondary.c.rid, + **kw + ) def _join_fixture_o2m(self, **kw): return relationships.JoinCondition( @@ -120,6 +141,7 @@ class JoinCondTest(fixtures.TestBase, AssertsCompiledSQL): **kw ) +class ColumnCollectionsTest(_JoinFixtures, fixtures.TestBase, AssertsCompiledSQL): def test_determine_remote_columns_compound_1(self): joincond = self._join_fixture_compound_expression_1( support_sync=False) @@ -133,7 +155,25 @@ class JoinCondTest(fixtures.TestBase, AssertsCompiledSQL): support_sync=False) eq_( joincond.local_remote_pairs, - [] + [ + (self.left.c.x, self.right.c.x), + (self.left.c.x, self.right.c.y), + (self.left.c.y, self.right.c.x), + (self.left.c.y, self.right.c.y) + ] + ) + + def test_determine_local_remote_compound_2(self): + joincond = self._join_fixture_compound_expression_2( + support_sync=False) + eq_( + joincond.local_remote_pairs, + [ + (self.left.c.x, self.right.c.x), + (self.left.c.x, self.right.c.y), + (self.left.c.y, self.right.c.x), + (self.left.c.y, self.right.c.y) + ] ) def test_err_local_remote_compound_1(self): @@ -160,14 +200,71 @@ class JoinCondTest(fixtures.TestBase, AssertsCompiledSQL): set([self.right.c.x, self.right.c.y]) ) - def test_determine_local_remote_compound_2(self): - joincond = self._join_fixture_compound_expression_2( - support_sync=False) + + def test_determine_remote_columns_o2m(self): + joincond = self._join_fixture_o2m() + eq_( + joincond.remote_columns, + set([self.right.c.lid]) + ) + + def test_determine_remote_columns_o2m_selfref(self): + joincond = self._join_fixture_o2m_selfref() + eq_( + joincond.remote_columns, + set([self.selfref.c.sid]) + ) + + def test_determine_remote_columns_o2m_composite_selfref(self): + joincond = self._join_fixture_o2m_composite_selfref() + eq_( + joincond.remote_columns, + set([self.composite_selfref.c.parent_id, + self.composite_selfref.c.group_id]) + ) + + def test_determine_remote_columns_m2o_composite_selfref(self): + joincond = self._join_fixture_m2o_composite_selfref() + eq_( + joincond.remote_columns, + set([self.composite_selfref.c.id, + self.composite_selfref.c.group_id]) + ) + + def test_determine_remote_columns_m2o(self): + joincond = self._join_fixture_m2o() + eq_( + joincond.remote_columns, + set([self.left.c.id]) + ) + + def test_determine_local_remote_pairs_o2m(self): + joincond = self._join_fixture_o2m() eq_( joincond.local_remote_pairs, - [] + [(self.left.c.id, self.right.c.lid)] + ) + + def test_determine_synchronize_pairs_m2m_selfref(self): + joincond = self._join_fixture_m2m_selfref() + eq_( + joincond.synchronize_pairs, + [(self.m2mleft.c.id, self.m2msecondary.c.lid)] + ) + eq_( + joincond.secondary_synchronize_pairs, + [(self.m2mright.c.id, self.m2msecondary.c.rid)] ) + def test_determine_remote_columns_m2o_selfref(self): + joincond = self._join_fixture_m2o_selfref() + eq_( + joincond.remote_columns, + set([self.selfref.c.id]) + ) + + +class DirectionTest(_JoinFixtures, fixtures.TestBase, AssertsCompiledSQL): def test_determine_direction_compound_2(self): joincond = self._join_fixture_compound_expression_2( support_sync=False) @@ -176,60 +273,46 @@ class JoinCondTest(fixtures.TestBase, AssertsCompiledSQL): ONETOMANY ) - def test_determine_join_o2m(self): - joincond = self._join_fixture_o2m() - self.assert_compile( - joincond.primaryjoin, - "lft.id = rgt.lid" - ) - def test_determine_direction_o2m(self): joincond = self._join_fixture_o2m() is_(joincond.direction, ONETOMANY) - def test_determine_remote_columns_o2m(self): - joincond = self._join_fixture_o2m() - eq_( - joincond.remote_columns, - set([self.right.c.lid]) - ) - - def test_determine_join_o2m_selfref(self): - joincond = self._join_fixture_o2m_selfref() - self.assert_compile( - joincond.primaryjoin, - "selfref.id = selfref.sid" - ) - def test_determine_direction_o2m_selfref(self): joincond = self._join_fixture_o2m_selfref() is_(joincond.direction, ONETOMANY) - def test_determine_remote_columns_o2m_selfref(self): - joincond = self._join_fixture_o2m_selfref() - eq_( - joincond.remote_columns, - set([self.selfref.c.sid]) - ) + def test_determine_direction_m2o_selfref(self): + joincond = self._join_fixture_m2o_selfref() + is_(joincond.direction, MANYTOONE) - def test_join_targets_o2m_selfref(self): - joincond = self._join_fixture_o2m_selfref() - left = select([joincond.parent_selectable]).alias('pj') - pj, sj, sec, adapter = joincond.join_targets( - left, - joincond.child_selectable, - True) + def test_determine_direction_o2m_composite_selfref(self): + joincond = self._join_fixture_o2m_composite_selfref() + is_(joincond.direction, ONETOMANY) + + def test_determine_direction_m2o_composite_selfref(self): + joincond = self._join_fixture_m2o_composite_selfref() + is_(joincond.direction, MANYTOONE) + + def test_determine_direction_m2o(self): + joincond = self._join_fixture_m2o() + is_(joincond.direction, MANYTOONE) + + +class DetermineJoinTest(_JoinFixtures, fixtures.TestBase, AssertsCompiledSQL): + __dialect__ = 'default' + + def test_determine_join_o2m(self): + joincond = self._join_fixture_o2m() self.assert_compile( - pj, "pj.id = selfref.sid" + joincond.primaryjoin, + "lft.id = rgt.lid" ) - right = select([joincond.child_selectable]).alias('pj') - pj, sj, sec, adapter = joincond.join_targets( - joincond.parent_selectable, - right, - True) + def test_determine_join_o2m_selfref(self): + joincond = self._join_fixture_o2m_selfref() self.assert_compile( - pj, "selfref.id = pj.sid" + joincond.primaryjoin, + "selfref.id = selfref.sid" ) def test_determine_join_m2o_selfref(self): @@ -239,17 +322,6 @@ class JoinCondTest(fixtures.TestBase, AssertsCompiledSQL): "selfref.id = selfref.sid" ) - def test_determine_direction_m2o_selfref(self): - joincond = self._join_fixture_m2o_selfref() - is_(joincond.direction, MANYTOONE) - - def test_determine_remote_columns_m2o_selfref(self): - joincond = self._join_fixture_m2o_selfref() - eq_( - joincond.remote_columns, - set([self.selfref.c.id]) - ) - def test_determine_join_o2m_composite_selfref(self): joincond = self._join_fixture_o2m_composite_selfref() self.assert_compile( @@ -258,18 +330,6 @@ class JoinCondTest(fixtures.TestBase, AssertsCompiledSQL): "AND composite_selfref.id = composite_selfref.parent_id" ) - def test_determine_direction_o2m_composite_selfref(self): - joincond = self._join_fixture_o2m_composite_selfref() - is_(joincond.direction, ONETOMANY) - - def test_determine_remote_columns_o2m_composite_selfref(self): - joincond = self._join_fixture_o2m_composite_selfref() - eq_( - joincond.remote_columns, - set([self.composite_selfref.c.parent_id, - self.composite_selfref.c.group_id]) - ) - def test_determine_join_m2o_composite_selfref(self): joincond = self._join_fixture_m2o_composite_selfref() self.assert_compile( @@ -278,17 +338,7 @@ class JoinCondTest(fixtures.TestBase, AssertsCompiledSQL): "AND composite_selfref.id = composite_selfref.parent_id" ) - def test_determine_direction_m2o_composite_selfref(self): - joincond = self._join_fixture_m2o_composite_selfref() - is_(joincond.direction, MANYTOONE) - def test_determine_remote_columns_m2o_composite_selfref(self): - joincond = self._join_fixture_m2o_composite_selfref() - eq_( - joincond.remote_columns, - set([self.composite_selfref.c.id, - self.composite_selfref.c.group_id]) - ) def test_determine_join_m2o(self): joincond = self._join_fixture_m2o() @@ -297,24 +347,30 @@ class JoinCondTest(fixtures.TestBase, AssertsCompiledSQL): "lft.id = rgt.lid" ) - def test_determine_direction_m2o(self): - joincond = self._join_fixture_m2o() - is_(joincond.direction, MANYTOONE) +class AdaptedJoinTest(_JoinFixtures, fixtures.TestBase, AssertsCompiledSQL): + __dialect__ = 'default' - def test_determine_remote_columns_m2o(self): - joincond = self._join_fixture_m2o() - eq_( - joincond.remote_columns, - set([self.left.c.id]) + def test_join_targets_o2m_selfref(self): + joincond = self._join_fixture_o2m_selfref() + left = select([joincond.parent_selectable]).alias('pj') + pj, sj, sec, adapter = joincond.join_targets( + left, + joincond.child_selectable, + True) + self.assert_compile( + pj, "pj.id = selfref.sid" ) - def test_determine_local_remote_pairs_o2m(self): - joincond = self._join_fixture_o2m() - eq_( - joincond.local_remote_pairs, - [(self.left.c.id, self.right.c.lid)] + right = select([joincond.child_selectable]).alias('pj') + pj, sj, sec, adapter = joincond.join_targets( + joincond.parent_selectable, + right, + True) + self.assert_compile( + pj, "selfref.id = pj.sid" ) + def test_join_targets_o2m_plain(self): joincond = self._join_fixture_o2m() pj, sj, sec, adapter = joincond.join_targets( @@ -347,6 +403,8 @@ class JoinCondTest(fixtures.TestBase, AssertsCompiledSQL): pj, "lft.id = pj.lid" ) +class LazyClauseTest(_JoinFixtures, fixtures.TestBase, AssertsCompiledSQL): + def _test_lazy_clause_o2m(self): joincond = self._join_fixture_o2m() self.assert_compile( diff --git a/test/orm/test_relationships.py b/test/orm/test_relationships.py index 0a02cbf9a..d2dcbe312 100644 --- a/test/orm/test_relationships.py +++ b/test/orm/test_relationships.py @@ -7,8 +7,9 @@ from test.lib.schema import Table, Column from sqlalchemy.orm import mapper, relationship, relation, \ backref, create_session, configure_mappers, \ clear_mappers, sessionmaker, attributes,\ - Session, composite, column_property -from test.lib.testing import eq_, startswith_, AssertsCompiledSQL + Session, composite, column_property, foreign +from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE, MANYTOMANY +from test.lib.testing import eq_, startswith_, AssertsCompiledSQL, is_ from test.lib import fixtures from test.orm import _fixtures @@ -141,12 +142,12 @@ class CompositeSelfRefFKTest(fixtures.MappedTest): Table('company_t', metadata, Column('company_id', Integer, primary_key=True, test_needs_autoincrement=True), - Column('name', sa.Unicode(30))) + Column('name', String(30))) Table('employee_t', metadata, Column('company_id', Integer, primary_key=True), Column('emp_id', Integer, primary_key=True), - Column('name', sa.Unicode(30)), + Column('name', String(30)), Column('reports_to_id', Integer), sa.ForeignKeyConstraint( ['company_id'], @@ -158,7 +159,8 @@ class CompositeSelfRefFKTest(fixtures.MappedTest): @classmethod def setup_classes(cls): class Company(cls.Basic): - pass + def __init__(self, name): + self.name = name class Employee(cls.Basic): def __init__(self, name, company, emp_id, reports_to=None): @@ -248,11 +250,16 @@ class CompositeSelfRefFKTest(fixtures.MappedTest): self._test() def _test(self): + sess = Session() + self._setup_data(sess) + self._test_lazy_relations(sess) + self._test_join_aliasing(sess) + + def _setup_data(self, sess): Employee, Company = self.classes.Employee, self.classes.Company - sess = create_session() - c1 = Company() - c2 = Company() + c1 = Company('c1') + c2 = Company('c2') e1 = Employee(u'emp1', c1, 1) e2 = Employee(u'emp2', c1, 2, e1) @@ -263,10 +270,17 @@ class CompositeSelfRefFKTest(fixtures.MappedTest): e7 = Employee(u'emp7', c2, 3, e5) sess.add_all((c1, c2)) - sess.flush() - sess.expunge_all() + sess.commit() + sess.close() + + def _test_lazy_relations(self, sess): + Employee, Company = self.classes.Employee, self.classes.Company + + c1 = sess.query(Company).filter_by(name='c1').one() + c2 = sess.query(Company).filter_by(name='c2').one() + e1 = sess.query(Employee).filter_by(name='emp1').one() + e5 = sess.query(Employee).filter_by(name='emp5').one() - test_c1 = sess.query(Company).get(c1.company_id) test_e1 = sess.query(Employee).get([c1.company_id, e1.emp_id]) assert test_e1.name == 'emp1', test_e1.name test_e5 = sess.query(Employee).get([c2.company_id, e5.emp_id]) @@ -277,6 +291,16 @@ class CompositeSelfRefFKTest(fixtures.MappedTest): assert sess.query(Employee).\ get([c2.company_id, 3]).reports_to.name == 'emp5' + def _test_join_aliasing(self, sess): + Employee, Company = self.classes.Employee, self.classes.Company + eq_( + [n for n, in sess.query(Employee.name).\ + join(Employee.reports_to, aliased=True).\ + filter_by(name='emp5').\ + reset_joinpoint().\ + order_by(Employee.name)], + ['emp6', 'emp7'] + ) class CompositeJoinPartialFK(fixtures.MappedTest, AssertsCompiledSQL): __dialect__ = 'default' @@ -839,7 +863,6 @@ class AmbiguousJoinInterpretedAsSelfRef(fixtures.MappedTest): def test_mapping(self): Subscriber, Address = self.classes.Subscriber, self.classes.Address - from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE sess = create_session() assert Subscriber.addresses.property.direction is ONETOMANY assert Address.customer.property.direction is MANYTOONE @@ -1733,21 +1756,45 @@ class ExplicitLocalRemoteTest(fixtures.MappedTest): class T2(cls.Comparable): pass - def test_onetomany_funcfk(self): + def test_onetomany_funcfk_oldstyle(self): T2, T1, t2, t1 = (self.classes.T2, self.classes.T1, self.tables.t2, self.tables.t1) - # use a function within join condition. but specifying - # local_remote_pairs overrides all parsing of the join condition. + # old _local_remote_pairs mapper(T1, t1, properties={ 't2s':relationship(T2, primaryjoin=t1.c.id==sa.func.lower(t2.c.t1id), _local_remote_pairs=[(t1.c.id, t2.c.t1id)], - foreign_keys=[t2.c.t1id])}) + foreign_keys=[t2.c.t1id] + ) + }) + mapper(T2, t2) + self._test_onetomany() + + def test_onetomany_funcfk_annotated(self): + T2, T1, t2, t1 = (self.classes.T2, + self.classes.T1, + self.tables.t2, + self.tables.t1) + + # use annotation + mapper(T1, t1, properties={ + 't2s':relationship(T2, + primaryjoin=t1.c.id== + foreign(sa.func.lower(t2.c.t1id)), + )}) mapper(T2, t2) + self._test_onetomany() + def _test_onetomany(self): + T2, T1, t2, t1 = (self.classes.T2, + self.classes.T1, + self.tables.t2, + self.tables.t1) + is_(T1.t2s.property.direction, ONETOMANY) + eq_(T1.t2s.property.local_remote_pairs, [(t1.c.id, t2.c.t1id)]) sess = create_session() a1 = T1(id='number1', data='a1') a2 = T1(id='number2', data='a2') diff --git a/test/sql/test_generative.py b/test/sql/test_generative.py index f9333dbf5..d4f324dd7 100644 --- a/test/sql/test_generative.py +++ b/test/sql/test_generative.py @@ -1,5 +1,5 @@ from sqlalchemy import * -from sqlalchemy.sql import table, column, ClauseElement +from sqlalchemy.sql import table, column, ClauseElement, operators from sqlalchemy.sql.expression import _clone, _from_objects from test.lib import * from sqlalchemy.sql.visitors import * @@ -166,6 +166,90 @@ class TraversalTest(fixtures.TestBase, AssertsExecutionResults): s = set(ClauseVisitor().iterate(bin)) assert set(ClauseVisitor().iterate(bin)) == set([foo, bar, bin]) +class BinaryEndpointTraversalTest(fixtures.TestBase): + """test the special binary product visit""" + + def _assert_traversal(self, expr, expected): + canary = [] + def visit(binary, l, r): + canary.append((binary.operator, l, r)) + print binary.operator, l, r + sql_util.visit_binary_product(visit, expr) + eq_( + canary, expected + ) + + def test_basic(self): + a, b = column("a"), column("b") + self._assert_traversal( + a == b, + [ + (operators.eq, a, b) + ] + ) + + def test_with_tuples(self): + a, b, c, d, b1, b1a, b1b, e, f = ( + column("a"), + column("b"), + column("c"), + column("d"), + column("b1"), + column("b1a"), + column("b1b"), + column("e"), + column("f") + ) + expr = tuple_( + a, b, b1==tuple_(b1a, b1b == d), c + ) > tuple_( + func.go(e + f) + ) + self._assert_traversal( + expr, + [ + (operators.gt, a, e), + (operators.gt, a, f), + (operators.gt, b, e), + (operators.gt, b, f), + (operators.eq, b1, b1a), + (operators.eq, b1b, d), + (operators.gt, c, e), + (operators.gt, c, f) + ] + ) + + def test_composed(self): + a, b, e, f, q, j, r = ( + column("a"), + column("b"), + column("e"), + column("f"), + column("q"), + column("j"), + column("r"), + ) + expr = and_( + (a + b) == q + func.sum(e + f), + and_( + j == r, + f == q + ) + ) + self._assert_traversal( + expr, + [ + (operators.eq, a, q), + (operators.eq, a, e), + (operators.eq, a, f), + (operators.eq, b, q), + (operators.eq, b, e), + (operators.eq, b, f), + (operators.eq, j, r), + (operators.eq, f, q), + ] + ) + class ClauseTest(fixtures.TestBase, AssertsCompiledSQL): """test copy-in-place behavior of various ClauseElements.""" diff --git a/test/sql/test_selectable.py b/test/sql/test_selectable.py index 6d85f7c4f..4f1f39014 100644 --- a/test/sql/test_selectable.py +++ b/test/sql/test_selectable.py @@ -1151,5 +1151,7 @@ class AnnotationsTest(fixtures.TestBase): assert b2.left is not bin.left assert b3.left is not b2.left is not bin.left assert b4.left is bin.left # since column is immutable - assert b4.right is not bin.right is not b2.right is not b3.right + assert b4.right is bin.right + assert b2.right is not bin.right + assert b3.right is b4.right is bin.right |