diff options
author | Olly Cope <olly@ollycope.com> | 2022-08-29 12:50:50 +0000 |
---|---|---|
committer | Olly Cope <olly@ollycope.com> | 2022-08-29 12:50:50 +0000 |
commit | 18fd8dbb678ffb818f57904dcc5b39011ab2ce38 (patch) | |
tree | 24dedf26ccd647a9b455c4cb739fd21946e370cb /yoyo | |
parent | 6dc7bcefcdcc93a5f00e837d4f34651d8f170390 (diff) | |
parent | 8eced158d43573528e12b95893f527d684a3196b (diff) | |
download | yoyo-18fd8dbb678ffb818f57904dcc5b39011ab2ce38.tar.gz |
merge new toposort algorithm
Diffstat (limited to 'yoyo')
-rwxr-xr-x | yoyo/migrations.py | 83 | ||||
-rw-r--r-- | yoyo/tests/test_migrations.py | 113 | ||||
-rw-r--r-- | yoyo/tests/test_topologicalsort.py | 77 | ||||
-rw-r--r-- | yoyo/topologicalsort.py | 79 |
4 files changed, 180 insertions, 172 deletions
diff --git a/yoyo/migrations.py b/yoyo/migrations.py index a34d262..2c067ee 100755 --- a/yoyo/migrations.py +++ b/yoyo/migrations.py @@ -15,7 +15,6 @@ from collections import Counter from collections import OrderedDict from collections import abc -from collections import defaultdict from copy import copy from glob import glob from itertools import chain @@ -719,72 +718,18 @@ def heads(migration_list): return heads -def topological_sort( - migration_list: Iterable[Migration], -) -> Iterable[Migration]: - - # Make a copy of migration_list. It's probably an iterator. - migration_list = list(migration_list) - - # Track graph edges in two parallel data structures. - # Use OrderedDict so that we can traverse edges in order - # and keep the sort stable - forward_edges = defaultdict( - OrderedDict - ) # type: Dict[Migration, Dict[Migration, int]] - backward_edges = defaultdict( - OrderedDict - ) # type: Dict[Migration, Dict[Migration, int]] - - def sort_by_stability_order( - items, ordering={m: index for index, m in enumerate(migration_list)} - ): - return sorted( - (item for item in items if item in ordering), key=ordering.get - ) +def topological_sort(migrations: Iterable[Migration]) -> Iterable[Migration]: + from yoyo.topologicalsort import topological_sort as topological_sort_impl + from yoyo.topologicalsort import CycleError - for m in migration_list: - for n in sort_by_stability_order(m.depends): - forward_edges[n][m] = 1 - backward_edges[m][n] = 1 - - def check_cycles(item): - stack: List[Tuple[Migration, List[Migration]]] = [(item, [])] - while stack: - n, path = stack.pop() - if n in path: - raise exceptions.BadMigration( - "Circular dependencies among these migrations {}".format( - ", ".join(m.id for m in path + [n]) - ) - ) - stack.extend((f, path + [n]) for f in forward_edges[n]) - - seen = set() - for item in migration_list: - - if item in seen: - continue - - check_cycles(item) - - # if item is in a dedepency graph, go back to the root node - while backward_edges[item]: - item = next(iter(backward_edges[item])) - - # is item at the start of a dependency graph? - if forward_edges[item]: - stack = [item] - while stack: - m = stack.pop() - yield m - seen.add(m) - for child in list(reversed(list(forward_edges[m]))): - if all( - dependency in seen - for dependency in backward_edges[child] - ): - stack.append(child) - else: - yield item - seen.add(item) + migration_list = list(migrations) + all_migrations = set(migration_list) + dependency_graph = {m: (m.depends & all_migrations) for m in migration_list} + try: + return topological_sort_impl(migration_list, dependency_graph) + except CycleError as e: + raise exceptions.BadMigration( + "Circular dependencies among these migrations {}".format( + ", ".join(m.id for m in e.args[1]) + ) + ) diff --git a/yoyo/tests/test_migrations.py b/yoyo/tests/test_migrations.py index bdfab00..09f94ba 100644 --- a/yoyo/tests/test_migrations.py +++ b/yoyo/tests/test_migrations.py @@ -16,7 +16,6 @@ from datetime import datetime from datetime import timedelta from mock import Mock, patch import io -import itertools import os import pytest @@ -28,7 +27,7 @@ from yoyo import ancestors, descendants from yoyo.tests import migrations_dir from yoyo.tests import tempdir -from yoyo.migrations import topological_sort, MigrationList +from yoyo.migrations import MigrationList from yoyo.scripts import newmigration @@ -269,79 +268,6 @@ def test_grouped_migrations_can_be_rolled_back(backend): backend.rollback_migrations(read_migrations(t1)) -class TestTopologicalSort(object): - def check(self, nodes, edges, expected_order): - migrations = self.get_mock_migrations(nodes, edges) - output_order = "".join(m.id for m in topological_sort(migrations)) - assert output_order == expected_order - - def get_mock_migrations(self, nodes="ABCD", edges=[]): - class MockMigration(Mock): - def __repr__(self): - return "<MockMigration {}>".format(self.id) - - migrations = {n: MockMigration(id=n, depends=set()) for n in nodes} - for g in edges: - for edge in zip(g, g[1:]): - migrations[edge[1]].depends.add(migrations[edge[0]]) - - return [migrations[n] for n in nodes] - - def test_it_keeps_stable_order(self): - - for s in itertools.permutations("ABCD"): - self.check(s, {}, "".join(s)) - - def test_it_sorts_topologically(self): - - # Single group at start - self.check("ABCD", {"AB"}, "ABCD") - self.check("BACD", {"AB"}, "ABCD") - - # Single group in middle start - self.check("CABD", {"AB"}, "CABD") - self.check("CBAD", {"AB"}, "CABD") - - # Extended group - self.check("ABCD", {"AB", "AD"}, "ABDC") - self.check("DBCA", {"AB", "AD"}, "ADBC") - - # Non-connected groups - self.check("ABCDEF", {"CB", "ED"}, "ACBEDF") - self.check("ADEBCF", {"CB", "ED"}, "AEDCBF") - self.check("ADEFBC", {"CB", "ED"}, "AEDFCB") - self.check("DBAFEC", {"CB", "ED"}, "EDCBAF") - - def test_it_discards_missing_dependencies(self): - A, B, C, D = self.get_mock_migrations() - C.depends.add(Mock()) - assert list(topological_sort([A, B, C, D])) == [A, B, C, D] - - def test_it_catches_cycles(self): - A, B, C, D = self.get_mock_migrations() - C.depends.add(C) - with pytest.raises(exceptions.BadMigration): - self.check("ABCD", {"AA"}, "") - with pytest.raises(exceptions.BadMigration): - self.check("ABCD", {"AB", "BA"}, "") - with pytest.raises(exceptions.BadMigration): - self.check("ABCD", {"AB", "BC", "CB"}, "") - with pytest.raises(exceptions.BadMigration): - self.check("ABCD", {"AB", "BC", "CA"}, "") - - def test_it_handles_multiple_edges_to_the_same_node(self): - self.check("ABCD", {"AB", "AC", "AD"}, "ABCD") - self.check("DCBA", {"AB", "AC", "AD"}, "ADCB") - - def test_it_handles_multiple_edges_to_the_same_node2(self): - # A --> B - # | ^ - # v | - # C --- + - for input_order in itertools.permutations("ABC"): - self.check(input_order, {"AB", "AC", "CB"}, "ACB") - - class TestMigrationList(object): def test_can_create_empty(self): m = MigrationList() @@ -420,9 +346,7 @@ class TestReadMigrations(object): The yoyo new command creates temporary files in the migrations directory. These shouldn't be picked up by yoyo apply etc """ - with migrations_dir( - **{newmigration.tempfile_prefix + "test": ""} - ) as tmpdir: + with migrations_dir(**{newmigration.tempfile_prefix + "test": ""}) as tmpdir: assert len(read_migrations(tmpdir)) == 0 def test_it_loads_post_apply_scripts(self): @@ -441,12 +365,8 @@ class TestReadMigrations(object): m.load() assert len(m.steps) == 1 - def test_it_does_not_add_duplicate_steps_with_imported_symbols( - self, tmpdir - ): - with migrations_dir( - a="from yoyo import step; step('SELECT 1')" - ) as tmpdir: + def test_it_does_not_add_duplicate_steps_with_imported_symbols(self, tmpdir): + with migrations_dir(a="from yoyo import step; step('SELECT 1')") as tmpdir: m = read_migrations(tmpdir)[0] m.load() assert len(m.steps) == 1 @@ -535,9 +455,7 @@ class TestReadMigrations(object): def test_it_sets_depends_for_sql_migrations(self): def check(sql, expected): - with migrations_dir( - **{"1.sql": "", "2.sql": "", "3.sql": sql} - ) as tmp: + with migrations_dir(**{"1.sql": "", "2.sql": "", "3.sql": sql}) as tmp: migration = read_migrations(tmp)[-1] migration.load() @@ -638,27 +556,20 @@ class TestLogging(object): "from _yoyo_log " "ORDER BY id DESC LIMIT 1" ) - return { - d[0]: value - for d, value in zip(cursor.description, cursor.fetchone()) - } + return {d[0]: value for d, value in zip(cursor.description, cursor.fetchone())} def get_log_count(self, backend): return backend.execute("SELECT count(1) FROM _yoyo_log").fetchone()[0] def test_it_logs_apply_and_rollback(self, backend): - with migrations_dir( - a='step("CREATE TABLE yoyo_test (id INT)")' - ) as tmpdir: + with migrations_dir(a='step("CREATE TABLE yoyo_test (id INT)")') as tmpdir: migrations = read_migrations(tmpdir) backend.apply_migrations(migrations) assert self.get_log_count(backend) == 1 logged = self.get_last_log_entry(backend) assert logged["migration_id"] == "a" assert logged["operation"] == "apply" - assert logged["created_at_utc"] >= datetime.utcnow() - timedelta( - seconds=3 - ) + assert logged["created_at_utc"] >= datetime.utcnow() - timedelta(seconds=3) apply_time = logged["created_at_utc"] backend.rollback_migrations(migrations) @@ -669,18 +580,14 @@ class TestLogging(object): assert logged["created_at_utc"] >= apply_time def test_it_logs_mark_and_unmark(self, backend): - with migrations_dir( - a='step("CREATE TABLE yoyo_test (id INT)")' - ) as tmpdir: + with migrations_dir(a='step("CREATE TABLE yoyo_test (id INT)")') as tmpdir: migrations = read_migrations(tmpdir) backend.mark_migrations(migrations) assert self.get_log_count(backend) == 1 logged = self.get_last_log_entry(backend) assert logged["migration_id"] == "a" assert logged["operation"] == "mark" - assert logged["created_at_utc"] >= datetime.utcnow() - timedelta( - seconds=3 - ) + assert logged["created_at_utc"] >= datetime.utcnow() - timedelta(seconds=3) marked_time = logged["created_at_utc"] backend.unmark_migrations(migrations) diff --git a/yoyo/tests/test_topologicalsort.py b/yoyo/tests/test_topologicalsort.py new file mode 100644 index 0000000..be308d2 --- /dev/null +++ b/yoyo/tests/test_topologicalsort.py @@ -0,0 +1,77 @@ +import itertools + +import pytest + +from yoyo import topologicalsort + + +class TestTopologicalSort(object): + def check(self, nodes, edges, expected): + deps = {} + edges = edges.split() if edges else [] + for a, b in edges: + deps.setdefault(a, set()).add(b) + output = list(topologicalsort.topological_sort(nodes, deps)) + for a, b in edges: + try: + assert output.index(a) > output.index(b) + except ValueError: + pass + assert output == list(expected) + + def test_it_keeps_stable_order(self): + for s in map(str, itertools.permutations("ABCD")): + self.check(s, "", s) + + def test_it_sorts_topologically(self): + + # Single group at start + self.check("ABCD", "BA", "ABCD") + self.check("BACD", "BA", "ABCD") + + # Single group in middle start + self.check("CABD", "BA", "CABD") + self.check("CBAD", "BA", "CABD") + + # Extended group + self.check("ABCD", "BA DA", "ABCD") + self.check("DBCA", "BA DA", "CADB") + + # Non-connected groups + self.check("ABCDEF", "BC DE", "ACBEDF") + self.check("ADEBCF", "BC DE", "AEDCBF") + self.check("ADEFBC", "BC DE", "AEDFCB") + self.check("DBAFEC", "BC DE", "AFEDCB") + + def test_it_discards_missing_dependencies(self): + self.check("ABCD", "CX XY", "ABCD") + + def test_it_catches_cycles(self): + with pytest.raises(topologicalsort.CycleError): + self.check("ABCD", "AA", "") + with pytest.raises(topologicalsort.CycleError): + self.check("ABCD", "AB BA", "") + with pytest.raises(topologicalsort.CycleError): + self.check("ABCD", "AB BC CB", "") + with pytest.raises(topologicalsort.CycleError): + self.check("ABCD", "AB BC CA", "") + + def test_it_handles_multiple_edges_to_the_same_node(self): + self.check("ABCD", "BA CA DA", "ABCD") + self.check("DCBA", "BA CA DA", "ADCB") + + def test_it_handles_multiple_edges_to_the_same_node2(self): + # A --> B + # | ^ + # v | + # C --- + + for input_order in itertools.permutations("ABC"): + self.check(input_order, "BA CA BC", "ACB") + + def test_it_doesnt_modify_order_unnecessarily(self): + """ + Test for issue raised in + + https://lists.sr.ht/~olly/yoyo/%3C09c43045fdf14024a0f2e905408ea41f%40atos.net%3E + """ + self.check("ABC", "CA", "ABC") diff --git a/yoyo/topologicalsort.py b/yoyo/topologicalsort.py new file mode 100644 index 0000000..e74bf4a --- /dev/null +++ b/yoyo/topologicalsort.py @@ -0,0 +1,79 @@ +from typing import Set +from typing import Dict +from typing import Mapping +from typing import TypeVar +from typing import Iterable +from typing import Collection +from collections import defaultdict +from heapq import heappop +from heapq import heappush + + +class CycleError(ValueError): + """ + Raised when cycles exist in the input graph. + + The second element in the args attribute of instances will contain the + sequence of nodes in which the cycle lies. + """ + + +T = TypeVar("T") + + +def topological_sort( + items: Iterable[T], dependency_graph: Mapping[T, Collection[T]] +) -> Iterable[T]: + + # Tag each item with its input order + pqueue = list(enumerate(items)) + ordering = {item: ix for ix, item in pqueue} + seen_since_last_change = 0 + output: Set[T] = set() + + # Map blockers to the list of items they block + blocked_on: Dict[T, Set[T]] = defaultdict(set) + blocked: Set[T] = set() + + while pqueue: + if seen_since_last_change == len(pqueue) + len(blocked): + raise_cycle_error(ordering, pqueue, blocked_on) + + _, n = heappop(pqueue) + + blockers = { + d for d in dependency_graph.get(n, []) if d not in output and d in ordering + } + if not blockers: + seen_since_last_change = 0 + output.add(n) + if n in blocked: + blocked.remove(n) + yield n + for b in blocked_on.pop(n, []): + if not any(b in other for other in blocked_on.values()): + heappush(pqueue, (ordering[b], b)) + else: + if n in blocked: + seen_since_last_change += 1 + else: + seen_since_last_change = 0 + blocked.add(n) + for b in blockers: + blocked_on[b].add(n) + if blocked_on: + raise_cycle_error(ordering, pqueue, blocked_on) + + +def raise_cycle_error(ordering, pqueue, blocked_on): + bad = next((item for item in blocked_on if item not in ordering), None) + if bad: + raise ValueError(f"Dependency graph contains a non-existent node {bad!r}") + unresolved = {n for _, n in pqueue} + unresolved.update(*blocked_on.values()) + if unresolved: + raise CycleError( + f"Dependency graph loop detected among {unresolved!r}", + list(sorted(unresolved, key=ordering.get)), + ) + raise AssertionError("raise_cycle_error called but no unresovled nodes exist") |