summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
Diffstat (limited to 'test')
-rw-r--r--test/ext/asyncio/test_session_py3k.py50
1 files changed, 46 insertions, 4 deletions
diff --git a/test/ext/asyncio/test_session_py3k.py b/test/ext/asyncio/test_session_py3k.py
index a0aaf7ee0..459d95ea6 100644
--- a/test/ext/asyncio/test_session_py3k.py
+++ b/test/ext/asyncio/test_session_py3k.py
@@ -14,11 +14,13 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio.base import ReversibleProxy
from sqlalchemy.orm import relationship
from sqlalchemy.orm import selectinload
+from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from sqlalchemy.testing import async_test
from sqlalchemy.testing import engines
from sqlalchemy.testing import eq_
from sqlalchemy.testing import is_
+from sqlalchemy.testing import is_true
from sqlalchemy.testing import mock
from .test_engine_py3k import AsyncFixture as _AsyncFixture
from ...orm import _fixtures
@@ -722,8 +724,6 @@ class AsyncProxyTest(AsyncFixture):
is_(inspect(u3).async_session, None)
def test_inspect_session_no_asyncio_used(self):
- from sqlalchemy.orm import Session
-
User = self.classes.User
s1 = Session(testing.db)
@@ -732,8 +732,6 @@ class AsyncProxyTest(AsyncFixture):
is_(inspect(u1).async_session, None)
def test_inspect_session_no_asyncio_imported(self):
- from sqlalchemy.orm import Session
-
with mock.patch("sqlalchemy.orm.state._async_provider", None):
User = self.classes.User
@@ -756,3 +754,47 @@ class AsyncProxyTest(AsyncFixture):
del async_session
eq_(len(ReversibleProxy._proxy_objects), 0)
+
+
+class _MySession(Session):
+ pass
+
+
+class _MyAS(AsyncSession):
+ sync_session_class = _MySession
+
+
+class OverrideSyncSession(AsyncFixture):
+ def test_default(self, async_engine):
+ ass = AsyncSession(async_engine)
+
+ is_true(isinstance(ass.sync_session, Session))
+ is_(ass.sync_session.__class__, Session)
+ is_(ass.sync_session_class, Session)
+
+ def test_init_class(self, async_engine):
+ ass = AsyncSession(async_engine, sync_session_class=_MySession)
+
+ is_true(isinstance(ass.sync_session, _MySession))
+ is_(ass.sync_session_class, _MySession)
+
+ def test_init_sessionmaker(self, async_engine):
+ sm = sessionmaker(
+ async_engine, class_=AsyncSession, sync_session_class=_MySession
+ )
+ ass = sm()
+
+ is_true(isinstance(ass.sync_session, _MySession))
+ is_(ass.sync_session_class, _MySession)
+
+ def test_subclass(self, async_engine):
+ ass = _MyAS(async_engine)
+
+ is_true(isinstance(ass.sync_session, _MySession))
+ is_(ass.sync_session_class, _MySession)
+
+ def test_subclass_override(self, async_engine):
+ ass = _MyAS(async_engine, sync_session_class=Session)
+
+ is_true(not isinstance(ass.sync_session, _MySession))
+ is_(ass.sync_session_class, Session)