summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/sqlalchemy/orm/__init__.py8
-rw-r--r--lib/sqlalchemy/orm/properties.py69
-rw-r--r--lib/sqlalchemy/orm/relationships.py447
-rw-r--r--lib/sqlalchemy/orm/strategies.py4
-rw-r--r--lib/sqlalchemy/orm/util.py13
-rw-r--r--lib/sqlalchemy/sql/expression.py24
-rw-r--r--lib/sqlalchemy/sql/util.py76
-rw-r--r--test/orm/test_rel_fn.py246
-rw-r--r--test/orm/test_relationships.py79
-rw-r--r--test/sql/test_generative.py86
-rw-r--r--test/sql/test_selectable.py4
11 files changed, 597 insertions, 459 deletions
diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py
index 9fd969e3b..13bd18f08 100644
--- a/lib/sqlalchemy/orm/__init__.py
+++ b/lib/sqlalchemy/orm/__init__.py
@@ -44,6 +44,11 @@ from sqlalchemy.orm.properties import (
PropertyLoader,
SynonymProperty,
)
+from sqlalchemy.orm.relationships import (
+ foreign,
+ remote,
+ remote_foreign
+)
from sqlalchemy.orm import mapper as mapperlib
from sqlalchemy.orm.mapper import reconstructor, validates
from sqlalchemy.orm import strategies
@@ -81,6 +86,7 @@ __all__ = (
'dynamic_loader',
'eagerload',
'eagerload_all',
+ 'foreign',
'immediateload',
'join',
'joinedload',
@@ -96,6 +102,8 @@ __all__ = (
'reconstructor',
'relationship',
'relation',
+ 'remote',
+ 'remote_foreign',
'scoped_session',
'sessionmaker',
'subqueryload',
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index 77da33b5f..38237b2d4 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -16,7 +16,7 @@ from sqlalchemy.sql.util import ClauseAdapter, criterion_as_pairs, \
join_condition, _shallow_annotate
from sqlalchemy.sql import operators, expression, visitors
from sqlalchemy.orm import attributes, dependency, mapper, \
- object_mapper, strategies, configure_mappers
+ object_mapper, strategies, configure_mappers, relationships
from sqlalchemy.orm.util import CascadeOptions, _class_to_mapper, \
_orm_annotate, _orm_deannotate
@@ -915,13 +915,12 @@ class RelationshipProperty(StrategizedProperty):
self._check_conflicts()
self._process_dependent_arguments()
self._setup_join_conditions()
- self._extra_determine_direction()
+ self._check_cascade_settings()
self._post_init()
self._generate_backref()
super(RelationshipProperty, self).do_init()
def _setup_join_conditions(self):
- import relationships
self._join_condition = jc = relationships.JoinCondition(
parent_selectable=self.parent.mapped_table,
child_selectable=self.mapper.mapped_table,
@@ -946,8 +945,8 @@ class RelationshipProperty(StrategizedProperty):
self.local_remote_pairs = jc.local_remote_pairs
self.remote_side = jc.remote_columns
self.synchronize_pairs = jc.synchronize_pairs
- self.secondary_synchronize_pairs = jc.secondary_synchronize_pairs
self._calculated_foreign_keys = jc.foreign_key_columns
+ self.secondary_synchronize_pairs = jc.secondary_synchronize_pairs
def _check_conflicts(self):
"""Test that this relationship is legal, warn about
@@ -1035,7 +1034,7 @@ class RelationshipProperty(StrategizedProperty):
(self.key, self.parent.class_)
)
- def _extra_determine_direction(self):
+ def _check_cascade_settings(self):
if self.cascade.delete_orphan and not self.single_parent \
and (self.direction is MANYTOMANY or self.direction
is MANYTOONE):
@@ -1064,7 +1063,6 @@ class RelationshipProperty(StrategizedProperty):
return False
return True
-
def _generate_backref(self):
if not self.is_primary():
return
@@ -1083,13 +1081,15 @@ class RelationshipProperty(StrategizedProperty):
pj = kwargs.pop('primaryjoin', self.secondaryjoin)
sj = kwargs.pop('secondaryjoin', self.primaryjoin)
else:
- pj = kwargs.pop('primaryjoin', self.primaryjoin)
+ pj = kwargs.pop('primaryjoin',
+ self._join_condition.primaryjoin_reverse_remote)
sj = kwargs.pop('secondaryjoin', None)
if sj:
raise sa_exc.InvalidRequestError(
"Can't assign 'secondaryjoin' on a backref against "
"a non-secondary relationship."
)
+
foreign_keys = kwargs.pop('foreign_keys',
self._user_defined_foreign_keys)
parent = self.parent.primary_mapper()
@@ -1112,21 +1112,6 @@ class RelationshipProperty(StrategizedProperty):
self._add_reverse_property(self.back_populates)
def _post_init(self):
- self.logger.info('%s setup primary join %s', self,
- self.primaryjoin)
- self.logger.info('%s setup secondary join %s', self,
- self.secondaryjoin)
- self.logger.info('%s synchronize pairs [%s]', self,
- ','.join('(%s => %s)' % (l, r) for (l, r) in
- self.synchronize_pairs))
- self.logger.info('%s secondary synchronize pairs [%s]', self,
- ','.join('(%s => %s)' % (l, r) for (l, r) in
- self.secondary_synchronize_pairs or []))
- self.logger.info('%s local/remote pairs [%s]', self,
- ','.join('(%s / %s)' % (l, r) for (l, r) in
- self.local_remote_pairs))
- self.logger.info('%s relationship direction %s', self,
- self.direction)
if self.uselist is None:
self.uselist = self.direction is not MANYTOONE
if not self.viewonly:
@@ -1141,46 +1126,6 @@ class RelationshipProperty(StrategizedProperty):
strategy = self._get_strategy(strategies.LazyLoader)
return strategy.use_get
- def _refers_to_parent_table(self):
- alt = self._alt_refers_to_parent_table()
- pt = self.parent.mapped_table
- mt = self.mapper.mapped_table
- for c, f in self.synchronize_pairs:
- if (
- pt.is_derived_from(c.table) and \
- pt.is_derived_from(f.table) and \
- mt.is_derived_from(c.table) and \
- mt.is_derived_from(f.table)
- ):
- assert alt
- return True
- else:
- assert not alt
- return False
-
- def _alt_refers_to_parent_table(self):
- pt = self.parent.mapped_table
- mt = self.mapper.mapped_table
- result = [False]
- def visit_binary(binary):
- c, f = binary.left, binary.right
- if (
- isinstance(c, expression.ColumnClause) and \
- isinstance(f, expression.ColumnClause) and \
- pt.is_derived_from(c.table) and \
- pt.is_derived_from(f.table) and \
- mt.is_derived_from(c.table) and \
- mt.is_derived_from(f.table)
- ):
- result[0] = True
-
- visitors.traverse(
- self.primaryjoin,
- {},
- {"binary":visit_binary}
- )
- return result[0]
-
@util.memoized_property
def _is_self_referential(self):
return self.mapper.common_parent(self.parent)
diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py
index 723d45295..d8c2659b6 100644
--- a/lib/sqlalchemy/orm/relationships.py
+++ b/lib/sqlalchemy/orm/relationships.py
@@ -15,7 +15,7 @@ and `secondaryjoin` aspects of :func:`.relationship`.
from sqlalchemy import sql, util, log, exc as sa_exc, schema
from sqlalchemy.sql.util import ClauseAdapter, criterion_as_pairs, \
- join_condition, _shallow_annotate
+ join_condition, _shallow_annotate, visit_binary_product
from sqlalchemy.sql import operators, expression, visitors
from sqlalchemy.orm.interfaces import MANYTOMANY, MANYTOONE, ONETOMANY
@@ -78,7 +78,34 @@ class JoinCondition(object):
self._determine_joins()
self._annotate_fks()
self._annotate_remote()
+ self._annotate_local()
self._determine_direction()
+ self._setup_pairs()
+ self._check_foreign_cols(self.primaryjoin, True)
+ if self.secondaryjoin is not None:
+ self._check_foreign_cols(self.secondaryjoin, False)
+ self._check_remote_side()
+ self._log_joins()
+
+ def _log_joins(self):
+ if self.prop is None:
+ return
+ log = self.prop.logger
+ log.info('%s setup primary join %s', self,
+ self.primaryjoin)
+ log.info('%s setup secondary join %s', self,
+ self.secondaryjoin)
+ log.info('%s synchronize pairs [%s]', self,
+ ','.join('(%s => %s)' % (l, r) for (l, r) in
+ self.synchronize_pairs))
+ log.info('%s secondary synchronize pairs [%s]', self,
+ ','.join('(%s => %s)' % (l, r) for (l, r) in
+ self.secondary_synchronize_pairs or []))
+ log.info('%s local/remote pairs [%s]', self,
+ ','.join('(%s / %s)' % (l, r) for (l, r) in
+ self.local_remote_pairs))
+ log.info('%s relationship direction %s', self,
+ self.direction)
def _determine_joins(self):
"""Determine the 'primaryjoin' and 'secondaryjoin' attributes,
@@ -128,28 +155,60 @@ class JoinCondition(object):
"'secondaryjoin' is needed as well."
% self.prop)
+ @util.memoized_property
+ def primaryjoin_reverse_remote(self):
+ def replace(element):
+ if "remote" in element._annotations:
+ v = element._annotations.copy()
+ del v['remote']
+ v['local'] = True
+ return element._with_annotations(v)
+ elif "local" in element._annotations:
+ v = element._annotations.copy()
+ del v['local']
+ v['remote'] = True
+ return element._with_annotations(v)
+ return visitors.replacement_traverse(self.primaryjoin, {}, replace)
+
+ def _has_annotation(self, clause, annotation):
+ for col in visitors.iterate(clause, {}):
+ if annotation in col._annotations:
+ return True
+ else:
+ return False
+
def _annotate_fks(self):
+ if self._has_annotation(self.primaryjoin, "foreign"):
+ return
+
+ if self.consider_as_foreign_keys:
+ self._annotate_from_fk_list()
+ else:
+ self._annotate_present_fks()
+
+ def _annotate_from_fk_list(self):
+ def check_fk(col):
+ if col in self.consider_as_foreign_keys:
+ return col._annotate({"foreign":True})
+ self.primaryjoin = visitors.replacement_traverse(
+ self.primaryjoin,
+ {},
+ check_fk
+ )
+ if self.secondaryjoin is not None:
+ self.secondaryjoin = visitors.replacement_traverse(
+ self.secondaryjoin,
+ {},
+ check_fk
+ )
+
+ def _annotate_present_fks(self):
if self.secondary is not None:
secondarycols = util.column_set(self.secondary.c)
else:
secondarycols = set()
- def col_is(a, b):
- return a.compare(b)
-
def is_foreign(a, b):
- if self.consider_as_foreign_keys:
- if a in self.consider_as_foreign_keys and (
- col_is(a, b) or
- b not in self.consider_as_foreign_keys
- ):
- return a
- elif b in self.consider_as_foreign_keys and (
- col_is(a, b) or
- a not in self.consider_as_foreign_keys
- ):
- return b
-
if isinstance(a, schema.Column) and \
isinstance(b, schema.Column):
if a.references(b):
@@ -163,19 +222,6 @@ class JoinCondition(object):
elif b in secondarycols and a not in secondarycols:
return b
- def _annotate_fk(binary, left, right):
- can_be_synced = self.can_be_synced_fn(left)
- left = left._annotate({
- #"equated":binary.operator is operators.eq,
- "can_be_synced":can_be_synced and \
- binary.operator is operators.eq
- })
- right = right._annotate({
- #"equated":binary.operator is operators.eq,
- "referent":True
- })
- return left, right
-
def visit_binary(binary):
if not isinstance(binary.left, sql.ColumnElement) or \
not isinstance(binary.right, sql.ColumnElement):
@@ -185,20 +231,12 @@ class JoinCondition(object):
"foreign" not in binary.right._annotations:
col = is_foreign(binary.left, binary.right)
if col is not None:
- if col is binary.left:
+ if col.compare(binary.left):
binary.left = binary.left._annotate(
{"foreign":True})
- elif col is binary.right:
+ elif col.compare(binary.right):
binary.right = binary.right._annotate(
{"foreign":True})
- # TODO: when the two cols are the same.
-
- if "foreign" in binary.left._annotations:
- binary.left, binary.right = _annotate_fk(
- binary, binary.left, binary.right)
- if "foreign" in binary.right._annotations:
- binary.right, binary.left = _annotate_fk(
- binary, binary.right, binary.left)
self.primaryjoin = visitors.cloned_traverse(
self.primaryjoin,
@@ -211,11 +249,6 @@ class JoinCondition(object):
{},
{"binary":visit_binary}
)
- self._check_foreign_cols(
- self.primaryjoin, True)
- if self.secondaryjoin is not None:
- self._check_foreign_cols(
- self.secondaryjoin, False)
def _refers_to_parent_table(self):
pt = self.parent_selectable
@@ -241,18 +274,14 @@ class JoinCondition(object):
return result[0]
def _annotate_remote(self):
- parentcols = util.column_set(self.parent_selectable.c)
+ if self._has_annotation(self.primaryjoin, "remote"):
+ return
- for col in visitors.iterate(self.primaryjoin, {}):
- if "remote" in col._annotations:
- has_remote_annotations = True
- break
- else:
- has_remote_annotations = False
+ parentcols = util.column_set(self.parent_selectable.c)
def _annotate_selfref(fn):
def visit_binary(binary):
- equated = binary.left is binary.right
+ equated = binary.left.compare(binary.right)
if isinstance(binary.left, sql.ColumnElement) and \
isinstance(binary.right, sql.ColumnElement):
# assume one to many - FKs are "remote"
@@ -267,44 +296,72 @@ class JoinCondition(object):
self.primaryjoin, {},
{"binary":visit_binary})
- if not has_remote_annotations:
+ if self.secondary is not None:
+ def repl(element):
+ if self.secondary.c.contains_column(element):
+ return element._annotate({"remote":True})
+ self.primaryjoin = visitors.replacement_traverse(
+ self.primaryjoin, {}, repl)
+ self.secondaryjoin = visitors.replacement_traverse(
+ self.secondaryjoin, {}, repl)
+ elif self._local_remote_pairs or self._remote_side:
+
if self._local_remote_pairs:
- raise NotImplementedError()
- elif self._remote_side:
- if self._refers_to_parent_table():
- _annotate_selfref(lambda col:col in self._remote_side)
- else:
- def repl(element):
- if element in self._remote_side:
- return element._annotate({"remote":True})
- self.primaryjoin = visitors.replacement_traverse(
- self.primaryjoin, {}, repl)
- elif self.secondary is not None:
- def repl(element):
- if self.secondary.c.contains_column(element):
- return element._annotate({"remote":True})
- self.primaryjoin = visitors.replacement_traverse(
- self.primaryjoin, {}, repl)
- self.secondaryjoin = visitors.replacement_traverse(
- self.secondaryjoin, {}, repl)
- elif self._refers_to_parent_table():
- _annotate_selfref(lambda col:"foreign" in col._annotations)
+ if self._remote_side:
+ raise sa_exc.ArgumentError(
+ "remote_side argument is redundant "
+ "against more detailed _local_remote_side "
+ "argument.")
+
+ remote_side = [r for (l, r) in self._local_remote_pairs]
+ else:
+ remote_side = self._remote_side
+
+ if self._refers_to_parent_table():
+ _annotate_selfref(lambda col:col in remote_side)
else:
def repl(element):
- if self.child_selectable.c.contains_column(element):
+ if element in remote_side:
return element._annotate({"remote":True})
-
self.primaryjoin = visitors.replacement_traverse(
self.primaryjoin, {}, repl)
+ elif self._refers_to_parent_table():
+ _annotate_selfref(lambda col:"foreign" in col._annotations)
+ else:
+ def repl(element):
+ if self.child_selectable.c.contains_column(element):
+ return element._annotate({"remote":True})
+
+ self.primaryjoin = visitors.replacement_traverse(
+ self.primaryjoin, {}, repl)
+
+ def _annotate_local(self):
+ if self._has_annotation(self.primaryjoin, "local"):
+ return
+
+ parentcols = util.column_set(self.parent_selectable.c)
+
+ if self._local_remote_pairs:
+ local_side = util.column_set([l for (l, r)
+ in self._local_remote_pairs])
+ else:
+ local_side = util.column_set(self.parent_selectable.c)
def locals_(elem):
if "remote" not in elem._annotations and \
- elem in parentcols:
+ elem in local_side:
return elem._annotate({"local":True})
self.primaryjoin = visitors.replacement_traverse(
self.primaryjoin, {}, locals_
)
+ def _check_remote_side(self):
+ if not self.local_remote_pairs:
+ raise sa_exc.ArgumentError('Relationship %s could '
+ 'not determine any local/remote column '
+ 'pairs from remote side argument %r'
+ % (self.prop, self._remote_side))
+
def _check_foreign_cols(self, join_condition, primary):
"""Check the foreign key columns collected and emit error messages."""
@@ -315,11 +372,10 @@ class JoinCondition(object):
has_foreign = bool(foreign_cols)
- if self.support_sync:
- for col in foreign_cols:
- if col._annotations.get("can_be_synced"):
- can_sync = True
- break
+ if primary:
+ can_sync = bool(self.synchronize_pairs)
+ else:
+ can_sync = bool(self.secondary_synchronize_pairs)
if self.support_sync and can_sync or \
(not self.support_sync and has_foreign):
@@ -407,6 +463,44 @@ class JoinCondition(object):
"key columns are present in neither the parent "
"nor the child's mapped tables" % self.prop)
+ def _setup_pairs(self):
+ sync_pairs = []
+ lrp = util.OrderedSet([])
+ secondary_sync_pairs = []
+
+ def go(joincond, collection):
+ def visit_binary(binary, left, right):
+ if "remote" in right._annotations and \
+ "remote" not in left._annotations and \
+ self.can_be_synced_fn(left):
+ lrp.add((left, right))
+ elif "remote" in left._annotations and \
+ "remote" not in right._annotations and \
+ self.can_be_synced_fn(right):
+ lrp.add((right, left))
+ if binary.operator is operators.eq:
+ # and \
+ #binary.left.compare(left) and \
+ #binary.right.compare(right):
+ if "foreign" in right._annotations:
+ collection.append((left, right))
+ elif "foreign" in left._annotations:
+ collection.append((right, left))
+ visit_binary_product(visit_binary, joincond)
+
+ for joincond, collection in [
+ (self.primaryjoin, sync_pairs),
+ (self.secondaryjoin, secondary_sync_pairs)
+ ]:
+ if joincond is None:
+ continue
+ go(joincond, collection)
+
+ self.local_remote_pairs = list(lrp)
+ self.synchronize_pairs = sync_pairs
+ self.secondary_synchronize_pairs = secondary_sync_pairs
+
+
@util.memoized_property
def remote_columns(self):
return self._gather_join_annotations("remote")
@@ -416,38 +510,6 @@ class JoinCondition(object):
return self._gather_join_annotations("local")
@util.memoized_property
- def synchronize_pairs(self):
- parentcols = util.column_set(self.parent_selectable.c)
- targetcols = util.column_set(self.child_selectable.c)
- result = []
- for l, r in self.local_remote_pairs:
- if self.secondary is not None:
- if "foreign" in r._annotations and \
- l in parentcols:
- result.append((l, r))
- elif "foreign" in r._annotations and \
- "can_be_synced" in r._annotations:
- result.append((l, r))
- elif "foreign" in l._annotations and \
- "can_be_synced" in l._annotations:
- result.append((r, l))
- return result
-
- @util.memoized_property
- def secondary_synchronize_pairs(self):
- parentcols = util.column_set(self.parent_selectable.c)
- targetcols = util.column_set(self.child_selectable.c)
- result = []
- if self.secondary is None:
- return result
-
- for l, r in self.local_remote_pairs:
- if "foreign" in l._annotations and \
- r in targetcols:
- result.append((l, r))
- return result
-
- @util.memoized_property
def foreign_key_columns(self):
return self._gather_join_annotations("foreign")
@@ -470,24 +532,6 @@ class JoinCondition(object):
if annotation.issubset(col._annotations)
])
- @util.memoized_property
- def local_remote_pairs(self):
- lrp = util.OrderedSet()
- def visit_binary(binary):
- if "remote" in binary.right._annotations and \
- "remote" not in binary.left._annotations and \
- isinstance(binary.left, expression.ColumnClause) and \
- self.can_be_synced_fn(binary.left):
- lrp.add((binary.left, binary.right))
- elif "remote" in binary.left._annotations and \
- "remote" not in binary.right._annotations and \
- isinstance(binary.right, expression.ColumnClause) and \
- self.can_be_synced_fn(binary.right):
- lrp.add((binary.right, binary.left))
- visitors.traverse(self.primaryjoin, {}, {"binary":visit_binary})
- if self.secondaryjoin is not None:
- visitors.traverse(self.secondaryjoin, {}, {"binary":visit_binary})
- return list(lrp)
def join_targets(self, source_selectable,
dest_selectable,
@@ -604,147 +648,6 @@ def _create_lazy_clause(cls, prop, reverse_direction=False):
return lazywhere, bind_to_col, equated_columns
-def _determine_synchronize_pairs(self):
- """Resolve 'primary'/foreign' column pairs from the primaryjoin
- and secondaryjoin arguments.
-
- """
- if self.local_remote_pairs:
- if not self._user_defined_foreign_keys:
- raise sa_exc.ArgumentError(
- "foreign_keys argument is "
- "required with _local_remote_pairs argument")
- self.synchronize_pairs = []
- for l, r in self.local_remote_pairs:
- if r in self._user_defined_foreign_keys:
- self.synchronize_pairs.append((l, r))
- elif l in self._user_defined_foreign_keys:
- self.synchronize_pairs.append((r, l))
- else:
- self.synchronize_pairs = self._sync_pairs_from_join(
- self.primaryjoin,
- True)
-
- self._calculated_foreign_keys = util.column_set(
- r for (l, r) in
- self.synchronize_pairs)
-
- if self.secondaryjoin is not None:
- self.secondary_synchronize_pairs = self._sync_pairs_from_join(
- self.secondaryjoin,
- False)
- self._calculated_foreign_keys.update(
- r for (l, r) in
- self.secondary_synchronize_pairs)
- else:
- self.secondary_synchronize_pairs = None
-
-
-def _determine_local_remote_pairs(self):
- """Determine pairs of columns representing "local" to
- "remote", where "local" columns are on the parent mapper,
- "remote" are on the target mapper.
-
- These pairs are used on the load side only to generate
- lazy loading clauses.
-
- """
- if not self.local_remote_pairs and not self.remote_side:
- # the most common, trivial case. Derive
- # local/remote pairs from the synchronize pairs.
- eq_pairs = util.unique_list(
- self.synchronize_pairs +
- (self.secondary_synchronize_pairs or []))
- if self.direction is MANYTOONE:
- self.local_remote_pairs = [(r, l) for l, r in eq_pairs]
- else:
- self.local_remote_pairs = eq_pairs
-
- # "remote_side" specified, derive from the primaryjoin
- # plus remote_side, similarly to how synchronize_pairs
- # were determined.
- elif self.remote_side:
- if self.local_remote_pairs:
- raise sa_exc.ArgumentError('remote_side argument is '
- 'redundant against more detailed '
- '_local_remote_side argument.')
- if self.direction is MANYTOONE:
- self.local_remote_pairs = [(r, l) for (l, r) in
- criterion_as_pairs(self.primaryjoin,
- consider_as_referenced_keys=self.remote_side,
- any_operator=True)]
-
- else:
- self.local_remote_pairs = \
- criterion_as_pairs(self.primaryjoin,
- consider_as_foreign_keys=self.remote_side,
- any_operator=True)
- if not self.local_remote_pairs:
- raise sa_exc.ArgumentError('Relationship %s could '
- 'not determine any local/remote column '
- 'pairs from remote side argument %r'
- % (self, self.remote_side))
- # else local_remote_pairs were sent explcitly via
- # ._local_remote_pairs.
-
- # create local_side/remote_side accessors
- self.local_side = util.ordered_column_set(
- l for l, r in self.local_remote_pairs)
- self.remote_side = util.ordered_column_set(
- r for l, r in self.local_remote_pairs)
-
- # check that the non-foreign key column in the local/remote
- # collection is mapped. The foreign key
- # which the individual mapped column references directly may
- # itself be in a non-mapped table; see
- # test.orm.test_relationships.ViewOnlyComplexJoin.test_basic
- # for an example of this.
- if self.direction is ONETOMANY:
- for col in self.local_side:
- if not self._columns_are_mapped(col):
- raise sa_exc.ArgumentError(
- "Local column '%s' is not "
- "part of mapping %s. Specify remote_side "
- "argument to indicate which column lazy join "
- "condition should compare against." % (col,
- self.parent))
- elif self.direction is MANYTOONE:
- for col in self.remote_side:
- if not self._columns_are_mapped(col):
- raise sa_exc.ArgumentError(
- "Remote column '%s' is not "
- "part of mapping %s. Specify remote_side "
- "argument to indicate which column lazy join "
- "condition should bind." % (col, self.mapper))
-
- count = [0]
- def clone(elem):
- if set(['local', 'remote']).intersection(elem._annotations):
- return None
- elif elem in self.local_side and elem in self.remote_side:
- # TODO: OK this still sucks. this is basically,
- # refuse, refuse, refuse the temptation to guess!
- # but crap we really have to guess don't we. we
- # might want to traverse here with cloned_traverse
- # so we can see the binary exprs and do it at that
- # level....
- if count[0] % 2 == 0:
- elem = elem._annotate({'local':True})
- else:
- elem = elem._annotate({'remote':True})
- count[0] += 1
- elif elem in self.local_side:
- elem = elem._annotate({'local':True})
- elif elem in self.remote_side:
- elem = elem._annotate({'remote':True})
- else:
- elem = None
- return elem
-
- self.primaryjoin = visitors.replacement_traverse(
- self.primaryjoin, {}, clone
- )
-
def _criterion_exists(self, criterion=None, **kwargs):
if getattr(self, '_of_type', None):
diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py
index 5f4b182d0..320234281 100644
--- a/lib/sqlalchemy/orm/strategies.py
+++ b/lib/sqlalchemy/orm/strategies.py
@@ -785,6 +785,7 @@ class SubqueryLoader(AbstractRelationshipLoader):
leftmost_mapper, leftmost_prop = \
subq_mapper, \
subq_mapper._props[subq_path[1]]
+ # TODO: local cols might not be unique here
leftmost_cols, remote_cols = self._local_remote_columns(leftmost_prop)
leftmost_attr = [
@@ -846,6 +847,7 @@ class SubqueryLoader(AbstractRelationshipLoader):
# self.parent is more specific than subq_path[-2]
parent_alias = mapperutil.AliasedClass(self.parent)
+ # TODO: local cols might not be unique here
local_cols, remote_cols = \
self._local_remote_columns(self.parent_property)
@@ -885,6 +887,7 @@ class SubqueryLoader(AbstractRelationshipLoader):
if prop.secondary is None:
return zip(*prop.local_remote_pairs)
else:
+ # TODO: this isn't going to work for readonly....
return \
[p[0] for p in prop.synchronize_pairs],\
[
@@ -930,6 +933,7 @@ class SubqueryLoader(AbstractRelationshipLoader):
if ('subquery', reduced_path) not in context.attributes:
return None, None, None
+ # TODO: local_cols might not be unique here
local_cols, remote_cols = self._local_remote_columns(self.parent_property)
q = context.attributes[('subquery', reduced_path)]
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index 0cd5b0594..f17f675f4 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -366,7 +366,18 @@ def _orm_annotate(element, exclude=None):
"""
return sql_util._deep_annotate(element, {'_orm_adapt':True}, exclude)
-_orm_deannotate = sql_util._deep_deannotate
+def _orm_deannotate(element):
+ """Remove annotations that link a column to a particular mapping.
+
+ Note this doesn't affect "remote" and "foreign" annotations
+ passed by the :func:`.orm.foreign` and :func:`.orm.remote`
+ annotators.
+
+ """
+
+ return sql_util._deep_deannotate(element,
+ values=("_orm_adapt", "parententity")
+ )
class _ORMJoin(expression.Join):
"""Extend Join to support ORM constructs as input."""
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index 72099a5f5..ebf4de9a2 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -1576,18 +1576,30 @@ class ClauseElement(Visitable):
return id(self)
def _annotate(self, values):
- """return a copy of this ClauseElement with the given annotations
- dictionary.
+ """return a copy of this ClauseElement with annotations
+ updated by the given dictionary.
"""
return sqlutil.Annotated(self, values)
- def _deannotate(self):
- """return a copy of this ClauseElement with an empty annotations
- dictionary.
+ def _with_annotations(self, values):
+ """return a copy of this ClauseElement with annotations
+ replaced by the given dictionary.
"""
- return self._clone()
+ return sqlutil.Annotated(self, values)
+
+ def _deannotate(self, values=None):
+ """return a copy of this :class:`.ClauseElement` with annotations
+ removed.
+
+ :param values: optional tuple of individual values
+ to remove.
+
+ """
+ # since we have no annotations we return
+ # self
+ return self
def unique_params(self, *optionaldict, **kwargs):
"""Return a copy with :func:`bindparam()` elments replaced.
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index f0509c16f..9a45a5777 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -62,6 +62,61 @@ def find_join_source(clauses, join_to):
else:
return None, None
+
+def visit_binary_product(fn, expr):
+ """Produce a traversal of the given expression, delivering
+ column comparisons to the given function.
+
+ The function is of the form::
+
+ def my_fn(binary, left, right)
+
+ For each binary expression located which has a
+ comparison operator, the product of "left" and
+ "right" will be delivered to that function,
+ in terms of that binary.
+
+ Hence an expression like::
+
+ and_(
+ (a + b) == q + func.sum(e + f),
+ j == r
+ )
+
+ would have the traversal::
+
+ a <eq> q
+ a <eq> e
+ a <eq> f
+ b <eq> q
+ b <eq> e
+ b <eq> f
+ j <eq> r
+
+ That is, every combination of "left" and
+ "right" that doesn't further contain
+ a binary comparison is passed as pairs.
+
+ """
+ stack = []
+ def visit(element):
+ if element.__visit_name__ == 'binary' and \
+ operators.is_comparison(element.operator):
+ stack.insert(0, element)
+ for l in visit(element.left):
+ for r in visit(element.right):
+ fn(stack[0], l, r)
+ stack.pop(0)
+ for elem in element.get_children():
+ visit(elem)
+ else:
+ if isinstance(element, expression.ColumnClause):
+ yield element
+ for elem in element.get_children():
+ for e in visit(elem):
+ yield e
+ list(visit(expr))
+
def find_tables(clause, check_columns=False,
include_aliases=False, include_joins=False,
include_selects=False, include_crud=False):
@@ -357,13 +412,22 @@ class Annotated(object):
def _annotate(self, values):
_values = self._annotations.copy()
_values.update(values)
+ return self._with_annotations(_values)
+
+ def _with_annotations(self, values):
clone = self.__class__.__new__(self.__class__)
clone.__dict__ = self.__dict__.copy()
- clone._annotations = _values
+ clone._annotations = values
return clone
- def _deannotate(self):
- return self.__element
+ def _deannotate(self, values=None):
+ if values is None:
+ return self.__element
+ else:
+ _values = self._annotations.copy()
+ for v in values:
+ _values.pop(v, None)
+ return self._with_annotations(_values)
def _compiler_dispatch(self, visitor, **kw):
return self.__element.__class__._compiler_dispatch(self, visitor, **kw)
@@ -426,11 +490,11 @@ def _deep_annotate(element, annotations, exclude=None):
element = clone(element)
return element
-def _deep_deannotate(element):
- """Deep copy the given element, removing all annotations."""
+def _deep_deannotate(element, values=None):
+ """Deep copy the given element, removing annotations."""
def clone(elem):
- elem = elem._deannotate()
+ elem = elem._deannotate(values=values)
elem._copy_internals(clone=clone)
return elem
diff --git a/test/orm/test_rel_fn.py b/test/orm/test_rel_fn.py
index d3d346bba..346cb90c1 100644
--- a/test/orm/test_rel_fn.py
+++ b/test/orm/test_rel_fn.py
@@ -6,9 +6,8 @@ from sqlalchemy import MetaData, Table, Column, ForeignKey, Integer, \
select, ForeignKeyConstraint, exc
from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE, MANYTOMANY
-class JoinCondTest(fixtures.TestBase, AssertsCompiledSQL):
- __dialect__ = 'default'
+class _JoinFixtures(object):
@classmethod
def setup_class(cls):
m = MetaData()
@@ -36,6 +35,28 @@ class JoinCondTest(fixtures.TestBase, AssertsCompiledSQL):
['composite_selfref.id', 'composite_selfref.group_id']
)
)
+ cls.m2mleft = Table('m2mlft', m,
+ Column('id', Integer, primary_key=True),
+ )
+ cls.m2mright = Table('m2mrgt', m,
+ Column('id', Integer, primary_key=True),
+ )
+ cls.m2msecondary = Table('m2msecondary', m,
+ Column('lid', Integer, ForeignKey('m2mlft.id'), primary_key=True),
+ Column('rid', Integer, ForeignKey('m2mrgt.id'), primary_key=True),
+ )
+
+ def _join_fixture_m2m_selfref(self, **kw):
+ return relationships.JoinCondition(
+ self.m2mleft,
+ self.m2mright,
+ self.m2mleft,
+ self.m2mright,
+ secondary=self.m2msecondary,
+ primaryjoin=self.m2mleft.c.id==self.m2msecondary.c.lid,
+ secondaryjoin=self.m2mright.c.id==self.m2msecondary.c.rid,
+ **kw
+ )
def _join_fixture_o2m(self, **kw):
return relationships.JoinCondition(
@@ -120,6 +141,7 @@ class JoinCondTest(fixtures.TestBase, AssertsCompiledSQL):
**kw
)
+class ColumnCollectionsTest(_JoinFixtures, fixtures.TestBase, AssertsCompiledSQL):
def test_determine_remote_columns_compound_1(self):
joincond = self._join_fixture_compound_expression_1(
support_sync=False)
@@ -133,7 +155,25 @@ class JoinCondTest(fixtures.TestBase, AssertsCompiledSQL):
support_sync=False)
eq_(
joincond.local_remote_pairs,
- []
+ [
+ (self.left.c.x, self.right.c.x),
+ (self.left.c.x, self.right.c.y),
+ (self.left.c.y, self.right.c.x),
+ (self.left.c.y, self.right.c.y)
+ ]
+ )
+
+ def test_determine_local_remote_compound_2(self):
+ joincond = self._join_fixture_compound_expression_2(
+ support_sync=False)
+ eq_(
+ joincond.local_remote_pairs,
+ [
+ (self.left.c.x, self.right.c.x),
+ (self.left.c.x, self.right.c.y),
+ (self.left.c.y, self.right.c.x),
+ (self.left.c.y, self.right.c.y)
+ ]
)
def test_err_local_remote_compound_1(self):
@@ -160,14 +200,71 @@ class JoinCondTest(fixtures.TestBase, AssertsCompiledSQL):
set([self.right.c.x, self.right.c.y])
)
- def test_determine_local_remote_compound_2(self):
- joincond = self._join_fixture_compound_expression_2(
- support_sync=False)
+
+ def test_determine_remote_columns_o2m(self):
+ joincond = self._join_fixture_o2m()
+ eq_(
+ joincond.remote_columns,
+ set([self.right.c.lid])
+ )
+
+ def test_determine_remote_columns_o2m_selfref(self):
+ joincond = self._join_fixture_o2m_selfref()
+ eq_(
+ joincond.remote_columns,
+ set([self.selfref.c.sid])
+ )
+
+ def test_determine_remote_columns_o2m_composite_selfref(self):
+ joincond = self._join_fixture_o2m_composite_selfref()
+ eq_(
+ joincond.remote_columns,
+ set([self.composite_selfref.c.parent_id,
+ self.composite_selfref.c.group_id])
+ )
+
+ def test_determine_remote_columns_m2o_composite_selfref(self):
+ joincond = self._join_fixture_m2o_composite_selfref()
+ eq_(
+ joincond.remote_columns,
+ set([self.composite_selfref.c.id,
+ self.composite_selfref.c.group_id])
+ )
+
+ def test_determine_remote_columns_m2o(self):
+ joincond = self._join_fixture_m2o()
+ eq_(
+ joincond.remote_columns,
+ set([self.left.c.id])
+ )
+
+ def test_determine_local_remote_pairs_o2m(self):
+ joincond = self._join_fixture_o2m()
eq_(
joincond.local_remote_pairs,
- []
+ [(self.left.c.id, self.right.c.lid)]
+ )
+
+ def test_determine_synchronize_pairs_m2m_selfref(self):
+ joincond = self._join_fixture_m2m_selfref()
+ eq_(
+ joincond.synchronize_pairs,
+ [(self.m2mleft.c.id, self.m2msecondary.c.lid)]
+ )
+ eq_(
+ joincond.secondary_synchronize_pairs,
+ [(self.m2mright.c.id, self.m2msecondary.c.rid)]
)
+ def test_determine_remote_columns_m2o_selfref(self):
+ joincond = self._join_fixture_m2o_selfref()
+ eq_(
+ joincond.remote_columns,
+ set([self.selfref.c.id])
+ )
+
+
+class DirectionTest(_JoinFixtures, fixtures.TestBase, AssertsCompiledSQL):
def test_determine_direction_compound_2(self):
joincond = self._join_fixture_compound_expression_2(
support_sync=False)
@@ -176,60 +273,46 @@ class JoinCondTest(fixtures.TestBase, AssertsCompiledSQL):
ONETOMANY
)
- def test_determine_join_o2m(self):
- joincond = self._join_fixture_o2m()
- self.assert_compile(
- joincond.primaryjoin,
- "lft.id = rgt.lid"
- )
-
def test_determine_direction_o2m(self):
joincond = self._join_fixture_o2m()
is_(joincond.direction, ONETOMANY)
- def test_determine_remote_columns_o2m(self):
- joincond = self._join_fixture_o2m()
- eq_(
- joincond.remote_columns,
- set([self.right.c.lid])
- )
-
- def test_determine_join_o2m_selfref(self):
- joincond = self._join_fixture_o2m_selfref()
- self.assert_compile(
- joincond.primaryjoin,
- "selfref.id = selfref.sid"
- )
-
def test_determine_direction_o2m_selfref(self):
joincond = self._join_fixture_o2m_selfref()
is_(joincond.direction, ONETOMANY)
- def test_determine_remote_columns_o2m_selfref(self):
- joincond = self._join_fixture_o2m_selfref()
- eq_(
- joincond.remote_columns,
- set([self.selfref.c.sid])
- )
+ def test_determine_direction_m2o_selfref(self):
+ joincond = self._join_fixture_m2o_selfref()
+ is_(joincond.direction, MANYTOONE)
- def test_join_targets_o2m_selfref(self):
- joincond = self._join_fixture_o2m_selfref()
- left = select([joincond.parent_selectable]).alias('pj')
- pj, sj, sec, adapter = joincond.join_targets(
- left,
- joincond.child_selectable,
- True)
+ def test_determine_direction_o2m_composite_selfref(self):
+ joincond = self._join_fixture_o2m_composite_selfref()
+ is_(joincond.direction, ONETOMANY)
+
+ def test_determine_direction_m2o_composite_selfref(self):
+ joincond = self._join_fixture_m2o_composite_selfref()
+ is_(joincond.direction, MANYTOONE)
+
+ def test_determine_direction_m2o(self):
+ joincond = self._join_fixture_m2o()
+ is_(joincond.direction, MANYTOONE)
+
+
+class DetermineJoinTest(_JoinFixtures, fixtures.TestBase, AssertsCompiledSQL):
+ __dialect__ = 'default'
+
+ def test_determine_join_o2m(self):
+ joincond = self._join_fixture_o2m()
self.assert_compile(
- pj, "pj.id = selfref.sid"
+ joincond.primaryjoin,
+ "lft.id = rgt.lid"
)
- right = select([joincond.child_selectable]).alias('pj')
- pj, sj, sec, adapter = joincond.join_targets(
- joincond.parent_selectable,
- right,
- True)
+ def test_determine_join_o2m_selfref(self):
+ joincond = self._join_fixture_o2m_selfref()
self.assert_compile(
- pj, "selfref.id = pj.sid"
+ joincond.primaryjoin,
+ "selfref.id = selfref.sid"
)
def test_determine_join_m2o_selfref(self):
@@ -239,17 +322,6 @@ class JoinCondTest(fixtures.TestBase, AssertsCompiledSQL):
"selfref.id = selfref.sid"
)
- def test_determine_direction_m2o_selfref(self):
- joincond = self._join_fixture_m2o_selfref()
- is_(joincond.direction, MANYTOONE)
-
- def test_determine_remote_columns_m2o_selfref(self):
- joincond = self._join_fixture_m2o_selfref()
- eq_(
- joincond.remote_columns,
- set([self.selfref.c.id])
- )
-
def test_determine_join_o2m_composite_selfref(self):
joincond = self._join_fixture_o2m_composite_selfref()
self.assert_compile(
@@ -258,18 +330,6 @@ class JoinCondTest(fixtures.TestBase, AssertsCompiledSQL):
"AND composite_selfref.id = composite_selfref.parent_id"
)
- def test_determine_direction_o2m_composite_selfref(self):
- joincond = self._join_fixture_o2m_composite_selfref()
- is_(joincond.direction, ONETOMANY)
-
- def test_determine_remote_columns_o2m_composite_selfref(self):
- joincond = self._join_fixture_o2m_composite_selfref()
- eq_(
- joincond.remote_columns,
- set([self.composite_selfref.c.parent_id,
- self.composite_selfref.c.group_id])
- )
-
def test_determine_join_m2o_composite_selfref(self):
joincond = self._join_fixture_m2o_composite_selfref()
self.assert_compile(
@@ -278,17 +338,7 @@ class JoinCondTest(fixtures.TestBase, AssertsCompiledSQL):
"AND composite_selfref.id = composite_selfref.parent_id"
)
- def test_determine_direction_m2o_composite_selfref(self):
- joincond = self._join_fixture_m2o_composite_selfref()
- is_(joincond.direction, MANYTOONE)
- def test_determine_remote_columns_m2o_composite_selfref(self):
- joincond = self._join_fixture_m2o_composite_selfref()
- eq_(
- joincond.remote_columns,
- set([self.composite_selfref.c.id,
- self.composite_selfref.c.group_id])
- )
def test_determine_join_m2o(self):
joincond = self._join_fixture_m2o()
@@ -297,24 +347,30 @@ class JoinCondTest(fixtures.TestBase, AssertsCompiledSQL):
"lft.id = rgt.lid"
)
- def test_determine_direction_m2o(self):
- joincond = self._join_fixture_m2o()
- is_(joincond.direction, MANYTOONE)
+class AdaptedJoinTest(_JoinFixtures, fixtures.TestBase, AssertsCompiledSQL):
+ __dialect__ = 'default'
- def test_determine_remote_columns_m2o(self):
- joincond = self._join_fixture_m2o()
- eq_(
- joincond.remote_columns,
- set([self.left.c.id])
+ def test_join_targets_o2m_selfref(self):
+ joincond = self._join_fixture_o2m_selfref()
+ left = select([joincond.parent_selectable]).alias('pj')
+ pj, sj, sec, adapter = joincond.join_targets(
+ left,
+ joincond.child_selectable,
+ True)
+ self.assert_compile(
+ pj, "pj.id = selfref.sid"
)
- def test_determine_local_remote_pairs_o2m(self):
- joincond = self._join_fixture_o2m()
- eq_(
- joincond.local_remote_pairs,
- [(self.left.c.id, self.right.c.lid)]
+ right = select([joincond.child_selectable]).alias('pj')
+ pj, sj, sec, adapter = joincond.join_targets(
+ joincond.parent_selectable,
+ right,
+ True)
+ self.assert_compile(
+ pj, "selfref.id = pj.sid"
)
+
def test_join_targets_o2m_plain(self):
joincond = self._join_fixture_o2m()
pj, sj, sec, adapter = joincond.join_targets(
@@ -347,6 +403,8 @@ class JoinCondTest(fixtures.TestBase, AssertsCompiledSQL):
pj, "lft.id = pj.lid"
)
+class LazyClauseTest(_JoinFixtures, fixtures.TestBase, AssertsCompiledSQL):
+
def _test_lazy_clause_o2m(self):
joincond = self._join_fixture_o2m()
self.assert_compile(
diff --git a/test/orm/test_relationships.py b/test/orm/test_relationships.py
index 0a02cbf9a..d2dcbe312 100644
--- a/test/orm/test_relationships.py
+++ b/test/orm/test_relationships.py
@@ -7,8 +7,9 @@ from test.lib.schema import Table, Column
from sqlalchemy.orm import mapper, relationship, relation, \
backref, create_session, configure_mappers, \
clear_mappers, sessionmaker, attributes,\
- Session, composite, column_property
-from test.lib.testing import eq_, startswith_, AssertsCompiledSQL
+ Session, composite, column_property, foreign
+from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE, MANYTOMANY
+from test.lib.testing import eq_, startswith_, AssertsCompiledSQL, is_
from test.lib import fixtures
from test.orm import _fixtures
@@ -141,12 +142,12 @@ class CompositeSelfRefFKTest(fixtures.MappedTest):
Table('company_t', metadata,
Column('company_id', Integer, primary_key=True,
test_needs_autoincrement=True),
- Column('name', sa.Unicode(30)))
+ Column('name', String(30)))
Table('employee_t', metadata,
Column('company_id', Integer, primary_key=True),
Column('emp_id', Integer, primary_key=True),
- Column('name', sa.Unicode(30)),
+ Column('name', String(30)),
Column('reports_to_id', Integer),
sa.ForeignKeyConstraint(
['company_id'],
@@ -158,7 +159,8 @@ class CompositeSelfRefFKTest(fixtures.MappedTest):
@classmethod
def setup_classes(cls):
class Company(cls.Basic):
- pass
+ def __init__(self, name):
+ self.name = name
class Employee(cls.Basic):
def __init__(self, name, company, emp_id, reports_to=None):
@@ -248,11 +250,16 @@ class CompositeSelfRefFKTest(fixtures.MappedTest):
self._test()
def _test(self):
+ sess = Session()
+ self._setup_data(sess)
+ self._test_lazy_relations(sess)
+ self._test_join_aliasing(sess)
+
+ def _setup_data(self, sess):
Employee, Company = self.classes.Employee, self.classes.Company
- sess = create_session()
- c1 = Company()
- c2 = Company()
+ c1 = Company('c1')
+ c2 = Company('c2')
e1 = Employee(u'emp1', c1, 1)
e2 = Employee(u'emp2', c1, 2, e1)
@@ -263,10 +270,17 @@ class CompositeSelfRefFKTest(fixtures.MappedTest):
e7 = Employee(u'emp7', c2, 3, e5)
sess.add_all((c1, c2))
- sess.flush()
- sess.expunge_all()
+ sess.commit()
+ sess.close()
+
+ def _test_lazy_relations(self, sess):
+ Employee, Company = self.classes.Employee, self.classes.Company
+
+ c1 = sess.query(Company).filter_by(name='c1').one()
+ c2 = sess.query(Company).filter_by(name='c2').one()
+ e1 = sess.query(Employee).filter_by(name='emp1').one()
+ e5 = sess.query(Employee).filter_by(name='emp5').one()
- test_c1 = sess.query(Company).get(c1.company_id)
test_e1 = sess.query(Employee).get([c1.company_id, e1.emp_id])
assert test_e1.name == 'emp1', test_e1.name
test_e5 = sess.query(Employee).get([c2.company_id, e5.emp_id])
@@ -277,6 +291,16 @@ class CompositeSelfRefFKTest(fixtures.MappedTest):
assert sess.query(Employee).\
get([c2.company_id, 3]).reports_to.name == 'emp5'
+ def _test_join_aliasing(self, sess):
+ Employee, Company = self.classes.Employee, self.classes.Company
+ eq_(
+ [n for n, in sess.query(Employee.name).\
+ join(Employee.reports_to, aliased=True).\
+ filter_by(name='emp5').\
+ reset_joinpoint().\
+ order_by(Employee.name)],
+ ['emp6', 'emp7']
+ )
class CompositeJoinPartialFK(fixtures.MappedTest, AssertsCompiledSQL):
__dialect__ = 'default'
@@ -839,7 +863,6 @@ class AmbiguousJoinInterpretedAsSelfRef(fixtures.MappedTest):
def test_mapping(self):
Subscriber, Address = self.classes.Subscriber, self.classes.Address
- from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE
sess = create_session()
assert Subscriber.addresses.property.direction is ONETOMANY
assert Address.customer.property.direction is MANYTOONE
@@ -1733,21 +1756,45 @@ class ExplicitLocalRemoteTest(fixtures.MappedTest):
class T2(cls.Comparable):
pass
- def test_onetomany_funcfk(self):
+ def test_onetomany_funcfk_oldstyle(self):
T2, T1, t2, t1 = (self.classes.T2,
self.classes.T1,
self.tables.t2,
self.tables.t1)
- # use a function within join condition. but specifying
- # local_remote_pairs overrides all parsing of the join condition.
+ # old _local_remote_pairs
mapper(T1, t1, properties={
't2s':relationship(T2,
primaryjoin=t1.c.id==sa.func.lower(t2.c.t1id),
_local_remote_pairs=[(t1.c.id, t2.c.t1id)],
- foreign_keys=[t2.c.t1id])})
+ foreign_keys=[t2.c.t1id]
+ )
+ })
+ mapper(T2, t2)
+ self._test_onetomany()
+
+ def test_onetomany_funcfk_annotated(self):
+ T2, T1, t2, t1 = (self.classes.T2,
+ self.classes.T1,
+ self.tables.t2,
+ self.tables.t1)
+
+ # use annotation
+ mapper(T1, t1, properties={
+ 't2s':relationship(T2,
+ primaryjoin=t1.c.id==
+ foreign(sa.func.lower(t2.c.t1id)),
+ )})
mapper(T2, t2)
+ self._test_onetomany()
+ def _test_onetomany(self):
+ T2, T1, t2, t1 = (self.classes.T2,
+ self.classes.T1,
+ self.tables.t2,
+ self.tables.t1)
+ is_(T1.t2s.property.direction, ONETOMANY)
+ eq_(T1.t2s.property.local_remote_pairs, [(t1.c.id, t2.c.t1id)])
sess = create_session()
a1 = T1(id='number1', data='a1')
a2 = T1(id='number2', data='a2')
diff --git a/test/sql/test_generative.py b/test/sql/test_generative.py
index f9333dbf5..d4f324dd7 100644
--- a/test/sql/test_generative.py
+++ b/test/sql/test_generative.py
@@ -1,5 +1,5 @@
from sqlalchemy import *
-from sqlalchemy.sql import table, column, ClauseElement
+from sqlalchemy.sql import table, column, ClauseElement, operators
from sqlalchemy.sql.expression import _clone, _from_objects
from test.lib import *
from sqlalchemy.sql.visitors import *
@@ -166,6 +166,90 @@ class TraversalTest(fixtures.TestBase, AssertsExecutionResults):
s = set(ClauseVisitor().iterate(bin))
assert set(ClauseVisitor().iterate(bin)) == set([foo, bar, bin])
+class BinaryEndpointTraversalTest(fixtures.TestBase):
+ """test the special binary product visit"""
+
+ def _assert_traversal(self, expr, expected):
+ canary = []
+ def visit(binary, l, r):
+ canary.append((binary.operator, l, r))
+ print binary.operator, l, r
+ sql_util.visit_binary_product(visit, expr)
+ eq_(
+ canary, expected
+ )
+
+ def test_basic(self):
+ a, b = column("a"), column("b")
+ self._assert_traversal(
+ a == b,
+ [
+ (operators.eq, a, b)
+ ]
+ )
+
+ def test_with_tuples(self):
+ a, b, c, d, b1, b1a, b1b, e, f = (
+ column("a"),
+ column("b"),
+ column("c"),
+ column("d"),
+ column("b1"),
+ column("b1a"),
+ column("b1b"),
+ column("e"),
+ column("f")
+ )
+ expr = tuple_(
+ a, b, b1==tuple_(b1a, b1b == d), c
+ ) > tuple_(
+ func.go(e + f)
+ )
+ self._assert_traversal(
+ expr,
+ [
+ (operators.gt, a, e),
+ (operators.gt, a, f),
+ (operators.gt, b, e),
+ (operators.gt, b, f),
+ (operators.eq, b1, b1a),
+ (operators.eq, b1b, d),
+ (operators.gt, c, e),
+ (operators.gt, c, f)
+ ]
+ )
+
+ def test_composed(self):
+ a, b, e, f, q, j, r = (
+ column("a"),
+ column("b"),
+ column("e"),
+ column("f"),
+ column("q"),
+ column("j"),
+ column("r"),
+ )
+ expr = and_(
+ (a + b) == q + func.sum(e + f),
+ and_(
+ j == r,
+ f == q
+ )
+ )
+ self._assert_traversal(
+ expr,
+ [
+ (operators.eq, a, q),
+ (operators.eq, a, e),
+ (operators.eq, a, f),
+ (operators.eq, b, q),
+ (operators.eq, b, e),
+ (operators.eq, b, f),
+ (operators.eq, j, r),
+ (operators.eq, f, q),
+ ]
+ )
+
class ClauseTest(fixtures.TestBase, AssertsCompiledSQL):
"""test copy-in-place behavior of various ClauseElements."""
diff --git a/test/sql/test_selectable.py b/test/sql/test_selectable.py
index 6d85f7c4f..4f1f39014 100644
--- a/test/sql/test_selectable.py
+++ b/test/sql/test_selectable.py
@@ -1151,5 +1151,7 @@ class AnnotationsTest(fixtures.TestBase):
assert b2.left is not bin.left
assert b3.left is not b2.left is not bin.left
assert b4.left is bin.left # since column is immutable
- assert b4.right is not bin.right is not b2.right is not b3.right
+ assert b4.right is bin.right
+ assert b2.right is not bin.right
+ assert b3.right is b4.right is bin.right