summaryrefslogtreecommitdiff
path: root/alembic/ddl
diff options
context:
space:
mode:
authorCaselIT <cfederico87@gmail.com>2021-04-18 15:44:50 +0200
committerMike Bayer <mike_mp@zzzcomputing.com>2021-08-11 15:04:56 -0400
commit6aad68605f510e8b51f42efa812e02b3831d6e33 (patch)
treecc0e98b8ad8245add8692d8e4910faf57abf7ae3 /alembic/ddl
parent3bf6a326c0a11e4f05c94008709d6b0b8e9e051a (diff)
downloadalembic-6aad68605f510e8b51f42efa812e02b3831d6e33.tar.gz
Add pep-484 type annotations
pep-484 type annotations have been added throughout the library. This should be helpful in providing Mypy and IDE support, however there is not full support for Alembic's dynamically modified "op" namespace as of yet; a future release will likely modify the approach used for importing this namespace to be better compatible with pep-484 capabilities. Type originally created using MonkeyType Add types extracted with the MonkeyType https://github.com/instagram/MonkeyType library by running the unit tests using ``monkeytype run -m pytest tests``, then ``monkeytype apply <module>`` (see below for further details). USed MonkeyType version 20.5 on Python 3.8, since newer version have issues After applying the types, the new imports are placed in a ``TYPE_CHECKING`` guard and all type definition of non base types are deferred by using the string notation. NOTE: since to apply the types MonkeType need to import the module, also the test ones, the patch below mocks the setup done by pytest so that the tests could be correctly imported diff --git a/alembic/testing/__init__.py b/alembic/testing/__init__.py index bdd1746..b1090c7 100644 Change-Id: Iff93628f4b43c740848871ce077a118db5e75d41 --- a/alembic/testing/__init__.py +++ b/alembic/testing/__init__.py @@ -9,6 +9,12 @@ from sqlalchemy.testing.config import combinations from sqlalchemy.testing.config import fixture from sqlalchemy.testing.config import requirements as requires +from sqlalchemy.testing.plugin.pytestplugin import PytestFixtureFunctions +from sqlalchemy.testing.plugin.plugin_base import _setup_requirements + +config._fixture_functions = PytestFixtureFunctions() +_setup_requirements("tests.requirements:DefaultRequirements") + from alembic import util from .assertions import assert_raises from .assertions import assert_raises_message Currently I'm using this branch of the sqlalchemy stubs: https://github.com/sqlalchemy/sqlalchemy2-stubs/tree/alembic_updates Change-Id: I8fd0700aab1913f395302626b8b84fea60334abd
Diffstat (limited to 'alembic/ddl')
-rw-r--r--alembic/ddl/base.py163
-rw-r--r--alembic/ddl/impl.py239
-rw-r--r--alembic/ddl/mssql.py123
-rw-r--r--alembic/ddl/mysql.py118
-rw-r--r--alembic/ddl/oracle.py57
-rw-r--r--alembic/ddl/postgresql.py202
-rw-r--r--alembic/ddl/sqlite.py65
7 files changed, 703 insertions, 264 deletions
diff --git a/alembic/ddl/base.py b/alembic/ddl/base.py
index da81c72..022dc24 100644
--- a/alembic/ddl/base.py
+++ b/alembic/ddl/base.py
@@ -1,4 +1,7 @@
import functools
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import Union
from sqlalchemy import exc
from sqlalchemy import Integer
@@ -14,6 +17,20 @@ from ..util.sqla_compat import _fk_spec # noqa
from ..util.sqla_compat import _is_type_bound # noqa
from ..util.sqla_compat import _table_for_constraint # noqa
+if TYPE_CHECKING:
+ from sqlalchemy.sql.compiler import Compiled
+ from sqlalchemy.sql.compiler import DDLCompiler
+ from sqlalchemy.sql.elements import TextClause
+ from sqlalchemy.sql.functions import Function
+ from sqlalchemy.sql.schema import FetchedValue
+ from sqlalchemy.sql.type_api import TypeEngine
+
+ from .impl import DefaultImpl
+ from ..util.sqla_compat import Computed
+ from ..util.sqla_compat import Identity
+
+_ServerDefault = Union["TextClause", "FetchedValue", "Function", str]
+
class AlterTable(DDLElement):
@@ -24,13 +41,22 @@ class AlterTable(DDLElement):
"""
- def __init__(self, table_name, schema=None):
+ def __init__(
+ self,
+ table_name: str,
+ schema: Optional[Union["quoted_name", str]] = None,
+ ) -> None:
self.table_name = table_name
self.schema = schema
class RenameTable(AlterTable):
- def __init__(self, old_table_name, new_table_name, schema=None):
+ def __init__(
+ self,
+ old_table_name: str,
+ new_table_name: Union["quoted_name", str],
+ schema: Optional[Union["quoted_name", str]] = None,
+ ) -> None:
super(RenameTable, self).__init__(old_table_name, schema=schema)
self.new_table_name = new_table_name
@@ -38,14 +64,14 @@ class RenameTable(AlterTable):
class AlterColumn(AlterTable):
def __init__(
self,
- name,
- column_name,
- schema=None,
- existing_type=None,
- existing_nullable=None,
- existing_server_default=None,
- existing_comment=None,
- ):
+ name: str,
+ column_name: str,
+ schema: Optional[str] = None,
+ existing_type: Optional["TypeEngine"] = None,
+ existing_nullable: Optional[bool] = None,
+ existing_server_default: Optional[_ServerDefault] = None,
+ existing_comment: Optional[str] = None,
+ ) -> None:
super(AlterColumn, self).__init__(name, schema=schema)
self.column_name = column_name
self.existing_type = (
@@ -59,62 +85,94 @@ class AlterColumn(AlterTable):
class ColumnNullable(AlterColumn):
- def __init__(self, name, column_name, nullable, **kw):
+ def __init__(
+ self, name: str, column_name: str, nullable: bool, **kw
+ ) -> None:
super(ColumnNullable, self).__init__(name, column_name, **kw)
self.nullable = nullable
class ColumnType(AlterColumn):
- def __init__(self, name, column_name, type_, **kw):
+ def __init__(
+ self, name: str, column_name: str, type_: "TypeEngine", **kw
+ ) -> None:
super(ColumnType, self).__init__(name, column_name, **kw)
self.type_ = sqltypes.to_instance(type_)
class ColumnName(AlterColumn):
- def __init__(self, name, column_name, newname, **kw):
+ def __init__(
+ self, name: str, column_name: str, newname: str, **kw
+ ) -> None:
super(ColumnName, self).__init__(name, column_name, **kw)
self.newname = newname
class ColumnDefault(AlterColumn):
- def __init__(self, name, column_name, default, **kw):
+ def __init__(
+ self,
+ name: str,
+ column_name: str,
+ default: Optional[_ServerDefault],
+ **kw
+ ) -> None:
super(ColumnDefault, self).__init__(name, column_name, **kw)
self.default = default
class ComputedColumnDefault(AlterColumn):
- def __init__(self, name, column_name, default, **kw):
+ def __init__(
+ self, name: str, column_name: str, default: Optional["Computed"], **kw
+ ) -> None:
super(ComputedColumnDefault, self).__init__(name, column_name, **kw)
self.default = default
class IdentityColumnDefault(AlterColumn):
- def __init__(self, name, column_name, default, impl, **kw):
+ def __init__(
+ self,
+ name: str,
+ column_name: str,
+ default: Optional["Identity"],
+ impl: "DefaultImpl",
+ **kw
+ ) -> None:
super(IdentityColumnDefault, self).__init__(name, column_name, **kw)
self.default = default
self.impl = impl
class AddColumn(AlterTable):
- def __init__(self, name, column, schema=None):
+ def __init__(
+ self,
+ name: str,
+ column: "Column",
+ schema: Optional[Union["quoted_name", str]] = None,
+ ) -> None:
super(AddColumn, self).__init__(name, schema=schema)
self.column = column
class DropColumn(AlterTable):
- def __init__(self, name, column, schema=None):
+ def __init__(
+ self, name: str, column: "Column", schema: Optional[str] = None
+ ) -> None:
super(DropColumn, self).__init__(name, schema=schema)
self.column = column
class ColumnComment(AlterColumn):
- def __init__(self, name, column_name, comment, **kw):
+ def __init__(
+ self, name: str, column_name: str, comment: Optional[str], **kw
+ ) -> None:
super(ColumnComment, self).__init__(name, column_name, **kw)
self.comment = comment
@compiles(RenameTable)
-def visit_rename_table(element, compiler, **kw):
+def visit_rename_table(
+ element: "RenameTable", compiler: "DDLCompiler", **kw
+) -> str:
return "%s RENAME TO %s" % (
alter_table(compiler, element.table_name, element.schema),
format_table_name(compiler, element.new_table_name, element.schema),
@@ -122,7 +180,9 @@ def visit_rename_table(element, compiler, **kw):
@compiles(AddColumn)
-def visit_add_column(element, compiler, **kw):
+def visit_add_column(
+ element: "AddColumn", compiler: "DDLCompiler", **kw
+) -> str:
return "%s %s" % (
alter_table(compiler, element.table_name, element.schema),
add_column(compiler, element.column, **kw),
@@ -130,7 +190,9 @@ def visit_add_column(element, compiler, **kw):
@compiles(DropColumn)
-def visit_drop_column(element, compiler, **kw):
+def visit_drop_column(
+ element: "DropColumn", compiler: "DDLCompiler", **kw
+) -> str:
return "%s %s" % (
alter_table(compiler, element.table_name, element.schema),
drop_column(compiler, element.column.name, **kw),
@@ -138,7 +200,9 @@ def visit_drop_column(element, compiler, **kw):
@compiles(ColumnNullable)
-def visit_column_nullable(element, compiler, **kw):
+def visit_column_nullable(
+ element: "ColumnNullable", compiler: "DDLCompiler", **kw
+) -> str:
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
@@ -147,7 +211,9 @@ def visit_column_nullable(element, compiler, **kw):
@compiles(ColumnType)
-def visit_column_type(element, compiler, **kw):
+def visit_column_type(
+ element: "ColumnType", compiler: "DDLCompiler", **kw
+) -> str:
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
@@ -156,7 +222,9 @@ def visit_column_type(element, compiler, **kw):
@compiles(ColumnName)
-def visit_column_name(element, compiler, **kw):
+def visit_column_name(
+ element: "ColumnName", compiler: "DDLCompiler", **kw
+) -> str:
return "%s RENAME %s TO %s" % (
alter_table(compiler, element.table_name, element.schema),
format_column_name(compiler, element.column_name),
@@ -165,7 +233,9 @@ def visit_column_name(element, compiler, **kw):
@compiles(ColumnDefault)
-def visit_column_default(element, compiler, **kw):
+def visit_column_default(
+ element: "ColumnDefault", compiler: "DDLCompiler", **kw
+) -> str:
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
@@ -176,7 +246,9 @@ def visit_column_default(element, compiler, **kw):
@compiles(ComputedColumnDefault)
-def visit_computed_column(element, compiler, **kw):
+def visit_computed_column(
+ element: "ComputedColumnDefault", compiler: "DDLCompiler", **kw
+):
raise exc.CompileError(
'Adding or removing a "computed" construct, e.g. GENERATED '
"ALWAYS AS, to or from an existing column is not supported."
@@ -184,7 +256,9 @@ def visit_computed_column(element, compiler, **kw):
@compiles(IdentityColumnDefault)
-def visit_identity_column(element, compiler, **kw):
+def visit_identity_column(
+ element: "IdentityColumnDefault", compiler: "DDLCompiler", **kw
+):
raise exc.CompileError(
'Adding, removing or modifying an "identity" construct, '
"e.g. GENERATED AS IDENTITY, to or from an existing "
@@ -192,7 +266,9 @@ def visit_identity_column(element, compiler, **kw):
)
-def quote_dotted(name, quote):
+def quote_dotted(
+ name: Union["quoted_name", str], quote: functools.partial
+) -> Union["quoted_name", str]:
"""quote the elements of a dotted name"""
if isinstance(name, quoted_name):
@@ -201,7 +277,11 @@ def quote_dotted(name, quote):
return result
-def format_table_name(compiler, name, schema):
+def format_table_name(
+ compiler: "Compiled",
+ name: Union["quoted_name", str],
+ schema: Optional[Union["quoted_name", str]],
+) -> Union["quoted_name", str]:
quote = functools.partial(compiler.preparer.quote)
if schema:
return quote_dotted(schema, quote) + "." + quote(name)
@@ -209,33 +289,42 @@ def format_table_name(compiler, name, schema):
return quote(name)
-def format_column_name(compiler, name):
+def format_column_name(
+ compiler: "DDLCompiler", name: Optional[Union["quoted_name", str]]
+) -> Union["quoted_name", str]:
return compiler.preparer.quote(name)
-def format_server_default(compiler, default):
+def format_server_default(
+ compiler: "DDLCompiler",
+ default: Optional[_ServerDefault],
+) -> str:
return compiler.get_column_default_string(
Column("x", Integer, server_default=default)
)
-def format_type(compiler, type_):
+def format_type(compiler: "DDLCompiler", type_: "TypeEngine") -> str:
return compiler.dialect.type_compiler.process(type_)
-def alter_table(compiler, name, schema):
+def alter_table(
+ compiler: "DDLCompiler",
+ name: str,
+ schema: Optional[str],
+) -> str:
return "ALTER TABLE %s" % format_table_name(compiler, name, schema)
-def drop_column(compiler, name):
+def drop_column(compiler: "DDLCompiler", name: str, **kw) -> str:
return "DROP COLUMN %s" % format_column_name(compiler, name)
-def alter_column(compiler, name):
+def alter_column(compiler: "DDLCompiler", name: str) -> str:
return "ALTER COLUMN %s" % format_column_name(compiler, name)
-def add_column(compiler, column, **kw):
+def add_column(compiler: "DDLCompiler", column: "Column", **kw) -> str:
text = "ADD COLUMN %s" % compiler.get_column_specification(column, **kw)
const = " ".join(
diff --git a/alembic/ddl/impl.py b/alembic/ddl/impl.py
index 710509c..2ca316c 100644
--- a/alembic/ddl/impl.py
+++ b/alembic/ddl/impl.py
@@ -1,5 +1,16 @@
from collections import namedtuple
import re
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Set
+from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
+from typing import Union
from sqlalchemy import cast
from sqlalchemy import schema
@@ -11,16 +22,49 @@ from ..util import sqla_compat
from ..util.compat import string_types
from ..util.compat import text_type
+if TYPE_CHECKING:
+ from io import StringIO
+ from typing import Literal
+
+ from sqlalchemy.engine import Connection
+ from sqlalchemy.engine import Dialect
+ from sqlalchemy.engine.cursor import CursorResult
+ from sqlalchemy.engine.cursor import LegacyCursorResult
+ from sqlalchemy.engine.reflection import Inspector
+ from sqlalchemy.sql.dml import Update
+ from sqlalchemy.sql.elements import ClauseElement
+ from sqlalchemy.sql.elements import ColumnElement
+ from sqlalchemy.sql.elements import quoted_name
+ from sqlalchemy.sql.elements import TextClause
+ from sqlalchemy.sql.schema import Column
+ from sqlalchemy.sql.schema import Constraint
+ from sqlalchemy.sql.schema import ForeignKeyConstraint
+ from sqlalchemy.sql.schema import Index
+ from sqlalchemy.sql.schema import Table
+ from sqlalchemy.sql.schema import UniqueConstraint
+ from sqlalchemy.sql.selectable import TableClause
+ from sqlalchemy.sql.type_api import TypeEngine
+
+ from .base import _ServerDefault
+ from ..autogenerate.api import AutogenContext
+ from ..operations.batch import ApplyBatchImpl
+ from ..operations.batch import BatchOperationsImpl
+
class ImplMeta(type):
- def __init__(cls, classname, bases, dict_):
+ def __init__(
+ cls,
+ classname: str,
+ bases: Tuple[Type["DefaultImpl"]],
+ dict_: Dict[str, Any],
+ ):
newtype = type.__init__(cls, classname, bases, dict_)
if "__dialect__" in dict_:
_impls[dict_["__dialect__"]] = cls
return newtype
-_impls = {}
+_impls: dict = {}
Params = namedtuple("Params", ["token0", "tokens", "args", "kwargs"])
@@ -43,27 +87,27 @@ class DefaultImpl(metaclass=ImplMeta):
transactional_ddl = False
command_terminator = ";"
- type_synonyms = ({"NUMERIC", "DECIMAL"},)
- type_arg_extract = ()
+ type_synonyms: Tuple[Set[str], ...] = ({"NUMERIC", "DECIMAL"},)
+ type_arg_extract: Sequence[str] = ()
# on_null is known to be supported only by oracle
- identity_attrs_ignore = ("on_null",)
+ identity_attrs_ignore: Tuple[str, ...] = ("on_null",)
def __init__(
self,
- dialect,
- connection,
- as_sql,
- transactional_ddl,
- output_buffer,
- context_opts,
- ):
+ dialect: "Dialect",
+ connection: Optional["Connection"],
+ as_sql: bool,
+ transactional_ddl: Optional[bool],
+ output_buffer: Optional["StringIO"],
+ context_opts: Dict[str, Any],
+ ) -> None:
self.dialect = dialect
self.connection = connection
self.as_sql = as_sql
self.literal_binds = context_opts.get("literal_binds", False)
self.output_buffer = output_buffer
- self.memo = {}
+ self.memo: dict = {}
self.context_opts = context_opts
if transactional_ddl is not None:
self.transactional_ddl = transactional_ddl
@@ -75,14 +119,17 @@ class DefaultImpl(metaclass=ImplMeta):
)
@classmethod
- def get_by_dialect(cls, dialect):
+ def get_by_dialect(cls, dialect: "Dialect") -> Any:
return _impls[dialect.name]
- def static_output(self, text):
+ def static_output(self, text: str) -> None:
+ assert self.output_buffer is not None
self.output_buffer.write(text_type(text + "\n\n"))
self.output_buffer.flush()
- def requires_recreate_in_batch(self, batch_op):
+ def requires_recreate_in_batch(
+ self, batch_op: "BatchOperationsImpl"
+ ) -> bool:
"""Return True if the given :class:`.BatchOperationsImpl`
would need the table to be recreated and copied in order to
proceed.
@@ -93,7 +140,9 @@ class DefaultImpl(metaclass=ImplMeta):
"""
return False
- def prep_table_for_batch(self, batch_impl, table):
+ def prep_table_for_batch(
+ self, batch_impl: "ApplyBatchImpl", table: "Table"
+ ) -> None:
"""perform any operations needed on a table before a new
one is created to replace it in batch mode.
@@ -103,16 +152,16 @@ class DefaultImpl(metaclass=ImplMeta):
"""
@property
- def bind(self):
+ def bind(self) -> Optional["Connection"]:
return self.connection
def _exec(
self,
- construct,
- execution_options=None,
- multiparams=(),
- params=util.immutabledict(),
- ):
+ construct: Union["ClauseElement", str],
+ execution_options: None = None,
+ multiparams: Sequence[dict] = (),
+ params: Dict[str, int] = util.immutabledict(),
+ ) -> Optional[Union["LegacyCursorResult", "CursorResult"]]:
if isinstance(construct, string_types):
construct = text(construct)
if self.as_sql:
@@ -135,35 +184,43 @@ class DefaultImpl(metaclass=ImplMeta):
.strip()
+ self.command_terminator
)
+ return None
else:
conn = self.connection
+ assert conn is not None
if execution_options:
conn = conn.execution_options(**execution_options)
if params:
+ assert isinstance(multiparams, tuple)
multiparams += (params,)
return conn.execute(construct, multiparams)
- def execute(self, sql, execution_options=None):
+ def execute(
+ self,
+ sql: Union["Update", "TextClause", str],
+ execution_options: None = None,
+ ) -> None:
self._exec(sql, execution_options)
def alter_column(
self,
- table_name,
- column_name,
- nullable=None,
- server_default=False,
- name=None,
- type_=None,
- schema=None,
- autoincrement=None,
- comment=False,
- existing_comment=None,
- existing_type=None,
- existing_server_default=None,
- existing_nullable=None,
- existing_autoincrement=None,
- ):
+ table_name: str,
+ column_name: str,
+ nullable: Optional[bool] = None,
+ server_default: Union["_ServerDefault", "Literal[False]"] = False,
+ name: Optional[str] = None,
+ type_: Optional["TypeEngine"] = None,
+ schema: Optional[str] = None,
+ autoincrement: Optional[bool] = None,
+ comment: Optional[Union[str, "Literal[False]"]] = False,
+ existing_comment: Optional[str] = None,
+ existing_type: Optional["TypeEngine"] = None,
+ existing_server_default: Optional["_ServerDefault"] = None,
+ existing_nullable: Optional[bool] = None,
+ existing_autoincrement: Optional[bool] = None,
+ **kw: Any
+ ) -> None:
if autoincrement is not None or existing_autoincrement is not None:
util.warn(
"autoincrement and existing_autoincrement "
@@ -185,6 +242,13 @@ class DefaultImpl(metaclass=ImplMeta):
)
if server_default is not False:
kw = {}
+ cls_: Type[
+ Union[
+ base.ComputedColumnDefault,
+ base.IdentityColumnDefault,
+ base.ColumnDefault,
+ ]
+ ]
if sqla_compat._server_default_is_computed(
server_default, existing_server_default
):
@@ -200,7 +264,7 @@ class DefaultImpl(metaclass=ImplMeta):
cls_(
table_name,
column_name,
- server_default,
+ server_default, # type:ignore[arg-type]
schema=schema,
existing_type=existing_type,
existing_server_default=existing_server_default,
@@ -251,25 +315,41 @@ class DefaultImpl(metaclass=ImplMeta):
)
)
- def add_column(self, table_name, column, schema=None):
+ def add_column(
+ self,
+ table_name: str,
+ column: "Column",
+ schema: Optional[Union[str, "quoted_name"]] = None,
+ ) -> None:
self._exec(base.AddColumn(table_name, column, schema=schema))
- def drop_column(self, table_name, column, schema=None, **kw):
+ def drop_column(
+ self,
+ table_name: str,
+ column: "Column",
+ schema: Optional[str] = None,
+ **kw
+ ) -> None:
self._exec(base.DropColumn(table_name, column, schema=schema))
- def add_constraint(self, const):
+ def add_constraint(self, const: Any) -> None:
if const._create_rule is None or const._create_rule(self):
self._exec(schema.AddConstraint(const))
- def drop_constraint(self, const):
+ def drop_constraint(self, const: "Constraint") -> None:
self._exec(schema.DropConstraint(const))
- def rename_table(self, old_table_name, new_table_name, schema=None):
+ def rename_table(
+ self,
+ old_table_name: str,
+ new_table_name: Union[str, "quoted_name"],
+ schema: Optional[Union[str, "quoted_name"]] = None,
+ ) -> None:
self._exec(
base.RenameTable(old_table_name, new_table_name, schema=schema)
)
- def create_table(self, table):
+ def create_table(self, table: "Table") -> None:
table.dispatch.before_create(
table, self.connection, checkfirst=False, _ddl_runner=self
)
@@ -292,25 +372,30 @@ class DefaultImpl(metaclass=ImplMeta):
if comment and with_comment:
self.create_column_comment(column)
- def drop_table(self, table):
+ def drop_table(self, table: "Table") -> None:
self._exec(schema.DropTable(table))
- def create_index(self, index):
+ def create_index(self, index: "Index") -> None:
self._exec(schema.CreateIndex(index))
- def create_table_comment(self, table):
+ def create_table_comment(self, table: "Table") -> None:
self._exec(schema.SetTableComment(table))
- def drop_table_comment(self, table):
+ def drop_table_comment(self, table: "Table") -> None:
self._exec(schema.DropTableComment(table))
- def create_column_comment(self, column):
+ def create_column_comment(self, column: "ColumnElement") -> None:
self._exec(schema.SetColumnComment(column))
- def drop_index(self, index):
+ def drop_index(self, index: "Index") -> None:
self._exec(schema.DropIndex(index))
- def bulk_insert(self, table, rows, multiinsert=True):
+ def bulk_insert(
+ self,
+ table: Union["TableClause", "Table"],
+ rows: List[dict],
+ multiinsert: bool = True,
+ ) -> None:
if not isinstance(rows, list):
raise TypeError("List expected")
elif rows and not isinstance(rows[0], dict):
@@ -349,7 +434,7 @@ class DefaultImpl(metaclass=ImplMeta):
sqla_compat._insert_inline(table).values(**row)
)
- def _tokenize_column_type(self, column):
+ def _tokenize_column_type(self, column: "Column") -> Params:
definition = self.dialect.type_compiler.process(column.type).lower()
# tokenize the SQLAlchemy-generated version of a type, so that
@@ -387,7 +472,9 @@ class DefaultImpl(metaclass=ImplMeta):
return params
- def _column_types_match(self, inspector_params, metadata_params):
+ def _column_types_match(
+ self, inspector_params: "Params", metadata_params: "Params"
+ ) -> bool:
if inspector_params.token0 == metadata_params.token0:
return True
@@ -407,7 +494,9 @@ class DefaultImpl(metaclass=ImplMeta):
return True
return False
- def _column_args_match(self, inspected_params, meta_params):
+ def _column_args_match(
+ self, inspected_params: "Params", meta_params: "Params"
+ ) -> bool:
"""We want to compare column parameters. However, we only want
to compare parameters that are set. If they both have `collation`,
we want to make sure they are the same. However, if only one
@@ -438,7 +527,9 @@ class DefaultImpl(metaclass=ImplMeta):
return True
- def compare_type(self, inspector_column, metadata_column):
+ def compare_type(
+ self, inspector_column: "Column", metadata_column: "Column"
+ ) -> bool:
"""Returns True if there ARE differences between the types of the two
columns. Takes impl.type_synonyms into account between retrospected
and metadata types
@@ -463,11 +554,11 @@ class DefaultImpl(metaclass=ImplMeta):
def correct_for_autogen_constraints(
self,
- conn_uniques,
- conn_indexes,
- metadata_unique_constraints,
- metadata_indexes,
- ):
+ conn_uniques: Union[Set["UniqueConstraint"]],
+ conn_indexes: Union[Set["Index"]],
+ metadata_unique_constraints: Set["UniqueConstraint"],
+ metadata_indexes: Set["Index"],
+ ) -> None:
pass
def cast_for_batch_migrate(self, existing, existing_transfer, new_type):
@@ -476,7 +567,9 @@ class DefaultImpl(metaclass=ImplMeta):
existing_transfer["expr"], new_type
)
- def render_ddl_sql_expr(self, expr, is_server_default=False, **kw):
+ def render_ddl_sql_expr(
+ self, expr: "ClauseElement", is_server_default: bool = False, **kw
+ ) -> str:
"""Render a SQL expression that is typically a server default,
index expression, etc.
@@ -489,10 +582,16 @@ class DefaultImpl(metaclass=ImplMeta):
)
return text_type(expr.compile(dialect=self.dialect, **compile_kw))
- def _compat_autogen_column_reflect(self, inspector):
+ def _compat_autogen_column_reflect(
+ self, inspector: "Inspector"
+ ) -> Callable:
return self.autogen_column_reflect
- def correct_for_autogen_foreignkeys(self, conn_fks, metadata_fks):
+ def correct_for_autogen_foreignkeys(
+ self,
+ conn_fks: Set["ForeignKeyConstraint"],
+ metadata_fks: Set["ForeignKeyConstraint"],
+ ) -> None:
pass
def autogen_column_reflect(self, inspector, table, column_info):
@@ -504,7 +603,7 @@ class DefaultImpl(metaclass=ImplMeta):
"""
- def start_migrations(self):
+ def start_migrations(self) -> None:
"""A hook called when :meth:`.EnvironmentContext.run_migrations`
is called.
@@ -512,7 +611,7 @@ class DefaultImpl(metaclass=ImplMeta):
"""
- def emit_begin(self):
+ def emit_begin(self) -> None:
"""Emit the string ``BEGIN``, or the backend-specific
equivalent, on the current connection context.
@@ -522,7 +621,7 @@ class DefaultImpl(metaclass=ImplMeta):
"""
self.static_output("BEGIN" + self.command_terminator)
- def emit_commit(self):
+ def emit_commit(self) -> None:
"""Emit the string ``COMMIT``, or the backend-specific
equivalent, on the current connection context.
@@ -532,7 +631,9 @@ class DefaultImpl(metaclass=ImplMeta):
"""
self.static_output("COMMIT" + self.command_terminator)
- def render_type(self, type_obj, autogen_context):
+ def render_type(
+ self, type_obj: "TypeEngine", autogen_context: "AutogenContext"
+ ) -> Union[str, "Literal[False]"]:
return False
def _compare_identity_default(self, metadata_identity, inspector_identity):
diff --git a/alembic/ddl/mssql.py b/alembic/ddl/mssql.py
index 8a99ee6..9e1ef76 100644
--- a/alembic/ddl/mssql.py
+++ b/alembic/ddl/mssql.py
@@ -1,9 +1,15 @@
+from typing import Any
+from typing import List
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import Union
+
from sqlalchemy import types as sqltypes
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.schema import Column
from sqlalchemy.schema import CreateIndex
-from sqlalchemy.sql.expression import ClauseElement
-from sqlalchemy.sql.expression import Executable
+from sqlalchemy.sql.base import Executable
+from sqlalchemy.sql.elements import ClauseElement
from .base import AddColumn
from .base import alter_column
@@ -21,6 +27,20 @@ from .impl import DefaultImpl
from .. import util
from ..util import sqla_compat
+if TYPE_CHECKING:
+ from typing import Literal
+
+ from sqlalchemy.dialects.mssql.base import MSDDLCompiler
+ from sqlalchemy.dialects.mssql.base import MSSQLCompiler
+ from sqlalchemy.engine.cursor import CursorResult
+ from sqlalchemy.engine.cursor import LegacyCursorResult
+ from sqlalchemy.sql.schema import Index
+ from sqlalchemy.sql.schema import Table
+ from sqlalchemy.sql.selectable import TableClause
+ from sqlalchemy.sql.type_api import TypeEngine
+
+ from .base import _ServerDefault
+
class MSSQLImpl(DefaultImpl):
__dialect__ = "mssql"
@@ -40,40 +60,44 @@ class MSSQLImpl(DefaultImpl):
"order",
)
- def __init__(self, *arg, **kw):
+ def __init__(self, *arg, **kw) -> None:
super(MSSQLImpl, self).__init__(*arg, **kw)
self.batch_separator = self.context_opts.get(
"mssql_batch_separator", self.batch_separator
)
- def _exec(self, construct, *args, **kw):
+ def _exec(
+ self, construct: Any, *args, **kw
+ ) -> Optional[Union["LegacyCursorResult", "CursorResult"]]:
result = super(MSSQLImpl, self)._exec(construct, *args, **kw)
if self.as_sql and self.batch_separator:
self.static_output(self.batch_separator)
return result
- def emit_begin(self):
+ def emit_begin(self) -> None:
self.static_output("BEGIN TRANSACTION" + self.command_terminator)
- def emit_commit(self):
+ def emit_commit(self) -> None:
super(MSSQLImpl, self).emit_commit()
if self.as_sql and self.batch_separator:
self.static_output(self.batch_separator)
- def alter_column(
+ def alter_column( # type:ignore[override]
self,
- table_name,
- column_name,
- nullable=None,
- server_default=False,
- name=None,
- type_=None,
- schema=None,
- existing_type=None,
- existing_server_default=None,
- existing_nullable=None,
- **kw
- ):
+ table_name: str,
+ column_name: str,
+ nullable: Optional[bool] = None,
+ server_default: Optional[
+ Union["_ServerDefault", "Literal[False]"]
+ ] = False,
+ name: Optional[str] = None,
+ type_: Optional["TypeEngine"] = None,
+ schema: Optional[str] = None,
+ existing_type: Optional["TypeEngine"] = None,
+ existing_server_default: Optional["_ServerDefault"] = None,
+ existing_nullable: Optional[bool] = None,
+ **kw: Any
+ ) -> None:
if nullable is not None:
if existing_type is None:
@@ -138,17 +162,20 @@ class MSSQLImpl(DefaultImpl):
table_name, column_name, schema=schema, name=name
)
- def create_index(self, index):
+ def create_index(self, index: "Index") -> None:
# this likely defaults to None if not present, so get()
# should normally not return the default value. being
# defensive in any case
mssql_include = index.kwargs.get("mssql_include", None) or ()
+ assert index.table is not None
for col in mssql_include:
if col not in index.table.c:
index.table.append_column(Column(col, sqltypes.NullType))
self._exec(CreateIndex(index))
- def bulk_insert(self, table, rows, **kw):
+ def bulk_insert( # type:ignore[override]
+ self, table: Union["TableClause", "Table"], rows: List[dict], **kw: Any
+ ) -> None:
if self.as_sql:
self._exec(
"SET IDENTITY_INSERT %s ON"
@@ -162,7 +189,13 @@ class MSSQLImpl(DefaultImpl):
else:
super(MSSQLImpl, self).bulk_insert(table, rows, **kw)
- def drop_column(self, table_name, column, schema=None, **kw):
+ def drop_column(
+ self,
+ table_name: str,
+ column: "Column",
+ schema: Optional[str] = None,
+ **kw
+ ) -> None:
drop_default = kw.pop("mssql_drop_default", False)
if drop_default:
self._exec(
@@ -222,7 +255,13 @@ class MSSQLImpl(DefaultImpl):
class _ExecDropConstraint(Executable, ClauseElement):
- def __init__(self, tname, colname, type_, schema):
+ def __init__(
+ self,
+ tname: str,
+ colname: Union["Column", str],
+ type_: str,
+ schema: Optional[str],
+ ) -> None:
self.tname = tname
self.colname = colname
self.type_ = type_
@@ -230,14 +269,18 @@ class _ExecDropConstraint(Executable, ClauseElement):
class _ExecDropFKConstraint(Executable, ClauseElement):
- def __init__(self, tname, colname, schema):
+ def __init__(
+ self, tname: str, colname: "Column", schema: Optional[str]
+ ) -> None:
self.tname = tname
self.colname = colname
self.schema = schema
@compiles(_ExecDropConstraint, "mssql")
-def _exec_drop_col_constraint(element, compiler, **kw):
+def _exec_drop_col_constraint(
+ element: "_ExecDropConstraint", compiler: "MSSQLCompiler", **kw
+) -> str:
schema, tname, colname, type_ = (
element.schema,
element.tname,
@@ -261,7 +304,9 @@ exec('alter table %(tname_quoted)s drop constraint ' + @const_name)""" % {
@compiles(_ExecDropFKConstraint, "mssql")
-def _exec_drop_col_fk_constraint(element, compiler, **kw):
+def _exec_drop_col_fk_constraint(
+ element: "_ExecDropFKConstraint", compiler: "MSSQLCompiler", **kw
+) -> str:
schema, tname, colname = element.schema, element.tname, element.colname
return """declare @const_name varchar(256)
@@ -279,19 +324,23 @@ exec('alter table %(tname_quoted)s drop constraint ' + @const_name)""" % {
@compiles(AddColumn, "mssql")
-def visit_add_column(element, compiler, **kw):
+def visit_add_column(
+ element: "AddColumn", compiler: "MSDDLCompiler", **kw
+) -> str:
return "%s %s" % (
alter_table(compiler, element.table_name, element.schema),
mssql_add_column(compiler, element.column, **kw),
)
-def mssql_add_column(compiler, column, **kw):
+def mssql_add_column(compiler: "MSDDLCompiler", column: "Column", **kw) -> str:
return "ADD %s" % compiler.get_column_specification(column, **kw)
@compiles(ColumnNullable, "mssql")
-def visit_column_nullable(element, compiler, **kw):
+def visit_column_nullable(
+ element: "ColumnNullable", compiler: "MSDDLCompiler", **kw
+) -> str:
return "%s %s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
@@ -301,7 +350,9 @@ def visit_column_nullable(element, compiler, **kw):
@compiles(ColumnDefault, "mssql")
-def visit_column_default(element, compiler, **kw):
+def visit_column_default(
+ element: "ColumnDefault", compiler: "MSDDLCompiler", **kw
+) -> str:
# TODO: there can also be a named constraint
# with ADD CONSTRAINT here
return "%s ADD DEFAULT %s FOR %s" % (
@@ -312,7 +363,9 @@ def visit_column_default(element, compiler, **kw):
@compiles(ColumnName, "mssql")
-def visit_rename_column(element, compiler, **kw):
+def visit_rename_column(
+ element: "ColumnName", compiler: "MSDDLCompiler", **kw
+) -> str:
return "EXEC sp_rename '%s.%s', %s, 'COLUMN'" % (
format_table_name(compiler, element.table_name, element.schema),
format_column_name(compiler, element.column_name),
@@ -321,7 +374,9 @@ def visit_rename_column(element, compiler, **kw):
@compiles(ColumnType, "mssql")
-def visit_column_type(element, compiler, **kw):
+def visit_column_type(
+ element: "ColumnType", compiler: "MSDDLCompiler", **kw
+) -> str:
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
@@ -330,7 +385,9 @@ def visit_column_type(element, compiler, **kw):
@compiles(RenameTable, "mssql")
-def visit_rename_table(element, compiler, **kw):
+def visit_rename_table(
+ element: "RenameTable", compiler: "MSDDLCompiler", **kw
+) -> str:
return "EXEC sp_rename '%s', %s" % (
format_table_name(compiler, element.table_name, element.schema),
format_table_name(compiler, element.new_table_name, None),
diff --git a/alembic/ddl/mysql.py b/alembic/ddl/mysql.py
index 4761f75..9489560 100644
--- a/alembic/ddl/mysql.py
+++ b/alembic/ddl/mysql.py
@@ -1,4 +1,8 @@
import re
+from typing import Any
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import Union
from sqlalchemy import schema
from sqlalchemy import types as sqltypes
@@ -19,6 +23,16 @@ from ..util import sqla_compat
from ..util.sqla_compat import _is_mariadb
from ..util.sqla_compat import _is_type_bound
+if TYPE_CHECKING:
+ from typing import Literal
+
+ from sqlalchemy.dialects.mysql.base import MySQLDDLCompiler
+ from sqlalchemy.sql.ddl import DropConstraint
+ from sqlalchemy.sql.schema import Constraint
+ from sqlalchemy.sql.type_api import TypeEngine
+
+ from .base import _ServerDefault
+
class MySQLImpl(DefaultImpl):
__dialect__ = "mysql"
@@ -27,24 +41,24 @@ class MySQLImpl(DefaultImpl):
type_synonyms = DefaultImpl.type_synonyms + ({"BOOL", "TINYINT"},)
type_arg_extract = [r"character set ([\w\-_]+)", r"collate ([\w\-_]+)"]
- def alter_column(
+ def alter_column( # type:ignore[override]
self,
- table_name,
- column_name,
- nullable=None,
- server_default=False,
- name=None,
- type_=None,
- schema=None,
- existing_type=None,
- existing_server_default=None,
- existing_nullable=None,
- autoincrement=None,
- existing_autoincrement=None,
- comment=False,
- existing_comment=None,
- **kw
- ):
+ table_name: str,
+ column_name: str,
+ nullable: Optional[bool] = None,
+ server_default: Union["_ServerDefault", "Literal[False]"] = False,
+ name: Optional[str] = None,
+ type_: Optional["TypeEngine"] = None,
+ schema: Optional[str] = None,
+ existing_type: Optional["TypeEngine"] = None,
+ existing_server_default: Optional["_ServerDefault"] = None,
+ existing_nullable: Optional[bool] = None,
+ autoincrement: Optional[bool] = None,
+ existing_autoincrement: Optional[bool] = None,
+ comment: Optional[Union[str, "Literal[False]"]] = False,
+ existing_comment: Optional[str] = None,
+ **kw: Any
+ ) -> None:
if sqla_compat._server_default_is_identity(
server_default, existing_server_default
) or sqla_compat._server_default_is_computed(
@@ -126,16 +140,24 @@ class MySQLImpl(DefaultImpl):
)
)
- def drop_constraint(self, const):
+ def drop_constraint(
+ self,
+ const: "Constraint",
+ ) -> None:
if isinstance(const, schema.CheckConstraint) and _is_type_bound(const):
return
super(MySQLImpl, self).drop_constraint(const)
- def _is_mysql_allowed_functional_default(self, type_, server_default):
+ def _is_mysql_allowed_functional_default(
+ self,
+ type_: Optional["TypeEngine"],
+ server_default: Union["_ServerDefault", "Literal[False]"],
+ ) -> bool:
return (
type_ is not None
- and type_._type_affinity is sqltypes.DateTime
+ and type_._type_affinity # type:ignore[attr-defined]
+ is sqltypes.DateTime
and server_default is not None
)
@@ -268,7 +290,13 @@ class MariaDBImpl(MySQLImpl):
class MySQLAlterDefault(AlterColumn):
- def __init__(self, name, column_name, default, schema=None):
+ def __init__(
+ self,
+ name: str,
+ column_name: str,
+ default: "_ServerDefault",
+ schema: Optional[str] = None,
+ ) -> None:
super(AlterColumn, self).__init__(name, schema=schema)
self.column_name = column_name
self.default = default
@@ -277,16 +305,16 @@ class MySQLAlterDefault(AlterColumn):
class MySQLChangeColumn(AlterColumn):
def __init__(
self,
- name,
- column_name,
- schema=None,
- newname=None,
- type_=None,
- nullable=None,
- default=False,
- autoincrement=None,
- comment=False,
- ):
+ name: str,
+ column_name: str,
+ schema: Optional[str] = None,
+ newname: Optional[str] = None,
+ type_: Optional["TypeEngine"] = None,
+ nullable: Optional[bool] = None,
+ default: Optional[Union["_ServerDefault", "Literal[False]"]] = False,
+ autoincrement: Optional[bool] = None,
+ comment: Optional[Union[str, "Literal[False]"]] = False,
+ ) -> None:
super(AlterColumn, self).__init__(name, schema=schema)
self.column_name = column_name
self.nullable = nullable
@@ -318,7 +346,9 @@ def _mysql_doesnt_support_individual(element, compiler, **kw):
@compiles(MySQLAlterDefault, "mysql", "mariadb")
-def _mysql_alter_default(element, compiler, **kw):
+def _mysql_alter_default(
+ element: "MySQLAlterDefault", compiler: "MySQLDDLCompiler", **kw
+) -> str:
return "%s ALTER COLUMN %s %s" % (
alter_table(compiler, element.table_name, element.schema),
format_column_name(compiler, element.column_name),
@@ -329,7 +359,9 @@ def _mysql_alter_default(element, compiler, **kw):
@compiles(MySQLModifyColumn, "mysql", "mariadb")
-def _mysql_modify_column(element, compiler, **kw):
+def _mysql_modify_column(
+ element: "MySQLModifyColumn", compiler: "MySQLDDLCompiler", **kw
+) -> str:
return "%s MODIFY %s %s" % (
alter_table(compiler, element.table_name, element.schema),
format_column_name(compiler, element.column_name),
@@ -345,7 +377,9 @@ def _mysql_modify_column(element, compiler, **kw):
@compiles(MySQLChangeColumn, "mysql", "mariadb")
-def _mysql_change_column(element, compiler, **kw):
+def _mysql_change_column(
+ element: "MySQLChangeColumn", compiler: "MySQLDDLCompiler", **kw
+) -> str:
return "%s CHANGE %s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
format_column_name(compiler, element.column_name),
@@ -362,8 +396,13 @@ def _mysql_change_column(element, compiler, **kw):
def _mysql_colspec(
- compiler, nullable, server_default, type_, autoincrement, comment
-):
+ compiler: "MySQLDDLCompiler",
+ nullable: Optional[bool],
+ server_default: Optional[Union["_ServerDefault", "Literal[False]"]],
+ type_: "TypeEngine",
+ autoincrement: Optional[bool],
+ comment: Optional[Union[str, "Literal[False]"]],
+) -> str:
spec = "%s %s" % (
compiler.dialect.type_compiler.process(type_),
"NULL" if nullable else "NOT NULL",
@@ -381,7 +420,9 @@ def _mysql_colspec(
@compiles(schema.DropConstraint, "mysql", "mariadb")
-def _mysql_drop_constraint(element, compiler, **kw):
+def _mysql_drop_constraint(
+ element: "DropConstraint", compiler: "MySQLDDLCompiler", **kw
+) -> str:
"""Redefine SQLAlchemy's drop constraint to
raise errors for invalid constraint type."""
@@ -394,7 +435,8 @@ def _mysql_drop_constraint(element, compiler, **kw):
schema.UniqueConstraint,
),
):
- return compiler.visit_drop_constraint(element, **kw)
+ assert not kw
+ return compiler.visit_drop_constraint(element)
elif isinstance(constraint, schema.CheckConstraint):
# note that SQLAlchemy as of 1.2 does not yet support
# DROP CONSTRAINT for MySQL/MariaDB, so we implement fully
diff --git a/alembic/ddl/oracle.py b/alembic/ddl/oracle.py
index 90f93d2..915edb8 100644
--- a/alembic/ddl/oracle.py
+++ b/alembic/ddl/oracle.py
@@ -1,3 +1,8 @@
+from typing import Any
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import Union
+
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql import sqltypes
@@ -16,6 +21,12 @@ from .base import IdentityColumnDefault
from .base import RenameTable
from .impl import DefaultImpl
+if TYPE_CHECKING:
+ from sqlalchemy.dialects.oracle.base import OracleDDLCompiler
+ from sqlalchemy.engine.cursor import CursorResult
+ from sqlalchemy.engine.cursor import LegacyCursorResult
+ from sqlalchemy.sql.schema import Column
+
class OracleImpl(DefaultImpl):
__dialect__ = "oracle"
@@ -28,27 +39,31 @@ class OracleImpl(DefaultImpl):
)
identity_attrs_ignore = ()
- def __init__(self, *arg, **kw):
+ def __init__(self, *arg, **kw) -> None:
super(OracleImpl, self).__init__(*arg, **kw)
self.batch_separator = self.context_opts.get(
"oracle_batch_separator", self.batch_separator
)
- def _exec(self, construct, *args, **kw):
+ def _exec(
+ self, construct: Any, *args, **kw
+ ) -> Optional[Union["LegacyCursorResult", "CursorResult"]]:
result = super(OracleImpl, self)._exec(construct, *args, **kw)
if self.as_sql and self.batch_separator:
self.static_output(self.batch_separator)
return result
- def emit_begin(self):
+ def emit_begin(self) -> None:
self._exec("SET TRANSACTION READ WRITE")
- def emit_commit(self):
+ def emit_commit(self) -> None:
self._exec("COMMIT")
@compiles(AddColumn, "oracle")
-def visit_add_column(element, compiler, **kw):
+def visit_add_column(
+ element: "AddColumn", compiler: "OracleDDLCompiler", **kw
+) -> str:
return "%s %s" % (
alter_table(compiler, element.table_name, element.schema),
add_column(compiler, element.column, **kw),
@@ -56,7 +71,9 @@ def visit_add_column(element, compiler, **kw):
@compiles(ColumnNullable, "oracle")
-def visit_column_nullable(element, compiler, **kw):
+def visit_column_nullable(
+ element: "ColumnNullable", compiler: "OracleDDLCompiler", **kw
+) -> str:
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
@@ -65,7 +82,9 @@ def visit_column_nullable(element, compiler, **kw):
@compiles(ColumnType, "oracle")
-def visit_column_type(element, compiler, **kw):
+def visit_column_type(
+ element: "ColumnType", compiler: "OracleDDLCompiler", **kw
+) -> str:
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
@@ -74,7 +93,9 @@ def visit_column_type(element, compiler, **kw):
@compiles(ColumnName, "oracle")
-def visit_column_name(element, compiler, **kw):
+def visit_column_name(
+ element: "ColumnName", compiler: "OracleDDLCompiler", **kw
+) -> str:
return "%s RENAME COLUMN %s TO %s" % (
alter_table(compiler, element.table_name, element.schema),
format_column_name(compiler, element.column_name),
@@ -83,7 +104,9 @@ def visit_column_name(element, compiler, **kw):
@compiles(ColumnDefault, "oracle")
-def visit_column_default(element, compiler, **kw):
+def visit_column_default(
+ element: "ColumnDefault", compiler: "OracleDDLCompiler", **kw
+) -> str:
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
@@ -94,7 +117,9 @@ def visit_column_default(element, compiler, **kw):
@compiles(ColumnComment, "oracle")
-def visit_column_comment(element, compiler, **kw):
+def visit_column_comment(
+ element: "ColumnComment", compiler: "OracleDDLCompiler", **kw
+) -> str:
ddl = "COMMENT ON COLUMN {table_name}.{column_name} IS {comment}"
comment = compiler.sql_compiler.render_literal_value(
@@ -110,23 +135,27 @@ def visit_column_comment(element, compiler, **kw):
@compiles(RenameTable, "oracle")
-def visit_rename_table(element, compiler, **kw):
+def visit_rename_table(
+ element: "RenameTable", compiler: "OracleDDLCompiler", **kw
+) -> str:
return "%s RENAME TO %s" % (
alter_table(compiler, element.table_name, element.schema),
format_table_name(compiler, element.new_table_name, None),
)
-def alter_column(compiler, name):
+def alter_column(compiler: "OracleDDLCompiler", name: str) -> str:
return "MODIFY %s" % format_column_name(compiler, name)
-def add_column(compiler, column, **kw):
+def add_column(compiler: "OracleDDLCompiler", column: "Column", **kw) -> str:
return "ADD %s" % compiler.get_column_specification(column, **kw)
@compiles(IdentityColumnDefault, "oracle")
-def visit_identity_column(element, compiler, **kw):
+def visit_identity_column(
+ element: "IdentityColumnDefault", compiler: "OracleDDLCompiler", **kw
+):
text = "%s %s " % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
diff --git a/alembic/ddl/postgresql.py b/alembic/ddl/postgresql.py
index 7468f08..c894649 100644
--- a/alembic/ddl/postgresql.py
+++ b/alembic/ddl/postgresql.py
@@ -1,5 +1,13 @@
import logging
import re
+from typing import Any
+from typing import cast
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import Union
from sqlalchemy import Column
from sqlalchemy import Numeric
@@ -8,8 +16,8 @@ from sqlalchemy import types as sqltypes
from sqlalchemy.dialects.postgresql import BIGINT
from sqlalchemy.dialects.postgresql import ExcludeConstraint
from sqlalchemy.dialects.postgresql import INTEGER
-from sqlalchemy.sql.expression import ColumnClause
-from sqlalchemy.sql.expression import UnaryExpression
+from sqlalchemy.sql.elements import ColumnClause
+from sqlalchemy.sql.elements import UnaryExpression
from sqlalchemy.types import NULLTYPE
from .base import alter_column
@@ -32,6 +40,25 @@ from ..operations.base import Operations
from ..util import compat
from ..util import sqla_compat
+if TYPE_CHECKING:
+ from typing import Literal
+
+ from sqlalchemy.dialects.postgresql.array import ARRAY
+ from sqlalchemy.dialects.postgresql.base import PGDDLCompiler
+ from sqlalchemy.dialects.postgresql.hstore import HSTORE
+ from sqlalchemy.dialects.postgresql.json import JSON
+ from sqlalchemy.dialects.postgresql.json import JSONB
+ from sqlalchemy.sql.elements import BinaryExpression
+ from sqlalchemy.sql.elements import quoted_name
+ from sqlalchemy.sql.schema import MetaData
+ from sqlalchemy.sql.schema import Table
+ from sqlalchemy.sql.type_api import TypeEngine
+
+ from .base import _ServerDefault
+ from ..autogenerate.api import AutogenContext
+ from ..autogenerate.render import _f_name
+ from ..runtime.migration import MigrationContext
+
log = logging.getLogger(__name__)
@@ -94,22 +121,22 @@ class PostgresqlImpl(DefaultImpl):
)
)
- def alter_column(
+ def alter_column( # type:ignore[override]
self,
- table_name,
- column_name,
- nullable=None,
- server_default=False,
- name=None,
- type_=None,
- schema=None,
- autoincrement=None,
- existing_type=None,
- existing_server_default=None,
- existing_nullable=None,
- existing_autoincrement=None,
- **kw
- ):
+ table_name: str,
+ column_name: str,
+ nullable: Optional[bool] = None,
+ server_default: Union["_ServerDefault", "Literal[False]"] = False,
+ name: Optional[str] = None,
+ type_: Optional["TypeEngine"] = None,
+ schema: Optional[str] = None,
+ autoincrement: Optional[bool] = None,
+ existing_type: Optional["TypeEngine"] = None,
+ existing_server_default: Optional["_ServerDefault"] = None,
+ existing_nullable: Optional[bool] = None,
+ existing_autoincrement: Optional[bool] = None,
+ **kw: Any
+ ) -> None:
using = kw.pop("postgresql_using", None)
@@ -218,7 +245,9 @@ class PostgresqlImpl(DefaultImpl):
)
metadata_indexes.discard(idx)
- def render_type(self, type_, autogen_context):
+ def render_type(
+ self, type_: "TypeEngine", autogen_context: "AutogenContext"
+ ) -> Union[str, "Literal[False]"]:
mod = type(type_).__module__
if not mod.startswith("sqlalchemy.dialects.postgresql"):
return False
@@ -229,29 +258,51 @@ class PostgresqlImpl(DefaultImpl):
return False
- def _render_HSTORE_type(self, type_, autogen_context):
- return render._render_type_w_subtype(
- type_, autogen_context, "text_type", r"(.+?\(.*text_type=)"
+ def _render_HSTORE_type(
+ self, type_: "HSTORE", autogen_context: "AutogenContext"
+ ) -> str:
+ return cast(
+ str,
+ render._render_type_w_subtype(
+ type_, autogen_context, "text_type", r"(.+?\(.*text_type=)"
+ ),
)
- def _render_ARRAY_type(self, type_, autogen_context):
- return render._render_type_w_subtype(
- type_, autogen_context, "item_type", r"(.+?\()"
+ def _render_ARRAY_type(
+ self, type_: "ARRAY", autogen_context: "AutogenContext"
+ ) -> str:
+ return cast(
+ str,
+ render._render_type_w_subtype(
+ type_, autogen_context, "item_type", r"(.+?\()"
+ ),
)
- def _render_JSON_type(self, type_, autogen_context):
- return render._render_type_w_subtype(
- type_, autogen_context, "astext_type", r"(.+?\(.*astext_type=)"
+ def _render_JSON_type(
+ self, type_: "JSON", autogen_context: "AutogenContext"
+ ) -> str:
+ return cast(
+ str,
+ render._render_type_w_subtype(
+ type_, autogen_context, "astext_type", r"(.+?\(.*astext_type=)"
+ ),
)
- def _render_JSONB_type(self, type_, autogen_context):
- return render._render_type_w_subtype(
- type_, autogen_context, "astext_type", r"(.+?\(.*astext_type=)"
+ def _render_JSONB_type(
+ self, type_: "JSONB", autogen_context: "AutogenContext"
+ ) -> str:
+ return cast(
+ str,
+ render._render_type_w_subtype(
+ type_, autogen_context, "astext_type", r"(.+?\(.*astext_type=)"
+ ),
)
class PostgresqlColumnType(AlterColumn):
- def __init__(self, name, column_name, type_, **kw):
+ def __init__(
+ self, name: str, column_name: str, type_: "TypeEngine", **kw
+ ) -> None:
using = kw.pop("using", None)
super(PostgresqlColumnType, self).__init__(name, column_name, **kw)
self.type_ = sqltypes.to_instance(type_)
@@ -259,7 +310,9 @@ class PostgresqlColumnType(AlterColumn):
@compiles(RenameTable, "postgresql")
-def visit_rename_table(element, compiler, **kw):
+def visit_rename_table(
+ element: RenameTable, compiler: "PGDDLCompiler", **kw
+) -> str:
return "%s RENAME TO %s" % (
alter_table(compiler, element.table_name, element.schema),
format_table_name(compiler, element.new_table_name, None),
@@ -267,7 +320,9 @@ def visit_rename_table(element, compiler, **kw):
@compiles(PostgresqlColumnType, "postgresql")
-def visit_column_type(element, compiler, **kw):
+def visit_column_type(
+ element: PostgresqlColumnType, compiler: "PGDDLCompiler", **kw
+) -> str:
return "%s %s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
@@ -277,7 +332,9 @@ def visit_column_type(element, compiler, **kw):
@compiles(ColumnComment, "postgresql")
-def visit_column_comment(element, compiler, **kw):
+def visit_column_comment(
+ element: "ColumnComment", compiler: "PGDDLCompiler", **kw
+) -> str:
ddl = "COMMENT ON COLUMN {table_name}.{column_name} IS {comment}"
comment = (
compiler.sql_compiler.render_literal_value(
@@ -297,7 +354,9 @@ def visit_column_comment(element, compiler, **kw):
@compiles(IdentityColumnDefault, "postgresql")
-def visit_identity_column(element, compiler, **kw):
+def visit_identity_column(
+ element: "IdentityColumnDefault", compiler: "PGDDLCompiler", **kw
+):
text = "%s %s " % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
@@ -341,14 +400,17 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
def __init__(
self,
- constraint_name,
- table_name,
- elements,
- where=None,
- schema=None,
- _orig_constraint=None,
+ constraint_name: Optional[str],
+ table_name: Union[str, "quoted_name"],
+ elements: Union[
+ Sequence[Tuple[str, str]],
+ Sequence[Tuple["ColumnClause", str]],
+ ],
+ where: Optional[Union["BinaryExpression", str]] = None,
+ schema: Optional[str] = None,
+ _orig_constraint: Optional["ExcludeConstraint"] = None,
**kw
- ):
+ ) -> None:
self.constraint_name = constraint_name
self.table_name = table_name
self.elements = elements
@@ -358,13 +420,18 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
self.kw = kw
@classmethod
- def from_constraint(cls, constraint):
+ def from_constraint( # type:ignore[override]
+ cls, constraint: "ExcludeConstraint"
+ ) -> "CreateExcludeConstraintOp":
constraint_table = sqla_compat._table_for_constraint(constraint)
return cls(
constraint.name,
constraint_table.name,
- [(expr, op) for expr, name, op in constraint._render_exprs],
+ [
+ (expr, op)
+ for expr, name, op in constraint._render_exprs # type:ignore[attr-defined] # noqa
+ ],
where=constraint.where,
schema=constraint_table.schema,
_orig_constraint=constraint,
@@ -373,7 +440,9 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
using=constraint.using,
)
- def to_constraint(self, migration_context=None):
+ def to_constraint(
+ self, migration_context: Optional["MigrationContext"] = None
+ ) -> "ExcludeConstraint":
if self._orig_constraint is not None:
return self._orig_constraint
schema_obj = schemaobj.SchemaObjects(migration_context)
@@ -384,15 +453,24 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
where=self.where,
**self.kw
)
- for expr, name, oper in excl._render_exprs:
+ for (
+ expr,
+ name,
+ oper,
+ ) in excl._render_exprs: # type:ignore[attr-defined]
t.append_column(Column(name, NULLTYPE))
t.append_constraint(excl)
return excl
@classmethod
def create_exclude_constraint(
- cls, operations, constraint_name, table_name, *elements, **kw
- ):
+ cls,
+ operations: "Operations",
+ constraint_name: str,
+ table_name: str,
+ *elements: Any,
+ **kw: Any
+ ) -> Optional["Table"]:
"""Issue an alter to create an EXCLUDE constraint using the
current migration context.
@@ -453,14 +531,18 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
@render.renderers.dispatch_for(CreateExcludeConstraintOp)
-def _add_exclude_constraint(autogen_context, op):
+def _add_exclude_constraint(
+ autogen_context: "AutogenContext", op: "CreateExcludeConstraintOp"
+) -> str:
return _exclude_constraint(op.to_constraint(), autogen_context, alter=True)
@render._constraint_renderers.dispatch_for(ExcludeConstraint)
def _render_inline_exclude_constraint(
- constraint, autogen_context, namespace_metadata
-):
+ constraint: "ExcludeConstraint",
+ autogen_context: "AutogenContext",
+ namespace_metadata: "MetaData",
+) -> str:
rendered = render._user_defined_render(
"exclude", constraint, autogen_context
)
@@ -470,7 +552,7 @@ def _render_inline_exclude_constraint(
return _exclude_constraint(constraint, autogen_context, False)
-def _postgresql_autogenerate_prefix(autogen_context):
+def _postgresql_autogenerate_prefix(autogen_context: "AutogenContext") -> str:
imports = autogen_context.imports
if imports is not None:
@@ -478,8 +560,12 @@ def _postgresql_autogenerate_prefix(autogen_context):
return "postgresql."
-def _exclude_constraint(constraint, autogen_context, alter):
- opts = []
+def _exclude_constraint(
+ constraint: "ExcludeConstraint",
+ autogen_context: "AutogenContext",
+ alter: bool,
+) -> str:
+ opts: List[Tuple[str, Union[quoted_name, str, _f_name, None]]] = []
has_batch = autogen_context._has_batch
@@ -509,7 +595,7 @@ def _exclude_constraint(constraint, autogen_context, alter):
_render_potential_column(sqltext, autogen_context),
opstring,
)
- for sqltext, name, opstring in constraint._render_exprs
+ for sqltext, name, opstring in constraint._render_exprs # type:ignore[attr-defined] # noqa
]
)
if constraint.where is not None:
@@ -528,7 +614,7 @@ def _exclude_constraint(constraint, autogen_context, alter):
args = [
"(%s, %r)"
% (_render_potential_column(sqltext, autogen_context), opstring)
- for sqltext, name, opstring in constraint._render_exprs
+ for sqltext, name, opstring in constraint._render_exprs # type:ignore[attr-defined] # noqa
]
if constraint.where is not None:
args.append(
@@ -544,7 +630,9 @@ def _exclude_constraint(constraint, autogen_context, alter):
}
-def _render_potential_column(value, autogen_context):
+def _render_potential_column(
+ value: Union["ColumnClause", "Column"], autogen_context: "AutogenContext"
+) -> str:
if isinstance(value, ColumnClause):
template = "%(prefix)scolumn(%(name)r)"
diff --git a/alembic/ddl/sqlite.py b/alembic/ddl/sqlite.py
index cb790ea..2f4ed77 100644
--- a/alembic/ddl/sqlite.py
+++ b/alembic/ddl/sqlite.py
@@ -1,4 +1,9 @@
import re
+from typing import Any
+from typing import Dict
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import Union
from sqlalchemy import cast
from sqlalchemy import JSON
@@ -8,6 +13,17 @@ from sqlalchemy import sql
from .impl import DefaultImpl
from .. import util
+if TYPE_CHECKING:
+ from sqlalchemy.engine.reflection import Inspector
+ from sqlalchemy.sql.elements import Cast
+ from sqlalchemy.sql.elements import ClauseElement
+ from sqlalchemy.sql.schema import Column
+ from sqlalchemy.sql.schema import Constraint
+ from sqlalchemy.sql.schema import Table
+ from sqlalchemy.sql.type_api import TypeEngine
+
+ from ..operations.batch import BatchOperationsImpl
+
class SQLiteImpl(DefaultImpl):
__dialect__ = "sqlite"
@@ -17,7 +33,9 @@ class SQLiteImpl(DefaultImpl):
see: http://bugs.python.org/issue10740
"""
- def requires_recreate_in_batch(self, batch_op):
+ def requires_recreate_in_batch(
+ self, batch_op: "BatchOperationsImpl"
+ ) -> bool:
"""Return True if the given :class:`.BatchOperationsImpl`
would need the table to be recreated and copied in order to
proceed.
@@ -44,16 +62,16 @@ class SQLiteImpl(DefaultImpl):
else:
return False
- def add_constraint(self, const):
+ def add_constraint(self, const: "Constraint"):
# attempt to distinguish between an
# auto-gen constraint and an explicit one
- if const._create_rule is None:
+ if const._create_rule is None: # type:ignore[attr-defined]
raise NotImplementedError(
"No support for ALTER of constraints in SQLite dialect"
"Please refer to the batch mode feature which allows for "
"SQLite migrations using a copy-and-move strategy."
)
- elif const._create_rule(self):
+ elif const._create_rule(self): # type:ignore[attr-defined]
util.warn(
"Skipping unsupported ALTER for "
"creation of implicit constraint"
@@ -61,8 +79,8 @@ class SQLiteImpl(DefaultImpl):
"SQLite migrations using a copy-and-move strategy."
)
- def drop_constraint(self, const):
- if const._create_rule is None:
+ def drop_constraint(self, const: "Constraint"):
+ if const._create_rule is None: # type:ignore[attr-defined]
raise NotImplementedError(
"No support for ALTER of constraints in SQLite dialect"
"Please refer to the batch mode feature which allows for "
@@ -71,11 +89,11 @@ class SQLiteImpl(DefaultImpl):
def compare_server_default(
self,
- inspector_column,
- metadata_column,
- rendered_metadata_default,
- rendered_inspector_default,
- ):
+ inspector_column: "Column",
+ metadata_column: "Column",
+ rendered_metadata_default: Optional[str],
+ rendered_inspector_default: Optional[str],
+ ) -> bool:
if rendered_metadata_default is not None:
rendered_metadata_default = re.sub(
@@ -93,7 +111,9 @@ class SQLiteImpl(DefaultImpl):
return rendered_inspector_default != rendered_metadata_default
- def _guess_if_default_is_unparenthesized_sql_expr(self, expr):
+ def _guess_if_default_is_unparenthesized_sql_expr(
+ self, expr: Optional[str]
+ ) -> bool:
"""Determine if a server default is a SQL expression or a constant.
There are too many assertions that expect server defaults to round-trip
@@ -112,7 +132,12 @@ class SQLiteImpl(DefaultImpl):
else:
return True
- def autogen_column_reflect(self, inspector, table, column_info):
+ def autogen_column_reflect(
+ self,
+ inspector: "Inspector",
+ table: "Table",
+ column_info: Dict[str, Any],
+ ) -> None:
# SQLite expression defaults require parenthesis when sent
# as DDL
if self._guess_if_default_is_unparenthesized_sql_expr(
@@ -120,7 +145,9 @@ class SQLiteImpl(DefaultImpl):
):
column_info["default"] = "(%s)" % (column_info["default"],)
- def render_ddl_sql_expr(self, expr, is_server_default=False, **kw):
+ def render_ddl_sql_expr(
+ self, expr: "ClauseElement", is_server_default: bool = False, **kw
+ ) -> str:
# SQLite expression defaults require parenthesis when sent
# as DDL
str_expr = super(SQLiteImpl, self).render_ddl_sql_expr(
@@ -134,9 +161,15 @@ class SQLiteImpl(DefaultImpl):
str_expr = "(%s)" % (str_expr,)
return str_expr
- def cast_for_batch_migrate(self, existing, existing_transfer, new_type):
+ def cast_for_batch_migrate(
+ self,
+ existing: "Column",
+ existing_transfer: Dict[str, Union["TypeEngine", "Cast"]],
+ new_type: "TypeEngine",
+ ) -> None:
if (
- existing.type._type_affinity is not new_type._type_affinity
+ existing.type._type_affinity # type:ignore[attr-defined]
+ is not new_type._type_affinity # type:ignore[attr-defined]
and not isinstance(new_type, JSON)
):
existing_transfer["expr"] = cast(