summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CHANGES2
-rw-r--r--lib/sqlalchemy/ext/selectresults.py2
-rw-r--r--lib/sqlalchemy/orm/query.py13
-rw-r--r--test/orm/selectresults.py42
4 files changed, 55 insertions, 4 deletions
diff --git a/CHANGES b/CHANGES
index 085a5979e..9ae530ace 100644
--- a/CHANGES
+++ b/CHANGES
@@ -49,6 +49,8 @@ so far will convert this to "TIME[STAMP] (WITH|WITHOUT) TIME ZONE",
so that control over timezone presence is more controllable (psycopg2
returns datetimes with tzinfo's if available, which can create confusion
against datetimes that dont).
+- fix to using query.count() with distinct, **kwargs with SelectResults
+count() [ticket:287]
0.2.7
- quoting facilities set up so that database-specific quoting can be
diff --git a/lib/sqlalchemy/ext/selectresults.py b/lib/sqlalchemy/ext/selectresults.py
index 79d56ec67..a35cdfa7e 100644
--- a/lib/sqlalchemy/ext/selectresults.py
+++ b/lib/sqlalchemy/ext/selectresults.py
@@ -28,7 +28,7 @@ class SelectResults(object):
def count(self):
"""executes the SQL count() function against the SelectResults criterion."""
- return self._query.count(self._clause)
+ return self._query.count(self._clause, **self._ops)
def _col_aggregate(self, col, func):
"""executes func() function against the given column
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index 1e9d40c75..29cc56761 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -232,7 +232,10 @@ class Query(object):
return self._select_statement(statement, params=params)
def count(self, whereclause=None, params=None, **kwargs):
- s = self.table.count(whereclause)
+ if self._nestable(**kwargs):
+ s = self.table.select(whereclause, **kwargs).alias('getcount').count()
+ else:
+ s = self.table.count(whereclause)
return self.session.scalar(self.mapper, s, params=params)
def select_statement(self, statement, **params):
@@ -302,14 +305,18 @@ class Query(object):
return self.instances(statement, params=params, **kwargs)
def _should_nest(self, **kwargs):
- """returns True if the given statement options indicate that we should "nest" the
+ """return True if the given statement options indicate that we should "nest" the
generated query as a subquery inside of a larger eager-loading query. this is used
with keywords like distinct, limit and offset and the mapper defines eager loads."""
return (
self.mapper.has_eager()
- and (kwargs.has_key('limit') or kwargs.has_key('offset') or kwargs.get('distinct', False))
+ and self._nestable(**kwargs)
)
+ def _nestable(self, **kwargs):
+ """return true if the given statement options imply it should be nested."""
+ return (kwargs.has_key('limit') or kwargs.has_key('offset') or kwargs.get('distinct', False))
+
def compile(self, whereclause = None, **kwargs):
order_by = kwargs.pop('order_by', False)
from_obj = kwargs.pop('from_obj', [])
diff --git a/test/orm/selectresults.py b/test/orm/selectresults.py
index 3f5bcff92..c4b1d6a56 100644
--- a/test/orm/selectresults.py
+++ b/test/orm/selectresults.py
@@ -79,6 +79,48 @@ class SelectResultsTest(PersistTest):
def test_offset(self):
assert len(list(self.res.limit(10))) == 10
+class Obj1(object):
+ pass
+class Obj2(object):
+ pass
+
+class SelectResultsTest2(PersistTest):
+ def setUpAll(self):
+ self.install_threadlocal()
+ global metadata, table1, table2
+ metadata = BoundMetaData(testbase.db)
+ table1 = Table('Table1', metadata,
+ Column('id', Integer, primary_key=True),
+ )
+ table2 = Table('Table2', metadata,
+ Column('t1id', Integer, ForeignKey("Table1.id"), primary_key=True),
+ Column('num', Integer, primary_key=True),
+ )
+ assign_mapper(Obj1, table1, extension=SelectResultsExt())
+ assign_mapper(Obj2, table2, extension=SelectResultsExt())
+ metadata.create_all()
+ table1.insert().execute({'id':1},{'id':2},{'id':3},{'id':4})
+ table2.insert().execute({'num':1,'t1id':1},{'num':2,'t1id':1},{'num':3,'t1id':1},\
+{'num':4,'t1id':2},{'num':5,'t1id':2},{'num':6,'t1id':3})
+
+ def setUp(self):
+ self.query = Obj1.mapper.query()
+ #self.orig = self.query.select_whereclause()
+ #self.res = self.query.select()
+
+ def tearDownAll(self):
+ metadata.drop_all()
+ self.uninstall_threadlocal()
+
+ def test_distinctcount(self):
+ res = self.query.select()
+ assert res.count() == 4
+ res = self.query.select(and_(table1.c.id==table2.c.t1id,table2.c.t1id==1))
+ assert res.count() == 3
+ res = self.query.select(and_(table1.c.id==table2.c.t1id,table2.c.t1id==1), distinct=True)
+ self.assertEqual(res.count(), 1)
+
+
if __name__ == "__main__":
testbase.main()