summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CHANGES19
-rw-r--r--lib/sqlalchemy/orm/query.py50
-rw-r--r--test/orm/query.py49
3 files changed, 100 insertions, 18 deletions
diff --git a/CHANGES b/CHANGES
index 41f509c5f..c59eb826a 100644
--- a/CHANGES
+++ b/CHANGES
@@ -6,15 +6,30 @@ CHANGES
=======
0.5.0rc4
========
+- features
+- orm
+ - Query.count() has been enhanced to do the "right
+ thing" in a wider variety of cases. It can now
+ count multiple-entity queries, as well as
+ column-based queries. Note that this means if you
+ say query(A, B).count() without any joining
+ criterion, it's going to count the cartesian
+ product of A*B. Any query which is against
+ column-based entities will automatically issue
+ "SELECT count(1) FROM (SELECT...)" so that the
+ real rowcount is returned, meaning a query such as
+ query(func.count(A.name)).count() will return a value of
+ one, since that query would return one row.
+
- bugfixes and behavioral changes
- general:
- global "propigate"->"propagate" change.
- orm
- - Query.count() and Query.get() return a more informative
+ - Query.get() returns a more informative
error message when executed against multiple entities.
[ticket:1220]
-
+
- access
- Added support for Currency type.
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index 39e3db43c..81250706b 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -1246,28 +1246,54 @@ class Query(object):
kwargs.get('distinct', False))
def count(self):
- """Apply this query's criterion to a SELECT COUNT statement."""
-
+ """Apply this query's criterion to a SELECT COUNT statement.
+
+ If column expressions or LIMIT/OFFSET/DISTINCT are present,
+ the query "SELECT count(1) FROM (SELECT ...)" is issued,
+ so that the result matches the total number of rows
+ this query would return. For mapped entities,
+ the primary key columns of each is written to the
+ columns clause of the nested SELECT statement.
+
+ For a Query which is only against mapped entities,
+ a simpler "SELECT count(1) FROM table1, table2, ...
+ WHERE criterion" is issued.
+
+ """
+ should_nest = [self._should_nest_selectable]
+ def ent_cols(ent):
+ if isinstance(ent, _MapperEntity):
+ return ent.mapper.primary_key
+ else:
+ should_nest[0] = True
+ return [ent.column]
+
return self._col_aggregate(sql.literal_column('1'), sql.func.count,
- nested_cols=list(self._only_mapper_zero(
- "Can't issue count() for multiple types of objects or columns. "
- " Construct the Query against a single element as the thing to be counted, "
- "or for an actual row count use Query(func.count(somecolumn)) or "
- "query.values(func.count(somecolumn)) instead.").primary_key))
+ nested_cols=chain(*[ent_cols(ent) for ent in self._entities]),
+ should_nest = should_nest[0]
+ )
- def _col_aggregate(self, col, func, nested_cols=None):
+ def _col_aggregate(self, col, func, nested_cols=None, should_nest=False):
context = QueryContext(self)
+ for entity in self._entities:
+ entity.setup_context(self, context)
+
+ if context.from_clause:
+ from_obj = [context.from_clause]
+ else:
+ from_obj = context.froms
+
self._adjust_for_single_inheritance(context)
whereclause = context.whereclause
- from_obj = self.__mapper_zero_from_obj()
-
- if self._should_nest_selectable:
+ if should_nest:
if not nested_cols:
nested_cols = [col]
- s = sql.select(nested_cols, whereclause, from_obj=from_obj, **self._select_args)
+ else:
+ nested_cols = list(nested_cols)
+ s = sql.select(nested_cols, whereclause, from_obj=from_obj, use_labels=True, **self._select_args)
s = s.alias()
s = sql.select([func(s.corresponding_column(col) or col)]).select_from(s)
else:
diff --git a/test/orm/query.py b/test/orm/query.py
index c90707342..3e2f327c3 100644
--- a/test/orm/query.py
+++ b/test/orm/query.py
@@ -240,8 +240,6 @@ class InvalidGenerationsTest(QueryTest):
s = create_session()
q = s.query(User, Address)
- self.assertRaises(sa_exc.InvalidRequestError, q.count)
-
self.assertRaises(sa_exc.InvalidRequestError, q.get, 5)
def test_from_statement(self):
@@ -779,10 +777,53 @@ class AggregateTest(QueryTest):
class CountTest(QueryTest):
def test_basic(self):
- assert 4 == create_session().query(User).count()
+ s = create_session()
+
+ eq_(s.query(User).count(), 4)
+
+ eq_(s.query(User).filter(users.c.name.endswith('ed')).count(), 2)
+
+ def test_multiple_entity(self):
+ s = create_session()
+ q = s.query(User, Address)
+ eq_(q.count(), 20) # cartesian product
+
+ q = s.query(User, Address).join(User.addresses)
+ eq_(q.count(), 5)
+
+ def test_nested(self):
+ s = create_session()
+ q = s.query(User, Address).limit(2)
+ eq_(q.count(), 2)
- assert 2 == create_session().query(User).filter(users.c.name.endswith('ed')).count()
+ q = s.query(User, Address).limit(100)
+ eq_(q.count(), 20)
+ q = s.query(User, Address).join(User.addresses).limit(100)
+ eq_(q.count(), 5)
+
+ def test_cols(self):
+ """test that column-based queries always nest."""
+
+ s = create_session()
+
+ q = s.query(func.count(distinct(User.name)))
+ eq_(q.count(), 1)
+
+ q = s.query(func.count(distinct(User.name))).distinct()
+ eq_(q.count(), 1)
+
+ q = s.query(User.name)
+ eq_(q.count(), 4)
+
+ q = s.query(User.name, Address)
+ eq_(q.count(), 20)
+
+ q = s.query(Address.user_id)
+ eq_(q.count(), 5)
+ eq_(q.distinct().count(), 3)
+
+
class DistinctTest(QueryTest):
def test_basic(self):
assert [User(id=7), User(id=8), User(id=9),User(id=10)] == create_session().query(User).distinct().all()