summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ansisql.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/ansisql.py')
-rw-r--r--lib/sqlalchemy/ansisql.py65
1 files changed, 49 insertions, 16 deletions
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py
index 031c63328..f4b0852e6 100644
--- a/lib/sqlalchemy/ansisql.py
+++ b/lib/sqlalchemy/ansisql.py
@@ -9,7 +9,7 @@ in the sql module."""
from sqlalchemy import schema, sql, engine, util
import sqlalchemy.engine.default as default
-import string, re, sets
+import string, re, sets, weakref
ANSI_FUNCS = sets.ImmutableSet([
'CURRENT_TIME',
@@ -27,6 +27,10 @@ def create_engine():
return engine.ComposedSQLEngine(None, ANSIDialect())
class ANSIDialect(default.DefaultDialect):
+ def __init__(self, **kwargs):
+ super(ANSIDialect,self).__init__(**kwargs)
+ self._identifier_cache = weakref.WeakKeyDictionary()
+
def connect_args(self):
return ([],{})
@@ -46,7 +50,7 @@ class ANSIDialect(default.DefaultDialect):
"""return an IdenfifierPreparer.
This object is used to format table and column names including proper quoting and case conventions."""
- return ANSIIdentifierPreparer()
+ return ANSIIdentifierPreparer(self)
class ANSICompiler(sql.Compiled):
"""default implementation of Compiled, which compiles ClauseElements into ANSI-compliant SQL strings."""
@@ -77,6 +81,7 @@ class ANSICompiler(sql.Compiled):
self.positiontup = []
self.preparer = dialect.preparer()
+
def after_compile(self):
# this re will search for params like :param
# it has a negative lookbehind for an extra ':' so that it doesnt match
@@ -704,8 +709,8 @@ class ANSIDefaultRunner(engine.DefaultRunner):
pass
class ANSIIdentifierPreparer(schema.SchemaVisitor):
- """Transforms identifiers of SchemaItems into ANSI-Compliant delimited identifiers where required"""
- def __init__(self, initial_quote='"', final_quote=None, omit_schema=False):
+ """handles quoting and case-folding of identifiers based on options"""
+ def __init__(self, dialect, initial_quote='"', final_quote=None, omit_schema=False):
"""Constructs a new ANSIIdentifierPreparer object.
initial_quote - Character that begins a delimited identifier
@@ -713,12 +718,12 @@ class ANSIIdentifierPreparer(schema.SchemaVisitor):
omit_schema - prevent prepending schema name. useful for databases that do not support schemae
"""
+ self.dialect = dialect
self.initial_quote = initial_quote
self.final_quote = final_quote or self.initial_quote
self.omit_schema = omit_schema
self.strings = {}
self.__visited = util.Set()
-
def _escape_identifier(self, value):
"""escape an identifier.
@@ -740,31 +745,59 @@ class ANSIIdentifierPreparer(schema.SchemaVisitor):
# some tests would need to be rewritten if this is done.
#return value.upper()
- def _requires_quotes(self, value):
+ def _requires_quotes(self, value, natural_case):
"""return true if the given identifier requires quoting."""
return False
-
+
+ def __requires_quotes_cached(self, value, natural_case):
+ try:
+ return self.dialect._identifier_cache[(value, natural_case)]
+ except KeyError:
+ result = self._requires_quotes(value, natural_case)
+ self.dialect._identifier_cache[(value, natural_case)] = result
+ return result
+
def visit_table(self, table):
if table in self.__visited:
return
- if table.quote or self._requires_quotes(table.name):
+
+ # cache the results within the dialect, weakly keyed to the table
+ try:
+ (self.strings[table], self.strings[(table, 'schema')]) = self.dialect._identifier_cache[table]
+ return
+ except KeyError:
+ pass
+
+ if table.quote or self._requires_quotes(table.name, table.natural_case):
self.strings[table] = self._quote_identifier(table.name)
else:
- self.strings[table] = table.name # TODO: case folding ?
+ self.strings[table] = table.name
if table.schema:
- if table.quote_schema or self._requires_quotes(table.quote_schema):
+ if table.quote_schema or self._requires_quotes(table.schema, table.natural_case_schema):
self.strings[(table, 'schema')] = self._quote_identifier(table.schema)
else:
- self.strings[(table, 'schema')] = table.schema # TODO: case folding ?
-
+ self.strings[(table, 'schema')] = table.schema
+ else:
+ self.strings[(table,'schema')] = None
+ self.dialect._identifier_cache[table] = (self.strings[table], self.strings[(table, 'schema')])
+
def visit_column(self, column):
if column in self.__visited:
return
- if column.quote or self._requires_quotes(column.name):
+
+ # cache the results within the dialect, weakly keyed to the column
+ try:
+ self.strings[column] = self.dialect._identifier_cache[column]
+ return
+ except KeyError:
+ pass
+
+ if column.quote or self._requires_quotes(column.name, column.natural_case):
self.strings[column] = self._quote_identifier(column.name)
else:
- self.strings[column] = column.name # TODO: case folding ?
-
+ self.strings[column] = column.name
+ self.dialect._identifier_cache[column] = self.strings[column]
+
def __start_visit(self, obj):
if obj in self.__visited:
return
@@ -774,7 +807,7 @@ class ANSIIdentifierPreparer(schema.SchemaVisitor):
def __prepare_table(self, table, use_schema=False):
self.__start_visit(table)
- if not self.omit_schema and use_schema and (table, 'schema') in self.strings:
+ if not self.omit_schema and use_schema and self.strings.get((table, 'schema'), None) is not None:
return self.strings[(table, 'schema')] + "." + self.strings.get(table, table.name)
else:
return self.strings.get(table, table.name)