summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2012-12-08 14:25:42 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2012-12-08 14:25:42 -0500
commit927b9859834096dd77182f935ff611351407f0dc (patch)
treed73e3495677628a8394f47a6db7c396d1aea97f9 /lib/sqlalchemy/sql/compiler.py
parent1ee4736beaadeb9053f8886503b64ee04fa4b557 (diff)
downloadsqlalchemy-927b9859834096dd77182f935ff611351407f0dc.tar.gz
- multivalued inserts, [ticket:2623]
- update "not supported" messages for empty inserts, mutlivalue inserts - rework the ValuesBase approach for multiple value sets so that stmt.parameters does store a list for multiple values; the _has_multiple_parameters flag now indicates which of the two modes the statement is within. it now raises exceptions if a subsequent call to values() attempts to call a ValuesBase with one mode in the style of the other mode; that is, you can't switch a single- or multi- valued ValuesBase to the other mode, and also if a multiple value is passed simultaneously with a kwargs set. Added tests for these error conditions - Calling values() multiple times in multivalue mode now extends the parameter list to include the new parameter sets. - add error/test if multiple *args were passed to ValuesBase.values() - rework the compiler approach for multivalue inserts, back to where _get_colparams() returns the same list of (column, value) as before, thereby maintaining the identical number of append() and other calls when multivalue is not enabled. In the case of multivalue, it makes a last-minute switch to return a list of lists instead of the single list. As it constructs the additional lists, the inline defaults and other calculated default parameters of the first parameter set are copied into the newly generated lists so that these features continue to function for a multivalue insert. Multivalue inserts now add no additional function calls to the compilation for regular insert constructs. - parameter lists for multivalue inserts now includes an integer index for all parameter sets. - add detailed documentation for ValuesBase.values(), including careful wording to describe the difference between multiple values and an executemany() call. - add a test for multivalue insert + returning - it works ! - remove the very old/never used "postgresql_returning"/"firebird_returning" flags.
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py175
1 files changed, 105 insertions, 70 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index b7dc03414..1e8bc3760 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -1275,19 +1275,26 @@ class SQLCompiler(engine.Compiled):
def visit_insert(self, insert_stmt, **kw):
self.isinsert = True
- cols, params = self._get_colparams(insert_stmt)
+ colparams = self._get_colparams(insert_stmt)
- if not cols and \
+ if not colparams and \
not self.dialect.supports_default_values and \
not self.dialect.supports_empty_insert:
- raise exc.CompileError("The version of %s you are using does "
- "not support empty inserts." %
+ raise exc.CompileError("The '%s' dialect with current database "
+ "version settings does not support empty "
+ "inserts." %
self.dialect.name)
- if insert_stmt.multi_parameters and not self.dialect.supports_multirow_insert:
- raise exc.CompileError("The version of %s you are using does "
- "not support multirow inserts." %
+ if insert_stmt._has_multi_parameters:
+ if not self.dialect.supports_multirow_insert:
+ raise exc.CompileError("The '%s' dialect with current database "
+ "version settings does not support "
+ "in-place multirow inserts." %
self.dialect.name)
+ colparams_single = colparams[0]
+ else:
+ colparams_single = colparams
+
preparer = self.preparer
supports_default_values = self.dialect.supports_default_values
@@ -1318,9 +1325,9 @@ class SQLCompiler(engine.Compiled):
text += table_text
- if cols or not supports_default_values:
- text += " (%s)" % ', '.join([preparer.format_column(c)
- for c in cols])
+ if colparams_single or not supports_default_values:
+ text += " (%s)" % ', '.join([preparer.format_column(c[0])
+ for c in colparams_single])
if self.returning or insert_stmt._returning:
self.returning = self.returning or insert_stmt._returning
@@ -1330,14 +1337,20 @@ class SQLCompiler(engine.Compiled):
if self.returning_precedes_values:
text += " " + returning_clause
- if not cols and supports_default_values:
+ if not colparams and supports_default_values:
text += " DEFAULT VALUES"
+ elif insert_stmt._has_multi_parameters:
+ text += " VALUES %s" % (
+ ", ".join(
+ "(%s)" % (
+ ', '.join(c[1] for c in colparam_set)
+ )
+ for colparam_set in colparams
+ )
+ )
else:
- values = []
- for row in params:
- values.append('(%s)' % ', '.join(row))
- text += " VALUES %s" % \
- ', '.join(values)
+ text += " VALUES (%s)" % \
+ ', '.join([c[1] for c in colparams])
if self.returning and not self.returning_precedes_values:
text += " " + returning_clause
@@ -1381,7 +1394,7 @@ class SQLCompiler(engine.Compiled):
extra_froms = update_stmt._extra_froms
- cols, params = self._get_colparams(update_stmt, extra_froms)
+ colparams = self._get_colparams(update_stmt, extra_froms)
text = "UPDATE "
@@ -1414,13 +1427,10 @@ class SQLCompiler(engine.Compiled):
text += ' SET '
include_table = extra_froms and \
self.render_table_with_column_in_update_from
- colparams = []
- if params:
- colparams = zip(cols, params[0])
text += ', '.join(
- c._compiler_dispatch(self,
+ c[0]._compiler_dispatch(self,
include_table=include_table) +
- '=' + p for c, p in colparams
+ '=' + c[1] for c in colparams
)
if update_stmt._returning:
@@ -1480,12 +1490,19 @@ class SQLCompiler(engine.Compiled):
# no parameters in the statement, no parameters in the
# compiled params - return binds for all columns
if self.column_keys is None and stmt.parameters is None:
- values = [self._create_crud_bind_param(c, None, required=True)
- for c in stmt.table.columns]
- return list(stmt.table.columns), [values]
+ return [
+ (c, self._create_crud_bind_param(c,
+ None, required=True))
+ for c in stmt.table.columns
+ ]
required = object()
+ if stmt._has_multi_parameters:
+ stmt_parameters = stmt.parameters[0]
+ else:
+ stmt_parameters = stmt.parameters
+
# if we have statement parameters - set defaults in the
# compiled params
if self.column_keys is None:
@@ -1493,15 +1510,14 @@ class SQLCompiler(engine.Compiled):
else:
parameters = dict((sql._column_as_key(key), required)
for key in self.column_keys
- if not stmt.parameters or
- key not in stmt.parameters)
+ if not stmt_parameters or
+ key not in stmt_parameters)
# create a list of column assignment clauses as tuples
- columns = []
values = []
- if stmt.parameters is not None:
- for k, v in stmt.parameters.iteritems():
+ if stmt_parameters is not None:
+ for k, v in stmt_parameters.iteritems():
colkey = sql._column_as_key(k)
if colkey is not None:
parameters.setdefault(colkey, v)
@@ -1514,8 +1530,7 @@ class SQLCompiler(engine.Compiled):
else:
v = self.process(v.self_group())
- columns.append(k)
- values.append(v)
+ values.append((k, v))
need_pks = self.isinsert and \
not self.inline and \
@@ -1530,10 +1545,10 @@ class SQLCompiler(engine.Compiled):
check_columns = {}
# special logic that only occurs for multi-table UPDATE
# statements
- if extra_tables and stmt.parameters:
+ if extra_tables and stmt_parameters:
normalized_params = dict(
(sql._clause_element_as_expr(c), param)
- for c, param in stmt.parameters.items()
+ for c, param in stmt_parameters.items()
)
assert self.isupdate
affected_tables = set()
@@ -1549,8 +1564,7 @@ class SQLCompiler(engine.Compiled):
else:
self.postfetch.append(c)
value = self.process(value.self_group())
- columns.append(c)
- values.append(value)
+ values.append((c, value))
# determine tables which are actually
# to be updated - process onupdate and
# server_onupdate for these
@@ -1560,12 +1574,14 @@ class SQLCompiler(engine.Compiled):
continue
elif c.onupdate is not None and not c.onupdate.is_sequence:
if c.onupdate.is_clause_element:
- columns.apppend(c)
- values.append(self.process(c.onupdate.arg.self_group()))
+ values.append(
+ (c, self.process(c.onupdate.arg.self_group()))
+ )
self.postfetch.append(c)
else:
- columns.append(c)
- values.append(self._create_crud_bind_param(c, None))
+ values.append(
+ (c, self._create_crud_bind_param(c, None))
+ )
self.prefetch.append(c)
elif c.server_onupdate is not None:
self.postfetch.append(c)
@@ -1578,15 +1594,18 @@ class SQLCompiler(engine.Compiled):
value = parameters.pop(c.key)
if sql._is_literal(value):
value = self._create_crud_bind_param(
- c, value, required=value is required)
+ c, value, required=value is required,
+ name=c.key
+ if not stmt._has_multi_parameters
+ else "%s_0" % c.key
+ )
elif c.primary_key and implicit_returning:
self.returning.append(c)
value = self.process(value.self_group())
else:
self.postfetch.append(c)
value = self.process(value.self_group())
- columns.append(c)
- values.append(value)
+ values.append((c, value))
elif self.isinsert:
if c.primary_key and \
@@ -1604,16 +1623,18 @@ class SQLCompiler(engine.Compiled):
(not c.default.optional or \
not self.dialect.sequences_optional):
proc = self.process(c.default)
- columns.append(c)
- values.append(proc)
+ values.append((c, proc))
self.returning.append(c)
elif c.default.is_clause_element:
- columns.append(c)
- values.append(self.process(c.default.arg.self_group()))
+ values.append(
+ (c,
+ self.process(c.default.arg.self_group()))
+ )
self.returning.append(c)
else:
- columns.append(c)
- values.append(self._create_crud_bind_param(c, None))
+ values.append(
+ (c, self._create_crud_bind_param(c, None))
+ )
self.prefetch.append(c)
else:
self.returning.append(c)
@@ -1624,8 +1645,10 @@ class SQLCompiler(engine.Compiled):
self.dialect.preexecute_autoincrement_sequences
):
- columns.append(c)
- values.append(self._create_crud_bind_param(c, None))
+ values.append(
+ (c, self._create_crud_bind_param(c, None))
+ )
+
self.prefetch.append(c)
elif c.default is not None:
@@ -1634,20 +1657,21 @@ class SQLCompiler(engine.Compiled):
(not c.default.optional or \
not self.dialect.sequences_optional):
proc = self.process(c.default)
- columns.append(c)
- values.append(proc)
+ values.append((c, proc))
if not c.primary_key:
self.postfetch.append(c)
elif c.default.is_clause_element:
- columns.append(c)
- values.append(self.process(c.default.arg.self_group()))
+ values.append(
+ (c, self.process(c.default.arg.self_group()))
+ )
if not c.primary_key:
# dont add primary key column to postfetch
self.postfetch.append(c)
else:
- columns.append(c)
- values.append(self._create_crud_bind_param(c, None))
+ values.append(
+ (c, self._create_crud_bind_param(c, None))
+ )
self.prefetch.append(c)
elif c.server_default is not None:
if not c.primary_key:
@@ -1656,17 +1680,19 @@ class SQLCompiler(engine.Compiled):
elif self.isupdate:
if c.onupdate is not None and not c.onupdate.is_sequence:
if c.onupdate.is_clause_element:
- columns.append(c)
- values.append(self.process(c.onupdate.arg.self_group()))
+ values.append(
+ (c, self.process(c.onupdate.arg.self_group()))
+ )
self.postfetch.append(c)
else:
- columns.append(c)
- values.append(self._create_crud_bind_param(c, None))
+ values.append(
+ (c, self._create_crud_bind_param(c, None))
+ )
self.prefetch.append(c)
elif c.server_onupdate is not None:
self.postfetch.append(c)
- if parameters and stmt.parameters:
+ if parameters and stmt_parameters:
check = set(parameters).intersection(
sql._column_as_key(k) for k in stmt.parameters
).difference(check_columns)
@@ -1676,17 +1702,26 @@ class SQLCompiler(engine.Compiled):
(", ".join(check))
)
- if values:
+ if stmt._has_multi_parameters:
+ values_0 = values
values = [values]
- for i, row in enumerate(stmt.multi_parameters):
- r = []
- for c in columns:
- r.append(self._create_crud_bind_param(c, row[c.key],
- name=c.key + str(i)))
- values.append(r)
+ values.extend(
+ [
+ (
+ c,
+ self._create_crud_bind_param(
+ c, row[c.key],
+ name="%s_%d" % (c.key, i + 1)
+ )
+ if c.key in row else param
+ )
+ for (c, param) in values_0
+ ]
+ for i, row in enumerate(stmt.parameters[1:])
+ )
- return columns, values
+ return values
def visit_delete(self, delete_stmt, **kw):
self.stack.append({'from': set([delete_stmt.table])})