diff options
Diffstat (limited to 'lib/sqlalchemy/ansisql.py')
-rw-r--r-- | lib/sqlalchemy/ansisql.py | 77 |
1 files changed, 53 insertions, 24 deletions
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 050e605eb..a75263d91 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -132,6 +132,11 @@ class ANSICompiler(sql.Compiled): # a dictionary of select columns labels mapped to their "generated" label self.column_labels = {} + # a dictionary of ClauseElement subclasses to counters, which are used to + # generate truncated identifier names or "anonymous" identifiers such as + # for aliases + self.generated_ids = {} + # True if this compiled represents an INSERT self.isinsert = False @@ -242,24 +247,27 @@ class ANSICompiler(sql.Compiled): return "" def visit_label(self, label): - labelname = label.name - if len(labelname) >= self.dialect.max_identifier_length(): - labelname = labelname[0:self.dialect.max_identifier_length() - 6] + "_" + hex(random.randint(0, 65535))[2:] + labelname = self._truncated_identifier("colident", label.name) if len(self.select_stack): self.typemap.setdefault(labelname.lower(), label.obj.type) if isinstance(label.obj, sql._ColumnClause): - self.column_labels[label.obj._label] = labelname.lower() + self.column_labels[label.obj._label] = labelname self.strings[label] = self.strings[label.obj] + " AS " + self.preparer.format_label(label, labelname) def visit_column(self, column): - if len(self.select_stack): - # if we are within a visit to a Select, set up the "typemap" - # for this column which is used to translate result set values - self.typemap.setdefault(column.name.lower(), column.type) - self.column_labels.setdefault(column._label, column.name.lower()) + # there is actually somewhat of a ruleset when you would *not* necessarily + # want to truncate a column identifier, if its mapped to the name of a + # physical column. but thats very hard to identify at this point, and + # the identifier length should be greater than the id lengths of any physical + # columns so should not matter. + if not column.is_literal: + name = self._truncated_identifier("colident", column.name) + else: + name = column.name + if column.table is None or not column.table.named_with_column(): - self.strings[column] = self.preparer.format_column(column) + self.strings[column] = self.preparer.format_column(column, name=name) else: if column.table.oid_column is column: n = self.dialect.oid_column_name(column) @@ -270,7 +278,13 @@ class ANSICompiler(sql.Compiled): else: self.strings[column] = None else: - self.strings[column] = self.preparer.format_column_with_table(column) + self.strings[column] = self.preparer.format_column_with_table(column, column_name=name) + + if len(self.select_stack): + # if we are within a visit to a Select, set up the "typemap" + # for this column which is used to translate result set values + self.typemap.setdefault(name.lower(), column.type) + self.column_labels.setdefault(column._label, name.lower()) def visit_fromclause(self, fromclause): self.froms[fromclause] = fromclause.name @@ -394,11 +408,23 @@ class ANSICompiler(sql.Compiled): bind_name = bindparam.key if len(bind_name) >= self.dialect.max_identifier_length(): - bind_name = bind_name[0:self.dialect.max_identifier_length() - 6] + "_" + hex(random.randint(0, 65535))[2:] + bind_name = self._truncated_identifier("bindparam", bind_name) # add to bind_names for translation self.bind_names[bindparam] = bind_name return bind_name - + + def _truncated_identifier(self, ident_class, name): + if (ident_class, name) in self.generated_ids: + return self.generated_ids[(ident_class, name)] + if len(name) >= self.dialect.max_identifier_length(): + counter = self.generated_ids.get(ident_class, 1) + truncname = name[0:self.dialect.max_identifier_length() - 6] + "_" + hex(counter)[2:] + self.generated_ids[ident_class] = counter + 1 + else: + truncname = name + self.generated_ids[(ident_class, name)] = truncname + return truncname + def bindparam_string(self, name): return self.bindtemplate % name @@ -1043,30 +1069,33 @@ class ANSIIdentifierPreparer(object): def format_alias(self, alias): return self.__generic_obj_format(alias, alias.name) - def format_table(self, table, use_schema=True): + def format_table(self, table, use_schema=True, name=None): """Prepare a quoted table and schema name.""" - result = self.__generic_obj_format(table, table.name) + if name is None: + name = table.name + result = self.__generic_obj_format(table, name) if use_schema and getattr(table, "schema", None): result = self.__generic_obj_format(table, table.schema) + "." + result return result - def format_column(self, column, use_table=False): + def format_column(self, column, use_table=False, name=None): """Prepare a quoted column name.""" - + if name is None: + name = column.name if not getattr(column, 'is_literal', False): if use_table: - return self.format_table(column.table, use_schema=False) + "." + self.__generic_obj_format(column, column.name) + return self.format_table(column.table, use_schema=False) + "." + self.__generic_obj_format(column, name) else: - return self.__generic_obj_format(column, column.name) + return self.__generic_obj_format(column, name) else: # literal textual elements get stuck into ColumnClause alot, which shouldnt get quoted if use_table: - return self.format_table(column.table, use_schema=False) + "." + column.name + return self.format_table(column.table, use_schema=False) + "." + name else: - return column.name + return name - def format_column_with_table(self, column): + def format_column_with_table(self, column, column_name=None): """Prepare a quoted column name with table name.""" - - return self.format_column(column, use_table=True) + + return self.format_column(column, use_table=True, name=column_name) |