summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ansisql.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2006-04-28 23:31:59 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2006-04-28 23:31:59 +0000
commit409acb42086bb82bdfc784dafef5a6fa50afd0e0 (patch)
tree0aac110858e7ce921196b4a9794cce267ea226f8 /lib/sqlalchemy/ansisql.py
parentb0fff23df8e00de7c783254f1dea51831b0ca6de (diff)
downloadsqlalchemy-409acb42086bb82bdfc784dafef5a6fa50afd0e0.tar.gz
fix for [ticket:169], moves the creation of "default" parameters more accurately
where theyre supposed to be
Diffstat (limited to 'lib/sqlalchemy/ansisql.py')
-rw-r--r--lib/sqlalchemy/ansisql.py37
1 files changed, 21 insertions, 16 deletions
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py
index 7aefea0ba..a344e017c 100644
--- a/lib/sqlalchemy/ansisql.py
+++ b/lib/sqlalchemy/ansisql.py
@@ -146,7 +146,6 @@ class ANSICompiler(sql.Compiled):
continue
d.set_parameter(b.key, value, b)
- #print "FROM", params, "TO", d
return d
def get_named_params(self, parameters):
@@ -425,26 +424,26 @@ class ANSICompiler(sql.Compiled):
" ON " + self.get_str(join.onclause))
self.strings[join] = self.froms[join]
- def visit_insert_column_default(self, column, default):
+ def visit_insert_column_default(self, column, default, parameters):
"""called when visiting an Insert statement, for each column in the table that
contains a ColumnDefault object. adds a blank 'placeholder' parameter so the
Insert gets compiled with this column's name in its column and VALUES clauses."""
- self.parameters.setdefault(column.key, None)
+ parameters.setdefault(column.key, None)
- def visit_update_column_default(self, column, default):
+ def visit_update_column_default(self, column, default, parameters):
"""called when visiting an Update statement, for each column in the table that
contains a ColumnDefault object as an onupdate. adds a blank 'placeholder' parameter so the
Update gets compiled with this column's name as one of its SET clauses."""
- self.parameters.setdefault(column.key, None)
+ parameters.setdefault(column.key, None)
- def visit_insert_sequence(self, column, sequence):
+ def visit_insert_sequence(self, column, sequence, parameters):
"""called when visiting an Insert statement, for each column in the table that
contains a Sequence object. Overridden by compilers that support sequences to place
a blank 'placeholder' parameter, so the Insert gets compiled with this column's
name in its column and VALUES clauses."""
pass
- def visit_insert_column(self, column):
+ def visit_insert_column(self, column, parameters):
"""called when visiting an Insert statement, for each column in the table
that is a NULL insert into the table. Overridden by compilers who disallow
NULL columns being set in an Insert where there is a default value on the column
@@ -454,25 +453,27 @@ class ANSICompiler(sql.Compiled):
def visit_insert(self, insert_stmt):
# scan the table's columns for defaults that have to be pre-set for an INSERT
# add these columns to the parameter list via visit_insert_XXX methods
+ default_params = {}
class DefaultVisitor(schema.SchemaVisitor):
def visit_column(s, c):
- self.visit_insert_column(c)
+ self.visit_insert_column(c, default_params)
def visit_column_default(s, cd):
- self.visit_insert_column_default(c, cd)
+ self.visit_insert_column_default(c, cd, default_params)
def visit_sequence(s, seq):
- self.visit_insert_sequence(c, seq)
+ self.visit_insert_sequence(c, seq, default_params)
vis = DefaultVisitor()
for c in insert_stmt.table.c:
if (isinstance(c, schema.SchemaItem) and (self.parameters is None or self.parameters.get(c.key, None) is None)):
c.accept_schema_visitor(vis)
self.isinsert = True
- colparams = self._get_colparams(insert_stmt)
+ colparams = self._get_colparams(insert_stmt, default_params)
def create_param(p):
if isinstance(p, sql.BindParamClause):
self.binds[p.key] = p
- self.binds[p.shortname] = p
+ if p.shortname is not None:
+ self.binds[p.shortname] = p
return self.bindparam_string(p.key)
else:
p.accept_visitor(self)
@@ -483,22 +484,23 @@ class ANSICompiler(sql.Compiled):
text = ("INSERT INTO " + insert_stmt.table.fullname + " (" + string.join([c[0].name for c in colparams], ', ') + ")" +
" VALUES (" + string.join([create_param(c[1]) for c in colparams], ', ') + ")")
-
+
self.strings[insert_stmt] = text
def visit_update(self, update_stmt):
# scan the table's columns for onupdates that have to be pre-set for an UPDATE
# add these columns to the parameter list via visit_update_XXX methods
+ default_params = {}
class OnUpdateVisitor(schema.SchemaVisitor):
def visit_column_onupdate(s, cd):
- self.visit_update_column_default(c, cd)
+ self.visit_update_column_default(c, cd, default_params)
vis = OnUpdateVisitor()
for c in update_stmt.table.c:
if (isinstance(c, schema.SchemaItem) and (self.parameters is None or self.parameters.get(c.key, None) is None)):
c.accept_schema_visitor(vis)
self.isupdate = True
- colparams = self._get_colparams(update_stmt)
+ colparams = self._get_colparams(update_stmt, default_params)
def create_param(p):
if isinstance(p, sql.BindParamClause):
self.binds[p.key] = p
@@ -519,7 +521,7 @@ class ANSICompiler(sql.Compiled):
self.strings[update_stmt] = text
- def _get_colparams(self, stmt):
+ def _get_colparams(self, stmt, default_params):
"""determines the VALUES or SET clause for an INSERT or UPDATE
clause based on the arguments specified to this ANSICompiler object
(i.e., the execute() or compile() method clause object):
@@ -550,6 +552,9 @@ class ANSICompiler(sql.Compiled):
for k, v in stmt.parameters.iteritems():
parameters.setdefault(k, v)
+ for k, v in default_params.iteritems():
+ parameters.setdefault(k, v)
+
# now go thru compiled params, get the Column object for each key
d = {}
for key, value in parameters.iteritems():