summaryrefslogtreecommitdiff
path: root/natsort/compat/locale.py
diff options
context:
space:
mode:
Diffstat (limited to 'natsort/compat/locale.py')
-rw-r--r--natsort/compat/locale.py43
1 files changed, 25 insertions, 18 deletions
diff --git a/natsort/compat/locale.py b/natsort/compat/locale.py
index ccb5592..9af5e7a 100644
--- a/natsort/compat/locale.py
+++ b/natsort/compat/locale.py
@@ -3,9 +3,11 @@
Interface for natsort to access locale functionality without
having to worry about if it is using PyICU or the built-in locale.
"""
-
-# Std. lib imports.
import sys
+from typing import Callable, Union, cast
+
+StrOrBytes = Union[str, bytes]
+TrxfmFunc = Callable[[str], StrOrBytes]
# This string should be sorted after any other byte string because
# it contains the max unicode character repeated 20 times.
@@ -13,6 +15,11 @@ import sys
null_string = ""
null_string_max = chr(sys.maxunicode) * 20
+# This variable could be str or bytes depending on the locale library
+# being used, so give the type-checker this information.
+null_string_locale: StrOrBytes
+null_string_locale_max: StrOrBytes
+
# strxfrm can be buggy (especially on BSD-based systems),
# so prefer icu if available.
try: # noqa: C901
@@ -26,26 +33,26 @@ try: # noqa: C901
# You would need some odd data to come after that.
null_string_locale_max = b"x7f" * 50
- def dumb_sort():
+ def dumb_sort() -> bool:
return False
# If using icu, get the locale from the current global locale,
- def get_icu_locale():
+ def get_icu_locale() -> str:
try:
- return icu.Locale(".".join(getlocale()))
+ return cast(str, icu.Locale(".".join(getlocale())))
except TypeError: # pragma: no cover
- return icu.Locale()
+ return cast(str, icu.Locale())
- def get_strxfrm():
- return icu.Collator.createInstance(get_icu_locale()).getSortKey
+ def get_strxfrm() -> TrxfmFunc:
+ return cast(TrxfmFunc, icu.Collator.createInstance(get_icu_locale()).getSortKey)
- def get_thousands_sep():
+ def get_thousands_sep() -> str:
sep = icu.DecimalFormatSymbols.kGroupingSeparatorSymbol
- return icu.DecimalFormatSymbols(get_icu_locale()).getSymbol(sep)
+ return cast(str, icu.DecimalFormatSymbols(get_icu_locale()).getSymbol(sep))
- def get_decimal_point():
+ def get_decimal_point() -> str:
sep = icu.DecimalFormatSymbols.kDecimalSeparatorSymbol
- return icu.DecimalFormatSymbols(get_icu_locale()).getSymbol(sep)
+ return cast(str, icu.DecimalFormatSymbols(get_icu_locale()).getSymbol(sep))
except ImportError:
@@ -57,14 +64,14 @@ except ImportError:
# On some systems, locale is broken and does not sort in the expected
# order. We will try to detect this and compensate.
- def dumb_sort():
+ def dumb_sort() -> bool:
return strxfrm("A") < strxfrm("a")
- def get_strxfrm():
+ def get_strxfrm() -> TrxfmFunc:
return strxfrm
- def get_thousands_sep():
- sep = locale.localeconv()["thousands_sep"]
+ def get_thousands_sep() -> str:
+ sep = cast(str, locale.localeconv()["thousands_sep"])
# If this locale library is broken, some of the thousands separator
# characters are incorrectly blank. Here is a lookup table of the
# corrections I am aware of.
@@ -111,5 +118,5 @@ except ImportError:
else:
return sep
- def get_decimal_point():
- return locale.localeconv()["decimal_point"]
+ def get_decimal_point() -> str:
+ return cast(str, locale.localeconv()["decimal_point"])