summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSeth Morton <seth.m.morton@gmail.com>2022-09-01 13:51:27 -0700
committerSeth Morton <seth.m.morton@gmail.com>2022-09-01 13:59:30 -0700
commitc634a4a8437de989b4394c48cd80e19ea82040f8 (patch)
tree758b831ad72636471e0630f38b676b86683f39e9
parent22cdc562baf60389a5f64a0ce241caf34d09abea (diff)
downloadnatsort-c634a4a8437de989b4394c48cd80e19ea82040f8.tar.gz
Add stubs for icu
This way we don't need to use casts in-code.
-rw-r--r--mypy_stubs/icu.pyi25
-rw-r--r--natsort/compat/locale.py11
-rw-r--r--setup.cfg4
3 files changed, 32 insertions, 8 deletions
diff --git a/mypy_stubs/icu.pyi b/mypy_stubs/icu.pyi
new file mode 100644
index 0000000..be46c19
--- /dev/null
+++ b/mypy_stubs/icu.pyi
@@ -0,0 +1,25 @@
+from __future__ import annotations
+
+from typing import overload
+@overload
+def Locale() -> str: ...
+@overload
+def Locale(x: str) -> str: ...
+
+class UCollAttribute:
+ NUMERIC_COLLATION: int
+
+class UCollAttributeValue:
+ ON: int
+
+class DecimalFormatSymbols:
+ kGroupingSeparatorSymbol: int
+ kDecimalSeparatorSymbol: int
+ def __init__(self, locale: str) -> None: ...
+ def getSymbol(self, symbol: int) -> str: ...
+
+class Collator:
+ @classmethod
+ def createInstance(cls, locale: str) -> Collator: ...
+ def getSortKey(self, source: str) -> bytes: ...
+ def setAttribute(self, attr: int, value: int) -> None: ...
diff --git a/natsort/compat/locale.py b/natsort/compat/locale.py
index 53080c3..8d7ae48 100644
--- a/natsort/compat/locale.py
+++ b/natsort/compat/locale.py
@@ -40,19 +40,20 @@ try: # noqa: C901
def get_icu_locale() -> str:
language_code, encoding = getlocale()
if language_code is None or encoding is None: # pragma: no cover
- return cast(str, icu.Locale())
- return cast(str, icu.Locale(f"{language_code}.{encoding}"))
+ return icu.Locale()
+ return icu.Locale(f"{language_code}.{encoding}")
def get_strxfrm() -> TrxfmFunc:
- return cast(TrxfmFunc, icu.Collator.createInstance(get_icu_locale()).getSortKey)
+ return icu.Collator.createInstance(get_icu_locale()).getSortKey
def get_thousands_sep() -> str:
sep = icu.DecimalFormatSymbols.kGroupingSeparatorSymbol
- return cast(str, icu.DecimalFormatSymbols(get_icu_locale()).getSymbol(sep))
+ return icu.DecimalFormatSymbols(get_icu_locale()).getSymbol(sep)
def get_decimal_point() -> str:
sep = icu.DecimalFormatSymbols.kDecimalSeparatorSymbol
- return cast(str, icu.DecimalFormatSymbols(get_icu_locale()).getSymbol(sep))
+ return icu.DecimalFormatSymbols(get_icu_locale()).getSymbol(sep)
+
except ImportError:
import locale
diff --git a/setup.cfg b/setup.cfg
index 6fd9d1d..a99791c 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -64,6 +64,4 @@ exclude =
.venv
[mypy]
-
-[mypy-icu]
-ignore_missing_imports = True
+mypy_path = mypy_stubs