summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2010-03-19 15:30:48 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2010-03-19 15:30:48 -0400
commitc6fbff56a38e23bfde3bd8d3982c4eb1e944be03 (patch)
tree5ccd7fcf0a0fe9be51aebbc5abde8feab87084fe
parent5f15e5569c89cc39918752d54520abb89b760a18 (diff)
downloadsqlalchemy-c6fbff56a38e23bfde3bd8d3982c4eb1e944be03.tar.gz
- join() will now simulate a NATURAL JOIN by default. Meaning,
if the left side is a join, it will attempt to join the right side to the rightmost side of the left first, and not raise any exceptions about ambiguous join conditions if successful even if there are further join targets across the rest of the left. [ticket:1714]
-rw-r--r--CHANGES7
-rw-r--r--lib/sqlalchemy/sql/expression.py8
-rw-r--r--lib/sqlalchemy/sql/util.py78
-rw-r--r--test/sql/test_select.py76
-rw-r--r--test/sql/test_selectable.py108
5 files changed, 215 insertions, 62 deletions
diff --git a/CHANGES b/CHANGES
index 1e37432bd..abf9e775f 100644
--- a/CHANGES
+++ b/CHANGES
@@ -142,6 +142,13 @@ CHANGES
version_id_col was in use. [ticket:1692]
- sql
+ - join() will now simulate a NATURAL JOIN by default. Meaning,
+ if the left side is a join, it will attempt to join the right
+ side to the rightmost side of the left first, and not raise
+ any exceptions about ambiguous join conditions if successful
+ even if there are further join targets across the rest of
+ the left. [ticket:1714]
+
- The most common result processors conversion function were
moved to the new "processors" module. Dialect authors are
encouraged to use those functions whenever they correspond
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index 9bc127291..1e02ba96a 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -2853,11 +2853,15 @@ class Join(FromClause):
def get_children(self, **kwargs):
return self.left, self.right, self.onclause
- def _match_primaries(self, primary, secondary):
+ def _match_primaries(self, left, right):
global sql_util
if not sql_util:
from sqlalchemy.sql import util as sql_util
- return sql_util.join_condition(primary, secondary)
+ if isinstance(left, Join):
+ left_right = left.right
+ else:
+ left_right = None
+ return sql_util.join_condition(left, right, a_subset=left_right)
def select(self, whereclause=None, fold_equivalents=False, **kwargs):
"""Create a :class:`Select` from this :class:`Join`.
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index 1b90a457f..74651a9d1 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -131,49 +131,77 @@ def adapt_criterion_to_null(crit, nulls):
return visitors.cloned_traverse(crit, {}, {'binary':visit_binary})
-def join_condition(a, b, ignore_nonexistent_tables=False):
- """create a join condition between two tables.
+def join_condition(a, b, ignore_nonexistent_tables=False, a_subset=None):
+ """create a join condition between two tables or selectables.
- ignore_nonexistent_tables=True allows a join condition to be
- determined between two tables which may contain references to
- other not-yet-defined tables. In general the NoSuchTableError
- raised is only required if the user is trying to join selectables
- across multiple MetaData objects (which is an extremely rare use
- case).
+ e.g.::
+
+ join_condition(tablea, tableb)
+
+ would produce an expression along the lines of::
+
+ tablea.c.id==tableb.c.tablea_id
+
+ The join is determined based on the foreign key relationships
+ between the two selectables. If there are multiple ways
+ to join, or no way to join, an error is raised.
+
+ :param ignore_nonexistent_tables: This flag will cause the
+ function to silently skip over foreign key resolution errors
+ due to nonexistent tables - the assumption is that these
+ tables have not yet been defined within an initialization process
+ and are not significant to the operation.
+
+ :param a_subset: An optional expression that is a sub-component
+ of ``a``. An attempt will be made to join to just this sub-component
+ first before looking at the full ``a`` construct, and if found
+ will be successful even if there are other ways to join to ``a``.
+ This allows the "right side" of a join to be passed thereby
+ providing a "natural join".
"""
crit = []
constraints = set()
- for fk in b.foreign_keys:
- try:
- col = fk.get_referent(a)
- except exc.NoReferencedTableError:
- if ignore_nonexistent_tables:
- continue
- else:
- raise
-
- if col is not None:
- crit.append(col == fk.parent)
- constraints.add(fk.constraint)
- if a is not b:
- for fk in a.foreign_keys:
+
+ for left in (a_subset, a):
+ if left is None:
+ continue
+ for fk in b.foreign_keys:
try:
- col = fk.get_referent(b)
+ col = fk.get_referent(left)
except exc.NoReferencedTableError:
if ignore_nonexistent_tables:
continue
else:
raise
-
+
if col is not None:
crit.append(col == fk.parent)
constraints.add(fk.constraint)
+ if left is not b:
+ for fk in left.foreign_keys:
+ try:
+ col = fk.get_referent(b)
+ except exc.NoReferencedTableError:
+ if ignore_nonexistent_tables:
+ continue
+ else:
+ raise
+ if col is not None:
+ crit.append(col == fk.parent)
+ constraints.add(fk.constraint)
+ if crit:
+ break
+
if len(crit) == 0:
+ if isinstance(b, expression._FromGrouping):
+ hint = " Perhaps you meant to convert the right side to a subquery using alias()?"
+ else:
+ hint = ""
raise exc.ArgumentError(
"Can't find any foreign key relationships "
- "between '%s' and '%s'" % (a.description, b.description))
+ "between '%s' and '%s'.%s" % (a.description, b.description, hint))
elif len(constraints) > 1:
raise exc.ArgumentError(
"Can't determine join between '%s' and '%s'; "
diff --git a/test/sql/test_select.py b/test/sql/test_select.py
index c97685dcb..d6a3804be 100644
--- a/test/sql/test_select.py
+++ b/test/sql/test_select.py
@@ -77,10 +77,14 @@ class SelectTest(TestBase, AssertsCompiledSQL):
assert not hasattr(table1.alias().c.myid, 'c')
def test_table_select(self):
- self.assert_compile(table1.select(), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable")
-
- self.assert_compile(select([table1, table2]), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, \
-myothertable.othername FROM mytable, myothertable")
+ self.assert_compile(table1.select(),
+ "SELECT mytable.myid, mytable.name, "
+ "mytable.description FROM mytable")
+
+ self.assert_compile(select([table1, table2]),
+ "SELECT mytable.myid, mytable.name, mytable.description, "
+ "myothertable.otherid, myothertable.othername FROM mytable, "
+ "myothertable")
def test_invalid_col_argument(self):
assert_raises(exc.ArgumentError, select, table1)
@@ -97,13 +101,15 @@ myothertable.othername FROM mytable, myothertable")
s.c.myid == 7
)
,
- "SELECT myid, name, description FROM (SELECT mytable.myid AS myid, mytable.name AS name, mytable.description AS description FROM mytable "\
+ "SELECT myid, name, description FROM (SELECT mytable.myid AS myid, "
+ "mytable.name AS name, mytable.description AS description FROM mytable "
"WHERE mytable.name = :name_1) WHERE myid = :myid_1")
sq = select([table1])
self.assert_compile(
sq.select(),
- "SELECT myid, name, description FROM (SELECT mytable.myid AS myid, mytable.name AS name, mytable.description AS description FROM mytable)"
+ "SELECT myid, name, description FROM (SELECT mytable.myid AS myid, "
+ "mytable.name AS name, mytable.description AS description FROM mytable)"
)
sq = select(
@@ -112,8 +118,9 @@ myothertable.othername FROM mytable, myothertable")
self.assert_compile(
sq.select(sq.c.myid == 7),
- "SELECT sq.myid, sq.name, sq.description FROM \
-(SELECT mytable.myid AS myid, mytable.name AS name, mytable.description AS description FROM mytable) AS sq WHERE sq.myid = :myid_1"
+ "SELECT sq.myid, sq.name, sq.description FROM "
+ "(SELECT mytable.myid AS myid, mytable.name AS name, "
+ "mytable.description AS description FROM mytable) AS sq WHERE sq.myid = :myid_1"
)
sq = select(
@@ -1112,8 +1119,9 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
def test_joins(self):
self.assert_compile(
join(table2, table1, table1.c.myid == table2.c.otherid).select(),
- "SELECT myothertable.otherid, myothertable.othername, mytable.myid, mytable.name, \
-mytable.description FROM myothertable JOIN mytable ON mytable.myid = myothertable.otherid"
+ "SELECT myothertable.otherid, myothertable.othername, "
+ "mytable.myid, mytable.name, mytable.description FROM "
+ "myothertable JOIN mytable ON mytable.myid = myothertable.otherid"
)
self.assert_compile(
@@ -1121,34 +1129,51 @@ mytable.description FROM myothertable JOIN mytable ON mytable.myid = myothertabl
[table1],
from_obj = [join(table1, table2, table1.c.myid == table2.c.otherid)]
),
- "SELECT mytable.myid, mytable.name, mytable.description FROM mytable JOIN myothertable ON mytable.myid = myothertable.otherid")
+ "SELECT mytable.myid, mytable.name, mytable.description FROM "
+ "mytable JOIN myothertable ON mytable.myid = myothertable.otherid")
self.assert_compile(
select(
- [join(join(table1, table2, table1.c.myid == table2.c.otherid), table3, table1.c.myid == table3.c.userid)]
+ [join(join(table1, table2, table1.c.myid == table2.c.otherid),
+ table3, table1.c.myid == table3.c.userid)]
),
- "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername, thirdtable.userid, thirdtable.otherstuff FROM mytable JOIN myothertable ON mytable.myid = myothertable.otherid JOIN thirdtable ON mytable.myid = thirdtable.userid"
+ "SELECT mytable.myid, mytable.name, mytable.description, "
+ "myothertable.otherid, myothertable.othername, thirdtable.userid, "
+ "thirdtable.otherstuff FROM mytable JOIN myothertable ON mytable.myid ="
+ " myothertable.otherid JOIN thirdtable ON mytable.myid = thirdtable.userid"
)
self.assert_compile(
join(users, addresses, users.c.user_id==addresses.c.user_id).select(),
- "SELECT users.user_id, users.user_name, users.password, addresses.address_id, addresses.user_id, addresses.street, addresses.city, addresses.state, addresses.zip FROM users JOIN addresses ON users.user_id = addresses.user_id"
+ "SELECT users.user_id, users.user_name, users.password, "
+ "addresses.address_id, addresses.user_id, addresses.street, "
+ "addresses.city, addresses.state, addresses.zip FROM users JOIN addresses "
+ "ON users.user_id = addresses.user_id"
)
self.assert_compile(
select([table1, table2, table3],
- from_obj = [join(table1, table2, table1.c.myid == table2.c.otherid).outerjoin(table3, table1.c.myid==table3.c.userid)]
-
- #from_obj = [outerjoin(join(table, table2, table1.c.myid == table2.c.otherid), table3, table1.c.myid==table3.c.userid)]
+ from_obj = [join(table1, table2, table1.c.myid == table2.c.otherid).
+ outerjoin(table3, table1.c.myid==table3.c.userid)]
)
- ,"SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername, thirdtable.userid, thirdtable.otherstuff FROM mytable JOIN myothertable ON mytable.myid = myothertable.otherid LEFT OUTER JOIN thirdtable ON mytable.myid = thirdtable.userid"
+ ,"SELECT mytable.myid, mytable.name, mytable.description, "
+ "myothertable.otherid, myothertable.othername, thirdtable.userid,"
+ " thirdtable.otherstuff FROM mytable JOIN myothertable ON mytable.myid "
+ "= myothertable.otherid LEFT OUTER JOIN thirdtable ON mytable.myid ="
+ " thirdtable.userid"
)
self.assert_compile(
select([table1, table2, table3],
- from_obj = [outerjoin(table1, join(table2, table3, table2.c.otherid == table3.c.userid), table1.c.myid==table2.c.otherid)]
+ from_obj = [outerjoin(table1,
+ join(table2, table3, table2.c.otherid == table3.c.userid),
+ table1.c.myid==table2.c.otherid)]
)
- ,"SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername, thirdtable.userid, thirdtable.otherstuff FROM mytable LEFT OUTER JOIN (myothertable JOIN thirdtable ON myothertable.otherid = thirdtable.userid) ON mytable.myid = myothertable.otherid"
+ ,"SELECT mytable.myid, mytable.name, mytable.description, "
+ "myothertable.otherid, myothertable.othername, thirdtable.userid,"
+ " thirdtable.otherstuff FROM mytable LEFT OUTER JOIN (myothertable "
+ "JOIN thirdtable ON myothertable.otherid = thirdtable.userid) ON "
+ "mytable.myid = myothertable.otherid"
)
query = select(
@@ -1162,11 +1187,12 @@ mytable.description FROM myothertable JOIN mytable ON mytable.myid = myothertabl
from_obj = [ outerjoin(table1, table2, table1.c.myid == table2.c.otherid) ]
)
self.assert_compile(query,
- "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername \
-FROM mytable LEFT OUTER JOIN myothertable ON mytable.myid = myothertable.otherid \
-WHERE mytable.name = :name_1 OR mytable.myid = :myid_1 OR \
-myothertable.othername != :othername_1 OR \
-EXISTS (select yay from foo where boo = lar)",
+ "SELECT mytable.myid, mytable.name, mytable.description, "
+ "myothertable.otherid, myothertable.othername "
+ "FROM mytable LEFT OUTER JOIN myothertable ON mytable.myid = "
+ "myothertable.otherid WHERE mytable.name = :name_1 OR "
+ "mytable.myid = :myid_1 OR myothertable.othername != :othername_1 "
+ "OR EXISTS (select yay from foo where boo = lar)",
)
def test_compound_selects(self):
diff --git a/test/sql/test_selectable.py b/test/sql/test_selectable.py
index 78455e6d6..13f629e28 100644
--- a/test/sql/test_selectable.py
+++ b/test/sql/test_selectable.py
@@ -242,6 +242,76 @@ class SelectableTest(TestBase, AssertsExecutionResults):
s = select([t2, t3], use_labels=True)
assert_raises(exc.NoReferencedTableError, s.join, t1)
+
+ def test_join_condition(self):
+ m = MetaData()
+ t1 = Table('t1', m, Column('id', Integer))
+ t2 = Table('t2', m, Column('id', Integer), Column('t1id', ForeignKey('t1.id')))
+ t3 = Table('t3', m, Column('id', Integer),
+ Column('t1id', ForeignKey('t1.id')),
+ Column('t2id', ForeignKey('t2.id')))
+ t4 = Table('t4', m, Column('id', Integer), Column('t2id', ForeignKey('t2.id')))
+
+ t1t2 = t1.join(t2)
+ t2t3 = t2.join(t3)
+
+ for left, right, a_subset, expected in [
+ (t1, t2, None, t1.c.id==t2.c.t1id),
+ (t1t2, t3, t2, t1t2.c.t2_id==t3.c.t2id),
+ (t2t3, t1, t3, t1.c.id==t3.c.t1id),
+ (t2t3, t4, None, t2t3.c.t2_id==t4.c.t2id),
+ (t2t3, t4, t3, t2t3.c.t2_id==t4.c.t2id),
+ (t2t3.join(t1), t4, None, t2t3.c.t2_id==t4.c.t2id),
+ (t2t3.join(t1), t4, t1, t2t3.c.t2_id==t4.c.t2id),
+ (t1t2, t2t3, t2, t1t2.c.t2_id==t2t3.c.t3_t2id),
+ ]:
+ assert expected.compare(
+ sql_util.join_condition(left, right, a_subset=a_subset)
+ )
+
+ # these are ambiguous, or have no joins
+ for left, right, a_subset in [
+ (t1t2, t3, None),
+ (t2t3, t1, None),
+ (t1, t4, None),
+ (t1t2, t2t3, None),
+ ]:
+ assert_raises(
+ exc.ArgumentError,
+ sql_util.join_condition,
+ left, right, a_subset=a_subset
+ )
+
+ als = t2t3.alias()
+ # test join's behavior, including natural
+ for left, right, expected in [
+ (t1, t2, t1.c.id==t2.c.t1id),
+ (t1t2, t3, t1t2.c.t2_id==t3.c.t2id),
+ (t2t3, t1, t1.c.id==t3.c.t1id),
+ (t2t3, t4, t2t3.c.t2_id==t4.c.t2id),
+ (t2t3, t4, t2t3.c.t2_id==t4.c.t2id),
+ (t2t3.join(t1), t4, t2t3.c.t2_id==t4.c.t2id),
+ (t2t3.join(t1), t4, t2t3.c.t2_id==t4.c.t2id),
+ (t1t2, als, t1t2.c.t2_id==als.c.t3_t2id)
+ ]:
+ assert expected.compare(
+ left.join(right).onclause
+ )
+
+ # TODO: this raises due to right side being "grouped",
+ # and no longer has FKs. Did we want to make
+ # _FromGrouping friendlier ?
+ assert_raises_message(
+ exc.ArgumentError,
+ r"Perhaps you meant to convert the right side to a subquery using alias\(\)\?",
+ t1t2.join, t2t3
+ )
+
+ assert_raises_message(
+ exc.ArgumentError,
+ r"Perhaps you meant to convert the right side to a subquery using alias\(\)\?",
+ t1t2.join, t2t3.select(use_labels=True)
+ )
class PrimaryKeyTest(TestBase, AssertsExecutionResults):
def test_join_pk_collapse_implicit(self):
@@ -287,8 +357,12 @@ class PrimaryKeyTest(TestBase, AssertsExecutionResults):
def test_init_doesnt_blowitaway(self):
meta = MetaData()
- a = Table('a', meta, Column('id', Integer, primary_key=True), Column('x', Integer))
- b = Table('b', meta, Column('id', Integer, ForeignKey('a.id'), primary_key=True), Column('x', Integer))
+ a = Table('a', meta,
+ Column('id', Integer, primary_key=True),
+ Column('x', Integer))
+ b = Table('b', meta,
+ Column('id', Integer, ForeignKey('a.id'), primary_key=True),
+ Column('x', Integer))
j = a.join(b)
assert list(j.primary_key) == [a.c.id]
@@ -298,8 +372,12 @@ class PrimaryKeyTest(TestBase, AssertsExecutionResults):
def test_non_column_clause(self):
meta = MetaData()
- a = Table('a', meta, Column('id', Integer, primary_key=True), Column('x', Integer))
- b = Table('b', meta, Column('id', Integer, ForeignKey('a.id'), primary_key=True), Column('x', Integer, primary_key=True))
+ a = Table('a', meta,
+ Column('id', Integer, primary_key=True),
+ Column('x', Integer))
+ b = Table('b', meta,
+ Column('id', Integer, ForeignKey('a.id'), primary_key=True),
+ Column('x', Integer, primary_key=True))
j = a.join(b, and_(a.c.id==b.c.id, b.c.x==5))
assert str(j) == "a JOIN b ON a.id = b.id AND b.x = :x_1", str(j)
@@ -343,7 +421,9 @@ class ReduceTest(TestBase, AssertsExecutionResults):
eq_(
- util.column_set(sql_util.reduce_columns([t1.c.t1id, t1.c.t1data, t2.c.t2id, t2.c.t2data, t3.c.t3id, t3.c.t3data])),
+ util.column_set(sql_util.reduce_columns([
+ t1.c.t1id, t1.c.t1data, t2.c.t2id,
+ t2.c.t2data, t3.c.t3id, t3.c.t3data])),
util.column_set([t1.c.t1id, t1.c.t1data, t2.c.t2data, t3.c.t3data])
)
@@ -386,9 +466,13 @@ class ReduceTest(TestBase, AssertsExecutionResults):
Column('manager_name', String(50))
)
- pjoin = people.outerjoin(engineers).outerjoin(managers).select(use_labels=True).alias('pjoin')
+ pjoin = people.outerjoin(engineers).\
+ outerjoin(managers).select(use_labels=True).\
+ alias('pjoin')
eq_(
- util.column_set(sql_util.reduce_columns([pjoin.c.people_person_id, pjoin.c.engineers_person_id, pjoin.c.managers_person_id])),
+ util.column_set(sql_util.reduce_columns([
+ pjoin.c.people_person_id, pjoin.c.engineers_person_id,
+ pjoin.c.managers_person_id])),
util.column_set([pjoin.c.people_person_id])
)
@@ -412,7 +496,9 @@ class ReduceTest(TestBase, AssertsExecutionResults):
}, None, 'item_join')
eq_(
- util.column_set(sql_util.reduce_columns([item_join.c.id, item_join.c.dummy, item_join.c.child_name])),
+ util.column_set(sql_util.reduce_columns([
+ item_join.c.id, item_join.c.dummy, item_join.c.child_name
+ ])),
util.column_set([item_join.c.id, item_join.c.dummy, item_join.c.child_name])
)
@@ -426,7 +512,8 @@ class ReduceTest(TestBase, AssertsExecutionResults):
Column('page_id', Integer, ForeignKey('page.id'), primary_key=True),
)
classified_page_table = Table('classified_page', metadata,
- Column('magazine_page_id', Integer, ForeignKey('magazine_page.page_id'), primary_key=True),
+ Column('magazine_page_id', Integer,
+ ForeignKey('magazine_page.page_id'), primary_key=True),
)
# this is essentially the union formed by the ORM's polymorphic_union function.
@@ -472,7 +559,8 @@ class ReduceTest(TestBase, AssertsExecutionResults):
).alias('pjoin')
eq_(
- util.column_set(sql_util.reduce_columns([pjoin.c.id, pjoin.c.page_id, pjoin.c.magazine_page_id])),
+ util.column_set(sql_util.reduce_columns([
+ pjoin.c.id, pjoin.c.page_id, pjoin.c.magazine_page_id])),
util.column_set([pjoin.c.id])
)