From 6ad07e26f8b8fc3da931ea572547bb6f2643e088 Mon Sep 17 00:00:00 2001 From: Federico Caselli Date: Tue, 16 May 2023 21:52:02 +0200 Subject: Define type for generic classes Fixed typing use of :class:`~sqlalchemy.schema.Column` and other generic SQLAlchemy classes. Fixes: #1246 Change-Id: I5ee80395d626894a52e3395c9986213289576355 --- alembic/autogenerate/compare.py | 24 ++++++++++++------------ alembic/autogenerate/render.py | 8 ++++++-- alembic/context.pyi | 4 ++-- alembic/ddl/base.py | 6 +++--- alembic/ddl/impl.py | 8 ++++---- alembic/ddl/mssql.py | 10 ++++++---- alembic/ddl/oracle.py | 2 +- alembic/ddl/postgresql.py | 6 ++++-- alembic/ddl/sqlite.py | 6 +++--- alembic/op.pyi | 2 +- alembic/operations/base.py | 4 ++-- alembic/operations/batch.py | 9 ++++++--- alembic/operations/ops.py | 14 +++++++------- alembic/runtime/environment.py | 4 ++-- alembic/runtime/migration.py | 6 +++--- alembic/util/sqla_compat.py | 8 ++++---- 16 files changed, 66 insertions(+), 55 deletions(-) (limited to 'alembic') diff --git a/alembic/autogenerate/compare.py b/alembic/autogenerate/compare.py index 5727891..031d683 100644 --- a/alembic/autogenerate/compare.py +++ b/alembic/autogenerate/compare.py @@ -926,8 +926,8 @@ def _compare_nullable( schema: Optional[str], tname: Union[quoted_name, str], cname: Union[quoted_name, str], - conn_col: Column, - metadata_col: Column, + conn_col: Column[Any], + metadata_col: Column[Any], ) -> None: metadata_col_nullable = metadata_col.nullable @@ -968,8 +968,8 @@ def _setup_autoincrement( schema: Optional[str], tname: Union[quoted_name, str], cname: quoted_name, - conn_col: Column, - metadata_col: Column, + conn_col: Column[Any], + metadata_col: Column[Any], ) -> None: if metadata_col.table._autoincrement_column is metadata_col: @@ -987,8 +987,8 @@ def _compare_type( schema: Optional[str], tname: Union[quoted_name, str], cname: Union[quoted_name, str], - conn_col: Column, - metadata_col: Column, + conn_col: Column[Any], + metadata_col: Column[Any], ) -> None: conn_type = conn_col.type @@ -1060,8 +1060,8 @@ def _compare_computed_default( schema: Optional[str], tname: str, cname: str, - conn_col: Column, - metadata_col: Column, + conn_col: Column[Any], + metadata_col: Column[Any], ) -> None: rendered_metadata_default = str( cast(sa_schema.Computed, metadata_col.server_default).sqltext.compile( @@ -1126,8 +1126,8 @@ def _compare_server_default( schema: Optional[str], tname: Union[quoted_name, str], cname: Union[quoted_name, str], - conn_col: Column, - metadata_col: Column, + conn_col: Column[Any], + metadata_col: Column[Any], ) -> Optional[bool]: metadata_default = metadata_col.server_default @@ -1215,8 +1215,8 @@ def _compare_column_comment( schema: Optional[str], tname: Union[quoted_name, str], cname: quoted_name, - conn_col: Column, - metadata_col: Column, + conn_col: Column[Any], + metadata_col: Column[Any], ) -> Optional[Literal[False]]: assert autogen_context.dialect is not None diff --git a/alembic/autogenerate/render.py b/alembic/autogenerate/render.py index dc841f8..00d1d2f 100644 --- a/alembic/autogenerate/render.py +++ b/alembic/autogenerate/render.py @@ -664,7 +664,9 @@ def _user_defined_render( return False -def _render_column(column: Column, autogen_context: AutogenContext) -> str: +def _render_column( + column: Column[Any], autogen_context: AutogenContext +) -> str: rendered = _user_defined_render("column", column, autogen_context) if rendered is not False: return rendered @@ -727,7 +729,9 @@ def _should_render_server_default_positionally(server_default: Any) -> bool: def _render_server_default( - default: Optional[Union[FetchedValue, str, TextClause, ColumnElement]], + default: Optional[ + Union[FetchedValue, str, TextClause, ColumnElement[Any]] + ], autogen_context: AutogenContext, repr_: bool = True, ) -> Optional[str]: diff --git a/alembic/context.pyi b/alembic/context.pyi index 621599d..eedf7af 100644 --- a/alembic/context.pyi +++ b/alembic/context.pyi @@ -151,8 +151,8 @@ def configure( Callable[ [ MigrationContext, - Column, - Column, + Column[Any], + Column[Any], Optional[str], Optional[FetchedValue], Optional[str], diff --git a/alembic/ddl/base.py b/alembic/ddl/base.py index 65da32f..339db0c 100644 --- a/alembic/ddl/base.py +++ b/alembic/ddl/base.py @@ -150,7 +150,7 @@ class AddColumn(AlterTable): def __init__( self, name: str, - column: Column, + column: Column[Any], schema: Optional[Union[quoted_name, str]] = None, ) -> None: super().__init__(name, schema=schema) @@ -159,7 +159,7 @@ class AddColumn(AlterTable): class DropColumn(AlterTable): def __init__( - self, name: str, column: Column, schema: Optional[str] = None + self, name: str, column: Column[Any], schema: Optional[str] = None ) -> None: super().__init__(name, schema=schema) self.column = column @@ -320,7 +320,7 @@ def alter_column(compiler: DDLCompiler, name: str) -> str: return "ALTER COLUMN %s" % format_column_name(compiler, name) -def add_column(compiler: DDLCompiler, column: Column, **kw) -> str: +def add_column(compiler: DDLCompiler, column: Column[Any], **kw) -> str: text = "ADD COLUMN %s" % compiler.get_column_specification(column, **kw) const = " ".join( diff --git a/alembic/ddl/impl.py b/alembic/ddl/impl.py index 84f5d86..03f134d 100644 --- a/alembic/ddl/impl.py +++ b/alembic/ddl/impl.py @@ -316,7 +316,7 @@ class DefaultImpl(metaclass=ImplMeta): def add_column( self, table_name: str, - column: Column, + column: Column[Any], schema: Optional[Union[str, quoted_name]] = None, ) -> None: self._exec(base.AddColumn(table_name, column, schema=schema)) @@ -324,7 +324,7 @@ class DefaultImpl(metaclass=ImplMeta): def drop_column( self, table_name: str, - column: Column, + column: Column[Any], schema: Optional[str] = None, **kw, ) -> None: @@ -388,7 +388,7 @@ class DefaultImpl(metaclass=ImplMeta): def drop_table_comment(self, table: Table) -> None: self._exec(schema.DropTableComment(table)) - def create_column_comment(self, column: ColumnElement) -> None: + def create_column_comment(self, column: ColumnElement[Any]) -> None: self._exec(schema.SetColumnComment(column)) def drop_index(self, index: Index) -> None: @@ -526,7 +526,7 @@ class DefaultImpl(metaclass=ImplMeta): return True def compare_type( - self, inspector_column: Column, metadata_column: Column + self, inspector_column: Column[Any], metadata_column: Column ) -> bool: """Returns True if there ARE differences between the types of the two columns. Takes impl.type_synonyms into account between retrospected diff --git a/alembic/ddl/mssql.py b/alembic/ddl/mssql.py index ebf4db1..10c1a6b 100644 --- a/alembic/ddl/mssql.py +++ b/alembic/ddl/mssql.py @@ -201,7 +201,7 @@ class MSSQLImpl(DefaultImpl): def drop_column( self, table_name: str, - column: Column, + column: Column[Any], schema: Optional[str] = None, **kw, ) -> None: @@ -273,7 +273,7 @@ class _ExecDropConstraint(Executable, ClauseElement): def __init__( self, tname: str, - colname: Union[Column, str], + colname: Union[Column[Any], str], type_: str, schema: Optional[str], ) -> None: @@ -287,7 +287,7 @@ class _ExecDropFKConstraint(Executable, ClauseElement): inherit_cache = False def __init__( - self, tname: str, colname: Column, schema: Optional[str] + self, tname: str, colname: Column[Any], schema: Optional[str] ) -> None: self.tname = tname self.colname = colname @@ -347,7 +347,9 @@ def visit_add_column(element: AddColumn, compiler: MSDDLCompiler, **kw) -> str: ) -def mssql_add_column(compiler: MSDDLCompiler, column: Column, **kw) -> str: +def mssql_add_column( + compiler: MSDDLCompiler, column: Column[Any], **kw +) -> str: return "ADD %s" % compiler.get_column_specification(column, **kw) diff --git a/alembic/ddl/oracle.py b/alembic/ddl/oracle.py index 9715c1e..e56bb21 100644 --- a/alembic/ddl/oracle.py +++ b/alembic/ddl/oracle.py @@ -176,7 +176,7 @@ def alter_column(compiler: OracleDDLCompiler, name: str) -> str: return "MODIFY %s" % format_column_name(compiler, name) -def add_column(compiler: OracleDDLCompiler, column: Column, **kw) -> str: +def add_column(compiler: OracleDDLCompiler, column: Column[Any], **kw) -> str: return "ADD %s" % compiler.get_column_specification(column, **kw) diff --git a/alembic/ddl/postgresql.py b/alembic/ddl/postgresql.py index 6c858e7..e3ada90 100644 --- a/alembic/ddl/postgresql.py +++ b/alembic/ddl/postgresql.py @@ -486,7 +486,7 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp): table_name: Union[str, quoted_name], elements: Union[ Sequence[Tuple[str, str]], - Sequence[Tuple[ColumnClause, str]], + Sequence[Tuple[ColumnClause[Any], str]], ], where: Optional[Union[BinaryExpression, str]] = None, schema: Optional[str] = None, @@ -706,7 +706,9 @@ def _exclude_constraint( def _render_potential_column( - value: Union[ColumnClause, Column, TextClause, FunctionElement], + value: Union[ + ColumnClause[Any], Column[Any], TextClause, FunctionElement[Any] + ], autogen_context: AutogenContext, ) -> str: if isinstance(value, ColumnClause): diff --git a/alembic/ddl/sqlite.py b/alembic/ddl/sqlite.py index 302a877..67a1c28 100644 --- a/alembic/ddl/sqlite.py +++ b/alembic/ddl/sqlite.py @@ -95,8 +95,8 @@ class SQLiteImpl(DefaultImpl): def compare_server_default( self, - inspector_column: Column, - metadata_column: Column, + inspector_column: Column[Any], + metadata_column: Column[Any], rendered_metadata_default: Optional[str], rendered_inspector_default: Optional[str], ) -> bool: @@ -173,7 +173,7 @@ class SQLiteImpl(DefaultImpl): def cast_for_batch_migrate( self, - existing: Column, + existing: Column[Any], existing_transfer: Dict[str, Union[TypeEngine, Cast]], new_type: TypeEngine, ) -> None: diff --git a/alembic/op.pyi b/alembic/op.pyi index 10e6f59..1eb1495 100644 --- a/alembic/op.pyi +++ b/alembic/op.pyi @@ -45,7 +45,7 @@ _T = TypeVar("_T") ### end imports ### def add_column( - table_name: str, column: Column, *, schema: Optional[str] = None + table_name: str, column: Column[Any], *, schema: Optional[str] = None ) -> None: """Issue an "add column" instruction using the current migration context. diff --git a/alembic/operations/base.py b/alembic/operations/base.py index 4e59e5b..fa3fe13 100644 --- a/alembic/operations/base.py +++ b/alembic/operations/base.py @@ -569,7 +569,7 @@ class Operations(AbstractOperations): def add_column( self, table_name: str, - column: Column, + column: Column[Any], *, schema: Optional[str] = None, ) -> None: @@ -1574,7 +1574,7 @@ class BatchOperations(AbstractOperations): def add_column( self, - column: Column, + column: Column[Any], *, insert_before: Optional[str] = None, insert_after: Optional[str] = None, diff --git a/alembic/operations/batch.py b/alembic/operations/batch.py index f4a058b..5b6b547 100644 --- a/alembic/operations/batch.py +++ b/alembic/operations/batch.py @@ -243,7 +243,7 @@ class ApplyBatchImpl: def _grab_table_elements(self) -> None: schema = self.table.schema - self.columns: Dict[str, Column] = OrderedDict() + self.columns: Dict[str, Column[Any]] = OrderedDict() for c in self.table.c: c_copy = _copy(c, schema=schema) c_copy.unique = c_copy.index = False @@ -607,7 +607,7 @@ class ApplyBatchImpl: def add_column( self, table_name: str, - column: Column, + column: Column[Any], insert_before: Optional[str] = None, insert_after: Optional[str] = None, **kw, @@ -621,7 +621,10 @@ class ApplyBatchImpl: self.column_transfers[column.name] = {} def drop_column( - self, table_name: str, column: Union[ColumnClause, Column], **kw + self, + table_name: str, + column: Union[ColumnClause[Any], Column[Any]], + **kw, ) -> None: if column.name in self.table.primary_key.columns: _remove_column_from_collection( diff --git a/alembic/operations/ops.py b/alembic/operations/ops.py index 5334a01..472c0e8 100644 --- a/alembic/operations/ops.py +++ b/alembic/operations/ops.py @@ -1994,7 +1994,7 @@ class AddColumnOp(AlterTableOp): def __init__( self, table_name: str, - column: Column, + column: Column[Any], *, schema: Optional[str] = None, **kw: Any, @@ -2010,7 +2010,7 @@ class AddColumnOp(AlterTableOp): def to_diff_tuple( self, - ) -> Tuple[str, Optional[str], str, Column]: + ) -> Tuple[str, Optional[str], str, Column[Any]]: return ("add_column", self.schema, self.table_name, self.column) def to_column(self) -> Column: @@ -2025,7 +2025,7 @@ class AddColumnOp(AlterTableOp): cls, schema: Optional[str], tname: str, - col: Column, + col: Column[Any], ) -> AddColumnOp: return cls(tname, col, schema=schema) @@ -2034,7 +2034,7 @@ class AddColumnOp(AlterTableOp): cls, operations: Operations, table_name: str, - column: Column, + column: Column[Any], *, schema: Optional[str] = None, ) -> None: @@ -2123,7 +2123,7 @@ class AddColumnOp(AlterTableOp): def batch_add_column( cls, operations: BatchOperations, - column: Column, + column: Column[Any], *, insert_before: Optional[str] = None, insert_after: Optional[str] = None, @@ -2173,7 +2173,7 @@ class DropColumnOp(AlterTableOp): def to_diff_tuple( self, - ) -> Tuple[str, Optional[str], str, Column]: + ) -> Tuple[str, Optional[str], str, Column[Any]]: return ( "remove_column", self.schema, @@ -2197,7 +2197,7 @@ class DropColumnOp(AlterTableOp): cls, schema: Optional[str], tname: str, - col: Column, + col: Column[Any], ) -> DropColumnOp: return cls( tname, diff --git a/alembic/runtime/environment.py b/alembic/runtime/environment.py index 3087377..acd5cd1 100644 --- a/alembic/runtime/environment.py +++ b/alembic/runtime/environment.py @@ -84,8 +84,8 @@ OnVersionApplyFn = Callable[ CompareServerDefault = Callable[ [ MigrationContext, - Column, - Column, + "Column[Any]", + "Column[Any]", Optional[str], Optional[FetchedValue], Optional[str], diff --git a/alembic/runtime/migration.py b/alembic/runtime/migration.py index 8baeaf0..1715e8a 100644 --- a/alembic/runtime/migration.py +++ b/alembic/runtime/migration.py @@ -708,7 +708,7 @@ class MigrationContext: return None def _compare_type( - self, inspector_column: Column, metadata_column: Column + self, inspector_column: Column[Any], metadata_column: Column ) -> bool: if self._user_compare_type is False: return False @@ -728,8 +728,8 @@ class MigrationContext: def _compare_server_default( self, - inspector_column: Column, - metadata_column: Column, + inspector_column: Column[Any], + metadata_column: Column[Any], rendered_metadata_default: Optional[str], rendered_column_default: Optional[str], ) -> bool: diff --git a/alembic/util/sqla_compat.py b/alembic/util/sqla_compat.py index 37e1ee1..376448a 100644 --- a/alembic/util/sqla_compat.py +++ b/alembic/util/sqla_compat.py @@ -46,7 +46,7 @@ if TYPE_CHECKING: from sqlalchemy.sql.selectable import Select from sqlalchemy.sql.selectable import TableClause -_CE = TypeVar("_CE", bound=Union["ColumnElement", "SchemaItem"]) +_CE = TypeVar("_CE", bound=Union["ColumnElement[Any]", "SchemaItem"]) def _safe_int(value: str) -> Union[int, str]: @@ -390,7 +390,7 @@ def _find_columns(clause): def _remove_column_from_collection( - collection: ColumnCollection, column: Union[Column, ColumnClause] + collection: ColumnCollection, column: Union[Column[Any], ColumnClause[Any]] ) -> None: """remove a column from a ColumnCollection.""" @@ -408,8 +408,8 @@ def _remove_column_from_collection( def _textual_index_column( - table: Table, text_: Union[str, TextClause, ColumnElement] -) -> Union[ColumnElement, Column]: + table: Table, text_: Union[str, TextClause, ColumnElement[Any]] +) -> Union[ColumnElement[Any], Column[Any]]: """a workaround for the Index construct's severe lack of flexibility""" if isinstance(text_, str): c = Column(text_, sqltypes.NULLTYPE) -- cgit v1.2.1