summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ansisql.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2006-03-05 20:31:44 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2006-03-05 20:31:44 +0000
commit9c4f3c0480f54e08b3aa2800ed76e89f957f8131 (patch)
treee7cad83cbd55ff0e2a3f4103160e7e8fed6b6a2c /lib/sqlalchemy/ansisql.py
parentc1d0c2dffc0eedfa63de5b90addb70bfd3a81540 (diff)
downloadsqlalchemy-9c4f3c0480f54e08b3aa2800ed76e89f957f8131.tar.gz
got column onupdate working
improvement to Function so that they can more easily be called standalone without having to throw them into a select().
Diffstat (limited to 'lib/sqlalchemy/ansisql.py')
-rw-r--r--lib/sqlalchemy/ansisql.py62
1 files changed, 55 insertions, 7 deletions
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py
index 7c0002aa5..7b39d5358 100644
--- a/lib/sqlalchemy/ansisql.py
+++ b/lib/sqlalchemy/ansisql.py
@@ -15,6 +15,18 @@ from sqlalchemy.sql import *
from sqlalchemy.util import *
import string, re
+ANSI_FUNCS = HashSet([
+'CURRENT_TIME',
+'CURRENT_TIMESTAMP',
+'CURRENT_DATE',
+'LOCAL_TIME',
+'LOCAL_TIMESTAMP',
+'CURRENT_USER',
+'SESSION_USER',
+'USER'
+])
+
+
def engine(**params):
return ANSISQLEngine(**params)
@@ -57,6 +69,7 @@ class ANSICompiler(sql.Compiled):
self.select_stack = []
self.typemap = typemap or {}
self.isinsert = False
+ self.isupdate = False
self.bindtemplate = ":%s"
if engine is not None:
self.paramstyle = engine.paramstyle
@@ -89,7 +102,7 @@ class ANSICompiler(sql.Compiled):
self.strings[self.statement] = re.sub(match, getnum, self.strings[self.statement])
def get_from_text(self, obj):
- return self.froms[obj]
+ return self.froms.get(obj, None)
def get_str(self, obj):
return self.strings[obj]
@@ -158,6 +171,11 @@ class ANSICompiler(sql.Compiled):
else:
return parameters
+ def default_from(self):
+ """called when a SELECT statement has no froms, and no FROM clause is to be appended.
+ gives Oracle a chance to tack on a "FROM DUAL" to the string output. """
+ return ""
+
def visit_label(self, label):
if len(self.select_stack):
self.typemap.setdefault(label.name.lower(), label.obj.type)
@@ -211,7 +229,12 @@ class ANSICompiler(sql.Compiled):
self.strings[list] = string.join([self.get_str(c) for c in list.clauses], ', ')
def visit_function(self, func):
- self.strings[func] = func.name + "(" + string.join([self.get_str(c) for c in func.clauses], ', ') + ")"
+ if len(self.select_stack):
+ self.typemap.setdefault(func.name, func.type)
+ if func.name.upper() in ANSI_FUNCS and not len(func.clauses):
+ self.strings[func] = func.name
+ else:
+ self.strings[func] = func.name + "(" + string.join([self.get_str(c) for c in func.clauses], ', ') + ")"
def visit_compound_select(self, cs):
text = string.join([self.get_str(c) for c in cs.selects], " " + cs.keyword + " ")
@@ -325,7 +348,9 @@ class ANSICompiler(sql.Compiled):
if len(froms):
text += " \nFROM "
text += string.join(froms, ', ')
-
+ else:
+ text += self.default_from()
+
if whereclause is not None:
t = self.get_str(whereclause)
if t:
@@ -384,21 +409,33 @@ class ANSICompiler(sql.Compiled):
def visit_insert_column_default(self, column, default):
"""called when visiting an Insert statement, for each column in the table that
- contains a ColumnDefault object."""
+ contains a ColumnDefault object. adds a blank 'placeholder' parameter so the
+ Insert gets compiled with this column's name in its column and VALUES clauses."""
+ self.parameters.setdefault(column.key, None)
+
+ def visit_update_column_default(self, column, default):
+ """called when visiting an Update statement, for each column in the table that
+ contains a ColumnDefault object as an onupdate. adds a blank 'placeholder' parameter so the
+ Update gets compiled with this column's name as one of its SET clauses."""
self.parameters.setdefault(column.key, None)
def visit_insert_sequence(self, column, sequence):
"""called when visiting an Insert statement, for each column in the table that
- contains a Sequence object."""
+ contains a Sequence object. Overridden by compilers that support sequences to place
+ a blank 'placeholder' parameter, so the Insert gets compiled with this column's
+ name in its column and VALUES clauses."""
pass
def visit_insert_column(self, column):
"""called when visiting an Insert statement, for each column in the table
- that is a NULL insert into the table"""
+ that is a NULL insert into the table. Overridden by compilers who disallow
+ NULL columns being set in an Insert where there is a default value on the column
+ (i.e. postgres), to remove the column from the parameter list."""
pass
def visit_insert(self, insert_stmt):
- # set up a call for the defaults and sequences inside the table
+ # scan the table's columns for defaults that have to be pre-set for an INSERT
+ # add these columns to the parameter list via visit_insert_XXX methods
class DefaultVisitor(schema.SchemaVisitor):
def visit_column(s, c):
self.visit_insert_column(c)
@@ -424,6 +461,17 @@ class ANSICompiler(sql.Compiled):
self.strings[insert_stmt] = text
def visit_update(self, update_stmt):
+ # scan the table's columns for onupdates that have to be pre-set for an UPDATE
+ # add these columns to the parameter list via visit_update_XXX methods
+ class OnUpdateVisitor(schema.SchemaVisitor):
+ def visit_column_onupdate(s, cd):
+ self.visit_update_column_default(c, cd)
+ vis = OnUpdateVisitor()
+ for c in update_stmt.table.c:
+ if (isinstance(c, schema.SchemaItem) and (self.parameters is None or self.parameters.get(c.key, None) is None)):
+ c.accept_schema_visitor(vis)
+
+ self.isupdate = True
colparams = self._get_colparams(update_stmt)
def create_param(p):
if isinstance(p, sql.BindParamClause):