diff options
Diffstat (limited to 'babel/numbers.py')
-rw-r--r-- | babel/numbers.py | 45 |
1 files changed, 23 insertions, 22 deletions
diff --git a/babel/numbers.py b/babel/numbers.py index 59acee2..1a86d9e 100644 --- a/babel/numbers.py +++ b/babel/numbers.py @@ -23,7 +23,7 @@ import datetime import decimal import re import warnings -from typing import TYPE_CHECKING, Any, overload +from typing import TYPE_CHECKING, Any, cast, overload from babel.core import Locale, default_locale, get_global from babel.localedata import LocaleDataDict @@ -428,7 +428,7 @@ def get_decimal_quantum(precision: int | decimal.Decimal) -> decimal.Decimal: def format_decimal( number: float | decimal.Decimal | str, - format: str | None = None, + format: str | NumberPattern | None = None, locale: Locale | str | None = LC_NUMERIC, decimal_quantization: bool = True, group_separator: bool = True, @@ -474,8 +474,8 @@ def format_decimal( number format. """ locale = Locale.parse(locale) - if not format: - format = locale.decimal_formats.get(format) + if format is None: + format = locale.decimal_formats[format] pattern = parse_pattern(format) return pattern.apply( number, locale, decimal_quantization=decimal_quantization, group_separator=group_separator) @@ -513,7 +513,7 @@ def format_compact_decimal( number, format = _get_compact_format(number, compact_format, locale, fraction_digits) # Did not find a format, fall back. if format is None: - format = locale.decimal_formats.get(None) + format = locale.decimal_formats[None] pattern = parse_pattern(format) return pattern.apply(number, locale, decimal_quantization=False) @@ -521,7 +521,7 @@ def format_compact_decimal( def _get_compact_format( number: float | decimal.Decimal | str, compact_format: LocaleDataDict, - locale: Locale | str | None, + locale: Locale, fraction_digits: int, ) -> tuple[decimal.Decimal, NumberPattern | None]: """Returns the number after dividing by the unit and the format pattern to use. @@ -543,7 +543,7 @@ def _get_compact_format( break # otherwise, we need to divide the number by the magnitude but remove zeros # equal to the number of 0's in the pattern minus 1 - number = number / (magnitude // (10 ** (pattern.count("0") - 1))) + number = cast(decimal.Decimal, number / (magnitude // (10 ** (pattern.count("0") - 1)))) # round to the number of fraction digits requested rounded = round(number, fraction_digits) # if the remaining number is singular, use the singular format @@ -565,7 +565,7 @@ class UnknownCurrencyFormatError(KeyError): def format_currency( number: float | decimal.Decimal | str, currency: str, - format: str | None = None, + format: str | NumberPattern | None = None, locale: Locale | str | None = LC_NUMERIC, currency_digits: bool = True, format_type: Literal["name", "standard", "accounting"] = "standard", @@ -680,7 +680,7 @@ def format_currency( def _format_currency_long_name( number: float | decimal.Decimal | str, currency: str, - format: str | None = None, + format: str | NumberPattern | None = None, locale: Locale | str | None = LC_NUMERIC, currency_digits: bool = True, format_type: Literal["name", "standard", "accounting"] = "standard", @@ -706,7 +706,7 @@ def _format_currency_long_name( # Step 5. if not format: - format = locale.decimal_formats.get(format) + format = locale.decimal_formats[format] pattern = parse_pattern(format) @@ -758,13 +758,15 @@ def format_compact_currency( # compress adjacent spaces into one format = re.sub(r'(\s)\s+', r'\1', format).strip() break + if format is None: + raise ValueError('No compact currency format found for the given number and locale.') pattern = parse_pattern(format) return pattern.apply(number, locale, currency=currency, currency_digits=False, decimal_quantization=False) def format_percent( number: float | decimal.Decimal | str, - format: str | None = None, + format: str | NumberPattern | None = None, locale: Locale | str | None = LC_NUMERIC, decimal_quantization: bool = True, group_separator: bool = True, @@ -808,7 +810,7 @@ def format_percent( """ locale = Locale.parse(locale) if not format: - format = locale.percent_formats.get(format) + format = locale.percent_formats[format] pattern = parse_pattern(format) return pattern.apply( number, locale, decimal_quantization=decimal_quantization, group_separator=group_separator) @@ -816,7 +818,7 @@ def format_percent( def format_scientific( number: float | decimal.Decimal | str, - format: str | None = None, + format: str | NumberPattern | None = None, locale: Locale | str | None = LC_NUMERIC, decimal_quantization: bool = True, ) -> str: @@ -847,7 +849,7 @@ def format_scientific( """ locale = Locale.parse(locale) if not format: - format = locale.scientific_formats.get(format) + format = locale.scientific_formats[format] pattern = parse_pattern(format) return pattern.apply( number, locale, decimal_quantization=decimal_quantization) @@ -856,7 +858,7 @@ def format_scientific( class NumberFormatError(ValueError): """Exception raised when a string cannot be parsed into a number.""" - def __init__(self, message: str, suggestions: str | None = None) -> None: + def __init__(self, message: str, suggestions: list[str] | None = None) -> None: super().__init__(message) #: a list of properly formatted numbers derived from the invalid input self.suggestions = suggestions @@ -1140,7 +1142,7 @@ class NumberPattern: def apply( self, - value: float | decimal.Decimal, + value: float | decimal.Decimal | str, locale: Locale | str | None, currency: str | None = None, currency_digits: bool = True, @@ -1211,9 +1213,9 @@ class NumberPattern: number = ''.join([ self._quantize_value(value, locale, frac_prec, group_separator), get_exponential_symbol(locale), - exp_sign, - self._format_int( - str(exp), self.exp_prec[0], self.exp_prec[1], locale)]) + exp_sign, # type: ignore # exp_sign is always defined here + self._format_int(str(exp), self.exp_prec[0], self.exp_prec[1], locale) # type: ignore # exp is always defined here + ]) # Is it a significant digits pattern? elif '@' in self.pattern: @@ -1234,9 +1236,8 @@ class NumberPattern: number if self.number_pattern != '' else '', self.suffix[is_negative]]) - if '¤' in retval: - retval = retval.replace('¤¤¤', - get_currency_name(currency, value, locale)) + if '¤' in retval and currency is not None: + retval = retval.replace('¤¤¤', get_currency_name(currency, value, locale)) retval = retval.replace('¤¤', currency.upper()) retval = retval.replace('¤', get_currency_symbol(currency, locale)) |