diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2008-11-09 16:06:05 +0000 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2008-11-09 16:06:05 +0000 |
commit | 043379efa5d61626c9a8ab42b15c7687c6e6a0fd (patch) | |
tree | a0d1012d6644f5d4ac227353ceaa5c1faffaa880 /lib/sqlalchemy | |
parent | 3f8914b4b28f309467b96f2903388e69cf8c2b2d (diff) | |
download | sqlalchemy-043379efa5d61626c9a8ab42b15c7687c6e6a0fd.tar.gz |
- Query.count() has been enhanced to do the "right
thing" in a wider variety of cases. It can now
count multiple-entity queries, as well as
column-based queries. Note that this means if you
say query(A, B).count() without any joining
criterion, it's going to count the cartesian
product of A*B. Any query which is against
column-based entities will automatically issue
"SELECT count(1) FROM (SELECT...)" so that the
real rowcount is returned, meaning a query such as
query(func.count(A.name)).count() will return a value of
one, since that query would return one row.
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r-- | lib/sqlalchemy/orm/query.py | 50 |
1 files changed, 38 insertions, 12 deletions
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 39e3db43c..81250706b 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -1246,28 +1246,54 @@ class Query(object): kwargs.get('distinct', False)) def count(self): - """Apply this query's criterion to a SELECT COUNT statement.""" - + """Apply this query's criterion to a SELECT COUNT statement. + + If column expressions or LIMIT/OFFSET/DISTINCT are present, + the query "SELECT count(1) FROM (SELECT ...)" is issued, + so that the result matches the total number of rows + this query would return. For mapped entities, + the primary key columns of each is written to the + columns clause of the nested SELECT statement. + + For a Query which is only against mapped entities, + a simpler "SELECT count(1) FROM table1, table2, ... + WHERE criterion" is issued. + + """ + should_nest = [self._should_nest_selectable] + def ent_cols(ent): + if isinstance(ent, _MapperEntity): + return ent.mapper.primary_key + else: + should_nest[0] = True + return [ent.column] + return self._col_aggregate(sql.literal_column('1'), sql.func.count, - nested_cols=list(self._only_mapper_zero( - "Can't issue count() for multiple types of objects or columns. " - " Construct the Query against a single element as the thing to be counted, " - "or for an actual row count use Query(func.count(somecolumn)) or " - "query.values(func.count(somecolumn)) instead.").primary_key)) + nested_cols=chain(*[ent_cols(ent) for ent in self._entities]), + should_nest = should_nest[0] + ) - def _col_aggregate(self, col, func, nested_cols=None): + def _col_aggregate(self, col, func, nested_cols=None, should_nest=False): context = QueryContext(self) + for entity in self._entities: + entity.setup_context(self, context) + + if context.from_clause: + from_obj = [context.from_clause] + else: + from_obj = context.froms + self._adjust_for_single_inheritance(context) whereclause = context.whereclause - from_obj = self.__mapper_zero_from_obj() - - if self._should_nest_selectable: + if should_nest: if not nested_cols: nested_cols = [col] - s = sql.select(nested_cols, whereclause, from_obj=from_obj, **self._select_args) + else: + nested_cols = list(nested_cols) + s = sql.select(nested_cols, whereclause, from_obj=from_obj, use_labels=True, **self._select_args) s = s.alias() s = sql.select([func(s.corresponding_column(col) or col)]).select_from(s) else: |