diff options
-rw-r--r-- | doc/source/reference/routines.other.rst | 9 | ||||
-rw-r--r-- | numpy/add_newdocs.py | 31 | ||||
-rw-r--r-- | numpy/core/_internal.py | 4 | ||||
-rw-r--r-- | numpy/core/function_base.py | 47 | ||||
-rw-r--r-- | numpy/core/numeric.py | 9 | ||||
-rw-r--r-- | numpy/core/src/multiarray/multiarray_tests.c.src | 72 | ||||
-rw-r--r-- | numpy/core/src/multiarray/multiarraymodule.c | 82 | ||||
-rw-r--r-- | numpy/core/tests/test_mem_overlap.py | 32 |
8 files changed, 184 insertions, 102 deletions
diff --git a/doc/source/reference/routines.other.rst b/doc/source/reference/routines.other.rst index 354f45733..a3a1f8a06 100644 --- a/doc/source/reference/routines.other.rst +++ b/doc/source/reference/routines.other.rst @@ -22,3 +22,12 @@ Performance tuning restoredot setbufsize getbufsize + +Memory ranges +------------- + +.. autosummary:: + :toctree: generated/ + + shares_memory + may_share_memory diff --git a/numpy/add_newdocs.py b/numpy/add_newdocs.py index 195789a39..0af053f89 100644 --- a/numpy/add_newdocs.py +++ b/numpy/add_newdocs.py @@ -3784,24 +3784,41 @@ add_newdoc('numpy.core.multiarray', 'ndarray', ('min', """)) -add_newdoc('numpy.core.multiarray', 'may_share_memory', +add_newdoc('numpy.core.multiarray', 'shares_memory', """ - Determine if two arrays can share memory + shares_memory(a, b, max_work=None) - The memory-bounds of a and b are computed. If they overlap then - this function returns True. Otherwise, it returns False. - - A return of True does not necessarily mean that the two arrays - share any element. It just means that they *might*. + Determine if two arrays share memory Parameters ---------- a, b : ndarray + Input arrays + max_work : int, optional + Effort to spend on solving the overlap problem (maximum number + of candidate solutions to consider). Note max_work=1 handles + most usual cases. In addition, the following special values + are recognized: + + max_work=MAY_SHARE_EXACT (default) + The problem is solved exactly. In this case, the function returns + True only if there is an element shared between the arrays. + max_work=MAY_SHARE_BOUNDS + Only the memory bounds of a and b are checked. + + Raises + ------ + numpy.TooHardError + Exceeded max_work. Returns ------- out : bool + See Also + -------- + may_share_memory + Examples -------- >>> np.may_share_memory(np.array([1,2]), np.array([5,8,9])) diff --git a/numpy/core/_internal.py b/numpy/core/_internal.py index bf492d105..879f4a224 100644 --- a/numpy/core/_internal.py +++ b/numpy/core/_internal.py @@ -759,3 +759,7 @@ def _gcd(a, b): while b: a, b = b, a % b return a + +# Exception used in shares_memory() +class TooHardError(RuntimeError): + pass diff --git a/numpy/core/function_base.py b/numpy/core/function_base.py index 532ef2950..05fea557a 100644 --- a/numpy/core/function_base.py +++ b/numpy/core/function_base.py @@ -1,9 +1,9 @@ from __future__ import division, absolute_import, print_function -__all__ = ['logspace', 'linspace'] +__all__ = ['logspace', 'linspace', 'may_share_memory'] from . import numeric as _nx -from .numeric import result_type, NaN +from .numeric import result_type, NaN, shares_memory, MAY_SHARE_BOUNDS, TooHardError def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None): @@ -201,3 +201,46 @@ def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None): if dtype is None: return _nx.power(base, y) return _nx.power(base, y).astype(dtype) + + +def may_share_memory(a, b, max_work=None): + """Determine if two arrays can share memory + + A return of True does not necessarily mean that the two arrays + share any element. It just means that they *might*. + + Only the memory bounds of a and b are checked by default. + + Parameters + ---------- + a, b : ndarray + Input arrays + max_work : int, optional + Effort to spend on solving the overlap problem. See + `shares_memory` for details. Default for ``may_share_memory`` + is to do a bounds check. + + Returns + ------- + out : bool + + See Also + -------- + shares_memory + + Examples + -------- + >>> np.may_share_memory(np.array([1,2]), np.array([5,8,9])) + False + >>> x = np.zeros([3, 4]) + >>> np.may_share_memory(x[:,0], x[:,1]) + True + + """ + if max_work is None: + max_work = MAY_SHARE_BOUNDS + try: + return shares_memory(a, b, max_work=max_work) + except (TooHardError, OverflowError): + # Unable to determine, assume yes + return True diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index f0163876f..1b7dfca3e 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -10,6 +10,7 @@ from .umath import (invert, sin, UFUNC_BUFSIZE_DEFAULT, ERR_IGNORE, ERR_DEFAULT, PINF, NAN) from . import numerictypes from .numerictypes import longlong, intc, int_, float_, complex_, bool_ +from ._internal import TooHardError if sys.version_info[0] >= 3: import pickle @@ -39,8 +40,8 @@ __all__ = [ 'getbufsize', 'seterrcall', 'geterrcall', 'errstate', 'flatnonzero', 'Inf', 'inf', 'infty', 'Infinity', 'nan', 'NaN', 'False_', 'True_', 'bitwise_not', 'CLIP', 'RAISE', 'WRAP', 'MAXDIMS', 'BUFSIZE', - 'ALLOW_THREADS', 'ComplexWarning', 'may_share_memory', 'full', - 'full_like', 'matmul', + 'ALLOW_THREADS', 'ComplexWarning', 'full', 'full_like', 'matmul', + 'shares_memory', 'MAY_SHARE_BOUNDS', 'MAY_SHARE_EXACT', 'TooHardError', ] if sys.version_info[0] < 3: @@ -65,6 +66,8 @@ RAISE = multiarray.RAISE MAXDIMS = multiarray.MAXDIMS ALLOW_THREADS = multiarray.ALLOW_THREADS BUFSIZE = multiarray.BUFSIZE +MAY_SHARE_BOUNDS = multiarray.MAY_SHARE_BOUNDS +MAY_SHARE_EXACT = multiarray.MAY_SHARE_EXACT ndarray = multiarray.ndarray flatiter = multiarray.flatiter @@ -375,7 +378,7 @@ fromstring = multiarray.fromstring fromiter = multiarray.fromiter fromfile = multiarray.fromfile frombuffer = multiarray.frombuffer -may_share_memory = multiarray.may_share_memory +shares_memory = multiarray.shares_memory if sys.version_info[0] < 3: newbuffer = multiarray.newbuffer getbuffer = multiarray.getbuffer diff --git a/numpy/core/src/multiarray/multiarray_tests.c.src b/numpy/core/src/multiarray/multiarray_tests.c.src index 706fa022f..4e59f57f7 100644 --- a/numpy/core/src/multiarray/multiarray_tests.c.src +++ b/numpy/core/src/multiarray/multiarray_tests.c.src @@ -1054,75 +1054,6 @@ fail: return NULL; } -static PyObject * -array_solve_may_share_memory(PyObject *NPY_UNUSED(ignored), PyObject *args, PyObject *kwds) -{ - PyArrayObject * self = NULL; - PyArrayObject * other = NULL; - PyObject *max_work_obj = NULL; - static char *kwlist[] = {"self", "other", "max_work", NULL}; - - mem_overlap_t result; - Py_ssize_t max_work; - NPY_BEGIN_THREADS_DEF; - - if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&O&|O", kwlist, - PyArray_Converter, &self, - PyArray_Converter, &other, - &max_work_obj)) { - return NULL; - } - - if (max_work_obj == NULL || max_work_obj == Py_None) { - max_work = NPY_MAY_SHARE_BOUNDS; - } -#if defined(NPY_PY3K) - else if (PyLong_Check(max_work_obj)) { - max_work = PyLong_AsSsize_t(max_work_obj); - } -#else - else if (PyInt_Check(max_work_obj)) { - max_work = PyInt_AsSsize_t(max_work_obj); - } -#endif - else { - PyErr_SetString(PyExc_ValueError, "max_work must be an integer"); - goto fail; - } - - if (max_work < -1) { - PyErr_SetString(PyExc_ValueError, "Invalid value for max_work"); - goto fail; - } - - NPY_BEGIN_THREADS; - result = solve_may_share_memory(self, other, max_work); - NPY_END_THREADS; - - Py_XDECREF(self); - Py_XDECREF(other); - - if (result == MEM_OVERLAP_NO) { - Py_RETURN_FALSE; - } - else if (result == MEM_OVERLAP_YES || - result == MEM_OVERLAP_OVERFLOW || - result == MEM_OVERLAP_TOO_HARD) { - Py_RETURN_TRUE; - } - else { - /* Doesn't happen usually */ - PyErr_SetString(PyExc_RuntimeError, - "Internal error in may_share_memory"); - return NULL; - } - -fail: - Py_XDECREF(self); - Py_XDECREF(other); - return NULL; -} - static PyObject * pylong_from_int128(npy_extint128_t value) @@ -1606,9 +1537,6 @@ static PyMethodDef Multiarray_TestsMethods[] = { {"solve_diophantine", (PyCFunction)array_solve_diophantine, METH_VARARGS | METH_KEYWORDS, NULL}, - {"solve_may_share_memory", - (PyCFunction)array_solve_may_share_memory, - METH_VARARGS | METH_KEYWORDS, NULL}, {"extint_safe_binop", extint_safe_binop, METH_VARARGS, NULL}, diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c index 04513c56c..e72c355dc 100644 --- a/numpy/core/src/multiarray/multiarraymodule.c +++ b/numpy/core/src/multiarray/multiarraymodule.c @@ -60,6 +60,7 @@ NPY_NO_EXPORT int NPY_NUMUSERTYPES = 0; #include "vdot.h" #include "templ_common.h" /* for npy_mul_with_overflow_intp */ #include "compiled_base.h" +#include "mem_overlap.h" /* Only here for API compatibility */ NPY_NO_EXPORT PyTypeObject PyBigArray_Type; @@ -3993,30 +3994,88 @@ test_interrupt(PyObject *NPY_UNUSED(self), PyObject *args) return PyInt_FromLong(a); } + static PyObject * -array_may_share_memory(PyObject *NPY_UNUSED(ignored), PyObject *args) +array_shares_memory(PyObject *NPY_UNUSED(ignored), PyObject *args, PyObject *kwds) { PyArrayObject * self = NULL; PyArrayObject * other = NULL; - int overlap; + PyObject *max_work_obj = NULL; + static char *kwlist[] = {"self", "other", "max_work", NULL}; + + mem_overlap_t result; + static PyObject *too_hard_cls = NULL; + Py_ssize_t max_work = NPY_MAY_SHARE_EXACT; + NPY_BEGIN_THREADS_DEF; - if (!PyArg_ParseTuple(args, "O&O&", PyArray_Converter, &self, - PyArray_Converter, &other)) { + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&O&|O", kwlist, + PyArray_Converter, &self, + PyArray_Converter, &other, + &max_work_obj)) { return NULL; } - overlap = arrays_overlap(self, other); + if (max_work_obj == NULL || max_work_obj == Py_None) { + /* noop */ + } + else if (PyLong_Check(max_work_obj)) { + max_work = PyLong_AsSsize_t(max_work_obj); + } +#if !defined(NPY_PY3K) + else if (PyInt_Check(max_work_obj)) { + max_work = PyInt_AsSsize_t(max_work_obj); + } +#endif + else { + PyErr_SetString(PyExc_ValueError, "max_work must be an integer"); + goto fail; + } + + if (max_work < -2) { + PyErr_SetString(PyExc_ValueError, "Invalid value for max_work"); + goto fail; + } + + NPY_BEGIN_THREADS; + result = solve_may_share_memory(self, other, max_work); + NPY_END_THREADS; + Py_XDECREF(self); Py_XDECREF(other); - if (overlap) { + if (result == MEM_OVERLAP_NO) { + Py_RETURN_FALSE; + } + else if (result == MEM_OVERLAP_YES) { Py_RETURN_TRUE; } + else if (result == MEM_OVERLAP_OVERFLOW) { + PyErr_SetString(PyExc_OverflowError, + "Integer overflow in computing overlap"); + return NULL; + } + else if (result == MEM_OVERLAP_TOO_HARD) { + npy_cache_import("numpy.core._internal", "TooHardError", + &too_hard_cls); + if (too_hard_cls) { + PyErr_SetString(too_hard_cls, "Exceeded max_work"); + } + return NULL; + } else { - Py_RETURN_FALSE; + /* Doesn't happen usually */ + PyErr_SetString(PyExc_RuntimeError, + "Error in computing overlap"); + return NULL; } + +fail: + Py_XDECREF(self); + Py_XDECREF(other); + return NULL; } + static struct PyMethodDef array_module_methods[] = { {"_get_ndarray_c_version", (PyCFunction)array__get_ndarray_c_version, @@ -4123,9 +4182,9 @@ static struct PyMethodDef array_module_methods[] = { {"result_type", (PyCFunction)array_result_type, METH_VARARGS, NULL}, - {"may_share_memory", - (PyCFunction)array_may_share_memory, - METH_VARARGS, NULL}, + {"shares_memory", + (PyCFunction)array_shares_memory, + METH_VARARGS | METH_KEYWORDS, NULL}, /* Datetime-related functions */ {"datetime_data", (PyCFunction)array_datetime_data, @@ -4583,6 +4642,9 @@ PyMODINIT_FUNC initmultiarray(void) { ADDCONST(RAISE); ADDCONST(WRAP); ADDCONST(MAXDIMS); + + ADDCONST(MAY_SHARE_BOUNDS); + ADDCONST(MAY_SHARE_EXACT); #undef ADDCONST Py_INCREF(&PyArray_Type); diff --git a/numpy/core/tests/test_mem_overlap.py b/numpy/core/tests/test_mem_overlap.py index e48d4891d..64938a938 100644 --- a/numpy/core/tests/test_mem_overlap.py +++ b/numpy/core/tests/test_mem_overlap.py @@ -6,8 +6,9 @@ import itertools import numpy as np from numpy.testing import run_module_suite, assert_, assert_raises, assert_equal -from numpy.core.multiarray_tests import solve_diophantine, solve_may_share_memory +from numpy.core.multiarray_tests import solve_diophantine from numpy.lib.stride_tricks import as_strided +from numpy.compat import long if sys.version_info[0] >= 3: xrange = range @@ -172,10 +173,10 @@ def test_diophantine_overflow(): def check_may_share_memory_exact(a, b): - got = solve_may_share_memory(a, b, max_work=MAY_SHARE_EXACT) + got = np.may_share_memory(a, b, max_work=MAY_SHARE_EXACT) assert_equal(np.may_share_memory(a, b), - solve_may_share_memory(a, b, max_work=MAY_SHARE_BOUNDS)) + np.may_share_memory(a, b, max_work=MAY_SHARE_BOUNDS)) a.fill(0) b.fill(0) @@ -215,8 +216,8 @@ def test_may_share_memory_manual(): for x in xs: # The default is a simple extent check - assert_(solve_may_share_memory(x[:,0,:], x[:,1,:])) - assert_(solve_may_share_memory(x[:,0,:], x[:,1,:], max_work=None)) + assert_(np.may_share_memory(x[:,0,:], x[:,1,:])) + assert_(np.may_share_memory(x[:,0,:], x[:,1,:], max_work=None)) # Exact checks check_may_share_memory_exact(x[:,0,:], x[:,1,:]) @@ -286,10 +287,10 @@ def check_may_share_memory_easy_fuzz(get_max_work, same_steps, min_count): a = x[s1].transpose(t1) b = x[s2].transpose(t2) - bounds_overlap = solve_may_share_memory(a, b) + bounds_overlap = np.may_share_memory(a, b) may_share_answer = np.may_share_memory(a, b) - easy_answer = solve_may_share_memory(a, b, max_work=get_max_work(a, b)) - exact_answer = solve_may_share_memory(a, b, max_work=MAY_SHARE_EXACT) + easy_answer = np.may_share_memory(a, b, max_work=get_max_work(a, b)) + exact_answer = np.may_share_memory(a, b, max_work=MAY_SHARE_EXACT) if easy_answer != exact_answer: # assert_equal is slow... @@ -328,5 +329,20 @@ def test_may_share_memory_harder_fuzz(): min_count=2000) + +def test_shares_memory_api(): + x = np.zeros([4, 5, 6], dtype=np.int8) + + assert_equal(np.shares_memory(x, x), True) + assert_equal(np.shares_memory(x, x.copy()), False) + + a = x[:,::2,::3] + b = x[:,::3,::2] + assert_equal(np.shares_memory(a, b), True) + assert_equal(np.shares_memory(a, b, max_work=None), True) + assert_raises(np.TooHardError, np.shares_memory, a, b, max_work=1) + assert_raises(np.TooHardError, np.shares_memory, a, b, max_work=long(1)) + + if __name__ == "__main__": run_module_suite() |