diff options
Diffstat (limited to 'lib/sqlalchemy/sql/crud.py')
-rw-r--r-- | lib/sqlalchemy/sql/crud.py | 32 |
1 files changed, 18 insertions, 14 deletions
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index e4408cd31..29d7b45d7 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -22,6 +22,7 @@ from typing import MutableMapping from typing import NamedTuple from typing import Optional from typing import overload +from typing import Sequence from typing import Tuple from typing import TYPE_CHECKING from typing import Union @@ -30,8 +31,10 @@ from . import coercions from . import dml from . import elements from . import roles +from .elements import ColumnClause from .schema import default_is_clause_element from .schema import default_is_sequence +from .selectable import TableClause from .. import exc from .. import util from ..util.typing import Literal @@ -41,16 +44,9 @@ if TYPE_CHECKING: from .compiler import SQLCompiler from .dml import _DMLColumnElement from .dml import DMLState - from .dml import Insert - from .dml import Update - from .dml import UpdateDMLState from .dml import ValuesBase - from .elements import ClauseElement - from .elements import ColumnClause from .elements import ColumnElement - from .elements import TextClause from .schema import _SQLExprDefault - from .schema import Column from .selectable import TableClause REQUIRED = util.symbol( @@ -68,12 +64,20 @@ values present. ) +def _as_dml_column(c: ColumnElement[Any]) -> ColumnClause[Any]: + if not isinstance(c, ColumnClause): + raise exc.CompileError( + f"Can't create DML statement against column expression {c!r}" + ) + return c + + class _CrudParams(NamedTuple): - single_params: List[ - Tuple[ColumnClause[Any], str, Optional[Union[str, _SQLExprDefault]]] + single_params: Sequence[ + Tuple[ColumnElement[Any], str, Optional[Union[str, _SQLExprDefault]]] ] all_multi_params: List[ - List[ + Sequence[ Tuple[ ColumnClause[Any], str, @@ -274,7 +278,7 @@ def _get_crud_params( compiler, stmt, compile_state, - cast("List[Tuple[ColumnClause[Any], str, str]]", values), + cast("Sequence[Tuple[ColumnClause[Any], str, str]]", values), cast("Callable[..., str]", _column_as_key), kw, ) @@ -290,7 +294,7 @@ def _get_crud_params( # insert_executemany_returning mode :) values = [ ( - stmt.table.columns[0], + _as_dml_column(stmt.table.columns[0]), compiler.preparer.format_column(stmt.table.columns[0]), "DEFAULT", ) @@ -1135,10 +1139,10 @@ def _extend_values_for_multiparams( compiler: SQLCompiler, stmt: ValuesBase, compile_state: DMLState, - initial_values: List[Tuple[ColumnClause[Any], str, str]], + initial_values: Sequence[Tuple[ColumnClause[Any], str, str]], _column_as_key: Callable[..., str], kw: Dict[str, Any], -) -> List[List[Tuple[ColumnClause[Any], str, str]]]: +) -> List[Sequence[Tuple[ColumnClause[Any], str, str]]]: values_0 = initial_values values = [initial_values] |