summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2008-09-08 03:51:47 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2008-09-08 03:51:47 +0000
commit58c5bb7fc104da26cd1797d9680a810a3b79ab0a (patch)
treedb54cd7731a7a670b616136ff034b6a0f0b5d1b6
parentcc0dcca7b4bf3ad05630c92012be715c6e515aaf (diff)
downloadsqlalchemy-58c5bb7fc104da26cd1797d9680a810a3b79ab0a.tar.gz
- Added func.min(), func.max(), func.sum() as "generic functions",
which basically allows for their return type to be determined automatically. Helps with dates on SQLite, decimal types, others. [ticket:1160] - added decimal.Decimal as an "auto-detect" type; bind parameters and generic functions will set their type to Numeric when a Decimal is used.
-rw-r--r--CHANGES9
-rw-r--r--lib/sqlalchemy/sql/functions.py17
-rw-r--r--lib/sqlalchemy/types.py1
-rw-r--r--test/sql/functions.py19
4 files changed, 39 insertions, 7 deletions
diff --git a/CHANGES b/CHANGES
index 31dd9e711..fc7bc65da 100644
--- a/CHANGES
+++ b/CHANGES
@@ -149,6 +149,15 @@ CHANGES
[ticket:1068]. This feature is on hold pending further
development.
+ - Added func.min(), func.max(), func.sum() as "generic functions",
+ which basically allows for their return type to be determined
+ automatically. Helps with dates on SQLite, decimal types,
+ others. [ticket:1160]
+
+ - added decimal.Decimal as an "auto-detect" type; bind parameters
+ and generic functions will set their type to Numeric when a
+ Decimal is used.
+
- mysql
- The 'length' argument to MSInteger, MSBigInteger, MSTinyInteger,
MSSmallInteger and MSYear has been renamed to 'display_width'.
diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py
index 7fce3b95b..c7a0f142d 100644
--- a/lib/sqlalchemy/sql/functions.py
+++ b/lib/sqlalchemy/sql/functions.py
@@ -36,12 +36,25 @@ class AnsiFunction(GenericFunction):
def __init__(self, **kwargs):
GenericFunction.__init__(self, **kwargs)
-
-class coalesce(GenericFunction):
+class ReturnTypeFromArgs(GenericFunction):
+ """Define a function whose return type is the same as its arguments."""
+
def __init__(self, *args, **kwargs):
kwargs.setdefault('type_', _type_from_args(args))
GenericFunction.__init__(self, args=args, **kwargs)
+class coalesce(ReturnTypeFromArgs):
+ pass
+
+class max(ReturnTypeFromArgs):
+ pass
+
+class min(ReturnTypeFromArgs):
+ pass
+
+class sum(ReturnTypeFromArgs):
+ pass
+
class now(GenericFunction):
__return_type__ = sqltypes.DateTime
diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py
index 3690ed3ca..4958e4812 100644
--- a/lib/sqlalchemy/types.py
+++ b/lib/sqlalchemy/types.py
@@ -625,6 +625,7 @@ type_map = {
unicode : NCHAR,
int : Integer,
float : Numeric,
+ _python_Decimal : Numeric,
dt.date : Date,
dt.datetime : DateTime,
dt.time : Time,
diff --git a/test/sql/functions.py b/test/sql/functions.py
index 27e87eceb..ac9b7e329 100644
--- a/test/sql/functions.py
+++ b/test/sql/functions.py
@@ -10,6 +10,7 @@ from sqlalchemy import types as sqltypes
from testlib import *
from sqlalchemy.sql.functions import GenericFunction
from testlib.testing import eq_
+from decimal import Decimal as _python_Decimal
from sqlalchemy.databases import *
@@ -90,13 +91,21 @@ class CompileTest(TestBase, AssertsCompiledSQL):
except TypeError:
assert True
- def test_typing(self):
- assert isinstance(func.coalesce(datetime.date(2007, 10, 5), datetime.date(2005, 10, 15)).type, sqltypes.Date)
-
- assert isinstance(func.coalesce(None, datetime.date(2005, 10, 15)).type, sqltypes.Date)
-
+ def test_return_type_detection(self):
+
+ for fn in [func.coalesce, func.max, func.min, func.sum]:
+ for args, type_ in [
+ ((datetime.date(2007, 10, 5), datetime.date(2005, 10, 15)), sqltypes.Date),
+ ((3, 5), sqltypes.Integer),
+ ((_python_Decimal(3), _python_Decimal(5)), sqltypes.Numeric),
+ (("foo", "bar"), sqltypes.String),
+ ((datetime.datetime(2007, 10, 5, 8, 3, 34), datetime.datetime(2005, 10, 15, 14, 45, 33)), sqltypes.DateTime)
+ ]:
+ assert isinstance(fn(*args).type, type_), "%s / %s" % (fn(), type_)
+
assert isinstance(func.concat("foo", "bar").type, sqltypes.String)
+
def test_assorted(self):
table1 = table('mytable',
column('myid', Integer),