summaryrefslogtreecommitdiff
path: root/alembic/script/revision.py
diff options
context:
space:
mode:
Diffstat (limited to 'alembic/script/revision.py')
-rw-r--r--alembic/script/revision.py806
1 files changed, 806 insertions, 0 deletions
diff --git a/alembic/script/revision.py b/alembic/script/revision.py
new file mode 100644
index 0000000..e9958b1
--- /dev/null
+++ b/alembic/script/revision.py
@@ -0,0 +1,806 @@
+import re
+import collections
+
+from .. import util
+from sqlalchemy import util as sqlautil
+from ..util import compat
+
+_relative_destination = re.compile(r'(?:(.+?)@)?(\w+)?((?:\+|-)\d+)')
+
+
+class RevisionError(Exception):
+ pass
+
+
+class RangeNotAncestorError(RevisionError):
+ def __init__(self, lower, upper):
+ self.lower = lower
+ self.upper = upper
+ super(RangeNotAncestorError, self).__init__(
+ "Revision %s is not an ancestor of revision %s" %
+ (lower or "base", upper or "base")
+ )
+
+
+class MultipleHeads(RevisionError):
+ def __init__(self, heads, argument):
+ self.heads = heads
+ self.argument = argument
+ super(MultipleHeads, self).__init__(
+ "Multiple heads are present for given argument '%s'; "
+ "%s" % (argument, ", ".join(heads))
+ )
+
+
+class ResolutionError(RevisionError):
+ pass
+
+
+class RevisionMap(object):
+ """Maintains a map of :class:`.Revision` objects.
+
+ :class:`.RevisionMap` is used by :class:`.ScriptDirectory` to maintain
+ and traverse the collection of :class:`.Script` objects, which are
+ themselves instances of :class:`.Revision`.
+
+ """
+
+ def __init__(self, generator):
+ """Construct a new :class:`.RevisionMap`.
+
+ :param generator: a zero-arg callable that will generate an iterable
+ of :class:`.Revision` instances to be used. These are typically
+ :class:`.Script` subclasses within regular Alembic use.
+
+ """
+ self._generator = generator
+
+ @util.memoized_property
+ def heads(self):
+ """All "head" revisions as strings.
+
+ This is normally a tuple of length one,
+ unless unmerged branches are present.
+
+ :return: a tuple of string revision numbers.
+
+ """
+ self._revision_map
+ return self.heads
+
+ @util.memoized_property
+ def bases(self):
+ """All "base" revisions as strings.
+
+ These are revisions that have a ``down_revision`` of None,
+ or empty tuple.
+
+ :return: a tuple of string revision numbers.
+
+ """
+ self._revision_map
+ return self.bases
+
+ @util.memoized_property
+ def _real_heads(self):
+ """All "real" head revisions as strings.
+
+ :return: a tuple of string revision numbers.
+
+ """
+ self._revision_map
+ return self._real_heads
+
+ @util.memoized_property
+ def _real_bases(self):
+ """All "real" base revisions as strings.
+
+ :return: a tuple of string revision numbers.
+
+ """
+ self._revision_map
+ return self._real_bases
+
+ @util.memoized_property
+ def _revision_map(self):
+ """memoized attribute, initializes the revision map from the
+ initial collection.
+
+ """
+ map_ = {}
+
+ heads = sqlautil.OrderedSet()
+ _real_heads = sqlautil.OrderedSet()
+ self.bases = ()
+ self._real_bases = ()
+
+ has_branch_labels = set()
+ for revision in self._generator():
+
+ if revision.revision in map_:
+ util.warn("Revision %s is present more than once" %
+ revision.revision)
+ map_[revision.revision] = revision
+ if revision.branch_labels:
+ has_branch_labels.add(revision)
+ heads.add(revision.revision)
+ _real_heads.add(revision.revision)
+ if revision.is_base:
+ self.bases += (revision.revision, )
+ if revision._is_real_base:
+ self._real_bases += (revision.revision, )
+
+ for rev in map_.values():
+ for downrev in rev._all_down_revisions:
+ if downrev not in map_:
+ util.warn("Revision %s referenced from %s is not present"
+ % (downrev, rev))
+ down_revision = map_[downrev]
+ down_revision.add_nextrev(rev)
+ if downrev in rev._versioned_down_revisions:
+ heads.discard(downrev)
+ _real_heads.discard(downrev)
+
+ map_[None] = map_[()] = None
+ self.heads = tuple(heads)
+ self._real_heads = tuple(_real_heads)
+
+ for revision in has_branch_labels:
+ self._add_branches(revision, map_)
+ return map_
+
+ def _add_branches(self, revision, map_):
+ if revision.branch_labels:
+ for branch_label in revision._orig_branch_labels:
+ if branch_label in map_:
+ raise RevisionError(
+ "Branch name '%s' in revision %s already "
+ "used by revision %s" %
+ (branch_label, revision.revision,
+ map_[branch_label].revision)
+ )
+ map_[branch_label] = revision
+ revision.branch_labels.update(revision.branch_labels)
+ for node in self._get_descendant_nodes(
+ [revision], map_, include_dependencies=False):
+ node.branch_labels.update(revision.branch_labels)
+
+ parent = node
+ while parent and \
+ not parent._is_real_branch_point and not parent.is_merge_point:
+
+ parent.branch_labels.update(revision.branch_labels)
+ if parent.down_revision:
+ parent = map_[parent.down_revision]
+ else:
+ break
+
+ def add_revision(self, revision, _replace=False):
+ """add a single revision to an existing map.
+
+ This method is for single-revision use cases, it's not
+ appropriate for fully populating an entire revision map.
+
+ """
+ map_ = self._revision_map
+ if not _replace and revision.revision in map_:
+ util.warn("Revision %s is present more than once" %
+ revision.revision)
+ elif _replace and revision.revision not in map_:
+ raise Exception("revision %s not in map" % revision.revision)
+
+ map_[revision.revision] = revision
+ self._add_branches(revision, map_)
+ if revision.is_base:
+ self.bases += (revision.revision, )
+ if revision._is_real_base:
+ self._real_bases += (revision.revision, )
+ for downrev in revision._all_down_revisions:
+ if downrev not in map_:
+ util.warn(
+ "Revision %s referenced from %s is not present"
+ % (downrev, revision)
+ )
+ map_[downrev].add_nextrev(revision)
+ if revision._is_real_head:
+ self._real_heads = tuple(
+ head for head in self._real_heads
+ if head not in
+ set(revision._all_down_revisions).union([revision.revision])
+ ) + (revision.revision,)
+ if revision.is_head:
+ self.heads = tuple(
+ head for head in self.heads
+ if head not in
+ set(revision._versioned_down_revisions).union([revision.revision])
+ ) + (revision.revision,)
+
+ def get_current_head(self, branch_label=None):
+ """Return the current head revision.
+
+ If the script directory has multiple heads
+ due to branching, an error is raised;
+ :meth:`.ScriptDirectory.get_heads` should be
+ preferred.
+
+ :param branch_label: optional branch name which will limit the
+ heads considered to those which include that branch_label.
+
+ :return: a string revision number.
+
+ .. seealso::
+
+ :meth:`.ScriptDirectory.get_heads`
+
+ """
+ current_heads = self.heads
+ if branch_label:
+ current_heads = self.filter_for_lineage(current_heads, branch_label)
+ if len(current_heads) > 1:
+ raise MultipleHeads(
+ current_heads,
+ "%s@head" % branch_label if branch_label else "head")
+
+ if current_heads:
+ return current_heads[0]
+ else:
+ return None
+
+ def _get_base_revisions(self, identifier):
+ return self.filter_for_lineage(self.bases, identifier)
+
+ def get_revisions(self, id_):
+ """Return the :class:`.Revision` instances with the given rev id
+ or identifiers.
+
+ May be given a single identifier, a sequence of identifiers, or the
+ special symbols "head" or "base". The result is a tuple of one
+ or more identifiers, or an empty tuple in the case of "base".
+
+ In the cases where 'head', 'heads' is requested and the
+ revision map is empty, returns an empty tuple.
+
+ Supports partial identifiers, where the given identifier
+ is matched against all identifiers that start with the given
+ characters; if there is exactly one match, that determines the
+ full revision.
+
+ """
+ if isinstance(id_, (list, tuple, set, frozenset)):
+ return sum([self.get_revisions(id_elem) for id_elem in id_], ())
+ else:
+ resolved_id, branch_label = self._resolve_revision_number(id_)
+ return tuple(
+ self._revision_for_ident(rev_id, branch_label)
+ for rev_id in resolved_id)
+
+ def get_revision(self, id_):
+ """Return the :class:`.Revision` instance with the given rev id.
+
+ If a symbolic name such as "head" or "base" is given, resolves
+ the identifier into the current head or base revision. If the symbolic
+ name refers to multiples, :class:`.MultipleHeads` is raised.
+
+ Supports partial identifiers, where the given identifier
+ is matched against all identifiers that start with the given
+ characters; if there is exactly one match, that determines the
+ full revision.
+
+ """
+
+ resolved_id, branch_label = self._resolve_revision_number(id_)
+ if len(resolved_id) > 1:
+ raise MultipleHeads(resolved_id, id_)
+ elif resolved_id:
+ resolved_id = resolved_id[0]
+
+ return self._revision_for_ident(resolved_id, branch_label)
+
+ def _resolve_branch(self, branch_label):
+ try:
+ branch_rev = self._revision_map[branch_label]
+ except KeyError:
+ try:
+ nonbranch_rev = self._revision_for_ident(branch_label)
+ except ResolutionError:
+ raise ResolutionError("No such branch: '%s'" % branch_label)
+ else:
+ return nonbranch_rev
+ else:
+ return branch_rev
+
+ def _revision_for_ident(self, resolved_id, check_branch=None):
+ if check_branch:
+ branch_rev = self._resolve_branch(check_branch)
+ else:
+ branch_rev = None
+
+ try:
+ revision = self._revision_map[resolved_id]
+ except KeyError:
+ # do a partial lookup
+ revs = [x for x in self._revision_map
+ if x and x.startswith(resolved_id)]
+ if branch_rev:
+ revs = self.filter_for_lineage(revs, check_branch)
+ if not revs:
+ raise ResolutionError(
+ "No such revision or branch '%s'" % resolved_id)
+ elif len(revs) > 1:
+ raise ResolutionError(
+ "Multiple revisions start "
+ "with '%s': %s..." % (
+ resolved_id,
+ ", ".join("'%s'" % r for r in revs[0:3])
+ ))
+ else:
+ revision = self._revision_map[revs[0]]
+
+ if check_branch and revision is not None:
+ if not self._shares_lineage(
+ revision.revision, branch_rev.revision):
+ raise ResolutionError(
+ "Revision %s is not a member of branch '%s'" %
+ (revision.revision, check_branch))
+ return revision
+
+ def filter_for_lineage(
+ self, targets, check_against, include_dependencies=False):
+ id_, branch_label = self._resolve_revision_number(check_against)
+
+ shares = []
+ if branch_label:
+ shares.append(branch_label)
+ if id_:
+ shares.extend(id_)
+
+ return [
+ tg for tg in targets
+ if self._shares_lineage(
+ tg, shares, include_dependencies=include_dependencies)]
+
+ def _shares_lineage(
+ self, target, test_against_revs, include_dependencies=False):
+ if not test_against_revs:
+ return True
+ if not isinstance(target, Revision):
+ target = self._revision_for_ident(target)
+
+ test_against_revs = [
+ self._revision_for_ident(test_against_rev)
+ if not isinstance(test_against_rev, Revision)
+ else test_against_rev
+ for test_against_rev
+ in util.to_tuple(test_against_revs, default=())
+ ]
+
+ return bool(
+ set(self._get_descendant_nodes([target],
+ include_dependencies=include_dependencies))
+ .union(self._get_ancestor_nodes([target],
+ include_dependencies=include_dependencies))
+ .intersection(test_against_revs)
+ )
+
+ def _resolve_revision_number(self, id_):
+ if isinstance(id_, compat.string_types) and "@" in id_:
+ branch_label, id_ = id_.split('@', 1)
+ else:
+ branch_label = None
+
+ # ensure map is loaded
+ self._revision_map
+ if id_ == 'heads':
+ if branch_label:
+ return self.filter_for_lineage(
+ self.heads, branch_label), branch_label
+ else:
+ return self._real_heads, branch_label
+ elif id_ == 'head':
+ current_head = self.get_current_head(branch_label)
+ if current_head:
+ return (current_head, ), branch_label
+ else:
+ return (), branch_label
+ elif id_ == 'base' or id_ is None:
+ return (), branch_label
+ else:
+ return util.to_tuple(id_, default=None), branch_label
+
+ def _relative_iterate(
+ self, destination, source, is_upwards,
+ implicit_base, inclusive, assert_relative_length):
+ if isinstance(destination, compat.string_types):
+ match = _relative_destination.match(destination)
+ if not match:
+ return None
+ else:
+ return None
+
+ relative = int(match.group(3))
+ symbol = match.group(2)
+ branch_label = match.group(1)
+
+ reldelta = 1 if inclusive and not symbol else 0
+
+ if is_upwards:
+ if branch_label:
+ from_ = "%s@head" % branch_label
+ elif symbol:
+ if symbol.startswith("head"):
+ from_ = symbol
+ else:
+ from_ = "%s@head" % symbol
+ else:
+ from_ = "head"
+ to_ = source
+ else:
+ if branch_label:
+ to_ = "%s@base" % branch_label
+ elif symbol:
+ to_ = "%s@base" % symbol
+ else:
+ to_ = "base"
+ from_ = source
+
+ revs = list(
+ self._iterate_revisions(
+ from_, to_,
+ inclusive=inclusive, implicit_base=implicit_base))
+
+ if symbol:
+ if branch_label:
+ symbol_rev = self.get_revision(
+ "%s@%s" % (branch_label, symbol))
+ else:
+ symbol_rev = self.get_revision(symbol)
+ if symbol.startswith("head"):
+ index = 0
+ elif symbol == "base":
+ index = len(revs) - 1
+ else:
+ range_ = compat.range(len(revs) - 1, 0, -1)
+ for index in range_:
+ if symbol_rev.revision == revs[index].revision:
+ break
+ else:
+ index = 0
+ else:
+ index = 0
+ if is_upwards:
+ revs = revs[index - relative - reldelta:]
+ if not index and assert_relative_length and \
+ len(revs) < abs(relative - reldelta):
+ raise RevisionError(
+ "Relative revision %s didn't "
+ "produce %d migrations" % (destination, abs(relative)))
+ else:
+ revs = revs[0:index - relative + reldelta]
+ if not index and assert_relative_length and \
+ len(revs) != abs(relative) + reldelta:
+ raise RevisionError(
+ "Relative revision %s didn't "
+ "produce %d migrations" % (destination, abs(relative)))
+
+ return iter(revs)
+
+ def iterate_revisions(
+ self, upper, lower, implicit_base=False, inclusive=False,
+ assert_relative_length=True):
+ """Iterate through script revisions, starting at the given
+ upper revision identifier and ending at the lower.
+
+ The traversal uses strictly the `down_revision`
+ marker inside each migration script, so
+ it is a requirement that upper >= lower,
+ else you'll get nothing back.
+
+ The iterator yields :class:`.Revision` objects.
+
+ """
+
+ relative_upper = self._relative_iterate(
+ upper, lower, True, implicit_base,
+ inclusive, assert_relative_length
+ )
+ if relative_upper:
+ return relative_upper
+
+ relative_lower = self._relative_iterate(
+ lower, upper, False, implicit_base,
+ inclusive, assert_relative_length
+ )
+ if relative_lower:
+ return relative_lower
+
+ return self._iterate_revisions(
+ upper, lower, inclusive=inclusive, implicit_base=implicit_base)
+
+ def _get_descendant_nodes(
+ self, targets, map_=None, check=False, include_dependencies=True):
+
+ if include_dependencies:
+ fn = lambda rev: rev._all_nextrev
+ else:
+ fn = lambda rev: rev.nextrev
+
+ return self._iterate_related_revisions(
+ fn, targets, map_=map_, check=check
+ )
+
+ def _get_ancestor_nodes(
+ self, targets, map_=None, check=False, include_dependencies=True):
+
+ if include_dependencies:
+ fn = lambda rev: rev._all_down_revisions
+ else:
+ fn = lambda rev: rev._versioned_down_revisions
+
+ return self._iterate_related_revisions(
+ fn, targets, map_=map_, check=check
+ )
+
+ def _iterate_related_revisions(self, fn, targets, map_, check=False):
+ if map_ is None:
+ map_ = self._revision_map
+
+ todo = collections.deque()
+ for target in targets:
+ todo.append(target)
+ if check:
+ per_target = set()
+ while todo:
+ rev = todo.pop()
+ todo.extend(
+ map_[rev_id] for rev_id in fn(rev))
+ if check:
+ per_target.add(rev)
+ yield rev
+ if check and per_target.intersection(targets).difference([target]):
+ raise RevisionError(
+ "Requested revision %s overlaps with "
+ "other requested revisions" % target.revision)
+
+ def _iterate_revisions(
+ self, upper, lower, inclusive=True, implicit_base=False):
+ """iterate revisions from upper to lower.
+
+ The traversal is depth-first within branches, and breadth-first
+ across branches as a whole.
+
+ """
+
+ requested_lowers = self.get_revisions(lower)
+
+ # some complexity to accommodate an iteration where some
+ # branches are starting from nothing, and others are starting
+ # from a given point. Additionally, if the bottom branch
+ # is specified using a branch identifier, then we limit operations
+ # to just that branch.
+
+ limit_to_lower_branch = \
+ isinstance(lower, compat.string_types) and lower.endswith('@base')
+
+ uppers = self.get_revisions(upper)
+ if not uppers and not requested_lowers:
+ raise StopIteration()
+
+ upper_ancestors = set(self._get_ancestor_nodes(uppers, check=True))
+
+ if limit_to_lower_branch:
+ lowers = self.get_revisions(self._get_base_revisions(lower))
+ elif implicit_base and requested_lowers:
+ lower_ancestors = set(
+ self._get_ancestor_nodes(requested_lowers)
+ )
+ lower_descendants = set(
+ self._get_descendant_nodes(requested_lowers)
+ )
+ base_lowers = set()
+ candidate_lowers = upper_ancestors.\
+ difference(lower_ancestors).\
+ difference(lower_descendants)
+ for rev in candidate_lowers:
+ for downrev in rev._all_down_revisions:
+ if self._revision_map[downrev] in candidate_lowers:
+ break
+ else:
+ base_lowers.add(rev)
+ lowers = base_lowers.union(requested_lowers)
+ elif implicit_base:
+ base_lowers = set(self.get_revisions(self._real_bases))
+ lowers = base_lowers.union(requested_lowers)
+ elif not requested_lowers:
+ lowers = set(self.get_revisions(self._real_bases))
+ else:
+ lowers = requested_lowers
+
+ # represents all nodes we will produce
+ total_space = set(
+ rev.revision for rev in upper_ancestors).intersection(
+ rev.revision for rev
+ in self._get_descendant_nodes(lowers, check=True)
+ )
+
+ if not total_space:
+ raise RangeNotAncestorError(lower, upper)
+
+ # organize branch points to be consumed separately from
+ # member nodes
+ branch_todo = set(
+ rev for rev in
+ (self._revision_map[rev] for rev in total_space)
+ if rev._is_real_branch_point and
+ len(total_space.intersection(rev._all_nextrev)) > 1
+ )
+
+ # it's not possible for any "uppers" to be in branch_todo,
+ # because the ._all_nextrev of those nodes is not in total_space
+ #assert not branch_todo.intersection(uppers)
+
+ todo = collections.deque(
+ r for r in uppers if r.revision in total_space)
+
+ # iterate for total_space being emptied out
+ while total_space:
+ # when everything non-branch pending is consumed,
+ # add to the todo any branch nodes that have no
+ # descendants left in the queue
+ if not todo:
+ todo.extendleft(
+ sorted(
+ (
+ rev for rev in branch_todo
+ if not rev._all_nextrev.intersection(total_space)
+ ),
+ # favor "revisioned" branch points before
+ # dependent ones
+ key=lambda rev: 0 if rev.is_branch_point else 1
+ )
+ )
+ branch_todo.difference_update(todo)
+ # iterate nodes that are in the immediate todo
+ while todo:
+ rev = todo.popleft()
+ total_space.remove(rev.revision)
+
+ # do depth first for elements within branches,
+ # don't consume any actual branch nodes
+ todo.extendleft([
+ self._revision_map[downrev]
+ for downrev in reversed(rev._all_down_revisions)
+ if self._revision_map[downrev] not in branch_todo
+ and downrev in total_space])
+
+ if not inclusive and rev in requested_lowers:
+ continue
+ yield rev
+
+ assert not branch_todo
+
+
+class Revision(object):
+ """Base class for revisioned objects.
+
+ The :class:`.Revision` class is the base of the more public-facing
+ :class:`.Script` object, which represents a migration script.
+ The mechanics of revision management and traversal are encapsulated
+ within :class:`.Revision`, while :class:`.Script` applies this logic
+ to Python files in a version directory.
+
+ """
+ nextrev = frozenset()
+ """following revisions, based on down_revision only."""
+
+ _all_nextrev = frozenset()
+
+ revision = None
+ """The string revision number."""
+
+ down_revision = None
+ """The ``down_revision`` identifier(s) within the migration script.
+
+ Note that the total set of "down" revisions is
+ down_revision + dependencies.
+
+ """
+
+ dependencies = None
+ """Additional revisions which this revision is dependent on.
+
+ From a migration standpoint, these dependencies are added to the
+ down_revision to form the full iteration. However, the separation
+ of down_revision from "dependencies" is to assist in navigating
+ a history that contains many branches, typically a multi-root scenario.
+
+ """
+
+ branch_labels = None
+ """Optional string/tuple of symbolic names to apply to this
+ revision's branch"""
+
+ def __init__(
+ self, revision, down_revision,
+ dependencies=None, branch_labels=None):
+ self.revision = revision
+ self.down_revision = tuple_rev_as_scalar(down_revision)
+ self.dependencies = tuple_rev_as_scalar(dependencies)
+ self._orig_branch_labels = util.to_tuple(branch_labels, default=())
+ self.branch_labels = set(self._orig_branch_labels)
+
+ def add_nextrev(self, revision):
+ self._all_nextrev = self._all_nextrev.union([revision.revision])
+ if self.revision in revision._versioned_down_revisions:
+ self.nextrev = self.nextrev.union([revision.revision])
+
+ @property
+ def _all_down_revisions(self):
+ return util.to_tuple(self.down_revision, default=()) + \
+ util.to_tuple(self.dependencies, default=())
+
+ @property
+ def _versioned_down_revisions(self):
+ return util.to_tuple(self.down_revision, default=())
+
+ @property
+ def is_head(self):
+ """Return True if this :class:`.Revision` is a 'head' revision.
+
+ This is determined based on whether any other :class:`.Script`
+ within the :class:`.ScriptDirectory` refers to this
+ :class:`.Script`. Multiple heads can be present.
+
+ """
+ return not bool(self.nextrev)
+
+ @property
+ def _is_real_head(self):
+ return not bool(self._all_nextrev)
+
+ @property
+ def is_base(self):
+ """Return True if this :class:`.Revision` is a 'base' revision."""
+
+ return self.down_revision is None
+
+ @property
+ def _is_real_base(self):
+ """Return True if this :class:`.Revision` is a "real" base revision,
+ e.g. that it has no dependencies either."""
+
+ return self.down_revision is None and self.dependencies is None
+
+ @property
+ def is_branch_point(self):
+ """Return True if this :class:`.Script` is a branch point.
+
+ A branchpoint is defined as a :class:`.Script` which is referred
+ to by more than one succeeding :class:`.Script`, that is more
+ than one :class:`.Script` has a `down_revision` identifier pointing
+ here.
+
+ """
+ return len(self.nextrev) > 1
+
+ @property
+ def _is_real_branch_point(self):
+ """Return True if this :class:`.Script` is a 'real' branch point,
+ taking into account dependencies as well.
+
+ """
+ return len(self._all_nextrev) > 1
+
+ @property
+ def is_merge_point(self):
+ """Return True if this :class:`.Script` is a merge point."""
+
+ return len(self._versioned_down_revisions) > 1
+
+
+def tuple_rev_as_scalar(rev):
+ if not rev:
+ return None
+ elif len(rev) == 1:
+ return rev[0]
+ else:
+ return rev