summaryrefslogtreecommitdiff
path: root/yoyo/migrations.py
diff options
context:
space:
mode:
Diffstat (limited to 'yoyo/migrations.py')
-rwxr-xr-xyoyo/migrations.py31
1 files changed, 14 insertions, 17 deletions
diff --git a/yoyo/migrations.py b/yoyo/migrations.py
index 5b30049..79d371d 100755
--- a/yoyo/migrations.py
+++ b/yoyo/migrations.py
@@ -21,10 +21,7 @@ from itertools import chain
from itertools import count
from itertools import zip_longest
from logging import getLogger
-from typing import Dict
-from typing import Iterable
-from typing import List
-from typing import Tuple
+import typing as t
import hashlib
import importlib.util
import os
@@ -72,7 +69,7 @@ def get_migration_hash(migration_id):
# eg: "-- depends: 1 2"
-DirectivesType = Dict[str, str]
+DirectivesType = t.Dict[str, str]
LeadingCommentType = str
@@ -81,7 +78,7 @@ SqlType = str
def parse_metadata_from_sql_comments(
s: str,
-) -> Tuple[DirectivesType, LeadingCommentType, SqlType]:
+) -> t.Tuple[DirectivesType, LeadingCommentType, SqlType]:
directive_names = ["transactional", "depends"]
comment_or_empty = re.compile(r"^(\s*|\s*--.*)$").match
directive_pattern = re.compile(
@@ -90,7 +87,7 @@ def parse_metadata_from_sql_comments(
lineending = re.search(r"\n|\r\n|\r", s + "\n").group(0) # type: ignore
lines = iter(s.split(lineending))
- directives = {} # type: DirectivesType
+ directives: DirectivesType = {}
leading_comments = []
sql = []
for line in lines:
@@ -117,8 +114,8 @@ def parse_metadata_from_sql_comments(
def read_sql_migration(
path: str,
-) -> Tuple[DirectivesType, LeadingCommentType, List[str]]:
- directives = {} # type: DirectivesType
+) -> t.Tuple[DirectivesType, LeadingCommentType, t.List[str]]:
+ directives: DirectivesType = {}
leading_comment = ""
statements = []
if os.path.exists(path):
@@ -137,7 +134,7 @@ def read_sql_migration(
class Migration(object):
- __all_migrations = {} # type: Dict[str, "Migration"]
+ __all_migrations: t.Dict[str, "Migration"] = {}
def __init__(self, id, path, source_dir):
self.id = id
@@ -451,7 +448,7 @@ class StepGroup(MigrationStep):
item.rollback(backend, force)
-def _expand_sources(sources) -> Iterable[Tuple[str, List[str]]]:
+def _expand_sources(sources) -> t.Iterable[t.Tuple[str, t.List[str]]]:
package_match = re.compile(r"^package:([^\s\/:]+):(.*)$").match
for source in sources:
mo = package_match(source)
@@ -482,7 +479,7 @@ def read_migrations(*sources):
"""
Return a ``MigrationList`` containing all migrations from ``sources``.
"""
- migrations = OrderedDict() # type: Dict[str, MigrationList]
+ migrations: t.Dict[str, MigrationList] = OrderedDict()
for source, paths in _expand_sources(sources):
for path in paths:
@@ -526,7 +523,7 @@ class MigrationList(abc.MutableSequence):
return "{}({})".format(self.__class__.__name__, repr(self.items))
def check_conflicts(self):
- c = Counter() # type: Dict[str, int]
+ c: t.Dict[str, int] = Counter()
for item in self:
c[item.id] += 1
if c[item.id] > 1:
@@ -601,9 +598,9 @@ class StepCollector(object):
def do_add(use_transactions):
wrapper = TransactionWrapper if use_transactions else Transactionless
- t = MigrationStep(next(self.step_id), apply, rollback) # type: StepBase
- t = wrapper(t, ignore_errors)
- return t
+ step: StepBase = MigrationStep(next(self.step_id), apply, rollback)
+ step = wrapper(step, ignore_errors)
+ return step
self.steps[do_add] = 1
return do_add
@@ -721,7 +718,7 @@ def heads(migration_list):
return heads
-def topological_sort(migrations: Iterable[Migration]) -> Iterable[Migration]:
+def topological_sort(migrations: t.Iterable[Migration]) -> t.Iterable[Migration]:
from yoyo.topologicalsort import topological_sort as topological_sort_impl
from yoyo.topologicalsort import CycleError