diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2005-07-23 20:06:57 +0000 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2005-07-23 20:06:57 +0000 |
commit | 0baed4225dd43885fdf2b0f94e6ea85b9f421e64 (patch) | |
tree | d38703f2a4465a569529b322cd3a7d8776e7aa41 | |
parent | 5a77af7c24d6d6c52b16859c36bb433428fe93ce (diff) | |
download | sqlalchemy-0baed4225dd43885fdf2b0f94e6ea85b9f421e64.tar.gz |
-rw-r--r-- | lib/sqlalchemy/ansisql.py | 20 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/sqlite.py | 10 | ||||
-rw-r--r-- | lib/sqlalchemy/sql.py | 5 |
3 files changed, 22 insertions, 13 deletions
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 0d93ea518..3f6cbb835 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -55,7 +55,7 @@ class ANSISQLEngine(sqlalchemy.engine.SQLEngine): class ANSICompiler(sql.Compiled): def __init__(self, parent, bindparams): self.binds = {} - self.bindparams = bindparams + self._bindparams = bindparams self.parent = parent self.froms = {} self.wheres = {} @@ -71,6 +71,8 @@ class ANSICompiler(sql.Compiled): return self.wheres.get(obj, None) def get_params(self, **params): + """returns the bind params for this compiled object, with values overridden by + those given in the **params dictionary""" d = {} for key, value in params.iteritems(): try: @@ -80,8 +82,7 @@ class ANSICompiler(sql.Compiled): d[b.key] = value for b in self.binds.values(): - if not d.has_key(b.key): - d[b.key] = b.value + d.setdefault(b.key, b.value) return d @@ -166,7 +167,7 @@ class ANSICompiler(sql.Compiled): if t is not None: froms.append(t) - text += string.join(froms, ', ') + text += string.join(froms, ', ') if whereclause is not None: t = self.get_str(whereclause) @@ -182,18 +183,17 @@ class ANSICompiler(sql.Compiled): def visit_table(self, table): self.froms[table] = table.name - + def visit_join(self, join): if join.isouter: self.froms[join] = (self.get_from_text(join.left) + " LEFT OUTER JOIN " + self.get_from_text(join.right) + " ON " + self.get_str(join.onclause)) else: - self.froms[join] = (self.get_from_text(join.left) + " JOIN " + self.get_from_text(join.right) + + self.froms[join] = (self.get_from_text(join.left) + " JOIN " + self.get_from_text(join.right) + " ON " + self.get_str(join.onclause)) - - + def visit_insert(self, insert_stmt): - colparams = insert_stmt.get_colparams(self.bindparams) + colparams = insert_stmt.get_colparams(self._bindparams) for c in colparams: b = c[1] @@ -206,7 +206,7 @@ class ANSICompiler(sql.Compiled): self.strings[insert_stmt] = text def visit_update(self, update_stmt): - colparams = update_stmt.get_colparams(self.bindparams) + colparams = update_stmt.get_colparams(self._bindparams) for c in colparams: b = c[1] diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py index 6a1b58da9..fffda0916 100644 --- a/lib/sqlalchemy/databases/sqlite.py +++ b/lib/sqlalchemy/databases/sqlite.py @@ -51,6 +51,12 @@ class SQLiteSQLEngine(ansisql.ANSISQLEngine): def connect_args(self): return ([self.filename], self.opts) + + def compile(self, statement, bindparams): + compiler = SQLiteCompiler(statement, bindparams) + + statement.accept_visitor(compiler) + return compiler def dbapi(self): return sqlite @@ -61,6 +67,10 @@ class SQLiteSQLEngine(ansisql.ANSISQLEngine): def reflecttable(self, table): raise NotImplementedError() +class SQLiteCompiler(ansisql.ANSICompiler): + def visit_insert(self, insert): + ansisql.ANSICompiler.visit_insert(self, insert) + class SQLiteColumnImpl(sql.ColumnSelectable): def _get_specification(self): coltype = self.column.type diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 29de0d34f..8f3e51fbd 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -136,13 +136,12 @@ class ClauseElement(object): c = self.compile(e, bindparams = params) # TODO: do pre-execute right here, for sequences, if the compiled object # defines it - # TODO: why do we send the params twice, once to compile, once to c.get_params - return e.execute(str(c), c.get_params(**params), echo = getattr(self, 'echo', None)) + return e.execute(str(c), c.get_params(), echo = getattr(self, 'echo', None)) def result(self, **params): e = self._engine() c = self.compile(e, bindparams = params) - return e.result(str(c), c.get_params(**params)) + return e.result(str(c), c.binds) class ColumnClause(ClauseElement): """represents a column clause element in a SQL statement.""" |