summaryrefslogtreecommitdiff
path: root/natsort/compat/locale.py
blob: b1b7c00931d813845be0e88ed78e6d8216380f8f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# -*- coding: utf-8 -*-
"""
Interface for natsort to access locale functionality without
having to worry about if it is using PyICU or the built-in locale.
"""
import sys
from typing import Callable, Union, cast

StrOrBytes = Union[str, bytes]
TrxfmFunc = Callable[[str], StrOrBytes]

# This string should be sorted after any other byte string because
# it contains the max unicode character repeated 20 times.
# You would need some odd data to come after that.
null_string = ""
null_string_max = chr(sys.maxunicode) * 20

# This variable could be str or bytes depending on the locale library
# being used, so give the type-checker this information.
null_string_locale: StrOrBytes
null_string_locale_max: StrOrBytes

# strxfrm can be buggy (especially on OSX and *possibly* some other
# BSD-based systems), so prefer icu if available.
try:  # noqa: C901
    import icu
    from locale import getlocale

    null_string_locale = b""

    # This string should in theory be sorted after any other byte
    # string because it contains the max byte char repeated many times.
    # You would need some odd data to come after that.
    null_string_locale_max = b"x7f" * 50

    def dumb_sort() -> bool:
        return False

    # If using icu, get the locale from the current global locale,
    def get_icu_locale() -> str:
        language_code, encoding = getlocale()
        if language_code is None or encoding is None:  # pragma: no cover
            return icu.Locale()
        return icu.Locale(f"{language_code}.{encoding}")

    def get_strxfrm() -> TrxfmFunc:
        return icu.Collator.createInstance(get_icu_locale()).getSortKey

    def get_thousands_sep() -> str:
        sep = icu.DecimalFormatSymbols.kGroupingSeparatorSymbol
        return icu.DecimalFormatSymbols(get_icu_locale()).getSymbol(sep)

    def get_decimal_point() -> str:
        sep = icu.DecimalFormatSymbols.kDecimalSeparatorSymbol
        return icu.DecimalFormatSymbols(get_icu_locale()).getSymbol(sep)

except ImportError:
    import locale
    from locale import strxfrm

    null_string_locale = null_string
    null_string_locale_max = null_string_max

    # On some systems, locale is broken and does not sort in the expected
    # order. We will try to detect this and compensate.
    def dumb_sort() -> bool:
        return strxfrm("A") < strxfrm("a")

    def get_strxfrm() -> TrxfmFunc:
        return strxfrm

    def get_thousands_sep() -> str:
        sep = cast(str, locale.localeconv()["thousands_sep"])
        # If this locale library is broken, some of the thousands separator
        # characters are incorrectly blank. Here is a lookup table of the
        # corrections I am aware of.
        if dumb_sort():
            language_code, encoding = locale.getlocale()
            if language_code is None or encoding is None:
                # No locale loaded, default to ','
                return ","
            loc = f"{language_code}.{encoding}"
            return {
                "de_DE.ISO8859-15": ".",
                "es_ES.ISO8859-1": ".",
                "de_AT.ISO8859-1": ".",
                "de_at": "\xa0",
                "nl_NL.UTF-8": ".",
                "es_es": ".",
                "fr_CH.ISO8859-15": "\xa0",
                "fr_CA.ISO8859-1": "\xa0",
                "de_CH.ISO8859-1": ".",
                "fr_FR.ISO8859-15": "\xa0",
                "nl_NL.ISO8859-1": ".",
                "ca_ES.UTF-8": ".",
                "nl_NL.ISO8859-15": ".",
                "de_ch": "'",
                "ca_es": ".",
                "de_AT.ISO8859-15": ".",
                "ca_ES.ISO8859-1": ".",
                "de_AT.UTF-8": ".",
                "es_ES.UTF-8": ".",
                "fr_fr": "\xa0",
                "es_ES.ISO8859-15": ".",
                "de_DE.ISO8859-1": ".",
                "nl_nl": ".",
                "fr_ch": "\xa0",
                "fr_ca": "\xa0",
                "de_DE.UTF-8": ".",
                "ca_ES.ISO8859-15": ".",
                "de_CH.ISO8859-15": ".",
                "fr_FR.ISO8859-1": "\xa0",
                "fr_CH.ISO8859-1": "\xa0",
                "de_de": ".",
                "fr_FR.UTF-8": "\xa0",
                "fr_CA.ISO8859-15": "\xa0",
            }.get(loc, sep)
        else:
            return sep

    def get_decimal_point() -> str:
        return cast(str, locale.localeconv()["decimal_point"])