diff options
Diffstat (limited to 'lib/sqlalchemy/sql/crud.py')
-rw-r--r-- | lib/sqlalchemy/sql/crud.py | 27 |
1 files changed, 17 insertions, 10 deletions
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 552f61b4a..881ea9fcd 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -9,14 +9,16 @@ within INSERT and UPDATE statements. """ +import functools import operator +from . import coercions from . import dml from . import elements +from . import roles from .. import exc from .. import util - REQUIRED = util.symbol( "REQUIRED", """ @@ -174,7 +176,7 @@ def _get_crud_params(compiler, stmt, **kw): if check: raise exc.CompileError( "Unconsumed column names: %s" - % (", ".join("%s" % c for c in check)) + % (", ".join("%s" % (c,) for c in check)) ) if stmt._has_multi_parameters: @@ -207,8 +209,12 @@ def _key_getters_for_crud_column(compiler, stmt): # statement. _et = set(stmt._extra_froms) + c_key_role = functools.partial( + coercions.expect_as_key, roles.DMLColumnRole + ) + def _column_as_key(key): - str_key = elements._column_as_key(key) + str_key = c_key_role(key) if hasattr(key, "table") and key.table in _et: return (key.table.name, str_key) else: @@ -227,7 +233,9 @@ def _key_getters_for_crud_column(compiler, stmt): return col.key else: - _column_as_key = elements._column_as_key + _column_as_key = functools.partial( + coercions.expect_as_key, roles.DMLColumnRole + ) _getattr_col_key = _col_bind_name = operator.attrgetter("key") return _column_as_key, _getattr_col_key, _col_bind_name @@ -386,7 +394,7 @@ def _append_param_parameter( kw, ): value = parameters.pop(col_key) - if elements._is_literal(value): + if coercions._is_literal(value): value = _create_bind_param( compiler, c, @@ -633,9 +641,8 @@ def _get_multitable_params( values, kw, ): - normalized_params = dict( - (elements._clause_element_as_expr(c), param) + (coercions.expect(roles.DMLColumnRole, c), param) for c, param in stmt_parameters.items() ) affected_tables = set() @@ -645,7 +652,7 @@ def _get_multitable_params( affected_tables.add(t) check_columns[_getattr_col_key(c)] = c value = normalized_params[c] - if elements._is_literal(value): + if coercions._is_literal(value): value = _create_bind_param( compiler, c, @@ -697,7 +704,7 @@ def _extend_values_for_multiparams(compiler, stmt, values, kw): if col in row or col.key in row: key = col if col in row else col.key - if elements._is_literal(row[key]): + if coercions._is_literal(row[key]): new_param = _create_bind_param( compiler, col, @@ -730,7 +737,7 @@ def _get_stmt_parameters_params( # a non-Column expression on the left side; # add it to values() in an "as-is" state, # coercing right side to bound param - if elements._is_literal(v): + if coercions._is_literal(v): v = compiler.process( elements.BindParameter(None, v, type_=k.type), **kw ) |