summaryrefslogtreecommitdiff
path: root/alembic/script
diff options
context:
space:
mode:
Diffstat (limited to 'alembic/script')
-rw-r--r--alembic/script/base.py129
-rw-r--r--alembic/script/revision.py176
-rw-r--r--alembic/script/write_hooks.py4
3 files changed, 188 insertions, 121 deletions
diff --git a/alembic/script/base.py b/alembic/script/base.py
index ef0fd52..ccbf86c 100644
--- a/alembic/script/base.py
+++ b/alembic/script/base.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
from contextlib import contextmanager
import datetime
import os
@@ -21,11 +23,13 @@ from . import revision
from . import write_hooks
from .. import util
from ..runtime import migration
+from ..util import not_none
if TYPE_CHECKING:
from ..config import Config
from ..runtime.migration import RevisionStep
from ..runtime.migration import StampStep
+ from ..script.revision import Revision
try:
from dateutil import tz
@@ -112,7 +116,7 @@ class ScriptDirectory:
else:
return (os.path.abspath(os.path.join(self.dir, "versions")),)
- def _load_revisions(self) -> Iterator["Script"]:
+ def _load_revisions(self) -> Iterator[Script]:
if self.version_locations:
paths = [
vers
@@ -139,7 +143,7 @@ class ScriptDirectory:
yield script
@classmethod
- def from_config(cls, config: "Config") -> "ScriptDirectory":
+ def from_config(cls, config: Config) -> ScriptDirectory:
"""Produce a new :class:`.ScriptDirectory` given a :class:`.Config`
instance.
@@ -152,14 +156,16 @@ class ScriptDirectory:
raise util.CommandError(
"No 'script_location' key " "found in configuration."
)
- truncate_slug_length = cast(
- Optional[int], config.get_main_option("truncate_slug_length")
- )
- if truncate_slug_length is not None:
- truncate_slug_length = int(truncate_slug_length)
+ truncate_slug_length: Optional[int]
+ tsl = config.get_main_option("truncate_slug_length")
+ if tsl is not None:
+ truncate_slug_length = int(tsl)
+ else:
+ truncate_slug_length = None
- version_locations = config.get_main_option("version_locations")
- if version_locations:
+ version_locations_str = config.get_main_option("version_locations")
+ version_locations: Optional[List[str]]
+ if version_locations_str:
version_path_separator = config.get_main_option(
"version_path_separator"
)
@@ -173,7 +179,9 @@ class ScriptDirectory:
}
try:
- split_char = split_on_path[version_path_separator]
+ split_char: Optional[str] = split_on_path[
+ version_path_separator
+ ]
except KeyError as ke:
raise ValueError(
"'%s' is not a valid value for "
@@ -183,17 +191,15 @@ class ScriptDirectory:
else:
if split_char is None:
# legacy behaviour for backwards compatibility
- vl = _split_on_space_comma.split(
- cast(str, version_locations)
+ version_locations = _split_on_space_comma.split(
+ version_locations_str
)
- version_locations: List[str] = vl # type: ignore[no-redef]
else:
- vl = [
- x
- for x in cast(str, version_locations).split(split_char)
- if x
+ version_locations = [
+ x for x in version_locations_str.split(split_char) if x
]
- version_locations: List[str] = vl # type: ignore[no-redef]
+ else:
+ version_locations = None
prepend_sys_path = config.get_main_option("prepend_sys_path")
if prepend_sys_path:
@@ -209,7 +215,7 @@ class ScriptDirectory:
truncate_slug_length=truncate_slug_length,
sourceless=config.get_main_option("sourceless") == "true",
output_encoding=config.get_main_option("output_encoding", "utf-8"),
- version_locations=cast("Optional[List[str]]", version_locations),
+ version_locations=version_locations,
timezone=config.get_main_option("timezone"),
hook_config=config.get_section("post_write_hooks", {}),
)
@@ -262,7 +268,7 @@ class ScriptDirectory:
def walk_revisions(
self, base: str = "base", head: str = "heads"
- ) -> Iterator["Script"]:
+ ) -> Iterator[Script]:
"""Iterate through all revisions.
:param base: the base revision, or "base" to start from the
@@ -279,25 +285,26 @@ class ScriptDirectory:
):
yield cast(Script, rev)
- def get_revisions(self, id_: _RevIdType) -> Tuple["Script", ...]:
+ def get_revisions(self, id_: _RevIdType) -> Tuple[Optional[Script], ...]:
"""Return the :class:`.Script` instance with the given rev identifier,
symbolic name, or sequence of identifiers.
"""
with self._catch_revision_errors():
return cast(
- "Tuple[Script, ...]", self.revision_map.get_revisions(id_)
+ Tuple[Optional[Script], ...],
+ self.revision_map.get_revisions(id_),
)
- def get_all_current(self, id_: Tuple[str, ...]) -> Set["Script"]:
+ def get_all_current(self, id_: Tuple[str, ...]) -> Set[Optional[Script]]:
with self._catch_revision_errors():
top_revs = cast(
- "Set[Script]",
+ Set[Optional[Script]],
set(self.revision_map.get_revisions(id_)),
)
top_revs.update(
cast(
- "Iterator[Script]",
+ Iterator[Script],
self.revision_map._get_ancestor_nodes(
list(top_revs), include_dependencies=True
),
@@ -306,7 +313,7 @@ class ScriptDirectory:
top_revs = self.revision_map._filter_into_branch_heads(top_revs)
return top_revs
- def get_revision(self, id_: str) -> "Script":
+ def get_revision(self, id_: str) -> Optional[Script]:
"""Return the :class:`.Script` instance with the given rev id.
.. seealso::
@@ -316,7 +323,7 @@ class ScriptDirectory:
"""
with self._catch_revision_errors():
- return cast(Script, self.revision_map.get_revision(id_))
+ return cast(Optional[Script], self.revision_map.get_revision(id_))
def as_revision_number(
self, id_: Optional[str]
@@ -335,7 +342,12 @@ class ScriptDirectory:
else:
return rev[0]
- def iterate_revisions(self, upper, lower):
+ def iterate_revisions(
+ self,
+ upper: Union[str, Tuple[str, ...], None],
+ lower: Union[str, Tuple[str, ...], None],
+ **kw: Any,
+ ) -> Iterator[Script]:
"""Iterate through script revisions, starting at the given
upper revision identifier and ending at the lower.
@@ -351,9 +363,12 @@ class ScriptDirectory:
:meth:`.RevisionMap.iterate_revisions`
"""
- return self.revision_map.iterate_revisions(upper, lower)
+ return cast(
+ Iterator[Script],
+ self.revision_map.iterate_revisions(upper, lower, **kw),
+ )
- def get_current_head(self):
+ def get_current_head(self) -> Optional[str]:
"""Return the current head revision.
If the script directory has multiple heads
@@ -423,36 +438,36 @@ class ScriptDirectory:
def _upgrade_revs(
self, destination: str, current_rev: str
- ) -> List["RevisionStep"]:
+ ) -> List[RevisionStep]:
with self._catch_revision_errors(
ancestor="Destination %(end)s is not a valid upgrade "
"target from current head(s)",
end=destination,
):
- revs = self.revision_map.iterate_revisions(
+ revs = self.iterate_revisions(
destination, current_rev, implicit_base=True
)
return [
migration.MigrationStep.upgrade_from_script(
- self.revision_map, cast(Script, script)
+ self.revision_map, script
)
for script in reversed(list(revs))
]
def _downgrade_revs(
self, destination: str, current_rev: Optional[str]
- ) -> List["RevisionStep"]:
+ ) -> List[RevisionStep]:
with self._catch_revision_errors(
ancestor="Destination %(end)s is not a valid downgrade "
"target from current head(s)",
end=destination,
):
- revs = self.revision_map.iterate_revisions(
+ revs = self.iterate_revisions(
current_rev, destination, select_for_downgrade=True
)
return [
migration.MigrationStep.downgrade_from_script(
- self.revision_map, cast(Script, script)
+ self.revision_map, script
)
for script in revs
]
@@ -472,12 +487,14 @@ class ScriptDirectory:
if not revision:
revision = "base"
- filtered_heads: List["Script"] = []
+ filtered_heads: List[Script] = []
for rev in util.to_tuple(revision):
if rev:
filtered_heads.extend(
self.revision_map.filter_for_lineage(
- heads_revs, rev, include_dependencies=True
+ cast(Sequence[Script], heads_revs),
+ rev,
+ include_dependencies=True,
)
)
filtered_heads = util.unique_list(filtered_heads)
@@ -573,7 +590,7 @@ class ScriptDirectory:
src,
dest,
self.output_encoding,
- **kw
+ **kw,
)
def _copy_file(self, src: str, dest: str) -> None:
@@ -621,8 +638,8 @@ class ScriptDirectory:
branch_labels: Optional[str] = None,
version_path: Optional[str] = None,
depends_on: Optional[_RevIdType] = None,
- **kw: Any
- ) -> Optional["Script"]:
+ **kw: Any,
+ ) -> Optional[Script]:
"""Generate a new revision file.
This runs the ``script.py.mako`` template, given
@@ -656,7 +673,12 @@ class ScriptDirectory:
"or perform a merge."
)
):
- heads = self.revision_map.get_revisions(head)
+ heads = cast(
+ Tuple[Optional["Revision"], ...],
+ self.revision_map.get_revisions(head),
+ )
+ for h in heads:
+ assert h != "base"
if len(set(heads)) != len(heads):
raise util.CommandError("Duplicate head revisions specified")
@@ -702,17 +724,20 @@ class ScriptDirectory:
% head_.revision
)
+ resolved_depends_on: Optional[List[str]]
if depends_on:
with self._catch_revision_errors():
- depends_on = [
+ resolved_depends_on = [
dep
if dep in rev.branch_labels # maintain branch labels
else rev.revision # resolve partial revision identifiers
for rev, dep in [
- (self.revision_map.get_revision(dep), dep)
+ (not_none(self.revision_map.get_revision(dep)), dep)
for dep in util.to_list(depends_on)
]
]
+ else:
+ resolved_depends_on = None
self._generate_template(
os.path.join(self.dir, "script.py.mako"),
@@ -722,13 +747,11 @@ class ScriptDirectory:
tuple(h.revision if h is not None else None for h in heads)
),
branch_labels=util.to_tuple(branch_labels),
- depends_on=revision.tuple_rev_as_scalar(
- cast("Optional[List[str]]", depends_on)
- ),
+ depends_on=revision.tuple_rev_as_scalar(resolved_depends_on),
create_date=create_date,
comma=util.format_as_comma,
message=message if message is not None else ("empty message"),
- **kw
+ **kw,
)
post_write_hooks = self.hook_config
@@ -801,13 +824,13 @@ class Script(revision.Revision):
),
)
- module: ModuleType = None # type: ignore[assignment]
+ module: ModuleType
"""The Python module representing the actual script itself."""
- path: str = None # type: ignore[assignment]
+ path: str
"""Filesystem path of the script."""
- _db_current_indicator = None
+ _db_current_indicator: Optional[bool] = None
"""Utility variable which when set will cause string output to indicate
this is a "current" version in some database"""
@@ -939,7 +962,7 @@ class Script(revision.Revision):
@classmethod
def _from_path(
cls, scriptdir: ScriptDirectory, path: str
- ) -> Optional["Script"]:
+ ) -> Optional[Script]:
dir_, filename = os.path.split(path)
return cls._from_filename(scriptdir, dir_, filename)
@@ -969,7 +992,7 @@ class Script(revision.Revision):
@classmethod
def _from_filename(
cls, scriptdir: ScriptDirectory, dir_: str, filename: str
- ) -> Optional["Script"]:
+ ) -> Optional[Script]:
if scriptdir.sourceless:
py_match = _sourceless_rev_file.match(filename)
else:
diff --git a/alembic/script/revision.py b/alembic/script/revision.py
index 2bfb7f9..335314f 100644
--- a/alembic/script/revision.py
+++ b/alembic/script/revision.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import collections
import re
from typing import Any
@@ -11,6 +13,7 @@ from typing import Iterable
from typing import Iterator
from typing import List
from typing import Optional
+from typing import overload
from typing import Sequence
from typing import Set
from typing import Tuple
@@ -21,6 +24,7 @@ from typing import Union
from sqlalchemy import util as sqlautil
from .. import util
+from ..util import not_none
if TYPE_CHECKING:
from typing import Literal
@@ -439,7 +443,7 @@ class RevisionMap:
"Revision %s referenced from %s is not present"
% (downrev, revision)
)
- cast("Revision", map_[downrev]).add_nextrev(revision)
+ not_none(map_[downrev]).add_nextrev(revision)
self._normalize_depends_on(revisions, map_)
@@ -502,8 +506,8 @@ class RevisionMap:
return self.filter_for_lineage(self.bases, identifier)
def get_revisions(
- self, id_: Union[str, Collection[str], None]
- ) -> Tuple["Revision", ...]:
+ self, id_: Union[str, Collection[Optional[str]], None]
+ ) -> Tuple[Optional[_RevisionOrBase], ...]:
"""Return the :class:`.Revision` instances with the given rev id
or identifiers.
@@ -537,7 +541,8 @@ class RevisionMap:
select_heads = tuple(
head
for head in select_heads
- if branch_label in head.branch_labels
+ if branch_label
+ in is_revision(head).branch_labels
)
return tuple(
self._walk(head, steps=rint)
@@ -551,7 +556,7 @@ class RevisionMap:
for rev_id in resolved_id
)
- def get_revision(self, id_: Optional[str]) -> "Revision":
+ def get_revision(self, id_: Optional[str]) -> Optional[Revision]:
"""Return the :class:`.Revision` instance with the given rev id.
If a symbolic name such as "head" or "base" is given, resolves
@@ -568,12 +573,11 @@ class RevisionMap:
resolved_id, branch_label = self._resolve_revision_number(id_)
if len(resolved_id) > 1:
raise MultipleHeads(resolved_id, id_)
- elif resolved_id:
- resolved_id = resolved_id[0] # type:ignore[assignment]
- return self._revision_for_ident(cast(str, resolved_id), branch_label)
+ resolved: Union[str, Tuple[()]] = resolved_id[0] if resolved_id else ()
+ return self._revision_for_ident(resolved, branch_label)
- def _resolve_branch(self, branch_label: str) -> "Revision":
+ def _resolve_branch(self, branch_label: str) -> Optional[Revision]:
try:
branch_rev = self._revision_map[branch_label]
except KeyError:
@@ -587,25 +591,28 @@ class RevisionMap:
else:
return nonbranch_rev
else:
- return cast("Revision", branch_rev)
+ return branch_rev
def _revision_for_ident(
- self, resolved_id: str, check_branch: Optional[str] = None
- ) -> "Revision":
- branch_rev: Optional["Revision"]
+ self,
+ resolved_id: Union[str, Tuple[()]],
+ check_branch: Optional[str] = None,
+ ) -> Optional[Revision]:
+ branch_rev: Optional[Revision]
if check_branch:
branch_rev = self._resolve_branch(check_branch)
else:
branch_rev = None
- revision: Union["Revision", "Literal[False]"]
+ revision: Union[Optional[Revision], "Literal[False]"]
try:
- revision = cast("Revision", self._revision_map[resolved_id])
+ revision = self._revision_map[resolved_id]
except KeyError:
# break out to avoid misleading py3k stack traces
revision = False
revs: Sequence[str]
if revision is False:
+ assert resolved_id
# do a partial lookup
revs = [
x
@@ -637,11 +644,11 @@ class RevisionMap:
resolved_id,
)
else:
- revision = cast("Revision", self._revision_map[revs[0]])
+ revision = self._revision_map[revs[0]]
- revision = cast("Revision", revision)
if check_branch and revision is not None:
assert branch_rev is not None
+ assert resolved_id
if not self._shares_lineage(
revision.revision, branch_rev.revision
):
@@ -653,11 +660,12 @@ class RevisionMap:
return revision
def _filter_into_branch_heads(
- self, targets: Set["Script"]
- ) -> Set["Script"]:
+ self, targets: Set[Optional[Script]]
+ ) -> Set[Optional[Script]]:
targets = set(targets)
for rev in list(targets):
+ assert rev
if targets.intersection(
self._get_descendant_nodes([rev], include_dependencies=False)
).difference([rev]):
@@ -695,9 +703,11 @@ class RevisionMap:
if not test_against_revs:
return True
if not isinstance(target, Revision):
- target = self._revision_for_ident(target)
+ resolved_target = not_none(self._revision_for_ident(target))
+ else:
+ resolved_target = target
- test_against_revs = [
+ resolved_test_against_revs = [
self._revision_for_ident(test_against_rev)
if not isinstance(test_against_rev, Revision)
else test_against_rev
@@ -709,15 +719,17 @@ class RevisionMap:
return bool(
set(
self._get_descendant_nodes(
- [target], include_dependencies=include_dependencies
+ [resolved_target],
+ include_dependencies=include_dependencies,
)
)
.union(
self._get_ancestor_nodes(
- [target], include_dependencies=include_dependencies
+ [resolved_target],
+ include_dependencies=include_dependencies,
)
)
- .intersection(test_against_revs)
+ .intersection(resolved_test_against_revs)
)
def _resolve_revision_number(
@@ -768,7 +780,7 @@ class RevisionMap:
inclusive: bool = False,
assert_relative_length: bool = True,
select_for_downgrade: bool = False,
- ) -> Iterator["Revision"]:
+ ) -> Iterator[Revision]:
"""Iterate through script revisions, starting at the given
upper revision identifier and ending at the lower.
@@ -795,11 +807,11 @@ class RevisionMap:
)
for node in self._topological_sort(revisions, heads):
- yield self.get_revision(node)
+ yield not_none(self.get_revision(node))
def _get_descendant_nodes(
self,
- targets: Collection["Revision"],
+ targets: Collection[Revision],
map_: Optional[_RevisionMapType] = None,
check: bool = False,
omit_immediate_dependencies: bool = False,
@@ -830,11 +842,11 @@ class RevisionMap:
def _get_ancestor_nodes(
self,
- targets: Collection["Revision"],
+ targets: Collection[Optional[_RevisionOrBase]],
map_: Optional[_RevisionMapType] = None,
check: bool = False,
include_dependencies: bool = True,
- ) -> Iterator["Revision"]:
+ ) -> Iterator[Revision]:
if include_dependencies:
@@ -853,17 +865,17 @@ class RevisionMap:
def _iterate_related_revisions(
self,
fn: Callable,
- targets: Collection["Revision"],
+ targets: Collection[Optional[_RevisionOrBase]],
map_: Optional[_RevisionMapType],
check: bool = False,
- ) -> Iterator["Revision"]:
+ ) -> Iterator[Revision]:
if map_ is None:
map_ = self._revision_map
seen = set()
- todo: Deque["Revision"] = collections.deque()
- for target in targets:
-
+ todo: Deque[Revision] = collections.deque()
+ for target_for in targets:
+ target = is_revision(target_for)
todo.append(target)
if check:
per_target = set()
@@ -902,7 +914,7 @@ class RevisionMap:
def _topological_sort(
self,
- revisions: Collection["Revision"],
+ revisions: Collection[Revision],
heads: Any,
) -> List[str]:
"""Yield revision ids of a collection of Revision objects in
@@ -1007,11 +1019,11 @@ class RevisionMap:
def _walk(
self,
- start: Optional[Union[str, "Revision"]],
+ start: Optional[Union[str, Revision]],
steps: int,
branch_label: Optional[str] = None,
no_overwalk: bool = True,
- ) -> "Revision":
+ ) -> Optional[_RevisionOrBase]:
"""
Walk the requested number of :steps up (steps > 0) or down (steps < 0)
the revision tree.
@@ -1030,20 +1042,21 @@ class RevisionMap:
else:
initial = start
- children: Sequence[_RevisionOrBase]
+ children: Sequence[Optional[_RevisionOrBase]]
for _ in range(abs(steps)):
if steps > 0:
+ assert initial != "base"
# Walk up
- children = [
- rev
+ walk_up = [
+ is_revision(rev)
for rev in self.get_revisions(
- self.bases
- if initial is None
- else cast("Revision", initial).nextrev
+ self.bases if initial is None else initial.nextrev
)
]
if branch_label:
- children = self.filter_for_lineage(children, branch_label)
+ children = self.filter_for_lineage(walk_up, branch_label)
+ else:
+ children = walk_up
else:
# Walk down
if initial == "base":
@@ -1055,17 +1068,17 @@ class RevisionMap:
else initial.down_revision
)
if not children:
- children = cast("Tuple[Literal['base']]", ("base",))
+ children = ("base",)
if not children:
# This will return an invalid result if no_overwalk, otherwise
# further steps will stay where we are.
ret = None if no_overwalk else initial
- return ret # type:ignore[return-value]
+ return ret
elif len(children) > 1:
raise RevisionError("Ambiguous walk")
initial = children[0]
- return cast("Revision", initial)
+ return initial
def _parse_downgrade_target(
self,
@@ -1170,7 +1183,7 @@ class RevisionMap:
current_revisions: _RevisionIdentifierType,
target: _RevisionIdentifierType,
assert_relative_length: bool,
- ) -> Tuple["Revision", ...]:
+ ) -> Tuple[Optional[_RevisionOrBase], ...]:
"""
Parse upgrade command syntax :target to retrieve the target revision
and given the :current_revisons stamp of the database.
@@ -1188,26 +1201,27 @@ class RevisionMap:
# No relative destination, target is absolute.
return self.get_revisions(target)
- current_revisions = util.to_tuple(current_revisions)
+ current_revisions_tup: Union[str, Collection[Optional[str]], None]
+ current_revisions_tup = util.to_tuple(current_revisions)
branch_label, symbol, relative_str = match.groups()
relative = int(relative_str)
if relative > 0:
if symbol is None:
- if not current_revisions:
- current_revisions = (None,)
+ if not current_revisions_tup:
+ current_revisions_tup = (None,)
# Try to filter to a single target (avoid ambiguous branches).
- start_revs = current_revisions
+ start_revs = current_revisions_tup
if branch_label:
start_revs = self.filter_for_lineage(
- self.get_revisions(current_revisions), branch_label
+ self.get_revisions(current_revisions_tup), branch_label
)
if not start_revs:
# The requested branch is not a head, so we need to
# backtrack to find a branchpoint.
active_on_branch = self.filter_for_lineage(
self._get_ancestor_nodes(
- self.get_revisions(current_revisions)
+ self.get_revisions(current_revisions_tup)
),
branch_label,
)
@@ -1294,6 +1308,7 @@ class RevisionMap:
target_revision = None
assert target_revision is None or isinstance(target_revision, Revision)
+ roots: List[Revision]
# Find candidates to drop.
if target_revision is None:
# Downgrading back to base: find all tree roots.
@@ -1307,7 +1322,10 @@ class RevisionMap:
roots = [target_revision]
else:
# Downgrading to fixed target: find all direct children.
- roots = list(self.get_revisions(target_revision.nextrev))
+ roots = [
+ is_revision(rev)
+ for rev in self.get_revisions(target_revision.nextrev)
+ ]
if branch_label and len(roots) > 1:
# Need to filter roots.
@@ -1320,11 +1338,12 @@ class RevisionMap:
}
# Intersection gives the root revisions we are trying to
# rollback with the downgrade.
- roots = list(
- self.get_revisions(
+ roots = [
+ is_revision(rev)
+ for rev in self.get_revisions(
{rev.revision for rev in roots}.intersection(ancestors)
)
- )
+ ]
# Ensure we didn't throw everything away when filtering branches.
if len(roots) == 0:
@@ -1374,7 +1393,7 @@ class RevisionMap:
inclusive: bool,
implicit_base: bool,
assert_relative_length: bool,
- ) -> Tuple[Set["Revision"], Tuple[Optional[_RevisionOrBase]]]:
+ ) -> Tuple[Set[Revision], Tuple[Optional[_RevisionOrBase]]]:
"""
Compute the set of required revisions specified by :upper, and the
current set of active revisions specified by :lower. Find the
@@ -1386,11 +1405,14 @@ class RevisionMap:
of the current/lower revisions. Dependencies from branches with
different bases will not be included.
"""
- targets: Collection["Revision"] = self._parse_upgrade_target(
- current_revisions=lower,
- target=upper,
- assert_relative_length=assert_relative_length,
- )
+ targets: Collection[Revision] = [
+ is_revision(rev)
+ for rev in self._parse_upgrade_target(
+ current_revisions=lower,
+ target=upper,
+ assert_relative_length=assert_relative_length,
+ )
+ ]
# assert type(targets) is tuple, "targets should be a tuple"
@@ -1432,6 +1454,7 @@ class RevisionMap:
target=lower,
assert_relative_length=assert_relative_length,
)
+ assert rev
if rev == "base":
current_revisions = tuple()
lower = None
@@ -1449,14 +1472,16 @@ class RevisionMap:
# Include the lower revision (=current_revisions?) in the iteration
if inclusive:
- needs.update(self.get_revisions(lower))
+ needs.update(is_revision(rev) for rev in self.get_revisions(lower))
# By default, base is implicit as we want all dependencies returned.
# Base is also implicit if lower = base
# implicit_base=False -> only return direct downstreams of
# current_revisions
if current_revisions and not implicit_base:
lower_descendents = self._get_descendant_nodes(
- current_revisions, check=True, include_dependencies=False
+ [is_revision(rev) for rev in current_revisions],
+ check=True,
+ include_dependencies=False,
)
needs.intersection_update(lower_descendents)
@@ -1545,7 +1570,7 @@ class Revision:
args.append("branch_labels=%r" % (self.branch_labels,))
return "%s(%s)" % (self.__class__.__name__, ", ".join(args))
- def add_nextrev(self, revision: "Revision") -> None:
+ def add_nextrev(self, revision: Revision) -> None:
self._all_nextrev = self._all_nextrev.union([revision.revision])
if self.revision in revision._versioned_down_revisions:
self.nextrev = self.nextrev.union([revision.revision])
@@ -1630,12 +1655,29 @@ class Revision:
return len(self._versioned_down_revisions) > 1
+@overload
def tuple_rev_as_scalar(
rev: Optional[Sequence[str]],
) -> Optional[Union[str, Sequence[str]]]:
+ ...
+
+
+@overload
+def tuple_rev_as_scalar(
+ rev: Optional[Sequence[Optional[str]]],
+) -> Optional[Union[Optional[str], Sequence[Optional[str]]]]:
+ ...
+
+
+def tuple_rev_as_scalar(rev):
if not rev:
return None
elif len(rev) == 1:
return rev[0]
else:
return rev
+
+
+def is_revision(rev: Any) -> Revision:
+ assert isinstance(rev, Revision)
+ return rev
diff --git a/alembic/script/write_hooks.py b/alembic/script/write_hooks.py
index 0cc9bb8..8bc7ac1 100644
--- a/alembic/script/write_hooks.py
+++ b/alembic/script/write_hooks.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import shlex
import subprocess
import sys
@@ -14,7 +16,7 @@ from ..util import compat
REVISION_SCRIPT_TOKEN = "REVISION_SCRIPT_FILENAME"
-_registry = {}
+_registry: dict = {}
def register(name: str) -> Callable: