summaryrefslogtreecommitdiff
path: root/alembic/ddl
diff options
context:
space:
mode:
authorCaselIT <cfederico87@gmail.com>2022-09-11 22:38:46 +0200
committerCaselIT <cfederico87@gmail.com>2022-09-12 21:00:52 +0200
commit0e83fddb6c110bf1658564c248ffad9163a365fa (patch)
tree9b61207cd621fb526cda1daa55eaf94bd50aeb25 /alembic/ddl
parent747ec301529cf2a08d56d3596aedbf54a93b8742 (diff)
downloadalembic-0e83fddb6c110bf1658564c248ffad9163a365fa.tar.gz
Improve typing
Change-Id: I9fc86c4a92e1b76d19c9e891ff08ce8a46ad4e35
Diffstat (limited to 'alembic/ddl')
-rw-r--r--alembic/ddl/base.py2
-rw-r--r--alembic/ddl/impl.py45
-rw-r--r--alembic/ddl/mssql.py7
-rw-r--r--alembic/ddl/oracle.py6
4 files changed, 27 insertions, 33 deletions
diff --git a/alembic/ddl/base.py b/alembic/ddl/base.py
index 7b0f63e..c910786 100644
--- a/alembic/ddl/base.py
+++ b/alembic/ddl/base.py
@@ -294,7 +294,7 @@ def format_table_name(
def format_column_name(
compiler: "DDLCompiler", name: Optional[Union["quoted_name", str]]
) -> Union["quoted_name", str]:
- return compiler.preparer.quote(name)
+ return compiler.preparer.quote(name) # type: ignore[arg-type]
def format_server_default(
diff --git a/alembic/ddl/impl.py b/alembic/ddl/impl.py
index 070c124..79d5245 100644
--- a/alembic/ddl/impl.py
+++ b/alembic/ddl/impl.py
@@ -23,19 +23,16 @@ from .. import util
from ..util import sqla_compat
if TYPE_CHECKING:
- from io import StringIO
from typing import Literal
+ from typing import TextIO
from sqlalchemy.engine import Connection
from sqlalchemy.engine import Dialect
from sqlalchemy.engine.cursor import CursorResult
- from sqlalchemy.engine.cursor import LegacyCursorResult
from sqlalchemy.engine.reflection import Inspector
- from sqlalchemy.sql.dml import Update
from sqlalchemy.sql.elements import ClauseElement
from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.elements import quoted_name
- from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.schema import Column
from sqlalchemy.sql.schema import Constraint
from sqlalchemy.sql.schema import ForeignKeyConstraint
@@ -60,11 +57,11 @@ class ImplMeta(type):
):
newtype = type.__init__(cls, classname, bases, dict_)
if "__dialect__" in dict_:
- _impls[dict_["__dialect__"]] = cls
+ _impls[dict_["__dialect__"]] = cls # type: ignore[assignment]
return newtype
-_impls: dict = {}
+_impls: Dict[str, Type["DefaultImpl"]] = {}
Params = namedtuple("Params", ["token0", "tokens", "args", "kwargs"])
@@ -98,7 +95,7 @@ class DefaultImpl(metaclass=ImplMeta):
connection: Optional["Connection"],
as_sql: bool,
transactional_ddl: Optional[bool],
- output_buffer: Optional["StringIO"],
+ output_buffer: Optional["TextIO"],
context_opts: Dict[str, Any],
) -> None:
self.dialect = dialect
@@ -119,7 +116,7 @@ class DefaultImpl(metaclass=ImplMeta):
)
@classmethod
- def get_by_dialect(cls, dialect: "Dialect") -> Any:
+ def get_by_dialect(cls, dialect: "Dialect") -> Type["DefaultImpl"]:
return _impls[dialect.name]
def static_output(self, text: str) -> None:
@@ -158,10 +155,10 @@ class DefaultImpl(metaclass=ImplMeta):
def _exec(
self,
construct: Union["ClauseElement", str],
- execution_options: None = None,
+ execution_options: Optional[dict] = None,
multiparams: Sequence[dict] = (),
params: Dict[str, int] = util.immutabledict(),
- ) -> Optional[Union["LegacyCursorResult", "CursorResult"]]:
+ ) -> Optional["CursorResult"]:
if isinstance(construct, str):
construct = text(construct)
if self.as_sql:
@@ -176,10 +173,11 @@ class DefaultImpl(metaclass=ImplMeta):
else:
compile_kw = {}
+ compiled = construct.compile(
+ dialect=self.dialect, **compile_kw # type: ignore[arg-type]
+ )
self.static_output(
- str(construct.compile(dialect=self.dialect, **compile_kw))
- .replace("\t", " ")
- .strip()
+ str(compiled).replace("\t", " ").strip()
+ self.command_terminator
)
return None
@@ -192,11 +190,13 @@ class DefaultImpl(metaclass=ImplMeta):
assert isinstance(multiparams, tuple)
multiparams += (params,)
- return conn.execute(construct, multiparams)
+ return conn.execute( # type: ignore[call-overload]
+ construct, multiparams
+ )
def execute(
self,
- sql: Union["Update", "TextClause", str],
+ sql: Union["ClauseElement", str],
execution_options: None = None,
) -> None:
self._exec(sql, execution_options)
@@ -424,9 +424,6 @@ class DefaultImpl(metaclass=ImplMeta):
)
)
else:
- # work around http://www.sqlalchemy.org/trac/ticket/2461
- if not hasattr(table, "_autoincrement_column"):
- table._autoincrement_column = None
if rows:
if multiinsert:
self._exec(
@@ -572,7 +569,7 @@ class DefaultImpl(metaclass=ImplMeta):
)
def render_ddl_sql_expr(
- self, expr: "ClauseElement", is_server_default: bool = False, **kw
+ self, expr: "ClauseElement", is_server_default: bool = False, **kw: Any
) -> str:
"""Render a SQL expression that is typically a server default,
index expression, etc.
@@ -581,10 +578,14 @@ class DefaultImpl(metaclass=ImplMeta):
"""
- compile_kw = dict(
- compile_kwargs={"literal_binds": True, "include_table": False}
+ compile_kw = {
+ "compile_kwargs": {"literal_binds": True, "include_table": False}
+ }
+ return str(
+ expr.compile(
+ dialect=self.dialect, **compile_kw # type: ignore[arg-type]
+ )
)
- return str(expr.compile(dialect=self.dialect, **compile_kw))
def _compat_autogen_column_reflect(
self, inspector: "Inspector"
diff --git a/alembic/ddl/mssql.py b/alembic/ddl/mssql.py
index b48f8ba..28f0678 100644
--- a/alembic/ddl/mssql.py
+++ b/alembic/ddl/mssql.py
@@ -35,7 +35,6 @@ if TYPE_CHECKING:
from sqlalchemy.dialects.mssql.base import MSDDLCompiler
from sqlalchemy.dialects.mssql.base import MSSQLCompiler
from sqlalchemy.engine.cursor import CursorResult
- from sqlalchemy.engine.cursor import LegacyCursorResult
from sqlalchemy.sql.schema import Index
from sqlalchemy.sql.schema import Table
from sqlalchemy.sql.selectable import TableClause
@@ -68,9 +67,7 @@ class MSSQLImpl(DefaultImpl):
"mssql_batch_separator", self.batch_separator
)
- def _exec(
- self, construct: Any, *args, **kw
- ) -> Optional[Union["LegacyCursorResult", "CursorResult"]]:
+ def _exec(self, construct: Any, *args, **kw) -> Optional["CursorResult"]:
result = super(MSSQLImpl, self)._exec(construct, *args, **kw)
if self.as_sql and self.batch_separator:
self.static_output(self.batch_separator)
@@ -359,7 +356,7 @@ def visit_column_nullable(
return "%s %s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
- format_type(compiler, element.existing_type),
+ format_type(compiler, element.existing_type), # type: ignore[arg-type]
"NULL" if element.nullable else "NOT NULL",
)
diff --git a/alembic/ddl/oracle.py b/alembic/ddl/oracle.py
index 0e787fb..accd1fc 100644
--- a/alembic/ddl/oracle.py
+++ b/alembic/ddl/oracle.py
@@ -3,7 +3,6 @@ from __future__ import annotations
from typing import Any
from typing import Optional
from typing import TYPE_CHECKING
-from typing import Union
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql import sqltypes
@@ -26,7 +25,6 @@ from .impl import DefaultImpl
if TYPE_CHECKING:
from sqlalchemy.dialects.oracle.base import OracleDDLCompiler
from sqlalchemy.engine.cursor import CursorResult
- from sqlalchemy.engine.cursor import LegacyCursorResult
from sqlalchemy.sql.schema import Column
@@ -48,9 +46,7 @@ class OracleImpl(DefaultImpl):
"oracle_batch_separator", self.batch_separator
)
- def _exec(
- self, construct: Any, *args, **kw
- ) -> Optional[Union["LegacyCursorResult", "CursorResult"]]:
+ def _exec(self, construct: Any, *args, **kw) -> Optional["CursorResult"]:
result = super(OracleImpl, self)._exec(construct, *args, **kw)
if self.as_sql and self.batch_separator:
self.static_output(self.batch_separator)