diff options
-rw-r--r-- | alembic/autogenerate/api.py | 6 | ||||
-rw-r--r-- | alembic/autogenerate/compare.py | 7 | ||||
-rw-r--r-- | alembic/autogenerate/render.py | 2 | ||||
-rw-r--r-- | alembic/context.pyi | 15 | ||||
-rw-r--r-- | alembic/ddl/base.py | 2 | ||||
-rw-r--r-- | alembic/ddl/impl.py | 45 | ||||
-rw-r--r-- | alembic/ddl/mssql.py | 7 | ||||
-rw-r--r-- | alembic/ddl/oracle.py | 6 | ||||
-rw-r--r-- | alembic/op.pyi | 11 | ||||
-rw-r--r-- | alembic/operations/base.py | 6 | ||||
-rw-r--r-- | alembic/operations/batch.py | 10 | ||||
-rw-r--r-- | alembic/operations/ops.py | 22 | ||||
-rw-r--r-- | alembic/operations/toimpl.py | 2 | ||||
-rw-r--r-- | alembic/runtime/environment.py | 20 | ||||
-rw-r--r-- | alembic/runtime/migration.py | 18 | ||||
-rw-r--r-- | alembic/util/sqla_compat.py | 18 | ||||
-rw-r--r-- | pyproject.toml | 11 | ||||
-rw-r--r-- | tools/write_pyi.py | 5 |
18 files changed, 133 insertions, 80 deletions
diff --git a/alembic/autogenerate/api.py b/alembic/autogenerate/api.py index 5f1c7f3..cbd64e1 100644 --- a/alembic/autogenerate/api.py +++ b/alembic/autogenerate/api.py @@ -523,12 +523,12 @@ class RevisionContext: def run_autogenerate( self, rev: tuple, migration_context: "MigrationContext" - ): + ) -> None: self._run_environment(rev, migration_context, True) def run_no_autogenerate( self, rev: tuple, migration_context: "MigrationContext" - ): + ) -> None: self._run_environment(rev, migration_context, False) def _run_environment( @@ -536,7 +536,7 @@ class RevisionContext: rev: tuple, migration_context: "MigrationContext", autogenerate: bool, - ): + ) -> None: if autogenerate: if self.command_args["sql"]: raise util.CommandError( diff --git a/alembic/autogenerate/compare.py b/alembic/autogenerate/compare.py index 5b69815..c32ab4d 100644 --- a/alembic/autogenerate/compare.py +++ b/alembic/autogenerate/compare.py @@ -616,8 +616,8 @@ def _compare_indexes_and_uniques( # we know are either added implicitly by the DB or that the DB # can't accurately report on autogen_context.migration_context.impl.correct_for_autogen_constraints( - conn_uniques, - conn_indexes, + conn_uniques, # type: ignore[arg-type] + conn_indexes, # type: ignore[arg-type] metadata_unique_constraints, metadata_indexes, ) @@ -1274,7 +1274,8 @@ def _compare_foreign_keys( ) conn_fks = set( - _make_foreign_key(const, conn_table) for const in conn_fks_list + _make_foreign_key(const, conn_table) # type: ignore[arg-type] + for const in conn_fks_list ) # give the dialect a chance to correct the FKs to match more diff --git a/alembic/autogenerate/render.py b/alembic/autogenerate/render.py index 9c992b4..1ac6753 100644 --- a/alembic/autogenerate/render.py +++ b/alembic/autogenerate/render.py @@ -989,7 +989,7 @@ def _fk_colspec( if table_fullname in namespace_metadata.tables: col = namespace_metadata.tables[table_fullname].c.get(colname) if col is not None: - colname = _ident(col.name) + colname = _ident(col.name) # type: ignore[assignment] colspec = "%s.%s" % (table_fullname, colname) diff --git a/alembic/context.pyi b/alembic/context.pyi index 14e1b5f..a2e5399 100644 --- a/alembic/context.pyi +++ b/alembic/context.pyi @@ -5,6 +5,8 @@ from __future__ import annotations from typing import Any from typing import Callable from typing import ContextManager +from typing import Dict +from typing import List from typing import Optional from typing import TextIO from typing import Tuple @@ -13,6 +15,7 @@ from typing import Union if TYPE_CHECKING: from sqlalchemy.engine.base import Connection + from sqlalchemy.sql.elements import ClauseElement from sqlalchemy.sql.schema import MetaData from .config import Config @@ -530,7 +533,9 @@ def configure( """ -def execute(sql, execution_options=None): +def execute( + sql: Union[ClauseElement, str], execution_options: Optional[dict] = None +) -> None: """Execute the given SQL using the current change context. The behavior of :meth:`.execute` is the same @@ -543,7 +548,7 @@ def execute(sql, execution_options=None): """ -def get_bind(): +def get_bind() -> Connection: """Return the current 'bind'. In "online" mode, this is the @@ -635,7 +640,9 @@ def get_tag_argument() -> Optional[str]: """ -def get_x_argument(as_dictionary: bool = False): +def get_x_argument( + as_dictionary: bool = False, +) -> Union[List[str], Dict[str, str]]: """Return the value(s) passed for the ``-x`` argument, if any. The ``-x`` argument is an open ended flag that allows any user-defined @@ -723,7 +730,7 @@ def run_migrations(**kw: Any) -> None: script: ScriptDirectory -def static_output(text): +def static_output(text: str) -> None: """Emit text directly to the "offline" SQL stream. Typically this is for emitting comments that diff --git a/alembic/ddl/base.py b/alembic/ddl/base.py index 7b0f63e..c910786 100644 --- a/alembic/ddl/base.py +++ b/alembic/ddl/base.py @@ -294,7 +294,7 @@ def format_table_name( def format_column_name( compiler: "DDLCompiler", name: Optional[Union["quoted_name", str]] ) -> Union["quoted_name", str]: - return compiler.preparer.quote(name) + return compiler.preparer.quote(name) # type: ignore[arg-type] def format_server_default( diff --git a/alembic/ddl/impl.py b/alembic/ddl/impl.py index 070c124..79d5245 100644 --- a/alembic/ddl/impl.py +++ b/alembic/ddl/impl.py @@ -23,19 +23,16 @@ from .. import util from ..util import sqla_compat if TYPE_CHECKING: - from io import StringIO from typing import Literal + from typing import TextIO from sqlalchemy.engine import Connection from sqlalchemy.engine import Dialect from sqlalchemy.engine.cursor import CursorResult - from sqlalchemy.engine.cursor import LegacyCursorResult from sqlalchemy.engine.reflection import Inspector - from sqlalchemy.sql.dml import Update from sqlalchemy.sql.elements import ClauseElement from sqlalchemy.sql.elements import ColumnElement from sqlalchemy.sql.elements import quoted_name - from sqlalchemy.sql.elements import TextClause from sqlalchemy.sql.schema import Column from sqlalchemy.sql.schema import Constraint from sqlalchemy.sql.schema import ForeignKeyConstraint @@ -60,11 +57,11 @@ class ImplMeta(type): ): newtype = type.__init__(cls, classname, bases, dict_) if "__dialect__" in dict_: - _impls[dict_["__dialect__"]] = cls + _impls[dict_["__dialect__"]] = cls # type: ignore[assignment] return newtype -_impls: dict = {} +_impls: Dict[str, Type["DefaultImpl"]] = {} Params = namedtuple("Params", ["token0", "tokens", "args", "kwargs"]) @@ -98,7 +95,7 @@ class DefaultImpl(metaclass=ImplMeta): connection: Optional["Connection"], as_sql: bool, transactional_ddl: Optional[bool], - output_buffer: Optional["StringIO"], + output_buffer: Optional["TextIO"], context_opts: Dict[str, Any], ) -> None: self.dialect = dialect @@ -119,7 +116,7 @@ class DefaultImpl(metaclass=ImplMeta): ) @classmethod - def get_by_dialect(cls, dialect: "Dialect") -> Any: + def get_by_dialect(cls, dialect: "Dialect") -> Type["DefaultImpl"]: return _impls[dialect.name] def static_output(self, text: str) -> None: @@ -158,10 +155,10 @@ class DefaultImpl(metaclass=ImplMeta): def _exec( self, construct: Union["ClauseElement", str], - execution_options: None = None, + execution_options: Optional[dict] = None, multiparams: Sequence[dict] = (), params: Dict[str, int] = util.immutabledict(), - ) -> Optional[Union["LegacyCursorResult", "CursorResult"]]: + ) -> Optional["CursorResult"]: if isinstance(construct, str): construct = text(construct) if self.as_sql: @@ -176,10 +173,11 @@ class DefaultImpl(metaclass=ImplMeta): else: compile_kw = {} + compiled = construct.compile( + dialect=self.dialect, **compile_kw # type: ignore[arg-type] + ) self.static_output( - str(construct.compile(dialect=self.dialect, **compile_kw)) - .replace("\t", " ") - .strip() + str(compiled).replace("\t", " ").strip() + self.command_terminator ) return None @@ -192,11 +190,13 @@ class DefaultImpl(metaclass=ImplMeta): assert isinstance(multiparams, tuple) multiparams += (params,) - return conn.execute(construct, multiparams) + return conn.execute( # type: ignore[call-overload] + construct, multiparams + ) def execute( self, - sql: Union["Update", "TextClause", str], + sql: Union["ClauseElement", str], execution_options: None = None, ) -> None: self._exec(sql, execution_options) @@ -424,9 +424,6 @@ class DefaultImpl(metaclass=ImplMeta): ) ) else: - # work around http://www.sqlalchemy.org/trac/ticket/2461 - if not hasattr(table, "_autoincrement_column"): - table._autoincrement_column = None if rows: if multiinsert: self._exec( @@ -572,7 +569,7 @@ class DefaultImpl(metaclass=ImplMeta): ) def render_ddl_sql_expr( - self, expr: "ClauseElement", is_server_default: bool = False, **kw + self, expr: "ClauseElement", is_server_default: bool = False, **kw: Any ) -> str: """Render a SQL expression that is typically a server default, index expression, etc. @@ -581,10 +578,14 @@ class DefaultImpl(metaclass=ImplMeta): """ - compile_kw = dict( - compile_kwargs={"literal_binds": True, "include_table": False} + compile_kw = { + "compile_kwargs": {"literal_binds": True, "include_table": False} + } + return str( + expr.compile( + dialect=self.dialect, **compile_kw # type: ignore[arg-type] + ) ) - return str(expr.compile(dialect=self.dialect, **compile_kw)) def _compat_autogen_column_reflect( self, inspector: "Inspector" diff --git a/alembic/ddl/mssql.py b/alembic/ddl/mssql.py index b48f8ba..28f0678 100644 --- a/alembic/ddl/mssql.py +++ b/alembic/ddl/mssql.py @@ -35,7 +35,6 @@ if TYPE_CHECKING: from sqlalchemy.dialects.mssql.base import MSDDLCompiler from sqlalchemy.dialects.mssql.base import MSSQLCompiler from sqlalchemy.engine.cursor import CursorResult - from sqlalchemy.engine.cursor import LegacyCursorResult from sqlalchemy.sql.schema import Index from sqlalchemy.sql.schema import Table from sqlalchemy.sql.selectable import TableClause @@ -68,9 +67,7 @@ class MSSQLImpl(DefaultImpl): "mssql_batch_separator", self.batch_separator ) - def _exec( - self, construct: Any, *args, **kw - ) -> Optional[Union["LegacyCursorResult", "CursorResult"]]: + def _exec(self, construct: Any, *args, **kw) -> Optional["CursorResult"]: result = super(MSSQLImpl, self)._exec(construct, *args, **kw) if self.as_sql and self.batch_separator: self.static_output(self.batch_separator) @@ -359,7 +356,7 @@ def visit_column_nullable( return "%s %s %s %s" % ( alter_table(compiler, element.table_name, element.schema), alter_column(compiler, element.column_name), - format_type(compiler, element.existing_type), + format_type(compiler, element.existing_type), # type: ignore[arg-type] "NULL" if element.nullable else "NOT NULL", ) diff --git a/alembic/ddl/oracle.py b/alembic/ddl/oracle.py index 0e787fb..accd1fc 100644 --- a/alembic/ddl/oracle.py +++ b/alembic/ddl/oracle.py @@ -3,7 +3,6 @@ from __future__ import annotations from typing import Any from typing import Optional from typing import TYPE_CHECKING -from typing import Union from sqlalchemy.ext.compiler import compiles from sqlalchemy.sql import sqltypes @@ -26,7 +25,6 @@ from .impl import DefaultImpl if TYPE_CHECKING: from sqlalchemy.dialects.oracle.base import OracleDDLCompiler from sqlalchemy.engine.cursor import CursorResult - from sqlalchemy.engine.cursor import LegacyCursorResult from sqlalchemy.sql.schema import Column @@ -48,9 +46,7 @@ class OracleImpl(DefaultImpl): "oracle_batch_separator", self.batch_separator ) - def _exec( - self, construct: Any, *args, **kw - ) -> Optional[Union["LegacyCursorResult", "CursorResult"]]: + def _exec(self, construct: Any, *args, **kw) -> Optional["CursorResult"]: result = super(OracleImpl, self)._exec(construct, *args, **kw) if self.as_sql and self.batch_separator: self.static_output(self.batch_separator) diff --git a/alembic/op.pyi b/alembic/op.pyi index 59dfc58..490d714 100644 --- a/alembic/op.pyi +++ b/alembic/op.pyi @@ -28,6 +28,7 @@ if TYPE_CHECKING: from sqlalchemy.sql.schema import Column from sqlalchemy.sql.schema import Computed from sqlalchemy.sql.schema import Identity + from sqlalchemy.sql.schema import SchemaItem from sqlalchemy.sql.schema import Table from sqlalchemy.sql.type_api import TypeEngine from sqlalchemy.util import immutabledict @@ -94,7 +95,7 @@ def alter_column( table_name: str, column_name: str, nullable: Optional[bool] = None, - comment: Union[str, bool, None] = False, + comment: Union[str, Literal[False], None] = False, server_default: Any = False, new_column_name: Optional[str] = None, type_: Union[TypeEngine, Type[TypeEngine], None] = None, @@ -202,13 +203,13 @@ def batch_alter_table( schema: Optional[str] = None, recreate: Literal["auto", "always", "never"] = "auto", partial_reordering: Optional[tuple] = None, - copy_from: Optional["Table"] = None, + copy_from: Optional[Table] = None, table_args: Tuple[Any, ...] = (), table_kwargs: Mapping[str, Any] = immutabledict({}), reflect_args: Tuple[Any, ...] = (), reflect_kwargs: Mapping[str, Any] = immutabledict({}), naming_convention: Optional[Dict[str, str]] = None, -) -> Iterator["BatchOperations"]: +) -> Iterator[BatchOperations]: """Invoke a series of per-table migrations in batch. Batch mode allows a series of operations specific to a table @@ -667,7 +668,9 @@ def create_primary_key( """ -def create_table(table_name: str, *columns, **kw: Any) -> Optional[Table]: +def create_table( + table_name: str, *columns: SchemaItem, **kw: Any +) -> Optional[Table]: """Issue a "create table" instruction using the current migration context. diff --git a/alembic/operations/base.py b/alembic/operations/base.py index 9ecf3d4..535dff0 100644 --- a/alembic/operations/base.py +++ b/alembic/operations/base.py @@ -37,6 +37,7 @@ if TYPE_CHECKING: from .batch import BatchOperationsImpl from .ops import MigrateOperation + from ..ddl import DefaultImpl from ..runtime.migration import MigrationContext from ..util.sqla_compat import _literal_bindparam @@ -74,6 +75,7 @@ class Operations(util.ModuleClsProxy): """ + impl: Union["DefaultImpl", "BatchOperationsImpl"] _to_impl = util.Dispatcher() def __init__( @@ -492,7 +494,7 @@ class Operations(util.ModuleClsProxy): In a SQL script context, this value is ``None``. [TODO: verify this] """ - return self.migration_context.impl.bind + return self.migration_context.impl.bind # type: ignore[return-value] class BatchOperations(Operations): @@ -512,6 +514,8 @@ class BatchOperations(Operations): """ + impl: "BatchOperationsImpl" + def _noop(self, operation): raise NotImplementedError( "The %s method does not apply to a batch table alter operation." diff --git a/alembic/operations/batch.py b/alembic/operations/batch.py index 71d2681..f1459e2 100644 --- a/alembic/operations/batch.py +++ b/alembic/operations/batch.py @@ -236,7 +236,7 @@ class ApplyBatchImpl: self._grab_table_elements() @classmethod - def _calc_temp_name(cls, tablename: "quoted_name") -> str: + def _calc_temp_name(cls, tablename: Union["quoted_name", str]) -> str: return ("_alembic_tmp_%s" % tablename)[0:50] def _grab_table_elements(self) -> None: @@ -280,7 +280,7 @@ class ApplyBatchImpl: self.col_named_constraints[const.name] = (col, const) for idx in self.table.indexes: - self.indexes[idx.name] = idx + self.indexes[idx.name] = idx # type: ignore[index] for k in self.table.kwargs: self.table_kwargs.setdefault(k, self.table.kwargs[k]) @@ -546,7 +546,7 @@ class ApplyBatchImpl: existing.server_default = None else: sql_schema.DefaultClause( - server_default + server_default # type: ignore[arg-type] )._set_parent( # type:ignore[attr-defined] existing ) @@ -699,11 +699,11 @@ class ApplyBatchImpl: self.columns[col.name].primary_key = False def create_index(self, idx: "Index") -> None: - self.new_indexes[idx.name] = idx + self.new_indexes[idx.name] = idx # type: ignore[index] def drop_index(self, idx: "Index") -> None: try: - del self.indexes[idx.name] + del self.indexes[idx.name] # type: ignore[arg-type] except KeyError: raise ValueError("No such index: '%s'" % idx.name) diff --git a/alembic/operations/ops.py b/alembic/operations/ops.py index 997274d..85ffe14 100644 --- a/alembic/operations/ops.py +++ b/alembic/operations/ops.py @@ -26,6 +26,8 @@ from .. import util from ..util import sqla_compat if TYPE_CHECKING: + from typing import Literal + from sqlalchemy.sql.dml import Insert from sqlalchemy.sql.dml import Update from sqlalchemy.sql.elements import BinaryExpression @@ -885,7 +887,7 @@ class CreateIndexOp(MigrateOperation): def from_index(cls, index: "Index") -> "CreateIndexOp": assert index.table is not None return cls( - index.name, + index.name, # type: ignore[arg-type] index.table.name, sqla_compat._get_index_expressions(index), schema=index.table.schema, @@ -1021,7 +1023,7 @@ class DropIndexOp(MigrateOperation): def from_index(cls, index: "Index") -> "DropIndexOp": assert index.table is not None return cls( - index.name, + index.name, # type: ignore[arg-type] index.table.name, schema=index.table.schema, _reverse=CreateIndexOp.from_index(index), @@ -1105,7 +1107,7 @@ class CreateTableOp(MigrateOperation): def __init__( self, table_name: str, - columns: Sequence[Union["Column", "Constraint"]], + columns: Sequence["SchemaItem"], schema: Optional[str] = None, _namespace_metadata: Optional["MetaData"] = None, _constraints_included: bool = False, @@ -1172,8 +1174,12 @@ class CreateTableOp(MigrateOperation): @classmethod def create_table( - cls, operations: "Operations", table_name: str, *columns, **kw: Any - ) -> Optional["Table"]: + cls, + operations: "Operations", + table_name: str, + *columns: "SchemaItem", + **kw: Any, + ) -> "Optional[Table]": r"""Issue a "create table" instruction using the current migration context. @@ -1603,7 +1609,7 @@ class AlterColumnOp(AlterTableOp): existing_nullable: Optional[bool] = None, existing_comment: Optional[str] = None, modify_nullable: Optional[bool] = None, - modify_comment: Optional[Union[str, bool]] = False, + modify_comment: Optional[Union[str, "Literal[False]"]] = False, modify_server_default: Any = False, modify_name: Optional[str] = None, modify_type: Optional[Any] = None, @@ -1757,7 +1763,7 @@ class AlterColumnOp(AlterTableOp): table_name: str, column_name: str, nullable: Optional[bool] = None, - comment: Optional[Union[str, bool]] = False, + comment: Optional[Union[str, "Literal[False]"]] = False, server_default: Any = False, new_column_name: Optional[str] = None, type_: Optional[Union["TypeEngine", Type["TypeEngine"]]] = None, @@ -1885,7 +1891,7 @@ class AlterColumnOp(AlterTableOp): operations: BatchOperations, column_name: str, nullable: Optional[bool] = None, - comment: bool = False, + comment: Union[str, "Literal[False]"] = False, server_default: Union["Function", bool] = False, new_column_name: Optional[str] = None, type_: Optional[Union["TypeEngine", Type["TypeEngine"]]] = None, diff --git a/alembic/operations/toimpl.py b/alembic/operations/toimpl.py index f97983e..add142d 100644 --- a/alembic/operations/toimpl.py +++ b/alembic/operations/toimpl.py @@ -195,7 +195,7 @@ def drop_constraint( def bulk_insert( operations: "Operations", operation: "ops.BulkInsertOp" ) -> None: - operations.impl.bulk_insert( + operations.impl.bulk_insert( # type: ignore[union-attr] operation.table, operation.rows, multiinsert=operation.multiinsert ) diff --git a/alembic/runtime/environment.py b/alembic/runtime/environment.py index b95e0b5..3cec5b1 100644 --- a/alembic/runtime/environment.py +++ b/alembic/runtime/environment.py @@ -20,10 +20,12 @@ if TYPE_CHECKING: from typing import Literal from sqlalchemy.engine.base import Connection + from sqlalchemy.sql.elements import ClauseElement from sqlalchemy.sql.schema import MetaData from .migration import _ProxyTransaction from ..config import Config + from ..ddl import DefaultImpl from ..script.base import ScriptDirectory _RevNumber = Optional[Union[str, Tuple[str, ...]]] @@ -273,7 +275,9 @@ class EnvironmentContext(util.ModuleClsProxy): ) -> Dict[str, str]: ... - def get_x_argument(self, as_dictionary: bool = False): + def get_x_argument( + self, as_dictionary: bool = False + ) -> Union[List[str], Dict[str, str]]: """Return the value(s) passed for the ``-x`` argument, if any. The ``-x`` argument is an open ended flag that allows any user-defined @@ -853,7 +857,11 @@ class EnvironmentContext(util.ModuleClsProxy): with Operations.context(self._migration_context): self.get_context().run_migrations(**kw) - def execute(self, sql, execution_options=None): + def execute( + self, + sql: Union["ClauseElement", str], + execution_options: Optional[dict] = None, + ) -> None: """Execute the given SQL using the current change context. The behavior of :meth:`.execute` is the same @@ -867,7 +875,7 @@ class EnvironmentContext(util.ModuleClsProxy): """ self.get_context().execute(sql, execution_options=execution_options) - def static_output(self, text): + def static_output(self, text: str) -> None: """Emit text directly to the "offline" SQL stream. Typically this is for emitting comments that @@ -938,7 +946,7 @@ class EnvironmentContext(util.ModuleClsProxy): raise Exception("No context has been configured yet.") return self._migration_context - def get_bind(self): + def get_bind(self) -> "Connection": """Return the current 'bind'. In "online" mode, this is the @@ -949,7 +957,7 @@ class EnvironmentContext(util.ModuleClsProxy): has first been made available via :meth:`.configure`. """ - return self.get_context().bind + return self.get_context().bind # type: ignore[return-value] - def get_impl(self): + def get_impl(self) -> "DefaultImpl": return self.get_context().impl diff --git a/alembic/runtime/migration.py b/alembic/runtime/migration.py index c09c8e4..677d0c7 100644 --- a/alembic/runtime/migration.py +++ b/alembic/runtime/migration.py @@ -36,6 +36,7 @@ if TYPE_CHECKING: from sqlalchemy.engine.base import Connection from sqlalchemy.engine.base import Transaction from sqlalchemy.engine.mock import MockConnection + from sqlalchemy.sql.elements import ClauseElement from .environment import EnvironmentContext from ..config import Config @@ -539,6 +540,7 @@ class MigrationContext: def _ensure_version_table(self, purge: bool = False) -> None: with sqla_compat._ensure_scope_for_ddl(self.connection): + assert self.connection is not None self._version.create(self.connection, checkfirst=True) if purge: assert self.connection is not None @@ -568,7 +570,7 @@ class MigrationContext: for step in script_directory._stamp_revs(revision, heads): head_maintainer.update_to_step(step) - def run_migrations(self, **kw) -> None: + def run_migrations(self, **kw: Any) -> None: r"""Run the migration scripts established for this :class:`.MigrationContext`, if any. @@ -614,6 +616,7 @@ class MigrationContext: if self.as_sql and not head_maintainer.heads: # for offline mode, include a CREATE TABLE from # the base + assert self.connection is not None self._version.create(self.connection) log.info("Running %s", step) if self.as_sql: @@ -637,6 +640,7 @@ class MigrationContext: ) if self.as_sql and not head_maintainer.heads: + assert self.connection is not None self._version.drop(self.connection) def _in_connection_transaction(self) -> bool: @@ -647,7 +651,11 @@ class MigrationContext: else: return meth() - def execute(self, sql: str, execution_options: None = None) -> None: + def execute( + self, + sql: Union["ClauseElement", str], + execution_options: Optional[dict] = None, + ) -> None: """Execute a SQL construct or string statement. The underlying execution mechanics are used, that is @@ -771,9 +779,11 @@ class HeadMaintainer: == literal_column("'%s'" % version) ) ) + if ( not self.context.as_sql and self.context.dialect.supports_sane_rowcount + and ret is not None and ret.rowcount != 1 ): raise util.CommandError( @@ -796,9 +806,11 @@ class HeadMaintainer: == literal_column("'%s'" % from_) ) ) + if ( not self.context.as_sql and self.context.dialect.supports_sane_rowcount + and ret is not None and ret.rowcount != 1 ): raise util.CommandError( @@ -1269,7 +1281,7 @@ class StampStep(MigrationStep): doc: None = None - def stamp_revision(self, **kw) -> None: + def stamp_revision(self, **kw: Any) -> None: return None def __eq__(self, other): diff --git a/alembic/util/sqla_compat.py b/alembic/util/sqla_compat.py index 179e548..7e98bb5 100644 --- a/alembic/util/sqla_compat.py +++ b/alembic/util/sqla_compat.py @@ -241,7 +241,7 @@ def _table_for_constraint(constraint: "Constraint") -> "Table": if isinstance(constraint, ForeignKeyConstraint): table = constraint.parent assert table is not None - return table + return table # type: ignore[return-value] else: return constraint.table @@ -261,7 +261,9 @@ def _reflect_table( if sqla_14: return inspector.reflect_table(table, None) else: - return inspector.reflecttable(table, None) + return inspector.reflecttable( # type: ignore[attr-defined] + table, None + ) def _resolve_for_variant(type_, dialect): @@ -391,7 +393,9 @@ def _copy_expression(expression: _CE, target_table: "Table") -> _CE: else: return None - return visitors.replacement_traverse(expression, {}, replace) + return visitors.replacement_traverse( # type: ignore[call-overload] + expression, {}, replace + ) class _textual_index_element(sql.ColumnElement): @@ -487,7 +491,7 @@ def _get_constraint_final_name( if isinstance(constraint, schema.Index): # name should not be quoted. - d = dialect.ddl_compiler(dialect, None) + d = dialect.ddl_compiler(dialect, None) # type: ignore[arg-type] return d._prepared_index_name( # type: ignore[attr-defined] constraint ) @@ -529,7 +533,7 @@ def _insert_inline(table: Union["TableClause", "Table"]) -> "Insert": if sqla_14: return table.insert().inline() else: - return table.insert(inline=True) + return table.insert(inline=True) # type: ignore[call-arg] if sqla_14: @@ -543,5 +547,5 @@ else: "postgresql://", strategy="mock", executor=executor ) - def _select(*columns, **kw) -> "Select": - return sql.select(list(columns), **kw) + def _select(*columns, **kw) -> "Select": # type: ignore[no-redef] + return sql.select(list(columns), **kw) # type: ignore[call-overload] diff --git a/pyproject.toml b/pyproject.toml index 2a8de06..f66269a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,17 @@ show_error_codes = true [[tool.mypy.overrides]] module = [ + 'alembic.operations.ops', + 'alembic.op', + 'alembic.context', + 'alembic.autogenerate.api', + 'alembic.runtime.*', +] + +disallow_incomplete_defs = true + +[[tool.mypy.overrides]] +module = [ 'mako.*', 'sqlalchemy.testing.*' ] diff --git a/tools/write_pyi.py b/tools/write_pyi.py index cf42d1b..52fac3c 100644 --- a/tools/write_pyi.py +++ b/tools/write_pyi.py @@ -29,6 +29,7 @@ IGNORE_ITEMS = { } TRIM_MODULE = [ "alembic.runtime.migration.", + "alembic.operations.base.", "alembic.operations.ops.", "sqlalchemy.engine.base.", "sqlalchemy.sql.schema.", @@ -85,6 +86,8 @@ def generate_pyi_for_proxy( module = sys.modules[cls.__module__] env = { + **typing.__dict__, + **sa.sql.schema.__dict__, **sa.__dict__, **sa.types.__dict__, **ops.__dict__, @@ -141,7 +144,7 @@ def _generate_stub_for_meth(cls, name, printer, env, is_context_manager): annotations = typing.get_type_hints(fn, env) spec.annotations.update(annotations) except NameError as e: - pass + print(f"{cls.__name__}.{name} NameError: {e}", file=sys.stderr) name_args = spec[0] assert name_args[0:1] == ["self"] or name_args[0:1] == ["cls"] |