diff options
40 files changed, 1368 insertions, 833 deletions
diff --git a/pint/_typing.py b/pint/_typing.py index 65e355c..5177e78 100644 --- a/pint/_typing.py +++ b/pint/_typing.py @@ -1,9 +1,9 @@ from __future__ import annotations from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union, Protocol +from decimal import Decimal +from fractions import Fraction -# TODO: Remove when 3.11 becomes minimal version. -Self = TypeVar("Self") if TYPE_CHECKING: from .facets.plain import PlainQuantity as Quantity @@ -11,7 +11,7 @@ if TYPE_CHECKING: from .util import UnitsContainer -class PintScalar(Protocol): +class ScalarProtocol(Protocol): def __add__(self, other: Any) -> Any: ... @@ -36,8 +36,20 @@ class PintScalar(Protocol): def __pow__(self, other: Any, modulo: Any) -> Any: ... + def __gt__(self, other: Any) -> bool: + ... + + def __lt__(self, other: Any) -> bool: + ... + + def __ge__(self, other: Any) -> bool: + ... -class PintArray(Protocol): + def __le__(self, other: Any) -> bool: + ... + + +class ArrayProtocol(Protocol): def __len__(self) -> int: ... @@ -48,18 +60,41 @@ class PintArray(Protocol): ... +HAS_NUMPY = False +if TYPE_CHECKING: + from .compat import HAS_NUMPY + +if HAS_NUMPY: + from .compat import np + + Scalar = Union[ScalarProtocol, float, int, Decimal, Fraction, np.number[Any]] + Array = Union[np.ndarray[Any, Any]] +else: + Scalar = Union[ScalarProtocol, float, int, Decimal, Fraction] + Array = ArrayProtocol + + # TODO: Change when Python 3.10 becomes minimal version. -# Magnitude = PintScalar | PintArray -Magnitude = Union[PintScalar, PintArray] +Magnitude = Union[ScalarProtocol, ArrayProtocol] -UnitLike = Union[str, "UnitsContainer", "Unit"] +UnitLike = Union[str, dict[str, Scalar], "UnitsContainer", "Unit"] QuantityOrUnitLike = Union["Quantity", UnitLike] -Shape = tuple[int, ...] +Shape = tuple[int] -_MagnitudeType = TypeVar("_MagnitudeType") S = TypeVar("S") FuncType = Callable[..., Any] F = TypeVar("F", bound=FuncType) + + +# TODO: Improve or delete types +QuantityArgument = Any + +T = TypeVar("T") + + +class Handler(Protocol): + def __getitem__(self, item: type[T]) -> Callable[[T], None]: + ... diff --git a/pint/compat.py b/pint/compat.py index 7b48efa..727ff99 100644 --- a/pint/compat.py +++ b/pint/compat.py @@ -20,6 +20,16 @@ from collections.abc import Mapping from typing import Any, NoReturn, Callable from collections.abc import Generator, Iterable +try: + from typing import TypeAlias # noqa +except ImportError: + from typing_extensions import TypeAlias # noqa + +try: + from typing import Self # noqa +except ImportError: + from typing_extensions import Self # noqa + def missing_dependency( package: str, display_name: str | None = None @@ -137,10 +147,10 @@ except ImportError: HAS_UNCERTAINTIES = False try: - from babel import Locale as Loc + from babel import Locale from babel import units as babel_units - babel_parse = Loc.parse + babel_parse = Locale.parse HAS_BABEL = hasattr(babel_units, "format_unit") except ImportError: diff --git a/pint/converters.py b/pint/converters.py index 9494ad1..822b8a0 100644 --- a/pint/converters.py +++ b/pint/converters.py @@ -15,16 +15,18 @@ from dataclasses import fields as dc_fields from typing import Any -from ._typing import Self, Magnitude +from ._typing import Magnitude -from .compat import HAS_NUMPY, exp, log # noqa: F401 +from .compat import HAS_NUMPY, exp, log, Self # noqa: F401 @dataclass(frozen=True) class Converter: """Base class for value converters.""" + # list[type[Converter]] _subclasses = [] + # dict[frozenset[str], type[Converter]] _param_names_to_subclass = {} @property @@ -41,21 +43,21 @@ class Converter: def from_reference(self, value: Magnitude, inplace: bool = False) -> Magnitude: return value - def __init_subclass__(cls, **kwargs): + def __init_subclass__(cls, **kwargs: Any): # Get constructor parameters super().__init_subclass__(**kwargs) cls._subclasses.append(cls) @classmethod - def get_field_names(cls, new_cls) -> frozenset[str]: + def get_field_names(cls, new_cls: type) -> frozenset[str]: return frozenset(p.name for p in dc_fields(new_cls)) @classmethod - def preprocess_kwargs(cls, **kwargs): + def preprocess_kwargs(cls, **kwargs: Any) -> dict[str, Any] | None: return None @classmethod - def from_arguments(cls: type[Self], **kwargs: Any) -> Self: + def from_arguments(cls, **kwargs: Any) -> Converter: kwk = frozenset(kwargs.keys()) try: new_cls = cls._param_names_to_subclass[kwk] diff --git a/pint/delegates/txt_defparser/defparser.py b/pint/delegates/txt_defparser/defparser.py index f1b8e45..4acea2f 100644 --- a/pint/delegates/txt_defparser/defparser.py +++ b/pint/delegates/txt_defparser/defparser.py @@ -130,7 +130,7 @@ class DefParser: else: yield stmt - def parse_file(self, filename: pathlib.Path, cfg: ParserConfig | None = None): + def parse_file(self, filename: pathlib.Path | str, cfg: ParserConfig | None = None): return fp.parse( filename, _PintParser, diff --git a/pint/facets/__init__.py b/pint/facets/__init__.py index 750f729..4fd1597 100644 --- a/pint/facets/__init__.py +++ b/pint/facets/__init__.py @@ -71,15 +71,18 @@ from __future__ import annotations -from .context import ContextRegistry -from .dask import DaskRegistry -from .formatting import FormattingRegistry -from .group import GroupRegistry -from .measurement import MeasurementRegistry -from .nonmultiplicative import NonMultiplicativeRegistry -from .numpy import NumpyRegistry -from .plain import PlainRegistry -from .system import SystemRegistry +from .context import ContextRegistry, GenericContextRegistry +from .dask import DaskRegistry, GenericDaskRegistry +from .formatting import FormattingRegistry, GenericFormattingRegistry +from .group import GroupRegistry, GenericGroupRegistry +from .measurement import MeasurementRegistry, GenericMeasurementRegistry +from .nonmultiplicative import ( + NonMultiplicativeRegistry, + GenericNonMultiplicativeRegistry, +) +from .numpy import NumpyRegistry, GenericNumpyRegistry +from .plain import PlainRegistry, GenericPlainRegistry, QuantityT, UnitT, MagnitudeT +from .system import SystemRegistry, GenericSystemRegistry __all__ = [ "ContextRegistry", @@ -91,4 +94,16 @@ __all__ = [ "NumpyRegistry", "PlainRegistry", "SystemRegistry", + "GenericContextRegistry", + "GenericDaskRegistry", + "GenericFormattingRegistry", + "GenericGroupRegistry", + "GenericMeasurementRegistry", + "GenericNonMultiplicativeRegistry", + "GenericNumpyRegistry", + "GenericPlainRegistry", + "GenericSystemRegistry", + "QuantityT", + "UnitT", + "MagnitudeT", ] diff --git a/pint/facets/context/__init__.py b/pint/facets/context/__init__.py index db28436..28c7b5c 100644 --- a/pint/facets/context/__init__.py +++ b/pint/facets/context/__init__.py @@ -13,6 +13,6 @@ from __future__ import annotations from .definitions import ContextDefinition from .objects import Context -from .registry import ContextRegistry +from .registry import ContextRegistry, GenericContextRegistry -__all__ = ["ContextDefinition", "Context", "ContextRegistry"] +__all__ = ["ContextDefinition", "Context", "ContextRegistry", "GenericContextRegistry"] diff --git a/pint/facets/context/definitions.py b/pint/facets/context/definitions.py index 833857e..d2581d5 100644 --- a/pint/facets/context/definitions.py +++ b/pint/facets/context/definitions.py @@ -12,7 +12,7 @@ import itertools import numbers import re from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Callable from collections.abc import Iterable from ... import errors @@ -47,7 +47,7 @@ class Relation: return set(self._varname_re.findall(self.equation)) @property - def transformation(self) -> Callable[..., Quantity[Any]]: + def transformation(self) -> Callable[..., Quantity]: """Return a transformation callable that uses the registry to parse the transformation equation. """ @@ -68,7 +68,7 @@ class ForwardRelation(Relation): """ @property - def bidirectional(self): + def bidirectional(self) -> bool: return False @@ -82,7 +82,7 @@ class BidirectionalRelation(Relation): """ @property - def bidirectional(self): + def bidirectional(self) -> bool: return True diff --git a/pint/facets/context/objects.py b/pint/facets/context/objects.py index 38d8805..9517821 100644 --- a/pint/facets/context/objects.py +++ b/pint/facets/context/objects.py @@ -10,12 +10,32 @@ from __future__ import annotations import weakref from collections import ChainMap, defaultdict -from typing import Any +from typing import Any, Callable, Protocol, Generic from collections.abc import Iterable -from ...facets.plain import UnitDefinition +from ...facets.plain import UnitDefinition, PlainQuantity, PlainUnit, MagnitudeT from ...util import UnitsContainer, to_units_container from .definitions import ContextDefinition +from ..._typing import Magnitude + + +class Transformation(Protocol): + def __call__(self, value: Magnitude, **kwargs: Any) -> Magnitude: + ... + + +from ..._typing import UnitLike + +ToBaseFunc = Callable[[UnitsContainer], UnitsContainer] +SrcDst = tuple[UnitsContainer, UnitsContainer] + + +class ContextQuantity(Generic[MagnitudeT], PlainQuantity[MagnitudeT]): + pass + + +class ContextUnit(PlainUnit): + pass class Context: @@ -75,24 +95,27 @@ class Context: aliases: tuple[str] = tuple(), defaults: dict[str, Any] | None = None, ) -> None: - self.name = name - self.aliases = aliases + self.name: str | None = name + self.aliases: tuple[str] = aliases #: Maps (src, dst) -> transformation function - self.funcs = {} + self.funcs: dict[SrcDst, Transformation] = {} #: Maps defaults variable names to values - self.defaults = defaults or {} + self.defaults: dict[str, Any] = defaults or {} # Store Definition objects that are context-specific - self.redefinitions = [] + # TODO: narrow type this if possible. + self.redefinitions: list[Any] = [] # Flag set to True by the Registry the first time the context is enabled self.checked = False #: Maps (src, dst) -> self #: Used as a convenience dictionary to be composed by ContextChain - self.relation_to_context = weakref.WeakValueDictionary() + self.relation_to_context: weakref.WeakValueDictionary[ + SrcDst, Context + ] = weakref.WeakValueDictionary() @classmethod def from_context(cls, context: Context, **defaults: Any) -> Context: @@ -125,13 +148,22 @@ class Context: @classmethod def from_lines( - cls, lines: Iterable[str], to_base_func=None, non_int_type: type = float + cls, + lines: Iterable[str], + to_base_func: ToBaseFunc | None = None, + non_int_type: type = float, ) -> Context: - cd = ContextDefinition.from_lines(lines, non_int_type) - return cls.from_definition(cd, to_base_func) + context_definition = ContextDefinition.from_lines(lines, non_int_type) + + if context_definition is None: + raise ValueError(f"Could not define Context from from {lines}") + + return cls.from_definition(context_definition, to_base_func) @classmethod - def from_definition(cls, cd: ContextDefinition, to_base_func=None) -> Context: + def from_definition( + cls, cd: ContextDefinition, to_base_func: ToBaseFunc | None = None + ) -> Context: ctx = cls(cd.name, cd.aliases, cd.defaults) for definition in cd.redefinitions: @@ -139,6 +171,7 @@ class Context: for relation in cd.relations: try: + # TODO: check to_base_func. Is it a good API idea? if to_base_func: src = to_base_func(relation.src) dst = to_base_func(relation.dst) @@ -154,14 +187,16 @@ class Context: return ctx - def add_transformation(self, src, dst, func) -> None: + def add_transformation( + self, src: UnitLike, dst: UnitLike, func: Transformation + ) -> None: """Add a transformation function to the context.""" _key = self.__keytransform__(src, dst) self.funcs[_key] = func self.relation_to_context[_key] = self - def remove_transformation(self, src, dst) -> None: + def remove_transformation(self, src: UnitLike, dst: UnitLike) -> None: """Add a transformation function to the context.""" _key = self.__keytransform__(src, dst) @@ -169,14 +204,17 @@ class Context: del self.relation_to_context[_key] @staticmethod - def __keytransform__(src, dst) -> tuple[UnitsContainer, UnitsContainer]: + def __keytransform__(src: UnitLike, dst: UnitLike) -> SrcDst: return to_units_container(src), to_units_container(dst) - def transform(self, src, dst, registry, value): + def transform( + self, src: UnitLike, dst: UnitLike, registry: Any, value: Magnitude + ) -> Magnitude: """Transform a value.""" _key = self.__keytransform__(src, dst) - return self.funcs[_key](registry, value, **self.defaults) + func = self.funcs[_key] + return func(registry, value, **self.defaults) def redefine(self, definition: str) -> None: """Override the definition of a unit in the registry. @@ -202,7 +240,13 @@ class Context: def hashable( self, - ) -> tuple[str | None, tuple[str, ...], frozenset, frozenset, tuple]: + ) -> tuple[ + str | None, + tuple[str], + frozenset[tuple[SrcDst, int]], + frozenset[tuple[str, Any]], + tuple[Any], + ]: """Generate a unique hashable and comparable representation of self, which can be used as a key in a dict. This class cannot define ``__hash__`` because it is mutable, and the Python interpreter does cache the output of ``__hash__``. @@ -220,18 +264,18 @@ class Context: ) -class ContextChain(ChainMap): +class ContextChain(ChainMap[SrcDst, Context]): """A specialized ChainMap for contexts that simplifies finding rules to transform from one dimension to another. """ def __init__(self): super().__init__() - self.contexts = [] + self.contexts: list[Context] = [] self.maps.clear() # Remove default empty map - self._graph = None + self._graph: dict[SrcDst, set[UnitsContainer]] | None = None - def insert_contexts(self, *contexts): + def insert_contexts(self, *contexts: Context): """Insert one or more contexts in reversed order the chained map. (A rule in last context will take precedence) @@ -243,7 +287,7 @@ class ContextChain(ChainMap): self.maps = [ctx.relation_to_context for ctx in reversed(contexts)] + self.maps self._graph = None - def remove_contexts(self, n: int = None): + def remove_contexts(self, n: int | None = None): """Remove the last n inserted contexts from the chain. Parameters @@ -257,7 +301,7 @@ class ContextChain(ChainMap): self._graph = None @property - def defaults(self): + def defaults(self) -> dict[str, Any]: for ctx in self.values(): return ctx.defaults return {} @@ -271,7 +315,10 @@ class ContextChain(ChainMap): self._graph[fr_].add(to_) return self._graph - def transform(self, src, dst, registry, value): + # TODO: type registry + def transform( + self, src: UnitsContainer, dst: UnitsContainer, registry: Any, value: Magnitude + ): """Transform the value, finding the rule in the chained context. (A rule in last context will take precedence) """ diff --git a/pint/facets/context/registry.py b/pint/facets/context/registry.py index a36d82d..746e79c 100644 --- a/pint/facets/context/registry.py +++ b/pint/facets/context/registry.py @@ -11,12 +11,13 @@ from __future__ import annotations import functools from collections import ChainMap from contextlib import contextmanager -from typing import Any, Callable, ContextManager +from typing import Any, Callable, Generator, Generic -from ..._typing import F +from ...compat import TypeAlias +from ..._typing import F, Magnitude from ...errors import UndefinedUnitError -from ...util import find_connected_nodes, find_shortest_path, logger -from ..plain import PlainRegistry, UnitDefinition +from ...util import find_connected_nodes, find_shortest_path, logger, UnitsContainer +from ..plain import GenericPlainRegistry, UnitDefinition, QuantityT, UnitT from .definitions import ContextDefinition from . import objects @@ -36,7 +37,9 @@ class ContextCacheOverlay: self.parse_unit = registry_cache.parse_unit -class ContextRegistry(PlainRegistry): +class GenericContextRegistry( + Generic[QuantityT, UnitT], GenericPlainRegistry[QuantityT, UnitT] +): """Handle of Contexts. Conversion between units with different dimensions according @@ -50,7 +53,7 @@ class ContextRegistry(PlainRegistry): - Parse @context directive. """ - Context = objects.Context + Context: type[objects.Context] = objects.Context def __init__(self, **kwargs: Any) -> None: # Map context name (string) or abbreviation to context. @@ -65,13 +68,13 @@ class ContextRegistry(PlainRegistry): super().__init__(**kwargs) # Allow contexts to add override layers to the units - self._units = ChainMap(self._units) + self._units: ChainMap[str, UnitDefinition] = ChainMap(self._units) def _register_definition_adders(self) -> None: super()._register_definition_adders() self._register_adder(ContextDefinition, self.add_context) - def add_context(self, context: Context | ContextDefinition) -> None: + def add_context(self, context: objects.Context | ContextDefinition) -> None: """Add a context object to the registry. The context will be accessible by its name and aliases. @@ -194,7 +197,7 @@ class ContextRegistry(PlainRegistry): self.define(definition) def enable_contexts( - self, *names_or_contexts: str | objects.Context, **kwargs + self, *names_or_contexts: str | objects.Context, **kwargs: Any ) -> None: """Enable contexts provided by name or by object. @@ -241,7 +244,7 @@ class ContextRegistry(PlainRegistry): self._active_ctx.insert_contexts(*contexts) self._switch_context_cache_and_units() - def disable_contexts(self, n: int = None) -> None: + def disable_contexts(self, n: int | None = None) -> None: """Disable the last n enabled contexts. Parameters @@ -253,7 +256,9 @@ class ContextRegistry(PlainRegistry): self._switch_context_cache_and_units() @contextmanager - def context(self, *names, **kwargs) -> ContextManager[objects.Context]: + def context( + self: GenericContextRegistry[QuantityT, UnitT], *names: str, **kwargs: Any + ) -> Generator[GenericContextRegistry[QuantityT, UnitT], None, None]: """Used as a context manager, this function enables to activate a context which is removed after usage. @@ -309,7 +314,7 @@ class ContextRegistry(PlainRegistry): # the added contexts are removed from the active one. self.disable_contexts(len(names)) - def with_context(self, name, **kwargs) -> Callable[[F], F]: + def with_context(self, name: str, **kwargs: Any) -> Callable[[F], F]: """Decorator to wrap a function call in a Pint context. Use it to ensure that a certain context is active when @@ -351,7 +356,13 @@ class ContextRegistry(PlainRegistry): return decorator - def _convert(self, value, src, dst, inplace=False): + def _convert( + self, + value: Magnitude, + src: UnitsContainer, + dst: UnitsContainer, + inplace: bool = False, + ) -> Magnitude: """Convert value from some source to destination units. In addition to what is done by the PlainRegistry, @@ -391,7 +402,9 @@ class ContextRegistry(PlainRegistry): return super()._convert(value, src, dst, inplace) - def _get_compatible_units(self, input_units, group_or_system): + def _get_compatible_units( + self, input_units: UnitsContainer, group_or_system: str | None = None + ): src_dim = self._get_dimensionality(input_units) ret = super()._get_compatible_units(input_units, group_or_system) @@ -404,3 +417,10 @@ class ContextRegistry(PlainRegistry): ret |= self._cache.dimensional_equivalents[node] return ret + + +class ContextRegistry( + GenericContextRegistry[objects.ContextQuantity[Any], objects.ContextUnit] +): + Quantity: TypeAlias = objects.ContextQuantity[Any] + Unit: TypeAlias = objects.ContextUnit diff --git a/pint/facets/dask/__init__.py b/pint/facets/dask/__init__.py index 90c8972..8d62f55 100644 --- a/pint/facets/dask/__init__.py +++ b/pint/facets/dask/__init__.py @@ -11,10 +11,18 @@ from __future__ import annotations +from typing import Generic, Any import functools -from ...compat import compute, dask_array, persist, visualize -from ..plain import PlainRegistry, PlainQuantity +from ...compat import compute, dask_array, persist, visualize, TypeAlias +from ..plain import ( + GenericPlainRegistry, + PlainQuantity, + QuantityT, + UnitT, + PlainUnit, + MagnitudeT, +) def check_dask_array(f): @@ -31,7 +39,7 @@ def check_dask_array(f): return wrapper -class DaskQuantity(PlainQuantity): +class DaskQuantity(Generic[MagnitudeT], PlainQuantity[MagnitudeT]): # Dask.array.Array ducking def __dask_graph__(self): if isinstance(self._magnitude, dask_array.Array): @@ -119,5 +127,16 @@ class DaskQuantity(PlainQuantity): visualize(self, **kwargs) -class DaskRegistry(PlainRegistry): - Quantity = DaskQuantity +class DaskUnit(PlainUnit): + pass + + +class GenericDaskRegistry( + Generic[QuantityT, UnitT], GenericPlainRegistry[QuantityT, UnitT] +): + pass + + +class DaskRegistry(GenericDaskRegistry[DaskQuantity[Any], DaskUnit]): + Quantity: TypeAlias = DaskQuantity[Any] + Unit: TypeAlias = DaskUnit diff --git a/pint/facets/formatting/__init__.py b/pint/facets/formatting/__init__.py index e3f4381..799fa31 100644 --- a/pint/facets/formatting/__init__.py +++ b/pint/facets/formatting/__init__.py @@ -11,6 +11,11 @@ from __future__ import annotations from .objects import FormattingQuantity, FormattingUnit -from .registry import FormattingRegistry +from .registry import FormattingRegistry, GenericFormattingRegistry -__all__ = ["FormattingQuantity", "FormattingUnit", "FormattingRegistry"] +__all__ = [ + "FormattingQuantity", + "FormattingUnit", + "FormattingRegistry", + "GenericFormattingRegistry", +] diff --git a/pint/facets/formatting/objects.py b/pint/facets/formatting/objects.py index 5df937c..7d39e91 100644 --- a/pint/facets/formatting/objects.py +++ b/pint/facets/formatting/objects.py @@ -9,7 +9,7 @@ from __future__ import annotations import re -from typing import Any +from typing import Any, Generic from ...compat import babel_parse, ndarray, np from ...formatting import ( @@ -23,10 +23,10 @@ from ...formatting import ( ) from ...util import UnitsContainer, iterable -from ..plain import PlainQuantity, PlainUnit +from ..plain import PlainQuantity, PlainUnit, MagnitudeT -class FormattingQuantity(PlainQuantity): +class FormattingQuantity(Generic[MagnitudeT], PlainQuantity[MagnitudeT]): _exp_pattern = re.compile(r"([0-9]\.?[0-9]*)e(-?)\+?0*([0-9]+)") def __format__(self, spec: str) -> str: diff --git a/pint/facets/formatting/registry.py b/pint/facets/formatting/registry.py index c4dc373..7684597 100644 --- a/pint/facets/formatting/registry.py +++ b/pint/facets/formatting/registry.py @@ -8,10 +8,21 @@ from __future__ import annotations -from ..plain import PlainRegistry -from .objects import FormattingQuantity, FormattingUnit +from typing import Generic, Any +from ...compat import TypeAlias +from ..plain import GenericPlainRegistry, QuantityT, UnitT +from . import objects -class FormattingRegistry(PlainRegistry): - Quantity = FormattingQuantity - Unit = FormattingUnit + +class GenericFormattingRegistry( + Generic[QuantityT, UnitT], GenericPlainRegistry[QuantityT, UnitT] +): + pass + + +class FormattingRegistry( + GenericFormattingRegistry[objects.FormattingQuantity[Any], objects.FormattingUnit] +): + Quantity: TypeAlias = objects.FormattingQuantity[Any] + Unit: TypeAlias = objects.FormattingUnit diff --git a/pint/facets/group/__init__.py b/pint/facets/group/__init__.py index e1fad04..b25ea85 100644 --- a/pint/facets/group/__init__.py +++ b/pint/facets/group/__init__.py @@ -11,7 +11,14 @@ from __future__ import annotations from .definitions import GroupDefinition -from .objects import Group -from .registry import GroupRegistry +from .objects import Group, GroupQuantity, GroupUnit +from .registry import GroupRegistry, GenericGroupRegistry -__all__ = ["GroupDefinition", "Group", "GroupRegistry"] +__all__ = [ + "GroupDefinition", + "Group", + "GroupRegistry", + "GenericGroupRegistry", + "GroupQuantity", + "GroupUnit", +] diff --git a/pint/facets/group/definitions.py b/pint/facets/group/definitions.py index 554a63b..2f34750 100644 --- a/pint/facets/group/definitions.py +++ b/pint/facets/group/definitions.py @@ -11,7 +11,7 @@ from __future__ import annotations from collections.abc import Iterable from dataclasses import dataclass -from ..._typing import Self +from ...compat import Self from ... import errors from .. import plain diff --git a/pint/facets/group/objects.py b/pint/facets/group/objects.py index 200a323..64d91c1 100644 --- a/pint/facets/group/objects.py +++ b/pint/facets/group/objects.py @@ -8,9 +8,36 @@ from __future__ import annotations +from typing import Callable, Any, TYPE_CHECKING, Generic + from collections.abc import Generator, Iterable from ...util import SharedRegistryObject, getattr_maybe_raise from .definitions import GroupDefinition +from ..plain import PlainQuantity, PlainUnit, MagnitudeT + +if TYPE_CHECKING: + from ..plain import UnitDefinition + + DefineFunc = Callable[ + [ + Any, + ], + None, + ] + AddUnitFunc = Callable[ + [ + UnitDefinition, + ], + None, + ] + + +class GroupQuantity(Generic[MagnitudeT], PlainQuantity[MagnitudeT]): + pass + + +class GroupUnit(PlainUnit): + pass class Group(SharedRegistryObject): @@ -57,7 +84,7 @@ class Group(SharedRegistryObject): self._computed_members: frozenset[str] | None = None @property - def members(self): + def members(self) -> frozenset[str]: """Names of the units that are members of the group. Calculated to include to all units in all included _used_groups. @@ -143,7 +170,7 @@ class Group(SharedRegistryObject): @classmethod def from_lines( - cls, lines: Iterable[str], define_func, non_int_type: type = float + cls, lines: Iterable[str], define_func: DefineFunc, non_int_type: type = float ) -> Group: """Return a Group object parsing an iterable of lines. @@ -160,11 +187,15 @@ class Group(SharedRegistryObject): """ group_definition = GroupDefinition.from_lines(lines, non_int_type) + + if group_definition is None: + raise ValueError(f"Could not define group from {lines}") + return cls.from_definition(group_definition, define_func) @classmethod def from_definition( - cls, group_definition: GroupDefinition, add_unit_func=None + cls, group_definition: GroupDefinition, add_unit_func: AddUnitFunc | None = None ) -> Group: grp = cls(group_definition.name) diff --git a/pint/facets/group/registry.py b/pint/facets/group/registry.py index 0d35ae0..f130e61 100644 --- a/pint/facets/group/registry.py +++ b/pint/facets/group/registry.py @@ -8,20 +8,28 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Generic, Any +from ...compat import TypeAlias from ... import errors if TYPE_CHECKING: - from ..._typing import Unit - -from ...util import create_class_with_registry -from ..plain import PlainRegistry, UnitDefinition + from ..._typing import Unit, UnitsContainer + +from ...util import create_class_with_registry, to_units_container +from ..plain import ( + GenericPlainRegistry, + UnitDefinition, + QuantityT, + UnitT, +) from .definitions import GroupDefinition from . import objects -class GroupRegistry(PlainRegistry): +class GenericGroupRegistry( + Generic[QuantityT, UnitT], GenericPlainRegistry[QuantityT, UnitT] +): """Handle of Groups. Group units @@ -34,7 +42,7 @@ class GroupRegistry(PlainRegistry): # TODO: Change this to Group: Group to specify class # and use introspection to get system class as a way # to enjoy typing goodies - Group = objects.Group + Group = type[objects.Group] def __init__(self, **kwargs): super().__init__(**kwargs) @@ -46,7 +54,7 @@ class GroupRegistry(PlainRegistry): def _init_dynamic_classes(self) -> None: """Generate subclasses on the fly and attach them to self""" super()._init_dynamic_classes() - self.Group = create_class_with_registry(self, self.Group) + self.Group = create_class_with_registry(self, objects.Group) def _after_init(self) -> None: """Invoked at the end of ``__init__``. @@ -113,8 +121,23 @@ class GroupRegistry(PlainRegistry): return self.Group(name) - def _get_compatible_units(self, input_units, group) -> frozenset[Unit]: - ret = super()._get_compatible_units(input_units, group) + def get_compatible_units( + self, input_units: UnitsContainer, group: str | None = None + ) -> frozenset[Unit]: + """ """ + if group is None: + return super().get_compatible_units(input_units) + + input_units = to_units_container(input_units) + + equiv = self._get_compatible_units(input_units, group) + + return frozenset(self.Unit(eq) for eq in equiv) + + def _get_compatible_units( + self, input_units: UnitsContainer, group: str | None = None + ) -> frozenset[str]: + ret = super()._get_compatible_units(input_units) if not group: return ret @@ -124,3 +147,10 @@ class GroupRegistry(PlainRegistry): else: raise ValueError("Unknown Group with name '%s'" % group) return frozenset(ret & members) + + +class GroupRegistry( + GenericGroupRegistry[objects.GroupQuantity[Any], objects.GroupUnit] +): + Quantity: TypeAlias = objects.GroupQuantity[Any] + Unit: TypeAlias = objects.GroupUnit diff --git a/pint/facets/measurement/__init__.py b/pint/facets/measurement/__init__.py index 21539dc..d36a5c3 100644 --- a/pint/facets/measurement/__init__.py +++ b/pint/facets/measurement/__init__.py @@ -11,6 +11,11 @@ from __future__ import annotations from .objects import Measurement, MeasurementQuantity -from .registry import MeasurementRegistry +from .registry import MeasurementRegistry, GenericMeasurementRegistry -__all__ = ["Measurement", "MeasurementQuantity", "MeasurementRegistry"] +__all__ = [ + "Measurement", + "MeasurementQuantity", + "MeasurementRegistry", + "GenericMeasurementRegistry", +] diff --git a/pint/facets/measurement/objects.py b/pint/facets/measurement/objects.py index 5f3ba7a..b9cacda 100644 --- a/pint/facets/measurement/objects.py +++ b/pint/facets/measurement/objects.py @@ -10,15 +10,16 @@ from __future__ import annotations import copy import re +from typing import Generic from ...compat import ufloat from ...formatting import _FORMATS, extract_custom_flags, siunitx_format_unit -from ..plain import PlainQuantity +from ..plain import PlainQuantity, PlainUnit, MagnitudeT MISSING = object() -class MeasurementQuantity(PlainQuantity): +class MeasurementQuantity(Generic[MagnitudeT], PlainQuantity[MagnitudeT]): # Measurement support def plus_minus(self, error, relative=False): if isinstance(error, self.__class__): @@ -32,6 +33,10 @@ class MeasurementQuantity(PlainQuantity): return self._REGISTRY.Measurement(copy.copy(self.magnitude), error, self._units) +class MeasurementUnit(PlainUnit): + pass + + class Measurement(PlainQuantity): """Implements a class to describe a quantity with uncertainty. diff --git a/pint/facets/measurement/registry.py b/pint/facets/measurement/registry.py index 0fc4391..4a3e878 100644 --- a/pint/facets/measurement/registry.py +++ b/pint/facets/measurement/registry.py @@ -9,15 +9,17 @@ from __future__ import annotations -from ...compat import ufloat +from typing import Generic, Any + +from ...compat import ufloat, TypeAlias from ...util import create_class_with_registry -from ..plain import PlainRegistry -from .objects import MeasurementQuantity +from ..plain import GenericPlainRegistry, QuantityT, UnitT from . import objects -class MeasurementRegistry(PlainRegistry): - Quantity = MeasurementQuantity +class GenericMeasurementRegistry( + Generic[QuantityT, UnitT], GenericPlainRegistry[QuantityT, UnitT] +): Measurement = objects.Measurement def _init_dynamic_classes(self) -> None: @@ -34,3 +36,12 @@ class MeasurementRegistry(PlainRegistry): ) self.Measurement = no_uncertainties + + +class MeasurementRegistry( + GenericMeasurementRegistry[ + objects.MeasurementQuantity[Any], objects.MeasurementUnit + ] +): + Quantity: TypeAlias = objects.MeasurementQuantity[Any] + Unit: TypeAlias = objects.MeasurementUnit diff --git a/pint/facets/nonmultiplicative/__init__.py b/pint/facets/nonmultiplicative/__init__.py index cbba410..eb3292b 100644 --- a/pint/facets/nonmultiplicative/__init__.py +++ b/pint/facets/nonmultiplicative/__init__.py @@ -15,8 +15,6 @@ from __future__ import annotations # This import register LogarithmicConverter and OffsetConverter to be usable # (via subclassing) from .definitions import LogarithmicConverter, OffsetConverter # noqa: F401 -from .registry import NonMultiplicativeRegistry +from .registry import NonMultiplicativeRegistry, GenericNonMultiplicativeRegistry -__all__ = [ - "NonMultiplicativeRegistry", -] +__all__ = ["NonMultiplicativeRegistry", "GenericNonMultiplicativeRegistry"] diff --git a/pint/facets/nonmultiplicative/objects.py b/pint/facets/nonmultiplicative/objects.py index 0ab743e..8b944b1 100644 --- a/pint/facets/nonmultiplicative/objects.py +++ b/pint/facets/nonmultiplicative/objects.py @@ -8,10 +8,12 @@ from __future__ import annotations -from ..plain import PlainQuantity +from typing import Generic +from ..plain import PlainQuantity, PlainUnit, MagnitudeT -class NonMultiplicativeQuantity(PlainQuantity): + +class NonMultiplicativeQuantity(Generic[MagnitudeT], PlainQuantity[MagnitudeT]): @property def _is_multiplicative(self) -> bool: """Check if the PlainQuantity object has only multiplicative units.""" @@ -59,3 +61,7 @@ class NonMultiplicativeQuantity(PlainQuantity): if next(iter(self._units.values())) != 1: is_ok = False return is_ok + + +class NonMultiplicativeUnit(PlainUnit): + pass diff --git a/pint/facets/nonmultiplicative/registry.py b/pint/facets/nonmultiplicative/registry.py index 8bc04db..505406c 100644 --- a/pint/facets/nonmultiplicative/registry.py +++ b/pint/facets/nonmultiplicative/registry.py @@ -8,16 +8,22 @@ from __future__ import annotations -from typing import Any +from typing import Any, TypeVar, Generic +from ...compat import TypeAlias from ...errors import DimensionalityError, UndefinedUnitError from ...util import UnitsContainer, logger -from ..plain import PlainRegistry, UnitDefinition +from ..plain import GenericPlainRegistry, UnitDefinition, QuantityT, UnitT from .definitions import OffsetConverter, ScaleConverter -from .objects import NonMultiplicativeQuantity +from . import objects -class NonMultiplicativeRegistry(PlainRegistry): +T = TypeVar("T") + + +class GenericNonMultiplicativeRegistry( + Generic[QuantityT, UnitT], GenericPlainRegistry[QuantityT, UnitT] +): """Handle of non multiplicative units (e.g. Temperature). Capabilities: @@ -35,8 +41,6 @@ class NonMultiplicativeRegistry(PlainRegistry): """ - Quantity = NonMultiplicativeQuantity - def __init__( self, default_as_delta: bool = True, @@ -58,14 +62,14 @@ class NonMultiplicativeRegistry(PlainRegistry): input_string: str, as_delta: bool | None = None, case_sensitive: bool | None = None, - ): + ) -> UnitsContainer: """ """ if as_delta is None: as_delta = self.default_as_delta return super()._parse_units(input_string, as_delta, case_sensitive) - def _add_unit(self, definition: UnitDefinition): + def _add_unit(self, definition: UnitDefinition) -> None: super()._add_unit(definition) if definition.is_multiplicative: @@ -104,22 +108,60 @@ class NonMultiplicativeRegistry(PlainRegistry): ) super()._add_unit(delta_def) - def _is_multiplicative(self, u) -> bool: - if u in self._units: - return self._units[u].is_multiplicative + def _is_multiplicative(self, unit_name: str) -> bool: + """True if the unit is multiplicative. + + Parameters + ---------- + unit_name + Name of the unit to check. + Can be prefixed, pluralized or even an alias + + Raises + ------ + UndefinedUnitError + If the unit is not in the registyr. + """ + if unit_name in self._units: + return self._units[unit_name].is_multiplicative # If the unit is not in the registry might be because it is not # registered with its prefixed version. # TODO: Might be better to register them. - names = self.parse_unit_name(u) + names = self.parse_unit_name(unit_name) assert len(names) == 1 _, base_name, _ = names[0] try: return self._units[base_name].is_multiplicative except KeyError: - raise UndefinedUnitError(u) + raise UndefinedUnitError(unit_name) + + def _validate_and_extract(self, units: UnitsContainer) -> str | None: + """Used to check if a given units is suitable for a simple + conversion. + + Return None if all units are non-multiplicative + Return the unit name if a single non-multiplicative unit is found + and is raised to a power equals to 1. + + Otherwise, raise an Exception. + + Parameters + ---------- + units + Compound dictionary. + + Raises + ------ + ValueError + If the more than a single non-multiplicative unit is present, + or a single one is present but raised to a power different from 1. + + """ + + # TODO: document what happens if autoconvert_offset_to_baseunit + # TODO: Clarify docs - def _validate_and_extract(self, units): # u is for unit, e is for exponent nonmult_units = [ (u, e) for u, e in units.items() if not self._is_multiplicative(u) @@ -147,11 +189,16 @@ class NonMultiplicativeRegistry(PlainRegistry): return None - def _add_ref_of_log_or_offset_unit(self, offset_unit, all_units): + def _add_ref_of_log_or_offset_unit( + self, offset_unit: str, all_units: UnitsContainer + ) -> UnitsContainer: slct_unit = self._units[offset_unit] if slct_unit.is_logarithmic or (not slct_unit.is_multiplicative): # Extract reference unit slct_ref = slct_unit.reference + + # TODO: Check that reference is None + # If reference unit is not dimensionless if slct_ref != UnitsContainer(): # Extract reference unit @@ -161,7 +208,9 @@ class NonMultiplicativeRegistry(PlainRegistry): # Otherwise, return the units unmodified return all_units - def _convert(self, value, src, dst, inplace=False): + def _convert( + self, value: T, src: UnitsContainer, dst: UnitsContainer, inplace: bool = False + ) -> T: """Convert value from some source to destination units. In addition to what is done by the PlainRegistry, @@ -235,3 +284,12 @@ class NonMultiplicativeRegistry(PlainRegistry): ) return value + + +class NonMultiplicativeRegistry( + GenericNonMultiplicativeRegistry[ + objects.NonMultiplicativeQuantity[Any], objects.NonMultiplicativeUnit + ] +): + Quantity: TypeAlias = objects.NonMultiplicativeQuantity[Any] + Unit: TypeAlias = objects.NonMultiplicativeUnit diff --git a/pint/facets/numpy/__init__.py b/pint/facets/numpy/__init__.py index aad9508..2e38dc1 100644 --- a/pint/facets/numpy/__init__.py +++ b/pint/facets/numpy/__init__.py @@ -10,6 +10,6 @@ from __future__ import annotations -from .registry import NumpyRegistry +from .registry import NumpyRegistry, GenericNumpyRegistry -__all__ = ["NumpyRegistry"] +__all__ = ["NumpyRegistry", "GenericNumpyRegistry"] diff --git a/pint/facets/numpy/quantity.py b/pint/facets/numpy/quantity.py index 131983c..880f860 100644 --- a/pint/facets/numpy/quantity.py +++ b/pint/facets/numpy/quantity.py @@ -11,11 +11,11 @@ from __future__ import annotations import functools import math import warnings -from typing import Any +from typing import Any, Generic -from ..plain import PlainQuantity +from ..plain import PlainQuantity, MagnitudeT -from ..._typing import Shape, _MagnitudeType +from ..._typing import Shape from ...compat import _to_magnitude, np from ...errors import DimensionalityError, PintTypeError, UnitStrippedWarning from .numpy_func import ( @@ -42,7 +42,7 @@ def method_wraps(numpy_func): return wrapper -class NumpyQuantity(PlainQuantity): +class NumpyQuantity(Generic[MagnitudeT], PlainQuantity[MagnitudeT]): """ """ # NumPy function/ufunc support @@ -130,11 +130,11 @@ class NumpyQuantity(PlainQuantity): raise DimensionalityError("dimensionless", self._units) return self.__class__(self.magnitude.clip(min, max, out, **kwargs), self._units) - def fill(self: NumpyQuantity[np.ndarray], value) -> None: + def fill(self: NumpyQuantity, value) -> None: self._units = value._units return self.magnitude.fill(value.magnitude) - def put(self: NumpyQuantity[np.ndarray], indices, values, mode="raise") -> None: + def put(self: NumpyQuantity, indices, values, mode="raise") -> None: if isinstance(values, self.__class__): values = values.to(self).magnitude elif self.dimensionless: @@ -144,11 +144,11 @@ class NumpyQuantity(PlainQuantity): self.magnitude.put(indices, values, mode) @property - def real(self) -> NumpyQuantity[_MagnitudeType]: + def real(self) -> NumpyQuantity: return self.__class__(self._magnitude.real, self._units) @property - def imag(self) -> NumpyQuantity[_MagnitudeType]: + def imag(self) -> NumpyQuantity: return self.__class__(self._magnitude.imag, self._units) @property diff --git a/pint/facets/numpy/registry.py b/pint/facets/numpy/registry.py index 11d57f3..e93de44 100644 --- a/pint/facets/numpy/registry.py +++ b/pint/facets/numpy/registry.py @@ -9,11 +9,20 @@ from __future__ import annotations -from ..plain import PlainRegistry +from typing import Generic, Any + +from ...compat import TypeAlias +from ..plain import GenericPlainRegistry, QuantityT, UnitT from .quantity import NumpyQuantity from .unit import NumpyUnit -class NumpyRegistry(PlainRegistry): - Quantity = NumpyQuantity - Unit = NumpyUnit +class GenericNumpyRegistry( + Generic[QuantityT, UnitT], GenericPlainRegistry[QuantityT, UnitT] +): + pass + + +class NumpyRegistry(GenericPlainRegistry[NumpyQuantity[Any], NumpyUnit]): + Quantity: TypeAlias = NumpyQuantity[Any] + Unit: TypeAlias = NumpyUnit diff --git a/pint/facets/plain/__init__.py b/pint/facets/plain/__init__.py index 211d017..90bf2e3 100644 --- a/pint/facets/plain/__init__.py +++ b/pint/facets/plain/__init__.py @@ -19,9 +19,11 @@ from .definitions import ( UnitDefinition, ) from .objects import PlainQuantity, PlainUnit -from .registry import PlainRegistry +from .registry import PlainRegistry, GenericPlainRegistry, QuantityT, UnitT +from .quantity import MagnitudeT __all__ = [ + "GenericPlainRegistry", "PlainUnit", "PlainQuantity", "PlainRegistry", @@ -31,4 +33,7 @@ __all__ = [ "PrefixDefinition", "ScaleConverter", "UnitDefinition", + "QuantityT", + "UnitT", + "MagnitudeT", ] diff --git a/pint/facets/plain/definitions.py b/pint/facets/plain/definitions.py index 79a44f1..4b352e7 100644 --- a/pint/facets/plain/definitions.py +++ b/pint/facets/plain/definitions.py @@ -13,7 +13,7 @@ import numbers import typing as ty from dataclasses import dataclass from functools import cached_property -from typing import Callable, Any +from typing import Any from ..._typing import Magnitude from ... import errors @@ -69,11 +69,15 @@ class DefaultsDefinition: @dataclass(frozen=True) -class PrefixDefinition(errors.WithDefErr): - """Definition of a prefix.""" - +class NamedDefinition: #: name of the prefix name: str + + +@dataclass(frozen=True) +class PrefixDefinition(NamedDefinition, errors.WithDefErr): + """Definition of a prefix.""" + #: scaling value for this prefix value: numbers.Number #: canonical symbol @@ -90,8 +94,8 @@ class PrefixDefinition(errors.WithDefErr): return bool(self.defined_symbol) @cached_property - def converter(self): - return Converter.from_arguments(scale=self.value) + def converter(self) -> ScaleConverter: + return ScaleConverter(self.value) def __post_init__(self): if not errors.is_valid_prefix_name(self.name): @@ -110,22 +114,19 @@ class PrefixDefinition(errors.WithDefErr): @dataclass(frozen=True) -class UnitDefinition(errors.WithDefErr): +class UnitDefinition(NamedDefinition, errors.WithDefErr): """Definition of a unit.""" - #: canonical name of the unit - name: str #: canonical symbol defined_symbol: str | None #: additional names for the same unit aliases: tuple[str] #: A functiont that converts a value in these units into the reference units - converter: Callable[ - [ - Magnitude, - ], - Magnitude, - ] | Converter | None + # TODO: this has changed as converter is now annotated as converter. + # Briefly, in several places converter attributes like as_multiplicative were + # accesed. So having a generic function is a no go. + # I guess this was never used as errors where not raised. + converter: Converter | None #: Reference units. reference: UnitsContainer | None @@ -190,7 +191,7 @@ class UnitDefinition(errors.WithDefErr): def is_base(self) -> bool: """Indicates if it is a base unit.""" - # TODO: why is this here + # TODO: This is set in __post_init__ return self._is_base @property @@ -215,17 +216,14 @@ class UnitDefinition(errors.WithDefErr): @dataclass(frozen=True) -class DimensionDefinition(errors.WithDefErr): +class DimensionDefinition(NamedDefinition, errors.WithDefErr): """Definition of a root dimension""" - #: name of the dimension - name: str - @property - def is_base(self): + def is_base(self) -> bool: return True - def __post_init__(self): + def __post_init__(self) -> None: if not errors.is_valid_dimension_name(self.name): raise self.def_err(errors.MSG_INVALID_DIMENSION_NAME) @@ -238,7 +236,7 @@ class DerivedDimensionDefinition(DimensionDefinition): reference: UnitsContainer @property - def is_base(self): + def is_base(self) -> bool: return False def __post_init__(self): diff --git a/pint/facets/plain/qto.py b/pint/facets/plain/qto.py new file mode 100644 index 0000000..72b8157 --- /dev/null +++ b/pint/facets/plain/qto.py @@ -0,0 +1,386 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import bisect +import math +import numbers +from ...util import infer_base_unit +import warnings +from ...compat import ( + mip_INF, + mip_INTEGER, + mip_model, + mip_Model, + mip_OptimizationStatus, + mip_xsum, +) + +if TYPE_CHECKING: + from ..._typing import UnitLike + from ...util import UnitsContainer + from .quantity import PlainQuantity + + +def _get_reduced_units( + quantity: PlainQuantity, units: UnitsContainer +) -> UnitsContainer: + # loop through individual units and compare to each other unit + # can we do better than a nested loop here? + for unit1, exp in units.items(): + # make sure it wasn't already reduced to zero exponent on prior pass + if unit1 not in units: + continue + for unit2 in units: + # get exponent after reduction + exp = units[unit1] + if unit1 != unit2: + power = quantity._REGISTRY._get_dimensionality_ratio(unit1, unit2) + if power: + units = units.add(unit2, exp / power).remove([unit1]) + break + return units + + +def ito_reduced_units(quantity: PlainQuantity) -> None: + """Return PlainQuantity scaled in place to reduced units, i.e. one unit per + dimension. This will not reduce compound units (e.g., 'J/kg' will not + be reduced to m**2/s**2), nor can it make use of contexts at this time. + """ + + # shortcuts in case we're dimensionless or only a single unit + if quantity.dimensionless: + return quantity.ito({}) + if len(quantity._units) == 1: + return None + + units = quantity._units.copy() + new_units = _get_reduced_units(quantity, units) + + return quantity.ito(new_units) + + +def to_reduced_units( + quantity: PlainQuantity, +) -> PlainQuantity: + """Return PlainQuantity scaled in place to reduced units, i.e. one unit per + dimension. This will not reduce compound units (intentionally), nor + can it make use of contexts at this time. + """ + + # shortcuts in case we're dimensionless or only a single unit + if quantity.dimensionless: + return quantity.to({}) + if len(quantity._units) == 1: + return quantity + + units = quantity._units.copy() + new_units = _get_reduced_units(quantity, units) + + return quantity.to(new_units) + + +def to_compact( + quantity: PlainQuantity, unit: UnitsContainer | None = None +) -> PlainQuantity: + """ "Return PlainQuantity rescaled to compact, human-readable units. + + To get output in terms of a different unit, use the unit parameter. + + + Examples + -------- + + >>> import pint + >>> ureg = pint.UnitRegistry() + >>> (200e-9*ureg.s).to_compact() + <Quantity(200.0, 'nanosecond')> + >>> (1e-2*ureg('kg m/s^2')).to_compact('N') + <Quantity(10.0, 'millinewton')> + """ + + if not isinstance(quantity.magnitude, numbers.Number): + msg = "to_compact applied to non numerical types " "has an undefined behavior." + w = RuntimeWarning(msg) + warnings.warn(w, stacklevel=2) + return quantity + + if ( + quantity.unitless + or quantity.magnitude == 0 + or math.isnan(quantity.magnitude) + or math.isinf(quantity.magnitude) + ): + return quantity + + SI_prefixes: dict[int, str] = {} + for prefix in quantity._REGISTRY._prefixes.values(): + try: + scale = prefix.converter.scale + # Kludgy way to check if this is an SI prefix + log10_scale = int(math.log10(scale)) + if log10_scale == math.log10(scale): + SI_prefixes[log10_scale] = prefix.name + except Exception: + SI_prefixes[0] = "" + + SI_prefixes_list = sorted(SI_prefixes.items()) + SI_powers = [item[0] for item in SI_prefixes_list] + SI_bases = [item[1] for item in SI_prefixes_list] + + if unit is None: + unit = infer_base_unit(quantity, registry=quantity._REGISTRY) + else: + unit = infer_base_unit(quantity.__class__(1, unit), registry=quantity._REGISTRY) + + q_base = quantity.to(unit) + + magnitude = q_base.magnitude + + units = list(q_base._units.items()) + units_numerator = [a for a in units if a[1] > 0] + + if len(units_numerator) > 0: + unit_str, unit_power = units_numerator[0] + else: + unit_str, unit_power = units[0] + + if unit_power > 0: + power = math.floor(math.log10(abs(magnitude)) / float(unit_power) / 3) * 3 + else: + power = math.ceil(math.log10(abs(magnitude)) / float(unit_power) / 3) * 3 + + index = bisect.bisect_left(SI_powers, power) + + if index >= len(SI_bases): + index = -1 + + prefix_str = SI_bases[index] + + new_unit_str = prefix_str + unit_str + new_unit_container = q_base._units.rename(unit_str, new_unit_str) + + return quantity.to(new_unit_container) + + +def to_preferred( + quantity: PlainQuantity, preferred_units: list[UnitLike] +) -> PlainQuantity: + """Return Quantity converted to a unit composed of the preferred units. + + Examples + -------- + + >>> import pint + >>> ureg = pint.UnitRegistry() + >>> (1*ureg.acre).to_preferred([ureg.meters]) + <Quantity(4046.87261, 'meter ** 2')> + >>> (1*(ureg.force_pound*ureg.m)).to_preferred([ureg.W]) + <Quantity(4.44822162, 'second * watt')> + """ + + if not quantity.dimensionality: + return quantity + + # The optimizer isn't perfect, and will sometimes miss obvious solutions. + # This sub-algorithm is less powerful, but always finds the very simple solutions. + def find_simple(): + best_ratio = None + best_unit = None + self_dims = sorted(quantity.dimensionality) + self_exps = [quantity.dimensionality[d] for d in self_dims] + s_exps_head, *s_exps_tail = self_exps + n = len(s_exps_tail) + for preferred_unit in preferred_units: + dims = sorted(preferred_unit.dimensionality) + if dims == self_dims: + p_exps_head, *p_exps_tail = ( + preferred_unit.dimensionality[d] for d in dims + ) + if all( + s_exps_tail[i] * p_exps_head == p_exps_tail[i] ** s_exps_head + for i in range(n) + ): + ratio = p_exps_head / s_exps_head + ratio = max(ratio, 1 / ratio) + if best_ratio is None or ratio < best_ratio: + best_ratio = ratio + best_unit = preferred_unit ** (s_exps_head / p_exps_head) + return best_unit + + simple = find_simple() + if simple is not None: + return quantity.to(simple) + + # For each dimension (e.g. T(ime), L(ength), M(ass)), assign a default base unit from + # the collection of base units + + unit_selections = { + base_unit.dimensionality: base_unit + for base_unit in map(quantity._REGISTRY.Unit, quantity._REGISTRY._base_units) + } + + # Override the default unit of each dimension with the 1D-units used in this Quantity + unit_selections.update( + { + unit.dimensionality: unit + for unit in map(quantity._REGISTRY.Unit, quantity._units.keys()) + } + ) + + # Determine the preferred unit for each dimensionality from the preferred_units + # (A prefered unit doesn't have to be only one dimensional, e.g. Watts) + preferred_dims = { + preferred_unit.dimensionality: preferred_unit + for preferred_unit in map(quantity._REGISTRY.Unit, preferred_units) + } + + # Combine the defaults and preferred, favoring the preferred + unit_selections.update(preferred_dims) + + # This algorithm has poor asymptotic time complexity, so first reduce the considered + # dimensions and units to only those that are useful to the problem + + # The dimensions (without powers) of this Quantity + dimension_set = set(quantity.dimensionality) + + # Getting zero exponents in dimensions not in dimension_set can be facilitated + # by units that interact with that dimension and one or more dimension_set members. + # For example MT^1 * LT^-1 lets you get MLT^0 when T is not in dimension_set. + # For each candidate unit that interacts with a dimension_set member, add the + # candidate unit's other dimensions to dimension_set, and repeat until no more + # dimensions are selected. + + discovery_done = False + while not discovery_done: + discovery_done = True + for d in unit_selections: + unit_dimensions = set(d) + intersection = unit_dimensions.intersection(dimension_set) + if 0 < len(intersection) < len(unit_dimensions): + # there are dimensions in this unit that are in dimension set + # and others that are not in dimension set + dimension_set = dimension_set.union(unit_dimensions) + discovery_done = False + break + + # filter out dimensions and their unit selections that don't interact with any + # dimension_set members + unit_selections = { + dimensionality: unit + for dimensionality, unit in unit_selections.items() + if set(dimensionality).intersection(dimension_set) + } + + # update preferred_units with the selected units that were originally preferred + preferred_units = list( + {u for d, u in unit_selections.items() if d in preferred_dims} + ) + preferred_units.sort(key=str) # for determinism + + # and unpreferred_units are the selected units that weren't originally preferred + unpreferred_units = list( + {u for d, u in unit_selections.items() if d not in preferred_dims} + ) + unpreferred_units.sort(key=str) # for determinism + + # for indexability + dimensions = list(dimension_set) + dimensions.sort() # for determinism + + # the powers for each elemet of dimensions (the list) for this Quantity + dimensionality = [quantity.dimensionality[dimension] for dimension in dimensions] + + # Now that the input data is minimized, setup the optimization problem + + # use mip to select units from preferred units + + model = mip_Model() + model.verbose = 0 + + # Make one variable for each candidate unit + + vars = [ + model.add_var(str(unit), lb=-mip_INF, ub=mip_INF, var_type=mip_INTEGER) + for unit in (preferred_units + unpreferred_units) + ] + + # where [u1 ... uN] are powers of N candidate units (vars) + # and [d1(uI) ... dK(uI)] are the K dimensional exponents of candidate unit I + # and [t1 ... tK] are the dimensional exponents of the quantity (quantity) + # create the following constraints + # + # ⎡ d1(u1) ⋯ dK(u1) ⎤ + # [ u1 ⋯ uN ] * ⎢ ⋮ ⋱ ⎢ = [ t1 ⋯ tK ] + # ⎣ d1(uN) dK(uN) ⎦ + # + # in English, the units we choose, and their exponents, when combined, must have the + # target dimensionality + + matrix = [ + [preferred_unit.dimensionality[dimension] for dimension in dimensions] + for preferred_unit in (preferred_units + unpreferred_units) + ] + + # Do the matrix multiplication with mip_model.xsum for performance and create constraints + for i in range(len(dimensions)): + dot = mip_model.xsum([var * vector[i] for var, vector in zip(vars, matrix)]) + # add constraint to the model + model += dot == dimensionality[i] + + # where [c1 ... cN] are costs, 1 when a preferred variable, and a large value when not + # minimize sum(abs(u1) * c1 ... abs(uN) * cN) + + # linearize the optimization variable via a proxy + objective = model.add_var("objective", lb=0, ub=mip_INF, var_type=mip_INTEGER) + + # Constrain the objective to be equal to the sums of the absolute values of the preferred + # unit powers. Do this by making a separate constraint for each permutation of signedness. + # Also apply the cost coefficient, which causes the output to prefer the preferred units + + # prefer units that interact with fewer dimensions + cost = [len(p.dimensionality) for p in preferred_units] + + # set the cost for non preferred units to a higher number + bias = ( + max(map(abs, dimensionality)) * max((1, *cost)) * 10 + ) # arbitrary, just needs to be larger + cost.extend([bias] * len(unpreferred_units)) + + for i in range(1 << len(vars)): + sum = mip_xsum( + [ + (-1 if i & 1 << (len(vars) - j - 1) else 1) * cost[j] * var + for j, var in enumerate(vars) + ] + ) + model += objective >= sum + + model.objective = objective + + # run the mips minimizer and extract the result if successful + if model.optimize() == mip_OptimizationStatus.OPTIMAL: + optimal_units = [] + min_objective = float("inf") + for i in range(model.num_solutions): + if model.objective_values[i] < min_objective: + min_objective = model.objective_values[i] + optimal_units.clear() + elif model.objective_values[i] > min_objective: + continue + + temp_unit = quantity._REGISTRY.Unit("") + for var in vars: + if var.xi(i): + temp_unit *= quantity._REGISTRY.Unit(var.name) ** var.xi(i) + optimal_units.append(temp_unit) + + sorting_keys = {tuple(sorted(unit._units)): unit for unit in optimal_units} + min_key = sorted(sorting_keys)[0] + result_unit = sorting_keys[min_key] + + return quantity.to(result_unit) + + # for whatever reason, a solution wasn't found + # return the original quantity + return quantity diff --git a/pint/facets/plain/quantity.py b/pint/facets/plain/quantity.py index 1eaaa3d..0058549 100644 --- a/pint/facets/plain/quantity.py +++ b/pint/facets/plain/quantity.py @@ -8,37 +8,22 @@ from __future__ import annotations -import bisect + import copy import datetime import locale -import math import numbers import operator -import warnings -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Generic, - TypeVar, - overload, -) -from collections.abc import Iterable, Iterator, Sequence +from typing import TYPE_CHECKING, Any, Callable, overload, Generic, TypeVar +from collections.abc import Iterator, Sequence -from ..._typing import S, UnitLike, _MagnitudeType +from ..._typing import UnitLike, QuantityOrUnitLike, Magnitude from ...compat import ( HAS_NUMPY, _to_magnitude, eq, is_duck_array_type, is_upcast_type, - mip_INF, - mip_INTEGER, - mip_model, - mip_Model, - mip_OptimizationStatus, - mip_xsum, np, zero_or_nan, ) @@ -47,11 +32,11 @@ from ...util import ( PrettyIPython, SharedRegistryObject, UnitsContainer, - infer_base_unit, logger, to_units_container, ) from .definitions import UnitDefinition +from . import qto if TYPE_CHECKING: from ..context import Context @@ -61,6 +46,10 @@ if TYPE_CHECKING: if HAS_NUMPY: import numpy as np # noqa +MagnitudeT = TypeVar("MagnitudeT", bound=Magnitude) + +T = TypeVar("T", bound=Magnitude) + def reduce_dimensions(f): def wrapped(self, *args, **kwargs): @@ -115,14 +104,10 @@ def method_wraps(numpy_func): return wrapper -# Workaround to bypass dynamically generated PlainQuantity with overload method -Magnitude = TypeVar("Magnitude") - - # TODO: remove all nonmultiplicative remnants -class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]): +class PlainQuantity(Generic[MagnitudeT], PrettyIPython, SharedRegistryObject): """Implements a class to describe a physical quantity: the product of a numerical value and a unit of measurement. @@ -140,7 +125,7 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType] #: Default formatting string. default_format: str = "" - _magnitude: _MagnitudeType + _magnitude: MagnitudeT @property def ndim(self) -> int: @@ -156,11 +141,7 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType] def force_ndarray_like(self) -> bool: return self._REGISTRY.force_ndarray_like - @property - def UnitsContainer(self) -> Callable[..., UnitsContainerT]: - return self._REGISTRY.UnitsContainer - - def __reduce__(self) -> tuple: + def __reduce__(self) -> tuple[type, Magnitude, UnitsContainer]: """Allow pickling quantities. Since UnitRegistries are not pickled, upon unpickling the new object is always attached to the application registry. """ @@ -168,12 +149,17 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType] # Note: type(self) would be a mistake as subclasses built by # dinamically can't be pickled + # TODO: Check if this is still the case. return _unpickle_quantity, (PlainQuantity, self.magnitude, self._units) + # @overload + # def __new__( + # cls, value: T, units: UnitLike | None = None + # ) -> PlainQuantity[T]: + # ... + @overload - def __new__( - cls, value: str, units: UnitLike | None = None - ) -> PlainQuantity[Magnitude]: + def __new__(cls, value: str, units: UnitLike | None = None) -> PlainQuantity[int]: ... @overload @@ -182,17 +168,11 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType] ) -> PlainQuantity[np.ndarray]: ... - @overload - def __new__( - cls, value: PlainQuantity[Magnitude], units: UnitLike | None = None - ) -> PlainQuantity[Magnitude]: - ... - - @overload - def __new__( - cls, value: Magnitude, units: UnitLike | None = None - ) -> PlainQuantity[Magnitude]: - ... + # @overload + # def __new__( + # cls, value: PlainQuantity[Any], units: UnitLike | None = None + # ) -> PlainQuantity[Any]: + # ... def __new__(cls, value, units=None): if is_upcast_type(type(value)): @@ -243,7 +223,7 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType] return inst - def __iter__(self: PlainQuantity[Iterable[S]]) -> Iterator[S]: + def __iter__(self: PlainQuantity[MagnitudeT]) -> Iterator[Any]: # Make sure that, if self.magnitude is not iterable, we raise TypeError as soon # as one calls iter(self) without waiting for the first element to be drawn from # the iterator @@ -255,11 +235,11 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType] return it_outer() - def __copy__(self) -> PlainQuantity[_MagnitudeType]: + def __copy__(self) -> PlainQuantity[MagnitudeT]: ret = self.__class__(copy.copy(self._magnitude), self._units) return ret - def __deepcopy__(self, memo) -> PlainQuantity[_MagnitudeType]: + def __deepcopy__(self, memo) -> PlainQuantity[MagnitudeT]: ret = self.__class__( copy.deepcopy(self._magnitude, memo), copy.deepcopy(self._units, memo) ) @@ -285,16 +265,16 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType] return hash((self_base.__class__, self_base.magnitude, self_base.units)) @property - def magnitude(self) -> _MagnitudeType: + def magnitude(self) -> MagnitudeT: """PlainQuantity's magnitude. Long form for `m`""" return self._magnitude @property - def m(self) -> _MagnitudeType: + def m(self) -> MagnitudeT: """PlainQuantity's magnitude. Short form for `magnitude`""" return self._magnitude - def m_as(self, units) -> _MagnitudeType: + def m_as(self, units) -> MagnitudeT: """PlainQuantity's magnitude expressed in particular units. Parameters @@ -351,8 +331,8 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType] @classmethod def from_list( - cls, quant_list: list[PlainQuantity], units=None - ) -> PlainQuantity[np.ndarray]: + cls, quant_list: list[PlainQuantity[MagnitudeT]], units=None + ) -> PlainQuantity[MagnitudeT]: """Transforms a list of Quantities into an numpy.array quantity. If no units are specified, the unit of the first element will be used. Same as from_sequence. @@ -375,8 +355,8 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType] @classmethod def from_sequence( - cls, seq: Sequence[PlainQuantity], units=None - ) -> PlainQuantity[np.ndarray]: + cls, seq: Sequence[PlainQuantity[MagnitudeT]], units=None + ) -> PlainQuantity[MagnitudeT]: """Transforms a sequence of Quantities into an numpy.array quantity. If no units are specified, the unit of the first element will be used. @@ -414,7 +394,7 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType] def from_tuple(cls, tup): return cls(tup[0], cls._REGISTRY.UnitsContainer(tup[1])) - def to_tuple(self) -> tuple[_MagnitudeType, tuple[tuple[str]]]: + def to_tuple(self) -> tuple[MagnitudeT, tuple[tuple[str]]]: return self.m, tuple(self._units.items()) def compatible_units(self, *contexts): @@ -452,7 +432,7 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType] except DimensionalityError: return False - if isinstance(other, (PlainQuantity, PlainUnit)): + if isinstance(other, (PlainQuantity[MagnitudeT], PlainUnit)): return self.dimensionality == other.dimensionality if isinstance(other, str): @@ -481,7 +461,9 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType] inplace=is_duck_array_type(type(self._magnitude)), ) - def ito(self, other=None, *contexts, **ctx_kwargs) -> None: + def ito( + self, other: QuantityOrUnitLike | None = None, *contexts, **ctx_kwargs + ) -> None: """Inplace rescale to different units. Parameters @@ -500,7 +482,9 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType] return None - def to(self, other=None, *contexts, **ctx_kwargs) -> PlainQuantity[_MagnitudeType]: + def to( + self, other: QuantityOrUnitLike | None = None, *contexts, **ctx_kwargs + ) -> PlainQuantity: """Return PlainQuantity rescaled to different units. Parameters @@ -532,7 +516,7 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType] return None - def to_root_units(self) -> PlainQuantity[_MagnitudeType]: + def to_root_units(self) -> PlainQuantity[MagnitudeT]: """Return PlainQuantity rescaled to root units.""" _, other = self._REGISTRY._get_root_units(self._units) @@ -551,7 +535,7 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType] return None - def to_base_units(self) -> PlainQuantity[_MagnitudeType]: + def to_base_units(self) -> PlainQuantity[MagnitudeT]: """Return PlainQuantity rescaled to plain units.""" _, other = self._REGISTRY._get_base_units(self._units) @@ -560,361 +544,13 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType] return self.__class__(magnitude, other) - def _get_reduced_units(self, units): - # loop through individual units and compare to each other unit - # can we do better than a nested loop here? - for unit1, exp in units.items(): - # make sure it wasn't already reduced to zero exponent on prior pass - if unit1 not in units: - continue - for unit2 in units: - # get exponent after reduction - exp = units[unit1] - if unit1 != unit2: - power = self._REGISTRY._get_dimensionality_ratio(unit1, unit2) - if power: - units = units.add(unit2, exp / power).remove([unit1]) - break - return units - - def ito_reduced_units(self) -> None: - """Return PlainQuantity scaled in place to reduced units, i.e. one unit per - dimension. This will not reduce compound units (e.g., 'J/kg' will not - be reduced to m**2/s**2), nor can it make use of contexts at this time. - """ - - # shortcuts in case we're dimensionless or only a single unit - if self.dimensionless: - return self.ito({}) - if len(self._units) == 1: - return None - - units = self._units.copy() - new_units = self._get_reduced_units(units) - - return self.ito(new_units) - - def to_reduced_units(self) -> PlainQuantity[_MagnitudeType]: - """Return PlainQuantity scaled in place to reduced units, i.e. one unit per - dimension. This will not reduce compound units (intentionally), nor - can it make use of contexts at this time. - """ - - # shortcuts in case we're dimensionless or only a single unit - if self.dimensionless: - return self.to({}) - if len(self._units) == 1: - return self - - units = self._units.copy() - new_units = self._get_reduced_units(units) - - return self.to(new_units) - - def to_compact(self, unit=None) -> PlainQuantity[_MagnitudeType]: - """ "Return PlainQuantity rescaled to compact, human-readable units. - - To get output in terms of a different unit, use the unit parameter. - - - Examples - -------- - - >>> import pint - >>> ureg = pint.UnitRegistry() - >>> (200e-9*ureg.s).to_compact() - <Quantity(200.0, 'nanosecond')> - >>> (1e-2*ureg('kg m/s^2')).to_compact('N') - <Quantity(10.0, 'millinewton')> - """ - - if not isinstance(self.magnitude, numbers.Number): - msg = ( - "to_compact applied to non numerical types " - "has an undefined behavior." - ) - w = RuntimeWarning(msg) - warnings.warn(w, stacklevel=2) - return self - - if ( - self.unitless - or self.magnitude == 0 - or math.isnan(self.magnitude) - or math.isinf(self.magnitude) - ): - return self - - SI_prefixes: dict[int, str] = {} - for prefix in self._REGISTRY._prefixes.values(): - try: - scale = prefix.converter.scale - # Kludgy way to check if this is an SI prefix - log10_scale = int(math.log10(scale)) - if log10_scale == math.log10(scale): - SI_prefixes[log10_scale] = prefix.name - except Exception: - SI_prefixes[0] = "" - - SI_prefixes_list = sorted(SI_prefixes.items()) - SI_powers = [item[0] for item in SI_prefixes_list] - SI_bases = [item[1] for item in SI_prefixes_list] - - if unit is None: - unit = infer_base_unit(self, registry=self._REGISTRY) - else: - unit = infer_base_unit(self.__class__(1, unit), registry=self._REGISTRY) - - q_base = self.to(unit) - - magnitude = q_base.magnitude - - units = list(q_base._units.items()) - units_numerator = [a for a in units if a[1] > 0] - - if len(units_numerator) > 0: - unit_str, unit_power = units_numerator[0] - else: - unit_str, unit_power = units[0] - - if unit_power > 0: - power = math.floor(math.log10(abs(magnitude)) / float(unit_power) / 3) * 3 - else: - power = math.ceil(math.log10(abs(magnitude)) / float(unit_power) / 3) * 3 - - index = bisect.bisect_left(SI_powers, power) - - if index >= len(SI_bases): - index = -1 - - prefix_str = SI_bases[index] - - new_unit_str = prefix_str + unit_str - new_unit_container = q_base._units.rename(unit_str, new_unit_str) - - return self.to(new_unit_container) - - def to_preferred( - self, preferred_units: list[UnitLike] - ) -> PlainQuantity[_MagnitudeType]: - """Return Quantity converted to a unit composed of the preferred units. - - Examples - -------- - - >>> import pint - >>> ureg = pint.UnitRegistry() - >>> (1*ureg.acre).to_preferred([ureg.meters]) - <Quantity(4046.87261, 'meter ** 2')> - >>> (1*(ureg.force_pound*ureg.m)).to_preferred([ureg.W]) - <Quantity(4.44822162, 'second * watt')> - """ - - if not self.dimensionality: - return self - - # The optimizer isn't perfect, and will sometimes miss obvious solutions. - # This sub-algorithm is less powerful, but always finds the very simple solutions. - def find_simple(): - best_ratio = None - best_unit = None - self_dims = sorted(self.dimensionality) - self_exps = [self.dimensionality[d] for d in self_dims] - s_exps_head, *s_exps_tail = self_exps - n = len(s_exps_tail) - for preferred_unit in preferred_units: - dims = sorted(preferred_unit.dimensionality) - if dims == self_dims: - p_exps_head, *p_exps_tail = ( - preferred_unit.dimensionality[d] for d in dims - ) - if all( - s_exps_tail[i] * p_exps_head == p_exps_tail[i] ** s_exps_head - for i in range(n) - ): - ratio = p_exps_head / s_exps_head - ratio = max(ratio, 1 / ratio) - if best_ratio is None or ratio < best_ratio: - best_ratio = ratio - best_unit = preferred_unit ** (s_exps_head / p_exps_head) - return best_unit - - simple = find_simple() - if simple is not None: - return self.to(simple) - - # For each dimension (e.g. T(ime), L(ength), M(ass)), assign a default base unit from - # the collection of base units - - unit_selections = { - base_unit.dimensionality: base_unit - for base_unit in map(self._REGISTRY.Unit, self._REGISTRY._base_units) - } - - # Override the default unit of each dimension with the 1D-units used in this Quantity - unit_selections.update( - { - unit.dimensionality: unit - for unit in map(self._REGISTRY.Unit, self._units.keys()) - } - ) - - # Determine the preferred unit for each dimensionality from the preferred_units - # (A prefered unit doesn't have to be only one dimensional, e.g. Watts) - preferred_dims = { - preferred_unit.dimensionality: preferred_unit - for preferred_unit in map(self._REGISTRY.Unit, preferred_units) - } - - # Combine the defaults and preferred, favoring the preferred - unit_selections.update(preferred_dims) - - # This algorithm has poor asymptotic time complexity, so first reduce the considered - # dimensions and units to only those that are useful to the problem - - # The dimensions (without powers) of this Quantity - dimension_set = set(self.dimensionality) - - # Getting zero exponents in dimensions not in dimension_set can be facilitated - # by units that interact with that dimension and one or more dimension_set members. - # For example MT^1 * LT^-1 lets you get MLT^0 when T is not in dimension_set. - # For each candidate unit that interacts with a dimension_set member, add the - # candidate unit's other dimensions to dimension_set, and repeat until no more - # dimensions are selected. - - discovery_done = False - while not discovery_done: - discovery_done = True - for d in unit_selections: - unit_dimensions = set(d) - intersection = unit_dimensions.intersection(dimension_set) - if 0 < len(intersection) < len(unit_dimensions): - # there are dimensions in this unit that are in dimension set - # and others that are not in dimension set - dimension_set = dimension_set.union(unit_dimensions) - discovery_done = False - break - - # filter out dimensions and their unit selections that don't interact with any - # dimension_set members - unit_selections = { - dimensionality: unit - for dimensionality, unit in unit_selections.items() - if set(dimensionality).intersection(dimension_set) - } - - # update preferred_units with the selected units that were originally preferred - preferred_units = list( - {u for d, u in unit_selections.items() if d in preferred_dims} - ) - preferred_units.sort(key=str) # for determinism - - # and unpreferred_units are the selected units that weren't originally preferred - unpreferred_units = list( - {u for d, u in unit_selections.items() if d not in preferred_dims} - ) - unpreferred_units.sort(key=str) # for determinism - - # for indexability - dimensions = list(dimension_set) - dimensions.sort() # for determinism - - # the powers for each elemet of dimensions (the list) for this Quantity - dimensionality = [self.dimensionality[dimension] for dimension in dimensions] - - # Now that the input data is minimized, setup the optimization problem - - # use mip to select units from preferred units - - model = mip_Model() - model.verbose = 0 - - # Make one variable for each candidate unit - - vars = [ - model.add_var(str(unit), lb=-mip_INF, ub=mip_INF, var_type=mip_INTEGER) - for unit in (preferred_units + unpreferred_units) - ] - - # where [u1 ... uN] are powers of N candidate units (vars) - # and [d1(uI) ... dK(uI)] are the K dimensional exponents of candidate unit I - # and [t1 ... tK] are the dimensional exponents of the quantity (self) - # create the following constraints - # - # ⎡ d1(u1) ⋯ dK(u1) ⎤ - # [ u1 ⋯ uN ] * ⎢ ⋮ ⋱ ⎢ = [ t1 ⋯ tK ] - # ⎣ d1(uN) dK(uN) ⎦ - # - # in English, the units we choose, and their exponents, when combined, must have the - # target dimensionality - - matrix = [ - [preferred_unit.dimensionality[dimension] for dimension in dimensions] - for preferred_unit in (preferred_units + unpreferred_units) - ] - - # Do the matrix multiplication with mip_model.xsum for performance and create constraints - for i in range(len(dimensions)): - dot = mip_model.xsum([var * vector[i] for var, vector in zip(vars, matrix)]) - # add constraint to the model - model += dot == dimensionality[i] - - # where [c1 ... cN] are costs, 1 when a preferred variable, and a large value when not - # minimize sum(abs(u1) * c1 ... abs(uN) * cN) - - # linearize the optimization variable via a proxy - objective = model.add_var("objective", lb=0, ub=mip_INF, var_type=mip_INTEGER) - - # Constrain the objective to be equal to the sums of the absolute values of the preferred - # unit powers. Do this by making a separate constraint for each permutation of signedness. - # Also apply the cost coefficient, which causes the output to prefer the preferred units - - # prefer units that interact with fewer dimensions - cost = [len(p.dimensionality) for p in preferred_units] - - # set the cost for non preferred units to a higher number - bias = ( - max(map(abs, dimensionality)) * max((1, *cost)) * 10 - ) # arbitrary, just needs to be larger - cost.extend([bias] * len(unpreferred_units)) - - for i in range(1 << len(vars)): - sum = mip_xsum( - [ - (-1 if i & 1 << (len(vars) - j - 1) else 1) * cost[j] * var - for j, var in enumerate(vars) - ] - ) - model += objective >= sum - - model.objective = objective - - # run the mips minimizer and extract the result if successful - if model.optimize() == mip_OptimizationStatus.OPTIMAL: - optimal_units = [] - min_objective = float("inf") - for i in range(model.num_solutions): - if model.objective_values[i] < min_objective: - min_objective = model.objective_values[i] - optimal_units.clear() - elif model.objective_values[i] > min_objective: - continue - - temp_unit = self._REGISTRY.Unit("") - for var in vars: - if var.xi(i): - temp_unit *= self._REGISTRY.Unit(var.name) ** var.xi(i) - optimal_units.append(temp_unit) - - sorting_keys = {tuple(sorted(unit._units)): unit for unit in optimal_units} - min_key = sorted(sorting_keys)[0] - result_unit = sorting_keys[min_key] - - return self.to(result_unit) - - # for whatever reason, a solution wasn't found - # return the original quantity - return self + # Functions not essential to a Quantity but it is + # convenient that they live in PlainQuantity. + # They are implemented elsewhere to keep Quantity class clean. + to_compact = qto.to_compact + to_preferred = qto.to_preferred + to_reduced_units = qto.to_reduced_units + ito_reduced_units = qto.ito_reduced_units # Mathematical operations def __int__(self) -> int: @@ -1163,7 +799,7 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType] ... @overload - def __iadd__(self, other) -> PlainQuantity[_MagnitudeType]: + def __iadd__(self, other) -> PlainQuantity[MagnitudeT]: ... def __iadd__(self, other): @@ -1539,7 +1175,7 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType] return self @check_implemented - def __pow__(self, other) -> PlainQuantity[_MagnitudeType]: + def __pow__(self, other) -> PlainQuantity[MagnitudeT]: try: _to_magnitude(other, self.force_ndarray, self.force_ndarray_like) except PintTypeError: @@ -1604,7 +1240,7 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType] return self.__class__(magnitude, units) @check_implemented - def __rpow__(self, other) -> PlainQuantity[_MagnitudeType]: + def __rpow__(self, other) -> PlainQuantity[MagnitudeT]: try: _to_magnitude(other, self.force_ndarray, self.force_ndarray_like) except PintTypeError: @@ -1617,16 +1253,16 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType] new_self = self.to_root_units() return other**new_self._magnitude - def __abs__(self) -> PlainQuantity[_MagnitudeType]: + def __abs__(self) -> PlainQuantity[MagnitudeT]: return self.__class__(abs(self._magnitude), self._units) - def __round__(self, ndigits: int | None = 0) -> PlainQuantity[int]: + def __round__(self, ndigits: int | None = 0) -> PlainQuantity[MagnitudeT]: return self.__class__(round(self._magnitude, ndigits=ndigits), self._units) - def __pos__(self) -> PlainQuantity[_MagnitudeType]: + def __pos__(self) -> PlainQuantity[MagnitudeT]: return self.__class__(operator.pos(self._magnitude), self._units) - def __neg__(self) -> PlainQuantity[_MagnitudeType]: + def __neg__(self) -> PlainQuantity[MagnitudeT]: return self.__class__(operator.neg(self._magnitude), self._units) @check_implemented @@ -1797,5 +1433,14 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType] def _ok_for_muldiv(self, no_offset_units=None) -> bool: return True - def to_timedelta(self: PlainQuantity[float]) -> datetime.timedelta: + def to_timedelta(self: PlainQuantity[MagnitudeT]) -> datetime.timedelta: return datetime.timedelta(microseconds=self.to("microseconds").magnitude) + + # We put this last to avoid overriding UnitsContainer + # and I do not want to rename it. + # TODO: Maybe in the future we need to change it to a more meaningful + # non-colliding name. + + @property + def UnitsContainer(self) -> Callable[..., UnitsContainerT]: + return self._REGISTRY.UnitsContainer diff --git a/pint/facets/plain/registry.py b/pint/facets/plain/registry.py index d3baff4..ed46608 100644 --- a/pint/facets/plain/registry.py +++ b/pint/facets/plain/registry.py @@ -20,27 +20,38 @@ from decimal import Decimal from fractions import Fraction from numbers import Number from token import NAME, NUMBER +from tokenize import TokenInfo + from typing import ( TYPE_CHECKING, Any, Callable, TypeVar, Union, + Generic, ) from collections.abc import Iterable, Iterator if TYPE_CHECKING: from ..context import Context - from ..._typing import Quantity, Unit + from ...compat import Locale + + # from ..._typing import Quantity, Unit + +from ..._typing import ( + QuantityOrUnitLike, + UnitLike, + QuantityArgument, + Scalar, + Handler, +) -from ..._typing import QuantityOrUnitLike, UnitLike from ..._vendor import appdirs -from ...compat import HAS_BABEL, babel_parse, tokenizer +from ...compat import babel_parse, tokenizer, TypeAlias, Self from ...errors import DimensionalityError, RedefinitionError, UndefinedUnitError from ...pint_eval import build_eval_tree from ...util import ParserHelper -from ...util import UnitsContainer -from ...util import UnitsContainer as UnitsContainerT +from ...util import UnitsContainer as UnitsContainer from ...util import ( _is_dim, create_class_with_registry, @@ -58,25 +69,20 @@ from .definitions import ( DimensionDefinition, PrefixDefinition, UnitDefinition, + NamedDefinition, ) from .objects import PlainQuantity, PlainUnit -if TYPE_CHECKING: - if HAS_BABEL: - import babel - - Locale = babel.Locale - else: - Locale = None - T = TypeVar("T") _BLOCK_RE = re.compile(r"[ (]") @functools.lru_cache -def pattern_to_regex(pattern): - if hasattr(pattern, "finditer"): +def pattern_to_regex(pattern: str | re.Pattern[str]) -> re.Pattern[str]: + # TODO: This has been changed during typing improvements. + # if hasattr(pattern, "finditer"): + if not isinstance(pattern, str): pattern = pattern.pattern # Replace "{unit_name}" match string with float regex with unit_name as group @@ -96,15 +102,19 @@ class RegistryCache: def __init__(self) -> None: #: Maps dimensionality (UnitsContainer) to Units (str) - self.dimensional_equivalents: dict[UnitsContainer, set[str]] = {} + self.dimensional_equivalents: dict[UnitsContainer, frozenset[str]] = {} + #: Maps dimensionality (UnitsContainer) to Dimensionality (UnitsContainer) - self.root_units = {} + # TODO: this description is not right. + self.root_units: dict[UnitsContainer, tuple[Scalar, UnitsContainer]] = {} + #: Maps dimensionality (UnitsContainer) to Units (UnitsContainer) self.dimensionality: dict[UnitsContainer, UnitsContainer] = {} + #: Cache the unit name associated to user input. ('mV' -> 'millivolt') self.parse_unit: dict[str, UnitsContainer] = {} - def __eq__(self, other): + def __eq__(self, other: Any): if not isinstance(other, self.__class__): return False attrs = ( @@ -127,7 +137,12 @@ class RegistryMeta(type): return obj -class PlainRegistry(metaclass=RegistryMeta): +# Generic types used to mark types associated to Registries. +QuantityT = TypeVar("QuantityT", bound=PlainQuantity) +UnitT = TypeVar("UnitT", bound=PlainUnit) + + +class GenericPlainRegistry(Generic[QuantityT, UnitT], metaclass=RegistryMeta): """Base class for all registries. Capabilities: @@ -174,11 +189,10 @@ class PlainRegistry(metaclass=RegistryMeta): #: Babel.Locale instance or None fmt_locale: Locale | None = None - _diskcache = None - - Quantity = PlainQuantity - Unit = PlainUnit + Quantity: type[QuantityT] + Unit: type[UnitT] + _diskcache = None _def_parser = None def __init__( @@ -197,7 +211,7 @@ class PlainRegistry(metaclass=RegistryMeta): mpl_formatter: str = "{:P}", ): #: Map a definition class to a adder methods. - self._adders = {} + self._adders: Handler = {} self._register_definition_adders() self._init_dynamic_classes() @@ -280,8 +294,8 @@ class PlainRegistry(metaclass=RegistryMeta): def _init_dynamic_classes(self) -> None: """Generate subclasses on the fly and attach them to self""" - self.Unit: Unit = create_class_with_registry(self, self.Unit) - self.Quantity: Quantity = create_class_with_registry(self, self.Quantity) + self.Unit = create_class_with_registry(self, self.Unit) + self.Quantity = create_class_with_registry(self, self.Quantity) def _after_init(self) -> None: """This should be called after all __init__""" @@ -297,7 +311,16 @@ class PlainRegistry(metaclass=RegistryMeta): self._build_cache(loaded_files) self._initialized = True - def _register_adder(self, definition_class, adder_func): + def _register_adder( + self, + definition_class: type[T], + adder_func: Callable[ + [ + T, + ], + None, + ], + ) -> None: """Register a block definition.""" self._adders[definition_class] = adder_func @@ -310,24 +333,25 @@ class PlainRegistry(metaclass=RegistryMeta): self._register_adder(DimensionDefinition, self._add_dimension) self._register_adder(DerivedDimensionDefinition, self._add_derived_dimension) - def __deepcopy__(self, memo) -> PlainRegistry: + def __deepcopy__(self: Self, memo) -> type[Self]: new = object.__new__(type(self)) new.__dict__ = copy.deepcopy(self.__dict__, memo) new._init_dynamic_classes() return new - def __getattr__(self, item): + def __getattr__(self, item: str) -> QuantityT: getattr_maybe_raise(self, item) return self.Unit(item) - def __getitem__(self, item): + def __getitem__(self, item: str) -> UnitT: logger.warning( "Calling the getitem method from a UnitRegistry is deprecated. " "use `parse_expression` method or use the registry as a callable." ) - return self.parse_expression(item) + return self.Quantity() + # return self.parse_expression(item) - def __contains__(self, item) -> bool: + def __contains__(self, item: str) -> bool: """Support checking prefixed units with the `in` operator""" try: self.__getattr__(item) @@ -366,16 +390,13 @@ class PlainRegistry(metaclass=RegistryMeta): self.fmt_locale = loc - def UnitsContainer(self, *args, **kwargs) -> UnitsContainerT: - return UnitsContainer(*args, non_int_type=self.non_int_type, **kwargs) - @property def default_format(self) -> str: """Default formatting string for quantities.""" return self.Quantity.default_format @default_format.setter - def default_format(self, value: str): + def default_format(self, value: str) -> None: self.Unit.default_format = value self.Quantity.default_format = value self.Measurement.default_format = value @@ -390,7 +411,7 @@ class PlainRegistry(metaclass=RegistryMeta): def non_int_type(self): return self._non_int_type - def define(self, definition): + def define(self, definition: str | type) -> None: """Add unit to the registry. Parameters @@ -413,7 +434,7 @@ class PlainRegistry(metaclass=RegistryMeta): # - then we define specific adder for each definition class. :-D ############ - def _helper_dispatch_adder(self, definition): + def _helper_dispatch_adder(self, definition: Any) -> None: """Helper function to add a single definition, choosing the appropiate method by class. """ @@ -428,7 +449,12 @@ class PlainRegistry(metaclass=RegistryMeta): adder_func(definition) - def _helper_adder(self, definition, target_dict, casei_target_dict): + def _helper_adder( + self, + definition: NamedDefinition, + target_dict: dict[str, Any], + casei_target_dict: dict[str, Any] | None, + ) -> None: """Helper function to store a definition in the internal dictionaries. It stores the definition under its name, symbol and aliases. """ @@ -436,6 +462,7 @@ class PlainRegistry(metaclass=RegistryMeta): definition.name, definition, target_dict, casei_target_dict ) + # TODO: Not sure why but using hasattr does not work here. if getattr(definition, "has_symbol", ""): self._helper_single_adder( definition.symbol, definition, target_dict, casei_target_dict @@ -447,7 +474,13 @@ class PlainRegistry(metaclass=RegistryMeta): self._helper_single_adder(alias, definition, target_dict, casei_target_dict) - def _helper_single_adder(self, key, value, target_dict, casei_target_dict): + def _helper_single_adder( + self, + key: str, + value: NamedDefinition, + target_dict: dict[str, Any], + casei_target_dict: dict[str, Any] | None, + ) -> None: """Helper function to store a definition in the internal dictionaries. It warns or raise error on redefinition. @@ -462,11 +495,11 @@ class PlainRegistry(metaclass=RegistryMeta): if casei_target_dict is not None: casei_target_dict[key.lower()].add(key) - def _add_defaults(self, defaults_definition: DefaultsDefinition): + def _add_defaults(self, defaults_definition: DefaultsDefinition) -> None: for k, v in defaults_definition.items(): self._defaults[k] = v - def _add_alias(self, definition: AliasDefinition): + def _add_alias(self, definition: AliasDefinition) -> None: unit_dict = self._units unit = unit_dict[definition.name] while not isinstance(unit, UnitDefinition): @@ -474,19 +507,19 @@ class PlainRegistry(metaclass=RegistryMeta): for alias in definition.aliases: self._helper_single_adder(alias, unit, self._units, self._units_casei) - def _add_dimension(self, definition: DimensionDefinition): + def _add_dimension(self, definition: DimensionDefinition) -> None: self._helper_adder(definition, self._dimensions, None) - def _add_derived_dimension(self, definition: DerivedDimensionDefinition): + def _add_derived_dimension(self, definition: DerivedDimensionDefinition) -> None: for dim_name in definition.reference.keys(): if dim_name not in self._dimensions: self._add_dimension(DimensionDefinition(dim_name)) self._helper_adder(definition, self._dimensions, None) - def _add_prefix(self, definition: PrefixDefinition): + def _add_prefix(self, definition: PrefixDefinition) -> None: self._helper_adder(definition, self._prefixes, None) - def _add_unit(self, definition: UnitDefinition): + def _add_unit(self, definition: UnitDefinition) -> None: if definition.is_base: self._base_units.append(definition.name) for dim_name in definition.reference.keys(): @@ -495,7 +528,9 @@ class PlainRegistry(metaclass=RegistryMeta): self._helper_adder(definition, self._units, self._units_casei) - def load_definitions(self, file, is_resource: bool = False): + def load_definitions( + self, file: Iterable[str] | str | pathlib.Path, is_resource: bool = False + ): """Add units and prefixes defined in a definition text file. Parameters @@ -531,8 +566,8 @@ class PlainRegistry(metaclass=RegistryMeta): self._cache = RegistryCache() - deps = { - name: definition.reference.keys() if definition.reference else set() + deps: dict[str, set[str]] = { + name: set(definition.reference.keys()) if definition.reference else set() for name, definition in self._units.items() } @@ -579,14 +614,13 @@ class PlainRegistry(metaclass=RegistryMeta): candidates = self.parse_unit_name(name_or_alias, case_sensitive) if not candidates: raise UndefinedUnitError(name_or_alias) - elif len(candidates) == 1: - prefix, unit_name, _ = candidates[0] - else: + + prefix, unit_name, _ = candidates[0] + if len(candidates) > 1: logger.warning( "Parsing {} yield multiple results. " - "Options are: {}".format(name_or_alias, candidates) + "Options are: {!r}".format(name_or_alias, candidates) ) - prefix, unit_name, _ = candidates[0] if prefix: name = prefix + unit_name @@ -595,7 +629,7 @@ class PlainRegistry(metaclass=RegistryMeta): self._units[name] = UnitDefinition( name, symbol, - (), + tuple(), prefix_def.converter, self.UnitsContainer({unit_name: 1}), ) @@ -608,21 +642,20 @@ class PlainRegistry(metaclass=RegistryMeta): candidates = self.parse_unit_name(name_or_alias, case_sensitive) if not candidates: raise UndefinedUnitError(name_or_alias) - elif len(candidates) == 1: - prefix, unit_name, _ = candidates[0] - else: + + prefix, unit_name, _ = candidates[0] + if len(candidates) > 1: logger.warning( "Parsing {} yield multiple results. " "Options are: {!r}".format(name_or_alias, candidates) ) - prefix, unit_name, _ = candidates[0] return self._prefixes[prefix].symbol + self._units[unit_name].symbol def _get_symbol(self, name: str) -> str: return self._units[name].symbol - def get_dimensionality(self, input_units) -> UnitsContainerT: + def get_dimensionality(self, input_units: UnitLike) -> UnitsContainer: """Convert unit or dict of units or dimensions to a dict of plain dimensions dimensions """ @@ -633,9 +666,7 @@ class PlainRegistry(metaclass=RegistryMeta): return self._get_dimensionality(input_units) - def _get_dimensionality( - self, input_units: UnitsContainerT | None - ) -> UnitsContainerT: + def _get_dimensionality(self, input_units: UnitsContainer | None) -> UnitsContainer: """Convert a UnitsContainer to plain dimensions.""" if not input_units: return self.UnitsContainer() @@ -647,7 +678,7 @@ class PlainRegistry(metaclass=RegistryMeta): except KeyError: pass - accumulator = defaultdict(int) + accumulator: dict[str, int] = defaultdict(int) self._get_dimensionality_recurse(input_units, 1, accumulator) if "[]" in accumulator: @@ -659,21 +690,25 @@ class PlainRegistry(metaclass=RegistryMeta): return dims - def _get_dimensionality_recurse(self, ref, exp, accumulator): + def _get_dimensionality_recurse( + self, ref: UnitsContainer, exp: Scalar, accumulator: dict[str, int] + ) -> None: for key in ref: exp2 = exp * ref[key] if _is_dim(key): reg = self._dimensions[key] - if reg.is_base: - accumulator[key] += exp2 - elif reg.reference is not None: + if isinstance(reg, DerivedDimensionDefinition): self._get_dimensionality_recurse(reg.reference, exp2, accumulator) + else: + # DimensionDefinition. + accumulator[key] += exp2 + else: reg = self._units[self.get_name(key)] if reg.reference is not None: self._get_dimensionality_recurse(reg.reference, exp2, accumulator) - def _get_dimensionality_ratio(self, unit1, unit2): + def _get_dimensionality_ratio(self, unit1: UnitLike, unit2: UnitLike): """Get the exponential ratio between two units, i.e. solve unit2 = unit1**x for x. Parameters @@ -707,7 +742,7 @@ class PlainRegistry(metaclass=RegistryMeta): def get_root_units( self, input_units: UnitLike, check_nonmult: bool = True - ) -> tuple[Number, PlainUnit]: + ) -> tuple[Number, UnitT]: """Convert unit or dict of units to the root units. If any unit is non multiplicative and check_converter is True, @@ -734,7 +769,9 @@ class PlainRegistry(metaclass=RegistryMeta): return f, self.Unit(units) - def _get_root_units(self, input_units, check_nonmult=True): + def _get_root_units( + self, input_units: UnitsContainer, check_nonmult: bool = True + ) -> tuple[Scalar, UnitsContainer]: """Convert unit or dict of units to the root units. If any unit is non multiplicative and check_converter is True, @@ -764,12 +801,13 @@ class PlainRegistry(metaclass=RegistryMeta): except KeyError: pass - accumulators = [1, defaultdict(int)] + accumulators: dict[str | None, int] = defaultdict(int) + accumulators[None] = 1 self._get_root_units_recurse(input_units, 1, accumulators) - factor = accumulators[0] + factor = accumulators[None] units = self.UnitsContainer( - {k: v for k, v in accumulators[1].items() if v != 0} + {k: v for k, v in accumulators.items() if k is not None and v != 0} ) # Check if any of the final units is non multiplicative and return None instead. @@ -780,7 +818,9 @@ class PlainRegistry(metaclass=RegistryMeta): cache[input_units] = factor, units return factor, units - def get_base_units(self, input_units, check_nonmult=True, system=None): + def get_base_units( + self, input_units: UnitsContainer | str, check_nonmult: bool = True, system=None + ) -> tuple[Number, UnitT]: """Convert unit or dict of units to the plain units. If any unit is non multiplicative and check_converter is True, @@ -806,35 +846,44 @@ class PlainRegistry(metaclass=RegistryMeta): return self.get_root_units(input_units, check_nonmult) - def _get_root_units_recurse(self, ref, exp, accumulators): + # TODO: accumulators breaks typing list[int, dict[str, int]] + # So we have changed the behavior here + def _get_root_units_recurse( + self, ref: UnitsContainer, exp: Scalar, accumulators: dict[str | None, int] + ) -> None: + """ + + accumulators None keeps the scalar prefactor not associated with a specific unit. + + """ for key in ref: exp2 = exp * ref[key] key = self.get_name(key) reg = self._units[key] if reg.is_base: - accumulators[1][key] += exp2 + accumulators[key] += exp2 else: - accumulators[0] *= reg.converter.scale**exp2 + accumulators[None] *= reg.converter.scale**exp2 if reg.reference is not None: self._get_root_units_recurse(reg.reference, exp2, accumulators) - def get_compatible_units( - self, input_units, group_or_system=None - ) -> frozenset[Unit]: + def get_compatible_units(self, input_units: QuantityOrUnitLike) -> frozenset[UnitT]: """ """ input_units = to_units_container(input_units) - equiv = self._get_compatible_units(input_units, group_or_system) + equiv = self._get_compatible_units(input_units) return frozenset(self.Unit(eq) for eq in equiv) - def _get_compatible_units(self, input_units, group_or_system): + def _get_compatible_units( + self, input_units: UnitsContainer, *args, **kwargs + ) -> frozenset[str]: """ """ if not input_units: return frozenset() src_dim = self._get_dimensionality(input_units) - return self._cache.dimensional_equivalents.setdefault(src_dim, set()) + return self._cache.dimensional_equivalents.setdefault(src_dim, frozenset()) # TODO: remove context from here def is_compatible_with( @@ -901,7 +950,14 @@ class PlainRegistry(metaclass=RegistryMeta): return self._convert(value, src, dst, inplace) - def _convert(self, value, src, dst, inplace=False, check_dimensionality=True): + def _convert( + self, + value: T, + src: UnitsContainer, + dst: UnitsContainer, + inplace: bool = False, + check_dimensionality: bool = True, + ) -> T: """Convert value from some source to destination units. Parameters @@ -931,7 +987,7 @@ class PlainRegistry(metaclass=RegistryMeta): # If the source and destination dimensionality are different, # then the conversion cannot be performed. if src_dim != dst_dim: - raise DimensionalityError(src, dst, src_dim, dst_dim) + raise DimensionalityError(src, dst, str(src_dim), str(dst_dim)) # Here src and dst have only multiplicative units left. Thus we can # convert with a factor. @@ -953,7 +1009,7 @@ class PlainRegistry(metaclass=RegistryMeta): def parse_unit_name( self, unit_name: str, case_sensitive: bool | None = None - ) -> tuple[tuple[str, str, str], ...]: + ) -> tuple[tuple[str, str, str]]: """Parse a unit to identify prefix, unit name and suffix by walking the list of prefix and suffix. In case of equivalent combinations (e.g. ('kilo', 'gram', '') and @@ -1033,7 +1089,7 @@ class PlainRegistry(metaclass=RegistryMeta): input_string: str, as_delta: bool | None = None, case_sensitive: bool | None = None, - ) -> Unit: + ) -> UnitT: """Parse a units expression and returns a UnitContainer with the canonical names. @@ -1054,6 +1110,8 @@ class PlainRegistry(metaclass=RegistryMeta): pint.Unit """ + + # TODO: deal or remove with as_delta = None for p in self.preprocessors: input_string = p(input_string) units = self._parse_units(input_string, as_delta, case_sensitive) @@ -1064,7 +1122,7 @@ class PlainRegistry(metaclass=RegistryMeta): input_string: str, as_delta: bool = True, case_sensitive: bool | None = None, - ) -> UnitsContainerT: + ) -> UnitsContainer: """Parse a units expression and returns a UnitContainer with the canonical names. """ @@ -1104,12 +1162,37 @@ class PlainRegistry(metaclass=RegistryMeta): return ret - def _eval_token(self, token, case_sensitive=None, **values): + def _eval_token( + self, + token: TokenInfo, + case_sensitive: bool | None = None, + **values: QuantityArgument, + ): + """Evaluate a single token using the following rules: + + 1. numerical values as strings are replaced by their numeric counterparts + - integers are parsed as integers + - other numeric values are parses of non_int_type + 2. strings in (inf, infinity, nan, dimensionless) with their numerical value. + 3. strings in values.keys() are replaced by Quantity(values[key]) + 4. in other cases, the values are parsed as units and replaced by their canonical name. + + Parameters + ---------- + token + Token to evaluate. + case_sensitive, optional + If true, a case sensitive matching of the unit name will be done in the registry. + If false, a case INsensitive matching of the unit name will be done in the registry. + (Default value = None, which uses registry setting) + **values + Other string that will be parsed using the Quantity constructor on their corresponding value. + """ token_type = token[0] token_text = token[1] if token_type == NAME: if token_text == "dimensionless": - return self.Quantity(1, self.dimensionless) + return self.Quantity(1) elif token_text.lower() in ("inf", "infinity"): return self.non_int_type("inf") elif token_text.lower() == "nan": @@ -1139,28 +1222,25 @@ class PlainRegistry(metaclass=RegistryMeta): Parameters ---------- - input_string : + input_string pattern_string: - The regex parse string - case_sensitive : - (Default value = None, which uses registry setting) - many : + The regex parse string + case_sensitive, optional + If true, a case sensitive matching of the unit name will be done in the registry. + If false, a case INsensitive matching of the unit name will be done in the registry. + (Default value = None, which uses registry setting) + many, optional Match many results (Default value = False) - - - Returns - ------- - """ if not input_string: return [] if many else None # Parse string - pattern = pattern_to_regex(pattern) - matched = re.finditer(pattern, input_string) + regex = pattern_to_regex(pattern) + matched = re.finditer(regex, input_string) # Extract result(s) results = [] @@ -1184,11 +1264,11 @@ class PlainRegistry(metaclass=RegistryMeta): return results def parse_expression( - self, + self: Self, input_string: str, case_sensitive: bool | None = None, - **values, - ) -> Quantity: + **values: QuantityArgument, + ) -> QuantityT: """Parse a mathematical expression including units and return a quantity object. Numerical constants can be specified as keyword arguments and will take precedence @@ -1196,16 +1276,14 @@ class PlainRegistry(metaclass=RegistryMeta): Parameters ---------- - input_string : - - case_sensitive : - (Default value = None, which uses registry setting) - **values : - - - Returns - ------- - + input_string + + case_sensitive, optional + If true, a case sensitive matching of the unit name will be done in the registry. + If false, a case INsensitive matching of the unit name will be done in the registry. + (Default value = None, which uses registry setting) + **values + Other string that will be parsed using the Quantity constructor on their corresponding value. """ if not input_string: return self.Quantity(1) @@ -1215,8 +1293,21 @@ class PlainRegistry(metaclass=RegistryMeta): input_string = string_preprocessor(input_string) gen = tokenizer(input_string) - return build_eval_tree(gen).evaluate( - lambda x: self._eval_token(x, case_sensitive=case_sensitive, **values) - ) + def _define_op(s: str): + return self._eval_token(s, case_sensitive=case_sensitive, **values) + + return build_eval_tree(gen).evaluate(_define_op) + + # We put this last to avoid overriding UnitsContainer + # and I do not want to rename it. + # TODO: Maybe in the future we need to change it to a more meaningful + # non-colliding name. + def UnitsContainer(self, *args: Any, **kwargs: Any) -> UnitsContainer: + return UnitsContainer(*args, non_int_type=self.non_int_type, **kwargs) __call__ = parse_expression + + +class PlainRegistry(GenericPlainRegistry[PlainQuantity[Any], PlainUnit]): + Quantity: TypeAlias = PlainQuantity[Any] + Unit: TypeAlias = PlainUnit diff --git a/pint/facets/system/__init__.py b/pint/facets/system/__init__.py index e95098b..24e68b7 100644 --- a/pint/facets/system/__init__.py +++ b/pint/facets/system/__init__.py @@ -12,6 +12,6 @@ from __future__ import annotations from .definitions import SystemDefinition from .objects import System -from .registry import SystemRegistry +from .registry import SystemRegistry, GenericSystemRegistry -__all__ = ["SystemDefinition", "System", "SystemRegistry"] +__all__ = ["SystemDefinition", "System", "SystemRegistry", "GenericSystemRegistry"] diff --git a/pint/facets/system/definitions.py b/pint/facets/system/definitions.py index 1ce8269..eb582f3 100644 --- a/pint/facets/system/definitions.py +++ b/pint/facets/system/definitions.py @@ -11,7 +11,7 @@ from __future__ import annotations from collections.abc import Iterable from dataclasses import dataclass -from ..._typing import Self +from ...compat import Self from ... import errors diff --git a/pint/facets/system/objects.py b/pint/facets/system/objects.py index 69b1c84..cf6a24f 100644 --- a/pint/facets/system/objects.py +++ b/pint/facets/system/objects.py @@ -14,7 +14,9 @@ import numbers from typing import Any from collections.abc import Iterable -from ..._typing import Self + +from typing import Callable, Generic +from numbers import Number from ...babel_names import _babel_systems from ...compat import babel_parse @@ -25,6 +27,20 @@ from ...util import ( to_units_container, ) from .definitions import SystemDefinition +from .. import group +from ..plain import MagnitudeT + +from ..._typing import UnitLike + +GetRootUnits = Callable[[UnitLike, bool], tuple[Number, UnitLike]] + + +class SystemQuantity(Generic[MagnitudeT], group.GroupQuantity[MagnitudeT]): + pass + + +class SystemUnit(group.GroupUnit): + pass class System(SharedRegistryObject): @@ -76,11 +92,11 @@ class System(SharedRegistryObject): def members(self): d = self._REGISTRY._groups if self._computed_members is None: - self._computed_members = set() + tmp: set[str] = set() for group_name in self._used_groups: try: - self._computed_members |= d[group_name].members + tmp |= d[group_name].members except KeyError: logger.warning( "Could not resolve {} in System {}".format( @@ -88,7 +104,7 @@ class System(SharedRegistryObject): ) ) - self._computed_members = frozenset(self._computed_members) + self._computed_members = frozenset(tmp) return self._computed_members @@ -116,17 +132,30 @@ class System(SharedRegistryObject): return locale.measurement_systems[name] return self.name + # TODO: When 3.11 is minimal version, use Self + @classmethod def from_lines( - cls: type[Self], lines: Iterable[str], get_root_func, non_int_type: type = float - ) -> Self: + cls: type[System], + lines: Iterable[str], + get_root_func: GetRootUnits, + non_int_type: type = float, + ) -> System: # TODO: we changed something here it used to be # system_definition = SystemDefinition.from_lines(lines, get_root_func) system_definition = SystemDefinition.from_lines(lines, non_int_type) + + if system_definition is None: + raise ValueError(f"Could not define System from from {lines}") + return cls.from_definition(system_definition, get_root_func) @classmethod - def from_definition(cls, system_definition: SystemDefinition, get_root_func=None): + def from_definition( + cls: type[System], + system_definition: SystemDefinition, + get_root_func: GetRootUnits | None = None, + ) -> System: if get_root_func is None: # TODO: kept for backwards compatibility get_root_func = cls._REGISTRY.get_root_units diff --git a/pint/facets/system/registry.py b/pint/facets/system/registry.py index 6e0878e..30921bd 100644 --- a/pint/facets/system/registry.py +++ b/pint/facets/system/registry.py @@ -9,10 +9,14 @@ from __future__ import annotations from numbers import Number -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Generic, Any from ... import errors +from ...compat import TypeAlias + +from ..plain import QuantityT, UnitT + if TYPE_CHECKING: from ..._typing import Quantity, Unit @@ -22,13 +26,14 @@ from ...util import ( create_class_with_registry, to_units_container, ) -from ..group import GroupRegistry +from ..group import GenericGroupRegistry from .definitions import SystemDefinition -from .objects import Lister, System from . import objects -class SystemRegistry(GroupRegistry): +class GenericSystemRegistry( + Generic[QuantityT, UnitT], GenericGroupRegistry[QuantityT, UnitT] +): """Handle of Systems. Conversion between units with different dimensions according @@ -46,24 +51,24 @@ class SystemRegistry(GroupRegistry): # TODO: Change this to System: System to specify class # and use introspection to get system class as a way # to enjoy typing goodies - System = objects.System + System: type[objects.System] - def __init__(self, system=None, **kwargs): + def __init__(self, system: str | None = None, **kwargs): super().__init__(**kwargs) #: Map system name to system. #: :type: dict[ str | System] - self._systems: dict[str, System] = {} + self._systems: dict[str, objects.System] = {} #: Maps dimensionality (UnitsContainer) to Dimensionality (UnitsContainer) - self._base_units_cache = {} + self._base_units_cache: dict[UnitsContainerT, UnitsContainerT] = {} - self._default_system = system + self._default_system_name: str | None = system def _init_dynamic_classes(self) -> None: """Generate subclasses on the fly and attach them to self""" super()._init_dynamic_classes() - self.System = create_class_with_registry(self, self.System) + self.System = create_class_with_registry(self, objects.System) def _after_init(self) -> None: """Invoked at the end of ``__init__``. @@ -74,7 +79,7 @@ class SystemRegistry(GroupRegistry): super()._after_init() #: System name to be used by default. - self._default_system = self._default_system or self._defaults.get( + self._default_system_name = self._default_system_name or self._defaults.get( "system", None ) @@ -82,7 +87,7 @@ class SystemRegistry(GroupRegistry): super()._register_definition_adders() self._register_adder(SystemDefinition, self._add_system) - def _add_system(self, sd: SystemDefinition): + def _add_system(self, sd: SystemDefinition) -> None: if sd.name in self._systems: raise ValueError(f"System {sd.name} already present in registry") @@ -96,29 +101,29 @@ class SystemRegistry(GroupRegistry): @property def sys(self): - return Lister(self._systems) + return objects.Lister(self._systems) @property - def default_system(self) -> System: - return self._default_system + def default_system(self) -> str | None: + return self._default_system_name @default_system.setter - def default_system(self, name): + def default_system(self, name: str) -> None: if name: if name not in self._systems: raise ValueError("Unknown system %s" % name) self._base_units_cache = {} - self._default_system = name + self._default_system_name = name - def get_system(self, name: str, create_if_needed: bool = True) -> System: + def get_system(self, name: str, create_if_needed: bool = True) -> objects.System: """Return a Group. Parameters ---------- name : str - Name of the group to be + Name of the group to be. create_if_needed : bool If True, create a group if not found. If False, raise an Exception. (Default value = True) @@ -141,7 +146,7 @@ class SystemRegistry(GroupRegistry): self, input_units: UnitLike | Quantity, check_nonmult: bool = True, - system: str | System | None = None, + system: str | objects.System | None = None, ) -> tuple[Number, Unit]: """Convert unit or dict of units to the plain units. @@ -179,15 +184,15 @@ class SystemRegistry(GroupRegistry): self, input_units: UnitsContainerT, check_nonmult: bool = True, - system: str | System | None = None, + system: str | objects.System | None = None, ): if system is None: - system = self._default_system + system = self._default_system_name # The cache is only done for check_nonmult=True and the current system. if ( check_nonmult - and system == self._default_system + and system == self._default_system_name and input_units in self._base_units_cache ): return self._base_units_cache[input_units] @@ -220,16 +225,32 @@ class SystemRegistry(GroupRegistry): return base_factor, destination_units - def _get_compatible_units(self, input_units, group_or_system) -> frozenset[Unit]: + def get_compatible_units( + self, input_units: UnitsContainerT, group_or_system: str | None = None + ) -> frozenset[Unit]: + """ """ + + group_or_system = group_or_system or self._default_system_name + if group_or_system is None: - group_or_system = self._default_system + return super().get_compatible_units(input_units) + + input_units = to_units_container(input_units) + + equiv = self._get_compatible_units(input_units, group_or_system) + + return frozenset(self.Unit(eq) for eq in equiv) + def _get_compatible_units( + self, input_units: UnitsContainerT, group_or_system: str | None = None + ) -> frozenset[Unit]: if group_or_system and group_or_system in self._systems: members = self._systems[group_or_system].members # group_or_system has been handled by System - return frozenset(members & super()._get_compatible_units(input_units, None)) + return frozenset(members & super()._get_compatible_units(input_units)) try: + # This will be handled by groups return super()._get_compatible_units(input_units, group_or_system) except ValueError as ex: # It might be also a system @@ -238,3 +259,10 @@ class SystemRegistry(GroupRegistry): "Unknown Group o System with name '%s'" % group_or_system ) from ex raise ex + + +class SystemRegistry( + GenericSystemRegistry[objects.SystemQuantity[Any], objects.SystemUnit] +): + Quantity: TypeAlias = objects.SystemQuantity[Any] + Unit: TypeAlias = objects.SystemUnit diff --git a/pint/formatting.py b/pint/formatting.py index 880f55b..28adf25 100644 --- a/pint/formatting.py +++ b/pint/formatting.py @@ -13,17 +13,27 @@ from __future__ import annotations import functools import re import warnings -from typing import Callable, Any +from typing import Callable, Any, TYPE_CHECKING, TypeVar from collections.abc import Iterable from numbers import Number from .babel_names import _babel_lengths, _babel_units -from .compat import babel_parse +from .compat import babel_parse, HAS_BABEL + +if TYPE_CHECKING: + from .util import ItMatrix, UnitsContainer + + if HAS_BABEL: + import babel + + Locale = babel.Locale + else: + Locale = TypeVar("Locale") __JOIN_REG_EXP = re.compile(r"{\d*}") -def _join(fmt: str, iterable: Iterable[Any]): +def _join(fmt: str, iterable: Iterable[Any]) -> str: """Join an iterable with the format specified in fmt. The format can be specified in two ways: @@ -124,6 +134,7 @@ _FORMATS: dict[str, dict[str, Any]] = { } #: _FORMATTERS maps format names to callables doing the formatting +# TODO fix Callable typing _FORMATTERS: dict[str, Callable] = {} @@ -167,7 +178,7 @@ def register_unit_format(name: str): @register_unit_format("P") -def format_pretty(unit, registry, **options): +def format_pretty(unit: UnitsContainer, registry, **options) -> str: return formatter( unit.items(), as_ratio=True, @@ -181,7 +192,7 @@ def format_pretty(unit, registry, **options): ) -def latex_escape(string): +def latex_escape(string: str) -> str: """ Prepend characters that have a special meaning in LaTeX with a backslash. """ @@ -198,7 +209,7 @@ def latex_escape(string): @register_unit_format("L") -def format_latex(unit, registry, **options): +def format_latex(unit: UnitsContainer, registry, **options) -> str: preprocessed = {rf"\mathrm{{{latex_escape(u)}}}": p for u, p in unit.items()} formatted = formatter( preprocessed.items(), @@ -214,7 +225,7 @@ def format_latex(unit, registry, **options): @register_unit_format("Lx") -def format_latex_siunitx(unit, registry, **options): +def format_latex_siunitx(unit: UnitsContainer, registry, **options) -> str: if registry is None: raise ValueError( "Can't format as siunitx without a registry." @@ -228,7 +239,7 @@ def format_latex_siunitx(unit, registry, **options): @register_unit_format("H") -def format_html(unit, registry, **options): +def format_html(unit: UnitsContainer, registry, **options) -> str: return formatter( unit.items(), as_ratio=True, @@ -242,7 +253,7 @@ def format_html(unit, registry, **options): @register_unit_format("D") -def format_default(unit, registry, **options): +def format_default(unit: UnitsContainer, registry, **options) -> str: return formatter( unit.items(), as_ratio=True, @@ -256,7 +267,7 @@ def format_default(unit, registry, **options): @register_unit_format("C") -def format_compact(unit, registry, **options): +def format_compact(unit: UnitsContainer, registry, **options) -> str: return formatter( unit.items(), as_ratio=True, @@ -270,7 +281,7 @@ def format_compact(unit, registry, **options): def formatter( - items: list[tuple[str, Number]], + items: Iterable[tuple[str, Number]], as_ratio: bool = True, single_denominator: bool = False, product_fmt: str = " * ", @@ -282,7 +293,7 @@ def formatter( babel_length: str = "long", babel_plural_form: str = "one", sort: bool = True, -): +) -> str: """Format a list of (name, exponent) pairs. Parameters @@ -393,7 +404,7 @@ def formatter( _BASIC_TYPES = frozenset("bcdeEfFgGnosxX%uS") -def _parse_spec(spec): +def _parse_spec(spec: str) -> str: result = "" for ch in reversed(spec): if ch == "~" or ch in _BASIC_TYPES: @@ -410,7 +421,7 @@ def _parse_spec(spec): return result -def format_unit(unit, spec, registry=None, **options): +def format_unit(unit, spec: str, registry=None, **options): # registry may be None to allow formatting `UnitsContainer` objects # in that case, the spec may not be "Lx" @@ -430,10 +441,10 @@ def format_unit(unit, spec, registry=None, **options): return fmt(unit, registry=registry, **options) -def siunitx_format_unit(units, registry): +def siunitx_format_unit(units: UnitsContainer, registry) -> str: """Returns LaTeX code for the unit that can be put into an siunitx command.""" - def _tothe(power): + def _tothe(power: int | float) -> str: if isinstance(power, int) or (isinstance(power, float) and power.is_integer()): if power == 1: return "" @@ -473,7 +484,7 @@ def siunitx_format_unit(units, registry): return "".join(lpos) + "".join(lneg) -def extract_custom_flags(spec): +def extract_custom_flags(spec: str) -> str: import re if not spec: @@ -488,14 +499,16 @@ def extract_custom_flags(spec): return "".join(custom_flags) -def remove_custom_flags(spec): +def remove_custom_flags(spec: str) -> str: for flag in sorted(_FORMATTERS.keys(), key=len, reverse=True) + ["~"]: if flag: spec = spec.replace(flag, "") return spec -def split_format(spec, default, separate_format_defaults=True): +def split_format( + spec: str, default: str, separate_format_defaults: bool = True +) -> tuple[str, str]: mspec = remove_custom_flags(spec) uspec = extract_custom_flags(spec) @@ -535,11 +548,11 @@ def split_format(spec, default, separate_format_defaults=True): return mspec, uspec -def vector_to_latex(vec, fmtfun=lambda x: format(x, ".2f")): +def vector_to_latex(vec: Iterable[Any], fmtfun=lambda x: format(x, ".2f")) -> str: return matrix_to_latex([vec], fmtfun) -def matrix_to_latex(matrix, fmtfun=lambda x: format(x, ".2f")): +def matrix_to_latex(matrix: ItMatrix, fmtfun=lambda x: format(x, ".2f")) -> str: ret = [] for row in matrix: @@ -548,7 +561,9 @@ def matrix_to_latex(matrix, fmtfun=lambda x: format(x, ".2f")): return r"\begin{pmatrix}%s\end{pmatrix}" % "\\\\ \n".join(ret) -def ndarray_to_latex_parts(ndarr, fmtfun=lambda x: format(x, ".2f"), dim=()): +def ndarray_to_latex_parts( + ndarr, fmtfun=lambda x: format(x, ".2f"), dim: tuple[int] = tuple() +): if isinstance(fmtfun, str): fmt = fmtfun fmtfun = lambda x: format(x, fmt) @@ -573,5 +588,7 @@ def ndarray_to_latex_parts(ndarr, fmtfun=lambda x: format(x, ".2f"), dim=()): return ret -def ndarray_to_latex(ndarr, fmtfun=lambda x: format(x, ".2f"), dim=()): +def ndarray_to_latex( + ndarr, fmtfun=lambda x: format(x, ".2f"), dim: tuple[int] = tuple() +) -> str: return "\n".join(ndarray_to_latex_parts(ndarr, fmtfun, dim)) diff --git a/pint/registry.py b/pint/registry.py index 474eb77..964d8a5 100644 --- a/pint/registry.py +++ b/pint/registry.py @@ -14,16 +14,10 @@ from __future__ import annotations +from typing import Generic + from . import registry_helpers -from .facets import ( - ContextRegistry, - DaskRegistry, - FormattingRegistry, - MeasurementRegistry, - NonMultiplicativeRegistry, - NumpyRegistry, - SystemRegistry, -) +from . import facets from .util import logger, pi_theorem @@ -33,37 +27,40 @@ from .util import logger, pi_theorem class Quantity( - # SystemRegistry.Quantity, - # ContextRegistry.Quantity, - DaskRegistry.Quantity, - NumpyRegistry.Quantity, - MeasurementRegistry.Quantity, - FormattingRegistry.Quantity, - NonMultiplicativeRegistry.Quantity, + facets.SystemRegistry.Quantity, + facets.ContextRegistry.Quantity, + facets.DaskRegistry.Quantity, + facets.NumpyRegistry.Quantity, + facets.MeasurementRegistry.Quantity, + facets.FormattingRegistry.Quantity, + facets.NonMultiplicativeRegistry.Quantity, + facets.PlainRegistry.Quantity, ): pass class Unit( - # SystemRegistry.Unit, - # ContextRegistry.Unit, - # DaskRegistry.Unit, - NumpyRegistry.Unit, - # MeasurementRegistry.Unit, - FormattingRegistry.Unit, - NonMultiplicativeRegistry.Unit, + facets.SystemRegistry.Unit, + facets.ContextRegistry.Unit, + facets.DaskRegistry.Unit, + facets.NumpyRegistry.Unit, + facets.MeasurementRegistry.Unit, + facets.FormattingRegistry.Unit, + facets.NonMultiplicativeRegistry.Unit, + facets.PlainRegistry.Unit, ): pass class UnitRegistry( - SystemRegistry, - ContextRegistry, - DaskRegistry, - NumpyRegistry, - MeasurementRegistry, - FormattingRegistry, - NonMultiplicativeRegistry, + facets.GenericSystemRegistry[Quantity, Unit], + facets.GenericContextRegistry[Quantity, Unit], + facets.GenericDaskRegistry[Quantity, Unit], + facets.GenericNumpyRegistry[Quantity, Unit], + facets.GenericMeasurementRegistry[Quantity, Unit], + facets.GenericFormattingRegistry[Quantity, Unit], + facets.GenericNonMultiplicativeRegistry[Quantity, Unit], + facets.GenericPlainRegistry[Quantity, Unit], ): """The unit registry stores the definitions and relationships between units. @@ -171,7 +168,7 @@ class UnitRegistry( check = registry_helpers.check -class LazyRegistry: +class LazyRegistry(Generic[facets.QuantityT, facets.UnitT]): def __init__(self, args=None, kwargs=None): self.__dict__["params"] = args or (), kwargs or {} diff --git a/pint/registry_helpers.py b/pint/registry_helpers.py index 1f28036..7eee694 100644 --- a/pint/registry_helpers.py +++ b/pint/registry_helpers.py @@ -13,7 +13,7 @@ from __future__ import annotations import functools from inspect import signature from itertools import zip_longest -from typing import TYPE_CHECKING, Callable, TypeVar +from typing import TYPE_CHECKING, Callable, TypeVar, Any from collections.abc import Iterable from ._typing import F @@ -189,7 +189,7 @@ def wraps( ret: str | Unit | Iterable[str | Unit | None] | None, args: str | Unit | Iterable[str | Unit | None] | None, strict: bool = True, -) -> Callable[[Callable[..., T]], Callable[..., Quantity[T]]]: +) -> Callable[[Callable[..., Any]], Callable[..., Quantity]]: """Wraps a function to become pint-aware. Use it when a function requires a numerical value but in some specific @@ -253,7 +253,7 @@ def wraps( ) ret = _to_units_container(ret, ureg) - def decorator(func: Callable[..., T]) -> Callable[..., Quantity[T]]: + def decorator(func: Callable[..., Any]) -> Callable[..., Quantity]: count_params = len(signature(func).parameters) if len(args) != count_params: raise TypeError( @@ -269,7 +269,7 @@ def wraps( ) @functools.wraps(func, assigned=assigned, updated=updated) - def wrapper(*values, **kw) -> Quantity[T]: + def wrapper(*values, **kw) -> Quantity: values, kw = _apply_defaults(func, values, kw) # In principle, the values are used as is diff --git a/pint/testing.py b/pint/testing.py index 8e4f15f..d99df0b 100644 --- a/pint/testing.py +++ b/pint/testing.py @@ -34,7 +34,7 @@ def _get_comparable_magnitudes(first, second, msg): return m1, m2 -def assert_equal(first, second, msg=None): +def assert_equal(first, second, msg: str | None = None) -> None: if msg is None: msg = f"Comparing {first!r} and {second!r}. " @@ -57,7 +57,9 @@ def assert_equal(first, second, msg=None): assert m1 == m2, msg -def assert_allclose(first, second, rtol=1e-07, atol=0, msg=None): +def assert_allclose( + first, second, rtol: float = 1e-07, atol: float = 0, msg: str | None = None +) -> None: if msg is None: try: msg = f"Comparing {first!r} and {second!r}. " diff --git a/pint/util.py b/pint/util.py index d75d1b5..40ea39e 100644 --- a/pint/util.py +++ b/pint/util.py @@ -30,15 +30,15 @@ from typing import ( ) from collections.abc import Hashable, Generator -from .compat import NUMERIC_TYPES, tokenizer +from .compat import NUMERIC_TYPES, tokenizer, Self from .errors import DefinitionSyntaxError from .formatting import format_unit from .pint_eval import build_eval_tree -from ._typing import PintScalar +from ._typing import Scalar if TYPE_CHECKING: - from ._typing import Quantity, UnitLike, Self + from ._typing import Quantity, UnitLike, QuantityOrUnitLike from .registry import UnitRegistry @@ -47,12 +47,13 @@ logger.addHandler(NullHandler()) T = TypeVar("T") TH = TypeVar("TH", bound=Hashable) +TT = TypeVar("TT", bound=type) # TODO: Change when Python 3.10 becomes minimal version. # ItMatrix: TypeAlias = Iterable[Iterable[PintScalar]] # Matrix: TypeAlias = list[list[PintScalar]] -ItMatrix = Iterable[Iterable[PintScalar]] -Matrix = list[list[PintScalar]] +ItMatrix = Iterable[Iterable[Scalar]] +Matrix = list[list[Scalar]] def _noop(x: T) -> T: @@ -65,7 +66,7 @@ def matrix_to_string( col_headers: Iterable[str] | None = None, fmtfun: Callable[ [ - PintScalar, + Scalar, ], str, ] = "{:0.0f}".format, @@ -125,9 +126,9 @@ def matrix_apply( matrix: ItMatrix, func: Callable[ [ - PintScalar, + Scalar, ], - PintScalar, + Scalar, ], ) -> Matrix: """Apply a function to individual elements within a matrix. @@ -172,7 +173,14 @@ def column_echelon_form( Swapped rows. """ - _transpose = transpose if transpose_result else _noop + _transpose: Callable[ + [ + ItMatrix, + ], + Matrix, + ] = ( + transpose if transpose_result else _noop + ) ech_matrix = matrix_apply( transpose(matrix), @@ -181,7 +189,7 @@ def column_echelon_form( rows, cols = len(ech_matrix), len(ech_matrix[0]) # M = [[ntype(x) for x in row] for row in M] - id_matrix: list[list[PintScalar]] = [ # noqa: E741 + id_matrix: list[list[Scalar]] = [ # noqa: E741 [ntype(1) if n == nc else ntype(0) for nc in range(rows)] for n in range(rows) ] @@ -415,7 +423,7 @@ def find_connected_nodes( return visited -class udict(dict[str, PintScalar]): +class udict(dict[str, Scalar]): """Custom dict implementing __missing__.""" def __missing__(self, key: str): @@ -425,7 +433,7 @@ class udict(dict[str, PintScalar]): return udict(self) -class UnitsContainer(Mapping[str, PintScalar]): +class UnitsContainer(Mapping[str, Scalar]): """The UnitsContainer stores the product of units and their respective exponent and implements the corresponding operations. @@ -441,10 +449,12 @@ class UnitsContainer(Mapping[str, PintScalar]): _d: udict _hash: int | None - _one: PintScalar + _one: Scalar _non_int_type: type - def __init__(self, *args, non_int_type: type | None = None, **kwargs) -> None: + def __init__( + self, *args: Any, non_int_type: type | None = None, **kwargs: Any + ) -> None: if args and isinstance(args[0], UnitsContainer): default_non_int_type = args[0]._non_int_type else: @@ -542,7 +552,7 @@ class UnitsContainer(Mapping[str, PintScalar]): def __len__(self) -> int: return len(self._d) - def __getitem__(self, key: str) -> PintScalar: + def __getitem__(self, key: str) -> Scalar: return self._d[key] def __contains__(self, key: str) -> bool: @@ -554,10 +564,10 @@ class UnitsContainer(Mapping[str, PintScalar]): return self._hash # Only needed by pickle protocol 0 and 1 (used by pytables) - def __getstate__(self) -> tuple[udict, PintScalar, type]: + def __getstate__(self) -> tuple[udict, Scalar, type]: return self._d, self._one, self._non_int_type - def __setstate__(self, state: tuple[udict, PintScalar, type]): + def __setstate__(self, state: tuple[udict, Scalar, type]): self._d, self._one, self._non_int_type = state self._hash = None @@ -682,9 +692,9 @@ class ParserHelper(UnitsContainer): __slots__ = ("scale",) - scale: PintScalar + scale: Scalar - def __init__(self, scale: PintScalar = 1, *args, **kwargs): + def __init__(self, scale: Scalar = 1, *args, **kwargs): super().__init__(*args, **kwargs) self.scale = scale @@ -1002,7 +1012,7 @@ class PrettyIPython: def to_units_container( - unit_like: UnitLike | Quantity, registry: UnitRegistry | None = None + unit_like: QuantityOrUnitLike, registry: UnitRegistry | None = None ) -> UnitsContainer: """Convert a unit compatible type to a UnitsContainer. @@ -1025,6 +1035,7 @@ def to_units_container( return unit_like._units elif str in mro: if registry: + # TODO: Why not parse.units here? return registry._parse_units(unit_like) else: return ParserHelper.from_string(unit_like) @@ -1124,7 +1135,9 @@ def sized(y: Any) -> bool: return True -def create_class_with_registry(registry: UnitRegistry, base_class: type) -> type: +def create_class_with_registry( + registry: UnitRegistry, base_class: type[TT] +) -> type[TT]: """Create new class inheriting from base_class and filling _REGISTRY class attribute with an actual instanced registry. """ |