diff options
Diffstat (limited to 'pint')
33 files changed, 905 insertions, 490 deletions
diff --git a/pint/_typing.py b/pint/_typing.py index 1dc3ea6..5547f85 100644 --- a/pint/_typing.py +++ b/pint/_typing.py @@ -1,12 +1,56 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union +from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union, Protocol + +# TODO: Remove when 3.11 becomes minimal version. +Self = TypeVar("Self") if TYPE_CHECKING: from .facets.plain import PlainQuantity as Quantity from .facets.plain import PlainUnit as Unit from .util import UnitsContainer + +class PintScalar(Protocol): + def __add__(self, other: Any) -> Any: + ... + + def __sub__(self, other: Any) -> Any: + ... + + def __mul__(self, other: Any) -> Any: + ... + + def __truediv__(self, other: Any) -> Any: + ... + + def __floordiv__(self, other: Any) -> Any: + ... + + def __mod__(self, other: Any) -> Any: + ... + + def __divmod__(self, other: Any) -> Any: + ... + + def __pow__(self, other: Any, modulo: Any) -> Any: + ... + + +class PintArray(Protocol): + def __len__(self) -> int: + ... + + def __getitem__(self, key: Any) -> Any: + ... + + def __setitem__(self, key: Any, value: Any) -> None: + ... + + +Magnitude = PintScalar | PintScalar + + UnitLike = Union[str, "UnitsContainer", "Unit"] QuantityOrUnitLike = Union["Quantity", UnitLike] diff --git a/pint/babel_names.py b/pint/babel_names.py index 09fa046..408ef8f 100644 --- a/pint/babel_names.py +++ b/pint/babel_names.py @@ -10,7 +10,7 @@ from __future__ import annotations from .compat import HAS_BABEL -_babel_units = dict( +_babel_units: dict[str, str] = dict( standard_gravity="acceleration-g-force", millibar="pressure-millibar", metric_ton="mass-metric-ton", @@ -141,6 +141,6 @@ _babel_units = dict( if not HAS_BABEL: _babel_units = {} -_babel_systems = dict(mks="metric", imperial="uksystem", US="ussystem") +_babel_systems: dict[str, str] = dict(mks="metric", imperial="uksystem", US="ussystem") -_babel_lengths = ["narrow", "short", "long"] +_babel_lengths: list[str] = ["narrow", "short", "long"] diff --git a/pint/compat.py b/pint/compat.py index ee8d443..f58e9cb 100644 --- a/pint/compat.py +++ b/pint/compat.py @@ -17,12 +17,19 @@ from importlib import import_module from io import BytesIO from numbers import Number from collections.abc import Mapping +from typing import Any, NoReturn, Callable, Generator, Iterable -def missing_dependency(package, display_name=None): +def missing_dependency( + package: str, display_name: str | None = None +) -> Callable[..., NoReturn]: + """Return a helper function that raises an exception when used. + + It provides a way delay a missing dependency exception until it is used. + """ display_name = display_name or package - def _inner(*args, **kwargs): + def _inner(*args: Any, **kwargs: Any) -> NoReturn: raise Exception( "This feature requires %s. Please install it by running:\n" "pip install %s" % (display_name, package) @@ -31,7 +38,14 @@ def missing_dependency(package, display_name=None): return _inner -def tokenizer(input_string): +def tokenizer(input_string: str) -> Generator[tokenize.TokenInfo, None, None]: + """Tokenize an input string, encoded as UTF-8 + and skipping the ENCODING token. + + See Also + -------- + tokenize.tokenize + """ for tokinfo in tokenize.tokenize(BytesIO(input_string.encode("utf-8")).readline): if tokinfo.type != tokenize.ENCODING: yield tokinfo @@ -154,7 +168,8 @@ else: from math import log # noqa: F401 if not HAS_BABEL: - babel_parse = babel_units = missing_dependency("Babel") # noqa: F811 + babel_parse = missing_dependency("Babel") # noqa: F811 + babel_units = babel_parse if not HAS_MIP: mip_missing = missing_dependency("mip") @@ -176,6 +191,9 @@ except ImportError: dask_array = None +# TODO: merge with upcast_type_map + +#: List upcast type names upcast_type_names = ( "pint_pandas.PintArray", "pandas.Series", @@ -186,10 +204,12 @@ upcast_type_names = ( "xarray.core.dataarray.DataArray", ) -upcast_type_map: Mapping[str : type | None] = {k: None for k in upcast_type_names} +#: Map type name to the actual type (for upcast types). +upcast_type_map: Mapping[str, type | None] = {k: None for k in upcast_type_names} def fully_qualified_name(t: type) -> str: + """Return the fully qualified name of a type.""" module = t.__module__ name = t.__qualname__ @@ -200,6 +220,10 @@ def fully_qualified_name(t: type) -> str: def check_upcast_type(obj: type) -> bool: + """Check if the type object is an upcast type.""" + + # TODO: merge or unify name with is_upcast_type + fqn = fully_qualified_name(obj) if fqn not in upcast_type_map: return False @@ -215,22 +239,17 @@ def check_upcast_type(obj: type) -> bool: def is_upcast_type(other: type) -> bool: + """Check if the type object is an upcast type.""" + + # TODO: merge or unify name with check_upcast_type + if other in upcast_type_map.values(): return True return check_upcast_type(other) -def is_duck_array_type(cls) -> bool: - """Check if the type object represents a (non-Quantity) duck array type. - - Parameters - ---------- - cls : class - - Returns - ------- - bool - """ +def is_duck_array_type(cls: type) -> bool: + """Check if the type object represents a (non-Quantity) duck array type.""" # TODO (NEP 30): replace duck array check with hasattr(other, "__duckarray__") return issubclass(cls, ndarray) or ( not hasattr(cls, "_magnitude") @@ -242,20 +261,21 @@ def is_duck_array_type(cls) -> bool: ) -def is_duck_array(obj): +def is_duck_array(obj: type) -> bool: + """Check if an object represents a (non-Quantity) duck array type.""" return is_duck_array_type(type(obj)) -def eq(lhs, rhs, check_all: bool): +def eq(lhs: Any, rhs: Any, check_all: bool) -> bool | Iterable[bool]: """Comparison of scalars and arrays. Parameters ---------- - lhs : object + lhs left-hand side - rhs : object + rhs right-hand side - check_all : bool + check_all if True, reduce sequence to single bool; return True if all the elements are equal. @@ -269,21 +289,21 @@ def eq(lhs, rhs, check_all: bool): return out -def isnan(obj, check_all: bool): - """Test for NaN or NaT +def isnan(obj: Any, check_all: bool) -> bool | Iterable[bool]: + """Test for NaN or NaT. Parameters ---------- - obj : object + obj scalar or vector - check_all : bool + check_all if True, reduce sequence to single bool; return True if any of the elements are NaN. Returns ------- bool or array_like of bool. - Always return False for non-numeric types. + Always return False for non-numeric types. """ if is_duck_array_type(type(obj)): if obj.dtype.kind in "if": @@ -302,21 +322,21 @@ def isnan(obj, check_all: bool): return False -def zero_or_nan(obj, check_all: bool): - """Test if obj is zero, NaN, or NaT +def zero_or_nan(obj: Any, check_all: bool) -> bool | Iterable[bool]: + """Test if obj is zero, NaN, or NaT. Parameters ---------- - obj : object + obj scalar or vector - check_all : bool + check_all if True, reduce sequence to single bool; return True if all the elements are zero, NaN, or NaT. Returns ------- bool or array_like of bool. - Always return False for non-numeric types. + Always return False for non-numeric types. """ out = eq(obj, 0, False) + isnan(obj, False) if check_all and is_duck_array_type(type(out)): diff --git a/pint/context.py b/pint/context.py index 4839926..6c74f65 100644 --- a/pint/context.py +++ b/pint/context.py @@ -18,3 +18,5 @@ if TYPE_CHECKING: #: Regex to match the header parts of a context. #: Regex to match variable names in an equation. + +# TODO: delete this file diff --git a/pint/converters.py b/pint/converters.py index 9b8513f..9494ad1 100644 --- a/pint/converters.py +++ b/pint/converters.py @@ -13,6 +13,10 @@ from __future__ import annotations from dataclasses import dataclass from dataclasses import fields as dc_fields +from typing import Any + +from ._typing import Self, Magnitude + from .compat import HAS_NUMPY, exp, log # noqa: F401 @@ -24,17 +28,17 @@ class Converter: _param_names_to_subclass = {} @property - def is_multiplicative(self): + def is_multiplicative(self) -> bool: return True @property - def is_logarithmic(self): + def is_logarithmic(self) -> bool: return False - def to_reference(self, value, inplace=False): + def to_reference(self, value: Magnitude, inplace: bool = False) -> Magnitude: return value - def from_reference(self, value, inplace=False): + def from_reference(self, value: Magnitude, inplace: bool = False) -> Magnitude: return value def __init_subclass__(cls, **kwargs): @@ -43,7 +47,7 @@ class Converter: cls._subclasses.append(cls) @classmethod - def get_field_names(cls, new_cls): + def get_field_names(cls, new_cls) -> frozenset[str]: return frozenset(p.name for p in dc_fields(new_cls)) @classmethod @@ -51,7 +55,7 @@ class Converter: return None @classmethod - def from_arguments(cls, **kwargs): + def from_arguments(cls: type[Self], **kwargs: Any) -> Self: kwk = frozenset(kwargs.keys()) try: new_cls = cls._param_names_to_subclass[kwk] diff --git a/pint/definitions.py b/pint/definitions.py index 789d9e3..ce89e94 100644 --- a/pint/definitions.py +++ b/pint/definitions.py @@ -8,6 +8,8 @@ :license: BSD, see LICENSE for more details. """ +from __future__ import annotations + from . import errors from ._vendor import flexparser as fp from .delegates import ParserConfig, txt_defparser @@ -17,12 +19,28 @@ class Definition: """This is kept for backwards compatibility""" @classmethod - def from_string(cls, s: str, non_int_type=float): + def from_string(cls, input_string: str, non_int_type: type = float) -> Definition: + """Parse a string into a definition object. + + Parameters + ---------- + input_string + Single line string. + non_int_type + Numerical type used for non integer values. + + Raises + ------ + DefinitionSyntaxError + If a syntax error was found. + """ cfg = ParserConfig(non_int_type) parser = txt_defparser.DefParser(cfg, None) - pp = parser.parse_string(s) + pp = parser.parse_string(input_string) for definition in parser.iter_parsed_project(pp): if isinstance(definition, Exception): raise errors.DefinitionSyntaxError(str(definition)) if not isinstance(definition, (fp.BOS, fp.BOF, fp.BOS)): return definition + + # TODO: What shall we do in this return path. diff --git a/pint/delegates/__init__.py b/pint/delegates/__init__.py index 363ef9c..b2eb9a3 100644 --- a/pint/delegates/__init__.py +++ b/pint/delegates/__init__.py @@ -11,4 +11,4 @@ from . import txt_defparser from .base_defparser import ParserConfig, build_disk_cache_class -__all__ = [txt_defparser, ParserConfig, build_disk_cache_class] +__all__ = ["txt_defparser", "ParserConfig", "build_disk_cache_class"] diff --git a/pint/delegates/base_defparser.py b/pint/delegates/base_defparser.py index 88d9d37..9e784ac 100644 --- a/pint/delegates/base_defparser.py +++ b/pint/delegates/base_defparser.py @@ -14,7 +14,6 @@ import functools import itertools import numbers import pathlib -import typing as ty from dataclasses import dataclass, field from pint import errors @@ -27,10 +26,10 @@ from .._vendor import flexparser as fp @dataclass(frozen=True) class ParserConfig: - """Configuration used by the parser.""" + """Configuration used by the parser in Pint.""" #: Indicates the output type of non integer numbers. - non_int_type: ty.Type[numbers.Number] = float + non_int_type: type[numbers.Number] = float def to_scaled_units_container(self, s: str): return ParserHelper.from_string(s, self.non_int_type) @@ -67,6 +66,11 @@ class ParserConfig: return val.scale +@dataclass(frozen=True) +class PintParsedStatement(fp.ParsedStatement[ParserConfig]): + """A parsed statement for pint, specialized in the actual config.""" + + @functools.lru_cache def build_disk_cache_class(non_int_type: type): """Build disk cache class, taking into account the non_int_type.""" diff --git a/pint/delegates/txt_defparser/__init__.py b/pint/delegates/txt_defparser/__init__.py index 5572ca1..49e4a0b 100644 --- a/pint/delegates/txt_defparser/__init__.py +++ b/pint/delegates/txt_defparser/__init__.py @@ -11,4 +11,6 @@ from .defparser import DefParser -__all__ = [DefParser] +__all__ = [ + "DefParser", +] diff --git a/pint/delegates/txt_defparser/block.py b/pint/delegates/txt_defparser/block.py index 20ebcba..e8d8aa4 100644 --- a/pint/delegates/txt_defparser/block.py +++ b/pint/delegates/txt_defparser/block.py @@ -17,11 +17,14 @@ from __future__ import annotations from dataclasses import dataclass +from typing import Generic, TypeVar + +from ..base_defparser import PintParsedStatement, ParserConfig from ..._vendor import flexparser as fp @dataclass(frozen=True) -class EndDirectiveBlock(fp.ParsedStatement): +class EndDirectiveBlock(PintParsedStatement): """An EndDirectiveBlock is simply an "@end" statement.""" @classmethod @@ -31,8 +34,16 @@ class EndDirectiveBlock(fp.ParsedStatement): return None +OPST = TypeVar("OPST", bound="PintParsedStatement") +IPST = TypeVar("IPST", bound="PintParsedStatement") + +DefT = TypeVar("DefT") + + @dataclass(frozen=True) -class DirectiveBlock(fp.Block): +class DirectiveBlock( + Generic[DefT, OPST, IPST], fp.Block[OPST, IPST, EndDirectiveBlock, ParserConfig] +): """Directive blocks have beginning statement starting with a @ character. and ending with a "@end" (captured using a EndDirectiveBlock). @@ -41,5 +52,5 @@ class DirectiveBlock(fp.Block): closing: EndDirectiveBlock - def derive_definition(self): - pass + def derive_definition(self) -> DefT: + ... diff --git a/pint/delegates/txt_defparser/common.py b/pint/delegates/txt_defparser/common.py index 493d0ec..a1195b3 100644 --- a/pint/delegates/txt_defparser/common.py +++ b/pint/delegates/txt_defparser/common.py @@ -30,7 +30,7 @@ class DefinitionSyntaxError(errors.DefinitionSyntaxError, fp.ParsingError): location: str = field(init=False, default="") - def __str__(self): + def __str__(self) -> str: msg = ( self.msg + "\n " + (self.format_position or "") + " " + (self.raw or "") ) @@ -38,7 +38,7 @@ class DefinitionSyntaxError(errors.DefinitionSyntaxError, fp.ParsingError): msg += "\n " + self.location return msg - def set_location(self, value): + def set_location(self, value: str) -> None: super().__setattr__("location", value) @@ -47,7 +47,7 @@ class ImportDefinition(fp.IncludeStatement): value: str @property - def target(self): + def target(self) -> str: return self.value @classmethod diff --git a/pint/delegates/txt_defparser/context.py b/pint/delegates/txt_defparser/context.py index b7e5a67..ce9fc9b 100644 --- a/pint/delegates/txt_defparser/context.py +++ b/pint/delegates/txt_defparser/context.py @@ -23,32 +23,32 @@ from dataclasses import dataclass from ..._vendor import flexparser as fp from ...facets.context import definitions -from ..base_defparser import ParserConfig +from ..base_defparser import ParserConfig, PintParsedStatement from . import block, common, plain +# TODO check syntax +T = ty.TypeVar("T", bound="ForwardRelation | BidirectionalRelation") -@dataclass(frozen=True) -class Relation(definitions.Relation): - @classmethod - def _from_string_and_context_sep( - cls, s: str, config: ParserConfig, separator: str - ) -> fp.FromString[Relation]: - if separator not in s: - return None - if ":" not in s: - return None - rel, eq = s.split(":") +def _from_string_and_context_sep( + cls: type[T], s: str, config: ParserConfig, separator: str +) -> T | None: + if separator not in s: + return None + if ":" not in s: + return None + + rel, eq = s.split(":") - parts = rel.split(separator) + parts = rel.split(separator) - src, dst = (config.to_dimension_container(s) for s in parts) + src, dst = (config.to_dimension_container(s) for s in parts) - return cls(src, dst, eq.strip()) + return cls(src, dst, eq.strip()) @dataclass(frozen=True) -class ForwardRelation(fp.ParsedStatement, definitions.ForwardRelation, Relation): +class ForwardRelation(PintParsedStatement, definitions.ForwardRelation): """A relation connecting a dimension to another via a transformation function. <source dimension> -> <target dimension>: <transformation function> @@ -58,13 +58,11 @@ class ForwardRelation(fp.ParsedStatement, definitions.ForwardRelation, Relation) def from_string_and_config( cls, s: str, config: ParserConfig ) -> fp.FromString[ForwardRelation]: - return super()._from_string_and_context_sep(s, config, "->") + return _from_string_and_context_sep(cls, s, config, "->") @dataclass(frozen=True) -class BidirectionalRelation( - fp.ParsedStatement, definitions.BidirectionalRelation, Relation -): +class BidirectionalRelation(PintParsedStatement, definitions.BidirectionalRelation): """A bidirectional relation connecting a dimension to another via a simple transformation function. @@ -76,11 +74,11 @@ class BidirectionalRelation( def from_string_and_config( cls, s: str, config: ParserConfig ) -> fp.FromString[BidirectionalRelation]: - return super()._from_string_and_context_sep(s, config, "<->") + return _from_string_and_context_sep(cls, s, config, "<->") @dataclass(frozen=True) -class BeginContext(fp.ParsedStatement): +class BeginContext(PintParsedStatement): """Being of a context directive. @context[(defaults)] <canonical name> [= <alias>] [= <alias>] @@ -91,7 +89,7 @@ class BeginContext(fp.ParsedStatement): ) name: str - aliases: tuple[str, ...] + aliases: tuple[str] defaults: dict[str, numbers.Number] @classmethod @@ -130,7 +128,18 @@ class BeginContext(fp.ParsedStatement): @dataclass(frozen=True) -class ContextDefinition(block.DirectiveBlock): +class ContextDefinition( + block.DirectiveBlock[ + definitions.ContextDefinition, + BeginContext, + ty.Union[ + plain.CommentDefinition, + BidirectionalRelation, + ForwardRelation, + plain.UnitDefinition, + ], + ] +): """Definition of a Context @context[(defaults)] <canonical name> [= <alias>] [= <alias>] @@ -169,27 +178,34 @@ class ContextDefinition(block.DirectiveBlock): ] ] - def derive_definition(self): + def derive_definition(self) -> definitions.ContextDefinition: return definitions.ContextDefinition( self.name, self.aliases, self.defaults, self.relations, self.redefinitions ) @property - def name(self): + def name(self) -> str: + assert isinstance(self.opening, BeginContext) return self.opening.name @property - def aliases(self): + def aliases(self) -> tuple[str]: + assert isinstance(self.opening, BeginContext) return self.opening.aliases @property - def defaults(self): + def defaults(self) -> dict[str, numbers.Number]: + assert isinstance(self.opening, BeginContext) return self.opening.defaults @property - def relations(self): - return tuple(r for r in self.body if isinstance(r, Relation)) + def relations(self) -> tuple[BidirectionalRelation | ForwardRelation]: + return tuple( + r + for r in self.body + if isinstance(r, (ForwardRelation, BidirectionalRelation)) + ) @property - def redefinitions(self): + def redefinitions(self) -> tuple[plain.UnitDefinition]: return tuple(r for r in self.body if isinstance(r, plain.UnitDefinition)) diff --git a/pint/delegates/txt_defparser/defaults.py b/pint/delegates/txt_defparser/defaults.py index af6e31f..688d90f 100644 --- a/pint/delegates/txt_defparser/defaults.py +++ b/pint/delegates/txt_defparser/defaults.py @@ -19,10 +19,11 @@ from dataclasses import dataclass, fields from ..._vendor import flexparser as fp from ...facets.plain import definitions from . import block, plain +from ..base_defparser import PintParsedStatement @dataclass(frozen=True) -class BeginDefaults(fp.ParsedStatement): +class BeginDefaults(PintParsedStatement): """Being of a defaults directive. @defaults @@ -36,7 +37,16 @@ class BeginDefaults(fp.ParsedStatement): @dataclass(frozen=True) -class DefaultsDefinition(block.DirectiveBlock): +class DefaultsDefinition( + block.DirectiveBlock[ + definitions.DefaultsDefinition, + BeginDefaults, + ty.Union[ + plain.CommentDefinition, + plain.Equality, + ], + ] +): """Directive to store values. @defaults @@ -55,10 +65,10 @@ class DefaultsDefinition(block.DirectiveBlock): ] @property - def _valid_fields(self): + def _valid_fields(self) -> tuple[str]: return tuple(f.name for f in fields(definitions.DefaultsDefinition)) - def derive_definition(self): + def derive_definition(self) -> definitions.DefaultsDefinition: for definition in self.filter_by(plain.Equality): if definition.lhs not in self._valid_fields: raise ValueError( @@ -70,7 +80,7 @@ class DefaultsDefinition(block.DirectiveBlock): *tuple(self.get_key(key) for key in self._valid_fields) ) - def get_key(self, key): + def get_key(self, key: str) -> str: for stmt in self.body: if isinstance(stmt, plain.Equality) and stmt.lhs == key: return stmt.rhs diff --git a/pint/delegates/txt_defparser/defparser.py b/pint/delegates/txt_defparser/defparser.py index 0b99d6d..f1b8e45 100644 --- a/pint/delegates/txt_defparser/defparser.py +++ b/pint/delegates/txt_defparser/defparser.py @@ -5,11 +5,28 @@ import typing as ty from ..._vendor import flexcache as fc from ..._vendor import flexparser as fp -from .. import base_defparser +from ..base_defparser import ParserConfig from . import block, common, context, defaults, group, plain, system -class PintRootBlock(fp.RootBlock): +class PintRootBlock( + fp.RootBlock[ + ty.Union[ + plain.CommentDefinition, + common.ImportDefinition, + context.ContextDefinition, + defaults.DefaultsDefinition, + system.SystemDefinition, + group.GroupDefinition, + plain.AliasDefinition, + plain.DerivedDimensionDefinition, + plain.DimensionDefinition, + plain.PrefixDefinition, + plain.UnitDefinition, + ], + ParserConfig, + ] +): body: fp.Multi[ ty.Union[ plain.CommentDefinition, @@ -27,11 +44,15 @@ class PintRootBlock(fp.RootBlock): ] +class PintSource(fp.ParsedSource[PintRootBlock, ParserConfig]): + """Source code in Pint.""" + + class HashTuple(tuple): pass -class _PintParser(fp.Parser): +class _PintParser(fp.Parser[PintRootBlock, ParserConfig]): """Parser for the original Pint definition file, with cache.""" _delimiters = { @@ -46,11 +67,11 @@ class _PintParser(fp.Parser): _diskcache: fc.DiskCache - def __init__(self, config: base_defparser.ParserConfig, *args, **kwargs): + def __init__(self, config: ParserConfig, *args, **kwargs): self._diskcache = kwargs.pop("diskcache", None) super().__init__(config, *args, **kwargs) - def parse_file(self, path: pathlib.Path) -> fp.ParsedSource: + def parse_file(self, path: pathlib.Path) -> PintSource: if self._diskcache is None: return super().parse_file(path) content, basename = self._diskcache.load(path, super().parse_file) @@ -58,7 +79,13 @@ class _PintParser(fp.Parser): class DefParser: - skip_classes = (fp.BOF, fp.BOR, fp.BOS, fp.EOS, plain.CommentDefinition) + skip_classes: tuple[type] = ( + fp.BOF, + fp.BOR, + fp.BOS, + fp.EOS, + plain.CommentDefinition, + ) def __init__(self, default_config, diskcache): self._default_config = default_config @@ -78,6 +105,8 @@ class DefParser: continue if isinstance(stmt, common.DefinitionSyntaxError): + # TODO: check why this assert fails + # assert isinstance(last_location, str) stmt.set_location(last_location) raise stmt elif isinstance(stmt, block.DirectiveBlock): @@ -101,7 +130,7 @@ class DefParser: else: yield stmt - def parse_file(self, filename: pathlib.Path, cfg=None): + def parse_file(self, filename: pathlib.Path, cfg: ParserConfig | None = None): return fp.parse( filename, _PintParser, @@ -109,7 +138,7 @@ class DefParser: diskcache=self._diskcache, ) - def parse_string(self, content: str, cfg=None): + def parse_string(self, content: str, cfg: ParserConfig | None = None): return fp.parse_bytes( content.encode("utf-8"), _PintParser, diff --git a/pint/delegates/txt_defparser/group.py b/pint/delegates/txt_defparser/group.py index 5be42ac..e96d44b 100644 --- a/pint/delegates/txt_defparser/group.py +++ b/pint/delegates/txt_defparser/group.py @@ -23,10 +23,11 @@ from dataclasses import dataclass from ..._vendor import flexparser as fp from ...facets.group import definitions from . import block, common, plain +from ..base_defparser import PintParsedStatement @dataclass(frozen=True) -class BeginGroup(fp.ParsedStatement): +class BeginGroup(PintParsedStatement): """Being of a group directive. @group <name> [using <group 1>, ..., <group N>] @@ -59,7 +60,16 @@ class BeginGroup(fp.ParsedStatement): @dataclass(frozen=True) -class GroupDefinition(block.DirectiveBlock): +class GroupDefinition( + block.DirectiveBlock[ + definitions.GroupDefinition, + BeginGroup, + ty.Union[ + plain.CommentDefinition, + plain.UnitDefinition, + ], + ] +): """Definition of a group. @group <name> [using <group 1>, ..., <group N>] @@ -88,19 +98,21 @@ class GroupDefinition(block.DirectiveBlock): ] ] - def derive_definition(self): + def derive_definition(self) -> definitions.GroupDefinition: return definitions.GroupDefinition( self.name, self.using_group_names, self.definitions ) @property - def name(self): + def name(self) -> str: + assert isinstance(self.opening, BeginGroup) return self.opening.name @property - def using_group_names(self): + def using_group_names(self) -> tuple[str]: + assert isinstance(self.opening, BeginGroup) return self.opening.using_group_names @property - def definitions(self) -> ty.Tuple[plain.UnitDefinition, ...]: + def definitions(self) -> tuple[plain.UnitDefinition]: return tuple(el for el in self.body if isinstance(el, plain.UnitDefinition)) diff --git a/pint/delegates/txt_defparser/plain.py b/pint/delegates/txt_defparser/plain.py index 749e7fd..9c7bd42 100644 --- a/pint/delegates/txt_defparser/plain.py +++ b/pint/delegates/txt_defparser/plain.py @@ -29,12 +29,12 @@ from ..._vendor import flexparser as fp from ...converters import Converter from ...facets.plain import definitions from ...util import UnitsContainer -from ..base_defparser import ParserConfig +from ..base_defparser import ParserConfig, PintParsedStatement from . import common @dataclass(frozen=True) -class Equality(fp.ParsedStatement, definitions.Equality): +class Equality(PintParsedStatement, definitions.Equality): """An equality statement contains a left and right hand separated lhs and rhs should be space stripped. @@ -53,7 +53,7 @@ class Equality(fp.ParsedStatement, definitions.Equality): @dataclass(frozen=True) -class CommentDefinition(fp.ParsedStatement, definitions.CommentDefinition): +class CommentDefinition(PintParsedStatement, definitions.CommentDefinition): """Comments start with a # character. # This is a comment. @@ -63,14 +63,14 @@ class CommentDefinition(fp.ParsedStatement, definitions.CommentDefinition): """ @classmethod - def from_string(cls, s: str) -> fp.FromString[fp.ParsedStatement]: + def from_string(cls, s: str) -> fp.FromString[CommentDefinition]: if not s.startswith("#"): return None return cls(s[1:].strip()) @dataclass(frozen=True) -class PrefixDefinition(fp.ParsedStatement, definitions.PrefixDefinition): +class PrefixDefinition(PintParsedStatement, definitions.PrefixDefinition): """Definition of a prefix:: <prefix>- = <value> [= <symbol>] [= <alias>] [ = <alias> ] [...] @@ -119,7 +119,7 @@ class PrefixDefinition(fp.ParsedStatement, definitions.PrefixDefinition): @dataclass(frozen=True) -class UnitDefinition(fp.ParsedStatement, definitions.UnitDefinition): +class UnitDefinition(PintParsedStatement, definitions.UnitDefinition): """Definition of a unit:: <canonical name> = <relation to another unit or dimension> [= <symbol>] [= <alias>] [ = <alias> ] [...] @@ -194,7 +194,7 @@ class UnitDefinition(fp.ParsedStatement, definitions.UnitDefinition): @dataclass(frozen=True) -class DimensionDefinition(fp.ParsedStatement, definitions.DimensionDefinition): +class DimensionDefinition(PintParsedStatement, definitions.DimensionDefinition): """Definition of a root dimension:: [dimension name] @@ -221,7 +221,7 @@ class DimensionDefinition(fp.ParsedStatement, definitions.DimensionDefinition): @dataclass(frozen=True) class DerivedDimensionDefinition( - fp.ParsedStatement, definitions.DerivedDimensionDefinition + PintParsedStatement, definitions.DerivedDimensionDefinition ): """Definition of a derived dimension:: @@ -261,7 +261,7 @@ class DerivedDimensionDefinition( @dataclass(frozen=True) -class AliasDefinition(fp.ParsedStatement, definitions.AliasDefinition): +class AliasDefinition(PintParsedStatement, definitions.AliasDefinition): """Additional alias(es) for an already existing unit:: @alias <canonical name or previous alias> = <alias> [ = <alias> ] [...] diff --git a/pint/delegates/txt_defparser/system.py b/pint/delegates/txt_defparser/system.py index b21fd7a..4efbb4d 100644 --- a/pint/delegates/txt_defparser/system.py +++ b/pint/delegates/txt_defparser/system.py @@ -14,11 +14,12 @@ from dataclasses import dataclass from ..._vendor import flexparser as fp from ...facets.system import definitions +from ..base_defparser import PintParsedStatement from . import block, common, plain @dataclass(frozen=True) -class BaseUnitRule(fp.ParsedStatement, definitions.BaseUnitRule): +class BaseUnitRule(PintParsedStatement, definitions.BaseUnitRule): @classmethod def from_string(cls, s: str) -> fp.FromString[BaseUnitRule]: if ":" not in s: @@ -32,7 +33,7 @@ class BaseUnitRule(fp.ParsedStatement, definitions.BaseUnitRule): @dataclass(frozen=True) -class BeginSystem(fp.ParsedStatement): +class BeginSystem(PintParsedStatement): """Being of a system directive. @system <name> [using <group 1>, ..., <group N>] @@ -67,7 +68,13 @@ class BeginSystem(fp.ParsedStatement): @dataclass(frozen=True) -class SystemDefinition(block.DirectiveBlock): +class SystemDefinition( + block.DirectiveBlock[ + definitions.SystemDefinition, + BeginSystem, + ty.Union[plain.CommentDefinition, BaseUnitRule], + ] +): """Definition of a System: @system <name> [using <group 1>, ..., <group N>] @@ -92,19 +99,21 @@ class SystemDefinition(block.DirectiveBlock): opening: fp.Single[BeginSystem] body: fp.Multi[ty.Union[plain.CommentDefinition, BaseUnitRule]] - def derive_definition(self): + def derive_definition(self) -> definitions.SystemDefinition: return definitions.SystemDefinition( self.name, self.using_group_names, self.rules ) @property - def name(self): + def name(self) -> str: + assert isinstance(self.opening, BeginSystem) return self.opening.name @property - def using_group_names(self): + def using_group_names(self) -> tuple[str]: + assert isinstance(self.opening, BeginSystem) return self.opening.using_group_names @property - def rules(self): + def rules(self) -> tuple[BaseUnitRule]: return tuple(el for el in self.body if isinstance(el, BaseUnitRule)) diff --git a/pint/errors.py b/pint/errors.py index 8f849da..6cebb21 100644 --- a/pint/errors.py +++ b/pint/errors.py @@ -36,18 +36,21 @@ MSG_INVALID_SYSTEM_NAME = ( ) -def is_dim(name): +def is_dim(name: str) -> bool: + """Return True if the name is flanked by square brackets `[` and `]`.""" return name[0] == "[" and name[-1] == "]" -def is_valid_prefix_name(name): +def is_valid_prefix_name(name: str) -> bool: + """Return True if the name is a valid python identifier or empty.""" return str.isidentifier(name) or name == "" is_valid_unit_name = is_valid_system_name = is_valid_context_name = str.isidentifier -def _no_space(name): +def _no_space(name: str) -> bool: + """Return False if the name contains a space in any position.""" return name.strip() == name and " " not in name @@ -58,7 +61,14 @@ is_valid_unit_alias = ( ) = is_valid_unit_symbol = is_valid_prefix_symbol = _no_space -def is_valid_dimension_name(name): +def is_valid_dimension_name(name: str) -> bool: + """Return True if the name is consistent with a dimension name. + + - flanked by square brackets. + - empty dimension name or identifier. + """ + + # TODO: shall we check also fro spaces? return name == "[]" or ( len(name) > 1 and is_dim(name) and str.isidentifier(name[1:-1]) ) @@ -67,8 +77,8 @@ def is_valid_dimension_name(name): class WithDefErr: """Mixing class to make some classes more readable.""" - def def_err(self, msg): - return DefinitionError(self.name, self.__class__.__name__, msg) + def def_err(self, msg: str): + return DefinitionError(self.name, self.__class__, msg) @dataclass(frozen=False) @@ -81,7 +91,7 @@ class DefinitionError(ValueError, PintError): """Raised when a definition is not properly constructed.""" name: str - definition_type: ty.Type + definition_type: type msg: str def __str__(self): @@ -110,7 +120,7 @@ class RedefinitionError(ValueError, PintError): """Raised when a unit or prefix is redefined.""" name: str - definition_type: ty.Type + definition_type: type def __str__(self): msg = f"Cannot redefine '{self.name}' ({self.definition_type})" @@ -124,7 +134,7 @@ class RedefinitionError(ValueError, PintError): class UndefinedUnitError(AttributeError, PintError): """Raised when the units are not defined in the unit registry.""" - unit_names: ty.Union[str, ty.Tuple[str, ...]] + unit_names: str | tuple[str] def __str__(self): if isinstance(self.unit_names, str): diff --git a/pint/facets/__init__.py b/pint/facets/__init__.py index 7b24463..750f729 100644 --- a/pint/facets/__init__.py +++ b/pint/facets/__init__.py @@ -82,13 +82,13 @@ from .plain import PlainRegistry from .system import SystemRegistry __all__ = [ - ContextRegistry, - DaskRegistry, - FormattingRegistry, - GroupRegistry, - MeasurementRegistry, - NonMultiplicativeRegistry, - NumpyRegistry, - PlainRegistry, - SystemRegistry, + "ContextRegistry", + "DaskRegistry", + "FormattingRegistry", + "GroupRegistry", + "MeasurementRegistry", + "NonMultiplicativeRegistry", + "NumpyRegistry", + "PlainRegistry", + "SystemRegistry", ] diff --git a/pint/facets/context/definitions.py b/pint/facets/context/definitions.py index d9ba473..07eb92f 100644 --- a/pint/facets/context/definitions.py +++ b/pint/facets/context/definitions.py @@ -12,7 +12,7 @@ import itertools import numbers import re from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any, Callable, Iterable from ... import errors from ..plain import UnitDefinition @@ -41,7 +41,7 @@ class Relation: # could be used. @property - def variables(self) -> set[str, ...]: + def variables(self) -> set[str]: """Find all variables names in the equation.""" return set(self._varname_re.findall(self.equation)) @@ -55,7 +55,7 @@ class Relation: ) @property - def bidirectional(self): + def bidirectional(self) -> bool: raise NotImplementedError @@ -92,18 +92,18 @@ class ContextDefinition(errors.WithDefErr): #: name of the context name: str #: other na - aliases: tuple[str, ...] + aliases: tuple[str] defaults: dict[str, numbers.Number] - relations: tuple[Relation, ...] - redefinitions: tuple[UnitDefinition, ...] + relations: tuple[Relation] + redefinitions: tuple[UnitDefinition] @property - def variables(self) -> set[str, ...]: + def variables(self) -> set[str]: """Return all variable names in all transformations.""" return set().union(*(r.variables for r in self.relations)) @classmethod - def from_lines(cls, lines, non_int_type): + def from_lines(cls, lines: Iterable[str], non_int_type: type): # TODO: this is to keep it backwards compatible from ...delegates import ParserConfig, txt_defparser diff --git a/pint/facets/context/objects.py b/pint/facets/context/objects.py index 58f8bb8..bec2a43 100644 --- a/pint/facets/context/objects.py +++ b/pint/facets/context/objects.py @@ -10,6 +10,7 @@ from __future__ import annotations import weakref from collections import ChainMap, defaultdict +from typing import Any, Iterable from ...facets.plain import UnitDefinition from ...util import UnitsContainer, to_units_container @@ -70,8 +71,8 @@ class Context: def __init__( self, name: str | None = None, - aliases: tuple[str, ...] = (), - defaults: dict | None = None, + aliases: tuple[str] = tuple(), + defaults: dict[str, Any] | None = None, ) -> None: self.name = name self.aliases = aliases @@ -93,7 +94,7 @@ class Context: self.relation_to_context = weakref.WeakValueDictionary() @classmethod - def from_context(cls, context: Context, **defaults) -> Context: + def from_context(cls, context: Context, **defaults: Any) -> Context: """Creates a new context that shares the funcs dictionary with the original context. The default values are copied from the original context and updated with the new defaults. @@ -122,7 +123,9 @@ class Context: return context @classmethod - def from_lines(cls, lines, to_base_func=None, non_int_type=float) -> Context: + def from_lines( + cls, lines: Iterable[str], to_base_func=None, non_int_type: type = float + ) -> Context: cd = ContextDefinition.from_lines(lines, non_int_type) return cls.from_definition(cd, to_base_func) @@ -273,7 +276,7 @@ class ContextChain(ChainMap): """ return self[(src, dst)].transform(src, dst, registry, value) - def hashable(self): + def hashable(self) -> tuple[Any]: """Generate a unique hashable and comparable representation of self, which can be used as a key in a dict. This class cannot define ``__hash__`` because it is mutable, and the Python interpreter does cache the output of ``__hash__``. diff --git a/pint/facets/group/definitions.py b/pint/facets/group/definitions.py index c0abced..48c6f4b 100644 --- a/pint/facets/group/definitions.py +++ b/pint/facets/group/definitions.py @@ -8,9 +8,10 @@ from __future__ import annotations -import typing as ty +from typing import Iterable from dataclasses import dataclass +from ..._typing import Self from ... import errors from .. import plain @@ -22,12 +23,14 @@ class GroupDefinition(errors.WithDefErr): #: name of the group name: str #: unit groups that will be included within the group - using_group_names: ty.Tuple[str, ...] + using_group_names: tuple[str] #: definitions for the units existing within the group - definitions: ty.Tuple[plain.UnitDefinition, ...] + definitions: tuple[plain.UnitDefinition] @classmethod - def from_lines(cls, lines, non_int_type): + def from_lines( + cls: type[Self], lines: Iterable[str], non_int_type: type + ) -> Self | None: # TODO: this is to keep it backwards compatible from ...delegates import ParserConfig, txt_defparser @@ -39,10 +42,10 @@ class GroupDefinition(errors.WithDefErr): return definition @property - def unit_names(self) -> ty.Tuple[str, ...]: + def unit_names(self) -> tuple[str]: return tuple(el.name for el in self.definitions) - def __post_init__(self): + def __post_init__(self) -> None: if not errors.is_valid_group_name(self.name): raise self.def_err(errors.MSG_INVALID_GROUP_NAME) diff --git a/pint/facets/group/objects.py b/pint/facets/group/objects.py index 558a107..a0a81be 100644 --- a/pint/facets/group/objects.py +++ b/pint/facets/group/objects.py @@ -8,6 +8,7 @@ from __future__ import annotations +from typing import Generator, Iterable from ...util import SharedRegistryObject, getattr_maybe_raise from .definitions import GroupDefinition @@ -23,32 +24,26 @@ class Group(SharedRegistryObject): The group belongs to one Registry. See GroupDefinition for the definition file syntax. - """ - def __init__(self, name): - """ - :param name: Name of the group. If not given, a root Group will be created. - :type name: str - :param groups: dictionary like object groups and system. - The newly created group will be added after creation. - :type groups: dict[str | Group] - """ + Parameters + ---------- + name + If not given, a root Group will be created. + """ + def __init__(self, name: str): # The name of the group. - #: type: str self.name = name #: Names of the units in this group. #: :type: set[str] - self._unit_names = set() + self._unit_names: set[str] = set() #: Names of the groups in this group. - #: :type: set[str] - self._used_groups = set() + self._used_groups: set[str] = set() #: Names of the groups in which this group is contained. - #: :type: set[str] - self._used_by = set() + self._used_by: set[str] = set() # Add this group to the group dictionary self._REGISTRY._groups[self.name] = self @@ -59,8 +54,7 @@ class Group(SharedRegistryObject): #: A cache of the included units. #: None indicates that the cache has been invalidated. - #: :type: frozenset[str] | None - self._computed_members = None + self._computed_members: frozenset[str] | None = None @property def members(self): @@ -70,23 +64,23 @@ class Group(SharedRegistryObject): """ if self._computed_members is None: - self._computed_members = set(self._unit_names) + tmp = set(self._unit_names) for _, group in self.iter_used_groups(): - self._computed_members |= group.members + tmp |= group.members - self._computed_members = frozenset(self._computed_members) + self._computed_members = frozenset(tmp) return self._computed_members - def invalidate_members(self): + def invalidate_members(self) -> None: """Invalidate computed members in this Group and all parent nodes.""" self._computed_members = None d = self._REGISTRY._groups for name in self._used_by: d[name].invalidate_members() - def iter_used_groups(self): + def iter_used_groups(self) -> Generator[tuple[str, Group], None, None]: pending = set(self._used_groups) d = self._REGISTRY._groups while pending: @@ -95,13 +89,13 @@ class Group(SharedRegistryObject): pending |= group._used_groups yield name, d[name] - def is_used_group(self, group_name): + def is_used_group(self, group_name: str) -> bool: for name, _ in self.iter_used_groups(): if name == group_name: return True return False - def add_units(self, *unit_names): + def add_units(self, *unit_names: str) -> None: """Add units to group.""" for unit_name in unit_names: self._unit_names.add(unit_name) @@ -109,17 +103,17 @@ class Group(SharedRegistryObject): self.invalidate_members() @property - def non_inherited_unit_names(self): + def non_inherited_unit_names(self) -> frozenset[str]: return frozenset(self._unit_names) - def remove_units(self, *unit_names): + def remove_units(self, *unit_names: str) -> None: """Remove units from group.""" for unit_name in unit_names: self._unit_names.remove(unit_name) self.invalidate_members() - def add_groups(self, *group_names): + def add_groups(self, *group_names: str) -> None: """Add groups to group.""" d = self._REGISTRY._groups for group_name in group_names: @@ -136,7 +130,7 @@ class Group(SharedRegistryObject): self.invalidate_members() - def remove_groups(self, *group_names): + def remove_groups(self, *group_names: str) -> None: """Remove groups from group.""" d = self._REGISTRY._groups for group_name in group_names: @@ -148,7 +142,9 @@ class Group(SharedRegistryObject): self.invalidate_members() @classmethod - def from_lines(cls, lines, define_func, non_int_type=float) -> Group: + def from_lines( + cls, lines: Iterable[str], define_func, non_int_type: type = float + ) -> Group: """Return a Group object parsing an iterable of lines. Parameters @@ -190,6 +186,6 @@ class Group(SharedRegistryObject): return grp - def __getattr__(self, item): + def __getattr__(self, item: str): getattr_maybe_raise(self, item) return self._REGISTRY diff --git a/pint/facets/nonmultiplicative/definitions.py b/pint/facets/nonmultiplicative/definitions.py index dbfc0ff..f795cf0 100644 --- a/pint/facets/nonmultiplicative/definitions.py +++ b/pint/facets/nonmultiplicative/definitions.py @@ -10,6 +10,7 @@ from __future__ import annotations from dataclasses import dataclass +from ..._typing import Magnitude from ...compat import HAS_NUMPY, exp, log from ..plain import ScaleConverter @@ -24,7 +25,7 @@ class OffsetConverter(ScaleConverter): def is_multiplicative(self): return self.offset == 0 - def to_reference(self, value, inplace=False): + def to_reference(self, value: Magnitude, inplace: bool = False) -> Magnitude: if inplace: value *= self.scale value += self.offset @@ -33,7 +34,7 @@ class OffsetConverter(ScaleConverter): return value - def from_reference(self, value, inplace=False): + def from_reference(self, value: Magnitude, inplace: bool = False) -> Magnitude: if inplace: value -= self.offset value /= self.scale @@ -66,6 +67,7 @@ class LogarithmicConverter(ScaleConverter): controls if computation is done in place """ + # TODO: Can I use PintScalar here? logbase: float logfactor: float @@ -77,7 +79,7 @@ class LogarithmicConverter(ScaleConverter): def is_logarithmic(self): return True - def from_reference(self, value, inplace=False): + def from_reference(self, value: Magnitude, inplace: bool = False) -> Magnitude: """Converts value from the reference unit to the logarithmic unit dBm <------ mW @@ -95,7 +97,7 @@ class LogarithmicConverter(ScaleConverter): return value - def to_reference(self, value, inplace=False): + def to_reference(self, value: Magnitude, inplace: bool = False) -> Magnitude: """Converts value to the reference unit from the logarithmic unit dBm ------> mW diff --git a/pint/facets/nonmultiplicative/objects.py b/pint/facets/nonmultiplicative/objects.py index 7f9064d..0ab743e 100644 --- a/pint/facets/nonmultiplicative/objects.py +++ b/pint/facets/nonmultiplicative/objects.py @@ -40,7 +40,7 @@ class NonMultiplicativeQuantity(PlainQuantity): self._get_unit_definition(d).reference == offset_unit_dim for d in deltas ) - def _ok_for_muldiv(self, no_offset_units=None) -> bool: + def _ok_for_muldiv(self, no_offset_units: int | None = None) -> bool: """Checks if PlainQuantity object can be multiplied or divided""" is_ok = True diff --git a/pint/facets/plain/definitions.py b/pint/facets/plain/definitions.py index eb45db1..79a44f1 100644 --- a/pint/facets/plain/definitions.py +++ b/pint/facets/plain/definitions.py @@ -13,8 +13,9 @@ import numbers import typing as ty from dataclasses import dataclass from functools import cached_property -from typing import Callable +from typing import Callable, Any +from ..._typing import Magnitude from ... import errors from ...converters import Converter from ...util import UnitsContainer @@ -23,7 +24,7 @@ from ...util import UnitsContainer class NotNumeric(Exception): """Internal exception. Do not expose outside Pint""" - def __init__(self, value): + def __init__(self, value: Any): self.value = value @@ -115,18 +116,26 @@ class UnitDefinition(errors.WithDefErr): #: canonical name of the unit name: str #: canonical symbol - defined_symbol: ty.Optional[str] + defined_symbol: str | None #: additional names for the same unit - aliases: ty.Tuple[str, ...] + aliases: tuple[str] #: A functiont that converts a value in these units into the reference units - converter: ty.Optional[ty.Union[Callable, Converter]] + converter: Callable[ + [ + Magnitude, + ], + Magnitude, + ] | Converter | None #: Reference units. - reference: ty.Optional[UnitsContainer] + reference: UnitsContainer | None def __post_init__(self): if not errors.is_valid_unit_name(self.name): raise self.def_err(errors.MSG_INVALID_UNIT_NAME) + # TODO: check why reference: UnitsContainer | None + assert isinstance(self.reference, UnitsContainer) + if not any(map(errors.is_dim, self.reference.keys())): invalid = tuple( itertools.filterfalse(errors.is_valid_unit_name, self.reference.keys()) @@ -180,14 +189,20 @@ class UnitDefinition(errors.WithDefErr): @property def is_base(self) -> bool: """Indicates if it is a base unit.""" + + # TODO: why is this here return self._is_base @property def is_multiplicative(self) -> bool: + # TODO: Check how to avoid this check + assert isinstance(self.converter, Converter) return self.converter.is_multiplicative @property def is_logarithmic(self) -> bool: + # TODO: Check how to avoid this check + assert isinstance(self.converter, Converter) return self.converter.is_logarithmic @property @@ -272,7 +287,7 @@ class ScaleConverter(Converter): scale: float - def to_reference(self, value, inplace=False): + def to_reference(self, value: Magnitude, inplace: bool = False) -> Magnitude: if inplace: value *= self.scale else: @@ -280,7 +295,7 @@ class ScaleConverter(Converter): return value - def from_reference(self, value, inplace=False): + def from_reference(self, value: Magnitude, inplace: bool = False) -> Magnitude: if inplace: value /= self.scale else: diff --git a/pint/facets/plain/objects.py b/pint/facets/plain/objects.py index 5b2837b..a868c7f 100644 --- a/pint/facets/plain/objects.py +++ b/pint/facets/plain/objects.py @@ -11,4 +11,4 @@ from __future__ import annotations from .quantity import PlainQuantity from .unit import PlainUnit, UnitsContainer -__all__ = [PlainUnit, PlainQuantity, UnitsContainer] +__all__ = ["PlainUnit", "PlainQuantity", "UnitsContainer"] diff --git a/pint/facets/system/definitions.py b/pint/facets/system/definitions.py index 8243324..893c510 100644 --- a/pint/facets/system/definitions.py +++ b/pint/facets/system/definitions.py @@ -8,9 +8,10 @@ from __future__ import annotations -import typing as ty +from typing import Iterable from dataclasses import dataclass +from ..._typing import Self from ... import errors @@ -23,7 +24,7 @@ class BaseUnitRule: new_unit_name: str #: name of the unit to be kicked out to make room for the new base uni #: If None, the current base unit with the same dimensionality will be used - old_unit_name: ty.Optional[str] = None + old_unit_name: str | None = None # Instead of defining __post_init__ here, # it will be added to the container class @@ -38,13 +39,16 @@ class SystemDefinition(errors.WithDefErr): #: name of the system name: str #: unit groups that will be included within the system - using_group_names: ty.Tuple[str, ...] + using_group_names: tuple[str] #: rules to define new base unit within the system. - rules: ty.Tuple[BaseUnitRule, ...] + rules: tuple[BaseUnitRule] @classmethod - def from_lines(cls, lines, non_int_type): + def from_lines( + cls: type[Self], lines: Iterable[str], non_int_type: type + ) -> Self | None: # TODO: this is to keep it backwards compatible + # TODO: check when is None returned. from ...delegates import ParserConfig, txt_defparser cfg = ParserConfig(non_int_type) @@ -55,7 +59,8 @@ class SystemDefinition(errors.WithDefErr): return definition @property - def unit_replacements(self) -> ty.Tuple[ty.Tuple[str, str], ...]: + def unit_replacements(self) -> tuple[tuple[str, str | None]]: + # TODO: check if None can be dropped. return tuple((el.new_unit_name, el.old_unit_name) for el in self.rules) def __post_init__(self): diff --git a/pint/facets/system/objects.py b/pint/facets/system/objects.py index 829fb5c..7af65a6 100644 --- a/pint/facets/system/objects.py +++ b/pint/facets/system/objects.py @@ -9,6 +9,12 @@ from __future__ import annotations +import numbers + +from typing import Any, Iterable + +from ..._typing import Self + from ...babel_names import _babel_systems from ...compat import babel_parse from ...util import ( @@ -29,32 +35,28 @@ class System(SharedRegistryObject): The System belongs to one Registry. See SystemDefinition for the definition file syntax. - """ - def __init__(self, name): - """ - :param name: Name of the group - :type name: str - """ + Parameters + ---------- + name + Name of the group. + """ + def __init__(self, name: str): #: Name of the system #: :type: str self.name = name #: Maps root unit names to a dict indicating the new unit and its exponent. - #: :type: dict[str, dict[str, number]]] - self.base_units = {} + self.base_units: dict[str, dict[str, numbers.Number]] = {} #: Derived unit names. - #: :type: set(str) - self.derived_units = set() + self.derived_units: set[str] = set() #: Names of the _used_groups in used by this system. - #: :type: set(str) - self._used_groups = set() + self._used_groups: set[str] = set() - #: :type: frozenset | None - self._computed_members = None + self._computed_members: frozenset[str] | None = None # Add this system to the system dictionary self._REGISTRY._systems[self.name] = self @@ -62,7 +64,7 @@ class System(SharedRegistryObject): def __dir__(self): return list(self.members) - def __getattr__(self, item): + def __getattr__(self, item: str) -> Any: getattr_maybe_raise(self, item) u = getattr(self._REGISTRY, self.name + "_" + item, None) if u is not None: @@ -93,19 +95,19 @@ class System(SharedRegistryObject): """Invalidate computed members in this Group and all parent nodes.""" self._computed_members = None - def add_groups(self, *group_names): + def add_groups(self, *group_names: str) -> None: """Add groups to group.""" self._used_groups |= set(group_names) self.invalidate_members() - def remove_groups(self, *group_names): + def remove_groups(self, *group_names: str) -> None: """Remove groups from group.""" self._used_groups -= set(group_names) self.invalidate_members() - def format_babel(self, locale): + def format_babel(self, locale: str) -> str: """translate the name of the system.""" if locale and self.name in _babel_systems: name = _babel_systems[self.name] @@ -114,8 +116,12 @@ class System(SharedRegistryObject): return self.name @classmethod - def from_lines(cls, lines, get_root_func, non_int_type=float): - system_definition = SystemDefinition.from_lines(lines, get_root_func) + def from_lines( + cls: type[Self], lines: Iterable[str], get_root_func, non_int_type: type = float + ) -> Self: + # TODO: we changed something here it used to be + # system_definition = SystemDefinition.from_lines(lines, get_root_func) + system_definition = SystemDefinition.from_lines(lines, non_int_type) return cls.from_definition(system_definition, get_root_func) @classmethod @@ -174,12 +180,12 @@ class System(SharedRegistryObject): class Lister: - def __init__(self, d): + def __init__(self, d: dict[str, Any]): self.d = d - def __dir__(self): + def __dir__(self) -> list[str]: return list(self.d.keys()) - def __getattr__(self, item): + def __getattr__(self, item: str) -> Any: getattr_maybe_raise(self, item) return self.d[item] diff --git a/pint/formatting.py b/pint/formatting.py index dcc8725..637d838 100644 --- a/pint/formatting.py +++ b/pint/formatting.py @@ -13,7 +13,8 @@ from __future__ import annotations import functools import re import warnings -from typing import Callable +from typing import Callable, Iterable, Any +from numbers import Number from .babel_names import _babel_lengths, _babel_units from .compat import babel_parse @@ -21,7 +22,7 @@ from .compat import babel_parse __JOIN_REG_EXP = re.compile(r"{\d*}") -def _join(fmt, iterable): +def _join(fmt: str, iterable: Iterable[Any]): """Join an iterable with the format specified in fmt. The format can be specified in two ways: @@ -55,7 +56,7 @@ def _join(fmt, iterable): _PRETTY_EXPONENTS = "⁰¹²³⁴⁵⁶⁷⁸⁹" -def _pretty_fmt_exponent(num): +def _pretty_fmt_exponent(num: Number) -> str: """Format an number into a pretty printed exponent. Parameters @@ -76,7 +77,7 @@ def _pretty_fmt_exponent(num): #: _FORMATS maps format specifications to the corresponding argument set to #: formatter(). -_FORMATS: dict[str, dict] = { +_FORMATS: dict[str, dict[str, Any]] = { "P": { # Pretty format. "as_ratio": True, "single_denominator": False, @@ -125,7 +126,7 @@ _FORMATS: dict[str, dict] = { _FORMATTERS: dict[str, Callable] = {} -def register_unit_format(name): +def register_unit_format(name: str): """register a function as a new format for units The registered function must have a signature of: @@ -268,18 +269,18 @@ def format_compact(unit, registry, **options): def formatter( - items, - as_ratio=True, - single_denominator=False, - product_fmt=" * ", - division_fmt=" / ", - power_fmt="{} ** {}", - parentheses_fmt="({0})", + items: list[tuple[str, Number]], + as_ratio: bool = True, + single_denominator: bool = False, + product_fmt: str = " * ", + division_fmt: str = " / ", + power_fmt: str = "{} ** {}", + parentheses_fmt: str = "({0})", exp_call=lambda x: f"{x:n}", - locale=None, - babel_length="long", - babel_plural_form="one", - sort=True, + locale: str | None = None, + babel_length: str = "long", + babel_plural_form: str = "one", + sort: bool = True, ): """Format a list of (name, exponent) pairs. diff --git a/pint/pint_eval.py b/pint/pint_eval.py index e776d60..d476eae 100644 --- a/pint/pint_eval.py +++ b/pint/pint_eval.py @@ -11,7 +11,9 @@ from __future__ import annotations import operator import token as tokenlib -import tokenize +from tokenize import TokenInfo + +from typing import Any from .errors import DefinitionSyntaxError @@ -30,7 +32,7 @@ _OP_PRIORITY = { } -def _power(left, right): +def _power(left: Any, right: Any) -> Any: from . import Quantity from .compat import is_duck_array @@ -45,7 +47,19 @@ def _power(left, right): return operator.pow(left, right) -_BINARY_OPERATOR_MAP = { +import typing + +UnaryOpT = typing.Callable[ + [ + Any, + ], + Any, +] +BinaryOpT = typing.Callable[[Any, Any], Any] + +_UNARY_OPERATOR_MAP: dict[str, UnaryOpT] = {"+": lambda x: x, "-": lambda x: x * -1} + +_BINARY_OPERATOR_MAP: dict[str, BinaryOpT] = { "**": _power, "*": operator.mul, "": operator.mul, # operator for implicit ops @@ -56,8 +70,6 @@ _BINARY_OPERATOR_MAP = { "//": operator.floordiv, } -_UNARY_OPERATOR_MAP = {"+": lambda x: x, "-": lambda x: x * -1} - class EvalTreeNode: """Single node within an evaluation tree @@ -68,25 +80,43 @@ class EvalTreeNode: left --> single value """ - def __init__(self, left, operator=None, right=None): + def __init__( + self, + left: EvalTreeNode | TokenInfo, + operator: TokenInfo | None = None, + right: EvalTreeNode | None = None, + ): self.left = left self.operator = operator self.right = right - def to_string(self): + def to_string(self) -> str: # For debugging purposes if self.right: + assert isinstance(self.left, EvalTreeNode), "self.left not EvalTreeNode (1)" comps = [self.left.to_string()] if self.operator: - comps.append(self.operator[1]) + comps.append(self.operator.string) comps.append(self.right.to_string()) elif self.operator: - comps = [self.operator[1], self.left.to_string()] + assert isinstance(self.left, EvalTreeNode), "self.left not EvalTreeNode (2)" + comps = [self.operator.string, self.left.to_string()] else: - return self.left[1] + assert isinstance(self.left, TokenInfo), "self.left not TokenInfo (1)" + return self.left.string return "(%s)" % " ".join(comps) - def evaluate(self, define_op, bin_op=None, un_op=None): + def evaluate( + self, + define_op: typing.Callable[ + [ + Any, + ], + Any, + ], + bin_op: dict[str, BinaryOpT] | None = None, + un_op: dict[str, UnaryOpT] | None = None, + ): """Evaluate node. Parameters @@ -107,17 +137,22 @@ class EvalTreeNode: un_op = un_op or _UNARY_OPERATOR_MAP if self.right: + assert isinstance(self.left, EvalTreeNode), "self.left not EvalTreeNode (3)" # binary or implicit operator - op_text = self.operator[1] if self.operator else "" + op_text = self.operator.string if self.operator else "" if op_text not in bin_op: - raise DefinitionSyntaxError('missing binary operator "%s"' % op_text) - left = self.left.evaluate(define_op, bin_op, un_op) - return bin_op[op_text](left, self.right.evaluate(define_op, bin_op, un_op)) + raise DefinitionSyntaxError(f"missing binary operator '{op_text}'") + + return bin_op[op_text]( + self.left.evaluate(define_op, bin_op, un_op), + self.right.evaluate(define_op, bin_op, un_op), + ) elif self.operator: + assert isinstance(self.left, EvalTreeNode), "self.left not EvalTreeNode (4)" # unary operator - op_text = self.operator[1] + op_text = self.operator.string if op_text not in un_op: - raise DefinitionSyntaxError('missing unary operator "%s"' % op_text) + raise DefinitionSyntaxError(f"missing unary operator '{op_text}'") return un_op[op_text](self.left.evaluate(define_op, bin_op, un_op)) # single value @@ -127,13 +162,13 @@ class EvalTreeNode: from collections.abc import Iterable -def build_eval_tree( - tokens: Iterable[tokenize.TokenInfo], - op_priority=None, - index=0, - depth=0, - prev_op=None, -) -> tuple[EvalTreeNode | None, int] | EvalTreeNode: +def _build_eval_tree( + tokens: list[TokenInfo], + op_priority: dict[str, int], + index: int = 0, + depth: int = 0, + prev_op: str = "<none>", +) -> tuple[EvalTreeNode, int]: """Build an evaluation tree from a set of tokens. Params: @@ -153,14 +188,12 @@ def build_eval_tree( 5) Combine left side, operator, and right side into a new left side 6) Go back to step #2 - """ - - if op_priority is None: - op_priority = _OP_PRIORITY + Raises + ------ + DefinitionSyntaxError + If there is a syntax error. - if depth == 0 and prev_op is None: - # ensure tokens is list so we can access by index - tokens = list(tokens) + """ result = None @@ -171,19 +204,21 @@ def build_eval_tree( if token_type == tokenlib.OP: if token_text == ")": - if prev_op is None: + if prev_op == "<none>": raise DefinitionSyntaxError( - "unopened parentheses in tokens: %s" % current_token + f"unopened parentheses in tokens: {current_token}" ) elif prev_op == "(": # close parenthetical group + assert result is not None return result, index else: # parenthetical group ending, but we need to close sub-operations within group + assert result is not None return result, index - 1 elif token_text == "(": # gather parenthetical group - right, index = build_eval_tree( + right, index = _build_eval_tree( tokens, op_priority, index + 1, 0, token_text ) if not tokens[index][1] == ")": @@ -208,7 +243,7 @@ def build_eval_tree( # previous operator is higher priority, so end previous binary op return result, index - 1 # get right side of binary op - right, index = build_eval_tree( + right, index = _build_eval_tree( tokens, op_priority, index + 1, depth + 1, token_text ) result = EvalTreeNode( @@ -216,7 +251,7 @@ def build_eval_tree( ) else: # unary operator - right, index = build_eval_tree( + right, index = _build_eval_tree( tokens, op_priority, index + 1, depth + 1, "unary" ) result = EvalTreeNode(left=right, operator=current_token) @@ -227,7 +262,7 @@ def build_eval_tree( # previous operator is higher priority than implicit, so end # previous binary op return result, index - 1 - right, index = build_eval_tree( + right, index = _build_eval_tree( tokens, op_priority, index, depth + 1, "" ) result = EvalTreeNode(left=result, right=right) @@ -240,13 +275,57 @@ def build_eval_tree( raise DefinitionSyntaxError("unclosed parentheses in tokens") if depth > 0 or prev_op: # have to close recursion + assert result is not None return result, index else: # recursion all closed, so just return the final result - return result + assert result is not None + return result, -1 if index + 1 >= len(tokens): # should hit ENDMARKER before this ever happens raise DefinitionSyntaxError("unexpected end to tokens") index += 1 + + +def build_eval_tree( + tokens: Iterable[TokenInfo], + op_priority: dict[str, int] | None = None, +) -> EvalTreeNode: + """Build an evaluation tree from a set of tokens. + + Params: + Index, depth, and prev_op used recursively, so don't touch. + Tokens is an iterable of tokens from an expression to be evaluated. + + Transform the tokens from an expression into a recursive parse tree, following order + of operations. Operations can include binary ops (3 + 4), implicit ops (3 kg), or + unary ops (-1). + + General Strategy: + 1) Get left side of operator + 2) If no tokens left, return final result + 3) Get operator + 4) Use recursion to create tree starting at token on right side of operator (start at step #1) + 4.1) If recursive call encounters an operator with lower or equal priority to step #2, exit recursion + 5) Combine left side, operator, and right side into a new left side + 6) Go back to step #2 + + Raises + ------ + DefinitionSyntaxError + If there is a syntax error. + + """ + + if op_priority is None: + op_priority = _OP_PRIORITY + + if not isinstance(tokens, list): + # ensure tokens is list so we can access by index + tokens = list(tokens) + + result, _ = _build_eval_tree(tokens, op_priority, 0, 0) + + return result diff --git a/pint/testsuite/test_compat_downcast.py b/pint/testsuite/test_compat_downcast.py index 4ca611d..ed43e94 100644 --- a/pint/testsuite/test_compat_downcast.py +++ b/pint/testsuite/test_compat_downcast.py @@ -38,7 +38,7 @@ def q_base(local_registry): # Define identity function for use in tests -def identity(ureg, x): +def id_matrix(ureg, x): return x @@ -63,17 +63,17 @@ def array(request): @pytest.mark.parametrize( "op, magnitude_op, unit_op", [ - pytest.param(identity, identity, identity, id="identity"), + pytest.param(id_matrix, id_matrix, id_matrix, id="identity"), pytest.param( lambda ureg, x: x + 1 * ureg.m, lambda ureg, x: x + 1, - identity, + id_matrix, id="addition", ), pytest.param( lambda ureg, x: x - 20 * ureg.cm, lambda ureg, x: x - 0.2, - identity, + id_matrix, id="subtraction", ), pytest.param( @@ -84,7 +84,7 @@ def array(request): ), pytest.param( lambda ureg, x: x / (1 * ureg.s), - identity, + id_matrix, lambda ureg, u: u / ureg.s, id="division", ), @@ -94,17 +94,17 @@ def array(request): WR(lambda u: u**2), id="square", ), - pytest.param(WR(lambda x: x.T), WR(lambda x: x.T), identity, id="transpose"), - pytest.param(WR(np.mean), WR(np.mean), identity, id="mean ufunc"), - pytest.param(WR(np.sum), WR(np.sum), identity, id="sum ufunc"), + pytest.param(WR(lambda x: x.T), WR(lambda x: x.T), id_matrix, id="transpose"), + pytest.param(WR(np.mean), WR(np.mean), id_matrix, id="mean ufunc"), + pytest.param(WR(np.sum), WR(np.sum), id_matrix, id="sum ufunc"), pytest.param(WR(np.sqrt), WR(np.sqrt), WR(lambda u: u**0.5), id="sqrt ufunc"), pytest.param( WR(lambda x: np.reshape(x, (25,))), WR(lambda x: np.reshape(x, (25,))), - identity, + id_matrix, id="reshape function", ), - pytest.param(WR(np.amax), WR(np.amax), identity, id="amax function"), + pytest.param(WR(np.amax), WR(np.amax), id_matrix, id="amax function"), ], ) def test_univariate_op_consistency( diff --git a/pint/util.py b/pint/util.py index 28710e7..807c3ac 100644 --- a/pint/util.py +++ b/pint/util.py @@ -14,48 +14,82 @@ import logging import math import operator import re -from collections.abc import Mapping +from collections.abc import Mapping, Iterable, Iterator from fractions import Fraction from functools import lru_cache, partial from logging import NullHandler from numbers import Number from token import NAME, NUMBER -from typing import TYPE_CHECKING, ClassVar +import tokenize +from typing import ( + TYPE_CHECKING, + ClassVar, + TypeAlias, + Callable, + TypeVar, + Hashable, + Generator, + Any, +) from .compat import NUMERIC_TYPES, tokenizer from .errors import DefinitionSyntaxError from .formatting import format_unit from .pint_eval import build_eval_tree +from ._typing import PintScalar + if TYPE_CHECKING: - from ._typing import Quantity, UnitLike + from ._typing import Quantity, UnitLike, Self from .registry import UnitRegistry + logger = logging.getLogger(__name__) logger.addHandler(NullHandler()) +T = TypeVar("T") +TH = TypeVar("TH", bound=Hashable) +ItMatrix: TypeAlias = Iterable[Iterable[PintScalar]] +Matrix: TypeAlias = list[list[PintScalar]] + + +def _noop(x: T) -> T: + return x + def matrix_to_string( - matrix, row_headers=None, col_headers=None, fmtfun=lambda x: str(int(x)) -): - """Takes a 2D matrix (as nested list) and returns a string. + matrix: ItMatrix, + row_headers: Iterable[str] | None = None, + col_headers: Iterable[str] | None = None, + fmtfun: Callable[ + [ + PintScalar, + ], + str, + ] = "{:0.0f}".format, +) -> str: + """Return a string representation of a matrix. Parameters ---------- - matrix : - - row_headers : - (Default value = None) - col_headers : - (Default value = None) - fmtfun : - (Default value = lambda x: str(int(x))) + matrix + A matrix given as an iterable of an iterable of numbers. + row_headers + An iterable of strings to serve as row headers. + (default = None, meaning no row headers are printed.) + col_headers + An iterable of strings to serve as column headers. + (default = None, meaning no col headers are printed.) + fmtfun + A callable to convert a number into string. + (default = `"{:0.0f}".format`) Returns ------- - + str + String representation of the matrix. """ - ret = [] + ret: list[str] = [] if col_headers: ret.append(("\t" if row_headers else "") + "\t".join(col_headers)) if row_headers: @@ -69,99 +103,124 @@ def matrix_to_string( return "\n".join(ret) -def transpose(matrix): - """Takes a 2D matrix (as nested list) and returns the transposed version. +def transpose(matrix: ItMatrix) -> Matrix: + """Return the transposed version of a matrix. Parameters ---------- - matrix : - + matrix + A matrix given as an iterable of an iterable of numbers. Returns ------- - + Matrix + The transposed version of the matrix. """ return [list(val) for val in zip(*matrix)] -def column_echelon_form(matrix, ntype=Fraction, transpose_result=False): - """Calculates the column echelon form using Gaussian elimination. +def matrix_apply( + matrix: ItMatrix, + func: Callable[ + [ + PintScalar, + ], + PintScalar, + ], +) -> Matrix: + """Apply a function to individual elements within a matrix. Parameters ---------- - matrix : - a 2D matrix as nested list. - ntype : - the numerical type to use in the calculation. (Default value = Fraction) - transpose_result : - indicates if the returned matrix should be transposed. (Default value = False) + matrix + A matrix given as an iterable of an iterable of numbers. + func + A callable that converts a number to another. Returns ------- - type - column echelon form, transformed identity matrix, swapped rows - + A new matrix in which each element has been replaced by new one. """ - lead = 0 + return [[func(x) for x in row] for row in matrix] + + +def column_echelon_form( + matrix: ItMatrix, ntype: type = Fraction, transpose_result: bool = False +) -> tuple[Matrix, Matrix, list[int]]: + """Calculate the column echelon form using Gaussian elimination. - M = transpose(matrix) + Parameters + ---------- + matrix + A 2D matrix as nested list. + ntype + The numerical type to use in the calculation. + (default = Fraction) + transpose_result + Indicates if the returned matrix should be transposed. + (default = False) - _transpose = transpose if transpose_result else lambda x: x + Returns + ------- + ech_matrix + Column echelon form. + id_matrix + Transformed identity matrix. + swapped + Swapped rows. + """ - rows, cols = len(M), len(M[0]) + _transpose = transpose if transpose_result else _noop - new_M = [] - for row in M: - r = [] - for x in row: - if isinstance(x, float): - x = ntype.from_float(x) - else: - x = ntype(x) - r.append(x) - new_M.append(r) - M = new_M + ech_matrix = matrix_apply( + transpose(matrix), + lambda x: ntype.from_float(x) if isinstance(x, float) else ntype(x), # type: ignore + ) + rows, cols = len(ech_matrix), len(ech_matrix[0]) # M = [[ntype(x) for x in row] for row in M] - I = [ # noqa: E741 + id_matrix: list[list[PintScalar]] = [ # noqa: E741 [ntype(1) if n == nc else ntype(0) for nc in range(rows)] for n in range(rows) ] - swapped = [] + swapped: list[int] = [] + lead = 0 for r in range(rows): if lead >= cols: - return _transpose(M), _transpose(I), swapped - i = r - while M[i][lead] == 0: - i += 1 - if i != rows: + return _transpose(ech_matrix), _transpose(id_matrix), swapped + s = r + while ech_matrix[s][lead] == 0: # type: ignore + s += 1 + if s != rows: continue - i = r + s = r lead += 1 if cols == lead: - return _transpose(M), _transpose(I), swapped + return _transpose(ech_matrix), _transpose(id_matrix), swapped - M[i], M[r] = M[r], M[i] - I[i], I[r] = I[r], I[i] + ech_matrix[s], ech_matrix[r] = ech_matrix[r], ech_matrix[s] + id_matrix[s], id_matrix[r] = id_matrix[r], id_matrix[s] - swapped.append(i) - lv = M[r][lead] - M[r] = [mrx / lv for mrx in M[r]] - I[r] = [mrx / lv for mrx in I[r]] + swapped.append(s) + lv = ech_matrix[r][lead] + ech_matrix[r] = [mrx / lv for mrx in ech_matrix[r]] + id_matrix[r] = [mrx / lv for mrx in id_matrix[r]] - for i in range(rows): - if i == r: + for s in range(rows): + if s == r: continue - lv = M[i][lead] - M[i] = [iv - lv * rv for rv, iv in zip(M[r], M[i])] - I[i] = [iv - lv * rv for rv, iv in zip(I[r], I[i])] + lv = ech_matrix[s][lead] + ech_matrix[s] = [ + iv - lv * rv for rv, iv in zip(ech_matrix[r], ech_matrix[s]) + ] + id_matrix[s] = [iv - lv * rv for rv, iv in zip(id_matrix[r], id_matrix[s])] lead += 1 - return _transpose(M), _transpose(I), swapped + return _transpose(ech_matrix), _transpose(id_matrix), swapped -def pi_theorem(quantities, registry=None): +def pi_theorem(quantities: dict[str, Any], registry: UnitRegistry | None = None): """Builds dimensionless quantities using the Buckingham π theorem Parameters @@ -169,7 +228,7 @@ def pi_theorem(quantities, registry=None): quantities : dict mapping between variable name and units registry : - (Default value = None) + (default value = None) Returns ------- @@ -183,7 +242,7 @@ def pi_theorem(quantities, registry=None): dimensions = set() if registry is None: - getdim = lambda x: x + getdim = _noop non_int_type = float else: getdim = registry.get_dimensionality @@ -211,18 +270,18 @@ def pi_theorem(quantities, registry=None): dimensions = list(dimensions) # Calculate dimensionless quantities - M = [ + matrix = [ [dimensionality[dimension] for name, dimensionality in quant] for dimension in dimensions ] - M, identity, pivot = column_echelon_form(M, transpose_result=False) + ech_matrix, id_matrix, pivot = column_echelon_form(matrix, transpose_result=False) # Collect results # Make all numbers integers and minimize the number of negative exponents. # Remove zeros results = [] - for rowm, rowi in zip(M, identity): + for rowm, rowi in zip(ech_matrix, id_matrix): if any(el != 0 for el in rowm): continue max_den = max(f.denominator for f in rowi) @@ -237,7 +296,9 @@ def pi_theorem(quantities, registry=None): return results -def solve_dependencies(dependencies): +def solve_dependencies( + dependencies: dict[TH, set[TH]] +) -> Generator[set[TH], None, None]: """Solve a dependency graph. Parameters @@ -246,12 +307,16 @@ def solve_dependencies(dependencies): dependency dictionary. For each key, the value is an iterable indicating its dependencies. - Returns - ------- - type + Yields + ------ + set iterator of sets, each containing keys of independents tasks dependent only of the previous tasks in the list. + Raises + ------ + ValueError + if a cyclic dependency is found. """ while dependencies: # values not in keys (items without dep) @@ -270,12 +335,37 @@ def solve_dependencies(dependencies): yield t -def find_shortest_path(graph, start, end, path=None): +def find_shortest_path( + graph: dict[TH, set[TH]], start: TH, end: TH, path: list[TH] | None = None +): + """Find shortest path between two nodes within a graph. + + Parameters + ---------- + graph + A graph given as a mapping of nodes + to a set of all connected nodes to it. + start + Starting node. + end + End node. + path + Path to prepend to the one found. + (default = None, empty path.) + + Returns + ------- + list[TH] + The shortest path between two nodes. + """ path = (path or []) + [start] if start == end: return path + + # TODO: raise ValueError when start not in graph if start not in graph: return None + shortest = None for node in graph[start]: if node not in path: @@ -283,10 +373,33 @@ def find_shortest_path(graph, start, end, path=None): if newpath: if not shortest or len(newpath) < len(shortest): shortest = newpath + return shortest -def find_connected_nodes(graph, start, visited=None): +def find_connected_nodes( + graph: dict[TH, set[TH]], start: TH, visited: set[TH] | None = None +) -> set[TH] | None: + """Find all nodes connected to a start node within a graph. + + Parameters + ---------- + graph + A graph given as a mapping of nodes + to a set of all connected nodes to it. + start + Starting node. + visited + Mutable set to collect visited nodes. + (default = None, empty set) + + Returns + ------- + set[TH] + The shortest path between two nodes. + """ + + # TODO: raise ValueError when start not in graph if start not in graph: return None @@ -300,17 +413,17 @@ def find_connected_nodes(graph, start, visited=None): return visited -class udict(dict): +class udict(dict[str, PintScalar]): """Custom dict implementing __missing__.""" - def __missing__(self, key): + def __missing__(self, key: str): return 0 - def copy(self): + def copy(self: Self) -> Self: return udict(self) -class UnitsContainer(Mapping): +class UnitsContainer(Mapping[str, PintScalar]): """The UnitsContainer stores the product of units and their respective exponent and implements the corresponding operations. @@ -318,23 +431,24 @@ class UnitsContainer(Mapping): Parameters ---------- - - Returns - ------- - type - - + non_int_type + Numerical type used for non integer values. """ __slots__ = ("_d", "_hash", "_one", "_non_int_type") - def __init__(self, *args, **kwargs) -> None: + _d: udict + _hash: int | None + _one: PintScalar + _non_int_type: type + + def __init__(self, *args, non_int_type: type | None = None, **kwargs) -> None: if args and isinstance(args[0], UnitsContainer): default_non_int_type = args[0]._non_int_type else: default_non_int_type = float - self._non_int_type = kwargs.pop("non_int_type", default_non_int_type) + self._non_int_type = non_int_type or default_non_int_type if self._non_int_type is float: self._one = 1 @@ -352,10 +466,26 @@ class UnitsContainer(Mapping): d[key] = self._non_int_type(value) self._hash = None - def copy(self): + def copy(self: Self) -> Self: + """Create a copy of this UnitsContainer.""" return self.__copy__() - def add(self, key, value): + def add(self: Self, key: str, value: Number) -> Self: + """Create a new UnitsContainer adding value to + the value existing for a given key. + + Parameters + ---------- + key + unit to which the value will be added. + value + value to be added. + + Returns + ------- + UnitsContainer + A copy of this container. + """ newval = self._d[key] + value new = self.copy() if newval: @@ -365,17 +495,18 @@ class UnitsContainer(Mapping): new._hash = None return new - def remove(self, keys): - """Create a new UnitsContainer purged from given keys. + def remove(self: Self, keys: Iterable[str]) -> Self: + """Create a new UnitsContainer purged from given entries. Parameters ---------- - keys : - + keys + Iterable of keys (units) to remove. Returns ------- - + UnitsContainer + A copy of this container. """ new = self.copy() for k in keys: @@ -383,51 +514,52 @@ class UnitsContainer(Mapping): new._hash = None return new - def rename(self, oldkey, newkey): + def rename(self: Self, oldkey: str, newkey: str) -> Self: """Create a new UnitsContainer in which an entry has been renamed. Parameters ---------- - oldkey : - - newkey : - + oldkey + Existing key (unit). + newkey + New key (unit). Returns ------- - + UnitsContainer + A copy of this container. """ new = self.copy() new._d[newkey] = new._d.pop(oldkey) new._hash = None return new - def __iter__(self): + def __iter__(self) -> Iterator[str]: return iter(self._d) def __len__(self) -> int: return len(self._d) - def __getitem__(self, key): + def __getitem__(self, key: str) -> PintScalar: return self._d[key] - def __contains__(self, key): + def __contains__(self, key: str) -> bool: return key in self._d - def __hash__(self): + def __hash__(self) -> int: if self._hash is None: self._hash = hash(frozenset(self._d.items())) return self._hash # Only needed by pickle protocol 0 and 1 (used by pytables) - def __getstate__(self): + def __getstate__(self) -> tuple[udict, PintScalar, type]: return self._d, self._one, self._non_int_type - def __setstate__(self, state): + def __setstate__(self, state: tuple[udict, PintScalar, type]): self._d, self._one, self._non_int_type = state self._hash = None - def __eq__(self, other) -> bool: + def __eq__(self, other: Any) -> bool: if isinstance(other, UnitsContainer): # UnitsContainer.__hash__(self) is not the same as hash(self); see # ParserHelper.__hash__ and __eq__. @@ -472,7 +604,7 @@ class UnitsContainer(Mapping): out._one = self._one return out - def __mul__(self, other): + def __mul__(self, other: Any): if not isinstance(other, self.__class__): err = "Cannot multiply UnitsContainer by {}" raise TypeError(err.format(type(other))) @@ -488,7 +620,7 @@ class UnitsContainer(Mapping): __rmul__ = __mul__ - def __pow__(self, other): + def __pow__(self, other: Any): if not isinstance(other, NUMERIC_TYPES): err = "Cannot power UnitsContainer by {}" raise TypeError(err.format(type(other))) @@ -499,7 +631,7 @@ class UnitsContainer(Mapping): new._hash = None return new - def __truediv__(self, other): + def __truediv__(self, other: Any): if not isinstance(other, self.__class__): err = "Cannot divide UnitsContainer by {}" raise TypeError(err.format(type(other))) @@ -513,7 +645,7 @@ class UnitsContainer(Mapping): new._hash = None return new - def __rtruediv__(self, other): + def __rtruediv__(self, other: Any): if not isinstance(other, self.__class__) and other != 1: err = "Cannot divide {} by UnitsContainer" raise TypeError(err.format(type(other))) @@ -524,41 +656,48 @@ class UnitsContainer(Mapping): class ParserHelper(UnitsContainer): """The ParserHelper stores in place the product of variables and their respective exponent and implements the corresponding operations. + It also provides a scaling factor. + + For example: + `3 * m ** 2` becomes ParserHelper(3, m=2) + + Briefly is a UnitsContainer with a scaling factor. ParserHelper is a read-only mapping. All operations (even in place ones) + WARNING : The hash value used does not take into account the scale + attribute so be careful if you use it as a dict key and then two unequal + object can have the same hash. + Parameters ---------- - - Returns - ------- - type - WARNING : The hash value used does not take into account the scale - attribute so be careful if you use it as a dict key and then two unequal - object can have the same hash. - + scale + Scaling factor. + (default = 1) + **kwargs + Used to populate the dict of units and exponents. """ __slots__ = ("scale",) - def __init__(self, scale=1, *args, **kwargs): + scale: PintScalar + + def __init__(self, scale: PintScalar = 1, *args, **kwargs): super().__init__(*args, **kwargs) self.scale = scale @classmethod - def from_word(cls, input_word, non_int_type=float): + def from_word(cls, input_word: str, non_int_type: type = float) -> ParserHelper: """Creates a ParserHelper object with a single variable with exponent one. - Equivalent to: ParserHelper({'word': 1}) + Equivalent to: ParserHelper(1, {input_word: 1}) Parameters ---------- - input_word : - - - Returns - ------- + input_word + non_int_type + Numerical type used for non integer values. """ if non_int_type is float: return cls(1, [(input_word, 1)], non_int_type=non_int_type) @@ -567,7 +706,7 @@ class ParserHelper(UnitsContainer): return cls(ONE, [(input_word, ONE)], non_int_type=non_int_type) @classmethod - def eval_token(cls, token, non_int_type=float): + def eval_token(cls, token: tokenize.TokenInfo, non_int_type: type = float): token_type = token.type token_text = token.string if token_type == NUMBER: @@ -585,17 +724,15 @@ class ParserHelper(UnitsContainer): @classmethod @lru_cache - def from_string(cls, input_string, non_int_type=float): + def from_string(cls, input_string: str, non_int_type: type = float) -> ParserHelper: """Parse linear expression mathematical units and return a quantity object. Parameters ---------- - input_string : - - - Returns - ------- + input_string + non_int_type + Numerical type used for non integer values. """ if not input_string: return cls(non_int_type=non_int_type) @@ -656,7 +793,7 @@ class ParserHelper(UnitsContainer): super().__setstate__(state[:-1]) self.scale = state[-1] - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if isinstance(other, ParserHelper): return self.scale == other.scale and super().__eq__(other) elif isinstance(other, str): @@ -666,7 +803,7 @@ class ParserHelper(UnitsContainer): return self.scale == 1 and super().__eq__(other) - def operate(self, items, op=operator.iadd, cleanup=True): + def operate(self, items, op=operator.iadd, cleanup: bool = True): d = udict(self._d) for key, value in items: d[key] = op(d[key], value) @@ -811,21 +948,22 @@ class SharedRegistryObject: inst._REGISTRY = application_registry.get() return inst - def _check(self, other) -> bool: + def _check(self, other: Any) -> bool: """Check if the other object use a registry and if so that it is the same registry. Parameters ---------- - other : - + other Returns ------- - type - other don't use a registry and raise ValueError if other don't use the - same unit registry. + bool + Raises + ------ + ValueError + if other don't use the same unit registry. """ if self._REGISTRY is getattr(other, "_REGISTRY", None): return True @@ -844,17 +982,17 @@ class PrettyIPython: default_format: str - def _repr_html_(self): + def _repr_html_(self) -> str: if "~" in self.default_format: return f"{self:~H}" return f"{self:H}" - def _repr_latex_(self): + def _repr_latex_(self) -> str: if "~" in self.default_format: return f"${self:~L}$" return f"${self:L}$" - def _repr_pretty_(self, p, cycle): + def _repr_pretty_(self, p, cycle: bool): if "~" in self.default_format: p.text(f"{self:~P}") else: @@ -868,14 +1006,15 @@ def to_units_container( Parameters ---------- - unit_like : - - registry : - (Default value = None) + unit_like + Quantity or Unit to infer the plain units from. + registry + If provided, uses the registry's UnitsContainer and parse_unit_name. If None, + uses the registry attached to unit_like. Returns ------- - + UnitsContainer """ mro = type(unit_like).mro() if UnitsContainer in mro: @@ -902,10 +1041,9 @@ def infer_base_unit( Parameters ---------- - unit_like : Union[UnitLike, Quantity] + unit_like Quantity or Unit to infer the plain units from. - - registry: Optional[UnitRegistry] + registry If provided, uses the registry's UnitsContainer and parse_unit_name. If None, uses the registry attached to unit_like. @@ -940,7 +1078,7 @@ def infer_base_unit( return registry.UnitsContainer(nonzero_dict) -def getattr_maybe_raise(self, item): +def getattr_maybe_raise(obj: Any, item: str): """Helper function invoked at start of all overridden ``__getattr__``. Raise AttributeError if the user tries to ask for a _ or __ attribute, @@ -949,39 +1087,25 @@ def getattr_maybe_raise(self, item): Parameters ---------- - item : string - Item to be found. - - - Returns - ------- + item + attribute to be found. + Raises + ------ + AttributeError """ # Double-underscore attributes are tricky to detect because they are - # automatically prefixed with the class name - which may be a subclass of self + # automatically prefixed with the class name - which may be a subclass of obj if ( item.endswith("__") or len(item.lstrip("_")) == 0 or (item.startswith("_") and not item.lstrip("_")[0].isdigit()) ): - raise AttributeError(f"{self!r} object has no attribute {item!r}") - + raise AttributeError(f"{obj!r} object has no attribute {item!r}") -def iterable(y) -> bool: - """Check whether or not an object can be iterated over. - - Vendored from numpy under the terms of the BSD 3-Clause License. (Copyright - (c) 2005-2019, NumPy Developers.) - - Parameters - ---------- - value : - Input object. - type : - object - y : - """ +def iterable(y: Any) -> bool: + """Check whether or not an object can be iterated over.""" try: iter(y) except TypeError: @@ -989,18 +1113,8 @@ def iterable(y) -> bool: return True -def sized(y) -> bool: - """Check whether or not an object has a defined length. - - Parameters - ---------- - value : - Input object. - type : - object - y : - - """ +def sized(y: Any) -> bool: + """Check whether or not an object has a defined length.""" try: len(y) except TypeError: @@ -1008,7 +1122,7 @@ def sized(y) -> bool: return True -def create_class_with_registry(registry, base_class) -> type: +def create_class_with_registry(registry: UnitRegistry, base_class: type) -> type: """Create new class inheriting from base_class and filling _REGISTRY class attribute with an actual instanced registry. """ |