summaryrefslogtreecommitdiff
path: root/yoyo
diff options
context:
space:
mode:
authorOlly Cope <olly@ollycope.com>2022-08-29 12:50:50 +0000
committerOlly Cope <olly@ollycope.com>2022-08-29 12:50:50 +0000
commit18fd8dbb678ffb818f57904dcc5b39011ab2ce38 (patch)
tree24dedf26ccd647a9b455c4cb739fd21946e370cb /yoyo
parent6dc7bcefcdcc93a5f00e837d4f34651d8f170390 (diff)
parent8eced158d43573528e12b95893f527d684a3196b (diff)
downloadyoyo-18fd8dbb678ffb818f57904dcc5b39011ab2ce38.tar.gz
merge new toposort algorithm
Diffstat (limited to 'yoyo')
-rwxr-xr-xyoyo/migrations.py83
-rw-r--r--yoyo/tests/test_migrations.py113
-rw-r--r--yoyo/tests/test_topologicalsort.py77
-rw-r--r--yoyo/topologicalsort.py79
4 files changed, 180 insertions, 172 deletions
diff --git a/yoyo/migrations.py b/yoyo/migrations.py
index a34d262..2c067ee 100755
--- a/yoyo/migrations.py
+++ b/yoyo/migrations.py
@@ -15,7 +15,6 @@
from collections import Counter
from collections import OrderedDict
from collections import abc
-from collections import defaultdict
from copy import copy
from glob import glob
from itertools import chain
@@ -719,72 +718,18 @@ def heads(migration_list):
return heads
-def topological_sort(
- migration_list: Iterable[Migration],
-) -> Iterable[Migration]:
-
- # Make a copy of migration_list. It's probably an iterator.
- migration_list = list(migration_list)
-
- # Track graph edges in two parallel data structures.
- # Use OrderedDict so that we can traverse edges in order
- # and keep the sort stable
- forward_edges = defaultdict(
- OrderedDict
- ) # type: Dict[Migration, Dict[Migration, int]]
- backward_edges = defaultdict(
- OrderedDict
- ) # type: Dict[Migration, Dict[Migration, int]]
-
- def sort_by_stability_order(
- items, ordering={m: index for index, m in enumerate(migration_list)}
- ):
- return sorted(
- (item for item in items if item in ordering), key=ordering.get
- )
+def topological_sort(migrations: Iterable[Migration]) -> Iterable[Migration]:
+ from yoyo.topologicalsort import topological_sort as topological_sort_impl
+ from yoyo.topologicalsort import CycleError
- for m in migration_list:
- for n in sort_by_stability_order(m.depends):
- forward_edges[n][m] = 1
- backward_edges[m][n] = 1
-
- def check_cycles(item):
- stack: List[Tuple[Migration, List[Migration]]] = [(item, [])]
- while stack:
- n, path = stack.pop()
- if n in path:
- raise exceptions.BadMigration(
- "Circular dependencies among these migrations {}".format(
- ", ".join(m.id for m in path + [n])
- )
- )
- stack.extend((f, path + [n]) for f in forward_edges[n])
-
- seen = set()
- for item in migration_list:
-
- if item in seen:
- continue
-
- check_cycles(item)
-
- # if item is in a dedepency graph, go back to the root node
- while backward_edges[item]:
- item = next(iter(backward_edges[item]))
-
- # is item at the start of a dependency graph?
- if forward_edges[item]:
- stack = [item]
- while stack:
- m = stack.pop()
- yield m
- seen.add(m)
- for child in list(reversed(list(forward_edges[m]))):
- if all(
- dependency in seen
- for dependency in backward_edges[child]
- ):
- stack.append(child)
- else:
- yield item
- seen.add(item)
+ migration_list = list(migrations)
+ all_migrations = set(migration_list)
+ dependency_graph = {m: (m.depends & all_migrations) for m in migration_list}
+ try:
+ return topological_sort_impl(migration_list, dependency_graph)
+ except CycleError as e:
+ raise exceptions.BadMigration(
+ "Circular dependencies among these migrations {}".format(
+ ", ".join(m.id for m in e.args[1])
+ )
+ )
diff --git a/yoyo/tests/test_migrations.py b/yoyo/tests/test_migrations.py
index bdfab00..09f94ba 100644
--- a/yoyo/tests/test_migrations.py
+++ b/yoyo/tests/test_migrations.py
@@ -16,7 +16,6 @@ from datetime import datetime
from datetime import timedelta
from mock import Mock, patch
import io
-import itertools
import os
import pytest
@@ -28,7 +27,7 @@ from yoyo import ancestors, descendants
from yoyo.tests import migrations_dir
from yoyo.tests import tempdir
-from yoyo.migrations import topological_sort, MigrationList
+from yoyo.migrations import MigrationList
from yoyo.scripts import newmigration
@@ -269,79 +268,6 @@ def test_grouped_migrations_can_be_rolled_back(backend):
backend.rollback_migrations(read_migrations(t1))
-class TestTopologicalSort(object):
- def check(self, nodes, edges, expected_order):
- migrations = self.get_mock_migrations(nodes, edges)
- output_order = "".join(m.id for m in topological_sort(migrations))
- assert output_order == expected_order
-
- def get_mock_migrations(self, nodes="ABCD", edges=[]):
- class MockMigration(Mock):
- def __repr__(self):
- return "<MockMigration {}>".format(self.id)
-
- migrations = {n: MockMigration(id=n, depends=set()) for n in nodes}
- for g in edges:
- for edge in zip(g, g[1:]):
- migrations[edge[1]].depends.add(migrations[edge[0]])
-
- return [migrations[n] for n in nodes]
-
- def test_it_keeps_stable_order(self):
-
- for s in itertools.permutations("ABCD"):
- self.check(s, {}, "".join(s))
-
- def test_it_sorts_topologically(self):
-
- # Single group at start
- self.check("ABCD", {"AB"}, "ABCD")
- self.check("BACD", {"AB"}, "ABCD")
-
- # Single group in middle start
- self.check("CABD", {"AB"}, "CABD")
- self.check("CBAD", {"AB"}, "CABD")
-
- # Extended group
- self.check("ABCD", {"AB", "AD"}, "ABDC")
- self.check("DBCA", {"AB", "AD"}, "ADBC")
-
- # Non-connected groups
- self.check("ABCDEF", {"CB", "ED"}, "ACBEDF")
- self.check("ADEBCF", {"CB", "ED"}, "AEDCBF")
- self.check("ADEFBC", {"CB", "ED"}, "AEDFCB")
- self.check("DBAFEC", {"CB", "ED"}, "EDCBAF")
-
- def test_it_discards_missing_dependencies(self):
- A, B, C, D = self.get_mock_migrations()
- C.depends.add(Mock())
- assert list(topological_sort([A, B, C, D])) == [A, B, C, D]
-
- def test_it_catches_cycles(self):
- A, B, C, D = self.get_mock_migrations()
- C.depends.add(C)
- with pytest.raises(exceptions.BadMigration):
- self.check("ABCD", {"AA"}, "")
- with pytest.raises(exceptions.BadMigration):
- self.check("ABCD", {"AB", "BA"}, "")
- with pytest.raises(exceptions.BadMigration):
- self.check("ABCD", {"AB", "BC", "CB"}, "")
- with pytest.raises(exceptions.BadMigration):
- self.check("ABCD", {"AB", "BC", "CA"}, "")
-
- def test_it_handles_multiple_edges_to_the_same_node(self):
- self.check("ABCD", {"AB", "AC", "AD"}, "ABCD")
- self.check("DCBA", {"AB", "AC", "AD"}, "ADCB")
-
- def test_it_handles_multiple_edges_to_the_same_node2(self):
- # A --> B
- # | ^
- # v |
- # C --- +
- for input_order in itertools.permutations("ABC"):
- self.check(input_order, {"AB", "AC", "CB"}, "ACB")
-
-
class TestMigrationList(object):
def test_can_create_empty(self):
m = MigrationList()
@@ -420,9 +346,7 @@ class TestReadMigrations(object):
The yoyo new command creates temporary files in the migrations directory.
These shouldn't be picked up by yoyo apply etc
"""
- with migrations_dir(
- **{newmigration.tempfile_prefix + "test": ""}
- ) as tmpdir:
+ with migrations_dir(**{newmigration.tempfile_prefix + "test": ""}) as tmpdir:
assert len(read_migrations(tmpdir)) == 0
def test_it_loads_post_apply_scripts(self):
@@ -441,12 +365,8 @@ class TestReadMigrations(object):
m.load()
assert len(m.steps) == 1
- def test_it_does_not_add_duplicate_steps_with_imported_symbols(
- self, tmpdir
- ):
- with migrations_dir(
- a="from yoyo import step; step('SELECT 1')"
- ) as tmpdir:
+ def test_it_does_not_add_duplicate_steps_with_imported_symbols(self, tmpdir):
+ with migrations_dir(a="from yoyo import step; step('SELECT 1')") as tmpdir:
m = read_migrations(tmpdir)[0]
m.load()
assert len(m.steps) == 1
@@ -535,9 +455,7 @@ class TestReadMigrations(object):
def test_it_sets_depends_for_sql_migrations(self):
def check(sql, expected):
- with migrations_dir(
- **{"1.sql": "", "2.sql": "", "3.sql": sql}
- ) as tmp:
+ with migrations_dir(**{"1.sql": "", "2.sql": "", "3.sql": sql}) as tmp:
migration = read_migrations(tmp)[-1]
migration.load()
@@ -638,27 +556,20 @@ class TestLogging(object):
"from _yoyo_log "
"ORDER BY id DESC LIMIT 1"
)
- return {
- d[0]: value
- for d, value in zip(cursor.description, cursor.fetchone())
- }
+ return {d[0]: value for d, value in zip(cursor.description, cursor.fetchone())}
def get_log_count(self, backend):
return backend.execute("SELECT count(1) FROM _yoyo_log").fetchone()[0]
def test_it_logs_apply_and_rollback(self, backend):
- with migrations_dir(
- a='step("CREATE TABLE yoyo_test (id INT)")'
- ) as tmpdir:
+ with migrations_dir(a='step("CREATE TABLE yoyo_test (id INT)")') as tmpdir:
migrations = read_migrations(tmpdir)
backend.apply_migrations(migrations)
assert self.get_log_count(backend) == 1
logged = self.get_last_log_entry(backend)
assert logged["migration_id"] == "a"
assert logged["operation"] == "apply"
- assert logged["created_at_utc"] >= datetime.utcnow() - timedelta(
- seconds=3
- )
+ assert logged["created_at_utc"] >= datetime.utcnow() - timedelta(seconds=3)
apply_time = logged["created_at_utc"]
backend.rollback_migrations(migrations)
@@ -669,18 +580,14 @@ class TestLogging(object):
assert logged["created_at_utc"] >= apply_time
def test_it_logs_mark_and_unmark(self, backend):
- with migrations_dir(
- a='step("CREATE TABLE yoyo_test (id INT)")'
- ) as tmpdir:
+ with migrations_dir(a='step("CREATE TABLE yoyo_test (id INT)")') as tmpdir:
migrations = read_migrations(tmpdir)
backend.mark_migrations(migrations)
assert self.get_log_count(backend) == 1
logged = self.get_last_log_entry(backend)
assert logged["migration_id"] == "a"
assert logged["operation"] == "mark"
- assert logged["created_at_utc"] >= datetime.utcnow() - timedelta(
- seconds=3
- )
+ assert logged["created_at_utc"] >= datetime.utcnow() - timedelta(seconds=3)
marked_time = logged["created_at_utc"]
backend.unmark_migrations(migrations)
diff --git a/yoyo/tests/test_topologicalsort.py b/yoyo/tests/test_topologicalsort.py
new file mode 100644
index 0000000..be308d2
--- /dev/null
+++ b/yoyo/tests/test_topologicalsort.py
@@ -0,0 +1,77 @@
+import itertools
+
+import pytest
+
+from yoyo import topologicalsort
+
+
+class TestTopologicalSort(object):
+ def check(self, nodes, edges, expected):
+ deps = {}
+ edges = edges.split() if edges else []
+ for a, b in edges:
+ deps.setdefault(a, set()).add(b)
+ output = list(topologicalsort.topological_sort(nodes, deps))
+ for a, b in edges:
+ try:
+ assert output.index(a) > output.index(b)
+ except ValueError:
+ pass
+ assert output == list(expected)
+
+ def test_it_keeps_stable_order(self):
+ for s in map(str, itertools.permutations("ABCD")):
+ self.check(s, "", s)
+
+ def test_it_sorts_topologically(self):
+
+ # Single group at start
+ self.check("ABCD", "BA", "ABCD")
+ self.check("BACD", "BA", "ABCD")
+
+ # Single group in middle start
+ self.check("CABD", "BA", "CABD")
+ self.check("CBAD", "BA", "CABD")
+
+ # Extended group
+ self.check("ABCD", "BA DA", "ABCD")
+ self.check("DBCA", "BA DA", "CADB")
+
+ # Non-connected groups
+ self.check("ABCDEF", "BC DE", "ACBEDF")
+ self.check("ADEBCF", "BC DE", "AEDCBF")
+ self.check("ADEFBC", "BC DE", "AEDFCB")
+ self.check("DBAFEC", "BC DE", "AFEDCB")
+
+ def test_it_discards_missing_dependencies(self):
+ self.check("ABCD", "CX XY", "ABCD")
+
+ def test_it_catches_cycles(self):
+ with pytest.raises(topologicalsort.CycleError):
+ self.check("ABCD", "AA", "")
+ with pytest.raises(topologicalsort.CycleError):
+ self.check("ABCD", "AB BA", "")
+ with pytest.raises(topologicalsort.CycleError):
+ self.check("ABCD", "AB BC CB", "")
+ with pytest.raises(topologicalsort.CycleError):
+ self.check("ABCD", "AB BC CA", "")
+
+ def test_it_handles_multiple_edges_to_the_same_node(self):
+ self.check("ABCD", "BA CA DA", "ABCD")
+ self.check("DCBA", "BA CA DA", "ADCB")
+
+ def test_it_handles_multiple_edges_to_the_same_node2(self):
+ # A --> B
+ # | ^
+ # v |
+ # C --- +
+ for input_order in itertools.permutations("ABC"):
+ self.check(input_order, "BA CA BC", "ACB")
+
+ def test_it_doesnt_modify_order_unnecessarily(self):
+ """
+ Test for issue raised in
+
+ https://lists.sr.ht/~olly/yoyo/%3C09c43045fdf14024a0f2e905408ea41f%40atos.net%3E
+ """
+ self.check("ABC", "CA", "ABC")
diff --git a/yoyo/topologicalsort.py b/yoyo/topologicalsort.py
new file mode 100644
index 0000000..e74bf4a
--- /dev/null
+++ b/yoyo/topologicalsort.py
@@ -0,0 +1,79 @@
+from typing import Set
+from typing import Dict
+from typing import Mapping
+from typing import TypeVar
+from typing import Iterable
+from typing import Collection
+from collections import defaultdict
+from heapq import heappop
+from heapq import heappush
+
+
+class CycleError(ValueError):
+ """
+ Raised when cycles exist in the input graph.
+
+ The second element in the args attribute of instances will contain the
+ sequence of nodes in which the cycle lies.
+ """
+
+
+T = TypeVar("T")
+
+
+def topological_sort(
+ items: Iterable[T], dependency_graph: Mapping[T, Collection[T]]
+) -> Iterable[T]:
+
+ # Tag each item with its input order
+ pqueue = list(enumerate(items))
+ ordering = {item: ix for ix, item in pqueue}
+ seen_since_last_change = 0
+ output: Set[T] = set()
+
+ # Map blockers to the list of items they block
+ blocked_on: Dict[T, Set[T]] = defaultdict(set)
+ blocked: Set[T] = set()
+
+ while pqueue:
+ if seen_since_last_change == len(pqueue) + len(blocked):
+ raise_cycle_error(ordering, pqueue, blocked_on)
+
+ _, n = heappop(pqueue)
+
+ blockers = {
+ d for d in dependency_graph.get(n, []) if d not in output and d in ordering
+ }
+ if not blockers:
+ seen_since_last_change = 0
+ output.add(n)
+ if n in blocked:
+ blocked.remove(n)
+ yield n
+ for b in blocked_on.pop(n, []):
+ if not any(b in other for other in blocked_on.values()):
+ heappush(pqueue, (ordering[b], b))
+ else:
+ if n in blocked:
+ seen_since_last_change += 1
+ else:
+ seen_since_last_change = 0
+ blocked.add(n)
+ for b in blockers:
+ blocked_on[b].add(n)
+ if blocked_on:
+ raise_cycle_error(ordering, pqueue, blocked_on)
+
+
+def raise_cycle_error(ordering, pqueue, blocked_on):
+ bad = next((item for item in blocked_on if item not in ordering), None)
+ if bad:
+ raise ValueError(f"Dependency graph contains a non-existent node {bad!r}")
+ unresolved = {n for _, n in pqueue}
+ unresolved.update(*blocked_on.values())
+ if unresolved:
+ raise CycleError(
+ f"Dependency graph loop detected among {unresolved!r}",
+ list(sorted(unresolved, key=ordering.get)),
+ )
+ raise AssertionError("raise_cycle_error called but no unresovled nodes exist")