diff options
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 175 |
1 files changed, 105 insertions, 70 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index b7dc03414..1e8bc3760 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1275,19 +1275,26 @@ class SQLCompiler(engine.Compiled): def visit_insert(self, insert_stmt, **kw): self.isinsert = True - cols, params = self._get_colparams(insert_stmt) + colparams = self._get_colparams(insert_stmt) - if not cols and \ + if not colparams and \ not self.dialect.supports_default_values and \ not self.dialect.supports_empty_insert: - raise exc.CompileError("The version of %s you are using does " - "not support empty inserts." % + raise exc.CompileError("The '%s' dialect with current database " + "version settings does not support empty " + "inserts." % self.dialect.name) - if insert_stmt.multi_parameters and not self.dialect.supports_multirow_insert: - raise exc.CompileError("The version of %s you are using does " - "not support multirow inserts." % + if insert_stmt._has_multi_parameters: + if not self.dialect.supports_multirow_insert: + raise exc.CompileError("The '%s' dialect with current database " + "version settings does not support " + "in-place multirow inserts." % self.dialect.name) + colparams_single = colparams[0] + else: + colparams_single = colparams + preparer = self.preparer supports_default_values = self.dialect.supports_default_values @@ -1318,9 +1325,9 @@ class SQLCompiler(engine.Compiled): text += table_text - if cols or not supports_default_values: - text += " (%s)" % ', '.join([preparer.format_column(c) - for c in cols]) + if colparams_single or not supports_default_values: + text += " (%s)" % ', '.join([preparer.format_column(c[0]) + for c in colparams_single]) if self.returning or insert_stmt._returning: self.returning = self.returning or insert_stmt._returning @@ -1330,14 +1337,20 @@ class SQLCompiler(engine.Compiled): if self.returning_precedes_values: text += " " + returning_clause - if not cols and supports_default_values: + if not colparams and supports_default_values: text += " DEFAULT VALUES" + elif insert_stmt._has_multi_parameters: + text += " VALUES %s" % ( + ", ".join( + "(%s)" % ( + ', '.join(c[1] for c in colparam_set) + ) + for colparam_set in colparams + ) + ) else: - values = [] - for row in params: - values.append('(%s)' % ', '.join(row)) - text += " VALUES %s" % \ - ', '.join(values) + text += " VALUES (%s)" % \ + ', '.join([c[1] for c in colparams]) if self.returning and not self.returning_precedes_values: text += " " + returning_clause @@ -1381,7 +1394,7 @@ class SQLCompiler(engine.Compiled): extra_froms = update_stmt._extra_froms - cols, params = self._get_colparams(update_stmt, extra_froms) + colparams = self._get_colparams(update_stmt, extra_froms) text = "UPDATE " @@ -1414,13 +1427,10 @@ class SQLCompiler(engine.Compiled): text += ' SET ' include_table = extra_froms and \ self.render_table_with_column_in_update_from - colparams = [] - if params: - colparams = zip(cols, params[0]) text += ', '.join( - c._compiler_dispatch(self, + c[0]._compiler_dispatch(self, include_table=include_table) + - '=' + p for c, p in colparams + '=' + c[1] for c in colparams ) if update_stmt._returning: @@ -1480,12 +1490,19 @@ class SQLCompiler(engine.Compiled): # no parameters in the statement, no parameters in the # compiled params - return binds for all columns if self.column_keys is None and stmt.parameters is None: - values = [self._create_crud_bind_param(c, None, required=True) - for c in stmt.table.columns] - return list(stmt.table.columns), [values] + return [ + (c, self._create_crud_bind_param(c, + None, required=True)) + for c in stmt.table.columns + ] required = object() + if stmt._has_multi_parameters: + stmt_parameters = stmt.parameters[0] + else: + stmt_parameters = stmt.parameters + # if we have statement parameters - set defaults in the # compiled params if self.column_keys is None: @@ -1493,15 +1510,14 @@ class SQLCompiler(engine.Compiled): else: parameters = dict((sql._column_as_key(key), required) for key in self.column_keys - if not stmt.parameters or - key not in stmt.parameters) + if not stmt_parameters or + key not in stmt_parameters) # create a list of column assignment clauses as tuples - columns = [] values = [] - if stmt.parameters is not None: - for k, v in stmt.parameters.iteritems(): + if stmt_parameters is not None: + for k, v in stmt_parameters.iteritems(): colkey = sql._column_as_key(k) if colkey is not None: parameters.setdefault(colkey, v) @@ -1514,8 +1530,7 @@ class SQLCompiler(engine.Compiled): else: v = self.process(v.self_group()) - columns.append(k) - values.append(v) + values.append((k, v)) need_pks = self.isinsert and \ not self.inline and \ @@ -1530,10 +1545,10 @@ class SQLCompiler(engine.Compiled): check_columns = {} # special logic that only occurs for multi-table UPDATE # statements - if extra_tables and stmt.parameters: + if extra_tables and stmt_parameters: normalized_params = dict( (sql._clause_element_as_expr(c), param) - for c, param in stmt.parameters.items() + for c, param in stmt_parameters.items() ) assert self.isupdate affected_tables = set() @@ -1549,8 +1564,7 @@ class SQLCompiler(engine.Compiled): else: self.postfetch.append(c) value = self.process(value.self_group()) - columns.append(c) - values.append(value) + values.append((c, value)) # determine tables which are actually # to be updated - process onupdate and # server_onupdate for these @@ -1560,12 +1574,14 @@ class SQLCompiler(engine.Compiled): continue elif c.onupdate is not None and not c.onupdate.is_sequence: if c.onupdate.is_clause_element: - columns.apppend(c) - values.append(self.process(c.onupdate.arg.self_group())) + values.append( + (c, self.process(c.onupdate.arg.self_group())) + ) self.postfetch.append(c) else: - columns.append(c) - values.append(self._create_crud_bind_param(c, None)) + values.append( + (c, self._create_crud_bind_param(c, None)) + ) self.prefetch.append(c) elif c.server_onupdate is not None: self.postfetch.append(c) @@ -1578,15 +1594,18 @@ class SQLCompiler(engine.Compiled): value = parameters.pop(c.key) if sql._is_literal(value): value = self._create_crud_bind_param( - c, value, required=value is required) + c, value, required=value is required, + name=c.key + if not stmt._has_multi_parameters + else "%s_0" % c.key + ) elif c.primary_key and implicit_returning: self.returning.append(c) value = self.process(value.self_group()) else: self.postfetch.append(c) value = self.process(value.self_group()) - columns.append(c) - values.append(value) + values.append((c, value)) elif self.isinsert: if c.primary_key and \ @@ -1604,16 +1623,18 @@ class SQLCompiler(engine.Compiled): (not c.default.optional or \ not self.dialect.sequences_optional): proc = self.process(c.default) - columns.append(c) - values.append(proc) + values.append((c, proc)) self.returning.append(c) elif c.default.is_clause_element: - columns.append(c) - values.append(self.process(c.default.arg.self_group())) + values.append( + (c, + self.process(c.default.arg.self_group())) + ) self.returning.append(c) else: - columns.append(c) - values.append(self._create_crud_bind_param(c, None)) + values.append( + (c, self._create_crud_bind_param(c, None)) + ) self.prefetch.append(c) else: self.returning.append(c) @@ -1624,8 +1645,10 @@ class SQLCompiler(engine.Compiled): self.dialect.preexecute_autoincrement_sequences ): - columns.append(c) - values.append(self._create_crud_bind_param(c, None)) + values.append( + (c, self._create_crud_bind_param(c, None)) + ) + self.prefetch.append(c) elif c.default is not None: @@ -1634,20 +1657,21 @@ class SQLCompiler(engine.Compiled): (not c.default.optional or \ not self.dialect.sequences_optional): proc = self.process(c.default) - columns.append(c) - values.append(proc) + values.append((c, proc)) if not c.primary_key: self.postfetch.append(c) elif c.default.is_clause_element: - columns.append(c) - values.append(self.process(c.default.arg.self_group())) + values.append( + (c, self.process(c.default.arg.self_group())) + ) if not c.primary_key: # dont add primary key column to postfetch self.postfetch.append(c) else: - columns.append(c) - values.append(self._create_crud_bind_param(c, None)) + values.append( + (c, self._create_crud_bind_param(c, None)) + ) self.prefetch.append(c) elif c.server_default is not None: if not c.primary_key: @@ -1656,17 +1680,19 @@ class SQLCompiler(engine.Compiled): elif self.isupdate: if c.onupdate is not None and not c.onupdate.is_sequence: if c.onupdate.is_clause_element: - columns.append(c) - values.append(self.process(c.onupdate.arg.self_group())) + values.append( + (c, self.process(c.onupdate.arg.self_group())) + ) self.postfetch.append(c) else: - columns.append(c) - values.append(self._create_crud_bind_param(c, None)) + values.append( + (c, self._create_crud_bind_param(c, None)) + ) self.prefetch.append(c) elif c.server_onupdate is not None: self.postfetch.append(c) - if parameters and stmt.parameters: + if parameters and stmt_parameters: check = set(parameters).intersection( sql._column_as_key(k) for k in stmt.parameters ).difference(check_columns) @@ -1676,17 +1702,26 @@ class SQLCompiler(engine.Compiled): (", ".join(check)) ) - if values: + if stmt._has_multi_parameters: + values_0 = values values = [values] - for i, row in enumerate(stmt.multi_parameters): - r = [] - for c in columns: - r.append(self._create_crud_bind_param(c, row[c.key], - name=c.key + str(i))) - values.append(r) + values.extend( + [ + ( + c, + self._create_crud_bind_param( + c, row[c.key], + name="%s_%d" % (c.key, i + 1) + ) + if c.key in row else param + ) + for (c, param) in values_0 + ] + for i, row in enumerate(stmt.parameters[1:]) + ) - return columns, values + return values def visit_delete(self, delete_stmt, **kw): self.stack.append({'from': set([delete_stmt.table])}) |