summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2013-10-25 12:34:31 -0700
committerCharles Harris <charlesr.harris@gmail.com>2013-10-25 12:34:31 -0700
commit47b5af987bf31553329334fa08898dac67dbf1ac (patch)
tree0dc8194b41d718024296962ab1502ef9621aab19
parent103abc22f803cd825a12c9ec900df26b34d501df (diff)
parentf8e07275f05e95a4d0af098b06d37925602f7861 (diff)
downloadnumpy-47b5af987bf31553329334fa08898dac67dbf1ac.tar.gz
Merge pull request #3978 from juliantaylor/py3-eigh-bug
BUG: fix broken UPLO of eigh in python3
-rw-r--r--numpy/linalg/linalg.py13
-rw-r--r--numpy/linalg/tests/test_linalg.py49
2 files changed, 57 insertions, 5 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index aa3bdea34..c5621eace 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -914,7 +914,7 @@ def eigvalsh(a, UPLO='L'):
A complex- or real-valued matrix whose eigenvalues are to be
computed.
UPLO : {'L', 'U'}, optional
- Same as `lower`, wth 'L' for lower and 'U' for upper triangular.
+ Same as `lower`, with 'L' for lower and 'U' for upper triangular.
Deprecated.
Returns
@@ -950,10 +950,13 @@ def eigvalsh(a, UPLO='L'):
array([ 0.17157288+0.j, 5.82842712+0.j])
"""
+ UPLO = asbytes(UPLO.upper())
+ if UPLO not in (b'L', b'U'):
+ raise ValueError("UPLO argument must be 'L' or 'U'")
extobj = get_linalg_error_extobj(
_raise_linalgerror_eigenvalues_nonconvergence)
- if UPLO == 'L':
+ if UPLO == _L:
gufunc = _umath_linalg.eigvalsh_lo
else:
gufunc = _umath_linalg.eigvalsh_up
@@ -1194,7 +1197,9 @@ def eigh(a, UPLO='L'):
[ 0.00000000+0.38268343j, 0.00000000-0.92387953j]])
"""
- UPLO = asbytes(UPLO)
+ UPLO = asbytes(UPLO.upper())
+ if UPLO not in (b'L', b'U'):
+ raise ValueError("UPLO argument must be 'L' or 'U'")
a, wrap = _makearray(a)
_assertRankAtLeast2(a)
@@ -1203,7 +1208,7 @@ def eigh(a, UPLO='L'):
extobj = get_linalg_error_extobj(
_raise_linalgerror_eigenvalues_nonconvergence)
- if 'L' == UPLO:
+ if _L == UPLO:
gufunc = _umath_linalg.eigh_lo
else:
gufunc = _umath_linalg.eigh_up
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py
index cc1404bf1..803b4c88f 100644
--- a/numpy/linalg/tests/test_linalg.py
+++ b/numpy/linalg/tests/test_linalg.py
@@ -728,7 +728,7 @@ class TestEigh(HermitianTestCase, HermitianGeneralizedTestCase):
assert_allclose(dot_generalized(a, evc2),
np.asarray(ev2)[...,None,:] * np.asarray(evc2),
- rtol=get_rtol(ev.dtype))
+ rtol=get_rtol(ev.dtype), err_msg=repr(a))
def test_types(self):
def check(dtype):
@@ -739,6 +739,53 @@ class TestEigh(HermitianTestCase, HermitianGeneralizedTestCase):
for dtype in [single, double, csingle, cdouble]:
yield check, dtype
+ def test_invalid(self):
+ x = np.array([[1, 0.5], [0.5, 1]], dtype=np.float32)
+ assert_raises(ValueError, np.linalg.eigh, x, UPLO="lrong")
+ assert_raises(ValueError, np.linalg.eigh, x, "lower")
+ assert_raises(ValueError, np.linalg.eigh, x, "upper")
+ assert_raises(ValueError, np.linalg.eigvalsh, x, UPLO="lrong")
+ assert_raises(ValueError, np.linalg.eigvalsh, x, "lower")
+ assert_raises(ValueError, np.linalg.eigvalsh, x, "upper")
+
+ def test_half_filled(self):
+ expect = np.array([-0.33333333, -0.33333333, -0.33333333, 0.99999999])
+ K = np.array([[ 0. , 0. , 0. , 0. ],
+ [-0.33333333, 0. , 0. , 0. ],
+ [ 0.33333333, -0.33333333, 0. , 0. ],
+ [ 0.33333333, -0.33333333, 0.33333333, 0. ]])
+ Kr = np.rot90(K, k=2)
+
+ w, V = np.linalg.eigh(K)
+ assert_allclose(np.sort(w), expect, rtol=get_rtol(K.dtype))
+
+ w, V = np.linalg.eigh(UPLO='L', a=K)
+ assert_allclose(np.sort(w), expect, rtol=get_rtol(K.dtype))
+ w, V = np.linalg.eigh(K, 'l')
+ w2, V2 = np.linalg.eigh(K, 'L')
+ assert_allclose(w, w2, rtol=get_rtol(K.dtype))
+ assert_allclose(V, V2, rtol=get_rtol(K.dtype))
+
+ w, V = np.linalg.eigh(Kr, 'U')
+ assert_allclose(np.sort(w), expect, rtol=get_rtol(K.dtype))
+ w, V = np.linalg.eigh(Kr, 'u')
+ w2, V2 = np.linalg.eigh(Kr, 'u')
+ assert_allclose(w, w2, rtol=get_rtol(K.dtype))
+ assert_allclose(V, V2, rtol=get_rtol(K.dtype))
+
+ w = np.linalg.eigvalsh(K)
+ assert_allclose(np.sort(w), expect, rtol=get_rtol(K.dtype))
+
+ w = np.linalg.eigvalsh(UPLO='L', a=K)
+ assert_allclose(np.sort(w), expect, rtol=get_rtol(K.dtype))
+ assert_allclose(np.linalg.eigvalsh(K, 'L'),
+ np.linalg.eigvalsh(K, 'l'), rtol=get_rtol(K.dtype))
+
+ w = np.linalg.eigvalsh(Kr, 'U')
+ assert_allclose(np.sort(w), expect, rtol=get_rtol(K.dtype))
+ assert_allclose(np.linalg.eigvalsh(Kr, 'U'),
+ np.linalg.eigvalsh(Kr, 'u'), rtol=get_rtol(K.dtype))
+
class _TestNorm(object):