summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2018-11-12 17:17:24 -0600
committerGitHub <noreply@github.com>2018-11-12 17:17:24 -0600
commit067264b20f8d2c043cf5b9cd3d1826384e558e79 (patch)
tree4714149704808160daa1ee9b2bdaa9cfab2672d8 /numpy
parente34f5bb4e3aad2bb35fb6d9e3a5c1bfc85eb97eb (diff)
parent9fa3a4e9802b32c985f6a1fc14dd315a2656ac38 (diff)
downloadnumpy-067264b20f8d2c043cf5b9cd3d1826384e558e79.tar.gz
Merge pull request #12362 from shoyer/disable-array-function-by-default
MAINT: disable `__array_function__` dispatch unless environment variable set
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/arrayprint.py6
-rw-r--r--numpy/core/overrides.py57
-rw-r--r--numpy/core/shape_base.py8
-rw-r--r--numpy/core/tests/test_overrides.py20
-rw-r--r--numpy/lib/shape_base.py5
-rw-r--r--numpy/lib/ufunclike.py6
-rw-r--r--numpy/testing/tests/test_utils.py9
7 files changed, 95 insertions, 16 deletions
diff --git a/numpy/core/arrayprint.py b/numpy/core/arrayprint.py
index ccc1468c4..b578fab54 100644
--- a/numpy/core/arrayprint.py
+++ b/numpy/core/arrayprint.py
@@ -1547,10 +1547,12 @@ def array_str(a, max_line_width=None, precision=None, suppress_small=None):
a, max_line_width, precision, suppress_small)
+# needed if __array_function__ is disabled
+_array2string_impl = getattr(array2string, '__wrapped__', array2string)
_default_array_str = functools.partial(_array_str_implementation,
- array2string=array2string.__wrapped__)
+ array2string=_array2string_impl)
_default_array_repr = functools.partial(_array_repr_implementation,
- array2string=array2string.__wrapped__)
+ array2string=_array2string_impl)
def set_string_function(f, repr=True):
diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py
index 088c15e65..e4d505f06 100644
--- a/numpy/core/overrides.py
+++ b/numpy/core/overrides.py
@@ -4,6 +4,7 @@ TODO: rewrite this in C for performance.
"""
import collections
import functools
+import os
from numpy.core._multiarray_umath import ndarray
from numpy.compat._inspect import getargspec
@@ -12,6 +13,9 @@ from numpy.compat._inspect import getargspec
_NDARRAY_ARRAY_FUNCTION = ndarray.__array_function__
_NDARRAY_ONLY = [ndarray]
+ENABLE_ARRAY_FUNCTION = bool(
+ int(os.environ.get('NUMPY_EXPERIMENTAL_ARRAY_FUNCTION', 0)))
+
def get_overloaded_types_and_args(relevant_args):
"""Returns a list of arguments on which to call __array_function__.
@@ -146,12 +150,57 @@ def verify_matching_signatures(implementation, dispatcher):
'default argument values')
+def override_module(module):
+ """Decorator for overriding __module__ on a function or class.
+
+ Example usage::
+
+ @override_module('numpy')
+ def example():
+ pass
+
+ assert example.__module__ == 'numpy'
+ """
+ def decorator(func):
+ if module is not None:
+ func.__module__ = module
+ return func
+ return decorator
+
+
def array_function_dispatch(dispatcher, module=None, verify=True):
- """Decorator for adding dispatch with the __array_function__ protocol."""
+ """Decorator for adding dispatch with the __array_function__ protocol.
+
+ See NEP-18 for example usage.
+
+ Parameters
+ ----------
+ dispatcher : callable
+ Function that when called like ``dispatcher(*args, **kwargs)`` with
+ arguments from the NumPy function call returns an iterable of
+ array-like arguments to check for ``__array_function__``.
+ module : str, optional
+ __module__ attribute to set on new function, e.g., ``module='numpy'``.
+ By default, module is copied from the decorated function.
+ verify : bool, optional
+ If True, verify the that the signature of the dispatcher and decorated
+ function signatures match exactly: all required and optional arguments
+ should appear in order with the same names, but the default values for
+ all optional arguments should be ``None``. Only disable verification
+ if the dispatcher's signature needs to deviate for some particular
+ reason, e.g., because the function has a signature like
+ ``func(*args, **kwargs)``.
+
+ Returns
+ -------
+ Function suitable for decorating the implementation of a NumPy function.
+ """
+
+ if not ENABLE_ARRAY_FUNCTION:
+ # __array_function__ requires an explicit opt-in for now
+ return override_module(module)
+
def decorator(implementation):
- # TODO: only do this check when the appropriate flag is enabled or for
- # a dev install. We want this check for testing but don't want to
- # slow down all numpy imports.
if verify:
verify_matching_signatures(implementation, dispatcher)
diff --git a/numpy/core/shape_base.py b/numpy/core/shape_base.py
index 3edf0824e..6d234e527 100644
--- a/numpy/core/shape_base.py
+++ b/numpy/core/shape_base.py
@@ -217,6 +217,11 @@ def _arrays_for_stack_dispatcher(arrays, stacklevel=4):
return arrays
+def _warn_for_nonsequence(arrays):
+ if not overrides.ENABLE_ARRAY_FUNCTION:
+ _arrays_for_stack_dispatcher(arrays, stacklevel=4)
+
+
def _vhstack_dispatcher(tup):
return _arrays_for_stack_dispatcher(tup)
@@ -274,6 +279,7 @@ def vstack(tup):
[4]])
"""
+ _warn_for_nonsequence(tup)
return _nx.concatenate([atleast_2d(_m) for _m in tup], 0)
@@ -325,6 +331,7 @@ def hstack(tup):
[3, 4]])
"""
+ _warn_for_nonsequence(tup)
arrs = [atleast_1d(_m) for _m in tup]
# As a special case, dimension 0 of 1-dimensional arrays is "horizontal"
if arrs and arrs[0].ndim == 1:
@@ -398,6 +405,7 @@ def stack(arrays, axis=0, out=None):
[3, 4]])
"""
+ _warn_for_nonsequence(arrays)
arrays = [asanyarray(arr) for arr in arrays]
if not arrays:
raise ValueError('need at least one array to stack')
diff --git a/numpy/core/tests/test_overrides.py b/numpy/core/tests/test_overrides.py
index ee6d5da4a..a32049ea1 100644
--- a/numpy/core/tests/test_overrides.py
+++ b/numpy/core/tests/test_overrides.py
@@ -1,5 +1,6 @@
from __future__ import division, absolute_import, print_function
+import inspect
import sys
import numpy as np
@@ -7,8 +8,14 @@ from numpy.testing import (
assert_, assert_equal, assert_raises, assert_raises_regex)
from numpy.core.overrides import (
get_overloaded_types_and_args, array_function_dispatch,
- verify_matching_signatures)
+ verify_matching_signatures, ENABLE_ARRAY_FUNCTION)
from numpy.core.numeric import pickle
+import pytest
+
+
+requires_array_function = pytest.mark.skipif(
+ not ENABLE_ARRAY_FUNCTION,
+ reason="__array_function__ dispatch not enabled.")
def _get_overloaded_args(relevant_args):
@@ -165,6 +172,7 @@ def dispatched_one_arg(array):
return 'original'
+@requires_array_function
class TestArrayFunctionDispatch(object):
def test_pickle(self):
@@ -204,6 +212,7 @@ class TestArrayFunctionDispatch(object):
dispatched_one_arg(array)
+@requires_array_function
class TestVerifyMatchingSignatures(object):
def test_verify_matching_signatures(self):
@@ -256,6 +265,7 @@ def _new_duck_type_and_implements():
return (MyArray, implements)
+@requires_array_function
class TestArrayFunctionImplementation(object):
def test_one_arg(self):
@@ -322,7 +332,7 @@ class TestNDArrayMethods(object):
assert_equal(repr(array), 'MyArray(1)')
assert_equal(str(array), '1')
-
+
class TestNumPyFunctions(object):
def test_module(self):
@@ -331,6 +341,12 @@ class TestNumPyFunctions(object):
assert_equal(np.fft.fft.__module__, 'numpy.fft')
assert_equal(np.linalg.solve.__module__, 'numpy.linalg')
+ @pytest.mark.skipif(sys.version_info[0] < 3, reason="Python 3 only")
+ def test_inspect_sum(self):
+ signature = inspect.signature(np.sum)
+ assert_('axis' in signature.parameters)
+
+ @requires_array_function
def test_override_sum(self):
MyArray, implements = _new_duck_type_and_implements()
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py
index 6e7cab3fa..f56c4f4db 100644
--- a/numpy/lib/shape_base.py
+++ b/numpy/lib/shape_base.py
@@ -11,7 +11,8 @@ from numpy.core.fromnumeric import product, reshape, transpose
from numpy.core.multiarray import normalize_axis_index
from numpy.core import overrides
from numpy.core import vstack, atleast_3d
-from numpy.core.shape_base import _arrays_for_stack_dispatcher
+from numpy.core.shape_base import (
+ _arrays_for_stack_dispatcher, _warn_for_nonsequence)
from numpy.lib.index_tricks import ndindex
from numpy.matrixlib.defmatrix import matrix # this raises all the right alarm bells
@@ -629,6 +630,7 @@ def column_stack(tup):
[3, 4]])
"""
+ _warn_for_nonsequence(tup)
arrays = []
for v in tup:
arr = array(v, copy=False, subok=True)
@@ -693,6 +695,7 @@ def dstack(tup):
[[3, 4]]])
"""
+ _warn_for_nonsequence(tup)
return _nx.concatenate([atleast_3d(_m) for _m in tup], 2)
diff --git a/numpy/lib/ufunclike.py b/numpy/lib/ufunclike.py
index ac0af0b37..9a9e6f9dd 100644
--- a/numpy/lib/ufunclike.py
+++ b/numpy/lib/ufunclike.py
@@ -8,7 +8,7 @@ from __future__ import division, absolute_import, print_function
__all__ = ['fix', 'isneginf', 'isposinf']
import numpy.core.numeric as nx
-from numpy.core.overrides import array_function_dispatch
+from numpy.core.overrides import array_function_dispatch, ENABLE_ARRAY_FUNCTION
import warnings
import functools
@@ -55,6 +55,10 @@ def _fix_out_named_y(f):
return func
+if not ENABLE_ARRAY_FUNCTION:
+ _fix_out_named_y = _deprecate_out_named_y
+
+
@_deprecate_out_named_y
def _dispatcher(x, out=None):
return (x, out)
diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py
index e54fbc390..e35cdb6cf 100644
--- a/numpy/testing/tests/test_utils.py
+++ b/numpy/testing/tests/test_utils.py
@@ -180,18 +180,15 @@ class TestArrayEqual(_GenericTest):
self._test_not_equal(b, a)
def test_subclass_that_does_not_implement_npall(self):
- # While we cannot guarantee testing functions will always work for
- # subclasses, the tests should ideally rely only on subclasses having
- # comparison operators, not on them being able to store booleans
- # (which, e.g., astropy Quantity cannot usefully do). See gh-8452.
class MyArray(np.ndarray):
def __array_function__(self, *args, **kwargs):
return NotImplemented
a = np.array([1., 2.]).view(MyArray)
b = np.array([2., 3.]).view(MyArray)
- with assert_raises(TypeError):
- np.all(a)
+ if np.core.overrides.ENABLE_ARRAY_FUNCTION:
+ with assert_raises(TypeError):
+ np.all(a)
self._test_equal(a, a)
self._test_not_equal(a, b)
self._test_not_equal(b, a)