diff options
Diffstat (limited to 'natsort/utils.py')
-rw-r--r-- | natsort/utils.py | 278 |
1 files changed, 210 insertions, 68 deletions
diff --git a/natsort/utils.py b/natsort/utils.py index 3f6641b..7102f41 100644 --- a/natsort/utils.py +++ b/natsort/utils.py @@ -38,19 +38,93 @@ that ensures "val" is a local variable instead of global variable and thus has a slightly improved performance at runtime. """ - import re from functools import partial, reduce from itertools import chain as ichain from operator import methodcaller from pathlib import PurePath +from typing import ( + Any, + Callable, + Dict, + Iterable, + Iterator, + List, + Match, + Optional, + Pattern, + Tuple, + Union, + cast, + overload, +) from unicodedata import normalize from natsort.compat.fastnumbers import fast_float, fast_int -from natsort.compat.locale import get_decimal_point, get_strxfrm, get_thousands_sep -from natsort.ns_enum import NS_DUMB, ns +from natsort.compat.locale import ( + StrOrBytes, + get_decimal_point, + get_strxfrm, + get_thousands_sep, +) +from natsort.ns_enum import NSType, NS_DUMB, ns from natsort.unicode_numbers import digits_no_decimals, numeric_no_decimals +# +# Pre-define a slew of aggregate types which makes the type hinting below easier +# +StrToStr = Callable[[str], str] +AnyCall = Callable[[Any], Any] + +# For the bytes transform factory +BytesTuple = Tuple[bytes] +NestedBytesTuple = Tuple[Tuple[bytes]] +BytesTransform = Union[BytesTuple, NestedBytesTuple] +BytesTransformer = Callable[[bytes], BytesTransform] + +# For the number transform factory +NumType = Union[float, int] +MaybeNumType = Optional[NumType] +NumTuple = Tuple[StrOrBytes, NumType] +NestedNumTuple = Tuple[NumTuple] +StrNumTuple = Tuple[Tuple[str], NumTuple] +NestedStrNumTuple = Tuple[StrNumTuple] +MaybeNumTransform = Union[NumTuple, NestedNumTuple, StrNumTuple, NestedStrNumTuple] +MaybeNumTransformer = Callable[[MaybeNumType], MaybeNumTransform] + +# For the string component transform factory +StrBytesNum = Union[str, bytes, float, int] +StrTransformer = Callable[[str], StrBytesNum] + +# For the final data transform factory +TwoBlankTuple = Tuple[Tuple[()], Tuple[()]] +TupleOfAny = Tuple[Any, ...] +TupleOfStrAnyPair = Tuple[Tuple[str], TupleOfAny] +FinalTransform = Union[TwoBlankTuple, TupleOfAny, TupleOfStrAnyPair] +FinalTransformer = Callable[[Iterable[Any], str], FinalTransform] + +# For the string parsing factory +StrSplitter = Callable[[str], Iterable[str]] +StrParser = Callable[[str], FinalTransform] + +# For the path splitter +PathArg = Union[str, PurePath] +MatchFn = Callable[[str], Optional[Match]] + +# For the path parsing factory +PathSplitter = Callable[[PathArg], Tuple[FinalTransform, ...]] + +# For the natsort key +StrBytesPathNum = Union[str, bytes, float, int, PurePath] +NatsortInType = Union[ + Optional[StrBytesPathNum], Iterable[Union[Optional[StrBytesPathNum], Iterable[Any]]] +] +NatsortOutType = Tuple[ + Union[StrBytesNum, Tuple[Union[StrBytesNum, Tuple[Any, ...]], ...]], ... +] +KeyType = Callable[[Any], NatsortInType] +MaybeKeyType = Optional[KeyType] + class NumericalRegularExpressions: """ @@ -62,51 +136,51 @@ class NumericalRegularExpressions: """ # All unicode numeric characters (minus the decimal characters). - numeric = numeric_no_decimals + numeric: str = numeric_no_decimals # All unicode digit characters (minus the decimal characters). - digits = digits_no_decimals + digits: str = digits_no_decimals # Regular expression to match exponential component of a float. - exp = r"(?:[eE][-+]?\d+)?" + exp: str = r"(?:[eE][-+]?\d+)?" # Regular expression to match a floating point number. - float_num = r"(?:\d+\.?\d*|\.\d+)" + float_num: str = r"(?:\d+\.?\d*|\.\d+)" @classmethod - def _construct_regex(cls, fmt): + def _construct_regex(cls, fmt: str) -> Pattern[str]: """Given a format string, construct the regex with class attributes.""" return re.compile(fmt.format(**vars(cls)), flags=re.U) @classmethod - def int_sign(cls): + def int_sign(cls) -> Pattern[str]: """Regular expression to match a signed int.""" return cls._construct_regex(r"([-+]?\d+|[{digits}])") @classmethod - def int_nosign(cls): + def int_nosign(cls) -> Pattern[str]: """Regular expression to match an unsigned int.""" return cls._construct_regex(r"(\d+|[{digits}])") @classmethod - def float_sign_exp(cls): + def float_sign_exp(cls) -> Pattern[str]: """Regular expression to match a signed float with exponent.""" return cls._construct_regex(r"([-+]?{float_num}{exp}|[{numeric}])") @classmethod - def float_nosign_exp(cls): + def float_nosign_exp(cls) -> Pattern[str]: """Regular expression to match an unsigned float with exponent.""" return cls._construct_regex(r"({float_num}{exp}|[{numeric}])") @classmethod - def float_sign_noexp(cls): + def float_sign_noexp(cls) -> Pattern[str]: """Regular expression to match a signed float without exponent.""" return cls._construct_regex(r"([-+]?{float_num}|[{numeric}])") @classmethod - def float_nosign_noexp(cls): + def float_nosign_noexp(cls) -> Pattern[str]: """Regular expression to match an unsigned float without exponent.""" return cls._construct_regex(r"({float_num}|[{numeric}])") -def regex_chooser(alg): +def regex_chooser(alg: NSType) -> Pattern[str]: """ Select an appropriate regex for the type of number of interest. @@ -136,12 +210,12 @@ def regex_chooser(alg): }[alg] -def _no_op(x): +def _no_op(x: Any) -> Any: """A function that does nothing and returns the input as-is.""" return x -def _normalize_input_factory(alg): +def _normalize_input_factory(alg: NSType) -> StrToStr: """ Create a function that will normalize unicode input data. @@ -161,7 +235,35 @@ def _normalize_input_factory(alg): return partial(normalize, normalization_form) -def natsort_key(val, key, string_func, bytes_func, num_func): +@overload +def natsort_key( + val: NatsortInType, + key: None, + string_func: Union[StrParser, PathSplitter], + bytes_func: BytesTransformer, + num_func: MaybeNumTransformer, +) -> NatsortOutType: + ... + + +@overload +def natsort_key( + val: Any, + key: KeyType, + string_func: Union[StrParser, PathSplitter], + bytes_func: BytesTransformer, + num_func: MaybeNumTransformer, +) -> NatsortOutType: + ... + + +def natsort_key( + val: Union[NatsortInType, Any], + key: MaybeKeyType, + string_func: Union[StrParser, PathSplitter], + bytes_func: BytesTransformer, + num_func: MaybeNumTransformer, +) -> NatsortOutType: """ Key to sort strings and numbers naturally. @@ -170,7 +272,7 @@ def natsort_key(val, key, string_func, bytes_func, num_func): Parameters ---------- - val : str | unicode | bytes | int | float | iterable + val : str | bytes | int | float | iterable key : callable | None A key to apply to the *val* before any other operations are performed. string_func : callable @@ -210,26 +312,27 @@ def natsort_key(val, key, string_func, bytes_func, num_func): # Assume the input are strings, which is the most common case try: - return string_func(val) + return string_func(cast(str, val)) except (TypeError, AttributeError): # If bytes type, use the bytes_func if type(val) in (bytes,): - return bytes_func(val) + return bytes_func(cast(bytes, val)) # Otherwise, assume it is an iterable that must be parsed recursively. # Do not apply the key recursively. try: return tuple( - natsort_key(x, None, string_func, bytes_func, num_func) for x in val + natsort_key(x, None, string_func, bytes_func, num_func) + for x in cast(Iterable[Any], val) ) # If that failed, it must be a number. except TypeError: - return num_func(val) + return num_func(cast(NumType, val)) -def parse_bytes_factory(alg): +def parse_bytes_factory(alg: NSType) -> BytesTransformer: """ Create a function that will format a *bytes* object into a tuple. @@ -262,7 +365,9 @@ def parse_bytes_factory(alg): return lambda x: (x,) -def parse_number_or_none_factory(alg, sep, pre_sep): +def parse_number_or_none_factory( + alg: NSType, sep: StrOrBytes, pre_sep: str +) -> MaybeNumTransformer: """ Create a function that will format a number (or None) into a tuple. @@ -293,7 +398,9 @@ def parse_number_or_none_factory(alg, sep, pre_sep): """ nan_replace = float("+inf") if alg & ns.NANLAST else float("-inf") - def func(val, _nan_replace=nan_replace, _sep=sep): + def func( + val: MaybeNumType, _nan_replace: float = nan_replace, _sep: StrOrBytes = sep + ) -> NumTuple: """Given a number, place it in a tuple with a leading null string.""" return _sep, (_nan_replace if val != val or val is None else val) @@ -309,8 +416,13 @@ def parse_number_or_none_factory(alg, sep, pre_sep): def parse_string_factory( - alg, sep, splitter, input_transform, component_transform, final_transform -): + alg: NSType, + sep: StrOrBytes, + splitter: StrSplitter, + input_transform: StrToStr, + component_transform: StrTransformer, + final_transform: FinalTransformer, +) -> StrParser: """ Create a function that will split and format a *str* into a tuple. @@ -361,22 +473,22 @@ def parse_string_factory( original_func = input_transform if orig_after_xfrm else _no_op normalize_input = _normalize_input_factory(alg) - def func(x): + def func(x: str) -> FinalTransform: # Apply string input transformation function and return to x. # Original function is usually a no-op, but some algorithms require it # to also be the transformation function. - x = normalize_input(x) - x, original = input_transform(x), original_func(x) - x = splitter(x) # Split string into components. - x = filter(None, x) # Remove empty strings. - x = map(component_transform, x) # Apply transform on components. - x = sep_inserter(x, sep) # Insert '' between numbers. - return final_transform(x, original) # Apply the final transform. + a = normalize_input(x) + b, original = input_transform(a), original_func(a) + c = splitter(b) # Split string into components. + d = filter(None, c) # Remove empty strings. + e = map(component_transform, d) # Apply transform on components. + f = sep_inserter(e, sep) # Insert '' between numbers. + return final_transform(f, original) # Apply the final transform. return func -def parse_path_factory(str_split): +def parse_path_factory(str_split: StrParser) -> PathSplitter: """ Create a function that will properly split and format a path. @@ -403,41 +515,41 @@ def parse_path_factory(str_split): return lambda x: tuple(map(str_split, path_splitter(x))) -def sep_inserter(iterable, sep): +def sep_inserter(iterator: Iterator[Any], sep: StrOrBytes) -> Iterator[Any]: """ - Insert '' between numbers in an iterable. + Insert '' between numbers in an iterator. Parameters ---------- - iterable + iterator sep : str The string character to be inserted between adjacent numeric objects. Yields ------ - The values of *iterable* in order, with *sep* inserted where adjacent + The values of *iterator* in order, with *sep* inserted where adjacent elements are numeric. If the first element in the input is numeric then *sep* will be the first value yielded. """ try: - # Get the first element. A StopIteration indicates an empty iterable. + # Get the first element. A StopIteration indicates an empty iterator. # Since we are controlling the types of the input, 'type' is used # instead of 'isinstance' for the small speed advantage it offers. types = (int, float) - first = next(iterable) + first = next(iterator) if type(first) in types: yield sep yield first # Now, check if pair of elements are both numbers. If so, add ''. - second = next(iterable) + second = next(iterator) if type(first) in types and type(second) in types: yield sep yield second # Now repeat in a loop. - for x in iterable: + for x in iterator: first, second = second, x if type(first) in types and type(second) in types: yield sep @@ -448,7 +560,7 @@ def sep_inserter(iterable, sep): return -def input_string_transform_factory(alg): +def input_string_transform_factory(alg: NSType) -> StrToStr: """ Create a function to transform a string. @@ -473,7 +585,7 @@ def input_string_transform_factory(alg): dumb = alg & NS_DUMB # Build the chain of functions to execute in order. - function_chain = [] + function_chain: List[StrToStr] = [] if (dumb and not lowfirst) or (lowfirst and not dumb): function_chain.append(methodcaller("swapcase")) @@ -502,8 +614,8 @@ def input_string_transform_factory(alg): strip_thousands = strip_thousands.format( thou=re.escape(get_thousands_sep()), nodecimal=nodecimal ) - strip_thousands = re.compile(strip_thousands, flags=re.VERBOSE) - function_chain.append(partial(strip_thousands.sub, "")) + strip_thousands_re = re.compile(strip_thousands, flags=re.VERBOSE) + function_chain.append(partial(strip_thousands_re.sub, "")) # Create a regular expression that will change the decimal point to # a period if not already a period. @@ -511,14 +623,14 @@ def input_string_transform_factory(alg): if alg & ns.FLOAT and decimal != ".": switch_decimal = r"(?<=[0-9]){decimal}|{decimal}(?=[0-9])" switch_decimal = switch_decimal.format(decimal=re.escape(decimal)) - switch_decimal = re.compile(switch_decimal) - function_chain.append(partial(switch_decimal.sub, ".")) + switch_decimal_re = re.compile(switch_decimal) + function_chain.append(partial(switch_decimal_re.sub, ".")) # Return the chained functions. return chain_functions(function_chain) -def string_component_transform_factory(alg): +def string_component_transform_factory(alg: NSType) -> StrTransformer: """ Create a function to either transform a string or convert to a number. @@ -545,23 +657,26 @@ def string_component_transform_factory(alg): nan_val = float("+inf") if alg & ns.NANLAST else float("-inf") # Build the chain of functions to execute in order. - func_chain = [] + func_chain: List[Callable[[str], StrOrBytes]] = [] if group_letters: func_chain.append(groupletters) if use_locale: func_chain.append(get_strxfrm()) - kwargs = {"key": chain_functions(func_chain)} if func_chain else {} # Return the correct chained functions. + kwargs: Dict[str, Union[float, Callable[[str], StrOrBytes]]] + kwargs = {"key": chain_functions(func_chain)} if func_chain else {} if alg & ns.FLOAT: # noinspection PyTypeChecker kwargs["nan"] = nan_val - return partial(fast_float, **kwargs) + return cast(Callable[[str], StrOrBytes], partial(fast_float, **kwargs)) else: - return partial(fast_int, **kwargs) + return cast(Callable[[str], StrOrBytes], partial(fast_int, **kwargs)) -def final_data_transform_factory(alg, sep, pre_sep): +def final_data_transform_factory( + alg: NSType, sep: StrOrBytes, pre_sep: str +) -> FinalTransformer: """ Create a function to transform a tuple. @@ -589,9 +704,15 @@ def final_data_transform_factory(alg, sep, pre_sep): """ if alg & ns.UNGROUPLETTERS and alg & ns.LOCALEALPHA: swap = alg & NS_DUMB and alg & ns.LOWERCASEFIRST - transform = methodcaller("swapcase") if swap else _no_op - - def func(split_val, val, _transform=transform, _sep=sep, _pre_sep=pre_sep): + transform = cast(StrToStr, methodcaller("swapcase")) if swap else _no_op + + def func( + split_val: Iterable[NatsortInType], + val: str, + _transform: StrToStr = transform, + _sep: StrOrBytes = sep, + _pre_sep: str = pre_sep, + ) -> FinalTransform: """ Return a tuple with the first character of the first element of the return value as the first element, and the return value @@ -606,16 +727,25 @@ def final_data_transform_factory(alg, sep, pre_sep): else: return (_transform(val[0]),), split_val - return func else: - return lambda split_val, val: tuple(split_val) + def func( + split_val: Iterable[NatsortInType], + val: str, + _transform: StrToStr = _no_op, + _sep: StrOrBytes = sep, + _pre_sep: str = pre_sep, + ) -> FinalTransform: + return tuple(split_val) + + return func -lower_function = methodcaller("casefold") + +lower_function: StrToStr = cast(StrToStr, methodcaller("casefold")) # noinspection PyIncorrectDocstring -def groupletters(x, _low=lower_function): +def groupletters(x: str, _low: StrToStr = lower_function) -> str: """ Double all characters, making doubled letters lowercase. @@ -637,7 +767,7 @@ def groupletters(x, _low=lower_function): return "".join(ichain.from_iterable((_low(y), y) for y in x)) -def chain_functions(functions): +def chain_functions(functions: Iterable[AnyCall]) -> AnyCall: """ Chain a list of single-argument functions together and return. @@ -674,7 +804,17 @@ def chain_functions(functions): return partial(reduce, lambda res, f: f(res), functions) -def do_decoding(s, encoding): +@overload +def do_decoding(s: bytes, encoding: str) -> str: + ... + + +@overload +def do_decoding(s: NatsortInType, encoding: str) -> NatsortInType: + ... + + +def do_decoding(s: NatsortInType, encoding: str) -> NatsortInType: """ Helper to decode a *bytes* object, or return the object as-is. @@ -692,13 +832,15 @@ def do_decoding(s, encoding): """ try: - return s.decode(encoding) + return cast(bytes, s).decode(encoding) except (AttributeError, TypeError): return s # noinspection PyIncorrectDocstring -def path_splitter(s, _d_match=re.compile(r"\.\d").match): +def path_splitter( + s: PathArg, _d_match: MatchFn = re.compile(r"\.\d").match +) -> Iterator[str]: """ Split a string into its path components. |