summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--pint/_typing.py53
-rw-r--r--pint/compat.py14
-rw-r--r--pint/converters.py14
-rw-r--r--pint/delegates/txt_defparser/defparser.py2
-rw-r--r--pint/facets/__init__.py33
-rw-r--r--pint/facets/context/__init__.py4
-rw-r--r--pint/facets/context/definitions.py8
-rw-r--r--pint/facets/context/objects.py97
-rw-r--r--pint/facets/context/registry.py48
-rw-r--r--pint/facets/dask/__init__.py29
-rw-r--r--pint/facets/formatting/__init__.py9
-rw-r--r--pint/facets/formatting/objects.py6
-rw-r--r--pint/facets/formatting/registry.py21
-rw-r--r--pint/facets/group/__init__.py13
-rw-r--r--pint/facets/group/definitions.py2
-rw-r--r--pint/facets/group/objects.py37
-rw-r--r--pint/facets/group/registry.py50
-rw-r--r--pint/facets/measurement/__init__.py9
-rw-r--r--pint/facets/measurement/objects.py9
-rw-r--r--pint/facets/measurement/registry.py21
-rw-r--r--pint/facets/nonmultiplicative/__init__.py6
-rw-r--r--pint/facets/nonmultiplicative/objects.py10
-rw-r--r--pint/facets/nonmultiplicative/registry.py90
-rw-r--r--pint/facets/numpy/__init__.py4
-rw-r--r--pint/facets/numpy/quantity.py16
-rw-r--r--pint/facets/numpy/registry.py17
-rw-r--r--pint/facets/plain/__init__.py7
-rw-r--r--pint/facets/plain/definitions.py44
-rw-r--r--pint/facets/plain/qto.py386
-rw-r--r--pint/facets/plain/quantity.py493
-rw-r--r--pint/facets/plain/registry.py329
-rw-r--r--pint/facets/system/__init__.py4
-rw-r--r--pint/facets/system/definitions.py2
-rw-r--r--pint/facets/system/objects.py43
-rw-r--r--pint/facets/system/registry.py80
-rw-r--r--pint/formatting.py63
-rw-r--r--pint/registry.py59
-rw-r--r--pint/registry_helpers.py8
-rw-r--r--pint/testing.py6
-rw-r--r--pint/util.py55
40 files changed, 1368 insertions, 833 deletions
diff --git a/pint/_typing.py b/pint/_typing.py
index 65e355c..5177e78 100644
--- a/pint/_typing.py
+++ b/pint/_typing.py
@@ -1,9 +1,9 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union, Protocol
+from decimal import Decimal
+from fractions import Fraction
-# TODO: Remove when 3.11 becomes minimal version.
-Self = TypeVar("Self")
if TYPE_CHECKING:
from .facets.plain import PlainQuantity as Quantity
@@ -11,7 +11,7 @@ if TYPE_CHECKING:
from .util import UnitsContainer
-class PintScalar(Protocol):
+class ScalarProtocol(Protocol):
def __add__(self, other: Any) -> Any:
...
@@ -36,8 +36,20 @@ class PintScalar(Protocol):
def __pow__(self, other: Any, modulo: Any) -> Any:
...
+ def __gt__(self, other: Any) -> bool:
+ ...
+
+ def __lt__(self, other: Any) -> bool:
+ ...
+
+ def __ge__(self, other: Any) -> bool:
+ ...
-class PintArray(Protocol):
+ def __le__(self, other: Any) -> bool:
+ ...
+
+
+class ArrayProtocol(Protocol):
def __len__(self) -> int:
...
@@ -48,18 +60,41 @@ class PintArray(Protocol):
...
+HAS_NUMPY = False
+if TYPE_CHECKING:
+ from .compat import HAS_NUMPY
+
+if HAS_NUMPY:
+ from .compat import np
+
+ Scalar = Union[ScalarProtocol, float, int, Decimal, Fraction, np.number[Any]]
+ Array = Union[np.ndarray[Any, Any]]
+else:
+ Scalar = Union[ScalarProtocol, float, int, Decimal, Fraction]
+ Array = ArrayProtocol
+
+
# TODO: Change when Python 3.10 becomes minimal version.
-# Magnitude = PintScalar | PintArray
-Magnitude = Union[PintScalar, PintArray]
+Magnitude = Union[ScalarProtocol, ArrayProtocol]
-UnitLike = Union[str, "UnitsContainer", "Unit"]
+UnitLike = Union[str, dict[str, Scalar], "UnitsContainer", "Unit"]
QuantityOrUnitLike = Union["Quantity", UnitLike]
-Shape = tuple[int, ...]
+Shape = tuple[int]
-_MagnitudeType = TypeVar("_MagnitudeType")
S = TypeVar("S")
FuncType = Callable[..., Any]
F = TypeVar("F", bound=FuncType)
+
+
+# TODO: Improve or delete types
+QuantityArgument = Any
+
+T = TypeVar("T")
+
+
+class Handler(Protocol):
+ def __getitem__(self, item: type[T]) -> Callable[[T], None]:
+ ...
diff --git a/pint/compat.py b/pint/compat.py
index 7b48efa..727ff99 100644
--- a/pint/compat.py
+++ b/pint/compat.py
@@ -20,6 +20,16 @@ from collections.abc import Mapping
from typing import Any, NoReturn, Callable
from collections.abc import Generator, Iterable
+try:
+ from typing import TypeAlias # noqa
+except ImportError:
+ from typing_extensions import TypeAlias # noqa
+
+try:
+ from typing import Self # noqa
+except ImportError:
+ from typing_extensions import Self # noqa
+
def missing_dependency(
package: str, display_name: str | None = None
@@ -137,10 +147,10 @@ except ImportError:
HAS_UNCERTAINTIES = False
try:
- from babel import Locale as Loc
+ from babel import Locale
from babel import units as babel_units
- babel_parse = Loc.parse
+ babel_parse = Locale.parse
HAS_BABEL = hasattr(babel_units, "format_unit")
except ImportError:
diff --git a/pint/converters.py b/pint/converters.py
index 9494ad1..822b8a0 100644
--- a/pint/converters.py
+++ b/pint/converters.py
@@ -15,16 +15,18 @@ from dataclasses import fields as dc_fields
from typing import Any
-from ._typing import Self, Magnitude
+from ._typing import Magnitude
-from .compat import HAS_NUMPY, exp, log # noqa: F401
+from .compat import HAS_NUMPY, exp, log, Self # noqa: F401
@dataclass(frozen=True)
class Converter:
"""Base class for value converters."""
+ # list[type[Converter]]
_subclasses = []
+ # dict[frozenset[str], type[Converter]]
_param_names_to_subclass = {}
@property
@@ -41,21 +43,21 @@ class Converter:
def from_reference(self, value: Magnitude, inplace: bool = False) -> Magnitude:
return value
- def __init_subclass__(cls, **kwargs):
+ def __init_subclass__(cls, **kwargs: Any):
# Get constructor parameters
super().__init_subclass__(**kwargs)
cls._subclasses.append(cls)
@classmethod
- def get_field_names(cls, new_cls) -> frozenset[str]:
+ def get_field_names(cls, new_cls: type) -> frozenset[str]:
return frozenset(p.name for p in dc_fields(new_cls))
@classmethod
- def preprocess_kwargs(cls, **kwargs):
+ def preprocess_kwargs(cls, **kwargs: Any) -> dict[str, Any] | None:
return None
@classmethod
- def from_arguments(cls: type[Self], **kwargs: Any) -> Self:
+ def from_arguments(cls, **kwargs: Any) -> Converter:
kwk = frozenset(kwargs.keys())
try:
new_cls = cls._param_names_to_subclass[kwk]
diff --git a/pint/delegates/txt_defparser/defparser.py b/pint/delegates/txt_defparser/defparser.py
index f1b8e45..4acea2f 100644
--- a/pint/delegates/txt_defparser/defparser.py
+++ b/pint/delegates/txt_defparser/defparser.py
@@ -130,7 +130,7 @@ class DefParser:
else:
yield stmt
- def parse_file(self, filename: pathlib.Path, cfg: ParserConfig | None = None):
+ def parse_file(self, filename: pathlib.Path | str, cfg: ParserConfig | None = None):
return fp.parse(
filename,
_PintParser,
diff --git a/pint/facets/__init__.py b/pint/facets/__init__.py
index 750f729..4fd1597 100644
--- a/pint/facets/__init__.py
+++ b/pint/facets/__init__.py
@@ -71,15 +71,18 @@
from __future__ import annotations
-from .context import ContextRegistry
-from .dask import DaskRegistry
-from .formatting import FormattingRegistry
-from .group import GroupRegistry
-from .measurement import MeasurementRegistry
-from .nonmultiplicative import NonMultiplicativeRegistry
-from .numpy import NumpyRegistry
-from .plain import PlainRegistry
-from .system import SystemRegistry
+from .context import ContextRegistry, GenericContextRegistry
+from .dask import DaskRegistry, GenericDaskRegistry
+from .formatting import FormattingRegistry, GenericFormattingRegistry
+from .group import GroupRegistry, GenericGroupRegistry
+from .measurement import MeasurementRegistry, GenericMeasurementRegistry
+from .nonmultiplicative import (
+ NonMultiplicativeRegistry,
+ GenericNonMultiplicativeRegistry,
+)
+from .numpy import NumpyRegistry, GenericNumpyRegistry
+from .plain import PlainRegistry, GenericPlainRegistry, QuantityT, UnitT, MagnitudeT
+from .system import SystemRegistry, GenericSystemRegistry
__all__ = [
"ContextRegistry",
@@ -91,4 +94,16 @@ __all__ = [
"NumpyRegistry",
"PlainRegistry",
"SystemRegistry",
+ "GenericContextRegistry",
+ "GenericDaskRegistry",
+ "GenericFormattingRegistry",
+ "GenericGroupRegistry",
+ "GenericMeasurementRegistry",
+ "GenericNonMultiplicativeRegistry",
+ "GenericNumpyRegistry",
+ "GenericPlainRegistry",
+ "GenericSystemRegistry",
+ "QuantityT",
+ "UnitT",
+ "MagnitudeT",
]
diff --git a/pint/facets/context/__init__.py b/pint/facets/context/__init__.py
index db28436..28c7b5c 100644
--- a/pint/facets/context/__init__.py
+++ b/pint/facets/context/__init__.py
@@ -13,6 +13,6 @@ from __future__ import annotations
from .definitions import ContextDefinition
from .objects import Context
-from .registry import ContextRegistry
+from .registry import ContextRegistry, GenericContextRegistry
-__all__ = ["ContextDefinition", "Context", "ContextRegistry"]
+__all__ = ["ContextDefinition", "Context", "ContextRegistry", "GenericContextRegistry"]
diff --git a/pint/facets/context/definitions.py b/pint/facets/context/definitions.py
index 833857e..d2581d5 100644
--- a/pint/facets/context/definitions.py
+++ b/pint/facets/context/definitions.py
@@ -12,7 +12,7 @@ import itertools
import numbers
import re
from dataclasses import dataclass
-from typing import TYPE_CHECKING, Any, Callable
+from typing import TYPE_CHECKING, Callable
from collections.abc import Iterable
from ... import errors
@@ -47,7 +47,7 @@ class Relation:
return set(self._varname_re.findall(self.equation))
@property
- def transformation(self) -> Callable[..., Quantity[Any]]:
+ def transformation(self) -> Callable[..., Quantity]:
"""Return a transformation callable that uses the registry
to parse the transformation equation.
"""
@@ -68,7 +68,7 @@ class ForwardRelation(Relation):
"""
@property
- def bidirectional(self):
+ def bidirectional(self) -> bool:
return False
@@ -82,7 +82,7 @@ class BidirectionalRelation(Relation):
"""
@property
- def bidirectional(self):
+ def bidirectional(self) -> bool:
return True
diff --git a/pint/facets/context/objects.py b/pint/facets/context/objects.py
index 38d8805..9517821 100644
--- a/pint/facets/context/objects.py
+++ b/pint/facets/context/objects.py
@@ -10,12 +10,32 @@ from __future__ import annotations
import weakref
from collections import ChainMap, defaultdict
-from typing import Any
+from typing import Any, Callable, Protocol, Generic
from collections.abc import Iterable
-from ...facets.plain import UnitDefinition
+from ...facets.plain import UnitDefinition, PlainQuantity, PlainUnit, MagnitudeT
from ...util import UnitsContainer, to_units_container
from .definitions import ContextDefinition
+from ..._typing import Magnitude
+
+
+class Transformation(Protocol):
+ def __call__(self, value: Magnitude, **kwargs: Any) -> Magnitude:
+ ...
+
+
+from ..._typing import UnitLike
+
+ToBaseFunc = Callable[[UnitsContainer], UnitsContainer]
+SrcDst = tuple[UnitsContainer, UnitsContainer]
+
+
+class ContextQuantity(Generic[MagnitudeT], PlainQuantity[MagnitudeT]):
+ pass
+
+
+class ContextUnit(PlainUnit):
+ pass
class Context:
@@ -75,24 +95,27 @@ class Context:
aliases: tuple[str] = tuple(),
defaults: dict[str, Any] | None = None,
) -> None:
- self.name = name
- self.aliases = aliases
+ self.name: str | None = name
+ self.aliases: tuple[str] = aliases
#: Maps (src, dst) -> transformation function
- self.funcs = {}
+ self.funcs: dict[SrcDst, Transformation] = {}
#: Maps defaults variable names to values
- self.defaults = defaults or {}
+ self.defaults: dict[str, Any] = defaults or {}
# Store Definition objects that are context-specific
- self.redefinitions = []
+ # TODO: narrow type this if possible.
+ self.redefinitions: list[Any] = []
# Flag set to True by the Registry the first time the context is enabled
self.checked = False
#: Maps (src, dst) -> self
#: Used as a convenience dictionary to be composed by ContextChain
- self.relation_to_context = weakref.WeakValueDictionary()
+ self.relation_to_context: weakref.WeakValueDictionary[
+ SrcDst, Context
+ ] = weakref.WeakValueDictionary()
@classmethod
def from_context(cls, context: Context, **defaults: Any) -> Context:
@@ -125,13 +148,22 @@ class Context:
@classmethod
def from_lines(
- cls, lines: Iterable[str], to_base_func=None, non_int_type: type = float
+ cls,
+ lines: Iterable[str],
+ to_base_func: ToBaseFunc | None = None,
+ non_int_type: type = float,
) -> Context:
- cd = ContextDefinition.from_lines(lines, non_int_type)
- return cls.from_definition(cd, to_base_func)
+ context_definition = ContextDefinition.from_lines(lines, non_int_type)
+
+ if context_definition is None:
+ raise ValueError(f"Could not define Context from from {lines}")
+
+ return cls.from_definition(context_definition, to_base_func)
@classmethod
- def from_definition(cls, cd: ContextDefinition, to_base_func=None) -> Context:
+ def from_definition(
+ cls, cd: ContextDefinition, to_base_func: ToBaseFunc | None = None
+ ) -> Context:
ctx = cls(cd.name, cd.aliases, cd.defaults)
for definition in cd.redefinitions:
@@ -139,6 +171,7 @@ class Context:
for relation in cd.relations:
try:
+ # TODO: check to_base_func. Is it a good API idea?
if to_base_func:
src = to_base_func(relation.src)
dst = to_base_func(relation.dst)
@@ -154,14 +187,16 @@ class Context:
return ctx
- def add_transformation(self, src, dst, func) -> None:
+ def add_transformation(
+ self, src: UnitLike, dst: UnitLike, func: Transformation
+ ) -> None:
"""Add a transformation function to the context."""
_key = self.__keytransform__(src, dst)
self.funcs[_key] = func
self.relation_to_context[_key] = self
- def remove_transformation(self, src, dst) -> None:
+ def remove_transformation(self, src: UnitLike, dst: UnitLike) -> None:
"""Add a transformation function to the context."""
_key = self.__keytransform__(src, dst)
@@ -169,14 +204,17 @@ class Context:
del self.relation_to_context[_key]
@staticmethod
- def __keytransform__(src, dst) -> tuple[UnitsContainer, UnitsContainer]:
+ def __keytransform__(src: UnitLike, dst: UnitLike) -> SrcDst:
return to_units_container(src), to_units_container(dst)
- def transform(self, src, dst, registry, value):
+ def transform(
+ self, src: UnitLike, dst: UnitLike, registry: Any, value: Magnitude
+ ) -> Magnitude:
"""Transform a value."""
_key = self.__keytransform__(src, dst)
- return self.funcs[_key](registry, value, **self.defaults)
+ func = self.funcs[_key]
+ return func(registry, value, **self.defaults)
def redefine(self, definition: str) -> None:
"""Override the definition of a unit in the registry.
@@ -202,7 +240,13 @@ class Context:
def hashable(
self,
- ) -> tuple[str | None, tuple[str, ...], frozenset, frozenset, tuple]:
+ ) -> tuple[
+ str | None,
+ tuple[str],
+ frozenset[tuple[SrcDst, int]],
+ frozenset[tuple[str, Any]],
+ tuple[Any],
+ ]:
"""Generate a unique hashable and comparable representation of self, which can
be used as a key in a dict. This class cannot define ``__hash__`` because it is
mutable, and the Python interpreter does cache the output of ``__hash__``.
@@ -220,18 +264,18 @@ class Context:
)
-class ContextChain(ChainMap):
+class ContextChain(ChainMap[SrcDst, Context]):
"""A specialized ChainMap for contexts that simplifies finding rules
to transform from one dimension to another.
"""
def __init__(self):
super().__init__()
- self.contexts = []
+ self.contexts: list[Context] = []
self.maps.clear() # Remove default empty map
- self._graph = None
+ self._graph: dict[SrcDst, set[UnitsContainer]] | None = None
- def insert_contexts(self, *contexts):
+ def insert_contexts(self, *contexts: Context):
"""Insert one or more contexts in reversed order the chained map.
(A rule in last context will take precedence)
@@ -243,7 +287,7 @@ class ContextChain(ChainMap):
self.maps = [ctx.relation_to_context for ctx in reversed(contexts)] + self.maps
self._graph = None
- def remove_contexts(self, n: int = None):
+ def remove_contexts(self, n: int | None = None):
"""Remove the last n inserted contexts from the chain.
Parameters
@@ -257,7 +301,7 @@ class ContextChain(ChainMap):
self._graph = None
@property
- def defaults(self):
+ def defaults(self) -> dict[str, Any]:
for ctx in self.values():
return ctx.defaults
return {}
@@ -271,7 +315,10 @@ class ContextChain(ChainMap):
self._graph[fr_].add(to_)
return self._graph
- def transform(self, src, dst, registry, value):
+ # TODO: type registry
+ def transform(
+ self, src: UnitsContainer, dst: UnitsContainer, registry: Any, value: Magnitude
+ ):
"""Transform the value, finding the rule in the chained context.
(A rule in last context will take precedence)
"""
diff --git a/pint/facets/context/registry.py b/pint/facets/context/registry.py
index a36d82d..746e79c 100644
--- a/pint/facets/context/registry.py
+++ b/pint/facets/context/registry.py
@@ -11,12 +11,13 @@ from __future__ import annotations
import functools
from collections import ChainMap
from contextlib import contextmanager
-from typing import Any, Callable, ContextManager
+from typing import Any, Callable, Generator, Generic
-from ..._typing import F
+from ...compat import TypeAlias
+from ..._typing import F, Magnitude
from ...errors import UndefinedUnitError
-from ...util import find_connected_nodes, find_shortest_path, logger
-from ..plain import PlainRegistry, UnitDefinition
+from ...util import find_connected_nodes, find_shortest_path, logger, UnitsContainer
+from ..plain import GenericPlainRegistry, UnitDefinition, QuantityT, UnitT
from .definitions import ContextDefinition
from . import objects
@@ -36,7 +37,9 @@ class ContextCacheOverlay:
self.parse_unit = registry_cache.parse_unit
-class ContextRegistry(PlainRegistry):
+class GenericContextRegistry(
+ Generic[QuantityT, UnitT], GenericPlainRegistry[QuantityT, UnitT]
+):
"""Handle of Contexts.
Conversion between units with different dimensions according
@@ -50,7 +53,7 @@ class ContextRegistry(PlainRegistry):
- Parse @context directive.
"""
- Context = objects.Context
+ Context: type[objects.Context] = objects.Context
def __init__(self, **kwargs: Any) -> None:
# Map context name (string) or abbreviation to context.
@@ -65,13 +68,13 @@ class ContextRegistry(PlainRegistry):
super().__init__(**kwargs)
# Allow contexts to add override layers to the units
- self._units = ChainMap(self._units)
+ self._units: ChainMap[str, UnitDefinition] = ChainMap(self._units)
def _register_definition_adders(self) -> None:
super()._register_definition_adders()
self._register_adder(ContextDefinition, self.add_context)
- def add_context(self, context: Context | ContextDefinition) -> None:
+ def add_context(self, context: objects.Context | ContextDefinition) -> None:
"""Add a context object to the registry.
The context will be accessible by its name and aliases.
@@ -194,7 +197,7 @@ class ContextRegistry(PlainRegistry):
self.define(definition)
def enable_contexts(
- self, *names_or_contexts: str | objects.Context, **kwargs
+ self, *names_or_contexts: str | objects.Context, **kwargs: Any
) -> None:
"""Enable contexts provided by name or by object.
@@ -241,7 +244,7 @@ class ContextRegistry(PlainRegistry):
self._active_ctx.insert_contexts(*contexts)
self._switch_context_cache_and_units()
- def disable_contexts(self, n: int = None) -> None:
+ def disable_contexts(self, n: int | None = None) -> None:
"""Disable the last n enabled contexts.
Parameters
@@ -253,7 +256,9 @@ class ContextRegistry(PlainRegistry):
self._switch_context_cache_and_units()
@contextmanager
- def context(self, *names, **kwargs) -> ContextManager[objects.Context]:
+ def context(
+ self: GenericContextRegistry[QuantityT, UnitT], *names: str, **kwargs: Any
+ ) -> Generator[GenericContextRegistry[QuantityT, UnitT], None, None]:
"""Used as a context manager, this function enables to activate a context
which is removed after usage.
@@ -309,7 +314,7 @@ class ContextRegistry(PlainRegistry):
# the added contexts are removed from the active one.
self.disable_contexts(len(names))
- def with_context(self, name, **kwargs) -> Callable[[F], F]:
+ def with_context(self, name: str, **kwargs: Any) -> Callable[[F], F]:
"""Decorator to wrap a function call in a Pint context.
Use it to ensure that a certain context is active when
@@ -351,7 +356,13 @@ class ContextRegistry(PlainRegistry):
return decorator
- def _convert(self, value, src, dst, inplace=False):
+ def _convert(
+ self,
+ value: Magnitude,
+ src: UnitsContainer,
+ dst: UnitsContainer,
+ inplace: bool = False,
+ ) -> Magnitude:
"""Convert value from some source to destination units.
In addition to what is done by the PlainRegistry,
@@ -391,7 +402,9 @@ class ContextRegistry(PlainRegistry):
return super()._convert(value, src, dst, inplace)
- def _get_compatible_units(self, input_units, group_or_system):
+ def _get_compatible_units(
+ self, input_units: UnitsContainer, group_or_system: str | None = None
+ ):
src_dim = self._get_dimensionality(input_units)
ret = super()._get_compatible_units(input_units, group_or_system)
@@ -404,3 +417,10 @@ class ContextRegistry(PlainRegistry):
ret |= self._cache.dimensional_equivalents[node]
return ret
+
+
+class ContextRegistry(
+ GenericContextRegistry[objects.ContextQuantity[Any], objects.ContextUnit]
+):
+ Quantity: TypeAlias = objects.ContextQuantity[Any]
+ Unit: TypeAlias = objects.ContextUnit
diff --git a/pint/facets/dask/__init__.py b/pint/facets/dask/__init__.py
index 90c8972..8d62f55 100644
--- a/pint/facets/dask/__init__.py
+++ b/pint/facets/dask/__init__.py
@@ -11,10 +11,18 @@
from __future__ import annotations
+from typing import Generic, Any
import functools
-from ...compat import compute, dask_array, persist, visualize
-from ..plain import PlainRegistry, PlainQuantity
+from ...compat import compute, dask_array, persist, visualize, TypeAlias
+from ..plain import (
+ GenericPlainRegistry,
+ PlainQuantity,
+ QuantityT,
+ UnitT,
+ PlainUnit,
+ MagnitudeT,
+)
def check_dask_array(f):
@@ -31,7 +39,7 @@ def check_dask_array(f):
return wrapper
-class DaskQuantity(PlainQuantity):
+class DaskQuantity(Generic[MagnitudeT], PlainQuantity[MagnitudeT]):
# Dask.array.Array ducking
def __dask_graph__(self):
if isinstance(self._magnitude, dask_array.Array):
@@ -119,5 +127,16 @@ class DaskQuantity(PlainQuantity):
visualize(self, **kwargs)
-class DaskRegistry(PlainRegistry):
- Quantity = DaskQuantity
+class DaskUnit(PlainUnit):
+ pass
+
+
+class GenericDaskRegistry(
+ Generic[QuantityT, UnitT], GenericPlainRegistry[QuantityT, UnitT]
+):
+ pass
+
+
+class DaskRegistry(GenericDaskRegistry[DaskQuantity[Any], DaskUnit]):
+ Quantity: TypeAlias = DaskQuantity[Any]
+ Unit: TypeAlias = DaskUnit
diff --git a/pint/facets/formatting/__init__.py b/pint/facets/formatting/__init__.py
index e3f4381..799fa31 100644
--- a/pint/facets/formatting/__init__.py
+++ b/pint/facets/formatting/__init__.py
@@ -11,6 +11,11 @@
from __future__ import annotations
from .objects import FormattingQuantity, FormattingUnit
-from .registry import FormattingRegistry
+from .registry import FormattingRegistry, GenericFormattingRegistry
-__all__ = ["FormattingQuantity", "FormattingUnit", "FormattingRegistry"]
+__all__ = [
+ "FormattingQuantity",
+ "FormattingUnit",
+ "FormattingRegistry",
+ "GenericFormattingRegistry",
+]
diff --git a/pint/facets/formatting/objects.py b/pint/facets/formatting/objects.py
index 5df937c..7d39e91 100644
--- a/pint/facets/formatting/objects.py
+++ b/pint/facets/formatting/objects.py
@@ -9,7 +9,7 @@
from __future__ import annotations
import re
-from typing import Any
+from typing import Any, Generic
from ...compat import babel_parse, ndarray, np
from ...formatting import (
@@ -23,10 +23,10 @@ from ...formatting import (
)
from ...util import UnitsContainer, iterable
-from ..plain import PlainQuantity, PlainUnit
+from ..plain import PlainQuantity, PlainUnit, MagnitudeT
-class FormattingQuantity(PlainQuantity):
+class FormattingQuantity(Generic[MagnitudeT], PlainQuantity[MagnitudeT]):
_exp_pattern = re.compile(r"([0-9]\.?[0-9]*)e(-?)\+?0*([0-9]+)")
def __format__(self, spec: str) -> str:
diff --git a/pint/facets/formatting/registry.py b/pint/facets/formatting/registry.py
index c4dc373..7684597 100644
--- a/pint/facets/formatting/registry.py
+++ b/pint/facets/formatting/registry.py
@@ -8,10 +8,21 @@
from __future__ import annotations
-from ..plain import PlainRegistry
-from .objects import FormattingQuantity, FormattingUnit
+from typing import Generic, Any
+from ...compat import TypeAlias
+from ..plain import GenericPlainRegistry, QuantityT, UnitT
+from . import objects
-class FormattingRegistry(PlainRegistry):
- Quantity = FormattingQuantity
- Unit = FormattingUnit
+
+class GenericFormattingRegistry(
+ Generic[QuantityT, UnitT], GenericPlainRegistry[QuantityT, UnitT]
+):
+ pass
+
+
+class FormattingRegistry(
+ GenericFormattingRegistry[objects.FormattingQuantity[Any], objects.FormattingUnit]
+):
+ Quantity: TypeAlias = objects.FormattingQuantity[Any]
+ Unit: TypeAlias = objects.FormattingUnit
diff --git a/pint/facets/group/__init__.py b/pint/facets/group/__init__.py
index e1fad04..b25ea85 100644
--- a/pint/facets/group/__init__.py
+++ b/pint/facets/group/__init__.py
@@ -11,7 +11,14 @@
from __future__ import annotations
from .definitions import GroupDefinition
-from .objects import Group
-from .registry import GroupRegistry
+from .objects import Group, GroupQuantity, GroupUnit
+from .registry import GroupRegistry, GenericGroupRegistry
-__all__ = ["GroupDefinition", "Group", "GroupRegistry"]
+__all__ = [
+ "GroupDefinition",
+ "Group",
+ "GroupRegistry",
+ "GenericGroupRegistry",
+ "GroupQuantity",
+ "GroupUnit",
+]
diff --git a/pint/facets/group/definitions.py b/pint/facets/group/definitions.py
index 554a63b..2f34750 100644
--- a/pint/facets/group/definitions.py
+++ b/pint/facets/group/definitions.py
@@ -11,7 +11,7 @@ from __future__ import annotations
from collections.abc import Iterable
from dataclasses import dataclass
-from ..._typing import Self
+from ...compat import Self
from ... import errors
from .. import plain
diff --git a/pint/facets/group/objects.py b/pint/facets/group/objects.py
index 200a323..64d91c1 100644
--- a/pint/facets/group/objects.py
+++ b/pint/facets/group/objects.py
@@ -8,9 +8,36 @@
from __future__ import annotations
+from typing import Callable, Any, TYPE_CHECKING, Generic
+
from collections.abc import Generator, Iterable
from ...util import SharedRegistryObject, getattr_maybe_raise
from .definitions import GroupDefinition
+from ..plain import PlainQuantity, PlainUnit, MagnitudeT
+
+if TYPE_CHECKING:
+ from ..plain import UnitDefinition
+
+ DefineFunc = Callable[
+ [
+ Any,
+ ],
+ None,
+ ]
+ AddUnitFunc = Callable[
+ [
+ UnitDefinition,
+ ],
+ None,
+ ]
+
+
+class GroupQuantity(Generic[MagnitudeT], PlainQuantity[MagnitudeT]):
+ pass
+
+
+class GroupUnit(PlainUnit):
+ pass
class Group(SharedRegistryObject):
@@ -57,7 +84,7 @@ class Group(SharedRegistryObject):
self._computed_members: frozenset[str] | None = None
@property
- def members(self):
+ def members(self) -> frozenset[str]:
"""Names of the units that are members of the group.
Calculated to include to all units in all included _used_groups.
@@ -143,7 +170,7 @@ class Group(SharedRegistryObject):
@classmethod
def from_lines(
- cls, lines: Iterable[str], define_func, non_int_type: type = float
+ cls, lines: Iterable[str], define_func: DefineFunc, non_int_type: type = float
) -> Group:
"""Return a Group object parsing an iterable of lines.
@@ -160,11 +187,15 @@ class Group(SharedRegistryObject):
"""
group_definition = GroupDefinition.from_lines(lines, non_int_type)
+
+ if group_definition is None:
+ raise ValueError(f"Could not define group from {lines}")
+
return cls.from_definition(group_definition, define_func)
@classmethod
def from_definition(
- cls, group_definition: GroupDefinition, add_unit_func=None
+ cls, group_definition: GroupDefinition, add_unit_func: AddUnitFunc | None = None
) -> Group:
grp = cls(group_definition.name)
diff --git a/pint/facets/group/registry.py b/pint/facets/group/registry.py
index 0d35ae0..f130e61 100644
--- a/pint/facets/group/registry.py
+++ b/pint/facets/group/registry.py
@@ -8,20 +8,28 @@
from __future__ import annotations
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Generic, Any
+from ...compat import TypeAlias
from ... import errors
if TYPE_CHECKING:
- from ..._typing import Unit
-
-from ...util import create_class_with_registry
-from ..plain import PlainRegistry, UnitDefinition
+ from ..._typing import Unit, UnitsContainer
+
+from ...util import create_class_with_registry, to_units_container
+from ..plain import (
+ GenericPlainRegistry,
+ UnitDefinition,
+ QuantityT,
+ UnitT,
+)
from .definitions import GroupDefinition
from . import objects
-class GroupRegistry(PlainRegistry):
+class GenericGroupRegistry(
+ Generic[QuantityT, UnitT], GenericPlainRegistry[QuantityT, UnitT]
+):
"""Handle of Groups.
Group units
@@ -34,7 +42,7 @@ class GroupRegistry(PlainRegistry):
# TODO: Change this to Group: Group to specify class
# and use introspection to get system class as a way
# to enjoy typing goodies
- Group = objects.Group
+ Group = type[objects.Group]
def __init__(self, **kwargs):
super().__init__(**kwargs)
@@ -46,7 +54,7 @@ class GroupRegistry(PlainRegistry):
def _init_dynamic_classes(self) -> None:
"""Generate subclasses on the fly and attach them to self"""
super()._init_dynamic_classes()
- self.Group = create_class_with_registry(self, self.Group)
+ self.Group = create_class_with_registry(self, objects.Group)
def _after_init(self) -> None:
"""Invoked at the end of ``__init__``.
@@ -113,8 +121,23 @@ class GroupRegistry(PlainRegistry):
return self.Group(name)
- def _get_compatible_units(self, input_units, group) -> frozenset[Unit]:
- ret = super()._get_compatible_units(input_units, group)
+ def get_compatible_units(
+ self, input_units: UnitsContainer, group: str | None = None
+ ) -> frozenset[Unit]:
+ """ """
+ if group is None:
+ return super().get_compatible_units(input_units)
+
+ input_units = to_units_container(input_units)
+
+ equiv = self._get_compatible_units(input_units, group)
+
+ return frozenset(self.Unit(eq) for eq in equiv)
+
+ def _get_compatible_units(
+ self, input_units: UnitsContainer, group: str | None = None
+ ) -> frozenset[str]:
+ ret = super()._get_compatible_units(input_units)
if not group:
return ret
@@ -124,3 +147,10 @@ class GroupRegistry(PlainRegistry):
else:
raise ValueError("Unknown Group with name '%s'" % group)
return frozenset(ret & members)
+
+
+class GroupRegistry(
+ GenericGroupRegistry[objects.GroupQuantity[Any], objects.GroupUnit]
+):
+ Quantity: TypeAlias = objects.GroupQuantity[Any]
+ Unit: TypeAlias = objects.GroupUnit
diff --git a/pint/facets/measurement/__init__.py b/pint/facets/measurement/__init__.py
index 21539dc..d36a5c3 100644
--- a/pint/facets/measurement/__init__.py
+++ b/pint/facets/measurement/__init__.py
@@ -11,6 +11,11 @@
from __future__ import annotations
from .objects import Measurement, MeasurementQuantity
-from .registry import MeasurementRegistry
+from .registry import MeasurementRegistry, GenericMeasurementRegistry
-__all__ = ["Measurement", "MeasurementQuantity", "MeasurementRegistry"]
+__all__ = [
+ "Measurement",
+ "MeasurementQuantity",
+ "MeasurementRegistry",
+ "GenericMeasurementRegistry",
+]
diff --git a/pint/facets/measurement/objects.py b/pint/facets/measurement/objects.py
index 5f3ba7a..b9cacda 100644
--- a/pint/facets/measurement/objects.py
+++ b/pint/facets/measurement/objects.py
@@ -10,15 +10,16 @@ from __future__ import annotations
import copy
import re
+from typing import Generic
from ...compat import ufloat
from ...formatting import _FORMATS, extract_custom_flags, siunitx_format_unit
-from ..plain import PlainQuantity
+from ..plain import PlainQuantity, PlainUnit, MagnitudeT
MISSING = object()
-class MeasurementQuantity(PlainQuantity):
+class MeasurementQuantity(Generic[MagnitudeT], PlainQuantity[MagnitudeT]):
# Measurement support
def plus_minus(self, error, relative=False):
if isinstance(error, self.__class__):
@@ -32,6 +33,10 @@ class MeasurementQuantity(PlainQuantity):
return self._REGISTRY.Measurement(copy.copy(self.magnitude), error, self._units)
+class MeasurementUnit(PlainUnit):
+ pass
+
+
class Measurement(PlainQuantity):
"""Implements a class to describe a quantity with uncertainty.
diff --git a/pint/facets/measurement/registry.py b/pint/facets/measurement/registry.py
index 0fc4391..4a3e878 100644
--- a/pint/facets/measurement/registry.py
+++ b/pint/facets/measurement/registry.py
@@ -9,15 +9,17 @@
from __future__ import annotations
-from ...compat import ufloat
+from typing import Generic, Any
+
+from ...compat import ufloat, TypeAlias
from ...util import create_class_with_registry
-from ..plain import PlainRegistry
-from .objects import MeasurementQuantity
+from ..plain import GenericPlainRegistry, QuantityT, UnitT
from . import objects
-class MeasurementRegistry(PlainRegistry):
- Quantity = MeasurementQuantity
+class GenericMeasurementRegistry(
+ Generic[QuantityT, UnitT], GenericPlainRegistry[QuantityT, UnitT]
+):
Measurement = objects.Measurement
def _init_dynamic_classes(self) -> None:
@@ -34,3 +36,12 @@ class MeasurementRegistry(PlainRegistry):
)
self.Measurement = no_uncertainties
+
+
+class MeasurementRegistry(
+ GenericMeasurementRegistry[
+ objects.MeasurementQuantity[Any], objects.MeasurementUnit
+ ]
+):
+ Quantity: TypeAlias = objects.MeasurementQuantity[Any]
+ Unit: TypeAlias = objects.MeasurementUnit
diff --git a/pint/facets/nonmultiplicative/__init__.py b/pint/facets/nonmultiplicative/__init__.py
index cbba410..eb3292b 100644
--- a/pint/facets/nonmultiplicative/__init__.py
+++ b/pint/facets/nonmultiplicative/__init__.py
@@ -15,8 +15,6 @@ from __future__ import annotations
# This import register LogarithmicConverter and OffsetConverter to be usable
# (via subclassing)
from .definitions import LogarithmicConverter, OffsetConverter # noqa: F401
-from .registry import NonMultiplicativeRegistry
+from .registry import NonMultiplicativeRegistry, GenericNonMultiplicativeRegistry
-__all__ = [
- "NonMultiplicativeRegistry",
-]
+__all__ = ["NonMultiplicativeRegistry", "GenericNonMultiplicativeRegistry"]
diff --git a/pint/facets/nonmultiplicative/objects.py b/pint/facets/nonmultiplicative/objects.py
index 0ab743e..8b944b1 100644
--- a/pint/facets/nonmultiplicative/objects.py
+++ b/pint/facets/nonmultiplicative/objects.py
@@ -8,10 +8,12 @@
from __future__ import annotations
-from ..plain import PlainQuantity
+from typing import Generic
+from ..plain import PlainQuantity, PlainUnit, MagnitudeT
-class NonMultiplicativeQuantity(PlainQuantity):
+
+class NonMultiplicativeQuantity(Generic[MagnitudeT], PlainQuantity[MagnitudeT]):
@property
def _is_multiplicative(self) -> bool:
"""Check if the PlainQuantity object has only multiplicative units."""
@@ -59,3 +61,7 @@ class NonMultiplicativeQuantity(PlainQuantity):
if next(iter(self._units.values())) != 1:
is_ok = False
return is_ok
+
+
+class NonMultiplicativeUnit(PlainUnit):
+ pass
diff --git a/pint/facets/nonmultiplicative/registry.py b/pint/facets/nonmultiplicative/registry.py
index 8bc04db..505406c 100644
--- a/pint/facets/nonmultiplicative/registry.py
+++ b/pint/facets/nonmultiplicative/registry.py
@@ -8,16 +8,22 @@
from __future__ import annotations
-from typing import Any
+from typing import Any, TypeVar, Generic
+from ...compat import TypeAlias
from ...errors import DimensionalityError, UndefinedUnitError
from ...util import UnitsContainer, logger
-from ..plain import PlainRegistry, UnitDefinition
+from ..plain import GenericPlainRegistry, UnitDefinition, QuantityT, UnitT
from .definitions import OffsetConverter, ScaleConverter
-from .objects import NonMultiplicativeQuantity
+from . import objects
-class NonMultiplicativeRegistry(PlainRegistry):
+T = TypeVar("T")
+
+
+class GenericNonMultiplicativeRegistry(
+ Generic[QuantityT, UnitT], GenericPlainRegistry[QuantityT, UnitT]
+):
"""Handle of non multiplicative units (e.g. Temperature).
Capabilities:
@@ -35,8 +41,6 @@ class NonMultiplicativeRegistry(PlainRegistry):
"""
- Quantity = NonMultiplicativeQuantity
-
def __init__(
self,
default_as_delta: bool = True,
@@ -58,14 +62,14 @@ class NonMultiplicativeRegistry(PlainRegistry):
input_string: str,
as_delta: bool | None = None,
case_sensitive: bool | None = None,
- ):
+ ) -> UnitsContainer:
""" """
if as_delta is None:
as_delta = self.default_as_delta
return super()._parse_units(input_string, as_delta, case_sensitive)
- def _add_unit(self, definition: UnitDefinition):
+ def _add_unit(self, definition: UnitDefinition) -> None:
super()._add_unit(definition)
if definition.is_multiplicative:
@@ -104,22 +108,60 @@ class NonMultiplicativeRegistry(PlainRegistry):
)
super()._add_unit(delta_def)
- def _is_multiplicative(self, u) -> bool:
- if u in self._units:
- return self._units[u].is_multiplicative
+ def _is_multiplicative(self, unit_name: str) -> bool:
+ """True if the unit is multiplicative.
+
+ Parameters
+ ----------
+ unit_name
+ Name of the unit to check.
+ Can be prefixed, pluralized or even an alias
+
+ Raises
+ ------
+ UndefinedUnitError
+ If the unit is not in the registyr.
+ """
+ if unit_name in self._units:
+ return self._units[unit_name].is_multiplicative
# If the unit is not in the registry might be because it is not
# registered with its prefixed version.
# TODO: Might be better to register them.
- names = self.parse_unit_name(u)
+ names = self.parse_unit_name(unit_name)
assert len(names) == 1
_, base_name, _ = names[0]
try:
return self._units[base_name].is_multiplicative
except KeyError:
- raise UndefinedUnitError(u)
+ raise UndefinedUnitError(unit_name)
+
+ def _validate_and_extract(self, units: UnitsContainer) -> str | None:
+ """Used to check if a given units is suitable for a simple
+ conversion.
+
+ Return None if all units are non-multiplicative
+ Return the unit name if a single non-multiplicative unit is found
+ and is raised to a power equals to 1.
+
+ Otherwise, raise an Exception.
+
+ Parameters
+ ----------
+ units
+ Compound dictionary.
+
+ Raises
+ ------
+ ValueError
+ If the more than a single non-multiplicative unit is present,
+ or a single one is present but raised to a power different from 1.
+
+ """
+
+ # TODO: document what happens if autoconvert_offset_to_baseunit
+ # TODO: Clarify docs
- def _validate_and_extract(self, units):
# u is for unit, e is for exponent
nonmult_units = [
(u, e) for u, e in units.items() if not self._is_multiplicative(u)
@@ -147,11 +189,16 @@ class NonMultiplicativeRegistry(PlainRegistry):
return None
- def _add_ref_of_log_or_offset_unit(self, offset_unit, all_units):
+ def _add_ref_of_log_or_offset_unit(
+ self, offset_unit: str, all_units: UnitsContainer
+ ) -> UnitsContainer:
slct_unit = self._units[offset_unit]
if slct_unit.is_logarithmic or (not slct_unit.is_multiplicative):
# Extract reference unit
slct_ref = slct_unit.reference
+
+ # TODO: Check that reference is None
+
# If reference unit is not dimensionless
if slct_ref != UnitsContainer():
# Extract reference unit
@@ -161,7 +208,9 @@ class NonMultiplicativeRegistry(PlainRegistry):
# Otherwise, return the units unmodified
return all_units
- def _convert(self, value, src, dst, inplace=False):
+ def _convert(
+ self, value: T, src: UnitsContainer, dst: UnitsContainer, inplace: bool = False
+ ) -> T:
"""Convert value from some source to destination units.
In addition to what is done by the PlainRegistry,
@@ -235,3 +284,12 @@ class NonMultiplicativeRegistry(PlainRegistry):
)
return value
+
+
+class NonMultiplicativeRegistry(
+ GenericNonMultiplicativeRegistry[
+ objects.NonMultiplicativeQuantity[Any], objects.NonMultiplicativeUnit
+ ]
+):
+ Quantity: TypeAlias = objects.NonMultiplicativeQuantity[Any]
+ Unit: TypeAlias = objects.NonMultiplicativeUnit
diff --git a/pint/facets/numpy/__init__.py b/pint/facets/numpy/__init__.py
index aad9508..2e38dc1 100644
--- a/pint/facets/numpy/__init__.py
+++ b/pint/facets/numpy/__init__.py
@@ -10,6 +10,6 @@
from __future__ import annotations
-from .registry import NumpyRegistry
+from .registry import NumpyRegistry, GenericNumpyRegistry
-__all__ = ["NumpyRegistry"]
+__all__ = ["NumpyRegistry", "GenericNumpyRegistry"]
diff --git a/pint/facets/numpy/quantity.py b/pint/facets/numpy/quantity.py
index 131983c..880f860 100644
--- a/pint/facets/numpy/quantity.py
+++ b/pint/facets/numpy/quantity.py
@@ -11,11 +11,11 @@ from __future__ import annotations
import functools
import math
import warnings
-from typing import Any
+from typing import Any, Generic
-from ..plain import PlainQuantity
+from ..plain import PlainQuantity, MagnitudeT
-from ..._typing import Shape, _MagnitudeType
+from ..._typing import Shape
from ...compat import _to_magnitude, np
from ...errors import DimensionalityError, PintTypeError, UnitStrippedWarning
from .numpy_func import (
@@ -42,7 +42,7 @@ def method_wraps(numpy_func):
return wrapper
-class NumpyQuantity(PlainQuantity):
+class NumpyQuantity(Generic[MagnitudeT], PlainQuantity[MagnitudeT]):
""" """
# NumPy function/ufunc support
@@ -130,11 +130,11 @@ class NumpyQuantity(PlainQuantity):
raise DimensionalityError("dimensionless", self._units)
return self.__class__(self.magnitude.clip(min, max, out, **kwargs), self._units)
- def fill(self: NumpyQuantity[np.ndarray], value) -> None:
+ def fill(self: NumpyQuantity, value) -> None:
self._units = value._units
return self.magnitude.fill(value.magnitude)
- def put(self: NumpyQuantity[np.ndarray], indices, values, mode="raise") -> None:
+ def put(self: NumpyQuantity, indices, values, mode="raise") -> None:
if isinstance(values, self.__class__):
values = values.to(self).magnitude
elif self.dimensionless:
@@ -144,11 +144,11 @@ class NumpyQuantity(PlainQuantity):
self.magnitude.put(indices, values, mode)
@property
- def real(self) -> NumpyQuantity[_MagnitudeType]:
+ def real(self) -> NumpyQuantity:
return self.__class__(self._magnitude.real, self._units)
@property
- def imag(self) -> NumpyQuantity[_MagnitudeType]:
+ def imag(self) -> NumpyQuantity:
return self.__class__(self._magnitude.imag, self._units)
@property
diff --git a/pint/facets/numpy/registry.py b/pint/facets/numpy/registry.py
index 11d57f3..e93de44 100644
--- a/pint/facets/numpy/registry.py
+++ b/pint/facets/numpy/registry.py
@@ -9,11 +9,20 @@
from __future__ import annotations
-from ..plain import PlainRegistry
+from typing import Generic, Any
+
+from ...compat import TypeAlias
+from ..plain import GenericPlainRegistry, QuantityT, UnitT
from .quantity import NumpyQuantity
from .unit import NumpyUnit
-class NumpyRegistry(PlainRegistry):
- Quantity = NumpyQuantity
- Unit = NumpyUnit
+class GenericNumpyRegistry(
+ Generic[QuantityT, UnitT], GenericPlainRegistry[QuantityT, UnitT]
+):
+ pass
+
+
+class NumpyRegistry(GenericPlainRegistry[NumpyQuantity[Any], NumpyUnit]):
+ Quantity: TypeAlias = NumpyQuantity[Any]
+ Unit: TypeAlias = NumpyUnit
diff --git a/pint/facets/plain/__init__.py b/pint/facets/plain/__init__.py
index 211d017..90bf2e3 100644
--- a/pint/facets/plain/__init__.py
+++ b/pint/facets/plain/__init__.py
@@ -19,9 +19,11 @@ from .definitions import (
UnitDefinition,
)
from .objects import PlainQuantity, PlainUnit
-from .registry import PlainRegistry
+from .registry import PlainRegistry, GenericPlainRegistry, QuantityT, UnitT
+from .quantity import MagnitudeT
__all__ = [
+ "GenericPlainRegistry",
"PlainUnit",
"PlainQuantity",
"PlainRegistry",
@@ -31,4 +33,7 @@ __all__ = [
"PrefixDefinition",
"ScaleConverter",
"UnitDefinition",
+ "QuantityT",
+ "UnitT",
+ "MagnitudeT",
]
diff --git a/pint/facets/plain/definitions.py b/pint/facets/plain/definitions.py
index 79a44f1..4b352e7 100644
--- a/pint/facets/plain/definitions.py
+++ b/pint/facets/plain/definitions.py
@@ -13,7 +13,7 @@ import numbers
import typing as ty
from dataclasses import dataclass
from functools import cached_property
-from typing import Callable, Any
+from typing import Any
from ..._typing import Magnitude
from ... import errors
@@ -69,11 +69,15 @@ class DefaultsDefinition:
@dataclass(frozen=True)
-class PrefixDefinition(errors.WithDefErr):
- """Definition of a prefix."""
-
+class NamedDefinition:
#: name of the prefix
name: str
+
+
+@dataclass(frozen=True)
+class PrefixDefinition(NamedDefinition, errors.WithDefErr):
+ """Definition of a prefix."""
+
#: scaling value for this prefix
value: numbers.Number
#: canonical symbol
@@ -90,8 +94,8 @@ class PrefixDefinition(errors.WithDefErr):
return bool(self.defined_symbol)
@cached_property
- def converter(self):
- return Converter.from_arguments(scale=self.value)
+ def converter(self) -> ScaleConverter:
+ return ScaleConverter(self.value)
def __post_init__(self):
if not errors.is_valid_prefix_name(self.name):
@@ -110,22 +114,19 @@ class PrefixDefinition(errors.WithDefErr):
@dataclass(frozen=True)
-class UnitDefinition(errors.WithDefErr):
+class UnitDefinition(NamedDefinition, errors.WithDefErr):
"""Definition of a unit."""
- #: canonical name of the unit
- name: str
#: canonical symbol
defined_symbol: str | None
#: additional names for the same unit
aliases: tuple[str]
#: A functiont that converts a value in these units into the reference units
- converter: Callable[
- [
- Magnitude,
- ],
- Magnitude,
- ] | Converter | None
+ # TODO: this has changed as converter is now annotated as converter.
+ # Briefly, in several places converter attributes like as_multiplicative were
+ # accesed. So having a generic function is a no go.
+ # I guess this was never used as errors where not raised.
+ converter: Converter | None
#: Reference units.
reference: UnitsContainer | None
@@ -190,7 +191,7 @@ class UnitDefinition(errors.WithDefErr):
def is_base(self) -> bool:
"""Indicates if it is a base unit."""
- # TODO: why is this here
+ # TODO: This is set in __post_init__
return self._is_base
@property
@@ -215,17 +216,14 @@ class UnitDefinition(errors.WithDefErr):
@dataclass(frozen=True)
-class DimensionDefinition(errors.WithDefErr):
+class DimensionDefinition(NamedDefinition, errors.WithDefErr):
"""Definition of a root dimension"""
- #: name of the dimension
- name: str
-
@property
- def is_base(self):
+ def is_base(self) -> bool:
return True
- def __post_init__(self):
+ def __post_init__(self) -> None:
if not errors.is_valid_dimension_name(self.name):
raise self.def_err(errors.MSG_INVALID_DIMENSION_NAME)
@@ -238,7 +236,7 @@ class DerivedDimensionDefinition(DimensionDefinition):
reference: UnitsContainer
@property
- def is_base(self):
+ def is_base(self) -> bool:
return False
def __post_init__(self):
diff --git a/pint/facets/plain/qto.py b/pint/facets/plain/qto.py
new file mode 100644
index 0000000..72b8157
--- /dev/null
+++ b/pint/facets/plain/qto.py
@@ -0,0 +1,386 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+import bisect
+import math
+import numbers
+from ...util import infer_base_unit
+import warnings
+from ...compat import (
+ mip_INF,
+ mip_INTEGER,
+ mip_model,
+ mip_Model,
+ mip_OptimizationStatus,
+ mip_xsum,
+)
+
+if TYPE_CHECKING:
+ from ..._typing import UnitLike
+ from ...util import UnitsContainer
+ from .quantity import PlainQuantity
+
+
+def _get_reduced_units(
+ quantity: PlainQuantity, units: UnitsContainer
+) -> UnitsContainer:
+ # loop through individual units and compare to each other unit
+ # can we do better than a nested loop here?
+ for unit1, exp in units.items():
+ # make sure it wasn't already reduced to zero exponent on prior pass
+ if unit1 not in units:
+ continue
+ for unit2 in units:
+ # get exponent after reduction
+ exp = units[unit1]
+ if unit1 != unit2:
+ power = quantity._REGISTRY._get_dimensionality_ratio(unit1, unit2)
+ if power:
+ units = units.add(unit2, exp / power).remove([unit1])
+ break
+ return units
+
+
+def ito_reduced_units(quantity: PlainQuantity) -> None:
+ """Return PlainQuantity scaled in place to reduced units, i.e. one unit per
+ dimension. This will not reduce compound units (e.g., 'J/kg' will not
+ be reduced to m**2/s**2), nor can it make use of contexts at this time.
+ """
+
+ # shortcuts in case we're dimensionless or only a single unit
+ if quantity.dimensionless:
+ return quantity.ito({})
+ if len(quantity._units) == 1:
+ return None
+
+ units = quantity._units.copy()
+ new_units = _get_reduced_units(quantity, units)
+
+ return quantity.ito(new_units)
+
+
+def to_reduced_units(
+ quantity: PlainQuantity,
+) -> PlainQuantity:
+ """Return PlainQuantity scaled in place to reduced units, i.e. one unit per
+ dimension. This will not reduce compound units (intentionally), nor
+ can it make use of contexts at this time.
+ """
+
+ # shortcuts in case we're dimensionless or only a single unit
+ if quantity.dimensionless:
+ return quantity.to({})
+ if len(quantity._units) == 1:
+ return quantity
+
+ units = quantity._units.copy()
+ new_units = _get_reduced_units(quantity, units)
+
+ return quantity.to(new_units)
+
+
+def to_compact(
+ quantity: PlainQuantity, unit: UnitsContainer | None = None
+) -> PlainQuantity:
+ """ "Return PlainQuantity rescaled to compact, human-readable units.
+
+ To get output in terms of a different unit, use the unit parameter.
+
+
+ Examples
+ --------
+
+ >>> import pint
+ >>> ureg = pint.UnitRegistry()
+ >>> (200e-9*ureg.s).to_compact()
+ <Quantity(200.0, 'nanosecond')>
+ >>> (1e-2*ureg('kg m/s^2')).to_compact('N')
+ <Quantity(10.0, 'millinewton')>
+ """
+
+ if not isinstance(quantity.magnitude, numbers.Number):
+ msg = "to_compact applied to non numerical types " "has an undefined behavior."
+ w = RuntimeWarning(msg)
+ warnings.warn(w, stacklevel=2)
+ return quantity
+
+ if (
+ quantity.unitless
+ or quantity.magnitude == 0
+ or math.isnan(quantity.magnitude)
+ or math.isinf(quantity.magnitude)
+ ):
+ return quantity
+
+ SI_prefixes: dict[int, str] = {}
+ for prefix in quantity._REGISTRY._prefixes.values():
+ try:
+ scale = prefix.converter.scale
+ # Kludgy way to check if this is an SI prefix
+ log10_scale = int(math.log10(scale))
+ if log10_scale == math.log10(scale):
+ SI_prefixes[log10_scale] = prefix.name
+ except Exception:
+ SI_prefixes[0] = ""
+
+ SI_prefixes_list = sorted(SI_prefixes.items())
+ SI_powers = [item[0] for item in SI_prefixes_list]
+ SI_bases = [item[1] for item in SI_prefixes_list]
+
+ if unit is None:
+ unit = infer_base_unit(quantity, registry=quantity._REGISTRY)
+ else:
+ unit = infer_base_unit(quantity.__class__(1, unit), registry=quantity._REGISTRY)
+
+ q_base = quantity.to(unit)
+
+ magnitude = q_base.magnitude
+
+ units = list(q_base._units.items())
+ units_numerator = [a for a in units if a[1] > 0]
+
+ if len(units_numerator) > 0:
+ unit_str, unit_power = units_numerator[0]
+ else:
+ unit_str, unit_power = units[0]
+
+ if unit_power > 0:
+ power = math.floor(math.log10(abs(magnitude)) / float(unit_power) / 3) * 3
+ else:
+ power = math.ceil(math.log10(abs(magnitude)) / float(unit_power) / 3) * 3
+
+ index = bisect.bisect_left(SI_powers, power)
+
+ if index >= len(SI_bases):
+ index = -1
+
+ prefix_str = SI_bases[index]
+
+ new_unit_str = prefix_str + unit_str
+ new_unit_container = q_base._units.rename(unit_str, new_unit_str)
+
+ return quantity.to(new_unit_container)
+
+
+def to_preferred(
+ quantity: PlainQuantity, preferred_units: list[UnitLike]
+) -> PlainQuantity:
+ """Return Quantity converted to a unit composed of the preferred units.
+
+ Examples
+ --------
+
+ >>> import pint
+ >>> ureg = pint.UnitRegistry()
+ >>> (1*ureg.acre).to_preferred([ureg.meters])
+ <Quantity(4046.87261, 'meter ** 2')>
+ >>> (1*(ureg.force_pound*ureg.m)).to_preferred([ureg.W])
+ <Quantity(4.44822162, 'second * watt')>
+ """
+
+ if not quantity.dimensionality:
+ return quantity
+
+ # The optimizer isn't perfect, and will sometimes miss obvious solutions.
+ # This sub-algorithm is less powerful, but always finds the very simple solutions.
+ def find_simple():
+ best_ratio = None
+ best_unit = None
+ self_dims = sorted(quantity.dimensionality)
+ self_exps = [quantity.dimensionality[d] for d in self_dims]
+ s_exps_head, *s_exps_tail = self_exps
+ n = len(s_exps_tail)
+ for preferred_unit in preferred_units:
+ dims = sorted(preferred_unit.dimensionality)
+ if dims == self_dims:
+ p_exps_head, *p_exps_tail = (
+ preferred_unit.dimensionality[d] for d in dims
+ )
+ if all(
+ s_exps_tail[i] * p_exps_head == p_exps_tail[i] ** s_exps_head
+ for i in range(n)
+ ):
+ ratio = p_exps_head / s_exps_head
+ ratio = max(ratio, 1 / ratio)
+ if best_ratio is None or ratio < best_ratio:
+ best_ratio = ratio
+ best_unit = preferred_unit ** (s_exps_head / p_exps_head)
+ return best_unit
+
+ simple = find_simple()
+ if simple is not None:
+ return quantity.to(simple)
+
+ # For each dimension (e.g. T(ime), L(ength), M(ass)), assign a default base unit from
+ # the collection of base units
+
+ unit_selections = {
+ base_unit.dimensionality: base_unit
+ for base_unit in map(quantity._REGISTRY.Unit, quantity._REGISTRY._base_units)
+ }
+
+ # Override the default unit of each dimension with the 1D-units used in this Quantity
+ unit_selections.update(
+ {
+ unit.dimensionality: unit
+ for unit in map(quantity._REGISTRY.Unit, quantity._units.keys())
+ }
+ )
+
+ # Determine the preferred unit for each dimensionality from the preferred_units
+ # (A prefered unit doesn't have to be only one dimensional, e.g. Watts)
+ preferred_dims = {
+ preferred_unit.dimensionality: preferred_unit
+ for preferred_unit in map(quantity._REGISTRY.Unit, preferred_units)
+ }
+
+ # Combine the defaults and preferred, favoring the preferred
+ unit_selections.update(preferred_dims)
+
+ # This algorithm has poor asymptotic time complexity, so first reduce the considered
+ # dimensions and units to only those that are useful to the problem
+
+ # The dimensions (without powers) of this Quantity
+ dimension_set = set(quantity.dimensionality)
+
+ # Getting zero exponents in dimensions not in dimension_set can be facilitated
+ # by units that interact with that dimension and one or more dimension_set members.
+ # For example MT^1 * LT^-1 lets you get MLT^0 when T is not in dimension_set.
+ # For each candidate unit that interacts with a dimension_set member, add the
+ # candidate unit's other dimensions to dimension_set, and repeat until no more
+ # dimensions are selected.
+
+ discovery_done = False
+ while not discovery_done:
+ discovery_done = True
+ for d in unit_selections:
+ unit_dimensions = set(d)
+ intersection = unit_dimensions.intersection(dimension_set)
+ if 0 < len(intersection) < len(unit_dimensions):
+ # there are dimensions in this unit that are in dimension set
+ # and others that are not in dimension set
+ dimension_set = dimension_set.union(unit_dimensions)
+ discovery_done = False
+ break
+
+ # filter out dimensions and their unit selections that don't interact with any
+ # dimension_set members
+ unit_selections = {
+ dimensionality: unit
+ for dimensionality, unit in unit_selections.items()
+ if set(dimensionality).intersection(dimension_set)
+ }
+
+ # update preferred_units with the selected units that were originally preferred
+ preferred_units = list(
+ {u for d, u in unit_selections.items() if d in preferred_dims}
+ )
+ preferred_units.sort(key=str) # for determinism
+
+ # and unpreferred_units are the selected units that weren't originally preferred
+ unpreferred_units = list(
+ {u for d, u in unit_selections.items() if d not in preferred_dims}
+ )
+ unpreferred_units.sort(key=str) # for determinism
+
+ # for indexability
+ dimensions = list(dimension_set)
+ dimensions.sort() # for determinism
+
+ # the powers for each elemet of dimensions (the list) for this Quantity
+ dimensionality = [quantity.dimensionality[dimension] for dimension in dimensions]
+
+ # Now that the input data is minimized, setup the optimization problem
+
+ # use mip to select units from preferred units
+
+ model = mip_Model()
+ model.verbose = 0
+
+ # Make one variable for each candidate unit
+
+ vars = [
+ model.add_var(str(unit), lb=-mip_INF, ub=mip_INF, var_type=mip_INTEGER)
+ for unit in (preferred_units + unpreferred_units)
+ ]
+
+ # where [u1 ... uN] are powers of N candidate units (vars)
+ # and [d1(uI) ... dK(uI)] are the K dimensional exponents of candidate unit I
+ # and [t1 ... tK] are the dimensional exponents of the quantity (quantity)
+ # create the following constraints
+ #
+ # ⎡ d1(u1) ⋯ dK(u1) ⎤
+ # [ u1 ⋯ uN ] * ⎢ ⋮ ⋱ ⎢ = [ t1 ⋯ tK ]
+ # ⎣ d1(uN) dK(uN) ⎦
+ #
+ # in English, the units we choose, and their exponents, when combined, must have the
+ # target dimensionality
+
+ matrix = [
+ [preferred_unit.dimensionality[dimension] for dimension in dimensions]
+ for preferred_unit in (preferred_units + unpreferred_units)
+ ]
+
+ # Do the matrix multiplication with mip_model.xsum for performance and create constraints
+ for i in range(len(dimensions)):
+ dot = mip_model.xsum([var * vector[i] for var, vector in zip(vars, matrix)])
+ # add constraint to the model
+ model += dot == dimensionality[i]
+
+ # where [c1 ... cN] are costs, 1 when a preferred variable, and a large value when not
+ # minimize sum(abs(u1) * c1 ... abs(uN) * cN)
+
+ # linearize the optimization variable via a proxy
+ objective = model.add_var("objective", lb=0, ub=mip_INF, var_type=mip_INTEGER)
+
+ # Constrain the objective to be equal to the sums of the absolute values of the preferred
+ # unit powers. Do this by making a separate constraint for each permutation of signedness.
+ # Also apply the cost coefficient, which causes the output to prefer the preferred units
+
+ # prefer units that interact with fewer dimensions
+ cost = [len(p.dimensionality) for p in preferred_units]
+
+ # set the cost for non preferred units to a higher number
+ bias = (
+ max(map(abs, dimensionality)) * max((1, *cost)) * 10
+ ) # arbitrary, just needs to be larger
+ cost.extend([bias] * len(unpreferred_units))
+
+ for i in range(1 << len(vars)):
+ sum = mip_xsum(
+ [
+ (-1 if i & 1 << (len(vars) - j - 1) else 1) * cost[j] * var
+ for j, var in enumerate(vars)
+ ]
+ )
+ model += objective >= sum
+
+ model.objective = objective
+
+ # run the mips minimizer and extract the result if successful
+ if model.optimize() == mip_OptimizationStatus.OPTIMAL:
+ optimal_units = []
+ min_objective = float("inf")
+ for i in range(model.num_solutions):
+ if model.objective_values[i] < min_objective:
+ min_objective = model.objective_values[i]
+ optimal_units.clear()
+ elif model.objective_values[i] > min_objective:
+ continue
+
+ temp_unit = quantity._REGISTRY.Unit("")
+ for var in vars:
+ if var.xi(i):
+ temp_unit *= quantity._REGISTRY.Unit(var.name) ** var.xi(i)
+ optimal_units.append(temp_unit)
+
+ sorting_keys = {tuple(sorted(unit._units)): unit for unit in optimal_units}
+ min_key = sorted(sorting_keys)[0]
+ result_unit = sorting_keys[min_key]
+
+ return quantity.to(result_unit)
+
+ # for whatever reason, a solution wasn't found
+ # return the original quantity
+ return quantity
diff --git a/pint/facets/plain/quantity.py b/pint/facets/plain/quantity.py
index 1eaaa3d..0058549 100644
--- a/pint/facets/plain/quantity.py
+++ b/pint/facets/plain/quantity.py
@@ -8,37 +8,22 @@
from __future__ import annotations
-import bisect
+
import copy
import datetime
import locale
-import math
import numbers
import operator
-import warnings
-from typing import (
- TYPE_CHECKING,
- Any,
- Callable,
- Generic,
- TypeVar,
- overload,
-)
-from collections.abc import Iterable, Iterator, Sequence
+from typing import TYPE_CHECKING, Any, Callable, overload, Generic, TypeVar
+from collections.abc import Iterator, Sequence
-from ..._typing import S, UnitLike, _MagnitudeType
+from ..._typing import UnitLike, QuantityOrUnitLike, Magnitude
from ...compat import (
HAS_NUMPY,
_to_magnitude,
eq,
is_duck_array_type,
is_upcast_type,
- mip_INF,
- mip_INTEGER,
- mip_model,
- mip_Model,
- mip_OptimizationStatus,
- mip_xsum,
np,
zero_or_nan,
)
@@ -47,11 +32,11 @@ from ...util import (
PrettyIPython,
SharedRegistryObject,
UnitsContainer,
- infer_base_unit,
logger,
to_units_container,
)
from .definitions import UnitDefinition
+from . import qto
if TYPE_CHECKING:
from ..context import Context
@@ -61,6 +46,10 @@ if TYPE_CHECKING:
if HAS_NUMPY:
import numpy as np # noqa
+MagnitudeT = TypeVar("MagnitudeT", bound=Magnitude)
+
+T = TypeVar("T", bound=Magnitude)
+
def reduce_dimensions(f):
def wrapped(self, *args, **kwargs):
@@ -115,14 +104,10 @@ def method_wraps(numpy_func):
return wrapper
-# Workaround to bypass dynamically generated PlainQuantity with overload method
-Magnitude = TypeVar("Magnitude")
-
-
# TODO: remove all nonmultiplicative remnants
-class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]):
+class PlainQuantity(Generic[MagnitudeT], PrettyIPython, SharedRegistryObject):
"""Implements a class to describe a physical quantity:
the product of a numerical value and a unit of measurement.
@@ -140,7 +125,7 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
#: Default formatting string.
default_format: str = ""
- _magnitude: _MagnitudeType
+ _magnitude: MagnitudeT
@property
def ndim(self) -> int:
@@ -156,11 +141,7 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
def force_ndarray_like(self) -> bool:
return self._REGISTRY.force_ndarray_like
- @property
- def UnitsContainer(self) -> Callable[..., UnitsContainerT]:
- return self._REGISTRY.UnitsContainer
-
- def __reduce__(self) -> tuple:
+ def __reduce__(self) -> tuple[type, Magnitude, UnitsContainer]:
"""Allow pickling quantities. Since UnitRegistries are not pickled, upon
unpickling the new object is always attached to the application registry.
"""
@@ -168,12 +149,17 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
# Note: type(self) would be a mistake as subclasses built by
# dinamically can't be pickled
+ # TODO: Check if this is still the case.
return _unpickle_quantity, (PlainQuantity, self.magnitude, self._units)
+ # @overload
+ # def __new__(
+ # cls, value: T, units: UnitLike | None = None
+ # ) -> PlainQuantity[T]:
+ # ...
+
@overload
- def __new__(
- cls, value: str, units: UnitLike | None = None
- ) -> PlainQuantity[Magnitude]:
+ def __new__(cls, value: str, units: UnitLike | None = None) -> PlainQuantity[int]:
...
@overload
@@ -182,17 +168,11 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
) -> PlainQuantity[np.ndarray]:
...
- @overload
- def __new__(
- cls, value: PlainQuantity[Magnitude], units: UnitLike | None = None
- ) -> PlainQuantity[Magnitude]:
- ...
-
- @overload
- def __new__(
- cls, value: Magnitude, units: UnitLike | None = None
- ) -> PlainQuantity[Magnitude]:
- ...
+ # @overload
+ # def __new__(
+ # cls, value: PlainQuantity[Any], units: UnitLike | None = None
+ # ) -> PlainQuantity[Any]:
+ # ...
def __new__(cls, value, units=None):
if is_upcast_type(type(value)):
@@ -243,7 +223,7 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
return inst
- def __iter__(self: PlainQuantity[Iterable[S]]) -> Iterator[S]:
+ def __iter__(self: PlainQuantity[MagnitudeT]) -> Iterator[Any]:
# Make sure that, if self.magnitude is not iterable, we raise TypeError as soon
# as one calls iter(self) without waiting for the first element to be drawn from
# the iterator
@@ -255,11 +235,11 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
return it_outer()
- def __copy__(self) -> PlainQuantity[_MagnitudeType]:
+ def __copy__(self) -> PlainQuantity[MagnitudeT]:
ret = self.__class__(copy.copy(self._magnitude), self._units)
return ret
- def __deepcopy__(self, memo) -> PlainQuantity[_MagnitudeType]:
+ def __deepcopy__(self, memo) -> PlainQuantity[MagnitudeT]:
ret = self.__class__(
copy.deepcopy(self._magnitude, memo), copy.deepcopy(self._units, memo)
)
@@ -285,16 +265,16 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
return hash((self_base.__class__, self_base.magnitude, self_base.units))
@property
- def magnitude(self) -> _MagnitudeType:
+ def magnitude(self) -> MagnitudeT:
"""PlainQuantity's magnitude. Long form for `m`"""
return self._magnitude
@property
- def m(self) -> _MagnitudeType:
+ def m(self) -> MagnitudeT:
"""PlainQuantity's magnitude. Short form for `magnitude`"""
return self._magnitude
- def m_as(self, units) -> _MagnitudeType:
+ def m_as(self, units) -> MagnitudeT:
"""PlainQuantity's magnitude expressed in particular units.
Parameters
@@ -351,8 +331,8 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
@classmethod
def from_list(
- cls, quant_list: list[PlainQuantity], units=None
- ) -> PlainQuantity[np.ndarray]:
+ cls, quant_list: list[PlainQuantity[MagnitudeT]], units=None
+ ) -> PlainQuantity[MagnitudeT]:
"""Transforms a list of Quantities into an numpy.array quantity.
If no units are specified, the unit of the first element will be used.
Same as from_sequence.
@@ -375,8 +355,8 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
@classmethod
def from_sequence(
- cls, seq: Sequence[PlainQuantity], units=None
- ) -> PlainQuantity[np.ndarray]:
+ cls, seq: Sequence[PlainQuantity[MagnitudeT]], units=None
+ ) -> PlainQuantity[MagnitudeT]:
"""Transforms a sequence of Quantities into an numpy.array quantity.
If no units are specified, the unit of the first element will be used.
@@ -414,7 +394,7 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
def from_tuple(cls, tup):
return cls(tup[0], cls._REGISTRY.UnitsContainer(tup[1]))
- def to_tuple(self) -> tuple[_MagnitudeType, tuple[tuple[str]]]:
+ def to_tuple(self) -> tuple[MagnitudeT, tuple[tuple[str]]]:
return self.m, tuple(self._units.items())
def compatible_units(self, *contexts):
@@ -452,7 +432,7 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
except DimensionalityError:
return False
- if isinstance(other, (PlainQuantity, PlainUnit)):
+ if isinstance(other, (PlainQuantity[MagnitudeT], PlainUnit)):
return self.dimensionality == other.dimensionality
if isinstance(other, str):
@@ -481,7 +461,9 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
inplace=is_duck_array_type(type(self._magnitude)),
)
- def ito(self, other=None, *contexts, **ctx_kwargs) -> None:
+ def ito(
+ self, other: QuantityOrUnitLike | None = None, *contexts, **ctx_kwargs
+ ) -> None:
"""Inplace rescale to different units.
Parameters
@@ -500,7 +482,9 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
return None
- def to(self, other=None, *contexts, **ctx_kwargs) -> PlainQuantity[_MagnitudeType]:
+ def to(
+ self, other: QuantityOrUnitLike | None = None, *contexts, **ctx_kwargs
+ ) -> PlainQuantity:
"""Return PlainQuantity rescaled to different units.
Parameters
@@ -532,7 +516,7 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
return None
- def to_root_units(self) -> PlainQuantity[_MagnitudeType]:
+ def to_root_units(self) -> PlainQuantity[MagnitudeT]:
"""Return PlainQuantity rescaled to root units."""
_, other = self._REGISTRY._get_root_units(self._units)
@@ -551,7 +535,7 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
return None
- def to_base_units(self) -> PlainQuantity[_MagnitudeType]:
+ def to_base_units(self) -> PlainQuantity[MagnitudeT]:
"""Return PlainQuantity rescaled to plain units."""
_, other = self._REGISTRY._get_base_units(self._units)
@@ -560,361 +544,13 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
return self.__class__(magnitude, other)
- def _get_reduced_units(self, units):
- # loop through individual units and compare to each other unit
- # can we do better than a nested loop here?
- for unit1, exp in units.items():
- # make sure it wasn't already reduced to zero exponent on prior pass
- if unit1 not in units:
- continue
- for unit2 in units:
- # get exponent after reduction
- exp = units[unit1]
- if unit1 != unit2:
- power = self._REGISTRY._get_dimensionality_ratio(unit1, unit2)
- if power:
- units = units.add(unit2, exp / power).remove([unit1])
- break
- return units
-
- def ito_reduced_units(self) -> None:
- """Return PlainQuantity scaled in place to reduced units, i.e. one unit per
- dimension. This will not reduce compound units (e.g., 'J/kg' will not
- be reduced to m**2/s**2), nor can it make use of contexts at this time.
- """
-
- # shortcuts in case we're dimensionless or only a single unit
- if self.dimensionless:
- return self.ito({})
- if len(self._units) == 1:
- return None
-
- units = self._units.copy()
- new_units = self._get_reduced_units(units)
-
- return self.ito(new_units)
-
- def to_reduced_units(self) -> PlainQuantity[_MagnitudeType]:
- """Return PlainQuantity scaled in place to reduced units, i.e. one unit per
- dimension. This will not reduce compound units (intentionally), nor
- can it make use of contexts at this time.
- """
-
- # shortcuts in case we're dimensionless or only a single unit
- if self.dimensionless:
- return self.to({})
- if len(self._units) == 1:
- return self
-
- units = self._units.copy()
- new_units = self._get_reduced_units(units)
-
- return self.to(new_units)
-
- def to_compact(self, unit=None) -> PlainQuantity[_MagnitudeType]:
- """ "Return PlainQuantity rescaled to compact, human-readable units.
-
- To get output in terms of a different unit, use the unit parameter.
-
-
- Examples
- --------
-
- >>> import pint
- >>> ureg = pint.UnitRegistry()
- >>> (200e-9*ureg.s).to_compact()
- <Quantity(200.0, 'nanosecond')>
- >>> (1e-2*ureg('kg m/s^2')).to_compact('N')
- <Quantity(10.0, 'millinewton')>
- """
-
- if not isinstance(self.magnitude, numbers.Number):
- msg = (
- "to_compact applied to non numerical types "
- "has an undefined behavior."
- )
- w = RuntimeWarning(msg)
- warnings.warn(w, stacklevel=2)
- return self
-
- if (
- self.unitless
- or self.magnitude == 0
- or math.isnan(self.magnitude)
- or math.isinf(self.magnitude)
- ):
- return self
-
- SI_prefixes: dict[int, str] = {}
- for prefix in self._REGISTRY._prefixes.values():
- try:
- scale = prefix.converter.scale
- # Kludgy way to check if this is an SI prefix
- log10_scale = int(math.log10(scale))
- if log10_scale == math.log10(scale):
- SI_prefixes[log10_scale] = prefix.name
- except Exception:
- SI_prefixes[0] = ""
-
- SI_prefixes_list = sorted(SI_prefixes.items())
- SI_powers = [item[0] for item in SI_prefixes_list]
- SI_bases = [item[1] for item in SI_prefixes_list]
-
- if unit is None:
- unit = infer_base_unit(self, registry=self._REGISTRY)
- else:
- unit = infer_base_unit(self.__class__(1, unit), registry=self._REGISTRY)
-
- q_base = self.to(unit)
-
- magnitude = q_base.magnitude
-
- units = list(q_base._units.items())
- units_numerator = [a for a in units if a[1] > 0]
-
- if len(units_numerator) > 0:
- unit_str, unit_power = units_numerator[0]
- else:
- unit_str, unit_power = units[0]
-
- if unit_power > 0:
- power = math.floor(math.log10(abs(magnitude)) / float(unit_power) / 3) * 3
- else:
- power = math.ceil(math.log10(abs(magnitude)) / float(unit_power) / 3) * 3
-
- index = bisect.bisect_left(SI_powers, power)
-
- if index >= len(SI_bases):
- index = -1
-
- prefix_str = SI_bases[index]
-
- new_unit_str = prefix_str + unit_str
- new_unit_container = q_base._units.rename(unit_str, new_unit_str)
-
- return self.to(new_unit_container)
-
- def to_preferred(
- self, preferred_units: list[UnitLike]
- ) -> PlainQuantity[_MagnitudeType]:
- """Return Quantity converted to a unit composed of the preferred units.
-
- Examples
- --------
-
- >>> import pint
- >>> ureg = pint.UnitRegistry()
- >>> (1*ureg.acre).to_preferred([ureg.meters])
- <Quantity(4046.87261, 'meter ** 2')>
- >>> (1*(ureg.force_pound*ureg.m)).to_preferred([ureg.W])
- <Quantity(4.44822162, 'second * watt')>
- """
-
- if not self.dimensionality:
- return self
-
- # The optimizer isn't perfect, and will sometimes miss obvious solutions.
- # This sub-algorithm is less powerful, but always finds the very simple solutions.
- def find_simple():
- best_ratio = None
- best_unit = None
- self_dims = sorted(self.dimensionality)
- self_exps = [self.dimensionality[d] for d in self_dims]
- s_exps_head, *s_exps_tail = self_exps
- n = len(s_exps_tail)
- for preferred_unit in preferred_units:
- dims = sorted(preferred_unit.dimensionality)
- if dims == self_dims:
- p_exps_head, *p_exps_tail = (
- preferred_unit.dimensionality[d] for d in dims
- )
- if all(
- s_exps_tail[i] * p_exps_head == p_exps_tail[i] ** s_exps_head
- for i in range(n)
- ):
- ratio = p_exps_head / s_exps_head
- ratio = max(ratio, 1 / ratio)
- if best_ratio is None or ratio < best_ratio:
- best_ratio = ratio
- best_unit = preferred_unit ** (s_exps_head / p_exps_head)
- return best_unit
-
- simple = find_simple()
- if simple is not None:
- return self.to(simple)
-
- # For each dimension (e.g. T(ime), L(ength), M(ass)), assign a default base unit from
- # the collection of base units
-
- unit_selections = {
- base_unit.dimensionality: base_unit
- for base_unit in map(self._REGISTRY.Unit, self._REGISTRY._base_units)
- }
-
- # Override the default unit of each dimension with the 1D-units used in this Quantity
- unit_selections.update(
- {
- unit.dimensionality: unit
- for unit in map(self._REGISTRY.Unit, self._units.keys())
- }
- )
-
- # Determine the preferred unit for each dimensionality from the preferred_units
- # (A prefered unit doesn't have to be only one dimensional, e.g. Watts)
- preferred_dims = {
- preferred_unit.dimensionality: preferred_unit
- for preferred_unit in map(self._REGISTRY.Unit, preferred_units)
- }
-
- # Combine the defaults and preferred, favoring the preferred
- unit_selections.update(preferred_dims)
-
- # This algorithm has poor asymptotic time complexity, so first reduce the considered
- # dimensions and units to only those that are useful to the problem
-
- # The dimensions (without powers) of this Quantity
- dimension_set = set(self.dimensionality)
-
- # Getting zero exponents in dimensions not in dimension_set can be facilitated
- # by units that interact with that dimension and one or more dimension_set members.
- # For example MT^1 * LT^-1 lets you get MLT^0 when T is not in dimension_set.
- # For each candidate unit that interacts with a dimension_set member, add the
- # candidate unit's other dimensions to dimension_set, and repeat until no more
- # dimensions are selected.
-
- discovery_done = False
- while not discovery_done:
- discovery_done = True
- for d in unit_selections:
- unit_dimensions = set(d)
- intersection = unit_dimensions.intersection(dimension_set)
- if 0 < len(intersection) < len(unit_dimensions):
- # there are dimensions in this unit that are in dimension set
- # and others that are not in dimension set
- dimension_set = dimension_set.union(unit_dimensions)
- discovery_done = False
- break
-
- # filter out dimensions and their unit selections that don't interact with any
- # dimension_set members
- unit_selections = {
- dimensionality: unit
- for dimensionality, unit in unit_selections.items()
- if set(dimensionality).intersection(dimension_set)
- }
-
- # update preferred_units with the selected units that were originally preferred
- preferred_units = list(
- {u for d, u in unit_selections.items() if d in preferred_dims}
- )
- preferred_units.sort(key=str) # for determinism
-
- # and unpreferred_units are the selected units that weren't originally preferred
- unpreferred_units = list(
- {u for d, u in unit_selections.items() if d not in preferred_dims}
- )
- unpreferred_units.sort(key=str) # for determinism
-
- # for indexability
- dimensions = list(dimension_set)
- dimensions.sort() # for determinism
-
- # the powers for each elemet of dimensions (the list) for this Quantity
- dimensionality = [self.dimensionality[dimension] for dimension in dimensions]
-
- # Now that the input data is minimized, setup the optimization problem
-
- # use mip to select units from preferred units
-
- model = mip_Model()
- model.verbose = 0
-
- # Make one variable for each candidate unit
-
- vars = [
- model.add_var(str(unit), lb=-mip_INF, ub=mip_INF, var_type=mip_INTEGER)
- for unit in (preferred_units + unpreferred_units)
- ]
-
- # where [u1 ... uN] are powers of N candidate units (vars)
- # and [d1(uI) ... dK(uI)] are the K dimensional exponents of candidate unit I
- # and [t1 ... tK] are the dimensional exponents of the quantity (self)
- # create the following constraints
- #
- # ⎡ d1(u1) ⋯ dK(u1) ⎤
- # [ u1 ⋯ uN ] * ⎢ ⋮ ⋱ ⎢ = [ t1 ⋯ tK ]
- # ⎣ d1(uN) dK(uN) ⎦
- #
- # in English, the units we choose, and their exponents, when combined, must have the
- # target dimensionality
-
- matrix = [
- [preferred_unit.dimensionality[dimension] for dimension in dimensions]
- for preferred_unit in (preferred_units + unpreferred_units)
- ]
-
- # Do the matrix multiplication with mip_model.xsum for performance and create constraints
- for i in range(len(dimensions)):
- dot = mip_model.xsum([var * vector[i] for var, vector in zip(vars, matrix)])
- # add constraint to the model
- model += dot == dimensionality[i]
-
- # where [c1 ... cN] are costs, 1 when a preferred variable, and a large value when not
- # minimize sum(abs(u1) * c1 ... abs(uN) * cN)
-
- # linearize the optimization variable via a proxy
- objective = model.add_var("objective", lb=0, ub=mip_INF, var_type=mip_INTEGER)
-
- # Constrain the objective to be equal to the sums of the absolute values of the preferred
- # unit powers. Do this by making a separate constraint for each permutation of signedness.
- # Also apply the cost coefficient, which causes the output to prefer the preferred units
-
- # prefer units that interact with fewer dimensions
- cost = [len(p.dimensionality) for p in preferred_units]
-
- # set the cost for non preferred units to a higher number
- bias = (
- max(map(abs, dimensionality)) * max((1, *cost)) * 10
- ) # arbitrary, just needs to be larger
- cost.extend([bias] * len(unpreferred_units))
-
- for i in range(1 << len(vars)):
- sum = mip_xsum(
- [
- (-1 if i & 1 << (len(vars) - j - 1) else 1) * cost[j] * var
- for j, var in enumerate(vars)
- ]
- )
- model += objective >= sum
-
- model.objective = objective
-
- # run the mips minimizer and extract the result if successful
- if model.optimize() == mip_OptimizationStatus.OPTIMAL:
- optimal_units = []
- min_objective = float("inf")
- for i in range(model.num_solutions):
- if model.objective_values[i] < min_objective:
- min_objective = model.objective_values[i]
- optimal_units.clear()
- elif model.objective_values[i] > min_objective:
- continue
-
- temp_unit = self._REGISTRY.Unit("")
- for var in vars:
- if var.xi(i):
- temp_unit *= self._REGISTRY.Unit(var.name) ** var.xi(i)
- optimal_units.append(temp_unit)
-
- sorting_keys = {tuple(sorted(unit._units)): unit for unit in optimal_units}
- min_key = sorted(sorting_keys)[0]
- result_unit = sorting_keys[min_key]
-
- return self.to(result_unit)
-
- # for whatever reason, a solution wasn't found
- # return the original quantity
- return self
+ # Functions not essential to a Quantity but it is
+ # convenient that they live in PlainQuantity.
+ # They are implemented elsewhere to keep Quantity class clean.
+ to_compact = qto.to_compact
+ to_preferred = qto.to_preferred
+ to_reduced_units = qto.to_reduced_units
+ ito_reduced_units = qto.ito_reduced_units
# Mathematical operations
def __int__(self) -> int:
@@ -1163,7 +799,7 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
...
@overload
- def __iadd__(self, other) -> PlainQuantity[_MagnitudeType]:
+ def __iadd__(self, other) -> PlainQuantity[MagnitudeT]:
...
def __iadd__(self, other):
@@ -1539,7 +1175,7 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
return self
@check_implemented
- def __pow__(self, other) -> PlainQuantity[_MagnitudeType]:
+ def __pow__(self, other) -> PlainQuantity[MagnitudeT]:
try:
_to_magnitude(other, self.force_ndarray, self.force_ndarray_like)
except PintTypeError:
@@ -1604,7 +1240,7 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
return self.__class__(magnitude, units)
@check_implemented
- def __rpow__(self, other) -> PlainQuantity[_MagnitudeType]:
+ def __rpow__(self, other) -> PlainQuantity[MagnitudeT]:
try:
_to_magnitude(other, self.force_ndarray, self.force_ndarray_like)
except PintTypeError:
@@ -1617,16 +1253,16 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
new_self = self.to_root_units()
return other**new_self._magnitude
- def __abs__(self) -> PlainQuantity[_MagnitudeType]:
+ def __abs__(self) -> PlainQuantity[MagnitudeT]:
return self.__class__(abs(self._magnitude), self._units)
- def __round__(self, ndigits: int | None = 0) -> PlainQuantity[int]:
+ def __round__(self, ndigits: int | None = 0) -> PlainQuantity[MagnitudeT]:
return self.__class__(round(self._magnitude, ndigits=ndigits), self._units)
- def __pos__(self) -> PlainQuantity[_MagnitudeType]:
+ def __pos__(self) -> PlainQuantity[MagnitudeT]:
return self.__class__(operator.pos(self._magnitude), self._units)
- def __neg__(self) -> PlainQuantity[_MagnitudeType]:
+ def __neg__(self) -> PlainQuantity[MagnitudeT]:
return self.__class__(operator.neg(self._magnitude), self._units)
@check_implemented
@@ -1797,5 +1433,14 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
def _ok_for_muldiv(self, no_offset_units=None) -> bool:
return True
- def to_timedelta(self: PlainQuantity[float]) -> datetime.timedelta:
+ def to_timedelta(self: PlainQuantity[MagnitudeT]) -> datetime.timedelta:
return datetime.timedelta(microseconds=self.to("microseconds").magnitude)
+
+ # We put this last to avoid overriding UnitsContainer
+ # and I do not want to rename it.
+ # TODO: Maybe in the future we need to change it to a more meaningful
+ # non-colliding name.
+
+ @property
+ def UnitsContainer(self) -> Callable[..., UnitsContainerT]:
+ return self._REGISTRY.UnitsContainer
diff --git a/pint/facets/plain/registry.py b/pint/facets/plain/registry.py
index d3baff4..ed46608 100644
--- a/pint/facets/plain/registry.py
+++ b/pint/facets/plain/registry.py
@@ -20,27 +20,38 @@ from decimal import Decimal
from fractions import Fraction
from numbers import Number
from token import NAME, NUMBER
+from tokenize import TokenInfo
+
from typing import (
TYPE_CHECKING,
Any,
Callable,
TypeVar,
Union,
+ Generic,
)
from collections.abc import Iterable, Iterator
if TYPE_CHECKING:
from ..context import Context
- from ..._typing import Quantity, Unit
+ from ...compat import Locale
+
+ # from ..._typing import Quantity, Unit
+
+from ..._typing import (
+ QuantityOrUnitLike,
+ UnitLike,
+ QuantityArgument,
+ Scalar,
+ Handler,
+)
-from ..._typing import QuantityOrUnitLike, UnitLike
from ..._vendor import appdirs
-from ...compat import HAS_BABEL, babel_parse, tokenizer
+from ...compat import babel_parse, tokenizer, TypeAlias, Self
from ...errors import DimensionalityError, RedefinitionError, UndefinedUnitError
from ...pint_eval import build_eval_tree
from ...util import ParserHelper
-from ...util import UnitsContainer
-from ...util import UnitsContainer as UnitsContainerT
+from ...util import UnitsContainer as UnitsContainer
from ...util import (
_is_dim,
create_class_with_registry,
@@ -58,25 +69,20 @@ from .definitions import (
DimensionDefinition,
PrefixDefinition,
UnitDefinition,
+ NamedDefinition,
)
from .objects import PlainQuantity, PlainUnit
-if TYPE_CHECKING:
- if HAS_BABEL:
- import babel
-
- Locale = babel.Locale
- else:
- Locale = None
-
T = TypeVar("T")
_BLOCK_RE = re.compile(r"[ (]")
@functools.lru_cache
-def pattern_to_regex(pattern):
- if hasattr(pattern, "finditer"):
+def pattern_to_regex(pattern: str | re.Pattern[str]) -> re.Pattern[str]:
+ # TODO: This has been changed during typing improvements.
+ # if hasattr(pattern, "finditer"):
+ if not isinstance(pattern, str):
pattern = pattern.pattern
# Replace "{unit_name}" match string with float regex with unit_name as group
@@ -96,15 +102,19 @@ class RegistryCache:
def __init__(self) -> None:
#: Maps dimensionality (UnitsContainer) to Units (str)
- self.dimensional_equivalents: dict[UnitsContainer, set[str]] = {}
+ self.dimensional_equivalents: dict[UnitsContainer, frozenset[str]] = {}
+
#: Maps dimensionality (UnitsContainer) to Dimensionality (UnitsContainer)
- self.root_units = {}
+ # TODO: this description is not right.
+ self.root_units: dict[UnitsContainer, tuple[Scalar, UnitsContainer]] = {}
+
#: Maps dimensionality (UnitsContainer) to Units (UnitsContainer)
self.dimensionality: dict[UnitsContainer, UnitsContainer] = {}
+
#: Cache the unit name associated to user input. ('mV' -> 'millivolt')
self.parse_unit: dict[str, UnitsContainer] = {}
- def __eq__(self, other):
+ def __eq__(self, other: Any):
if not isinstance(other, self.__class__):
return False
attrs = (
@@ -127,7 +137,12 @@ class RegistryMeta(type):
return obj
-class PlainRegistry(metaclass=RegistryMeta):
+# Generic types used to mark types associated to Registries.
+QuantityT = TypeVar("QuantityT", bound=PlainQuantity)
+UnitT = TypeVar("UnitT", bound=PlainUnit)
+
+
+class GenericPlainRegistry(Generic[QuantityT, UnitT], metaclass=RegistryMeta):
"""Base class for all registries.
Capabilities:
@@ -174,11 +189,10 @@ class PlainRegistry(metaclass=RegistryMeta):
#: Babel.Locale instance or None
fmt_locale: Locale | None = None
- _diskcache = None
-
- Quantity = PlainQuantity
- Unit = PlainUnit
+ Quantity: type[QuantityT]
+ Unit: type[UnitT]
+ _diskcache = None
_def_parser = None
def __init__(
@@ -197,7 +211,7 @@ class PlainRegistry(metaclass=RegistryMeta):
mpl_formatter: str = "{:P}",
):
#: Map a definition class to a adder methods.
- self._adders = {}
+ self._adders: Handler = {}
self._register_definition_adders()
self._init_dynamic_classes()
@@ -280,8 +294,8 @@ class PlainRegistry(metaclass=RegistryMeta):
def _init_dynamic_classes(self) -> None:
"""Generate subclasses on the fly and attach them to self"""
- self.Unit: Unit = create_class_with_registry(self, self.Unit)
- self.Quantity: Quantity = create_class_with_registry(self, self.Quantity)
+ self.Unit = create_class_with_registry(self, self.Unit)
+ self.Quantity = create_class_with_registry(self, self.Quantity)
def _after_init(self) -> None:
"""This should be called after all __init__"""
@@ -297,7 +311,16 @@ class PlainRegistry(metaclass=RegistryMeta):
self._build_cache(loaded_files)
self._initialized = True
- def _register_adder(self, definition_class, adder_func):
+ def _register_adder(
+ self,
+ definition_class: type[T],
+ adder_func: Callable[
+ [
+ T,
+ ],
+ None,
+ ],
+ ) -> None:
"""Register a block definition."""
self._adders[definition_class] = adder_func
@@ -310,24 +333,25 @@ class PlainRegistry(metaclass=RegistryMeta):
self._register_adder(DimensionDefinition, self._add_dimension)
self._register_adder(DerivedDimensionDefinition, self._add_derived_dimension)
- def __deepcopy__(self, memo) -> PlainRegistry:
+ def __deepcopy__(self: Self, memo) -> type[Self]:
new = object.__new__(type(self))
new.__dict__ = copy.deepcopy(self.__dict__, memo)
new._init_dynamic_classes()
return new
- def __getattr__(self, item):
+ def __getattr__(self, item: str) -> QuantityT:
getattr_maybe_raise(self, item)
return self.Unit(item)
- def __getitem__(self, item):
+ def __getitem__(self, item: str) -> UnitT:
logger.warning(
"Calling the getitem method from a UnitRegistry is deprecated. "
"use `parse_expression` method or use the registry as a callable."
)
- return self.parse_expression(item)
+ return self.Quantity()
+ # return self.parse_expression(item)
- def __contains__(self, item) -> bool:
+ def __contains__(self, item: str) -> bool:
"""Support checking prefixed units with the `in` operator"""
try:
self.__getattr__(item)
@@ -366,16 +390,13 @@ class PlainRegistry(metaclass=RegistryMeta):
self.fmt_locale = loc
- def UnitsContainer(self, *args, **kwargs) -> UnitsContainerT:
- return UnitsContainer(*args, non_int_type=self.non_int_type, **kwargs)
-
@property
def default_format(self) -> str:
"""Default formatting string for quantities."""
return self.Quantity.default_format
@default_format.setter
- def default_format(self, value: str):
+ def default_format(self, value: str) -> None:
self.Unit.default_format = value
self.Quantity.default_format = value
self.Measurement.default_format = value
@@ -390,7 +411,7 @@ class PlainRegistry(metaclass=RegistryMeta):
def non_int_type(self):
return self._non_int_type
- def define(self, definition):
+ def define(self, definition: str | type) -> None:
"""Add unit to the registry.
Parameters
@@ -413,7 +434,7 @@ class PlainRegistry(metaclass=RegistryMeta):
# - then we define specific adder for each definition class. :-D
############
- def _helper_dispatch_adder(self, definition):
+ def _helper_dispatch_adder(self, definition: Any) -> None:
"""Helper function to add a single definition,
choosing the appropiate method by class.
"""
@@ -428,7 +449,12 @@ class PlainRegistry(metaclass=RegistryMeta):
adder_func(definition)
- def _helper_adder(self, definition, target_dict, casei_target_dict):
+ def _helper_adder(
+ self,
+ definition: NamedDefinition,
+ target_dict: dict[str, Any],
+ casei_target_dict: dict[str, Any] | None,
+ ) -> None:
"""Helper function to store a definition in the internal dictionaries.
It stores the definition under its name, symbol and aliases.
"""
@@ -436,6 +462,7 @@ class PlainRegistry(metaclass=RegistryMeta):
definition.name, definition, target_dict, casei_target_dict
)
+ # TODO: Not sure why but using hasattr does not work here.
if getattr(definition, "has_symbol", ""):
self._helper_single_adder(
definition.symbol, definition, target_dict, casei_target_dict
@@ -447,7 +474,13 @@ class PlainRegistry(metaclass=RegistryMeta):
self._helper_single_adder(alias, definition, target_dict, casei_target_dict)
- def _helper_single_adder(self, key, value, target_dict, casei_target_dict):
+ def _helper_single_adder(
+ self,
+ key: str,
+ value: NamedDefinition,
+ target_dict: dict[str, Any],
+ casei_target_dict: dict[str, Any] | None,
+ ) -> None:
"""Helper function to store a definition in the internal dictionaries.
It warns or raise error on redefinition.
@@ -462,11 +495,11 @@ class PlainRegistry(metaclass=RegistryMeta):
if casei_target_dict is not None:
casei_target_dict[key.lower()].add(key)
- def _add_defaults(self, defaults_definition: DefaultsDefinition):
+ def _add_defaults(self, defaults_definition: DefaultsDefinition) -> None:
for k, v in defaults_definition.items():
self._defaults[k] = v
- def _add_alias(self, definition: AliasDefinition):
+ def _add_alias(self, definition: AliasDefinition) -> None:
unit_dict = self._units
unit = unit_dict[definition.name]
while not isinstance(unit, UnitDefinition):
@@ -474,19 +507,19 @@ class PlainRegistry(metaclass=RegistryMeta):
for alias in definition.aliases:
self._helper_single_adder(alias, unit, self._units, self._units_casei)
- def _add_dimension(self, definition: DimensionDefinition):
+ def _add_dimension(self, definition: DimensionDefinition) -> None:
self._helper_adder(definition, self._dimensions, None)
- def _add_derived_dimension(self, definition: DerivedDimensionDefinition):
+ def _add_derived_dimension(self, definition: DerivedDimensionDefinition) -> None:
for dim_name in definition.reference.keys():
if dim_name not in self._dimensions:
self._add_dimension(DimensionDefinition(dim_name))
self._helper_adder(definition, self._dimensions, None)
- def _add_prefix(self, definition: PrefixDefinition):
+ def _add_prefix(self, definition: PrefixDefinition) -> None:
self._helper_adder(definition, self._prefixes, None)
- def _add_unit(self, definition: UnitDefinition):
+ def _add_unit(self, definition: UnitDefinition) -> None:
if definition.is_base:
self._base_units.append(definition.name)
for dim_name in definition.reference.keys():
@@ -495,7 +528,9 @@ class PlainRegistry(metaclass=RegistryMeta):
self._helper_adder(definition, self._units, self._units_casei)
- def load_definitions(self, file, is_resource: bool = False):
+ def load_definitions(
+ self, file: Iterable[str] | str | pathlib.Path, is_resource: bool = False
+ ):
"""Add units and prefixes defined in a definition text file.
Parameters
@@ -531,8 +566,8 @@ class PlainRegistry(metaclass=RegistryMeta):
self._cache = RegistryCache()
- deps = {
- name: definition.reference.keys() if definition.reference else set()
+ deps: dict[str, set[str]] = {
+ name: set(definition.reference.keys()) if definition.reference else set()
for name, definition in self._units.items()
}
@@ -579,14 +614,13 @@ class PlainRegistry(metaclass=RegistryMeta):
candidates = self.parse_unit_name(name_or_alias, case_sensitive)
if not candidates:
raise UndefinedUnitError(name_or_alias)
- elif len(candidates) == 1:
- prefix, unit_name, _ = candidates[0]
- else:
+
+ prefix, unit_name, _ = candidates[0]
+ if len(candidates) > 1:
logger.warning(
"Parsing {} yield multiple results. "
- "Options are: {}".format(name_or_alias, candidates)
+ "Options are: {!r}".format(name_or_alias, candidates)
)
- prefix, unit_name, _ = candidates[0]
if prefix:
name = prefix + unit_name
@@ -595,7 +629,7 @@ class PlainRegistry(metaclass=RegistryMeta):
self._units[name] = UnitDefinition(
name,
symbol,
- (),
+ tuple(),
prefix_def.converter,
self.UnitsContainer({unit_name: 1}),
)
@@ -608,21 +642,20 @@ class PlainRegistry(metaclass=RegistryMeta):
candidates = self.parse_unit_name(name_or_alias, case_sensitive)
if not candidates:
raise UndefinedUnitError(name_or_alias)
- elif len(candidates) == 1:
- prefix, unit_name, _ = candidates[0]
- else:
+
+ prefix, unit_name, _ = candidates[0]
+ if len(candidates) > 1:
logger.warning(
"Parsing {} yield multiple results. "
"Options are: {!r}".format(name_or_alias, candidates)
)
- prefix, unit_name, _ = candidates[0]
return self._prefixes[prefix].symbol + self._units[unit_name].symbol
def _get_symbol(self, name: str) -> str:
return self._units[name].symbol
- def get_dimensionality(self, input_units) -> UnitsContainerT:
+ def get_dimensionality(self, input_units: UnitLike) -> UnitsContainer:
"""Convert unit or dict of units or dimensions to a dict of plain dimensions
dimensions
"""
@@ -633,9 +666,7 @@ class PlainRegistry(metaclass=RegistryMeta):
return self._get_dimensionality(input_units)
- def _get_dimensionality(
- self, input_units: UnitsContainerT | None
- ) -> UnitsContainerT:
+ def _get_dimensionality(self, input_units: UnitsContainer | None) -> UnitsContainer:
"""Convert a UnitsContainer to plain dimensions."""
if not input_units:
return self.UnitsContainer()
@@ -647,7 +678,7 @@ class PlainRegistry(metaclass=RegistryMeta):
except KeyError:
pass
- accumulator = defaultdict(int)
+ accumulator: dict[str, int] = defaultdict(int)
self._get_dimensionality_recurse(input_units, 1, accumulator)
if "[]" in accumulator:
@@ -659,21 +690,25 @@ class PlainRegistry(metaclass=RegistryMeta):
return dims
- def _get_dimensionality_recurse(self, ref, exp, accumulator):
+ def _get_dimensionality_recurse(
+ self, ref: UnitsContainer, exp: Scalar, accumulator: dict[str, int]
+ ) -> None:
for key in ref:
exp2 = exp * ref[key]
if _is_dim(key):
reg = self._dimensions[key]
- if reg.is_base:
- accumulator[key] += exp2
- elif reg.reference is not None:
+ if isinstance(reg, DerivedDimensionDefinition):
self._get_dimensionality_recurse(reg.reference, exp2, accumulator)
+ else:
+ # DimensionDefinition.
+ accumulator[key] += exp2
+
else:
reg = self._units[self.get_name(key)]
if reg.reference is not None:
self._get_dimensionality_recurse(reg.reference, exp2, accumulator)
- def _get_dimensionality_ratio(self, unit1, unit2):
+ def _get_dimensionality_ratio(self, unit1: UnitLike, unit2: UnitLike):
"""Get the exponential ratio between two units, i.e. solve unit2 = unit1**x for x.
Parameters
@@ -707,7 +742,7 @@ class PlainRegistry(metaclass=RegistryMeta):
def get_root_units(
self, input_units: UnitLike, check_nonmult: bool = True
- ) -> tuple[Number, PlainUnit]:
+ ) -> tuple[Number, UnitT]:
"""Convert unit or dict of units to the root units.
If any unit is non multiplicative and check_converter is True,
@@ -734,7 +769,9 @@ class PlainRegistry(metaclass=RegistryMeta):
return f, self.Unit(units)
- def _get_root_units(self, input_units, check_nonmult=True):
+ def _get_root_units(
+ self, input_units: UnitsContainer, check_nonmult: bool = True
+ ) -> tuple[Scalar, UnitsContainer]:
"""Convert unit or dict of units to the root units.
If any unit is non multiplicative and check_converter is True,
@@ -764,12 +801,13 @@ class PlainRegistry(metaclass=RegistryMeta):
except KeyError:
pass
- accumulators = [1, defaultdict(int)]
+ accumulators: dict[str | None, int] = defaultdict(int)
+ accumulators[None] = 1
self._get_root_units_recurse(input_units, 1, accumulators)
- factor = accumulators[0]
+ factor = accumulators[None]
units = self.UnitsContainer(
- {k: v for k, v in accumulators[1].items() if v != 0}
+ {k: v for k, v in accumulators.items() if k is not None and v != 0}
)
# Check if any of the final units is non multiplicative and return None instead.
@@ -780,7 +818,9 @@ class PlainRegistry(metaclass=RegistryMeta):
cache[input_units] = factor, units
return factor, units
- def get_base_units(self, input_units, check_nonmult=True, system=None):
+ def get_base_units(
+ self, input_units: UnitsContainer | str, check_nonmult: bool = True, system=None
+ ) -> tuple[Number, UnitT]:
"""Convert unit or dict of units to the plain units.
If any unit is non multiplicative and check_converter is True,
@@ -806,35 +846,44 @@ class PlainRegistry(metaclass=RegistryMeta):
return self.get_root_units(input_units, check_nonmult)
- def _get_root_units_recurse(self, ref, exp, accumulators):
+ # TODO: accumulators breaks typing list[int, dict[str, int]]
+ # So we have changed the behavior here
+ def _get_root_units_recurse(
+ self, ref: UnitsContainer, exp: Scalar, accumulators: dict[str | None, int]
+ ) -> None:
+ """
+
+ accumulators None keeps the scalar prefactor not associated with a specific unit.
+
+ """
for key in ref:
exp2 = exp * ref[key]
key = self.get_name(key)
reg = self._units[key]
if reg.is_base:
- accumulators[1][key] += exp2
+ accumulators[key] += exp2
else:
- accumulators[0] *= reg.converter.scale**exp2
+ accumulators[None] *= reg.converter.scale**exp2
if reg.reference is not None:
self._get_root_units_recurse(reg.reference, exp2, accumulators)
- def get_compatible_units(
- self, input_units, group_or_system=None
- ) -> frozenset[Unit]:
+ def get_compatible_units(self, input_units: QuantityOrUnitLike) -> frozenset[UnitT]:
""" """
input_units = to_units_container(input_units)
- equiv = self._get_compatible_units(input_units, group_or_system)
+ equiv = self._get_compatible_units(input_units)
return frozenset(self.Unit(eq) for eq in equiv)
- def _get_compatible_units(self, input_units, group_or_system):
+ def _get_compatible_units(
+ self, input_units: UnitsContainer, *args, **kwargs
+ ) -> frozenset[str]:
""" """
if not input_units:
return frozenset()
src_dim = self._get_dimensionality(input_units)
- return self._cache.dimensional_equivalents.setdefault(src_dim, set())
+ return self._cache.dimensional_equivalents.setdefault(src_dim, frozenset())
# TODO: remove context from here
def is_compatible_with(
@@ -901,7 +950,14 @@ class PlainRegistry(metaclass=RegistryMeta):
return self._convert(value, src, dst, inplace)
- def _convert(self, value, src, dst, inplace=False, check_dimensionality=True):
+ def _convert(
+ self,
+ value: T,
+ src: UnitsContainer,
+ dst: UnitsContainer,
+ inplace: bool = False,
+ check_dimensionality: bool = True,
+ ) -> T:
"""Convert value from some source to destination units.
Parameters
@@ -931,7 +987,7 @@ class PlainRegistry(metaclass=RegistryMeta):
# If the source and destination dimensionality are different,
# then the conversion cannot be performed.
if src_dim != dst_dim:
- raise DimensionalityError(src, dst, src_dim, dst_dim)
+ raise DimensionalityError(src, dst, str(src_dim), str(dst_dim))
# Here src and dst have only multiplicative units left. Thus we can
# convert with a factor.
@@ -953,7 +1009,7 @@ class PlainRegistry(metaclass=RegistryMeta):
def parse_unit_name(
self, unit_name: str, case_sensitive: bool | None = None
- ) -> tuple[tuple[str, str, str], ...]:
+ ) -> tuple[tuple[str, str, str]]:
"""Parse a unit to identify prefix, unit name and suffix
by walking the list of prefix and suffix.
In case of equivalent combinations (e.g. ('kilo', 'gram', '') and
@@ -1033,7 +1089,7 @@ class PlainRegistry(metaclass=RegistryMeta):
input_string: str,
as_delta: bool | None = None,
case_sensitive: bool | None = None,
- ) -> Unit:
+ ) -> UnitT:
"""Parse a units expression and returns a UnitContainer with
the canonical names.
@@ -1054,6 +1110,8 @@ class PlainRegistry(metaclass=RegistryMeta):
pint.Unit
"""
+
+ # TODO: deal or remove with as_delta = None
for p in self.preprocessors:
input_string = p(input_string)
units = self._parse_units(input_string, as_delta, case_sensitive)
@@ -1064,7 +1122,7 @@ class PlainRegistry(metaclass=RegistryMeta):
input_string: str,
as_delta: bool = True,
case_sensitive: bool | None = None,
- ) -> UnitsContainerT:
+ ) -> UnitsContainer:
"""Parse a units expression and returns a UnitContainer with
the canonical names.
"""
@@ -1104,12 +1162,37 @@ class PlainRegistry(metaclass=RegistryMeta):
return ret
- def _eval_token(self, token, case_sensitive=None, **values):
+ def _eval_token(
+ self,
+ token: TokenInfo,
+ case_sensitive: bool | None = None,
+ **values: QuantityArgument,
+ ):
+ """Evaluate a single token using the following rules:
+
+ 1. numerical values as strings are replaced by their numeric counterparts
+ - integers are parsed as integers
+ - other numeric values are parses of non_int_type
+ 2. strings in (inf, infinity, nan, dimensionless) with their numerical value.
+ 3. strings in values.keys() are replaced by Quantity(values[key])
+ 4. in other cases, the values are parsed as units and replaced by their canonical name.
+
+ Parameters
+ ----------
+ token
+ Token to evaluate.
+ case_sensitive, optional
+ If true, a case sensitive matching of the unit name will be done in the registry.
+ If false, a case INsensitive matching of the unit name will be done in the registry.
+ (Default value = None, which uses registry setting)
+ **values
+ Other string that will be parsed using the Quantity constructor on their corresponding value.
+ """
token_type = token[0]
token_text = token[1]
if token_type == NAME:
if token_text == "dimensionless":
- return self.Quantity(1, self.dimensionless)
+ return self.Quantity(1)
elif token_text.lower() in ("inf", "infinity"):
return self.non_int_type("inf")
elif token_text.lower() == "nan":
@@ -1139,28 +1222,25 @@ class PlainRegistry(metaclass=RegistryMeta):
Parameters
----------
- input_string :
+ input_string
pattern_string:
- The regex parse string
- case_sensitive :
- (Default value = None, which uses registry setting)
- many :
+ The regex parse string
+ case_sensitive, optional
+ If true, a case sensitive matching of the unit name will be done in the registry.
+ If false, a case INsensitive matching of the unit name will be done in the registry.
+ (Default value = None, which uses registry setting)
+ many, optional
Match many results
(Default value = False)
-
-
- Returns
- -------
-
"""
if not input_string:
return [] if many else None
# Parse string
- pattern = pattern_to_regex(pattern)
- matched = re.finditer(pattern, input_string)
+ regex = pattern_to_regex(pattern)
+ matched = re.finditer(regex, input_string)
# Extract result(s)
results = []
@@ -1184,11 +1264,11 @@ class PlainRegistry(metaclass=RegistryMeta):
return results
def parse_expression(
- self,
+ self: Self,
input_string: str,
case_sensitive: bool | None = None,
- **values,
- ) -> Quantity:
+ **values: QuantityArgument,
+ ) -> QuantityT:
"""Parse a mathematical expression including units and return a quantity object.
Numerical constants can be specified as keyword arguments and will take precedence
@@ -1196,16 +1276,14 @@ class PlainRegistry(metaclass=RegistryMeta):
Parameters
----------
- input_string :
-
- case_sensitive :
- (Default value = None, which uses registry setting)
- **values :
-
-
- Returns
- -------
-
+ input_string
+
+ case_sensitive, optional
+ If true, a case sensitive matching of the unit name will be done in the registry.
+ If false, a case INsensitive matching of the unit name will be done in the registry.
+ (Default value = None, which uses registry setting)
+ **values
+ Other string that will be parsed using the Quantity constructor on their corresponding value.
"""
if not input_string:
return self.Quantity(1)
@@ -1215,8 +1293,21 @@ class PlainRegistry(metaclass=RegistryMeta):
input_string = string_preprocessor(input_string)
gen = tokenizer(input_string)
- return build_eval_tree(gen).evaluate(
- lambda x: self._eval_token(x, case_sensitive=case_sensitive, **values)
- )
+ def _define_op(s: str):
+ return self._eval_token(s, case_sensitive=case_sensitive, **values)
+
+ return build_eval_tree(gen).evaluate(_define_op)
+
+ # We put this last to avoid overriding UnitsContainer
+ # and I do not want to rename it.
+ # TODO: Maybe in the future we need to change it to a more meaningful
+ # non-colliding name.
+ def UnitsContainer(self, *args: Any, **kwargs: Any) -> UnitsContainer:
+ return UnitsContainer(*args, non_int_type=self.non_int_type, **kwargs)
__call__ = parse_expression
+
+
+class PlainRegistry(GenericPlainRegistry[PlainQuantity[Any], PlainUnit]):
+ Quantity: TypeAlias = PlainQuantity[Any]
+ Unit: TypeAlias = PlainUnit
diff --git a/pint/facets/system/__init__.py b/pint/facets/system/__init__.py
index e95098b..24e68b7 100644
--- a/pint/facets/system/__init__.py
+++ b/pint/facets/system/__init__.py
@@ -12,6 +12,6 @@ from __future__ import annotations
from .definitions import SystemDefinition
from .objects import System
-from .registry import SystemRegistry
+from .registry import SystemRegistry, GenericSystemRegistry
-__all__ = ["SystemDefinition", "System", "SystemRegistry"]
+__all__ = ["SystemDefinition", "System", "SystemRegistry", "GenericSystemRegistry"]
diff --git a/pint/facets/system/definitions.py b/pint/facets/system/definitions.py
index 1ce8269..eb582f3 100644
--- a/pint/facets/system/definitions.py
+++ b/pint/facets/system/definitions.py
@@ -11,7 +11,7 @@ from __future__ import annotations
from collections.abc import Iterable
from dataclasses import dataclass
-from ..._typing import Self
+from ...compat import Self
from ... import errors
diff --git a/pint/facets/system/objects.py b/pint/facets/system/objects.py
index 69b1c84..cf6a24f 100644
--- a/pint/facets/system/objects.py
+++ b/pint/facets/system/objects.py
@@ -14,7 +14,9 @@ import numbers
from typing import Any
from collections.abc import Iterable
-from ..._typing import Self
+
+from typing import Callable, Generic
+from numbers import Number
from ...babel_names import _babel_systems
from ...compat import babel_parse
@@ -25,6 +27,20 @@ from ...util import (
to_units_container,
)
from .definitions import SystemDefinition
+from .. import group
+from ..plain import MagnitudeT
+
+from ..._typing import UnitLike
+
+GetRootUnits = Callable[[UnitLike, bool], tuple[Number, UnitLike]]
+
+
+class SystemQuantity(Generic[MagnitudeT], group.GroupQuantity[MagnitudeT]):
+ pass
+
+
+class SystemUnit(group.GroupUnit):
+ pass
class System(SharedRegistryObject):
@@ -76,11 +92,11 @@ class System(SharedRegistryObject):
def members(self):
d = self._REGISTRY._groups
if self._computed_members is None:
- self._computed_members = set()
+ tmp: set[str] = set()
for group_name in self._used_groups:
try:
- self._computed_members |= d[group_name].members
+ tmp |= d[group_name].members
except KeyError:
logger.warning(
"Could not resolve {} in System {}".format(
@@ -88,7 +104,7 @@ class System(SharedRegistryObject):
)
)
- self._computed_members = frozenset(self._computed_members)
+ self._computed_members = frozenset(tmp)
return self._computed_members
@@ -116,17 +132,30 @@ class System(SharedRegistryObject):
return locale.measurement_systems[name]
return self.name
+ # TODO: When 3.11 is minimal version, use Self
+
@classmethod
def from_lines(
- cls: type[Self], lines: Iterable[str], get_root_func, non_int_type: type = float
- ) -> Self:
+ cls: type[System],
+ lines: Iterable[str],
+ get_root_func: GetRootUnits,
+ non_int_type: type = float,
+ ) -> System:
# TODO: we changed something here it used to be
# system_definition = SystemDefinition.from_lines(lines, get_root_func)
system_definition = SystemDefinition.from_lines(lines, non_int_type)
+
+ if system_definition is None:
+ raise ValueError(f"Could not define System from from {lines}")
+
return cls.from_definition(system_definition, get_root_func)
@classmethod
- def from_definition(cls, system_definition: SystemDefinition, get_root_func=None):
+ def from_definition(
+ cls: type[System],
+ system_definition: SystemDefinition,
+ get_root_func: GetRootUnits | None = None,
+ ) -> System:
if get_root_func is None:
# TODO: kept for backwards compatibility
get_root_func = cls._REGISTRY.get_root_units
diff --git a/pint/facets/system/registry.py b/pint/facets/system/registry.py
index 6e0878e..30921bd 100644
--- a/pint/facets/system/registry.py
+++ b/pint/facets/system/registry.py
@@ -9,10 +9,14 @@
from __future__ import annotations
from numbers import Number
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Generic, Any
from ... import errors
+from ...compat import TypeAlias
+
+from ..plain import QuantityT, UnitT
+
if TYPE_CHECKING:
from ..._typing import Quantity, Unit
@@ -22,13 +26,14 @@ from ...util import (
create_class_with_registry,
to_units_container,
)
-from ..group import GroupRegistry
+from ..group import GenericGroupRegistry
from .definitions import SystemDefinition
-from .objects import Lister, System
from . import objects
-class SystemRegistry(GroupRegistry):
+class GenericSystemRegistry(
+ Generic[QuantityT, UnitT], GenericGroupRegistry[QuantityT, UnitT]
+):
"""Handle of Systems.
Conversion between units with different dimensions according
@@ -46,24 +51,24 @@ class SystemRegistry(GroupRegistry):
# TODO: Change this to System: System to specify class
# and use introspection to get system class as a way
# to enjoy typing goodies
- System = objects.System
+ System: type[objects.System]
- def __init__(self, system=None, **kwargs):
+ def __init__(self, system: str | None = None, **kwargs):
super().__init__(**kwargs)
#: Map system name to system.
#: :type: dict[ str | System]
- self._systems: dict[str, System] = {}
+ self._systems: dict[str, objects.System] = {}
#: Maps dimensionality (UnitsContainer) to Dimensionality (UnitsContainer)
- self._base_units_cache = {}
+ self._base_units_cache: dict[UnitsContainerT, UnitsContainerT] = {}
- self._default_system = system
+ self._default_system_name: str | None = system
def _init_dynamic_classes(self) -> None:
"""Generate subclasses on the fly and attach them to self"""
super()._init_dynamic_classes()
- self.System = create_class_with_registry(self, self.System)
+ self.System = create_class_with_registry(self, objects.System)
def _after_init(self) -> None:
"""Invoked at the end of ``__init__``.
@@ -74,7 +79,7 @@ class SystemRegistry(GroupRegistry):
super()._after_init()
#: System name to be used by default.
- self._default_system = self._default_system or self._defaults.get(
+ self._default_system_name = self._default_system_name or self._defaults.get(
"system", None
)
@@ -82,7 +87,7 @@ class SystemRegistry(GroupRegistry):
super()._register_definition_adders()
self._register_adder(SystemDefinition, self._add_system)
- def _add_system(self, sd: SystemDefinition):
+ def _add_system(self, sd: SystemDefinition) -> None:
if sd.name in self._systems:
raise ValueError(f"System {sd.name} already present in registry")
@@ -96,29 +101,29 @@ class SystemRegistry(GroupRegistry):
@property
def sys(self):
- return Lister(self._systems)
+ return objects.Lister(self._systems)
@property
- def default_system(self) -> System:
- return self._default_system
+ def default_system(self) -> str | None:
+ return self._default_system_name
@default_system.setter
- def default_system(self, name):
+ def default_system(self, name: str) -> None:
if name:
if name not in self._systems:
raise ValueError("Unknown system %s" % name)
self._base_units_cache = {}
- self._default_system = name
+ self._default_system_name = name
- def get_system(self, name: str, create_if_needed: bool = True) -> System:
+ def get_system(self, name: str, create_if_needed: bool = True) -> objects.System:
"""Return a Group.
Parameters
----------
name : str
- Name of the group to be
+ Name of the group to be.
create_if_needed : bool
If True, create a group if not found. If False, raise an Exception.
(Default value = True)
@@ -141,7 +146,7 @@ class SystemRegistry(GroupRegistry):
self,
input_units: UnitLike | Quantity,
check_nonmult: bool = True,
- system: str | System | None = None,
+ system: str | objects.System | None = None,
) -> tuple[Number, Unit]:
"""Convert unit or dict of units to the plain units.
@@ -179,15 +184,15 @@ class SystemRegistry(GroupRegistry):
self,
input_units: UnitsContainerT,
check_nonmult: bool = True,
- system: str | System | None = None,
+ system: str | objects.System | None = None,
):
if system is None:
- system = self._default_system
+ system = self._default_system_name
# The cache is only done for check_nonmult=True and the current system.
if (
check_nonmult
- and system == self._default_system
+ and system == self._default_system_name
and input_units in self._base_units_cache
):
return self._base_units_cache[input_units]
@@ -220,16 +225,32 @@ class SystemRegistry(GroupRegistry):
return base_factor, destination_units
- def _get_compatible_units(self, input_units, group_or_system) -> frozenset[Unit]:
+ def get_compatible_units(
+ self, input_units: UnitsContainerT, group_or_system: str | None = None
+ ) -> frozenset[Unit]:
+ """ """
+
+ group_or_system = group_or_system or self._default_system_name
+
if group_or_system is None:
- group_or_system = self._default_system
+ return super().get_compatible_units(input_units)
+
+ input_units = to_units_container(input_units)
+
+ equiv = self._get_compatible_units(input_units, group_or_system)
+
+ return frozenset(self.Unit(eq) for eq in equiv)
+ def _get_compatible_units(
+ self, input_units: UnitsContainerT, group_or_system: str | None = None
+ ) -> frozenset[Unit]:
if group_or_system and group_or_system in self._systems:
members = self._systems[group_or_system].members
# group_or_system has been handled by System
- return frozenset(members & super()._get_compatible_units(input_units, None))
+ return frozenset(members & super()._get_compatible_units(input_units))
try:
+ # This will be handled by groups
return super()._get_compatible_units(input_units, group_or_system)
except ValueError as ex:
# It might be also a system
@@ -238,3 +259,10 @@ class SystemRegistry(GroupRegistry):
"Unknown Group o System with name '%s'" % group_or_system
) from ex
raise ex
+
+
+class SystemRegistry(
+ GenericSystemRegistry[objects.SystemQuantity[Any], objects.SystemUnit]
+):
+ Quantity: TypeAlias = objects.SystemQuantity[Any]
+ Unit: TypeAlias = objects.SystemUnit
diff --git a/pint/formatting.py b/pint/formatting.py
index 880f55b..28adf25 100644
--- a/pint/formatting.py
+++ b/pint/formatting.py
@@ -13,17 +13,27 @@ from __future__ import annotations
import functools
import re
import warnings
-from typing import Callable, Any
+from typing import Callable, Any, TYPE_CHECKING, TypeVar
from collections.abc import Iterable
from numbers import Number
from .babel_names import _babel_lengths, _babel_units
-from .compat import babel_parse
+from .compat import babel_parse, HAS_BABEL
+
+if TYPE_CHECKING:
+ from .util import ItMatrix, UnitsContainer
+
+ if HAS_BABEL:
+ import babel
+
+ Locale = babel.Locale
+ else:
+ Locale = TypeVar("Locale")
__JOIN_REG_EXP = re.compile(r"{\d*}")
-def _join(fmt: str, iterable: Iterable[Any]):
+def _join(fmt: str, iterable: Iterable[Any]) -> str:
"""Join an iterable with the format specified in fmt.
The format can be specified in two ways:
@@ -124,6 +134,7 @@ _FORMATS: dict[str, dict[str, Any]] = {
}
#: _FORMATTERS maps format names to callables doing the formatting
+# TODO fix Callable typing
_FORMATTERS: dict[str, Callable] = {}
@@ -167,7 +178,7 @@ def register_unit_format(name: str):
@register_unit_format("P")
-def format_pretty(unit, registry, **options):
+def format_pretty(unit: UnitsContainer, registry, **options) -> str:
return formatter(
unit.items(),
as_ratio=True,
@@ -181,7 +192,7 @@ def format_pretty(unit, registry, **options):
)
-def latex_escape(string):
+def latex_escape(string: str) -> str:
"""
Prepend characters that have a special meaning in LaTeX with a backslash.
"""
@@ -198,7 +209,7 @@ def latex_escape(string):
@register_unit_format("L")
-def format_latex(unit, registry, **options):
+def format_latex(unit: UnitsContainer, registry, **options) -> str:
preprocessed = {rf"\mathrm{{{latex_escape(u)}}}": p for u, p in unit.items()}
formatted = formatter(
preprocessed.items(),
@@ -214,7 +225,7 @@ def format_latex(unit, registry, **options):
@register_unit_format("Lx")
-def format_latex_siunitx(unit, registry, **options):
+def format_latex_siunitx(unit: UnitsContainer, registry, **options) -> str:
if registry is None:
raise ValueError(
"Can't format as siunitx without a registry."
@@ -228,7 +239,7 @@ def format_latex_siunitx(unit, registry, **options):
@register_unit_format("H")
-def format_html(unit, registry, **options):
+def format_html(unit: UnitsContainer, registry, **options) -> str:
return formatter(
unit.items(),
as_ratio=True,
@@ -242,7 +253,7 @@ def format_html(unit, registry, **options):
@register_unit_format("D")
-def format_default(unit, registry, **options):
+def format_default(unit: UnitsContainer, registry, **options) -> str:
return formatter(
unit.items(),
as_ratio=True,
@@ -256,7 +267,7 @@ def format_default(unit, registry, **options):
@register_unit_format("C")
-def format_compact(unit, registry, **options):
+def format_compact(unit: UnitsContainer, registry, **options) -> str:
return formatter(
unit.items(),
as_ratio=True,
@@ -270,7 +281,7 @@ def format_compact(unit, registry, **options):
def formatter(
- items: list[tuple[str, Number]],
+ items: Iterable[tuple[str, Number]],
as_ratio: bool = True,
single_denominator: bool = False,
product_fmt: str = " * ",
@@ -282,7 +293,7 @@ def formatter(
babel_length: str = "long",
babel_plural_form: str = "one",
sort: bool = True,
-):
+) -> str:
"""Format a list of (name, exponent) pairs.
Parameters
@@ -393,7 +404,7 @@ def formatter(
_BASIC_TYPES = frozenset("bcdeEfFgGnosxX%uS")
-def _parse_spec(spec):
+def _parse_spec(spec: str) -> str:
result = ""
for ch in reversed(spec):
if ch == "~" or ch in _BASIC_TYPES:
@@ -410,7 +421,7 @@ def _parse_spec(spec):
return result
-def format_unit(unit, spec, registry=None, **options):
+def format_unit(unit, spec: str, registry=None, **options):
# registry may be None to allow formatting `UnitsContainer` objects
# in that case, the spec may not be "Lx"
@@ -430,10 +441,10 @@ def format_unit(unit, spec, registry=None, **options):
return fmt(unit, registry=registry, **options)
-def siunitx_format_unit(units, registry):
+def siunitx_format_unit(units: UnitsContainer, registry) -> str:
"""Returns LaTeX code for the unit that can be put into an siunitx command."""
- def _tothe(power):
+ def _tothe(power: int | float) -> str:
if isinstance(power, int) or (isinstance(power, float) and power.is_integer()):
if power == 1:
return ""
@@ -473,7 +484,7 @@ def siunitx_format_unit(units, registry):
return "".join(lpos) + "".join(lneg)
-def extract_custom_flags(spec):
+def extract_custom_flags(spec: str) -> str:
import re
if not spec:
@@ -488,14 +499,16 @@ def extract_custom_flags(spec):
return "".join(custom_flags)
-def remove_custom_flags(spec):
+def remove_custom_flags(spec: str) -> str:
for flag in sorted(_FORMATTERS.keys(), key=len, reverse=True) + ["~"]:
if flag:
spec = spec.replace(flag, "")
return spec
-def split_format(spec, default, separate_format_defaults=True):
+def split_format(
+ spec: str, default: str, separate_format_defaults: bool = True
+) -> tuple[str, str]:
mspec = remove_custom_flags(spec)
uspec = extract_custom_flags(spec)
@@ -535,11 +548,11 @@ def split_format(spec, default, separate_format_defaults=True):
return mspec, uspec
-def vector_to_latex(vec, fmtfun=lambda x: format(x, ".2f")):
+def vector_to_latex(vec: Iterable[Any], fmtfun=lambda x: format(x, ".2f")) -> str:
return matrix_to_latex([vec], fmtfun)
-def matrix_to_latex(matrix, fmtfun=lambda x: format(x, ".2f")):
+def matrix_to_latex(matrix: ItMatrix, fmtfun=lambda x: format(x, ".2f")) -> str:
ret = []
for row in matrix:
@@ -548,7 +561,9 @@ def matrix_to_latex(matrix, fmtfun=lambda x: format(x, ".2f")):
return r"\begin{pmatrix}%s\end{pmatrix}" % "\\\\ \n".join(ret)
-def ndarray_to_latex_parts(ndarr, fmtfun=lambda x: format(x, ".2f"), dim=()):
+def ndarray_to_latex_parts(
+ ndarr, fmtfun=lambda x: format(x, ".2f"), dim: tuple[int] = tuple()
+):
if isinstance(fmtfun, str):
fmt = fmtfun
fmtfun = lambda x: format(x, fmt)
@@ -573,5 +588,7 @@ def ndarray_to_latex_parts(ndarr, fmtfun=lambda x: format(x, ".2f"), dim=()):
return ret
-def ndarray_to_latex(ndarr, fmtfun=lambda x: format(x, ".2f"), dim=()):
+def ndarray_to_latex(
+ ndarr, fmtfun=lambda x: format(x, ".2f"), dim: tuple[int] = tuple()
+) -> str:
return "\n".join(ndarray_to_latex_parts(ndarr, fmtfun, dim))
diff --git a/pint/registry.py b/pint/registry.py
index 474eb77..964d8a5 100644
--- a/pint/registry.py
+++ b/pint/registry.py
@@ -14,16 +14,10 @@
from __future__ import annotations
+from typing import Generic
+
from . import registry_helpers
-from .facets import (
- ContextRegistry,
- DaskRegistry,
- FormattingRegistry,
- MeasurementRegistry,
- NonMultiplicativeRegistry,
- NumpyRegistry,
- SystemRegistry,
-)
+from . import facets
from .util import logger, pi_theorem
@@ -33,37 +27,40 @@ from .util import logger, pi_theorem
class Quantity(
- # SystemRegistry.Quantity,
- # ContextRegistry.Quantity,
- DaskRegistry.Quantity,
- NumpyRegistry.Quantity,
- MeasurementRegistry.Quantity,
- FormattingRegistry.Quantity,
- NonMultiplicativeRegistry.Quantity,
+ facets.SystemRegistry.Quantity,
+ facets.ContextRegistry.Quantity,
+ facets.DaskRegistry.Quantity,
+ facets.NumpyRegistry.Quantity,
+ facets.MeasurementRegistry.Quantity,
+ facets.FormattingRegistry.Quantity,
+ facets.NonMultiplicativeRegistry.Quantity,
+ facets.PlainRegistry.Quantity,
):
pass
class Unit(
- # SystemRegistry.Unit,
- # ContextRegistry.Unit,
- # DaskRegistry.Unit,
- NumpyRegistry.Unit,
- # MeasurementRegistry.Unit,
- FormattingRegistry.Unit,
- NonMultiplicativeRegistry.Unit,
+ facets.SystemRegistry.Unit,
+ facets.ContextRegistry.Unit,
+ facets.DaskRegistry.Unit,
+ facets.NumpyRegistry.Unit,
+ facets.MeasurementRegistry.Unit,
+ facets.FormattingRegistry.Unit,
+ facets.NonMultiplicativeRegistry.Unit,
+ facets.PlainRegistry.Unit,
):
pass
class UnitRegistry(
- SystemRegistry,
- ContextRegistry,
- DaskRegistry,
- NumpyRegistry,
- MeasurementRegistry,
- FormattingRegistry,
- NonMultiplicativeRegistry,
+ facets.GenericSystemRegistry[Quantity, Unit],
+ facets.GenericContextRegistry[Quantity, Unit],
+ facets.GenericDaskRegistry[Quantity, Unit],
+ facets.GenericNumpyRegistry[Quantity, Unit],
+ facets.GenericMeasurementRegistry[Quantity, Unit],
+ facets.GenericFormattingRegistry[Quantity, Unit],
+ facets.GenericNonMultiplicativeRegistry[Quantity, Unit],
+ facets.GenericPlainRegistry[Quantity, Unit],
):
"""The unit registry stores the definitions and relationships between units.
@@ -171,7 +168,7 @@ class UnitRegistry(
check = registry_helpers.check
-class LazyRegistry:
+class LazyRegistry(Generic[facets.QuantityT, facets.UnitT]):
def __init__(self, args=None, kwargs=None):
self.__dict__["params"] = args or (), kwargs or {}
diff --git a/pint/registry_helpers.py b/pint/registry_helpers.py
index 1f28036..7eee694 100644
--- a/pint/registry_helpers.py
+++ b/pint/registry_helpers.py
@@ -13,7 +13,7 @@ from __future__ import annotations
import functools
from inspect import signature
from itertools import zip_longest
-from typing import TYPE_CHECKING, Callable, TypeVar
+from typing import TYPE_CHECKING, Callable, TypeVar, Any
from collections.abc import Iterable
from ._typing import F
@@ -189,7 +189,7 @@ def wraps(
ret: str | Unit | Iterable[str | Unit | None] | None,
args: str | Unit | Iterable[str | Unit | None] | None,
strict: bool = True,
-) -> Callable[[Callable[..., T]], Callable[..., Quantity[T]]]:
+) -> Callable[[Callable[..., Any]], Callable[..., Quantity]]:
"""Wraps a function to become pint-aware.
Use it when a function requires a numerical value but in some specific
@@ -253,7 +253,7 @@ def wraps(
)
ret = _to_units_container(ret, ureg)
- def decorator(func: Callable[..., T]) -> Callable[..., Quantity[T]]:
+ def decorator(func: Callable[..., Any]) -> Callable[..., Quantity]:
count_params = len(signature(func).parameters)
if len(args) != count_params:
raise TypeError(
@@ -269,7 +269,7 @@ def wraps(
)
@functools.wraps(func, assigned=assigned, updated=updated)
- def wrapper(*values, **kw) -> Quantity[T]:
+ def wrapper(*values, **kw) -> Quantity:
values, kw = _apply_defaults(func, values, kw)
# In principle, the values are used as is
diff --git a/pint/testing.py b/pint/testing.py
index 8e4f15f..d99df0b 100644
--- a/pint/testing.py
+++ b/pint/testing.py
@@ -34,7 +34,7 @@ def _get_comparable_magnitudes(first, second, msg):
return m1, m2
-def assert_equal(first, second, msg=None):
+def assert_equal(first, second, msg: str | None = None) -> None:
if msg is None:
msg = f"Comparing {first!r} and {second!r}. "
@@ -57,7 +57,9 @@ def assert_equal(first, second, msg=None):
assert m1 == m2, msg
-def assert_allclose(first, second, rtol=1e-07, atol=0, msg=None):
+def assert_allclose(
+ first, second, rtol: float = 1e-07, atol: float = 0, msg: str | None = None
+) -> None:
if msg is None:
try:
msg = f"Comparing {first!r} and {second!r}. "
diff --git a/pint/util.py b/pint/util.py
index d75d1b5..40ea39e 100644
--- a/pint/util.py
+++ b/pint/util.py
@@ -30,15 +30,15 @@ from typing import (
)
from collections.abc import Hashable, Generator
-from .compat import NUMERIC_TYPES, tokenizer
+from .compat import NUMERIC_TYPES, tokenizer, Self
from .errors import DefinitionSyntaxError
from .formatting import format_unit
from .pint_eval import build_eval_tree
-from ._typing import PintScalar
+from ._typing import Scalar
if TYPE_CHECKING:
- from ._typing import Quantity, UnitLike, Self
+ from ._typing import Quantity, UnitLike, QuantityOrUnitLike
from .registry import UnitRegistry
@@ -47,12 +47,13 @@ logger.addHandler(NullHandler())
T = TypeVar("T")
TH = TypeVar("TH", bound=Hashable)
+TT = TypeVar("TT", bound=type)
# TODO: Change when Python 3.10 becomes minimal version.
# ItMatrix: TypeAlias = Iterable[Iterable[PintScalar]]
# Matrix: TypeAlias = list[list[PintScalar]]
-ItMatrix = Iterable[Iterable[PintScalar]]
-Matrix = list[list[PintScalar]]
+ItMatrix = Iterable[Iterable[Scalar]]
+Matrix = list[list[Scalar]]
def _noop(x: T) -> T:
@@ -65,7 +66,7 @@ def matrix_to_string(
col_headers: Iterable[str] | None = None,
fmtfun: Callable[
[
- PintScalar,
+ Scalar,
],
str,
] = "{:0.0f}".format,
@@ -125,9 +126,9 @@ def matrix_apply(
matrix: ItMatrix,
func: Callable[
[
- PintScalar,
+ Scalar,
],
- PintScalar,
+ Scalar,
],
) -> Matrix:
"""Apply a function to individual elements within a matrix.
@@ -172,7 +173,14 @@ def column_echelon_form(
Swapped rows.
"""
- _transpose = transpose if transpose_result else _noop
+ _transpose: Callable[
+ [
+ ItMatrix,
+ ],
+ Matrix,
+ ] = (
+ transpose if transpose_result else _noop
+ )
ech_matrix = matrix_apply(
transpose(matrix),
@@ -181,7 +189,7 @@ def column_echelon_form(
rows, cols = len(ech_matrix), len(ech_matrix[0])
# M = [[ntype(x) for x in row] for row in M]
- id_matrix: list[list[PintScalar]] = [ # noqa: E741
+ id_matrix: list[list[Scalar]] = [ # noqa: E741
[ntype(1) if n == nc else ntype(0) for nc in range(rows)] for n in range(rows)
]
@@ -415,7 +423,7 @@ def find_connected_nodes(
return visited
-class udict(dict[str, PintScalar]):
+class udict(dict[str, Scalar]):
"""Custom dict implementing __missing__."""
def __missing__(self, key: str):
@@ -425,7 +433,7 @@ class udict(dict[str, PintScalar]):
return udict(self)
-class UnitsContainer(Mapping[str, PintScalar]):
+class UnitsContainer(Mapping[str, Scalar]):
"""The UnitsContainer stores the product of units and their respective
exponent and implements the corresponding operations.
@@ -441,10 +449,12 @@ class UnitsContainer(Mapping[str, PintScalar]):
_d: udict
_hash: int | None
- _one: PintScalar
+ _one: Scalar
_non_int_type: type
- def __init__(self, *args, non_int_type: type | None = None, **kwargs) -> None:
+ def __init__(
+ self, *args: Any, non_int_type: type | None = None, **kwargs: Any
+ ) -> None:
if args and isinstance(args[0], UnitsContainer):
default_non_int_type = args[0]._non_int_type
else:
@@ -542,7 +552,7 @@ class UnitsContainer(Mapping[str, PintScalar]):
def __len__(self) -> int:
return len(self._d)
- def __getitem__(self, key: str) -> PintScalar:
+ def __getitem__(self, key: str) -> Scalar:
return self._d[key]
def __contains__(self, key: str) -> bool:
@@ -554,10 +564,10 @@ class UnitsContainer(Mapping[str, PintScalar]):
return self._hash
# Only needed by pickle protocol 0 and 1 (used by pytables)
- def __getstate__(self) -> tuple[udict, PintScalar, type]:
+ def __getstate__(self) -> tuple[udict, Scalar, type]:
return self._d, self._one, self._non_int_type
- def __setstate__(self, state: tuple[udict, PintScalar, type]):
+ def __setstate__(self, state: tuple[udict, Scalar, type]):
self._d, self._one, self._non_int_type = state
self._hash = None
@@ -682,9 +692,9 @@ class ParserHelper(UnitsContainer):
__slots__ = ("scale",)
- scale: PintScalar
+ scale: Scalar
- def __init__(self, scale: PintScalar = 1, *args, **kwargs):
+ def __init__(self, scale: Scalar = 1, *args, **kwargs):
super().__init__(*args, **kwargs)
self.scale = scale
@@ -1002,7 +1012,7 @@ class PrettyIPython:
def to_units_container(
- unit_like: UnitLike | Quantity, registry: UnitRegistry | None = None
+ unit_like: QuantityOrUnitLike, registry: UnitRegistry | None = None
) -> UnitsContainer:
"""Convert a unit compatible type to a UnitsContainer.
@@ -1025,6 +1035,7 @@ def to_units_container(
return unit_like._units
elif str in mro:
if registry:
+ # TODO: Why not parse.units here?
return registry._parse_units(unit_like)
else:
return ParserHelper.from_string(unit_like)
@@ -1124,7 +1135,9 @@ def sized(y: Any) -> bool:
return True
-def create_class_with_registry(registry: UnitRegistry, base_class: type) -> type:
+def create_class_with_registry(
+ registry: UnitRegistry, base_class: type[TT]
+) -> type[TT]:
"""Create new class inheriting from base_class and
filling _REGISTRY class attribute with an actual instanced registry.
"""