summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2007-04-25 17:49:26 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2007-04-25 17:49:26 +0000
commit2f547222da46b38df07dde08e60bc5efbb0afd79 (patch)
treed97a3d7a640df27b9e27bbc4ae9e6dec239480c7
parent3a0f65b4343f5e338b1425351d190f0f752caee1 (diff)
downloadsqlalchemy-2f547222da46b38df07dde08e60bc5efbb0afd79.tar.gz
- added generative versions of aggregates, i.e. sum(), avg(), etc.
to query. used via query.apply_max(), apply_sum(), etc. #552
-rw-r--r--CHANGES3
-rw-r--r--lib/sqlalchemy/orm/query.py46
-rw-r--r--test/orm/generative.py4
3 files changed, 52 insertions, 1 deletions
diff --git a/CHANGES b/CHANGES
index 450ef3279..fea91b7f5 100644
--- a/CHANGES
+++ b/CHANGES
@@ -80,6 +80,9 @@
takes optional string "property" to isolate the desired relation.
also adds static Query.query_from_parent(instance, property)
version. [ticket:541]
+ - added generative versions of aggregates, i.e. sum(), avg(), etc.
+ to query. used via query.apply_max(), apply_sum(), etc.
+ #552
- corresponding to label/bindparam name generataion, eager loaders
generate deterministic names for the aliases they create using
md5 hashes.
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index 9eec1bc0e..c43b9a946 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -43,6 +43,8 @@ class Query(object):
self._offset = kwargs.pop('offset', None)
self._limit = kwargs.pop('limit', None)
self._criterion = None
+ self._col = None
+ self._func = None
self._joinpoint = self.mapper
self._from_obj = [self.table]
@@ -71,6 +73,8 @@ class Query(object):
q._from_obj = list(self._from_obj)
q._joinpoint = self._joinpoint
q._criterion = self._criterion
+ q._col = self._col
+ q._func = self._func
return q
def _get_session(self):
@@ -318,7 +322,6 @@ class Query(object):
"""Given a ``WHERE`` criterion, create a ``SELECT`` statement,
execute and return the resulting instances.
"""
-
statement = self.compile(whereclause, **kwargs)
return self._select_statement(statement, params=params)
@@ -611,6 +614,41 @@ class Query(object):
raise exceptions.InvalidRequestError("Can't locate property named '%s'" % key)
return [keys, p]
+ def _generative_col_aggregate(self, col, func):
+ """apply the given aggregate function to the query and return the newly
+ resulting ``Query``.
+ """
+ if self._col is not None or self._func is not None:
+ raise exceptions.InvalidRequestError("Query already contains an aggregate column or function")
+ q = self._clone()
+ q._col = col
+ q._func = func
+ return q
+
+ def apply_min(self, col):
+ """apply the SQL ``min()`` function against the given column to the
+ query and return the newly resulting ``Query``.
+ """
+ return self._generative_col_aggregate(col, sql.func.min)
+
+ def apply_max(self, col):
+ """apply the SQL ``max()`` function against the given column to the
+ query and return the newly resulting ``Query``.
+ """
+ return self._generative_col_aggregate(col, sql.func.max)
+
+ def apply_sum(self, col):
+ """apply the SQL ``sum()`` function against the given column to the
+ query and return the newly resulting ``Query``.
+ """
+ return self._generative_col_aggregate(col, sql.func.sum)
+
+ def apply_avg(self, col):
+ """apply the SQL ``avg()`` function against the given column to the
+ query and return the newly resulting ``Query``.
+ """
+ return self._generative_col_aggregate(col, sql.func.avg)
+
def _col_aggregate(self, col, func):
"""Execute ``func()`` function against the given column.
@@ -767,6 +805,12 @@ class Query(object):
"""
return list(self)
+
+ def scalar(self):
+ if self._col is None or self._func is None:
+ return self[0]
+ else:
+ return self._col_aggregate(self._col, self._func)
def __iter__(self):
return iter(self.select_whereclause())
diff --git a/test/orm/generative.py b/test/orm/generative.py
index 6cda21964..512f04ae9 100644
--- a/test/orm/generative.py
+++ b/test/orm/generative.py
@@ -59,6 +59,7 @@ class GenerativeQueryTest(PersistTest):
assert self.query.count() == 100
assert self.query.filter(foo.c.bar<30).min(foo.c.bar) == 0
assert self.query.filter(foo.c.bar<30).max(foo.c.bar) == 29
+ assert self.query.filter(foo.c.bar<30).apply_max(foo.c.bar).scalar() == 29
@testbase.unsupported('mysql')
def test_aggregate_1(self):
@@ -73,6 +74,9 @@ class GenerativeQueryTest(PersistTest):
def test_aggregate_2_int(self):
assert int(self.res.filter(foo.c.bar<30).avg(foo.c.bar)) == 14
+ def test_aggregate_3(self):
+ assert self.res.filter(foo.c.bar<30).apply_avg(foo.c.bar).scalar() == 14.5
+
def test_filter(self):
assert self.query.count() == 100
assert self.query.filter(Foo.c.bar < 30).count() == 30