diff options
Diffstat (limited to 'alembic/ddl')
-rw-r--r-- | alembic/ddl/mysql.py | 113 |
1 files changed, 88 insertions, 25 deletions
diff --git a/alembic/ddl/mysql.py b/alembic/ddl/mysql.py index 69b32b2..3d954f6 100644 --- a/alembic/ddl/mysql.py +++ b/alembic/ddl/mysql.py @@ -6,7 +6,8 @@ from ..compat import string_types from .. import util from .impl import DefaultImpl from .base import ColumnNullable, ColumnName, ColumnDefault, \ - ColumnType, AlterColumn, format_column_name + ColumnType, AlterColumn, format_column_name, \ + format_server_default from .base import alter_table class MySQLImpl(DefaultImpl): @@ -26,22 +27,49 @@ class MySQLImpl(DefaultImpl): existing_nullable=None, existing_autoincrement=None ): - self._exec( - MySQLAlterColumn( - table_name, column_name, - schema=schema, - newname=name if name is not None else column_name, - nullable=nullable if nullable is not None else - existing_nullable - if existing_nullable is not None - else True, - type_=type_ if type_ is not None else existing_type, - default=server_default if server_default is not False - else existing_server_default, - autoincrement=autoincrement if autoincrement is not None - else existing_autoincrement + if name is not None: + self._exec( + MySQLChangeColumn( + table_name, column_name, + schema=schema, + newname=name, + nullable=nullable if nullable is not None else + existing_nullable + if existing_nullable is not None + else True, + type_=type_ if type_ is not None else existing_type, + default=server_default if server_default is not False + else existing_server_default, + autoincrement=autoincrement if autoincrement is not None + else existing_autoincrement + ) + ) + elif nullable is not None or \ + type_ is not None or \ + autoincrement is not None: + self._exec( + MySQLModifyColumn( + table_name, column_name, + schema=schema, + newname=name if name is not None else column_name, + nullable=nullable if nullable is not None else + existing_nullable + if existing_nullable is not None + else True, + type_=type_ if type_ is not None else existing_type, + default=server_default if server_default is not False + else existing_server_default, + autoincrement=autoincrement if autoincrement is not None + else existing_autoincrement + ) + ) + elif server_default is not False: + self._exec( + MySQLAlterDefault( + table_name, column_name, server_default, + schema=schema, + ) ) - ) def correct_for_autogen_constraints(self, conn_unique_constraints, conn_indexes, metadata_unique_constraints, @@ -53,7 +81,14 @@ class MySQLImpl(DefaultImpl): conn_indexes.remove(idx) -class MySQLAlterColumn(AlterColumn): +class MySQLAlterDefault(AlterColumn): + def __init__(self, name, column_name, default, schema=None): + super(AlterColumn, self).__init__(name, schema=schema) + self.column_name = column_name + self.default = default + + +class MySQLChangeColumn(AlterColumn): def __init__(self, name, column_name, schema=None, newname=None, type_=None, @@ -68,12 +103,16 @@ class MySQLAlterColumn(AlterColumn): self.autoincrement = autoincrement if type_ is None: raise util.CommandError( - "All MySQL ALTER COLUMN operations " + "All MySQL CHANGE/MODIFY COLUMN operations " "require the existing type." ) self.type_ = sqltypes.to_instance(type_) +class MySQLModifyColumn(MySQLChangeColumn): + pass + + @compiles(ColumnNullable, 'mysql') @compiles(ColumnName, 'mysql') @compiles(ColumnDefault, 'mysql') @@ -84,14 +123,39 @@ def _mysql_doesnt_support_individual(element, compiler, **kw): ) -@compiles(MySQLAlterColumn, "mysql") -def _mysql_alter_column(element, compiler, **kw): - return "%s CHANGE %s %s" % ( +@compiles(MySQLAlterDefault, "mysql") +def _mysql_alter_default(element, compiler, **kw): + return "%s ALTER COLUMN %s %s" % ( + alter_table(compiler, element.table_name, element.schema), + format_column_name(compiler, element.column_name), + "SET DEFAULT %s" % format_server_default(compiler, element.default) + if element.default is not None + else "DROP DEFAULT" + ) + +@compiles(MySQLModifyColumn, "mysql") +def _mysql_modify_column(element, compiler, **kw): + return "%s MODIFY %s %s" % ( + alter_table(compiler, element.table_name, element.schema), + format_column_name(compiler, element.column_name), + _mysql_colspec( + compiler, + nullable=element.nullable, + server_default=element.default, + type_=element.type_, + autoincrement=element.autoincrement + ), + ) + + +@compiles(MySQLChangeColumn, "mysql") +def _mysql_change_column(element, compiler, **kw): + return "%s CHANGE %s %s %s" % ( alter_table(compiler, element.table_name, element.schema), format_column_name(compiler, element.column_name), + format_column_name(compiler, element.newname), _mysql_colspec( compiler, - name=element.newname, nullable=element.nullable, server_default=element.default, type_=element.type_, @@ -105,10 +169,9 @@ def _render_value(compiler, expr): else: return compiler.sql_compiler.process(expr) -def _mysql_colspec(compiler, name, nullable, server_default, type_, +def _mysql_colspec(compiler, nullable, server_default, type_, autoincrement): - spec = "%s %s %s" % ( - format_column_name(compiler, name), + spec = "%s %s" % ( compiler.dialect.type_compiler.process(type_), "NULL" if nullable else "NOT NULL" ) |