summaryrefslogtreecommitdiff
path: root/numpy/lib/twodim_base.py
diff options
context:
space:
mode:
authorjaimefrio <jaime.frio@gmail.com>2014-07-10 09:03:16 -0700
committerjaimefrio <jaime.frio@gmail.com>2014-07-10 11:44:50 -0700
commit1392888ec4ddb37d4fa7bbb9231712af0dda4ea6 (patch)
treefa427ad63edb8ee4a506b0dbb22508741116dfa3 /numpy/lib/twodim_base.py
parentaba6adbc9d4ffa6b9b4fcc41afcb4884944ce181 (diff)
downloadnumpy-1392888ec4ddb37d4fa7bbb9231712af0dda4ea6.tar.gz
BUG: Use `np.where` in np.triu/np.tril, fixes #4859
Replaces the current method to zero items, from multiplication to using `np.where`.
Diffstat (limited to 'numpy/lib/twodim_base.py')
-rw-r--r--numpy/lib/twodim_base.py9
1 files changed, 6 insertions, 3 deletions
diff --git a/numpy/lib/twodim_base.py b/numpy/lib/twodim_base.py
index a8925592a..f26ff0619 100644
--- a/numpy/lib/twodim_base.py
+++ b/numpy/lib/twodim_base.py
@@ -387,7 +387,6 @@ def tri(N, M=None, k=0, dtype=float):
dtype : dtype, optional
Data type of the returned array. The default is float.
-
Returns
-------
tri : ndarray of shape (N, M)
@@ -452,7 +451,9 @@ def tril(m, k=0):
"""
m = asanyarray(m)
- return multiply(tri(*m.shape[-2:], k=k, dtype=bool), m, dtype=m.dtype)
+ mask = tri(*m.shape[-2:], k=k, dtype=bool)
+
+ return where(mask, m, 0)
def triu(m, k=0):
@@ -478,7 +479,9 @@ def triu(m, k=0):
"""
m = asanyarray(m)
- return multiply(~tri(*m.shape[-2:], k=k-1, dtype=bool), m, dtype=m.dtype)
+ mask = tri(*m.shape[-2:], k=k-1, dtype=bool)
+
+ return where(mask, 0, m)
# Originally borrowed from John Hunter and matplotlib