summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorOlly Cope <olly@ollycope.com>2022-09-01 11:49:44 +0000
committerOlly Cope <olly@ollycope.com>2022-09-01 11:49:44 +0000
commit290150383e2a12628ccc72c4f4fed252a36d0f40 (patch)
tree3c44235d1d6378031f0ac95b2e9b70e434c875a9
parent67160446adfedc322b632ff92bfe8a1f843d881a (diff)
downloadyoyo-290150383e2a12628ccc72c4f4fed252a36d0f40.tar.gz
Reformat with Black
-rw-r--r--yoyo/config.py9
-rw-r--r--yoyo/internalmigrations/__init__.py4
-rw-r--r--yoyo/internalmigrations/v2.py3
-rwxr-xr-xyoyo/migrations.py12
-rwxr-xr-xyoyo/scripts/main.py19
-rwxr-xr-xyoyo/scripts/migrate.py43
-rw-r--r--yoyo/tests/conftest.py4
-rw-r--r--yoyo/tests/test_backends.py25
-rw-r--r--yoyo/tests/test_cli_script.py54
-rw-r--r--yoyo/tests/test_config.py6
-rw-r--r--yoyo/tests/test_connections.py8
-rw-r--r--yoyo/tests/test_internalmigrations.py4
-rwxr-xr-xyoyo/utils.py13
13 files changed, 68 insertions, 136 deletions
diff --git a/yoyo/config.py b/yoyo/config.py
index a7b2806..9010cac 100644
--- a/yoyo/config.py
+++ b/yoyo/config.py
@@ -60,8 +60,7 @@ class CustomInterpolation(configparser.BasicInterpolation):
def get_interpolation_defaults(path: Optional[str] = None):
parser = configparser.ConfigParser()
defaults = {
- parser.optionxform(k): v.replace("%", "%%")
- for k, v in os.environ.items()
+ parser.optionxform(k): v.replace("%", "%%") for k, v in os.environ.items()
}
if path:
defaults["here"] = os.path.dirname(os.path.abspath(path))
@@ -80,9 +79,7 @@ def update_argparser_defaults(parser, defaults):
arguments the parser has configured.
"""
known_args = {action.dest for action in parser._actions}
- parser.set_defaults(
- **{k: v for k, v in defaults.items() if k in known_args}
- )
+ parser.set_defaults(**{k: v for k, v in defaults.items() if k in known_args})
def read_config(src: Optional[str]) -> ConfigParser:
@@ -179,7 +176,7 @@ def find_includes(
return result[INHERIT], result[INCLUDE]
-def merge_configs(configs: List[ConfigParser],) -> ConfigParser:
+def merge_configs(configs: List[ConfigParser]) -> ConfigParser:
def merge(c1, c2):
c1.read_dict(c2)
return c1
diff --git a/yoyo/internalmigrations/__init__.py b/yoyo/internalmigrations/__init__.py
index 65db528..5bcb0b5 100644
--- a/yoyo/internalmigrations/__init__.py
+++ b/yoyo/internalmigrations/__init__.py
@@ -67,8 +67,6 @@ def mark_schema_version(backend, version):
if version < USE_VERSION_TABLE_FROM:
return
backend.execute(
- "INSERT INTO {0.version_table_quoted} VALUES (:version, :when)".format(
- backend
- ),
+ "INSERT INTO {0.version_table_quoted} VALUES (:version, :when)".format(backend),
{"version": version, "when": datetime.utcnow()},
)
diff --git a/yoyo/internalmigrations/v2.py b/yoyo/internalmigrations/v2.py
index 0a4848d..e81c9e3 100644
--- a/yoyo/internalmigrations/v2.py
+++ b/yoyo/internalmigrations/v2.py
@@ -21,8 +21,7 @@ def upgrade(backend):
backend.get_log_data(),
operation="apply",
comment=(
- "this log entry created automatically by an "
- "internal schema upgrade"
+ "this log entry created automatically by an internal schema upgrade"
),
created_at_utc=created_at,
migration_hash=migration_hash,
diff --git a/yoyo/migrations.py b/yoyo/migrations.py
index b4d1a7d..5450d0a 100755
--- a/yoyo/migrations.py
+++ b/yoyo/migrations.py
@@ -125,9 +125,11 @@ def read_sql_migration(
with open(path, "r", encoding="UTF-8") as f:
statements = sqlparse.split(f.read())
if statements:
- (directives, leading_comment, sql,) = parse_metadata_from_sql_comments(
- statements[0]
- )
+ (
+ directives,
+ leading_comment,
+ sql,
+ ) = parse_metadata_from_sql_comments(statements[0])
statements[0] = sql
statements = [s for s in statements if s.strip()]
return directives, leading_comment, statements
@@ -495,7 +497,9 @@ def read_migrations(*sources):
migration_class = PostApplyHookMigration
migration = migration_class(
- os.path.splitext(os.path.basename(path))[0], path, source_dir=source,
+ os.path.splitext(os.path.basename(path))[0],
+ path,
+ source_dir=source,
)
ml = migrations.setdefault(source, MigrationList())
if migration_class is PostApplyHookMigration:
diff --git a/yoyo/scripts/main.py b/yoyo/scripts/main.py
index b3c33cf..e6bae2b 100755
--- a/yoyo/scripts/main.py
+++ b/yoyo/scripts/main.py
@@ -68,9 +68,7 @@ def parse_args(argv=None):
# Read the config file and create a dictionary of defaults for argparser
config = read_config(
- (global_args.config or find_config())
- if global_args.use_config_file
- else None
+ (global_args.config or find_config()) if global_args.use_config_file else None
)
defaults = {}
@@ -116,8 +114,7 @@ def make_argparser():
dest="verbosity",
action="count",
default=min_verbosity,
- help="Verbose output. Use multiple times "
- "to increase level of verbosity",
+ help="Verbose output. Use multiple times to increase level of verbosity",
)
global_parser.add_argument(
"-b",
@@ -180,9 +177,7 @@ def upgrade_legacy_config(args, config, sources):
legacy_config = read_config(path)
- def transfer_setting(
- oldname, newname, transform=None, section="DEFAULT"
- ):
+ def transfer_setting(oldname, newname, transform=None, section="DEFAULT"):
try:
config.get(section, newname)
except configparser.NoOptionError:
@@ -205,9 +200,7 @@ def upgrade_legacy_config(args, config, sources):
config_path = args.config or CONFIG_FILENAME
if not args.batch_mode:
if utils.confirm(
- "Move legacy configuration in {!r} to {!r}?".format(
- path, config_path
- )
+ "Move legacy configuration in {!r} to {!r}?".format(path, config_path)
):
save_config(config, config_path)
try:
@@ -231,9 +224,7 @@ def upgrade_legacy_config(args, config, sources):
)
try:
- args.database = args.database or legacy_config.get(
- "DEFAULT", "dburi"
- )
+ args.database = args.database or legacy_config.get("DEFAULT", "dburi")
except configparser.NoOptionError:
pass
try:
diff --git a/yoyo/scripts/migrate.py b/yoyo/scripts/migrate.py
index b7c79f3..15de7b0 100755
--- a/yoyo/scripts/migrate.py
+++ b/yoyo/scripts/migrate.py
@@ -70,17 +70,12 @@ def install_argparsers(global_parser, subparsers):
filter_parser.add_argument(
"-r",
"--revision",
- help=(
- "Apply/rollback migration with id REVISION and all its "
- "dependencies"
- ),
+ help=("Apply/rollback migration with id REVISION and all its dependencies"),
metavar="REVISION",
)
# Options related to applying/rolling back migrations
- apply_parser = argparse.ArgumentParser(
- add_help=False, parents=[filter_parser]
- )
+ apply_parser = argparse.ArgumentParser(add_help=False, parents=[filter_parser])
apply_parser.add_argument(
"-a",
"--all",
@@ -94,8 +89,7 @@ def install_argparsers(global_parser, subparsers):
"--force",
dest="force",
action="store_true",
- help="Force apply/rollback of steps even if "
- "previous steps have failed",
+ help="Force apply/rollback of steps even if previous steps have failed",
)
parser_apply = subparsers.add_parser(
@@ -119,7 +113,10 @@ def install_argparsers(global_parser, subparsers):
)
parser_develop.set_defaults(func=develop, command_name="develop")
parser_develop.add_argument(
- "-n", type=int, help="Act on the last N migrations", default=1,
+ "-n",
+ type=int,
+ help="Act on the last N migrations",
+ default=1,
)
parser_list = subparsers.add_parser(
@@ -183,9 +180,7 @@ def migrations_to_revision(migrations, revision, direction):
targets = [m for m in migrations if revision in m.id]
if len(targets) == 0:
- raise InvalidArgument(
- "'{}' doesn't match any revisions.".format(revision)
- )
+ raise InvalidArgument("'{}' doesn't match any revisions.".format(revision))
if len(targets) > 1:
raise InvalidArgument(
"'{}' matches multiple revisions. "
@@ -233,12 +228,7 @@ def get_migrations(args, backend, direction=None):
if not args.batch_mode and not args.revision:
migrations = prompt_migrations(backend, migrations, args.command_name)
- if (
- args.batch_mode
- and not args.revision
- and not args.all
- and args.func is rollback
- ):
+ if args.batch_mode and not args.revision and not args.all and args.func is rollback:
if len(migrations) > 1:
warnings.warn(
"Only rolling back a single migration."
@@ -258,9 +248,7 @@ def get_migrations(args, backend, direction=None):
print(" [{m.id}]".format(m=m))
prompt = "{} {} to {}".format(
args.command_name.title(),
- utils.plural(
- len(migrations), "this migration", "these %d migrations"
- ),
+ utils.plural(len(migrations), "this migration", "these %d migrations"),
dburi,
)
if not utils.confirm(prompt, default="y"):
@@ -385,15 +373,11 @@ def prompt_migrations(backend, migrations, direction):
choice = "n" if is_applied else "y"
else:
choice = "y" if is_applied else "n"
- options = "".join(
- o.upper() if o == choice else o.lower() for o in "ynvdaqjk?"
- )
+ options = "".join(o.upper() if o == choice else o.lower() for o in "ynvdaqjk?")
print("")
print("[%s]" % (mig.migration.id,))
- response = utils.prompt(
- "Shall I %s this migration?" % (direction,), options
- )
+ response = utils.prompt("Shall I %s this migration?" % (direction,), options)
if response == "?":
print("")
@@ -403,8 +387,7 @@ def prompt_migrations(backend, migrations, direction):
print("v: view this migration in full")
print("")
print(
- "d: %s the selected migrations, skipping any remaining"
- % (direction,)
+ "d: %s the selected migrations, skipping any remaining" % (direction,)
)
print("a: %s all the remaining migrations" % (direction,))
print("q: cancel without making any changes")
diff --git a/yoyo/tests/conftest.py b/yoyo/tests/conftest.py
index f92e59e..1a8ce51 100644
--- a/yoyo/tests/conftest.py
+++ b/yoyo/tests/conftest.py
@@ -15,9 +15,7 @@ def _backend(dburi):
with backend.transaction():
if backend.__class__ is backends.MySQLBackend:
backend.execute(
- "CREATE TABLE yoyo_t "
- "(id CHAR(1) primary key) "
- "ENGINE=InnoDB"
+ "CREATE TABLE yoyo_t (id CHAR(1) primary key) ENGINE=InnoDB"
)
else:
backend.execute("CREATE TABLE yoyo_t " "(id CHAR(1) primary key)")
diff --git a/yoyo/tests/test_backends.py b/yoyo/tests/test_backends.py
index abf8d47..c812f3d 100644
--- a/yoyo/tests/test_backends.py
+++ b/yoyo/tests/test_backends.py
@@ -101,14 +101,10 @@ class TestTransactionHandling(object):
backend.execute("INSERT INTO yoyo_b VALUES (1)")
trans.rollback()
- count_a = backend.execute("SELECT COUNT(1) FROM yoyo_a").fetchall()[0][
- 0
- ]
+ count_a = backend.execute("SELECT COUNT(1) FROM yoyo_a").fetchall()[0][0]
assert count_a == 1
- count_b = backend.execute("SELECT COUNT(1) FROM yoyo_b").fetchall()[0][
- 0
- ]
+ count_b = backend.execute("SELECT COUNT(1) FROM yoyo_b").fetchall()[0][0]
assert count_b == 0
def test_statements_requiring_no_transaction(self):
@@ -183,19 +179,14 @@ class TestConcurrency(object):
return lock_sleep
def skip_if_not_concurrency_safe(self, backend):
- if (
- "sqlite" in backend.uri.scheme
- and backend.uri.database == ":memory:"
- ):
+ if "sqlite" in backend.uri.scheme and backend.uri.database == ":memory:":
pytest.skip(
"Concurrency tests not supported for SQLite "
"in-memory databases, which cannot be shared "
"between threads"
)
if backend.driver.threadsafety < 1:
- pytest.skip(
- "Concurrency tests not supported for non-threadsafe backends"
- )
+ pytest.skip("Concurrency tests not supported for non-threadsafe backends")
def test_lock(self, dburi):
"""
@@ -310,9 +301,7 @@ class TestInitConnection(object):
finally:
with backend.transaction():
- backend.execute(
- "ALTER DATABASE {} RESET SEARCH_PATH".format(dbname)
- )
+ backend.execute("ALTER DATABASE {} RESET SEARCH_PATH".format(dbname))
backend.execute("DROP SCHEMA custom_schema CASCADE")
def test_postgresql_migrations_can_change_schema_search_path(self):
@@ -324,9 +313,7 @@ class TestInitConnection(object):
pytest.skip("PostgreSQL backend not available")
backend = get_backend(dburi)
with migrations_dir(
- **{
- "1.sql": "SELECT pg_catalog.set_config('search_path', '', false)"
- }
+ **{"1.sql": "SELECT pg_catalog.set_config('search_path', '', false)"}
) as tmpdir:
migrations = read_migrations(tmpdir)
backend.apply_migrations(migrations)
diff --git a/yoyo/tests/test_cli_script.py b/yoyo/tests/test_cli_script.py
index b385e1c..8199ee1 100644
--- a/yoyo/tests/test_cli_script.py
+++ b/yoyo/tests/test_cli_script.py
@@ -39,9 +39,9 @@ from yoyo.scripts import newmigration
def is_tmpfile(p, directory=None):
- return (
- p.startswith(directory) if directory else True
- ) and os.path.basename(p).startswith(newmigration.tempfile_prefix)
+ return (p.startswith(directory) if directory else True) and os.path.basename(
+ p
+ ).startswith(newmigration.tempfile_prefix)
class TestInteractiveScript(object):
@@ -127,9 +127,9 @@ class TestYoyoScript(TestInteractiveScript):
def test_it_prompts_password(self, tmpdir):
dburi = "sqlite://user@/:memory"
- with patch(
- "yoyo.scripts.main.getpass", return_value="fish"
- ) as getpass, patch("yoyo.connections.get_backend") as get_backend:
+ with patch("yoyo.scripts.main.getpass", return_value="fish") as getpass, patch(
+ "yoyo.connections.get_backend"
+ ) as get_backend:
main(["apply", str(tmpdir), "--database", dburi, "--prompt-password"])
assert getpass.call_count == 1
assert get_backend.call_args == call(
@@ -189,9 +189,7 @@ class TestYoyoScript(TestInteractiveScript):
self.confirm.return_value = True
main(["apply", str(tmpdir)])
- prompts = [
- args[0].lower() for args, kwargs in self.confirm.call_args_list
- ]
+ prompts = [args[0].lower() for args, kwargs in self.confirm.call_args_list]
assert len(prompts) == 2
assert prompts[0].startswith("move legacy configuration")
assert prompts[1].startswith("delete legacy configuration")
@@ -243,9 +241,7 @@ class TestYoyoScript(TestInteractiveScript):
run_migrations = partial(
main, ["apply", "-b", tmpdir, "--database", str(backend.uri)]
)
- threads = [
- threading.Thread(target=run_migrations) for ix in range(20)
- ]
+ threads = [threading.Thread(target=run_migrations) for ix in range(20)]
for t in threads:
t.start()
for t in threads:
@@ -266,16 +262,14 @@ class TestYoyoScript(TestInteractiveScript):
)
backend.commit()
main(["break-lock", "--database", dburi])
- lock_count = backend.execute(
- "SELECT COUNT(1) FROM yoyo_lock"
- ).fetchone()[0]
+ lock_count = backend.execute("SELECT COUNT(1) FROM yoyo_lock").fetchone()[0]
assert lock_count == 0
def test_it_prompts_password_on_break_lock(self):
dburi = "sqlite://user@/:memory"
- with patch(
- "yoyo.scripts.main.getpass", return_value="fish"
- ) as getpass, patch("yoyo.connections.get_backend") as get_backend:
+ with patch("yoyo.scripts.main.getpass", return_value="fish") as getpass, patch(
+ "yoyo.connections.get_backend"
+ ) as get_backend:
main(["break-lock", "--database", dburi, "--prompt-password"])
assert getpass.call_count == 1
assert get_backend.call_args == call(
@@ -331,9 +325,7 @@ class TestMarkCommand(TestInteractiveScript):
backend = get_backend(self.dburi)
backend.apply_migrations(migrations[:1])
- with patch(
- "yoyo.scripts.migrate.prompt_migrations"
- ) as prompt_migrations:
+ with patch("yoyo.scripts.migrate.prompt_migrations") as prompt_migrations:
main(["mark", tmpdir, "--database", self.dburi])
_, prompted, _ = prompt_migrations.call_args[0]
prompted = [m.id for m in prompted]
@@ -373,9 +365,7 @@ class TestUnmarkCommand(TestInteractiveScript):
backend.apply_migrations(migrations[:2])
assert len(backend.get_applied_migration_hashes()) == 2
- with patch(
- "yoyo.scripts.migrate.prompt_migrations"
- ) as prompt_migrations:
+ with patch("yoyo.scripts.migrate.prompt_migrations") as prompt_migrations:
main(["unmark", tmpdir, "--database", self.dburi])
_, prompted, _ = prompt_migrations.call_args[0]
prompted = [m.id for m in prompted]
@@ -454,9 +444,7 @@ class TestNewMigration(TestInteractiveScript):
# default to $VISUAL
with patch("os.environ", {"EDITOR": "ed", "VISUAL": "visualed"}):
main(["new", str(tmpdir), "--database", dburi_sqlite3])
- assert self.subprocess.call.call_args == call(
- ["visualed", tms.Unicode()]
- )
+ assert self.subprocess.call.call_args == call(["visualed", tms.Unicode()])
# fallback to $EDITOR
with patch("os.environ", {"EDITOR": "ed"}):
@@ -490,9 +478,7 @@ class TestNewMigration(TestInteractiveScript):
self.subprocess.call = write_migration
main(["new", str(tmpdir), "--database", dburi_sqlite3])
- prompts = [
- args[0].lower() for args, kwargs in self.prompt.call_args_list
- ]
+ prompts = [args[0].lower() for args, kwargs in self.prompt.call_args_list]
assert "retry editing?" in prompts[0]
def test_it_defaults_docstring_to_message(self, tmpdir):
@@ -508,18 +494,14 @@ class TestNewMigration(TestInteractiveScript):
]
)
names = [n for n in sorted(os.listdir(tmpdir)) if n.endswith(".py")]
- with open(
- os.path.join(str(tmpdir), names[0]), "r", encoding="utf-8"
- ) as f:
+ with open(os.path.join(str(tmpdir), names[0]), "r", encoding="utf-8") as f:
assert "your ad here" in f.read()
def test_it_calls_post_create_command(self, tmpdir):
self.writeconfig(post_create_command="/bin/ls -l {} {}")
with freezegun.freeze_time("2001-1-1"):
main(["new", "-b", str(tmpdir), "--database", dburi_sqlite3])
- is_filename = tms.Str(
- lambda s: os.path.basename(s).startswith("20010101_01_")
- )
+ is_filename = tms.Str(lambda s: os.path.basename(s).startswith("20010101_01_"))
assert self.subprocess.call.call_args == call(
["/bin/ls", "-l", is_filename, is_filename]
)
diff --git a/yoyo/tests/test_config.py b/yoyo/tests/test_config.py
index d62b697..3c03c4c 100644
--- a/yoyo/tests/test_config.py
+++ b/yoyo/tests/test_config.py
@@ -133,12 +133,14 @@ class TestInheritance:
def test_it_raises_on_not_found(self):
with pytest.raises(FileNotFoundError):
_test_files(
- {"a.ini": "[DEFAULT]\n%inherit = b.ini\n"}, {"DEFAULT": {}},
+ {"a.ini": "[DEFAULT]\n%inherit = b.ini\n"},
+ {"DEFAULT": {}},
)
def test_it_ignores_not_found(self):
_test_files(
- {"a.ini": "[DEFAULT]\n%inherit = ?b.ini\n"}, {"DEFAULT": {}},
+ {"a.ini": "[DEFAULT]\n%inherit = ?b.ini\n"},
+ {"DEFAULT": {}},
)
def test_it_traps_circular_references(self):
diff --git a/yoyo/tests/test_connections.py b/yoyo/tests/test_connections.py
index 3877181..ec4b70b 100644
--- a/yoyo/tests/test_connections.py
+++ b/yoyo/tests/test_connections.py
@@ -67,16 +67,14 @@ class TestParseURI:
def test_passwords_with_slashes_dont_break_netloc(self):
parsed = parse_uri("postgresql://user:a%2Fb@localhost:5432/db")
- assert parsed.netloc == 'user:a%2Fb@localhost:5432'
+ assert parsed.netloc == "user:a%2Fb@localhost:5432"
assert parsed.port == 5432
- assert parsed.password == 'a/b'
+ assert parsed.password == "a/b"
@patch(
"yoyo.backends.get_dbapi_module",
- return_value=MagicMock(
- DatabaseError=MockDatabaseError, paramstyle="qmark"
- ),
+ return_value=MagicMock(DatabaseError=MockDatabaseError, paramstyle="qmark"),
)
def test_connections(get_dbapi_module):
diff --git a/yoyo/tests/test_internalmigrations.py b/yoyo/tests/test_internalmigrations.py
index 25128e6..9e6f880 100644
--- a/yoyo/tests/test_internalmigrations.py
+++ b/yoyo/tests/test_internalmigrations.py
@@ -57,9 +57,7 @@ def test_v3_preserves_history_when_upgrading(backend):
cursor = backend.execute(
"SELECT migration_hash, migration_id, applied_at_utc "
- "FROM {0.migration_table_quoted} order by applied_at_utc".format(
- backend
- )
+ "FROM {0.migration_table_quoted} order by applied_at_utc".format(backend)
)
applied = list(cursor.fetchall())
assert applied == [
diff --git a/yoyo/utils.py b/yoyo/utils.py
index a769cb9..f08314d 100755
--- a/yoyo/utils.py
+++ b/yoyo/utils.py
@@ -45,7 +45,7 @@ else:
saved_attributes = termios.tcgetattr(fd)
try:
attributes = termios.tcgetattr(fd) # get a fresh copy!
- attributes[3] = (attributes[3] & ~(termios.ICANON | termios.ECHO)) # type: ignore # noqa: E501
+ attributes[3] = attributes[3] & ~(termios.ICANON | termios.ECHO) # type: ignore # noqa: E501
attributes[6][termios.VMIN] = 1 # type: ignore
attributes[6][termios.VTIME] = 0 # type: ignore
termios.tcsetattr(fd, termios.TCSANOW, attributes)
@@ -68,8 +68,7 @@ def prompt(prompt, options):
ch = getch()
if ch == os.linesep:
ch = (
- [o.lower() for o in options if "A" <= o <= "Z"]
- + list(options.lower())
+ [o.lower() for o in options if "A" <= o <= "Z"] + list(options.lower())
)[0]
print(ch)
if ch.lower() not in options.lower():
@@ -175,9 +174,7 @@ def change_param_style(target_style, sql, bind_parameters):
r"(?=\W|$)"
)
- transformed_sql = pattern.sub(
- lambda match: param_gen(match.group(1)), sql
- )
+ transformed_sql = pattern.sub(lambda match: param_gen(match.group(1)), sql)
if positional:
positional_params = []
for match in pattern.finditer(sql):
@@ -192,6 +189,4 @@ def unidecode(s: str) -> str:
Return ``s`` with unicode diacritics removed.
"""
combining = unicodedata.combining
- return "".join(
- c for c in unicodedata.normalize("NFD", s) if not combining(c)
- )
+ return "".join(c for c in unicodedata.normalize("NFD", s) if not combining(c))