diff options
author | Seth Morton <seth.m.morton@gmail.com> | 2021-11-02 19:52:31 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-11-02 19:52:31 -0700 |
commit | 5eaed4145d174ac05752095ec205ce2cd2e90a5f (patch) | |
tree | 1fcc1c7ce9698515ea1ae3f239d4339783e6c4d6 | |
parent | afe12261977a219bef0d0c0e6a40e1f81cf44d4f (diff) | |
parent | 3461338e52292926bc2148dd5fdda37f253b5860 (diff) | |
download | natsort-5eaed4145d174ac05752095ec205ce2cd2e90a5f.tar.gz |
Merge pull request #138 from SethMMorton/type-hinting
Type hinting
36 files changed, 976 insertions, 434 deletions
diff --git a/.github/workflows/code-quality.yml b/.github/workflows/code-quality.yml index f95887b..89d31f2 100644 --- a/.github/workflows/code-quality.yml +++ b/.github/workflows/code-quality.yml @@ -38,7 +38,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v2 with: - python-version: '3.5' + python-version: '3.6' - name: Install Flake8 run: pip install flake8 flake8-import-order flake8-bugbear pep8-naming @@ -46,6 +46,24 @@ jobs: - name: Run Flake8 run: flake8 + type-checking: + name: Type Checking + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.6' + + - name: Install MyPy + run: pip install mypy hypothesis pytest pytest-mock fastnumbers + + - name: Run MyPy + run: mypy --strict natsort tests + package-validation: name: Package Validation runs-on: ubuntu-latest @@ -56,7 +74,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v2 with: - python-version: '3.5' + python-version: '3.6' - name: Install Validators run: pip install twine check-manifest diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 3263030..ce64b65 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -15,7 +15,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - python-version: [3.5, 3.6, 3.7, 3.8, 3.9] + python-version: [3.6, 3.7, 3.8, 3.9] os: [ubuntu-latest] extras: [false] include: diff --git a/CHANGELOG.md b/CHANGELOG.md index 4e0c86d..1e1d147 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,15 @@ Unreleased --- +### Changed + + - The `ns` enum is now implemented as an `enum.IntEnum` instead of a + `collections.namedtuple` + +### Removed + + - Support for Python 3.4 and Python 3.5 + [7.1.1] - 2021-01-24 --- diff --git a/docs/api.rst b/docs/api.rst index 86a26ee..39d7cec 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -117,3 +117,25 @@ Given your chosen algorithm (selected using the :class:`~natsort.ns` enum), the corresponding regular expression to locate numbers will be returned. .. autofunction:: numeric_regex_chooser + +Help With Type Hinting +++++++++++++++++++++++ + +If you need to explictly specify the types that natsort accepts or returns +in your code, the following types have been exposed for your convenience. + ++--------------------------------+----------------------------------------------------------------------------------------+ +| Type | Purpose | ++================================+========================================================================================+ +|:attr:`natsort.NatsortKeyType` | Returned by :func:`natsort.natsort_keygen`, and type of :attr:`natsort.natsort_key` | ++--------------------------------+----------------------------------------------------------------------------------------+ +|:attr:`natsort.OSSortKeyType` | Returned by :func:`natsort.os_sort_keygen`, and type of :attr:`natsort.os_sort_key` | ++--------------------------------+----------------------------------------------------------------------------------------+ +|:attr:`natsort.KeyType` | Type of `key` argument to :func:`natsort.natsorted` and :func:`natsort.natsort_keygen` | ++--------------------------------+----------------------------------------------------------------------------------------+ +|:attr:`natsort.NatsortInType` | The input type of :attr:`natsort.NatsortKeyType` | ++--------------------------------+----------------------------------------------------------------------------------------+ +|:attr:`natsort.NatsortOutType` | The output type of :attr:`natsort.NatsortKeyType` | ++--------------------------------+----------------------------------------------------------------------------------------+ +|:attr:`natsort.NSType` | The type of the :class:`ns` enum | ++--------------------------------+----------------------------------------------------------------------------------------+ diff --git a/docs/conf.py b/docs/conf.py index 05919ec..feab38e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -35,6 +35,7 @@ extensions = [ "sphinx.ext.napoleon", "m2r2", ] +autodoc_typehints = "none" # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] diff --git a/natsort/__init__.py b/natsort/__init__.py index 8c8f87f..4207f79 100644 --- a/natsort/__init__.py +++ b/natsort/__init__.py @@ -1,6 +1,8 @@ # -*- coding: utf-8 -*- from natsort.natsort import ( + NatsortKeyType, + OSSortKeyType, as_ascii, as_utf8, decoder, @@ -11,7 +13,6 @@ from natsort.natsort import ( natsort_key, natsort_keygen, natsorted, - ns, numeric_regex_chooser, order_by_index, os_sort_key, @@ -19,7 +20,8 @@ from natsort.natsort import ( os_sorted, realsorted, ) -from natsort.utils import chain_functions +from natsort.ns_enum import NSType, ns +from natsort.utils import KeyType, NatsortInType, NatsortOutType, chain_functions __version__ = "7.1.1" @@ -42,7 +44,13 @@ __all__ = [ "os_sort_key", "os_sort_keygen", "os_sorted", + "NatsortKeyType", + "OSSortKeyType", + "KeyType", + "NatsortInType", + "NatsortOutType", + "NSType", ] # Add the ns keys to this namespace for convenience. -globals().update(ns._asdict()) +globals().update({name: value for name, value in ns.__members__.items()}) diff --git a/natsort/__main__.py b/natsort/__main__.py index e2f3299..4dffc4b 100644 --- a/natsort/__main__.py +++ b/natsort/__main__.py @@ -1,12 +1,53 @@ # -*- coding: utf-8 -*- +import argparse import sys +from typing import Callable, Iterable, List, Optional, Pattern, Tuple, Union, cast import natsort from natsort.utils import regex_chooser - -def main(*arguments): +Num = Union[float, int] +NumIter = Iterable[Num] +NumPair = Tuple[Num, Num] +NumPairIter = Iterable[NumPair] +NumConverter = Callable[[str], Num] + + +class TypedArgs(argparse.Namespace): + paths: bool + filter: Optional[List[NumPair]] + reverse_filter: Optional[List[NumPair]] + exclude: List[Num] + reverse: bool + number_type: str + nosign: bool + sign: bool + noexp: bool + locale: bool + entries: List[str] + + def __init__( + self, + filter: Optional[List[NumPair]] = None, + reverse_filter: Optional[List[NumPair]] = None, + exclude: Optional[List[Num]] = None, + paths: bool = False, + reverse: bool = False, + ) -> None: + """Used by testing only""" + self.filter = filter + self.reverse_filter = reverse_filter + self.exclude = [] if exclude is None else exclude + self.paths = paths + self.reverse = reverse + self.number_type = "int" + self.signed = False + self.exp = True + self.locale = False + + +def main(*arguments: str) -> None: """ Performs a natural sort on entries given on the command-line. @@ -17,7 +58,8 @@ def main(*arguments): from textwrap import dedent parser = ArgumentParser( - description=dedent(main.__doc__), formatter_class=RawDescriptionHelpFormatter + description=dedent(cast(str, main.__doc__)), + formatter_class=RawDescriptionHelpFormatter, ) parser.add_argument( "--version", @@ -126,7 +168,7 @@ def main(*arguments): help="The entries to sort. Taken from stdin if nothing is given on " "the command line.", ) - args = parser.parse_args(arguments or None) + args = parser.parse_args(arguments or None, namespace=TypedArgs()) # Make sure the filter range is given properly. Does nothing if no filter args.filter = check_filters(args.filter) @@ -139,7 +181,7 @@ def main(*arguments): sort_and_print_entries(entries, args) -def range_check(low, high): +def range_check(low: Num, high: Num) -> NumPair: """ Verify that that given range has a low lower than the high. @@ -164,7 +206,7 @@ def range_check(low, high): return low, high -def check_filters(filters): +def check_filters(filters: Optional[NumPairIter]) -> Optional[List[NumPair]]: """ Execute range_check for every element of an iterable. @@ -192,7 +234,13 @@ def check_filters(filters): raise ValueError("Error in --filter: " + str(err)) -def keep_entry_range(entry, lows, highs, converter, regex): +def keep_entry_range( + entry: str, + lows: NumIter, + highs: NumIter, + converter: NumConverter, + regex: Pattern[str], +) -> bool: """ Check if an entry falls into a desired range. @@ -224,7 +272,9 @@ def keep_entry_range(entry, lows, highs, converter, regex): ) -def keep_entry_value(entry, values, converter, regex): +def keep_entry_value( + entry: str, values: NumIter, converter: NumConverter, regex: Pattern[str] +) -> bool: """ Check if an entry does not match a given value. @@ -249,13 +299,13 @@ def keep_entry_value(entry, values, converter, regex): return not any(converter(num) in values for num in regex.findall(entry)) -def sort_and_print_entries(entries, args): +def sort_and_print_entries(entries: List[str], args: TypedArgs) -> None: """Sort the entries, applying the filters first if necessary.""" # Extract the proper number type. is_float = args.number_type in ("float", "real", "f", "r") signed = args.signed or args.number_type in ("real", "r") - alg = ( + alg: int = ( natsort.ns.FLOAT * is_float | natsort.ns.SIGNED * signed | natsort.ns.NOEXP * (not args.exp) diff --git a/natsort/compat/fake_fastnumbers.py b/natsort/compat/fake_fastnumbers.py index 7177551..5d44605 100644 --- a/natsort/compat/fake_fastnumbers.py +++ b/natsort/compat/fake_fastnumbers.py @@ -3,14 +3,12 @@ This module is intended to replicate some of the functionality from the fastnumbers module in the event that module is not installed. """ - -# Std. lib imports. import unicodedata +from typing import Callable, FrozenSet, Optional, Union -# Local imports. from natsort.unicode_numbers import decimal_chars -NAN_INF = [ +_NAN_INF = [ "INF", "INf", "Inf", @@ -28,21 +26,24 @@ NAN_INF = [ "nAN", "Nan", ] -NAN_INF.extend(["+" + x[:2] for x in NAN_INF] + ["-" + x[:2] for x in NAN_INF]) -NAN_INF = frozenset(NAN_INF) +_NAN_INF.extend(["+" + x[:2] for x in _NAN_INF] + ["-" + x[:2] for x in _NAN_INF]) +NAN_INF = frozenset(_NAN_INF) ASCII_NUMS = "0123456789+-" POTENTIAL_FIRST_CHAR = frozenset(decimal_chars + list(ASCII_NUMS + ".")) +StrOrFloat = Union[str, float] +StrOrInt = Union[str, int] + # noinspection PyIncorrectDocstring def fast_float( - x, - key=lambda x: x, - nan=None, - _uni=unicodedata.numeric, - _nan_inf=NAN_INF, - _first_char=POTENTIAL_FIRST_CHAR, -): + x: str, + key: Callable[[str], StrOrFloat] = lambda x: x, + nan: Optional[StrOrFloat] = None, + _uni: Callable[[str, StrOrFloat], StrOrFloat] = unicodedata.numeric, + _nan_inf: FrozenSet[str] = NAN_INF, + _first_char: FrozenSet[str] = POTENTIAL_FIRST_CHAR, +) -> StrOrFloat: """ Convert a string to a float quickly, return input as-is if not possible. @@ -65,8 +66,8 @@ def fast_float( """ if x[0] in _first_char or x.lstrip()[:3] in _nan_inf: try: - x = float(x) - return nan if nan is not None and x != x else x + ret = float(x) + return nan if nan is not None and ret != ret else ret except ValueError: try: return _uni(x, key(x)) if len(x) == 1 else key(x) @@ -81,8 +82,11 @@ def fast_float( # noinspection PyIncorrectDocstring def fast_int( - x, key=lambda x: x, _uni=unicodedata.digit, _first_char=POTENTIAL_FIRST_CHAR -): + x: str, + key: Callable[[str], StrOrInt] = lambda x: x, + _uni: Callable[[str, StrOrInt], StrOrInt] = unicodedata.digit, + _first_char: FrozenSet[str] = POTENTIAL_FIRST_CHAR, +) -> StrOrInt: """ Convert a string to a int quickly, return input as-is if not possible. diff --git a/natsort/compat/fastnumbers.py b/natsort/compat/fastnumbers.py index 4f4d75a..049030d 100644 --- a/natsort/compat/fastnumbers.py +++ b/natsort/compat/fastnumbers.py @@ -3,9 +3,10 @@ Interface for natsort to access fastnumbers functions without having to worry if it is actually installed. """ - import re +__all__ = ["fast_float", "fast_int"] + def is_supported_fastnumbers(fastnumbers_version: str) -> bool: match = re.match( @@ -34,4 +35,4 @@ try: if not is_supported_fastnumbers(fn_ver): raise ImportError # pragma: no cover except ImportError: - from natsort.compat.fake_fastnumbers import fast_float, fast_int # noqa: F401 + from natsort.compat.fake_fastnumbers import fast_float, fast_int # type: ignore diff --git a/natsort/compat/locale.py b/natsort/compat/locale.py index ccb5592..9af5e7a 100644 --- a/natsort/compat/locale.py +++ b/natsort/compat/locale.py @@ -3,9 +3,11 @@ Interface for natsort to access locale functionality without having to worry about if it is using PyICU or the built-in locale. """ - -# Std. lib imports. import sys +from typing import Callable, Union, cast + +StrOrBytes = Union[str, bytes] +TrxfmFunc = Callable[[str], StrOrBytes] # This string should be sorted after any other byte string because # it contains the max unicode character repeated 20 times. @@ -13,6 +15,11 @@ import sys null_string = "" null_string_max = chr(sys.maxunicode) * 20 +# This variable could be str or bytes depending on the locale library +# being used, so give the type-checker this information. +null_string_locale: StrOrBytes +null_string_locale_max: StrOrBytes + # strxfrm can be buggy (especially on BSD-based systems), # so prefer icu if available. try: # noqa: C901 @@ -26,26 +33,26 @@ try: # noqa: C901 # You would need some odd data to come after that. null_string_locale_max = b"x7f" * 50 - def dumb_sort(): + def dumb_sort() -> bool: return False # If using icu, get the locale from the current global locale, - def get_icu_locale(): + def get_icu_locale() -> str: try: - return icu.Locale(".".join(getlocale())) + return cast(str, icu.Locale(".".join(getlocale()))) except TypeError: # pragma: no cover - return icu.Locale() + return cast(str, icu.Locale()) - def get_strxfrm(): - return icu.Collator.createInstance(get_icu_locale()).getSortKey + def get_strxfrm() -> TrxfmFunc: + return cast(TrxfmFunc, icu.Collator.createInstance(get_icu_locale()).getSortKey) - def get_thousands_sep(): + def get_thousands_sep() -> str: sep = icu.DecimalFormatSymbols.kGroupingSeparatorSymbol - return icu.DecimalFormatSymbols(get_icu_locale()).getSymbol(sep) + return cast(str, icu.DecimalFormatSymbols(get_icu_locale()).getSymbol(sep)) - def get_decimal_point(): + def get_decimal_point() -> str: sep = icu.DecimalFormatSymbols.kDecimalSeparatorSymbol - return icu.DecimalFormatSymbols(get_icu_locale()).getSymbol(sep) + return cast(str, icu.DecimalFormatSymbols(get_icu_locale()).getSymbol(sep)) except ImportError: @@ -57,14 +64,14 @@ except ImportError: # On some systems, locale is broken and does not sort in the expected # order. We will try to detect this and compensate. - def dumb_sort(): + def dumb_sort() -> bool: return strxfrm("A") < strxfrm("a") - def get_strxfrm(): + def get_strxfrm() -> TrxfmFunc: return strxfrm - def get_thousands_sep(): - sep = locale.localeconv()["thousands_sep"] + def get_thousands_sep() -> str: + sep = cast(str, locale.localeconv()["thousands_sep"]) # If this locale library is broken, some of the thousands separator # characters are incorrectly blank. Here is a lookup table of the # corrections I am aware of. @@ -111,5 +118,5 @@ except ImportError: else: return sep - def get_decimal_point(): - return locale.localeconv()["decimal_point"] + def get_decimal_point() -> str: + return cast(str, locale.localeconv()["decimal_point"]) diff --git a/natsort/natsort.py b/natsort/natsort.py index 8e3a7b5..a95f9a9 100644 --- a/natsort/natsort.py +++ b/natsort/natsort.py @@ -9,13 +9,51 @@ The majority of the "work" is defined in utils.py. import platform from functools import partial from operator import itemgetter +from typing import ( + Any, + Callable, + Iterable, + Iterator, + List, + Optional, + Sequence, + Tuple, + Union, + cast, + overload, +) import natsort.compat.locale from natsort import utils -from natsort.ns_enum import NS_DUMB, ns - - -def decoder(encoding): +from natsort.ns_enum import NSType, NS_DUMB, ns +from natsort.utils import ( + KeyType, + MaybeKeyType, + NatsortInType, + NatsortOutType, + StrBytesNum, + StrBytesPathNum, +) + +# Common input and output types +Iter_ns = Iterable[NatsortInType] +Iter_any = Iterable[Any] +List_ns = List[NatsortInType] +List_any = List[Any] +List_int = List[int] + +# The type that natsort_key returns +NatsortKeyType = Callable[[NatsortInType], NatsortOutType] + +# Types for os_sorted +OSSortInType = Iterable[Optional[StrBytesPathNum]] +OSSortOutType = Tuple[Union[StrBytesNum, Tuple[StrBytesNum, ...]], ...] +OSSortKeyType = Callable[[Optional[StrBytesPathNum]], OSSortOutType] +Iter_path = Iterable[Optional[StrBytesPathNum]] +List_path = List[StrBytesPathNum] + + +def decoder(encoding: str) -> Callable[[NatsortInType], NatsortInType]: """ Return a function that can be used to decode bytes to unicode. @@ -56,7 +94,7 @@ def decoder(encoding): return partial(utils.do_decoding, encoding=encoding) -def as_ascii(s): +def as_ascii(s: NatsortInType) -> NatsortInType: """ Function to decode an input with the ASCII codec, or return as-is. @@ -79,7 +117,7 @@ def as_ascii(s): return utils.do_decoding(s, "ascii") -def as_utf8(s): +def as_utf8(s: NatsortInType) -> NatsortInType: """ Function to decode an input with the UTF-8 codec, or return as-is. @@ -102,7 +140,9 @@ def as_utf8(s): return utils.do_decoding(s, "utf-8") -def natsort_keygen(key=None, alg=ns.DEFAULT): +def natsort_keygen( + key: MaybeKeyType = None, alg: NSType = ns.DEFAULT +) -> NatsortKeyType: """ Generate a key to sort strings and numbers naturally. @@ -212,7 +252,26 @@ natsort_keygen """ -def natsorted(seq, key=None, reverse=False, alg=ns.DEFAULT): +@overload +def natsorted( + seq: Iter_ns, key: None = None, reverse: bool = False, alg: NSType = ns.DEFAULT +) -> List_ns: + ... + + +@overload +def natsorted( + seq: Iter_any, key: KeyType, reverse: bool = False, alg: NSType = ns.DEFAULT +) -> List_any: + ... + + +def natsorted( + seq: Iter_any, + key: MaybeKeyType = None, + reverse: bool = False, + alg: NSType = ns.DEFAULT, +) -> List_any: """ Sorts an iterable naturally. @@ -257,11 +316,29 @@ def natsorted(seq, key=None, reverse=False, alg=ns.DEFAULT): ['num2', 'num3', 'num5'] """ - key = natsort_keygen(key, alg) - return sorted(seq, reverse=reverse, key=key) + return sorted(seq, reverse=reverse, key=natsort_keygen(key, alg)) + + +@overload +def humansorted( + seq: Iter_ns, key: None = None, reverse: bool = False, alg: NSType = ns.DEFAULT +) -> List_ns: + ... -def humansorted(seq, key=None, reverse=False, alg=ns.DEFAULT): +@overload +def humansorted( + seq: Iter_any, key: KeyType, reverse: bool = False, alg: NSType = ns.DEFAULT +) -> List_any: + ... + + +def humansorted( + seq: Iter_any, + key: MaybeKeyType = None, + reverse: bool = False, + alg: NSType = ns.DEFAULT, +) -> List_any: """ Convenience function to properly sort non-numeric characters. @@ -313,7 +390,26 @@ def humansorted(seq, key=None, reverse=False, alg=ns.DEFAULT): return natsorted(seq, key, reverse, alg | ns.LOCALE) -def realsorted(seq, key=None, reverse=False, alg=ns.DEFAULT): +@overload +def realsorted( + seq: Iter_ns, key: None = None, reverse: bool = False, alg: NSType = ns.DEFAULT +) -> List_ns: + ... + + +@overload +def realsorted( + seq: Iter_any, key: KeyType, reverse: bool = False, alg: NSType = ns.DEFAULT +) -> List_any: + ... + + +def realsorted( + seq: Iter_any, + key: MaybeKeyType = None, + reverse: bool = False, + alg: NSType = ns.DEFAULT, +) -> List_any: """ Convenience function to properly sort signed floats. @@ -366,7 +462,26 @@ def realsorted(seq, key=None, reverse=False, alg=ns.DEFAULT): return natsorted(seq, key, reverse, alg | ns.REAL) -def index_natsorted(seq, key=None, reverse=False, alg=ns.DEFAULT): +@overload +def index_natsorted( + seq: Iter_ns, key: None = None, reverse: bool = False, alg: NSType = ns.DEFAULT +) -> List_int: + ... + + +@overload +def index_natsorted( + seq: Iter_any, key: KeyType, reverse: bool = False, alg: NSType = ns.DEFAULT +) -> List_int: + ... + + +def index_natsorted( + seq: Iter_any, + key: MaybeKeyType = None, + reverse: bool = False, + alg: NSType = ns.DEFAULT, +) -> List_int: """ Determine the list of the indexes used to sort the input sequence. @@ -422,12 +537,13 @@ def index_natsorted(seq, key=None, reverse=False, alg=ns.DEFAULT): ['baz', 'foo', 'bar'] """ + newkey: KeyType if key is None: newkey = itemgetter(1) else: - def newkey(x): - return key(itemgetter(1)(x)) + def newkey(x: Any) -> NatsortInType: + return cast(KeyType, key)(itemgetter(1)(x)) # Pair the index and sequence together, then sort by element index_seq_pair = [(x, y) for x, y in enumerate(seq)] @@ -435,7 +551,26 @@ def index_natsorted(seq, key=None, reverse=False, alg=ns.DEFAULT): return [x for x, _ in index_seq_pair] -def index_humansorted(seq, key=None, reverse=False, alg=ns.DEFAULT): +@overload +def index_humansorted( + seq: Iter_ns, key: None = None, reverse: bool = False, alg: NSType = ns.DEFAULT +) -> List_int: + ... + + +@overload +def index_humansorted( + seq: Iter_any, key: KeyType, reverse: bool = False, alg: NSType = ns.DEFAULT +) -> List_int: + ... + + +def index_humansorted( + seq: Iter_any, + key: MaybeKeyType = None, + reverse: bool = False, + alg: NSType = ns.DEFAULT, +) -> List_int: """ This is a wrapper around ``index_natsorted(seq, alg=ns.LOCALE)``. @@ -484,7 +619,26 @@ def index_humansorted(seq, key=None, reverse=False, alg=ns.DEFAULT): return index_natsorted(seq, key, reverse, alg | ns.LOCALE) -def index_realsorted(seq, key=None, reverse=False, alg=ns.DEFAULT): +@overload +def index_realsorted( + seq: Iter_ns, key: None = None, reverse: bool = False, alg: NSType = ns.DEFAULT +) -> List_int: + ... + + +@overload +def index_realsorted( + seq: Iter_any, key: KeyType, reverse: bool = False, alg: NSType = ns.DEFAULT +) -> List_int: + ... + + +def index_realsorted( + seq: Iter_any, + key: MaybeKeyType = None, + reverse: bool = False, + alg: NSType = ns.DEFAULT, +) -> List_int: """ This is a wrapper around ``index_natsorted(seq, alg=ns.REAL)``. @@ -530,7 +684,9 @@ def index_realsorted(seq, key=None, reverse=False, alg=ns.DEFAULT): # noinspection PyShadowingBuiltins,PyUnresolvedReferences -def order_by_index(seq, index, iter=False): +def order_by_index( + seq: Sequence[Any], index: Iterable[int], iter: bool = False +) -> Iter_any: """ Order a given sequence by an index sequence. @@ -589,7 +745,7 @@ def order_by_index(seq, index, iter=False): return (seq[i] for i in index) if iter else [seq[i] for i in index] -def numeric_regex_chooser(alg): +def numeric_regex_chooser(alg: NSType) -> str: """ Select an appropriate regex for the type of number of interest. @@ -608,7 +764,7 @@ def numeric_regex_chooser(alg): return utils.regex_chooser(alg).pattern[1:-1] -def _split_apply(v, key=None): +def _split_apply(v: Any, key: MaybeKeyType = None) -> Iterator[str]: if key is not None: v = key(v) return utils.path_splitter(str(v)) @@ -617,7 +773,7 @@ def _split_apply(v, key=None): # Choose the implementation based on the host OS if platform.system() == "Windows": - from ctypes import wintypes, windll + from ctypes import wintypes, windll # type: ignore from functools import cmp_to_key _windows_sort_cmp = windll.Shlwapi.StrCmpLogicalW @@ -625,8 +781,10 @@ if platform.system() == "Windows": _windows_sort_cmp.restype = wintypes.INT _winsort_key = cmp_to_key(_windows_sort_cmp) - def os_sort_keygen(key=None): - return lambda x: tuple(map(_winsort_key, _split_apply(x, key))) + def os_sort_keygen(key: MaybeKeyType = None) -> OSSortKeyType: + return cast( + OSSortKeyType, lambda x: tuple(map(_winsort_key, _split_apply(x, key))) + ) else: @@ -645,12 +803,15 @@ else: except ImportError: # No ICU installed - def os_sort_keygen(key=None): - return natsort_keygen(key=key, alg=ns.LOCALE | ns.PATH | ns.IGNORECASE) + def os_sort_keygen(key: MaybeKeyType = None) -> OSSortKeyType: + return cast( + OSSortKeyType, + natsort_keygen(key=key, alg=ns.LOCALE | ns.PATH | ns.IGNORECASE), + ) else: # ICU installed - def os_sort_keygen(key=None): + def os_sort_keygen(key: MaybeKeyType = None) -> OSSortKeyType: loc = natsort.compat.locale.get_icu_locale() collator = icu.Collator.createInstance(loc) collator.setAttribute( @@ -697,7 +858,19 @@ os_sort_keygen """ -def os_sorted(seq, key=None, reverse=False): +@overload +def os_sorted(seq: Iter_path, key: None = None, reverse: bool = False) -> List_path: + ... + + +@overload +def os_sorted(seq: Iter_any, key: KeyType, reverse: bool = False) -> List_any: + ... + + +def os_sorted( + seq: Iter_any, key: MaybeKeyType = None, reverse: bool = False +) -> List_any: """ Sort elements in the same order as your operating system's file browser diff --git a/natsort/ns_enum.py b/natsort/ns_enum.py index 283a793..c147909 100644 --- a/natsort/ns_enum.py +++ b/natsort/ns_enum.py @@ -4,73 +4,15 @@ This module defines the "ns" enum for natsort is used to determine what algorithm natsort uses. """ -import collections - -# The below are the base ns options. The values will be stored as powers -# of two so bitmasks can be used to extract the user's requested options. -enum_options = [ - "FLOAT", - "SIGNED", - "NOEXP", - "PATH", - "LOCALEALPHA", - "LOCALENUM", - "IGNORECASE", - "LOWERCASEFIRST", - "GROUPLETTERS", - "UNGROUPLETTERS", - "NANLAST", - "COMPATIBILITYNORMALIZE", - "NUMAFTER", -] - -# Following were previously options but are now defaults. -enum_do_nothing = ["DEFAULT", "INT", "UNSIGNED"] - -# The following are bitwise-OR combinations of other fields. -enum_combos = [("REAL", ("FLOAT", "SIGNED")), ("LOCALE", ("LOCALEALPHA", "LOCALENUM"))] - -# The following are aliases for other fields. -enum_aliases = [ - ("I", "INT"), - ("U", "UNSIGNED"), - ("F", "FLOAT"), - ("S", "SIGNED"), - ("R", "REAL"), - ("N", "NOEXP"), - ("P", "PATH"), - ("LA", "LOCALEALPHA"), - ("LN", "LOCALENUM"), - ("L", "LOCALE"), - ("IC", "IGNORECASE"), - ("LF", "LOWERCASEFIRST"), - ("G", "GROUPLETTERS"), - ("UG", "UNGROUPLETTERS"), - ("C", "UNGROUPLETTERS"), - ("CAPITALFIRST", "UNGROUPLETTERS"), - ("NL", "NANLAST"), - ("CN", "COMPATIBILITYNORMALIZE"), - ("NA", "NUMAFTER"), -] - -# Construct the list of bitwise distinct enums with their fields. -enum_fields = collections.OrderedDict( - (name, 1 << i) for i, name in enumerate(enum_options) -) -enum_fields.update((name, 0) for name in enum_do_nothing) - -for name, combo in enum_combos: - combined_value = enum_fields[combo[0]] - for combo_name in combo[1:]: - combined_value |= enum_fields[combo_name] - enum_fields[name] = combined_value - -enum_fields.update((alias, enum_fields[name]) for alias, name in enum_aliases) - - -# Subclass the namedtuple to improve the docstring. -# noinspection PyUnresolvedReferences -class _NSEnum(collections.namedtuple("_NSEnum", enum_fields.keys())): +import enum +import itertools +import typing + + +_counter = itertools.count(0) + + +class ns(enum.IntEnum): # noqa: N801 """ Enum to control the `natsort` algorithm. @@ -186,10 +128,35 @@ class _NSEnum(collections.namedtuple("_NSEnum", enum_fields.keys())): """ + # The below are the base ns options. The values will be stored as powers + # of two so bitmasks can be used to extract the user's requested options. + FLOAT = F = 1 << next(_counter) + SIGNED = S = 1 << next(_counter) + NOEXP = N = 1 << next(_counter) + PATH = P = 1 << next(_counter) + LOCALEALPHA = LA = 1 << next(_counter) + LOCALENUM = LN = 1 << next(_counter) + IGNORECASE = IC = 1 << next(_counter) + LOWERCASEFIRST = LF = 1 << next(_counter) + GROUPLETTERS = G = 1 << next(_counter) + UNGROUPLETTERS = CAPITALFIRST = C = UG = 1 << next(_counter) + NANLAST = NL = 1 << next(_counter) + COMPATIBILITYNORMALIZE = CN = 1 << next(_counter) + NUMAFTER = NA = 1 << next(_counter) + + # Following were previously options but are now defaults. + DEFAULT = 0 + INT = I = 0 # noqa: E741 + UNSIGNED = U = 0 + + # The following are bitwise-OR combinations of other fields. + REAL = R = FLOAT | SIGNED + LOCALE = L = LOCALEALPHA | LOCALENUM -# Here is where the instance of the ns enum that will be exported is created. -# It is a poor-man's singleton. -ns = _NSEnum(*enum_fields.values()) # The below is private for internal use only. NS_DUMB = 1 << 31 + +# An integer can be used in place of the ns enum so make the +# type to use for this enum a union of it and an inteter. +NSType = typing.Union[ns, int] diff --git a/natsort/py.typed b/natsort/py.typed new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/natsort/py.typed diff --git a/natsort/utils.py b/natsort/utils.py index 3f6641b..7102f41 100644 --- a/natsort/utils.py +++ b/natsort/utils.py @@ -38,19 +38,93 @@ that ensures "val" is a local variable instead of global variable and thus has a slightly improved performance at runtime. """ - import re from functools import partial, reduce from itertools import chain as ichain from operator import methodcaller from pathlib import PurePath +from typing import ( + Any, + Callable, + Dict, + Iterable, + Iterator, + List, + Match, + Optional, + Pattern, + Tuple, + Union, + cast, + overload, +) from unicodedata import normalize from natsort.compat.fastnumbers import fast_float, fast_int -from natsort.compat.locale import get_decimal_point, get_strxfrm, get_thousands_sep -from natsort.ns_enum import NS_DUMB, ns +from natsort.compat.locale import ( + StrOrBytes, + get_decimal_point, + get_strxfrm, + get_thousands_sep, +) +from natsort.ns_enum import NSType, NS_DUMB, ns from natsort.unicode_numbers import digits_no_decimals, numeric_no_decimals +# +# Pre-define a slew of aggregate types which makes the type hinting below easier +# +StrToStr = Callable[[str], str] +AnyCall = Callable[[Any], Any] + +# For the bytes transform factory +BytesTuple = Tuple[bytes] +NestedBytesTuple = Tuple[Tuple[bytes]] +BytesTransform = Union[BytesTuple, NestedBytesTuple] +BytesTransformer = Callable[[bytes], BytesTransform] + +# For the number transform factory +NumType = Union[float, int] +MaybeNumType = Optional[NumType] +NumTuple = Tuple[StrOrBytes, NumType] +NestedNumTuple = Tuple[NumTuple] +StrNumTuple = Tuple[Tuple[str], NumTuple] +NestedStrNumTuple = Tuple[StrNumTuple] +MaybeNumTransform = Union[NumTuple, NestedNumTuple, StrNumTuple, NestedStrNumTuple] +MaybeNumTransformer = Callable[[MaybeNumType], MaybeNumTransform] + +# For the string component transform factory +StrBytesNum = Union[str, bytes, float, int] +StrTransformer = Callable[[str], StrBytesNum] + +# For the final data transform factory +TwoBlankTuple = Tuple[Tuple[()], Tuple[()]] +TupleOfAny = Tuple[Any, ...] +TupleOfStrAnyPair = Tuple[Tuple[str], TupleOfAny] +FinalTransform = Union[TwoBlankTuple, TupleOfAny, TupleOfStrAnyPair] +FinalTransformer = Callable[[Iterable[Any], str], FinalTransform] + +# For the string parsing factory +StrSplitter = Callable[[str], Iterable[str]] +StrParser = Callable[[str], FinalTransform] + +# For the path splitter +PathArg = Union[str, PurePath] +MatchFn = Callable[[str], Optional[Match]] + +# For the path parsing factory +PathSplitter = Callable[[PathArg], Tuple[FinalTransform, ...]] + +# For the natsort key +StrBytesPathNum = Union[str, bytes, float, int, PurePath] +NatsortInType = Union[ + Optional[StrBytesPathNum], Iterable[Union[Optional[StrBytesPathNum], Iterable[Any]]] +] +NatsortOutType = Tuple[ + Union[StrBytesNum, Tuple[Union[StrBytesNum, Tuple[Any, ...]], ...]], ... +] +KeyType = Callable[[Any], NatsortInType] +MaybeKeyType = Optional[KeyType] + class NumericalRegularExpressions: """ @@ -62,51 +136,51 @@ class NumericalRegularExpressions: """ # All unicode numeric characters (minus the decimal characters). - numeric = numeric_no_decimals + numeric: str = numeric_no_decimals # All unicode digit characters (minus the decimal characters). - digits = digits_no_decimals + digits: str = digits_no_decimals # Regular expression to match exponential component of a float. - exp = r"(?:[eE][-+]?\d+)?" + exp: str = r"(?:[eE][-+]?\d+)?" # Regular expression to match a floating point number. - float_num = r"(?:\d+\.?\d*|\.\d+)" + float_num: str = r"(?:\d+\.?\d*|\.\d+)" @classmethod - def _construct_regex(cls, fmt): + def _construct_regex(cls, fmt: str) -> Pattern[str]: """Given a format string, construct the regex with class attributes.""" return re.compile(fmt.format(**vars(cls)), flags=re.U) @classmethod - def int_sign(cls): + def int_sign(cls) -> Pattern[str]: """Regular expression to match a signed int.""" return cls._construct_regex(r"([-+]?\d+|[{digits}])") @classmethod - def int_nosign(cls): + def int_nosign(cls) -> Pattern[str]: """Regular expression to match an unsigned int.""" return cls._construct_regex(r"(\d+|[{digits}])") @classmethod - def float_sign_exp(cls): + def float_sign_exp(cls) -> Pattern[str]: """Regular expression to match a signed float with exponent.""" return cls._construct_regex(r"([-+]?{float_num}{exp}|[{numeric}])") @classmethod - def float_nosign_exp(cls): + def float_nosign_exp(cls) -> Pattern[str]: """Regular expression to match an unsigned float with exponent.""" return cls._construct_regex(r"({float_num}{exp}|[{numeric}])") @classmethod - def float_sign_noexp(cls): + def float_sign_noexp(cls) -> Pattern[str]: """Regular expression to match a signed float without exponent.""" return cls._construct_regex(r"([-+]?{float_num}|[{numeric}])") @classmethod - def float_nosign_noexp(cls): + def float_nosign_noexp(cls) -> Pattern[str]: """Regular expression to match an unsigned float without exponent.""" return cls._construct_regex(r"({float_num}|[{numeric}])") -def regex_chooser(alg): +def regex_chooser(alg: NSType) -> Pattern[str]: """ Select an appropriate regex for the type of number of interest. @@ -136,12 +210,12 @@ def regex_chooser(alg): }[alg] -def _no_op(x): +def _no_op(x: Any) -> Any: """A function that does nothing and returns the input as-is.""" return x -def _normalize_input_factory(alg): +def _normalize_input_factory(alg: NSType) -> StrToStr: """ Create a function that will normalize unicode input data. @@ -161,7 +235,35 @@ def _normalize_input_factory(alg): return partial(normalize, normalization_form) -def natsort_key(val, key, string_func, bytes_func, num_func): +@overload +def natsort_key( + val: NatsortInType, + key: None, + string_func: Union[StrParser, PathSplitter], + bytes_func: BytesTransformer, + num_func: MaybeNumTransformer, +) -> NatsortOutType: + ... + + +@overload +def natsort_key( + val: Any, + key: KeyType, + string_func: Union[StrParser, PathSplitter], + bytes_func: BytesTransformer, + num_func: MaybeNumTransformer, +) -> NatsortOutType: + ... + + +def natsort_key( + val: Union[NatsortInType, Any], + key: MaybeKeyType, + string_func: Union[StrParser, PathSplitter], + bytes_func: BytesTransformer, + num_func: MaybeNumTransformer, +) -> NatsortOutType: """ Key to sort strings and numbers naturally. @@ -170,7 +272,7 @@ def natsort_key(val, key, string_func, bytes_func, num_func): Parameters ---------- - val : str | unicode | bytes | int | float | iterable + val : str | bytes | int | float | iterable key : callable | None A key to apply to the *val* before any other operations are performed. string_func : callable @@ -210,26 +312,27 @@ def natsort_key(val, key, string_func, bytes_func, num_func): # Assume the input are strings, which is the most common case try: - return string_func(val) + return string_func(cast(str, val)) except (TypeError, AttributeError): # If bytes type, use the bytes_func if type(val) in (bytes,): - return bytes_func(val) + return bytes_func(cast(bytes, val)) # Otherwise, assume it is an iterable that must be parsed recursively. # Do not apply the key recursively. try: return tuple( - natsort_key(x, None, string_func, bytes_func, num_func) for x in val + natsort_key(x, None, string_func, bytes_func, num_func) + for x in cast(Iterable[Any], val) ) # If that failed, it must be a number. except TypeError: - return num_func(val) + return num_func(cast(NumType, val)) -def parse_bytes_factory(alg): +def parse_bytes_factory(alg: NSType) -> BytesTransformer: """ Create a function that will format a *bytes* object into a tuple. @@ -262,7 +365,9 @@ def parse_bytes_factory(alg): return lambda x: (x,) -def parse_number_or_none_factory(alg, sep, pre_sep): +def parse_number_or_none_factory( + alg: NSType, sep: StrOrBytes, pre_sep: str +) -> MaybeNumTransformer: """ Create a function that will format a number (or None) into a tuple. @@ -293,7 +398,9 @@ def parse_number_or_none_factory(alg, sep, pre_sep): """ nan_replace = float("+inf") if alg & ns.NANLAST else float("-inf") - def func(val, _nan_replace=nan_replace, _sep=sep): + def func( + val: MaybeNumType, _nan_replace: float = nan_replace, _sep: StrOrBytes = sep + ) -> NumTuple: """Given a number, place it in a tuple with a leading null string.""" return _sep, (_nan_replace if val != val or val is None else val) @@ -309,8 +416,13 @@ def parse_number_or_none_factory(alg, sep, pre_sep): def parse_string_factory( - alg, sep, splitter, input_transform, component_transform, final_transform -): + alg: NSType, + sep: StrOrBytes, + splitter: StrSplitter, + input_transform: StrToStr, + component_transform: StrTransformer, + final_transform: FinalTransformer, +) -> StrParser: """ Create a function that will split and format a *str* into a tuple. @@ -361,22 +473,22 @@ def parse_string_factory( original_func = input_transform if orig_after_xfrm else _no_op normalize_input = _normalize_input_factory(alg) - def func(x): + def func(x: str) -> FinalTransform: # Apply string input transformation function and return to x. # Original function is usually a no-op, but some algorithms require it # to also be the transformation function. - x = normalize_input(x) - x, original = input_transform(x), original_func(x) - x = splitter(x) # Split string into components. - x = filter(None, x) # Remove empty strings. - x = map(component_transform, x) # Apply transform on components. - x = sep_inserter(x, sep) # Insert '' between numbers. - return final_transform(x, original) # Apply the final transform. + a = normalize_input(x) + b, original = input_transform(a), original_func(a) + c = splitter(b) # Split string into components. + d = filter(None, c) # Remove empty strings. + e = map(component_transform, d) # Apply transform on components. + f = sep_inserter(e, sep) # Insert '' between numbers. + return final_transform(f, original) # Apply the final transform. return func -def parse_path_factory(str_split): +def parse_path_factory(str_split: StrParser) -> PathSplitter: """ Create a function that will properly split and format a path. @@ -403,41 +515,41 @@ def parse_path_factory(str_split): return lambda x: tuple(map(str_split, path_splitter(x))) -def sep_inserter(iterable, sep): +def sep_inserter(iterator: Iterator[Any], sep: StrOrBytes) -> Iterator[Any]: """ - Insert '' between numbers in an iterable. + Insert '' between numbers in an iterator. Parameters ---------- - iterable + iterator sep : str The string character to be inserted between adjacent numeric objects. Yields ------ - The values of *iterable* in order, with *sep* inserted where adjacent + The values of *iterator* in order, with *sep* inserted where adjacent elements are numeric. If the first element in the input is numeric then *sep* will be the first value yielded. """ try: - # Get the first element. A StopIteration indicates an empty iterable. + # Get the first element. A StopIteration indicates an empty iterator. # Since we are controlling the types of the input, 'type' is used # instead of 'isinstance' for the small speed advantage it offers. types = (int, float) - first = next(iterable) + first = next(iterator) if type(first) in types: yield sep yield first # Now, check if pair of elements are both numbers. If so, add ''. - second = next(iterable) + second = next(iterator) if type(first) in types and type(second) in types: yield sep yield second # Now repeat in a loop. - for x in iterable: + for x in iterator: first, second = second, x if type(first) in types and type(second) in types: yield sep @@ -448,7 +560,7 @@ def sep_inserter(iterable, sep): return -def input_string_transform_factory(alg): +def input_string_transform_factory(alg: NSType) -> StrToStr: """ Create a function to transform a string. @@ -473,7 +585,7 @@ def input_string_transform_factory(alg): dumb = alg & NS_DUMB # Build the chain of functions to execute in order. - function_chain = [] + function_chain: List[StrToStr] = [] if (dumb and not lowfirst) or (lowfirst and not dumb): function_chain.append(methodcaller("swapcase")) @@ -502,8 +614,8 @@ def input_string_transform_factory(alg): strip_thousands = strip_thousands.format( thou=re.escape(get_thousands_sep()), nodecimal=nodecimal ) - strip_thousands = re.compile(strip_thousands, flags=re.VERBOSE) - function_chain.append(partial(strip_thousands.sub, "")) + strip_thousands_re = re.compile(strip_thousands, flags=re.VERBOSE) + function_chain.append(partial(strip_thousands_re.sub, "")) # Create a regular expression that will change the decimal point to # a period if not already a period. @@ -511,14 +623,14 @@ def input_string_transform_factory(alg): if alg & ns.FLOAT and decimal != ".": switch_decimal = r"(?<=[0-9]){decimal}|{decimal}(?=[0-9])" switch_decimal = switch_decimal.format(decimal=re.escape(decimal)) - switch_decimal = re.compile(switch_decimal) - function_chain.append(partial(switch_decimal.sub, ".")) + switch_decimal_re = re.compile(switch_decimal) + function_chain.append(partial(switch_decimal_re.sub, ".")) # Return the chained functions. return chain_functions(function_chain) -def string_component_transform_factory(alg): +def string_component_transform_factory(alg: NSType) -> StrTransformer: """ Create a function to either transform a string or convert to a number. @@ -545,23 +657,26 @@ def string_component_transform_factory(alg): nan_val = float("+inf") if alg & ns.NANLAST else float("-inf") # Build the chain of functions to execute in order. - func_chain = [] + func_chain: List[Callable[[str], StrOrBytes]] = [] if group_letters: func_chain.append(groupletters) if use_locale: func_chain.append(get_strxfrm()) - kwargs = {"key": chain_functions(func_chain)} if func_chain else {} # Return the correct chained functions. + kwargs: Dict[str, Union[float, Callable[[str], StrOrBytes]]] + kwargs = {"key": chain_functions(func_chain)} if func_chain else {} if alg & ns.FLOAT: # noinspection PyTypeChecker kwargs["nan"] = nan_val - return partial(fast_float, **kwargs) + return cast(Callable[[str], StrOrBytes], partial(fast_float, **kwargs)) else: - return partial(fast_int, **kwargs) + return cast(Callable[[str], StrOrBytes], partial(fast_int, **kwargs)) -def final_data_transform_factory(alg, sep, pre_sep): +def final_data_transform_factory( + alg: NSType, sep: StrOrBytes, pre_sep: str +) -> FinalTransformer: """ Create a function to transform a tuple. @@ -589,9 +704,15 @@ def final_data_transform_factory(alg, sep, pre_sep): """ if alg & ns.UNGROUPLETTERS and alg & ns.LOCALEALPHA: swap = alg & NS_DUMB and alg & ns.LOWERCASEFIRST - transform = methodcaller("swapcase") if swap else _no_op - - def func(split_val, val, _transform=transform, _sep=sep, _pre_sep=pre_sep): + transform = cast(StrToStr, methodcaller("swapcase")) if swap else _no_op + + def func( + split_val: Iterable[NatsortInType], + val: str, + _transform: StrToStr = transform, + _sep: StrOrBytes = sep, + _pre_sep: str = pre_sep, + ) -> FinalTransform: """ Return a tuple with the first character of the first element of the return value as the first element, and the return value @@ -606,16 +727,25 @@ def final_data_transform_factory(alg, sep, pre_sep): else: return (_transform(val[0]),), split_val - return func else: - return lambda split_val, val: tuple(split_val) + def func( + split_val: Iterable[NatsortInType], + val: str, + _transform: StrToStr = _no_op, + _sep: StrOrBytes = sep, + _pre_sep: str = pre_sep, + ) -> FinalTransform: + return tuple(split_val) + + return func -lower_function = methodcaller("casefold") + +lower_function: StrToStr = cast(StrToStr, methodcaller("casefold")) # noinspection PyIncorrectDocstring -def groupletters(x, _low=lower_function): +def groupletters(x: str, _low: StrToStr = lower_function) -> str: """ Double all characters, making doubled letters lowercase. @@ -637,7 +767,7 @@ def groupletters(x, _low=lower_function): return "".join(ichain.from_iterable((_low(y), y) for y in x)) -def chain_functions(functions): +def chain_functions(functions: Iterable[AnyCall]) -> AnyCall: """ Chain a list of single-argument functions together and return. @@ -674,7 +804,17 @@ def chain_functions(functions): return partial(reduce, lambda res, f: f(res), functions) -def do_decoding(s, encoding): +@overload +def do_decoding(s: bytes, encoding: str) -> str: + ... + + +@overload +def do_decoding(s: NatsortInType, encoding: str) -> NatsortInType: + ... + + +def do_decoding(s: NatsortInType, encoding: str) -> NatsortInType: """ Helper to decode a *bytes* object, or return the object as-is. @@ -692,13 +832,15 @@ def do_decoding(s, encoding): """ try: - return s.decode(encoding) + return cast(bytes, s).decode(encoding) except (AttributeError, TypeError): return s # noinspection PyIncorrectDocstring -def path_splitter(s, _d_match=re.compile(r"\.\d").match): +def path_splitter( + s: PathArg, _d_match: MatchFn = re.compile(r"\.\d").match +) -> Iterator[str]: """ Split a string into its path components. @@ -63,3 +63,9 @@ exclude = dist, docs, .venv + +[mypy] + +[mypy-icu] +ignore_missing_imports = True + @@ -7,6 +7,8 @@ setup( version="7.1.1", packages=find_packages(), entry_points={"console_scripts": ["natsort = natsort.__main__:main"]}, - python_requires=">=3.4", + python_requires=">=3.6", extras_require={"fast": ["fastnumbers >= 2.0.0"], "icu": ["PyICU >= 1.0.0"]}, + package_data={"": ["py.typed"]}, + zip_safe=False, ) diff --git a/tests/conftest.py b/tests/conftest.py index 74d7f4f..c63e149 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,7 @@ Fixtures for pytest. """ import locale +from typing import Iterator import hypothesis import pytest @@ -16,7 +17,7 @@ hypothesis.settings.register_profile( ) -def load_locale(x): +def load_locale(x: str) -> None: """Convenience to load a locale, trying ISO8859-1 first.""" try: locale.setlocale(locale.LC_ALL, str("{}.ISO8859-1".format(x))) @@ -25,15 +26,16 @@ def load_locale(x): @pytest.fixture() -def with_locale_en_us(): +def with_locale_en_us() -> Iterator[None]: """Convenience to load the en_US locale - reset when complete.""" orig = locale.getlocale() - yield load_locale("en_US") + load_locale("en_US") + yield locale.setlocale(locale.LC_ALL, orig) @pytest.fixture() -def with_locale_de_de(): +def with_locale_de_de() -> Iterator[None]: """ Convenience to load the de_DE locale - reset when complete - skip if missing. """ diff --git a/tests/profile_natsorted.py b/tests/profile_natsorted.py index a2b1a5c..f6580a3 100644 --- a/tests/profile_natsorted.py +++ b/tests/profile_natsorted.py @@ -7,6 +7,7 @@ inputs and different settings. import cProfile import locale import sys +from typing import List, Union try: from natsort import ns, natsort_keygen @@ -14,6 +15,8 @@ except ImportError: sys.path.insert(0, ".") from natsort import ns, natsort_keygen +from natsort.natsort import NatsortKeyType + locale.setlocale(locale.LC_ALL, "en_US.UTF-8") # Samples to parse @@ -32,7 +35,7 @@ path_key = natsort_keygen(alg=ns.PATH) locale_key = natsort_keygen(alg=ns.LOCALE) -def prof_time_to_generate(): +def prof_time_to_generate() -> None: print("*** Generate Plain Key ***") for _ in range(100000): natsort_keygen() @@ -41,7 +44,9 @@ def prof_time_to_generate(): cProfile.run("prof_time_to_generate()", sort="time") -def prof_parsing(a, msg, key=basic_key): +def prof_parsing( + a: Union[str, int, bytes, List[str]], msg: str, key: NatsortKeyType = basic_key +) -> None: print(msg) for _ in range(100000): key(a) diff --git a/tests/test_fake_fastnumbers.py b/tests/test_fake_fastnumbers.py index c75bb11..574f7cf 100644 --- a/tests/test_fake_fastnumbers.py +++ b/tests/test_fake_fastnumbers.py @@ -5,13 +5,14 @@ Test the fake fastnumbers module. import unicodedata from math import isnan +from typing import Union, cast from hypothesis import given from hypothesis.strategies import floats, integers, text from natsort.compat.fake_fastnumbers import fast_float, fast_int -def is_float(x): +def is_float(x: str) -> bool: try: float(x) except ValueError: @@ -25,19 +26,19 @@ def is_float(x): return True -def not_a_float(x): +def not_a_float(x: str) -> bool: return not is_float(x) -def is_int(x): +def is_int(x: Union[str, float]) -> bool: try: - return x.is_integer() + return cast(float, x).is_integer() except AttributeError: try: int(x) except ValueError: try: - unicodedata.digit(x) + unicodedata.digit(cast(str, x)) except (ValueError, TypeError): return False else: @@ -46,7 +47,7 @@ def is_int(x): return True -def not_an_int(x): +def not_an_int(x: Union[str, float]) -> bool: return not is_int(x) @@ -54,56 +55,56 @@ def not_an_int(x): # and a test that uses the hypothesis module. -def test_fast_float_returns_nan_alternate_if_nan_option_is_given(): +def test_fast_float_returns_nan_alternate_if_nan_option_is_given() -> None: assert fast_float("nan", nan=7) == 7 -def test_fast_float_converts_float_string_to_float_example(): +def test_fast_float_converts_float_string_to_float_example() -> None: assert fast_float("45.8") == 45.8 assert fast_float("-45") == -45.0 assert fast_float("45.8e-2", key=len) == 45.8e-2 - assert isnan(fast_float("nan")) - assert isnan(fast_float("+nan")) - assert isnan(fast_float("-NaN")) + assert isnan(cast(float, fast_float("nan"))) + assert isnan(cast(float, fast_float("+nan"))) + assert isnan(cast(float, fast_float("-NaN"))) assert fast_float("۱۲.۱۲") == 12.12 assert fast_float("-۱۲.۱۲") == -12.12 @given(floats(allow_nan=False)) -def test_fast_float_converts_float_string_to_float(x): +def test_fast_float_converts_float_string_to_float(x: float) -> None: assert fast_float(repr(x)) == x -def test_fast_float_leaves_string_as_is_example(): +def test_fast_float_leaves_string_as_is_example() -> None: assert fast_float("invalid") == "invalid" @given(text().filter(not_a_float).filter(bool)) -def test_fast_float_leaves_string_as_is(x): +def test_fast_float_leaves_string_as_is(x: str) -> None: assert fast_float(x) == x -def test_fast_float_with_key_applies_to_string_example(): +def test_fast_float_with_key_applies_to_string_example() -> None: assert fast_float("invalid", key=len) == len("invalid") @given(text().filter(not_a_float).filter(bool)) -def test_fast_float_with_key_applies_to_string(x): +def test_fast_float_with_key_applies_to_string(x: str) -> None: assert fast_float(x, key=len) == len(x) -def test_fast_int_leaves_float_string_as_is_example(): +def test_fast_int_leaves_float_string_as_is_example() -> None: assert fast_int("45.8") == "45.8" assert fast_int("nan") == "nan" assert fast_int("inf") == "inf" @given(floats().filter(not_an_int)) -def test_fast_int_leaves_float_string_as_is(x): +def test_fast_int_leaves_float_string_as_is(x: float) -> None: assert fast_int(repr(x)) == repr(x) -def test_fast_int_converts_int_string_to_int_example(): +def test_fast_int_converts_int_string_to_int_example() -> None: assert fast_int("-45") == -45 assert fast_int("+45") == 45 assert fast_int("۱۲") == 12 @@ -111,23 +112,23 @@ def test_fast_int_converts_int_string_to_int_example(): @given(integers()) -def test_fast_int_converts_int_string_to_int(x): +def test_fast_int_converts_int_string_to_int(x: int) -> None: assert fast_int(repr(x)) == x -def test_fast_int_leaves_string_as_is_example(): +def test_fast_int_leaves_string_as_is_example() -> None: assert fast_int("invalid") == "invalid" @given(text().filter(not_an_int).filter(bool)) -def test_fast_int_leaves_string_as_is(x): +def test_fast_int_leaves_string_as_is(x: str) -> None: assert fast_int(x) == x -def test_fast_int_with_key_applies_to_string_example(): +def test_fast_int_with_key_applies_to_string_example() -> None: assert fast_int("invalid", key=len) == len("invalid") @given(text().filter(not_an_int).filter(bool)) -def test_fast_int_with_key_applies_to_string(x): +def test_fast_int_with_key_applies_to_string(x: str) -> None: assert fast_int(x, key=len) == len(x) diff --git a/tests/test_final_data_transform_factory.py b/tests/test_final_data_transform_factory.py index f6bf636..36607b6 100644 --- a/tests/test_final_data_transform_factory.py +++ b/tests/test_final_data_transform_factory.py @@ -1,17 +1,20 @@ # -*- coding: utf-8 -*- """These test the utils.py functions.""" +from typing import Callable, Union import pytest from hypothesis import example, given from hypothesis.strategies import floats, integers, text -from natsort.ns_enum import NS_DUMB, ns +from natsort.ns_enum import NSType, NS_DUMB, ns from natsort.utils import final_data_transform_factory @pytest.mark.parametrize("alg", [ns.DEFAULT, ns.UNGROUPLETTERS, ns.LOCALE]) @given(x=text(), y=floats(allow_nan=False, allow_infinity=False) | integers()) @pytest.mark.usefixtures("with_locale_en_us") -def test_final_data_transform_factory_default(x, y, alg): +def test_final_data_transform_factory_default( + x: str, y: Union[int, float], alg: NSType +) -> None: final_data_transform_func = final_data_transform_factory(alg, "", "::") value = (x, y) original_value = "".join(map(str, value)) @@ -34,7 +37,9 @@ def test_final_data_transform_factory_default(x, y, alg): @given(x=text(), y=floats(allow_nan=False, allow_infinity=False) | integers()) @example(x="İ", y=0) @pytest.mark.usefixtures("with_locale_en_us") -def test_final_data_transform_factory_ungroup_and_locale(x, y, alg, func): +def test_final_data_transform_factory_ungroup_and_locale( + x: str, y: Union[int, float], alg: NSType, func: Callable[[str], str] +) -> None: final_data_transform_func = final_data_transform_factory(alg, "", "::") value = (x, y) original_value = "".join(map(str, value)) @@ -46,6 +51,6 @@ def test_final_data_transform_factory_ungroup_and_locale(x, y, alg, func): assert result == expected -def test_final_data_transform_factory_ungroup_and_locale_empty_tuple(): +def test_final_data_transform_factory_ungroup_and_locale_empty_tuple() -> None: final_data_transform_func = final_data_transform_factory(ns.UG | ns.L, "", "::") assert final_data_transform_func((), "") == ((), ()) diff --git a/tests/test_input_string_transform_factory.py b/tests/test_input_string_transform_factory.py index 7d54afd..6a08318 100644 --- a/tests/test_input_string_transform_factory.py +++ b/tests/test_input_string_transform_factory.py @@ -1,14 +1,15 @@ # -*- coding: utf-8 -*- """These test the utils.py functions.""" +from typing import Callable import pytest from hypothesis import example, given from hypothesis.strategies import integers, text -from natsort.ns_enum import NS_DUMB, ns +from natsort.ns_enum import NSType, NS_DUMB, ns from natsort.utils import input_string_transform_factory -def thousands_separated_int(n): +def thousands_separated_int(n: str) -> str: """Insert thousands separators in an int.""" new_int = "" for i, y in enumerate(reversed(n), 1): @@ -20,7 +21,7 @@ def thousands_separated_int(n): @given(text()) -def test_input_string_transform_factory_is_no_op_for_no_alg_options(x): +def test_input_string_transform_factory_is_no_op_for_no_alg_options(x: str) -> None: input_string_transform_func = input_string_transform_factory(ns.DEFAULT) assert input_string_transform_func(x) is x @@ -36,7 +37,9 @@ def test_input_string_transform_factory_is_no_op_for_no_alg_options(x): ], ) @given(x=text()) -def test_input_string_transform_factory(x, alg, example_func): +def test_input_string_transform_factory( + x: str, alg: NSType, example_func: Callable[[str], str] +) -> None: input_string_transform_func = input_string_transform_factory(alg) assert input_string_transform_func(x) == example_func(x) @@ -44,7 +47,7 @@ def test_input_string_transform_factory(x, alg, example_func): @example(12543642642534980) # 12,543,642,642,534,980 => 12543642642534980 @given(x=integers(min_value=1000)) @pytest.mark.usefixtures("with_locale_en_us") -def test_input_string_transform_factory_cleans_thousands(x): +def test_input_string_transform_factory_cleans_thousands(x: int) -> None: int_str = str(x).rstrip("lL") thousands_int_str = thousands_separated_int(int_str) assert thousands_int_str.replace(",", "") != thousands_int_str @@ -69,7 +72,9 @@ def test_input_string_transform_factory_cleans_thousands(x): ], ) @pytest.mark.usefixtures("with_locale_en_us") -def test_input_string_transform_factory_handles_us_locale(x, expected): +def test_input_string_transform_factory_handles_us_locale( + x: str, expected: str +) -> None: input_string_transform_func = input_string_transform_factory(ns.LOCALE) assert input_string_transform_func(x) == expected @@ -83,7 +88,9 @@ def test_input_string_transform_factory_handles_us_locale(x, expected): ], ) @pytest.mark.usefixtures("with_locale_de_de") -def test_input_string_transform_factory_handles_de_locale(x, expected): +def test_input_string_transform_factory_handles_de_locale( + x: str, expected: str +) -> None: input_string_transform_func = input_string_transform_factory(ns.LOCALE) assert input_string_transform_func(x) == expected @@ -97,13 +104,15 @@ def test_input_string_transform_factory_handles_de_locale(x, expected): ], ) @pytest.mark.usefixtures("with_locale_de_de") -def test_input_string_transform_factory_handles_german_locale(alg, expected): +def test_input_string_transform_factory_handles_german_locale( + alg: NSType, expected: str +) -> None: input_string_transform_func = input_string_transform_factory(alg) assert input_string_transform_func("1543,753") == expected @pytest.mark.usefixtures("with_locale_de_de") -def test_input_string_transform_factory_does_nothing_with_non_num_input(): +def test_input_string_transform_factory_does_nothing_with_non_num_input() -> None: input_string_transform_func = input_string_transform_factory(ns.LOCALE | ns.FLOAT) expected = "154s,t53" assert input_string_transform_func("154s,t53") == expected diff --git a/tests/test_main.py b/tests/test_main.py index da91fdd..2f784a1 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -5,11 +5,13 @@ Test the natsort command-line tool functions. import re import sys +from typing import Any, List, Union import pytest from hypothesis import given -from hypothesis.strategies import data, floats, integers, lists +from hypothesis.strategies import DataObject, data, floats, integers, lists from natsort.__main__ import ( + TypedArgs, check_filters, keep_entry_range, keep_entry_value, @@ -17,16 +19,19 @@ from natsort.__main__ import ( range_check, sort_and_print_entries, ) +from pytest_mock import MockerFixture -def test_main_passes_default_arguments_with_no_command_line_options(mocker): +def test_main_passes_default_arguments_with_no_command_line_options( + mocker: MockerFixture, +) -> None: p = mocker.patch("natsort.__main__.sort_and_print_entries") main("num-2", "num-6", "num-1") args = p.call_args[0][1] assert not args.paths assert args.filter is None assert args.reverse_filter is None - assert args.exclude is None + assert args.exclude == [] assert not args.reverse assert args.number_type == "int" assert not args.signed @@ -34,7 +39,9 @@ def test_main_passes_default_arguments_with_no_command_line_options(mocker): assert not args.locale -def test_main_passes_arguments_with_all_command_line_options(mocker): +def test_main_passes_arguments_with_all_command_line_options( + mocker: MockerFixture, +) -> None: arguments = ["--paths", "--reverse", "--locale"] arguments.extend(["--filter", "4", "10"]) arguments.extend(["--reverse-filter", "100", "110"]) @@ -57,21 +64,6 @@ def test_main_passes_arguments_with_all_command_line_options(mocker): assert args.locale -class Args: - """A dummy class to simulate the argparse Namespace object""" - - def __init__(self, filt, reverse_filter, exclude, as_path, reverse): - self.filter = filt - self.reverse_filter = reverse_filter - self.exclude = exclude - self.reverse = reverse - self.number_type = "float" - self.signed = True - self.exp = True - self.paths = as_path - self.locale = 0 - - mock_print = "__builtin__.print" if sys.version[0] == "2" else "builtins.print" entries = [ @@ -135,9 +127,11 @@ entries = [ ([None, None, False, True, True], reversed([2, 3, 1, 0, 5, 6, 4])), ], ) -def test_sort_and_print_entries(options, order, mocker): +def test_sort_and_print_entries( + options: List[Any], order: List[int], mocker: MockerFixture +) -> None: p = mocker.patch(mock_print) - sort_and_print_entries(entries, Args(*options)) + sort_and_print_entries(entries, TypedArgs(*options)) e = [mocker.call(entries[i]) for i in order] p.assert_has_calls(e) @@ -146,13 +140,15 @@ def test_sort_and_print_entries(options, order, mocker): # and a test that uses the hypothesis module. -def test_range_check_returns_range_as_is_but_with_floats_example(): +def test_range_check_returns_range_as_is_but_with_floats_example() -> None: assert range_check(10, 11) == (10.0, 11.0) assert range_check(6.4, 30) == (6.4, 30.0) @given(x=floats(allow_nan=False, min_value=-1e8, max_value=1e8) | integers(), d=data()) -def test_range_check_returns_range_as_is_if_first_is_less_than_second(x, d): +def test_range_check_returns_range_as_is_if_first_is_less_than_second( + x: Union[int, float], d: DataObject +) -> None: # Pull data such that the first is less than the second. if isinstance(x, float): y = d.draw(floats(min_value=x + 1.0, max_value=1e9, allow_nan=False)) @@ -161,44 +157,48 @@ def test_range_check_returns_range_as_is_if_first_is_less_than_second(x, d): assert range_check(x, y) == (x, y) -def test_range_check_raises_value_error_if_second_is_less_than_first_example(): +def test_range_check_raises_value_error_if_second_is_less_than_first_example() -> None: with pytest.raises(ValueError, match="low >= high"): range_check(7, 2) @given(x=floats(allow_nan=False), d=data()) -def test_range_check_raises_value_error_if_second_is_less_than_first(x, d): +def test_range_check_raises_value_error_if_second_is_less_than_first( + x: float, d: DataObject +) -> None: # Pull data such that the first is greater than or equal to the second. y = d.draw(floats(max_value=x, allow_nan=False)) with pytest.raises(ValueError, match="low >= high"): range_check(x, y) -def test_check_filters_returns_none_if_filter_evaluates_to_false(): +def test_check_filters_returns_none_if_filter_evaluates_to_false() -> None: assert check_filters(()) is None - assert check_filters(False) is None - assert check_filters(None) is None -def test_check_filters_returns_input_as_is_if_filter_is_valid_example(): +def test_check_filters_returns_input_as_is_if_filter_is_valid_example() -> None: assert check_filters([(6, 7)]) == [(6, 7)] assert check_filters([(6, 7), (2, 8)]) == [(6, 7), (2, 8)] @given(x=lists(integers(), min_size=1), d=data()) -def test_check_filters_returns_input_as_is_if_filter_is_valid(x, d): +def test_check_filters_returns_input_as_is_if_filter_is_valid( + x: List[int], d: DataObject +) -> None: # ensure y is element-wise greater than x y = [d.draw(integers(min_value=val + 1)) for val in x] assert check_filters(list(zip(x, y))) == [(i, j) for i, j in zip(x, y)] -def test_check_filters_raises_value_error_if_filter_is_invalid_example(): +def test_check_filters_raises_value_error_if_filter_is_invalid_example() -> None: with pytest.raises(ValueError, match="Error in --filter: low >= high"): check_filters([(7, 2)]) @given(x=lists(integers(), min_size=1), d=data()) -def test_check_filters_raises_value_error_if_filter_is_invalid(x, d): +def test_check_filters_raises_value_error_if_filter_is_invalid( + x: List[int], d: DataObject +) -> None: # ensure y is element-wise less than or equal to x y = [d.draw(integers(max_value=val)) for val in x] with pytest.raises(ValueError, match="Error in --filter: low >= high"): @@ -212,11 +212,11 @@ def test_check_filters_raises_value_error_if_filter_is_invalid(x, d): # 3. No portion is between the bounds => False. [([0], [100], True), ([1, 88], [20, 90], True), ([1], [20], False)], ) -def test_keep_entry_range(lows, highs, truth): +def test_keep_entry_range(lows: List[int], highs: List[int], truth: bool) -> None: assert keep_entry_range("a56b23c89", lows, highs, int, re.compile(r"\d+")) is truth # 1. Values not in entry => True. 2. Values in entry => False. @pytest.mark.parametrize("values, truth", [([100, 45], True), ([23], False)]) -def test_keep_entry_value(values, truth): +def test_keep_entry_value(values: List[int], truth: bool) -> None: assert keep_entry_value("a56b23c89", values, int, re.compile(r"\d+")) is truth diff --git a/tests/test_natsort_key.py b/tests/test_natsort_key.py index 6b56bb1..cdfdc67 100644 --- a/tests/test_natsort_key.py +++ b/tests/test_natsort_key.py @@ -1,42 +1,45 @@ # -*- coding: utf-8 -*- """These test the utils.py functions.""" +from typing import Any, List, NoReturn, Tuple, Union, cast from hypothesis import given from hypothesis.strategies import binary, floats, integers, lists, text from natsort.utils import natsort_key -def str_func(x): +def str_func(x: Any) -> Tuple[str]: if isinstance(x, str): - return x + return (x,) else: raise TypeError("Not a str!") -def fail(_): +def fail(_: Any) -> NoReturn: raise AssertionError("This should never be reached!") @given(floats(allow_nan=False) | integers()) -def test_natsort_key_with_numeric_input_takes_number_path(x): - assert natsort_key(x, None, str_func, fail, lambda y: y) is x +def test_natsort_key_with_numeric_input_takes_number_path(x: Union[float, int]) -> None: + assert natsort_key(x, None, str_func, fail, lambda y: ("", y))[1] is x @given(binary().filter(bool)) -def test_natsort_key_with_bytes_input_takes_bytes_path(x): - assert natsort_key(x, None, str_func, lambda y: y, fail) is x +def test_natsort_key_with_bytes_input_takes_bytes_path(x: bytes) -> None: + assert natsort_key(x, None, str_func, lambda y: (y,), fail)[0] is x @given(text()) -def test_natsort_key_with_text_input_takes_string_path(x): - assert natsort_key(x, None, str_func, fail, fail) is x +def test_natsort_key_with_text_input_takes_string_path(x: str) -> None: + assert natsort_key(x, None, str_func, fail, fail)[0] is x @given(lists(elements=text(), min_size=1, max_size=10)) -def test_natsort_key_with_nested_input_takes_nested_path(x): - assert natsort_key(x, None, str_func, fail, fail) == tuple(x) +def test_natsort_key_with_nested_input_takes_nested_path(x: List[str]) -> None: + assert natsort_key(x, None, str_func, fail, fail) == tuple((y,) for y in x) @given(text()) -def test_natsort_key_with_key_argument_applies_key_before_processing(x): - assert natsort_key(x, len, str_func, fail, lambda y: y) == len(x) +def test_natsort_key_with_key_argument_applies_key_before_processing(x: str) -> None: + assert natsort_key(x, len, str_func, fail, lambda y: ("", cast(int, y)))[1] == len( + x + ) diff --git a/tests/test_natsort_keygen.py b/tests/test_natsort_keygen.py index 5fb8cf3..d4da2e3 100644 --- a/tests/test_natsort_keygen.py +++ b/tests/test_natsort_keygen.py @@ -5,23 +5,27 @@ See the README or the natsort homepage for more details. """ import os +from typing import List, Tuple, Union import pytest from natsort import natsort_key, natsort_keygen, natsorted, ns from natsort.compat.locale import get_strxfrm, null_string_locale +from natsort.ns_enum import NSType +from natsort.utils import BytesTransform, FinalTransform +from pytest_mock import MockerFixture @pytest.fixture -def arbitrary_input(): +def arbitrary_input() -> List[Union[str, float]]: return ["6A-5.034e+1", "/Folder (1)/Foo", 56.7] @pytest.fixture -def bytes_input(): +def bytes_input() -> bytes: return b"6A-5.034e+1" -def test_natsort_keygen_demonstration(): +def test_natsort_keygen_demonstration() -> None: original_list = ["a50", "a51.", "a50.31", "a50.4", "a5.034e1", "a50.300"] copy_of_list = original_list[:] original_list.sort(key=natsort_keygen(alg=ns.F)) @@ -29,21 +33,23 @@ def test_natsort_keygen_demonstration(): assert original_list == natsorted(copy_of_list, alg=ns.F) -def test_natsort_key_public(): +def test_natsort_key_public() -> None: assert natsort_key("a-5.034e2") == ("a-", 5, ".", 34, "e", 2) -def test_natsort_keygen_with_invalid_alg_input_raises_value_error(): +def test_natsort_keygen_with_invalid_alg_input_raises_value_error() -> None: # Invalid arguments give the correct response with pytest.raises(ValueError, match="'alg' argument"): - natsort_keygen(None, "1") + natsort_keygen(None, "1") # type: ignore @pytest.mark.parametrize( "alg, expected", [(ns.DEFAULT, ("a-", 5, ".", 34, "e", 1)), (ns.FLOAT | ns.SIGNED, ("a", -50.34))], ) -def test_natsort_keygen_returns_natsort_key_that_parses_input(alg, expected): +def test_natsort_keygen_returns_natsort_key_that_parses_input( + alg: NSType, expected: Tuple[Union[str, int, float], ...] +) -> None: ns_key = natsort_keygen(alg=alg) assert ns_key("a-5.034e1") == expected @@ -78,7 +84,9 @@ def test_natsort_keygen_returns_natsort_key_that_parses_input(alg, expected): ), ], ) -def test_natsort_keygen_handles_arbitrary_input(arbitrary_input, alg, expected): +def test_natsort_keygen_handles_arbitrary_input( + arbitrary_input: List[Union[str, float]], alg: NSType, expected: FinalTransform +) -> None: ns_key = natsort_keygen(alg=alg) assert ns_key(arbitrary_input) == expected @@ -93,7 +101,9 @@ def test_natsort_keygen_handles_arbitrary_input(arbitrary_input, alg, expected): (ns.PATH | ns.GROUPLETTERS, ((b"6A-5.034e+1",),)), ], ) -def test_natsort_keygen_handles_bytes_input(bytes_input, alg, expected): +def test_natsort_keygen_handles_bytes_input( + bytes_input: bytes, alg: NSType, expected: BytesTransform +) -> None: ns_key = natsort_keygen(alg=alg) assert ns_key(bytes_input) == expected @@ -131,23 +141,29 @@ def test_natsort_keygen_handles_bytes_input(bytes_input, alg, expected): ], ) @pytest.mark.usefixtures("with_locale_en_us") -def test_natsort_keygen_with_locale(mocker, arbitrary_input, alg, expected, is_dumb): +def test_natsort_keygen_with_locale( + mocker: MockerFixture, + arbitrary_input: List[Union[str, float]], + alg: NSType, + expected: FinalTransform, + is_dumb: bool, +) -> None: # First, apply the correct strxfrm function to the string values. strxfrm = get_strxfrm() - expected = [list(sub) for sub in expected] + expected_tmp = [list(sub) for sub in expected] try: for i in (2, 4, 6): - expected[0][i] = strxfrm(expected[0][i]) + expected_tmp[0][i] = strxfrm(expected_tmp[0][i]) for i in (0, 2): - expected[1][i] = strxfrm(expected[1][i]) - expected = tuple(tuple(sub) for sub in expected) + expected_tmp[1][i] = strxfrm(expected_tmp[1][i]) + expected = tuple(tuple(sub) for sub in expected_tmp) except IndexError: # ns.LOCALE | ns.CAPITALFIRST - expected = [[list(subsub) for subsub in sub] for sub in expected] + expected_tmp = [[list(subsub) for subsub in sub] for sub in expected_tmp] for i in (2, 4, 6): - expected[0][1][i] = strxfrm(expected[0][1][i]) + expected_tmp[0][1][i] = strxfrm(expected_tmp[0][1][i]) for i in (0, 2): - expected[1][1][i] = strxfrm(expected[1][1][i]) - expected = tuple(tuple(tuple(subsub) for subsub in sub) for sub in expected) + expected_tmp[1][1][i] = strxfrm(expected_tmp[1][1][i]) + expected = tuple(tuple(tuple(subsub) for subsub in sub) for sub in expected_tmp) mocker.patch("natsort.compat.locale.dumb_sort", return_value=is_dumb) ns_key = natsort_keygen(alg=alg) @@ -159,7 +175,9 @@ def test_natsort_keygen_with_locale(mocker, arbitrary_input, alg, expected, is_d [(ns.LOCALE, False), (ns.LOCALE, True), (ns.LOCALE | ns.CAPITALFIRST, False)], ) @pytest.mark.usefixtures("with_locale_en_us") -def test_natsort_keygen_with_locale_bytes(mocker, bytes_input, alg, is_dumb): +def test_natsort_keygen_with_locale_bytes( + mocker: MockerFixture, bytes_input: bytes, alg: NSType, is_dumb: bool +) -> None: expected = (b"6A-5.034e+1",) mocker.patch("natsort.compat.locale.dumb_sort", return_value=is_dumb) ns_key = natsort_keygen(alg=alg) diff --git a/tests/test_natsorted.py b/tests/test_natsorted.py index 4254e6c..d043ab4 100644 --- a/tests/test_natsorted.py +++ b/tests/test_natsorted.py @@ -5,34 +5,38 @@ See the README or the natsort homepage for more details. """ from operator import itemgetter +from typing import List, Tuple, Union import pytest from natsort import as_utf8, natsorted, ns +from natsort.ns_enum import NSType from pytest import raises @pytest.fixture -def float_list(): +def float_list() -> List[str]: return ["a50", "a51.", "a50.31", "a-50", "a50.4", "a5.034e1", "a50.300"] @pytest.fixture -def fruit_list(): +def fruit_list() -> List[str]: return ["Apple", "corn", "Corn", "Banana", "apple", "banana"] @pytest.fixture -def mixed_list(): +def mixed_list() -> List[Union[str, int, float]]: return ["Ä", "0", "ä", 3, "b", 1.5, "2", "Z"] -def test_natsorted_numbers_in_ascending_order(): +def test_natsorted_numbers_in_ascending_order() -> None: given = ["a2", "a5", "a9", "a1", "a4", "a10", "a6"] expected = ["a1", "a2", "a4", "a5", "a6", "a9", "a10"] assert natsorted(given) == expected -def test_natsorted_can_sort_as_signed_floats_with_exponents(float_list): +def test_natsorted_can_sort_as_signed_floats_with_exponents( + float_list: List[str], +) -> None: expected = ["a-50", "a50", "a50.300", "a50.31", "a5.034e1", "a50.4", "a51."] assert natsorted(float_list, alg=ns.REAL) == expected @@ -42,19 +46,23 @@ def test_natsorted_can_sort_as_signed_floats_with_exponents(float_list): "alg", [ns.NOEXP | ns.FLOAT | ns.UNSIGNED, ns.NOEXP | ns.FLOAT], ) -def test_natsorted_can_sort_as_unsigned_and_ignore_exponents(float_list, alg): +def test_natsorted_can_sort_as_unsigned_and_ignore_exponents( + float_list: List[str], alg: NSType +) -> None: expected = ["a5.034e1", "a50", "a50.300", "a50.31", "a50.4", "a51.", "a-50"] assert natsorted(float_list, alg=alg) == expected # DEFAULT and INT are all equivalent. @pytest.mark.parametrize("alg", [ns.DEFAULT, ns.INT]) -def test_natsorted_can_sort_as_unsigned_ints_which_is_default(float_list, alg): +def test_natsorted_can_sort_as_unsigned_ints_which_is_default( + float_list: List[str], alg: NSType +) -> None: expected = ["a5.034e1", "a50", "a50.4", "a50.31", "a50.300", "a51.", "a-50"] assert natsorted(float_list, alg=alg) == expected -def test_natsorted_can_sort_as_signed_ints(float_list): +def test_natsorted_can_sort_as_signed_ints(float_list: List[str]) -> None: expected = ["a-50", "a5.034e1", "a50", "a50.4", "a50.31", "a50.300", "a51."] assert natsorted(float_list, alg=ns.SIGNED) == expected @@ -63,12 +71,14 @@ def test_natsorted_can_sort_as_signed_ints(float_list): "alg, expected", [(ns.UNSIGNED, ["a7", "a+2", "a-5"]), (ns.SIGNED, ["a-5", "a+2", "a7"])], ) -def test_natsorted_can_sort_with_or_without_accounting_for_sign(alg, expected): +def test_natsorted_can_sort_with_or_without_accounting_for_sign( + alg: NSType, expected: List[str] +) -> None: given = ["a-5", "a7", "a+2"] assert natsorted(given, alg=alg) == expected -def test_natsorted_can_sort_as_version_numbers(): +def test_natsorted_can_sort_as_version_numbers() -> None: given = ["1.9.9a", "1.11", "1.9.9b", "1.11.4", "1.10.1"] expected = ["1.9.9a", "1.9.9b", "1.10.1", "1.11", "1.11.4"] assert natsorted(given) == expected @@ -81,7 +91,11 @@ def test_natsorted_can_sort_as_version_numbers(): (ns.NUMAFTER, ["Ä", "Z", "ä", "b", "0", 1.5, "2", 3]), ], ) -def test_natsorted_handles_mixed_types(mixed_list, alg, expected): +def test_natsorted_handles_mixed_types( + mixed_list: List[Union[str, int, float]], + alg: NSType, + expected: List[Union[str, int, float]], +) -> None: assert natsorted(mixed_list, alg=alg) == expected @@ -92,14 +106,16 @@ def test_natsorted_handles_mixed_types(mixed_list, alg, expected): (ns.NANLAST, [5, "25", 1e40, float("nan")], slice(None, 3)), ], ) -def test_natsorted_handles_nan(alg, expected, slc): - given = ["25", 5, float("nan"), 1e40] +def test_natsorted_handles_nan( + alg: NSType, expected: List[Union[str, float, int]], slc: slice +) -> None: + given: List[Union[str, float, int]] = ["25", 5, float("nan"), 1e40] # The slice is because NaN != NaN # noinspection PyUnresolvedReferences assert natsorted(given, alg=alg)[slc] == expected[slc] -def test_natsorted_with_mixed_bytes_and_str_input_raises_type_error(): +def test_natsorted_with_mixed_bytes_and_str_input_raises_type_error() -> None: with raises(TypeError, match="bytes"): natsorted(["ä", b"b"]) @@ -107,29 +123,31 @@ def test_natsorted_with_mixed_bytes_and_str_input_raises_type_error(): assert natsorted(["ä", b"b"], key=as_utf8) == ["ä", b"b"] -def test_natsorted_raises_type_error_for_non_iterable_input(): +def test_natsorted_raises_type_error_for_non_iterable_input() -> None: with raises(TypeError, match="'int' object is not iterable"): - natsorted(100) + natsorted(100) # type: ignore -def test_natsorted_recurses_into_nested_lists(): +def test_natsorted_recurses_into_nested_lists() -> None: given = [["a1", "a5"], ["a1", "a40"], ["a10", "a1"], ["a2", "a5"]] expected = [["a1", "a5"], ["a1", "a40"], ["a2", "a5"], ["a10", "a1"]] assert natsorted(given) == expected -def test_natsorted_applies_key_to_each_list_element_before_sorting_list(): +def test_natsorted_applies_key_to_each_list_element_before_sorting_list() -> None: given = [("a", "num3"), ("b", "num5"), ("c", "num2")] expected = [("c", "num2"), ("a", "num3"), ("b", "num5")] assert natsorted(given, key=itemgetter(1)) == expected -def test_natsorted_returns_list_in_reversed_order_with_reverse_option(float_list): +def test_natsorted_returns_list_in_reversed_order_with_reverse_option( + float_list: List[str], +) -> None: expected = natsorted(float_list)[::-1] assert natsorted(float_list, reverse=True) == expected -def test_natsorted_handles_filesystem_paths(): +def test_natsorted_handles_filesystem_paths() -> None: given = [ "/p/Folder (10)/file.tar.gz", "/p/Folder (1)/file (1).tar.gz", @@ -157,10 +175,10 @@ def test_natsorted_handles_filesystem_paths(): assert natsorted(given, alg=ns.FLOAT | ns.PATH) == expected_correct -def test_natsorted_handles_numbers_and_filesystem_paths_simultaneously(): +def test_natsorted_handles_numbers_and_filesystem_paths_simultaneously() -> None: # You can sort paths and numbers, not that you'd want to - given = ["/Folder (9)/file.exe", 43] - expected = [43, "/Folder (9)/file.exe"] + given: List[Union[str, int]] = ["/Folder (9)/file.exe", 43] + expected: List[Union[str, int]] = [43, "/Folder (9)/file.exe"] assert natsorted(given, alg=ns.PATH) == expected @@ -174,7 +192,9 @@ def test_natsorted_handles_numbers_and_filesystem_paths_simultaneously(): (ns.G | ns.LF, ["apple", "Apple", "banana", "Banana", "corn", "Corn"]), ], ) -def test_natsorted_supports_case_handling(alg, expected, fruit_list): +def test_natsorted_supports_case_handling( + alg: NSType, expected: List[str], fruit_list: List[str] +) -> None: assert natsorted(fruit_list, alg=alg) == expected @@ -186,7 +206,9 @@ def test_natsorted_supports_case_handling(alg, expected, fruit_list): (ns.IGNORECASE, [("a3", "a1"), ("A5", "a6")]), ], ) -def test_natsorted_supports_nested_case_handling(alg, expected): +def test_natsorted_supports_nested_case_handling( + alg: NSType, expected: List[Tuple[str, str]] +) -> None: given = [("A5", "a6"), ("a3", "a1")] assert natsorted(given, alg=alg) == expected @@ -201,26 +223,28 @@ def test_natsorted_supports_nested_case_handling(alg, expected): ], ) @pytest.mark.usefixtures("with_locale_en_us") -def test_natsorted_can_sort_using_locale(fruit_list, alg, expected): +def test_natsorted_can_sort_using_locale( + fruit_list: List[str], alg: NSType, expected: List[str] +) -> None: assert natsorted(fruit_list, alg=ns.LOCALE | alg) == expected @pytest.mark.usefixtures("with_locale_en_us") -def test_natsorted_can_sort_locale_specific_numbers_en(): +def test_natsorted_can_sort_locale_specific_numbers_en() -> None: given = ["c", "a5,467.86", "ä", "b", "a5367.86", "a5,6", "a5,50"] expected = ["a5,6", "a5,50", "a5367.86", "a5,467.86", "ä", "b", "c"] assert natsorted(given, alg=ns.LOCALE | ns.F) == expected @pytest.mark.usefixtures("with_locale_de_de") -def test_natsorted_can_sort_locale_specific_numbers_de(): +def test_natsorted_can_sort_locale_specific_numbers_de() -> None: given = ["c", "a5.467,86", "ä", "b", "a5367.86", "a5,6", "a5,50"] expected = ["a5,50", "a5,6", "a5367.86", "a5.467,86", "ä", "b", "c"] assert natsorted(given, alg=ns.LOCALE | ns.F) == expected @pytest.mark.usefixtures("with_locale_de_de") -def test_natsorted_locale_bug_regression_test_109(): +def test_natsorted_locale_bug_regression_test_109() -> None: # https://github.com/SethMMorton/natsort/issues/109 given = ["462166", "461761"] expected = ["461761", "462166"] @@ -242,7 +266,11 @@ def test_natsorted_locale_bug_regression_test_109(): ], ) @pytest.mark.usefixtures("with_locale_en_us") -def test_natsorted_handles_mixed_types_with_locale(mixed_list, alg, expected): +def test_natsorted_handles_mixed_types_with_locale( + mixed_list: List[Union[str, int, float]], + alg: NSType, + expected: List[Union[str, int, float]], +) -> None: assert natsorted(mixed_list, alg=ns.LOCALE | alg) == expected @@ -253,12 +281,14 @@ def test_natsorted_handles_mixed_types_with_locale(mixed_list, alg, expected): (ns.NUMAFTER, ["Banana", "apple", "corn", "~~~~~~", "73", "5039"]), ], ) -def test_natsorted_sorts_an_odd_collection_of_strings(alg, expected): +def test_natsorted_sorts_an_odd_collection_of_strings( + alg: NSType, expected: List[str] +) -> None: given = ["apple", "Banana", "73", "5039", "corn", "~~~~~~"] assert natsorted(given, alg=alg) == expected -def test_natsorted_sorts_mixed_ascii_and_non_ascii_numbers(): +def test_natsorted_sorts_mixed_ascii_and_non_ascii_numbers() -> None: given = [ "1st street", "10th street", diff --git a/tests/test_natsorted_convenience.py b/tests/test_natsorted_convenience.py index cdc2c50..0b2cd75 100644 --- a/tests/test_natsorted_convenience.py +++ b/tests/test_natsorted_convenience.py @@ -5,6 +5,7 @@ See the README or the natsort homepage for more details. """ from operator import itemgetter +from typing import List import pytest from natsort import ( @@ -23,21 +24,21 @@ from natsort import ( @pytest.fixture -def version_list(): +def version_list() -> List[str]: return ["1.9.9a", "1.11", "1.9.9b", "1.11.4", "1.10.1"] @pytest.fixture -def float_list(): +def float_list() -> List[str]: return ["a50", "a51.", "a50.31", "a-50", "a50.4", "a5.034e1", "a50.300"] @pytest.fixture -def fruit_list(): +def fruit_list() -> List[str]: return ["Apple", "corn", "Corn", "Banana", "apple", "banana"] -def test_decoder_returns_function_that_can_decode_bytes_but_return_non_bytes_as_is(): +def test_decoder_returns_function_that_decodes_bytes_but_returns_other_as_is() -> None: func = decoder("latin1") str_obj = "bytes" int_obj = 14 @@ -46,24 +47,28 @@ def test_decoder_returns_function_that_can_decode_bytes_but_return_non_bytes_as_ assert func(str_obj) is str_obj # same object returned b/c only bytes has decode -def test_as_ascii_converts_bytes_to_ascii(): +def test_as_ascii_converts_bytes_to_ascii() -> None: assert decoder("ascii")(b"bytes") == as_ascii(b"bytes") -def test_as_utf8_converts_bytes_to_utf8(): +def test_as_utf8_converts_bytes_to_utf8() -> None: assert decoder("utf8")(b"bytes") == as_utf8(b"bytes") -def test_realsorted_is_identical_to_natsorted_with_real_alg(float_list): +def test_realsorted_is_identical_to_natsorted_with_real_alg( + float_list: List[str], +) -> None: assert realsorted(float_list) == natsorted(float_list, alg=ns.REAL) @pytest.mark.usefixtures("with_locale_en_us") -def test_humansorted_is_identical_to_natsorted_with_locale_alg(fruit_list): +def test_humansorted_is_identical_to_natsorted_with_locale_alg( + fruit_list: List[str], +) -> None: assert humansorted(fruit_list) == natsorted(fruit_list, alg=ns.LOCALE) -def test_index_natsorted_returns_integer_list_of_sort_order_for_input_list(): +def test_index_natsorted_returns_integer_list_of_sort_order_for_input_list() -> None: given = ["num3", "num5", "num2"] other = ["foo", "bar", "baz"] index = index_natsorted(given) @@ -72,27 +77,31 @@ def test_index_natsorted_returns_integer_list_of_sort_order_for_input_list(): assert [other[i] for i in index] == ["baz", "foo", "bar"] -def test_index_natsorted_reverse(): +def test_index_natsorted_reverse() -> None: given = ["num3", "num5", "num2"] assert index_natsorted(given, reverse=True) == index_natsorted(given)[::-1] -def test_index_natsorted_applies_key_function_before_sorting(): +def test_index_natsorted_applies_key_function_before_sorting() -> None: given = [("a", "num3"), ("b", "num5"), ("c", "num2")] expected = [2, 0, 1] assert index_natsorted(given, key=itemgetter(1)) == expected -def test_index_realsorted_is_identical_to_index_natsorted_with_real_alg(float_list): +def test_index_realsorted_is_identical_to_index_natsorted_with_real_alg( + float_list: List[str], +) -> None: assert index_realsorted(float_list) == index_natsorted(float_list, alg=ns.REAL) @pytest.mark.usefixtures("with_locale_en_us") -def test_index_humansorted_is_identical_to_index_natsorted_with_locale_alg(fruit_list): +def test_index_humansorted_is_identical_to_index_natsorted_with_locale_alg( + fruit_list: List[str], +) -> None: assert index_humansorted(fruit_list) == index_natsorted(fruit_list, alg=ns.LOCALE) -def test_order_by_index_sorts_list_according_to_order_of_integer_list(): +def test_order_by_index_sorts_list_according_to_order_of_integer_list() -> None: given = ["num3", "num5", "num2"] index = [2, 0, 1] expected = [given[i] for i in index] @@ -100,7 +109,7 @@ def test_order_by_index_sorts_list_according_to_order_of_integer_list(): assert order_by_index(given, index) == expected -def test_order_by_index_returns_generator_with_iter_true(): +def test_order_by_index_returns_generator_with_iter_true() -> None: given = ["num3", "num5", "num2"] index = [2, 0, 1] assert order_by_index(given, index, True) != [given[i] for i in index] diff --git a/tests/test_ns_enum.py b/tests/test_ns_enum.py index 1d3803b..7a30718 100644 --- a/tests/test_ns_enum.py +++ b/tests/test_ns_enum.py @@ -1,8 +1,10 @@ +import pytest from natsort import ns -def test_ns_enum(): - enum_name_values = [ +@pytest.mark.parametrize( + "given, expected", + [ ("FLOAT", 0x0001), ("SIGNED", 0x0002), ("NOEXP", 0x0004), @@ -40,5 +42,7 @@ def test_ns_enum(): ("NL", 0x0400), ("CN", 0x0800), ("NA", 0x1000), - ] - assert list(ns._asdict().items()) == enum_name_values + ], +) +def test_ns_enum(given: str, expected: int) -> None: + assert ns[given] == expected diff --git a/tests/test_os_sorted.py b/tests/test_os_sorted.py index afb15cf..d0ecc79 100644 --- a/tests/test_os_sorted.py +++ b/tests/test_os_sorted.py @@ -3,6 +3,7 @@ Testing for the OS sorting """ import platform +from typing import cast import natsort import pytest @@ -15,7 +16,7 @@ else: has_icu = True -def test_os_sorted_compound(): +def test_os_sorted_compound() -> None: given = [ "/p/Folder (10)/file.tar.gz", "/p/Folder (1)/file (1).tar.gz", @@ -36,14 +37,14 @@ def test_os_sorted_compound(): assert result == expected -def test_os_sorted_misc_no_fail(): +def test_os_sorted_misc_no_fail() -> None: natsort.os_sorted([9, 4.3, None, float("nan")]) -def test_os_sorted_key(): +def test_os_sorted_key() -> None: given = ["foo0", "foo2", "goo1"] expected = ["foo0", "goo1", "foo2"] - result = natsort.os_sorted(given, key=lambda x: x.replace("g", "f")) + result = natsort.os_sorted(given, key=lambda x: cast(str, x).replace("g", "f")) assert result == expected @@ -199,7 +200,7 @@ else: @pytest.mark.usefixtures("with_locale_en_us") -def test_os_sorted_corpus(): +def test_os_sorted_corpus() -> None: result = natsort.os_sorted(given) print(result) assert result == expected diff --git a/tests/test_parse_bytes_function.py b/tests/test_parse_bytes_function.py index 6637cbd..318c4aa 100644 --- a/tests/test_parse_bytes_function.py +++ b/tests/test_parse_bytes_function.py @@ -4,8 +4,8 @@ import pytest from hypothesis import given from hypothesis.strategies import binary -from natsort.ns_enum import ns -from natsort.utils import parse_bytes_factory +from natsort.ns_enum import NSType, ns +from natsort.utils import BytesTransformer, parse_bytes_factory @pytest.mark.parametrize( @@ -19,6 +19,8 @@ from natsort.utils import parse_bytes_factory ], ) @given(x=binary()) -def test_parse_bytest_factory_makes_function_that_returns_tuple(x, alg, example_func): +def test_parse_bytest_factory_makes_function_that_returns_tuple( + x: bytes, alg: NSType, example_func: BytesTransformer +) -> None: parse_bytes_func = parse_bytes_factory(alg) assert parse_bytes_func(x) == example_func(x) diff --git a/tests/test_parse_number_function.py b/tests/test_parse_number_function.py index 29ee5a3..e5f417d 100644 --- a/tests/test_parse_number_function.py +++ b/tests/test_parse_number_function.py @@ -1,11 +1,13 @@ # -*- coding: utf-8 -*- """These test the utils.py functions.""" +from typing import Optional, Tuple, Union + import pytest from hypothesis import given from hypothesis.strategies import floats, integers -from natsort.ns_enum import ns -from natsort.utils import parse_number_or_none_factory +from natsort.ns_enum import NSType, ns +from natsort.utils import MaybeNumTransformer, parse_number_or_none_factory @pytest.mark.usefixtures("with_locale_en_us") @@ -19,7 +21,9 @@ from natsort.utils import parse_number_or_none_factory ], ) @given(x=floats(allow_nan=False) | integers()) -def test_parse_number_factory_makes_function_that_returns_tuple(x, alg, example_func): +def test_parse_number_factory_makes_function_that_returns_tuple( + x: Union[float, int], alg: NSType, example_func: MaybeNumTransformer +) -> None: parse_number_func = parse_number_or_none_factory(alg, "", "xx") assert parse_number_func(x) == example_func(x) @@ -34,6 +38,8 @@ def test_parse_number_factory_makes_function_that_returns_tuple(x, alg, example_ (ns.NANLAST, None, ("", float("+inf"))), # NANLAST makes it +infinity ], ) -def test_parse_number_factory_treats_nan_and_none_special(alg, x, result): +def test_parse_number_factory_treats_nan_and_none_special( + alg: NSType, x: Optional[Union[float, int]], result: Tuple[str, Union[float, int]] +) -> None: parse_number_func = parse_number_or_none_factory(alg, "", "xx") assert parse_number_func(x) == result diff --git a/tests/test_parse_string_function.py b/tests/test_parse_string_function.py index 46347f1..653a065 100644 --- a/tests/test_parse_string_function.py +++ b/tests/test_parse_string_function.py @@ -2,23 +2,28 @@ """These test the utils.py functions.""" import unicodedata +from typing import Any, Callable, Iterable, List, Tuple, Union import pytest from hypothesis import given from hypothesis.strategies import floats, integers, lists, text from natsort.compat.fastnumbers import fast_float -from natsort.ns_enum import NS_DUMB, ns -from natsort.utils import NumericalRegularExpressions as NumRegex +from natsort.ns_enum import NSType, NS_DUMB, ns +from natsort.utils import ( + FinalTransform, + NumericalRegularExpressions as NumRegex, + StrParser, +) from natsort.utils import parse_string_factory -class CustomTuple(tuple): +class CustomTuple(Tuple[Any, ...]): """Used to ensure what is given during testing is what is returned.""" - original = None + original: Any = None -def input_transform(x): +def input_transform(x: Any) -> Any: """Make uppercase.""" try: return x.upper() @@ -26,14 +31,14 @@ def input_transform(x): return x -def final_transform(x, original): +def final_transform(x: Iterable[Any], original: str) -> FinalTransform: """Make the input a CustomTuple.""" t = CustomTuple(x) t.original = original return t -def parse_string_func_factory(alg): +def parse_string_func_factory(alg: NSType) -> StrParser: """A parse_string_factory result with sample arguments.""" sep = "" return parse_string_factory( @@ -47,10 +52,12 @@ def parse_string_func_factory(alg): @given(x=floats() | integers()) -def test_parse_string_factory_raises_type_error_if_given_number(x): +def test_parse_string_factory_raises_type_error_if_given_number( + x: Union[int, float] +) -> None: parse_string_func = parse_string_func_factory(ns.DEFAULT) with pytest.raises(TypeError): - assert parse_string_func(x) + assert parse_string_func(x) # type: ignore # noinspection PyCallingNonCallable @@ -68,7 +75,9 @@ def test_parse_string_factory_raises_type_error_if_given_number(x): ) ) @pytest.mark.usefixtures("with_locale_en_us") -def test_parse_string_factory_invariance(x, alg, orig_func): +def test_parse_string_factory_invariance( + x: List[Union[float, str, int]], alg: NSType, orig_func: Callable[[str], str] +) -> None: parse_string_func = parse_string_func_factory(alg) # parse_string_factory is the high-level combination of several dedicated # functions involved in splitting and manipulating a string. The details of diff --git a/tests/test_regex.py b/tests/test_regex.py index f647f5f..08314b5 100644 --- a/tests/test_regex.py +++ b/tests/test_regex.py @@ -1,8 +1,11 @@ # -*- coding: utf-8 -*- """These test the splitting regular expressions.""" +from typing import List, Pattern + import pytest from natsort import ns, numeric_regex_chooser +from natsort.ns_enum import NSType from natsort.utils import NumericalRegularExpressions as NumRegex @@ -95,7 +98,9 @@ labels = ["{}-{}".format(given, regex_names[regex]) for given, _, regex in regex @pytest.mark.parametrize("x, expected, regex", regex_params, ids=labels) -def test_regex_splits_correctly(x, expected, regex): +def test_regex_splits_correctly( + x: str, expected: List[str], regex: Pattern[str] +) -> None: # noinspection PyUnresolvedReferences assert regex.split(x) == expected @@ -115,5 +120,5 @@ def test_regex_splits_correctly(x, expected, regex): (ns.FLOAT | ns.UNSIGNED | ns.NOEXP, NumRegex.float_nosign_noexp()), ], ) -def test_regex_chooser(given, expected): +def test_regex_chooser(given: NSType, expected: Pattern[str]) -> None: assert numeric_regex_chooser(given) == expected.pattern[1:-1] # remove parens diff --git a/tests/test_string_component_transform_factory.py b/tests/test_string_component_transform_factory.py index 8b77d38..99df7ea 100644 --- a/tests/test_string_component_transform_factory.py +++ b/tests/test_string_component_transform_factory.py @@ -2,13 +2,14 @@ """These test the utils.py functions.""" from functools import partial +from typing import Any, Callable, FrozenSet, Union import pytest from hypothesis import example, given from hypothesis.strategies import floats, integers, text from natsort.compat.fastnumbers import fast_float, fast_int from natsort.compat.locale import get_strxfrm -from natsort.ns_enum import NS_DUMB, ns +from natsort.ns_enum import NSType, NS_DUMB, ns from natsort.utils import groupletters, string_component_transform_factory # There are some unicode values that are known failures with the builtin locale @@ -21,12 +22,12 @@ except ValueError: bad_uni_chars = frozenset() -def no_bad_uni_chars(x, _bad_chars=bad_uni_chars): +def no_bad_uni_chars(x: str, _bad_chars: FrozenSet[str] = bad_uni_chars) -> bool: """Ensure text does not contain bad unicode characters""" return not any(y in _bad_chars for y in x) -def no_null(x): +def no_null(x: str) -> bool: """Ensure text does not contain a null character.""" return "\0" not in x @@ -65,7 +66,9 @@ def no_null(x): | text().filter(bool).filter(no_bad_uni_chars).filter(no_null) ) @pytest.mark.usefixtures("with_locale_en_us") -def test_string_component_transform_factory(x, alg, example_func): +def test_string_component_transform_factory( + x: Union[str, float, int], alg: NSType, example_func: Callable[[str], Any] +) -> None: string_component_transform_func = string_component_transform_factory(alg) try: assert string_component_transform_func(str(x)) == example_func(str(x)) diff --git a/tests/test_unicode_numbers.py b/tests/test_unicode_numbers.py index be0f4ee..eb71125 100644 --- a/tests/test_unicode_numbers.py +++ b/tests/test_unicode_numbers.py @@ -14,27 +14,27 @@ from natsort.unicode_numbers import ( digits_no_decimals, numeric, numeric_chars, - numeric_hex, numeric_no_decimals, ) +from natsort.unicode_numeric_hex import numeric_hex -def test_numeric_chars_contains_only_valid_unicode_numeric_characters(): +def test_numeric_chars_contains_only_valid_unicode_numeric_characters() -> None: for a in numeric_chars: assert unicodedata.numeric(a, None) is not None -def test_digit_chars_contains_only_valid_unicode_digit_characters(): +def test_digit_chars_contains_only_valid_unicode_digit_characters() -> None: for a in digit_chars: assert unicodedata.digit(a, None) is not None -def test_decimal_chars_contains_only_valid_unicode_decimal_characters(): +def test_decimal_chars_contains_only_valid_unicode_decimal_characters() -> None: for a in decimal_chars: assert unicodedata.decimal(a, None) is not None -def test_numeric_chars_contains_all_valid_unicode_numeric_and_digit_characters(): +def test_numeric_chars_contains_all_valid_unicode_numeric_and_digit_characters() -> None: set_numeric_chars = set(numeric_chars) set_digit_chars = set(digit_chars) set_decimal_chars = set(decimal_chars) @@ -46,7 +46,7 @@ def test_numeric_chars_contains_all_valid_unicode_numeric_and_digit_characters() assert set_numeric_chars.issuperset(numeric_no_decimals) -def test_missing_unicode_number_in_collection(): +def test_missing_unicode_number_in_collection() -> None: ok = True set_numeric_hex = set(numeric_hex) for i in range(0x110000): @@ -71,7 +71,7 @@ repository (https://github.com/SethMMorton/natsort) with the resulting change. ) -def test_combined_string_contains_all_characters_in_list(): +def test_combined_string_contains_all_characters_in_list() -> None: assert numeric == "".join(numeric_chars) assert digits == "".join(digit_chars) assert decimals == "".join(decimal_chars) diff --git a/tests/test_utils.py b/tests/test_utils.py index d559803..38df303 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -6,15 +6,16 @@ import pathlib import string from itertools import chain from operator import neg as op_neg +from typing import List, Pattern, Union import pytest from hypothesis import given from hypothesis.strategies import integers, lists, sampled_from, text from natsort import utils -from natsort.ns_enum import ns +from natsort.ns_enum import NSType, ns -def test_do_decoding_decodes_bytes_string_to_unicode(): +def test_do_decoding_decodes_bytes_string_to_unicode() -> None: assert type(utils.do_decoding(b"bytes", "ascii")) is str assert utils.do_decoding(b"bytes", "ascii") == "bytes" assert utils.do_decoding(b"bytes", "ascii") == b"bytes".decode("ascii") @@ -33,7 +34,9 @@ def test_do_decoding_decodes_bytes_string_to_unicode(): (ns.F | ns.S | ns.N, utils.NumericalRegularExpressions.float_sign_noexp()), ], ) -def test_regex_chooser_returns_correct_regular_expression_object(alg, expected): +def test_regex_chooser_returns_correct_regular_expression_object( + alg: NSType, expected: Pattern[str] +) -> None: assert utils.regex_chooser(alg).pattern == expected.pattern @@ -68,21 +71,21 @@ def test_regex_chooser_returns_correct_regular_expression_object(alg, expected): (ns.REAL, ns.FLOAT | ns.SIGNED), ], ) -def test_ns_enum_values_and_aliases(alg, value_or_alias): +def test_ns_enum_values_and_aliases(alg: NSType, value_or_alias: NSType) -> None: assert alg == value_or_alias -def test_chain_functions_is_a_no_op_if_no_functions_are_given(): +def test_chain_functions_is_a_no_op_if_no_functions_are_given() -> None: x = 2345 assert utils.chain_functions([])(x) is x -def test_chain_functions_does_one_function_if_one_function_is_given(): +def test_chain_functions_does_one_function_if_one_function_is_given() -> None: x = "2345" assert utils.chain_functions([len])(x) == 4 -def test_chain_functions_combines_functions_in_given_order(): +def test_chain_functions_combines_functions_in_given_order() -> None: x = 2345 assert utils.chain_functions([str, len, op_neg])(x) == -len(str(x)) @@ -91,33 +94,37 @@ def test_chain_functions_combines_functions_in_given_order(): # and a test that uses the hypothesis module. -def test_groupletters_returns_letters_with_lowercase_transform_of_letter_example(): +def test_groupletters_gives_letters_with_lowercase_letter_transform_example() -> None: assert utils.groupletters("HELLO") == "hHeElLlLoO" assert utils.groupletters("hello") == "hheelllloo" @given(text().filter(bool)) -def test_groupletters_returns_letters_with_lowercase_transform_of_letter(x): +def test_groupletters_gives_letters_with_lowercase_letter_transform( + x: str, +) -> None: assert utils.groupletters(x) == "".join( chain.from_iterable([y.casefold(), y] for y in x) ) -def test_sep_inserter_does_nothing_if_no_numbers_example(): +def test_sep_inserter_does_nothing_if_no_numbers_example() -> None: assert list(utils.sep_inserter(iter(["a", "b", "c"]), "")) == ["a", "b", "c"] assert list(utils.sep_inserter(iter(["a"]), "")) == ["a"] -def test_sep_inserter_does_nothing_if_only_one_number_example(): +def test_sep_inserter_does_nothing_if_only_one_number_example() -> None: assert list(utils.sep_inserter(iter(["a", 5]), "")) == ["a", 5] -def test_sep_inserter_inserts_separator_string_between_two_numbers_example(): +def test_sep_inserter_inserts_separator_string_between_two_numbers_example() -> None: assert list(utils.sep_inserter(iter([5, 9]), "")) == ["", 5, "", 9] @given(lists(elements=text().filter(bool) | integers(), min_size=3)) -def test_sep_inserter_inserts_separator_between_two_numbers(x): +def test_sep_inserter_inserts_separator_between_two_numbers( + x: List[Union[str, int]] +) -> None: # Rather than just replicating the results in a different algorithm, # validate that the "shape" of the output is as expected. result = list(utils.sep_inserter(iter(x), "")) @@ -127,28 +134,29 @@ def test_sep_inserter_inserts_separator_between_two_numbers(x): assert isinstance(result[i + 1], int) -def test_path_splitter_splits_path_string_by_separator_example(): +def test_path_splitter_splits_path_string_by_sep_example() -> None: given = "/this/is/a/path" expected = (os.sep, "this", "is", "a", "path") assert tuple(utils.path_splitter(given)) == tuple(expected) - given = pathlib.Path(given) - assert tuple(utils.path_splitter(given)) == tuple(expected) + assert tuple(utils.path_splitter(pathlib.Path(given))) == tuple(expected) @given(lists(sampled_from(string.ascii_letters), min_size=2).filter(all)) -def test_path_splitter_splits_path_string_by_separator(x): +def test_path_splitter_splits_path_string_by_sep(x: List[str]) -> None: z = str(pathlib.Path(*x)) assert tuple(utils.path_splitter(z)) == tuple(pathlib.Path(z).parts) -def test_path_splitter_splits_path_string_by_separator_and_removes_extension_example(): +def test_path_splitter_splits_path_string_by_sep_and_removes_extension_example() -> None: given = "/this/is/a/path/file.x1.10.tar.gz" expected = (os.sep, "this", "is", "a", "path", "file.x1.10", ".tar", ".gz") assert tuple(utils.path_splitter(given)) == tuple(expected) @given(lists(sampled_from(string.ascii_letters), min_size=3).filter(all)) -def test_path_splitter_splits_path_string_by_separator_and_removes_extension(x): +def test_path_splitter_splits_path_string_by_sep_and_removes_extension( + x: List[str], +) -> None: z = str(pathlib.Path(*x[:-2])) + "." + x[-1] y = tuple(pathlib.Path(z).parts) assert tuple(utils.path_splitter(z)) == y[:-1] + ( @@ -5,7 +5,7 @@ [tox] envlist = - flake8, py35, py36, py37, py38, py39 + flake8, mypy, py36, py37, py38, py39 # Other valid evironments are: # docs # release @@ -52,6 +52,18 @@ commands = twine check dist/* skip_install = true +# Type checking +[testenv:mypy] +deps = + mypy + hypothesis + pytest + pytest-mock + fastnumbers +commands = + mypy --strict natsort tests +skip_install = true + # Build documentation. # sphinx and sphinx_rtd_theme not in docs/requirements.txt because they # will already be installed on readthedocs. |