summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJules Chéron <jules.cheron@gmail.com>2021-02-16 17:06:03 +0100
committerJules Chéron <jules.cheron@gmail.com>2021-02-16 17:14:27 +0100
commit146f4f16a6268860e0f27c1e129df0ac341eebb4 (patch)
tree10b67988ccf4e33f5116e98d54a68bb77a6e1cbb
parent397022713184adb0acd9eedf15d1df236db39a7a (diff)
downloadpint-146f4f16a6268860e0f27c1e129df0ac341eebb4.tar.gz
Fix numpy.linalg.solve units output.
Update get_op_output_unit with new type invdiv. It outputs the product of the following units over the first one in the args list. Update tests with values & np.dot(A, x) == b. Where x = np.linalg.solve(A, b) Closes #1246
-rw-r--r--CHANGES1
-rw-r--r--pint/numpy_func.py14
-rw-r--r--pint/testsuite/test_numpy.py11
3 files changed, 20 insertions, 6 deletions
diff --git a/CHANGES b/CHANGES
index b93b83b..cf50d1e 100644
--- a/CHANGES
+++ b/CHANGES
@@ -12,6 +12,7 @@ Pint Changelog
- UnitsContainer returns false if other is str and cannnot be parsed
(Issue #1179, thanks rfrowe)
- Add Github Actions CI. (Issue #1236)
+- Fix numpy.linalg.solve unit output. (Issue #1246)
0.16.1 (2020-09-22)
-------------------
diff --git a/pint/numpy_func.py b/pint/numpy_func.py
index 1dce044..c335f3d 100644
--- a/pint/numpy_func.py
+++ b/pint/numpy_func.py
@@ -148,6 +148,7 @@ def get_op_output_unit(unit_op, first_input_units, all_args=None, size=None):
- "sqrt": square root of `first_input_units`
- "reciprocal": reciprocal of `first_input_units`
- "size": `first_input_units` raised to the power of `size`
+ - "invdiv": inverse of `div`, product of all following units divided by first argument unit
Parameters
----------
@@ -205,7 +206,15 @@ def get_op_output_unit(unit_op, first_input_units, all_args=None, size=None):
if size is None:
raise ValueError('size argument must be given when unit_op=="size"')
result_unit = first_input_units ** size
-
+ elif unit_op == "invdiv":
+ # Start with first arg in numerator, all others in denominator
+ product = getattr(
+ all_args[0], "units", first_input_units._REGISTRY.parse_units("")
+ )
+ for x in all_args[1:]:
+ if hasattr(x, "units"):
+ product /= x.units
+ result_unit = product ** -1
else:
raise ValueError("Output unit method {} not understood".format(unit_op))
@@ -304,6 +313,7 @@ def implement_func(func_type, func_str, input_units=None, output_unit=None):
"delta",
"delta,div",
"div",
+ "invdiv",
"variance",
"square",
"sqrt",
@@ -878,7 +888,7 @@ for func_str in ["diff", "ediff1d"]:
for func_str in ["gradient"]:
implement_func("function", func_str, input_units=None, output_unit="delta,div")
for func_str in ["linalg.solve"]:
- implement_func("function", func_str, input_units=None, output_unit="div")
+ implement_func("function", func_str, input_units=None, output_unit="invdiv")
for func_str in ["var", "nanvar"]:
implement_func("function", func_str, input_units=None, output_unit="variance")
diff --git a/pint/testsuite/test_numpy.py b/pint/testsuite/test_numpy.py
index beea69d..44d4271 100644
--- a/pint/testsuite/test_numpy.py
+++ b/pint/testsuite/test_numpy.py
@@ -449,10 +449,13 @@ class TestNumpyMathematicalFunctions(TestNumpyMethods):
@helpers.requires_array_function_protocol()
def test_solve(self):
- helpers.assert_quantity_almost_equal(
- np.linalg.solve(self.q, [[3], [7]] * self.ureg.s),
- self.Q_([[1], [1]], "m / s"),
- )
+ A = self.q
+ b = [[3], [7]] * self.ureg.s
+ x = np.linalg.solve(A, b)
+
+ helpers.assert_quantity_almost_equal(x, self.Q_([[1], [1]], "s / m"))
+
+ helpers.assert_quantity_almost_equal(np.dot(A, x), b)
# Arithmetic operations
def test_addition_with_scalar(self):