diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2006-03-05 20:31:44 +0000 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2006-03-05 20:31:44 +0000 |
commit | 9c4f3c0480f54e08b3aa2800ed76e89f957f8131 (patch) | |
tree | e7cad83cbd55ff0e2a3f4103160e7e8fed6b6a2c /lib/sqlalchemy/ansisql.py | |
parent | c1d0c2dffc0eedfa63de5b90addb70bfd3a81540 (diff) | |
download | sqlalchemy-9c4f3c0480f54e08b3aa2800ed76e89f957f8131.tar.gz |
got column onupdate working
improvement to Function so that they can more easily be called standalone without having to throw them into a select().
Diffstat (limited to 'lib/sqlalchemy/ansisql.py')
-rw-r--r-- | lib/sqlalchemy/ansisql.py | 62 |
1 files changed, 55 insertions, 7 deletions
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 7c0002aa5..7b39d5358 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -15,6 +15,18 @@ from sqlalchemy.sql import * from sqlalchemy.util import * import string, re +ANSI_FUNCS = HashSet([ +'CURRENT_TIME', +'CURRENT_TIMESTAMP', +'CURRENT_DATE', +'LOCAL_TIME', +'LOCAL_TIMESTAMP', +'CURRENT_USER', +'SESSION_USER', +'USER' +]) + + def engine(**params): return ANSISQLEngine(**params) @@ -57,6 +69,7 @@ class ANSICompiler(sql.Compiled): self.select_stack = [] self.typemap = typemap or {} self.isinsert = False + self.isupdate = False self.bindtemplate = ":%s" if engine is not None: self.paramstyle = engine.paramstyle @@ -89,7 +102,7 @@ class ANSICompiler(sql.Compiled): self.strings[self.statement] = re.sub(match, getnum, self.strings[self.statement]) def get_from_text(self, obj): - return self.froms[obj] + return self.froms.get(obj, None) def get_str(self, obj): return self.strings[obj] @@ -158,6 +171,11 @@ class ANSICompiler(sql.Compiled): else: return parameters + def default_from(self): + """called when a SELECT statement has no froms, and no FROM clause is to be appended. + gives Oracle a chance to tack on a "FROM DUAL" to the string output. """ + return "" + def visit_label(self, label): if len(self.select_stack): self.typemap.setdefault(label.name.lower(), label.obj.type) @@ -211,7 +229,12 @@ class ANSICompiler(sql.Compiled): self.strings[list] = string.join([self.get_str(c) for c in list.clauses], ', ') def visit_function(self, func): - self.strings[func] = func.name + "(" + string.join([self.get_str(c) for c in func.clauses], ', ') + ")" + if len(self.select_stack): + self.typemap.setdefault(func.name, func.type) + if func.name.upper() in ANSI_FUNCS and not len(func.clauses): + self.strings[func] = func.name + else: + self.strings[func] = func.name + "(" + string.join([self.get_str(c) for c in func.clauses], ', ') + ")" def visit_compound_select(self, cs): text = string.join([self.get_str(c) for c in cs.selects], " " + cs.keyword + " ") @@ -325,7 +348,9 @@ class ANSICompiler(sql.Compiled): if len(froms): text += " \nFROM " text += string.join(froms, ', ') - + else: + text += self.default_from() + if whereclause is not None: t = self.get_str(whereclause) if t: @@ -384,21 +409,33 @@ class ANSICompiler(sql.Compiled): def visit_insert_column_default(self, column, default): """called when visiting an Insert statement, for each column in the table that - contains a ColumnDefault object.""" + contains a ColumnDefault object. adds a blank 'placeholder' parameter so the + Insert gets compiled with this column's name in its column and VALUES clauses.""" + self.parameters.setdefault(column.key, None) + + def visit_update_column_default(self, column, default): + """called when visiting an Update statement, for each column in the table that + contains a ColumnDefault object as an onupdate. adds a blank 'placeholder' parameter so the + Update gets compiled with this column's name as one of its SET clauses.""" self.parameters.setdefault(column.key, None) def visit_insert_sequence(self, column, sequence): """called when visiting an Insert statement, for each column in the table that - contains a Sequence object.""" + contains a Sequence object. Overridden by compilers that support sequences to place + a blank 'placeholder' parameter, so the Insert gets compiled with this column's + name in its column and VALUES clauses.""" pass def visit_insert_column(self, column): """called when visiting an Insert statement, for each column in the table - that is a NULL insert into the table""" + that is a NULL insert into the table. Overridden by compilers who disallow + NULL columns being set in an Insert where there is a default value on the column + (i.e. postgres), to remove the column from the parameter list.""" pass def visit_insert(self, insert_stmt): - # set up a call for the defaults and sequences inside the table + # scan the table's columns for defaults that have to be pre-set for an INSERT + # add these columns to the parameter list via visit_insert_XXX methods class DefaultVisitor(schema.SchemaVisitor): def visit_column(s, c): self.visit_insert_column(c) @@ -424,6 +461,17 @@ class ANSICompiler(sql.Compiled): self.strings[insert_stmt] = text def visit_update(self, update_stmt): + # scan the table's columns for onupdates that have to be pre-set for an UPDATE + # add these columns to the parameter list via visit_update_XXX methods + class OnUpdateVisitor(schema.SchemaVisitor): + def visit_column_onupdate(s, cd): + self.visit_update_column_default(c, cd) + vis = OnUpdateVisitor() + for c in update_stmt.table.c: + if (isinstance(c, schema.SchemaItem) and (self.parameters is None or self.parameters.get(c.key, None) is None)): + c.accept_schema_visitor(vis) + + self.isupdate = True colparams = self._get_colparams(update_stmt) def create_param(p): if isinstance(p, sql.BindParamClause): |