diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2005-08-07 00:42:55 +0000 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2005-08-07 00:42:55 +0000 |
commit | 2eaaa50b465197d497e0b437d37c339c01b4f3c8 (patch) | |
tree | d18079cb2d5182db79110b9a36e3f80237d23750 /lib/sqlalchemy/sql.py | |
parent | 66a86fe2e3be52f27a142aafcea2798911f7cc42 (diff) | |
download | sqlalchemy-2eaaa50b465197d497e0b437d37c339c01b4f3c8.tar.gz |
Diffstat (limited to 'lib/sqlalchemy/sql.py')
-rw-r--r-- | lib/sqlalchemy/sql.py | 81 |
1 files changed, 42 insertions, 39 deletions
diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index a8b75b875..00333d4c8 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -103,12 +103,8 @@ def or_(*clauses): def exists(*args, **params): s = select(*args, **params) - return BinaryClause(TextClause("EXISTS"), s, '') + return BinaryClause(TextClause("EXISTS"), s, None) -def in_(*args, **params): - s = select(*args, **params) - return BinaryClause(TextClause("IN"), s, '') - def union(*selects, **params): return _compound_select('UNION', *selects, **params) @@ -121,7 +117,7 @@ def subquery(alias, *args, **params): def bindparam(key, value = None): return BindParamClause(key, value) -def textclause(text): +def text(text): return TextClause(text) def sequence(): @@ -142,6 +138,9 @@ def _compound_select(keyword, *selects, **params): return s +def _is_literal(element): + return not isinstance(element, ClauseElement) and not isinstance(element, schema.SchemaItem) + class ClauseVisitor(schema.SchemaVisitor): """builds upon SchemaVisitor to define the visiting of SQL statement elements in addition to Schema elements.""" @@ -327,14 +326,13 @@ class CompoundClause(ClauseElement): return CompoundClause(self.operator, *clauses) def append(self, clause): - if type(clause) == str: - clause = TextClause(clause) + if _is_literal(clause): + clause = TextClause(str(clause)) elif isinstance(clause, CompoundClause): clause.parens = True - self.clauses.append(clause) self.fromobj += clause._get_from_objects() - + def accept_visitor(self, visitor): for c in self.clauses: c.accept_visitor(visitor) @@ -364,8 +362,6 @@ class BinaryClause(ClauseElement): def __init__(self, left, right, operator): self.left = left self.right = right - if isinstance(right, Select): - right._set_from_objects([]) self.operator = operator self.parens = False @@ -391,7 +387,6 @@ class Selectable(FromClause): c = property(lambda self: self.columns) def accept_visitor(self, visitor): - print repr(self.__class__) raise NotImplementedError() def select(self, whereclauses = None, **params): @@ -414,19 +409,16 @@ class Join(Selectable): def hash_key(self): return "Join(%s, %s, %s, %s)" % (repr(self.left.hash_key()), repr(self.right.hash_key()), repr(self.onclause.hash_key()), repr(self.isouter)) - - def add_join(self, join): - pass - + def select(self, whereclauses = None, **params): return select([self.left, self.right], and_(self.onclause, whereclauses), **params) - + def accept_visitor(self, visitor): self.left.accept_visitor(visitor) self.right.accept_visitor(visitor) self.onclause.accept_visitor(visitor) visitor.visit_join(self) - + def _engine(self): return self.left._engine() or self.right._engine() @@ -434,7 +426,7 @@ class Join(Selectable): m = {} for x in self.onclause._get_from_objects(): m[x.id] = x - result = [self] + [FromClause(from_key = c.id) for c in self.left._get_from_objects() + self.right._get_from_objects()] + result = [self] + [FromClause(from_key = c.id) for c in self.left._get_from_objects() + self.right._get_from_objects()] for x in result: m[x.id] = x result = m.values() @@ -493,7 +485,7 @@ class ColumnSelectable(Selectable): return [self.column.table] def _compare(self, operator, obj): - if not isinstance(obj, ClauseElement) and not isinstance(obj, schema.Column): + if _is_literal(obj): if self.column.table.name is None: obj = BindParamClause(self.name, obj, shortname = self.name) else: @@ -516,12 +508,18 @@ class ColumnSelectable(Selectable): def __gt__(self, other): return self._compare('>', other) - def __ge__(self, other): + def __ge__(self, other): return self._compare('>=', other) - + def like(self, other): return self._compare('LIKE', other) - + + def in_(self, *other): + if _is_literal(other[0]): + return self._compare('IN', CompoundClause(',', other)) + else: + return self._compare('IN', union(*other)) + def startswith(self, other): return self._compare('LIKE', str(other) + "%") @@ -578,6 +576,10 @@ class Select(Selectable): self.whereclause = whereclause self.engine = engine + # indicates if this select statement is a subquery inside of a WHERE clause + # note this is different from a subquery inside the FROM list + self.issubquery = False + self._text = None self._raw_columns = [] self._clauses = [] @@ -598,14 +600,14 @@ class Select(Selectable): self.order_by(*order_by) def append_column(self, column): - if type(column) == str: - column = ColumnClause(column, self) + if _is_literal(column): + column = ColumnClause(str(column), self) self._raw_columns.append(column) for f in column._get_from_objects(): self.froms.setdefault(f.id, f) - + for co in column.columns: if self.use_labels: co._make_proxy(self, name = co.label) @@ -615,18 +617,21 @@ class Select(Selectable): def set_whereclause(self, whereclause): if type(whereclause) == str: self.whereclause = TextClause(whereclause) - - for f in self.whereclause._get_from_objects(): - self.froms.setdefault(f.id, f) class CorrelatedVisitor(ClauseVisitor): def visit_select(s, select): for f in self.froms.keys(): select.clear_from(f) + select.issubquery = True self.whereclause.accept_visitor(CorrelatedVisitor()) + + for f in self.whereclause._get_from_objects(): + self.froms.setdefault(f.id, f) + def clear_from(self, id): self.append_from(FromClause(from_name = None, from_key = id)) + def append_from(self, fromclause): if type(fromclause) == str: fromclause = FromClause(from_name = fromclause) @@ -658,8 +663,6 @@ class Select(Selectable): return engine.compile(self, bindparams) def accept_visitor(self, visitor): -# for c in self._raw_columns: -# c.accept_visitor(visitor) for f in self.froms.values(): f.accept_visitor(visitor) if self.whereclause is not None: @@ -689,11 +692,11 @@ class Select(Selectable): return None - def _set_from_objects(self, obj): - self._from_obj = obj - def _get_from_objects(self): - return getattr(self, '_from_obj', [self]) + if self.issubquery: + return [] + else: + return [self] class UpdateBase(ClauseElement): @@ -709,8 +712,8 @@ class UpdateBase(ClauseElement): for key in parameters.keys(): value = parameters[key] if isinstance(value, Select): - value.append_from(FromClause(from_key=self.table.id)) - elif not isinstance(value, schema.Column) and not isinstance(value, ClauseElement): + value.clear_from(self.table.id) + elif _is_literal(value): try: col = self.table.c[key] parameters[key] = bindparam(col.name, value) @@ -747,7 +750,7 @@ class UpdateBase(ClauseElement): for c in self.table.columns: if d.has_key(c): value = d[c] - if not isinstance(value, schema.Column) and not isinstance(value, ClauseElement): + if _is_literal(value): value = bindparam(c.name, value) values.append((c, value)) return values |