summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorOlly Cope <olly@ollycope.com>2021-05-23 11:38:54 +0000
committerOlly Cope <olly@ollycope.com>2021-05-23 11:38:54 +0000
commitcec1c9cb184fa468c82cea3d5951dadf26be43a6 (patch)
tree160b0bfe08b642eca2ece90c60b8b51abffdb8f6
parent62bcd9849eea8b4faf71b00e85d3d5bbe1a03e42 (diff)
downloadyoyo-cec1c9cb184fa468c82cea3d5951dadf26be43a6.tar.gz
Fix access to StepCollector from imported yoyo.migrations.step
Fixes https://todo.sr.ht/~olly/yoyo/79 The old logic relied on a string comparison of the module path. This broke when Python's importlib switched to using absolute paths. It's safer to simply look for a ``__yoyo_collector__`` attribute in the caller frame's namespace, which is injected by Migration.load. I also chose to rename this attribute from 'collector' to __yoyo_collector__ to reduce the risk of collision with a user defined variable.
-rwxr-xr-xyoyo/migrations.py28
1 files changed, 13 insertions, 15 deletions
diff --git a/yoyo/migrations.py b/yoyo/migrations.py
index c95ed32..a34d262 100755
--- a/yoyo/migrations.py
+++ b/yoyo/migrations.py
@@ -25,7 +25,6 @@ from logging import getLogger
from typing import Dict
from typing import Iterable
from typing import List
-from typing import MutableMapping
from typing import Tuple
import hashlib
import importlib.util
@@ -35,7 +34,6 @@ import sys
import inspect
import types
import textwrap
-import weakref
import pkg_resources
import sqlparse
@@ -48,8 +46,6 @@ default_migration_table = "_yoyo_migration"
hash_function = hashlib.sha256
-_collectors: MutableMapping[str, "StepCollector"] = (weakref.WeakValueDictionary())
-
def _is_migration_file(path):
"""
@@ -174,7 +170,6 @@ class Migration(object):
return
collector = StepCollector(migration=self)
- _collectors[self.path] = collector
with open(self.path, "r") as f:
self.source = f.read()
@@ -187,7 +182,7 @@ class Migration(object):
self.module.step = collector.add_step # type: ignore
self.module.group = collector.add_step_group # type: ignore
self.module.transaction = collector.add_step_group # type: ignore
- self.module.collector = collector # type: ignore
+ self.module.__yoyo_collector__ = collector # type: ignore
if self.is_raw_sql():
directives, leading_comment, statements = read_sql_migration(self.path)
_, _, rollback_statements = read_sql_migration(
@@ -199,7 +194,7 @@ class Migration(object):
)
for s, r in statements_with_rollback:
- self.module.collector.add_step(s, r) # type: ignore
+ collector.add_step(s, r)
self.module.__doc__ = leading_comment
setattr(
self.module,
@@ -644,14 +639,17 @@ class StepCollector(object):
return [create_step(use_transactions) for create_step in self.steps]
-def _get_collector(depth=2):
- for stackframe in reversed(inspect.stack()):
- path = stackframe.frame.f_code.co_filename
- if path in _collectors:
- return _collectors[path]
- raise AssertionError(
- "Expected to be called in the context of a migration module import"
- )
+def _get_collector():
+ frame = inspect.currentframe()
+ try:
+ while frame is not None:
+ if "__yoyo_collector__" in frame.f_globals:
+ return frame.f_globals["__yoyo_collector__"]
+ frame = frame.f_back
+ except KeyError:
+ raise AssertionError(
+ "Expected to be called in the context of a migration module import"
+ )
def step(*args, **kwargs):