diff options
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/arrayprint.py | 6 | ||||
-rw-r--r-- | numpy/core/overrides.py | 57 | ||||
-rw-r--r-- | numpy/core/shape_base.py | 8 | ||||
-rw-r--r-- | numpy/core/tests/test_overrides.py | 20 | ||||
-rw-r--r-- | numpy/lib/shape_base.py | 5 | ||||
-rw-r--r-- | numpy/lib/ufunclike.py | 6 | ||||
-rw-r--r-- | numpy/testing/tests/test_utils.py | 9 |
7 files changed, 95 insertions, 16 deletions
diff --git a/numpy/core/arrayprint.py b/numpy/core/arrayprint.py index ccc1468c4..b578fab54 100644 --- a/numpy/core/arrayprint.py +++ b/numpy/core/arrayprint.py @@ -1547,10 +1547,12 @@ def array_str(a, max_line_width=None, precision=None, suppress_small=None): a, max_line_width, precision, suppress_small) +# needed if __array_function__ is disabled +_array2string_impl = getattr(array2string, '__wrapped__', array2string) _default_array_str = functools.partial(_array_str_implementation, - array2string=array2string.__wrapped__) + array2string=_array2string_impl) _default_array_repr = functools.partial(_array_repr_implementation, - array2string=array2string.__wrapped__) + array2string=_array2string_impl) def set_string_function(f, repr=True): diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py index 088c15e65..e4d505f06 100644 --- a/numpy/core/overrides.py +++ b/numpy/core/overrides.py @@ -4,6 +4,7 @@ TODO: rewrite this in C for performance. """ import collections import functools +import os from numpy.core._multiarray_umath import ndarray from numpy.compat._inspect import getargspec @@ -12,6 +13,9 @@ from numpy.compat._inspect import getargspec _NDARRAY_ARRAY_FUNCTION = ndarray.__array_function__ _NDARRAY_ONLY = [ndarray] +ENABLE_ARRAY_FUNCTION = bool( + int(os.environ.get('NUMPY_EXPERIMENTAL_ARRAY_FUNCTION', 0))) + def get_overloaded_types_and_args(relevant_args): """Returns a list of arguments on which to call __array_function__. @@ -146,12 +150,57 @@ def verify_matching_signatures(implementation, dispatcher): 'default argument values') +def override_module(module): + """Decorator for overriding __module__ on a function or class. + + Example usage:: + + @override_module('numpy') + def example(): + pass + + assert example.__module__ == 'numpy' + """ + def decorator(func): + if module is not None: + func.__module__ = module + return func + return decorator + + def array_function_dispatch(dispatcher, module=None, verify=True): - """Decorator for adding dispatch with the __array_function__ protocol.""" + """Decorator for adding dispatch with the __array_function__ protocol. + + See NEP-18 for example usage. + + Parameters + ---------- + dispatcher : callable + Function that when called like ``dispatcher(*args, **kwargs)`` with + arguments from the NumPy function call returns an iterable of + array-like arguments to check for ``__array_function__``. + module : str, optional + __module__ attribute to set on new function, e.g., ``module='numpy'``. + By default, module is copied from the decorated function. + verify : bool, optional + If True, verify the that the signature of the dispatcher and decorated + function signatures match exactly: all required and optional arguments + should appear in order with the same names, but the default values for + all optional arguments should be ``None``. Only disable verification + if the dispatcher's signature needs to deviate for some particular + reason, e.g., because the function has a signature like + ``func(*args, **kwargs)``. + + Returns + ------- + Function suitable for decorating the implementation of a NumPy function. + """ + + if not ENABLE_ARRAY_FUNCTION: + # __array_function__ requires an explicit opt-in for now + return override_module(module) + def decorator(implementation): - # TODO: only do this check when the appropriate flag is enabled or for - # a dev install. We want this check for testing but don't want to - # slow down all numpy imports. if verify: verify_matching_signatures(implementation, dispatcher) diff --git a/numpy/core/shape_base.py b/numpy/core/shape_base.py index 3edf0824e..6d234e527 100644 --- a/numpy/core/shape_base.py +++ b/numpy/core/shape_base.py @@ -217,6 +217,11 @@ def _arrays_for_stack_dispatcher(arrays, stacklevel=4): return arrays +def _warn_for_nonsequence(arrays): + if not overrides.ENABLE_ARRAY_FUNCTION: + _arrays_for_stack_dispatcher(arrays, stacklevel=4) + + def _vhstack_dispatcher(tup): return _arrays_for_stack_dispatcher(tup) @@ -274,6 +279,7 @@ def vstack(tup): [4]]) """ + _warn_for_nonsequence(tup) return _nx.concatenate([atleast_2d(_m) for _m in tup], 0) @@ -325,6 +331,7 @@ def hstack(tup): [3, 4]]) """ + _warn_for_nonsequence(tup) arrs = [atleast_1d(_m) for _m in tup] # As a special case, dimension 0 of 1-dimensional arrays is "horizontal" if arrs and arrs[0].ndim == 1: @@ -398,6 +405,7 @@ def stack(arrays, axis=0, out=None): [3, 4]]) """ + _warn_for_nonsequence(arrays) arrays = [asanyarray(arr) for arr in arrays] if not arrays: raise ValueError('need at least one array to stack') diff --git a/numpy/core/tests/test_overrides.py b/numpy/core/tests/test_overrides.py index ee6d5da4a..a32049ea1 100644 --- a/numpy/core/tests/test_overrides.py +++ b/numpy/core/tests/test_overrides.py @@ -1,5 +1,6 @@ from __future__ import division, absolute_import, print_function +import inspect import sys import numpy as np @@ -7,8 +8,14 @@ from numpy.testing import ( assert_, assert_equal, assert_raises, assert_raises_regex) from numpy.core.overrides import ( get_overloaded_types_and_args, array_function_dispatch, - verify_matching_signatures) + verify_matching_signatures, ENABLE_ARRAY_FUNCTION) from numpy.core.numeric import pickle +import pytest + + +requires_array_function = pytest.mark.skipif( + not ENABLE_ARRAY_FUNCTION, + reason="__array_function__ dispatch not enabled.") def _get_overloaded_args(relevant_args): @@ -165,6 +172,7 @@ def dispatched_one_arg(array): return 'original' +@requires_array_function class TestArrayFunctionDispatch(object): def test_pickle(self): @@ -204,6 +212,7 @@ class TestArrayFunctionDispatch(object): dispatched_one_arg(array) +@requires_array_function class TestVerifyMatchingSignatures(object): def test_verify_matching_signatures(self): @@ -256,6 +265,7 @@ def _new_duck_type_and_implements(): return (MyArray, implements) +@requires_array_function class TestArrayFunctionImplementation(object): def test_one_arg(self): @@ -322,7 +332,7 @@ class TestNDArrayMethods(object): assert_equal(repr(array), 'MyArray(1)') assert_equal(str(array), '1') - + class TestNumPyFunctions(object): def test_module(self): @@ -331,6 +341,12 @@ class TestNumPyFunctions(object): assert_equal(np.fft.fft.__module__, 'numpy.fft') assert_equal(np.linalg.solve.__module__, 'numpy.linalg') + @pytest.mark.skipif(sys.version_info[0] < 3, reason="Python 3 only") + def test_inspect_sum(self): + signature = inspect.signature(np.sum) + assert_('axis' in signature.parameters) + + @requires_array_function def test_override_sum(self): MyArray, implements = _new_duck_type_and_implements() diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py index 6e7cab3fa..f56c4f4db 100644 --- a/numpy/lib/shape_base.py +++ b/numpy/lib/shape_base.py @@ -11,7 +11,8 @@ from numpy.core.fromnumeric import product, reshape, transpose from numpy.core.multiarray import normalize_axis_index from numpy.core import overrides from numpy.core import vstack, atleast_3d -from numpy.core.shape_base import _arrays_for_stack_dispatcher +from numpy.core.shape_base import ( + _arrays_for_stack_dispatcher, _warn_for_nonsequence) from numpy.lib.index_tricks import ndindex from numpy.matrixlib.defmatrix import matrix # this raises all the right alarm bells @@ -629,6 +630,7 @@ def column_stack(tup): [3, 4]]) """ + _warn_for_nonsequence(tup) arrays = [] for v in tup: arr = array(v, copy=False, subok=True) @@ -693,6 +695,7 @@ def dstack(tup): [[3, 4]]]) """ + _warn_for_nonsequence(tup) return _nx.concatenate([atleast_3d(_m) for _m in tup], 2) diff --git a/numpy/lib/ufunclike.py b/numpy/lib/ufunclike.py index ac0af0b37..9a9e6f9dd 100644 --- a/numpy/lib/ufunclike.py +++ b/numpy/lib/ufunclike.py @@ -8,7 +8,7 @@ from __future__ import division, absolute_import, print_function __all__ = ['fix', 'isneginf', 'isposinf'] import numpy.core.numeric as nx -from numpy.core.overrides import array_function_dispatch +from numpy.core.overrides import array_function_dispatch, ENABLE_ARRAY_FUNCTION import warnings import functools @@ -55,6 +55,10 @@ def _fix_out_named_y(f): return func +if not ENABLE_ARRAY_FUNCTION: + _fix_out_named_y = _deprecate_out_named_y + + @_deprecate_out_named_y def _dispatcher(x, out=None): return (x, out) diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py index e54fbc390..e35cdb6cf 100644 --- a/numpy/testing/tests/test_utils.py +++ b/numpy/testing/tests/test_utils.py @@ -180,18 +180,15 @@ class TestArrayEqual(_GenericTest): self._test_not_equal(b, a) def test_subclass_that_does_not_implement_npall(self): - # While we cannot guarantee testing functions will always work for - # subclasses, the tests should ideally rely only on subclasses having - # comparison operators, not on them being able to store booleans - # (which, e.g., astropy Quantity cannot usefully do). See gh-8452. class MyArray(np.ndarray): def __array_function__(self, *args, **kwargs): return NotImplemented a = np.array([1., 2.]).view(MyArray) b = np.array([2., 3.]).view(MyArray) - with assert_raises(TypeError): - np.all(a) + if np.core.overrides.ENABLE_ARRAY_FUNCTION: + with assert_raises(TypeError): + np.all(a) self._test_equal(a, a) self._test_not_equal(a, b) self._test_not_equal(b, a) |