diff options
Diffstat (limited to 'pint/facets/nonmultiplicative/registry.py')
-rw-r--r-- | pint/facets/nonmultiplicative/registry.py | 90 |
1 files changed, 74 insertions, 16 deletions
diff --git a/pint/facets/nonmultiplicative/registry.py b/pint/facets/nonmultiplicative/registry.py index 8bc04db..505406c 100644 --- a/pint/facets/nonmultiplicative/registry.py +++ b/pint/facets/nonmultiplicative/registry.py @@ -8,16 +8,22 @@ from __future__ import annotations -from typing import Any +from typing import Any, TypeVar, Generic +from ...compat import TypeAlias from ...errors import DimensionalityError, UndefinedUnitError from ...util import UnitsContainer, logger -from ..plain import PlainRegistry, UnitDefinition +from ..plain import GenericPlainRegistry, UnitDefinition, QuantityT, UnitT from .definitions import OffsetConverter, ScaleConverter -from .objects import NonMultiplicativeQuantity +from . import objects -class NonMultiplicativeRegistry(PlainRegistry): +T = TypeVar("T") + + +class GenericNonMultiplicativeRegistry( + Generic[QuantityT, UnitT], GenericPlainRegistry[QuantityT, UnitT] +): """Handle of non multiplicative units (e.g. Temperature). Capabilities: @@ -35,8 +41,6 @@ class NonMultiplicativeRegistry(PlainRegistry): """ - Quantity = NonMultiplicativeQuantity - def __init__( self, default_as_delta: bool = True, @@ -58,14 +62,14 @@ class NonMultiplicativeRegistry(PlainRegistry): input_string: str, as_delta: bool | None = None, case_sensitive: bool | None = None, - ): + ) -> UnitsContainer: """ """ if as_delta is None: as_delta = self.default_as_delta return super()._parse_units(input_string, as_delta, case_sensitive) - def _add_unit(self, definition: UnitDefinition): + def _add_unit(self, definition: UnitDefinition) -> None: super()._add_unit(definition) if definition.is_multiplicative: @@ -104,22 +108,60 @@ class NonMultiplicativeRegistry(PlainRegistry): ) super()._add_unit(delta_def) - def _is_multiplicative(self, u) -> bool: - if u in self._units: - return self._units[u].is_multiplicative + def _is_multiplicative(self, unit_name: str) -> bool: + """True if the unit is multiplicative. + + Parameters + ---------- + unit_name + Name of the unit to check. + Can be prefixed, pluralized or even an alias + + Raises + ------ + UndefinedUnitError + If the unit is not in the registyr. + """ + if unit_name in self._units: + return self._units[unit_name].is_multiplicative # If the unit is not in the registry might be because it is not # registered with its prefixed version. # TODO: Might be better to register them. - names = self.parse_unit_name(u) + names = self.parse_unit_name(unit_name) assert len(names) == 1 _, base_name, _ = names[0] try: return self._units[base_name].is_multiplicative except KeyError: - raise UndefinedUnitError(u) + raise UndefinedUnitError(unit_name) + + def _validate_and_extract(self, units: UnitsContainer) -> str | None: + """Used to check if a given units is suitable for a simple + conversion. + + Return None if all units are non-multiplicative + Return the unit name if a single non-multiplicative unit is found + and is raised to a power equals to 1. + + Otherwise, raise an Exception. + + Parameters + ---------- + units + Compound dictionary. + + Raises + ------ + ValueError + If the more than a single non-multiplicative unit is present, + or a single one is present but raised to a power different from 1. + + """ + + # TODO: document what happens if autoconvert_offset_to_baseunit + # TODO: Clarify docs - def _validate_and_extract(self, units): # u is for unit, e is for exponent nonmult_units = [ (u, e) for u, e in units.items() if not self._is_multiplicative(u) @@ -147,11 +189,16 @@ class NonMultiplicativeRegistry(PlainRegistry): return None - def _add_ref_of_log_or_offset_unit(self, offset_unit, all_units): + def _add_ref_of_log_or_offset_unit( + self, offset_unit: str, all_units: UnitsContainer + ) -> UnitsContainer: slct_unit = self._units[offset_unit] if slct_unit.is_logarithmic or (not slct_unit.is_multiplicative): # Extract reference unit slct_ref = slct_unit.reference + + # TODO: Check that reference is None + # If reference unit is not dimensionless if slct_ref != UnitsContainer(): # Extract reference unit @@ -161,7 +208,9 @@ class NonMultiplicativeRegistry(PlainRegistry): # Otherwise, return the units unmodified return all_units - def _convert(self, value, src, dst, inplace=False): + def _convert( + self, value: T, src: UnitsContainer, dst: UnitsContainer, inplace: bool = False + ) -> T: """Convert value from some source to destination units. In addition to what is done by the PlainRegistry, @@ -235,3 +284,12 @@ class NonMultiplicativeRegistry(PlainRegistry): ) return value + + +class NonMultiplicativeRegistry( + GenericNonMultiplicativeRegistry[ + objects.NonMultiplicativeQuantity[Any], objects.NonMultiplicativeUnit + ] +): + Quantity: TypeAlias = objects.NonMultiplicativeQuantity[Any] + Unit: TypeAlias = objects.NonMultiplicativeUnit |