summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/dialects/postgresql
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/dialects/postgresql')
-rw-r--r--lib/sqlalchemy/dialects/postgresql/base.py56
-rw-r--r--lib/sqlalchemy/dialects/postgresql/dml.py4
2 files changed, 49 insertions, 11 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py
index b436b934f..169b792f5 100644
--- a/lib/sqlalchemy/dialects/postgresql/base.py
+++ b/lib/sqlalchemy/dialects/postgresql/base.py
@@ -862,6 +862,7 @@ import re
import datetime as dt
+from sqlalchemy.sql import elements
from ... import sql, schema, exc, util
from ...engine import default, reflection
from ...sql import compiler, expression
@@ -1499,17 +1500,52 @@ class PGCompiler(compiler.SQLCompiler):
target_text = self._on_conflict_target(on_conflict, **kw)
action_set_ops = []
- for k, v in clause.update_values_to_set:
- key_text = (
- self.preparer.quote(k)
- if isinstance(k, util.string_types)
- else self.process(k, use_schema=False)
- )
- value_text = self.process(
- v,
- use_schema=False
+
+ set_parameters = dict(clause.update_values_to_set)
+ # create a list of column assignment clauses as tuples
+ cols = self.statement.table.c
+ for c in cols:
+ col_key = c.key
+ if col_key in set_parameters:
+ value = set_parameters.pop(col_key)
+ if elements._is_literal(value):
+ value = elements.BindParameter(
+ None, value, type_=c.type
+ )
+
+ else:
+ if isinstance(value, elements.BindParameter) and \
+ value.type._isnull:
+ value = value._clone()
+ value.type = c.type
+ value_text = self.process(value.self_group(), use_schema=False)
+
+ key_text = (
+ self.preparer.quote(col_key)
+ )
+ action_set_ops.append('%s = %s' % (key_text, value_text))
+
+ # check for names that don't match columns
+ if set_parameters:
+ util.warn(
+ "Additional column names not matching "
+ "any column keys in table '%s': %s" % (
+ self.statement.table.name,
+ (", ".join("'%s'" % c for c in set_parameters))
+ )
)
- action_set_ops.append('%s = %s' % (key_text, value_text))
+ for k, v in set_parameters.items():
+ key_text = (
+ self.preparer.quote(k)
+ if isinstance(k, util.string_types)
+ else self.process(k, use_schema=False)
+ )
+ value_text = self.process(
+ elements._literal_as_binds(v),
+ use_schema=False
+ )
+ action_set_ops.append('%s = %s' % (key_text, value_text))
+
action_text = ', '.join(action_set_ops)
if clause.update_whereclause is not None:
action_text += ' WHERE %s' % \
diff --git a/lib/sqlalchemy/dialects/postgresql/dml.py b/lib/sqlalchemy/dialects/postgresql/dml.py
index df53fa8a2..bfdfbfa36 100644
--- a/lib/sqlalchemy/dialects/postgresql/dml.py
+++ b/lib/sqlalchemy/dialects/postgresql/dml.py
@@ -70,6 +70,8 @@ class Insert(StandardInsert):
Required argument. A dictionary or other mapping object
with column names as keys and expressions or literals as values,
specifying the ``SET`` actions to take.
+ If the target :class:`.Column` specifies a ".key" attribute distinct
+ from the column name, that key should be used.
.. warning:: This dictionary does **not** take into account
Python-specified default UPDATE values or generation functions,
@@ -205,7 +207,7 @@ class OnConflictDoUpdate(OnConflictClause):
if (not isinstance(set_, dict) or not set_):
raise ValueError("set parameter must be a non-empty dictionary")
self.update_values_to_set = [
- (key, _literal_as_binds(value))
+ (key, value)
for key, value in set_.items()
]
self.update_whereclause = where