summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/engine/cursor.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/engine/cursor.py')
-rw-r--r--lib/sqlalchemy/engine/cursor.py116
1 files changed, 96 insertions, 20 deletions
diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py
index f4e22df2d..d5f0d8126 100644
--- a/lib/sqlalchemy/engine/cursor.py
+++ b/lib/sqlalchemy/engine/cursor.py
@@ -21,6 +21,7 @@ from typing import ClassVar
from typing import Dict
from typing import Iterator
from typing import List
+from typing import NoReturn
from typing import Optional
from typing import Sequence
from typing import Tuple
@@ -53,7 +54,11 @@ _UNPICKLED = util.symbol("unpickled")
if typing.TYPE_CHECKING:
+ from .base import Connection
+ from .default import DefaultExecutionContext
from .interfaces import _DBAPICursorDescription
+ from .interfaces import DBAPICursor
+ from .interfaces import Dialect
from .interfaces import ExecutionContext
from .result import _KeyIndexType
from .result import _KeyMapRecType
@@ -61,6 +66,7 @@ if typing.TYPE_CHECKING:
from .result import _ProcessorsType
from ..sql.type_api import _ResultProcessorType
+
_T = TypeVar("_T", bound=Any)
# metadata entry tuple indexes.
@@ -235,7 +241,7 @@ class CursorResultMetaData(ResultMetaData):
) = context.result_column_struct
num_ctx_cols = len(result_columns)
else:
- result_columns = (
+ result_columns = ( # type: ignore
cols_are_ordered
) = (
num_ctx_cols
@@ -776,25 +782,53 @@ class ResultFetchStrategy:
alternate_cursor_description: Optional[_DBAPICursorDescription] = None
- def soft_close(self, result, dbapi_cursor):
+ def soft_close(
+ self, result: CursorResult[Any], dbapi_cursor: Optional[DBAPICursor]
+ ) -> None:
raise NotImplementedError()
- def hard_close(self, result, dbapi_cursor):
+ def hard_close(
+ self, result: CursorResult[Any], dbapi_cursor: Optional[DBAPICursor]
+ ) -> None:
raise NotImplementedError()
- def yield_per(self, result, dbapi_cursor, num):
+ def yield_per(
+ self,
+ result: CursorResult[Any],
+ dbapi_cursor: Optional[DBAPICursor],
+ num: int,
+ ) -> None:
return
- def fetchone(self, result, dbapi_cursor, hard_close=False):
+ def fetchone(
+ self,
+ result: CursorResult[Any],
+ dbapi_cursor: DBAPICursor,
+ hard_close: bool = False,
+ ) -> Any:
raise NotImplementedError()
- def fetchmany(self, result, dbapi_cursor, size=None):
+ def fetchmany(
+ self,
+ result: CursorResult[Any],
+ dbapi_cursor: DBAPICursor,
+ size: Optional[int] = None,
+ ) -> Any:
raise NotImplementedError()
- def fetchall(self, result):
+ def fetchall(
+ self,
+ result: CursorResult[Any],
+ dbapi_cursor: DBAPICursor,
+ ) -> Any:
raise NotImplementedError()
- def handle_exception(self, result, dbapi_cursor, err):
+ def handle_exception(
+ self,
+ result: CursorResult[Any],
+ dbapi_cursor: Optional[DBAPICursor],
+ err: BaseException,
+ ) -> NoReturn:
raise err
@@ -882,18 +916,32 @@ class CursorFetchStrategy(ResultFetchStrategy):
__slots__ = ()
- def soft_close(self, result, dbapi_cursor):
+ def soft_close(
+ self, result: CursorResult[Any], dbapi_cursor: Optional[DBAPICursor]
+ ) -> None:
result.cursor_strategy = _NO_CURSOR_DQL
- def hard_close(self, result, dbapi_cursor):
+ def hard_close(
+ self, result: CursorResult[Any], dbapi_cursor: Optional[DBAPICursor]
+ ) -> None:
result.cursor_strategy = _NO_CURSOR_DQL
- def handle_exception(self, result, dbapi_cursor, err):
+ def handle_exception(
+ self,
+ result: CursorResult[Any],
+ dbapi_cursor: Optional[DBAPICursor],
+ err: BaseException,
+ ) -> NoReturn:
result.connection._handle_dbapi_exception(
err, None, None, dbapi_cursor, result.context
)
- def yield_per(self, result, dbapi_cursor, num):
+ def yield_per(
+ self,
+ result: CursorResult[Any],
+ dbapi_cursor: Optional[DBAPICursor],
+ num: int,
+ ) -> None:
result.cursor_strategy = BufferedRowCursorFetchStrategy(
dbapi_cursor,
{"max_row_buffer": num},
@@ -901,7 +949,12 @@ class CursorFetchStrategy(ResultFetchStrategy):
growth_factor=0,
)
- def fetchone(self, result, dbapi_cursor, hard_close=False):
+ def fetchone(
+ self,
+ result: CursorResult[Any],
+ dbapi_cursor: DBAPICursor,
+ hard_close: bool = False,
+ ) -> Any:
try:
row = dbapi_cursor.fetchone()
if row is None:
@@ -910,7 +963,12 @@ class CursorFetchStrategy(ResultFetchStrategy):
except BaseException as e:
self.handle_exception(result, dbapi_cursor, e)
- def fetchmany(self, result, dbapi_cursor, size=None):
+ def fetchmany(
+ self,
+ result: CursorResult[Any],
+ dbapi_cursor: DBAPICursor,
+ size: Optional[int] = None,
+ ) -> Any:
try:
if size is None:
l = dbapi_cursor.fetchmany()
@@ -923,7 +981,11 @@ class CursorFetchStrategy(ResultFetchStrategy):
except BaseException as e:
self.handle_exception(result, dbapi_cursor, e)
- def fetchall(self, result, dbapi_cursor):
+ def fetchall(
+ self,
+ result: CursorResult[Any],
+ dbapi_cursor: DBAPICursor,
+ ) -> Any:
try:
rows = dbapi_cursor.fetchall()
result._soft_close()
@@ -1163,6 +1225,9 @@ class _NoResultMetaData(ResultMetaData):
_NO_RESULT_METADATA = _NoResultMetaData()
+SelfCursorResult = TypeVar("SelfCursorResult", bound="CursorResult[Any]")
+
+
class CursorResult(Result[_T]):
"""A Result that is representing state from a DBAPI cursor.
@@ -1199,7 +1264,17 @@ class CursorResult(Result[_T]):
closed: bool = False
_is_cursor = True
- def __init__(self, context, cursor_strategy, cursor_description):
+ context: DefaultExecutionContext
+ dialect: Dialect
+ cursor_strategy: ResultFetchStrategy
+ connection: Connection
+
+ def __init__(
+ self,
+ context: DefaultExecutionContext,
+ cursor_strategy: ResultFetchStrategy,
+ cursor_description: Optional[_DBAPICursorDescription],
+ ):
self.context = context
self.dialect = context.dialect
self.cursor = context.cursor
@@ -1333,7 +1408,7 @@ class CursorResult(Result[_T]):
if not self._soft_closed:
cursor = self.cursor
- self.cursor = None
+ self.cursor = None # type: ignore
self.connection._safe_close_cursor(cursor)
self._soft_closed = True
@@ -1605,7 +1680,7 @@ class CursorResult(Result[_T]):
return self.dialect.supports_sane_multi_rowcount
@util.memoized_property
- def rowcount(self):
+ def rowcount(self) -> int:
"""Return the 'rowcount' for this result.
The 'rowcount' reports the number of rows *matched*
@@ -1655,6 +1730,7 @@ class CursorResult(Result[_T]):
return self.context.rowcount
except BaseException as e:
self.cursor_strategy.handle_exception(self, self.cursor, e)
+ raise # not called
@property
def lastrowid(self):
@@ -1749,7 +1825,7 @@ class CursorResult(Result[_T]):
)
return merged_result
- def close(self):
+ def close(self) -> Any:
"""Close this :class:`_engine.CursorResult`.
This closes out the underlying DBAPI cursor corresponding to the
@@ -1772,7 +1848,7 @@ class CursorResult(Result[_T]):
self._soft_close(hard=True)
@_generative
- def yield_per(self, num):
+ def yield_per(self: SelfCursorResult, num: int) -> SelfCursorResult:
self._yield_per = num
self.cursor_strategy.yield_per(self, self.cursor, num)
return self