diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2006-10-15 00:07:06 +0000 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2006-10-15 00:07:06 +0000 |
commit | 7e5e985c0e17a2d300f9aa8633c3610db600f2e2 (patch) | |
tree | 553780288c3fc75697d1558187c85f09a9cb42ed /lib/sqlalchemy/ansisql.py | |
parent | 6b40f50b87a03172d77abf0e50f42b565f416645 (diff) | |
download | sqlalchemy-7e5e985c0e17a2d300f9aa8633c3610db600f2e2.tar.gz |
- ForeignKey(Constraint) supports "use_alter=True", to create/drop a foreign key
via ALTER. this allows circular foreign key relationships to be set up.
Diffstat (limited to 'lib/sqlalchemy/ansisql.py')
-rw-r--r-- | lib/sqlalchemy/ansisql.py | 58 |
1 files changed, 47 insertions, 11 deletions
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 208b2f603..b6923c7da 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -606,8 +606,20 @@ class ANSICompiler(sql.Compiled): def __str__(self): return self.get_str(self.statement) - -class ANSISchemaGenerator(engine.SchemaIterator): +class ANSISchemaBase(engine.SchemaIterator): + def find_alterables(self, tables): + alterables = [] + class FindAlterables(schema.SchemaVisitor): + def visit_foreign_key_constraint(self, constraint): + if constraint.use_alter and constraint.table in tables: + alterables.append(constraint) + findalterables = FindAlterables() + for table in tables: + for c in table.constraints: + c.accept_schema_visitor(findalterables) + return alterables + +class ANSISchemaGenerator(ANSISchemaBase): def __init__(self, engine, proxy, connection, checkfirst=False, tables=None, **kwargs): super(ANSISchemaGenerator, self).__init__(engine, proxy, **kwargs) self.checkfirst = checkfirst @@ -620,11 +632,13 @@ class ANSISchemaGenerator(engine.SchemaIterator): raise NotImplementedError() def visit_metadata(self, metadata): - for table in metadata.table_iterator(reverse=False, tables=self.tables): - if self.checkfirst and self.dialect.has_table(self.connection, table.name): - continue + collection = [t for t in metadata.table_iterator(reverse=False, tables=self.tables) if (not self.checkfirst or not self.dialect.has_table(self.connection, t.name))] + for table in collection: table.accept_schema_visitor(self, traverse=False) - + if self.supports_alter(): + for alterable in self.find_alterables(collection): + self.add_foreignkey(alterable) + def visit_table(self, table): for column in table.columns: if column.default is not None: @@ -687,9 +701,22 @@ class ANSISchemaGenerator(engine.SchemaIterator): if constraint.name is not None: self.append("%s " % constraint.name) self.append("(%s)" % (string.join([self.preparer.format_column(c) for c in constraint],', '))) - + + def supports_alter(self): + return True + def visit_foreign_key_constraint(self, constraint): + if constraint.use_alter and self.supports_alter(): + return self.append(", \n\t ") + self.define_foreign_key(constraint) + + def add_foreignkey(self, constraint): + self.append("ALTER TABLE %s ADD " % self.preparer.format_table(constraint.table)) + self.define_foreign_key(constraint) + self.execute() + + def define_foreign_key(self, constraint): if constraint.name is not None: self.append("CONSTRAINT %s " % constraint.name) self.append("FOREIGN KEY(%s) REFERENCES %s (%s)" % ( @@ -721,7 +748,7 @@ class ANSISchemaGenerator(engine.SchemaIterator): string.join([self.preparer.format_column(c) for c in index.columns], ', '))) self.execute() -class ANSISchemaDropper(engine.SchemaIterator): +class ANSISchemaDropper(ANSISchemaBase): def __init__(self, engine, proxy, connection, checkfirst=False, tables=None, **kwargs): super(ANSISchemaDropper, self).__init__(engine, proxy, **kwargs) self.checkfirst = checkfirst @@ -731,14 +758,23 @@ class ANSISchemaDropper(engine.SchemaIterator): self.dialect = self.engine.dialect def visit_metadata(self, metadata): - for table in metadata.table_iterator(reverse=True, tables=self.tables): - if self.checkfirst and not self.dialect.has_table(self.connection, table.name): - continue + collection = [t for t in metadata.table_iterator(reverse=True, tables=self.tables) if (not self.checkfirst or self.dialect.has_table(self.connection, t.name))] + if self.supports_alter(): + for alterable in self.find_alterables(collection): + self.drop_foreignkey(alterable) + for table in collection: table.accept_schema_visitor(self, traverse=False) + def supports_alter(self): + return True + def visit_index(self, index): self.append("\nDROP INDEX " + index.name) self.execute() + + def drop_foreignkey(self, constraint): + self.append("ALTER TABLE %s DROP CONSTRAINT %s" % (self.preparer.format_table(constraint.table), constraint.name)) + self.execute() def visit_table(self, table): for column in table.columns: |