diff options
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 98 |
1 files changed, 43 insertions, 55 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 59eb3cdb3..59964178c 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -421,9 +421,9 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor): anonname = ANONYMOUS_LABEL.sub(self._process_anon, name) - if len(anonname) > self.dialect.max_identifier_length(): + if len(anonname) > self.dialect.max_identifier_length: counter = self.generated_ids.get(ident_class, 1) - truncname = name[0:self.dialect.max_identifier_length() - 6] + "_" + hex(counter)[2:] + truncname = name[0:self.dialect.max_identifier_length - 6] + "_" + hex(counter)[2:] self.generated_ids[ident_class] = counter + 1 else: truncname = anonname @@ -515,7 +515,6 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor): l = co.label(labelname) inner_columns.add(self.process(l)) else: - self.traverse(co) inner_columns.add(self.process(co)) else: l = self.label_select_column(select, co) @@ -620,20 +619,16 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor): # for inserts, this includes Python-side defaults, columns with sequences for dialects # that support sequences, and primary key columns for dialects that explicitly insert # pre-generated primary key values - required_cols = util.Set() - class DefaultVisitor(schema.SchemaVisitor): - def visit_column(s, cd): - if c.primary_key and self.uses_sequences_for_inserts(): - required_cols.add(c) - def visit_column_default(s, cd): - required_cols.add(c) - def visit_sequence(s, seq): - if self.uses_sequences_for_inserts(): - required_cols.add(c) - vis = DefaultVisitor() - for c in insert_stmt.table.c: - if (isinstance(c, schema.SchemaItem) and (self.parameters is None or self.parameters.get(c.key, None) is None)): - vis.traverse(c) + required_cols = [ + c for c in insert_stmt.table.c + if \ + isinstance(c, schema.SchemaItem) and \ + (self.parameters is None or self.parameters.get(c.key, None) is None) and \ + ( + ((c.primary_key or isinstance(c.default, schema.Sequence)) and self.uses_sequences_for_inserts()) or + isinstance(c.default, schema.ColumnDefault) + ) + ] self.isinsert = True colparams = self._get_colparams(insert_stmt, required_cols) @@ -646,14 +641,12 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor): # search for columns who will be required to have an explicit bound value. # for updates, this includes Python-side "onupdate" defaults. - required_cols = util.Set() - class OnUpdateVisitor(schema.SchemaVisitor): - def visit_column_onupdate(s, cd): - required_cols.add(c) - vis = OnUpdateVisitor() - for c in update_stmt.table.c: - if (isinstance(c, schema.SchemaItem) and (self.parameters is None or self.parameters.get(c.key, None) is None)): - vis.traverse(c) + required_cols = [c for c in update_stmt.table.c + if + isinstance(c, schema.SchemaItem) and \ + (self.parameters is None or self.parameters.get(c.key, None) is None) and + isinstance(c.onupdate, schema.ColumnDefault) + ] self.isupdate = True colparams = self._get_colparams(update_stmt, required_cols) @@ -681,11 +674,6 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor): self.binds[col.key] = bindparam return self.bindparam_string(self._truncate_bindparam(bindparam)) - def create_clause_param(col, value): - self.traverse(value) - self.inline_params.add(col) - return self.process(value) - self.inline_params = util.Set() def to_col(key): @@ -704,25 +692,28 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor): if self.parameters is None: parameters = {} else: - parameters = dict([(to_col(k), v) for k, v in self.parameters.iteritems()]) + parameters = dict([(getattr(k, 'key', k), v) for k, v in self.parameters.iteritems()]) if stmt.parameters is not None: for k, v in stmt.parameters.iteritems(): - parameters.setdefault(to_col(k), v) + parameters.setdefault(getattr(k, 'key', k), v) for col in required_cols: - parameters.setdefault(col, None) + parameters.setdefault(col.key, None) # create a list of column assignment clauses as tuples values = [] for c in stmt.table.columns: - if c in parameters: - value = parameters[c] - if sql._is_literal(value): - value = create_bind_param(c, value) - else: - value = create_clause_param(c, value) - values.append((c, value)) + if c.key in parameters: + value = parameters[c.key] + else: + continue + if sql._is_literal(value): + value = create_bind_param(c, value) + else: + self.inline_params.add(c) + value = self.process(value) + values.append((c, value)) return values @@ -778,7 +769,7 @@ class SchemaGenerator(DDLBase): collection = [t for t in metadata.table_iterator(reverse=False, tables=self.tables) if (not self.checkfirst or not self.dialect.has_table(self.connection, t.name, schema=t.schema))] for table in collection: self.traverse_single(table) - if self.dialect.supports_alter(): + if self.dialect.supports_alter: for alterable in self.find_alterables(collection): self.add_foreignkey(alterable) @@ -853,7 +844,7 @@ class SchemaGenerator(DDLBase): self.append("(%s)" % ', '.join([self.preparer.format_column(c) for c in constraint])) def visit_foreign_key_constraint(self, constraint): - if constraint.use_alter and self.dialect.supports_alter(): + if constraint.use_alter and self.dialect.supports_alter: return self.append(", \n\t ") self.define_foreign_key(constraint) @@ -909,7 +900,7 @@ class SchemaDropper(DDLBase): def visit_metadata(self, metadata): collection = [t for t in metadata.table_iterator(reverse=True, tables=self.tables) if (not self.checkfirst or self.dialect.has_table(self.connection, t.name, schema=t.schema))] - if self.dialect.supports_alter(): + if self.dialect.supports_alter: for alterable in self.find_alterables(collection): self.drop_foreignkey(alterable) for table in collection: @@ -936,6 +927,12 @@ class SchemaDropper(DDLBase): class IdentifierPreparer(object): """Handle quoting and case-folding of identifiers based on options.""" + reserved_words = RESERVED_WORDS + + legal_characters = LEGAL_CHARACTERS + + illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS + def __init__(self, dialect, initial_quote='"', final_quote=None, omit_schema=False): """Construct a new ``IdentifierPreparer`` object. @@ -995,21 +992,12 @@ class IdentifierPreparer(object): # some tests would need to be rewritten if this is done. #return value.upper() - def _reserved_words(self): - return RESERVED_WORDS - - def _legal_characters(self): - return LEGAL_CHARACTERS - - def _illegal_initial_characters(self): - return ILLEGAL_INITIAL_CHARACTERS - def _requires_quotes(self, value): """Return True if the given identifier requires quoting.""" return \ - value in self._reserved_words() \ - or (value[0] in self._illegal_initial_characters()) \ - or bool(len([x for x in unicode(value) if x not in self._legal_characters()])) \ + value in self.reserved_words \ + or (value[0] in self.illegal_initial_characters) \ + or bool(len([x for x in unicode(value) if x not in self.legal_characters])) \ or (value.lower() != value) def __generic_obj_format(self, obj, ident): |