summaryrefslogtreecommitdiff
path: root/pint/facets/nonmultiplicative/registry.py
diff options
context:
space:
mode:
Diffstat (limited to 'pint/facets/nonmultiplicative/registry.py')
-rw-r--r--pint/facets/nonmultiplicative/registry.py90
1 files changed, 74 insertions, 16 deletions
diff --git a/pint/facets/nonmultiplicative/registry.py b/pint/facets/nonmultiplicative/registry.py
index 8bc04db..505406c 100644
--- a/pint/facets/nonmultiplicative/registry.py
+++ b/pint/facets/nonmultiplicative/registry.py
@@ -8,16 +8,22 @@
from __future__ import annotations
-from typing import Any
+from typing import Any, TypeVar, Generic
+from ...compat import TypeAlias
from ...errors import DimensionalityError, UndefinedUnitError
from ...util import UnitsContainer, logger
-from ..plain import PlainRegistry, UnitDefinition
+from ..plain import GenericPlainRegistry, UnitDefinition, QuantityT, UnitT
from .definitions import OffsetConverter, ScaleConverter
-from .objects import NonMultiplicativeQuantity
+from . import objects
-class NonMultiplicativeRegistry(PlainRegistry):
+T = TypeVar("T")
+
+
+class GenericNonMultiplicativeRegistry(
+ Generic[QuantityT, UnitT], GenericPlainRegistry[QuantityT, UnitT]
+):
"""Handle of non multiplicative units (e.g. Temperature).
Capabilities:
@@ -35,8 +41,6 @@ class NonMultiplicativeRegistry(PlainRegistry):
"""
- Quantity = NonMultiplicativeQuantity
-
def __init__(
self,
default_as_delta: bool = True,
@@ -58,14 +62,14 @@ class NonMultiplicativeRegistry(PlainRegistry):
input_string: str,
as_delta: bool | None = None,
case_sensitive: bool | None = None,
- ):
+ ) -> UnitsContainer:
""" """
if as_delta is None:
as_delta = self.default_as_delta
return super()._parse_units(input_string, as_delta, case_sensitive)
- def _add_unit(self, definition: UnitDefinition):
+ def _add_unit(self, definition: UnitDefinition) -> None:
super()._add_unit(definition)
if definition.is_multiplicative:
@@ -104,22 +108,60 @@ class NonMultiplicativeRegistry(PlainRegistry):
)
super()._add_unit(delta_def)
- def _is_multiplicative(self, u) -> bool:
- if u in self._units:
- return self._units[u].is_multiplicative
+ def _is_multiplicative(self, unit_name: str) -> bool:
+ """True if the unit is multiplicative.
+
+ Parameters
+ ----------
+ unit_name
+ Name of the unit to check.
+ Can be prefixed, pluralized or even an alias
+
+ Raises
+ ------
+ UndefinedUnitError
+ If the unit is not in the registyr.
+ """
+ if unit_name in self._units:
+ return self._units[unit_name].is_multiplicative
# If the unit is not in the registry might be because it is not
# registered with its prefixed version.
# TODO: Might be better to register them.
- names = self.parse_unit_name(u)
+ names = self.parse_unit_name(unit_name)
assert len(names) == 1
_, base_name, _ = names[0]
try:
return self._units[base_name].is_multiplicative
except KeyError:
- raise UndefinedUnitError(u)
+ raise UndefinedUnitError(unit_name)
+
+ def _validate_and_extract(self, units: UnitsContainer) -> str | None:
+ """Used to check if a given units is suitable for a simple
+ conversion.
+
+ Return None if all units are non-multiplicative
+ Return the unit name if a single non-multiplicative unit is found
+ and is raised to a power equals to 1.
+
+ Otherwise, raise an Exception.
+
+ Parameters
+ ----------
+ units
+ Compound dictionary.
+
+ Raises
+ ------
+ ValueError
+ If the more than a single non-multiplicative unit is present,
+ or a single one is present but raised to a power different from 1.
+
+ """
+
+ # TODO: document what happens if autoconvert_offset_to_baseunit
+ # TODO: Clarify docs
- def _validate_and_extract(self, units):
# u is for unit, e is for exponent
nonmult_units = [
(u, e) for u, e in units.items() if not self._is_multiplicative(u)
@@ -147,11 +189,16 @@ class NonMultiplicativeRegistry(PlainRegistry):
return None
- def _add_ref_of_log_or_offset_unit(self, offset_unit, all_units):
+ def _add_ref_of_log_or_offset_unit(
+ self, offset_unit: str, all_units: UnitsContainer
+ ) -> UnitsContainer:
slct_unit = self._units[offset_unit]
if slct_unit.is_logarithmic or (not slct_unit.is_multiplicative):
# Extract reference unit
slct_ref = slct_unit.reference
+
+ # TODO: Check that reference is None
+
# If reference unit is not dimensionless
if slct_ref != UnitsContainer():
# Extract reference unit
@@ -161,7 +208,9 @@ class NonMultiplicativeRegistry(PlainRegistry):
# Otherwise, return the units unmodified
return all_units
- def _convert(self, value, src, dst, inplace=False):
+ def _convert(
+ self, value: T, src: UnitsContainer, dst: UnitsContainer, inplace: bool = False
+ ) -> T:
"""Convert value from some source to destination units.
In addition to what is done by the PlainRegistry,
@@ -235,3 +284,12 @@ class NonMultiplicativeRegistry(PlainRegistry):
)
return value
+
+
+class NonMultiplicativeRegistry(
+ GenericNonMultiplicativeRegistry[
+ objects.NonMultiplicativeQuantity[Any], objects.NonMultiplicativeUnit
+ ]
+):
+ Quantity: TypeAlias = objects.NonMultiplicativeQuantity[Any]
+ Unit: TypeAlias = objects.NonMultiplicativeUnit