summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/sqlalchemy/orm/util.py126
-rw-r--r--lib/sqlalchemy/sql/expression.py16
-rw-r--r--lib/sqlalchemy/sql/util.py2
-rw-r--r--test/orm/test_joins.py103
4 files changed, 167 insertions, 80 deletions
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index dd24d8bf4..9472f8698 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -874,70 +874,82 @@ class _ORMJoin(expression.Join):
isouter=False, join_to_left=True):
adapt_from = None
- if hasattr(left, '_orm_mappers'):
- left_mapper = left._orm_mappers[1]
- else:
- info = inspection.inspect(left)
- left_mapper = getattr(info, 'mapper', None)
-
left_info = inspection.inspect(left)
- left_selectable = left_info.selectable
- info = inspection.inspect(right)
- right_mapper = getattr(info, 'mapper', None)
- right = info.selectable
- right_is_aliased = getattr(info, 'is_aliased_class', False)
+ if hasattr(left, '_orm_infos'):
+ left_orm_info = left._orm_infos[1]
+ else:
+ #if isinstance(left, expression.Join):
+ # info = inspection.inspect(left.right)
+ #else:
+ # info = inspection.inspect(left)
+ left_orm_info = left_info
- if right_is_aliased:
- adapt_to = right
+ right_info = inspection.inspect(right)
+
+ if getattr(right_info, 'is_aliased_class', False):
+ adapt_to = right_info.selectable
else:
adapt_to = None
- if left_mapper or right_mapper:
- self._orm_mappers = (left_mapper, right_mapper)
-
- if isinstance(onclause, basestring):
- prop = left_mapper.get_property(onclause)
- on_selectable = prop.parent.selectable
- elif isinstance(onclause, attributes.QueryableAttribute):
- on_selectable = onclause.comparator._source_selectable()
- #if adapt_from is None:
- # adapt_from = onclause.comparator._source_selectable()
- prop = onclause.property
- elif isinstance(onclause, MapperProperty):
- prop = onclause
- on_selectable = prop.parent.selectable
+# import pdb
+# pdb.set_trace()
+ self._orm_infos = (left_orm_info, right_info)
+
+ if isinstance(onclause, basestring):
+ onclause = getattr(left_orm_info.entity, onclause)
+
+ if isinstance(onclause, attributes.QueryableAttribute):
+ on_selectable = onclause.comparator._source_selectable()
+ prop = onclause.property
+ elif isinstance(onclause, MapperProperty):
+ prop = onclause
+ on_selectable = prop.parent.selectable
+ else:
+ prop = None
+
+ if prop:
+ #import pdb
+ #pdb.set_trace()
+ if sql_util.clause_is_present(on_selectable, left_info.selectable):
+ adapt_from = on_selectable
+ else:
+ adapt_from = left_info.selectable
+# import pdb
+# pdb.set_trace()
+ #adapt_from = left_orm_info.selectable
+ #adapt_from = left_info.selectable
+# adapt_from = None
+# if adapt_from is None:
+# _derived = []
+# for s in expression._from_objects(left_info.selectable):
+# if s == on_selectable:
+# adapt_from = s
+# break
+# elif s.is_derived_from(on_selectable):
+# _derived.append(s)
+# else:
+# if _derived:
+# adapt_from = _derived[0]
+
+ #if adapt_from is None:
+# adapt_from = left_info.selectable
+
+ #adapt_from = None
+ pj, sj, source, dest, \
+ secondary, target_adapter = prop._create_joins(
+ source_selectable=adapt_from,
+ dest_selectable=adapt_to,
+ source_polymorphic=True,
+ dest_polymorphic=True,
+ of_type=right_info.mapper)
+
+ if sj is not None:
+ left = sql.join(left, secondary, pj, isouter)
+ onclause = sj
else:
- prop = None
-
- if prop:
- import pdb
- pdb.set_trace()
- _derived = []
- for s in expression._from_objects(left_selectable):
- if s == on_selectable:
- adapt_from = s
- break
- elif s.is_derived_from(on_selectable):
- _derived.append(s)
- else:
- if _derived:
- adapt_from = _derived[0]
-
- pj, sj, source, dest, \
- secondary, target_adapter = prop._create_joins(
- source_selectable=adapt_from,
- dest_selectable=adapt_to,
- source_polymorphic=True,
- dest_polymorphic=True,
- of_type=right_mapper)
-
- if sj is not None:
- left = sql.join(left, secondary, pj, isouter)
- onclause = sj
- else:
- onclause = pj
- self._target_adapter = target_adapter
+ onclause = pj
+ self._target_adapter = target_adapter
expression.Join.__init__(self, left, right, onclause, isouter)
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index d2e644ce2..92b8aea98 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -3909,8 +3909,14 @@ class Join(FromClause):
def is_derived_from(self, fromclause):
return fromclause is self or \
- self.left.is_derived_from(fromclause) or\
- self.right.is_derived_from(fromclause)
+ self.left.is_derived_from(fromclause) or \
+ self.right.is_derived_from(fromclause) or \
+ (
+ isinstance(fromclause, Join) and
+ self.left.is_derived_from(fromclause.left) and
+ self.right.is_derived_from(fromclause.right) and
+ self.onclause.compare(fromclause.onclause)
+ )
def self_group(self, against=None):
return FromGrouping(self)
@@ -3947,6 +3953,12 @@ class Join(FromClause):
def get_children(self, **kwargs):
return self.left, self.right, self.onclause
+ def compare(self, other):
+ return isinstance(other, Join) and \
+ self.left.compare(other.left) and \
+ self.right.compare(other.right) and \
+ self.onclause.compare(other.onclause)
+
def _match_primaries(self, left, right):
if isinstance(left, Join):
left_right = left.right
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index 520c90f99..4aa2d7496 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -203,7 +203,7 @@ def clause_is_present(clause, search):
stack = [search]
while stack:
elem = stack.pop()
- if clause is elem:
+ if clause == elem: # use == here so that Annotated's compare
return True
elif isinstance(elem, expression.Join):
stack.extend((elem.left, elem.right))
diff --git a/test/orm/test_joins.py b/test/orm/test_joins.py
index 629c55ce5..2bf0d8d92 100644
--- a/test/orm/test_joins.py
+++ b/test/orm/test_joins.py
@@ -215,7 +215,7 @@ class InheritedJoinTest(fixtures.MappedTest, AssertsCompiledSQL):
, use_default_dialect = True
)
- def test_prop_with_polymorphic(self):
+ def test_prop_with_polymorphic_1(self):
Person, Manager, Paperwork = (self.classes.Person,
self.classes.Manager,
self.classes.Paperwork)
@@ -238,6 +238,13 @@ class InheritedJoinTest(fixtures.MappedTest, AssertsCompiledSQL):
, use_default_dialect=True
)
+ def test_prop_with_polymorphic_2(self):
+ Person, Manager, Paperwork = (self.classes.Person,
+ self.classes.Manager,
+ self.classes.Paperwork)
+
+ sess = create_session()
+
self.assert_compile(
sess.query(Person).with_polymorphic(Manager).
join('paperwork', aliased=True).
@@ -1928,34 +1935,50 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL):
use_default_dialect=True
)
- def test_explicit_join(self):
+ def test_explicit_join_1(self):
Node = self.classes.Node
-
- sess = create_session()
-
n1 = aliased(Node)
n2 = aliased(Node)
self.assert_compile(
join(Node, n1, 'children').join(n2, 'children'),
- "nodes JOIN nodes AS nodes_1 ON nodes.id = nodes_1.parent_id JOIN nodes AS nodes_2 ON nodes_1.id = nodes_2.parent_id",
+ "nodes JOIN nodes AS nodes_1 ON nodes.id = nodes_1.parent_id "
+ "JOIN nodes AS nodes_2 ON nodes_1.id = nodes_2.parent_id",
use_default_dialect=True
)
+ def test_explicit_join_2(self):
+ Node = self.classes.Node
+ n1 = aliased(Node)
+ n2 = aliased(Node)
+
self.assert_compile(
join(Node, n1, Node.children).join(n2, n1.children),
- "nodes JOIN nodes AS nodes_1 ON nodes.id = nodes_1.parent_id JOIN nodes AS nodes_2 ON nodes_1.id = nodes_2.parent_id",
+ "nodes JOIN nodes AS nodes_1 ON nodes.id = nodes_1.parent_id "
+ "JOIN nodes AS nodes_2 ON nodes_1.id = nodes_2.parent_id",
use_default_dialect=True
)
+ def test_explicit_join_3(self):
+ Node = self.classes.Node
+ n1 = aliased(Node)
+ n2 = aliased(Node)
+
# the join_to_left=False here is unfortunate. the default on this flag should
# be False.
self.assert_compile(
join(Node, n1, Node.children).join(n2, Node.children, join_to_left=False),
- "nodes JOIN nodes AS nodes_1 ON nodes.id = nodes_1.parent_id JOIN nodes AS nodes_2 ON nodes.id = nodes_2.parent_id",
+ "nodes JOIN nodes AS nodes_1 ON nodes.id = nodes_1.parent_id "
+ "JOIN nodes AS nodes_2 ON nodes.id = nodes_2.parent_id",
use_default_dialect=True
)
+ def test_explicit_join_4(self):
+ Node = self.classes.Node
+ sess = create_session()
+ n1 = aliased(Node)
+ n2 = aliased(Node)
+
self.assert_compile(
sess.query(Node).join(n1, Node.children).join(n2, n1.children),
"SELECT nodes.id AS nodes_id, nodes.parent_id AS nodes_parent_id, nodes.data AS "
@@ -1964,6 +1987,12 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL):
use_default_dialect=True
)
+ def test_explicit_join_5(self):
+ Node = self.classes.Node
+ sess = create_session()
+ n1 = aliased(Node)
+ n2 = aliased(Node)
+
self.assert_compile(
sess.query(Node).join(n1, Node.children).join(n2, Node.children),
"SELECT nodes.id AS nodes_id, nodes.parent_id AS nodes_parent_id, nodes.data AS "
@@ -1972,25 +2001,59 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL):
use_default_dialect=True
)
- node = sess.query(Node).select_from(join(Node, n1, 'children')).filter(n1.data=='n122').first()
- assert node.data=='n12'
+ def test_explicit_join_6(self):
+ Node = self.classes.Node
+ sess = create_session()
+ n1 = aliased(Node)
- node = sess.query(Node).select_from(join(Node, n1, 'children').join(n2, 'children')).\
- filter(n2.data=='n122').first()
- assert node.data=='n1'
+ node = sess.query(Node).select_from(join(Node, n1, 'children')).\
+ filter(n1.data == 'n122').first()
+ assert node.data == 'n12'
+
+ def test_explicit_join_7(self):
+ Node = self.classes.Node
+ sess = create_session()
+ n1 = aliased(Node)
+ n2 = aliased(Node)
+
+ node = sess.query(Node).select_from(
+ join(Node, n1, 'children').join(n2, 'children')).\
+ filter(n2.data == 'n122').first()
+ assert node.data == 'n1'
+
+ def test_explicit_join_8(self):
+ Node = self.classes.Node
+ sess = create_session()
+ n1 = aliased(Node)
+ n2 = aliased(Node)
# mix explicit and named onclauses
- node = sess.query(Node).select_from(join(Node, n1, Node.id==n1.parent_id).join(n2, 'children')).\
- filter(n2.data=='n122').first()
- assert node.data=='n1'
+ node = sess.query(Node).select_from(
+ join(Node, n1, Node.id == n1.parent_id).join(n2, 'children')).\
+ filter(n2.data == 'n122').first()
+ assert node.data == 'n1'
+
+ def test_explicit_join_9(self):
+ Node = self.classes.Node
+ sess = create_session()
+ n1 = aliased(Node)
+ n2 = aliased(Node)
node = sess.query(Node).select_from(join(Node, n1, 'parent').join(n2, 'parent')).\
- filter(and_(Node.data=='n122', n1.data=='n12', n2.data=='n1')).first()
+ filter(and_(Node.data == 'n122', n1.data == 'n12', n2.data == 'n1')).first()
assert node.data == 'n122'
+ def test_explicit_join_10(self):
+ Node = self.classes.Node
+ sess = create_session()
+ n1 = aliased(Node)
+ n2 = aliased(Node)
+
eq_(
list(sess.query(Node).select_from(join(Node, n1, 'parent').join(n2, 'parent')).\
- filter(and_(Node.data=='n122', n1.data=='n12', n2.data=='n1')).values(Node.data, n1.data, n2.data)),
+ filter(and_(Node.data == 'n122',
+ n1.data == 'n12',
+ n2.data == 'n1')).values(Node.data, n1.data, n2.data)),
[('n122', 'n12', 'n1')])
def test_join_to_nonaliased(self):
@@ -2040,8 +2103,8 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL):
sess.query(Node, parent, grandparent).\
join(parent, Node.parent).\
join(grandparent, parent.parent).\
- filter(Node.data=='n122').filter(parent.data=='n12').\
- filter(grandparent.data=='n1').from_self().first(),
+ filter(Node.data == 'n122').filter(parent.data == 'n12').\
+ filter(grandparent.data == 'n1').from_self().first(),
(Node(data='n122'), Node(data='n12'), Node(data='n1'))
)