diff options
Diffstat (limited to 'lib/sqlalchemy/engine')
-rw-r--r-- | lib/sqlalchemy/engine/cursor.py | 116 | ||||
-rw-r--r-- | lib/sqlalchemy/engine/default.py | 7 | ||||
-rw-r--r-- | lib/sqlalchemy/engine/interfaces.py | 25 |
3 files changed, 113 insertions, 35 deletions
diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py index f4e22df2d..d5f0d8126 100644 --- a/lib/sqlalchemy/engine/cursor.py +++ b/lib/sqlalchemy/engine/cursor.py @@ -21,6 +21,7 @@ from typing import ClassVar from typing import Dict from typing import Iterator from typing import List +from typing import NoReturn from typing import Optional from typing import Sequence from typing import Tuple @@ -53,7 +54,11 @@ _UNPICKLED = util.symbol("unpickled") if typing.TYPE_CHECKING: + from .base import Connection + from .default import DefaultExecutionContext from .interfaces import _DBAPICursorDescription + from .interfaces import DBAPICursor + from .interfaces import Dialect from .interfaces import ExecutionContext from .result import _KeyIndexType from .result import _KeyMapRecType @@ -61,6 +66,7 @@ if typing.TYPE_CHECKING: from .result import _ProcessorsType from ..sql.type_api import _ResultProcessorType + _T = TypeVar("_T", bound=Any) # metadata entry tuple indexes. @@ -235,7 +241,7 @@ class CursorResultMetaData(ResultMetaData): ) = context.result_column_struct num_ctx_cols = len(result_columns) else: - result_columns = ( + result_columns = ( # type: ignore cols_are_ordered ) = ( num_ctx_cols @@ -776,25 +782,53 @@ class ResultFetchStrategy: alternate_cursor_description: Optional[_DBAPICursorDescription] = None - def soft_close(self, result, dbapi_cursor): + def soft_close( + self, result: CursorResult[Any], dbapi_cursor: Optional[DBAPICursor] + ) -> None: raise NotImplementedError() - def hard_close(self, result, dbapi_cursor): + def hard_close( + self, result: CursorResult[Any], dbapi_cursor: Optional[DBAPICursor] + ) -> None: raise NotImplementedError() - def yield_per(self, result, dbapi_cursor, num): + def yield_per( + self, + result: CursorResult[Any], + dbapi_cursor: Optional[DBAPICursor], + num: int, + ) -> None: return - def fetchone(self, result, dbapi_cursor, hard_close=False): + def fetchone( + self, + result: CursorResult[Any], + dbapi_cursor: DBAPICursor, + hard_close: bool = False, + ) -> Any: raise NotImplementedError() - def fetchmany(self, result, dbapi_cursor, size=None): + def fetchmany( + self, + result: CursorResult[Any], + dbapi_cursor: DBAPICursor, + size: Optional[int] = None, + ) -> Any: raise NotImplementedError() - def fetchall(self, result): + def fetchall( + self, + result: CursorResult[Any], + dbapi_cursor: DBAPICursor, + ) -> Any: raise NotImplementedError() - def handle_exception(self, result, dbapi_cursor, err): + def handle_exception( + self, + result: CursorResult[Any], + dbapi_cursor: Optional[DBAPICursor], + err: BaseException, + ) -> NoReturn: raise err @@ -882,18 +916,32 @@ class CursorFetchStrategy(ResultFetchStrategy): __slots__ = () - def soft_close(self, result, dbapi_cursor): + def soft_close( + self, result: CursorResult[Any], dbapi_cursor: Optional[DBAPICursor] + ) -> None: result.cursor_strategy = _NO_CURSOR_DQL - def hard_close(self, result, dbapi_cursor): + def hard_close( + self, result: CursorResult[Any], dbapi_cursor: Optional[DBAPICursor] + ) -> None: result.cursor_strategy = _NO_CURSOR_DQL - def handle_exception(self, result, dbapi_cursor, err): + def handle_exception( + self, + result: CursorResult[Any], + dbapi_cursor: Optional[DBAPICursor], + err: BaseException, + ) -> NoReturn: result.connection._handle_dbapi_exception( err, None, None, dbapi_cursor, result.context ) - def yield_per(self, result, dbapi_cursor, num): + def yield_per( + self, + result: CursorResult[Any], + dbapi_cursor: Optional[DBAPICursor], + num: int, + ) -> None: result.cursor_strategy = BufferedRowCursorFetchStrategy( dbapi_cursor, {"max_row_buffer": num}, @@ -901,7 +949,12 @@ class CursorFetchStrategy(ResultFetchStrategy): growth_factor=0, ) - def fetchone(self, result, dbapi_cursor, hard_close=False): + def fetchone( + self, + result: CursorResult[Any], + dbapi_cursor: DBAPICursor, + hard_close: bool = False, + ) -> Any: try: row = dbapi_cursor.fetchone() if row is None: @@ -910,7 +963,12 @@ class CursorFetchStrategy(ResultFetchStrategy): except BaseException as e: self.handle_exception(result, dbapi_cursor, e) - def fetchmany(self, result, dbapi_cursor, size=None): + def fetchmany( + self, + result: CursorResult[Any], + dbapi_cursor: DBAPICursor, + size: Optional[int] = None, + ) -> Any: try: if size is None: l = dbapi_cursor.fetchmany() @@ -923,7 +981,11 @@ class CursorFetchStrategy(ResultFetchStrategy): except BaseException as e: self.handle_exception(result, dbapi_cursor, e) - def fetchall(self, result, dbapi_cursor): + def fetchall( + self, + result: CursorResult[Any], + dbapi_cursor: DBAPICursor, + ) -> Any: try: rows = dbapi_cursor.fetchall() result._soft_close() @@ -1163,6 +1225,9 @@ class _NoResultMetaData(ResultMetaData): _NO_RESULT_METADATA = _NoResultMetaData() +SelfCursorResult = TypeVar("SelfCursorResult", bound="CursorResult[Any]") + + class CursorResult(Result[_T]): """A Result that is representing state from a DBAPI cursor. @@ -1199,7 +1264,17 @@ class CursorResult(Result[_T]): closed: bool = False _is_cursor = True - def __init__(self, context, cursor_strategy, cursor_description): + context: DefaultExecutionContext + dialect: Dialect + cursor_strategy: ResultFetchStrategy + connection: Connection + + def __init__( + self, + context: DefaultExecutionContext, + cursor_strategy: ResultFetchStrategy, + cursor_description: Optional[_DBAPICursorDescription], + ): self.context = context self.dialect = context.dialect self.cursor = context.cursor @@ -1333,7 +1408,7 @@ class CursorResult(Result[_T]): if not self._soft_closed: cursor = self.cursor - self.cursor = None + self.cursor = None # type: ignore self.connection._safe_close_cursor(cursor) self._soft_closed = True @@ -1605,7 +1680,7 @@ class CursorResult(Result[_T]): return self.dialect.supports_sane_multi_rowcount @util.memoized_property - def rowcount(self): + def rowcount(self) -> int: """Return the 'rowcount' for this result. The 'rowcount' reports the number of rows *matched* @@ -1655,6 +1730,7 @@ class CursorResult(Result[_T]): return self.context.rowcount except BaseException as e: self.cursor_strategy.handle_exception(self, self.cursor, e) + raise # not called @property def lastrowid(self): @@ -1749,7 +1825,7 @@ class CursorResult(Result[_T]): ) return merged_result - def close(self): + def close(self) -> Any: """Close this :class:`_engine.CursorResult`. This closes out the underlying DBAPI cursor corresponding to the @@ -1772,7 +1848,7 @@ class CursorResult(Result[_T]): self._soft_close(hard=True) @_generative - def yield_per(self, num): + def yield_per(self: SelfCursorResult, num: int) -> SelfCursorResult: self._yield_per = num self.cursor_strategy.yield_per(self, self.cursor, num) return self diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 6094ad0fb..fc114efa3 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -64,6 +64,7 @@ if typing.TYPE_CHECKING: from .base import Engine from .interfaces import _CoreMultiExecuteParams from .interfaces import _CoreSingleExecuteParams + from .interfaces import _DBAPICursorDescription from .interfaces import _DBAPIMultiExecuteParams from .interfaces import _ExecuteOptions from .interfaces import _IsolationLevel @@ -1285,8 +1286,8 @@ class DefaultExecutionContext(ExecutionContext): def handle_dbapi_exception(self, e): pass - @property - def rowcount(self): + @util.non_memoized_property + def rowcount(self) -> int: return self.cursor.rowcount def supports_sane_rowcount(self): @@ -1304,7 +1305,7 @@ class DefaultExecutionContext(ExecutionContext): strategy = _cursor.BufferedRowCursorFetchStrategy( self.cursor, self.execution_options ) - cursor_description = ( + cursor_description: _DBAPICursorDescription = ( strategy.alternate_cursor_description or self.cursor.description ) diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index 641024603..e5414b70f 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -133,17 +133,7 @@ class DBAPICursor(Protocol): @property def description( self, - ) -> Sequence[ - Tuple[ - str, - "DBAPIType", - Optional[int], - Optional[int], - Optional[int], - Optional[int], - Optional[bool], - ] - ]: + ) -> _DBAPICursorDescription: """The description attribute of the Cursor. .. seealso:: @@ -217,7 +207,15 @@ _DBAPIMultiExecuteParams = Union[ _DBAPIAnyExecuteParams = Union[ _DBAPIMultiExecuteParams, _DBAPISingleExecuteParams ] -_DBAPICursorDescription = Tuple[str, Any, Any, Any, Any, Any, Any] +_DBAPICursorDescription = Tuple[ + str, + "DBAPIType", + Optional[int], + Optional[int], + Optional[int], + Optional[int], + Optional[bool], +] _AnySingleExecuteParams = _DBAPISingleExecuteParams _AnyMultiExecuteParams = _DBAPIMultiExecuteParams @@ -2297,6 +2295,9 @@ class ExecutionContext: """ + engine: Engine + """engine which the Connection is associated with""" + connection: Connection """Connection object which can be freely used by default value generators to execute SQL. This Connection should reference the |