summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2008-04-27 18:44:47 +0000
committerCharles Harris <charlesr.harris@gmail.com>2008-04-27 18:44:47 +0000
commit06c0d0e97c7781cc81be38c1d3124890822b303f (patch)
treec0e35437cb0b6b29070cc01b08bf9be8b7f136ec /numpy
parent8d915a55c5ecbca15ebaf13584b0c255d22768a1 (diff)
downloadnumpy-06c0d0e97c7781cc81be38c1d3124890822b303f.tar.gz
Fix test of lstsqr to work with matrix tests.
Fix lstsq
Diffstat (limited to 'numpy')
-rw-r--r--numpy/linalg/linalg.py8
-rw-r--r--numpy/linalg/tests/test_linalg.py2
2 files changed, 5 insertions, 5 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index a557af875..ea4585dfa 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -1147,10 +1147,10 @@ def lstsq(a, b, rcond=-1):
"""
import math
- a = _makearray(a)
+ a, _ = _makearray(a)
b, wrap = _makearray(b)
- one_eq = len(b.shape) == 1
- if one_eq:
+ is_1d = len(b.shape) == 1
+ if is_1d:
b = b[:, newaxis]
_assertRank2(a, b)
m = a.shape[0]
@@ -1199,7 +1199,7 @@ def lstsq(a, b, rcond=-1):
if results['info'] > 0:
raise LinAlgError, 'SVD did not converge in Linear Least Squares'
resids = array([], t)
- if one_eq:
+ if is_1d:
x = array(ravel(bstar)[:n], dtype=result_t, copy=True)
if results['rank'] == n and m > n:
resids = array([sum((ravel(bstar)[n:])**2)], dtype=result_t)
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py
index 9a6e8bfc9..b6274bc35 100644
--- a/numpy/linalg/tests/test_linalg.py
+++ b/numpy/linalg/tests/test_linalg.py
@@ -144,7 +144,7 @@ class TestLstsq(LinalgTestCase):
x, residuals, rank, sv = linalg.lstsq(a, b)
assert_almost_equal(b, dot(a, x))
assert_equal(rank, asarray(a).shape[0])
- assert_almost_equal(sv, s)
+ assert_almost_equal(sv, sv.__array_wrap__(s))
assert imply(isinstance(b, matrix), isinstance(x, matrix))
assert imply(isinstance(b, matrix), isinstance(residuals, matrix))
assert imply(isinstance(b, matrix), isinstance(sv, matrix))