summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2014-10-04 11:51:52 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2014-10-04 11:51:52 -0400
commitbe2541736d886eefa6bdbae5581536abba198736 (patch)
tree957aa28ecb3cd8956ddfa8dc8ab2b7195e666312
parent4da020dae324cb871074e302f4840e8731988be0 (diff)
parent76c06aa65345b47af38a0a1d20638dfbc890b640 (diff)
downloadsqlalchemy-be2541736d886eefa6bdbae5581536abba198736.tar.gz
Merge remote-tracking branch 'origin/pr/134' into pr134
-rw-r--r--lib/sqlalchemy/__init__.py1
-rw-r--r--lib/sqlalchemy/sql/__init__.py1
-rw-r--r--lib/sqlalchemy/sql/compiler.py6
-rw-r--r--lib/sqlalchemy/sql/elements.py85
-rw-r--r--lib/sqlalchemy/sql/expression.py4
-rw-r--r--lib/sqlalchemy/sql/functions.py24
-rw-r--r--test/sql/test_compiler.py91
-rw-r--r--test/sql/test_generative.py5
8 files changed, 215 insertions, 2 deletions
diff --git a/lib/sqlalchemy/__init__.py b/lib/sqlalchemy/__init__.py
index 853566172..d184e1fbf 100644
--- a/lib/sqlalchemy/__init__.py
+++ b/lib/sqlalchemy/__init__.py
@@ -25,6 +25,7 @@ from .sql import (
extract,
false,
func,
+ funcfilter,
insert,
intersect,
intersect_all,
diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py
index 4d013859c..351e08d0b 100644
--- a/lib/sqlalchemy/sql/__init__.py
+++ b/lib/sqlalchemy/sql/__init__.py
@@ -38,6 +38,7 @@ from .expression import (
false,
False_,
func,
+ funcfilter,
insert,
intersect,
intersect_all,
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 18b4d4cfc..86f00d944 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -746,6 +746,12 @@ class SQLCompiler(Compiled):
)
)
+ def visit_funcfilter(self, funcfilter, **kwargs):
+ return "%s FILTER (WHERE %s)" % (
+ funcfilter.func._compiler_dispatch(self, **kwargs),
+ funcfilter.criterion._compiler_dispatch(self, **kwargs)
+ )
+
def visit_extract(self, extract, **kwargs):
field = self.extract_map.get(extract.field, extract.field)
return "EXTRACT(%s FROM %s)" % (
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index 8ec0aa700..53838358d 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -2888,6 +2888,91 @@ class Over(ColumnElement):
))
+class FunctionFilter(ColumnElement):
+ """Represent a function FILTER clause.
+
+ This is a special operator against aggregate and window functions,
+ which controls which rows are passed to it.
+ It's supported only by certain database backends.
+
+ """
+ __visit_name__ = 'funcfilter'
+
+ criterion = None
+
+ def __init__(self, func, *criterion):
+ """Produce a :class:`.FunctionFilter` object against a function.
+
+ Used against aggregate and window functions,
+ for database backends that support the "FILTER" clause.
+
+ E.g.::
+
+ from sqlalchemy import funcfilter
+ funcfilter(func.count(1), MyClass.name == 'some name')
+
+ Would produce "COUNT(1) FILTER (WHERE myclass.name = 'some name')".
+
+ This function is also available from the :data:`~.expression.func`
+ construct itself via the :meth:`.FunctionElement.filter` method.
+
+ """
+ self.func = func
+ self.filter(*criterion)
+
+ def filter(self, *criterion):
+ for criterion in list(criterion):
+ criterion = _expression_literal_as_text(criterion)
+
+ if self.criterion is not None:
+ self.criterion = self.criterion & criterion
+ else:
+ self.criterion = criterion
+
+ return self
+
+ def over(self, partition_by=None, order_by=None):
+ """Produce an OVER clause against this filtered function.
+
+ Used against aggregate or so-called "window" functions,
+ for database backends that support window functions.
+
+ The expression::
+
+ func.rank().filter(MyClass.y > 5).over(order_by='x')
+
+ is shorthand for::
+
+ from sqlalchemy import over, funcfilter
+ over(funcfilter(func.rank(), MyClass.y > 5), order_by='x')
+
+ See :func:`~.expression.over` for a full description.
+
+ """
+ return Over(self, partition_by=partition_by, order_by=order_by)
+
+ @util.memoized_property
+ def type(self):
+ return self.func.type
+
+ def get_children(self, **kwargs):
+ return [c for c in
+ (self.func, self.criterion)
+ if c is not None]
+
+ def _copy_internals(self, clone=_clone, **kw):
+ self.func = clone(self.func, **kw)
+ if self.criterion is not None:
+ self.criterion = clone(self.criterion, **kw)
+
+ @property
+ def _from_objects(self):
+ return list(itertools.chain(
+ *[c._from_objects for c in (self.func, self.criterion)
+ if c is not None]
+ ))
+
+
class Label(ColumnElement):
"""Represents a column label (AS).
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index d96f048b9..2e10b7370 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -36,7 +36,7 @@ from .elements import ClauseElement, ColumnElement,\
True_, False_, BinaryExpression, Tuple, TypeClause, Extract, \
Grouping, not_, \
collate, literal_column, between,\
- literal, outparam, type_coerce, ClauseList
+ literal, outparam, type_coerce, ClauseList, FunctionFilter
from .elements import SavepointClause, RollbackToSavepointClause, \
ReleaseSavepointClause
@@ -97,6 +97,8 @@ outerjoin = public_factory(Join._create_outerjoin, ".expression.outerjoin")
insert = public_factory(Insert, ".expression.insert")
update = public_factory(Update, ".expression.update")
delete = public_factory(Delete, ".expression.delete")
+funcfilter = public_factory(
+ FunctionFilter, ".expression.funcfilter")
# internal functions still being called from tests and the ORM,
diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py
index 7efb1e916..a07eca8c6 100644
--- a/lib/sqlalchemy/sql/functions.py
+++ b/lib/sqlalchemy/sql/functions.py
@@ -12,7 +12,7 @@ from . import sqltypes, schema
from .base import Executable, ColumnCollection
from .elements import ClauseList, Cast, Extract, _literal_as_binds, \
literal_column, _type_from_args, ColumnElement, _clone,\
- Over, BindParameter
+ Over, BindParameter, FunctionFilter
from .selectable import FromClause, Select, Alias
from . import operators
@@ -116,6 +116,28 @@ class FunctionElement(Executable, ColumnElement, FromClause):
"""
return Over(self, partition_by=partition_by, order_by=order_by)
+ def filter(self, *criterion):
+ """Produce a FILTER clause against this function.
+
+ Used against aggregate and window functions,
+ for database backends that support the "FILTER" clause.
+
+ The expression::
+
+ func.count(1).filter(True)
+
+ is shorthand for::
+
+ from sqlalchemy import funcfilter
+ funcfilter(func.count(1), True)
+
+ See :func:`~.expression.funcfilter` for a full description.
+
+ """
+ if not criterion:
+ return self
+ return FunctionFilter(self, *criterion)
+
@property
def _from_objects(self):
return self.clauses._from_objects
diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py
index 3e6b87351..ed13e8455 100644
--- a/test/sql/test_compiler.py
+++ b/test/sql/test_compiler.py
@@ -2190,6 +2190,97 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
"(ORDER BY mytable.myid + :myid_1) AS anon_1 FROM mytable"
)
+ def test_funcfilter(self):
+ self.assert_compile(
+ func.count(1).filter(),
+ "count(:param_1)"
+ )
+ self.assert_compile(
+ func.count(1).filter(
+ table1.c.name != None
+ ),
+ "count(:param_1) FILTER (WHERE mytable.name IS NOT NULL)"
+ )
+ self.assert_compile(
+ func.count(1).filter(
+ table1.c.name == None,
+ table1.c.myid > 0
+ ),
+ "count(:param_1) FILTER (WHERE mytable.name IS NULL AND "
+ "mytable.myid > :myid_1)"
+ )
+
+ self.assert_compile(
+ select([func.count(1).filter(
+ table1.c.description != None
+ ).label('foo')]),
+ "SELECT count(:param_1) FILTER (WHERE mytable.description "
+ "IS NOT NULL) AS foo FROM mytable"
+ )
+
+ # test from_obj generation.
+ # from func:
+ self.assert_compile(
+ select([
+ func.max(table1.c.name).filter(
+ literal_column('description') != None
+ )
+ ]),
+ "SELECT max(mytable.name) FILTER (WHERE description "
+ "IS NOT NULL) AS anon_1 FROM mytable"
+ )
+ # from criterion:
+ self.assert_compile(
+ select([
+ func.count(1).filter(
+ table1.c.name == 'name'
+ )
+ ]),
+ "SELECT count(:param_1) FILTER (WHERE mytable.name = :name_1) "
+ "AS anon_1 FROM mytable"
+ )
+
+ # test chaining:
+ self.assert_compile(
+ select([
+ func.count(1).filter(
+ table1.c.name == 'name'
+ ).filter(
+ table1.c.description == 'description'
+ )
+ ]),
+ "SELECT count(:param_1) FILTER (WHERE "
+ "mytable.name = :name_1 AND mytable.description = :description_1) "
+ "AS anon_1 FROM mytable"
+ )
+
+ # test filtered windowing:
+ self.assert_compile(
+ select([
+ func.rank().filter(
+ table1.c.name > 'foo'
+ ).over(
+ order_by=table1.c.name
+ )
+ ]),
+ "SELECT rank() FILTER (WHERE mytable.name > :name_1) "
+ "OVER (ORDER BY mytable.name) AS anon_1 FROM mytable"
+ )
+
+ self.assert_compile(
+ select([
+ func.rank().filter(
+ table1.c.name > 'foo'
+ ).over(
+ order_by=table1.c.name,
+ partition_by=['description']
+ )
+ ]),
+ "SELECT rank() FILTER (WHERE mytable.name > :name_1) "
+ "OVER (PARTITION BY mytable.description ORDER BY mytable.name) "
+ "AS anon_1 FROM mytable"
+ )
+
def test_date_between(self):
import datetime
table = Table('dt', metadata,
diff --git a/test/sql/test_generative.py b/test/sql/test_generative.py
index 013ba8082..6044cecb0 100644
--- a/test/sql/test_generative.py
+++ b/test/sql/test_generative.py
@@ -539,6 +539,11 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL):
expr2 = CloningVisitor().traverse(expr)
assert str(expr) == str(expr2)
+ def test_funcfilter(self):
+ expr = func.count(1).filter(t1.c.col1 > 1)
+ expr2 = CloningVisitor().traverse(expr)
+ assert str(expr) == str(expr2)
+
def test_adapt_union(self):
u = union(
t1.select().where(t1.c.col1 == 4),