summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorIdan Kamara <idankk86@gmail.com>2012-12-05 23:45:49 +0200
committerIdan Kamara <idankk86@gmail.com>2012-12-05 23:45:49 +0200
commit51839352a4a9d4b87bdca6c148ec0fd847b8630b (patch)
tree07a1a851af7f30d7feb19936c9939d1c9cb768d8 /lib/sqlalchemy
parent20dd73f575f3d10d53cd68b09bb1c42049fc4211 (diff)
downloadsqlalchemy-51839352a4a9d4b87bdca6c148ec0fd847b8630b.tar.gz
compiler: adjust _get_colparams to return the columns and parameters in separate lists
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/sql/compiler.py99
1 files changed, 49 insertions, 50 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 102b44a7e..6f7f1dadd 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -1275,9 +1275,9 @@ class SQLCompiler(engine.Compiled):
def visit_insert(self, insert_stmt, **kw):
self.isinsert = True
- colparams = self._get_colparams(insert_stmt)
+ cols, params = self._get_colparams(insert_stmt)
- if not colparams and \
+ if not cols and \
not self.dialect.supports_default_values and \
not self.dialect.supports_empty_insert:
raise exc.CompileError("The version of %s you are using does "
@@ -1313,9 +1313,9 @@ class SQLCompiler(engine.Compiled):
text += table_text
- if colparams or not supports_default_values:
- text += " (%s)" % ', '.join([preparer.format_column(c[0])
- for c in colparams])
+ if cols or not supports_default_values:
+ text += " (%s)" % ', '.join([preparer.format_column(c)
+ for c in cols])
if self.returning or insert_stmt._returning:
self.returning = self.returning or insert_stmt._returning
@@ -1325,11 +1325,11 @@ class SQLCompiler(engine.Compiled):
if self.returning_precedes_values:
text += " " + returning_clause
- if not colparams and supports_default_values:
+ if not cols and supports_default_values:
text += " DEFAULT VALUES"
else:
text += " VALUES (%s)" % \
- ', '.join([c[1] for c in colparams])
+ ', '.join(params[0])
if self.returning and not self.returning_precedes_values:
text += " " + returning_clause
@@ -1373,7 +1373,7 @@ class SQLCompiler(engine.Compiled):
extra_froms = update_stmt._extra_froms
- colparams = self._get_colparams(update_stmt, extra_froms)
+ cols, params = self._get_colparams(update_stmt, extra_froms)
text = "UPDATE "
@@ -1406,10 +1406,13 @@ class SQLCompiler(engine.Compiled):
text += ' SET '
include_table = extra_froms and \
self.render_table_with_column_in_update_from
+ colparams = []
+ if params:
+ colparams = zip(cols, params[0])
text += ', '.join(
- c[0]._compiler_dispatch(self,
+ c._compiler_dispatch(self,
include_table=include_table) +
- '=' + c[1] for c in colparams
+ '=' + p for c, p in colparams
)
if update_stmt._returning:
@@ -1467,11 +1470,9 @@ class SQLCompiler(engine.Compiled):
# no parameters in the statement, no parameters in the
# compiled params - return binds for all columns
if self.column_keys is None and stmt.parameters is None:
- return [
- (c, self._create_crud_bind_param(c,
- None, required=True))
- for c in stmt.table.columns
- ]
+ values = [self._create_crud_bind_param(c, None, required=True)
+ for c in stmt.table.columns]
+ return list(stmt.table.columns), [values]
required = object()
@@ -1486,6 +1487,7 @@ class SQLCompiler(engine.Compiled):
key not in stmt.parameters)
# create a list of column assignment clauses as tuples
+ columns = []
values = []
if stmt.parameters is not None:
@@ -1502,7 +1504,8 @@ class SQLCompiler(engine.Compiled):
else:
v = self.process(v.self_group())
- values.append((k, v))
+ columns.append(k)
+ values.append(v)
need_pks = self.isinsert and \
not self.inline and \
@@ -1536,7 +1539,8 @@ class SQLCompiler(engine.Compiled):
else:
self.postfetch.append(c)
value = self.process(value.self_group())
- values.append((c, value))
+ columns.append(c)
+ values.append(value)
# determine tables which are actually
# to be updated - process onupdate and
# server_onupdate for these
@@ -1546,14 +1550,12 @@ class SQLCompiler(engine.Compiled):
continue
elif c.onupdate is not None and not c.onupdate.is_sequence:
if c.onupdate.is_clause_element:
- values.append(
- (c, self.process(c.onupdate.arg.self_group()))
- )
+ columns.apppend(c)
+ values.append(self.process(c.onupdate.arg.self_group()))
self.postfetch.append(c)
else:
- values.append(
- (c, self._create_crud_bind_param(c, None))
- )
+ columns.append(c)
+ values.append(self._create_crud_bind_param(c, None))
self.prefetch.append(c)
elif c.server_onupdate is not None:
self.postfetch.append(c)
@@ -1573,7 +1575,8 @@ class SQLCompiler(engine.Compiled):
else:
self.postfetch.append(c)
value = self.process(value.self_group())
- values.append((c, value))
+ columns.append(c)
+ values.append(value)
elif self.isinsert:
if c.primary_key and \
@@ -1591,18 +1594,16 @@ class SQLCompiler(engine.Compiled):
(not c.default.optional or \
not self.dialect.sequences_optional):
proc = self.process(c.default)
- values.append((c, proc))
+ columns.append(c)
+ values.append(proc)
self.returning.append(c)
elif c.default.is_clause_element:
- values.append(
- (c,
- self.process(c.default.arg.self_group()))
- )
+ columns.append(c)
+ values.append(self.process(c.default.arg.self_group()))
self.returning.append(c)
else:
- values.append(
- (c, self._create_crud_bind_param(c, None))
- )
+ columns.append(c)
+ values.append(self._create_crud_bind_param(c, None))
self.prefetch.append(c)
else:
self.returning.append(c)
@@ -1613,10 +1614,8 @@ class SQLCompiler(engine.Compiled):
self.dialect.preexecute_autoincrement_sequences
):
- values.append(
- (c, self._create_crud_bind_param(c, None))
- )
-
+ columns.append(c)
+ values.append(self._create_crud_bind_param(c, None))
self.prefetch.append(c)
elif c.default is not None:
@@ -1625,21 +1624,20 @@ class SQLCompiler(engine.Compiled):
(not c.default.optional or \
not self.dialect.sequences_optional):
proc = self.process(c.default)
- values.append((c, proc))
+ columns.append(c)
+ values.append(proc)
if not c.primary_key:
self.postfetch.append(c)
elif c.default.is_clause_element:
- values.append(
- (c, self.process(c.default.arg.self_group()))
- )
+ columns.append(c)
+ values.append(self.process(c.default.arg.self_group()))
if not c.primary_key:
# dont add primary key column to postfetch
self.postfetch.append(c)
else:
- values.append(
- (c, self._create_crud_bind_param(c, None))
- )
+ columns.append(c)
+ values.append(self._create_crud_bind_param(c, None))
self.prefetch.append(c)
elif c.server_default is not None:
if not c.primary_key:
@@ -1648,14 +1646,12 @@ class SQLCompiler(engine.Compiled):
elif self.isupdate:
if c.onupdate is not None and not c.onupdate.is_sequence:
if c.onupdate.is_clause_element:
- values.append(
- (c, self.process(c.onupdate.arg.self_group()))
- )
+ columns.append(c)
+ values.append(self.process(c.onupdate.arg.self_group()))
self.postfetch.append(c)
else:
- values.append(
- (c, self._create_crud_bind_param(c, None))
- )
+ columns.append(c)
+ values.append(self._create_crud_bind_param(c, None))
self.prefetch.append(c)
elif c.server_onupdate is not None:
self.postfetch.append(c)
@@ -1670,7 +1666,10 @@ class SQLCompiler(engine.Compiled):
(", ".join(check))
)
- return values
+ if values:
+ values = [values]
+
+ return columns, values
def visit_delete(self, delete_stmt, **kw):
self.stack.append({'from': set([delete_stmt.table])})