diff options
Diffstat (limited to 'lib/sqlalchemy/sql/expression.py')
-rw-r--r-- | lib/sqlalchemy/sql/expression.py | 32 |
1 files changed, 24 insertions, 8 deletions
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index f836d7eaf..bc0497bea 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -3500,9 +3500,14 @@ class _Exists(_UnaryExpression): def select(self, whereclause=None, **params): return select([self], whereclause, **params) - def correlate(self, fromclause): + def correlate(self, *fromclause): e = self._clone() - e.element = self.element.correlate(fromclause).self_group() + e.element = self.element.correlate(*fromclause).self_group() + return e + + def correlate_except(self, *fromclause): + e = self._clone() + e.element = self.element.correlate_except(*fromclause).self_group() return e def select_from(self, clause): @@ -4708,7 +4713,8 @@ class Select(_SelectBase): _hints = util.immutabledict() _distinct = False _from_cloned = None - + _correlate = () + _correlate_except = () _memoized_property = _SelectBase._memoized_property def __init__(self, @@ -4750,7 +4756,6 @@ class Select(_SelectBase): for e in util.to_list(distinct) ] - self._correlate = set() if from_obj is not None: self._from_obj = util.OrderedSet( _literal_as_text(f) @@ -4837,10 +4842,13 @@ class Select(_SelectBase): # using a list to maintain ordering froms = [f for f in froms if f not in toremove] - if len(froms) > 1 or self._correlate: + if len(froms) > 1 or self._correlate or self._correlate_except: if self._correlate: froms = [f for f in froms if f not in _cloned_intersection(froms, self._correlate)] + if self._correlate_except: + froms = [f for f in froms if f in _cloned_intersection(froms, + self._correlate_except)] if self._should_correlate and existing_froms: froms = [f for f in froms if f not in _cloned_intersection(froms, existing_froms)] @@ -5198,16 +5206,24 @@ class Select(_SelectBase): """ self._should_correlate = False if fromclauses and fromclauses[0] is None: - self._correlate = set() + self._correlate = () + else: + self._correlate = set(self._correlate).union(fromclauses) + + @_generative + def correlate_except(self, *fromclauses): + self._should_correlate = False + if fromclauses and fromclauses[0] is None: + self._correlate_except = () else: - self._correlate = self._correlate.union(fromclauses) + self._correlate_except = set(self._correlate_except).union(fromclauses) def append_correlation(self, fromclause): """append the given correlation expression to this select() construct.""" self._should_correlate = False - self._correlate = self._correlate.union([fromclause]) + self._correlate = set(self._correlate).union([fromclause]) def append_column(self, column): """append the given column expression to the columns clause of this |