diff options
Diffstat (limited to 'lib/sqlalchemy/ansisql.py')
-rw-r--r-- | lib/sqlalchemy/ansisql.py | 48 |
1 files changed, 21 insertions, 27 deletions
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index a75263d91..03053b998 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -49,14 +49,11 @@ class ANSIDialect(default.DefaultDialect): def create_connect_args(self): return ([],{}) - def dbapi(self): - return None + def schemagenerator(self, *args, **kwargs): + return ANSISchemaGenerator(self, *args, **kwargs) - def schemagenerator(self, *args, **params): - return ANSISchemaGenerator(*args, **params) - - def schemadropper(self, *args, **params): - return ANSISchemaDropper(*args, **params) + def schemadropper(self, *args, **kwargs): + return ANSISchemaDropper(self, *args, **kwargs) def compiler(self, statement, parameters, **kwargs): return ANSICompiler(self, statement, parameters, **kwargs) @@ -97,6 +94,9 @@ class ANSICompiler(sql.Compiled): sql.Compiled.__init__(self, dialect, statement, parameters, **kwargs) + # if we are insert/update. set to true when we visit an INSERT or UPDATE + self.isinsert = self.isupdate = False + # a dictionary of bind parameter keys to _BindParamClause instances. self.binds = {} @@ -789,13 +789,12 @@ class ANSISchemaBase(engine.SchemaIterator): return alterables class ANSISchemaGenerator(ANSISchemaBase): - def __init__(self, engine, proxy, connection, checkfirst=False, tables=None, **kwargs): - super(ANSISchemaGenerator, self).__init__(engine, proxy, **kwargs) + def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs): + super(ANSISchemaGenerator, self).__init__(connection, **kwargs) self.checkfirst = checkfirst self.tables = tables and util.Set(tables) or None - self.connection = connection - self.preparer = self.engine.dialect.preparer() - self.dialect = self.engine.dialect + self.preparer = dialect.preparer() + self.dialect = dialect def get_column_specification(self, column, first_pk=False): raise NotImplementedError() @@ -804,7 +803,7 @@ class ANSISchemaGenerator(ANSISchemaBase): collection = [t for t in metadata.table_iterator(reverse=False, tables=self.tables) if (not self.checkfirst or not self.dialect.has_table(self.connection, t.name, schema=t.schema))] for table in collection: table.accept_visitor(self) - if self.supports_alter(): + if self.dialect.supports_alter(): for alterable in self.find_alterables(collection): self.add_foreignkey(alterable) @@ -857,7 +856,7 @@ class ANSISchemaGenerator(ANSISchemaBase): def _compile(self, tocompile, parameters): """compile the given string/parameters using this SchemaGenerator's dialect.""" - compiler = self.engine.dialect.compiler(tocompile, parameters) + compiler = self.dialect.compiler(tocompile, parameters) compiler.compile() return compiler @@ -880,11 +879,8 @@ class ANSISchemaGenerator(ANSISchemaBase): self.append("PRIMARY KEY ") self.append("(%s)" % (string.join([self.preparer.format_column(c) for c in constraint],', '))) - def supports_alter(self): - return True - def visit_foreign_key_constraint(self, constraint): - if constraint.use_alter and self.supports_alter(): + if constraint.use_alter and self.dialect.supports_alter(): return self.append(", \n\t ") self.define_foreign_key(constraint) @@ -927,25 +923,21 @@ class ANSISchemaGenerator(ANSISchemaBase): self.execute() class ANSISchemaDropper(ANSISchemaBase): - def __init__(self, engine, proxy, connection, checkfirst=False, tables=None, **kwargs): - super(ANSISchemaDropper, self).__init__(engine, proxy, **kwargs) + def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs): + super(ANSISchemaDropper, self).__init__(connection, **kwargs) self.checkfirst = checkfirst self.tables = tables - self.connection = connection - self.preparer = self.engine.dialect.preparer() - self.dialect = self.engine.dialect + self.preparer = dialect.preparer() + self.dialect = dialect def visit_metadata(self, metadata): collection = [t for t in metadata.table_iterator(reverse=True, tables=self.tables) if (not self.checkfirst or self.dialect.has_table(self.connection, t.name, schema=t.schema))] - if self.supports_alter(): + if self.dialect.supports_alter(): for alterable in self.find_alterables(collection): self.drop_foreignkey(alterable) for table in collection: table.accept_visitor(self) - def supports_alter(self): - return True - def visit_index(self, index): self.append("\nDROP INDEX " + index.name) self.execute() @@ -1099,3 +1091,5 @@ class ANSIIdentifierPreparer(object): """Prepare a quoted column name with table name.""" return self.format_column(column, use_table=True, name=column_name) + +dialect = ANSIDialect |