blob: 223c7d9031ca642b4cc89a2447779d14d08893ee (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
|
from asyncio import current_task
import sqlalchemy as sa
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.ext.asyncio import async_scoped_session
from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession
from sqlalchemy.testing import async_test
from sqlalchemy.testing import eq_
from sqlalchemy.testing import is_
from .test_session_py3k import AsyncFixture
class AsyncScopedSessionTest(AsyncFixture):
@async_test
async def test_basic(self, async_engine):
AsyncSession = async_scoped_session(
sa.orm.sessionmaker(async_engine, class_=_AsyncSession),
scopefunc=current_task,
)
some_async_session = AsyncSession()
some_other_async_session = AsyncSession()
is_(some_async_session, some_other_async_session)
is_(some_async_session.bind, async_engine)
User = self.classes.User
async with AsyncSession.begin():
user_name = "scoped_async_session_u1"
u1 = User(name=user_name)
AsyncSession.add(u1)
await AsyncSession.flush()
conn = await AsyncSession.connection()
stmt = select(func.count(User.id)).where(User.name == user_name)
eq_(await conn.scalar(stmt), 1)
await AsyncSession.delete(u1)
await AsyncSession.flush()
eq_(await conn.scalar(stmt), 0)
|