summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSeth Morton <seth.m.morton@gmail.com>2021-11-02 19:52:31 -0700
committerGitHub <noreply@github.com>2021-11-02 19:52:31 -0700
commit5eaed4145d174ac05752095ec205ce2cd2e90a5f (patch)
tree1fcc1c7ce9698515ea1ae3f239d4339783e6c4d6
parentafe12261977a219bef0d0c0e6a40e1f81cf44d4f (diff)
parent3461338e52292926bc2148dd5fdda37f253b5860 (diff)
downloadnatsort-5eaed4145d174ac05752095ec205ce2cd2e90a5f.tar.gz
Merge pull request #138 from SethMMorton/type-hinting
Type hinting
-rw-r--r--.github/workflows/code-quality.yml22
-rw-r--r--.github/workflows/tests.yml2
-rw-r--r--CHANGELOG.md9
-rw-r--r--docs/api.rst22
-rw-r--r--docs/conf.py1
-rw-r--r--natsort/__init__.py14
-rw-r--r--natsort/__main__.py70
-rw-r--r--natsort/compat/fake_fastnumbers.py38
-rw-r--r--natsort/compat/fastnumbers.py5
-rw-r--r--natsort/compat/locale.py43
-rw-r--r--natsort/natsort.py227
-rw-r--r--natsort/ns_enum.py107
-rw-r--r--natsort/py.typed0
-rw-r--r--natsort/utils.py278
-rw-r--r--setup.cfg6
-rw-r--r--setup.py4
-rw-r--r--tests/conftest.py10
-rw-r--r--tests/profile_natsorted.py9
-rw-r--r--tests/test_fake_fastnumbers.py49
-rw-r--r--tests/test_final_data_transform_factory.py13
-rw-r--r--tests/test_input_string_transform_factory.py27
-rw-r--r--tests/test_main.py68
-rw-r--r--tests/test_natsort_key.py29
-rw-r--r--tests/test_natsort_keygen.py56
-rw-r--r--tests/test_natsorted.py94
-rw-r--r--tests/test_natsorted_convenience.py39
-rw-r--r--tests/test_ns_enum.py12
-rw-r--r--tests/test_os_sorted.py11
-rw-r--r--tests/test_parse_bytes_function.py8
-rw-r--r--tests/test_parse_number_function.py14
-rw-r--r--tests/test_parse_string_function.py29
-rw-r--r--tests/test_regex.py9
-rw-r--r--tests/test_string_component_transform_factory.py11
-rw-r--r--tests/test_unicode_numbers.py14
-rw-r--r--tests/test_utils.py46
-rw-r--r--tox.ini14
36 files changed, 976 insertions, 434 deletions
diff --git a/.github/workflows/code-quality.yml b/.github/workflows/code-quality.yml
index f95887b..89d31f2 100644
--- a/.github/workflows/code-quality.yml
+++ b/.github/workflows/code-quality.yml
@@ -38,7 +38,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v2
with:
- python-version: '3.5'
+ python-version: '3.6'
- name: Install Flake8
run: pip install flake8 flake8-import-order flake8-bugbear pep8-naming
@@ -46,6 +46,24 @@ jobs:
- name: Run Flake8
run: flake8
+ type-checking:
+ name: Type Checking
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v2
+
+ - name: Set up Python
+ uses: actions/setup-python@v2
+ with:
+ python-version: '3.6'
+
+ - name: Install MyPy
+ run: pip install mypy hypothesis pytest pytest-mock fastnumbers
+
+ - name: Run MyPy
+ run: mypy --strict natsort tests
+
package-validation:
name: Package Validation
runs-on: ubuntu-latest
@@ -56,7 +74,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v2
with:
- python-version: '3.5'
+ python-version: '3.6'
- name: Install Validators
run: pip install twine check-manifest
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index 3263030..ce64b65 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -15,7 +15,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
- python-version: [3.5, 3.6, 3.7, 3.8, 3.9]
+ python-version: [3.6, 3.7, 3.8, 3.9]
os: [ubuntu-latest]
extras: [false]
include:
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 4e0c86d..1e1d147 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,6 +1,15 @@
Unreleased
---
+### Changed
+
+ - The `ns` enum is now implemented as an `enum.IntEnum` instead of a
+ `collections.namedtuple`
+
+### Removed
+
+ - Support for Python 3.4 and Python 3.5
+
[7.1.1] - 2021-01-24
---
diff --git a/docs/api.rst b/docs/api.rst
index 86a26ee..39d7cec 100644
--- a/docs/api.rst
+++ b/docs/api.rst
@@ -117,3 +117,25 @@ Given your chosen algorithm (selected using the :class:`~natsort.ns` enum),
the corresponding regular expression to locate numbers will be returned.
.. autofunction:: numeric_regex_chooser
+
+Help With Type Hinting
+++++++++++++++++++++++
+
+If you need to explictly specify the types that natsort accepts or returns
+in your code, the following types have been exposed for your convenience.
+
++--------------------------------+----------------------------------------------------------------------------------------+
+| Type | Purpose |
++================================+========================================================================================+
+|:attr:`natsort.NatsortKeyType` | Returned by :func:`natsort.natsort_keygen`, and type of :attr:`natsort.natsort_key` |
++--------------------------------+----------------------------------------------------------------------------------------+
+|:attr:`natsort.OSSortKeyType` | Returned by :func:`natsort.os_sort_keygen`, and type of :attr:`natsort.os_sort_key` |
++--------------------------------+----------------------------------------------------------------------------------------+
+|:attr:`natsort.KeyType` | Type of `key` argument to :func:`natsort.natsorted` and :func:`natsort.natsort_keygen` |
++--------------------------------+----------------------------------------------------------------------------------------+
+|:attr:`natsort.NatsortInType` | The input type of :attr:`natsort.NatsortKeyType` |
++--------------------------------+----------------------------------------------------------------------------------------+
+|:attr:`natsort.NatsortOutType` | The output type of :attr:`natsort.NatsortKeyType` |
++--------------------------------+----------------------------------------------------------------------------------------+
+|:attr:`natsort.NSType` | The type of the :class:`ns` enum |
++--------------------------------+----------------------------------------------------------------------------------------+
diff --git a/docs/conf.py b/docs/conf.py
index 05919ec..feab38e 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -35,6 +35,7 @@ extensions = [
"sphinx.ext.napoleon",
"m2r2",
]
+autodoc_typehints = "none"
# Add any paths that contain templates here, relative to this directory.
templates_path = ["_templates"]
diff --git a/natsort/__init__.py b/natsort/__init__.py
index 8c8f87f..4207f79 100644
--- a/natsort/__init__.py
+++ b/natsort/__init__.py
@@ -1,6 +1,8 @@
# -*- coding: utf-8 -*-
from natsort.natsort import (
+ NatsortKeyType,
+ OSSortKeyType,
as_ascii,
as_utf8,
decoder,
@@ -11,7 +13,6 @@ from natsort.natsort import (
natsort_key,
natsort_keygen,
natsorted,
- ns,
numeric_regex_chooser,
order_by_index,
os_sort_key,
@@ -19,7 +20,8 @@ from natsort.natsort import (
os_sorted,
realsorted,
)
-from natsort.utils import chain_functions
+from natsort.ns_enum import NSType, ns
+from natsort.utils import KeyType, NatsortInType, NatsortOutType, chain_functions
__version__ = "7.1.1"
@@ -42,7 +44,13 @@ __all__ = [
"os_sort_key",
"os_sort_keygen",
"os_sorted",
+ "NatsortKeyType",
+ "OSSortKeyType",
+ "KeyType",
+ "NatsortInType",
+ "NatsortOutType",
+ "NSType",
]
# Add the ns keys to this namespace for convenience.
-globals().update(ns._asdict())
+globals().update({name: value for name, value in ns.__members__.items()})
diff --git a/natsort/__main__.py b/natsort/__main__.py
index e2f3299..4dffc4b 100644
--- a/natsort/__main__.py
+++ b/natsort/__main__.py
@@ -1,12 +1,53 @@
# -*- coding: utf-8 -*-
+import argparse
import sys
+from typing import Callable, Iterable, List, Optional, Pattern, Tuple, Union, cast
import natsort
from natsort.utils import regex_chooser
-
-def main(*arguments):
+Num = Union[float, int]
+NumIter = Iterable[Num]
+NumPair = Tuple[Num, Num]
+NumPairIter = Iterable[NumPair]
+NumConverter = Callable[[str], Num]
+
+
+class TypedArgs(argparse.Namespace):
+ paths: bool
+ filter: Optional[List[NumPair]]
+ reverse_filter: Optional[List[NumPair]]
+ exclude: List[Num]
+ reverse: bool
+ number_type: str
+ nosign: bool
+ sign: bool
+ noexp: bool
+ locale: bool
+ entries: List[str]
+
+ def __init__(
+ self,
+ filter: Optional[List[NumPair]] = None,
+ reverse_filter: Optional[List[NumPair]] = None,
+ exclude: Optional[List[Num]] = None,
+ paths: bool = False,
+ reverse: bool = False,
+ ) -> None:
+ """Used by testing only"""
+ self.filter = filter
+ self.reverse_filter = reverse_filter
+ self.exclude = [] if exclude is None else exclude
+ self.paths = paths
+ self.reverse = reverse
+ self.number_type = "int"
+ self.signed = False
+ self.exp = True
+ self.locale = False
+
+
+def main(*arguments: str) -> None:
"""
Performs a natural sort on entries given on the command-line.
@@ -17,7 +58,8 @@ def main(*arguments):
from textwrap import dedent
parser = ArgumentParser(
- description=dedent(main.__doc__), formatter_class=RawDescriptionHelpFormatter
+ description=dedent(cast(str, main.__doc__)),
+ formatter_class=RawDescriptionHelpFormatter,
)
parser.add_argument(
"--version",
@@ -126,7 +168,7 @@ def main(*arguments):
help="The entries to sort. Taken from stdin if nothing is given on "
"the command line.",
)
- args = parser.parse_args(arguments or None)
+ args = parser.parse_args(arguments or None, namespace=TypedArgs())
# Make sure the filter range is given properly. Does nothing if no filter
args.filter = check_filters(args.filter)
@@ -139,7 +181,7 @@ def main(*arguments):
sort_and_print_entries(entries, args)
-def range_check(low, high):
+def range_check(low: Num, high: Num) -> NumPair:
"""
Verify that that given range has a low lower than the high.
@@ -164,7 +206,7 @@ def range_check(low, high):
return low, high
-def check_filters(filters):
+def check_filters(filters: Optional[NumPairIter]) -> Optional[List[NumPair]]:
"""
Execute range_check for every element of an iterable.
@@ -192,7 +234,13 @@ def check_filters(filters):
raise ValueError("Error in --filter: " + str(err))
-def keep_entry_range(entry, lows, highs, converter, regex):
+def keep_entry_range(
+ entry: str,
+ lows: NumIter,
+ highs: NumIter,
+ converter: NumConverter,
+ regex: Pattern[str],
+) -> bool:
"""
Check if an entry falls into a desired range.
@@ -224,7 +272,9 @@ def keep_entry_range(entry, lows, highs, converter, regex):
)
-def keep_entry_value(entry, values, converter, regex):
+def keep_entry_value(
+ entry: str, values: NumIter, converter: NumConverter, regex: Pattern[str]
+) -> bool:
"""
Check if an entry does not match a given value.
@@ -249,13 +299,13 @@ def keep_entry_value(entry, values, converter, regex):
return not any(converter(num) in values for num in regex.findall(entry))
-def sort_and_print_entries(entries, args):
+def sort_and_print_entries(entries: List[str], args: TypedArgs) -> None:
"""Sort the entries, applying the filters first if necessary."""
# Extract the proper number type.
is_float = args.number_type in ("float", "real", "f", "r")
signed = args.signed or args.number_type in ("real", "r")
- alg = (
+ alg: int = (
natsort.ns.FLOAT * is_float
| natsort.ns.SIGNED * signed
| natsort.ns.NOEXP * (not args.exp)
diff --git a/natsort/compat/fake_fastnumbers.py b/natsort/compat/fake_fastnumbers.py
index 7177551..5d44605 100644
--- a/natsort/compat/fake_fastnumbers.py
+++ b/natsort/compat/fake_fastnumbers.py
@@ -3,14 +3,12 @@
This module is intended to replicate some of the functionality
from the fastnumbers module in the event that module is not installed.
"""
-
-# Std. lib imports.
import unicodedata
+from typing import Callable, FrozenSet, Optional, Union
-# Local imports.
from natsort.unicode_numbers import decimal_chars
-NAN_INF = [
+_NAN_INF = [
"INF",
"INf",
"Inf",
@@ -28,21 +26,24 @@ NAN_INF = [
"nAN",
"Nan",
]
-NAN_INF.extend(["+" + x[:2] for x in NAN_INF] + ["-" + x[:2] for x in NAN_INF])
-NAN_INF = frozenset(NAN_INF)
+_NAN_INF.extend(["+" + x[:2] for x in _NAN_INF] + ["-" + x[:2] for x in _NAN_INF])
+NAN_INF = frozenset(_NAN_INF)
ASCII_NUMS = "0123456789+-"
POTENTIAL_FIRST_CHAR = frozenset(decimal_chars + list(ASCII_NUMS + "."))
+StrOrFloat = Union[str, float]
+StrOrInt = Union[str, int]
+
# noinspection PyIncorrectDocstring
def fast_float(
- x,
- key=lambda x: x,
- nan=None,
- _uni=unicodedata.numeric,
- _nan_inf=NAN_INF,
- _first_char=POTENTIAL_FIRST_CHAR,
-):
+ x: str,
+ key: Callable[[str], StrOrFloat] = lambda x: x,
+ nan: Optional[StrOrFloat] = None,
+ _uni: Callable[[str, StrOrFloat], StrOrFloat] = unicodedata.numeric,
+ _nan_inf: FrozenSet[str] = NAN_INF,
+ _first_char: FrozenSet[str] = POTENTIAL_FIRST_CHAR,
+) -> StrOrFloat:
"""
Convert a string to a float quickly, return input as-is if not possible.
@@ -65,8 +66,8 @@ def fast_float(
"""
if x[0] in _first_char or x.lstrip()[:3] in _nan_inf:
try:
- x = float(x)
- return nan if nan is not None and x != x else x
+ ret = float(x)
+ return nan if nan is not None and ret != ret else ret
except ValueError:
try:
return _uni(x, key(x)) if len(x) == 1 else key(x)
@@ -81,8 +82,11 @@ def fast_float(
# noinspection PyIncorrectDocstring
def fast_int(
- x, key=lambda x: x, _uni=unicodedata.digit, _first_char=POTENTIAL_FIRST_CHAR
-):
+ x: str,
+ key: Callable[[str], StrOrInt] = lambda x: x,
+ _uni: Callable[[str, StrOrInt], StrOrInt] = unicodedata.digit,
+ _first_char: FrozenSet[str] = POTENTIAL_FIRST_CHAR,
+) -> StrOrInt:
"""
Convert a string to a int quickly, return input as-is if not possible.
diff --git a/natsort/compat/fastnumbers.py b/natsort/compat/fastnumbers.py
index 4f4d75a..049030d 100644
--- a/natsort/compat/fastnumbers.py
+++ b/natsort/compat/fastnumbers.py
@@ -3,9 +3,10 @@
Interface for natsort to access fastnumbers functions without
having to worry if it is actually installed.
"""
-
import re
+__all__ = ["fast_float", "fast_int"]
+
def is_supported_fastnumbers(fastnumbers_version: str) -> bool:
match = re.match(
@@ -34,4 +35,4 @@ try:
if not is_supported_fastnumbers(fn_ver):
raise ImportError # pragma: no cover
except ImportError:
- from natsort.compat.fake_fastnumbers import fast_float, fast_int # noqa: F401
+ from natsort.compat.fake_fastnumbers import fast_float, fast_int # type: ignore
diff --git a/natsort/compat/locale.py b/natsort/compat/locale.py
index ccb5592..9af5e7a 100644
--- a/natsort/compat/locale.py
+++ b/natsort/compat/locale.py
@@ -3,9 +3,11 @@
Interface for natsort to access locale functionality without
having to worry about if it is using PyICU or the built-in locale.
"""
-
-# Std. lib imports.
import sys
+from typing import Callable, Union, cast
+
+StrOrBytes = Union[str, bytes]
+TrxfmFunc = Callable[[str], StrOrBytes]
# This string should be sorted after any other byte string because
# it contains the max unicode character repeated 20 times.
@@ -13,6 +15,11 @@ import sys
null_string = ""
null_string_max = chr(sys.maxunicode) * 20
+# This variable could be str or bytes depending on the locale library
+# being used, so give the type-checker this information.
+null_string_locale: StrOrBytes
+null_string_locale_max: StrOrBytes
+
# strxfrm can be buggy (especially on BSD-based systems),
# so prefer icu if available.
try: # noqa: C901
@@ -26,26 +33,26 @@ try: # noqa: C901
# You would need some odd data to come after that.
null_string_locale_max = b"x7f" * 50
- def dumb_sort():
+ def dumb_sort() -> bool:
return False
# If using icu, get the locale from the current global locale,
- def get_icu_locale():
+ def get_icu_locale() -> str:
try:
- return icu.Locale(".".join(getlocale()))
+ return cast(str, icu.Locale(".".join(getlocale())))
except TypeError: # pragma: no cover
- return icu.Locale()
+ return cast(str, icu.Locale())
- def get_strxfrm():
- return icu.Collator.createInstance(get_icu_locale()).getSortKey
+ def get_strxfrm() -> TrxfmFunc:
+ return cast(TrxfmFunc, icu.Collator.createInstance(get_icu_locale()).getSortKey)
- def get_thousands_sep():
+ def get_thousands_sep() -> str:
sep = icu.DecimalFormatSymbols.kGroupingSeparatorSymbol
- return icu.DecimalFormatSymbols(get_icu_locale()).getSymbol(sep)
+ return cast(str, icu.DecimalFormatSymbols(get_icu_locale()).getSymbol(sep))
- def get_decimal_point():
+ def get_decimal_point() -> str:
sep = icu.DecimalFormatSymbols.kDecimalSeparatorSymbol
- return icu.DecimalFormatSymbols(get_icu_locale()).getSymbol(sep)
+ return cast(str, icu.DecimalFormatSymbols(get_icu_locale()).getSymbol(sep))
except ImportError:
@@ -57,14 +64,14 @@ except ImportError:
# On some systems, locale is broken and does not sort in the expected
# order. We will try to detect this and compensate.
- def dumb_sort():
+ def dumb_sort() -> bool:
return strxfrm("A") < strxfrm("a")
- def get_strxfrm():
+ def get_strxfrm() -> TrxfmFunc:
return strxfrm
- def get_thousands_sep():
- sep = locale.localeconv()["thousands_sep"]
+ def get_thousands_sep() -> str:
+ sep = cast(str, locale.localeconv()["thousands_sep"])
# If this locale library is broken, some of the thousands separator
# characters are incorrectly blank. Here is a lookup table of the
# corrections I am aware of.
@@ -111,5 +118,5 @@ except ImportError:
else:
return sep
- def get_decimal_point():
- return locale.localeconv()["decimal_point"]
+ def get_decimal_point() -> str:
+ return cast(str, locale.localeconv()["decimal_point"])
diff --git a/natsort/natsort.py b/natsort/natsort.py
index 8e3a7b5..a95f9a9 100644
--- a/natsort/natsort.py
+++ b/natsort/natsort.py
@@ -9,13 +9,51 @@ The majority of the "work" is defined in utils.py.
import platform
from functools import partial
from operator import itemgetter
+from typing import (
+ Any,
+ Callable,
+ Iterable,
+ Iterator,
+ List,
+ Optional,
+ Sequence,
+ Tuple,
+ Union,
+ cast,
+ overload,
+)
import natsort.compat.locale
from natsort import utils
-from natsort.ns_enum import NS_DUMB, ns
-
-
-def decoder(encoding):
+from natsort.ns_enum import NSType, NS_DUMB, ns
+from natsort.utils import (
+ KeyType,
+ MaybeKeyType,
+ NatsortInType,
+ NatsortOutType,
+ StrBytesNum,
+ StrBytesPathNum,
+)
+
+# Common input and output types
+Iter_ns = Iterable[NatsortInType]
+Iter_any = Iterable[Any]
+List_ns = List[NatsortInType]
+List_any = List[Any]
+List_int = List[int]
+
+# The type that natsort_key returns
+NatsortKeyType = Callable[[NatsortInType], NatsortOutType]
+
+# Types for os_sorted
+OSSortInType = Iterable[Optional[StrBytesPathNum]]
+OSSortOutType = Tuple[Union[StrBytesNum, Tuple[StrBytesNum, ...]], ...]
+OSSortKeyType = Callable[[Optional[StrBytesPathNum]], OSSortOutType]
+Iter_path = Iterable[Optional[StrBytesPathNum]]
+List_path = List[StrBytesPathNum]
+
+
+def decoder(encoding: str) -> Callable[[NatsortInType], NatsortInType]:
"""
Return a function that can be used to decode bytes to unicode.
@@ -56,7 +94,7 @@ def decoder(encoding):
return partial(utils.do_decoding, encoding=encoding)
-def as_ascii(s):
+def as_ascii(s: NatsortInType) -> NatsortInType:
"""
Function to decode an input with the ASCII codec, or return as-is.
@@ -79,7 +117,7 @@ def as_ascii(s):
return utils.do_decoding(s, "ascii")
-def as_utf8(s):
+def as_utf8(s: NatsortInType) -> NatsortInType:
"""
Function to decode an input with the UTF-8 codec, or return as-is.
@@ -102,7 +140,9 @@ def as_utf8(s):
return utils.do_decoding(s, "utf-8")
-def natsort_keygen(key=None, alg=ns.DEFAULT):
+def natsort_keygen(
+ key: MaybeKeyType = None, alg: NSType = ns.DEFAULT
+) -> NatsortKeyType:
"""
Generate a key to sort strings and numbers naturally.
@@ -212,7 +252,26 @@ natsort_keygen
"""
-def natsorted(seq, key=None, reverse=False, alg=ns.DEFAULT):
+@overload
+def natsorted(
+ seq: Iter_ns, key: None = None, reverse: bool = False, alg: NSType = ns.DEFAULT
+) -> List_ns:
+ ...
+
+
+@overload
+def natsorted(
+ seq: Iter_any, key: KeyType, reverse: bool = False, alg: NSType = ns.DEFAULT
+) -> List_any:
+ ...
+
+
+def natsorted(
+ seq: Iter_any,
+ key: MaybeKeyType = None,
+ reverse: bool = False,
+ alg: NSType = ns.DEFAULT,
+) -> List_any:
"""
Sorts an iterable naturally.
@@ -257,11 +316,29 @@ def natsorted(seq, key=None, reverse=False, alg=ns.DEFAULT):
['num2', 'num3', 'num5']
"""
- key = natsort_keygen(key, alg)
- return sorted(seq, reverse=reverse, key=key)
+ return sorted(seq, reverse=reverse, key=natsort_keygen(key, alg))
+
+
+@overload
+def humansorted(
+ seq: Iter_ns, key: None = None, reverse: bool = False, alg: NSType = ns.DEFAULT
+) -> List_ns:
+ ...
-def humansorted(seq, key=None, reverse=False, alg=ns.DEFAULT):
+@overload
+def humansorted(
+ seq: Iter_any, key: KeyType, reverse: bool = False, alg: NSType = ns.DEFAULT
+) -> List_any:
+ ...
+
+
+def humansorted(
+ seq: Iter_any,
+ key: MaybeKeyType = None,
+ reverse: bool = False,
+ alg: NSType = ns.DEFAULT,
+) -> List_any:
"""
Convenience function to properly sort non-numeric characters.
@@ -313,7 +390,26 @@ def humansorted(seq, key=None, reverse=False, alg=ns.DEFAULT):
return natsorted(seq, key, reverse, alg | ns.LOCALE)
-def realsorted(seq, key=None, reverse=False, alg=ns.DEFAULT):
+@overload
+def realsorted(
+ seq: Iter_ns, key: None = None, reverse: bool = False, alg: NSType = ns.DEFAULT
+) -> List_ns:
+ ...
+
+
+@overload
+def realsorted(
+ seq: Iter_any, key: KeyType, reverse: bool = False, alg: NSType = ns.DEFAULT
+) -> List_any:
+ ...
+
+
+def realsorted(
+ seq: Iter_any,
+ key: MaybeKeyType = None,
+ reverse: bool = False,
+ alg: NSType = ns.DEFAULT,
+) -> List_any:
"""
Convenience function to properly sort signed floats.
@@ -366,7 +462,26 @@ def realsorted(seq, key=None, reverse=False, alg=ns.DEFAULT):
return natsorted(seq, key, reverse, alg | ns.REAL)
-def index_natsorted(seq, key=None, reverse=False, alg=ns.DEFAULT):
+@overload
+def index_natsorted(
+ seq: Iter_ns, key: None = None, reverse: bool = False, alg: NSType = ns.DEFAULT
+) -> List_int:
+ ...
+
+
+@overload
+def index_natsorted(
+ seq: Iter_any, key: KeyType, reverse: bool = False, alg: NSType = ns.DEFAULT
+) -> List_int:
+ ...
+
+
+def index_natsorted(
+ seq: Iter_any,
+ key: MaybeKeyType = None,
+ reverse: bool = False,
+ alg: NSType = ns.DEFAULT,
+) -> List_int:
"""
Determine the list of the indexes used to sort the input sequence.
@@ -422,12 +537,13 @@ def index_natsorted(seq, key=None, reverse=False, alg=ns.DEFAULT):
['baz', 'foo', 'bar']
"""
+ newkey: KeyType
if key is None:
newkey = itemgetter(1)
else:
- def newkey(x):
- return key(itemgetter(1)(x))
+ def newkey(x: Any) -> NatsortInType:
+ return cast(KeyType, key)(itemgetter(1)(x))
# Pair the index and sequence together, then sort by element
index_seq_pair = [(x, y) for x, y in enumerate(seq)]
@@ -435,7 +551,26 @@ def index_natsorted(seq, key=None, reverse=False, alg=ns.DEFAULT):
return [x for x, _ in index_seq_pair]
-def index_humansorted(seq, key=None, reverse=False, alg=ns.DEFAULT):
+@overload
+def index_humansorted(
+ seq: Iter_ns, key: None = None, reverse: bool = False, alg: NSType = ns.DEFAULT
+) -> List_int:
+ ...
+
+
+@overload
+def index_humansorted(
+ seq: Iter_any, key: KeyType, reverse: bool = False, alg: NSType = ns.DEFAULT
+) -> List_int:
+ ...
+
+
+def index_humansorted(
+ seq: Iter_any,
+ key: MaybeKeyType = None,
+ reverse: bool = False,
+ alg: NSType = ns.DEFAULT,
+) -> List_int:
"""
This is a wrapper around ``index_natsorted(seq, alg=ns.LOCALE)``.
@@ -484,7 +619,26 @@ def index_humansorted(seq, key=None, reverse=False, alg=ns.DEFAULT):
return index_natsorted(seq, key, reverse, alg | ns.LOCALE)
-def index_realsorted(seq, key=None, reverse=False, alg=ns.DEFAULT):
+@overload
+def index_realsorted(
+ seq: Iter_ns, key: None = None, reverse: bool = False, alg: NSType = ns.DEFAULT
+) -> List_int:
+ ...
+
+
+@overload
+def index_realsorted(
+ seq: Iter_any, key: KeyType, reverse: bool = False, alg: NSType = ns.DEFAULT
+) -> List_int:
+ ...
+
+
+def index_realsorted(
+ seq: Iter_any,
+ key: MaybeKeyType = None,
+ reverse: bool = False,
+ alg: NSType = ns.DEFAULT,
+) -> List_int:
"""
This is a wrapper around ``index_natsorted(seq, alg=ns.REAL)``.
@@ -530,7 +684,9 @@ def index_realsorted(seq, key=None, reverse=False, alg=ns.DEFAULT):
# noinspection PyShadowingBuiltins,PyUnresolvedReferences
-def order_by_index(seq, index, iter=False):
+def order_by_index(
+ seq: Sequence[Any], index: Iterable[int], iter: bool = False
+) -> Iter_any:
"""
Order a given sequence by an index sequence.
@@ -589,7 +745,7 @@ def order_by_index(seq, index, iter=False):
return (seq[i] for i in index) if iter else [seq[i] for i in index]
-def numeric_regex_chooser(alg):
+def numeric_regex_chooser(alg: NSType) -> str:
"""
Select an appropriate regex for the type of number of interest.
@@ -608,7 +764,7 @@ def numeric_regex_chooser(alg):
return utils.regex_chooser(alg).pattern[1:-1]
-def _split_apply(v, key=None):
+def _split_apply(v: Any, key: MaybeKeyType = None) -> Iterator[str]:
if key is not None:
v = key(v)
return utils.path_splitter(str(v))
@@ -617,7 +773,7 @@ def _split_apply(v, key=None):
# Choose the implementation based on the host OS
if platform.system() == "Windows":
- from ctypes import wintypes, windll
+ from ctypes import wintypes, windll # type: ignore
from functools import cmp_to_key
_windows_sort_cmp = windll.Shlwapi.StrCmpLogicalW
@@ -625,8 +781,10 @@ if platform.system() == "Windows":
_windows_sort_cmp.restype = wintypes.INT
_winsort_key = cmp_to_key(_windows_sort_cmp)
- def os_sort_keygen(key=None):
- return lambda x: tuple(map(_winsort_key, _split_apply(x, key)))
+ def os_sort_keygen(key: MaybeKeyType = None) -> OSSortKeyType:
+ return cast(
+ OSSortKeyType, lambda x: tuple(map(_winsort_key, _split_apply(x, key)))
+ )
else:
@@ -645,12 +803,15 @@ else:
except ImportError:
# No ICU installed
- def os_sort_keygen(key=None):
- return natsort_keygen(key=key, alg=ns.LOCALE | ns.PATH | ns.IGNORECASE)
+ def os_sort_keygen(key: MaybeKeyType = None) -> OSSortKeyType:
+ return cast(
+ OSSortKeyType,
+ natsort_keygen(key=key, alg=ns.LOCALE | ns.PATH | ns.IGNORECASE),
+ )
else:
# ICU installed
- def os_sort_keygen(key=None):
+ def os_sort_keygen(key: MaybeKeyType = None) -> OSSortKeyType:
loc = natsort.compat.locale.get_icu_locale()
collator = icu.Collator.createInstance(loc)
collator.setAttribute(
@@ -697,7 +858,19 @@ os_sort_keygen
"""
-def os_sorted(seq, key=None, reverse=False):
+@overload
+def os_sorted(seq: Iter_path, key: None = None, reverse: bool = False) -> List_path:
+ ...
+
+
+@overload
+def os_sorted(seq: Iter_any, key: KeyType, reverse: bool = False) -> List_any:
+ ...
+
+
+def os_sorted(
+ seq: Iter_any, key: MaybeKeyType = None, reverse: bool = False
+) -> List_any:
"""
Sort elements in the same order as your operating system's file browser
diff --git a/natsort/ns_enum.py b/natsort/ns_enum.py
index 283a793..c147909 100644
--- a/natsort/ns_enum.py
+++ b/natsort/ns_enum.py
@@ -4,73 +4,15 @@ This module defines the "ns" enum for natsort is used to determine
what algorithm natsort uses.
"""
-import collections
-
-# The below are the base ns options. The values will be stored as powers
-# of two so bitmasks can be used to extract the user's requested options.
-enum_options = [
- "FLOAT",
- "SIGNED",
- "NOEXP",
- "PATH",
- "LOCALEALPHA",
- "LOCALENUM",
- "IGNORECASE",
- "LOWERCASEFIRST",
- "GROUPLETTERS",
- "UNGROUPLETTERS",
- "NANLAST",
- "COMPATIBILITYNORMALIZE",
- "NUMAFTER",
-]
-
-# Following were previously options but are now defaults.
-enum_do_nothing = ["DEFAULT", "INT", "UNSIGNED"]
-
-# The following are bitwise-OR combinations of other fields.
-enum_combos = [("REAL", ("FLOAT", "SIGNED")), ("LOCALE", ("LOCALEALPHA", "LOCALENUM"))]
-
-# The following are aliases for other fields.
-enum_aliases = [
- ("I", "INT"),
- ("U", "UNSIGNED"),
- ("F", "FLOAT"),
- ("S", "SIGNED"),
- ("R", "REAL"),
- ("N", "NOEXP"),
- ("P", "PATH"),
- ("LA", "LOCALEALPHA"),
- ("LN", "LOCALENUM"),
- ("L", "LOCALE"),
- ("IC", "IGNORECASE"),
- ("LF", "LOWERCASEFIRST"),
- ("G", "GROUPLETTERS"),
- ("UG", "UNGROUPLETTERS"),
- ("C", "UNGROUPLETTERS"),
- ("CAPITALFIRST", "UNGROUPLETTERS"),
- ("NL", "NANLAST"),
- ("CN", "COMPATIBILITYNORMALIZE"),
- ("NA", "NUMAFTER"),
-]
-
-# Construct the list of bitwise distinct enums with their fields.
-enum_fields = collections.OrderedDict(
- (name, 1 << i) for i, name in enumerate(enum_options)
-)
-enum_fields.update((name, 0) for name in enum_do_nothing)
-
-for name, combo in enum_combos:
- combined_value = enum_fields[combo[0]]
- for combo_name in combo[1:]:
- combined_value |= enum_fields[combo_name]
- enum_fields[name] = combined_value
-
-enum_fields.update((alias, enum_fields[name]) for alias, name in enum_aliases)
-
-
-# Subclass the namedtuple to improve the docstring.
-# noinspection PyUnresolvedReferences
-class _NSEnum(collections.namedtuple("_NSEnum", enum_fields.keys())):
+import enum
+import itertools
+import typing
+
+
+_counter = itertools.count(0)
+
+
+class ns(enum.IntEnum): # noqa: N801
"""
Enum to control the `natsort` algorithm.
@@ -186,10 +128,35 @@ class _NSEnum(collections.namedtuple("_NSEnum", enum_fields.keys())):
"""
+ # The below are the base ns options. The values will be stored as powers
+ # of two so bitmasks can be used to extract the user's requested options.
+ FLOAT = F = 1 << next(_counter)
+ SIGNED = S = 1 << next(_counter)
+ NOEXP = N = 1 << next(_counter)
+ PATH = P = 1 << next(_counter)
+ LOCALEALPHA = LA = 1 << next(_counter)
+ LOCALENUM = LN = 1 << next(_counter)
+ IGNORECASE = IC = 1 << next(_counter)
+ LOWERCASEFIRST = LF = 1 << next(_counter)
+ GROUPLETTERS = G = 1 << next(_counter)
+ UNGROUPLETTERS = CAPITALFIRST = C = UG = 1 << next(_counter)
+ NANLAST = NL = 1 << next(_counter)
+ COMPATIBILITYNORMALIZE = CN = 1 << next(_counter)
+ NUMAFTER = NA = 1 << next(_counter)
+
+ # Following were previously options but are now defaults.
+ DEFAULT = 0
+ INT = I = 0 # noqa: E741
+ UNSIGNED = U = 0
+
+ # The following are bitwise-OR combinations of other fields.
+ REAL = R = FLOAT | SIGNED
+ LOCALE = L = LOCALEALPHA | LOCALENUM
-# Here is where the instance of the ns enum that will be exported is created.
-# It is a poor-man's singleton.
-ns = _NSEnum(*enum_fields.values())
# The below is private for internal use only.
NS_DUMB = 1 << 31
+
+# An integer can be used in place of the ns enum so make the
+# type to use for this enum a union of it and an inteter.
+NSType = typing.Union[ns, int]
diff --git a/natsort/py.typed b/natsort/py.typed
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/natsort/py.typed
diff --git a/natsort/utils.py b/natsort/utils.py
index 3f6641b..7102f41 100644
--- a/natsort/utils.py
+++ b/natsort/utils.py
@@ -38,19 +38,93 @@ that ensures "val" is a local variable instead of global variable
and thus has a slightly improved performance at runtime.
"""
-
import re
from functools import partial, reduce
from itertools import chain as ichain
from operator import methodcaller
from pathlib import PurePath
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Iterable,
+ Iterator,
+ List,
+ Match,
+ Optional,
+ Pattern,
+ Tuple,
+ Union,
+ cast,
+ overload,
+)
from unicodedata import normalize
from natsort.compat.fastnumbers import fast_float, fast_int
-from natsort.compat.locale import get_decimal_point, get_strxfrm, get_thousands_sep
-from natsort.ns_enum import NS_DUMB, ns
+from natsort.compat.locale import (
+ StrOrBytes,
+ get_decimal_point,
+ get_strxfrm,
+ get_thousands_sep,
+)
+from natsort.ns_enum import NSType, NS_DUMB, ns
from natsort.unicode_numbers import digits_no_decimals, numeric_no_decimals
+#
+# Pre-define a slew of aggregate types which makes the type hinting below easier
+#
+StrToStr = Callable[[str], str]
+AnyCall = Callable[[Any], Any]
+
+# For the bytes transform factory
+BytesTuple = Tuple[bytes]
+NestedBytesTuple = Tuple[Tuple[bytes]]
+BytesTransform = Union[BytesTuple, NestedBytesTuple]
+BytesTransformer = Callable[[bytes], BytesTransform]
+
+# For the number transform factory
+NumType = Union[float, int]
+MaybeNumType = Optional[NumType]
+NumTuple = Tuple[StrOrBytes, NumType]
+NestedNumTuple = Tuple[NumTuple]
+StrNumTuple = Tuple[Tuple[str], NumTuple]
+NestedStrNumTuple = Tuple[StrNumTuple]
+MaybeNumTransform = Union[NumTuple, NestedNumTuple, StrNumTuple, NestedStrNumTuple]
+MaybeNumTransformer = Callable[[MaybeNumType], MaybeNumTransform]
+
+# For the string component transform factory
+StrBytesNum = Union[str, bytes, float, int]
+StrTransformer = Callable[[str], StrBytesNum]
+
+# For the final data transform factory
+TwoBlankTuple = Tuple[Tuple[()], Tuple[()]]
+TupleOfAny = Tuple[Any, ...]
+TupleOfStrAnyPair = Tuple[Tuple[str], TupleOfAny]
+FinalTransform = Union[TwoBlankTuple, TupleOfAny, TupleOfStrAnyPair]
+FinalTransformer = Callable[[Iterable[Any], str], FinalTransform]
+
+# For the string parsing factory
+StrSplitter = Callable[[str], Iterable[str]]
+StrParser = Callable[[str], FinalTransform]
+
+# For the path splitter
+PathArg = Union[str, PurePath]
+MatchFn = Callable[[str], Optional[Match]]
+
+# For the path parsing factory
+PathSplitter = Callable[[PathArg], Tuple[FinalTransform, ...]]
+
+# For the natsort key
+StrBytesPathNum = Union[str, bytes, float, int, PurePath]
+NatsortInType = Union[
+ Optional[StrBytesPathNum], Iterable[Union[Optional[StrBytesPathNum], Iterable[Any]]]
+]
+NatsortOutType = Tuple[
+ Union[StrBytesNum, Tuple[Union[StrBytesNum, Tuple[Any, ...]], ...]], ...
+]
+KeyType = Callable[[Any], NatsortInType]
+MaybeKeyType = Optional[KeyType]
+
class NumericalRegularExpressions:
"""
@@ -62,51 +136,51 @@ class NumericalRegularExpressions:
"""
# All unicode numeric characters (minus the decimal characters).
- numeric = numeric_no_decimals
+ numeric: str = numeric_no_decimals
# All unicode digit characters (minus the decimal characters).
- digits = digits_no_decimals
+ digits: str = digits_no_decimals
# Regular expression to match exponential component of a float.
- exp = r"(?:[eE][-+]?\d+)?"
+ exp: str = r"(?:[eE][-+]?\d+)?"
# Regular expression to match a floating point number.
- float_num = r"(?:\d+\.?\d*|\.\d+)"
+ float_num: str = r"(?:\d+\.?\d*|\.\d+)"
@classmethod
- def _construct_regex(cls, fmt):
+ def _construct_regex(cls, fmt: str) -> Pattern[str]:
"""Given a format string, construct the regex with class attributes."""
return re.compile(fmt.format(**vars(cls)), flags=re.U)
@classmethod
- def int_sign(cls):
+ def int_sign(cls) -> Pattern[str]:
"""Regular expression to match a signed int."""
return cls._construct_regex(r"([-+]?\d+|[{digits}])")
@classmethod
- def int_nosign(cls):
+ def int_nosign(cls) -> Pattern[str]:
"""Regular expression to match an unsigned int."""
return cls._construct_regex(r"(\d+|[{digits}])")
@classmethod
- def float_sign_exp(cls):
+ def float_sign_exp(cls) -> Pattern[str]:
"""Regular expression to match a signed float with exponent."""
return cls._construct_regex(r"([-+]?{float_num}{exp}|[{numeric}])")
@classmethod
- def float_nosign_exp(cls):
+ def float_nosign_exp(cls) -> Pattern[str]:
"""Regular expression to match an unsigned float with exponent."""
return cls._construct_regex(r"({float_num}{exp}|[{numeric}])")
@classmethod
- def float_sign_noexp(cls):
+ def float_sign_noexp(cls) -> Pattern[str]:
"""Regular expression to match a signed float without exponent."""
return cls._construct_regex(r"([-+]?{float_num}|[{numeric}])")
@classmethod
- def float_nosign_noexp(cls):
+ def float_nosign_noexp(cls) -> Pattern[str]:
"""Regular expression to match an unsigned float without exponent."""
return cls._construct_regex(r"({float_num}|[{numeric}])")
-def regex_chooser(alg):
+def regex_chooser(alg: NSType) -> Pattern[str]:
"""
Select an appropriate regex for the type of number of interest.
@@ -136,12 +210,12 @@ def regex_chooser(alg):
}[alg]
-def _no_op(x):
+def _no_op(x: Any) -> Any:
"""A function that does nothing and returns the input as-is."""
return x
-def _normalize_input_factory(alg):
+def _normalize_input_factory(alg: NSType) -> StrToStr:
"""
Create a function that will normalize unicode input data.
@@ -161,7 +235,35 @@ def _normalize_input_factory(alg):
return partial(normalize, normalization_form)
-def natsort_key(val, key, string_func, bytes_func, num_func):
+@overload
+def natsort_key(
+ val: NatsortInType,
+ key: None,
+ string_func: Union[StrParser, PathSplitter],
+ bytes_func: BytesTransformer,
+ num_func: MaybeNumTransformer,
+) -> NatsortOutType:
+ ...
+
+
+@overload
+def natsort_key(
+ val: Any,
+ key: KeyType,
+ string_func: Union[StrParser, PathSplitter],
+ bytes_func: BytesTransformer,
+ num_func: MaybeNumTransformer,
+) -> NatsortOutType:
+ ...
+
+
+def natsort_key(
+ val: Union[NatsortInType, Any],
+ key: MaybeKeyType,
+ string_func: Union[StrParser, PathSplitter],
+ bytes_func: BytesTransformer,
+ num_func: MaybeNumTransformer,
+) -> NatsortOutType:
"""
Key to sort strings and numbers naturally.
@@ -170,7 +272,7 @@ def natsort_key(val, key, string_func, bytes_func, num_func):
Parameters
----------
- val : str | unicode | bytes | int | float | iterable
+ val : str | bytes | int | float | iterable
key : callable | None
A key to apply to the *val* before any other operations are performed.
string_func : callable
@@ -210,26 +312,27 @@ def natsort_key(val, key, string_func, bytes_func, num_func):
# Assume the input are strings, which is the most common case
try:
- return string_func(val)
+ return string_func(cast(str, val))
except (TypeError, AttributeError):
# If bytes type, use the bytes_func
if type(val) in (bytes,):
- return bytes_func(val)
+ return bytes_func(cast(bytes, val))
# Otherwise, assume it is an iterable that must be parsed recursively.
# Do not apply the key recursively.
try:
return tuple(
- natsort_key(x, None, string_func, bytes_func, num_func) for x in val
+ natsort_key(x, None, string_func, bytes_func, num_func)
+ for x in cast(Iterable[Any], val)
)
# If that failed, it must be a number.
except TypeError:
- return num_func(val)
+ return num_func(cast(NumType, val))
-def parse_bytes_factory(alg):
+def parse_bytes_factory(alg: NSType) -> BytesTransformer:
"""
Create a function that will format a *bytes* object into a tuple.
@@ -262,7 +365,9 @@ def parse_bytes_factory(alg):
return lambda x: (x,)
-def parse_number_or_none_factory(alg, sep, pre_sep):
+def parse_number_or_none_factory(
+ alg: NSType, sep: StrOrBytes, pre_sep: str
+) -> MaybeNumTransformer:
"""
Create a function that will format a number (or None) into a tuple.
@@ -293,7 +398,9 @@ def parse_number_or_none_factory(alg, sep, pre_sep):
"""
nan_replace = float("+inf") if alg & ns.NANLAST else float("-inf")
- def func(val, _nan_replace=nan_replace, _sep=sep):
+ def func(
+ val: MaybeNumType, _nan_replace: float = nan_replace, _sep: StrOrBytes = sep
+ ) -> NumTuple:
"""Given a number, place it in a tuple with a leading null string."""
return _sep, (_nan_replace if val != val or val is None else val)
@@ -309,8 +416,13 @@ def parse_number_or_none_factory(alg, sep, pre_sep):
def parse_string_factory(
- alg, sep, splitter, input_transform, component_transform, final_transform
-):
+ alg: NSType,
+ sep: StrOrBytes,
+ splitter: StrSplitter,
+ input_transform: StrToStr,
+ component_transform: StrTransformer,
+ final_transform: FinalTransformer,
+) -> StrParser:
"""
Create a function that will split and format a *str* into a tuple.
@@ -361,22 +473,22 @@ def parse_string_factory(
original_func = input_transform if orig_after_xfrm else _no_op
normalize_input = _normalize_input_factory(alg)
- def func(x):
+ def func(x: str) -> FinalTransform:
# Apply string input transformation function and return to x.
# Original function is usually a no-op, but some algorithms require it
# to also be the transformation function.
- x = normalize_input(x)
- x, original = input_transform(x), original_func(x)
- x = splitter(x) # Split string into components.
- x = filter(None, x) # Remove empty strings.
- x = map(component_transform, x) # Apply transform on components.
- x = sep_inserter(x, sep) # Insert '' between numbers.
- return final_transform(x, original) # Apply the final transform.
+ a = normalize_input(x)
+ b, original = input_transform(a), original_func(a)
+ c = splitter(b) # Split string into components.
+ d = filter(None, c) # Remove empty strings.
+ e = map(component_transform, d) # Apply transform on components.
+ f = sep_inserter(e, sep) # Insert '' between numbers.
+ return final_transform(f, original) # Apply the final transform.
return func
-def parse_path_factory(str_split):
+def parse_path_factory(str_split: StrParser) -> PathSplitter:
"""
Create a function that will properly split and format a path.
@@ -403,41 +515,41 @@ def parse_path_factory(str_split):
return lambda x: tuple(map(str_split, path_splitter(x)))
-def sep_inserter(iterable, sep):
+def sep_inserter(iterator: Iterator[Any], sep: StrOrBytes) -> Iterator[Any]:
"""
- Insert '' between numbers in an iterable.
+ Insert '' between numbers in an iterator.
Parameters
----------
- iterable
+ iterator
sep : str
The string character to be inserted between adjacent numeric objects.
Yields
------
- The values of *iterable* in order, with *sep* inserted where adjacent
+ The values of *iterator* in order, with *sep* inserted where adjacent
elements are numeric. If the first element in the input is numeric
then *sep* will be the first value yielded.
"""
try:
- # Get the first element. A StopIteration indicates an empty iterable.
+ # Get the first element. A StopIteration indicates an empty iterator.
# Since we are controlling the types of the input, 'type' is used
# instead of 'isinstance' for the small speed advantage it offers.
types = (int, float)
- first = next(iterable)
+ first = next(iterator)
if type(first) in types:
yield sep
yield first
# Now, check if pair of elements are both numbers. If so, add ''.
- second = next(iterable)
+ second = next(iterator)
if type(first) in types and type(second) in types:
yield sep
yield second
# Now repeat in a loop.
- for x in iterable:
+ for x in iterator:
first, second = second, x
if type(first) in types and type(second) in types:
yield sep
@@ -448,7 +560,7 @@ def sep_inserter(iterable, sep):
return
-def input_string_transform_factory(alg):
+def input_string_transform_factory(alg: NSType) -> StrToStr:
"""
Create a function to transform a string.
@@ -473,7 +585,7 @@ def input_string_transform_factory(alg):
dumb = alg & NS_DUMB
# Build the chain of functions to execute in order.
- function_chain = []
+ function_chain: List[StrToStr] = []
if (dumb and not lowfirst) or (lowfirst and not dumb):
function_chain.append(methodcaller("swapcase"))
@@ -502,8 +614,8 @@ def input_string_transform_factory(alg):
strip_thousands = strip_thousands.format(
thou=re.escape(get_thousands_sep()), nodecimal=nodecimal
)
- strip_thousands = re.compile(strip_thousands, flags=re.VERBOSE)
- function_chain.append(partial(strip_thousands.sub, ""))
+ strip_thousands_re = re.compile(strip_thousands, flags=re.VERBOSE)
+ function_chain.append(partial(strip_thousands_re.sub, ""))
# Create a regular expression that will change the decimal point to
# a period if not already a period.
@@ -511,14 +623,14 @@ def input_string_transform_factory(alg):
if alg & ns.FLOAT and decimal != ".":
switch_decimal = r"(?<=[0-9]){decimal}|{decimal}(?=[0-9])"
switch_decimal = switch_decimal.format(decimal=re.escape(decimal))
- switch_decimal = re.compile(switch_decimal)
- function_chain.append(partial(switch_decimal.sub, "."))
+ switch_decimal_re = re.compile(switch_decimal)
+ function_chain.append(partial(switch_decimal_re.sub, "."))
# Return the chained functions.
return chain_functions(function_chain)
-def string_component_transform_factory(alg):
+def string_component_transform_factory(alg: NSType) -> StrTransformer:
"""
Create a function to either transform a string or convert to a number.
@@ -545,23 +657,26 @@ def string_component_transform_factory(alg):
nan_val = float("+inf") if alg & ns.NANLAST else float("-inf")
# Build the chain of functions to execute in order.
- func_chain = []
+ func_chain: List[Callable[[str], StrOrBytes]] = []
if group_letters:
func_chain.append(groupletters)
if use_locale:
func_chain.append(get_strxfrm())
- kwargs = {"key": chain_functions(func_chain)} if func_chain else {}
# Return the correct chained functions.
+ kwargs: Dict[str, Union[float, Callable[[str], StrOrBytes]]]
+ kwargs = {"key": chain_functions(func_chain)} if func_chain else {}
if alg & ns.FLOAT:
# noinspection PyTypeChecker
kwargs["nan"] = nan_val
- return partial(fast_float, **kwargs)
+ return cast(Callable[[str], StrOrBytes], partial(fast_float, **kwargs))
else:
- return partial(fast_int, **kwargs)
+ return cast(Callable[[str], StrOrBytes], partial(fast_int, **kwargs))
-def final_data_transform_factory(alg, sep, pre_sep):
+def final_data_transform_factory(
+ alg: NSType, sep: StrOrBytes, pre_sep: str
+) -> FinalTransformer:
"""
Create a function to transform a tuple.
@@ -589,9 +704,15 @@ def final_data_transform_factory(alg, sep, pre_sep):
"""
if alg & ns.UNGROUPLETTERS and alg & ns.LOCALEALPHA:
swap = alg & NS_DUMB and alg & ns.LOWERCASEFIRST
- transform = methodcaller("swapcase") if swap else _no_op
-
- def func(split_val, val, _transform=transform, _sep=sep, _pre_sep=pre_sep):
+ transform = cast(StrToStr, methodcaller("swapcase")) if swap else _no_op
+
+ def func(
+ split_val: Iterable[NatsortInType],
+ val: str,
+ _transform: StrToStr = transform,
+ _sep: StrOrBytes = sep,
+ _pre_sep: str = pre_sep,
+ ) -> FinalTransform:
"""
Return a tuple with the first character of the first element
of the return value as the first element, and the return value
@@ -606,16 +727,25 @@ def final_data_transform_factory(alg, sep, pre_sep):
else:
return (_transform(val[0]),), split_val
- return func
else:
- return lambda split_val, val: tuple(split_val)
+ def func(
+ split_val: Iterable[NatsortInType],
+ val: str,
+ _transform: StrToStr = _no_op,
+ _sep: StrOrBytes = sep,
+ _pre_sep: str = pre_sep,
+ ) -> FinalTransform:
+ return tuple(split_val)
+
+ return func
-lower_function = methodcaller("casefold")
+
+lower_function: StrToStr = cast(StrToStr, methodcaller("casefold"))
# noinspection PyIncorrectDocstring
-def groupletters(x, _low=lower_function):
+def groupletters(x: str, _low: StrToStr = lower_function) -> str:
"""
Double all characters, making doubled letters lowercase.
@@ -637,7 +767,7 @@ def groupletters(x, _low=lower_function):
return "".join(ichain.from_iterable((_low(y), y) for y in x))
-def chain_functions(functions):
+def chain_functions(functions: Iterable[AnyCall]) -> AnyCall:
"""
Chain a list of single-argument functions together and return.
@@ -674,7 +804,17 @@ def chain_functions(functions):
return partial(reduce, lambda res, f: f(res), functions)
-def do_decoding(s, encoding):
+@overload
+def do_decoding(s: bytes, encoding: str) -> str:
+ ...
+
+
+@overload
+def do_decoding(s: NatsortInType, encoding: str) -> NatsortInType:
+ ...
+
+
+def do_decoding(s: NatsortInType, encoding: str) -> NatsortInType:
"""
Helper to decode a *bytes* object, or return the object as-is.
@@ -692,13 +832,15 @@ def do_decoding(s, encoding):
"""
try:
- return s.decode(encoding)
+ return cast(bytes, s).decode(encoding)
except (AttributeError, TypeError):
return s
# noinspection PyIncorrectDocstring
-def path_splitter(s, _d_match=re.compile(r"\.\d").match):
+def path_splitter(
+ s: PathArg, _d_match: MatchFn = re.compile(r"\.\d").match
+) -> Iterator[str]:
"""
Split a string into its path components.
diff --git a/setup.cfg b/setup.cfg
index a045246..cf651bf 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -63,3 +63,9 @@ exclude =
dist,
docs,
.venv
+
+[mypy]
+
+[mypy-icu]
+ignore_missing_imports = True
+
diff --git a/setup.py b/setup.py
index ed1362a..2d27ae0 100644
--- a/setup.py
+++ b/setup.py
@@ -7,6 +7,8 @@ setup(
version="7.1.1",
packages=find_packages(),
entry_points={"console_scripts": ["natsort = natsort.__main__:main"]},
- python_requires=">=3.4",
+ python_requires=">=3.6",
extras_require={"fast": ["fastnumbers >= 2.0.0"], "icu": ["PyICU >= 1.0.0"]},
+ package_data={"": ["py.typed"]},
+ zip_safe=False,
)
diff --git a/tests/conftest.py b/tests/conftest.py
index 74d7f4f..c63e149 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -3,6 +3,7 @@ Fixtures for pytest.
"""
import locale
+from typing import Iterator
import hypothesis
import pytest
@@ -16,7 +17,7 @@ hypothesis.settings.register_profile(
)
-def load_locale(x):
+def load_locale(x: str) -> None:
"""Convenience to load a locale, trying ISO8859-1 first."""
try:
locale.setlocale(locale.LC_ALL, str("{}.ISO8859-1".format(x)))
@@ -25,15 +26,16 @@ def load_locale(x):
@pytest.fixture()
-def with_locale_en_us():
+def with_locale_en_us() -> Iterator[None]:
"""Convenience to load the en_US locale - reset when complete."""
orig = locale.getlocale()
- yield load_locale("en_US")
+ load_locale("en_US")
+ yield
locale.setlocale(locale.LC_ALL, orig)
@pytest.fixture()
-def with_locale_de_de():
+def with_locale_de_de() -> Iterator[None]:
"""
Convenience to load the de_DE locale - reset when complete - skip if missing.
"""
diff --git a/tests/profile_natsorted.py b/tests/profile_natsorted.py
index a2b1a5c..f6580a3 100644
--- a/tests/profile_natsorted.py
+++ b/tests/profile_natsorted.py
@@ -7,6 +7,7 @@ inputs and different settings.
import cProfile
import locale
import sys
+from typing import List, Union
try:
from natsort import ns, natsort_keygen
@@ -14,6 +15,8 @@ except ImportError:
sys.path.insert(0, ".")
from natsort import ns, natsort_keygen
+from natsort.natsort import NatsortKeyType
+
locale.setlocale(locale.LC_ALL, "en_US.UTF-8")
# Samples to parse
@@ -32,7 +35,7 @@ path_key = natsort_keygen(alg=ns.PATH)
locale_key = natsort_keygen(alg=ns.LOCALE)
-def prof_time_to_generate():
+def prof_time_to_generate() -> None:
print("*** Generate Plain Key ***")
for _ in range(100000):
natsort_keygen()
@@ -41,7 +44,9 @@ def prof_time_to_generate():
cProfile.run("prof_time_to_generate()", sort="time")
-def prof_parsing(a, msg, key=basic_key):
+def prof_parsing(
+ a: Union[str, int, bytes, List[str]], msg: str, key: NatsortKeyType = basic_key
+) -> None:
print(msg)
for _ in range(100000):
key(a)
diff --git a/tests/test_fake_fastnumbers.py b/tests/test_fake_fastnumbers.py
index c75bb11..574f7cf 100644
--- a/tests/test_fake_fastnumbers.py
+++ b/tests/test_fake_fastnumbers.py
@@ -5,13 +5,14 @@ Test the fake fastnumbers module.
import unicodedata
from math import isnan
+from typing import Union, cast
from hypothesis import given
from hypothesis.strategies import floats, integers, text
from natsort.compat.fake_fastnumbers import fast_float, fast_int
-def is_float(x):
+def is_float(x: str) -> bool:
try:
float(x)
except ValueError:
@@ -25,19 +26,19 @@ def is_float(x):
return True
-def not_a_float(x):
+def not_a_float(x: str) -> bool:
return not is_float(x)
-def is_int(x):
+def is_int(x: Union[str, float]) -> bool:
try:
- return x.is_integer()
+ return cast(float, x).is_integer()
except AttributeError:
try:
int(x)
except ValueError:
try:
- unicodedata.digit(x)
+ unicodedata.digit(cast(str, x))
except (ValueError, TypeError):
return False
else:
@@ -46,7 +47,7 @@ def is_int(x):
return True
-def not_an_int(x):
+def not_an_int(x: Union[str, float]) -> bool:
return not is_int(x)
@@ -54,56 +55,56 @@ def not_an_int(x):
# and a test that uses the hypothesis module.
-def test_fast_float_returns_nan_alternate_if_nan_option_is_given():
+def test_fast_float_returns_nan_alternate_if_nan_option_is_given() -> None:
assert fast_float("nan", nan=7) == 7
-def test_fast_float_converts_float_string_to_float_example():
+def test_fast_float_converts_float_string_to_float_example() -> None:
assert fast_float("45.8") == 45.8
assert fast_float("-45") == -45.0
assert fast_float("45.8e-2", key=len) == 45.8e-2
- assert isnan(fast_float("nan"))
- assert isnan(fast_float("+nan"))
- assert isnan(fast_float("-NaN"))
+ assert isnan(cast(float, fast_float("nan")))
+ assert isnan(cast(float, fast_float("+nan")))
+ assert isnan(cast(float, fast_float("-NaN")))
assert fast_float("۱۲.۱۲") == 12.12
assert fast_float("-۱۲.۱۲") == -12.12
@given(floats(allow_nan=False))
-def test_fast_float_converts_float_string_to_float(x):
+def test_fast_float_converts_float_string_to_float(x: float) -> None:
assert fast_float(repr(x)) == x
-def test_fast_float_leaves_string_as_is_example():
+def test_fast_float_leaves_string_as_is_example() -> None:
assert fast_float("invalid") == "invalid"
@given(text().filter(not_a_float).filter(bool))
-def test_fast_float_leaves_string_as_is(x):
+def test_fast_float_leaves_string_as_is(x: str) -> None:
assert fast_float(x) == x
-def test_fast_float_with_key_applies_to_string_example():
+def test_fast_float_with_key_applies_to_string_example() -> None:
assert fast_float("invalid", key=len) == len("invalid")
@given(text().filter(not_a_float).filter(bool))
-def test_fast_float_with_key_applies_to_string(x):
+def test_fast_float_with_key_applies_to_string(x: str) -> None:
assert fast_float(x, key=len) == len(x)
-def test_fast_int_leaves_float_string_as_is_example():
+def test_fast_int_leaves_float_string_as_is_example() -> None:
assert fast_int("45.8") == "45.8"
assert fast_int("nan") == "nan"
assert fast_int("inf") == "inf"
@given(floats().filter(not_an_int))
-def test_fast_int_leaves_float_string_as_is(x):
+def test_fast_int_leaves_float_string_as_is(x: float) -> None:
assert fast_int(repr(x)) == repr(x)
-def test_fast_int_converts_int_string_to_int_example():
+def test_fast_int_converts_int_string_to_int_example() -> None:
assert fast_int("-45") == -45
assert fast_int("+45") == 45
assert fast_int("۱۲") == 12
@@ -111,23 +112,23 @@ def test_fast_int_converts_int_string_to_int_example():
@given(integers())
-def test_fast_int_converts_int_string_to_int(x):
+def test_fast_int_converts_int_string_to_int(x: int) -> None:
assert fast_int(repr(x)) == x
-def test_fast_int_leaves_string_as_is_example():
+def test_fast_int_leaves_string_as_is_example() -> None:
assert fast_int("invalid") == "invalid"
@given(text().filter(not_an_int).filter(bool))
-def test_fast_int_leaves_string_as_is(x):
+def test_fast_int_leaves_string_as_is(x: str) -> None:
assert fast_int(x) == x
-def test_fast_int_with_key_applies_to_string_example():
+def test_fast_int_with_key_applies_to_string_example() -> None:
assert fast_int("invalid", key=len) == len("invalid")
@given(text().filter(not_an_int).filter(bool))
-def test_fast_int_with_key_applies_to_string(x):
+def test_fast_int_with_key_applies_to_string(x: str) -> None:
assert fast_int(x, key=len) == len(x)
diff --git a/tests/test_final_data_transform_factory.py b/tests/test_final_data_transform_factory.py
index f6bf636..36607b6 100644
--- a/tests/test_final_data_transform_factory.py
+++ b/tests/test_final_data_transform_factory.py
@@ -1,17 +1,20 @@
# -*- coding: utf-8 -*-
"""These test the utils.py functions."""
+from typing import Callable, Union
import pytest
from hypothesis import example, given
from hypothesis.strategies import floats, integers, text
-from natsort.ns_enum import NS_DUMB, ns
+from natsort.ns_enum import NSType, NS_DUMB, ns
from natsort.utils import final_data_transform_factory
@pytest.mark.parametrize("alg", [ns.DEFAULT, ns.UNGROUPLETTERS, ns.LOCALE])
@given(x=text(), y=floats(allow_nan=False, allow_infinity=False) | integers())
@pytest.mark.usefixtures("with_locale_en_us")
-def test_final_data_transform_factory_default(x, y, alg):
+def test_final_data_transform_factory_default(
+ x: str, y: Union[int, float], alg: NSType
+) -> None:
final_data_transform_func = final_data_transform_factory(alg, "", "::")
value = (x, y)
original_value = "".join(map(str, value))
@@ -34,7 +37,9 @@ def test_final_data_transform_factory_default(x, y, alg):
@given(x=text(), y=floats(allow_nan=False, allow_infinity=False) | integers())
@example(x="İ", y=0)
@pytest.mark.usefixtures("with_locale_en_us")
-def test_final_data_transform_factory_ungroup_and_locale(x, y, alg, func):
+def test_final_data_transform_factory_ungroup_and_locale(
+ x: str, y: Union[int, float], alg: NSType, func: Callable[[str], str]
+) -> None:
final_data_transform_func = final_data_transform_factory(alg, "", "::")
value = (x, y)
original_value = "".join(map(str, value))
@@ -46,6 +51,6 @@ def test_final_data_transform_factory_ungroup_and_locale(x, y, alg, func):
assert result == expected
-def test_final_data_transform_factory_ungroup_and_locale_empty_tuple():
+def test_final_data_transform_factory_ungroup_and_locale_empty_tuple() -> None:
final_data_transform_func = final_data_transform_factory(ns.UG | ns.L, "", "::")
assert final_data_transform_func((), "") == ((), ())
diff --git a/tests/test_input_string_transform_factory.py b/tests/test_input_string_transform_factory.py
index 7d54afd..6a08318 100644
--- a/tests/test_input_string_transform_factory.py
+++ b/tests/test_input_string_transform_factory.py
@@ -1,14 +1,15 @@
# -*- coding: utf-8 -*-
"""These test the utils.py functions."""
+from typing import Callable
import pytest
from hypothesis import example, given
from hypothesis.strategies import integers, text
-from natsort.ns_enum import NS_DUMB, ns
+from natsort.ns_enum import NSType, NS_DUMB, ns
from natsort.utils import input_string_transform_factory
-def thousands_separated_int(n):
+def thousands_separated_int(n: str) -> str:
"""Insert thousands separators in an int."""
new_int = ""
for i, y in enumerate(reversed(n), 1):
@@ -20,7 +21,7 @@ def thousands_separated_int(n):
@given(text())
-def test_input_string_transform_factory_is_no_op_for_no_alg_options(x):
+def test_input_string_transform_factory_is_no_op_for_no_alg_options(x: str) -> None:
input_string_transform_func = input_string_transform_factory(ns.DEFAULT)
assert input_string_transform_func(x) is x
@@ -36,7 +37,9 @@ def test_input_string_transform_factory_is_no_op_for_no_alg_options(x):
],
)
@given(x=text())
-def test_input_string_transform_factory(x, alg, example_func):
+def test_input_string_transform_factory(
+ x: str, alg: NSType, example_func: Callable[[str], str]
+) -> None:
input_string_transform_func = input_string_transform_factory(alg)
assert input_string_transform_func(x) == example_func(x)
@@ -44,7 +47,7 @@ def test_input_string_transform_factory(x, alg, example_func):
@example(12543642642534980) # 12,543,642,642,534,980 => 12543642642534980
@given(x=integers(min_value=1000))
@pytest.mark.usefixtures("with_locale_en_us")
-def test_input_string_transform_factory_cleans_thousands(x):
+def test_input_string_transform_factory_cleans_thousands(x: int) -> None:
int_str = str(x).rstrip("lL")
thousands_int_str = thousands_separated_int(int_str)
assert thousands_int_str.replace(",", "") != thousands_int_str
@@ -69,7 +72,9 @@ def test_input_string_transform_factory_cleans_thousands(x):
],
)
@pytest.mark.usefixtures("with_locale_en_us")
-def test_input_string_transform_factory_handles_us_locale(x, expected):
+def test_input_string_transform_factory_handles_us_locale(
+ x: str, expected: str
+) -> None:
input_string_transform_func = input_string_transform_factory(ns.LOCALE)
assert input_string_transform_func(x) == expected
@@ -83,7 +88,9 @@ def test_input_string_transform_factory_handles_us_locale(x, expected):
],
)
@pytest.mark.usefixtures("with_locale_de_de")
-def test_input_string_transform_factory_handles_de_locale(x, expected):
+def test_input_string_transform_factory_handles_de_locale(
+ x: str, expected: str
+) -> None:
input_string_transform_func = input_string_transform_factory(ns.LOCALE)
assert input_string_transform_func(x) == expected
@@ -97,13 +104,15 @@ def test_input_string_transform_factory_handles_de_locale(x, expected):
],
)
@pytest.mark.usefixtures("with_locale_de_de")
-def test_input_string_transform_factory_handles_german_locale(alg, expected):
+def test_input_string_transform_factory_handles_german_locale(
+ alg: NSType, expected: str
+) -> None:
input_string_transform_func = input_string_transform_factory(alg)
assert input_string_transform_func("1543,753") == expected
@pytest.mark.usefixtures("with_locale_de_de")
-def test_input_string_transform_factory_does_nothing_with_non_num_input():
+def test_input_string_transform_factory_does_nothing_with_non_num_input() -> None:
input_string_transform_func = input_string_transform_factory(ns.LOCALE | ns.FLOAT)
expected = "154s,t53"
assert input_string_transform_func("154s,t53") == expected
diff --git a/tests/test_main.py b/tests/test_main.py
index da91fdd..2f784a1 100644
--- a/tests/test_main.py
+++ b/tests/test_main.py
@@ -5,11 +5,13 @@ Test the natsort command-line tool functions.
import re
import sys
+from typing import Any, List, Union
import pytest
from hypothesis import given
-from hypothesis.strategies import data, floats, integers, lists
+from hypothesis.strategies import DataObject, data, floats, integers, lists
from natsort.__main__ import (
+ TypedArgs,
check_filters,
keep_entry_range,
keep_entry_value,
@@ -17,16 +19,19 @@ from natsort.__main__ import (
range_check,
sort_and_print_entries,
)
+from pytest_mock import MockerFixture
-def test_main_passes_default_arguments_with_no_command_line_options(mocker):
+def test_main_passes_default_arguments_with_no_command_line_options(
+ mocker: MockerFixture,
+) -> None:
p = mocker.patch("natsort.__main__.sort_and_print_entries")
main("num-2", "num-6", "num-1")
args = p.call_args[0][1]
assert not args.paths
assert args.filter is None
assert args.reverse_filter is None
- assert args.exclude is None
+ assert args.exclude == []
assert not args.reverse
assert args.number_type == "int"
assert not args.signed
@@ -34,7 +39,9 @@ def test_main_passes_default_arguments_with_no_command_line_options(mocker):
assert not args.locale
-def test_main_passes_arguments_with_all_command_line_options(mocker):
+def test_main_passes_arguments_with_all_command_line_options(
+ mocker: MockerFixture,
+) -> None:
arguments = ["--paths", "--reverse", "--locale"]
arguments.extend(["--filter", "4", "10"])
arguments.extend(["--reverse-filter", "100", "110"])
@@ -57,21 +64,6 @@ def test_main_passes_arguments_with_all_command_line_options(mocker):
assert args.locale
-class Args:
- """A dummy class to simulate the argparse Namespace object"""
-
- def __init__(self, filt, reverse_filter, exclude, as_path, reverse):
- self.filter = filt
- self.reverse_filter = reverse_filter
- self.exclude = exclude
- self.reverse = reverse
- self.number_type = "float"
- self.signed = True
- self.exp = True
- self.paths = as_path
- self.locale = 0
-
-
mock_print = "__builtin__.print" if sys.version[0] == "2" else "builtins.print"
entries = [
@@ -135,9 +127,11 @@ entries = [
([None, None, False, True, True], reversed([2, 3, 1, 0, 5, 6, 4])),
],
)
-def test_sort_and_print_entries(options, order, mocker):
+def test_sort_and_print_entries(
+ options: List[Any], order: List[int], mocker: MockerFixture
+) -> None:
p = mocker.patch(mock_print)
- sort_and_print_entries(entries, Args(*options))
+ sort_and_print_entries(entries, TypedArgs(*options))
e = [mocker.call(entries[i]) for i in order]
p.assert_has_calls(e)
@@ -146,13 +140,15 @@ def test_sort_and_print_entries(options, order, mocker):
# and a test that uses the hypothesis module.
-def test_range_check_returns_range_as_is_but_with_floats_example():
+def test_range_check_returns_range_as_is_but_with_floats_example() -> None:
assert range_check(10, 11) == (10.0, 11.0)
assert range_check(6.4, 30) == (6.4, 30.0)
@given(x=floats(allow_nan=False, min_value=-1e8, max_value=1e8) | integers(), d=data())
-def test_range_check_returns_range_as_is_if_first_is_less_than_second(x, d):
+def test_range_check_returns_range_as_is_if_first_is_less_than_second(
+ x: Union[int, float], d: DataObject
+) -> None:
# Pull data such that the first is less than the second.
if isinstance(x, float):
y = d.draw(floats(min_value=x + 1.0, max_value=1e9, allow_nan=False))
@@ -161,44 +157,48 @@ def test_range_check_returns_range_as_is_if_first_is_less_than_second(x, d):
assert range_check(x, y) == (x, y)
-def test_range_check_raises_value_error_if_second_is_less_than_first_example():
+def test_range_check_raises_value_error_if_second_is_less_than_first_example() -> None:
with pytest.raises(ValueError, match="low >= high"):
range_check(7, 2)
@given(x=floats(allow_nan=False), d=data())
-def test_range_check_raises_value_error_if_second_is_less_than_first(x, d):
+def test_range_check_raises_value_error_if_second_is_less_than_first(
+ x: float, d: DataObject
+) -> None:
# Pull data such that the first is greater than or equal to the second.
y = d.draw(floats(max_value=x, allow_nan=False))
with pytest.raises(ValueError, match="low >= high"):
range_check(x, y)
-def test_check_filters_returns_none_if_filter_evaluates_to_false():
+def test_check_filters_returns_none_if_filter_evaluates_to_false() -> None:
assert check_filters(()) is None
- assert check_filters(False) is None
- assert check_filters(None) is None
-def test_check_filters_returns_input_as_is_if_filter_is_valid_example():
+def test_check_filters_returns_input_as_is_if_filter_is_valid_example() -> None:
assert check_filters([(6, 7)]) == [(6, 7)]
assert check_filters([(6, 7), (2, 8)]) == [(6, 7), (2, 8)]
@given(x=lists(integers(), min_size=1), d=data())
-def test_check_filters_returns_input_as_is_if_filter_is_valid(x, d):
+def test_check_filters_returns_input_as_is_if_filter_is_valid(
+ x: List[int], d: DataObject
+) -> None:
# ensure y is element-wise greater than x
y = [d.draw(integers(min_value=val + 1)) for val in x]
assert check_filters(list(zip(x, y))) == [(i, j) for i, j in zip(x, y)]
-def test_check_filters_raises_value_error_if_filter_is_invalid_example():
+def test_check_filters_raises_value_error_if_filter_is_invalid_example() -> None:
with pytest.raises(ValueError, match="Error in --filter: low >= high"):
check_filters([(7, 2)])
@given(x=lists(integers(), min_size=1), d=data())
-def test_check_filters_raises_value_error_if_filter_is_invalid(x, d):
+def test_check_filters_raises_value_error_if_filter_is_invalid(
+ x: List[int], d: DataObject
+) -> None:
# ensure y is element-wise less than or equal to x
y = [d.draw(integers(max_value=val)) for val in x]
with pytest.raises(ValueError, match="Error in --filter: low >= high"):
@@ -212,11 +212,11 @@ def test_check_filters_raises_value_error_if_filter_is_invalid(x, d):
# 3. No portion is between the bounds => False.
[([0], [100], True), ([1, 88], [20, 90], True), ([1], [20], False)],
)
-def test_keep_entry_range(lows, highs, truth):
+def test_keep_entry_range(lows: List[int], highs: List[int], truth: bool) -> None:
assert keep_entry_range("a56b23c89", lows, highs, int, re.compile(r"\d+")) is truth
# 1. Values not in entry => True. 2. Values in entry => False.
@pytest.mark.parametrize("values, truth", [([100, 45], True), ([23], False)])
-def test_keep_entry_value(values, truth):
+def test_keep_entry_value(values: List[int], truth: bool) -> None:
assert keep_entry_value("a56b23c89", values, int, re.compile(r"\d+")) is truth
diff --git a/tests/test_natsort_key.py b/tests/test_natsort_key.py
index 6b56bb1..cdfdc67 100644
--- a/tests/test_natsort_key.py
+++ b/tests/test_natsort_key.py
@@ -1,42 +1,45 @@
# -*- coding: utf-8 -*-
"""These test the utils.py functions."""
+from typing import Any, List, NoReturn, Tuple, Union, cast
from hypothesis import given
from hypothesis.strategies import binary, floats, integers, lists, text
from natsort.utils import natsort_key
-def str_func(x):
+def str_func(x: Any) -> Tuple[str]:
if isinstance(x, str):
- return x
+ return (x,)
else:
raise TypeError("Not a str!")
-def fail(_):
+def fail(_: Any) -> NoReturn:
raise AssertionError("This should never be reached!")
@given(floats(allow_nan=False) | integers())
-def test_natsort_key_with_numeric_input_takes_number_path(x):
- assert natsort_key(x, None, str_func, fail, lambda y: y) is x
+def test_natsort_key_with_numeric_input_takes_number_path(x: Union[float, int]) -> None:
+ assert natsort_key(x, None, str_func, fail, lambda y: ("", y))[1] is x
@given(binary().filter(bool))
-def test_natsort_key_with_bytes_input_takes_bytes_path(x):
- assert natsort_key(x, None, str_func, lambda y: y, fail) is x
+def test_natsort_key_with_bytes_input_takes_bytes_path(x: bytes) -> None:
+ assert natsort_key(x, None, str_func, lambda y: (y,), fail)[0] is x
@given(text())
-def test_natsort_key_with_text_input_takes_string_path(x):
- assert natsort_key(x, None, str_func, fail, fail) is x
+def test_natsort_key_with_text_input_takes_string_path(x: str) -> None:
+ assert natsort_key(x, None, str_func, fail, fail)[0] is x
@given(lists(elements=text(), min_size=1, max_size=10))
-def test_natsort_key_with_nested_input_takes_nested_path(x):
- assert natsort_key(x, None, str_func, fail, fail) == tuple(x)
+def test_natsort_key_with_nested_input_takes_nested_path(x: List[str]) -> None:
+ assert natsort_key(x, None, str_func, fail, fail) == tuple((y,) for y in x)
@given(text())
-def test_natsort_key_with_key_argument_applies_key_before_processing(x):
- assert natsort_key(x, len, str_func, fail, lambda y: y) == len(x)
+def test_natsort_key_with_key_argument_applies_key_before_processing(x: str) -> None:
+ assert natsort_key(x, len, str_func, fail, lambda y: ("", cast(int, y)))[1] == len(
+ x
+ )
diff --git a/tests/test_natsort_keygen.py b/tests/test_natsort_keygen.py
index 5fb8cf3..d4da2e3 100644
--- a/tests/test_natsort_keygen.py
+++ b/tests/test_natsort_keygen.py
@@ -5,23 +5,27 @@ See the README or the natsort homepage for more details.
"""
import os
+from typing import List, Tuple, Union
import pytest
from natsort import natsort_key, natsort_keygen, natsorted, ns
from natsort.compat.locale import get_strxfrm, null_string_locale
+from natsort.ns_enum import NSType
+from natsort.utils import BytesTransform, FinalTransform
+from pytest_mock import MockerFixture
@pytest.fixture
-def arbitrary_input():
+def arbitrary_input() -> List[Union[str, float]]:
return ["6A-5.034e+1", "/Folder (1)/Foo", 56.7]
@pytest.fixture
-def bytes_input():
+def bytes_input() -> bytes:
return b"6A-5.034e+1"
-def test_natsort_keygen_demonstration():
+def test_natsort_keygen_demonstration() -> None:
original_list = ["a50", "a51.", "a50.31", "a50.4", "a5.034e1", "a50.300"]
copy_of_list = original_list[:]
original_list.sort(key=natsort_keygen(alg=ns.F))
@@ -29,21 +33,23 @@ def test_natsort_keygen_demonstration():
assert original_list == natsorted(copy_of_list, alg=ns.F)
-def test_natsort_key_public():
+def test_natsort_key_public() -> None:
assert natsort_key("a-5.034e2") == ("a-", 5, ".", 34, "e", 2)
-def test_natsort_keygen_with_invalid_alg_input_raises_value_error():
+def test_natsort_keygen_with_invalid_alg_input_raises_value_error() -> None:
# Invalid arguments give the correct response
with pytest.raises(ValueError, match="'alg' argument"):
- natsort_keygen(None, "1")
+ natsort_keygen(None, "1") # type: ignore
@pytest.mark.parametrize(
"alg, expected",
[(ns.DEFAULT, ("a-", 5, ".", 34, "e", 1)), (ns.FLOAT | ns.SIGNED, ("a", -50.34))],
)
-def test_natsort_keygen_returns_natsort_key_that_parses_input(alg, expected):
+def test_natsort_keygen_returns_natsort_key_that_parses_input(
+ alg: NSType, expected: Tuple[Union[str, int, float], ...]
+) -> None:
ns_key = natsort_keygen(alg=alg)
assert ns_key("a-5.034e1") == expected
@@ -78,7 +84,9 @@ def test_natsort_keygen_returns_natsort_key_that_parses_input(alg, expected):
),
],
)
-def test_natsort_keygen_handles_arbitrary_input(arbitrary_input, alg, expected):
+def test_natsort_keygen_handles_arbitrary_input(
+ arbitrary_input: List[Union[str, float]], alg: NSType, expected: FinalTransform
+) -> None:
ns_key = natsort_keygen(alg=alg)
assert ns_key(arbitrary_input) == expected
@@ -93,7 +101,9 @@ def test_natsort_keygen_handles_arbitrary_input(arbitrary_input, alg, expected):
(ns.PATH | ns.GROUPLETTERS, ((b"6A-5.034e+1",),)),
],
)
-def test_natsort_keygen_handles_bytes_input(bytes_input, alg, expected):
+def test_natsort_keygen_handles_bytes_input(
+ bytes_input: bytes, alg: NSType, expected: BytesTransform
+) -> None:
ns_key = natsort_keygen(alg=alg)
assert ns_key(bytes_input) == expected
@@ -131,23 +141,29 @@ def test_natsort_keygen_handles_bytes_input(bytes_input, alg, expected):
],
)
@pytest.mark.usefixtures("with_locale_en_us")
-def test_natsort_keygen_with_locale(mocker, arbitrary_input, alg, expected, is_dumb):
+def test_natsort_keygen_with_locale(
+ mocker: MockerFixture,
+ arbitrary_input: List[Union[str, float]],
+ alg: NSType,
+ expected: FinalTransform,
+ is_dumb: bool,
+) -> None:
# First, apply the correct strxfrm function to the string values.
strxfrm = get_strxfrm()
- expected = [list(sub) for sub in expected]
+ expected_tmp = [list(sub) for sub in expected]
try:
for i in (2, 4, 6):
- expected[0][i] = strxfrm(expected[0][i])
+ expected_tmp[0][i] = strxfrm(expected_tmp[0][i])
for i in (0, 2):
- expected[1][i] = strxfrm(expected[1][i])
- expected = tuple(tuple(sub) for sub in expected)
+ expected_tmp[1][i] = strxfrm(expected_tmp[1][i])
+ expected = tuple(tuple(sub) for sub in expected_tmp)
except IndexError: # ns.LOCALE | ns.CAPITALFIRST
- expected = [[list(subsub) for subsub in sub] for sub in expected]
+ expected_tmp = [[list(subsub) for subsub in sub] for sub in expected_tmp]
for i in (2, 4, 6):
- expected[0][1][i] = strxfrm(expected[0][1][i])
+ expected_tmp[0][1][i] = strxfrm(expected_tmp[0][1][i])
for i in (0, 2):
- expected[1][1][i] = strxfrm(expected[1][1][i])
- expected = tuple(tuple(tuple(subsub) for subsub in sub) for sub in expected)
+ expected_tmp[1][1][i] = strxfrm(expected_tmp[1][1][i])
+ expected = tuple(tuple(tuple(subsub) for subsub in sub) for sub in expected_tmp)
mocker.patch("natsort.compat.locale.dumb_sort", return_value=is_dumb)
ns_key = natsort_keygen(alg=alg)
@@ -159,7 +175,9 @@ def test_natsort_keygen_with_locale(mocker, arbitrary_input, alg, expected, is_d
[(ns.LOCALE, False), (ns.LOCALE, True), (ns.LOCALE | ns.CAPITALFIRST, False)],
)
@pytest.mark.usefixtures("with_locale_en_us")
-def test_natsort_keygen_with_locale_bytes(mocker, bytes_input, alg, is_dumb):
+def test_natsort_keygen_with_locale_bytes(
+ mocker: MockerFixture, bytes_input: bytes, alg: NSType, is_dumb: bool
+) -> None:
expected = (b"6A-5.034e+1",)
mocker.patch("natsort.compat.locale.dumb_sort", return_value=is_dumb)
ns_key = natsort_keygen(alg=alg)
diff --git a/tests/test_natsorted.py b/tests/test_natsorted.py
index 4254e6c..d043ab4 100644
--- a/tests/test_natsorted.py
+++ b/tests/test_natsorted.py
@@ -5,34 +5,38 @@ See the README or the natsort homepage for more details.
"""
from operator import itemgetter
+from typing import List, Tuple, Union
import pytest
from natsort import as_utf8, natsorted, ns
+from natsort.ns_enum import NSType
from pytest import raises
@pytest.fixture
-def float_list():
+def float_list() -> List[str]:
return ["a50", "a51.", "a50.31", "a-50", "a50.4", "a5.034e1", "a50.300"]
@pytest.fixture
-def fruit_list():
+def fruit_list() -> List[str]:
return ["Apple", "corn", "Corn", "Banana", "apple", "banana"]
@pytest.fixture
-def mixed_list():
+def mixed_list() -> List[Union[str, int, float]]:
return ["Ä", "0", "ä", 3, "b", 1.5, "2", "Z"]
-def test_natsorted_numbers_in_ascending_order():
+def test_natsorted_numbers_in_ascending_order() -> None:
given = ["a2", "a5", "a9", "a1", "a4", "a10", "a6"]
expected = ["a1", "a2", "a4", "a5", "a6", "a9", "a10"]
assert natsorted(given) == expected
-def test_natsorted_can_sort_as_signed_floats_with_exponents(float_list):
+def test_natsorted_can_sort_as_signed_floats_with_exponents(
+ float_list: List[str],
+) -> None:
expected = ["a-50", "a50", "a50.300", "a50.31", "a5.034e1", "a50.4", "a51."]
assert natsorted(float_list, alg=ns.REAL) == expected
@@ -42,19 +46,23 @@ def test_natsorted_can_sort_as_signed_floats_with_exponents(float_list):
"alg",
[ns.NOEXP | ns.FLOAT | ns.UNSIGNED, ns.NOEXP | ns.FLOAT],
)
-def test_natsorted_can_sort_as_unsigned_and_ignore_exponents(float_list, alg):
+def test_natsorted_can_sort_as_unsigned_and_ignore_exponents(
+ float_list: List[str], alg: NSType
+) -> None:
expected = ["a5.034e1", "a50", "a50.300", "a50.31", "a50.4", "a51.", "a-50"]
assert natsorted(float_list, alg=alg) == expected
# DEFAULT and INT are all equivalent.
@pytest.mark.parametrize("alg", [ns.DEFAULT, ns.INT])
-def test_natsorted_can_sort_as_unsigned_ints_which_is_default(float_list, alg):
+def test_natsorted_can_sort_as_unsigned_ints_which_is_default(
+ float_list: List[str], alg: NSType
+) -> None:
expected = ["a5.034e1", "a50", "a50.4", "a50.31", "a50.300", "a51.", "a-50"]
assert natsorted(float_list, alg=alg) == expected
-def test_natsorted_can_sort_as_signed_ints(float_list):
+def test_natsorted_can_sort_as_signed_ints(float_list: List[str]) -> None:
expected = ["a-50", "a5.034e1", "a50", "a50.4", "a50.31", "a50.300", "a51."]
assert natsorted(float_list, alg=ns.SIGNED) == expected
@@ -63,12 +71,14 @@ def test_natsorted_can_sort_as_signed_ints(float_list):
"alg, expected",
[(ns.UNSIGNED, ["a7", "a+2", "a-5"]), (ns.SIGNED, ["a-5", "a+2", "a7"])],
)
-def test_natsorted_can_sort_with_or_without_accounting_for_sign(alg, expected):
+def test_natsorted_can_sort_with_or_without_accounting_for_sign(
+ alg: NSType, expected: List[str]
+) -> None:
given = ["a-5", "a7", "a+2"]
assert natsorted(given, alg=alg) == expected
-def test_natsorted_can_sort_as_version_numbers():
+def test_natsorted_can_sort_as_version_numbers() -> None:
given = ["1.9.9a", "1.11", "1.9.9b", "1.11.4", "1.10.1"]
expected = ["1.9.9a", "1.9.9b", "1.10.1", "1.11", "1.11.4"]
assert natsorted(given) == expected
@@ -81,7 +91,11 @@ def test_natsorted_can_sort_as_version_numbers():
(ns.NUMAFTER, ["Ä", "Z", "ä", "b", "0", 1.5, "2", 3]),
],
)
-def test_natsorted_handles_mixed_types(mixed_list, alg, expected):
+def test_natsorted_handles_mixed_types(
+ mixed_list: List[Union[str, int, float]],
+ alg: NSType,
+ expected: List[Union[str, int, float]],
+) -> None:
assert natsorted(mixed_list, alg=alg) == expected
@@ -92,14 +106,16 @@ def test_natsorted_handles_mixed_types(mixed_list, alg, expected):
(ns.NANLAST, [5, "25", 1e40, float("nan")], slice(None, 3)),
],
)
-def test_natsorted_handles_nan(alg, expected, slc):
- given = ["25", 5, float("nan"), 1e40]
+def test_natsorted_handles_nan(
+ alg: NSType, expected: List[Union[str, float, int]], slc: slice
+) -> None:
+ given: List[Union[str, float, int]] = ["25", 5, float("nan"), 1e40]
# The slice is because NaN != NaN
# noinspection PyUnresolvedReferences
assert natsorted(given, alg=alg)[slc] == expected[slc]
-def test_natsorted_with_mixed_bytes_and_str_input_raises_type_error():
+def test_natsorted_with_mixed_bytes_and_str_input_raises_type_error() -> None:
with raises(TypeError, match="bytes"):
natsorted(["ä", b"b"])
@@ -107,29 +123,31 @@ def test_natsorted_with_mixed_bytes_and_str_input_raises_type_error():
assert natsorted(["ä", b"b"], key=as_utf8) == ["ä", b"b"]
-def test_natsorted_raises_type_error_for_non_iterable_input():
+def test_natsorted_raises_type_error_for_non_iterable_input() -> None:
with raises(TypeError, match="'int' object is not iterable"):
- natsorted(100)
+ natsorted(100) # type: ignore
-def test_natsorted_recurses_into_nested_lists():
+def test_natsorted_recurses_into_nested_lists() -> None:
given = [["a1", "a5"], ["a1", "a40"], ["a10", "a1"], ["a2", "a5"]]
expected = [["a1", "a5"], ["a1", "a40"], ["a2", "a5"], ["a10", "a1"]]
assert natsorted(given) == expected
-def test_natsorted_applies_key_to_each_list_element_before_sorting_list():
+def test_natsorted_applies_key_to_each_list_element_before_sorting_list() -> None:
given = [("a", "num3"), ("b", "num5"), ("c", "num2")]
expected = [("c", "num2"), ("a", "num3"), ("b", "num5")]
assert natsorted(given, key=itemgetter(1)) == expected
-def test_natsorted_returns_list_in_reversed_order_with_reverse_option(float_list):
+def test_natsorted_returns_list_in_reversed_order_with_reverse_option(
+ float_list: List[str],
+) -> None:
expected = natsorted(float_list)[::-1]
assert natsorted(float_list, reverse=True) == expected
-def test_natsorted_handles_filesystem_paths():
+def test_natsorted_handles_filesystem_paths() -> None:
given = [
"/p/Folder (10)/file.tar.gz",
"/p/Folder (1)/file (1).tar.gz",
@@ -157,10 +175,10 @@ def test_natsorted_handles_filesystem_paths():
assert natsorted(given, alg=ns.FLOAT | ns.PATH) == expected_correct
-def test_natsorted_handles_numbers_and_filesystem_paths_simultaneously():
+def test_natsorted_handles_numbers_and_filesystem_paths_simultaneously() -> None:
# You can sort paths and numbers, not that you'd want to
- given = ["/Folder (9)/file.exe", 43]
- expected = [43, "/Folder (9)/file.exe"]
+ given: List[Union[str, int]] = ["/Folder (9)/file.exe", 43]
+ expected: List[Union[str, int]] = [43, "/Folder (9)/file.exe"]
assert natsorted(given, alg=ns.PATH) == expected
@@ -174,7 +192,9 @@ def test_natsorted_handles_numbers_and_filesystem_paths_simultaneously():
(ns.G | ns.LF, ["apple", "Apple", "banana", "Banana", "corn", "Corn"]),
],
)
-def test_natsorted_supports_case_handling(alg, expected, fruit_list):
+def test_natsorted_supports_case_handling(
+ alg: NSType, expected: List[str], fruit_list: List[str]
+) -> None:
assert natsorted(fruit_list, alg=alg) == expected
@@ -186,7 +206,9 @@ def test_natsorted_supports_case_handling(alg, expected, fruit_list):
(ns.IGNORECASE, [("a3", "a1"), ("A5", "a6")]),
],
)
-def test_natsorted_supports_nested_case_handling(alg, expected):
+def test_natsorted_supports_nested_case_handling(
+ alg: NSType, expected: List[Tuple[str, str]]
+) -> None:
given = [("A5", "a6"), ("a3", "a1")]
assert natsorted(given, alg=alg) == expected
@@ -201,26 +223,28 @@ def test_natsorted_supports_nested_case_handling(alg, expected):
],
)
@pytest.mark.usefixtures("with_locale_en_us")
-def test_natsorted_can_sort_using_locale(fruit_list, alg, expected):
+def test_natsorted_can_sort_using_locale(
+ fruit_list: List[str], alg: NSType, expected: List[str]
+) -> None:
assert natsorted(fruit_list, alg=ns.LOCALE | alg) == expected
@pytest.mark.usefixtures("with_locale_en_us")
-def test_natsorted_can_sort_locale_specific_numbers_en():
+def test_natsorted_can_sort_locale_specific_numbers_en() -> None:
given = ["c", "a5,467.86", "ä", "b", "a5367.86", "a5,6", "a5,50"]
expected = ["a5,6", "a5,50", "a5367.86", "a5,467.86", "ä", "b", "c"]
assert natsorted(given, alg=ns.LOCALE | ns.F) == expected
@pytest.mark.usefixtures("with_locale_de_de")
-def test_natsorted_can_sort_locale_specific_numbers_de():
+def test_natsorted_can_sort_locale_specific_numbers_de() -> None:
given = ["c", "a5.467,86", "ä", "b", "a5367.86", "a5,6", "a5,50"]
expected = ["a5,50", "a5,6", "a5367.86", "a5.467,86", "ä", "b", "c"]
assert natsorted(given, alg=ns.LOCALE | ns.F) == expected
@pytest.mark.usefixtures("with_locale_de_de")
-def test_natsorted_locale_bug_regression_test_109():
+def test_natsorted_locale_bug_regression_test_109() -> None:
# https://github.com/SethMMorton/natsort/issues/109
given = ["462166", "461761"]
expected = ["461761", "462166"]
@@ -242,7 +266,11 @@ def test_natsorted_locale_bug_regression_test_109():
],
)
@pytest.mark.usefixtures("with_locale_en_us")
-def test_natsorted_handles_mixed_types_with_locale(mixed_list, alg, expected):
+def test_natsorted_handles_mixed_types_with_locale(
+ mixed_list: List[Union[str, int, float]],
+ alg: NSType,
+ expected: List[Union[str, int, float]],
+) -> None:
assert natsorted(mixed_list, alg=ns.LOCALE | alg) == expected
@@ -253,12 +281,14 @@ def test_natsorted_handles_mixed_types_with_locale(mixed_list, alg, expected):
(ns.NUMAFTER, ["Banana", "apple", "corn", "~~~~~~", "73", "5039"]),
],
)
-def test_natsorted_sorts_an_odd_collection_of_strings(alg, expected):
+def test_natsorted_sorts_an_odd_collection_of_strings(
+ alg: NSType, expected: List[str]
+) -> None:
given = ["apple", "Banana", "73", "5039", "corn", "~~~~~~"]
assert natsorted(given, alg=alg) == expected
-def test_natsorted_sorts_mixed_ascii_and_non_ascii_numbers():
+def test_natsorted_sorts_mixed_ascii_and_non_ascii_numbers() -> None:
given = [
"1st street",
"10th street",
diff --git a/tests/test_natsorted_convenience.py b/tests/test_natsorted_convenience.py
index cdc2c50..0b2cd75 100644
--- a/tests/test_natsorted_convenience.py
+++ b/tests/test_natsorted_convenience.py
@@ -5,6 +5,7 @@ See the README or the natsort homepage for more details.
"""
from operator import itemgetter
+from typing import List
import pytest
from natsort import (
@@ -23,21 +24,21 @@ from natsort import (
@pytest.fixture
-def version_list():
+def version_list() -> List[str]:
return ["1.9.9a", "1.11", "1.9.9b", "1.11.4", "1.10.1"]
@pytest.fixture
-def float_list():
+def float_list() -> List[str]:
return ["a50", "a51.", "a50.31", "a-50", "a50.4", "a5.034e1", "a50.300"]
@pytest.fixture
-def fruit_list():
+def fruit_list() -> List[str]:
return ["Apple", "corn", "Corn", "Banana", "apple", "banana"]
-def test_decoder_returns_function_that_can_decode_bytes_but_return_non_bytes_as_is():
+def test_decoder_returns_function_that_decodes_bytes_but_returns_other_as_is() -> None:
func = decoder("latin1")
str_obj = "bytes"
int_obj = 14
@@ -46,24 +47,28 @@ def test_decoder_returns_function_that_can_decode_bytes_but_return_non_bytes_as_
assert func(str_obj) is str_obj # same object returned b/c only bytes has decode
-def test_as_ascii_converts_bytes_to_ascii():
+def test_as_ascii_converts_bytes_to_ascii() -> None:
assert decoder("ascii")(b"bytes") == as_ascii(b"bytes")
-def test_as_utf8_converts_bytes_to_utf8():
+def test_as_utf8_converts_bytes_to_utf8() -> None:
assert decoder("utf8")(b"bytes") == as_utf8(b"bytes")
-def test_realsorted_is_identical_to_natsorted_with_real_alg(float_list):
+def test_realsorted_is_identical_to_natsorted_with_real_alg(
+ float_list: List[str],
+) -> None:
assert realsorted(float_list) == natsorted(float_list, alg=ns.REAL)
@pytest.mark.usefixtures("with_locale_en_us")
-def test_humansorted_is_identical_to_natsorted_with_locale_alg(fruit_list):
+def test_humansorted_is_identical_to_natsorted_with_locale_alg(
+ fruit_list: List[str],
+) -> None:
assert humansorted(fruit_list) == natsorted(fruit_list, alg=ns.LOCALE)
-def test_index_natsorted_returns_integer_list_of_sort_order_for_input_list():
+def test_index_natsorted_returns_integer_list_of_sort_order_for_input_list() -> None:
given = ["num3", "num5", "num2"]
other = ["foo", "bar", "baz"]
index = index_natsorted(given)
@@ -72,27 +77,31 @@ def test_index_natsorted_returns_integer_list_of_sort_order_for_input_list():
assert [other[i] for i in index] == ["baz", "foo", "bar"]
-def test_index_natsorted_reverse():
+def test_index_natsorted_reverse() -> None:
given = ["num3", "num5", "num2"]
assert index_natsorted(given, reverse=True) == index_natsorted(given)[::-1]
-def test_index_natsorted_applies_key_function_before_sorting():
+def test_index_natsorted_applies_key_function_before_sorting() -> None:
given = [("a", "num3"), ("b", "num5"), ("c", "num2")]
expected = [2, 0, 1]
assert index_natsorted(given, key=itemgetter(1)) == expected
-def test_index_realsorted_is_identical_to_index_natsorted_with_real_alg(float_list):
+def test_index_realsorted_is_identical_to_index_natsorted_with_real_alg(
+ float_list: List[str],
+) -> None:
assert index_realsorted(float_list) == index_natsorted(float_list, alg=ns.REAL)
@pytest.mark.usefixtures("with_locale_en_us")
-def test_index_humansorted_is_identical_to_index_natsorted_with_locale_alg(fruit_list):
+def test_index_humansorted_is_identical_to_index_natsorted_with_locale_alg(
+ fruit_list: List[str],
+) -> None:
assert index_humansorted(fruit_list) == index_natsorted(fruit_list, alg=ns.LOCALE)
-def test_order_by_index_sorts_list_according_to_order_of_integer_list():
+def test_order_by_index_sorts_list_according_to_order_of_integer_list() -> None:
given = ["num3", "num5", "num2"]
index = [2, 0, 1]
expected = [given[i] for i in index]
@@ -100,7 +109,7 @@ def test_order_by_index_sorts_list_according_to_order_of_integer_list():
assert order_by_index(given, index) == expected
-def test_order_by_index_returns_generator_with_iter_true():
+def test_order_by_index_returns_generator_with_iter_true() -> None:
given = ["num3", "num5", "num2"]
index = [2, 0, 1]
assert order_by_index(given, index, True) != [given[i] for i in index]
diff --git a/tests/test_ns_enum.py b/tests/test_ns_enum.py
index 1d3803b..7a30718 100644
--- a/tests/test_ns_enum.py
+++ b/tests/test_ns_enum.py
@@ -1,8 +1,10 @@
+import pytest
from natsort import ns
-def test_ns_enum():
- enum_name_values = [
+@pytest.mark.parametrize(
+ "given, expected",
+ [
("FLOAT", 0x0001),
("SIGNED", 0x0002),
("NOEXP", 0x0004),
@@ -40,5 +42,7 @@ def test_ns_enum():
("NL", 0x0400),
("CN", 0x0800),
("NA", 0x1000),
- ]
- assert list(ns._asdict().items()) == enum_name_values
+ ],
+)
+def test_ns_enum(given: str, expected: int) -> None:
+ assert ns[given] == expected
diff --git a/tests/test_os_sorted.py b/tests/test_os_sorted.py
index afb15cf..d0ecc79 100644
--- a/tests/test_os_sorted.py
+++ b/tests/test_os_sorted.py
@@ -3,6 +3,7 @@
Testing for the OS sorting
"""
import platform
+from typing import cast
import natsort
import pytest
@@ -15,7 +16,7 @@ else:
has_icu = True
-def test_os_sorted_compound():
+def test_os_sorted_compound() -> None:
given = [
"/p/Folder (10)/file.tar.gz",
"/p/Folder (1)/file (1).tar.gz",
@@ -36,14 +37,14 @@ def test_os_sorted_compound():
assert result == expected
-def test_os_sorted_misc_no_fail():
+def test_os_sorted_misc_no_fail() -> None:
natsort.os_sorted([9, 4.3, None, float("nan")])
-def test_os_sorted_key():
+def test_os_sorted_key() -> None:
given = ["foo0", "foo2", "goo1"]
expected = ["foo0", "goo1", "foo2"]
- result = natsort.os_sorted(given, key=lambda x: x.replace("g", "f"))
+ result = natsort.os_sorted(given, key=lambda x: cast(str, x).replace("g", "f"))
assert result == expected
@@ -199,7 +200,7 @@ else:
@pytest.mark.usefixtures("with_locale_en_us")
-def test_os_sorted_corpus():
+def test_os_sorted_corpus() -> None:
result = natsort.os_sorted(given)
print(result)
assert result == expected
diff --git a/tests/test_parse_bytes_function.py b/tests/test_parse_bytes_function.py
index 6637cbd..318c4aa 100644
--- a/tests/test_parse_bytes_function.py
+++ b/tests/test_parse_bytes_function.py
@@ -4,8 +4,8 @@
import pytest
from hypothesis import given
from hypothesis.strategies import binary
-from natsort.ns_enum import ns
-from natsort.utils import parse_bytes_factory
+from natsort.ns_enum import NSType, ns
+from natsort.utils import BytesTransformer, parse_bytes_factory
@pytest.mark.parametrize(
@@ -19,6 +19,8 @@ from natsort.utils import parse_bytes_factory
],
)
@given(x=binary())
-def test_parse_bytest_factory_makes_function_that_returns_tuple(x, alg, example_func):
+def test_parse_bytest_factory_makes_function_that_returns_tuple(
+ x: bytes, alg: NSType, example_func: BytesTransformer
+) -> None:
parse_bytes_func = parse_bytes_factory(alg)
assert parse_bytes_func(x) == example_func(x)
diff --git a/tests/test_parse_number_function.py b/tests/test_parse_number_function.py
index 29ee5a3..e5f417d 100644
--- a/tests/test_parse_number_function.py
+++ b/tests/test_parse_number_function.py
@@ -1,11 +1,13 @@
# -*- coding: utf-8 -*-
"""These test the utils.py functions."""
+from typing import Optional, Tuple, Union
+
import pytest
from hypothesis import given
from hypothesis.strategies import floats, integers
-from natsort.ns_enum import ns
-from natsort.utils import parse_number_or_none_factory
+from natsort.ns_enum import NSType, ns
+from natsort.utils import MaybeNumTransformer, parse_number_or_none_factory
@pytest.mark.usefixtures("with_locale_en_us")
@@ -19,7 +21,9 @@ from natsort.utils import parse_number_or_none_factory
],
)
@given(x=floats(allow_nan=False) | integers())
-def test_parse_number_factory_makes_function_that_returns_tuple(x, alg, example_func):
+def test_parse_number_factory_makes_function_that_returns_tuple(
+ x: Union[float, int], alg: NSType, example_func: MaybeNumTransformer
+) -> None:
parse_number_func = parse_number_or_none_factory(alg, "", "xx")
assert parse_number_func(x) == example_func(x)
@@ -34,6 +38,8 @@ def test_parse_number_factory_makes_function_that_returns_tuple(x, alg, example_
(ns.NANLAST, None, ("", float("+inf"))), # NANLAST makes it +infinity
],
)
-def test_parse_number_factory_treats_nan_and_none_special(alg, x, result):
+def test_parse_number_factory_treats_nan_and_none_special(
+ alg: NSType, x: Optional[Union[float, int]], result: Tuple[str, Union[float, int]]
+) -> None:
parse_number_func = parse_number_or_none_factory(alg, "", "xx")
assert parse_number_func(x) == result
diff --git a/tests/test_parse_string_function.py b/tests/test_parse_string_function.py
index 46347f1..653a065 100644
--- a/tests/test_parse_string_function.py
+++ b/tests/test_parse_string_function.py
@@ -2,23 +2,28 @@
"""These test the utils.py functions."""
import unicodedata
+from typing import Any, Callable, Iterable, List, Tuple, Union
import pytest
from hypothesis import given
from hypothesis.strategies import floats, integers, lists, text
from natsort.compat.fastnumbers import fast_float
-from natsort.ns_enum import NS_DUMB, ns
-from natsort.utils import NumericalRegularExpressions as NumRegex
+from natsort.ns_enum import NSType, NS_DUMB, ns
+from natsort.utils import (
+ FinalTransform,
+ NumericalRegularExpressions as NumRegex,
+ StrParser,
+)
from natsort.utils import parse_string_factory
-class CustomTuple(tuple):
+class CustomTuple(Tuple[Any, ...]):
"""Used to ensure what is given during testing is what is returned."""
- original = None
+ original: Any = None
-def input_transform(x):
+def input_transform(x: Any) -> Any:
"""Make uppercase."""
try:
return x.upper()
@@ -26,14 +31,14 @@ def input_transform(x):
return x
-def final_transform(x, original):
+def final_transform(x: Iterable[Any], original: str) -> FinalTransform:
"""Make the input a CustomTuple."""
t = CustomTuple(x)
t.original = original
return t
-def parse_string_func_factory(alg):
+def parse_string_func_factory(alg: NSType) -> StrParser:
"""A parse_string_factory result with sample arguments."""
sep = ""
return parse_string_factory(
@@ -47,10 +52,12 @@ def parse_string_func_factory(alg):
@given(x=floats() | integers())
-def test_parse_string_factory_raises_type_error_if_given_number(x):
+def test_parse_string_factory_raises_type_error_if_given_number(
+ x: Union[int, float]
+) -> None:
parse_string_func = parse_string_func_factory(ns.DEFAULT)
with pytest.raises(TypeError):
- assert parse_string_func(x)
+ assert parse_string_func(x) # type: ignore
# noinspection PyCallingNonCallable
@@ -68,7 +75,9 @@ def test_parse_string_factory_raises_type_error_if_given_number(x):
)
)
@pytest.mark.usefixtures("with_locale_en_us")
-def test_parse_string_factory_invariance(x, alg, orig_func):
+def test_parse_string_factory_invariance(
+ x: List[Union[float, str, int]], alg: NSType, orig_func: Callable[[str], str]
+) -> None:
parse_string_func = parse_string_func_factory(alg)
# parse_string_factory is the high-level combination of several dedicated
# functions involved in splitting and manipulating a string. The details of
diff --git a/tests/test_regex.py b/tests/test_regex.py
index f647f5f..08314b5 100644
--- a/tests/test_regex.py
+++ b/tests/test_regex.py
@@ -1,8 +1,11 @@
# -*- coding: utf-8 -*-
"""These test the splitting regular expressions."""
+from typing import List, Pattern
+
import pytest
from natsort import ns, numeric_regex_chooser
+from natsort.ns_enum import NSType
from natsort.utils import NumericalRegularExpressions as NumRegex
@@ -95,7 +98,9 @@ labels = ["{}-{}".format(given, regex_names[regex]) for given, _, regex in regex
@pytest.mark.parametrize("x, expected, regex", regex_params, ids=labels)
-def test_regex_splits_correctly(x, expected, regex):
+def test_regex_splits_correctly(
+ x: str, expected: List[str], regex: Pattern[str]
+) -> None:
# noinspection PyUnresolvedReferences
assert regex.split(x) == expected
@@ -115,5 +120,5 @@ def test_regex_splits_correctly(x, expected, regex):
(ns.FLOAT | ns.UNSIGNED | ns.NOEXP, NumRegex.float_nosign_noexp()),
],
)
-def test_regex_chooser(given, expected):
+def test_regex_chooser(given: NSType, expected: Pattern[str]) -> None:
assert numeric_regex_chooser(given) == expected.pattern[1:-1] # remove parens
diff --git a/tests/test_string_component_transform_factory.py b/tests/test_string_component_transform_factory.py
index 8b77d38..99df7ea 100644
--- a/tests/test_string_component_transform_factory.py
+++ b/tests/test_string_component_transform_factory.py
@@ -2,13 +2,14 @@
"""These test the utils.py functions."""
from functools import partial
+from typing import Any, Callable, FrozenSet, Union
import pytest
from hypothesis import example, given
from hypothesis.strategies import floats, integers, text
from natsort.compat.fastnumbers import fast_float, fast_int
from natsort.compat.locale import get_strxfrm
-from natsort.ns_enum import NS_DUMB, ns
+from natsort.ns_enum import NSType, NS_DUMB, ns
from natsort.utils import groupletters, string_component_transform_factory
# There are some unicode values that are known failures with the builtin locale
@@ -21,12 +22,12 @@ except ValueError:
bad_uni_chars = frozenset()
-def no_bad_uni_chars(x, _bad_chars=bad_uni_chars):
+def no_bad_uni_chars(x: str, _bad_chars: FrozenSet[str] = bad_uni_chars) -> bool:
"""Ensure text does not contain bad unicode characters"""
return not any(y in _bad_chars for y in x)
-def no_null(x):
+def no_null(x: str) -> bool:
"""Ensure text does not contain a null character."""
return "\0" not in x
@@ -65,7 +66,9 @@ def no_null(x):
| text().filter(bool).filter(no_bad_uni_chars).filter(no_null)
)
@pytest.mark.usefixtures("with_locale_en_us")
-def test_string_component_transform_factory(x, alg, example_func):
+def test_string_component_transform_factory(
+ x: Union[str, float, int], alg: NSType, example_func: Callable[[str], Any]
+) -> None:
string_component_transform_func = string_component_transform_factory(alg)
try:
assert string_component_transform_func(str(x)) == example_func(str(x))
diff --git a/tests/test_unicode_numbers.py b/tests/test_unicode_numbers.py
index be0f4ee..eb71125 100644
--- a/tests/test_unicode_numbers.py
+++ b/tests/test_unicode_numbers.py
@@ -14,27 +14,27 @@ from natsort.unicode_numbers import (
digits_no_decimals,
numeric,
numeric_chars,
- numeric_hex,
numeric_no_decimals,
)
+from natsort.unicode_numeric_hex import numeric_hex
-def test_numeric_chars_contains_only_valid_unicode_numeric_characters():
+def test_numeric_chars_contains_only_valid_unicode_numeric_characters() -> None:
for a in numeric_chars:
assert unicodedata.numeric(a, None) is not None
-def test_digit_chars_contains_only_valid_unicode_digit_characters():
+def test_digit_chars_contains_only_valid_unicode_digit_characters() -> None:
for a in digit_chars:
assert unicodedata.digit(a, None) is not None
-def test_decimal_chars_contains_only_valid_unicode_decimal_characters():
+def test_decimal_chars_contains_only_valid_unicode_decimal_characters() -> None:
for a in decimal_chars:
assert unicodedata.decimal(a, None) is not None
-def test_numeric_chars_contains_all_valid_unicode_numeric_and_digit_characters():
+def test_numeric_chars_contains_all_valid_unicode_numeric_and_digit_characters() -> None:
set_numeric_chars = set(numeric_chars)
set_digit_chars = set(digit_chars)
set_decimal_chars = set(decimal_chars)
@@ -46,7 +46,7 @@ def test_numeric_chars_contains_all_valid_unicode_numeric_and_digit_characters()
assert set_numeric_chars.issuperset(numeric_no_decimals)
-def test_missing_unicode_number_in_collection():
+def test_missing_unicode_number_in_collection() -> None:
ok = True
set_numeric_hex = set(numeric_hex)
for i in range(0x110000):
@@ -71,7 +71,7 @@ repository (https://github.com/SethMMorton/natsort) with the resulting change.
)
-def test_combined_string_contains_all_characters_in_list():
+def test_combined_string_contains_all_characters_in_list() -> None:
assert numeric == "".join(numeric_chars)
assert digits == "".join(digit_chars)
assert decimals == "".join(decimal_chars)
diff --git a/tests/test_utils.py b/tests/test_utils.py
index d559803..38df303 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -6,15 +6,16 @@ import pathlib
import string
from itertools import chain
from operator import neg as op_neg
+from typing import List, Pattern, Union
import pytest
from hypothesis import given
from hypothesis.strategies import integers, lists, sampled_from, text
from natsort import utils
-from natsort.ns_enum import ns
+from natsort.ns_enum import NSType, ns
-def test_do_decoding_decodes_bytes_string_to_unicode():
+def test_do_decoding_decodes_bytes_string_to_unicode() -> None:
assert type(utils.do_decoding(b"bytes", "ascii")) is str
assert utils.do_decoding(b"bytes", "ascii") == "bytes"
assert utils.do_decoding(b"bytes", "ascii") == b"bytes".decode("ascii")
@@ -33,7 +34,9 @@ def test_do_decoding_decodes_bytes_string_to_unicode():
(ns.F | ns.S | ns.N, utils.NumericalRegularExpressions.float_sign_noexp()),
],
)
-def test_regex_chooser_returns_correct_regular_expression_object(alg, expected):
+def test_regex_chooser_returns_correct_regular_expression_object(
+ alg: NSType, expected: Pattern[str]
+) -> None:
assert utils.regex_chooser(alg).pattern == expected.pattern
@@ -68,21 +71,21 @@ def test_regex_chooser_returns_correct_regular_expression_object(alg, expected):
(ns.REAL, ns.FLOAT | ns.SIGNED),
],
)
-def test_ns_enum_values_and_aliases(alg, value_or_alias):
+def test_ns_enum_values_and_aliases(alg: NSType, value_or_alias: NSType) -> None:
assert alg == value_or_alias
-def test_chain_functions_is_a_no_op_if_no_functions_are_given():
+def test_chain_functions_is_a_no_op_if_no_functions_are_given() -> None:
x = 2345
assert utils.chain_functions([])(x) is x
-def test_chain_functions_does_one_function_if_one_function_is_given():
+def test_chain_functions_does_one_function_if_one_function_is_given() -> None:
x = "2345"
assert utils.chain_functions([len])(x) == 4
-def test_chain_functions_combines_functions_in_given_order():
+def test_chain_functions_combines_functions_in_given_order() -> None:
x = 2345
assert utils.chain_functions([str, len, op_neg])(x) == -len(str(x))
@@ -91,33 +94,37 @@ def test_chain_functions_combines_functions_in_given_order():
# and a test that uses the hypothesis module.
-def test_groupletters_returns_letters_with_lowercase_transform_of_letter_example():
+def test_groupletters_gives_letters_with_lowercase_letter_transform_example() -> None:
assert utils.groupletters("HELLO") == "hHeElLlLoO"
assert utils.groupletters("hello") == "hheelllloo"
@given(text().filter(bool))
-def test_groupletters_returns_letters_with_lowercase_transform_of_letter(x):
+def test_groupletters_gives_letters_with_lowercase_letter_transform(
+ x: str,
+) -> None:
assert utils.groupletters(x) == "".join(
chain.from_iterable([y.casefold(), y] for y in x)
)
-def test_sep_inserter_does_nothing_if_no_numbers_example():
+def test_sep_inserter_does_nothing_if_no_numbers_example() -> None:
assert list(utils.sep_inserter(iter(["a", "b", "c"]), "")) == ["a", "b", "c"]
assert list(utils.sep_inserter(iter(["a"]), "")) == ["a"]
-def test_sep_inserter_does_nothing_if_only_one_number_example():
+def test_sep_inserter_does_nothing_if_only_one_number_example() -> None:
assert list(utils.sep_inserter(iter(["a", 5]), "")) == ["a", 5]
-def test_sep_inserter_inserts_separator_string_between_two_numbers_example():
+def test_sep_inserter_inserts_separator_string_between_two_numbers_example() -> None:
assert list(utils.sep_inserter(iter([5, 9]), "")) == ["", 5, "", 9]
@given(lists(elements=text().filter(bool) | integers(), min_size=3))
-def test_sep_inserter_inserts_separator_between_two_numbers(x):
+def test_sep_inserter_inserts_separator_between_two_numbers(
+ x: List[Union[str, int]]
+) -> None:
# Rather than just replicating the results in a different algorithm,
# validate that the "shape" of the output is as expected.
result = list(utils.sep_inserter(iter(x), ""))
@@ -127,28 +134,29 @@ def test_sep_inserter_inserts_separator_between_two_numbers(x):
assert isinstance(result[i + 1], int)
-def test_path_splitter_splits_path_string_by_separator_example():
+def test_path_splitter_splits_path_string_by_sep_example() -> None:
given = "/this/is/a/path"
expected = (os.sep, "this", "is", "a", "path")
assert tuple(utils.path_splitter(given)) == tuple(expected)
- given = pathlib.Path(given)
- assert tuple(utils.path_splitter(given)) == tuple(expected)
+ assert tuple(utils.path_splitter(pathlib.Path(given))) == tuple(expected)
@given(lists(sampled_from(string.ascii_letters), min_size=2).filter(all))
-def test_path_splitter_splits_path_string_by_separator(x):
+def test_path_splitter_splits_path_string_by_sep(x: List[str]) -> None:
z = str(pathlib.Path(*x))
assert tuple(utils.path_splitter(z)) == tuple(pathlib.Path(z).parts)
-def test_path_splitter_splits_path_string_by_separator_and_removes_extension_example():
+def test_path_splitter_splits_path_string_by_sep_and_removes_extension_example() -> None:
given = "/this/is/a/path/file.x1.10.tar.gz"
expected = (os.sep, "this", "is", "a", "path", "file.x1.10", ".tar", ".gz")
assert tuple(utils.path_splitter(given)) == tuple(expected)
@given(lists(sampled_from(string.ascii_letters), min_size=3).filter(all))
-def test_path_splitter_splits_path_string_by_separator_and_removes_extension(x):
+def test_path_splitter_splits_path_string_by_sep_and_removes_extension(
+ x: List[str],
+) -> None:
z = str(pathlib.Path(*x[:-2])) + "." + x[-1]
y = tuple(pathlib.Path(z).parts)
assert tuple(utils.path_splitter(z)) == y[:-1] + (
diff --git a/tox.ini b/tox.ini
index 2e3d0df..7ab790d 100644
--- a/tox.ini
+++ b/tox.ini
@@ -5,7 +5,7 @@
[tox]
envlist =
- flake8, py35, py36, py37, py38, py39
+ flake8, mypy, py36, py37, py38, py39
# Other valid evironments are:
# docs
# release
@@ -52,6 +52,18 @@ commands =
twine check dist/*
skip_install = true
+# Type checking
+[testenv:mypy]
+deps =
+ mypy
+ hypothesis
+ pytest
+ pytest-mock
+ fastnumbers
+commands =
+ mypy --strict natsort tests
+skip_install = true
+
# Build documentation.
# sphinx and sphinx_rtd_theme not in docs/requirements.txt because they
# will already be installed on readthedocs.