summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ansisql.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2006-08-12 17:28:15 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2006-08-12 17:28:15 +0000
commit9e8fad2abcce364253352f042836bf58ce8f4f81 (patch)
tree5058c15280a2e56d454670deeb7a53dd8b6b1f67 /lib/sqlalchemy/ansisql.py
parentfb88b031d916ea91ce9af760a67ea27e00113c14 (diff)
downloadsqlalchemy-9e8fad2abcce364253352f042836bf58ce8f4f81.tar.gz
quoting facilities set up so that database-specific quoting can be
turned on for individual table, schema, and column identifiers when used in all queries/creates/drops. Enabled via "quote=True" in Table or Column, as well as "quote_schema=True" in Table. Thanks to Aaron Spike for his excellent efforts. [ticket:155]
Diffstat (limited to 'lib/sqlalchemy/ansisql.py')
-rw-r--r--lib/sqlalchemy/ansisql.py109
1 files changed, 90 insertions, 19 deletions
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py
index b85f67d47..e1791324d 100644
--- a/lib/sqlalchemy/ansisql.py
+++ b/lib/sqlalchemy/ansisql.py
@@ -42,6 +42,11 @@ class ANSIDialect(default.DefaultDialect):
def compiler(self, statement, parameters, **kwargs):
return ANSICompiler(self, statement, parameters, **kwargs)
+ def preparer(self):
+ """return an IdenfifierPreparer.
+
+ This object is used to format table and column names including proper quoting and case conventions."""
+ return ANSIIdentifierPreparer()
class ANSICompiler(sql.Compiled):
"""default implementation of Compiled, which compiles ClauseElements into ANSI-compliant SQL strings."""
@@ -69,6 +74,7 @@ class ANSICompiler(sql.Compiled):
self.bindtemplate = ":%s"
self.paramstyle = dialect.paramstyle
self.positional = dialect.positional
+ self.preparer = dialect.preparer()
def after_compile(self):
# this re will search for params like :param
@@ -170,19 +176,18 @@ class ANSICompiler(sql.Compiled):
# for this column which is used to translate result set values
self.typemap.setdefault(column.name.lower(), column.type)
if column.table is None or not column.table.named_with_column():
- self.strings[column] = column.name
+ self.strings[column] = self.preparer.format_column(column)
else:
if column.table.oid_column is column:
n = self.dialect.oid_column_name()
if n is not None:
- self.strings[column] = "%s.%s" % (column.table.name, n)
+ self.strings[column] = "%s.%s" % (self.preparer.format_table(column.table, use_schema=False), n)
elif len(column.table.primary_key) != 0:
- self.strings[column] = "%s.%s" % (column.table.name, column.table.primary_key[0].name)
+ self.strings[column] = self.preparer.format_column_with_table(column.table.primary_key[0])
else:
self.strings[column] = None
else:
- self.strings[column] = "%s.%s" % (column.table.name, column.name)
-
+ self.strings[column] = self.preparer.format_column_with_table(column)
def visit_fromclause(self, fromclause):
self.froms[fromclause] = fromclause.from_name
@@ -427,7 +432,7 @@ class ANSICompiler(sql.Compiled):
return " OFFSET " + str(select.offset)
def visit_table(self, table):
- self.froms[table] = table.fullname
+ self.froms[table] = self.preparer.format_table(table)
self.strings[table] = ""
def visit_join(self, join):
@@ -501,7 +506,7 @@ class ANSICompiler(sql.Compiled):
else:
return self.get_str(p)
- text = ("INSERT INTO " + insert_stmt.table.fullname + " (" + string.join([c[0].name for c in colparams], ', ') + ")" +
+ text = ("INSERT INTO " + self.preparer.format_table(insert_stmt.table) + " (" + string.join([self.preparer.format_column(c[0]) for c in colparams], ', ') + ")" +
" VALUES (" + string.join([create_param(c[1]) for c in colparams], ', ') + ")")
self.strings[insert_stmt] = text
@@ -532,7 +537,7 @@ class ANSICompiler(sql.Compiled):
else:
return self.get_str(p)
- text = "UPDATE " + update_stmt.table.fullname + " SET " + string.join(["%s=%s" % (c[0].name, create_param(c[1])) for c in colparams], ', ')
+ text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.format_column(c[0]), create_param(c[1])) for c in colparams], ', ')
if update_stmt.whereclause:
text += " WHERE " + self.get_str(update_stmt.whereclause)
@@ -596,7 +601,7 @@ class ANSICompiler(sql.Compiled):
return values
def visit_delete(self, delete_stmt):
- text = "DELETE FROM " + delete_stmt.table.fullname
+ text = "DELETE FROM " + self.preparer.format_table(delete_stmt.table)
if delete_stmt.whereclause:
text += " WHERE " + self.get_str(delete_stmt.whereclause)
@@ -612,6 +617,8 @@ class ANSISchemaGenerator(engine.SchemaIterator):
super(ANSISchemaGenerator, self).__init__(engine, proxy, **params)
self.checkfirst = checkfirst
self.connection = connection
+ self.preparer = self.engine.dialect.preparer()
+
def get_column_specification(self, column, first_pk=False):
raise NotImplementedError()
@@ -622,7 +629,7 @@ class ANSISchemaGenerator(engine.SchemaIterator):
if self.checkfirst and self.engine.dialect.has_table(self.connection, table.name):
return
- self.append("\nCREATE TABLE " + table.fullname + " (")
+ self.append("\nCREATE TABLE " + self.preparer.format_table(table) + " (")
separator = "\n"
@@ -665,16 +672,16 @@ class ANSISchemaGenerator(engine.SchemaIterator):
if len(constraint) == 0:
return
self.append(", \n")
- self.append("\tPRIMARY KEY (%s)" % string.join([c.name for c in constraint],', '))
+ self.append("\tPRIMARY KEY (%s)" % string.join([self.preparer.format_column(c) for c in constraint],', '))
def visit_foreign_key_constraint(self, constraint):
self.append(", \n\t ")
if constraint.name is not None:
self.append("CONSTRAINT %s " % constraint.name)
self.append("FOREIGN KEY(%s) REFERENCES %s (%s)" % (
- string.join([f.parent.name for f in constraint.elements], ', '),
- list(constraint.elements)[0].column.table.fullname,
- string.join([f.column.name for f in constraint.elements], ', ')
+ string.join([self.preparer.format_column(f.parent) for f in constraint.elements], ', '),
+ self.preparer.format_table(list(constraint.elements)[0].column.table),
+ string.join([self.preparer.format_column(f.column) for f in constraint.elements], ', ')
))
if constraint.ondelete is not None:
self.append(" ON DELETE %s" % constraint.ondelete)
@@ -689,16 +696,16 @@ class ANSISchemaGenerator(engine.SchemaIterator):
if index.unique:
self.append('UNIQUE ')
self.append('INDEX %s ON %s (%s)' \
- % (index.name, index.table.fullname,
- string.join([c.name for c in index.columns], ', ')))
+ % (index.name, self.preparer.format_table(index.table),
+ string.join([self.preparer.format_column(c) for c in index.columns], ', ')))
self.execute()
-
class ANSISchemaDropper(engine.SchemaIterator):
def __init__(self, engine, proxy, connection=None, checkfirst=False, **params):
super(ANSISchemaDropper, self).__init__(engine, proxy, **params)
self.checkfirst = checkfirst
self.connection = connection
+ self.preparer = self.engine.dialect.preparer()
def visit_index(self, index):
self.append("\nDROP INDEX " + index.name)
@@ -709,9 +716,73 @@ class ANSISchemaDropper(engine.SchemaIterator):
# no need to drop them individually
if self.checkfirst and not self.engine.dialect.has_table(self.connection, table.name):
return
- self.append("\nDROP TABLE " + table.fullname)
+ self.append("\nDROP TABLE " + self.preparer.format_table(table))
self.execute()
-
class ANSIDefaultRunner(engine.DefaultRunner):
pass
+
+class ANSIIdentifierPreparer(object):
+ """Transforms identifiers into ANSI-Compliant delimited identifiers where required"""
+ def __init__(self, initial_quote='"', final_quote=None, omit_schema=False):
+ """Constructs a new ANSIIdentifierPreparer object.
+
+ initial_quote - Character that begins a delimited identifier
+ final_quote - Caracter that ends a delimited identifier. defaults to initial_quote.
+
+ omit_schema - prevent prepending schema name. useful for databases that do not support schemae
+ """
+ self.initial_quote = initial_quote
+ self.final_quote = final_quote or self.initial_quote
+ self.omit_schema = omit_schema
+
+ def _escape_identifier(self, value):
+ return value.replace('"', '""')
+
+ def _quote_identifier(self, value):
+ return self.initial_quote + self._escape_identifier(value) + self.final_quote
+
+ def _fold_identifier_case(self, value):
+ return value
+ # ANSI SQL calls for the case of all unquoted identifiers to be folded to UPPER.
+ # some tests would need to be rewritten if this is done.
+ #return value.upper()
+
+ def _prepare_table(self, table, use_schema=False):
+ names = []
+ if table.quote:
+ names.append(self._quote_identifier(table.name))
+ else:
+ names.append(self._fold_identifier_case(table.name))
+
+ if not self.omit_schema and use_schema and table.schema:
+ if table.quote_schema:
+ names.insert(0, self._quote_identifier(table.schema))
+ else:
+ names.insert(0, self._fold_identifier_case(table.schema))
+
+ return ".".join(names)
+
+ def _prepare_column(self, column, use_table=True, **kwargs):
+ names = []
+ if column.quote:
+ names.append(self._quote_identifier(column.name))
+ else:
+ names.append(self._fold_identifier_case(column.name))
+
+ if use_table:
+ names.insert(0, self._prepare_table(column.table, **kwargs))
+
+ return ".".join(names)
+
+ def format_table(self, table, use_schema=True):
+ """Prepare a quoted table and schema name"""
+ return self._prepare_table(table, use_schema=use_schema)
+
+ def format_column(self, column):
+ """Prepare a quoted column name"""
+ return self._prepare_column(column, use_table=False)
+
+ def format_column_with_table(self, column):
+ """Prepare a quoted column name with table name"""
+ return self._prepare_column(column)