summaryrefslogtreecommitdiff
path: root/pint/facets/nonmultiplicative/registry.py
blob: 7d783de11d53635cd0dc0c182ee2a80dc41d260f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
"""
    pint.facets.nonmultiplicative.registry
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    :copyright: 2022 by Pint Authors, see AUTHORS for more details.
    :license: BSD, see LICENSE for more details.
"""

from __future__ import annotations

from typing import Any, TypeVar, Generic, Optional

from ...compat import TypeAlias
from ...errors import DimensionalityError, UndefinedUnitError
from ...util import UnitsContainer, logger
from ..plain import GenericPlainRegistry, UnitDefinition, QuantityT, UnitT
from .definitions import OffsetConverter, ScaleConverter
from . import objects


T = TypeVar("T")


class GenericNonMultiplicativeRegistry(
    Generic[QuantityT, UnitT], GenericPlainRegistry[QuantityT, UnitT]
):
    """Handle of non multiplicative units (e.g. Temperature).

    Capabilities:
    - Register non-multiplicative units and their relations.
    - Convert between non-multiplicative units.

    Parameters
    ----------
    default_as_delta : bool
        If True, non-multiplicative units are interpreted as
        their *delta* counterparts in multiplications.
    autoconvert_offset_to_baseunit : bool
        If True, non-multiplicative units are
        converted to plain units in multiplications.

    """

    def __init__(
        self,
        default_as_delta: bool = True,
        autoconvert_offset_to_baseunit: bool = False,
        **kwargs: Any,
    ) -> None:
        super().__init__(**kwargs)

        #: When performing a multiplication of units, interpret
        #: non-multiplicative units as their *delta* counterparts.
        self.default_as_delta = default_as_delta

        # Determines if quantities with offset units are converted to their
        # plain units on multiplication and division.
        self.autoconvert_offset_to_baseunit = autoconvert_offset_to_baseunit

    def _parse_units(
        self,
        input_string: str,
        as_delta: Optional[bool] = None,
        case_sensitive: Optional[bool] = None,
    ) -> UnitsContainer:
        """ """
        if as_delta is None:
            as_delta = self.default_as_delta

        return super()._parse_units(input_string, as_delta, case_sensitive)

    def _add_unit(self, definition: UnitDefinition) -> None:
        super()._add_unit(definition)

        if definition.is_multiplicative:
            return

        if definition.is_logarithmic:
            return

        if not isinstance(definition.converter, OffsetConverter):
            logger.debug(
                "Cannot autogenerate delta version for a unit in "
                "which the converter is not an OffsetConverter"
            )
            return

        delta_name = "delta_" + definition.name
        if definition.symbol:
            delta_symbol = "Δ" + definition.symbol
        else:
            delta_symbol = None

        delta_aliases = tuple("Δ" + alias for alias in definition.aliases) + tuple(
            "delta_" + alias for alias in definition.aliases
        )

        delta_reference = self.UnitsContainer(
            {ref: value for ref, value in definition.reference.items()}
        )

        delta_def = UnitDefinition(
            delta_name,
            delta_symbol,
            delta_aliases,
            ScaleConverter(definition.converter.scale),
            delta_reference,
        )
        super()._add_unit(delta_def)

    def _is_multiplicative(self, unit_name: str) -> bool:
        """True if the unit is multiplicative.

        Parameters
        ----------
        unit_name
            Name of the unit to check.
            Can be prefixed, pluralized or even an alias

        Raises
        ------
        UndefinedUnitError
            If the unit is not in the registyr.
        """
        if unit_name in self._units:
            return self._units[unit_name].is_multiplicative

        # If the unit is not in the registry might be because it is not
        # registered with its prefixed version.
        # TODO: Might be better to register them.
        names = self.parse_unit_name(unit_name)
        assert len(names) == 1
        _, base_name, _ = names[0]
        try:
            return self._units[base_name].is_multiplicative
        except KeyError:
            raise UndefinedUnitError(unit_name)

    def _validate_and_extract(self, units: UnitsContainer) -> Optional[str]:
        """Used to check if a given units is suitable for a simple
        conversion.

        Return None if all units are non-multiplicative
        Return the unit name if a single non-multiplicative unit is found
        and is raised to a power equals to 1.

        Otherwise, raise an Exception.

        Parameters
        ----------
        units
            Compound dictionary.

        Raises
        ------
        ValueError
            If the more than a single non-multiplicative unit is present,
            or a single one is present but raised to a power different from 1.

        """

        # TODO: document what happens if autoconvert_offset_to_baseunit
        # TODO: Clarify docs

        # u is for unit, e is for exponent
        nonmult_units = [
            (u, e) for u, e in units.items() if not self._is_multiplicative(u)
        ]

        # Let's validate source offset units
        if len(nonmult_units) > 1:
            # More than one src offset unit is not allowed
            raise ValueError("more than one offset unit.")

        elif len(nonmult_units) == 1:
            # A single src offset unit is present. Extract it
            # But check that:
            # - the exponent is 1
            # - is not used in multiplicative context
            nonmult_unit, exponent = nonmult_units.pop()

            if exponent != 1:
                raise ValueError("offset units in higher order.")

            if len(units) > 1 and not self.autoconvert_offset_to_baseunit:
                raise ValueError("offset unit used in multiplicative context.")

            return nonmult_unit

        return None

    def _add_ref_of_log_or_offset_unit(
        self, offset_unit: str, all_units: UnitsContainer
    ) -> UnitsContainer:
        slct_unit = self._units[offset_unit]
        if slct_unit.is_logarithmic or (not slct_unit.is_multiplicative):
            # Extract reference unit
            slct_ref = slct_unit.reference

            # TODO: Check that reference is None

            # If reference unit is not dimensionless
            if slct_ref != UnitsContainer():
                # Extract reference unit
                (u, e) = [(u, e) for u, e in slct_ref.items()].pop()
                # Add it back to the unit list
                return all_units.add(u, e)
        # Otherwise, return the units unmodified
        return all_units

    def _convert(
        self, value: T, src: UnitsContainer, dst: UnitsContainer, inplace: bool = False
    ) -> T:
        """Convert value from some source to destination units.

        In addition to what is done by the PlainRegistry,
        converts between non-multiplicative units.

        Parameters
        ----------
        value :
            value
        src : UnitsContainer
            source units.
        dst : UnitsContainer
            destination units.
        inplace :
             (Default value = False)

        Returns
        -------
        type
            converted value

        """

        # Conversion needs to consider if non-multiplicative (AKA offset
        # units) are involved. Conversion is only possible if src and dst
        # have at most one offset unit per dimension. Other rules are applied
        # by validate and extract.
        try:
            src_offset_unit = self._validate_and_extract(src)
        except ValueError as ex:
            raise DimensionalityError(src, dst, extra_msg=f" - In source units, {ex}")

        try:
            dst_offset_unit = self._validate_and_extract(dst)
        except ValueError as ex:
            raise DimensionalityError(
                src, dst, extra_msg=f" - In destination units, {ex}"
            )

        if not (src_offset_unit or dst_offset_unit):
            return super()._convert(value, src, dst, inplace)

        src_dim = self._get_dimensionality(src)
        dst_dim = self._get_dimensionality(dst)

        # If the source and destination dimensionality are different,
        # then the conversion cannot be performed.
        if src_dim != dst_dim:
            raise DimensionalityError(src, dst, src_dim, dst_dim)

        # clean src from offset units by converting to reference
        if src_offset_unit:
            value = self._units[src_offset_unit].converter.to_reference(value, inplace)
            src = src.remove([src_offset_unit])
            # Add reference unit for multiplicative section
            src = self._add_ref_of_log_or_offset_unit(src_offset_unit, src)

        # clean dst units from offset units
        if dst_offset_unit:
            dst = dst.remove([dst_offset_unit])
            # Add reference unit for multiplicative section
            dst = self._add_ref_of_log_or_offset_unit(dst_offset_unit, dst)

        # Convert non multiplicative units to the dst.
        value = super()._convert(value, src, dst, inplace, False)

        # Finally convert to offset units specified in destination
        if dst_offset_unit:
            value = self._units[dst_offset_unit].converter.from_reference(
                value, inplace
            )

        return value


class NonMultiplicativeRegistry(
    GenericNonMultiplicativeRegistry[
        objects.NonMultiplicativeQuantity[Any], objects.NonMultiplicativeUnit
    ]
):
    Quantity: TypeAlias = objects.NonMultiplicativeQuantity[Any]
    Unit: TypeAlias = objects.NonMultiplicativeUnit