summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSeth Morton <seth.m.morton@gmail.com>2022-09-01 13:12:30 -0700
committerSeth Morton <seth.m.morton@gmail.com>2022-09-01 13:59:30 -0700
commit22cdc562baf60389a5f64a0ce241caf34d09abea (patch)
treed123aeaabb2e269028c558576d4e8a9aa126ed3f
parente5d2e4507728e53d1867ac87e169fca1d251d8cf (diff)
downloadnatsort-22cdc562baf60389a5f64a0ce241caf34d09abea.tar.gz
Simplify type hints for public functions
...and to some degree private as well. Previously, the declared hints for natsort were too restrictive. Generics and protocols are now utilized to make the type hints more "open" which is more realistic, since more than just basic types can be sorted.
-rw-r--r--natsort/natsort.py205
-rw-r--r--natsort/utils.py68
-rw-r--r--tests/test_os_sorted.py3
-rw-r--r--tests/test_parse_number_function.py4
-rw-r--r--tox.ini1
5 files changed, 94 insertions, 187 deletions
diff --git a/natsort/natsort.py b/natsort/natsort.py
index c0eec58..f649500 100644
--- a/natsort/natsort.py
+++ b/natsort/natsort.py
@@ -18,42 +18,27 @@ from typing import (
Optional,
Sequence,
Tuple,
- Union,
+ TypeVar,
cast,
- overload,
)
import natsort.compat.locale
from natsort import utils
from natsort.ns_enum import NSType, NS_DUMB, ns
-from natsort.utils import (
- KeyType,
- MaybeKeyType,
- NatsortInType,
- NatsortOutType,
- StrBytesNum,
- StrBytesPathNum,
-)
+from natsort.utils import NatsortInType, NatsortOutType
# Common input and output types
-Iter_ns = Iterable[NatsortInType]
-Iter_any = Iterable[Any]
-List_ns = List[NatsortInType]
-List_any = List[Any]
-List_int = List[int]
+T = TypeVar("T")
+NatsortInTypeT = TypeVar("NatsortInTypeT", bound=NatsortInType)
# The type that natsort_key returns
NatsortKeyType = Callable[[NatsortInType], NatsortOutType]
# Types for os_sorted
-OSSortInType = Iterable[Optional[StrBytesPathNum]]
-OSSortOutType = Tuple[Union[StrBytesNum, Tuple[StrBytesNum, ...]], ...]
-OSSortKeyType = Callable[[Optional[StrBytesPathNum]], OSSortOutType]
-Iter_path = Iterable[Optional[StrBytesPathNum]]
-List_path = List[StrBytesPathNum]
+OSSortKeyType = Callable[[NatsortInType], NatsortOutType]
-def decoder(encoding: str) -> Callable[[NatsortInType], NatsortInType]:
+def decoder(encoding: str) -> Callable[[Any], Any]:
"""
Return a function that can be used to decode bytes to unicode.
@@ -94,7 +79,7 @@ def decoder(encoding: str) -> Callable[[NatsortInType], NatsortInType]:
return partial(utils.do_decoding, encoding=encoding)
-def as_ascii(s: NatsortInType) -> NatsortInType:
+def as_ascii(s: Any) -> Any:
"""
Function to decode an input with the ASCII codec, or return as-is.
@@ -117,7 +102,7 @@ def as_ascii(s: NatsortInType) -> NatsortInType:
return utils.do_decoding(s, "ascii")
-def as_utf8(s: NatsortInType) -> NatsortInType:
+def as_utf8(s: Any) -> Any:
"""
Function to decode an input with the UTF-8 codec, or return as-is.
@@ -141,8 +126,8 @@ def as_utf8(s: NatsortInType) -> NatsortInType:
def natsort_keygen(
- key: MaybeKeyType = None, alg: NSType = ns.DEFAULT
-) -> NatsortKeyType:
+ key: Optional[Callable[[Any], NatsortInType]] = None, alg: NSType = ns.DEFAULT
+) -> Callable[[Any], NatsortOutType]:
"""
Generate a key to sort strings and numbers naturally.
@@ -252,26 +237,12 @@ natsort_keygen
"""
-@overload
-def natsorted(
- seq: Iter_ns, key: None = None, reverse: bool = False, alg: NSType = ns.DEFAULT
-) -> List_ns:
- ...
-
-
-@overload
def natsorted(
- seq: Iter_any, key: KeyType, reverse: bool = False, alg: NSType = ns.DEFAULT
-) -> List_any:
- ...
-
-
-def natsorted(
- seq: Iter_any,
- key: MaybeKeyType = None,
+ seq: Iterable[T],
+ key: Optional[Callable[[T], NatsortInType]] = None,
reverse: bool = False,
alg: NSType = ns.DEFAULT,
-) -> List_any:
+) -> List[T]:
"""
Sorts an iterable naturally.
@@ -319,26 +290,12 @@ def natsorted(
return sorted(seq, reverse=reverse, key=natsort_keygen(key, alg))
-@overload
-def humansorted(
- seq: Iter_ns, key: None = None, reverse: bool = False, alg: NSType = ns.DEFAULT
-) -> List_ns:
- ...
-
-
-@overload
-def humansorted(
- seq: Iter_any, key: KeyType, reverse: bool = False, alg: NSType = ns.DEFAULT
-) -> List_any:
- ...
-
-
def humansorted(
- seq: Iter_any,
- key: MaybeKeyType = None,
+ seq: Iterable[T],
+ key: Optional[Callable[[T], NatsortInType]] = None,
reverse: bool = False,
alg: NSType = ns.DEFAULT,
-) -> List_any:
+) -> List[T]:
"""
Convenience function to properly sort non-numeric characters.
@@ -390,26 +347,12 @@ def humansorted(
return natsorted(seq, key, reverse, alg | ns.LOCALE)
-@overload
-def realsorted(
- seq: Iter_ns, key: None = None, reverse: bool = False, alg: NSType = ns.DEFAULT
-) -> List_ns:
- ...
-
-
-@overload
def realsorted(
- seq: Iter_any, key: KeyType, reverse: bool = False, alg: NSType = ns.DEFAULT
-) -> List_any:
- ...
-
-
-def realsorted(
- seq: Iter_any,
- key: MaybeKeyType = None,
+ seq: Iterable[T],
+ key: Optional[Callable[[T], NatsortInType]] = None,
reverse: bool = False,
alg: NSType = ns.DEFAULT,
-) -> List_any:
+) -> List[T]:
"""
Convenience function to properly sort signed floats.
@@ -462,26 +405,12 @@ def realsorted(
return natsorted(seq, key, reverse, alg | ns.REAL)
-@overload
-def index_natsorted(
- seq: Iter_ns, key: None = None, reverse: bool = False, alg: NSType = ns.DEFAULT
-) -> List_int:
- ...
-
-
-@overload
def index_natsorted(
- seq: Iter_any, key: KeyType, reverse: bool = False, alg: NSType = ns.DEFAULT
-) -> List_int:
- ...
-
-
-def index_natsorted(
- seq: Iter_any,
- key: MaybeKeyType = None,
+ seq: Iterable[T],
+ key: Optional[Callable[[T], NatsortInType]] = None,
reverse: bool = False,
alg: NSType = ns.DEFAULT,
-) -> List_int:
+) -> List[int]:
"""
Determine the list of the indexes used to sort the input sequence.
@@ -537,13 +466,13 @@ def index_natsorted(
['baz', 'foo', 'bar']
"""
- newkey: KeyType
+ newkey: Callable[[Tuple[int, T]], NatsortInType]
if key is None:
newkey = itemgetter(1)
else:
- def newkey(x: Any) -> NatsortInType:
- return cast(KeyType, key)(itemgetter(1)(x))
+ def newkey(x: Tuple[int, T]) -> NatsortInType:
+ return cast(Callable[[T], NatsortInType], key)(itemgetter(1)(x))
# Pair the index and sequence together, then sort by element
index_seq_pair = [(x, y) for x, y in enumerate(seq)]
@@ -551,26 +480,12 @@ def index_natsorted(
return [x for x, _ in index_seq_pair]
-@overload
-def index_humansorted(
- seq: Iter_ns, key: None = None, reverse: bool = False, alg: NSType = ns.DEFAULT
-) -> List_int:
- ...
-
-
-@overload
def index_humansorted(
- seq: Iter_any, key: KeyType, reverse: bool = False, alg: NSType = ns.DEFAULT
-) -> List_int:
- ...
-
-
-def index_humansorted(
- seq: Iter_any,
- key: MaybeKeyType = None,
+ seq: Iterable[T],
+ key: Optional[Callable[[T], NatsortInType]] = None,
reverse: bool = False,
alg: NSType = ns.DEFAULT,
-) -> List_int:
+) -> List[int]:
"""
This is a wrapper around ``index_natsorted(seq, alg=ns.LOCALE)``.
@@ -619,26 +534,12 @@ def index_humansorted(
return index_natsorted(seq, key, reverse, alg | ns.LOCALE)
-@overload
-def index_realsorted(
- seq: Iter_ns, key: None = None, reverse: bool = False, alg: NSType = ns.DEFAULT
-) -> List_int:
- ...
-
-
-@overload
-def index_realsorted(
- seq: Iter_any, key: KeyType, reverse: bool = False, alg: NSType = ns.DEFAULT
-) -> List_int:
- ...
-
-
def index_realsorted(
- seq: Iter_any,
- key: MaybeKeyType = None,
+ seq: Iterable[T],
+ key: Optional[Callable[[T], NatsortInType]] = None,
reverse: bool = False,
alg: NSType = ns.DEFAULT,
-) -> List_int:
+) -> List[int]:
"""
This is a wrapper around ``index_natsorted(seq, alg=ns.REAL)``.
@@ -683,10 +584,9 @@ def index_realsorted(
return index_natsorted(seq, key, reverse, alg | ns.REAL)
-# noinspection PyShadowingBuiltins,PyUnresolvedReferences
def order_by_index(
seq: Sequence[Any], index: Iterable[int], iter: bool = False
-) -> Iter_any:
+) -> Iterable[Any]:
"""
Order a given sequence by an index sequence.
@@ -764,7 +664,9 @@ def numeric_regex_chooser(alg: NSType) -> str:
return utils.regex_chooser(alg).pattern[1:-1]
-def _split_apply(v: Any, key: MaybeKeyType = None) -> Iterator[str]:
+def _split_apply(
+ v: Any, key: Optional[Callable[[T], NatsortInType]] = None
+) -> Iterator[str]:
if key is not None:
v = key(v)
return utils.path_splitter(str(v))
@@ -781,11 +683,15 @@ if platform.system() == "Windows":
_windows_sort_cmp.restype = wintypes.INT
_winsort_key = cmp_to_key(_windows_sort_cmp)
- def os_sort_keygen(key: MaybeKeyType = None) -> OSSortKeyType:
+ def os_sort_keygen(
+ key: Optional[Callable[[Any], NatsortInType]] = None
+ ) -> Callable[[Any], NatsortOutType]:
return cast(
- OSSortKeyType, lambda x: tuple(map(_winsort_key, _split_apply(x, key)))
+ Callable[[Any], NatsortOutType],
+ lambda x: tuple(map(_winsort_key, _split_apply(x, key))),
)
+
else:
# For UNIX-based platforms, ICU performs MUCH better than locale
@@ -802,15 +708,16 @@ else:
except ImportError:
# No ICU installed
- def os_sort_keygen(key: MaybeKeyType = None) -> OSSortKeyType:
- return cast(
- OSSortKeyType,
- natsort_keygen(key=key, alg=ns.LOCALE | ns.PATH | ns.IGNORECASE),
- )
+ def os_sort_keygen(
+ key: Optional[Callable[[Any], NatsortInType]] = None
+ ) -> Callable[[Any], NatsortOutType]:
+ return natsort_keygen(key=key, alg=ns.LOCALE | ns.PATH | ns.IGNORECASE)
else:
# ICU installed
- def os_sort_keygen(key: MaybeKeyType = None) -> OSSortKeyType:
+ def os_sort_keygen(
+ key: Optional[Callable[[Any], NatsortInType]] = None
+ ) -> Callable[[Any], NatsortOutType]:
loc = natsort.compat.locale.get_icu_locale()
collator = icu.Collator.createInstance(loc)
collator.setAttribute(
@@ -857,19 +764,11 @@ os_sort_keygen
"""
-@overload
-def os_sorted(seq: Iter_path, key: None = None, reverse: bool = False) -> List_path:
- ...
-
-
-@overload
-def os_sorted(seq: Iter_any, key: KeyType, reverse: bool = False) -> List_any:
- ...
-
-
def os_sorted(
- seq: Iter_any, key: MaybeKeyType = None, reverse: bool = False
-) -> List_any:
+ seq: Iterable[T],
+ key: Optional[Callable[[T], NatsortInType]] = None,
+ reverse: bool = False,
+) -> List[T]:
"""
Sort elements in the same order as your operating system's file browser
diff --git a/natsort/utils.py b/natsort/utils.py
index 2bceb85..1e38bec 100644
--- a/natsort/utils.py
+++ b/natsort/utils.py
@@ -53,6 +53,7 @@ from typing import (
Match,
Optional,
Pattern,
+ TYPE_CHECKING,
Tuple,
Union,
cast,
@@ -70,9 +71,28 @@ from natsort.compat.locale import (
from natsort.ns_enum import NSType, NS_DUMB, ns
from natsort.unicode_numbers import digits_no_decimals, numeric_no_decimals
+if TYPE_CHECKING:
+ from typing_extensions import Protocol
+else:
+ Protocol = object
+
#
# Pre-define a slew of aggregate types which makes the type hinting below easier
#
+
+
+class SupportsDunderLT(Protocol):
+ def __lt__(self, __other: Any) -> bool:
+ ...
+
+
+class SupportsDunderGT(Protocol):
+ def __gt__(self, __other: Any) -> bool:
+ ...
+
+
+Sortable = Union[SupportsDunderLT, SupportsDunderGT]
+
StrToStr = Callable[[str], str]
AnyCall = Callable[[Any], Any]
@@ -83,27 +103,20 @@ BytesTransform = Union[BytesTuple, NestedBytesTuple]
BytesTransformer = Callable[[bytes], BytesTransform]
# For the number transform factory
-NumType = Union[float, int]
-MaybeNumType = Optional[NumType]
-NumTuple = Tuple[StrOrBytes, NumType]
-NestedNumTuple = Tuple[NumTuple]
-StrNumTuple = Tuple[Tuple[str], NumTuple]
-NestedStrNumTuple = Tuple[StrNumTuple]
-MaybeNumTransform = Union[NumTuple, NestedNumTuple, StrNumTuple, NestedStrNumTuple]
-MaybeNumTransformer = Callable[[MaybeNumType], MaybeNumTransform]
+BasicTuple = Tuple[Any, ...]
+NestedAnyTuple = Tuple[BasicTuple, ...]
+AnyTuple = Union[BasicTuple, NestedAnyTuple]
+NumTransform = AnyTuple
+NumTransformer = Callable[[Any], NumTransform]
# For the string component transform factory
StrBytesNum = Union[str, bytes, float, int]
StrTransformer = Callable[[str], StrBytesNum]
# For the final data transform factory
-TwoBlankTuple = Tuple[Tuple[()], Tuple[()]]
-TupleOfAny = Tuple[Any, ...]
-TupleOfStrAnyPair = Tuple[Tuple[str], TupleOfAny]
-FinalTransform = Union[TwoBlankTuple, TupleOfAny, TupleOfStrAnyPair]
+FinalTransform = AnyTuple
FinalTransformer = Callable[[Iterable[Any], str], FinalTransform]
-# For the path splitter
PathArg = Union[str, PurePath]
MatchFn = Callable[[str], Optional[Match]]
@@ -115,13 +128,8 @@ StrParser = Callable[[PathArg], FinalTransform]
PathSplitter = Callable[[PathArg], Tuple[FinalTransform, ...]]
# For the natsort key
-StrBytesPathNum = Union[str, bytes, float, int, PurePath]
-NatsortInType = Union[
- Optional[StrBytesPathNum], Iterable[Union[Optional[StrBytesPathNum], Iterable[Any]]]
-]
-NatsortOutType = Tuple[
- Union[StrBytesNum, Tuple[Union[StrBytesNum, Tuple[Any, ...]], ...]], ...
-]
+NatsortInType = Optional[Sortable]
+NatsortOutType = Tuple[Sortable, ...]
KeyType = Callable[[Any], NatsortInType]
MaybeKeyType = Optional[KeyType]
@@ -260,7 +268,7 @@ def natsort_key(
key: None,
string_func: Union[StrParser, PathSplitter],
bytes_func: BytesTransformer,
- num_func: MaybeNumTransformer,
+ num_func: NumTransformer,
) -> NatsortOutType:
...
@@ -271,7 +279,7 @@ def natsort_key(
key: KeyType,
string_func: Union[StrParser, PathSplitter],
bytes_func: BytesTransformer,
- num_func: MaybeNumTransformer,
+ num_func: NumTransformer,
) -> NatsortOutType:
...
@@ -281,7 +289,7 @@ def natsort_key(
key: MaybeKeyType,
string_func: Union[StrParser, PathSplitter],
bytes_func: BytesTransformer,
- num_func: MaybeNumTransformer,
+ num_func: NumTransformer,
) -> NatsortOutType:
"""
Key to sort strings and numbers naturally.
@@ -348,7 +356,7 @@ def natsort_key(
# If that failed, it must be a number.
except TypeError:
- return num_func(cast(NumType, val))
+ return num_func(val)
def parse_bytes_factory(alg: NSType) -> BytesTransformer:
@@ -386,7 +394,7 @@ def parse_bytes_factory(alg: NSType) -> BytesTransformer:
def parse_number_or_none_factory(
alg: NSType, sep: StrOrBytes, pre_sep: str
-) -> MaybeNumTransformer:
+) -> NumTransformer:
"""
Create a function that will format a number (or None) into a tuple.
@@ -418,8 +426,8 @@ def parse_number_or_none_factory(
nan_replace = float("+inf") if alg & ns.NANLAST else float("-inf")
def func(
- val: MaybeNumType, _nan_replace: float = nan_replace, _sep: StrOrBytes = sep
- ) -> NumTuple:
+ val: Any, _nan_replace: float = nan_replace, _sep: StrOrBytes = sep
+ ) -> BasicTuple:
"""Given a number, place it in a tuple with a leading null string."""
return _sep, (_nan_replace if val != val or val is None else val)
@@ -729,7 +737,7 @@ def final_data_transform_factory(
"""
if alg & ns.UNGROUPLETTERS and alg & ns.LOCALEALPHA:
swap = alg & NS_DUMB and alg & ns.LOWERCASEFIRST
- transform = cast(StrToStr, methodcaller("swapcase")) if swap else _no_op
+ transform = cast(StrToStr, methodcaller("swapcase") if swap else _no_op)
def func(
split_val: Iterable[NatsortInType],
@@ -835,11 +843,11 @@ def do_decoding(s: bytes, encoding: str) -> str:
@overload
-def do_decoding(s: NatsortInType, encoding: str) -> NatsortInType:
+def do_decoding(s: Any, encoding: str) -> Any:
...
-def do_decoding(s: NatsortInType, encoding: str) -> NatsortInType:
+def do_decoding(s: Any, encoding: str) -> Any:
"""
Helper to decode a *bytes* object, or return the object as-is.
diff --git a/tests/test_os_sorted.py b/tests/test_os_sorted.py
index d0ecc79..f714437 100644
--- a/tests/test_os_sorted.py
+++ b/tests/test_os_sorted.py
@@ -3,7 +3,6 @@
Testing for the OS sorting
"""
import platform
-from typing import cast
import natsort
import pytest
@@ -44,7 +43,7 @@ def test_os_sorted_misc_no_fail() -> None:
def test_os_sorted_key() -> None:
given = ["foo0", "foo2", "goo1"]
expected = ["foo0", "goo1", "foo2"]
- result = natsort.os_sorted(given, key=lambda x: cast(str, x).replace("g", "f"))
+ result = natsort.os_sorted(given, key=lambda x: x.replace("g", "f"))
assert result == expected
diff --git a/tests/test_parse_number_function.py b/tests/test_parse_number_function.py
index e5f417d..85d6b96 100644
--- a/tests/test_parse_number_function.py
+++ b/tests/test_parse_number_function.py
@@ -7,7 +7,7 @@ import pytest
from hypothesis import given
from hypothesis.strategies import floats, integers
from natsort.ns_enum import NSType, ns
-from natsort.utils import MaybeNumTransformer, parse_number_or_none_factory
+from natsort.utils import NumTransformer, parse_number_or_none_factory
@pytest.mark.usefixtures("with_locale_en_us")
@@ -22,7 +22,7 @@ from natsort.utils import MaybeNumTransformer, parse_number_or_none_factory
)
@given(x=floats(allow_nan=False) | integers())
def test_parse_number_factory_makes_function_that_returns_tuple(
- x: Union[float, int], alg: NSType, example_func: MaybeNumTransformer
+ x: Union[float, int], alg: NSType, example_func: NumTransformer
) -> None:
parse_number_func = parse_number_or_none_factory(alg, "", "xx")
assert parse_number_func(x) == example_func(x)
diff --git a/tox.ini b/tox.ini
index 74a7066..19ab53d 100644
--- a/tox.ini
+++ b/tox.ini
@@ -60,6 +60,7 @@ deps =
pytest
pytest-mock
fastnumbers
+ typing_extensions
commands =
mypy --strict natsort tests
skip_install = true