diff options
-rw-r--r-- | CHANGES | 6 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/functions.py | 12 | ||||
-rw-r--r-- | test/sql/functions.py | 9 |
3 files changed, 25 insertions, 2 deletions
@@ -50,7 +50,11 @@ CHANGES should reduce the probability of "Attribute x was not replaced during compile" warnings. (this generally applies to SQLA hackers, like Elixir devs). - + +- sql + - func.count() with no arguments renders as COUNT(*), + equivalent to func.count(text('*')). + - ext - Class-bound attributes sent as arguments to relation()'s remote_side and foreign_keys parameters diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index 7303bd0c6..7fce3b95b 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -1,6 +1,6 @@ from sqlalchemy import types as sqltypes from sqlalchemy.sql.expression import ( - ClauseList, _FigureVisitName, _Function, _literal_as_binds, + ClauseList, _FigureVisitName, _Function, _literal_as_binds, text ) from sqlalchemy.sql import operators @@ -61,6 +61,16 @@ class random(GenericFunction): kwargs.setdefault('type_', None) GenericFunction.__init__(self, args=args, **kwargs) +class count(GenericFunction): + """The ANSI COUNT aggregate function. With no arguments, emits COUNT *.""" + + __return_type__ = sqltypes.Integer + + def __init__(self, expression=None, **kwargs): + if expression is None: + expression = text('*') + GenericFunction.__init__(self, args=(expression,), **kwargs) + class current_date(AnsiFunction): __return_type__ = sqltypes.Date diff --git a/test/sql/functions.py b/test/sql/functions.py index 6754d6d42..681d6a557 100644 --- a/test/sql/functions.py +++ b/test/sql/functions.py @@ -8,6 +8,7 @@ from sqlalchemy.engine import default from sqlalchemy import types as sqltypes from testlib import * from sqlalchemy.sql.functions import GenericFunction +from testlib.testing import eq_ from sqlalchemy.databases import * # every dialect in databases.__all__ is expected to pass these tests. @@ -68,6 +69,14 @@ class CompileTest(TestBase, AssertsCompiledSQL): ]: self.assert_compile(func.random(), ret, dialect=dialect) + def test_generic_count(self): + assert isinstance(func.count().type, sqltypes.Integer) + + self.assert_compile(func.count(), 'count(*)') + self.assert_compile(func.count(1), 'count(:param_1)') + c = column('abc') + self.assert_compile(func.count(c), 'count(abc)') + def test_constructor(self): try: func.current_timestamp('somearg') |