diff options
author | jaimefrio <jaime.frio@gmail.com> | 2014-07-10 09:03:16 -0700 |
---|---|---|
committer | jaimefrio <jaime.frio@gmail.com> | 2014-07-10 11:44:50 -0700 |
commit | 1392888ec4ddb37d4fa7bbb9231712af0dda4ea6 (patch) | |
tree | fa427ad63edb8ee4a506b0dbb22508741116dfa3 /numpy/lib/twodim_base.py | |
parent | aba6adbc9d4ffa6b9b4fcc41afcb4884944ce181 (diff) | |
download | numpy-1392888ec4ddb37d4fa7bbb9231712af0dda4ea6.tar.gz |
BUG: Use `np.where` in np.triu/np.tril, fixes #4859
Replaces the current method to zero items, from multiplication to
using `np.where`.
Diffstat (limited to 'numpy/lib/twodim_base.py')
-rw-r--r-- | numpy/lib/twodim_base.py | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/numpy/lib/twodim_base.py b/numpy/lib/twodim_base.py index a8925592a..f26ff0619 100644 --- a/numpy/lib/twodim_base.py +++ b/numpy/lib/twodim_base.py @@ -387,7 +387,6 @@ def tri(N, M=None, k=0, dtype=float): dtype : dtype, optional Data type of the returned array. The default is float. - Returns ------- tri : ndarray of shape (N, M) @@ -452,7 +451,9 @@ def tril(m, k=0): """ m = asanyarray(m) - return multiply(tri(*m.shape[-2:], k=k, dtype=bool), m, dtype=m.dtype) + mask = tri(*m.shape[-2:], k=k, dtype=bool) + + return where(mask, m, 0) def triu(m, k=0): @@ -478,7 +479,9 @@ def triu(m, k=0): """ m = asanyarray(m) - return multiply(~tri(*m.shape[-2:], k=k-1, dtype=bool), m, dtype=m.dtype) + mask = tri(*m.shape[-2:], k=k-1, dtype=bool) + + return where(mask, 0, m) # Originally borrowed from John Hunter and matplotlib |