diff options
Diffstat (limited to 'alembic/runtime/migration.py')
-rw-r--r-- | alembic/runtime/migration.py | 798 |
1 files changed, 798 insertions, 0 deletions
diff --git a/alembic/runtime/migration.py b/alembic/runtime/migration.py new file mode 100644 index 0000000..84a3c7f --- /dev/null +++ b/alembic/runtime/migration.py @@ -0,0 +1,798 @@ +import logging +import sys +from contextlib import contextmanager + +from sqlalchemy import MetaData, Table, Column, String, literal_column +from sqlalchemy.engine.strategies import MockEngineStrategy +from sqlalchemy.engine import url as sqla_url + +from ..util.compat import callable, EncodedIO +from .. import ddl, util + +log = logging.getLogger(__name__) + + +class MigrationContext(object): + + """Represent the database state made available to a migration + script. + + :class:`.MigrationContext` is the front end to an actual + database connection, or alternatively a string output + stream given a particular database dialect, + from an Alembic perspective. + + When inside the ``env.py`` script, the :class:`.MigrationContext` + is available via the + :meth:`.EnvironmentContext.get_context` method, + which is available at ``alembic.context``:: + + # from within env.py script + from alembic import context + migration_context = context.get_context() + + For usage outside of an ``env.py`` script, such as for + utility routines that want to check the current version + in the database, the :meth:`.MigrationContext.configure` + method to create new :class:`.MigrationContext` objects. + For example, to get at the current revision in the + database using :meth:`.MigrationContext.get_current_revision`:: + + # in any application, outside of an env.py script + from alembic.migration import MigrationContext + from sqlalchemy import create_engine + + engine = create_engine("postgresql://mydatabase") + conn = engine.connect() + + context = MigrationContext.configure(conn) + current_rev = context.get_current_revision() + + The above context can also be used to produce + Alembic migration operations with an :class:`.Operations` + instance:: + + # in any application, outside of the normal Alembic environment + from alembic.operations import Operations + op = Operations(context) + op.alter_column("mytable", "somecolumn", nullable=True) + + """ + + def __init__(self, dialect, connection, opts, environment_context=None): + self.environment_context = environment_context + self.opts = opts + self.dialect = dialect + self.script = opts.get('script') + as_sql = opts.get('as_sql', False) + transactional_ddl = opts.get("transactional_ddl") + + self._transaction_per_migration = opts.get( + "transaction_per_migration", False) + + if as_sql: + self.connection = self._stdout_connection(connection) + assert self.connection is not None + else: + self.connection = connection + self._migrations_fn = opts.get('fn') + self.as_sql = as_sql + + if "output_encoding" in opts: + self.output_buffer = EncodedIO( + opts.get("output_buffer") or sys.stdout, + opts['output_encoding'] + ) + else: + self.output_buffer = opts.get("output_buffer", sys.stdout) + + self._user_compare_type = opts.get('compare_type', False) + self._user_compare_server_default = opts.get( + 'compare_server_default', + False) + self.version_table = version_table = opts.get( + 'version_table', 'alembic_version') + self.version_table_schema = version_table_schema = \ + opts.get('version_table_schema', None) + self._version = Table( + version_table, MetaData(), + Column('version_num', String(32), nullable=False), + schema=version_table_schema) + + self._start_from_rev = opts.get("starting_rev") + self.impl = ddl.DefaultImpl.get_by_dialect(dialect)( + dialect, self.connection, self.as_sql, + transactional_ddl, + self.output_buffer, + opts + ) + log.info("Context impl %s.", self.impl.__class__.__name__) + if self.as_sql: + log.info("Generating static SQL") + log.info("Will assume %s DDL.", + "transactional" if self.impl.transactional_ddl + else "non-transactional") + + @classmethod + def configure(cls, + connection=None, + url=None, + dialect_name=None, + environment_context=None, + opts=None, + ): + """Create a new :class:`.MigrationContext`. + + This is a factory method usually called + by :meth:`.EnvironmentContext.configure`. + + :param connection: a :class:`~sqlalchemy.engine.Connection` + to use for SQL execution in "online" mode. When present, + is also used to determine the type of dialect in use. + :param url: a string database url, or a + :class:`sqlalchemy.engine.url.URL` object. + The type of dialect to be used will be derived from this if + ``connection`` is not passed. + :param dialect_name: string name of a dialect, such as + "postgresql", "mssql", etc. The type of dialect to be used will be + derived from this if ``connection`` and ``url`` are not passed. + :param opts: dictionary of options. Most other options + accepted by :meth:`.EnvironmentContext.configure` are passed via + this dictionary. + + """ + if opts is None: + opts = {} + + if connection: + dialect = connection.dialect + elif url: + url = sqla_url.make_url(url) + dialect = url.get_dialect()() + elif dialect_name: + url = sqla_url.make_url("%s://" % dialect_name) + dialect = url.get_dialect()() + else: + raise Exception("Connection, url, or dialect_name is required.") + + return MigrationContext(dialect, connection, opts, environment_context) + + def begin_transaction(self, _per_migration=False): + transaction_now = _per_migration == self._transaction_per_migration + + if not transaction_now: + @contextmanager + def do_nothing(): + yield + return do_nothing() + + elif not self.impl.transactional_ddl: + @contextmanager + def do_nothing(): + yield + return do_nothing() + elif self.as_sql: + @contextmanager + def begin_commit(): + self.impl.emit_begin() + yield + self.impl.emit_commit() + return begin_commit() + else: + return self.bind.begin() + + def get_current_revision(self): + """Return the current revision, usually that which is present + in the ``alembic_version`` table in the database. + + This method intends to be used only for a migration stream that + does not contain unmerged branches in the target database; + if there are multiple branches present, an exception is raised. + The :meth:`.MigrationContext.get_current_heads` should be preferred + over this method going forward in order to be compatible with + branch migration support. + + If this :class:`.MigrationContext` was configured in "offline" + mode, that is with ``as_sql=True``, the ``starting_rev`` + parameter is returned instead, if any. + + """ + heads = self.get_current_heads() + if len(heads) == 0: + return None + elif len(heads) > 1: + raise util.CommandError( + "Version table '%s' has more than one head present; " + "please use get_current_heads()" % self.version_table) + else: + return heads[0] + + def get_current_heads(self): + """Return a tuple of the current 'head versions' that are represented + in the target database. + + For a migration stream without branches, this will be a single + value, synonymous with that of + :meth:`.MigrationContext.get_current_revision`. However when multiple + unmerged branches exist within the target database, the returned tuple + will contain a value for each head. + + If this :class:`.MigrationContext` was configured in "offline" + mode, that is with ``as_sql=True``, the ``starting_rev`` + parameter is returned in a one-length tuple. + + If no version table is present, or if there are no revisions + present, an empty tuple is returned. + + .. versionadded:: 0.7.0 + + """ + if self.as_sql: + start_from_rev = self._start_from_rev + if start_from_rev is not None and self.script: + start_from_rev = \ + self.script.get_revision(start_from_rev).revision + + return util.to_tuple(start_from_rev, default=()) + else: + if self._start_from_rev: + raise util.CommandError( + "Can't specify current_rev to context " + "when using a database connection") + if not self._has_version_table(): + return () + return tuple( + row[0] for row in self.connection.execute(self._version.select()) + ) + + def _ensure_version_table(self): + self._version.create(self.connection, checkfirst=True) + + def _has_version_table(self): + return self.connection.dialect.has_table( + self.connection, self.version_table, self.version_table_schema) + + def stamp(self, script_directory, revision): + """Stamp the version table with a specific revision. + + This method calculates those branches to which the given revision + can apply, and updates those branches as though they were migrated + towards that revision (either up or down). If no current branches + include the revision, it is added as a new branch head. + + .. versionadded:: 0.7.0 + + """ + heads = self.get_current_heads() + if not self.as_sql and not heads: + self._ensure_version_table() + head_maintainer = HeadMaintainer(self, heads) + for step in script_directory._stamp_revs(revision, heads): + head_maintainer.update_to_step(step) + + def run_migrations(self, **kw): + """Run the migration scripts established for this + :class:`.MigrationContext`, if any. + + The commands in :mod:`alembic.command` will set up a function + that is ultimately passed to the :class:`.MigrationContext` + as the ``fn`` argument. This function represents the "work" + that will be done when :meth:`.MigrationContext.run_migrations` + is called, typically from within the ``env.py`` script of the + migration environment. The "work function" then provides an iterable + of version callables and other version information which + in the case of the ``upgrade`` or ``downgrade`` commands are the + list of version scripts to invoke. Other commands yield nothing, + in the case that a command wants to run some other operation + against the database such as the ``current`` or ``stamp`` commands. + + :param \**kw: keyword arguments here will be passed to each + migration callable, that is the ``upgrade()`` or ``downgrade()`` + method within revision scripts. + + """ + self.impl.start_migrations() + + heads = self.get_current_heads() + if not self.as_sql and not heads: + self._ensure_version_table() + + head_maintainer = HeadMaintainer(self, heads) + + for step in self._migrations_fn(heads, self): + with self.begin_transaction(_per_migration=True): + if self.as_sql and not head_maintainer.heads: + # for offline mode, include a CREATE TABLE from + # the base + self._version.create(self.connection) + log.info("Running %s", step) + if self.as_sql: + self.impl.static_output("-- Running %s" % (step.short_log,)) + step.migration_fn(**kw) + + # previously, we wouldn't stamp per migration + # if we were in a transaction, however given the more + # complex model that involves any number of inserts + # and row-targeted updates and deletes, it's simpler for now + # just to run the operations on every version + head_maintainer.update_to_step(step) + + if self.as_sql and not head_maintainer.heads: + self._version.drop(self.connection) + + def execute(self, sql, execution_options=None): + """Execute a SQL construct or string statement. + + The underlying execution mechanics are used, that is + if this is "offline mode" the SQL is written to the + output buffer, otherwise the SQL is emitted on + the current SQLAlchemy connection. + + """ + self.impl._exec(sql, execution_options) + + def _stdout_connection(self, connection): + def dump(construct, *multiparams, **params): + self.impl._exec(construct) + + return MockEngineStrategy.MockConnection(self.dialect, dump) + + @property + def bind(self): + """Return the current "bind". + + In online mode, this is an instance of + :class:`sqlalchemy.engine.Connection`, and is suitable + for ad-hoc execution of any kind of usage described + in :ref:`sqlexpression_toplevel` as well as + for usage with the :meth:`sqlalchemy.schema.Table.create` + and :meth:`sqlalchemy.schema.MetaData.create_all` methods + of :class:`~sqlalchemy.schema.Table`, + :class:`~sqlalchemy.schema.MetaData`. + + Note that when "standard output" mode is enabled, + this bind will be a "mock" connection handler that cannot + return results and is only appropriate for a very limited + subset of commands. + + """ + return self.connection + + @property + def config(self): + """Return the :class:`.Config` used by the current environment, if any. + + .. versionadded:: 0.6.6 + + """ + if self.environment_context: + return self.environment_context.config + else: + return None + + def _compare_type(self, inspector_column, metadata_column): + if self._user_compare_type is False: + return False + + if callable(self._user_compare_type): + user_value = self._user_compare_type( + self, + inspector_column, + metadata_column, + inspector_column.type, + metadata_column.type + ) + if user_value is not None: + return user_value + + return self.impl.compare_type( + inspector_column, + metadata_column) + + def _compare_server_default(self, inspector_column, + metadata_column, + rendered_metadata_default, + rendered_column_default): + + if self._user_compare_server_default is False: + return False + + if callable(self._user_compare_server_default): + user_value = self._user_compare_server_default( + self, + inspector_column, + metadata_column, + rendered_column_default, + metadata_column.server_default, + rendered_metadata_default + ) + if user_value is not None: + return user_value + + return self.impl.compare_server_default( + inspector_column, + metadata_column, + rendered_metadata_default, + rendered_column_default) + + +class HeadMaintainer(object): + def __init__(self, context, heads): + self.context = context + self.heads = set(heads) + + def _insert_version(self, version): + assert version not in self.heads + self.heads.add(version) + + self.context.impl._exec( + self.context._version.insert(). + values( + version_num=literal_column("'%s'" % version) + ) + ) + + def _delete_version(self, version): + self.heads.remove(version) + + ret = self.context.impl._exec( + self.context._version.delete().where( + self.context._version.c.version_num == + literal_column("'%s'" % version))) + if not self.context.as_sql and ret.rowcount != 1: + raise util.CommandError( + "Online migration expected to match one " + "row when deleting '%s' in '%s'; " + "%d found" + % (version, + self.context.version_table, ret.rowcount)) + + def _update_version(self, from_, to_): + assert to_ not in self.heads + self.heads.remove(from_) + self.heads.add(to_) + + ret = self.context.impl._exec( + self.context._version.update(). + values(version_num=literal_column("'%s'" % to_)).where( + self.context._version.c.version_num + == literal_column("'%s'" % from_)) + ) + if not self.context.as_sql and ret.rowcount != 1: + raise util.CommandError( + "Online migration expected to match one " + "row when updating '%s' to '%s' in '%s'; " + "%d found" + % (from_, to_, self.context.version_table, ret.rowcount)) + + def update_to_step(self, step): + if step.should_delete_branch(self.heads): + vers = step.delete_version_num + log.debug("branch delete %s", vers) + self._delete_version(vers) + elif step.should_create_branch(self.heads): + vers = step.insert_version_num + log.debug("new branch insert %s", vers) + self._insert_version(vers) + elif step.should_merge_branches(self.heads): + # delete revs, update from rev, update to rev + (delete_revs, update_from_rev, + update_to_rev) = step.merge_branch_idents(self.heads) + log.debug( + "merge, delete %s, update %s to %s", + delete_revs, update_from_rev, update_to_rev) + for delrev in delete_revs: + self._delete_version(delrev) + self._update_version(update_from_rev, update_to_rev) + elif step.should_unmerge_branches(self.heads): + (update_from_rev, update_to_rev, + insert_revs) = step.unmerge_branch_idents(self.heads) + log.debug( + "unmerge, insert %s, update %s to %s", + insert_revs, update_from_rev, update_to_rev) + for insrev in insert_revs: + self._insert_version(insrev) + self._update_version(update_from_rev, update_to_rev) + else: + from_, to_ = step.update_version_num(self.heads) + log.debug("update %s to %s", from_, to_) + self._update_version(from_, to_) + + +class MigrationStep(object): + @property + def name(self): + return self.migration_fn.__name__ + + @classmethod + def upgrade_from_script(cls, revision_map, script): + return RevisionStep(revision_map, script, True) + + @classmethod + def downgrade_from_script(cls, revision_map, script): + return RevisionStep(revision_map, script, False) + + @property + def is_downgrade(self): + return not self.is_upgrade + + @property + def short_log(self): + return "%s %s -> %s" % ( + self.name, + util.format_as_comma(self.from_revisions), + util.format_as_comma(self.to_revisions) + ) + + def __str__(self): + if self.doc: + return "%s %s -> %s, %s" % ( + self.name, + util.format_as_comma(self.from_revisions), + util.format_as_comma(self.to_revisions), + self.doc + ) + else: + return self.short_log + + +class RevisionStep(MigrationStep): + def __init__(self, revision_map, revision, is_upgrade): + self.revision_map = revision_map + self.revision = revision + self.is_upgrade = is_upgrade + if is_upgrade: + self.migration_fn = revision.module.upgrade + else: + self.migration_fn = revision.module.downgrade + + def __eq__(self, other): + return isinstance(other, RevisionStep) and \ + other.revision == self.revision and \ + self.is_upgrade == other.is_upgrade + + @property + def doc(self): + return self.revision.doc + + @property + def from_revisions(self): + if self.is_upgrade: + return self.revision._all_down_revisions + else: + return (self.revision.revision, ) + + @property + def to_revisions(self): + if self.is_upgrade: + return (self.revision.revision, ) + else: + return self.revision._all_down_revisions + + @property + def _has_scalar_down_revision(self): + return len(self.revision._all_down_revisions) == 1 + + def should_delete_branch(self, heads): + if not self.is_downgrade: + return False + + if self.revision.revision not in heads: + return False + + downrevs = self.revision._all_down_revisions + if not downrevs: + # is a base + return True + elif len(downrevs) == 1: + downrev = self.revision_map.get_revision(downrevs[0]) + + if not downrev._is_real_branch_point: + return False + + descendants = set( + r.revision for r in self.revision_map._get_descendant_nodes( + self.revision_map.get_revisions(downrev._all_nextrev), + check=False + ) + ) + + # the downrev is a branchpoint, and other members or descendants + # of the branch are still in heads; so delete this branch. + # the reason this occurs is because traversal tries to stay + # fully on one branch down to the branchpoint before starting + # the other; so if we have a->b->(c1->d1->e1, c2->d2->e2), + # on a downgrade from the top we may go e1, d1, c1, now heads + # are at c1 and e2, with the current method, we don't know that + # "e2" is important unless we get all descendants of c1/c2 + + if len(descendants.intersection(heads).difference( + [self.revision.revision])): + + # TODO: this doesn't work; make sure tests are here to ensure + # this fails + #if len(downrev._all_nextrev.intersection(heads).difference( + # [self.revision.revision])): + + return True + else: + return False + else: + # is a merge point + return False + + def merge_branch_idents(self, heads): + other_heads = set(heads).difference(self.from_revisions) + + if other_heads: + ancestors = set( + r.revision for r in + self.revision_map._get_ancestor_nodes( + self.revision_map.get_revisions(other_heads), + check=False + ) + ) + from_revisions = list( + set(self.from_revisions).difference(ancestors)) + else: + from_revisions = list(self.from_revisions) + + return ( + # delete revs, update from rev, update to rev + list(from_revisions[0:-1]), from_revisions[-1], + self.to_revisions[0] + ) + + def unmerge_branch_idents(self, heads): + other_heads = set(heads).difference([self.revision.revision]) + if other_heads: + ancestors = set( + r.revision for r in + self.revision_map._get_ancestor_nodes( + self.revision_map.get_revisions(other_heads), + check=False + ) + ) + to_revisions = list(set(self.to_revisions).difference(ancestors)) + else: + to_revisions = self.to_revisions + + return ( + # update from rev, update to rev, insert revs + self.from_revisions[0], to_revisions[-1], + to_revisions[0:-1] + ) + + def should_create_branch(self, heads): + if not self.is_upgrade: + return False + + downrevs = self.revision._all_down_revisions + + if not downrevs: + # is a base + return True + else: + # none of our downrevs are present, so... + # we have to insert our version. This is true whether + # or not there is only one downrev, or multiple (in the latter + # case, we're a merge point.) + if not heads.intersection(downrevs): + return True + else: + return False + + def should_merge_branches(self, heads): + if not self.is_upgrade: + return False + + downrevs = self.revision._all_down_revisions + + if len(downrevs) > 1 and \ + len(heads.intersection(downrevs)) > 1: + return True + + return False + + def should_unmerge_branches(self, heads): + if not self.is_downgrade: + return False + + downrevs = self.revision._all_down_revisions + + if self.revision.revision in heads and len(downrevs) > 1: + return True + + return False + + def update_version_num(self, heads): + if not self._has_scalar_down_revision: + downrev = heads.intersection(self.revision._all_down_revisions) + assert len(downrev) == 1, \ + "Can't do an UPDATE because downrevision is ambiguous" + down_revision = list(downrev)[0] + else: + down_revision = self.revision._all_down_revisions[0] + + if self.is_upgrade: + return down_revision, self.revision.revision + else: + return self.revision.revision, down_revision + + @property + def delete_version_num(self): + return self.revision.revision + + @property + def insert_version_num(self): + return self.revision.revision + + +class StampStep(MigrationStep): + def __init__(self, from_, to_, is_upgrade, branch_move): + self.from_ = util.to_tuple(from_, default=()) + self.to_ = util.to_tuple(to_, default=()) + self.is_upgrade = is_upgrade + self.branch_move = branch_move + self.migration_fn = self.stamp_revision + + doc = None + + def stamp_revision(self, **kw): + return None + + def __eq__(self, other): + return isinstance(other, StampStep) and \ + other.from_revisions == self.revisions and \ + other.to_revisions == self.to_revisions and \ + other.branch_move == self.branch_move and \ + self.is_upgrade == other.is_upgrade + + @property + def from_revisions(self): + return self.from_ + + @property + def to_revisions(self): + return self.to_ + + @property + def delete_version_num(self): + assert len(self.from_) == 1 + return self.from_[0] + + @property + def insert_version_num(self): + assert len(self.to_) == 1 + return self.to_[0] + + def update_version_num(self, heads): + assert len(self.from_) == 1 + assert len(self.to_) == 1 + return self.from_[0], self.to_[0] + + def merge_branch_idents(self, heads): + return ( + # delete revs, update from rev, update to rev + list(self.from_[0:-1]), self.from_[-1], + self.to_[0] + ) + + def unmerge_branch_idents(self, heads): + return ( + # update from rev, update to rev, insert revs + self.from_[0], self.to_[-1], + list(self.to_[0:-1]) + ) + + def should_delete_branch(self, heads): + return self.is_downgrade and self.branch_move + + def should_create_branch(self, heads): + return self.is_upgrade and self.branch_move + + def should_merge_branches(self, heads): + return len(self.from_) > 1 + + def should_unmerge_branches(self, heads): + return len(self.to_) > 1 |