summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ansisql.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/ansisql.py')
-rw-r--r--lib/sqlalchemy/ansisql.py54
1 files changed, 35 insertions, 19 deletions
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py
index 1b600a4a8..7c0002aa5 100644
--- a/lib/sqlalchemy/ansisql.py
+++ b/lib/sqlalchemy/ansisql.py
@@ -27,7 +27,7 @@ class ANSISQLEngine(sqlalchemy.engine.SQLEngine):
return ANSISchemaDropper(self, **params)
def compiler(self, statement, parameters, **kwargs):
- return ANSICompiler(self, statement, parameters, **kwargs)
+ return ANSICompiler(statement, parameters, engine=self, **kwargs)
def connect_args(self):
return ([],{})
@@ -37,7 +37,7 @@ class ANSISQLEngine(sqlalchemy.engine.SQLEngine):
class ANSICompiler(sql.Compiled):
"""default implementation of Compiled, which compiles ClauseElements into ANSI-compliant SQL strings."""
- def __init__(self, engine, statement, parameters=None, typemap=None, **kwargs):
+ def __init__(self, statement, parameters=None, typemap=None, engine=None, positional=None, paramstyle=None, **kwargs):
"""constructs a new ANSICompiler object.
engine - SQLEngine to compile against
@@ -49,7 +49,7 @@ class ANSICompiler(sql.Compiled):
key/value pairs when the Compiled is executed, and also may affect the
actual compilation, as in the case of an INSERT where the actual columns
inserted will correspond to the keys present in the parameters."""
- sql.Compiled.__init__(self, engine, statement, parameters)
+ sql.Compiled.__init__(self, statement, parameters, engine=engine)
self.binds = {}
self.froms = {}
self.wheres = {}
@@ -57,19 +57,31 @@ class ANSICompiler(sql.Compiled):
self.select_stack = []
self.typemap = typemap or {}
self.isinsert = False
+ self.bindtemplate = ":%s"
+ if engine is not None:
+ self.paramstyle = engine.paramstyle
+ self.positional = engine.positional
+ else:
+ self.positional = False
+ self.paramstyle = 'named'
def after_compile(self):
- if self.engine.positional:
+ # this re will search for params like :param
+ # it has a negative lookbehind for an extra ':' so that it doesnt match
+ # postgres '::text' tokens
+ match = r'(?<!:):([\w_]+)'
+ if self.paramstyle=='pyformat':
+ self.strings[self.statement] = re.sub(match, lambda m:'%(' + m.group(1) +')s', self.strings[self.statement])
+ elif self.positional:
self.positiontup = []
- match = r'%\(([\w_]+)\)s'
params = re.finditer(match, self.strings[self.statement])
for p in params:
self.positiontup.append(p.group(1))
- if self.engine.paramstyle=='qmark':
+ if self.paramstyle=='qmark':
self.strings[self.statement] = re.sub(match, '?', self.strings[self.statement])
- elif self.engine.paramstyle=='format':
+ elif self.paramstyle=='format':
self.strings[self.statement] = re.sub(match, '%s', self.strings[self.statement])
- elif self.engine.paramstyle=='numeric':
+ elif self.paramstyle=='numeric':
i = [0]
def getnum(x):
i[0] += 1
@@ -104,28 +116,33 @@ class ANSICompiler(sql.Compiled):
bindparams = {}
bindparams.update(params)
- if self.engine.positional:
+ if self.positional:
d = OrderedDict()
for k in self.positiontup:
b = self.binds[k]
- d[k] = b.typeprocess(b.value, self.engine)
+ if self.engine is not None:
+ d[k] = b.typeprocess(b.value, self.engine)
+ else:
+ d[k] = b.value
else:
d = {}
for b in self.binds.values():
- d[b.key] = b.typeprocess(b.value, self.engine)
+ if self.engine is not None:
+ d[b.key] = b.typeprocess(b.value, self.engine)
+ else:
+ d[b.key] = b.value
for key, value in bindparams.iteritems():
try:
b = self.binds[key]
except KeyError:
continue
- d[b.key] = b.typeprocess(value, self.engine)
+ if self.engine is not None:
+ d[b.key] = b.typeprocess(value, self.engine)
+ else:
+ d[b.key] = value
return d
- if self.engine.positional:
- return d.values()
- else:
- return d
def get_named_params(self, parameters):
"""given the results of the get_params method, returns the parameters
@@ -133,8 +150,7 @@ class ANSICompiler(sql.Compiled):
same dictionary. For a positional paramstyle, the given parameters are
assumed to be in list format and are converted back to a dictionary.
"""
-# return parameters
- if self.engine.positional:
+ if self.positional:
p = {}
for i in range(0, len(self.positiontup)):
p[self.positiontup[i]] = parameters[i]
@@ -231,7 +247,7 @@ class ANSICompiler(sql.Compiled):
self.strings[bindparam] = self.bindparam_string(key)
def bindparam_string(self, name):
- return self.engine.bindtemplate % name
+ return self.bindtemplate % name
def visit_alias(self, alias):
self.froms[alias] = self.get_from_text(alias.original) + " AS " + alias.name