summaryrefslogtreecommitdiff
path: root/alembic/operations/batch.py
diff options
context:
space:
mode:
Diffstat (limited to 'alembic/operations/batch.py')
-rw-r--r--alembic/operations/batch.py287
1 files changed, 287 insertions, 0 deletions
diff --git a/alembic/operations/batch.py b/alembic/operations/batch.py
new file mode 100644
index 0000000..726df78
--- /dev/null
+++ b/alembic/operations/batch.py
@@ -0,0 +1,287 @@
+from sqlalchemy import Table, MetaData, Index, select, Column, \
+ ForeignKeyConstraint, cast
+from sqlalchemy import types as sqltypes
+from sqlalchemy import schema as sql_schema
+from sqlalchemy.util import OrderedDict
+from .. import util
+from ..util.sqla_compat import _columns_for_constraint, _is_type_bound
+
+
+class BatchOperationsImpl(object):
+ def __init__(self, operations, table_name, schema, recreate,
+ copy_from, table_args, table_kwargs,
+ reflect_args, reflect_kwargs, naming_convention):
+ if not util.sqla_08:
+ raise NotImplementedError(
+ "batch mode requires SQLAlchemy 0.8 or greater.")
+ self.operations = operations
+ self.table_name = table_name
+ self.schema = schema
+ if recreate not in ('auto', 'always', 'never'):
+ raise ValueError(
+ "recreate may be one of 'auto', 'always', or 'never'.")
+ self.recreate = recreate
+ self.copy_from = copy_from
+ self.table_args = table_args
+ self.table_kwargs = table_kwargs
+ self.reflect_args = reflect_args
+ self.reflect_kwargs = reflect_kwargs
+ self.naming_convention = naming_convention
+ self.batch = []
+
+ @property
+ def dialect(self):
+ return self.operations.impl.dialect
+
+ @property
+ def impl(self):
+ return self.operations.impl
+
+ def _should_recreate(self):
+ if self.recreate == 'auto':
+ return self.operations.impl.requires_recreate_in_batch(self)
+ elif self.recreate == 'always':
+ return True
+ else:
+ return False
+
+ def flush(self):
+ should_recreate = self._should_recreate()
+
+ if not should_recreate:
+ for opname, arg, kw in self.batch:
+ fn = getattr(self.operations.impl, opname)
+ fn(*arg, **kw)
+ else:
+ if self.naming_convention:
+ m1 = MetaData(naming_convention=self.naming_convention)
+ else:
+ m1 = MetaData()
+
+ if self.copy_from is not None:
+ existing_table = self.copy_from
+ else:
+ existing_table = Table(
+ self.table_name, m1,
+ schema=self.schema,
+ autoload=True,
+ autoload_with=self.operations.get_bind(),
+ *self.reflect_args, **self.reflect_kwargs)
+
+ batch_impl = ApplyBatchImpl(
+ existing_table, self.table_args, self.table_kwargs)
+ for opname, arg, kw in self.batch:
+ fn = getattr(batch_impl, opname)
+ fn(*arg, **kw)
+
+ batch_impl._create(self.impl)
+
+ def alter_column(self, *arg, **kw):
+ self.batch.append(("alter_column", arg, kw))
+
+ def add_column(self, *arg, **kw):
+ self.batch.append(("add_column", arg, kw))
+
+ def drop_column(self, *arg, **kw):
+ self.batch.append(("drop_column", arg, kw))
+
+ def add_constraint(self, const):
+ self.batch.append(("add_constraint", (const,), {}))
+
+ def drop_constraint(self, const):
+ self.batch.append(("drop_constraint", (const, ), {}))
+
+ def rename_table(self, *arg, **kw):
+ self.batch.append(("rename_table", arg, kw))
+
+ def create_index(self, idx):
+ self.batch.append(("create_index", (idx,), {}))
+
+ def drop_index(self, idx):
+ self.batch.append(("drop_index", (idx,), {}))
+
+ def create_table(self, table):
+ raise NotImplementedError("Can't create table in batch mode")
+
+ def drop_table(self, table):
+ raise NotImplementedError("Can't drop table in batch mode")
+
+
+class ApplyBatchImpl(object):
+ def __init__(self, table, table_args, table_kwargs):
+ self.table = table # this is a Table object
+ self.table_args = table_args
+ self.table_kwargs = table_kwargs
+ self.new_table = None
+ self.column_transfers = OrderedDict(
+ (c.name, {'expr': c}) for c in self.table.c
+ )
+ self._grab_table_elements()
+
+ def _grab_table_elements(self):
+ schema = self.table.schema
+ self.columns = OrderedDict()
+ for c in self.table.c:
+ c_copy = c.copy(schema=schema)
+ c_copy.unique = c_copy.index = False
+ self.columns[c.name] = c_copy
+ self.named_constraints = {}
+ self.unnamed_constraints = []
+ self.indexes = {}
+ for const in self.table.constraints:
+ if _is_type_bound(const):
+ continue
+ if const.name:
+ self.named_constraints[const.name] = const
+ else:
+ self.unnamed_constraints.append(const)
+
+ for idx in self.table.indexes:
+ self.indexes[idx.name] = idx
+
+ def _transfer_elements_to_new_table(self):
+ assert self.new_table is None, "Can only create new table once"
+
+ m = MetaData()
+ schema = self.table.schema
+ self.new_table = new_table = Table(
+ '_alembic_batch_temp', m,
+ *(list(self.columns.values()) + list(self.table_args)),
+ schema=schema,
+ **self.table_kwargs)
+
+ for const in list(self.named_constraints.values()) + \
+ self.unnamed_constraints:
+
+ const_columns = set([
+ c.key for c in _columns_for_constraint(const)])
+
+ if not const_columns.issubset(self.column_transfers):
+ continue
+ const_copy = const.copy(schema=schema, target_table=new_table)
+ if isinstance(const, ForeignKeyConstraint):
+ self._setup_referent(m, const)
+ new_table.append_constraint(const_copy)
+
+ for index in self.indexes.values():
+ Index(index.name,
+ unique=index.unique,
+ *[new_table.c[col] for col in index.columns.keys()],
+ **index.kwargs)
+
+ def _setup_referent(self, metadata, constraint):
+ spec = constraint.elements[0]._get_colspec()
+ parts = spec.split(".")
+ tname = parts[-2]
+ if len(parts) == 3:
+ referent_schema = parts[0]
+ else:
+ referent_schema = None
+ if tname != '_alembic_batch_temp':
+ key = sql_schema._get_table_key(tname, referent_schema)
+ if key in metadata.tables:
+ t = metadata.tables[key]
+ for elem in constraint.elements:
+ colname = elem._get_colspec().split(".")[-1]
+ if not t.c.contains_column(colname):
+ t.append_column(
+ Column(colname, sqltypes.NULLTYPE)
+ )
+ else:
+ Table(
+ tname, metadata,
+ *[Column(n, sqltypes.NULLTYPE) for n in
+ [elem._get_colspec().split(".")[-1]
+ for elem in constraint.elements]],
+ schema=referent_schema)
+
+ def _create(self, op_impl):
+ self._transfer_elements_to_new_table()
+
+ op_impl.prep_table_for_batch(self.table)
+ op_impl.create_table(self.new_table)
+
+ try:
+ op_impl._exec(
+ self.new_table.insert(inline=True).from_select(
+ list(k for k, transfer in
+ self.column_transfers.items() if 'expr' in transfer),
+ select([
+ transfer['expr']
+ for transfer in self.column_transfers.values()
+ if 'expr' in transfer
+ ])
+ )
+ )
+ op_impl.drop_table(self.table)
+ except:
+ op_impl.drop_table(self.new_table)
+ raise
+ else:
+ op_impl.rename_table(
+ "_alembic_batch_temp",
+ self.table.name,
+ schema=self.table.schema
+ )
+
+ def alter_column(self, table_name, column_name,
+ nullable=None,
+ server_default=False,
+ name=None,
+ type_=None,
+ autoincrement=None,
+ **kw
+ ):
+ existing = self.columns[column_name]
+ existing_transfer = self.column_transfers[column_name]
+ if name is not None and name != column_name:
+ # note that we don't change '.key' - we keep referring
+ # to the renamed column by its old key in _create(). neat!
+ existing.name = name
+ existing_transfer["name"] = name
+
+ if type_ is not None:
+ type_ = sqltypes.to_instance(type_)
+ existing.type = type_
+ existing_transfer["expr"] = cast(existing_transfer["expr"], type_)
+ if nullable is not None:
+ existing.nullable = nullable
+ if server_default is not False:
+ existing.server_default = server_default
+ if autoincrement is not None:
+ existing.autoincrement = bool(autoincrement)
+
+ def add_column(self, table_name, column, **kw):
+ # we copy the column because operations.add_column()
+ # gives us a Column that is part of a Table already.
+ self.columns[column.name] = column.copy(schema=self.table.schema)
+ self.column_transfers[column.name] = {}
+
+ def drop_column(self, table_name, column, **kw):
+ del self.columns[column.name]
+ del self.column_transfers[column.name]
+
+ def add_constraint(self, const):
+ if not const.name:
+ raise ValueError("Constraint must have a name")
+ self.named_constraints[const.name] = const
+
+ def drop_constraint(self, const):
+ if not const.name:
+ raise ValueError("Constraint must have a name")
+ try:
+ del self.named_constraints[const.name]
+ except KeyError:
+ raise ValueError("No such constraint: '%s'" % const.name)
+
+ def create_index(self, idx):
+ self.indexes[idx.name] = idx
+
+ def drop_index(self, idx):
+ try:
+ del self.indexes[idx.name]
+ except KeyError:
+ raise ValueError("No such index: '%s'" % idx.name)
+
+ def rename_table(self, *arg, **kw):
+ raise NotImplementedError("TODO")