summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/crud.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-03-20 16:39:36 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2022-03-24 16:57:30 -0400
commit6f02d5edd88fe2475629438b0730181a2b00c5fe (patch)
treebbf9e9f3e8a2363659be35d59a7749c7fe35ef7c /lib/sqlalchemy/sql/crud.py
parentc565c470517e1cc70a7f33d1ad3d3256935f1121 (diff)
downloadsqlalchemy-6f02d5edd88fe2475629438b0730181a2b00c5fe.tar.gz
pep484 - SQL internals
non-strict checking for mostly internal or semi-internal code Change-Id: Ib91b47f1a8ccc15e666b94bad1ce78c4ab15b0ec
Diffstat (limited to 'lib/sqlalchemy/sql/crud.py')
-rw-r--r--lib/sqlalchemy/sql/crud.py305
1 files changed, 244 insertions, 61 deletions
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py
index 4292aa916..533a2f6cd 100644
--- a/lib/sqlalchemy/sql/crud.py
+++ b/lib/sqlalchemy/sql/crud.py
@@ -13,13 +13,44 @@ from __future__ import annotations
import functools
import operator
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Dict
+from typing import List
+from typing import MutableMapping
+from typing import NamedTuple
+from typing import Optional
+from typing import overload
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import Union
from . import coercions
from . import dml
from . import elements
from . import roles
+from .schema import default_is_clause_element
+from .schema import default_is_sequence
from .. import exc
from .. import util
+from ..util.typing import Literal
+
+if TYPE_CHECKING:
+ from .compiler import _BindNameForColProtocol
+ from .compiler import SQLCompiler
+ from .dml import DMLState
+ from .dml import Insert
+ from .dml import Update
+ from .dml import UpdateDMLState
+ from .dml import ValuesBase
+ from .elements import ClauseElement
+ from .elements import ColumnClause
+ from .elements import ColumnElement
+ from .elements import TextClause
+ from .schema import _SQLExprDefault
+ from .schema import Column
+ from .selectable import TableClause
REQUIRED = util.symbol(
"REQUIRED",
@@ -36,7 +67,27 @@ values present.
)
-def _get_crud_params(compiler, stmt, compile_state, **kw):
+class _CrudParams(NamedTuple):
+ single_params: List[
+ Tuple[ColumnClause[Any], str, Optional[Union[str, _SQLExprDefault]]]
+ ]
+ all_multi_params: List[
+ List[
+ Tuple[
+ ColumnClause[Any],
+ str,
+ str,
+ ]
+ ]
+ ]
+
+
+def _get_crud_params(
+ compiler: SQLCompiler,
+ stmt: ValuesBase,
+ compile_state: DMLState,
+ **kw: Any,
+) -> _CrudParams:
"""create a set of tuples representing column/string pairs for use
in an INSERT or UPDATE statement.
@@ -59,24 +110,32 @@ def _get_crud_params(compiler, stmt, compile_state, **kw):
_column_as_key,
_getattr_col_key,
_col_bind_name,
- ) = getters = _key_getters_for_crud_column(compiler, stmt, compile_state)
+ ) = _key_getters_for_crud_column(compiler, stmt, compile_state)
- compiler._key_getters_for_crud_column = getters
+ compiler._get_bind_name_for_col = _col_bind_name
# no parameters in the statement, no parameters in the
# compiled params - return binds for all columns
if compiler.column_keys is None and compile_state._no_parameters:
- return [
- (
- c,
- compiler.preparer.format_column(c),
- _create_bind_param(compiler, c, None, required=True),
- )
- for c in stmt.table.columns
- ]
+ return _CrudParams(
+ [
+ (
+ c,
+ compiler.preparer.format_column(c),
+ _create_bind_param(compiler, c, None, required=True),
+ )
+ for c in stmt.table.columns
+ ],
+ [],
+ )
+
+ stmt_parameter_tuples: Optional[List[Any]]
+ spd: Optional[MutableMapping[str, Any]]
if compile_state._has_multi_parameters:
- spd = compile_state._multi_parameters[0]
+ mp = compile_state._multi_parameters
+ assert mp is not None
+ spd = mp[0]
stmt_parameter_tuples = list(spd.items())
elif compile_state._ordered_values:
spd = compile_state._dict_parameters
@@ -92,6 +151,7 @@ def _get_crud_params(compiler, stmt, compile_state, **kw):
if compiler.column_keys is None:
parameters = {}
elif stmt_parameter_tuples:
+ assert spd is not None
parameters = dict(
(_column_as_key(key), REQUIRED)
for key in compiler.column_keys
@@ -103,7 +163,9 @@ def _get_crud_params(compiler, stmt, compile_state, **kw):
)
# create a list of column assignment clauses as tuples
- values = []
+ values: List[
+ Tuple[ColumnClause[Any], str, Optional[Union[str, _SQLExprDefault]]]
+ ] = []
if stmt_parameter_tuples is not None:
_get_stmt_parameter_tuples_params(
@@ -116,11 +178,11 @@ def _get_crud_params(compiler, stmt, compile_state, **kw):
kw,
)
- check_columns = {}
+ check_columns: Dict[str, ColumnClause[Any]] = {}
# special logic that only occurs for multi-table UPDATE
# statements
- if compile_state.isupdate and compile_state.is_multitable:
+ if dml.isupdate(compile_state) and compile_state.is_multitable:
_get_update_multitable_params(
compiler,
stmt,
@@ -134,6 +196,10 @@ def _get_crud_params(compiler, stmt, compile_state, **kw):
)
if compile_state.isinsert and stmt._select_names:
+ # is an insert from select, is not a multiparams
+
+ assert not compile_state._has_multi_parameters
+
_scan_insert_from_select_cols(
compiler,
stmt,
@@ -173,14 +239,17 @@ def _get_crud_params(compiler, stmt, compile_state, **kw):
)
if compile_state._has_multi_parameters:
- values = _extend_values_for_multiparams(
+ # is a multiparams, is not an insert from a select
+ assert not stmt._select_names
+ multi_extended_values = _extend_values_for_multiparams(
compiler,
stmt,
compile_state,
- values,
- _column_as_key,
+ cast("List[Tuple[ColumnClause[Any], str, str]]", values),
+ cast("Callable[..., str]", _column_as_key),
kw,
)
+ return _CrudParams(values, multi_extended_values)
elif (
not values
and compiler.for_executemany
@@ -198,12 +267,41 @@ def _get_crud_params(compiler, stmt, compile_state, **kw):
)
]
- return values
+ return _CrudParams(values, [])
+@overload
def _create_bind_param(
- compiler, col, value, process=True, required=False, name=None, **kw
-):
+ compiler: SQLCompiler,
+ col: ColumnElement[Any],
+ value: Any,
+ process: Literal[True] = ...,
+ required: bool = False,
+ name: Optional[str] = None,
+ **kw: Any,
+) -> str:
+ ...
+
+
+@overload
+def _create_bind_param(
+ compiler: SQLCompiler,
+ col: ColumnElement[Any],
+ value: Any,
+ **kw: Any,
+) -> str:
+ ...
+
+
+def _create_bind_param(
+ compiler: SQLCompiler,
+ col: ColumnElement[Any],
+ value: Any,
+ process: bool = True,
+ required: bool = False,
+ name: Optional[str] = None,
+ **kw: Any,
+) -> Union[str, elements.BindParameter[Any]]:
if name is None:
name = col.key
bindparam = elements.BindParameter(
@@ -211,8 +309,9 @@ def _create_bind_param(
)
bindparam._is_crud = True
if process:
- bindparam = bindparam._compiler_dispatch(compiler, **kw)
- return bindparam
+ return bindparam._compiler_dispatch(compiler, **kw)
+ else:
+ return bindparam
def _handle_values_anonymous_param(compiler, col, value, name, **kw):
@@ -253,8 +352,14 @@ def _handle_values_anonymous_param(compiler, col, value, name, **kw):
return value._compiler_dispatch(compiler, **kw)
-def _key_getters_for_crud_column(compiler, stmt, compile_state):
- if compile_state.isupdate and compile_state._extra_froms:
+def _key_getters_for_crud_column(
+ compiler: SQLCompiler, stmt: ValuesBase, compile_state: DMLState
+) -> Tuple[
+ Callable[[Union[str, Column[Any]]], Union[str, Tuple[str, str]]],
+ Callable[[Column[Any]], Union[str, Tuple[str, str]]],
+ _BindNameForColProtocol,
+]:
+ if dml.isupdate(compile_state) and compile_state._extra_froms:
# when extra tables are present, refer to the columns
# in those extra tables as table-qualified, including in
# dictionaries and when rendering bind param names.
@@ -267,30 +372,36 @@ def _key_getters_for_crud_column(compiler, stmt, compile_state):
coercions.expect_as_key, roles.DMLColumnRole
)
- def _column_as_key(key):
+ def _column_as_key(
+ key: Union[ColumnClause[Any], str]
+ ) -> Union[str, Tuple[str, str]]:
str_key = c_key_role(key)
- if hasattr(key, "table") and key.table in _et:
- return (key.table.name, str_key)
+ if hasattr(key, "table") and key.table in _et: # type: ignore
+ return (key.table.name, str_key) # type: ignore
else:
- return str_key
+ return str_key # type: ignore
- def _getattr_col_key(col):
+ def _getattr_col_key(
+ col: ColumnClause[Any],
+ ) -> Union[str, Tuple[str, str]]:
if col.table in _et:
- return (col.table.name, col.key)
+ return (col.table.name, col.key) # type: ignore
else:
return col.key
- def _col_bind_name(col):
+ def _col_bind_name(col: ColumnClause[Any]) -> str:
if col.table in _et:
+ if TYPE_CHECKING:
+ assert isinstance(col.table, TableClause)
return "%s_%s" % (col.table.name, col.key)
else:
return col.key
else:
- _column_as_key = functools.partial(
+ _column_as_key = functools.partial( # type: ignore
coercions.expect_as_key, roles.DMLColumnRole
)
- _getattr_col_key = _col_bind_name = operator.attrgetter("key")
+ _getattr_col_key = _col_bind_name = operator.attrgetter("key") # type: ignore # noqa E501
return _column_as_key, _getattr_col_key, _col_bind_name
@@ -321,7 +432,7 @@ def _scan_insert_from_select_cols(
compiler.stack[-1]["insert_from_select"] = stmt.select
- add_select_cols = []
+ add_select_cols: List[Tuple[ColumnClause[Any], str, _SQLExprDefault]] = []
if stmt.include_insert_from_select_defaults:
col_set = set(cols)
for col in stmt.table.columns:
@@ -707,16 +818,22 @@ def _append_param_insert_hasdefault(
)
-def _append_param_insert_select_hasdefault(compiler, stmt, c, values, kw):
+def _append_param_insert_select_hasdefault(
+ compiler: SQLCompiler,
+ stmt: ValuesBase,
+ c: ColumnClause[Any],
+ values: List[Tuple[ColumnClause[Any], str, _SQLExprDefault]],
+ kw: Dict[str, Any],
+) -> None:
- if c.default.is_sequence:
+ if default_is_sequence(c.default):
if compiler.dialect.supports_sequences and (
not c.default.optional or not compiler.dialect.sequences_optional
):
values.append(
(c, compiler.preparer.format_column(c), c.default.next_value())
)
- elif c.default.is_clause_element:
+ elif default_is_clause_element(c.default):
values.append(
(c, compiler.preparer.format_column(c), c.default.arg.self_group())
)
@@ -777,28 +894,76 @@ def _append_param_update(
compiler.returning.append(c)
+@overload
def _create_insert_prefetch_bind_param(
- compiler, c, process=True, name=None, **kw
-):
+ compiler: SQLCompiler,
+ c: ColumnElement[Any],
+ process: Literal[True] = ...,
+ **kw: Any,
+) -> str:
+ ...
+
+
+@overload
+def _create_insert_prefetch_bind_param(
+ compiler: SQLCompiler,
+ c: ColumnElement[Any],
+ process: Literal[False],
+ **kw: Any,
+) -> elements.BindParameter[Any]:
+ ...
+
+
+def _create_insert_prefetch_bind_param(
+ compiler: SQLCompiler,
+ c: ColumnElement[Any],
+ process: bool = True,
+ name: Optional[str] = None,
+ **kw: Any,
+) -> Union[elements.BindParameter[Any], str]:
param = _create_bind_param(
compiler, c, None, process=process, name=name, **kw
)
- compiler.insert_prefetch.append(c)
+ compiler.insert_prefetch.append(c) # type: ignore
return param
+@overload
def _create_update_prefetch_bind_param(
- compiler, c, process=True, name=None, **kw
-):
+ compiler: SQLCompiler,
+ c: ColumnElement[Any],
+ process: Literal[True] = ...,
+ **kw: Any,
+) -> str:
+ ...
+
+
+@overload
+def _create_update_prefetch_bind_param(
+ compiler: SQLCompiler,
+ c: ColumnElement[Any],
+ process: Literal[False],
+ **kw: Any,
+) -> elements.BindParameter[Any]:
+ ...
+
+
+def _create_update_prefetch_bind_param(
+ compiler: SQLCompiler,
+ c: ColumnElement[Any],
+ process: bool = True,
+ name: Optional[str] = None,
+ **kw: Any,
+) -> Union[elements.BindParameter[Any], str]:
param = _create_bind_param(
compiler, c, None, process=process, name=name, **kw
)
- compiler.update_prefetch.append(c)
+ compiler.update_prefetch.append(c) # type: ignore
return param
-class _multiparam_column(elements.ColumnElement):
+class _multiparam_column(elements.ColumnElement[Any]):
_is_multiparam_column = True
def __init__(self, original, index):
@@ -822,14 +987,20 @@ class _multiparam_column(elements.ColumnElement):
)
-def _process_multiparam_default_bind(compiler, stmt, c, index, kw):
+def _process_multiparam_default_bind(
+ compiler: SQLCompiler,
+ stmt: ValuesBase,
+ c: ColumnClause[Any],
+ index: int,
+ kw: Dict[str, Any],
+) -> str:
if not c.default:
raise exc.CompileError(
"INSERT value for column %s is explicitly rendered as a bound"
"parameter in the VALUES clause; "
"a Python-side value or SQL expression is required" % c
)
- elif c.default.is_clause_element:
+ elif default_is_clause_element(c.default):
return compiler.process(c.default.arg.self_group(), **kw)
elif c.default.is_sequence:
# these conditions would have been established
@@ -844,9 +1015,13 @@ def _process_multiparam_default_bind(compiler, stmt, c, index, kw):
else:
col = _multiparam_column(c, index)
if isinstance(stmt, dml.Insert):
- return _create_insert_prefetch_bind_param(compiler, col, **kw)
+ return _create_insert_prefetch_bind_param(
+ compiler, col, process=True, **kw
+ )
else:
- return _create_update_prefetch_bind_param(compiler, col, **kw)
+ return _create_update_prefetch_bind_param(
+ compiler, col, process=True, **kw
+ )
def _get_update_multitable_params(
@@ -926,18 +1101,26 @@ def _get_update_multitable_params(
def _extend_values_for_multiparams(
- compiler,
- stmt,
- compile_state,
- values,
- _column_as_key,
- kw,
-):
- values_0 = values
- values = [values]
-
- for i, row in enumerate(compile_state._multi_parameters[1:]):
- extension = []
+ compiler: SQLCompiler,
+ stmt: ValuesBase,
+ compile_state: DMLState,
+ initial_values: List[Tuple[ColumnClause[Any], str, str]],
+ _column_as_key: Callable[..., str],
+ kw: Dict[str, Any],
+) -> List[List[Tuple[ColumnClause[Any], str, str]]]:
+ values_0 = initial_values
+ values = [initial_values]
+
+ mp = compile_state._multi_parameters
+ assert mp is not None
+ for i, row in enumerate(mp[1:]):
+ extension: List[
+ Tuple[
+ ColumnClause[Any],
+ str,
+ str,
+ ]
+ ] = []
row = {_column_as_key(key): v for key, v in row.items()}