summaryrefslogtreecommitdiff
path: root/alembic/script
diff options
context:
space:
mode:
authorCaselIT <cfederico87@gmail.com>2021-04-18 15:44:50 +0200
committerMike Bayer <mike_mp@zzzcomputing.com>2021-08-11 15:04:56 -0400
commit6aad68605f510e8b51f42efa812e02b3831d6e33 (patch)
treecc0e98b8ad8245add8692d8e4910faf57abf7ae3 /alembic/script
parent3bf6a326c0a11e4f05c94008709d6b0b8e9e051a (diff)
downloadalembic-6aad68605f510e8b51f42efa812e02b3831d6e33.tar.gz
Add pep-484 type annotations
pep-484 type annotations have been added throughout the library. This should be helpful in providing Mypy and IDE support, however there is not full support for Alembic's dynamically modified "op" namespace as of yet; a future release will likely modify the approach used for importing this namespace to be better compatible with pep-484 capabilities. Type originally created using MonkeyType Add types extracted with the MonkeyType https://github.com/instagram/MonkeyType library by running the unit tests using ``monkeytype run -m pytest tests``, then ``monkeytype apply <module>`` (see below for further details). USed MonkeyType version 20.5 on Python 3.8, since newer version have issues After applying the types, the new imports are placed in a ``TYPE_CHECKING`` guard and all type definition of non base types are deferred by using the string notation. NOTE: since to apply the types MonkeType need to import the module, also the test ones, the patch below mocks the setup done by pytest so that the tests could be correctly imported diff --git a/alembic/testing/__init__.py b/alembic/testing/__init__.py index bdd1746..b1090c7 100644 Change-Id: Iff93628f4b43c740848871ce077a118db5e75d41 --- a/alembic/testing/__init__.py +++ b/alembic/testing/__init__.py @@ -9,6 +9,12 @@ from sqlalchemy.testing.config import combinations from sqlalchemy.testing.config import fixture from sqlalchemy.testing.config import requirements as requires +from sqlalchemy.testing.plugin.pytestplugin import PytestFixtureFunctions +from sqlalchemy.testing.plugin.plugin_base import _setup_requirements + +config._fixture_functions = PytestFixtureFunctions() +_setup_requirements("tests.requirements:DefaultRequirements") + from alembic import util from .assertions import assert_raises from .assertions import assert_raises_message Currently I'm using this branch of the sqlalchemy stubs: https://github.com/sqlalchemy/sqlalchemy2-stubs/tree/alembic_updates Change-Id: I8fd0700aab1913f395302626b8b84fea60334abd
Diffstat (limited to 'alembic/script')
-rw-r--r--alembic/script/base.py257
-rw-r--r--alembic/script/revision.py409
-rw-r--r--alembic/script/write_hooks.py15
3 files changed, 444 insertions, 237 deletions
diff --git a/alembic/script/base.py b/alembic/script/base.py
index d0500c4..ef0fd52 100644
--- a/alembic/script/base.py
+++ b/alembic/script/base.py
@@ -4,16 +4,35 @@ import os
import re
import shutil
import sys
+from types import ModuleType
+from typing import Any
+from typing import cast
+from typing import Dict
+from typing import Iterator
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Set
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import Union
from . import revision
from . import write_hooks
from .. import util
from ..runtime import migration
+if TYPE_CHECKING:
+ from ..config import Config
+ from ..runtime.migration import RevisionStep
+ from ..runtime.migration import StampStep
+
try:
from dateutil import tz
except ImportError:
- tz = None # noqa
+ tz = None # type: ignore[assignment]
+
+_RevIdType = Union[str, Sequence[str]]
_sourceless_rev_file = re.compile(r"(?!\.\#|__init__)(.*\.py)(c|o)?$")
_only_source_rev_file = re.compile(r"(?!\.\#|__init__)(.*\.py)$")
@@ -49,15 +68,15 @@ class ScriptDirectory:
def __init__(
self,
- dir, # noqa
- file_template=_default_file_template,
- truncate_slug_length=40,
- version_locations=None,
- sourceless=False,
- output_encoding="utf-8",
- timezone=None,
- hook_config=None,
- ):
+ dir: str, # noqa
+ file_template: str = _default_file_template,
+ truncate_slug_length: Optional[int] = 40,
+ version_locations: Optional[List[str]] = None,
+ sourceless: bool = False,
+ output_encoding: str = "utf-8",
+ timezone: Optional[str] = None,
+ hook_config: Optional[Dict[str, str]] = None,
+ ) -> None:
self.dir = dir
self.file_template = file_template
self.version_locations = version_locations
@@ -76,7 +95,7 @@ class ScriptDirectory:
)
@property
- def versions(self):
+ def versions(self) -> str:
loc = self._version_locations
if len(loc) > 1:
raise util.CommandError("Multiple version_locations present")
@@ -93,7 +112,7 @@ class ScriptDirectory:
else:
return (os.path.abspath(os.path.join(self.dir, "versions")),)
- def _load_revisions(self):
+ def _load_revisions(self) -> Iterator["Script"]:
if self.version_locations:
paths = [
vers
@@ -120,7 +139,7 @@ class ScriptDirectory:
yield script
@classmethod
- def from_config(cls, config):
+ def from_config(cls, config: "Config") -> "ScriptDirectory":
"""Produce a new :class:`.ScriptDirectory` given a :class:`.Config`
instance.
@@ -133,7 +152,9 @@ class ScriptDirectory:
raise util.CommandError(
"No 'script_location' key " "found in configuration."
)
- truncate_slug_length = config.get_main_option("truncate_slug_length")
+ truncate_slug_length = cast(
+ Optional[int], config.get_main_option("truncate_slug_length")
+ )
if truncate_slug_length is not None:
truncate_slug_length = int(truncate_slug_length)
@@ -162,13 +183,17 @@ class ScriptDirectory:
else:
if split_char is None:
# legacy behaviour for backwards compatibility
- version_locations = _split_on_space_comma.split(
- version_locations
+ vl = _split_on_space_comma.split(
+ cast(str, version_locations)
)
+ version_locations: List[str] = vl # type: ignore[no-redef]
else:
- version_locations = [
- x for x in version_locations.split(split_char) if x
+ vl = [
+ x
+ for x in cast(str, version_locations).split(split_char)
+ if x
]
+ version_locations: List[str] = vl # type: ignore[no-redef]
prepend_sys_path = config.get_main_option("prepend_sys_path")
if prepend_sys_path:
@@ -184,7 +209,7 @@ class ScriptDirectory:
truncate_slug_length=truncate_slug_length,
sourceless=config.get_main_option("sourceless") == "true",
output_encoding=config.get_main_option("output_encoding", "utf-8"),
- version_locations=version_locations,
+ version_locations=cast("Optional[List[str]]", version_locations),
timezone=config.get_main_option("timezone"),
hook_config=config.get_section("post_write_hooks", {}),
)
@@ -192,19 +217,19 @@ class ScriptDirectory:
@contextmanager
def _catch_revision_errors(
self,
- ancestor=None,
- multiple_heads=None,
- start=None,
- end=None,
- resolution=None,
- ):
+ ancestor: Optional[str] = None,
+ multiple_heads: Optional[str] = None,
+ start: Optional[str] = None,
+ end: Optional[str] = None,
+ resolution: Optional[str] = None,
+ ) -> Iterator[None]:
try:
yield
except revision.RangeNotAncestorError as rna:
if start is None:
- start = rna.lower
+ start = cast(Any, rna.lower)
if end is None:
- end = rna.upper
+ end = cast(Any, rna.upper)
if not ancestor:
ancestor = (
"Requested range %(start)s:%(end)s does not refer to "
@@ -235,7 +260,9 @@ class ScriptDirectory:
except revision.RevisionError as err:
raise util.CommandError(err.args[0]) from err
- def walk_revisions(self, base="base", head="heads"):
+ def walk_revisions(
+ self, base: str = "base", head: str = "heads"
+ ) -> Iterator["Script"]:
"""Iterate through all revisions.
:param base: the base revision, or "base" to start from the
@@ -250,28 +277,36 @@ class ScriptDirectory:
for rev in self.revision_map.iterate_revisions(
head, base, inclusive=True, assert_relative_length=False
):
- yield rev
+ yield cast(Script, rev)
- def get_revisions(self, id_):
+ def get_revisions(self, id_: _RevIdType) -> Tuple["Script", ...]:
"""Return the :class:`.Script` instance with the given rev identifier,
symbolic name, or sequence of identifiers.
"""
with self._catch_revision_errors():
- return self.revision_map.get_revisions(id_)
+ return cast(
+ "Tuple[Script, ...]", self.revision_map.get_revisions(id_)
+ )
- def get_all_current(self, id_):
+ def get_all_current(self, id_: Tuple[str, ...]) -> Set["Script"]:
with self._catch_revision_errors():
- top_revs = set(self.revision_map.get_revisions(id_))
+ top_revs = cast(
+ "Set[Script]",
+ set(self.revision_map.get_revisions(id_)),
+ )
top_revs.update(
- self.revision_map._get_ancestor_nodes(
- list(top_revs), include_dependencies=True
+ cast(
+ "Iterator[Script]",
+ self.revision_map._get_ancestor_nodes(
+ list(top_revs), include_dependencies=True
+ ),
)
)
top_revs = self.revision_map._filter_into_branch_heads(top_revs)
return top_revs
- def get_revision(self, id_):
+ def get_revision(self, id_: str) -> "Script":
"""Return the :class:`.Script` instance with the given rev id.
.. seealso::
@@ -281,9 +316,11 @@ class ScriptDirectory:
"""
with self._catch_revision_errors():
- return self.revision_map.get_revision(id_)
+ return cast(Script, self.revision_map.get_revision(id_))
- def as_revision_number(self, id_):
+ def as_revision_number(
+ self, id_: Optional[str]
+ ) -> Optional[Union[str, Tuple[str, ...]]]:
"""Convert a symbolic revision, i.e. 'head' or 'base', into
an actual revision number."""
@@ -340,7 +377,7 @@ class ScriptDirectory:
):
return self.revision_map.get_current_head()
- def get_heads(self):
+ def get_heads(self) -> List[str]:
"""Return all "versioned head" revisions as strings.
This is normally a list of length one,
@@ -353,7 +390,7 @@ class ScriptDirectory:
"""
return list(self.revision_map.heads)
- def get_base(self):
+ def get_base(self) -> Optional[str]:
"""Return the "base" revision as a string.
This is the revision number of the script that
@@ -375,7 +412,7 @@ class ScriptDirectory:
else:
return None
- def get_bases(self):
+ def get_bases(self) -> List[str]:
"""return all "base" revisions as strings.
This is the revision number of all scripts that
@@ -384,7 +421,9 @@ class ScriptDirectory:
"""
return list(self.revision_map.bases)
- def _upgrade_revs(self, destination, current_rev):
+ def _upgrade_revs(
+ self, destination: str, current_rev: str
+ ) -> List["RevisionStep"]:
with self._catch_revision_errors(
ancestor="Destination %(end)s is not a valid upgrade "
"target from current head(s)",
@@ -393,15 +432,16 @@ class ScriptDirectory:
revs = self.revision_map.iterate_revisions(
destination, current_rev, implicit_base=True
)
- revs = list(revs)
return [
migration.MigrationStep.upgrade_from_script(
- self.revision_map, script
+ self.revision_map, cast(Script, script)
)
for script in reversed(list(revs))
]
- def _downgrade_revs(self, destination, current_rev):
+ def _downgrade_revs(
+ self, destination: str, current_rev: Optional[str]
+ ) -> List["RevisionStep"]:
with self._catch_revision_errors(
ancestor="Destination %(end)s is not a valid downgrade "
"target from current head(s)",
@@ -412,30 +452,32 @@ class ScriptDirectory:
)
return [
migration.MigrationStep.downgrade_from_script(
- self.revision_map, script
+ self.revision_map, cast(Script, script)
)
for script in revs
]
- def _stamp_revs(self, revision, heads):
+ def _stamp_revs(
+ self, revision: _RevIdType, heads: _RevIdType
+ ) -> List["StampStep"]:
with self._catch_revision_errors(
multiple_heads="Multiple heads are present; please specify a "
"single target revision"
):
- heads = self.get_revisions(heads)
+ heads_revs = self.get_revisions(heads)
steps = []
if not revision:
revision = "base"
- filtered_heads = []
+ filtered_heads: List["Script"] = []
for rev in util.to_tuple(revision):
if rev:
filtered_heads.extend(
self.revision_map.filter_for_lineage(
- heads, rev, include_dependencies=True
+ heads_revs, rev, include_dependencies=True
)
)
filtered_heads = util.unique_list(filtered_heads)
@@ -509,7 +551,7 @@ class ScriptDirectory:
return steps
- def run_env(self):
+ def run_env(self) -> None:
"""Run the script environment.
This basically runs the ``env.py`` script present
@@ -524,7 +566,7 @@ class ScriptDirectory:
def env_py_location(self):
return os.path.abspath(os.path.join(self.dir, "env.py"))
- def _generate_template(self, src, dest, **kw):
+ def _generate_template(self, src: str, dest: str, **kw: Any) -> None:
util.status(
"Generating %s" % os.path.abspath(dest),
util.template_to_file,
@@ -534,17 +576,17 @@ class ScriptDirectory:
**kw
)
- def _copy_file(self, src, dest):
+ def _copy_file(self, src: str, dest: str) -> None:
util.status(
"Generating %s" % os.path.abspath(dest), shutil.copy, src, dest
)
- def _ensure_directory(self, path):
+ def _ensure_directory(self, path: str) -> None:
path = os.path.abspath(path)
if not os.path.exists(path):
util.status("Creating directory %s" % path, os.makedirs, path)
- def _generate_create_date(self):
+ def _generate_create_date(self) -> "datetime.datetime":
if self.timezone is not None:
if tz is None:
raise util.CommandError(
@@ -571,16 +613,16 @@ class ScriptDirectory:
def generate_revision(
self,
- revid,
- message,
- head=None,
- refresh=False,
- splice=False,
- branch_labels=None,
- version_path=None,
- depends_on=None,
- **kw
- ):
+ revid: str,
+ message: Optional[str],
+ head: Optional[str] = None,
+ refresh: bool = False,
+ splice: Optional[bool] = False,
+ branch_labels: Optional[str] = None,
+ version_path: Optional[str] = None,
+ depends_on: Optional[_RevIdType] = None,
+ **kw: Any
+ ) -> Optional["Script"]:
"""Generate a new revision file.
This runs the ``script.py.mako`` template, given
@@ -623,9 +665,10 @@ class ScriptDirectory:
if version_path is None:
if len(self._version_locations) > 1:
- for head in heads:
- if head is not None:
- version_path = os.path.dirname(head.path)
+ for head_ in heads:
+ if head_ is not None:
+ assert isinstance(head_, Script)
+ version_path = os.path.dirname(head_.path)
break
else:
raise util.CommandError(
@@ -651,12 +694,12 @@ class ScriptDirectory:
path = self._rev_path(version_path, revid, message, create_date)
if not splice:
- for head in heads:
- if head is not None and not head.is_head:
+ for head_ in heads:
+ if head_ is not None and not head_.is_head:
raise util.CommandError(
"Revision %s is not a head revision; please specify "
"--splice to create a new branch from this revision"
- % head.revision
+ % head_.revision
)
if depends_on:
@@ -679,7 +722,9 @@ class ScriptDirectory:
tuple(h.revision if h is not None else None for h in heads)
),
branch_labels=util.to_tuple(branch_labels),
- depends_on=revision.tuple_rev_as_scalar(depends_on),
+ depends_on=revision.tuple_rev_as_scalar(
+ cast("Optional[List[str]]", depends_on)
+ ),
create_date=create_date,
comma=util.format_as_comma,
message=message if message is not None else ("empty message"),
@@ -694,6 +739,8 @@ class ScriptDirectory:
script = Script._from_path(self, path)
except revision.RevisionError as err:
raise util.CommandError(err.args[0]) from err
+ if script is None:
+ return None
if branch_labels and not script.branch_labels:
raise util.CommandError(
"Version %s specified branch_labels %s, however the "
@@ -702,11 +749,16 @@ class ScriptDirectory:
"'branch_labels' section?"
% (script.revision, branch_labels, script.path)
)
-
self.revision_map.add_revision(script)
return script
- def _rev_path(self, path, rev_id, message, create_date):
+ def _rev_path(
+ self,
+ path: str,
+ rev_id: str,
+ message: Optional[str],
+ create_date: "datetime.datetime",
+ ) -> str:
slug = "_".join(_slug_re.findall(message or "")).lower()
if len(slug) > self.truncate_slug_length:
slug = slug[: self.truncate_slug_length].rsplit("_", 1)[0] + "_"
@@ -735,12 +787,12 @@ class Script(revision.Revision):
"""
- def __init__(self, module, rev_id, path):
+ def __init__(self, module: ModuleType, rev_id: str, path: str):
self.module = module
self.path = path
super(Script, self).__init__(
rev_id,
- module.down_revision,
+ module.down_revision, # type: ignore[attr-defined]
branch_labels=util.to_tuple(
getattr(module, "branch_labels", None), default=()
),
@@ -749,10 +801,10 @@ class Script(revision.Revision):
),
)
- module = None
+ module: ModuleType = None # type: ignore[assignment]
"""The Python module representing the actual script itself."""
- path = None
+ path: str = None # type: ignore[assignment]
"""Filesystem path of the script."""
_db_current_indicator = None
@@ -760,25 +812,27 @@ class Script(revision.Revision):
this is a "current" version in some database"""
@property
- def doc(self):
+ def doc(self) -> str:
"""Return the docstring given in the script."""
return re.split("\n\n", self.longdoc)[0]
@property
- def longdoc(self):
+ def longdoc(self) -> str:
"""Return the docstring given in the script."""
doc = self.module.__doc__
if doc:
if hasattr(self.module, "_alembic_source_encoding"):
- doc = doc.decode(self.module._alembic_source_encoding)
- return doc.strip()
+ doc = doc.decode( # type: ignore[attr-defined]
+ self.module._alembic_source_encoding # type: ignore[attr-defined] # noqa
+ )
+ return doc.strip() # type: ignore[union-attr]
else:
return ""
@property
- def log_entry(self):
+ def log_entry(self) -> str:
entry = "Rev: %s%s%s%s%s\n" % (
self.revision,
" (head)" if self.is_head else "",
@@ -825,12 +879,12 @@ class Script(revision.Revision):
def _head_only(
self,
- include_branches=False,
- include_doc=False,
- include_parents=False,
- tree_indicators=True,
- head_indicators=True,
- ):
+ include_branches: bool = False,
+ include_doc: bool = False,
+ include_parents: bool = False,
+ tree_indicators: bool = True,
+ head_indicators: bool = True,
+ ) -> str:
text = self.revision
if include_parents:
if self.dependencies:
@@ -841,6 +895,7 @@ class Script(revision.Revision):
)
else:
text = "%s -> %s" % (self._format_down_revision(), text)
+ assert text is not None
if include_branches and self.branch_labels:
text += " (%s)" % util.format_as_comma(self.branch_labels)
if head_indicators or tree_indicators:
@@ -862,12 +917,12 @@ class Script(revision.Revision):
def cmd_format(
self,
- verbose,
- include_branches=False,
- include_doc=False,
- include_parents=False,
- tree_indicators=True,
- ):
+ verbose: bool,
+ include_branches: bool = False,
+ include_doc: bool = False,
+ include_parents: bool = False,
+ tree_indicators: bool = True,
+ ) -> str:
if verbose:
return self.log_entry
else:
@@ -875,19 +930,21 @@ class Script(revision.Revision):
include_branches, include_doc, include_parents, tree_indicators
)
- def _format_down_revision(self):
+ def _format_down_revision(self) -> str:
if not self.down_revision:
return "<base>"
else:
return util.format_as_comma(self._versioned_down_revisions)
@classmethod
- def _from_path(cls, scriptdir, path):
+ def _from_path(
+ cls, scriptdir: ScriptDirectory, path: str
+ ) -> Optional["Script"]:
dir_, filename = os.path.split(path)
return cls._from_filename(scriptdir, dir_, filename)
@classmethod
- def _list_py_dir(cls, scriptdir, path):
+ def _list_py_dir(cls, scriptdir: ScriptDirectory, path: str) -> List[str]:
if scriptdir.sourceless:
# read files in version path, e.g. pyc or pyo files
# in the immediate path
@@ -910,7 +967,9 @@ class Script(revision.Revision):
return os.listdir(path)
@classmethod
- def _from_filename(cls, scriptdir, dir_, filename):
+ def _from_filename(
+ cls, scriptdir: ScriptDirectory, dir_: str, filename: str
+ ) -> Optional["Script"]:
if scriptdir.sourceless:
py_match = _sourceless_rev_file.match(filename)
else:
diff --git a/alembic/script/revision.py b/alembic/script/revision.py
index bdae805..eccb98e 100644
--- a/alembic/script/revision.py
+++ b/alembic/script/revision.py
@@ -1,11 +1,40 @@
import collections
import re
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Collection
+from typing import Deque
+from typing import Dict
+from typing import FrozenSet
+from typing import Iterator
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Set
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import TypeVar
+from typing import Union
from sqlalchemy import util as sqlautil
from .. import util
from ..util import compat
+if TYPE_CHECKING:
+ from typing import Literal
+
+ from .base import Script
+
+_RevIdType = Union[str, Sequence[str]]
+_RevisionIdentifierType = Union[str, Tuple[str, ...], None]
+_RevisionOrStr = Union["Revision", str]
+_RevisionOrBase = Union["Revision", "Literal['base']"]
+_InterimRevisionMapType = Dict[str, "Revision"]
+_RevisionMapType = Dict[Union[None, str, Tuple[()]], Optional["Revision"]]
+_T = TypeVar("_T", bound=Union[str, "Revision"])
+
_relative_destination = re.compile(r"(?:(.+?)@)?(\w+)?((?:\+|-)\d+)")
_revision_illegal_chars = ["@", "-", "+"]
@@ -15,7 +44,9 @@ class RevisionError(Exception):
class RangeNotAncestorError(RevisionError):
- def __init__(self, lower, upper):
+ def __init__(
+ self, lower: _RevisionIdentifierType, upper: _RevisionIdentifierType
+ ) -> None:
self.lower = lower
self.upper = upper
super(RangeNotAncestorError, self).__init__(
@@ -25,7 +56,7 @@ class RangeNotAncestorError(RevisionError):
class MultipleHeads(RevisionError):
- def __init__(self, heads, argument):
+ def __init__(self, heads: Sequence[str], argument: Optional[str]) -> None:
self.heads = heads
self.argument = argument
super(MultipleHeads, self).__init__(
@@ -35,7 +66,7 @@ class MultipleHeads(RevisionError):
class ResolutionError(RevisionError):
- def __init__(self, message, argument):
+ def __init__(self, message: str, argument: str) -> None:
super(ResolutionError, self).__init__(message)
self.argument = argument
@@ -43,7 +74,7 @@ class ResolutionError(RevisionError):
class CycleDetected(RevisionError):
kind = "Cycle"
- def __init__(self, revisions):
+ def __init__(self, revisions: Sequence[str]) -> None:
self.revisions = revisions
super(CycleDetected, self).__init__(
"%s is detected in revisions (%s)"
@@ -54,21 +85,21 @@ class CycleDetected(RevisionError):
class DependencyCycleDetected(CycleDetected):
kind = "Dependency cycle"
- def __init__(self, revisions):
+ def __init__(self, revisions: Sequence[str]) -> None:
super(DependencyCycleDetected, self).__init__(revisions)
class LoopDetected(CycleDetected):
kind = "Self-loop"
- def __init__(self, revision):
+ def __init__(self, revision: str) -> None:
super(LoopDetected, self).__init__([revision])
class DependencyLoopDetected(DependencyCycleDetected, LoopDetected):
kind = "Dependency self-loop"
- def __init__(self, revision):
+ def __init__(self, revision: Sequence[str]) -> None:
super(DependencyLoopDetected, self).__init__(revision)
@@ -81,7 +112,7 @@ class RevisionMap:
"""
- def __init__(self, generator):
+ def __init__(self, generator: Callable[[], Iterator["Revision"]]) -> None:
"""Construct a new :class:`.RevisionMap`.
:param generator: a zero-arg callable that will generate an iterable
@@ -92,7 +123,7 @@ class RevisionMap:
self._generator = generator
@util.memoized_property
- def heads(self):
+ def heads(self) -> Tuple[str, ...]:
"""All "head" revisions as strings.
This is normally a tuple of length one,
@@ -105,7 +136,7 @@ class RevisionMap:
return self.heads
@util.memoized_property
- def bases(self):
+ def bases(self) -> Tuple[str, ...]:
"""All "base" revisions as strings.
These are revisions that have a ``down_revision`` of None,
@@ -118,7 +149,7 @@ class RevisionMap:
return self.bases
@util.memoized_property
- def _real_heads(self):
+ def _real_heads(self) -> Tuple[str, ...]:
"""All "real" head revisions as strings.
:return: a tuple of string revision numbers.
@@ -128,7 +159,7 @@ class RevisionMap:
return self._real_heads
@util.memoized_property
- def _real_bases(self):
+ def _real_bases(self) -> Tuple[str, ...]:
"""All "real" base revisions as strings.
:return: a tuple of string revision numbers.
@@ -138,19 +169,19 @@ class RevisionMap:
return self._real_bases
@util.memoized_property
- def _revision_map(self):
+ def _revision_map(self) -> _RevisionMapType:
"""memoized attribute, initializes the revision map from the
initial collection.
"""
# Ordering required for some tests to pass (but not required in
# general)
- map_ = sqlautil.OrderedDict()
+ map_: _InterimRevisionMapType = sqlautil.OrderedDict()
- heads = sqlautil.OrderedSet()
- _real_heads = sqlautil.OrderedSet()
- bases = ()
- _real_bases = ()
+ heads: Set["Revision"] = sqlautil.OrderedSet()
+ _real_heads: Set["Revision"] = sqlautil.OrderedSet()
+ bases: Tuple["Revision", ...] = ()
+ _real_bases: Tuple["Revision", ...] = ()
has_branch_labels = set()
all_revisions = set()
@@ -176,11 +207,13 @@ class RevisionMap:
# add the branch_labels to the map_. We'll need these
# to resolve the dependencies.
rev_map = map_.copy()
- self._map_branch_labels(has_branch_labels, map_)
+ self._map_branch_labels(
+ has_branch_labels, cast(_RevisionMapType, map_)
+ )
# resolve dependency names from branch labels and symbolic
# names
- self._add_depends_on(all_revisions, map_)
+ self._add_depends_on(all_revisions, cast(_RevisionMapType, map_))
for rev in map_.values():
for downrev in rev._all_down_revisions:
@@ -198,32 +231,44 @@ class RevisionMap:
# once the map has downrevisions populated, the dependencies
# can be further refined to include only those which are not
# already ancestors
- self._normalize_depends_on(all_revisions, map_)
+ self._normalize_depends_on(all_revisions, cast(_RevisionMapType, map_))
self._detect_cycles(rev_map, heads, bases, _real_heads, _real_bases)
- map_[None] = map_[()] = None
+ revision_map: _RevisionMapType = dict(map_.items())
+ revision_map[None] = revision_map[()] = None
self.heads = tuple(rev.revision for rev in heads)
self._real_heads = tuple(rev.revision for rev in _real_heads)
self.bases = tuple(rev.revision for rev in bases)
self._real_bases = tuple(rev.revision for rev in _real_bases)
- self._add_branches(has_branch_labels, map_)
- return map_
+ self._add_branches(has_branch_labels, revision_map)
+ return revision_map
- def _detect_cycles(self, rev_map, heads, bases, _real_heads, _real_bases):
+ def _detect_cycles(
+ self,
+ rev_map: _InterimRevisionMapType,
+ heads: Set["Revision"],
+ bases: Tuple["Revision", ...],
+ _real_heads: Set["Revision"],
+ _real_bases: Tuple["Revision", ...],
+ ) -> None:
if not rev_map:
return
if not heads or not bases:
- raise CycleDetected(rev_map.keys())
+ raise CycleDetected(list(rev_map))
total_space = {
rev.revision
for rev in self._iterate_related_revisions(
- lambda r: r._versioned_down_revisions, heads, map_=rev_map
+ lambda r: r._versioned_down_revisions,
+ heads,
+ map_=cast(_RevisionMapType, rev_map),
)
}.intersection(
rev.revision
for rev in self._iterate_related_revisions(
- lambda r: r.nextrev, bases, map_=rev_map
+ lambda r: r.nextrev,
+ bases,
+ map_=cast(_RevisionMapType, rev_map),
)
)
deleted_revs = set(rev_map.keys()) - total_space
@@ -231,39 +276,50 @@ class RevisionMap:
raise CycleDetected(sorted(deleted_revs))
if not _real_heads or not _real_bases:
- raise DependencyCycleDetected(rev_map.keys())
+ raise DependencyCycleDetected(list(rev_map))
total_space = {
rev.revision
for rev in self._iterate_related_revisions(
- lambda r: r._all_down_revisions, _real_heads, map_=rev_map
+ lambda r: r._all_down_revisions,
+ _real_heads,
+ map_=cast(_RevisionMapType, rev_map),
)
}.intersection(
rev.revision
for rev in self._iterate_related_revisions(
- lambda r: r._all_nextrev, _real_bases, map_=rev_map
+ lambda r: r._all_nextrev,
+ _real_bases,
+ map_=cast(_RevisionMapType, rev_map),
)
)
deleted_revs = set(rev_map.keys()) - total_space
if deleted_revs:
raise DependencyCycleDetected(sorted(deleted_revs))
- def _map_branch_labels(self, revisions, map_):
+ def _map_branch_labels(
+ self, revisions: Collection["Revision"], map_: _RevisionMapType
+ ) -> None:
for revision in revisions:
if revision.branch_labels:
+ assert revision._orig_branch_labels is not None
for branch_label in revision._orig_branch_labels:
if branch_label in map_:
+ map_rev = map_[branch_label]
+ assert map_rev is not None
raise RevisionError(
"Branch name '%s' in revision %s already "
"used by revision %s"
% (
branch_label,
revision.revision,
- map_[branch_label].revision,
+ map_rev.revision,
)
)
map_[branch_label] = revision
- def _add_branches(self, revisions, map_):
+ def _add_branches(
+ self, revisions: Collection["Revision"], map_: _RevisionMapType
+ ) -> None:
for revision in revisions:
if revision.branch_labels:
revision.branch_labels.update(revision.branch_labels)
@@ -285,7 +341,9 @@ class RevisionMap:
else:
break
- def _add_depends_on(self, revisions, map_):
+ def _add_depends_on(
+ self, revisions: Collection["Revision"], map_: _RevisionMapType
+ ) -> None:
"""Resolve the 'dependencies' for each revision in a collection
in terms of actual revision ids, as opposed to branch labels or other
symbolic names.
@@ -301,12 +359,14 @@ class RevisionMap:
map_[dep] for dep in util.to_tuple(revision.dependencies)
]
revision._resolved_dependencies = tuple(
- [d.revision for d in deps]
+ [d.revision for d in deps if d is not None]
)
else:
revision._resolved_dependencies = ()
- def _normalize_depends_on(self, revisions, map_):
+ def _normalize_depends_on(
+ self, revisions: Collection["Revision"], map_: _RevisionMapType
+ ) -> None:
"""Create a collection of "dependencies" that omits dependencies
that are already ancestor nodes for each revision in a given
collection.
@@ -327,7 +387,9 @@ class RevisionMap:
if revision._resolved_dependencies:
normalized_resolved = set(revision._resolved_dependencies)
for rev in self._get_ancestor_nodes(
- [revision], include_dependencies=False, map_=map_
+ [revision],
+ include_dependencies=False,
+ map_=cast(_RevisionMapType, map_),
):
if rev is revision:
continue
@@ -342,7 +404,9 @@ class RevisionMap:
else:
revision._normalized_resolved_dependencies = ()
- def add_revision(self, revision, _replace=False):
+ def add_revision(
+ self, revision: "Revision", _replace: bool = False
+ ) -> None:
"""add a single revision to an existing map.
This method is for single-revision use cases, it's not
@@ -375,7 +439,7 @@ class RevisionMap:
"Revision %s referenced from %s is not present"
% (downrev, revision)
)
- map_[downrev].add_nextrev(revision)
+ cast("Revision", map_[downrev]).add_nextrev(revision)
self._normalize_depends_on(revisions, map_)
@@ -398,7 +462,9 @@ class RevisionMap:
)
) + (revision.revision,)
- def get_current_head(self, branch_label=None):
+ def get_current_head(
+ self, branch_label: Optional[str] = None
+ ) -> Optional[str]:
"""Return the current head revision.
If the script directory has multiple heads
@@ -416,7 +482,7 @@ class RevisionMap:
:meth:`.ScriptDirectory.get_heads`
"""
- current_heads = self.heads
+ current_heads: Sequence[str] = self.heads
if branch_label:
current_heads = self.filter_for_lineage(
current_heads, branch_label
@@ -432,10 +498,12 @@ class RevisionMap:
else:
return None
- def _get_base_revisions(self, identifier):
+ def _get_base_revisions(self, identifier: str) -> Tuple[str, ...]:
return self.filter_for_lineage(self.bases, identifier)
- def get_revisions(self, id_):
+ def get_revisions(
+ self, id_: Union[str, Collection[str], None]
+ ) -> Tuple["Revision", ...]:
"""Return the :class:`.Revision` instances with the given rev id
or identifiers.
@@ -456,7 +524,9 @@ class RevisionMap:
if isinstance(id_, (list, tuple, set, frozenset)):
return sum([self.get_revisions(id_elem) for id_elem in id_], ())
else:
- resolved_id, branch_label = self._resolve_revision_number(id_)
+ resolved_id, branch_label = self._resolve_revision_number(
+ id_ # type:ignore [arg-type]
+ )
if len(resolved_id) == 1:
try:
rint = int(resolved_id[0])
@@ -464,11 +534,11 @@ class RevisionMap:
# branch@-n -> walk down from heads
select_heads = self.get_revisions("heads")
if branch_label is not None:
- select_heads = [
+ select_heads = tuple(
head
for head in select_heads
if branch_label in head.branch_labels
- ]
+ )
return tuple(
self._walk(head, steps=rint)
for head in select_heads
@@ -481,7 +551,7 @@ class RevisionMap:
for rev_id in resolved_id
)
- def get_revision(self, id_):
+ def get_revision(self, id_: Optional[str]) -> "Revision":
"""Return the :class:`.Revision` instance with the given rev id.
If a symbolic name such as "head" or "base" is given, resolves
@@ -499,11 +569,11 @@ class RevisionMap:
if len(resolved_id) > 1:
raise MultipleHeads(resolved_id, id_)
elif resolved_id:
- resolved_id = resolved_id[0]
+ resolved_id = resolved_id[0] # type:ignore[assignment]
- return self._revision_for_ident(resolved_id, branch_label)
+ return self._revision_for_ident(cast(str, resolved_id), branch_label)
- def _resolve_branch(self, branch_label):
+ def _resolve_branch(self, branch_label: str) -> "Revision":
try:
branch_rev = self._revision_map[branch_label]
except KeyError:
@@ -517,19 +587,24 @@ class RevisionMap:
else:
return nonbranch_rev
else:
- return branch_rev
+ return cast("Revision", branch_rev)
- def _revision_for_ident(self, resolved_id, check_branch=None):
+ def _revision_for_ident(
+ self, resolved_id: str, check_branch: Optional[str] = None
+ ) -> "Revision":
+ branch_rev: Optional["Revision"]
if check_branch:
branch_rev = self._resolve_branch(check_branch)
else:
branch_rev = None
+ revision: Union["Revision", "Literal[False]"]
try:
- revision = self._revision_map[resolved_id]
+ revision = cast("Revision", self._revision_map[resolved_id])
except KeyError:
# break out to avoid misleading py3k stack traces
revision = False
+ revs: Sequence[str]
if revision is False:
# do a partial lookup
revs = [
@@ -562,9 +637,11 @@ class RevisionMap:
resolved_id,
)
else:
- revision = self._revision_map[revs[0]]
+ revision = cast("Revision", self._revision_map[revs[0]])
+ revision = cast("Revision", revision)
if check_branch and revision is not None:
+ assert branch_rev is not None
if not self._shares_lineage(
revision.revision, branch_rev.revision
):
@@ -575,7 +652,9 @@ class RevisionMap:
)
return revision
- def _filter_into_branch_heads(self, targets):
+ def _filter_into_branch_heads(
+ self, targets: Set["Script"]
+ ) -> Set["Script"]:
targets = set(targets)
for rev in list(targets):
@@ -586,8 +665,11 @@ class RevisionMap:
return targets
def filter_for_lineage(
- self, targets, check_against, include_dependencies=False
- ):
+ self,
+ targets: Sequence[_T],
+ check_against: Optional[str],
+ include_dependencies: bool = False,
+ ) -> Tuple[_T, ...]:
id_, branch_label = self._resolve_revision_number(check_against)
shares = []
@@ -596,17 +678,20 @@ class RevisionMap:
if id_:
shares.extend(id_)
- return [
+ return tuple(
tg
for tg in targets
if self._shares_lineage(
tg, shares, include_dependencies=include_dependencies
)
- ]
+ )
def _shares_lineage(
- self, target, test_against_revs, include_dependencies=False
- ):
+ self,
+ target: _RevisionOrStr,
+ test_against_revs: Sequence[_RevisionOrStr],
+ include_dependencies: bool = False,
+ ) -> bool:
if not test_against_revs:
return True
if not isinstance(target, Revision):
@@ -635,7 +720,10 @@ class RevisionMap:
.intersection(test_against_revs)
)
- def _resolve_revision_number(self, id_):
+ def _resolve_revision_number(
+ self, id_: Optional[str]
+ ) -> Tuple[Tuple[str, ...], Optional[str]]:
+ branch_label: Optional[str]
if isinstance(id_, compat.string_types) and "@" in id_:
branch_label, id_ = id_.split("@", 1)
@@ -678,13 +766,13 @@ class RevisionMap:
def iterate_revisions(
self,
- upper,
- lower,
- implicit_base=False,
- inclusive=False,
- assert_relative_length=True,
- select_for_downgrade=False,
- ):
+ upper: _RevisionIdentifierType,
+ lower: _RevisionIdentifierType,
+ implicit_base: bool = False,
+ inclusive: bool = False,
+ assert_relative_length: bool = True,
+ select_for_downgrade: bool = False,
+ ) -> Iterator["Revision"]:
"""Iterate through script revisions, starting at the given
upper revision identifier and ending at the lower.
@@ -696,6 +784,7 @@ class RevisionMap:
The iterator yields :class:`.Revision` objects.
"""
+ fn: Callable
if select_for_downgrade:
fn = self._collect_downgrade_revisions
else:
@@ -714,12 +803,12 @@ class RevisionMap:
def _get_descendant_nodes(
self,
- targets,
- map_=None,
- check=False,
- omit_immediate_dependencies=False,
- include_dependencies=True,
- ):
+ targets: Collection["Revision"],
+ map_: Optional[_RevisionMapType] = None,
+ check: bool = False,
+ omit_immediate_dependencies: bool = False,
+ include_dependencies: bool = True,
+ ) -> Iterator[Any]:
if omit_immediate_dependencies:
@@ -744,8 +833,12 @@ class RevisionMap:
)
def _get_ancestor_nodes(
- self, targets, map_=None, check=False, include_dependencies=True
- ):
+ self,
+ targets: Collection["Revision"],
+ map_: Optional[_RevisionMapType] = None,
+ check: bool = False,
+ include_dependencies: bool = True,
+ ) -> Iterator["Revision"]:
if include_dependencies:
@@ -761,12 +854,18 @@ class RevisionMap:
fn, targets, map_=map_, check=check
)
- def _iterate_related_revisions(self, fn, targets, map_, check=False):
+ def _iterate_related_revisions(
+ self,
+ fn: Callable,
+ targets: Collection["Revision"],
+ map_: Optional[_RevisionMapType],
+ check: bool = False,
+ ) -> Iterator["Revision"]:
if map_ is None:
map_ = self._revision_map
seen = set()
- todo = collections.deque()
+ todo: Deque["Revision"] = collections.deque()
for target in targets:
todo.append(target)
@@ -784,6 +883,7 @@ class RevisionMap:
# Check for map errors before collecting.
for rev_id in fn(rev):
next_rev = map_[rev_id]
+ assert next_rev is not None
if next_rev.revision != rev_id:
raise RevisionError(
"Dependency resolution failed; broken map"
@@ -804,7 +904,11 @@ class RevisionMap:
)
)
- def _topological_sort(self, revisions, heads):
+ def _topological_sort(
+ self,
+ revisions: Collection["Revision"],
+ heads: Any,
+ ) -> List[str]:
"""Yield revision ids of a collection of Revision objects in
topological sorted order (i.e. revisions always come after their
down_revisions and dependencies). Uses the order of keys in
@@ -860,6 +964,7 @@ class RevisionMap:
# now update the heads with our ancestors.
candidate_rev = id_to_rev[candidate]
+ assert candidate_rev is not None
heads_to_add = [
r
@@ -873,7 +978,6 @@ class RevisionMap:
del ancestors_by_idx[current_candidate_idx]
current_candidate_idx = max(current_candidate_idx - 1, 0)
else:
-
if (
not candidate_rev._normalized_resolved_dependencies
and len(candidate_rev._versioned_down_revisions) == 1
@@ -905,7 +1009,13 @@ class RevisionMap:
assert not todo
return output
- def _walk(self, start, steps, branch_label=None, no_overwalk=True):
+ def _walk(
+ self,
+ start: Optional[Union[str, "Revision"]],
+ steps: int,
+ branch_label: Optional[str] = None,
+ no_overwalk: bool = True,
+ ) -> "Revision":
"""
Walk the requested number of :steps up (steps > 0) or down (steps < 0)
the revision tree.
@@ -918,44 +1028,55 @@ class RevisionMap:
A RevisionError is raised if there is no unambiguous revision to
walk to.
"""
-
+ initial: Optional[_RevisionOrBase]
if isinstance(start, compat.string_types):
- start = self.get_revision(start)
+ initial = self.get_revision(start)
+ else:
+ initial = start
+ children: Sequence[_RevisionOrBase]
for _ in range(abs(steps)):
if steps > 0:
# Walk up
children = [
rev
for rev in self.get_revisions(
- self.bases if start is None else start.nextrev
+ self.bases
+ if initial is None
+ else cast("Revision", initial).nextrev
)
]
if branch_label:
children = self.filter_for_lineage(children, branch_label)
else:
# Walk down
- if start == "base":
- children = tuple()
+ if initial == "base":
+ children = ()
else:
children = self.get_revisions(
- self.heads if start is None else start.down_revision
+ self.heads
+ if initial is None
+ else initial.down_revision
)
if not children:
- children = ("base",)
+ children = cast("Tuple[Literal['base']]", ("base",))
if not children:
# This will return an invalid result if no_overwalk, otherwise
# further steps will stay where we are.
- return None if no_overwalk else start
+ ret = None if no_overwalk else initial
+ return ret # type:ignore[return-value]
elif len(children) > 1:
raise RevisionError("Ambiguous walk")
- start = children[0]
+ initial = children[0]
- return start
+ return cast("Revision", initial)
def _parse_downgrade_target(
- self, current_revisions, target, assert_relative_length
- ):
+ self,
+ current_revisions: _RevisionIdentifierType,
+ target: _RevisionIdentifierType,
+ assert_relative_length: bool,
+ ) -> Tuple[Optional[str], Optional[_RevisionOrBase]]:
"""
Parse downgrade command syntax :target to retrieve the target revision
and branch label (if any) given the :current_revisons stamp of the
@@ -999,11 +1120,11 @@ class RevisionMap:
if relative_revision:
# Find target revision relative to current state.
if branch_label:
- symbol = self.filter_for_lineage(
+ symbol_list = self.filter_for_lineage(
util.to_tuple(current_revisions), branch_label
)
- assert len(symbol) == 1
- symbol = symbol[0]
+ assert len(symbol_list) == 1
+ symbol = symbol_list[0]
else:
current_revisions = util.to_tuple(current_revisions)
if not current_revisions:
@@ -1045,12 +1166,15 @@ class RevisionMap:
# No relative destination given, revision specified is absolute.
branch_label, _, symbol = target.rpartition("@")
if not branch_label:
- branch_label = None
+ branch_label = None # type:ignore[assignment]
return branch_label, self.get_revision(symbol)
def _parse_upgrade_target(
- self, current_revisions, target, assert_relative_length
- ):
+ self,
+ current_revisions: _RevisionIdentifierType,
+ target: _RevisionIdentifierType,
+ assert_relative_length: bool,
+ ) -> Tuple["Revision", ...]:
"""
Parse upgrade command syntax :target to retrieve the target revision
and given the :current_revisons stamp of the database.
@@ -1070,9 +1194,8 @@ class RevisionMap:
current_revisions = util.to_tuple(current_revisions)
- branch_label, symbol, relative = match.groups()
- relative_str = relative
- relative = int(relative)
+ branch_label, symbol, relative_str = match.groups()
+ relative = int(relative_str)
if relative > 0:
if symbol is None:
if not current_revisions:
@@ -1151,8 +1274,13 @@ class RevisionMap:
)
def _collect_downgrade_revisions(
- self, upper, target, inclusive, implicit_base, assert_relative_length
- ):
+ self,
+ upper: _RevisionIdentifierType,
+ target: _RevisionIdentifierType,
+ inclusive: bool,
+ implicit_base: bool,
+ assert_relative_length: bool,
+ ) -> Any:
"""
Compute the set of current revisions specified by :upper, and the
downgrade target specified by :target. Return all dependents of target
@@ -1244,8 +1372,13 @@ class RevisionMap:
return downgrade_revisions, heads
def _collect_upgrade_revisions(
- self, upper, lower, inclusive, implicit_base, assert_relative_length
- ):
+ self,
+ upper: _RevisionIdentifierType,
+ lower: _RevisionIdentifierType,
+ inclusive: bool,
+ implicit_base: bool,
+ assert_relative_length: bool,
+ ) -> Tuple[Set["Revision"], Tuple[Optional[_RevisionOrBase]]]:
"""
Compute the set of required revisions specified by :upper, and the
current set of active revisions specified by :lower. Find the
@@ -1257,14 +1390,13 @@ class RevisionMap:
of the current/lower revisions. Dependencies from branches with
different bases will not be included.
"""
- targets = self._parse_upgrade_target(
+ targets: Collection["Revision"] = self._parse_upgrade_target(
current_revisions=lower,
target=upper,
assert_relative_length=assert_relative_length,
)
- assert targets is not None
- assert type(targets) is tuple, "targets should be a tuple"
+ # assert type(targets) is tuple, "targets should be a tuple"
# Handled named bases (e.g. branch@... -> heads should only produce
# targets on the given branch)
@@ -1332,7 +1464,7 @@ class RevisionMap:
)
needs.intersection_update(lower_descendents)
- return needs, targets
+ return needs, tuple(targets) # type:ignore[return-value]
class Revision:
@@ -1346,15 +1478,15 @@ class Revision:
"""
- nextrev = frozenset()
+ nextrev: FrozenSet[str] = frozenset()
"""following revisions, based on down_revision only."""
- _all_nextrev = frozenset()
+ _all_nextrev: FrozenSet[str] = frozenset()
- revision = None
+ revision: str = None # type: ignore[assignment]
"""The string revision number."""
- down_revision = None
+ down_revision: Optional[_RevIdType] = None
"""The ``down_revision`` identifier(s) within the migration script.
Note that the total set of "down" revisions is
@@ -1362,7 +1494,7 @@ class Revision:
"""
- dependencies = None
+ dependencies: Optional[_RevIdType] = None
"""Additional revisions which this revision is dependent on.
From a migration standpoint, these dependencies are added to the
@@ -1372,12 +1504,15 @@ class Revision:
"""
- branch_labels = None
+ branch_labels: Set[str] = None # type: ignore[assignment]
"""Optional string/tuple of symbolic names to apply to this
revision's branch"""
+ _resolved_dependencies: Tuple[str, ...]
+ _normalized_resolved_dependencies: Tuple[str, ...]
+
@classmethod
- def verify_rev_id(cls, revision):
+ def verify_rev_id(cls, revision: str) -> None:
illegal_chars = set(revision).intersection(_revision_illegal_chars)
if illegal_chars:
raise RevisionError(
@@ -1386,8 +1521,12 @@ class Revision:
)
def __init__(
- self, revision, down_revision, dependencies=None, branch_labels=None
- ):
+ self,
+ revision: str,
+ down_revision: Optional[Union[str, Tuple[str, ...]]],
+ dependencies: Optional[Tuple[str, ...]] = None,
+ branch_labels: Optional[Tuple[str, ...]] = None,
+ ) -> None:
if down_revision and revision in util.to_tuple(down_revision):
raise LoopDetected(revision)
elif dependencies is not None and revision in util.to_tuple(
@@ -1402,7 +1541,7 @@ class Revision:
self._orig_branch_labels = util.to_tuple(branch_labels, default=())
self.branch_labels = set(self._orig_branch_labels)
- def __repr__(self):
+ def __repr__(self) -> str:
args = [repr(self.revision), repr(self.down_revision)]
if self.dependencies:
args.append("dependencies=%r" % (self.dependencies,))
@@ -1410,20 +1549,20 @@ class Revision:
args.append("branch_labels=%r" % (self.branch_labels,))
return "%s(%s)" % (self.__class__.__name__, ", ".join(args))
- def add_nextrev(self, revision):
+ def add_nextrev(self, revision: "Revision") -> None:
self._all_nextrev = self._all_nextrev.union([revision.revision])
if self.revision in revision._versioned_down_revisions:
self.nextrev = self.nextrev.union([revision.revision])
@property
- def _all_down_revisions(self):
+ def _all_down_revisions(self) -> Tuple[str, ...]:
return util.dedupe_tuple(
util.to_tuple(self.down_revision, default=())
+ self._resolved_dependencies
)
@property
- def _normalized_down_revisions(self):
+ def _normalized_down_revisions(self) -> Tuple[str, ...]:
"""return immediate down revisions for a rev, omitting dependencies
that are still dependencies of ancestors.
@@ -1434,11 +1573,11 @@ class Revision:
)
@property
- def _versioned_down_revisions(self):
+ def _versioned_down_revisions(self) -> Tuple[str, ...]:
return util.to_tuple(self.down_revision, default=())
@property
- def is_head(self):
+ def is_head(self) -> bool:
"""Return True if this :class:`.Revision` is a 'head' revision.
This is determined based on whether any other :class:`.Script`
@@ -1449,17 +1588,17 @@ class Revision:
return not bool(self.nextrev)
@property
- def _is_real_head(self):
+ def _is_real_head(self) -> bool:
return not bool(self._all_nextrev)
@property
- def is_base(self):
+ def is_base(self) -> bool:
"""Return True if this :class:`.Revision` is a 'base' revision."""
return self.down_revision is None
@property
- def _is_real_base(self):
+ def _is_real_base(self) -> bool:
"""Return True if this :class:`.Revision` is a "real" base revision,
e.g. that it has no dependencies either."""
@@ -1469,7 +1608,7 @@ class Revision:
return self.down_revision is None and self.dependencies is None
@property
- def is_branch_point(self):
+ def is_branch_point(self) -> bool:
"""Return True if this :class:`.Script` is a branch point.
A branchpoint is defined as a :class:`.Script` which is referred
@@ -1481,7 +1620,7 @@ class Revision:
return len(self.nextrev) > 1
@property
- def _is_real_branch_point(self):
+ def _is_real_branch_point(self) -> bool:
"""Return True if this :class:`.Script` is a 'real' branch point,
taking into account dependencies as well.
@@ -1489,13 +1628,15 @@ class Revision:
return len(self._all_nextrev) > 1
@property
- def is_merge_point(self):
+ def is_merge_point(self) -> bool:
"""Return True if this :class:`.Script` is a merge point."""
return len(self._versioned_down_revisions) > 1
-def tuple_rev_as_scalar(rev):
+def tuple_rev_as_scalar(
+ rev: Optional[Sequence[str]],
+) -> Optional[Union[str, Sequence[str]]]:
if not rev:
return None
elif len(rev) == 1:
diff --git a/alembic/script/write_hooks.py b/alembic/script/write_hooks.py
index 8cd3dcc..8f9e35e 100644
--- a/alembic/script/write_hooks.py
+++ b/alembic/script/write_hooks.py
@@ -1,6 +1,11 @@
import shlex
import subprocess
import sys
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import List
+from typing import Union
from .. import util
from ..util import compat
@@ -11,7 +16,7 @@ REVISION_SCRIPT_TOKEN = "REVISION_SCRIPT_FILENAME"
_registry = {}
-def register(name):
+def register(name: str) -> Callable:
"""A function decorator that will register that function as a write hook.
See the documentation linked below for an example.
@@ -31,7 +36,9 @@ def register(name):
return decorate
-def _invoke(name, revision, options):
+def _invoke(
+ name: str, revision: str, options: Dict[str, Union[str, int]]
+) -> Any:
"""Invokes the formatter registered for the given name.
:param name: The name of a formatter in the registry
@@ -50,7 +57,7 @@ def _invoke(name, revision, options):
return hook(revision, options)
-def _run_hooks(path, hook_config):
+def _run_hooks(path: str, hook_config: Dict[str, str]) -> None:
"""Invoke hooks for a generated revision."""
from .base import _split_on_space_comma
@@ -83,7 +90,7 @@ def _run_hooks(path, hook_config):
)
-def _parse_cmdline_options(cmdline_options_str, path):
+def _parse_cmdline_options(cmdline_options_str: str, path: str) -> List[str]:
"""Parse options from a string into a list.
Also substitutes the revision script token with the actual filename of