summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/util/langhelpers.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/util/langhelpers.py')
-rw-r--r--lib/sqlalchemy/util/langhelpers.py47
1 files changed, 34 insertions, 13 deletions
diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py
index ed879894d..9e024b3c0 100644
--- a/lib/sqlalchemy/util/langhelpers.py
+++ b/lib/sqlalchemy/util/langhelpers.py
@@ -9,6 +9,7 @@
modules, classes, hierarchies, attributes, functions, and methods.
"""
+from __future__ import annotations
import collections
from functools import update_wrapper
@@ -452,7 +453,9 @@ def get_func_kwargs(func):
return compat.inspect_getfullargspec(func)[0]
-def get_callable_argspec(fn, no_self=False, _is_init=False):
+def get_callable_argspec(
+ fn: Callable[..., Any], no_self: bool = False, _is_init: bool = False
+) -> compat.FullArgSpec:
"""Return the argument signature for any callable.
All pure-Python callables are accepted, including
@@ -496,10 +499,12 @@ def get_callable_argspec(fn, no_self=False, _is_init=False):
fn.__init__, no_self=no_self, _is_init=True
)
elif hasattr(fn, "__func__"):
- return compat.inspect_getfullargspec(fn.__func__)
+ return compat.inspect_getfullargspec(fn.__func__) # type: ignore[attr-defined] # noqa E501
elif hasattr(fn, "__call__"):
- if inspect.ismethod(fn.__call__):
- return get_callable_argspec(fn.__call__, no_self=no_self)
+ if inspect.ismethod(fn.__call__): # type: ignore [operator]
+ return get_callable_argspec(
+ fn.__call__, no_self=no_self # type: ignore [operator]
+ )
else:
raise TypeError("Can't inspect callable: %s" % fn)
else:
@@ -1521,7 +1526,12 @@ class hybridmethod:
class _symbol(int):
name: str
- def __new__(cls, name, doc=None, canonical=None):
+ def __new__(
+ cls,
+ name: str,
+ doc: Optional[str] = None,
+ canonical: Optional[int] = None,
+ ) -> "_symbol":
"""Construct a new named symbol."""
assert isinstance(name, str)
if canonical is None:
@@ -1570,7 +1580,12 @@ class symbol:
symbols: Dict[str, "_symbol"] = {}
_lock = threading.Lock()
- def __new__(cls, name, doc=None, canonical=None):
+ def __new__( # type: ignore[misc]
+ cls,
+ name: str,
+ doc: Optional[str] = None,
+ canonical: Optional[int] = None,
+ ) -> _symbol:
with cls._lock:
sym = cls.symbols.get(name)
if sym is None:
@@ -1730,13 +1745,15 @@ def _warnings_warn(message, category=None, stacklevel=2):
warnings.warn(message, stacklevel=stacklevel + 1)
-def only_once(fn, retry_on_exception):
+def only_once(
+ fn: Callable[..., _T], retry_on_exception: bool
+) -> Callable[..., Optional[_T]]:
"""Decorate the given function to be a no-op after it is called exactly
once."""
once = [fn]
- def go(*arg, **kw):
+ def go(*arg: Any, **kw: Any) -> Optional[_T]:
# strong reference fn so that it isn't garbage collected,
# which interferes with the event system's expectations
strong_fn = fn # noqa
@@ -1749,6 +1766,8 @@ def only_once(fn, retry_on_exception):
once.insert(0, once_fn)
raise
+ return None
+
return go
@@ -1936,7 +1955,7 @@ def add_parameter_text(params, text):
return decorate
-def _dedent_docstring(text):
+def _dedent_docstring(text: str) -> str:
split_text = text.split("\n", 1)
if len(split_text) == 1:
return text
@@ -1948,8 +1967,10 @@ def _dedent_docstring(text):
return textwrap.dedent(text)
-def inject_docstring_text(doctext, injecttext, pos):
- doctext = _dedent_docstring(doctext or "")
+def inject_docstring_text(
+ given_doctext: Optional[str], injecttext: str, pos: int
+) -> str:
+ doctext: str = _dedent_docstring(given_doctext or "")
lines = doctext.split("\n")
if len(lines) == 1:
lines.append("")
@@ -1969,7 +1990,7 @@ def inject_docstring_text(doctext, injecttext, pos):
_param_reg = re.compile(r"(\s+):param (.+?):")
-def inject_param_text(doctext, inject_params):
+def inject_param_text(doctext: str, inject_params: Dict[str, str]) -> str:
doclines = collections.deque(doctext.splitlines())
lines = []
@@ -2012,7 +2033,7 @@ def inject_param_text(doctext, inject_params):
return "\n".join(lines)
-def repr_tuple_names(names):
+def repr_tuple_names(names: List[str]) -> Optional[str]:
"""Trims a list of strings from the middle and return a string of up to
four elements. Strings greater than 11 characters will be truncated"""
if len(names) == 0: