summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCaselIT <cfederico87@gmail.com>2021-04-18 15:44:50 +0200
committerMike Bayer <mike_mp@zzzcomputing.com>2021-08-11 15:04:56 -0400
commit6aad68605f510e8b51f42efa812e02b3831d6e33 (patch)
treecc0e98b8ad8245add8692d8e4910faf57abf7ae3
parent3bf6a326c0a11e4f05c94008709d6b0b8e9e051a (diff)
downloadalembic-6aad68605f510e8b51f42efa812e02b3831d6e33.tar.gz
Add pep-484 type annotations
pep-484 type annotations have been added throughout the library. This should be helpful in providing Mypy and IDE support, however there is not full support for Alembic's dynamically modified "op" namespace as of yet; a future release will likely modify the approach used for importing this namespace to be better compatible with pep-484 capabilities. Type originally created using MonkeyType Add types extracted with the MonkeyType https://github.com/instagram/MonkeyType library by running the unit tests using ``monkeytype run -m pytest tests``, then ``monkeytype apply <module>`` (see below for further details). USed MonkeyType version 20.5 on Python 3.8, since newer version have issues After applying the types, the new imports are placed in a ``TYPE_CHECKING`` guard and all type definition of non base types are deferred by using the string notation. NOTE: since to apply the types MonkeType need to import the module, also the test ones, the patch below mocks the setup done by pytest so that the tests could be correctly imported diff --git a/alembic/testing/__init__.py b/alembic/testing/__init__.py index bdd1746..b1090c7 100644 Change-Id: Iff93628f4b43c740848871ce077a118db5e75d41 --- a/alembic/testing/__init__.py +++ b/alembic/testing/__init__.py @@ -9,6 +9,12 @@ from sqlalchemy.testing.config import combinations from sqlalchemy.testing.config import fixture from sqlalchemy.testing.config import requirements as requires +from sqlalchemy.testing.plugin.pytestplugin import PytestFixtureFunctions +from sqlalchemy.testing.plugin.plugin_base import _setup_requirements + +config._fixture_functions = PytestFixtureFunctions() +_setup_requirements("tests.requirements:DefaultRequirements") + from alembic import util from .assertions import assert_raises from .assertions import assert_raises_message Currently I'm using this branch of the sqlalchemy stubs: https://github.com/sqlalchemy/sqlalchemy2-stubs/tree/alembic_updates Change-Id: I8fd0700aab1913f395302626b8b84fea60334abd
-rw-r--r--alembic/__init__.py5
-rw-r--r--alembic/autogenerate/api.py161
-rw-r--r--alembic/autogenerate/compare.py341
-rw-r--r--alembic/autogenerate/render.py301
-rw-r--r--alembic/autogenerate/rewriter.py92
-rw-r--r--alembic/command.py95
-rw-r--r--alembic/config.py73
-rw-r--r--alembic/ddl/base.py163
-rw-r--r--alembic/ddl/impl.py239
-rw-r--r--alembic/ddl/mssql.py123
-rw-r--r--alembic/ddl/mysql.py118
-rw-r--r--alembic/ddl/oracle.py57
-rw-r--r--alembic/ddl/postgresql.py202
-rw-r--r--alembic/ddl/sqlite.py65
-rw-r--r--alembic/environment.py1
-rw-r--r--alembic/migration.py1
-rw-r--r--alembic/operations/base.py59
-rw-r--r--alembic/operations/batch.py168
-rw-r--r--alembic/operations/ops.py822
-rw-r--r--alembic/operations/schemaobj.py124
-rw-r--r--alembic/operations/toimpl.py55
-rw-r--r--alembic/runtime/environment.py128
-rw-r--r--alembic/runtime/migration.py327
-rw-r--r--alembic/script/base.py257
-rw-r--r--alembic/script/revision.py409
-rw-r--r--alembic/script/write_hooks.py15
-rw-r--r--alembic/testing/assertions.py6
-rw-r--r--alembic/testing/fixtures.py6
-rw-r--r--alembic/testing/requirements.py8
-rw-r--r--alembic/testing/suite/_autogen_fixtures.py5
-rw-r--r--alembic/util/compat.py42
-rw-r--r--alembic/util/editor.py20
-rw-r--r--alembic/util/langhelpers.py66
-rw-r--r--alembic/util/messaging.py23
-rw-r--r--alembic/util/pyfiles.py27
-rw-r--r--alembic/util/sqla_compat.py149
-rw-r--r--docs/build/unreleased/py3_typing.rst8
-rw-r--r--setup.cfg9
-rw-r--r--tests/test_revision.py32
-rw-r--r--tox.ini13
40 files changed, 3302 insertions, 1513 deletions
diff --git a/alembic/__init__.py b/alembic/__init__.py
index 0820de0..023fd06 100644
--- a/alembic/__init__.py
+++ b/alembic/__init__.py
@@ -2,10 +2,5 @@ import sys
from . import context
from . import op
-from .runtime import environment
-from .runtime import migration
__version__ = "1.7.0"
-
-sys.modules["alembic.migration"] = migration
-sys.modules["alembic.environment"] = environment
diff --git a/alembic/autogenerate/api.py b/alembic/autogenerate/api.py
index 4c156c4..3b23dcd 100644
--- a/alembic/autogenerate/api.py
+++ b/alembic/autogenerate/api.py
@@ -2,6 +2,15 @@
automatically."""
import contextlib
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import Iterator
+from typing import Optional
+from typing import Set
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import Union
from sqlalchemy import inspect
@@ -10,8 +19,26 @@ from . import render
from .. import util
from ..operations import ops
-
-def compare_metadata(context, metadata):
+if TYPE_CHECKING:
+ from sqlalchemy.engine import Connection
+ from sqlalchemy.engine import Dialect
+ from sqlalchemy.engine import Inspector
+ from sqlalchemy.sql.schema import Column
+ from sqlalchemy.sql.schema import ForeignKeyConstraint
+ from sqlalchemy.sql.schema import Index
+ from sqlalchemy.sql.schema import MetaData
+ from sqlalchemy.sql.schema import Table
+ from sqlalchemy.sql.schema import UniqueConstraint
+
+ from alembic.config import Config
+ from alembic.operations.ops import MigrationScript
+ from alembic.operations.ops import UpgradeOps
+ from alembic.runtime.migration import MigrationContext
+ from alembic.script.base import Script
+ from alembic.script.base import ScriptDirectory
+
+
+def compare_metadata(context: "MigrationContext", metadata: "MetaData") -> Any:
"""Compare a database schema to that given in a
:class:`~sqlalchemy.schema.MetaData` instance.
@@ -106,7 +133,9 @@ def compare_metadata(context, metadata):
return migration_script.upgrade_ops.as_diffs()
-def produce_migrations(context, metadata):
+def produce_migrations(
+ context: "MigrationContext", metadata: "MetaData"
+) -> "MigrationScript":
"""Produce a :class:`.MigrationScript` structure based on schema
comparison.
@@ -136,14 +165,14 @@ def produce_migrations(context, metadata):
def render_python_code(
- up_or_down_op,
- sqlalchemy_module_prefix="sa.",
- alembic_module_prefix="op.",
- render_as_batch=False,
- imports=(),
- render_item=None,
- migration_context=None,
-):
+ up_or_down_op: "UpgradeOps",
+ sqlalchemy_module_prefix: str = "sa.",
+ alembic_module_prefix: str = "op.",
+ render_as_batch: bool = False,
+ imports: Tuple[str, ...] = (),
+ render_item: None = None,
+ migration_context: Optional["MigrationContext"] = None,
+) -> str:
"""Render Python code given an :class:`.UpgradeOps` or
:class:`.DowngradeOps` object.
@@ -173,7 +202,9 @@ def render_python_code(
)
-def _render_migration_diffs(context, template_args):
+def _render_migration_diffs(
+ context: "MigrationContext", template_args: Dict[Any, Any]
+) -> None:
"""legacy, used by test_autogen_composition at the moment"""
autogen_context = AutogenContext(context)
@@ -196,7 +227,7 @@ class AutogenContext:
"""Maintains configuration and state that's specific to an
autogenerate operation."""
- metadata = None
+ metadata: Optional["MetaData"] = None
"""The :class:`~sqlalchemy.schema.MetaData` object
representing the destination.
@@ -214,7 +245,7 @@ class AutogenContext:
"""
- connection = None
+ connection: Optional["Connection"] = None
"""The :class:`~sqlalchemy.engine.base.Connection` object currently
connected to the database backend being compared.
@@ -223,7 +254,7 @@ class AutogenContext:
"""
- dialect = None
+ dialect: Optional["Dialect"] = None
"""The :class:`~sqlalchemy.engine.Dialect` object currently in use.
This is normally obtained from the
@@ -231,7 +262,7 @@ class AutogenContext:
"""
- imports = None
+ imports: Set[str] = None # type: ignore[assignment]
"""A ``set()`` which contains string Python import directives.
The directives are to be rendered into the ``${imports}`` section
@@ -245,12 +276,16 @@ class AutogenContext:
"""
- migration_context = None
+ migration_context: "MigrationContext" = None # type: ignore[assignment]
"""The :class:`.MigrationContext` established by the ``env.py`` script."""
def __init__(
- self, migration_context, metadata=None, opts=None, autogenerate=True
- ):
+ self,
+ migration_context: "MigrationContext",
+ metadata: Optional["MetaData"] = None,
+ opts: Optional[dict] = None,
+ autogenerate: bool = True,
+ ) -> None:
if (
autogenerate
@@ -301,20 +336,25 @@ class AutogenContext:
self.dialect = self.migration_context.dialect
self.imports = set()
- self.opts = opts
- self._has_batch = False
+ self.opts: Dict[str, Any] = opts
+ self._has_batch: bool = False
@util.memoized_property
- def inspector(self):
+ def inspector(self) -> "Inspector":
return inspect(self.connection)
@contextlib.contextmanager
- def _within_batch(self):
+ def _within_batch(self) -> Iterator[None]:
self._has_batch = True
yield
self._has_batch = False
- def run_name_filters(self, name, type_, parent_names):
+ def run_name_filters(
+ self,
+ name: Optional[str],
+ type_: str,
+ parent_names: Dict[str, Optional[str]],
+ ) -> bool:
"""Run the context's name filters and return True if the targets
should be part of the autogenerate operation.
@@ -348,7 +388,22 @@ class AutogenContext:
else:
return True
- def run_object_filters(self, object_, name, type_, reflected, compare_to):
+ def run_object_filters(
+ self,
+ object_: Union[
+ "Table",
+ "Index",
+ "Column",
+ "UniqueConstraint",
+ "ForeignKeyConstraint",
+ ],
+ name: Optional[str],
+ type_: str,
+ reflected: bool,
+ compare_to: Optional[
+ Union["Table", "Index", "Column", "UniqueConstraint"]
+ ],
+ ) -> bool:
"""Run the context's object filters and return True if the targets
should be part of the autogenerate operation.
@@ -414,11 +469,11 @@ class RevisionContext:
def __init__(
self,
- config,
- script_directory,
- command_args,
- process_revision_directives=None,
- ):
+ config: "Config",
+ script_directory: "ScriptDirectory",
+ command_args: Dict[str, Any],
+ process_revision_directives: Optional[Callable] = None,
+ ) -> None:
self.config = config
self.script_directory = script_directory
self.command_args = command_args
@@ -429,10 +484,10 @@ class RevisionContext:
}
self.generated_revisions = [self._default_revision()]
- def _to_script(self, migration_script):
- template_args = {}
- for k, v in self.template_args.items():
- template_args.setdefault(k, v)
+ def _to_script(
+ self, migration_script: "MigrationScript"
+ ) -> Optional["Script"]:
+ template_args: Dict[str, Any] = self.template_args.copy()
if getattr(migration_script, "_needs_render", False):
autogen_context = self._last_autogen_context
@@ -446,6 +501,7 @@ class RevisionContext:
autogen_context, migration_script, template_args
)
+ assert migration_script.rev_id is not None
return self.script_directory.generate_revision(
migration_script.rev_id,
migration_script.message,
@@ -458,13 +514,22 @@ class RevisionContext:
**template_args
)
- def run_autogenerate(self, rev, migration_context):
+ def run_autogenerate(
+ self, rev: tuple, migration_context: "MigrationContext"
+ ):
self._run_environment(rev, migration_context, True)
- def run_no_autogenerate(self, rev, migration_context):
+ def run_no_autogenerate(
+ self, rev: tuple, migration_context: "MigrationContext"
+ ):
self._run_environment(rev, migration_context, False)
- def _run_environment(self, rev, migration_context, autogenerate):
+ def _run_environment(
+ self,
+ rev: tuple,
+ migration_context: "MigrationContext",
+ autogenerate: bool,
+ ):
if autogenerate:
if self.command_args["sql"]:
raise util.CommandError(
@@ -493,9 +558,10 @@ class RevisionContext:
ops.DowngradeOps([], downgrade_token=downgrade_token)
)
- self._last_autogen_context = autogen_context = AutogenContext(
+ autogen_context = AutogenContext(
migration_context, autogenerate=autogenerate
)
+ self._last_autogen_context: AutogenContext = autogen_context
if autogenerate:
compare._populate_migration_script(
@@ -514,20 +580,21 @@ class RevisionContext:
for migration_script in self.generated_revisions:
migration_script._needs_render = True
- def _default_revision(self):
+ def _default_revision(self) -> "MigrationScript":
+ command_args: Dict[str, Any] = self.command_args
op = ops.MigrationScript(
- rev_id=self.command_args["rev_id"] or util.rev_id(),
- message=self.command_args["message"],
+ rev_id=command_args["rev_id"] or util.rev_id(),
+ message=command_args["message"],
upgrade_ops=ops.UpgradeOps([]),
downgrade_ops=ops.DowngradeOps([]),
- head=self.command_args["head"],
- splice=self.command_args["splice"],
- branch_label=self.command_args["branch_label"],
- version_path=self.command_args["version_path"],
- depends_on=self.command_args["depends_on"],
+ head=command_args["head"],
+ splice=command_args["splice"],
+ branch_label=command_args["branch_label"],
+ version_path=command_args["version_path"],
+ depends_on=command_args["depends_on"],
)
return op
- def generate_scripts(self):
+ def generate_scripts(self) -> Iterator[Optional["Script"]]:
for generated_revision in self.generated_revisions:
yield self._to_script(generated_revision)
diff --git a/alembic/autogenerate/compare.py b/alembic/autogenerate/compare.py
index dbb0706..528b17a 100644
--- a/alembic/autogenerate/compare.py
+++ b/alembic/autogenerate/compare.py
@@ -1,6 +1,16 @@
import contextlib
import logging
import re
+from typing import Any
+from typing import cast
+from typing import Dict
+from typing import Iterator
+from typing import List
+from typing import Optional
+from typing import Set
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import Union
from sqlalchemy import event
from sqlalchemy import inspect
@@ -14,10 +24,29 @@ from .. import util
from ..operations import ops
from ..util import sqla_compat
+if TYPE_CHECKING:
+ from typing import Literal
+
+ from sqlalchemy.engine.reflection import Inspector
+ from sqlalchemy.sql.elements import quoted_name
+ from sqlalchemy.sql.schema import Column
+ from sqlalchemy.sql.schema import ForeignKeyConstraint
+ from sqlalchemy.sql.schema import Index
+ from sqlalchemy.sql.schema import Table
+ from sqlalchemy.sql.schema import UniqueConstraint
+
+ from alembic.autogenerate.api import AutogenContext
+ from alembic.operations.ops import AlterColumnOp
+ from alembic.operations.ops import MigrationScript
+ from alembic.operations.ops import ModifyTableOps
+ from alembic.operations.ops import UpgradeOps
+
log = logging.getLogger(__name__)
-def _populate_migration_script(autogen_context, migration_script):
+def _populate_migration_script(
+ autogen_context: "AutogenContext", migration_script: "MigrationScript"
+) -> None:
upgrade_ops = migration_script.upgrade_ops_list[-1]
downgrade_ops = migration_script.downgrade_ops_list[-1]
@@ -28,14 +57,18 @@ def _populate_migration_script(autogen_context, migration_script):
comparators = util.Dispatcher(uselist=True)
-def _produce_net_changes(autogen_context, upgrade_ops):
+def _produce_net_changes(
+ autogen_context: "AutogenContext", upgrade_ops: "UpgradeOps"
+) -> None:
connection = autogen_context.connection
+ assert connection is not None
include_schemas = autogen_context.opts.get("include_schemas", False)
- inspector = inspect(connection)
+ inspector: "Inspector" = inspect(connection)
default_schema = connection.dialect.default_schema_name
+ schemas: Set[Optional[str]]
if include_schemas:
schemas = set(inspector.get_schema_names())
# replace default schema name with None
@@ -44,22 +77,27 @@ def _produce_net_changes(autogen_context, upgrade_ops):
schemas.discard(default_schema)
schemas.add(None)
else:
- schemas = [None]
+ schemas = {None}
schemas = {
s for s in schemas if autogen_context.run_name_filters(s, "schema", {})
}
+ assert autogen_context.dialect is not None
comparators.dispatch("schema", autogen_context.dialect.name)(
autogen_context, upgrade_ops, schemas
)
@comparators.dispatch_for("schema")
-def _autogen_for_tables(autogen_context, upgrade_ops, schemas):
+def _autogen_for_tables(
+ autogen_context: "AutogenContext",
+ upgrade_ops: "UpgradeOps",
+ schemas: Union[Set[None], Set[Optional[str]]],
+) -> None:
inspector = autogen_context.inspector
- conn_table_names = set()
+ conn_table_names: Set[Tuple[Optional[str], str]] = set()
version_table_schema = (
autogen_context.migration_context.version_table_schema
@@ -95,12 +133,12 @@ def _autogen_for_tables(autogen_context, upgrade_ops, schemas):
def _compare_tables(
- conn_table_names,
- metadata_table_names,
- inspector,
- upgrade_ops,
- autogen_context,
-):
+ conn_table_names: "set",
+ metadata_table_names: "set",
+ inspector: "Inspector",
+ upgrade_ops: "UpgradeOps",
+ autogen_context: "AutogenContext",
+) -> None:
default_schema = inspector.bind.dialect.default_schema_name
@@ -239,7 +277,7 @@ def _compare_tables(
upgrade_ops.ops.append(modify_table_ops)
-def _make_index(params, conn_table):
+def _make_index(params: Dict[str, Any], conn_table: "Table") -> "Index":
ix = sa_schema.Index(
params["name"],
*[conn_table.c[cname] for cname in params["column_names"]],
@@ -251,7 +289,9 @@ def _make_index(params, conn_table):
return ix
-def _make_unique_constraint(params, conn_table):
+def _make_unique_constraint(
+ params: Dict[str, Any], conn_table: "Table"
+) -> "UniqueConstraint":
uq = sa_schema.UniqueConstraint(
*[conn_table.c[cname] for cname in params["column_names"]],
name=params["name"]
@@ -262,7 +302,9 @@ def _make_unique_constraint(params, conn_table):
return uq
-def _make_foreign_key(params, conn_table):
+def _make_foreign_key(
+ params: Dict[str, Any], conn_table: "Table"
+) -> "ForeignKeyConstraint":
tname = params["referred_table"]
if params["referred_schema"]:
tname = "%s.%s" % (params["referred_schema"], tname)
@@ -285,14 +327,14 @@ def _make_foreign_key(params, conn_table):
@contextlib.contextmanager
def _compare_columns(
- schema,
- tname,
- conn_table,
- metadata_table,
- modify_table_ops,
- autogen_context,
- inspector,
-):
+ schema: Optional[str],
+ tname: Union["quoted_name", str],
+ conn_table: "Table",
+ metadata_table: "Table",
+ modify_table_ops: "ModifyTableOps",
+ autogen_context: "AutogenContext",
+ inspector: "Inspector",
+) -> Iterator[None]:
name = "%s.%s" % (schema, tname) if schema else tname
metadata_col_names = OrderedSet(
c.name for c in metadata_table.c if not c.system
@@ -357,7 +399,9 @@ def _compare_columns(
class _constraint_sig:
- def md_name_to_sql_name(self, context):
+ const: Union["UniqueConstraint", "ForeignKeyConstraint", "Index"]
+
+ def md_name_to_sql_name(self, context: "AutogenContext") -> Optional[str]:
return sqla_compat._get_constraint_final_name(
self.const, context.dialect
)
@@ -368,7 +412,7 @@ class _constraint_sig:
def __ne__(self, other):
return self.const != other.const
- def __hash__(self):
+ def __hash__(self) -> int:
return hash(self.const)
@@ -376,37 +420,39 @@ class _uq_constraint_sig(_constraint_sig):
is_index = False
is_unique = True
- def __init__(self, const):
+ def __init__(self, const: "UniqueConstraint") -> None:
self.const = const
self.name = const.name
self.sig = tuple(sorted([col.name for col in const.columns]))
@property
- def column_names(self):
+ def column_names(self) -> List[str]:
return [col.name for col in self.const.columns]
class _ix_constraint_sig(_constraint_sig):
is_index = True
- def __init__(self, const):
+ def __init__(self, const: "Index") -> None:
self.const = const
self.name = const.name
self.sig = tuple(sorted([col.name for col in const.columns]))
self.is_unique = bool(const.unique)
- def md_name_to_sql_name(self, context):
+ def md_name_to_sql_name(self, context: "AutogenContext") -> Optional[str]:
return sqla_compat._get_constraint_final_name(
self.const, context.dialect
)
@property
- def column_names(self):
+ def column_names(self) -> Union[List["quoted_name"], List[None]]:
return sqla_compat._get_index_column_names(self.const)
class _fk_constraint_sig(_constraint_sig):
- def __init__(self, const, include_options=False):
+ def __init__(
+ self, const: "ForeignKeyConstraint", include_options: bool = False
+ ) -> None:
self.const = const
self.name = const.name
@@ -423,7 +469,7 @@ class _fk_constraint_sig(_constraint_sig):
initially,
) = _fk_spec(const)
- self.sig = (
+ self.sig: Tuple[Any, ...] = (
self.source_schema,
self.source_table,
tuple(self.source_columns),
@@ -450,8 +496,13 @@ class _fk_constraint_sig(_constraint_sig):
@comparators.dispatch_for("table")
def _compare_indexes_and_uniques(
- autogen_context, modify_ops, schema, tname, conn_table, metadata_table
-):
+ autogen_context: "AutogenContext",
+ modify_ops: "ModifyTableOps",
+ schema: Optional[str],
+ tname: Union["quoted_name", str],
+ conn_table: Optional["Table"],
+ metadata_table: Optional["Table"],
+) -> None:
inspector = autogen_context.inspector
is_create_table = conn_table is None
@@ -469,7 +520,7 @@ def _compare_indexes_and_uniques(
metadata_unique_constraints = set()
metadata_indexes = set()
- conn_uniques = conn_indexes = frozenset()
+ conn_uniques = conn_indexes = frozenset() # type:ignore[var-annotated]
supports_unique_constraints = False
@@ -479,7 +530,7 @@ def _compare_indexes_and_uniques(
# 1b. ... and from connection, if the table exists
if hasattr(inspector, "get_unique_constraints"):
try:
- conn_uniques = inspector.get_unique_constraints(
+ conn_uniques = inspector.get_unique_constraints( # type:ignore[assignment] # noqa
tname, schema=schema
)
supports_unique_constraints = True
@@ -491,7 +542,7 @@ def _compare_indexes_and_uniques(
# not being present
pass
else:
- conn_uniques = [
+ conn_uniques = [ # type:ignore[assignment]
uq
for uq in conn_uniques
if autogen_context.run_name_filters(
@@ -504,11 +555,13 @@ def _compare_indexes_and_uniques(
if uq.get("duplicates_index"):
unique_constraints_duplicate_unique_indexes = True
try:
- conn_indexes = inspector.get_indexes(tname, schema=schema)
+ conn_indexes = inspector.get_indexes( # type:ignore[assignment]
+ tname, schema=schema
+ )
except NotImplementedError:
pass
else:
- conn_indexes = [
+ conn_indexes = [ # type:ignore[assignment]
ix
for ix in conn_indexes
if autogen_context.run_name_filters(
@@ -522,14 +575,16 @@ def _compare_indexes_and_uniques(
# into schema objects
if is_drop_table:
# for DROP TABLE uniques are inline, don't need them
- conn_uniques = set()
+ conn_uniques = set() # type:ignore[assignment]
else:
- conn_uniques = set(
+ conn_uniques = set( # type:ignore[assignment]
_make_unique_constraint(uq_def, conn_table)
for uq_def in conn_uniques
)
- conn_indexes = set(_make_index(ix, conn_table) for ix in conn_indexes)
+ conn_indexes = set( # type:ignore[assignment]
+ _make_index(ix, conn_table) for ix in conn_indexes
+ )
# 2a. if the dialect dupes unique indexes as unique constraints
# (mysql and oracle), correct for that
@@ -557,31 +612,39 @@ def _compare_indexes_and_uniques(
# _constraint_sig() objects provide a consistent facade over both
# Index and UniqueConstraint so we can easily work with them
# interchangeably
- metadata_unique_constraints = set(
+ metadata_unique_constraints_sig = set(
_uq_constraint_sig(uq) for uq in metadata_unique_constraints
)
- metadata_indexes = set(_ix_constraint_sig(ix) for ix in metadata_indexes)
+ metadata_indexes_sig = set(
+ _ix_constraint_sig(ix) for ix in metadata_indexes
+ )
conn_unique_constraints = set(
_uq_constraint_sig(uq) for uq in conn_uniques
)
- conn_indexes = set(_ix_constraint_sig(ix) for ix in conn_indexes)
+ conn_indexes_sig = set(_ix_constraint_sig(ix) for ix in conn_indexes)
# 5. index things by name, for those objects that have names
metadata_names = dict(
- (c.md_name_to_sql_name(autogen_context), c)
- for c in metadata_unique_constraints.union(metadata_indexes)
+ (cast(str, c.md_name_to_sql_name(autogen_context)), c)
+ for c in metadata_unique_constraints_sig.union(
+ metadata_indexes_sig # type:ignore[arg-type]
+ )
if isinstance(c, _ix_constraint_sig)
or sqla_compat._constraint_is_named(c.const, autogen_context.dialect)
)
conn_uniques_by_name = dict((c.name, c) for c in conn_unique_constraints)
- conn_indexes_by_name = dict((c.name, c) for c in conn_indexes)
+ conn_indexes_by_name: Dict[Optional[str], _ix_constraint_sig] = dict(
+ (c.name, c) for c in conn_indexes_sig
+ )
conn_names = dict(
(c.name, c)
- for c in conn_unique_constraints.union(conn_indexes)
+ for c in conn_unique_constraints.union(
+ conn_indexes_sig # type:ignore[arg-type]
+ )
if c.name is not None
)
@@ -596,12 +659,12 @@ def _compare_indexes_and_uniques(
# constraints.
conn_uniques_by_sig = dict((uq.sig, uq) for uq in conn_unique_constraints)
metadata_uniques_by_sig = dict(
- (uq.sig, uq) for uq in metadata_unique_constraints
+ (uq.sig, uq) for uq in metadata_unique_constraints_sig
)
- metadata_indexes_by_sig = dict((ix.sig, ix) for ix in metadata_indexes)
+ metadata_indexes_by_sig = dict((ix.sig, ix) for ix in metadata_indexes_sig)
unnamed_metadata_uniques = dict(
(uq.sig, uq)
- for uq in metadata_unique_constraints
+ for uq in metadata_unique_constraints_sig
if not sqla_compat._constraint_is_named(
uq.const, autogen_context.dialect
)
@@ -709,7 +772,9 @@ def _compare_indexes_and_uniques(
)
for removed_name in sorted(set(conn_names).difference(metadata_names)):
- conn_obj = conn_names[removed_name]
+ conn_obj: Union[_ix_constraint_sig, _uq_constraint_sig] = conn_names[
+ removed_name
+ ]
if not conn_obj.is_index and conn_obj.sig in unnamed_metadata_uniques:
continue
elif removed_name in doubled_constraints:
@@ -831,14 +896,14 @@ def _correct_for_uq_duplicates_uix(
@comparators.dispatch_for("column")
def _compare_nullable(
- autogen_context,
- alter_column_op,
- schema,
- tname,
- cname,
- conn_col,
- metadata_col,
-):
+ autogen_context: "AutogenContext",
+ alter_column_op: "AlterColumnOp",
+ schema: Optional[str],
+ tname: Union["quoted_name", str],
+ cname: Union["quoted_name", str],
+ conn_col: "Column",
+ metadata_col: "Column",
+) -> None:
metadata_col_nullable = metadata_col.nullable
conn_col_nullable = conn_col.nullable
@@ -873,14 +938,14 @@ def _compare_nullable(
@comparators.dispatch_for("column")
def _setup_autoincrement(
- autogen_context,
- alter_column_op,
- schema,
- tname,
- cname,
- conn_col,
- metadata_col,
-):
+ autogen_context: "AutogenContext",
+ alter_column_op: "AlterColumnOp",
+ schema: Optional[str],
+ tname: Union["quoted_name", str],
+ cname: "quoted_name",
+ conn_col: "Column",
+ metadata_col: "Column",
+) -> None:
if metadata_col.table._autoincrement_column is metadata_col:
alter_column_op.kw["autoincrement"] = True
@@ -892,14 +957,14 @@ def _setup_autoincrement(
@comparators.dispatch_for("column")
def _compare_type(
- autogen_context,
- alter_column_op,
- schema,
- tname,
- cname,
- conn_col,
- metadata_col,
-):
+ autogen_context: "AutogenContext",
+ alter_column_op: "AlterColumnOp",
+ schema: Optional[str],
+ tname: Union["quoted_name", str],
+ cname: Union["quoted_name", str],
+ conn_col: "Column",
+ metadata_col: "Column",
+) -> None:
conn_type = conn_col.type
alter_column_op.existing_type = conn_type
@@ -935,8 +1000,10 @@ def _compare_type(
def _render_server_default_for_compare(
- metadata_default, metadata_col, autogen_context
-):
+ metadata_default: Optional[Any],
+ metadata_col: "Column",
+ autogen_context: "AutogenContext",
+) -> Optional[str]:
rendered = _user_defined_render(
"server_default", metadata_default, autogen_context
)
@@ -963,7 +1030,7 @@ def _render_server_default_for_compare(
return None
-def _normalize_computed_default(sqltext):
+def _normalize_computed_default(sqltext: str) -> str:
"""we want to warn if a computed sql expression has changed. however
we don't want false positives and the warning is not that critical.
so filter out most forms of variability from the SQL text.
@@ -974,16 +1041,16 @@ def _normalize_computed_default(sqltext):
def _compare_computed_default(
- autogen_context,
- alter_column_op,
- schema,
- tname,
- cname,
- conn_col,
- metadata_col,
-):
+ autogen_context: "AutogenContext",
+ alter_column_op: "AlterColumnOp",
+ schema: Optional[str],
+ tname: "str",
+ cname: "str",
+ conn_col: "Column",
+ metadata_col: "Column",
+) -> None:
rendered_metadata_default = str(
- metadata_col.server_default.sqltext.compile(
+ cast(sa_schema.Computed, metadata_col.server_default).sqltext.compile(
dialect=autogen_context.dialect,
compile_kwargs={"literal_binds": True},
)
@@ -1017,7 +1084,7 @@ def _compare_computed_default(
_warn_computed_not_supported(tname, cname)
-def _warn_computed_not_supported(tname, cname):
+def _warn_computed_not_supported(tname: str, cname: str) -> None:
util.warn("Computed default on %s.%s cannot be modified" % (tname, cname))
@@ -1040,14 +1107,14 @@ def _compare_identity_default(
@comparators.dispatch_for("column")
def _compare_server_default(
- autogen_context,
- alter_column_op,
- schema,
- tname,
- cname,
- conn_col,
- metadata_col,
-):
+ autogen_context: "AutogenContext",
+ alter_column_op: "AlterColumnOp",
+ schema: Optional[str],
+ tname: Union["quoted_name", str],
+ cname: Union["quoted_name", str],
+ conn_col: "Column",
+ metadata_col: "Column",
+) -> Optional[bool]:
metadata_default = metadata_col.server_default
conn_col_default = conn_col.server_default
@@ -1065,14 +1132,16 @@ def _compare_server_default(
return False
else:
- return _compare_computed_default(
- autogen_context,
- alter_column_op,
- schema,
- tname,
- cname,
- conn_col,
- metadata_col,
+ return (
+ _compare_computed_default( # type:ignore[func-returns-value]
+ autogen_context,
+ alter_column_op,
+ schema,
+ tname,
+ cname,
+ conn_col,
+ metadata_col,
+ )
)
if sqla_compat._server_default_is_computed(conn_col_default):
_warn_computed_not_supported(tname, cname)
@@ -1107,7 +1176,7 @@ def _compare_server_default(
)
rendered_conn_default = (
- conn_col_default.arg.text if conn_col_default else None
+ cast(Any, conn_col_default).arg.text if conn_col_default else None
)
alter_column_op.existing_server_default = conn_col_default
@@ -1122,20 +1191,23 @@ def _compare_server_default(
alter_column_op.modify_server_default = metadata_default
log.info("Detected server default on column '%s.%s'", tname, cname)
+ return None
+
@comparators.dispatch_for("column")
def _compare_column_comment(
- autogen_context,
- alter_column_op,
- schema,
- tname,
- cname,
- conn_col,
- metadata_col,
-):
-
+ autogen_context: "AutogenContext",
+ alter_column_op: "AlterColumnOp",
+ schema: Optional[str],
+ tname: Union["quoted_name", str],
+ cname: "quoted_name",
+ conn_col: "Column",
+ metadata_col: "Column",
+) -> Optional["Literal[False]"]:
+
+ assert autogen_context.dialect is not None
if not autogen_context.dialect.supports_comments:
- return
+ return None
metadata_comment = metadata_col.comment
conn_col_comment = conn_col.comment
@@ -1148,16 +1220,18 @@ def _compare_column_comment(
alter_column_op.modify_comment = metadata_comment
log.info("Detected column comment '%s.%s'", tname, cname)
+ return None
+
@comparators.dispatch_for("table")
def _compare_foreign_keys(
- autogen_context,
- modify_table_ops,
- schema,
- tname,
- conn_table,
- metadata_table,
-):
+ autogen_context: "AutogenContext",
+ modify_table_ops: "ModifyTableOps",
+ schema: Optional[str],
+ tname: Union["quoted_name", str],
+ conn_table: Optional["Table"],
+ metadata_table: Optional["Table"],
+) -> None:
# if we're doing CREATE TABLE, all FKs are created
# inline within the table def
@@ -1181,7 +1255,7 @@ def _compare_foreign_keys(
)
]
- backend_reflects_fk_options = conn_fks and "options" in conn_fks[0]
+ backend_reflects_fk_options = bool(conn_fks and "options" in conn_fks[0])
conn_fks = set(_make_foreign_key(const, conn_table) for const in conn_fks)
@@ -1268,14 +1342,15 @@ def _compare_foreign_keys(
@comparators.dispatch_for("table")
def _compare_table_comment(
- autogen_context,
- modify_table_ops,
- schema,
- tname,
- conn_table,
- metadata_table,
-):
-
+ autogen_context: "AutogenContext",
+ modify_table_ops: "ModifyTableOps",
+ schema: Optional[str],
+ tname: Union["quoted_name", str],
+ conn_table: Optional["Table"],
+ metadata_table: Optional["Table"],
+) -> None:
+
+ assert autogen_context.dialect is not None
if not autogen_context.dialect.supports_comments:
return
diff --git a/alembic/autogenerate/render.py b/alembic/autogenerate/render.py
index 490d65c..90d49e5 100644
--- a/alembic/autogenerate/render.py
+++ b/alembic/autogenerate/render.py
@@ -1,11 +1,20 @@
from collections import OrderedDict
from io import StringIO
import re
+from typing import Any
+from typing import cast
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import Union
from mako.pygen import PythonPrinter
from sqlalchemy import schema as sa_schema
from sqlalchemy import sql
from sqlalchemy import types as sqltypes
+from sqlalchemy.sql.elements import conv
from .. import util
from ..operations import ops
@@ -13,34 +22,59 @@ from ..util import compat
from ..util import sqla_compat
from ..util.compat import string_types
+if TYPE_CHECKING:
+ from typing import Literal
+
+ from sqlalchemy.sql.elements import ColumnElement
+ from sqlalchemy.sql.elements import quoted_name
+ from sqlalchemy.sql.elements import TextClause
+ from sqlalchemy.sql.schema import CheckConstraint
+ from sqlalchemy.sql.schema import Column
+ from sqlalchemy.sql.schema import Constraint
+ from sqlalchemy.sql.schema import DefaultClause
+ from sqlalchemy.sql.schema import FetchedValue
+ from sqlalchemy.sql.schema import ForeignKey
+ from sqlalchemy.sql.schema import ForeignKeyConstraint
+ from sqlalchemy.sql.schema import Index
+ from sqlalchemy.sql.schema import MetaData
+ from sqlalchemy.sql.schema import PrimaryKeyConstraint
+ from sqlalchemy.sql.schema import UniqueConstraint
+ from sqlalchemy.sql.sqltypes import ARRAY
+ from sqlalchemy.sql.type_api import TypeEngine
+ from sqlalchemy.sql.type_api import Variant
+
+ from alembic.autogenerate.api import AutogenContext
+ from alembic.config import Config
+ from alembic.operations.ops import MigrationScript
+ from alembic.operations.ops import ModifyTableOps
+ from alembic.util.sqla_compat import Computed
+ from alembic.util.sqla_compat import Identity
-MAX_PYTHON_ARGS = 255
-
-try:
- from sqlalchemy.sql.naming import conv
-
- def _render_gen_name(autogen_context, name):
- if isinstance(name, conv):
- return _f_name(_alembic_autogenerate_prefix(autogen_context), name)
- else:
- return name
+MAX_PYTHON_ARGS = 255
-except ImportError:
- def _render_gen_name(autogen_context, name):
+def _render_gen_name(
+ autogen_context: "AutogenContext",
+ name: Optional[Union["quoted_name", str]],
+) -> Optional[Union["quoted_name", str, "_f_name"]]:
+ if isinstance(name, conv):
+ return _f_name(_alembic_autogenerate_prefix(autogen_context), name)
+ else:
return name
-def _indent(text):
+def _indent(text: str) -> str:
text = re.compile(r"^", re.M).sub(" ", text).strip()
text = re.compile(r" +$", re.M).sub("", text)
return text
def _render_python_into_templatevars(
- autogen_context, migration_script, template_args
-):
+ autogen_context: "AutogenContext",
+ migration_script: "MigrationScript",
+ template_args: Dict[str, Union[str, "Config"]],
+) -> None:
imports = autogen_context.imports
for upgrade_ops, downgrade_ops in zip(
@@ -58,7 +92,10 @@ def _render_python_into_templatevars(
default_renderers = renderers = util.Dispatcher()
-def _render_cmd_body(op_container, autogen_context):
+def _render_cmd_body(
+ op_container: "ops.OpContainer",
+ autogen_context: "AutogenContext",
+) -> str:
buf = StringIO()
printer = PythonPrinter(buf)
@@ -70,7 +107,7 @@ def _render_cmd_body(op_container, autogen_context):
has_lines = False
for op in op_container.ops:
lines = render_op(autogen_context, op)
- has_lines = has_lines or lines
+ has_lines = has_lines or bool(lines)
for line in lines:
printer.writeline(line)
@@ -83,18 +120,24 @@ def _render_cmd_body(op_container, autogen_context):
return buf.getvalue()
-def render_op(autogen_context, op):
+def render_op(
+ autogen_context: "AutogenContext", op: "ops.MigrateOperation"
+) -> List[str]:
renderer = renderers.dispatch(op)
lines = util.to_list(renderer(autogen_context, op))
return lines
-def render_op_text(autogen_context, op):
+def render_op_text(
+ autogen_context: "AutogenContext", op: "ops.MigrateOperation"
+) -> str:
return "\n".join(render_op(autogen_context, op))
@renderers.dispatch_for(ops.ModifyTableOps)
-def _render_modify_table(autogen_context, op):
+def _render_modify_table(
+ autogen_context: "AutogenContext", op: "ModifyTableOps"
+) -> List[str]:
opts = autogen_context.opts
render_as_batch = opts.get("render_as_batch", False)
@@ -121,7 +164,9 @@ def _render_modify_table(autogen_context, op):
@renderers.dispatch_for(ops.CreateTableCommentOp)
-def _render_create_table_comment(autogen_context, op):
+def _render_create_table_comment(
+ autogen_context: "AutogenContext", op: "ops.CreateTableCommentOp"
+) -> str:
templ = (
"{prefix}create_table_comment(\n"
@@ -144,7 +189,9 @@ def _render_create_table_comment(autogen_context, op):
@renderers.dispatch_for(ops.DropTableCommentOp)
-def _render_drop_table_comment(autogen_context, op):
+def _render_drop_table_comment(
+ autogen_context: "AutogenContext", op: "ops.DropTableCommentOp"
+) -> str:
templ = (
"{prefix}drop_table_comment(\n"
@@ -165,7 +212,9 @@ def _render_drop_table_comment(autogen_context, op):
@renderers.dispatch_for(ops.CreateTableOp)
-def _add_table(autogen_context, op):
+def _add_table(
+ autogen_context: "AutogenContext", op: "ops.CreateTableOp"
+) -> str:
table = op.to_table()
args = [
@@ -188,14 +237,14 @@ def _add_table(autogen_context, op):
)
if len(args) > MAX_PYTHON_ARGS:
- args = "*[" + ",\n".join(args) + "]"
+ args_str = "*[" + ",\n".join(args) + "]"
else:
- args = ",\n".join(args)
+ args_str = ",\n".join(args)
text = "%(prefix)screate_table(%(tablename)r,\n%(args)s" % {
"tablename": _ident(op.table_name),
"prefix": _alembic_autogenerate_prefix(autogen_context),
- "args": args,
+ "args": args_str,
}
if op.schema:
text += ",\nschema=%r" % _ident(op.schema)
@@ -215,7 +264,9 @@ def _add_table(autogen_context, op):
@renderers.dispatch_for(ops.DropTableOp)
-def _drop_table(autogen_context, op):
+def _drop_table(
+ autogen_context: "AutogenContext", op: "ops.DropTableOp"
+) -> str:
text = "%(prefix)sdrop_table(%(tname)r" % {
"prefix": _alembic_autogenerate_prefix(autogen_context),
"tname": _ident(op.table_name),
@@ -227,7 +278,9 @@ def _drop_table(autogen_context, op):
@renderers.dispatch_for(ops.CreateIndexOp)
-def _add_index(autogen_context, op):
+def _add_index(
+ autogen_context: "AutogenContext", op: "ops.CreateIndexOp"
+) -> str:
index = op.to_index()
has_batch = autogen_context._has_batch
@@ -243,6 +296,7 @@ def _add_index(autogen_context, op):
"unique=%(unique)r%(schema)s%(kwargs)s)"
)
+ assert index.table is not None
text = tmpl % {
"prefix": _alembic_autogenerate_prefix(autogen_context),
"name": _render_gen_name(autogen_context, index.name),
@@ -271,7 +325,9 @@ def _add_index(autogen_context, op):
@renderers.dispatch_for(ops.DropIndexOp)
-def _drop_index(autogen_context, op):
+def _drop_index(
+ autogen_context: "AutogenContext", op: "ops.DropIndexOp"
+) -> str:
index = op.to_index()
has_batch = autogen_context._has_batch
@@ -306,12 +362,16 @@ def _drop_index(autogen_context, op):
@renderers.dispatch_for(ops.CreateUniqueConstraintOp)
-def _add_unique_constraint(autogen_context, op):
+def _add_unique_constraint(
+ autogen_context: "AutogenContext", op: "ops.CreateUniqueConstraintOp"
+) -> List[str]:
return [_uq_constraint(op.to_constraint(), autogen_context, True)]
@renderers.dispatch_for(ops.CreateForeignKeyOp)
-def _add_fk_constraint(autogen_context, op):
+def _add_fk_constraint(
+ autogen_context: "AutogenContext", op: "ops.CreateForeignKeyOp"
+) -> str:
args = [repr(_render_gen_name(autogen_context, op.constraint_name))]
if not autogen_context._has_batch:
@@ -358,7 +418,9 @@ def _add_check_constraint(constraint, autogen_context):
@renderers.dispatch_for(ops.DropConstraintOp)
-def _drop_constraint(autogen_context, op):
+def _drop_constraint(
+ autogen_context: "AutogenContext", op: "ops.DropConstraintOp"
+) -> str:
if autogen_context._has_batch:
template = "%(prefix)sdrop_constraint" "(%(name)r, type_=%(type)r)"
@@ -379,7 +441,9 @@ def _drop_constraint(autogen_context, op):
@renderers.dispatch_for(ops.AddColumnOp)
-def _add_column(autogen_context, op):
+def _add_column(
+ autogen_context: "AutogenContext", op: "ops.AddColumnOp"
+) -> str:
schema, tname, column = op.schema, op.table_name, op.column
if autogen_context._has_batch:
@@ -399,7 +463,9 @@ def _add_column(autogen_context, op):
@renderers.dispatch_for(ops.DropColumnOp)
-def _drop_column(autogen_context, op):
+def _drop_column(
+ autogen_context: "AutogenContext", op: "ops.DropColumnOp"
+) -> str:
schema, tname, column_name = op.schema, op.table_name, op.column_name
@@ -421,7 +487,9 @@ def _drop_column(autogen_context, op):
@renderers.dispatch_for(ops.AlterColumnOp)
-def _alter_column(autogen_context, op):
+def _alter_column(
+ autogen_context: "AutogenContext", op: "ops.AlterColumnOp"
+) -> str:
tname = op.table_name
cname = op.column_name
@@ -481,15 +549,15 @@ def _alter_column(autogen_context, op):
class _f_name:
- def __init__(self, prefix, name):
+ def __init__(self, prefix: str, name: conv) -> None:
self.prefix = prefix
self.name = name
- def __repr__(self):
+ def __repr__(self) -> str:
return "%sf(%r)" % (self.prefix, _ident(self.name))
-def _ident(name):
+def _ident(name: Optional[Union["quoted_name", str]]) -> Optional[str]:
"""produce a __repr__() object for a string identifier that may
use quoted_name() in SQLAlchemy 0.9 and greater.
@@ -506,8 +574,11 @@ def _ident(name):
def _render_potential_expr(
- value, autogen_context, wrap_in_text=True, is_server_default=False
-):
+ value: Any,
+ autogen_context: "AutogenContext",
+ wrap_in_text: bool = True,
+ is_server_default: bool = False,
+) -> str:
if isinstance(value, sql.ClauseElement):
if wrap_in_text:
@@ -526,7 +597,9 @@ def _render_potential_expr(
return repr(value)
-def _get_index_rendered_expressions(idx, autogen_context):
+def _get_index_rendered_expressions(
+ idx: "Index", autogen_context: "AutogenContext"
+) -> List[str]:
return [
repr(_ident(getattr(exp, "name", None)))
if isinstance(exp, sa_schema.Column)
@@ -535,8 +608,12 @@ def _get_index_rendered_expressions(idx, autogen_context):
]
-def _uq_constraint(constraint, autogen_context, alter):
- opts = []
+def _uq_constraint(
+ constraint: "UniqueConstraint",
+ autogen_context: "AutogenContext",
+ alter: bool,
+) -> str:
+ opts: List[Tuple[str, Any]] = []
has_batch = autogen_context._has_batch
@@ -578,18 +655,20 @@ def _user_autogenerate_prefix(autogen_context, target):
return prefix
-def _sqlalchemy_autogenerate_prefix(autogen_context):
+def _sqlalchemy_autogenerate_prefix(autogen_context: "AutogenContext") -> str:
return autogen_context.opts["sqlalchemy_module_prefix"] or ""
-def _alembic_autogenerate_prefix(autogen_context):
+def _alembic_autogenerate_prefix(autogen_context: "AutogenContext") -> str:
if autogen_context._has_batch:
return "batch_op."
else:
return autogen_context.opts["alembic_module_prefix"] or ""
-def _user_defined_render(type_, object_, autogen_context):
+def _user_defined_render(
+ type_: str, object_: Any, autogen_context: "AutogenContext"
+) -> Union[str, "Literal[False]"]:
if "render_item" in autogen_context.opts:
render = autogen_context.opts["render_item"]
if render:
@@ -599,17 +678,17 @@ def _user_defined_render(type_, object_, autogen_context):
return False
-def _render_column(column, autogen_context):
+def _render_column(column: "Column", autogen_context: "AutogenContext") -> str:
rendered = _user_defined_render("column", column, autogen_context)
if rendered is not False:
return rendered
- args = []
- opts = []
+ args: List[str] = []
+ opts: List[Tuple[str, Any]] = []
if column.server_default:
- rendered = _render_server_default(
+ rendered = _render_server_default( # type:ignore[assignment]
column.server_default, autogen_context
)
if rendered:
@@ -655,21 +734,29 @@ def _render_column(column, autogen_context):
}
-def _should_render_server_default_positionally(server_default):
+def _should_render_server_default_positionally(
+ server_default: Union["Computed", "DefaultClause"]
+) -> bool:
return sqla_compat._server_default_is_computed(
server_default
) or sqla_compat._server_default_is_identity(server_default)
-def _render_server_default(default, autogen_context, repr_=True):
+def _render_server_default(
+ default: Optional[
+ Union["FetchedValue", str, "TextClause", "ColumnElement"]
+ ],
+ autogen_context: "AutogenContext",
+ repr_: bool = True,
+) -> Optional[str]:
rendered = _user_defined_render("server_default", default, autogen_context)
if rendered is not False:
return rendered
if sqla_compat._server_default_is_computed(default):
- return _render_computed(default, autogen_context)
+ return _render_computed(cast("Computed", default), autogen_context)
elif sqla_compat._server_default_is_identity(default):
- return _render_identity(default, autogen_context)
+ return _render_identity(cast("Identity", default), autogen_context)
elif isinstance(default, sa_schema.DefaultClause):
if isinstance(default.arg, compat.string_types):
default = default.arg
@@ -681,10 +768,12 @@ def _render_server_default(default, autogen_context, repr_=True):
if isinstance(default, string_types) and repr_:
default = repr(re.sub(r"^'|'$", "", default))
- return default
+ return cast(str, default)
-def _render_computed(computed, autogen_context):
+def _render_computed(
+ computed: "Computed", autogen_context: "AutogenContext"
+) -> str:
text = _render_potential_expr(
computed.sqltext, autogen_context, wrap_in_text=False
)
@@ -699,7 +788,9 @@ def _render_computed(computed, autogen_context):
}
-def _render_identity(identity, autogen_context):
+def _render_identity(
+ identity: "Identity", autogen_context: "AutogenContext"
+) -> str:
# always=None means something different than always=False
kwargs = OrderedDict(always=identity.always)
if identity.on_null is not None:
@@ -712,7 +803,7 @@ def _render_identity(identity, autogen_context):
}
-def _get_identity_options(identity_options):
+def _get_identity_options(identity_options: "Identity") -> OrderedDict:
kwargs = OrderedDict()
for attr in sqla_compat._identity_options_attrs:
value = getattr(identity_options, attr, None)
@@ -721,7 +812,7 @@ def _get_identity_options(identity_options):
return kwargs
-def _repr_type(type_, autogen_context):
+def _repr_type(type_: "TypeEngine", autogen_context: "AutogenContext") -> str:
rendered = _user_defined_render("type", type_, autogen_context)
if rendered is not False:
return rendered
@@ -736,7 +827,9 @@ def _repr_type(type_, autogen_context):
mod = type(type_).__module__
imports = autogen_context.imports
if mod.startswith("sqlalchemy.dialects"):
- dname = re.match(r"sqlalchemy\.dialects\.(\w+)", mod).group(1)
+ match = re.match(r"sqlalchemy\.dialects\.(\w+)", mod)
+ assert match is not None
+ dname = match.group(1)
if imports is not None:
imports.add("from sqlalchemy.dialects import %s" % dname)
if impl_rt:
@@ -759,14 +852,22 @@ def _repr_type(type_, autogen_context):
return "%s%r" % (prefix, type_)
-def _render_ARRAY_type(type_, autogen_context):
- return _render_type_w_subtype(
- type_, autogen_context, "item_type", r"(.+?\()"
+def _render_ARRAY_type(
+ type_: "ARRAY", autogen_context: "AutogenContext"
+) -> str:
+ return cast(
+ str,
+ _render_type_w_subtype(
+ type_, autogen_context, "item_type", r"(.+?\()"
+ ),
)
-def _render_Variant_type(type_, autogen_context):
+def _render_Variant_type(
+ type_: "Variant", autogen_context: "AutogenContext"
+) -> str:
base = _repr_type(type_.impl, autogen_context)
+ assert base is not None and base is not False
for dialect in sorted(type_.mapping):
typ = type_.mapping[dialect]
base += ".with_variant(%s, %r)" % (
@@ -777,8 +878,12 @@ def _render_Variant_type(type_, autogen_context):
def _render_type_w_subtype(
- type_, autogen_context, attrname, regexp, prefix=None
-):
+ type_: "TypeEngine",
+ autogen_context: "AutogenContext",
+ attrname: str,
+ regexp: str,
+ prefix: Optional[str] = None,
+) -> Union[Optional[str], "Literal[False]"]:
outer_repr = repr(type_)
inner_type = getattr(type_, attrname, None)
if inner_type is None:
@@ -795,7 +900,9 @@ def _render_type_w_subtype(
mod = type(type_).__module__
if mod.startswith("sqlalchemy.dialects"):
- dname = re.match(r"sqlalchemy\.dialects\.(\w+)", mod).group(1)
+ match = re.match(r"sqlalchemy\.dialects\.(\w+)", mod)
+ assert match is not None
+ dname = match.group(1)
return "%s.%s" % (dname, outer_type)
elif mod.startswith("sqlalchemy"):
prefix = _sqlalchemy_autogenerate_prefix(autogen_context)
@@ -807,7 +914,11 @@ def _render_type_w_subtype(
_constraint_renderers = util.Dispatcher()
-def _render_constraint(constraint, autogen_context, namespace_metadata):
+def _render_constraint(
+ constraint: "Constraint",
+ autogen_context: "AutogenContext",
+ namespace_metadata: Optional["MetaData"],
+) -> Optional[str]:
try:
renderer = _constraint_renderers.dispatch(constraint)
except ValueError:
@@ -818,7 +929,11 @@ def _render_constraint(constraint, autogen_context, namespace_metadata):
@_constraint_renderers.dispatch_for(sa_schema.PrimaryKeyConstraint)
-def _render_primary_key(constraint, autogen_context, namespace_metadata):
+def _render_primary_key(
+ constraint: "PrimaryKeyConstraint",
+ autogen_context: "AutogenContext",
+ namespace_metadata: Optional["MetaData"],
+) -> Optional[str]:
rendered = _user_defined_render("primary_key", constraint, autogen_context)
if rendered is not False:
return rendered
@@ -840,12 +955,16 @@ def _render_primary_key(constraint, autogen_context, namespace_metadata):
}
-def _fk_colspec(fk, metadata_schema, namespace_metadata):
+def _fk_colspec(
+ fk: "ForeignKey",
+ metadata_schema: Optional[str],
+ namespace_metadata: "MetaData",
+) -> str:
"""Implement a 'safe' version of ForeignKey._get_colspec() that
won't fail if the remote table can't be resolved.
"""
- colspec = fk._get_colspec()
+ colspec = fk._get_colspec() # type:ignore[attr-defined]
tokens = colspec.split(".")
tname, colname = tokens[-2:]
@@ -873,7 +992,9 @@ def _fk_colspec(fk, metadata_schema, namespace_metadata):
return colspec
-def _populate_render_fk_opts(constraint, opts):
+def _populate_render_fk_opts(
+ constraint: "ForeignKeyConstraint", opts: List[Tuple[str, str]]
+) -> None:
if constraint.onupdate:
opts.append(("onupdate", repr(constraint.onupdate)))
@@ -888,7 +1009,11 @@ def _populate_render_fk_opts(constraint, opts):
@_constraint_renderers.dispatch_for(sa_schema.ForeignKeyConstraint)
-def _render_foreign_key(constraint, autogen_context, namespace_metadata):
+def _render_foreign_key(
+ constraint: "ForeignKeyConstraint",
+ autogen_context: "AutogenContext",
+ namespace_metadata: "MetaData",
+) -> Optional[str]:
rendered = _user_defined_render("foreign_key", constraint, autogen_context)
if rendered is not False:
return rendered
@@ -908,7 +1033,8 @@ def _render_foreign_key(constraint, autogen_context, namespace_metadata):
% {
"prefix": _sqlalchemy_autogenerate_prefix(autogen_context),
"cols": ", ".join(
- "%r" % _ident(f.parent.name) for f in constraint.elements
+ "%r" % _ident(cast("Column", f.parent).name)
+ for f in constraint.elements
),
"refcols": ", ".join(
repr(_fk_colspec(f, apply_metadata_schema, namespace_metadata))
@@ -922,7 +1048,11 @@ def _render_foreign_key(constraint, autogen_context, namespace_metadata):
@_constraint_renderers.dispatch_for(sa_schema.UniqueConstraint)
-def _render_unique_constraint(constraint, autogen_context, namespace_metadata):
+def _render_unique_constraint(
+ constraint: "UniqueConstraint",
+ autogen_context: "AutogenContext",
+ namespace_metadata: Optional["MetaData"],
+) -> str:
rendered = _user_defined_render("unique", constraint, autogen_context)
if rendered is not False:
return rendered
@@ -931,7 +1061,11 @@ def _render_unique_constraint(constraint, autogen_context, namespace_metadata):
@_constraint_renderers.dispatch_for(sa_schema.CheckConstraint)
-def _render_check_constraint(constraint, autogen_context, namespace_metadata):
+def _render_check_constraint(
+ constraint: "CheckConstraint",
+ autogen_context: "AutogenContext",
+ namespace_metadata: Optional["MetaData"],
+) -> Optional[str]:
rendered = _user_defined_render("check", constraint, autogen_context)
if rendered is not False:
return rendered
@@ -941,9 +1075,14 @@ def _render_check_constraint(constraint, autogen_context, namespace_metadata):
# ideally SQLAlchemy would give us more of a first class
# way to detect this.
if (
- constraint._create_rule
- and hasattr(constraint._create_rule, "target")
- and isinstance(constraint._create_rule.target, sqltypes.TypeEngine)
+ constraint._create_rule # type:ignore[attr-defined]
+ and hasattr(
+ constraint._create_rule, "target" # type:ignore[attr-defined]
+ )
+ and isinstance(
+ constraint._create_rule.target, # type:ignore[attr-defined]
+ sqltypes.TypeEngine,
+ )
):
return None
opts = []
@@ -963,7 +1102,9 @@ def _render_check_constraint(constraint, autogen_context, namespace_metadata):
@renderers.dispatch_for(ops.ExecuteSQLOp)
-def _execute_sql(autogen_context, op):
+def _execute_sql(
+ autogen_context: "AutogenContext", op: "ops.ExecuteSQLOp"
+) -> str:
if not isinstance(op.sqltext, string_types):
raise NotImplementedError(
"Autogenerate rendering of SQL Expression language constructs "
diff --git a/alembic/autogenerate/rewriter.py b/alembic/autogenerate/rewriter.py
index ba9a06d..0fdd398 100644
--- a/alembic/autogenerate/rewriter.py
+++ b/alembic/autogenerate/rewriter.py
@@ -1,6 +1,25 @@
+from typing import Any
+from typing import Callable
+from typing import Iterator
+from typing import List
+from typing import Type
+from typing import TYPE_CHECKING
+from typing import Union
+
from alembic import util
from alembic.operations import ops
+if TYPE_CHECKING:
+ from alembic.operations.ops import AddColumnOp
+ from alembic.operations.ops import AlterColumnOp
+ from alembic.operations.ops import CreateTableOp
+ from alembic.operations.ops import MigrateOperation
+ from alembic.operations.ops import MigrationScript
+ from alembic.operations.ops import ModifyTableOps
+ from alembic.operations.ops import OpContainer
+ from alembic.runtime.migration import MigrationContext
+ from alembic.script.revision import Revision
+
class Rewriter:
"""A helper object that allows easy 'rewriting' of ops streams.
@@ -32,10 +51,10 @@ class Rewriter:
_chained = None
- def __init__(self):
+ def __init__(self) -> None:
self.dispatch = util.Dispatcher()
- def chain(self, other):
+ def chain(self, other: "Rewriter") -> "Rewriter":
"""Produce a "chain" of this :class:`.Rewriter` to another.
This allows two rewriters to operate serially on a stream,
@@ -70,7 +89,16 @@ class Rewriter:
wr._chained = other
return wr
- def rewrites(self, operator):
+ def rewrites(
+ self,
+ operator: Union[
+ Type["AddColumnOp"],
+ Type["MigrateOperation"],
+ Type["AlterColumnOp"],
+ Type["CreateTableOp"],
+ Type["ModifyTableOps"],
+ ],
+ ) -> Callable:
"""Register a function as rewriter for a given type.
The function should receive three arguments, which are
@@ -85,7 +113,12 @@ class Rewriter:
"""
return self.dispatch.dispatch_for(operator)
- def _rewrite(self, context, revision, directive):
+ def _rewrite(
+ self,
+ context: "MigrationContext",
+ revision: "Revision",
+ directive: "MigrateOperation",
+ ) -> Iterator["MigrateOperation"]:
try:
_rewriter = self.dispatch.dispatch(directive)
except ValueError:
@@ -96,20 +129,30 @@ class Rewriter:
yield directive
else:
for r_directive in util.to_list(
- _rewriter(context, revision, directive)
+ _rewriter(context, revision, directive), []
):
r_directive._mutations = r_directive._mutations.union(
[self]
)
yield r_directive
- def __call__(self, context, revision, directives):
+ def __call__(
+ self,
+ context: "MigrationContext",
+ revision: "Revision",
+ directives: List["MigrationScript"],
+ ) -> None:
self.process_revision_directives(context, revision, directives)
if self._chained:
self._chained(context, revision, directives)
@_traverse.dispatch_for(ops.MigrationScript)
- def _traverse_script(self, context, revision, directive):
+ def _traverse_script(
+ self,
+ context: "MigrationContext",
+ revision: "Revision",
+ directive: "MigrationScript",
+ ) -> None:
upgrade_ops_list = []
for upgrade_ops in directive.upgrade_ops_list:
ret = self._traverse_for(context, revision, upgrade_ops)
@@ -131,26 +174,51 @@ class Rewriter:
directive.downgrade_ops = downgrade_ops_list
@_traverse.dispatch_for(ops.OpContainer)
- def _traverse_op_container(self, context, revision, directive):
+ def _traverse_op_container(
+ self,
+ context: "MigrationContext",
+ revision: "Revision",
+ directive: "OpContainer",
+ ) -> None:
self._traverse_list(context, revision, directive.ops)
@_traverse.dispatch_for(ops.MigrateOperation)
- def _traverse_any_directive(self, context, revision, directive):
+ def _traverse_any_directive(
+ self,
+ context: "MigrationContext",
+ revision: "Revision",
+ directive: "MigrateOperation",
+ ) -> None:
pass
- def _traverse_for(self, context, revision, directive):
+ def _traverse_for(
+ self,
+ context: "MigrationContext",
+ revision: "Revision",
+ directive: "MigrateOperation",
+ ) -> Any:
directives = list(self._rewrite(context, revision, directive))
for directive in directives:
traverser = self._traverse.dispatch(directive)
traverser(self, context, revision, directive)
return directives
- def _traverse_list(self, context, revision, directives):
+ def _traverse_list(
+ self,
+ context: "MigrationContext",
+ revision: "Revision",
+ directives: Any,
+ ) -> None:
dest = []
for directive in directives:
dest.extend(self._traverse_for(context, revision, directive))
directives[:] = dest
- def process_revision_directives(self, context, revision, directives):
+ def process_revision_directives(
+ self,
+ context: "MigrationContext",
+ revision: "Revision",
+ directives: List["MigrationScript"],
+ ) -> None:
self._traverse_list(context, revision, directives)
diff --git a/alembic/command.py b/alembic/command.py
index ada458d..1e79460 100644
--- a/alembic/command.py
+++ b/alembic/command.py
@@ -1,10 +1,20 @@
import os
+from typing import Callable
+from typing import cast
+from typing import List
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import Union
from . import autogenerate as autogen
from . import util
from .runtime.environment import EnvironmentContext
from .script import ScriptDirectory
+if TYPE_CHECKING:
+ from alembic.config import Config
+ from alembic.script.base import Script
+
def list_templates(config):
"""List available templates.
@@ -25,7 +35,12 @@ def list_templates(config):
config.print_stdout("\n alembic init --template generic ./scripts")
-def init(config, directory, template="generic", package=False):
+def init(
+ config: "Config",
+ directory: str,
+ template: str = "generic",
+ package: bool = False,
+) -> None:
"""Initialize a new scripts directory.
:param config: a :class:`.Config` object.
@@ -71,8 +86,8 @@ def init(config, directory, template="generic", package=False):
for file_ in os.listdir(template_dir):
file_path = os.path.join(template_dir, file_)
if file_ == "alembic.ini.mako":
- config_file = os.path.abspath(config.config_file_name)
- if os.access(config_file, os.F_OK):
+ config_file = os.path.abspath(cast(str, config.config_file_name))
+ if os.access(cast(str, config_file), os.F_OK):
util.msg("File %s already exists, skipping" % config_file)
else:
script._generate_template(
@@ -88,7 +103,7 @@ def init(config, directory, template="generic", package=False):
os.path.join(os.path.abspath(versions), "__init__.py"),
]:
file_ = util.status("Adding %s" % path, open, path, "w")
- file_.close()
+ file_.close() # type:ignore[attr-defined]
util.msg(
"Please edit configuration/connection/logging "
@@ -97,18 +112,18 @@ def init(config, directory, template="generic", package=False):
def revision(
- config,
- message=None,
- autogenerate=False,
- sql=False,
- head="head",
- splice=False,
- branch_label=None,
- version_path=None,
- rev_id=None,
- depends_on=None,
- process_revision_directives=None,
-):
+ config: "Config",
+ message: Optional[str] = None,
+ autogenerate: bool = False,
+ sql: bool = False,
+ head: str = "head",
+ splice: bool = False,
+ branch_label: Optional[str] = None,
+ version_path: Optional[str] = None,
+ rev_id: Optional[str] = None,
+ depends_on: Optional[str] = None,
+ process_revision_directives: Callable = None,
+) -> Union[Optional["Script"], List[Optional["Script"]]]:
"""Create a new revision file.
:param config: a :class:`.Config` object.
@@ -223,7 +238,13 @@ def revision(
return scripts
-def merge(config, revisions, message=None, branch_label=None, rev_id=None):
+def merge(
+ config: "Config",
+ revisions: str,
+ message: str = None,
+ branch_label: str = None,
+ rev_id: str = None,
+) -> Optional["Script"]:
"""Merge two revisions together. Creates a new migration file.
:param config: a :class:`.Config` instance
@@ -243,7 +264,7 @@ def merge(config, revisions, message=None, branch_label=None, rev_id=None):
script = ScriptDirectory.from_config(config)
template_args = {
- "config": config # Let templates use config for
+ "config": "config" # Let templates use config for
# e.g. multiple databases
}
return script.generate_revision(
@@ -252,11 +273,16 @@ def merge(config, revisions, message=None, branch_label=None, rev_id=None):
refresh=True,
head=revisions,
branch_labels=branch_label,
- **template_args
+ **template_args # type:ignore[arg-type]
)
-def upgrade(config, revision, sql=False, tag=None):
+def upgrade(
+ config: "Config",
+ revision: str,
+ sql: bool = False,
+ tag: Optional[str] = None,
+) -> None:
"""Upgrade to a later version.
:param config: a :class:`.Config` instance.
@@ -294,7 +320,12 @@ def upgrade(config, revision, sql=False, tag=None):
script.run_env()
-def downgrade(config, revision, sql=False, tag=None):
+def downgrade(
+ config: "Config",
+ revision: str,
+ sql: bool = False,
+ tag: Optional[str] = None,
+) -> None:
"""Revert to a previous version.
:param config: a :class:`.Config` instance.
@@ -360,7 +391,12 @@ def show(config, rev):
config.print_stdout(sc.log_entry)
-def history(config, rev_range=None, verbose=False, indicate_current=False):
+def history(
+ config: "Config",
+ rev_range: Optional[str] = None,
+ verbose: bool = False,
+ indicate_current: bool = False,
+) -> None:
"""List changeset scripts in chronological order.
:param config: a :class:`.Config` instance.
@@ -372,7 +408,8 @@ def history(config, rev_range=None, verbose=False, indicate_current=False):
:param indicate_current: indicate current revision.
"""
-
+ base: Optional[str]
+ head: Optional[str]
script = ScriptDirectory.from_config(config)
if rev_range is not None:
if ":" not in rev_range:
@@ -478,7 +515,7 @@ def branches(config, verbose=False):
)
-def current(config, verbose=False):
+def current(config: "Config", verbose: bool = False) -> None:
"""Display the current revision for a database.
:param config: a :class:`.Config` instance.
@@ -506,7 +543,13 @@ def current(config, verbose=False):
script.run_env()
-def stamp(config, revision, sql=False, tag=None, purge=False):
+def stamp(
+ config: "Config",
+ revision: str,
+ sql: bool = False,
+ tag: Optional[str] = None,
+ purge: bool = False,
+) -> None:
"""'stamp' the revision table with the given revision; don't
run any migrations.
@@ -570,7 +613,7 @@ def stamp(config, revision, sql=False, tag=None, purge=False):
script.run_env()
-def edit(config, rev):
+def edit(config: "Config", rev: str) -> None:
"""Edit revision script(s) using $EDITOR.
:param config: a :class:`.Config` instance.
diff --git a/alembic/config.py b/alembic/config.py
index b8b465d..dbcd106 100644
--- a/alembic/config.py
+++ b/alembic/config.py
@@ -1,8 +1,13 @@
from argparse import ArgumentParser
+from argparse import Namespace
from configparser import ConfigParser
import inspect
import os
import sys
+from typing import Dict
+from typing import Optional
+from typing import overload
+from typing import TextIO
from . import __version__
from . import command
@@ -86,14 +91,14 @@ class Config:
def __init__(
self,
- file_=None,
- ini_section="alembic",
- output_buffer=None,
- stdout=sys.stdout,
- cmd_opts=None,
- config_args=util.immutabledict(),
- attributes=None,
- ):
+ file_: Optional[str] = None,
+ ini_section: str = "alembic",
+ output_buffer: Optional[TextIO] = None,
+ stdout: TextIO = sys.stdout,
+ cmd_opts: Optional[Namespace] = None,
+ config_args: util.immutabledict = util.immutabledict(),
+ attributes: dict = None,
+ ) -> None:
"""Construct a new :class:`.Config`"""
self.config_file_name = file_
self.config_ini_section = ini_section
@@ -104,7 +109,7 @@ class Config:
if attributes:
self.attributes.update(attributes)
- cmd_opts = None
+ cmd_opts: Optional[Namespace] = None
"""The command-line options passed to the ``alembic`` script.
Within an ``env.py`` script this can be accessed via the
@@ -116,10 +121,10 @@ class Config:
"""
- config_file_name = None
+ config_file_name: Optional[str] = None
"""Filesystem path to the .ini file in use."""
- config_ini_section = None
+ config_ini_section: str = None # type:ignore[assignment]
"""Name of the config file section to read basic configuration
from. Defaults to ``alembic``, that is the ``[alembic]`` section
of the .ini file. This value is modified using the ``-n/--name``
@@ -147,7 +152,7 @@ class Config:
"""
return {}
- def print_stdout(self, text, *arg):
+ def print_stdout(self, text: str, *arg) -> None:
"""Render a message to standard out.
When :meth:`.Config.print_stdout` is called with additional args
@@ -191,7 +196,7 @@ class Config:
file_config.add_section(self.config_ini_section)
return file_config
- def get_template_directory(self):
+ def get_template_directory(self) -> str:
"""Return the directory where Alembic setup templates are found.
This method is used by the alembic ``init`` and ``list_templates``
@@ -203,7 +208,19 @@ class Config:
package_dir = os.path.abspath(os.path.dirname(alembic.__file__))
return os.path.join(package_dir, "templates")
- def get_section(self, name, default=None):
+ @overload
+ def get_section(
+ self, name: str, default: Dict[str, str]
+ ) -> Dict[str, str]:
+ ...
+
+ @overload
+ def get_section(
+ self, name: str, default: Optional[Dict[str, str]] = ...
+ ) -> Optional[Dict[str, str]]:
+ ...
+
+ def get_section(self, name: str, default=None):
"""Return all the configuration options from a given .ini file section
as a dictionary.
@@ -213,7 +230,7 @@ class Config:
return dict(self.file_config.items(name))
- def set_main_option(self, name, value):
+ def set_main_option(self, name: str, value: str) -> None:
"""Set an option programmatically within the 'main' section.
This overrides whatever was in the .ini file.
@@ -230,10 +247,10 @@ class Config:
"""
self.set_section_option(self.config_ini_section, name, value)
- def remove_main_option(self, name):
+ def remove_main_option(self, name: str) -> None:
self.file_config.remove_option(self.config_ini_section, name)
- def set_section_option(self, section, name, value):
+ def set_section_option(self, section: str, name: str, value: str) -> None:
"""Set an option programmatically within the given section.
The section is created if it doesn't exist already.
@@ -257,7 +274,9 @@ class Config:
self.file_config.add_section(section)
self.file_config.set(section, name, value)
- def get_section_option(self, section, name, default=None):
+ def get_section_option(
+ self, section: str, name: str, default: Optional[str] = None
+ ) -> Optional[str]:
"""Return an option from the given section of the .ini file."""
if not self.file_config.has_section(section):
raise util.CommandError(
@@ -269,6 +288,16 @@ class Config:
else:
return default
+ @overload
+ def get_main_option(self, name: str, default: str) -> str:
+ ...
+
+ @overload
+ def get_main_option(
+ self, name: str, default: Optional[str] = None
+ ) -> Optional[str]:
+ ...
+
def get_main_option(self, name, default=None):
"""Return an option from the 'main' section of the .ini file.
@@ -281,10 +310,10 @@ class Config:
class CommandLine:
- def __init__(self, prog=None):
+ def __init__(self, prog: Optional[str] = None) -> None:
self._generate_args(prog)
- def _generate_args(self, prog):
+ def _generate_args(self, prog: Optional[str]) -> None:
def add_options(fn, parser, positional, kwargs):
kwargs_opts = {
"template": (
@@ -515,7 +544,7 @@ class CommandLine:
else:
help_text.append(line.strip())
else:
- help_text = ""
+ help_text = []
subparser = subparsers.add_parser(
fn.__name__, help=" ".join(help_text)
)
@@ -523,7 +552,7 @@ class CommandLine:
subparser.set_defaults(cmd=(fn, positional, kwarg))
self.parser = parser
- def run_cmd(self, config, options):
+ def run_cmd(self, config: Config, options: Namespace) -> None:
fn, positional, kwarg = options.cmd
try:
diff --git a/alembic/ddl/base.py b/alembic/ddl/base.py
index da81c72..022dc24 100644
--- a/alembic/ddl/base.py
+++ b/alembic/ddl/base.py
@@ -1,4 +1,7 @@
import functools
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import Union
from sqlalchemy import exc
from sqlalchemy import Integer
@@ -14,6 +17,20 @@ from ..util.sqla_compat import _fk_spec # noqa
from ..util.sqla_compat import _is_type_bound # noqa
from ..util.sqla_compat import _table_for_constraint # noqa
+if TYPE_CHECKING:
+ from sqlalchemy.sql.compiler import Compiled
+ from sqlalchemy.sql.compiler import DDLCompiler
+ from sqlalchemy.sql.elements import TextClause
+ from sqlalchemy.sql.functions import Function
+ from sqlalchemy.sql.schema import FetchedValue
+ from sqlalchemy.sql.type_api import TypeEngine
+
+ from .impl import DefaultImpl
+ from ..util.sqla_compat import Computed
+ from ..util.sqla_compat import Identity
+
+_ServerDefault = Union["TextClause", "FetchedValue", "Function", str]
+
class AlterTable(DDLElement):
@@ -24,13 +41,22 @@ class AlterTable(DDLElement):
"""
- def __init__(self, table_name, schema=None):
+ def __init__(
+ self,
+ table_name: str,
+ schema: Optional[Union["quoted_name", str]] = None,
+ ) -> None:
self.table_name = table_name
self.schema = schema
class RenameTable(AlterTable):
- def __init__(self, old_table_name, new_table_name, schema=None):
+ def __init__(
+ self,
+ old_table_name: str,
+ new_table_name: Union["quoted_name", str],
+ schema: Optional[Union["quoted_name", str]] = None,
+ ) -> None:
super(RenameTable, self).__init__(old_table_name, schema=schema)
self.new_table_name = new_table_name
@@ -38,14 +64,14 @@ class RenameTable(AlterTable):
class AlterColumn(AlterTable):
def __init__(
self,
- name,
- column_name,
- schema=None,
- existing_type=None,
- existing_nullable=None,
- existing_server_default=None,
- existing_comment=None,
- ):
+ name: str,
+ column_name: str,
+ schema: Optional[str] = None,
+ existing_type: Optional["TypeEngine"] = None,
+ existing_nullable: Optional[bool] = None,
+ existing_server_default: Optional[_ServerDefault] = None,
+ existing_comment: Optional[str] = None,
+ ) -> None:
super(AlterColumn, self).__init__(name, schema=schema)
self.column_name = column_name
self.existing_type = (
@@ -59,62 +85,94 @@ class AlterColumn(AlterTable):
class ColumnNullable(AlterColumn):
- def __init__(self, name, column_name, nullable, **kw):
+ def __init__(
+ self, name: str, column_name: str, nullable: bool, **kw
+ ) -> None:
super(ColumnNullable, self).__init__(name, column_name, **kw)
self.nullable = nullable
class ColumnType(AlterColumn):
- def __init__(self, name, column_name, type_, **kw):
+ def __init__(
+ self, name: str, column_name: str, type_: "TypeEngine", **kw
+ ) -> None:
super(ColumnType, self).__init__(name, column_name, **kw)
self.type_ = sqltypes.to_instance(type_)
class ColumnName(AlterColumn):
- def __init__(self, name, column_name, newname, **kw):
+ def __init__(
+ self, name: str, column_name: str, newname: str, **kw
+ ) -> None:
super(ColumnName, self).__init__(name, column_name, **kw)
self.newname = newname
class ColumnDefault(AlterColumn):
- def __init__(self, name, column_name, default, **kw):
+ def __init__(
+ self,
+ name: str,
+ column_name: str,
+ default: Optional[_ServerDefault],
+ **kw
+ ) -> None:
super(ColumnDefault, self).__init__(name, column_name, **kw)
self.default = default
class ComputedColumnDefault(AlterColumn):
- def __init__(self, name, column_name, default, **kw):
+ def __init__(
+ self, name: str, column_name: str, default: Optional["Computed"], **kw
+ ) -> None:
super(ComputedColumnDefault, self).__init__(name, column_name, **kw)
self.default = default
class IdentityColumnDefault(AlterColumn):
- def __init__(self, name, column_name, default, impl, **kw):
+ def __init__(
+ self,
+ name: str,
+ column_name: str,
+ default: Optional["Identity"],
+ impl: "DefaultImpl",
+ **kw
+ ) -> None:
super(IdentityColumnDefault, self).__init__(name, column_name, **kw)
self.default = default
self.impl = impl
class AddColumn(AlterTable):
- def __init__(self, name, column, schema=None):
+ def __init__(
+ self,
+ name: str,
+ column: "Column",
+ schema: Optional[Union["quoted_name", str]] = None,
+ ) -> None:
super(AddColumn, self).__init__(name, schema=schema)
self.column = column
class DropColumn(AlterTable):
- def __init__(self, name, column, schema=None):
+ def __init__(
+ self, name: str, column: "Column", schema: Optional[str] = None
+ ) -> None:
super(DropColumn, self).__init__(name, schema=schema)
self.column = column
class ColumnComment(AlterColumn):
- def __init__(self, name, column_name, comment, **kw):
+ def __init__(
+ self, name: str, column_name: str, comment: Optional[str], **kw
+ ) -> None:
super(ColumnComment, self).__init__(name, column_name, **kw)
self.comment = comment
@compiles(RenameTable)
-def visit_rename_table(element, compiler, **kw):
+def visit_rename_table(
+ element: "RenameTable", compiler: "DDLCompiler", **kw
+) -> str:
return "%s RENAME TO %s" % (
alter_table(compiler, element.table_name, element.schema),
format_table_name(compiler, element.new_table_name, element.schema),
@@ -122,7 +180,9 @@ def visit_rename_table(element, compiler, **kw):
@compiles(AddColumn)
-def visit_add_column(element, compiler, **kw):
+def visit_add_column(
+ element: "AddColumn", compiler: "DDLCompiler", **kw
+) -> str:
return "%s %s" % (
alter_table(compiler, element.table_name, element.schema),
add_column(compiler, element.column, **kw),
@@ -130,7 +190,9 @@ def visit_add_column(element, compiler, **kw):
@compiles(DropColumn)
-def visit_drop_column(element, compiler, **kw):
+def visit_drop_column(
+ element: "DropColumn", compiler: "DDLCompiler", **kw
+) -> str:
return "%s %s" % (
alter_table(compiler, element.table_name, element.schema),
drop_column(compiler, element.column.name, **kw),
@@ -138,7 +200,9 @@ def visit_drop_column(element, compiler, **kw):
@compiles(ColumnNullable)
-def visit_column_nullable(element, compiler, **kw):
+def visit_column_nullable(
+ element: "ColumnNullable", compiler: "DDLCompiler", **kw
+) -> str:
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
@@ -147,7 +211,9 @@ def visit_column_nullable(element, compiler, **kw):
@compiles(ColumnType)
-def visit_column_type(element, compiler, **kw):
+def visit_column_type(
+ element: "ColumnType", compiler: "DDLCompiler", **kw
+) -> str:
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
@@ -156,7 +222,9 @@ def visit_column_type(element, compiler, **kw):
@compiles(ColumnName)
-def visit_column_name(element, compiler, **kw):
+def visit_column_name(
+ element: "ColumnName", compiler: "DDLCompiler", **kw
+) -> str:
return "%s RENAME %s TO %s" % (
alter_table(compiler, element.table_name, element.schema),
format_column_name(compiler, element.column_name),
@@ -165,7 +233,9 @@ def visit_column_name(element, compiler, **kw):
@compiles(ColumnDefault)
-def visit_column_default(element, compiler, **kw):
+def visit_column_default(
+ element: "ColumnDefault", compiler: "DDLCompiler", **kw
+) -> str:
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
@@ -176,7 +246,9 @@ def visit_column_default(element, compiler, **kw):
@compiles(ComputedColumnDefault)
-def visit_computed_column(element, compiler, **kw):
+def visit_computed_column(
+ element: "ComputedColumnDefault", compiler: "DDLCompiler", **kw
+):
raise exc.CompileError(
'Adding or removing a "computed" construct, e.g. GENERATED '
"ALWAYS AS, to or from an existing column is not supported."
@@ -184,7 +256,9 @@ def visit_computed_column(element, compiler, **kw):
@compiles(IdentityColumnDefault)
-def visit_identity_column(element, compiler, **kw):
+def visit_identity_column(
+ element: "IdentityColumnDefault", compiler: "DDLCompiler", **kw
+):
raise exc.CompileError(
'Adding, removing or modifying an "identity" construct, '
"e.g. GENERATED AS IDENTITY, to or from an existing "
@@ -192,7 +266,9 @@ def visit_identity_column(element, compiler, **kw):
)
-def quote_dotted(name, quote):
+def quote_dotted(
+ name: Union["quoted_name", str], quote: functools.partial
+) -> Union["quoted_name", str]:
"""quote the elements of a dotted name"""
if isinstance(name, quoted_name):
@@ -201,7 +277,11 @@ def quote_dotted(name, quote):
return result
-def format_table_name(compiler, name, schema):
+def format_table_name(
+ compiler: "Compiled",
+ name: Union["quoted_name", str],
+ schema: Optional[Union["quoted_name", str]],
+) -> Union["quoted_name", str]:
quote = functools.partial(compiler.preparer.quote)
if schema:
return quote_dotted(schema, quote) + "." + quote(name)
@@ -209,33 +289,42 @@ def format_table_name(compiler, name, schema):
return quote(name)
-def format_column_name(compiler, name):
+def format_column_name(
+ compiler: "DDLCompiler", name: Optional[Union["quoted_name", str]]
+) -> Union["quoted_name", str]:
return compiler.preparer.quote(name)
-def format_server_default(compiler, default):
+def format_server_default(
+ compiler: "DDLCompiler",
+ default: Optional[_ServerDefault],
+) -> str:
return compiler.get_column_default_string(
Column("x", Integer, server_default=default)
)
-def format_type(compiler, type_):
+def format_type(compiler: "DDLCompiler", type_: "TypeEngine") -> str:
return compiler.dialect.type_compiler.process(type_)
-def alter_table(compiler, name, schema):
+def alter_table(
+ compiler: "DDLCompiler",
+ name: str,
+ schema: Optional[str],
+) -> str:
return "ALTER TABLE %s" % format_table_name(compiler, name, schema)
-def drop_column(compiler, name):
+def drop_column(compiler: "DDLCompiler", name: str, **kw) -> str:
return "DROP COLUMN %s" % format_column_name(compiler, name)
-def alter_column(compiler, name):
+def alter_column(compiler: "DDLCompiler", name: str) -> str:
return "ALTER COLUMN %s" % format_column_name(compiler, name)
-def add_column(compiler, column, **kw):
+def add_column(compiler: "DDLCompiler", column: "Column", **kw) -> str:
text = "ADD COLUMN %s" % compiler.get_column_specification(column, **kw)
const = " ".join(
diff --git a/alembic/ddl/impl.py b/alembic/ddl/impl.py
index 710509c..2ca316c 100644
--- a/alembic/ddl/impl.py
+++ b/alembic/ddl/impl.py
@@ -1,5 +1,16 @@
from collections import namedtuple
import re
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Set
+from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
+from typing import Union
from sqlalchemy import cast
from sqlalchemy import schema
@@ -11,16 +22,49 @@ from ..util import sqla_compat
from ..util.compat import string_types
from ..util.compat import text_type
+if TYPE_CHECKING:
+ from io import StringIO
+ from typing import Literal
+
+ from sqlalchemy.engine import Connection
+ from sqlalchemy.engine import Dialect
+ from sqlalchemy.engine.cursor import CursorResult
+ from sqlalchemy.engine.cursor import LegacyCursorResult
+ from sqlalchemy.engine.reflection import Inspector
+ from sqlalchemy.sql.dml import Update
+ from sqlalchemy.sql.elements import ClauseElement
+ from sqlalchemy.sql.elements import ColumnElement
+ from sqlalchemy.sql.elements import quoted_name
+ from sqlalchemy.sql.elements import TextClause
+ from sqlalchemy.sql.schema import Column
+ from sqlalchemy.sql.schema import Constraint
+ from sqlalchemy.sql.schema import ForeignKeyConstraint
+ from sqlalchemy.sql.schema import Index
+ from sqlalchemy.sql.schema import Table
+ from sqlalchemy.sql.schema import UniqueConstraint
+ from sqlalchemy.sql.selectable import TableClause
+ from sqlalchemy.sql.type_api import TypeEngine
+
+ from .base import _ServerDefault
+ from ..autogenerate.api import AutogenContext
+ from ..operations.batch import ApplyBatchImpl
+ from ..operations.batch import BatchOperationsImpl
+
class ImplMeta(type):
- def __init__(cls, classname, bases, dict_):
+ def __init__(
+ cls,
+ classname: str,
+ bases: Tuple[Type["DefaultImpl"]],
+ dict_: Dict[str, Any],
+ ):
newtype = type.__init__(cls, classname, bases, dict_)
if "__dialect__" in dict_:
_impls[dict_["__dialect__"]] = cls
return newtype
-_impls = {}
+_impls: dict = {}
Params = namedtuple("Params", ["token0", "tokens", "args", "kwargs"])
@@ -43,27 +87,27 @@ class DefaultImpl(metaclass=ImplMeta):
transactional_ddl = False
command_terminator = ";"
- type_synonyms = ({"NUMERIC", "DECIMAL"},)
- type_arg_extract = ()
+ type_synonyms: Tuple[Set[str], ...] = ({"NUMERIC", "DECIMAL"},)
+ type_arg_extract: Sequence[str] = ()
# on_null is known to be supported only by oracle
- identity_attrs_ignore = ("on_null",)
+ identity_attrs_ignore: Tuple[str, ...] = ("on_null",)
def __init__(
self,
- dialect,
- connection,
- as_sql,
- transactional_ddl,
- output_buffer,
- context_opts,
- ):
+ dialect: "Dialect",
+ connection: Optional["Connection"],
+ as_sql: bool,
+ transactional_ddl: Optional[bool],
+ output_buffer: Optional["StringIO"],
+ context_opts: Dict[str, Any],
+ ) -> None:
self.dialect = dialect
self.connection = connection
self.as_sql = as_sql
self.literal_binds = context_opts.get("literal_binds", False)
self.output_buffer = output_buffer
- self.memo = {}
+ self.memo: dict = {}
self.context_opts = context_opts
if transactional_ddl is not None:
self.transactional_ddl = transactional_ddl
@@ -75,14 +119,17 @@ class DefaultImpl(metaclass=ImplMeta):
)
@classmethod
- def get_by_dialect(cls, dialect):
+ def get_by_dialect(cls, dialect: "Dialect") -> Any:
return _impls[dialect.name]
- def static_output(self, text):
+ def static_output(self, text: str) -> None:
+ assert self.output_buffer is not None
self.output_buffer.write(text_type(text + "\n\n"))
self.output_buffer.flush()
- def requires_recreate_in_batch(self, batch_op):
+ def requires_recreate_in_batch(
+ self, batch_op: "BatchOperationsImpl"
+ ) -> bool:
"""Return True if the given :class:`.BatchOperationsImpl`
would need the table to be recreated and copied in order to
proceed.
@@ -93,7 +140,9 @@ class DefaultImpl(metaclass=ImplMeta):
"""
return False
- def prep_table_for_batch(self, batch_impl, table):
+ def prep_table_for_batch(
+ self, batch_impl: "ApplyBatchImpl", table: "Table"
+ ) -> None:
"""perform any operations needed on a table before a new
one is created to replace it in batch mode.
@@ -103,16 +152,16 @@ class DefaultImpl(metaclass=ImplMeta):
"""
@property
- def bind(self):
+ def bind(self) -> Optional["Connection"]:
return self.connection
def _exec(
self,
- construct,
- execution_options=None,
- multiparams=(),
- params=util.immutabledict(),
- ):
+ construct: Union["ClauseElement", str],
+ execution_options: None = None,
+ multiparams: Sequence[dict] = (),
+ params: Dict[str, int] = util.immutabledict(),
+ ) -> Optional[Union["LegacyCursorResult", "CursorResult"]]:
if isinstance(construct, string_types):
construct = text(construct)
if self.as_sql:
@@ -135,35 +184,43 @@ class DefaultImpl(metaclass=ImplMeta):
.strip()
+ self.command_terminator
)
+ return None
else:
conn = self.connection
+ assert conn is not None
if execution_options:
conn = conn.execution_options(**execution_options)
if params:
+ assert isinstance(multiparams, tuple)
multiparams += (params,)
return conn.execute(construct, multiparams)
- def execute(self, sql, execution_options=None):
+ def execute(
+ self,
+ sql: Union["Update", "TextClause", str],
+ execution_options: None = None,
+ ) -> None:
self._exec(sql, execution_options)
def alter_column(
self,
- table_name,
- column_name,
- nullable=None,
- server_default=False,
- name=None,
- type_=None,
- schema=None,
- autoincrement=None,
- comment=False,
- existing_comment=None,
- existing_type=None,
- existing_server_default=None,
- existing_nullable=None,
- existing_autoincrement=None,
- ):
+ table_name: str,
+ column_name: str,
+ nullable: Optional[bool] = None,
+ server_default: Union["_ServerDefault", "Literal[False]"] = False,
+ name: Optional[str] = None,
+ type_: Optional["TypeEngine"] = None,
+ schema: Optional[str] = None,
+ autoincrement: Optional[bool] = None,
+ comment: Optional[Union[str, "Literal[False]"]] = False,
+ existing_comment: Optional[str] = None,
+ existing_type: Optional["TypeEngine"] = None,
+ existing_server_default: Optional["_ServerDefault"] = None,
+ existing_nullable: Optional[bool] = None,
+ existing_autoincrement: Optional[bool] = None,
+ **kw: Any
+ ) -> None:
if autoincrement is not None or existing_autoincrement is not None:
util.warn(
"autoincrement and existing_autoincrement "
@@ -185,6 +242,13 @@ class DefaultImpl(metaclass=ImplMeta):
)
if server_default is not False:
kw = {}
+ cls_: Type[
+ Union[
+ base.ComputedColumnDefault,
+ base.IdentityColumnDefault,
+ base.ColumnDefault,
+ ]
+ ]
if sqla_compat._server_default_is_computed(
server_default, existing_server_default
):
@@ -200,7 +264,7 @@ class DefaultImpl(metaclass=ImplMeta):
cls_(
table_name,
column_name,
- server_default,
+ server_default, # type:ignore[arg-type]
schema=schema,
existing_type=existing_type,
existing_server_default=existing_server_default,
@@ -251,25 +315,41 @@ class DefaultImpl(metaclass=ImplMeta):
)
)
- def add_column(self, table_name, column, schema=None):
+ def add_column(
+ self,
+ table_name: str,
+ column: "Column",
+ schema: Optional[Union[str, "quoted_name"]] = None,
+ ) -> None:
self._exec(base.AddColumn(table_name, column, schema=schema))
- def drop_column(self, table_name, column, schema=None, **kw):
+ def drop_column(
+ self,
+ table_name: str,
+ column: "Column",
+ schema: Optional[str] = None,
+ **kw
+ ) -> None:
self._exec(base.DropColumn(table_name, column, schema=schema))
- def add_constraint(self, const):
+ def add_constraint(self, const: Any) -> None:
if const._create_rule is None or const._create_rule(self):
self._exec(schema.AddConstraint(const))
- def drop_constraint(self, const):
+ def drop_constraint(self, const: "Constraint") -> None:
self._exec(schema.DropConstraint(const))
- def rename_table(self, old_table_name, new_table_name, schema=None):
+ def rename_table(
+ self,
+ old_table_name: str,
+ new_table_name: Union[str, "quoted_name"],
+ schema: Optional[Union[str, "quoted_name"]] = None,
+ ) -> None:
self._exec(
base.RenameTable(old_table_name, new_table_name, schema=schema)
)
- def create_table(self, table):
+ def create_table(self, table: "Table") -> None:
table.dispatch.before_create(
table, self.connection, checkfirst=False, _ddl_runner=self
)
@@ -292,25 +372,30 @@ class DefaultImpl(metaclass=ImplMeta):
if comment and with_comment:
self.create_column_comment(column)
- def drop_table(self, table):
+ def drop_table(self, table: "Table") -> None:
self._exec(schema.DropTable(table))
- def create_index(self, index):
+ def create_index(self, index: "Index") -> None:
self._exec(schema.CreateIndex(index))
- def create_table_comment(self, table):
+ def create_table_comment(self, table: "Table") -> None:
self._exec(schema.SetTableComment(table))
- def drop_table_comment(self, table):
+ def drop_table_comment(self, table: "Table") -> None:
self._exec(schema.DropTableComment(table))
- def create_column_comment(self, column):
+ def create_column_comment(self, column: "ColumnElement") -> None:
self._exec(schema.SetColumnComment(column))
- def drop_index(self, index):
+ def drop_index(self, index: "Index") -> None:
self._exec(schema.DropIndex(index))
- def bulk_insert(self, table, rows, multiinsert=True):
+ def bulk_insert(
+ self,
+ table: Union["TableClause", "Table"],
+ rows: List[dict],
+ multiinsert: bool = True,
+ ) -> None:
if not isinstance(rows, list):
raise TypeError("List expected")
elif rows and not isinstance(rows[0], dict):
@@ -349,7 +434,7 @@ class DefaultImpl(metaclass=ImplMeta):
sqla_compat._insert_inline(table).values(**row)
)
- def _tokenize_column_type(self, column):
+ def _tokenize_column_type(self, column: "Column") -> Params:
definition = self.dialect.type_compiler.process(column.type).lower()
# tokenize the SQLAlchemy-generated version of a type, so that
@@ -387,7 +472,9 @@ class DefaultImpl(metaclass=ImplMeta):
return params
- def _column_types_match(self, inspector_params, metadata_params):
+ def _column_types_match(
+ self, inspector_params: "Params", metadata_params: "Params"
+ ) -> bool:
if inspector_params.token0 == metadata_params.token0:
return True
@@ -407,7 +494,9 @@ class DefaultImpl(metaclass=ImplMeta):
return True
return False
- def _column_args_match(self, inspected_params, meta_params):
+ def _column_args_match(
+ self, inspected_params: "Params", meta_params: "Params"
+ ) -> bool:
"""We want to compare column parameters. However, we only want
to compare parameters that are set. If they both have `collation`,
we want to make sure they are the same. However, if only one
@@ -438,7 +527,9 @@ class DefaultImpl(metaclass=ImplMeta):
return True
- def compare_type(self, inspector_column, metadata_column):
+ def compare_type(
+ self, inspector_column: "Column", metadata_column: "Column"
+ ) -> bool:
"""Returns True if there ARE differences between the types of the two
columns. Takes impl.type_synonyms into account between retrospected
and metadata types
@@ -463,11 +554,11 @@ class DefaultImpl(metaclass=ImplMeta):
def correct_for_autogen_constraints(
self,
- conn_uniques,
- conn_indexes,
- metadata_unique_constraints,
- metadata_indexes,
- ):
+ conn_uniques: Union[Set["UniqueConstraint"]],
+ conn_indexes: Union[Set["Index"]],
+ metadata_unique_constraints: Set["UniqueConstraint"],
+ metadata_indexes: Set["Index"],
+ ) -> None:
pass
def cast_for_batch_migrate(self, existing, existing_transfer, new_type):
@@ -476,7 +567,9 @@ class DefaultImpl(metaclass=ImplMeta):
existing_transfer["expr"], new_type
)
- def render_ddl_sql_expr(self, expr, is_server_default=False, **kw):
+ def render_ddl_sql_expr(
+ self, expr: "ClauseElement", is_server_default: bool = False, **kw
+ ) -> str:
"""Render a SQL expression that is typically a server default,
index expression, etc.
@@ -489,10 +582,16 @@ class DefaultImpl(metaclass=ImplMeta):
)
return text_type(expr.compile(dialect=self.dialect, **compile_kw))
- def _compat_autogen_column_reflect(self, inspector):
+ def _compat_autogen_column_reflect(
+ self, inspector: "Inspector"
+ ) -> Callable:
return self.autogen_column_reflect
- def correct_for_autogen_foreignkeys(self, conn_fks, metadata_fks):
+ def correct_for_autogen_foreignkeys(
+ self,
+ conn_fks: Set["ForeignKeyConstraint"],
+ metadata_fks: Set["ForeignKeyConstraint"],
+ ) -> None:
pass
def autogen_column_reflect(self, inspector, table, column_info):
@@ -504,7 +603,7 @@ class DefaultImpl(metaclass=ImplMeta):
"""
- def start_migrations(self):
+ def start_migrations(self) -> None:
"""A hook called when :meth:`.EnvironmentContext.run_migrations`
is called.
@@ -512,7 +611,7 @@ class DefaultImpl(metaclass=ImplMeta):
"""
- def emit_begin(self):
+ def emit_begin(self) -> None:
"""Emit the string ``BEGIN``, or the backend-specific
equivalent, on the current connection context.
@@ -522,7 +621,7 @@ class DefaultImpl(metaclass=ImplMeta):
"""
self.static_output("BEGIN" + self.command_terminator)
- def emit_commit(self):
+ def emit_commit(self) -> None:
"""Emit the string ``COMMIT``, or the backend-specific
equivalent, on the current connection context.
@@ -532,7 +631,9 @@ class DefaultImpl(metaclass=ImplMeta):
"""
self.static_output("COMMIT" + self.command_terminator)
- def render_type(self, type_obj, autogen_context):
+ def render_type(
+ self, type_obj: "TypeEngine", autogen_context: "AutogenContext"
+ ) -> Union[str, "Literal[False]"]:
return False
def _compare_identity_default(self, metadata_identity, inspector_identity):
diff --git a/alembic/ddl/mssql.py b/alembic/ddl/mssql.py
index 8a99ee6..9e1ef76 100644
--- a/alembic/ddl/mssql.py
+++ b/alembic/ddl/mssql.py
@@ -1,9 +1,15 @@
+from typing import Any
+from typing import List
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import Union
+
from sqlalchemy import types as sqltypes
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.schema import Column
from sqlalchemy.schema import CreateIndex
-from sqlalchemy.sql.expression import ClauseElement
-from sqlalchemy.sql.expression import Executable
+from sqlalchemy.sql.base import Executable
+from sqlalchemy.sql.elements import ClauseElement
from .base import AddColumn
from .base import alter_column
@@ -21,6 +27,20 @@ from .impl import DefaultImpl
from .. import util
from ..util import sqla_compat
+if TYPE_CHECKING:
+ from typing import Literal
+
+ from sqlalchemy.dialects.mssql.base import MSDDLCompiler
+ from sqlalchemy.dialects.mssql.base import MSSQLCompiler
+ from sqlalchemy.engine.cursor import CursorResult
+ from sqlalchemy.engine.cursor import LegacyCursorResult
+ from sqlalchemy.sql.schema import Index
+ from sqlalchemy.sql.schema import Table
+ from sqlalchemy.sql.selectable import TableClause
+ from sqlalchemy.sql.type_api import TypeEngine
+
+ from .base import _ServerDefault
+
class MSSQLImpl(DefaultImpl):
__dialect__ = "mssql"
@@ -40,40 +60,44 @@ class MSSQLImpl(DefaultImpl):
"order",
)
- def __init__(self, *arg, **kw):
+ def __init__(self, *arg, **kw) -> None:
super(MSSQLImpl, self).__init__(*arg, **kw)
self.batch_separator = self.context_opts.get(
"mssql_batch_separator", self.batch_separator
)
- def _exec(self, construct, *args, **kw):
+ def _exec(
+ self, construct: Any, *args, **kw
+ ) -> Optional[Union["LegacyCursorResult", "CursorResult"]]:
result = super(MSSQLImpl, self)._exec(construct, *args, **kw)
if self.as_sql and self.batch_separator:
self.static_output(self.batch_separator)
return result
- def emit_begin(self):
+ def emit_begin(self) -> None:
self.static_output("BEGIN TRANSACTION" + self.command_terminator)
- def emit_commit(self):
+ def emit_commit(self) -> None:
super(MSSQLImpl, self).emit_commit()
if self.as_sql and self.batch_separator:
self.static_output(self.batch_separator)
- def alter_column(
+ def alter_column( # type:ignore[override]
self,
- table_name,
- column_name,
- nullable=None,
- server_default=False,
- name=None,
- type_=None,
- schema=None,
- existing_type=None,
- existing_server_default=None,
- existing_nullable=None,
- **kw
- ):
+ table_name: str,
+ column_name: str,
+ nullable: Optional[bool] = None,
+ server_default: Optional[
+ Union["_ServerDefault", "Literal[False]"]
+ ] = False,
+ name: Optional[str] = None,
+ type_: Optional["TypeEngine"] = None,
+ schema: Optional[str] = None,
+ existing_type: Optional["TypeEngine"] = None,
+ existing_server_default: Optional["_ServerDefault"] = None,
+ existing_nullable: Optional[bool] = None,
+ **kw: Any
+ ) -> None:
if nullable is not None:
if existing_type is None:
@@ -138,17 +162,20 @@ class MSSQLImpl(DefaultImpl):
table_name, column_name, schema=schema, name=name
)
- def create_index(self, index):
+ def create_index(self, index: "Index") -> None:
# this likely defaults to None if not present, so get()
# should normally not return the default value. being
# defensive in any case
mssql_include = index.kwargs.get("mssql_include", None) or ()
+ assert index.table is not None
for col in mssql_include:
if col not in index.table.c:
index.table.append_column(Column(col, sqltypes.NullType))
self._exec(CreateIndex(index))
- def bulk_insert(self, table, rows, **kw):
+ def bulk_insert( # type:ignore[override]
+ self, table: Union["TableClause", "Table"], rows: List[dict], **kw: Any
+ ) -> None:
if self.as_sql:
self._exec(
"SET IDENTITY_INSERT %s ON"
@@ -162,7 +189,13 @@ class MSSQLImpl(DefaultImpl):
else:
super(MSSQLImpl, self).bulk_insert(table, rows, **kw)
- def drop_column(self, table_name, column, schema=None, **kw):
+ def drop_column(
+ self,
+ table_name: str,
+ column: "Column",
+ schema: Optional[str] = None,
+ **kw
+ ) -> None:
drop_default = kw.pop("mssql_drop_default", False)
if drop_default:
self._exec(
@@ -222,7 +255,13 @@ class MSSQLImpl(DefaultImpl):
class _ExecDropConstraint(Executable, ClauseElement):
- def __init__(self, tname, colname, type_, schema):
+ def __init__(
+ self,
+ tname: str,
+ colname: Union["Column", str],
+ type_: str,
+ schema: Optional[str],
+ ) -> None:
self.tname = tname
self.colname = colname
self.type_ = type_
@@ -230,14 +269,18 @@ class _ExecDropConstraint(Executable, ClauseElement):
class _ExecDropFKConstraint(Executable, ClauseElement):
- def __init__(self, tname, colname, schema):
+ def __init__(
+ self, tname: str, colname: "Column", schema: Optional[str]
+ ) -> None:
self.tname = tname
self.colname = colname
self.schema = schema
@compiles(_ExecDropConstraint, "mssql")
-def _exec_drop_col_constraint(element, compiler, **kw):
+def _exec_drop_col_constraint(
+ element: "_ExecDropConstraint", compiler: "MSSQLCompiler", **kw
+) -> str:
schema, tname, colname, type_ = (
element.schema,
element.tname,
@@ -261,7 +304,9 @@ exec('alter table %(tname_quoted)s drop constraint ' + @const_name)""" % {
@compiles(_ExecDropFKConstraint, "mssql")
-def _exec_drop_col_fk_constraint(element, compiler, **kw):
+def _exec_drop_col_fk_constraint(
+ element: "_ExecDropFKConstraint", compiler: "MSSQLCompiler", **kw
+) -> str:
schema, tname, colname = element.schema, element.tname, element.colname
return """declare @const_name varchar(256)
@@ -279,19 +324,23 @@ exec('alter table %(tname_quoted)s drop constraint ' + @const_name)""" % {
@compiles(AddColumn, "mssql")
-def visit_add_column(element, compiler, **kw):
+def visit_add_column(
+ element: "AddColumn", compiler: "MSDDLCompiler", **kw
+) -> str:
return "%s %s" % (
alter_table(compiler, element.table_name, element.schema),
mssql_add_column(compiler, element.column, **kw),
)
-def mssql_add_column(compiler, column, **kw):
+def mssql_add_column(compiler: "MSDDLCompiler", column: "Column", **kw) -> str:
return "ADD %s" % compiler.get_column_specification(column, **kw)
@compiles(ColumnNullable, "mssql")
-def visit_column_nullable(element, compiler, **kw):
+def visit_column_nullable(
+ element: "ColumnNullable", compiler: "MSDDLCompiler", **kw
+) -> str:
return "%s %s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
@@ -301,7 +350,9 @@ def visit_column_nullable(element, compiler, **kw):
@compiles(ColumnDefault, "mssql")
-def visit_column_default(element, compiler, **kw):
+def visit_column_default(
+ element: "ColumnDefault", compiler: "MSDDLCompiler", **kw
+) -> str:
# TODO: there can also be a named constraint
# with ADD CONSTRAINT here
return "%s ADD DEFAULT %s FOR %s" % (
@@ -312,7 +363,9 @@ def visit_column_default(element, compiler, **kw):
@compiles(ColumnName, "mssql")
-def visit_rename_column(element, compiler, **kw):
+def visit_rename_column(
+ element: "ColumnName", compiler: "MSDDLCompiler", **kw
+) -> str:
return "EXEC sp_rename '%s.%s', %s, 'COLUMN'" % (
format_table_name(compiler, element.table_name, element.schema),
format_column_name(compiler, element.column_name),
@@ -321,7 +374,9 @@ def visit_rename_column(element, compiler, **kw):
@compiles(ColumnType, "mssql")
-def visit_column_type(element, compiler, **kw):
+def visit_column_type(
+ element: "ColumnType", compiler: "MSDDLCompiler", **kw
+) -> str:
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
@@ -330,7 +385,9 @@ def visit_column_type(element, compiler, **kw):
@compiles(RenameTable, "mssql")
-def visit_rename_table(element, compiler, **kw):
+def visit_rename_table(
+ element: "RenameTable", compiler: "MSDDLCompiler", **kw
+) -> str:
return "EXEC sp_rename '%s', %s" % (
format_table_name(compiler, element.table_name, element.schema),
format_table_name(compiler, element.new_table_name, None),
diff --git a/alembic/ddl/mysql.py b/alembic/ddl/mysql.py
index 4761f75..9489560 100644
--- a/alembic/ddl/mysql.py
+++ b/alembic/ddl/mysql.py
@@ -1,4 +1,8 @@
import re
+from typing import Any
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import Union
from sqlalchemy import schema
from sqlalchemy import types as sqltypes
@@ -19,6 +23,16 @@ from ..util import sqla_compat
from ..util.sqla_compat import _is_mariadb
from ..util.sqla_compat import _is_type_bound
+if TYPE_CHECKING:
+ from typing import Literal
+
+ from sqlalchemy.dialects.mysql.base import MySQLDDLCompiler
+ from sqlalchemy.sql.ddl import DropConstraint
+ from sqlalchemy.sql.schema import Constraint
+ from sqlalchemy.sql.type_api import TypeEngine
+
+ from .base import _ServerDefault
+
class MySQLImpl(DefaultImpl):
__dialect__ = "mysql"
@@ -27,24 +41,24 @@ class MySQLImpl(DefaultImpl):
type_synonyms = DefaultImpl.type_synonyms + ({"BOOL", "TINYINT"},)
type_arg_extract = [r"character set ([\w\-_]+)", r"collate ([\w\-_]+)"]
- def alter_column(
+ def alter_column( # type:ignore[override]
self,
- table_name,
- column_name,
- nullable=None,
- server_default=False,
- name=None,
- type_=None,
- schema=None,
- existing_type=None,
- existing_server_default=None,
- existing_nullable=None,
- autoincrement=None,
- existing_autoincrement=None,
- comment=False,
- existing_comment=None,
- **kw
- ):
+ table_name: str,
+ column_name: str,
+ nullable: Optional[bool] = None,
+ server_default: Union["_ServerDefault", "Literal[False]"] = False,
+ name: Optional[str] = None,
+ type_: Optional["TypeEngine"] = None,
+ schema: Optional[str] = None,
+ existing_type: Optional["TypeEngine"] = None,
+ existing_server_default: Optional["_ServerDefault"] = None,
+ existing_nullable: Optional[bool] = None,
+ autoincrement: Optional[bool] = None,
+ existing_autoincrement: Optional[bool] = None,
+ comment: Optional[Union[str, "Literal[False]"]] = False,
+ existing_comment: Optional[str] = None,
+ **kw: Any
+ ) -> None:
if sqla_compat._server_default_is_identity(
server_default, existing_server_default
) or sqla_compat._server_default_is_computed(
@@ -126,16 +140,24 @@ class MySQLImpl(DefaultImpl):
)
)
- def drop_constraint(self, const):
+ def drop_constraint(
+ self,
+ const: "Constraint",
+ ) -> None:
if isinstance(const, schema.CheckConstraint) and _is_type_bound(const):
return
super(MySQLImpl, self).drop_constraint(const)
- def _is_mysql_allowed_functional_default(self, type_, server_default):
+ def _is_mysql_allowed_functional_default(
+ self,
+ type_: Optional["TypeEngine"],
+ server_default: Union["_ServerDefault", "Literal[False]"],
+ ) -> bool:
return (
type_ is not None
- and type_._type_affinity is sqltypes.DateTime
+ and type_._type_affinity # type:ignore[attr-defined]
+ is sqltypes.DateTime
and server_default is not None
)
@@ -268,7 +290,13 @@ class MariaDBImpl(MySQLImpl):
class MySQLAlterDefault(AlterColumn):
- def __init__(self, name, column_name, default, schema=None):
+ def __init__(
+ self,
+ name: str,
+ column_name: str,
+ default: "_ServerDefault",
+ schema: Optional[str] = None,
+ ) -> None:
super(AlterColumn, self).__init__(name, schema=schema)
self.column_name = column_name
self.default = default
@@ -277,16 +305,16 @@ class MySQLAlterDefault(AlterColumn):
class MySQLChangeColumn(AlterColumn):
def __init__(
self,
- name,
- column_name,
- schema=None,
- newname=None,
- type_=None,
- nullable=None,
- default=False,
- autoincrement=None,
- comment=False,
- ):
+ name: str,
+ column_name: str,
+ schema: Optional[str] = None,
+ newname: Optional[str] = None,
+ type_: Optional["TypeEngine"] = None,
+ nullable: Optional[bool] = None,
+ default: Optional[Union["_ServerDefault", "Literal[False]"]] = False,
+ autoincrement: Optional[bool] = None,
+ comment: Optional[Union[str, "Literal[False]"]] = False,
+ ) -> None:
super(AlterColumn, self).__init__(name, schema=schema)
self.column_name = column_name
self.nullable = nullable
@@ -318,7 +346,9 @@ def _mysql_doesnt_support_individual(element, compiler, **kw):
@compiles(MySQLAlterDefault, "mysql", "mariadb")
-def _mysql_alter_default(element, compiler, **kw):
+def _mysql_alter_default(
+ element: "MySQLAlterDefault", compiler: "MySQLDDLCompiler", **kw
+) -> str:
return "%s ALTER COLUMN %s %s" % (
alter_table(compiler, element.table_name, element.schema),
format_column_name(compiler, element.column_name),
@@ -329,7 +359,9 @@ def _mysql_alter_default(element, compiler, **kw):
@compiles(MySQLModifyColumn, "mysql", "mariadb")
-def _mysql_modify_column(element, compiler, **kw):
+def _mysql_modify_column(
+ element: "MySQLModifyColumn", compiler: "MySQLDDLCompiler", **kw
+) -> str:
return "%s MODIFY %s %s" % (
alter_table(compiler, element.table_name, element.schema),
format_column_name(compiler, element.column_name),
@@ -345,7 +377,9 @@ def _mysql_modify_column(element, compiler, **kw):
@compiles(MySQLChangeColumn, "mysql", "mariadb")
-def _mysql_change_column(element, compiler, **kw):
+def _mysql_change_column(
+ element: "MySQLChangeColumn", compiler: "MySQLDDLCompiler", **kw
+) -> str:
return "%s CHANGE %s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
format_column_name(compiler, element.column_name),
@@ -362,8 +396,13 @@ def _mysql_change_column(element, compiler, **kw):
def _mysql_colspec(
- compiler, nullable, server_default, type_, autoincrement, comment
-):
+ compiler: "MySQLDDLCompiler",
+ nullable: Optional[bool],
+ server_default: Optional[Union["_ServerDefault", "Literal[False]"]],
+ type_: "TypeEngine",
+ autoincrement: Optional[bool],
+ comment: Optional[Union[str, "Literal[False]"]],
+) -> str:
spec = "%s %s" % (
compiler.dialect.type_compiler.process(type_),
"NULL" if nullable else "NOT NULL",
@@ -381,7 +420,9 @@ def _mysql_colspec(
@compiles(schema.DropConstraint, "mysql", "mariadb")
-def _mysql_drop_constraint(element, compiler, **kw):
+def _mysql_drop_constraint(
+ element: "DropConstraint", compiler: "MySQLDDLCompiler", **kw
+) -> str:
"""Redefine SQLAlchemy's drop constraint to
raise errors for invalid constraint type."""
@@ -394,7 +435,8 @@ def _mysql_drop_constraint(element, compiler, **kw):
schema.UniqueConstraint,
),
):
- return compiler.visit_drop_constraint(element, **kw)
+ assert not kw
+ return compiler.visit_drop_constraint(element)
elif isinstance(constraint, schema.CheckConstraint):
# note that SQLAlchemy as of 1.2 does not yet support
# DROP CONSTRAINT for MySQL/MariaDB, so we implement fully
diff --git a/alembic/ddl/oracle.py b/alembic/ddl/oracle.py
index 90f93d2..915edb8 100644
--- a/alembic/ddl/oracle.py
+++ b/alembic/ddl/oracle.py
@@ -1,3 +1,8 @@
+from typing import Any
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import Union
+
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql import sqltypes
@@ -16,6 +21,12 @@ from .base import IdentityColumnDefault
from .base import RenameTable
from .impl import DefaultImpl
+if TYPE_CHECKING:
+ from sqlalchemy.dialects.oracle.base import OracleDDLCompiler
+ from sqlalchemy.engine.cursor import CursorResult
+ from sqlalchemy.engine.cursor import LegacyCursorResult
+ from sqlalchemy.sql.schema import Column
+
class OracleImpl(DefaultImpl):
__dialect__ = "oracle"
@@ -28,27 +39,31 @@ class OracleImpl(DefaultImpl):
)
identity_attrs_ignore = ()
- def __init__(self, *arg, **kw):
+ def __init__(self, *arg, **kw) -> None:
super(OracleImpl, self).__init__(*arg, **kw)
self.batch_separator = self.context_opts.get(
"oracle_batch_separator", self.batch_separator
)
- def _exec(self, construct, *args, **kw):
+ def _exec(
+ self, construct: Any, *args, **kw
+ ) -> Optional[Union["LegacyCursorResult", "CursorResult"]]:
result = super(OracleImpl, self)._exec(construct, *args, **kw)
if self.as_sql and self.batch_separator:
self.static_output(self.batch_separator)
return result
- def emit_begin(self):
+ def emit_begin(self) -> None:
self._exec("SET TRANSACTION READ WRITE")
- def emit_commit(self):
+ def emit_commit(self) -> None:
self._exec("COMMIT")
@compiles(AddColumn, "oracle")
-def visit_add_column(element, compiler, **kw):
+def visit_add_column(
+ element: "AddColumn", compiler: "OracleDDLCompiler", **kw
+) -> str:
return "%s %s" % (
alter_table(compiler, element.table_name, element.schema),
add_column(compiler, element.column, **kw),
@@ -56,7 +71,9 @@ def visit_add_column(element, compiler, **kw):
@compiles(ColumnNullable, "oracle")
-def visit_column_nullable(element, compiler, **kw):
+def visit_column_nullable(
+ element: "ColumnNullable", compiler: "OracleDDLCompiler", **kw
+) -> str:
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
@@ -65,7 +82,9 @@ def visit_column_nullable(element, compiler, **kw):
@compiles(ColumnType, "oracle")
-def visit_column_type(element, compiler, **kw):
+def visit_column_type(
+ element: "ColumnType", compiler: "OracleDDLCompiler", **kw
+) -> str:
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
@@ -74,7 +93,9 @@ def visit_column_type(element, compiler, **kw):
@compiles(ColumnName, "oracle")
-def visit_column_name(element, compiler, **kw):
+def visit_column_name(
+ element: "ColumnName", compiler: "OracleDDLCompiler", **kw
+) -> str:
return "%s RENAME COLUMN %s TO %s" % (
alter_table(compiler, element.table_name, element.schema),
format_column_name(compiler, element.column_name),
@@ -83,7 +104,9 @@ def visit_column_name(element, compiler, **kw):
@compiles(ColumnDefault, "oracle")
-def visit_column_default(element, compiler, **kw):
+def visit_column_default(
+ element: "ColumnDefault", compiler: "OracleDDLCompiler", **kw
+) -> str:
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
@@ -94,7 +117,9 @@ def visit_column_default(element, compiler, **kw):
@compiles(ColumnComment, "oracle")
-def visit_column_comment(element, compiler, **kw):
+def visit_column_comment(
+ element: "ColumnComment", compiler: "OracleDDLCompiler", **kw
+) -> str:
ddl = "COMMENT ON COLUMN {table_name}.{column_name} IS {comment}"
comment = compiler.sql_compiler.render_literal_value(
@@ -110,23 +135,27 @@ def visit_column_comment(element, compiler, **kw):
@compiles(RenameTable, "oracle")
-def visit_rename_table(element, compiler, **kw):
+def visit_rename_table(
+ element: "RenameTable", compiler: "OracleDDLCompiler", **kw
+) -> str:
return "%s RENAME TO %s" % (
alter_table(compiler, element.table_name, element.schema),
format_table_name(compiler, element.new_table_name, None),
)
-def alter_column(compiler, name):
+def alter_column(compiler: "OracleDDLCompiler", name: str) -> str:
return "MODIFY %s" % format_column_name(compiler, name)
-def add_column(compiler, column, **kw):
+def add_column(compiler: "OracleDDLCompiler", column: "Column", **kw) -> str:
return "ADD %s" % compiler.get_column_specification(column, **kw)
@compiles(IdentityColumnDefault, "oracle")
-def visit_identity_column(element, compiler, **kw):
+def visit_identity_column(
+ element: "IdentityColumnDefault", compiler: "OracleDDLCompiler", **kw
+):
text = "%s %s " % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
diff --git a/alembic/ddl/postgresql.py b/alembic/ddl/postgresql.py
index 7468f08..c894649 100644
--- a/alembic/ddl/postgresql.py
+++ b/alembic/ddl/postgresql.py
@@ -1,5 +1,13 @@
import logging
import re
+from typing import Any
+from typing import cast
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import Union
from sqlalchemy import Column
from sqlalchemy import Numeric
@@ -8,8 +16,8 @@ from sqlalchemy import types as sqltypes
from sqlalchemy.dialects.postgresql import BIGINT
from sqlalchemy.dialects.postgresql import ExcludeConstraint
from sqlalchemy.dialects.postgresql import INTEGER
-from sqlalchemy.sql.expression import ColumnClause
-from sqlalchemy.sql.expression import UnaryExpression
+from sqlalchemy.sql.elements import ColumnClause
+from sqlalchemy.sql.elements import UnaryExpression
from sqlalchemy.types import NULLTYPE
from .base import alter_column
@@ -32,6 +40,25 @@ from ..operations.base import Operations
from ..util import compat
from ..util import sqla_compat
+if TYPE_CHECKING:
+ from typing import Literal
+
+ from sqlalchemy.dialects.postgresql.array import ARRAY
+ from sqlalchemy.dialects.postgresql.base import PGDDLCompiler
+ from sqlalchemy.dialects.postgresql.hstore import HSTORE
+ from sqlalchemy.dialects.postgresql.json import JSON
+ from sqlalchemy.dialects.postgresql.json import JSONB
+ from sqlalchemy.sql.elements import BinaryExpression
+ from sqlalchemy.sql.elements import quoted_name
+ from sqlalchemy.sql.schema import MetaData
+ from sqlalchemy.sql.schema import Table
+ from sqlalchemy.sql.type_api import TypeEngine
+
+ from .base import _ServerDefault
+ from ..autogenerate.api import AutogenContext
+ from ..autogenerate.render import _f_name
+ from ..runtime.migration import MigrationContext
+
log = logging.getLogger(__name__)
@@ -94,22 +121,22 @@ class PostgresqlImpl(DefaultImpl):
)
)
- def alter_column(
+ def alter_column( # type:ignore[override]
self,
- table_name,
- column_name,
- nullable=None,
- server_default=False,
- name=None,
- type_=None,
- schema=None,
- autoincrement=None,
- existing_type=None,
- existing_server_default=None,
- existing_nullable=None,
- existing_autoincrement=None,
- **kw
- ):
+ table_name: str,
+ column_name: str,
+ nullable: Optional[bool] = None,
+ server_default: Union["_ServerDefault", "Literal[False]"] = False,
+ name: Optional[str] = None,
+ type_: Optional["TypeEngine"] = None,
+ schema: Optional[str] = None,
+ autoincrement: Optional[bool] = None,
+ existing_type: Optional["TypeEngine"] = None,
+ existing_server_default: Optional["_ServerDefault"] = None,
+ existing_nullable: Optional[bool] = None,
+ existing_autoincrement: Optional[bool] = None,
+ **kw: Any
+ ) -> None:
using = kw.pop("postgresql_using", None)
@@ -218,7 +245,9 @@ class PostgresqlImpl(DefaultImpl):
)
metadata_indexes.discard(idx)
- def render_type(self, type_, autogen_context):
+ def render_type(
+ self, type_: "TypeEngine", autogen_context: "AutogenContext"
+ ) -> Union[str, "Literal[False]"]:
mod = type(type_).__module__
if not mod.startswith("sqlalchemy.dialects.postgresql"):
return False
@@ -229,29 +258,51 @@ class PostgresqlImpl(DefaultImpl):
return False
- def _render_HSTORE_type(self, type_, autogen_context):
- return render._render_type_w_subtype(
- type_, autogen_context, "text_type", r"(.+?\(.*text_type=)"
+ def _render_HSTORE_type(
+ self, type_: "HSTORE", autogen_context: "AutogenContext"
+ ) -> str:
+ return cast(
+ str,
+ render._render_type_w_subtype(
+ type_, autogen_context, "text_type", r"(.+?\(.*text_type=)"
+ ),
)
- def _render_ARRAY_type(self, type_, autogen_context):
- return render._render_type_w_subtype(
- type_, autogen_context, "item_type", r"(.+?\()"
+ def _render_ARRAY_type(
+ self, type_: "ARRAY", autogen_context: "AutogenContext"
+ ) -> str:
+ return cast(
+ str,
+ render._render_type_w_subtype(
+ type_, autogen_context, "item_type", r"(.+?\()"
+ ),
)
- def _render_JSON_type(self, type_, autogen_context):
- return render._render_type_w_subtype(
- type_, autogen_context, "astext_type", r"(.+?\(.*astext_type=)"
+ def _render_JSON_type(
+ self, type_: "JSON", autogen_context: "AutogenContext"
+ ) -> str:
+ return cast(
+ str,
+ render._render_type_w_subtype(
+ type_, autogen_context, "astext_type", r"(.+?\(.*astext_type=)"
+ ),
)
- def _render_JSONB_type(self, type_, autogen_context):
- return render._render_type_w_subtype(
- type_, autogen_context, "astext_type", r"(.+?\(.*astext_type=)"
+ def _render_JSONB_type(
+ self, type_: "JSONB", autogen_context: "AutogenContext"
+ ) -> str:
+ return cast(
+ str,
+ render._render_type_w_subtype(
+ type_, autogen_context, "astext_type", r"(.+?\(.*astext_type=)"
+ ),
)
class PostgresqlColumnType(AlterColumn):
- def __init__(self, name, column_name, type_, **kw):
+ def __init__(
+ self, name: str, column_name: str, type_: "TypeEngine", **kw
+ ) -> None:
using = kw.pop("using", None)
super(PostgresqlColumnType, self).__init__(name, column_name, **kw)
self.type_ = sqltypes.to_instance(type_)
@@ -259,7 +310,9 @@ class PostgresqlColumnType(AlterColumn):
@compiles(RenameTable, "postgresql")
-def visit_rename_table(element, compiler, **kw):
+def visit_rename_table(
+ element: RenameTable, compiler: "PGDDLCompiler", **kw
+) -> str:
return "%s RENAME TO %s" % (
alter_table(compiler, element.table_name, element.schema),
format_table_name(compiler, element.new_table_name, None),
@@ -267,7 +320,9 @@ def visit_rename_table(element, compiler, **kw):
@compiles(PostgresqlColumnType, "postgresql")
-def visit_column_type(element, compiler, **kw):
+def visit_column_type(
+ element: PostgresqlColumnType, compiler: "PGDDLCompiler", **kw
+) -> str:
return "%s %s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
@@ -277,7 +332,9 @@ def visit_column_type(element, compiler, **kw):
@compiles(ColumnComment, "postgresql")
-def visit_column_comment(element, compiler, **kw):
+def visit_column_comment(
+ element: "ColumnComment", compiler: "PGDDLCompiler", **kw
+) -> str:
ddl = "COMMENT ON COLUMN {table_name}.{column_name} IS {comment}"
comment = (
compiler.sql_compiler.render_literal_value(
@@ -297,7 +354,9 @@ def visit_column_comment(element, compiler, **kw):
@compiles(IdentityColumnDefault, "postgresql")
-def visit_identity_column(element, compiler, **kw):
+def visit_identity_column(
+ element: "IdentityColumnDefault", compiler: "PGDDLCompiler", **kw
+):
text = "%s %s " % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
@@ -341,14 +400,17 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
def __init__(
self,
- constraint_name,
- table_name,
- elements,
- where=None,
- schema=None,
- _orig_constraint=None,
+ constraint_name: Optional[str],
+ table_name: Union[str, "quoted_name"],
+ elements: Union[
+ Sequence[Tuple[str, str]],
+ Sequence[Tuple["ColumnClause", str]],
+ ],
+ where: Optional[Union["BinaryExpression", str]] = None,
+ schema: Optional[str] = None,
+ _orig_constraint: Optional["ExcludeConstraint"] = None,
**kw
- ):
+ ) -> None:
self.constraint_name = constraint_name
self.table_name = table_name
self.elements = elements
@@ -358,13 +420,18 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
self.kw = kw
@classmethod
- def from_constraint(cls, constraint):
+ def from_constraint( # type:ignore[override]
+ cls, constraint: "ExcludeConstraint"
+ ) -> "CreateExcludeConstraintOp":
constraint_table = sqla_compat._table_for_constraint(constraint)
return cls(
constraint.name,
constraint_table.name,
- [(expr, op) for expr, name, op in constraint._render_exprs],
+ [
+ (expr, op)
+ for expr, name, op in constraint._render_exprs # type:ignore[attr-defined] # noqa
+ ],
where=constraint.where,
schema=constraint_table.schema,
_orig_constraint=constraint,
@@ -373,7 +440,9 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
using=constraint.using,
)
- def to_constraint(self, migration_context=None):
+ def to_constraint(
+ self, migration_context: Optional["MigrationContext"] = None
+ ) -> "ExcludeConstraint":
if self._orig_constraint is not None:
return self._orig_constraint
schema_obj = schemaobj.SchemaObjects(migration_context)
@@ -384,15 +453,24 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
where=self.where,
**self.kw
)
- for expr, name, oper in excl._render_exprs:
+ for (
+ expr,
+ name,
+ oper,
+ ) in excl._render_exprs: # type:ignore[attr-defined]
t.append_column(Column(name, NULLTYPE))
t.append_constraint(excl)
return excl
@classmethod
def create_exclude_constraint(
- cls, operations, constraint_name, table_name, *elements, **kw
- ):
+ cls,
+ operations: "Operations",
+ constraint_name: str,
+ table_name: str,
+ *elements: Any,
+ **kw: Any
+ ) -> Optional["Table"]:
"""Issue an alter to create an EXCLUDE constraint using the
current migration context.
@@ -453,14 +531,18 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
@render.renderers.dispatch_for(CreateExcludeConstraintOp)
-def _add_exclude_constraint(autogen_context, op):
+def _add_exclude_constraint(
+ autogen_context: "AutogenContext", op: "CreateExcludeConstraintOp"
+) -> str:
return _exclude_constraint(op.to_constraint(), autogen_context, alter=True)
@render._constraint_renderers.dispatch_for(ExcludeConstraint)
def _render_inline_exclude_constraint(
- constraint, autogen_context, namespace_metadata
-):
+ constraint: "ExcludeConstraint",
+ autogen_context: "AutogenContext",
+ namespace_metadata: "MetaData",
+) -> str:
rendered = render._user_defined_render(
"exclude", constraint, autogen_context
)
@@ -470,7 +552,7 @@ def _render_inline_exclude_constraint(
return _exclude_constraint(constraint, autogen_context, False)
-def _postgresql_autogenerate_prefix(autogen_context):
+def _postgresql_autogenerate_prefix(autogen_context: "AutogenContext") -> str:
imports = autogen_context.imports
if imports is not None:
@@ -478,8 +560,12 @@ def _postgresql_autogenerate_prefix(autogen_context):
return "postgresql."
-def _exclude_constraint(constraint, autogen_context, alter):
- opts = []
+def _exclude_constraint(
+ constraint: "ExcludeConstraint",
+ autogen_context: "AutogenContext",
+ alter: bool,
+) -> str:
+ opts: List[Tuple[str, Union[quoted_name, str, _f_name, None]]] = []
has_batch = autogen_context._has_batch
@@ -509,7 +595,7 @@ def _exclude_constraint(constraint, autogen_context, alter):
_render_potential_column(sqltext, autogen_context),
opstring,
)
- for sqltext, name, opstring in constraint._render_exprs
+ for sqltext, name, opstring in constraint._render_exprs # type:ignore[attr-defined] # noqa
]
)
if constraint.where is not None:
@@ -528,7 +614,7 @@ def _exclude_constraint(constraint, autogen_context, alter):
args = [
"(%s, %r)"
% (_render_potential_column(sqltext, autogen_context), opstring)
- for sqltext, name, opstring in constraint._render_exprs
+ for sqltext, name, opstring in constraint._render_exprs # type:ignore[attr-defined] # noqa
]
if constraint.where is not None:
args.append(
@@ -544,7 +630,9 @@ def _exclude_constraint(constraint, autogen_context, alter):
}
-def _render_potential_column(value, autogen_context):
+def _render_potential_column(
+ value: Union["ColumnClause", "Column"], autogen_context: "AutogenContext"
+) -> str:
if isinstance(value, ColumnClause):
template = "%(prefix)scolumn(%(name)r)"
diff --git a/alembic/ddl/sqlite.py b/alembic/ddl/sqlite.py
index cb790ea..2f4ed77 100644
--- a/alembic/ddl/sqlite.py
+++ b/alembic/ddl/sqlite.py
@@ -1,4 +1,9 @@
import re
+from typing import Any
+from typing import Dict
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import Union
from sqlalchemy import cast
from sqlalchemy import JSON
@@ -8,6 +13,17 @@ from sqlalchemy import sql
from .impl import DefaultImpl
from .. import util
+if TYPE_CHECKING:
+ from sqlalchemy.engine.reflection import Inspector
+ from sqlalchemy.sql.elements import Cast
+ from sqlalchemy.sql.elements import ClauseElement
+ from sqlalchemy.sql.schema import Column
+ from sqlalchemy.sql.schema import Constraint
+ from sqlalchemy.sql.schema import Table
+ from sqlalchemy.sql.type_api import TypeEngine
+
+ from ..operations.batch import BatchOperationsImpl
+
class SQLiteImpl(DefaultImpl):
__dialect__ = "sqlite"
@@ -17,7 +33,9 @@ class SQLiteImpl(DefaultImpl):
see: http://bugs.python.org/issue10740
"""
- def requires_recreate_in_batch(self, batch_op):
+ def requires_recreate_in_batch(
+ self, batch_op: "BatchOperationsImpl"
+ ) -> bool:
"""Return True if the given :class:`.BatchOperationsImpl`
would need the table to be recreated and copied in order to
proceed.
@@ -44,16 +62,16 @@ class SQLiteImpl(DefaultImpl):
else:
return False
- def add_constraint(self, const):
+ def add_constraint(self, const: "Constraint"):
# attempt to distinguish between an
# auto-gen constraint and an explicit one
- if const._create_rule is None:
+ if const._create_rule is None: # type:ignore[attr-defined]
raise NotImplementedError(
"No support for ALTER of constraints in SQLite dialect"
"Please refer to the batch mode feature which allows for "
"SQLite migrations using a copy-and-move strategy."
)
- elif const._create_rule(self):
+ elif const._create_rule(self): # type:ignore[attr-defined]
util.warn(
"Skipping unsupported ALTER for "
"creation of implicit constraint"
@@ -61,8 +79,8 @@ class SQLiteImpl(DefaultImpl):
"SQLite migrations using a copy-and-move strategy."
)
- def drop_constraint(self, const):
- if const._create_rule is None:
+ def drop_constraint(self, const: "Constraint"):
+ if const._create_rule is None: # type:ignore[attr-defined]
raise NotImplementedError(
"No support for ALTER of constraints in SQLite dialect"
"Please refer to the batch mode feature which allows for "
@@ -71,11 +89,11 @@ class SQLiteImpl(DefaultImpl):
def compare_server_default(
self,
- inspector_column,
- metadata_column,
- rendered_metadata_default,
- rendered_inspector_default,
- ):
+ inspector_column: "Column",
+ metadata_column: "Column",
+ rendered_metadata_default: Optional[str],
+ rendered_inspector_default: Optional[str],
+ ) -> bool:
if rendered_metadata_default is not None:
rendered_metadata_default = re.sub(
@@ -93,7 +111,9 @@ class SQLiteImpl(DefaultImpl):
return rendered_inspector_default != rendered_metadata_default
- def _guess_if_default_is_unparenthesized_sql_expr(self, expr):
+ def _guess_if_default_is_unparenthesized_sql_expr(
+ self, expr: Optional[str]
+ ) -> bool:
"""Determine if a server default is a SQL expression or a constant.
There are too many assertions that expect server defaults to round-trip
@@ -112,7 +132,12 @@ class SQLiteImpl(DefaultImpl):
else:
return True
- def autogen_column_reflect(self, inspector, table, column_info):
+ def autogen_column_reflect(
+ self,
+ inspector: "Inspector",
+ table: "Table",
+ column_info: Dict[str, Any],
+ ) -> None:
# SQLite expression defaults require parenthesis when sent
# as DDL
if self._guess_if_default_is_unparenthesized_sql_expr(
@@ -120,7 +145,9 @@ class SQLiteImpl(DefaultImpl):
):
column_info["default"] = "(%s)" % (column_info["default"],)
- def render_ddl_sql_expr(self, expr, is_server_default=False, **kw):
+ def render_ddl_sql_expr(
+ self, expr: "ClauseElement", is_server_default: bool = False, **kw
+ ) -> str:
# SQLite expression defaults require parenthesis when sent
# as DDL
str_expr = super(SQLiteImpl, self).render_ddl_sql_expr(
@@ -134,9 +161,15 @@ class SQLiteImpl(DefaultImpl):
str_expr = "(%s)" % (str_expr,)
return str_expr
- def cast_for_batch_migrate(self, existing, existing_transfer, new_type):
+ def cast_for_batch_migrate(
+ self,
+ existing: "Column",
+ existing_transfer: Dict[str, Union["TypeEngine", "Cast"]],
+ new_type: "TypeEngine",
+ ) -> None:
if (
- existing.type._type_affinity is not new_type._type_affinity
+ existing.type._type_affinity # type:ignore[attr-defined]
+ is not new_type._type_affinity # type:ignore[attr-defined]
and not isinstance(new_type, JSON)
):
existing_transfer["expr"] = cast(
diff --git a/alembic/environment.py b/alembic/environment.py
new file mode 100644
index 0000000..adfc93e
--- /dev/null
+++ b/alembic/environment.py
@@ -0,0 +1 @@
+from .runtime.environment import * # noqa
diff --git a/alembic/migration.py b/alembic/migration.py
new file mode 100644
index 0000000..02626e2
--- /dev/null
+++ b/alembic/migration.py
@@ -0,0 +1 @@
+from .runtime.migration import * # noqa
diff --git a/alembic/operations/base.py b/alembic/operations/base.py
index cd14080..d4ec7b1 100644
--- a/alembic/operations/base.py
+++ b/alembic/operations/base.py
@@ -1,5 +1,13 @@
from contextlib import contextmanager
import textwrap
+from typing import Any
+from typing import Callable
+from typing import Iterator
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import Union
+
+from sqlalchemy.sql.elements import conv
from . import batch
from . import schemaobj
@@ -8,12 +16,15 @@ from ..util import sqla_compat
from ..util.compat import inspect_formatargspec
from ..util.compat import inspect_getargspec
-__all__ = ("Operations", "BatchOperations")
+if TYPE_CHECKING:
+ from sqlalchemy.engine import Connection
-try:
- from sqlalchemy.sql.naming import conv
-except:
- conv = None
+ from .batch import BatchOperationsImpl
+ from .ops import MigrateOperation
+ from ..runtime.migration import MigrationContext
+ from ..util.sqla_compat import _literal_bindparam
+
+__all__ = ("Operations", "BatchOperations")
class Operations(util.ModuleClsProxy):
@@ -49,7 +60,11 @@ class Operations(util.ModuleClsProxy):
_to_impl = util.Dispatcher()
- def __init__(self, migration_context, impl=None):
+ def __init__(
+ self,
+ migration_context: "MigrationContext",
+ impl: Optional["BatchOperationsImpl"] = None,
+ ) -> None:
"""Construct a new :class:`.Operations`
:param migration_context: a :class:`.MigrationContext`
@@ -65,7 +80,9 @@ class Operations(util.ModuleClsProxy):
self.schema_obj = schemaobj.SchemaObjects(migration_context)
@classmethod
- def register_operation(cls, name, sourcename=None):
+ def register_operation(
+ cls, name: str, sourcename: Optional[str] = None
+ ) -> Callable:
"""Register a new operation for this class.
This method is normally used to add new operations
@@ -142,7 +159,7 @@ class Operations(util.ModuleClsProxy):
return register
@classmethod
- def implementation_for(cls, op_cls):
+ def implementation_for(cls, op_cls: Any) -> Callable:
"""Register an implementation for a given :class:`.MigrateOperation`.
This is part of the operation extensibility API.
@@ -161,7 +178,9 @@ class Operations(util.ModuleClsProxy):
@classmethod
@contextmanager
- def context(cls, migration_context):
+ def context(
+ cls, migration_context: "MigrationContext"
+ ) -> Iterator["Operations"]:
op = Operations(migration_context)
op._install_proxy()
yield op
@@ -342,7 +361,7 @@ class Operations(util.ModuleClsProxy):
return self.migration_context
- def invoke(self, operation):
+ def invoke(self, operation: "MigrateOperation") -> Any:
"""Given a :class:`.MigrateOperation`, invoke it in terms of
this :class:`.Operations` instance.
@@ -352,7 +371,7 @@ class Operations(util.ModuleClsProxy):
)
return fn(self, operation)
- def f(self, name):
+ def f(self, name: str) -> "conv":
"""Indicate a string name that has already had a naming convention
applied to it.
@@ -385,20 +404,14 @@ class Operations(util.ModuleClsProxy):
CONSTRAINT ck_bool_t_x CHECK (x in (1, 0)))
The function is rendered in the output of autogenerate when
- a particular constraint name is already converted, for SQLAlchemy
- version **0.9.4 and greater only**. Even though ``naming_convention``
- was introduced in 0.9.2, the string disambiguation service is new
- as of 0.9.4.
+ a particular constraint name is already converted.
"""
- if conv:
- return conv(name)
- else:
- raise NotImplementedError(
- "op.f() feature requires SQLAlchemy 0.9.4 or greater."
- )
+ return conv(name)
- def inline_literal(self, value, type_=None):
+ def inline_literal(
+ self, value: Union[str, int], type_: None = None
+ ) -> "_literal_bindparam":
r"""Produce an 'inline literal' expression, suitable for
using in an INSERT, UPDATE, or DELETE statement.
@@ -442,7 +455,7 @@ class Operations(util.ModuleClsProxy):
"""
return sqla_compat._literal_bindparam(None, value, type_=type_)
- def get_bind(self):
+ def get_bind(self) -> "Connection":
"""Return the current 'bind'.
Under normal circumstances, this is the
diff --git a/alembic/operations/batch.py b/alembic/operations/batch.py
index 656b868..ee1fe05 100644
--- a/alembic/operations/batch.py
+++ b/alembic/operations/batch.py
@@ -1,3 +1,12 @@
+from typing import Any
+from typing import cast
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import Union
+
from sqlalchemy import CheckConstraint
from sqlalchemy import Column
from sqlalchemy import ForeignKeyConstraint
@@ -21,6 +30,18 @@ from ..util.sqla_compat import _is_type_bound
from ..util.sqla_compat import _remove_column_from_collection
from ..util.sqla_compat import _select
+if TYPE_CHECKING:
+ from typing import Literal
+
+ from sqlalchemy.engine import Dialect
+ from sqlalchemy.sql.elements import ColumnClause
+ from sqlalchemy.sql.elements import quoted_name
+ from sqlalchemy.sql.functions import Function
+ from sqlalchemy.sql.schema import Constraint
+ from sqlalchemy.sql.type_api import TypeEngine
+
+ from ..ddl.impl import DefaultImpl
+
class BatchOperationsImpl:
def __init__(
@@ -61,14 +82,14 @@ class BatchOperationsImpl:
self.batch = []
@property
- def dialect(self):
+ def dialect(self) -> "Dialect":
return self.operations.impl.dialect
@property
- def impl(self):
+ def impl(self) -> "DefaultImpl":
return self.operations.impl
- def _should_recreate(self):
+ def _should_recreate(self) -> bool:
if self.recreate == "auto":
return self.operations.impl.requires_recreate_in_batch(self)
elif self.recreate == "always":
@@ -76,7 +97,7 @@ class BatchOperationsImpl:
else:
return False
- def flush(self):
+ def flush(self) -> None:
should_recreate = self._should_recreate()
with _ensure_scope_for_ddl(self.impl.connection):
@@ -118,10 +139,10 @@ class BatchOperationsImpl:
batch_impl._create(self.impl)
- def alter_column(self, *arg, **kw):
+ def alter_column(self, *arg, **kw) -> None:
self.batch.append(("alter_column", arg, kw))
- def add_column(self, *arg, **kw):
+ def add_column(self, *arg, **kw) -> None:
if (
"insert_before" in kw or "insert_after" in kw
) and not self._should_recreate():
@@ -131,22 +152,22 @@ class BatchOperationsImpl:
)
self.batch.append(("add_column", arg, kw))
- def drop_column(self, *arg, **kw):
+ def drop_column(self, *arg, **kw) -> None:
self.batch.append(("drop_column", arg, kw))
- def add_constraint(self, const):
+ def add_constraint(self, const: "Constraint") -> None:
self.batch.append(("add_constraint", (const,), {}))
- def drop_constraint(self, const):
+ def drop_constraint(self, const: "Constraint") -> None:
self.batch.append(("drop_constraint", (const,), {}))
def rename_table(self, *arg, **kw):
self.batch.append(("rename_table", arg, kw))
- def create_index(self, idx):
+ def create_index(self, idx: "Index") -> None:
self.batch.append(("create_index", (idx,), {}))
- def drop_index(self, idx):
+ def drop_index(self, idx: "Index") -> None:
self.batch.append(("drop_index", (idx,), {}))
def create_table_comment(self, table):
@@ -168,22 +189,24 @@ class BatchOperationsImpl:
class ApplyBatchImpl:
def __init__(
self,
- impl,
- table,
- table_args,
- table_kwargs,
- reflected,
- partial_reordering=(),
- ):
+ impl: "DefaultImpl",
+ table: "Table",
+ table_args: tuple,
+ table_kwargs: Dict[str, Any],
+ reflected: bool,
+ partial_reordering: tuple = (),
+ ) -> None:
self.impl = impl
self.table = table # this is a Table object
self.table_args = table_args
self.table_kwargs = table_kwargs
self.temp_table_name = self._calc_temp_name(table.name)
- self.new_table = None
+ self.new_table: Optional[Table] = None
self.partial_reordering = partial_reordering # tuple of tuples
- self.add_col_ordering = () # tuple of tuples
+ self.add_col_ordering: Tuple[
+ Tuple[str, str], ...
+ ] = () # tuple of tuples
self.column_transfers = OrderedDict(
(c.name, {"expr": c}) for c in self.table.c
@@ -194,12 +217,12 @@ class ApplyBatchImpl:
self._grab_table_elements()
@classmethod
- def _calc_temp_name(cls, tablename):
+ def _calc_temp_name(cls, tablename: "quoted_name") -> str:
return ("_alembic_tmp_%s" % tablename)[0:50]
- def _grab_table_elements(self):
+ def _grab_table_elements(self) -> None:
schema = self.table.schema
- self.columns = OrderedDict()
+ self.columns: Dict[str, "Column"] = OrderedDict()
for c in self.table.c:
c_copy = _copy(c, schema=schema)
c_copy.unique = c_copy.index = False
@@ -208,11 +231,11 @@ class ApplyBatchImpl:
if isinstance(c.type, SchemaEventTarget):
assert c_copy.type is not c.type
self.columns[c.name] = c_copy
- self.named_constraints = {}
+ self.named_constraints: Dict[str, "Constraint"] = {}
self.unnamed_constraints = []
self.col_named_constraints = {}
- self.indexes = {}
- self.new_indexes = {}
+ self.indexes: Dict[str, "Index"] = {}
+ self.new_indexes: Dict[str, "Index"] = {}
for const in self.table.constraints:
if _is_type_bound(const):
@@ -238,7 +261,7 @@ class ApplyBatchImpl:
for k in self.table.kwargs:
self.table_kwargs.setdefault(k, self.table.kwargs[k])
- def _adjust_self_columns_for_partial_reordering(self):
+ def _adjust_self_columns_for_partial_reordering(self) -> None:
pairs = set()
col_by_idx = list(self.columns)
@@ -258,17 +281,17 @@ class ApplyBatchImpl:
# this can happen if some columns were dropped and not removed
# from existing_ordering. this should be prevented already, but
# conservatively making sure this didn't happen
- pairs = [p for p in pairs if p[0] != p[1]]
+ pairs_list = [p for p in pairs if p[0] != p[1]]
sorted_ = list(
- topological.sort(pairs, col_by_idx, deterministic_order=True)
+ topological.sort(pairs_list, col_by_idx, deterministic_order=True)
)
self.columns = OrderedDict((k, self.columns[k]) for k in sorted_)
self.column_transfers = OrderedDict(
(k, self.column_transfers[k]) for k in sorted_
)
- def _transfer_elements_to_new_table(self):
+ def _transfer_elements_to_new_table(self) -> None:
assert self.new_table is None, "Can only create new table once"
m = MetaData()
@@ -296,6 +319,7 @@ class ApplyBatchImpl:
if not const_columns.issubset(self.column_transfers):
continue
+ const_copy: "Constraint"
if isinstance(const, ForeignKeyConstraint):
if _fk_is_self_referential(const):
# for self-referential constraint, refer to the
@@ -320,8 +344,9 @@ class ApplyBatchImpl:
self._setup_referent(m, const)
new_table.append_constraint(const_copy)
- def _gather_indexes_from_both_tables(self):
- idx = []
+ def _gather_indexes_from_both_tables(self) -> List["Index"]:
+ assert self.new_table is not None
+ idx: List[Index] = []
idx.extend(self.indexes.values())
for index in self.new_indexes.values():
idx.append(
@@ -334,8 +359,12 @@ class ApplyBatchImpl:
)
return idx
- def _setup_referent(self, metadata, constraint):
- spec = constraint.elements[0]._get_colspec()
+ def _setup_referent(
+ self, metadata: "MetaData", constraint: "ForeignKeyConstraint"
+ ) -> None:
+ spec = constraint.elements[
+ 0
+ ]._get_colspec() # type:ignore[attr-defined]
parts = spec.split(".")
tname = parts[-2]
if len(parts) == 3:
@@ -345,10 +374,14 @@ class ApplyBatchImpl:
if tname != self.temp_table_name:
key = sql_schema._get_table_key(tname, referent_schema)
+
+ def colspec(elem: Any):
+ return elem._get_colspec()
+
if key in metadata.tables:
t = metadata.tables[key]
for elem in constraint.elements:
- colname = elem._get_colspec().split(".")[-1]
+ colname = colspec(elem).split(".")[-1]
if colname not in t.c:
t.append_column(Column(colname, sqltypes.NULLTYPE))
else:
@@ -358,17 +391,18 @@ class ApplyBatchImpl:
*[
Column(n, sqltypes.NULLTYPE)
for n in [
- elem._get_colspec().split(".")[-1]
+ colspec(elem).split(".")[-1]
for elem in constraint.elements
]
],
schema=referent_schema
)
- def _create(self, op_impl):
+ def _create(self, op_impl: "DefaultImpl") -> None:
self._transfer_elements_to_new_table()
op_impl.prep_table_for_batch(self, self.table)
+ assert self.new_table is not None
op_impl.create_table(self.new_table)
try:
@@ -405,18 +439,18 @@ class ApplyBatchImpl:
def alter_column(
self,
- table_name,
- column_name,
- nullable=None,
- server_default=False,
- name=None,
- type_=None,
- autoincrement=None,
- comment=False,
+ table_name: str,
+ column_name: str,
+ nullable: Optional[bool] = None,
+ server_default: Optional[Union["Function", str, bool]] = False,
+ name: Optional[str] = None,
+ type_: Optional["TypeEngine"] = None,
+ autoincrement: None = None,
+ comment: Union[str, "Literal[False]"] = False,
**kw
- ):
+ ) -> None:
existing = self.columns[column_name]
- existing_transfer = self.column_transfers[column_name]
+ existing_transfer: Dict[str, Any] = self.column_transfers[column_name]
if name is not None and name != column_name:
# note that we don't change '.key' - we keep referring
# to the renamed column by its old key in _create(). neat!
@@ -431,8 +465,8 @@ class ApplyBatchImpl:
# we also ignore the drop_constraint that will come here from
# Operations.implementation_for(alter_column)
if isinstance(existing.type, SchemaEventTarget):
- existing.type._create_events = (
- existing.type.create_constraint
+ existing.type._create_events = ( # type:ignore[attr-defined]
+ existing.type.create_constraint # type:ignore[attr-defined] # noqa
) = False
self.impl.cast_for_batch_migrate(
@@ -452,7 +486,11 @@ class ApplyBatchImpl:
if server_default is None:
existing.server_default = None
else:
- sql_schema.DefaultClause(server_default)._set_parent(existing)
+ sql_schema.DefaultClause(
+ server_default
+ )._set_parent( # type:ignore[attr-defined]
+ existing
+ )
if autoincrement is not None:
existing.autoincrement = bool(autoincrement)
@@ -460,8 +498,11 @@ class ApplyBatchImpl:
existing.comment = comment
def _setup_dependencies_for_add_column(
- self, colname, insert_before, insert_after
- ):
+ self,
+ colname: str,
+ insert_before: Optional[str],
+ insert_after: Optional[str],
+ ) -> None:
index_cols = self.existing_ordering
col_indexes = {name: i for i, name in enumerate(index_cols)}
@@ -505,8 +546,13 @@ class ApplyBatchImpl:
self.add_col_ordering += ((index_cols[-1], colname),)
def add_column(
- self, table_name, column, insert_before=None, insert_after=None, **kw
- ):
+ self,
+ table_name: str,
+ column: "Column",
+ insert_before: Optional[str] = None,
+ insert_after: Optional[str] = None,
+ **kw
+ ) -> None:
self._setup_dependencies_for_add_column(
column.name, insert_before, insert_after
)
@@ -515,7 +561,9 @@ class ApplyBatchImpl:
self.columns[column.name] = _copy(column, schema=self.table.schema)
self.column_transfers[column.name] = {}
- def drop_column(self, table_name, column, **kw):
+ def drop_column(
+ self, table_name: str, column: Union["ColumnClause", "Column"], **kw
+ ) -> None:
if column.name in self.table.primary_key.columns:
_remove_column_from_collection(
self.table.primary_key.columns, column
@@ -546,7 +594,7 @@ class ApplyBatchImpl:
"""
- def add_constraint(self, const):
+ def add_constraint(self, const: "Constraint") -> None:
if not const.name:
raise ValueError("Constraint must have a name")
if isinstance(const, sql_schema.PrimaryKeyConstraint):
@@ -555,7 +603,7 @@ class ApplyBatchImpl:
self.named_constraints[const.name] = const
- def drop_constraint(self, const):
+ def drop_constraint(self, const: "Constraint") -> None:
if not const.name:
raise ValueError("Constraint must have a name")
try:
@@ -566,7 +614,7 @@ class ApplyBatchImpl:
if col_const.name == const.name:
self.columns[col.name].constraints.remove(col_const)
else:
- const = self.named_constraints.pop(const.name)
+ const = self.named_constraints.pop(cast(str, const.name))
except KeyError:
if _is_type_bound(const):
# type-bound constraints are only included in the new
@@ -580,10 +628,10 @@ class ApplyBatchImpl:
for col in const.columns:
self.columns[col.name].primary_key = False
- def create_index(self, idx):
+ def create_index(self, idx: "Index") -> None:
self.new_indexes[idx.name] = idx
- def drop_index(self, idx):
+ def drop_index(self, idx: "Index") -> None:
try:
del self.indexes[idx.name]
except KeyError:
diff --git a/alembic/operations/ops.py b/alembic/operations/ops.py
index 8979355..d5ddbc9 100644
--- a/alembic/operations/ops.py
+++ b/alembic/operations/ops.py
@@ -1,4 +1,19 @@
+from abc import abstractmethod
import re
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import FrozenSet
+from typing import Iterator
+from typing import List
+from typing import MutableMapping
+from typing import Optional
+from typing import Sequence
+from typing import Set
+from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
+from typing import Union
from sqlalchemy.types import NULLTYPE
@@ -8,6 +23,33 @@ from .base import Operations
from .. import util
from ..util import sqla_compat
+if TYPE_CHECKING:
+ from sqlalchemy.sql.dml import Insert
+ from sqlalchemy.sql.dml import Update
+ from sqlalchemy.sql.elements import BinaryExpression
+ from sqlalchemy.sql.elements import ColumnElement
+ from sqlalchemy.sql.elements import conv
+ from sqlalchemy.sql.elements import quoted_name
+ from sqlalchemy.sql.elements import TextClause
+ from sqlalchemy.sql.functions import Function
+ from sqlalchemy.sql.schema import CheckConstraint
+ from sqlalchemy.sql.schema import Column
+ from sqlalchemy.sql.schema import Computed
+ from sqlalchemy.sql.schema import Constraint
+ from sqlalchemy.sql.schema import ForeignKeyConstraint
+ from sqlalchemy.sql.schema import Identity
+ from sqlalchemy.sql.schema import Index
+ from sqlalchemy.sql.schema import MetaData
+ from sqlalchemy.sql.schema import PrimaryKeyConstraint
+ from sqlalchemy.sql.schema import SchemaItem
+ from sqlalchemy.sql.schema import Table
+ from sqlalchemy.sql.schema import UniqueConstraint
+ from sqlalchemy.sql.selectable import TableClause
+ from sqlalchemy.sql.type_api import TypeEngine
+
+ from ..autogenerate.rewriter import Rewriter
+ from ..runtime.migration import MigrationContext
+
class MigrateOperation:
"""base class for migration command and organization objects.
@@ -32,7 +74,13 @@ class MigrateOperation:
"""
return {}
- _mutations = frozenset()
+ _mutations: FrozenSet["Rewriter"] = frozenset()
+
+ def reverse(self) -> "MigrateOperation":
+ raise NotImplementedError
+
+ def to_diff_tuple(self) -> Tuple[Any, ...]:
+ raise NotImplementedError
class AddConstraintOp(MigrateOperation):
@@ -45,7 +93,7 @@ class AddConstraintOp(MigrateOperation):
raise NotImplementedError()
@classmethod
- def register_add_constraint(cls, type_):
+ def register_add_constraint(cls, type_: str) -> Callable:
def go(klass):
cls.add_constraint_ops.dispatch_for(type_)(klass.from_constraint)
return klass
@@ -53,15 +101,21 @@ class AddConstraintOp(MigrateOperation):
return go
@classmethod
- def from_constraint(cls, constraint):
+ def from_constraint(cls, constraint: "Constraint") -> "AddConstraintOp":
return cls.add_constraint_ops.dispatch(constraint.__visit_name__)(
constraint
)
- def reverse(self):
+ @abstractmethod
+ def to_constraint(
+ self, migration_context: Optional["MigrationContext"] = None
+ ) -> "Constraint":
+ pass
+
+ def reverse(self) -> "DropConstraintOp":
return DropConstraintOp.from_constraint(self.to_constraint())
- def to_diff_tuple(self):
+ def to_diff_tuple(self) -> Tuple[str, "Constraint"]:
return ("add_constraint", self.to_constraint())
@@ -72,29 +126,34 @@ class DropConstraintOp(MigrateOperation):
def __init__(
self,
- constraint_name,
- table_name,
- type_=None,
- schema=None,
- _reverse=None,
- ):
+ constraint_name: Optional[str],
+ table_name: str,
+ type_: Optional[str] = None,
+ schema: Optional[str] = None,
+ _reverse: Optional["AddConstraintOp"] = None,
+ ) -> None:
self.constraint_name = constraint_name
self.table_name = table_name
self.constraint_type = type_
self.schema = schema
self._reverse = _reverse
- def reverse(self):
+ def reverse(self) -> "AddConstraintOp":
return AddConstraintOp.from_constraint(self.to_constraint())
- def to_diff_tuple(self):
+ def to_diff_tuple(
+ self,
+ ) -> Tuple[str, "SchemaItem"]:
if self.constraint_type == "foreignkey":
return ("remove_fk", self.to_constraint())
else:
return ("remove_constraint", self.to_constraint())
@classmethod
- def from_constraint(cls, constraint):
+ def from_constraint(
+ cls,
+ constraint: "Constraint",
+ ) -> "DropConstraintOp":
types = {
"unique_constraint": "unique",
"foreign_key_constraint": "foreignkey",
@@ -113,7 +172,9 @@ class DropConstraintOp(MigrateOperation):
_reverse=AddConstraintOp.from_constraint(constraint),
)
- def to_constraint(self):
+ def to_constraint(
+ self,
+ ) -> "Constraint":
if self._reverse is not None:
constraint = self._reverse.to_constraint()
@@ -131,8 +192,13 @@ class DropConstraintOp(MigrateOperation):
@classmethod
def drop_constraint(
- cls, operations, constraint_name, table_name, type_=None, schema=None
- ):
+ cls,
+ operations: "Operations",
+ constraint_name: str,
+ table_name: str,
+ type_: Optional[str] = None,
+ schema: Optional[str] = None,
+ ) -> Optional["Table"]:
r"""Drop a constraint of the given name, typically via DROP CONSTRAINT.
:param constraint_name: name of the constraint.
@@ -150,7 +216,12 @@ class DropConstraintOp(MigrateOperation):
return operations.invoke(op)
@classmethod
- def batch_drop_constraint(cls, operations, constraint_name, type_=None):
+ def batch_drop_constraint(
+ cls,
+ operations: "BatchOperations",
+ constraint_name: str,
+ type_: Optional[str] = None,
+ ) -> None:
"""Issue a "drop constraint" instruction using the
current batch migration context.
@@ -182,8 +253,13 @@ class CreatePrimaryKeyOp(AddConstraintOp):
constraint_type = "primarykey"
def __init__(
- self, constraint_name, table_name, columns, schema=None, **kw
- ):
+ self,
+ constraint_name: Optional[str],
+ table_name: str,
+ columns: Sequence[str],
+ schema: Optional[str] = None,
+ **kw
+ ) -> None:
self.constraint_name = constraint_name
self.table_name = table_name
self.columns = columns
@@ -191,18 +267,23 @@ class CreatePrimaryKeyOp(AddConstraintOp):
self.kw = kw
@classmethod
- def from_constraint(cls, constraint):
+ def from_constraint(cls, constraint: "Constraint") -> "CreatePrimaryKeyOp":
constraint_table = sqla_compat._table_for_constraint(constraint)
+ pk_constraint = cast("PrimaryKeyConstraint", constraint)
+
return cls(
- constraint.name,
+ pk_constraint.name,
constraint_table.name,
- constraint.columns.keys(),
+ pk_constraint.columns.keys(),
schema=constraint_table.schema,
- **constraint.dialect_kwargs,
+ **pk_constraint.dialect_kwargs,
)
- def to_constraint(self, migration_context=None):
+ def to_constraint(
+ self, migration_context: Optional["MigrationContext"] = None
+ ) -> "PrimaryKeyConstraint":
schema_obj = schemaobj.SchemaObjects(migration_context)
+
return schema_obj.primary_key_constraint(
self.constraint_name,
self.table_name,
@@ -213,8 +294,13 @@ class CreatePrimaryKeyOp(AddConstraintOp):
@classmethod
def create_primary_key(
- cls, operations, constraint_name, table_name, columns, schema=None
- ):
+ cls,
+ operations: "Operations",
+ constraint_name: str,
+ table_name: str,
+ columns: List[str],
+ schema: Optional[str] = None,
+ ) -> Optional["Table"]:
"""Issue a "create primary key" instruction using the current
migration context.
@@ -255,7 +341,12 @@ class CreatePrimaryKeyOp(AddConstraintOp):
return operations.invoke(op)
@classmethod
- def batch_create_primary_key(cls, operations, constraint_name, columns):
+ def batch_create_primary_key(
+ cls,
+ operations: "BatchOperations",
+ constraint_name: str,
+ columns: List[str],
+ ) -> None:
"""Issue a "create primary key" instruction using the
current batch migration context.
@@ -287,8 +378,13 @@ class CreateUniqueConstraintOp(AddConstraintOp):
constraint_type = "unique"
def __init__(
- self, constraint_name, table_name, columns, schema=None, **kw
- ):
+ self,
+ constraint_name: Optional[str],
+ table_name: str,
+ columns: Sequence[str],
+ schema: Optional[str] = None,
+ **kw
+ ) -> None:
self.constraint_name = constraint_name
self.table_name = table_name
self.columns = columns
@@ -296,24 +392,31 @@ class CreateUniqueConstraintOp(AddConstraintOp):
self.kw = kw
@classmethod
- def from_constraint(cls, constraint):
+ def from_constraint(
+ cls, constraint: "Constraint"
+ ) -> "CreateUniqueConstraintOp":
+
constraint_table = sqla_compat._table_for_constraint(constraint)
- kw = {}
- if constraint.deferrable:
- kw["deferrable"] = constraint.deferrable
- if constraint.initially:
- kw["initially"] = constraint.initially
- kw.update(constraint.dialect_kwargs)
+ uq_constraint = cast("UniqueConstraint", constraint)
+
+ kw: dict = {}
+ if uq_constraint.deferrable:
+ kw["deferrable"] = uq_constraint.deferrable
+ if uq_constraint.initially:
+ kw["initially"] = uq_constraint.initially
+ kw.update(uq_constraint.dialect_kwargs)
return cls(
- constraint.name,
+ uq_constraint.name,
constraint_table.name,
- [c.name for c in constraint.columns],
+ [c.name for c in uq_constraint.columns],
schema=constraint_table.schema,
**kw,
)
- def to_constraint(self, migration_context=None):
+ def to_constraint(
+ self, migration_context: Optional["MigrationContext"] = None
+ ) -> "UniqueConstraint":
schema_obj = schemaobj.SchemaObjects(migration_context)
return schema_obj.unique_constraint(
self.constraint_name,
@@ -326,13 +429,13 @@ class CreateUniqueConstraintOp(AddConstraintOp):
@classmethod
def create_unique_constraint(
cls,
- operations,
- constraint_name,
- table_name,
- columns,
- schema=None,
+ operations: "Operations",
+ constraint_name: Optional[str],
+ table_name: str,
+ columns: Sequence[str],
+ schema: Optional[str] = None,
**kw
- ):
+ ) -> Any:
"""Issue a "create unique constraint" instruction using the
current migration context.
@@ -376,8 +479,12 @@ class CreateUniqueConstraintOp(AddConstraintOp):
@classmethod
def batch_create_unique_constraint(
- cls, operations, constraint_name, columns, **kw
- ):
+ cls,
+ operations: "BatchOperations",
+ constraint_name: str,
+ columns: Sequence[str],
+ **kw
+ ) -> Any:
"""Issue a "create unique constraint" instruction using the
current batch migration context.
@@ -406,13 +513,13 @@ class CreateForeignKeyOp(AddConstraintOp):
def __init__(
self,
- constraint_name,
- source_table,
- referent_table,
- local_cols,
- remote_cols,
+ constraint_name: Optional[str],
+ source_table: str,
+ referent_table: str,
+ local_cols: List[str],
+ remote_cols: List[str],
**kw
- ):
+ ) -> None:
self.constraint_name = constraint_name
self.source_table = source_table
self.referent_table = referent_table
@@ -420,22 +527,24 @@ class CreateForeignKeyOp(AddConstraintOp):
self.remote_cols = remote_cols
self.kw = kw
- def to_diff_tuple(self):
+ def to_diff_tuple(self) -> Tuple[str, "ForeignKeyConstraint"]:
return ("add_fk", self.to_constraint())
@classmethod
- def from_constraint(cls, constraint):
- kw = {}
- if constraint.onupdate:
- kw["onupdate"] = constraint.onupdate
- if constraint.ondelete:
- kw["ondelete"] = constraint.ondelete
- if constraint.initially:
- kw["initially"] = constraint.initially
- if constraint.deferrable:
- kw["deferrable"] = constraint.deferrable
- if constraint.use_alter:
- kw["use_alter"] = constraint.use_alter
+ def from_constraint(cls, constraint: "Constraint") -> "CreateForeignKeyOp":
+
+ fk_constraint = cast("ForeignKeyConstraint", constraint)
+ kw: dict = {}
+ if fk_constraint.onupdate:
+ kw["onupdate"] = fk_constraint.onupdate
+ if fk_constraint.ondelete:
+ kw["ondelete"] = fk_constraint.ondelete
+ if fk_constraint.initially:
+ kw["initially"] = fk_constraint.initially
+ if fk_constraint.deferrable:
+ kw["deferrable"] = fk_constraint.deferrable
+ if fk_constraint.use_alter:
+ kw["use_alter"] = fk_constraint.use_alter
(
source_schema,
@@ -448,13 +557,13 @@ class CreateForeignKeyOp(AddConstraintOp):
ondelete,
deferrable,
initially,
- ) = sqla_compat._fk_spec(constraint)
+ ) = sqla_compat._fk_spec(fk_constraint)
kw["source_schema"] = source_schema
kw["referent_schema"] = target_schema
- kw.update(constraint.dialect_kwargs)
+ kw.update(fk_constraint.dialect_kwargs)
return cls(
- constraint.name,
+ fk_constraint.name,
source_table,
target_table,
source_columns,
@@ -462,7 +571,9 @@ class CreateForeignKeyOp(AddConstraintOp):
**kw,
)
- def to_constraint(self, migration_context=None):
+ def to_constraint(
+ self, migration_context: Optional["MigrationContext"] = None
+ ) -> "ForeignKeyConstraint":
schema_obj = schemaobj.SchemaObjects(migration_context)
return schema_obj.foreign_key_constraint(
self.constraint_name,
@@ -476,21 +587,21 @@ class CreateForeignKeyOp(AddConstraintOp):
@classmethod
def create_foreign_key(
cls,
- operations,
- constraint_name,
- source_table,
- referent_table,
- local_cols,
- remote_cols,
- onupdate=None,
- ondelete=None,
- deferrable=None,
- initially=None,
- match=None,
- source_schema=None,
- referent_schema=None,
+ operations: "Operations",
+ constraint_name: str,
+ source_table: str,
+ referent_table: str,
+ local_cols: List[str],
+ remote_cols: List[str],
+ onupdate: Optional[str] = None,
+ ondelete: Optional[str] = None,
+ deferrable: Optional[bool] = None,
+ initially: Optional[str] = None,
+ match: Optional[str] = None,
+ source_schema: Optional[str] = None,
+ referent_schema: Optional[str] = None,
**dialect_kw
- ):
+ ) -> Optional["Table"]:
"""Issue a "create foreign key" instruction using the
current migration context.
@@ -556,19 +667,19 @@ class CreateForeignKeyOp(AddConstraintOp):
@classmethod
def batch_create_foreign_key(
cls,
- operations,
- constraint_name,
- referent_table,
- local_cols,
- remote_cols,
- referent_schema=None,
- onupdate=None,
- ondelete=None,
- deferrable=None,
- initially=None,
- match=None,
+ operations: "BatchOperations",
+ constraint_name: str,
+ referent_table: str,
+ local_cols: List[str],
+ remote_cols: List[str],
+ referent_schema: Optional[str] = None,
+ onupdate: None = None,
+ ondelete: None = None,
+ deferrable: None = None,
+ initially: None = None,
+ match: None = None,
**dialect_kw
- ):
+ ) -> None:
"""Issue a "create foreign key" instruction using the
current batch migration context.
@@ -618,8 +729,13 @@ class CreateCheckConstraintOp(AddConstraintOp):
constraint_type = "check"
def __init__(
- self, constraint_name, table_name, condition, schema=None, **kw
- ):
+ self,
+ constraint_name: Optional[str],
+ table_name: str,
+ condition: Union["TextClause", "ColumnElement[Any]"],
+ schema: Optional[str] = None,
+ **kw
+ ) -> None:
self.constraint_name = constraint_name
self.table_name = table_name
self.condition = condition
@@ -627,18 +743,26 @@ class CreateCheckConstraintOp(AddConstraintOp):
self.kw = kw
@classmethod
- def from_constraint(cls, constraint):
+ def from_constraint(
+ cls, constraint: "Constraint"
+ ) -> "CreateCheckConstraintOp":
constraint_table = sqla_compat._table_for_constraint(constraint)
+ ck_constraint = cast("CheckConstraint", constraint)
+
return cls(
- constraint.name,
+ ck_constraint.name,
constraint_table.name,
- constraint.sqltext,
+ cast(
+ "Union[TextClause, ColumnElement[Any]]", ck_constraint.sqltext
+ ),
schema=constraint_table.schema,
- **constraint.dialect_kwargs,
+ **ck_constraint.dialect_kwargs,
)
- def to_constraint(self, migration_context=None):
+ def to_constraint(
+ self, migration_context: Optional["MigrationContext"] = None
+ ) -> "CheckConstraint":
schema_obj = schemaobj.SchemaObjects(migration_context)
return schema_obj.check_constraint(
self.constraint_name,
@@ -651,13 +775,13 @@ class CreateCheckConstraintOp(AddConstraintOp):
@classmethod
def create_check_constraint(
cls,
- operations,
- constraint_name,
- table_name,
- condition,
- schema=None,
+ operations: "Operations",
+ constraint_name: Optional[str],
+ table_name: str,
+ condition: "BinaryExpression",
+ schema: Optional[str] = None,
**kw
- ):
+ ) -> Optional["Table"]:
"""Issue a "create check constraint" instruction using the
current migration context.
@@ -703,8 +827,12 @@ class CreateCheckConstraintOp(AddConstraintOp):
@classmethod
def batch_create_check_constraint(
- cls, operations, constraint_name, condition, **kw
- ):
+ cls,
+ operations: "BatchOperations",
+ constraint_name: str,
+ condition: "TextClause",
+ **kw
+ ) -> Optional["Table"]:
"""Issue a "create check constraint" instruction using the
current batch migration context.
@@ -732,8 +860,14 @@ class CreateIndexOp(MigrateOperation):
"""Represent a create index operation."""
def __init__(
- self, index_name, table_name, columns, schema=None, unique=False, **kw
- ):
+ self,
+ index_name: str,
+ table_name: str,
+ columns: Sequence[Union[str, "TextClause", "ColumnElement[Any]"]],
+ schema: Optional[str] = None,
+ unique: bool = False,
+ **kw
+ ) -> None:
self.index_name = index_name
self.table_name = table_name
self.columns = columns
@@ -741,14 +875,15 @@ class CreateIndexOp(MigrateOperation):
self.unique = unique
self.kw = kw
- def reverse(self):
+ def reverse(self) -> "DropIndexOp":
return DropIndexOp.from_index(self.to_index())
- def to_diff_tuple(self):
+ def to_diff_tuple(self) -> Tuple[str, "Index"]:
return ("add_index", self.to_index())
@classmethod
- def from_index(cls, index):
+ def from_index(cls, index: "Index") -> "CreateIndexOp":
+ assert index.table is not None
return cls(
index.name,
index.table.name,
@@ -758,7 +893,9 @@ class CreateIndexOp(MigrateOperation):
**index.kwargs,
)
- def to_index(self, migration_context=None):
+ def to_index(
+ self, migration_context: Optional["MigrationContext"] = None
+ ) -> "Index":
schema_obj = schemaobj.SchemaObjects(migration_context)
idx = schema_obj.index(
@@ -774,14 +911,14 @@ class CreateIndexOp(MigrateOperation):
@classmethod
def create_index(
cls,
- operations,
- index_name,
- table_name,
- columns,
- schema=None,
- unique=False,
+ operations: Operations,
+ index_name: str,
+ table_name: str,
+ columns: Sequence[Union[str, "TextClause", "Function"]],
+ schema: Optional[str] = None,
+ unique: bool = False,
**kw
- ):
+ ) -> Optional["Table"]:
r"""Issue a "create index" instruction using the current
migration context.
@@ -829,7 +966,13 @@ class CreateIndexOp(MigrateOperation):
return operations.invoke(op)
@classmethod
- def batch_create_index(cls, operations, index_name, columns, **kw):
+ def batch_create_index(
+ cls,
+ operations: "BatchOperations",
+ index_name: str,
+ columns: List[str],
+ **kw
+ ) -> Optional["Table"]:
"""Issue a "create index" instruction using the
current batch migration context.
@@ -855,22 +998,28 @@ class DropIndexOp(MigrateOperation):
"""Represent a drop index operation."""
def __init__(
- self, index_name, table_name=None, schema=None, _reverse=None, **kw
- ):
+ self,
+ index_name: Union["quoted_name", str, "conv"],
+ table_name: Optional[str] = None,
+ schema: Optional[str] = None,
+ _reverse: Optional["CreateIndexOp"] = None,
+ **kw
+ ) -> None:
self.index_name = index_name
self.table_name = table_name
self.schema = schema
self._reverse = _reverse
self.kw = kw
- def to_diff_tuple(self):
+ def to_diff_tuple(self) -> Tuple[str, "Index"]:
return ("remove_index", self.to_index())
- def reverse(self):
+ def reverse(self) -> "CreateIndexOp":
return CreateIndexOp.from_index(self.to_index())
@classmethod
- def from_index(cls, index):
+ def from_index(cls, index: "Index") -> "DropIndexOp":
+ assert index.table is not None
return cls(
index.name,
index.table.name,
@@ -879,7 +1028,9 @@ class DropIndexOp(MigrateOperation):
**index.kwargs,
)
- def to_index(self, migration_context=None):
+ def to_index(
+ self, migration_context: Optional["MigrationContext"] = None
+ ) -> "Index":
schema_obj = schemaobj.SchemaObjects(migration_context)
# need a dummy column name here since SQLAlchemy
@@ -894,8 +1045,13 @@ class DropIndexOp(MigrateOperation):
@classmethod
def drop_index(
- cls, operations, index_name, table_name=None, schema=None, **kw
- ):
+ cls,
+ operations: "Operations",
+ index_name: str,
+ table_name: Optional[str] = None,
+ schema: Optional[str] = None,
+ **kw
+ ) -> Optional["Table"]:
r"""Issue a "drop index" instruction using the current
migration context.
@@ -921,7 +1077,9 @@ class DropIndexOp(MigrateOperation):
return operations.invoke(op)
@classmethod
- def batch_drop_index(cls, operations, index_name, **kw):
+ def batch_drop_index(
+ cls, operations: BatchOperations, index_name: str, **kw
+ ) -> Optional["Table"]:
"""Issue a "drop index" instruction using the
current batch migration context.
@@ -946,13 +1104,13 @@ class CreateTableOp(MigrateOperation):
def __init__(
self,
- table_name,
- columns,
- schema=None,
- _namespace_metadata=None,
- _constraints_included=False,
+ table_name: str,
+ columns: Sequence[Union["Column", "Constraint"]],
+ schema: Optional[str] = None,
+ _namespace_metadata: Optional["MetaData"] = None,
+ _constraints_included: bool = False,
**kw
- ):
+ ) -> None:
self.table_name = table_name
self.columns = columns
self.schema = schema
@@ -963,22 +1121,24 @@ class CreateTableOp(MigrateOperation):
self._namespace_metadata = _namespace_metadata
self._constraints_included = _constraints_included
- def reverse(self):
+ def reverse(self) -> "DropTableOp":
return DropTableOp.from_table(
self.to_table(), _namespace_metadata=self._namespace_metadata
)
- def to_diff_tuple(self):
+ def to_diff_tuple(self) -> Tuple[str, "Table"]:
return ("add_table", self.to_table())
@classmethod
- def from_table(cls, table, _namespace_metadata=None):
+ def from_table(
+ cls, table: "Table", _namespace_metadata: Optional["MetaData"] = None
+ ) -> "CreateTableOp":
if _namespace_metadata is None:
_namespace_metadata = table.metadata
return cls(
table.name,
- list(table.c) + list(table.constraints),
+ list(table.c) + list(table.constraints), # type:ignore[arg-type]
schema=table.schema,
_namespace_metadata=_namespace_metadata,
# given a Table() object, this Table will contain full Index()
@@ -989,12 +1149,14 @@ class CreateTableOp(MigrateOperation):
# not doubled up. see #844 #848
_constraints_included=True,
comment=table.comment,
- info=table.info.copy(),
+ info=dict(table.info),
prefixes=list(table._prefixes),
**table.kwargs,
)
- def to_table(self, migration_context=None):
+ def to_table(
+ self, migration_context: Optional["MigrationContext"] = None
+ ) -> "Table":
schema_obj = schemaobj.SchemaObjects(migration_context)
return schema_obj.table(
@@ -1009,7 +1171,9 @@ class CreateTableOp(MigrateOperation):
)
@classmethod
- def create_table(cls, operations, table_name, *columns, **kw):
+ def create_table(
+ cls, operations: "Operations", table_name: str, *columns, **kw
+ ) -> Optional["Table"]:
r"""Issue a "create table" instruction using the current migration
context.
@@ -1094,7 +1258,13 @@ class CreateTableOp(MigrateOperation):
class DropTableOp(MigrateOperation):
"""Represent a drop table operation."""
- def __init__(self, table_name, schema=None, table_kw=None, _reverse=None):
+ def __init__(
+ self,
+ table_name: str,
+ schema: Optional[str] = None,
+ table_kw: Optional[MutableMapping[Any, Any]] = None,
+ _reverse: Optional["CreateTableOp"] = None,
+ ) -> None:
self.table_name = table_name
self.schema = schema
self.table_kw = table_kw or {}
@@ -1103,20 +1273,22 @@ class DropTableOp(MigrateOperation):
self.prefixes = self.table_kw.pop("prefixes", None)
self._reverse = _reverse
- def to_diff_tuple(self):
+ def to_diff_tuple(self) -> Tuple[str, "Table"]:
return ("remove_table", self.to_table())
- def reverse(self):
+ def reverse(self) -> "CreateTableOp":
return CreateTableOp.from_table(self.to_table())
@classmethod
- def from_table(cls, table, _namespace_metadata=None):
+ def from_table(
+ cls, table: "Table", _namespace_metadata: Optional["MetaData"] = None
+ ) -> "DropTableOp":
return cls(
table.name,
schema=table.schema,
table_kw={
"comment": table.comment,
- "info": table.info.copy(),
+ "info": dict(table.info),
"prefixes": list(table._prefixes),
**table.kwargs,
},
@@ -1125,7 +1297,9 @@ class DropTableOp(MigrateOperation):
),
)
- def to_table(self, migration_context=None):
+ def to_table(
+ self, migration_context: Optional["MigrationContext"] = None
+ ) -> "Table":
if self._reverse:
cols_and_constraints = self._reverse.columns
else:
@@ -1139,14 +1313,21 @@ class DropTableOp(MigrateOperation):
info=self.info.copy() if self.info else {},
prefixes=list(self.prefixes) if self.prefixes else [],
schema=self.schema,
- _constraints_included=bool(self._reverse)
- and self._reverse._constraints_included,
+ _constraints_included=self._reverse._constraints_included
+ if self._reverse
+ else False,
**self.table_kw,
)
return t
@classmethod
- def drop_table(cls, operations, table_name, schema=None, **kw):
+ def drop_table(
+ cls,
+ operations: "Operations",
+ table_name: str,
+ schema: Optional[str] = None,
+ **kw: Any
+ ) -> None:
r"""Issue a "drop table" instruction using the current
migration context.
@@ -1171,7 +1352,11 @@ class DropTableOp(MigrateOperation):
class AlterTableOp(MigrateOperation):
"""Represent an alter table operation."""
- def __init__(self, table_name, schema=None):
+ def __init__(
+ self,
+ table_name: str,
+ schema: Optional[str] = None,
+ ) -> None:
self.table_name = table_name
self.schema = schema
@@ -1180,14 +1365,23 @@ class AlterTableOp(MigrateOperation):
class RenameTableOp(AlterTableOp):
"""Represent a rename table operation."""
- def __init__(self, old_table_name, new_table_name, schema=None):
+ def __init__(
+ self,
+ old_table_name: str,
+ new_table_name: str,
+ schema: Optional[str] = None,
+ ) -> None:
super(RenameTableOp, self).__init__(old_table_name, schema=schema)
self.new_table_name = new_table_name
@classmethod
def rename_table(
- cls, operations, old_table_name, new_table_name, schema=None
- ):
+ cls,
+ operations: "Operations",
+ old_table_name: str,
+ new_table_name: str,
+ schema: Optional[str] = None,
+ ) -> Optional["Table"]:
"""Emit an ALTER TABLE to rename a table.
:param old_table_name: old name.
@@ -1210,8 +1404,12 @@ class CreateTableCommentOp(AlterTableOp):
"""Represent a COMMENT ON `table` operation."""
def __init__(
- self, table_name, comment, schema=None, existing_comment=None
- ):
+ self,
+ table_name: str,
+ comment: Optional[str],
+ schema: Optional[str] = None,
+ existing_comment: Optional[str] = None,
+ ) -> None:
self.table_name = table_name
self.comment = comment
self.existing_comment = existing_comment
@@ -1220,12 +1418,12 @@ class CreateTableCommentOp(AlterTableOp):
@classmethod
def create_table_comment(
cls,
- operations,
- table_name,
- comment,
- existing_comment=None,
- schema=None,
- ):
+ operations: "Operations",
+ table_name: str,
+ comment: Optional[str],
+ existing_comment: None = None,
+ schema: Optional[str] = None,
+ ) -> Optional["Table"]:
"""Emit a COMMENT ON operation to set the comment for a table.
.. versionadded:: 1.0.6
@@ -1317,15 +1515,24 @@ class CreateTableCommentOp(AlterTableOp):
class DropTableCommentOp(AlterTableOp):
"""Represent an operation to remove the comment from a table."""
- def __init__(self, table_name, schema=None, existing_comment=None):
+ def __init__(
+ self,
+ table_name: str,
+ schema: Optional[str] = None,
+ existing_comment: Optional[str] = None,
+ ) -> None:
self.table_name = table_name
self.existing_comment = existing_comment
self.schema = schema
@classmethod
def drop_table_comment(
- cls, operations, table_name, existing_comment=None, schema=None
- ):
+ cls,
+ operations: "Operations",
+ table_name: str,
+ existing_comment: Optional[str] = None,
+ schema: Optional[str] = None,
+ ) -> Optional["Table"]:
"""Issue a "drop table comment" operation to
remove an existing comment set on a table.
@@ -1388,20 +1595,20 @@ class AlterColumnOp(AlterTableOp):
def __init__(
self,
- table_name,
- column_name,
- schema=None,
- existing_type=None,
- existing_server_default=False,
- existing_nullable=None,
- existing_comment=None,
- modify_nullable=None,
- modify_comment=False,
- modify_server_default=False,
- modify_name=None,
- modify_type=None,
+ table_name: str,
+ column_name: str,
+ schema: Optional[str] = None,
+ existing_type: Optional[Any] = None,
+ existing_server_default: Any = False,
+ existing_nullable: Optional[bool] = None,
+ existing_comment: Optional[str] = None,
+ modify_nullable: Optional[bool] = None,
+ modify_comment: Optional[Union[str, bool]] = False,
+ modify_server_default: Any = False,
+ modify_name: Optional[str] = None,
+ modify_type: Optional[Any] = None,
**kw
- ):
+ ) -> None:
super(AlterColumnOp, self).__init__(table_name, schema=schema)
self.column_name = column_name
self.existing_type = existing_type
@@ -1415,7 +1622,7 @@ class AlterColumnOp(AlterTableOp):
self.modify_type = modify_type
self.kw = kw
- def to_diff_tuple(self):
+ def to_diff_tuple(self) -> Any:
col_diff = []
schema, tname, cname = self.schema, self.table_name, self.column_name
@@ -1495,7 +1702,7 @@ class AlterColumnOp(AlterTableOp):
return col_diff
- def has_changes(self):
+ def has_changes(self) -> bool:
hc1 = (
self.modify_nullable is not None
or self.modify_server_default is not False
@@ -1510,7 +1717,7 @@ class AlterColumnOp(AlterTableOp):
else:
return False
- def reverse(self):
+ def reverse(self) -> "AlterColumnOp":
kw = self.kw.copy()
kw["existing_type"] = self.existing_type
@@ -1546,21 +1753,25 @@ class AlterColumnOp(AlterTableOp):
@classmethod
def alter_column(
cls,
- operations,
- table_name,
- column_name,
- nullable=None,
- comment=False,
- server_default=False,
- new_column_name=None,
- type_=None,
- existing_type=None,
- existing_server_default=False,
- existing_nullable=None,
- existing_comment=None,
- schema=None,
+ operations: Operations,
+ table_name: str,
+ column_name: str,
+ nullable: Optional[bool] = None,
+ comment: Optional[Union[str, bool]] = False,
+ server_default: Any = False,
+ new_column_name: Optional[str] = None,
+ type_: Optional[Union["TypeEngine", Type["TypeEngine"]]] = None,
+ existing_type: Optional[
+ Union["TypeEngine", Type["TypeEngine"]]
+ ] = None,
+ existing_server_default: Optional[
+ Union[str, bool, "Identity", "Computed"]
+ ] = False,
+ existing_nullable: Optional[bool] = None,
+ existing_comment: Optional[str] = None,
+ schema: Optional[str] = None,
**kw
- ):
+ ) -> Optional["Table"]:
r"""Issue an "alter column" instruction using the
current migration context.
@@ -1671,21 +1882,23 @@ class AlterColumnOp(AlterTableOp):
@classmethod
def batch_alter_column(
cls,
- operations,
- column_name,
- nullable=None,
- comment=False,
- server_default=False,
- new_column_name=None,
- type_=None,
- existing_type=None,
- existing_server_default=False,
- existing_nullable=None,
- existing_comment=None,
- insert_before=None,
- insert_after=None,
+ operations: BatchOperations,
+ column_name: str,
+ nullable: Optional[bool] = None,
+ comment: bool = False,
+ server_default: Union["Function", bool] = False,
+ new_column_name: Optional[str] = None,
+ type_: Optional[Union["TypeEngine", Type["TypeEngine"]]] = None,
+ existing_type: Optional[
+ Union["TypeEngine", Type["TypeEngine"]]
+ ] = None,
+ existing_server_default: bool = False,
+ existing_nullable: None = None,
+ existing_comment: None = None,
+ insert_before: None = None,
+ insert_after: None = None,
**kw
- ):
+ ) -> Optional["Table"]:
"""Issue an "alter column" instruction using the current
batch migration context.
@@ -1736,32 +1949,51 @@ class AlterColumnOp(AlterTableOp):
class AddColumnOp(AlterTableOp):
"""Represent an add column operation."""
- def __init__(self, table_name, column, schema=None, **kw):
+ def __init__(
+ self,
+ table_name: str,
+ column: "Column",
+ schema: Optional[str] = None,
+ **kw
+ ) -> None:
super(AddColumnOp, self).__init__(table_name, schema=schema)
self.column = column
self.kw = kw
- def reverse(self):
+ def reverse(self) -> "DropColumnOp":
return DropColumnOp.from_column_and_tablename(
self.schema, self.table_name, self.column
)
- def to_diff_tuple(self):
+ def to_diff_tuple(
+ self,
+ ) -> Tuple[str, Optional[str], str, "Column"]:
return ("add_column", self.schema, self.table_name, self.column)
- def to_column(self):
+ def to_column(self) -> "Column":
return self.column
@classmethod
- def from_column(cls, col):
+ def from_column(cls, col: "Column") -> "AddColumnOp":
return cls(col.table.name, col, schema=col.table.schema)
@classmethod
- def from_column_and_tablename(cls, schema, tname, col):
+ def from_column_and_tablename(
+ cls,
+ schema: Optional[str],
+ tname: str,
+ col: "Column",
+ ) -> "AddColumnOp":
return cls(tname, col, schema=schema)
@classmethod
- def add_column(cls, operations, table_name, column, schema=None):
+ def add_column(
+ cls,
+ operations: "Operations",
+ table_name: str,
+ column: "Column",
+ schema: Optional[str] = None,
+ ) -> Optional["Table"]:
"""Issue an "add column" instruction using the current
migration context.
@@ -1816,8 +2048,12 @@ class AddColumnOp(AlterTableOp):
@classmethod
def batch_add_column(
- cls, operations, column, insert_before=None, insert_after=None
- ):
+ cls,
+ operations: "BatchOperations",
+ column: "Column",
+ insert_before: Optional[str] = None,
+ insert_after: Optional[str] = None,
+ ) -> Optional["Table"]:
"""Issue an "add column" instruction using the current
batch migration context.
@@ -1848,14 +2084,21 @@ class DropColumnOp(AlterTableOp):
"""Represent a drop column operation."""
def __init__(
- self, table_name, column_name, schema=None, _reverse=None, **kw
- ):
+ self,
+ table_name: str,
+ column_name: str,
+ schema: Optional[str] = None,
+ _reverse: Optional["AddColumnOp"] = None,
+ **kw
+ ) -> None:
super(DropColumnOp, self).__init__(table_name, schema=schema)
self.column_name = column_name
self.kw = kw
self._reverse = _reverse
- def to_diff_tuple(self):
+ def to_diff_tuple(
+ self,
+ ) -> Tuple[str, Optional[str], str, "Column"]:
return (
"remove_column",
self.schema,
@@ -1863,7 +2106,7 @@ class DropColumnOp(AlterTableOp):
self.to_column(),
)
- def reverse(self):
+ def reverse(self) -> "AddColumnOp":
if self._reverse is None:
raise ValueError(
"operation is not reversible; "
@@ -1875,7 +2118,12 @@ class DropColumnOp(AlterTableOp):
)
@classmethod
- def from_column_and_tablename(cls, schema, tname, col):
+ def from_column_and_tablename(
+ cls,
+ schema: Optional[str],
+ tname: str,
+ col: "Column",
+ ) -> "DropColumnOp":
return cls(
tname,
col.name,
@@ -1883,7 +2131,9 @@ class DropColumnOp(AlterTableOp):
_reverse=AddColumnOp.from_column_and_tablename(schema, tname, col),
)
- def to_column(self, migration_context=None):
+ def to_column(
+ self, migration_context: Optional["MigrationContext"] = None
+ ) -> "Column":
if self._reverse is not None:
return self._reverse.column
schema_obj = schemaobj.SchemaObjects(migration_context)
@@ -1891,8 +2141,13 @@ class DropColumnOp(AlterTableOp):
@classmethod
def drop_column(
- cls, operations, table_name, column_name, schema=None, **kw
- ):
+ cls,
+ operations: "Operations",
+ table_name: str,
+ column_name: str,
+ schema: Optional[str] = None,
+ **kw
+ ) -> Optional["Table"]:
"""Issue a "drop column" instruction using the current
migration context.
@@ -1934,7 +2189,9 @@ class DropColumnOp(AlterTableOp):
return operations.invoke(op)
@classmethod
- def batch_drop_column(cls, operations, column_name, **kw):
+ def batch_drop_column(
+ cls, operations: "BatchOperations", column_name: str, **kw
+ ) -> Optional["Table"]:
"""Issue a "drop column" instruction using the current
batch migration context.
@@ -1956,13 +2213,24 @@ class DropColumnOp(AlterTableOp):
class BulkInsertOp(MigrateOperation):
"""Represent a bulk insert operation."""
- def __init__(self, table, rows, multiinsert=True):
+ def __init__(
+ self,
+ table: Union["Table", "TableClause"],
+ rows: List[dict],
+ multiinsert: bool = True,
+ ) -> None:
self.table = table
self.rows = rows
self.multiinsert = multiinsert
@classmethod
- def bulk_insert(cls, operations, table, rows, multiinsert=True):
+ def bulk_insert(
+ cls,
+ operations: Operations,
+ table: Union["Table", "TableClause"],
+ rows: List[dict],
+ multiinsert: bool = True,
+ ) -> None:
"""Issue a "bulk insert" operation using the current
migration context.
@@ -2046,12 +2314,21 @@ class BulkInsertOp(MigrateOperation):
class ExecuteSQLOp(MigrateOperation):
"""Represent an execute SQL operation."""
- def __init__(self, sqltext, execution_options=None):
+ def __init__(
+ self,
+ sqltext: Union["Update", str, "Insert", "TextClause"],
+ execution_options: None = None,
+ ) -> None:
self.sqltext = sqltext
self.execution_options = execution_options
@classmethod
- def execute(cls, operations, sqltext, execution_options=None):
+ def execute(
+ cls,
+ operations: Operations,
+ sqltext: Union[str, "TextClause", "Update"],
+ execution_options: None = None,
+ ) -> Optional["Table"]:
r"""Execute the given SQL using the current migration context.
The given SQL can be a plain string, e.g.::
@@ -2140,20 +2417,22 @@ class ExecuteSQLOp(MigrateOperation):
class OpContainer(MigrateOperation):
"""Represent a sequence of operations operation."""
- def __init__(self, ops=()):
- self.ops = ops
+ def __init__(self, ops: Sequence[MigrateOperation] = ()) -> None:
+ self.ops = list(ops)
- def is_empty(self):
+ def is_empty(self) -> bool:
return not self.ops
- def as_diffs(self):
+ def as_diffs(self) -> Any:
return list(OpContainer._ops_as_diffs(self))
@classmethod
- def _ops_as_diffs(cls, migrations):
+ def _ops_as_diffs(
+ cls, migrations: "OpContainer"
+ ) -> Iterator[Tuple[Any, ...]]:
for op in migrations.ops:
if hasattr(op, "ops"):
- for sub_op in cls._ops_as_diffs(op):
+ for sub_op in cls._ops_as_diffs(cast("OpContainer", op)):
yield sub_op
else:
yield op.to_diff_tuple()
@@ -2162,12 +2441,17 @@ class OpContainer(MigrateOperation):
class ModifyTableOps(OpContainer):
"""Contains a sequence of operations that all apply to a single Table."""
- def __init__(self, table_name, ops, schema=None):
+ def __init__(
+ self,
+ table_name: str,
+ ops: Sequence[MigrateOperation],
+ schema: Optional[str] = None,
+ ) -> None:
super(ModifyTableOps, self).__init__(ops)
self.table_name = table_name
self.schema = schema
- def reverse(self):
+ def reverse(self) -> "ModifyTableOps":
return ModifyTableOps(
self.table_name,
ops=list(reversed([op.reverse() for op in self.ops])),
@@ -2185,17 +2469,21 @@ class UpgradeOps(OpContainer):
"""
- def __init__(self, ops=(), upgrade_token="upgrades"):
+ def __init__(
+ self,
+ ops: Sequence[MigrateOperation] = (),
+ upgrade_token: str = "upgrades",
+ ) -> None:
super(UpgradeOps, self).__init__(ops=ops)
self.upgrade_token = upgrade_token
- def reverse_into(self, downgrade_ops):
- downgrade_ops.ops[:] = list(
+ def reverse_into(self, downgrade_ops: "DowngradeOps") -> "DowngradeOps":
+ downgrade_ops.ops[:] = list( # type:ignore[index]
reversed([op.reverse() for op in self.ops])
)
return downgrade_ops
- def reverse(self):
+ def reverse(self) -> "DowngradeOps":
return self.reverse_into(DowngradeOps(ops=[]))
@@ -2209,7 +2497,11 @@ class DowngradeOps(OpContainer):
"""
- def __init__(self, ops=(), downgrade_token="downgrades"):
+ def __init__(
+ self,
+ ops: Sequence[MigrateOperation] = (),
+ downgrade_token: str = "downgrades",
+ ) -> None:
super(DowngradeOps, self).__init__(ops=ops)
self.downgrade_token = downgrade_token
@@ -2243,19 +2535,21 @@ class MigrationScript(MigrateOperation):
"""
+ _needs_render: Optional[bool]
+
def __init__(
self,
- rev_id,
- upgrade_ops,
- downgrade_ops,
- message=None,
- imports=set(),
- head=None,
- splice=None,
- branch_label=None,
- version_path=None,
- depends_on=None,
- ):
+ rev_id: Optional[str],
+ upgrade_ops: "UpgradeOps",
+ downgrade_ops: "DowngradeOps",
+ message: Optional[str] = None,
+ imports: Set[str] = set(),
+ head: Optional[str] = None,
+ splice: Optional[bool] = None,
+ branch_label: Optional[str] = None,
+ version_path: Optional[str] = None,
+ depends_on: Optional[Union[str, Sequence[str]]] = None,
+ ) -> None:
self.rev_id = rev_id
self.message = message
self.imports = imports
@@ -2318,7 +2612,7 @@ class MigrationScript(MigrateOperation):
assert isinstance(elem, DowngradeOps)
@property
- def upgrade_ops_list(self):
+ def upgrade_ops_list(self) -> List["UpgradeOps"]:
"""A list of :class:`.UpgradeOps` instances.
This is used in place of the :attr:`.MigrationScript.upgrade_ops`
@@ -2329,7 +2623,7 @@ class MigrationScript(MigrateOperation):
return self._upgrade_ops
@property
- def downgrade_ops_list(self):
+ def downgrade_ops_list(self) -> List["DowngradeOps"]:
"""A list of :class:`.DowngradeOps` instances.
This is used in place of the :attr:`.MigrationScript.downgrade_ops`
diff --git a/alembic/operations/schemaobj.py b/alembic/operations/schemaobj.py
index adbffdc..0d40dc7 100644
--- a/alembic/operations/schemaobj.py
+++ b/alembic/operations/schemaobj.py
@@ -1,3 +1,12 @@
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import Union
+
from sqlalchemy import schema as sa_schema
from sqlalchemy.sql.schema import Column
from sqlalchemy.sql.schema import Constraint
@@ -9,34 +18,59 @@ from .. import util
from ..util import sqla_compat
from ..util.compat import string_types
+if TYPE_CHECKING:
+ from sqlalchemy.sql.elements import ColumnElement
+ from sqlalchemy.sql.elements import TextClause
+ from sqlalchemy.sql.schema import CheckConstraint
+ from sqlalchemy.sql.schema import ForeignKey
+ from sqlalchemy.sql.schema import ForeignKeyConstraint
+ from sqlalchemy.sql.schema import MetaData
+ from sqlalchemy.sql.schema import PrimaryKeyConstraint
+ from sqlalchemy.sql.schema import Table
+ from sqlalchemy.sql.schema import UniqueConstraint
+ from sqlalchemy.sql.type_api import TypeEngine
+
+ from ..runtime.migration import MigrationContext
+
class SchemaObjects:
- def __init__(self, migration_context=None):
+ def __init__(
+ self, migration_context: Optional["MigrationContext"] = None
+ ) -> None:
self.migration_context = migration_context
- def primary_key_constraint(self, name, table_name, cols, schema=None):
+ def primary_key_constraint(
+ self,
+ name: Optional[str],
+ table_name: str,
+ cols: Sequence[str],
+ schema: Optional[str] = None,
+ **dialect_kw
+ ) -> "PrimaryKeyConstraint":
m = self.metadata()
columns = [sa_schema.Column(n, NULLTYPE) for n in cols]
t = sa_schema.Table(table_name, m, *columns, schema=schema)
- p = sa_schema.PrimaryKeyConstraint(*[t.c[n] for n in cols], name=name)
+ p = sa_schema.PrimaryKeyConstraint(
+ *[t.c[n] for n in cols], name=name, **dialect_kw
+ )
return p
def foreign_key_constraint(
self,
- name,
- source,
- referent,
- local_cols,
- remote_cols,
- onupdate=None,
- ondelete=None,
- deferrable=None,
- source_schema=None,
- referent_schema=None,
- initially=None,
- match=None,
+ name: Optional[str],
+ source: str,
+ referent: str,
+ local_cols: List[str],
+ remote_cols: List[str],
+ onupdate: Optional[str] = None,
+ ondelete: Optional[str] = None,
+ deferrable: Optional[bool] = None,
+ source_schema: Optional[str] = None,
+ referent_schema: Optional[str] = None,
+ initially: Optional[str] = None,
+ match: Optional[str] = None,
**dialect_kw
- ):
+ ) -> "ForeignKeyConstraint":
m = self.metadata()
if source == referent and source_schema == referent_schema:
t1_cols = local_cols + remote_cols
@@ -78,7 +112,14 @@ class SchemaObjects:
return f
- def unique_constraint(self, name, source, local_cols, schema=None, **kw):
+ def unique_constraint(
+ self,
+ name: Optional[str],
+ source: str,
+ local_cols: Sequence[str],
+ schema: Optional[str] = None,
+ **kw
+ ) -> "UniqueConstraint":
t = sa_schema.Table(
source,
self.metadata(),
@@ -92,7 +133,14 @@ class SchemaObjects:
t.append_constraint(uq)
return uq
- def check_constraint(self, name, source, condition, schema=None, **kw):
+ def check_constraint(
+ self,
+ name: Optional[str],
+ source: str,
+ condition: Union["TextClause", "ColumnElement[Any]"],
+ schema: Optional[str] = None,
+ **kw
+ ) -> Union["CheckConstraint"]:
t = sa_schema.Table(
source,
self.metadata(),
@@ -103,9 +151,16 @@ class SchemaObjects:
t.append_constraint(ck)
return ck
- def generic_constraint(self, name, table_name, type_, schema=None, **kw):
+ def generic_constraint(
+ self,
+ name: Optional[str],
+ table_name: str,
+ type_: Optional[str],
+ schema: Optional[str] = None,
+ **kw
+ ) -> Any:
t = self.table(table_name, schema=schema)
- types = {
+ types: Dict[Optional[str], Any] = {
"foreignkey": lambda name: sa_schema.ForeignKeyConstraint(
[], [], name=name
),
@@ -126,7 +181,7 @@ class SchemaObjects:
t.append_constraint(const)
return const
- def metadata(self):
+ def metadata(self) -> "MetaData":
kw = {}
if (
self.migration_context is not None
@@ -137,7 +192,7 @@ class SchemaObjects:
kw["naming_convention"] = mt.naming_convention
return sa_schema.MetaData(**kw)
- def table(self, name, *columns, **kw):
+ def table(self, name: str, *columns, **kw) -> "Table":
m = self.metadata()
cols = [
@@ -173,10 +228,17 @@ class SchemaObjects:
self._ensure_table_for_fk(m, f)
return t
- def column(self, name, type_, **kw):
+ def column(self, name: str, type_: "TypeEngine", **kw) -> "Column":
return sa_schema.Column(name, type_, **kw)
- def index(self, name, tablename, columns, schema=None, **kw):
+ def index(
+ self,
+ name: str,
+ tablename: Optional[str],
+ columns: Sequence[Union[str, "TextClause", "ColumnElement[Any]"]],
+ schema: Optional[str] = None,
+ **kw
+ ) -> "Index":
t = sa_schema.Table(
tablename or "no_table",
self.metadata(),
@@ -190,23 +252,27 @@ class SchemaObjects:
)
return idx
- def _parse_table_key(self, table_key):
+ def _parse_table_key(self, table_key: str) -> Tuple[Optional[str], str]:
if "." in table_key:
tokens = table_key.split(".")
- sname = ".".join(tokens[0:-1])
+ sname: Optional[str] = ".".join(tokens[0:-1])
tname = tokens[-1]
else:
tname = table_key
sname = None
return (sname, tname)
- def _ensure_table_for_fk(self, metadata, fk):
+ def _ensure_table_for_fk(
+ self, metadata: "MetaData", fk: "ForeignKey"
+ ) -> None:
"""create a placeholder Table object for the referent of a
ForeignKey.
"""
- if isinstance(fk._colspec, string_types):
- table_key, cname = fk._colspec.rsplit(".", 1)
+ if isinstance(fk._colspec, string_types): # type:ignore[attr-defined]
+ table_key, cname = fk._colspec.rsplit( # type:ignore[attr-defined]
+ ".", 1
+ )
sname, tname = self._parse_table_key(table_key)
if table_key not in metadata.tables:
rel_t = sa_schema.Table(tname, metadata, schema=sname)
diff --git a/alembic/operations/toimpl.py b/alembic/operations/toimpl.py
index 10a41e4..f97983e 100644
--- a/alembic/operations/toimpl.py
+++ b/alembic/operations/toimpl.py
@@ -1,12 +1,19 @@
+from typing import TYPE_CHECKING
+
from sqlalchemy import schema as sa_schema
from . import ops
from .base import Operations
from ..util.sqla_compat import _copy
+if TYPE_CHECKING:
+ from sqlalchemy.sql.schema import Table
+
@Operations.implementation_for(ops.AlterColumnOp)
-def alter_column(operations, operation):
+def alter_column(
+ operations: "Operations", operation: "ops.AlterColumnOp"
+) -> None:
compiler = operations.impl.dialect.statement_compiler(
operations.impl.dialect, None
@@ -68,14 +75,16 @@ def alter_column(operations, operation):
@Operations.implementation_for(ops.DropTableOp)
-def drop_table(operations, operation):
+def drop_table(operations: "Operations", operation: "ops.DropTableOp") -> None:
operations.impl.drop_table(
operation.to_table(operations.migration_context)
)
@Operations.implementation_for(ops.DropColumnOp)
-def drop_column(operations, operation):
+def drop_column(
+ operations: "Operations", operation: "ops.DropColumnOp"
+) -> None:
column = operation.to_column(operations.migration_context)
operations.impl.drop_column(
operation.table_name, column, schema=operation.schema, **operation.kw
@@ -83,46 +92,56 @@ def drop_column(operations, operation):
@Operations.implementation_for(ops.CreateIndexOp)
-def create_index(operations, operation):
+def create_index(
+ operations: "Operations", operation: "ops.CreateIndexOp"
+) -> None:
idx = operation.to_index(operations.migration_context)
operations.impl.create_index(idx)
@Operations.implementation_for(ops.DropIndexOp)
-def drop_index(operations, operation):
+def drop_index(operations: "Operations", operation: "ops.DropIndexOp") -> None:
operations.impl.drop_index(
operation.to_index(operations.migration_context)
)
@Operations.implementation_for(ops.CreateTableOp)
-def create_table(operations, operation):
+def create_table(
+ operations: "Operations", operation: "ops.CreateTableOp"
+) -> "Table":
table = operation.to_table(operations.migration_context)
operations.impl.create_table(table)
return table
@Operations.implementation_for(ops.RenameTableOp)
-def rename_table(operations, operation):
+def rename_table(
+ operations: "Operations", operation: "ops.RenameTableOp"
+) -> None:
operations.impl.rename_table(
operation.table_name, operation.new_table_name, schema=operation.schema
)
@Operations.implementation_for(ops.CreateTableCommentOp)
-def create_table_comment(operations, operation):
+def create_table_comment(
+ operations: "Operations", operation: "ops.CreateTableCommentOp"
+) -> None:
table = operation.to_table(operations.migration_context)
operations.impl.create_table_comment(table)
@Operations.implementation_for(ops.DropTableCommentOp)
-def drop_table_comment(operations, operation):
+def drop_table_comment(
+ operations: "Operations", operation: "ops.DropTableCommentOp"
+) -> None:
table = operation.to_table(operations.migration_context)
operations.impl.drop_table_comment(table)
@Operations.implementation_for(ops.AddColumnOp)
-def add_column(operations, operation):
+def add_column(operations: "Operations", operation: "ops.AddColumnOp") -> None:
table_name = operation.table_name
column = operation.column
schema = operation.schema
@@ -150,14 +169,18 @@ def add_column(operations, operation):
@Operations.implementation_for(ops.AddConstraintOp)
-def create_constraint(operations, operation):
+def create_constraint(
+ operations: "Operations", operation: "ops.AddConstraintOp"
+) -> None:
operations.impl.add_constraint(
operation.to_constraint(operations.migration_context)
)
@Operations.implementation_for(ops.DropConstraintOp)
-def drop_constraint(operations, operation):
+def drop_constraint(
+ operations: "Operations", operation: "ops.DropConstraintOp"
+) -> None:
operations.impl.drop_constraint(
operations.schema_obj.generic_constraint(
operation.constraint_name,
@@ -169,14 +192,18 @@ def drop_constraint(operations, operation):
@Operations.implementation_for(ops.BulkInsertOp)
-def bulk_insert(operations, operation):
+def bulk_insert(
+ operations: "Operations", operation: "ops.BulkInsertOp"
+) -> None:
operations.impl.bulk_insert(
operation.table, operation.rows, multiinsert=operation.multiinsert
)
@Operations.implementation_for(ops.ExecuteSQLOp)
-def execute_sql(operations, operation):
+def execute_sql(
+ operations: "Operations", operation: "ops.ExecuteSQLOp"
+) -> None:
operations.migration_context.impl.execute(
operation.sqltext, execution_options=operation.execution_options
)
diff --git a/alembic/runtime/environment.py b/alembic/runtime/environment.py
index e4f4d42..f3473de 100644
--- a/alembic/runtime/environment.py
+++ b/alembic/runtime/environment.py
@@ -1,7 +1,30 @@
+from typing import Callable
+from typing import ContextManager
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import overload
+from typing import TextIO
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import Union
+
from .migration import MigrationContext
from .. import util
from ..operations import Operations
+if TYPE_CHECKING:
+ from typing import Literal
+
+ from sqlalchemy.engine.base import Connection
+ from sqlalchemy.sql.schema import MetaData
+
+ from .migration import _ProxyTransaction
+ from ..config import Config
+ from ..script.base import ScriptDirectory
+
+_RevNumber = Optional[Union[str, Tuple[str, ...]]]
+
class EnvironmentContext(util.ModuleClsProxy):
@@ -66,21 +89,23 @@ class EnvironmentContext(util.ModuleClsProxy):
"""
- _migration_context = None
+ _migration_context: Optional["MigrationContext"] = None
- config = None
+ config: "Config" = None # type:ignore[assignment]
"""An instance of :class:`.Config` representing the
configuration file contents as well as other variables
set programmatically within it."""
- script = None
+ script: "ScriptDirectory" = None # type:ignore[assignment]
"""An instance of :class:`.ScriptDirectory` which provides
programmatic access to version files within the ``versions/``
directory.
"""
- def __init__(self, config, script, **kw):
+ def __init__(
+ self, config: "Config", script: "ScriptDirectory", **kw
+ ) -> None:
r"""Construct a new :class:`.EnvironmentContext`.
:param config: a :class:`.Config` instance.
@@ -94,7 +119,7 @@ class EnvironmentContext(util.ModuleClsProxy):
self.script = script
self.context_opts = kw
- def __enter__(self):
+ def __enter__(self) -> "EnvironmentContext":
"""Establish a context which provides a
:class:`.EnvironmentContext` object to
env.py scripts.
@@ -106,10 +131,10 @@ class EnvironmentContext(util.ModuleClsProxy):
self._install_proxy()
return self
- def __exit__(self, *arg, **kw):
+ def __exit__(self, *arg, **kw) -> None:
self._remove_proxy()
- def is_offline_mode(self):
+ def is_offline_mode(self) -> bool:
"""Return True if the current migrations environment
is running in "offline mode".
@@ -136,10 +161,10 @@ class EnvironmentContext(util.ModuleClsProxy):
"""
return self.get_context().impl.transactional_ddl
- def requires_connection(self):
+ def requires_connection(self) -> bool:
return not self.is_offline_mode()
- def get_head_revision(self):
+ def get_head_revision(self) -> _RevNumber:
"""Return the hex identifier of the 'head' script revision.
If the script directory has multiple heads, this
@@ -154,7 +179,7 @@ class EnvironmentContext(util.ModuleClsProxy):
"""
return self.script.as_revision_number("head")
- def get_head_revisions(self):
+ def get_head_revisions(self) -> _RevNumber:
"""Return the hex identifier of the 'heads' script revision(s).
This returns a tuple containing the version number of all
@@ -166,7 +191,7 @@ class EnvironmentContext(util.ModuleClsProxy):
"""
return self.script.as_revision_number("heads")
- def get_starting_revision_argument(self):
+ def get_starting_revision_argument(self) -> _RevNumber:
"""Return the 'starting revision' argument,
if the revision was passed using ``start:end``.
@@ -195,7 +220,7 @@ class EnvironmentContext(util.ModuleClsProxy):
"No starting revision argument is available."
)
- def get_revision_argument(self):
+ def get_revision_argument(self) -> _RevNumber:
"""Get the 'destination' revision argument.
This is typically the argument passed to the
@@ -213,7 +238,7 @@ class EnvironmentContext(util.ModuleClsProxy):
self.context_opts["destination_rev"]
)
- def get_tag_argument(self):
+ def get_tag_argument(self) -> Optional[str]:
"""Return the value passed for the ``--tag`` argument, if any.
The ``--tag`` argument is not used directly by Alembic,
@@ -233,7 +258,19 @@ class EnvironmentContext(util.ModuleClsProxy):
"""
return self.context_opts.get("tag", None)
- def get_x_argument(self, as_dictionary=False):
+ @overload
+ def get_x_argument( # type:ignore[misc]
+ self, as_dictionary: "Literal[False]" = ...
+ ) -> List[str]:
+ ...
+
+ @overload
+ def get_x_argument( # type:ignore[misc]
+ self, as_dictionary: "Literal[True]" = ...
+ ) -> Dict[str, str]:
+ ...
+
+ def get_x_argument(self, as_dictionary: bool = False):
"""Return the value(s) passed for the ``-x`` argument, if any.
The ``-x`` argument is an open ended flag that allows any user-defined
@@ -282,34 +319,34 @@ class EnvironmentContext(util.ModuleClsProxy):
def configure(
self,
- connection=None,
- url=None,
- dialect_name=None,
- dialect_opts=None,
- transactional_ddl=None,
- transaction_per_migration=False,
- output_buffer=None,
- starting_rev=None,
- tag=None,
- template_args=None,
- render_as_batch=False,
- target_metadata=None,
- include_name=None,
- include_object=None,
- include_schemas=False,
- process_revision_directives=None,
- compare_type=False,
- compare_server_default=False,
- render_item=None,
- literal_binds=False,
- upgrade_token="upgrades",
- downgrade_token="downgrades",
- alembic_module_prefix="op.",
- sqlalchemy_module_prefix="sa.",
- user_module_prefix=None,
- on_version_apply=None,
+ connection: Optional["Connection"] = None,
+ url: Optional[str] = None,
+ dialect_name: Optional[str] = None,
+ dialect_opts: Optional[dict] = None,
+ transactional_ddl: Optional[bool] = None,
+ transaction_per_migration: bool = False,
+ output_buffer: Optional[TextIO] = None,
+ starting_rev: Optional[str] = None,
+ tag: Optional[str] = None,
+ template_args: Optional[dict] = None,
+ render_as_batch: bool = False,
+ target_metadata: Optional["MetaData"] = None,
+ include_name: Optional[Callable] = None,
+ include_object: Optional[Callable] = None,
+ include_schemas: bool = False,
+ process_revision_directives: Optional[Callable] = None,
+ compare_type: bool = False,
+ compare_server_default: bool = False,
+ render_item: Optional[Callable] = None,
+ literal_binds: bool = False,
+ upgrade_token: str = "upgrades",
+ downgrade_token: str = "downgrades",
+ alembic_module_prefix: str = "op.",
+ sqlalchemy_module_prefix: str = "sa.",
+ user_module_prefix: Optional[str] = None,
+ on_version_apply: Optional[Callable] = None,
**kw
- ):
+ ) -> None:
"""Configure a :class:`.MigrationContext` within this
:class:`.EnvironmentContext` which will provide database
connectivity and other configuration to a series of
@@ -789,7 +826,7 @@ class EnvironmentContext(util.ModuleClsProxy):
opts=opts,
)
- def run_migrations(self, **kw):
+ def run_migrations(self, **kw) -> None:
"""Run migrations as determined by the current command line
configuration
as well as versioning information present (or not) in the current
@@ -809,6 +846,7 @@ class EnvironmentContext(util.ModuleClsProxy):
first been made available via :meth:`.configure`.
"""
+ assert self._migration_context is not None
with Operations.context(self._migration_context):
self.get_context().run_migrations(**kw)
@@ -837,7 +875,9 @@ class EnvironmentContext(util.ModuleClsProxy):
"""
self.get_context().impl.static_output(text)
- def begin_transaction(self):
+ def begin_transaction(
+ self,
+ ) -> Union["_ProxyTransaction", ContextManager]:
"""Return a context manager that will
enclose an operation within a "transaction",
as defined by the environment's offline
@@ -883,7 +923,7 @@ class EnvironmentContext(util.ModuleClsProxy):
return self.get_context().begin_transaction()
- def get_context(self):
+ def get_context(self) -> "MigrationContext":
"""Return the current :class:`.MigrationContext` object.
If :meth:`.EnvironmentContext.configure` has not been
diff --git a/alembic/runtime/migration.py b/alembic/runtime/migration.py
index 5ed2136..c64e91f 100644
--- a/alembic/runtime/migration.py
+++ b/alembic/runtime/migration.py
@@ -1,6 +1,18 @@
from contextlib import contextmanager
import logging
import sys
+from typing import Any
+from typing import cast
+from typing import Collection
+from typing import ContextManager
+from typing import Dict
+from typing import Iterator
+from typing import List
+from typing import Optional
+from typing import Set
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import Union
from sqlalchemy import Column
from sqlalchemy import literal_column
@@ -17,29 +29,46 @@ from .. import util
from ..util import sqla_compat
from ..util.compat import EncodedIO
+if TYPE_CHECKING:
+ from sqlalchemy.engine import Dialect
+ from sqlalchemy.engine.base import Connection
+ from sqlalchemy.engine.base import Transaction
+ from sqlalchemy.engine.mock import MockConnection
+
+ from .environment import EnvironmentContext
+ from ..config import Config
+ from ..script.base import Script
+ from ..script.base import ScriptDirectory
+ from ..script.revision import Revision
+ from ..script.revision import RevisionMap
+
log = logging.getLogger(__name__)
class _ProxyTransaction:
- def __init__(self, migration_context):
+ def __init__(self, migration_context: "MigrationContext") -> None:
self.migration_context = migration_context
@property
- def _proxied_transaction(self):
+ def _proxied_transaction(self) -> Optional["Transaction"]:
return self.migration_context._transaction
- def rollback(self):
- self._proxied_transaction.rollback()
+ def rollback(self) -> None:
+ t = self._proxied_transaction
+ assert t is not None
+ t.rollback()
self.migration_context._transaction = None
- def commit(self):
- self._proxied_transaction.commit()
+ def commit(self) -> None:
+ t = self._proxied_transaction
+ assert t is not None
+ t.commit()
self.migration_context._transaction = None
- def __enter__(self):
+ def __enter__(self) -> "_ProxyTransaction":
return self
- def __exit__(self, type_, value, traceback):
+ def __exit__(self, type_: None, value: None, traceback: None) -> None:
if self._proxied_transaction is not None:
self._proxied_transaction.__exit__(type_, value, traceback)
self.migration_context._transaction = None
@@ -92,21 +121,29 @@ class MigrationContext:
"""
- def __init__(self, dialect, connection, opts, environment_context=None):
+ def __init__(
+ self,
+ dialect: "Dialect",
+ connection: Optional["Connection"],
+ opts: Dict[str, Any],
+ environment_context: Optional["EnvironmentContext"] = None,
+ ) -> None:
self.environment_context = environment_context
self.opts = opts
self.dialect = dialect
- self.script = opts.get("script")
- as_sql = opts.get("as_sql", False)
+ self.script: Optional["ScriptDirectory"] = opts.get("script")
+ as_sql: bool = opts.get("as_sql", False)
transactional_ddl = opts.get("transactional_ddl")
self._transaction_per_migration = opts.get(
"transaction_per_migration", False
)
self.on_version_apply_callbacks = opts.get("on_version_apply", ())
- self._transaction = None
+ self._transaction: Optional["Transaction"] = None
if as_sql:
- self.connection = self._stdout_connection(connection)
+ self.connection = cast(
+ Optional["Connection"], self._stdout_connection(connection)
+ )
assert self.connection is not None
self._in_external_transaction = False
else:
@@ -122,7 +159,8 @@ class MigrationContext:
if "output_encoding" in opts:
self.output_buffer = EncodedIO(
- opts.get("output_buffer") or sys.stdout,
+ opts.get("output_buffer")
+ or sys.stdout, # type:ignore[arg-type]
opts["output_encoding"],
)
else:
@@ -151,7 +189,7 @@ class MigrationContext:
)
)
- self._start_from_rev = opts.get("starting_rev")
+ self._start_from_rev: Optional[str] = opts.get("starting_rev")
self.impl = ddl.DefaultImpl.get_by_dialect(dialect)(
dialect,
self.connection,
@@ -173,14 +211,14 @@ class MigrationContext:
@classmethod
def configure(
cls,
- connection=None,
- url=None,
- dialect_name=None,
- dialect=None,
- environment_context=None,
- dialect_opts=None,
- opts=None,
- ):
+ connection: Optional["Connection"] = None,
+ url: Optional[str] = None,
+ dialect_name: Optional[str] = None,
+ dialect: Optional["Dialect"] = None,
+ environment_context: Optional["EnvironmentContext"] = None,
+ dialect_opts: Optional[Dict[str, str]] = None,
+ opts: Optional[Any] = None,
+ ) -> "MigrationContext":
"""Create a new :class:`.MigrationContext`.
This is a factory method usually called
@@ -216,18 +254,18 @@ class MigrationContext:
dialect = connection.dialect
elif url:
- url = sqla_url.make_url(url)
- dialect = url.get_dialect()(**dialect_opts)
+ url_obj = sqla_url.make_url(url)
+ dialect = url_obj.get_dialect()(**dialect_opts)
elif dialect_name:
- url = sqla_url.make_url("%s://" % dialect_name)
- dialect = url.get_dialect()(**dialect_opts)
+ url_obj = sqla_url.make_url("%s://" % dialect_name)
+ dialect = url_obj.get_dialect()(**dialect_opts)
elif not dialect:
raise Exception("Connection, url, or dialect_name is required.")
-
+ assert dialect is not None
return MigrationContext(dialect, connection, opts, environment_context)
@contextmanager
- def autocommit_block(self):
+ def autocommit_block(self) -> Iterator[None]:
"""Enter an "autocommit" block, for databases that support AUTOCOMMIT
isolation levels.
@@ -285,6 +323,7 @@ class MigrationContext:
self._transaction = None
if not self.as_sql:
+ assert self.connection is not None
current_level = self.connection.get_isolation_level()
base_connection = self.connection
@@ -300,6 +339,7 @@ class MigrationContext:
yield
finally:
if not self.as_sql:
+ assert self.connection is not None
self.connection.execution_options(
isolation_level=current_level
)
@@ -309,9 +349,12 @@ class MigrationContext:
self.impl.emit_begin()
elif _in_connection_transaction:
+ assert self.connection is not None
self._transaction = self.connection.begin()
- def begin_transaction(self, _per_migration=False):
+ def begin_transaction(
+ self, _per_migration: bool = False
+ ) -> Union["_ProxyTransaction", ContextManager]:
"""Begin a logical transaction for migration operations.
This method is used within an ``env.py`` script to demarcate where
@@ -390,6 +433,7 @@ class MigrationContext:
if in_transaction:
return do_nothing()
else:
+ assert self.connection is not None
self._transaction = (
sqla_compat._safe_begin_connection_transaction(
self.connection
@@ -406,12 +450,13 @@ class MigrationContext:
return begin_commit()
else:
+ assert self.connection is not None
self._transaction = sqla_compat._safe_begin_connection_transaction(
self.connection
)
return _ProxyTransaction(self)
- def get_current_revision(self):
+ def get_current_revision(self) -> Optional[str]:
"""Return the current revision, usually that which is present
in the ``alembic_version`` table in the database.
@@ -438,7 +483,7 @@ class MigrationContext:
else:
return heads[0]
- def get_current_heads(self):
+ def get_current_heads(self) -> Tuple[str, ...]:
"""Return a tuple of the current 'head versions' that are represented
in the target database.
@@ -457,7 +502,7 @@ class MigrationContext:
"""
if self.as_sql:
- start_from_rev = self._start_from_rev
+ start_from_rev: Any = self._start_from_rev
if start_from_rev == "base":
start_from_rev = None
elif start_from_rev is not None and self.script:
@@ -476,22 +521,27 @@ class MigrationContext:
)
if not self._has_version_table():
return ()
+ assert self.connection is not None
return tuple(
row[0] for row in self.connection.execute(self._version.select())
)
- def _ensure_version_table(self, purge=False):
+ def _ensure_version_table(self, purge: bool = False) -> None:
with sqla_compat._ensure_scope_for_ddl(self.connection):
self._version.create(self.connection, checkfirst=True)
if purge:
+ assert self.connection is not None
self.connection.execute(self._version.delete())
- def _has_version_table(self):
+ def _has_version_table(self) -> bool:
+ assert self.connection is not None
return sqla_compat._connectable_has_table(
self.connection, self.version_table, self.version_table_schema
)
- def stamp(self, script_directory, revision):
+ def stamp(
+ self, script_directory: "ScriptDirectory", revision: str
+ ) -> None:
"""Stamp the version table with a specific revision.
This method calculates those branches to which the given revision
@@ -507,7 +557,7 @@ class MigrationContext:
for step in script_directory._stamp_revs(revision, heads):
head_maintainer.update_to_step(step)
- def run_migrations(self, **kw):
+ def run_migrations(self, **kw) -> None:
r"""Run the migration scripts established for this
:class:`.MigrationContext`, if any.
@@ -530,6 +580,7 @@ class MigrationContext:
"""
self.impl.start_migrations()
+ heads: Tuple[str, ...]
if self.purge:
if self.as_sql:
raise util.CommandError("Can't use --purge with --sql mode")
@@ -545,6 +596,7 @@ class MigrationContext:
head_maintainer = HeadMaintainer(self, heads)
+ assert self._migrations_fn is not None
for step in self._migrations_fn(heads, self):
with self.begin_transaction(_per_migration=True):
@@ -576,15 +628,15 @@ class MigrationContext:
if self.as_sql and not head_maintainer.heads:
self._version.drop(self.connection)
- def _in_connection_transaction(self):
+ def _in_connection_transaction(self) -> bool:
try:
- meth = self.connection.in_transaction
+ meth = self.connection.in_transaction # type:ignore[union-attr]
except AttributeError:
return False
else:
return meth()
- def execute(self, sql, execution_options=None):
+ def execute(self, sql: str, execution_options: None = None) -> None:
"""Execute a SQL construct or string statement.
The underlying execution mechanics are used, that is
@@ -595,14 +647,16 @@ class MigrationContext:
"""
self.impl._exec(sql, execution_options)
- def _stdout_connection(self, connection):
+ def _stdout_connection(
+ self, connection: Optional["Connection"]
+ ) -> "MockConnection":
def dump(construct, *multiparams, **params):
self.impl._exec(construct)
return MockEngineStrategy.MockConnection(self.dialect, dump)
@property
- def bind(self):
+ def bind(self) -> Optional["Connection"]:
"""Return the current "bind".
In online mode, this is an instance of
@@ -623,7 +677,7 @@ class MigrationContext:
return self.connection
@property
- def config(self):
+ def config(self) -> Optional["Config"]:
"""Return the :class:`.Config` used by the current environment,
if any."""
@@ -632,7 +686,9 @@ class MigrationContext:
else:
return None
- def _compare_type(self, inspector_column, metadata_column):
+ def _compare_type(
+ self, inspector_column: "Column", metadata_column: "Column"
+ ) -> bool:
if self._user_compare_type is False:
return False
@@ -651,11 +707,11 @@ class MigrationContext:
def _compare_server_default(
self,
- inspector_column,
- metadata_column,
- rendered_metadata_default,
- rendered_column_default,
- ):
+ inspector_column: "Column",
+ metadata_column: "Column",
+ rendered_metadata_default: Optional[str],
+ rendered_column_default: Optional[str],
+ ) -> bool:
if self._user_compare_server_default is False:
return False
@@ -681,11 +737,11 @@ class MigrationContext:
class HeadMaintainer:
- def __init__(self, context, heads):
+ def __init__(self, context: "MigrationContext", heads: Any) -> None:
self.context = context
self.heads = set(heads)
- def _insert_version(self, version):
+ def _insert_version(self, version: str) -> None:
assert version not in self.heads
self.heads.add(version)
@@ -695,7 +751,7 @@ class HeadMaintainer:
)
)
- def _delete_version(self, version):
+ def _delete_version(self, version: str) -> None:
self.heads.remove(version)
ret = self.context.impl._exec(
@@ -716,7 +772,7 @@ class HeadMaintainer:
% (version, self.context.version_table, ret.rowcount)
)
- def _update_version(self, from_, to_):
+ def _update_version(self, from_: str, to_: str) -> None:
assert to_ not in self.heads
self.heads.remove(from_)
self.heads.add(to_)
@@ -741,7 +797,7 @@ class HeadMaintainer:
% (from_, to_, self.context.version_table, ret.rowcount)
)
- def update_to_step(self, step):
+ def update_to_step(self, step: Union["RevisionStep", "StampStep"]) -> None:
if step.should_delete_branch(self.heads):
vers = step.delete_version_num
log.debug("branch delete %s", vers)
@@ -796,15 +852,15 @@ class MigrationInfo:
"""
- is_upgrade = None
+ is_upgrade: bool = None # type:ignore[assignment]
"""True/False: indicates whether this operation ascends or descends the
version tree."""
- is_stamp = None
+ is_stamp: bool = None # type:ignore[assignment]
"""True/False: indicates whether this operation is a stamp (i.e. whether
it results in any actual database operations)."""
- up_revision_id = None
+ up_revision_id: Optional[str] = None
"""Version string corresponding to :attr:`.Revision.revision`.
In the case of a stamp operation, it is advised to use the
@@ -818,7 +874,7 @@ class MigrationInfo:
"""
- up_revision_ids = None
+ up_revision_ids: Tuple[str, ...] = None # type:ignore[assignment]
"""Tuple of version strings corresponding to :attr:`.Revision.revision`.
In the majority of cases, this tuple will be a single value, synonomous
@@ -829,7 +885,7 @@ class MigrationInfo:
"""
- down_revision_ids = None
+ down_revision_ids: Tuple[str, ...] = None # type:ignore[assignment]
"""Tuple of strings representing the base revisions of this migration step.
If empty, this represents a root revision; otherwise, the first item
@@ -837,12 +893,17 @@ class MigrationInfo:
from dependencies.
"""
- revision_map = None
+ revision_map: "RevisionMap" = None # type:ignore[assignment]
"""The revision map inside of which this operation occurs."""
def __init__(
- self, revision_map, is_upgrade, is_stamp, up_revisions, down_revisions
- ):
+ self,
+ revision_map: "RevisionMap",
+ is_upgrade: bool,
+ is_stamp: bool,
+ up_revisions: Union[str, Tuple[str, ...]],
+ down_revisions: Union[str, Tuple[str, ...]],
+ ) -> None:
self.revision_map = revision_map
self.is_upgrade = is_upgrade
self.is_stamp = is_stamp
@@ -857,7 +918,7 @@ class MigrationInfo:
self.down_revision_ids = util.to_tuple(down_revisions, default=())
@property
- def is_migration(self):
+ def is_migration(self) -> bool:
"""True/False: indicates whether this operation is a migration.
At present this is true if and only the migration is not a stamp.
@@ -867,21 +928,21 @@ class MigrationInfo:
return not self.is_stamp
@property
- def source_revision_ids(self):
+ def source_revision_ids(self) -> Tuple[str, ...]:
"""Active revisions before this migration step is applied."""
return (
self.down_revision_ids if self.is_upgrade else self.up_revision_ids
)
@property
- def destination_revision_ids(self):
+ def destination_revision_ids(self) -> Tuple[str, ...]:
"""Active revisions after this migration step is applied."""
return (
self.up_revision_ids if self.is_upgrade else self.down_revision_ids
)
@property
- def up_revision(self):
+ def up_revision(self) -> "Revision":
"""Get :attr:`~.MigrationInfo.up_revision_id` as
a :class:`.Revision`.
@@ -889,49 +950,59 @@ class MigrationInfo:
return self.revision_map.get_revision(self.up_revision_id)
@property
- def up_revisions(self):
+ def up_revisions(self) -> Tuple["Revision", ...]:
"""Get :attr:`~.MigrationInfo.up_revision_ids` as a
:class:`.Revision`."""
return self.revision_map.get_revisions(self.up_revision_ids)
@property
- def down_revisions(self):
+ def down_revisions(self) -> Tuple["Revision", ...]:
"""Get :attr:`~.MigrationInfo.down_revision_ids` as a tuple of
:class:`Revisions <.Revision>`."""
return self.revision_map.get_revisions(self.down_revision_ids)
@property
- def source_revisions(self):
+ def source_revisions(self) -> Tuple["Revision", ...]:
"""Get :attr:`~MigrationInfo.source_revision_ids` as a tuple of
:class:`Revisions <.Revision>`."""
return self.revision_map.get_revisions(self.source_revision_ids)
@property
- def destination_revisions(self):
+ def destination_revisions(self) -> Tuple["Revision", ...]:
"""Get :attr:`~MigrationInfo.destination_revision_ids` as a tuple of
:class:`Revisions <.Revision>`."""
return self.revision_map.get_revisions(self.destination_revision_ids)
class MigrationStep:
+
+ from_revisions_no_deps: Tuple[str, ...]
+ to_revisions_no_deps: Tuple[str, ...]
+ is_upgrade: bool
+ migration_fn: Any
+
@property
- def name(self):
+ def name(self) -> str:
return self.migration_fn.__name__
@classmethod
- def upgrade_from_script(cls, revision_map, script):
+ def upgrade_from_script(
+ cls, revision_map: "RevisionMap", script: "Script"
+ ) -> "RevisionStep":
return RevisionStep(revision_map, script, True)
@classmethod
- def downgrade_from_script(cls, revision_map, script):
+ def downgrade_from_script(
+ cls, revision_map: "RevisionMap", script: "Script"
+ ) -> "RevisionStep":
return RevisionStep(revision_map, script, False)
@property
- def is_downgrade(self):
+ def is_downgrade(self) -> bool:
return not self.is_upgrade
@property
- def short_log(self):
+ def short_log(self) -> str:
return "%s %s -> %s" % (
self.name,
util.format_as_comma(self.from_revisions_no_deps),
@@ -951,14 +1022,20 @@ class MigrationStep:
class RevisionStep(MigrationStep):
- def __init__(self, revision_map, revision, is_upgrade):
+ def __init__(
+ self, revision_map: "RevisionMap", revision: "Script", is_upgrade: bool
+ ) -> None:
self.revision_map = revision_map
self.revision = revision
self.is_upgrade = is_upgrade
if is_upgrade:
- self.migration_fn = revision.module.upgrade
+ self.migration_fn = (
+ revision.module.upgrade # type:ignore[attr-defined]
+ )
else:
- self.migration_fn = revision.module.downgrade
+ self.migration_fn = (
+ revision.module.downgrade # type:ignore[attr-defined]
+ )
def __repr__(self):
return "RevisionStep(%r, is_upgrade=%r)" % (
@@ -966,7 +1043,7 @@ class RevisionStep(MigrationStep):
self.is_upgrade,
)
- def __eq__(self, other):
+ def __eq__(self, other: object) -> bool:
return (
isinstance(other, RevisionStep)
and other.revision == self.revision
@@ -978,38 +1055,42 @@ class RevisionStep(MigrationStep):
return self.revision.doc
@property
- def from_revisions(self):
+ def from_revisions(self) -> Tuple[str, ...]:
if self.is_upgrade:
return self.revision._normalized_down_revisions
else:
return (self.revision.revision,)
@property
- def from_revisions_no_deps(self):
+ def from_revisions_no_deps( # type:ignore[override]
+ self,
+ ) -> Tuple[str, ...]:
if self.is_upgrade:
return self.revision._versioned_down_revisions
else:
return (self.revision.revision,)
@property
- def to_revisions(self):
+ def to_revisions(self) -> Tuple[str, ...]:
if self.is_upgrade:
return (self.revision.revision,)
else:
return self.revision._normalized_down_revisions
@property
- def to_revisions_no_deps(self):
+ def to_revisions_no_deps( # type:ignore[override]
+ self,
+ ) -> Tuple[str, ...]:
if self.is_upgrade:
return (self.revision.revision,)
else:
return self.revision._versioned_down_revisions
@property
- def _has_scalar_down_revision(self):
+ def _has_scalar_down_revision(self) -> bool:
return len(self.revision._normalized_down_revisions) == 1
- def should_delete_branch(self, heads):
+ def should_delete_branch(self, heads: Set[str]) -> bool:
"""A delete is when we are a. in a downgrade and b.
we are going to the "base" or we are going to a version that
is implied as a dependency on another version that is remaining.
@@ -1032,7 +1113,9 @@ class RevisionStep(MigrationStep):
to_revisions = self._unmerge_to_revisions(heads)
return not to_revisions
- def merge_branch_idents(self, heads):
+ def merge_branch_idents(
+ self, heads: Set[str]
+ ) -> Tuple[List[str], str, str]:
other_heads = set(heads).difference(self.from_revisions)
if other_heads:
@@ -1055,7 +1138,7 @@ class RevisionStep(MigrationStep):
self.to_revisions[0],
)
- def _unmerge_to_revisions(self, heads):
+ def _unmerge_to_revisions(self, heads: Collection[str]) -> Tuple[str, ...]:
other_heads = set(heads).difference([self.revision.revision])
if other_heads:
ancestors = set(
@@ -1064,11 +1147,13 @@ class RevisionStep(MigrationStep):
self.revision_map.get_revisions(other_heads), check=False
)
)
- return list(set(self.to_revisions).difference(ancestors))
+ return tuple(set(self.to_revisions).difference(ancestors))
else:
return self.to_revisions
- def unmerge_branch_idents(self, heads):
+ def unmerge_branch_idents(
+ self, heads: Collection[str]
+ ) -> Tuple[str, str, Tuple[str, ...]]:
to_revisions = self._unmerge_to_revisions(heads)
return (
@@ -1078,7 +1163,7 @@ class RevisionStep(MigrationStep):
to_revisions[0:-1],
)
- def should_create_branch(self, heads):
+ def should_create_branch(self, heads: Set[str]) -> bool:
if not self.is_upgrade:
return False
@@ -1097,7 +1182,7 @@ class RevisionStep(MigrationStep):
else:
return False
- def should_merge_branches(self, heads):
+ def should_merge_branches(self, heads: Set[str]) -> bool:
if not self.is_upgrade:
return False
@@ -1108,7 +1193,7 @@ class RevisionStep(MigrationStep):
return False
- def should_unmerge_branches(self, heads):
+ def should_unmerge_branches(self, heads: Set[str]) -> bool:
if not self.is_downgrade:
return False
@@ -1119,7 +1204,7 @@ class RevisionStep(MigrationStep):
return False
- def update_version_num(self, heads):
+ def update_version_num(self, heads: Set[str]) -> Tuple[str, str]:
if not self._has_scalar_down_revision:
downrev = heads.intersection(
self.revision._normalized_down_revisions
@@ -1137,15 +1222,15 @@ class RevisionStep(MigrationStep):
return self.revision.revision, down_revision
@property
- def delete_version_num(self):
+ def delete_version_num(self) -> str:
return self.revision.revision
@property
- def insert_version_num(self):
+ def insert_version_num(self) -> str:
return self.revision.revision
@property
- def info(self):
+ def info(self) -> "MigrationInfo":
return MigrationInfo(
revision_map=self.revision_map,
up_revisions=self.revision.revision,
@@ -1156,9 +1241,16 @@ class RevisionStep(MigrationStep):
class StampStep(MigrationStep):
- def __init__(self, from_, to_, is_upgrade, branch_move, revision_map=None):
- self.from_ = util.to_tuple(from_, default=())
- self.to_ = util.to_tuple(to_, default=())
+ def __init__(
+ self,
+ from_: Optional[Union[str, Collection[str]]],
+ to_: Optional[Union[str, Collection[str]]],
+ is_upgrade: bool,
+ branch_move: bool,
+ revision_map: Optional["RevisionMap"] = None,
+ ) -> None:
+ self.from_: Tuple[str, ...] = util.to_tuple(from_, default=())
+ self.to_: Tuple[str, ...] = util.to_tuple(to_, default=())
self.is_upgrade = is_upgrade
self.branch_move = branch_move
self.migration_fn = self.stamp_revision
@@ -1166,7 +1258,7 @@ class StampStep(MigrationStep):
doc = None
- def stamp_revision(self, **kw):
+ def stamp_revision(self, **kw) -> None:
return None
def __eq__(self, other):
@@ -1183,33 +1275,39 @@ class StampStep(MigrationStep):
return self.from_
@property
- def to_revisions(self):
+ def to_revisions(self) -> Tuple[str, ...]:
return self.to_
@property
- def from_revisions_no_deps(self):
+ def from_revisions_no_deps( # type:ignore[override]
+ self,
+ ) -> Tuple[str, ...]:
return self.from_
@property
- def to_revisions_no_deps(self):
+ def to_revisions_no_deps( # type:ignore[override]
+ self,
+ ) -> Tuple[str, ...]:
return self.to_
@property
- def delete_version_num(self):
+ def delete_version_num(self) -> str:
assert len(self.from_) == 1
return self.from_[0]
@property
- def insert_version_num(self):
+ def insert_version_num(self) -> str:
assert len(self.to_) == 1
return self.to_[0]
- def update_version_num(self, heads):
+ def update_version_num(self, heads: Set[str]) -> Tuple[str, str]:
assert len(self.from_) == 1
assert len(self.to_) == 1
return self.from_[0], self.to_[0]
- def merge_branch_idents(self, heads):
+ def merge_branch_idents(
+ self, heads: Union[Set[str], List[str]]
+ ) -> Union[Tuple[List[Any], str, str], Tuple[List[str], str, str]]:
return (
# delete revs, update from rev, update to rev
list(self.from_[0:-1]),
@@ -1217,7 +1315,9 @@ class StampStep(MigrationStep):
self.to_[0],
)
- def unmerge_branch_idents(self, heads):
+ def unmerge_branch_idents(
+ self, heads: Set[str]
+ ) -> Tuple[str, str, List[str]]:
return (
# update from rev, update to rev, insert revs
self.from_[0],
@@ -1225,32 +1325,33 @@ class StampStep(MigrationStep):
list(self.to_[0:-1]),
)
- def should_delete_branch(self, heads):
+ def should_delete_branch(self, heads: Set[str]) -> bool:
# TODO: we probably need to look for self.to_ inside of heads,
# in a similar manner as should_create_branch, however we have
# no tests for this yet (stamp downgrades w/ branches)
return self.is_downgrade and self.branch_move
- def should_create_branch(self, heads):
+ def should_create_branch(self, heads: Set[str]) -> Union[Set[str], bool]:
return (
self.is_upgrade
and (self.branch_move or set(self.from_).difference(heads))
and set(self.to_).difference(heads)
)
- def should_merge_branches(self, heads):
+ def should_merge_branches(self, heads: Set[str]) -> bool:
return len(self.from_) > 1
- def should_unmerge_branches(self, heads):
+ def should_unmerge_branches(self, heads: Set[str]) -> bool:
return len(self.to_) > 1
@property
- def info(self):
+ def info(self) -> "MigrationInfo":
up, down = (
(self.to_, self.from_)
if self.is_upgrade
else (self.from_, self.to_)
)
+ assert self.revision_map is not None
return MigrationInfo(
revision_map=self.revision_map,
up_revisions=up,
diff --git a/alembic/script/base.py b/alembic/script/base.py
index d0500c4..ef0fd52 100644
--- a/alembic/script/base.py
+++ b/alembic/script/base.py
@@ -4,16 +4,35 @@ import os
import re
import shutil
import sys
+from types import ModuleType
+from typing import Any
+from typing import cast
+from typing import Dict
+from typing import Iterator
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Set
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import Union
from . import revision
from . import write_hooks
from .. import util
from ..runtime import migration
+if TYPE_CHECKING:
+ from ..config import Config
+ from ..runtime.migration import RevisionStep
+ from ..runtime.migration import StampStep
+
try:
from dateutil import tz
except ImportError:
- tz = None # noqa
+ tz = None # type: ignore[assignment]
+
+_RevIdType = Union[str, Sequence[str]]
_sourceless_rev_file = re.compile(r"(?!\.\#|__init__)(.*\.py)(c|o)?$")
_only_source_rev_file = re.compile(r"(?!\.\#|__init__)(.*\.py)$")
@@ -49,15 +68,15 @@ class ScriptDirectory:
def __init__(
self,
- dir, # noqa
- file_template=_default_file_template,
- truncate_slug_length=40,
- version_locations=None,
- sourceless=False,
- output_encoding="utf-8",
- timezone=None,
- hook_config=None,
- ):
+ dir: str, # noqa
+ file_template: str = _default_file_template,
+ truncate_slug_length: Optional[int] = 40,
+ version_locations: Optional[List[str]] = None,
+ sourceless: bool = False,
+ output_encoding: str = "utf-8",
+ timezone: Optional[str] = None,
+ hook_config: Optional[Dict[str, str]] = None,
+ ) -> None:
self.dir = dir
self.file_template = file_template
self.version_locations = version_locations
@@ -76,7 +95,7 @@ class ScriptDirectory:
)
@property
- def versions(self):
+ def versions(self) -> str:
loc = self._version_locations
if len(loc) > 1:
raise util.CommandError("Multiple version_locations present")
@@ -93,7 +112,7 @@ class ScriptDirectory:
else:
return (os.path.abspath(os.path.join(self.dir, "versions")),)
- def _load_revisions(self):
+ def _load_revisions(self) -> Iterator["Script"]:
if self.version_locations:
paths = [
vers
@@ -120,7 +139,7 @@ class ScriptDirectory:
yield script
@classmethod
- def from_config(cls, config):
+ def from_config(cls, config: "Config") -> "ScriptDirectory":
"""Produce a new :class:`.ScriptDirectory` given a :class:`.Config`
instance.
@@ -133,7 +152,9 @@ class ScriptDirectory:
raise util.CommandError(
"No 'script_location' key " "found in configuration."
)
- truncate_slug_length = config.get_main_option("truncate_slug_length")
+ truncate_slug_length = cast(
+ Optional[int], config.get_main_option("truncate_slug_length")
+ )
if truncate_slug_length is not None:
truncate_slug_length = int(truncate_slug_length)
@@ -162,13 +183,17 @@ class ScriptDirectory:
else:
if split_char is None:
# legacy behaviour for backwards compatibility
- version_locations = _split_on_space_comma.split(
- version_locations
+ vl = _split_on_space_comma.split(
+ cast(str, version_locations)
)
+ version_locations: List[str] = vl # type: ignore[no-redef]
else:
- version_locations = [
- x for x in version_locations.split(split_char) if x
+ vl = [
+ x
+ for x in cast(str, version_locations).split(split_char)
+ if x
]
+ version_locations: List[str] = vl # type: ignore[no-redef]
prepend_sys_path = config.get_main_option("prepend_sys_path")
if prepend_sys_path:
@@ -184,7 +209,7 @@ class ScriptDirectory:
truncate_slug_length=truncate_slug_length,
sourceless=config.get_main_option("sourceless") == "true",
output_encoding=config.get_main_option("output_encoding", "utf-8"),
- version_locations=version_locations,
+ version_locations=cast("Optional[List[str]]", version_locations),
timezone=config.get_main_option("timezone"),
hook_config=config.get_section("post_write_hooks", {}),
)
@@ -192,19 +217,19 @@ class ScriptDirectory:
@contextmanager
def _catch_revision_errors(
self,
- ancestor=None,
- multiple_heads=None,
- start=None,
- end=None,
- resolution=None,
- ):
+ ancestor: Optional[str] = None,
+ multiple_heads: Optional[str] = None,
+ start: Optional[str] = None,
+ end: Optional[str] = None,
+ resolution: Optional[str] = None,
+ ) -> Iterator[None]:
try:
yield
except revision.RangeNotAncestorError as rna:
if start is None:
- start = rna.lower
+ start = cast(Any, rna.lower)
if end is None:
- end = rna.upper
+ end = cast(Any, rna.upper)
if not ancestor:
ancestor = (
"Requested range %(start)s:%(end)s does not refer to "
@@ -235,7 +260,9 @@ class ScriptDirectory:
except revision.RevisionError as err:
raise util.CommandError(err.args[0]) from err
- def walk_revisions(self, base="base", head="heads"):
+ def walk_revisions(
+ self, base: str = "base", head: str = "heads"
+ ) -> Iterator["Script"]:
"""Iterate through all revisions.
:param base: the base revision, or "base" to start from the
@@ -250,28 +277,36 @@ class ScriptDirectory:
for rev in self.revision_map.iterate_revisions(
head, base, inclusive=True, assert_relative_length=False
):
- yield rev
+ yield cast(Script, rev)
- def get_revisions(self, id_):
+ def get_revisions(self, id_: _RevIdType) -> Tuple["Script", ...]:
"""Return the :class:`.Script` instance with the given rev identifier,
symbolic name, or sequence of identifiers.
"""
with self._catch_revision_errors():
- return self.revision_map.get_revisions(id_)
+ return cast(
+ "Tuple[Script, ...]", self.revision_map.get_revisions(id_)
+ )
- def get_all_current(self, id_):
+ def get_all_current(self, id_: Tuple[str, ...]) -> Set["Script"]:
with self._catch_revision_errors():
- top_revs = set(self.revision_map.get_revisions(id_))
+ top_revs = cast(
+ "Set[Script]",
+ set(self.revision_map.get_revisions(id_)),
+ )
top_revs.update(
- self.revision_map._get_ancestor_nodes(
- list(top_revs), include_dependencies=True
+ cast(
+ "Iterator[Script]",
+ self.revision_map._get_ancestor_nodes(
+ list(top_revs), include_dependencies=True
+ ),
)
)
top_revs = self.revision_map._filter_into_branch_heads(top_revs)
return top_revs
- def get_revision(self, id_):
+ def get_revision(self, id_: str) -> "Script":
"""Return the :class:`.Script` instance with the given rev id.
.. seealso::
@@ -281,9 +316,11 @@ class ScriptDirectory:
"""
with self._catch_revision_errors():
- return self.revision_map.get_revision(id_)
+ return cast(Script, self.revision_map.get_revision(id_))
- def as_revision_number(self, id_):
+ def as_revision_number(
+ self, id_: Optional[str]
+ ) -> Optional[Union[str, Tuple[str, ...]]]:
"""Convert a symbolic revision, i.e. 'head' or 'base', into
an actual revision number."""
@@ -340,7 +377,7 @@ class ScriptDirectory:
):
return self.revision_map.get_current_head()
- def get_heads(self):
+ def get_heads(self) -> List[str]:
"""Return all "versioned head" revisions as strings.
This is normally a list of length one,
@@ -353,7 +390,7 @@ class ScriptDirectory:
"""
return list(self.revision_map.heads)
- def get_base(self):
+ def get_base(self) -> Optional[str]:
"""Return the "base" revision as a string.
This is the revision number of the script that
@@ -375,7 +412,7 @@ class ScriptDirectory:
else:
return None
- def get_bases(self):
+ def get_bases(self) -> List[str]:
"""return all "base" revisions as strings.
This is the revision number of all scripts that
@@ -384,7 +421,9 @@ class ScriptDirectory:
"""
return list(self.revision_map.bases)
- def _upgrade_revs(self, destination, current_rev):
+ def _upgrade_revs(
+ self, destination: str, current_rev: str
+ ) -> List["RevisionStep"]:
with self._catch_revision_errors(
ancestor="Destination %(end)s is not a valid upgrade "
"target from current head(s)",
@@ -393,15 +432,16 @@ class ScriptDirectory:
revs = self.revision_map.iterate_revisions(
destination, current_rev, implicit_base=True
)
- revs = list(revs)
return [
migration.MigrationStep.upgrade_from_script(
- self.revision_map, script
+ self.revision_map, cast(Script, script)
)
for script in reversed(list(revs))
]
- def _downgrade_revs(self, destination, current_rev):
+ def _downgrade_revs(
+ self, destination: str, current_rev: Optional[str]
+ ) -> List["RevisionStep"]:
with self._catch_revision_errors(
ancestor="Destination %(end)s is not a valid downgrade "
"target from current head(s)",
@@ -412,30 +452,32 @@ class ScriptDirectory:
)
return [
migration.MigrationStep.downgrade_from_script(
- self.revision_map, script
+ self.revision_map, cast(Script, script)
)
for script in revs
]
- def _stamp_revs(self, revision, heads):
+ def _stamp_revs(
+ self, revision: _RevIdType, heads: _RevIdType
+ ) -> List["StampStep"]:
with self._catch_revision_errors(
multiple_heads="Multiple heads are present; please specify a "
"single target revision"
):
- heads = self.get_revisions(heads)
+ heads_revs = self.get_revisions(heads)
steps = []
if not revision:
revision = "base"
- filtered_heads = []
+ filtered_heads: List["Script"] = []
for rev in util.to_tuple(revision):
if rev:
filtered_heads.extend(
self.revision_map.filter_for_lineage(
- heads, rev, include_dependencies=True
+ heads_revs, rev, include_dependencies=True
)
)
filtered_heads = util.unique_list(filtered_heads)
@@ -509,7 +551,7 @@ class ScriptDirectory:
return steps
- def run_env(self):
+ def run_env(self) -> None:
"""Run the script environment.
This basically runs the ``env.py`` script present
@@ -524,7 +566,7 @@ class ScriptDirectory:
def env_py_location(self):
return os.path.abspath(os.path.join(self.dir, "env.py"))
- def _generate_template(self, src, dest, **kw):
+ def _generate_template(self, src: str, dest: str, **kw: Any) -> None:
util.status(
"Generating %s" % os.path.abspath(dest),
util.template_to_file,
@@ -534,17 +576,17 @@ class ScriptDirectory:
**kw
)
- def _copy_file(self, src, dest):
+ def _copy_file(self, src: str, dest: str) -> None:
util.status(
"Generating %s" % os.path.abspath(dest), shutil.copy, src, dest
)
- def _ensure_directory(self, path):
+ def _ensure_directory(self, path: str) -> None:
path = os.path.abspath(path)
if not os.path.exists(path):
util.status("Creating directory %s" % path, os.makedirs, path)
- def _generate_create_date(self):
+ def _generate_create_date(self) -> "datetime.datetime":
if self.timezone is not None:
if tz is None:
raise util.CommandError(
@@ -571,16 +613,16 @@ class ScriptDirectory:
def generate_revision(
self,
- revid,
- message,
- head=None,
- refresh=False,
- splice=False,
- branch_labels=None,
- version_path=None,
- depends_on=None,
- **kw
- ):
+ revid: str,
+ message: Optional[str],
+ head: Optional[str] = None,
+ refresh: bool = False,
+ splice: Optional[bool] = False,
+ branch_labels: Optional[str] = None,
+ version_path: Optional[str] = None,
+ depends_on: Optional[_RevIdType] = None,
+ **kw: Any
+ ) -> Optional["Script"]:
"""Generate a new revision file.
This runs the ``script.py.mako`` template, given
@@ -623,9 +665,10 @@ class ScriptDirectory:
if version_path is None:
if len(self._version_locations) > 1:
- for head in heads:
- if head is not None:
- version_path = os.path.dirname(head.path)
+ for head_ in heads:
+ if head_ is not None:
+ assert isinstance(head_, Script)
+ version_path = os.path.dirname(head_.path)
break
else:
raise util.CommandError(
@@ -651,12 +694,12 @@ class ScriptDirectory:
path = self._rev_path(version_path, revid, message, create_date)
if not splice:
- for head in heads:
- if head is not None and not head.is_head:
+ for head_ in heads:
+ if head_ is not None and not head_.is_head:
raise util.CommandError(
"Revision %s is not a head revision; please specify "
"--splice to create a new branch from this revision"
- % head.revision
+ % head_.revision
)
if depends_on:
@@ -679,7 +722,9 @@ class ScriptDirectory:
tuple(h.revision if h is not None else None for h in heads)
),
branch_labels=util.to_tuple(branch_labels),
- depends_on=revision.tuple_rev_as_scalar(depends_on),
+ depends_on=revision.tuple_rev_as_scalar(
+ cast("Optional[List[str]]", depends_on)
+ ),
create_date=create_date,
comma=util.format_as_comma,
message=message if message is not None else ("empty message"),
@@ -694,6 +739,8 @@ class ScriptDirectory:
script = Script._from_path(self, path)
except revision.RevisionError as err:
raise util.CommandError(err.args[0]) from err
+ if script is None:
+ return None
if branch_labels and not script.branch_labels:
raise util.CommandError(
"Version %s specified branch_labels %s, however the "
@@ -702,11 +749,16 @@ class ScriptDirectory:
"'branch_labels' section?"
% (script.revision, branch_labels, script.path)
)
-
self.revision_map.add_revision(script)
return script
- def _rev_path(self, path, rev_id, message, create_date):
+ def _rev_path(
+ self,
+ path: str,
+ rev_id: str,
+ message: Optional[str],
+ create_date: "datetime.datetime",
+ ) -> str:
slug = "_".join(_slug_re.findall(message or "")).lower()
if len(slug) > self.truncate_slug_length:
slug = slug[: self.truncate_slug_length].rsplit("_", 1)[0] + "_"
@@ -735,12 +787,12 @@ class Script(revision.Revision):
"""
- def __init__(self, module, rev_id, path):
+ def __init__(self, module: ModuleType, rev_id: str, path: str):
self.module = module
self.path = path
super(Script, self).__init__(
rev_id,
- module.down_revision,
+ module.down_revision, # type: ignore[attr-defined]
branch_labels=util.to_tuple(
getattr(module, "branch_labels", None), default=()
),
@@ -749,10 +801,10 @@ class Script(revision.Revision):
),
)
- module = None
+ module: ModuleType = None # type: ignore[assignment]
"""The Python module representing the actual script itself."""
- path = None
+ path: str = None # type: ignore[assignment]
"""Filesystem path of the script."""
_db_current_indicator = None
@@ -760,25 +812,27 @@ class Script(revision.Revision):
this is a "current" version in some database"""
@property
- def doc(self):
+ def doc(self) -> str:
"""Return the docstring given in the script."""
return re.split("\n\n", self.longdoc)[0]
@property
- def longdoc(self):
+ def longdoc(self) -> str:
"""Return the docstring given in the script."""
doc = self.module.__doc__
if doc:
if hasattr(self.module, "_alembic_source_encoding"):
- doc = doc.decode(self.module._alembic_source_encoding)
- return doc.strip()
+ doc = doc.decode( # type: ignore[attr-defined]
+ self.module._alembic_source_encoding # type: ignore[attr-defined] # noqa
+ )
+ return doc.strip() # type: ignore[union-attr]
else:
return ""
@property
- def log_entry(self):
+ def log_entry(self) -> str:
entry = "Rev: %s%s%s%s%s\n" % (
self.revision,
" (head)" if self.is_head else "",
@@ -825,12 +879,12 @@ class Script(revision.Revision):
def _head_only(
self,
- include_branches=False,
- include_doc=False,
- include_parents=False,
- tree_indicators=True,
- head_indicators=True,
- ):
+ include_branches: bool = False,
+ include_doc: bool = False,
+ include_parents: bool = False,
+ tree_indicators: bool = True,
+ head_indicators: bool = True,
+ ) -> str:
text = self.revision
if include_parents:
if self.dependencies:
@@ -841,6 +895,7 @@ class Script(revision.Revision):
)
else:
text = "%s -> %s" % (self._format_down_revision(), text)
+ assert text is not None
if include_branches and self.branch_labels:
text += " (%s)" % util.format_as_comma(self.branch_labels)
if head_indicators or tree_indicators:
@@ -862,12 +917,12 @@ class Script(revision.Revision):
def cmd_format(
self,
- verbose,
- include_branches=False,
- include_doc=False,
- include_parents=False,
- tree_indicators=True,
- ):
+ verbose: bool,
+ include_branches: bool = False,
+ include_doc: bool = False,
+ include_parents: bool = False,
+ tree_indicators: bool = True,
+ ) -> str:
if verbose:
return self.log_entry
else:
@@ -875,19 +930,21 @@ class Script(revision.Revision):
include_branches, include_doc, include_parents, tree_indicators
)
- def _format_down_revision(self):
+ def _format_down_revision(self) -> str:
if not self.down_revision:
return "<base>"
else:
return util.format_as_comma(self._versioned_down_revisions)
@classmethod
- def _from_path(cls, scriptdir, path):
+ def _from_path(
+ cls, scriptdir: ScriptDirectory, path: str
+ ) -> Optional["Script"]:
dir_, filename = os.path.split(path)
return cls._from_filename(scriptdir, dir_, filename)
@classmethod
- def _list_py_dir(cls, scriptdir, path):
+ def _list_py_dir(cls, scriptdir: ScriptDirectory, path: str) -> List[str]:
if scriptdir.sourceless:
# read files in version path, e.g. pyc or pyo files
# in the immediate path
@@ -910,7 +967,9 @@ class Script(revision.Revision):
return os.listdir(path)
@classmethod
- def _from_filename(cls, scriptdir, dir_, filename):
+ def _from_filename(
+ cls, scriptdir: ScriptDirectory, dir_: str, filename: str
+ ) -> Optional["Script"]:
if scriptdir.sourceless:
py_match = _sourceless_rev_file.match(filename)
else:
diff --git a/alembic/script/revision.py b/alembic/script/revision.py
index bdae805..eccb98e 100644
--- a/alembic/script/revision.py
+++ b/alembic/script/revision.py
@@ -1,11 +1,40 @@
import collections
import re
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Collection
+from typing import Deque
+from typing import Dict
+from typing import FrozenSet
+from typing import Iterator
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Set
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import TypeVar
+from typing import Union
from sqlalchemy import util as sqlautil
from .. import util
from ..util import compat
+if TYPE_CHECKING:
+ from typing import Literal
+
+ from .base import Script
+
+_RevIdType = Union[str, Sequence[str]]
+_RevisionIdentifierType = Union[str, Tuple[str, ...], None]
+_RevisionOrStr = Union["Revision", str]
+_RevisionOrBase = Union["Revision", "Literal['base']"]
+_InterimRevisionMapType = Dict[str, "Revision"]
+_RevisionMapType = Dict[Union[None, str, Tuple[()]], Optional["Revision"]]
+_T = TypeVar("_T", bound=Union[str, "Revision"])
+
_relative_destination = re.compile(r"(?:(.+?)@)?(\w+)?((?:\+|-)\d+)")
_revision_illegal_chars = ["@", "-", "+"]
@@ -15,7 +44,9 @@ class RevisionError(Exception):
class RangeNotAncestorError(RevisionError):
- def __init__(self, lower, upper):
+ def __init__(
+ self, lower: _RevisionIdentifierType, upper: _RevisionIdentifierType
+ ) -> None:
self.lower = lower
self.upper = upper
super(RangeNotAncestorError, self).__init__(
@@ -25,7 +56,7 @@ class RangeNotAncestorError(RevisionError):
class MultipleHeads(RevisionError):
- def __init__(self, heads, argument):
+ def __init__(self, heads: Sequence[str], argument: Optional[str]) -> None:
self.heads = heads
self.argument = argument
super(MultipleHeads, self).__init__(
@@ -35,7 +66,7 @@ class MultipleHeads(RevisionError):
class ResolutionError(RevisionError):
- def __init__(self, message, argument):
+ def __init__(self, message: str, argument: str) -> None:
super(ResolutionError, self).__init__(message)
self.argument = argument
@@ -43,7 +74,7 @@ class ResolutionError(RevisionError):
class CycleDetected(RevisionError):
kind = "Cycle"
- def __init__(self, revisions):
+ def __init__(self, revisions: Sequence[str]) -> None:
self.revisions = revisions
super(CycleDetected, self).__init__(
"%s is detected in revisions (%s)"
@@ -54,21 +85,21 @@ class CycleDetected(RevisionError):
class DependencyCycleDetected(CycleDetected):
kind = "Dependency cycle"
- def __init__(self, revisions):
+ def __init__(self, revisions: Sequence[str]) -> None:
super(DependencyCycleDetected, self).__init__(revisions)
class LoopDetected(CycleDetected):
kind = "Self-loop"
- def __init__(self, revision):
+ def __init__(self, revision: str) -> None:
super(LoopDetected, self).__init__([revision])
class DependencyLoopDetected(DependencyCycleDetected, LoopDetected):
kind = "Dependency self-loop"
- def __init__(self, revision):
+ def __init__(self, revision: Sequence[str]) -> None:
super(DependencyLoopDetected, self).__init__(revision)
@@ -81,7 +112,7 @@ class RevisionMap:
"""
- def __init__(self, generator):
+ def __init__(self, generator: Callable[[], Iterator["Revision"]]) -> None:
"""Construct a new :class:`.RevisionMap`.
:param generator: a zero-arg callable that will generate an iterable
@@ -92,7 +123,7 @@ class RevisionMap:
self._generator = generator
@util.memoized_property
- def heads(self):
+ def heads(self) -> Tuple[str, ...]:
"""All "head" revisions as strings.
This is normally a tuple of length one,
@@ -105,7 +136,7 @@ class RevisionMap:
return self.heads
@util.memoized_property
- def bases(self):
+ def bases(self) -> Tuple[str, ...]:
"""All "base" revisions as strings.
These are revisions that have a ``down_revision`` of None,
@@ -118,7 +149,7 @@ class RevisionMap:
return self.bases
@util.memoized_property
- def _real_heads(self):
+ def _real_heads(self) -> Tuple[str, ...]:
"""All "real" head revisions as strings.
:return: a tuple of string revision numbers.
@@ -128,7 +159,7 @@ class RevisionMap:
return self._real_heads
@util.memoized_property
- def _real_bases(self):
+ def _real_bases(self) -> Tuple[str, ...]:
"""All "real" base revisions as strings.
:return: a tuple of string revision numbers.
@@ -138,19 +169,19 @@ class RevisionMap:
return self._real_bases
@util.memoized_property
- def _revision_map(self):
+ def _revision_map(self) -> _RevisionMapType:
"""memoized attribute, initializes the revision map from the
initial collection.
"""
# Ordering required for some tests to pass (but not required in
# general)
- map_ = sqlautil.OrderedDict()
+ map_: _InterimRevisionMapType = sqlautil.OrderedDict()
- heads = sqlautil.OrderedSet()
- _real_heads = sqlautil.OrderedSet()
- bases = ()
- _real_bases = ()
+ heads: Set["Revision"] = sqlautil.OrderedSet()
+ _real_heads: Set["Revision"] = sqlautil.OrderedSet()
+ bases: Tuple["Revision", ...] = ()
+ _real_bases: Tuple["Revision", ...] = ()
has_branch_labels = set()
all_revisions = set()
@@ -176,11 +207,13 @@ class RevisionMap:
# add the branch_labels to the map_. We'll need these
# to resolve the dependencies.
rev_map = map_.copy()
- self._map_branch_labels(has_branch_labels, map_)
+ self._map_branch_labels(
+ has_branch_labels, cast(_RevisionMapType, map_)
+ )
# resolve dependency names from branch labels and symbolic
# names
- self._add_depends_on(all_revisions, map_)
+ self._add_depends_on(all_revisions, cast(_RevisionMapType, map_))
for rev in map_.values():
for downrev in rev._all_down_revisions:
@@ -198,32 +231,44 @@ class RevisionMap:
# once the map has downrevisions populated, the dependencies
# can be further refined to include only those which are not
# already ancestors
- self._normalize_depends_on(all_revisions, map_)
+ self._normalize_depends_on(all_revisions, cast(_RevisionMapType, map_))
self._detect_cycles(rev_map, heads, bases, _real_heads, _real_bases)
- map_[None] = map_[()] = None
+ revision_map: _RevisionMapType = dict(map_.items())
+ revision_map[None] = revision_map[()] = None
self.heads = tuple(rev.revision for rev in heads)
self._real_heads = tuple(rev.revision for rev in _real_heads)
self.bases = tuple(rev.revision for rev in bases)
self._real_bases = tuple(rev.revision for rev in _real_bases)
- self._add_branches(has_branch_labels, map_)
- return map_
+ self._add_branches(has_branch_labels, revision_map)
+ return revision_map
- def _detect_cycles(self, rev_map, heads, bases, _real_heads, _real_bases):
+ def _detect_cycles(
+ self,
+ rev_map: _InterimRevisionMapType,
+ heads: Set["Revision"],
+ bases: Tuple["Revision", ...],
+ _real_heads: Set["Revision"],
+ _real_bases: Tuple["Revision", ...],
+ ) -> None:
if not rev_map:
return
if not heads or not bases:
- raise CycleDetected(rev_map.keys())
+ raise CycleDetected(list(rev_map))
total_space = {
rev.revision
for rev in self._iterate_related_revisions(
- lambda r: r._versioned_down_revisions, heads, map_=rev_map
+ lambda r: r._versioned_down_revisions,
+ heads,
+ map_=cast(_RevisionMapType, rev_map),
)
}.intersection(
rev.revision
for rev in self._iterate_related_revisions(
- lambda r: r.nextrev, bases, map_=rev_map
+ lambda r: r.nextrev,
+ bases,
+ map_=cast(_RevisionMapType, rev_map),
)
)
deleted_revs = set(rev_map.keys()) - total_space
@@ -231,39 +276,50 @@ class RevisionMap:
raise CycleDetected(sorted(deleted_revs))
if not _real_heads or not _real_bases:
- raise DependencyCycleDetected(rev_map.keys())
+ raise DependencyCycleDetected(list(rev_map))
total_space = {
rev.revision
for rev in self._iterate_related_revisions(
- lambda r: r._all_down_revisions, _real_heads, map_=rev_map
+ lambda r: r._all_down_revisions,
+ _real_heads,
+ map_=cast(_RevisionMapType, rev_map),
)
}.intersection(
rev.revision
for rev in self._iterate_related_revisions(
- lambda r: r._all_nextrev, _real_bases, map_=rev_map
+ lambda r: r._all_nextrev,
+ _real_bases,
+ map_=cast(_RevisionMapType, rev_map),
)
)
deleted_revs = set(rev_map.keys()) - total_space
if deleted_revs:
raise DependencyCycleDetected(sorted(deleted_revs))
- def _map_branch_labels(self, revisions, map_):
+ def _map_branch_labels(
+ self, revisions: Collection["Revision"], map_: _RevisionMapType
+ ) -> None:
for revision in revisions:
if revision.branch_labels:
+ assert revision._orig_branch_labels is not None
for branch_label in revision._orig_branch_labels:
if branch_label in map_:
+ map_rev = map_[branch_label]
+ assert map_rev is not None
raise RevisionError(
"Branch name '%s' in revision %s already "
"used by revision %s"
% (
branch_label,
revision.revision,
- map_[branch_label].revision,
+ map_rev.revision,
)
)
map_[branch_label] = revision
- def _add_branches(self, revisions, map_):
+ def _add_branches(
+ self, revisions: Collection["Revision"], map_: _RevisionMapType
+ ) -> None:
for revision in revisions:
if revision.branch_labels:
revision.branch_labels.update(revision.branch_labels)
@@ -285,7 +341,9 @@ class RevisionMap:
else:
break
- def _add_depends_on(self, revisions, map_):
+ def _add_depends_on(
+ self, revisions: Collection["Revision"], map_: _RevisionMapType
+ ) -> None:
"""Resolve the 'dependencies' for each revision in a collection
in terms of actual revision ids, as opposed to branch labels or other
symbolic names.
@@ -301,12 +359,14 @@ class RevisionMap:
map_[dep] for dep in util.to_tuple(revision.dependencies)
]
revision._resolved_dependencies = tuple(
- [d.revision for d in deps]
+ [d.revision for d in deps if d is not None]
)
else:
revision._resolved_dependencies = ()
- def _normalize_depends_on(self, revisions, map_):
+ def _normalize_depends_on(
+ self, revisions: Collection["Revision"], map_: _RevisionMapType
+ ) -> None:
"""Create a collection of "dependencies" that omits dependencies
that are already ancestor nodes for each revision in a given
collection.
@@ -327,7 +387,9 @@ class RevisionMap:
if revision._resolved_dependencies:
normalized_resolved = set(revision._resolved_dependencies)
for rev in self._get_ancestor_nodes(
- [revision], include_dependencies=False, map_=map_
+ [revision],
+ include_dependencies=False,
+ map_=cast(_RevisionMapType, map_),
):
if rev is revision:
continue
@@ -342,7 +404,9 @@ class RevisionMap:
else:
revision._normalized_resolved_dependencies = ()
- def add_revision(self, revision, _replace=False):
+ def add_revision(
+ self, revision: "Revision", _replace: bool = False
+ ) -> None:
"""add a single revision to an existing map.
This method is for single-revision use cases, it's not
@@ -375,7 +439,7 @@ class RevisionMap:
"Revision %s referenced from %s is not present"
% (downrev, revision)
)
- map_[downrev].add_nextrev(revision)
+ cast("Revision", map_[downrev]).add_nextrev(revision)
self._normalize_depends_on(revisions, map_)
@@ -398,7 +462,9 @@ class RevisionMap:
)
) + (revision.revision,)
- def get_current_head(self, branch_label=None):
+ def get_current_head(
+ self, branch_label: Optional[str] = None
+ ) -> Optional[str]:
"""Return the current head revision.
If the script directory has multiple heads
@@ -416,7 +482,7 @@ class RevisionMap:
:meth:`.ScriptDirectory.get_heads`
"""
- current_heads = self.heads
+ current_heads: Sequence[str] = self.heads
if branch_label:
current_heads = self.filter_for_lineage(
current_heads, branch_label
@@ -432,10 +498,12 @@ class RevisionMap:
else:
return None
- def _get_base_revisions(self, identifier):
+ def _get_base_revisions(self, identifier: str) -> Tuple[str, ...]:
return self.filter_for_lineage(self.bases, identifier)
- def get_revisions(self, id_):
+ def get_revisions(
+ self, id_: Union[str, Collection[str], None]
+ ) -> Tuple["Revision", ...]:
"""Return the :class:`.Revision` instances with the given rev id
or identifiers.
@@ -456,7 +524,9 @@ class RevisionMap:
if isinstance(id_, (list, tuple, set, frozenset)):
return sum([self.get_revisions(id_elem) for id_elem in id_], ())
else:
- resolved_id, branch_label = self._resolve_revision_number(id_)
+ resolved_id, branch_label = self._resolve_revision_number(
+ id_ # type:ignore [arg-type]
+ )
if len(resolved_id) == 1:
try:
rint = int(resolved_id[0])
@@ -464,11 +534,11 @@ class RevisionMap:
# branch@-n -> walk down from heads
select_heads = self.get_revisions("heads")
if branch_label is not None:
- select_heads = [
+ select_heads = tuple(
head
for head in select_heads
if branch_label in head.branch_labels
- ]
+ )
return tuple(
self._walk(head, steps=rint)
for head in select_heads
@@ -481,7 +551,7 @@ class RevisionMap:
for rev_id in resolved_id
)
- def get_revision(self, id_):
+ def get_revision(self, id_: Optional[str]) -> "Revision":
"""Return the :class:`.Revision` instance with the given rev id.
If a symbolic name such as "head" or "base" is given, resolves
@@ -499,11 +569,11 @@ class RevisionMap:
if len(resolved_id) > 1:
raise MultipleHeads(resolved_id, id_)
elif resolved_id:
- resolved_id = resolved_id[0]
+ resolved_id = resolved_id[0] # type:ignore[assignment]
- return self._revision_for_ident(resolved_id, branch_label)
+ return self._revision_for_ident(cast(str, resolved_id), branch_label)
- def _resolve_branch(self, branch_label):
+ def _resolve_branch(self, branch_label: str) -> "Revision":
try:
branch_rev = self._revision_map[branch_label]
except KeyError:
@@ -517,19 +587,24 @@ class RevisionMap:
else:
return nonbranch_rev
else:
- return branch_rev
+ return cast("Revision", branch_rev)
- def _revision_for_ident(self, resolved_id, check_branch=None):
+ def _revision_for_ident(
+ self, resolved_id: str, check_branch: Optional[str] = None
+ ) -> "Revision":
+ branch_rev: Optional["Revision"]
if check_branch:
branch_rev = self._resolve_branch(check_branch)
else:
branch_rev = None
+ revision: Union["Revision", "Literal[False]"]
try:
- revision = self._revision_map[resolved_id]
+ revision = cast("Revision", self._revision_map[resolved_id])
except KeyError:
# break out to avoid misleading py3k stack traces
revision = False
+ revs: Sequence[str]
if revision is False:
# do a partial lookup
revs = [
@@ -562,9 +637,11 @@ class RevisionMap:
resolved_id,
)
else:
- revision = self._revision_map[revs[0]]
+ revision = cast("Revision", self._revision_map[revs[0]])
+ revision = cast("Revision", revision)
if check_branch and revision is not None:
+ assert branch_rev is not None
if not self._shares_lineage(
revision.revision, branch_rev.revision
):
@@ -575,7 +652,9 @@ class RevisionMap:
)
return revision
- def _filter_into_branch_heads(self, targets):
+ def _filter_into_branch_heads(
+ self, targets: Set["Script"]
+ ) -> Set["Script"]:
targets = set(targets)
for rev in list(targets):
@@ -586,8 +665,11 @@ class RevisionMap:
return targets
def filter_for_lineage(
- self, targets, check_against, include_dependencies=False
- ):
+ self,
+ targets: Sequence[_T],
+ check_against: Optional[str],
+ include_dependencies: bool = False,
+ ) -> Tuple[_T, ...]:
id_, branch_label = self._resolve_revision_number(check_against)
shares = []
@@ -596,17 +678,20 @@ class RevisionMap:
if id_:
shares.extend(id_)
- return [
+ return tuple(
tg
for tg in targets
if self._shares_lineage(
tg, shares, include_dependencies=include_dependencies
)
- ]
+ )
def _shares_lineage(
- self, target, test_against_revs, include_dependencies=False
- ):
+ self,
+ target: _RevisionOrStr,
+ test_against_revs: Sequence[_RevisionOrStr],
+ include_dependencies: bool = False,
+ ) -> bool:
if not test_against_revs:
return True
if not isinstance(target, Revision):
@@ -635,7 +720,10 @@ class RevisionMap:
.intersection(test_against_revs)
)
- def _resolve_revision_number(self, id_):
+ def _resolve_revision_number(
+ self, id_: Optional[str]
+ ) -> Tuple[Tuple[str, ...], Optional[str]]:
+ branch_label: Optional[str]
if isinstance(id_, compat.string_types) and "@" in id_:
branch_label, id_ = id_.split("@", 1)
@@ -678,13 +766,13 @@ class RevisionMap:
def iterate_revisions(
self,
- upper,
- lower,
- implicit_base=False,
- inclusive=False,
- assert_relative_length=True,
- select_for_downgrade=False,
- ):
+ upper: _RevisionIdentifierType,
+ lower: _RevisionIdentifierType,
+ implicit_base: bool = False,
+ inclusive: bool = False,
+ assert_relative_length: bool = True,
+ select_for_downgrade: bool = False,
+ ) -> Iterator["Revision"]:
"""Iterate through script revisions, starting at the given
upper revision identifier and ending at the lower.
@@ -696,6 +784,7 @@ class RevisionMap:
The iterator yields :class:`.Revision` objects.
"""
+ fn: Callable
if select_for_downgrade:
fn = self._collect_downgrade_revisions
else:
@@ -714,12 +803,12 @@ class RevisionMap:
def _get_descendant_nodes(
self,
- targets,
- map_=None,
- check=False,
- omit_immediate_dependencies=False,
- include_dependencies=True,
- ):
+ targets: Collection["Revision"],
+ map_: Optional[_RevisionMapType] = None,
+ check: bool = False,
+ omit_immediate_dependencies: bool = False,
+ include_dependencies: bool = True,
+ ) -> Iterator[Any]:
if omit_immediate_dependencies:
@@ -744,8 +833,12 @@ class RevisionMap:
)
def _get_ancestor_nodes(
- self, targets, map_=None, check=False, include_dependencies=True
- ):
+ self,
+ targets: Collection["Revision"],
+ map_: Optional[_RevisionMapType] = None,
+ check: bool = False,
+ include_dependencies: bool = True,
+ ) -> Iterator["Revision"]:
if include_dependencies:
@@ -761,12 +854,18 @@ class RevisionMap:
fn, targets, map_=map_, check=check
)
- def _iterate_related_revisions(self, fn, targets, map_, check=False):
+ def _iterate_related_revisions(
+ self,
+ fn: Callable,
+ targets: Collection["Revision"],
+ map_: Optional[_RevisionMapType],
+ check: bool = False,
+ ) -> Iterator["Revision"]:
if map_ is None:
map_ = self._revision_map
seen = set()
- todo = collections.deque()
+ todo: Deque["Revision"] = collections.deque()
for target in targets:
todo.append(target)
@@ -784,6 +883,7 @@ class RevisionMap:
# Check for map errors before collecting.
for rev_id in fn(rev):
next_rev = map_[rev_id]
+ assert next_rev is not None
if next_rev.revision != rev_id:
raise RevisionError(
"Dependency resolution failed; broken map"
@@ -804,7 +904,11 @@ class RevisionMap:
)
)
- def _topological_sort(self, revisions, heads):
+ def _topological_sort(
+ self,
+ revisions: Collection["Revision"],
+ heads: Any,
+ ) -> List[str]:
"""Yield revision ids of a collection of Revision objects in
topological sorted order (i.e. revisions always come after their
down_revisions and dependencies). Uses the order of keys in
@@ -860,6 +964,7 @@ class RevisionMap:
# now update the heads with our ancestors.
candidate_rev = id_to_rev[candidate]
+ assert candidate_rev is not None
heads_to_add = [
r
@@ -873,7 +978,6 @@ class RevisionMap:
del ancestors_by_idx[current_candidate_idx]
current_candidate_idx = max(current_candidate_idx - 1, 0)
else:
-
if (
not candidate_rev._normalized_resolved_dependencies
and len(candidate_rev._versioned_down_revisions) == 1
@@ -905,7 +1009,13 @@ class RevisionMap:
assert not todo
return output
- def _walk(self, start, steps, branch_label=None, no_overwalk=True):
+ def _walk(
+ self,
+ start: Optional[Union[str, "Revision"]],
+ steps: int,
+ branch_label: Optional[str] = None,
+ no_overwalk: bool = True,
+ ) -> "Revision":
"""
Walk the requested number of :steps up (steps > 0) or down (steps < 0)
the revision tree.
@@ -918,44 +1028,55 @@ class RevisionMap:
A RevisionError is raised if there is no unambiguous revision to
walk to.
"""
-
+ initial: Optional[_RevisionOrBase]
if isinstance(start, compat.string_types):
- start = self.get_revision(start)
+ initial = self.get_revision(start)
+ else:
+ initial = start
+ children: Sequence[_RevisionOrBase]
for _ in range(abs(steps)):
if steps > 0:
# Walk up
children = [
rev
for rev in self.get_revisions(
- self.bases if start is None else start.nextrev
+ self.bases
+ if initial is None
+ else cast("Revision", initial).nextrev
)
]
if branch_label:
children = self.filter_for_lineage(children, branch_label)
else:
# Walk down
- if start == "base":
- children = tuple()
+ if initial == "base":
+ children = ()
else:
children = self.get_revisions(
- self.heads if start is None else start.down_revision
+ self.heads
+ if initial is None
+ else initial.down_revision
)
if not children:
- children = ("base",)
+ children = cast("Tuple[Literal['base']]", ("base",))
if not children:
# This will return an invalid result if no_overwalk, otherwise
# further steps will stay where we are.
- return None if no_overwalk else start
+ ret = None if no_overwalk else initial
+ return ret # type:ignore[return-value]
elif len(children) > 1:
raise RevisionError("Ambiguous walk")
- start = children[0]
+ initial = children[0]
- return start
+ return cast("Revision", initial)
def _parse_downgrade_target(
- self, current_revisions, target, assert_relative_length
- ):
+ self,
+ current_revisions: _RevisionIdentifierType,
+ target: _RevisionIdentifierType,
+ assert_relative_length: bool,
+ ) -> Tuple[Optional[str], Optional[_RevisionOrBase]]:
"""
Parse downgrade command syntax :target to retrieve the target revision
and branch label (if any) given the :current_revisons stamp of the
@@ -999,11 +1120,11 @@ class RevisionMap:
if relative_revision:
# Find target revision relative to current state.
if branch_label:
- symbol = self.filter_for_lineage(
+ symbol_list = self.filter_for_lineage(
util.to_tuple(current_revisions), branch_label
)
- assert len(symbol) == 1
- symbol = symbol[0]
+ assert len(symbol_list) == 1
+ symbol = symbol_list[0]
else:
current_revisions = util.to_tuple(current_revisions)
if not current_revisions:
@@ -1045,12 +1166,15 @@ class RevisionMap:
# No relative destination given, revision specified is absolute.
branch_label, _, symbol = target.rpartition("@")
if not branch_label:
- branch_label = None
+ branch_label = None # type:ignore[assignment]
return branch_label, self.get_revision(symbol)
def _parse_upgrade_target(
- self, current_revisions, target, assert_relative_length
- ):
+ self,
+ current_revisions: _RevisionIdentifierType,
+ target: _RevisionIdentifierType,
+ assert_relative_length: bool,
+ ) -> Tuple["Revision", ...]:
"""
Parse upgrade command syntax :target to retrieve the target revision
and given the :current_revisons stamp of the database.
@@ -1070,9 +1194,8 @@ class RevisionMap:
current_revisions = util.to_tuple(current_revisions)
- branch_label, symbol, relative = match.groups()
- relative_str = relative
- relative = int(relative)
+ branch_label, symbol, relative_str = match.groups()
+ relative = int(relative_str)
if relative > 0:
if symbol is None:
if not current_revisions:
@@ -1151,8 +1274,13 @@ class RevisionMap:
)
def _collect_downgrade_revisions(
- self, upper, target, inclusive, implicit_base, assert_relative_length
- ):
+ self,
+ upper: _RevisionIdentifierType,
+ target: _RevisionIdentifierType,
+ inclusive: bool,
+ implicit_base: bool,
+ assert_relative_length: bool,
+ ) -> Any:
"""
Compute the set of current revisions specified by :upper, and the
downgrade target specified by :target. Return all dependents of target
@@ -1244,8 +1372,13 @@ class RevisionMap:
return downgrade_revisions, heads
def _collect_upgrade_revisions(
- self, upper, lower, inclusive, implicit_base, assert_relative_length
- ):
+ self,
+ upper: _RevisionIdentifierType,
+ lower: _RevisionIdentifierType,
+ inclusive: bool,
+ implicit_base: bool,
+ assert_relative_length: bool,
+ ) -> Tuple[Set["Revision"], Tuple[Optional[_RevisionOrBase]]]:
"""
Compute the set of required revisions specified by :upper, and the
current set of active revisions specified by :lower. Find the
@@ -1257,14 +1390,13 @@ class RevisionMap:
of the current/lower revisions. Dependencies from branches with
different bases will not be included.
"""
- targets = self._parse_upgrade_target(
+ targets: Collection["Revision"] = self._parse_upgrade_target(
current_revisions=lower,
target=upper,
assert_relative_length=assert_relative_length,
)
- assert targets is not None
- assert type(targets) is tuple, "targets should be a tuple"
+ # assert type(targets) is tuple, "targets should be a tuple"
# Handled named bases (e.g. branch@... -> heads should only produce
# targets on the given branch)
@@ -1332,7 +1464,7 @@ class RevisionMap:
)
needs.intersection_update(lower_descendents)
- return needs, targets
+ return needs, tuple(targets) # type:ignore[return-value]
class Revision:
@@ -1346,15 +1478,15 @@ class Revision:
"""
- nextrev = frozenset()
+ nextrev: FrozenSet[str] = frozenset()
"""following revisions, based on down_revision only."""
- _all_nextrev = frozenset()
+ _all_nextrev: FrozenSet[str] = frozenset()
- revision = None
+ revision: str = None # type: ignore[assignment]
"""The string revision number."""
- down_revision = None
+ down_revision: Optional[_RevIdType] = None
"""The ``down_revision`` identifier(s) within the migration script.
Note that the total set of "down" revisions is
@@ -1362,7 +1494,7 @@ class Revision:
"""
- dependencies = None
+ dependencies: Optional[_RevIdType] = None
"""Additional revisions which this revision is dependent on.
From a migration standpoint, these dependencies are added to the
@@ -1372,12 +1504,15 @@ class Revision:
"""
- branch_labels = None
+ branch_labels: Set[str] = None # type: ignore[assignment]
"""Optional string/tuple of symbolic names to apply to this
revision's branch"""
+ _resolved_dependencies: Tuple[str, ...]
+ _normalized_resolved_dependencies: Tuple[str, ...]
+
@classmethod
- def verify_rev_id(cls, revision):
+ def verify_rev_id(cls, revision: str) -> None:
illegal_chars = set(revision).intersection(_revision_illegal_chars)
if illegal_chars:
raise RevisionError(
@@ -1386,8 +1521,12 @@ class Revision:
)
def __init__(
- self, revision, down_revision, dependencies=None, branch_labels=None
- ):
+ self,
+ revision: str,
+ down_revision: Optional[Union[str, Tuple[str, ...]]],
+ dependencies: Optional[Tuple[str, ...]] = None,
+ branch_labels: Optional[Tuple[str, ...]] = None,
+ ) -> None:
if down_revision and revision in util.to_tuple(down_revision):
raise LoopDetected(revision)
elif dependencies is not None and revision in util.to_tuple(
@@ -1402,7 +1541,7 @@ class Revision:
self._orig_branch_labels = util.to_tuple(branch_labels, default=())
self.branch_labels = set(self._orig_branch_labels)
- def __repr__(self):
+ def __repr__(self) -> str:
args = [repr(self.revision), repr(self.down_revision)]
if self.dependencies:
args.append("dependencies=%r" % (self.dependencies,))
@@ -1410,20 +1549,20 @@ class Revision:
args.append("branch_labels=%r" % (self.branch_labels,))
return "%s(%s)" % (self.__class__.__name__, ", ".join(args))
- def add_nextrev(self, revision):
+ def add_nextrev(self, revision: "Revision") -> None:
self._all_nextrev = self._all_nextrev.union([revision.revision])
if self.revision in revision._versioned_down_revisions:
self.nextrev = self.nextrev.union([revision.revision])
@property
- def _all_down_revisions(self):
+ def _all_down_revisions(self) -> Tuple[str, ...]:
return util.dedupe_tuple(
util.to_tuple(self.down_revision, default=())
+ self._resolved_dependencies
)
@property
- def _normalized_down_revisions(self):
+ def _normalized_down_revisions(self) -> Tuple[str, ...]:
"""return immediate down revisions for a rev, omitting dependencies
that are still dependencies of ancestors.
@@ -1434,11 +1573,11 @@ class Revision:
)
@property
- def _versioned_down_revisions(self):
+ def _versioned_down_revisions(self) -> Tuple[str, ...]:
return util.to_tuple(self.down_revision, default=())
@property
- def is_head(self):
+ def is_head(self) -> bool:
"""Return True if this :class:`.Revision` is a 'head' revision.
This is determined based on whether any other :class:`.Script`
@@ -1449,17 +1588,17 @@ class Revision:
return not bool(self.nextrev)
@property
- def _is_real_head(self):
+ def _is_real_head(self) -> bool:
return not bool(self._all_nextrev)
@property
- def is_base(self):
+ def is_base(self) -> bool:
"""Return True if this :class:`.Revision` is a 'base' revision."""
return self.down_revision is None
@property
- def _is_real_base(self):
+ def _is_real_base(self) -> bool:
"""Return True if this :class:`.Revision` is a "real" base revision,
e.g. that it has no dependencies either."""
@@ -1469,7 +1608,7 @@ class Revision:
return self.down_revision is None and self.dependencies is None
@property
- def is_branch_point(self):
+ def is_branch_point(self) -> bool:
"""Return True if this :class:`.Script` is a branch point.
A branchpoint is defined as a :class:`.Script` which is referred
@@ -1481,7 +1620,7 @@ class Revision:
return len(self.nextrev) > 1
@property
- def _is_real_branch_point(self):
+ def _is_real_branch_point(self) -> bool:
"""Return True if this :class:`.Script` is a 'real' branch point,
taking into account dependencies as well.
@@ -1489,13 +1628,15 @@ class Revision:
return len(self._all_nextrev) > 1
@property
- def is_merge_point(self):
+ def is_merge_point(self) -> bool:
"""Return True if this :class:`.Script` is a merge point."""
return len(self._versioned_down_revisions) > 1
-def tuple_rev_as_scalar(rev):
+def tuple_rev_as_scalar(
+ rev: Optional[Sequence[str]],
+) -> Optional[Union[str, Sequence[str]]]:
if not rev:
return None
elif len(rev) == 1:
diff --git a/alembic/script/write_hooks.py b/alembic/script/write_hooks.py
index 8cd3dcc..8f9e35e 100644
--- a/alembic/script/write_hooks.py
+++ b/alembic/script/write_hooks.py
@@ -1,6 +1,11 @@
import shlex
import subprocess
import sys
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import List
+from typing import Union
from .. import util
from ..util import compat
@@ -11,7 +16,7 @@ REVISION_SCRIPT_TOKEN = "REVISION_SCRIPT_FILENAME"
_registry = {}
-def register(name):
+def register(name: str) -> Callable:
"""A function decorator that will register that function as a write hook.
See the documentation linked below for an example.
@@ -31,7 +36,9 @@ def register(name):
return decorate
-def _invoke(name, revision, options):
+def _invoke(
+ name: str, revision: str, options: Dict[str, Union[str, int]]
+) -> Any:
"""Invokes the formatter registered for the given name.
:param name: The name of a formatter in the registry
@@ -50,7 +57,7 @@ def _invoke(name, revision, options):
return hook(revision, options)
-def _run_hooks(path, hook_config):
+def _run_hooks(path: str, hook_config: Dict[str, str]) -> None:
"""Invoke hooks for a generated revision."""
from .base import _split_on_space_comma
@@ -83,7 +90,7 @@ def _run_hooks(path, hook_config):
)
-def _parse_cmdline_options(cmdline_options_str, path):
+def _parse_cmdline_options(cmdline_options_str: str, path: str) -> List[str]:
"""Parse options from a string into a list.
Also substitutes the revision script token with the actual filename of
diff --git a/alembic/testing/assertions.py b/alembic/testing/assertions.py
index e22ac6b..ed53206 100644
--- a/alembic/testing/assertions.py
+++ b/alembic/testing/assertions.py
@@ -1,8 +1,8 @@
-from __future__ import absolute_import
-
import contextlib
import re
import sys
+from typing import Any
+from typing import Dict
from sqlalchemy import exc as sa_exc
from sqlalchemy import util
@@ -114,7 +114,7 @@ def eq_ignore_whitespace(a, b, msg=None):
assert a == b, msg or "%r != %r" % (a, b)
-_dialect_mods = {}
+_dialect_mods: Dict[Any, Any] = {}
def _get_dialect(name):
diff --git a/alembic/testing/fixtures.py b/alembic/testing/fixtures.py
index cccc382..c273665 100644
--- a/alembic/testing/fixtures.py
+++ b/alembic/testing/fixtures.py
@@ -3,6 +3,8 @@ import configparser
from contextlib import contextmanager
import io
import re
+from typing import Any
+from typing import Dict
from sqlalchemy import Column
from sqlalchemy import inspect
@@ -61,7 +63,7 @@ if sqla_14:
from sqlalchemy.testing.fixtures import FutureEngineMixin
else:
- class FutureEngineMixin:
+ class FutureEngineMixin: # type:ignore[no-redef]
__requires__ = ("sqlalchemy_14",)
@@ -78,7 +80,7 @@ def capture_db(dialect="postgresql://"):
return engine, buf
-_engs = {}
+_engs: Dict[Any, Any] = {}
@contextmanager
diff --git a/alembic/testing/requirements.py b/alembic/testing/requirements.py
index f1792f8..37780ab 100644
--- a/alembic/testing/requirements.py
+++ b/alembic/testing/requirements.py
@@ -1,5 +1,3 @@
-import sys
-
from sqlalchemy.testing.requirements import Requirements
from alembic import util
@@ -86,12 +84,6 @@ class SuiteRequirements(Requirements):
)
@property
- def python3(self):
- return exclusions.skip_if(
- lambda: sys.version_info < (3,), "Python version 3.xx is required."
- )
-
- @property
def comments(self):
return exclusions.only_if(
lambda config: config.db.dialect.supports_comments
diff --git a/alembic/testing/suite/_autogen_fixtures.py b/alembic/testing/suite/_autogen_fixtures.py
index 44fc24f..ea1957a 100644
--- a/alembic/testing/suite/_autogen_fixtures.py
+++ b/alembic/testing/suite/_autogen_fixtures.py
@@ -1,3 +1,6 @@
+from typing import Any
+from typing import Dict
+
from sqlalchemy import CHAR
from sqlalchemy import CheckConstraint
from sqlalchemy import Column
@@ -211,7 +214,7 @@ class AutogenTest(_ComparesFKs):
def _get_bind(cls):
return config.db
- configure_opts = {}
+ configure_opts: Dict[Any, Any] = {}
@classmethod
def setup_class(cls):
diff --git a/alembic/util/compat.py b/alembic/util/compat.py
index 0fdd86d..a07813c 100644
--- a/alembic/util/compat.py
+++ b/alembic/util/compat.py
@@ -2,6 +2,12 @@ import collections
import inspect
import io
import os
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Type
is_posix = os.name == "posix"
@@ -10,11 +16,11 @@ ArgSpec = collections.namedtuple(
)
-def inspect_getargspec(func):
+def inspect_getargspec(func: Callable) -> ArgSpec:
"""getargspec based on fully vendored getfullargspec from Python 3.3."""
if inspect.ismethod(func):
- func = func.__func__
+ func = func.__func__ # type: ignore
if not inspect.isfunction(func):
raise TypeError("{!r} is not a Python function".format(func))
@@ -36,7 +42,7 @@ def inspect_getargspec(func):
if co.co_flags & inspect.CO_VARKEYWORDS:
varkw = co.co_varnames[nargs]
- return ArgSpec(args, varargs, varkw, func.__defaults__)
+ return ArgSpec(args, varargs, varkw, func.__defaults__) # type: ignore
string_types = (str,)
@@ -57,20 +63,20 @@ def _formatannotation(annotation, base_module=None):
def inspect_formatargspec(
- args,
- varargs=None,
- varkw=None,
- defaults=None,
- kwonlyargs=(),
- kwonlydefaults={},
- annotations={},
- formatarg=str,
- formatvarargs=lambda name: "*" + name,
- formatvarkw=lambda name: "**" + name,
- formatvalue=lambda value: "=" + repr(value),
- formatreturns=lambda text: " -> " + text,
- formatannotation=_formatannotation,
-):
+ args: List[str],
+ varargs: Optional[str] = None,
+ varkw: Optional[str] = None,
+ defaults: Optional[Any] = None,
+ kwonlyargs: tuple = (),
+ kwonlydefaults: Dict[Any, Any] = {},
+ annotations: Dict[Any, Any] = {},
+ formatarg: Type[str] = str,
+ formatvarargs: Callable = lambda name: "*" + name,
+ formatvarkw: Callable = lambda name: "**" + name,
+ formatvalue: Callable = lambda value: "=" + repr(value),
+ formatreturns: Callable = lambda text: " -> " + text,
+ formatannotation: Callable = _formatannotation,
+) -> str:
"""Copy formatargspec from python 3.7 standard library.
Python 3 has deprecated formatargspec and requested that Signature
@@ -118,5 +124,5 @@ def inspect_formatargspec(
# into a given buffer, but doesn't close it.
# not sure of a more idiomatic approach to this.
class EncodedIO(io.TextIOWrapper):
- def close(self):
+ def close(self) -> None:
pass
diff --git a/alembic/util/editor.py b/alembic/util/editor.py
index c27f0f3..ba376c0 100644
--- a/alembic/util/editor.py
+++ b/alembic/util/editor.py
@@ -3,12 +3,18 @@ from os.path import exists
from os.path import join
from os.path import splitext
from subprocess import check_call
+from typing import Dict
+from typing import List
+from typing import Mapping
+from typing import Optional
from .compat import is_posix
from .exc import CommandError
-def open_in_editor(filename, environ=None):
+def open_in_editor(
+ filename: str, environ: Optional[Dict[str, str]] = None
+) -> None:
"""
Opens the given file in a text editor. If the environment variable
``EDITOR`` is set, this is taken as preference.
@@ -22,15 +28,15 @@ def open_in_editor(filename, environ=None):
:param environ: An optional drop-in replacement for ``os.environ``. Used
mainly for testing.
"""
-
+ env = os.environ if environ is None else environ
try:
- editor = _find_editor(environ)
+ editor = _find_editor(env)
check_call([editor, filename])
except Exception as exc:
raise CommandError("Error executing editor (%s)" % (exc,)) from exc
-def _find_editor(environ=None):
+def _find_editor(environ: Mapping[str, str]) -> str:
candidates = _default_editors()
for i, var in enumerate(("EDITOR", "VISUAL")):
if var in environ:
@@ -50,7 +56,9 @@ def _find_editor(environ=None):
)
-def _find_executable(candidate, environ):
+def _find_executable(
+ candidate: str, environ: Mapping[str, str]
+) -> Optional[str]:
# Assuming this is on the PATH, we need to determine it's absolute
# location. Otherwise, ``check_call`` will fail
if not is_posix and splitext(candidate)[1] != ".exe":
@@ -62,7 +70,7 @@ def _find_executable(candidate, environ):
return None
-def _default_editors():
+def _default_editors() -> List[str]:
# Look for an editor. Prefer the user's choice by env-var, fall back to
# most commonly installed editor (nano/vim)
if is_posix:
diff --git a/alembic/util/langhelpers.py b/alembic/util/langhelpers.py
index dbd1f21..87a9aca 100644
--- a/alembic/util/langhelpers.py
+++ b/alembic/util/langhelpers.py
@@ -1,6 +1,16 @@
import collections
from collections.abc import Iterable
import textwrap
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import overload
+from typing import Sequence
+from typing import Tuple
+from typing import TypeVar
+from typing import Union
import uuid
import warnings
@@ -14,10 +24,13 @@ from .compat import inspect_getargspec
from .compat import string_types
+_T = TypeVar("_T")
+
+
class _ModuleClsMeta(type):
- def __setattr__(cls, key, value):
+ def __setattr__(cls, key: str, value: Callable) -> None:
super(_ModuleClsMeta, cls).__setattr__(key, value)
- cls._update_module_proxies(key)
+ cls._update_module_proxies(key) # type: ignore
class ModuleClsProxy(metaclass=_ModuleClsMeta):
@@ -29,22 +42,24 @@ class ModuleClsProxy(metaclass=_ModuleClsMeta):
"""
- _setups = collections.defaultdict(lambda: (set(), []))
+ _setups: Dict[type, Tuple[set, list]] = collections.defaultdict(
+ lambda: (set(), [])
+ )
@classmethod
- def _update_module_proxies(cls, name):
+ def _update_module_proxies(cls, name: str) -> None:
attr_names, modules = cls._setups[cls]
for globals_, locals_ in modules:
cls._add_proxied_attribute(name, globals_, locals_, attr_names)
- def _install_proxy(self):
+ def _install_proxy(self) -> None:
attr_names, modules = self._setups[self.__class__]
for globals_, locals_ in modules:
globals_["_proxy"] = self
for attr_name in attr_names:
globals_[attr_name] = getattr(self, attr_name)
- def _remove_proxy(self):
+ def _remove_proxy(self) -> None:
attr_names, modules = self._setups[self.__class__]
for globals_, locals_ in modules:
globals_["_proxy"] = None
@@ -171,10 +186,25 @@ def _with_legacy_names(translations):
return decorate
-def rev_id():
+def rev_id() -> str:
return uuid.uuid4().hex[-12:]
+@overload
+def to_tuple(x: Any, default: tuple) -> tuple:
+ ...
+
+
+@overload
+def to_tuple(x: None, default: _T = None) -> _T:
+ ...
+
+
+@overload
+def to_tuple(x: Any, default: Optional[tuple] = None) -> tuple:
+ ...
+
+
def to_tuple(x, default=None):
if x is None:
return default
@@ -186,16 +216,18 @@ def to_tuple(x, default=None):
return (x,)
-def dedupe_tuple(tup):
+def dedupe_tuple(tup: Tuple[str, ...]) -> Tuple[str, ...]:
return tuple(unique_list(tup))
class Dispatcher:
- def __init__(self, uselist=False):
- self._registry = {}
+ def __init__(self, uselist: bool = False) -> None:
+ self._registry: Dict[tuple, Any] = {}
self.uselist = uselist
- def dispatch_for(self, target, qualifier="default"):
+ def dispatch_for(
+ self, target: Any, qualifier: str = "default"
+ ) -> Callable:
def decorate(fn):
if self.uselist:
self._registry.setdefault((target, qualifier), []).append(fn)
@@ -206,10 +238,10 @@ class Dispatcher:
return decorate
- def dispatch(self, obj, qualifier="default"):
+ def dispatch(self, obj: Any, qualifier: str = "default") -> Any:
if isinstance(obj, string_types):
- targets = [obj]
+ targets: Sequence = [obj]
elif isinstance(obj, type):
targets = obj.__mro__
else:
@@ -223,7 +255,9 @@ class Dispatcher:
else:
raise ValueError("no dispatch function for object: %s" % obj)
- def _fn_or_list(self, fn_or_list):
+ def _fn_or_list(
+ self, fn_or_list: Union[List[Callable], Callable]
+ ) -> Callable:
if self.uselist:
def go(*arg, **kw):
@@ -232,9 +266,9 @@ class Dispatcher:
return go
else:
- return fn_or_list
+ return fn_or_list # type: ignore
- def branch(self):
+ def branch(self) -> "Dispatcher":
"""Return a copy of this dispatcher that is independently
writable."""
diff --git a/alembic/util/messaging.py b/alembic/util/messaging.py
index 70c9128..062890a 100644
--- a/alembic/util/messaging.py
+++ b/alembic/util/messaging.py
@@ -2,6 +2,11 @@ from collections.abc import Iterable
import logging
import sys
import textwrap
+from typing import Any
+from typing import Callable
+from typing import Optional
+from typing import TextIO
+from typing import Union
import warnings
from sqlalchemy.engine import url
@@ -29,7 +34,7 @@ except (ImportError, IOError):
TERMWIDTH = None
-def write_outstream(stream, *text):
+def write_outstream(stream: TextIO, *text) -> None:
encoding = getattr(stream, "encoding", "ascii") or "ascii"
for t in text:
if not isinstance(t, binary_type):
@@ -44,7 +49,7 @@ def write_outstream(stream, *text):
break
-def status(_statmsg, fn, *arg, **kw):
+def status(_statmsg: str, fn: Callable, *arg, **kw) -> Any:
newline = kw.pop("newline", False)
msg(_statmsg + " ...", newline, True)
try:
@@ -56,27 +61,27 @@ def status(_statmsg, fn, *arg, **kw):
raise
-def err(message):
+def err(message: str):
log.error(message)
msg("FAILED: %s" % message)
sys.exit(-1)
-def obfuscate_url_pw(u):
- u = url.make_url(u)
+def obfuscate_url_pw(input_url: str) -> str:
+ u = url.make_url(input_url)
if u.password:
if sqla_compat.sqla_14:
u = u.set(password="XXXXX")
else:
- u.password = "XXXXX"
+ u.password = "XXXXX" # type: ignore[misc]
return str(u)
-def warn(msg, stacklevel=2):
+def warn(msg: str, stacklevel: int = 2) -> None:
warnings.warn(msg, UserWarning, stacklevel=stacklevel)
-def msg(msg, newline=True, flush=False):
+def msg(msg: str, newline: bool = True, flush: bool = False) -> None:
if TERMWIDTH is None:
write_outstream(sys.stdout, msg)
if newline:
@@ -92,7 +97,7 @@ def msg(msg, newline=True, flush=False):
sys.stdout.flush()
-def format_as_comma(value):
+def format_as_comma(value: Optional[Union[str, "Iterable[str]"]]) -> str:
if value is None:
return ""
elif isinstance(value, string_types):
diff --git a/alembic/util/pyfiles.py b/alembic/util/pyfiles.py
index 53cc3cc..7eb582e 100644
--- a/alembic/util/pyfiles.py
+++ b/alembic/util/pyfiles.py
@@ -4,6 +4,7 @@ import importlib.util
import os
import re
import tempfile
+from typing import Optional
from mako import exceptions
from mako.template import Template
@@ -11,7 +12,9 @@ from mako.template import Template
from .exc import CommandError
-def template_to_file(template_file, dest, output_encoding, **kw):
+def template_to_file(
+ template_file: str, dest: str, output_encoding: str, **kw
+) -> None:
template = Template(filename=template_file)
try:
output = template.render_unicode(**kw).encode(output_encoding)
@@ -32,7 +35,7 @@ def template_to_file(template_file, dest, output_encoding, **kw):
f.write(output)
-def coerce_resource_to_filename(fname):
+def coerce_resource_to_filename(fname: str) -> str:
"""Interpret a filename as either a filesystem location or as a package
resource.
@@ -47,7 +50,7 @@ def coerce_resource_to_filename(fname):
return fname
-def pyc_file_from_path(path):
+def pyc_file_from_path(path: str) -> Optional[str]:
"""Given a python source path, locate the .pyc."""
candidate = importlib.util.cache_from_source(path)
@@ -64,7 +67,7 @@ def pyc_file_from_path(path):
return None
-def load_python_file(dir_, filename):
+def load_python_file(dir_: str, filename: str):
"""Load a file from the given path as a Python module."""
module_id = re.sub(r"\W", "_", filename)
@@ -78,21 +81,15 @@ def load_python_file(dir_, filename):
if pyc_path is None:
raise ImportError("Can't find Python file %s" % path)
else:
- module = load_module_pyc(module_id, pyc_path)
+ module = load_module_py(module_id, pyc_path)
elif ext in (".pyc", ".pyo"):
- module = load_module_pyc(module_id, path)
+ module = load_module_py(module_id, path)
return module
-def load_module_py(module_id, path):
+def load_module_py(module_id: str, path: str):
spec = importlib.util.spec_from_file_location(module_id, path)
+ assert spec
module = importlib.util.module_from_spec(spec)
- spec.loader.exec_module(module)
- return module
-
-
-def load_module_pyc(module_id, path):
- spec = importlib.util.spec_from_file_location(module_id, path)
- module = importlib.util.module_from_spec(spec)
- spec.loader.exec_module(module)
+ spec.loader.exec_module(module) # type: ignore
return module
diff --git a/alembic/util/sqla_compat.py b/alembic/util/sqla_compat.py
index a04ab2e..e1ccd41 100644
--- a/alembic/util/sqla_compat.py
+++ b/alembic/util/sqla_compat.py
@@ -1,5 +1,11 @@
import contextlib
import re
+from typing import Iterator
+from typing import Mapping
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import TypeVar
+from typing import Union
from sqlalchemy import __version__
from sqlalchemy import inspect
@@ -12,15 +18,34 @@ from sqlalchemy.schema import CheckConstraint
from sqlalchemy.schema import Column
from sqlalchemy.schema import ForeignKeyConstraint
from sqlalchemy.sql import visitors
+from sqlalchemy.sql.elements import BindParameter
from sqlalchemy.sql.elements import quoted_name
-from sqlalchemy.sql.expression import _BindParamClause
-from sqlalchemy.sql.expression import _TextClause as TextClause
+from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.visitors import traverse
from . import compat
-
-def _safe_int(value):
+if TYPE_CHECKING:
+ from sqlalchemy import Index
+ from sqlalchemy import Table
+ from sqlalchemy.engine import Connection
+ from sqlalchemy.engine import Dialect
+ from sqlalchemy.engine import Transaction
+ from sqlalchemy.engine.reflection import Inspector
+ from sqlalchemy.sql.base import ColumnCollection
+ from sqlalchemy.sql.compiler import SQLCompiler
+ from sqlalchemy.sql.dml import Insert
+ from sqlalchemy.sql.elements import ColumnClause
+ from sqlalchemy.sql.elements import ColumnElement
+ from sqlalchemy.sql.schema import Constraint
+ from sqlalchemy.sql.schema import SchemaItem
+ from sqlalchemy.sql.selectable import Select
+ from sqlalchemy.sql.selectable import TableClause
+
+_CE = TypeVar("_CE", bound=Union["ColumnElement", "SchemaItem"])
+
+
+def _safe_int(value: str) -> Union[int, str]:
try:
return int(value)
except:
@@ -36,6 +61,7 @@ sqla_14 = _vers >= (1, 4)
try:
from sqlalchemy import Computed # noqa
except ImportError:
+ Computed = None # type: ignore
has_computed = False
has_computed_reflection = False
else:
@@ -45,6 +71,7 @@ else:
try:
from sqlalchemy import Identity # noqa
except ImportError:
+ Identity = None # type: ignore
has_identity = False
else:
# attributes common to Indentity and Sequence
@@ -67,21 +94,26 @@ AUTOINCREMENT_DEFAULT = "auto"
@contextlib.contextmanager
-def _ensure_scope_for_ddl(connection):
+def _ensure_scope_for_ddl(
+ connection: Optional["Connection"],
+) -> Iterator[None]:
try:
- in_transaction = connection.in_transaction
+ in_transaction = connection.in_transaction # type: ignore[union-attr]
except AttributeError:
- # catch for MockConnection
+ # catch for MockConnection, None
yield
else:
if not in_transaction():
+ assert connection is not None
with connection.begin():
yield
else:
yield
-def _safe_begin_connection_transaction(connection):
+def _safe_begin_connection_transaction(
+ connection: "Connection",
+) -> "Transaction":
transaction = _get_connection_transaction(connection)
if transaction:
return transaction
@@ -89,9 +121,9 @@ def _safe_begin_connection_transaction(connection):
return connection.begin()
-def _get_connection_in_transaction(connection):
+def _get_connection_in_transaction(connection: Optional["Connection"]) -> bool:
try:
- in_transaction = connection.in_transaction
+ in_transaction = connection.in_transaction # type: ignore
except AttributeError:
# catch for MockConnection
return False
@@ -99,28 +131,33 @@ def _get_connection_in_transaction(connection):
return in_transaction()
-def _copy(schema_item, **kw):
+def _copy(schema_item: _CE, **kw) -> _CE:
if hasattr(schema_item, "_copy"):
return schema_item._copy(**kw)
else:
return schema_item.copy(**kw)
-def _get_connection_transaction(connection):
+def _get_connection_transaction(
+ connection: "Connection",
+) -> Optional["Transaction"]:
if sqla_14:
return connection.get_transaction()
else:
- return connection._root._Connection__transaction
+ r = connection._root # type: ignore[attr-defined]
+ return r._Connection__transaction
-def _create_url(*arg, **kw):
+def _create_url(*arg, **kw) -> url.URL:
if hasattr(url.URL, "create"):
return url.URL.create(*arg, **kw)
else:
return url.URL(*arg, **kw)
-def _connectable_has_table(connectable, tablename, schemaname):
+def _connectable_has_table(
+ connectable: "Connection", tablename: str, schemaname: Union[str, None]
+) -> bool:
if sqla_14:
return inspect(connectable).has_table(tablename, schemaname)
else:
@@ -148,23 +185,25 @@ def _nullability_might_be_unset(metadata_column):
)
-def _server_default_is_computed(*server_default):
+def _server_default_is_computed(*server_default) -> bool:
if not has_computed:
return False
else:
return any(isinstance(sd, Computed) for sd in server_default)
-def _server_default_is_identity(*server_default):
+def _server_default_is_identity(*server_default) -> bool:
if not sqla_14:
return False
else:
return any(isinstance(sd, Identity) for sd in server_default)
-def _table_for_constraint(constraint):
+def _table_for_constraint(constraint: "Constraint") -> "Table":
if isinstance(constraint, ForeignKeyConstraint):
- return constraint.parent
+ table = constraint.parent
+ assert table is not None
+ return table
else:
return constraint.table
@@ -178,7 +217,9 @@ def _columns_for_constraint(constraint):
return list(constraint.columns)
-def _reflect_table(inspector, table, include_cols):
+def _reflect_table(
+ inspector: "Inspector", table: "Table", include_cols: None
+) -> None:
if sqla_14:
return inspector.reflect_table(table, None)
else:
@@ -213,19 +254,20 @@ def _fk_spec(constraint):
)
-def _fk_is_self_referential(constraint):
- spec = constraint.elements[0]._get_colspec()
+def _fk_is_self_referential(constraint: "ForeignKeyConstraint") -> bool:
+ spec = constraint.elements[0]._get_colspec() # type: ignore[attr-defined]
tokens = spec.split(".")
tokens.pop(-1) # colname
tablekey = ".".join(tokens)
+ assert constraint.parent is not None
return tablekey == constraint.parent.key
-def _is_type_bound(constraint):
+def _is_type_bound(constraint: "Constraint") -> bool:
# this deals with SQLAlchemy #3260, don't copy CHECK constraints
# that will be generated by the type.
# new feature added for #3260
- return constraint._type_bound
+ return constraint._type_bound # type: ignore[attr-defined]
def _find_columns(clause):
@@ -236,16 +278,21 @@ def _find_columns(clause):
return cols
-def _remove_column_from_collection(collection, column):
+def _remove_column_from_collection(
+ collection: "ColumnCollection", column: Union["Column", "ColumnClause"]
+) -> None:
"""remove a column from a ColumnCollection."""
# workaround for older SQLAlchemy, remove the
# same object that's present
+ assert column.key is not None
to_remove = collection[column.key]
collection.remove(to_remove)
-def _textual_index_column(table, text_):
+def _textual_index_column(
+ table: "Table", text_: Union[str, "TextClause", "ColumnElement"]
+) -> Union["ColumnElement", "Column"]:
"""a workaround for the Index construct's severe lack of flexibility"""
if isinstance(text_, compat.string_types):
c = Column(text_, sqltypes.NULLTYPE)
@@ -259,7 +306,7 @@ def _textual_index_column(table, text_):
raise ValueError("String or text() construct expected")
-def _copy_expression(expression, target_table):
+def _copy_expression(expression: _CE, target_table: "Table") -> _CE:
def replace(col):
if (
isinstance(col, Column)
@@ -296,7 +343,7 @@ class _textual_index_element(sql.ColumnElement):
__visit_name__ = "_textual_idx_element"
- def __init__(self, table, text):
+ def __init__(self, table: "Table", text: "TextClause") -> None:
self.table = table
self.text = text
self.key = text.text
@@ -308,16 +355,20 @@ class _textual_index_element(sql.ColumnElement):
@compiles(_textual_index_element)
-def _render_textual_index_column(element, compiler, **kw):
+def _render_textual_index_column(
+ element: _textual_index_element, compiler: "SQLCompiler", **kw
+) -> str:
return compiler.process(element.text, **kw)
-class _literal_bindparam(_BindParamClause):
+class _literal_bindparam(BindParameter):
pass
@compiles(_literal_bindparam)
-def _render_literal_bindparam(element, compiler, **kw):
+def _render_literal_bindparam(
+ element: _literal_bindparam, compiler: "SQLCompiler", **kw
+) -> str:
return compiler.render_literal_bindparam(element, **kw)
@@ -329,17 +380,20 @@ def _get_index_column_names(idx):
return [getattr(exp, "name", None) for exp in _get_index_expressions(idx)]
-def _column_kwargs(col):
+def _column_kwargs(col: "Column") -> Mapping:
if sqla_13:
return col.kwargs
else:
return {}
-def _get_constraint_final_name(constraint, dialect):
+def _get_constraint_final_name(
+ constraint: Union["Index", "Constraint"], dialect: Optional["Dialect"]
+) -> Optional[str]:
if constraint.name is None:
return None
- elif sqla_14:
+ assert dialect is not None
+ if sqla_14:
# for SQLAlchemy 1.4 we would like to have the option to expand
# the use of "deferred" names for constraints as well as to have
# some flexibility with "None" name and similar; make use of new
@@ -355,7 +409,7 @@ def _get_constraint_final_name(constraint, dialect):
if hasattr(constraint.name, "quote"):
# might be quoted_name, might be truncated_name, keep it the
# same
- quoted_name_cls = type(constraint.name)
+ quoted_name_cls: type = type(constraint.name)
else:
quoted_name_cls = quoted_name
@@ -364,7 +418,8 @@ def _get_constraint_final_name(constraint, dialect):
if isinstance(constraint, schema.Index):
# name should not be quoted.
- return dialect.ddl_compiler(dialect, None)._prepared_index_name(
+ d = dialect.ddl_compiler(dialect, None)
+ return d._prepared_index_name( # type: ignore[attr-defined]
constraint
)
else:
@@ -372,10 +427,13 @@ def _get_constraint_final_name(constraint, dialect):
return dialect.identifier_preparer.format_constraint(constraint)
-def _constraint_is_named(constraint, dialect):
+def _constraint_is_named(
+ constraint: Union["Constraint", "Index"], dialect: Optional["Dialect"]
+) -> bool:
if sqla_14:
if constraint.name is None:
return False
+ assert dialect is not None
name = dialect.identifier_preparer.format_constraint(
constraint, _alembic_quote=False
)
@@ -384,18 +442,21 @@ def _constraint_is_named(constraint, dialect):
return constraint.name is not None
-def _is_mariadb(mysql_dialect):
+def _is_mariadb(mysql_dialect: "Dialect") -> bool:
if sqla_14:
- return mysql_dialect.is_mariadb
+ return mysql_dialect.is_mariadb # type: ignore[attr-defined]
else:
- return mysql_dialect.server_version_info and mysql_dialect._is_mariadb
+ return bool(
+ mysql_dialect.server_version_info
+ and mysql_dialect._is_mariadb # type: ignore[attr-defined]
+ )
def _mariadb_normalized_version_info(mysql_dialect):
return mysql_dialect._mariadb_normalized_version_info
-def _insert_inline(table):
+def _insert_inline(table: Union["TableClause", "Table"]) -> "Insert":
if sqla_14:
return table.insert().inline()
else:
@@ -408,10 +469,10 @@ if sqla_14:
else:
from sqlalchemy import create_engine
- def create_mock_engine(url, executor):
+ def create_mock_engine(url, executor, **kw): # type: ignore[misc]
return create_engine(
"postgresql://", strategy="mock", executor=executor
)
- def _select(*columns):
- return sql.select(list(columns))
+ def _select(*columns, **kw) -> "Select":
+ return sql.select(list(columns), **kw)
diff --git a/docs/build/unreleased/py3_typing.rst b/docs/build/unreleased/py3_typing.rst
new file mode 100644
index 0000000..7f8aa6c
--- /dev/null
+++ b/docs/build/unreleased/py3_typing.rst
@@ -0,0 +1,8 @@
+.. change::
+ :tags: feature, general
+
+ pep-484 type annotations have been added throughout the library. This
+ should be helpful in providing Mypy and IDE support, however there is not
+ full support for Alembic's dynamically modified "op" namespace as of yet; a
+ future release will likely modify the approach used for importing this
+ namespace to be better compatible with pep-484 capabilities. \ No newline at end of file
diff --git a/setup.cfg b/setup.cfg
index 7514d8b..025a93a 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -91,6 +91,7 @@ import-order-style = google
application-import-names = alembic,tests
per-file-ignores =
**/__init__.py:F401
+max-line-length = 79
[sqla_testing]
requirement_cls=tests.requirements:DefaultRequirements
@@ -115,4 +116,12 @@ oracle8=oracle://scott:tiger@127.0.0.1:1521/?use_ansi=0
addopts= --tb native -v -r sfxX -p no:warnings -p no:logging --maxfail=25
python_files=tests/test_*.py
+[mypy]
+show_error_codes = True
+allow_redefinition = True
+[mypy-mako.*]
+ignore_missing_imports = True
+
+[mypy-sqlalchemy.testing.*]
+ignore_missing_imports = True
diff --git a/tests/test_revision.py b/tests/test_revision.py
index c2c1410..61998bf 100644
--- a/tests/test_revision.py
+++ b/tests/test_revision.py
@@ -9,14 +9,13 @@ from alembic.script.revision import Revision
from alembic.script.revision import RevisionError
from alembic.script.revision import RevisionMap
from alembic.testing import assert_raises_message
-from alembic.testing import config
from alembic.testing import eq_
+from alembic.testing import expect_raises_message
from alembic.testing.fixtures import TestBase
from . import _large_map
class APITest(TestBase):
- @config.requirements.python3
def test_invalid_datatype(self):
map_ = RevisionMap(
lambda: [
@@ -25,29 +24,26 @@ class APITest(TestBase):
Revision("c", ("b",)),
]
)
- assert_raises_message(
+ with expect_raises_message(
RevisionError,
"revision identifier b'12345' is not a string; "
"ensure database driver settings are correct",
- map_.get_revisions,
- b"12345",
- )
+ ):
+ map_.get_revisions(b"12345")
- assert_raises_message(
+ with expect_raises_message(
RevisionError,
"revision identifier b'12345' is not a string; "
"ensure database driver settings are correct",
- map_.get_revision,
- b"12345",
- )
+ ):
+ map_.get_revision(b"12345")
- assert_raises_message(
+ with expect_raises_message(
RevisionError,
r"revision identifier \(b'12345',\) is not a string; "
"ensure database driver settings are correct",
- map_.get_revision,
- (b"12345",),
- )
+ ):
+ map_.get_revision((b"12345",))
map_.get_revision(("a",))
map_.get_revision("a")
@@ -310,12 +306,12 @@ class LabeledBranchTest(DownIterateTest):
c1 = map_.get_revision("c1")
c2 = map_.get_revision("c2")
d = map_.get_revision("d")
- eq_(map_.filter_for_lineage([c1, c2, d], "c1branch@head"), [c1, c2, d])
+ eq_(map_.filter_for_lineage([c1, c2, d], "c1branch@head"), (c1, c2, d))
def test_filter_for_lineage_heads(self):
eq_(
self.map.filter_for_lineage([self.map.get_revision("f")], "heads"),
- [self.map.get_revision("f")],
+ (self.map.get_revision("f"),),
)
def setUp(self):
@@ -333,13 +329,13 @@ class LabeledBranchTest(DownIterateTest):
)
def test_get_base_revisions_labeled(self):
- eq_(self.map._get_base_revisions("somelongername@base"), ["a"])
+ eq_(self.map._get_base_revisions("somelongername@base"), ("a",))
def test_get_current_named_rev(self):
eq_(self.map.get_revision("ebranch@head"), self.map.get_revision("f"))
def test_get_base_revisions(self):
- eq_(self.map._get_base_revisions("base"), ["a", "d"])
+ eq_(self.map._get_base_revisions("base"), ("a", "d"))
def test_iterate_head_to_named_base(self):
self._assert_iteration(
diff --git a/tox.ini b/tox.ini
index f5456d8..20e55ab 100644
--- a/tox.ini
+++ b/tox.ini
@@ -55,6 +55,19 @@ commands=
{oracle,mssql}: python reap_dbs.py db_idents.txt
+[testenv:mypy]
+basepython = python3
+deps=
+ mypy
+ sqlalchemy>=1.4.0
+ sqlalchemy2-stubs
+ mako
+ types-pkg-resources
+ types-python-dateutil
+ # is imported in alembic/testing and mypy complains if it's installed.
+ pytest
+commands = mypy ./alembic/ --exclude alembic/templates
+
[testenv:pep8]
basepython = python3
deps=