summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2006-05-15 23:47:07 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2006-05-15 23:47:07 +0000
commit79555fb434660bc4b317b5d384120deaa2dc9b60 (patch)
tree74ea4df7ea4230d914ba98d9ed3adb7f6e9a6f2a
parent31e001580b5114d5410ba4f9ee0c42f1448a0efd (diff)
downloadsqlalchemy-79555fb434660bc4b317b5d384120deaa2dc9b60.tar.gz
rick morrison's CASE statement + unit test
-rw-r--r--lib/sqlalchemy/ansisql.py8
-rw-r--r--lib/sqlalchemy/sql.py67
-rw-r--r--test/alltests.py1
-rw-r--r--test/case_statement.py59
4 files changed, 116 insertions, 19 deletions
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py
index 602da58ee..df3f8fa59 100644
--- a/lib/sqlalchemy/ansisql.py
+++ b/lib/sqlalchemy/ansisql.py
@@ -224,7 +224,13 @@ class ANSICompiler(sql.Compiled):
def apply_function_parens(self, func):
return func.name.upper() not in ANSI_FUNCS or len(func.clauses) > 0
-
+
+ def visit_calculatedclause(self, list):
+ if list.parens:
+ self.strings[list] = "(" + string.join([self.get_str(c) for c in list.clauses], ' ') + ")"
+ else:
+ self.strings[list] = string.join([self.get_str(c) for c in list.clauses], ' ')
+
def visit_function(self, func):
if len(self.select_stack):
self.typemap.setdefault(func.name, func.type)
diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py
index 5edfe6f1f..fc0346b85 100644
--- a/lib/sqlalchemy/sql.py
+++ b/lib/sqlalchemy/sql.py
@@ -1,4 +1,3 @@
-# sql.py
# Copyright (C) 2005,2006 Michael Bayer mike_mp@zzzcomputing.com
#
# This module is part of SQLAlchemy and is released under
@@ -13,7 +12,7 @@ from exceptions import *
import string, re, random
types = __import__('types')
-__all__ = ['text', 'table', 'column', 'func', 'select', 'update', 'insert', 'delete', 'join', 'and_', 'or_', 'not_', 'between_', 'cast', 'union', 'union_all', 'null', 'desc', 'asc', 'outerjoin', 'alias', 'subquery', 'literal', 'bindparam', 'exists']
+__all__ = ['text', 'table', 'column', 'func', 'select', 'update', 'insert', 'delete', 'join', 'and_', 'or_', 'not_', 'between_', 'case', 'cast', 'union', 'union_all', 'null', 'desc', 'asc', 'outerjoin', 'alias', 'subquery', 'literal', 'bindparam', 'exists']
def desc(column):
"""returns a descending ORDER BY clause element, e.g.:
@@ -132,6 +131,17 @@ def between_(ctest, cleft, cright):
""" returns BETWEEN predicate clause (clausetest BETWEEN clauseleft AND clauseright) """
return BooleanExpression(ctest, and_(cleft, cright), 'BETWEEN')
between = between_
+
+def case(whens, value=None, else_=None):
+ """ SQL CASE statement -- whens are a sequence of pairs to be translated into "when / then" clauses;
+ optional [value] for simple case statements, and [else_] for case defaults """
+ whenlist = [CompoundClause(None, 'WHEN', c, 'THEN', r) for (c,r) in whens]
+ if else_:
+ whenlist.append(CompoundClause(None, 'ELSE', else_))
+ cc = CalculatedClause(None, 'CASE', value, *whenlist + ['END'])
+ for c in cc.clauses:
+ c.parens = False
+ return cc
def cast(clause, totype, **kwargs):
""" returns CAST function CAST(clause AS totype)
@@ -295,6 +305,7 @@ class ClauseVisitor(object):
def visit_join(self, join):pass
def visit_null(self, null):pass
def visit_clauselist(self, list):pass
+ def visit_calculatedclause(self, calcclause):pass
def visit_function(self, func):pass
def visit_label(self, label):pass
def visit_typeclause(self, typeclause):pass
@@ -831,9 +842,42 @@ class CompoundClause(ClauseList):
return self.operator == other.operator
else:
return False
+
+class CalculatedClause(ClauseList, ColumnElement):
+ """ describes a calculated SQL expression that has a type, like CASE. extends ColumnElement to
+ provide column-level comparison operators. """
+ def __init__(self, name, *clauses, **kwargs):
+ self.name = name
+ self.type = sqltypes.to_instance(kwargs.get('type', None))
+ self._engine = kwargs.get('engine', None)
+ ClauseList.__init__(self, *clauses)
+ key = property(lambda self:self.name or "_calc_")
+ def _process_from_dict(self, data, asfrom):
+ super(CalculatedClause, self)._process_from_dict(data, asfrom)
+ # this helps a Select object get the engine from us
+ data.setdefault(self, self)
+ def copy_container(self):
+ clauses = [clause.copy_container() for clause in self.clauses]
+ return CalculatedClause(type=self.type, engine=self._engine, *clauses)
+ def accept_visitor(self, visitor):
+ for c in self.clauses:
+ c.accept_visitor(visitor)
+ visitor.visit_calculatedclause(self)
+ def _bind_param(self, obj):
+ return BindParamClause(self.name, obj, type=self.type)
+ def select(self):
+ return select([self])
+ def scalar(self):
+ return select([self]).scalar()
+ def execute(self):
+ return select([self]).execute()
+ def _compare_type(self, obj):
+ return self.type
+
-class Function(ClauseList, ColumnElement):
- """describes a SQL function. extends ClauseList to provide comparison operators."""
+class Function(CalculatedClause):
+ """describes a SQL function. extends CalculatedClause turn the "clauselist" into function
+ arguments, also adds a "packagenames" argument"""
def __init__(self, name, *clauses, **kwargs):
self.name = name
self.type = sqltypes.to_instance(kwargs.get('type', None))
@@ -848,10 +892,6 @@ class Function(ClauseList, ColumnElement):
else:
clause = BindParamClause(self.name, clause, shortname=self.name, type=None)
self.clauses.append(clause)
- def _process_from_dict(self, data, asfrom):
- super(Function, self)._process_from_dict(data, asfrom)
- # this helps a Select object get the engine from us
- data.setdefault(self, self)
def copy_container(self):
clauses = [clause.copy_container() for clause in self.clauses]
return Function(self.name, type=self.type, packagenames=self.packagenames, engine=self._engine, *clauses)
@@ -859,16 +899,7 @@ class Function(ClauseList, ColumnElement):
for c in self.clauses:
c.accept_visitor(visitor)
visitor.visit_function(self)
- def _bind_param(self, obj):
- return BindParamClause(self.name, obj, shortname=self.name, type=self.type)
- def select(self):
- return select([self])
- def scalar(self):
- return select([self]).scalar()
- def execute(self):
- return select([self]).execute()
- def _compare_type(self, obj):
- return self.type
+
class FunctionGenerator(object):
"""generates Function objects based on getattr calls"""
diff --git a/test/alltests.py b/test/alltests.py
index 4e9c73c2c..3595edd7e 100644
--- a/test/alltests.py
+++ b/test/alltests.py
@@ -24,6 +24,7 @@ def suite():
# SQL syntax
'select',
'selectable',
+ 'case_statement',
# assorted round-trip tests
'query',
diff --git a/test/case_statement.py b/test/case_statement.py
new file mode 100644
index 000000000..fc0e919a5
--- /dev/null
+++ b/test/case_statement.py
@@ -0,0 +1,59 @@
+import sys
+import testbase
+from sqlalchemy import *
+
+
+class CaseTest(testbase.PersistTest):
+
+ def setUpAll(self):
+ global info_table
+ info_table = Table('infos', testbase.db,
+ Column('pk', Integer, primary_key=True),
+ Column('info', String))
+
+ info_table.create()
+
+ info_table.insert().execute(
+ {'pk':1, 'info':'pk_1_data'},
+ {'pk':2, 'info':'pk_2_data'},
+ {'pk':3, 'info':'pk_3_data'},
+ {'pk':4, 'info':'pk_4_data'},
+ {'pk':5, 'info':'pk_5_data'})
+ def tearDownAll(self):
+ info_table.drop()
+
+ def testcase(self):
+ inner = select([case([[info_table.c.pk < 3, literal('lessthan3', type=String)],
+ [info_table.c.pk >= 3, literal('gt3', type=String)]]).label('x'),
+ info_table.c.pk, info_table.c.info], from_obj=[info_table]).alias('q_inner')
+
+ inner_result = inner.execute().fetchall()
+
+ # Outputs:
+ # lessthan3 1 pk_1_data
+ # lessthan3 2 pk_2_data
+ # gt3 3 pk_3_data
+ # gt3 4 pk_4_data
+ # gt3 5 pk_5_data
+ assert inner_result == [
+ ('lessthan3', 1, 'pk_1_data'),
+ ('lessthan3', 2, 'pk_2_data'),
+ ('gt3', 3, 'pk_3_data'),
+ ('gt3', 4, 'pk_4_data'),
+ ('gt3', 5, 'pk_5_data'),
+ ]
+
+ outer = select([inner])
+
+ outer_result = outer.execute().fetchall()
+
+ assert outer_result == [
+ ('lessthan3', 1, 'pk_1_data'),
+ ('lessthan3', 2, 'pk_2_data'),
+ ('gt3', 3, 'pk_3_data'),
+ ('gt3', 4, 'pk_4_data'),
+ ('gt3', 5, 'pk_5_data'),
+ ]
+
+if __name__ == "__main__":
+ testbase.main()