summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ansisql.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/ansisql.py')
-rw-r--r--lib/sqlalchemy/ansisql.py79
1 files changed, 54 insertions, 25 deletions
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py
index 2b0d7d17e..208b2f603 100644
--- a/lib/sqlalchemy/ansisql.py
+++ b/lib/sqlalchemy/ansisql.py
@@ -7,7 +7,7 @@
"""defines ANSI SQL operations. Contains default implementations for the abstract objects
in the sql module."""
-from sqlalchemy import schema, sql, engine, util
+from sqlalchemy import schema, sql, engine, util, sql_util
import sqlalchemy.engine.default as default
import string, re, sets, weakref
@@ -28,9 +28,6 @@ RESERVED_WORDS = util.Set(['all', 'analyse', 'analyze', 'and', 'any', 'array', '
LEGAL_CHARACTERS = util.Set(string.ascii_lowercase + string.ascii_uppercase + string.digits + '_$')
ILLEGAL_INITIAL_CHARACTERS = util.Set(string.digits + '$')
-def create_engine():
- return engine.ComposedSQLEngine(None, ANSIDialect())
-
class ANSIDialect(default.DefaultDialect):
def __init__(self, cache_identifiers=True, **kwargs):
super(ANSIDialect,self).__init__(**kwargs)
@@ -174,7 +171,7 @@ class ANSICompiler(sql.Compiled):
if n is not None:
self.strings[column] = "%s.%s" % (self.preparer.format_table(column.table, use_schema=False), n)
elif len(column.table.primary_key) != 0:
- self.strings[column] = self.preparer.format_column_with_table(column.table.primary_key[0])
+ self.strings[column] = self.preparer.format_column_with_table(list(column.table.primary_key)[0])
else:
self.strings[column] = None
else:
@@ -611,22 +608,30 @@ class ANSICompiler(sql.Compiled):
class ANSISchemaGenerator(engine.SchemaIterator):
- def __init__(self, engine, proxy, connection=None, checkfirst=False, **params):
- super(ANSISchemaGenerator, self).__init__(engine, proxy, **params)
+ def __init__(self, engine, proxy, connection, checkfirst=False, tables=None, **kwargs):
+ super(ANSISchemaGenerator, self).__init__(engine, proxy, **kwargs)
self.checkfirst = checkfirst
+ self.tables = tables and util.Set(tables) or None
self.connection = connection
self.preparer = self.engine.dialect.preparer()
-
+ self.dialect = self.engine.dialect
+
def get_column_specification(self, column, first_pk=False):
raise NotImplementedError()
-
- def visit_table(self, table):
- # the single whitespace before the "(" is significant
- # as its MySQL's method of indicating a table name and not a reserved word.
- # feel free to localize this logic to the mysql module
- if self.checkfirst and self.engine.dialect.has_table(self.connection, table.name):
- return
+
+ def visit_metadata(self, metadata):
+ for table in metadata.table_iterator(reverse=False, tables=self.tables):
+ if self.checkfirst and self.dialect.has_table(self.connection, table.name):
+ continue
+ table.accept_schema_visitor(self, traverse=False)
+ def visit_table(self, table):
+ for column in table.columns:
+ if column.default is not None:
+ column.default.accept_schema_visitor(self, traverse=False)
+ #if column.onupdate is not None:
+ # column.onupdate.accept_schema_visitor(visitor, traverse=False)
+
self.append("\nCREATE TABLE " + self.preparer.format_table(table) + " (")
separator = "\n"
@@ -639,15 +644,17 @@ class ANSISchemaGenerator(engine.SchemaIterator):
self.append("\t" + self.get_column_specification(column, first_pk=column.primary_key and not first_pk))
if column.primary_key:
first_pk = True
-
+ for constraint in column.constraints:
+ constraint.accept_schema_visitor(self, traverse=False)
+
for constraint in table.constraints:
- constraint.accept_schema_visitor(self)
+ constraint.accept_schema_visitor(self, traverse=False)
self.append("\n)%s\n\n" % self.post_create_table(table))
- self.execute()
+ self.execute()
if hasattr(table, 'indexes'):
for index in table.indexes:
- self.visit_index(index)
+ index.accept_schema_visitor(self, traverse=False)
def post_create_table(self, table):
return ''
@@ -662,10 +669,17 @@ class ANSISchemaGenerator(engine.SchemaIterator):
return None
def _compile(self, tocompile, parameters):
+ """compile the given string/parameters using this SchemaGenerator's dialect."""
compiler = self.engine.dialect.compiler(tocompile, parameters)
compiler.compile()
return compiler
+ def visit_check_constraint(self, constraint):
+ self.append(", \n\t")
+ if constraint.name is not None:
+ self.append("CONSTRAINT %s " % constraint.name)
+ self.append(" CHECK (%s)" % constraint.sqltext)
+
def visit_primary_key_constraint(self, constraint):
if len(constraint) == 0:
return
@@ -688,6 +702,13 @@ class ANSISchemaGenerator(engine.SchemaIterator):
if constraint.onupdate is not None:
self.append(" ON UPDATE %s" % constraint.onupdate)
+ def visit_unique_constraint(self, constraint):
+ self.append(", \n\t")
+ if constraint.name is not None:
+ self.append("CONSTRAINT %s " % constraint.name)
+ self.append(" UNIQUE ")
+ self.append("(%s)" % (string.join([self.preparer.format_column(c) for c in constraint],', ')))
+
def visit_column(self, column):
pass
@@ -701,21 +722,29 @@ class ANSISchemaGenerator(engine.SchemaIterator):
self.execute()
class ANSISchemaDropper(engine.SchemaIterator):
- def __init__(self, engine, proxy, connection=None, checkfirst=False, **params):
- super(ANSISchemaDropper, self).__init__(engine, proxy, **params)
+ def __init__(self, engine, proxy, connection, checkfirst=False, tables=None, **kwargs):
+ super(ANSISchemaDropper, self).__init__(engine, proxy, **kwargs)
self.checkfirst = checkfirst
+ self.tables = tables
self.connection = connection
self.preparer = self.engine.dialect.preparer()
+ self.dialect = self.engine.dialect
+
+ def visit_metadata(self, metadata):
+ for table in metadata.table_iterator(reverse=True, tables=self.tables):
+ if self.checkfirst and not self.dialect.has_table(self.connection, table.name):
+ continue
+ table.accept_schema_visitor(self, traverse=False)
def visit_index(self, index):
self.append("\nDROP INDEX " + index.name)
self.execute()
def visit_table(self, table):
- # NOTE: indexes on the table will be automatically dropped, so
- # no need to drop them individually
- if self.checkfirst and not self.engine.dialect.has_table(self.connection, table.name):
- return
+ for column in table.columns:
+ if column.default is not None:
+ column.default.accept_schema_visitor(self, traverse=False)
+
self.append("\nDROP TABLE " + self.preparer.format_table(table))
self.execute()