diff options
author | abel <aoun@cerfacs.fr> | 2021-10-19 12:10:16 +0200 |
---|---|---|
committer | Sebastian Berg <sebastian@sipsolutions.net> | 2021-11-04 14:50:27 -0500 |
commit | ab19ed256bf9b20340c92cebcfd6158241122c88 (patch) | |
tree | 4c0e4f39f7881475137aeddf164fff15d0f722b6 /numpy/lib/function_base.py | |
parent | 303c12cfe7ad1b8b6ed5417c126857b29355b1fb (diff) | |
download | numpy-ab19ed256bf9b20340c92cebcfd6158241122c88.tar.gz |
Fix _lerp
- some changes were unrelated to the PR and have been reverted, including, renaming and moving the logic around.
- Also renamed _quantile_ureduce_func to its original name
Diffstat (limited to 'numpy/lib/function_base.py')
-rw-r--r-- | numpy/lib/function_base.py | 42 |
1 files changed, 19 insertions, 23 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index dbaba87f9..353490fc2 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -4329,7 +4329,7 @@ def _quantile_unchecked(a, keepdims=False): """Assumes that q is in [0, 1], and is an ndarray""" r, k = _ureduce(a, - func=_quantiles_ureduce_func, + func=_quantile_ureduce_func, q=q, axis=axis, out=out, @@ -4383,22 +4383,22 @@ def _get_gamma(virtual_indexes: np.array, return np.asanyarray(gamma) -def _linear_interpolation_formula( - left: np.array, right: np.array, gamma: np.array, out: np.array = None -) -> np.array: +def _lerp(a, b, t, out=None): """ Compute the linear interpolation weighted by gamma on each point of two same shape array. """ - # Equivalent to gamma * right + (1 - gamma) * left - # see gh-14685 - diff_right_left = subtract(right, left) - result = asanyarray(add(left, diff_right_left * gamma, out=out)) - result = subtract(right, - diff_right_left * (1 - gamma), - out=result, - where=gamma >= 0.5) - return result + # Equivalent to gamma * right + (1 - gamma) * left, see gh-14685 + diff_b_a = subtract(b, a) + # asanyarray is a stop-gap until gh-13105 + lerp_interpolation = asanyarray(add(a, diff_b_a * t, out=out)) + lerp_interpolation = subtract(b, + diff_b_a * (1 - t), + out=lerp_interpolation, + where=t >= 0.5) + if lerp_interpolation.ndim == 0 and out is None: + lerp_interpolation = lerp_interpolation[()] # unpack 0d arrays + return lerp_interpolation def _get_gamma_mask(shape, default_value, conditioned_value, where): @@ -4430,7 +4430,7 @@ def _inverted_cdf(n, quantiles): gamma_fun) -def _quantiles_ureduce_func( +def _quantile_ureduce_func( a: np.array, q: np.array, axis: int = None, @@ -4460,10 +4460,6 @@ def _quantiles_ureduce_func( axis=axis, interpolation=interpolation, out=out) - if result.ndim == 0 and out is None: - result = result[()] # unpack 0d arrays - elif result.size == 1 and out is None and q.ndim == 0: - result = result[0] return result @@ -4551,7 +4547,7 @@ def _quantile( # cannot contain nan arr.partition(virtual_indexes.ravel(), axis=0) slices_having_nans = np.array(False, dtype=bool) - result = np.asanyarray(take(arr, virtual_indexes, axis=0, out=out)) + result = take(arr, virtual_indexes, axis=0, out=out) else: previous_indexes, next_indexes = _get_indexes(arr, virtual_indexes, @@ -4578,10 +4574,10 @@ def _quantile( interpolation) result_shape = virtual_indexes.shape + (1,) * (arr.ndim - 1) gamma = gamma.reshape(result_shape) - result = _linear_interpolation_formula(previous, - next, - gamma, - out=out) + result = _lerp(previous, + next, + gamma, + out=out) if np.any(slices_having_nans): if result.ndim == 0 and out is None: # can't write to a scalar |