summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ansisql.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2006-05-25 14:20:23 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2006-05-25 14:20:23 +0000
commitbb79e2e871d0a4585164c1a6ed626d96d0231975 (patch)
tree6d457ba6c36c408b45db24ec3c29e147fe7504ff /lib/sqlalchemy/ansisql.py
parent4fc3a0648699c2b441251ba4e1d37a9107bd1986 (diff)
downloadsqlalchemy-bb79e2e871d0a4585164c1a6ed626d96d0231975.tar.gz
merged 0.2 branch into trunk; 0.1 now in sqlalchemy/branches/rel_0_1
Diffstat (limited to 'lib/sqlalchemy/ansisql.py')
-rw-r--r--lib/sqlalchemy/ansisql.py100
1 files changed, 53 insertions, 47 deletions
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py
index df3f8fa59..6956c5379 100644
--- a/lib/sqlalchemy/ansisql.py
+++ b/lib/sqlalchemy/ansisql.py
@@ -4,18 +4,14 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-"""defines ANSI SQL operations."""
+"""defines ANSI SQL operations. Contains default implementations for the abstract objects
+in the sql module."""
-import sqlalchemy.schema as schema
-
-from sqlalchemy.schema import *
-import sqlalchemy.sql as sql
-import sqlalchemy.engine
-from sqlalchemy.sql import *
-from sqlalchemy.util import *
+from sqlalchemy import schema, sql, engine, util
+import sqlalchemy.engine.default as default
import string, re
-ANSI_FUNCS = HashSet([
+ANSI_FUNCS = util.HashSet([
'CURRENT_TIME',
'CURRENT_TIMESTAMP',
'CURRENT_DATE',
@@ -27,32 +23,32 @@ ANSI_FUNCS = HashSet([
])
-def engine(**params):
- return ANSISQLEngine(**params)
-
-class ANSISQLEngine(sqlalchemy.engine.SQLEngine):
-
- def schemagenerator(self, **params):
- return ANSISchemaGenerator(self, **params)
-
- def schemadropper(self, **params):
- return ANSISchemaDropper(self, **params)
-
- def compiler(self, statement, parameters, **kwargs):
- return ANSICompiler(statement, parameters, engine=self, **kwargs)
+def create_engine():
+ return engine.ComposedSQLEngine(None, ANSIDialect())
+class ANSIDialect(default.DefaultDialect):
def connect_args(self):
return ([],{})
def dbapi(self):
return None
+ def schemagenerator(self, *args, **params):
+ return ANSISchemaGenerator(*args, **params)
+
+ def schemadropper(self, *args, **params):
+ return ANSISchemaDropper(*args, **params)
+
+ def compiler(self, statement, parameters, **kwargs):
+ return ANSICompiler(self, statement, parameters, **kwargs)
+
+
class ANSICompiler(sql.Compiled):
"""default implementation of Compiled, which compiles ClauseElements into ANSI-compliant SQL strings."""
- def __init__(self, statement, parameters=None, typemap=None, engine=None, positional=None, paramstyle=None, **kwargs):
+ def __init__(self, dialect, statement, parameters=None, **kwargs):
"""constructs a new ANSICompiler object.
- engine - SQLEngine to compile against
+ dialect - Dialect to be used
statement - ClauseElement to be compiled
@@ -61,22 +57,18 @@ class ANSICompiler(sql.Compiled):
key/value pairs when the Compiled is executed, and also may affect the
actual compilation, as in the case of an INSERT where the actual columns
inserted will correspond to the keys present in the parameters."""
- sql.Compiled.__init__(self, statement, parameters, engine=engine)
+ sql.Compiled.__init__(self, dialect, statement, parameters, **kwargs)
self.binds = {}
self.froms = {}
self.wheres = {}
self.strings = {}
self.select_stack = []
- self.typemap = typemap or {}
+ self.typemap = {}
self.isinsert = False
self.isupdate = False
self.bindtemplate = ":%s"
- if engine is not None:
- self.paramstyle = engine.paramstyle
- self.positional = engine.positional
- else:
- self.positional = False
- self.paramstyle = 'named'
+ self.paramstyle = dialect.paramstyle
+ self.positional = dialect.positional
def after_compile(self):
# this re will search for params like :param
@@ -130,7 +122,7 @@ class ANSICompiler(sql.Compiled):
bindparams = {}
bindparams.update(params)
- d = sql.ClauseParameters(self.engine)
+ d = sql.ClauseParameters(self.dialect)
if self.positional:
for k in self.positiontup:
b = self.binds[k]
@@ -177,10 +169,19 @@ class ANSICompiler(sql.Compiled):
# if we are within a visit to a Select, set up the "typemap"
# for this column which is used to translate result set values
self.typemap.setdefault(column.key.lower(), column.type)
- if column.table is None or column.table.name is None:
+ if column.table is None or not column.table.named_with_column():
self.strings[column] = column.name
else:
- self.strings[column] = "%s.%s" % (column.table.name, column.name)
+ if column.table.oid_column is column:
+ n = self.dialect.oid_column_name()
+ if n is not None:
+ self.strings[column] = "%s.%s" % (column.table.name, n)
+ elif len(column.table.primary_key) != 0:
+ self.strings[column] = "%s.%s" % (column.table.name, column.table.primary_key[0].name)
+ else:
+ self.strings[column] = None
+ else:
+ self.strings[column] = "%s.%s" % (column.table.name, column.name)
def visit_fromclause(self, fromclause):
@@ -190,7 +191,7 @@ class ANSICompiler(sql.Compiled):
self.strings[index] = index.name
def visit_typeclause(self, typeclause):
- self.strings[typeclause] = typeclause.type.engine_impl(self.engine).get_col_spec()
+ self.strings[typeclause] = typeclause.type.dialect_impl(self.dialect).get_col_spec()
def visit_textclause(self, textclause):
if textclause.parens and len(textclause.text):
@@ -218,9 +219,9 @@ class ANSICompiler(sql.Compiled):
def visit_clauselist(self, list):
if list.parens:
- self.strings[list] = "(" + string.join([self.get_str(c) for c in list.clauses], ', ') + ")"
+ self.strings[list] = "(" + string.join([s for s in [self.get_str(c) for c in list.clauses] if s is not None], ', ') + ")"
else:
- self.strings[list] = string.join([self.get_str(c) for c in list.clauses], ', ')
+ self.strings[list] = string.join([s for s in [self.get_str(c) for c in list.clauses] if s is not None], ', ')
def apply_function_parens(self, func):
return func.name.upper() not in ANSI_FUNCS or len(func.clauses) > 0
@@ -294,7 +295,7 @@ class ANSICompiler(sql.Compiled):
# the actual list of columns to print in the SELECT column list.
# its an ordered dictionary to insure that the actual labeled column name
# is unique.
- inner_columns = OrderedDict()
+ inner_columns = util.OrderedDict()
self.select_stack.append(select)
for c in select._raw_columns:
@@ -314,7 +315,7 @@ class ANSICompiler(sql.Compiled):
# SQLite doesnt like selecting from a subquery where the column
# names look like table.colname, so add a label synonomous with
# the column name
- l = co.label(co.text)
+ l = co.label(co.name)
l.accept_visitor(self)
inner_columns[self.get_str(l.obj)] = l
else:
@@ -385,7 +386,7 @@ class ANSICompiler(sql.Compiled):
order_by = self.get_str(select.order_by_clause)
if order_by:
text += " ORDER BY " + order_by
-
+
text += self.visit_select_postclauses(select)
if select.for_update:
@@ -545,7 +546,7 @@ class ANSICompiler(sql.Compiled):
# case one: no parameters in the statement, no parameters in the
# compiled params - just return binds for all the table columns
if self.parameters is None and stmt.parameters is None:
- return [(c, bindparam(c.name, type=c.type)) for c in stmt.table.columns]
+ return [(c, sql.bindparam(c.name, type=c.type)) for c in stmt.table.columns]
# if we have statement parameters - set defaults in the
# compiled params
@@ -578,7 +579,7 @@ class ANSICompiler(sql.Compiled):
if d.has_key(c):
value = d[c]
if sql._is_literal(value):
- value = bindparam(c.name, value, type=c.type)
+ value = sql.bindparam(c.name, value, type=c.type)
values.append((c, value))
return values
@@ -594,7 +595,7 @@ class ANSICompiler(sql.Compiled):
return self.get_str(self.statement)
-class ANSISchemaGenerator(sqlalchemy.engine.SchemaIterator):
+class ANSISchemaGenerator(engine.SchemaIterator):
def get_column_specification(self, column, override_pk=False, first_pk=False):
raise NotImplementedError()
@@ -631,10 +632,15 @@ class ANSISchemaGenerator(sqlalchemy.engine.SchemaIterator):
if isinstance(column.default.arg, str):
return repr(column.default.arg)
else:
- return str(column.default.arg.compile(self.engine))
+ return str(self._compile(column.default.arg, None))
else:
return None
+ def _compile(self, tocompile, parameters):
+ compiler = self.engine.dialect.compiler(tocompile, parameters)
+ compiler.compile()
+ return compiler
+
def visit_column(self, column):
pass
@@ -648,7 +654,7 @@ class ANSISchemaGenerator(sqlalchemy.engine.SchemaIterator):
self.execute()
-class ANSISchemaDropper(sqlalchemy.engine.SchemaIterator):
+class ANSISchemaDropper(engine.SchemaIterator):
def visit_index(self, index):
self.append("\nDROP INDEX " + index.name)
self.execute()
@@ -660,5 +666,5 @@ class ANSISchemaDropper(sqlalchemy.engine.SchemaIterator):
self.execute()
-class ANSIDefaultRunner(sqlalchemy.engine.DefaultRunner):
+class ANSIDefaultRunner(engine.DefaultRunner):
pass