summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorOlly Cope <olly@ollycope.com>2017-01-12 11:06:17 +0000
committerOlly Cope <olly@ollycope.com>2017-01-12 11:06:17 +0000
commitdb98555d8fcd2949263a068f7ee2e8c7be4b0716 (patch)
tree506a097e0d39e490394cf389382fb478a19e6257
parent72b013dea006c100ca224802f1b4725bf9742784 (diff)
downloadyoyo-db98555d8fcd2949263a068f7ee2e8c7be4b0716.tar.gz
Add a __transactional__ flag for migration files
Adding ``__transaction__ = False`` in the migration module disables transaction support for the migration, causing it to run in autocommit mode. This is necessary to support certain statements in PostgreSQL (eg "CREATE TABLE") that raise errors if run in a transaction block.
-rwxr-xr-xREADME.rst21
-rw-r--r--yoyo/backends.py33
-rwxr-xr-xyoyo/migrations.py118
-rw-r--r--yoyo/tests/__init__.py9
-rw-r--r--yoyo/tests/test_backends.py21
5 files changed, 158 insertions, 44 deletions
diff --git a/README.rst b/README.rst
index 43074e0..14d62c7 100755
--- a/README.rst
+++ b/README.rst
@@ -216,6 +216,27 @@ rollbacks happen. For example::
step("UPDATE employees SET tax_code='B' WHERE pay_grade >= 6")
step("UPDATE employees SET tax_code='A' WHERE pay_grade >= 8")
+Disabling transactions
+~~~~~~~~~~~~~~~~~~~~~~
+
+In PostgreSQL it is an error to run certain statements inside a transaction
+block. These include:
+
+.. code::sql
+
+ CREATE TABLE <foo>
+ ALTER TYPE <enum> ADD ...
+
+Migrations containing such statements should set
+``__transactional__ = False``, eg:
+
+.. code::python
+
+ __transactional__ = False
+
+ step("CREATE DATABASE mydb", "DROP DATABASE mydb")
+
+Note that this feature is implemented for the PostgreSQL backend only.
Post-apply hook
---------------
diff --git a/yoyo/backends.py b/yoyo/backends.py
index bb0b31d..5219751 100644
--- a/yoyo/backends.py
+++ b/yoyo/backends.py
@@ -13,6 +13,7 @@
# limitations under the License.
from datetime import datetime
+from contextlib import contextmanager
from importlib import import_module
from itertools import count
from logging import getLogger
@@ -215,6 +216,15 @@ class DatabaseBackend(object):
"""
self.execute("ROLLBACK TO SAVEPOINT {}".format(id))
+ @contextmanager
+ def disable_transactions(self):
+ """
+ Disable the connection's transaction support, for example by
+ setting the isolation mode to 'autocommit'
+ """
+ self.rollback()
+ yield
+
def execute(self, stmt, args=tuple()):
"""
Create a new cursor, execute a single statement and return the cursor
@@ -294,8 +304,7 @@ class DatabaseBackend(object):
return
for m in migrations:
try:
- with self.transaction():
- self.apply_one(m, force=force)
+ self.apply_one(m, force=force)
except exceptions.BadMigration:
continue
@@ -304,16 +313,14 @@ class DatabaseBackend(object):
Run any post-apply migrations present in ``migrations``
"""
for m in migrations.post_apply:
- with self.transaction():
- self.apply_one(m, mark=False, force=force)
+ self.apply_one(m, mark=False, force=force)
def rollback_migrations(self, migrations, force=False):
if not migrations:
return
for m in migrations:
try:
- with self.transaction():
- self.rollback_one(m, force)
+ self.rollback_one(m, force)
except exceptions.BadMigration:
continue
@@ -340,7 +347,8 @@ class DatabaseBackend(object):
logger.info("Applying %s", migration.id)
migration.process_steps(self, 'apply', force=force)
if mark:
- self.mark_one(migration)
+ with self.transaction():
+ self.mark_one(migration)
def rollback_one(self, migration, force=False):
"""
@@ -348,7 +356,8 @@ class DatabaseBackend(object):
"""
logger.info("Rolling back %s", migration.id)
migration.process_steps(self, 'rollback', force=force)
- self.unmark_one(migration)
+ with self.transaction():
+ self.unmark_one(migration)
def unmark_one(self, migration):
sql = self._with_placeholders(self.delete_migration_sql.format(self))
@@ -441,3 +450,11 @@ class PostgresqlBackend(DatabaseBackend):
connargs.append('host=%s' % dburi.hostname)
connargs.append('dbname=%s' % dburi.database)
return self.driver.connect(' '.join(connargs))
+
+ @contextmanager
+ def disable_transactions(self):
+ with super(PostgresqlBackend, self).disable_transactions():
+ saved = self.connection.autocommit
+ self.connection.autocommit = True
+ yield
+ self.connection.autocommit = saved
diff --git a/yoyo/migrations.py b/yoyo/migrations.py
index 8aab47c..9ade1ae 100755
--- a/yoyo/migrations.py
+++ b/yoyo/migrations.py
@@ -38,6 +38,7 @@ class Migration(object):
self.path = path
self.steps = None
self.source = None
+ self.use_transactions = True
self._depends = None
self.__all_migrations[id] = self
@@ -62,7 +63,7 @@ class Migration(object):
self.source = source = f.read()
migration_code = compile(source, f.name, 'exec')
- collector = StepCollector()
+ collector = StepCollector(migration=self)
ns = {'step': collector.add_step,
'group': collector.add_step_group,
'transaction': collector.add_step_group,
@@ -78,12 +79,13 @@ class Migration(object):
depends = [depends]
self._depends = {self.__all_migrations.get(id, None)
for id in depends}
+ self.use_transactions = ns.get('__transactional__', True)
if None in self._depends:
raise exceptions.BadMigration(
"Could not resolve dependencies in {}".format(self.path))
self.ns = ns
self.source = source
- self.steps = list(collector.steps)
+ self.steps = collector.create_steps(self.use_transactions)
def process_steps(self, backend, direction, force=False):
@@ -96,24 +98,31 @@ class Migration(object):
steps = reversed(steps)
executed_steps = []
- for step in steps:
- try:
- getattr(step, direction)(backend, force)
- executed_steps.append(step)
- except backend.DatabaseError:
- exc_info = sys.exc_info()
-
- if not backend.has_transactional_ddl:
- # Any DDL statements that have been executed have been
- # committed. Go through the rollback steps to undo these
- # inasmuch is possible.
- try:
- for step in reversed(executed_steps):
- getattr(step, reverse)(backend)
- except backend.DatabaseError:
- logger.exception('Could not %s step %s',
- direction, step.id)
- reraise(exc_info[0], exc_info[1], exc_info[2])
+ if self.use_transactions:
+ transaction = backend.transaction
+ else:
+ transaction = backend.disable_transactions
+
+ with transaction():
+ for step in steps:
+ try:
+ getattr(step, direction)(backend, force)
+ executed_steps.append(step)
+ except backend.DatabaseError:
+ exc_info = sys.exc_info()
+
+ if not (backend.has_transactional_ddl or
+ not self.use_transactions):
+ # Any DDL statements that have been executed have been
+ # committed. Go through the rollback steps to undo
+ # these inasmuch is possible.
+ try:
+ for step in reversed(executed_steps):
+ getattr(step, reverse)(backend)
+ except backend.DatabaseError:
+ logger.exception('Could not %s step %s',
+ direction, step.id)
+ reraise(exc_info[0], exc_info[1], exc_info[2])
class PostApplyHookMigration(Migration):
@@ -140,9 +149,9 @@ class StepBase(object):
class TransactionWrapper(StepBase):
"""
- A ``Transaction`` object causes a step to be run within a
- single database transaction. Nested transactions are implemented via
- savepoints.
+ A :class:~`yoyo.migrations.TransactionWrapper` object causes a step to be
+ run within a single database transaction. Nested transactions are
+ implemented via savepoints.
"""
def __init__(self, step, ignore_errors=None):
@@ -169,6 +178,34 @@ class TransactionWrapper(StepBase):
self.apply(backend, force, 'rollback')
+class Transactionless(StepBase):
+ """
+ A :class:~`yoyo.migrations.TransactionWrapper` object causes a step to be
+ run outside of a database transaction.
+ """
+
+ def __init__(self, step, ignore_errors=None):
+ assert ignore_errors in (None, 'all', 'apply', 'rollback')
+ self.step = step
+ self.ignore_errors = ignore_errors
+
+ def __repr__(self):
+ return '<TransactionWrapper {!r}>'.format(self.step)
+
+ def apply(self, backend, force=False, direction='apply'):
+ try:
+ getattr(self.step, direction)(backend, force)
+ except backend.DatabaseError:
+ if force or self.ignore_errors in (direction, 'all'):
+ logger.exception("Ignored error in %r", self.step)
+ return
+ else:
+ raise
+
+ def rollback(self, backend, force=False):
+ self.apply(backend, force, 'rollback')
+
+
class MigrationStep(StepBase):
"""
Model a single migration.
@@ -374,7 +411,8 @@ class StepCollector(object):
Each call to step/transaction updates the StepCollector's ``steps`` list.
"""
- def __init__(self):
+ def __init__(self, migration):
+ self.migration = migration
self.steps = OrderedDict()
self.step_id = count(0)
@@ -384,16 +422,20 @@ class StepCollector(object):
to the list of steps.
Return the transaction-wrapped step.
"""
- t = MigrationStep(next(self.step_id), apply, rollback)
- t = TransactionWrapper(t, ignore_errors)
- self.steps[t] = 1
- return t
+ def do_add(use_transactions):
+ wrapper = (TransactionWrapper
+ if use_transactions
+ else Transactionless)
+ t = MigrationStep(next(self.step_id), apply, rollback)
+ t = wrapper(t, ignore_errors)
+ return t
+ self.steps[do_add] = 1
+ return do_add
def add_step_group(self, *args, **kwargs):
"""
Create a ``StepGroup`` group of steps.
"""
- ignore_errors = kwargs.pop('ignore_errors', None)
if 'steps' in kwargs:
if args:
raise ValueError("steps cannot be called with both keyword "
@@ -406,9 +448,21 @@ class StepCollector(object):
for s in steps:
del self.steps[s]
- group = TransactionWrapper(StepGroup(steps), ignore_errors)
- self.steps[group] = 1
- return group
+ def do_add(use_transactions):
+ ignore_errors = kwargs.pop('ignore_errors', None)
+ wrapper = (TransactionWrapper
+ if use_transactions
+ else Transactionless)
+
+ group = StepGroup([create_step(use_transactions)
+ for create_step in steps])
+ return wrapper(group, ignore_errors)
+
+ self.steps[do_add] = 1
+ return do_add
+
+ def create_steps(self, use_transactions):
+ return [create_step(use_transactions) for create_step in self.steps]
def _get_collector(depth=2):
diff --git a/yoyo/tests/__init__.py b/yoyo/tests/__init__.py
index e2cd489..03021cb 100644
--- a/yoyo/tests/__init__.py
+++ b/yoyo/tests/__init__.py
@@ -29,12 +29,13 @@ config = get_configparser()
config.read([config_file])
-def get_test_dburis():
- return [dburi for _, dburi in config.items('DEFAULT')]
+def get_test_dburis(only=frozenset(), exclude=frozenset()):
+ return [dburi for name, dburi in config.items('DEFAULT')
+ if (only and name in only) or (not only and name not in exclude)]
-def get_test_backends():
- return [get_backend(dburi) for dburi in get_test_dburis()]
+def get_test_backends(only=frozenset(), exclude=frozenset()):
+ return [get_backend(dburi) for dburi in get_test_dburis(only, exclude)]
class MigrationsContextManager(object):
diff --git a/yoyo/tests/test_backends.py b/yoyo/tests/test_backends.py
index 4b9850f..89851d9 100644
--- a/yoyo/tests/test_backends.py
+++ b/yoyo/tests/test_backends.py
@@ -2,6 +2,7 @@ import pytest
from yoyo import backends
from yoyo.tests import get_test_backends
+from yoyo.tests import with_migrations
class TestTransactionHandling(object):
@@ -91,3 +92,23 @@ class TestTransactionHandling(object):
count_b = backend.execute("SELECT COUNT(1) FROM _yoyo_b")\
.fetchall()[0][0]
assert count_b == 0
+
+ @with_migrations(a="""
+ __transactional__ = False
+ step('CREATE DATABASE yoyo_test_tmp',
+ 'DROP DATABASE yoyo_test_tmp',
+ )
+ """)
+ def test_statements_requiring_no_transaction(self, tmpdir):
+ """
+ PostgreSQL will error if certain statements (eg CREATE DATABASE)
+ are run within a transaction block.
+
+ As far as I know this behavior is PostgreSQL specific. We can't run
+ this test in sqlite as it does not support CREATE DATABASE.
+ """
+ from yoyo import read_migrations
+ for backend in get_test_backends(exclude={'sqlite'}):
+ migrations = read_migrations(tmpdir)
+ backend.apply_migrations(migrations)
+ backend.rollback_migrations(migrations)