summaryrefslogtreecommitdiff
path: root/pint/facets/group/registry.py
diff options
context:
space:
mode:
Diffstat (limited to 'pint/facets/group/registry.py')
-rw-r--r--pint/facets/group/registry.py50
1 files changed, 40 insertions, 10 deletions
diff --git a/pint/facets/group/registry.py b/pint/facets/group/registry.py
index 0d35ae0..f130e61 100644
--- a/pint/facets/group/registry.py
+++ b/pint/facets/group/registry.py
@@ -8,20 +8,28 @@
from __future__ import annotations
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Generic, Any
+from ...compat import TypeAlias
from ... import errors
if TYPE_CHECKING:
- from ..._typing import Unit
-
-from ...util import create_class_with_registry
-from ..plain import PlainRegistry, UnitDefinition
+ from ..._typing import Unit, UnitsContainer
+
+from ...util import create_class_with_registry, to_units_container
+from ..plain import (
+ GenericPlainRegistry,
+ UnitDefinition,
+ QuantityT,
+ UnitT,
+)
from .definitions import GroupDefinition
from . import objects
-class GroupRegistry(PlainRegistry):
+class GenericGroupRegistry(
+ Generic[QuantityT, UnitT], GenericPlainRegistry[QuantityT, UnitT]
+):
"""Handle of Groups.
Group units
@@ -34,7 +42,7 @@ class GroupRegistry(PlainRegistry):
# TODO: Change this to Group: Group to specify class
# and use introspection to get system class as a way
# to enjoy typing goodies
- Group = objects.Group
+ Group = type[objects.Group]
def __init__(self, **kwargs):
super().__init__(**kwargs)
@@ -46,7 +54,7 @@ class GroupRegistry(PlainRegistry):
def _init_dynamic_classes(self) -> None:
"""Generate subclasses on the fly and attach them to self"""
super()._init_dynamic_classes()
- self.Group = create_class_with_registry(self, self.Group)
+ self.Group = create_class_with_registry(self, objects.Group)
def _after_init(self) -> None:
"""Invoked at the end of ``__init__``.
@@ -113,8 +121,23 @@ class GroupRegistry(PlainRegistry):
return self.Group(name)
- def _get_compatible_units(self, input_units, group) -> frozenset[Unit]:
- ret = super()._get_compatible_units(input_units, group)
+ def get_compatible_units(
+ self, input_units: UnitsContainer, group: str | None = None
+ ) -> frozenset[Unit]:
+ """ """
+ if group is None:
+ return super().get_compatible_units(input_units)
+
+ input_units = to_units_container(input_units)
+
+ equiv = self._get_compatible_units(input_units, group)
+
+ return frozenset(self.Unit(eq) for eq in equiv)
+
+ def _get_compatible_units(
+ self, input_units: UnitsContainer, group: str | None = None
+ ) -> frozenset[str]:
+ ret = super()._get_compatible_units(input_units)
if not group:
return ret
@@ -124,3 +147,10 @@ class GroupRegistry(PlainRegistry):
else:
raise ValueError("Unknown Group with name '%s'" % group)
return frozenset(ret & members)
+
+
+class GroupRegistry(
+ GenericGroupRegistry[objects.GroupQuantity[Any], objects.GroupUnit]
+):
+ Quantity: TypeAlias = objects.GroupQuantity[Any]
+ Unit: TypeAlias = objects.GroupUnit