summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ansisql.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2005-07-01 02:43:15 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2005-07-01 02:43:15 +0000
commitb2f0d64fa8c06b5662ce6831dc3fe1588397c76b (patch)
treee37ba5e716c999f91b287b46583de7beab4f24d0 /lib/sqlalchemy/ansisql.py
parent76ed6f7ab6823d0906286026a40e6a3fca7ada27 (diff)
downloadsqlalchemy-b2f0d64fa8c06b5662ce6831dc3fe1588397c76b.tar.gz
Initial revision
Diffstat (limited to 'lib/sqlalchemy/ansisql.py')
-rw-r--r--lib/sqlalchemy/ansisql.py244
1 files changed, 244 insertions, 0 deletions
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py
new file mode 100644
index 000000000..bd63aa9d7
--- /dev/null
+++ b/lib/sqlalchemy/ansisql.py
@@ -0,0 +1,244 @@
+"""defines ANSI SQL operations."""
+
+import sqlalchemy.schema as schema
+
+from sqlalchemy.schema import *
+import sqlalchemy.sql as sql
+import sqlalchemy.engine
+from sqlalchemy.sql import *
+from sqlalchemy.util import *
+import string
+
+def engine(**params):
+ return ANSISQLEngine(**params)
+
+class ANSISQLEngine(sqlalchemy.engine.SQLEngine):
+
+ def tableimpl(self, table):
+ return ANSISQLTableImpl(table)
+
+ def schemagenerator(self, proxy, **params):
+ return ANSISchemaGenerator(proxy, **params)
+
+ def schemadropper(self, proxy, **params):
+ return ANSISchemaDropper(proxy, **params)
+
+ def connect_args(self):
+ return ([],{})
+
+ def dbapi(self):
+ return object()
+
+ def compile(self, statement, bindparams):
+ compiler = ANSICompiler(statement, bindparams)
+
+ statement.accept_visitor(compiler)
+ return compiler
+
+class ANSICompiler(sql.Compiled):
+ def __init__(self, parent, bindparams):
+ self.binds = {}
+ self.bindparams = bindparams
+ self.parent = parent
+ self.froms = {}
+ self.wheres = {}
+ self.strings = {}
+
+ def get_from_text(self, obj):
+ return self.froms[obj]
+
+ def get_str(self, obj):
+ return self.strings[obj]
+
+ def get_whereclause(self, obj):
+ return self.wheres.get(obj, None)
+
+ def get_params(self, **params):
+ d = {}
+ for key, value in params.iteritems():
+ try:
+ b = self.binds[key]
+ except KeyError:
+ raise "No such bind param in statement '%s': %s" % (str(self), key)
+ d[b.key] = value
+
+ for b in self.binds.values():
+ if not d.has_key(b.key):
+ d[b.key] = b.value
+
+ return d
+
+ def visit_column(self, column):
+ if column.table.name is None:
+ self.strings[column] = column.name
+ else:
+ self.strings[column] = "%s.%s" % (column.table.name, column.name)
+
+ def visit_fromclause(self, fromclause):
+ self.froms[fromclause] = fromclause.from_name
+
+ def visit_textclause(self, textclause):
+ if textclause.parens and len(textclause.text):
+ self.strings[textclause] = "(" + textclause.text + ")"
+ else:
+ self.strings[textclause] = textclause.text
+
+ def visit_compound(self, compound):
+ if compound.operator is None:
+ sep = " "
+ else:
+ sep = " " + compound.operator + " "
+
+ if compound.parens:
+ self.strings[compound] = "(" + string.join([self.get_str(c) for c in compound.clauses], sep) + ")"
+ else:
+ self.strings[compound] = string.join([self.get_str(c) for c in compound.clauses], sep)
+
+ def visit_clauselist(self, list):
+ self.strings[list] = string.join([self.get_str(c) for c in list.clauses], ', ')
+
+ def visit_binary(self, binary):
+
+ if binary.parens:
+ self.strings[binary] = "(" + self.get_str(binary.left) + " " + str(binary.operator) + " " + self.get_str(binary.right) + ")"
+ else:
+ self.strings[binary] = self.get_str(binary.left) + " " + str(binary.operator) + " " + self.get_str(binary.right)
+
+ def visit_bindparam(self, bindparam):
+ self.binds[bindparam.shortname] = bindparam
+
+ count = 1
+ key = bindparam.key
+
+ while self.binds.setdefault(key, bindparam) is not bindparam:
+ key = "%s_%d" % (bindparam.key, count)
+ count += 1
+
+ self.strings[bindparam] = ":" + key
+
+ def visit_alias(self, alias):
+ self.froms[alias] = self.get_from_text(alias.selectable) + " " + alias.name
+
+ def visit_select(self, select):
+ inner_columns = []
+
+ for c in select._raw_columns:
+ for co in c.columns:
+ inner_columns.append(co)
+
+ if select.use_labels:
+ collist = string.join(["%s AS %s" % (c.fullname, c.label) for c in inner_columns], ', ')
+ else:
+ collist = string.join([c.fullname for c in inner_columns], ', ')
+
+ text = "SELECT " + collist + " FROM "
+
+ whereclause = select.whereclause
+
+ froms = []
+ for f in select.froms.values():
+
+ # special thingy used by oracle to redefine a join
+ w = self.get_whereclause(f)
+ if w is not None:
+ # TODO: move this more into the oracle module
+ whereclause = sql.and_(w, whereclause)
+ self.visit_compound(whereclause)
+
+ t = self.get_from_text(f)
+ if t is not None:
+ froms.append(t)
+
+ text += string.join(froms, ', ')
+
+ if whereclause is not None:
+ t = self.get_str(whereclause)
+ if t:
+ text += " WHERE " + t
+
+ for tup in select._clauses:
+ text += " " + tup[0] + " " + self.get_str(tup[1])
+
+ self.strings[select] = text
+ self.froms[select] = "(" + text + ")"
+
+
+ def visit_table(self, table):
+ self.froms[table] = table.name
+
+ def visit_join(self, join):
+ if join.isouter:
+ self.froms[join] = ("(" + self.get_from_text(join.left) + " LEFT OUTER JOIN " + self.get_from_text(join.right) +
+ " ON " + self.get_str(join.onclause) + ")")
+ else:
+ self.froms[join] = ("(" + self.get_from_text(join.left) + " JOIN " + self.get_from_text(join.right) +
+ " ON " + self.get_str(join.onclause) + ")")
+
+ def visit_insert(self, insert_stmt):
+ colparams = insert_stmt.get_colparams(self.bindparams)
+
+ for c in colparams:
+ b = c[1]
+ self.binds[b.key] = b
+ self.binds[b.shortname] = b
+
+ text = ("INSERT INTO " + insert_stmt.table.name + " (" + string.join([c[0].name for c in colparams], ', ') + ")" +
+ " VALUES (" + string.join([":" + c[1].key for c in colparams], ', ') + ")")
+
+ self.strings[insert_stmt] = text
+
+ def visit_update(self, update_stmt):
+ colparams = update_stmt.get_colparams(self.bindparams)
+
+ for c in colparams:
+ b = c[1]
+ self.binds[b.key] = b
+ self.binds[b.shortname] = b
+
+ text = "UPDATE " + update_stmt.table.name + " SET " + string.join(["%s=:%s" % (c[0].name, c[1].key) for c in colparams], ', ')
+
+ if update_stmt.whereclause:
+ text += " WHERE " + self.get_str(update_stmt.whereclause)
+
+ self.strings[update_stmt] = text
+
+ def __str__(self):
+ return self.get_str(self.parent)
+
+
+
+class ANSISQLTableImpl(sql.TableImpl):
+ """Selectable implementation that gets attached to a schema.Table object."""
+
+ def __init__(self, table):
+ sql.TableImpl.__init__(self)
+ self.table = table
+ self.id = self.table.name
+
+ def get_from_text(self):
+ return self.table.name
+
+class ANSISchemaGenerator(sqlalchemy.engine.SchemaIterator):
+
+ def visit_table(self, table):
+ self.append("\nCREATE TABLE " + table.name + "(")
+
+ separator = "\n"
+
+ for column in table.columns:
+ self.append(separator)
+ separator = ", \n"
+ self.append("\t" + column._get_specification())
+
+ self.append("\n)\n\n")
+ self.execute()
+
+ def visit_column(self, column):
+ pass
+
+class ANSISchemaDropper(sqlalchemy.engine.SchemaIterator):
+ def visit_table(self, table):
+ self.append("\nDROP TABLE " + table.name)
+ self.execute()
+
+