diff options
Diffstat (limited to 'lib/sqlalchemy/sql/expression.py')
-rw-r--r-- | lib/sqlalchemy/sql/expression.py | 106 |
1 files changed, 54 insertions, 52 deletions
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index c7ab34272..b3200a7eb 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -863,6 +863,16 @@ class ClauseElement(object): raise NotImplementedError(repr(self)) + def _aggregate_hide_froms(self, **modifiers): + """Return a list of ``FROM`` clause elements which this ``ClauseElement`` replaces, taking into account + previous ClauseElements which this ClauseElement is a clone of.""" + + s = self + while s is not None: + for h in s._hide_froms(**modifiers): + yield h + s = getattr(s, '_is_clone_of', None) + def _hide_froms(self, **modifiers): """Return a list of ``FROM`` clause elements which this ``ClauseElement`` replaces.""" @@ -2203,11 +2213,10 @@ class Join(FromClause): else: equivs[x] = util.Set([y]) - class BinaryVisitor(visitors.ClauseVisitor): - def visit_binary(self, binary): - if binary.operator == operators.eq and isinstance(binary.left, schema.Column) and isinstance(binary.right, schema.Column): - add_equiv(binary.left, binary.right) - BinaryVisitor().traverse(self.onclause) + def visit_binary(binary): + if binary.operator == operators.eq and isinstance(binary.left, schema.Column) and isinstance(binary.right, schema.Column): + add_equiv(binary.left, binary.right) + visitors.traverse(self.onclause, visit_binary=visit_binary) for col in pkcol: for fk in col.foreign_keys: @@ -2719,8 +2728,8 @@ class _SelectBaseMixin(object): self._offset = offset self._bind = bind - self.append_order_by(*util.to_list(order_by, [])) - self.append_group_by(*util.to_list(group_by, [])) + self._order_by_clause = ClauseList(*util.to_list(order_by, [])) + self._group_by_clause = ClauseList(*util.to_list(group_by, [])) def as_scalar(self): """return a 'scalar' representation of this selectable, which can be used @@ -2967,30 +2976,41 @@ class Select(_SelectBaseMixin, FromClause): # usually called via a generative method, create a copy of each collection # by default - self._raw_columns = [] self.__correlate = util.Set() - self._froms = util.OrderedSet() - self._whereclause = None self._having = None self._prefixes = [] - if columns is not None: - for c in columns: - self.append_column(c, _copy_collection=False) - - if from_obj is not None: - for f in from_obj: - self.append_from(f, _copy_collection=False) + if columns: + self._raw_columns = [ + isinstance(c, _ScalarSelect) and c.self_group(against=operators.comma_op) or c + for c in + [_literal_as_column(c) for c in columns] + ] + else: + self._raw_columns = [] + + if from_obj: + self._froms = util.Set([ + _is_literal(f) and _TextFromClause(f) or f + for f in from_obj + ]) + else: + self._froms = util.Set() - if whereclause is not None: - self.append_whereclause(whereclause) + if whereclause: + self._whereclause = _literal_as_text(whereclause) + else: + self._whereclause = None - if having is not None: - self.append_having(having) + if having: + self._having = _literal_as_text(having) + else: + self._having = None - if prefixes is not None: - for p in prefixes: - self.append_prefix(p, _copy_collection=False) + if prefixes: + self._prefixes = [_literal_as_text(p) for p in prefixes] + else: + self._prefixes = [] _SelectBaseMixin.__init__(self, **kwargs) @@ -3003,48 +3023,30 @@ class Select(_SelectBaseMixin, FromClause): correlating. """ - froms = util.OrderedSet() + froms = util.Set() hide_froms = util.Set() for col in self._raw_columns: - for f in col._hide_froms(): - hide_froms.add(f) - while hasattr(f, '_is_clone_of'): - hide_froms.add(f._is_clone_of) - f = f._is_clone_of - for f in col._get_from_objects(): - froms.add(f) + hide_froms.update(col._aggregate_hide_froms()) + froms.update(col._get_from_objects()) if self._whereclause is not None: - for f in self._whereclause._get_from_objects(is_where=True): - froms.add(f) + froms.update(self._whereclause._get_from_objects(is_where=True)) - for elem in self._froms: - froms.add(elem) - for f in elem._get_from_objects(): - froms.add(f) - - for elem in froms: - for f in elem._hide_froms(): - hide_froms.add(f) - while hasattr(f, '_is_clone_of'): - hide_froms.add(f._is_clone_of) - f = f._is_clone_of + if self._froms: + froms.update(self._froms) + for elem in self._froms: + hide_froms.update(elem._aggregate_hide_froms()) froms = froms.difference(hide_froms) if len(froms) > 1: corr = self.__correlate if self._should_correlate and existing_froms is not None: - corr = existing_froms.union(corr) - - for f in list(corr): - while hasattr(f, '_is_clone_of'): - corr.add(f._is_clone_of) - f = f._is_clone_of + corr.update(existing_froms) f = froms.difference(corr) - if len(f) == 0: + if not f: raise exceptions.InvalidRequestError("Select statement '%s' is overcorrelated; returned no 'from' clauses" % str(self.__dont_correlate())) return f else: |