summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/expression.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/expression.py')
-rw-r--r--lib/sqlalchemy/sql/expression.py106
1 files changed, 54 insertions, 52 deletions
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index c7ab34272..b3200a7eb 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -863,6 +863,16 @@ class ClauseElement(object):
raise NotImplementedError(repr(self))
+ def _aggregate_hide_froms(self, **modifiers):
+ """Return a list of ``FROM`` clause elements which this ``ClauseElement`` replaces, taking into account
+ previous ClauseElements which this ClauseElement is a clone of."""
+
+ s = self
+ while s is not None:
+ for h in s._hide_froms(**modifiers):
+ yield h
+ s = getattr(s, '_is_clone_of', None)
+
def _hide_froms(self, **modifiers):
"""Return a list of ``FROM`` clause elements which this ``ClauseElement`` replaces."""
@@ -2203,11 +2213,10 @@ class Join(FromClause):
else:
equivs[x] = util.Set([y])
- class BinaryVisitor(visitors.ClauseVisitor):
- def visit_binary(self, binary):
- if binary.operator == operators.eq and isinstance(binary.left, schema.Column) and isinstance(binary.right, schema.Column):
- add_equiv(binary.left, binary.right)
- BinaryVisitor().traverse(self.onclause)
+ def visit_binary(binary):
+ if binary.operator == operators.eq and isinstance(binary.left, schema.Column) and isinstance(binary.right, schema.Column):
+ add_equiv(binary.left, binary.right)
+ visitors.traverse(self.onclause, visit_binary=visit_binary)
for col in pkcol:
for fk in col.foreign_keys:
@@ -2719,8 +2728,8 @@ class _SelectBaseMixin(object):
self._offset = offset
self._bind = bind
- self.append_order_by(*util.to_list(order_by, []))
- self.append_group_by(*util.to_list(group_by, []))
+ self._order_by_clause = ClauseList(*util.to_list(order_by, []))
+ self._group_by_clause = ClauseList(*util.to_list(group_by, []))
def as_scalar(self):
"""return a 'scalar' representation of this selectable, which can be used
@@ -2967,30 +2976,41 @@ class Select(_SelectBaseMixin, FromClause):
# usually called via a generative method, create a copy of each collection
# by default
- self._raw_columns = []
self.__correlate = util.Set()
- self._froms = util.OrderedSet()
- self._whereclause = None
self._having = None
self._prefixes = []
- if columns is not None:
- for c in columns:
- self.append_column(c, _copy_collection=False)
-
- if from_obj is not None:
- for f in from_obj:
- self.append_from(f, _copy_collection=False)
+ if columns:
+ self._raw_columns = [
+ isinstance(c, _ScalarSelect) and c.self_group(against=operators.comma_op) or c
+ for c in
+ [_literal_as_column(c) for c in columns]
+ ]
+ else:
+ self._raw_columns = []
+
+ if from_obj:
+ self._froms = util.Set([
+ _is_literal(f) and _TextFromClause(f) or f
+ for f in from_obj
+ ])
+ else:
+ self._froms = util.Set()
- if whereclause is not None:
- self.append_whereclause(whereclause)
+ if whereclause:
+ self._whereclause = _literal_as_text(whereclause)
+ else:
+ self._whereclause = None
- if having is not None:
- self.append_having(having)
+ if having:
+ self._having = _literal_as_text(having)
+ else:
+ self._having = None
- if prefixes is not None:
- for p in prefixes:
- self.append_prefix(p, _copy_collection=False)
+ if prefixes:
+ self._prefixes = [_literal_as_text(p) for p in prefixes]
+ else:
+ self._prefixes = []
_SelectBaseMixin.__init__(self, **kwargs)
@@ -3003,48 +3023,30 @@ class Select(_SelectBaseMixin, FromClause):
correlating.
"""
- froms = util.OrderedSet()
+ froms = util.Set()
hide_froms = util.Set()
for col in self._raw_columns:
- for f in col._hide_froms():
- hide_froms.add(f)
- while hasattr(f, '_is_clone_of'):
- hide_froms.add(f._is_clone_of)
- f = f._is_clone_of
- for f in col._get_from_objects():
- froms.add(f)
+ hide_froms.update(col._aggregate_hide_froms())
+ froms.update(col._get_from_objects())
if self._whereclause is not None:
- for f in self._whereclause._get_from_objects(is_where=True):
- froms.add(f)
+ froms.update(self._whereclause._get_from_objects(is_where=True))
- for elem in self._froms:
- froms.add(elem)
- for f in elem._get_from_objects():
- froms.add(f)
-
- for elem in froms:
- for f in elem._hide_froms():
- hide_froms.add(f)
- while hasattr(f, '_is_clone_of'):
- hide_froms.add(f._is_clone_of)
- f = f._is_clone_of
+ if self._froms:
+ froms.update(self._froms)
+ for elem in self._froms:
+ hide_froms.update(elem._aggregate_hide_froms())
froms = froms.difference(hide_froms)
if len(froms) > 1:
corr = self.__correlate
if self._should_correlate and existing_froms is not None:
- corr = existing_froms.union(corr)
-
- for f in list(corr):
- while hasattr(f, '_is_clone_of'):
- corr.add(f._is_clone_of)
- f = f._is_clone_of
+ corr.update(existing_froms)
f = froms.difference(corr)
- if len(f) == 0:
+ if not f:
raise exceptions.InvalidRequestError("Select statement '%s' is overcorrelated; returned no 'from' clauses" % str(self.__dont_correlate()))
return f
else: