summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorostr00000 <ostr00000@gmail.com>2023-02-27 18:18:19 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2023-02-28 19:04:59 -0500
commita55f79a109008e79f740ab40e6a9fc597785cbbe (patch)
tree5b346c4e46767ceb0ddd732edf7d85c173542771 /tests
parentb53ec0004a08c40a25a4dbf047c51cd140971a9c (diff)
downloadalembic-a55f79a109008e79f740ab40e6a9fc597785cbbe.tar.gz
add recursive_version_locations option for searching revision files
Recursive traversal of revision files in a particular revision directory is now supported, by indicating ``recursive_version_locations = true`` in alembic.ini. Pull request courtesy ostr00000. Fixes: #760 Closes: #1182 Pull-request: https://github.com/sqlalchemy/alembic/pull/1182 Pull-request-sha: ecb0da48b459abd3f5e95390ec7030a7e3fcbc6d Change-Id: I711ca2dbd35fb9a2acdbfd374bcac13043b0d129
Diffstat (limited to 'tests')
-rw-r--r--tests/test_script_consumption.py333
-rw-r--r--tests/test_script_production.py2
2 files changed, 334 insertions, 1 deletions
diff --git a/tests/test_script_consumption.py b/tests/test_script_consumption.py
index fa84d7e..a107b80 100644
--- a/tests/test_script_consumption.py
+++ b/tests/test_script_consumption.py
@@ -1,7 +1,12 @@
+from __future__ import annotations
+
from contextlib import contextmanager
import os
import re
+import shutil
import textwrap
+from typing import Dict
+from typing import List
import sqlalchemy as sa
from sqlalchemy import pool
@@ -9,18 +14,24 @@ from sqlalchemy import pool
from alembic import command
from alembic import testing
from alembic import util
+from alembic.config import Config
from alembic.environment import EnvironmentContext
from alembic.script import Script
from alembic.script import ScriptDirectory
from alembic.testing import assert_raises_message
+from alembic.testing import assertions
from alembic.testing import config
from alembic.testing import eq_
+from alembic.testing import expect_raises_message
from alembic.testing import mock
+from alembic.testing.env import _get_staging_directory
+from alembic.testing.env import _multi_dir_testing_config
from alembic.testing.env import _no_sql_testing_config
from alembic.testing.env import _sqlite_file_db
from alembic.testing.env import _sqlite_testing_config
from alembic.testing.env import clear_staging_env
from alembic.testing.env import env_file_fixture
+from alembic.testing.env import multi_heads_fixture
from alembic.testing.env import staging_env
from alembic.testing.env import three_rev_fixture
from alembic.testing.env import write_script
@@ -900,3 +911,325 @@ class SourcelessNeedsFlagTest(TestBase):
self.cfg.set_main_option("sourceless", "true")
script = ScriptDirectory.from_config(self.cfg)
eq_(script.get_heads(), [a])
+
+
+class RecursiveScriptDirectoryTest(TestBase):
+ """test recursive version directory consumption for #760"""
+
+ rev: List[str]
+ org_script_dir: ScriptDirectory
+ cfg: Config
+ _script_by_name: Dict[str, Script]
+ _name_by_revision: Dict[str, str]
+
+ def _setup_revision_files(
+ self, listing, destination=".", version_path="scripts/versions"
+ ):
+ for elem in listing:
+ if isinstance(elem, str):
+ if destination != ".":
+ script = self._script_by_name[elem]
+ target_file = self._get_moved_path(
+ elem, destination, version_path
+ )
+ os.makedirs(os.path.dirname(target_file), exist_ok=True)
+ shutil.move(script.path, target_file)
+ else:
+ dest, files = elem
+ if dest == "delete":
+ for fname in files:
+ revision_to_remove = self._script_by_name[fname]
+ os.remove(revision_to_remove.path)
+ else:
+ self._setup_revision_files(
+ files, os.path.join(destination, dest), version_path
+ )
+
+ def _get_moved_path(
+ self,
+ elem: str,
+ destination_dir: str = "",
+ version_path="scripts/versions",
+ ):
+ script = self._script_by_name[elem]
+ file_name = os.path.basename(script.path)
+ target_file = os.path.join(
+ _get_staging_directory(), version_path, destination_dir, file_name
+ )
+ target_file = os.path.realpath(target_file)
+ return target_file
+
+ def _assert_setup(self, *elements):
+ sd = ScriptDirectory.from_config(self.cfg)
+
+ _new_rev_to_script = {
+ self._name_by_revision[r.revision]: r for r in sd.walk_revisions()
+ }
+
+ for revname, directory, version_path in elements:
+ eq_(
+ _new_rev_to_script[revname].path,
+ self._get_moved_path(revname, directory, version_path),
+ )
+
+ eq_(len(_new_rev_to_script), len(elements))
+
+ revs_to_check = {
+ self._script_by_name[rev].revision for rev, _, _ in elements
+ }
+
+ # topological order check
+ for rev_id in revs_to_check:
+ new_script = sd.get_revision(rev_id)
+ assertions.is_not_(new_script, None)
+
+ old_revisions = {
+ r.revision: r
+ for r in self.org_script_dir.revision_map.iterate_revisions(
+ rev_id,
+ "base",
+ inclusive=True,
+ assert_relative_length=False,
+ )
+ }
+ new_revisions = {
+ r.revision: r
+ for r in sd.revision_map.iterate_revisions(
+ rev_id,
+ "base",
+ inclusive=True,
+ assert_relative_length=False,
+ )
+ }
+
+ eq_(len(old_revisions), len(new_revisions))
+
+ for common_rev_id in set(old_revisions.keys()).union(
+ new_revisions.keys()
+ ):
+ old_rev = old_revisions[common_rev_id]
+ new_rev = new_revisions[common_rev_id]
+
+ eq_(old_rev.revision, new_rev.revision)
+ eq_(old_rev.down_revision, new_rev.down_revision)
+ eq_(old_rev.dependencies, new_rev.dependencies)
+
+ def _setup_for_fixture(self, revs):
+ self.rev = revs
+
+ self.org_script_dir = ScriptDirectory.from_config(self.cfg)
+ rev_to_script = {
+ script.revision: script
+ for script in self.org_script_dir.walk_revisions()
+ }
+ self._script_by_name = {
+ f"r{i}": rev_to_script[revnum] for i, revnum in enumerate(self.rev)
+ }
+ self._name_by_revision = {
+ v.revision: k for k, v in self._script_by_name.items()
+ }
+
+ @testing.fixture
+ def non_recursive_fixture(self):
+ self.env = staging_env()
+ self.cfg = _sqlite_testing_config()
+
+ ids = [util.rev_id() for i in range(5)]
+
+ script = ScriptDirectory.from_config(self.cfg)
+ script.generate_revision(
+ ids[0], "revision a", refresh=True, head="base"
+ )
+ script.generate_revision(
+ ids[1], "revision b", refresh=True, head=ids[0]
+ )
+ script.generate_revision(
+ ids[2], "revision c", refresh=True, head=ids[1]
+ )
+ script.generate_revision(
+ ids[3], "revision d", refresh=True, head="base"
+ )
+ script.generate_revision(
+ ids[4], "revision e", refresh=True, head=ids[3]
+ )
+
+ self._setup_for_fixture(ids)
+
+ yield
+
+ clear_staging_env()
+
+ @testing.fixture
+ def single_base_fixture(self):
+ self.env = staging_env()
+ self.cfg = _sqlite_testing_config()
+ self.cfg.set_main_option("recursive_version_locations", "true")
+
+ revs = list(three_rev_fixture(self.cfg))
+ revs.extend(multi_heads_fixture(self.cfg, *revs[0:3]))
+
+ self._setup_for_fixture(revs)
+
+ yield
+
+ clear_staging_env()
+
+ @testing.fixture
+ def multi_base_fixture(self):
+
+ self.env = staging_env()
+ self.cfg = _multi_dir_testing_config()
+ self.cfg.set_main_option("recursive_version_locations", "true")
+
+ script0 = command.revision(
+ self.cfg,
+ message="x",
+ head="base",
+ version_path=os.path.join(_get_staging_directory(), "model1"),
+ )
+ assert isinstance(script0, Script)
+ script1 = command.revision(
+ self.cfg,
+ message="y",
+ head="base",
+ version_path=os.path.join(_get_staging_directory(), "model2"),
+ )
+ assert isinstance(script1, Script)
+ script2 = command.revision(
+ self.cfg, message="y2", head=script1.revision
+ )
+ assert isinstance(script2, Script)
+
+ self.org_script_dir = ScriptDirectory.from_config(self.cfg)
+
+ rev_to_script = {
+ script0.revision: script0,
+ script1.revision: script1,
+ script2.revision: script2,
+ }
+
+ self._setup_for_fixture(rev_to_script)
+
+ yield
+
+ clear_staging_env()
+
+ def test_ignore_for_non_recursive(self, non_recursive_fixture):
+ """test traversal is non-recursive when the feature is not enabled
+ (subdirectories are ignored).
+
+ """
+
+ self._setup_revision_files(
+ [
+ "r0",
+ "r1",
+ ("dir_1", ["r2", "r3"]),
+ ("dir_2", ["r4"]),
+ ]
+ )
+
+ vl = "scripts/versions"
+
+ self._assert_setup(
+ ("r0", "", vl),
+ ("r1", "", vl),
+ )
+
+ def test_flat_structure(self, single_base_fixture):
+ assert len(self.rev) == 6
+
+ def test_flat_and_dir_structure(self, single_base_fixture):
+ self._setup_revision_files(
+ [
+ "r1",
+ ("dir_1", ["r0", "r2"]),
+ ("dir_2", ["r4"]),
+ ("dir_3", ["r5"]),
+ ]
+ )
+
+ vl = "scripts/versions"
+
+ self._assert_setup(
+ ("r0", "dir_1", vl),
+ ("r1", "", vl),
+ ("r2", "dir_1", vl),
+ ("r3", "", vl),
+ ("r4", "dir_2", vl),
+ ("r5", "dir_3", vl),
+ )
+
+ def test_nested_dir_structure(self, single_base_fixture):
+ self._setup_revision_files(
+ [
+ (
+ "dir_1",
+ ["r0", ("nested_1", ["r1", "r2"]), ("nested_2", ["r3"])],
+ ),
+ ("dir_2", ["r4"]),
+ ("dir_3", [("nested_3", ["r5"])]),
+ ]
+ )
+
+ vl = "scripts/versions"
+
+ self._assert_setup(
+ ("r0", "dir_1", vl),
+ ("r1", "dir_1/nested_1", vl),
+ ("r2", "dir_1/nested_1", vl),
+ ("r3", "dir_1/nested_2", vl),
+ ("r4", "dir_2", vl),
+ ("r5", "dir_3/nested_3", vl),
+ )
+
+ def test_dir_structure_with_missing_file(self, single_base_fixture):
+ sd = ScriptDirectory.from_config(self.cfg)
+
+ revision_to_remove = self._script_by_name["r1"]
+ self._setup_revision_files(
+ [
+ ("delete", ["r1"]),
+ ("dir_1", ["r0", "r2"]),
+ ("dir_2", ["r4"]),
+ ("dir_3", ["r5"]),
+ ]
+ )
+
+ with expect_raises_message(KeyError, revision_to_remove.revision):
+ list(sd.walk_revisions())
+
+ def test_multiple_dir_recursive(self, multi_base_fixture):
+ self._setup_revision_files(
+ [
+ ("dir_0", ["r0"]),
+ ],
+ version_path="model1",
+ )
+ self._setup_revision_files(
+ [
+ ("dir_1", ["r1", ("nested", ["r2"])]),
+ ],
+ version_path="model2",
+ )
+ self._assert_setup(
+ ("r0", "dir_0", "model1"),
+ ("r1", "dir_1", "model2"),
+ ("r2", "dir_1/nested", "model2"),
+ )
+
+ def test_multiple_dir_recursive_change_version_dir(
+ self, multi_base_fixture
+ ):
+ self._setup_revision_files(
+ [
+ ("dir_0", ["r0"]),
+ ("dir_1", ["r1", ("nested", ["r2"])]),
+ ],
+ version_path="model1",
+ )
+ self._assert_setup(
+ ("r0", "dir_0", "model1"),
+ ("r1", "dir_1", "model1"),
+ ("r2", "dir_1/nested", "model1"),
+ )
diff --git a/tests/test_script_production.py b/tests/test_script_production.py
index bedf545..bddea5f 100644
--- a/tests/test_script_production.py
+++ b/tests/test_script_production.py
@@ -707,7 +707,7 @@ class ImportsTest(TestBase):
context.configure(
connection=connection,
target_metadata=target_metadata,
- **kw
+ **kw,
)
with context.begin_transaction():
context.run_migrations()