summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorSebastian Berg <sebastian@sipsolutions.net>2022-05-26 20:48:02 -0700
committerSebastian Berg <sebastian@sipsolutions.net>2022-06-15 11:42:02 -0700
commitbaaeb9a16c9c28683db97c4fc3d047e86d32a0c5 (patch)
tree0070a9e10de564c6d6b7d66548419e684cfca77c /numpy
parent2a6a3931c0af78180fa6984fd09bc0264c156fd0 (diff)
downloadnumpy-baaeb9a16c9c28683db97c4fc3d047e86d32a0c5.tar.gz
WIP: Add warning context manager and fix min_scalar for new promotion
Even the new promotion has to use the min-scalar logic to avoid picking up a float16 loop for `np.int8(3) * 3.`.
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/_ufunc_config.py23
-rw-r--r--numpy/core/numeric.py4
-rw-r--r--numpy/core/src/multiarray/convert_datatype.c63
-rw-r--r--numpy/core/src/multiarray/convert_datatype.h4
-rw-r--r--numpy/core/src/umath/dispatching.c3
-rw-r--r--numpy/testing/_private/utils.py4
6 files changed, 93 insertions, 8 deletions
diff --git a/numpy/core/_ufunc_config.py b/numpy/core/_ufunc_config.py
index a731f6bf7..1937c4ca2 100644
--- a/numpy/core/_ufunc_config.py
+++ b/numpy/core/_ufunc_config.py
@@ -5,6 +5,7 @@ This provides helpers which wrap `umath.geterrobj` and `umath.seterrobj`
"""
import collections.abc
import contextlib
+import contextvars
from .overrides import set_module
from .umath import (
@@ -16,7 +17,7 @@ from . import umath
__all__ = [
"seterr", "geterr", "setbufsize", "getbufsize", "seterrcall", "geterrcall",
- "errstate",
+ "errstate", 'no_nep50_warning'
]
_errdict = {"ignore": ERR_IGNORE,
@@ -444,3 +445,23 @@ def _setdef():
# set the default values
_setdef()
+
+
+NO_NEP50_WARNING = contextvars.ContextVar("no_nep50_warning", default=False)
+
+@set_module('numpy')
+@contextlib.contextmanager
+def no_nep50_warning():
+ """
+ Context manager to disable NEP 50 warnings. This context manager is
+ only relevant if the NEP 50 warnings are enabled globally (which is NOT
+ thread/context safe).
+ On the other hand, this warning suppression is safe.
+ """
+ # TODO: We could skip the manager entirely if NumPy as a whole is not
+ # in the warning mode. (Which is NOT thread/context safe.)
+ token = NO_NEP50_WARNING.set(True)
+ try:
+ yield
+ finally:
+ NO_NEP50_WARNING.reset(token)
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index bb3cbf054..38d85da6e 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -27,7 +27,7 @@ from .umath import (multiply, invert, sin, PINF, NAN)
from . import numerictypes
from .numerictypes import longlong, intc, int_, float_, complex_, bool_
from ._exceptions import TooHardError, AxisError
-from ._ufunc_config import errstate
+from ._ufunc_config import errstate, no_nep50_warning
bitwise_not = invert
ufunc = type(sin)
@@ -2352,7 +2352,7 @@ def isclose(a, b, rtol=1.e-5, atol=1.e-8, equal_nan=False):
array([False, True])
"""
def within_tol(x, y, atol, rtol):
- with errstate(invalid='ignore'):
+ with errstate(invalid='ignore'), no_nep50_warning():
return less_equal(abs(x-y), atol + rtol * abs(y))
x = asanyarray(a)
diff --git a/numpy/core/src/multiarray/convert_datatype.c b/numpy/core/src/multiarray/convert_datatype.c
index 77e8a6a1d..94b7d6d98 100644
--- a/numpy/core/src/multiarray/convert_datatype.c
+++ b/numpy/core/src/multiarray/convert_datatype.c
@@ -49,6 +49,7 @@ NPY_NO_EXPORT npy_intp REQUIRED_STR_LEN[] = {0, 3, 5, 10, 10, 20, 20, 20, 20};
* Whether or not legacy value-based promotion/casting is used.
*/
NPY_NO_EXPORT int npy_promotion_state = NPY_USE_WEAK_PROMOTION_AND_WARN;
+NPY_NO_EXPORT PyObject *NO_NEP50_WARNING_CTX = NULL;
static PyObject *
PyArray_GetGenericToVoidCastingImpl(void);
@@ -63,6 +64,33 @@ static PyObject *
PyArray_GetObjectToGenericCastingImpl(void);
+/*
+ * Return 1 if promotion warnings should be given and 0 if they are currently
+ * suppressed in the local context.
+ */
+NPY_NO_EXPORT int
+npy_give_promotion_warnings(void)
+{
+ PyObject *val;
+
+ npy_cache_import(
+ "numpy.core._ufunc_config", "NO_NEP50_WARNING",
+ &NO_NEP50_WARNING_CTX);
+ if (NO_NEP50_WARNING_CTX == NULL) {
+ PyErr_WriteUnraisable(NULL);
+ return 1;
+ }
+
+ if (PyContextVar_Get(NO_NEP50_WARNING_CTX, Py_False, &val) < 0) {
+ /* Errors should not really happen, but if it does assume we warn. */
+ PyErr_WriteUnraisable(NULL);
+ return 1;
+ }
+ Py_DECREF(val);
+ /* only when the no-warnings context is false, we give warnings */
+ return val == Py_False;
+}
+
/**
* Fetch the casting implementation from one DType to another.
*
@@ -1634,15 +1662,38 @@ should_use_min_scalar(npy_intp narrs, PyArrayObject **arr,
NPY_NO_EXPORT int
should_use_min_scalar_weak_literals(int narrs, PyArrayObject **arr) {
- int count_literals = 0;
+ int all_scalars = 1;
+ int max_scalar_kind = -1;
+ int max_array_kind = -1;
+
for (int i = 0; i < narrs; i++) {
- if (PyArray_FLAGS(arr[i]) & NPY_ARRAY_WAS_PYTHON_LITERAL) {
- count_literals++;
+ if (PyArray_FLAGS(arr[i]) & NPY_ARRAY_WAS_PYTHON_INT) {
+ /* A Python integer could be `u` so is effectively that: */
+ int new = dtype_kind_to_simplified_ordering('u');
+ if (new > max_scalar_kind) {
+ max_scalar_kind = new;
+ }
+ }
+ /* For the new logic, only complex or not matters: */
+ else if (PyArray_FLAGS(arr[i]) & NPY_ARRAY_WAS_PYTHON_FLOAT) {
+ max_scalar_kind = dtype_kind_to_simplified_ordering('f');
+ }
+ else if (PyArray_FLAGS(arr[i]) & NPY_ARRAY_WAS_PYTHON_COMPLEX) {
+ max_scalar_kind = dtype_kind_to_simplified_ordering('f');
+ }
+ else {
+ all_scalars = 0;
+ int kind = dtype_kind_to_simplified_ordering(
+ PyArray_DESCR(arr[i])->kind);
+ if (kind > max_array_kind) {
+ max_array_kind = kind;
+ }
}
}
- if (count_literals > 0 && count_literals < narrs) {
+ if (!all_scalars && max_array_kind >= max_scalar_kind) {
return 1;
}
+
return 0;
}
@@ -1856,6 +1907,10 @@ PyArray_CheckLegacyResultType(
if (npy_promotion_state == NPY_USE_WEAK_PROMOTION) {
return 0;
}
+ if (npy_promotion_state == NPY_USE_WEAK_PROMOTION_AND_WARN
+ && !npy_give_promotion_warnings()) {
+ return 0;
+ }
npy_intp i;
diff --git a/numpy/core/src/multiarray/convert_datatype.h b/numpy/core/src/multiarray/convert_datatype.h
index 08a45ceae..3550f45d2 100644
--- a/numpy/core/src/multiarray/convert_datatype.h
+++ b/numpy/core/src/multiarray/convert_datatype.h
@@ -13,6 +13,10 @@ extern NPY_NO_EXPORT npy_intp REQUIRED_STR_LEN[];
#define NPY_USE_WEAK_PROMOTION 1
#define NPY_USE_WEAK_PROMOTION_AND_WARN 2
extern NPY_NO_EXPORT int npy_promotion_state;
+extern NPY_NO_EXPORT PyObject *NO_NEP50_WARNING_CTX;
+
+NPY_NO_EXPORT int
+npy_give_promotion_warnings(void);
NPY_NO_EXPORT PyObject *
PyArray_GetCastingImpl(PyArray_DTypeMeta *from, PyArray_DTypeMeta *to);
diff --git a/numpy/core/src/umath/dispatching.c b/numpy/core/src/umath/dispatching.c
index 19db47d6a..8b5db2e71 100644
--- a/numpy/core/src/umath/dispatching.c
+++ b/numpy/core/src/umath/dispatching.c
@@ -958,7 +958,8 @@ promote_and_get_ufuncimpl(PyUFuncObject *ufunc,
/* If necessary, check if the old result would have been different */
if (NPY_UNLIKELY(npy_promotion_state == NPY_USE_WEAK_PROMOTION_AND_WARN)
- && (force_legacy_promotion || promoting_pyscalars)) {
+ && (force_legacy_promotion || promoting_pyscalars)
+ && npy_give_promotion_warnings()) {
PyArray_DTypeMeta *check_dtypes[NPY_MAXARGS];
for (int i = 0; i < nargs; i++) {
check_dtypes[i] = (PyArray_DTypeMeta *)PyTuple_GET_ITEM(
diff --git a/numpy/testing/_private/utils.py b/numpy/testing/_private/utils.py
index e4f8b9892..ca64446db 100644
--- a/numpy/testing/_private/utils.py
+++ b/numpy/testing/_private/utils.py
@@ -473,6 +473,7 @@ def print_assert_equal(test_string, actual, desired):
raise AssertionError(msg.getvalue())
+@np.no_nep50_warning()
def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True):
"""
Raises an AssertionError if two items are not equal up to desired
@@ -599,6 +600,7 @@ def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True):
raise AssertionError(_build_err_msg())
+@np.no_nep50_warning()
def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True):
"""
Raises an AssertionError if two items are not equal up to significant
@@ -698,6 +700,7 @@ def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True):
raise AssertionError(msg)
+@np.no_nep50_warning()
def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='',
precision=6, equal_nan=True, equal_inf=True):
__tracebackhide__ = True # Hide traceback for py.test
@@ -935,6 +938,7 @@ def assert_array_equal(x, y, err_msg='', verbose=True):
verbose=verbose, header='Arrays are not equal')
+@np.no_nep50_warning()
def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True):
"""
Raises an AssertionError if two objects are not equal up to desired