diff options
Diffstat (limited to 'lib/sqlalchemy/ansisql.py')
| -rw-r--r-- | lib/sqlalchemy/ansisql.py | 39 |
1 files changed, 23 insertions, 16 deletions
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index f96bf7abe..19cde3862 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -10,7 +10,7 @@ Contains default implementations for the abstract objects in the sql module. """ -from sqlalchemy import schema, sql, engine, util, sql_util +from sqlalchemy import schema, sql, engine, util, sql_util, exceptions from sqlalchemy.engine import default import string, re, sets, weakref @@ -353,20 +353,27 @@ class ANSICompiler(sql.Compiled): def visit_bindparam(self, bindparam): if bindparam.shortname != bindparam.key: self.binds.setdefault(bindparam.shortname, bindparam) - count = 1 - key = bindparam.key - - # redefine the generated name of the bind param in the case - # that we have multiple conflicting bind parameters. - while self.binds.setdefault(key, bindparam) is not bindparam: - # ensure the name doesn't expand the length of the string - # in case we're at the edge of max identifier length - tag = "_%d" % count - key = bindparam.key[0 : len(bindparam.key) - len(tag)] + tag - count += 1 - bindparam.key = key - self.strings[bindparam] = self.bindparam_string(key) - + if bindparam.unique: + count = 1 + key = bindparam.key + + # redefine the generated name of the bind param in the case + # that we have multiple conflicting bind parameters. + while self.binds.setdefault(key, bindparam) is not bindparam: + # ensure the name doesn't expand the length of the string + # in case we're at the edge of max identifier length + tag = "_%d" % count + key = bindparam.key[0 : len(bindparam.key) - len(tag)] + tag + count += 1 + bindparam.key = key + self.strings[bindparam] = self.bindparam_string(key) + else: + existing = self.binds.get(bindparam.key) + if existing is not None and existing.unique: + raise exceptions.CompileError("Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key) + self.strings[bindparam] = self.bindparam_string(bindparam.key) + self.binds[bindparam.key] = bindparam + def bindparam_string(self, name): return self.bindtemplate % name @@ -702,7 +709,7 @@ class ANSICompiler(sql.Compiled): if parameters.has_key(c): value = parameters[c] if sql._is_literal(value): - value = sql.bindparam(c.key, value, type=c.type) + value = sql.bindparam(c.key, value, type=c.type, unique=True) values.append((c, value)) return values |
