summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/engine.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/engine.py')
-rw-r--r--lib/sqlalchemy/engine.py46
1 files changed, 44 insertions, 2 deletions
diff --git a/lib/sqlalchemy/engine.py b/lib/sqlalchemy/engine.py
index 7d158cb7e..3703169fa 100644
--- a/lib/sqlalchemy/engine.py
+++ b/lib/sqlalchemy/engine.py
@@ -135,6 +135,12 @@ class DefaultRunner(schema.SchemaVisitor):
else:
return None
+ def get_column_onupdate(self, column):
+ if column.onupdate is not None:
+ return column.onupdate.accept_schema_visitor(self)
+ else:
+ return None
+
def visit_passive_default(self, default):
"""passive defaults by definition return None on the app side,
and are post-fetched to get the DB-side value"""
@@ -147,7 +153,15 @@ class DefaultRunner(schema.SchemaVisitor):
def exec_default_sql(self, default):
c = sql.select([default.arg], engine=self.engine).compile()
return self.proxy(str(c), c.get_params()).fetchone()[0]
-
+
+ def visit_column_onupdate(self, onupdate):
+ if isinstance(onupdate.arg, sql.ClauseElement):
+ return self.exec_default_sql(onupdate)
+ elif callable(onupdate.arg):
+ return onupdate.arg()
+ else:
+ return onupdate.arg
+
def visit_column_default(self, default):
if isinstance(default.arg, sql.ClauseElement):
return self.exec_default_sql(default)
@@ -245,6 +259,13 @@ class SQLEngine(schema.SchemaEngine):
typeobj = typeobj()
return typeobj
+ def _func(self):
+ class FunctionGateway(object):
+ def __getattr__(s, name):
+ return lambda *c, **kwargs: sql.Function(name, engine=self, *c, **kwargs)
+ return FunctionGateway()
+ func = property(_func)
+
def text(self, text, *args, **kwargs):
"""returns a sql.text() object for performing literal queries."""
return sql.text(text, engine=self, *args, **kwargs)
@@ -426,6 +447,15 @@ class SQLEngine(schema.SchemaEngine):
self.context.tcount = None
def _process_defaults(self, proxy, compiled, parameters, **kwargs):
+ """INSERT and UPDATE statements, when compiled, may have additional columns added to their
+ VALUES and SET lists corresponding to column defaults/onupdates that are present on the
+ Table object (i.e. ColumnDefault, Sequence, PassiveDefault). This method pre-execs those
+ DefaultGenerator objects that require pre-execution and sets their values within the
+ parameter list, and flags the thread-local state about
+ PassiveDefault objects that may require post-fetching the row after it is inserted/updated.
+ This method relies upon logic within the ANSISQLCompiler in its visit_insert and
+ visit_update methods that add the appropriate column clauses to the statement when its
+ being compiled, so that these parameters can be bound to the statement."""
if compiled is None: return
if getattr(compiled, "isinsert", False):
if isinstance(parameters, list):
@@ -454,7 +484,19 @@ class SQLEngine(schema.SchemaEngine):
self.context.last_inserted_ids = None
else:
self.context.last_inserted_ids = last_inserted_ids
-
+ elif getattr(compiled, 'isupdate', False):
+ if isinstance(parameters, list):
+ plist = parameters
+ else:
+ plist = [parameters]
+ drunner = self.defaultrunner(proxy)
+ for param in plist:
+ for c in compiled.statement.table.c:
+ if c.onupdate is not None and (not param.has_key(c.name) or param[c.name] is None):
+ value = drunner.get_column_onupdate(c)
+ if value is not None:
+ param[c.name] = value
+
def lastrow_has_defaults(self):
return self.context.lastrow_has_defaults