summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql.py')
-rw-r--r--lib/sqlalchemy/sql.py81
1 files changed, 42 insertions, 39 deletions
diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py
index a8b75b875..00333d4c8 100644
--- a/lib/sqlalchemy/sql.py
+++ b/lib/sqlalchemy/sql.py
@@ -103,12 +103,8 @@ def or_(*clauses):
def exists(*args, **params):
s = select(*args, **params)
- return BinaryClause(TextClause("EXISTS"), s, '')
+ return BinaryClause(TextClause("EXISTS"), s, None)
-def in_(*args, **params):
- s = select(*args, **params)
- return BinaryClause(TextClause("IN"), s, '')
-
def union(*selects, **params):
return _compound_select('UNION', *selects, **params)
@@ -121,7 +117,7 @@ def subquery(alias, *args, **params):
def bindparam(key, value = None):
return BindParamClause(key, value)
-def textclause(text):
+def text(text):
return TextClause(text)
def sequence():
@@ -142,6 +138,9 @@ def _compound_select(keyword, *selects, **params):
return s
+def _is_literal(element):
+ return not isinstance(element, ClauseElement) and not isinstance(element, schema.SchemaItem)
+
class ClauseVisitor(schema.SchemaVisitor):
"""builds upon SchemaVisitor to define the visiting of SQL statement elements in
addition to Schema elements."""
@@ -327,14 +326,13 @@ class CompoundClause(ClauseElement):
return CompoundClause(self.operator, *clauses)
def append(self, clause):
- if type(clause) == str:
- clause = TextClause(clause)
+ if _is_literal(clause):
+ clause = TextClause(str(clause))
elif isinstance(clause, CompoundClause):
clause.parens = True
-
self.clauses.append(clause)
self.fromobj += clause._get_from_objects()
-
+
def accept_visitor(self, visitor):
for c in self.clauses:
c.accept_visitor(visitor)
@@ -364,8 +362,6 @@ class BinaryClause(ClauseElement):
def __init__(self, left, right, operator):
self.left = left
self.right = right
- if isinstance(right, Select):
- right._set_from_objects([])
self.operator = operator
self.parens = False
@@ -391,7 +387,6 @@ class Selectable(FromClause):
c = property(lambda self: self.columns)
def accept_visitor(self, visitor):
- print repr(self.__class__)
raise NotImplementedError()
def select(self, whereclauses = None, **params):
@@ -414,19 +409,16 @@ class Join(Selectable):
def hash_key(self):
return "Join(%s, %s, %s, %s)" % (repr(self.left.hash_key()), repr(self.right.hash_key()), repr(self.onclause.hash_key()), repr(self.isouter))
-
- def add_join(self, join):
- pass
-
+
def select(self, whereclauses = None, **params):
return select([self.left, self.right], and_(self.onclause, whereclauses), **params)
-
+
def accept_visitor(self, visitor):
self.left.accept_visitor(visitor)
self.right.accept_visitor(visitor)
self.onclause.accept_visitor(visitor)
visitor.visit_join(self)
-
+
def _engine(self):
return self.left._engine() or self.right._engine()
@@ -434,7 +426,7 @@ class Join(Selectable):
m = {}
for x in self.onclause._get_from_objects():
m[x.id] = x
- result = [self] + [FromClause(from_key = c.id) for c in self.left._get_from_objects() + self.right._get_from_objects()]
+ result = [self] + [FromClause(from_key = c.id) for c in self.left._get_from_objects() + self.right._get_from_objects()]
for x in result:
m[x.id] = x
result = m.values()
@@ -493,7 +485,7 @@ class ColumnSelectable(Selectable):
return [self.column.table]
def _compare(self, operator, obj):
- if not isinstance(obj, ClauseElement) and not isinstance(obj, schema.Column):
+ if _is_literal(obj):
if self.column.table.name is None:
obj = BindParamClause(self.name, obj, shortname = self.name)
else:
@@ -516,12 +508,18 @@ class ColumnSelectable(Selectable):
def __gt__(self, other):
return self._compare('>', other)
- def __ge__(self, other):
+ def __ge__(self, other):
return self._compare('>=', other)
-
+
def like(self, other):
return self._compare('LIKE', other)
-
+
+ def in_(self, *other):
+ if _is_literal(other[0]):
+ return self._compare('IN', CompoundClause(',', other))
+ else:
+ return self._compare('IN', union(*other))
+
def startswith(self, other):
return self._compare('LIKE', str(other) + "%")
@@ -578,6 +576,10 @@ class Select(Selectable):
self.whereclause = whereclause
self.engine = engine
+ # indicates if this select statement is a subquery inside of a WHERE clause
+ # note this is different from a subquery inside the FROM list
+ self.issubquery = False
+
self._text = None
self._raw_columns = []
self._clauses = []
@@ -598,14 +600,14 @@ class Select(Selectable):
self.order_by(*order_by)
def append_column(self, column):
- if type(column) == str:
- column = ColumnClause(column, self)
+ if _is_literal(column):
+ column = ColumnClause(str(column), self)
self._raw_columns.append(column)
for f in column._get_from_objects():
self.froms.setdefault(f.id, f)
-
+
for co in column.columns:
if self.use_labels:
co._make_proxy(self, name = co.label)
@@ -615,18 +617,21 @@ class Select(Selectable):
def set_whereclause(self, whereclause):
if type(whereclause) == str:
self.whereclause = TextClause(whereclause)
-
- for f in self.whereclause._get_from_objects():
- self.froms.setdefault(f.id, f)
class CorrelatedVisitor(ClauseVisitor):
def visit_select(s, select):
for f in self.froms.keys():
select.clear_from(f)
+ select.issubquery = True
self.whereclause.accept_visitor(CorrelatedVisitor())
+
+ for f in self.whereclause._get_from_objects():
+ self.froms.setdefault(f.id, f)
+
def clear_from(self, id):
self.append_from(FromClause(from_name = None, from_key = id))
+
def append_from(self, fromclause):
if type(fromclause) == str:
fromclause = FromClause(from_name = fromclause)
@@ -658,8 +663,6 @@ class Select(Selectable):
return engine.compile(self, bindparams)
def accept_visitor(self, visitor):
-# for c in self._raw_columns:
-# c.accept_visitor(visitor)
for f in self.froms.values():
f.accept_visitor(visitor)
if self.whereclause is not None:
@@ -689,11 +692,11 @@ class Select(Selectable):
return None
- def _set_from_objects(self, obj):
- self._from_obj = obj
-
def _get_from_objects(self):
- return getattr(self, '_from_obj', [self])
+ if self.issubquery:
+ return []
+ else:
+ return [self]
class UpdateBase(ClauseElement):
@@ -709,8 +712,8 @@ class UpdateBase(ClauseElement):
for key in parameters.keys():
value = parameters[key]
if isinstance(value, Select):
- value.append_from(FromClause(from_key=self.table.id))
- elif not isinstance(value, schema.Column) and not isinstance(value, ClauseElement):
+ value.clear_from(self.table.id)
+ elif _is_literal(value):
try:
col = self.table.c[key]
parameters[key] = bindparam(col.name, value)
@@ -747,7 +750,7 @@ class UpdateBase(ClauseElement):
for c in self.table.columns:
if d.has_key(c):
value = d[c]
- if not isinstance(value, schema.Column) and not isinstance(value, ClauseElement):
+ if _is_literal(value):
value = bindparam(c.name, value)
values.append((c, value))
return values