diff options
Diffstat (limited to 'pint/facets/context/registry.py')
-rw-r--r-- | pint/facets/context/registry.py | 48 |
1 files changed, 34 insertions, 14 deletions
diff --git a/pint/facets/context/registry.py b/pint/facets/context/registry.py index a36d82d..746e79c 100644 --- a/pint/facets/context/registry.py +++ b/pint/facets/context/registry.py @@ -11,12 +11,13 @@ from __future__ import annotations import functools from collections import ChainMap from contextlib import contextmanager -from typing import Any, Callable, ContextManager +from typing import Any, Callable, Generator, Generic -from ..._typing import F +from ...compat import TypeAlias +from ..._typing import F, Magnitude from ...errors import UndefinedUnitError -from ...util import find_connected_nodes, find_shortest_path, logger -from ..plain import PlainRegistry, UnitDefinition +from ...util import find_connected_nodes, find_shortest_path, logger, UnitsContainer +from ..plain import GenericPlainRegistry, UnitDefinition, QuantityT, UnitT from .definitions import ContextDefinition from . import objects @@ -36,7 +37,9 @@ class ContextCacheOverlay: self.parse_unit = registry_cache.parse_unit -class ContextRegistry(PlainRegistry): +class GenericContextRegistry( + Generic[QuantityT, UnitT], GenericPlainRegistry[QuantityT, UnitT] +): """Handle of Contexts. Conversion between units with different dimensions according @@ -50,7 +53,7 @@ class ContextRegistry(PlainRegistry): - Parse @context directive. """ - Context = objects.Context + Context: type[objects.Context] = objects.Context def __init__(self, **kwargs: Any) -> None: # Map context name (string) or abbreviation to context. @@ -65,13 +68,13 @@ class ContextRegistry(PlainRegistry): super().__init__(**kwargs) # Allow contexts to add override layers to the units - self._units = ChainMap(self._units) + self._units: ChainMap[str, UnitDefinition] = ChainMap(self._units) def _register_definition_adders(self) -> None: super()._register_definition_adders() self._register_adder(ContextDefinition, self.add_context) - def add_context(self, context: Context | ContextDefinition) -> None: + def add_context(self, context: objects.Context | ContextDefinition) -> None: """Add a context object to the registry. The context will be accessible by its name and aliases. @@ -194,7 +197,7 @@ class ContextRegistry(PlainRegistry): self.define(definition) def enable_contexts( - self, *names_or_contexts: str | objects.Context, **kwargs + self, *names_or_contexts: str | objects.Context, **kwargs: Any ) -> None: """Enable contexts provided by name or by object. @@ -241,7 +244,7 @@ class ContextRegistry(PlainRegistry): self._active_ctx.insert_contexts(*contexts) self._switch_context_cache_and_units() - def disable_contexts(self, n: int = None) -> None: + def disable_contexts(self, n: int | None = None) -> None: """Disable the last n enabled contexts. Parameters @@ -253,7 +256,9 @@ class ContextRegistry(PlainRegistry): self._switch_context_cache_and_units() @contextmanager - def context(self, *names, **kwargs) -> ContextManager[objects.Context]: + def context( + self: GenericContextRegistry[QuantityT, UnitT], *names: str, **kwargs: Any + ) -> Generator[GenericContextRegistry[QuantityT, UnitT], None, None]: """Used as a context manager, this function enables to activate a context which is removed after usage. @@ -309,7 +314,7 @@ class ContextRegistry(PlainRegistry): # the added contexts are removed from the active one. self.disable_contexts(len(names)) - def with_context(self, name, **kwargs) -> Callable[[F], F]: + def with_context(self, name: str, **kwargs: Any) -> Callable[[F], F]: """Decorator to wrap a function call in a Pint context. Use it to ensure that a certain context is active when @@ -351,7 +356,13 @@ class ContextRegistry(PlainRegistry): return decorator - def _convert(self, value, src, dst, inplace=False): + def _convert( + self, + value: Magnitude, + src: UnitsContainer, + dst: UnitsContainer, + inplace: bool = False, + ) -> Magnitude: """Convert value from some source to destination units. In addition to what is done by the PlainRegistry, @@ -391,7 +402,9 @@ class ContextRegistry(PlainRegistry): return super()._convert(value, src, dst, inplace) - def _get_compatible_units(self, input_units, group_or_system): + def _get_compatible_units( + self, input_units: UnitsContainer, group_or_system: str | None = None + ): src_dim = self._get_dimensionality(input_units) ret = super()._get_compatible_units(input_units, group_or_system) @@ -404,3 +417,10 @@ class ContextRegistry(PlainRegistry): ret |= self._cache.dimensional_equivalents[node] return ret + + +class ContextRegistry( + GenericContextRegistry[objects.ContextQuantity[Any], objects.ContextUnit] +): + Quantity: TypeAlias = objects.ContextQuantity[Any] + Unit: TypeAlias = objects.ContextUnit |