summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHernan Grecco <hgrecco@gmail.com>2023-05-01 10:00:26 -0300
committerHernan Grecco <hgrecco@gmail.com>2023-05-01 19:12:18 -0300
commitd0442e776bff3053ef900143cd64facf173e1650 (patch)
treec25053cb5391121c785222320f1a8829db878a37
parenta814d1bc4a2e1b0786ca1e56e0f9bf7e363434b4 (diff)
downloadpint-d0442e776bff3053ef900143cd64facf173e1650.tar.gz
Typing improvements
While there is still a lot of work to do (mainly in Registry, Quantity, Unit), this large PR makes several changes all around the code. There has not been any intended functional change, but certain typing improvements required code minor code refactoring to streamline input and output types of functions. An important experimental idea is the PintScalar and PintArray protocols, and Magnitude type. This is to overcome the lack of a proper numerical hierarchy in Python.
-rw-r--r--pint/_typing.py46
-rw-r--r--pint/babel_names.py6
-rw-r--r--pint/compat.py82
-rw-r--r--pint/context.py2
-rw-r--r--pint/converters.py16
-rw-r--r--pint/definitions.py22
-rw-r--r--pint/delegates/__init__.py2
-rw-r--r--pint/delegates/base_defparser.py10
-rw-r--r--pint/delegates/txt_defparser/__init__.py4
-rw-r--r--pint/delegates/txt_defparser/block.py19
-rw-r--r--pint/delegates/txt_defparser/common.py6
-rw-r--r--pint/delegates/txt_defparser/context.py78
-rw-r--r--pint/delegates/txt_defparser/defaults.py20
-rw-r--r--pint/delegates/txt_defparser/defparser.py45
-rw-r--r--pint/delegates/txt_defparser/group.py24
-rw-r--r--pint/delegates/txt_defparser/plain.py18
-rw-r--r--pint/delegates/txt_defparser/system.py23
-rw-r--r--pint/errors.py28
-rw-r--r--pint/facets/__init__.py18
-rw-r--r--pint/facets/context/definitions.py16
-rw-r--r--pint/facets/context/objects.py13
-rw-r--r--pint/facets/group/definitions.py15
-rw-r--r--pint/facets/group/objects.py56
-rw-r--r--pint/facets/nonmultiplicative/definitions.py10
-rw-r--r--pint/facets/nonmultiplicative/objects.py2
-rw-r--r--pint/facets/plain/definitions.py31
-rw-r--r--pint/facets/plain/objects.py2
-rw-r--r--pint/facets/system/definitions.py17
-rw-r--r--pint/facets/system/objects.py52
-rw-r--r--pint/formatting.py33
-rw-r--r--pint/pint_eval.py155
-rw-r--r--pint/testsuite/test_compat_downcast.py20
-rw-r--r--pint/util.py504
33 files changed, 905 insertions, 490 deletions
diff --git a/pint/_typing.py b/pint/_typing.py
index 1dc3ea6..5547f85 100644
--- a/pint/_typing.py
+++ b/pint/_typing.py
@@ -1,12 +1,56 @@
from __future__ import annotations
-from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union
+from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union, Protocol
+
+# TODO: Remove when 3.11 becomes minimal version.
+Self = TypeVar("Self")
if TYPE_CHECKING:
from .facets.plain import PlainQuantity as Quantity
from .facets.plain import PlainUnit as Unit
from .util import UnitsContainer
+
+class PintScalar(Protocol):
+ def __add__(self, other: Any) -> Any:
+ ...
+
+ def __sub__(self, other: Any) -> Any:
+ ...
+
+ def __mul__(self, other: Any) -> Any:
+ ...
+
+ def __truediv__(self, other: Any) -> Any:
+ ...
+
+ def __floordiv__(self, other: Any) -> Any:
+ ...
+
+ def __mod__(self, other: Any) -> Any:
+ ...
+
+ def __divmod__(self, other: Any) -> Any:
+ ...
+
+ def __pow__(self, other: Any, modulo: Any) -> Any:
+ ...
+
+
+class PintArray(Protocol):
+ def __len__(self) -> int:
+ ...
+
+ def __getitem__(self, key: Any) -> Any:
+ ...
+
+ def __setitem__(self, key: Any, value: Any) -> None:
+ ...
+
+
+Magnitude = PintScalar | PintScalar
+
+
UnitLike = Union[str, "UnitsContainer", "Unit"]
QuantityOrUnitLike = Union["Quantity", UnitLike]
diff --git a/pint/babel_names.py b/pint/babel_names.py
index 09fa046..408ef8f 100644
--- a/pint/babel_names.py
+++ b/pint/babel_names.py
@@ -10,7 +10,7 @@ from __future__ import annotations
from .compat import HAS_BABEL
-_babel_units = dict(
+_babel_units: dict[str, str] = dict(
standard_gravity="acceleration-g-force",
millibar="pressure-millibar",
metric_ton="mass-metric-ton",
@@ -141,6 +141,6 @@ _babel_units = dict(
if not HAS_BABEL:
_babel_units = {}
-_babel_systems = dict(mks="metric", imperial="uksystem", US="ussystem")
+_babel_systems: dict[str, str] = dict(mks="metric", imperial="uksystem", US="ussystem")
-_babel_lengths = ["narrow", "short", "long"]
+_babel_lengths: list[str] = ["narrow", "short", "long"]
diff --git a/pint/compat.py b/pint/compat.py
index ee8d443..f58e9cb 100644
--- a/pint/compat.py
+++ b/pint/compat.py
@@ -17,12 +17,19 @@ from importlib import import_module
from io import BytesIO
from numbers import Number
from collections.abc import Mapping
+from typing import Any, NoReturn, Callable, Generator, Iterable
-def missing_dependency(package, display_name=None):
+def missing_dependency(
+ package: str, display_name: str | None = None
+) -> Callable[..., NoReturn]:
+ """Return a helper function that raises an exception when used.
+
+ It provides a way delay a missing dependency exception until it is used.
+ """
display_name = display_name or package
- def _inner(*args, **kwargs):
+ def _inner(*args: Any, **kwargs: Any) -> NoReturn:
raise Exception(
"This feature requires %s. Please install it by running:\n"
"pip install %s" % (display_name, package)
@@ -31,7 +38,14 @@ def missing_dependency(package, display_name=None):
return _inner
-def tokenizer(input_string):
+def tokenizer(input_string: str) -> Generator[tokenize.TokenInfo, None, None]:
+ """Tokenize an input string, encoded as UTF-8
+ and skipping the ENCODING token.
+
+ See Also
+ --------
+ tokenize.tokenize
+ """
for tokinfo in tokenize.tokenize(BytesIO(input_string.encode("utf-8")).readline):
if tokinfo.type != tokenize.ENCODING:
yield tokinfo
@@ -154,7 +168,8 @@ else:
from math import log # noqa: F401
if not HAS_BABEL:
- babel_parse = babel_units = missing_dependency("Babel") # noqa: F811
+ babel_parse = missing_dependency("Babel") # noqa: F811
+ babel_units = babel_parse
if not HAS_MIP:
mip_missing = missing_dependency("mip")
@@ -176,6 +191,9 @@ except ImportError:
dask_array = None
+# TODO: merge with upcast_type_map
+
+#: List upcast type names
upcast_type_names = (
"pint_pandas.PintArray",
"pandas.Series",
@@ -186,10 +204,12 @@ upcast_type_names = (
"xarray.core.dataarray.DataArray",
)
-upcast_type_map: Mapping[str : type | None] = {k: None for k in upcast_type_names}
+#: Map type name to the actual type (for upcast types).
+upcast_type_map: Mapping[str, type | None] = {k: None for k in upcast_type_names}
def fully_qualified_name(t: type) -> str:
+ """Return the fully qualified name of a type."""
module = t.__module__
name = t.__qualname__
@@ -200,6 +220,10 @@ def fully_qualified_name(t: type) -> str:
def check_upcast_type(obj: type) -> bool:
+ """Check if the type object is an upcast type."""
+
+ # TODO: merge or unify name with is_upcast_type
+
fqn = fully_qualified_name(obj)
if fqn not in upcast_type_map:
return False
@@ -215,22 +239,17 @@ def check_upcast_type(obj: type) -> bool:
def is_upcast_type(other: type) -> bool:
+ """Check if the type object is an upcast type."""
+
+ # TODO: merge or unify name with check_upcast_type
+
if other in upcast_type_map.values():
return True
return check_upcast_type(other)
-def is_duck_array_type(cls) -> bool:
- """Check if the type object represents a (non-Quantity) duck array type.
-
- Parameters
- ----------
- cls : class
-
- Returns
- -------
- bool
- """
+def is_duck_array_type(cls: type) -> bool:
+ """Check if the type object represents a (non-Quantity) duck array type."""
# TODO (NEP 30): replace duck array check with hasattr(other, "__duckarray__")
return issubclass(cls, ndarray) or (
not hasattr(cls, "_magnitude")
@@ -242,20 +261,21 @@ def is_duck_array_type(cls) -> bool:
)
-def is_duck_array(obj):
+def is_duck_array(obj: type) -> bool:
+ """Check if an object represents a (non-Quantity) duck array type."""
return is_duck_array_type(type(obj))
-def eq(lhs, rhs, check_all: bool):
+def eq(lhs: Any, rhs: Any, check_all: bool) -> bool | Iterable[bool]:
"""Comparison of scalars and arrays.
Parameters
----------
- lhs : object
+ lhs
left-hand side
- rhs : object
+ rhs
right-hand side
- check_all : bool
+ check_all
if True, reduce sequence to single bool;
return True if all the elements are equal.
@@ -269,21 +289,21 @@ def eq(lhs, rhs, check_all: bool):
return out
-def isnan(obj, check_all: bool):
- """Test for NaN or NaT
+def isnan(obj: Any, check_all: bool) -> bool | Iterable[bool]:
+ """Test for NaN or NaT.
Parameters
----------
- obj : object
+ obj
scalar or vector
- check_all : bool
+ check_all
if True, reduce sequence to single bool;
return True if any of the elements are NaN.
Returns
-------
bool or array_like of bool.
- Always return False for non-numeric types.
+ Always return False for non-numeric types.
"""
if is_duck_array_type(type(obj)):
if obj.dtype.kind in "if":
@@ -302,21 +322,21 @@ def isnan(obj, check_all: bool):
return False
-def zero_or_nan(obj, check_all: bool):
- """Test if obj is zero, NaN, or NaT
+def zero_or_nan(obj: Any, check_all: bool) -> bool | Iterable[bool]:
+ """Test if obj is zero, NaN, or NaT.
Parameters
----------
- obj : object
+ obj
scalar or vector
- check_all : bool
+ check_all
if True, reduce sequence to single bool;
return True if all the elements are zero, NaN, or NaT.
Returns
-------
bool or array_like of bool.
- Always return False for non-numeric types.
+ Always return False for non-numeric types.
"""
out = eq(obj, 0, False) + isnan(obj, False)
if check_all and is_duck_array_type(type(out)):
diff --git a/pint/context.py b/pint/context.py
index 4839926..6c74f65 100644
--- a/pint/context.py
+++ b/pint/context.py
@@ -18,3 +18,5 @@ if TYPE_CHECKING:
#: Regex to match the header parts of a context.
#: Regex to match variable names in an equation.
+
+# TODO: delete this file
diff --git a/pint/converters.py b/pint/converters.py
index 9b8513f..9494ad1 100644
--- a/pint/converters.py
+++ b/pint/converters.py
@@ -13,6 +13,10 @@ from __future__ import annotations
from dataclasses import dataclass
from dataclasses import fields as dc_fields
+from typing import Any
+
+from ._typing import Self, Magnitude
+
from .compat import HAS_NUMPY, exp, log # noqa: F401
@@ -24,17 +28,17 @@ class Converter:
_param_names_to_subclass = {}
@property
- def is_multiplicative(self):
+ def is_multiplicative(self) -> bool:
return True
@property
- def is_logarithmic(self):
+ def is_logarithmic(self) -> bool:
return False
- def to_reference(self, value, inplace=False):
+ def to_reference(self, value: Magnitude, inplace: bool = False) -> Magnitude:
return value
- def from_reference(self, value, inplace=False):
+ def from_reference(self, value: Magnitude, inplace: bool = False) -> Magnitude:
return value
def __init_subclass__(cls, **kwargs):
@@ -43,7 +47,7 @@ class Converter:
cls._subclasses.append(cls)
@classmethod
- def get_field_names(cls, new_cls):
+ def get_field_names(cls, new_cls) -> frozenset[str]:
return frozenset(p.name for p in dc_fields(new_cls))
@classmethod
@@ -51,7 +55,7 @@ class Converter:
return None
@classmethod
- def from_arguments(cls, **kwargs):
+ def from_arguments(cls: type[Self], **kwargs: Any) -> Self:
kwk = frozenset(kwargs.keys())
try:
new_cls = cls._param_names_to_subclass[kwk]
diff --git a/pint/definitions.py b/pint/definitions.py
index 789d9e3..ce89e94 100644
--- a/pint/definitions.py
+++ b/pint/definitions.py
@@ -8,6 +8,8 @@
:license: BSD, see LICENSE for more details.
"""
+from __future__ import annotations
+
from . import errors
from ._vendor import flexparser as fp
from .delegates import ParserConfig, txt_defparser
@@ -17,12 +19,28 @@ class Definition:
"""This is kept for backwards compatibility"""
@classmethod
- def from_string(cls, s: str, non_int_type=float):
+ def from_string(cls, input_string: str, non_int_type: type = float) -> Definition:
+ """Parse a string into a definition object.
+
+ Parameters
+ ----------
+ input_string
+ Single line string.
+ non_int_type
+ Numerical type used for non integer values.
+
+ Raises
+ ------
+ DefinitionSyntaxError
+ If a syntax error was found.
+ """
cfg = ParserConfig(non_int_type)
parser = txt_defparser.DefParser(cfg, None)
- pp = parser.parse_string(s)
+ pp = parser.parse_string(input_string)
for definition in parser.iter_parsed_project(pp):
if isinstance(definition, Exception):
raise errors.DefinitionSyntaxError(str(definition))
if not isinstance(definition, (fp.BOS, fp.BOF, fp.BOS)):
return definition
+
+ # TODO: What shall we do in this return path.
diff --git a/pint/delegates/__init__.py b/pint/delegates/__init__.py
index 363ef9c..b2eb9a3 100644
--- a/pint/delegates/__init__.py
+++ b/pint/delegates/__init__.py
@@ -11,4 +11,4 @@
from . import txt_defparser
from .base_defparser import ParserConfig, build_disk_cache_class
-__all__ = [txt_defparser, ParserConfig, build_disk_cache_class]
+__all__ = ["txt_defparser", "ParserConfig", "build_disk_cache_class"]
diff --git a/pint/delegates/base_defparser.py b/pint/delegates/base_defparser.py
index 88d9d37..9e784ac 100644
--- a/pint/delegates/base_defparser.py
+++ b/pint/delegates/base_defparser.py
@@ -14,7 +14,6 @@ import functools
import itertools
import numbers
import pathlib
-import typing as ty
from dataclasses import dataclass, field
from pint import errors
@@ -27,10 +26,10 @@ from .._vendor import flexparser as fp
@dataclass(frozen=True)
class ParserConfig:
- """Configuration used by the parser."""
+ """Configuration used by the parser in Pint."""
#: Indicates the output type of non integer numbers.
- non_int_type: ty.Type[numbers.Number] = float
+ non_int_type: type[numbers.Number] = float
def to_scaled_units_container(self, s: str):
return ParserHelper.from_string(s, self.non_int_type)
@@ -67,6 +66,11 @@ class ParserConfig:
return val.scale
+@dataclass(frozen=True)
+class PintParsedStatement(fp.ParsedStatement[ParserConfig]):
+ """A parsed statement for pint, specialized in the actual config."""
+
+
@functools.lru_cache
def build_disk_cache_class(non_int_type: type):
"""Build disk cache class, taking into account the non_int_type."""
diff --git a/pint/delegates/txt_defparser/__init__.py b/pint/delegates/txt_defparser/__init__.py
index 5572ca1..49e4a0b 100644
--- a/pint/delegates/txt_defparser/__init__.py
+++ b/pint/delegates/txt_defparser/__init__.py
@@ -11,4 +11,6 @@
from .defparser import DefParser
-__all__ = [DefParser]
+__all__ = [
+ "DefParser",
+]
diff --git a/pint/delegates/txt_defparser/block.py b/pint/delegates/txt_defparser/block.py
index 20ebcba..e8d8aa4 100644
--- a/pint/delegates/txt_defparser/block.py
+++ b/pint/delegates/txt_defparser/block.py
@@ -17,11 +17,14 @@ from __future__ import annotations
from dataclasses import dataclass
+from typing import Generic, TypeVar
+
+from ..base_defparser import PintParsedStatement, ParserConfig
from ..._vendor import flexparser as fp
@dataclass(frozen=True)
-class EndDirectiveBlock(fp.ParsedStatement):
+class EndDirectiveBlock(PintParsedStatement):
"""An EndDirectiveBlock is simply an "@end" statement."""
@classmethod
@@ -31,8 +34,16 @@ class EndDirectiveBlock(fp.ParsedStatement):
return None
+OPST = TypeVar("OPST", bound="PintParsedStatement")
+IPST = TypeVar("IPST", bound="PintParsedStatement")
+
+DefT = TypeVar("DefT")
+
+
@dataclass(frozen=True)
-class DirectiveBlock(fp.Block):
+class DirectiveBlock(
+ Generic[DefT, OPST, IPST], fp.Block[OPST, IPST, EndDirectiveBlock, ParserConfig]
+):
"""Directive blocks have beginning statement starting with a @ character.
and ending with a "@end" (captured using a EndDirectiveBlock).
@@ -41,5 +52,5 @@ class DirectiveBlock(fp.Block):
closing: EndDirectiveBlock
- def derive_definition(self):
- pass
+ def derive_definition(self) -> DefT:
+ ...
diff --git a/pint/delegates/txt_defparser/common.py b/pint/delegates/txt_defparser/common.py
index 493d0ec..a1195b3 100644
--- a/pint/delegates/txt_defparser/common.py
+++ b/pint/delegates/txt_defparser/common.py
@@ -30,7 +30,7 @@ class DefinitionSyntaxError(errors.DefinitionSyntaxError, fp.ParsingError):
location: str = field(init=False, default="")
- def __str__(self):
+ def __str__(self) -> str:
msg = (
self.msg + "\n " + (self.format_position or "") + " " + (self.raw or "")
)
@@ -38,7 +38,7 @@ class DefinitionSyntaxError(errors.DefinitionSyntaxError, fp.ParsingError):
msg += "\n " + self.location
return msg
- def set_location(self, value):
+ def set_location(self, value: str) -> None:
super().__setattr__("location", value)
@@ -47,7 +47,7 @@ class ImportDefinition(fp.IncludeStatement):
value: str
@property
- def target(self):
+ def target(self) -> str:
return self.value
@classmethod
diff --git a/pint/delegates/txt_defparser/context.py b/pint/delegates/txt_defparser/context.py
index b7e5a67..ce9fc9b 100644
--- a/pint/delegates/txt_defparser/context.py
+++ b/pint/delegates/txt_defparser/context.py
@@ -23,32 +23,32 @@ from dataclasses import dataclass
from ..._vendor import flexparser as fp
from ...facets.context import definitions
-from ..base_defparser import ParserConfig
+from ..base_defparser import ParserConfig, PintParsedStatement
from . import block, common, plain
+# TODO check syntax
+T = ty.TypeVar("T", bound="ForwardRelation | BidirectionalRelation")
-@dataclass(frozen=True)
-class Relation(definitions.Relation):
- @classmethod
- def _from_string_and_context_sep(
- cls, s: str, config: ParserConfig, separator: str
- ) -> fp.FromString[Relation]:
- if separator not in s:
- return None
- if ":" not in s:
- return None
- rel, eq = s.split(":")
+def _from_string_and_context_sep(
+ cls: type[T], s: str, config: ParserConfig, separator: str
+) -> T | None:
+ if separator not in s:
+ return None
+ if ":" not in s:
+ return None
+
+ rel, eq = s.split(":")
- parts = rel.split(separator)
+ parts = rel.split(separator)
- src, dst = (config.to_dimension_container(s) for s in parts)
+ src, dst = (config.to_dimension_container(s) for s in parts)
- return cls(src, dst, eq.strip())
+ return cls(src, dst, eq.strip())
@dataclass(frozen=True)
-class ForwardRelation(fp.ParsedStatement, definitions.ForwardRelation, Relation):
+class ForwardRelation(PintParsedStatement, definitions.ForwardRelation):
"""A relation connecting a dimension to another via a transformation function.
<source dimension> -> <target dimension>: <transformation function>
@@ -58,13 +58,11 @@ class ForwardRelation(fp.ParsedStatement, definitions.ForwardRelation, Relation)
def from_string_and_config(
cls, s: str, config: ParserConfig
) -> fp.FromString[ForwardRelation]:
- return super()._from_string_and_context_sep(s, config, "->")
+ return _from_string_and_context_sep(cls, s, config, "->")
@dataclass(frozen=True)
-class BidirectionalRelation(
- fp.ParsedStatement, definitions.BidirectionalRelation, Relation
-):
+class BidirectionalRelation(PintParsedStatement, definitions.BidirectionalRelation):
"""A bidirectional relation connecting a dimension to another
via a simple transformation function.
@@ -76,11 +74,11 @@ class BidirectionalRelation(
def from_string_and_config(
cls, s: str, config: ParserConfig
) -> fp.FromString[BidirectionalRelation]:
- return super()._from_string_and_context_sep(s, config, "<->")
+ return _from_string_and_context_sep(cls, s, config, "<->")
@dataclass(frozen=True)
-class BeginContext(fp.ParsedStatement):
+class BeginContext(PintParsedStatement):
"""Being of a context directive.
@context[(defaults)] <canonical name> [= <alias>] [= <alias>]
@@ -91,7 +89,7 @@ class BeginContext(fp.ParsedStatement):
)
name: str
- aliases: tuple[str, ...]
+ aliases: tuple[str]
defaults: dict[str, numbers.Number]
@classmethod
@@ -130,7 +128,18 @@ class BeginContext(fp.ParsedStatement):
@dataclass(frozen=True)
-class ContextDefinition(block.DirectiveBlock):
+class ContextDefinition(
+ block.DirectiveBlock[
+ definitions.ContextDefinition,
+ BeginContext,
+ ty.Union[
+ plain.CommentDefinition,
+ BidirectionalRelation,
+ ForwardRelation,
+ plain.UnitDefinition,
+ ],
+ ]
+):
"""Definition of a Context
@context[(defaults)] <canonical name> [= <alias>] [= <alias>]
@@ -169,27 +178,34 @@ class ContextDefinition(block.DirectiveBlock):
]
]
- def derive_definition(self):
+ def derive_definition(self) -> definitions.ContextDefinition:
return definitions.ContextDefinition(
self.name, self.aliases, self.defaults, self.relations, self.redefinitions
)
@property
- def name(self):
+ def name(self) -> str:
+ assert isinstance(self.opening, BeginContext)
return self.opening.name
@property
- def aliases(self):
+ def aliases(self) -> tuple[str]:
+ assert isinstance(self.opening, BeginContext)
return self.opening.aliases
@property
- def defaults(self):
+ def defaults(self) -> dict[str, numbers.Number]:
+ assert isinstance(self.opening, BeginContext)
return self.opening.defaults
@property
- def relations(self):
- return tuple(r for r in self.body if isinstance(r, Relation))
+ def relations(self) -> tuple[BidirectionalRelation | ForwardRelation]:
+ return tuple(
+ r
+ for r in self.body
+ if isinstance(r, (ForwardRelation, BidirectionalRelation))
+ )
@property
- def redefinitions(self):
+ def redefinitions(self) -> tuple[plain.UnitDefinition]:
return tuple(r for r in self.body if isinstance(r, plain.UnitDefinition))
diff --git a/pint/delegates/txt_defparser/defaults.py b/pint/delegates/txt_defparser/defaults.py
index af6e31f..688d90f 100644
--- a/pint/delegates/txt_defparser/defaults.py
+++ b/pint/delegates/txt_defparser/defaults.py
@@ -19,10 +19,11 @@ from dataclasses import dataclass, fields
from ..._vendor import flexparser as fp
from ...facets.plain import definitions
from . import block, plain
+from ..base_defparser import PintParsedStatement
@dataclass(frozen=True)
-class BeginDefaults(fp.ParsedStatement):
+class BeginDefaults(PintParsedStatement):
"""Being of a defaults directive.
@defaults
@@ -36,7 +37,16 @@ class BeginDefaults(fp.ParsedStatement):
@dataclass(frozen=True)
-class DefaultsDefinition(block.DirectiveBlock):
+class DefaultsDefinition(
+ block.DirectiveBlock[
+ definitions.DefaultsDefinition,
+ BeginDefaults,
+ ty.Union[
+ plain.CommentDefinition,
+ plain.Equality,
+ ],
+ ]
+):
"""Directive to store values.
@defaults
@@ -55,10 +65,10 @@ class DefaultsDefinition(block.DirectiveBlock):
]
@property
- def _valid_fields(self):
+ def _valid_fields(self) -> tuple[str]:
return tuple(f.name for f in fields(definitions.DefaultsDefinition))
- def derive_definition(self):
+ def derive_definition(self) -> definitions.DefaultsDefinition:
for definition in self.filter_by(plain.Equality):
if definition.lhs not in self._valid_fields:
raise ValueError(
@@ -70,7 +80,7 @@ class DefaultsDefinition(block.DirectiveBlock):
*tuple(self.get_key(key) for key in self._valid_fields)
)
- def get_key(self, key):
+ def get_key(self, key: str) -> str:
for stmt in self.body:
if isinstance(stmt, plain.Equality) and stmt.lhs == key:
return stmt.rhs
diff --git a/pint/delegates/txt_defparser/defparser.py b/pint/delegates/txt_defparser/defparser.py
index 0b99d6d..f1b8e45 100644
--- a/pint/delegates/txt_defparser/defparser.py
+++ b/pint/delegates/txt_defparser/defparser.py
@@ -5,11 +5,28 @@ import typing as ty
from ..._vendor import flexcache as fc
from ..._vendor import flexparser as fp
-from .. import base_defparser
+from ..base_defparser import ParserConfig
from . import block, common, context, defaults, group, plain, system
-class PintRootBlock(fp.RootBlock):
+class PintRootBlock(
+ fp.RootBlock[
+ ty.Union[
+ plain.CommentDefinition,
+ common.ImportDefinition,
+ context.ContextDefinition,
+ defaults.DefaultsDefinition,
+ system.SystemDefinition,
+ group.GroupDefinition,
+ plain.AliasDefinition,
+ plain.DerivedDimensionDefinition,
+ plain.DimensionDefinition,
+ plain.PrefixDefinition,
+ plain.UnitDefinition,
+ ],
+ ParserConfig,
+ ]
+):
body: fp.Multi[
ty.Union[
plain.CommentDefinition,
@@ -27,11 +44,15 @@ class PintRootBlock(fp.RootBlock):
]
+class PintSource(fp.ParsedSource[PintRootBlock, ParserConfig]):
+ """Source code in Pint."""
+
+
class HashTuple(tuple):
pass
-class _PintParser(fp.Parser):
+class _PintParser(fp.Parser[PintRootBlock, ParserConfig]):
"""Parser for the original Pint definition file, with cache."""
_delimiters = {
@@ -46,11 +67,11 @@ class _PintParser(fp.Parser):
_diskcache: fc.DiskCache
- def __init__(self, config: base_defparser.ParserConfig, *args, **kwargs):
+ def __init__(self, config: ParserConfig, *args, **kwargs):
self._diskcache = kwargs.pop("diskcache", None)
super().__init__(config, *args, **kwargs)
- def parse_file(self, path: pathlib.Path) -> fp.ParsedSource:
+ def parse_file(self, path: pathlib.Path) -> PintSource:
if self._diskcache is None:
return super().parse_file(path)
content, basename = self._diskcache.load(path, super().parse_file)
@@ -58,7 +79,13 @@ class _PintParser(fp.Parser):
class DefParser:
- skip_classes = (fp.BOF, fp.BOR, fp.BOS, fp.EOS, plain.CommentDefinition)
+ skip_classes: tuple[type] = (
+ fp.BOF,
+ fp.BOR,
+ fp.BOS,
+ fp.EOS,
+ plain.CommentDefinition,
+ )
def __init__(self, default_config, diskcache):
self._default_config = default_config
@@ -78,6 +105,8 @@ class DefParser:
continue
if isinstance(stmt, common.DefinitionSyntaxError):
+ # TODO: check why this assert fails
+ # assert isinstance(last_location, str)
stmt.set_location(last_location)
raise stmt
elif isinstance(stmt, block.DirectiveBlock):
@@ -101,7 +130,7 @@ class DefParser:
else:
yield stmt
- def parse_file(self, filename: pathlib.Path, cfg=None):
+ def parse_file(self, filename: pathlib.Path, cfg: ParserConfig | None = None):
return fp.parse(
filename,
_PintParser,
@@ -109,7 +138,7 @@ class DefParser:
diskcache=self._diskcache,
)
- def parse_string(self, content: str, cfg=None):
+ def parse_string(self, content: str, cfg: ParserConfig | None = None):
return fp.parse_bytes(
content.encode("utf-8"),
_PintParser,
diff --git a/pint/delegates/txt_defparser/group.py b/pint/delegates/txt_defparser/group.py
index 5be42ac..e96d44b 100644
--- a/pint/delegates/txt_defparser/group.py
+++ b/pint/delegates/txt_defparser/group.py
@@ -23,10 +23,11 @@ from dataclasses import dataclass
from ..._vendor import flexparser as fp
from ...facets.group import definitions
from . import block, common, plain
+from ..base_defparser import PintParsedStatement
@dataclass(frozen=True)
-class BeginGroup(fp.ParsedStatement):
+class BeginGroup(PintParsedStatement):
"""Being of a group directive.
@group <name> [using <group 1>, ..., <group N>]
@@ -59,7 +60,16 @@ class BeginGroup(fp.ParsedStatement):
@dataclass(frozen=True)
-class GroupDefinition(block.DirectiveBlock):
+class GroupDefinition(
+ block.DirectiveBlock[
+ definitions.GroupDefinition,
+ BeginGroup,
+ ty.Union[
+ plain.CommentDefinition,
+ plain.UnitDefinition,
+ ],
+ ]
+):
"""Definition of a group.
@group <name> [using <group 1>, ..., <group N>]
@@ -88,19 +98,21 @@ class GroupDefinition(block.DirectiveBlock):
]
]
- def derive_definition(self):
+ def derive_definition(self) -> definitions.GroupDefinition:
return definitions.GroupDefinition(
self.name, self.using_group_names, self.definitions
)
@property
- def name(self):
+ def name(self) -> str:
+ assert isinstance(self.opening, BeginGroup)
return self.opening.name
@property
- def using_group_names(self):
+ def using_group_names(self) -> tuple[str]:
+ assert isinstance(self.opening, BeginGroup)
return self.opening.using_group_names
@property
- def definitions(self) -> ty.Tuple[plain.UnitDefinition, ...]:
+ def definitions(self) -> tuple[plain.UnitDefinition]:
return tuple(el for el in self.body if isinstance(el, plain.UnitDefinition))
diff --git a/pint/delegates/txt_defparser/plain.py b/pint/delegates/txt_defparser/plain.py
index 749e7fd..9c7bd42 100644
--- a/pint/delegates/txt_defparser/plain.py
+++ b/pint/delegates/txt_defparser/plain.py
@@ -29,12 +29,12 @@ from ..._vendor import flexparser as fp
from ...converters import Converter
from ...facets.plain import definitions
from ...util import UnitsContainer
-from ..base_defparser import ParserConfig
+from ..base_defparser import ParserConfig, PintParsedStatement
from . import common
@dataclass(frozen=True)
-class Equality(fp.ParsedStatement, definitions.Equality):
+class Equality(PintParsedStatement, definitions.Equality):
"""An equality statement contains a left and right hand separated
lhs and rhs should be space stripped.
@@ -53,7 +53,7 @@ class Equality(fp.ParsedStatement, definitions.Equality):
@dataclass(frozen=True)
-class CommentDefinition(fp.ParsedStatement, definitions.CommentDefinition):
+class CommentDefinition(PintParsedStatement, definitions.CommentDefinition):
"""Comments start with a # character.
# This is a comment.
@@ -63,14 +63,14 @@ class CommentDefinition(fp.ParsedStatement, definitions.CommentDefinition):
"""
@classmethod
- def from_string(cls, s: str) -> fp.FromString[fp.ParsedStatement]:
+ def from_string(cls, s: str) -> fp.FromString[CommentDefinition]:
if not s.startswith("#"):
return None
return cls(s[1:].strip())
@dataclass(frozen=True)
-class PrefixDefinition(fp.ParsedStatement, definitions.PrefixDefinition):
+class PrefixDefinition(PintParsedStatement, definitions.PrefixDefinition):
"""Definition of a prefix::
<prefix>- = <value> [= <symbol>] [= <alias>] [ = <alias> ] [...]
@@ -119,7 +119,7 @@ class PrefixDefinition(fp.ParsedStatement, definitions.PrefixDefinition):
@dataclass(frozen=True)
-class UnitDefinition(fp.ParsedStatement, definitions.UnitDefinition):
+class UnitDefinition(PintParsedStatement, definitions.UnitDefinition):
"""Definition of a unit::
<canonical name> = <relation to another unit or dimension> [= <symbol>] [= <alias>] [ = <alias> ] [...]
@@ -194,7 +194,7 @@ class UnitDefinition(fp.ParsedStatement, definitions.UnitDefinition):
@dataclass(frozen=True)
-class DimensionDefinition(fp.ParsedStatement, definitions.DimensionDefinition):
+class DimensionDefinition(PintParsedStatement, definitions.DimensionDefinition):
"""Definition of a root dimension::
[dimension name]
@@ -221,7 +221,7 @@ class DimensionDefinition(fp.ParsedStatement, definitions.DimensionDefinition):
@dataclass(frozen=True)
class DerivedDimensionDefinition(
- fp.ParsedStatement, definitions.DerivedDimensionDefinition
+ PintParsedStatement, definitions.DerivedDimensionDefinition
):
"""Definition of a derived dimension::
@@ -261,7 +261,7 @@ class DerivedDimensionDefinition(
@dataclass(frozen=True)
-class AliasDefinition(fp.ParsedStatement, definitions.AliasDefinition):
+class AliasDefinition(PintParsedStatement, definitions.AliasDefinition):
"""Additional alias(es) for an already existing unit::
@alias <canonical name or previous alias> = <alias> [ = <alias> ] [...]
diff --git a/pint/delegates/txt_defparser/system.py b/pint/delegates/txt_defparser/system.py
index b21fd7a..4efbb4d 100644
--- a/pint/delegates/txt_defparser/system.py
+++ b/pint/delegates/txt_defparser/system.py
@@ -14,11 +14,12 @@ from dataclasses import dataclass
from ..._vendor import flexparser as fp
from ...facets.system import definitions
+from ..base_defparser import PintParsedStatement
from . import block, common, plain
@dataclass(frozen=True)
-class BaseUnitRule(fp.ParsedStatement, definitions.BaseUnitRule):
+class BaseUnitRule(PintParsedStatement, definitions.BaseUnitRule):
@classmethod
def from_string(cls, s: str) -> fp.FromString[BaseUnitRule]:
if ":" not in s:
@@ -32,7 +33,7 @@ class BaseUnitRule(fp.ParsedStatement, definitions.BaseUnitRule):
@dataclass(frozen=True)
-class BeginSystem(fp.ParsedStatement):
+class BeginSystem(PintParsedStatement):
"""Being of a system directive.
@system <name> [using <group 1>, ..., <group N>]
@@ -67,7 +68,13 @@ class BeginSystem(fp.ParsedStatement):
@dataclass(frozen=True)
-class SystemDefinition(block.DirectiveBlock):
+class SystemDefinition(
+ block.DirectiveBlock[
+ definitions.SystemDefinition,
+ BeginSystem,
+ ty.Union[plain.CommentDefinition, BaseUnitRule],
+ ]
+):
"""Definition of a System:
@system <name> [using <group 1>, ..., <group N>]
@@ -92,19 +99,21 @@ class SystemDefinition(block.DirectiveBlock):
opening: fp.Single[BeginSystem]
body: fp.Multi[ty.Union[plain.CommentDefinition, BaseUnitRule]]
- def derive_definition(self):
+ def derive_definition(self) -> definitions.SystemDefinition:
return definitions.SystemDefinition(
self.name, self.using_group_names, self.rules
)
@property
- def name(self):
+ def name(self) -> str:
+ assert isinstance(self.opening, BeginSystem)
return self.opening.name
@property
- def using_group_names(self):
+ def using_group_names(self) -> tuple[str]:
+ assert isinstance(self.opening, BeginSystem)
return self.opening.using_group_names
@property
- def rules(self):
+ def rules(self) -> tuple[BaseUnitRule]:
return tuple(el for el in self.body if isinstance(el, BaseUnitRule))
diff --git a/pint/errors.py b/pint/errors.py
index 8f849da..6cebb21 100644
--- a/pint/errors.py
+++ b/pint/errors.py
@@ -36,18 +36,21 @@ MSG_INVALID_SYSTEM_NAME = (
)
-def is_dim(name):
+def is_dim(name: str) -> bool:
+ """Return True if the name is flanked by square brackets `[` and `]`."""
return name[0] == "[" and name[-1] == "]"
-def is_valid_prefix_name(name):
+def is_valid_prefix_name(name: str) -> bool:
+ """Return True if the name is a valid python identifier or empty."""
return str.isidentifier(name) or name == ""
is_valid_unit_name = is_valid_system_name = is_valid_context_name = str.isidentifier
-def _no_space(name):
+def _no_space(name: str) -> bool:
+ """Return False if the name contains a space in any position."""
return name.strip() == name and " " not in name
@@ -58,7 +61,14 @@ is_valid_unit_alias = (
) = is_valid_unit_symbol = is_valid_prefix_symbol = _no_space
-def is_valid_dimension_name(name):
+def is_valid_dimension_name(name: str) -> bool:
+ """Return True if the name is consistent with a dimension name.
+
+ - flanked by square brackets.
+ - empty dimension name or identifier.
+ """
+
+ # TODO: shall we check also fro spaces?
return name == "[]" or (
len(name) > 1 and is_dim(name) and str.isidentifier(name[1:-1])
)
@@ -67,8 +77,8 @@ def is_valid_dimension_name(name):
class WithDefErr:
"""Mixing class to make some classes more readable."""
- def def_err(self, msg):
- return DefinitionError(self.name, self.__class__.__name__, msg)
+ def def_err(self, msg: str):
+ return DefinitionError(self.name, self.__class__, msg)
@dataclass(frozen=False)
@@ -81,7 +91,7 @@ class DefinitionError(ValueError, PintError):
"""Raised when a definition is not properly constructed."""
name: str
- definition_type: ty.Type
+ definition_type: type
msg: str
def __str__(self):
@@ -110,7 +120,7 @@ class RedefinitionError(ValueError, PintError):
"""Raised when a unit or prefix is redefined."""
name: str
- definition_type: ty.Type
+ definition_type: type
def __str__(self):
msg = f"Cannot redefine '{self.name}' ({self.definition_type})"
@@ -124,7 +134,7 @@ class RedefinitionError(ValueError, PintError):
class UndefinedUnitError(AttributeError, PintError):
"""Raised when the units are not defined in the unit registry."""
- unit_names: ty.Union[str, ty.Tuple[str, ...]]
+ unit_names: str | tuple[str]
def __str__(self):
if isinstance(self.unit_names, str):
diff --git a/pint/facets/__init__.py b/pint/facets/__init__.py
index 7b24463..750f729 100644
--- a/pint/facets/__init__.py
+++ b/pint/facets/__init__.py
@@ -82,13 +82,13 @@ from .plain import PlainRegistry
from .system import SystemRegistry
__all__ = [
- ContextRegistry,
- DaskRegistry,
- FormattingRegistry,
- GroupRegistry,
- MeasurementRegistry,
- NonMultiplicativeRegistry,
- NumpyRegistry,
- PlainRegistry,
- SystemRegistry,
+ "ContextRegistry",
+ "DaskRegistry",
+ "FormattingRegistry",
+ "GroupRegistry",
+ "MeasurementRegistry",
+ "NonMultiplicativeRegistry",
+ "NumpyRegistry",
+ "PlainRegistry",
+ "SystemRegistry",
]
diff --git a/pint/facets/context/definitions.py b/pint/facets/context/definitions.py
index d9ba473..07eb92f 100644
--- a/pint/facets/context/definitions.py
+++ b/pint/facets/context/definitions.py
@@ -12,7 +12,7 @@ import itertools
import numbers
import re
from dataclasses import dataclass
-from typing import TYPE_CHECKING, Any, Callable
+from typing import TYPE_CHECKING, Any, Callable, Iterable
from ... import errors
from ..plain import UnitDefinition
@@ -41,7 +41,7 @@ class Relation:
# could be used.
@property
- def variables(self) -> set[str, ...]:
+ def variables(self) -> set[str]:
"""Find all variables names in the equation."""
return set(self._varname_re.findall(self.equation))
@@ -55,7 +55,7 @@ class Relation:
)
@property
- def bidirectional(self):
+ def bidirectional(self) -> bool:
raise NotImplementedError
@@ -92,18 +92,18 @@ class ContextDefinition(errors.WithDefErr):
#: name of the context
name: str
#: other na
- aliases: tuple[str, ...]
+ aliases: tuple[str]
defaults: dict[str, numbers.Number]
- relations: tuple[Relation, ...]
- redefinitions: tuple[UnitDefinition, ...]
+ relations: tuple[Relation]
+ redefinitions: tuple[UnitDefinition]
@property
- def variables(self) -> set[str, ...]:
+ def variables(self) -> set[str]:
"""Return all variable names in all transformations."""
return set().union(*(r.variables for r in self.relations))
@classmethod
- def from_lines(cls, lines, non_int_type):
+ def from_lines(cls, lines: Iterable[str], non_int_type: type):
# TODO: this is to keep it backwards compatible
from ...delegates import ParserConfig, txt_defparser
diff --git a/pint/facets/context/objects.py b/pint/facets/context/objects.py
index 58f8bb8..bec2a43 100644
--- a/pint/facets/context/objects.py
+++ b/pint/facets/context/objects.py
@@ -10,6 +10,7 @@ from __future__ import annotations
import weakref
from collections import ChainMap, defaultdict
+from typing import Any, Iterable
from ...facets.plain import UnitDefinition
from ...util import UnitsContainer, to_units_container
@@ -70,8 +71,8 @@ class Context:
def __init__(
self,
name: str | None = None,
- aliases: tuple[str, ...] = (),
- defaults: dict | None = None,
+ aliases: tuple[str] = tuple(),
+ defaults: dict[str, Any] | None = None,
) -> None:
self.name = name
self.aliases = aliases
@@ -93,7 +94,7 @@ class Context:
self.relation_to_context = weakref.WeakValueDictionary()
@classmethod
- def from_context(cls, context: Context, **defaults) -> Context:
+ def from_context(cls, context: Context, **defaults: Any) -> Context:
"""Creates a new context that shares the funcs dictionary with the
original context. The default values are copied from the original
context and updated with the new defaults.
@@ -122,7 +123,9 @@ class Context:
return context
@classmethod
- def from_lines(cls, lines, to_base_func=None, non_int_type=float) -> Context:
+ def from_lines(
+ cls, lines: Iterable[str], to_base_func=None, non_int_type: type = float
+ ) -> Context:
cd = ContextDefinition.from_lines(lines, non_int_type)
return cls.from_definition(cd, to_base_func)
@@ -273,7 +276,7 @@ class ContextChain(ChainMap):
"""
return self[(src, dst)].transform(src, dst, registry, value)
- def hashable(self):
+ def hashable(self) -> tuple[Any]:
"""Generate a unique hashable and comparable representation of self, which can
be used as a key in a dict. This class cannot define ``__hash__`` because it is
mutable, and the Python interpreter does cache the output of ``__hash__``.
diff --git a/pint/facets/group/definitions.py b/pint/facets/group/definitions.py
index c0abced..48c6f4b 100644
--- a/pint/facets/group/definitions.py
+++ b/pint/facets/group/definitions.py
@@ -8,9 +8,10 @@
from __future__ import annotations
-import typing as ty
+from typing import Iterable
from dataclasses import dataclass
+from ..._typing import Self
from ... import errors
from .. import plain
@@ -22,12 +23,14 @@ class GroupDefinition(errors.WithDefErr):
#: name of the group
name: str
#: unit groups that will be included within the group
- using_group_names: ty.Tuple[str, ...]
+ using_group_names: tuple[str]
#: definitions for the units existing within the group
- definitions: ty.Tuple[plain.UnitDefinition, ...]
+ definitions: tuple[plain.UnitDefinition]
@classmethod
- def from_lines(cls, lines, non_int_type):
+ def from_lines(
+ cls: type[Self], lines: Iterable[str], non_int_type: type
+ ) -> Self | None:
# TODO: this is to keep it backwards compatible
from ...delegates import ParserConfig, txt_defparser
@@ -39,10 +42,10 @@ class GroupDefinition(errors.WithDefErr):
return definition
@property
- def unit_names(self) -> ty.Tuple[str, ...]:
+ def unit_names(self) -> tuple[str]:
return tuple(el.name for el in self.definitions)
- def __post_init__(self):
+ def __post_init__(self) -> None:
if not errors.is_valid_group_name(self.name):
raise self.def_err(errors.MSG_INVALID_GROUP_NAME)
diff --git a/pint/facets/group/objects.py b/pint/facets/group/objects.py
index 558a107..a0a81be 100644
--- a/pint/facets/group/objects.py
+++ b/pint/facets/group/objects.py
@@ -8,6 +8,7 @@
from __future__ import annotations
+from typing import Generator, Iterable
from ...util import SharedRegistryObject, getattr_maybe_raise
from .definitions import GroupDefinition
@@ -23,32 +24,26 @@ class Group(SharedRegistryObject):
The group belongs to one Registry.
See GroupDefinition for the definition file syntax.
- """
- def __init__(self, name):
- """
- :param name: Name of the group. If not given, a root Group will be created.
- :type name: str
- :param groups: dictionary like object groups and system.
- The newly created group will be added after creation.
- :type groups: dict[str | Group]
- """
+ Parameters
+ ----------
+ name
+ If not given, a root Group will be created.
+ """
+ def __init__(self, name: str):
# The name of the group.
- #: type: str
self.name = name
#: Names of the units in this group.
#: :type: set[str]
- self._unit_names = set()
+ self._unit_names: set[str] = set()
#: Names of the groups in this group.
- #: :type: set[str]
- self._used_groups = set()
+ self._used_groups: set[str] = set()
#: Names of the groups in which this group is contained.
- #: :type: set[str]
- self._used_by = set()
+ self._used_by: set[str] = set()
# Add this group to the group dictionary
self._REGISTRY._groups[self.name] = self
@@ -59,8 +54,7 @@ class Group(SharedRegistryObject):
#: A cache of the included units.
#: None indicates that the cache has been invalidated.
- #: :type: frozenset[str] | None
- self._computed_members = None
+ self._computed_members: frozenset[str] | None = None
@property
def members(self):
@@ -70,23 +64,23 @@ class Group(SharedRegistryObject):
"""
if self._computed_members is None:
- self._computed_members = set(self._unit_names)
+ tmp = set(self._unit_names)
for _, group in self.iter_used_groups():
- self._computed_members |= group.members
+ tmp |= group.members
- self._computed_members = frozenset(self._computed_members)
+ self._computed_members = frozenset(tmp)
return self._computed_members
- def invalidate_members(self):
+ def invalidate_members(self) -> None:
"""Invalidate computed members in this Group and all parent nodes."""
self._computed_members = None
d = self._REGISTRY._groups
for name in self._used_by:
d[name].invalidate_members()
- def iter_used_groups(self):
+ def iter_used_groups(self) -> Generator[tuple[str, Group], None, None]:
pending = set(self._used_groups)
d = self._REGISTRY._groups
while pending:
@@ -95,13 +89,13 @@ class Group(SharedRegistryObject):
pending |= group._used_groups
yield name, d[name]
- def is_used_group(self, group_name):
+ def is_used_group(self, group_name: str) -> bool:
for name, _ in self.iter_used_groups():
if name == group_name:
return True
return False
- def add_units(self, *unit_names):
+ def add_units(self, *unit_names: str) -> None:
"""Add units to group."""
for unit_name in unit_names:
self._unit_names.add(unit_name)
@@ -109,17 +103,17 @@ class Group(SharedRegistryObject):
self.invalidate_members()
@property
- def non_inherited_unit_names(self):
+ def non_inherited_unit_names(self) -> frozenset[str]:
return frozenset(self._unit_names)
- def remove_units(self, *unit_names):
+ def remove_units(self, *unit_names: str) -> None:
"""Remove units from group."""
for unit_name in unit_names:
self._unit_names.remove(unit_name)
self.invalidate_members()
- def add_groups(self, *group_names):
+ def add_groups(self, *group_names: str) -> None:
"""Add groups to group."""
d = self._REGISTRY._groups
for group_name in group_names:
@@ -136,7 +130,7 @@ class Group(SharedRegistryObject):
self.invalidate_members()
- def remove_groups(self, *group_names):
+ def remove_groups(self, *group_names: str) -> None:
"""Remove groups from group."""
d = self._REGISTRY._groups
for group_name in group_names:
@@ -148,7 +142,9 @@ class Group(SharedRegistryObject):
self.invalidate_members()
@classmethod
- def from_lines(cls, lines, define_func, non_int_type=float) -> Group:
+ def from_lines(
+ cls, lines: Iterable[str], define_func, non_int_type: type = float
+ ) -> Group:
"""Return a Group object parsing an iterable of lines.
Parameters
@@ -190,6 +186,6 @@ class Group(SharedRegistryObject):
return grp
- def __getattr__(self, item):
+ def __getattr__(self, item: str):
getattr_maybe_raise(self, item)
return self._REGISTRY
diff --git a/pint/facets/nonmultiplicative/definitions.py b/pint/facets/nonmultiplicative/definitions.py
index dbfc0ff..f795cf0 100644
--- a/pint/facets/nonmultiplicative/definitions.py
+++ b/pint/facets/nonmultiplicative/definitions.py
@@ -10,6 +10,7 @@ from __future__ import annotations
from dataclasses import dataclass
+from ..._typing import Magnitude
from ...compat import HAS_NUMPY, exp, log
from ..plain import ScaleConverter
@@ -24,7 +25,7 @@ class OffsetConverter(ScaleConverter):
def is_multiplicative(self):
return self.offset == 0
- def to_reference(self, value, inplace=False):
+ def to_reference(self, value: Magnitude, inplace: bool = False) -> Magnitude:
if inplace:
value *= self.scale
value += self.offset
@@ -33,7 +34,7 @@ class OffsetConverter(ScaleConverter):
return value
- def from_reference(self, value, inplace=False):
+ def from_reference(self, value: Magnitude, inplace: bool = False) -> Magnitude:
if inplace:
value -= self.offset
value /= self.scale
@@ -66,6 +67,7 @@ class LogarithmicConverter(ScaleConverter):
controls if computation is done in place
"""
+ # TODO: Can I use PintScalar here?
logbase: float
logfactor: float
@@ -77,7 +79,7 @@ class LogarithmicConverter(ScaleConverter):
def is_logarithmic(self):
return True
- def from_reference(self, value, inplace=False):
+ def from_reference(self, value: Magnitude, inplace: bool = False) -> Magnitude:
"""Converts value from the reference unit to the logarithmic unit
dBm <------ mW
@@ -95,7 +97,7 @@ class LogarithmicConverter(ScaleConverter):
return value
- def to_reference(self, value, inplace=False):
+ def to_reference(self, value: Magnitude, inplace: bool = False) -> Magnitude:
"""Converts value to the reference unit from the logarithmic unit
dBm ------> mW
diff --git a/pint/facets/nonmultiplicative/objects.py b/pint/facets/nonmultiplicative/objects.py
index 7f9064d..0ab743e 100644
--- a/pint/facets/nonmultiplicative/objects.py
+++ b/pint/facets/nonmultiplicative/objects.py
@@ -40,7 +40,7 @@ class NonMultiplicativeQuantity(PlainQuantity):
self._get_unit_definition(d).reference == offset_unit_dim for d in deltas
)
- def _ok_for_muldiv(self, no_offset_units=None) -> bool:
+ def _ok_for_muldiv(self, no_offset_units: int | None = None) -> bool:
"""Checks if PlainQuantity object can be multiplied or divided"""
is_ok = True
diff --git a/pint/facets/plain/definitions.py b/pint/facets/plain/definitions.py
index eb45db1..79a44f1 100644
--- a/pint/facets/plain/definitions.py
+++ b/pint/facets/plain/definitions.py
@@ -13,8 +13,9 @@ import numbers
import typing as ty
from dataclasses import dataclass
from functools import cached_property
-from typing import Callable
+from typing import Callable, Any
+from ..._typing import Magnitude
from ... import errors
from ...converters import Converter
from ...util import UnitsContainer
@@ -23,7 +24,7 @@ from ...util import UnitsContainer
class NotNumeric(Exception):
"""Internal exception. Do not expose outside Pint"""
- def __init__(self, value):
+ def __init__(self, value: Any):
self.value = value
@@ -115,18 +116,26 @@ class UnitDefinition(errors.WithDefErr):
#: canonical name of the unit
name: str
#: canonical symbol
- defined_symbol: ty.Optional[str]
+ defined_symbol: str | None
#: additional names for the same unit
- aliases: ty.Tuple[str, ...]
+ aliases: tuple[str]
#: A functiont that converts a value in these units into the reference units
- converter: ty.Optional[ty.Union[Callable, Converter]]
+ converter: Callable[
+ [
+ Magnitude,
+ ],
+ Magnitude,
+ ] | Converter | None
#: Reference units.
- reference: ty.Optional[UnitsContainer]
+ reference: UnitsContainer | None
def __post_init__(self):
if not errors.is_valid_unit_name(self.name):
raise self.def_err(errors.MSG_INVALID_UNIT_NAME)
+ # TODO: check why reference: UnitsContainer | None
+ assert isinstance(self.reference, UnitsContainer)
+
if not any(map(errors.is_dim, self.reference.keys())):
invalid = tuple(
itertools.filterfalse(errors.is_valid_unit_name, self.reference.keys())
@@ -180,14 +189,20 @@ class UnitDefinition(errors.WithDefErr):
@property
def is_base(self) -> bool:
"""Indicates if it is a base unit."""
+
+ # TODO: why is this here
return self._is_base
@property
def is_multiplicative(self) -> bool:
+ # TODO: Check how to avoid this check
+ assert isinstance(self.converter, Converter)
return self.converter.is_multiplicative
@property
def is_logarithmic(self) -> bool:
+ # TODO: Check how to avoid this check
+ assert isinstance(self.converter, Converter)
return self.converter.is_logarithmic
@property
@@ -272,7 +287,7 @@ class ScaleConverter(Converter):
scale: float
- def to_reference(self, value, inplace=False):
+ def to_reference(self, value: Magnitude, inplace: bool = False) -> Magnitude:
if inplace:
value *= self.scale
else:
@@ -280,7 +295,7 @@ class ScaleConverter(Converter):
return value
- def from_reference(self, value, inplace=False):
+ def from_reference(self, value: Magnitude, inplace: bool = False) -> Magnitude:
if inplace:
value /= self.scale
else:
diff --git a/pint/facets/plain/objects.py b/pint/facets/plain/objects.py
index 5b2837b..a868c7f 100644
--- a/pint/facets/plain/objects.py
+++ b/pint/facets/plain/objects.py
@@ -11,4 +11,4 @@ from __future__ import annotations
from .quantity import PlainQuantity
from .unit import PlainUnit, UnitsContainer
-__all__ = [PlainUnit, PlainQuantity, UnitsContainer]
+__all__ = ["PlainUnit", "PlainQuantity", "UnitsContainer"]
diff --git a/pint/facets/system/definitions.py b/pint/facets/system/definitions.py
index 8243324..893c510 100644
--- a/pint/facets/system/definitions.py
+++ b/pint/facets/system/definitions.py
@@ -8,9 +8,10 @@
from __future__ import annotations
-import typing as ty
+from typing import Iterable
from dataclasses import dataclass
+from ..._typing import Self
from ... import errors
@@ -23,7 +24,7 @@ class BaseUnitRule:
new_unit_name: str
#: name of the unit to be kicked out to make room for the new base uni
#: If None, the current base unit with the same dimensionality will be used
- old_unit_name: ty.Optional[str] = None
+ old_unit_name: str | None = None
# Instead of defining __post_init__ here,
# it will be added to the container class
@@ -38,13 +39,16 @@ class SystemDefinition(errors.WithDefErr):
#: name of the system
name: str
#: unit groups that will be included within the system
- using_group_names: ty.Tuple[str, ...]
+ using_group_names: tuple[str]
#: rules to define new base unit within the system.
- rules: ty.Tuple[BaseUnitRule, ...]
+ rules: tuple[BaseUnitRule]
@classmethod
- def from_lines(cls, lines, non_int_type):
+ def from_lines(
+ cls: type[Self], lines: Iterable[str], non_int_type: type
+ ) -> Self | None:
# TODO: this is to keep it backwards compatible
+ # TODO: check when is None returned.
from ...delegates import ParserConfig, txt_defparser
cfg = ParserConfig(non_int_type)
@@ -55,7 +59,8 @@ class SystemDefinition(errors.WithDefErr):
return definition
@property
- def unit_replacements(self) -> ty.Tuple[ty.Tuple[str, str], ...]:
+ def unit_replacements(self) -> tuple[tuple[str, str | None]]:
+ # TODO: check if None can be dropped.
return tuple((el.new_unit_name, el.old_unit_name) for el in self.rules)
def __post_init__(self):
diff --git a/pint/facets/system/objects.py b/pint/facets/system/objects.py
index 829fb5c..7af65a6 100644
--- a/pint/facets/system/objects.py
+++ b/pint/facets/system/objects.py
@@ -9,6 +9,12 @@
from __future__ import annotations
+import numbers
+
+from typing import Any, Iterable
+
+from ..._typing import Self
+
from ...babel_names import _babel_systems
from ...compat import babel_parse
from ...util import (
@@ -29,32 +35,28 @@ class System(SharedRegistryObject):
The System belongs to one Registry.
See SystemDefinition for the definition file syntax.
- """
- def __init__(self, name):
- """
- :param name: Name of the group
- :type name: str
- """
+ Parameters
+ ----------
+ name
+ Name of the group.
+ """
+ def __init__(self, name: str):
#: Name of the system
#: :type: str
self.name = name
#: Maps root unit names to a dict indicating the new unit and its exponent.
- #: :type: dict[str, dict[str, number]]]
- self.base_units = {}
+ self.base_units: dict[str, dict[str, numbers.Number]] = {}
#: Derived unit names.
- #: :type: set(str)
- self.derived_units = set()
+ self.derived_units: set[str] = set()
#: Names of the _used_groups in used by this system.
- #: :type: set(str)
- self._used_groups = set()
+ self._used_groups: set[str] = set()
- #: :type: frozenset | None
- self._computed_members = None
+ self._computed_members: frozenset[str] | None = None
# Add this system to the system dictionary
self._REGISTRY._systems[self.name] = self
@@ -62,7 +64,7 @@ class System(SharedRegistryObject):
def __dir__(self):
return list(self.members)
- def __getattr__(self, item):
+ def __getattr__(self, item: str) -> Any:
getattr_maybe_raise(self, item)
u = getattr(self._REGISTRY, self.name + "_" + item, None)
if u is not None:
@@ -93,19 +95,19 @@ class System(SharedRegistryObject):
"""Invalidate computed members in this Group and all parent nodes."""
self._computed_members = None
- def add_groups(self, *group_names):
+ def add_groups(self, *group_names: str) -> None:
"""Add groups to group."""
self._used_groups |= set(group_names)
self.invalidate_members()
- def remove_groups(self, *group_names):
+ def remove_groups(self, *group_names: str) -> None:
"""Remove groups from group."""
self._used_groups -= set(group_names)
self.invalidate_members()
- def format_babel(self, locale):
+ def format_babel(self, locale: str) -> str:
"""translate the name of the system."""
if locale and self.name in _babel_systems:
name = _babel_systems[self.name]
@@ -114,8 +116,12 @@ class System(SharedRegistryObject):
return self.name
@classmethod
- def from_lines(cls, lines, get_root_func, non_int_type=float):
- system_definition = SystemDefinition.from_lines(lines, get_root_func)
+ def from_lines(
+ cls: type[Self], lines: Iterable[str], get_root_func, non_int_type: type = float
+ ) -> Self:
+ # TODO: we changed something here it used to be
+ # system_definition = SystemDefinition.from_lines(lines, get_root_func)
+ system_definition = SystemDefinition.from_lines(lines, non_int_type)
return cls.from_definition(system_definition, get_root_func)
@classmethod
@@ -174,12 +180,12 @@ class System(SharedRegistryObject):
class Lister:
- def __init__(self, d):
+ def __init__(self, d: dict[str, Any]):
self.d = d
- def __dir__(self):
+ def __dir__(self) -> list[str]:
return list(self.d.keys())
- def __getattr__(self, item):
+ def __getattr__(self, item: str) -> Any:
getattr_maybe_raise(self, item)
return self.d[item]
diff --git a/pint/formatting.py b/pint/formatting.py
index dcc8725..637d838 100644
--- a/pint/formatting.py
+++ b/pint/formatting.py
@@ -13,7 +13,8 @@ from __future__ import annotations
import functools
import re
import warnings
-from typing import Callable
+from typing import Callable, Iterable, Any
+from numbers import Number
from .babel_names import _babel_lengths, _babel_units
from .compat import babel_parse
@@ -21,7 +22,7 @@ from .compat import babel_parse
__JOIN_REG_EXP = re.compile(r"{\d*}")
-def _join(fmt, iterable):
+def _join(fmt: str, iterable: Iterable[Any]):
"""Join an iterable with the format specified in fmt.
The format can be specified in two ways:
@@ -55,7 +56,7 @@ def _join(fmt, iterable):
_PRETTY_EXPONENTS = "⁰¹²³⁴⁵⁶⁷⁸⁹"
-def _pretty_fmt_exponent(num):
+def _pretty_fmt_exponent(num: Number) -> str:
"""Format an number into a pretty printed exponent.
Parameters
@@ -76,7 +77,7 @@ def _pretty_fmt_exponent(num):
#: _FORMATS maps format specifications to the corresponding argument set to
#: formatter().
-_FORMATS: dict[str, dict] = {
+_FORMATS: dict[str, dict[str, Any]] = {
"P": { # Pretty format.
"as_ratio": True,
"single_denominator": False,
@@ -125,7 +126,7 @@ _FORMATS: dict[str, dict] = {
_FORMATTERS: dict[str, Callable] = {}
-def register_unit_format(name):
+def register_unit_format(name: str):
"""register a function as a new format for units
The registered function must have a signature of:
@@ -268,18 +269,18 @@ def format_compact(unit, registry, **options):
def formatter(
- items,
- as_ratio=True,
- single_denominator=False,
- product_fmt=" * ",
- division_fmt=" / ",
- power_fmt="{} ** {}",
- parentheses_fmt="({0})",
+ items: list[tuple[str, Number]],
+ as_ratio: bool = True,
+ single_denominator: bool = False,
+ product_fmt: str = " * ",
+ division_fmt: str = " / ",
+ power_fmt: str = "{} ** {}",
+ parentheses_fmt: str = "({0})",
exp_call=lambda x: f"{x:n}",
- locale=None,
- babel_length="long",
- babel_plural_form="one",
- sort=True,
+ locale: str | None = None,
+ babel_length: str = "long",
+ babel_plural_form: str = "one",
+ sort: bool = True,
):
"""Format a list of (name, exponent) pairs.
diff --git a/pint/pint_eval.py b/pint/pint_eval.py
index e776d60..d476eae 100644
--- a/pint/pint_eval.py
+++ b/pint/pint_eval.py
@@ -11,7 +11,9 @@ from __future__ import annotations
import operator
import token as tokenlib
-import tokenize
+from tokenize import TokenInfo
+
+from typing import Any
from .errors import DefinitionSyntaxError
@@ -30,7 +32,7 @@ _OP_PRIORITY = {
}
-def _power(left, right):
+def _power(left: Any, right: Any) -> Any:
from . import Quantity
from .compat import is_duck_array
@@ -45,7 +47,19 @@ def _power(left, right):
return operator.pow(left, right)
-_BINARY_OPERATOR_MAP = {
+import typing
+
+UnaryOpT = typing.Callable[
+ [
+ Any,
+ ],
+ Any,
+]
+BinaryOpT = typing.Callable[[Any, Any], Any]
+
+_UNARY_OPERATOR_MAP: dict[str, UnaryOpT] = {"+": lambda x: x, "-": lambda x: x * -1}
+
+_BINARY_OPERATOR_MAP: dict[str, BinaryOpT] = {
"**": _power,
"*": operator.mul,
"": operator.mul, # operator for implicit ops
@@ -56,8 +70,6 @@ _BINARY_OPERATOR_MAP = {
"//": operator.floordiv,
}
-_UNARY_OPERATOR_MAP = {"+": lambda x: x, "-": lambda x: x * -1}
-
class EvalTreeNode:
"""Single node within an evaluation tree
@@ -68,25 +80,43 @@ class EvalTreeNode:
left --> single value
"""
- def __init__(self, left, operator=None, right=None):
+ def __init__(
+ self,
+ left: EvalTreeNode | TokenInfo,
+ operator: TokenInfo | None = None,
+ right: EvalTreeNode | None = None,
+ ):
self.left = left
self.operator = operator
self.right = right
- def to_string(self):
+ def to_string(self) -> str:
# For debugging purposes
if self.right:
+ assert isinstance(self.left, EvalTreeNode), "self.left not EvalTreeNode (1)"
comps = [self.left.to_string()]
if self.operator:
- comps.append(self.operator[1])
+ comps.append(self.operator.string)
comps.append(self.right.to_string())
elif self.operator:
- comps = [self.operator[1], self.left.to_string()]
+ assert isinstance(self.left, EvalTreeNode), "self.left not EvalTreeNode (2)"
+ comps = [self.operator.string, self.left.to_string()]
else:
- return self.left[1]
+ assert isinstance(self.left, TokenInfo), "self.left not TokenInfo (1)"
+ return self.left.string
return "(%s)" % " ".join(comps)
- def evaluate(self, define_op, bin_op=None, un_op=None):
+ def evaluate(
+ self,
+ define_op: typing.Callable[
+ [
+ Any,
+ ],
+ Any,
+ ],
+ bin_op: dict[str, BinaryOpT] | None = None,
+ un_op: dict[str, UnaryOpT] | None = None,
+ ):
"""Evaluate node.
Parameters
@@ -107,17 +137,22 @@ class EvalTreeNode:
un_op = un_op or _UNARY_OPERATOR_MAP
if self.right:
+ assert isinstance(self.left, EvalTreeNode), "self.left not EvalTreeNode (3)"
# binary or implicit operator
- op_text = self.operator[1] if self.operator else ""
+ op_text = self.operator.string if self.operator else ""
if op_text not in bin_op:
- raise DefinitionSyntaxError('missing binary operator "%s"' % op_text)
- left = self.left.evaluate(define_op, bin_op, un_op)
- return bin_op[op_text](left, self.right.evaluate(define_op, bin_op, un_op))
+ raise DefinitionSyntaxError(f"missing binary operator '{op_text}'")
+
+ return bin_op[op_text](
+ self.left.evaluate(define_op, bin_op, un_op),
+ self.right.evaluate(define_op, bin_op, un_op),
+ )
elif self.operator:
+ assert isinstance(self.left, EvalTreeNode), "self.left not EvalTreeNode (4)"
# unary operator
- op_text = self.operator[1]
+ op_text = self.operator.string
if op_text not in un_op:
- raise DefinitionSyntaxError('missing unary operator "%s"' % op_text)
+ raise DefinitionSyntaxError(f"missing unary operator '{op_text}'")
return un_op[op_text](self.left.evaluate(define_op, bin_op, un_op))
# single value
@@ -127,13 +162,13 @@ class EvalTreeNode:
from collections.abc import Iterable
-def build_eval_tree(
- tokens: Iterable[tokenize.TokenInfo],
- op_priority=None,
- index=0,
- depth=0,
- prev_op=None,
-) -> tuple[EvalTreeNode | None, int] | EvalTreeNode:
+def _build_eval_tree(
+ tokens: list[TokenInfo],
+ op_priority: dict[str, int],
+ index: int = 0,
+ depth: int = 0,
+ prev_op: str = "<none>",
+) -> tuple[EvalTreeNode, int]:
"""Build an evaluation tree from a set of tokens.
Params:
@@ -153,14 +188,12 @@ def build_eval_tree(
5) Combine left side, operator, and right side into a new left side
6) Go back to step #2
- """
-
- if op_priority is None:
- op_priority = _OP_PRIORITY
+ Raises
+ ------
+ DefinitionSyntaxError
+ If there is a syntax error.
- if depth == 0 and prev_op is None:
- # ensure tokens is list so we can access by index
- tokens = list(tokens)
+ """
result = None
@@ -171,19 +204,21 @@ def build_eval_tree(
if token_type == tokenlib.OP:
if token_text == ")":
- if prev_op is None:
+ if prev_op == "<none>":
raise DefinitionSyntaxError(
- "unopened parentheses in tokens: %s" % current_token
+ f"unopened parentheses in tokens: {current_token}"
)
elif prev_op == "(":
# close parenthetical group
+ assert result is not None
return result, index
else:
# parenthetical group ending, but we need to close sub-operations within group
+ assert result is not None
return result, index - 1
elif token_text == "(":
# gather parenthetical group
- right, index = build_eval_tree(
+ right, index = _build_eval_tree(
tokens, op_priority, index + 1, 0, token_text
)
if not tokens[index][1] == ")":
@@ -208,7 +243,7 @@ def build_eval_tree(
# previous operator is higher priority, so end previous binary op
return result, index - 1
# get right side of binary op
- right, index = build_eval_tree(
+ right, index = _build_eval_tree(
tokens, op_priority, index + 1, depth + 1, token_text
)
result = EvalTreeNode(
@@ -216,7 +251,7 @@ def build_eval_tree(
)
else:
# unary operator
- right, index = build_eval_tree(
+ right, index = _build_eval_tree(
tokens, op_priority, index + 1, depth + 1, "unary"
)
result = EvalTreeNode(left=right, operator=current_token)
@@ -227,7 +262,7 @@ def build_eval_tree(
# previous operator is higher priority than implicit, so end
# previous binary op
return result, index - 1
- right, index = build_eval_tree(
+ right, index = _build_eval_tree(
tokens, op_priority, index, depth + 1, ""
)
result = EvalTreeNode(left=result, right=right)
@@ -240,13 +275,57 @@ def build_eval_tree(
raise DefinitionSyntaxError("unclosed parentheses in tokens")
if depth > 0 or prev_op:
# have to close recursion
+ assert result is not None
return result, index
else:
# recursion all closed, so just return the final result
- return result
+ assert result is not None
+ return result, -1
if index + 1 >= len(tokens):
# should hit ENDMARKER before this ever happens
raise DefinitionSyntaxError("unexpected end to tokens")
index += 1
+
+
+def build_eval_tree(
+ tokens: Iterable[TokenInfo],
+ op_priority: dict[str, int] | None = None,
+) -> EvalTreeNode:
+ """Build an evaluation tree from a set of tokens.
+
+ Params:
+ Index, depth, and prev_op used recursively, so don't touch.
+ Tokens is an iterable of tokens from an expression to be evaluated.
+
+ Transform the tokens from an expression into a recursive parse tree, following order
+ of operations. Operations can include binary ops (3 + 4), implicit ops (3 kg), or
+ unary ops (-1).
+
+ General Strategy:
+ 1) Get left side of operator
+ 2) If no tokens left, return final result
+ 3) Get operator
+ 4) Use recursion to create tree starting at token on right side of operator (start at step #1)
+ 4.1) If recursive call encounters an operator with lower or equal priority to step #2, exit recursion
+ 5) Combine left side, operator, and right side into a new left side
+ 6) Go back to step #2
+
+ Raises
+ ------
+ DefinitionSyntaxError
+ If there is a syntax error.
+
+ """
+
+ if op_priority is None:
+ op_priority = _OP_PRIORITY
+
+ if not isinstance(tokens, list):
+ # ensure tokens is list so we can access by index
+ tokens = list(tokens)
+
+ result, _ = _build_eval_tree(tokens, op_priority, 0, 0)
+
+ return result
diff --git a/pint/testsuite/test_compat_downcast.py b/pint/testsuite/test_compat_downcast.py
index 4ca611d..ed43e94 100644
--- a/pint/testsuite/test_compat_downcast.py
+++ b/pint/testsuite/test_compat_downcast.py
@@ -38,7 +38,7 @@ def q_base(local_registry):
# Define identity function for use in tests
-def identity(ureg, x):
+def id_matrix(ureg, x):
return x
@@ -63,17 +63,17 @@ def array(request):
@pytest.mark.parametrize(
"op, magnitude_op, unit_op",
[
- pytest.param(identity, identity, identity, id="identity"),
+ pytest.param(id_matrix, id_matrix, id_matrix, id="identity"),
pytest.param(
lambda ureg, x: x + 1 * ureg.m,
lambda ureg, x: x + 1,
- identity,
+ id_matrix,
id="addition",
),
pytest.param(
lambda ureg, x: x - 20 * ureg.cm,
lambda ureg, x: x - 0.2,
- identity,
+ id_matrix,
id="subtraction",
),
pytest.param(
@@ -84,7 +84,7 @@ def array(request):
),
pytest.param(
lambda ureg, x: x / (1 * ureg.s),
- identity,
+ id_matrix,
lambda ureg, u: u / ureg.s,
id="division",
),
@@ -94,17 +94,17 @@ def array(request):
WR(lambda u: u**2),
id="square",
),
- pytest.param(WR(lambda x: x.T), WR(lambda x: x.T), identity, id="transpose"),
- pytest.param(WR(np.mean), WR(np.mean), identity, id="mean ufunc"),
- pytest.param(WR(np.sum), WR(np.sum), identity, id="sum ufunc"),
+ pytest.param(WR(lambda x: x.T), WR(lambda x: x.T), id_matrix, id="transpose"),
+ pytest.param(WR(np.mean), WR(np.mean), id_matrix, id="mean ufunc"),
+ pytest.param(WR(np.sum), WR(np.sum), id_matrix, id="sum ufunc"),
pytest.param(WR(np.sqrt), WR(np.sqrt), WR(lambda u: u**0.5), id="sqrt ufunc"),
pytest.param(
WR(lambda x: np.reshape(x, (25,))),
WR(lambda x: np.reshape(x, (25,))),
- identity,
+ id_matrix,
id="reshape function",
),
- pytest.param(WR(np.amax), WR(np.amax), identity, id="amax function"),
+ pytest.param(WR(np.amax), WR(np.amax), id_matrix, id="amax function"),
],
)
def test_univariate_op_consistency(
diff --git a/pint/util.py b/pint/util.py
index 28710e7..807c3ac 100644
--- a/pint/util.py
+++ b/pint/util.py
@@ -14,48 +14,82 @@ import logging
import math
import operator
import re
-from collections.abc import Mapping
+from collections.abc import Mapping, Iterable, Iterator
from fractions import Fraction
from functools import lru_cache, partial
from logging import NullHandler
from numbers import Number
from token import NAME, NUMBER
-from typing import TYPE_CHECKING, ClassVar
+import tokenize
+from typing import (
+ TYPE_CHECKING,
+ ClassVar,
+ TypeAlias,
+ Callable,
+ TypeVar,
+ Hashable,
+ Generator,
+ Any,
+)
from .compat import NUMERIC_TYPES, tokenizer
from .errors import DefinitionSyntaxError
from .formatting import format_unit
from .pint_eval import build_eval_tree
+from ._typing import PintScalar
+
if TYPE_CHECKING:
- from ._typing import Quantity, UnitLike
+ from ._typing import Quantity, UnitLike, Self
from .registry import UnitRegistry
+
logger = logging.getLogger(__name__)
logger.addHandler(NullHandler())
+T = TypeVar("T")
+TH = TypeVar("TH", bound=Hashable)
+ItMatrix: TypeAlias = Iterable[Iterable[PintScalar]]
+Matrix: TypeAlias = list[list[PintScalar]]
+
+
+def _noop(x: T) -> T:
+ return x
+
def matrix_to_string(
- matrix, row_headers=None, col_headers=None, fmtfun=lambda x: str(int(x))
-):
- """Takes a 2D matrix (as nested list) and returns a string.
+ matrix: ItMatrix,
+ row_headers: Iterable[str] | None = None,
+ col_headers: Iterable[str] | None = None,
+ fmtfun: Callable[
+ [
+ PintScalar,
+ ],
+ str,
+ ] = "{:0.0f}".format,
+) -> str:
+ """Return a string representation of a matrix.
Parameters
----------
- matrix :
-
- row_headers :
- (Default value = None)
- col_headers :
- (Default value = None)
- fmtfun :
- (Default value = lambda x: str(int(x)))
+ matrix
+ A matrix given as an iterable of an iterable of numbers.
+ row_headers
+ An iterable of strings to serve as row headers.
+ (default = None, meaning no row headers are printed.)
+ col_headers
+ An iterable of strings to serve as column headers.
+ (default = None, meaning no col headers are printed.)
+ fmtfun
+ A callable to convert a number into string.
+ (default = `"{:0.0f}".format`)
Returns
-------
-
+ str
+ String representation of the matrix.
"""
- ret = []
+ ret: list[str] = []
if col_headers:
ret.append(("\t" if row_headers else "") + "\t".join(col_headers))
if row_headers:
@@ -69,99 +103,124 @@ def matrix_to_string(
return "\n".join(ret)
-def transpose(matrix):
- """Takes a 2D matrix (as nested list) and returns the transposed version.
+def transpose(matrix: ItMatrix) -> Matrix:
+ """Return the transposed version of a matrix.
Parameters
----------
- matrix :
-
+ matrix
+ A matrix given as an iterable of an iterable of numbers.
Returns
-------
-
+ Matrix
+ The transposed version of the matrix.
"""
return [list(val) for val in zip(*matrix)]
-def column_echelon_form(matrix, ntype=Fraction, transpose_result=False):
- """Calculates the column echelon form using Gaussian elimination.
+def matrix_apply(
+ matrix: ItMatrix,
+ func: Callable[
+ [
+ PintScalar,
+ ],
+ PintScalar,
+ ],
+) -> Matrix:
+ """Apply a function to individual elements within a matrix.
Parameters
----------
- matrix :
- a 2D matrix as nested list.
- ntype :
- the numerical type to use in the calculation. (Default value = Fraction)
- transpose_result :
- indicates if the returned matrix should be transposed. (Default value = False)
+ matrix
+ A matrix given as an iterable of an iterable of numbers.
+ func
+ A callable that converts a number to another.
Returns
-------
- type
- column echelon form, transformed identity matrix, swapped rows
-
+ A new matrix in which each element has been replaced by new one.
"""
- lead = 0
+ return [[func(x) for x in row] for row in matrix]
+
+
+def column_echelon_form(
+ matrix: ItMatrix, ntype: type = Fraction, transpose_result: bool = False
+) -> tuple[Matrix, Matrix, list[int]]:
+ """Calculate the column echelon form using Gaussian elimination.
- M = transpose(matrix)
+ Parameters
+ ----------
+ matrix
+ A 2D matrix as nested list.
+ ntype
+ The numerical type to use in the calculation.
+ (default = Fraction)
+ transpose_result
+ Indicates if the returned matrix should be transposed.
+ (default = False)
- _transpose = transpose if transpose_result else lambda x: x
+ Returns
+ -------
+ ech_matrix
+ Column echelon form.
+ id_matrix
+ Transformed identity matrix.
+ swapped
+ Swapped rows.
+ """
- rows, cols = len(M), len(M[0])
+ _transpose = transpose if transpose_result else _noop
- new_M = []
- for row in M:
- r = []
- for x in row:
- if isinstance(x, float):
- x = ntype.from_float(x)
- else:
- x = ntype(x)
- r.append(x)
- new_M.append(r)
- M = new_M
+ ech_matrix = matrix_apply(
+ transpose(matrix),
+ lambda x: ntype.from_float(x) if isinstance(x, float) else ntype(x), # type: ignore
+ )
+ rows, cols = len(ech_matrix), len(ech_matrix[0])
# M = [[ntype(x) for x in row] for row in M]
- I = [ # noqa: E741
+ id_matrix: list[list[PintScalar]] = [ # noqa: E741
[ntype(1) if n == nc else ntype(0) for nc in range(rows)] for n in range(rows)
]
- swapped = []
+ swapped: list[int] = []
+ lead = 0
for r in range(rows):
if lead >= cols:
- return _transpose(M), _transpose(I), swapped
- i = r
- while M[i][lead] == 0:
- i += 1
- if i != rows:
+ return _transpose(ech_matrix), _transpose(id_matrix), swapped
+ s = r
+ while ech_matrix[s][lead] == 0: # type: ignore
+ s += 1
+ if s != rows:
continue
- i = r
+ s = r
lead += 1
if cols == lead:
- return _transpose(M), _transpose(I), swapped
+ return _transpose(ech_matrix), _transpose(id_matrix), swapped
- M[i], M[r] = M[r], M[i]
- I[i], I[r] = I[r], I[i]
+ ech_matrix[s], ech_matrix[r] = ech_matrix[r], ech_matrix[s]
+ id_matrix[s], id_matrix[r] = id_matrix[r], id_matrix[s]
- swapped.append(i)
- lv = M[r][lead]
- M[r] = [mrx / lv for mrx in M[r]]
- I[r] = [mrx / lv for mrx in I[r]]
+ swapped.append(s)
+ lv = ech_matrix[r][lead]
+ ech_matrix[r] = [mrx / lv for mrx in ech_matrix[r]]
+ id_matrix[r] = [mrx / lv for mrx in id_matrix[r]]
- for i in range(rows):
- if i == r:
+ for s in range(rows):
+ if s == r:
continue
- lv = M[i][lead]
- M[i] = [iv - lv * rv for rv, iv in zip(M[r], M[i])]
- I[i] = [iv - lv * rv for rv, iv in zip(I[r], I[i])]
+ lv = ech_matrix[s][lead]
+ ech_matrix[s] = [
+ iv - lv * rv for rv, iv in zip(ech_matrix[r], ech_matrix[s])
+ ]
+ id_matrix[s] = [iv - lv * rv for rv, iv in zip(id_matrix[r], id_matrix[s])]
lead += 1
- return _transpose(M), _transpose(I), swapped
+ return _transpose(ech_matrix), _transpose(id_matrix), swapped
-def pi_theorem(quantities, registry=None):
+def pi_theorem(quantities: dict[str, Any], registry: UnitRegistry | None = None):
"""Builds dimensionless quantities using the Buckingham π theorem
Parameters
@@ -169,7 +228,7 @@ def pi_theorem(quantities, registry=None):
quantities : dict
mapping between variable name and units
registry :
- (Default value = None)
+ (default value = None)
Returns
-------
@@ -183,7 +242,7 @@ def pi_theorem(quantities, registry=None):
dimensions = set()
if registry is None:
- getdim = lambda x: x
+ getdim = _noop
non_int_type = float
else:
getdim = registry.get_dimensionality
@@ -211,18 +270,18 @@ def pi_theorem(quantities, registry=None):
dimensions = list(dimensions)
# Calculate dimensionless quantities
- M = [
+ matrix = [
[dimensionality[dimension] for name, dimensionality in quant]
for dimension in dimensions
]
- M, identity, pivot = column_echelon_form(M, transpose_result=False)
+ ech_matrix, id_matrix, pivot = column_echelon_form(matrix, transpose_result=False)
# Collect results
# Make all numbers integers and minimize the number of negative exponents.
# Remove zeros
results = []
- for rowm, rowi in zip(M, identity):
+ for rowm, rowi in zip(ech_matrix, id_matrix):
if any(el != 0 for el in rowm):
continue
max_den = max(f.denominator for f in rowi)
@@ -237,7 +296,9 @@ def pi_theorem(quantities, registry=None):
return results
-def solve_dependencies(dependencies):
+def solve_dependencies(
+ dependencies: dict[TH, set[TH]]
+) -> Generator[set[TH], None, None]:
"""Solve a dependency graph.
Parameters
@@ -246,12 +307,16 @@ def solve_dependencies(dependencies):
dependency dictionary. For each key, the value is an iterable indicating its
dependencies.
- Returns
- -------
- type
+ Yields
+ ------
+ set
iterator of sets, each containing keys of independents tasks dependent only of
the previous tasks in the list.
+ Raises
+ ------
+ ValueError
+ if a cyclic dependency is found.
"""
while dependencies:
# values not in keys (items without dep)
@@ -270,12 +335,37 @@ def solve_dependencies(dependencies):
yield t
-def find_shortest_path(graph, start, end, path=None):
+def find_shortest_path(
+ graph: dict[TH, set[TH]], start: TH, end: TH, path: list[TH] | None = None
+):
+ """Find shortest path between two nodes within a graph.
+
+ Parameters
+ ----------
+ graph
+ A graph given as a mapping of nodes
+ to a set of all connected nodes to it.
+ start
+ Starting node.
+ end
+ End node.
+ path
+ Path to prepend to the one found.
+ (default = None, empty path.)
+
+ Returns
+ -------
+ list[TH]
+ The shortest path between two nodes.
+ """
path = (path or []) + [start]
if start == end:
return path
+
+ # TODO: raise ValueError when start not in graph
if start not in graph:
return None
+
shortest = None
for node in graph[start]:
if node not in path:
@@ -283,10 +373,33 @@ def find_shortest_path(graph, start, end, path=None):
if newpath:
if not shortest or len(newpath) < len(shortest):
shortest = newpath
+
return shortest
-def find_connected_nodes(graph, start, visited=None):
+def find_connected_nodes(
+ graph: dict[TH, set[TH]], start: TH, visited: set[TH] | None = None
+) -> set[TH] | None:
+ """Find all nodes connected to a start node within a graph.
+
+ Parameters
+ ----------
+ graph
+ A graph given as a mapping of nodes
+ to a set of all connected nodes to it.
+ start
+ Starting node.
+ visited
+ Mutable set to collect visited nodes.
+ (default = None, empty set)
+
+ Returns
+ -------
+ set[TH]
+ The shortest path between two nodes.
+ """
+
+ # TODO: raise ValueError when start not in graph
if start not in graph:
return None
@@ -300,17 +413,17 @@ def find_connected_nodes(graph, start, visited=None):
return visited
-class udict(dict):
+class udict(dict[str, PintScalar]):
"""Custom dict implementing __missing__."""
- def __missing__(self, key):
+ def __missing__(self, key: str):
return 0
- def copy(self):
+ def copy(self: Self) -> Self:
return udict(self)
-class UnitsContainer(Mapping):
+class UnitsContainer(Mapping[str, PintScalar]):
"""The UnitsContainer stores the product of units and their respective
exponent and implements the corresponding operations.
@@ -318,23 +431,24 @@ class UnitsContainer(Mapping):
Parameters
----------
-
- Returns
- -------
- type
-
-
+ non_int_type
+ Numerical type used for non integer values.
"""
__slots__ = ("_d", "_hash", "_one", "_non_int_type")
- def __init__(self, *args, **kwargs) -> None:
+ _d: udict
+ _hash: int | None
+ _one: PintScalar
+ _non_int_type: type
+
+ def __init__(self, *args, non_int_type: type | None = None, **kwargs) -> None:
if args and isinstance(args[0], UnitsContainer):
default_non_int_type = args[0]._non_int_type
else:
default_non_int_type = float
- self._non_int_type = kwargs.pop("non_int_type", default_non_int_type)
+ self._non_int_type = non_int_type or default_non_int_type
if self._non_int_type is float:
self._one = 1
@@ -352,10 +466,26 @@ class UnitsContainer(Mapping):
d[key] = self._non_int_type(value)
self._hash = None
- def copy(self):
+ def copy(self: Self) -> Self:
+ """Create a copy of this UnitsContainer."""
return self.__copy__()
- def add(self, key, value):
+ def add(self: Self, key: str, value: Number) -> Self:
+ """Create a new UnitsContainer adding value to
+ the value existing for a given key.
+
+ Parameters
+ ----------
+ key
+ unit to which the value will be added.
+ value
+ value to be added.
+
+ Returns
+ -------
+ UnitsContainer
+ A copy of this container.
+ """
newval = self._d[key] + value
new = self.copy()
if newval:
@@ -365,17 +495,18 @@ class UnitsContainer(Mapping):
new._hash = None
return new
- def remove(self, keys):
- """Create a new UnitsContainer purged from given keys.
+ def remove(self: Self, keys: Iterable[str]) -> Self:
+ """Create a new UnitsContainer purged from given entries.
Parameters
----------
- keys :
-
+ keys
+ Iterable of keys (units) to remove.
Returns
-------
-
+ UnitsContainer
+ A copy of this container.
"""
new = self.copy()
for k in keys:
@@ -383,51 +514,52 @@ class UnitsContainer(Mapping):
new._hash = None
return new
- def rename(self, oldkey, newkey):
+ def rename(self: Self, oldkey: str, newkey: str) -> Self:
"""Create a new UnitsContainer in which an entry has been renamed.
Parameters
----------
- oldkey :
-
- newkey :
-
+ oldkey
+ Existing key (unit).
+ newkey
+ New key (unit).
Returns
-------
-
+ UnitsContainer
+ A copy of this container.
"""
new = self.copy()
new._d[newkey] = new._d.pop(oldkey)
new._hash = None
return new
- def __iter__(self):
+ def __iter__(self) -> Iterator[str]:
return iter(self._d)
def __len__(self) -> int:
return len(self._d)
- def __getitem__(self, key):
+ def __getitem__(self, key: str) -> PintScalar:
return self._d[key]
- def __contains__(self, key):
+ def __contains__(self, key: str) -> bool:
return key in self._d
- def __hash__(self):
+ def __hash__(self) -> int:
if self._hash is None:
self._hash = hash(frozenset(self._d.items()))
return self._hash
# Only needed by pickle protocol 0 and 1 (used by pytables)
- def __getstate__(self):
+ def __getstate__(self) -> tuple[udict, PintScalar, type]:
return self._d, self._one, self._non_int_type
- def __setstate__(self, state):
+ def __setstate__(self, state: tuple[udict, PintScalar, type]):
self._d, self._one, self._non_int_type = state
self._hash = None
- def __eq__(self, other) -> bool:
+ def __eq__(self, other: Any) -> bool:
if isinstance(other, UnitsContainer):
# UnitsContainer.__hash__(self) is not the same as hash(self); see
# ParserHelper.__hash__ and __eq__.
@@ -472,7 +604,7 @@ class UnitsContainer(Mapping):
out._one = self._one
return out
- def __mul__(self, other):
+ def __mul__(self, other: Any):
if not isinstance(other, self.__class__):
err = "Cannot multiply UnitsContainer by {}"
raise TypeError(err.format(type(other)))
@@ -488,7 +620,7 @@ class UnitsContainer(Mapping):
__rmul__ = __mul__
- def __pow__(self, other):
+ def __pow__(self, other: Any):
if not isinstance(other, NUMERIC_TYPES):
err = "Cannot power UnitsContainer by {}"
raise TypeError(err.format(type(other)))
@@ -499,7 +631,7 @@ class UnitsContainer(Mapping):
new._hash = None
return new
- def __truediv__(self, other):
+ def __truediv__(self, other: Any):
if not isinstance(other, self.__class__):
err = "Cannot divide UnitsContainer by {}"
raise TypeError(err.format(type(other)))
@@ -513,7 +645,7 @@ class UnitsContainer(Mapping):
new._hash = None
return new
- def __rtruediv__(self, other):
+ def __rtruediv__(self, other: Any):
if not isinstance(other, self.__class__) and other != 1:
err = "Cannot divide {} by UnitsContainer"
raise TypeError(err.format(type(other)))
@@ -524,41 +656,48 @@ class UnitsContainer(Mapping):
class ParserHelper(UnitsContainer):
"""The ParserHelper stores in place the product of variables and
their respective exponent and implements the corresponding operations.
+ It also provides a scaling factor.
+
+ For example:
+ `3 * m ** 2` becomes ParserHelper(3, m=2)
+
+ Briefly is a UnitsContainer with a scaling factor.
ParserHelper is a read-only mapping. All operations (even in place ones)
+ WARNING : The hash value used does not take into account the scale
+ attribute so be careful if you use it as a dict key and then two unequal
+ object can have the same hash.
+
Parameters
----------
-
- Returns
- -------
- type
- WARNING : The hash value used does not take into account the scale
- attribute so be careful if you use it as a dict key and then two unequal
- object can have the same hash.
-
+ scale
+ Scaling factor.
+ (default = 1)
+ **kwargs
+ Used to populate the dict of units and exponents.
"""
__slots__ = ("scale",)
- def __init__(self, scale=1, *args, **kwargs):
+ scale: PintScalar
+
+ def __init__(self, scale: PintScalar = 1, *args, **kwargs):
super().__init__(*args, **kwargs)
self.scale = scale
@classmethod
- def from_word(cls, input_word, non_int_type=float):
+ def from_word(cls, input_word: str, non_int_type: type = float) -> ParserHelper:
"""Creates a ParserHelper object with a single variable with exponent one.
- Equivalent to: ParserHelper({'word': 1})
+ Equivalent to: ParserHelper(1, {input_word: 1})
Parameters
----------
- input_word :
-
-
- Returns
- -------
+ input_word
+ non_int_type
+ Numerical type used for non integer values.
"""
if non_int_type is float:
return cls(1, [(input_word, 1)], non_int_type=non_int_type)
@@ -567,7 +706,7 @@ class ParserHelper(UnitsContainer):
return cls(ONE, [(input_word, ONE)], non_int_type=non_int_type)
@classmethod
- def eval_token(cls, token, non_int_type=float):
+ def eval_token(cls, token: tokenize.TokenInfo, non_int_type: type = float):
token_type = token.type
token_text = token.string
if token_type == NUMBER:
@@ -585,17 +724,15 @@ class ParserHelper(UnitsContainer):
@classmethod
@lru_cache
- def from_string(cls, input_string, non_int_type=float):
+ def from_string(cls, input_string: str, non_int_type: type = float) -> ParserHelper:
"""Parse linear expression mathematical units and return a quantity object.
Parameters
----------
- input_string :
-
-
- Returns
- -------
+ input_string
+ non_int_type
+ Numerical type used for non integer values.
"""
if not input_string:
return cls(non_int_type=non_int_type)
@@ -656,7 +793,7 @@ class ParserHelper(UnitsContainer):
super().__setstate__(state[:-1])
self.scale = state[-1]
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> bool:
if isinstance(other, ParserHelper):
return self.scale == other.scale and super().__eq__(other)
elif isinstance(other, str):
@@ -666,7 +803,7 @@ class ParserHelper(UnitsContainer):
return self.scale == 1 and super().__eq__(other)
- def operate(self, items, op=operator.iadd, cleanup=True):
+ def operate(self, items, op=operator.iadd, cleanup: bool = True):
d = udict(self._d)
for key, value in items:
d[key] = op(d[key], value)
@@ -811,21 +948,22 @@ class SharedRegistryObject:
inst._REGISTRY = application_registry.get()
return inst
- def _check(self, other) -> bool:
+ def _check(self, other: Any) -> bool:
"""Check if the other object use a registry and if so that it is the
same registry.
Parameters
----------
- other :
-
+ other
Returns
-------
- type
- other don't use a registry and raise ValueError if other don't use the
- same unit registry.
+ bool
+ Raises
+ ------
+ ValueError
+ if other don't use the same unit registry.
"""
if self._REGISTRY is getattr(other, "_REGISTRY", None):
return True
@@ -844,17 +982,17 @@ class PrettyIPython:
default_format: str
- def _repr_html_(self):
+ def _repr_html_(self) -> str:
if "~" in self.default_format:
return f"{self:~H}"
return f"{self:H}"
- def _repr_latex_(self):
+ def _repr_latex_(self) -> str:
if "~" in self.default_format:
return f"${self:~L}$"
return f"${self:L}$"
- def _repr_pretty_(self, p, cycle):
+ def _repr_pretty_(self, p, cycle: bool):
if "~" in self.default_format:
p.text(f"{self:~P}")
else:
@@ -868,14 +1006,15 @@ def to_units_container(
Parameters
----------
- unit_like :
-
- registry :
- (Default value = None)
+ unit_like
+ Quantity or Unit to infer the plain units from.
+ registry
+ If provided, uses the registry's UnitsContainer and parse_unit_name. If None,
+ uses the registry attached to unit_like.
Returns
-------
-
+ UnitsContainer
"""
mro = type(unit_like).mro()
if UnitsContainer in mro:
@@ -902,10 +1041,9 @@ def infer_base_unit(
Parameters
----------
- unit_like : Union[UnitLike, Quantity]
+ unit_like
Quantity or Unit to infer the plain units from.
-
- registry: Optional[UnitRegistry]
+ registry
If provided, uses the registry's UnitsContainer and parse_unit_name. If None,
uses the registry attached to unit_like.
@@ -940,7 +1078,7 @@ def infer_base_unit(
return registry.UnitsContainer(nonzero_dict)
-def getattr_maybe_raise(self, item):
+def getattr_maybe_raise(obj: Any, item: str):
"""Helper function invoked at start of all overridden ``__getattr__``.
Raise AttributeError if the user tries to ask for a _ or __ attribute,
@@ -949,39 +1087,25 @@ def getattr_maybe_raise(self, item):
Parameters
----------
- item : string
- Item to be found.
-
-
- Returns
- -------
+ item
+ attribute to be found.
+ Raises
+ ------
+ AttributeError
"""
# Double-underscore attributes are tricky to detect because they are
- # automatically prefixed with the class name - which may be a subclass of self
+ # automatically prefixed with the class name - which may be a subclass of obj
if (
item.endswith("__")
or len(item.lstrip("_")) == 0
or (item.startswith("_") and not item.lstrip("_")[0].isdigit())
):
- raise AttributeError(f"{self!r} object has no attribute {item!r}")
-
+ raise AttributeError(f"{obj!r} object has no attribute {item!r}")
-def iterable(y) -> bool:
- """Check whether or not an object can be iterated over.
-
- Vendored from numpy under the terms of the BSD 3-Clause License. (Copyright
- (c) 2005-2019, NumPy Developers.)
-
- Parameters
- ----------
- value :
- Input object.
- type :
- object
- y :
- """
+def iterable(y: Any) -> bool:
+ """Check whether or not an object can be iterated over."""
try:
iter(y)
except TypeError:
@@ -989,18 +1113,8 @@ def iterable(y) -> bool:
return True
-def sized(y) -> bool:
- """Check whether or not an object has a defined length.
-
- Parameters
- ----------
- value :
- Input object.
- type :
- object
- y :
-
- """
+def sized(y: Any) -> bool:
+ """Check whether or not an object has a defined length."""
try:
len(y)
except TypeError:
@@ -1008,7 +1122,7 @@ def sized(y) -> bool:
return True
-def create_class_with_registry(registry, base_class) -> type:
+def create_class_with_registry(registry: UnitRegistry, base_class: type) -> type:
"""Create new class inheriting from base_class and
filling _REGISTRY class attribute with an actual instanced registry.
"""