summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/orm/__init__.py8
-rw-r--r--lib/sqlalchemy/orm/properties.py69
-rw-r--r--lib/sqlalchemy/orm/relationships.py447
-rw-r--r--lib/sqlalchemy/orm/strategies.py4
-rw-r--r--lib/sqlalchemy/orm/util.py13
-rw-r--r--lib/sqlalchemy/sql/expression.py24
-rw-r--r--lib/sqlalchemy/sql/util.py76
7 files changed, 294 insertions, 347 deletions
diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py
index 9fd969e3b..13bd18f08 100644
--- a/lib/sqlalchemy/orm/__init__.py
+++ b/lib/sqlalchemy/orm/__init__.py
@@ -44,6 +44,11 @@ from sqlalchemy.orm.properties import (
PropertyLoader,
SynonymProperty,
)
+from sqlalchemy.orm.relationships import (
+ foreign,
+ remote,
+ remote_foreign
+)
from sqlalchemy.orm import mapper as mapperlib
from sqlalchemy.orm.mapper import reconstructor, validates
from sqlalchemy.orm import strategies
@@ -81,6 +86,7 @@ __all__ = (
'dynamic_loader',
'eagerload',
'eagerload_all',
+ 'foreign',
'immediateload',
'join',
'joinedload',
@@ -96,6 +102,8 @@ __all__ = (
'reconstructor',
'relationship',
'relation',
+ 'remote',
+ 'remote_foreign',
'scoped_session',
'sessionmaker',
'subqueryload',
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index 77da33b5f..38237b2d4 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -16,7 +16,7 @@ from sqlalchemy.sql.util import ClauseAdapter, criterion_as_pairs, \
join_condition, _shallow_annotate
from sqlalchemy.sql import operators, expression, visitors
from sqlalchemy.orm import attributes, dependency, mapper, \
- object_mapper, strategies, configure_mappers
+ object_mapper, strategies, configure_mappers, relationships
from sqlalchemy.orm.util import CascadeOptions, _class_to_mapper, \
_orm_annotate, _orm_deannotate
@@ -915,13 +915,12 @@ class RelationshipProperty(StrategizedProperty):
self._check_conflicts()
self._process_dependent_arguments()
self._setup_join_conditions()
- self._extra_determine_direction()
+ self._check_cascade_settings()
self._post_init()
self._generate_backref()
super(RelationshipProperty, self).do_init()
def _setup_join_conditions(self):
- import relationships
self._join_condition = jc = relationships.JoinCondition(
parent_selectable=self.parent.mapped_table,
child_selectable=self.mapper.mapped_table,
@@ -946,8 +945,8 @@ class RelationshipProperty(StrategizedProperty):
self.local_remote_pairs = jc.local_remote_pairs
self.remote_side = jc.remote_columns
self.synchronize_pairs = jc.synchronize_pairs
- self.secondary_synchronize_pairs = jc.secondary_synchronize_pairs
self._calculated_foreign_keys = jc.foreign_key_columns
+ self.secondary_synchronize_pairs = jc.secondary_synchronize_pairs
def _check_conflicts(self):
"""Test that this relationship is legal, warn about
@@ -1035,7 +1034,7 @@ class RelationshipProperty(StrategizedProperty):
(self.key, self.parent.class_)
)
- def _extra_determine_direction(self):
+ def _check_cascade_settings(self):
if self.cascade.delete_orphan and not self.single_parent \
and (self.direction is MANYTOMANY or self.direction
is MANYTOONE):
@@ -1064,7 +1063,6 @@ class RelationshipProperty(StrategizedProperty):
return False
return True
-
def _generate_backref(self):
if not self.is_primary():
return
@@ -1083,13 +1081,15 @@ class RelationshipProperty(StrategizedProperty):
pj = kwargs.pop('primaryjoin', self.secondaryjoin)
sj = kwargs.pop('secondaryjoin', self.primaryjoin)
else:
- pj = kwargs.pop('primaryjoin', self.primaryjoin)
+ pj = kwargs.pop('primaryjoin',
+ self._join_condition.primaryjoin_reverse_remote)
sj = kwargs.pop('secondaryjoin', None)
if sj:
raise sa_exc.InvalidRequestError(
"Can't assign 'secondaryjoin' on a backref against "
"a non-secondary relationship."
)
+
foreign_keys = kwargs.pop('foreign_keys',
self._user_defined_foreign_keys)
parent = self.parent.primary_mapper()
@@ -1112,21 +1112,6 @@ class RelationshipProperty(StrategizedProperty):
self._add_reverse_property(self.back_populates)
def _post_init(self):
- self.logger.info('%s setup primary join %s', self,
- self.primaryjoin)
- self.logger.info('%s setup secondary join %s', self,
- self.secondaryjoin)
- self.logger.info('%s synchronize pairs [%s]', self,
- ','.join('(%s => %s)' % (l, r) for (l, r) in
- self.synchronize_pairs))
- self.logger.info('%s secondary synchronize pairs [%s]', self,
- ','.join('(%s => %s)' % (l, r) for (l, r) in
- self.secondary_synchronize_pairs or []))
- self.logger.info('%s local/remote pairs [%s]', self,
- ','.join('(%s / %s)' % (l, r) for (l, r) in
- self.local_remote_pairs))
- self.logger.info('%s relationship direction %s', self,
- self.direction)
if self.uselist is None:
self.uselist = self.direction is not MANYTOONE
if not self.viewonly:
@@ -1141,46 +1126,6 @@ class RelationshipProperty(StrategizedProperty):
strategy = self._get_strategy(strategies.LazyLoader)
return strategy.use_get
- def _refers_to_parent_table(self):
- alt = self._alt_refers_to_parent_table()
- pt = self.parent.mapped_table
- mt = self.mapper.mapped_table
- for c, f in self.synchronize_pairs:
- if (
- pt.is_derived_from(c.table) and \
- pt.is_derived_from(f.table) and \
- mt.is_derived_from(c.table) and \
- mt.is_derived_from(f.table)
- ):
- assert alt
- return True
- else:
- assert not alt
- return False
-
- def _alt_refers_to_parent_table(self):
- pt = self.parent.mapped_table
- mt = self.mapper.mapped_table
- result = [False]
- def visit_binary(binary):
- c, f = binary.left, binary.right
- if (
- isinstance(c, expression.ColumnClause) and \
- isinstance(f, expression.ColumnClause) and \
- pt.is_derived_from(c.table) and \
- pt.is_derived_from(f.table) and \
- mt.is_derived_from(c.table) and \
- mt.is_derived_from(f.table)
- ):
- result[0] = True
-
- visitors.traverse(
- self.primaryjoin,
- {},
- {"binary":visit_binary}
- )
- return result[0]
-
@util.memoized_property
def _is_self_referential(self):
return self.mapper.common_parent(self.parent)
diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py
index 723d45295..d8c2659b6 100644
--- a/lib/sqlalchemy/orm/relationships.py
+++ b/lib/sqlalchemy/orm/relationships.py
@@ -15,7 +15,7 @@ and `secondaryjoin` aspects of :func:`.relationship`.
from sqlalchemy import sql, util, log, exc as sa_exc, schema
from sqlalchemy.sql.util import ClauseAdapter, criterion_as_pairs, \
- join_condition, _shallow_annotate
+ join_condition, _shallow_annotate, visit_binary_product
from sqlalchemy.sql import operators, expression, visitors
from sqlalchemy.orm.interfaces import MANYTOMANY, MANYTOONE, ONETOMANY
@@ -78,7 +78,34 @@ class JoinCondition(object):
self._determine_joins()
self._annotate_fks()
self._annotate_remote()
+ self._annotate_local()
self._determine_direction()
+ self._setup_pairs()
+ self._check_foreign_cols(self.primaryjoin, True)
+ if self.secondaryjoin is not None:
+ self._check_foreign_cols(self.secondaryjoin, False)
+ self._check_remote_side()
+ self._log_joins()
+
+ def _log_joins(self):
+ if self.prop is None:
+ return
+ log = self.prop.logger
+ log.info('%s setup primary join %s', self,
+ self.primaryjoin)
+ log.info('%s setup secondary join %s', self,
+ self.secondaryjoin)
+ log.info('%s synchronize pairs [%s]', self,
+ ','.join('(%s => %s)' % (l, r) for (l, r) in
+ self.synchronize_pairs))
+ log.info('%s secondary synchronize pairs [%s]', self,
+ ','.join('(%s => %s)' % (l, r) for (l, r) in
+ self.secondary_synchronize_pairs or []))
+ log.info('%s local/remote pairs [%s]', self,
+ ','.join('(%s / %s)' % (l, r) for (l, r) in
+ self.local_remote_pairs))
+ log.info('%s relationship direction %s', self,
+ self.direction)
def _determine_joins(self):
"""Determine the 'primaryjoin' and 'secondaryjoin' attributes,
@@ -128,28 +155,60 @@ class JoinCondition(object):
"'secondaryjoin' is needed as well."
% self.prop)
+ @util.memoized_property
+ def primaryjoin_reverse_remote(self):
+ def replace(element):
+ if "remote" in element._annotations:
+ v = element._annotations.copy()
+ del v['remote']
+ v['local'] = True
+ return element._with_annotations(v)
+ elif "local" in element._annotations:
+ v = element._annotations.copy()
+ del v['local']
+ v['remote'] = True
+ return element._with_annotations(v)
+ return visitors.replacement_traverse(self.primaryjoin, {}, replace)
+
+ def _has_annotation(self, clause, annotation):
+ for col in visitors.iterate(clause, {}):
+ if annotation in col._annotations:
+ return True
+ else:
+ return False
+
def _annotate_fks(self):
+ if self._has_annotation(self.primaryjoin, "foreign"):
+ return
+
+ if self.consider_as_foreign_keys:
+ self._annotate_from_fk_list()
+ else:
+ self._annotate_present_fks()
+
+ def _annotate_from_fk_list(self):
+ def check_fk(col):
+ if col in self.consider_as_foreign_keys:
+ return col._annotate({"foreign":True})
+ self.primaryjoin = visitors.replacement_traverse(
+ self.primaryjoin,
+ {},
+ check_fk
+ )
+ if self.secondaryjoin is not None:
+ self.secondaryjoin = visitors.replacement_traverse(
+ self.secondaryjoin,
+ {},
+ check_fk
+ )
+
+ def _annotate_present_fks(self):
if self.secondary is not None:
secondarycols = util.column_set(self.secondary.c)
else:
secondarycols = set()
- def col_is(a, b):
- return a.compare(b)
-
def is_foreign(a, b):
- if self.consider_as_foreign_keys:
- if a in self.consider_as_foreign_keys and (
- col_is(a, b) or
- b not in self.consider_as_foreign_keys
- ):
- return a
- elif b in self.consider_as_foreign_keys and (
- col_is(a, b) or
- a not in self.consider_as_foreign_keys
- ):
- return b
-
if isinstance(a, schema.Column) and \
isinstance(b, schema.Column):
if a.references(b):
@@ -163,19 +222,6 @@ class JoinCondition(object):
elif b in secondarycols and a not in secondarycols:
return b
- def _annotate_fk(binary, left, right):
- can_be_synced = self.can_be_synced_fn(left)
- left = left._annotate({
- #"equated":binary.operator is operators.eq,
- "can_be_synced":can_be_synced and \
- binary.operator is operators.eq
- })
- right = right._annotate({
- #"equated":binary.operator is operators.eq,
- "referent":True
- })
- return left, right
-
def visit_binary(binary):
if not isinstance(binary.left, sql.ColumnElement) or \
not isinstance(binary.right, sql.ColumnElement):
@@ -185,20 +231,12 @@ class JoinCondition(object):
"foreign" not in binary.right._annotations:
col = is_foreign(binary.left, binary.right)
if col is not None:
- if col is binary.left:
+ if col.compare(binary.left):
binary.left = binary.left._annotate(
{"foreign":True})
- elif col is binary.right:
+ elif col.compare(binary.right):
binary.right = binary.right._annotate(
{"foreign":True})
- # TODO: when the two cols are the same.
-
- if "foreign" in binary.left._annotations:
- binary.left, binary.right = _annotate_fk(
- binary, binary.left, binary.right)
- if "foreign" in binary.right._annotations:
- binary.right, binary.left = _annotate_fk(
- binary, binary.right, binary.left)
self.primaryjoin = visitors.cloned_traverse(
self.primaryjoin,
@@ -211,11 +249,6 @@ class JoinCondition(object):
{},
{"binary":visit_binary}
)
- self._check_foreign_cols(
- self.primaryjoin, True)
- if self.secondaryjoin is not None:
- self._check_foreign_cols(
- self.secondaryjoin, False)
def _refers_to_parent_table(self):
pt = self.parent_selectable
@@ -241,18 +274,14 @@ class JoinCondition(object):
return result[0]
def _annotate_remote(self):
- parentcols = util.column_set(self.parent_selectable.c)
+ if self._has_annotation(self.primaryjoin, "remote"):
+ return
- for col in visitors.iterate(self.primaryjoin, {}):
- if "remote" in col._annotations:
- has_remote_annotations = True
- break
- else:
- has_remote_annotations = False
+ parentcols = util.column_set(self.parent_selectable.c)
def _annotate_selfref(fn):
def visit_binary(binary):
- equated = binary.left is binary.right
+ equated = binary.left.compare(binary.right)
if isinstance(binary.left, sql.ColumnElement) and \
isinstance(binary.right, sql.ColumnElement):
# assume one to many - FKs are "remote"
@@ -267,44 +296,72 @@ class JoinCondition(object):
self.primaryjoin, {},
{"binary":visit_binary})
- if not has_remote_annotations:
+ if self.secondary is not None:
+ def repl(element):
+ if self.secondary.c.contains_column(element):
+ return element._annotate({"remote":True})
+ self.primaryjoin = visitors.replacement_traverse(
+ self.primaryjoin, {}, repl)
+ self.secondaryjoin = visitors.replacement_traverse(
+ self.secondaryjoin, {}, repl)
+ elif self._local_remote_pairs or self._remote_side:
+
if self._local_remote_pairs:
- raise NotImplementedError()
- elif self._remote_side:
- if self._refers_to_parent_table():
- _annotate_selfref(lambda col:col in self._remote_side)
- else:
- def repl(element):
- if element in self._remote_side:
- return element._annotate({"remote":True})
- self.primaryjoin = visitors.replacement_traverse(
- self.primaryjoin, {}, repl)
- elif self.secondary is not None:
- def repl(element):
- if self.secondary.c.contains_column(element):
- return element._annotate({"remote":True})
- self.primaryjoin = visitors.replacement_traverse(
- self.primaryjoin, {}, repl)
- self.secondaryjoin = visitors.replacement_traverse(
- self.secondaryjoin, {}, repl)
- elif self._refers_to_parent_table():
- _annotate_selfref(lambda col:"foreign" in col._annotations)
+ if self._remote_side:
+ raise sa_exc.ArgumentError(
+ "remote_side argument is redundant "
+ "against more detailed _local_remote_side "
+ "argument.")
+
+ remote_side = [r for (l, r) in self._local_remote_pairs]
+ else:
+ remote_side = self._remote_side
+
+ if self._refers_to_parent_table():
+ _annotate_selfref(lambda col:col in remote_side)
else:
def repl(element):
- if self.child_selectable.c.contains_column(element):
+ if element in remote_side:
return element._annotate({"remote":True})
-
self.primaryjoin = visitors.replacement_traverse(
self.primaryjoin, {}, repl)
+ elif self._refers_to_parent_table():
+ _annotate_selfref(lambda col:"foreign" in col._annotations)
+ else:
+ def repl(element):
+ if self.child_selectable.c.contains_column(element):
+ return element._annotate({"remote":True})
+
+ self.primaryjoin = visitors.replacement_traverse(
+ self.primaryjoin, {}, repl)
+
+ def _annotate_local(self):
+ if self._has_annotation(self.primaryjoin, "local"):
+ return
+
+ parentcols = util.column_set(self.parent_selectable.c)
+
+ if self._local_remote_pairs:
+ local_side = util.column_set([l for (l, r)
+ in self._local_remote_pairs])
+ else:
+ local_side = util.column_set(self.parent_selectable.c)
def locals_(elem):
if "remote" not in elem._annotations and \
- elem in parentcols:
+ elem in local_side:
return elem._annotate({"local":True})
self.primaryjoin = visitors.replacement_traverse(
self.primaryjoin, {}, locals_
)
+ def _check_remote_side(self):
+ if not self.local_remote_pairs:
+ raise sa_exc.ArgumentError('Relationship %s could '
+ 'not determine any local/remote column '
+ 'pairs from remote side argument %r'
+ % (self.prop, self._remote_side))
+
def _check_foreign_cols(self, join_condition, primary):
"""Check the foreign key columns collected and emit error messages."""
@@ -315,11 +372,10 @@ class JoinCondition(object):
has_foreign = bool(foreign_cols)
- if self.support_sync:
- for col in foreign_cols:
- if col._annotations.get("can_be_synced"):
- can_sync = True
- break
+ if primary:
+ can_sync = bool(self.synchronize_pairs)
+ else:
+ can_sync = bool(self.secondary_synchronize_pairs)
if self.support_sync and can_sync or \
(not self.support_sync and has_foreign):
@@ -407,6 +463,44 @@ class JoinCondition(object):
"key columns are present in neither the parent "
"nor the child's mapped tables" % self.prop)
+ def _setup_pairs(self):
+ sync_pairs = []
+ lrp = util.OrderedSet([])
+ secondary_sync_pairs = []
+
+ def go(joincond, collection):
+ def visit_binary(binary, left, right):
+ if "remote" in right._annotations and \
+ "remote" not in left._annotations and \
+ self.can_be_synced_fn(left):
+ lrp.add((left, right))
+ elif "remote" in left._annotations and \
+ "remote" not in right._annotations and \
+ self.can_be_synced_fn(right):
+ lrp.add((right, left))
+ if binary.operator is operators.eq:
+ # and \
+ #binary.left.compare(left) and \
+ #binary.right.compare(right):
+ if "foreign" in right._annotations:
+ collection.append((left, right))
+ elif "foreign" in left._annotations:
+ collection.append((right, left))
+ visit_binary_product(visit_binary, joincond)
+
+ for joincond, collection in [
+ (self.primaryjoin, sync_pairs),
+ (self.secondaryjoin, secondary_sync_pairs)
+ ]:
+ if joincond is None:
+ continue
+ go(joincond, collection)
+
+ self.local_remote_pairs = list(lrp)
+ self.synchronize_pairs = sync_pairs
+ self.secondary_synchronize_pairs = secondary_sync_pairs
+
+
@util.memoized_property
def remote_columns(self):
return self._gather_join_annotations("remote")
@@ -416,38 +510,6 @@ class JoinCondition(object):
return self._gather_join_annotations("local")
@util.memoized_property
- def synchronize_pairs(self):
- parentcols = util.column_set(self.parent_selectable.c)
- targetcols = util.column_set(self.child_selectable.c)
- result = []
- for l, r in self.local_remote_pairs:
- if self.secondary is not None:
- if "foreign" in r._annotations and \
- l in parentcols:
- result.append((l, r))
- elif "foreign" in r._annotations and \
- "can_be_synced" in r._annotations:
- result.append((l, r))
- elif "foreign" in l._annotations and \
- "can_be_synced" in l._annotations:
- result.append((r, l))
- return result
-
- @util.memoized_property
- def secondary_synchronize_pairs(self):
- parentcols = util.column_set(self.parent_selectable.c)
- targetcols = util.column_set(self.child_selectable.c)
- result = []
- if self.secondary is None:
- return result
-
- for l, r in self.local_remote_pairs:
- if "foreign" in l._annotations and \
- r in targetcols:
- result.append((l, r))
- return result
-
- @util.memoized_property
def foreign_key_columns(self):
return self._gather_join_annotations("foreign")
@@ -470,24 +532,6 @@ class JoinCondition(object):
if annotation.issubset(col._annotations)
])
- @util.memoized_property
- def local_remote_pairs(self):
- lrp = util.OrderedSet()
- def visit_binary(binary):
- if "remote" in binary.right._annotations and \
- "remote" not in binary.left._annotations and \
- isinstance(binary.left, expression.ColumnClause) and \
- self.can_be_synced_fn(binary.left):
- lrp.add((binary.left, binary.right))
- elif "remote" in binary.left._annotations and \
- "remote" not in binary.right._annotations and \
- isinstance(binary.right, expression.ColumnClause) and \
- self.can_be_synced_fn(binary.right):
- lrp.add((binary.right, binary.left))
- visitors.traverse(self.primaryjoin, {}, {"binary":visit_binary})
- if self.secondaryjoin is not None:
- visitors.traverse(self.secondaryjoin, {}, {"binary":visit_binary})
- return list(lrp)
def join_targets(self, source_selectable,
dest_selectable,
@@ -604,147 +648,6 @@ def _create_lazy_clause(cls, prop, reverse_direction=False):
return lazywhere, bind_to_col, equated_columns
-def _determine_synchronize_pairs(self):
- """Resolve 'primary'/foreign' column pairs from the primaryjoin
- and secondaryjoin arguments.
-
- """
- if self.local_remote_pairs:
- if not self._user_defined_foreign_keys:
- raise sa_exc.ArgumentError(
- "foreign_keys argument is "
- "required with _local_remote_pairs argument")
- self.synchronize_pairs = []
- for l, r in self.local_remote_pairs:
- if r in self._user_defined_foreign_keys:
- self.synchronize_pairs.append((l, r))
- elif l in self._user_defined_foreign_keys:
- self.synchronize_pairs.append((r, l))
- else:
- self.synchronize_pairs = self._sync_pairs_from_join(
- self.primaryjoin,
- True)
-
- self._calculated_foreign_keys = util.column_set(
- r for (l, r) in
- self.synchronize_pairs)
-
- if self.secondaryjoin is not None:
- self.secondary_synchronize_pairs = self._sync_pairs_from_join(
- self.secondaryjoin,
- False)
- self._calculated_foreign_keys.update(
- r for (l, r) in
- self.secondary_synchronize_pairs)
- else:
- self.secondary_synchronize_pairs = None
-
-
-def _determine_local_remote_pairs(self):
- """Determine pairs of columns representing "local" to
- "remote", where "local" columns are on the parent mapper,
- "remote" are on the target mapper.
-
- These pairs are used on the load side only to generate
- lazy loading clauses.
-
- """
- if not self.local_remote_pairs and not self.remote_side:
- # the most common, trivial case. Derive
- # local/remote pairs from the synchronize pairs.
- eq_pairs = util.unique_list(
- self.synchronize_pairs +
- (self.secondary_synchronize_pairs or []))
- if self.direction is MANYTOONE:
- self.local_remote_pairs = [(r, l) for l, r in eq_pairs]
- else:
- self.local_remote_pairs = eq_pairs
-
- # "remote_side" specified, derive from the primaryjoin
- # plus remote_side, similarly to how synchronize_pairs
- # were determined.
- elif self.remote_side:
- if self.local_remote_pairs:
- raise sa_exc.ArgumentError('remote_side argument is '
- 'redundant against more detailed '
- '_local_remote_side argument.')
- if self.direction is MANYTOONE:
- self.local_remote_pairs = [(r, l) for (l, r) in
- criterion_as_pairs(self.primaryjoin,
- consider_as_referenced_keys=self.remote_side,
- any_operator=True)]
-
- else:
- self.local_remote_pairs = \
- criterion_as_pairs(self.primaryjoin,
- consider_as_foreign_keys=self.remote_side,
- any_operator=True)
- if not self.local_remote_pairs:
- raise sa_exc.ArgumentError('Relationship %s could '
- 'not determine any local/remote column '
- 'pairs from remote side argument %r'
- % (self, self.remote_side))
- # else local_remote_pairs were sent explcitly via
- # ._local_remote_pairs.
-
- # create local_side/remote_side accessors
- self.local_side = util.ordered_column_set(
- l for l, r in self.local_remote_pairs)
- self.remote_side = util.ordered_column_set(
- r for l, r in self.local_remote_pairs)
-
- # check that the non-foreign key column in the local/remote
- # collection is mapped. The foreign key
- # which the individual mapped column references directly may
- # itself be in a non-mapped table; see
- # test.orm.test_relationships.ViewOnlyComplexJoin.test_basic
- # for an example of this.
- if self.direction is ONETOMANY:
- for col in self.local_side:
- if not self._columns_are_mapped(col):
- raise sa_exc.ArgumentError(
- "Local column '%s' is not "
- "part of mapping %s. Specify remote_side "
- "argument to indicate which column lazy join "
- "condition should compare against." % (col,
- self.parent))
- elif self.direction is MANYTOONE:
- for col in self.remote_side:
- if not self._columns_are_mapped(col):
- raise sa_exc.ArgumentError(
- "Remote column '%s' is not "
- "part of mapping %s. Specify remote_side "
- "argument to indicate which column lazy join "
- "condition should bind." % (col, self.mapper))
-
- count = [0]
- def clone(elem):
- if set(['local', 'remote']).intersection(elem._annotations):
- return None
- elif elem in self.local_side and elem in self.remote_side:
- # TODO: OK this still sucks. this is basically,
- # refuse, refuse, refuse the temptation to guess!
- # but crap we really have to guess don't we. we
- # might want to traverse here with cloned_traverse
- # so we can see the binary exprs and do it at that
- # level....
- if count[0] % 2 == 0:
- elem = elem._annotate({'local':True})
- else:
- elem = elem._annotate({'remote':True})
- count[0] += 1
- elif elem in self.local_side:
- elem = elem._annotate({'local':True})
- elif elem in self.remote_side:
- elem = elem._annotate({'remote':True})
- else:
- elem = None
- return elem
-
- self.primaryjoin = visitors.replacement_traverse(
- self.primaryjoin, {}, clone
- )
-
def _criterion_exists(self, criterion=None, **kwargs):
if getattr(self, '_of_type', None):
diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py
index 5f4b182d0..320234281 100644
--- a/lib/sqlalchemy/orm/strategies.py
+++ b/lib/sqlalchemy/orm/strategies.py
@@ -785,6 +785,7 @@ class SubqueryLoader(AbstractRelationshipLoader):
leftmost_mapper, leftmost_prop = \
subq_mapper, \
subq_mapper._props[subq_path[1]]
+ # TODO: local cols might not be unique here
leftmost_cols, remote_cols = self._local_remote_columns(leftmost_prop)
leftmost_attr = [
@@ -846,6 +847,7 @@ class SubqueryLoader(AbstractRelationshipLoader):
# self.parent is more specific than subq_path[-2]
parent_alias = mapperutil.AliasedClass(self.parent)
+ # TODO: local cols might not be unique here
local_cols, remote_cols = \
self._local_remote_columns(self.parent_property)
@@ -885,6 +887,7 @@ class SubqueryLoader(AbstractRelationshipLoader):
if prop.secondary is None:
return zip(*prop.local_remote_pairs)
else:
+ # TODO: this isn't going to work for readonly....
return \
[p[0] for p in prop.synchronize_pairs],\
[
@@ -930,6 +933,7 @@ class SubqueryLoader(AbstractRelationshipLoader):
if ('subquery', reduced_path) not in context.attributes:
return None, None, None
+ # TODO: local_cols might not be unique here
local_cols, remote_cols = self._local_remote_columns(self.parent_property)
q = context.attributes[('subquery', reduced_path)]
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index 0cd5b0594..f17f675f4 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -366,7 +366,18 @@ def _orm_annotate(element, exclude=None):
"""
return sql_util._deep_annotate(element, {'_orm_adapt':True}, exclude)
-_orm_deannotate = sql_util._deep_deannotate
+def _orm_deannotate(element):
+ """Remove annotations that link a column to a particular mapping.
+
+ Note this doesn't affect "remote" and "foreign" annotations
+ passed by the :func:`.orm.foreign` and :func:`.orm.remote`
+ annotators.
+
+ """
+
+ return sql_util._deep_deannotate(element,
+ values=("_orm_adapt", "parententity")
+ )
class _ORMJoin(expression.Join):
"""Extend Join to support ORM constructs as input."""
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index 72099a5f5..ebf4de9a2 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -1576,18 +1576,30 @@ class ClauseElement(Visitable):
return id(self)
def _annotate(self, values):
- """return a copy of this ClauseElement with the given annotations
- dictionary.
+ """return a copy of this ClauseElement with annotations
+ updated by the given dictionary.
"""
return sqlutil.Annotated(self, values)
- def _deannotate(self):
- """return a copy of this ClauseElement with an empty annotations
- dictionary.
+ def _with_annotations(self, values):
+ """return a copy of this ClauseElement with annotations
+ replaced by the given dictionary.
"""
- return self._clone()
+ return sqlutil.Annotated(self, values)
+
+ def _deannotate(self, values=None):
+ """return a copy of this :class:`.ClauseElement` with annotations
+ removed.
+
+ :param values: optional tuple of individual values
+ to remove.
+
+ """
+ # since we have no annotations we return
+ # self
+ return self
def unique_params(self, *optionaldict, **kwargs):
"""Return a copy with :func:`bindparam()` elments replaced.
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index f0509c16f..9a45a5777 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -62,6 +62,61 @@ def find_join_source(clauses, join_to):
else:
return None, None
+
+def visit_binary_product(fn, expr):
+ """Produce a traversal of the given expression, delivering
+ column comparisons to the given function.
+
+ The function is of the form::
+
+ def my_fn(binary, left, right)
+
+ For each binary expression located which has a
+ comparison operator, the product of "left" and
+ "right" will be delivered to that function,
+ in terms of that binary.
+
+ Hence an expression like::
+
+ and_(
+ (a + b) == q + func.sum(e + f),
+ j == r
+ )
+
+ would have the traversal::
+
+ a <eq> q
+ a <eq> e
+ a <eq> f
+ b <eq> q
+ b <eq> e
+ b <eq> f
+ j <eq> r
+
+ That is, every combination of "left" and
+ "right" that doesn't further contain
+ a binary comparison is passed as pairs.
+
+ """
+ stack = []
+ def visit(element):
+ if element.__visit_name__ == 'binary' and \
+ operators.is_comparison(element.operator):
+ stack.insert(0, element)
+ for l in visit(element.left):
+ for r in visit(element.right):
+ fn(stack[0], l, r)
+ stack.pop(0)
+ for elem in element.get_children():
+ visit(elem)
+ else:
+ if isinstance(element, expression.ColumnClause):
+ yield element
+ for elem in element.get_children():
+ for e in visit(elem):
+ yield e
+ list(visit(expr))
+
def find_tables(clause, check_columns=False,
include_aliases=False, include_joins=False,
include_selects=False, include_crud=False):
@@ -357,13 +412,22 @@ class Annotated(object):
def _annotate(self, values):
_values = self._annotations.copy()
_values.update(values)
+ return self._with_annotations(_values)
+
+ def _with_annotations(self, values):
clone = self.__class__.__new__(self.__class__)
clone.__dict__ = self.__dict__.copy()
- clone._annotations = _values
+ clone._annotations = values
return clone
- def _deannotate(self):
- return self.__element
+ def _deannotate(self, values=None):
+ if values is None:
+ return self.__element
+ else:
+ _values = self._annotations.copy()
+ for v in values:
+ _values.pop(v, None)
+ return self._with_annotations(_values)
def _compiler_dispatch(self, visitor, **kw):
return self.__element.__class__._compiler_dispatch(self, visitor, **kw)
@@ -426,11 +490,11 @@ def _deep_annotate(element, annotations, exclude=None):
element = clone(element)
return element
-def _deep_deannotate(element):
- """Deep copy the given element, removing all annotations."""
+def _deep_deannotate(element, values=None):
+ """Deep copy the given element, removing annotations."""
def clone(elem):
- elem = elem._deannotate()
+ elem = elem._deannotate(values=values)
elem._copy_internals(clone=clone)
return elem