diff options
author | mike bayer <mike_mp@zzzcomputing.com> | 2017-01-13 15:47:00 -0500 |
---|---|---|
committer | Gerrit Code Review <gerrit@awstats.zzzcomputing.com> | 2017-01-13 15:47:00 -0500 |
commit | 9ef1913ed64764d3097e12f022b8ce2f3b84ae04 (patch) | |
tree | 93ba545703a324eb18ec3a401f6937a2c2e1d044 /lib/sqlalchemy/dialects/postgresql/base.py | |
parent | 2c13aa097b3588a25173eb297eee08afd18f88d6 (diff) | |
parent | afd78a37dafe8e84e23bccfb570bd758797e2142 (diff) | |
download | sqlalchemy-9ef1913ed64764d3097e12f022b8ce2f3b84ae04.tar.gz |
Merge "Use full column->type processing for ON CONFLICT SET clause"
Diffstat (limited to 'lib/sqlalchemy/dialects/postgresql/base.py')
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/base.py | 56 |
1 files changed, 46 insertions, 10 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 4090da563..44e12f1ca 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -862,6 +862,7 @@ import re import datetime as dt +from sqlalchemy.sql import elements from ... import sql, schema, exc, util from ...engine import default, reflection from ...sql import compiler, expression @@ -1499,17 +1500,52 @@ class PGCompiler(compiler.SQLCompiler): target_text = self._on_conflict_target(on_conflict, **kw) action_set_ops = [] - for k, v in clause.update_values_to_set: - key_text = ( - self.preparer.quote(k) - if isinstance(k, util.string_types) - else self.process(k, use_schema=False) - ) - value_text = self.process( - v, - use_schema=False + + set_parameters = dict(clause.update_values_to_set) + # create a list of column assignment clauses as tuples + cols = self.statement.table.c + for c in cols: + col_key = c.key + if col_key in set_parameters: + value = set_parameters.pop(col_key) + if elements._is_literal(value): + value = elements.BindParameter( + None, value, type_=c.type + ) + + else: + if isinstance(value, elements.BindParameter) and \ + value.type._isnull: + value = value._clone() + value.type = c.type + value_text = self.process(value.self_group(), use_schema=False) + + key_text = ( + self.preparer.quote(col_key) + ) + action_set_ops.append('%s = %s' % (key_text, value_text)) + + # check for names that don't match columns + if set_parameters: + util.warn( + "Additional column names not matching " + "any column keys in table '%s': %s" % ( + self.statement.table.name, + (", ".join("'%s'" % c for c in set_parameters)) + ) ) - action_set_ops.append('%s = %s' % (key_text, value_text)) + for k, v in set_parameters.items(): + key_text = ( + self.preparer.quote(k) + if isinstance(k, util.string_types) + else self.process(k, use_schema=False) + ) + value_text = self.process( + elements._literal_as_binds(v), + use_schema=False + ) + action_set_ops.append('%s = %s' % (key_text, value_text)) + action_text = ', '.join(action_set_ops) if clause.update_whereclause is not None: action_text += ' WHERE %s' % \ |