summaryrefslogtreecommitdiff
path: root/alembic/script/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'alembic/script/base.py')
-rw-r--r--alembic/script/base.py73
1 files changed, 47 insertions, 26 deletions
diff --git a/alembic/script/base.py b/alembic/script/base.py
index 3c09cef..b6858b5 100644
--- a/alembic/script/base.py
+++ b/alembic/script/base.py
@@ -80,6 +80,7 @@ class ScriptDirectory:
output_encoding: str = "utf-8",
timezone: Optional[str] = None,
hook_config: Optional[Dict[str, str]] = None,
+ recursive_version_locations: bool = False,
) -> None:
self.dir = dir
self.file_template = file_template
@@ -90,6 +91,7 @@ class ScriptDirectory:
self.revision_map = revision.RevisionMap(self._load_revisions)
self.timezone = timezone
self.hook_config = hook_config
+ self.recursive_version_locations = recursive_version_locations
if not os.access(dir, os.F_OK):
raise util.CommandError(
@@ -128,16 +130,19 @@ class ScriptDirectory:
dupes = set()
for vers in paths:
- for file_ in Script._list_py_dir(self, vers):
- path = os.path.realpath(os.path.join(vers, file_))
- if path in dupes:
+ for file_path in Script._list_py_dir(self, vers):
+ real_path = os.path.realpath(file_path)
+ if real_path in dupes:
util.warn(
"File %s loaded twice! ignoring. Please ensure "
- "version_locations is unique." % path
+ "version_locations is unique." % real_path
)
continue
- dupes.add(path)
- script = Script._from_filename(self, vers, file_)
+ dupes.add(real_path)
+
+ filename = os.path.basename(real_path)
+ dir_name = os.path.dirname(real_path)
+ script = Script._from_filename(self, dir_name, filename)
if script is None:
continue
yield script
@@ -207,6 +212,7 @@ class ScriptDirectory:
_split_on_space_comma_colon.split(prepend_sys_path)
)
+ rvl = config.get_main_option("recursive_version_locations") == "true"
return ScriptDirectory(
util.coerce_resource_to_filename(script_location),
file_template=config.get_main_option(
@@ -218,6 +224,7 @@ class ScriptDirectory:
version_locations=version_locations,
timezone=config.get_main_option("timezone"),
hook_config=config.get_section("post_write_hooks", {}),
+ recursive_version_locations=rvl,
)
@contextmanager
@@ -959,26 +966,40 @@ class Script(revision.Revision):
@classmethod
def _list_py_dir(cls, scriptdir: ScriptDirectory, path: str) -> List[str]:
- if scriptdir.sourceless:
- # read files in version path, e.g. pyc or pyo files
- # in the immediate path
- paths = os.listdir(path)
-
- names = {fname.split(".")[0] for fname in paths}
-
- # look for __pycache__
- if os.path.exists(os.path.join(path, "__pycache__")):
- # add all files from __pycache__ whose filename is not
- # already in the names we got from the version directory.
- # add as relative paths including __pycache__ token
- paths.extend(
- os.path.join("__pycache__", pyc)
- for pyc in os.listdir(os.path.join(path, "__pycache__"))
- if pyc.split(".")[0] not in names
- )
- return paths
- else:
- return os.listdir(path)
+ paths = []
+ for root, dirs, files in os.walk(path, topdown=True):
+ if root.endswith("__pycache__"):
+ # a special case - we may include these files
+ # if a `sourceless` option is specified
+ continue
+
+ for filename in sorted(files):
+ paths.append(os.path.join(root, filename))
+
+ if scriptdir.sourceless:
+ # look for __pycache__
+ py_cache_path = os.path.join(root, "__pycache__")
+ if os.path.exists(py_cache_path):
+ # add all files from __pycache__ whose filename is not
+ # already in the names we got from the version directory.
+ # add as relative paths including __pycache__ token
+ names = {filename.split(".")[0] for filename in files}
+ paths.extend(
+ os.path.join(py_cache_path, pyc)
+ for pyc in os.listdir(py_cache_path)
+ if pyc.split(".")[0] not in names
+ )
+
+ if not scriptdir.recursive_version_locations:
+ break
+
+ # the real script order is defined by revision,
+ # but it may be undefined if there are many files with a same
+ # `down_revision`, for a better user experience (ex. debugging),
+ # we use a deterministic order
+ dirs.sort()
+
+ return paths
@classmethod
def _from_filename(