diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2006-08-12 17:28:15 +0000 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2006-08-12 17:28:15 +0000 |
commit | 9e8fad2abcce364253352f042836bf58ce8f4f81 (patch) | |
tree | 5058c15280a2e56d454670deeb7a53dd8b6b1f67 /lib/sqlalchemy/ansisql.py | |
parent | fb88b031d916ea91ce9af760a67ea27e00113c14 (diff) | |
download | sqlalchemy-9e8fad2abcce364253352f042836bf58ce8f4f81.tar.gz |
quoting facilities set up so that database-specific quoting can be
turned on for individual table, schema, and column identifiers when
used in all queries/creates/drops. Enabled via "quote=True" in
Table or Column, as well as "quote_schema=True" in Table. Thanks to
Aaron Spike for his excellent efforts. [ticket:155]
Diffstat (limited to 'lib/sqlalchemy/ansisql.py')
-rw-r--r-- | lib/sqlalchemy/ansisql.py | 109 |
1 files changed, 90 insertions, 19 deletions
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index b85f67d47..e1791324d 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -42,6 +42,11 @@ class ANSIDialect(default.DefaultDialect): def compiler(self, statement, parameters, **kwargs): return ANSICompiler(self, statement, parameters, **kwargs) + def preparer(self): + """return an IdenfifierPreparer. + + This object is used to format table and column names including proper quoting and case conventions.""" + return ANSIIdentifierPreparer() class ANSICompiler(sql.Compiled): """default implementation of Compiled, which compiles ClauseElements into ANSI-compliant SQL strings.""" @@ -69,6 +74,7 @@ class ANSICompiler(sql.Compiled): self.bindtemplate = ":%s" self.paramstyle = dialect.paramstyle self.positional = dialect.positional + self.preparer = dialect.preparer() def after_compile(self): # this re will search for params like :param @@ -170,19 +176,18 @@ class ANSICompiler(sql.Compiled): # for this column which is used to translate result set values self.typemap.setdefault(column.name.lower(), column.type) if column.table is None or not column.table.named_with_column(): - self.strings[column] = column.name + self.strings[column] = self.preparer.format_column(column) else: if column.table.oid_column is column: n = self.dialect.oid_column_name() if n is not None: - self.strings[column] = "%s.%s" % (column.table.name, n) + self.strings[column] = "%s.%s" % (self.preparer.format_table(column.table, use_schema=False), n) elif len(column.table.primary_key) != 0: - self.strings[column] = "%s.%s" % (column.table.name, column.table.primary_key[0].name) + self.strings[column] = self.preparer.format_column_with_table(column.table.primary_key[0]) else: self.strings[column] = None else: - self.strings[column] = "%s.%s" % (column.table.name, column.name) - + self.strings[column] = self.preparer.format_column_with_table(column) def visit_fromclause(self, fromclause): self.froms[fromclause] = fromclause.from_name @@ -427,7 +432,7 @@ class ANSICompiler(sql.Compiled): return " OFFSET " + str(select.offset) def visit_table(self, table): - self.froms[table] = table.fullname + self.froms[table] = self.preparer.format_table(table) self.strings[table] = "" def visit_join(self, join): @@ -501,7 +506,7 @@ class ANSICompiler(sql.Compiled): else: return self.get_str(p) - text = ("INSERT INTO " + insert_stmt.table.fullname + " (" + string.join([c[0].name for c in colparams], ', ') + ")" + + text = ("INSERT INTO " + self.preparer.format_table(insert_stmt.table) + " (" + string.join([self.preparer.format_column(c[0]) for c in colparams], ', ') + ")" + " VALUES (" + string.join([create_param(c[1]) for c in colparams], ', ') + ")") self.strings[insert_stmt] = text @@ -532,7 +537,7 @@ class ANSICompiler(sql.Compiled): else: return self.get_str(p) - text = "UPDATE " + update_stmt.table.fullname + " SET " + string.join(["%s=%s" % (c[0].name, create_param(c[1])) for c in colparams], ', ') + text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.format_column(c[0]), create_param(c[1])) for c in colparams], ', ') if update_stmt.whereclause: text += " WHERE " + self.get_str(update_stmt.whereclause) @@ -596,7 +601,7 @@ class ANSICompiler(sql.Compiled): return values def visit_delete(self, delete_stmt): - text = "DELETE FROM " + delete_stmt.table.fullname + text = "DELETE FROM " + self.preparer.format_table(delete_stmt.table) if delete_stmt.whereclause: text += " WHERE " + self.get_str(delete_stmt.whereclause) @@ -612,6 +617,8 @@ class ANSISchemaGenerator(engine.SchemaIterator): super(ANSISchemaGenerator, self).__init__(engine, proxy, **params) self.checkfirst = checkfirst self.connection = connection + self.preparer = self.engine.dialect.preparer() + def get_column_specification(self, column, first_pk=False): raise NotImplementedError() @@ -622,7 +629,7 @@ class ANSISchemaGenerator(engine.SchemaIterator): if self.checkfirst and self.engine.dialect.has_table(self.connection, table.name): return - self.append("\nCREATE TABLE " + table.fullname + " (") + self.append("\nCREATE TABLE " + self.preparer.format_table(table) + " (") separator = "\n" @@ -665,16 +672,16 @@ class ANSISchemaGenerator(engine.SchemaIterator): if len(constraint) == 0: return self.append(", \n") - self.append("\tPRIMARY KEY (%s)" % string.join([c.name for c in constraint],', ')) + self.append("\tPRIMARY KEY (%s)" % string.join([self.preparer.format_column(c) for c in constraint],', ')) def visit_foreign_key_constraint(self, constraint): self.append(", \n\t ") if constraint.name is not None: self.append("CONSTRAINT %s " % constraint.name) self.append("FOREIGN KEY(%s) REFERENCES %s (%s)" % ( - string.join([f.parent.name for f in constraint.elements], ', '), - list(constraint.elements)[0].column.table.fullname, - string.join([f.column.name for f in constraint.elements], ', ') + string.join([self.preparer.format_column(f.parent) for f in constraint.elements], ', '), + self.preparer.format_table(list(constraint.elements)[0].column.table), + string.join([self.preparer.format_column(f.column) for f in constraint.elements], ', ') )) if constraint.ondelete is not None: self.append(" ON DELETE %s" % constraint.ondelete) @@ -689,16 +696,16 @@ class ANSISchemaGenerator(engine.SchemaIterator): if index.unique: self.append('UNIQUE ') self.append('INDEX %s ON %s (%s)' \ - % (index.name, index.table.fullname, - string.join([c.name for c in index.columns], ', '))) + % (index.name, self.preparer.format_table(index.table), + string.join([self.preparer.format_column(c) for c in index.columns], ', '))) self.execute() - class ANSISchemaDropper(engine.SchemaIterator): def __init__(self, engine, proxy, connection=None, checkfirst=False, **params): super(ANSISchemaDropper, self).__init__(engine, proxy, **params) self.checkfirst = checkfirst self.connection = connection + self.preparer = self.engine.dialect.preparer() def visit_index(self, index): self.append("\nDROP INDEX " + index.name) @@ -709,9 +716,73 @@ class ANSISchemaDropper(engine.SchemaIterator): # no need to drop them individually if self.checkfirst and not self.engine.dialect.has_table(self.connection, table.name): return - self.append("\nDROP TABLE " + table.fullname) + self.append("\nDROP TABLE " + self.preparer.format_table(table)) self.execute() - class ANSIDefaultRunner(engine.DefaultRunner): pass + +class ANSIIdentifierPreparer(object): + """Transforms identifiers into ANSI-Compliant delimited identifiers where required""" + def __init__(self, initial_quote='"', final_quote=None, omit_schema=False): + """Constructs a new ANSIIdentifierPreparer object. + + initial_quote - Character that begins a delimited identifier + final_quote - Caracter that ends a delimited identifier. defaults to initial_quote. + + omit_schema - prevent prepending schema name. useful for databases that do not support schemae + """ + self.initial_quote = initial_quote + self.final_quote = final_quote or self.initial_quote + self.omit_schema = omit_schema + + def _escape_identifier(self, value): + return value.replace('"', '""') + + def _quote_identifier(self, value): + return self.initial_quote + self._escape_identifier(value) + self.final_quote + + def _fold_identifier_case(self, value): + return value + # ANSI SQL calls for the case of all unquoted identifiers to be folded to UPPER. + # some tests would need to be rewritten if this is done. + #return value.upper() + + def _prepare_table(self, table, use_schema=False): + names = [] + if table.quote: + names.append(self._quote_identifier(table.name)) + else: + names.append(self._fold_identifier_case(table.name)) + + if not self.omit_schema and use_schema and table.schema: + if table.quote_schema: + names.insert(0, self._quote_identifier(table.schema)) + else: + names.insert(0, self._fold_identifier_case(table.schema)) + + return ".".join(names) + + def _prepare_column(self, column, use_table=True, **kwargs): + names = [] + if column.quote: + names.append(self._quote_identifier(column.name)) + else: + names.append(self._fold_identifier_case(column.name)) + + if use_table: + names.insert(0, self._prepare_table(column.table, **kwargs)) + + return ".".join(names) + + def format_table(self, table, use_schema=True): + """Prepare a quoted table and schema name""" + return self._prepare_table(table, use_schema=use_schema) + + def format_column(self, column): + """Prepare a quoted column name""" + return self._prepare_column(column, use_table=False) + + def format_column_with_table(self, column): + """Prepare a quoted column name with table name""" + return self._prepare_column(column) |