diff options
-rw-r--r-- | numpy/linalg/lapack_litemodule.c | 20 | ||||
-rw-r--r-- | numpy/linalg/tests/test_linalg.py | 20 |
2 files changed, 39 insertions, 1 deletions
diff --git a/numpy/linalg/lapack_litemodule.c b/numpy/linalg/lapack_litemodule.c index 53e9f7a90..2f142f105 100644 --- a/numpy/linalg/lapack_litemodule.c +++ b/numpy/linalg/lapack_litemodule.c @@ -93,6 +93,8 @@ extern int FNAME(zungqr)(int *m, int *n, int *k, f2c_doublecomplex a[], int *lda, f2c_doublecomplex tau[], f2c_doublecomplex work[], int *lwork, int *info); +extern int FNAME(xerbla)(char *srname, int *info); + static PyObject *LapackError; #define TRY(E) if (!(E)) return NULL @@ -857,6 +859,23 @@ lapack_lite_zungqr(PyObject *NPY_UNUSED(self), PyObject *args) } +static PyObject * +lapack_lite_xerbla(PyObject *NPY_UNUSED(self), PyObject *args) +{ + int info = -1; + + NPY_BEGIN_THREADS_DEF; + NPY_BEGIN_THREADS; + FNAME(xerbla)("test", &info); + NPY_END_THREADS; + + if (PyErr_Occurred()) { + return NULL; + } + Py_INCREF(Py_None); + return Py_None; +} + #define STR(x) #x #define lameth(name) {STR(name), lapack_lite_##name, METH_VARARGS, NULL} @@ -879,6 +898,7 @@ static struct PyMethodDef lapack_lite_module_methods[] = { lameth(zpotrf), lameth(zgeqrf), lameth(zungqr), + lameth(xerbla), { NULL,NULL,0, NULL} }; diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py index 84c95af10..9e12d7e87 100644 --- a/numpy/linalg/tests/test_linalg.py +++ b/numpy/linalg/tests/test_linalg.py @@ -7,7 +7,7 @@ import sys import numpy as np from numpy.testing import (TestCase, assert_, assert_equal, assert_raises, assert_array_equal, assert_almost_equal, - run_module_suite) + run_module_suite, dec) from numpy import array, single, double, csingle, cdouble, dot, identity from numpy import multiply, atleast_2d, inf, asarray, matrix from numpy import linalg @@ -750,5 +750,23 @@ def test_generalized_raise_multiloop(): assert_raises(np.linalg.LinAlgError, np.linalg.inv, x) + +@dec.skipif(sys.platform == "win32", "python_xerbla not enabled on Win32") +def test_xerbla(): + # Test that xerbla works (with GIL) + a = np.array([[1]]) + try: + np.linalg.lapack_lite.dgetrf( + 1, 1, a.astype(np.double), + 0, # <- invalid value + a.astype(np.intc), 0) + except ValueError as e: + assert_("DGETRF parameter number 4" in str(e)) + else: + assert_(False) + + # Test that xerbla works (without GIL) + assert_raises(ValueError, np.linalg.lapack_lite.xerbla) + if __name__ == "__main__": run_module_suite() |