diff options
Diffstat (limited to 'natsort/utils.py')
-rw-r--r-- | natsort/utils.py | 52 |
1 files changed, 34 insertions, 18 deletions
diff --git a/natsort/utils.py b/natsort/utils.py index 2d15e39..27791d7 100644 --- a/natsort/utils.py +++ b/natsort/utils.py @@ -115,9 +115,18 @@ MatchFn = Callable[[str], Optional[Match]] PathSplitter = Callable[[PathArg], Tuple[FinalTransform, ...]] # For the natsort key -NatsortInType = Union[StrBytesNum, Iterable[Union[StrBytesNum, Iterable]]] -NatsortOutType = Tuple[Union[StrBytesNum, Tuple[Union[StrBytesNum, tuple], ...]], ...] +NatsortIterType = Iterable[Union[StrBytesNum, Iterable[Any_t]]] +NatsortInType = Union[StrBytesNum, NatsortIterType] +NatsortOutElement = Union[ + FinalTransform, + Tuple[FinalTransform], + StrBytesNum, + MaybeNumTransform, + BytesTransform, +] +NatsortOutType = Union[NatsortOutElement, Tuple[Union[NatsortOutElement, tuple], ...]] KeyType = Callable[[Any_t], NatsortInType] +MaybeKeyType = Optional[KeyType] class NumericalRegularExpressions: @@ -139,42 +148,42 @@ class NumericalRegularExpressions: float_num: str = r"(?:\d+\.?\d*|\.\d+)" @classmethod - def _construct_regex(cls, fmt: str) -> Pattern: + def _construct_regex(cls, fmt: str) -> Pattern[str]: """Given a format string, construct the regex with class attributes.""" return re.compile(fmt.format(**vars(cls)), flags=re.U) @classmethod - def int_sign(cls) -> Pattern: + def int_sign(cls) -> Pattern[str]: """Regular expression to match a signed int.""" return cls._construct_regex(r"([-+]?\d+|[{digits}])") @classmethod - def int_nosign(cls) -> Pattern: + def int_nosign(cls) -> Pattern[str]: """Regular expression to match an unsigned int.""" return cls._construct_regex(r"(\d+|[{digits}])") @classmethod - def float_sign_exp(cls) -> Pattern: + def float_sign_exp(cls) -> Pattern[str]: """Regular expression to match a signed float with exponent.""" return cls._construct_regex(r"([-+]?{float_num}{exp}|[{numeric}])") @classmethod - def float_nosign_exp(cls) -> Pattern: + def float_nosign_exp(cls) -> Pattern[str]: """Regular expression to match an unsigned float with exponent.""" return cls._construct_regex(r"({float_num}{exp}|[{numeric}])") @classmethod - def float_sign_noexp(cls) -> Pattern: + def float_sign_noexp(cls) -> Pattern[str]: """Regular expression to match a signed float without exponent.""" return cls._construct_regex(r"([-+]?{float_num}|[{numeric}])") @classmethod - def float_nosign_noexp(cls) -> Pattern: + def float_nosign_noexp(cls) -> Pattern[str]: """Regular expression to match an unsigned float without exponent.""" return cls._construct_regex(r"({float_num}|[{numeric}])") -def regex_chooser(alg: NS_t) -> Pattern: +def regex_chooser(alg: NS_t) -> Pattern[str]: """ Select an appropriate regex for the type of number of interest. @@ -251,7 +260,13 @@ def natsort_key( ... -def natsort_key(val, key, string_func, bytes_func, num_func): +def natsort_key( + val: Union[NatsortInType, Any_t], + key: MaybeKeyType, + string_func: Union[StrParser, PathSplitter], + bytes_func: BytesTransformer, + num_func: MaybeNumTransformer, +) -> NatsortOutType: """ Key to sort strings and numbers naturally. @@ -300,23 +315,24 @@ def natsort_key(val, key, string_func, bytes_func, num_func): # Assume the input are strings, which is the most common case try: - return string_func(val) + return string_func(cast(str, val)) except (TypeError, AttributeError): # If bytes type, use the bytes_func if type(val) in (bytes,): - return bytes_func(val) + return bytes_func(cast(bytes, val)) # Otherwise, assume it is an iterable that must be parsed recursively. # Do not apply the key recursively. try: return tuple( - natsort_key(x, None, string_func, bytes_func, num_func) for x in val + natsort_key(x, None, string_func, bytes_func, num_func) + for x in cast(NatsortIterType, val) ) # If that failed, it must be a number. except TypeError: - return num_func(val) + return num_func(cast(NumType, val)) def parse_bytes_factory(alg: NS_t) -> BytesTransformer: @@ -502,7 +518,7 @@ def parse_path_factory(str_split: StrParser) -> PathSplitter: return lambda x: tuple(map(str_split, path_splitter(x))) -def sep_inserter(iterator: Iterator, sep: StrOrBytes) -> Iterator: +def sep_inserter(iterator: Iterator[Any_t], sep: StrOrBytes) -> Iterator[Any_t]: """ Insert '' between numbers in an iterator. @@ -801,7 +817,7 @@ def do_decoding(s: NatsortInType, encoding: str) -> NatsortInType: ... -def do_decoding(s, encoding): +def do_decoding(s: NatsortInType, encoding: str) -> NatsortInType: """ Helper to decode a *bytes* object, or return the object as-is. @@ -819,7 +835,7 @@ def do_decoding(s, encoding): """ try: - return s.decode(encoding) + return cast(bytes, s).decode(encoding) except (AttributeError, TypeError): return s |