diff options
Diffstat (limited to 'lib/sqlalchemy/ansisql.py')
-rw-r--r-- | lib/sqlalchemy/ansisql.py | 54 |
1 files changed, 35 insertions, 19 deletions
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 1b600a4a8..7c0002aa5 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -27,7 +27,7 @@ class ANSISQLEngine(sqlalchemy.engine.SQLEngine): return ANSISchemaDropper(self, **params) def compiler(self, statement, parameters, **kwargs): - return ANSICompiler(self, statement, parameters, **kwargs) + return ANSICompiler(statement, parameters, engine=self, **kwargs) def connect_args(self): return ([],{}) @@ -37,7 +37,7 @@ class ANSISQLEngine(sqlalchemy.engine.SQLEngine): class ANSICompiler(sql.Compiled): """default implementation of Compiled, which compiles ClauseElements into ANSI-compliant SQL strings.""" - def __init__(self, engine, statement, parameters=None, typemap=None, **kwargs): + def __init__(self, statement, parameters=None, typemap=None, engine=None, positional=None, paramstyle=None, **kwargs): """constructs a new ANSICompiler object. engine - SQLEngine to compile against @@ -49,7 +49,7 @@ class ANSICompiler(sql.Compiled): key/value pairs when the Compiled is executed, and also may affect the actual compilation, as in the case of an INSERT where the actual columns inserted will correspond to the keys present in the parameters.""" - sql.Compiled.__init__(self, engine, statement, parameters) + sql.Compiled.__init__(self, statement, parameters, engine=engine) self.binds = {} self.froms = {} self.wheres = {} @@ -57,19 +57,31 @@ class ANSICompiler(sql.Compiled): self.select_stack = [] self.typemap = typemap or {} self.isinsert = False + self.bindtemplate = ":%s" + if engine is not None: + self.paramstyle = engine.paramstyle + self.positional = engine.positional + else: + self.positional = False + self.paramstyle = 'named' def after_compile(self): - if self.engine.positional: + # this re will search for params like :param + # it has a negative lookbehind for an extra ':' so that it doesnt match + # postgres '::text' tokens + match = r'(?<!:):([\w_]+)' + if self.paramstyle=='pyformat': + self.strings[self.statement] = re.sub(match, lambda m:'%(' + m.group(1) +')s', self.strings[self.statement]) + elif self.positional: self.positiontup = [] - match = r'%\(([\w_]+)\)s' params = re.finditer(match, self.strings[self.statement]) for p in params: self.positiontup.append(p.group(1)) - if self.engine.paramstyle=='qmark': + if self.paramstyle=='qmark': self.strings[self.statement] = re.sub(match, '?', self.strings[self.statement]) - elif self.engine.paramstyle=='format': + elif self.paramstyle=='format': self.strings[self.statement] = re.sub(match, '%s', self.strings[self.statement]) - elif self.engine.paramstyle=='numeric': + elif self.paramstyle=='numeric': i = [0] def getnum(x): i[0] += 1 @@ -104,28 +116,33 @@ class ANSICompiler(sql.Compiled): bindparams = {} bindparams.update(params) - if self.engine.positional: + if self.positional: d = OrderedDict() for k in self.positiontup: b = self.binds[k] - d[k] = b.typeprocess(b.value, self.engine) + if self.engine is not None: + d[k] = b.typeprocess(b.value, self.engine) + else: + d[k] = b.value else: d = {} for b in self.binds.values(): - d[b.key] = b.typeprocess(b.value, self.engine) + if self.engine is not None: + d[b.key] = b.typeprocess(b.value, self.engine) + else: + d[b.key] = b.value for key, value in bindparams.iteritems(): try: b = self.binds[key] except KeyError: continue - d[b.key] = b.typeprocess(value, self.engine) + if self.engine is not None: + d[b.key] = b.typeprocess(value, self.engine) + else: + d[b.key] = value return d - if self.engine.positional: - return d.values() - else: - return d def get_named_params(self, parameters): """given the results of the get_params method, returns the parameters @@ -133,8 +150,7 @@ class ANSICompiler(sql.Compiled): same dictionary. For a positional paramstyle, the given parameters are assumed to be in list format and are converted back to a dictionary. """ -# return parameters - if self.engine.positional: + if self.positional: p = {} for i in range(0, len(self.positiontup)): p[self.positiontup[i]] = parameters[i] @@ -231,7 +247,7 @@ class ANSICompiler(sql.Compiled): self.strings[bindparam] = self.bindparam_string(key) def bindparam_string(self, name): - return self.engine.bindtemplate % name + return self.bindtemplate % name def visit_alias(self, alias): self.froms[alias] = self.get_from_text(alias.original) + " AS " + alias.name |