summaryrefslogtreecommitdiff
path: root/test/orm/selectresults.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/orm/selectresults.py')
-rw-r--r--test/orm/selectresults.py84
1 files changed, 84 insertions, 0 deletions
diff --git a/test/orm/selectresults.py b/test/orm/selectresults.py
new file mode 100644
index 000000000..3f5bcff92
--- /dev/null
+++ b/test/orm/selectresults.py
@@ -0,0 +1,84 @@
+from testbase import PersistTest
+import testbase
+
+from sqlalchemy import *
+
+from sqlalchemy.ext.selectresults import SelectResultsExt
+
+class Foo(object):
+ pass
+
+class SelectResultsTest(PersistTest):
+ def setUpAll(self):
+ self.install_threadlocal()
+ global foo
+ foo = Table('foo', testbase.db,
+ Column('id', Integer, Sequence('foo_id_seq'), primary_key=True),
+ Column('bar', Integer),
+ Column('range', Integer))
+
+ assign_mapper(Foo, foo, extension=SelectResultsExt())
+ foo.create()
+ for i in range(100):
+ Foo(bar=i, range=i%10)
+ objectstore.flush()
+
+ def setUp(self):
+ self.query = Foo.mapper.query()
+ self.orig = self.query.select_whereclause()
+ self.res = self.query.select()
+
+ def tearDownAll(self):
+ global foo
+ foo.drop()
+ self.uninstall_threadlocal()
+
+ def test_selectby(self):
+ res = self.query.select_by(range=5)
+ assert res.order_by([Foo.c.bar])[0].bar == 5
+ assert res.order_by([desc(Foo.c.bar)])[0].bar == 95
+
+ def test_slice(self):
+ assert self.res[1] == self.orig[1]
+ assert list(self.res[10:20]) == self.orig[10:20]
+ assert list(self.res[10:]) == self.orig[10:]
+ assert list(self.res[:10]) == self.orig[:10]
+ assert list(self.res[:10]) == self.orig[:10]
+ assert list(self.res[10:40:3]) == self.orig[10:40:3]
+ assert list(self.res[-5:]) == self.orig[-5:]
+ assert self.res[10:20][5] == self.orig[10:20][5]
+
+ def test_aggregate(self):
+ assert self.res.count() == 100
+ assert self.res.filter(foo.c.bar<30).min(foo.c.bar) == 0
+ assert self.res.filter(foo.c.bar<30).max(foo.c.bar) == 29
+
+ @testbase.unsupported('mysql')
+ def test_aggregate_1(self):
+ # this one fails in mysql as the result comes back as a string
+ assert self.res.filter(foo.c.bar<30).sum(foo.c.bar) == 435
+
+ @testbase.unsupported('postgres', 'mysql', 'firebird')
+ def test_aggregate_2(self):
+ # this one fails with postgres, the floating point comparison fails
+ assert self.res.filter(foo.c.bar<30).avg(foo.c.bar) == 14.5
+
+ def test_filter(self):
+ assert self.res.count() == 100
+ assert self.res.filter(Foo.c.bar < 30).count() == 30
+ res2 = self.res.filter(Foo.c.bar < 30).filter(Foo.c.bar > 10)
+ assert res2.count() == 19
+
+ def test_order_by(self):
+ assert self.res.order_by([Foo.c.bar])[0].bar == 0
+ assert self.res.order_by([desc(Foo.c.bar)])[0].bar == 99
+
+ def test_offset(self):
+ assert list(self.res.order_by([Foo.c.bar]).offset(10))[0].bar == 10
+
+ def test_offset(self):
+ assert len(list(self.res.limit(10))) == 10
+
+
+if __name__ == "__main__":
+ testbase.main()