summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ansisql.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2005-12-17 02:49:47 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2005-12-17 02:49:47 +0000
commit40964f68a143ab211bfd903dcc6733bf1c77906a (patch)
treee6210800dcdfedf7c9de56b0a35a3643fb21a941 /lib/sqlalchemy/ansisql.py
parent5b4c585078a38da03f6cf8b5958022911cb611e3 (diff)
downloadsqlalchemy-40964f68a143ab211bfd903dcc6733bf1c77906a.tar.gz
refactoring of execution path, defaults, and treatment of different paramstyles
Diffstat (limited to 'lib/sqlalchemy/ansisql.py')
-rw-r--r--lib/sqlalchemy/ansisql.py111
1 files changed, 67 insertions, 44 deletions
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py
index cd1d3a0b0..e4bcdd077 100644
--- a/lib/sqlalchemy/ansisql.py
+++ b/lib/sqlalchemy/ansisql.py
@@ -37,8 +37,8 @@ class ANSISQLEngine(sqlalchemy.engine.SQLEngine):
def schemadropper(self, proxy, **params):
return ANSISchemaDropper(proxy, **params)
- def compiler(self, statement, bindparams, **kwargs):
- return ANSICompiler(self, statement, bindparams, **kwargs)
+ def compiler(self, statement, parameters, **kwargs):
+ return ANSICompiler(self, statement, parameters, **kwargs)
def connect_args(self):
return ([],{})
@@ -47,8 +47,20 @@ class ANSISQLEngine(sqlalchemy.engine.SQLEngine):
return None
class ANSICompiler(sql.Compiled):
- def __init__(self, engine, statement, bindparams, typemap=None, paramstyle=None,**kwargs):
- sql.Compiled.__init__(self, engine, statement, bindparams)
+ """default implementation of Compiled, which compiles ClauseElements into ANSI-compliant SQL strings."""
+ def __init__(self, engine, statement, parameters=None, typemap=None, **kwargs):
+ """constructs a new ANSICompiler object.
+
+ engine - SQLEngine to compile against
+
+ statement - ClauseElement to be compiled
+
+ parameters - optional dictionary indicating a set of bind parameters
+ specified with this Compiled object. These parameters are the "default"
+ key/value pairs when the Compiled is executed, and also may affect the
+ actual compilation, as in the case of an INSERT where the actual columns
+ inserted will correspond to the keys present in the parameters."""
+ sql.Compiled.__init__(self, engine, statement, parameters)
self.binds = {}
self.froms = {}
self.wheres = {}
@@ -57,37 +69,18 @@ class ANSICompiler(sql.Compiled):
self.typemap = typemap or {}
self.isinsert = False
- if paramstyle is None:
- db = self.engine.dbapi()
- if db is not None:
- paramstyle = db.paramstyle
- else:
- paramstyle = 'named'
-
- if paramstyle == 'named':
- self.bindtemplate = ':%s'
- self.positional=False
- elif paramstyle =='pyformat':
- self.bindtemplate = "%%(%s)s"
- self.positional=False
- else:
- # for positional, use pyformat until the end
- self.bindtemplate = "%%(%s)s"
- self.positional=True
- self.paramstyle=paramstyle
-
def after_compile(self):
- if self.positional:
+ if self.engine.positional:
self.positiontup = []
match = r'%\(([\w_]+)\)s'
params = re.finditer(match, self.strings[self.statement])
for p in params:
self.positiontup.append(p.group(1))
- if self.paramstyle=='qmark':
+ if self.engine.paramstyle=='qmark':
self.strings[self.statement] = re.sub(match, '?', self.strings[self.statement])
- elif self.paramstyle=='format':
+ elif self.engine.paramstyle=='format':
self.strings[self.statement] = re.sub(match, '%s', self.strings[self.statement])
- elif self.paramstyle=='numeric':
+ elif self.engine.paramstyle=='numeric':
i = 0
def getnum(x):
i += 1
@@ -116,14 +109,22 @@ class ANSICompiler(sql.Compiled):
for an executemany style of call, this method should be called for each element
in the list of parameter groups that will ultimately be executed.
"""
- d = {}
- if self.bindparams is not None:
- bindparams = self.bindparams.copy()
+ if self.parameters is not None:
+ bindparams = self.parameters.copy()
else:
bindparams = {}
bindparams.update(params)
- # TODO: cant we make "d" an ordereddict and add params in
- # positional order
+
+ if self.engine.positional:
+ d = OrderedDict()
+ for k in self.positiontup:
+ b = self.binds[k]
+ d[k] = b.typeprocess(b.value)
+ else:
+ d = {}
+ for b in self.binds.values():
+ d[b.key] = b.typeprocess(b.value)
+
for key, value in bindparams.iteritems():
try:
b = self.binds[key]
@@ -131,11 +132,9 @@ class ANSICompiler(sql.Compiled):
continue
d[b.key] = b.typeprocess(value)
- for b in self.binds.values():
- d.setdefault(b.key, b.typeprocess(b.value))
-
- if self.positional:
- return [d[key] for key in self.positiontup]
+ return d
+ if self.engine.positional:
+ return d.values()
else:
return d
@@ -145,7 +144,8 @@ class ANSICompiler(sql.Compiled):
same dictionary. For a positional paramstyle, the given parameters are
assumed to be in list format and are converted back to a dictionary.
"""
- if self.positional:
+# return parameters
+ if self.engine.positional:
p = {}
for i in range(0, len(self.positiontup)):
p[self.positiontup[i]] = parameters[i]
@@ -237,7 +237,7 @@ class ANSICompiler(sql.Compiled):
self.strings[bindparam] = self.bindparam_string(key)
def bindparam_string(self, name):
- return self.bindtemplate % name
+ return self.engine.bindtemplate % name
def visit_alias(self, alias):
self.froms[alias] = self.get_from_text(alias.selectable) + " AS " + alias.name
@@ -265,7 +265,7 @@ class ANSICompiler(sql.Compiled):
text = "SELECT "
if select.distinct:
text += "DISTINCT "
- text += collist + " \nFROM "
+ text += collist
whereclause = select.whereclause
@@ -282,8 +282,10 @@ class ANSICompiler(sql.Compiled):
t = self.get_from_text(f)
if t is not None:
froms.append(t)
-
- text += string.join(froms, ', ')
+
+ if len(froms):
+ text += " \nFROM "
+ text += string.join(froms, ', ')
if whereclause is not None:
t = self.get_str(whereclause)
@@ -333,10 +335,31 @@ class ANSICompiler(sql.Compiled):
self.froms[join] = (self.get_from_text(join.left) + " JOIN " + righttext +
" ON " + self.get_str(join.onclause))
self.strings[join] = self.froms[join]
+
+ def visit_insert_column_default(self, column, default):
+ """called when visiting an Insert statement, for each column in the table that
+ contains a ColumnDefault object."""
+ self.parameters.setdefault(column.key, None)
+
+ def visit_insert_sequence(self, column, sequence):
+ """called when visiting an Insert statement, for each column in the table that
+ contains a Sequence object."""
+ pass
def visit_insert(self, insert_stmt):
+ # set up a call for the defaults and sequences inside the table
+ class DefaultVisitor(schema.SchemaVisitor):
+ def visit_column_default(s, cd):
+ self.visit_insert_column_default(c, cd)
+ def visit_sequence(s, seq):
+ self.visit_insert_sequence(c, seq)
+ vis = DefaultVisitor()
+ for c in insert_stmt.table.c:
+ if self.parameters.get(c.key, None) is None and c.default is not None:
+ c.default.accept_visitor(vis)
+
self.isinsert = True
- colparams = insert_stmt.get_colparams(self.bindparams)
+ colparams = insert_stmt.get_colparams(self.parameters)
for c in colparams:
b = c[1]
self.binds[b.key] = b
@@ -348,7 +371,7 @@ class ANSICompiler(sql.Compiled):
self.strings[insert_stmt] = text
def visit_update(self, update_stmt):
- colparams = update_stmt.get_colparams(self.bindparams)
+ colparams = update_stmt.get_colparams(self.parameters)
def create_param(p):
if isinstance(p, sql.BindParamClause):
self.binds[p.key] = p