diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2008-11-09 16:06:05 +0000 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2008-11-09 16:06:05 +0000 |
commit | 043379efa5d61626c9a8ab42b15c7687c6e6a0fd (patch) | |
tree | a0d1012d6644f5d4ac227353ceaa5c1faffaa880 | |
parent | 3f8914b4b28f309467b96f2903388e69cf8c2b2d (diff) | |
download | sqlalchemy-043379efa5d61626c9a8ab42b15c7687c6e6a0fd.tar.gz |
- Query.count() has been enhanced to do the "right
thing" in a wider variety of cases. It can now
count multiple-entity queries, as well as
column-based queries. Note that this means if you
say query(A, B).count() without any joining
criterion, it's going to count the cartesian
product of A*B. Any query which is against
column-based entities will automatically issue
"SELECT count(1) FROM (SELECT...)" so that the
real rowcount is returned, meaning a query such as
query(func.count(A.name)).count() will return a value of
one, since that query would return one row.
-rw-r--r-- | CHANGES | 19 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/query.py | 50 | ||||
-rw-r--r-- | test/orm/query.py | 49 |
3 files changed, 100 insertions, 18 deletions
@@ -6,15 +6,30 @@ CHANGES ======= 0.5.0rc4 ======== +- features +- orm + - Query.count() has been enhanced to do the "right + thing" in a wider variety of cases. It can now + count multiple-entity queries, as well as + column-based queries. Note that this means if you + say query(A, B).count() without any joining + criterion, it's going to count the cartesian + product of A*B. Any query which is against + column-based entities will automatically issue + "SELECT count(1) FROM (SELECT...)" so that the + real rowcount is returned, meaning a query such as + query(func.count(A.name)).count() will return a value of + one, since that query would return one row. + - bugfixes and behavioral changes - general: - global "propigate"->"propagate" change. - orm - - Query.count() and Query.get() return a more informative + - Query.get() returns a more informative error message when executed against multiple entities. [ticket:1220] - + - access - Added support for Currency type. diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 39e3db43c..81250706b 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -1246,28 +1246,54 @@ class Query(object): kwargs.get('distinct', False)) def count(self): - """Apply this query's criterion to a SELECT COUNT statement.""" - + """Apply this query's criterion to a SELECT COUNT statement. + + If column expressions or LIMIT/OFFSET/DISTINCT are present, + the query "SELECT count(1) FROM (SELECT ...)" is issued, + so that the result matches the total number of rows + this query would return. For mapped entities, + the primary key columns of each is written to the + columns clause of the nested SELECT statement. + + For a Query which is only against mapped entities, + a simpler "SELECT count(1) FROM table1, table2, ... + WHERE criterion" is issued. + + """ + should_nest = [self._should_nest_selectable] + def ent_cols(ent): + if isinstance(ent, _MapperEntity): + return ent.mapper.primary_key + else: + should_nest[0] = True + return [ent.column] + return self._col_aggregate(sql.literal_column('1'), sql.func.count, - nested_cols=list(self._only_mapper_zero( - "Can't issue count() for multiple types of objects or columns. " - " Construct the Query against a single element as the thing to be counted, " - "or for an actual row count use Query(func.count(somecolumn)) or " - "query.values(func.count(somecolumn)) instead.").primary_key)) + nested_cols=chain(*[ent_cols(ent) for ent in self._entities]), + should_nest = should_nest[0] + ) - def _col_aggregate(self, col, func, nested_cols=None): + def _col_aggregate(self, col, func, nested_cols=None, should_nest=False): context = QueryContext(self) + for entity in self._entities: + entity.setup_context(self, context) + + if context.from_clause: + from_obj = [context.from_clause] + else: + from_obj = context.froms + self._adjust_for_single_inheritance(context) whereclause = context.whereclause - from_obj = self.__mapper_zero_from_obj() - - if self._should_nest_selectable: + if should_nest: if not nested_cols: nested_cols = [col] - s = sql.select(nested_cols, whereclause, from_obj=from_obj, **self._select_args) + else: + nested_cols = list(nested_cols) + s = sql.select(nested_cols, whereclause, from_obj=from_obj, use_labels=True, **self._select_args) s = s.alias() s = sql.select([func(s.corresponding_column(col) or col)]).select_from(s) else: diff --git a/test/orm/query.py b/test/orm/query.py index c90707342..3e2f327c3 100644 --- a/test/orm/query.py +++ b/test/orm/query.py @@ -240,8 +240,6 @@ class InvalidGenerationsTest(QueryTest): s = create_session() q = s.query(User, Address) - self.assertRaises(sa_exc.InvalidRequestError, q.count) - self.assertRaises(sa_exc.InvalidRequestError, q.get, 5) def test_from_statement(self): @@ -779,10 +777,53 @@ class AggregateTest(QueryTest): class CountTest(QueryTest): def test_basic(self): - assert 4 == create_session().query(User).count() + s = create_session() + + eq_(s.query(User).count(), 4) + + eq_(s.query(User).filter(users.c.name.endswith('ed')).count(), 2) + + def test_multiple_entity(self): + s = create_session() + q = s.query(User, Address) + eq_(q.count(), 20) # cartesian product + + q = s.query(User, Address).join(User.addresses) + eq_(q.count(), 5) + + def test_nested(self): + s = create_session() + q = s.query(User, Address).limit(2) + eq_(q.count(), 2) - assert 2 == create_session().query(User).filter(users.c.name.endswith('ed')).count() + q = s.query(User, Address).limit(100) + eq_(q.count(), 20) + q = s.query(User, Address).join(User.addresses).limit(100) + eq_(q.count(), 5) + + def test_cols(self): + """test that column-based queries always nest.""" + + s = create_session() + + q = s.query(func.count(distinct(User.name))) + eq_(q.count(), 1) + + q = s.query(func.count(distinct(User.name))).distinct() + eq_(q.count(), 1) + + q = s.query(User.name) + eq_(q.count(), 4) + + q = s.query(User.name, Address) + eq_(q.count(), 20) + + q = s.query(Address.user_id) + eq_(q.count(), 5) + eq_(q.distinct().count(), 3) + + class DistinctTest(QueryTest): def test_basic(self): assert [User(id=7), User(id=8), User(id=9),User(id=10)] == create_session().query(User).distinct().all() |