summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql.py')
-rw-r--r--lib/sqlalchemy/sql.py67
1 files changed, 49 insertions, 18 deletions
diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py
index 5edfe6f1f..fc0346b85 100644
--- a/lib/sqlalchemy/sql.py
+++ b/lib/sqlalchemy/sql.py
@@ -1,4 +1,3 @@
-# sql.py
# Copyright (C) 2005,2006 Michael Bayer mike_mp@zzzcomputing.com
#
# This module is part of SQLAlchemy and is released under
@@ -13,7 +12,7 @@ from exceptions import *
import string, re, random
types = __import__('types')
-__all__ = ['text', 'table', 'column', 'func', 'select', 'update', 'insert', 'delete', 'join', 'and_', 'or_', 'not_', 'between_', 'cast', 'union', 'union_all', 'null', 'desc', 'asc', 'outerjoin', 'alias', 'subquery', 'literal', 'bindparam', 'exists']
+__all__ = ['text', 'table', 'column', 'func', 'select', 'update', 'insert', 'delete', 'join', 'and_', 'or_', 'not_', 'between_', 'case', 'cast', 'union', 'union_all', 'null', 'desc', 'asc', 'outerjoin', 'alias', 'subquery', 'literal', 'bindparam', 'exists']
def desc(column):
"""returns a descending ORDER BY clause element, e.g.:
@@ -132,6 +131,17 @@ def between_(ctest, cleft, cright):
""" returns BETWEEN predicate clause (clausetest BETWEEN clauseleft AND clauseright) """
return BooleanExpression(ctest, and_(cleft, cright), 'BETWEEN')
between = between_
+
+def case(whens, value=None, else_=None):
+ """ SQL CASE statement -- whens are a sequence of pairs to be translated into "when / then" clauses;
+ optional [value] for simple case statements, and [else_] for case defaults """
+ whenlist = [CompoundClause(None, 'WHEN', c, 'THEN', r) for (c,r) in whens]
+ if else_:
+ whenlist.append(CompoundClause(None, 'ELSE', else_))
+ cc = CalculatedClause(None, 'CASE', value, *whenlist + ['END'])
+ for c in cc.clauses:
+ c.parens = False
+ return cc
def cast(clause, totype, **kwargs):
""" returns CAST function CAST(clause AS totype)
@@ -295,6 +305,7 @@ class ClauseVisitor(object):
def visit_join(self, join):pass
def visit_null(self, null):pass
def visit_clauselist(self, list):pass
+ def visit_calculatedclause(self, calcclause):pass
def visit_function(self, func):pass
def visit_label(self, label):pass
def visit_typeclause(self, typeclause):pass
@@ -831,9 +842,42 @@ class CompoundClause(ClauseList):
return self.operator == other.operator
else:
return False
+
+class CalculatedClause(ClauseList, ColumnElement):
+ """ describes a calculated SQL expression that has a type, like CASE. extends ColumnElement to
+ provide column-level comparison operators. """
+ def __init__(self, name, *clauses, **kwargs):
+ self.name = name
+ self.type = sqltypes.to_instance(kwargs.get('type', None))
+ self._engine = kwargs.get('engine', None)
+ ClauseList.__init__(self, *clauses)
+ key = property(lambda self:self.name or "_calc_")
+ def _process_from_dict(self, data, asfrom):
+ super(CalculatedClause, self)._process_from_dict(data, asfrom)
+ # this helps a Select object get the engine from us
+ data.setdefault(self, self)
+ def copy_container(self):
+ clauses = [clause.copy_container() for clause in self.clauses]
+ return CalculatedClause(type=self.type, engine=self._engine, *clauses)
+ def accept_visitor(self, visitor):
+ for c in self.clauses:
+ c.accept_visitor(visitor)
+ visitor.visit_calculatedclause(self)
+ def _bind_param(self, obj):
+ return BindParamClause(self.name, obj, type=self.type)
+ def select(self):
+ return select([self])
+ def scalar(self):
+ return select([self]).scalar()
+ def execute(self):
+ return select([self]).execute()
+ def _compare_type(self, obj):
+ return self.type
+
-class Function(ClauseList, ColumnElement):
- """describes a SQL function. extends ClauseList to provide comparison operators."""
+class Function(CalculatedClause):
+ """describes a SQL function. extends CalculatedClause turn the "clauselist" into function
+ arguments, also adds a "packagenames" argument"""
def __init__(self, name, *clauses, **kwargs):
self.name = name
self.type = sqltypes.to_instance(kwargs.get('type', None))
@@ -848,10 +892,6 @@ class Function(ClauseList, ColumnElement):
else:
clause = BindParamClause(self.name, clause, shortname=self.name, type=None)
self.clauses.append(clause)
- def _process_from_dict(self, data, asfrom):
- super(Function, self)._process_from_dict(data, asfrom)
- # this helps a Select object get the engine from us
- data.setdefault(self, self)
def copy_container(self):
clauses = [clause.copy_container() for clause in self.clauses]
return Function(self.name, type=self.type, packagenames=self.packagenames, engine=self._engine, *clauses)
@@ -859,16 +899,7 @@ class Function(ClauseList, ColumnElement):
for c in self.clauses:
c.accept_visitor(visitor)
visitor.visit_function(self)
- def _bind_param(self, obj):
- return BindParamClause(self.name, obj, shortname=self.name, type=self.type)
- def select(self):
- return select([self])
- def scalar(self):
- return select([self]).scalar()
- def execute(self):
- return select([self]).execute()
- def _compare_type(self, obj):
- return self.type
+
class FunctionGenerator(object):
"""generates Function objects based on getattr calls"""