From d0442e776bff3053ef900143cd64facf173e1650 Mon Sep 17 00:00:00 2001 From: Hernan Grecco Date: Mon, 1 May 2023 10:00:26 -0300 Subject: Typing improvements While there is still a lot of work to do (mainly in Registry, Quantity, Unit), this large PR makes several changes all around the code. There has not been any intended functional change, but certain typing improvements required code minor code refactoring to streamline input and output types of functions. An important experimental idea is the PintScalar and PintArray protocols, and Magnitude type. This is to overcome the lack of a proper numerical hierarchy in Python. --- pint/_typing.py | 46 ++- pint/babel_names.py | 6 +- pint/compat.py | 82 +++-- pint/context.py | 2 + pint/converters.py | 16 +- pint/definitions.py | 22 +- pint/delegates/__init__.py | 2 +- pint/delegates/base_defparser.py | 10 +- pint/delegates/txt_defparser/__init__.py | 4 +- pint/delegates/txt_defparser/block.py | 19 +- pint/delegates/txt_defparser/common.py | 6 +- pint/delegates/txt_defparser/context.py | 78 +++-- pint/delegates/txt_defparser/defaults.py | 20 +- pint/delegates/txt_defparser/defparser.py | 45 ++- pint/delegates/txt_defparser/group.py | 24 +- pint/delegates/txt_defparser/plain.py | 18 +- pint/delegates/txt_defparser/system.py | 23 +- pint/errors.py | 28 +- pint/facets/__init__.py | 18 +- pint/facets/context/definitions.py | 16 +- pint/facets/context/objects.py | 13 +- pint/facets/group/definitions.py | 15 +- pint/facets/group/objects.py | 56 ++- pint/facets/nonmultiplicative/definitions.py | 10 +- pint/facets/nonmultiplicative/objects.py | 2 +- pint/facets/plain/definitions.py | 31 +- pint/facets/plain/objects.py | 2 +- pint/facets/system/definitions.py | 17 +- pint/facets/system/objects.py | 52 +-- pint/formatting.py | 33 +- pint/pint_eval.py | 155 ++++++-- pint/testsuite/test_compat_downcast.py | 20 +- pint/util.py | 504 ++++++++++++++++----------- 33 files changed, 905 insertions(+), 490 deletions(-) diff --git a/pint/_typing.py b/pint/_typing.py index 1dc3ea6..5547f85 100644 --- a/pint/_typing.py +++ b/pint/_typing.py @@ -1,12 +1,56 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union +from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union, Protocol + +# TODO: Remove when 3.11 becomes minimal version. +Self = TypeVar("Self") if TYPE_CHECKING: from .facets.plain import PlainQuantity as Quantity from .facets.plain import PlainUnit as Unit from .util import UnitsContainer + +class PintScalar(Protocol): + def __add__(self, other: Any) -> Any: + ... + + def __sub__(self, other: Any) -> Any: + ... + + def __mul__(self, other: Any) -> Any: + ... + + def __truediv__(self, other: Any) -> Any: + ... + + def __floordiv__(self, other: Any) -> Any: + ... + + def __mod__(self, other: Any) -> Any: + ... + + def __divmod__(self, other: Any) -> Any: + ... + + def __pow__(self, other: Any, modulo: Any) -> Any: + ... + + +class PintArray(Protocol): + def __len__(self) -> int: + ... + + def __getitem__(self, key: Any) -> Any: + ... + + def __setitem__(self, key: Any, value: Any) -> None: + ... + + +Magnitude = PintScalar | PintScalar + + UnitLike = Union[str, "UnitsContainer", "Unit"] QuantityOrUnitLike = Union["Quantity", UnitLike] diff --git a/pint/babel_names.py b/pint/babel_names.py index 09fa046..408ef8f 100644 --- a/pint/babel_names.py +++ b/pint/babel_names.py @@ -10,7 +10,7 @@ from __future__ import annotations from .compat import HAS_BABEL -_babel_units = dict( +_babel_units: dict[str, str] = dict( standard_gravity="acceleration-g-force", millibar="pressure-millibar", metric_ton="mass-metric-ton", @@ -141,6 +141,6 @@ _babel_units = dict( if not HAS_BABEL: _babel_units = {} -_babel_systems = dict(mks="metric", imperial="uksystem", US="ussystem") +_babel_systems: dict[str, str] = dict(mks="metric", imperial="uksystem", US="ussystem") -_babel_lengths = ["narrow", "short", "long"] +_babel_lengths: list[str] = ["narrow", "short", "long"] diff --git a/pint/compat.py b/pint/compat.py index ee8d443..f58e9cb 100644 --- a/pint/compat.py +++ b/pint/compat.py @@ -17,12 +17,19 @@ from importlib import import_module from io import BytesIO from numbers import Number from collections.abc import Mapping +from typing import Any, NoReturn, Callable, Generator, Iterable -def missing_dependency(package, display_name=None): +def missing_dependency( + package: str, display_name: str | None = None +) -> Callable[..., NoReturn]: + """Return a helper function that raises an exception when used. + + It provides a way delay a missing dependency exception until it is used. + """ display_name = display_name or package - def _inner(*args, **kwargs): + def _inner(*args: Any, **kwargs: Any) -> NoReturn: raise Exception( "This feature requires %s. Please install it by running:\n" "pip install %s" % (display_name, package) @@ -31,7 +38,14 @@ def missing_dependency(package, display_name=None): return _inner -def tokenizer(input_string): +def tokenizer(input_string: str) -> Generator[tokenize.TokenInfo, None, None]: + """Tokenize an input string, encoded as UTF-8 + and skipping the ENCODING token. + + See Also + -------- + tokenize.tokenize + """ for tokinfo in tokenize.tokenize(BytesIO(input_string.encode("utf-8")).readline): if tokinfo.type != tokenize.ENCODING: yield tokinfo @@ -154,7 +168,8 @@ else: from math import log # noqa: F401 if not HAS_BABEL: - babel_parse = babel_units = missing_dependency("Babel") # noqa: F811 + babel_parse = missing_dependency("Babel") # noqa: F811 + babel_units = babel_parse if not HAS_MIP: mip_missing = missing_dependency("mip") @@ -176,6 +191,9 @@ except ImportError: dask_array = None +# TODO: merge with upcast_type_map + +#: List upcast type names upcast_type_names = ( "pint_pandas.PintArray", "pandas.Series", @@ -186,10 +204,12 @@ upcast_type_names = ( "xarray.core.dataarray.DataArray", ) -upcast_type_map: Mapping[str : type | None] = {k: None for k in upcast_type_names} +#: Map type name to the actual type (for upcast types). +upcast_type_map: Mapping[str, type | None] = {k: None for k in upcast_type_names} def fully_qualified_name(t: type) -> str: + """Return the fully qualified name of a type.""" module = t.__module__ name = t.__qualname__ @@ -200,6 +220,10 @@ def fully_qualified_name(t: type) -> str: def check_upcast_type(obj: type) -> bool: + """Check if the type object is an upcast type.""" + + # TODO: merge or unify name with is_upcast_type + fqn = fully_qualified_name(obj) if fqn not in upcast_type_map: return False @@ -215,22 +239,17 @@ def check_upcast_type(obj: type) -> bool: def is_upcast_type(other: type) -> bool: + """Check if the type object is an upcast type.""" + + # TODO: merge or unify name with check_upcast_type + if other in upcast_type_map.values(): return True return check_upcast_type(other) -def is_duck_array_type(cls) -> bool: - """Check if the type object represents a (non-Quantity) duck array type. - - Parameters - ---------- - cls : class - - Returns - ------- - bool - """ +def is_duck_array_type(cls: type) -> bool: + """Check if the type object represents a (non-Quantity) duck array type.""" # TODO (NEP 30): replace duck array check with hasattr(other, "__duckarray__") return issubclass(cls, ndarray) or ( not hasattr(cls, "_magnitude") @@ -242,20 +261,21 @@ def is_duck_array_type(cls) -> bool: ) -def is_duck_array(obj): +def is_duck_array(obj: type) -> bool: + """Check if an object represents a (non-Quantity) duck array type.""" return is_duck_array_type(type(obj)) -def eq(lhs, rhs, check_all: bool): +def eq(lhs: Any, rhs: Any, check_all: bool) -> bool | Iterable[bool]: """Comparison of scalars and arrays. Parameters ---------- - lhs : object + lhs left-hand side - rhs : object + rhs right-hand side - check_all : bool + check_all if True, reduce sequence to single bool; return True if all the elements are equal. @@ -269,21 +289,21 @@ def eq(lhs, rhs, check_all: bool): return out -def isnan(obj, check_all: bool): - """Test for NaN or NaT +def isnan(obj: Any, check_all: bool) -> bool | Iterable[bool]: + """Test for NaN or NaT. Parameters ---------- - obj : object + obj scalar or vector - check_all : bool + check_all if True, reduce sequence to single bool; return True if any of the elements are NaN. Returns ------- bool or array_like of bool. - Always return False for non-numeric types. + Always return False for non-numeric types. """ if is_duck_array_type(type(obj)): if obj.dtype.kind in "if": @@ -302,21 +322,21 @@ def isnan(obj, check_all: bool): return False -def zero_or_nan(obj, check_all: bool): - """Test if obj is zero, NaN, or NaT +def zero_or_nan(obj: Any, check_all: bool) -> bool | Iterable[bool]: + """Test if obj is zero, NaN, or NaT. Parameters ---------- - obj : object + obj scalar or vector - check_all : bool + check_all if True, reduce sequence to single bool; return True if all the elements are zero, NaN, or NaT. Returns ------- bool or array_like of bool. - Always return False for non-numeric types. + Always return False for non-numeric types. """ out = eq(obj, 0, False) + isnan(obj, False) if check_all and is_duck_array_type(type(out)): diff --git a/pint/context.py b/pint/context.py index 4839926..6c74f65 100644 --- a/pint/context.py +++ b/pint/context.py @@ -18,3 +18,5 @@ if TYPE_CHECKING: #: Regex to match the header parts of a context. #: Regex to match variable names in an equation. + +# TODO: delete this file diff --git a/pint/converters.py b/pint/converters.py index 9b8513f..9494ad1 100644 --- a/pint/converters.py +++ b/pint/converters.py @@ -13,6 +13,10 @@ from __future__ import annotations from dataclasses import dataclass from dataclasses import fields as dc_fields +from typing import Any + +from ._typing import Self, Magnitude + from .compat import HAS_NUMPY, exp, log # noqa: F401 @@ -24,17 +28,17 @@ class Converter: _param_names_to_subclass = {} @property - def is_multiplicative(self): + def is_multiplicative(self) -> bool: return True @property - def is_logarithmic(self): + def is_logarithmic(self) -> bool: return False - def to_reference(self, value, inplace=False): + def to_reference(self, value: Magnitude, inplace: bool = False) -> Magnitude: return value - def from_reference(self, value, inplace=False): + def from_reference(self, value: Magnitude, inplace: bool = False) -> Magnitude: return value def __init_subclass__(cls, **kwargs): @@ -43,7 +47,7 @@ class Converter: cls._subclasses.append(cls) @classmethod - def get_field_names(cls, new_cls): + def get_field_names(cls, new_cls) -> frozenset[str]: return frozenset(p.name for p in dc_fields(new_cls)) @classmethod @@ -51,7 +55,7 @@ class Converter: return None @classmethod - def from_arguments(cls, **kwargs): + def from_arguments(cls: type[Self], **kwargs: Any) -> Self: kwk = frozenset(kwargs.keys()) try: new_cls = cls._param_names_to_subclass[kwk] diff --git a/pint/definitions.py b/pint/definitions.py index 789d9e3..ce89e94 100644 --- a/pint/definitions.py +++ b/pint/definitions.py @@ -8,6 +8,8 @@ :license: BSD, see LICENSE for more details. """ +from __future__ import annotations + from . import errors from ._vendor import flexparser as fp from .delegates import ParserConfig, txt_defparser @@ -17,12 +19,28 @@ class Definition: """This is kept for backwards compatibility""" @classmethod - def from_string(cls, s: str, non_int_type=float): + def from_string(cls, input_string: str, non_int_type: type = float) -> Definition: + """Parse a string into a definition object. + + Parameters + ---------- + input_string + Single line string. + non_int_type + Numerical type used for non integer values. + + Raises + ------ + DefinitionSyntaxError + If a syntax error was found. + """ cfg = ParserConfig(non_int_type) parser = txt_defparser.DefParser(cfg, None) - pp = parser.parse_string(s) + pp = parser.parse_string(input_string) for definition in parser.iter_parsed_project(pp): if isinstance(definition, Exception): raise errors.DefinitionSyntaxError(str(definition)) if not isinstance(definition, (fp.BOS, fp.BOF, fp.BOS)): return definition + + # TODO: What shall we do in this return path. diff --git a/pint/delegates/__init__.py b/pint/delegates/__init__.py index 363ef9c..b2eb9a3 100644 --- a/pint/delegates/__init__.py +++ b/pint/delegates/__init__.py @@ -11,4 +11,4 @@ from . import txt_defparser from .base_defparser import ParserConfig, build_disk_cache_class -__all__ = [txt_defparser, ParserConfig, build_disk_cache_class] +__all__ = ["txt_defparser", "ParserConfig", "build_disk_cache_class"] diff --git a/pint/delegates/base_defparser.py b/pint/delegates/base_defparser.py index 88d9d37..9e784ac 100644 --- a/pint/delegates/base_defparser.py +++ b/pint/delegates/base_defparser.py @@ -14,7 +14,6 @@ import functools import itertools import numbers import pathlib -import typing as ty from dataclasses import dataclass, field from pint import errors @@ -27,10 +26,10 @@ from .._vendor import flexparser as fp @dataclass(frozen=True) class ParserConfig: - """Configuration used by the parser.""" + """Configuration used by the parser in Pint.""" #: Indicates the output type of non integer numbers. - non_int_type: ty.Type[numbers.Number] = float + non_int_type: type[numbers.Number] = float def to_scaled_units_container(self, s: str): return ParserHelper.from_string(s, self.non_int_type) @@ -67,6 +66,11 @@ class ParserConfig: return val.scale +@dataclass(frozen=True) +class PintParsedStatement(fp.ParsedStatement[ParserConfig]): + """A parsed statement for pint, specialized in the actual config.""" + + @functools.lru_cache def build_disk_cache_class(non_int_type: type): """Build disk cache class, taking into account the non_int_type.""" diff --git a/pint/delegates/txt_defparser/__init__.py b/pint/delegates/txt_defparser/__init__.py index 5572ca1..49e4a0b 100644 --- a/pint/delegates/txt_defparser/__init__.py +++ b/pint/delegates/txt_defparser/__init__.py @@ -11,4 +11,6 @@ from .defparser import DefParser -__all__ = [DefParser] +__all__ = [ + "DefParser", +] diff --git a/pint/delegates/txt_defparser/block.py b/pint/delegates/txt_defparser/block.py index 20ebcba..e8d8aa4 100644 --- a/pint/delegates/txt_defparser/block.py +++ b/pint/delegates/txt_defparser/block.py @@ -17,11 +17,14 @@ from __future__ import annotations from dataclasses import dataclass +from typing import Generic, TypeVar + +from ..base_defparser import PintParsedStatement, ParserConfig from ..._vendor import flexparser as fp @dataclass(frozen=True) -class EndDirectiveBlock(fp.ParsedStatement): +class EndDirectiveBlock(PintParsedStatement): """An EndDirectiveBlock is simply an "@end" statement.""" @classmethod @@ -31,8 +34,16 @@ class EndDirectiveBlock(fp.ParsedStatement): return None +OPST = TypeVar("OPST", bound="PintParsedStatement") +IPST = TypeVar("IPST", bound="PintParsedStatement") + +DefT = TypeVar("DefT") + + @dataclass(frozen=True) -class DirectiveBlock(fp.Block): +class DirectiveBlock( + Generic[DefT, OPST, IPST], fp.Block[OPST, IPST, EndDirectiveBlock, ParserConfig] +): """Directive blocks have beginning statement starting with a @ character. and ending with a "@end" (captured using a EndDirectiveBlock). @@ -41,5 +52,5 @@ class DirectiveBlock(fp.Block): closing: EndDirectiveBlock - def derive_definition(self): - pass + def derive_definition(self) -> DefT: + ... diff --git a/pint/delegates/txt_defparser/common.py b/pint/delegates/txt_defparser/common.py index 493d0ec..a1195b3 100644 --- a/pint/delegates/txt_defparser/common.py +++ b/pint/delegates/txt_defparser/common.py @@ -30,7 +30,7 @@ class DefinitionSyntaxError(errors.DefinitionSyntaxError, fp.ParsingError): location: str = field(init=False, default="") - def __str__(self): + def __str__(self) -> str: msg = ( self.msg + "\n " + (self.format_position or "") + " " + (self.raw or "") ) @@ -38,7 +38,7 @@ class DefinitionSyntaxError(errors.DefinitionSyntaxError, fp.ParsingError): msg += "\n " + self.location return msg - def set_location(self, value): + def set_location(self, value: str) -> None: super().__setattr__("location", value) @@ -47,7 +47,7 @@ class ImportDefinition(fp.IncludeStatement): value: str @property - def target(self): + def target(self) -> str: return self.value @classmethod diff --git a/pint/delegates/txt_defparser/context.py b/pint/delegates/txt_defparser/context.py index b7e5a67..ce9fc9b 100644 --- a/pint/delegates/txt_defparser/context.py +++ b/pint/delegates/txt_defparser/context.py @@ -23,32 +23,32 @@ from dataclasses import dataclass from ..._vendor import flexparser as fp from ...facets.context import definitions -from ..base_defparser import ParserConfig +from ..base_defparser import ParserConfig, PintParsedStatement from . import block, common, plain +# TODO check syntax +T = ty.TypeVar("T", bound="ForwardRelation | BidirectionalRelation") -@dataclass(frozen=True) -class Relation(definitions.Relation): - @classmethod - def _from_string_and_context_sep( - cls, s: str, config: ParserConfig, separator: str - ) -> fp.FromString[Relation]: - if separator not in s: - return None - if ":" not in s: - return None - rel, eq = s.split(":") +def _from_string_and_context_sep( + cls: type[T], s: str, config: ParserConfig, separator: str +) -> T | None: + if separator not in s: + return None + if ":" not in s: + return None + + rel, eq = s.split(":") - parts = rel.split(separator) + parts = rel.split(separator) - src, dst = (config.to_dimension_container(s) for s in parts) + src, dst = (config.to_dimension_container(s) for s in parts) - return cls(src, dst, eq.strip()) + return cls(src, dst, eq.strip()) @dataclass(frozen=True) -class ForwardRelation(fp.ParsedStatement, definitions.ForwardRelation, Relation): +class ForwardRelation(PintParsedStatement, definitions.ForwardRelation): """A relation connecting a dimension to another via a transformation function. -> : @@ -58,13 +58,11 @@ class ForwardRelation(fp.ParsedStatement, definitions.ForwardRelation, Relation) def from_string_and_config( cls, s: str, config: ParserConfig ) -> fp.FromString[ForwardRelation]: - return super()._from_string_and_context_sep(s, config, "->") + return _from_string_and_context_sep(cls, s, config, "->") @dataclass(frozen=True) -class BidirectionalRelation( - fp.ParsedStatement, definitions.BidirectionalRelation, Relation -): +class BidirectionalRelation(PintParsedStatement, definitions.BidirectionalRelation): """A bidirectional relation connecting a dimension to another via a simple transformation function. @@ -76,11 +74,11 @@ class BidirectionalRelation( def from_string_and_config( cls, s: str, config: ParserConfig ) -> fp.FromString[BidirectionalRelation]: - return super()._from_string_and_context_sep(s, config, "<->") + return _from_string_and_context_sep(cls, s, config, "<->") @dataclass(frozen=True) -class BeginContext(fp.ParsedStatement): +class BeginContext(PintParsedStatement): """Being of a context directive. @context[(defaults)] [= ] [= ] @@ -91,7 +89,7 @@ class BeginContext(fp.ParsedStatement): ) name: str - aliases: tuple[str, ...] + aliases: tuple[str] defaults: dict[str, numbers.Number] @classmethod @@ -130,7 +128,18 @@ class BeginContext(fp.ParsedStatement): @dataclass(frozen=True) -class ContextDefinition(block.DirectiveBlock): +class ContextDefinition( + block.DirectiveBlock[ + definitions.ContextDefinition, + BeginContext, + ty.Union[ + plain.CommentDefinition, + BidirectionalRelation, + ForwardRelation, + plain.UnitDefinition, + ], + ] +): """Definition of a Context @context[(defaults)] [= ] [= ] @@ -169,27 +178,34 @@ class ContextDefinition(block.DirectiveBlock): ] ] - def derive_definition(self): + def derive_definition(self) -> definitions.ContextDefinition: return definitions.ContextDefinition( self.name, self.aliases, self.defaults, self.relations, self.redefinitions ) @property - def name(self): + def name(self) -> str: + assert isinstance(self.opening, BeginContext) return self.opening.name @property - def aliases(self): + def aliases(self) -> tuple[str]: + assert isinstance(self.opening, BeginContext) return self.opening.aliases @property - def defaults(self): + def defaults(self) -> dict[str, numbers.Number]: + assert isinstance(self.opening, BeginContext) return self.opening.defaults @property - def relations(self): - return tuple(r for r in self.body if isinstance(r, Relation)) + def relations(self) -> tuple[BidirectionalRelation | ForwardRelation]: + return tuple( + r + for r in self.body + if isinstance(r, (ForwardRelation, BidirectionalRelation)) + ) @property - def redefinitions(self): + def redefinitions(self) -> tuple[plain.UnitDefinition]: return tuple(r for r in self.body if isinstance(r, plain.UnitDefinition)) diff --git a/pint/delegates/txt_defparser/defaults.py b/pint/delegates/txt_defparser/defaults.py index af6e31f..688d90f 100644 --- a/pint/delegates/txt_defparser/defaults.py +++ b/pint/delegates/txt_defparser/defaults.py @@ -19,10 +19,11 @@ from dataclasses import dataclass, fields from ..._vendor import flexparser as fp from ...facets.plain import definitions from . import block, plain +from ..base_defparser import PintParsedStatement @dataclass(frozen=True) -class BeginDefaults(fp.ParsedStatement): +class BeginDefaults(PintParsedStatement): """Being of a defaults directive. @defaults @@ -36,7 +37,16 @@ class BeginDefaults(fp.ParsedStatement): @dataclass(frozen=True) -class DefaultsDefinition(block.DirectiveBlock): +class DefaultsDefinition( + block.DirectiveBlock[ + definitions.DefaultsDefinition, + BeginDefaults, + ty.Union[ + plain.CommentDefinition, + plain.Equality, + ], + ] +): """Directive to store values. @defaults @@ -55,10 +65,10 @@ class DefaultsDefinition(block.DirectiveBlock): ] @property - def _valid_fields(self): + def _valid_fields(self) -> tuple[str]: return tuple(f.name for f in fields(definitions.DefaultsDefinition)) - def derive_definition(self): + def derive_definition(self) -> definitions.DefaultsDefinition: for definition in self.filter_by(plain.Equality): if definition.lhs not in self._valid_fields: raise ValueError( @@ -70,7 +80,7 @@ class DefaultsDefinition(block.DirectiveBlock): *tuple(self.get_key(key) for key in self._valid_fields) ) - def get_key(self, key): + def get_key(self, key: str) -> str: for stmt in self.body: if isinstance(stmt, plain.Equality) and stmt.lhs == key: return stmt.rhs diff --git a/pint/delegates/txt_defparser/defparser.py b/pint/delegates/txt_defparser/defparser.py index 0b99d6d..f1b8e45 100644 --- a/pint/delegates/txt_defparser/defparser.py +++ b/pint/delegates/txt_defparser/defparser.py @@ -5,11 +5,28 @@ import typing as ty from ..._vendor import flexcache as fc from ..._vendor import flexparser as fp -from .. import base_defparser +from ..base_defparser import ParserConfig from . import block, common, context, defaults, group, plain, system -class PintRootBlock(fp.RootBlock): +class PintRootBlock( + fp.RootBlock[ + ty.Union[ + plain.CommentDefinition, + common.ImportDefinition, + context.ContextDefinition, + defaults.DefaultsDefinition, + system.SystemDefinition, + group.GroupDefinition, + plain.AliasDefinition, + plain.DerivedDimensionDefinition, + plain.DimensionDefinition, + plain.PrefixDefinition, + plain.UnitDefinition, + ], + ParserConfig, + ] +): body: fp.Multi[ ty.Union[ plain.CommentDefinition, @@ -27,11 +44,15 @@ class PintRootBlock(fp.RootBlock): ] +class PintSource(fp.ParsedSource[PintRootBlock, ParserConfig]): + """Source code in Pint.""" + + class HashTuple(tuple): pass -class _PintParser(fp.Parser): +class _PintParser(fp.Parser[PintRootBlock, ParserConfig]): """Parser for the original Pint definition file, with cache.""" _delimiters = { @@ -46,11 +67,11 @@ class _PintParser(fp.Parser): _diskcache: fc.DiskCache - def __init__(self, config: base_defparser.ParserConfig, *args, **kwargs): + def __init__(self, config: ParserConfig, *args, **kwargs): self._diskcache = kwargs.pop("diskcache", None) super().__init__(config, *args, **kwargs) - def parse_file(self, path: pathlib.Path) -> fp.ParsedSource: + def parse_file(self, path: pathlib.Path) -> PintSource: if self._diskcache is None: return super().parse_file(path) content, basename = self._diskcache.load(path, super().parse_file) @@ -58,7 +79,13 @@ class _PintParser(fp.Parser): class DefParser: - skip_classes = (fp.BOF, fp.BOR, fp.BOS, fp.EOS, plain.CommentDefinition) + skip_classes: tuple[type] = ( + fp.BOF, + fp.BOR, + fp.BOS, + fp.EOS, + plain.CommentDefinition, + ) def __init__(self, default_config, diskcache): self._default_config = default_config @@ -78,6 +105,8 @@ class DefParser: continue if isinstance(stmt, common.DefinitionSyntaxError): + # TODO: check why this assert fails + # assert isinstance(last_location, str) stmt.set_location(last_location) raise stmt elif isinstance(stmt, block.DirectiveBlock): @@ -101,7 +130,7 @@ class DefParser: else: yield stmt - def parse_file(self, filename: pathlib.Path, cfg=None): + def parse_file(self, filename: pathlib.Path, cfg: ParserConfig | None = None): return fp.parse( filename, _PintParser, @@ -109,7 +138,7 @@ class DefParser: diskcache=self._diskcache, ) - def parse_string(self, content: str, cfg=None): + def parse_string(self, content: str, cfg: ParserConfig | None = None): return fp.parse_bytes( content.encode("utf-8"), _PintParser, diff --git a/pint/delegates/txt_defparser/group.py b/pint/delegates/txt_defparser/group.py index 5be42ac..e96d44b 100644 --- a/pint/delegates/txt_defparser/group.py +++ b/pint/delegates/txt_defparser/group.py @@ -23,10 +23,11 @@ from dataclasses import dataclass from ..._vendor import flexparser as fp from ...facets.group import definitions from . import block, common, plain +from ..base_defparser import PintParsedStatement @dataclass(frozen=True) -class BeginGroup(fp.ParsedStatement): +class BeginGroup(PintParsedStatement): """Being of a group directive. @group [using , ..., ] @@ -59,7 +60,16 @@ class BeginGroup(fp.ParsedStatement): @dataclass(frozen=True) -class GroupDefinition(block.DirectiveBlock): +class GroupDefinition( + block.DirectiveBlock[ + definitions.GroupDefinition, + BeginGroup, + ty.Union[ + plain.CommentDefinition, + plain.UnitDefinition, + ], + ] +): """Definition of a group. @group [using , ..., ] @@ -88,19 +98,21 @@ class GroupDefinition(block.DirectiveBlock): ] ] - def derive_definition(self): + def derive_definition(self) -> definitions.GroupDefinition: return definitions.GroupDefinition( self.name, self.using_group_names, self.definitions ) @property - def name(self): + def name(self) -> str: + assert isinstance(self.opening, BeginGroup) return self.opening.name @property - def using_group_names(self): + def using_group_names(self) -> tuple[str]: + assert isinstance(self.opening, BeginGroup) return self.opening.using_group_names @property - def definitions(self) -> ty.Tuple[plain.UnitDefinition, ...]: + def definitions(self) -> tuple[plain.UnitDefinition]: return tuple(el for el in self.body if isinstance(el, plain.UnitDefinition)) diff --git a/pint/delegates/txt_defparser/plain.py b/pint/delegates/txt_defparser/plain.py index 749e7fd..9c7bd42 100644 --- a/pint/delegates/txt_defparser/plain.py +++ b/pint/delegates/txt_defparser/plain.py @@ -29,12 +29,12 @@ from ..._vendor import flexparser as fp from ...converters import Converter from ...facets.plain import definitions from ...util import UnitsContainer -from ..base_defparser import ParserConfig +from ..base_defparser import ParserConfig, PintParsedStatement from . import common @dataclass(frozen=True) -class Equality(fp.ParsedStatement, definitions.Equality): +class Equality(PintParsedStatement, definitions.Equality): """An equality statement contains a left and right hand separated lhs and rhs should be space stripped. @@ -53,7 +53,7 @@ class Equality(fp.ParsedStatement, definitions.Equality): @dataclass(frozen=True) -class CommentDefinition(fp.ParsedStatement, definitions.CommentDefinition): +class CommentDefinition(PintParsedStatement, definitions.CommentDefinition): """Comments start with a # character. # This is a comment. @@ -63,14 +63,14 @@ class CommentDefinition(fp.ParsedStatement, definitions.CommentDefinition): """ @classmethod - def from_string(cls, s: str) -> fp.FromString[fp.ParsedStatement]: + def from_string(cls, s: str) -> fp.FromString[CommentDefinition]: if not s.startswith("#"): return None return cls(s[1:].strip()) @dataclass(frozen=True) -class PrefixDefinition(fp.ParsedStatement, definitions.PrefixDefinition): +class PrefixDefinition(PintParsedStatement, definitions.PrefixDefinition): """Definition of a prefix:: - = [= ] [= ] [ = ] [...] @@ -119,7 +119,7 @@ class PrefixDefinition(fp.ParsedStatement, definitions.PrefixDefinition): @dataclass(frozen=True) -class UnitDefinition(fp.ParsedStatement, definitions.UnitDefinition): +class UnitDefinition(PintParsedStatement, definitions.UnitDefinition): """Definition of a unit:: = [= ] [= ] [ = ] [...] @@ -194,7 +194,7 @@ class UnitDefinition(fp.ParsedStatement, definitions.UnitDefinition): @dataclass(frozen=True) -class DimensionDefinition(fp.ParsedStatement, definitions.DimensionDefinition): +class DimensionDefinition(PintParsedStatement, definitions.DimensionDefinition): """Definition of a root dimension:: [dimension name] @@ -221,7 +221,7 @@ class DimensionDefinition(fp.ParsedStatement, definitions.DimensionDefinition): @dataclass(frozen=True) class DerivedDimensionDefinition( - fp.ParsedStatement, definitions.DerivedDimensionDefinition + PintParsedStatement, definitions.DerivedDimensionDefinition ): """Definition of a derived dimension:: @@ -261,7 +261,7 @@ class DerivedDimensionDefinition( @dataclass(frozen=True) -class AliasDefinition(fp.ParsedStatement, definitions.AliasDefinition): +class AliasDefinition(PintParsedStatement, definitions.AliasDefinition): """Additional alias(es) for an already existing unit:: @alias = [ = ] [...] diff --git a/pint/delegates/txt_defparser/system.py b/pint/delegates/txt_defparser/system.py index b21fd7a..4efbb4d 100644 --- a/pint/delegates/txt_defparser/system.py +++ b/pint/delegates/txt_defparser/system.py @@ -14,11 +14,12 @@ from dataclasses import dataclass from ..._vendor import flexparser as fp from ...facets.system import definitions +from ..base_defparser import PintParsedStatement from . import block, common, plain @dataclass(frozen=True) -class BaseUnitRule(fp.ParsedStatement, definitions.BaseUnitRule): +class BaseUnitRule(PintParsedStatement, definitions.BaseUnitRule): @classmethod def from_string(cls, s: str) -> fp.FromString[BaseUnitRule]: if ":" not in s: @@ -32,7 +33,7 @@ class BaseUnitRule(fp.ParsedStatement, definitions.BaseUnitRule): @dataclass(frozen=True) -class BeginSystem(fp.ParsedStatement): +class BeginSystem(PintParsedStatement): """Being of a system directive. @system [using , ..., ] @@ -67,7 +68,13 @@ class BeginSystem(fp.ParsedStatement): @dataclass(frozen=True) -class SystemDefinition(block.DirectiveBlock): +class SystemDefinition( + block.DirectiveBlock[ + definitions.SystemDefinition, + BeginSystem, + ty.Union[plain.CommentDefinition, BaseUnitRule], + ] +): """Definition of a System: @system [using , ..., ] @@ -92,19 +99,21 @@ class SystemDefinition(block.DirectiveBlock): opening: fp.Single[BeginSystem] body: fp.Multi[ty.Union[plain.CommentDefinition, BaseUnitRule]] - def derive_definition(self): + def derive_definition(self) -> definitions.SystemDefinition: return definitions.SystemDefinition( self.name, self.using_group_names, self.rules ) @property - def name(self): + def name(self) -> str: + assert isinstance(self.opening, BeginSystem) return self.opening.name @property - def using_group_names(self): + def using_group_names(self) -> tuple[str]: + assert isinstance(self.opening, BeginSystem) return self.opening.using_group_names @property - def rules(self): + def rules(self) -> tuple[BaseUnitRule]: return tuple(el for el in self.body if isinstance(el, BaseUnitRule)) diff --git a/pint/errors.py b/pint/errors.py index 8f849da..6cebb21 100644 --- a/pint/errors.py +++ b/pint/errors.py @@ -36,18 +36,21 @@ MSG_INVALID_SYSTEM_NAME = ( ) -def is_dim(name): +def is_dim(name: str) -> bool: + """Return True if the name is flanked by square brackets `[` and `]`.""" return name[0] == "[" and name[-1] == "]" -def is_valid_prefix_name(name): +def is_valid_prefix_name(name: str) -> bool: + """Return True if the name is a valid python identifier or empty.""" return str.isidentifier(name) or name == "" is_valid_unit_name = is_valid_system_name = is_valid_context_name = str.isidentifier -def _no_space(name): +def _no_space(name: str) -> bool: + """Return False if the name contains a space in any position.""" return name.strip() == name and " " not in name @@ -58,7 +61,14 @@ is_valid_unit_alias = ( ) = is_valid_unit_symbol = is_valid_prefix_symbol = _no_space -def is_valid_dimension_name(name): +def is_valid_dimension_name(name: str) -> bool: + """Return True if the name is consistent with a dimension name. + + - flanked by square brackets. + - empty dimension name or identifier. + """ + + # TODO: shall we check also fro spaces? return name == "[]" or ( len(name) > 1 and is_dim(name) and str.isidentifier(name[1:-1]) ) @@ -67,8 +77,8 @@ def is_valid_dimension_name(name): class WithDefErr: """Mixing class to make some classes more readable.""" - def def_err(self, msg): - return DefinitionError(self.name, self.__class__.__name__, msg) + def def_err(self, msg: str): + return DefinitionError(self.name, self.__class__, msg) @dataclass(frozen=False) @@ -81,7 +91,7 @@ class DefinitionError(ValueError, PintError): """Raised when a definition is not properly constructed.""" name: str - definition_type: ty.Type + definition_type: type msg: str def __str__(self): @@ -110,7 +120,7 @@ class RedefinitionError(ValueError, PintError): """Raised when a unit or prefix is redefined.""" name: str - definition_type: ty.Type + definition_type: type def __str__(self): msg = f"Cannot redefine '{self.name}' ({self.definition_type})" @@ -124,7 +134,7 @@ class RedefinitionError(ValueError, PintError): class UndefinedUnitError(AttributeError, PintError): """Raised when the units are not defined in the unit registry.""" - unit_names: ty.Union[str, ty.Tuple[str, ...]] + unit_names: str | tuple[str] def __str__(self): if isinstance(self.unit_names, str): diff --git a/pint/facets/__init__.py b/pint/facets/__init__.py index 7b24463..750f729 100644 --- a/pint/facets/__init__.py +++ b/pint/facets/__init__.py @@ -82,13 +82,13 @@ from .plain import PlainRegistry from .system import SystemRegistry __all__ = [ - ContextRegistry, - DaskRegistry, - FormattingRegistry, - GroupRegistry, - MeasurementRegistry, - NonMultiplicativeRegistry, - NumpyRegistry, - PlainRegistry, - SystemRegistry, + "ContextRegistry", + "DaskRegistry", + "FormattingRegistry", + "GroupRegistry", + "MeasurementRegistry", + "NonMultiplicativeRegistry", + "NumpyRegistry", + "PlainRegistry", + "SystemRegistry", ] diff --git a/pint/facets/context/definitions.py b/pint/facets/context/definitions.py index d9ba473..07eb92f 100644 --- a/pint/facets/context/definitions.py +++ b/pint/facets/context/definitions.py @@ -12,7 +12,7 @@ import itertools import numbers import re from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any, Callable, Iterable from ... import errors from ..plain import UnitDefinition @@ -41,7 +41,7 @@ class Relation: # could be used. @property - def variables(self) -> set[str, ...]: + def variables(self) -> set[str]: """Find all variables names in the equation.""" return set(self._varname_re.findall(self.equation)) @@ -55,7 +55,7 @@ class Relation: ) @property - def bidirectional(self): + def bidirectional(self) -> bool: raise NotImplementedError @@ -92,18 +92,18 @@ class ContextDefinition(errors.WithDefErr): #: name of the context name: str #: other na - aliases: tuple[str, ...] + aliases: tuple[str] defaults: dict[str, numbers.Number] - relations: tuple[Relation, ...] - redefinitions: tuple[UnitDefinition, ...] + relations: tuple[Relation] + redefinitions: tuple[UnitDefinition] @property - def variables(self) -> set[str, ...]: + def variables(self) -> set[str]: """Return all variable names in all transformations.""" return set().union(*(r.variables for r in self.relations)) @classmethod - def from_lines(cls, lines, non_int_type): + def from_lines(cls, lines: Iterable[str], non_int_type: type): # TODO: this is to keep it backwards compatible from ...delegates import ParserConfig, txt_defparser diff --git a/pint/facets/context/objects.py b/pint/facets/context/objects.py index 58f8bb8..bec2a43 100644 --- a/pint/facets/context/objects.py +++ b/pint/facets/context/objects.py @@ -10,6 +10,7 @@ from __future__ import annotations import weakref from collections import ChainMap, defaultdict +from typing import Any, Iterable from ...facets.plain import UnitDefinition from ...util import UnitsContainer, to_units_container @@ -70,8 +71,8 @@ class Context: def __init__( self, name: str | None = None, - aliases: tuple[str, ...] = (), - defaults: dict | None = None, + aliases: tuple[str] = tuple(), + defaults: dict[str, Any] | None = None, ) -> None: self.name = name self.aliases = aliases @@ -93,7 +94,7 @@ class Context: self.relation_to_context = weakref.WeakValueDictionary() @classmethod - def from_context(cls, context: Context, **defaults) -> Context: + def from_context(cls, context: Context, **defaults: Any) -> Context: """Creates a new context that shares the funcs dictionary with the original context. The default values are copied from the original context and updated with the new defaults. @@ -122,7 +123,9 @@ class Context: return context @classmethod - def from_lines(cls, lines, to_base_func=None, non_int_type=float) -> Context: + def from_lines( + cls, lines: Iterable[str], to_base_func=None, non_int_type: type = float + ) -> Context: cd = ContextDefinition.from_lines(lines, non_int_type) return cls.from_definition(cd, to_base_func) @@ -273,7 +276,7 @@ class ContextChain(ChainMap): """ return self[(src, dst)].transform(src, dst, registry, value) - def hashable(self): + def hashable(self) -> tuple[Any]: """Generate a unique hashable and comparable representation of self, which can be used as a key in a dict. This class cannot define ``__hash__`` because it is mutable, and the Python interpreter does cache the output of ``__hash__``. diff --git a/pint/facets/group/definitions.py b/pint/facets/group/definitions.py index c0abced..48c6f4b 100644 --- a/pint/facets/group/definitions.py +++ b/pint/facets/group/definitions.py @@ -8,9 +8,10 @@ from __future__ import annotations -import typing as ty +from typing import Iterable from dataclasses import dataclass +from ..._typing import Self from ... import errors from .. import plain @@ -22,12 +23,14 @@ class GroupDefinition(errors.WithDefErr): #: name of the group name: str #: unit groups that will be included within the group - using_group_names: ty.Tuple[str, ...] + using_group_names: tuple[str] #: definitions for the units existing within the group - definitions: ty.Tuple[plain.UnitDefinition, ...] + definitions: tuple[plain.UnitDefinition] @classmethod - def from_lines(cls, lines, non_int_type): + def from_lines( + cls: type[Self], lines: Iterable[str], non_int_type: type + ) -> Self | None: # TODO: this is to keep it backwards compatible from ...delegates import ParserConfig, txt_defparser @@ -39,10 +42,10 @@ class GroupDefinition(errors.WithDefErr): return definition @property - def unit_names(self) -> ty.Tuple[str, ...]: + def unit_names(self) -> tuple[str]: return tuple(el.name for el in self.definitions) - def __post_init__(self): + def __post_init__(self) -> None: if not errors.is_valid_group_name(self.name): raise self.def_err(errors.MSG_INVALID_GROUP_NAME) diff --git a/pint/facets/group/objects.py b/pint/facets/group/objects.py index 558a107..a0a81be 100644 --- a/pint/facets/group/objects.py +++ b/pint/facets/group/objects.py @@ -8,6 +8,7 @@ from __future__ import annotations +from typing import Generator, Iterable from ...util import SharedRegistryObject, getattr_maybe_raise from .definitions import GroupDefinition @@ -23,32 +24,26 @@ class Group(SharedRegistryObject): The group belongs to one Registry. See GroupDefinition for the definition file syntax. - """ - def __init__(self, name): - """ - :param name: Name of the group. If not given, a root Group will be created. - :type name: str - :param groups: dictionary like object groups and system. - The newly created group will be added after creation. - :type groups: dict[str | Group] - """ + Parameters + ---------- + name + If not given, a root Group will be created. + """ + def __init__(self, name: str): # The name of the group. - #: type: str self.name = name #: Names of the units in this group. #: :type: set[str] - self._unit_names = set() + self._unit_names: set[str] = set() #: Names of the groups in this group. - #: :type: set[str] - self._used_groups = set() + self._used_groups: set[str] = set() #: Names of the groups in which this group is contained. - #: :type: set[str] - self._used_by = set() + self._used_by: set[str] = set() # Add this group to the group dictionary self._REGISTRY._groups[self.name] = self @@ -59,8 +54,7 @@ class Group(SharedRegistryObject): #: A cache of the included units. #: None indicates that the cache has been invalidated. - #: :type: frozenset[str] | None - self._computed_members = None + self._computed_members: frozenset[str] | None = None @property def members(self): @@ -70,23 +64,23 @@ class Group(SharedRegistryObject): """ if self._computed_members is None: - self._computed_members = set(self._unit_names) + tmp = set(self._unit_names) for _, group in self.iter_used_groups(): - self._computed_members |= group.members + tmp |= group.members - self._computed_members = frozenset(self._computed_members) + self._computed_members = frozenset(tmp) return self._computed_members - def invalidate_members(self): + def invalidate_members(self) -> None: """Invalidate computed members in this Group and all parent nodes.""" self._computed_members = None d = self._REGISTRY._groups for name in self._used_by: d[name].invalidate_members() - def iter_used_groups(self): + def iter_used_groups(self) -> Generator[tuple[str, Group], None, None]: pending = set(self._used_groups) d = self._REGISTRY._groups while pending: @@ -95,13 +89,13 @@ class Group(SharedRegistryObject): pending |= group._used_groups yield name, d[name] - def is_used_group(self, group_name): + def is_used_group(self, group_name: str) -> bool: for name, _ in self.iter_used_groups(): if name == group_name: return True return False - def add_units(self, *unit_names): + def add_units(self, *unit_names: str) -> None: """Add units to group.""" for unit_name in unit_names: self._unit_names.add(unit_name) @@ -109,17 +103,17 @@ class Group(SharedRegistryObject): self.invalidate_members() @property - def non_inherited_unit_names(self): + def non_inherited_unit_names(self) -> frozenset[str]: return frozenset(self._unit_names) - def remove_units(self, *unit_names): + def remove_units(self, *unit_names: str) -> None: """Remove units from group.""" for unit_name in unit_names: self._unit_names.remove(unit_name) self.invalidate_members() - def add_groups(self, *group_names): + def add_groups(self, *group_names: str) -> None: """Add groups to group.""" d = self._REGISTRY._groups for group_name in group_names: @@ -136,7 +130,7 @@ class Group(SharedRegistryObject): self.invalidate_members() - def remove_groups(self, *group_names): + def remove_groups(self, *group_names: str) -> None: """Remove groups from group.""" d = self._REGISTRY._groups for group_name in group_names: @@ -148,7 +142,9 @@ class Group(SharedRegistryObject): self.invalidate_members() @classmethod - def from_lines(cls, lines, define_func, non_int_type=float) -> Group: + def from_lines( + cls, lines: Iterable[str], define_func, non_int_type: type = float + ) -> Group: """Return a Group object parsing an iterable of lines. Parameters @@ -190,6 +186,6 @@ class Group(SharedRegistryObject): return grp - def __getattr__(self, item): + def __getattr__(self, item: str): getattr_maybe_raise(self, item) return self._REGISTRY diff --git a/pint/facets/nonmultiplicative/definitions.py b/pint/facets/nonmultiplicative/definitions.py index dbfc0ff..f795cf0 100644 --- a/pint/facets/nonmultiplicative/definitions.py +++ b/pint/facets/nonmultiplicative/definitions.py @@ -10,6 +10,7 @@ from __future__ import annotations from dataclasses import dataclass +from ..._typing import Magnitude from ...compat import HAS_NUMPY, exp, log from ..plain import ScaleConverter @@ -24,7 +25,7 @@ class OffsetConverter(ScaleConverter): def is_multiplicative(self): return self.offset == 0 - def to_reference(self, value, inplace=False): + def to_reference(self, value: Magnitude, inplace: bool = False) -> Magnitude: if inplace: value *= self.scale value += self.offset @@ -33,7 +34,7 @@ class OffsetConverter(ScaleConverter): return value - def from_reference(self, value, inplace=False): + def from_reference(self, value: Magnitude, inplace: bool = False) -> Magnitude: if inplace: value -= self.offset value /= self.scale @@ -66,6 +67,7 @@ class LogarithmicConverter(ScaleConverter): controls if computation is done in place """ + # TODO: Can I use PintScalar here? logbase: float logfactor: float @@ -77,7 +79,7 @@ class LogarithmicConverter(ScaleConverter): def is_logarithmic(self): return True - def from_reference(self, value, inplace=False): + def from_reference(self, value: Magnitude, inplace: bool = False) -> Magnitude: """Converts value from the reference unit to the logarithmic unit dBm <------ mW @@ -95,7 +97,7 @@ class LogarithmicConverter(ScaleConverter): return value - def to_reference(self, value, inplace=False): + def to_reference(self, value: Magnitude, inplace: bool = False) -> Magnitude: """Converts value to the reference unit from the logarithmic unit dBm ------> mW diff --git a/pint/facets/nonmultiplicative/objects.py b/pint/facets/nonmultiplicative/objects.py index 7f9064d..0ab743e 100644 --- a/pint/facets/nonmultiplicative/objects.py +++ b/pint/facets/nonmultiplicative/objects.py @@ -40,7 +40,7 @@ class NonMultiplicativeQuantity(PlainQuantity): self._get_unit_definition(d).reference == offset_unit_dim for d in deltas ) - def _ok_for_muldiv(self, no_offset_units=None) -> bool: + def _ok_for_muldiv(self, no_offset_units: int | None = None) -> bool: """Checks if PlainQuantity object can be multiplied or divided""" is_ok = True diff --git a/pint/facets/plain/definitions.py b/pint/facets/plain/definitions.py index eb45db1..79a44f1 100644 --- a/pint/facets/plain/definitions.py +++ b/pint/facets/plain/definitions.py @@ -13,8 +13,9 @@ import numbers import typing as ty from dataclasses import dataclass from functools import cached_property -from typing import Callable +from typing import Callable, Any +from ..._typing import Magnitude from ... import errors from ...converters import Converter from ...util import UnitsContainer @@ -23,7 +24,7 @@ from ...util import UnitsContainer class NotNumeric(Exception): """Internal exception. Do not expose outside Pint""" - def __init__(self, value): + def __init__(self, value: Any): self.value = value @@ -115,18 +116,26 @@ class UnitDefinition(errors.WithDefErr): #: canonical name of the unit name: str #: canonical symbol - defined_symbol: ty.Optional[str] + defined_symbol: str | None #: additional names for the same unit - aliases: ty.Tuple[str, ...] + aliases: tuple[str] #: A functiont that converts a value in these units into the reference units - converter: ty.Optional[ty.Union[Callable, Converter]] + converter: Callable[ + [ + Magnitude, + ], + Magnitude, + ] | Converter | None #: Reference units. - reference: ty.Optional[UnitsContainer] + reference: UnitsContainer | None def __post_init__(self): if not errors.is_valid_unit_name(self.name): raise self.def_err(errors.MSG_INVALID_UNIT_NAME) + # TODO: check why reference: UnitsContainer | None + assert isinstance(self.reference, UnitsContainer) + if not any(map(errors.is_dim, self.reference.keys())): invalid = tuple( itertools.filterfalse(errors.is_valid_unit_name, self.reference.keys()) @@ -180,14 +189,20 @@ class UnitDefinition(errors.WithDefErr): @property def is_base(self) -> bool: """Indicates if it is a base unit.""" + + # TODO: why is this here return self._is_base @property def is_multiplicative(self) -> bool: + # TODO: Check how to avoid this check + assert isinstance(self.converter, Converter) return self.converter.is_multiplicative @property def is_logarithmic(self) -> bool: + # TODO: Check how to avoid this check + assert isinstance(self.converter, Converter) return self.converter.is_logarithmic @property @@ -272,7 +287,7 @@ class ScaleConverter(Converter): scale: float - def to_reference(self, value, inplace=False): + def to_reference(self, value: Magnitude, inplace: bool = False) -> Magnitude: if inplace: value *= self.scale else: @@ -280,7 +295,7 @@ class ScaleConverter(Converter): return value - def from_reference(self, value, inplace=False): + def from_reference(self, value: Magnitude, inplace: bool = False) -> Magnitude: if inplace: value /= self.scale else: diff --git a/pint/facets/plain/objects.py b/pint/facets/plain/objects.py index 5b2837b..a868c7f 100644 --- a/pint/facets/plain/objects.py +++ b/pint/facets/plain/objects.py @@ -11,4 +11,4 @@ from __future__ import annotations from .quantity import PlainQuantity from .unit import PlainUnit, UnitsContainer -__all__ = [PlainUnit, PlainQuantity, UnitsContainer] +__all__ = ["PlainUnit", "PlainQuantity", "UnitsContainer"] diff --git a/pint/facets/system/definitions.py b/pint/facets/system/definitions.py index 8243324..893c510 100644 --- a/pint/facets/system/definitions.py +++ b/pint/facets/system/definitions.py @@ -8,9 +8,10 @@ from __future__ import annotations -import typing as ty +from typing import Iterable from dataclasses import dataclass +from ..._typing import Self from ... import errors @@ -23,7 +24,7 @@ class BaseUnitRule: new_unit_name: str #: name of the unit to be kicked out to make room for the new base uni #: If None, the current base unit with the same dimensionality will be used - old_unit_name: ty.Optional[str] = None + old_unit_name: str | None = None # Instead of defining __post_init__ here, # it will be added to the container class @@ -38,13 +39,16 @@ class SystemDefinition(errors.WithDefErr): #: name of the system name: str #: unit groups that will be included within the system - using_group_names: ty.Tuple[str, ...] + using_group_names: tuple[str] #: rules to define new base unit within the system. - rules: ty.Tuple[BaseUnitRule, ...] + rules: tuple[BaseUnitRule] @classmethod - def from_lines(cls, lines, non_int_type): + def from_lines( + cls: type[Self], lines: Iterable[str], non_int_type: type + ) -> Self | None: # TODO: this is to keep it backwards compatible + # TODO: check when is None returned. from ...delegates import ParserConfig, txt_defparser cfg = ParserConfig(non_int_type) @@ -55,7 +59,8 @@ class SystemDefinition(errors.WithDefErr): return definition @property - def unit_replacements(self) -> ty.Tuple[ty.Tuple[str, str], ...]: + def unit_replacements(self) -> tuple[tuple[str, str | None]]: + # TODO: check if None can be dropped. return tuple((el.new_unit_name, el.old_unit_name) for el in self.rules) def __post_init__(self): diff --git a/pint/facets/system/objects.py b/pint/facets/system/objects.py index 829fb5c..7af65a6 100644 --- a/pint/facets/system/objects.py +++ b/pint/facets/system/objects.py @@ -9,6 +9,12 @@ from __future__ import annotations +import numbers + +from typing import Any, Iterable + +from ..._typing import Self + from ...babel_names import _babel_systems from ...compat import babel_parse from ...util import ( @@ -29,32 +35,28 @@ class System(SharedRegistryObject): The System belongs to one Registry. See SystemDefinition for the definition file syntax. - """ - def __init__(self, name): - """ - :param name: Name of the group - :type name: str - """ + Parameters + ---------- + name + Name of the group. + """ + def __init__(self, name: str): #: Name of the system #: :type: str self.name = name #: Maps root unit names to a dict indicating the new unit and its exponent. - #: :type: dict[str, dict[str, number]]] - self.base_units = {} + self.base_units: dict[str, dict[str, numbers.Number]] = {} #: Derived unit names. - #: :type: set(str) - self.derived_units = set() + self.derived_units: set[str] = set() #: Names of the _used_groups in used by this system. - #: :type: set(str) - self._used_groups = set() + self._used_groups: set[str] = set() - #: :type: frozenset | None - self._computed_members = None + self._computed_members: frozenset[str] | None = None # Add this system to the system dictionary self._REGISTRY._systems[self.name] = self @@ -62,7 +64,7 @@ class System(SharedRegistryObject): def __dir__(self): return list(self.members) - def __getattr__(self, item): + def __getattr__(self, item: str) -> Any: getattr_maybe_raise(self, item) u = getattr(self._REGISTRY, self.name + "_" + item, None) if u is not None: @@ -93,19 +95,19 @@ class System(SharedRegistryObject): """Invalidate computed members in this Group and all parent nodes.""" self._computed_members = None - def add_groups(self, *group_names): + def add_groups(self, *group_names: str) -> None: """Add groups to group.""" self._used_groups |= set(group_names) self.invalidate_members() - def remove_groups(self, *group_names): + def remove_groups(self, *group_names: str) -> None: """Remove groups from group.""" self._used_groups -= set(group_names) self.invalidate_members() - def format_babel(self, locale): + def format_babel(self, locale: str) -> str: """translate the name of the system.""" if locale and self.name in _babel_systems: name = _babel_systems[self.name] @@ -114,8 +116,12 @@ class System(SharedRegistryObject): return self.name @classmethod - def from_lines(cls, lines, get_root_func, non_int_type=float): - system_definition = SystemDefinition.from_lines(lines, get_root_func) + def from_lines( + cls: type[Self], lines: Iterable[str], get_root_func, non_int_type: type = float + ) -> Self: + # TODO: we changed something here it used to be + # system_definition = SystemDefinition.from_lines(lines, get_root_func) + system_definition = SystemDefinition.from_lines(lines, non_int_type) return cls.from_definition(system_definition, get_root_func) @classmethod @@ -174,12 +180,12 @@ class System(SharedRegistryObject): class Lister: - def __init__(self, d): + def __init__(self, d: dict[str, Any]): self.d = d - def __dir__(self): + def __dir__(self) -> list[str]: return list(self.d.keys()) - def __getattr__(self, item): + def __getattr__(self, item: str) -> Any: getattr_maybe_raise(self, item) return self.d[item] diff --git a/pint/formatting.py b/pint/formatting.py index dcc8725..637d838 100644 --- a/pint/formatting.py +++ b/pint/formatting.py @@ -13,7 +13,8 @@ from __future__ import annotations import functools import re import warnings -from typing import Callable +from typing import Callable, Iterable, Any +from numbers import Number from .babel_names import _babel_lengths, _babel_units from .compat import babel_parse @@ -21,7 +22,7 @@ from .compat import babel_parse __JOIN_REG_EXP = re.compile(r"{\d*}") -def _join(fmt, iterable): +def _join(fmt: str, iterable: Iterable[Any]): """Join an iterable with the format specified in fmt. The format can be specified in two ways: @@ -55,7 +56,7 @@ def _join(fmt, iterable): _PRETTY_EXPONENTS = "⁰¹²³⁴⁵⁶⁷⁸⁹" -def _pretty_fmt_exponent(num): +def _pretty_fmt_exponent(num: Number) -> str: """Format an number into a pretty printed exponent. Parameters @@ -76,7 +77,7 @@ def _pretty_fmt_exponent(num): #: _FORMATS maps format specifications to the corresponding argument set to #: formatter(). -_FORMATS: dict[str, dict] = { +_FORMATS: dict[str, dict[str, Any]] = { "P": { # Pretty format. "as_ratio": True, "single_denominator": False, @@ -125,7 +126,7 @@ _FORMATS: dict[str, dict] = { _FORMATTERS: dict[str, Callable] = {} -def register_unit_format(name): +def register_unit_format(name: str): """register a function as a new format for units The registered function must have a signature of: @@ -268,18 +269,18 @@ def format_compact(unit, registry, **options): def formatter( - items, - as_ratio=True, - single_denominator=False, - product_fmt=" * ", - division_fmt=" / ", - power_fmt="{} ** {}", - parentheses_fmt="({0})", + items: list[tuple[str, Number]], + as_ratio: bool = True, + single_denominator: bool = False, + product_fmt: str = " * ", + division_fmt: str = " / ", + power_fmt: str = "{} ** {}", + parentheses_fmt: str = "({0})", exp_call=lambda x: f"{x:n}", - locale=None, - babel_length="long", - babel_plural_form="one", - sort=True, + locale: str | None = None, + babel_length: str = "long", + babel_plural_form: str = "one", + sort: bool = True, ): """Format a list of (name, exponent) pairs. diff --git a/pint/pint_eval.py b/pint/pint_eval.py index e776d60..d476eae 100644 --- a/pint/pint_eval.py +++ b/pint/pint_eval.py @@ -11,7 +11,9 @@ from __future__ import annotations import operator import token as tokenlib -import tokenize +from tokenize import TokenInfo + +from typing import Any from .errors import DefinitionSyntaxError @@ -30,7 +32,7 @@ _OP_PRIORITY = { } -def _power(left, right): +def _power(left: Any, right: Any) -> Any: from . import Quantity from .compat import is_duck_array @@ -45,7 +47,19 @@ def _power(left, right): return operator.pow(left, right) -_BINARY_OPERATOR_MAP = { +import typing + +UnaryOpT = typing.Callable[ + [ + Any, + ], + Any, +] +BinaryOpT = typing.Callable[[Any, Any], Any] + +_UNARY_OPERATOR_MAP: dict[str, UnaryOpT] = {"+": lambda x: x, "-": lambda x: x * -1} + +_BINARY_OPERATOR_MAP: dict[str, BinaryOpT] = { "**": _power, "*": operator.mul, "": operator.mul, # operator for implicit ops @@ -56,8 +70,6 @@ _BINARY_OPERATOR_MAP = { "//": operator.floordiv, } -_UNARY_OPERATOR_MAP = {"+": lambda x: x, "-": lambda x: x * -1} - class EvalTreeNode: """Single node within an evaluation tree @@ -68,25 +80,43 @@ class EvalTreeNode: left --> single value """ - def __init__(self, left, operator=None, right=None): + def __init__( + self, + left: EvalTreeNode | TokenInfo, + operator: TokenInfo | None = None, + right: EvalTreeNode | None = None, + ): self.left = left self.operator = operator self.right = right - def to_string(self): + def to_string(self) -> str: # For debugging purposes if self.right: + assert isinstance(self.left, EvalTreeNode), "self.left not EvalTreeNode (1)" comps = [self.left.to_string()] if self.operator: - comps.append(self.operator[1]) + comps.append(self.operator.string) comps.append(self.right.to_string()) elif self.operator: - comps = [self.operator[1], self.left.to_string()] + assert isinstance(self.left, EvalTreeNode), "self.left not EvalTreeNode (2)" + comps = [self.operator.string, self.left.to_string()] else: - return self.left[1] + assert isinstance(self.left, TokenInfo), "self.left not TokenInfo (1)" + return self.left.string return "(%s)" % " ".join(comps) - def evaluate(self, define_op, bin_op=None, un_op=None): + def evaluate( + self, + define_op: typing.Callable[ + [ + Any, + ], + Any, + ], + bin_op: dict[str, BinaryOpT] | None = None, + un_op: dict[str, UnaryOpT] | None = None, + ): """Evaluate node. Parameters @@ -107,17 +137,22 @@ class EvalTreeNode: un_op = un_op or _UNARY_OPERATOR_MAP if self.right: + assert isinstance(self.left, EvalTreeNode), "self.left not EvalTreeNode (3)" # binary or implicit operator - op_text = self.operator[1] if self.operator else "" + op_text = self.operator.string if self.operator else "" if op_text not in bin_op: - raise DefinitionSyntaxError('missing binary operator "%s"' % op_text) - left = self.left.evaluate(define_op, bin_op, un_op) - return bin_op[op_text](left, self.right.evaluate(define_op, bin_op, un_op)) + raise DefinitionSyntaxError(f"missing binary operator '{op_text}'") + + return bin_op[op_text]( + self.left.evaluate(define_op, bin_op, un_op), + self.right.evaluate(define_op, bin_op, un_op), + ) elif self.operator: + assert isinstance(self.left, EvalTreeNode), "self.left not EvalTreeNode (4)" # unary operator - op_text = self.operator[1] + op_text = self.operator.string if op_text not in un_op: - raise DefinitionSyntaxError('missing unary operator "%s"' % op_text) + raise DefinitionSyntaxError(f"missing unary operator '{op_text}'") return un_op[op_text](self.left.evaluate(define_op, bin_op, un_op)) # single value @@ -127,13 +162,13 @@ class EvalTreeNode: from collections.abc import Iterable -def build_eval_tree( - tokens: Iterable[tokenize.TokenInfo], - op_priority=None, - index=0, - depth=0, - prev_op=None, -) -> tuple[EvalTreeNode | None, int] | EvalTreeNode: +def _build_eval_tree( + tokens: list[TokenInfo], + op_priority: dict[str, int], + index: int = 0, + depth: int = 0, + prev_op: str = "", +) -> tuple[EvalTreeNode, int]: """Build an evaluation tree from a set of tokens. Params: @@ -153,14 +188,12 @@ def build_eval_tree( 5) Combine left side, operator, and right side into a new left side 6) Go back to step #2 - """ - - if op_priority is None: - op_priority = _OP_PRIORITY + Raises + ------ + DefinitionSyntaxError + If there is a syntax error. - if depth == 0 and prev_op is None: - # ensure tokens is list so we can access by index - tokens = list(tokens) + """ result = None @@ -171,19 +204,21 @@ def build_eval_tree( if token_type == tokenlib.OP: if token_text == ")": - if prev_op is None: + if prev_op == "": raise DefinitionSyntaxError( - "unopened parentheses in tokens: %s" % current_token + f"unopened parentheses in tokens: {current_token}" ) elif prev_op == "(": # close parenthetical group + assert result is not None return result, index else: # parenthetical group ending, but we need to close sub-operations within group + assert result is not None return result, index - 1 elif token_text == "(": # gather parenthetical group - right, index = build_eval_tree( + right, index = _build_eval_tree( tokens, op_priority, index + 1, 0, token_text ) if not tokens[index][1] == ")": @@ -208,7 +243,7 @@ def build_eval_tree( # previous operator is higher priority, so end previous binary op return result, index - 1 # get right side of binary op - right, index = build_eval_tree( + right, index = _build_eval_tree( tokens, op_priority, index + 1, depth + 1, token_text ) result = EvalTreeNode( @@ -216,7 +251,7 @@ def build_eval_tree( ) else: # unary operator - right, index = build_eval_tree( + right, index = _build_eval_tree( tokens, op_priority, index + 1, depth + 1, "unary" ) result = EvalTreeNode(left=right, operator=current_token) @@ -227,7 +262,7 @@ def build_eval_tree( # previous operator is higher priority than implicit, so end # previous binary op return result, index - 1 - right, index = build_eval_tree( + right, index = _build_eval_tree( tokens, op_priority, index, depth + 1, "" ) result = EvalTreeNode(left=result, right=right) @@ -240,13 +275,57 @@ def build_eval_tree( raise DefinitionSyntaxError("unclosed parentheses in tokens") if depth > 0 or prev_op: # have to close recursion + assert result is not None return result, index else: # recursion all closed, so just return the final result - return result + assert result is not None + return result, -1 if index + 1 >= len(tokens): # should hit ENDMARKER before this ever happens raise DefinitionSyntaxError("unexpected end to tokens") index += 1 + + +def build_eval_tree( + tokens: Iterable[TokenInfo], + op_priority: dict[str, int] | None = None, +) -> EvalTreeNode: + """Build an evaluation tree from a set of tokens. + + Params: + Index, depth, and prev_op used recursively, so don't touch. + Tokens is an iterable of tokens from an expression to be evaluated. + + Transform the tokens from an expression into a recursive parse tree, following order + of operations. Operations can include binary ops (3 + 4), implicit ops (3 kg), or + unary ops (-1). + + General Strategy: + 1) Get left side of operator + 2) If no tokens left, return final result + 3) Get operator + 4) Use recursion to create tree starting at token on right side of operator (start at step #1) + 4.1) If recursive call encounters an operator with lower or equal priority to step #2, exit recursion + 5) Combine left side, operator, and right side into a new left side + 6) Go back to step #2 + + Raises + ------ + DefinitionSyntaxError + If there is a syntax error. + + """ + + if op_priority is None: + op_priority = _OP_PRIORITY + + if not isinstance(tokens, list): + # ensure tokens is list so we can access by index + tokens = list(tokens) + + result, _ = _build_eval_tree(tokens, op_priority, 0, 0) + + return result diff --git a/pint/testsuite/test_compat_downcast.py b/pint/testsuite/test_compat_downcast.py index 4ca611d..ed43e94 100644 --- a/pint/testsuite/test_compat_downcast.py +++ b/pint/testsuite/test_compat_downcast.py @@ -38,7 +38,7 @@ def q_base(local_registry): # Define identity function for use in tests -def identity(ureg, x): +def id_matrix(ureg, x): return x @@ -63,17 +63,17 @@ def array(request): @pytest.mark.parametrize( "op, magnitude_op, unit_op", [ - pytest.param(identity, identity, identity, id="identity"), + pytest.param(id_matrix, id_matrix, id_matrix, id="identity"), pytest.param( lambda ureg, x: x + 1 * ureg.m, lambda ureg, x: x + 1, - identity, + id_matrix, id="addition", ), pytest.param( lambda ureg, x: x - 20 * ureg.cm, lambda ureg, x: x - 0.2, - identity, + id_matrix, id="subtraction", ), pytest.param( @@ -84,7 +84,7 @@ def array(request): ), pytest.param( lambda ureg, x: x / (1 * ureg.s), - identity, + id_matrix, lambda ureg, u: u / ureg.s, id="division", ), @@ -94,17 +94,17 @@ def array(request): WR(lambda u: u**2), id="square", ), - pytest.param(WR(lambda x: x.T), WR(lambda x: x.T), identity, id="transpose"), - pytest.param(WR(np.mean), WR(np.mean), identity, id="mean ufunc"), - pytest.param(WR(np.sum), WR(np.sum), identity, id="sum ufunc"), + pytest.param(WR(lambda x: x.T), WR(lambda x: x.T), id_matrix, id="transpose"), + pytest.param(WR(np.mean), WR(np.mean), id_matrix, id="mean ufunc"), + pytest.param(WR(np.sum), WR(np.sum), id_matrix, id="sum ufunc"), pytest.param(WR(np.sqrt), WR(np.sqrt), WR(lambda u: u**0.5), id="sqrt ufunc"), pytest.param( WR(lambda x: np.reshape(x, (25,))), WR(lambda x: np.reshape(x, (25,))), - identity, + id_matrix, id="reshape function", ), - pytest.param(WR(np.amax), WR(np.amax), identity, id="amax function"), + pytest.param(WR(np.amax), WR(np.amax), id_matrix, id="amax function"), ], ) def test_univariate_op_consistency( diff --git a/pint/util.py b/pint/util.py index 28710e7..807c3ac 100644 --- a/pint/util.py +++ b/pint/util.py @@ -14,48 +14,82 @@ import logging import math import operator import re -from collections.abc import Mapping +from collections.abc import Mapping, Iterable, Iterator from fractions import Fraction from functools import lru_cache, partial from logging import NullHandler from numbers import Number from token import NAME, NUMBER -from typing import TYPE_CHECKING, ClassVar +import tokenize +from typing import ( + TYPE_CHECKING, + ClassVar, + TypeAlias, + Callable, + TypeVar, + Hashable, + Generator, + Any, +) from .compat import NUMERIC_TYPES, tokenizer from .errors import DefinitionSyntaxError from .formatting import format_unit from .pint_eval import build_eval_tree +from ._typing import PintScalar + if TYPE_CHECKING: - from ._typing import Quantity, UnitLike + from ._typing import Quantity, UnitLike, Self from .registry import UnitRegistry + logger = logging.getLogger(__name__) logger.addHandler(NullHandler()) +T = TypeVar("T") +TH = TypeVar("TH", bound=Hashable) +ItMatrix: TypeAlias = Iterable[Iterable[PintScalar]] +Matrix: TypeAlias = list[list[PintScalar]] + + +def _noop(x: T) -> T: + return x + def matrix_to_string( - matrix, row_headers=None, col_headers=None, fmtfun=lambda x: str(int(x)) -): - """Takes a 2D matrix (as nested list) and returns a string. + matrix: ItMatrix, + row_headers: Iterable[str] | None = None, + col_headers: Iterable[str] | None = None, + fmtfun: Callable[ + [ + PintScalar, + ], + str, + ] = "{:0.0f}".format, +) -> str: + """Return a string representation of a matrix. Parameters ---------- - matrix : - - row_headers : - (Default value = None) - col_headers : - (Default value = None) - fmtfun : - (Default value = lambda x: str(int(x))) + matrix + A matrix given as an iterable of an iterable of numbers. + row_headers + An iterable of strings to serve as row headers. + (default = None, meaning no row headers are printed.) + col_headers + An iterable of strings to serve as column headers. + (default = None, meaning no col headers are printed.) + fmtfun + A callable to convert a number into string. + (default = `"{:0.0f}".format`) Returns ------- - + str + String representation of the matrix. """ - ret = [] + ret: list[str] = [] if col_headers: ret.append(("\t" if row_headers else "") + "\t".join(col_headers)) if row_headers: @@ -69,99 +103,124 @@ def matrix_to_string( return "\n".join(ret) -def transpose(matrix): - """Takes a 2D matrix (as nested list) and returns the transposed version. +def transpose(matrix: ItMatrix) -> Matrix: + """Return the transposed version of a matrix. Parameters ---------- - matrix : - + matrix + A matrix given as an iterable of an iterable of numbers. Returns ------- - + Matrix + The transposed version of the matrix. """ return [list(val) for val in zip(*matrix)] -def column_echelon_form(matrix, ntype=Fraction, transpose_result=False): - """Calculates the column echelon form using Gaussian elimination. +def matrix_apply( + matrix: ItMatrix, + func: Callable[ + [ + PintScalar, + ], + PintScalar, + ], +) -> Matrix: + """Apply a function to individual elements within a matrix. Parameters ---------- - matrix : - a 2D matrix as nested list. - ntype : - the numerical type to use in the calculation. (Default value = Fraction) - transpose_result : - indicates if the returned matrix should be transposed. (Default value = False) + matrix + A matrix given as an iterable of an iterable of numbers. + func + A callable that converts a number to another. Returns ------- - type - column echelon form, transformed identity matrix, swapped rows - + A new matrix in which each element has been replaced by new one. """ - lead = 0 + return [[func(x) for x in row] for row in matrix] + + +def column_echelon_form( + matrix: ItMatrix, ntype: type = Fraction, transpose_result: bool = False +) -> tuple[Matrix, Matrix, list[int]]: + """Calculate the column echelon form using Gaussian elimination. - M = transpose(matrix) + Parameters + ---------- + matrix + A 2D matrix as nested list. + ntype + The numerical type to use in the calculation. + (default = Fraction) + transpose_result + Indicates if the returned matrix should be transposed. + (default = False) - _transpose = transpose if transpose_result else lambda x: x + Returns + ------- + ech_matrix + Column echelon form. + id_matrix + Transformed identity matrix. + swapped + Swapped rows. + """ - rows, cols = len(M), len(M[0]) + _transpose = transpose if transpose_result else _noop - new_M = [] - for row in M: - r = [] - for x in row: - if isinstance(x, float): - x = ntype.from_float(x) - else: - x = ntype(x) - r.append(x) - new_M.append(r) - M = new_M + ech_matrix = matrix_apply( + transpose(matrix), + lambda x: ntype.from_float(x) if isinstance(x, float) else ntype(x), # type: ignore + ) + rows, cols = len(ech_matrix), len(ech_matrix[0]) # M = [[ntype(x) for x in row] for row in M] - I = [ # noqa: E741 + id_matrix: list[list[PintScalar]] = [ # noqa: E741 [ntype(1) if n == nc else ntype(0) for nc in range(rows)] for n in range(rows) ] - swapped = [] + swapped: list[int] = [] + lead = 0 for r in range(rows): if lead >= cols: - return _transpose(M), _transpose(I), swapped - i = r - while M[i][lead] == 0: - i += 1 - if i != rows: + return _transpose(ech_matrix), _transpose(id_matrix), swapped + s = r + while ech_matrix[s][lead] == 0: # type: ignore + s += 1 + if s != rows: continue - i = r + s = r lead += 1 if cols == lead: - return _transpose(M), _transpose(I), swapped + return _transpose(ech_matrix), _transpose(id_matrix), swapped - M[i], M[r] = M[r], M[i] - I[i], I[r] = I[r], I[i] + ech_matrix[s], ech_matrix[r] = ech_matrix[r], ech_matrix[s] + id_matrix[s], id_matrix[r] = id_matrix[r], id_matrix[s] - swapped.append(i) - lv = M[r][lead] - M[r] = [mrx / lv for mrx in M[r]] - I[r] = [mrx / lv for mrx in I[r]] + swapped.append(s) + lv = ech_matrix[r][lead] + ech_matrix[r] = [mrx / lv for mrx in ech_matrix[r]] + id_matrix[r] = [mrx / lv for mrx in id_matrix[r]] - for i in range(rows): - if i == r: + for s in range(rows): + if s == r: continue - lv = M[i][lead] - M[i] = [iv - lv * rv for rv, iv in zip(M[r], M[i])] - I[i] = [iv - lv * rv for rv, iv in zip(I[r], I[i])] + lv = ech_matrix[s][lead] + ech_matrix[s] = [ + iv - lv * rv for rv, iv in zip(ech_matrix[r], ech_matrix[s]) + ] + id_matrix[s] = [iv - lv * rv for rv, iv in zip(id_matrix[r], id_matrix[s])] lead += 1 - return _transpose(M), _transpose(I), swapped + return _transpose(ech_matrix), _transpose(id_matrix), swapped -def pi_theorem(quantities, registry=None): +def pi_theorem(quantities: dict[str, Any], registry: UnitRegistry | None = None): """Builds dimensionless quantities using the Buckingham π theorem Parameters @@ -169,7 +228,7 @@ def pi_theorem(quantities, registry=None): quantities : dict mapping between variable name and units registry : - (Default value = None) + (default value = None) Returns ------- @@ -183,7 +242,7 @@ def pi_theorem(quantities, registry=None): dimensions = set() if registry is None: - getdim = lambda x: x + getdim = _noop non_int_type = float else: getdim = registry.get_dimensionality @@ -211,18 +270,18 @@ def pi_theorem(quantities, registry=None): dimensions = list(dimensions) # Calculate dimensionless quantities - M = [ + matrix = [ [dimensionality[dimension] for name, dimensionality in quant] for dimension in dimensions ] - M, identity, pivot = column_echelon_form(M, transpose_result=False) + ech_matrix, id_matrix, pivot = column_echelon_form(matrix, transpose_result=False) # Collect results # Make all numbers integers and minimize the number of negative exponents. # Remove zeros results = [] - for rowm, rowi in zip(M, identity): + for rowm, rowi in zip(ech_matrix, id_matrix): if any(el != 0 for el in rowm): continue max_den = max(f.denominator for f in rowi) @@ -237,7 +296,9 @@ def pi_theorem(quantities, registry=None): return results -def solve_dependencies(dependencies): +def solve_dependencies( + dependencies: dict[TH, set[TH]] +) -> Generator[set[TH], None, None]: """Solve a dependency graph. Parameters @@ -246,12 +307,16 @@ def solve_dependencies(dependencies): dependency dictionary. For each key, the value is an iterable indicating its dependencies. - Returns - ------- - type + Yields + ------ + set iterator of sets, each containing keys of independents tasks dependent only of the previous tasks in the list. + Raises + ------ + ValueError + if a cyclic dependency is found. """ while dependencies: # values not in keys (items without dep) @@ -270,12 +335,37 @@ def solve_dependencies(dependencies): yield t -def find_shortest_path(graph, start, end, path=None): +def find_shortest_path( + graph: dict[TH, set[TH]], start: TH, end: TH, path: list[TH] | None = None +): + """Find shortest path between two nodes within a graph. + + Parameters + ---------- + graph + A graph given as a mapping of nodes + to a set of all connected nodes to it. + start + Starting node. + end + End node. + path + Path to prepend to the one found. + (default = None, empty path.) + + Returns + ------- + list[TH] + The shortest path between two nodes. + """ path = (path or []) + [start] if start == end: return path + + # TODO: raise ValueError when start not in graph if start not in graph: return None + shortest = None for node in graph[start]: if node not in path: @@ -283,10 +373,33 @@ def find_shortest_path(graph, start, end, path=None): if newpath: if not shortest or len(newpath) < len(shortest): shortest = newpath + return shortest -def find_connected_nodes(graph, start, visited=None): +def find_connected_nodes( + graph: dict[TH, set[TH]], start: TH, visited: set[TH] | None = None +) -> set[TH] | None: + """Find all nodes connected to a start node within a graph. + + Parameters + ---------- + graph + A graph given as a mapping of nodes + to a set of all connected nodes to it. + start + Starting node. + visited + Mutable set to collect visited nodes. + (default = None, empty set) + + Returns + ------- + set[TH] + The shortest path between two nodes. + """ + + # TODO: raise ValueError when start not in graph if start not in graph: return None @@ -300,17 +413,17 @@ def find_connected_nodes(graph, start, visited=None): return visited -class udict(dict): +class udict(dict[str, PintScalar]): """Custom dict implementing __missing__.""" - def __missing__(self, key): + def __missing__(self, key: str): return 0 - def copy(self): + def copy(self: Self) -> Self: return udict(self) -class UnitsContainer(Mapping): +class UnitsContainer(Mapping[str, PintScalar]): """The UnitsContainer stores the product of units and their respective exponent and implements the corresponding operations. @@ -318,23 +431,24 @@ class UnitsContainer(Mapping): Parameters ---------- - - Returns - ------- - type - - + non_int_type + Numerical type used for non integer values. """ __slots__ = ("_d", "_hash", "_one", "_non_int_type") - def __init__(self, *args, **kwargs) -> None: + _d: udict + _hash: int | None + _one: PintScalar + _non_int_type: type + + def __init__(self, *args, non_int_type: type | None = None, **kwargs) -> None: if args and isinstance(args[0], UnitsContainer): default_non_int_type = args[0]._non_int_type else: default_non_int_type = float - self._non_int_type = kwargs.pop("non_int_type", default_non_int_type) + self._non_int_type = non_int_type or default_non_int_type if self._non_int_type is float: self._one = 1 @@ -352,10 +466,26 @@ class UnitsContainer(Mapping): d[key] = self._non_int_type(value) self._hash = None - def copy(self): + def copy(self: Self) -> Self: + """Create a copy of this UnitsContainer.""" return self.__copy__() - def add(self, key, value): + def add(self: Self, key: str, value: Number) -> Self: + """Create a new UnitsContainer adding value to + the value existing for a given key. + + Parameters + ---------- + key + unit to which the value will be added. + value + value to be added. + + Returns + ------- + UnitsContainer + A copy of this container. + """ newval = self._d[key] + value new = self.copy() if newval: @@ -365,17 +495,18 @@ class UnitsContainer(Mapping): new._hash = None return new - def remove(self, keys): - """Create a new UnitsContainer purged from given keys. + def remove(self: Self, keys: Iterable[str]) -> Self: + """Create a new UnitsContainer purged from given entries. Parameters ---------- - keys : - + keys + Iterable of keys (units) to remove. Returns ------- - + UnitsContainer + A copy of this container. """ new = self.copy() for k in keys: @@ -383,51 +514,52 @@ class UnitsContainer(Mapping): new._hash = None return new - def rename(self, oldkey, newkey): + def rename(self: Self, oldkey: str, newkey: str) -> Self: """Create a new UnitsContainer in which an entry has been renamed. Parameters ---------- - oldkey : - - newkey : - + oldkey + Existing key (unit). + newkey + New key (unit). Returns ------- - + UnitsContainer + A copy of this container. """ new = self.copy() new._d[newkey] = new._d.pop(oldkey) new._hash = None return new - def __iter__(self): + def __iter__(self) -> Iterator[str]: return iter(self._d) def __len__(self) -> int: return len(self._d) - def __getitem__(self, key): + def __getitem__(self, key: str) -> PintScalar: return self._d[key] - def __contains__(self, key): + def __contains__(self, key: str) -> bool: return key in self._d - def __hash__(self): + def __hash__(self) -> int: if self._hash is None: self._hash = hash(frozenset(self._d.items())) return self._hash # Only needed by pickle protocol 0 and 1 (used by pytables) - def __getstate__(self): + def __getstate__(self) -> tuple[udict, PintScalar, type]: return self._d, self._one, self._non_int_type - def __setstate__(self, state): + def __setstate__(self, state: tuple[udict, PintScalar, type]): self._d, self._one, self._non_int_type = state self._hash = None - def __eq__(self, other) -> bool: + def __eq__(self, other: Any) -> bool: if isinstance(other, UnitsContainer): # UnitsContainer.__hash__(self) is not the same as hash(self); see # ParserHelper.__hash__ and __eq__. @@ -472,7 +604,7 @@ class UnitsContainer(Mapping): out._one = self._one return out - def __mul__(self, other): + def __mul__(self, other: Any): if not isinstance(other, self.__class__): err = "Cannot multiply UnitsContainer by {}" raise TypeError(err.format(type(other))) @@ -488,7 +620,7 @@ class UnitsContainer(Mapping): __rmul__ = __mul__ - def __pow__(self, other): + def __pow__(self, other: Any): if not isinstance(other, NUMERIC_TYPES): err = "Cannot power UnitsContainer by {}" raise TypeError(err.format(type(other))) @@ -499,7 +631,7 @@ class UnitsContainer(Mapping): new._hash = None return new - def __truediv__(self, other): + def __truediv__(self, other: Any): if not isinstance(other, self.__class__): err = "Cannot divide UnitsContainer by {}" raise TypeError(err.format(type(other))) @@ -513,7 +645,7 @@ class UnitsContainer(Mapping): new._hash = None return new - def __rtruediv__(self, other): + def __rtruediv__(self, other: Any): if not isinstance(other, self.__class__) and other != 1: err = "Cannot divide {} by UnitsContainer" raise TypeError(err.format(type(other))) @@ -524,41 +656,48 @@ class UnitsContainer(Mapping): class ParserHelper(UnitsContainer): """The ParserHelper stores in place the product of variables and their respective exponent and implements the corresponding operations. + It also provides a scaling factor. + + For example: + `3 * m ** 2` becomes ParserHelper(3, m=2) + + Briefly is a UnitsContainer with a scaling factor. ParserHelper is a read-only mapping. All operations (even in place ones) + WARNING : The hash value used does not take into account the scale + attribute so be careful if you use it as a dict key and then two unequal + object can have the same hash. + Parameters ---------- - - Returns - ------- - type - WARNING : The hash value used does not take into account the scale - attribute so be careful if you use it as a dict key and then two unequal - object can have the same hash. - + scale + Scaling factor. + (default = 1) + **kwargs + Used to populate the dict of units and exponents. """ __slots__ = ("scale",) - def __init__(self, scale=1, *args, **kwargs): + scale: PintScalar + + def __init__(self, scale: PintScalar = 1, *args, **kwargs): super().__init__(*args, **kwargs) self.scale = scale @classmethod - def from_word(cls, input_word, non_int_type=float): + def from_word(cls, input_word: str, non_int_type: type = float) -> ParserHelper: """Creates a ParserHelper object with a single variable with exponent one. - Equivalent to: ParserHelper({'word': 1}) + Equivalent to: ParserHelper(1, {input_word: 1}) Parameters ---------- - input_word : - - - Returns - ------- + input_word + non_int_type + Numerical type used for non integer values. """ if non_int_type is float: return cls(1, [(input_word, 1)], non_int_type=non_int_type) @@ -567,7 +706,7 @@ class ParserHelper(UnitsContainer): return cls(ONE, [(input_word, ONE)], non_int_type=non_int_type) @classmethod - def eval_token(cls, token, non_int_type=float): + def eval_token(cls, token: tokenize.TokenInfo, non_int_type: type = float): token_type = token.type token_text = token.string if token_type == NUMBER: @@ -585,17 +724,15 @@ class ParserHelper(UnitsContainer): @classmethod @lru_cache - def from_string(cls, input_string, non_int_type=float): + def from_string(cls, input_string: str, non_int_type: type = float) -> ParserHelper: """Parse linear expression mathematical units and return a quantity object. Parameters ---------- - input_string : - - - Returns - ------- + input_string + non_int_type + Numerical type used for non integer values. """ if not input_string: return cls(non_int_type=non_int_type) @@ -656,7 +793,7 @@ class ParserHelper(UnitsContainer): super().__setstate__(state[:-1]) self.scale = state[-1] - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if isinstance(other, ParserHelper): return self.scale == other.scale and super().__eq__(other) elif isinstance(other, str): @@ -666,7 +803,7 @@ class ParserHelper(UnitsContainer): return self.scale == 1 and super().__eq__(other) - def operate(self, items, op=operator.iadd, cleanup=True): + def operate(self, items, op=operator.iadd, cleanup: bool = True): d = udict(self._d) for key, value in items: d[key] = op(d[key], value) @@ -811,21 +948,22 @@ class SharedRegistryObject: inst._REGISTRY = application_registry.get() return inst - def _check(self, other) -> bool: + def _check(self, other: Any) -> bool: """Check if the other object use a registry and if so that it is the same registry. Parameters ---------- - other : - + other Returns ------- - type - other don't use a registry and raise ValueError if other don't use the - same unit registry. + bool + Raises + ------ + ValueError + if other don't use the same unit registry. """ if self._REGISTRY is getattr(other, "_REGISTRY", None): return True @@ -844,17 +982,17 @@ class PrettyIPython: default_format: str - def _repr_html_(self): + def _repr_html_(self) -> str: if "~" in self.default_format: return f"{self:~H}" return f"{self:H}" - def _repr_latex_(self): + def _repr_latex_(self) -> str: if "~" in self.default_format: return f"${self:~L}$" return f"${self:L}$" - def _repr_pretty_(self, p, cycle): + def _repr_pretty_(self, p, cycle: bool): if "~" in self.default_format: p.text(f"{self:~P}") else: @@ -868,14 +1006,15 @@ def to_units_container( Parameters ---------- - unit_like : - - registry : - (Default value = None) + unit_like + Quantity or Unit to infer the plain units from. + registry + If provided, uses the registry's UnitsContainer and parse_unit_name. If None, + uses the registry attached to unit_like. Returns ------- - + UnitsContainer """ mro = type(unit_like).mro() if UnitsContainer in mro: @@ -902,10 +1041,9 @@ def infer_base_unit( Parameters ---------- - unit_like : Union[UnitLike, Quantity] + unit_like Quantity or Unit to infer the plain units from. - - registry: Optional[UnitRegistry] + registry If provided, uses the registry's UnitsContainer and parse_unit_name. If None, uses the registry attached to unit_like. @@ -940,7 +1078,7 @@ def infer_base_unit( return registry.UnitsContainer(nonzero_dict) -def getattr_maybe_raise(self, item): +def getattr_maybe_raise(obj: Any, item: str): """Helper function invoked at start of all overridden ``__getattr__``. Raise AttributeError if the user tries to ask for a _ or __ attribute, @@ -949,39 +1087,25 @@ def getattr_maybe_raise(self, item): Parameters ---------- - item : string - Item to be found. - - - Returns - ------- + item + attribute to be found. + Raises + ------ + AttributeError """ # Double-underscore attributes are tricky to detect because they are - # automatically prefixed with the class name - which may be a subclass of self + # automatically prefixed with the class name - which may be a subclass of obj if ( item.endswith("__") or len(item.lstrip("_")) == 0 or (item.startswith("_") and not item.lstrip("_")[0].isdigit()) ): - raise AttributeError(f"{self!r} object has no attribute {item!r}") - + raise AttributeError(f"{obj!r} object has no attribute {item!r}") -def iterable(y) -> bool: - """Check whether or not an object can be iterated over. - - Vendored from numpy under the terms of the BSD 3-Clause License. (Copyright - (c) 2005-2019, NumPy Developers.) - - Parameters - ---------- - value : - Input object. - type : - object - y : - """ +def iterable(y: Any) -> bool: + """Check whether or not an object can be iterated over.""" try: iter(y) except TypeError: @@ -989,18 +1113,8 @@ def iterable(y) -> bool: return True -def sized(y) -> bool: - """Check whether or not an object has a defined length. - - Parameters - ---------- - value : - Input object. - type : - object - y : - - """ +def sized(y: Any) -> bool: + """Check whether or not an object has a defined length.""" try: len(y) except TypeError: @@ -1008,7 +1122,7 @@ def sized(y) -> bool: return True -def create_class_with_registry(registry, base_class) -> type: +def create_class_with_registry(registry: UnitRegistry, base_class: type) -> type: """Create new class inheriting from base_class and filling _REGISTRY class attribute with an actual instanced registry. """ -- cgit v1.2.1