diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2005-12-17 02:49:47 +0000 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2005-12-17 02:49:47 +0000 |
commit | 40964f68a143ab211bfd903dcc6733bf1c77906a (patch) | |
tree | e6210800dcdfedf7c9de56b0a35a3643fb21a941 /lib/sqlalchemy/ansisql.py | |
parent | 5b4c585078a38da03f6cf8b5958022911cb611e3 (diff) | |
download | sqlalchemy-40964f68a143ab211bfd903dcc6733bf1c77906a.tar.gz |
refactoring of execution path, defaults, and treatment of different paramstyles
Diffstat (limited to 'lib/sqlalchemy/ansisql.py')
-rw-r--r-- | lib/sqlalchemy/ansisql.py | 111 |
1 files changed, 67 insertions, 44 deletions
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index cd1d3a0b0..e4bcdd077 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -37,8 +37,8 @@ class ANSISQLEngine(sqlalchemy.engine.SQLEngine): def schemadropper(self, proxy, **params): return ANSISchemaDropper(proxy, **params) - def compiler(self, statement, bindparams, **kwargs): - return ANSICompiler(self, statement, bindparams, **kwargs) + def compiler(self, statement, parameters, **kwargs): + return ANSICompiler(self, statement, parameters, **kwargs) def connect_args(self): return ([],{}) @@ -47,8 +47,20 @@ class ANSISQLEngine(sqlalchemy.engine.SQLEngine): return None class ANSICompiler(sql.Compiled): - def __init__(self, engine, statement, bindparams, typemap=None, paramstyle=None,**kwargs): - sql.Compiled.__init__(self, engine, statement, bindparams) + """default implementation of Compiled, which compiles ClauseElements into ANSI-compliant SQL strings.""" + def __init__(self, engine, statement, parameters=None, typemap=None, **kwargs): + """constructs a new ANSICompiler object. + + engine - SQLEngine to compile against + + statement - ClauseElement to be compiled + + parameters - optional dictionary indicating a set of bind parameters + specified with this Compiled object. These parameters are the "default" + key/value pairs when the Compiled is executed, and also may affect the + actual compilation, as in the case of an INSERT where the actual columns + inserted will correspond to the keys present in the parameters.""" + sql.Compiled.__init__(self, engine, statement, parameters) self.binds = {} self.froms = {} self.wheres = {} @@ -57,37 +69,18 @@ class ANSICompiler(sql.Compiled): self.typemap = typemap or {} self.isinsert = False - if paramstyle is None: - db = self.engine.dbapi() - if db is not None: - paramstyle = db.paramstyle - else: - paramstyle = 'named' - - if paramstyle == 'named': - self.bindtemplate = ':%s' - self.positional=False - elif paramstyle =='pyformat': - self.bindtemplate = "%%(%s)s" - self.positional=False - else: - # for positional, use pyformat until the end - self.bindtemplate = "%%(%s)s" - self.positional=True - self.paramstyle=paramstyle - def after_compile(self): - if self.positional: + if self.engine.positional: self.positiontup = [] match = r'%\(([\w_]+)\)s' params = re.finditer(match, self.strings[self.statement]) for p in params: self.positiontup.append(p.group(1)) - if self.paramstyle=='qmark': + if self.engine.paramstyle=='qmark': self.strings[self.statement] = re.sub(match, '?', self.strings[self.statement]) - elif self.paramstyle=='format': + elif self.engine.paramstyle=='format': self.strings[self.statement] = re.sub(match, '%s', self.strings[self.statement]) - elif self.paramstyle=='numeric': + elif self.engine.paramstyle=='numeric': i = 0 def getnum(x): i += 1 @@ -116,14 +109,22 @@ class ANSICompiler(sql.Compiled): for an executemany style of call, this method should be called for each element in the list of parameter groups that will ultimately be executed. """ - d = {} - if self.bindparams is not None: - bindparams = self.bindparams.copy() + if self.parameters is not None: + bindparams = self.parameters.copy() else: bindparams = {} bindparams.update(params) - # TODO: cant we make "d" an ordereddict and add params in - # positional order + + if self.engine.positional: + d = OrderedDict() + for k in self.positiontup: + b = self.binds[k] + d[k] = b.typeprocess(b.value) + else: + d = {} + for b in self.binds.values(): + d[b.key] = b.typeprocess(b.value) + for key, value in bindparams.iteritems(): try: b = self.binds[key] @@ -131,11 +132,9 @@ class ANSICompiler(sql.Compiled): continue d[b.key] = b.typeprocess(value) - for b in self.binds.values(): - d.setdefault(b.key, b.typeprocess(b.value)) - - if self.positional: - return [d[key] for key in self.positiontup] + return d + if self.engine.positional: + return d.values() else: return d @@ -145,7 +144,8 @@ class ANSICompiler(sql.Compiled): same dictionary. For a positional paramstyle, the given parameters are assumed to be in list format and are converted back to a dictionary. """ - if self.positional: +# return parameters + if self.engine.positional: p = {} for i in range(0, len(self.positiontup)): p[self.positiontup[i]] = parameters[i] @@ -237,7 +237,7 @@ class ANSICompiler(sql.Compiled): self.strings[bindparam] = self.bindparam_string(key) def bindparam_string(self, name): - return self.bindtemplate % name + return self.engine.bindtemplate % name def visit_alias(self, alias): self.froms[alias] = self.get_from_text(alias.selectable) + " AS " + alias.name @@ -265,7 +265,7 @@ class ANSICompiler(sql.Compiled): text = "SELECT " if select.distinct: text += "DISTINCT " - text += collist + " \nFROM " + text += collist whereclause = select.whereclause @@ -282,8 +282,10 @@ class ANSICompiler(sql.Compiled): t = self.get_from_text(f) if t is not None: froms.append(t) - - text += string.join(froms, ', ') + + if len(froms): + text += " \nFROM " + text += string.join(froms, ', ') if whereclause is not None: t = self.get_str(whereclause) @@ -333,10 +335,31 @@ class ANSICompiler(sql.Compiled): self.froms[join] = (self.get_from_text(join.left) + " JOIN " + righttext + " ON " + self.get_str(join.onclause)) self.strings[join] = self.froms[join] + + def visit_insert_column_default(self, column, default): + """called when visiting an Insert statement, for each column in the table that + contains a ColumnDefault object.""" + self.parameters.setdefault(column.key, None) + + def visit_insert_sequence(self, column, sequence): + """called when visiting an Insert statement, for each column in the table that + contains a Sequence object.""" + pass def visit_insert(self, insert_stmt): + # set up a call for the defaults and sequences inside the table + class DefaultVisitor(schema.SchemaVisitor): + def visit_column_default(s, cd): + self.visit_insert_column_default(c, cd) + def visit_sequence(s, seq): + self.visit_insert_sequence(c, seq) + vis = DefaultVisitor() + for c in insert_stmt.table.c: + if self.parameters.get(c.key, None) is None and c.default is not None: + c.default.accept_visitor(vis) + self.isinsert = True - colparams = insert_stmt.get_colparams(self.bindparams) + colparams = insert_stmt.get_colparams(self.parameters) for c in colparams: b = c[1] self.binds[b.key] = b @@ -348,7 +371,7 @@ class ANSICompiler(sql.Compiled): self.strings[insert_stmt] = text def visit_update(self, update_stmt): - colparams = update_stmt.get_colparams(self.bindparams) + colparams = update_stmt.get_colparams(self.parameters) def create_param(p): if isinstance(p, sql.BindParamClause): self.binds[p.key] = p |