summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CHANGES6
-rw-r--r--lib/sqlalchemy/sql/functions.py12
-rw-r--r--test/sql/functions.py9
3 files changed, 25 insertions, 2 deletions
diff --git a/CHANGES b/CHANGES
index b10e0a1cd..100626f1a 100644
--- a/CHANGES
+++ b/CHANGES
@@ -50,7 +50,11 @@ CHANGES
should reduce the probability of "Attribute x was
not replaced during compile" warnings. (this generally
applies to SQLA hackers, like Elixir devs).
-
+
+- sql
+ - func.count() with no arguments renders as COUNT(*),
+ equivalent to func.count(text('*')).
+
- ext
- Class-bound attributes sent as arguments to
relation()'s remote_side and foreign_keys parameters
diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py
index 7303bd0c6..7fce3b95b 100644
--- a/lib/sqlalchemy/sql/functions.py
+++ b/lib/sqlalchemy/sql/functions.py
@@ -1,6 +1,6 @@
from sqlalchemy import types as sqltypes
from sqlalchemy.sql.expression import (
- ClauseList, _FigureVisitName, _Function, _literal_as_binds,
+ ClauseList, _FigureVisitName, _Function, _literal_as_binds, text
)
from sqlalchemy.sql import operators
@@ -61,6 +61,16 @@ class random(GenericFunction):
kwargs.setdefault('type_', None)
GenericFunction.__init__(self, args=args, **kwargs)
+class count(GenericFunction):
+ """The ANSI COUNT aggregate function. With no arguments, emits COUNT *."""
+
+ __return_type__ = sqltypes.Integer
+
+ def __init__(self, expression=None, **kwargs):
+ if expression is None:
+ expression = text('*')
+ GenericFunction.__init__(self, args=(expression,), **kwargs)
+
class current_date(AnsiFunction):
__return_type__ = sqltypes.Date
diff --git a/test/sql/functions.py b/test/sql/functions.py
index 6754d6d42..681d6a557 100644
--- a/test/sql/functions.py
+++ b/test/sql/functions.py
@@ -8,6 +8,7 @@ from sqlalchemy.engine import default
from sqlalchemy import types as sqltypes
from testlib import *
from sqlalchemy.sql.functions import GenericFunction
+from testlib.testing import eq_
from sqlalchemy.databases import *
# every dialect in databases.__all__ is expected to pass these tests.
@@ -68,6 +69,14 @@ class CompileTest(TestBase, AssertsCompiledSQL):
]:
self.assert_compile(func.random(), ret, dialect=dialect)
+ def test_generic_count(self):
+ assert isinstance(func.count().type, sqltypes.Integer)
+
+ self.assert_compile(func.count(), 'count(*)')
+ self.assert_compile(func.count(1), 'count(:param_1)')
+ c = column('abc')
+ self.assert_compile(func.count(c), 'count(abc)')
+
def test_constructor(self):
try:
func.current_timestamp('somearg')