diff options
Diffstat (limited to 'lib/sqlalchemy/sql.py')
-rw-r--r-- | lib/sqlalchemy/sql.py | 53 |
1 files changed, 50 insertions, 3 deletions
diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 273af5415..9c8d5db08 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -1549,8 +1549,55 @@ class Join(FromClause): def _group_parenthesized(self): return True - def select(self, whereclauses = None, **params): - return select([self.left, self.right], whereclauses, from_obj=[self], **params) + def _get_folded_equivalents(self, equivs=None): + if equivs is None: + equivs = util.Set() + class LocateEquivs(ClauseVisitor): + def visit_binary(self, binary): + if binary.operator == '=' and binary.left.name == binary.right.name: + equivs.add(binary.right) + equivs.add(binary.left) + self.onclause.accept_visitor(LocateEquivs()) + collist = [] + if isinstance(self.left, Join): + left = self.left._get_folded_equivalents(equivs) + else: + left = list(self.left.columns) + if isinstance(self.right, Join): + right = self.right._get_folded_equivalents(equivs) + else: + right = list(self.right.columns) + used = util.Set() + for c in left + right: + if c in equivs: + if c.name not in used: + collist.append(c) + used.add(c.name) + else: + collist.append(c) + return collist + + def select(self, whereclause = None, fold_equivalents=False, **kwargs): + """Create a ``Select`` from this ``Join``. + + whereclause + the WHERE criterion that will be sent to the ``select()`` function + + fold_equivalents + based on the join criterion of this ``Join``, do not include equivalent + columns in the column list of the resulting select. this will recursively + apply to any joins directly nested by this one as well. + + **kwargs + all other kwargs are sent to the underlying ``select()`` function + + """ + if fold_equivalents: + collist = self._get_folded_equivalents() + else: + collist = [self.left, self.right] + + return select(collist, whereclause, from_obj=[self], **kwargs) def accept_visitor(self, visitor): self.left.accept_visitor(visitor) @@ -1912,7 +1959,7 @@ class Select(_SelectBaseMixin, FromClause): self.offset = offset self.for_update = for_update self.is_compound = False - + # indicates that this select statement should not expand its columns # into the column clause of an enclosing select, and should instead # act like a single scalar column |