diff options
Diffstat (limited to 'lib/sqlalchemy/ansisql.py')
-rw-r--r-- | lib/sqlalchemy/ansisql.py | 65 |
1 files changed, 49 insertions, 16 deletions
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 031c63328..f4b0852e6 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -9,7 +9,7 @@ in the sql module.""" from sqlalchemy import schema, sql, engine, util import sqlalchemy.engine.default as default -import string, re, sets +import string, re, sets, weakref ANSI_FUNCS = sets.ImmutableSet([ 'CURRENT_TIME', @@ -27,6 +27,10 @@ def create_engine(): return engine.ComposedSQLEngine(None, ANSIDialect()) class ANSIDialect(default.DefaultDialect): + def __init__(self, **kwargs): + super(ANSIDialect,self).__init__(**kwargs) + self._identifier_cache = weakref.WeakKeyDictionary() + def connect_args(self): return ([],{}) @@ -46,7 +50,7 @@ class ANSIDialect(default.DefaultDialect): """return an IdenfifierPreparer. This object is used to format table and column names including proper quoting and case conventions.""" - return ANSIIdentifierPreparer() + return ANSIIdentifierPreparer(self) class ANSICompiler(sql.Compiled): """default implementation of Compiled, which compiles ClauseElements into ANSI-compliant SQL strings.""" @@ -77,6 +81,7 @@ class ANSICompiler(sql.Compiled): self.positiontup = [] self.preparer = dialect.preparer() + def after_compile(self): # this re will search for params like :param # it has a negative lookbehind for an extra ':' so that it doesnt match @@ -704,8 +709,8 @@ class ANSIDefaultRunner(engine.DefaultRunner): pass class ANSIIdentifierPreparer(schema.SchemaVisitor): - """Transforms identifiers of SchemaItems into ANSI-Compliant delimited identifiers where required""" - def __init__(self, initial_quote='"', final_quote=None, omit_schema=False): + """handles quoting and case-folding of identifiers based on options""" + def __init__(self, dialect, initial_quote='"', final_quote=None, omit_schema=False): """Constructs a new ANSIIdentifierPreparer object. initial_quote - Character that begins a delimited identifier @@ -713,12 +718,12 @@ class ANSIIdentifierPreparer(schema.SchemaVisitor): omit_schema - prevent prepending schema name. useful for databases that do not support schemae """ + self.dialect = dialect self.initial_quote = initial_quote self.final_quote = final_quote or self.initial_quote self.omit_schema = omit_schema self.strings = {} self.__visited = util.Set() - def _escape_identifier(self, value): """escape an identifier. @@ -740,31 +745,59 @@ class ANSIIdentifierPreparer(schema.SchemaVisitor): # some tests would need to be rewritten if this is done. #return value.upper() - def _requires_quotes(self, value): + def _requires_quotes(self, value, natural_case): """return true if the given identifier requires quoting.""" return False - + + def __requires_quotes_cached(self, value, natural_case): + try: + return self.dialect._identifier_cache[(value, natural_case)] + except KeyError: + result = self._requires_quotes(value, natural_case) + self.dialect._identifier_cache[(value, natural_case)] = result + return result + def visit_table(self, table): if table in self.__visited: return - if table.quote or self._requires_quotes(table.name): + + # cache the results within the dialect, weakly keyed to the table + try: + (self.strings[table], self.strings[(table, 'schema')]) = self.dialect._identifier_cache[table] + return + except KeyError: + pass + + if table.quote or self._requires_quotes(table.name, table.natural_case): self.strings[table] = self._quote_identifier(table.name) else: - self.strings[table] = table.name # TODO: case folding ? + self.strings[table] = table.name if table.schema: - if table.quote_schema or self._requires_quotes(table.quote_schema): + if table.quote_schema or self._requires_quotes(table.schema, table.natural_case_schema): self.strings[(table, 'schema')] = self._quote_identifier(table.schema) else: - self.strings[(table, 'schema')] = table.schema # TODO: case folding ? - + self.strings[(table, 'schema')] = table.schema + else: + self.strings[(table,'schema')] = None + self.dialect._identifier_cache[table] = (self.strings[table], self.strings[(table, 'schema')]) + def visit_column(self, column): if column in self.__visited: return - if column.quote or self._requires_quotes(column.name): + + # cache the results within the dialect, weakly keyed to the column + try: + self.strings[column] = self.dialect._identifier_cache[column] + return + except KeyError: + pass + + if column.quote or self._requires_quotes(column.name, column.natural_case): self.strings[column] = self._quote_identifier(column.name) else: - self.strings[column] = column.name # TODO: case folding ? - + self.strings[column] = column.name + self.dialect._identifier_cache[column] = self.strings[column] + def __start_visit(self, obj): if obj in self.__visited: return @@ -774,7 +807,7 @@ class ANSIIdentifierPreparer(schema.SchemaVisitor): def __prepare_table(self, table, use_schema=False): self.__start_visit(table) - if not self.omit_schema and use_schema and (table, 'schema') in self.strings: + if not self.omit_schema and use_schema and self.strings.get((table, 'schema'), None) is not None: return self.strings[(table, 'schema')] + "." + self.strings.get(table, table.name) else: return self.strings.get(table, table.name) |