diff options
Diffstat (limited to 'lib/sqlalchemy/ansisql.py')
-rw-r--r-- | lib/sqlalchemy/ansisql.py | 34 |
1 files changed, 26 insertions, 8 deletions
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 78017bc91..5d01e275c 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -602,7 +602,7 @@ class ANSICompiler(sql.Compiled): class ANSISchemaGenerator(engine.SchemaIterator): - def get_column_specification(self, column, override_pk=False, first_pk=False): + def get_column_specification(self, column, first_pk=False): raise NotImplementedError() def visit_table(self, table): @@ -614,19 +614,17 @@ class ANSISchemaGenerator(engine.SchemaIterator): separator = "\n" # if only one primary key, specify it along with the column - pks = table.primary_key first_pk = False for column in table.columns: self.append(separator) separator = ", \n" - self.append("\t" + self.get_column_specification(column, override_pk=len(pks)>1, first_pk=column.primary_key and not first_pk)) + self.append("\t" + self.get_column_specification(column, first_pk=column.primary_key and not first_pk)) if column.primary_key: first_pk = True - # if multiple primary keys, specify it at the bottom - if len(pks) > 1: - self.append(", \n") - self.append("\tPRIMARY KEY (%s)" % string.join([c.name for c in pks],', ')) - + + for constraint in table.constraints: + constraint.accept_schema_visitor(self) + self.append("\n)%s\n\n" % self.post_create_table(table)) self.execute() if hasattr(table, 'indexes'): @@ -650,6 +648,26 @@ class ANSISchemaGenerator(engine.SchemaIterator): compiler.compile() return compiler + def visit_primary_key_constraint(self, constraint): + if len(constraint) == 0: + return + self.append(", \n") + self.append("\tPRIMARY KEY (%s)" % string.join([c.name for c in constraint],', ')) + + def visit_foreign_key_constraint(self, constraint): + self.append(", \n\t ") + if constraint.name is not None: + self.append("CONSTRAINT %s " % constraint.name) + self.append("FOREIGN KEY(%s) REFERENCES %s (%s)" % ( + string.join([f.parent.name for f in constraint.elements], ', '), + list(constraint.elements)[0].column.table.name, + string.join([f.column.name for f in constraint.elements], ', ') + )) + if constraint.ondelete is not None: + self.append(" ON DELETE %s" % constraint.ondelete) + if constraint.onupdate is not None: + self.append(" ON UPDATE %s" % constraint.onupdate) + def visit_column(self, column): pass |