summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorOlly Cope <olly@ollycope.com>2018-01-29 13:44:26 +0000
committerOlly Cope <olly@ollycope.com>2018-01-29 13:44:26 +0000
commite8f58c64929220a8c9c9fc78132b50aff94ec906 (patch)
tree377e8cc773fcb2c53c5db8a6989f1f93abcbfa1c
parent464d4c134c5d1eb30abe2f56648ea3d9ba217f2b (diff)
downloadyoyo-e8f58c64929220a8c9c9fc78132b50aff94ec906.tar.gz
rewrite locking mechanism to use a lock table and polling
PostgreSQL's table locking commands are always scoped to a single transaction, so are not useful for locking across migration runs. This new implementation relies on the database enforcing consistency over a unique key to prevent concurrent runs. This is less elegant, but should work across any backend database that enforces key integrity.
-rw-r--r--yoyo/backends.py87
-rwxr-xr-xyoyo/exceptions.py6
-rwxr-xr-xyoyo/scripts/migrate.py21
-rw-r--r--yoyo/tests/test_backends.py57
-rw-r--r--yoyo/tests/test_cli_script.py25
5 files changed, 153 insertions, 43 deletions
diff --git a/yoyo/backends.py b/yoyo/backends.py
index e1a32ee..f32b2f3 100644
--- a/yoyo/backends.py
+++ b/yoyo/backends.py
@@ -17,6 +17,8 @@ from contextlib import contextmanager
from importlib import import_module
from itertools import count
from logging import getLogger
+import os
+import time
from . import exceptions, utils
from .migrations import topological_sort
@@ -104,11 +106,19 @@ class DatabaseBackend(object):
driver_module = None
connection = None
- create_table_sql = """
+ lock_table = '_yoyo_lock'
+ create_migration_table_sql = """
CREATE TABLE {table_name} (
id VARCHAR(255) NOT NULL PRIMARY KEY,
ctime TIMESTAMP
)"""
+ create_lock_table_sql = """
+ CREATE TABLE {table_name} (
+ locked INT DEFAULT 1,
+ ctime TIMESTAMP,
+ pid INT NOT NULL,
+ PRIMARY KEY (locked)
+ )"""
list_tables_sql = "SELECT table_name FROM information_schema.tables"
is_applied_sql = "SELECT COUNT(1) FROM {0.migration_table} WHERE id=?"
insert_migration_sql = ("INSERT INTO {0.migration_table} (id, ctime) "
@@ -125,7 +135,7 @@ class DatabaseBackend(object):
self.DatabaseError = self.driver.DatabaseError
self._connection = self.connect(dburi)
self.migration_table = migration_table
- self.create_migrations_table()
+ self.create_tables()
self.has_transactional_ddl = self._check_transactional_ddl()
def _load_driver_module(self):
@@ -226,11 +236,48 @@ class DatabaseBackend(object):
yield
@contextmanager
- def lock_migration_table(self):
+ def lock(self, timeout=10):
"""
- Lock `migrations_table` to prevent concurrent migrations.
+ Create a lock to prevent concurrent migrations.
+
+ :param timeout: duration in seconds before raising a LockTimeout error.
"""
- yield
+
+ pid = os.getpid()
+ started = time.time()
+ while True:
+ try:
+ with self.transaction():
+ self.execute("INSERT INTO {} (locked, ctime, pid) "
+ "VALUES (1, ?, ?)".format(self.lock_table),
+ (datetime.utcnow(), pid))
+ except self.DatabaseError:
+ if timeout and time.time() > started + timeout:
+ cursor = self.execute("SELECT pid FROM {}"
+ .format(self.lock_table))
+ row = cursor.fetchone()
+ if row:
+ raise exceptions.LockTimeout(
+ "Process {} has locked this database for "
+ "migrations (run yoyo break-lock to "
+ "unconditionally remove locks)"
+ .format(row[0]))
+ else:
+ raise
+ time.sleep(0.1)
+ else:
+ break
+ try:
+ yield
+ finally:
+ with self.transaction():
+ self.execute("DELETE FROM {} WHERE pid=?"
+ .format(self.lock_table),
+ (pid,))
+
+ def break_lock(self):
+ with self.transaction():
+ self.execute("DELETE FROM {}" .format(self.lock_table))
def execute(self, stmt, args=tuple()):
"""
@@ -241,21 +288,21 @@ class DatabaseBackend(object):
cursor.execute(self._with_placeholders(stmt), args)
return cursor
- def create_migrations_table(self):
+ def create_tables(self):
"""
- Create the migrations table if it does not already exist.
+ Create the migrations and lock tables if they do not already exist.
"""
- sql = self.create_table_sql.format(table_name=self.migration_table)
- try:
- with self.transaction():
- self.get_applied_migration_ids()
- table_exists = True
- except self.DatabaseError:
- table_exists = False
+ statements = [
+ self.create_migration_table_sql.format(table_name=self.migration_table),
+ self.create_lock_table_sql.format(table_name=self.lock_table)
+ ]
- if not table_exists:
- with self.transaction():
- self.execute(sql)
+ for stmt in statements:
+ try:
+ with self.transaction():
+ self.execute(stmt)
+ except self.DatabaseError:
+ pass
def _with_placeholders(self, sql):
placeholder_gen = {'qmark': '?',
@@ -464,9 +511,3 @@ class PostgresqlBackend(DatabaseBackend):
self.connection.autocommit = True
yield
self.connection.autocommit = saved
-
- @contextmanager
- def lock_migration_table(self):
- self.execute('LOCK TABLE {table_name} IN ACCESS EXCLUSIVE MODE'
- .format(table_name=self.migration_table))
- yield
diff --git a/yoyo/exceptions.py b/yoyo/exceptions.py
index 606c51c..03ddd8f 100755
--- a/yoyo/exceptions.py
+++ b/yoyo/exceptions.py
@@ -29,3 +29,9 @@ class MigrationConflict(Exception):
"""
The migration id conflicts with another migration
"""
+
+
+class LockTimeout(Exception):
+ """
+ Timeout was reached while acquiring the migration lock
+ """
diff --git a/yoyo/scripts/migrate.py b/yoyo/scripts/migrate.py
index bb3117c..2e19700 100755
--- a/yoyo/scripts/migrate.py
+++ b/yoyo/scripts/migrate.py
@@ -96,6 +96,12 @@ def install_argparsers(global_parser, subparsers):
help="Unmark applied migrations, without rolling them back")
parser_unmark.set_defaults(func=unmark, command_name='unmark')
+ parser_break_lock = subparsers.add_parser(
+ 'break_lock',
+ parents=[global_parser],
+ help="Break migration locks")
+ parser_break_lock.set_defaults(func=break_lock, command_name='break-lock')
+
def get_migrations(args, backend):
@@ -160,14 +166,14 @@ def get_migrations(args, backend):
def apply(args, config):
backend = get_backend(args, config)
- with backend.lock_migration_table():
+ with backend.lock():
migrations = get_migrations(args, backend)
backend.apply_migrations(migrations, args.force)
def reapply(args, config):
backend = get_backend(args, config)
- with backend.lock_migration_table():
+ with backend.lock():
migrations = get_migrations(args, backend)
backend.rollback_migrations(migrations, args.force)
backend.apply_migrations(migrations, args.force)
@@ -175,25 +181,30 @@ def reapply(args, config):
def rollback(args, config):
backend = get_backend(args, config)
- with backend.lock_migration_table():
+ with backend.lock():
migrations = get_migrations(args, backend)
backend.rollback_migrations(migrations, args.force)
def mark(args, config):
backend = get_backend(args, config)
- with backend.lock_migration_table():
+ with backend.lock():
migrations = get_migrations(args, backend)
backend.mark_migrations(migrations)
def unmark(args, config):
backend = get_backend(args, config)
- with backend.lock_migration_table():
+ with backend.lock():
migrations = get_migrations(args, backend)
backend.unmark_migrations(migrations)
+def break_lock(args, config):
+ backend = get_backend(args, config)
+ backend.break_lock()
+
+
def prompt_migrations(backend, migrations, direction):
"""
Iterate through the list of migrations and prompt the user to
diff --git a/yoyo/tests/test_backends.py b/yoyo/tests/test_backends.py
index ddd94f7..bb7477f 100644
--- a/yoyo/tests/test_backends.py
+++ b/yoyo/tests/test_backends.py
@@ -2,7 +2,9 @@ import pytest
from threading import Thread
import time
-from yoyo import backends, read_migrations
+from yoyo import backends
+from yoyo import read_migrations
+from yoyo import exceptions
from yoyo.tests import get_test_backends
from yoyo.tests import with_migrations
@@ -96,17 +98,42 @@ class TestTransactionHandling(object):
backend.apply_migrations(migrations)
backend.rollback_migrations(migrations)
- @with_migrations(a="""
- steps = [
- step("SELECT pg_sleep(10)"),
- ]
- """)
- def test_lock_migration_table(self, tmpdir):
- backend = get_test_backends(only={'postgresql'})[0]
- migrations = read_migrations(tmpdir)
- t = Thread(target=backend.apply_migrations, args=(migrations,))
- t.start()
- # give a chance to start, but wake up in the middle of applying
- time.sleep(1)
- assert backend.get_applied_migration_ids() == ['a']
- t.join()
+ def test_lock(self, backend):
+ """
+ Test that :meth:`~yoyo.backends.DatabaseBackend.lock`
+ acquires an exclusive lock
+ """
+ lock_duration = 0.2
+
+ def do_something_with_lock():
+ with backend.lock():
+ time.sleep(lock_duration)
+
+ thread = Thread(target=do_something_with_lock)
+ t = time.time()
+ thread.start()
+ # Give the thread time to acquire the lock, but not enough
+ # to complete
+ time.sleep(lock_duration * 0.2)
+ with backend.lock():
+ assert time.time() > t + lock_duration
+
+ thread.join()
+
+ def test_lock_times_out(self, backend):
+
+ def do_something_with_lock():
+ with backend.lock():
+ time.sleep(lock_duration)
+
+ lock_duration = 2
+ thread = Thread(target=do_something_with_lock)
+ thread.start()
+ # Give the thread time to acquire the lock, but not enough
+ # to complete
+ time.sleep(lock_duration * 0.1)
+ with pytest.raises(exceptions.LockTimeout):
+ with backend.lock(timeout=lock_duration * 0.1):
+ assert False, "Execution should never reach this point"
+
+ thread.join()
diff --git a/yoyo/tests/test_cli_script.py b/yoyo/tests/test_cli_script.py
index 6bcef48..8fc6457 100644
--- a/yoyo/tests/test_cli_script.py
+++ b/yoyo/tests/test_cli_script.py
@@ -210,6 +210,31 @@ class TestYoyoScript(TestInteractiveScript):
assert self.prompt.call_count == 0
assert self.confirm.call_count == 0
+ def test_concurrent_instances_do_not_conflict(self, backend):
+ import threading
+ from functools import partial
+
+ with with_migrations(m1=('import time\n'
+ 'step(lambda conn: time.sleep(0.1))\n'
+ 'step("INSERT INTO _yoyo_t VALUES (\'A\')")')
+ ) as tmpdir:
+ assert '_yoyo_t' in backend.list_tables()
+ backend.rollback()
+ backend.execute("SELECT * FROM _yoyo_t")
+ run_migrations = partial(
+ main,
+ ['apply', '-b', tmpdir, '--database', str(backend.uri)])
+ threads = [threading.Thread(target=run_migrations)
+ for ix in range(20)]
+ for t in threads:
+ t.start()
+ for t in threads:
+ t.join()
+
+ # Exactly one instance of the migration script should have succeeded
+ cursor = backend.execute('SELECT COUNT(1) from "_yoyo_t"')
+ assert cursor.fetchone()[0] == 1
+
class TestArgParsing(TestInteractiveScript):