diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2005-07-01 02:43:15 +0000 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2005-07-01 02:43:15 +0000 |
commit | b2f0d64fa8c06b5662ce6831dc3fe1588397c76b (patch) | |
tree | e37ba5e716c999f91b287b46583de7beab4f24d0 /lib/sqlalchemy/ansisql.py | |
parent | 76ed6f7ab6823d0906286026a40e6a3fca7ada27 (diff) | |
download | sqlalchemy-b2f0d64fa8c06b5662ce6831dc3fe1588397c76b.tar.gz |
Initial revision
Diffstat (limited to 'lib/sqlalchemy/ansisql.py')
-rw-r--r-- | lib/sqlalchemy/ansisql.py | 244 |
1 files changed, 244 insertions, 0 deletions
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py new file mode 100644 index 000000000..bd63aa9d7 --- /dev/null +++ b/lib/sqlalchemy/ansisql.py @@ -0,0 +1,244 @@ +"""defines ANSI SQL operations.""" + +import sqlalchemy.schema as schema + +from sqlalchemy.schema import * +import sqlalchemy.sql as sql +import sqlalchemy.engine +from sqlalchemy.sql import * +from sqlalchemy.util import * +import string + +def engine(**params): + return ANSISQLEngine(**params) + +class ANSISQLEngine(sqlalchemy.engine.SQLEngine): + + def tableimpl(self, table): + return ANSISQLTableImpl(table) + + def schemagenerator(self, proxy, **params): + return ANSISchemaGenerator(proxy, **params) + + def schemadropper(self, proxy, **params): + return ANSISchemaDropper(proxy, **params) + + def connect_args(self): + return ([],{}) + + def dbapi(self): + return object() + + def compile(self, statement, bindparams): + compiler = ANSICompiler(statement, bindparams) + + statement.accept_visitor(compiler) + return compiler + +class ANSICompiler(sql.Compiled): + def __init__(self, parent, bindparams): + self.binds = {} + self.bindparams = bindparams + self.parent = parent + self.froms = {} + self.wheres = {} + self.strings = {} + + def get_from_text(self, obj): + return self.froms[obj] + + def get_str(self, obj): + return self.strings[obj] + + def get_whereclause(self, obj): + return self.wheres.get(obj, None) + + def get_params(self, **params): + d = {} + for key, value in params.iteritems(): + try: + b = self.binds[key] + except KeyError: + raise "No such bind param in statement '%s': %s" % (str(self), key) + d[b.key] = value + + for b in self.binds.values(): + if not d.has_key(b.key): + d[b.key] = b.value + + return d + + def visit_column(self, column): + if column.table.name is None: + self.strings[column] = column.name + else: + self.strings[column] = "%s.%s" % (column.table.name, column.name) + + def visit_fromclause(self, fromclause): + self.froms[fromclause] = fromclause.from_name + + def visit_textclause(self, textclause): + if textclause.parens and len(textclause.text): + self.strings[textclause] = "(" + textclause.text + ")" + else: + self.strings[textclause] = textclause.text + + def visit_compound(self, compound): + if compound.operator is None: + sep = " " + else: + sep = " " + compound.operator + " " + + if compound.parens: + self.strings[compound] = "(" + string.join([self.get_str(c) for c in compound.clauses], sep) + ")" + else: + self.strings[compound] = string.join([self.get_str(c) for c in compound.clauses], sep) + + def visit_clauselist(self, list): + self.strings[list] = string.join([self.get_str(c) for c in list.clauses], ', ') + + def visit_binary(self, binary): + + if binary.parens: + self.strings[binary] = "(" + self.get_str(binary.left) + " " + str(binary.operator) + " " + self.get_str(binary.right) + ")" + else: + self.strings[binary] = self.get_str(binary.left) + " " + str(binary.operator) + " " + self.get_str(binary.right) + + def visit_bindparam(self, bindparam): + self.binds[bindparam.shortname] = bindparam + + count = 1 + key = bindparam.key + + while self.binds.setdefault(key, bindparam) is not bindparam: + key = "%s_%d" % (bindparam.key, count) + count += 1 + + self.strings[bindparam] = ":" + key + + def visit_alias(self, alias): + self.froms[alias] = self.get_from_text(alias.selectable) + " " + alias.name + + def visit_select(self, select): + inner_columns = [] + + for c in select._raw_columns: + for co in c.columns: + inner_columns.append(co) + + if select.use_labels: + collist = string.join(["%s AS %s" % (c.fullname, c.label) for c in inner_columns], ', ') + else: + collist = string.join([c.fullname for c in inner_columns], ', ') + + text = "SELECT " + collist + " FROM " + + whereclause = select.whereclause + + froms = [] + for f in select.froms.values(): + + # special thingy used by oracle to redefine a join + w = self.get_whereclause(f) + if w is not None: + # TODO: move this more into the oracle module + whereclause = sql.and_(w, whereclause) + self.visit_compound(whereclause) + + t = self.get_from_text(f) + if t is not None: + froms.append(t) + + text += string.join(froms, ', ') + + if whereclause is not None: + t = self.get_str(whereclause) + if t: + text += " WHERE " + t + + for tup in select._clauses: + text += " " + tup[0] + " " + self.get_str(tup[1]) + + self.strings[select] = text + self.froms[select] = "(" + text + ")" + + + def visit_table(self, table): + self.froms[table] = table.name + + def visit_join(self, join): + if join.isouter: + self.froms[join] = ("(" + self.get_from_text(join.left) + " LEFT OUTER JOIN " + self.get_from_text(join.right) + + " ON " + self.get_str(join.onclause) + ")") + else: + self.froms[join] = ("(" + self.get_from_text(join.left) + " JOIN " + self.get_from_text(join.right) + + " ON " + self.get_str(join.onclause) + ")") + + def visit_insert(self, insert_stmt): + colparams = insert_stmt.get_colparams(self.bindparams) + + for c in colparams: + b = c[1] + self.binds[b.key] = b + self.binds[b.shortname] = b + + text = ("INSERT INTO " + insert_stmt.table.name + " (" + string.join([c[0].name for c in colparams], ', ') + ")" + + " VALUES (" + string.join([":" + c[1].key for c in colparams], ', ') + ")") + + self.strings[insert_stmt] = text + + def visit_update(self, update_stmt): + colparams = update_stmt.get_colparams(self.bindparams) + + for c in colparams: + b = c[1] + self.binds[b.key] = b + self.binds[b.shortname] = b + + text = "UPDATE " + update_stmt.table.name + " SET " + string.join(["%s=:%s" % (c[0].name, c[1].key) for c in colparams], ', ') + + if update_stmt.whereclause: + text += " WHERE " + self.get_str(update_stmt.whereclause) + + self.strings[update_stmt] = text + + def __str__(self): + return self.get_str(self.parent) + + + +class ANSISQLTableImpl(sql.TableImpl): + """Selectable implementation that gets attached to a schema.Table object.""" + + def __init__(self, table): + sql.TableImpl.__init__(self) + self.table = table + self.id = self.table.name + + def get_from_text(self): + return self.table.name + +class ANSISchemaGenerator(sqlalchemy.engine.SchemaIterator): + + def visit_table(self, table): + self.append("\nCREATE TABLE " + table.name + "(") + + separator = "\n" + + for column in table.columns: + self.append(separator) + separator = ", \n" + self.append("\t" + column._get_specification()) + + self.append("\n)\n\n") + self.execute() + + def visit_column(self, column): + pass + +class ANSISchemaDropper(sqlalchemy.engine.SchemaIterator): + def visit_table(self, table): + self.append("\nDROP TABLE " + table.name) + self.execute() + + |