diff options
Diffstat (limited to 'lib/sqlalchemy/ansisql.py')
-rw-r--r-- | lib/sqlalchemy/ansisql.py | 79 |
1 files changed, 54 insertions, 25 deletions
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 2b0d7d17e..208b2f603 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -7,7 +7,7 @@ """defines ANSI SQL operations. Contains default implementations for the abstract objects in the sql module.""" -from sqlalchemy import schema, sql, engine, util +from sqlalchemy import schema, sql, engine, util, sql_util import sqlalchemy.engine.default as default import string, re, sets, weakref @@ -28,9 +28,6 @@ RESERVED_WORDS = util.Set(['all', 'analyse', 'analyze', 'and', 'any', 'array', ' LEGAL_CHARACTERS = util.Set(string.ascii_lowercase + string.ascii_uppercase + string.digits + '_$') ILLEGAL_INITIAL_CHARACTERS = util.Set(string.digits + '$') -def create_engine(): - return engine.ComposedSQLEngine(None, ANSIDialect()) - class ANSIDialect(default.DefaultDialect): def __init__(self, cache_identifiers=True, **kwargs): super(ANSIDialect,self).__init__(**kwargs) @@ -174,7 +171,7 @@ class ANSICompiler(sql.Compiled): if n is not None: self.strings[column] = "%s.%s" % (self.preparer.format_table(column.table, use_schema=False), n) elif len(column.table.primary_key) != 0: - self.strings[column] = self.preparer.format_column_with_table(column.table.primary_key[0]) + self.strings[column] = self.preparer.format_column_with_table(list(column.table.primary_key)[0]) else: self.strings[column] = None else: @@ -611,22 +608,30 @@ class ANSICompiler(sql.Compiled): class ANSISchemaGenerator(engine.SchemaIterator): - def __init__(self, engine, proxy, connection=None, checkfirst=False, **params): - super(ANSISchemaGenerator, self).__init__(engine, proxy, **params) + def __init__(self, engine, proxy, connection, checkfirst=False, tables=None, **kwargs): + super(ANSISchemaGenerator, self).__init__(engine, proxy, **kwargs) self.checkfirst = checkfirst + self.tables = tables and util.Set(tables) or None self.connection = connection self.preparer = self.engine.dialect.preparer() - + self.dialect = self.engine.dialect + def get_column_specification(self, column, first_pk=False): raise NotImplementedError() - - def visit_table(self, table): - # the single whitespace before the "(" is significant - # as its MySQL's method of indicating a table name and not a reserved word. - # feel free to localize this logic to the mysql module - if self.checkfirst and self.engine.dialect.has_table(self.connection, table.name): - return + + def visit_metadata(self, metadata): + for table in metadata.table_iterator(reverse=False, tables=self.tables): + if self.checkfirst and self.dialect.has_table(self.connection, table.name): + continue + table.accept_schema_visitor(self, traverse=False) + def visit_table(self, table): + for column in table.columns: + if column.default is not None: + column.default.accept_schema_visitor(self, traverse=False) + #if column.onupdate is not None: + # column.onupdate.accept_schema_visitor(visitor, traverse=False) + self.append("\nCREATE TABLE " + self.preparer.format_table(table) + " (") separator = "\n" @@ -639,15 +644,17 @@ class ANSISchemaGenerator(engine.SchemaIterator): self.append("\t" + self.get_column_specification(column, first_pk=column.primary_key and not first_pk)) if column.primary_key: first_pk = True - + for constraint in column.constraints: + constraint.accept_schema_visitor(self, traverse=False) + for constraint in table.constraints: - constraint.accept_schema_visitor(self) + constraint.accept_schema_visitor(self, traverse=False) self.append("\n)%s\n\n" % self.post_create_table(table)) - self.execute() + self.execute() if hasattr(table, 'indexes'): for index in table.indexes: - self.visit_index(index) + index.accept_schema_visitor(self, traverse=False) def post_create_table(self, table): return '' @@ -662,10 +669,17 @@ class ANSISchemaGenerator(engine.SchemaIterator): return None def _compile(self, tocompile, parameters): + """compile the given string/parameters using this SchemaGenerator's dialect.""" compiler = self.engine.dialect.compiler(tocompile, parameters) compiler.compile() return compiler + def visit_check_constraint(self, constraint): + self.append(", \n\t") + if constraint.name is not None: + self.append("CONSTRAINT %s " % constraint.name) + self.append(" CHECK (%s)" % constraint.sqltext) + def visit_primary_key_constraint(self, constraint): if len(constraint) == 0: return @@ -688,6 +702,13 @@ class ANSISchemaGenerator(engine.SchemaIterator): if constraint.onupdate is not None: self.append(" ON UPDATE %s" % constraint.onupdate) + def visit_unique_constraint(self, constraint): + self.append(", \n\t") + if constraint.name is not None: + self.append("CONSTRAINT %s " % constraint.name) + self.append(" UNIQUE ") + self.append("(%s)" % (string.join([self.preparer.format_column(c) for c in constraint],', '))) + def visit_column(self, column): pass @@ -701,21 +722,29 @@ class ANSISchemaGenerator(engine.SchemaIterator): self.execute() class ANSISchemaDropper(engine.SchemaIterator): - def __init__(self, engine, proxy, connection=None, checkfirst=False, **params): - super(ANSISchemaDropper, self).__init__(engine, proxy, **params) + def __init__(self, engine, proxy, connection, checkfirst=False, tables=None, **kwargs): + super(ANSISchemaDropper, self).__init__(engine, proxy, **kwargs) self.checkfirst = checkfirst + self.tables = tables self.connection = connection self.preparer = self.engine.dialect.preparer() + self.dialect = self.engine.dialect + + def visit_metadata(self, metadata): + for table in metadata.table_iterator(reverse=True, tables=self.tables): + if self.checkfirst and not self.dialect.has_table(self.connection, table.name): + continue + table.accept_schema_visitor(self, traverse=False) def visit_index(self, index): self.append("\nDROP INDEX " + index.name) self.execute() def visit_table(self, table): - # NOTE: indexes on the table will be automatically dropped, so - # no need to drop them individually - if self.checkfirst and not self.engine.dialect.has_table(self.connection, table.name): - return + for column in table.columns: + if column.default is not None: + column.default.accept_schema_visitor(self, traverse=False) + self.append("\nDROP TABLE " + self.preparer.format_table(table)) self.execute() |