diff options
Diffstat (limited to 'lib/sqlalchemy/engine/cursor.py')
-rw-r--r-- | lib/sqlalchemy/engine/cursor.py | 390 |
1 files changed, 336 insertions, 54 deletions
diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py index 8840b5916..07e782296 100644 --- a/lib/sqlalchemy/engine/cursor.py +++ b/lib/sqlalchemy/engine/cursor.py @@ -23,12 +23,14 @@ from typing import Iterator from typing import List from typing import NoReturn from typing import Optional +from typing import overload from typing import Sequence from typing import Tuple from typing import TYPE_CHECKING from typing import TypeVar from typing import Union +from .result import IteratorResult from .result import MergedResult from .result import Result from .result import ResultMetaData @@ -62,36 +64,80 @@ if typing.TYPE_CHECKING: from .interfaces import ExecutionContext from .result import _KeyIndexType from .result import _KeyMapRecType + from .result import _KeyMapType from .result import _KeyType from .result import _ProcessorsType + from .result import _TupleGetterType from ..sql.type_api import _ResultProcessorType _T = TypeVar("_T", bound=Any) + # metadata entry tuple indexes. # using raw tuple is faster than namedtuple. -MD_INDEX: Literal[0] = 0 # integer index in cursor.description -MD_RESULT_MAP_INDEX: Literal[ - 1 -] = 1 # integer index in compiled._result_columns -MD_OBJECTS: Literal[ - 2 -] = 2 # other string keys and ColumnElement obj that can match -MD_LOOKUP_KEY: Literal[ - 3 -] = 3 # string key we usually expect for key-based lookup -MD_RENDERED_NAME: Literal[4] = 4 # name that is usually in cursor.description -MD_PROCESSOR: Literal[5] = 5 # callable to process a result value into a row -MD_UNTRANSLATED: Literal[6] = 6 # raw name from cursor.description +# these match up to the positions in +# _CursorKeyMapRecType +MD_INDEX: Literal[0] = 0 +"""integer index in cursor.description + +""" + +MD_RESULT_MAP_INDEX: Literal[1] = 1 +"""integer index in compiled._result_columns""" + +MD_OBJECTS: Literal[2] = 2 +"""other string keys and ColumnElement obj that can match. + +This comes from compiler.RM_OBJECTS / compiler.ResultColumnsEntry.objects + +""" + +MD_LOOKUP_KEY: Literal[3] = 3 +"""string key we usually expect for key-based lookup + +this comes from compiler.RM_NAME / compiler.ResultColumnsEntry.name +""" + + +MD_RENDERED_NAME: Literal[4] = 4 +"""name that is usually in cursor.description + +this comes from compiler.RENDERED_NAME / compiler.ResultColumnsEntry.keyname +""" + + +MD_PROCESSOR: Literal[5] = 5 +"""callable to process a result value into a row""" + +MD_UNTRANSLATED: Literal[6] = 6 +"""raw name from cursor.description""" _CursorKeyMapRecType = Tuple[ - int, int, List[Any], str, str, Optional["_ResultProcessorType"], str + Optional[int], # MD_INDEX, None means the record is ambiguously named + int, # MD_RESULT_MAP_INDEX + List[Any], # MD_OBJECTS + str, # MD_LOOKUP_KEY + str, # MD_RENDERED_NAME + Optional["_ResultProcessorType"], # MD_PROCESSOR + Optional[str], # MD_UNTRANSLATED ] _CursorKeyMapType = Dict["_KeyType", _CursorKeyMapRecType] +# same as _CursorKeyMapRecType except the MD_INDEX value is definitely +# not None +_NonAmbigCursorKeyMapRecType = Tuple[ + int, + int, + List[Any], + str, + str, + Optional["_ResultProcessorType"], + str, +] + class CursorResultMetaData(ResultMetaData): """Result metadata for DBAPI cursors.""" @@ -127,38 +173,112 @@ class CursorResultMetaData(ResultMetaData): extra=[self._keymap[key][MD_OBJECTS] for key in self._keys], ) - def _reduce(self, keys: Sequence[_KeyIndexType]) -> ResultMetaData: - recs = cast( - "List[_CursorKeyMapRecType]", list(self._metadata_for_keys(keys)) + def _make_new_metadata( + self, + *, + unpickled: bool, + processors: _ProcessorsType, + keys: Sequence[str], + keymap: _KeyMapType, + tuplefilter: Optional[_TupleGetterType], + translated_indexes: Optional[List[int]], + safe_for_cache: bool, + keymap_by_result_column_idx: Any, + ) -> CursorResultMetaData: + new_obj = self.__class__.__new__(self.__class__) + new_obj._unpickled = unpickled + new_obj._processors = processors + new_obj._keys = keys + new_obj._keymap = keymap + new_obj._tuplefilter = tuplefilter + new_obj._translated_indexes = translated_indexes + new_obj._safe_for_cache = safe_for_cache + new_obj._keymap_by_result_column_idx = keymap_by_result_column_idx + return new_obj + + def _remove_processors(self) -> CursorResultMetaData: + assert not self._tuplefilter + return self._make_new_metadata( + unpickled=self._unpickled, + processors=[None] * len(self._processors), + tuplefilter=None, + translated_indexes=None, + keymap={ + key: value[0:5] + (None,) + value[6:] + for key, value in self._keymap.items() + }, + keys=self._keys, + safe_for_cache=self._safe_for_cache, + keymap_by_result_column_idx=self._keymap_by_result_column_idx, ) + def _splice_horizontally( + self, other: CursorResultMetaData + ) -> CursorResultMetaData: + + assert not self._tuplefilter + + keymap = self._keymap.copy() + offset = len(self._keys) + keymap.update( + { + key: ( + # int index should be None for ambiguous key + value[0] + offset + if value[0] is not None and key not in keymap + else None, + value[1] + offset, + *value[2:], + ) + for key, value in other._keymap.items() + } + ) + + return self._make_new_metadata( + unpickled=self._unpickled, + processors=self._processors + other._processors, # type: ignore + tuplefilter=None, + translated_indexes=None, + keys=self._keys + other._keys, # type: ignore + keymap=keymap, + safe_for_cache=self._safe_for_cache, + keymap_by_result_column_idx={ + metadata_entry[MD_RESULT_MAP_INDEX]: metadata_entry + for metadata_entry in keymap.values() + }, + ) + + def _reduce(self, keys: Sequence[_KeyIndexType]) -> ResultMetaData: + recs = list(self._metadata_for_keys(keys)) + indexes = [rec[MD_INDEX] for rec in recs] new_keys: List[str] = [rec[MD_LOOKUP_KEY] for rec in recs] if self._translated_indexes: indexes = [self._translated_indexes[idx] for idx in indexes] tup = tuplegetter(*indexes) - - new_metadata = self.__class__.__new__(self.__class__) - new_metadata._unpickled = self._unpickled - new_metadata._processors = self._processors - new_metadata._keys = new_keys - new_metadata._tuplefilter = tup - new_metadata._translated_indexes = indexes - new_recs = [(index,) + rec[1:] for index, rec in enumerate(recs)] - new_metadata._keymap = {rec[MD_LOOKUP_KEY]: rec for rec in new_recs} + keymap: _KeyMapType = {rec[MD_LOOKUP_KEY]: rec for rec in new_recs} # TODO: need unit test for: # result = connection.execute("raw sql, no columns").scalars() # without the "or ()" it's failing because MD_OBJECTS is None - new_metadata._keymap.update( + keymap.update( (e, new_rec) for new_rec in new_recs for e in new_rec[MD_OBJECTS] or () ) - return new_metadata + return self._make_new_metadata( + unpickled=self._unpickled, + processors=self._processors, + keys=new_keys, + tuplefilter=tup, + translated_indexes=indexes, + keymap=keymap, + safe_for_cache=self._safe_for_cache, + keymap_by_result_column_idx=self._keymap_by_result_column_idx, + ) def _adapt_to_context(self, context: ExecutionContext) -> ResultMetaData: """When using a cached Compiled construct that has a _result_map, @@ -168,6 +288,7 @@ class CursorResultMetaData(ResultMetaData): as matched to those of the cached statement. """ + if not context.compiled or not context.compiled._result_columns: return self @@ -189,7 +310,6 @@ class CursorResultMetaData(ResultMetaData): # make a copy and add the columns from the invoked statement # to the result map. - md = self.__class__.__new__(self.__class__) keymap_by_position = self._keymap_by_result_column_idx @@ -201,26 +321,26 @@ class CursorResultMetaData(ResultMetaData): for metadata_entry in self._keymap.values() } - md._keymap = compat.dict_union( - self._keymap, - { - new: keymap_by_position[idx] - for idx, new in enumerate( - invoked_statement._all_selected_columns - ) - if idx in keymap_by_position - }, - ) - - md._unpickled = self._unpickled - md._processors = self._processors assert not self._tuplefilter - md._tuplefilter = None - md._translated_indexes = None - md._keys = self._keys - md._keymap_by_result_column_idx = self._keymap_by_result_column_idx - md._safe_for_cache = self._safe_for_cache - return md + return self._make_new_metadata( + keymap=compat.dict_union( + self._keymap, + { + new: keymap_by_position[idx] + for idx, new in enumerate( + invoked_statement._all_selected_columns + ) + if idx in keymap_by_position + }, + ), + unpickled=self._unpickled, + processors=self._processors, + tuplefilter=None, + translated_indexes=None, + keys=self._keys, + safe_for_cache=self._safe_for_cache, + keymap_by_result_column_idx=self._keymap_by_result_column_idx, + ) def __init__( self, @@ -683,7 +803,27 @@ class CursorResultMetaData(ResultMetaData): untranslated, ) - def _key_fallback(self, key, err, raiseerr=True): + @overload + def _key_fallback( + self, key: Any, err: Exception, raiseerr: Literal[True] = ... + ) -> NoReturn: + ... + + @overload + def _key_fallback( + self, key: Any, err: Exception, raiseerr: Literal[False] = ... + ) -> None: + ... + + @overload + def _key_fallback( + self, key: Any, err: Exception, raiseerr: bool = ... + ) -> Optional[NoReturn]: + ... + + def _key_fallback( + self, key: Any, err: Exception, raiseerr: bool = True + ) -> Optional[NoReturn]: if raiseerr: if self._unpickled and isinstance(key, elements.ColumnElement): @@ -714,9 +854,9 @@ class CursorResultMetaData(ResultMetaData): try: rec = self._keymap[key] except KeyError as ke: - rec = self._key_fallback(key, ke, raiseerr) - if rec is None: - return None + x = self._key_fallback(key, ke, raiseerr) + assert x is None + return None index = rec[0] @@ -734,7 +874,7 @@ class CursorResultMetaData(ResultMetaData): def _metadata_for_keys( self, keys: Sequence[Any] - ) -> Iterator[_CursorKeyMapRecType]: + ) -> Iterator[_NonAmbigCursorKeyMapRecType]: for key in keys: if int in key.__class__.__mro__: key = self._keys[key] @@ -750,7 +890,7 @@ class CursorResultMetaData(ResultMetaData): if index is None: self._raise_for_ambiguous_column_name(rec) - yield rec + yield cast(_NonAmbigCursorKeyMapRecType, rec) def __getstate__(self): return { @@ -1237,6 +1377,12 @@ _NO_RESULT_METADATA = _NoResultMetaData() SelfCursorResult = TypeVar("SelfCursorResult", bound="CursorResult[Any]") +def null_dml_result() -> IteratorResult[Any]: + it: IteratorResult[Any] = IteratorResult(_NoResultMetaData(), iter([])) + it._soft_close() + return it + + class CursorResult(Result[_T]): """A Result that is representing state from a DBAPI cursor. @@ -1586,6 +1732,142 @@ class CursorResult(Result[_T]): """ return self.context.returned_default_rows + def splice_horizontally(self, other): + """Return a new :class:`.CursorResult` that "horizontally splices" + together the rows of this :class:`.CursorResult` with that of another + :class:`.CursorResult`. + + .. tip:: This method is for the benefit of the SQLAlchemy ORM and is + not intended for general use. + + "horizontally splices" means that for each row in the first and second + result sets, a new row that concatenates the two rows together is + produced, which then becomes the new row. The incoming + :class:`.CursorResult` must have the identical number of rows. It is + typically expected that the two result sets come from the same sort + order as well, as the result rows are spliced together based on their + position in the result. + + The expected use case here is so that multiple INSERT..RETURNING + statements against different tables can produce a single result + that looks like a JOIN of those two tables. + + E.g.:: + + r1 = connection.execute( + users.insert().returning(users.c.user_name, users.c.user_id), + user_values + ) + + r2 = connection.execute( + addresses.insert().returning( + addresses.c.address_id, + addresses.c.address, + addresses.c.user_id, + ), + address_values + ) + + rows = r1.splice_horizontally(r2).all() + assert ( + rows == + [ + ("john", 1, 1, "foo@bar.com", 1), + ("jack", 2, 2, "bar@bat.com", 2), + ] + ) + + .. versionadded:: 2.0 + + .. seealso:: + + :meth:`.CursorResult.splice_vertically` + + + """ + + clone = self._generate() + total_rows = [ + tuple(r1) + tuple(r2) + for r1, r2 in zip( + list(self._raw_row_iterator()), + list(other._raw_row_iterator()), + ) + ] + + clone._metadata = clone._metadata._splice_horizontally(other._metadata) + + clone.cursor_strategy = FullyBufferedCursorFetchStrategy( + None, + initial_buffer=total_rows, + ) + clone._reset_memoizations() + return clone + + def splice_vertically(self, other): + """Return a new :class:`.CursorResult` that "vertically splices", + i.e. "extends", the rows of this :class:`.CursorResult` with that of + another :class:`.CursorResult`. + + .. tip:: This method is for the benefit of the SQLAlchemy ORM and is + not intended for general use. + + "vertically splices" means the rows of the given result are appended to + the rows of this cursor result. The incoming :class:`.CursorResult` + must have rows that represent the identical list of columns in the + identical order as they are in this :class:`.CursorResult`. + + .. versionadded:: 2.0 + + .. seealso:: + + :ref:`.CursorResult.splice_horizontally` + + """ + clone = self._generate() + total_rows = list(self._raw_row_iterator()) + list( + other._raw_row_iterator() + ) + + clone.cursor_strategy = FullyBufferedCursorFetchStrategy( + None, + initial_buffer=total_rows, + ) + clone._reset_memoizations() + return clone + + def _rewind(self, rows): + """rewind this result back to the given rowset. + + this is used internally for the case where an :class:`.Insert` + construct combines the use of + :meth:`.Insert.return_defaults` along with the + "supplemental columns" feature. + + """ + + if self._echo: + self.context.connection._log_debug( + "CursorResult rewound %d row(s)", len(rows) + ) + + # the rows given are expected to be Row objects, so we + # have to clear out processors which have already run on these + # rows + self._metadata = cast( + CursorResultMetaData, self._metadata + )._remove_processors() + + self.cursor_strategy = FullyBufferedCursorFetchStrategy( + None, + # TODO: if these are Row objects, can we save on not having to + # re-make new Row objects out of them a second time? is that + # what's actually happening right now? maybe look into this + initial_buffer=rows, + ) + self._reset_memoizations() + return self + @property def returned_defaults(self): """Return the values of default columns that were fetched using |