diff options
Diffstat (limited to 'lib/sqlalchemy/engine.py')
-rw-r--r-- | lib/sqlalchemy/engine.py | 46 |
1 files changed, 44 insertions, 2 deletions
diff --git a/lib/sqlalchemy/engine.py b/lib/sqlalchemy/engine.py index 7d158cb7e..3703169fa 100644 --- a/lib/sqlalchemy/engine.py +++ b/lib/sqlalchemy/engine.py @@ -135,6 +135,12 @@ class DefaultRunner(schema.SchemaVisitor): else: return None + def get_column_onupdate(self, column): + if column.onupdate is not None: + return column.onupdate.accept_schema_visitor(self) + else: + return None + def visit_passive_default(self, default): """passive defaults by definition return None on the app side, and are post-fetched to get the DB-side value""" @@ -147,7 +153,15 @@ class DefaultRunner(schema.SchemaVisitor): def exec_default_sql(self, default): c = sql.select([default.arg], engine=self.engine).compile() return self.proxy(str(c), c.get_params()).fetchone()[0] - + + def visit_column_onupdate(self, onupdate): + if isinstance(onupdate.arg, sql.ClauseElement): + return self.exec_default_sql(onupdate) + elif callable(onupdate.arg): + return onupdate.arg() + else: + return onupdate.arg + def visit_column_default(self, default): if isinstance(default.arg, sql.ClauseElement): return self.exec_default_sql(default) @@ -245,6 +259,13 @@ class SQLEngine(schema.SchemaEngine): typeobj = typeobj() return typeobj + def _func(self): + class FunctionGateway(object): + def __getattr__(s, name): + return lambda *c, **kwargs: sql.Function(name, engine=self, *c, **kwargs) + return FunctionGateway() + func = property(_func) + def text(self, text, *args, **kwargs): """returns a sql.text() object for performing literal queries.""" return sql.text(text, engine=self, *args, **kwargs) @@ -426,6 +447,15 @@ class SQLEngine(schema.SchemaEngine): self.context.tcount = None def _process_defaults(self, proxy, compiled, parameters, **kwargs): + """INSERT and UPDATE statements, when compiled, may have additional columns added to their + VALUES and SET lists corresponding to column defaults/onupdates that are present on the + Table object (i.e. ColumnDefault, Sequence, PassiveDefault). This method pre-execs those + DefaultGenerator objects that require pre-execution and sets their values within the + parameter list, and flags the thread-local state about + PassiveDefault objects that may require post-fetching the row after it is inserted/updated. + This method relies upon logic within the ANSISQLCompiler in its visit_insert and + visit_update methods that add the appropriate column clauses to the statement when its + being compiled, so that these parameters can be bound to the statement.""" if compiled is None: return if getattr(compiled, "isinsert", False): if isinstance(parameters, list): @@ -454,7 +484,19 @@ class SQLEngine(schema.SchemaEngine): self.context.last_inserted_ids = None else: self.context.last_inserted_ids = last_inserted_ids - + elif getattr(compiled, 'isupdate', False): + if isinstance(parameters, list): + plist = parameters + else: + plist = [parameters] + drunner = self.defaultrunner(proxy) + for param in plist: + for c in compiled.statement.table.c: + if c.onupdate is not None and (not param.has_key(c.name) or param[c.name] is None): + value = drunner.get_column_onupdate(c) + if value is not None: + param[c.name] = value + def lastrow_has_defaults(self): return self.context.lastrow_has_defaults |