diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2011-11-22 18:46:45 -0500 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2011-11-22 18:46:45 -0500 |
commit | 9c896906c7e4130ea11cf913dd50d29a9a3e1fa7 (patch) | |
tree | 288933c49c5980776d9b751a48bbb7957acf988a /lib/sqlalchemy/sql/compiler.py | |
parent | e7b612a69e8b2ec29306d88e08b999dcf79a4822 (diff) | |
download | sqlalchemy-9c896906c7e4130ea11cf913dd50d29a9a3e1fa7.tar.gz |
also add support for onupdate as we'd like this to fire off if an UPDATE actually
happens on the table
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 44 |
1 files changed, 31 insertions, 13 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 4b1b9bd5d..7aee5da81 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1140,23 +1140,41 @@ class SQLCompiler(engine.Compiled): # special logic that only occurs for multi-table UPDATE # statements if extra_tables and stmt.parameters: + assert self.isupdate + affected_tables = set() for t in extra_tables: for c in t.c: if c in stmt.parameters: + affected_tables.add(t) check_columns[c.key] = c - - for c in check_columns.values(): - value = stmt.parameters[c] - if sql._is_literal(value): - value = self._create_crud_bind_param( - c, value, required=value is required) - elif c.primary_key and implicit_returning: - self.returning.append(c) - value = self.process(value.self_group()) - else: - self.postfetch.append(c) - value = self.process(value.self_group()) - values.append((c, value)) + value = stmt.parameters[c] + if sql._is_literal(value): + value = self._create_crud_bind_param( + c, value, required=value is required) + else: + self.postfetch.append(c) + value = self.process(value.self_group()) + values.append((c, value)) + # determine tables which are actually + # to be updated - process onupdate and + # server_onupdate for these + for t in affected_tables: + for c in t.c: + if c in stmt.parameters: + continue + elif c.onupdate is not None and not c.onupdate.is_sequence: + if c.onupdate.is_clause_element: + values.append( + (c, self.process(c.onupdate.arg.self_group())) + ) + self.postfetch.append(c) + else: + values.append( + (c, self._create_crud_bind_param(c, None)) + ) + self.prefetch.append(c) + elif c.server_onupdate is not None: + self.postfetch.append(c) # iterating through columns at the top to maintain ordering. # otherwise we might iterate through individual sets of |