summaryrefslogtreecommitdiff
path: root/numpy/lib/function_base.py
diff options
context:
space:
mode:
authorabel <aoun@cerfacs.fr>2021-10-19 12:10:16 +0200
committerSebastian Berg <sebastian@sipsolutions.net>2021-11-04 14:50:27 -0500
commitab19ed256bf9b20340c92cebcfd6158241122c88 (patch)
tree4c0e4f39f7881475137aeddf164fff15d0f722b6 /numpy/lib/function_base.py
parent303c12cfe7ad1b8b6ed5417c126857b29355b1fb (diff)
downloadnumpy-ab19ed256bf9b20340c92cebcfd6158241122c88.tar.gz
Fix _lerp
- some changes were unrelated to the PR and have been reverted, including, renaming and moving the logic around. - Also renamed _quantile_ureduce_func to its original name
Diffstat (limited to 'numpy/lib/function_base.py')
-rw-r--r--numpy/lib/function_base.py42
1 files changed, 19 insertions, 23 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index dbaba87f9..353490fc2 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -4329,7 +4329,7 @@ def _quantile_unchecked(a,
keepdims=False):
"""Assumes that q is in [0, 1], and is an ndarray"""
r, k = _ureduce(a,
- func=_quantiles_ureduce_func,
+ func=_quantile_ureduce_func,
q=q,
axis=axis,
out=out,
@@ -4383,22 +4383,22 @@ def _get_gamma(virtual_indexes: np.array,
return np.asanyarray(gamma)
-def _linear_interpolation_formula(
- left: np.array, right: np.array, gamma: np.array, out: np.array = None
-) -> np.array:
+def _lerp(a, b, t, out=None):
"""
Compute the linear interpolation weighted by gamma on each point of
two same shape array.
"""
- # Equivalent to gamma * right + (1 - gamma) * left
- # see gh-14685
- diff_right_left = subtract(right, left)
- result = asanyarray(add(left, diff_right_left * gamma, out=out))
- result = subtract(right,
- diff_right_left * (1 - gamma),
- out=result,
- where=gamma >= 0.5)
- return result
+ # Equivalent to gamma * right + (1 - gamma) * left, see gh-14685
+ diff_b_a = subtract(b, a)
+ # asanyarray is a stop-gap until gh-13105
+ lerp_interpolation = asanyarray(add(a, diff_b_a * t, out=out))
+ lerp_interpolation = subtract(b,
+ diff_b_a * (1 - t),
+ out=lerp_interpolation,
+ where=t >= 0.5)
+ if lerp_interpolation.ndim == 0 and out is None:
+ lerp_interpolation = lerp_interpolation[()] # unpack 0d arrays
+ return lerp_interpolation
def _get_gamma_mask(shape, default_value, conditioned_value, where):
@@ -4430,7 +4430,7 @@ def _inverted_cdf(n, quantiles):
gamma_fun)
-def _quantiles_ureduce_func(
+def _quantile_ureduce_func(
a: np.array,
q: np.array,
axis: int = None,
@@ -4460,10 +4460,6 @@ def _quantiles_ureduce_func(
axis=axis,
interpolation=interpolation,
out=out)
- if result.ndim == 0 and out is None:
- result = result[()] # unpack 0d arrays
- elif result.size == 1 and out is None and q.ndim == 0:
- result = result[0]
return result
@@ -4551,7 +4547,7 @@ def _quantile(
# cannot contain nan
arr.partition(virtual_indexes.ravel(), axis=0)
slices_having_nans = np.array(False, dtype=bool)
- result = np.asanyarray(take(arr, virtual_indexes, axis=0, out=out))
+ result = take(arr, virtual_indexes, axis=0, out=out)
else:
previous_indexes, next_indexes = _get_indexes(arr,
virtual_indexes,
@@ -4578,10 +4574,10 @@ def _quantile(
interpolation)
result_shape = virtual_indexes.shape + (1,) * (arr.ndim - 1)
gamma = gamma.reshape(result_shape)
- result = _linear_interpolation_formula(previous,
- next,
- gamma,
- out=out)
+ result = _lerp(previous,
+ next,
+ gamma,
+ out=out)
if np.any(slices_having_nans):
if result.ndim == 0 and out is None:
# can't write to a scalar