summaryrefslogtreecommitdiff
path: root/natsort/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'natsort/utils.py')
-rw-r--r--natsort/utils.py278
1 files changed, 210 insertions, 68 deletions
diff --git a/natsort/utils.py b/natsort/utils.py
index 3f6641b..7102f41 100644
--- a/natsort/utils.py
+++ b/natsort/utils.py
@@ -38,19 +38,93 @@ that ensures "val" is a local variable instead of global variable
and thus has a slightly improved performance at runtime.
"""
-
import re
from functools import partial, reduce
from itertools import chain as ichain
from operator import methodcaller
from pathlib import PurePath
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Iterable,
+ Iterator,
+ List,
+ Match,
+ Optional,
+ Pattern,
+ Tuple,
+ Union,
+ cast,
+ overload,
+)
from unicodedata import normalize
from natsort.compat.fastnumbers import fast_float, fast_int
-from natsort.compat.locale import get_decimal_point, get_strxfrm, get_thousands_sep
-from natsort.ns_enum import NS_DUMB, ns
+from natsort.compat.locale import (
+ StrOrBytes,
+ get_decimal_point,
+ get_strxfrm,
+ get_thousands_sep,
+)
+from natsort.ns_enum import NSType, NS_DUMB, ns
from natsort.unicode_numbers import digits_no_decimals, numeric_no_decimals
+#
+# Pre-define a slew of aggregate types which makes the type hinting below easier
+#
+StrToStr = Callable[[str], str]
+AnyCall = Callable[[Any], Any]
+
+# For the bytes transform factory
+BytesTuple = Tuple[bytes]
+NestedBytesTuple = Tuple[Tuple[bytes]]
+BytesTransform = Union[BytesTuple, NestedBytesTuple]
+BytesTransformer = Callable[[bytes], BytesTransform]
+
+# For the number transform factory
+NumType = Union[float, int]
+MaybeNumType = Optional[NumType]
+NumTuple = Tuple[StrOrBytes, NumType]
+NestedNumTuple = Tuple[NumTuple]
+StrNumTuple = Tuple[Tuple[str], NumTuple]
+NestedStrNumTuple = Tuple[StrNumTuple]
+MaybeNumTransform = Union[NumTuple, NestedNumTuple, StrNumTuple, NestedStrNumTuple]
+MaybeNumTransformer = Callable[[MaybeNumType], MaybeNumTransform]
+
+# For the string component transform factory
+StrBytesNum = Union[str, bytes, float, int]
+StrTransformer = Callable[[str], StrBytesNum]
+
+# For the final data transform factory
+TwoBlankTuple = Tuple[Tuple[()], Tuple[()]]
+TupleOfAny = Tuple[Any, ...]
+TupleOfStrAnyPair = Tuple[Tuple[str], TupleOfAny]
+FinalTransform = Union[TwoBlankTuple, TupleOfAny, TupleOfStrAnyPair]
+FinalTransformer = Callable[[Iterable[Any], str], FinalTransform]
+
+# For the string parsing factory
+StrSplitter = Callable[[str], Iterable[str]]
+StrParser = Callable[[str], FinalTransform]
+
+# For the path splitter
+PathArg = Union[str, PurePath]
+MatchFn = Callable[[str], Optional[Match]]
+
+# For the path parsing factory
+PathSplitter = Callable[[PathArg], Tuple[FinalTransform, ...]]
+
+# For the natsort key
+StrBytesPathNum = Union[str, bytes, float, int, PurePath]
+NatsortInType = Union[
+ Optional[StrBytesPathNum], Iterable[Union[Optional[StrBytesPathNum], Iterable[Any]]]
+]
+NatsortOutType = Tuple[
+ Union[StrBytesNum, Tuple[Union[StrBytesNum, Tuple[Any, ...]], ...]], ...
+]
+KeyType = Callable[[Any], NatsortInType]
+MaybeKeyType = Optional[KeyType]
+
class NumericalRegularExpressions:
"""
@@ -62,51 +136,51 @@ class NumericalRegularExpressions:
"""
# All unicode numeric characters (minus the decimal characters).
- numeric = numeric_no_decimals
+ numeric: str = numeric_no_decimals
# All unicode digit characters (minus the decimal characters).
- digits = digits_no_decimals
+ digits: str = digits_no_decimals
# Regular expression to match exponential component of a float.
- exp = r"(?:[eE][-+]?\d+)?"
+ exp: str = r"(?:[eE][-+]?\d+)?"
# Regular expression to match a floating point number.
- float_num = r"(?:\d+\.?\d*|\.\d+)"
+ float_num: str = r"(?:\d+\.?\d*|\.\d+)"
@classmethod
- def _construct_regex(cls, fmt):
+ def _construct_regex(cls, fmt: str) -> Pattern[str]:
"""Given a format string, construct the regex with class attributes."""
return re.compile(fmt.format(**vars(cls)), flags=re.U)
@classmethod
- def int_sign(cls):
+ def int_sign(cls) -> Pattern[str]:
"""Regular expression to match a signed int."""
return cls._construct_regex(r"([-+]?\d+|[{digits}])")
@classmethod
- def int_nosign(cls):
+ def int_nosign(cls) -> Pattern[str]:
"""Regular expression to match an unsigned int."""
return cls._construct_regex(r"(\d+|[{digits}])")
@classmethod
- def float_sign_exp(cls):
+ def float_sign_exp(cls) -> Pattern[str]:
"""Regular expression to match a signed float with exponent."""
return cls._construct_regex(r"([-+]?{float_num}{exp}|[{numeric}])")
@classmethod
- def float_nosign_exp(cls):
+ def float_nosign_exp(cls) -> Pattern[str]:
"""Regular expression to match an unsigned float with exponent."""
return cls._construct_regex(r"({float_num}{exp}|[{numeric}])")
@classmethod
- def float_sign_noexp(cls):
+ def float_sign_noexp(cls) -> Pattern[str]:
"""Regular expression to match a signed float without exponent."""
return cls._construct_regex(r"([-+]?{float_num}|[{numeric}])")
@classmethod
- def float_nosign_noexp(cls):
+ def float_nosign_noexp(cls) -> Pattern[str]:
"""Regular expression to match an unsigned float without exponent."""
return cls._construct_regex(r"({float_num}|[{numeric}])")
-def regex_chooser(alg):
+def regex_chooser(alg: NSType) -> Pattern[str]:
"""
Select an appropriate regex for the type of number of interest.
@@ -136,12 +210,12 @@ def regex_chooser(alg):
}[alg]
-def _no_op(x):
+def _no_op(x: Any) -> Any:
"""A function that does nothing and returns the input as-is."""
return x
-def _normalize_input_factory(alg):
+def _normalize_input_factory(alg: NSType) -> StrToStr:
"""
Create a function that will normalize unicode input data.
@@ -161,7 +235,35 @@ def _normalize_input_factory(alg):
return partial(normalize, normalization_form)
-def natsort_key(val, key, string_func, bytes_func, num_func):
+@overload
+def natsort_key(
+ val: NatsortInType,
+ key: None,
+ string_func: Union[StrParser, PathSplitter],
+ bytes_func: BytesTransformer,
+ num_func: MaybeNumTransformer,
+) -> NatsortOutType:
+ ...
+
+
+@overload
+def natsort_key(
+ val: Any,
+ key: KeyType,
+ string_func: Union[StrParser, PathSplitter],
+ bytes_func: BytesTransformer,
+ num_func: MaybeNumTransformer,
+) -> NatsortOutType:
+ ...
+
+
+def natsort_key(
+ val: Union[NatsortInType, Any],
+ key: MaybeKeyType,
+ string_func: Union[StrParser, PathSplitter],
+ bytes_func: BytesTransformer,
+ num_func: MaybeNumTransformer,
+) -> NatsortOutType:
"""
Key to sort strings and numbers naturally.
@@ -170,7 +272,7 @@ def natsort_key(val, key, string_func, bytes_func, num_func):
Parameters
----------
- val : str | unicode | bytes | int | float | iterable
+ val : str | bytes | int | float | iterable
key : callable | None
A key to apply to the *val* before any other operations are performed.
string_func : callable
@@ -210,26 +312,27 @@ def natsort_key(val, key, string_func, bytes_func, num_func):
# Assume the input are strings, which is the most common case
try:
- return string_func(val)
+ return string_func(cast(str, val))
except (TypeError, AttributeError):
# If bytes type, use the bytes_func
if type(val) in (bytes,):
- return bytes_func(val)
+ return bytes_func(cast(bytes, val))
# Otherwise, assume it is an iterable that must be parsed recursively.
# Do not apply the key recursively.
try:
return tuple(
- natsort_key(x, None, string_func, bytes_func, num_func) for x in val
+ natsort_key(x, None, string_func, bytes_func, num_func)
+ for x in cast(Iterable[Any], val)
)
# If that failed, it must be a number.
except TypeError:
- return num_func(val)
+ return num_func(cast(NumType, val))
-def parse_bytes_factory(alg):
+def parse_bytes_factory(alg: NSType) -> BytesTransformer:
"""
Create a function that will format a *bytes* object into a tuple.
@@ -262,7 +365,9 @@ def parse_bytes_factory(alg):
return lambda x: (x,)
-def parse_number_or_none_factory(alg, sep, pre_sep):
+def parse_number_or_none_factory(
+ alg: NSType, sep: StrOrBytes, pre_sep: str
+) -> MaybeNumTransformer:
"""
Create a function that will format a number (or None) into a tuple.
@@ -293,7 +398,9 @@ def parse_number_or_none_factory(alg, sep, pre_sep):
"""
nan_replace = float("+inf") if alg & ns.NANLAST else float("-inf")
- def func(val, _nan_replace=nan_replace, _sep=sep):
+ def func(
+ val: MaybeNumType, _nan_replace: float = nan_replace, _sep: StrOrBytes = sep
+ ) -> NumTuple:
"""Given a number, place it in a tuple with a leading null string."""
return _sep, (_nan_replace if val != val or val is None else val)
@@ -309,8 +416,13 @@ def parse_number_or_none_factory(alg, sep, pre_sep):
def parse_string_factory(
- alg, sep, splitter, input_transform, component_transform, final_transform
-):
+ alg: NSType,
+ sep: StrOrBytes,
+ splitter: StrSplitter,
+ input_transform: StrToStr,
+ component_transform: StrTransformer,
+ final_transform: FinalTransformer,
+) -> StrParser:
"""
Create a function that will split and format a *str* into a tuple.
@@ -361,22 +473,22 @@ def parse_string_factory(
original_func = input_transform if orig_after_xfrm else _no_op
normalize_input = _normalize_input_factory(alg)
- def func(x):
+ def func(x: str) -> FinalTransform:
# Apply string input transformation function and return to x.
# Original function is usually a no-op, but some algorithms require it
# to also be the transformation function.
- x = normalize_input(x)
- x, original = input_transform(x), original_func(x)
- x = splitter(x) # Split string into components.
- x = filter(None, x) # Remove empty strings.
- x = map(component_transform, x) # Apply transform on components.
- x = sep_inserter(x, sep) # Insert '' between numbers.
- return final_transform(x, original) # Apply the final transform.
+ a = normalize_input(x)
+ b, original = input_transform(a), original_func(a)
+ c = splitter(b) # Split string into components.
+ d = filter(None, c) # Remove empty strings.
+ e = map(component_transform, d) # Apply transform on components.
+ f = sep_inserter(e, sep) # Insert '' between numbers.
+ return final_transform(f, original) # Apply the final transform.
return func
-def parse_path_factory(str_split):
+def parse_path_factory(str_split: StrParser) -> PathSplitter:
"""
Create a function that will properly split and format a path.
@@ -403,41 +515,41 @@ def parse_path_factory(str_split):
return lambda x: tuple(map(str_split, path_splitter(x)))
-def sep_inserter(iterable, sep):
+def sep_inserter(iterator: Iterator[Any], sep: StrOrBytes) -> Iterator[Any]:
"""
- Insert '' between numbers in an iterable.
+ Insert '' between numbers in an iterator.
Parameters
----------
- iterable
+ iterator
sep : str
The string character to be inserted between adjacent numeric objects.
Yields
------
- The values of *iterable* in order, with *sep* inserted where adjacent
+ The values of *iterator* in order, with *sep* inserted where adjacent
elements are numeric. If the first element in the input is numeric
then *sep* will be the first value yielded.
"""
try:
- # Get the first element. A StopIteration indicates an empty iterable.
+ # Get the first element. A StopIteration indicates an empty iterator.
# Since we are controlling the types of the input, 'type' is used
# instead of 'isinstance' for the small speed advantage it offers.
types = (int, float)
- first = next(iterable)
+ first = next(iterator)
if type(first) in types:
yield sep
yield first
# Now, check if pair of elements are both numbers. If so, add ''.
- second = next(iterable)
+ second = next(iterator)
if type(first) in types and type(second) in types:
yield sep
yield second
# Now repeat in a loop.
- for x in iterable:
+ for x in iterator:
first, second = second, x
if type(first) in types and type(second) in types:
yield sep
@@ -448,7 +560,7 @@ def sep_inserter(iterable, sep):
return
-def input_string_transform_factory(alg):
+def input_string_transform_factory(alg: NSType) -> StrToStr:
"""
Create a function to transform a string.
@@ -473,7 +585,7 @@ def input_string_transform_factory(alg):
dumb = alg & NS_DUMB
# Build the chain of functions to execute in order.
- function_chain = []
+ function_chain: List[StrToStr] = []
if (dumb and not lowfirst) or (lowfirst and not dumb):
function_chain.append(methodcaller("swapcase"))
@@ -502,8 +614,8 @@ def input_string_transform_factory(alg):
strip_thousands = strip_thousands.format(
thou=re.escape(get_thousands_sep()), nodecimal=nodecimal
)
- strip_thousands = re.compile(strip_thousands, flags=re.VERBOSE)
- function_chain.append(partial(strip_thousands.sub, ""))
+ strip_thousands_re = re.compile(strip_thousands, flags=re.VERBOSE)
+ function_chain.append(partial(strip_thousands_re.sub, ""))
# Create a regular expression that will change the decimal point to
# a period if not already a period.
@@ -511,14 +623,14 @@ def input_string_transform_factory(alg):
if alg & ns.FLOAT and decimal != ".":
switch_decimal = r"(?<=[0-9]){decimal}|{decimal}(?=[0-9])"
switch_decimal = switch_decimal.format(decimal=re.escape(decimal))
- switch_decimal = re.compile(switch_decimal)
- function_chain.append(partial(switch_decimal.sub, "."))
+ switch_decimal_re = re.compile(switch_decimal)
+ function_chain.append(partial(switch_decimal_re.sub, "."))
# Return the chained functions.
return chain_functions(function_chain)
-def string_component_transform_factory(alg):
+def string_component_transform_factory(alg: NSType) -> StrTransformer:
"""
Create a function to either transform a string or convert to a number.
@@ -545,23 +657,26 @@ def string_component_transform_factory(alg):
nan_val = float("+inf") if alg & ns.NANLAST else float("-inf")
# Build the chain of functions to execute in order.
- func_chain = []
+ func_chain: List[Callable[[str], StrOrBytes]] = []
if group_letters:
func_chain.append(groupletters)
if use_locale:
func_chain.append(get_strxfrm())
- kwargs = {"key": chain_functions(func_chain)} if func_chain else {}
# Return the correct chained functions.
+ kwargs: Dict[str, Union[float, Callable[[str], StrOrBytes]]]
+ kwargs = {"key": chain_functions(func_chain)} if func_chain else {}
if alg & ns.FLOAT:
# noinspection PyTypeChecker
kwargs["nan"] = nan_val
- return partial(fast_float, **kwargs)
+ return cast(Callable[[str], StrOrBytes], partial(fast_float, **kwargs))
else:
- return partial(fast_int, **kwargs)
+ return cast(Callable[[str], StrOrBytes], partial(fast_int, **kwargs))
-def final_data_transform_factory(alg, sep, pre_sep):
+def final_data_transform_factory(
+ alg: NSType, sep: StrOrBytes, pre_sep: str
+) -> FinalTransformer:
"""
Create a function to transform a tuple.
@@ -589,9 +704,15 @@ def final_data_transform_factory(alg, sep, pre_sep):
"""
if alg & ns.UNGROUPLETTERS and alg & ns.LOCALEALPHA:
swap = alg & NS_DUMB and alg & ns.LOWERCASEFIRST
- transform = methodcaller("swapcase") if swap else _no_op
-
- def func(split_val, val, _transform=transform, _sep=sep, _pre_sep=pre_sep):
+ transform = cast(StrToStr, methodcaller("swapcase")) if swap else _no_op
+
+ def func(
+ split_val: Iterable[NatsortInType],
+ val: str,
+ _transform: StrToStr = transform,
+ _sep: StrOrBytes = sep,
+ _pre_sep: str = pre_sep,
+ ) -> FinalTransform:
"""
Return a tuple with the first character of the first element
of the return value as the first element, and the return value
@@ -606,16 +727,25 @@ def final_data_transform_factory(alg, sep, pre_sep):
else:
return (_transform(val[0]),), split_val
- return func
else:
- return lambda split_val, val: tuple(split_val)
+ def func(
+ split_val: Iterable[NatsortInType],
+ val: str,
+ _transform: StrToStr = _no_op,
+ _sep: StrOrBytes = sep,
+ _pre_sep: str = pre_sep,
+ ) -> FinalTransform:
+ return tuple(split_val)
+
+ return func
-lower_function = methodcaller("casefold")
+
+lower_function: StrToStr = cast(StrToStr, methodcaller("casefold"))
# noinspection PyIncorrectDocstring
-def groupletters(x, _low=lower_function):
+def groupletters(x: str, _low: StrToStr = lower_function) -> str:
"""
Double all characters, making doubled letters lowercase.
@@ -637,7 +767,7 @@ def groupletters(x, _low=lower_function):
return "".join(ichain.from_iterable((_low(y), y) for y in x))
-def chain_functions(functions):
+def chain_functions(functions: Iterable[AnyCall]) -> AnyCall:
"""
Chain a list of single-argument functions together and return.
@@ -674,7 +804,17 @@ def chain_functions(functions):
return partial(reduce, lambda res, f: f(res), functions)
-def do_decoding(s, encoding):
+@overload
+def do_decoding(s: bytes, encoding: str) -> str:
+ ...
+
+
+@overload
+def do_decoding(s: NatsortInType, encoding: str) -> NatsortInType:
+ ...
+
+
+def do_decoding(s: NatsortInType, encoding: str) -> NatsortInType:
"""
Helper to decode a *bytes* object, or return the object as-is.
@@ -692,13 +832,15 @@ def do_decoding(s, encoding):
"""
try:
- return s.decode(encoding)
+ return cast(bytes, s).decode(encoding)
except (AttributeError, TypeError):
return s
# noinspection PyIncorrectDocstring
-def path_splitter(s, _d_match=re.compile(r"\.\d").match):
+def path_splitter(
+ s: PathArg, _d_match: MatchFn = re.compile(r"\.\d").match
+) -> Iterator[str]:
"""
Split a string into its path components.