From c7d92dcb1d3f3d007812f9a273735c1071c7bbdc Mon Sep 17 00:00:00 2001 From: Olly Cope Date: Wed, 14 Dec 2022 14:01:47 +0000 Subject: Replace py36-compatible type annotation comments with newer syntax --- yoyo/backends/base.py | 2 +- yoyo/config.py | 28 ++++++++++++---------------- yoyo/internalmigrations/v2.py | 4 ++-- yoyo/migrations.py | 31 ++++++++++++++----------------- 4 files changed, 29 insertions(+), 36 deletions(-) diff --git a/yoyo/backends/base.py b/yoyo/backends/base.py index 8ca956e..100849a 100644 --- a/yoyo/backends/base.py +++ b/yoyo/backends/base.py @@ -107,7 +107,7 @@ class SavepointTransactionManager(TransactionManager): class DatabaseBackend: - driver_module = "" # type: str + driver_module = "" log_table = "_yoyo_log" lock_table = "yoyo_lock" diff --git a/yoyo/config.py b/yoyo/config.py index 0569a12..9517874 100644 --- a/yoyo/config.py +++ b/yoyo/config.py @@ -18,15 +18,11 @@ Handle config file and argument parsing from collections import deque from configparser import ConfigParser from pathlib import Path -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union import configparser import functools import itertools import os +import typing as t CONFIG_FILENAME = "yoyo.ini" CONFIG_EDITOR_KEY = "editor" @@ -44,7 +40,7 @@ class CircularReferenceError(configparser.Error): class CustomInterpolation(configparser.BasicInterpolation): - defaults = {} # type: Dict[str, str] + defaults: t.Dict[str, str] = {} def __init__(self, defaults): self.defaults = defaults or {} @@ -57,7 +53,7 @@ class CustomInterpolation(configparser.BasicInterpolation): ) -def get_interpolation_defaults(path: Optional[str] = None): +def get_interpolation_defaults(path: t.Optional[str] = None): parser = configparser.ConfigParser() defaults = { parser.optionxform(k): v.replace("%", "%%") for k, v in os.environ.items() @@ -82,7 +78,7 @@ def update_argparser_defaults(parser, defaults): parser.set_defaults(**{k: v for k, v in defaults.items() if k in known_args}) -def read_config(src: Optional[str]) -> ConfigParser: +def read_config(src: t.Optional[str]) -> ConfigParser: """ Read the configuration file at ``src`` and construct a ConfigParser instance. @@ -95,9 +91,9 @@ def read_config(src: Optional[str]) -> ConfigParser: config = _read_config(path) config_files = {path: config} merge_paths = deque([path]) - to_process = [ - ((), path, config) - ] # type: List[Tuple[Union[Tuple, Tuple[Path]], Path, ConfigParser]] + to_process: t.List[t.Tuple[t.Union[t.List[Path]], Path, ConfigParser]] = [ + ([], path, config) + ] while to_process: ancestors, path, config = to_process.pop() inherits, includes = find_includes(path, config) @@ -108,7 +104,7 @@ def read_config(src: Optional[str]) -> ConfigParser: ) config = _read_config(p) config_files[p] = config - to_process.append((ancestors + (path,), p, config)) + to_process.append((ancestors + [path], p, config)) merge_paths.extendleft(inherits) merge_paths.extend(includes) @@ -120,7 +116,7 @@ def read_config(src: Optional[str]) -> ConfigParser: return merged -def _make_path(s: str, basepath: Optional[Path] = None) -> Path: +def _make_path(s: str, basepath: t.Optional[Path] = None) -> Path: """ Return a fully resolved Path. Raises FileNotFoundError if the path does not exist. @@ -152,9 +148,9 @@ def _read_config(path: Path) -> ConfigParser: def find_includes( basepath: Path, config: ConfigParser -) -> Tuple[List[Path], List[Path]]: +) -> t.Tuple[t.List[Path], t.List[Path]]: - result = {INCLUDE: [], INHERIT: []} # type: Dict[str, List[Path]] + result: t.Dict[str, t.List[Path]] = {INCLUDE: [], INHERIT: []} for key in [INHERIT, INCLUDE]: try: paths = config["DEFAULT"][key].split() @@ -177,7 +173,7 @@ def find_includes( return result[INHERIT], result[INCLUDE] -def merge_configs(configs: List[ConfigParser]) -> ConfigParser: +def merge_configs(configs: t.List[ConfigParser]) -> ConfigParser: def merge(c1, c2): c1.read_dict(c2) return c1 diff --git a/yoyo/internalmigrations/v2.py b/yoyo/internalmigrations/v2.py index e81c9e3..1887307 100644 --- a/yoyo/internalmigrations/v2.py +++ b/yoyo/internalmigrations/v2.py @@ -13,8 +13,8 @@ def upgrade(backend): cursor = backend.execute( "SELECT id, ctime FROM {}".format(backend.migration_table_quoted) ) - migration_id = "" # type: str - created_at = datetime(1970, 1, 1) # type: datetime + migration_id = "" + created_at = datetime(1970, 1, 1) for migration_id, created_at in iter(cursor.fetchone, None): # type: ignore migration_hash = get_migration_hash(migration_id) log_data = dict( diff --git a/yoyo/migrations.py b/yoyo/migrations.py index 5b30049..79d371d 100755 --- a/yoyo/migrations.py +++ b/yoyo/migrations.py @@ -21,10 +21,7 @@ from itertools import chain from itertools import count from itertools import zip_longest from logging import getLogger -from typing import Dict -from typing import Iterable -from typing import List -from typing import Tuple +import typing as t import hashlib import importlib.util import os @@ -72,7 +69,7 @@ def get_migration_hash(migration_id): # eg: "-- depends: 1 2" -DirectivesType = Dict[str, str] +DirectivesType = t.Dict[str, str] LeadingCommentType = str @@ -81,7 +78,7 @@ SqlType = str def parse_metadata_from_sql_comments( s: str, -) -> Tuple[DirectivesType, LeadingCommentType, SqlType]: +) -> t.Tuple[DirectivesType, LeadingCommentType, SqlType]: directive_names = ["transactional", "depends"] comment_or_empty = re.compile(r"^(\s*|\s*--.*)$").match directive_pattern = re.compile( @@ -90,7 +87,7 @@ def parse_metadata_from_sql_comments( lineending = re.search(r"\n|\r\n|\r", s + "\n").group(0) # type: ignore lines = iter(s.split(lineending)) - directives = {} # type: DirectivesType + directives: DirectivesType = {} leading_comments = [] sql = [] for line in lines: @@ -117,8 +114,8 @@ def parse_metadata_from_sql_comments( def read_sql_migration( path: str, -) -> Tuple[DirectivesType, LeadingCommentType, List[str]]: - directives = {} # type: DirectivesType +) -> t.Tuple[DirectivesType, LeadingCommentType, t.List[str]]: + directives: DirectivesType = {} leading_comment = "" statements = [] if os.path.exists(path): @@ -137,7 +134,7 @@ def read_sql_migration( class Migration(object): - __all_migrations = {} # type: Dict[str, "Migration"] + __all_migrations: t.Dict[str, "Migration"] = {} def __init__(self, id, path, source_dir): self.id = id @@ -451,7 +448,7 @@ class StepGroup(MigrationStep): item.rollback(backend, force) -def _expand_sources(sources) -> Iterable[Tuple[str, List[str]]]: +def _expand_sources(sources) -> t.Iterable[t.Tuple[str, t.List[str]]]: package_match = re.compile(r"^package:([^\s\/:]+):(.*)$").match for source in sources: mo = package_match(source) @@ -482,7 +479,7 @@ def read_migrations(*sources): """ Return a ``MigrationList`` containing all migrations from ``sources``. """ - migrations = OrderedDict() # type: Dict[str, MigrationList] + migrations: t.Dict[str, MigrationList] = OrderedDict() for source, paths in _expand_sources(sources): for path in paths: @@ -526,7 +523,7 @@ class MigrationList(abc.MutableSequence): return "{}({})".format(self.__class__.__name__, repr(self.items)) def check_conflicts(self): - c = Counter() # type: Dict[str, int] + c: t.Dict[str, int] = Counter() for item in self: c[item.id] += 1 if c[item.id] > 1: @@ -601,9 +598,9 @@ class StepCollector(object): def do_add(use_transactions): wrapper = TransactionWrapper if use_transactions else Transactionless - t = MigrationStep(next(self.step_id), apply, rollback) # type: StepBase - t = wrapper(t, ignore_errors) - return t + step: StepBase = MigrationStep(next(self.step_id), apply, rollback) + step = wrapper(step, ignore_errors) + return step self.steps[do_add] = 1 return do_add @@ -721,7 +718,7 @@ def heads(migration_list): return heads -def topological_sort(migrations: Iterable[Migration]) -> Iterable[Migration]: +def topological_sort(migrations: t.Iterable[Migration]) -> t.Iterable[Migration]: from yoyo.topologicalsort import topological_sort as topological_sort_impl from yoyo.topologicalsort import CycleError -- cgit v1.2.1