summaryrefslogtreecommitdiff
path: root/pint/facets/context/registry.py
diff options
context:
space:
mode:
Diffstat (limited to 'pint/facets/context/registry.py')
-rw-r--r--pint/facets/context/registry.py48
1 files changed, 34 insertions, 14 deletions
diff --git a/pint/facets/context/registry.py b/pint/facets/context/registry.py
index a36d82d..746e79c 100644
--- a/pint/facets/context/registry.py
+++ b/pint/facets/context/registry.py
@@ -11,12 +11,13 @@ from __future__ import annotations
import functools
from collections import ChainMap
from contextlib import contextmanager
-from typing import Any, Callable, ContextManager
+from typing import Any, Callable, Generator, Generic
-from ..._typing import F
+from ...compat import TypeAlias
+from ..._typing import F, Magnitude
from ...errors import UndefinedUnitError
-from ...util import find_connected_nodes, find_shortest_path, logger
-from ..plain import PlainRegistry, UnitDefinition
+from ...util import find_connected_nodes, find_shortest_path, logger, UnitsContainer
+from ..plain import GenericPlainRegistry, UnitDefinition, QuantityT, UnitT
from .definitions import ContextDefinition
from . import objects
@@ -36,7 +37,9 @@ class ContextCacheOverlay:
self.parse_unit = registry_cache.parse_unit
-class ContextRegistry(PlainRegistry):
+class GenericContextRegistry(
+ Generic[QuantityT, UnitT], GenericPlainRegistry[QuantityT, UnitT]
+):
"""Handle of Contexts.
Conversion between units with different dimensions according
@@ -50,7 +53,7 @@ class ContextRegistry(PlainRegistry):
- Parse @context directive.
"""
- Context = objects.Context
+ Context: type[objects.Context] = objects.Context
def __init__(self, **kwargs: Any) -> None:
# Map context name (string) or abbreviation to context.
@@ -65,13 +68,13 @@ class ContextRegistry(PlainRegistry):
super().__init__(**kwargs)
# Allow contexts to add override layers to the units
- self._units = ChainMap(self._units)
+ self._units: ChainMap[str, UnitDefinition] = ChainMap(self._units)
def _register_definition_adders(self) -> None:
super()._register_definition_adders()
self._register_adder(ContextDefinition, self.add_context)
- def add_context(self, context: Context | ContextDefinition) -> None:
+ def add_context(self, context: objects.Context | ContextDefinition) -> None:
"""Add a context object to the registry.
The context will be accessible by its name and aliases.
@@ -194,7 +197,7 @@ class ContextRegistry(PlainRegistry):
self.define(definition)
def enable_contexts(
- self, *names_or_contexts: str | objects.Context, **kwargs
+ self, *names_or_contexts: str | objects.Context, **kwargs: Any
) -> None:
"""Enable contexts provided by name or by object.
@@ -241,7 +244,7 @@ class ContextRegistry(PlainRegistry):
self._active_ctx.insert_contexts(*contexts)
self._switch_context_cache_and_units()
- def disable_contexts(self, n: int = None) -> None:
+ def disable_contexts(self, n: int | None = None) -> None:
"""Disable the last n enabled contexts.
Parameters
@@ -253,7 +256,9 @@ class ContextRegistry(PlainRegistry):
self._switch_context_cache_and_units()
@contextmanager
- def context(self, *names, **kwargs) -> ContextManager[objects.Context]:
+ def context(
+ self: GenericContextRegistry[QuantityT, UnitT], *names: str, **kwargs: Any
+ ) -> Generator[GenericContextRegistry[QuantityT, UnitT], None, None]:
"""Used as a context manager, this function enables to activate a context
which is removed after usage.
@@ -309,7 +314,7 @@ class ContextRegistry(PlainRegistry):
# the added contexts are removed from the active one.
self.disable_contexts(len(names))
- def with_context(self, name, **kwargs) -> Callable[[F], F]:
+ def with_context(self, name: str, **kwargs: Any) -> Callable[[F], F]:
"""Decorator to wrap a function call in a Pint context.
Use it to ensure that a certain context is active when
@@ -351,7 +356,13 @@ class ContextRegistry(PlainRegistry):
return decorator
- def _convert(self, value, src, dst, inplace=False):
+ def _convert(
+ self,
+ value: Magnitude,
+ src: UnitsContainer,
+ dst: UnitsContainer,
+ inplace: bool = False,
+ ) -> Magnitude:
"""Convert value from some source to destination units.
In addition to what is done by the PlainRegistry,
@@ -391,7 +402,9 @@ class ContextRegistry(PlainRegistry):
return super()._convert(value, src, dst, inplace)
- def _get_compatible_units(self, input_units, group_or_system):
+ def _get_compatible_units(
+ self, input_units: UnitsContainer, group_or_system: str | None = None
+ ):
src_dim = self._get_dimensionality(input_units)
ret = super()._get_compatible_units(input_units, group_or_system)
@@ -404,3 +417,10 @@ class ContextRegistry(PlainRegistry):
ret |= self._cache.dimensional_equivalents[node]
return ret
+
+
+class ContextRegistry(
+ GenericContextRegistry[objects.ContextQuantity[Any], objects.ContextUnit]
+):
+ Quantity: TypeAlias = objects.ContextQuantity[Any]
+ Unit: TypeAlias = objects.ContextUnit