diff options
-rw-r--r-- | alembic/script/base.py | 73 | ||||
-rw-r--r-- | alembic/templates/async/alembic.ini.mako | 5 | ||||
-rw-r--r-- | alembic/templates/generic/alembic.ini.mako | 5 | ||||
-rw-r--r-- | alembic/templates/multidb/alembic.ini.mako | 5 | ||||
-rw-r--r-- | docs/build/tutorial.rst | 10 | ||||
-rw-r--r-- | docs/build/unreleased/760.rst | 8 | ||||
-rw-r--r-- | tests/test_script_consumption.py | 333 | ||||
-rw-r--r-- | tests/test_script_production.py | 2 |
8 files changed, 414 insertions, 27 deletions
diff --git a/alembic/script/base.py b/alembic/script/base.py index 3c09cef..b6858b5 100644 --- a/alembic/script/base.py +++ b/alembic/script/base.py @@ -80,6 +80,7 @@ class ScriptDirectory: output_encoding: str = "utf-8", timezone: Optional[str] = None, hook_config: Optional[Dict[str, str]] = None, + recursive_version_locations: bool = False, ) -> None: self.dir = dir self.file_template = file_template @@ -90,6 +91,7 @@ class ScriptDirectory: self.revision_map = revision.RevisionMap(self._load_revisions) self.timezone = timezone self.hook_config = hook_config + self.recursive_version_locations = recursive_version_locations if not os.access(dir, os.F_OK): raise util.CommandError( @@ -128,16 +130,19 @@ class ScriptDirectory: dupes = set() for vers in paths: - for file_ in Script._list_py_dir(self, vers): - path = os.path.realpath(os.path.join(vers, file_)) - if path in dupes: + for file_path in Script._list_py_dir(self, vers): + real_path = os.path.realpath(file_path) + if real_path in dupes: util.warn( "File %s loaded twice! ignoring. Please ensure " - "version_locations is unique." % path + "version_locations is unique." % real_path ) continue - dupes.add(path) - script = Script._from_filename(self, vers, file_) + dupes.add(real_path) + + filename = os.path.basename(real_path) + dir_name = os.path.dirname(real_path) + script = Script._from_filename(self, dir_name, filename) if script is None: continue yield script @@ -207,6 +212,7 @@ class ScriptDirectory: _split_on_space_comma_colon.split(prepend_sys_path) ) + rvl = config.get_main_option("recursive_version_locations") == "true" return ScriptDirectory( util.coerce_resource_to_filename(script_location), file_template=config.get_main_option( @@ -218,6 +224,7 @@ class ScriptDirectory: version_locations=version_locations, timezone=config.get_main_option("timezone"), hook_config=config.get_section("post_write_hooks", {}), + recursive_version_locations=rvl, ) @contextmanager @@ -959,26 +966,40 @@ class Script(revision.Revision): @classmethod def _list_py_dir(cls, scriptdir: ScriptDirectory, path: str) -> List[str]: - if scriptdir.sourceless: - # read files in version path, e.g. pyc or pyo files - # in the immediate path - paths = os.listdir(path) - - names = {fname.split(".")[0] for fname in paths} - - # look for __pycache__ - if os.path.exists(os.path.join(path, "__pycache__")): - # add all files from __pycache__ whose filename is not - # already in the names we got from the version directory. - # add as relative paths including __pycache__ token - paths.extend( - os.path.join("__pycache__", pyc) - for pyc in os.listdir(os.path.join(path, "__pycache__")) - if pyc.split(".")[0] not in names - ) - return paths - else: - return os.listdir(path) + paths = [] + for root, dirs, files in os.walk(path, topdown=True): + if root.endswith("__pycache__"): + # a special case - we may include these files + # if a `sourceless` option is specified + continue + + for filename in sorted(files): + paths.append(os.path.join(root, filename)) + + if scriptdir.sourceless: + # look for __pycache__ + py_cache_path = os.path.join(root, "__pycache__") + if os.path.exists(py_cache_path): + # add all files from __pycache__ whose filename is not + # already in the names we got from the version directory. + # add as relative paths including __pycache__ token + names = {filename.split(".")[0] for filename in files} + paths.extend( + os.path.join(py_cache_path, pyc) + for pyc in os.listdir(py_cache_path) + if pyc.split(".")[0] not in names + ) + + if not scriptdir.recursive_version_locations: + break + + # the real script order is defined by revision, + # but it may be undefined if there are many files with a same + # `down_revision`, for a better user experience (ex. debugging), + # we use a deterministic order + dirs.sort() + + return paths @classmethod def _from_filename( diff --git a/alembic/templates/async/alembic.ini.mako b/alembic/templates/async/alembic.ini.mako index 5268e7c..64c7b6b 100644 --- a/alembic/templates/async/alembic.ini.mako +++ b/alembic/templates/async/alembic.ini.mako @@ -49,6 +49,11 @@ prepend_sys_path = . # version_path_separator = space version_path_separator = os # Use os.pathsep. Default configuration used for new projects. +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + # the output encoding used when revision files # are written from script.py.mako # output_encoding = utf-8 diff --git a/alembic/templates/generic/alembic.ini.mako b/alembic/templates/generic/alembic.ini.mako index 8aa47b1..f541b17 100644 --- a/alembic/templates/generic/alembic.ini.mako +++ b/alembic/templates/generic/alembic.ini.mako @@ -51,6 +51,11 @@ prepend_sys_path = . # version_path_separator = space version_path_separator = os # Use os.pathsep. Default configuration used for new projects. +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + # the output encoding used when revision files # are written from script.py.mako # output_encoding = utf-8 diff --git a/alembic/templates/multidb/alembic.ini.mako b/alembic/templates/multidb/alembic.ini.mako index 5adef39..4230fe1 100644 --- a/alembic/templates/multidb/alembic.ini.mako +++ b/alembic/templates/multidb/alembic.ini.mako @@ -51,6 +51,11 @@ prepend_sys_path = . # version_path_separator = space version_path_separator = os # Use os.pathsep. Default configuration used for new projects. +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + # the output encoding used when revision files # are written from script.py.mako # output_encoding = utf-8 diff --git a/docs/build/tutorial.rst b/docs/build/tutorial.rst index 8540be1..8823a91 100644 --- a/docs/build/tutorial.rst +++ b/docs/build/tutorial.rst @@ -177,6 +177,11 @@ The file generated with the "generic" configuration looks like:: # version_path_separator = space version_path_separator = os # Use os.pathsep. Default configuration used for new projects. + # set to 'true' to search source files recursively + # in each "version_locations" directory + # new in Alembic version 1.10 + # recursive_version_locations = false + # the output encoding used when revision files # are written from script.py.mako # output_encoding = utf-8 @@ -332,6 +337,11 @@ This file contains the following features: It should be defined if multiple ``version_locations`` is used. See :ref:`multiple_bases` for examples. +* ``recursive_version_locations`` - when set to 'true', revision files + are searched recursively in each "version_locations" directory. + + .. versionadded:: 1.10 + * ``output_encoding`` - the encoding to use when Alembic writes the ``script.py.mako`` file into a new migration file. Defaults to ``'utf-8'``. diff --git a/docs/build/unreleased/760.rst b/docs/build/unreleased/760.rst new file mode 100644 index 0000000..5f46e10 --- /dev/null +++ b/docs/build/unreleased/760.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: feature, revisioning + :tickets: 760 + + Recursive traversal of revision files in a particular revision directory is + now supported, by indicating ``recursive_version_locations = true`` in + alembic.ini. Pull request courtesy ostr00000. + diff --git a/tests/test_script_consumption.py b/tests/test_script_consumption.py index fa84d7e..a107b80 100644 --- a/tests/test_script_consumption.py +++ b/tests/test_script_consumption.py @@ -1,7 +1,12 @@ +from __future__ import annotations + from contextlib import contextmanager import os import re +import shutil import textwrap +from typing import Dict +from typing import List import sqlalchemy as sa from sqlalchemy import pool @@ -9,18 +14,24 @@ from sqlalchemy import pool from alembic import command from alembic import testing from alembic import util +from alembic.config import Config from alembic.environment import EnvironmentContext from alembic.script import Script from alembic.script import ScriptDirectory from alembic.testing import assert_raises_message +from alembic.testing import assertions from alembic.testing import config from alembic.testing import eq_ +from alembic.testing import expect_raises_message from alembic.testing import mock +from alembic.testing.env import _get_staging_directory +from alembic.testing.env import _multi_dir_testing_config from alembic.testing.env import _no_sql_testing_config from alembic.testing.env import _sqlite_file_db from alembic.testing.env import _sqlite_testing_config from alembic.testing.env import clear_staging_env from alembic.testing.env import env_file_fixture +from alembic.testing.env import multi_heads_fixture from alembic.testing.env import staging_env from alembic.testing.env import three_rev_fixture from alembic.testing.env import write_script @@ -900,3 +911,325 @@ class SourcelessNeedsFlagTest(TestBase): self.cfg.set_main_option("sourceless", "true") script = ScriptDirectory.from_config(self.cfg) eq_(script.get_heads(), [a]) + + +class RecursiveScriptDirectoryTest(TestBase): + """test recursive version directory consumption for #760""" + + rev: List[str] + org_script_dir: ScriptDirectory + cfg: Config + _script_by_name: Dict[str, Script] + _name_by_revision: Dict[str, str] + + def _setup_revision_files( + self, listing, destination=".", version_path="scripts/versions" + ): + for elem in listing: + if isinstance(elem, str): + if destination != ".": + script = self._script_by_name[elem] + target_file = self._get_moved_path( + elem, destination, version_path + ) + os.makedirs(os.path.dirname(target_file), exist_ok=True) + shutil.move(script.path, target_file) + else: + dest, files = elem + if dest == "delete": + for fname in files: + revision_to_remove = self._script_by_name[fname] + os.remove(revision_to_remove.path) + else: + self._setup_revision_files( + files, os.path.join(destination, dest), version_path + ) + + def _get_moved_path( + self, + elem: str, + destination_dir: str = "", + version_path="scripts/versions", + ): + script = self._script_by_name[elem] + file_name = os.path.basename(script.path) + target_file = os.path.join( + _get_staging_directory(), version_path, destination_dir, file_name + ) + target_file = os.path.realpath(target_file) + return target_file + + def _assert_setup(self, *elements): + sd = ScriptDirectory.from_config(self.cfg) + + _new_rev_to_script = { + self._name_by_revision[r.revision]: r for r in sd.walk_revisions() + } + + for revname, directory, version_path in elements: + eq_( + _new_rev_to_script[revname].path, + self._get_moved_path(revname, directory, version_path), + ) + + eq_(len(_new_rev_to_script), len(elements)) + + revs_to_check = { + self._script_by_name[rev].revision for rev, _, _ in elements + } + + # topological order check + for rev_id in revs_to_check: + new_script = sd.get_revision(rev_id) + assertions.is_not_(new_script, None) + + old_revisions = { + r.revision: r + for r in self.org_script_dir.revision_map.iterate_revisions( + rev_id, + "base", + inclusive=True, + assert_relative_length=False, + ) + } + new_revisions = { + r.revision: r + for r in sd.revision_map.iterate_revisions( + rev_id, + "base", + inclusive=True, + assert_relative_length=False, + ) + } + + eq_(len(old_revisions), len(new_revisions)) + + for common_rev_id in set(old_revisions.keys()).union( + new_revisions.keys() + ): + old_rev = old_revisions[common_rev_id] + new_rev = new_revisions[common_rev_id] + + eq_(old_rev.revision, new_rev.revision) + eq_(old_rev.down_revision, new_rev.down_revision) + eq_(old_rev.dependencies, new_rev.dependencies) + + def _setup_for_fixture(self, revs): + self.rev = revs + + self.org_script_dir = ScriptDirectory.from_config(self.cfg) + rev_to_script = { + script.revision: script + for script in self.org_script_dir.walk_revisions() + } + self._script_by_name = { + f"r{i}": rev_to_script[revnum] for i, revnum in enumerate(self.rev) + } + self._name_by_revision = { + v.revision: k for k, v in self._script_by_name.items() + } + + @testing.fixture + def non_recursive_fixture(self): + self.env = staging_env() + self.cfg = _sqlite_testing_config() + + ids = [util.rev_id() for i in range(5)] + + script = ScriptDirectory.from_config(self.cfg) + script.generate_revision( + ids[0], "revision a", refresh=True, head="base" + ) + script.generate_revision( + ids[1], "revision b", refresh=True, head=ids[0] + ) + script.generate_revision( + ids[2], "revision c", refresh=True, head=ids[1] + ) + script.generate_revision( + ids[3], "revision d", refresh=True, head="base" + ) + script.generate_revision( + ids[4], "revision e", refresh=True, head=ids[3] + ) + + self._setup_for_fixture(ids) + + yield + + clear_staging_env() + + @testing.fixture + def single_base_fixture(self): + self.env = staging_env() + self.cfg = _sqlite_testing_config() + self.cfg.set_main_option("recursive_version_locations", "true") + + revs = list(three_rev_fixture(self.cfg)) + revs.extend(multi_heads_fixture(self.cfg, *revs[0:3])) + + self._setup_for_fixture(revs) + + yield + + clear_staging_env() + + @testing.fixture + def multi_base_fixture(self): + + self.env = staging_env() + self.cfg = _multi_dir_testing_config() + self.cfg.set_main_option("recursive_version_locations", "true") + + script0 = command.revision( + self.cfg, + message="x", + head="base", + version_path=os.path.join(_get_staging_directory(), "model1"), + ) + assert isinstance(script0, Script) + script1 = command.revision( + self.cfg, + message="y", + head="base", + version_path=os.path.join(_get_staging_directory(), "model2"), + ) + assert isinstance(script1, Script) + script2 = command.revision( + self.cfg, message="y2", head=script1.revision + ) + assert isinstance(script2, Script) + + self.org_script_dir = ScriptDirectory.from_config(self.cfg) + + rev_to_script = { + script0.revision: script0, + script1.revision: script1, + script2.revision: script2, + } + + self._setup_for_fixture(rev_to_script) + + yield + + clear_staging_env() + + def test_ignore_for_non_recursive(self, non_recursive_fixture): + """test traversal is non-recursive when the feature is not enabled + (subdirectories are ignored). + + """ + + self._setup_revision_files( + [ + "r0", + "r1", + ("dir_1", ["r2", "r3"]), + ("dir_2", ["r4"]), + ] + ) + + vl = "scripts/versions" + + self._assert_setup( + ("r0", "", vl), + ("r1", "", vl), + ) + + def test_flat_structure(self, single_base_fixture): + assert len(self.rev) == 6 + + def test_flat_and_dir_structure(self, single_base_fixture): + self._setup_revision_files( + [ + "r1", + ("dir_1", ["r0", "r2"]), + ("dir_2", ["r4"]), + ("dir_3", ["r5"]), + ] + ) + + vl = "scripts/versions" + + self._assert_setup( + ("r0", "dir_1", vl), + ("r1", "", vl), + ("r2", "dir_1", vl), + ("r3", "", vl), + ("r4", "dir_2", vl), + ("r5", "dir_3", vl), + ) + + def test_nested_dir_structure(self, single_base_fixture): + self._setup_revision_files( + [ + ( + "dir_1", + ["r0", ("nested_1", ["r1", "r2"]), ("nested_2", ["r3"])], + ), + ("dir_2", ["r4"]), + ("dir_3", [("nested_3", ["r5"])]), + ] + ) + + vl = "scripts/versions" + + self._assert_setup( + ("r0", "dir_1", vl), + ("r1", "dir_1/nested_1", vl), + ("r2", "dir_1/nested_1", vl), + ("r3", "dir_1/nested_2", vl), + ("r4", "dir_2", vl), + ("r5", "dir_3/nested_3", vl), + ) + + def test_dir_structure_with_missing_file(self, single_base_fixture): + sd = ScriptDirectory.from_config(self.cfg) + + revision_to_remove = self._script_by_name["r1"] + self._setup_revision_files( + [ + ("delete", ["r1"]), + ("dir_1", ["r0", "r2"]), + ("dir_2", ["r4"]), + ("dir_3", ["r5"]), + ] + ) + + with expect_raises_message(KeyError, revision_to_remove.revision): + list(sd.walk_revisions()) + + def test_multiple_dir_recursive(self, multi_base_fixture): + self._setup_revision_files( + [ + ("dir_0", ["r0"]), + ], + version_path="model1", + ) + self._setup_revision_files( + [ + ("dir_1", ["r1", ("nested", ["r2"])]), + ], + version_path="model2", + ) + self._assert_setup( + ("r0", "dir_0", "model1"), + ("r1", "dir_1", "model2"), + ("r2", "dir_1/nested", "model2"), + ) + + def test_multiple_dir_recursive_change_version_dir( + self, multi_base_fixture + ): + self._setup_revision_files( + [ + ("dir_0", ["r0"]), + ("dir_1", ["r1", ("nested", ["r2"])]), + ], + version_path="model1", + ) + self._assert_setup( + ("r0", "dir_0", "model1"), + ("r1", "dir_1", "model1"), + ("r2", "dir_1/nested", "model1"), + ) diff --git a/tests/test_script_production.py b/tests/test_script_production.py index 5c4cd9e..151b3b8 100644 --- a/tests/test_script_production.py +++ b/tests/test_script_production.py @@ -710,7 +710,7 @@ class ImportsTest(TestBase): context.configure( connection=connection, target_metadata=target_metadata, - **kw + **kw, ) with context.begin_transaction(): context.run_migrations() |