summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ext
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/ext')
-rw-r--r--lib/sqlalchemy/ext/asyncio/engine.py19
-rw-r--r--lib/sqlalchemy/ext/asyncio/result.py22
-rw-r--r--lib/sqlalchemy/ext/asyncio/session.py4
3 files changed, 28 insertions, 17 deletions
diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py
index 9bbc04e77..fcf3b974d 100644
--- a/lib/sqlalchemy/ext/asyncio/engine.py
+++ b/lib/sqlalchemy/ext/asyncio/engine.py
@@ -7,6 +7,7 @@
from . import exc as async_exc
from .base import ProxyComparable
from .base import StartableContext
+from .result import _ensure_sync_result
from .result import AsyncResult
from ... import exc
from ... import inspection
@@ -381,15 +382,8 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
execution_options,
_require_await=True,
)
- if result.context._is_server_side:
- raise async_exc.AsyncMethodRequired(
- "Can't use the connection.exec_driver_sql() method with a "
- "server-side cursor."
- "Use the connection.stream() method for an async "
- "streaming result set."
- )
- return result
+ return await _ensure_sync_result(result, self.exec_driver_sql)
async def stream(
self,
@@ -462,14 +456,7 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
execution_options,
_require_await=True,
)
- if result.context._is_server_side:
- raise async_exc.AsyncMethodRequired(
- "Can't use the connection.execute() method with a "
- "server-side cursor."
- "Use the connection.stream() method for an async "
- "streaming result set."
- )
- return result
+ return await _ensure_sync_result(result, self.execute)
async def scalar(
self,
diff --git a/lib/sqlalchemy/ext/asyncio/result.py b/lib/sqlalchemy/ext/asyncio/result.py
index 81ef9915c..62e4a9a0e 100644
--- a/lib/sqlalchemy/ext/asyncio/result.py
+++ b/lib/sqlalchemy/ext/asyncio/result.py
@@ -7,6 +7,7 @@
import operator
+from . import exc as async_exc
from ...engine.result import _NO_ROW
from ...engine.result import FilterResult
from ...engine.result import FrozenResult
@@ -646,3 +647,24 @@ class AsyncMappingResult(AsyncCommon):
"""
return await greenlet_spawn(self._only_one_row, True, True, False)
+
+
+async def _ensure_sync_result(result, calling_method):
+ if not result._is_cursor:
+ cursor_result = getattr(result, "raw", None)
+ else:
+ cursor_result = result
+ if cursor_result and cursor_result.context._is_server_side:
+ await greenlet_spawn(cursor_result.close)
+ raise async_exc.AsyncMethodRequired(
+ "Can't use the %s.%s() method with a "
+ "server-side cursor. "
+ "Use the %s.stream() method for an async "
+ "streaming result set."
+ % (
+ calling_method.__self__.__class__.__name__,
+ calling_method.__name__,
+ calling_method.__self__.__class__.__name__,
+ )
+ )
+ return result
diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py
index 0840a0d7d..22de2cab1 100644
--- a/lib/sqlalchemy/ext/asyncio/session.py
+++ b/lib/sqlalchemy/ext/asyncio/session.py
@@ -8,6 +8,7 @@ from . import engine
from . import result as _result
from .base import ReversibleProxy
from .base import StartableContext
+from .result import _ensure_sync_result
from ... import util
from ...orm import object_session
from ...orm import Session
@@ -208,7 +209,7 @@ class AsyncSession(ReversibleProxy):
else:
execution_options = _EXECUTE_OPTIONS
- return await greenlet_spawn(
+ result = await greenlet_spawn(
self.sync_session.execute,
statement,
params=params,
@@ -216,6 +217,7 @@ class AsyncSession(ReversibleProxy):
bind_arguments=bind_arguments,
**kw,
)
+ return await _ensure_sync_result(result, self.execute)
async def scalar(
self,