diff options
Diffstat (limited to 'alembic/operations/batch.py')
-rw-r--r-- | alembic/operations/batch.py | 287 |
1 files changed, 287 insertions, 0 deletions
diff --git a/alembic/operations/batch.py b/alembic/operations/batch.py new file mode 100644 index 0000000..726df78 --- /dev/null +++ b/alembic/operations/batch.py @@ -0,0 +1,287 @@ +from sqlalchemy import Table, MetaData, Index, select, Column, \ + ForeignKeyConstraint, cast +from sqlalchemy import types as sqltypes +from sqlalchemy import schema as sql_schema +from sqlalchemy.util import OrderedDict +from .. import util +from ..util.sqla_compat import _columns_for_constraint, _is_type_bound + + +class BatchOperationsImpl(object): + def __init__(self, operations, table_name, schema, recreate, + copy_from, table_args, table_kwargs, + reflect_args, reflect_kwargs, naming_convention): + if not util.sqla_08: + raise NotImplementedError( + "batch mode requires SQLAlchemy 0.8 or greater.") + self.operations = operations + self.table_name = table_name + self.schema = schema + if recreate not in ('auto', 'always', 'never'): + raise ValueError( + "recreate may be one of 'auto', 'always', or 'never'.") + self.recreate = recreate + self.copy_from = copy_from + self.table_args = table_args + self.table_kwargs = table_kwargs + self.reflect_args = reflect_args + self.reflect_kwargs = reflect_kwargs + self.naming_convention = naming_convention + self.batch = [] + + @property + def dialect(self): + return self.operations.impl.dialect + + @property + def impl(self): + return self.operations.impl + + def _should_recreate(self): + if self.recreate == 'auto': + return self.operations.impl.requires_recreate_in_batch(self) + elif self.recreate == 'always': + return True + else: + return False + + def flush(self): + should_recreate = self._should_recreate() + + if not should_recreate: + for opname, arg, kw in self.batch: + fn = getattr(self.operations.impl, opname) + fn(*arg, **kw) + else: + if self.naming_convention: + m1 = MetaData(naming_convention=self.naming_convention) + else: + m1 = MetaData() + + if self.copy_from is not None: + existing_table = self.copy_from + else: + existing_table = Table( + self.table_name, m1, + schema=self.schema, + autoload=True, + autoload_with=self.operations.get_bind(), + *self.reflect_args, **self.reflect_kwargs) + + batch_impl = ApplyBatchImpl( + existing_table, self.table_args, self.table_kwargs) + for opname, arg, kw in self.batch: + fn = getattr(batch_impl, opname) + fn(*arg, **kw) + + batch_impl._create(self.impl) + + def alter_column(self, *arg, **kw): + self.batch.append(("alter_column", arg, kw)) + + def add_column(self, *arg, **kw): + self.batch.append(("add_column", arg, kw)) + + def drop_column(self, *arg, **kw): + self.batch.append(("drop_column", arg, kw)) + + def add_constraint(self, const): + self.batch.append(("add_constraint", (const,), {})) + + def drop_constraint(self, const): + self.batch.append(("drop_constraint", (const, ), {})) + + def rename_table(self, *arg, **kw): + self.batch.append(("rename_table", arg, kw)) + + def create_index(self, idx): + self.batch.append(("create_index", (idx,), {})) + + def drop_index(self, idx): + self.batch.append(("drop_index", (idx,), {})) + + def create_table(self, table): + raise NotImplementedError("Can't create table in batch mode") + + def drop_table(self, table): + raise NotImplementedError("Can't drop table in batch mode") + + +class ApplyBatchImpl(object): + def __init__(self, table, table_args, table_kwargs): + self.table = table # this is a Table object + self.table_args = table_args + self.table_kwargs = table_kwargs + self.new_table = None + self.column_transfers = OrderedDict( + (c.name, {'expr': c}) for c in self.table.c + ) + self._grab_table_elements() + + def _grab_table_elements(self): + schema = self.table.schema + self.columns = OrderedDict() + for c in self.table.c: + c_copy = c.copy(schema=schema) + c_copy.unique = c_copy.index = False + self.columns[c.name] = c_copy + self.named_constraints = {} + self.unnamed_constraints = [] + self.indexes = {} + for const in self.table.constraints: + if _is_type_bound(const): + continue + if const.name: + self.named_constraints[const.name] = const + else: + self.unnamed_constraints.append(const) + + for idx in self.table.indexes: + self.indexes[idx.name] = idx + + def _transfer_elements_to_new_table(self): + assert self.new_table is None, "Can only create new table once" + + m = MetaData() + schema = self.table.schema + self.new_table = new_table = Table( + '_alembic_batch_temp', m, + *(list(self.columns.values()) + list(self.table_args)), + schema=schema, + **self.table_kwargs) + + for const in list(self.named_constraints.values()) + \ + self.unnamed_constraints: + + const_columns = set([ + c.key for c in _columns_for_constraint(const)]) + + if not const_columns.issubset(self.column_transfers): + continue + const_copy = const.copy(schema=schema, target_table=new_table) + if isinstance(const, ForeignKeyConstraint): + self._setup_referent(m, const) + new_table.append_constraint(const_copy) + + for index in self.indexes.values(): + Index(index.name, + unique=index.unique, + *[new_table.c[col] for col in index.columns.keys()], + **index.kwargs) + + def _setup_referent(self, metadata, constraint): + spec = constraint.elements[0]._get_colspec() + parts = spec.split(".") + tname = parts[-2] + if len(parts) == 3: + referent_schema = parts[0] + else: + referent_schema = None + if tname != '_alembic_batch_temp': + key = sql_schema._get_table_key(tname, referent_schema) + if key in metadata.tables: + t = metadata.tables[key] + for elem in constraint.elements: + colname = elem._get_colspec().split(".")[-1] + if not t.c.contains_column(colname): + t.append_column( + Column(colname, sqltypes.NULLTYPE) + ) + else: + Table( + tname, metadata, + *[Column(n, sqltypes.NULLTYPE) for n in + [elem._get_colspec().split(".")[-1] + for elem in constraint.elements]], + schema=referent_schema) + + def _create(self, op_impl): + self._transfer_elements_to_new_table() + + op_impl.prep_table_for_batch(self.table) + op_impl.create_table(self.new_table) + + try: + op_impl._exec( + self.new_table.insert(inline=True).from_select( + list(k for k, transfer in + self.column_transfers.items() if 'expr' in transfer), + select([ + transfer['expr'] + for transfer in self.column_transfers.values() + if 'expr' in transfer + ]) + ) + ) + op_impl.drop_table(self.table) + except: + op_impl.drop_table(self.new_table) + raise + else: + op_impl.rename_table( + "_alembic_batch_temp", + self.table.name, + schema=self.table.schema + ) + + def alter_column(self, table_name, column_name, + nullable=None, + server_default=False, + name=None, + type_=None, + autoincrement=None, + **kw + ): + existing = self.columns[column_name] + existing_transfer = self.column_transfers[column_name] + if name is not None and name != column_name: + # note that we don't change '.key' - we keep referring + # to the renamed column by its old key in _create(). neat! + existing.name = name + existing_transfer["name"] = name + + if type_ is not None: + type_ = sqltypes.to_instance(type_) + existing.type = type_ + existing_transfer["expr"] = cast(existing_transfer["expr"], type_) + if nullable is not None: + existing.nullable = nullable + if server_default is not False: + existing.server_default = server_default + if autoincrement is not None: + existing.autoincrement = bool(autoincrement) + + def add_column(self, table_name, column, **kw): + # we copy the column because operations.add_column() + # gives us a Column that is part of a Table already. + self.columns[column.name] = column.copy(schema=self.table.schema) + self.column_transfers[column.name] = {} + + def drop_column(self, table_name, column, **kw): + del self.columns[column.name] + del self.column_transfers[column.name] + + def add_constraint(self, const): + if not const.name: + raise ValueError("Constraint must have a name") + self.named_constraints[const.name] = const + + def drop_constraint(self, const): + if not const.name: + raise ValueError("Constraint must have a name") + try: + del self.named_constraints[const.name] + except KeyError: + raise ValueError("No such constraint: '%s'" % const.name) + + def create_index(self, idx): + self.indexes[idx.name] = idx + + def drop_index(self, idx): + try: + del self.indexes[idx.name] + except KeyError: + raise ValueError("No such index: '%s'" % idx.name) + + def rename_table(self, *arg, **kw): + raise NotImplementedError("TODO") |