diff options
Diffstat (limited to 'yoyo/migrations.py')
-rwxr-xr-x | yoyo/migrations.py | 31 |
1 files changed, 14 insertions, 17 deletions
diff --git a/yoyo/migrations.py b/yoyo/migrations.py index 5b30049..79d371d 100755 --- a/yoyo/migrations.py +++ b/yoyo/migrations.py @@ -21,10 +21,7 @@ from itertools import chain from itertools import count from itertools import zip_longest from logging import getLogger -from typing import Dict -from typing import Iterable -from typing import List -from typing import Tuple +import typing as t import hashlib import importlib.util import os @@ -72,7 +69,7 @@ def get_migration_hash(migration_id): # eg: "-- depends: 1 2" -DirectivesType = Dict[str, str] +DirectivesType = t.Dict[str, str] LeadingCommentType = str @@ -81,7 +78,7 @@ SqlType = str def parse_metadata_from_sql_comments( s: str, -) -> Tuple[DirectivesType, LeadingCommentType, SqlType]: +) -> t.Tuple[DirectivesType, LeadingCommentType, SqlType]: directive_names = ["transactional", "depends"] comment_or_empty = re.compile(r"^(\s*|\s*--.*)$").match directive_pattern = re.compile( @@ -90,7 +87,7 @@ def parse_metadata_from_sql_comments( lineending = re.search(r"\n|\r\n|\r", s + "\n").group(0) # type: ignore lines = iter(s.split(lineending)) - directives = {} # type: DirectivesType + directives: DirectivesType = {} leading_comments = [] sql = [] for line in lines: @@ -117,8 +114,8 @@ def parse_metadata_from_sql_comments( def read_sql_migration( path: str, -) -> Tuple[DirectivesType, LeadingCommentType, List[str]]: - directives = {} # type: DirectivesType +) -> t.Tuple[DirectivesType, LeadingCommentType, t.List[str]]: + directives: DirectivesType = {} leading_comment = "" statements = [] if os.path.exists(path): @@ -137,7 +134,7 @@ def read_sql_migration( class Migration(object): - __all_migrations = {} # type: Dict[str, "Migration"] + __all_migrations: t.Dict[str, "Migration"] = {} def __init__(self, id, path, source_dir): self.id = id @@ -451,7 +448,7 @@ class StepGroup(MigrationStep): item.rollback(backend, force) -def _expand_sources(sources) -> Iterable[Tuple[str, List[str]]]: +def _expand_sources(sources) -> t.Iterable[t.Tuple[str, t.List[str]]]: package_match = re.compile(r"^package:([^\s\/:]+):(.*)$").match for source in sources: mo = package_match(source) @@ -482,7 +479,7 @@ def read_migrations(*sources): """ Return a ``MigrationList`` containing all migrations from ``sources``. """ - migrations = OrderedDict() # type: Dict[str, MigrationList] + migrations: t.Dict[str, MigrationList] = OrderedDict() for source, paths in _expand_sources(sources): for path in paths: @@ -526,7 +523,7 @@ class MigrationList(abc.MutableSequence): return "{}({})".format(self.__class__.__name__, repr(self.items)) def check_conflicts(self): - c = Counter() # type: Dict[str, int] + c: t.Dict[str, int] = Counter() for item in self: c[item.id] += 1 if c[item.id] > 1: @@ -601,9 +598,9 @@ class StepCollector(object): def do_add(use_transactions): wrapper = TransactionWrapper if use_transactions else Transactionless - t = MigrationStep(next(self.step_id), apply, rollback) # type: StepBase - t = wrapper(t, ignore_errors) - return t + step: StepBase = MigrationStep(next(self.step_id), apply, rollback) + step = wrapper(step, ignore_errors) + return step self.steps[do_add] = 1 return do_add @@ -721,7 +718,7 @@ def heads(migration_list): return heads -def topological_sort(migrations: Iterable[Migration]) -> Iterable[Migration]: +def topological_sort(migrations: t.Iterable[Migration]) -> t.Iterable[Migration]: from yoyo.topologicalsort import topological_sort as topological_sort_impl from yoyo.topologicalsort import CycleError |