diff options
author | Sebastian Berg <sebastian@sipsolutions.net> | 2022-05-26 20:48:02 -0700 |
---|---|---|
committer | Sebastian Berg <sebastian@sipsolutions.net> | 2022-06-15 11:42:02 -0700 |
commit | baaeb9a16c9c28683db97c4fc3d047e86d32a0c5 (patch) | |
tree | 0070a9e10de564c6d6b7d66548419e684cfca77c /numpy | |
parent | 2a6a3931c0af78180fa6984fd09bc0264c156fd0 (diff) | |
download | numpy-baaeb9a16c9c28683db97c4fc3d047e86d32a0c5.tar.gz |
WIP: Add warning context manager and fix min_scalar for new promotion
Even the new promotion has to use the min-scalar logic to avoid
picking up a float16 loop for `np.int8(3) * 3.`.
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/_ufunc_config.py | 23 | ||||
-rw-r--r-- | numpy/core/numeric.py | 4 | ||||
-rw-r--r-- | numpy/core/src/multiarray/convert_datatype.c | 63 | ||||
-rw-r--r-- | numpy/core/src/multiarray/convert_datatype.h | 4 | ||||
-rw-r--r-- | numpy/core/src/umath/dispatching.c | 3 | ||||
-rw-r--r-- | numpy/testing/_private/utils.py | 4 |
6 files changed, 93 insertions, 8 deletions
diff --git a/numpy/core/_ufunc_config.py b/numpy/core/_ufunc_config.py index a731f6bf7..1937c4ca2 100644 --- a/numpy/core/_ufunc_config.py +++ b/numpy/core/_ufunc_config.py @@ -5,6 +5,7 @@ This provides helpers which wrap `umath.geterrobj` and `umath.seterrobj` """ import collections.abc import contextlib +import contextvars from .overrides import set_module from .umath import ( @@ -16,7 +17,7 @@ from . import umath __all__ = [ "seterr", "geterr", "setbufsize", "getbufsize", "seterrcall", "geterrcall", - "errstate", + "errstate", 'no_nep50_warning' ] _errdict = {"ignore": ERR_IGNORE, @@ -444,3 +445,23 @@ def _setdef(): # set the default values _setdef() + + +NO_NEP50_WARNING = contextvars.ContextVar("no_nep50_warning", default=False) + +@set_module('numpy') +@contextlib.contextmanager +def no_nep50_warning(): + """ + Context manager to disable NEP 50 warnings. This context manager is + only relevant if the NEP 50 warnings are enabled globally (which is NOT + thread/context safe). + On the other hand, this warning suppression is safe. + """ + # TODO: We could skip the manager entirely if NumPy as a whole is not + # in the warning mode. (Which is NOT thread/context safe.) + token = NO_NEP50_WARNING.set(True) + try: + yield + finally: + NO_NEP50_WARNING.reset(token) diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index bb3cbf054..38d85da6e 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -27,7 +27,7 @@ from .umath import (multiply, invert, sin, PINF, NAN) from . import numerictypes from .numerictypes import longlong, intc, int_, float_, complex_, bool_ from ._exceptions import TooHardError, AxisError -from ._ufunc_config import errstate +from ._ufunc_config import errstate, no_nep50_warning bitwise_not = invert ufunc = type(sin) @@ -2352,7 +2352,7 @@ def isclose(a, b, rtol=1.e-5, atol=1.e-8, equal_nan=False): array([False, True]) """ def within_tol(x, y, atol, rtol): - with errstate(invalid='ignore'): + with errstate(invalid='ignore'), no_nep50_warning(): return less_equal(abs(x-y), atol + rtol * abs(y)) x = asanyarray(a) diff --git a/numpy/core/src/multiarray/convert_datatype.c b/numpy/core/src/multiarray/convert_datatype.c index 77e8a6a1d..94b7d6d98 100644 --- a/numpy/core/src/multiarray/convert_datatype.c +++ b/numpy/core/src/multiarray/convert_datatype.c @@ -49,6 +49,7 @@ NPY_NO_EXPORT npy_intp REQUIRED_STR_LEN[] = {0, 3, 5, 10, 10, 20, 20, 20, 20}; * Whether or not legacy value-based promotion/casting is used. */ NPY_NO_EXPORT int npy_promotion_state = NPY_USE_WEAK_PROMOTION_AND_WARN; +NPY_NO_EXPORT PyObject *NO_NEP50_WARNING_CTX = NULL; static PyObject * PyArray_GetGenericToVoidCastingImpl(void); @@ -63,6 +64,33 @@ static PyObject * PyArray_GetObjectToGenericCastingImpl(void); +/* + * Return 1 if promotion warnings should be given and 0 if they are currently + * suppressed in the local context. + */ +NPY_NO_EXPORT int +npy_give_promotion_warnings(void) +{ + PyObject *val; + + npy_cache_import( + "numpy.core._ufunc_config", "NO_NEP50_WARNING", + &NO_NEP50_WARNING_CTX); + if (NO_NEP50_WARNING_CTX == NULL) { + PyErr_WriteUnraisable(NULL); + return 1; + } + + if (PyContextVar_Get(NO_NEP50_WARNING_CTX, Py_False, &val) < 0) { + /* Errors should not really happen, but if it does assume we warn. */ + PyErr_WriteUnraisable(NULL); + return 1; + } + Py_DECREF(val); + /* only when the no-warnings context is false, we give warnings */ + return val == Py_False; +} + /** * Fetch the casting implementation from one DType to another. * @@ -1634,15 +1662,38 @@ should_use_min_scalar(npy_intp narrs, PyArrayObject **arr, NPY_NO_EXPORT int should_use_min_scalar_weak_literals(int narrs, PyArrayObject **arr) { - int count_literals = 0; + int all_scalars = 1; + int max_scalar_kind = -1; + int max_array_kind = -1; + for (int i = 0; i < narrs; i++) { - if (PyArray_FLAGS(arr[i]) & NPY_ARRAY_WAS_PYTHON_LITERAL) { - count_literals++; + if (PyArray_FLAGS(arr[i]) & NPY_ARRAY_WAS_PYTHON_INT) { + /* A Python integer could be `u` so is effectively that: */ + int new = dtype_kind_to_simplified_ordering('u'); + if (new > max_scalar_kind) { + max_scalar_kind = new; + } + } + /* For the new logic, only complex or not matters: */ + else if (PyArray_FLAGS(arr[i]) & NPY_ARRAY_WAS_PYTHON_FLOAT) { + max_scalar_kind = dtype_kind_to_simplified_ordering('f'); + } + else if (PyArray_FLAGS(arr[i]) & NPY_ARRAY_WAS_PYTHON_COMPLEX) { + max_scalar_kind = dtype_kind_to_simplified_ordering('f'); + } + else { + all_scalars = 0; + int kind = dtype_kind_to_simplified_ordering( + PyArray_DESCR(arr[i])->kind); + if (kind > max_array_kind) { + max_array_kind = kind; + } } } - if (count_literals > 0 && count_literals < narrs) { + if (!all_scalars && max_array_kind >= max_scalar_kind) { return 1; } + return 0; } @@ -1856,6 +1907,10 @@ PyArray_CheckLegacyResultType( if (npy_promotion_state == NPY_USE_WEAK_PROMOTION) { return 0; } + if (npy_promotion_state == NPY_USE_WEAK_PROMOTION_AND_WARN + && !npy_give_promotion_warnings()) { + return 0; + } npy_intp i; diff --git a/numpy/core/src/multiarray/convert_datatype.h b/numpy/core/src/multiarray/convert_datatype.h index 08a45ceae..3550f45d2 100644 --- a/numpy/core/src/multiarray/convert_datatype.h +++ b/numpy/core/src/multiarray/convert_datatype.h @@ -13,6 +13,10 @@ extern NPY_NO_EXPORT npy_intp REQUIRED_STR_LEN[]; #define NPY_USE_WEAK_PROMOTION 1 #define NPY_USE_WEAK_PROMOTION_AND_WARN 2 extern NPY_NO_EXPORT int npy_promotion_state; +extern NPY_NO_EXPORT PyObject *NO_NEP50_WARNING_CTX; + +NPY_NO_EXPORT int +npy_give_promotion_warnings(void); NPY_NO_EXPORT PyObject * PyArray_GetCastingImpl(PyArray_DTypeMeta *from, PyArray_DTypeMeta *to); diff --git a/numpy/core/src/umath/dispatching.c b/numpy/core/src/umath/dispatching.c index 19db47d6a..8b5db2e71 100644 --- a/numpy/core/src/umath/dispatching.c +++ b/numpy/core/src/umath/dispatching.c @@ -958,7 +958,8 @@ promote_and_get_ufuncimpl(PyUFuncObject *ufunc, /* If necessary, check if the old result would have been different */ if (NPY_UNLIKELY(npy_promotion_state == NPY_USE_WEAK_PROMOTION_AND_WARN) - && (force_legacy_promotion || promoting_pyscalars)) { + && (force_legacy_promotion || promoting_pyscalars) + && npy_give_promotion_warnings()) { PyArray_DTypeMeta *check_dtypes[NPY_MAXARGS]; for (int i = 0; i < nargs; i++) { check_dtypes[i] = (PyArray_DTypeMeta *)PyTuple_GET_ITEM( diff --git a/numpy/testing/_private/utils.py b/numpy/testing/_private/utils.py index e4f8b9892..ca64446db 100644 --- a/numpy/testing/_private/utils.py +++ b/numpy/testing/_private/utils.py @@ -473,6 +473,7 @@ def print_assert_equal(test_string, actual, desired): raise AssertionError(msg.getvalue()) +@np.no_nep50_warning() def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True): """ Raises an AssertionError if two items are not equal up to desired @@ -599,6 +600,7 @@ def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True): raise AssertionError(_build_err_msg()) +@np.no_nep50_warning() def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True): """ Raises an AssertionError if two items are not equal up to significant @@ -698,6 +700,7 @@ def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True): raise AssertionError(msg) +@np.no_nep50_warning() def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='', precision=6, equal_nan=True, equal_inf=True): __tracebackhide__ = True # Hide traceback for py.test @@ -935,6 +938,7 @@ def assert_array_equal(x, y, err_msg='', verbose=True): verbose=verbose, header='Arrays are not equal') +@np.no_nep50_warning() def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True): """ Raises an AssertionError if two objects are not equal up to desired |