diff options
Diffstat (limited to 'lib/sqlalchemy/orm/relationships.py')
-rw-r--r-- | lib/sqlalchemy/orm/relationships.py | 72 |
1 files changed, 46 insertions, 26 deletions
diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index 311fba478..0d9ee87b3 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -12,7 +12,7 @@ SQL annotation and aliasing behavior focused on the `primaryjoin` and `secondaryjoin` aspects of :func:`.relationship`. """ - +from __future__ import absolute_import from .. import sql, util, exc as sa_exc, schema, log from .util import CascadeOptions, _orm_annotate, _orm_deannotate @@ -27,6 +27,7 @@ from ..sql import operators, expression, visitors from .interfaces import MANYTOMANY, MANYTOONE, ONETOMANY, StrategizedProperty, PropComparator from ..inspection import inspect from . import mapper as mapperlib +import collections def remote(expr): """Annotate a portion of a primaryjoin expression @@ -2391,16 +2392,38 @@ class JoinCondition(object): if onetomany_fk and manytoone_fk: # fks on both sides. test for overlap of local/remote - # with foreign key - self_equated = self.remote_columns.intersection( - self.local_columns - ) - onetomany_local = self.remote_columns.\ - intersection(self.foreign_key_columns).\ - difference(self_equated) - manytoone_local = self.local_columns.\ - intersection(self.foreign_key_columns).\ - difference(self_equated) + # with foreign key. + # we will gather columns directly from their annotations + # without deannotating, so that we can distinguish on a column + # that refers to itself. + + # 1. columns that are both remote and FK suggest + # onetomany. + onetomany_local = self._gather_columns_with_annotation( + self.primaryjoin, "remote", "foreign") + + # 2. columns that are FK but are not remote (e.g. local) + # suggest manytoone. + manytoone_local = set([c for c in + self._gather_columns_with_annotation( + self.primaryjoin, + "foreign") + if "remote" not in c._annotations]) + + # 3. if both collections are present, remove columns that + # refer to themselves. This is for the case of + # and_(Me.id == Me.remote_id, Me.version == Me.version) + if onetomany_local and manytoone_local: + self_equated = self.remote_columns.intersection( + self.local_columns + ) + onetomany_local = onetomany_local.difference(self_equated) + manytoone_local = manytoone_local.difference(self_equated) + + # at this point, if only one or the other collection is + # present, we know the direction, otherwise it's still + # ambiguous. + if onetomany_local and not manytoone_local: self.direction = ONETOMANY elif manytoone_local and not onetomany_local: @@ -2585,46 +2608,40 @@ class JoinCondition(object): def create_lazy_clause(self, reverse_direction=False): binds = util.column_dict() - lookup = util.column_dict() + lookup = collections.defaultdict(list) equated_columns = util.column_dict() - being_replaced = set() if reverse_direction and self.secondaryjoin is None: for l, r in self.local_remote_pairs: - _list = lookup.setdefault(r, []) - _list.append((r, l)) + lookup[r].append((r, l)) equated_columns[l] = r else: # replace all "local side" columns, which is # anything that isn't marked "remote" - being_replaced.update(self.local_columns) for l, r in self.local_remote_pairs: - _list = lookup.setdefault(l, []) - _list.append((l, r)) + lookup[l].append((l, r)) equated_columns[r] = l def col_to_bind(col): - if col in being_replaced or col in lookup: + if (reverse_direction and col in lookup) or \ + (not reverse_direction and "local" in col._annotations): if col in lookup: for tobind, equated in lookup[col]: if equated in binds: return None - else: - assert not reverse_direction if col not in binds: binds[col] = sql.bindparam( None, None, type_=col.type, unique=True) return binds[col] return None - lazywhere = self.deannotated_primaryjoin - - if self.deannotated_secondaryjoin is None or not reverse_direction: + lazywhere = self.primaryjoin + if self.secondaryjoin is None or not reverse_direction: lazywhere = visitors.replacement_traverse( lazywhere, {}, col_to_bind) - if self.deannotated_secondaryjoin is not None: - secondaryjoin = self.deannotated_secondaryjoin + if self.secondaryjoin is not None: + secondaryjoin = self.secondaryjoin if reverse_direction: secondaryjoin = visitors.replacement_traverse( secondaryjoin, {}, col_to_bind) @@ -2632,6 +2649,9 @@ class JoinCondition(object): bind_to_col = dict((binds[col].key, col) for col in binds) + # this is probably not necessary + lazywhere = _deep_deannotate(lazywhere) + return lazywhere, bind_to_col, equated_columns class _ColInAnnotations(object): |