summaryrefslogtreecommitdiff
path: root/alembic/ddl/impl.py
diff options
context:
space:
mode:
Diffstat (limited to 'alembic/ddl/impl.py')
-rw-r--r--alembic/ddl/impl.py239
1 files changed, 170 insertions, 69 deletions
diff --git a/alembic/ddl/impl.py b/alembic/ddl/impl.py
index 710509c..2ca316c 100644
--- a/alembic/ddl/impl.py
+++ b/alembic/ddl/impl.py
@@ -1,5 +1,16 @@
from collections import namedtuple
import re
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Set
+from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
+from typing import Union
from sqlalchemy import cast
from sqlalchemy import schema
@@ -11,16 +22,49 @@ from ..util import sqla_compat
from ..util.compat import string_types
from ..util.compat import text_type
+if TYPE_CHECKING:
+ from io import StringIO
+ from typing import Literal
+
+ from sqlalchemy.engine import Connection
+ from sqlalchemy.engine import Dialect
+ from sqlalchemy.engine.cursor import CursorResult
+ from sqlalchemy.engine.cursor import LegacyCursorResult
+ from sqlalchemy.engine.reflection import Inspector
+ from sqlalchemy.sql.dml import Update
+ from sqlalchemy.sql.elements import ClauseElement
+ from sqlalchemy.sql.elements import ColumnElement
+ from sqlalchemy.sql.elements import quoted_name
+ from sqlalchemy.sql.elements import TextClause
+ from sqlalchemy.sql.schema import Column
+ from sqlalchemy.sql.schema import Constraint
+ from sqlalchemy.sql.schema import ForeignKeyConstraint
+ from sqlalchemy.sql.schema import Index
+ from sqlalchemy.sql.schema import Table
+ from sqlalchemy.sql.schema import UniqueConstraint
+ from sqlalchemy.sql.selectable import TableClause
+ from sqlalchemy.sql.type_api import TypeEngine
+
+ from .base import _ServerDefault
+ from ..autogenerate.api import AutogenContext
+ from ..operations.batch import ApplyBatchImpl
+ from ..operations.batch import BatchOperationsImpl
+
class ImplMeta(type):
- def __init__(cls, classname, bases, dict_):
+ def __init__(
+ cls,
+ classname: str,
+ bases: Tuple[Type["DefaultImpl"]],
+ dict_: Dict[str, Any],
+ ):
newtype = type.__init__(cls, classname, bases, dict_)
if "__dialect__" in dict_:
_impls[dict_["__dialect__"]] = cls
return newtype
-_impls = {}
+_impls: dict = {}
Params = namedtuple("Params", ["token0", "tokens", "args", "kwargs"])
@@ -43,27 +87,27 @@ class DefaultImpl(metaclass=ImplMeta):
transactional_ddl = False
command_terminator = ";"
- type_synonyms = ({"NUMERIC", "DECIMAL"},)
- type_arg_extract = ()
+ type_synonyms: Tuple[Set[str], ...] = ({"NUMERIC", "DECIMAL"},)
+ type_arg_extract: Sequence[str] = ()
# on_null is known to be supported only by oracle
- identity_attrs_ignore = ("on_null",)
+ identity_attrs_ignore: Tuple[str, ...] = ("on_null",)
def __init__(
self,
- dialect,
- connection,
- as_sql,
- transactional_ddl,
- output_buffer,
- context_opts,
- ):
+ dialect: "Dialect",
+ connection: Optional["Connection"],
+ as_sql: bool,
+ transactional_ddl: Optional[bool],
+ output_buffer: Optional["StringIO"],
+ context_opts: Dict[str, Any],
+ ) -> None:
self.dialect = dialect
self.connection = connection
self.as_sql = as_sql
self.literal_binds = context_opts.get("literal_binds", False)
self.output_buffer = output_buffer
- self.memo = {}
+ self.memo: dict = {}
self.context_opts = context_opts
if transactional_ddl is not None:
self.transactional_ddl = transactional_ddl
@@ -75,14 +119,17 @@ class DefaultImpl(metaclass=ImplMeta):
)
@classmethod
- def get_by_dialect(cls, dialect):
+ def get_by_dialect(cls, dialect: "Dialect") -> Any:
return _impls[dialect.name]
- def static_output(self, text):
+ def static_output(self, text: str) -> None:
+ assert self.output_buffer is not None
self.output_buffer.write(text_type(text + "\n\n"))
self.output_buffer.flush()
- def requires_recreate_in_batch(self, batch_op):
+ def requires_recreate_in_batch(
+ self, batch_op: "BatchOperationsImpl"
+ ) -> bool:
"""Return True if the given :class:`.BatchOperationsImpl`
would need the table to be recreated and copied in order to
proceed.
@@ -93,7 +140,9 @@ class DefaultImpl(metaclass=ImplMeta):
"""
return False
- def prep_table_for_batch(self, batch_impl, table):
+ def prep_table_for_batch(
+ self, batch_impl: "ApplyBatchImpl", table: "Table"
+ ) -> None:
"""perform any operations needed on a table before a new
one is created to replace it in batch mode.
@@ -103,16 +152,16 @@ class DefaultImpl(metaclass=ImplMeta):
"""
@property
- def bind(self):
+ def bind(self) -> Optional["Connection"]:
return self.connection
def _exec(
self,
- construct,
- execution_options=None,
- multiparams=(),
- params=util.immutabledict(),
- ):
+ construct: Union["ClauseElement", str],
+ execution_options: None = None,
+ multiparams: Sequence[dict] = (),
+ params: Dict[str, int] = util.immutabledict(),
+ ) -> Optional[Union["LegacyCursorResult", "CursorResult"]]:
if isinstance(construct, string_types):
construct = text(construct)
if self.as_sql:
@@ -135,35 +184,43 @@ class DefaultImpl(metaclass=ImplMeta):
.strip()
+ self.command_terminator
)
+ return None
else:
conn = self.connection
+ assert conn is not None
if execution_options:
conn = conn.execution_options(**execution_options)
if params:
+ assert isinstance(multiparams, tuple)
multiparams += (params,)
return conn.execute(construct, multiparams)
- def execute(self, sql, execution_options=None):
+ def execute(
+ self,
+ sql: Union["Update", "TextClause", str],
+ execution_options: None = None,
+ ) -> None:
self._exec(sql, execution_options)
def alter_column(
self,
- table_name,
- column_name,
- nullable=None,
- server_default=False,
- name=None,
- type_=None,
- schema=None,
- autoincrement=None,
- comment=False,
- existing_comment=None,
- existing_type=None,
- existing_server_default=None,
- existing_nullable=None,
- existing_autoincrement=None,
- ):
+ table_name: str,
+ column_name: str,
+ nullable: Optional[bool] = None,
+ server_default: Union["_ServerDefault", "Literal[False]"] = False,
+ name: Optional[str] = None,
+ type_: Optional["TypeEngine"] = None,
+ schema: Optional[str] = None,
+ autoincrement: Optional[bool] = None,
+ comment: Optional[Union[str, "Literal[False]"]] = False,
+ existing_comment: Optional[str] = None,
+ existing_type: Optional["TypeEngine"] = None,
+ existing_server_default: Optional["_ServerDefault"] = None,
+ existing_nullable: Optional[bool] = None,
+ existing_autoincrement: Optional[bool] = None,
+ **kw: Any
+ ) -> None:
if autoincrement is not None or existing_autoincrement is not None:
util.warn(
"autoincrement and existing_autoincrement "
@@ -185,6 +242,13 @@ class DefaultImpl(metaclass=ImplMeta):
)
if server_default is not False:
kw = {}
+ cls_: Type[
+ Union[
+ base.ComputedColumnDefault,
+ base.IdentityColumnDefault,
+ base.ColumnDefault,
+ ]
+ ]
if sqla_compat._server_default_is_computed(
server_default, existing_server_default
):
@@ -200,7 +264,7 @@ class DefaultImpl(metaclass=ImplMeta):
cls_(
table_name,
column_name,
- server_default,
+ server_default, # type:ignore[arg-type]
schema=schema,
existing_type=existing_type,
existing_server_default=existing_server_default,
@@ -251,25 +315,41 @@ class DefaultImpl(metaclass=ImplMeta):
)
)
- def add_column(self, table_name, column, schema=None):
+ def add_column(
+ self,
+ table_name: str,
+ column: "Column",
+ schema: Optional[Union[str, "quoted_name"]] = None,
+ ) -> None:
self._exec(base.AddColumn(table_name, column, schema=schema))
- def drop_column(self, table_name, column, schema=None, **kw):
+ def drop_column(
+ self,
+ table_name: str,
+ column: "Column",
+ schema: Optional[str] = None,
+ **kw
+ ) -> None:
self._exec(base.DropColumn(table_name, column, schema=schema))
- def add_constraint(self, const):
+ def add_constraint(self, const: Any) -> None:
if const._create_rule is None or const._create_rule(self):
self._exec(schema.AddConstraint(const))
- def drop_constraint(self, const):
+ def drop_constraint(self, const: "Constraint") -> None:
self._exec(schema.DropConstraint(const))
- def rename_table(self, old_table_name, new_table_name, schema=None):
+ def rename_table(
+ self,
+ old_table_name: str,
+ new_table_name: Union[str, "quoted_name"],
+ schema: Optional[Union[str, "quoted_name"]] = None,
+ ) -> None:
self._exec(
base.RenameTable(old_table_name, new_table_name, schema=schema)
)
- def create_table(self, table):
+ def create_table(self, table: "Table") -> None:
table.dispatch.before_create(
table, self.connection, checkfirst=False, _ddl_runner=self
)
@@ -292,25 +372,30 @@ class DefaultImpl(metaclass=ImplMeta):
if comment and with_comment:
self.create_column_comment(column)
- def drop_table(self, table):
+ def drop_table(self, table: "Table") -> None:
self._exec(schema.DropTable(table))
- def create_index(self, index):
+ def create_index(self, index: "Index") -> None:
self._exec(schema.CreateIndex(index))
- def create_table_comment(self, table):
+ def create_table_comment(self, table: "Table") -> None:
self._exec(schema.SetTableComment(table))
- def drop_table_comment(self, table):
+ def drop_table_comment(self, table: "Table") -> None:
self._exec(schema.DropTableComment(table))
- def create_column_comment(self, column):
+ def create_column_comment(self, column: "ColumnElement") -> None:
self._exec(schema.SetColumnComment(column))
- def drop_index(self, index):
+ def drop_index(self, index: "Index") -> None:
self._exec(schema.DropIndex(index))
- def bulk_insert(self, table, rows, multiinsert=True):
+ def bulk_insert(
+ self,
+ table: Union["TableClause", "Table"],
+ rows: List[dict],
+ multiinsert: bool = True,
+ ) -> None:
if not isinstance(rows, list):
raise TypeError("List expected")
elif rows and not isinstance(rows[0], dict):
@@ -349,7 +434,7 @@ class DefaultImpl(metaclass=ImplMeta):
sqla_compat._insert_inline(table).values(**row)
)
- def _tokenize_column_type(self, column):
+ def _tokenize_column_type(self, column: "Column") -> Params:
definition = self.dialect.type_compiler.process(column.type).lower()
# tokenize the SQLAlchemy-generated version of a type, so that
@@ -387,7 +472,9 @@ class DefaultImpl(metaclass=ImplMeta):
return params
- def _column_types_match(self, inspector_params, metadata_params):
+ def _column_types_match(
+ self, inspector_params: "Params", metadata_params: "Params"
+ ) -> bool:
if inspector_params.token0 == metadata_params.token0:
return True
@@ -407,7 +494,9 @@ class DefaultImpl(metaclass=ImplMeta):
return True
return False
- def _column_args_match(self, inspected_params, meta_params):
+ def _column_args_match(
+ self, inspected_params: "Params", meta_params: "Params"
+ ) -> bool:
"""We want to compare column parameters. However, we only want
to compare parameters that are set. If they both have `collation`,
we want to make sure they are the same. However, if only one
@@ -438,7 +527,9 @@ class DefaultImpl(metaclass=ImplMeta):
return True
- def compare_type(self, inspector_column, metadata_column):
+ def compare_type(
+ self, inspector_column: "Column", metadata_column: "Column"
+ ) -> bool:
"""Returns True if there ARE differences between the types of the two
columns. Takes impl.type_synonyms into account between retrospected
and metadata types
@@ -463,11 +554,11 @@ class DefaultImpl(metaclass=ImplMeta):
def correct_for_autogen_constraints(
self,
- conn_uniques,
- conn_indexes,
- metadata_unique_constraints,
- metadata_indexes,
- ):
+ conn_uniques: Union[Set["UniqueConstraint"]],
+ conn_indexes: Union[Set["Index"]],
+ metadata_unique_constraints: Set["UniqueConstraint"],
+ metadata_indexes: Set["Index"],
+ ) -> None:
pass
def cast_for_batch_migrate(self, existing, existing_transfer, new_type):
@@ -476,7 +567,9 @@ class DefaultImpl(metaclass=ImplMeta):
existing_transfer["expr"], new_type
)
- def render_ddl_sql_expr(self, expr, is_server_default=False, **kw):
+ def render_ddl_sql_expr(
+ self, expr: "ClauseElement", is_server_default: bool = False, **kw
+ ) -> str:
"""Render a SQL expression that is typically a server default,
index expression, etc.
@@ -489,10 +582,16 @@ class DefaultImpl(metaclass=ImplMeta):
)
return text_type(expr.compile(dialect=self.dialect, **compile_kw))
- def _compat_autogen_column_reflect(self, inspector):
+ def _compat_autogen_column_reflect(
+ self, inspector: "Inspector"
+ ) -> Callable:
return self.autogen_column_reflect
- def correct_for_autogen_foreignkeys(self, conn_fks, metadata_fks):
+ def correct_for_autogen_foreignkeys(
+ self,
+ conn_fks: Set["ForeignKeyConstraint"],
+ metadata_fks: Set["ForeignKeyConstraint"],
+ ) -> None:
pass
def autogen_column_reflect(self, inspector, table, column_info):
@@ -504,7 +603,7 @@ class DefaultImpl(metaclass=ImplMeta):
"""
- def start_migrations(self):
+ def start_migrations(self) -> None:
"""A hook called when :meth:`.EnvironmentContext.run_migrations`
is called.
@@ -512,7 +611,7 @@ class DefaultImpl(metaclass=ImplMeta):
"""
- def emit_begin(self):
+ def emit_begin(self) -> None:
"""Emit the string ``BEGIN``, or the backend-specific
equivalent, on the current connection context.
@@ -522,7 +621,7 @@ class DefaultImpl(metaclass=ImplMeta):
"""
self.static_output("BEGIN" + self.command_terminator)
- def emit_commit(self):
+ def emit_commit(self) -> None:
"""Emit the string ``COMMIT``, or the backend-specific
equivalent, on the current connection context.
@@ -532,7 +631,9 @@ class DefaultImpl(metaclass=ImplMeta):
"""
self.static_output("COMMIT" + self.command_terminator)
- def render_type(self, type_obj, autogen_context):
+ def render_type(
+ self, type_obj: "TypeEngine", autogen_context: "AutogenContext"
+ ) -> Union[str, "Literal[False]"]:
return False
def _compare_identity_default(self, metadata_identity, inspector_identity):