diff options
Diffstat (limited to 'numpy/lib/function_base.py')
-rw-r--r-- | numpy/lib/function_base.py | 9 |
1 files changed, 5 insertions, 4 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index fb5dd6fdd..960285f7d 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -516,7 +516,8 @@ def average(a, axis=None, weights=None, returned=False, *, if weights is None: avg = a.mean(axis, **keepdims_kw) - scl = avg.dtype.type(a.size/avg.size) + avg_as_array = np.asanyarray(avg) + scl = avg_as_array.dtype.type(a.size/avg_as_array.size) else: wgt = np.asanyarray(weights) @@ -547,12 +548,12 @@ def average(a, axis=None, weights=None, returned=False, *, raise ZeroDivisionError( "Weights sum to zero, can't be normalized") - avg = np.multiply(a, wgt, + avg = avg_as_array = np.multiply(a, wgt, dtype=result_dtype).sum(axis, **keepdims_kw) / scl if returned: - if scl.shape != avg.shape: - scl = np.broadcast_to(scl, avg.shape).copy() + if scl.shape != avg_as_array.shape: + scl = np.broadcast_to(scl, avg_as_array.shape).copy() return avg, scl else: return avg |