diff options
author | Hernan Grecco <hgrecco@gmail.com> | 2023-05-01 19:23:47 -0300 |
---|---|---|
committer | Hernan Grecco <hgrecco@gmail.com> | 2023-05-01 19:24:18 -0300 |
commit | 556aeea0d363f5757c42296ad66ffe47f05c02b2 (patch) | |
tree | 476472495bd9c623e098aa6b5ee7554fd44387a5 | |
parent | 95f3eaca1129b735cb3eae8702ea857928a05909 (diff) | |
parent | b1c01862a3811f77cca726675b52a32418c4d853 (diff) | |
download | pint-556aeea0d363f5757c42296ad66ffe47f05c02b2.tar.gz |
Merge changes to modernize code from 0.21 to 0.22
See #1751
86 files changed, 1377 insertions, 1038 deletions
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 369b9b9..7dd55db 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -7,15 +7,15 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.8, 3.9, "3.10", "3.11"] - numpy: [null, "numpy>=1.19,<2.0.0"] + python-version: [3.9, "3.10", "3.11"] + numpy: [null, "numpy>=1.21,<2.0.0"] uncertainties: [null, "uncertainties==3.1.6", "uncertainties>=3.1.6,<4.0.0"] extras: [null] include: - - python-version: 3.8 # Minimal versions + - python-version: 3.9 # Minimal versions numpy: "numpy" extras: matplotlib==2.2.5 - - python-version: 3.8 + - python-version: 3.9 numpy: "numpy" uncertainties: "uncertainties" extras: "sparse xarray netCDF4 dask[complete]==2023.4.0 graphviz babel==2.8" @@ -92,8 +92,8 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.8, 3.9, "3.10", "3.11"] - numpy: [ "numpy>=1.19,<2.0.0" ] + python-version: [3.9, "3.10", "3.11"] + numpy: [ "numpy>=1.21,<2.0.0" ] runs-on: windows-latest env: @@ -153,8 +153,8 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.8, 3.9, "3.10", "3.11"] - numpy: [null, "numpy>=1.19,<2.0.0" ] + python-version: [3.9, "3.10", "3.11"] + numpy: [null, "numpy>=1.21,<2.0.0" ] runs-on: macos-latest env: @@ -226,13 +226,3 @@ jobs: # run: | # pip install coveralls "requests<2.29" # coveralls --finish - - # Dummy task to summarize all. See https://github.com/bors-ng/bors-ng/issues/1300 - # ci-success: - # name: ci - # if: ${{ success() }} - # needs: test-linux - # runs-on: ubuntu-latest - # steps: - # - name: CI succeeded - # run: exit 0 diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 2340683..0a26da8 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -14,10 +14,10 @@ jobs: - name: Get tags run: git fetch --depth=1 origin +refs/tags/*:refs/tags/* - - name: Set up Python 3.8 + - name: Set up minimal Python version uses: actions/setup-python@v2 with: - python-version: 3.8 + python-version: 3.9 - name: Get pip cache dir id: pip-cache diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 0000000..3cf9f79 --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,27 @@ +name: Build and publish to PyPI + +on: + push: + tags: + - '*' + +jobs: + publish: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + + - uses: actions/setup-python@v4 + with: + python-version: '3.x' + + - name: Install dependencies + run: python -m pip install build + + - name: Build package + run: python -m build + + - name: Publish to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 2bda3d4..830a8c2 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -5,7 +5,7 @@ sphinx: configuration: docs/conf.py fail_on_warning: false python: - version: 3.8 + version: 3.9 install: - requirements: requirements_docs.txt - method: pip @@ -43,7 +43,7 @@ and constants. Due to its modular design, you can extend (or even rewrite!) the complete list without changing the source code. It supports a lot of numpy mathematical operations **without monkey patching or wrapping numpy**. -It has a complete test coverage. It runs in Python 3.8+ with no other dependency. +It has a complete test coverage. It runs in Python 3.9+ with no other dependency. It is licensed under BSD. It is extremely easy and natural to use: diff --git a/benchmarks/benchmarks/20_quantity.py b/benchmarks/benchmarks/20_quantity.py index c0174ef..cbd03b2 100644 --- a/benchmarks/benchmarks/20_quantity.py +++ b/benchmarks/benchmarks/20_quantity.py @@ -8,7 +8,7 @@ from . import util units = ("meter", "kilometer", "second", "minute", "angstrom") all_values = ("int", "float", "complex") all_values_q = tuple( - "%s_%s" % (a, b) for a, b in it.product(all_values, ("meter", "kilometer")) + f"{a}_{b}" for a, b in it.product(all_values, ("meter", "kilometer")) ) op1 = (operator.neg, operator.truth) diff --git a/benchmarks/benchmarks/30_numpy.py b/benchmarks/benchmarks/30_numpy.py index 15ae66c..139ce58 100644 --- a/benchmarks/benchmarks/30_numpy.py +++ b/benchmarks/benchmarks/30_numpy.py @@ -9,11 +9,11 @@ from . import util lengths = ("short", "mid") all_values = tuple( - "%s_%s" % (a, b) for a, b in it.product(lengths, ("list", "tuple", "array")) + f"{a}_{b}" for a, b in it.product(lengths, ("list", "tuple", "array")) ) all_arrays = ("short_array", "mid_array") units = ("meter", "kilometer") -all_arrays_q = tuple("%s_%s" % (a, b) for a, b in it.product(all_arrays, units)) +all_arrays_q = tuple(f"{a}_{b}" for a, b in it.product(all_arrays, units)) ureg = None data = {} diff --git a/bors.toml b/bors.toml deleted file mode 100644 index 4e9e7be..0000000 --- a/bors.toml +++ /dev/null @@ -1,8 +0,0 @@ -status = [ - "ci", - "docbuild", - "lint" -] -delete_merged_branches = true -timeout_sec = 10800 -block_labels = [ "do-not-merge-yet" ] diff --git a/docs/dev/contributing.rst b/docs/dev/contributing.rst index c63381b..e70a375 100644 --- a/docs/dev/contributing.rst +++ b/docs/dev/contributing.rst @@ -9,7 +9,6 @@ Pint uses (and thanks): - `github actions`_ to test all commits and PRs. - coveralls_ to monitor coverage test coverage - readthedocs_ to host the documentation. -- `bors-ng`_ as a merge bot and therefore every PR is tested before merging. - black_, isort_ and flake8_ as code linters and pre-commit_ to enforce them. - pytest_ to write tests - sphinx_ to write docs. @@ -133,7 +132,6 @@ features that work best as an extension package versus direct inclusion in Pint .. _github: http://github.com/hgrecco/pint .. _`issue tracker`: https://github.com/hgrecco/pint/issues -.. _`bors-ng`: https://github.com/bors-ng/bors-ng .. _`github docs`: https://help.github.com/articles/closing-issues-via-commit-messages/ .. _`github actions`: https://docs.github.com/en/actions .. _coveralls: https://coveralls.io/ diff --git a/docs/getting/index.rst b/docs/getting/index.rst index 9907aeb..41ffaf9 100644 --- a/docs/getting/index.rst +++ b/docs/getting/index.rst @@ -8,7 +8,7 @@ The getting started guide aims to get you using pint productively as quickly as Installation ------------ -Pint has no dependencies except Python itself. In runs on Python 3.8+. +Pint has no dependencies except Python itself. In runs on Python 3.9+. .. grid:: 2 diff --git a/docs/getting/overview.rst b/docs/getting/overview.rst index cd639aa..61dfc14 100644 --- a/docs/getting/overview.rst +++ b/docs/getting/overview.rst @@ -14,7 +14,7 @@ Due to its modular design, you can extend (or even rewrite!) the complete list without changing the source code. It supports a lot of numpy mathematical operations **without monkey patching or wrapping numpy**. -It has a complete test coverage. It runs in Python 3.8+ with no other +It has a complete test coverage. It runs in Python 3.9+ with no other dependencies. It is licensed under a `BSD 3-clause style license`_. It is extremely easy and natural to use: diff --git a/pint/_typing.py b/pint/_typing.py index 64c3a2b..5547f85 100644 --- a/pint/_typing.py +++ b/pint/_typing.py @@ -1,17 +1,61 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Tuple, TypeVar, Union +from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union, Protocol + +# TODO: Remove when 3.11 becomes minimal version. +Self = TypeVar("Self") if TYPE_CHECKING: from .facets.plain import PlainQuantity as Quantity from .facets.plain import PlainUnit as Unit from .util import UnitsContainer + +class PintScalar(Protocol): + def __add__(self, other: Any) -> Any: + ... + + def __sub__(self, other: Any) -> Any: + ... + + def __mul__(self, other: Any) -> Any: + ... + + def __truediv__(self, other: Any) -> Any: + ... + + def __floordiv__(self, other: Any) -> Any: + ... + + def __mod__(self, other: Any) -> Any: + ... + + def __divmod__(self, other: Any) -> Any: + ... + + def __pow__(self, other: Any, modulo: Any) -> Any: + ... + + +class PintArray(Protocol): + def __len__(self) -> int: + ... + + def __getitem__(self, key: Any) -> Any: + ... + + def __setitem__(self, key: Any, value: Any) -> None: + ... + + +Magnitude = PintScalar | PintScalar + + UnitLike = Union[str, "UnitsContainer", "Unit"] QuantityOrUnitLike = Union["Quantity", UnitLike] -Shape = Tuple[int, ...] +Shape = tuple[int, ...] _MagnitudeType = TypeVar("_MagnitudeType") S = TypeVar("S") diff --git a/pint/babel_names.py b/pint/babel_names.py index 09fa046..408ef8f 100644 --- a/pint/babel_names.py +++ b/pint/babel_names.py @@ -10,7 +10,7 @@ from __future__ import annotations from .compat import HAS_BABEL -_babel_units = dict( +_babel_units: dict[str, str] = dict( standard_gravity="acceleration-g-force", millibar="pressure-millibar", metric_ton="mass-metric-ton", @@ -141,6 +141,6 @@ _babel_units = dict( if not HAS_BABEL: _babel_units = {} -_babel_systems = dict(mks="metric", imperial="uksystem", US="ussystem") +_babel_systems: dict[str, str] = dict(mks="metric", imperial="uksystem", US="ussystem") -_babel_lengths = ["narrow", "short", "long"] +_babel_lengths: list[str] = ["narrow", "short", "long"] diff --git a/pint/compat.py b/pint/compat.py index c585455..7b48efa 100644 --- a/pint/compat.py +++ b/pint/compat.py @@ -16,13 +16,21 @@ from decimal import Decimal from importlib import import_module from io import BytesIO from numbers import Number -from typing import Mapping, Optional +from collections.abc import Mapping +from typing import Any, NoReturn, Callable +from collections.abc import Generator, Iterable -def missing_dependency(package, display_name=None): +def missing_dependency( + package: str, display_name: str | None = None +) -> Callable[..., NoReturn]: + """Return a helper function that raises an exception when used. + + It provides a way delay a missing dependency exception until it is used. + """ display_name = display_name or package - def _inner(*args, **kwargs): + def _inner(*args: Any, **kwargs: Any) -> NoReturn: raise Exception( "This feature requires %s. Please install it by running:\n" "pip install %s" % (display_name, package) @@ -31,7 +39,14 @@ def missing_dependency(package, display_name=None): return _inner -def tokenizer(input_string): +def tokenizer(input_string: str) -> Generator[tokenize.TokenInfo, None, None]: + """Tokenize an input string, encoded as UTF-8 + and skipping the ENCODING token. + + See Also + -------- + tokenize.tokenize + """ for tokinfo in tokenize.tokenize(BytesIO(input_string.encode("utf-8")).readline): if tokinfo.type != tokenize.ENCODING: yield tokinfo @@ -53,7 +68,7 @@ try: def _to_magnitude(value, force_ndarray=False, force_ndarray_like=False): if isinstance(value, (dict, bool)) or value is None: - raise TypeError("Invalid magnitude for Quantity: {0!r}".format(value)) + raise TypeError(f"Invalid magnitude for Quantity: {value!r}") elif isinstance(value, str) and value == "": raise ValueError("Quantity magnitude cannot be an empty string.") elif isinstance(value, (list, tuple)): @@ -102,7 +117,7 @@ except ImportError: "Cannot force to ndarray or ndarray-like when NumPy is not present." ) elif isinstance(value, (dict, bool)) or value is None: - raise TypeError("Invalid magnitude for Quantity: {0!r}".format(value)) + raise TypeError(f"Invalid magnitude for Quantity: {value!r}") elif isinstance(value, str) and value == "": raise ValueError("Quantity magnitude cannot be an empty string.") elif isinstance(value, (list, tuple)): @@ -154,7 +169,8 @@ else: from math import log # noqa: F401 if not HAS_BABEL: - babel_parse = babel_units = missing_dependency("Babel") # noqa: F811 + babel_parse = missing_dependency("Babel") # noqa: F811 + babel_units = babel_parse if not HAS_MIP: mip_missing = missing_dependency("mip") @@ -176,6 +192,9 @@ except ImportError: dask_array = None +# TODO: merge with upcast_type_map + +#: List upcast type names upcast_type_names = ( "pint_pandas.PintArray", "pandas.Series", @@ -186,10 +205,12 @@ upcast_type_names = ( "xarray.core.dataarray.DataArray", ) -upcast_type_map: Mapping[str : Optional[type]] = {k: None for k in upcast_type_names} +#: Map type name to the actual type (for upcast types). +upcast_type_map: Mapping[str, type | None] = {k: None for k in upcast_type_names} def fully_qualified_name(t: type) -> str: + """Return the fully qualified name of a type.""" module = t.__module__ name = t.__qualname__ @@ -200,6 +221,10 @@ def fully_qualified_name(t: type) -> str: def check_upcast_type(obj: type) -> bool: + """Check if the type object is an upcast type.""" + + # TODO: merge or unify name with is_upcast_type + fqn = fully_qualified_name(obj) if fqn not in upcast_type_map: return False @@ -215,22 +240,17 @@ def check_upcast_type(obj: type) -> bool: def is_upcast_type(other: type) -> bool: + """Check if the type object is an upcast type.""" + + # TODO: merge or unify name with check_upcast_type + if other in upcast_type_map.values(): return True return check_upcast_type(other) -def is_duck_array_type(cls) -> bool: - """Check if the type object represents a (non-Quantity) duck array type. - - Parameters - ---------- - cls : class - - Returns - ------- - bool - """ +def is_duck_array_type(cls: type) -> bool: + """Check if the type object represents a (non-Quantity) duck array type.""" # TODO (NEP 30): replace duck array check with hasattr(other, "__duckarray__") return issubclass(cls, ndarray) or ( not hasattr(cls, "_magnitude") @@ -242,20 +262,21 @@ def is_duck_array_type(cls) -> bool: ) -def is_duck_array(obj): +def is_duck_array(obj: type) -> bool: + """Check if an object represents a (non-Quantity) duck array type.""" return is_duck_array_type(type(obj)) -def eq(lhs, rhs, check_all: bool): +def eq(lhs: Any, rhs: Any, check_all: bool) -> bool | Iterable[bool]: """Comparison of scalars and arrays. Parameters ---------- - lhs : object + lhs left-hand side - rhs : object + rhs right-hand side - check_all : bool + check_all if True, reduce sequence to single bool; return True if all the elements are equal. @@ -269,21 +290,21 @@ def eq(lhs, rhs, check_all: bool): return out -def isnan(obj, check_all: bool): - """Test for NaN or NaT +def isnan(obj: Any, check_all: bool) -> bool | Iterable[bool]: + """Test for NaN or NaT. Parameters ---------- - obj : object + obj scalar or vector - check_all : bool + check_all if True, reduce sequence to single bool; return True if any of the elements are NaN. Returns ------- bool or array_like of bool. - Always return False for non-numeric types. + Always return False for non-numeric types. """ if is_duck_array_type(type(obj)): if obj.dtype.kind in "if": @@ -302,21 +323,21 @@ def isnan(obj, check_all: bool): return False -def zero_or_nan(obj, check_all: bool): - """Test if obj is zero, NaN, or NaT +def zero_or_nan(obj: Any, check_all: bool) -> bool | Iterable[bool]: + """Test if obj is zero, NaN, or NaT. Parameters ---------- - obj : object + obj scalar or vector - check_all : bool + check_all if True, reduce sequence to single bool; return True if all the elements are zero, NaN, or NaT. Returns ------- bool or array_like of bool. - Always return False for non-numeric types. + Always return False for non-numeric types. """ out = eq(obj, 0, False) + isnan(obj, False) if check_all and is_duck_array_type(type(out)): diff --git a/pint/context.py b/pint/context.py index 4839926..6c74f65 100644 --- a/pint/context.py +++ b/pint/context.py @@ -18,3 +18,5 @@ if TYPE_CHECKING: #: Regex to match the header parts of a context. #: Regex to match variable names in an equation. + +# TODO: delete this file diff --git a/pint/converters.py b/pint/converters.py index 12248a8..9494ad1 100644 --- a/pint/converters.py +++ b/pint/converters.py @@ -13,6 +13,10 @@ from __future__ import annotations from dataclasses import dataclass from dataclasses import fields as dc_fields +from typing import Any + +from ._typing import Self, Magnitude + from .compat import HAS_NUMPY, exp, log # noqa: F401 @@ -24,17 +28,17 @@ class Converter: _param_names_to_subclass = {} @property - def is_multiplicative(self): + def is_multiplicative(self) -> bool: return True @property - def is_logarithmic(self): + def is_logarithmic(self) -> bool: return False - def to_reference(self, value, inplace=False): + def to_reference(self, value: Magnitude, inplace: bool = False) -> Magnitude: return value - def from_reference(self, value, inplace=False): + def from_reference(self, value: Magnitude, inplace: bool = False) -> Magnitude: return value def __init_subclass__(cls, **kwargs): @@ -43,21 +47,21 @@ class Converter: cls._subclasses.append(cls) @classmethod - def get_field_names(cls, new_cls): - return frozenset((p.name for p in dc_fields(new_cls))) + def get_field_names(cls, new_cls) -> frozenset[str]: + return frozenset(p.name for p in dc_fields(new_cls)) @classmethod def preprocess_kwargs(cls, **kwargs): return None @classmethod - def from_arguments(cls, **kwargs): + def from_arguments(cls: type[Self], **kwargs: Any) -> Self: kwk = frozenset(kwargs.keys()) try: new_cls = cls._param_names_to_subclass[kwk] except KeyError: for new_cls in cls._subclasses: - p_names = frozenset((p.name for p in dc_fields(new_cls))) + p_names = frozenset(p.name for p in dc_fields(new_cls)) if p_names == kwk: cls._param_names_to_subclass[kwk] = new_cls break diff --git a/pint/definitions.py b/pint/definitions.py index 789d9e3..ce89e94 100644 --- a/pint/definitions.py +++ b/pint/definitions.py @@ -8,6 +8,8 @@ :license: BSD, see LICENSE for more details. """ +from __future__ import annotations + from . import errors from ._vendor import flexparser as fp from .delegates import ParserConfig, txt_defparser @@ -17,12 +19,28 @@ class Definition: """This is kept for backwards compatibility""" @classmethod - def from_string(cls, s: str, non_int_type=float): + def from_string(cls, input_string: str, non_int_type: type = float) -> Definition: + """Parse a string into a definition object. + + Parameters + ---------- + input_string + Single line string. + non_int_type + Numerical type used for non integer values. + + Raises + ------ + DefinitionSyntaxError + If a syntax error was found. + """ cfg = ParserConfig(non_int_type) parser = txt_defparser.DefParser(cfg, None) - pp = parser.parse_string(s) + pp = parser.parse_string(input_string) for definition in parser.iter_parsed_project(pp): if isinstance(definition, Exception): raise errors.DefinitionSyntaxError(str(definition)) if not isinstance(definition, (fp.BOS, fp.BOF, fp.BOS)): return definition + + # TODO: What shall we do in this return path. diff --git a/pint/delegates/__init__.py b/pint/delegates/__init__.py index 363ef9c..b2eb9a3 100644 --- a/pint/delegates/__init__.py +++ b/pint/delegates/__init__.py @@ -11,4 +11,4 @@ from . import txt_defparser from .base_defparser import ParserConfig, build_disk_cache_class -__all__ = [txt_defparser, ParserConfig, build_disk_cache_class] +__all__ = ["txt_defparser", "ParserConfig", "build_disk_cache_class"] diff --git a/pint/delegates/base_defparser.py b/pint/delegates/base_defparser.py index d35f3e3..9e784ac 100644 --- a/pint/delegates/base_defparser.py +++ b/pint/delegates/base_defparser.py @@ -14,7 +14,6 @@ import functools import itertools import numbers import pathlib -import typing as ty from dataclasses import dataclass, field from pint import errors @@ -27,10 +26,10 @@ from .._vendor import flexparser as fp @dataclass(frozen=True) class ParserConfig: - """Configuration used by the parser.""" + """Configuration used by the parser in Pint.""" #: Indicates the output type of non integer numbers. - non_int_type: ty.Type[numbers.Number] = float + non_int_type: type[numbers.Number] = float def to_scaled_units_container(self, s: str): return ParserHelper.from_string(s, self.non_int_type) @@ -67,7 +66,12 @@ class ParserConfig: return val.scale -@functools.lru_cache() +@dataclass(frozen=True) +class PintParsedStatement(fp.ParsedStatement[ParserConfig]): + """A parsed statement for pint, specialized in the actual config.""" + + +@functools.lru_cache def build_disk_cache_class(non_int_type: type): """Build disk cache class, taking into account the non_int_type.""" @@ -84,14 +88,11 @@ def build_disk_cache_class(non_int_type: type): class ParsedProjecHeader(fc.NameByHashIter, PintHeader): @classmethod def from_parsed_project(cls, pp: fp.ParsedProject, reader_id): - tmp = [] - for stmt in pp.iter_statements(): - if isinstance(stmt, fp.BOS): - tmp.append( - stmt.content_hash.algorithm_name - + ":" - + stmt.content_hash.hexdigest - ) + tmp = ( + f"{stmt.content_hash.algorithm_name}:{stmt.content_hash.hexdigest}" + for stmt in pp.iter_statements() + if isinstance(stmt, fp.BOS) + ) return cls(tuple(tmp), reader_id) diff --git a/pint/delegates/txt_defparser/__init__.py b/pint/delegates/txt_defparser/__init__.py index 5572ca1..49e4a0b 100644 --- a/pint/delegates/txt_defparser/__init__.py +++ b/pint/delegates/txt_defparser/__init__.py @@ -11,4 +11,6 @@ from .defparser import DefParser -__all__ = [DefParser] +__all__ = [ + "DefParser", +] diff --git a/pint/delegates/txt_defparser/block.py b/pint/delegates/txt_defparser/block.py index 20ebcba..e8d8aa4 100644 --- a/pint/delegates/txt_defparser/block.py +++ b/pint/delegates/txt_defparser/block.py @@ -17,11 +17,14 @@ from __future__ import annotations from dataclasses import dataclass +from typing import Generic, TypeVar + +from ..base_defparser import PintParsedStatement, ParserConfig from ..._vendor import flexparser as fp @dataclass(frozen=True) -class EndDirectiveBlock(fp.ParsedStatement): +class EndDirectiveBlock(PintParsedStatement): """An EndDirectiveBlock is simply an "@end" statement.""" @classmethod @@ -31,8 +34,16 @@ class EndDirectiveBlock(fp.ParsedStatement): return None +OPST = TypeVar("OPST", bound="PintParsedStatement") +IPST = TypeVar("IPST", bound="PintParsedStatement") + +DefT = TypeVar("DefT") + + @dataclass(frozen=True) -class DirectiveBlock(fp.Block): +class DirectiveBlock( + Generic[DefT, OPST, IPST], fp.Block[OPST, IPST, EndDirectiveBlock, ParserConfig] +): """Directive blocks have beginning statement starting with a @ character. and ending with a "@end" (captured using a EndDirectiveBlock). @@ -41,5 +52,5 @@ class DirectiveBlock(fp.Block): closing: EndDirectiveBlock - def derive_definition(self): - pass + def derive_definition(self) -> DefT: + ... diff --git a/pint/delegates/txt_defparser/common.py b/pint/delegates/txt_defparser/common.py index 493d0ec..a1195b3 100644 --- a/pint/delegates/txt_defparser/common.py +++ b/pint/delegates/txt_defparser/common.py @@ -30,7 +30,7 @@ class DefinitionSyntaxError(errors.DefinitionSyntaxError, fp.ParsingError): location: str = field(init=False, default="") - def __str__(self): + def __str__(self) -> str: msg = ( self.msg + "\n " + (self.format_position or "") + " " + (self.raw or "") ) @@ -38,7 +38,7 @@ class DefinitionSyntaxError(errors.DefinitionSyntaxError, fp.ParsingError): msg += "\n " + self.location return msg - def set_location(self, value): + def set_location(self, value: str) -> None: super().__setattr__("location", value) @@ -47,7 +47,7 @@ class ImportDefinition(fp.IncludeStatement): value: str @property - def target(self): + def target(self) -> str: return self.value @classmethod diff --git a/pint/delegates/txt_defparser/context.py b/pint/delegates/txt_defparser/context.py index 5c54b4c..ce9fc9b 100644 --- a/pint/delegates/txt_defparser/context.py +++ b/pint/delegates/txt_defparser/context.py @@ -20,36 +20,35 @@ import numbers import re import typing as ty from dataclasses import dataclass -from typing import Dict, Tuple from ..._vendor import flexparser as fp from ...facets.context import definitions -from ..base_defparser import ParserConfig +from ..base_defparser import ParserConfig, PintParsedStatement from . import block, common, plain +# TODO check syntax +T = ty.TypeVar("T", bound="ForwardRelation | BidirectionalRelation") -@dataclass(frozen=True) -class Relation(definitions.Relation): - @classmethod - def _from_string_and_context_sep( - cls, s: str, config: ParserConfig, separator: str - ) -> fp.FromString[Relation]: - if separator not in s: - return None - if ":" not in s: - return None - rel, eq = s.split(":") +def _from_string_and_context_sep( + cls: type[T], s: str, config: ParserConfig, separator: str +) -> T | None: + if separator not in s: + return None + if ":" not in s: + return None + + rel, eq = s.split(":") - parts = rel.split(separator) + parts = rel.split(separator) - src, dst = (config.to_dimension_container(s) for s in parts) + src, dst = (config.to_dimension_container(s) for s in parts) - return cls(src, dst, eq.strip()) + return cls(src, dst, eq.strip()) @dataclass(frozen=True) -class ForwardRelation(fp.ParsedStatement, definitions.ForwardRelation, Relation): +class ForwardRelation(PintParsedStatement, definitions.ForwardRelation): """A relation connecting a dimension to another via a transformation function. <source dimension> -> <target dimension>: <transformation function> @@ -59,13 +58,11 @@ class ForwardRelation(fp.ParsedStatement, definitions.ForwardRelation, Relation) def from_string_and_config( cls, s: str, config: ParserConfig ) -> fp.FromString[ForwardRelation]: - return super()._from_string_and_context_sep(s, config, "->") + return _from_string_and_context_sep(cls, s, config, "->") @dataclass(frozen=True) -class BidirectionalRelation( - fp.ParsedStatement, definitions.BidirectionalRelation, Relation -): +class BidirectionalRelation(PintParsedStatement, definitions.BidirectionalRelation): """A bidirectional relation connecting a dimension to another via a simple transformation function. @@ -77,11 +74,11 @@ class BidirectionalRelation( def from_string_and_config( cls, s: str, config: ParserConfig ) -> fp.FromString[BidirectionalRelation]: - return super()._from_string_and_context_sep(s, config, "<->") + return _from_string_and_context_sep(cls, s, config, "<->") @dataclass(frozen=True) -class BeginContext(fp.ParsedStatement): +class BeginContext(PintParsedStatement): """Being of a context directive. @context[(defaults)] <canonical name> [= <alias>] [= <alias>] @@ -92,8 +89,8 @@ class BeginContext(fp.ParsedStatement): ) name: str - aliases: Tuple[str, ...] - defaults: Dict[str, numbers.Number] + aliases: tuple[str] + defaults: dict[str, numbers.Number] @classmethod def from_string_and_config( @@ -131,7 +128,18 @@ class BeginContext(fp.ParsedStatement): @dataclass(frozen=True) -class ContextDefinition(block.DirectiveBlock): +class ContextDefinition( + block.DirectiveBlock[ + definitions.ContextDefinition, + BeginContext, + ty.Union[ + plain.CommentDefinition, + BidirectionalRelation, + ForwardRelation, + plain.UnitDefinition, + ], + ] +): """Definition of a Context @context[(defaults)] <canonical name> [= <alias>] [= <alias>] @@ -170,27 +178,34 @@ class ContextDefinition(block.DirectiveBlock): ] ] - def derive_definition(self): + def derive_definition(self) -> definitions.ContextDefinition: return definitions.ContextDefinition( self.name, self.aliases, self.defaults, self.relations, self.redefinitions ) @property - def name(self): + def name(self) -> str: + assert isinstance(self.opening, BeginContext) return self.opening.name @property - def aliases(self): + def aliases(self) -> tuple[str]: + assert isinstance(self.opening, BeginContext) return self.opening.aliases @property - def defaults(self): + def defaults(self) -> dict[str, numbers.Number]: + assert isinstance(self.opening, BeginContext) return self.opening.defaults @property - def relations(self): - return tuple(r for r in self.body if isinstance(r, Relation)) + def relations(self) -> tuple[BidirectionalRelation | ForwardRelation]: + return tuple( + r + for r in self.body + if isinstance(r, (ForwardRelation, BidirectionalRelation)) + ) @property - def redefinitions(self): + def redefinitions(self) -> tuple[plain.UnitDefinition]: return tuple(r for r in self.body if isinstance(r, plain.UnitDefinition)) diff --git a/pint/delegates/txt_defparser/defaults.py b/pint/delegates/txt_defparser/defaults.py index af6e31f..688d90f 100644 --- a/pint/delegates/txt_defparser/defaults.py +++ b/pint/delegates/txt_defparser/defaults.py @@ -19,10 +19,11 @@ from dataclasses import dataclass, fields from ..._vendor import flexparser as fp from ...facets.plain import definitions from . import block, plain +from ..base_defparser import PintParsedStatement @dataclass(frozen=True) -class BeginDefaults(fp.ParsedStatement): +class BeginDefaults(PintParsedStatement): """Being of a defaults directive. @defaults @@ -36,7 +37,16 @@ class BeginDefaults(fp.ParsedStatement): @dataclass(frozen=True) -class DefaultsDefinition(block.DirectiveBlock): +class DefaultsDefinition( + block.DirectiveBlock[ + definitions.DefaultsDefinition, + BeginDefaults, + ty.Union[ + plain.CommentDefinition, + plain.Equality, + ], + ] +): """Directive to store values. @defaults @@ -55,10 +65,10 @@ class DefaultsDefinition(block.DirectiveBlock): ] @property - def _valid_fields(self): + def _valid_fields(self) -> tuple[str]: return tuple(f.name for f in fields(definitions.DefaultsDefinition)) - def derive_definition(self): + def derive_definition(self) -> definitions.DefaultsDefinition: for definition in self.filter_by(plain.Equality): if definition.lhs not in self._valid_fields: raise ValueError( @@ -70,7 +80,7 @@ class DefaultsDefinition(block.DirectiveBlock): *tuple(self.get_key(key) for key in self._valid_fields) ) - def get_key(self, key): + def get_key(self, key: str) -> str: for stmt in self.body: if isinstance(stmt, plain.Equality) and stmt.lhs == key: return stmt.rhs diff --git a/pint/delegates/txt_defparser/defparser.py b/pint/delegates/txt_defparser/defparser.py index 0b99d6d..f1b8e45 100644 --- a/pint/delegates/txt_defparser/defparser.py +++ b/pint/delegates/txt_defparser/defparser.py @@ -5,11 +5,28 @@ import typing as ty from ..._vendor import flexcache as fc from ..._vendor import flexparser as fp -from .. import base_defparser +from ..base_defparser import ParserConfig from . import block, common, context, defaults, group, plain, system -class PintRootBlock(fp.RootBlock): +class PintRootBlock( + fp.RootBlock[ + ty.Union[ + plain.CommentDefinition, + common.ImportDefinition, + context.ContextDefinition, + defaults.DefaultsDefinition, + system.SystemDefinition, + group.GroupDefinition, + plain.AliasDefinition, + plain.DerivedDimensionDefinition, + plain.DimensionDefinition, + plain.PrefixDefinition, + plain.UnitDefinition, + ], + ParserConfig, + ] +): body: fp.Multi[ ty.Union[ plain.CommentDefinition, @@ -27,11 +44,15 @@ class PintRootBlock(fp.RootBlock): ] +class PintSource(fp.ParsedSource[PintRootBlock, ParserConfig]): + """Source code in Pint.""" + + class HashTuple(tuple): pass -class _PintParser(fp.Parser): +class _PintParser(fp.Parser[PintRootBlock, ParserConfig]): """Parser for the original Pint definition file, with cache.""" _delimiters = { @@ -46,11 +67,11 @@ class _PintParser(fp.Parser): _diskcache: fc.DiskCache - def __init__(self, config: base_defparser.ParserConfig, *args, **kwargs): + def __init__(self, config: ParserConfig, *args, **kwargs): self._diskcache = kwargs.pop("diskcache", None) super().__init__(config, *args, **kwargs) - def parse_file(self, path: pathlib.Path) -> fp.ParsedSource: + def parse_file(self, path: pathlib.Path) -> PintSource: if self._diskcache is None: return super().parse_file(path) content, basename = self._diskcache.load(path, super().parse_file) @@ -58,7 +79,13 @@ class _PintParser(fp.Parser): class DefParser: - skip_classes = (fp.BOF, fp.BOR, fp.BOS, fp.EOS, plain.CommentDefinition) + skip_classes: tuple[type] = ( + fp.BOF, + fp.BOR, + fp.BOS, + fp.EOS, + plain.CommentDefinition, + ) def __init__(self, default_config, diskcache): self._default_config = default_config @@ -78,6 +105,8 @@ class DefParser: continue if isinstance(stmt, common.DefinitionSyntaxError): + # TODO: check why this assert fails + # assert isinstance(last_location, str) stmt.set_location(last_location) raise stmt elif isinstance(stmt, block.DirectiveBlock): @@ -101,7 +130,7 @@ class DefParser: else: yield stmt - def parse_file(self, filename: pathlib.Path, cfg=None): + def parse_file(self, filename: pathlib.Path, cfg: ParserConfig | None = None): return fp.parse( filename, _PintParser, @@ -109,7 +138,7 @@ class DefParser: diskcache=self._diskcache, ) - def parse_string(self, content: str, cfg=None): + def parse_string(self, content: str, cfg: ParserConfig | None = None): return fp.parse_bytes( content.encode("utf-8"), _PintParser, diff --git a/pint/delegates/txt_defparser/group.py b/pint/delegates/txt_defparser/group.py index 5be42ac..e96d44b 100644 --- a/pint/delegates/txt_defparser/group.py +++ b/pint/delegates/txt_defparser/group.py @@ -23,10 +23,11 @@ from dataclasses import dataclass from ..._vendor import flexparser as fp from ...facets.group import definitions from . import block, common, plain +from ..base_defparser import PintParsedStatement @dataclass(frozen=True) -class BeginGroup(fp.ParsedStatement): +class BeginGroup(PintParsedStatement): """Being of a group directive. @group <name> [using <group 1>, ..., <group N>] @@ -59,7 +60,16 @@ class BeginGroup(fp.ParsedStatement): @dataclass(frozen=True) -class GroupDefinition(block.DirectiveBlock): +class GroupDefinition( + block.DirectiveBlock[ + definitions.GroupDefinition, + BeginGroup, + ty.Union[ + plain.CommentDefinition, + plain.UnitDefinition, + ], + ] +): """Definition of a group. @group <name> [using <group 1>, ..., <group N>] @@ -88,19 +98,21 @@ class GroupDefinition(block.DirectiveBlock): ] ] - def derive_definition(self): + def derive_definition(self) -> definitions.GroupDefinition: return definitions.GroupDefinition( self.name, self.using_group_names, self.definitions ) @property - def name(self): + def name(self) -> str: + assert isinstance(self.opening, BeginGroup) return self.opening.name @property - def using_group_names(self): + def using_group_names(self) -> tuple[str]: + assert isinstance(self.opening, BeginGroup) return self.opening.using_group_names @property - def definitions(self) -> ty.Tuple[plain.UnitDefinition, ...]: + def definitions(self) -> tuple[plain.UnitDefinition]: return tuple(el for el in self.body if isinstance(el, plain.UnitDefinition)) diff --git a/pint/delegates/txt_defparser/plain.py b/pint/delegates/txt_defparser/plain.py index 428df10..9c7bd42 100644 --- a/pint/delegates/txt_defparser/plain.py +++ b/pint/delegates/txt_defparser/plain.py @@ -29,12 +29,12 @@ from ..._vendor import flexparser as fp from ...converters import Converter from ...facets.plain import definitions from ...util import UnitsContainer -from ..base_defparser import ParserConfig +from ..base_defparser import ParserConfig, PintParsedStatement from . import common @dataclass(frozen=True) -class Equality(fp.ParsedStatement, definitions.Equality): +class Equality(PintParsedStatement, definitions.Equality): """An equality statement contains a left and right hand separated lhs and rhs should be space stripped. @@ -53,7 +53,7 @@ class Equality(fp.ParsedStatement, definitions.Equality): @dataclass(frozen=True) -class CommentDefinition(fp.ParsedStatement, definitions.CommentDefinition): +class CommentDefinition(PintParsedStatement, definitions.CommentDefinition): """Comments start with a # character. # This is a comment. @@ -63,14 +63,14 @@ class CommentDefinition(fp.ParsedStatement, definitions.CommentDefinition): """ @classmethod - def from_string(cls, s: str) -> fp.FromString[fp.ParsedStatement]: + def from_string(cls, s: str) -> fp.FromString[CommentDefinition]: if not s.startswith("#"): return None return cls(s[1:].strip()) @dataclass(frozen=True) -class PrefixDefinition(fp.ParsedStatement, definitions.PrefixDefinition): +class PrefixDefinition(PintParsedStatement, definitions.PrefixDefinition): """Definition of a prefix:: <prefix>- = <value> [= <symbol>] [= <alias>] [ = <alias> ] [...] @@ -119,7 +119,7 @@ class PrefixDefinition(fp.ParsedStatement, definitions.PrefixDefinition): @dataclass(frozen=True) -class UnitDefinition(fp.ParsedStatement, definitions.UnitDefinition): +class UnitDefinition(PintParsedStatement, definitions.UnitDefinition): """Definition of a unit:: <canonical name> = <relation to another unit or dimension> [= <symbol>] [= <alias>] [ = <alias> ] [...] @@ -159,10 +159,10 @@ class UnitDefinition(fp.ParsedStatement, definitions.UnitDefinition): [converter, modifiers] = value.split(";", 1) try: - modifiers = dict( - (key.strip(), config.to_number(value)) + modifiers = { + key.strip(): config.to_number(value) for key, value in (part.split(":") for part in modifiers.split(";")) - ) + } except definitions.NotNumeric as ex: return common.DefinitionSyntaxError( f"Unit definition ('{name}') must contain only numbers in modifier, not {ex.value}" @@ -194,7 +194,7 @@ class UnitDefinition(fp.ParsedStatement, definitions.UnitDefinition): @dataclass(frozen=True) -class DimensionDefinition(fp.ParsedStatement, definitions.DimensionDefinition): +class DimensionDefinition(PintParsedStatement, definitions.DimensionDefinition): """Definition of a root dimension:: [dimension name] @@ -221,7 +221,7 @@ class DimensionDefinition(fp.ParsedStatement, definitions.DimensionDefinition): @dataclass(frozen=True) class DerivedDimensionDefinition( - fp.ParsedStatement, definitions.DerivedDimensionDefinition + PintParsedStatement, definitions.DerivedDimensionDefinition ): """Definition of a derived dimension:: @@ -261,7 +261,7 @@ class DerivedDimensionDefinition( @dataclass(frozen=True) -class AliasDefinition(fp.ParsedStatement, definitions.AliasDefinition): +class AliasDefinition(PintParsedStatement, definitions.AliasDefinition): """Additional alias(es) for an already existing unit:: @alias <canonical name or previous alias> = <alias> [ = <alias> ] [...] diff --git a/pint/delegates/txt_defparser/system.py b/pint/delegates/txt_defparser/system.py index b21fd7a..4efbb4d 100644 --- a/pint/delegates/txt_defparser/system.py +++ b/pint/delegates/txt_defparser/system.py @@ -14,11 +14,12 @@ from dataclasses import dataclass from ..._vendor import flexparser as fp from ...facets.system import definitions +from ..base_defparser import PintParsedStatement from . import block, common, plain @dataclass(frozen=True) -class BaseUnitRule(fp.ParsedStatement, definitions.BaseUnitRule): +class BaseUnitRule(PintParsedStatement, definitions.BaseUnitRule): @classmethod def from_string(cls, s: str) -> fp.FromString[BaseUnitRule]: if ":" not in s: @@ -32,7 +33,7 @@ class BaseUnitRule(fp.ParsedStatement, definitions.BaseUnitRule): @dataclass(frozen=True) -class BeginSystem(fp.ParsedStatement): +class BeginSystem(PintParsedStatement): """Being of a system directive. @system <name> [using <group 1>, ..., <group N>] @@ -67,7 +68,13 @@ class BeginSystem(fp.ParsedStatement): @dataclass(frozen=True) -class SystemDefinition(block.DirectiveBlock): +class SystemDefinition( + block.DirectiveBlock[ + definitions.SystemDefinition, + BeginSystem, + ty.Union[plain.CommentDefinition, BaseUnitRule], + ] +): """Definition of a System: @system <name> [using <group 1>, ..., <group N>] @@ -92,19 +99,21 @@ class SystemDefinition(block.DirectiveBlock): opening: fp.Single[BeginSystem] body: fp.Multi[ty.Union[plain.CommentDefinition, BaseUnitRule]] - def derive_definition(self): + def derive_definition(self) -> definitions.SystemDefinition: return definitions.SystemDefinition( self.name, self.using_group_names, self.rules ) @property - def name(self): + def name(self) -> str: + assert isinstance(self.opening, BeginSystem) return self.opening.name @property - def using_group_names(self): + def using_group_names(self) -> tuple[str]: + assert isinstance(self.opening, BeginSystem) return self.opening.using_group_names @property - def rules(self): + def rules(self) -> tuple[BaseUnitRule]: return tuple(el for el in self.body if isinstance(el, BaseUnitRule)) diff --git a/pint/errors.py b/pint/errors.py index 8f849da..6cebb21 100644 --- a/pint/errors.py +++ b/pint/errors.py @@ -36,18 +36,21 @@ MSG_INVALID_SYSTEM_NAME = ( ) -def is_dim(name): +def is_dim(name: str) -> bool: + """Return True if the name is flanked by square brackets `[` and `]`.""" return name[0] == "[" and name[-1] == "]" -def is_valid_prefix_name(name): +def is_valid_prefix_name(name: str) -> bool: + """Return True if the name is a valid python identifier or empty.""" return str.isidentifier(name) or name == "" is_valid_unit_name = is_valid_system_name = is_valid_context_name = str.isidentifier -def _no_space(name): +def _no_space(name: str) -> bool: + """Return False if the name contains a space in any position.""" return name.strip() == name and " " not in name @@ -58,7 +61,14 @@ is_valid_unit_alias = ( ) = is_valid_unit_symbol = is_valid_prefix_symbol = _no_space -def is_valid_dimension_name(name): +def is_valid_dimension_name(name: str) -> bool: + """Return True if the name is consistent with a dimension name. + + - flanked by square brackets. + - empty dimension name or identifier. + """ + + # TODO: shall we check also fro spaces? return name == "[]" or ( len(name) > 1 and is_dim(name) and str.isidentifier(name[1:-1]) ) @@ -67,8 +77,8 @@ def is_valid_dimension_name(name): class WithDefErr: """Mixing class to make some classes more readable.""" - def def_err(self, msg): - return DefinitionError(self.name, self.__class__.__name__, msg) + def def_err(self, msg: str): + return DefinitionError(self.name, self.__class__, msg) @dataclass(frozen=False) @@ -81,7 +91,7 @@ class DefinitionError(ValueError, PintError): """Raised when a definition is not properly constructed.""" name: str - definition_type: ty.Type + definition_type: type msg: str def __str__(self): @@ -110,7 +120,7 @@ class RedefinitionError(ValueError, PintError): """Raised when a unit or prefix is redefined.""" name: str - definition_type: ty.Type + definition_type: type def __str__(self): msg = f"Cannot redefine '{self.name}' ({self.definition_type})" @@ -124,7 +134,7 @@ class RedefinitionError(ValueError, PintError): class UndefinedUnitError(AttributeError, PintError): """Raised when the units are not defined in the unit registry.""" - unit_names: ty.Union[str, ty.Tuple[str, ...]] + unit_names: str | tuple[str] def __str__(self): if isinstance(self.unit_names, str): diff --git a/pint/facets/__init__.py b/pint/facets/__init__.py index d669b9f..750f729 100644 --- a/pint/facets/__init__.py +++ b/pint/facets/__init__.py @@ -30,8 +30,8 @@ class NumpyRegistry: - _quantity_class = NumpyQuantity - _unit_class = NumpyUnit + Quantity = NumpyQuantity + Unit = NumpyUnit This tells pint that it should use NumpyQuantity as base class for a quantity class that belongs to a registry that has NumpyRegistry as one of its bases. @@ -82,13 +82,13 @@ from .plain import PlainRegistry from .system import SystemRegistry __all__ = [ - ContextRegistry, - DaskRegistry, - FormattingRegistry, - GroupRegistry, - MeasurementRegistry, - NonMultiplicativeRegistry, - NumpyRegistry, - PlainRegistry, - SystemRegistry, + "ContextRegistry", + "DaskRegistry", + "FormattingRegistry", + "GroupRegistry", + "MeasurementRegistry", + "NonMultiplicativeRegistry", + "NumpyRegistry", + "PlainRegistry", + "SystemRegistry", ] diff --git a/pint/facets/context/definitions.py b/pint/facets/context/definitions.py index fbdb390..833857e 100644 --- a/pint/facets/context/definitions.py +++ b/pint/facets/context/definitions.py @@ -12,7 +12,8 @@ import itertools import numbers import re from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, Dict, Set, Tuple +from typing import TYPE_CHECKING, Any, Callable +from collections.abc import Iterable from ... import errors from ..plain import UnitDefinition @@ -41,7 +42,7 @@ class Relation: # could be used. @property - def variables(self) -> Set[str, ...]: + def variables(self) -> set[str]: """Find all variables names in the equation.""" return set(self._varname_re.findall(self.equation)) @@ -55,7 +56,7 @@ class Relation: ) @property - def bidirectional(self): + def bidirectional(self) -> bool: raise NotImplementedError @@ -92,18 +93,18 @@ class ContextDefinition(errors.WithDefErr): #: name of the context name: str #: other na - aliases: Tuple[str, ...] - defaults: Dict[str, numbers.Number] - relations: Tuple[Relation, ...] - redefinitions: Tuple[UnitDefinition, ...] + aliases: tuple[str] + defaults: dict[str, numbers.Number] + relations: tuple[Relation] + redefinitions: tuple[UnitDefinition] @property - def variables(self) -> Set[str, ...]: + def variables(self) -> set[str]: """Return all variable names in all transformations.""" return set().union(*(r.variables for r in self.relations)) @classmethod - def from_lines(cls, lines, non_int_type): + def from_lines(cls, lines: Iterable[str], non_int_type: type): # TODO: this is to keep it backwards compatible from ...delegates import ParserConfig, txt_defparser diff --git a/pint/facets/context/objects.py b/pint/facets/context/objects.py index 40c2bb5..38d8805 100644 --- a/pint/facets/context/objects.py +++ b/pint/facets/context/objects.py @@ -10,7 +10,8 @@ from __future__ import annotations import weakref from collections import ChainMap, defaultdict -from typing import Optional, Tuple +from typing import Any +from collections.abc import Iterable from ...facets.plain import UnitDefinition from ...util import UnitsContainer, to_units_container @@ -70,9 +71,9 @@ class Context: def __init__( self, - name: Optional[str] = None, - aliases: Tuple[str, ...] = (), - defaults: Optional[dict] = None, + name: str | None = None, + aliases: tuple[str] = tuple(), + defaults: dict[str, Any] | None = None, ) -> None: self.name = name self.aliases = aliases @@ -94,7 +95,7 @@ class Context: self.relation_to_context = weakref.WeakValueDictionary() @classmethod - def from_context(cls, context: Context, **defaults) -> Context: + def from_context(cls, context: Context, **defaults: Any) -> Context: """Creates a new context that shares the funcs dictionary with the original context. The default values are copied from the original context and updated with the new defaults. @@ -123,7 +124,9 @@ class Context: return context @classmethod - def from_lines(cls, lines, to_base_func=None, non_int_type=float) -> Context: + def from_lines( + cls, lines: Iterable[str], to_base_func=None, non_int_type: type = float + ) -> Context: cd = ContextDefinition.from_lines(lines, non_int_type) return cls.from_definition(cd, to_base_func) @@ -166,7 +169,7 @@ class Context: del self.relation_to_context[_key] @staticmethod - def __keytransform__(src, dst) -> Tuple[UnitsContainer, UnitsContainer]: + def __keytransform__(src, dst) -> tuple[UnitsContainer, UnitsContainer]: return to_units_container(src), to_units_container(dst) def transform(self, src, dst, registry, value): @@ -199,7 +202,7 @@ class Context: def hashable( self, - ) -> Tuple[Optional[str], Tuple[str, ...], frozenset, frozenset, tuple]: + ) -> tuple[str | None, tuple[str, ...], frozenset, frozenset, tuple]: """Generate a unique hashable and comparable representation of self, which can be used as a key in a dict. This class cannot define ``__hash__`` because it is mutable, and the Python interpreter does cache the output of ``__hash__``. @@ -274,7 +277,7 @@ class ContextChain(ChainMap): """ return self[(src, dst)].transform(src, dst, registry, value) - def hashable(self): + def hashable(self) -> tuple[Any]: """Generate a unique hashable and comparable representation of self, which can be used as a key in a dict. This class cannot define ``__hash__`` because it is mutable, and the Python interpreter does cache the output of ``__hash__``. diff --git a/pint/facets/context/registry.py b/pint/facets/context/registry.py index ccf69d2..a36d82d 100644 --- a/pint/facets/context/registry.py +++ b/pint/facets/context/registry.py @@ -11,14 +11,14 @@ from __future__ import annotations import functools from collections import ChainMap from contextlib import contextmanager -from typing import Any, Callable, ContextManager, Dict, Union +from typing import Any, Callable, ContextManager from ..._typing import F from ...errors import UndefinedUnitError from ...util import find_connected_nodes, find_shortest_path, logger from ..plain import PlainRegistry, UnitDefinition from .definitions import ContextDefinition -from .objects import Context, ContextChain +from . import objects # TODO: Put back annotation when possible # registry_cache: "RegistryCache" @@ -50,13 +50,13 @@ class ContextRegistry(PlainRegistry): - Parse @context directive. """ - Context = Context + Context = objects.Context def __init__(self, **kwargs: Any) -> None: # Map context name (string) or abbreviation to context. - self._contexts: Dict[str, Context] = {} + self._contexts: dict[str, objects.Context] = {} # Stores active contexts. - self._active_ctx = ContextChain() + self._active_ctx = objects.ContextChain() # Map context chain to cache self._caches = {} # Map context chain to units override @@ -71,7 +71,7 @@ class ContextRegistry(PlainRegistry): super()._register_definition_adders() self._register_adder(ContextDefinition, self.add_context) - def add_context(self, context: Union[Context, ContextDefinition]) -> None: + def add_context(self, context: Context | ContextDefinition) -> None: """Add a context object to the registry. The context will be accessible by its name and aliases. @@ -80,7 +80,7 @@ class ContextRegistry(PlainRegistry): see :meth:`enable_contexts`. """ if isinstance(context, ContextDefinition): - context = Context.from_definition(context, self.get_dimensionality) + context = objects.Context.from_definition(context, self.get_dimensionality) if not context.name: raise ValueError("Can't add unnamed context to registry") @@ -97,7 +97,7 @@ class ContextRegistry(PlainRegistry): ) self._contexts[alias] = context - def remove_context(self, name_or_alias: str) -> Context: + def remove_context(self, name_or_alias: str) -> objects.Context: """Remove a context from the registry and return it. Notice that this methods will not disable the context; @@ -194,7 +194,7 @@ class ContextRegistry(PlainRegistry): self.define(definition) def enable_contexts( - self, *names_or_contexts: Union[str, Context], **kwargs + self, *names_or_contexts: str | objects.Context, **kwargs ) -> None: """Enable contexts provided by name or by object. @@ -235,7 +235,7 @@ class ContextRegistry(PlainRegistry): ctx.checked = True # and create a new one with the new defaults. - contexts = tuple(Context.from_context(ctx, **kwargs) for ctx in ctxs) + contexts = tuple(objects.Context.from_context(ctx, **kwargs) for ctx in ctxs) # Finally we add them to the active context. self._active_ctx.insert_contexts(*contexts) @@ -253,7 +253,7 @@ class ContextRegistry(PlainRegistry): self._switch_context_cache_and_units() @contextmanager - def context(self, *names, **kwargs) -> ContextManager[Context]: + def context(self, *names, **kwargs) -> ContextManager[objects.Context]: """Used as a context manager, this function enables to activate a context which is removed after usage. diff --git a/pint/facets/dask/__init__.py b/pint/facets/dask/__init__.py index 42fced0..90c8972 100644 --- a/pint/facets/dask/__init__.py +++ b/pint/facets/dask/__init__.py @@ -14,7 +14,7 @@ from __future__ import annotations import functools from ...compat import compute, dask_array, persist, visualize -from ..plain import PlainRegistry +from ..plain import PlainRegistry, PlainQuantity def check_dask_array(f): @@ -31,13 +31,13 @@ def check_dask_array(f): return wrapper -class DaskQuantity: +class DaskQuantity(PlainQuantity): # Dask.array.Array ducking def __dask_graph__(self): if isinstance(self._magnitude, dask_array.Array): return self._magnitude.__dask_graph__() - else: - return None + + return None def __dask_keys__(self): return self._magnitude.__dask_keys__() @@ -120,4 +120,4 @@ class DaskQuantity: class DaskRegistry(PlainRegistry): - _quantity_class = DaskQuantity + Quantity = DaskQuantity diff --git a/pint/facets/formatting/objects.py b/pint/facets/formatting/objects.py index 1ba92c9..5df937c 100644 --- a/pint/facets/formatting/objects.py +++ b/pint/facets/formatting/objects.py @@ -23,8 +23,10 @@ from ...formatting import ( ) from ...util import UnitsContainer, iterable +from ..plain import PlainQuantity, PlainUnit -class FormattingQuantity: + +class FormattingQuantity(PlainQuantity): _exp_pattern = re.compile(r"([0-9]\.?[0-9]*)e(-?)\+?0*([0-9]+)") def __format__(self, spec: str) -> str: @@ -80,7 +82,7 @@ class FormattingQuantity: else: if isinstance(self.magnitude, ndarray): # Use custom ndarray text formatting with monospace font - formatter = "{{:{}}}".format(mspec) + formatter = f"{{:{mspec}}}" # Need to override for scalars, which are detected as iterable, # and don't respond to printoptions. if self.magnitude.ndim == 0: @@ -112,7 +114,7 @@ class FormattingQuantity: else: # Use custom ndarray text formatting--need to handle scalars differently # since they don't respond to printoptions - formatter = "{{:{}}}".format(mspec) + formatter = f"{{:{mspec}}}" if obj.magnitude.ndim == 0: mstr = formatter.format(obj.magnitude) else: @@ -154,7 +156,7 @@ class FormattingQuantity: obj = self.to_compact() else: obj = self - kwspec = dict(kwspec) + kwspec = kwspec.copy() if "length" in kwspec: kwspec["babel_length"] = kwspec.pop("length") @@ -176,7 +178,7 @@ class FormattingQuantity: return format(self) -class FormattingUnit: +class FormattingUnit(PlainUnit): def __str__(self): return format(self) @@ -188,10 +190,10 @@ class FormattingUnit: if not self._units: return "" units = UnitsContainer( - dict( - (self._REGISTRY._get_symbol(key), value) + { + self._REGISTRY._get_symbol(key): value for key, value in self._units.items() - ) + } ) uspec = uspec.replace("~", "") else: @@ -206,10 +208,10 @@ class FormattingUnit: if self.dimensionless: return "" units = UnitsContainer( - dict( - (self._REGISTRY._get_symbol(key), value) + { + self._REGISTRY._get_symbol(key): value for key, value in self._units.items() - ) + } ) spec = spec.replace("~", "") else: diff --git a/pint/facets/formatting/registry.py b/pint/facets/formatting/registry.py index bd9c74c..c4dc373 100644 --- a/pint/facets/formatting/registry.py +++ b/pint/facets/formatting/registry.py @@ -13,5 +13,5 @@ from .objects import FormattingQuantity, FormattingUnit class FormattingRegistry(PlainRegistry): - _quantity_class = FormattingQuantity - _unit_class = FormattingUnit + Quantity = FormattingQuantity + Unit = FormattingUnit diff --git a/pint/facets/group/definitions.py b/pint/facets/group/definitions.py index c0abced..554a63b 100644 --- a/pint/facets/group/definitions.py +++ b/pint/facets/group/definitions.py @@ -8,9 +8,10 @@ from __future__ import annotations -import typing as ty +from collections.abc import Iterable from dataclasses import dataclass +from ..._typing import Self from ... import errors from .. import plain @@ -22,12 +23,14 @@ class GroupDefinition(errors.WithDefErr): #: name of the group name: str #: unit groups that will be included within the group - using_group_names: ty.Tuple[str, ...] + using_group_names: tuple[str] #: definitions for the units existing within the group - definitions: ty.Tuple[plain.UnitDefinition, ...] + definitions: tuple[plain.UnitDefinition] @classmethod - def from_lines(cls, lines, non_int_type): + def from_lines( + cls: type[Self], lines: Iterable[str], non_int_type: type + ) -> Self | None: # TODO: this is to keep it backwards compatible from ...delegates import ParserConfig, txt_defparser @@ -39,10 +42,10 @@ class GroupDefinition(errors.WithDefErr): return definition @property - def unit_names(self) -> ty.Tuple[str, ...]: + def unit_names(self) -> tuple[str]: return tuple(el.name for el in self.definitions) - def __post_init__(self): + def __post_init__(self) -> None: if not errors.is_valid_group_name(self.name): raise self.def_err(errors.MSG_INVALID_GROUP_NAME) diff --git a/pint/facets/group/objects.py b/pint/facets/group/objects.py index 558a107..200a323 100644 --- a/pint/facets/group/objects.py +++ b/pint/facets/group/objects.py @@ -8,6 +8,7 @@ from __future__ import annotations +from collections.abc import Generator, Iterable from ...util import SharedRegistryObject, getattr_maybe_raise from .definitions import GroupDefinition @@ -23,32 +24,26 @@ class Group(SharedRegistryObject): The group belongs to one Registry. See GroupDefinition for the definition file syntax. - """ - def __init__(self, name): - """ - :param name: Name of the group. If not given, a root Group will be created. - :type name: str - :param groups: dictionary like object groups and system. - The newly created group will be added after creation. - :type groups: dict[str | Group] - """ + Parameters + ---------- + name + If not given, a root Group will be created. + """ + def __init__(self, name: str): # The name of the group. - #: type: str self.name = name #: Names of the units in this group. #: :type: set[str] - self._unit_names = set() + self._unit_names: set[str] = set() #: Names of the groups in this group. - #: :type: set[str] - self._used_groups = set() + self._used_groups: set[str] = set() #: Names of the groups in which this group is contained. - #: :type: set[str] - self._used_by = set() + self._used_by: set[str] = set() # Add this group to the group dictionary self._REGISTRY._groups[self.name] = self @@ -59,8 +54,7 @@ class Group(SharedRegistryObject): #: A cache of the included units. #: None indicates that the cache has been invalidated. - #: :type: frozenset[str] | None - self._computed_members = None + self._computed_members: frozenset[str] | None = None @property def members(self): @@ -70,23 +64,23 @@ class Group(SharedRegistryObject): """ if self._computed_members is None: - self._computed_members = set(self._unit_names) + tmp = set(self._unit_names) for _, group in self.iter_used_groups(): - self._computed_members |= group.members + tmp |= group.members - self._computed_members = frozenset(self._computed_members) + self._computed_members = frozenset(tmp) return self._computed_members - def invalidate_members(self): + def invalidate_members(self) -> None: """Invalidate computed members in this Group and all parent nodes.""" self._computed_members = None d = self._REGISTRY._groups for name in self._used_by: d[name].invalidate_members() - def iter_used_groups(self): + def iter_used_groups(self) -> Generator[tuple[str, Group], None, None]: pending = set(self._used_groups) d = self._REGISTRY._groups while pending: @@ -95,13 +89,13 @@ class Group(SharedRegistryObject): pending |= group._used_groups yield name, d[name] - def is_used_group(self, group_name): + def is_used_group(self, group_name: str) -> bool: for name, _ in self.iter_used_groups(): if name == group_name: return True return False - def add_units(self, *unit_names): + def add_units(self, *unit_names: str) -> None: """Add units to group.""" for unit_name in unit_names: self._unit_names.add(unit_name) @@ -109,17 +103,17 @@ class Group(SharedRegistryObject): self.invalidate_members() @property - def non_inherited_unit_names(self): + def non_inherited_unit_names(self) -> frozenset[str]: return frozenset(self._unit_names) - def remove_units(self, *unit_names): + def remove_units(self, *unit_names: str) -> None: """Remove units from group.""" for unit_name in unit_names: self._unit_names.remove(unit_name) self.invalidate_members() - def add_groups(self, *group_names): + def add_groups(self, *group_names: str) -> None: """Add groups to group.""" d = self._REGISTRY._groups for group_name in group_names: @@ -136,7 +130,7 @@ class Group(SharedRegistryObject): self.invalidate_members() - def remove_groups(self, *group_names): + def remove_groups(self, *group_names: str) -> None: """Remove groups from group.""" d = self._REGISTRY._groups for group_name in group_names: @@ -148,7 +142,9 @@ class Group(SharedRegistryObject): self.invalidate_members() @classmethod - def from_lines(cls, lines, define_func, non_int_type=float) -> Group: + def from_lines( + cls, lines: Iterable[str], define_func, non_int_type: type = float + ) -> Group: """Return a Group object parsing an iterable of lines. Parameters @@ -190,6 +186,6 @@ class Group(SharedRegistryObject): return grp - def __getattr__(self, item): + def __getattr__(self, item: str): getattr_maybe_raise(self, item) return self._REGISTRY diff --git a/pint/facets/group/registry.py b/pint/facets/group/registry.py index 7269082..0d35ae0 100644 --- a/pint/facets/group/registry.py +++ b/pint/facets/group/registry.py @@ -8,17 +8,17 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Dict, FrozenSet +from typing import TYPE_CHECKING from ... import errors if TYPE_CHECKING: from ..._typing import Unit -from ...util import build_dependent_class, create_class_with_registry +from ...util import create_class_with_registry from ..plain import PlainRegistry, UnitDefinition from .definitions import GroupDefinition -from .objects import Group +from . import objects class GroupRegistry(PlainRegistry): @@ -34,19 +34,15 @@ class GroupRegistry(PlainRegistry): # TODO: Change this to Group: Group to specify class # and use introspection to get system class as a way # to enjoy typing goodies - _group_class = Group + Group = objects.Group def __init__(self, **kwargs): super().__init__(**kwargs) #: Map group name to group. #: :type: dict[ str | Group] - self._groups: Dict[str, Group] = {} + self._groups: dict[str, objects.Group] = {} self._groups["root"] = self.Group("root") - def __init_subclass__(cls, **kwargs): - super().__init_subclass__() - cls.Group = build_dependent_class(cls, "Group", "_group_class") - def _init_dynamic_classes(self) -> None: """Generate subclasses on the fly and attach them to self""" super()._init_dynamic_classes() @@ -93,7 +89,7 @@ class GroupRegistry(PlainRegistry): except KeyError as e: raise errors.DefinitionSyntaxError(f"unknown dimension {e} in context") - def get_group(self, name: str, create_if_needed: bool = True) -> Group: + def get_group(self, name: str, create_if_needed: bool = True) -> objects.Group: """Return a Group. Parameters @@ -117,7 +113,7 @@ class GroupRegistry(PlainRegistry): return self.Group(name) - def _get_compatible_units(self, input_units, group) -> FrozenSet["Unit"]: + def _get_compatible_units(self, input_units, group) -> frozenset[Unit]: ret = super()._get_compatible_units(input_units, group) if not group: diff --git a/pint/facets/measurement/objects.py b/pint/facets/measurement/objects.py index 0fed93f..5f3ba7a 100644 --- a/pint/facets/measurement/objects.py +++ b/pint/facets/measurement/objects.py @@ -18,12 +18,12 @@ from ..plain import PlainQuantity MISSING = object() -class MeasurementQuantity: +class MeasurementQuantity(PlainQuantity): # Measurement support def plus_minus(self, error, relative=False): if isinstance(error, self.__class__): if relative: - raise ValueError("{} is not a valid relative error.".format(error)) + raise ValueError(f"{error} is not a valid relative error.") error = error.to(self._units).magnitude else: if relative: @@ -98,7 +98,7 @@ class Measurement(PlainQuantity): ) def __str__(self): - return "{}".format(self) + return f"{self}" def __format__(self, spec): spec = spec or self.default_format @@ -133,7 +133,7 @@ class Measurement(PlainQuantity): # scientific notation ('e' or 'E' and sometimes 'g' or 'G'). mstr = mstr.replace("(", "").replace(")", " ") ustr = siunitx_format_unit(self.units._units, self._REGISTRY) - return r"\SI%s{%s}{%s}" % (opts, mstr, ustr) + return rf"\SI{opts}{{{mstr}}}{{{ustr}}}" # standard cases if "L" in spec: diff --git a/pint/facets/measurement/registry.py b/pint/facets/measurement/registry.py index e704399..0fc4391 100644 --- a/pint/facets/measurement/registry.py +++ b/pint/facets/measurement/registry.py @@ -10,21 +10,15 @@ from __future__ import annotations from ...compat import ufloat -from ...util import build_dependent_class, create_class_with_registry +from ...util import create_class_with_registry from ..plain import PlainRegistry -from .objects import Measurement, MeasurementQuantity +from .objects import MeasurementQuantity +from . import objects class MeasurementRegistry(PlainRegistry): - _quantity_class = MeasurementQuantity - _measurement_class = Measurement - - def __init_subclass__(cls, **kwargs): - super().__init_subclass__() - - cls.Measurement = build_dependent_class( - cls, "Measurement", "_measurement_class" - ) + Quantity = MeasurementQuantity + Measurement = objects.Measurement def _init_dynamic_classes(self) -> None: """Generate subclasses on the fly and attach them to self""" diff --git a/pint/facets/nonmultiplicative/definitions.py b/pint/facets/nonmultiplicative/definitions.py index dbfc0ff..f795cf0 100644 --- a/pint/facets/nonmultiplicative/definitions.py +++ b/pint/facets/nonmultiplicative/definitions.py @@ -10,6 +10,7 @@ from __future__ import annotations from dataclasses import dataclass +from ..._typing import Magnitude from ...compat import HAS_NUMPY, exp, log from ..plain import ScaleConverter @@ -24,7 +25,7 @@ class OffsetConverter(ScaleConverter): def is_multiplicative(self): return self.offset == 0 - def to_reference(self, value, inplace=False): + def to_reference(self, value: Magnitude, inplace: bool = False) -> Magnitude: if inplace: value *= self.scale value += self.offset @@ -33,7 +34,7 @@ class OffsetConverter(ScaleConverter): return value - def from_reference(self, value, inplace=False): + def from_reference(self, value: Magnitude, inplace: bool = False) -> Magnitude: if inplace: value -= self.offset value /= self.scale @@ -66,6 +67,7 @@ class LogarithmicConverter(ScaleConverter): controls if computation is done in place """ + # TODO: Can I use PintScalar here? logbase: float logfactor: float @@ -77,7 +79,7 @@ class LogarithmicConverter(ScaleConverter): def is_logarithmic(self): return True - def from_reference(self, value, inplace=False): + def from_reference(self, value: Magnitude, inplace: bool = False) -> Magnitude: """Converts value from the reference unit to the logarithmic unit dBm <------ mW @@ -95,7 +97,7 @@ class LogarithmicConverter(ScaleConverter): return value - def to_reference(self, value, inplace=False): + def to_reference(self, value: Magnitude, inplace: bool = False) -> Magnitude: """Converts value to the reference unit from the logarithmic unit dBm ------> mW diff --git a/pint/facets/nonmultiplicative/objects.py b/pint/facets/nonmultiplicative/objects.py index 1708e32..0ab743e 100644 --- a/pint/facets/nonmultiplicative/objects.py +++ b/pint/facets/nonmultiplicative/objects.py @@ -8,16 +8,16 @@ from __future__ import annotations -from typing import List +from ..plain import PlainQuantity -class NonMultiplicativeQuantity: +class NonMultiplicativeQuantity(PlainQuantity): @property def _is_multiplicative(self) -> bool: """Check if the PlainQuantity object has only multiplicative units.""" return not self._get_non_multiplicative_units() - def _get_non_multiplicative_units(self) -> List[str]: + def _get_non_multiplicative_units(self) -> list[str]: """Return a list of the of non-multiplicative units of the PlainQuantity object.""" return [ unit @@ -25,7 +25,7 @@ class NonMultiplicativeQuantity: if not self._get_unit_definition(unit).is_multiplicative ] - def _get_delta_units(self) -> List[str]: + def _get_delta_units(self) -> list[str]: """Return list of delta units ot the PlainQuantity object.""" return [u for u in self._units if u.startswith("delta_")] @@ -40,7 +40,7 @@ class NonMultiplicativeQuantity: self._get_unit_definition(d).reference == offset_unit_dim for d in deltas ) - def _ok_for_muldiv(self, no_offset_units=None) -> bool: + def _ok_for_muldiv(self, no_offset_units: int | None = None) -> bool: """Checks if PlainQuantity object can be multiplied or divided""" is_ok = True diff --git a/pint/facets/nonmultiplicative/registry.py b/pint/facets/nonmultiplicative/registry.py index 17b053e..8bc04db 100644 --- a/pint/facets/nonmultiplicative/registry.py +++ b/pint/facets/nonmultiplicative/registry.py @@ -8,7 +8,7 @@ from __future__ import annotations -from typing import Any, Optional +from typing import Any from ...errors import DimensionalityError, UndefinedUnitError from ...util import UnitsContainer, logger @@ -35,7 +35,7 @@ class NonMultiplicativeRegistry(PlainRegistry): """ - _quantity_class = NonMultiplicativeQuantity + Quantity = NonMultiplicativeQuantity def __init__( self, @@ -56,8 +56,8 @@ class NonMultiplicativeRegistry(PlainRegistry): def _parse_units( self, input_string: str, - as_delta: Optional[bool] = None, - case_sensitive: Optional[bool] = None, + as_delta: bool | None = None, + case_sensitive: bool | None = None, ): """ """ if as_delta is None: diff --git a/pint/facets/numpy/numpy_func.py b/pint/facets/numpy/numpy_func.py index f25f4a4..e7a9b67 100644 --- a/pint/facets/numpy/numpy_func.py +++ b/pint/facets/numpy/numpy_func.py @@ -220,7 +220,7 @@ def get_op_output_unit(unit_op, first_input_units, all_args=None, size=None): product /= x.units result_unit = product**-1 else: - raise ValueError("Output unit method {} not understood".format(unit_op)) + raise ValueError(f"Output unit method {unit_op} not understood") return result_unit @@ -237,7 +237,7 @@ def implements(numpy_func_string, func_type): elif func_type == "ufunc": HANDLED_UFUNCS[numpy_func_string] = func else: - raise ValueError("Invalid func_type {}".format(func_type)) + raise ValueError(f"Invalid func_type {func_type}") return func return decorator @@ -311,7 +311,7 @@ def implement_func(func_type, func_str, input_units=None, output_unit=None): return result_magnitude elif output_unit == "match_input": result_unit = first_input_units - elif output_unit in [ + elif output_unit in ( "sum", "mul", "delta", @@ -324,7 +324,7 @@ def implement_func(func_type, func_str, input_units=None, output_unit=None): "cbrt", "reciprocal", "size", - ]: + ): result_unit = get_op_output_unit( output_unit, first_input_units, tuple(chain(args, kwargs.values())) ) @@ -499,8 +499,8 @@ def _frexp(x, *args, **kwargs): def _power(x1, x2): if _is_quantity(x1): return x1**x2 - else: - return x2.__rpow__(x1) + + return x2.__rpow__(x1) @implements("add", "ufunc") @@ -535,8 +535,8 @@ def _full_like(a, fill_value, **kwargs): np.ones_like(a, **kwargs) * fill_value.m, fill_value.units, ) - else: - return np.ones_like(a, **kwargs) * fill_value + + return np.ones_like(a, **kwargs) * fill_value @implements("interp", "function") @@ -671,8 +671,8 @@ def _any(a, *args, **kwargs): # Only valid when multiplicative unit/no offset if a._is_multiplicative: return np.any(a._magnitude, *args, **kwargs) - else: - raise ValueError("Boolean value of Quantity with offset unit is ambiguous.") + + raise ValueError("Boolean value of Quantity with offset unit is ambiguous.") @implements("all", "function") @@ -725,7 +725,7 @@ def implement_prod_func(name): return registry.Quantity(result, units) -for name in ["prod", "nanprod"]: +for name in ("prod", "nanprod"): implement_prod_func(name) @@ -780,7 +780,7 @@ def implement_mul_func(func): return a.units._REGISTRY.Quantity(mag, units) -for func_str in ["cross", "dot"]: +for func_str in ("cross", "dot"): implement_mul_func(func_str) @@ -830,11 +830,11 @@ def implement_consistent_units_by_argument(func_str, unit_arguments, wrap_output # Conditionally wrap output if wrap_output: return output_wrap(ret) - else: - return ret + + return ret -for func_str, unit_arguments, wrap_output in [ +for func_str, unit_arguments, wrap_output in ( ("expand_dims", "a", True), ("squeeze", "a", True), ("rollaxis", "a", True), @@ -884,7 +884,7 @@ for func_str, unit_arguments, wrap_output in [ ("reshape", "a", True), ("allclose", ["a", "b", "atol"], False), ("intersect1d", ["ar1", "ar2"], True), -]: +): implement_consistent_units_by_argument(func_str, unit_arguments, wrap_output) @@ -914,7 +914,7 @@ def implement_atleast_nd(func_str): return output_unit._REGISTRY.Quantity(arrays_magnitude, output_unit) -for func_str in ["atleast_1d", "atleast_2d", "atleast_3d"]: +for func_str in ("atleast_1d", "atleast_2d", "atleast_3d"): implement_atleast_nd(func_str) @@ -935,24 +935,24 @@ def implement_single_dimensionless_argument_func(func_str): return a._REGISTRY.Quantity(func(a_stripped, *args, **kwargs)) -for func_str in ["cumprod", "cumproduct", "nancumprod"]: +for func_str in ("cumprod", "cumproduct", "nancumprod"): implement_single_dimensionless_argument_func(func_str) # Handle single-argument consistent unit functions -for func_str in [ +for func_str in ( "block", "hstack", "vstack", "dstack", "column_stack", "broadcast_arrays", -]: +): implement_func( "function", func_str, input_units="all_consistent", output_unit="match_input" ) # Handle functions that ignore units on input and output -for func_str in [ +for func_str in ( "size", "isreal", "iscomplex", @@ -969,19 +969,19 @@ for func_str in [ "count_nonzero", "nonzero", "result_type", -]: +): implement_func("function", func_str, input_units=None, output_unit=None) # Handle functions with output unit defined by operation -for func_str in ["std", "nanstd", "sum", "nansum", "cumsum", "nancumsum"]: +for func_str in ("std", "nanstd", "sum", "nansum", "cumsum", "nancumsum"): implement_func("function", func_str, input_units=None, output_unit="sum") -for func_str in ["diff", "ediff1d"]: +for func_str in ("diff", "ediff1d"): implement_func("function", func_str, input_units=None, output_unit="delta") -for func_str in ["gradient"]: +for func_str in ("gradient",): implement_func("function", func_str, input_units=None, output_unit="delta,div") -for func_str in ["linalg.solve"]: +for func_str in ("linalg.solve",): implement_func("function", func_str, input_units=None, output_unit="invdiv") -for func_str in ["var", "nanvar"]: +for func_str in ("var", "nanvar"): implement_func("function", func_str, input_units=None, output_unit="variance") @@ -997,7 +997,7 @@ def numpy_wrap(func_type, func, args, kwargs, types): # ufuncs do not have func.__module__ name = func.__name__ else: - raise ValueError("Invalid func_type {}".format(func_type)) + raise ValueError(f"Invalid func_type {func_type}") if name not in handled or any(is_upcast_type(t) for t in types): return NotImplemented diff --git a/pint/facets/numpy/quantity.py b/pint/facets/numpy/quantity.py index 9aa55ce..131983c 100644 --- a/pint/facets/numpy/quantity.py +++ b/pint/facets/numpy/quantity.py @@ -13,6 +13,8 @@ import math import warnings from typing import Any +from ..plain import PlainQuantity + from ..._typing import Shape, _MagnitudeType from ...compat import _to_magnitude, np from ...errors import DimensionalityError, PintTypeError, UnitStrippedWarning @@ -40,7 +42,7 @@ def method_wraps(numpy_func): return wrapper -class NumpyQuantity: +class NumpyQuantity(PlainQuantity): """ """ # NumPy function/ufunc support @@ -52,11 +54,11 @@ class NumpyQuantity: return NotImplemented # Replicate types from __array_function__ - types = set( + types = { type(arg) for arg in list(inputs) + list(kwargs.values()) if hasattr(arg, "__array_ufunc__") - ) + } return numpy_wrap("ufunc", ufunc, inputs, kwargs, types) @@ -99,8 +101,8 @@ class NumpyQuantity: if output_unit is not None: return self.__class__(value, output_unit) - else: - return value + + return value def __array__(self, t=None) -> np.ndarray: warnings.warn( diff --git a/pint/facets/numpy/registry.py b/pint/facets/numpy/registry.py index fa4768f..11d57f3 100644 --- a/pint/facets/numpy/registry.py +++ b/pint/facets/numpy/registry.py @@ -15,5 +15,5 @@ from .unit import NumpyUnit class NumpyRegistry(PlainRegistry): - _quantity_class = NumpyQuantity - _unit_class = NumpyUnit + Quantity = NumpyQuantity + Unit = NumpyUnit diff --git a/pint/facets/numpy/unit.py b/pint/facets/numpy/unit.py index 0b5007f..d6bf140 100644 --- a/pint/facets/numpy/unit.py +++ b/pint/facets/numpy/unit.py @@ -9,9 +9,10 @@ from __future__ import annotations from ...compat import is_upcast_type +from ..plain import PlainUnit -class NumpyUnit: +class NumpyUnit(PlainUnit): __array_priority__ = 17 def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): @@ -20,11 +21,11 @@ class NumpyUnit: return NotImplemented # Check types and return NotImplemented when upcast type encountered - types = set( + types = { type(arg) for arg in list(inputs) + list(kwargs.values()) if hasattr(arg, "__array_ufunc__") - ) + } if any(is_upcast_type(other) for other in types): return NotImplemented @@ -38,5 +39,5 @@ class NumpyUnit: ), **kwargs, ) - else: - return NotImplemented + + return NotImplemented diff --git a/pint/facets/plain/definitions.py b/pint/facets/plain/definitions.py index 11a3095..79a44f1 100644 --- a/pint/facets/plain/definitions.py +++ b/pint/facets/plain/definitions.py @@ -13,8 +13,9 @@ import numbers import typing as ty from dataclasses import dataclass from functools import cached_property -from typing import Callable, Optional +from typing import Callable, Any +from ..._typing import Magnitude from ... import errors from ...converters import Converter from ...util import UnitsContainer @@ -23,7 +24,7 @@ from ...util import UnitsContainer class NotNumeric(Exception): """Internal exception. Do not expose outside Pint""" - def __init__(self, value): + def __init__(self, value: Any): self.value = value @@ -76,7 +77,7 @@ class PrefixDefinition(errors.WithDefErr): #: scaling value for this prefix value: numbers.Number #: canonical symbol - defined_symbol: Optional[str] = "" + defined_symbol: str | None = "" #: additional names for the same prefix aliases: ty.Tuple[str, ...] = () @@ -115,18 +116,26 @@ class UnitDefinition(errors.WithDefErr): #: canonical name of the unit name: str #: canonical symbol - defined_symbol: ty.Optional[str] + defined_symbol: str | None #: additional names for the same unit - aliases: ty.Tuple[str, ...] + aliases: tuple[str] #: A functiont that converts a value in these units into the reference units - converter: ty.Optional[ty.Union[Callable, Converter]] + converter: Callable[ + [ + Magnitude, + ], + Magnitude, + ] | Converter | None #: Reference units. - reference: ty.Optional[UnitsContainer] + reference: UnitsContainer | None def __post_init__(self): if not errors.is_valid_unit_name(self.name): raise self.def_err(errors.MSG_INVALID_UNIT_NAME) + # TODO: check why reference: UnitsContainer | None + assert isinstance(self.reference, UnitsContainer) + if not any(map(errors.is_dim, self.reference.keys())): invalid = tuple( itertools.filterfalse(errors.is_valid_unit_name, self.reference.keys()) @@ -180,14 +189,20 @@ class UnitDefinition(errors.WithDefErr): @property def is_base(self) -> bool: """Indicates if it is a base unit.""" + + # TODO: why is this here return self._is_base @property def is_multiplicative(self) -> bool: + # TODO: Check how to avoid this check + assert isinstance(self.converter, Converter) return self.converter.is_multiplicative @property def is_logarithmic(self) -> bool: + # TODO: Check how to avoid this check + assert isinstance(self.converter, Converter) return self.converter.is_logarithmic @property @@ -272,7 +287,7 @@ class ScaleConverter(Converter): scale: float - def to_reference(self, value, inplace=False): + def to_reference(self, value: Magnitude, inplace: bool = False) -> Magnitude: if inplace: value *= self.scale else: @@ -280,7 +295,7 @@ class ScaleConverter(Converter): return value - def from_reference(self, value, inplace=False): + def from_reference(self, value: Magnitude, inplace: bool = False) -> Magnitude: if inplace: value /= self.scale else: diff --git a/pint/facets/plain/objects.py b/pint/facets/plain/objects.py index 5b2837b..a868c7f 100644 --- a/pint/facets/plain/objects.py +++ b/pint/facets/plain/objects.py @@ -11,4 +11,4 @@ from __future__ import annotations from .quantity import PlainQuantity from .unit import PlainUnit, UnitsContainer -__all__ = [PlainUnit, PlainQuantity, UnitsContainer] +__all__ = ["PlainUnit", "PlainQuantity", "UnitsContainer"] diff --git a/pint/facets/plain/quantity.py b/pint/facets/plain/quantity.py index 359e613..1eaaa3d 100644 --- a/pint/facets/plain/quantity.py +++ b/pint/facets/plain/quantity.py @@ -20,18 +20,11 @@ from typing import ( TYPE_CHECKING, Any, Callable, - Dict, Generic, - Iterable, - Iterator, - List, - Optional, - Sequence, - Tuple, TypeVar, - Union, overload, ) +from collections.abc import Iterable, Iterator, Sequence from ..._typing import S, UnitLike, _MagnitudeType from ...compat import ( @@ -179,25 +172,25 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType] @overload def __new__( - cls, value: str, units: Optional[UnitLike] = None + cls, value: str, units: UnitLike | None = None ) -> PlainQuantity[Magnitude]: ... @overload def __new__( # type: ignore[misc] - cls, value: Sequence, units: Optional[UnitLike] = None + cls, value: Sequence, units: UnitLike | None = None ) -> PlainQuantity[np.ndarray]: ... @overload def __new__( - cls, value: PlainQuantity[Magnitude], units: Optional[UnitLike] = None + cls, value: PlainQuantity[Magnitude], units: UnitLike | None = None ) -> PlainQuantity[Magnitude]: ... @overload def __new__( - cls, value: Magnitude, units: Optional[UnitLike] = None + cls, value: Magnitude, units: UnitLike | None = None ) -> PlainQuantity[Magnitude]: ... @@ -281,15 +274,15 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType] def __repr__(self) -> str: if isinstance(self._magnitude, float): return f"<Quantity({self._magnitude:.9}, '{self._units}')>" - else: - return f"<Quantity({self._magnitude}, '{self._units}')>" + + return f"<Quantity({self._magnitude}, '{self._units}')>" def __hash__(self) -> int: self_base = self.to_base_units() if self_base.dimensionless: return hash(self_base.magnitude) - else: - return hash((self_base.__class__, self_base.magnitude, self_base.units)) + + return hash((self_base.__class__, self_base.magnitude, self_base.units)) @property def magnitude(self) -> _MagnitudeType: @@ -316,12 +309,12 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType] return self.to(units).magnitude @property - def units(self) -> "Unit": + def units(self) -> Unit: """PlainQuantity's units. Long form for `u`""" return self._REGISTRY.Unit(self._units) @property - def u(self) -> "Unit": + def u(self) -> Unit: """PlainQuantity's units. Short form for `units`""" return self._REGISTRY.Unit(self._units) @@ -337,7 +330,7 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType] return not bool(tmp.dimensionality) - _dimensionality: Optional[UnitsContainerT] = None + _dimensionality: UnitsContainerT | None = None @property def dimensionality(self) -> UnitsContainerT: @@ -358,7 +351,7 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType] @classmethod def from_list( - cls, quant_list: List[PlainQuantity], units=None + cls, quant_list: list[PlainQuantity], units=None ) -> PlainQuantity[np.ndarray]: """Transforms a list of Quantities into an numpy.array quantity. If no units are specified, the unit of the first element will be used. @@ -421,7 +414,7 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType] def from_tuple(cls, tup): return cls(tup[0], cls._REGISTRY.UnitsContainer(tup[1])) - def to_tuple(self) -> Tuple[_MagnitudeType, Tuple[Tuple[str]]]: + def to_tuple(self) -> tuple[_MagnitudeType, tuple[tuple[str]]]: return self.m, tuple(self._units.items()) def compatible_units(self, *contexts): @@ -432,7 +425,7 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType] return self._REGISTRY.get_compatible_units(self._units) def is_compatible_with( - self, other: Any, *contexts: Union[str, Context], **ctx_kwargs: Any + self, other: Any, *contexts: str | Context, **ctx_kwargs: Any ) -> bool: """check if the other object is compatible @@ -652,7 +645,7 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType] ): return self - SI_prefixes: Dict[int, str] = {} + SI_prefixes: dict[int, str] = {} for prefix in self._REGISTRY._prefixes.values(): try: scale = prefix.converter.scale @@ -702,7 +695,7 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType] return self.to(new_unit_container) def to_preferred( - self, preferred_units: List[UnitLike] + self, preferred_units: list[UnitLike] ) -> PlainQuantity[_MagnitudeType]: """Return Quantity converted to a unit composed of the preferred units. @@ -732,9 +725,9 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType] for preferred_unit in preferred_units: dims = sorted(preferred_unit.dimensionality) if dims == self_dims: - p_exps_head, *p_exps_tail = [ + p_exps_head, *p_exps_tail = ( preferred_unit.dimensionality[d] for d in dims - ] + ) if all( s_exps_tail[i] * p_exps_head == p_exps_tail[i] ** s_exps_head for i in range(n) @@ -812,15 +805,15 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType] # update preferred_units with the selected units that were originally preferred preferred_units = list( - set(u for d, u in unit_selections.items() if d in preferred_dims) + {u for d, u in unit_selections.items() if d in preferred_dims} ) - preferred_units.sort(key=lambda unit: str(unit)) # for determinism + preferred_units.sort(key=str) # for determinism # and unpreferred_units are the selected units that weren't originally preferred unpreferred_units = list( - set(u for d, u in unit_selections.items() if d not in preferred_dims) + {u for d, u in unit_selections.items() if d not in preferred_dims} ) - unpreferred_units.sort(key=lambda unit: str(unit)) # for determinism + unpreferred_units.sort(key=str) # for determinism # for indexability dimensions = list(dimension_set) @@ -918,10 +911,10 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType] result_unit = sorting_keys[min_key] return self.to(result_unit) - else: - # for whatever reason, a solution wasn't found - # return the original quantity - return self + + # for whatever reason, a solution wasn't found + # return the original quantity + return self # Mathematical operations def __int__(self) -> int: @@ -1178,22 +1171,22 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType] return self.to_timedelta() + other elif is_duck_array_type(type(self._magnitude)): return self._iadd_sub(other, operator.iadd) - else: - return self._add_sub(other, operator.add) + + return self._add_sub(other, operator.add) def __add__(self, other): if isinstance(other, datetime.datetime): return self.to_timedelta() + other - else: - return self._add_sub(other, operator.add) + + return self._add_sub(other, operator.add) __radd__ = __add__ def __isub__(self, other): if is_duck_array_type(type(self._magnitude)): return self._iadd_sub(other, operator.isub) - else: - return self._add_sub(other, operator.sub) + + return self._add_sub(other, operator.sub) def __sub__(self, other): return self._add_sub(other, operator.sub) @@ -1201,8 +1194,8 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType] def __rsub__(self, other): if isinstance(other, datetime.datetime): return other - self.to_timedelta() - else: - return -self._add_sub(other, operator.sub) + + return -self._add_sub(other, operator.sub) @check_implemented @ireduce_dimensions @@ -1235,10 +1228,10 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType] if not self._ok_for_muldiv(no_offset_units_self): raise OffsetUnitCalculusError(self._units, getattr(other, "units", "")) if len(offset_units_self) == 1: - if self._units[offset_units_self[0]] != 1 or magnitude_op not in [ + if self._units[offset_units_self[0]] != 1 or magnitude_op not in ( operator.mul, operator.imul, - ]: + ): raise OffsetUnitCalculusError( self._units, getattr(other, "units", "") ) @@ -1259,14 +1252,14 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType] if not self._ok_for_muldiv(no_offset_units_self): raise OffsetUnitCalculusError(self._units, other._units) - elif no_offset_units_self == 1 and len(self._units) == 1: + elif no_offset_units_self == len(self._units) == 1: self.ito_root_units() no_offset_units_other = len(other._get_non_multiplicative_units()) if not other._ok_for_muldiv(no_offset_units_other): raise OffsetUnitCalculusError(self._units, other._units) - elif no_offset_units_other == 1 and len(other._units) == 1: + elif no_offset_units_other == len(other._units) == 1: other.ito_root_units() self._magnitude = magnitude_op(self._magnitude, other._magnitude) @@ -1304,10 +1297,10 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType] if not self._ok_for_muldiv(no_offset_units_self): raise OffsetUnitCalculusError(self._units, getattr(other, "units", "")) if len(offset_units_self) == 1: - if self._units[offset_units_self[0]] != 1 or magnitude_op not in [ + if self._units[offset_units_self[0]] != 1 or magnitude_op not in ( operator.mul, operator.imul, - ]: + ): raise OffsetUnitCalculusError( self._units, getattr(other, "units", "") ) @@ -1332,14 +1325,14 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType] if not self._ok_for_muldiv(no_offset_units_self): raise OffsetUnitCalculusError(self._units, other._units) - elif no_offset_units_self == 1 and len(self._units) == 1: + elif no_offset_units_self == len(self._units) == 1: new_self = self.to_root_units() no_offset_units_other = len(other._get_non_multiplicative_units()) if not other._ok_for_muldiv(no_offset_units_other): raise OffsetUnitCalculusError(self._units, other._units) - elif no_offset_units_other == 1 and len(other._units) == 1: + elif no_offset_units_other == len(other._units) == 1: other = other.to_root_units() magnitude = magnitude_op(new_self._magnitude, other._magnitude) @@ -1350,8 +1343,8 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType] def __imul__(self, other): if is_duck_array_type(type(self._magnitude)): return self._imul_div(other, operator.imul) - else: - return self._mul_div(other, operator.mul) + + return self._mul_div(other, operator.mul) def __mul__(self, other): return self._mul_div(other, operator.mul) @@ -1374,8 +1367,8 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType] def __itruediv__(self, other): if is_duck_array_type(type(self._magnitude)): return self._imul_div(other, operator.itruediv) - else: - return self._mul_div(other, operator.truediv) + + return self._mul_div(other, operator.truediv) def __truediv__(self, other): if isinstance(self.m, int) or isinstance(getattr(other, "m", None), int): @@ -1395,7 +1388,7 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType] no_offset_units_self = len(self._get_non_multiplicative_units()) if not self._ok_for_muldiv(no_offset_units_self): raise OffsetUnitCalculusError(self._units, "") - elif no_offset_units_self == 1 and len(self._units) == 1: + elif no_offset_units_self == len(self._units) == 1: self = self.to_root_units() return self.__class__(other_magnitude / self._magnitude, 1 / self._units) @@ -1627,7 +1620,7 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType] def __abs__(self) -> PlainQuantity[_MagnitudeType]: return self.__class__(abs(self._magnitude), self._units) - def __round__(self, ndigits: Optional[int] = 0) -> PlainQuantity[int]: + def __round__(self, ndigits: int | None = 0) -> PlainQuantity[int]: return self.__class__(round(self._magnitude, ndigits=ndigits), self._units) def __pos__(self) -> PlainQuantity[_MagnitudeType]: @@ -1720,9 +1713,7 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType] else: raise OffsetUnitCalculusError(self._units) else: - raise ValueError( - "Cannot compare PlainQuantity and {}".format(type(other)) - ) + raise ValueError(f"Cannot compare PlainQuantity and {type(other)}") # Registry equality check based on util.SharedRegistryObject if self._REGISTRY is not other._REGISTRY: @@ -1791,11 +1782,11 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType] """Check if the PlainQuantity object has only multiplicative units.""" return True - def _get_non_multiplicative_units(self) -> List[str]: + def _get_non_multiplicative_units(self) -> list[str]: """Return a list of the of non-multiplicative units of the PlainQuantity object.""" return [] - def _get_delta_units(self) -> List[str]: + def _get_delta_units(self) -> list[str]: """Return list of delta units ot the PlainQuantity object.""" return [u for u in self._units if u.startswith("delta_")] diff --git a/pint/facets/plain/registry.py b/pint/facets/plain/registry.py index 0bf1545..d3baff4 100644 --- a/pint/facets/plain/registry.py +++ b/pint/facets/plain/registry.py @@ -24,18 +24,10 @@ from typing import ( TYPE_CHECKING, Any, Callable, - Dict, - FrozenSet, - Iterable, - Iterator, - List, - Optional, - Set, - Tuple, - Type, TypeVar, Union, ) +from collections.abc import Iterable, Iterator if TYPE_CHECKING: from ..context import Context @@ -51,7 +43,6 @@ from ...util import UnitsContainer from ...util import UnitsContainer as UnitsContainerT from ...util import ( _is_dim, - build_dependent_class, create_class_with_registry, getattr_maybe_raise, logger, @@ -83,7 +74,7 @@ T = TypeVar("T") _BLOCK_RE = re.compile(r"[ (]") -@functools.lru_cache() +@functools.lru_cache def pattern_to_regex(pattern): if hasattr(pattern, "finditer"): pattern = pattern.pattern @@ -96,7 +87,7 @@ def pattern_to_regex(pattern): return re.compile(pattern) -NON_INT_TYPE = Type[Union[float, Decimal, Fraction]] +NON_INT_TYPE = type[Union[float, Decimal, Fraction]] PreprocessorType = Callable[[str], str] @@ -105,13 +96,13 @@ class RegistryCache: def __init__(self) -> None: #: Maps dimensionality (UnitsContainer) to Units (str) - self.dimensional_equivalents: Dict[UnitsContainer, Set[str]] = {} + self.dimensional_equivalents: dict[UnitsContainer, set[str]] = {} #: Maps dimensionality (UnitsContainer) to Dimensionality (UnitsContainer) self.root_units = {} #: Maps dimensionality (UnitsContainer) to Units (UnitsContainer) - self.dimensionality: Dict[UnitsContainer, UnitsContainer] = {} + self.dimensionality: dict[UnitsContainer, UnitsContainer] = {} #: Cache the unit name associated to user input. ('mV' -> 'millivolt') - self.parse_unit: Dict[str, UnitsContainer] = {} + self.parse_unit: dict[str, UnitsContainer] = {} def __eq__(self, other): if not isinstance(other, self.__class__): @@ -181,12 +172,12 @@ class PlainRegistry(metaclass=RegistryMeta): """ #: Babel.Locale instance or None - fmt_locale: Optional[Locale] = None + fmt_locale: Locale | None = None _diskcache = None - _quantity_class = PlainQuantity - _unit_class = PlainUnit + Quantity = PlainQuantity + Unit = PlainUnit _def_parser = None @@ -197,16 +188,16 @@ class PlainRegistry(metaclass=RegistryMeta): force_ndarray_like: bool = False, on_redefinition: str = "warn", auto_reduce_dimensions: bool = False, - preprocessors: Optional[List[PreprocessorType]] = None, - fmt_locale: Optional[str] = None, + preprocessors: list[PreprocessorType] | None = None, + fmt_locale: str | None = None, non_int_type: NON_INT_TYPE = float, case_sensitive: bool = True, - cache_folder: Union[str, pathlib.Path, None] = None, - separate_format_defaults: Optional[bool] = None, + cache_folder: str | pathlib.Path | None = None, + separate_format_defaults: bool | None = None, mpl_formatter: str = "{:P}", ): #: Map a definition class to a adder methods. - self._adders = dict() + self._adders = {} self._register_definition_adders() self._init_dynamic_classes() @@ -255,44 +246,37 @@ class PlainRegistry(metaclass=RegistryMeta): #: Map between name (string) and value (string) of defaults stored in the #: definitions file. - self._defaults: Dict[str, str] = {} + self._defaults: dict[str, str] = {} #: Map dimension name (string) to its definition (DimensionDefinition). - self._dimensions: Dict[ - str, Union[DimensionDefinition, DerivedDimensionDefinition] + self._dimensions: dict[ + str, DimensionDefinition | DerivedDimensionDefinition ] = {} #: Map unit name (string) to its definition (UnitDefinition). #: Might contain prefixed units. - self._units: Dict[str, UnitDefinition] = {} + self._units: dict[str, UnitDefinition] = {} #: List base unit names - self._base_units: List[str] = [] + self._base_units: list[str] = [] #: Map unit name in lower case (string) to a set of unit names with the right #: case. #: Does not contain prefixed units. #: e.g: 'hz' - > set('Hz', ) - self._units_casei: Dict[str, Set[str]] = defaultdict(set) + self._units_casei: dict[str, set[str]] = defaultdict(set) #: Map prefix name (string) to its definition (PrefixDefinition). - self._prefixes: Dict[str, PrefixDefinition] = {"": PrefixDefinition("", 1)} + self._prefixes: dict[str, PrefixDefinition] = {"": PrefixDefinition("", 1)} #: Map suffix name (string) to canonical , and unit alias to canonical unit name - self._suffixes: Dict[str, str] = {"": "", "s": ""} + self._suffixes: dict[str, str] = {"": "", "s": ""} #: Map contexts to RegistryCache self._cache = RegistryCache() self._initialized = False - def __init_subclass__(cls, **kwargs): - super().__init_subclass__() - cls.Unit: Unit = build_dependent_class(cls, "Unit", "_unit_class") - cls.Quantity: Quantity = build_dependent_class( - cls, "Quantity", "_quantity_class" - ) - def _init_dynamic_classes(self) -> None: """Generate subclasses on the fly and attach them to self""" @@ -326,7 +310,7 @@ class PlainRegistry(metaclass=RegistryMeta): self._register_adder(DimensionDefinition, self._add_dimension) self._register_adder(DerivedDimensionDefinition, self._add_derived_dimension) - def __deepcopy__(self, memo) -> "PlainRegistry": + def __deepcopy__(self, memo) -> PlainRegistry: new = object.__new__(type(self)) new.__dict__ = copy.deepcopy(self.__dict__, memo) new._init_dynamic_classes() @@ -351,7 +335,7 @@ class PlainRegistry(metaclass=RegistryMeta): except UndefinedUnitError: return False - def __dir__(self) -> List[str]: + def __dir__(self) -> list[str]: #: Calling dir(registry) gives all units, methods, and attributes. #: Also used for autocompletion in IPython. return list(self._units.keys()) + list(object.__dir__(self)) @@ -365,7 +349,7 @@ class PlainRegistry(metaclass=RegistryMeta): """ return iter(sorted(self._units.keys())) - def set_fmt_locale(self, loc: Optional[str]) -> None: + def set_fmt_locale(self, loc: str | None) -> None: """Change the locale used by default by `format_babel`. Parameters @@ -397,7 +381,7 @@ class PlainRegistry(metaclass=RegistryMeta): self.Measurement.default_format = value @property - def cache_folder(self) -> Optional[pathlib.Path]: + def cache_folder(self) -> pathlib.Path | None: if self._diskcache: return self._diskcache.cache_folder return None @@ -472,7 +456,7 @@ class PlainRegistry(metaclass=RegistryMeta): if self._on_redefinition == "raise": raise RedefinitionError(key, type(value)) elif self._on_redefinition == "warn": - logger.warning("Redefining '%s' (%s)" % (key, type(value))) + logger.warning(f"Redefining '{key}' ({type(value)})") target_dict[key] = value if casei_target_dict is not None: @@ -581,9 +565,7 @@ class PlainRegistry(metaclass=RegistryMeta): logger.warning(f"Could not resolve {unit_name}: {exc!r}") return self._cache - def get_name( - self, name_or_alias: str, case_sensitive: Optional[bool] = None - ) -> str: + def get_name(self, name_or_alias: str, case_sensitive: bool | None = None) -> str: """Return the canonical name of a unit.""" if name_or_alias == "dimensionless": @@ -621,9 +603,7 @@ class PlainRegistry(metaclass=RegistryMeta): return unit_name - def get_symbol( - self, name_or_alias: str, case_sensitive: Optional[bool] = None - ) -> str: + def get_symbol(self, name_or_alias: str, case_sensitive: bool | None = None) -> str: """Return the preferred alias for a unit.""" candidates = self.parse_unit_name(name_or_alias, case_sensitive) if not candidates: @@ -632,8 +612,8 @@ class PlainRegistry(metaclass=RegistryMeta): prefix, unit_name, _ = candidates[0] else: logger.warning( - "Parsing {0} yield multiple results. " - "Options are: {1!r}".format(name_or_alias, candidates) + "Parsing {} yield multiple results. " + "Options are: {!r}".format(name_or_alias, candidates) ) prefix, unit_name, _ = candidates[0] @@ -654,7 +634,7 @@ class PlainRegistry(metaclass=RegistryMeta): return self._get_dimensionality(input_units) def _get_dimensionality( - self, input_units: Optional[UnitsContainerT] + self, input_units: UnitsContainerT | None ) -> UnitsContainerT: """Convert a UnitsContainer to plain dimensions.""" if not input_units: @@ -727,7 +707,7 @@ class PlainRegistry(metaclass=RegistryMeta): def get_root_units( self, input_units: UnitLike, check_nonmult: bool = True - ) -> Tuple[Number, PlainUnit]: + ) -> tuple[Number, PlainUnit]: """Convert unit or dict of units to the root units. If any unit is non multiplicative and check_converter is True, @@ -840,7 +820,7 @@ class PlainRegistry(metaclass=RegistryMeta): def get_compatible_units( self, input_units, group_or_system=None - ) -> FrozenSet[Unit]: + ) -> frozenset[Unit]: """ """ input_units = to_units_container(input_units) @@ -858,7 +838,7 @@ class PlainRegistry(metaclass=RegistryMeta): # TODO: remove context from here def is_compatible_with( - self, obj1: Any, obj2: Any, *contexts: Union[str, Context], **ctx_kwargs + self, obj1: Any, obj2: Any, *contexts: str | Context, **ctx_kwargs ) -> bool: """check if the other object is compatible @@ -972,8 +952,8 @@ class PlainRegistry(metaclass=RegistryMeta): return value def parse_unit_name( - self, unit_name: str, case_sensitive: Optional[bool] = None - ) -> Tuple[Tuple[str, str, str], ...]: + self, unit_name: str, case_sensitive: bool | None = None + ) -> tuple[tuple[str, str, str], ...]: """Parse a unit to identify prefix, unit name and suffix by walking the list of prefix and suffix. In case of equivalent combinations (e.g. ('kilo', 'gram', '') and @@ -997,8 +977,8 @@ class PlainRegistry(metaclass=RegistryMeta): ) def _parse_unit_name( - self, unit_name: str, case_sensitive: Optional[bool] = None - ) -> Iterator[Tuple[str, str, str]]: + self, unit_name: str, case_sensitive: bool | None = None + ) -> Iterator[tuple[str, str, str]]: """Helper of parse_unit_name.""" case_sensitive = ( self.case_sensitive if case_sensitive is None else case_sensitive @@ -1029,8 +1009,8 @@ class PlainRegistry(metaclass=RegistryMeta): @staticmethod def _dedup_candidates( - candidates: Iterable[Tuple[str, str, str]] - ) -> Tuple[Tuple[str, str, str], ...]: + candidates: Iterable[tuple[str, str, str]] + ) -> tuple[tuple[str, str, str], ...]: """Helper of parse_unit_name. Given an iterable of unit triplets (prefix, name, suffix), remove those with @@ -1051,8 +1031,8 @@ class PlainRegistry(metaclass=RegistryMeta): def parse_units( self, input_string: str, - as_delta: Optional[bool] = None, - case_sensitive: Optional[bool] = None, + as_delta: bool | None = None, + case_sensitive: bool | None = None, ) -> Unit: """Parse a units expression and returns a UnitContainer with the canonical names. @@ -1083,7 +1063,7 @@ class PlainRegistry(metaclass=RegistryMeta): self, input_string: str, as_delta: bool = True, - case_sensitive: Optional[bool] = None, + case_sensitive: bool | None = None, ) -> UnitsContainerT: """Parse a units expression and returns a UnitContainer with the canonical names. @@ -1124,15 +1104,7 @@ class PlainRegistry(metaclass=RegistryMeta): return ret - def _eval_token(self, token, case_sensitive=None, use_decimal=False, **values): - # TODO: remove this code when use_decimal is deprecated - if use_decimal: - raise DeprecationWarning( - "`use_decimal` is deprecated, use `non_int_type` keyword argument when instantiating the registry.\n" - ">>> from decimal import Decimal\n" - ">>> ureg = UnitRegistry(non_int_type=Decimal)" - ) - + def _eval_token(self, token, case_sensitive=None, **values): token_type = token[0] token_text = token[1] if token_type == NAME: @@ -1160,10 +1132,9 @@ class PlainRegistry(metaclass=RegistryMeta): self, input_string: str, pattern: str, - case_sensitive: Optional[bool] = None, - use_decimal: bool = False, + case_sensitive: bool | None = None, many: bool = False, - ) -> Union[List[str], str, None]: + ) -> list[str] | str | None: """Parse a string with a given regex pattern and returns result. Parameters @@ -1174,8 +1145,6 @@ class PlainRegistry(metaclass=RegistryMeta): The regex parse string case_sensitive : (Default value = None, which uses registry setting) - use_decimal : - (Default value = False) many : Match many results (Default value = False) @@ -1200,13 +1169,10 @@ class PlainRegistry(metaclass=RegistryMeta): match = match.groupdict() # Parse units - units = [] - for unit, value in match.items(): - # Construct measure by multiplying value by unit - units.append( - float(value) - * self.parse_expression(unit, case_sensitive, use_decimal) - ) + units = [ + float(value) * self.parse_expression(unit, case_sensitive) + for unit, value in match.items() + ] # Add to results results.append(units) @@ -1220,8 +1186,7 @@ class PlainRegistry(metaclass=RegistryMeta): def parse_expression( self, input_string: str, - case_sensitive: Optional[bool] = None, - use_decimal: bool = False, + case_sensitive: bool | None = None, **values, ) -> Quantity: """Parse a mathematical expression including units and return a quantity object. @@ -1235,8 +1200,6 @@ class PlainRegistry(metaclass=RegistryMeta): case_sensitive : (Default value = None, which uses registry setting) - use_decimal : - (Default value = False) **values : @@ -1244,15 +1207,6 @@ class PlainRegistry(metaclass=RegistryMeta): ------- """ - - # TODO: remove this code when use_decimal is deprecated - if use_decimal: - raise DeprecationWarning( - "`use_decimal` is deprecated, use `non_int_type` keyword argument when instantiating the registry.\n" - ">>> from decimal import Decimal\n" - ">>> ureg = UnitRegistry(non_int_type=Decimal)" - ) - if not input_string: return self.Quantity(1) diff --git a/pint/facets/plain/unit.py b/pint/facets/plain/unit.py index b608c05..64a7d3c 100644 --- a/pint/facets/plain/unit.py +++ b/pint/facets/plain/unit.py @@ -12,7 +12,7 @@ import copy import locale import operator from numbers import Number -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Any from ..._typing import UnitLike from ...compat import NUMERIC_TYPES @@ -65,7 +65,7 @@ class PlainUnit(PrettyIPython, SharedRegistryObject): return str(self).encode(locale.getpreferredencoding()) def __repr__(self) -> str: - return "<Unit('{}')>".format(self._units) + return f"<Unit('{self._units}')>" @property def dimensionless(self) -> bool: @@ -96,7 +96,7 @@ class PlainUnit(PrettyIPython, SharedRegistryObject): return self._REGISTRY.get_compatible_units(self) def is_compatible_with( - self, other: Any, *contexts: Union[str, Context], **ctx_kwargs: Any + self, other: Any, *contexts: str | Context, **ctx_kwargs: Any ) -> bool: """check if the other object is compatible @@ -165,18 +165,18 @@ class PlainUnit(PrettyIPython, SharedRegistryObject): return self._REGISTRY.Quantity(other, 1 / self._units) elif isinstance(other, UnitsContainer): return self.__class__(other / self._units) - else: - return NotImplemented + + return NotImplemented __div__ = __truediv__ __rdiv__ = __rtruediv__ - def __pow__(self, other) -> "PlainUnit": + def __pow__(self, other) -> PlainUnit: if isinstance(other, NUMERIC_TYPES): return self.__class__(self._units**other) else: - mess = "Cannot power PlainUnit by {}".format(type(other)) + mess = f"Cannot power PlainUnit by {type(other)}" raise TypeError(mess) def __hash__(self) -> int: @@ -207,8 +207,8 @@ class PlainUnit(PrettyIPython, SharedRegistryObject): return self_q.compare(other, op) elif isinstance(other, (PlainUnit, UnitsContainer, dict)): return self_q.compare(self._REGISTRY.Quantity(1, other), op) - else: - return NotImplemented + + return NotImplemented __lt__ = lambda self, other: self.compare(other, op=operator.lt) __le__ = lambda self, other: self.compare(other, op=operator.le) diff --git a/pint/facets/system/definitions.py b/pint/facets/system/definitions.py index 8243324..1ce8269 100644 --- a/pint/facets/system/definitions.py +++ b/pint/facets/system/definitions.py @@ -8,9 +8,10 @@ from __future__ import annotations -import typing as ty +from collections.abc import Iterable from dataclasses import dataclass +from ..._typing import Self from ... import errors @@ -23,7 +24,7 @@ class BaseUnitRule: new_unit_name: str #: name of the unit to be kicked out to make room for the new base uni #: If None, the current base unit with the same dimensionality will be used - old_unit_name: ty.Optional[str] = None + old_unit_name: str | None = None # Instead of defining __post_init__ here, # it will be added to the container class @@ -38,13 +39,16 @@ class SystemDefinition(errors.WithDefErr): #: name of the system name: str #: unit groups that will be included within the system - using_group_names: ty.Tuple[str, ...] + using_group_names: tuple[str] #: rules to define new base unit within the system. - rules: ty.Tuple[BaseUnitRule, ...] + rules: tuple[BaseUnitRule] @classmethod - def from_lines(cls, lines, non_int_type): + def from_lines( + cls: type[Self], lines: Iterable[str], non_int_type: type + ) -> Self | None: # TODO: this is to keep it backwards compatible + # TODO: check when is None returned. from ...delegates import ParserConfig, txt_defparser cfg = ParserConfig(non_int_type) @@ -55,7 +59,8 @@ class SystemDefinition(errors.WithDefErr): return definition @property - def unit_replacements(self) -> ty.Tuple[ty.Tuple[str, str], ...]: + def unit_replacements(self) -> tuple[tuple[str, str | None]]: + # TODO: check if None can be dropped. return tuple((el.new_unit_name, el.old_unit_name) for el in self.rules) def __post_init__(self): diff --git a/pint/facets/system/objects.py b/pint/facets/system/objects.py index 829fb5c..69b1c84 100644 --- a/pint/facets/system/objects.py +++ b/pint/facets/system/objects.py @@ -9,6 +9,13 @@ from __future__ import annotations +import numbers + +from typing import Any +from collections.abc import Iterable + +from ..._typing import Self + from ...babel_names import _babel_systems from ...compat import babel_parse from ...util import ( @@ -29,32 +36,28 @@ class System(SharedRegistryObject): The System belongs to one Registry. See SystemDefinition for the definition file syntax. - """ - def __init__(self, name): - """ - :param name: Name of the group - :type name: str - """ + Parameters + ---------- + name + Name of the group. + """ + def __init__(self, name: str): #: Name of the system #: :type: str self.name = name #: Maps root unit names to a dict indicating the new unit and its exponent. - #: :type: dict[str, dict[str, number]]] - self.base_units = {} + self.base_units: dict[str, dict[str, numbers.Number]] = {} #: Derived unit names. - #: :type: set(str) - self.derived_units = set() + self.derived_units: set[str] = set() #: Names of the _used_groups in used by this system. - #: :type: set(str) - self._used_groups = set() + self._used_groups: set[str] = set() - #: :type: frozenset | None - self._computed_members = None + self._computed_members: frozenset[str] | None = None # Add this system to the system dictionary self._REGISTRY._systems[self.name] = self @@ -62,7 +65,7 @@ class System(SharedRegistryObject): def __dir__(self): return list(self.members) - def __getattr__(self, item): + def __getattr__(self, item: str) -> Any: getattr_maybe_raise(self, item) u = getattr(self._REGISTRY, self.name + "_" + item, None) if u is not None: @@ -93,19 +96,19 @@ class System(SharedRegistryObject): """Invalidate computed members in this Group and all parent nodes.""" self._computed_members = None - def add_groups(self, *group_names): + def add_groups(self, *group_names: str) -> None: """Add groups to group.""" self._used_groups |= set(group_names) self.invalidate_members() - def remove_groups(self, *group_names): + def remove_groups(self, *group_names: str) -> None: """Remove groups from group.""" self._used_groups -= set(group_names) self.invalidate_members() - def format_babel(self, locale): + def format_babel(self, locale: str) -> str: """translate the name of the system.""" if locale and self.name in _babel_systems: name = _babel_systems[self.name] @@ -114,8 +117,12 @@ class System(SharedRegistryObject): return self.name @classmethod - def from_lines(cls, lines, get_root_func, non_int_type=float): - system_definition = SystemDefinition.from_lines(lines, get_root_func) + def from_lines( + cls: type[Self], lines: Iterable[str], get_root_func, non_int_type: type = float + ) -> Self: + # TODO: we changed something here it used to be + # system_definition = SystemDefinition.from_lines(lines, get_root_func) + system_definition = SystemDefinition.from_lines(lines, non_int_type) return cls.from_definition(system_definition, get_root_func) @classmethod @@ -174,12 +181,12 @@ class System(SharedRegistryObject): class Lister: - def __init__(self, d): + def __init__(self, d: dict[str, Any]): self.d = d - def __dir__(self): + def __dir__(self) -> list[str]: return list(self.d.keys()) - def __getattr__(self, item): + def __getattr__(self, item: str) -> Any: getattr_maybe_raise(self, item) return self.d[item] diff --git a/pint/facets/system/registry.py b/pint/facets/system/registry.py index 527440a..6e0878e 100644 --- a/pint/facets/system/registry.py +++ b/pint/facets/system/registry.py @@ -9,7 +9,7 @@ from __future__ import annotations from numbers import Number -from typing import TYPE_CHECKING, Dict, FrozenSet, Tuple, Union +from typing import TYPE_CHECKING from ... import errors @@ -19,13 +19,13 @@ if TYPE_CHECKING: from ..._typing import UnitLike from ...util import UnitsContainer as UnitsContainerT from ...util import ( - build_dependent_class, create_class_with_registry, to_units_container, ) from ..group import GroupRegistry from .definitions import SystemDefinition from .objects import Lister, System +from . import objects class SystemRegistry(GroupRegistry): @@ -46,24 +46,20 @@ class SystemRegistry(GroupRegistry): # TODO: Change this to System: System to specify class # and use introspection to get system class as a way # to enjoy typing goodies - _system_class = System + System = objects.System def __init__(self, system=None, **kwargs): super().__init__(**kwargs) #: Map system name to system. #: :type: dict[ str | System] - self._systems: Dict[str, System] = {} + self._systems: dict[str, System] = {} #: Maps dimensionality (UnitsContainer) to Dimensionality (UnitsContainer) - self._base_units_cache = dict() + self._base_units_cache = {} self._default_system = system - def __init_subclass__(cls, **kwargs): - super().__init_subclass__() - cls.System = build_dependent_class(cls, "System", "_system_class") - def _init_dynamic_classes(self) -> None: """Generate subclasses on the fly and attach them to self""" super()._init_dynamic_classes() @@ -143,10 +139,10 @@ class SystemRegistry(GroupRegistry): def get_base_units( self, - input_units: Union[UnitLike, Quantity], + input_units: UnitLike | Quantity, check_nonmult: bool = True, - system: Union[str, System, None] = None, - ) -> Tuple[Number, Unit]: + system: str | System | None = None, + ) -> tuple[Number, Unit]: """Convert unit or dict of units to the plain units. If any unit is non multiplicative and check_converter is True, @@ -183,7 +179,7 @@ class SystemRegistry(GroupRegistry): self, input_units: UnitsContainerT, check_nonmult: bool = True, - system: Union[str, System, None] = None, + system: str | System | None = None, ): if system is None: system = self._default_system @@ -224,7 +220,7 @@ class SystemRegistry(GroupRegistry): return base_factor, destination_units - def _get_compatible_units(self, input_units, group_or_system) -> FrozenSet[Unit]: + def _get_compatible_units(self, input_units, group_or_system) -> frozenset[Unit]: if group_or_system is None: group_or_system = self._default_system diff --git a/pint/formatting.py b/pint/formatting.py index f450d5f..880f55b 100644 --- a/pint/formatting.py +++ b/pint/formatting.py @@ -13,7 +13,9 @@ from __future__ import annotations import functools import re import warnings -from typing import Callable, Dict +from typing import Callable, Any +from collections.abc import Iterable +from numbers import Number from .babel_names import _babel_lengths, _babel_units from .compat import babel_parse @@ -21,7 +23,7 @@ from .compat import babel_parse __JOIN_REG_EXP = re.compile(r"{\d*}") -def _join(fmt, iterable): +def _join(fmt: str, iterable: Iterable[Any]): """Join an iterable with the format specified in fmt. The format can be specified in two ways: @@ -55,7 +57,7 @@ def _join(fmt, iterable): _PRETTY_EXPONENTS = "⁰¹²³⁴⁵⁶⁷⁸⁹" -def _pretty_fmt_exponent(num): +def _pretty_fmt_exponent(num: Number) -> str: """Format an number into a pretty printed exponent. Parameters @@ -76,7 +78,7 @@ def _pretty_fmt_exponent(num): #: _FORMATS maps format specifications to the corresponding argument set to #: formatter(). -_FORMATS: Dict[str, dict] = { +_FORMATS: dict[str, dict[str, Any]] = { "P": { # Pretty format. "as_ratio": True, "single_denominator": False, @@ -122,10 +124,10 @@ _FORMATS: Dict[str, dict] = { } #: _FORMATTERS maps format names to callables doing the formatting -_FORMATTERS: Dict[str, Callable] = {} +_FORMATTERS: dict[str, Callable] = {} -def register_unit_format(name): +def register_unit_format(name: str): """register a function as a new format for units The registered function must have a signature of: @@ -197,9 +199,7 @@ def latex_escape(string): @register_unit_format("L") def format_latex(unit, registry, **options): - preprocessed = { - r"\mathrm{{{}}}".format(latex_escape(u)): p for u, p in unit.items() - } + preprocessed = {rf"\mathrm{{{latex_escape(u)}}}": p for u, p in unit.items()} formatted = formatter( preprocessed.items(), as_ratio=True, @@ -270,18 +270,18 @@ def format_compact(unit, registry, **options): def formatter( - items, - as_ratio=True, - single_denominator=False, - product_fmt=" * ", - division_fmt=" / ", - power_fmt="{} ** {}", - parentheses_fmt="({0})", + items: list[tuple[str, Number]], + as_ratio: bool = True, + single_denominator: bool = False, + product_fmt: str = " * ", + division_fmt: str = " / ", + power_fmt: str = "{} ** {}", + parentheses_fmt: str = "({0})", exp_call=lambda x: f"{x:n}", - locale=None, - babel_length="long", - babel_plural_form="one", - sort=True, + locale: str | None = None, + babel_length: str = "long", + babel_plural_form: str = "one", + sort: bool = True, ): """Format a list of (name, exponent) pairs. @@ -442,10 +442,10 @@ def siunitx_format_unit(units, registry): elif power == 3: return r"\cubed" else: - return r"\tothe{{{:d}}}".format(int(power)) + return rf"\tothe{{{int(power):d}}}" else: # limit float powers to 3 decimal places - return r"\tothe{{{:.3f}}}".format(power).rstrip("0") + return rf"\tothe{{{power:.3f}}}".rstrip("0") lpos = [] lneg = [] @@ -466,9 +466,9 @@ def siunitx_format_unit(units, registry): if power < 0: lpick.append(r"\per") if prefix is not None: - lpick.append(r"\{}".format(prefix)) - lpick.append(r"\{}".format(unit)) - lpick.append(r"{}".format(_tothe(abs(power)))) + lpick.append(rf"\{prefix}") + lpick.append(rf"\{unit}") + lpick.append(rf"{_tothe(abs(power))}") return "".join(lpos) + "".join(lneg) @@ -529,8 +529,8 @@ def split_format(spec, default, separate_format_defaults=True): elif not spec: mspec, uspec = default_mspec, default_uspec else: - mspec = mspec if mspec else default_mspec - uspec = uspec if uspec else default_uspec + mspec = mspec or default_mspec + uspec = uspec or default_uspec return mspec, uspec diff --git a/pint/matplotlib.py b/pint/matplotlib.py index ea88c70..25c257b 100644 --- a/pint/matplotlib.py +++ b/pint/matplotlib.py @@ -36,15 +36,15 @@ class PintConverter(matplotlib.units.ConversionInterface): """Convert :`Quantity` instances for matplotlib to use.""" if iterable(value): return [self._convert_value(v, unit, axis) for v in value] - else: - return self._convert_value(value, unit, axis) + + return self._convert_value(value, unit, axis) def _convert_value(self, value, unit, axis): """Handle converting using attached unit or falling back to axis units.""" if hasattr(value, "units"): return value.to(unit).magnitude - else: - return self._reg.Quantity(value, axis.get_units()).to(unit).magnitude + + return self._reg.Quantity(value, axis.get_units()).to(unit).magnitude @staticmethod def axisinfo(unit, axis): diff --git a/pint/pint_convert.py b/pint/pint_convert.py index d8d60e8..bf90972 100755 --- a/pint/pint_convert.py +++ b/pint/pint_convert.py @@ -11,6 +11,7 @@ from __future__ import annotations import argparse +import contextlib import re from pint import UnitRegistry @@ -154,13 +155,13 @@ if args.unc: ), ) - ureg._root_units_cache = dict() + ureg._root_units_cache = {} ureg._build_cache() def convert(u_from, u_to=None, unc=None, factor=None): q = ureg.Quantity(u_from) - fmt = ".{}g".format(args.prec) + fmt = f".{args.prec}g" if unc: q = q.plus_minus(unc) if u_to: @@ -172,25 +173,23 @@ def convert(u_from, u_to=None, unc=None, factor=None): nq *= ureg.Quantity(factor).to_base_units() prec_unc = use_unc(nq.magnitude, fmt, args.prec_unc) if prec_unc > 0: - fmt = ".{}uS".format(prec_unc) + fmt = f".{prec_unc}uS" else: - try: + with contextlib.suppress(Exception): nq = nq.magnitude.n * nq.units - except Exception: - pass + fmt = "{:" + fmt + "} {:~P}" print(("{:} = " + fmt).format(q, nq.magnitude, nq.units)) def use_unc(num, fmt, prec_unc): unc = 0 - try: + with contextlib.suppress(Exception): if isinstance(num, uncertainties.UFloat): full = ("{:" + fmt + "}").format(num) unc = re.search(r"\+/-[0.]*([\d.]*)", full).group(1) unc = len(unc.replace(".", "")) - except Exception: - pass + return max(0, min(prec_unc, unc)) diff --git a/pint/pint_eval.py b/pint/pint_eval.py index 2054260..d476eae 100644 --- a/pint/pint_eval.py +++ b/pint/pint_eval.py @@ -11,7 +11,9 @@ from __future__ import annotations import operator import token as tokenlib -import tokenize +from tokenize import TokenInfo + +from typing import Any from .errors import DefinitionSyntaxError @@ -30,7 +32,7 @@ _OP_PRIORITY = { } -def _power(left, right): +def _power(left: Any, right: Any) -> Any: from . import Quantity from .compat import is_duck_array @@ -45,7 +47,19 @@ def _power(left, right): return operator.pow(left, right) -_BINARY_OPERATOR_MAP = { +import typing + +UnaryOpT = typing.Callable[ + [ + Any, + ], + Any, +] +BinaryOpT = typing.Callable[[Any, Any], Any] + +_UNARY_OPERATOR_MAP: dict[str, UnaryOpT] = {"+": lambda x: x, "-": lambda x: x * -1} + +_BINARY_OPERATOR_MAP: dict[str, BinaryOpT] = { "**": _power, "*": operator.mul, "": operator.mul, # operator for implicit ops @@ -56,8 +70,6 @@ _BINARY_OPERATOR_MAP = { "//": operator.floordiv, } -_UNARY_OPERATOR_MAP = {"+": lambda x: x, "-": lambda x: x * -1} - class EvalTreeNode: """Single node within an evaluation tree @@ -68,25 +80,43 @@ class EvalTreeNode: left --> single value """ - def __init__(self, left, operator=None, right=None): + def __init__( + self, + left: EvalTreeNode | TokenInfo, + operator: TokenInfo | None = None, + right: EvalTreeNode | None = None, + ): self.left = left self.operator = operator self.right = right - def to_string(self): + def to_string(self) -> str: # For debugging purposes if self.right: + assert isinstance(self.left, EvalTreeNode), "self.left not EvalTreeNode (1)" comps = [self.left.to_string()] if self.operator: - comps.append(self.operator[1]) + comps.append(self.operator.string) comps.append(self.right.to_string()) elif self.operator: - comps = [self.operator[1], self.left.to_string()] + assert isinstance(self.left, EvalTreeNode), "self.left not EvalTreeNode (2)" + comps = [self.operator.string, self.left.to_string()] else: - return self.left[1] + assert isinstance(self.left, TokenInfo), "self.left not TokenInfo (1)" + return self.left.string return "(%s)" % " ".join(comps) - def evaluate(self, define_op, bin_op=None, un_op=None): + def evaluate( + self, + define_op: typing.Callable[ + [ + Any, + ], + Any, + ], + bin_op: dict[str, BinaryOpT] | None = None, + un_op: dict[str, UnaryOpT] | None = None, + ): """Evaluate node. Parameters @@ -107,33 +137,38 @@ class EvalTreeNode: un_op = un_op or _UNARY_OPERATOR_MAP if self.right: + assert isinstance(self.left, EvalTreeNode), "self.left not EvalTreeNode (3)" # binary or implicit operator - op_text = self.operator[1] if self.operator else "" + op_text = self.operator.string if self.operator else "" if op_text not in bin_op: - raise DefinitionSyntaxError('missing binary operator "%s"' % op_text) - left = self.left.evaluate(define_op, bin_op, un_op) - return bin_op[op_text](left, self.right.evaluate(define_op, bin_op, un_op)) + raise DefinitionSyntaxError(f"missing binary operator '{op_text}'") + + return bin_op[op_text]( + self.left.evaluate(define_op, bin_op, un_op), + self.right.evaluate(define_op, bin_op, un_op), + ) elif self.operator: + assert isinstance(self.left, EvalTreeNode), "self.left not EvalTreeNode (4)" # unary operator - op_text = self.operator[1] + op_text = self.operator.string if op_text not in un_op: - raise DefinitionSyntaxError('missing unary operator "%s"' % op_text) + raise DefinitionSyntaxError(f"missing unary operator '{op_text}'") return un_op[op_text](self.left.evaluate(define_op, bin_op, un_op)) - else: - # single value - return define_op(self.left) + # single value + return define_op(self.left) -from typing import Iterable +from collections.abc import Iterable -def build_eval_tree( - tokens: Iterable[tokenize.TokenInfo], - op_priority=None, - index=0, - depth=0, - prev_op=None, -) -> tuple[EvalTreeNode | None, int] | EvalTreeNode: + +def _build_eval_tree( + tokens: list[TokenInfo], + op_priority: dict[str, int], + index: int = 0, + depth: int = 0, + prev_op: str = "<none>", +) -> tuple[EvalTreeNode, int]: """Build an evaluation tree from a set of tokens. Params: @@ -153,14 +188,12 @@ def build_eval_tree( 5) Combine left side, operator, and right side into a new left side 6) Go back to step #2 - """ + Raises + ------ + DefinitionSyntaxError + If there is a syntax error. - if op_priority is None: - op_priority = _OP_PRIORITY - - if depth == 0 and prev_op is None: - # ensure tokens is list so we can access by index - tokens = list(tokens) + """ result = None @@ -171,19 +204,21 @@ def build_eval_tree( if token_type == tokenlib.OP: if token_text == ")": - if prev_op is None: + if prev_op == "<none>": raise DefinitionSyntaxError( - "unopened parentheses in tokens: %s" % current_token + f"unopened parentheses in tokens: {current_token}" ) elif prev_op == "(": # close parenthetical group + assert result is not None return result, index else: # parenthetical group ending, but we need to close sub-operations within group + assert result is not None return result, index - 1 elif token_text == "(": # gather parenthetical group - right, index = build_eval_tree( + right, index = _build_eval_tree( tokens, op_priority, index + 1, 0, token_text ) if not tokens[index][1] == ")": @@ -204,11 +239,11 @@ def build_eval_tree( # (2 * 3 / 4) --> ((2 * 3) / 4) if op_priority[token_text] <= op_priority.get( prev_op, -1 - ) and token_text not in ["**", "^"]: + ) and token_text not in ("**", "^"): # previous operator is higher priority, so end previous binary op return result, index - 1 # get right side of binary op - right, index = build_eval_tree( + right, index = _build_eval_tree( tokens, op_priority, index + 1, depth + 1, token_text ) result = EvalTreeNode( @@ -216,18 +251,18 @@ def build_eval_tree( ) else: # unary operator - right, index = build_eval_tree( + right, index = _build_eval_tree( tokens, op_priority, index + 1, depth + 1, "unary" ) result = EvalTreeNode(left=right, operator=current_token) - elif token_type == tokenlib.NUMBER or token_type == tokenlib.NAME: + elif token_type in (tokenlib.NUMBER, tokenlib.NAME): if result: # tokens with an implicit operation i.e. "1 kg" if op_priority[""] <= op_priority.get(prev_op, -1): # previous operator is higher priority than implicit, so end # previous binary op return result, index - 1 - right, index = build_eval_tree( + right, index = _build_eval_tree( tokens, op_priority, index, depth + 1, "" ) result = EvalTreeNode(left=result, right=right) @@ -240,13 +275,57 @@ def build_eval_tree( raise DefinitionSyntaxError("unclosed parentheses in tokens") if depth > 0 or prev_op: # have to close recursion + assert result is not None return result, index else: # recursion all closed, so just return the final result - return result + assert result is not None + return result, -1 if index + 1 >= len(tokens): # should hit ENDMARKER before this ever happens raise DefinitionSyntaxError("unexpected end to tokens") index += 1 + + +def build_eval_tree( + tokens: Iterable[TokenInfo], + op_priority: dict[str, int] | None = None, +) -> EvalTreeNode: + """Build an evaluation tree from a set of tokens. + + Params: + Index, depth, and prev_op used recursively, so don't touch. + Tokens is an iterable of tokens from an expression to be evaluated. + + Transform the tokens from an expression into a recursive parse tree, following order + of operations. Operations can include binary ops (3 + 4), implicit ops (3 kg), or + unary ops (-1). + + General Strategy: + 1) Get left side of operator + 2) If no tokens left, return final result + 3) Get operator + 4) Use recursion to create tree starting at token on right side of operator (start at step #1) + 4.1) If recursive call encounters an operator with lower or equal priority to step #2, exit recursion + 5) Combine left side, operator, and right side into a new left side + 6) Go back to step #2 + + Raises + ------ + DefinitionSyntaxError + If there is a syntax error. + + """ + + if op_priority is None: + op_priority = _OP_PRIORITY + + if not isinstance(tokens, list): + # ensure tokens is list so we can access by index + tokens = list(tokens) + + result, _ = _build_eval_tree(tokens, op_priority, 0, 0) + + return result diff --git a/pint/registry.py b/pint/registry.py index 29d5c89..474eb77 100644 --- a/pint/registry.py +++ b/pint/registry.py @@ -27,6 +27,35 @@ from .facets import ( from .util import logger, pi_theorem +# To build the Quantity and Unit classes +# we follow the UnitRegistry bases +# but + + +class Quantity( + # SystemRegistry.Quantity, + # ContextRegistry.Quantity, + DaskRegistry.Quantity, + NumpyRegistry.Quantity, + MeasurementRegistry.Quantity, + FormattingRegistry.Quantity, + NonMultiplicativeRegistry.Quantity, +): + pass + + +class Unit( + # SystemRegistry.Unit, + # ContextRegistry.Unit, + # DaskRegistry.Unit, + NumpyRegistry.Unit, + # MeasurementRegistry.Unit, + FormattingRegistry.Unit, + NonMultiplicativeRegistry.Unit, +): + pass + + class UnitRegistry( SystemRegistry, ContextRegistry, @@ -72,6 +101,9 @@ class UnitRegistry( If None, the cache is disabled. (default) """ + Quantity = Quantity + Unit = Unit + def __init__( self, filename="", diff --git a/pint/registry_helpers.py b/pint/registry_helpers.py index 07b00ff..1f28036 100644 --- a/pint/registry_helpers.py +++ b/pint/registry_helpers.py @@ -13,7 +13,8 @@ from __future__ import annotations import functools from inspect import signature from itertools import zip_longest -from typing import TYPE_CHECKING, Callable, Iterable, TypeVar, Union +from typing import TYPE_CHECKING, Callable, TypeVar +from collections.abc import Iterable from ._typing import F from .errors import DimensionalityError @@ -184,9 +185,9 @@ def _apply_defaults(func, args, kwargs): def wraps( - ureg: "UnitRegistry", - ret: Union[str, "Unit", Iterable[Union[str, "Unit", None]], None], - args: Union[str, "Unit", Iterable[Union[str, "Unit", None]], None], + ureg: UnitRegistry, + ret: str | Unit | Iterable[str | Unit | None] | None, + args: str | Unit | Iterable[str | Unit | None] | None, strict: bool = True, ) -> Callable[[Callable[..., T]], Callable[..., Quantity[T]]]: """Wraps a function to become pint-aware. @@ -300,7 +301,7 @@ def wraps( def check( - ureg: "UnitRegistry", *args: Union[str, UnitsContainer, "Unit", None] + ureg: UnitRegistry, *args: str | UnitsContainer | Unit | None ) -> Callable[[F], F]: """Decorator to for quantity type checking for function inputs. diff --git a/pint/testing.py b/pint/testing.py index 1c458f5..8e4f15f 100644 --- a/pint/testing.py +++ b/pint/testing.py @@ -36,10 +36,10 @@ def _get_comparable_magnitudes(first, second, msg): def assert_equal(first, second, msg=None): if msg is None: - msg = "Comparing %r and %r. " % (first, second) + msg = f"Comparing {first!r} and {second!r}. " m1, m2 = _get_comparable_magnitudes(first, second, msg) - msg += " (Converted to %r and %r): Magnitudes are not equal" % (m1, m2) + msg += f" (Converted to {m1!r} and {m2!r}): Magnitudes are not equal" if isinstance(m1, ndarray) or isinstance(m2, ndarray): np.testing.assert_array_equal(m1, m2, err_msg=msg) @@ -60,15 +60,15 @@ def assert_equal(first, second, msg=None): def assert_allclose(first, second, rtol=1e-07, atol=0, msg=None): if msg is None: try: - msg = "Comparing %r and %r. " % (first, second) + msg = f"Comparing {first!r} and {second!r}. " except TypeError: try: - msg = "Comparing %s and %s. " % (first, second) + msg = f"Comparing {first} and {second}. " except Exception: msg = "Comparing" m1, m2 = _get_comparable_magnitudes(first, second, msg) - msg += " (Converted to %r and %r)" % (m1, m2) + msg += f" (Converted to {m1!r} and {m2!r})" if isinstance(m1, ndarray) or isinstance(m2, ndarray): np.testing.assert_allclose(m1, m2, rtol=rtol, atol=atol, err_msg=msg) diff --git a/pint/testsuite/__init__.py b/pint/testsuite/__init__.py index 8c0cd09..35b0d91 100644 --- a/pint/testsuite/__init__.py +++ b/pint/testsuite/__init__.py @@ -3,7 +3,8 @@ import math import os import unittest import warnings -from contextlib import contextmanager +import contextlib +import pathlib from pint import UnitRegistry from pint.testsuite.helpers import PintOutputChecker @@ -25,7 +26,7 @@ class QuantityTestCase: cls.U_ = None -@contextmanager +@contextlib.contextmanager def assert_no_warnings(): with warnings.catch_warnings(): warnings.simplefilter("error") @@ -40,13 +41,12 @@ def testsuite(): # TESTING THE DOCUMENTATION requires pyyaml, serialize, numpy and uncertainties if HAS_NUMPY and HAS_UNCERTAINTIES: - try: + with contextlib.suppress(ImportError): import serialize # noqa: F401 import yaml # noqa: F401 add_docs(suite) - except ImportError: - pass + return suite @@ -98,7 +98,7 @@ def add_docs(suite): """ docpath = os.path.join(os.path.dirname(__file__), "..", "..", "docs") docpath = os.path.abspath(docpath) - if os.path.exists(docpath): + if pathlib.Path(docpath).exists(): checker = PintOutputChecker() for name in (name for name in os.listdir(docpath) if name.endswith(".rst")): file = os.path.join(docpath, name) diff --git a/pint/testsuite/helpers.py b/pint/testsuite/helpers.py index 4c560fb..191f4c3 100644 --- a/pint/testsuite/helpers.py +++ b/pint/testsuite/helpers.py @@ -1,6 +1,7 @@ import doctest import pickle import re +import contextlib import pytest from packaging.version import parse as version_parse @@ -41,14 +42,12 @@ class PintOutputChecker(doctest.OutputChecker): if check: return check - try: + with contextlib.suppress(Exception): if eval(want) == eval(got): return True - except Exception: - pass for regex in (_q_re, _sq_re): - try: + with contextlib.suppress(Exception): parsed_got = regex.match(got.replace(r"\\", "")).groupdict() parsed_want = regex.match(want.replace(r"\\", "")).groupdict() @@ -62,12 +61,10 @@ class PintOutputChecker(doctest.OutputChecker): return False return True - except Exception: - pass cnt = 0 for regex in (_unit_re,): - try: + with contextlib.suppress(Exception): parsed_got, tmp = regex.subn("\1", got) cnt += tmp parsed_want, temp = regex.subn("\1", want) @@ -76,9 +73,6 @@ class PintOutputChecker(doctest.OutputChecker): if parsed_got == parsed_want: return True - except Exception: - pass - if cnt: # If there was any replacement, we try again the previous methods. return self.check_output(parsed_want, parsed_got, optionflags) diff --git a/pint/testsuite/test_babel.py b/pint/testsuite/test_babel.py index 5c32879..7842d54 100644 --- a/pint/testsuite/test_babel.py +++ b/pint/testsuite/test_babel.py @@ -84,16 +84,16 @@ def test_str(func_registry): s = "24.0 meter" assert str(d) == s assert "%s" % d == s - assert "{}".format(d) == s + assert f"{d}" == s ureg.set_fmt_locale("fr_FR") s = "24.0 mètres" assert str(d) == s assert "%s" % d == s - assert "{}".format(d) == s + assert f"{d}" == s ureg.set_fmt_locale(None) s = "24.0 meter" assert str(d) == s assert "%s" % d == s - assert "{}".format(d) == s + assert f"{d}" == s diff --git a/pint/testsuite/test_compat_downcast.py b/pint/testsuite/test_compat_downcast.py index ebb5907..ed43e94 100644 --- a/pint/testsuite/test_compat_downcast.py +++ b/pint/testsuite/test_compat_downcast.py @@ -1,3 +1,4 @@ +import operator import pytest from pint import UnitRegistry @@ -37,7 +38,7 @@ def q_base(local_registry): # Define identity function for use in tests -def identity(ureg, x): +def id_matrix(ureg, x): return x @@ -62,17 +63,17 @@ def array(request): @pytest.mark.parametrize( "op, magnitude_op, unit_op", [ - pytest.param(identity, identity, identity, id="identity"), + pytest.param(id_matrix, id_matrix, id_matrix, id="identity"), pytest.param( lambda ureg, x: x + 1 * ureg.m, lambda ureg, x: x + 1, - identity, + id_matrix, id="addition", ), pytest.param( lambda ureg, x: x - 20 * ureg.cm, lambda ureg, x: x - 0.2, - identity, + id_matrix, id="subtraction", ), pytest.param( @@ -83,7 +84,7 @@ def array(request): ), pytest.param( lambda ureg, x: x / (1 * ureg.s), - identity, + id_matrix, lambda ureg, u: u / ureg.s, id="division", ), @@ -93,17 +94,17 @@ def array(request): WR(lambda u: u**2), id="square", ), - pytest.param(WR(lambda x: x.T), WR(lambda x: x.T), identity, id="transpose"), - pytest.param(WR(np.mean), WR(np.mean), identity, id="mean ufunc"), - pytest.param(WR(np.sum), WR(np.sum), identity, id="sum ufunc"), + pytest.param(WR(lambda x: x.T), WR(lambda x: x.T), id_matrix, id="transpose"), + pytest.param(WR(np.mean), WR(np.mean), id_matrix, id="mean ufunc"), + pytest.param(WR(np.sum), WR(np.sum), id_matrix, id="sum ufunc"), pytest.param(WR(np.sqrt), WR(np.sqrt), WR(lambda u: u**0.5), id="sqrt ufunc"), pytest.param( WR(lambda x: np.reshape(x, (25,))), WR(lambda x: np.reshape(x, (25,))), - identity, + id_matrix, id="reshape function", ), - pytest.param(WR(np.amax), WR(np.amax), identity, id="amax function"), + pytest.param(WR(np.amax), WR(np.amax), id_matrix, id="amax function"), ], ) def test_univariate_op_consistency( @@ -121,10 +122,8 @@ def test_univariate_op_consistency( @pytest.mark.parametrize( "op, unit", [ - pytest.param( - lambda x, y: x * y, lambda ureg: ureg("kg m"), id="multiplication" - ), - pytest.param(lambda x, y: x / y, lambda ureg: ureg("m / kg"), id="division"), + pytest.param(operator.mul, lambda ureg: ureg("kg m"), id="multiplication"), + pytest.param(operator.truediv, lambda ureg: ureg("m / kg"), id="division"), pytest.param(np.multiply, lambda ureg: ureg("kg m"), id="multiply ufunc"), ], ) @@ -143,11 +142,11 @@ def test_bivariate_op_consistency(local_registry, q_base, op, unit, array): "op", [ pytest.param( - WR2(lambda a, u: a * u), + WR2(operator.mul), id="array-first", marks=pytest.mark.xfail(reason="upstream issue numpy/numpy#15200"), ), - pytest.param(WR2(lambda a, u: u * a), id="unit-first"), + pytest.param(WR2(operator.mul), id="unit-first"), ], ) @pytest.mark.parametrize( diff --git a/pint/testsuite/test_compat_upcast.py b/pint/testsuite/test_compat_upcast.py index ad267c1..c8266f7 100644 --- a/pint/testsuite/test_compat_upcast.py +++ b/pint/testsuite/test_compat_upcast.py @@ -1,3 +1,4 @@ +import operator import pytest # Conditionally import NumPy and any upcast type libraries @@ -49,9 +50,9 @@ def test_quantification(module_registry, ds): @pytest.mark.parametrize( "op", [ - lambda x, y: x + y, + operator.add, lambda x, y: x - (-y), - lambda x, y: x * y, + operator.mul, lambda x, y: x / (y**-1), ], ) @@ -126,9 +127,7 @@ def test_array_function_deferral(da, module_registry): upper = 3 * module_registry.m args = (da, lower, upper) assert ( - lower.__array_function__( - np.clip, tuple(set(type(arg) for arg in args)), args, {} - ) + lower.__array_function__(np.clip, tuple({type(arg) for arg in args}), args, {}) is NotImplemented ) diff --git a/pint/testsuite/test_contexts.py b/pint/testsuite/test_contexts.py index c7551e4..ea6525d 100644 --- a/pint/testsuite/test_contexts.py +++ b/pint/testsuite/test_contexts.py @@ -683,7 +683,7 @@ class TestDefinedContexts: ) p = find_shortest_path(ureg._active_ctx.graph, da, db) assert p - msg = "{} <-> {}".format(a, b) + msg = f"{a} <-> {b}" # assertAlmostEqualRelError converts second to first helpers.assert_quantity_almost_equal(b, a, rtol=0.01, msg=msg) @@ -705,7 +705,7 @@ class TestDefinedContexts: da, db = Context.__keytransform__(a.dimensionality, b.dimensionality) p = find_shortest_path(ureg._active_ctx.graph, da, db) assert p - msg = "{} <-> {}".format(a, b) + msg = f"{a} <-> {b}" helpers.assert_quantity_almost_equal(b, a, rtol=0.01, msg=msg) # Check RKM <-> cN/tex conversion diff --git a/pint/testsuite/test_converters.py b/pint/testsuite/test_converters.py index 62ffdb7..71a076f 100644 --- a/pint/testsuite/test_converters.py +++ b/pint/testsuite/test_converters.py @@ -69,7 +69,7 @@ class TestConverter: @helpers.requires_numpy def test_log_converter_inplace(self): - arb_value = 3.14 + arb_value = 3.13 c = LogarithmicConverter(scale=1, logbase=10, logfactor=1) from_to = lambda value, inplace: c.from_reference( diff --git a/pint/testsuite/test_dask.py b/pint/testsuite/test_dask.py index f4dee6a..0e6a1cf 100644 --- a/pint/testsuite/test_dask.py +++ b/pint/testsuite/test_dask.py @@ -1,5 +1,6 @@ import importlib -import os + +import pathlib import pytest @@ -135,8 +136,8 @@ def test_visualize(local_registry, dask_array): assert res is None # These commands only work on Unix and Windows - assert os.path.exists("mydask.png") - os.remove("mydask.png") + assert pathlib.Path("mydask.png").exists() + pathlib.Path("mydask.png").unlink() def test_compute_persist_equivalent(local_registry, dask_array, numpy_array): diff --git a/pint/testsuite/test_definitions.py b/pint/testsuite/test_definitions.py index 2618c6e..69a337d 100644 --- a/pint/testsuite/test_definitions.py +++ b/pint/testsuite/test_definitions.py @@ -1,5 +1,7 @@ import pytest +import math + from pint.definitions import Definition from pint.errors import DefinitionSyntaxError from pint.facets.nonmultiplicative.definitions import ( @@ -81,7 +83,7 @@ class TestDefinition: assert x.reference == UnitsContainer(kelvin=1) x = Definition.from_string( - "turn = 6.28 * radian = _ = revolution = = cycle = _" + f"turn = {math.tau} * radian = _ = revolution = = cycle = _" ) assert isinstance(x, UnitDefinition) assert x.name == "turn" @@ -89,7 +91,7 @@ class TestDefinition: assert x.symbol == "turn" assert not x.is_base assert isinstance(x.converter, ScaleConverter) - assert x.converter.scale == 6.28 + assert x.converter.scale == math.tau assert x.reference == UnitsContainer(radian=1) with pytest.raises(ValueError): @@ -136,7 +138,7 @@ class TestDefinition: assert x.converter.logfactor == 1 assert x.reference == UnitsContainer() - eulersnumber = 2.71828182845904523536028747135266249775724709369995 + eulersnumber = math.e x = Definition.from_string( "neper = 1 ; logbase: %1.50f; logfactor: 0.5 = Np" % eulersnumber ) diff --git a/pint/testsuite/test_errors.py b/pint/testsuite/test_errors.py index 6a42eec..a045f6e 100644 --- a/pint/testsuite/test_errors.py +++ b/pint/testsuite/test_errors.py @@ -116,7 +116,7 @@ class TestErrors: q2 = ureg.Quantity("1 bar") for protocol in range(pickle.HIGHEST_PROTOCOL + 1): - for ex in [ + for ex in ( DefinitionSyntaxError("foo"), RedefinitionError("foo", "bar"), UndefinedUnitError("meter"), @@ -125,7 +125,7 @@ class TestErrors: Quantity("1 kg")._units, Quantity("1 s")._units ), OffsetUnitCalculusError(q1._units, q2._units), - ]: + ): with subtests.test(protocol=protocol, etype=type(ex)): pik = pickle.dumps(ureg.Quantity("1 foo"), protocol) with pytest.raises(UndefinedUnitError): diff --git a/pint/testsuite/test_formatter.py b/pint/testsuite/test_formatter.py index 9e362fc..5a51a0a 100644 --- a/pint/testsuite/test_formatter.py +++ b/pint/testsuite/test_formatter.py @@ -5,13 +5,13 @@ from pint import formatting as fmt class TestFormatter: def test_join(self): - for empty in (tuple(), []): + for empty in ((), []): assert fmt._join("s", empty) == "" assert fmt._join("*", "1 2 3".split()) == "1*2*3" assert fmt._join("{0}*{1}", "1 2 3".split()) == "1*2*3" def test_formatter(self): - assert fmt.formatter(dict().items()) == "" + assert fmt.formatter({}.items()) == "" assert fmt.formatter(dict(meter=1).items()) == "meter" assert fmt.formatter(dict(meter=-1).items()) == "1 / meter" assert fmt.formatter(dict(meter=-1).items(), as_ratio=False) == "meter ** -1" diff --git a/pint/testsuite/test_infer_base_unit.py b/pint/testsuite/test_infer_base_unit.py index f2605c6..9a27362 100644 --- a/pint/testsuite/test_infer_base_unit.py +++ b/pint/testsuite/test_infer_base_unit.py @@ -34,9 +34,9 @@ class TestInferBaseUnit: ureg = UnitRegistry(non_int_type=Decimal) QD = ureg.Quantity - ibu_d = infer_base_unit(QD(Decimal("1"), "millimeter * nanometer")) + ibu_d = infer_base_unit(QD(Decimal(1), "millimeter * nanometer")) - assert ibu_d == QD(Decimal("1"), "meter**2").units + assert ibu_d == QD(Decimal(1), "meter**2").units assert all(isinstance(v, Decimal) for v in ibu_d.values()) @@ -69,9 +69,9 @@ class TestInferBaseUnit: Q = ureg.Quantity r = ( Q(Decimal("1000000000.0"), "m") - * Q(Decimal("1"), "mm") - / Q(Decimal("1"), "s") - / Q(Decimal("1"), "ms") + * Q(Decimal(1), "mm") + / Q(Decimal(1), "s") + / Q(Decimal(1), "ms") ) compact_r = r.to_compact() expected = Q(Decimal("1000.0"), "kilometer**2 / second**2") diff --git a/pint/testsuite/test_issues.py b/pint/testsuite/test_issues.py index 8517bd9..9540814 100644 --- a/pint/testsuite/test_issues.py +++ b/pint/testsuite/test_issues.py @@ -445,10 +445,10 @@ class TestIssues(QuantityTestCase): def test_issue354_356_370(self, module_registry): assert ( - "{:~}".format(1 * module_registry.second / module_registry.millisecond) + f"{1 * module_registry.second / module_registry.millisecond:~}" == "1.0 s / ms" ) - assert "{:~}".format(1 * module_registry.count) == "1 count" + assert f"{1 * module_registry.count:~}" == "1 count" assert "{:~}".format(1 * module_registry("MiB")) == "1 MiB" def test_issue468(self, module_registry): diff --git a/pint/testsuite/test_log_units.py b/pint/testsuite/test_log_units.py index 2a048f6..3d1c905 100644 --- a/pint/testsuite/test_log_units.py +++ b/pint/testsuite/test_log_units.py @@ -56,7 +56,7 @@ class TestLogarithmicQuantity(QuantityTestCase): # ## Test dB to dB units octave - decade # 1 decade = log2(10) octave helpers.assert_quantity_almost_equal( - self.Q_(1.0, "decade"), self.Q_(math.log(10, 2), "octave") + self.Q_(1.0, "decade"), self.Q_(math.log2(10), "octave") ) # ## Test dB to dB units dBm - dBu # 0 dBm = 1mW = 1e3 uW = 30 dBu diff --git a/pint/testsuite/test_measurement.py b/pint/testsuite/test_measurement.py index b78ca0e..9de2762 100644 --- a/pint/testsuite/test_measurement.py +++ b/pint/testsuite/test_measurement.py @@ -178,7 +178,7 @@ class TestMeasurement(QuantityTestCase): ): with subtests.test(spec): self.ureg.default_format = spec - assert "{}".format(m) == result + assert f"{m}" == result def test_raise_build(self): v, u = self.Q_(1.0, "s"), self.Q_(0.1, "s") diff --git a/pint/testsuite/test_non_int.py b/pint/testsuite/test_non_int.py index 66637e1..5a74a99 100644 --- a/pint/testsuite/test_non_int.py +++ b/pint/testsuite/test_non_int.py @@ -740,10 +740,10 @@ class _TestQuantityBasicMath(NonIntTypeTestCase): zy = self.Q_(fun(y.magnitude), "meter") rx = fun(x) ry = fun(y) - assert rx == zx, "while testing {0}".format(fun) - assert ry == zy, "while testing {0}".format(fun) - assert rx is not zx, "while testing {0}".format(fun) - assert ry is not zy, "while testing {0}".format(fun) + assert rx == zx, f"while testing {fun}" + assert ry == zy, f"while testing {fun}" + assert rx is not zx, f"while testing {fun}" + assert ry is not zy, f"while testing {fun}" def test_quantity_float_complex(self): x = self.QP_("-4.2", None) @@ -1093,7 +1093,7 @@ class _TestOffsetUnitMath(NonIntTypeTestCase): else: in1, in2 = self.kwargs["non_int_type"](in1), self.QP_(*in2) input_tuple = in1, in2 # update input_tuple for better tracebacks - expected_copy = expected_output[:] + expected_copy = expected_output.copy() for i, mode in enumerate([False, True]): self.ureg.autoconvert_offset_to_baseunit = mode if expected_copy[i] == "error": @@ -1130,14 +1130,14 @@ class _TestOffsetUnitMath(NonIntTypeTestCase): def test_exponentiation(self, input_tuple, expected_output): self.ureg.default_as_delta = False in1, in2 = input_tuple - if type(in1) is tuple and type(in2) is tuple: + if type(in1) is type(in2) is tuple: in1, in2 = self.QP_(*in1), self.QP_(*in2) elif type(in1) is not tuple and type(in2) is tuple: in1, in2 = self.kwargs["non_int_type"](in1), self.QP_(*in2) else: in1, in2 = self.QP_(*in1), self.kwargs["non_int_type"](in2) input_tuple = in1, in2 - expected_copy = expected_output[:] + expected_copy = expected_output.copy() for i, mode in enumerate([False, True]): self.ureg.autoconvert_offset_to_baseunit = mode if expected_copy[i] == "error": diff --git a/pint/testsuite/test_numpy.py b/pint/testsuite/test_numpy.py index 1e0b928..0e96c77 100644 --- a/pint/testsuite/test_numpy.py +++ b/pint/testsuite/test_numpy.py @@ -303,7 +303,7 @@ class TestNumpyMathematicalFunctions(TestNumpyMethods): @helpers.requires_array_function_protocol() def test_fix(self): - helpers.assert_quantity_equal(np.fix(3.14 * self.ureg.m), 3.0 * self.ureg.m) + helpers.assert_quantity_equal(np.fix(3.13 * self.ureg.m), 3.0 * self.ureg.m) helpers.assert_quantity_equal(np.fix(3.0 * self.ureg.m), 3.0 * self.ureg.m) helpers.assert_quantity_equal( np.fix([2.1, 2.9, -2.1, -2.9] * self.ureg.m), @@ -505,7 +505,7 @@ class TestNumpyMathematicalFunctions(TestNumpyMethods): arr = np.array(range(3), dtype=float) q = self.Q_(arr, "meter") - for op_ in [op.pow, op.ipow, np.power]: + for op_ in (op.pow, op.ipow, np.power): q_cp = copy.copy(q) with pytest.raises(DimensionalityError): op_(2.0, q_cp) diff --git a/pint/testsuite/test_quantity.py b/pint/testsuite/test_quantity.py index 8fb712a..45b163d 100644 --- a/pint/testsuite/test_quantity.py +++ b/pint/testsuite/test_quantity.py @@ -393,7 +393,7 @@ class TestQuantity(QuantityTestCase): temp = (Q_(" 1 lbf*m")).to_preferred(preferred_units) # would prefer this to be repeatable, but mip doesn't guarantee that currently - assert temp.units in [ureg.W * ureg.s, ureg.ft * ureg.lbf] + assert temp.units in (ureg.W * ureg.s, ureg.ft * ureg.lbf) temp = Q_("1 kg").to_preferred(preferred_units) assert temp.units == ureg.slug @@ -1050,10 +1050,10 @@ class TestQuantityBasicMath(QuantityTestCase): zy = self.Q_(fun(y.magnitude), "meter") rx = fun(x) ry = fun(y) - assert rx == zx, "while testing {0}".format(fun) - assert ry == zy, "while testing {0}".format(fun) - assert rx is not zx, "while testing {0}".format(fun) - assert ry is not zy, "while testing {0}".format(fun) + assert rx == zx, f"while testing {fun}" + assert ry == zy, f"while testing {fun}" + assert rx is not zx, f"while testing {fun}" + assert ry is not zy, f"while testing {fun}" def test_quantity_float_complex(self): x = self.Q_(-4.2, None) @@ -1661,7 +1661,7 @@ class TestOffsetUnitMath(QuantityTestCase): else: in1, in2 = in1, self.Q_(*in2) input_tuple = in1, in2 # update input_tuple for better tracebacks - expected_copy = expected[:] + expected_copy = expected.copy() for i, mode in enumerate([False, True]): self.ureg.autoconvert_offset_to_baseunit = mode if expected_copy[i] == "error": @@ -1695,14 +1695,14 @@ class TestOffsetUnitMath(QuantityTestCase): def test_exponentiation(self, input_tuple, expected): self.ureg.default_as_delta = False in1, in2 = input_tuple - if type(in1) is tuple and type(in2) is tuple: + if type(in1) is type(in2) is tuple: in1, in2 = self.Q_(*in1), self.Q_(*in2) elif type(in1) is not tuple and type(in2) is tuple: in2 = self.Q_(*in2) else: in1 = self.Q_(*in1) input_tuple = in1, in2 - expected_copy = expected[:] + expected_copy = expected.copy() for i, mode in enumerate([False, True]): self.ureg.autoconvert_offset_to_baseunit = mode if expected_copy[i] == "error": @@ -1733,7 +1733,7 @@ class TestOffsetUnitMath(QuantityTestCase): def test_inplace_exponentiation(self, input_tuple, expected): self.ureg.default_as_delta = False in1, in2 = input_tuple - if type(in1) is tuple and type(in2) is tuple: + if type(in1) is type(in2) is tuple: (q1v, q1u), (q2v, q2u) = in1, in2 in1 = self.Q_(*(np.array([q1v] * 2, dtype=float), q1u)) in2 = self.Q_(q2v, q2u) @@ -1744,7 +1744,7 @@ class TestOffsetUnitMath(QuantityTestCase): input_tuple = in1, in2 - expected_copy = expected[:] + expected_copy = expected.copy() for i, mode in enumerate([False, True]): self.ureg.autoconvert_offset_to_baseunit = mode in1_cp = copy.copy(in1) diff --git a/pint/testsuite/test_umath.py b/pint/testsuite/test_umath.py index 6f32ab5..73d0ae7 100644 --- a/pint/testsuite/test_umath.py +++ b/pint/testsuite/test_umath.py @@ -79,7 +79,7 @@ class TestUFuncs: if results is None: results = [None] * len(ok_with) for x1, res in zip(ok_with, results): - err_msg = "At {} with {}".format(func.__name__, x1) + err_msg = f"At {func.__name__} with {x1}" if output_units == "same": ou = x1.units elif isinstance(output_units, (int, float)): @@ -163,7 +163,7 @@ class TestUFuncs: if results is None: results = [None] * len(ok_with) for x1, res in zip(ok_with, results): - err_msg = "At {} with {}".format(func.__name__, x1) + err_msg = f"At {func.__name__} with {x1}" qms = func(x1) if res is None: res = func(x1.magnitude) @@ -223,7 +223,7 @@ class TestUFuncs: """ for x2 in ok_with: - err_msg = "At {} with {} and {}".format(func.__name__, x1, x2) + err_msg = f"At {func.__name__} with {x1} and {x2}" if output_units == "same": ou = x1.units elif output_units == "prod": diff --git a/pint/testsuite/test_unit.py b/pint/testsuite/test_unit.py index 98a4fcc..c1a2704 100644 --- a/pint/testsuite/test_unit.py +++ b/pint/testsuite/test_unit.py @@ -2,6 +2,7 @@ import copy import functools import logging import math +import operator import re from contextlib import nullcontext as does_not_raise @@ -144,7 +145,7 @@ class TestUnit(QuantityTestCase): ureg = UnitRegistry() - assert "{:new}".format(ureg.m) == "new format" + assert f"{ureg.m:new}" == "new format" def test_ipython(self): alltext = [] @@ -193,7 +194,7 @@ class TestUnit(QuantityTestCase): ("unit", "power_ratio", "expectation", "expected_unit"), [ ("m", 2, does_not_raise(), "m**2"), - ("m", dict(), pytest.raises(TypeError), None), + ("m", {}, pytest.raises(TypeError), None), ], ) def test_unit_pow(self, unit, power_ratio, expectation, expected_unit): @@ -283,7 +284,7 @@ class TestRegistry(QuantityTestCase): with pytest.raises(errors.RedefinitionError): ureg.define("meter = [length]") with pytest.raises(TypeError): - ureg.define(list()) + ureg.define([]) ureg.define("degC = kelvin; offset: 273.15") def test_define(self): @@ -394,7 +395,7 @@ class TestRegistry(QuantityTestCase): ) def test_parse_pretty_degrees(self): - for exp in ["1Δ°C", "1 Δ°C", "ΔdegC", "delta_°C"]: + for exp in ("1Δ°C", "1 Δ°C", "ΔdegC", "delta_°C"): assert self.ureg.parse_expression(exp) == self.Q_( 1, UnitsContainer(delta_degree_Celsius=1) ) @@ -566,8 +567,7 @@ class TestRegistry(QuantityTestCase): assert f3(3.0 * ureg.centimeter) == 0.03 * ureg.centimeter assert f3(3.0 * ureg.meter) == 3.0 * ureg.centimeter - def gfunc(x, y): - return x + y + gfunc = operator.add g0 = ureg.wraps(None, [None, None])(gfunc) assert g0(3, 1) == 4 @@ -596,8 +596,7 @@ class TestRegistry(QuantityTestCase): def test_wrap_referencing(self): ureg = self.ureg - def gfunc(x, y): - return x + y + gfunc = operator.add def gfunc2(x, y): return x**2 + y @@ -650,8 +649,7 @@ class TestRegistry(QuantityTestCase): with pytest.raises(DimensionalityError): f0b(3.0 * ureg.kilogram) - def gfunc(x, y): - return x / y + gfunc = operator.truediv g0 = ureg.check(None, None)(gfunc) assert g0(6, 2) == 3 diff --git a/pint/testsuite/test_util.py b/pint/testsuite/test_util.py index fd6494a..a61194d 100644 --- a/pint/testsuite/test_util.py +++ b/pint/testsuite/test_util.py @@ -120,13 +120,13 @@ class TestUnitsContainer: UnitsContainer({"1": "2"}) d = UnitsContainer() with pytest.raises(TypeError): - d.__mul__(list()) + d.__mul__([]) with pytest.raises(TypeError): - d.__pow__(list()) + d.__pow__([]) with pytest.raises(TypeError): - d.__truediv__(list()) + d.__truediv__([]) with pytest.raises(TypeError): - d.__rtruediv__(list()) + d.__rtruediv__([]) class TestToUnitsContainer: @@ -193,9 +193,9 @@ class TestParseHelper: assert "seconds" / z() == ParserHelper(0.5, seconds=1, meter=-2) assert dict(seconds=1) / z() == ParserHelper(0.5, seconds=1, meter=-2) - def _test_eval_token(self, expected, expression, use_decimal=False): + def _test_eval_token(self, expected, expression): token = next(tokenizer(expression)) - actual = ParserHelper.eval_token(token, use_decimal=use_decimal) + actual = ParserHelper.eval_token(token) assert expected == actual assert type(expected) == type(actual) @@ -353,12 +353,12 @@ class TestOtherUtils: # Test with list, string, generator, and scalar assert iterable([0, 1, 2, 3]) assert iterable("test") - assert iterable((i for i in range(5))) + assert iterable(i for i in range(5)) assert not iterable(0) def test_sized(self): # Test with list, string, generator, and scalar assert sized([0, 1, 2, 3]) assert sized("test") - assert not sized((i for i in range(5))) + assert not sized(i for i in range(5)) assert not sized(0) diff --git a/pint/util.py b/pint/util.py index d5f3aab..149945b 100644 --- a/pint/util.py +++ b/pint/util.py @@ -10,54 +10,85 @@ from __future__ import annotations -import functools -import inspect import logging import math import operator import re -from collections.abc import Mapping +from collections.abc import Mapping, Iterable, Iterator from fractions import Fraction from functools import lru_cache, partial from logging import NullHandler from numbers import Number from token import NAME, NUMBER -from typing import TYPE_CHECKING, ClassVar, Optional, Type, Union +import tokenize +from typing import ( + TYPE_CHECKING, + ClassVar, + TypeAlias, + Callable, + TypeVar, + Any, +) +from collections.abc import Hashable, Generator from .compat import NUMERIC_TYPES, tokenizer from .errors import DefinitionSyntaxError from .formatting import format_unit from .pint_eval import build_eval_tree +from ._typing import PintScalar + if TYPE_CHECKING: - from ._typing import Quantity, UnitLike + from ._typing import Quantity, UnitLike, Self from .registry import UnitRegistry + logger = logging.getLogger(__name__) logger.addHandler(NullHandler()) +T = TypeVar("T") +TH = TypeVar("TH", bound=Hashable) +ItMatrix: TypeAlias = Iterable[Iterable[PintScalar]] +Matrix: TypeAlias = list[list[PintScalar]] + + +def _noop(x: T) -> T: + return x + def matrix_to_string( - matrix, row_headers=None, col_headers=None, fmtfun=lambda x: str(int(x)) -): - """Takes a 2D matrix (as nested list) and returns a string. + matrix: ItMatrix, + row_headers: Iterable[str] | None = None, + col_headers: Iterable[str] | None = None, + fmtfun: Callable[ + [ + PintScalar, + ], + str, + ] = "{:0.0f}".format, +) -> str: + """Return a string representation of a matrix. Parameters ---------- - matrix : - - row_headers : - (Default value = None) - col_headers : - (Default value = None) - fmtfun : - (Default value = lambda x: str(int(x))) + matrix + A matrix given as an iterable of an iterable of numbers. + row_headers + An iterable of strings to serve as row headers. + (default = None, meaning no row headers are printed.) + col_headers + An iterable of strings to serve as column headers. + (default = None, meaning no col headers are printed.) + fmtfun + A callable to convert a number into string. + (default = `"{:0.0f}".format`) Returns ------- - + str + String representation of the matrix. """ - ret = [] + ret: list[str] = [] if col_headers: ret.append(("\t" if row_headers else "") + "\t".join(col_headers)) if row_headers: @@ -71,99 +102,124 @@ def matrix_to_string( return "\n".join(ret) -def transpose(matrix): - """Takes a 2D matrix (as nested list) and returns the transposed version. +def transpose(matrix: ItMatrix) -> Matrix: + """Return the transposed version of a matrix. Parameters ---------- - matrix : - + matrix + A matrix given as an iterable of an iterable of numbers. Returns ------- - + Matrix + The transposed version of the matrix. """ return [list(val) for val in zip(*matrix)] -def column_echelon_form(matrix, ntype=Fraction, transpose_result=False): - """Calculates the column echelon form using Gaussian elimination. +def matrix_apply( + matrix: ItMatrix, + func: Callable[ + [ + PintScalar, + ], + PintScalar, + ], +) -> Matrix: + """Apply a function to individual elements within a matrix. Parameters ---------- - matrix : - a 2D matrix as nested list. - ntype : - the numerical type to use in the calculation. (Default value = Fraction) - transpose_result : - indicates if the returned matrix should be transposed. (Default value = False) + matrix + A matrix given as an iterable of an iterable of numbers. + func + A callable that converts a number to another. Returns ------- - type - column echelon form, transformed identity matrix, swapped rows - + A new matrix in which each element has been replaced by new one. """ - lead = 0 + return [[func(x) for x in row] for row in matrix] - M = transpose(matrix) - _transpose = transpose if transpose_result else lambda x: x +def column_echelon_form( + matrix: ItMatrix, ntype: type = Fraction, transpose_result: bool = False +) -> tuple[Matrix, Matrix, list[int]]: + """Calculate the column echelon form using Gaussian elimination. - rows, cols = len(M), len(M[0]) + Parameters + ---------- + matrix + A 2D matrix as nested list. + ntype + The numerical type to use in the calculation. + (default = Fraction) + transpose_result + Indicates if the returned matrix should be transposed. + (default = False) - new_M = [] - for row in M: - r = [] - for x in row: - if isinstance(x, float): - x = ntype.from_float(x) - else: - x = ntype(x) - r.append(x) - new_M.append(r) - M = new_M + Returns + ------- + ech_matrix + Column echelon form. + id_matrix + Transformed identity matrix. + swapped + Swapped rows. + """ + + _transpose = transpose if transpose_result else _noop + + ech_matrix = matrix_apply( + transpose(matrix), + lambda x: ntype.from_float(x) if isinstance(x, float) else ntype(x), # type: ignore + ) + rows, cols = len(ech_matrix), len(ech_matrix[0]) # M = [[ntype(x) for x in row] for row in M] - I = [ # noqa: E741 + id_matrix: list[list[PintScalar]] = [ # noqa: E741 [ntype(1) if n == nc else ntype(0) for nc in range(rows)] for n in range(rows) ] - swapped = [] + swapped: list[int] = [] + lead = 0 for r in range(rows): if lead >= cols: - return _transpose(M), _transpose(I), swapped - i = r - while M[i][lead] == 0: - i += 1 - if i != rows: + return _transpose(ech_matrix), _transpose(id_matrix), swapped + s = r + while ech_matrix[s][lead] == 0: # type: ignore + s += 1 + if s != rows: continue - i = r + s = r lead += 1 if cols == lead: - return _transpose(M), _transpose(I), swapped + return _transpose(ech_matrix), _transpose(id_matrix), swapped - M[i], M[r] = M[r], M[i] - I[i], I[r] = I[r], I[i] + ech_matrix[s], ech_matrix[r] = ech_matrix[r], ech_matrix[s] + id_matrix[s], id_matrix[r] = id_matrix[r], id_matrix[s] - swapped.append(i) - lv = M[r][lead] - M[r] = [mrx / lv for mrx in M[r]] - I[r] = [mrx / lv for mrx in I[r]] + swapped.append(s) + lv = ech_matrix[r][lead] + ech_matrix[r] = [mrx / lv for mrx in ech_matrix[r]] + id_matrix[r] = [mrx / lv for mrx in id_matrix[r]] - for i in range(rows): - if i == r: + for s in range(rows): + if s == r: continue - lv = M[i][lead] - M[i] = [iv - lv * rv for rv, iv in zip(M[r], M[i])] - I[i] = [iv - lv * rv for rv, iv in zip(I[r], I[i])] + lv = ech_matrix[s][lead] + ech_matrix[s] = [ + iv - lv * rv for rv, iv in zip(ech_matrix[r], ech_matrix[s]) + ] + id_matrix[s] = [iv - lv * rv for rv, iv in zip(id_matrix[r], id_matrix[s])] lead += 1 - return _transpose(M), _transpose(I), swapped + return _transpose(ech_matrix), _transpose(id_matrix), swapped -def pi_theorem(quantities, registry=None): +def pi_theorem(quantities: dict[str, Any], registry: UnitRegistry | None = None): """Builds dimensionless quantities using the Buckingham π theorem Parameters @@ -171,7 +227,7 @@ def pi_theorem(quantities, registry=None): quantities : dict mapping between variable name and units registry : - (Default value = None) + (default value = None) Returns ------- @@ -185,7 +241,7 @@ def pi_theorem(quantities, registry=None): dimensions = set() if registry is None: - getdim = lambda x: x + getdim = _noop non_int_type = float else: getdim = registry.get_dimensionality @@ -213,33 +269,35 @@ def pi_theorem(quantities, registry=None): dimensions = list(dimensions) # Calculate dimensionless quantities - M = [ + matrix = [ [dimensionality[dimension] for name, dimensionality in quant] for dimension in dimensions ] - M, identity, pivot = column_echelon_form(M, transpose_result=False) + ech_matrix, id_matrix, pivot = column_echelon_form(matrix, transpose_result=False) # Collect results # Make all numbers integers and minimize the number of negative exponents. # Remove zeros results = [] - for rowm, rowi in zip(M, identity): + for rowm, rowi in zip(ech_matrix, id_matrix): if any(el != 0 for el in rowm): continue max_den = max(f.denominator for f in rowi) neg = -1 if sum(f < 0 for f in rowi) > sum(f > 0 for f in rowi) else 1 results.append( - dict( - (q[0], neg * f.numerator * max_den / f.denominator) + { + q[0]: neg * f.numerator * max_den / f.denominator for q, f in zip(quant, rowi) if f.numerator != 0 - ) + } ) return results -def solve_dependencies(dependencies): +def solve_dependencies( + dependencies: dict[TH, set[TH]] +) -> Generator[set[TH], None, None]: """Solve a dependency graph. Parameters @@ -248,12 +306,16 @@ def solve_dependencies(dependencies): dependency dictionary. For each key, the value is an iterable indicating its dependencies. - Returns - ------- - type + Yields + ------ + set iterator of sets, each containing keys of independents tasks dependent only of the previous tasks in the list. + Raises + ------ + ValueError + if a cyclic dependency is found. """ while dependencies: # values not in keys (items without dep) @@ -272,12 +334,37 @@ def solve_dependencies(dependencies): yield t -def find_shortest_path(graph, start, end, path=None): +def find_shortest_path( + graph: dict[TH, set[TH]], start: TH, end: TH, path: list[TH] | None = None +): + """Find shortest path between two nodes within a graph. + + Parameters + ---------- + graph + A graph given as a mapping of nodes + to a set of all connected nodes to it. + start + Starting node. + end + End node. + path + Path to prepend to the one found. + (default = None, empty path.) + + Returns + ------- + list[TH] + The shortest path between two nodes. + """ path = (path or []) + [start] if start == end: return path + + # TODO: raise ValueError when start not in graph if start not in graph: return None + shortest = None for node in graph[start]: if node not in path: @@ -285,10 +372,33 @@ def find_shortest_path(graph, start, end, path=None): if newpath: if not shortest or len(newpath) < len(shortest): shortest = newpath + return shortest -def find_connected_nodes(graph, start, visited=None): +def find_connected_nodes( + graph: dict[TH, set[TH]], start: TH, visited: set[TH] | None = None +) -> set[TH] | None: + """Find all nodes connected to a start node within a graph. + + Parameters + ---------- + graph + A graph given as a mapping of nodes + to a set of all connected nodes to it. + start + Starting node. + visited + Mutable set to collect visited nodes. + (default = None, empty set) + + Returns + ------- + set[TH] + The shortest path between two nodes. + """ + + # TODO: raise ValueError when start not in graph if start not in graph: return None @@ -302,17 +412,17 @@ def find_connected_nodes(graph, start, visited=None): return visited -class udict(dict): +class udict(dict[str, PintScalar]): """Custom dict implementing __missing__.""" - def __missing__(self, key): + def __missing__(self, key: str): return 0 - def copy(self): + def copy(self: Self) -> Self: return udict(self) -class UnitsContainer(Mapping): +class UnitsContainer(Mapping[str, PintScalar]): """The UnitsContainer stores the product of units and their respective exponent and implements the corresponding operations. @@ -320,23 +430,24 @@ class UnitsContainer(Mapping): Parameters ---------- - - Returns - ------- - type - - + non_int_type + Numerical type used for non integer values. """ __slots__ = ("_d", "_hash", "_one", "_non_int_type") - def __init__(self, *args, **kwargs) -> None: + _d: udict + _hash: int | None + _one: PintScalar + _non_int_type: type + + def __init__(self, *args, non_int_type: type | None = None, **kwargs) -> None: if args and isinstance(args[0], UnitsContainer): default_non_int_type = args[0]._non_int_type else: default_non_int_type = float - self._non_int_type = kwargs.pop("non_int_type", default_non_int_type) + self._non_int_type = non_int_type or default_non_int_type if self._non_int_type is float: self._one = 1 @@ -347,17 +458,33 @@ class UnitsContainer(Mapping): self._d = d for key, value in d.items(): if not isinstance(key, str): - raise TypeError("key must be a str, not {}".format(type(key))) + raise TypeError(f"key must be a str, not {type(key)}") if not isinstance(value, Number): - raise TypeError("value must be a number, not {}".format(type(value))) + raise TypeError(f"value must be a number, not {type(value)}") if not isinstance(value, int) and not isinstance(value, self._non_int_type): d[key] = self._non_int_type(value) self._hash = None - def copy(self): + def copy(self: Self) -> Self: + """Create a copy of this UnitsContainer.""" return self.__copy__() - def add(self, key, value): + def add(self: Self, key: str, value: Number) -> Self: + """Create a new UnitsContainer adding value to + the value existing for a given key. + + Parameters + ---------- + key + unit to which the value will be added. + value + value to be added. + + Returns + ------- + UnitsContainer + A copy of this container. + """ newval = self._d[key] + value new = self.copy() if newval: @@ -367,17 +494,18 @@ class UnitsContainer(Mapping): new._hash = None return new - def remove(self, keys): - """Create a new UnitsContainer purged from given keys. + def remove(self: Self, keys: Iterable[str]) -> Self: + """Create a new UnitsContainer purged from given entries. Parameters ---------- - keys : - + keys + Iterable of keys (units) to remove. Returns ------- - + UnitsContainer + A copy of this container. """ new = self.copy() for k in keys: @@ -385,51 +513,52 @@ class UnitsContainer(Mapping): new._hash = None return new - def rename(self, oldkey, newkey): + def rename(self: Self, oldkey: str, newkey: str) -> Self: """Create a new UnitsContainer in which an entry has been renamed. Parameters ---------- - oldkey : - - newkey : - + oldkey + Existing key (unit). + newkey + New key (unit). Returns ------- - + UnitsContainer + A copy of this container. """ new = self.copy() new._d[newkey] = new._d.pop(oldkey) new._hash = None return new - def __iter__(self): + def __iter__(self) -> Iterator[str]: return iter(self._d) def __len__(self) -> int: return len(self._d) - def __getitem__(self, key): + def __getitem__(self, key: str) -> PintScalar: return self._d[key] - def __contains__(self, key): + def __contains__(self, key: str) -> bool: return key in self._d - def __hash__(self): + def __hash__(self) -> int: if self._hash is None: self._hash = hash(frozenset(self._d.items())) return self._hash # Only needed by pickle protocol 0 and 1 (used by pytables) - def __getstate__(self): + def __getstate__(self) -> tuple[udict, PintScalar, type]: return self._d, self._one, self._non_int_type - def __setstate__(self, state): + def __setstate__(self, state: tuple[udict, PintScalar, type]): self._d, self._one, self._non_int_type = state self._hash = None - def __eq__(self, other) -> bool: + def __eq__(self, other: Any) -> bool: if isinstance(other, UnitsContainer): # UnitsContainer.__hash__(self) is not the same as hash(self); see # ParserHelper.__hash__ and __eq__. @@ -455,9 +584,9 @@ class UnitsContainer(Mapping): def __repr__(self) -> str: tmp = "{%s}" % ", ".join( - ["'{}': {}".format(key, value) for key, value in sorted(self._d.items())] + [f"'{key}': {value}" for key, value in sorted(self._d.items())] ) - return "<UnitsContainer({})>".format(tmp) + return f"<UnitsContainer({tmp})>" def __format__(self, spec: str) -> str: return format_unit(self, spec) @@ -474,7 +603,7 @@ class UnitsContainer(Mapping): out._one = self._one return out - def __mul__(self, other): + def __mul__(self, other: Any): if not isinstance(other, self.__class__): err = "Cannot multiply UnitsContainer by {}" raise TypeError(err.format(type(other))) @@ -490,7 +619,7 @@ class UnitsContainer(Mapping): __rmul__ = __mul__ - def __pow__(self, other): + def __pow__(self, other: Any): if not isinstance(other, NUMERIC_TYPES): err = "Cannot power UnitsContainer by {}" raise TypeError(err.format(type(other))) @@ -501,7 +630,7 @@ class UnitsContainer(Mapping): new._hash = None return new - def __truediv__(self, other): + def __truediv__(self, other: Any): if not isinstance(other, self.__class__): err = "Cannot divide UnitsContainer by {}" raise TypeError(err.format(type(other))) @@ -515,7 +644,7 @@ class UnitsContainer(Mapping): new._hash = None return new - def __rtruediv__(self, other): + def __rtruediv__(self, other: Any): if not isinstance(other, self.__class__) and other != 1: err = "Cannot divide {} by UnitsContainer" raise TypeError(err.format(type(other))) @@ -526,41 +655,48 @@ class UnitsContainer(Mapping): class ParserHelper(UnitsContainer): """The ParserHelper stores in place the product of variables and their respective exponent and implements the corresponding operations. + It also provides a scaling factor. + + For example: + `3 * m ** 2` becomes ParserHelper(3, m=2) + + Briefly is a UnitsContainer with a scaling factor. ParserHelper is a read-only mapping. All operations (even in place ones) + WARNING : The hash value used does not take into account the scale + attribute so be careful if you use it as a dict key and then two unequal + object can have the same hash. + Parameters ---------- - - Returns - ------- - type - WARNING : The hash value used does not take into account the scale - attribute so be careful if you use it as a dict key and then two unequal - object can have the same hash. - + scale + Scaling factor. + (default = 1) + **kwargs + Used to populate the dict of units and exponents. """ __slots__ = ("scale",) - def __init__(self, scale=1, *args, **kwargs): + scale: PintScalar + + def __init__(self, scale: PintScalar = 1, *args, **kwargs): super().__init__(*args, **kwargs) self.scale = scale @classmethod - def from_word(cls, input_word, non_int_type=float): + def from_word(cls, input_word: str, non_int_type: type = float) -> ParserHelper: """Creates a ParserHelper object with a single variable with exponent one. - Equivalent to: ParserHelper({'word': 1}) + Equivalent to: ParserHelper(1, {input_word: 1}) Parameters ---------- - input_word : - - - Returns - ------- + input_word + non_int_type + Numerical type used for non integer values. """ if non_int_type is float: return cls(1, [(input_word, 1)], non_int_type=non_int_type) @@ -569,15 +705,7 @@ class ParserHelper(UnitsContainer): return cls(ONE, [(input_word, ONE)], non_int_type=non_int_type) @classmethod - def eval_token(cls, token, use_decimal=False, non_int_type=float): - # TODO: remove this code when use_decimal is deprecated - if use_decimal: - raise DeprecationWarning( - "`use_decimal` is deprecated, use `non_int_type` keyword argument when instantiating the registry.\n" - ">>> from decimal import Decimal\n" - ">>> ureg = UnitRegistry(non_int_type=Decimal)" - ) - + def eval_token(cls, token: tokenize.TokenInfo, non_int_type: type = float): token_type = token.type token_text = token.string if token_type == NUMBER: @@ -594,18 +722,16 @@ class ParserHelper(UnitsContainer): raise Exception("unknown token type") @classmethod - @lru_cache() - def from_string(cls, input_string, non_int_type=float): + @lru_cache + def from_string(cls, input_string: str, non_int_type: type = float) -> ParserHelper: """Parse linear expression mathematical units and return a quantity object. Parameters ---------- - input_string : - - - Returns - ------- + input_string + non_int_type + Numerical type used for non integer values. """ if not input_string: return cls(non_int_type=non_int_type) @@ -666,17 +792,17 @@ class ParserHelper(UnitsContainer): super().__setstate__(state[:-1]) self.scale = state[-1] - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if isinstance(other, ParserHelper): return self.scale == other.scale and super().__eq__(other) elif isinstance(other, str): return self == ParserHelper.from_string(other, self._non_int_type) elif isinstance(other, Number): return self.scale == other and not len(self._d) - else: - return self.scale == 1 and super().__eq__(other) - def operate(self, items, op=operator.iadd, cleanup=True): + return self.scale == 1 and super().__eq__(other) + + def operate(self, items, op=operator.iadd, cleanup: bool = True): d = udict(self._d) for key, value in items: d[key] = op(d[key], value) @@ -690,15 +816,15 @@ class ParserHelper(UnitsContainer): def __str__(self): tmp = "{%s}" % ", ".join( - ["'{}': {}".format(key, value) for key, value in sorted(self._d.items())] + [f"'{key}': {value}" for key, value in sorted(self._d.items())] ) - return "{} {}".format(self.scale, tmp) + return f"{self.scale} {tmp}" def __repr__(self): tmp = "{%s}" % ", ".join( - ["'{}': {}".format(key, value) for key, value in sorted(self._d.items())] + [f"'{key}': {value}" for key, value in sorted(self._d.items())] ) - return "<ParserHelper({}, {})>".format(self.scale, tmp) + return f"<ParserHelper({self.scale}, {tmp})>" def __mul__(self, other): if isinstance(other, str): @@ -821,21 +947,22 @@ class SharedRegistryObject: inst._REGISTRY = application_registry.get() return inst - def _check(self, other) -> bool: + def _check(self, other: Any) -> bool: """Check if the other object use a registry and if so that it is the same registry. Parameters ---------- - other : - + other Returns ------- - type - other don't use a registry and raise ValueError if other don't use the - same unit registry. + bool + Raises + ------ + ValueError + if other don't use the same unit registry. """ if self._REGISTRY is getattr(other, "_REGISTRY", None): return True @@ -854,40 +981,39 @@ class PrettyIPython: default_format: str - def _repr_html_(self): + def _repr_html_(self) -> str: if "~" in self.default_format: - return "{:~H}".format(self) - else: - return "{:H}".format(self) + return f"{self:~H}" + return f"{self:H}" - def _repr_latex_(self): + def _repr_latex_(self) -> str: if "~" in self.default_format: - return "${:~L}$".format(self) - else: - return "${:L}$".format(self) + return f"${self:~L}$" + return f"${self:L}$" - def _repr_pretty_(self, p, cycle): + def _repr_pretty_(self, p, cycle: bool): if "~" in self.default_format: - p.text("{:~P}".format(self)) + p.text(f"{self:~P}") else: - p.text("{:P}".format(self)) + p.text(f"{self:P}") def to_units_container( - unit_like: Union[UnitLike, Quantity], registry: Optional[UnitRegistry] = None + unit_like: UnitLike | Quantity, registry: UnitRegistry | None = None ) -> UnitsContainer: """Convert a unit compatible type to a UnitsContainer. Parameters ---------- - unit_like : - - registry : - (Default value = None) + unit_like + Quantity or Unit to infer the plain units from. + registry + If provided, uses the registry's UnitsContainer and parse_unit_name. If None, + uses the registry attached to unit_like. Returns ------- - + UnitsContainer """ mro = type(unit_like).mro() if UnitsContainer in mro: @@ -907,17 +1033,16 @@ def to_units_container( def infer_base_unit( - unit_like: Union[UnitLike, Quantity], registry: Optional[UnitRegistry] = None + unit_like: UnitLike | Quantity, registry: UnitRegistry | None = None ) -> UnitsContainer: """ Given a Quantity or UnitLike, give the UnitsContainer for it's plain units. Parameters ---------- - unit_like : Union[UnitLike, Quantity] + unit_like Quantity or Unit to infer the plain units from. - - registry: Optional[UnitRegistry] + registry If provided, uses the registry's UnitsContainer and parse_unit_name. If None, uses the registry attached to unit_like. @@ -952,7 +1077,7 @@ def infer_base_unit( return registry.UnitsContainer(nonzero_dict) -def getattr_maybe_raise(self, item): +def getattr_maybe_raise(obj: Any, item: str): """Helper function invoked at start of all overridden ``__getattr__``. Raise AttributeError if the user tries to ask for a _ or __ attribute, @@ -961,39 +1086,25 @@ def getattr_maybe_raise(self, item): Parameters ---------- - item : string - Item to be found. - - - Returns - ------- + item + attribute to be found. + Raises + ------ + AttributeError """ # Double-underscore attributes are tricky to detect because they are - # automatically prefixed with the class name - which may be a subclass of self + # automatically prefixed with the class name - which may be a subclass of obj if ( item.endswith("__") or len(item.lstrip("_")) == 0 or (item.startswith("_") and not item.lstrip("_")[0].isdigit()) ): - raise AttributeError("%r object has no attribute %r" % (self, item)) - + raise AttributeError(f"{obj!r} object has no attribute {item!r}") -def iterable(y) -> bool: - """Check whether or not an object can be iterated over. - Vendored from numpy under the terms of the BSD 3-Clause License. (Copyright - (c) 2005-2019, NumPy Developers.) - - Parameters - ---------- - value : - Input object. - type : - object - y : - - """ +def iterable(y: Any) -> bool: + """Check whether or not an object can be iterated over.""" try: iter(y) except TypeError: @@ -1001,18 +1112,8 @@ def iterable(y) -> bool: return True -def sized(y) -> bool: - """Check whether or not an object has a defined length. - - Parameters - ---------- - value : - Input object. - type : - object - y : - - """ +def sized(y: Any) -> bool: + """Check whether or not an object has a defined length.""" try: len(y) except TypeError: @@ -1020,37 +1121,9 @@ def sized(y) -> bool: return True -@functools.lru_cache( - maxsize=None -) # TODO: replace with cache when Python 3.8 is dropped. -def _build_type(class_name: str, bases): - return type(class_name, bases, dict()) - - -def build_dependent_class(registry_class, class_name: str, attribute_name: str) -> Type: - """Creates a class specifically for the given registry that - subclass all the classes named by the registry bases in a - specific attribute - - 1. List the 'attribute_name' attribute for each of the bases of the registry class. - 2. Use this list as bases for the new class - 3. Add the provided registry as the class registry. - - """ - bases = ( - getattr(base, attribute_name) - for base in inspect.getmro(registry_class) - if attribute_name in base.__dict__ - ) - bases = tuple(dict.fromkeys(bases, None).keys()) - if len(bases) == 1 and bases[0].__name__ == class_name: - return bases[0] - return _build_type(class_name, bases) - - -def create_class_with_registry(registry, base_class) -> Type: +def create_class_with_registry(registry: UnitRegistry, base_class: type) -> type: """Create new class inheriting from base_class and filling _REGISTRY class attribute with an actual instanced registry. """ - return type(base_class.__name__, tuple((base_class,)), dict(_REGISTRY=registry)) + return type(base_class.__name__, (base_class,), dict(_REGISTRY=registry)) diff --git a/pyproject.toml b/pyproject.toml index 72b6560..bbcfbdf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,13 +22,12 @@ classifiers = [ "Programming Language :: Python", "Topic :: Scientific/Engineering", "Topic :: Software Development :: Libraries", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11" ] -requires-python = ">=3.8" -dynamic = ["version"] +requires-python = ">=3.9" +dynamic = ["version"] # Version is taken from git tags using setuptools_scm [tool.setuptools.package-data] pint = [ |