summaryrefslogtreecommitdiff
path: root/natsort/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'natsort/utils.py')
-rw-r--r--natsort/utils.py52
1 files changed, 34 insertions, 18 deletions
diff --git a/natsort/utils.py b/natsort/utils.py
index 2d15e39..27791d7 100644
--- a/natsort/utils.py
+++ b/natsort/utils.py
@@ -115,9 +115,18 @@ MatchFn = Callable[[str], Optional[Match]]
PathSplitter = Callable[[PathArg], Tuple[FinalTransform, ...]]
# For the natsort key
-NatsortInType = Union[StrBytesNum, Iterable[Union[StrBytesNum, Iterable]]]
-NatsortOutType = Tuple[Union[StrBytesNum, Tuple[Union[StrBytesNum, tuple], ...]], ...]
+NatsortIterType = Iterable[Union[StrBytesNum, Iterable[Any_t]]]
+NatsortInType = Union[StrBytesNum, NatsortIterType]
+NatsortOutElement = Union[
+ FinalTransform,
+ Tuple[FinalTransform],
+ StrBytesNum,
+ MaybeNumTransform,
+ BytesTransform,
+]
+NatsortOutType = Union[NatsortOutElement, Tuple[Union[NatsortOutElement, tuple], ...]]
KeyType = Callable[[Any_t], NatsortInType]
+MaybeKeyType = Optional[KeyType]
class NumericalRegularExpressions:
@@ -139,42 +148,42 @@ class NumericalRegularExpressions:
float_num: str = r"(?:\d+\.?\d*|\.\d+)"
@classmethod
- def _construct_regex(cls, fmt: str) -> Pattern:
+ def _construct_regex(cls, fmt: str) -> Pattern[str]:
"""Given a format string, construct the regex with class attributes."""
return re.compile(fmt.format(**vars(cls)), flags=re.U)
@classmethod
- def int_sign(cls) -> Pattern:
+ def int_sign(cls) -> Pattern[str]:
"""Regular expression to match a signed int."""
return cls._construct_regex(r"([-+]?\d+|[{digits}])")
@classmethod
- def int_nosign(cls) -> Pattern:
+ def int_nosign(cls) -> Pattern[str]:
"""Regular expression to match an unsigned int."""
return cls._construct_regex(r"(\d+|[{digits}])")
@classmethod
- def float_sign_exp(cls) -> Pattern:
+ def float_sign_exp(cls) -> Pattern[str]:
"""Regular expression to match a signed float with exponent."""
return cls._construct_regex(r"([-+]?{float_num}{exp}|[{numeric}])")
@classmethod
- def float_nosign_exp(cls) -> Pattern:
+ def float_nosign_exp(cls) -> Pattern[str]:
"""Regular expression to match an unsigned float with exponent."""
return cls._construct_regex(r"({float_num}{exp}|[{numeric}])")
@classmethod
- def float_sign_noexp(cls) -> Pattern:
+ def float_sign_noexp(cls) -> Pattern[str]:
"""Regular expression to match a signed float without exponent."""
return cls._construct_regex(r"([-+]?{float_num}|[{numeric}])")
@classmethod
- def float_nosign_noexp(cls) -> Pattern:
+ def float_nosign_noexp(cls) -> Pattern[str]:
"""Regular expression to match an unsigned float without exponent."""
return cls._construct_regex(r"({float_num}|[{numeric}])")
-def regex_chooser(alg: NS_t) -> Pattern:
+def regex_chooser(alg: NS_t) -> Pattern[str]:
"""
Select an appropriate regex for the type of number of interest.
@@ -251,7 +260,13 @@ def natsort_key(
...
-def natsort_key(val, key, string_func, bytes_func, num_func):
+def natsort_key(
+ val: Union[NatsortInType, Any_t],
+ key: MaybeKeyType,
+ string_func: Union[StrParser, PathSplitter],
+ bytes_func: BytesTransformer,
+ num_func: MaybeNumTransformer,
+) -> NatsortOutType:
"""
Key to sort strings and numbers naturally.
@@ -300,23 +315,24 @@ def natsort_key(val, key, string_func, bytes_func, num_func):
# Assume the input are strings, which is the most common case
try:
- return string_func(val)
+ return string_func(cast(str, val))
except (TypeError, AttributeError):
# If bytes type, use the bytes_func
if type(val) in (bytes,):
- return bytes_func(val)
+ return bytes_func(cast(bytes, val))
# Otherwise, assume it is an iterable that must be parsed recursively.
# Do not apply the key recursively.
try:
return tuple(
- natsort_key(x, None, string_func, bytes_func, num_func) for x in val
+ natsort_key(x, None, string_func, bytes_func, num_func)
+ for x in cast(NatsortIterType, val)
)
# If that failed, it must be a number.
except TypeError:
- return num_func(val)
+ return num_func(cast(NumType, val))
def parse_bytes_factory(alg: NS_t) -> BytesTransformer:
@@ -502,7 +518,7 @@ def parse_path_factory(str_split: StrParser) -> PathSplitter:
return lambda x: tuple(map(str_split, path_splitter(x)))
-def sep_inserter(iterator: Iterator, sep: StrOrBytes) -> Iterator:
+def sep_inserter(iterator: Iterator[Any_t], sep: StrOrBytes) -> Iterator[Any_t]:
"""
Insert '' between numbers in an iterator.
@@ -801,7 +817,7 @@ def do_decoding(s: NatsortInType, encoding: str) -> NatsortInType:
...
-def do_decoding(s, encoding):
+def do_decoding(s: NatsortInType, encoding: str) -> NatsortInType:
"""
Helper to decode a *bytes* object, or return the object as-is.
@@ -819,7 +835,7 @@ def do_decoding(s, encoding):
"""
try:
- return s.decode(encoding)
+ return cast(bytes, s).decode(encoding)
except (AttributeError, TypeError):
return s