summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2011-11-22 18:46:45 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2011-11-22 18:46:45 -0500
commit9c896906c7e4130ea11cf913dd50d29a9a3e1fa7 (patch)
tree288933c49c5980776d9b751a48bbb7957acf988a /lib/sqlalchemy/sql/compiler.py
parente7b612a69e8b2ec29306d88e08b999dcf79a4822 (diff)
downloadsqlalchemy-9c896906c7e4130ea11cf913dd50d29a9a3e1fa7.tar.gz
also add support for onupdate as we'd like this to fire off if an UPDATE actually
happens on the table
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py44
1 files changed, 31 insertions, 13 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 4b1b9bd5d..7aee5da81 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -1140,23 +1140,41 @@ class SQLCompiler(engine.Compiled):
# special logic that only occurs for multi-table UPDATE
# statements
if extra_tables and stmt.parameters:
+ assert self.isupdate
+ affected_tables = set()
for t in extra_tables:
for c in t.c:
if c in stmt.parameters:
+ affected_tables.add(t)
check_columns[c.key] = c
-
- for c in check_columns.values():
- value = stmt.parameters[c]
- if sql._is_literal(value):
- value = self._create_crud_bind_param(
- c, value, required=value is required)
- elif c.primary_key and implicit_returning:
- self.returning.append(c)
- value = self.process(value.self_group())
- else:
- self.postfetch.append(c)
- value = self.process(value.self_group())
- values.append((c, value))
+ value = stmt.parameters[c]
+ if sql._is_literal(value):
+ value = self._create_crud_bind_param(
+ c, value, required=value is required)
+ else:
+ self.postfetch.append(c)
+ value = self.process(value.self_group())
+ values.append((c, value))
+ # determine tables which are actually
+ # to be updated - process onupdate and
+ # server_onupdate for these
+ for t in affected_tables:
+ for c in t.c:
+ if c in stmt.parameters:
+ continue
+ elif c.onupdate is not None and not c.onupdate.is_sequence:
+ if c.onupdate.is_clause_element:
+ values.append(
+ (c, self.process(c.onupdate.arg.self_group()))
+ )
+ self.postfetch.append(c)
+ else:
+ values.append(
+ (c, self._create_crud_bind_param(c, None))
+ )
+ self.prefetch.append(c)
+ elif c.server_onupdate is not None:
+ self.postfetch.append(c)
# iterating through columns at the top to maintain ordering.
# otherwise we might iterate through individual sets of