diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2012-12-08 14:25:42 -0500 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2012-12-08 14:25:42 -0500 |
commit | 927b9859834096dd77182f935ff611351407f0dc (patch) | |
tree | d73e3495677628a8394f47a6db7c396d1aea97f9 /lib/sqlalchemy/sql/compiler.py | |
parent | 1ee4736beaadeb9053f8886503b64ee04fa4b557 (diff) | |
download | sqlalchemy-927b9859834096dd77182f935ff611351407f0dc.tar.gz |
- multivalued inserts, [ticket:2623]
- update "not supported" messages for empty inserts, mutlivalue inserts
- rework the ValuesBase approach for multiple value sets so that stmt.parameters
does store a list for multiple values; the _has_multiple_parameters flag now indicates
which of the two modes the statement is within. it now raises exceptions if a subsequent
call to values() attempts to call a ValuesBase with one mode in the style of the other
mode; that is, you can't switch a single- or multi- valued ValuesBase to the other mode,
and also if a multiple value is passed simultaneously with a kwargs set.
Added tests for these error conditions
- Calling values() multiple times in multivalue mode now extends the parameter list to
include the new parameter sets.
- add error/test if multiple *args were passed to ValuesBase.values()
- rework the compiler approach for multivalue inserts, back to where
_get_colparams() returns the same list of (column, value) as before, thereby
maintaining the identical number of append() and other calls when multivalue
is not enabled. In the case of multivalue, it makes a last-minute switch to return
a list of lists instead of the single list. As it constructs the additional lists, the inline
defaults and other calculated default parameters of the first parameter
set are copied into the newly generated lists so that these features continue
to function for a multivalue insert. Multivalue inserts now add no additional
function calls to the compilation for regular insert constructs.
- parameter lists for multivalue inserts now includes an integer index for all
parameter sets.
- add detailed documentation for ValuesBase.values(), including careful wording
to describe the difference between multiple values and an executemany() call.
- add a test for multivalue insert + returning - it works !
- remove the very old/never used "postgresql_returning"/"firebird_returning" flags.
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 175 |
1 files changed, 105 insertions, 70 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index b7dc03414..1e8bc3760 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1275,19 +1275,26 @@ class SQLCompiler(engine.Compiled): def visit_insert(self, insert_stmt, **kw): self.isinsert = True - cols, params = self._get_colparams(insert_stmt) + colparams = self._get_colparams(insert_stmt) - if not cols and \ + if not colparams and \ not self.dialect.supports_default_values and \ not self.dialect.supports_empty_insert: - raise exc.CompileError("The version of %s you are using does " - "not support empty inserts." % + raise exc.CompileError("The '%s' dialect with current database " + "version settings does not support empty " + "inserts." % self.dialect.name) - if insert_stmt.multi_parameters and not self.dialect.supports_multirow_insert: - raise exc.CompileError("The version of %s you are using does " - "not support multirow inserts." % + if insert_stmt._has_multi_parameters: + if not self.dialect.supports_multirow_insert: + raise exc.CompileError("The '%s' dialect with current database " + "version settings does not support " + "in-place multirow inserts." % self.dialect.name) + colparams_single = colparams[0] + else: + colparams_single = colparams + preparer = self.preparer supports_default_values = self.dialect.supports_default_values @@ -1318,9 +1325,9 @@ class SQLCompiler(engine.Compiled): text += table_text - if cols or not supports_default_values: - text += " (%s)" % ', '.join([preparer.format_column(c) - for c in cols]) + if colparams_single or not supports_default_values: + text += " (%s)" % ', '.join([preparer.format_column(c[0]) + for c in colparams_single]) if self.returning or insert_stmt._returning: self.returning = self.returning or insert_stmt._returning @@ -1330,14 +1337,20 @@ class SQLCompiler(engine.Compiled): if self.returning_precedes_values: text += " " + returning_clause - if not cols and supports_default_values: + if not colparams and supports_default_values: text += " DEFAULT VALUES" + elif insert_stmt._has_multi_parameters: + text += " VALUES %s" % ( + ", ".join( + "(%s)" % ( + ', '.join(c[1] for c in colparam_set) + ) + for colparam_set in colparams + ) + ) else: - values = [] - for row in params: - values.append('(%s)' % ', '.join(row)) - text += " VALUES %s" % \ - ', '.join(values) + text += " VALUES (%s)" % \ + ', '.join([c[1] for c in colparams]) if self.returning and not self.returning_precedes_values: text += " " + returning_clause @@ -1381,7 +1394,7 @@ class SQLCompiler(engine.Compiled): extra_froms = update_stmt._extra_froms - cols, params = self._get_colparams(update_stmt, extra_froms) + colparams = self._get_colparams(update_stmt, extra_froms) text = "UPDATE " @@ -1414,13 +1427,10 @@ class SQLCompiler(engine.Compiled): text += ' SET ' include_table = extra_froms and \ self.render_table_with_column_in_update_from - colparams = [] - if params: - colparams = zip(cols, params[0]) text += ', '.join( - c._compiler_dispatch(self, + c[0]._compiler_dispatch(self, include_table=include_table) + - '=' + p for c, p in colparams + '=' + c[1] for c in colparams ) if update_stmt._returning: @@ -1480,12 +1490,19 @@ class SQLCompiler(engine.Compiled): # no parameters in the statement, no parameters in the # compiled params - return binds for all columns if self.column_keys is None and stmt.parameters is None: - values = [self._create_crud_bind_param(c, None, required=True) - for c in stmt.table.columns] - return list(stmt.table.columns), [values] + return [ + (c, self._create_crud_bind_param(c, + None, required=True)) + for c in stmt.table.columns + ] required = object() + if stmt._has_multi_parameters: + stmt_parameters = stmt.parameters[0] + else: + stmt_parameters = stmt.parameters + # if we have statement parameters - set defaults in the # compiled params if self.column_keys is None: @@ -1493,15 +1510,14 @@ class SQLCompiler(engine.Compiled): else: parameters = dict((sql._column_as_key(key), required) for key in self.column_keys - if not stmt.parameters or - key not in stmt.parameters) + if not stmt_parameters or + key not in stmt_parameters) # create a list of column assignment clauses as tuples - columns = [] values = [] - if stmt.parameters is not None: - for k, v in stmt.parameters.iteritems(): + if stmt_parameters is not None: + for k, v in stmt_parameters.iteritems(): colkey = sql._column_as_key(k) if colkey is not None: parameters.setdefault(colkey, v) @@ -1514,8 +1530,7 @@ class SQLCompiler(engine.Compiled): else: v = self.process(v.self_group()) - columns.append(k) - values.append(v) + values.append((k, v)) need_pks = self.isinsert and \ not self.inline and \ @@ -1530,10 +1545,10 @@ class SQLCompiler(engine.Compiled): check_columns = {} # special logic that only occurs for multi-table UPDATE # statements - if extra_tables and stmt.parameters: + if extra_tables and stmt_parameters: normalized_params = dict( (sql._clause_element_as_expr(c), param) - for c, param in stmt.parameters.items() + for c, param in stmt_parameters.items() ) assert self.isupdate affected_tables = set() @@ -1549,8 +1564,7 @@ class SQLCompiler(engine.Compiled): else: self.postfetch.append(c) value = self.process(value.self_group()) - columns.append(c) - values.append(value) + values.append((c, value)) # determine tables which are actually # to be updated - process onupdate and # server_onupdate for these @@ -1560,12 +1574,14 @@ class SQLCompiler(engine.Compiled): continue elif c.onupdate is not None and not c.onupdate.is_sequence: if c.onupdate.is_clause_element: - columns.apppend(c) - values.append(self.process(c.onupdate.arg.self_group())) + values.append( + (c, self.process(c.onupdate.arg.self_group())) + ) self.postfetch.append(c) else: - columns.append(c) - values.append(self._create_crud_bind_param(c, None)) + values.append( + (c, self._create_crud_bind_param(c, None)) + ) self.prefetch.append(c) elif c.server_onupdate is not None: self.postfetch.append(c) @@ -1578,15 +1594,18 @@ class SQLCompiler(engine.Compiled): value = parameters.pop(c.key) if sql._is_literal(value): value = self._create_crud_bind_param( - c, value, required=value is required) + c, value, required=value is required, + name=c.key + if not stmt._has_multi_parameters + else "%s_0" % c.key + ) elif c.primary_key and implicit_returning: self.returning.append(c) value = self.process(value.self_group()) else: self.postfetch.append(c) value = self.process(value.self_group()) - columns.append(c) - values.append(value) + values.append((c, value)) elif self.isinsert: if c.primary_key and \ @@ -1604,16 +1623,18 @@ class SQLCompiler(engine.Compiled): (not c.default.optional or \ not self.dialect.sequences_optional): proc = self.process(c.default) - columns.append(c) - values.append(proc) + values.append((c, proc)) self.returning.append(c) elif c.default.is_clause_element: - columns.append(c) - values.append(self.process(c.default.arg.self_group())) + values.append( + (c, + self.process(c.default.arg.self_group())) + ) self.returning.append(c) else: - columns.append(c) - values.append(self._create_crud_bind_param(c, None)) + values.append( + (c, self._create_crud_bind_param(c, None)) + ) self.prefetch.append(c) else: self.returning.append(c) @@ -1624,8 +1645,10 @@ class SQLCompiler(engine.Compiled): self.dialect.preexecute_autoincrement_sequences ): - columns.append(c) - values.append(self._create_crud_bind_param(c, None)) + values.append( + (c, self._create_crud_bind_param(c, None)) + ) + self.prefetch.append(c) elif c.default is not None: @@ -1634,20 +1657,21 @@ class SQLCompiler(engine.Compiled): (not c.default.optional or \ not self.dialect.sequences_optional): proc = self.process(c.default) - columns.append(c) - values.append(proc) + values.append((c, proc)) if not c.primary_key: self.postfetch.append(c) elif c.default.is_clause_element: - columns.append(c) - values.append(self.process(c.default.arg.self_group())) + values.append( + (c, self.process(c.default.arg.self_group())) + ) if not c.primary_key: # dont add primary key column to postfetch self.postfetch.append(c) else: - columns.append(c) - values.append(self._create_crud_bind_param(c, None)) + values.append( + (c, self._create_crud_bind_param(c, None)) + ) self.prefetch.append(c) elif c.server_default is not None: if not c.primary_key: @@ -1656,17 +1680,19 @@ class SQLCompiler(engine.Compiled): elif self.isupdate: if c.onupdate is not None and not c.onupdate.is_sequence: if c.onupdate.is_clause_element: - columns.append(c) - values.append(self.process(c.onupdate.arg.self_group())) + values.append( + (c, self.process(c.onupdate.arg.self_group())) + ) self.postfetch.append(c) else: - columns.append(c) - values.append(self._create_crud_bind_param(c, None)) + values.append( + (c, self._create_crud_bind_param(c, None)) + ) self.prefetch.append(c) elif c.server_onupdate is not None: self.postfetch.append(c) - if parameters and stmt.parameters: + if parameters and stmt_parameters: check = set(parameters).intersection( sql._column_as_key(k) for k in stmt.parameters ).difference(check_columns) @@ -1676,17 +1702,26 @@ class SQLCompiler(engine.Compiled): (", ".join(check)) ) - if values: + if stmt._has_multi_parameters: + values_0 = values values = [values] - for i, row in enumerate(stmt.multi_parameters): - r = [] - for c in columns: - r.append(self._create_crud_bind_param(c, row[c.key], - name=c.key + str(i))) - values.append(r) + values.extend( + [ + ( + c, + self._create_crud_bind_param( + c, row[c.key], + name="%s_%d" % (c.key, i + 1) + ) + if c.key in row else param + ) + for (c, param) in values_0 + ] + for i, row in enumerate(stmt.parameters[1:]) + ) - return columns, values + return values def visit_delete(self, delete_stmt, **kw): self.stack.append({'from': set([delete_stmt.table])}) |