From 146f4f16a6268860e0f27c1e129df0ac341eebb4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jules=20Ch=C3=A9ron?= Date: Tue, 16 Feb 2021 17:06:03 +0100 Subject: Fix numpy.linalg.solve units output. Update get_op_output_unit with new type invdiv. It outputs the product of the following units over the first one in the args list. Update tests with values & np.dot(A, x) == b. Where x = np.linalg.solve(A, b) Closes #1246 --- CHANGES | 1 + pint/numpy_func.py | 14 ++++++++++++-- pint/testsuite/test_numpy.py | 11 +++++++---- 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/CHANGES b/CHANGES index b93b83b..cf50d1e 100644 --- a/CHANGES +++ b/CHANGES @@ -12,6 +12,7 @@ Pint Changelog - UnitsContainer returns false if other is str and cannnot be parsed (Issue #1179, thanks rfrowe) - Add Github Actions CI. (Issue #1236) +- Fix numpy.linalg.solve unit output. (Issue #1246) 0.16.1 (2020-09-22) ------------------- diff --git a/pint/numpy_func.py b/pint/numpy_func.py index 1dce044..c335f3d 100644 --- a/pint/numpy_func.py +++ b/pint/numpy_func.py @@ -148,6 +148,7 @@ def get_op_output_unit(unit_op, first_input_units, all_args=None, size=None): - "sqrt": square root of `first_input_units` - "reciprocal": reciprocal of `first_input_units` - "size": `first_input_units` raised to the power of `size` + - "invdiv": inverse of `div`, product of all following units divided by first argument unit Parameters ---------- @@ -205,7 +206,15 @@ def get_op_output_unit(unit_op, first_input_units, all_args=None, size=None): if size is None: raise ValueError('size argument must be given when unit_op=="size"') result_unit = first_input_units ** size - + elif unit_op == "invdiv": + # Start with first arg in numerator, all others in denominator + product = getattr( + all_args[0], "units", first_input_units._REGISTRY.parse_units("") + ) + for x in all_args[1:]: + if hasattr(x, "units"): + product /= x.units + result_unit = product ** -1 else: raise ValueError("Output unit method {} not understood".format(unit_op)) @@ -304,6 +313,7 @@ def implement_func(func_type, func_str, input_units=None, output_unit=None): "delta", "delta,div", "div", + "invdiv", "variance", "square", "sqrt", @@ -878,7 +888,7 @@ for func_str in ["diff", "ediff1d"]: for func_str in ["gradient"]: implement_func("function", func_str, input_units=None, output_unit="delta,div") for func_str in ["linalg.solve"]: - implement_func("function", func_str, input_units=None, output_unit="div") + implement_func("function", func_str, input_units=None, output_unit="invdiv") for func_str in ["var", "nanvar"]: implement_func("function", func_str, input_units=None, output_unit="variance") diff --git a/pint/testsuite/test_numpy.py b/pint/testsuite/test_numpy.py index beea69d..44d4271 100644 --- a/pint/testsuite/test_numpy.py +++ b/pint/testsuite/test_numpy.py @@ -449,10 +449,13 @@ class TestNumpyMathematicalFunctions(TestNumpyMethods): @helpers.requires_array_function_protocol() def test_solve(self): - helpers.assert_quantity_almost_equal( - np.linalg.solve(self.q, [[3], [7]] * self.ureg.s), - self.Q_([[1], [1]], "m / s"), - ) + A = self.q + b = [[3], [7]] * self.ureg.s + x = np.linalg.solve(A, b) + + helpers.assert_quantity_almost_equal(x, self.Q_([[1], [1]], "s / m")) + + helpers.assert_quantity_almost_equal(np.dot(A, x), b) # Arithmetic operations def test_addition_with_scalar(self): -- cgit v1.2.1