diff options
Diffstat (limited to 'natsort/compat/locale.py')
-rw-r--r-- | natsort/compat/locale.py | 43 |
1 files changed, 25 insertions, 18 deletions
diff --git a/natsort/compat/locale.py b/natsort/compat/locale.py index ccb5592..9af5e7a 100644 --- a/natsort/compat/locale.py +++ b/natsort/compat/locale.py @@ -3,9 +3,11 @@ Interface for natsort to access locale functionality without having to worry about if it is using PyICU or the built-in locale. """ - -# Std. lib imports. import sys +from typing import Callable, Union, cast + +StrOrBytes = Union[str, bytes] +TrxfmFunc = Callable[[str], StrOrBytes] # This string should be sorted after any other byte string because # it contains the max unicode character repeated 20 times. @@ -13,6 +15,11 @@ import sys null_string = "" null_string_max = chr(sys.maxunicode) * 20 +# This variable could be str or bytes depending on the locale library +# being used, so give the type-checker this information. +null_string_locale: StrOrBytes +null_string_locale_max: StrOrBytes + # strxfrm can be buggy (especially on BSD-based systems), # so prefer icu if available. try: # noqa: C901 @@ -26,26 +33,26 @@ try: # noqa: C901 # You would need some odd data to come after that. null_string_locale_max = b"x7f" * 50 - def dumb_sort(): + def dumb_sort() -> bool: return False # If using icu, get the locale from the current global locale, - def get_icu_locale(): + def get_icu_locale() -> str: try: - return icu.Locale(".".join(getlocale())) + return cast(str, icu.Locale(".".join(getlocale()))) except TypeError: # pragma: no cover - return icu.Locale() + return cast(str, icu.Locale()) - def get_strxfrm(): - return icu.Collator.createInstance(get_icu_locale()).getSortKey + def get_strxfrm() -> TrxfmFunc: + return cast(TrxfmFunc, icu.Collator.createInstance(get_icu_locale()).getSortKey) - def get_thousands_sep(): + def get_thousands_sep() -> str: sep = icu.DecimalFormatSymbols.kGroupingSeparatorSymbol - return icu.DecimalFormatSymbols(get_icu_locale()).getSymbol(sep) + return cast(str, icu.DecimalFormatSymbols(get_icu_locale()).getSymbol(sep)) - def get_decimal_point(): + def get_decimal_point() -> str: sep = icu.DecimalFormatSymbols.kDecimalSeparatorSymbol - return icu.DecimalFormatSymbols(get_icu_locale()).getSymbol(sep) + return cast(str, icu.DecimalFormatSymbols(get_icu_locale()).getSymbol(sep)) except ImportError: @@ -57,14 +64,14 @@ except ImportError: # On some systems, locale is broken and does not sort in the expected # order. We will try to detect this and compensate. - def dumb_sort(): + def dumb_sort() -> bool: return strxfrm("A") < strxfrm("a") - def get_strxfrm(): + def get_strxfrm() -> TrxfmFunc: return strxfrm - def get_thousands_sep(): - sep = locale.localeconv()["thousands_sep"] + def get_thousands_sep() -> str: + sep = cast(str, locale.localeconv()["thousands_sep"]) # If this locale library is broken, some of the thousands separator # characters are incorrectly blank. Here is a lookup table of the # corrections I am aware of. @@ -111,5 +118,5 @@ except ImportError: else: return sep - def get_decimal_point(): - return locale.localeconv()["decimal_point"] + def get_decimal_point() -> str: + return cast(str, locale.localeconv()["decimal_point"]) |