diff options
-rw-r--r-- | yoyo/backends/base.py | 22 | ||||
-rwxr-xr-x | yoyo/migrations.py | 16 | ||||
-rw-r--r-- | yoyo/tests/test_backends.py | 8 |
3 files changed, 17 insertions, 29 deletions
diff --git a/yoyo/backends/base.py b/yoyo/backends/base.py index 4f95a66..1f3c167 100644 --- a/yoyo/backends/base.py +++ b/yoyo/backends/base.py @@ -45,9 +45,9 @@ class TransactionManager: when the context manager block closes """ - def __init__(self, backend): + def __init__(self, backend, rollback_on_exit=False): self.backend = backend - self._rollback = False + self.rollback_on_exit = rollback_on_exit def __enter__(self): self._do_begin() @@ -58,18 +58,11 @@ class TransactionManager: self._do_rollback() return None - if self._rollback: + if self.rollback_on_exit: self._do_rollback() else: self._do_commit() - def rollback(self): - """ - Flag that the transaction will be rolled back when the with statement - exits - """ - self._rollback = True - def _do_begin(self): """ Instruct the backend to begin a transaction @@ -238,9 +231,8 @@ class DatabaseBackend: table_name_quoted = self.quote_identifier(table_name) sql = self.create_test_table_sql.format(table_name_quoted=table_name_quoted) try: - with self.transaction() as t: + with self.transaction(rollback_on_exit=True): self.execute(sql) - t.rollback() except self.DatabaseError: return False @@ -263,12 +255,12 @@ class DatabaseBackend: ) return [row[0] for row in cursor.fetchall()] - def transaction(self): + def transaction(self, rollback_on_exit=False): if not self._in_transaction: - return TransactionManager(self) + return TransactionManager(self, rollback_on_exit=rollback_on_exit) else: - return SavepointTransactionManager(self) + return SavepointTransactionManager(self, rollback_on_exit=rollback_on_exit) def cursor(self): return self.connection.cursor() diff --git a/yoyo/migrations.py b/yoyo/migrations.py index 5450d0a..5b30049 100755 --- a/yoyo/migrations.py +++ b/yoyo/migrations.py @@ -312,16 +312,14 @@ class TransactionWrapper(StepBase): return "<TransactionWrapper {!r}>".format(self.step) def apply(self, backend, force=False, direction="apply"): - with backend.transaction() as transaction: - try: + try: + with backend.transaction(): getattr(self.step, direction)(backend, force) - except backend.DatabaseError: - if force or self.ignore_errors in (direction, "all"): - logger.exception("Ignored error in %r", self.step) - transaction.rollback() - return - else: - raise + except backend.DatabaseError: + if force or self.ignore_errors in (direction, "all"): + logger.exception("Ignored error in %r", self.step) + else: + raise def rollback(self, backend, force=False): self.apply(backend, force, "rollback") diff --git a/yoyo/tests/test_backends.py b/yoyo/tests/test_backends.py index 8d99009..e4b9727 100644 --- a/yoyo/tests/test_backends.py +++ b/yoyo/tests/test_backends.py @@ -44,11 +44,10 @@ class TestTransactionHandling(object): with backend.transaction(): backend.execute("INSERT INTO yoyo_t values ('A')") - with backend.transaction() as trans: + with backend.transaction(rollback_on_exit=True): backend.execute("INSERT INTO yoyo_t values ('B')") - trans.rollback() - with backend.transaction() as trans: + with backend.transaction(): backend.execute("INSERT INTO yoyo_t values ('C')") with backend.transaction(): @@ -95,12 +94,11 @@ class TestTransactionHandling(object): if backend.has_transactional_ddl: return - with backend.transaction() as trans: + with backend.transaction(rollback_on_exit=True): backend.execute("CREATE TABLE yoyo_a (id INT)") # implicit commit backend.execute("INSERT INTO yoyo_a VALUES (1)") backend.execute("CREATE TABLE yoyo_b (id INT)") # implicit commit backend.execute("INSERT INTO yoyo_b VALUES (1)") - trans.rollback() count_a = backend.execute("SELECT COUNT(1) FROM yoyo_a").fetchall()[0][0] assert count_a == 1 |