summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSeth Morton <seth.m.morton@gmail.com>2021-10-27 09:13:05 -0700
committerSeth Morton <seth.m.morton@gmail.com>2021-10-27 09:13:05 -0700
commit31235520d7f630df371edd35d53b39fd6b5987d3 (patch)
treee97aed212c173b365919fe0b7d8f4be649886748
parentd8b0f2547188772831e24c69037c6139d4c5f4ab (diff)
downloadnatsort-31235520d7f630df371edd35d53b39fd6b5987d3.tar.gz
Fully type hint natsort.py
-rw-r--r--natsort/natsort.py257
1 files changed, 233 insertions, 24 deletions
diff --git a/natsort/natsort.py b/natsort/natsort.py
index 1402cf8..1dcf9ce 100644
--- a/natsort/natsort.py
+++ b/natsort/natsort.py
@@ -9,18 +9,55 @@ The majority of the "work" is defined in utils.py.
import platform
from functools import partial
from operator import itemgetter
-from typing import Callable, Iterable, TypeVar
-
-from _typeshed import SupportsLessThan
+from typing import (
+ Any as Any_t,
+ Callable,
+ Iterable,
+ Iterator,
+ List,
+ Optional,
+ Sequence,
+ Tuple,
+ TypeVar,
+ Union,
+ cast,
+ overload,
+)
import natsort.compat.locale
from natsort import utils
-from natsort.ns_enum import NS_DUMB, ns
-
-_T = TypeVar("_T")
-
-
-def decoder(encoding):
+from natsort.compat.locale import StrOrBytes
+from natsort.ns_enum import NS_DUMB, NS_t, ns
+from natsort.utils import KeyType, NatsortInType, NatsortOutType, PathArg
+
+# Some generic types
+T_ns = TypeVar("T_ns", bound=NatsortInType)
+T_path = TypeVar("T_path", bound=PathArg)
+T_any = TypeVar("T_any")
+
+# Keys that natsort accepts
+MaybeKeyType = Optional[KeyType]
+PathKeyType = Callable[[Any_t], PathArg]
+MaybePathKeyType = Optional[PathKeyType]
+
+# Common input and output types
+Iter_ns = Iterable[T_ns]
+Iter_any = Iterable[T_any]
+Iter_path = Iterable[T_path]
+List_ns = List[T_ns]
+List_any = List[T_any]
+List_path = List[T_path]
+List_int = List[int]
+
+# The type that natsort_key returns
+NatsortKeyType = Callable[[NatsortInType], NatsortOutType]
+
+# The type that os_sort_key returns
+OSSortKeyType_ = Callable[[PathArg], Tuple[Tuple[StrOrBytes, ...], ...]]
+OSSortKeyType = Union[OSSortKeyType_, NatsortKeyType]
+
+
+def decoder(encoding: str) -> Callable[[NatsortInType], NatsortInType]:
"""
Return a function that can be used to decode bytes to unicode.
@@ -61,7 +98,7 @@ def decoder(encoding):
return partial(utils.do_decoding, encoding=encoding)
-def as_ascii(s):
+def as_ascii(s: NatsortInType) -> NatsortInType:
"""
Function to decode an input with the ASCII codec, or return as-is.
@@ -84,7 +121,7 @@ def as_ascii(s):
return utils.do_decoding(s, "ascii")
-def as_utf8(s):
+def as_utf8(s: NatsortInType) -> NatsortInType:
"""
Function to decode an input with the UTF-8 codec, or return as-is.
@@ -107,7 +144,7 @@ def as_utf8(s):
return utils.do_decoding(s, "utf-8")
-def natsort_keygen(key=None, alg=ns.DEFAULT):
+def natsort_keygen(key=MaybeKeyType, alg: NS_t = ns.DEFAULT) -> NatsortKeyType:
"""
Generate a key to sort strings and numbers naturally.
@@ -217,7 +254,67 @@ natsort_keygen
"""
-def natsorted(seq: Iterable[_T], key: Callable[[_T], SupportsLessThan]=None, reverse=False, alg=ns.DEFAULT):
+@overload
+def natsorted(seq: Iter_ns) -> List_ns:
+ ...
+
+
+@overload
+def natsorted(seq: Iter_ns, reverse: bool) -> List_ns:
+ ...
+
+
+@overload
+def natsorted(seq: Iter_ns, alg: NS_t) -> List_ns:
+ ...
+
+
+@overload
+def natsorted(seq: Iter_ns, reverse: bool, alg: NS_t) -> List_ns:
+ ...
+
+
+@overload
+def natsorted(seq: Iter_ns, key: None) -> List_ns:
+ ...
+
+
+@overload
+def natsorted(seq: Iter_ns, key: None, reverse: bool) -> List_ns:
+ ...
+
+
+@overload
+def natsorted(seq: Iter_ns, key: None, alg: NS_t) -> List_ns:
+ ...
+
+
+@overload
+def natsorted(seq: Iter_ns, key: None, reverse: bool, alg: NS_t) -> List_ns:
+ ...
+
+
+@overload
+def natsorted(seq: Iter_any, key: KeyType) -> List_any:
+ ...
+
+
+@overload
+def natsorted(seq: Iter_any, key: KeyType, reverse: bool) -> List_any:
+ ...
+
+
+@overload
+def natsorted(seq: Iter_any, key: KeyType, alg: NS_t) -> List_any:
+ ...
+
+
+@overload
+def natsorted(seq: Iter_any, key: KeyType, reverse: bool, alg: NS_t) -> List_any:
+ ...
+
+
+def natsorted(seq, key=None, reverse=False, alg=ns.DEFAULT):
"""
Sorts an iterable naturally.
@@ -266,7 +363,9 @@ def natsorted(seq: Iterable[_T], key: Callable[[_T], SupportsLessThan]=None, rev
return sorted(seq, reverse=reverse, key=key)
-def humansorted(seq, key=None, reverse=False, alg=ns.DEFAULT):
+def humansorted(
+ seq: Iter_any, key=MaybeKeyType, reverse: bool = False, alg: NS_t = ns.DEFAULT
+) -> List_any:
"""
Convenience function to properly sort non-numeric characters.
@@ -318,7 +417,12 @@ def humansorted(seq, key=None, reverse=False, alg=ns.DEFAULT):
return natsorted(seq, key, reverse, alg | ns.LOCALE)
-def realsorted(seq, key=None, reverse=False, alg=ns.DEFAULT):
+def realsorted(
+ seq: Iter_any,
+ key: MaybeKeyType = None,
+ reverse: bool = False,
+ alg: NS_t = ns.DEFAULT,
+) -> List_any:
"""
Convenience function to properly sort signed floats.
@@ -371,6 +475,66 @@ def realsorted(seq, key=None, reverse=False, alg=ns.DEFAULT):
return natsorted(seq, key, reverse, alg | ns.REAL)
+@overload
+def index_natsorted(seq: Iter_ns) -> List_int:
+ ...
+
+
+@overload
+def index_natsorted(seq: Iter_ns, reverse: bool) -> List_int:
+ ...
+
+
+@overload
+def index_natsorted(seq: Iter_ns, alg: NS_t) -> List_int:
+ ...
+
+
+@overload
+def index_natsorted(seq: Iter_ns, reverse: bool, alg: NS_t) -> List_int:
+ ...
+
+
+@overload
+def index_natsorted(seq: Iter_ns, key: None) -> List_int:
+ ...
+
+
+@overload
+def index_natsorted(seq: Iter_ns, key: None, reverse: bool) -> List_int:
+ ...
+
+
+@overload
+def index_natsorted(seq: Iter_ns, key: None, alg: NS_t) -> List_int:
+ ...
+
+
+@overload
+def index_natsorted(seq: Iter_ns, key: None, reverse: bool, alg: NS_t) -> List_int:
+ ...
+
+
+@overload
+def index_natsorted(seq: Iter_any, key: KeyType) -> List_int:
+ ...
+
+
+@overload
+def index_natsorted(seq: Iter_any, key: KeyType, reverse: bool) -> List_int:
+ ...
+
+
+@overload
+def index_natsorted(seq: Iter_any, key: KeyType, alg: NS_t) -> List_int:
+ ...
+
+
+@overload
+def index_natsorted(seq: Iter_any, key: KeyType, reverse: bool, alg: NS_t) -> List_int:
+ ...
+
+
def index_natsorted(seq, key=None, reverse=False, alg=ns.DEFAULT):
"""
Determine the list of the indexes used to sort the input sequence.
@@ -427,6 +591,7 @@ def index_natsorted(seq, key=None, reverse=False, alg=ns.DEFAULT):
['baz', 'foo', 'bar']
"""
+ newkey: KeyType
if key is None:
newkey = itemgetter(1)
else:
@@ -440,7 +605,12 @@ def index_natsorted(seq, key=None, reverse=False, alg=ns.DEFAULT):
return [x for x, _ in index_seq_pair]
-def index_humansorted(seq, key=None, reverse=False, alg=ns.DEFAULT):
+def index_humansorted(
+ seq: Iter_any,
+ key: MaybeKeyType = None,
+ reverse: bool = False,
+ alg: NS_t = ns.DEFAULT,
+) -> List_int:
"""
This is a wrapper around ``index_natsorted(seq, alg=ns.LOCALE)``.
@@ -489,7 +659,12 @@ def index_humansorted(seq, key=None, reverse=False, alg=ns.DEFAULT):
return index_natsorted(seq, key, reverse, alg | ns.LOCALE)
-def index_realsorted(seq, key=None, reverse=False, alg=ns.DEFAULT):
+def index_realsorted(
+ seq: Iter_any,
+ key: MaybeKeyType = None,
+ reverse: bool = False,
+ alg: NS_t = ns.DEFAULT,
+) -> List_int:
"""
This is a wrapper around ``index_natsorted(seq, alg=ns.REAL)``.
@@ -535,7 +710,9 @@ def index_realsorted(seq, key=None, reverse=False, alg=ns.DEFAULT):
# noinspection PyShadowingBuiltins,PyUnresolvedReferences
-def order_by_index(seq, index, iter=False):
+def order_by_index(
+ seq: Sequence[T_any], index: Iterable[int], iter: bool = False
+) -> Iter_any:
"""
Order a given sequence by an index sequence.
@@ -594,7 +771,7 @@ def order_by_index(seq, index, iter=False):
return (seq[i] for i in index) if iter else [seq[i] for i in index]
-def numeric_regex_chooser(alg):
+def numeric_regex_chooser(alg: NS_t) -> str:
"""
Select an appropriate regex for the type of number of interest.
@@ -613,7 +790,7 @@ def numeric_regex_chooser(alg):
return utils.regex_chooser(alg).pattern[1:-1]
-def _split_apply(v, key=None):
+def _split_apply(v: Any_t, key: MaybePathKeyType = None) -> Iterator[str]:
if key is not None:
v = key(v)
return utils.path_splitter(str(v))
@@ -630,8 +807,10 @@ if platform.system() == "Windows":
_windows_sort_cmp.restype = wintypes.INT
_winsort_key = cmp_to_key(_windows_sort_cmp)
- def os_sort_keygen(key=None):
- return lambda x: tuple(map(_winsort_key, _split_apply(x, key)))
+ def os_sort_keygen(key: MaybePathKeyType = None) -> OSSortKeyType:
+ return cast(
+ OSSortKeyType_, lambda x: tuple(map(_winsort_key, _split_apply(x, key)))
+ )
else:
@@ -650,12 +829,12 @@ else:
except ImportError:
# No ICU installed
- def os_sort_keygen(key=None):
+ def os_sort_keygen(key: MaybePathKeyType = None) -> OSSortKeyType:
return natsort_keygen(key=key, alg=ns.LOCALE | ns.PATH | ns.IGNORECASE)
else:
# ICU installed
- def os_sort_keygen(key=None):
+ def os_sort_keygen(key: MaybePathKeyType = None) -> OSSortKeyType:
loc = natsort.compat.locale.get_icu_locale()
collator = icu.Collator.createInstance(loc)
collator.setAttribute(
@@ -702,6 +881,36 @@ os_sort_keygen
"""
+@overload
+def os_sorted(seq: Iter_path) -> List_path:
+ ...
+
+
+@overload
+def os_sorted(seq: Iter_path, reverse: bool) -> List_path:
+ ...
+
+
+@overload
+def os_sorted(seq: Iter_path, key: None) -> List_path:
+ ...
+
+
+@overload
+def os_sorted(seq: Iter_path, key: None, reverse: bool) -> List_path:
+ ...
+
+
+@overload
+def os_sorted(seq: Iter_any, key: PathKeyType) -> List_any:
+ ...
+
+
+@overload
+def os_sorted(seq: Iter_any, key: PathKeyType, reverse: bool) -> List_any:
+ ...
+
+
def os_sorted(seq, key=None, reverse=False):
"""
Sort elements in the same order as your operating system's file browser