summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--alembic/autogenerate/api.py6
-rw-r--r--alembic/autogenerate/compare.py7
-rw-r--r--alembic/autogenerate/render.py2
-rw-r--r--alembic/context.pyi15
-rw-r--r--alembic/ddl/base.py2
-rw-r--r--alembic/ddl/impl.py45
-rw-r--r--alembic/ddl/mssql.py7
-rw-r--r--alembic/ddl/oracle.py6
-rw-r--r--alembic/op.pyi11
-rw-r--r--alembic/operations/base.py6
-rw-r--r--alembic/operations/batch.py10
-rw-r--r--alembic/operations/ops.py22
-rw-r--r--alembic/operations/toimpl.py2
-rw-r--r--alembic/runtime/environment.py20
-rw-r--r--alembic/runtime/migration.py18
-rw-r--r--alembic/util/sqla_compat.py18
-rw-r--r--pyproject.toml11
-rw-r--r--tools/write_pyi.py5
18 files changed, 133 insertions, 80 deletions
diff --git a/alembic/autogenerate/api.py b/alembic/autogenerate/api.py
index 5f1c7f3..cbd64e1 100644
--- a/alembic/autogenerate/api.py
+++ b/alembic/autogenerate/api.py
@@ -523,12 +523,12 @@ class RevisionContext:
def run_autogenerate(
self, rev: tuple, migration_context: "MigrationContext"
- ):
+ ) -> None:
self._run_environment(rev, migration_context, True)
def run_no_autogenerate(
self, rev: tuple, migration_context: "MigrationContext"
- ):
+ ) -> None:
self._run_environment(rev, migration_context, False)
def _run_environment(
@@ -536,7 +536,7 @@ class RevisionContext:
rev: tuple,
migration_context: "MigrationContext",
autogenerate: bool,
- ):
+ ) -> None:
if autogenerate:
if self.command_args["sql"]:
raise util.CommandError(
diff --git a/alembic/autogenerate/compare.py b/alembic/autogenerate/compare.py
index 5b69815..c32ab4d 100644
--- a/alembic/autogenerate/compare.py
+++ b/alembic/autogenerate/compare.py
@@ -616,8 +616,8 @@ def _compare_indexes_and_uniques(
# we know are either added implicitly by the DB or that the DB
# can't accurately report on
autogen_context.migration_context.impl.correct_for_autogen_constraints(
- conn_uniques,
- conn_indexes,
+ conn_uniques, # type: ignore[arg-type]
+ conn_indexes, # type: ignore[arg-type]
metadata_unique_constraints,
metadata_indexes,
)
@@ -1274,7 +1274,8 @@ def _compare_foreign_keys(
)
conn_fks = set(
- _make_foreign_key(const, conn_table) for const in conn_fks_list
+ _make_foreign_key(const, conn_table) # type: ignore[arg-type]
+ for const in conn_fks_list
)
# give the dialect a chance to correct the FKs to match more
diff --git a/alembic/autogenerate/render.py b/alembic/autogenerate/render.py
index 9c992b4..1ac6753 100644
--- a/alembic/autogenerate/render.py
+++ b/alembic/autogenerate/render.py
@@ -989,7 +989,7 @@ def _fk_colspec(
if table_fullname in namespace_metadata.tables:
col = namespace_metadata.tables[table_fullname].c.get(colname)
if col is not None:
- colname = _ident(col.name)
+ colname = _ident(col.name) # type: ignore[assignment]
colspec = "%s.%s" % (table_fullname, colname)
diff --git a/alembic/context.pyi b/alembic/context.pyi
index 14e1b5f..a2e5399 100644
--- a/alembic/context.pyi
+++ b/alembic/context.pyi
@@ -5,6 +5,8 @@ from __future__ import annotations
from typing import Any
from typing import Callable
from typing import ContextManager
+from typing import Dict
+from typing import List
from typing import Optional
from typing import TextIO
from typing import Tuple
@@ -13,6 +15,7 @@ from typing import Union
if TYPE_CHECKING:
from sqlalchemy.engine.base import Connection
+ from sqlalchemy.sql.elements import ClauseElement
from sqlalchemy.sql.schema import MetaData
from .config import Config
@@ -530,7 +533,9 @@ def configure(
"""
-def execute(sql, execution_options=None):
+def execute(
+ sql: Union[ClauseElement, str], execution_options: Optional[dict] = None
+) -> None:
"""Execute the given SQL using the current change context.
The behavior of :meth:`.execute` is the same
@@ -543,7 +548,7 @@ def execute(sql, execution_options=None):
"""
-def get_bind():
+def get_bind() -> Connection:
"""Return the current 'bind'.
In "online" mode, this is the
@@ -635,7 +640,9 @@ def get_tag_argument() -> Optional[str]:
"""
-def get_x_argument(as_dictionary: bool = False):
+def get_x_argument(
+ as_dictionary: bool = False,
+) -> Union[List[str], Dict[str, str]]:
"""Return the value(s) passed for the ``-x`` argument, if any.
The ``-x`` argument is an open ended flag that allows any user-defined
@@ -723,7 +730,7 @@ def run_migrations(**kw: Any) -> None:
script: ScriptDirectory
-def static_output(text):
+def static_output(text: str) -> None:
"""Emit text directly to the "offline" SQL stream.
Typically this is for emitting comments that
diff --git a/alembic/ddl/base.py b/alembic/ddl/base.py
index 7b0f63e..c910786 100644
--- a/alembic/ddl/base.py
+++ b/alembic/ddl/base.py
@@ -294,7 +294,7 @@ def format_table_name(
def format_column_name(
compiler: "DDLCompiler", name: Optional[Union["quoted_name", str]]
) -> Union["quoted_name", str]:
- return compiler.preparer.quote(name)
+ return compiler.preparer.quote(name) # type: ignore[arg-type]
def format_server_default(
diff --git a/alembic/ddl/impl.py b/alembic/ddl/impl.py
index 070c124..79d5245 100644
--- a/alembic/ddl/impl.py
+++ b/alembic/ddl/impl.py
@@ -23,19 +23,16 @@ from .. import util
from ..util import sqla_compat
if TYPE_CHECKING:
- from io import StringIO
from typing import Literal
+ from typing import TextIO
from sqlalchemy.engine import Connection
from sqlalchemy.engine import Dialect
from sqlalchemy.engine.cursor import CursorResult
- from sqlalchemy.engine.cursor import LegacyCursorResult
from sqlalchemy.engine.reflection import Inspector
- from sqlalchemy.sql.dml import Update
from sqlalchemy.sql.elements import ClauseElement
from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.elements import quoted_name
- from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.schema import Column
from sqlalchemy.sql.schema import Constraint
from sqlalchemy.sql.schema import ForeignKeyConstraint
@@ -60,11 +57,11 @@ class ImplMeta(type):
):
newtype = type.__init__(cls, classname, bases, dict_)
if "__dialect__" in dict_:
- _impls[dict_["__dialect__"]] = cls
+ _impls[dict_["__dialect__"]] = cls # type: ignore[assignment]
return newtype
-_impls: dict = {}
+_impls: Dict[str, Type["DefaultImpl"]] = {}
Params = namedtuple("Params", ["token0", "tokens", "args", "kwargs"])
@@ -98,7 +95,7 @@ class DefaultImpl(metaclass=ImplMeta):
connection: Optional["Connection"],
as_sql: bool,
transactional_ddl: Optional[bool],
- output_buffer: Optional["StringIO"],
+ output_buffer: Optional["TextIO"],
context_opts: Dict[str, Any],
) -> None:
self.dialect = dialect
@@ -119,7 +116,7 @@ class DefaultImpl(metaclass=ImplMeta):
)
@classmethod
- def get_by_dialect(cls, dialect: "Dialect") -> Any:
+ def get_by_dialect(cls, dialect: "Dialect") -> Type["DefaultImpl"]:
return _impls[dialect.name]
def static_output(self, text: str) -> None:
@@ -158,10 +155,10 @@ class DefaultImpl(metaclass=ImplMeta):
def _exec(
self,
construct: Union["ClauseElement", str],
- execution_options: None = None,
+ execution_options: Optional[dict] = None,
multiparams: Sequence[dict] = (),
params: Dict[str, int] = util.immutabledict(),
- ) -> Optional[Union["LegacyCursorResult", "CursorResult"]]:
+ ) -> Optional["CursorResult"]:
if isinstance(construct, str):
construct = text(construct)
if self.as_sql:
@@ -176,10 +173,11 @@ class DefaultImpl(metaclass=ImplMeta):
else:
compile_kw = {}
+ compiled = construct.compile(
+ dialect=self.dialect, **compile_kw # type: ignore[arg-type]
+ )
self.static_output(
- str(construct.compile(dialect=self.dialect, **compile_kw))
- .replace("\t", " ")
- .strip()
+ str(compiled).replace("\t", " ").strip()
+ self.command_terminator
)
return None
@@ -192,11 +190,13 @@ class DefaultImpl(metaclass=ImplMeta):
assert isinstance(multiparams, tuple)
multiparams += (params,)
- return conn.execute(construct, multiparams)
+ return conn.execute( # type: ignore[call-overload]
+ construct, multiparams
+ )
def execute(
self,
- sql: Union["Update", "TextClause", str],
+ sql: Union["ClauseElement", str],
execution_options: None = None,
) -> None:
self._exec(sql, execution_options)
@@ -424,9 +424,6 @@ class DefaultImpl(metaclass=ImplMeta):
)
)
else:
- # work around http://www.sqlalchemy.org/trac/ticket/2461
- if not hasattr(table, "_autoincrement_column"):
- table._autoincrement_column = None
if rows:
if multiinsert:
self._exec(
@@ -572,7 +569,7 @@ class DefaultImpl(metaclass=ImplMeta):
)
def render_ddl_sql_expr(
- self, expr: "ClauseElement", is_server_default: bool = False, **kw
+ self, expr: "ClauseElement", is_server_default: bool = False, **kw: Any
) -> str:
"""Render a SQL expression that is typically a server default,
index expression, etc.
@@ -581,10 +578,14 @@ class DefaultImpl(metaclass=ImplMeta):
"""
- compile_kw = dict(
- compile_kwargs={"literal_binds": True, "include_table": False}
+ compile_kw = {
+ "compile_kwargs": {"literal_binds": True, "include_table": False}
+ }
+ return str(
+ expr.compile(
+ dialect=self.dialect, **compile_kw # type: ignore[arg-type]
+ )
)
- return str(expr.compile(dialect=self.dialect, **compile_kw))
def _compat_autogen_column_reflect(
self, inspector: "Inspector"
diff --git a/alembic/ddl/mssql.py b/alembic/ddl/mssql.py
index b48f8ba..28f0678 100644
--- a/alembic/ddl/mssql.py
+++ b/alembic/ddl/mssql.py
@@ -35,7 +35,6 @@ if TYPE_CHECKING:
from sqlalchemy.dialects.mssql.base import MSDDLCompiler
from sqlalchemy.dialects.mssql.base import MSSQLCompiler
from sqlalchemy.engine.cursor import CursorResult
- from sqlalchemy.engine.cursor import LegacyCursorResult
from sqlalchemy.sql.schema import Index
from sqlalchemy.sql.schema import Table
from sqlalchemy.sql.selectable import TableClause
@@ -68,9 +67,7 @@ class MSSQLImpl(DefaultImpl):
"mssql_batch_separator", self.batch_separator
)
- def _exec(
- self, construct: Any, *args, **kw
- ) -> Optional[Union["LegacyCursorResult", "CursorResult"]]:
+ def _exec(self, construct: Any, *args, **kw) -> Optional["CursorResult"]:
result = super(MSSQLImpl, self)._exec(construct, *args, **kw)
if self.as_sql and self.batch_separator:
self.static_output(self.batch_separator)
@@ -359,7 +356,7 @@ def visit_column_nullable(
return "%s %s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
- format_type(compiler, element.existing_type),
+ format_type(compiler, element.existing_type), # type: ignore[arg-type]
"NULL" if element.nullable else "NOT NULL",
)
diff --git a/alembic/ddl/oracle.py b/alembic/ddl/oracle.py
index 0e787fb..accd1fc 100644
--- a/alembic/ddl/oracle.py
+++ b/alembic/ddl/oracle.py
@@ -3,7 +3,6 @@ from __future__ import annotations
from typing import Any
from typing import Optional
from typing import TYPE_CHECKING
-from typing import Union
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql import sqltypes
@@ -26,7 +25,6 @@ from .impl import DefaultImpl
if TYPE_CHECKING:
from sqlalchemy.dialects.oracle.base import OracleDDLCompiler
from sqlalchemy.engine.cursor import CursorResult
- from sqlalchemy.engine.cursor import LegacyCursorResult
from sqlalchemy.sql.schema import Column
@@ -48,9 +46,7 @@ class OracleImpl(DefaultImpl):
"oracle_batch_separator", self.batch_separator
)
- def _exec(
- self, construct: Any, *args, **kw
- ) -> Optional[Union["LegacyCursorResult", "CursorResult"]]:
+ def _exec(self, construct: Any, *args, **kw) -> Optional["CursorResult"]:
result = super(OracleImpl, self)._exec(construct, *args, **kw)
if self.as_sql and self.batch_separator:
self.static_output(self.batch_separator)
diff --git a/alembic/op.pyi b/alembic/op.pyi
index 59dfc58..490d714 100644
--- a/alembic/op.pyi
+++ b/alembic/op.pyi
@@ -28,6 +28,7 @@ if TYPE_CHECKING:
from sqlalchemy.sql.schema import Column
from sqlalchemy.sql.schema import Computed
from sqlalchemy.sql.schema import Identity
+ from sqlalchemy.sql.schema import SchemaItem
from sqlalchemy.sql.schema import Table
from sqlalchemy.sql.type_api import TypeEngine
from sqlalchemy.util import immutabledict
@@ -94,7 +95,7 @@ def alter_column(
table_name: str,
column_name: str,
nullable: Optional[bool] = None,
- comment: Union[str, bool, None] = False,
+ comment: Union[str, Literal[False], None] = False,
server_default: Any = False,
new_column_name: Optional[str] = None,
type_: Union[TypeEngine, Type[TypeEngine], None] = None,
@@ -202,13 +203,13 @@ def batch_alter_table(
schema: Optional[str] = None,
recreate: Literal["auto", "always", "never"] = "auto",
partial_reordering: Optional[tuple] = None,
- copy_from: Optional["Table"] = None,
+ copy_from: Optional[Table] = None,
table_args: Tuple[Any, ...] = (),
table_kwargs: Mapping[str, Any] = immutabledict({}),
reflect_args: Tuple[Any, ...] = (),
reflect_kwargs: Mapping[str, Any] = immutabledict({}),
naming_convention: Optional[Dict[str, str]] = None,
-) -> Iterator["BatchOperations"]:
+) -> Iterator[BatchOperations]:
"""Invoke a series of per-table migrations in batch.
Batch mode allows a series of operations specific to a table
@@ -667,7 +668,9 @@ def create_primary_key(
"""
-def create_table(table_name: str, *columns, **kw: Any) -> Optional[Table]:
+def create_table(
+ table_name: str, *columns: SchemaItem, **kw: Any
+) -> Optional[Table]:
"""Issue a "create table" instruction using the current migration
context.
diff --git a/alembic/operations/base.py b/alembic/operations/base.py
index 9ecf3d4..535dff0 100644
--- a/alembic/operations/base.py
+++ b/alembic/operations/base.py
@@ -37,6 +37,7 @@ if TYPE_CHECKING:
from .batch import BatchOperationsImpl
from .ops import MigrateOperation
+ from ..ddl import DefaultImpl
from ..runtime.migration import MigrationContext
from ..util.sqla_compat import _literal_bindparam
@@ -74,6 +75,7 @@ class Operations(util.ModuleClsProxy):
"""
+ impl: Union["DefaultImpl", "BatchOperationsImpl"]
_to_impl = util.Dispatcher()
def __init__(
@@ -492,7 +494,7 @@ class Operations(util.ModuleClsProxy):
In a SQL script context, this value is ``None``. [TODO: verify this]
"""
- return self.migration_context.impl.bind
+ return self.migration_context.impl.bind # type: ignore[return-value]
class BatchOperations(Operations):
@@ -512,6 +514,8 @@ class BatchOperations(Operations):
"""
+ impl: "BatchOperationsImpl"
+
def _noop(self, operation):
raise NotImplementedError(
"The %s method does not apply to a batch table alter operation."
diff --git a/alembic/operations/batch.py b/alembic/operations/batch.py
index 71d2681..f1459e2 100644
--- a/alembic/operations/batch.py
+++ b/alembic/operations/batch.py
@@ -236,7 +236,7 @@ class ApplyBatchImpl:
self._grab_table_elements()
@classmethod
- def _calc_temp_name(cls, tablename: "quoted_name") -> str:
+ def _calc_temp_name(cls, tablename: Union["quoted_name", str]) -> str:
return ("_alembic_tmp_%s" % tablename)[0:50]
def _grab_table_elements(self) -> None:
@@ -280,7 +280,7 @@ class ApplyBatchImpl:
self.col_named_constraints[const.name] = (col, const)
for idx in self.table.indexes:
- self.indexes[idx.name] = idx
+ self.indexes[idx.name] = idx # type: ignore[index]
for k in self.table.kwargs:
self.table_kwargs.setdefault(k, self.table.kwargs[k])
@@ -546,7 +546,7 @@ class ApplyBatchImpl:
existing.server_default = None
else:
sql_schema.DefaultClause(
- server_default
+ server_default # type: ignore[arg-type]
)._set_parent( # type:ignore[attr-defined]
existing
)
@@ -699,11 +699,11 @@ class ApplyBatchImpl:
self.columns[col.name].primary_key = False
def create_index(self, idx: "Index") -> None:
- self.new_indexes[idx.name] = idx
+ self.new_indexes[idx.name] = idx # type: ignore[index]
def drop_index(self, idx: "Index") -> None:
try:
- del self.indexes[idx.name]
+ del self.indexes[idx.name] # type: ignore[arg-type]
except KeyError:
raise ValueError("No such index: '%s'" % idx.name)
diff --git a/alembic/operations/ops.py b/alembic/operations/ops.py
index 997274d..85ffe14 100644
--- a/alembic/operations/ops.py
+++ b/alembic/operations/ops.py
@@ -26,6 +26,8 @@ from .. import util
from ..util import sqla_compat
if TYPE_CHECKING:
+ from typing import Literal
+
from sqlalchemy.sql.dml import Insert
from sqlalchemy.sql.dml import Update
from sqlalchemy.sql.elements import BinaryExpression
@@ -885,7 +887,7 @@ class CreateIndexOp(MigrateOperation):
def from_index(cls, index: "Index") -> "CreateIndexOp":
assert index.table is not None
return cls(
- index.name,
+ index.name, # type: ignore[arg-type]
index.table.name,
sqla_compat._get_index_expressions(index),
schema=index.table.schema,
@@ -1021,7 +1023,7 @@ class DropIndexOp(MigrateOperation):
def from_index(cls, index: "Index") -> "DropIndexOp":
assert index.table is not None
return cls(
- index.name,
+ index.name, # type: ignore[arg-type]
index.table.name,
schema=index.table.schema,
_reverse=CreateIndexOp.from_index(index),
@@ -1105,7 +1107,7 @@ class CreateTableOp(MigrateOperation):
def __init__(
self,
table_name: str,
- columns: Sequence[Union["Column", "Constraint"]],
+ columns: Sequence["SchemaItem"],
schema: Optional[str] = None,
_namespace_metadata: Optional["MetaData"] = None,
_constraints_included: bool = False,
@@ -1172,8 +1174,12 @@ class CreateTableOp(MigrateOperation):
@classmethod
def create_table(
- cls, operations: "Operations", table_name: str, *columns, **kw: Any
- ) -> Optional["Table"]:
+ cls,
+ operations: "Operations",
+ table_name: str,
+ *columns: "SchemaItem",
+ **kw: Any,
+ ) -> "Optional[Table]":
r"""Issue a "create table" instruction using the current migration
context.
@@ -1603,7 +1609,7 @@ class AlterColumnOp(AlterTableOp):
existing_nullable: Optional[bool] = None,
existing_comment: Optional[str] = None,
modify_nullable: Optional[bool] = None,
- modify_comment: Optional[Union[str, bool]] = False,
+ modify_comment: Optional[Union[str, "Literal[False]"]] = False,
modify_server_default: Any = False,
modify_name: Optional[str] = None,
modify_type: Optional[Any] = None,
@@ -1757,7 +1763,7 @@ class AlterColumnOp(AlterTableOp):
table_name: str,
column_name: str,
nullable: Optional[bool] = None,
- comment: Optional[Union[str, bool]] = False,
+ comment: Optional[Union[str, "Literal[False]"]] = False,
server_default: Any = False,
new_column_name: Optional[str] = None,
type_: Optional[Union["TypeEngine", Type["TypeEngine"]]] = None,
@@ -1885,7 +1891,7 @@ class AlterColumnOp(AlterTableOp):
operations: BatchOperations,
column_name: str,
nullable: Optional[bool] = None,
- comment: bool = False,
+ comment: Union[str, "Literal[False]"] = False,
server_default: Union["Function", bool] = False,
new_column_name: Optional[str] = None,
type_: Optional[Union["TypeEngine", Type["TypeEngine"]]] = None,
diff --git a/alembic/operations/toimpl.py b/alembic/operations/toimpl.py
index f97983e..add142d 100644
--- a/alembic/operations/toimpl.py
+++ b/alembic/operations/toimpl.py
@@ -195,7 +195,7 @@ def drop_constraint(
def bulk_insert(
operations: "Operations", operation: "ops.BulkInsertOp"
) -> None:
- operations.impl.bulk_insert(
+ operations.impl.bulk_insert( # type: ignore[union-attr]
operation.table, operation.rows, multiinsert=operation.multiinsert
)
diff --git a/alembic/runtime/environment.py b/alembic/runtime/environment.py
index b95e0b5..3cec5b1 100644
--- a/alembic/runtime/environment.py
+++ b/alembic/runtime/environment.py
@@ -20,10 +20,12 @@ if TYPE_CHECKING:
from typing import Literal
from sqlalchemy.engine.base import Connection
+ from sqlalchemy.sql.elements import ClauseElement
from sqlalchemy.sql.schema import MetaData
from .migration import _ProxyTransaction
from ..config import Config
+ from ..ddl import DefaultImpl
from ..script.base import ScriptDirectory
_RevNumber = Optional[Union[str, Tuple[str, ...]]]
@@ -273,7 +275,9 @@ class EnvironmentContext(util.ModuleClsProxy):
) -> Dict[str, str]:
...
- def get_x_argument(self, as_dictionary: bool = False):
+ def get_x_argument(
+ self, as_dictionary: bool = False
+ ) -> Union[List[str], Dict[str, str]]:
"""Return the value(s) passed for the ``-x`` argument, if any.
The ``-x`` argument is an open ended flag that allows any user-defined
@@ -853,7 +857,11 @@ class EnvironmentContext(util.ModuleClsProxy):
with Operations.context(self._migration_context):
self.get_context().run_migrations(**kw)
- def execute(self, sql, execution_options=None):
+ def execute(
+ self,
+ sql: Union["ClauseElement", str],
+ execution_options: Optional[dict] = None,
+ ) -> None:
"""Execute the given SQL using the current change context.
The behavior of :meth:`.execute` is the same
@@ -867,7 +875,7 @@ class EnvironmentContext(util.ModuleClsProxy):
"""
self.get_context().execute(sql, execution_options=execution_options)
- def static_output(self, text):
+ def static_output(self, text: str) -> None:
"""Emit text directly to the "offline" SQL stream.
Typically this is for emitting comments that
@@ -938,7 +946,7 @@ class EnvironmentContext(util.ModuleClsProxy):
raise Exception("No context has been configured yet.")
return self._migration_context
- def get_bind(self):
+ def get_bind(self) -> "Connection":
"""Return the current 'bind'.
In "online" mode, this is the
@@ -949,7 +957,7 @@ class EnvironmentContext(util.ModuleClsProxy):
has first been made available via :meth:`.configure`.
"""
- return self.get_context().bind
+ return self.get_context().bind # type: ignore[return-value]
- def get_impl(self):
+ def get_impl(self) -> "DefaultImpl":
return self.get_context().impl
diff --git a/alembic/runtime/migration.py b/alembic/runtime/migration.py
index c09c8e4..677d0c7 100644
--- a/alembic/runtime/migration.py
+++ b/alembic/runtime/migration.py
@@ -36,6 +36,7 @@ if TYPE_CHECKING:
from sqlalchemy.engine.base import Connection
from sqlalchemy.engine.base import Transaction
from sqlalchemy.engine.mock import MockConnection
+ from sqlalchemy.sql.elements import ClauseElement
from .environment import EnvironmentContext
from ..config import Config
@@ -539,6 +540,7 @@ class MigrationContext:
def _ensure_version_table(self, purge: bool = False) -> None:
with sqla_compat._ensure_scope_for_ddl(self.connection):
+ assert self.connection is not None
self._version.create(self.connection, checkfirst=True)
if purge:
assert self.connection is not None
@@ -568,7 +570,7 @@ class MigrationContext:
for step in script_directory._stamp_revs(revision, heads):
head_maintainer.update_to_step(step)
- def run_migrations(self, **kw) -> None:
+ def run_migrations(self, **kw: Any) -> None:
r"""Run the migration scripts established for this
:class:`.MigrationContext`, if any.
@@ -614,6 +616,7 @@ class MigrationContext:
if self.as_sql and not head_maintainer.heads:
# for offline mode, include a CREATE TABLE from
# the base
+ assert self.connection is not None
self._version.create(self.connection)
log.info("Running %s", step)
if self.as_sql:
@@ -637,6 +640,7 @@ class MigrationContext:
)
if self.as_sql and not head_maintainer.heads:
+ assert self.connection is not None
self._version.drop(self.connection)
def _in_connection_transaction(self) -> bool:
@@ -647,7 +651,11 @@ class MigrationContext:
else:
return meth()
- def execute(self, sql: str, execution_options: None = None) -> None:
+ def execute(
+ self,
+ sql: Union["ClauseElement", str],
+ execution_options: Optional[dict] = None,
+ ) -> None:
"""Execute a SQL construct or string statement.
The underlying execution mechanics are used, that is
@@ -771,9 +779,11 @@ class HeadMaintainer:
== literal_column("'%s'" % version)
)
)
+
if (
not self.context.as_sql
and self.context.dialect.supports_sane_rowcount
+ and ret is not None
and ret.rowcount != 1
):
raise util.CommandError(
@@ -796,9 +806,11 @@ class HeadMaintainer:
== literal_column("'%s'" % from_)
)
)
+
if (
not self.context.as_sql
and self.context.dialect.supports_sane_rowcount
+ and ret is not None
and ret.rowcount != 1
):
raise util.CommandError(
@@ -1269,7 +1281,7 @@ class StampStep(MigrationStep):
doc: None = None
- def stamp_revision(self, **kw) -> None:
+ def stamp_revision(self, **kw: Any) -> None:
return None
def __eq__(self, other):
diff --git a/alembic/util/sqla_compat.py b/alembic/util/sqla_compat.py
index 179e548..7e98bb5 100644
--- a/alembic/util/sqla_compat.py
+++ b/alembic/util/sqla_compat.py
@@ -241,7 +241,7 @@ def _table_for_constraint(constraint: "Constraint") -> "Table":
if isinstance(constraint, ForeignKeyConstraint):
table = constraint.parent
assert table is not None
- return table
+ return table # type: ignore[return-value]
else:
return constraint.table
@@ -261,7 +261,9 @@ def _reflect_table(
if sqla_14:
return inspector.reflect_table(table, None)
else:
- return inspector.reflecttable(table, None)
+ return inspector.reflecttable( # type: ignore[attr-defined]
+ table, None
+ )
def _resolve_for_variant(type_, dialect):
@@ -391,7 +393,9 @@ def _copy_expression(expression: _CE, target_table: "Table") -> _CE:
else:
return None
- return visitors.replacement_traverse(expression, {}, replace)
+ return visitors.replacement_traverse( # type: ignore[call-overload]
+ expression, {}, replace
+ )
class _textual_index_element(sql.ColumnElement):
@@ -487,7 +491,7 @@ def _get_constraint_final_name(
if isinstance(constraint, schema.Index):
# name should not be quoted.
- d = dialect.ddl_compiler(dialect, None)
+ d = dialect.ddl_compiler(dialect, None) # type: ignore[arg-type]
return d._prepared_index_name( # type: ignore[attr-defined]
constraint
)
@@ -529,7 +533,7 @@ def _insert_inline(table: Union["TableClause", "Table"]) -> "Insert":
if sqla_14:
return table.insert().inline()
else:
- return table.insert(inline=True)
+ return table.insert(inline=True) # type: ignore[call-arg]
if sqla_14:
@@ -543,5 +547,5 @@ else:
"postgresql://", strategy="mock", executor=executor
)
- def _select(*columns, **kw) -> "Select":
- return sql.select(list(columns), **kw)
+ def _select(*columns, **kw) -> "Select": # type: ignore[no-redef]
+ return sql.select(list(columns), **kw) # type: ignore[call-overload]
diff --git a/pyproject.toml b/pyproject.toml
index 2a8de06..f66269a 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -17,6 +17,17 @@ show_error_codes = true
[[tool.mypy.overrides]]
module = [
+ 'alembic.operations.ops',
+ 'alembic.op',
+ 'alembic.context',
+ 'alembic.autogenerate.api',
+ 'alembic.runtime.*',
+]
+
+disallow_incomplete_defs = true
+
+[[tool.mypy.overrides]]
+module = [
'mako.*',
'sqlalchemy.testing.*'
]
diff --git a/tools/write_pyi.py b/tools/write_pyi.py
index cf42d1b..52fac3c 100644
--- a/tools/write_pyi.py
+++ b/tools/write_pyi.py
@@ -29,6 +29,7 @@ IGNORE_ITEMS = {
}
TRIM_MODULE = [
"alembic.runtime.migration.",
+ "alembic.operations.base.",
"alembic.operations.ops.",
"sqlalchemy.engine.base.",
"sqlalchemy.sql.schema.",
@@ -85,6 +86,8 @@ def generate_pyi_for_proxy(
module = sys.modules[cls.__module__]
env = {
+ **typing.__dict__,
+ **sa.sql.schema.__dict__,
**sa.__dict__,
**sa.types.__dict__,
**ops.__dict__,
@@ -141,7 +144,7 @@ def _generate_stub_for_meth(cls, name, printer, env, is_context_manager):
annotations = typing.get_type_hints(fn, env)
spec.annotations.update(annotations)
except NameError as e:
- pass
+ print(f"{cls.__name__}.{name} NameError: {e}", file=sys.stderr)
name_args = spec[0]
assert name_args[0:1] == ["self"] or name_args[0:1] == ["cls"]