summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--doc/source/reference/routines.other.rst9
-rw-r--r--numpy/add_newdocs.py31
-rw-r--r--numpy/core/_internal.py4
-rw-r--r--numpy/core/function_base.py47
-rw-r--r--numpy/core/numeric.py9
-rw-r--r--numpy/core/src/multiarray/multiarray_tests.c.src72
-rw-r--r--numpy/core/src/multiarray/multiarraymodule.c82
-rw-r--r--numpy/core/tests/test_mem_overlap.py32
8 files changed, 184 insertions, 102 deletions
diff --git a/doc/source/reference/routines.other.rst b/doc/source/reference/routines.other.rst
index 354f45733..a3a1f8a06 100644
--- a/doc/source/reference/routines.other.rst
+++ b/doc/source/reference/routines.other.rst
@@ -22,3 +22,12 @@ Performance tuning
restoredot
setbufsize
getbufsize
+
+Memory ranges
+-------------
+
+.. autosummary::
+ :toctree: generated/
+
+ shares_memory
+ may_share_memory
diff --git a/numpy/add_newdocs.py b/numpy/add_newdocs.py
index 195789a39..0af053f89 100644
--- a/numpy/add_newdocs.py
+++ b/numpy/add_newdocs.py
@@ -3784,24 +3784,41 @@ add_newdoc('numpy.core.multiarray', 'ndarray', ('min',
"""))
-add_newdoc('numpy.core.multiarray', 'may_share_memory',
+add_newdoc('numpy.core.multiarray', 'shares_memory',
"""
- Determine if two arrays can share memory
+ shares_memory(a, b, max_work=None)
- The memory-bounds of a and b are computed. If they overlap then
- this function returns True. Otherwise, it returns False.
-
- A return of True does not necessarily mean that the two arrays
- share any element. It just means that they *might*.
+ Determine if two arrays share memory
Parameters
----------
a, b : ndarray
+ Input arrays
+ max_work : int, optional
+ Effort to spend on solving the overlap problem (maximum number
+ of candidate solutions to consider). Note max_work=1 handles
+ most usual cases. In addition, the following special values
+ are recognized:
+
+ max_work=MAY_SHARE_EXACT (default)
+ The problem is solved exactly. In this case, the function returns
+ True only if there is an element shared between the arrays.
+ max_work=MAY_SHARE_BOUNDS
+ Only the memory bounds of a and b are checked.
+
+ Raises
+ ------
+ numpy.TooHardError
+ Exceeded max_work.
Returns
-------
out : bool
+ See Also
+ --------
+ may_share_memory
+
Examples
--------
>>> np.may_share_memory(np.array([1,2]), np.array([5,8,9]))
diff --git a/numpy/core/_internal.py b/numpy/core/_internal.py
index bf492d105..879f4a224 100644
--- a/numpy/core/_internal.py
+++ b/numpy/core/_internal.py
@@ -759,3 +759,7 @@ def _gcd(a, b):
while b:
a, b = b, a % b
return a
+
+# Exception used in shares_memory()
+class TooHardError(RuntimeError):
+ pass
diff --git a/numpy/core/function_base.py b/numpy/core/function_base.py
index 532ef2950..05fea557a 100644
--- a/numpy/core/function_base.py
+++ b/numpy/core/function_base.py
@@ -1,9 +1,9 @@
from __future__ import division, absolute_import, print_function
-__all__ = ['logspace', 'linspace']
+__all__ = ['logspace', 'linspace', 'may_share_memory']
from . import numeric as _nx
-from .numeric import result_type, NaN
+from .numeric import result_type, NaN, shares_memory, MAY_SHARE_BOUNDS, TooHardError
def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None):
@@ -201,3 +201,46 @@ def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None):
if dtype is None:
return _nx.power(base, y)
return _nx.power(base, y).astype(dtype)
+
+
+def may_share_memory(a, b, max_work=None):
+ """Determine if two arrays can share memory
+
+ A return of True does not necessarily mean that the two arrays
+ share any element. It just means that they *might*.
+
+ Only the memory bounds of a and b are checked by default.
+
+ Parameters
+ ----------
+ a, b : ndarray
+ Input arrays
+ max_work : int, optional
+ Effort to spend on solving the overlap problem. See
+ `shares_memory` for details. Default for ``may_share_memory``
+ is to do a bounds check.
+
+ Returns
+ -------
+ out : bool
+
+ See Also
+ --------
+ shares_memory
+
+ Examples
+ --------
+ >>> np.may_share_memory(np.array([1,2]), np.array([5,8,9]))
+ False
+ >>> x = np.zeros([3, 4])
+ >>> np.may_share_memory(x[:,0], x[:,1])
+ True
+
+ """
+ if max_work is None:
+ max_work = MAY_SHARE_BOUNDS
+ try:
+ return shares_memory(a, b, max_work=max_work)
+ except (TooHardError, OverflowError):
+ # Unable to determine, assume yes
+ return True
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index f0163876f..1b7dfca3e 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -10,6 +10,7 @@ from .umath import (invert, sin, UFUNC_BUFSIZE_DEFAULT, ERR_IGNORE,
ERR_DEFAULT, PINF, NAN)
from . import numerictypes
from .numerictypes import longlong, intc, int_, float_, complex_, bool_
+from ._internal import TooHardError
if sys.version_info[0] >= 3:
import pickle
@@ -39,8 +40,8 @@ __all__ = [
'getbufsize', 'seterrcall', 'geterrcall', 'errstate', 'flatnonzero',
'Inf', 'inf', 'infty', 'Infinity', 'nan', 'NaN', 'False_', 'True_',
'bitwise_not', 'CLIP', 'RAISE', 'WRAP', 'MAXDIMS', 'BUFSIZE',
- 'ALLOW_THREADS', 'ComplexWarning', 'may_share_memory', 'full',
- 'full_like', 'matmul',
+ 'ALLOW_THREADS', 'ComplexWarning', 'full', 'full_like', 'matmul',
+ 'shares_memory', 'MAY_SHARE_BOUNDS', 'MAY_SHARE_EXACT', 'TooHardError',
]
if sys.version_info[0] < 3:
@@ -65,6 +66,8 @@ RAISE = multiarray.RAISE
MAXDIMS = multiarray.MAXDIMS
ALLOW_THREADS = multiarray.ALLOW_THREADS
BUFSIZE = multiarray.BUFSIZE
+MAY_SHARE_BOUNDS = multiarray.MAY_SHARE_BOUNDS
+MAY_SHARE_EXACT = multiarray.MAY_SHARE_EXACT
ndarray = multiarray.ndarray
flatiter = multiarray.flatiter
@@ -375,7 +378,7 @@ fromstring = multiarray.fromstring
fromiter = multiarray.fromiter
fromfile = multiarray.fromfile
frombuffer = multiarray.frombuffer
-may_share_memory = multiarray.may_share_memory
+shares_memory = multiarray.shares_memory
if sys.version_info[0] < 3:
newbuffer = multiarray.newbuffer
getbuffer = multiarray.getbuffer
diff --git a/numpy/core/src/multiarray/multiarray_tests.c.src b/numpy/core/src/multiarray/multiarray_tests.c.src
index 706fa022f..4e59f57f7 100644
--- a/numpy/core/src/multiarray/multiarray_tests.c.src
+++ b/numpy/core/src/multiarray/multiarray_tests.c.src
@@ -1054,75 +1054,6 @@ fail:
return NULL;
}
-static PyObject *
-array_solve_may_share_memory(PyObject *NPY_UNUSED(ignored), PyObject *args, PyObject *kwds)
-{
- PyArrayObject * self = NULL;
- PyArrayObject * other = NULL;
- PyObject *max_work_obj = NULL;
- static char *kwlist[] = {"self", "other", "max_work", NULL};
-
- mem_overlap_t result;
- Py_ssize_t max_work;
- NPY_BEGIN_THREADS_DEF;
-
- if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&O&|O", kwlist,
- PyArray_Converter, &self,
- PyArray_Converter, &other,
- &max_work_obj)) {
- return NULL;
- }
-
- if (max_work_obj == NULL || max_work_obj == Py_None) {
- max_work = NPY_MAY_SHARE_BOUNDS;
- }
-#if defined(NPY_PY3K)
- else if (PyLong_Check(max_work_obj)) {
- max_work = PyLong_AsSsize_t(max_work_obj);
- }
-#else
- else if (PyInt_Check(max_work_obj)) {
- max_work = PyInt_AsSsize_t(max_work_obj);
- }
-#endif
- else {
- PyErr_SetString(PyExc_ValueError, "max_work must be an integer");
- goto fail;
- }
-
- if (max_work < -1) {
- PyErr_SetString(PyExc_ValueError, "Invalid value for max_work");
- goto fail;
- }
-
- NPY_BEGIN_THREADS;
- result = solve_may_share_memory(self, other, max_work);
- NPY_END_THREADS;
-
- Py_XDECREF(self);
- Py_XDECREF(other);
-
- if (result == MEM_OVERLAP_NO) {
- Py_RETURN_FALSE;
- }
- else if (result == MEM_OVERLAP_YES ||
- result == MEM_OVERLAP_OVERFLOW ||
- result == MEM_OVERLAP_TOO_HARD) {
- Py_RETURN_TRUE;
- }
- else {
- /* Doesn't happen usually */
- PyErr_SetString(PyExc_RuntimeError,
- "Internal error in may_share_memory");
- return NULL;
- }
-
-fail:
- Py_XDECREF(self);
- Py_XDECREF(other);
- return NULL;
-}
-
static PyObject *
pylong_from_int128(npy_extint128_t value)
@@ -1606,9 +1537,6 @@ static PyMethodDef Multiarray_TestsMethods[] = {
{"solve_diophantine",
(PyCFunction)array_solve_diophantine,
METH_VARARGS | METH_KEYWORDS, NULL},
- {"solve_may_share_memory",
- (PyCFunction)array_solve_may_share_memory,
- METH_VARARGS | METH_KEYWORDS, NULL},
{"extint_safe_binop",
extint_safe_binop,
METH_VARARGS, NULL},
diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c
index 04513c56c..e72c355dc 100644
--- a/numpy/core/src/multiarray/multiarraymodule.c
+++ b/numpy/core/src/multiarray/multiarraymodule.c
@@ -60,6 +60,7 @@ NPY_NO_EXPORT int NPY_NUMUSERTYPES = 0;
#include "vdot.h"
#include "templ_common.h" /* for npy_mul_with_overflow_intp */
#include "compiled_base.h"
+#include "mem_overlap.h"
/* Only here for API compatibility */
NPY_NO_EXPORT PyTypeObject PyBigArray_Type;
@@ -3993,30 +3994,88 @@ test_interrupt(PyObject *NPY_UNUSED(self), PyObject *args)
return PyInt_FromLong(a);
}
+
static PyObject *
-array_may_share_memory(PyObject *NPY_UNUSED(ignored), PyObject *args)
+array_shares_memory(PyObject *NPY_UNUSED(ignored), PyObject *args, PyObject *kwds)
{
PyArrayObject * self = NULL;
PyArrayObject * other = NULL;
- int overlap;
+ PyObject *max_work_obj = NULL;
+ static char *kwlist[] = {"self", "other", "max_work", NULL};
+
+ mem_overlap_t result;
+ static PyObject *too_hard_cls = NULL;
+ Py_ssize_t max_work = NPY_MAY_SHARE_EXACT;
+ NPY_BEGIN_THREADS_DEF;
- if (!PyArg_ParseTuple(args, "O&O&", PyArray_Converter, &self,
- PyArray_Converter, &other)) {
+ if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&O&|O", kwlist,
+ PyArray_Converter, &self,
+ PyArray_Converter, &other,
+ &max_work_obj)) {
return NULL;
}
- overlap = arrays_overlap(self, other);
+ if (max_work_obj == NULL || max_work_obj == Py_None) {
+ /* noop */
+ }
+ else if (PyLong_Check(max_work_obj)) {
+ max_work = PyLong_AsSsize_t(max_work_obj);
+ }
+#if !defined(NPY_PY3K)
+ else if (PyInt_Check(max_work_obj)) {
+ max_work = PyInt_AsSsize_t(max_work_obj);
+ }
+#endif
+ else {
+ PyErr_SetString(PyExc_ValueError, "max_work must be an integer");
+ goto fail;
+ }
+
+ if (max_work < -2) {
+ PyErr_SetString(PyExc_ValueError, "Invalid value for max_work");
+ goto fail;
+ }
+
+ NPY_BEGIN_THREADS;
+ result = solve_may_share_memory(self, other, max_work);
+ NPY_END_THREADS;
+
Py_XDECREF(self);
Py_XDECREF(other);
- if (overlap) {
+ if (result == MEM_OVERLAP_NO) {
+ Py_RETURN_FALSE;
+ }
+ else if (result == MEM_OVERLAP_YES) {
Py_RETURN_TRUE;
}
+ else if (result == MEM_OVERLAP_OVERFLOW) {
+ PyErr_SetString(PyExc_OverflowError,
+ "Integer overflow in computing overlap");
+ return NULL;
+ }
+ else if (result == MEM_OVERLAP_TOO_HARD) {
+ npy_cache_import("numpy.core._internal", "TooHardError",
+ &too_hard_cls);
+ if (too_hard_cls) {
+ PyErr_SetString(too_hard_cls, "Exceeded max_work");
+ }
+ return NULL;
+ }
else {
- Py_RETURN_FALSE;
+ /* Doesn't happen usually */
+ PyErr_SetString(PyExc_RuntimeError,
+ "Error in computing overlap");
+ return NULL;
}
+
+fail:
+ Py_XDECREF(self);
+ Py_XDECREF(other);
+ return NULL;
}
+
static struct PyMethodDef array_module_methods[] = {
{"_get_ndarray_c_version",
(PyCFunction)array__get_ndarray_c_version,
@@ -4123,9 +4182,9 @@ static struct PyMethodDef array_module_methods[] = {
{"result_type",
(PyCFunction)array_result_type,
METH_VARARGS, NULL},
- {"may_share_memory",
- (PyCFunction)array_may_share_memory,
- METH_VARARGS, NULL},
+ {"shares_memory",
+ (PyCFunction)array_shares_memory,
+ METH_VARARGS | METH_KEYWORDS, NULL},
/* Datetime-related functions */
{"datetime_data",
(PyCFunction)array_datetime_data,
@@ -4583,6 +4642,9 @@ PyMODINIT_FUNC initmultiarray(void) {
ADDCONST(RAISE);
ADDCONST(WRAP);
ADDCONST(MAXDIMS);
+
+ ADDCONST(MAY_SHARE_BOUNDS);
+ ADDCONST(MAY_SHARE_EXACT);
#undef ADDCONST
Py_INCREF(&PyArray_Type);
diff --git a/numpy/core/tests/test_mem_overlap.py b/numpy/core/tests/test_mem_overlap.py
index e48d4891d..64938a938 100644
--- a/numpy/core/tests/test_mem_overlap.py
+++ b/numpy/core/tests/test_mem_overlap.py
@@ -6,8 +6,9 @@ import itertools
import numpy as np
from numpy.testing import run_module_suite, assert_, assert_raises, assert_equal
-from numpy.core.multiarray_tests import solve_diophantine, solve_may_share_memory
+from numpy.core.multiarray_tests import solve_diophantine
from numpy.lib.stride_tricks import as_strided
+from numpy.compat import long
if sys.version_info[0] >= 3:
xrange = range
@@ -172,10 +173,10 @@ def test_diophantine_overflow():
def check_may_share_memory_exact(a, b):
- got = solve_may_share_memory(a, b, max_work=MAY_SHARE_EXACT)
+ got = np.may_share_memory(a, b, max_work=MAY_SHARE_EXACT)
assert_equal(np.may_share_memory(a, b),
- solve_may_share_memory(a, b, max_work=MAY_SHARE_BOUNDS))
+ np.may_share_memory(a, b, max_work=MAY_SHARE_BOUNDS))
a.fill(0)
b.fill(0)
@@ -215,8 +216,8 @@ def test_may_share_memory_manual():
for x in xs:
# The default is a simple extent check
- assert_(solve_may_share_memory(x[:,0,:], x[:,1,:]))
- assert_(solve_may_share_memory(x[:,0,:], x[:,1,:], max_work=None))
+ assert_(np.may_share_memory(x[:,0,:], x[:,1,:]))
+ assert_(np.may_share_memory(x[:,0,:], x[:,1,:], max_work=None))
# Exact checks
check_may_share_memory_exact(x[:,0,:], x[:,1,:])
@@ -286,10 +287,10 @@ def check_may_share_memory_easy_fuzz(get_max_work, same_steps, min_count):
a = x[s1].transpose(t1)
b = x[s2].transpose(t2)
- bounds_overlap = solve_may_share_memory(a, b)
+ bounds_overlap = np.may_share_memory(a, b)
may_share_answer = np.may_share_memory(a, b)
- easy_answer = solve_may_share_memory(a, b, max_work=get_max_work(a, b))
- exact_answer = solve_may_share_memory(a, b, max_work=MAY_SHARE_EXACT)
+ easy_answer = np.may_share_memory(a, b, max_work=get_max_work(a, b))
+ exact_answer = np.may_share_memory(a, b, max_work=MAY_SHARE_EXACT)
if easy_answer != exact_answer:
# assert_equal is slow...
@@ -328,5 +329,20 @@ def test_may_share_memory_harder_fuzz():
min_count=2000)
+
+def test_shares_memory_api():
+ x = np.zeros([4, 5, 6], dtype=np.int8)
+
+ assert_equal(np.shares_memory(x, x), True)
+ assert_equal(np.shares_memory(x, x.copy()), False)
+
+ a = x[:,::2,::3]
+ b = x[:,::3,::2]
+ assert_equal(np.shares_memory(a, b), True)
+ assert_equal(np.shares_memory(a, b, max_work=None), True)
+ assert_raises(np.TooHardError, np.shares_memory, a, b, max_work=1)
+ assert_raises(np.TooHardError, np.shares_memory, a, b, max_work=long(1))
+
+
if __name__ == "__main__":
run_module_suite()