summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHernan Grecco <hgrecco@gmail.com>2023-05-01 19:23:47 -0300
committerHernan Grecco <hgrecco@gmail.com>2023-05-01 19:24:18 -0300
commit556aeea0d363f5757c42296ad66ffe47f05c02b2 (patch)
tree476472495bd9c623e098aa6b5ee7554fd44387a5
parent95f3eaca1129b735cb3eae8702ea857928a05909 (diff)
parentb1c01862a3811f77cca726675b52a32418c4d853 (diff)
downloadpint-556aeea0d363f5757c42296ad66ffe47f05c02b2.tar.gz
Merge changes to modernize code from 0.21 to 0.22
See #1751
-rw-r--r--.github/workflows/ci.yml26
-rw-r--r--.github/workflows/docs.yml4
-rw-r--r--.github/workflows/publish.yml27
-rw-r--r--.readthedocs.yaml2
-rw-r--r--README.rst2
-rw-r--r--benchmarks/benchmarks/20_quantity.py2
-rw-r--r--benchmarks/benchmarks/30_numpy.py4
-rw-r--r--bors.toml8
-rw-r--r--docs/dev/contributing.rst2
-rw-r--r--docs/getting/index.rst2
-rw-r--r--docs/getting/overview.rst2
-rw-r--r--pint/_typing.py48
-rw-r--r--pint/babel_names.py6
-rw-r--r--pint/compat.py89
-rw-r--r--pint/context.py2
-rw-r--r--pint/converters.py20
-rw-r--r--pint/definitions.py22
-rw-r--r--pint/delegates/__init__.py2
-rw-r--r--pint/delegates/base_defparser.py25
-rw-r--r--pint/delegates/txt_defparser/__init__.py4
-rw-r--r--pint/delegates/txt_defparser/block.py19
-rw-r--r--pint/delegates/txt_defparser/common.py6
-rw-r--r--pint/delegates/txt_defparser/context.py81
-rw-r--r--pint/delegates/txt_defparser/defaults.py20
-rw-r--r--pint/delegates/txt_defparser/defparser.py45
-rw-r--r--pint/delegates/txt_defparser/group.py24
-rw-r--r--pint/delegates/txt_defparser/plain.py24
-rw-r--r--pint/delegates/txt_defparser/system.py23
-rw-r--r--pint/errors.py28
-rw-r--r--pint/facets/__init__.py22
-rw-r--r--pint/facets/context/definitions.py19
-rw-r--r--pint/facets/context/objects.py21
-rw-r--r--pint/facets/context/registry.py22
-rw-r--r--pint/facets/dask/__init__.py10
-rw-r--r--pint/facets/formatting/objects.py24
-rw-r--r--pint/facets/formatting/registry.py4
-rw-r--r--pint/facets/group/definitions.py15
-rw-r--r--pint/facets/group/objects.py56
-rw-r--r--pint/facets/group/registry.py18
-rw-r--r--pint/facets/measurement/objects.py8
-rw-r--r--pint/facets/measurement/registry.py16
-rw-r--r--pint/facets/nonmultiplicative/definitions.py10
-rw-r--r--pint/facets/nonmultiplicative/objects.py10
-rw-r--r--pint/facets/nonmultiplicative/registry.py8
-rw-r--r--pint/facets/numpy/numpy_func.py56
-rw-r--r--pint/facets/numpy/quantity.py12
-rw-r--r--pint/facets/numpy/registry.py4
-rw-r--r--pint/facets/numpy/unit.py11
-rw-r--r--pint/facets/plain/definitions.py33
-rw-r--r--pint/facets/plain/objects.py2
-rw-r--r--pint/facets/plain/quantity.py113
-rw-r--r--pint/facets/plain/registry.py150
-rw-r--r--pint/facets/plain/unit.py18
-rw-r--r--pint/facets/system/definitions.py17
-rw-r--r--pint/facets/system/objects.py53
-rw-r--r--pint/facets/system/registry.py24
-rw-r--r--pint/formatting.py54
-rw-r--r--pint/matplotlib.py8
-rwxr-xr-xpint/pint_convert.py17
-rw-r--r--pint/pint_eval.py167
-rw-r--r--pint/registry.py32
-rw-r--r--pint/registry_helpers.py11
-rw-r--r--pint/testing.py10
-rw-r--r--pint/testsuite/__init__.py12
-rw-r--r--pint/testsuite/helpers.py14
-rw-r--r--pint/testsuite/test_babel.py6
-rw-r--r--pint/testsuite/test_compat_downcast.py31
-rw-r--r--pint/testsuite/test_compat_upcast.py9
-rw-r--r--pint/testsuite/test_contexts.py4
-rw-r--r--pint/testsuite/test_converters.py2
-rw-r--r--pint/testsuite/test_dask.py7
-rw-r--r--pint/testsuite/test_definitions.py8
-rw-r--r--pint/testsuite/test_errors.py4
-rw-r--r--pint/testsuite/test_formatter.py4
-rw-r--r--pint/testsuite/test_infer_base_unit.py10
-rw-r--r--pint/testsuite/test_issues.py4
-rw-r--r--pint/testsuite/test_log_units.py2
-rw-r--r--pint/testsuite/test_measurement.py2
-rw-r--r--pint/testsuite/test_non_int.py14
-rw-r--r--pint/testsuite/test_numpy.py4
-rw-r--r--pint/testsuite/test_quantity.py20
-rw-r--r--pint/testsuite/test_umath.py6
-rw-r--r--pint/testsuite/test_unit.py18
-rw-r--r--pint/testsuite/test_util.py16
-rw-r--r--pint/util.py589
-rw-r--r--pyproject.toml5
86 files changed, 1377 insertions, 1038 deletions
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 369b9b9..7dd55db 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -7,15 +7,15 @@ jobs:
strategy:
fail-fast: false
matrix:
- python-version: [3.8, 3.9, "3.10", "3.11"]
- numpy: [null, "numpy>=1.19,<2.0.0"]
+ python-version: [3.9, "3.10", "3.11"]
+ numpy: [null, "numpy>=1.21,<2.0.0"]
uncertainties: [null, "uncertainties==3.1.6", "uncertainties>=3.1.6,<4.0.0"]
extras: [null]
include:
- - python-version: 3.8 # Minimal versions
+ - python-version: 3.9 # Minimal versions
numpy: "numpy"
extras: matplotlib==2.2.5
- - python-version: 3.8
+ - python-version: 3.9
numpy: "numpy"
uncertainties: "uncertainties"
extras: "sparse xarray netCDF4 dask[complete]==2023.4.0 graphviz babel==2.8"
@@ -92,8 +92,8 @@ jobs:
strategy:
fail-fast: false
matrix:
- python-version: [3.8, 3.9, "3.10", "3.11"]
- numpy: [ "numpy>=1.19,<2.0.0" ]
+ python-version: [3.9, "3.10", "3.11"]
+ numpy: [ "numpy>=1.21,<2.0.0" ]
runs-on: windows-latest
env:
@@ -153,8 +153,8 @@ jobs:
strategy:
fail-fast: false
matrix:
- python-version: [3.8, 3.9, "3.10", "3.11"]
- numpy: [null, "numpy>=1.19,<2.0.0" ]
+ python-version: [3.9, "3.10", "3.11"]
+ numpy: [null, "numpy>=1.21,<2.0.0" ]
runs-on: macos-latest
env:
@@ -226,13 +226,3 @@ jobs:
# run: |
# pip install coveralls "requests<2.29"
# coveralls --finish
-
- # Dummy task to summarize all. See https://github.com/bors-ng/bors-ng/issues/1300
- # ci-success:
- # name: ci
- # if: ${{ success() }}
- # needs: test-linux
- # runs-on: ubuntu-latest
- # steps:
- # - name: CI succeeded
- # run: exit 0
diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml
index 2340683..0a26da8 100644
--- a/.github/workflows/docs.yml
+++ b/.github/workflows/docs.yml
@@ -14,10 +14,10 @@ jobs:
- name: Get tags
run: git fetch --depth=1 origin +refs/tags/*:refs/tags/*
- - name: Set up Python 3.8
+ - name: Set up minimal Python version
uses: actions/setup-python@v2
with:
- python-version: 3.8
+ python-version: 3.9
- name: Get pip cache dir
id: pip-cache
diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml
new file mode 100644
index 0000000..3cf9f79
--- /dev/null
+++ b/.github/workflows/publish.yml
@@ -0,0 +1,27 @@
+name: Build and publish to PyPI
+
+on:
+ push:
+ tags:
+ - '*'
+
+jobs:
+ publish:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v3
+
+ - uses: actions/setup-python@v4
+ with:
+ python-version: '3.x'
+
+ - name: Install dependencies
+ run: python -m pip install build
+
+ - name: Build package
+ run: python -m build
+
+ - name: Publish to PyPI
+ uses: pypa/gh-action-pypi-publish@release/v1
+ with:
+ password: ${{ secrets.PYPI_API_TOKEN }}
diff --git a/.readthedocs.yaml b/.readthedocs.yaml
index 2bda3d4..830a8c2 100644
--- a/.readthedocs.yaml
+++ b/.readthedocs.yaml
@@ -5,7 +5,7 @@ sphinx:
configuration: docs/conf.py
fail_on_warning: false
python:
- version: 3.8
+ version: 3.9
install:
- requirements: requirements_docs.txt
- method: pip
diff --git a/README.rst b/README.rst
index 32879d9..89f19f4 100644
--- a/README.rst
+++ b/README.rst
@@ -43,7 +43,7 @@ and constants. Due to its modular design, you can extend (or even rewrite!)
the complete list without changing the source code. It supports a lot of
numpy mathematical operations **without monkey patching or wrapping numpy**.
-It has a complete test coverage. It runs in Python 3.8+ with no other dependency.
+It has a complete test coverage. It runs in Python 3.9+ with no other dependency.
It is licensed under BSD.
It is extremely easy and natural to use:
diff --git a/benchmarks/benchmarks/20_quantity.py b/benchmarks/benchmarks/20_quantity.py
index c0174ef..cbd03b2 100644
--- a/benchmarks/benchmarks/20_quantity.py
+++ b/benchmarks/benchmarks/20_quantity.py
@@ -8,7 +8,7 @@ from . import util
units = ("meter", "kilometer", "second", "minute", "angstrom")
all_values = ("int", "float", "complex")
all_values_q = tuple(
- "%s_%s" % (a, b) for a, b in it.product(all_values, ("meter", "kilometer"))
+ f"{a}_{b}" for a, b in it.product(all_values, ("meter", "kilometer"))
)
op1 = (operator.neg, operator.truth)
diff --git a/benchmarks/benchmarks/30_numpy.py b/benchmarks/benchmarks/30_numpy.py
index 15ae66c..139ce58 100644
--- a/benchmarks/benchmarks/30_numpy.py
+++ b/benchmarks/benchmarks/30_numpy.py
@@ -9,11 +9,11 @@ from . import util
lengths = ("short", "mid")
all_values = tuple(
- "%s_%s" % (a, b) for a, b in it.product(lengths, ("list", "tuple", "array"))
+ f"{a}_{b}" for a, b in it.product(lengths, ("list", "tuple", "array"))
)
all_arrays = ("short_array", "mid_array")
units = ("meter", "kilometer")
-all_arrays_q = tuple("%s_%s" % (a, b) for a, b in it.product(all_arrays, units))
+all_arrays_q = tuple(f"{a}_{b}" for a, b in it.product(all_arrays, units))
ureg = None
data = {}
diff --git a/bors.toml b/bors.toml
deleted file mode 100644
index 4e9e7be..0000000
--- a/bors.toml
+++ /dev/null
@@ -1,8 +0,0 @@
-status = [
- "ci",
- "docbuild",
- "lint"
-]
-delete_merged_branches = true
-timeout_sec = 10800
-block_labels = [ "do-not-merge-yet" ]
diff --git a/docs/dev/contributing.rst b/docs/dev/contributing.rst
index c63381b..e70a375 100644
--- a/docs/dev/contributing.rst
+++ b/docs/dev/contributing.rst
@@ -9,7 +9,6 @@ Pint uses (and thanks):
- `github actions`_ to test all commits and PRs.
- coveralls_ to monitor coverage test coverage
- readthedocs_ to host the documentation.
-- `bors-ng`_ as a merge bot and therefore every PR is tested before merging.
- black_, isort_ and flake8_ as code linters and pre-commit_ to enforce them.
- pytest_ to write tests
- sphinx_ to write docs.
@@ -133,7 +132,6 @@ features that work best as an extension package versus direct inclusion in Pint
.. _github: http://github.com/hgrecco/pint
.. _`issue tracker`: https://github.com/hgrecco/pint/issues
-.. _`bors-ng`: https://github.com/bors-ng/bors-ng
.. _`github docs`: https://help.github.com/articles/closing-issues-via-commit-messages/
.. _`github actions`: https://docs.github.com/en/actions
.. _coveralls: https://coveralls.io/
diff --git a/docs/getting/index.rst b/docs/getting/index.rst
index 9907aeb..41ffaf9 100644
--- a/docs/getting/index.rst
+++ b/docs/getting/index.rst
@@ -8,7 +8,7 @@ The getting started guide aims to get you using pint productively as quickly as
Installation
------------
-Pint has no dependencies except Python itself. In runs on Python 3.8+.
+Pint has no dependencies except Python itself. In runs on Python 3.9+.
.. grid:: 2
diff --git a/docs/getting/overview.rst b/docs/getting/overview.rst
index cd639aa..61dfc14 100644
--- a/docs/getting/overview.rst
+++ b/docs/getting/overview.rst
@@ -14,7 +14,7 @@ Due to its modular design, you can extend (or even rewrite!) the complete list
without changing the source code. It supports a lot of numpy mathematical
operations **without monkey patching or wrapping numpy**.
-It has a complete test coverage. It runs in Python 3.8+ with no other
+It has a complete test coverage. It runs in Python 3.9+ with no other
dependencies. It is licensed under a `BSD 3-clause style license`_.
It is extremely easy and natural to use:
diff --git a/pint/_typing.py b/pint/_typing.py
index 64c3a2b..5547f85 100644
--- a/pint/_typing.py
+++ b/pint/_typing.py
@@ -1,17 +1,61 @@
from __future__ import annotations
-from typing import TYPE_CHECKING, Any, Callable, Tuple, TypeVar, Union
+from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union, Protocol
+
+# TODO: Remove when 3.11 becomes minimal version.
+Self = TypeVar("Self")
if TYPE_CHECKING:
from .facets.plain import PlainQuantity as Quantity
from .facets.plain import PlainUnit as Unit
from .util import UnitsContainer
+
+class PintScalar(Protocol):
+ def __add__(self, other: Any) -> Any:
+ ...
+
+ def __sub__(self, other: Any) -> Any:
+ ...
+
+ def __mul__(self, other: Any) -> Any:
+ ...
+
+ def __truediv__(self, other: Any) -> Any:
+ ...
+
+ def __floordiv__(self, other: Any) -> Any:
+ ...
+
+ def __mod__(self, other: Any) -> Any:
+ ...
+
+ def __divmod__(self, other: Any) -> Any:
+ ...
+
+ def __pow__(self, other: Any, modulo: Any) -> Any:
+ ...
+
+
+class PintArray(Protocol):
+ def __len__(self) -> int:
+ ...
+
+ def __getitem__(self, key: Any) -> Any:
+ ...
+
+ def __setitem__(self, key: Any, value: Any) -> None:
+ ...
+
+
+Magnitude = PintScalar | PintScalar
+
+
UnitLike = Union[str, "UnitsContainer", "Unit"]
QuantityOrUnitLike = Union["Quantity", UnitLike]
-Shape = Tuple[int, ...]
+Shape = tuple[int, ...]
_MagnitudeType = TypeVar("_MagnitudeType")
S = TypeVar("S")
diff --git a/pint/babel_names.py b/pint/babel_names.py
index 09fa046..408ef8f 100644
--- a/pint/babel_names.py
+++ b/pint/babel_names.py
@@ -10,7 +10,7 @@ from __future__ import annotations
from .compat import HAS_BABEL
-_babel_units = dict(
+_babel_units: dict[str, str] = dict(
standard_gravity="acceleration-g-force",
millibar="pressure-millibar",
metric_ton="mass-metric-ton",
@@ -141,6 +141,6 @@ _babel_units = dict(
if not HAS_BABEL:
_babel_units = {}
-_babel_systems = dict(mks="metric", imperial="uksystem", US="ussystem")
+_babel_systems: dict[str, str] = dict(mks="metric", imperial="uksystem", US="ussystem")
-_babel_lengths = ["narrow", "short", "long"]
+_babel_lengths: list[str] = ["narrow", "short", "long"]
diff --git a/pint/compat.py b/pint/compat.py
index c585455..7b48efa 100644
--- a/pint/compat.py
+++ b/pint/compat.py
@@ -16,13 +16,21 @@ from decimal import Decimal
from importlib import import_module
from io import BytesIO
from numbers import Number
-from typing import Mapping, Optional
+from collections.abc import Mapping
+from typing import Any, NoReturn, Callable
+from collections.abc import Generator, Iterable
-def missing_dependency(package, display_name=None):
+def missing_dependency(
+ package: str, display_name: str | None = None
+) -> Callable[..., NoReturn]:
+ """Return a helper function that raises an exception when used.
+
+ It provides a way delay a missing dependency exception until it is used.
+ """
display_name = display_name or package
- def _inner(*args, **kwargs):
+ def _inner(*args: Any, **kwargs: Any) -> NoReturn:
raise Exception(
"This feature requires %s. Please install it by running:\n"
"pip install %s" % (display_name, package)
@@ -31,7 +39,14 @@ def missing_dependency(package, display_name=None):
return _inner
-def tokenizer(input_string):
+def tokenizer(input_string: str) -> Generator[tokenize.TokenInfo, None, None]:
+ """Tokenize an input string, encoded as UTF-8
+ and skipping the ENCODING token.
+
+ See Also
+ --------
+ tokenize.tokenize
+ """
for tokinfo in tokenize.tokenize(BytesIO(input_string.encode("utf-8")).readline):
if tokinfo.type != tokenize.ENCODING:
yield tokinfo
@@ -53,7 +68,7 @@ try:
def _to_magnitude(value, force_ndarray=False, force_ndarray_like=False):
if isinstance(value, (dict, bool)) or value is None:
- raise TypeError("Invalid magnitude for Quantity: {0!r}".format(value))
+ raise TypeError(f"Invalid magnitude for Quantity: {value!r}")
elif isinstance(value, str) and value == "":
raise ValueError("Quantity magnitude cannot be an empty string.")
elif isinstance(value, (list, tuple)):
@@ -102,7 +117,7 @@ except ImportError:
"Cannot force to ndarray or ndarray-like when NumPy is not present."
)
elif isinstance(value, (dict, bool)) or value is None:
- raise TypeError("Invalid magnitude for Quantity: {0!r}".format(value))
+ raise TypeError(f"Invalid magnitude for Quantity: {value!r}")
elif isinstance(value, str) and value == "":
raise ValueError("Quantity magnitude cannot be an empty string.")
elif isinstance(value, (list, tuple)):
@@ -154,7 +169,8 @@ else:
from math import log # noqa: F401
if not HAS_BABEL:
- babel_parse = babel_units = missing_dependency("Babel") # noqa: F811
+ babel_parse = missing_dependency("Babel") # noqa: F811
+ babel_units = babel_parse
if not HAS_MIP:
mip_missing = missing_dependency("mip")
@@ -176,6 +192,9 @@ except ImportError:
dask_array = None
+# TODO: merge with upcast_type_map
+
+#: List upcast type names
upcast_type_names = (
"pint_pandas.PintArray",
"pandas.Series",
@@ -186,10 +205,12 @@ upcast_type_names = (
"xarray.core.dataarray.DataArray",
)
-upcast_type_map: Mapping[str : Optional[type]] = {k: None for k in upcast_type_names}
+#: Map type name to the actual type (for upcast types).
+upcast_type_map: Mapping[str, type | None] = {k: None for k in upcast_type_names}
def fully_qualified_name(t: type) -> str:
+ """Return the fully qualified name of a type."""
module = t.__module__
name = t.__qualname__
@@ -200,6 +221,10 @@ def fully_qualified_name(t: type) -> str:
def check_upcast_type(obj: type) -> bool:
+ """Check if the type object is an upcast type."""
+
+ # TODO: merge or unify name with is_upcast_type
+
fqn = fully_qualified_name(obj)
if fqn not in upcast_type_map:
return False
@@ -215,22 +240,17 @@ def check_upcast_type(obj: type) -> bool:
def is_upcast_type(other: type) -> bool:
+ """Check if the type object is an upcast type."""
+
+ # TODO: merge or unify name with check_upcast_type
+
if other in upcast_type_map.values():
return True
return check_upcast_type(other)
-def is_duck_array_type(cls) -> bool:
- """Check if the type object represents a (non-Quantity) duck array type.
-
- Parameters
- ----------
- cls : class
-
- Returns
- -------
- bool
- """
+def is_duck_array_type(cls: type) -> bool:
+ """Check if the type object represents a (non-Quantity) duck array type."""
# TODO (NEP 30): replace duck array check with hasattr(other, "__duckarray__")
return issubclass(cls, ndarray) or (
not hasattr(cls, "_magnitude")
@@ -242,20 +262,21 @@ def is_duck_array_type(cls) -> bool:
)
-def is_duck_array(obj):
+def is_duck_array(obj: type) -> bool:
+ """Check if an object represents a (non-Quantity) duck array type."""
return is_duck_array_type(type(obj))
-def eq(lhs, rhs, check_all: bool):
+def eq(lhs: Any, rhs: Any, check_all: bool) -> bool | Iterable[bool]:
"""Comparison of scalars and arrays.
Parameters
----------
- lhs : object
+ lhs
left-hand side
- rhs : object
+ rhs
right-hand side
- check_all : bool
+ check_all
if True, reduce sequence to single bool;
return True if all the elements are equal.
@@ -269,21 +290,21 @@ def eq(lhs, rhs, check_all: bool):
return out
-def isnan(obj, check_all: bool):
- """Test for NaN or NaT
+def isnan(obj: Any, check_all: bool) -> bool | Iterable[bool]:
+ """Test for NaN or NaT.
Parameters
----------
- obj : object
+ obj
scalar or vector
- check_all : bool
+ check_all
if True, reduce sequence to single bool;
return True if any of the elements are NaN.
Returns
-------
bool or array_like of bool.
- Always return False for non-numeric types.
+ Always return False for non-numeric types.
"""
if is_duck_array_type(type(obj)):
if obj.dtype.kind in "if":
@@ -302,21 +323,21 @@ def isnan(obj, check_all: bool):
return False
-def zero_or_nan(obj, check_all: bool):
- """Test if obj is zero, NaN, or NaT
+def zero_or_nan(obj: Any, check_all: bool) -> bool | Iterable[bool]:
+ """Test if obj is zero, NaN, or NaT.
Parameters
----------
- obj : object
+ obj
scalar or vector
- check_all : bool
+ check_all
if True, reduce sequence to single bool;
return True if all the elements are zero, NaN, or NaT.
Returns
-------
bool or array_like of bool.
- Always return False for non-numeric types.
+ Always return False for non-numeric types.
"""
out = eq(obj, 0, False) + isnan(obj, False)
if check_all and is_duck_array_type(type(out)):
diff --git a/pint/context.py b/pint/context.py
index 4839926..6c74f65 100644
--- a/pint/context.py
+++ b/pint/context.py
@@ -18,3 +18,5 @@ if TYPE_CHECKING:
#: Regex to match the header parts of a context.
#: Regex to match variable names in an equation.
+
+# TODO: delete this file
diff --git a/pint/converters.py b/pint/converters.py
index 12248a8..9494ad1 100644
--- a/pint/converters.py
+++ b/pint/converters.py
@@ -13,6 +13,10 @@ from __future__ import annotations
from dataclasses import dataclass
from dataclasses import fields as dc_fields
+from typing import Any
+
+from ._typing import Self, Magnitude
+
from .compat import HAS_NUMPY, exp, log # noqa: F401
@@ -24,17 +28,17 @@ class Converter:
_param_names_to_subclass = {}
@property
- def is_multiplicative(self):
+ def is_multiplicative(self) -> bool:
return True
@property
- def is_logarithmic(self):
+ def is_logarithmic(self) -> bool:
return False
- def to_reference(self, value, inplace=False):
+ def to_reference(self, value: Magnitude, inplace: bool = False) -> Magnitude:
return value
- def from_reference(self, value, inplace=False):
+ def from_reference(self, value: Magnitude, inplace: bool = False) -> Magnitude:
return value
def __init_subclass__(cls, **kwargs):
@@ -43,21 +47,21 @@ class Converter:
cls._subclasses.append(cls)
@classmethod
- def get_field_names(cls, new_cls):
- return frozenset((p.name for p in dc_fields(new_cls)))
+ def get_field_names(cls, new_cls) -> frozenset[str]:
+ return frozenset(p.name for p in dc_fields(new_cls))
@classmethod
def preprocess_kwargs(cls, **kwargs):
return None
@classmethod
- def from_arguments(cls, **kwargs):
+ def from_arguments(cls: type[Self], **kwargs: Any) -> Self:
kwk = frozenset(kwargs.keys())
try:
new_cls = cls._param_names_to_subclass[kwk]
except KeyError:
for new_cls in cls._subclasses:
- p_names = frozenset((p.name for p in dc_fields(new_cls)))
+ p_names = frozenset(p.name for p in dc_fields(new_cls))
if p_names == kwk:
cls._param_names_to_subclass[kwk] = new_cls
break
diff --git a/pint/definitions.py b/pint/definitions.py
index 789d9e3..ce89e94 100644
--- a/pint/definitions.py
+++ b/pint/definitions.py
@@ -8,6 +8,8 @@
:license: BSD, see LICENSE for more details.
"""
+from __future__ import annotations
+
from . import errors
from ._vendor import flexparser as fp
from .delegates import ParserConfig, txt_defparser
@@ -17,12 +19,28 @@ class Definition:
"""This is kept for backwards compatibility"""
@classmethod
- def from_string(cls, s: str, non_int_type=float):
+ def from_string(cls, input_string: str, non_int_type: type = float) -> Definition:
+ """Parse a string into a definition object.
+
+ Parameters
+ ----------
+ input_string
+ Single line string.
+ non_int_type
+ Numerical type used for non integer values.
+
+ Raises
+ ------
+ DefinitionSyntaxError
+ If a syntax error was found.
+ """
cfg = ParserConfig(non_int_type)
parser = txt_defparser.DefParser(cfg, None)
- pp = parser.parse_string(s)
+ pp = parser.parse_string(input_string)
for definition in parser.iter_parsed_project(pp):
if isinstance(definition, Exception):
raise errors.DefinitionSyntaxError(str(definition))
if not isinstance(definition, (fp.BOS, fp.BOF, fp.BOS)):
return definition
+
+ # TODO: What shall we do in this return path.
diff --git a/pint/delegates/__init__.py b/pint/delegates/__init__.py
index 363ef9c..b2eb9a3 100644
--- a/pint/delegates/__init__.py
+++ b/pint/delegates/__init__.py
@@ -11,4 +11,4 @@
from . import txt_defparser
from .base_defparser import ParserConfig, build_disk_cache_class
-__all__ = [txt_defparser, ParserConfig, build_disk_cache_class]
+__all__ = ["txt_defparser", "ParserConfig", "build_disk_cache_class"]
diff --git a/pint/delegates/base_defparser.py b/pint/delegates/base_defparser.py
index d35f3e3..9e784ac 100644
--- a/pint/delegates/base_defparser.py
+++ b/pint/delegates/base_defparser.py
@@ -14,7 +14,6 @@ import functools
import itertools
import numbers
import pathlib
-import typing as ty
from dataclasses import dataclass, field
from pint import errors
@@ -27,10 +26,10 @@ from .._vendor import flexparser as fp
@dataclass(frozen=True)
class ParserConfig:
- """Configuration used by the parser."""
+ """Configuration used by the parser in Pint."""
#: Indicates the output type of non integer numbers.
- non_int_type: ty.Type[numbers.Number] = float
+ non_int_type: type[numbers.Number] = float
def to_scaled_units_container(self, s: str):
return ParserHelper.from_string(s, self.non_int_type)
@@ -67,7 +66,12 @@ class ParserConfig:
return val.scale
-@functools.lru_cache()
+@dataclass(frozen=True)
+class PintParsedStatement(fp.ParsedStatement[ParserConfig]):
+ """A parsed statement for pint, specialized in the actual config."""
+
+
+@functools.lru_cache
def build_disk_cache_class(non_int_type: type):
"""Build disk cache class, taking into account the non_int_type."""
@@ -84,14 +88,11 @@ def build_disk_cache_class(non_int_type: type):
class ParsedProjecHeader(fc.NameByHashIter, PintHeader):
@classmethod
def from_parsed_project(cls, pp: fp.ParsedProject, reader_id):
- tmp = []
- for stmt in pp.iter_statements():
- if isinstance(stmt, fp.BOS):
- tmp.append(
- stmt.content_hash.algorithm_name
- + ":"
- + stmt.content_hash.hexdigest
- )
+ tmp = (
+ f"{stmt.content_hash.algorithm_name}:{stmt.content_hash.hexdigest}"
+ for stmt in pp.iter_statements()
+ if isinstance(stmt, fp.BOS)
+ )
return cls(tuple(tmp), reader_id)
diff --git a/pint/delegates/txt_defparser/__init__.py b/pint/delegates/txt_defparser/__init__.py
index 5572ca1..49e4a0b 100644
--- a/pint/delegates/txt_defparser/__init__.py
+++ b/pint/delegates/txt_defparser/__init__.py
@@ -11,4 +11,6 @@
from .defparser import DefParser
-__all__ = [DefParser]
+__all__ = [
+ "DefParser",
+]
diff --git a/pint/delegates/txt_defparser/block.py b/pint/delegates/txt_defparser/block.py
index 20ebcba..e8d8aa4 100644
--- a/pint/delegates/txt_defparser/block.py
+++ b/pint/delegates/txt_defparser/block.py
@@ -17,11 +17,14 @@ from __future__ import annotations
from dataclasses import dataclass
+from typing import Generic, TypeVar
+
+from ..base_defparser import PintParsedStatement, ParserConfig
from ..._vendor import flexparser as fp
@dataclass(frozen=True)
-class EndDirectiveBlock(fp.ParsedStatement):
+class EndDirectiveBlock(PintParsedStatement):
"""An EndDirectiveBlock is simply an "@end" statement."""
@classmethod
@@ -31,8 +34,16 @@ class EndDirectiveBlock(fp.ParsedStatement):
return None
+OPST = TypeVar("OPST", bound="PintParsedStatement")
+IPST = TypeVar("IPST", bound="PintParsedStatement")
+
+DefT = TypeVar("DefT")
+
+
@dataclass(frozen=True)
-class DirectiveBlock(fp.Block):
+class DirectiveBlock(
+ Generic[DefT, OPST, IPST], fp.Block[OPST, IPST, EndDirectiveBlock, ParserConfig]
+):
"""Directive blocks have beginning statement starting with a @ character.
and ending with a "@end" (captured using a EndDirectiveBlock).
@@ -41,5 +52,5 @@ class DirectiveBlock(fp.Block):
closing: EndDirectiveBlock
- def derive_definition(self):
- pass
+ def derive_definition(self) -> DefT:
+ ...
diff --git a/pint/delegates/txt_defparser/common.py b/pint/delegates/txt_defparser/common.py
index 493d0ec..a1195b3 100644
--- a/pint/delegates/txt_defparser/common.py
+++ b/pint/delegates/txt_defparser/common.py
@@ -30,7 +30,7 @@ class DefinitionSyntaxError(errors.DefinitionSyntaxError, fp.ParsingError):
location: str = field(init=False, default="")
- def __str__(self):
+ def __str__(self) -> str:
msg = (
self.msg + "\n " + (self.format_position or "") + " " + (self.raw or "")
)
@@ -38,7 +38,7 @@ class DefinitionSyntaxError(errors.DefinitionSyntaxError, fp.ParsingError):
msg += "\n " + self.location
return msg
- def set_location(self, value):
+ def set_location(self, value: str) -> None:
super().__setattr__("location", value)
@@ -47,7 +47,7 @@ class ImportDefinition(fp.IncludeStatement):
value: str
@property
- def target(self):
+ def target(self) -> str:
return self.value
@classmethod
diff --git a/pint/delegates/txt_defparser/context.py b/pint/delegates/txt_defparser/context.py
index 5c54b4c..ce9fc9b 100644
--- a/pint/delegates/txt_defparser/context.py
+++ b/pint/delegates/txt_defparser/context.py
@@ -20,36 +20,35 @@ import numbers
import re
import typing as ty
from dataclasses import dataclass
-from typing import Dict, Tuple
from ..._vendor import flexparser as fp
from ...facets.context import definitions
-from ..base_defparser import ParserConfig
+from ..base_defparser import ParserConfig, PintParsedStatement
from . import block, common, plain
+# TODO check syntax
+T = ty.TypeVar("T", bound="ForwardRelation | BidirectionalRelation")
-@dataclass(frozen=True)
-class Relation(definitions.Relation):
- @classmethod
- def _from_string_and_context_sep(
- cls, s: str, config: ParserConfig, separator: str
- ) -> fp.FromString[Relation]:
- if separator not in s:
- return None
- if ":" not in s:
- return None
- rel, eq = s.split(":")
+def _from_string_and_context_sep(
+ cls: type[T], s: str, config: ParserConfig, separator: str
+) -> T | None:
+ if separator not in s:
+ return None
+ if ":" not in s:
+ return None
+
+ rel, eq = s.split(":")
- parts = rel.split(separator)
+ parts = rel.split(separator)
- src, dst = (config.to_dimension_container(s) for s in parts)
+ src, dst = (config.to_dimension_container(s) for s in parts)
- return cls(src, dst, eq.strip())
+ return cls(src, dst, eq.strip())
@dataclass(frozen=True)
-class ForwardRelation(fp.ParsedStatement, definitions.ForwardRelation, Relation):
+class ForwardRelation(PintParsedStatement, definitions.ForwardRelation):
"""A relation connecting a dimension to another via a transformation function.
<source dimension> -> <target dimension>: <transformation function>
@@ -59,13 +58,11 @@ class ForwardRelation(fp.ParsedStatement, definitions.ForwardRelation, Relation)
def from_string_and_config(
cls, s: str, config: ParserConfig
) -> fp.FromString[ForwardRelation]:
- return super()._from_string_and_context_sep(s, config, "->")
+ return _from_string_and_context_sep(cls, s, config, "->")
@dataclass(frozen=True)
-class BidirectionalRelation(
- fp.ParsedStatement, definitions.BidirectionalRelation, Relation
-):
+class BidirectionalRelation(PintParsedStatement, definitions.BidirectionalRelation):
"""A bidirectional relation connecting a dimension to another
via a simple transformation function.
@@ -77,11 +74,11 @@ class BidirectionalRelation(
def from_string_and_config(
cls, s: str, config: ParserConfig
) -> fp.FromString[BidirectionalRelation]:
- return super()._from_string_and_context_sep(s, config, "<->")
+ return _from_string_and_context_sep(cls, s, config, "<->")
@dataclass(frozen=True)
-class BeginContext(fp.ParsedStatement):
+class BeginContext(PintParsedStatement):
"""Being of a context directive.
@context[(defaults)] <canonical name> [= <alias>] [= <alias>]
@@ -92,8 +89,8 @@ class BeginContext(fp.ParsedStatement):
)
name: str
- aliases: Tuple[str, ...]
- defaults: Dict[str, numbers.Number]
+ aliases: tuple[str]
+ defaults: dict[str, numbers.Number]
@classmethod
def from_string_and_config(
@@ -131,7 +128,18 @@ class BeginContext(fp.ParsedStatement):
@dataclass(frozen=True)
-class ContextDefinition(block.DirectiveBlock):
+class ContextDefinition(
+ block.DirectiveBlock[
+ definitions.ContextDefinition,
+ BeginContext,
+ ty.Union[
+ plain.CommentDefinition,
+ BidirectionalRelation,
+ ForwardRelation,
+ plain.UnitDefinition,
+ ],
+ ]
+):
"""Definition of a Context
@context[(defaults)] <canonical name> [= <alias>] [= <alias>]
@@ -170,27 +178,34 @@ class ContextDefinition(block.DirectiveBlock):
]
]
- def derive_definition(self):
+ def derive_definition(self) -> definitions.ContextDefinition:
return definitions.ContextDefinition(
self.name, self.aliases, self.defaults, self.relations, self.redefinitions
)
@property
- def name(self):
+ def name(self) -> str:
+ assert isinstance(self.opening, BeginContext)
return self.opening.name
@property
- def aliases(self):
+ def aliases(self) -> tuple[str]:
+ assert isinstance(self.opening, BeginContext)
return self.opening.aliases
@property
- def defaults(self):
+ def defaults(self) -> dict[str, numbers.Number]:
+ assert isinstance(self.opening, BeginContext)
return self.opening.defaults
@property
- def relations(self):
- return tuple(r for r in self.body if isinstance(r, Relation))
+ def relations(self) -> tuple[BidirectionalRelation | ForwardRelation]:
+ return tuple(
+ r
+ for r in self.body
+ if isinstance(r, (ForwardRelation, BidirectionalRelation))
+ )
@property
- def redefinitions(self):
+ def redefinitions(self) -> tuple[plain.UnitDefinition]:
return tuple(r for r in self.body if isinstance(r, plain.UnitDefinition))
diff --git a/pint/delegates/txt_defparser/defaults.py b/pint/delegates/txt_defparser/defaults.py
index af6e31f..688d90f 100644
--- a/pint/delegates/txt_defparser/defaults.py
+++ b/pint/delegates/txt_defparser/defaults.py
@@ -19,10 +19,11 @@ from dataclasses import dataclass, fields
from ..._vendor import flexparser as fp
from ...facets.plain import definitions
from . import block, plain
+from ..base_defparser import PintParsedStatement
@dataclass(frozen=True)
-class BeginDefaults(fp.ParsedStatement):
+class BeginDefaults(PintParsedStatement):
"""Being of a defaults directive.
@defaults
@@ -36,7 +37,16 @@ class BeginDefaults(fp.ParsedStatement):
@dataclass(frozen=True)
-class DefaultsDefinition(block.DirectiveBlock):
+class DefaultsDefinition(
+ block.DirectiveBlock[
+ definitions.DefaultsDefinition,
+ BeginDefaults,
+ ty.Union[
+ plain.CommentDefinition,
+ plain.Equality,
+ ],
+ ]
+):
"""Directive to store values.
@defaults
@@ -55,10 +65,10 @@ class DefaultsDefinition(block.DirectiveBlock):
]
@property
- def _valid_fields(self):
+ def _valid_fields(self) -> tuple[str]:
return tuple(f.name for f in fields(definitions.DefaultsDefinition))
- def derive_definition(self):
+ def derive_definition(self) -> definitions.DefaultsDefinition:
for definition in self.filter_by(plain.Equality):
if definition.lhs not in self._valid_fields:
raise ValueError(
@@ -70,7 +80,7 @@ class DefaultsDefinition(block.DirectiveBlock):
*tuple(self.get_key(key) for key in self._valid_fields)
)
- def get_key(self, key):
+ def get_key(self, key: str) -> str:
for stmt in self.body:
if isinstance(stmt, plain.Equality) and stmt.lhs == key:
return stmt.rhs
diff --git a/pint/delegates/txt_defparser/defparser.py b/pint/delegates/txt_defparser/defparser.py
index 0b99d6d..f1b8e45 100644
--- a/pint/delegates/txt_defparser/defparser.py
+++ b/pint/delegates/txt_defparser/defparser.py
@@ -5,11 +5,28 @@ import typing as ty
from ..._vendor import flexcache as fc
from ..._vendor import flexparser as fp
-from .. import base_defparser
+from ..base_defparser import ParserConfig
from . import block, common, context, defaults, group, plain, system
-class PintRootBlock(fp.RootBlock):
+class PintRootBlock(
+ fp.RootBlock[
+ ty.Union[
+ plain.CommentDefinition,
+ common.ImportDefinition,
+ context.ContextDefinition,
+ defaults.DefaultsDefinition,
+ system.SystemDefinition,
+ group.GroupDefinition,
+ plain.AliasDefinition,
+ plain.DerivedDimensionDefinition,
+ plain.DimensionDefinition,
+ plain.PrefixDefinition,
+ plain.UnitDefinition,
+ ],
+ ParserConfig,
+ ]
+):
body: fp.Multi[
ty.Union[
plain.CommentDefinition,
@@ -27,11 +44,15 @@ class PintRootBlock(fp.RootBlock):
]
+class PintSource(fp.ParsedSource[PintRootBlock, ParserConfig]):
+ """Source code in Pint."""
+
+
class HashTuple(tuple):
pass
-class _PintParser(fp.Parser):
+class _PintParser(fp.Parser[PintRootBlock, ParserConfig]):
"""Parser for the original Pint definition file, with cache."""
_delimiters = {
@@ -46,11 +67,11 @@ class _PintParser(fp.Parser):
_diskcache: fc.DiskCache
- def __init__(self, config: base_defparser.ParserConfig, *args, **kwargs):
+ def __init__(self, config: ParserConfig, *args, **kwargs):
self._diskcache = kwargs.pop("diskcache", None)
super().__init__(config, *args, **kwargs)
- def parse_file(self, path: pathlib.Path) -> fp.ParsedSource:
+ def parse_file(self, path: pathlib.Path) -> PintSource:
if self._diskcache is None:
return super().parse_file(path)
content, basename = self._diskcache.load(path, super().parse_file)
@@ -58,7 +79,13 @@ class _PintParser(fp.Parser):
class DefParser:
- skip_classes = (fp.BOF, fp.BOR, fp.BOS, fp.EOS, plain.CommentDefinition)
+ skip_classes: tuple[type] = (
+ fp.BOF,
+ fp.BOR,
+ fp.BOS,
+ fp.EOS,
+ plain.CommentDefinition,
+ )
def __init__(self, default_config, diskcache):
self._default_config = default_config
@@ -78,6 +105,8 @@ class DefParser:
continue
if isinstance(stmt, common.DefinitionSyntaxError):
+ # TODO: check why this assert fails
+ # assert isinstance(last_location, str)
stmt.set_location(last_location)
raise stmt
elif isinstance(stmt, block.DirectiveBlock):
@@ -101,7 +130,7 @@ class DefParser:
else:
yield stmt
- def parse_file(self, filename: pathlib.Path, cfg=None):
+ def parse_file(self, filename: pathlib.Path, cfg: ParserConfig | None = None):
return fp.parse(
filename,
_PintParser,
@@ -109,7 +138,7 @@ class DefParser:
diskcache=self._diskcache,
)
- def parse_string(self, content: str, cfg=None):
+ def parse_string(self, content: str, cfg: ParserConfig | None = None):
return fp.parse_bytes(
content.encode("utf-8"),
_PintParser,
diff --git a/pint/delegates/txt_defparser/group.py b/pint/delegates/txt_defparser/group.py
index 5be42ac..e96d44b 100644
--- a/pint/delegates/txt_defparser/group.py
+++ b/pint/delegates/txt_defparser/group.py
@@ -23,10 +23,11 @@ from dataclasses import dataclass
from ..._vendor import flexparser as fp
from ...facets.group import definitions
from . import block, common, plain
+from ..base_defparser import PintParsedStatement
@dataclass(frozen=True)
-class BeginGroup(fp.ParsedStatement):
+class BeginGroup(PintParsedStatement):
"""Being of a group directive.
@group <name> [using <group 1>, ..., <group N>]
@@ -59,7 +60,16 @@ class BeginGroup(fp.ParsedStatement):
@dataclass(frozen=True)
-class GroupDefinition(block.DirectiveBlock):
+class GroupDefinition(
+ block.DirectiveBlock[
+ definitions.GroupDefinition,
+ BeginGroup,
+ ty.Union[
+ plain.CommentDefinition,
+ plain.UnitDefinition,
+ ],
+ ]
+):
"""Definition of a group.
@group <name> [using <group 1>, ..., <group N>]
@@ -88,19 +98,21 @@ class GroupDefinition(block.DirectiveBlock):
]
]
- def derive_definition(self):
+ def derive_definition(self) -> definitions.GroupDefinition:
return definitions.GroupDefinition(
self.name, self.using_group_names, self.definitions
)
@property
- def name(self):
+ def name(self) -> str:
+ assert isinstance(self.opening, BeginGroup)
return self.opening.name
@property
- def using_group_names(self):
+ def using_group_names(self) -> tuple[str]:
+ assert isinstance(self.opening, BeginGroup)
return self.opening.using_group_names
@property
- def definitions(self) -> ty.Tuple[plain.UnitDefinition, ...]:
+ def definitions(self) -> tuple[plain.UnitDefinition]:
return tuple(el for el in self.body if isinstance(el, plain.UnitDefinition))
diff --git a/pint/delegates/txt_defparser/plain.py b/pint/delegates/txt_defparser/plain.py
index 428df10..9c7bd42 100644
--- a/pint/delegates/txt_defparser/plain.py
+++ b/pint/delegates/txt_defparser/plain.py
@@ -29,12 +29,12 @@ from ..._vendor import flexparser as fp
from ...converters import Converter
from ...facets.plain import definitions
from ...util import UnitsContainer
-from ..base_defparser import ParserConfig
+from ..base_defparser import ParserConfig, PintParsedStatement
from . import common
@dataclass(frozen=True)
-class Equality(fp.ParsedStatement, definitions.Equality):
+class Equality(PintParsedStatement, definitions.Equality):
"""An equality statement contains a left and right hand separated
lhs and rhs should be space stripped.
@@ -53,7 +53,7 @@ class Equality(fp.ParsedStatement, definitions.Equality):
@dataclass(frozen=True)
-class CommentDefinition(fp.ParsedStatement, definitions.CommentDefinition):
+class CommentDefinition(PintParsedStatement, definitions.CommentDefinition):
"""Comments start with a # character.
# This is a comment.
@@ -63,14 +63,14 @@ class CommentDefinition(fp.ParsedStatement, definitions.CommentDefinition):
"""
@classmethod
- def from_string(cls, s: str) -> fp.FromString[fp.ParsedStatement]:
+ def from_string(cls, s: str) -> fp.FromString[CommentDefinition]:
if not s.startswith("#"):
return None
return cls(s[1:].strip())
@dataclass(frozen=True)
-class PrefixDefinition(fp.ParsedStatement, definitions.PrefixDefinition):
+class PrefixDefinition(PintParsedStatement, definitions.PrefixDefinition):
"""Definition of a prefix::
<prefix>- = <value> [= <symbol>] [= <alias>] [ = <alias> ] [...]
@@ -119,7 +119,7 @@ class PrefixDefinition(fp.ParsedStatement, definitions.PrefixDefinition):
@dataclass(frozen=True)
-class UnitDefinition(fp.ParsedStatement, definitions.UnitDefinition):
+class UnitDefinition(PintParsedStatement, definitions.UnitDefinition):
"""Definition of a unit::
<canonical name> = <relation to another unit or dimension> [= <symbol>] [= <alias>] [ = <alias> ] [...]
@@ -159,10 +159,10 @@ class UnitDefinition(fp.ParsedStatement, definitions.UnitDefinition):
[converter, modifiers] = value.split(";", 1)
try:
- modifiers = dict(
- (key.strip(), config.to_number(value))
+ modifiers = {
+ key.strip(): config.to_number(value)
for key, value in (part.split(":") for part in modifiers.split(";"))
- )
+ }
except definitions.NotNumeric as ex:
return common.DefinitionSyntaxError(
f"Unit definition ('{name}') must contain only numbers in modifier, not {ex.value}"
@@ -194,7 +194,7 @@ class UnitDefinition(fp.ParsedStatement, definitions.UnitDefinition):
@dataclass(frozen=True)
-class DimensionDefinition(fp.ParsedStatement, definitions.DimensionDefinition):
+class DimensionDefinition(PintParsedStatement, definitions.DimensionDefinition):
"""Definition of a root dimension::
[dimension name]
@@ -221,7 +221,7 @@ class DimensionDefinition(fp.ParsedStatement, definitions.DimensionDefinition):
@dataclass(frozen=True)
class DerivedDimensionDefinition(
- fp.ParsedStatement, definitions.DerivedDimensionDefinition
+ PintParsedStatement, definitions.DerivedDimensionDefinition
):
"""Definition of a derived dimension::
@@ -261,7 +261,7 @@ class DerivedDimensionDefinition(
@dataclass(frozen=True)
-class AliasDefinition(fp.ParsedStatement, definitions.AliasDefinition):
+class AliasDefinition(PintParsedStatement, definitions.AliasDefinition):
"""Additional alias(es) for an already existing unit::
@alias <canonical name or previous alias> = <alias> [ = <alias> ] [...]
diff --git a/pint/delegates/txt_defparser/system.py b/pint/delegates/txt_defparser/system.py
index b21fd7a..4efbb4d 100644
--- a/pint/delegates/txt_defparser/system.py
+++ b/pint/delegates/txt_defparser/system.py
@@ -14,11 +14,12 @@ from dataclasses import dataclass
from ..._vendor import flexparser as fp
from ...facets.system import definitions
+from ..base_defparser import PintParsedStatement
from . import block, common, plain
@dataclass(frozen=True)
-class BaseUnitRule(fp.ParsedStatement, definitions.BaseUnitRule):
+class BaseUnitRule(PintParsedStatement, definitions.BaseUnitRule):
@classmethod
def from_string(cls, s: str) -> fp.FromString[BaseUnitRule]:
if ":" not in s:
@@ -32,7 +33,7 @@ class BaseUnitRule(fp.ParsedStatement, definitions.BaseUnitRule):
@dataclass(frozen=True)
-class BeginSystem(fp.ParsedStatement):
+class BeginSystem(PintParsedStatement):
"""Being of a system directive.
@system <name> [using <group 1>, ..., <group N>]
@@ -67,7 +68,13 @@ class BeginSystem(fp.ParsedStatement):
@dataclass(frozen=True)
-class SystemDefinition(block.DirectiveBlock):
+class SystemDefinition(
+ block.DirectiveBlock[
+ definitions.SystemDefinition,
+ BeginSystem,
+ ty.Union[plain.CommentDefinition, BaseUnitRule],
+ ]
+):
"""Definition of a System:
@system <name> [using <group 1>, ..., <group N>]
@@ -92,19 +99,21 @@ class SystemDefinition(block.DirectiveBlock):
opening: fp.Single[BeginSystem]
body: fp.Multi[ty.Union[plain.CommentDefinition, BaseUnitRule]]
- def derive_definition(self):
+ def derive_definition(self) -> definitions.SystemDefinition:
return definitions.SystemDefinition(
self.name, self.using_group_names, self.rules
)
@property
- def name(self):
+ def name(self) -> str:
+ assert isinstance(self.opening, BeginSystem)
return self.opening.name
@property
- def using_group_names(self):
+ def using_group_names(self) -> tuple[str]:
+ assert isinstance(self.opening, BeginSystem)
return self.opening.using_group_names
@property
- def rules(self):
+ def rules(self) -> tuple[BaseUnitRule]:
return tuple(el for el in self.body if isinstance(el, BaseUnitRule))
diff --git a/pint/errors.py b/pint/errors.py
index 8f849da..6cebb21 100644
--- a/pint/errors.py
+++ b/pint/errors.py
@@ -36,18 +36,21 @@ MSG_INVALID_SYSTEM_NAME = (
)
-def is_dim(name):
+def is_dim(name: str) -> bool:
+ """Return True if the name is flanked by square brackets `[` and `]`."""
return name[0] == "[" and name[-1] == "]"
-def is_valid_prefix_name(name):
+def is_valid_prefix_name(name: str) -> bool:
+ """Return True if the name is a valid python identifier or empty."""
return str.isidentifier(name) or name == ""
is_valid_unit_name = is_valid_system_name = is_valid_context_name = str.isidentifier
-def _no_space(name):
+def _no_space(name: str) -> bool:
+ """Return False if the name contains a space in any position."""
return name.strip() == name and " " not in name
@@ -58,7 +61,14 @@ is_valid_unit_alias = (
) = is_valid_unit_symbol = is_valid_prefix_symbol = _no_space
-def is_valid_dimension_name(name):
+def is_valid_dimension_name(name: str) -> bool:
+ """Return True if the name is consistent with a dimension name.
+
+ - flanked by square brackets.
+ - empty dimension name or identifier.
+ """
+
+ # TODO: shall we check also fro spaces?
return name == "[]" or (
len(name) > 1 and is_dim(name) and str.isidentifier(name[1:-1])
)
@@ -67,8 +77,8 @@ def is_valid_dimension_name(name):
class WithDefErr:
"""Mixing class to make some classes more readable."""
- def def_err(self, msg):
- return DefinitionError(self.name, self.__class__.__name__, msg)
+ def def_err(self, msg: str):
+ return DefinitionError(self.name, self.__class__, msg)
@dataclass(frozen=False)
@@ -81,7 +91,7 @@ class DefinitionError(ValueError, PintError):
"""Raised when a definition is not properly constructed."""
name: str
- definition_type: ty.Type
+ definition_type: type
msg: str
def __str__(self):
@@ -110,7 +120,7 @@ class RedefinitionError(ValueError, PintError):
"""Raised when a unit or prefix is redefined."""
name: str
- definition_type: ty.Type
+ definition_type: type
def __str__(self):
msg = f"Cannot redefine '{self.name}' ({self.definition_type})"
@@ -124,7 +134,7 @@ class RedefinitionError(ValueError, PintError):
class UndefinedUnitError(AttributeError, PintError):
"""Raised when the units are not defined in the unit registry."""
- unit_names: ty.Union[str, ty.Tuple[str, ...]]
+ unit_names: str | tuple[str]
def __str__(self):
if isinstance(self.unit_names, str):
diff --git a/pint/facets/__init__.py b/pint/facets/__init__.py
index d669b9f..750f729 100644
--- a/pint/facets/__init__.py
+++ b/pint/facets/__init__.py
@@ -30,8 +30,8 @@
class NumpyRegistry:
- _quantity_class = NumpyQuantity
- _unit_class = NumpyUnit
+ Quantity = NumpyQuantity
+ Unit = NumpyUnit
This tells pint that it should use NumpyQuantity as base class for a quantity
class that belongs to a registry that has NumpyRegistry as one of its bases.
@@ -82,13 +82,13 @@ from .plain import PlainRegistry
from .system import SystemRegistry
__all__ = [
- ContextRegistry,
- DaskRegistry,
- FormattingRegistry,
- GroupRegistry,
- MeasurementRegistry,
- NonMultiplicativeRegistry,
- NumpyRegistry,
- PlainRegistry,
- SystemRegistry,
+ "ContextRegistry",
+ "DaskRegistry",
+ "FormattingRegistry",
+ "GroupRegistry",
+ "MeasurementRegistry",
+ "NonMultiplicativeRegistry",
+ "NumpyRegistry",
+ "PlainRegistry",
+ "SystemRegistry",
]
diff --git a/pint/facets/context/definitions.py b/pint/facets/context/definitions.py
index fbdb390..833857e 100644
--- a/pint/facets/context/definitions.py
+++ b/pint/facets/context/definitions.py
@@ -12,7 +12,8 @@ import itertools
import numbers
import re
from dataclasses import dataclass
-from typing import TYPE_CHECKING, Any, Callable, Dict, Set, Tuple
+from typing import TYPE_CHECKING, Any, Callable
+from collections.abc import Iterable
from ... import errors
from ..plain import UnitDefinition
@@ -41,7 +42,7 @@ class Relation:
# could be used.
@property
- def variables(self) -> Set[str, ...]:
+ def variables(self) -> set[str]:
"""Find all variables names in the equation."""
return set(self._varname_re.findall(self.equation))
@@ -55,7 +56,7 @@ class Relation:
)
@property
- def bidirectional(self):
+ def bidirectional(self) -> bool:
raise NotImplementedError
@@ -92,18 +93,18 @@ class ContextDefinition(errors.WithDefErr):
#: name of the context
name: str
#: other na
- aliases: Tuple[str, ...]
- defaults: Dict[str, numbers.Number]
- relations: Tuple[Relation, ...]
- redefinitions: Tuple[UnitDefinition, ...]
+ aliases: tuple[str]
+ defaults: dict[str, numbers.Number]
+ relations: tuple[Relation]
+ redefinitions: tuple[UnitDefinition]
@property
- def variables(self) -> Set[str, ...]:
+ def variables(self) -> set[str]:
"""Return all variable names in all transformations."""
return set().union(*(r.variables for r in self.relations))
@classmethod
- def from_lines(cls, lines, non_int_type):
+ def from_lines(cls, lines: Iterable[str], non_int_type: type):
# TODO: this is to keep it backwards compatible
from ...delegates import ParserConfig, txt_defparser
diff --git a/pint/facets/context/objects.py b/pint/facets/context/objects.py
index 40c2bb5..38d8805 100644
--- a/pint/facets/context/objects.py
+++ b/pint/facets/context/objects.py
@@ -10,7 +10,8 @@ from __future__ import annotations
import weakref
from collections import ChainMap, defaultdict
-from typing import Optional, Tuple
+from typing import Any
+from collections.abc import Iterable
from ...facets.plain import UnitDefinition
from ...util import UnitsContainer, to_units_container
@@ -70,9 +71,9 @@ class Context:
def __init__(
self,
- name: Optional[str] = None,
- aliases: Tuple[str, ...] = (),
- defaults: Optional[dict] = None,
+ name: str | None = None,
+ aliases: tuple[str] = tuple(),
+ defaults: dict[str, Any] | None = None,
) -> None:
self.name = name
self.aliases = aliases
@@ -94,7 +95,7 @@ class Context:
self.relation_to_context = weakref.WeakValueDictionary()
@classmethod
- def from_context(cls, context: Context, **defaults) -> Context:
+ def from_context(cls, context: Context, **defaults: Any) -> Context:
"""Creates a new context that shares the funcs dictionary with the
original context. The default values are copied from the original
context and updated with the new defaults.
@@ -123,7 +124,9 @@ class Context:
return context
@classmethod
- def from_lines(cls, lines, to_base_func=None, non_int_type=float) -> Context:
+ def from_lines(
+ cls, lines: Iterable[str], to_base_func=None, non_int_type: type = float
+ ) -> Context:
cd = ContextDefinition.from_lines(lines, non_int_type)
return cls.from_definition(cd, to_base_func)
@@ -166,7 +169,7 @@ class Context:
del self.relation_to_context[_key]
@staticmethod
- def __keytransform__(src, dst) -> Tuple[UnitsContainer, UnitsContainer]:
+ def __keytransform__(src, dst) -> tuple[UnitsContainer, UnitsContainer]:
return to_units_container(src), to_units_container(dst)
def transform(self, src, dst, registry, value):
@@ -199,7 +202,7 @@ class Context:
def hashable(
self,
- ) -> Tuple[Optional[str], Tuple[str, ...], frozenset, frozenset, tuple]:
+ ) -> tuple[str | None, tuple[str, ...], frozenset, frozenset, tuple]:
"""Generate a unique hashable and comparable representation of self, which can
be used as a key in a dict. This class cannot define ``__hash__`` because it is
mutable, and the Python interpreter does cache the output of ``__hash__``.
@@ -274,7 +277,7 @@ class ContextChain(ChainMap):
"""
return self[(src, dst)].transform(src, dst, registry, value)
- def hashable(self):
+ def hashable(self) -> tuple[Any]:
"""Generate a unique hashable and comparable representation of self, which can
be used as a key in a dict. This class cannot define ``__hash__`` because it is
mutable, and the Python interpreter does cache the output of ``__hash__``.
diff --git a/pint/facets/context/registry.py b/pint/facets/context/registry.py
index ccf69d2..a36d82d 100644
--- a/pint/facets/context/registry.py
+++ b/pint/facets/context/registry.py
@@ -11,14 +11,14 @@ from __future__ import annotations
import functools
from collections import ChainMap
from contextlib import contextmanager
-from typing import Any, Callable, ContextManager, Dict, Union
+from typing import Any, Callable, ContextManager
from ..._typing import F
from ...errors import UndefinedUnitError
from ...util import find_connected_nodes, find_shortest_path, logger
from ..plain import PlainRegistry, UnitDefinition
from .definitions import ContextDefinition
-from .objects import Context, ContextChain
+from . import objects
# TODO: Put back annotation when possible
# registry_cache: "RegistryCache"
@@ -50,13 +50,13 @@ class ContextRegistry(PlainRegistry):
- Parse @context directive.
"""
- Context = Context
+ Context = objects.Context
def __init__(self, **kwargs: Any) -> None:
# Map context name (string) or abbreviation to context.
- self._contexts: Dict[str, Context] = {}
+ self._contexts: dict[str, objects.Context] = {}
# Stores active contexts.
- self._active_ctx = ContextChain()
+ self._active_ctx = objects.ContextChain()
# Map context chain to cache
self._caches = {}
# Map context chain to units override
@@ -71,7 +71,7 @@ class ContextRegistry(PlainRegistry):
super()._register_definition_adders()
self._register_adder(ContextDefinition, self.add_context)
- def add_context(self, context: Union[Context, ContextDefinition]) -> None:
+ def add_context(self, context: Context | ContextDefinition) -> None:
"""Add a context object to the registry.
The context will be accessible by its name and aliases.
@@ -80,7 +80,7 @@ class ContextRegistry(PlainRegistry):
see :meth:`enable_contexts`.
"""
if isinstance(context, ContextDefinition):
- context = Context.from_definition(context, self.get_dimensionality)
+ context = objects.Context.from_definition(context, self.get_dimensionality)
if not context.name:
raise ValueError("Can't add unnamed context to registry")
@@ -97,7 +97,7 @@ class ContextRegistry(PlainRegistry):
)
self._contexts[alias] = context
- def remove_context(self, name_or_alias: str) -> Context:
+ def remove_context(self, name_or_alias: str) -> objects.Context:
"""Remove a context from the registry and return it.
Notice that this methods will not disable the context;
@@ -194,7 +194,7 @@ class ContextRegistry(PlainRegistry):
self.define(definition)
def enable_contexts(
- self, *names_or_contexts: Union[str, Context], **kwargs
+ self, *names_or_contexts: str | objects.Context, **kwargs
) -> None:
"""Enable contexts provided by name or by object.
@@ -235,7 +235,7 @@ class ContextRegistry(PlainRegistry):
ctx.checked = True
# and create a new one with the new defaults.
- contexts = tuple(Context.from_context(ctx, **kwargs) for ctx in ctxs)
+ contexts = tuple(objects.Context.from_context(ctx, **kwargs) for ctx in ctxs)
# Finally we add them to the active context.
self._active_ctx.insert_contexts(*contexts)
@@ -253,7 +253,7 @@ class ContextRegistry(PlainRegistry):
self._switch_context_cache_and_units()
@contextmanager
- def context(self, *names, **kwargs) -> ContextManager[Context]:
+ def context(self, *names, **kwargs) -> ContextManager[objects.Context]:
"""Used as a context manager, this function enables to activate a context
which is removed after usage.
diff --git a/pint/facets/dask/__init__.py b/pint/facets/dask/__init__.py
index 42fced0..90c8972 100644
--- a/pint/facets/dask/__init__.py
+++ b/pint/facets/dask/__init__.py
@@ -14,7 +14,7 @@ from __future__ import annotations
import functools
from ...compat import compute, dask_array, persist, visualize
-from ..plain import PlainRegistry
+from ..plain import PlainRegistry, PlainQuantity
def check_dask_array(f):
@@ -31,13 +31,13 @@ def check_dask_array(f):
return wrapper
-class DaskQuantity:
+class DaskQuantity(PlainQuantity):
# Dask.array.Array ducking
def __dask_graph__(self):
if isinstance(self._magnitude, dask_array.Array):
return self._magnitude.__dask_graph__()
- else:
- return None
+
+ return None
def __dask_keys__(self):
return self._magnitude.__dask_keys__()
@@ -120,4 +120,4 @@ class DaskQuantity:
class DaskRegistry(PlainRegistry):
- _quantity_class = DaskQuantity
+ Quantity = DaskQuantity
diff --git a/pint/facets/formatting/objects.py b/pint/facets/formatting/objects.py
index 1ba92c9..5df937c 100644
--- a/pint/facets/formatting/objects.py
+++ b/pint/facets/formatting/objects.py
@@ -23,8 +23,10 @@ from ...formatting import (
)
from ...util import UnitsContainer, iterable
+from ..plain import PlainQuantity, PlainUnit
-class FormattingQuantity:
+
+class FormattingQuantity(PlainQuantity):
_exp_pattern = re.compile(r"([0-9]\.?[0-9]*)e(-?)\+?0*([0-9]+)")
def __format__(self, spec: str) -> str:
@@ -80,7 +82,7 @@ class FormattingQuantity:
else:
if isinstance(self.magnitude, ndarray):
# Use custom ndarray text formatting with monospace font
- formatter = "{{:{}}}".format(mspec)
+ formatter = f"{{:{mspec}}}"
# Need to override for scalars, which are detected as iterable,
# and don't respond to printoptions.
if self.magnitude.ndim == 0:
@@ -112,7 +114,7 @@ class FormattingQuantity:
else:
# Use custom ndarray text formatting--need to handle scalars differently
# since they don't respond to printoptions
- formatter = "{{:{}}}".format(mspec)
+ formatter = f"{{:{mspec}}}"
if obj.magnitude.ndim == 0:
mstr = formatter.format(obj.magnitude)
else:
@@ -154,7 +156,7 @@ class FormattingQuantity:
obj = self.to_compact()
else:
obj = self
- kwspec = dict(kwspec)
+ kwspec = kwspec.copy()
if "length" in kwspec:
kwspec["babel_length"] = kwspec.pop("length")
@@ -176,7 +178,7 @@ class FormattingQuantity:
return format(self)
-class FormattingUnit:
+class FormattingUnit(PlainUnit):
def __str__(self):
return format(self)
@@ -188,10 +190,10 @@ class FormattingUnit:
if not self._units:
return ""
units = UnitsContainer(
- dict(
- (self._REGISTRY._get_symbol(key), value)
+ {
+ self._REGISTRY._get_symbol(key): value
for key, value in self._units.items()
- )
+ }
)
uspec = uspec.replace("~", "")
else:
@@ -206,10 +208,10 @@ class FormattingUnit:
if self.dimensionless:
return ""
units = UnitsContainer(
- dict(
- (self._REGISTRY._get_symbol(key), value)
+ {
+ self._REGISTRY._get_symbol(key): value
for key, value in self._units.items()
- )
+ }
)
spec = spec.replace("~", "")
else:
diff --git a/pint/facets/formatting/registry.py b/pint/facets/formatting/registry.py
index bd9c74c..c4dc373 100644
--- a/pint/facets/formatting/registry.py
+++ b/pint/facets/formatting/registry.py
@@ -13,5 +13,5 @@ from .objects import FormattingQuantity, FormattingUnit
class FormattingRegistry(PlainRegistry):
- _quantity_class = FormattingQuantity
- _unit_class = FormattingUnit
+ Quantity = FormattingQuantity
+ Unit = FormattingUnit
diff --git a/pint/facets/group/definitions.py b/pint/facets/group/definitions.py
index c0abced..554a63b 100644
--- a/pint/facets/group/definitions.py
+++ b/pint/facets/group/definitions.py
@@ -8,9 +8,10 @@
from __future__ import annotations
-import typing as ty
+from collections.abc import Iterable
from dataclasses import dataclass
+from ..._typing import Self
from ... import errors
from .. import plain
@@ -22,12 +23,14 @@ class GroupDefinition(errors.WithDefErr):
#: name of the group
name: str
#: unit groups that will be included within the group
- using_group_names: ty.Tuple[str, ...]
+ using_group_names: tuple[str]
#: definitions for the units existing within the group
- definitions: ty.Tuple[plain.UnitDefinition, ...]
+ definitions: tuple[plain.UnitDefinition]
@classmethod
- def from_lines(cls, lines, non_int_type):
+ def from_lines(
+ cls: type[Self], lines: Iterable[str], non_int_type: type
+ ) -> Self | None:
# TODO: this is to keep it backwards compatible
from ...delegates import ParserConfig, txt_defparser
@@ -39,10 +42,10 @@ class GroupDefinition(errors.WithDefErr):
return definition
@property
- def unit_names(self) -> ty.Tuple[str, ...]:
+ def unit_names(self) -> tuple[str]:
return tuple(el.name for el in self.definitions)
- def __post_init__(self):
+ def __post_init__(self) -> None:
if not errors.is_valid_group_name(self.name):
raise self.def_err(errors.MSG_INVALID_GROUP_NAME)
diff --git a/pint/facets/group/objects.py b/pint/facets/group/objects.py
index 558a107..200a323 100644
--- a/pint/facets/group/objects.py
+++ b/pint/facets/group/objects.py
@@ -8,6 +8,7 @@
from __future__ import annotations
+from collections.abc import Generator, Iterable
from ...util import SharedRegistryObject, getattr_maybe_raise
from .definitions import GroupDefinition
@@ -23,32 +24,26 @@ class Group(SharedRegistryObject):
The group belongs to one Registry.
See GroupDefinition for the definition file syntax.
- """
- def __init__(self, name):
- """
- :param name: Name of the group. If not given, a root Group will be created.
- :type name: str
- :param groups: dictionary like object groups and system.
- The newly created group will be added after creation.
- :type groups: dict[str | Group]
- """
+ Parameters
+ ----------
+ name
+ If not given, a root Group will be created.
+ """
+ def __init__(self, name: str):
# The name of the group.
- #: type: str
self.name = name
#: Names of the units in this group.
#: :type: set[str]
- self._unit_names = set()
+ self._unit_names: set[str] = set()
#: Names of the groups in this group.
- #: :type: set[str]
- self._used_groups = set()
+ self._used_groups: set[str] = set()
#: Names of the groups in which this group is contained.
- #: :type: set[str]
- self._used_by = set()
+ self._used_by: set[str] = set()
# Add this group to the group dictionary
self._REGISTRY._groups[self.name] = self
@@ -59,8 +54,7 @@ class Group(SharedRegistryObject):
#: A cache of the included units.
#: None indicates that the cache has been invalidated.
- #: :type: frozenset[str] | None
- self._computed_members = None
+ self._computed_members: frozenset[str] | None = None
@property
def members(self):
@@ -70,23 +64,23 @@ class Group(SharedRegistryObject):
"""
if self._computed_members is None:
- self._computed_members = set(self._unit_names)
+ tmp = set(self._unit_names)
for _, group in self.iter_used_groups():
- self._computed_members |= group.members
+ tmp |= group.members
- self._computed_members = frozenset(self._computed_members)
+ self._computed_members = frozenset(tmp)
return self._computed_members
- def invalidate_members(self):
+ def invalidate_members(self) -> None:
"""Invalidate computed members in this Group and all parent nodes."""
self._computed_members = None
d = self._REGISTRY._groups
for name in self._used_by:
d[name].invalidate_members()
- def iter_used_groups(self):
+ def iter_used_groups(self) -> Generator[tuple[str, Group], None, None]:
pending = set(self._used_groups)
d = self._REGISTRY._groups
while pending:
@@ -95,13 +89,13 @@ class Group(SharedRegistryObject):
pending |= group._used_groups
yield name, d[name]
- def is_used_group(self, group_name):
+ def is_used_group(self, group_name: str) -> bool:
for name, _ in self.iter_used_groups():
if name == group_name:
return True
return False
- def add_units(self, *unit_names):
+ def add_units(self, *unit_names: str) -> None:
"""Add units to group."""
for unit_name in unit_names:
self._unit_names.add(unit_name)
@@ -109,17 +103,17 @@ class Group(SharedRegistryObject):
self.invalidate_members()
@property
- def non_inherited_unit_names(self):
+ def non_inherited_unit_names(self) -> frozenset[str]:
return frozenset(self._unit_names)
- def remove_units(self, *unit_names):
+ def remove_units(self, *unit_names: str) -> None:
"""Remove units from group."""
for unit_name in unit_names:
self._unit_names.remove(unit_name)
self.invalidate_members()
- def add_groups(self, *group_names):
+ def add_groups(self, *group_names: str) -> None:
"""Add groups to group."""
d = self._REGISTRY._groups
for group_name in group_names:
@@ -136,7 +130,7 @@ class Group(SharedRegistryObject):
self.invalidate_members()
- def remove_groups(self, *group_names):
+ def remove_groups(self, *group_names: str) -> None:
"""Remove groups from group."""
d = self._REGISTRY._groups
for group_name in group_names:
@@ -148,7 +142,9 @@ class Group(SharedRegistryObject):
self.invalidate_members()
@classmethod
- def from_lines(cls, lines, define_func, non_int_type=float) -> Group:
+ def from_lines(
+ cls, lines: Iterable[str], define_func, non_int_type: type = float
+ ) -> Group:
"""Return a Group object parsing an iterable of lines.
Parameters
@@ -190,6 +186,6 @@ class Group(SharedRegistryObject):
return grp
- def __getattr__(self, item):
+ def __getattr__(self, item: str):
getattr_maybe_raise(self, item)
return self._REGISTRY
diff --git a/pint/facets/group/registry.py b/pint/facets/group/registry.py
index 7269082..0d35ae0 100644
--- a/pint/facets/group/registry.py
+++ b/pint/facets/group/registry.py
@@ -8,17 +8,17 @@
from __future__ import annotations
-from typing import TYPE_CHECKING, Dict, FrozenSet
+from typing import TYPE_CHECKING
from ... import errors
if TYPE_CHECKING:
from ..._typing import Unit
-from ...util import build_dependent_class, create_class_with_registry
+from ...util import create_class_with_registry
from ..plain import PlainRegistry, UnitDefinition
from .definitions import GroupDefinition
-from .objects import Group
+from . import objects
class GroupRegistry(PlainRegistry):
@@ -34,19 +34,15 @@ class GroupRegistry(PlainRegistry):
# TODO: Change this to Group: Group to specify class
# and use introspection to get system class as a way
# to enjoy typing goodies
- _group_class = Group
+ Group = objects.Group
def __init__(self, **kwargs):
super().__init__(**kwargs)
#: Map group name to group.
#: :type: dict[ str | Group]
- self._groups: Dict[str, Group] = {}
+ self._groups: dict[str, objects.Group] = {}
self._groups["root"] = self.Group("root")
- def __init_subclass__(cls, **kwargs):
- super().__init_subclass__()
- cls.Group = build_dependent_class(cls, "Group", "_group_class")
-
def _init_dynamic_classes(self) -> None:
"""Generate subclasses on the fly and attach them to self"""
super()._init_dynamic_classes()
@@ -93,7 +89,7 @@ class GroupRegistry(PlainRegistry):
except KeyError as e:
raise errors.DefinitionSyntaxError(f"unknown dimension {e} in context")
- def get_group(self, name: str, create_if_needed: bool = True) -> Group:
+ def get_group(self, name: str, create_if_needed: bool = True) -> objects.Group:
"""Return a Group.
Parameters
@@ -117,7 +113,7 @@ class GroupRegistry(PlainRegistry):
return self.Group(name)
- def _get_compatible_units(self, input_units, group) -> FrozenSet["Unit"]:
+ def _get_compatible_units(self, input_units, group) -> frozenset[Unit]:
ret = super()._get_compatible_units(input_units, group)
if not group:
diff --git a/pint/facets/measurement/objects.py b/pint/facets/measurement/objects.py
index 0fed93f..5f3ba7a 100644
--- a/pint/facets/measurement/objects.py
+++ b/pint/facets/measurement/objects.py
@@ -18,12 +18,12 @@ from ..plain import PlainQuantity
MISSING = object()
-class MeasurementQuantity:
+class MeasurementQuantity(PlainQuantity):
# Measurement support
def plus_minus(self, error, relative=False):
if isinstance(error, self.__class__):
if relative:
- raise ValueError("{} is not a valid relative error.".format(error))
+ raise ValueError(f"{error} is not a valid relative error.")
error = error.to(self._units).magnitude
else:
if relative:
@@ -98,7 +98,7 @@ class Measurement(PlainQuantity):
)
def __str__(self):
- return "{}".format(self)
+ return f"{self}"
def __format__(self, spec):
spec = spec or self.default_format
@@ -133,7 +133,7 @@ class Measurement(PlainQuantity):
# scientific notation ('e' or 'E' and sometimes 'g' or 'G').
mstr = mstr.replace("(", "").replace(")", " ")
ustr = siunitx_format_unit(self.units._units, self._REGISTRY)
- return r"\SI%s{%s}{%s}" % (opts, mstr, ustr)
+ return rf"\SI{opts}{{{mstr}}}{{{ustr}}}"
# standard cases
if "L" in spec:
diff --git a/pint/facets/measurement/registry.py b/pint/facets/measurement/registry.py
index e704399..0fc4391 100644
--- a/pint/facets/measurement/registry.py
+++ b/pint/facets/measurement/registry.py
@@ -10,21 +10,15 @@
from __future__ import annotations
from ...compat import ufloat
-from ...util import build_dependent_class, create_class_with_registry
+from ...util import create_class_with_registry
from ..plain import PlainRegistry
-from .objects import Measurement, MeasurementQuantity
+from .objects import MeasurementQuantity
+from . import objects
class MeasurementRegistry(PlainRegistry):
- _quantity_class = MeasurementQuantity
- _measurement_class = Measurement
-
- def __init_subclass__(cls, **kwargs):
- super().__init_subclass__()
-
- cls.Measurement = build_dependent_class(
- cls, "Measurement", "_measurement_class"
- )
+ Quantity = MeasurementQuantity
+ Measurement = objects.Measurement
def _init_dynamic_classes(self) -> None:
"""Generate subclasses on the fly and attach them to self"""
diff --git a/pint/facets/nonmultiplicative/definitions.py b/pint/facets/nonmultiplicative/definitions.py
index dbfc0ff..f795cf0 100644
--- a/pint/facets/nonmultiplicative/definitions.py
+++ b/pint/facets/nonmultiplicative/definitions.py
@@ -10,6 +10,7 @@ from __future__ import annotations
from dataclasses import dataclass
+from ..._typing import Magnitude
from ...compat import HAS_NUMPY, exp, log
from ..plain import ScaleConverter
@@ -24,7 +25,7 @@ class OffsetConverter(ScaleConverter):
def is_multiplicative(self):
return self.offset == 0
- def to_reference(self, value, inplace=False):
+ def to_reference(self, value: Magnitude, inplace: bool = False) -> Magnitude:
if inplace:
value *= self.scale
value += self.offset
@@ -33,7 +34,7 @@ class OffsetConverter(ScaleConverter):
return value
- def from_reference(self, value, inplace=False):
+ def from_reference(self, value: Magnitude, inplace: bool = False) -> Magnitude:
if inplace:
value -= self.offset
value /= self.scale
@@ -66,6 +67,7 @@ class LogarithmicConverter(ScaleConverter):
controls if computation is done in place
"""
+ # TODO: Can I use PintScalar here?
logbase: float
logfactor: float
@@ -77,7 +79,7 @@ class LogarithmicConverter(ScaleConverter):
def is_logarithmic(self):
return True
- def from_reference(self, value, inplace=False):
+ def from_reference(self, value: Magnitude, inplace: bool = False) -> Magnitude:
"""Converts value from the reference unit to the logarithmic unit
dBm <------ mW
@@ -95,7 +97,7 @@ class LogarithmicConverter(ScaleConverter):
return value
- def to_reference(self, value, inplace=False):
+ def to_reference(self, value: Magnitude, inplace: bool = False) -> Magnitude:
"""Converts value to the reference unit from the logarithmic unit
dBm ------> mW
diff --git a/pint/facets/nonmultiplicative/objects.py b/pint/facets/nonmultiplicative/objects.py
index 1708e32..0ab743e 100644
--- a/pint/facets/nonmultiplicative/objects.py
+++ b/pint/facets/nonmultiplicative/objects.py
@@ -8,16 +8,16 @@
from __future__ import annotations
-from typing import List
+from ..plain import PlainQuantity
-class NonMultiplicativeQuantity:
+class NonMultiplicativeQuantity(PlainQuantity):
@property
def _is_multiplicative(self) -> bool:
"""Check if the PlainQuantity object has only multiplicative units."""
return not self._get_non_multiplicative_units()
- def _get_non_multiplicative_units(self) -> List[str]:
+ def _get_non_multiplicative_units(self) -> list[str]:
"""Return a list of the of non-multiplicative units of the PlainQuantity object."""
return [
unit
@@ -25,7 +25,7 @@ class NonMultiplicativeQuantity:
if not self._get_unit_definition(unit).is_multiplicative
]
- def _get_delta_units(self) -> List[str]:
+ def _get_delta_units(self) -> list[str]:
"""Return list of delta units ot the PlainQuantity object."""
return [u for u in self._units if u.startswith("delta_")]
@@ -40,7 +40,7 @@ class NonMultiplicativeQuantity:
self._get_unit_definition(d).reference == offset_unit_dim for d in deltas
)
- def _ok_for_muldiv(self, no_offset_units=None) -> bool:
+ def _ok_for_muldiv(self, no_offset_units: int | None = None) -> bool:
"""Checks if PlainQuantity object can be multiplied or divided"""
is_ok = True
diff --git a/pint/facets/nonmultiplicative/registry.py b/pint/facets/nonmultiplicative/registry.py
index 17b053e..8bc04db 100644
--- a/pint/facets/nonmultiplicative/registry.py
+++ b/pint/facets/nonmultiplicative/registry.py
@@ -8,7 +8,7 @@
from __future__ import annotations
-from typing import Any, Optional
+from typing import Any
from ...errors import DimensionalityError, UndefinedUnitError
from ...util import UnitsContainer, logger
@@ -35,7 +35,7 @@ class NonMultiplicativeRegistry(PlainRegistry):
"""
- _quantity_class = NonMultiplicativeQuantity
+ Quantity = NonMultiplicativeQuantity
def __init__(
self,
@@ -56,8 +56,8 @@ class NonMultiplicativeRegistry(PlainRegistry):
def _parse_units(
self,
input_string: str,
- as_delta: Optional[bool] = None,
- case_sensitive: Optional[bool] = None,
+ as_delta: bool | None = None,
+ case_sensitive: bool | None = None,
):
""" """
if as_delta is None:
diff --git a/pint/facets/numpy/numpy_func.py b/pint/facets/numpy/numpy_func.py
index f25f4a4..e7a9b67 100644
--- a/pint/facets/numpy/numpy_func.py
+++ b/pint/facets/numpy/numpy_func.py
@@ -220,7 +220,7 @@ def get_op_output_unit(unit_op, first_input_units, all_args=None, size=None):
product /= x.units
result_unit = product**-1
else:
- raise ValueError("Output unit method {} not understood".format(unit_op))
+ raise ValueError(f"Output unit method {unit_op} not understood")
return result_unit
@@ -237,7 +237,7 @@ def implements(numpy_func_string, func_type):
elif func_type == "ufunc":
HANDLED_UFUNCS[numpy_func_string] = func
else:
- raise ValueError("Invalid func_type {}".format(func_type))
+ raise ValueError(f"Invalid func_type {func_type}")
return func
return decorator
@@ -311,7 +311,7 @@ def implement_func(func_type, func_str, input_units=None, output_unit=None):
return result_magnitude
elif output_unit == "match_input":
result_unit = first_input_units
- elif output_unit in [
+ elif output_unit in (
"sum",
"mul",
"delta",
@@ -324,7 +324,7 @@ def implement_func(func_type, func_str, input_units=None, output_unit=None):
"cbrt",
"reciprocal",
"size",
- ]:
+ ):
result_unit = get_op_output_unit(
output_unit, first_input_units, tuple(chain(args, kwargs.values()))
)
@@ -499,8 +499,8 @@ def _frexp(x, *args, **kwargs):
def _power(x1, x2):
if _is_quantity(x1):
return x1**x2
- else:
- return x2.__rpow__(x1)
+
+ return x2.__rpow__(x1)
@implements("add", "ufunc")
@@ -535,8 +535,8 @@ def _full_like(a, fill_value, **kwargs):
np.ones_like(a, **kwargs) * fill_value.m,
fill_value.units,
)
- else:
- return np.ones_like(a, **kwargs) * fill_value
+
+ return np.ones_like(a, **kwargs) * fill_value
@implements("interp", "function")
@@ -671,8 +671,8 @@ def _any(a, *args, **kwargs):
# Only valid when multiplicative unit/no offset
if a._is_multiplicative:
return np.any(a._magnitude, *args, **kwargs)
- else:
- raise ValueError("Boolean value of Quantity with offset unit is ambiguous.")
+
+ raise ValueError("Boolean value of Quantity with offset unit is ambiguous.")
@implements("all", "function")
@@ -725,7 +725,7 @@ def implement_prod_func(name):
return registry.Quantity(result, units)
-for name in ["prod", "nanprod"]:
+for name in ("prod", "nanprod"):
implement_prod_func(name)
@@ -780,7 +780,7 @@ def implement_mul_func(func):
return a.units._REGISTRY.Quantity(mag, units)
-for func_str in ["cross", "dot"]:
+for func_str in ("cross", "dot"):
implement_mul_func(func_str)
@@ -830,11 +830,11 @@ def implement_consistent_units_by_argument(func_str, unit_arguments, wrap_output
# Conditionally wrap output
if wrap_output:
return output_wrap(ret)
- else:
- return ret
+
+ return ret
-for func_str, unit_arguments, wrap_output in [
+for func_str, unit_arguments, wrap_output in (
("expand_dims", "a", True),
("squeeze", "a", True),
("rollaxis", "a", True),
@@ -884,7 +884,7 @@ for func_str, unit_arguments, wrap_output in [
("reshape", "a", True),
("allclose", ["a", "b", "atol"], False),
("intersect1d", ["ar1", "ar2"], True),
-]:
+):
implement_consistent_units_by_argument(func_str, unit_arguments, wrap_output)
@@ -914,7 +914,7 @@ def implement_atleast_nd(func_str):
return output_unit._REGISTRY.Quantity(arrays_magnitude, output_unit)
-for func_str in ["atleast_1d", "atleast_2d", "atleast_3d"]:
+for func_str in ("atleast_1d", "atleast_2d", "atleast_3d"):
implement_atleast_nd(func_str)
@@ -935,24 +935,24 @@ def implement_single_dimensionless_argument_func(func_str):
return a._REGISTRY.Quantity(func(a_stripped, *args, **kwargs))
-for func_str in ["cumprod", "cumproduct", "nancumprod"]:
+for func_str in ("cumprod", "cumproduct", "nancumprod"):
implement_single_dimensionless_argument_func(func_str)
# Handle single-argument consistent unit functions
-for func_str in [
+for func_str in (
"block",
"hstack",
"vstack",
"dstack",
"column_stack",
"broadcast_arrays",
-]:
+):
implement_func(
"function", func_str, input_units="all_consistent", output_unit="match_input"
)
# Handle functions that ignore units on input and output
-for func_str in [
+for func_str in (
"size",
"isreal",
"iscomplex",
@@ -969,19 +969,19 @@ for func_str in [
"count_nonzero",
"nonzero",
"result_type",
-]:
+):
implement_func("function", func_str, input_units=None, output_unit=None)
# Handle functions with output unit defined by operation
-for func_str in ["std", "nanstd", "sum", "nansum", "cumsum", "nancumsum"]:
+for func_str in ("std", "nanstd", "sum", "nansum", "cumsum", "nancumsum"):
implement_func("function", func_str, input_units=None, output_unit="sum")
-for func_str in ["diff", "ediff1d"]:
+for func_str in ("diff", "ediff1d"):
implement_func("function", func_str, input_units=None, output_unit="delta")
-for func_str in ["gradient"]:
+for func_str in ("gradient",):
implement_func("function", func_str, input_units=None, output_unit="delta,div")
-for func_str in ["linalg.solve"]:
+for func_str in ("linalg.solve",):
implement_func("function", func_str, input_units=None, output_unit="invdiv")
-for func_str in ["var", "nanvar"]:
+for func_str in ("var", "nanvar"):
implement_func("function", func_str, input_units=None, output_unit="variance")
@@ -997,7 +997,7 @@ def numpy_wrap(func_type, func, args, kwargs, types):
# ufuncs do not have func.__module__
name = func.__name__
else:
- raise ValueError("Invalid func_type {}".format(func_type))
+ raise ValueError(f"Invalid func_type {func_type}")
if name not in handled or any(is_upcast_type(t) for t in types):
return NotImplemented
diff --git a/pint/facets/numpy/quantity.py b/pint/facets/numpy/quantity.py
index 9aa55ce..131983c 100644
--- a/pint/facets/numpy/quantity.py
+++ b/pint/facets/numpy/quantity.py
@@ -13,6 +13,8 @@ import math
import warnings
from typing import Any
+from ..plain import PlainQuantity
+
from ..._typing import Shape, _MagnitudeType
from ...compat import _to_magnitude, np
from ...errors import DimensionalityError, PintTypeError, UnitStrippedWarning
@@ -40,7 +42,7 @@ def method_wraps(numpy_func):
return wrapper
-class NumpyQuantity:
+class NumpyQuantity(PlainQuantity):
""" """
# NumPy function/ufunc support
@@ -52,11 +54,11 @@ class NumpyQuantity:
return NotImplemented
# Replicate types from __array_function__
- types = set(
+ types = {
type(arg)
for arg in list(inputs) + list(kwargs.values())
if hasattr(arg, "__array_ufunc__")
- )
+ }
return numpy_wrap("ufunc", ufunc, inputs, kwargs, types)
@@ -99,8 +101,8 @@ class NumpyQuantity:
if output_unit is not None:
return self.__class__(value, output_unit)
- else:
- return value
+
+ return value
def __array__(self, t=None) -> np.ndarray:
warnings.warn(
diff --git a/pint/facets/numpy/registry.py b/pint/facets/numpy/registry.py
index fa4768f..11d57f3 100644
--- a/pint/facets/numpy/registry.py
+++ b/pint/facets/numpy/registry.py
@@ -15,5 +15,5 @@ from .unit import NumpyUnit
class NumpyRegistry(PlainRegistry):
- _quantity_class = NumpyQuantity
- _unit_class = NumpyUnit
+ Quantity = NumpyQuantity
+ Unit = NumpyUnit
diff --git a/pint/facets/numpy/unit.py b/pint/facets/numpy/unit.py
index 0b5007f..d6bf140 100644
--- a/pint/facets/numpy/unit.py
+++ b/pint/facets/numpy/unit.py
@@ -9,9 +9,10 @@
from __future__ import annotations
from ...compat import is_upcast_type
+from ..plain import PlainUnit
-class NumpyUnit:
+class NumpyUnit(PlainUnit):
__array_priority__ = 17
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
@@ -20,11 +21,11 @@ class NumpyUnit:
return NotImplemented
# Check types and return NotImplemented when upcast type encountered
- types = set(
+ types = {
type(arg)
for arg in list(inputs) + list(kwargs.values())
if hasattr(arg, "__array_ufunc__")
- )
+ }
if any(is_upcast_type(other) for other in types):
return NotImplemented
@@ -38,5 +39,5 @@ class NumpyUnit:
),
**kwargs,
)
- else:
- return NotImplemented
+
+ return NotImplemented
diff --git a/pint/facets/plain/definitions.py b/pint/facets/plain/definitions.py
index 11a3095..79a44f1 100644
--- a/pint/facets/plain/definitions.py
+++ b/pint/facets/plain/definitions.py
@@ -13,8 +13,9 @@ import numbers
import typing as ty
from dataclasses import dataclass
from functools import cached_property
-from typing import Callable, Optional
+from typing import Callable, Any
+from ..._typing import Magnitude
from ... import errors
from ...converters import Converter
from ...util import UnitsContainer
@@ -23,7 +24,7 @@ from ...util import UnitsContainer
class NotNumeric(Exception):
"""Internal exception. Do not expose outside Pint"""
- def __init__(self, value):
+ def __init__(self, value: Any):
self.value = value
@@ -76,7 +77,7 @@ class PrefixDefinition(errors.WithDefErr):
#: scaling value for this prefix
value: numbers.Number
#: canonical symbol
- defined_symbol: Optional[str] = ""
+ defined_symbol: str | None = ""
#: additional names for the same prefix
aliases: ty.Tuple[str, ...] = ()
@@ -115,18 +116,26 @@ class UnitDefinition(errors.WithDefErr):
#: canonical name of the unit
name: str
#: canonical symbol
- defined_symbol: ty.Optional[str]
+ defined_symbol: str | None
#: additional names for the same unit
- aliases: ty.Tuple[str, ...]
+ aliases: tuple[str]
#: A functiont that converts a value in these units into the reference units
- converter: ty.Optional[ty.Union[Callable, Converter]]
+ converter: Callable[
+ [
+ Magnitude,
+ ],
+ Magnitude,
+ ] | Converter | None
#: Reference units.
- reference: ty.Optional[UnitsContainer]
+ reference: UnitsContainer | None
def __post_init__(self):
if not errors.is_valid_unit_name(self.name):
raise self.def_err(errors.MSG_INVALID_UNIT_NAME)
+ # TODO: check why reference: UnitsContainer | None
+ assert isinstance(self.reference, UnitsContainer)
+
if not any(map(errors.is_dim, self.reference.keys())):
invalid = tuple(
itertools.filterfalse(errors.is_valid_unit_name, self.reference.keys())
@@ -180,14 +189,20 @@ class UnitDefinition(errors.WithDefErr):
@property
def is_base(self) -> bool:
"""Indicates if it is a base unit."""
+
+ # TODO: why is this here
return self._is_base
@property
def is_multiplicative(self) -> bool:
+ # TODO: Check how to avoid this check
+ assert isinstance(self.converter, Converter)
return self.converter.is_multiplicative
@property
def is_logarithmic(self) -> bool:
+ # TODO: Check how to avoid this check
+ assert isinstance(self.converter, Converter)
return self.converter.is_logarithmic
@property
@@ -272,7 +287,7 @@ class ScaleConverter(Converter):
scale: float
- def to_reference(self, value, inplace=False):
+ def to_reference(self, value: Magnitude, inplace: bool = False) -> Magnitude:
if inplace:
value *= self.scale
else:
@@ -280,7 +295,7 @@ class ScaleConverter(Converter):
return value
- def from_reference(self, value, inplace=False):
+ def from_reference(self, value: Magnitude, inplace: bool = False) -> Magnitude:
if inplace:
value /= self.scale
else:
diff --git a/pint/facets/plain/objects.py b/pint/facets/plain/objects.py
index 5b2837b..a868c7f 100644
--- a/pint/facets/plain/objects.py
+++ b/pint/facets/plain/objects.py
@@ -11,4 +11,4 @@ from __future__ import annotations
from .quantity import PlainQuantity
from .unit import PlainUnit, UnitsContainer
-__all__ = [PlainUnit, PlainQuantity, UnitsContainer]
+__all__ = ["PlainUnit", "PlainQuantity", "UnitsContainer"]
diff --git a/pint/facets/plain/quantity.py b/pint/facets/plain/quantity.py
index 359e613..1eaaa3d 100644
--- a/pint/facets/plain/quantity.py
+++ b/pint/facets/plain/quantity.py
@@ -20,18 +20,11 @@ from typing import (
TYPE_CHECKING,
Any,
Callable,
- Dict,
Generic,
- Iterable,
- Iterator,
- List,
- Optional,
- Sequence,
- Tuple,
TypeVar,
- Union,
overload,
)
+from collections.abc import Iterable, Iterator, Sequence
from ..._typing import S, UnitLike, _MagnitudeType
from ...compat import (
@@ -179,25 +172,25 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
@overload
def __new__(
- cls, value: str, units: Optional[UnitLike] = None
+ cls, value: str, units: UnitLike | None = None
) -> PlainQuantity[Magnitude]:
...
@overload
def __new__( # type: ignore[misc]
- cls, value: Sequence, units: Optional[UnitLike] = None
+ cls, value: Sequence, units: UnitLike | None = None
) -> PlainQuantity[np.ndarray]:
...
@overload
def __new__(
- cls, value: PlainQuantity[Magnitude], units: Optional[UnitLike] = None
+ cls, value: PlainQuantity[Magnitude], units: UnitLike | None = None
) -> PlainQuantity[Magnitude]:
...
@overload
def __new__(
- cls, value: Magnitude, units: Optional[UnitLike] = None
+ cls, value: Magnitude, units: UnitLike | None = None
) -> PlainQuantity[Magnitude]:
...
@@ -281,15 +274,15 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
def __repr__(self) -> str:
if isinstance(self._magnitude, float):
return f"<Quantity({self._magnitude:.9}, '{self._units}')>"
- else:
- return f"<Quantity({self._magnitude}, '{self._units}')>"
+
+ return f"<Quantity({self._magnitude}, '{self._units}')>"
def __hash__(self) -> int:
self_base = self.to_base_units()
if self_base.dimensionless:
return hash(self_base.magnitude)
- else:
- return hash((self_base.__class__, self_base.magnitude, self_base.units))
+
+ return hash((self_base.__class__, self_base.magnitude, self_base.units))
@property
def magnitude(self) -> _MagnitudeType:
@@ -316,12 +309,12 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
return self.to(units).magnitude
@property
- def units(self) -> "Unit":
+ def units(self) -> Unit:
"""PlainQuantity's units. Long form for `u`"""
return self._REGISTRY.Unit(self._units)
@property
- def u(self) -> "Unit":
+ def u(self) -> Unit:
"""PlainQuantity's units. Short form for `units`"""
return self._REGISTRY.Unit(self._units)
@@ -337,7 +330,7 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
return not bool(tmp.dimensionality)
- _dimensionality: Optional[UnitsContainerT] = None
+ _dimensionality: UnitsContainerT | None = None
@property
def dimensionality(self) -> UnitsContainerT:
@@ -358,7 +351,7 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
@classmethod
def from_list(
- cls, quant_list: List[PlainQuantity], units=None
+ cls, quant_list: list[PlainQuantity], units=None
) -> PlainQuantity[np.ndarray]:
"""Transforms a list of Quantities into an numpy.array quantity.
If no units are specified, the unit of the first element will be used.
@@ -421,7 +414,7 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
def from_tuple(cls, tup):
return cls(tup[0], cls._REGISTRY.UnitsContainer(tup[1]))
- def to_tuple(self) -> Tuple[_MagnitudeType, Tuple[Tuple[str]]]:
+ def to_tuple(self) -> tuple[_MagnitudeType, tuple[tuple[str]]]:
return self.m, tuple(self._units.items())
def compatible_units(self, *contexts):
@@ -432,7 +425,7 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
return self._REGISTRY.get_compatible_units(self._units)
def is_compatible_with(
- self, other: Any, *contexts: Union[str, Context], **ctx_kwargs: Any
+ self, other: Any, *contexts: str | Context, **ctx_kwargs: Any
) -> bool:
"""check if the other object is compatible
@@ -652,7 +645,7 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
):
return self
- SI_prefixes: Dict[int, str] = {}
+ SI_prefixes: dict[int, str] = {}
for prefix in self._REGISTRY._prefixes.values():
try:
scale = prefix.converter.scale
@@ -702,7 +695,7 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
return self.to(new_unit_container)
def to_preferred(
- self, preferred_units: List[UnitLike]
+ self, preferred_units: list[UnitLike]
) -> PlainQuantity[_MagnitudeType]:
"""Return Quantity converted to a unit composed of the preferred units.
@@ -732,9 +725,9 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
for preferred_unit in preferred_units:
dims = sorted(preferred_unit.dimensionality)
if dims == self_dims:
- p_exps_head, *p_exps_tail = [
+ p_exps_head, *p_exps_tail = (
preferred_unit.dimensionality[d] for d in dims
- ]
+ )
if all(
s_exps_tail[i] * p_exps_head == p_exps_tail[i] ** s_exps_head
for i in range(n)
@@ -812,15 +805,15 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
# update preferred_units with the selected units that were originally preferred
preferred_units = list(
- set(u for d, u in unit_selections.items() if d in preferred_dims)
+ {u for d, u in unit_selections.items() if d in preferred_dims}
)
- preferred_units.sort(key=lambda unit: str(unit)) # for determinism
+ preferred_units.sort(key=str) # for determinism
# and unpreferred_units are the selected units that weren't originally preferred
unpreferred_units = list(
- set(u for d, u in unit_selections.items() if d not in preferred_dims)
+ {u for d, u in unit_selections.items() if d not in preferred_dims}
)
- unpreferred_units.sort(key=lambda unit: str(unit)) # for determinism
+ unpreferred_units.sort(key=str) # for determinism
# for indexability
dimensions = list(dimension_set)
@@ -918,10 +911,10 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
result_unit = sorting_keys[min_key]
return self.to(result_unit)
- else:
- # for whatever reason, a solution wasn't found
- # return the original quantity
- return self
+
+ # for whatever reason, a solution wasn't found
+ # return the original quantity
+ return self
# Mathematical operations
def __int__(self) -> int:
@@ -1178,22 +1171,22 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
return self.to_timedelta() + other
elif is_duck_array_type(type(self._magnitude)):
return self._iadd_sub(other, operator.iadd)
- else:
- return self._add_sub(other, operator.add)
+
+ return self._add_sub(other, operator.add)
def __add__(self, other):
if isinstance(other, datetime.datetime):
return self.to_timedelta() + other
- else:
- return self._add_sub(other, operator.add)
+
+ return self._add_sub(other, operator.add)
__radd__ = __add__
def __isub__(self, other):
if is_duck_array_type(type(self._magnitude)):
return self._iadd_sub(other, operator.isub)
- else:
- return self._add_sub(other, operator.sub)
+
+ return self._add_sub(other, operator.sub)
def __sub__(self, other):
return self._add_sub(other, operator.sub)
@@ -1201,8 +1194,8 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
def __rsub__(self, other):
if isinstance(other, datetime.datetime):
return other - self.to_timedelta()
- else:
- return -self._add_sub(other, operator.sub)
+
+ return -self._add_sub(other, operator.sub)
@check_implemented
@ireduce_dimensions
@@ -1235,10 +1228,10 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
if not self._ok_for_muldiv(no_offset_units_self):
raise OffsetUnitCalculusError(self._units, getattr(other, "units", ""))
if len(offset_units_self) == 1:
- if self._units[offset_units_self[0]] != 1 or magnitude_op not in [
+ if self._units[offset_units_self[0]] != 1 or magnitude_op not in (
operator.mul,
operator.imul,
- ]:
+ ):
raise OffsetUnitCalculusError(
self._units, getattr(other, "units", "")
)
@@ -1259,14 +1252,14 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
if not self._ok_for_muldiv(no_offset_units_self):
raise OffsetUnitCalculusError(self._units, other._units)
- elif no_offset_units_self == 1 and len(self._units) == 1:
+ elif no_offset_units_self == len(self._units) == 1:
self.ito_root_units()
no_offset_units_other = len(other._get_non_multiplicative_units())
if not other._ok_for_muldiv(no_offset_units_other):
raise OffsetUnitCalculusError(self._units, other._units)
- elif no_offset_units_other == 1 and len(other._units) == 1:
+ elif no_offset_units_other == len(other._units) == 1:
other.ito_root_units()
self._magnitude = magnitude_op(self._magnitude, other._magnitude)
@@ -1304,10 +1297,10 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
if not self._ok_for_muldiv(no_offset_units_self):
raise OffsetUnitCalculusError(self._units, getattr(other, "units", ""))
if len(offset_units_self) == 1:
- if self._units[offset_units_self[0]] != 1 or magnitude_op not in [
+ if self._units[offset_units_self[0]] != 1 or magnitude_op not in (
operator.mul,
operator.imul,
- ]:
+ ):
raise OffsetUnitCalculusError(
self._units, getattr(other, "units", "")
)
@@ -1332,14 +1325,14 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
if not self._ok_for_muldiv(no_offset_units_self):
raise OffsetUnitCalculusError(self._units, other._units)
- elif no_offset_units_self == 1 and len(self._units) == 1:
+ elif no_offset_units_self == len(self._units) == 1:
new_self = self.to_root_units()
no_offset_units_other = len(other._get_non_multiplicative_units())
if not other._ok_for_muldiv(no_offset_units_other):
raise OffsetUnitCalculusError(self._units, other._units)
- elif no_offset_units_other == 1 and len(other._units) == 1:
+ elif no_offset_units_other == len(other._units) == 1:
other = other.to_root_units()
magnitude = magnitude_op(new_self._magnitude, other._magnitude)
@@ -1350,8 +1343,8 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
def __imul__(self, other):
if is_duck_array_type(type(self._magnitude)):
return self._imul_div(other, operator.imul)
- else:
- return self._mul_div(other, operator.mul)
+
+ return self._mul_div(other, operator.mul)
def __mul__(self, other):
return self._mul_div(other, operator.mul)
@@ -1374,8 +1367,8 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
def __itruediv__(self, other):
if is_duck_array_type(type(self._magnitude)):
return self._imul_div(other, operator.itruediv)
- else:
- return self._mul_div(other, operator.truediv)
+
+ return self._mul_div(other, operator.truediv)
def __truediv__(self, other):
if isinstance(self.m, int) or isinstance(getattr(other, "m", None), int):
@@ -1395,7 +1388,7 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
no_offset_units_self = len(self._get_non_multiplicative_units())
if not self._ok_for_muldiv(no_offset_units_self):
raise OffsetUnitCalculusError(self._units, "")
- elif no_offset_units_self == 1 and len(self._units) == 1:
+ elif no_offset_units_self == len(self._units) == 1:
self = self.to_root_units()
return self.__class__(other_magnitude / self._magnitude, 1 / self._units)
@@ -1627,7 +1620,7 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
def __abs__(self) -> PlainQuantity[_MagnitudeType]:
return self.__class__(abs(self._magnitude), self._units)
- def __round__(self, ndigits: Optional[int] = 0) -> PlainQuantity[int]:
+ def __round__(self, ndigits: int | None = 0) -> PlainQuantity[int]:
return self.__class__(round(self._magnitude, ndigits=ndigits), self._units)
def __pos__(self) -> PlainQuantity[_MagnitudeType]:
@@ -1720,9 +1713,7 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
else:
raise OffsetUnitCalculusError(self._units)
else:
- raise ValueError(
- "Cannot compare PlainQuantity and {}".format(type(other))
- )
+ raise ValueError(f"Cannot compare PlainQuantity and {type(other)}")
# Registry equality check based on util.SharedRegistryObject
if self._REGISTRY is not other._REGISTRY:
@@ -1791,11 +1782,11 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
"""Check if the PlainQuantity object has only multiplicative units."""
return True
- def _get_non_multiplicative_units(self) -> List[str]:
+ def _get_non_multiplicative_units(self) -> list[str]:
"""Return a list of the of non-multiplicative units of the PlainQuantity object."""
return []
- def _get_delta_units(self) -> List[str]:
+ def _get_delta_units(self) -> list[str]:
"""Return list of delta units ot the PlainQuantity object."""
return [u for u in self._units if u.startswith("delta_")]
diff --git a/pint/facets/plain/registry.py b/pint/facets/plain/registry.py
index 0bf1545..d3baff4 100644
--- a/pint/facets/plain/registry.py
+++ b/pint/facets/plain/registry.py
@@ -24,18 +24,10 @@ from typing import (
TYPE_CHECKING,
Any,
Callable,
- Dict,
- FrozenSet,
- Iterable,
- Iterator,
- List,
- Optional,
- Set,
- Tuple,
- Type,
TypeVar,
Union,
)
+from collections.abc import Iterable, Iterator
if TYPE_CHECKING:
from ..context import Context
@@ -51,7 +43,6 @@ from ...util import UnitsContainer
from ...util import UnitsContainer as UnitsContainerT
from ...util import (
_is_dim,
- build_dependent_class,
create_class_with_registry,
getattr_maybe_raise,
logger,
@@ -83,7 +74,7 @@ T = TypeVar("T")
_BLOCK_RE = re.compile(r"[ (]")
-@functools.lru_cache()
+@functools.lru_cache
def pattern_to_regex(pattern):
if hasattr(pattern, "finditer"):
pattern = pattern.pattern
@@ -96,7 +87,7 @@ def pattern_to_regex(pattern):
return re.compile(pattern)
-NON_INT_TYPE = Type[Union[float, Decimal, Fraction]]
+NON_INT_TYPE = type[Union[float, Decimal, Fraction]]
PreprocessorType = Callable[[str], str]
@@ -105,13 +96,13 @@ class RegistryCache:
def __init__(self) -> None:
#: Maps dimensionality (UnitsContainer) to Units (str)
- self.dimensional_equivalents: Dict[UnitsContainer, Set[str]] = {}
+ self.dimensional_equivalents: dict[UnitsContainer, set[str]] = {}
#: Maps dimensionality (UnitsContainer) to Dimensionality (UnitsContainer)
self.root_units = {}
#: Maps dimensionality (UnitsContainer) to Units (UnitsContainer)
- self.dimensionality: Dict[UnitsContainer, UnitsContainer] = {}
+ self.dimensionality: dict[UnitsContainer, UnitsContainer] = {}
#: Cache the unit name associated to user input. ('mV' -> 'millivolt')
- self.parse_unit: Dict[str, UnitsContainer] = {}
+ self.parse_unit: dict[str, UnitsContainer] = {}
def __eq__(self, other):
if not isinstance(other, self.__class__):
@@ -181,12 +172,12 @@ class PlainRegistry(metaclass=RegistryMeta):
"""
#: Babel.Locale instance or None
- fmt_locale: Optional[Locale] = None
+ fmt_locale: Locale | None = None
_diskcache = None
- _quantity_class = PlainQuantity
- _unit_class = PlainUnit
+ Quantity = PlainQuantity
+ Unit = PlainUnit
_def_parser = None
@@ -197,16 +188,16 @@ class PlainRegistry(metaclass=RegistryMeta):
force_ndarray_like: bool = False,
on_redefinition: str = "warn",
auto_reduce_dimensions: bool = False,
- preprocessors: Optional[List[PreprocessorType]] = None,
- fmt_locale: Optional[str] = None,
+ preprocessors: list[PreprocessorType] | None = None,
+ fmt_locale: str | None = None,
non_int_type: NON_INT_TYPE = float,
case_sensitive: bool = True,
- cache_folder: Union[str, pathlib.Path, None] = None,
- separate_format_defaults: Optional[bool] = None,
+ cache_folder: str | pathlib.Path | None = None,
+ separate_format_defaults: bool | None = None,
mpl_formatter: str = "{:P}",
):
#: Map a definition class to a adder methods.
- self._adders = dict()
+ self._adders = {}
self._register_definition_adders()
self._init_dynamic_classes()
@@ -255,44 +246,37 @@ class PlainRegistry(metaclass=RegistryMeta):
#: Map between name (string) and value (string) of defaults stored in the
#: definitions file.
- self._defaults: Dict[str, str] = {}
+ self._defaults: dict[str, str] = {}
#: Map dimension name (string) to its definition (DimensionDefinition).
- self._dimensions: Dict[
- str, Union[DimensionDefinition, DerivedDimensionDefinition]
+ self._dimensions: dict[
+ str, DimensionDefinition | DerivedDimensionDefinition
] = {}
#: Map unit name (string) to its definition (UnitDefinition).
#: Might contain prefixed units.
- self._units: Dict[str, UnitDefinition] = {}
+ self._units: dict[str, UnitDefinition] = {}
#: List base unit names
- self._base_units: List[str] = []
+ self._base_units: list[str] = []
#: Map unit name in lower case (string) to a set of unit names with the right
#: case.
#: Does not contain prefixed units.
#: e.g: 'hz' - > set('Hz', )
- self._units_casei: Dict[str, Set[str]] = defaultdict(set)
+ self._units_casei: dict[str, set[str]] = defaultdict(set)
#: Map prefix name (string) to its definition (PrefixDefinition).
- self._prefixes: Dict[str, PrefixDefinition] = {"": PrefixDefinition("", 1)}
+ self._prefixes: dict[str, PrefixDefinition] = {"": PrefixDefinition("", 1)}
#: Map suffix name (string) to canonical , and unit alias to canonical unit name
- self._suffixes: Dict[str, str] = {"": "", "s": ""}
+ self._suffixes: dict[str, str] = {"": "", "s": ""}
#: Map contexts to RegistryCache
self._cache = RegistryCache()
self._initialized = False
- def __init_subclass__(cls, **kwargs):
- super().__init_subclass__()
- cls.Unit: Unit = build_dependent_class(cls, "Unit", "_unit_class")
- cls.Quantity: Quantity = build_dependent_class(
- cls, "Quantity", "_quantity_class"
- )
-
def _init_dynamic_classes(self) -> None:
"""Generate subclasses on the fly and attach them to self"""
@@ -326,7 +310,7 @@ class PlainRegistry(metaclass=RegistryMeta):
self._register_adder(DimensionDefinition, self._add_dimension)
self._register_adder(DerivedDimensionDefinition, self._add_derived_dimension)
- def __deepcopy__(self, memo) -> "PlainRegistry":
+ def __deepcopy__(self, memo) -> PlainRegistry:
new = object.__new__(type(self))
new.__dict__ = copy.deepcopy(self.__dict__, memo)
new._init_dynamic_classes()
@@ -351,7 +335,7 @@ class PlainRegistry(metaclass=RegistryMeta):
except UndefinedUnitError:
return False
- def __dir__(self) -> List[str]:
+ def __dir__(self) -> list[str]:
#: Calling dir(registry) gives all units, methods, and attributes.
#: Also used for autocompletion in IPython.
return list(self._units.keys()) + list(object.__dir__(self))
@@ -365,7 +349,7 @@ class PlainRegistry(metaclass=RegistryMeta):
"""
return iter(sorted(self._units.keys()))
- def set_fmt_locale(self, loc: Optional[str]) -> None:
+ def set_fmt_locale(self, loc: str | None) -> None:
"""Change the locale used by default by `format_babel`.
Parameters
@@ -397,7 +381,7 @@ class PlainRegistry(metaclass=RegistryMeta):
self.Measurement.default_format = value
@property
- def cache_folder(self) -> Optional[pathlib.Path]:
+ def cache_folder(self) -> pathlib.Path | None:
if self._diskcache:
return self._diskcache.cache_folder
return None
@@ -472,7 +456,7 @@ class PlainRegistry(metaclass=RegistryMeta):
if self._on_redefinition == "raise":
raise RedefinitionError(key, type(value))
elif self._on_redefinition == "warn":
- logger.warning("Redefining '%s' (%s)" % (key, type(value)))
+ logger.warning(f"Redefining '{key}' ({type(value)})")
target_dict[key] = value
if casei_target_dict is not None:
@@ -581,9 +565,7 @@ class PlainRegistry(metaclass=RegistryMeta):
logger.warning(f"Could not resolve {unit_name}: {exc!r}")
return self._cache
- def get_name(
- self, name_or_alias: str, case_sensitive: Optional[bool] = None
- ) -> str:
+ def get_name(self, name_or_alias: str, case_sensitive: bool | None = None) -> str:
"""Return the canonical name of a unit."""
if name_or_alias == "dimensionless":
@@ -621,9 +603,7 @@ class PlainRegistry(metaclass=RegistryMeta):
return unit_name
- def get_symbol(
- self, name_or_alias: str, case_sensitive: Optional[bool] = None
- ) -> str:
+ def get_symbol(self, name_or_alias: str, case_sensitive: bool | None = None) -> str:
"""Return the preferred alias for a unit."""
candidates = self.parse_unit_name(name_or_alias, case_sensitive)
if not candidates:
@@ -632,8 +612,8 @@ class PlainRegistry(metaclass=RegistryMeta):
prefix, unit_name, _ = candidates[0]
else:
logger.warning(
- "Parsing {0} yield multiple results. "
- "Options are: {1!r}".format(name_or_alias, candidates)
+ "Parsing {} yield multiple results. "
+ "Options are: {!r}".format(name_or_alias, candidates)
)
prefix, unit_name, _ = candidates[0]
@@ -654,7 +634,7 @@ class PlainRegistry(metaclass=RegistryMeta):
return self._get_dimensionality(input_units)
def _get_dimensionality(
- self, input_units: Optional[UnitsContainerT]
+ self, input_units: UnitsContainerT | None
) -> UnitsContainerT:
"""Convert a UnitsContainer to plain dimensions."""
if not input_units:
@@ -727,7 +707,7 @@ class PlainRegistry(metaclass=RegistryMeta):
def get_root_units(
self, input_units: UnitLike, check_nonmult: bool = True
- ) -> Tuple[Number, PlainUnit]:
+ ) -> tuple[Number, PlainUnit]:
"""Convert unit or dict of units to the root units.
If any unit is non multiplicative and check_converter is True,
@@ -840,7 +820,7 @@ class PlainRegistry(metaclass=RegistryMeta):
def get_compatible_units(
self, input_units, group_or_system=None
- ) -> FrozenSet[Unit]:
+ ) -> frozenset[Unit]:
""" """
input_units = to_units_container(input_units)
@@ -858,7 +838,7 @@ class PlainRegistry(metaclass=RegistryMeta):
# TODO: remove context from here
def is_compatible_with(
- self, obj1: Any, obj2: Any, *contexts: Union[str, Context], **ctx_kwargs
+ self, obj1: Any, obj2: Any, *contexts: str | Context, **ctx_kwargs
) -> bool:
"""check if the other object is compatible
@@ -972,8 +952,8 @@ class PlainRegistry(metaclass=RegistryMeta):
return value
def parse_unit_name(
- self, unit_name: str, case_sensitive: Optional[bool] = None
- ) -> Tuple[Tuple[str, str, str], ...]:
+ self, unit_name: str, case_sensitive: bool | None = None
+ ) -> tuple[tuple[str, str, str], ...]:
"""Parse a unit to identify prefix, unit name and suffix
by walking the list of prefix and suffix.
In case of equivalent combinations (e.g. ('kilo', 'gram', '') and
@@ -997,8 +977,8 @@ class PlainRegistry(metaclass=RegistryMeta):
)
def _parse_unit_name(
- self, unit_name: str, case_sensitive: Optional[bool] = None
- ) -> Iterator[Tuple[str, str, str]]:
+ self, unit_name: str, case_sensitive: bool | None = None
+ ) -> Iterator[tuple[str, str, str]]:
"""Helper of parse_unit_name."""
case_sensitive = (
self.case_sensitive if case_sensitive is None else case_sensitive
@@ -1029,8 +1009,8 @@ class PlainRegistry(metaclass=RegistryMeta):
@staticmethod
def _dedup_candidates(
- candidates: Iterable[Tuple[str, str, str]]
- ) -> Tuple[Tuple[str, str, str], ...]:
+ candidates: Iterable[tuple[str, str, str]]
+ ) -> tuple[tuple[str, str, str], ...]:
"""Helper of parse_unit_name.
Given an iterable of unit triplets (prefix, name, suffix), remove those with
@@ -1051,8 +1031,8 @@ class PlainRegistry(metaclass=RegistryMeta):
def parse_units(
self,
input_string: str,
- as_delta: Optional[bool] = None,
- case_sensitive: Optional[bool] = None,
+ as_delta: bool | None = None,
+ case_sensitive: bool | None = None,
) -> Unit:
"""Parse a units expression and returns a UnitContainer with
the canonical names.
@@ -1083,7 +1063,7 @@ class PlainRegistry(metaclass=RegistryMeta):
self,
input_string: str,
as_delta: bool = True,
- case_sensitive: Optional[bool] = None,
+ case_sensitive: bool | None = None,
) -> UnitsContainerT:
"""Parse a units expression and returns a UnitContainer with
the canonical names.
@@ -1124,15 +1104,7 @@ class PlainRegistry(metaclass=RegistryMeta):
return ret
- def _eval_token(self, token, case_sensitive=None, use_decimal=False, **values):
- # TODO: remove this code when use_decimal is deprecated
- if use_decimal:
- raise DeprecationWarning(
- "`use_decimal` is deprecated, use `non_int_type` keyword argument when instantiating the registry.\n"
- ">>> from decimal import Decimal\n"
- ">>> ureg = UnitRegistry(non_int_type=Decimal)"
- )
-
+ def _eval_token(self, token, case_sensitive=None, **values):
token_type = token[0]
token_text = token[1]
if token_type == NAME:
@@ -1160,10 +1132,9 @@ class PlainRegistry(metaclass=RegistryMeta):
self,
input_string: str,
pattern: str,
- case_sensitive: Optional[bool] = None,
- use_decimal: bool = False,
+ case_sensitive: bool | None = None,
many: bool = False,
- ) -> Union[List[str], str, None]:
+ ) -> list[str] | str | None:
"""Parse a string with a given regex pattern and returns result.
Parameters
@@ -1174,8 +1145,6 @@ class PlainRegistry(metaclass=RegistryMeta):
The regex parse string
case_sensitive :
(Default value = None, which uses registry setting)
- use_decimal :
- (Default value = False)
many :
Match many results
(Default value = False)
@@ -1200,13 +1169,10 @@ class PlainRegistry(metaclass=RegistryMeta):
match = match.groupdict()
# Parse units
- units = []
- for unit, value in match.items():
- # Construct measure by multiplying value by unit
- units.append(
- float(value)
- * self.parse_expression(unit, case_sensitive, use_decimal)
- )
+ units = [
+ float(value) * self.parse_expression(unit, case_sensitive)
+ for unit, value in match.items()
+ ]
# Add to results
results.append(units)
@@ -1220,8 +1186,7 @@ class PlainRegistry(metaclass=RegistryMeta):
def parse_expression(
self,
input_string: str,
- case_sensitive: Optional[bool] = None,
- use_decimal: bool = False,
+ case_sensitive: bool | None = None,
**values,
) -> Quantity:
"""Parse a mathematical expression including units and return a quantity object.
@@ -1235,8 +1200,6 @@ class PlainRegistry(metaclass=RegistryMeta):
case_sensitive :
(Default value = None, which uses registry setting)
- use_decimal :
- (Default value = False)
**values :
@@ -1244,15 +1207,6 @@ class PlainRegistry(metaclass=RegistryMeta):
-------
"""
-
- # TODO: remove this code when use_decimal is deprecated
- if use_decimal:
- raise DeprecationWarning(
- "`use_decimal` is deprecated, use `non_int_type` keyword argument when instantiating the registry.\n"
- ">>> from decimal import Decimal\n"
- ">>> ureg = UnitRegistry(non_int_type=Decimal)"
- )
-
if not input_string:
return self.Quantity(1)
diff --git a/pint/facets/plain/unit.py b/pint/facets/plain/unit.py
index b608c05..64a7d3c 100644
--- a/pint/facets/plain/unit.py
+++ b/pint/facets/plain/unit.py
@@ -12,7 +12,7 @@ import copy
import locale
import operator
from numbers import Number
-from typing import TYPE_CHECKING, Any, Union
+from typing import TYPE_CHECKING, Any
from ..._typing import UnitLike
from ...compat import NUMERIC_TYPES
@@ -65,7 +65,7 @@ class PlainUnit(PrettyIPython, SharedRegistryObject):
return str(self).encode(locale.getpreferredencoding())
def __repr__(self) -> str:
- return "<Unit('{}')>".format(self._units)
+ return f"<Unit('{self._units}')>"
@property
def dimensionless(self) -> bool:
@@ -96,7 +96,7 @@ class PlainUnit(PrettyIPython, SharedRegistryObject):
return self._REGISTRY.get_compatible_units(self)
def is_compatible_with(
- self, other: Any, *contexts: Union[str, Context], **ctx_kwargs: Any
+ self, other: Any, *contexts: str | Context, **ctx_kwargs: Any
) -> bool:
"""check if the other object is compatible
@@ -165,18 +165,18 @@ class PlainUnit(PrettyIPython, SharedRegistryObject):
return self._REGISTRY.Quantity(other, 1 / self._units)
elif isinstance(other, UnitsContainer):
return self.__class__(other / self._units)
- else:
- return NotImplemented
+
+ return NotImplemented
__div__ = __truediv__
__rdiv__ = __rtruediv__
- def __pow__(self, other) -> "PlainUnit":
+ def __pow__(self, other) -> PlainUnit:
if isinstance(other, NUMERIC_TYPES):
return self.__class__(self._units**other)
else:
- mess = "Cannot power PlainUnit by {}".format(type(other))
+ mess = f"Cannot power PlainUnit by {type(other)}"
raise TypeError(mess)
def __hash__(self) -> int:
@@ -207,8 +207,8 @@ class PlainUnit(PrettyIPython, SharedRegistryObject):
return self_q.compare(other, op)
elif isinstance(other, (PlainUnit, UnitsContainer, dict)):
return self_q.compare(self._REGISTRY.Quantity(1, other), op)
- else:
- return NotImplemented
+
+ return NotImplemented
__lt__ = lambda self, other: self.compare(other, op=operator.lt)
__le__ = lambda self, other: self.compare(other, op=operator.le)
diff --git a/pint/facets/system/definitions.py b/pint/facets/system/definitions.py
index 8243324..1ce8269 100644
--- a/pint/facets/system/definitions.py
+++ b/pint/facets/system/definitions.py
@@ -8,9 +8,10 @@
from __future__ import annotations
-import typing as ty
+from collections.abc import Iterable
from dataclasses import dataclass
+from ..._typing import Self
from ... import errors
@@ -23,7 +24,7 @@ class BaseUnitRule:
new_unit_name: str
#: name of the unit to be kicked out to make room for the new base uni
#: If None, the current base unit with the same dimensionality will be used
- old_unit_name: ty.Optional[str] = None
+ old_unit_name: str | None = None
# Instead of defining __post_init__ here,
# it will be added to the container class
@@ -38,13 +39,16 @@ class SystemDefinition(errors.WithDefErr):
#: name of the system
name: str
#: unit groups that will be included within the system
- using_group_names: ty.Tuple[str, ...]
+ using_group_names: tuple[str]
#: rules to define new base unit within the system.
- rules: ty.Tuple[BaseUnitRule, ...]
+ rules: tuple[BaseUnitRule]
@classmethod
- def from_lines(cls, lines, non_int_type):
+ def from_lines(
+ cls: type[Self], lines: Iterable[str], non_int_type: type
+ ) -> Self | None:
# TODO: this is to keep it backwards compatible
+ # TODO: check when is None returned.
from ...delegates import ParserConfig, txt_defparser
cfg = ParserConfig(non_int_type)
@@ -55,7 +59,8 @@ class SystemDefinition(errors.WithDefErr):
return definition
@property
- def unit_replacements(self) -> ty.Tuple[ty.Tuple[str, str], ...]:
+ def unit_replacements(self) -> tuple[tuple[str, str | None]]:
+ # TODO: check if None can be dropped.
return tuple((el.new_unit_name, el.old_unit_name) for el in self.rules)
def __post_init__(self):
diff --git a/pint/facets/system/objects.py b/pint/facets/system/objects.py
index 829fb5c..69b1c84 100644
--- a/pint/facets/system/objects.py
+++ b/pint/facets/system/objects.py
@@ -9,6 +9,13 @@
from __future__ import annotations
+import numbers
+
+from typing import Any
+from collections.abc import Iterable
+
+from ..._typing import Self
+
from ...babel_names import _babel_systems
from ...compat import babel_parse
from ...util import (
@@ -29,32 +36,28 @@ class System(SharedRegistryObject):
The System belongs to one Registry.
See SystemDefinition for the definition file syntax.
- """
- def __init__(self, name):
- """
- :param name: Name of the group
- :type name: str
- """
+ Parameters
+ ----------
+ name
+ Name of the group.
+ """
+ def __init__(self, name: str):
#: Name of the system
#: :type: str
self.name = name
#: Maps root unit names to a dict indicating the new unit and its exponent.
- #: :type: dict[str, dict[str, number]]]
- self.base_units = {}
+ self.base_units: dict[str, dict[str, numbers.Number]] = {}
#: Derived unit names.
- #: :type: set(str)
- self.derived_units = set()
+ self.derived_units: set[str] = set()
#: Names of the _used_groups in used by this system.
- #: :type: set(str)
- self._used_groups = set()
+ self._used_groups: set[str] = set()
- #: :type: frozenset | None
- self._computed_members = None
+ self._computed_members: frozenset[str] | None = None
# Add this system to the system dictionary
self._REGISTRY._systems[self.name] = self
@@ -62,7 +65,7 @@ class System(SharedRegistryObject):
def __dir__(self):
return list(self.members)
- def __getattr__(self, item):
+ def __getattr__(self, item: str) -> Any:
getattr_maybe_raise(self, item)
u = getattr(self._REGISTRY, self.name + "_" + item, None)
if u is not None:
@@ -93,19 +96,19 @@ class System(SharedRegistryObject):
"""Invalidate computed members in this Group and all parent nodes."""
self._computed_members = None
- def add_groups(self, *group_names):
+ def add_groups(self, *group_names: str) -> None:
"""Add groups to group."""
self._used_groups |= set(group_names)
self.invalidate_members()
- def remove_groups(self, *group_names):
+ def remove_groups(self, *group_names: str) -> None:
"""Remove groups from group."""
self._used_groups -= set(group_names)
self.invalidate_members()
- def format_babel(self, locale):
+ def format_babel(self, locale: str) -> str:
"""translate the name of the system."""
if locale and self.name in _babel_systems:
name = _babel_systems[self.name]
@@ -114,8 +117,12 @@ class System(SharedRegistryObject):
return self.name
@classmethod
- def from_lines(cls, lines, get_root_func, non_int_type=float):
- system_definition = SystemDefinition.from_lines(lines, get_root_func)
+ def from_lines(
+ cls: type[Self], lines: Iterable[str], get_root_func, non_int_type: type = float
+ ) -> Self:
+ # TODO: we changed something here it used to be
+ # system_definition = SystemDefinition.from_lines(lines, get_root_func)
+ system_definition = SystemDefinition.from_lines(lines, non_int_type)
return cls.from_definition(system_definition, get_root_func)
@classmethod
@@ -174,12 +181,12 @@ class System(SharedRegistryObject):
class Lister:
- def __init__(self, d):
+ def __init__(self, d: dict[str, Any]):
self.d = d
- def __dir__(self):
+ def __dir__(self) -> list[str]:
return list(self.d.keys())
- def __getattr__(self, item):
+ def __getattr__(self, item: str) -> Any:
getattr_maybe_raise(self, item)
return self.d[item]
diff --git a/pint/facets/system/registry.py b/pint/facets/system/registry.py
index 527440a..6e0878e 100644
--- a/pint/facets/system/registry.py
+++ b/pint/facets/system/registry.py
@@ -9,7 +9,7 @@
from __future__ import annotations
from numbers import Number
-from typing import TYPE_CHECKING, Dict, FrozenSet, Tuple, Union
+from typing import TYPE_CHECKING
from ... import errors
@@ -19,13 +19,13 @@ if TYPE_CHECKING:
from ..._typing import UnitLike
from ...util import UnitsContainer as UnitsContainerT
from ...util import (
- build_dependent_class,
create_class_with_registry,
to_units_container,
)
from ..group import GroupRegistry
from .definitions import SystemDefinition
from .objects import Lister, System
+from . import objects
class SystemRegistry(GroupRegistry):
@@ -46,24 +46,20 @@ class SystemRegistry(GroupRegistry):
# TODO: Change this to System: System to specify class
# and use introspection to get system class as a way
# to enjoy typing goodies
- _system_class = System
+ System = objects.System
def __init__(self, system=None, **kwargs):
super().__init__(**kwargs)
#: Map system name to system.
#: :type: dict[ str | System]
- self._systems: Dict[str, System] = {}
+ self._systems: dict[str, System] = {}
#: Maps dimensionality (UnitsContainer) to Dimensionality (UnitsContainer)
- self._base_units_cache = dict()
+ self._base_units_cache = {}
self._default_system = system
- def __init_subclass__(cls, **kwargs):
- super().__init_subclass__()
- cls.System = build_dependent_class(cls, "System", "_system_class")
-
def _init_dynamic_classes(self) -> None:
"""Generate subclasses on the fly and attach them to self"""
super()._init_dynamic_classes()
@@ -143,10 +139,10 @@ class SystemRegistry(GroupRegistry):
def get_base_units(
self,
- input_units: Union[UnitLike, Quantity],
+ input_units: UnitLike | Quantity,
check_nonmult: bool = True,
- system: Union[str, System, None] = None,
- ) -> Tuple[Number, Unit]:
+ system: str | System | None = None,
+ ) -> tuple[Number, Unit]:
"""Convert unit or dict of units to the plain units.
If any unit is non multiplicative and check_converter is True,
@@ -183,7 +179,7 @@ class SystemRegistry(GroupRegistry):
self,
input_units: UnitsContainerT,
check_nonmult: bool = True,
- system: Union[str, System, None] = None,
+ system: str | System | None = None,
):
if system is None:
system = self._default_system
@@ -224,7 +220,7 @@ class SystemRegistry(GroupRegistry):
return base_factor, destination_units
- def _get_compatible_units(self, input_units, group_or_system) -> FrozenSet[Unit]:
+ def _get_compatible_units(self, input_units, group_or_system) -> frozenset[Unit]:
if group_or_system is None:
group_or_system = self._default_system
diff --git a/pint/formatting.py b/pint/formatting.py
index f450d5f..880f55b 100644
--- a/pint/formatting.py
+++ b/pint/formatting.py
@@ -13,7 +13,9 @@ from __future__ import annotations
import functools
import re
import warnings
-from typing import Callable, Dict
+from typing import Callable, Any
+from collections.abc import Iterable
+from numbers import Number
from .babel_names import _babel_lengths, _babel_units
from .compat import babel_parse
@@ -21,7 +23,7 @@ from .compat import babel_parse
__JOIN_REG_EXP = re.compile(r"{\d*}")
-def _join(fmt, iterable):
+def _join(fmt: str, iterable: Iterable[Any]):
"""Join an iterable with the format specified in fmt.
The format can be specified in two ways:
@@ -55,7 +57,7 @@ def _join(fmt, iterable):
_PRETTY_EXPONENTS = "⁰¹²³⁴⁵⁶⁷⁸⁹"
-def _pretty_fmt_exponent(num):
+def _pretty_fmt_exponent(num: Number) -> str:
"""Format an number into a pretty printed exponent.
Parameters
@@ -76,7 +78,7 @@ def _pretty_fmt_exponent(num):
#: _FORMATS maps format specifications to the corresponding argument set to
#: formatter().
-_FORMATS: Dict[str, dict] = {
+_FORMATS: dict[str, dict[str, Any]] = {
"P": { # Pretty format.
"as_ratio": True,
"single_denominator": False,
@@ -122,10 +124,10 @@ _FORMATS: Dict[str, dict] = {
}
#: _FORMATTERS maps format names to callables doing the formatting
-_FORMATTERS: Dict[str, Callable] = {}
+_FORMATTERS: dict[str, Callable] = {}
-def register_unit_format(name):
+def register_unit_format(name: str):
"""register a function as a new format for units
The registered function must have a signature of:
@@ -197,9 +199,7 @@ def latex_escape(string):
@register_unit_format("L")
def format_latex(unit, registry, **options):
- preprocessed = {
- r"\mathrm{{{}}}".format(latex_escape(u)): p for u, p in unit.items()
- }
+ preprocessed = {rf"\mathrm{{{latex_escape(u)}}}": p for u, p in unit.items()}
formatted = formatter(
preprocessed.items(),
as_ratio=True,
@@ -270,18 +270,18 @@ def format_compact(unit, registry, **options):
def formatter(
- items,
- as_ratio=True,
- single_denominator=False,
- product_fmt=" * ",
- division_fmt=" / ",
- power_fmt="{} ** {}",
- parentheses_fmt="({0})",
+ items: list[tuple[str, Number]],
+ as_ratio: bool = True,
+ single_denominator: bool = False,
+ product_fmt: str = " * ",
+ division_fmt: str = " / ",
+ power_fmt: str = "{} ** {}",
+ parentheses_fmt: str = "({0})",
exp_call=lambda x: f"{x:n}",
- locale=None,
- babel_length="long",
- babel_plural_form="one",
- sort=True,
+ locale: str | None = None,
+ babel_length: str = "long",
+ babel_plural_form: str = "one",
+ sort: bool = True,
):
"""Format a list of (name, exponent) pairs.
@@ -442,10 +442,10 @@ def siunitx_format_unit(units, registry):
elif power == 3:
return r"\cubed"
else:
- return r"\tothe{{{:d}}}".format(int(power))
+ return rf"\tothe{{{int(power):d}}}"
else:
# limit float powers to 3 decimal places
- return r"\tothe{{{:.3f}}}".format(power).rstrip("0")
+ return rf"\tothe{{{power:.3f}}}".rstrip("0")
lpos = []
lneg = []
@@ -466,9 +466,9 @@ def siunitx_format_unit(units, registry):
if power < 0:
lpick.append(r"\per")
if prefix is not None:
- lpick.append(r"\{}".format(prefix))
- lpick.append(r"\{}".format(unit))
- lpick.append(r"{}".format(_tothe(abs(power))))
+ lpick.append(rf"\{prefix}")
+ lpick.append(rf"\{unit}")
+ lpick.append(rf"{_tothe(abs(power))}")
return "".join(lpos) + "".join(lneg)
@@ -529,8 +529,8 @@ def split_format(spec, default, separate_format_defaults=True):
elif not spec:
mspec, uspec = default_mspec, default_uspec
else:
- mspec = mspec if mspec else default_mspec
- uspec = uspec if uspec else default_uspec
+ mspec = mspec or default_mspec
+ uspec = uspec or default_uspec
return mspec, uspec
diff --git a/pint/matplotlib.py b/pint/matplotlib.py
index ea88c70..25c257b 100644
--- a/pint/matplotlib.py
+++ b/pint/matplotlib.py
@@ -36,15 +36,15 @@ class PintConverter(matplotlib.units.ConversionInterface):
"""Convert :`Quantity` instances for matplotlib to use."""
if iterable(value):
return [self._convert_value(v, unit, axis) for v in value]
- else:
- return self._convert_value(value, unit, axis)
+
+ return self._convert_value(value, unit, axis)
def _convert_value(self, value, unit, axis):
"""Handle converting using attached unit or falling back to axis units."""
if hasattr(value, "units"):
return value.to(unit).magnitude
- else:
- return self._reg.Quantity(value, axis.get_units()).to(unit).magnitude
+
+ return self._reg.Quantity(value, axis.get_units()).to(unit).magnitude
@staticmethod
def axisinfo(unit, axis):
diff --git a/pint/pint_convert.py b/pint/pint_convert.py
index d8d60e8..bf90972 100755
--- a/pint/pint_convert.py
+++ b/pint/pint_convert.py
@@ -11,6 +11,7 @@
from __future__ import annotations
import argparse
+import contextlib
import re
from pint import UnitRegistry
@@ -154,13 +155,13 @@ if args.unc:
),
)
- ureg._root_units_cache = dict()
+ ureg._root_units_cache = {}
ureg._build_cache()
def convert(u_from, u_to=None, unc=None, factor=None):
q = ureg.Quantity(u_from)
- fmt = ".{}g".format(args.prec)
+ fmt = f".{args.prec}g"
if unc:
q = q.plus_minus(unc)
if u_to:
@@ -172,25 +173,23 @@ def convert(u_from, u_to=None, unc=None, factor=None):
nq *= ureg.Quantity(factor).to_base_units()
prec_unc = use_unc(nq.magnitude, fmt, args.prec_unc)
if prec_unc > 0:
- fmt = ".{}uS".format(prec_unc)
+ fmt = f".{prec_unc}uS"
else:
- try:
+ with contextlib.suppress(Exception):
nq = nq.magnitude.n * nq.units
- except Exception:
- pass
+
fmt = "{:" + fmt + "} {:~P}"
print(("{:} = " + fmt).format(q, nq.magnitude, nq.units))
def use_unc(num, fmt, prec_unc):
unc = 0
- try:
+ with contextlib.suppress(Exception):
if isinstance(num, uncertainties.UFloat):
full = ("{:" + fmt + "}").format(num)
unc = re.search(r"\+/-[0.]*([\d.]*)", full).group(1)
unc = len(unc.replace(".", ""))
- except Exception:
- pass
+
return max(0, min(prec_unc, unc))
diff --git a/pint/pint_eval.py b/pint/pint_eval.py
index 2054260..d476eae 100644
--- a/pint/pint_eval.py
+++ b/pint/pint_eval.py
@@ -11,7 +11,9 @@ from __future__ import annotations
import operator
import token as tokenlib
-import tokenize
+from tokenize import TokenInfo
+
+from typing import Any
from .errors import DefinitionSyntaxError
@@ -30,7 +32,7 @@ _OP_PRIORITY = {
}
-def _power(left, right):
+def _power(left: Any, right: Any) -> Any:
from . import Quantity
from .compat import is_duck_array
@@ -45,7 +47,19 @@ def _power(left, right):
return operator.pow(left, right)
-_BINARY_OPERATOR_MAP = {
+import typing
+
+UnaryOpT = typing.Callable[
+ [
+ Any,
+ ],
+ Any,
+]
+BinaryOpT = typing.Callable[[Any, Any], Any]
+
+_UNARY_OPERATOR_MAP: dict[str, UnaryOpT] = {"+": lambda x: x, "-": lambda x: x * -1}
+
+_BINARY_OPERATOR_MAP: dict[str, BinaryOpT] = {
"**": _power,
"*": operator.mul,
"": operator.mul, # operator for implicit ops
@@ -56,8 +70,6 @@ _BINARY_OPERATOR_MAP = {
"//": operator.floordiv,
}
-_UNARY_OPERATOR_MAP = {"+": lambda x: x, "-": lambda x: x * -1}
-
class EvalTreeNode:
"""Single node within an evaluation tree
@@ -68,25 +80,43 @@ class EvalTreeNode:
left --> single value
"""
- def __init__(self, left, operator=None, right=None):
+ def __init__(
+ self,
+ left: EvalTreeNode | TokenInfo,
+ operator: TokenInfo | None = None,
+ right: EvalTreeNode | None = None,
+ ):
self.left = left
self.operator = operator
self.right = right
- def to_string(self):
+ def to_string(self) -> str:
# For debugging purposes
if self.right:
+ assert isinstance(self.left, EvalTreeNode), "self.left not EvalTreeNode (1)"
comps = [self.left.to_string()]
if self.operator:
- comps.append(self.operator[1])
+ comps.append(self.operator.string)
comps.append(self.right.to_string())
elif self.operator:
- comps = [self.operator[1], self.left.to_string()]
+ assert isinstance(self.left, EvalTreeNode), "self.left not EvalTreeNode (2)"
+ comps = [self.operator.string, self.left.to_string()]
else:
- return self.left[1]
+ assert isinstance(self.left, TokenInfo), "self.left not TokenInfo (1)"
+ return self.left.string
return "(%s)" % " ".join(comps)
- def evaluate(self, define_op, bin_op=None, un_op=None):
+ def evaluate(
+ self,
+ define_op: typing.Callable[
+ [
+ Any,
+ ],
+ Any,
+ ],
+ bin_op: dict[str, BinaryOpT] | None = None,
+ un_op: dict[str, UnaryOpT] | None = None,
+ ):
"""Evaluate node.
Parameters
@@ -107,33 +137,38 @@ class EvalTreeNode:
un_op = un_op or _UNARY_OPERATOR_MAP
if self.right:
+ assert isinstance(self.left, EvalTreeNode), "self.left not EvalTreeNode (3)"
# binary or implicit operator
- op_text = self.operator[1] if self.operator else ""
+ op_text = self.operator.string if self.operator else ""
if op_text not in bin_op:
- raise DefinitionSyntaxError('missing binary operator "%s"' % op_text)
- left = self.left.evaluate(define_op, bin_op, un_op)
- return bin_op[op_text](left, self.right.evaluate(define_op, bin_op, un_op))
+ raise DefinitionSyntaxError(f"missing binary operator '{op_text}'")
+
+ return bin_op[op_text](
+ self.left.evaluate(define_op, bin_op, un_op),
+ self.right.evaluate(define_op, bin_op, un_op),
+ )
elif self.operator:
+ assert isinstance(self.left, EvalTreeNode), "self.left not EvalTreeNode (4)"
# unary operator
- op_text = self.operator[1]
+ op_text = self.operator.string
if op_text not in un_op:
- raise DefinitionSyntaxError('missing unary operator "%s"' % op_text)
+ raise DefinitionSyntaxError(f"missing unary operator '{op_text}'")
return un_op[op_text](self.left.evaluate(define_op, bin_op, un_op))
- else:
- # single value
- return define_op(self.left)
+ # single value
+ return define_op(self.left)
-from typing import Iterable
+from collections.abc import Iterable
-def build_eval_tree(
- tokens: Iterable[tokenize.TokenInfo],
- op_priority=None,
- index=0,
- depth=0,
- prev_op=None,
-) -> tuple[EvalTreeNode | None, int] | EvalTreeNode:
+
+def _build_eval_tree(
+ tokens: list[TokenInfo],
+ op_priority: dict[str, int],
+ index: int = 0,
+ depth: int = 0,
+ prev_op: str = "<none>",
+) -> tuple[EvalTreeNode, int]:
"""Build an evaluation tree from a set of tokens.
Params:
@@ -153,14 +188,12 @@ def build_eval_tree(
5) Combine left side, operator, and right side into a new left side
6) Go back to step #2
- """
+ Raises
+ ------
+ DefinitionSyntaxError
+ If there is a syntax error.
- if op_priority is None:
- op_priority = _OP_PRIORITY
-
- if depth == 0 and prev_op is None:
- # ensure tokens is list so we can access by index
- tokens = list(tokens)
+ """
result = None
@@ -171,19 +204,21 @@ def build_eval_tree(
if token_type == tokenlib.OP:
if token_text == ")":
- if prev_op is None:
+ if prev_op == "<none>":
raise DefinitionSyntaxError(
- "unopened parentheses in tokens: %s" % current_token
+ f"unopened parentheses in tokens: {current_token}"
)
elif prev_op == "(":
# close parenthetical group
+ assert result is not None
return result, index
else:
# parenthetical group ending, but we need to close sub-operations within group
+ assert result is not None
return result, index - 1
elif token_text == "(":
# gather parenthetical group
- right, index = build_eval_tree(
+ right, index = _build_eval_tree(
tokens, op_priority, index + 1, 0, token_text
)
if not tokens[index][1] == ")":
@@ -204,11 +239,11 @@ def build_eval_tree(
# (2 * 3 / 4) --> ((2 * 3) / 4)
if op_priority[token_text] <= op_priority.get(
prev_op, -1
- ) and token_text not in ["**", "^"]:
+ ) and token_text not in ("**", "^"):
# previous operator is higher priority, so end previous binary op
return result, index - 1
# get right side of binary op
- right, index = build_eval_tree(
+ right, index = _build_eval_tree(
tokens, op_priority, index + 1, depth + 1, token_text
)
result = EvalTreeNode(
@@ -216,18 +251,18 @@ def build_eval_tree(
)
else:
# unary operator
- right, index = build_eval_tree(
+ right, index = _build_eval_tree(
tokens, op_priority, index + 1, depth + 1, "unary"
)
result = EvalTreeNode(left=right, operator=current_token)
- elif token_type == tokenlib.NUMBER or token_type == tokenlib.NAME:
+ elif token_type in (tokenlib.NUMBER, tokenlib.NAME):
if result:
# tokens with an implicit operation i.e. "1 kg"
if op_priority[""] <= op_priority.get(prev_op, -1):
# previous operator is higher priority than implicit, so end
# previous binary op
return result, index - 1
- right, index = build_eval_tree(
+ right, index = _build_eval_tree(
tokens, op_priority, index, depth + 1, ""
)
result = EvalTreeNode(left=result, right=right)
@@ -240,13 +275,57 @@ def build_eval_tree(
raise DefinitionSyntaxError("unclosed parentheses in tokens")
if depth > 0 or prev_op:
# have to close recursion
+ assert result is not None
return result, index
else:
# recursion all closed, so just return the final result
- return result
+ assert result is not None
+ return result, -1
if index + 1 >= len(tokens):
# should hit ENDMARKER before this ever happens
raise DefinitionSyntaxError("unexpected end to tokens")
index += 1
+
+
+def build_eval_tree(
+ tokens: Iterable[TokenInfo],
+ op_priority: dict[str, int] | None = None,
+) -> EvalTreeNode:
+ """Build an evaluation tree from a set of tokens.
+
+ Params:
+ Index, depth, and prev_op used recursively, so don't touch.
+ Tokens is an iterable of tokens from an expression to be evaluated.
+
+ Transform the tokens from an expression into a recursive parse tree, following order
+ of operations. Operations can include binary ops (3 + 4), implicit ops (3 kg), or
+ unary ops (-1).
+
+ General Strategy:
+ 1) Get left side of operator
+ 2) If no tokens left, return final result
+ 3) Get operator
+ 4) Use recursion to create tree starting at token on right side of operator (start at step #1)
+ 4.1) If recursive call encounters an operator with lower or equal priority to step #2, exit recursion
+ 5) Combine left side, operator, and right side into a new left side
+ 6) Go back to step #2
+
+ Raises
+ ------
+ DefinitionSyntaxError
+ If there is a syntax error.
+
+ """
+
+ if op_priority is None:
+ op_priority = _OP_PRIORITY
+
+ if not isinstance(tokens, list):
+ # ensure tokens is list so we can access by index
+ tokens = list(tokens)
+
+ result, _ = _build_eval_tree(tokens, op_priority, 0, 0)
+
+ return result
diff --git a/pint/registry.py b/pint/registry.py
index 29d5c89..474eb77 100644
--- a/pint/registry.py
+++ b/pint/registry.py
@@ -27,6 +27,35 @@ from .facets import (
from .util import logger, pi_theorem
+# To build the Quantity and Unit classes
+# we follow the UnitRegistry bases
+# but
+
+
+class Quantity(
+ # SystemRegistry.Quantity,
+ # ContextRegistry.Quantity,
+ DaskRegistry.Quantity,
+ NumpyRegistry.Quantity,
+ MeasurementRegistry.Quantity,
+ FormattingRegistry.Quantity,
+ NonMultiplicativeRegistry.Quantity,
+):
+ pass
+
+
+class Unit(
+ # SystemRegistry.Unit,
+ # ContextRegistry.Unit,
+ # DaskRegistry.Unit,
+ NumpyRegistry.Unit,
+ # MeasurementRegistry.Unit,
+ FormattingRegistry.Unit,
+ NonMultiplicativeRegistry.Unit,
+):
+ pass
+
+
class UnitRegistry(
SystemRegistry,
ContextRegistry,
@@ -72,6 +101,9 @@ class UnitRegistry(
If None, the cache is disabled. (default)
"""
+ Quantity = Quantity
+ Unit = Unit
+
def __init__(
self,
filename="",
diff --git a/pint/registry_helpers.py b/pint/registry_helpers.py
index 07b00ff..1f28036 100644
--- a/pint/registry_helpers.py
+++ b/pint/registry_helpers.py
@@ -13,7 +13,8 @@ from __future__ import annotations
import functools
from inspect import signature
from itertools import zip_longest
-from typing import TYPE_CHECKING, Callable, Iterable, TypeVar, Union
+from typing import TYPE_CHECKING, Callable, TypeVar
+from collections.abc import Iterable
from ._typing import F
from .errors import DimensionalityError
@@ -184,9 +185,9 @@ def _apply_defaults(func, args, kwargs):
def wraps(
- ureg: "UnitRegistry",
- ret: Union[str, "Unit", Iterable[Union[str, "Unit", None]], None],
- args: Union[str, "Unit", Iterable[Union[str, "Unit", None]], None],
+ ureg: UnitRegistry,
+ ret: str | Unit | Iterable[str | Unit | None] | None,
+ args: str | Unit | Iterable[str | Unit | None] | None,
strict: bool = True,
) -> Callable[[Callable[..., T]], Callable[..., Quantity[T]]]:
"""Wraps a function to become pint-aware.
@@ -300,7 +301,7 @@ def wraps(
def check(
- ureg: "UnitRegistry", *args: Union[str, UnitsContainer, "Unit", None]
+ ureg: UnitRegistry, *args: str | UnitsContainer | Unit | None
) -> Callable[[F], F]:
"""Decorator to for quantity type checking for function inputs.
diff --git a/pint/testing.py b/pint/testing.py
index 1c458f5..8e4f15f 100644
--- a/pint/testing.py
+++ b/pint/testing.py
@@ -36,10 +36,10 @@ def _get_comparable_magnitudes(first, second, msg):
def assert_equal(first, second, msg=None):
if msg is None:
- msg = "Comparing %r and %r. " % (first, second)
+ msg = f"Comparing {first!r} and {second!r}. "
m1, m2 = _get_comparable_magnitudes(first, second, msg)
- msg += " (Converted to %r and %r): Magnitudes are not equal" % (m1, m2)
+ msg += f" (Converted to {m1!r} and {m2!r}): Magnitudes are not equal"
if isinstance(m1, ndarray) or isinstance(m2, ndarray):
np.testing.assert_array_equal(m1, m2, err_msg=msg)
@@ -60,15 +60,15 @@ def assert_equal(first, second, msg=None):
def assert_allclose(first, second, rtol=1e-07, atol=0, msg=None):
if msg is None:
try:
- msg = "Comparing %r and %r. " % (first, second)
+ msg = f"Comparing {first!r} and {second!r}. "
except TypeError:
try:
- msg = "Comparing %s and %s. " % (first, second)
+ msg = f"Comparing {first} and {second}. "
except Exception:
msg = "Comparing"
m1, m2 = _get_comparable_magnitudes(first, second, msg)
- msg += " (Converted to %r and %r)" % (m1, m2)
+ msg += f" (Converted to {m1!r} and {m2!r})"
if isinstance(m1, ndarray) or isinstance(m2, ndarray):
np.testing.assert_allclose(m1, m2, rtol=rtol, atol=atol, err_msg=msg)
diff --git a/pint/testsuite/__init__.py b/pint/testsuite/__init__.py
index 8c0cd09..35b0d91 100644
--- a/pint/testsuite/__init__.py
+++ b/pint/testsuite/__init__.py
@@ -3,7 +3,8 @@ import math
import os
import unittest
import warnings
-from contextlib import contextmanager
+import contextlib
+import pathlib
from pint import UnitRegistry
from pint.testsuite.helpers import PintOutputChecker
@@ -25,7 +26,7 @@ class QuantityTestCase:
cls.U_ = None
-@contextmanager
+@contextlib.contextmanager
def assert_no_warnings():
with warnings.catch_warnings():
warnings.simplefilter("error")
@@ -40,13 +41,12 @@ def testsuite():
# TESTING THE DOCUMENTATION requires pyyaml, serialize, numpy and uncertainties
if HAS_NUMPY and HAS_UNCERTAINTIES:
- try:
+ with contextlib.suppress(ImportError):
import serialize # noqa: F401
import yaml # noqa: F401
add_docs(suite)
- except ImportError:
- pass
+
return suite
@@ -98,7 +98,7 @@ def add_docs(suite):
"""
docpath = os.path.join(os.path.dirname(__file__), "..", "..", "docs")
docpath = os.path.abspath(docpath)
- if os.path.exists(docpath):
+ if pathlib.Path(docpath).exists():
checker = PintOutputChecker()
for name in (name for name in os.listdir(docpath) if name.endswith(".rst")):
file = os.path.join(docpath, name)
diff --git a/pint/testsuite/helpers.py b/pint/testsuite/helpers.py
index 4c560fb..191f4c3 100644
--- a/pint/testsuite/helpers.py
+++ b/pint/testsuite/helpers.py
@@ -1,6 +1,7 @@
import doctest
import pickle
import re
+import contextlib
import pytest
from packaging.version import parse as version_parse
@@ -41,14 +42,12 @@ class PintOutputChecker(doctest.OutputChecker):
if check:
return check
- try:
+ with contextlib.suppress(Exception):
if eval(want) == eval(got):
return True
- except Exception:
- pass
for regex in (_q_re, _sq_re):
- try:
+ with contextlib.suppress(Exception):
parsed_got = regex.match(got.replace(r"\\", "")).groupdict()
parsed_want = regex.match(want.replace(r"\\", "")).groupdict()
@@ -62,12 +61,10 @@ class PintOutputChecker(doctest.OutputChecker):
return False
return True
- except Exception:
- pass
cnt = 0
for regex in (_unit_re,):
- try:
+ with contextlib.suppress(Exception):
parsed_got, tmp = regex.subn("\1", got)
cnt += tmp
parsed_want, temp = regex.subn("\1", want)
@@ -76,9 +73,6 @@ class PintOutputChecker(doctest.OutputChecker):
if parsed_got == parsed_want:
return True
- except Exception:
- pass
-
if cnt:
# If there was any replacement, we try again the previous methods.
return self.check_output(parsed_want, parsed_got, optionflags)
diff --git a/pint/testsuite/test_babel.py b/pint/testsuite/test_babel.py
index 5c32879..7842d54 100644
--- a/pint/testsuite/test_babel.py
+++ b/pint/testsuite/test_babel.py
@@ -84,16 +84,16 @@ def test_str(func_registry):
s = "24.0 meter"
assert str(d) == s
assert "%s" % d == s
- assert "{}".format(d) == s
+ assert f"{d}" == s
ureg.set_fmt_locale("fr_FR")
s = "24.0 mètres"
assert str(d) == s
assert "%s" % d == s
- assert "{}".format(d) == s
+ assert f"{d}" == s
ureg.set_fmt_locale(None)
s = "24.0 meter"
assert str(d) == s
assert "%s" % d == s
- assert "{}".format(d) == s
+ assert f"{d}" == s
diff --git a/pint/testsuite/test_compat_downcast.py b/pint/testsuite/test_compat_downcast.py
index ebb5907..ed43e94 100644
--- a/pint/testsuite/test_compat_downcast.py
+++ b/pint/testsuite/test_compat_downcast.py
@@ -1,3 +1,4 @@
+import operator
import pytest
from pint import UnitRegistry
@@ -37,7 +38,7 @@ def q_base(local_registry):
# Define identity function for use in tests
-def identity(ureg, x):
+def id_matrix(ureg, x):
return x
@@ -62,17 +63,17 @@ def array(request):
@pytest.mark.parametrize(
"op, magnitude_op, unit_op",
[
- pytest.param(identity, identity, identity, id="identity"),
+ pytest.param(id_matrix, id_matrix, id_matrix, id="identity"),
pytest.param(
lambda ureg, x: x + 1 * ureg.m,
lambda ureg, x: x + 1,
- identity,
+ id_matrix,
id="addition",
),
pytest.param(
lambda ureg, x: x - 20 * ureg.cm,
lambda ureg, x: x - 0.2,
- identity,
+ id_matrix,
id="subtraction",
),
pytest.param(
@@ -83,7 +84,7 @@ def array(request):
),
pytest.param(
lambda ureg, x: x / (1 * ureg.s),
- identity,
+ id_matrix,
lambda ureg, u: u / ureg.s,
id="division",
),
@@ -93,17 +94,17 @@ def array(request):
WR(lambda u: u**2),
id="square",
),
- pytest.param(WR(lambda x: x.T), WR(lambda x: x.T), identity, id="transpose"),
- pytest.param(WR(np.mean), WR(np.mean), identity, id="mean ufunc"),
- pytest.param(WR(np.sum), WR(np.sum), identity, id="sum ufunc"),
+ pytest.param(WR(lambda x: x.T), WR(lambda x: x.T), id_matrix, id="transpose"),
+ pytest.param(WR(np.mean), WR(np.mean), id_matrix, id="mean ufunc"),
+ pytest.param(WR(np.sum), WR(np.sum), id_matrix, id="sum ufunc"),
pytest.param(WR(np.sqrt), WR(np.sqrt), WR(lambda u: u**0.5), id="sqrt ufunc"),
pytest.param(
WR(lambda x: np.reshape(x, (25,))),
WR(lambda x: np.reshape(x, (25,))),
- identity,
+ id_matrix,
id="reshape function",
),
- pytest.param(WR(np.amax), WR(np.amax), identity, id="amax function"),
+ pytest.param(WR(np.amax), WR(np.amax), id_matrix, id="amax function"),
],
)
def test_univariate_op_consistency(
@@ -121,10 +122,8 @@ def test_univariate_op_consistency(
@pytest.mark.parametrize(
"op, unit",
[
- pytest.param(
- lambda x, y: x * y, lambda ureg: ureg("kg m"), id="multiplication"
- ),
- pytest.param(lambda x, y: x / y, lambda ureg: ureg("m / kg"), id="division"),
+ pytest.param(operator.mul, lambda ureg: ureg("kg m"), id="multiplication"),
+ pytest.param(operator.truediv, lambda ureg: ureg("m / kg"), id="division"),
pytest.param(np.multiply, lambda ureg: ureg("kg m"), id="multiply ufunc"),
],
)
@@ -143,11 +142,11 @@ def test_bivariate_op_consistency(local_registry, q_base, op, unit, array):
"op",
[
pytest.param(
- WR2(lambda a, u: a * u),
+ WR2(operator.mul),
id="array-first",
marks=pytest.mark.xfail(reason="upstream issue numpy/numpy#15200"),
),
- pytest.param(WR2(lambda a, u: u * a), id="unit-first"),
+ pytest.param(WR2(operator.mul), id="unit-first"),
],
)
@pytest.mark.parametrize(
diff --git a/pint/testsuite/test_compat_upcast.py b/pint/testsuite/test_compat_upcast.py
index ad267c1..c8266f7 100644
--- a/pint/testsuite/test_compat_upcast.py
+++ b/pint/testsuite/test_compat_upcast.py
@@ -1,3 +1,4 @@
+import operator
import pytest
# Conditionally import NumPy and any upcast type libraries
@@ -49,9 +50,9 @@ def test_quantification(module_registry, ds):
@pytest.mark.parametrize(
"op",
[
- lambda x, y: x + y,
+ operator.add,
lambda x, y: x - (-y),
- lambda x, y: x * y,
+ operator.mul,
lambda x, y: x / (y**-1),
],
)
@@ -126,9 +127,7 @@ def test_array_function_deferral(da, module_registry):
upper = 3 * module_registry.m
args = (da, lower, upper)
assert (
- lower.__array_function__(
- np.clip, tuple(set(type(arg) for arg in args)), args, {}
- )
+ lower.__array_function__(np.clip, tuple({type(arg) for arg in args}), args, {})
is NotImplemented
)
diff --git a/pint/testsuite/test_contexts.py b/pint/testsuite/test_contexts.py
index c7551e4..ea6525d 100644
--- a/pint/testsuite/test_contexts.py
+++ b/pint/testsuite/test_contexts.py
@@ -683,7 +683,7 @@ class TestDefinedContexts:
)
p = find_shortest_path(ureg._active_ctx.graph, da, db)
assert p
- msg = "{} <-> {}".format(a, b)
+ msg = f"{a} <-> {b}"
# assertAlmostEqualRelError converts second to first
helpers.assert_quantity_almost_equal(b, a, rtol=0.01, msg=msg)
@@ -705,7 +705,7 @@ class TestDefinedContexts:
da, db = Context.__keytransform__(a.dimensionality, b.dimensionality)
p = find_shortest_path(ureg._active_ctx.graph, da, db)
assert p
- msg = "{} <-> {}".format(a, b)
+ msg = f"{a} <-> {b}"
helpers.assert_quantity_almost_equal(b, a, rtol=0.01, msg=msg)
# Check RKM <-> cN/tex conversion
diff --git a/pint/testsuite/test_converters.py b/pint/testsuite/test_converters.py
index 62ffdb7..71a076f 100644
--- a/pint/testsuite/test_converters.py
+++ b/pint/testsuite/test_converters.py
@@ -69,7 +69,7 @@ class TestConverter:
@helpers.requires_numpy
def test_log_converter_inplace(self):
- arb_value = 3.14
+ arb_value = 3.13
c = LogarithmicConverter(scale=1, logbase=10, logfactor=1)
from_to = lambda value, inplace: c.from_reference(
diff --git a/pint/testsuite/test_dask.py b/pint/testsuite/test_dask.py
index f4dee6a..0e6a1cf 100644
--- a/pint/testsuite/test_dask.py
+++ b/pint/testsuite/test_dask.py
@@ -1,5 +1,6 @@
import importlib
-import os
+
+import pathlib
import pytest
@@ -135,8 +136,8 @@ def test_visualize(local_registry, dask_array):
assert res is None
# These commands only work on Unix and Windows
- assert os.path.exists("mydask.png")
- os.remove("mydask.png")
+ assert pathlib.Path("mydask.png").exists()
+ pathlib.Path("mydask.png").unlink()
def test_compute_persist_equivalent(local_registry, dask_array, numpy_array):
diff --git a/pint/testsuite/test_definitions.py b/pint/testsuite/test_definitions.py
index 2618c6e..69a337d 100644
--- a/pint/testsuite/test_definitions.py
+++ b/pint/testsuite/test_definitions.py
@@ -1,5 +1,7 @@
import pytest
+import math
+
from pint.definitions import Definition
from pint.errors import DefinitionSyntaxError
from pint.facets.nonmultiplicative.definitions import (
@@ -81,7 +83,7 @@ class TestDefinition:
assert x.reference == UnitsContainer(kelvin=1)
x = Definition.from_string(
- "turn = 6.28 * radian = _ = revolution = = cycle = _"
+ f"turn = {math.tau} * radian = _ = revolution = = cycle = _"
)
assert isinstance(x, UnitDefinition)
assert x.name == "turn"
@@ -89,7 +91,7 @@ class TestDefinition:
assert x.symbol == "turn"
assert not x.is_base
assert isinstance(x.converter, ScaleConverter)
- assert x.converter.scale == 6.28
+ assert x.converter.scale == math.tau
assert x.reference == UnitsContainer(radian=1)
with pytest.raises(ValueError):
@@ -136,7 +138,7 @@ class TestDefinition:
assert x.converter.logfactor == 1
assert x.reference == UnitsContainer()
- eulersnumber = 2.71828182845904523536028747135266249775724709369995
+ eulersnumber = math.e
x = Definition.from_string(
"neper = 1 ; logbase: %1.50f; logfactor: 0.5 = Np" % eulersnumber
)
diff --git a/pint/testsuite/test_errors.py b/pint/testsuite/test_errors.py
index 6a42eec..a045f6e 100644
--- a/pint/testsuite/test_errors.py
+++ b/pint/testsuite/test_errors.py
@@ -116,7 +116,7 @@ class TestErrors:
q2 = ureg.Quantity("1 bar")
for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
- for ex in [
+ for ex in (
DefinitionSyntaxError("foo"),
RedefinitionError("foo", "bar"),
UndefinedUnitError("meter"),
@@ -125,7 +125,7 @@ class TestErrors:
Quantity("1 kg")._units, Quantity("1 s")._units
),
OffsetUnitCalculusError(q1._units, q2._units),
- ]:
+ ):
with subtests.test(protocol=protocol, etype=type(ex)):
pik = pickle.dumps(ureg.Quantity("1 foo"), protocol)
with pytest.raises(UndefinedUnitError):
diff --git a/pint/testsuite/test_formatter.py b/pint/testsuite/test_formatter.py
index 9e362fc..5a51a0a 100644
--- a/pint/testsuite/test_formatter.py
+++ b/pint/testsuite/test_formatter.py
@@ -5,13 +5,13 @@ from pint import formatting as fmt
class TestFormatter:
def test_join(self):
- for empty in (tuple(), []):
+ for empty in ((), []):
assert fmt._join("s", empty) == ""
assert fmt._join("*", "1 2 3".split()) == "1*2*3"
assert fmt._join("{0}*{1}", "1 2 3".split()) == "1*2*3"
def test_formatter(self):
- assert fmt.formatter(dict().items()) == ""
+ assert fmt.formatter({}.items()) == ""
assert fmt.formatter(dict(meter=1).items()) == "meter"
assert fmt.formatter(dict(meter=-1).items()) == "1 / meter"
assert fmt.formatter(dict(meter=-1).items(), as_ratio=False) == "meter ** -1"
diff --git a/pint/testsuite/test_infer_base_unit.py b/pint/testsuite/test_infer_base_unit.py
index f2605c6..9a27362 100644
--- a/pint/testsuite/test_infer_base_unit.py
+++ b/pint/testsuite/test_infer_base_unit.py
@@ -34,9 +34,9 @@ class TestInferBaseUnit:
ureg = UnitRegistry(non_int_type=Decimal)
QD = ureg.Quantity
- ibu_d = infer_base_unit(QD(Decimal("1"), "millimeter * nanometer"))
+ ibu_d = infer_base_unit(QD(Decimal(1), "millimeter * nanometer"))
- assert ibu_d == QD(Decimal("1"), "meter**2").units
+ assert ibu_d == QD(Decimal(1), "meter**2").units
assert all(isinstance(v, Decimal) for v in ibu_d.values())
@@ -69,9 +69,9 @@ class TestInferBaseUnit:
Q = ureg.Quantity
r = (
Q(Decimal("1000000000.0"), "m")
- * Q(Decimal("1"), "mm")
- / Q(Decimal("1"), "s")
- / Q(Decimal("1"), "ms")
+ * Q(Decimal(1), "mm")
+ / Q(Decimal(1), "s")
+ / Q(Decimal(1), "ms")
)
compact_r = r.to_compact()
expected = Q(Decimal("1000.0"), "kilometer**2 / second**2")
diff --git a/pint/testsuite/test_issues.py b/pint/testsuite/test_issues.py
index 8517bd9..9540814 100644
--- a/pint/testsuite/test_issues.py
+++ b/pint/testsuite/test_issues.py
@@ -445,10 +445,10 @@ class TestIssues(QuantityTestCase):
def test_issue354_356_370(self, module_registry):
assert (
- "{:~}".format(1 * module_registry.second / module_registry.millisecond)
+ f"{1 * module_registry.second / module_registry.millisecond:~}"
== "1.0 s / ms"
)
- assert "{:~}".format(1 * module_registry.count) == "1 count"
+ assert f"{1 * module_registry.count:~}" == "1 count"
assert "{:~}".format(1 * module_registry("MiB")) == "1 MiB"
def test_issue468(self, module_registry):
diff --git a/pint/testsuite/test_log_units.py b/pint/testsuite/test_log_units.py
index 2a048f6..3d1c905 100644
--- a/pint/testsuite/test_log_units.py
+++ b/pint/testsuite/test_log_units.py
@@ -56,7 +56,7 @@ class TestLogarithmicQuantity(QuantityTestCase):
# ## Test dB to dB units octave - decade
# 1 decade = log2(10) octave
helpers.assert_quantity_almost_equal(
- self.Q_(1.0, "decade"), self.Q_(math.log(10, 2), "octave")
+ self.Q_(1.0, "decade"), self.Q_(math.log2(10), "octave")
)
# ## Test dB to dB units dBm - dBu
# 0 dBm = 1mW = 1e3 uW = 30 dBu
diff --git a/pint/testsuite/test_measurement.py b/pint/testsuite/test_measurement.py
index b78ca0e..9de2762 100644
--- a/pint/testsuite/test_measurement.py
+++ b/pint/testsuite/test_measurement.py
@@ -178,7 +178,7 @@ class TestMeasurement(QuantityTestCase):
):
with subtests.test(spec):
self.ureg.default_format = spec
- assert "{}".format(m) == result
+ assert f"{m}" == result
def test_raise_build(self):
v, u = self.Q_(1.0, "s"), self.Q_(0.1, "s")
diff --git a/pint/testsuite/test_non_int.py b/pint/testsuite/test_non_int.py
index 66637e1..5a74a99 100644
--- a/pint/testsuite/test_non_int.py
+++ b/pint/testsuite/test_non_int.py
@@ -740,10 +740,10 @@ class _TestQuantityBasicMath(NonIntTypeTestCase):
zy = self.Q_(fun(y.magnitude), "meter")
rx = fun(x)
ry = fun(y)
- assert rx == zx, "while testing {0}".format(fun)
- assert ry == zy, "while testing {0}".format(fun)
- assert rx is not zx, "while testing {0}".format(fun)
- assert ry is not zy, "while testing {0}".format(fun)
+ assert rx == zx, f"while testing {fun}"
+ assert ry == zy, f"while testing {fun}"
+ assert rx is not zx, f"while testing {fun}"
+ assert ry is not zy, f"while testing {fun}"
def test_quantity_float_complex(self):
x = self.QP_("-4.2", None)
@@ -1093,7 +1093,7 @@ class _TestOffsetUnitMath(NonIntTypeTestCase):
else:
in1, in2 = self.kwargs["non_int_type"](in1), self.QP_(*in2)
input_tuple = in1, in2 # update input_tuple for better tracebacks
- expected_copy = expected_output[:]
+ expected_copy = expected_output.copy()
for i, mode in enumerate([False, True]):
self.ureg.autoconvert_offset_to_baseunit = mode
if expected_copy[i] == "error":
@@ -1130,14 +1130,14 @@ class _TestOffsetUnitMath(NonIntTypeTestCase):
def test_exponentiation(self, input_tuple, expected_output):
self.ureg.default_as_delta = False
in1, in2 = input_tuple
- if type(in1) is tuple and type(in2) is tuple:
+ if type(in1) is type(in2) is tuple:
in1, in2 = self.QP_(*in1), self.QP_(*in2)
elif type(in1) is not tuple and type(in2) is tuple:
in1, in2 = self.kwargs["non_int_type"](in1), self.QP_(*in2)
else:
in1, in2 = self.QP_(*in1), self.kwargs["non_int_type"](in2)
input_tuple = in1, in2
- expected_copy = expected_output[:]
+ expected_copy = expected_output.copy()
for i, mode in enumerate([False, True]):
self.ureg.autoconvert_offset_to_baseunit = mode
if expected_copy[i] == "error":
diff --git a/pint/testsuite/test_numpy.py b/pint/testsuite/test_numpy.py
index 1e0b928..0e96c77 100644
--- a/pint/testsuite/test_numpy.py
+++ b/pint/testsuite/test_numpy.py
@@ -303,7 +303,7 @@ class TestNumpyMathematicalFunctions(TestNumpyMethods):
@helpers.requires_array_function_protocol()
def test_fix(self):
- helpers.assert_quantity_equal(np.fix(3.14 * self.ureg.m), 3.0 * self.ureg.m)
+ helpers.assert_quantity_equal(np.fix(3.13 * self.ureg.m), 3.0 * self.ureg.m)
helpers.assert_quantity_equal(np.fix(3.0 * self.ureg.m), 3.0 * self.ureg.m)
helpers.assert_quantity_equal(
np.fix([2.1, 2.9, -2.1, -2.9] * self.ureg.m),
@@ -505,7 +505,7 @@ class TestNumpyMathematicalFunctions(TestNumpyMethods):
arr = np.array(range(3), dtype=float)
q = self.Q_(arr, "meter")
- for op_ in [op.pow, op.ipow, np.power]:
+ for op_ in (op.pow, op.ipow, np.power):
q_cp = copy.copy(q)
with pytest.raises(DimensionalityError):
op_(2.0, q_cp)
diff --git a/pint/testsuite/test_quantity.py b/pint/testsuite/test_quantity.py
index 8fb712a..45b163d 100644
--- a/pint/testsuite/test_quantity.py
+++ b/pint/testsuite/test_quantity.py
@@ -393,7 +393,7 @@ class TestQuantity(QuantityTestCase):
temp = (Q_(" 1 lbf*m")).to_preferred(preferred_units)
# would prefer this to be repeatable, but mip doesn't guarantee that currently
- assert temp.units in [ureg.W * ureg.s, ureg.ft * ureg.lbf]
+ assert temp.units in (ureg.W * ureg.s, ureg.ft * ureg.lbf)
temp = Q_("1 kg").to_preferred(preferred_units)
assert temp.units == ureg.slug
@@ -1050,10 +1050,10 @@ class TestQuantityBasicMath(QuantityTestCase):
zy = self.Q_(fun(y.magnitude), "meter")
rx = fun(x)
ry = fun(y)
- assert rx == zx, "while testing {0}".format(fun)
- assert ry == zy, "while testing {0}".format(fun)
- assert rx is not zx, "while testing {0}".format(fun)
- assert ry is not zy, "while testing {0}".format(fun)
+ assert rx == zx, f"while testing {fun}"
+ assert ry == zy, f"while testing {fun}"
+ assert rx is not zx, f"while testing {fun}"
+ assert ry is not zy, f"while testing {fun}"
def test_quantity_float_complex(self):
x = self.Q_(-4.2, None)
@@ -1661,7 +1661,7 @@ class TestOffsetUnitMath(QuantityTestCase):
else:
in1, in2 = in1, self.Q_(*in2)
input_tuple = in1, in2 # update input_tuple for better tracebacks
- expected_copy = expected[:]
+ expected_copy = expected.copy()
for i, mode in enumerate([False, True]):
self.ureg.autoconvert_offset_to_baseunit = mode
if expected_copy[i] == "error":
@@ -1695,14 +1695,14 @@ class TestOffsetUnitMath(QuantityTestCase):
def test_exponentiation(self, input_tuple, expected):
self.ureg.default_as_delta = False
in1, in2 = input_tuple
- if type(in1) is tuple and type(in2) is tuple:
+ if type(in1) is type(in2) is tuple:
in1, in2 = self.Q_(*in1), self.Q_(*in2)
elif type(in1) is not tuple and type(in2) is tuple:
in2 = self.Q_(*in2)
else:
in1 = self.Q_(*in1)
input_tuple = in1, in2
- expected_copy = expected[:]
+ expected_copy = expected.copy()
for i, mode in enumerate([False, True]):
self.ureg.autoconvert_offset_to_baseunit = mode
if expected_copy[i] == "error":
@@ -1733,7 +1733,7 @@ class TestOffsetUnitMath(QuantityTestCase):
def test_inplace_exponentiation(self, input_tuple, expected):
self.ureg.default_as_delta = False
in1, in2 = input_tuple
- if type(in1) is tuple and type(in2) is tuple:
+ if type(in1) is type(in2) is tuple:
(q1v, q1u), (q2v, q2u) = in1, in2
in1 = self.Q_(*(np.array([q1v] * 2, dtype=float), q1u))
in2 = self.Q_(q2v, q2u)
@@ -1744,7 +1744,7 @@ class TestOffsetUnitMath(QuantityTestCase):
input_tuple = in1, in2
- expected_copy = expected[:]
+ expected_copy = expected.copy()
for i, mode in enumerate([False, True]):
self.ureg.autoconvert_offset_to_baseunit = mode
in1_cp = copy.copy(in1)
diff --git a/pint/testsuite/test_umath.py b/pint/testsuite/test_umath.py
index 6f32ab5..73d0ae7 100644
--- a/pint/testsuite/test_umath.py
+++ b/pint/testsuite/test_umath.py
@@ -79,7 +79,7 @@ class TestUFuncs:
if results is None:
results = [None] * len(ok_with)
for x1, res in zip(ok_with, results):
- err_msg = "At {} with {}".format(func.__name__, x1)
+ err_msg = f"At {func.__name__} with {x1}"
if output_units == "same":
ou = x1.units
elif isinstance(output_units, (int, float)):
@@ -163,7 +163,7 @@ class TestUFuncs:
if results is None:
results = [None] * len(ok_with)
for x1, res in zip(ok_with, results):
- err_msg = "At {} with {}".format(func.__name__, x1)
+ err_msg = f"At {func.__name__} with {x1}"
qms = func(x1)
if res is None:
res = func(x1.magnitude)
@@ -223,7 +223,7 @@ class TestUFuncs:
"""
for x2 in ok_with:
- err_msg = "At {} with {} and {}".format(func.__name__, x1, x2)
+ err_msg = f"At {func.__name__} with {x1} and {x2}"
if output_units == "same":
ou = x1.units
elif output_units == "prod":
diff --git a/pint/testsuite/test_unit.py b/pint/testsuite/test_unit.py
index 98a4fcc..c1a2704 100644
--- a/pint/testsuite/test_unit.py
+++ b/pint/testsuite/test_unit.py
@@ -2,6 +2,7 @@ import copy
import functools
import logging
import math
+import operator
import re
from contextlib import nullcontext as does_not_raise
@@ -144,7 +145,7 @@ class TestUnit(QuantityTestCase):
ureg = UnitRegistry()
- assert "{:new}".format(ureg.m) == "new format"
+ assert f"{ureg.m:new}" == "new format"
def test_ipython(self):
alltext = []
@@ -193,7 +194,7 @@ class TestUnit(QuantityTestCase):
("unit", "power_ratio", "expectation", "expected_unit"),
[
("m", 2, does_not_raise(), "m**2"),
- ("m", dict(), pytest.raises(TypeError), None),
+ ("m", {}, pytest.raises(TypeError), None),
],
)
def test_unit_pow(self, unit, power_ratio, expectation, expected_unit):
@@ -283,7 +284,7 @@ class TestRegistry(QuantityTestCase):
with pytest.raises(errors.RedefinitionError):
ureg.define("meter = [length]")
with pytest.raises(TypeError):
- ureg.define(list())
+ ureg.define([])
ureg.define("degC = kelvin; offset: 273.15")
def test_define(self):
@@ -394,7 +395,7 @@ class TestRegistry(QuantityTestCase):
)
def test_parse_pretty_degrees(self):
- for exp in ["1Δ°C", "1 Δ°C", "ΔdegC", "delta_°C"]:
+ for exp in ("1Δ°C", "1 Δ°C", "ΔdegC", "delta_°C"):
assert self.ureg.parse_expression(exp) == self.Q_(
1, UnitsContainer(delta_degree_Celsius=1)
)
@@ -566,8 +567,7 @@ class TestRegistry(QuantityTestCase):
assert f3(3.0 * ureg.centimeter) == 0.03 * ureg.centimeter
assert f3(3.0 * ureg.meter) == 3.0 * ureg.centimeter
- def gfunc(x, y):
- return x + y
+ gfunc = operator.add
g0 = ureg.wraps(None, [None, None])(gfunc)
assert g0(3, 1) == 4
@@ -596,8 +596,7 @@ class TestRegistry(QuantityTestCase):
def test_wrap_referencing(self):
ureg = self.ureg
- def gfunc(x, y):
- return x + y
+ gfunc = operator.add
def gfunc2(x, y):
return x**2 + y
@@ -650,8 +649,7 @@ class TestRegistry(QuantityTestCase):
with pytest.raises(DimensionalityError):
f0b(3.0 * ureg.kilogram)
- def gfunc(x, y):
- return x / y
+ gfunc = operator.truediv
g0 = ureg.check(None, None)(gfunc)
assert g0(6, 2) == 3
diff --git a/pint/testsuite/test_util.py b/pint/testsuite/test_util.py
index fd6494a..a61194d 100644
--- a/pint/testsuite/test_util.py
+++ b/pint/testsuite/test_util.py
@@ -120,13 +120,13 @@ class TestUnitsContainer:
UnitsContainer({"1": "2"})
d = UnitsContainer()
with pytest.raises(TypeError):
- d.__mul__(list())
+ d.__mul__([])
with pytest.raises(TypeError):
- d.__pow__(list())
+ d.__pow__([])
with pytest.raises(TypeError):
- d.__truediv__(list())
+ d.__truediv__([])
with pytest.raises(TypeError):
- d.__rtruediv__(list())
+ d.__rtruediv__([])
class TestToUnitsContainer:
@@ -193,9 +193,9 @@ class TestParseHelper:
assert "seconds" / z() == ParserHelper(0.5, seconds=1, meter=-2)
assert dict(seconds=1) / z() == ParserHelper(0.5, seconds=1, meter=-2)
- def _test_eval_token(self, expected, expression, use_decimal=False):
+ def _test_eval_token(self, expected, expression):
token = next(tokenizer(expression))
- actual = ParserHelper.eval_token(token, use_decimal=use_decimal)
+ actual = ParserHelper.eval_token(token)
assert expected == actual
assert type(expected) == type(actual)
@@ -353,12 +353,12 @@ class TestOtherUtils:
# Test with list, string, generator, and scalar
assert iterable([0, 1, 2, 3])
assert iterable("test")
- assert iterable((i for i in range(5)))
+ assert iterable(i for i in range(5))
assert not iterable(0)
def test_sized(self):
# Test with list, string, generator, and scalar
assert sized([0, 1, 2, 3])
assert sized("test")
- assert not sized((i for i in range(5)))
+ assert not sized(i for i in range(5))
assert not sized(0)
diff --git a/pint/util.py b/pint/util.py
index d5f3aab..149945b 100644
--- a/pint/util.py
+++ b/pint/util.py
@@ -10,54 +10,85 @@
from __future__ import annotations
-import functools
-import inspect
import logging
import math
import operator
import re
-from collections.abc import Mapping
+from collections.abc import Mapping, Iterable, Iterator
from fractions import Fraction
from functools import lru_cache, partial
from logging import NullHandler
from numbers import Number
from token import NAME, NUMBER
-from typing import TYPE_CHECKING, ClassVar, Optional, Type, Union
+import tokenize
+from typing import (
+ TYPE_CHECKING,
+ ClassVar,
+ TypeAlias,
+ Callable,
+ TypeVar,
+ Any,
+)
+from collections.abc import Hashable, Generator
from .compat import NUMERIC_TYPES, tokenizer
from .errors import DefinitionSyntaxError
from .formatting import format_unit
from .pint_eval import build_eval_tree
+from ._typing import PintScalar
+
if TYPE_CHECKING:
- from ._typing import Quantity, UnitLike
+ from ._typing import Quantity, UnitLike, Self
from .registry import UnitRegistry
+
logger = logging.getLogger(__name__)
logger.addHandler(NullHandler())
+T = TypeVar("T")
+TH = TypeVar("TH", bound=Hashable)
+ItMatrix: TypeAlias = Iterable[Iterable[PintScalar]]
+Matrix: TypeAlias = list[list[PintScalar]]
+
+
+def _noop(x: T) -> T:
+ return x
+
def matrix_to_string(
- matrix, row_headers=None, col_headers=None, fmtfun=lambda x: str(int(x))
-):
- """Takes a 2D matrix (as nested list) and returns a string.
+ matrix: ItMatrix,
+ row_headers: Iterable[str] | None = None,
+ col_headers: Iterable[str] | None = None,
+ fmtfun: Callable[
+ [
+ PintScalar,
+ ],
+ str,
+ ] = "{:0.0f}".format,
+) -> str:
+ """Return a string representation of a matrix.
Parameters
----------
- matrix :
-
- row_headers :
- (Default value = None)
- col_headers :
- (Default value = None)
- fmtfun :
- (Default value = lambda x: str(int(x)))
+ matrix
+ A matrix given as an iterable of an iterable of numbers.
+ row_headers
+ An iterable of strings to serve as row headers.
+ (default = None, meaning no row headers are printed.)
+ col_headers
+ An iterable of strings to serve as column headers.
+ (default = None, meaning no col headers are printed.)
+ fmtfun
+ A callable to convert a number into string.
+ (default = `"{:0.0f}".format`)
Returns
-------
-
+ str
+ String representation of the matrix.
"""
- ret = []
+ ret: list[str] = []
if col_headers:
ret.append(("\t" if row_headers else "") + "\t".join(col_headers))
if row_headers:
@@ -71,99 +102,124 @@ def matrix_to_string(
return "\n".join(ret)
-def transpose(matrix):
- """Takes a 2D matrix (as nested list) and returns the transposed version.
+def transpose(matrix: ItMatrix) -> Matrix:
+ """Return the transposed version of a matrix.
Parameters
----------
- matrix :
-
+ matrix
+ A matrix given as an iterable of an iterable of numbers.
Returns
-------
-
+ Matrix
+ The transposed version of the matrix.
"""
return [list(val) for val in zip(*matrix)]
-def column_echelon_form(matrix, ntype=Fraction, transpose_result=False):
- """Calculates the column echelon form using Gaussian elimination.
+def matrix_apply(
+ matrix: ItMatrix,
+ func: Callable[
+ [
+ PintScalar,
+ ],
+ PintScalar,
+ ],
+) -> Matrix:
+ """Apply a function to individual elements within a matrix.
Parameters
----------
- matrix :
- a 2D matrix as nested list.
- ntype :
- the numerical type to use in the calculation. (Default value = Fraction)
- transpose_result :
- indicates if the returned matrix should be transposed. (Default value = False)
+ matrix
+ A matrix given as an iterable of an iterable of numbers.
+ func
+ A callable that converts a number to another.
Returns
-------
- type
- column echelon form, transformed identity matrix, swapped rows
-
+ A new matrix in which each element has been replaced by new one.
"""
- lead = 0
+ return [[func(x) for x in row] for row in matrix]
- M = transpose(matrix)
- _transpose = transpose if transpose_result else lambda x: x
+def column_echelon_form(
+ matrix: ItMatrix, ntype: type = Fraction, transpose_result: bool = False
+) -> tuple[Matrix, Matrix, list[int]]:
+ """Calculate the column echelon form using Gaussian elimination.
- rows, cols = len(M), len(M[0])
+ Parameters
+ ----------
+ matrix
+ A 2D matrix as nested list.
+ ntype
+ The numerical type to use in the calculation.
+ (default = Fraction)
+ transpose_result
+ Indicates if the returned matrix should be transposed.
+ (default = False)
- new_M = []
- for row in M:
- r = []
- for x in row:
- if isinstance(x, float):
- x = ntype.from_float(x)
- else:
- x = ntype(x)
- r.append(x)
- new_M.append(r)
- M = new_M
+ Returns
+ -------
+ ech_matrix
+ Column echelon form.
+ id_matrix
+ Transformed identity matrix.
+ swapped
+ Swapped rows.
+ """
+
+ _transpose = transpose if transpose_result else _noop
+
+ ech_matrix = matrix_apply(
+ transpose(matrix),
+ lambda x: ntype.from_float(x) if isinstance(x, float) else ntype(x), # type: ignore
+ )
+ rows, cols = len(ech_matrix), len(ech_matrix[0])
# M = [[ntype(x) for x in row] for row in M]
- I = [ # noqa: E741
+ id_matrix: list[list[PintScalar]] = [ # noqa: E741
[ntype(1) if n == nc else ntype(0) for nc in range(rows)] for n in range(rows)
]
- swapped = []
+ swapped: list[int] = []
+ lead = 0
for r in range(rows):
if lead >= cols:
- return _transpose(M), _transpose(I), swapped
- i = r
- while M[i][lead] == 0:
- i += 1
- if i != rows:
+ return _transpose(ech_matrix), _transpose(id_matrix), swapped
+ s = r
+ while ech_matrix[s][lead] == 0: # type: ignore
+ s += 1
+ if s != rows:
continue
- i = r
+ s = r
lead += 1
if cols == lead:
- return _transpose(M), _transpose(I), swapped
+ return _transpose(ech_matrix), _transpose(id_matrix), swapped
- M[i], M[r] = M[r], M[i]
- I[i], I[r] = I[r], I[i]
+ ech_matrix[s], ech_matrix[r] = ech_matrix[r], ech_matrix[s]
+ id_matrix[s], id_matrix[r] = id_matrix[r], id_matrix[s]
- swapped.append(i)
- lv = M[r][lead]
- M[r] = [mrx / lv for mrx in M[r]]
- I[r] = [mrx / lv for mrx in I[r]]
+ swapped.append(s)
+ lv = ech_matrix[r][lead]
+ ech_matrix[r] = [mrx / lv for mrx in ech_matrix[r]]
+ id_matrix[r] = [mrx / lv for mrx in id_matrix[r]]
- for i in range(rows):
- if i == r:
+ for s in range(rows):
+ if s == r:
continue
- lv = M[i][lead]
- M[i] = [iv - lv * rv for rv, iv in zip(M[r], M[i])]
- I[i] = [iv - lv * rv for rv, iv in zip(I[r], I[i])]
+ lv = ech_matrix[s][lead]
+ ech_matrix[s] = [
+ iv - lv * rv for rv, iv in zip(ech_matrix[r], ech_matrix[s])
+ ]
+ id_matrix[s] = [iv - lv * rv for rv, iv in zip(id_matrix[r], id_matrix[s])]
lead += 1
- return _transpose(M), _transpose(I), swapped
+ return _transpose(ech_matrix), _transpose(id_matrix), swapped
-def pi_theorem(quantities, registry=None):
+def pi_theorem(quantities: dict[str, Any], registry: UnitRegistry | None = None):
"""Builds dimensionless quantities using the Buckingham π theorem
Parameters
@@ -171,7 +227,7 @@ def pi_theorem(quantities, registry=None):
quantities : dict
mapping between variable name and units
registry :
- (Default value = None)
+ (default value = None)
Returns
-------
@@ -185,7 +241,7 @@ def pi_theorem(quantities, registry=None):
dimensions = set()
if registry is None:
- getdim = lambda x: x
+ getdim = _noop
non_int_type = float
else:
getdim = registry.get_dimensionality
@@ -213,33 +269,35 @@ def pi_theorem(quantities, registry=None):
dimensions = list(dimensions)
# Calculate dimensionless quantities
- M = [
+ matrix = [
[dimensionality[dimension] for name, dimensionality in quant]
for dimension in dimensions
]
- M, identity, pivot = column_echelon_form(M, transpose_result=False)
+ ech_matrix, id_matrix, pivot = column_echelon_form(matrix, transpose_result=False)
# Collect results
# Make all numbers integers and minimize the number of negative exponents.
# Remove zeros
results = []
- for rowm, rowi in zip(M, identity):
+ for rowm, rowi in zip(ech_matrix, id_matrix):
if any(el != 0 for el in rowm):
continue
max_den = max(f.denominator for f in rowi)
neg = -1 if sum(f < 0 for f in rowi) > sum(f > 0 for f in rowi) else 1
results.append(
- dict(
- (q[0], neg * f.numerator * max_den / f.denominator)
+ {
+ q[0]: neg * f.numerator * max_den / f.denominator
for q, f in zip(quant, rowi)
if f.numerator != 0
- )
+ }
)
return results
-def solve_dependencies(dependencies):
+def solve_dependencies(
+ dependencies: dict[TH, set[TH]]
+) -> Generator[set[TH], None, None]:
"""Solve a dependency graph.
Parameters
@@ -248,12 +306,16 @@ def solve_dependencies(dependencies):
dependency dictionary. For each key, the value is an iterable indicating its
dependencies.
- Returns
- -------
- type
+ Yields
+ ------
+ set
iterator of sets, each containing keys of independents tasks dependent only of
the previous tasks in the list.
+ Raises
+ ------
+ ValueError
+ if a cyclic dependency is found.
"""
while dependencies:
# values not in keys (items without dep)
@@ -272,12 +334,37 @@ def solve_dependencies(dependencies):
yield t
-def find_shortest_path(graph, start, end, path=None):
+def find_shortest_path(
+ graph: dict[TH, set[TH]], start: TH, end: TH, path: list[TH] | None = None
+):
+ """Find shortest path between two nodes within a graph.
+
+ Parameters
+ ----------
+ graph
+ A graph given as a mapping of nodes
+ to a set of all connected nodes to it.
+ start
+ Starting node.
+ end
+ End node.
+ path
+ Path to prepend to the one found.
+ (default = None, empty path.)
+
+ Returns
+ -------
+ list[TH]
+ The shortest path between two nodes.
+ """
path = (path or []) + [start]
if start == end:
return path
+
+ # TODO: raise ValueError when start not in graph
if start not in graph:
return None
+
shortest = None
for node in graph[start]:
if node not in path:
@@ -285,10 +372,33 @@ def find_shortest_path(graph, start, end, path=None):
if newpath:
if not shortest or len(newpath) < len(shortest):
shortest = newpath
+
return shortest
-def find_connected_nodes(graph, start, visited=None):
+def find_connected_nodes(
+ graph: dict[TH, set[TH]], start: TH, visited: set[TH] | None = None
+) -> set[TH] | None:
+ """Find all nodes connected to a start node within a graph.
+
+ Parameters
+ ----------
+ graph
+ A graph given as a mapping of nodes
+ to a set of all connected nodes to it.
+ start
+ Starting node.
+ visited
+ Mutable set to collect visited nodes.
+ (default = None, empty set)
+
+ Returns
+ -------
+ set[TH]
+ The shortest path between two nodes.
+ """
+
+ # TODO: raise ValueError when start not in graph
if start not in graph:
return None
@@ -302,17 +412,17 @@ def find_connected_nodes(graph, start, visited=None):
return visited
-class udict(dict):
+class udict(dict[str, PintScalar]):
"""Custom dict implementing __missing__."""
- def __missing__(self, key):
+ def __missing__(self, key: str):
return 0
- def copy(self):
+ def copy(self: Self) -> Self:
return udict(self)
-class UnitsContainer(Mapping):
+class UnitsContainer(Mapping[str, PintScalar]):
"""The UnitsContainer stores the product of units and their respective
exponent and implements the corresponding operations.
@@ -320,23 +430,24 @@ class UnitsContainer(Mapping):
Parameters
----------
-
- Returns
- -------
- type
-
-
+ non_int_type
+ Numerical type used for non integer values.
"""
__slots__ = ("_d", "_hash", "_one", "_non_int_type")
- def __init__(self, *args, **kwargs) -> None:
+ _d: udict
+ _hash: int | None
+ _one: PintScalar
+ _non_int_type: type
+
+ def __init__(self, *args, non_int_type: type | None = None, **kwargs) -> None:
if args and isinstance(args[0], UnitsContainer):
default_non_int_type = args[0]._non_int_type
else:
default_non_int_type = float
- self._non_int_type = kwargs.pop("non_int_type", default_non_int_type)
+ self._non_int_type = non_int_type or default_non_int_type
if self._non_int_type is float:
self._one = 1
@@ -347,17 +458,33 @@ class UnitsContainer(Mapping):
self._d = d
for key, value in d.items():
if not isinstance(key, str):
- raise TypeError("key must be a str, not {}".format(type(key)))
+ raise TypeError(f"key must be a str, not {type(key)}")
if not isinstance(value, Number):
- raise TypeError("value must be a number, not {}".format(type(value)))
+ raise TypeError(f"value must be a number, not {type(value)}")
if not isinstance(value, int) and not isinstance(value, self._non_int_type):
d[key] = self._non_int_type(value)
self._hash = None
- def copy(self):
+ def copy(self: Self) -> Self:
+ """Create a copy of this UnitsContainer."""
return self.__copy__()
- def add(self, key, value):
+ def add(self: Self, key: str, value: Number) -> Self:
+ """Create a new UnitsContainer adding value to
+ the value existing for a given key.
+
+ Parameters
+ ----------
+ key
+ unit to which the value will be added.
+ value
+ value to be added.
+
+ Returns
+ -------
+ UnitsContainer
+ A copy of this container.
+ """
newval = self._d[key] + value
new = self.copy()
if newval:
@@ -367,17 +494,18 @@ class UnitsContainer(Mapping):
new._hash = None
return new
- def remove(self, keys):
- """Create a new UnitsContainer purged from given keys.
+ def remove(self: Self, keys: Iterable[str]) -> Self:
+ """Create a new UnitsContainer purged from given entries.
Parameters
----------
- keys :
-
+ keys
+ Iterable of keys (units) to remove.
Returns
-------
-
+ UnitsContainer
+ A copy of this container.
"""
new = self.copy()
for k in keys:
@@ -385,51 +513,52 @@ class UnitsContainer(Mapping):
new._hash = None
return new
- def rename(self, oldkey, newkey):
+ def rename(self: Self, oldkey: str, newkey: str) -> Self:
"""Create a new UnitsContainer in which an entry has been renamed.
Parameters
----------
- oldkey :
-
- newkey :
-
+ oldkey
+ Existing key (unit).
+ newkey
+ New key (unit).
Returns
-------
-
+ UnitsContainer
+ A copy of this container.
"""
new = self.copy()
new._d[newkey] = new._d.pop(oldkey)
new._hash = None
return new
- def __iter__(self):
+ def __iter__(self) -> Iterator[str]:
return iter(self._d)
def __len__(self) -> int:
return len(self._d)
- def __getitem__(self, key):
+ def __getitem__(self, key: str) -> PintScalar:
return self._d[key]
- def __contains__(self, key):
+ def __contains__(self, key: str) -> bool:
return key in self._d
- def __hash__(self):
+ def __hash__(self) -> int:
if self._hash is None:
self._hash = hash(frozenset(self._d.items()))
return self._hash
# Only needed by pickle protocol 0 and 1 (used by pytables)
- def __getstate__(self):
+ def __getstate__(self) -> tuple[udict, PintScalar, type]:
return self._d, self._one, self._non_int_type
- def __setstate__(self, state):
+ def __setstate__(self, state: tuple[udict, PintScalar, type]):
self._d, self._one, self._non_int_type = state
self._hash = None
- def __eq__(self, other) -> bool:
+ def __eq__(self, other: Any) -> bool:
if isinstance(other, UnitsContainer):
# UnitsContainer.__hash__(self) is not the same as hash(self); see
# ParserHelper.__hash__ and __eq__.
@@ -455,9 +584,9 @@ class UnitsContainer(Mapping):
def __repr__(self) -> str:
tmp = "{%s}" % ", ".join(
- ["'{}': {}".format(key, value) for key, value in sorted(self._d.items())]
+ [f"'{key}': {value}" for key, value in sorted(self._d.items())]
)
- return "<UnitsContainer({})>".format(tmp)
+ return f"<UnitsContainer({tmp})>"
def __format__(self, spec: str) -> str:
return format_unit(self, spec)
@@ -474,7 +603,7 @@ class UnitsContainer(Mapping):
out._one = self._one
return out
- def __mul__(self, other):
+ def __mul__(self, other: Any):
if not isinstance(other, self.__class__):
err = "Cannot multiply UnitsContainer by {}"
raise TypeError(err.format(type(other)))
@@ -490,7 +619,7 @@ class UnitsContainer(Mapping):
__rmul__ = __mul__
- def __pow__(self, other):
+ def __pow__(self, other: Any):
if not isinstance(other, NUMERIC_TYPES):
err = "Cannot power UnitsContainer by {}"
raise TypeError(err.format(type(other)))
@@ -501,7 +630,7 @@ class UnitsContainer(Mapping):
new._hash = None
return new
- def __truediv__(self, other):
+ def __truediv__(self, other: Any):
if not isinstance(other, self.__class__):
err = "Cannot divide UnitsContainer by {}"
raise TypeError(err.format(type(other)))
@@ -515,7 +644,7 @@ class UnitsContainer(Mapping):
new._hash = None
return new
- def __rtruediv__(self, other):
+ def __rtruediv__(self, other: Any):
if not isinstance(other, self.__class__) and other != 1:
err = "Cannot divide {} by UnitsContainer"
raise TypeError(err.format(type(other)))
@@ -526,41 +655,48 @@ class UnitsContainer(Mapping):
class ParserHelper(UnitsContainer):
"""The ParserHelper stores in place the product of variables and
their respective exponent and implements the corresponding operations.
+ It also provides a scaling factor.
+
+ For example:
+ `3 * m ** 2` becomes ParserHelper(3, m=2)
+
+ Briefly is a UnitsContainer with a scaling factor.
ParserHelper is a read-only mapping. All operations (even in place ones)
+ WARNING : The hash value used does not take into account the scale
+ attribute so be careful if you use it as a dict key and then two unequal
+ object can have the same hash.
+
Parameters
----------
-
- Returns
- -------
- type
- WARNING : The hash value used does not take into account the scale
- attribute so be careful if you use it as a dict key and then two unequal
- object can have the same hash.
-
+ scale
+ Scaling factor.
+ (default = 1)
+ **kwargs
+ Used to populate the dict of units and exponents.
"""
__slots__ = ("scale",)
- def __init__(self, scale=1, *args, **kwargs):
+ scale: PintScalar
+
+ def __init__(self, scale: PintScalar = 1, *args, **kwargs):
super().__init__(*args, **kwargs)
self.scale = scale
@classmethod
- def from_word(cls, input_word, non_int_type=float):
+ def from_word(cls, input_word: str, non_int_type: type = float) -> ParserHelper:
"""Creates a ParserHelper object with a single variable with exponent one.
- Equivalent to: ParserHelper({'word': 1})
+ Equivalent to: ParserHelper(1, {input_word: 1})
Parameters
----------
- input_word :
-
-
- Returns
- -------
+ input_word
+ non_int_type
+ Numerical type used for non integer values.
"""
if non_int_type is float:
return cls(1, [(input_word, 1)], non_int_type=non_int_type)
@@ -569,15 +705,7 @@ class ParserHelper(UnitsContainer):
return cls(ONE, [(input_word, ONE)], non_int_type=non_int_type)
@classmethod
- def eval_token(cls, token, use_decimal=False, non_int_type=float):
- # TODO: remove this code when use_decimal is deprecated
- if use_decimal:
- raise DeprecationWarning(
- "`use_decimal` is deprecated, use `non_int_type` keyword argument when instantiating the registry.\n"
- ">>> from decimal import Decimal\n"
- ">>> ureg = UnitRegistry(non_int_type=Decimal)"
- )
-
+ def eval_token(cls, token: tokenize.TokenInfo, non_int_type: type = float):
token_type = token.type
token_text = token.string
if token_type == NUMBER:
@@ -594,18 +722,16 @@ class ParserHelper(UnitsContainer):
raise Exception("unknown token type")
@classmethod
- @lru_cache()
- def from_string(cls, input_string, non_int_type=float):
+ @lru_cache
+ def from_string(cls, input_string: str, non_int_type: type = float) -> ParserHelper:
"""Parse linear expression mathematical units and return a quantity object.
Parameters
----------
- input_string :
-
-
- Returns
- -------
+ input_string
+ non_int_type
+ Numerical type used for non integer values.
"""
if not input_string:
return cls(non_int_type=non_int_type)
@@ -666,17 +792,17 @@ class ParserHelper(UnitsContainer):
super().__setstate__(state[:-1])
self.scale = state[-1]
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> bool:
if isinstance(other, ParserHelper):
return self.scale == other.scale and super().__eq__(other)
elif isinstance(other, str):
return self == ParserHelper.from_string(other, self._non_int_type)
elif isinstance(other, Number):
return self.scale == other and not len(self._d)
- else:
- return self.scale == 1 and super().__eq__(other)
- def operate(self, items, op=operator.iadd, cleanup=True):
+ return self.scale == 1 and super().__eq__(other)
+
+ def operate(self, items, op=operator.iadd, cleanup: bool = True):
d = udict(self._d)
for key, value in items:
d[key] = op(d[key], value)
@@ -690,15 +816,15 @@ class ParserHelper(UnitsContainer):
def __str__(self):
tmp = "{%s}" % ", ".join(
- ["'{}': {}".format(key, value) for key, value in sorted(self._d.items())]
+ [f"'{key}': {value}" for key, value in sorted(self._d.items())]
)
- return "{} {}".format(self.scale, tmp)
+ return f"{self.scale} {tmp}"
def __repr__(self):
tmp = "{%s}" % ", ".join(
- ["'{}': {}".format(key, value) for key, value in sorted(self._d.items())]
+ [f"'{key}': {value}" for key, value in sorted(self._d.items())]
)
- return "<ParserHelper({}, {})>".format(self.scale, tmp)
+ return f"<ParserHelper({self.scale}, {tmp})>"
def __mul__(self, other):
if isinstance(other, str):
@@ -821,21 +947,22 @@ class SharedRegistryObject:
inst._REGISTRY = application_registry.get()
return inst
- def _check(self, other) -> bool:
+ def _check(self, other: Any) -> bool:
"""Check if the other object use a registry and if so that it is the
same registry.
Parameters
----------
- other :
-
+ other
Returns
-------
- type
- other don't use a registry and raise ValueError if other don't use the
- same unit registry.
+ bool
+ Raises
+ ------
+ ValueError
+ if other don't use the same unit registry.
"""
if self._REGISTRY is getattr(other, "_REGISTRY", None):
return True
@@ -854,40 +981,39 @@ class PrettyIPython:
default_format: str
- def _repr_html_(self):
+ def _repr_html_(self) -> str:
if "~" in self.default_format:
- return "{:~H}".format(self)
- else:
- return "{:H}".format(self)
+ return f"{self:~H}"
+ return f"{self:H}"
- def _repr_latex_(self):
+ def _repr_latex_(self) -> str:
if "~" in self.default_format:
- return "${:~L}$".format(self)
- else:
- return "${:L}$".format(self)
+ return f"${self:~L}$"
+ return f"${self:L}$"
- def _repr_pretty_(self, p, cycle):
+ def _repr_pretty_(self, p, cycle: bool):
if "~" in self.default_format:
- p.text("{:~P}".format(self))
+ p.text(f"{self:~P}")
else:
- p.text("{:P}".format(self))
+ p.text(f"{self:P}")
def to_units_container(
- unit_like: Union[UnitLike, Quantity], registry: Optional[UnitRegistry] = None
+ unit_like: UnitLike | Quantity, registry: UnitRegistry | None = None
) -> UnitsContainer:
"""Convert a unit compatible type to a UnitsContainer.
Parameters
----------
- unit_like :
-
- registry :
- (Default value = None)
+ unit_like
+ Quantity or Unit to infer the plain units from.
+ registry
+ If provided, uses the registry's UnitsContainer and parse_unit_name. If None,
+ uses the registry attached to unit_like.
Returns
-------
-
+ UnitsContainer
"""
mro = type(unit_like).mro()
if UnitsContainer in mro:
@@ -907,17 +1033,16 @@ def to_units_container(
def infer_base_unit(
- unit_like: Union[UnitLike, Quantity], registry: Optional[UnitRegistry] = None
+ unit_like: UnitLike | Quantity, registry: UnitRegistry | None = None
) -> UnitsContainer:
"""
Given a Quantity or UnitLike, give the UnitsContainer for it's plain units.
Parameters
----------
- unit_like : Union[UnitLike, Quantity]
+ unit_like
Quantity or Unit to infer the plain units from.
-
- registry: Optional[UnitRegistry]
+ registry
If provided, uses the registry's UnitsContainer and parse_unit_name. If None,
uses the registry attached to unit_like.
@@ -952,7 +1077,7 @@ def infer_base_unit(
return registry.UnitsContainer(nonzero_dict)
-def getattr_maybe_raise(self, item):
+def getattr_maybe_raise(obj: Any, item: str):
"""Helper function invoked at start of all overridden ``__getattr__``.
Raise AttributeError if the user tries to ask for a _ or __ attribute,
@@ -961,39 +1086,25 @@ def getattr_maybe_raise(self, item):
Parameters
----------
- item : string
- Item to be found.
-
-
- Returns
- -------
+ item
+ attribute to be found.
+ Raises
+ ------
+ AttributeError
"""
# Double-underscore attributes are tricky to detect because they are
- # automatically prefixed with the class name - which may be a subclass of self
+ # automatically prefixed with the class name - which may be a subclass of obj
if (
item.endswith("__")
or len(item.lstrip("_")) == 0
or (item.startswith("_") and not item.lstrip("_")[0].isdigit())
):
- raise AttributeError("%r object has no attribute %r" % (self, item))
-
+ raise AttributeError(f"{obj!r} object has no attribute {item!r}")
-def iterable(y) -> bool:
- """Check whether or not an object can be iterated over.
- Vendored from numpy under the terms of the BSD 3-Clause License. (Copyright
- (c) 2005-2019, NumPy Developers.)
-
- Parameters
- ----------
- value :
- Input object.
- type :
- object
- y :
-
- """
+def iterable(y: Any) -> bool:
+ """Check whether or not an object can be iterated over."""
try:
iter(y)
except TypeError:
@@ -1001,18 +1112,8 @@ def iterable(y) -> bool:
return True
-def sized(y) -> bool:
- """Check whether or not an object has a defined length.
-
- Parameters
- ----------
- value :
- Input object.
- type :
- object
- y :
-
- """
+def sized(y: Any) -> bool:
+ """Check whether or not an object has a defined length."""
try:
len(y)
except TypeError:
@@ -1020,37 +1121,9 @@ def sized(y) -> bool:
return True
-@functools.lru_cache(
- maxsize=None
-) # TODO: replace with cache when Python 3.8 is dropped.
-def _build_type(class_name: str, bases):
- return type(class_name, bases, dict())
-
-
-def build_dependent_class(registry_class, class_name: str, attribute_name: str) -> Type:
- """Creates a class specifically for the given registry that
- subclass all the classes named by the registry bases in a
- specific attribute
-
- 1. List the 'attribute_name' attribute for each of the bases of the registry class.
- 2. Use this list as bases for the new class
- 3. Add the provided registry as the class registry.
-
- """
- bases = (
- getattr(base, attribute_name)
- for base in inspect.getmro(registry_class)
- if attribute_name in base.__dict__
- )
- bases = tuple(dict.fromkeys(bases, None).keys())
- if len(bases) == 1 and bases[0].__name__ == class_name:
- return bases[0]
- return _build_type(class_name, bases)
-
-
-def create_class_with_registry(registry, base_class) -> Type:
+def create_class_with_registry(registry: UnitRegistry, base_class: type) -> type:
"""Create new class inheriting from base_class and
filling _REGISTRY class attribute with an actual instanced registry.
"""
- return type(base_class.__name__, tuple((base_class,)), dict(_REGISTRY=registry))
+ return type(base_class.__name__, (base_class,), dict(_REGISTRY=registry))
diff --git a/pyproject.toml b/pyproject.toml
index 72b6560..bbcfbdf 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -22,13 +22,12 @@ classifiers = [
"Programming Language :: Python",
"Topic :: Scientific/Engineering",
"Topic :: Software Development :: Libraries",
- "Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11"
]
-requires-python = ">=3.8"
-dynamic = ["version"]
+requires-python = ">=3.9"
+dynamic = ["version"] # Version is taken from git tags using setuptools_scm
[tool.setuptools.package-data]
pint = [