summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py175
1 files changed, 105 insertions, 70 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index b7dc03414..1e8bc3760 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -1275,19 +1275,26 @@ class SQLCompiler(engine.Compiled):
def visit_insert(self, insert_stmt, **kw):
self.isinsert = True
- cols, params = self._get_colparams(insert_stmt)
+ colparams = self._get_colparams(insert_stmt)
- if not cols and \
+ if not colparams and \
not self.dialect.supports_default_values and \
not self.dialect.supports_empty_insert:
- raise exc.CompileError("The version of %s you are using does "
- "not support empty inserts." %
+ raise exc.CompileError("The '%s' dialect with current database "
+ "version settings does not support empty "
+ "inserts." %
self.dialect.name)
- if insert_stmt.multi_parameters and not self.dialect.supports_multirow_insert:
- raise exc.CompileError("The version of %s you are using does "
- "not support multirow inserts." %
+ if insert_stmt._has_multi_parameters:
+ if not self.dialect.supports_multirow_insert:
+ raise exc.CompileError("The '%s' dialect with current database "
+ "version settings does not support "
+ "in-place multirow inserts." %
self.dialect.name)
+ colparams_single = colparams[0]
+ else:
+ colparams_single = colparams
+
preparer = self.preparer
supports_default_values = self.dialect.supports_default_values
@@ -1318,9 +1325,9 @@ class SQLCompiler(engine.Compiled):
text += table_text
- if cols or not supports_default_values:
- text += " (%s)" % ', '.join([preparer.format_column(c)
- for c in cols])
+ if colparams_single or not supports_default_values:
+ text += " (%s)" % ', '.join([preparer.format_column(c[0])
+ for c in colparams_single])
if self.returning or insert_stmt._returning:
self.returning = self.returning or insert_stmt._returning
@@ -1330,14 +1337,20 @@ class SQLCompiler(engine.Compiled):
if self.returning_precedes_values:
text += " " + returning_clause
- if not cols and supports_default_values:
+ if not colparams and supports_default_values:
text += " DEFAULT VALUES"
+ elif insert_stmt._has_multi_parameters:
+ text += " VALUES %s" % (
+ ", ".join(
+ "(%s)" % (
+ ', '.join(c[1] for c in colparam_set)
+ )
+ for colparam_set in colparams
+ )
+ )
else:
- values = []
- for row in params:
- values.append('(%s)' % ', '.join(row))
- text += " VALUES %s" % \
- ', '.join(values)
+ text += " VALUES (%s)" % \
+ ', '.join([c[1] for c in colparams])
if self.returning and not self.returning_precedes_values:
text += " " + returning_clause
@@ -1381,7 +1394,7 @@ class SQLCompiler(engine.Compiled):
extra_froms = update_stmt._extra_froms
- cols, params = self._get_colparams(update_stmt, extra_froms)
+ colparams = self._get_colparams(update_stmt, extra_froms)
text = "UPDATE "
@@ -1414,13 +1427,10 @@ class SQLCompiler(engine.Compiled):
text += ' SET '
include_table = extra_froms and \
self.render_table_with_column_in_update_from
- colparams = []
- if params:
- colparams = zip(cols, params[0])
text += ', '.join(
- c._compiler_dispatch(self,
+ c[0]._compiler_dispatch(self,
include_table=include_table) +
- '=' + p for c, p in colparams
+ '=' + c[1] for c in colparams
)
if update_stmt._returning:
@@ -1480,12 +1490,19 @@ class SQLCompiler(engine.Compiled):
# no parameters in the statement, no parameters in the
# compiled params - return binds for all columns
if self.column_keys is None and stmt.parameters is None:
- values = [self._create_crud_bind_param(c, None, required=True)
- for c in stmt.table.columns]
- return list(stmt.table.columns), [values]
+ return [
+ (c, self._create_crud_bind_param(c,
+ None, required=True))
+ for c in stmt.table.columns
+ ]
required = object()
+ if stmt._has_multi_parameters:
+ stmt_parameters = stmt.parameters[0]
+ else:
+ stmt_parameters = stmt.parameters
+
# if we have statement parameters - set defaults in the
# compiled params
if self.column_keys is None:
@@ -1493,15 +1510,14 @@ class SQLCompiler(engine.Compiled):
else:
parameters = dict((sql._column_as_key(key), required)
for key in self.column_keys
- if not stmt.parameters or
- key not in stmt.parameters)
+ if not stmt_parameters or
+ key not in stmt_parameters)
# create a list of column assignment clauses as tuples
- columns = []
values = []
- if stmt.parameters is not None:
- for k, v in stmt.parameters.iteritems():
+ if stmt_parameters is not None:
+ for k, v in stmt_parameters.iteritems():
colkey = sql._column_as_key(k)
if colkey is not None:
parameters.setdefault(colkey, v)
@@ -1514,8 +1530,7 @@ class SQLCompiler(engine.Compiled):
else:
v = self.process(v.self_group())
- columns.append(k)
- values.append(v)
+ values.append((k, v))
need_pks = self.isinsert and \
not self.inline and \
@@ -1530,10 +1545,10 @@ class SQLCompiler(engine.Compiled):
check_columns = {}
# special logic that only occurs for multi-table UPDATE
# statements
- if extra_tables and stmt.parameters:
+ if extra_tables and stmt_parameters:
normalized_params = dict(
(sql._clause_element_as_expr(c), param)
- for c, param in stmt.parameters.items()
+ for c, param in stmt_parameters.items()
)
assert self.isupdate
affected_tables = set()
@@ -1549,8 +1564,7 @@ class SQLCompiler(engine.Compiled):
else:
self.postfetch.append(c)
value = self.process(value.self_group())
- columns.append(c)
- values.append(value)
+ values.append((c, value))
# determine tables which are actually
# to be updated - process onupdate and
# server_onupdate for these
@@ -1560,12 +1574,14 @@ class SQLCompiler(engine.Compiled):
continue
elif c.onupdate is not None and not c.onupdate.is_sequence:
if c.onupdate.is_clause_element:
- columns.apppend(c)
- values.append(self.process(c.onupdate.arg.self_group()))
+ values.append(
+ (c, self.process(c.onupdate.arg.self_group()))
+ )
self.postfetch.append(c)
else:
- columns.append(c)
- values.append(self._create_crud_bind_param(c, None))
+ values.append(
+ (c, self._create_crud_bind_param(c, None))
+ )
self.prefetch.append(c)
elif c.server_onupdate is not None:
self.postfetch.append(c)
@@ -1578,15 +1594,18 @@ class SQLCompiler(engine.Compiled):
value = parameters.pop(c.key)
if sql._is_literal(value):
value = self._create_crud_bind_param(
- c, value, required=value is required)
+ c, value, required=value is required,
+ name=c.key
+ if not stmt._has_multi_parameters
+ else "%s_0" % c.key
+ )
elif c.primary_key and implicit_returning:
self.returning.append(c)
value = self.process(value.self_group())
else:
self.postfetch.append(c)
value = self.process(value.self_group())
- columns.append(c)
- values.append(value)
+ values.append((c, value))
elif self.isinsert:
if c.primary_key and \
@@ -1604,16 +1623,18 @@ class SQLCompiler(engine.Compiled):
(not c.default.optional or \
not self.dialect.sequences_optional):
proc = self.process(c.default)
- columns.append(c)
- values.append(proc)
+ values.append((c, proc))
self.returning.append(c)
elif c.default.is_clause_element:
- columns.append(c)
- values.append(self.process(c.default.arg.self_group()))
+ values.append(
+ (c,
+ self.process(c.default.arg.self_group()))
+ )
self.returning.append(c)
else:
- columns.append(c)
- values.append(self._create_crud_bind_param(c, None))
+ values.append(
+ (c, self._create_crud_bind_param(c, None))
+ )
self.prefetch.append(c)
else:
self.returning.append(c)
@@ -1624,8 +1645,10 @@ class SQLCompiler(engine.Compiled):
self.dialect.preexecute_autoincrement_sequences
):
- columns.append(c)
- values.append(self._create_crud_bind_param(c, None))
+ values.append(
+ (c, self._create_crud_bind_param(c, None))
+ )
+
self.prefetch.append(c)
elif c.default is not None:
@@ -1634,20 +1657,21 @@ class SQLCompiler(engine.Compiled):
(not c.default.optional or \
not self.dialect.sequences_optional):
proc = self.process(c.default)
- columns.append(c)
- values.append(proc)
+ values.append((c, proc))
if not c.primary_key:
self.postfetch.append(c)
elif c.default.is_clause_element:
- columns.append(c)
- values.append(self.process(c.default.arg.self_group()))
+ values.append(
+ (c, self.process(c.default.arg.self_group()))
+ )
if not c.primary_key:
# dont add primary key column to postfetch
self.postfetch.append(c)
else:
- columns.append(c)
- values.append(self._create_crud_bind_param(c, None))
+ values.append(
+ (c, self._create_crud_bind_param(c, None))
+ )
self.prefetch.append(c)
elif c.server_default is not None:
if not c.primary_key:
@@ -1656,17 +1680,19 @@ class SQLCompiler(engine.Compiled):
elif self.isupdate:
if c.onupdate is not None and not c.onupdate.is_sequence:
if c.onupdate.is_clause_element:
- columns.append(c)
- values.append(self.process(c.onupdate.arg.self_group()))
+ values.append(
+ (c, self.process(c.onupdate.arg.self_group()))
+ )
self.postfetch.append(c)
else:
- columns.append(c)
- values.append(self._create_crud_bind_param(c, None))
+ values.append(
+ (c, self._create_crud_bind_param(c, None))
+ )
self.prefetch.append(c)
elif c.server_onupdate is not None:
self.postfetch.append(c)
- if parameters and stmt.parameters:
+ if parameters and stmt_parameters:
check = set(parameters).intersection(
sql._column_as_key(k) for k in stmt.parameters
).difference(check_columns)
@@ -1676,17 +1702,26 @@ class SQLCompiler(engine.Compiled):
(", ".join(check))
)
- if values:
+ if stmt._has_multi_parameters:
+ values_0 = values
values = [values]
- for i, row in enumerate(stmt.multi_parameters):
- r = []
- for c in columns:
- r.append(self._create_crud_bind_param(c, row[c.key],
- name=c.key + str(i)))
- values.append(r)
+ values.extend(
+ [
+ (
+ c,
+ self._create_crud_bind_param(
+ c, row[c.key],
+ name="%s_%d" % (c.key, i + 1)
+ )
+ if c.key in row else param
+ )
+ for (c, param) in values_0
+ ]
+ for i, row in enumerate(stmt.parameters[1:])
+ )
- return columns, values
+ return values
def visit_delete(self, delete_stmt, **kw):
self.stack.append({'from': set([delete_stmt.table])})