summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/crud.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/crud.py')
-rw-r--r--lib/sqlalchemy/sql/crud.py27
1 files changed, 17 insertions, 10 deletions
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py
index 552f61b4a..881ea9fcd 100644
--- a/lib/sqlalchemy/sql/crud.py
+++ b/lib/sqlalchemy/sql/crud.py
@@ -9,14 +9,16 @@
within INSERT and UPDATE statements.
"""
+import functools
import operator
+from . import coercions
from . import dml
from . import elements
+from . import roles
from .. import exc
from .. import util
-
REQUIRED = util.symbol(
"REQUIRED",
"""
@@ -174,7 +176,7 @@ def _get_crud_params(compiler, stmt, **kw):
if check:
raise exc.CompileError(
"Unconsumed column names: %s"
- % (", ".join("%s" % c for c in check))
+ % (", ".join("%s" % (c,) for c in check))
)
if stmt._has_multi_parameters:
@@ -207,8 +209,12 @@ def _key_getters_for_crud_column(compiler, stmt):
# statement.
_et = set(stmt._extra_froms)
+ c_key_role = functools.partial(
+ coercions.expect_as_key, roles.DMLColumnRole
+ )
+
def _column_as_key(key):
- str_key = elements._column_as_key(key)
+ str_key = c_key_role(key)
if hasattr(key, "table") and key.table in _et:
return (key.table.name, str_key)
else:
@@ -227,7 +233,9 @@ def _key_getters_for_crud_column(compiler, stmt):
return col.key
else:
- _column_as_key = elements._column_as_key
+ _column_as_key = functools.partial(
+ coercions.expect_as_key, roles.DMLColumnRole
+ )
_getattr_col_key = _col_bind_name = operator.attrgetter("key")
return _column_as_key, _getattr_col_key, _col_bind_name
@@ -386,7 +394,7 @@ def _append_param_parameter(
kw,
):
value = parameters.pop(col_key)
- if elements._is_literal(value):
+ if coercions._is_literal(value):
value = _create_bind_param(
compiler,
c,
@@ -633,9 +641,8 @@ def _get_multitable_params(
values,
kw,
):
-
normalized_params = dict(
- (elements._clause_element_as_expr(c), param)
+ (coercions.expect(roles.DMLColumnRole, c), param)
for c, param in stmt_parameters.items()
)
affected_tables = set()
@@ -645,7 +652,7 @@ def _get_multitable_params(
affected_tables.add(t)
check_columns[_getattr_col_key(c)] = c
value = normalized_params[c]
- if elements._is_literal(value):
+ if coercions._is_literal(value):
value = _create_bind_param(
compiler,
c,
@@ -697,7 +704,7 @@ def _extend_values_for_multiparams(compiler, stmt, values, kw):
if col in row or col.key in row:
key = col if col in row else col.key
- if elements._is_literal(row[key]):
+ if coercions._is_literal(row[key]):
new_param = _create_bind_param(
compiler,
col,
@@ -730,7 +737,7 @@ def _get_stmt_parameters_params(
# a non-Column expression on the left side;
# add it to values() in an "as-is" state,
# coercing right side to bound param
- if elements._is_literal(v):
+ if coercions._is_literal(v):
v = compiler.process(
elements.BindParameter(None, v, type_=k.type), **kw
)