diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-03-20 16:39:36 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-03-24 16:57:30 -0400 |
commit | 6f02d5edd88fe2475629438b0730181a2b00c5fe (patch) | |
tree | bbf9e9f3e8a2363659be35d59a7749c7fe35ef7c /lib/sqlalchemy/sql/crud.py | |
parent | c565c470517e1cc70a7f33d1ad3d3256935f1121 (diff) | |
download | sqlalchemy-6f02d5edd88fe2475629438b0730181a2b00c5fe.tar.gz |
pep484 - SQL internals
non-strict checking for mostly internal or semi-internal
code
Change-Id: Ib91b47f1a8ccc15e666b94bad1ce78c4ab15b0ec
Diffstat (limited to 'lib/sqlalchemy/sql/crud.py')
-rw-r--r-- | lib/sqlalchemy/sql/crud.py | 305 |
1 files changed, 244 insertions, 61 deletions
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 4292aa916..533a2f6cd 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -13,13 +13,44 @@ from __future__ import annotations import functools import operator +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import List +from typing import MutableMapping +from typing import NamedTuple +from typing import Optional +from typing import overload +from typing import Tuple +from typing import TYPE_CHECKING +from typing import Union from . import coercions from . import dml from . import elements from . import roles +from .schema import default_is_clause_element +from .schema import default_is_sequence from .. import exc from .. import util +from ..util.typing import Literal + +if TYPE_CHECKING: + from .compiler import _BindNameForColProtocol + from .compiler import SQLCompiler + from .dml import DMLState + from .dml import Insert + from .dml import Update + from .dml import UpdateDMLState + from .dml import ValuesBase + from .elements import ClauseElement + from .elements import ColumnClause + from .elements import ColumnElement + from .elements import TextClause + from .schema import _SQLExprDefault + from .schema import Column + from .selectable import TableClause REQUIRED = util.symbol( "REQUIRED", @@ -36,7 +67,27 @@ values present. ) -def _get_crud_params(compiler, stmt, compile_state, **kw): +class _CrudParams(NamedTuple): + single_params: List[ + Tuple[ColumnClause[Any], str, Optional[Union[str, _SQLExprDefault]]] + ] + all_multi_params: List[ + List[ + Tuple[ + ColumnClause[Any], + str, + str, + ] + ] + ] + + +def _get_crud_params( + compiler: SQLCompiler, + stmt: ValuesBase, + compile_state: DMLState, + **kw: Any, +) -> _CrudParams: """create a set of tuples representing column/string pairs for use in an INSERT or UPDATE statement. @@ -59,24 +110,32 @@ def _get_crud_params(compiler, stmt, compile_state, **kw): _column_as_key, _getattr_col_key, _col_bind_name, - ) = getters = _key_getters_for_crud_column(compiler, stmt, compile_state) + ) = _key_getters_for_crud_column(compiler, stmt, compile_state) - compiler._key_getters_for_crud_column = getters + compiler._get_bind_name_for_col = _col_bind_name # no parameters in the statement, no parameters in the # compiled params - return binds for all columns if compiler.column_keys is None and compile_state._no_parameters: - return [ - ( - c, - compiler.preparer.format_column(c), - _create_bind_param(compiler, c, None, required=True), - ) - for c in stmt.table.columns - ] + return _CrudParams( + [ + ( + c, + compiler.preparer.format_column(c), + _create_bind_param(compiler, c, None, required=True), + ) + for c in stmt.table.columns + ], + [], + ) + + stmt_parameter_tuples: Optional[List[Any]] + spd: Optional[MutableMapping[str, Any]] if compile_state._has_multi_parameters: - spd = compile_state._multi_parameters[0] + mp = compile_state._multi_parameters + assert mp is not None + spd = mp[0] stmt_parameter_tuples = list(spd.items()) elif compile_state._ordered_values: spd = compile_state._dict_parameters @@ -92,6 +151,7 @@ def _get_crud_params(compiler, stmt, compile_state, **kw): if compiler.column_keys is None: parameters = {} elif stmt_parameter_tuples: + assert spd is not None parameters = dict( (_column_as_key(key), REQUIRED) for key in compiler.column_keys @@ -103,7 +163,9 @@ def _get_crud_params(compiler, stmt, compile_state, **kw): ) # create a list of column assignment clauses as tuples - values = [] + values: List[ + Tuple[ColumnClause[Any], str, Optional[Union[str, _SQLExprDefault]]] + ] = [] if stmt_parameter_tuples is not None: _get_stmt_parameter_tuples_params( @@ -116,11 +178,11 @@ def _get_crud_params(compiler, stmt, compile_state, **kw): kw, ) - check_columns = {} + check_columns: Dict[str, ColumnClause[Any]] = {} # special logic that only occurs for multi-table UPDATE # statements - if compile_state.isupdate and compile_state.is_multitable: + if dml.isupdate(compile_state) and compile_state.is_multitable: _get_update_multitable_params( compiler, stmt, @@ -134,6 +196,10 @@ def _get_crud_params(compiler, stmt, compile_state, **kw): ) if compile_state.isinsert and stmt._select_names: + # is an insert from select, is not a multiparams + + assert not compile_state._has_multi_parameters + _scan_insert_from_select_cols( compiler, stmt, @@ -173,14 +239,17 @@ def _get_crud_params(compiler, stmt, compile_state, **kw): ) if compile_state._has_multi_parameters: - values = _extend_values_for_multiparams( + # is a multiparams, is not an insert from a select + assert not stmt._select_names + multi_extended_values = _extend_values_for_multiparams( compiler, stmt, compile_state, - values, - _column_as_key, + cast("List[Tuple[ColumnClause[Any], str, str]]", values), + cast("Callable[..., str]", _column_as_key), kw, ) + return _CrudParams(values, multi_extended_values) elif ( not values and compiler.for_executemany @@ -198,12 +267,41 @@ def _get_crud_params(compiler, stmt, compile_state, **kw): ) ] - return values + return _CrudParams(values, []) +@overload def _create_bind_param( - compiler, col, value, process=True, required=False, name=None, **kw -): + compiler: SQLCompiler, + col: ColumnElement[Any], + value: Any, + process: Literal[True] = ..., + required: bool = False, + name: Optional[str] = None, + **kw: Any, +) -> str: + ... + + +@overload +def _create_bind_param( + compiler: SQLCompiler, + col: ColumnElement[Any], + value: Any, + **kw: Any, +) -> str: + ... + + +def _create_bind_param( + compiler: SQLCompiler, + col: ColumnElement[Any], + value: Any, + process: bool = True, + required: bool = False, + name: Optional[str] = None, + **kw: Any, +) -> Union[str, elements.BindParameter[Any]]: if name is None: name = col.key bindparam = elements.BindParameter( @@ -211,8 +309,9 @@ def _create_bind_param( ) bindparam._is_crud = True if process: - bindparam = bindparam._compiler_dispatch(compiler, **kw) - return bindparam + return bindparam._compiler_dispatch(compiler, **kw) + else: + return bindparam def _handle_values_anonymous_param(compiler, col, value, name, **kw): @@ -253,8 +352,14 @@ def _handle_values_anonymous_param(compiler, col, value, name, **kw): return value._compiler_dispatch(compiler, **kw) -def _key_getters_for_crud_column(compiler, stmt, compile_state): - if compile_state.isupdate and compile_state._extra_froms: +def _key_getters_for_crud_column( + compiler: SQLCompiler, stmt: ValuesBase, compile_state: DMLState +) -> Tuple[ + Callable[[Union[str, Column[Any]]], Union[str, Tuple[str, str]]], + Callable[[Column[Any]], Union[str, Tuple[str, str]]], + _BindNameForColProtocol, +]: + if dml.isupdate(compile_state) and compile_state._extra_froms: # when extra tables are present, refer to the columns # in those extra tables as table-qualified, including in # dictionaries and when rendering bind param names. @@ -267,30 +372,36 @@ def _key_getters_for_crud_column(compiler, stmt, compile_state): coercions.expect_as_key, roles.DMLColumnRole ) - def _column_as_key(key): + def _column_as_key( + key: Union[ColumnClause[Any], str] + ) -> Union[str, Tuple[str, str]]: str_key = c_key_role(key) - if hasattr(key, "table") and key.table in _et: - return (key.table.name, str_key) + if hasattr(key, "table") and key.table in _et: # type: ignore + return (key.table.name, str_key) # type: ignore else: - return str_key + return str_key # type: ignore - def _getattr_col_key(col): + def _getattr_col_key( + col: ColumnClause[Any], + ) -> Union[str, Tuple[str, str]]: if col.table in _et: - return (col.table.name, col.key) + return (col.table.name, col.key) # type: ignore else: return col.key - def _col_bind_name(col): + def _col_bind_name(col: ColumnClause[Any]) -> str: if col.table in _et: + if TYPE_CHECKING: + assert isinstance(col.table, TableClause) return "%s_%s" % (col.table.name, col.key) else: return col.key else: - _column_as_key = functools.partial( + _column_as_key = functools.partial( # type: ignore coercions.expect_as_key, roles.DMLColumnRole ) - _getattr_col_key = _col_bind_name = operator.attrgetter("key") + _getattr_col_key = _col_bind_name = operator.attrgetter("key") # type: ignore # noqa E501 return _column_as_key, _getattr_col_key, _col_bind_name @@ -321,7 +432,7 @@ def _scan_insert_from_select_cols( compiler.stack[-1]["insert_from_select"] = stmt.select - add_select_cols = [] + add_select_cols: List[Tuple[ColumnClause[Any], str, _SQLExprDefault]] = [] if stmt.include_insert_from_select_defaults: col_set = set(cols) for col in stmt.table.columns: @@ -707,16 +818,22 @@ def _append_param_insert_hasdefault( ) -def _append_param_insert_select_hasdefault(compiler, stmt, c, values, kw): +def _append_param_insert_select_hasdefault( + compiler: SQLCompiler, + stmt: ValuesBase, + c: ColumnClause[Any], + values: List[Tuple[ColumnClause[Any], str, _SQLExprDefault]], + kw: Dict[str, Any], +) -> None: - if c.default.is_sequence: + if default_is_sequence(c.default): if compiler.dialect.supports_sequences and ( not c.default.optional or not compiler.dialect.sequences_optional ): values.append( (c, compiler.preparer.format_column(c), c.default.next_value()) ) - elif c.default.is_clause_element: + elif default_is_clause_element(c.default): values.append( (c, compiler.preparer.format_column(c), c.default.arg.self_group()) ) @@ -777,28 +894,76 @@ def _append_param_update( compiler.returning.append(c) +@overload def _create_insert_prefetch_bind_param( - compiler, c, process=True, name=None, **kw -): + compiler: SQLCompiler, + c: ColumnElement[Any], + process: Literal[True] = ..., + **kw: Any, +) -> str: + ... + + +@overload +def _create_insert_prefetch_bind_param( + compiler: SQLCompiler, + c: ColumnElement[Any], + process: Literal[False], + **kw: Any, +) -> elements.BindParameter[Any]: + ... + + +def _create_insert_prefetch_bind_param( + compiler: SQLCompiler, + c: ColumnElement[Any], + process: bool = True, + name: Optional[str] = None, + **kw: Any, +) -> Union[elements.BindParameter[Any], str]: param = _create_bind_param( compiler, c, None, process=process, name=name, **kw ) - compiler.insert_prefetch.append(c) + compiler.insert_prefetch.append(c) # type: ignore return param +@overload def _create_update_prefetch_bind_param( - compiler, c, process=True, name=None, **kw -): + compiler: SQLCompiler, + c: ColumnElement[Any], + process: Literal[True] = ..., + **kw: Any, +) -> str: + ... + + +@overload +def _create_update_prefetch_bind_param( + compiler: SQLCompiler, + c: ColumnElement[Any], + process: Literal[False], + **kw: Any, +) -> elements.BindParameter[Any]: + ... + + +def _create_update_prefetch_bind_param( + compiler: SQLCompiler, + c: ColumnElement[Any], + process: bool = True, + name: Optional[str] = None, + **kw: Any, +) -> Union[elements.BindParameter[Any], str]: param = _create_bind_param( compiler, c, None, process=process, name=name, **kw ) - compiler.update_prefetch.append(c) + compiler.update_prefetch.append(c) # type: ignore return param -class _multiparam_column(elements.ColumnElement): +class _multiparam_column(elements.ColumnElement[Any]): _is_multiparam_column = True def __init__(self, original, index): @@ -822,14 +987,20 @@ class _multiparam_column(elements.ColumnElement): ) -def _process_multiparam_default_bind(compiler, stmt, c, index, kw): +def _process_multiparam_default_bind( + compiler: SQLCompiler, + stmt: ValuesBase, + c: ColumnClause[Any], + index: int, + kw: Dict[str, Any], +) -> str: if not c.default: raise exc.CompileError( "INSERT value for column %s is explicitly rendered as a bound" "parameter in the VALUES clause; " "a Python-side value or SQL expression is required" % c ) - elif c.default.is_clause_element: + elif default_is_clause_element(c.default): return compiler.process(c.default.arg.self_group(), **kw) elif c.default.is_sequence: # these conditions would have been established @@ -844,9 +1015,13 @@ def _process_multiparam_default_bind(compiler, stmt, c, index, kw): else: col = _multiparam_column(c, index) if isinstance(stmt, dml.Insert): - return _create_insert_prefetch_bind_param(compiler, col, **kw) + return _create_insert_prefetch_bind_param( + compiler, col, process=True, **kw + ) else: - return _create_update_prefetch_bind_param(compiler, col, **kw) + return _create_update_prefetch_bind_param( + compiler, col, process=True, **kw + ) def _get_update_multitable_params( @@ -926,18 +1101,26 @@ def _get_update_multitable_params( def _extend_values_for_multiparams( - compiler, - stmt, - compile_state, - values, - _column_as_key, - kw, -): - values_0 = values - values = [values] - - for i, row in enumerate(compile_state._multi_parameters[1:]): - extension = [] + compiler: SQLCompiler, + stmt: ValuesBase, + compile_state: DMLState, + initial_values: List[Tuple[ColumnClause[Any], str, str]], + _column_as_key: Callable[..., str], + kw: Dict[str, Any], +) -> List[List[Tuple[ColumnClause[Any], str, str]]]: + values_0 = initial_values + values = [initial_values] + + mp = compile_state._multi_parameters + assert mp is not None + for i, row in enumerate(mp[1:]): + extension: List[ + Tuple[ + ColumnClause[Any], + str, + str, + ] + ] = [] row = {_column_as_key(key): v for key, v in row.items()} |