diff options
Diffstat (limited to 'alembic/ddl/impl.py')
-rw-r--r-- | alembic/ddl/impl.py | 239 |
1 files changed, 170 insertions, 69 deletions
diff --git a/alembic/ddl/impl.py b/alembic/ddl/impl.py index 710509c..2ca316c 100644 --- a/alembic/ddl/impl.py +++ b/alembic/ddl/impl.py @@ -1,5 +1,16 @@ from collections import namedtuple import re +from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import Optional +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import Union from sqlalchemy import cast from sqlalchemy import schema @@ -11,16 +22,49 @@ from ..util import sqla_compat from ..util.compat import string_types from ..util.compat import text_type +if TYPE_CHECKING: + from io import StringIO + from typing import Literal + + from sqlalchemy.engine import Connection + from sqlalchemy.engine import Dialect + from sqlalchemy.engine.cursor import CursorResult + from sqlalchemy.engine.cursor import LegacyCursorResult + from sqlalchemy.engine.reflection import Inspector + from sqlalchemy.sql.dml import Update + from sqlalchemy.sql.elements import ClauseElement + from sqlalchemy.sql.elements import ColumnElement + from sqlalchemy.sql.elements import quoted_name + from sqlalchemy.sql.elements import TextClause + from sqlalchemy.sql.schema import Column + from sqlalchemy.sql.schema import Constraint + from sqlalchemy.sql.schema import ForeignKeyConstraint + from sqlalchemy.sql.schema import Index + from sqlalchemy.sql.schema import Table + from sqlalchemy.sql.schema import UniqueConstraint + from sqlalchemy.sql.selectable import TableClause + from sqlalchemy.sql.type_api import TypeEngine + + from .base import _ServerDefault + from ..autogenerate.api import AutogenContext + from ..operations.batch import ApplyBatchImpl + from ..operations.batch import BatchOperationsImpl + class ImplMeta(type): - def __init__(cls, classname, bases, dict_): + def __init__( + cls, + classname: str, + bases: Tuple[Type["DefaultImpl"]], + dict_: Dict[str, Any], + ): newtype = type.__init__(cls, classname, bases, dict_) if "__dialect__" in dict_: _impls[dict_["__dialect__"]] = cls return newtype -_impls = {} +_impls: dict = {} Params = namedtuple("Params", ["token0", "tokens", "args", "kwargs"]) @@ -43,27 +87,27 @@ class DefaultImpl(metaclass=ImplMeta): transactional_ddl = False command_terminator = ";" - type_synonyms = ({"NUMERIC", "DECIMAL"},) - type_arg_extract = () + type_synonyms: Tuple[Set[str], ...] = ({"NUMERIC", "DECIMAL"},) + type_arg_extract: Sequence[str] = () # on_null is known to be supported only by oracle - identity_attrs_ignore = ("on_null",) + identity_attrs_ignore: Tuple[str, ...] = ("on_null",) def __init__( self, - dialect, - connection, - as_sql, - transactional_ddl, - output_buffer, - context_opts, - ): + dialect: "Dialect", + connection: Optional["Connection"], + as_sql: bool, + transactional_ddl: Optional[bool], + output_buffer: Optional["StringIO"], + context_opts: Dict[str, Any], + ) -> None: self.dialect = dialect self.connection = connection self.as_sql = as_sql self.literal_binds = context_opts.get("literal_binds", False) self.output_buffer = output_buffer - self.memo = {} + self.memo: dict = {} self.context_opts = context_opts if transactional_ddl is not None: self.transactional_ddl = transactional_ddl @@ -75,14 +119,17 @@ class DefaultImpl(metaclass=ImplMeta): ) @classmethod - def get_by_dialect(cls, dialect): + def get_by_dialect(cls, dialect: "Dialect") -> Any: return _impls[dialect.name] - def static_output(self, text): + def static_output(self, text: str) -> None: + assert self.output_buffer is not None self.output_buffer.write(text_type(text + "\n\n")) self.output_buffer.flush() - def requires_recreate_in_batch(self, batch_op): + def requires_recreate_in_batch( + self, batch_op: "BatchOperationsImpl" + ) -> bool: """Return True if the given :class:`.BatchOperationsImpl` would need the table to be recreated and copied in order to proceed. @@ -93,7 +140,9 @@ class DefaultImpl(metaclass=ImplMeta): """ return False - def prep_table_for_batch(self, batch_impl, table): + def prep_table_for_batch( + self, batch_impl: "ApplyBatchImpl", table: "Table" + ) -> None: """perform any operations needed on a table before a new one is created to replace it in batch mode. @@ -103,16 +152,16 @@ class DefaultImpl(metaclass=ImplMeta): """ @property - def bind(self): + def bind(self) -> Optional["Connection"]: return self.connection def _exec( self, - construct, - execution_options=None, - multiparams=(), - params=util.immutabledict(), - ): + construct: Union["ClauseElement", str], + execution_options: None = None, + multiparams: Sequence[dict] = (), + params: Dict[str, int] = util.immutabledict(), + ) -> Optional[Union["LegacyCursorResult", "CursorResult"]]: if isinstance(construct, string_types): construct = text(construct) if self.as_sql: @@ -135,35 +184,43 @@ class DefaultImpl(metaclass=ImplMeta): .strip() + self.command_terminator ) + return None else: conn = self.connection + assert conn is not None if execution_options: conn = conn.execution_options(**execution_options) if params: + assert isinstance(multiparams, tuple) multiparams += (params,) return conn.execute(construct, multiparams) - def execute(self, sql, execution_options=None): + def execute( + self, + sql: Union["Update", "TextClause", str], + execution_options: None = None, + ) -> None: self._exec(sql, execution_options) def alter_column( self, - table_name, - column_name, - nullable=None, - server_default=False, - name=None, - type_=None, - schema=None, - autoincrement=None, - comment=False, - existing_comment=None, - existing_type=None, - existing_server_default=None, - existing_nullable=None, - existing_autoincrement=None, - ): + table_name: str, + column_name: str, + nullable: Optional[bool] = None, + server_default: Union["_ServerDefault", "Literal[False]"] = False, + name: Optional[str] = None, + type_: Optional["TypeEngine"] = None, + schema: Optional[str] = None, + autoincrement: Optional[bool] = None, + comment: Optional[Union[str, "Literal[False]"]] = False, + existing_comment: Optional[str] = None, + existing_type: Optional["TypeEngine"] = None, + existing_server_default: Optional["_ServerDefault"] = None, + existing_nullable: Optional[bool] = None, + existing_autoincrement: Optional[bool] = None, + **kw: Any + ) -> None: if autoincrement is not None or existing_autoincrement is not None: util.warn( "autoincrement and existing_autoincrement " @@ -185,6 +242,13 @@ class DefaultImpl(metaclass=ImplMeta): ) if server_default is not False: kw = {} + cls_: Type[ + Union[ + base.ComputedColumnDefault, + base.IdentityColumnDefault, + base.ColumnDefault, + ] + ] if sqla_compat._server_default_is_computed( server_default, existing_server_default ): @@ -200,7 +264,7 @@ class DefaultImpl(metaclass=ImplMeta): cls_( table_name, column_name, - server_default, + server_default, # type:ignore[arg-type] schema=schema, existing_type=existing_type, existing_server_default=existing_server_default, @@ -251,25 +315,41 @@ class DefaultImpl(metaclass=ImplMeta): ) ) - def add_column(self, table_name, column, schema=None): + def add_column( + self, + table_name: str, + column: "Column", + schema: Optional[Union[str, "quoted_name"]] = None, + ) -> None: self._exec(base.AddColumn(table_name, column, schema=schema)) - def drop_column(self, table_name, column, schema=None, **kw): + def drop_column( + self, + table_name: str, + column: "Column", + schema: Optional[str] = None, + **kw + ) -> None: self._exec(base.DropColumn(table_name, column, schema=schema)) - def add_constraint(self, const): + def add_constraint(self, const: Any) -> None: if const._create_rule is None or const._create_rule(self): self._exec(schema.AddConstraint(const)) - def drop_constraint(self, const): + def drop_constraint(self, const: "Constraint") -> None: self._exec(schema.DropConstraint(const)) - def rename_table(self, old_table_name, new_table_name, schema=None): + def rename_table( + self, + old_table_name: str, + new_table_name: Union[str, "quoted_name"], + schema: Optional[Union[str, "quoted_name"]] = None, + ) -> None: self._exec( base.RenameTable(old_table_name, new_table_name, schema=schema) ) - def create_table(self, table): + def create_table(self, table: "Table") -> None: table.dispatch.before_create( table, self.connection, checkfirst=False, _ddl_runner=self ) @@ -292,25 +372,30 @@ class DefaultImpl(metaclass=ImplMeta): if comment and with_comment: self.create_column_comment(column) - def drop_table(self, table): + def drop_table(self, table: "Table") -> None: self._exec(schema.DropTable(table)) - def create_index(self, index): + def create_index(self, index: "Index") -> None: self._exec(schema.CreateIndex(index)) - def create_table_comment(self, table): + def create_table_comment(self, table: "Table") -> None: self._exec(schema.SetTableComment(table)) - def drop_table_comment(self, table): + def drop_table_comment(self, table: "Table") -> None: self._exec(schema.DropTableComment(table)) - def create_column_comment(self, column): + def create_column_comment(self, column: "ColumnElement") -> None: self._exec(schema.SetColumnComment(column)) - def drop_index(self, index): + def drop_index(self, index: "Index") -> None: self._exec(schema.DropIndex(index)) - def bulk_insert(self, table, rows, multiinsert=True): + def bulk_insert( + self, + table: Union["TableClause", "Table"], + rows: List[dict], + multiinsert: bool = True, + ) -> None: if not isinstance(rows, list): raise TypeError("List expected") elif rows and not isinstance(rows[0], dict): @@ -349,7 +434,7 @@ class DefaultImpl(metaclass=ImplMeta): sqla_compat._insert_inline(table).values(**row) ) - def _tokenize_column_type(self, column): + def _tokenize_column_type(self, column: "Column") -> Params: definition = self.dialect.type_compiler.process(column.type).lower() # tokenize the SQLAlchemy-generated version of a type, so that @@ -387,7 +472,9 @@ class DefaultImpl(metaclass=ImplMeta): return params - def _column_types_match(self, inspector_params, metadata_params): + def _column_types_match( + self, inspector_params: "Params", metadata_params: "Params" + ) -> bool: if inspector_params.token0 == metadata_params.token0: return True @@ -407,7 +494,9 @@ class DefaultImpl(metaclass=ImplMeta): return True return False - def _column_args_match(self, inspected_params, meta_params): + def _column_args_match( + self, inspected_params: "Params", meta_params: "Params" + ) -> bool: """We want to compare column parameters. However, we only want to compare parameters that are set. If they both have `collation`, we want to make sure they are the same. However, if only one @@ -438,7 +527,9 @@ class DefaultImpl(metaclass=ImplMeta): return True - def compare_type(self, inspector_column, metadata_column): + def compare_type( + self, inspector_column: "Column", metadata_column: "Column" + ) -> bool: """Returns True if there ARE differences between the types of the two columns. Takes impl.type_synonyms into account between retrospected and metadata types @@ -463,11 +554,11 @@ class DefaultImpl(metaclass=ImplMeta): def correct_for_autogen_constraints( self, - conn_uniques, - conn_indexes, - metadata_unique_constraints, - metadata_indexes, - ): + conn_uniques: Union[Set["UniqueConstraint"]], + conn_indexes: Union[Set["Index"]], + metadata_unique_constraints: Set["UniqueConstraint"], + metadata_indexes: Set["Index"], + ) -> None: pass def cast_for_batch_migrate(self, existing, existing_transfer, new_type): @@ -476,7 +567,9 @@ class DefaultImpl(metaclass=ImplMeta): existing_transfer["expr"], new_type ) - def render_ddl_sql_expr(self, expr, is_server_default=False, **kw): + def render_ddl_sql_expr( + self, expr: "ClauseElement", is_server_default: bool = False, **kw + ) -> str: """Render a SQL expression that is typically a server default, index expression, etc. @@ -489,10 +582,16 @@ class DefaultImpl(metaclass=ImplMeta): ) return text_type(expr.compile(dialect=self.dialect, **compile_kw)) - def _compat_autogen_column_reflect(self, inspector): + def _compat_autogen_column_reflect( + self, inspector: "Inspector" + ) -> Callable: return self.autogen_column_reflect - def correct_for_autogen_foreignkeys(self, conn_fks, metadata_fks): + def correct_for_autogen_foreignkeys( + self, + conn_fks: Set["ForeignKeyConstraint"], + metadata_fks: Set["ForeignKeyConstraint"], + ) -> None: pass def autogen_column_reflect(self, inspector, table, column_info): @@ -504,7 +603,7 @@ class DefaultImpl(metaclass=ImplMeta): """ - def start_migrations(self): + def start_migrations(self) -> None: """A hook called when :meth:`.EnvironmentContext.run_migrations` is called. @@ -512,7 +611,7 @@ class DefaultImpl(metaclass=ImplMeta): """ - def emit_begin(self): + def emit_begin(self) -> None: """Emit the string ``BEGIN``, or the backend-specific equivalent, on the current connection context. @@ -522,7 +621,7 @@ class DefaultImpl(metaclass=ImplMeta): """ self.static_output("BEGIN" + self.command_terminator) - def emit_commit(self): + def emit_commit(self) -> None: """Emit the string ``COMMIT``, or the backend-specific equivalent, on the current connection context. @@ -532,7 +631,9 @@ class DefaultImpl(metaclass=ImplMeta): """ self.static_output("COMMIT" + self.command_terminator) - def render_type(self, type_obj, autogen_context): + def render_type( + self, type_obj: "TypeEngine", autogen_context: "AutogenContext" + ) -> Union[str, "Literal[False]"]: return False def _compare_identity_default(self, metadata_identity, inspector_identity): |