summaryrefslogtreecommitdiff
path: root/alembic/util
diff options
context:
space:
mode:
authorCaselIT <cfederico87@gmail.com>2021-04-18 15:44:50 +0200
committerMike Bayer <mike_mp@zzzcomputing.com>2021-08-11 15:04:56 -0400
commit6aad68605f510e8b51f42efa812e02b3831d6e33 (patch)
treecc0e98b8ad8245add8692d8e4910faf57abf7ae3 /alembic/util
parent3bf6a326c0a11e4f05c94008709d6b0b8e9e051a (diff)
downloadalembic-6aad68605f510e8b51f42efa812e02b3831d6e33.tar.gz
Add pep-484 type annotations
pep-484 type annotations have been added throughout the library. This should be helpful in providing Mypy and IDE support, however there is not full support for Alembic's dynamically modified "op" namespace as of yet; a future release will likely modify the approach used for importing this namespace to be better compatible with pep-484 capabilities. Type originally created using MonkeyType Add types extracted with the MonkeyType https://github.com/instagram/MonkeyType library by running the unit tests using ``monkeytype run -m pytest tests``, then ``monkeytype apply <module>`` (see below for further details). USed MonkeyType version 20.5 on Python 3.8, since newer version have issues After applying the types, the new imports are placed in a ``TYPE_CHECKING`` guard and all type definition of non base types are deferred by using the string notation. NOTE: since to apply the types MonkeType need to import the module, also the test ones, the patch below mocks the setup done by pytest so that the tests could be correctly imported diff --git a/alembic/testing/__init__.py b/alembic/testing/__init__.py index bdd1746..b1090c7 100644 Change-Id: Iff93628f4b43c740848871ce077a118db5e75d41 --- a/alembic/testing/__init__.py +++ b/alembic/testing/__init__.py @@ -9,6 +9,12 @@ from sqlalchemy.testing.config import combinations from sqlalchemy.testing.config import fixture from sqlalchemy.testing.config import requirements as requires +from sqlalchemy.testing.plugin.pytestplugin import PytestFixtureFunctions +from sqlalchemy.testing.plugin.plugin_base import _setup_requirements + +config._fixture_functions = PytestFixtureFunctions() +_setup_requirements("tests.requirements:DefaultRequirements") + from alembic import util from .assertions import assert_raises from .assertions import assert_raises_message Currently I'm using this branch of the sqlalchemy stubs: https://github.com/sqlalchemy/sqlalchemy2-stubs/tree/alembic_updates Change-Id: I8fd0700aab1913f395302626b8b84fea60334abd
Diffstat (limited to 'alembic/util')
-rw-r--r--alembic/util/compat.py42
-rw-r--r--alembic/util/editor.py20
-rw-r--r--alembic/util/langhelpers.py66
-rw-r--r--alembic/util/messaging.py23
-rw-r--r--alembic/util/pyfiles.py27
-rw-r--r--alembic/util/sqla_compat.py149
6 files changed, 219 insertions, 108 deletions
diff --git a/alembic/util/compat.py b/alembic/util/compat.py
index 0fdd86d..a07813c 100644
--- a/alembic/util/compat.py
+++ b/alembic/util/compat.py
@@ -2,6 +2,12 @@ import collections
import inspect
import io
import os
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Type
is_posix = os.name == "posix"
@@ -10,11 +16,11 @@ ArgSpec = collections.namedtuple(
)
-def inspect_getargspec(func):
+def inspect_getargspec(func: Callable) -> ArgSpec:
"""getargspec based on fully vendored getfullargspec from Python 3.3."""
if inspect.ismethod(func):
- func = func.__func__
+ func = func.__func__ # type: ignore
if not inspect.isfunction(func):
raise TypeError("{!r} is not a Python function".format(func))
@@ -36,7 +42,7 @@ def inspect_getargspec(func):
if co.co_flags & inspect.CO_VARKEYWORDS:
varkw = co.co_varnames[nargs]
- return ArgSpec(args, varargs, varkw, func.__defaults__)
+ return ArgSpec(args, varargs, varkw, func.__defaults__) # type: ignore
string_types = (str,)
@@ -57,20 +63,20 @@ def _formatannotation(annotation, base_module=None):
def inspect_formatargspec(
- args,
- varargs=None,
- varkw=None,
- defaults=None,
- kwonlyargs=(),
- kwonlydefaults={},
- annotations={},
- formatarg=str,
- formatvarargs=lambda name: "*" + name,
- formatvarkw=lambda name: "**" + name,
- formatvalue=lambda value: "=" + repr(value),
- formatreturns=lambda text: " -> " + text,
- formatannotation=_formatannotation,
-):
+ args: List[str],
+ varargs: Optional[str] = None,
+ varkw: Optional[str] = None,
+ defaults: Optional[Any] = None,
+ kwonlyargs: tuple = (),
+ kwonlydefaults: Dict[Any, Any] = {},
+ annotations: Dict[Any, Any] = {},
+ formatarg: Type[str] = str,
+ formatvarargs: Callable = lambda name: "*" + name,
+ formatvarkw: Callable = lambda name: "**" + name,
+ formatvalue: Callable = lambda value: "=" + repr(value),
+ formatreturns: Callable = lambda text: " -> " + text,
+ formatannotation: Callable = _formatannotation,
+) -> str:
"""Copy formatargspec from python 3.7 standard library.
Python 3 has deprecated formatargspec and requested that Signature
@@ -118,5 +124,5 @@ def inspect_formatargspec(
# into a given buffer, but doesn't close it.
# not sure of a more idiomatic approach to this.
class EncodedIO(io.TextIOWrapper):
- def close(self):
+ def close(self) -> None:
pass
diff --git a/alembic/util/editor.py b/alembic/util/editor.py
index c27f0f3..ba376c0 100644
--- a/alembic/util/editor.py
+++ b/alembic/util/editor.py
@@ -3,12 +3,18 @@ from os.path import exists
from os.path import join
from os.path import splitext
from subprocess import check_call
+from typing import Dict
+from typing import List
+from typing import Mapping
+from typing import Optional
from .compat import is_posix
from .exc import CommandError
-def open_in_editor(filename, environ=None):
+def open_in_editor(
+ filename: str, environ: Optional[Dict[str, str]] = None
+) -> None:
"""
Opens the given file in a text editor. If the environment variable
``EDITOR`` is set, this is taken as preference.
@@ -22,15 +28,15 @@ def open_in_editor(filename, environ=None):
:param environ: An optional drop-in replacement for ``os.environ``. Used
mainly for testing.
"""
-
+ env = os.environ if environ is None else environ
try:
- editor = _find_editor(environ)
+ editor = _find_editor(env)
check_call([editor, filename])
except Exception as exc:
raise CommandError("Error executing editor (%s)" % (exc,)) from exc
-def _find_editor(environ=None):
+def _find_editor(environ: Mapping[str, str]) -> str:
candidates = _default_editors()
for i, var in enumerate(("EDITOR", "VISUAL")):
if var in environ:
@@ -50,7 +56,9 @@ def _find_editor(environ=None):
)
-def _find_executable(candidate, environ):
+def _find_executable(
+ candidate: str, environ: Mapping[str, str]
+) -> Optional[str]:
# Assuming this is on the PATH, we need to determine it's absolute
# location. Otherwise, ``check_call`` will fail
if not is_posix and splitext(candidate)[1] != ".exe":
@@ -62,7 +70,7 @@ def _find_executable(candidate, environ):
return None
-def _default_editors():
+def _default_editors() -> List[str]:
# Look for an editor. Prefer the user's choice by env-var, fall back to
# most commonly installed editor (nano/vim)
if is_posix:
diff --git a/alembic/util/langhelpers.py b/alembic/util/langhelpers.py
index dbd1f21..87a9aca 100644
--- a/alembic/util/langhelpers.py
+++ b/alembic/util/langhelpers.py
@@ -1,6 +1,16 @@
import collections
from collections.abc import Iterable
import textwrap
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import overload
+from typing import Sequence
+from typing import Tuple
+from typing import TypeVar
+from typing import Union
import uuid
import warnings
@@ -14,10 +24,13 @@ from .compat import inspect_getargspec
from .compat import string_types
+_T = TypeVar("_T")
+
+
class _ModuleClsMeta(type):
- def __setattr__(cls, key, value):
+ def __setattr__(cls, key: str, value: Callable) -> None:
super(_ModuleClsMeta, cls).__setattr__(key, value)
- cls._update_module_proxies(key)
+ cls._update_module_proxies(key) # type: ignore
class ModuleClsProxy(metaclass=_ModuleClsMeta):
@@ -29,22 +42,24 @@ class ModuleClsProxy(metaclass=_ModuleClsMeta):
"""
- _setups = collections.defaultdict(lambda: (set(), []))
+ _setups: Dict[type, Tuple[set, list]] = collections.defaultdict(
+ lambda: (set(), [])
+ )
@classmethod
- def _update_module_proxies(cls, name):
+ def _update_module_proxies(cls, name: str) -> None:
attr_names, modules = cls._setups[cls]
for globals_, locals_ in modules:
cls._add_proxied_attribute(name, globals_, locals_, attr_names)
- def _install_proxy(self):
+ def _install_proxy(self) -> None:
attr_names, modules = self._setups[self.__class__]
for globals_, locals_ in modules:
globals_["_proxy"] = self
for attr_name in attr_names:
globals_[attr_name] = getattr(self, attr_name)
- def _remove_proxy(self):
+ def _remove_proxy(self) -> None:
attr_names, modules = self._setups[self.__class__]
for globals_, locals_ in modules:
globals_["_proxy"] = None
@@ -171,10 +186,25 @@ def _with_legacy_names(translations):
return decorate
-def rev_id():
+def rev_id() -> str:
return uuid.uuid4().hex[-12:]
+@overload
+def to_tuple(x: Any, default: tuple) -> tuple:
+ ...
+
+
+@overload
+def to_tuple(x: None, default: _T = None) -> _T:
+ ...
+
+
+@overload
+def to_tuple(x: Any, default: Optional[tuple] = None) -> tuple:
+ ...
+
+
def to_tuple(x, default=None):
if x is None:
return default
@@ -186,16 +216,18 @@ def to_tuple(x, default=None):
return (x,)
-def dedupe_tuple(tup):
+def dedupe_tuple(tup: Tuple[str, ...]) -> Tuple[str, ...]:
return tuple(unique_list(tup))
class Dispatcher:
- def __init__(self, uselist=False):
- self._registry = {}
+ def __init__(self, uselist: bool = False) -> None:
+ self._registry: Dict[tuple, Any] = {}
self.uselist = uselist
- def dispatch_for(self, target, qualifier="default"):
+ def dispatch_for(
+ self, target: Any, qualifier: str = "default"
+ ) -> Callable:
def decorate(fn):
if self.uselist:
self._registry.setdefault((target, qualifier), []).append(fn)
@@ -206,10 +238,10 @@ class Dispatcher:
return decorate
- def dispatch(self, obj, qualifier="default"):
+ def dispatch(self, obj: Any, qualifier: str = "default") -> Any:
if isinstance(obj, string_types):
- targets = [obj]
+ targets: Sequence = [obj]
elif isinstance(obj, type):
targets = obj.__mro__
else:
@@ -223,7 +255,9 @@ class Dispatcher:
else:
raise ValueError("no dispatch function for object: %s" % obj)
- def _fn_or_list(self, fn_or_list):
+ def _fn_or_list(
+ self, fn_or_list: Union[List[Callable], Callable]
+ ) -> Callable:
if self.uselist:
def go(*arg, **kw):
@@ -232,9 +266,9 @@ class Dispatcher:
return go
else:
- return fn_or_list
+ return fn_or_list # type: ignore
- def branch(self):
+ def branch(self) -> "Dispatcher":
"""Return a copy of this dispatcher that is independently
writable."""
diff --git a/alembic/util/messaging.py b/alembic/util/messaging.py
index 70c9128..062890a 100644
--- a/alembic/util/messaging.py
+++ b/alembic/util/messaging.py
@@ -2,6 +2,11 @@ from collections.abc import Iterable
import logging
import sys
import textwrap
+from typing import Any
+from typing import Callable
+from typing import Optional
+from typing import TextIO
+from typing import Union
import warnings
from sqlalchemy.engine import url
@@ -29,7 +34,7 @@ except (ImportError, IOError):
TERMWIDTH = None
-def write_outstream(stream, *text):
+def write_outstream(stream: TextIO, *text) -> None:
encoding = getattr(stream, "encoding", "ascii") or "ascii"
for t in text:
if not isinstance(t, binary_type):
@@ -44,7 +49,7 @@ def write_outstream(stream, *text):
break
-def status(_statmsg, fn, *arg, **kw):
+def status(_statmsg: str, fn: Callable, *arg, **kw) -> Any:
newline = kw.pop("newline", False)
msg(_statmsg + " ...", newline, True)
try:
@@ -56,27 +61,27 @@ def status(_statmsg, fn, *arg, **kw):
raise
-def err(message):
+def err(message: str):
log.error(message)
msg("FAILED: %s" % message)
sys.exit(-1)
-def obfuscate_url_pw(u):
- u = url.make_url(u)
+def obfuscate_url_pw(input_url: str) -> str:
+ u = url.make_url(input_url)
if u.password:
if sqla_compat.sqla_14:
u = u.set(password="XXXXX")
else:
- u.password = "XXXXX"
+ u.password = "XXXXX" # type: ignore[misc]
return str(u)
-def warn(msg, stacklevel=2):
+def warn(msg: str, stacklevel: int = 2) -> None:
warnings.warn(msg, UserWarning, stacklevel=stacklevel)
-def msg(msg, newline=True, flush=False):
+def msg(msg: str, newline: bool = True, flush: bool = False) -> None:
if TERMWIDTH is None:
write_outstream(sys.stdout, msg)
if newline:
@@ -92,7 +97,7 @@ def msg(msg, newline=True, flush=False):
sys.stdout.flush()
-def format_as_comma(value):
+def format_as_comma(value: Optional[Union[str, "Iterable[str]"]]) -> str:
if value is None:
return ""
elif isinstance(value, string_types):
diff --git a/alembic/util/pyfiles.py b/alembic/util/pyfiles.py
index 53cc3cc..7eb582e 100644
--- a/alembic/util/pyfiles.py
+++ b/alembic/util/pyfiles.py
@@ -4,6 +4,7 @@ import importlib.util
import os
import re
import tempfile
+from typing import Optional
from mako import exceptions
from mako.template import Template
@@ -11,7 +12,9 @@ from mako.template import Template
from .exc import CommandError
-def template_to_file(template_file, dest, output_encoding, **kw):
+def template_to_file(
+ template_file: str, dest: str, output_encoding: str, **kw
+) -> None:
template = Template(filename=template_file)
try:
output = template.render_unicode(**kw).encode(output_encoding)
@@ -32,7 +35,7 @@ def template_to_file(template_file, dest, output_encoding, **kw):
f.write(output)
-def coerce_resource_to_filename(fname):
+def coerce_resource_to_filename(fname: str) -> str:
"""Interpret a filename as either a filesystem location or as a package
resource.
@@ -47,7 +50,7 @@ def coerce_resource_to_filename(fname):
return fname
-def pyc_file_from_path(path):
+def pyc_file_from_path(path: str) -> Optional[str]:
"""Given a python source path, locate the .pyc."""
candidate = importlib.util.cache_from_source(path)
@@ -64,7 +67,7 @@ def pyc_file_from_path(path):
return None
-def load_python_file(dir_, filename):
+def load_python_file(dir_: str, filename: str):
"""Load a file from the given path as a Python module."""
module_id = re.sub(r"\W", "_", filename)
@@ -78,21 +81,15 @@ def load_python_file(dir_, filename):
if pyc_path is None:
raise ImportError("Can't find Python file %s" % path)
else:
- module = load_module_pyc(module_id, pyc_path)
+ module = load_module_py(module_id, pyc_path)
elif ext in (".pyc", ".pyo"):
- module = load_module_pyc(module_id, path)
+ module = load_module_py(module_id, path)
return module
-def load_module_py(module_id, path):
+def load_module_py(module_id: str, path: str):
spec = importlib.util.spec_from_file_location(module_id, path)
+ assert spec
module = importlib.util.module_from_spec(spec)
- spec.loader.exec_module(module)
- return module
-
-
-def load_module_pyc(module_id, path):
- spec = importlib.util.spec_from_file_location(module_id, path)
- module = importlib.util.module_from_spec(spec)
- spec.loader.exec_module(module)
+ spec.loader.exec_module(module) # type: ignore
return module
diff --git a/alembic/util/sqla_compat.py b/alembic/util/sqla_compat.py
index a04ab2e..e1ccd41 100644
--- a/alembic/util/sqla_compat.py
+++ b/alembic/util/sqla_compat.py
@@ -1,5 +1,11 @@
import contextlib
import re
+from typing import Iterator
+from typing import Mapping
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import TypeVar
+from typing import Union
from sqlalchemy import __version__
from sqlalchemy import inspect
@@ -12,15 +18,34 @@ from sqlalchemy.schema import CheckConstraint
from sqlalchemy.schema import Column
from sqlalchemy.schema import ForeignKeyConstraint
from sqlalchemy.sql import visitors
+from sqlalchemy.sql.elements import BindParameter
from sqlalchemy.sql.elements import quoted_name
-from sqlalchemy.sql.expression import _BindParamClause
-from sqlalchemy.sql.expression import _TextClause as TextClause
+from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.visitors import traverse
from . import compat
-
-def _safe_int(value):
+if TYPE_CHECKING:
+ from sqlalchemy import Index
+ from sqlalchemy import Table
+ from sqlalchemy.engine import Connection
+ from sqlalchemy.engine import Dialect
+ from sqlalchemy.engine import Transaction
+ from sqlalchemy.engine.reflection import Inspector
+ from sqlalchemy.sql.base import ColumnCollection
+ from sqlalchemy.sql.compiler import SQLCompiler
+ from sqlalchemy.sql.dml import Insert
+ from sqlalchemy.sql.elements import ColumnClause
+ from sqlalchemy.sql.elements import ColumnElement
+ from sqlalchemy.sql.schema import Constraint
+ from sqlalchemy.sql.schema import SchemaItem
+ from sqlalchemy.sql.selectable import Select
+ from sqlalchemy.sql.selectable import TableClause
+
+_CE = TypeVar("_CE", bound=Union["ColumnElement", "SchemaItem"])
+
+
+def _safe_int(value: str) -> Union[int, str]:
try:
return int(value)
except:
@@ -36,6 +61,7 @@ sqla_14 = _vers >= (1, 4)
try:
from sqlalchemy import Computed # noqa
except ImportError:
+ Computed = None # type: ignore
has_computed = False
has_computed_reflection = False
else:
@@ -45,6 +71,7 @@ else:
try:
from sqlalchemy import Identity # noqa
except ImportError:
+ Identity = None # type: ignore
has_identity = False
else:
# attributes common to Indentity and Sequence
@@ -67,21 +94,26 @@ AUTOINCREMENT_DEFAULT = "auto"
@contextlib.contextmanager
-def _ensure_scope_for_ddl(connection):
+def _ensure_scope_for_ddl(
+ connection: Optional["Connection"],
+) -> Iterator[None]:
try:
- in_transaction = connection.in_transaction
+ in_transaction = connection.in_transaction # type: ignore[union-attr]
except AttributeError:
- # catch for MockConnection
+ # catch for MockConnection, None
yield
else:
if not in_transaction():
+ assert connection is not None
with connection.begin():
yield
else:
yield
-def _safe_begin_connection_transaction(connection):
+def _safe_begin_connection_transaction(
+ connection: "Connection",
+) -> "Transaction":
transaction = _get_connection_transaction(connection)
if transaction:
return transaction
@@ -89,9 +121,9 @@ def _safe_begin_connection_transaction(connection):
return connection.begin()
-def _get_connection_in_transaction(connection):
+def _get_connection_in_transaction(connection: Optional["Connection"]) -> bool:
try:
- in_transaction = connection.in_transaction
+ in_transaction = connection.in_transaction # type: ignore
except AttributeError:
# catch for MockConnection
return False
@@ -99,28 +131,33 @@ def _get_connection_in_transaction(connection):
return in_transaction()
-def _copy(schema_item, **kw):
+def _copy(schema_item: _CE, **kw) -> _CE:
if hasattr(schema_item, "_copy"):
return schema_item._copy(**kw)
else:
return schema_item.copy(**kw)
-def _get_connection_transaction(connection):
+def _get_connection_transaction(
+ connection: "Connection",
+) -> Optional["Transaction"]:
if sqla_14:
return connection.get_transaction()
else:
- return connection._root._Connection__transaction
+ r = connection._root # type: ignore[attr-defined]
+ return r._Connection__transaction
-def _create_url(*arg, **kw):
+def _create_url(*arg, **kw) -> url.URL:
if hasattr(url.URL, "create"):
return url.URL.create(*arg, **kw)
else:
return url.URL(*arg, **kw)
-def _connectable_has_table(connectable, tablename, schemaname):
+def _connectable_has_table(
+ connectable: "Connection", tablename: str, schemaname: Union[str, None]
+) -> bool:
if sqla_14:
return inspect(connectable).has_table(tablename, schemaname)
else:
@@ -148,23 +185,25 @@ def _nullability_might_be_unset(metadata_column):
)
-def _server_default_is_computed(*server_default):
+def _server_default_is_computed(*server_default) -> bool:
if not has_computed:
return False
else:
return any(isinstance(sd, Computed) for sd in server_default)
-def _server_default_is_identity(*server_default):
+def _server_default_is_identity(*server_default) -> bool:
if not sqla_14:
return False
else:
return any(isinstance(sd, Identity) for sd in server_default)
-def _table_for_constraint(constraint):
+def _table_for_constraint(constraint: "Constraint") -> "Table":
if isinstance(constraint, ForeignKeyConstraint):
- return constraint.parent
+ table = constraint.parent
+ assert table is not None
+ return table
else:
return constraint.table
@@ -178,7 +217,9 @@ def _columns_for_constraint(constraint):
return list(constraint.columns)
-def _reflect_table(inspector, table, include_cols):
+def _reflect_table(
+ inspector: "Inspector", table: "Table", include_cols: None
+) -> None:
if sqla_14:
return inspector.reflect_table(table, None)
else:
@@ -213,19 +254,20 @@ def _fk_spec(constraint):
)
-def _fk_is_self_referential(constraint):
- spec = constraint.elements[0]._get_colspec()
+def _fk_is_self_referential(constraint: "ForeignKeyConstraint") -> bool:
+ spec = constraint.elements[0]._get_colspec() # type: ignore[attr-defined]
tokens = spec.split(".")
tokens.pop(-1) # colname
tablekey = ".".join(tokens)
+ assert constraint.parent is not None
return tablekey == constraint.parent.key
-def _is_type_bound(constraint):
+def _is_type_bound(constraint: "Constraint") -> bool:
# this deals with SQLAlchemy #3260, don't copy CHECK constraints
# that will be generated by the type.
# new feature added for #3260
- return constraint._type_bound
+ return constraint._type_bound # type: ignore[attr-defined]
def _find_columns(clause):
@@ -236,16 +278,21 @@ def _find_columns(clause):
return cols
-def _remove_column_from_collection(collection, column):
+def _remove_column_from_collection(
+ collection: "ColumnCollection", column: Union["Column", "ColumnClause"]
+) -> None:
"""remove a column from a ColumnCollection."""
# workaround for older SQLAlchemy, remove the
# same object that's present
+ assert column.key is not None
to_remove = collection[column.key]
collection.remove(to_remove)
-def _textual_index_column(table, text_):
+def _textual_index_column(
+ table: "Table", text_: Union[str, "TextClause", "ColumnElement"]
+) -> Union["ColumnElement", "Column"]:
"""a workaround for the Index construct's severe lack of flexibility"""
if isinstance(text_, compat.string_types):
c = Column(text_, sqltypes.NULLTYPE)
@@ -259,7 +306,7 @@ def _textual_index_column(table, text_):
raise ValueError("String or text() construct expected")
-def _copy_expression(expression, target_table):
+def _copy_expression(expression: _CE, target_table: "Table") -> _CE:
def replace(col):
if (
isinstance(col, Column)
@@ -296,7 +343,7 @@ class _textual_index_element(sql.ColumnElement):
__visit_name__ = "_textual_idx_element"
- def __init__(self, table, text):
+ def __init__(self, table: "Table", text: "TextClause") -> None:
self.table = table
self.text = text
self.key = text.text
@@ -308,16 +355,20 @@ class _textual_index_element(sql.ColumnElement):
@compiles(_textual_index_element)
-def _render_textual_index_column(element, compiler, **kw):
+def _render_textual_index_column(
+ element: _textual_index_element, compiler: "SQLCompiler", **kw
+) -> str:
return compiler.process(element.text, **kw)
-class _literal_bindparam(_BindParamClause):
+class _literal_bindparam(BindParameter):
pass
@compiles(_literal_bindparam)
-def _render_literal_bindparam(element, compiler, **kw):
+def _render_literal_bindparam(
+ element: _literal_bindparam, compiler: "SQLCompiler", **kw
+) -> str:
return compiler.render_literal_bindparam(element, **kw)
@@ -329,17 +380,20 @@ def _get_index_column_names(idx):
return [getattr(exp, "name", None) for exp in _get_index_expressions(idx)]
-def _column_kwargs(col):
+def _column_kwargs(col: "Column") -> Mapping:
if sqla_13:
return col.kwargs
else:
return {}
-def _get_constraint_final_name(constraint, dialect):
+def _get_constraint_final_name(
+ constraint: Union["Index", "Constraint"], dialect: Optional["Dialect"]
+) -> Optional[str]:
if constraint.name is None:
return None
- elif sqla_14:
+ assert dialect is not None
+ if sqla_14:
# for SQLAlchemy 1.4 we would like to have the option to expand
# the use of "deferred" names for constraints as well as to have
# some flexibility with "None" name and similar; make use of new
@@ -355,7 +409,7 @@ def _get_constraint_final_name(constraint, dialect):
if hasattr(constraint.name, "quote"):
# might be quoted_name, might be truncated_name, keep it the
# same
- quoted_name_cls = type(constraint.name)
+ quoted_name_cls: type = type(constraint.name)
else:
quoted_name_cls = quoted_name
@@ -364,7 +418,8 @@ def _get_constraint_final_name(constraint, dialect):
if isinstance(constraint, schema.Index):
# name should not be quoted.
- return dialect.ddl_compiler(dialect, None)._prepared_index_name(
+ d = dialect.ddl_compiler(dialect, None)
+ return d._prepared_index_name( # type: ignore[attr-defined]
constraint
)
else:
@@ -372,10 +427,13 @@ def _get_constraint_final_name(constraint, dialect):
return dialect.identifier_preparer.format_constraint(constraint)
-def _constraint_is_named(constraint, dialect):
+def _constraint_is_named(
+ constraint: Union["Constraint", "Index"], dialect: Optional["Dialect"]
+) -> bool:
if sqla_14:
if constraint.name is None:
return False
+ assert dialect is not None
name = dialect.identifier_preparer.format_constraint(
constraint, _alembic_quote=False
)
@@ -384,18 +442,21 @@ def _constraint_is_named(constraint, dialect):
return constraint.name is not None
-def _is_mariadb(mysql_dialect):
+def _is_mariadb(mysql_dialect: "Dialect") -> bool:
if sqla_14:
- return mysql_dialect.is_mariadb
+ return mysql_dialect.is_mariadb # type: ignore[attr-defined]
else:
- return mysql_dialect.server_version_info and mysql_dialect._is_mariadb
+ return bool(
+ mysql_dialect.server_version_info
+ and mysql_dialect._is_mariadb # type: ignore[attr-defined]
+ )
def _mariadb_normalized_version_info(mysql_dialect):
return mysql_dialect._mariadb_normalized_version_info
-def _insert_inline(table):
+def _insert_inline(table: Union["TableClause", "Table"]) -> "Insert":
if sqla_14:
return table.insert().inline()
else:
@@ -408,10 +469,10 @@ if sqla_14:
else:
from sqlalchemy import create_engine
- def create_mock_engine(url, executor):
+ def create_mock_engine(url, executor, **kw): # type: ignore[misc]
return create_engine(
"postgresql://", strategy="mock", executor=executor
)
- def _select(*columns):
- return sql.select(list(columns))
+ def _select(*columns, **kw) -> "Select":
+ return sql.select(list(columns), **kw)