summaryrefslogtreecommitdiff
path: root/alembic/script
diff options
context:
space:
mode:
authorSimon Bowly <simon.bowly@gmail.com>2021-02-02 22:11:15 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2021-04-26 15:21:11 -0400
commit58a170a6c6bfa5d0460c038d63b66d74a5a2c830 (patch)
tree2aa827e8d5cd52f9d3306eabb315992776159d47 /alembic/script
parentf94765f82d8770c234a2c1850116d33000fb5d75 (diff)
downloadalembic-58a170a6c6bfa5d0460c038d63b66d74a5a2c830.tar.gz
New downgrade algorithm to fix branch behaviour
The algorithm used for calculating downgrades/upgrades/iterating revisions has been rewritten, to resolve ongoing issues of branches not being handled consistently particularly within downgrade operations, as well as for overall clarity and maintainability. This change includes that a deprecation warning is emitted if an ambiguous command such as "downgrade -1" when multiple heads are present is given. In particular, the change implements a long-requested use case of allowing downgrades of a single branch to a branchpoint. Huge thanks to Simon Bowly for their impressive efforts in successfully tackling this very difficult problem. Topological algorithm written by Mike which retains the behavior of walking along a single branch as long as possible before switching to another branch due to cross-dependencies. Fixes: #765 Fixes: #464 Fixes: #603 Fixes: #660 Closes: #790 Pull-request: https://github.com/sqlalchemy/alembic/pull/790 Pull-request-sha: 60752399336b883d26f38d61a8337dd6fdd48087 Change-Id: I8ab81d241df52a6bf0a474ccee35b16aecd706ac
Diffstat (limited to 'alembic/script')
-rw-r--r--alembic/script/revision.py799
1 files changed, 537 insertions, 262 deletions
diff --git a/alembic/script/revision.py b/alembic/script/revision.py
index d131912..4db0a7b 100644
--- a/alembic/script/revision.py
+++ b/alembic/script/revision.py
@@ -143,7 +143,9 @@ class RevisionMap(object):
initial collection.
"""
- map_ = {}
+ # Ordering required for some tests to pass (but not required in
+ # general)
+ map_ = sqlautil.OrderedDict()
heads = sqlautil.OrderedSet()
_real_heads = sqlautil.OrderedSet()
@@ -455,6 +457,25 @@ class RevisionMap(object):
return sum([self.get_revisions(id_elem) for id_elem in id_], ())
else:
resolved_id, branch_label = self._resolve_revision_number(id_)
+ if len(resolved_id) == 1:
+ try:
+ rint = int(resolved_id[0])
+ if rint < 0:
+ # branch@-n -> walk down from heads
+ select_heads = self.get_revisions("heads")
+ if branch_label is not None:
+ select_heads = [
+ head
+ for head in select_heads
+ if branch_label in head.branch_labels
+ ]
+ return tuple(
+ self._walk(head, steps=rint)
+ for head in select_heads
+ )
+ except ValueError:
+ # couldn't resolve as integer
+ pass
return tuple(
self._revision_for_ident(rev_id, branch_label)
for rev_id in resolved_id
@@ -657,99 +678,6 @@ class RevisionMap(object):
else:
return util.to_tuple(id_, default=None), branch_label
- def _relative_iterate(
- self,
- destination,
- source,
- is_upwards,
- implicit_base,
- inclusive,
- assert_relative_length,
- ):
- if isinstance(destination, compat.string_types):
- match = _relative_destination.match(destination)
- if not match:
- return None
- else:
- return None
-
- relative = int(match.group(3))
- symbol = match.group(2)
- branch_label = match.group(1)
-
- reldelta = 1 if inclusive and not symbol else 0
-
- if is_upwards:
- if branch_label:
- from_ = "%s@head" % branch_label
- elif symbol:
- if symbol.startswith("head"):
- from_ = symbol
- else:
- from_ = "%s@head" % symbol
- else:
- from_ = "head"
- to_ = source
- else:
- if branch_label:
- to_ = "%s@base" % branch_label
- elif symbol:
- to_ = "%s@base" % symbol
- else:
- to_ = "base"
- from_ = source
-
- revs = list(
- self._iterate_revisions(
- from_, to_, inclusive=inclusive, implicit_base=implicit_base
- )
- )
-
- if symbol:
- if branch_label:
- symbol_rev = self.get_revision(
- "%s@%s" % (branch_label, symbol)
- )
- else:
- symbol_rev = self.get_revision(symbol)
- if symbol.startswith("head"):
- index = 0
- elif symbol == "base":
- index = len(revs) - 1
- else:
- range_ = compat.range(len(revs) - 1, 0, -1)
- for index in range_:
- if symbol_rev.revision == revs[index].revision:
- break
- else:
- index = 0
- else:
- index = 0
- if is_upwards:
- revs = revs[index - relative - reldelta :]
- if (
- not index
- and assert_relative_length
- and len(revs) < abs(relative - reldelta)
- ):
- raise RevisionError(
- "Relative revision %s didn't "
- "produce %d migrations" % (destination, abs(relative))
- )
- else:
- revs = revs[0 : index - relative + reldelta]
- if (
- not index
- and assert_relative_length
- and len(revs) != abs(relative) + reldelta
- ):
- raise RevisionError(
- "Relative revision %s didn't "
- "produce %d migrations" % (destination, abs(relative))
- )
-
- return iter(revs)
-
def iterate_revisions(
self,
upper,
@@ -770,37 +698,22 @@ class RevisionMap(object):
The iterator yields :class:`.Revision` objects.
"""
+ if select_for_downgrade:
+ fn = self._collect_downgrade_revisions
+ else:
+ fn = self._collect_upgrade_revisions
- relative_upper = self._relative_iterate(
- upper,
- lower,
- True,
- implicit_base,
- inclusive,
- assert_relative_length,
- )
- if relative_upper:
- return relative_upper
-
- relative_lower = self._relative_iterate(
- lower,
- upper,
- False,
- implicit_base,
- inclusive,
- assert_relative_length,
- )
- if relative_lower:
- return relative_lower
-
- return self._iterate_revisions(
+ revisions, heads = fn(
upper,
lower,
inclusive=inclusive,
implicit_base=implicit_base,
- select_for_downgrade=select_for_downgrade,
+ assert_relative_length=assert_relative_length,
)
+ for node in self._topological_sort(revisions, heads):
+ yield self.get_revision(node)
+
def _get_descendant_nodes(
self,
targets,
@@ -870,7 +783,14 @@ class RevisionMap(object):
if rev in seen:
continue
seen.add(rev)
- todo.extend(map_[rev_id] for rev_id in fn(rev))
+ # Check for map errors before collecting.
+ for rev_id in fn(rev):
+ next_rev = map_[rev_id]
+ if next_rev.revision != rev_id:
+ raise RevisionError(
+ "Dependency resolution failed; broken map"
+ )
+ todo.append(next_rev)
yield rev
if check:
overlaps = per_target.intersection(targets).difference(
@@ -886,176 +806,531 @@ class RevisionMap(object):
)
)
- def _iterate_revisions(
- self,
- upper,
- lower,
- inclusive=True,
- implicit_base=False,
- select_for_downgrade=False,
- ):
- """iterate revisions from upper to lower.
-
- The traversal is depth-first within branches, and breadth-first
- across branches as a whole.
+ def _topological_sort(self, revisions, heads):
+ """Yield revision ids of a collection of Revision objects in
+ topological sorted order (i.e. revisions always come after their
+ down_revisions and dependencies). Uses the order of keys in
+ _revision_map to sort.
"""
- requested_lowers = self.get_revisions(lower)
+ id_to_rev = self._revision_map
- # some complexity to accommodate an iteration where some
- # branches are starting from nothing, and others are starting
- # from a given point. Additionally, if the bottom branch
- # is specified using a branch identifier, then we limit operations
- # to just that branch.
+ def get_ancestors(rev_id):
+ return {
+ r.revision
+ for r in self._get_ancestor_nodes([id_to_rev[rev_id]])
+ }
- limit_to_lower_branch = isinstance(
- lower, compat.string_types
- ) and lower.endswith("@base")
+ todo = {d.revision for d in revisions}
- uppers = util.dedupe_tuple(self.get_revisions(upper))
+ # Use revision map (ordered dict) key order to pre-sort.
+ inserted_order = list(self._revision_map)
- if not uppers and not requested_lowers:
- return
+ current_heads = list(
+ sorted(
+ {d.revision for d in heads if d.revision in todo},
+ key=inserted_order.index,
+ )
+ )
- upper_ancestors = set(self._get_ancestor_nodes(uppers, check=True))
+ ancestors_by_idx = [get_ancestors(rev_id) for rev_id in current_heads]
- if limit_to_lower_branch:
- lowers = self.get_revisions(self._get_base_revisions(lower))
- elif implicit_base and requested_lowers:
- lower_ancestors = set(self._get_ancestor_nodes(requested_lowers))
- lower_descendants = set(
- self._get_descendant_nodes(requested_lowers)
- )
- base_lowers = set()
- candidate_lowers = upper_ancestors.difference(
- lower_ancestors
- ).difference(lower_descendants)
- for rev in candidate_lowers:
- # note: the use of _normalized_down_revisions as opposed
- # to _all_down_revisions repairs
- # an issue related to looking at a revision in isolation
- # when updating the alembic_version table (issue #789).
- # however, while it seems likely that using
- # _normalized_down_revisions within traversal is more correct
- # than _all_down_revisions, we don't yet have any case to
- # show that it actually makes a difference.
- for downrev in rev._normalized_down_revisions:
- if self._revision_map[downrev] in candidate_lowers:
- break
+ output = []
+
+ current_candidate_idx = 0
+ while current_heads:
+
+ candidate = current_heads[current_candidate_idx]
+
+ for check_head_index, ancestors in enumerate(ancestors_by_idx):
+ # scan all the heads. see if we can continue walking
+ # down the current branch indicated by current_candidate_idx.
+ if (
+ check_head_index != current_candidate_idx
+ and candidate in ancestors
+ ):
+ current_candidate_idx = check_head_index
+ # nope, another head is dependent on us, they have
+ # to be traversed first
+ break
+ else:
+ # yup, we can emit
+ if candidate in todo:
+ output.append(candidate)
+ todo.remove(candidate)
+
+ # now update the heads with our ancestors.
+
+ candidate_rev = id_to_rev[candidate]
+
+ # immediate ancestor nodes
+ heads_to_add = [
+ r
+ for r in candidate_rev._normalized_down_revisions
+ if r in todo and r not in current_heads
+ ]
+
+ if not heads_to_add:
+ # no ancestors, so remove this head from the list
+ del current_heads[current_candidate_idx]
+ del ancestors_by_idx[current_candidate_idx]
+ current_candidate_idx = max(current_candidate_idx - 1, 0)
else:
- base_lowers.add(rev)
- lowers = base_lowers.union(requested_lowers)
- elif implicit_base:
- base_lowers = set(self.get_revisions(self._real_bases))
- lowers = base_lowers.union(requested_lowers)
- elif not requested_lowers:
- lowers = set(self.get_revisions(self._real_bases))
- else:
- lowers = requested_lowers
- # represents all nodes we will produce
- total_space = set(
- rev.revision for rev in upper_ancestors
- ).intersection(
- rev.revision
- for rev in self._get_descendant_nodes(
- lowers,
- check=True,
- omit_immediate_dependencies=(
- select_for_downgrade and requested_lowers
+ if (
+ not candidate_rev._normalized_resolved_dependencies
+ and len(candidate_rev._versioned_down_revisions) == 1
+ ):
+ current_heads[current_candidate_idx] = heads_to_add[0]
+
+ # for plain movement down a revision line without
+ # any mergepoints, branchpoints, or deps, we
+ # can update the ancestors collection directly
+ # by popping out the candidate we just emitted
+ ancestors_by_idx[current_candidate_idx].discard(
+ candidate
+ )
+
+ else:
+ # otherwise recalculate it again, things get
+ # complicated otherwise. This can possibly be
+ # improved to not run the whole ancestor thing
+ # each time but it was getting complicated
+ current_heads[current_candidate_idx] = heads_to_add[0]
+ current_heads.extend(heads_to_add[1:])
+ ancestors_by_idx[
+ current_candidate_idx
+ ] = get_ancestors(heads_to_add[0])
+ ancestors_by_idx.extend(
+ get_ancestors(head) for head in heads_to_add[1:]
+ )
+
+ assert not todo
+ return output
+
+ def _walk(self, start, steps, branch_label=None, no_overwalk=True):
+ """
+ Walk the requested number of :steps up (steps > 0) or down (steps < 0)
+ the revision tree.
+
+ :branch_label is used to select branches only when walking up.
+
+ If the walk goes past the boundaries of the tree and :no_overwalk is
+ True, None is returned, otherwise the walk terminates early.
+
+ A RevisionError is raised if there is no unambiguous revision to
+ walk to.
+ """
+
+ if isinstance(start, compat.string_types):
+ start = self.get_revision(start)
+
+ for _ in range(abs(steps)):
+ if steps > 0:
+ # Walk up
+ children = [
+ rev
+ for rev in self.get_revisions(
+ self.bases if start is None else start.nextrev
+ )
+ ]
+ if branch_label:
+ children = self.filter_for_lineage(children, branch_label)
+ else:
+ # Walk down
+ if start == "base":
+ children = tuple()
+ else:
+ children = self.get_revisions(
+ self.heads if start is None else start.down_revision
+ )
+ if not children:
+ children = ("base",)
+ if not children:
+ # This will return an invalid result if no_overwalk, otherwise
+ # further steps will stay where we are.
+ return None if no_overwalk else start
+ elif len(children) > 1:
+ raise RevisionError("Ambiguous walk")
+ start = children[0]
+
+ return start
+
+ def _parse_downgrade_target(
+ self, current_revisions, target, assert_relative_length
+ ):
+ """
+ Parse downgrade command syntax :target to retrieve the target revision
+ and branch label (if any) given the :current_revisons stamp of the
+ database.
+
+ Returns a tuple (branch_label, target_revision) where branch_label
+ is a string from the command specifying the branch to consider (or
+ None if no branch given), and target_revision is a Revision object
+ which the command refers to. target_revsions is None if the command
+ refers to 'base'. The target may be specified in absolute form, or
+ relative to :current_revisions.
+ """
+ if target is None:
+ return None, None
+ assert isinstance(
+ target, compat.string_types
+ ), "Expected downgrade target in string form"
+ match = _relative_destination.match(target)
+ if match:
+ branch_label, symbol, relative = match.groups()
+ rel_int = int(relative)
+ if rel_int >= 0:
+ if symbol is None:
+ # Downgrading to current + n is not valid.
+ raise RevisionError(
+ "Relative revision %s didn't "
+ "produce %d migrations" % (relative, abs(rel_int))
+ )
+ # Find target revision relative to given symbol.
+ rev = self._walk(
+ symbol,
+ rel_int,
+ branch_label,
+ no_overwalk=assert_relative_length,
+ )
+ if rev is None:
+ raise RevisionError("Walked too far")
+ return branch_label, rev
+ else:
+ relative_revision = symbol is None
+ if relative_revision:
+ # Find target revision relative to current state.
+ if branch_label:
+ symbol = self.filter_for_lineage(
+ util.to_tuple(current_revisions), branch_label
+ )
+ assert len(symbol) == 1
+ symbol = symbol[0]
+ else:
+ current_revisions = util.to_tuple(current_revisions)
+ if not current_revisions:
+ raise RevisionError(
+ "Relative revision %s didn't "
+ "produce %d migrations"
+ % (relative, abs(rel_int))
+ )
+ # Have to check uniques here for duplicate rows test.
+ if len(set(current_revisions)) > 1:
+ util.warn(
+ "downgrade -1 from multiple heads is "
+ "ambiguous; "
+ "this usage will be disallowed in a future "
+ "release."
+ )
+ symbol = current_revisions[0]
+ # Restrict iteration to just the selected branch when
+ # ambiguous branches are involved.
+ branch_label = symbol
+ # Walk down the tree to find downgrade target.
+ rev = self._walk(
+ start=self.get_revision(symbol)
+ if branch_label is None
+ else self.get_revision("%s@%s" % (branch_label, symbol)),
+ steps=rel_int,
+ no_overwalk=assert_relative_length,
+ )
+ if rev is None:
+ if relative_revision:
+ raise RevisionError(
+ "Relative revision %s didn't "
+ "produce %d migrations" % (relative, abs(rel_int))
+ )
+ else:
+ raise RevisionError("Walked too far")
+ return branch_label, rev
+
+ # No relative destination given, revision specified is absolute.
+ branch_label, _, symbol = target.rpartition("@")
+ if not branch_label:
+ branch_label is None
+ return branch_label, self.get_revision(symbol)
+
+ def _parse_upgrade_target(
+ self, current_revisions, target, assert_relative_length
+ ):
+ """
+ Parse upgrade command syntax :target to retrieve the target revision
+ and given the :current_revisons stamp of the database.
+
+ Returns a tuple of Revision objects which should be iterated/upgraded
+ to. The target may be specified in absolute form, or relative to
+ :current_revisions.
+ """
+ if isinstance(target, compat.string_types):
+ match = _relative_destination.match(target)
+ else:
+ match = None
+
+ if not match:
+ # No relative destination, target is absolute.
+ return self.get_revisions(target)
+
+ current_revisions = util.to_tuple(current_revisions)
+
+ branch_label, symbol, relative = match.groups()
+ relative_str = relative
+ relative = int(relative)
+ if relative > 0:
+ if symbol is None:
+ if not current_revisions:
+ current_revisions = (None,)
+ # Try to filter to a single target (avoid ambiguous branches).
+ start_revs = current_revisions
+ if branch_label:
+ start_revs = self.filter_for_lineage(
+ self.get_revisions(current_revisions), branch_label
+ )
+ if not start_revs:
+ # The requested branch is not a head, so we need to
+ # backtrack to find a branchpoint.
+ active_on_branch = self.filter_for_lineage(
+ self._get_ancestor_nodes(
+ self.get_revisions(current_revisions)
+ ),
+ branch_label,
+ )
+ # Find the tips of this set of revisions (revisions
+ # without children within the set).
+ start_revs = tuple(
+ {rev.revision for rev in active_on_branch}
+ - {
+ down
+ for rev in active_on_branch
+ for down in rev._normalized_down_revisions
+ }
+ )
+ if not start_revs:
+ # We must need to go right back to base to find
+ # a starting point for this branch.
+ start_revs = (None,)
+ if len(start_revs) > 1:
+ raise RevisionError(
+ "Ambiguous upgrade from multiple current revisions"
+ )
+ # Walk up from unique target revision.
+ rev = self._walk(
+ start=start_revs[0],
+ steps=relative,
+ branch_label=branch_label,
+ no_overwalk=assert_relative_length,
+ )
+ if rev is None:
+ raise RevisionError(
+ "Relative revision %s didn't "
+ "produce %d migrations" % (relative_str, abs(relative))
+ )
+ return (rev,)
+ else:
+ # Walk is relative to a given revision, not the current state.
+ return (
+ self._walk(
+ start=self.get_revision(symbol),
+ steps=relative,
+ branch_label=branch_label,
+ no_overwalk=assert_relative_length,
+ ),
+ )
+ else:
+ if symbol is None:
+ # Upgrading to current - n is not valid.
+ raise RevisionError(
+ "Relative revision %s didn't "
+ "produce %d migrations" % (relative, abs(relative))
+ )
+ return (
+ self._walk(
+ start=self.get_revision(symbol)
+ if branch_label is None
+ else self.get_revision("%s@%s" % (branch_label, symbol)),
+ steps=relative,
+ no_overwalk=assert_relative_length,
),
)
+
+ def _collect_downgrade_revisions(
+ self, upper, target, inclusive, implicit_base, assert_relative_length
+ ):
+ """
+ Compute the set of current revisions specified by :upper, and the
+ downgrade target specified by :target. Return all dependents of target
+ which are currently active.
+
+ :inclusive=True includes the target revision in the set
+ """
+
+ branch_label, target_revision = self._parse_downgrade_target(
+ current_revisions=upper,
+ target=target,
+ assert_relative_length=assert_relative_length,
)
+ if target_revision == "base":
+ target_revision = None
+ assert target_revision is None or isinstance(target_revision, Revision)
+
+ # Find candidates to drop.
+ if target_revision is None:
+ # Downgrading back to base: find all tree roots.
+ roots = [
+ rev
+ for rev in self._revision_map.values()
+ if rev is not None and rev.down_revision is None
+ ]
+ elif inclusive:
+ # inclusive implies target revision should also be dropped
+ roots = [target_revision]
+ else:
+ # Downgrading to fixed target: find all direct children.
+ roots = list(self.get_revisions(target_revision.nextrev))
- if not total_space:
- # no nodes. determine if this is an invalid range
- # or not.
- start_from = set(requested_lowers)
- start_from.update(
- self._get_ancestor_nodes(
- list(start_from), include_dependencies=True
+ if branch_label and len(roots) > 1:
+ # Need to filter roots.
+ ancestors = {
+ rev.revision
+ for rev in self._get_ancestor_nodes(
+ [self._resolve_branch(branch_label)],
+ include_dependencies=False,
+ )
+ }
+ # Intersection gives the root revisions we are trying to
+ # rollback with the downgrade.
+ roots = list(
+ self.get_revisions(
+ {rev.revision for rev in roots}.intersection(ancestors)
)
)
- # determine all the current branch points represented
- # by requested_lowers
- start_from = self._filter_into_branch_heads(start_from)
+ # Ensure we didn't throw everything away.
+ if len(roots) == 0:
+ raise RevisionError(
+ "Not a valid downgrade target from current heads"
+ )
- # if the requested start is one of those branch points,
- # then just return empty set
- if start_from.intersection(upper_ancestors):
- return
- else:
- # otherwise, they requested nodes out of
- # order
- raise RangeNotAncestorError(lower, upper)
-
- # organize branch points to be consumed separately from
- # member nodes
- branch_todo = set(
- rev
- for rev in (self._revision_map[rev] for rev in total_space)
- if rev._is_real_branch_point
- and len(total_space.intersection(rev._all_nextrev)) > 1
+ heads = self.get_revisions(upper)
+
+ # Aim is to drop :branch_revision; to do so we also need to drop its
+ # descendents and anything dependent on it.
+ downgrade_revisions = set(
+ self._get_descendant_nodes(
+ roots,
+ include_dependencies=True,
+ omit_immediate_dependencies=False,
+ )
+ )
+ active_revisions = set(
+ self._get_ancestor_nodes(heads, include_dependencies=True)
)
+ # Emit revisions to drop in reverse topological sorted order.
+ downgrade_revisions.intersection_update(active_revisions)
+
+ if implicit_base:
+ # Wind other branches back to base.
+ downgrade_revisions.update(
+ active_revisions.difference(self._get_ancestor_nodes(roots))
+ )
+
+ if not downgrade_revisions:
+ # Empty intersection: target revs are not present.
+ raise RangeNotAncestorError("Nothing to drop", upper)
- # it's not possible for any "uppers" to be in branch_todo,
- # because the ._all_nextrev of those nodes is not in total_space
- # assert not branch_todo.intersection(uppers)
+ return downgrade_revisions, heads
- todo = collections.deque(
- r for r in uppers if r.revision in total_space
+ def _collect_upgrade_revisions(
+ self, upper, lower, inclusive, implicit_base, assert_relative_length
+ ):
+ """
+ Compute the set of required revisions specified by :upper, and the
+ current set of active revisions specified by :lower. Find the
+ difference between the two to compute the required upgrades.
+
+ :inclusive=True includes the current/lower revisions in the set
+
+ :implicit_base=False only returns revisions which are downstream
+ of the current/lower revisions. Dependencies from branches with
+ different bases will not be included.
+ """
+ targets = self._parse_upgrade_target(
+ current_revisions=lower,
+ target=upper,
+ assert_relative_length=assert_relative_length,
)
- # iterate for total_space being emptied out
- total_space_modified = True
- while total_space:
+ assert targets is not None
+ assert type(targets) is tuple, "targets should be a tuple"
+
+ # Handled named bases (e.g. branch@... -> heads should only produce
+ # targets on the given branch)
+ if isinstance(lower, compat.string_types) and "@" in lower:
+ branch, _, _ = lower.partition("@")
+ branch_rev = self.get_revision(branch)
+ if branch_rev is not None and branch_rev.revision == branch:
+ # A revision was used as a label; get its branch instead
+ assert len(branch_rev.branch_labels) == 1
+ branch = next(iter(branch_rev.branch_labels))
+ targets = {
+ need for need in targets if branch in need.branch_labels
+ }
+
+ required_node_set = set(
+ self._get_ancestor_nodes(
+ targets, check=True, include_dependencies=True
+ )
+ ).union(targets)
- if not total_space_modified:
- raise RevisionError(
- "Dependency resolution failed; iteration can't proceed"
- )
- total_space_modified = False
- # when everything non-branch pending is consumed,
- # add to the todo any branch nodes that have no
- # descendants left in the queue
- if not todo:
- todo.extendleft(
- sorted(
- (
- rev
- for rev in branch_todo
- if not rev._all_nextrev.intersection(total_space)
- ),
- # favor "revisioned" branch points before
- # dependent ones
- key=lambda rev: 0 if rev.is_branch_point else 1,
- )
- )
- branch_todo.difference_update(todo)
- # iterate nodes that are in the immediate todo
- while todo:
- rev = todo.popleft()
- total_space.remove(rev.revision)
- total_space_modified = True
-
- # do depth first for elements within branches,
- # don't consume any actual branch nodes
- todo.extendleft(
- [
- self._revision_map[downrev]
- for downrev in reversed(rev._normalized_down_revisions)
- if self._revision_map[downrev] not in branch_todo
- and downrev in total_space
- ]
- )
+ current_revisions = self.get_revisions(lower)
+ if not implicit_base and any(
+ rev not in required_node_set
+ for rev in current_revisions
+ if rev is not None
+ ):
+ raise RangeNotAncestorError(lower, upper)
+ assert (
+ type(current_revisions) is tuple
+ ), "current_revisions should be a tuple"
+
+ # Special case where lower = a relative value (get_revisions can't
+ # find it)
+ if current_revisions and current_revisions[0] is None:
+ _, rev = self._parse_downgrade_target(
+ current_revisions=upper,
+ target=lower,
+ assert_relative_length=assert_relative_length,
+ )
+ if rev == "base":
+ current_revisions = tuple()
+ lower = None
+ else:
+ current_revisions = (rev,)
+ lower = rev.revision
- if not inclusive and rev in requested_lowers:
- continue
- yield rev
+ current_node_set = set(
+ self._get_ancestor_nodes(
+ current_revisions, check=True, include_dependencies=True
+ )
+ ).union(current_revisions)
+
+ needs = required_node_set.difference(current_node_set)
+
+ # Include the lower revision (=current_revisions?) in the iteration
+ if inclusive:
+ needs.update(self.get_revisions(lower))
+ # By default, base is implicit as we want all dependencies returned.
+ # Base is also implicit if lower = base
+ # implicit_base=False -> only return direct downstreams of
+ # current_revisions
+ if current_revisions and not implicit_base:
+ lower_descendents = self._get_descendant_nodes(
+ current_revisions, check=True, include_dependencies=False
+ )
+ needs.intersection_update(lower_descendents)
- assert not branch_todo
+ return needs, targets
class Revision(object):