diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2014-11-15 19:18:02 -0500 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2014-11-15 19:18:02 -0500 |
commit | 755bea9ab281e929301ab1c554635a1d330dcde8 (patch) | |
tree | b38ca55f1bf234158d31351dc39767b97be76bc7 | |
parent | 8839aaba20fbb94982f0697e9edc921cb21c8727 (diff) | |
download | alembic-755bea9ab281e929301ab1c554635a1d330dcde8.tar.gz |
- working out how to ensure upgrades/downgrades can always proceed along a single
branch, or all branches
-rw-r--r-- | alembic/revision.py | 84 | ||||
-rw-r--r-- | tests/test_revision.py | 110 |
2 files changed, 126 insertions, 68 deletions
diff --git a/alembic/revision.py b/alembic/revision.py index d082296..f4935e8 100644 --- a/alembic/revision.py +++ b/alembic/revision.py @@ -160,10 +160,7 @@ class RevisionMap(object): """ current_heads = self.heads if branch_name: - current_heads = [ - h for h in current_heads - if self._shares_lineage(h, branch_name) - ] + current_heads = self.filter_for_lineage(current_heads, branch_name) if len(current_heads) > 1: raise MultipleHeads("Multiple heads are present") @@ -189,10 +186,13 @@ class RevisionMap(object): full revision. """ - resolved_id, branch_name = self._resolve_revision_number(id_) - return tuple( - self._revision_for_ident(rev_id, branch_name) - for rev_id in resolved_id) + if isinstance(id_, (list, tuple)): + return sum([self.get_revisions(id_elem) for id_elem in id_], ()) + else: + resolved_id, branch_name = self._resolve_revision_number(id_) + return tuple( + self._revision_for_ident(rev_id, branch_name) + for rev_id in resolved_id) def get_revision(self, id_): """Return the :class:`.Revision` instance with the given rev id. @@ -243,9 +243,7 @@ class RevisionMap(object): revs = [x for x in self._revision_map if x and x.startswith(resolved_id)] if branch_rev: - revs = [ - x for x in revs if - self._shares_lineage(x, branch_rev)] + revs = self.filter_for_lineage(revs, check_branch) if not revs: raise ResolutionError("No such revision '%s'" % resolved_id) elif len(revs) > 1: @@ -267,28 +265,24 @@ class RevisionMap(object): return revision def filter_for_lineage(self, targets, check_against): - if not isinstance(check_against, Revision): - check_against = self.get_revisions(check_against) + id_, branch_name = self._resolve_revision_number(check_against) return [ tg for tg in targets - if self._shares_lineage(tg, check_against)] + if self._shares_lineage(tg, branch_name or id_[0])] - def _shares_lineage(self, target, test_against_revs): - if not test_against_revs: + def _shares_lineage(self, target, test_against_rev): + if not test_against_rev: return True if not isinstance(target, Revision): target = self._revision_for_ident(target) - test_against_revs = [ - self._revision_for_ident(tr) - if not isinstance(tr, Revision) else tr - for tr in util.to_tuple(test_against_revs, ()) - ] + if not isinstance(test_against_rev, Revision): + test_against_rev = self._revision_for_ident(test_against_rev) return bool( self._get_descendant_nodes([target]) .union(self._get_ancestor_nodes([target])) - .intersection(test_against_revs) + .intersection([test_against_rev]) ) def _resolve_revision_number(self, id_): @@ -311,6 +305,7 @@ class RevisionMap(object): elif id_ == 'base' or id_ is None: return (), branch_name else: + assert isinstance(id_, compat.string_types) return util.to_tuple(id_, default=None), branch_name def iterate_revisions(self, upper, lower): @@ -366,8 +361,8 @@ class RevisionMap(object): tg for tg in targets if tg is not target ): raise RevisionError( - "Requested base revision %s overlaps with " - "other requested base revisions" % target.revision) + "Requested revision %s overlaps with " + "other requested revisions" % target.revision) total_descendants.update(descendants) return total_descendants @@ -387,8 +382,8 @@ class RevisionMap(object): tg for tg in targets if tg is not target ): raise RevisionError( - "Requested head revision %s overlaps with " - "other requested head revisions" % target.revision) + "Requested revision %s overlaps with " + "other requested revisions" % target.revision) total_ancestors.update(ancestors) return total_ancestors @@ -399,12 +394,34 @@ class RevisionMap(object): across branches as a whole. """ - lowers = self.get_revisions(lower) - if not lowers: # lower of None or (), we go to the bases. - lowers = self.get_revisions(self._get_base_revisions(lower)) - inclusive = True + + requested_lowers = self.get_revisions(lower) + + # some complexity to accommodate an iteration where some + # branches are starting from nothing, and others are starting + # from a given point. Additionally, if the bottom branch + # is specified using a branch identifier, then we limit operations + # to just that branch. + + limit_to_lower_branch = \ + isinstance(lower, compat.string_types) and '@' in lower + + if not limit_to_lower_branch or not requested_lowers: + if not requested_lowers and limit_to_lower_branch: + base_lowers = self.get_revisions( + self._get_base_revisions(lower)) + lowers = base_lowers + else: + base_lowers = set(self.get_revisions(self.bases)) + base_lowers.difference_update( + self._get_ancestor_nodes(requested_lowers)) + lowers = base_lowers.union(requested_lowers) + else: + base_lowers = set() + lowers = requested_lowers uppers = self.get_revisions(upper) + total_space = set( rev.revision for rev in self._get_ancestor_nodes(uppers) @@ -453,8 +470,11 @@ class RevisionMap(object): and downrev in total_space ]) - if inclusive or rev not in lowers: - yield rev + if not inclusive and rev in requested_lowers: + continue + #if rev in base_lowers + #if inclusive or rev not in lowers: + yield rev class Revision(object): diff --git a/tests/test_revision.py b/tests/test_revision.py index 3468741..8e47f64 100644 --- a/tests/test_revision.py +++ b/tests/test_revision.py @@ -63,7 +63,7 @@ class APITest(TestBase): ) assert_raises_message( MultipleHeads, - "Multiple heads are present; please use current_heads()", + "Multiple heads are present", map_.get_revision, 'head' ) @@ -80,10 +80,13 @@ class APITest(TestBase): class DownIterateTest(TestBase): - def _assert_iteration(self, upper, lower, assertion, inclusive=True): + def _assert_iteration( + self, upper, lower, assertion, inclusive=True, map_=None): + if map_ is None: + map_ = self.map eq_( [rev.revision for rev in - self.map._iterate_revisions(upper, lower, inclusive=inclusive)], + map_._iterate_revisions(upper, lower, inclusive=inclusive)], assertion ) @@ -107,7 +110,7 @@ class DiamondTest(DownIterateTest): ) -class NamedBranchTest(TestBase): +class NamedBranchTest(DownIterateTest): def test_dupe_branch_collection(self): fn = lambda: [ Revision('a', ()), @@ -124,7 +127,7 @@ class NamedBranchTest(TestBase): ) def setUp(self): - self.map_ = RevisionMap(lambda: [ + self.map = RevisionMap(lambda: [ Revision('a', (), branch_names='abranch'), Revision('b', ('a',)), Revision('somelongername', ('b',)), @@ -135,62 +138,84 @@ class NamedBranchTest(TestBase): Revision('f', ('someothername',)), ]) + def test_iterate_head_to_named_base(self): + self._assert_iteration( + "heads", "ebranch@base", + ['f', 'someothername', 'e', 'd'] + ) + + self._assert_iteration( + "heads", "abranch@base", + ['c', 'somelongername', 'b', 'a'] + ) + + def test_iterate_head_to_version_specific_base(self): + self._assert_iteration( + "heads", "e@base", + ['f', 'someothername', 'e', 'd'] + ) + + self._assert_iteration( + "heads", "c@base", + ['c', 'somelongername', 'b', 'a'] + ) + def test_partial_id_resolve(self): - eq_(self.map_.get_revision("ebranch@some").revision, "someothername") - eq_(self.map_.get_revision("abranch@some").revision, "somelongername") + eq_(self.map.get_revision("ebranch@some").revision, "someothername") + eq_(self.map.get_revision("abranch@some").revision, "somelongername") def test_branch_at_heads(self): assert_raises_message( RevisionError, "Branch name given with 'heads' makes no sense", - self.map_.get_revision, "abranch@heads" + self.map.get_revision, "abranch@heads" ) def test_branch_at_syntax(self): - eq_(self.map_.get_revision("abranch@head").revision, 'c') - eq_(self.map_.get_revision("abranch@base"), None) - eq_(self.map_.get_revision("ebranch@head").revision, 'f') - eq_(self.map_.get_revision("abranch@base"), None) - eq_(self.map_.get_revision("ebranch@d").revision, 'd') + eq_(self.map.get_revision("abranch@head").revision, 'c') + eq_(self.map.get_revision("abranch@base"), None) + eq_(self.map.get_revision("ebranch@head").revision, 'f') + eq_(self.map.get_revision("abranch@base"), None) + eq_(self.map.get_revision("ebranch@d").revision, 'd') def test_branch_at_self(self): - eq_(self.map_.get_revision("ebranch@ebranch").revision, 'e') + eq_(self.map.get_revision("ebranch@ebranch").revision, 'e') def test_retrieve_branch_revision(self): - eq_(self.map_.get_revision("abranch").revision, 'a') - eq_(self.map_.get_revision("ebranch").revision, 'e') + eq_(self.map.get_revision("abranch").revision, 'a') + eq_(self.map.get_revision("ebranch").revision, 'e') def test_rev_not_in_branch(self): assert_raises_message( RevisionError, "Revision b is not a member of branch 'ebranch'", - self.map_.get_revision, "ebranch@b" + self.map.get_revision, "ebranch@b" ) assert_raises_message( RevisionError, "Revision d is not a member of branch 'abranch'", - self.map_.get_revision, "abranch@d" + self.map.get_revision, "abranch@d" ) def test_no_revision_exists(self): assert_raises_message( RevisionError, "No such revision 'q'", - self.map_.get_revision, "abranch@q" + self.map.get_revision, "abranch@q" ) def test_not_actually_a_branch(self): - eq_(self.map_.get_revision("e@d").revision, "d") + eq_(self.map.get_revision("e@d").revision, "d") def test_not_actually_a_branch_partial_resolution(self): - eq_(self.map_.get_revision("someoth@d").revision, "d") + eq_(self.map.get_revision("someoth@d").revision, "d") def test_no_such_branch(self): assert_raises_message( RevisionError, "No such branch: 'x'", - self.map_.get_revision, "x@d" + self.map.get_revision, "x@d" ) @@ -316,6 +341,14 @@ class BranchTravellingTest(DownIterateTest): ] ) + def test_three_branches_end_in_single_branch(self): + + self._assert_iteration( + ["merge", "fe1b1"], "a3", + ['merge', 'e2b1', 'e2b2', 'db2', 'cb2', 'b2', + 'fe1b1', 'e1b1', 'db1', 'cb1', 'b1', 'a3'] + ) + def test_two_branches_to_root(self): # here we want 'a3' as a "stop" branch point, but *not* @@ -368,16 +401,6 @@ class BranchTravellingTest(DownIterateTest): ] # noqa ) - def test_three_branches_end_in_single_branch(self): - - # in this case, both "a3" and "db1" are stop points - self._assert_iteration( - ["merge", "fe1b1"], "e1b1", - ['merge', - 'fe1b1', 'e1b1', # fe1b1 branch - ] # noqa - ) - def test_three_branches_end_multiple_bases(self): # in this case, both "a3" and "db1" are stop points @@ -411,8 +434,8 @@ class BranchTravellingTest(DownIterateTest): # db1 is an ancestor of fe1b1 assert_raises_message( RevisionError, - "Requested head revision fe1b1 overlaps " - "with other requested head revisions", + "Requested revision fe1b1 overlaps " + "with other requested revisions", list, self.map._iterate_revisions(["db1", "b2", "fe1b1"], ()) ) @@ -505,8 +528,23 @@ class MultipleBaseTest(DownIterateTest): def test_detect_invalid_base_selection(self): assert_raises_message( RevisionError, - "Requested base revision a2 overlaps with " - "other requested base revisions", + "Requested revision b2 overlaps with " + "other requested revisions", list, self.map._iterate_revisions(["c2"], ["a2", "b2"]) ) + + def test_heads_to_revs_plus_base_exclusive(self): + self._assert_iteration( + "heads", ["c2", "base"], + [ + 'b1a', 'a1a', + 'b1b', 'a1b', + 'mergeb3d2', + 'b3', 'a3', 'base3', + 'd2', + 'base1' + ], + inclusive=False + ) + |