summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorolly <olly@ollycope.com>2015-04-16 18:18:46 +0000
committerolly <olly@ollycope.com>2015-04-16 18:18:46 +0000
commitdf5e7ce1645c4772b7cd70e323bc0d1d9bb88e22 (patch)
tree64318eac2632862e6cd48c04f27b562d3050f329
parent6edba388d78966279f40d197d2e6f53d59905b22 (diff)
downloadyoyo-df5e7ce1645c4772b7cd70e323bc0d1d9bb88e22.tar.gz
Add topological sort method
This is to allow us to sort migrations based on their declared dependencies
-rwxr-xr-xyoyo/migrations.py59
-rw-r--r--yoyo/tests/test_migrations.py38
2 files changed, 96 insertions, 1 deletions
diff --git a/yoyo/migrations.py b/yoyo/migrations.py
index c800b03..3799c68 100755
--- a/yoyo/migrations.py
+++ b/yoyo/migrations.py
@@ -1,5 +1,6 @@
+from collections import defaultdict, OrderedDict
from datetime import datetime
-from itertools import count
+from itertools import chain, count
from logging import getLogger
import os
import sys
@@ -483,3 +484,59 @@ def step(*args, **kwargs):
def transaction(*args, **kwargs):
fi = inspect.getframeinfo(inspect.stack()[1][0])
return _step_collectors[fi.filename].transaction(*args, **kwargs)
+
+
+def topological_sort(migration_list):
+
+ # The sorted list, initially empty
+ L = list()
+
+ valid_migrations = set(migration_list)
+
+ # Track graph edges in two parallel data structures.
+ # Use OrderedDict so that we can traverse edges in order
+ # and keep the sort stable
+ forward_edges = defaultdict(OrderedDict)
+ backward_edges = defaultdict(OrderedDict)
+
+ for m in migration_list:
+ for n in m.depends:
+ if n not in valid_migrations:
+ continue
+ forward_edges[n][m] = 1
+ backward_edges[m][n] = 1
+
+ # Only toposort the migrations forming part of the dependency graph
+ to_toposort = set(chain(forward_edges, backward_edges))
+
+ # Starting migrations: those with no dependencies
+ # This is a reversed list so that popping/pushing from the end maintains
+ # the desired order
+ S = list(reversed([m for m in to_toposort
+ if not any(n in valid_migrations for n in m.depends)]))
+
+ while S:
+ n = S.pop()
+ L.append(n)
+
+ # for each node M with an edge E from N to M
+ for m in forward_edges[n]:
+
+ # remove edge E from the graph
+ del forward_edges[n][m]
+ del backward_edges[m][n]
+
+ # If M has no other incoming edges, it qualifies as a starting node
+ if not backward_edges[m]:
+ S.append(m)
+
+ if any(forward_edges.values()):
+ raise exceptions.BadMigration(
+ "Circular dependencies among these migrations {}".format(
+ ', '.join(m.path
+ for m in forward_edges
+ for n in {m} | set(forward_edges[m]))))
+
+ # Return the toposorted migrations followed by the remainder of migrations
+ # in their original order
+ return L + [m for m in migration_list if m not in to_toposort]
diff --git a/yoyo/tests/test_migrations.py b/yoyo/tests/test_migrations.py
index e798a48..ad98938 100644
--- a/yoyo/tests/test_migrations.py
+++ b/yoyo/tests/test_migrations.py
@@ -1,8 +1,12 @@
+import pytest
+from mock import Mock
+
from yoyo.connections import connect
from yoyo import read_migrations
from yoyo import exceptions
from yoyo.tests import with_migrations, dburi
+from yoyo.migrations import topological_sort
@with_migrations(
@@ -161,3 +165,37 @@ def test_migrations_can_import_step_and_transaction(tmpdir):
cursor = conn.cursor()
cursor.execute("SELECT id FROM test")
assert cursor.fetchall() == [(1,)]
+
+
+class TestTopologicalSort(object):
+
+ def get_mock_migrations(self):
+ return [Mock(path='m1', depends=set()), Mock(path='m2', depends=set()),
+ Mock(path='m3', depends=set()), Mock(path='m4', depends=set())]
+
+ def test_it_keeps_stable_order(self):
+ m1, m2, m3, m4 = self.get_mock_migrations()
+ assert list(topological_sort([m1, m2, m3, m4])) == [m1, m2, m3, m4]
+ assert list(topological_sort([m4, m3, m2, m1])) == [m4, m3, m2, m1]
+
+ def test_it_sorts_topologically(self):
+ m1, m2, m3, m4 = self.get_mock_migrations()
+ m3.depends.add(m4)
+ assert list(topological_sort([m1, m2, m3, m4])) == [m4, m3, m1, m2]
+
+ def test_it_brings_depended_upon_migrations_to_the_front(self):
+ m1, m2, m3, m4 = self.get_mock_migrations()
+ m1.depends.add(m4)
+ print(list(m.id for m in topological_sort([m1, m2, m3, m4])))
+ assert list(topological_sort([m1, m2, m3, m4])) == [m4, m1, m2, m3]
+
+ def test_it_discards_missing_dependencies(self):
+ m1, m2, m3, m4 = self.get_mock_migrations()
+ m3.depends.add(Mock())
+ assert list(topological_sort([m1, m2, m3, m4])) == [m1, m2, m3, m4]
+
+ def test_it_catches_cycles(self):
+ m1, m2, m3, m4 = self.get_mock_migrations()
+ m3.depends.add(m3)
+ with pytest.raises(exceptions.BadMigration):
+ list(topological_sort([m1, m2, m3, m4]))