diff options
Diffstat (limited to 'pint/facets/group/registry.py')
-rw-r--r-- | pint/facets/group/registry.py | 50 |
1 files changed, 40 insertions, 10 deletions
diff --git a/pint/facets/group/registry.py b/pint/facets/group/registry.py index 0d35ae0..f130e61 100644 --- a/pint/facets/group/registry.py +++ b/pint/facets/group/registry.py @@ -8,20 +8,28 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Generic, Any +from ...compat import TypeAlias from ... import errors if TYPE_CHECKING: - from ..._typing import Unit - -from ...util import create_class_with_registry -from ..plain import PlainRegistry, UnitDefinition + from ..._typing import Unit, UnitsContainer + +from ...util import create_class_with_registry, to_units_container +from ..plain import ( + GenericPlainRegistry, + UnitDefinition, + QuantityT, + UnitT, +) from .definitions import GroupDefinition from . import objects -class GroupRegistry(PlainRegistry): +class GenericGroupRegistry( + Generic[QuantityT, UnitT], GenericPlainRegistry[QuantityT, UnitT] +): """Handle of Groups. Group units @@ -34,7 +42,7 @@ class GroupRegistry(PlainRegistry): # TODO: Change this to Group: Group to specify class # and use introspection to get system class as a way # to enjoy typing goodies - Group = objects.Group + Group = type[objects.Group] def __init__(self, **kwargs): super().__init__(**kwargs) @@ -46,7 +54,7 @@ class GroupRegistry(PlainRegistry): def _init_dynamic_classes(self) -> None: """Generate subclasses on the fly and attach them to self""" super()._init_dynamic_classes() - self.Group = create_class_with_registry(self, self.Group) + self.Group = create_class_with_registry(self, objects.Group) def _after_init(self) -> None: """Invoked at the end of ``__init__``. @@ -113,8 +121,23 @@ class GroupRegistry(PlainRegistry): return self.Group(name) - def _get_compatible_units(self, input_units, group) -> frozenset[Unit]: - ret = super()._get_compatible_units(input_units, group) + def get_compatible_units( + self, input_units: UnitsContainer, group: str | None = None + ) -> frozenset[Unit]: + """ """ + if group is None: + return super().get_compatible_units(input_units) + + input_units = to_units_container(input_units) + + equiv = self._get_compatible_units(input_units, group) + + return frozenset(self.Unit(eq) for eq in equiv) + + def _get_compatible_units( + self, input_units: UnitsContainer, group: str | None = None + ) -> frozenset[str]: + ret = super()._get_compatible_units(input_units) if not group: return ret @@ -124,3 +147,10 @@ class GroupRegistry(PlainRegistry): else: raise ValueError("Unknown Group with name '%s'" % group) return frozenset(ret & members) + + +class GroupRegistry( + GenericGroupRegistry[objects.GroupQuantity[Any], objects.GroupUnit] +): + Quantity: TypeAlias = objects.GroupQuantity[Any] + Unit: TypeAlias = objects.GroupUnit |