import sqlalchemy as sa from sqlalchemy import ForeignKey from sqlalchemy import func from sqlalchemy import Integer from sqlalchemy import testing from sqlalchemy.orm import relationship from sqlalchemy.testing import config from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.provision import normalize_sequence from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table from test.orm import _fixtures class GenerativeQueryTest(fixtures.MappedTest): run_inserts = "once" run_deletes = None @classmethod def define_tables(cls, metadata): Table( "foo", metadata, Column( "id", Integer, normalize_sequence(config, sa.Sequence("foo_id_seq")), primary_key=True, ), Column("bar", Integer), Column("range", Integer), ) @classmethod def fixtures(cls): rows = tuple([(i, i % 10) for i in range(100)]) foo_data = (("bar", "range"),) + rows return dict(foo=foo_data) @classmethod def setup_mappers(cls): foo = cls.tables.foo class Foo(cls.Basic): pass cls.mapper_registry.map_imperatively(Foo, foo) def test_selectby(self): Foo = self.classes.Foo res = fixture_session().query(Foo).filter_by(range=5) assert res.order_by(Foo.bar)[0].bar == 5 assert res.order_by(sa.desc(Foo.bar))[0].bar == 95 def test_slice(self): Foo = self.classes.Foo sess = fixture_session() query = sess.query(Foo).order_by(Foo.id) orig = query.all() assert query[1] == orig[1] assert list(query[10:20]) == orig[10:20] assert list(query[10:]) == orig[10:] assert list(query[:10]) == orig[:10] assert list(query[:10]) == orig[:10] assert list(query[5:5]) == orig[5:5] assert list(query[10:40:3]) == orig[10:40:3] # negative slices and indexes are deprecated and are tested # in test_query.py and test_deprecations.py assert query[10:20][5] == orig[10:20][5] def test_aggregate(self): foo, Foo = self.tables.foo, self.classes.Foo sess = fixture_session() query = sess.query(Foo) assert query.count() == 100 assert sess.query(func.min(foo.c.bar)).filter( foo.c.bar < 30 ).one() == (0,) assert sess.query(func.max(foo.c.bar)).filter( foo.c.bar < 30 ).one() == (29,) eq_( query.filter(foo.c.bar < 30) .with_entities(sa.func.max(foo.c.bar)) .scalar(), 29, ) @testing.fails_if( lambda: testing.against("mysql+mysqldb") and testing.db.dialect.dbapi.version_info[:4] == (1, 2, 1, "gamma"), "unknown incompatibility", ) def test_aggregate_1(self): foo = self.tables.foo query = fixture_session().query(func.sum(foo.c.bar)) assert query.filter(foo.c.bar < 30).one() == (435,) @testing.fails_on( "mssql", "AVG produces an average as the original column type on mssql.", ) def test_aggregate_2(self): foo = self.tables.foo query = fixture_session().query(func.avg(foo.c.bar)) avg = query.filter(foo.c.bar < 30).one()[0] eq_(float(round(avg, 1)), 14.5) @testing.fails_on( "mssql", "AVG produces an average as the original column type on mssql.", ) def test_aggregate_3(self): foo, Foo = self.tables.foo, self.classes.Foo query = fixture_session().query(Foo) avg_f = ( query.filter(foo.c.bar < 30) .with_entities(sa.func.avg(foo.c.bar)) .scalar() ) eq_(float(round(avg_f, 1)), 14.5) avg_o = ( query.filter(foo.c.bar < 30) .with_entities(sa.func.avg(foo.c.bar)) .scalar() ) eq_(float(round(avg_o, 1)), 14.5) def test_filter(self): Foo = self.classes.Foo query = fixture_session().query(Foo) assert query.count() == 100 assert query.filter(Foo.bar < 30).count() == 30 res2 = query.filter(Foo.bar < 30).filter(Foo.bar > 10) assert res2.count() == 19 def test_order_by(self): Foo = self.classes.Foo query = fixture_session().query(Foo) assert query.order_by(Foo.bar)[0].bar == 0 assert query.order_by(sa.desc(Foo.bar))[0].bar == 99 def test_offset_order_by(self): Foo = self.classes.Foo query = fixture_session().query(Foo) assert list(query.order_by(Foo.bar).offset(10))[0].bar == 10 def test_offset(self): Foo = self.classes.Foo query = fixture_session().query(Foo) assert len(list(query.limit(10))) == 10 class GenerativeTest2(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): Table("table1", metadata, Column("id", Integer, primary_key=True)) Table( "table2", metadata, Column("t1id", Integer, ForeignKey("table1.id"), primary_key=True), Column("num", Integer, primary_key=True), ) @classmethod def setup_mappers(cls): table2, table1 = cls.tables.table2, cls.tables.table1 class Obj1(cls.Basic): pass class Obj2(cls.Basic): pass cls.mapper_registry.map_imperatively(Obj1, table1) cls.mapper_registry.map_imperatively(Obj2, table2) @classmethod def fixtures(cls): return dict( table1=(("id",), (1,), (2,), (3,), (4,)), table2=( ("num", "t1id"), (1, 1), (2, 1), (3, 1), (4, 2), (5, 2), (6, 3), ), ) def test_distinct_count(self): table2, Obj1, table1 = ( self.tables.table2, self.classes.Obj1, self.tables.table1, ) query = fixture_session().query(Obj1) eq_(query.count(), 4) res = query.filter( sa.and_(table1.c.id == table2.c.t1id, table2.c.t1id == 1) ) eq_(res.count(), 3) res = query.filter( sa.and_(table1.c.id == table2.c.t1id, table2.c.t1id == 1) ).distinct() eq_(res.count(), 1) class RelationshipsTest(_fixtures.FixtureTest): run_setup_mappers = "once" run_inserts = "once" run_deletes = None @classmethod def setup_mappers(cls): addresses, Order, User, Address, orders, users = ( cls.tables.addresses, cls.classes.Order, cls.classes.User, cls.classes.Address, cls.tables.orders, cls.tables.users, ) cls.mapper_registry.map_imperatively( User, users, properties={ "orders": relationship( cls.mapper_registry.map_imperatively( Order, orders, properties={ "addresses": relationship( cls.mapper_registry.map_imperatively( Address, addresses ) ) }, ) ) }, ) def test_join(self): """Query.join""" User, Address = self.classes.User, self.classes.Address Order = self.classes.Order session = fixture_session() q = ( session.query(User) .outerjoin(User.orders) .outerjoin(Order.addresses) .filter(Address.id == 1) ) eq_([User(id=7)], q.all()) def test_outer_join(self): """Query.outerjoin""" Order, User, Address = ( self.classes.Order, self.classes.User, self.classes.Address, ) session = fixture_session() q = ( session.query(User) .outerjoin(User.orders) .outerjoin(Order.addresses) .filter(sa.or_(Order.id == None, Address.id == 1)) ) # noqa eq_({User(id=7), User(id=8), User(id=10)}, set(q.all())) def test_outer_join_count(self): """test the join and outerjoin functions on Query""" Order, User, Address = ( self.classes.Order, self.classes.User, self.classes.Address, ) session = fixture_session() q = ( session.query(User) .outerjoin(User.orders) .outerjoin(Order.addresses) .filter(sa.or_(Order.id == None, Address.id == 1)) ) # noqa eq_(q.count(), 4) def test_from(self): users, Order, User, Address, orders, addresses = ( self.tables.users, self.classes.Order, self.classes.User, self.classes.Address, self.tables.orders, self.tables.addresses, ) session = fixture_session() sel = users.outerjoin(orders).outerjoin( addresses, orders.c.address_id == addresses.c.id ) q = ( session.query(User) .select_from(sel) .filter(sa.or_(Order.id == None, Address.id == 1)) ) # noqa eq_({User(id=7), User(id=8), User(id=10)}, set(q.all())) class CaseSensitiveTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): Table("Table1", metadata, Column("ID", Integer, primary_key=True)) Table( "Table2", metadata, Column("T1ID", Integer, ForeignKey("Table1.ID"), primary_key=True), Column("NUM", Integer, primary_key=True), ) @classmethod def setup_mappers(cls): Table2, Table1 = cls.tables.Table2, cls.tables.Table1 class Obj1(cls.Basic): pass class Obj2(cls.Basic): pass cls.mapper_registry.map_imperatively(Obj1, Table1) cls.mapper_registry.map_imperatively(Obj2, Table2) @classmethod def fixtures(cls): return dict( Table1=(("ID",), (1,), (2,), (3,), (4,)), Table2=( ("NUM", "T1ID"), (1, 1), (2, 1), (3, 1), (4, 2), (5, 2), (6, 3), ), ) def test_distinct_count(self): Table2, Obj1, Table1 = ( self.tables.Table2, self.classes.Obj1, self.tables.Table1, ) q = fixture_session().query(Obj1) assert q.count() == 4 res = q.filter( sa.and_(Table1.c.ID == Table2.c.T1ID, Table2.c.T1ID == 1) ) assert res.count() == 3 res = q.filter( sa.and_(Table1.c.ID == Table2.c.T1ID, Table2.c.T1ID == 1) ).distinct() eq_(res.count(), 1)