diff options
author | Olly Cope <olly@ollycope.com> | 2021-05-23 11:38:54 +0000 |
---|---|---|
committer | Olly Cope <olly@ollycope.com> | 2021-05-23 11:38:54 +0000 |
commit | cec1c9cb184fa468c82cea3d5951dadf26be43a6 (patch) | |
tree | 160b0bfe08b642eca2ece90c60b8b51abffdb8f6 | |
parent | 62bcd9849eea8b4faf71b00e85d3d5bbe1a03e42 (diff) | |
download | yoyo-cec1c9cb184fa468c82cea3d5951dadf26be43a6.tar.gz |
Fix access to StepCollector from imported yoyo.migrations.step
Fixes https://todo.sr.ht/~olly/yoyo/79
The old logic relied on a string comparison of the module path. This broke when
Python's importlib switched to using absolute paths.
It's safer to simply look for a ``__yoyo_collector__`` attribute in the caller
frame's namespace, which is injected by Migration.load. I also chose to rename
this attribute from 'collector' to __yoyo_collector__ to reduce
the risk of collision with a user defined variable.
-rwxr-xr-x | yoyo/migrations.py | 28 |
1 files changed, 13 insertions, 15 deletions
diff --git a/yoyo/migrations.py b/yoyo/migrations.py index c95ed32..a34d262 100755 --- a/yoyo/migrations.py +++ b/yoyo/migrations.py @@ -25,7 +25,6 @@ from logging import getLogger from typing import Dict from typing import Iterable from typing import List -from typing import MutableMapping from typing import Tuple import hashlib import importlib.util @@ -35,7 +34,6 @@ import sys import inspect import types import textwrap -import weakref import pkg_resources import sqlparse @@ -48,8 +46,6 @@ default_migration_table = "_yoyo_migration" hash_function = hashlib.sha256 -_collectors: MutableMapping[str, "StepCollector"] = (weakref.WeakValueDictionary()) - def _is_migration_file(path): """ @@ -174,7 +170,6 @@ class Migration(object): return collector = StepCollector(migration=self) - _collectors[self.path] = collector with open(self.path, "r") as f: self.source = f.read() @@ -187,7 +182,7 @@ class Migration(object): self.module.step = collector.add_step # type: ignore self.module.group = collector.add_step_group # type: ignore self.module.transaction = collector.add_step_group # type: ignore - self.module.collector = collector # type: ignore + self.module.__yoyo_collector__ = collector # type: ignore if self.is_raw_sql(): directives, leading_comment, statements = read_sql_migration(self.path) _, _, rollback_statements = read_sql_migration( @@ -199,7 +194,7 @@ class Migration(object): ) for s, r in statements_with_rollback: - self.module.collector.add_step(s, r) # type: ignore + collector.add_step(s, r) self.module.__doc__ = leading_comment setattr( self.module, @@ -644,14 +639,17 @@ class StepCollector(object): return [create_step(use_transactions) for create_step in self.steps] -def _get_collector(depth=2): - for stackframe in reversed(inspect.stack()): - path = stackframe.frame.f_code.co_filename - if path in _collectors: - return _collectors[path] - raise AssertionError( - "Expected to be called in the context of a migration module import" - ) +def _get_collector(): + frame = inspect.currentframe() + try: + while frame is not None: + if "__yoyo_collector__" in frame.f_globals: + return frame.f_globals["__yoyo_collector__"] + frame = frame.f_back + except KeyError: + raise AssertionError( + "Expected to be called in the context of a migration module import" + ) def step(*args, **kwargs): |