summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/crud.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/crud.py')
-rw-r--r--lib/sqlalchemy/sql/crud.py32
1 files changed, 18 insertions, 14 deletions
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py
index e4408cd31..29d7b45d7 100644
--- a/lib/sqlalchemy/sql/crud.py
+++ b/lib/sqlalchemy/sql/crud.py
@@ -22,6 +22,7 @@ from typing import MutableMapping
from typing import NamedTuple
from typing import Optional
from typing import overload
+from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
from typing import Union
@@ -30,8 +31,10 @@ from . import coercions
from . import dml
from . import elements
from . import roles
+from .elements import ColumnClause
from .schema import default_is_clause_element
from .schema import default_is_sequence
+from .selectable import TableClause
from .. import exc
from .. import util
from ..util.typing import Literal
@@ -41,16 +44,9 @@ if TYPE_CHECKING:
from .compiler import SQLCompiler
from .dml import _DMLColumnElement
from .dml import DMLState
- from .dml import Insert
- from .dml import Update
- from .dml import UpdateDMLState
from .dml import ValuesBase
- from .elements import ClauseElement
- from .elements import ColumnClause
from .elements import ColumnElement
- from .elements import TextClause
from .schema import _SQLExprDefault
- from .schema import Column
from .selectable import TableClause
REQUIRED = util.symbol(
@@ -68,12 +64,20 @@ values present.
)
+def _as_dml_column(c: ColumnElement[Any]) -> ColumnClause[Any]:
+ if not isinstance(c, ColumnClause):
+ raise exc.CompileError(
+ f"Can't create DML statement against column expression {c!r}"
+ )
+ return c
+
+
class _CrudParams(NamedTuple):
- single_params: List[
- Tuple[ColumnClause[Any], str, Optional[Union[str, _SQLExprDefault]]]
+ single_params: Sequence[
+ Tuple[ColumnElement[Any], str, Optional[Union[str, _SQLExprDefault]]]
]
all_multi_params: List[
- List[
+ Sequence[
Tuple[
ColumnClause[Any],
str,
@@ -274,7 +278,7 @@ def _get_crud_params(
compiler,
stmt,
compile_state,
- cast("List[Tuple[ColumnClause[Any], str, str]]", values),
+ cast("Sequence[Tuple[ColumnClause[Any], str, str]]", values),
cast("Callable[..., str]", _column_as_key),
kw,
)
@@ -290,7 +294,7 @@ def _get_crud_params(
# insert_executemany_returning mode :)
values = [
(
- stmt.table.columns[0],
+ _as_dml_column(stmt.table.columns[0]),
compiler.preparer.format_column(stmt.table.columns[0]),
"DEFAULT",
)
@@ -1135,10 +1139,10 @@ def _extend_values_for_multiparams(
compiler: SQLCompiler,
stmt: ValuesBase,
compile_state: DMLState,
- initial_values: List[Tuple[ColumnClause[Any], str, str]],
+ initial_values: Sequence[Tuple[ColumnClause[Any], str, str]],
_column_as_key: Callable[..., str],
kw: Dict[str, Any],
-) -> List[List[Tuple[ColumnClause[Any], str, str]]]:
+) -> List[Sequence[Tuple[ColumnClause[Any], str, str]]]:
values_0 = initial_values
values = [initial_values]