diff options
author | CaselIT <cfederico87@gmail.com> | 2021-04-18 15:44:50 +0200 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2021-08-11 15:04:56 -0400 |
commit | 6aad68605f510e8b51f42efa812e02b3831d6e33 (patch) | |
tree | cc0e98b8ad8245add8692d8e4910faf57abf7ae3 /alembic/util | |
parent | 3bf6a326c0a11e4f05c94008709d6b0b8e9e051a (diff) | |
download | alembic-6aad68605f510e8b51f42efa812e02b3831d6e33.tar.gz |
Add pep-484 type annotations
pep-484 type annotations have been added throughout the library. This
should be helpful in providing Mypy and IDE support, however there is not
full support for Alembic's dynamically modified "op" namespace as of yet; a
future release will likely modify the approach used for importing this
namespace to be better compatible with pep-484 capabilities.
Type originally created using MonkeyType
Add types extracted with the MonkeyType https://github.com/instagram/MonkeyType
library by running the unit tests using ``monkeytype run -m pytest tests``, then
``monkeytype apply <module>`` (see below for further details).
USed MonkeyType version 20.5 on Python 3.8, since newer version have issues
After applying the types, the new imports are placed in a ``TYPE_CHECKING`` guard
and all type definition of non base types are deferred by using the string notation.
NOTE: since to apply the types MonkeType need to import the module, also the test
ones, the patch below mocks the setup done by pytest so that the tests could be
correctly imported
diff --git a/alembic/testing/__init__.py b/alembic/testing/__init__.py
index bdd1746..b1090c7 100644
Change-Id: Iff93628f4b43c740848871ce077a118db5e75d41
--- a/alembic/testing/__init__.py
+++ b/alembic/testing/__init__.py
@@ -9,6 +9,12 @@ from sqlalchemy.testing.config import combinations
from sqlalchemy.testing.config import fixture
from sqlalchemy.testing.config import requirements as requires
+from sqlalchemy.testing.plugin.pytestplugin import PytestFixtureFunctions
+from sqlalchemy.testing.plugin.plugin_base import _setup_requirements
+
+config._fixture_functions = PytestFixtureFunctions()
+_setup_requirements("tests.requirements:DefaultRequirements")
+
from alembic import util
from .assertions import assert_raises
from .assertions import assert_raises_message
Currently I'm using this branch of the sqlalchemy stubs:
https://github.com/sqlalchemy/sqlalchemy2-stubs/tree/alembic_updates
Change-Id: I8fd0700aab1913f395302626b8b84fea60334abd
Diffstat (limited to 'alembic/util')
-rw-r--r-- | alembic/util/compat.py | 42 | ||||
-rw-r--r-- | alembic/util/editor.py | 20 | ||||
-rw-r--r-- | alembic/util/langhelpers.py | 66 | ||||
-rw-r--r-- | alembic/util/messaging.py | 23 | ||||
-rw-r--r-- | alembic/util/pyfiles.py | 27 | ||||
-rw-r--r-- | alembic/util/sqla_compat.py | 149 |
6 files changed, 219 insertions, 108 deletions
diff --git a/alembic/util/compat.py b/alembic/util/compat.py index 0fdd86d..a07813c 100644 --- a/alembic/util/compat.py +++ b/alembic/util/compat.py @@ -2,6 +2,12 @@ import collections import inspect import io import os +from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import Optional +from typing import Type is_posix = os.name == "posix" @@ -10,11 +16,11 @@ ArgSpec = collections.namedtuple( ) -def inspect_getargspec(func): +def inspect_getargspec(func: Callable) -> ArgSpec: """getargspec based on fully vendored getfullargspec from Python 3.3.""" if inspect.ismethod(func): - func = func.__func__ + func = func.__func__ # type: ignore if not inspect.isfunction(func): raise TypeError("{!r} is not a Python function".format(func)) @@ -36,7 +42,7 @@ def inspect_getargspec(func): if co.co_flags & inspect.CO_VARKEYWORDS: varkw = co.co_varnames[nargs] - return ArgSpec(args, varargs, varkw, func.__defaults__) + return ArgSpec(args, varargs, varkw, func.__defaults__) # type: ignore string_types = (str,) @@ -57,20 +63,20 @@ def _formatannotation(annotation, base_module=None): def inspect_formatargspec( - args, - varargs=None, - varkw=None, - defaults=None, - kwonlyargs=(), - kwonlydefaults={}, - annotations={}, - formatarg=str, - formatvarargs=lambda name: "*" + name, - formatvarkw=lambda name: "**" + name, - formatvalue=lambda value: "=" + repr(value), - formatreturns=lambda text: " -> " + text, - formatannotation=_formatannotation, -): + args: List[str], + varargs: Optional[str] = None, + varkw: Optional[str] = None, + defaults: Optional[Any] = None, + kwonlyargs: tuple = (), + kwonlydefaults: Dict[Any, Any] = {}, + annotations: Dict[Any, Any] = {}, + formatarg: Type[str] = str, + formatvarargs: Callable = lambda name: "*" + name, + formatvarkw: Callable = lambda name: "**" + name, + formatvalue: Callable = lambda value: "=" + repr(value), + formatreturns: Callable = lambda text: " -> " + text, + formatannotation: Callable = _formatannotation, +) -> str: """Copy formatargspec from python 3.7 standard library. Python 3 has deprecated formatargspec and requested that Signature @@ -118,5 +124,5 @@ def inspect_formatargspec( # into a given buffer, but doesn't close it. # not sure of a more idiomatic approach to this. class EncodedIO(io.TextIOWrapper): - def close(self): + def close(self) -> None: pass diff --git a/alembic/util/editor.py b/alembic/util/editor.py index c27f0f3..ba376c0 100644 --- a/alembic/util/editor.py +++ b/alembic/util/editor.py @@ -3,12 +3,18 @@ from os.path import exists from os.path import join from os.path import splitext from subprocess import check_call +from typing import Dict +from typing import List +from typing import Mapping +from typing import Optional from .compat import is_posix from .exc import CommandError -def open_in_editor(filename, environ=None): +def open_in_editor( + filename: str, environ: Optional[Dict[str, str]] = None +) -> None: """ Opens the given file in a text editor. If the environment variable ``EDITOR`` is set, this is taken as preference. @@ -22,15 +28,15 @@ def open_in_editor(filename, environ=None): :param environ: An optional drop-in replacement for ``os.environ``. Used mainly for testing. """ - + env = os.environ if environ is None else environ try: - editor = _find_editor(environ) + editor = _find_editor(env) check_call([editor, filename]) except Exception as exc: raise CommandError("Error executing editor (%s)" % (exc,)) from exc -def _find_editor(environ=None): +def _find_editor(environ: Mapping[str, str]) -> str: candidates = _default_editors() for i, var in enumerate(("EDITOR", "VISUAL")): if var in environ: @@ -50,7 +56,9 @@ def _find_editor(environ=None): ) -def _find_executable(candidate, environ): +def _find_executable( + candidate: str, environ: Mapping[str, str] +) -> Optional[str]: # Assuming this is on the PATH, we need to determine it's absolute # location. Otherwise, ``check_call`` will fail if not is_posix and splitext(candidate)[1] != ".exe": @@ -62,7 +70,7 @@ def _find_executable(candidate, environ): return None -def _default_editors(): +def _default_editors() -> List[str]: # Look for an editor. Prefer the user's choice by env-var, fall back to # most commonly installed editor (nano/vim) if is_posix: diff --git a/alembic/util/langhelpers.py b/alembic/util/langhelpers.py index dbd1f21..87a9aca 100644 --- a/alembic/util/langhelpers.py +++ b/alembic/util/langhelpers.py @@ -1,6 +1,16 @@ import collections from collections.abc import Iterable import textwrap +from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Tuple +from typing import TypeVar +from typing import Union import uuid import warnings @@ -14,10 +24,13 @@ from .compat import inspect_getargspec from .compat import string_types +_T = TypeVar("_T") + + class _ModuleClsMeta(type): - def __setattr__(cls, key, value): + def __setattr__(cls, key: str, value: Callable) -> None: super(_ModuleClsMeta, cls).__setattr__(key, value) - cls._update_module_proxies(key) + cls._update_module_proxies(key) # type: ignore class ModuleClsProxy(metaclass=_ModuleClsMeta): @@ -29,22 +42,24 @@ class ModuleClsProxy(metaclass=_ModuleClsMeta): """ - _setups = collections.defaultdict(lambda: (set(), [])) + _setups: Dict[type, Tuple[set, list]] = collections.defaultdict( + lambda: (set(), []) + ) @classmethod - def _update_module_proxies(cls, name): + def _update_module_proxies(cls, name: str) -> None: attr_names, modules = cls._setups[cls] for globals_, locals_ in modules: cls._add_proxied_attribute(name, globals_, locals_, attr_names) - def _install_proxy(self): + def _install_proxy(self) -> None: attr_names, modules = self._setups[self.__class__] for globals_, locals_ in modules: globals_["_proxy"] = self for attr_name in attr_names: globals_[attr_name] = getattr(self, attr_name) - def _remove_proxy(self): + def _remove_proxy(self) -> None: attr_names, modules = self._setups[self.__class__] for globals_, locals_ in modules: globals_["_proxy"] = None @@ -171,10 +186,25 @@ def _with_legacy_names(translations): return decorate -def rev_id(): +def rev_id() -> str: return uuid.uuid4().hex[-12:] +@overload +def to_tuple(x: Any, default: tuple) -> tuple: + ... + + +@overload +def to_tuple(x: None, default: _T = None) -> _T: + ... + + +@overload +def to_tuple(x: Any, default: Optional[tuple] = None) -> tuple: + ... + + def to_tuple(x, default=None): if x is None: return default @@ -186,16 +216,18 @@ def to_tuple(x, default=None): return (x,) -def dedupe_tuple(tup): +def dedupe_tuple(tup: Tuple[str, ...]) -> Tuple[str, ...]: return tuple(unique_list(tup)) class Dispatcher: - def __init__(self, uselist=False): - self._registry = {} + def __init__(self, uselist: bool = False) -> None: + self._registry: Dict[tuple, Any] = {} self.uselist = uselist - def dispatch_for(self, target, qualifier="default"): + def dispatch_for( + self, target: Any, qualifier: str = "default" + ) -> Callable: def decorate(fn): if self.uselist: self._registry.setdefault((target, qualifier), []).append(fn) @@ -206,10 +238,10 @@ class Dispatcher: return decorate - def dispatch(self, obj, qualifier="default"): + def dispatch(self, obj: Any, qualifier: str = "default") -> Any: if isinstance(obj, string_types): - targets = [obj] + targets: Sequence = [obj] elif isinstance(obj, type): targets = obj.__mro__ else: @@ -223,7 +255,9 @@ class Dispatcher: else: raise ValueError("no dispatch function for object: %s" % obj) - def _fn_or_list(self, fn_or_list): + def _fn_or_list( + self, fn_or_list: Union[List[Callable], Callable] + ) -> Callable: if self.uselist: def go(*arg, **kw): @@ -232,9 +266,9 @@ class Dispatcher: return go else: - return fn_or_list + return fn_or_list # type: ignore - def branch(self): + def branch(self) -> "Dispatcher": """Return a copy of this dispatcher that is independently writable.""" diff --git a/alembic/util/messaging.py b/alembic/util/messaging.py index 70c9128..062890a 100644 --- a/alembic/util/messaging.py +++ b/alembic/util/messaging.py @@ -2,6 +2,11 @@ from collections.abc import Iterable import logging import sys import textwrap +from typing import Any +from typing import Callable +from typing import Optional +from typing import TextIO +from typing import Union import warnings from sqlalchemy.engine import url @@ -29,7 +34,7 @@ except (ImportError, IOError): TERMWIDTH = None -def write_outstream(stream, *text): +def write_outstream(stream: TextIO, *text) -> None: encoding = getattr(stream, "encoding", "ascii") or "ascii" for t in text: if not isinstance(t, binary_type): @@ -44,7 +49,7 @@ def write_outstream(stream, *text): break -def status(_statmsg, fn, *arg, **kw): +def status(_statmsg: str, fn: Callable, *arg, **kw) -> Any: newline = kw.pop("newline", False) msg(_statmsg + " ...", newline, True) try: @@ -56,27 +61,27 @@ def status(_statmsg, fn, *arg, **kw): raise -def err(message): +def err(message: str): log.error(message) msg("FAILED: %s" % message) sys.exit(-1) -def obfuscate_url_pw(u): - u = url.make_url(u) +def obfuscate_url_pw(input_url: str) -> str: + u = url.make_url(input_url) if u.password: if sqla_compat.sqla_14: u = u.set(password="XXXXX") else: - u.password = "XXXXX" + u.password = "XXXXX" # type: ignore[misc] return str(u) -def warn(msg, stacklevel=2): +def warn(msg: str, stacklevel: int = 2) -> None: warnings.warn(msg, UserWarning, stacklevel=stacklevel) -def msg(msg, newline=True, flush=False): +def msg(msg: str, newline: bool = True, flush: bool = False) -> None: if TERMWIDTH is None: write_outstream(sys.stdout, msg) if newline: @@ -92,7 +97,7 @@ def msg(msg, newline=True, flush=False): sys.stdout.flush() -def format_as_comma(value): +def format_as_comma(value: Optional[Union[str, "Iterable[str]"]]) -> str: if value is None: return "" elif isinstance(value, string_types): diff --git a/alembic/util/pyfiles.py b/alembic/util/pyfiles.py index 53cc3cc..7eb582e 100644 --- a/alembic/util/pyfiles.py +++ b/alembic/util/pyfiles.py @@ -4,6 +4,7 @@ import importlib.util import os import re import tempfile +from typing import Optional from mako import exceptions from mako.template import Template @@ -11,7 +12,9 @@ from mako.template import Template from .exc import CommandError -def template_to_file(template_file, dest, output_encoding, **kw): +def template_to_file( + template_file: str, dest: str, output_encoding: str, **kw +) -> None: template = Template(filename=template_file) try: output = template.render_unicode(**kw).encode(output_encoding) @@ -32,7 +35,7 @@ def template_to_file(template_file, dest, output_encoding, **kw): f.write(output) -def coerce_resource_to_filename(fname): +def coerce_resource_to_filename(fname: str) -> str: """Interpret a filename as either a filesystem location or as a package resource. @@ -47,7 +50,7 @@ def coerce_resource_to_filename(fname): return fname -def pyc_file_from_path(path): +def pyc_file_from_path(path: str) -> Optional[str]: """Given a python source path, locate the .pyc.""" candidate = importlib.util.cache_from_source(path) @@ -64,7 +67,7 @@ def pyc_file_from_path(path): return None -def load_python_file(dir_, filename): +def load_python_file(dir_: str, filename: str): """Load a file from the given path as a Python module.""" module_id = re.sub(r"\W", "_", filename) @@ -78,21 +81,15 @@ def load_python_file(dir_, filename): if pyc_path is None: raise ImportError("Can't find Python file %s" % path) else: - module = load_module_pyc(module_id, pyc_path) + module = load_module_py(module_id, pyc_path) elif ext in (".pyc", ".pyo"): - module = load_module_pyc(module_id, path) + module = load_module_py(module_id, path) return module -def load_module_py(module_id, path): +def load_module_py(module_id: str, path: str): spec = importlib.util.spec_from_file_location(module_id, path) + assert spec module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - return module - - -def load_module_pyc(module_id, path): - spec = importlib.util.spec_from_file_location(module_id, path) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) + spec.loader.exec_module(module) # type: ignore return module diff --git a/alembic/util/sqla_compat.py b/alembic/util/sqla_compat.py index a04ab2e..e1ccd41 100644 --- a/alembic/util/sqla_compat.py +++ b/alembic/util/sqla_compat.py @@ -1,5 +1,11 @@ import contextlib import re +from typing import Iterator +from typing import Mapping +from typing import Optional +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union from sqlalchemy import __version__ from sqlalchemy import inspect @@ -12,15 +18,34 @@ from sqlalchemy.schema import CheckConstraint from sqlalchemy.schema import Column from sqlalchemy.schema import ForeignKeyConstraint from sqlalchemy.sql import visitors +from sqlalchemy.sql.elements import BindParameter from sqlalchemy.sql.elements import quoted_name -from sqlalchemy.sql.expression import _BindParamClause -from sqlalchemy.sql.expression import _TextClause as TextClause +from sqlalchemy.sql.elements import TextClause from sqlalchemy.sql.visitors import traverse from . import compat - -def _safe_int(value): +if TYPE_CHECKING: + from sqlalchemy import Index + from sqlalchemy import Table + from sqlalchemy.engine import Connection + from sqlalchemy.engine import Dialect + from sqlalchemy.engine import Transaction + from sqlalchemy.engine.reflection import Inspector + from sqlalchemy.sql.base import ColumnCollection + from sqlalchemy.sql.compiler import SQLCompiler + from sqlalchemy.sql.dml import Insert + from sqlalchemy.sql.elements import ColumnClause + from sqlalchemy.sql.elements import ColumnElement + from sqlalchemy.sql.schema import Constraint + from sqlalchemy.sql.schema import SchemaItem + from sqlalchemy.sql.selectable import Select + from sqlalchemy.sql.selectable import TableClause + +_CE = TypeVar("_CE", bound=Union["ColumnElement", "SchemaItem"]) + + +def _safe_int(value: str) -> Union[int, str]: try: return int(value) except: @@ -36,6 +61,7 @@ sqla_14 = _vers >= (1, 4) try: from sqlalchemy import Computed # noqa except ImportError: + Computed = None # type: ignore has_computed = False has_computed_reflection = False else: @@ -45,6 +71,7 @@ else: try: from sqlalchemy import Identity # noqa except ImportError: + Identity = None # type: ignore has_identity = False else: # attributes common to Indentity and Sequence @@ -67,21 +94,26 @@ AUTOINCREMENT_DEFAULT = "auto" @contextlib.contextmanager -def _ensure_scope_for_ddl(connection): +def _ensure_scope_for_ddl( + connection: Optional["Connection"], +) -> Iterator[None]: try: - in_transaction = connection.in_transaction + in_transaction = connection.in_transaction # type: ignore[union-attr] except AttributeError: - # catch for MockConnection + # catch for MockConnection, None yield else: if not in_transaction(): + assert connection is not None with connection.begin(): yield else: yield -def _safe_begin_connection_transaction(connection): +def _safe_begin_connection_transaction( + connection: "Connection", +) -> "Transaction": transaction = _get_connection_transaction(connection) if transaction: return transaction @@ -89,9 +121,9 @@ def _safe_begin_connection_transaction(connection): return connection.begin() -def _get_connection_in_transaction(connection): +def _get_connection_in_transaction(connection: Optional["Connection"]) -> bool: try: - in_transaction = connection.in_transaction + in_transaction = connection.in_transaction # type: ignore except AttributeError: # catch for MockConnection return False @@ -99,28 +131,33 @@ def _get_connection_in_transaction(connection): return in_transaction() -def _copy(schema_item, **kw): +def _copy(schema_item: _CE, **kw) -> _CE: if hasattr(schema_item, "_copy"): return schema_item._copy(**kw) else: return schema_item.copy(**kw) -def _get_connection_transaction(connection): +def _get_connection_transaction( + connection: "Connection", +) -> Optional["Transaction"]: if sqla_14: return connection.get_transaction() else: - return connection._root._Connection__transaction + r = connection._root # type: ignore[attr-defined] + return r._Connection__transaction -def _create_url(*arg, **kw): +def _create_url(*arg, **kw) -> url.URL: if hasattr(url.URL, "create"): return url.URL.create(*arg, **kw) else: return url.URL(*arg, **kw) -def _connectable_has_table(connectable, tablename, schemaname): +def _connectable_has_table( + connectable: "Connection", tablename: str, schemaname: Union[str, None] +) -> bool: if sqla_14: return inspect(connectable).has_table(tablename, schemaname) else: @@ -148,23 +185,25 @@ def _nullability_might_be_unset(metadata_column): ) -def _server_default_is_computed(*server_default): +def _server_default_is_computed(*server_default) -> bool: if not has_computed: return False else: return any(isinstance(sd, Computed) for sd in server_default) -def _server_default_is_identity(*server_default): +def _server_default_is_identity(*server_default) -> bool: if not sqla_14: return False else: return any(isinstance(sd, Identity) for sd in server_default) -def _table_for_constraint(constraint): +def _table_for_constraint(constraint: "Constraint") -> "Table": if isinstance(constraint, ForeignKeyConstraint): - return constraint.parent + table = constraint.parent + assert table is not None + return table else: return constraint.table @@ -178,7 +217,9 @@ def _columns_for_constraint(constraint): return list(constraint.columns) -def _reflect_table(inspector, table, include_cols): +def _reflect_table( + inspector: "Inspector", table: "Table", include_cols: None +) -> None: if sqla_14: return inspector.reflect_table(table, None) else: @@ -213,19 +254,20 @@ def _fk_spec(constraint): ) -def _fk_is_self_referential(constraint): - spec = constraint.elements[0]._get_colspec() +def _fk_is_self_referential(constraint: "ForeignKeyConstraint") -> bool: + spec = constraint.elements[0]._get_colspec() # type: ignore[attr-defined] tokens = spec.split(".") tokens.pop(-1) # colname tablekey = ".".join(tokens) + assert constraint.parent is not None return tablekey == constraint.parent.key -def _is_type_bound(constraint): +def _is_type_bound(constraint: "Constraint") -> bool: # this deals with SQLAlchemy #3260, don't copy CHECK constraints # that will be generated by the type. # new feature added for #3260 - return constraint._type_bound + return constraint._type_bound # type: ignore[attr-defined] def _find_columns(clause): @@ -236,16 +278,21 @@ def _find_columns(clause): return cols -def _remove_column_from_collection(collection, column): +def _remove_column_from_collection( + collection: "ColumnCollection", column: Union["Column", "ColumnClause"] +) -> None: """remove a column from a ColumnCollection.""" # workaround for older SQLAlchemy, remove the # same object that's present + assert column.key is not None to_remove = collection[column.key] collection.remove(to_remove) -def _textual_index_column(table, text_): +def _textual_index_column( + table: "Table", text_: Union[str, "TextClause", "ColumnElement"] +) -> Union["ColumnElement", "Column"]: """a workaround for the Index construct's severe lack of flexibility""" if isinstance(text_, compat.string_types): c = Column(text_, sqltypes.NULLTYPE) @@ -259,7 +306,7 @@ def _textual_index_column(table, text_): raise ValueError("String or text() construct expected") -def _copy_expression(expression, target_table): +def _copy_expression(expression: _CE, target_table: "Table") -> _CE: def replace(col): if ( isinstance(col, Column) @@ -296,7 +343,7 @@ class _textual_index_element(sql.ColumnElement): __visit_name__ = "_textual_idx_element" - def __init__(self, table, text): + def __init__(self, table: "Table", text: "TextClause") -> None: self.table = table self.text = text self.key = text.text @@ -308,16 +355,20 @@ class _textual_index_element(sql.ColumnElement): @compiles(_textual_index_element) -def _render_textual_index_column(element, compiler, **kw): +def _render_textual_index_column( + element: _textual_index_element, compiler: "SQLCompiler", **kw +) -> str: return compiler.process(element.text, **kw) -class _literal_bindparam(_BindParamClause): +class _literal_bindparam(BindParameter): pass @compiles(_literal_bindparam) -def _render_literal_bindparam(element, compiler, **kw): +def _render_literal_bindparam( + element: _literal_bindparam, compiler: "SQLCompiler", **kw +) -> str: return compiler.render_literal_bindparam(element, **kw) @@ -329,17 +380,20 @@ def _get_index_column_names(idx): return [getattr(exp, "name", None) for exp in _get_index_expressions(idx)] -def _column_kwargs(col): +def _column_kwargs(col: "Column") -> Mapping: if sqla_13: return col.kwargs else: return {} -def _get_constraint_final_name(constraint, dialect): +def _get_constraint_final_name( + constraint: Union["Index", "Constraint"], dialect: Optional["Dialect"] +) -> Optional[str]: if constraint.name is None: return None - elif sqla_14: + assert dialect is not None + if sqla_14: # for SQLAlchemy 1.4 we would like to have the option to expand # the use of "deferred" names for constraints as well as to have # some flexibility with "None" name and similar; make use of new @@ -355,7 +409,7 @@ def _get_constraint_final_name(constraint, dialect): if hasattr(constraint.name, "quote"): # might be quoted_name, might be truncated_name, keep it the # same - quoted_name_cls = type(constraint.name) + quoted_name_cls: type = type(constraint.name) else: quoted_name_cls = quoted_name @@ -364,7 +418,8 @@ def _get_constraint_final_name(constraint, dialect): if isinstance(constraint, schema.Index): # name should not be quoted. - return dialect.ddl_compiler(dialect, None)._prepared_index_name( + d = dialect.ddl_compiler(dialect, None) + return d._prepared_index_name( # type: ignore[attr-defined] constraint ) else: @@ -372,10 +427,13 @@ def _get_constraint_final_name(constraint, dialect): return dialect.identifier_preparer.format_constraint(constraint) -def _constraint_is_named(constraint, dialect): +def _constraint_is_named( + constraint: Union["Constraint", "Index"], dialect: Optional["Dialect"] +) -> bool: if sqla_14: if constraint.name is None: return False + assert dialect is not None name = dialect.identifier_preparer.format_constraint( constraint, _alembic_quote=False ) @@ -384,18 +442,21 @@ def _constraint_is_named(constraint, dialect): return constraint.name is not None -def _is_mariadb(mysql_dialect): +def _is_mariadb(mysql_dialect: "Dialect") -> bool: if sqla_14: - return mysql_dialect.is_mariadb + return mysql_dialect.is_mariadb # type: ignore[attr-defined] else: - return mysql_dialect.server_version_info and mysql_dialect._is_mariadb + return bool( + mysql_dialect.server_version_info + and mysql_dialect._is_mariadb # type: ignore[attr-defined] + ) def _mariadb_normalized_version_info(mysql_dialect): return mysql_dialect._mariadb_normalized_version_info -def _insert_inline(table): +def _insert_inline(table: Union["TableClause", "Table"]) -> "Insert": if sqla_14: return table.insert().inline() else: @@ -408,10 +469,10 @@ if sqla_14: else: from sqlalchemy import create_engine - def create_mock_engine(url, executor): + def create_mock_engine(url, executor, **kw): # type: ignore[misc] return create_engine( "postgresql://", strategy="mock", executor=executor ) - def _select(*columns): - return sql.select(list(columns)) + def _select(*columns, **kw) -> "Select": + return sql.select(list(columns), **kw) |