diff options
Diffstat (limited to 'lib/sqlalchemy/engine/cursor.py')
-rw-r--r-- | lib/sqlalchemy/engine/cursor.py | 116 |
1 files changed, 96 insertions, 20 deletions
diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py index f4e22df2d..d5f0d8126 100644 --- a/lib/sqlalchemy/engine/cursor.py +++ b/lib/sqlalchemy/engine/cursor.py @@ -21,6 +21,7 @@ from typing import ClassVar from typing import Dict from typing import Iterator from typing import List +from typing import NoReturn from typing import Optional from typing import Sequence from typing import Tuple @@ -53,7 +54,11 @@ _UNPICKLED = util.symbol("unpickled") if typing.TYPE_CHECKING: + from .base import Connection + from .default import DefaultExecutionContext from .interfaces import _DBAPICursorDescription + from .interfaces import DBAPICursor + from .interfaces import Dialect from .interfaces import ExecutionContext from .result import _KeyIndexType from .result import _KeyMapRecType @@ -61,6 +66,7 @@ if typing.TYPE_CHECKING: from .result import _ProcessorsType from ..sql.type_api import _ResultProcessorType + _T = TypeVar("_T", bound=Any) # metadata entry tuple indexes. @@ -235,7 +241,7 @@ class CursorResultMetaData(ResultMetaData): ) = context.result_column_struct num_ctx_cols = len(result_columns) else: - result_columns = ( + result_columns = ( # type: ignore cols_are_ordered ) = ( num_ctx_cols @@ -776,25 +782,53 @@ class ResultFetchStrategy: alternate_cursor_description: Optional[_DBAPICursorDescription] = None - def soft_close(self, result, dbapi_cursor): + def soft_close( + self, result: CursorResult[Any], dbapi_cursor: Optional[DBAPICursor] + ) -> None: raise NotImplementedError() - def hard_close(self, result, dbapi_cursor): + def hard_close( + self, result: CursorResult[Any], dbapi_cursor: Optional[DBAPICursor] + ) -> None: raise NotImplementedError() - def yield_per(self, result, dbapi_cursor, num): + def yield_per( + self, + result: CursorResult[Any], + dbapi_cursor: Optional[DBAPICursor], + num: int, + ) -> None: return - def fetchone(self, result, dbapi_cursor, hard_close=False): + def fetchone( + self, + result: CursorResult[Any], + dbapi_cursor: DBAPICursor, + hard_close: bool = False, + ) -> Any: raise NotImplementedError() - def fetchmany(self, result, dbapi_cursor, size=None): + def fetchmany( + self, + result: CursorResult[Any], + dbapi_cursor: DBAPICursor, + size: Optional[int] = None, + ) -> Any: raise NotImplementedError() - def fetchall(self, result): + def fetchall( + self, + result: CursorResult[Any], + dbapi_cursor: DBAPICursor, + ) -> Any: raise NotImplementedError() - def handle_exception(self, result, dbapi_cursor, err): + def handle_exception( + self, + result: CursorResult[Any], + dbapi_cursor: Optional[DBAPICursor], + err: BaseException, + ) -> NoReturn: raise err @@ -882,18 +916,32 @@ class CursorFetchStrategy(ResultFetchStrategy): __slots__ = () - def soft_close(self, result, dbapi_cursor): + def soft_close( + self, result: CursorResult[Any], dbapi_cursor: Optional[DBAPICursor] + ) -> None: result.cursor_strategy = _NO_CURSOR_DQL - def hard_close(self, result, dbapi_cursor): + def hard_close( + self, result: CursorResult[Any], dbapi_cursor: Optional[DBAPICursor] + ) -> None: result.cursor_strategy = _NO_CURSOR_DQL - def handle_exception(self, result, dbapi_cursor, err): + def handle_exception( + self, + result: CursorResult[Any], + dbapi_cursor: Optional[DBAPICursor], + err: BaseException, + ) -> NoReturn: result.connection._handle_dbapi_exception( err, None, None, dbapi_cursor, result.context ) - def yield_per(self, result, dbapi_cursor, num): + def yield_per( + self, + result: CursorResult[Any], + dbapi_cursor: Optional[DBAPICursor], + num: int, + ) -> None: result.cursor_strategy = BufferedRowCursorFetchStrategy( dbapi_cursor, {"max_row_buffer": num}, @@ -901,7 +949,12 @@ class CursorFetchStrategy(ResultFetchStrategy): growth_factor=0, ) - def fetchone(self, result, dbapi_cursor, hard_close=False): + def fetchone( + self, + result: CursorResult[Any], + dbapi_cursor: DBAPICursor, + hard_close: bool = False, + ) -> Any: try: row = dbapi_cursor.fetchone() if row is None: @@ -910,7 +963,12 @@ class CursorFetchStrategy(ResultFetchStrategy): except BaseException as e: self.handle_exception(result, dbapi_cursor, e) - def fetchmany(self, result, dbapi_cursor, size=None): + def fetchmany( + self, + result: CursorResult[Any], + dbapi_cursor: DBAPICursor, + size: Optional[int] = None, + ) -> Any: try: if size is None: l = dbapi_cursor.fetchmany() @@ -923,7 +981,11 @@ class CursorFetchStrategy(ResultFetchStrategy): except BaseException as e: self.handle_exception(result, dbapi_cursor, e) - def fetchall(self, result, dbapi_cursor): + def fetchall( + self, + result: CursorResult[Any], + dbapi_cursor: DBAPICursor, + ) -> Any: try: rows = dbapi_cursor.fetchall() result._soft_close() @@ -1163,6 +1225,9 @@ class _NoResultMetaData(ResultMetaData): _NO_RESULT_METADATA = _NoResultMetaData() +SelfCursorResult = TypeVar("SelfCursorResult", bound="CursorResult[Any]") + + class CursorResult(Result[_T]): """A Result that is representing state from a DBAPI cursor. @@ -1199,7 +1264,17 @@ class CursorResult(Result[_T]): closed: bool = False _is_cursor = True - def __init__(self, context, cursor_strategy, cursor_description): + context: DefaultExecutionContext + dialect: Dialect + cursor_strategy: ResultFetchStrategy + connection: Connection + + def __init__( + self, + context: DefaultExecutionContext, + cursor_strategy: ResultFetchStrategy, + cursor_description: Optional[_DBAPICursorDescription], + ): self.context = context self.dialect = context.dialect self.cursor = context.cursor @@ -1333,7 +1408,7 @@ class CursorResult(Result[_T]): if not self._soft_closed: cursor = self.cursor - self.cursor = None + self.cursor = None # type: ignore self.connection._safe_close_cursor(cursor) self._soft_closed = True @@ -1605,7 +1680,7 @@ class CursorResult(Result[_T]): return self.dialect.supports_sane_multi_rowcount @util.memoized_property - def rowcount(self): + def rowcount(self) -> int: """Return the 'rowcount' for this result. The 'rowcount' reports the number of rows *matched* @@ -1655,6 +1730,7 @@ class CursorResult(Result[_T]): return self.context.rowcount except BaseException as e: self.cursor_strategy.handle_exception(self, self.cursor, e) + raise # not called @property def lastrowid(self): @@ -1749,7 +1825,7 @@ class CursorResult(Result[_T]): ) return merged_result - def close(self): + def close(self) -> Any: """Close this :class:`_engine.CursorResult`. This closes out the underlying DBAPI cursor corresponding to the @@ -1772,7 +1848,7 @@ class CursorResult(Result[_T]): self._soft_close(hard=True) @_generative - def yield_per(self, num): + def yield_per(self: SelfCursorResult, num: int) -> SelfCursorResult: self._yield_per = num self.cursor_strategy.yield_per(self, self.cursor, num) return self |