diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2006-09-04 01:56:31 +0000 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2006-09-04 01:56:31 +0000 |
commit | f5454c89ea82966075e58458b44fe2279d70a361 (patch) | |
tree | b0bd68f4c3a6ca187289c6b0fab59c59b4fcf52a /lib/sqlalchemy/ansisql.py | |
parent | c0d89919ec871c69bbfa32ef83417be61be2291b (diff) | |
download | sqlalchemy-f5454c89ea82966075e58458b44fe2279d70a361.tar.gz |
simplification to quoting to just cache strings per-dialect, added quoting for alias and label names
fixes [ticket:294]
Diffstat (limited to 'lib/sqlalchemy/ansisql.py')
-rw-r--r-- | lib/sqlalchemy/ansisql.py | 126 |
1 files changed, 52 insertions, 74 deletions
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index d65e8ad33..d053f7389 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -32,10 +32,11 @@ def create_engine(): return engine.ComposedSQLEngine(None, ANSIDialect()) class ANSIDialect(default.DefaultDialect): - def __init__(self, **kwargs): + def __init__(self, cache_identifiers=True, **kwargs): super(ANSIDialect,self).__init__(**kwargs) self.identifier_preparer = self.preparer() - + self.cache_identifiers = cache_identifiers + def connect_args(self): return ([],{}) @@ -158,7 +159,7 @@ class ANSICompiler(sql.Compiled): def visit_label(self, label): if len(self.select_stack): self.typemap.setdefault(label.name.lower(), label.obj.type) - self.strings[label] = self.strings[label.obj] + " AS " + label.name + self.strings[label] = self.strings[label.obj] + " AS " + self.preparer.format_label(label) def visit_column(self, column): if len(self.select_stack): @@ -289,7 +290,7 @@ class ANSICompiler(sql.Compiled): return self.bindtemplate % name def visit_alias(self, alias): - self.froms[alias] = self.get_from_text(alias.original) + " AS " + alias.name + self.froms[alias] = self.get_from_text(alias.original) + " AS " + self.preparer.format_alias(alias) self.strings[alias] = self.get_str(alias.original) def visit_select(self, select): @@ -717,7 +718,7 @@ class ANSISchemaDropper(engine.SchemaIterator): class ANSIDefaultRunner(engine.DefaultRunner): pass -class ANSIIdentifierPreparer(schema.SchemaVisitor): +class ANSIIdentifierPreparer(object): """handles quoting and case-folding of identifiers based on options""" def __init__(self, dialect, initial_quote='"', final_quote=None, omit_schema=False): """Constructs a new ANSIIdentifierPreparer object. @@ -731,8 +732,7 @@ class ANSIIdentifierPreparer(schema.SchemaVisitor): self.initial_quote = initial_quote self.final_quote = final_quote or self.initial_quote self.omit_schema = omit_schema - self.__strings = weakref.WeakKeyDictionary() - self.__visited = weakref.WeakKeyDictionary() + self.__strings = {} def _escape_identifier(self, value): """escape an identifier. @@ -771,68 +771,24 @@ class ANSIIdentifierPreparer(schema.SchemaVisitor): or bool(len([x for x in str(value) if x not in self._legal_characters()])) \ or (case_sensitive and value.lower() != value) - def visit_table(self, table): - if table in self.__visited: - return - - if table.quote or self._requires_quotes(table.name, table.case_sensitive): - tablestring = self._quote_identifier(table.name) - else: - tablestring = table.name - - if table.schema: - if table.quote_schema or self._requires_quotes(table.schema, table.case_sensitive_schema): - schemastring = self._quote_identifier(table.schema) - else: - schemastring = table.schema - else: - schemastring = None - - self.__strings[table] = (tablestring, schemastring) - - def visit_column(self, column): - if column in self.__visited: - return - if column.quote or self._requires_quotes(column.name, column.case_sensitive): - self.__strings[column] = self._quote_identifier(column.name) - else: - self.__strings[column] = column.name - - def visit_sequence(self, sequence): - if sequence in self.__visited: - return - if sequence.quote or self._requires_quotes(sequence.name, sequence.case_sensitive): - self.__strings[sequence] = self._quote_identifier(sequence.name) - else: - self.__strings[sequence] = sequence.name - - def __analyze_identifiers(self, obj): - """insure that each object we encounter is analyzed only once for its lifetime.""" - if obj in self.__visited: - return - if isinstance(obj, schema.SchemaItem): - obj.accept_schema_visitor(self) - self.__visited[obj] = True - - def __prepare_sequence(self, sequence): - self.__analyze_identifiers(sequence) - return self.__strings.get(sequence, sequence.name) - - def __prepare_table(self, table, use_schema=False): - self.__analyze_identifiers(table) - tablename = self.__strings.get(table, (table.name, None))[0] - if not self.omit_schema and use_schema and self.__strings.get(table, (None,None))[1] is not None: - return self.__strings[table][1] + "." + tablename - else: - return tablename - - def __prepare_column(self, column, use_table=True, **kwargs): - self.__analyze_identifiers(column) - if use_table: - return self.__prepare_table(column.table, **kwargs) + "." + self.__strings.get(column, column.name) + def __generic_obj_format(self, obj, ident): + if getattr(obj, 'quote', False): + return self._quote_identifier(ident) + if self.dialect.cache_identifiers: + try: + return self.__strings[ident] + except KeyError: + if self._requires_quotes(ident, getattr(obj, 'case_sensitive', ident == ident.lower())): + self.__strings[ident] = self._quote_identifier(ident) + else: + self.__strings[ident] = ident + return self.__strings[ident] else: - return self.__strings.get(column, column.name) - + if self._requires_quotes(ident, getattr(obj, 'case_sensitive', ident == ident.lower())): + return self._quote_identifier(ident) + else: + return ident + def should_quote(self, object): return object.quote or self._requires_quotes(object.name, object.case_sensitive) @@ -840,16 +796,38 @@ class ANSIIdentifierPreparer(schema.SchemaVisitor): return object.quote or self._requires_quotes(object.name, object.case_sensitive) def format_sequence(self, sequence): - return self.__prepare_sequence(sequence) + return self.__generic_obj_format(sequence, sequence.name) + + def format_label(self, label): + return self.__generic_obj_format(label, label.name) + + def format_alias(self, alias): + return self.__generic_obj_format(alias, alias.name) def format_table(self, table, use_schema=True): """Prepare a quoted table and schema name""" - return self.__prepare_table(table, use_schema=use_schema) + result = self.__generic_obj_format(table, table.name) + if use_schema and getattr(table, "schema", None): + result = self.__generic_obj_format(table, table.schema) + "." + result + return result - def format_column(self, column): + def format_column(self, column, use_table=False): """Prepare a quoted column name """ - return self.__prepare_column(column, use_table=False) - + # TODO: isinstance alert ! get ColumnClause and Column to better + # differentiate themselves + if isinstance(column, schema.SchemaItem): + if use_table: + return self.format_table(column.table, use_schema=False) + "." + self.__generic_obj_format(column, column.name) + else: + return self.__generic_obj_format(column, column.name) + else: + # literal textual elements get stuck into ColumnClause alot, which shouldnt get quoted + if use_table: + return column.table.name + "." + column.name + else: + return column.name + def format_column_with_table(self, column): """Prepare a quoted column name with table name""" - return self.__prepare_column(column) + return self.format_column(column, use_table=True) + |