diff options
author | CaselIT <cfederico87@gmail.com> | 2022-09-11 22:38:46 +0200 |
---|---|---|
committer | CaselIT <cfederico87@gmail.com> | 2022-09-12 21:00:52 +0200 |
commit | 0e83fddb6c110bf1658564c248ffad9163a365fa (patch) | |
tree | 9b61207cd621fb526cda1daa55eaf94bd50aeb25 /alembic/ddl | |
parent | 747ec301529cf2a08d56d3596aedbf54a93b8742 (diff) | |
download | alembic-0e83fddb6c110bf1658564c248ffad9163a365fa.tar.gz |
Improve typing
Change-Id: I9fc86c4a92e1b76d19c9e891ff08ce8a46ad4e35
Diffstat (limited to 'alembic/ddl')
-rw-r--r-- | alembic/ddl/base.py | 2 | ||||
-rw-r--r-- | alembic/ddl/impl.py | 45 | ||||
-rw-r--r-- | alembic/ddl/mssql.py | 7 | ||||
-rw-r--r-- | alembic/ddl/oracle.py | 6 |
4 files changed, 27 insertions, 33 deletions
diff --git a/alembic/ddl/base.py b/alembic/ddl/base.py index 7b0f63e..c910786 100644 --- a/alembic/ddl/base.py +++ b/alembic/ddl/base.py @@ -294,7 +294,7 @@ def format_table_name( def format_column_name( compiler: "DDLCompiler", name: Optional[Union["quoted_name", str]] ) -> Union["quoted_name", str]: - return compiler.preparer.quote(name) + return compiler.preparer.quote(name) # type: ignore[arg-type] def format_server_default( diff --git a/alembic/ddl/impl.py b/alembic/ddl/impl.py index 070c124..79d5245 100644 --- a/alembic/ddl/impl.py +++ b/alembic/ddl/impl.py @@ -23,19 +23,16 @@ from .. import util from ..util import sqla_compat if TYPE_CHECKING: - from io import StringIO from typing import Literal + from typing import TextIO from sqlalchemy.engine import Connection from sqlalchemy.engine import Dialect from sqlalchemy.engine.cursor import CursorResult - from sqlalchemy.engine.cursor import LegacyCursorResult from sqlalchemy.engine.reflection import Inspector - from sqlalchemy.sql.dml import Update from sqlalchemy.sql.elements import ClauseElement from sqlalchemy.sql.elements import ColumnElement from sqlalchemy.sql.elements import quoted_name - from sqlalchemy.sql.elements import TextClause from sqlalchemy.sql.schema import Column from sqlalchemy.sql.schema import Constraint from sqlalchemy.sql.schema import ForeignKeyConstraint @@ -60,11 +57,11 @@ class ImplMeta(type): ): newtype = type.__init__(cls, classname, bases, dict_) if "__dialect__" in dict_: - _impls[dict_["__dialect__"]] = cls + _impls[dict_["__dialect__"]] = cls # type: ignore[assignment] return newtype -_impls: dict = {} +_impls: Dict[str, Type["DefaultImpl"]] = {} Params = namedtuple("Params", ["token0", "tokens", "args", "kwargs"]) @@ -98,7 +95,7 @@ class DefaultImpl(metaclass=ImplMeta): connection: Optional["Connection"], as_sql: bool, transactional_ddl: Optional[bool], - output_buffer: Optional["StringIO"], + output_buffer: Optional["TextIO"], context_opts: Dict[str, Any], ) -> None: self.dialect = dialect @@ -119,7 +116,7 @@ class DefaultImpl(metaclass=ImplMeta): ) @classmethod - def get_by_dialect(cls, dialect: "Dialect") -> Any: + def get_by_dialect(cls, dialect: "Dialect") -> Type["DefaultImpl"]: return _impls[dialect.name] def static_output(self, text: str) -> None: @@ -158,10 +155,10 @@ class DefaultImpl(metaclass=ImplMeta): def _exec( self, construct: Union["ClauseElement", str], - execution_options: None = None, + execution_options: Optional[dict] = None, multiparams: Sequence[dict] = (), params: Dict[str, int] = util.immutabledict(), - ) -> Optional[Union["LegacyCursorResult", "CursorResult"]]: + ) -> Optional["CursorResult"]: if isinstance(construct, str): construct = text(construct) if self.as_sql: @@ -176,10 +173,11 @@ class DefaultImpl(metaclass=ImplMeta): else: compile_kw = {} + compiled = construct.compile( + dialect=self.dialect, **compile_kw # type: ignore[arg-type] + ) self.static_output( - str(construct.compile(dialect=self.dialect, **compile_kw)) - .replace("\t", " ") - .strip() + str(compiled).replace("\t", " ").strip() + self.command_terminator ) return None @@ -192,11 +190,13 @@ class DefaultImpl(metaclass=ImplMeta): assert isinstance(multiparams, tuple) multiparams += (params,) - return conn.execute(construct, multiparams) + return conn.execute( # type: ignore[call-overload] + construct, multiparams + ) def execute( self, - sql: Union["Update", "TextClause", str], + sql: Union["ClauseElement", str], execution_options: None = None, ) -> None: self._exec(sql, execution_options) @@ -424,9 +424,6 @@ class DefaultImpl(metaclass=ImplMeta): ) ) else: - # work around http://www.sqlalchemy.org/trac/ticket/2461 - if not hasattr(table, "_autoincrement_column"): - table._autoincrement_column = None if rows: if multiinsert: self._exec( @@ -572,7 +569,7 @@ class DefaultImpl(metaclass=ImplMeta): ) def render_ddl_sql_expr( - self, expr: "ClauseElement", is_server_default: bool = False, **kw + self, expr: "ClauseElement", is_server_default: bool = False, **kw: Any ) -> str: """Render a SQL expression that is typically a server default, index expression, etc. @@ -581,10 +578,14 @@ class DefaultImpl(metaclass=ImplMeta): """ - compile_kw = dict( - compile_kwargs={"literal_binds": True, "include_table": False} + compile_kw = { + "compile_kwargs": {"literal_binds": True, "include_table": False} + } + return str( + expr.compile( + dialect=self.dialect, **compile_kw # type: ignore[arg-type] + ) ) - return str(expr.compile(dialect=self.dialect, **compile_kw)) def _compat_autogen_column_reflect( self, inspector: "Inspector" diff --git a/alembic/ddl/mssql.py b/alembic/ddl/mssql.py index b48f8ba..28f0678 100644 --- a/alembic/ddl/mssql.py +++ b/alembic/ddl/mssql.py @@ -35,7 +35,6 @@ if TYPE_CHECKING: from sqlalchemy.dialects.mssql.base import MSDDLCompiler from sqlalchemy.dialects.mssql.base import MSSQLCompiler from sqlalchemy.engine.cursor import CursorResult - from sqlalchemy.engine.cursor import LegacyCursorResult from sqlalchemy.sql.schema import Index from sqlalchemy.sql.schema import Table from sqlalchemy.sql.selectable import TableClause @@ -68,9 +67,7 @@ class MSSQLImpl(DefaultImpl): "mssql_batch_separator", self.batch_separator ) - def _exec( - self, construct: Any, *args, **kw - ) -> Optional[Union["LegacyCursorResult", "CursorResult"]]: + def _exec(self, construct: Any, *args, **kw) -> Optional["CursorResult"]: result = super(MSSQLImpl, self)._exec(construct, *args, **kw) if self.as_sql and self.batch_separator: self.static_output(self.batch_separator) @@ -359,7 +356,7 @@ def visit_column_nullable( return "%s %s %s %s" % ( alter_table(compiler, element.table_name, element.schema), alter_column(compiler, element.column_name), - format_type(compiler, element.existing_type), + format_type(compiler, element.existing_type), # type: ignore[arg-type] "NULL" if element.nullable else "NOT NULL", ) diff --git a/alembic/ddl/oracle.py b/alembic/ddl/oracle.py index 0e787fb..accd1fc 100644 --- a/alembic/ddl/oracle.py +++ b/alembic/ddl/oracle.py @@ -3,7 +3,6 @@ from __future__ import annotations from typing import Any from typing import Optional from typing import TYPE_CHECKING -from typing import Union from sqlalchemy.ext.compiler import compiles from sqlalchemy.sql import sqltypes @@ -26,7 +25,6 @@ from .impl import DefaultImpl if TYPE_CHECKING: from sqlalchemy.dialects.oracle.base import OracleDDLCompiler from sqlalchemy.engine.cursor import CursorResult - from sqlalchemy.engine.cursor import LegacyCursorResult from sqlalchemy.sql.schema import Column @@ -48,9 +46,7 @@ class OracleImpl(DefaultImpl): "oracle_batch_separator", self.batch_separator ) - def _exec( - self, construct: Any, *args, **kw - ) -> Optional[Union["LegacyCursorResult", "CursorResult"]]: + def _exec(self, construct: Any, *args, **kw) -> Optional["CursorResult"]: result = super(OracleImpl, self)._exec(construct, *args, **kw) if self.as_sql and self.batch_separator: self.static_output(self.batch_separator) |