summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPauli Virtanen <pav@iki.fi>2010-10-10 23:50:53 +0200
committerPauli Virtanen <pav@iki.fi>2010-10-11 00:01:25 +0200
commitd7ff9074fcde66225478d6721cf22b2db32dc2fd (patch)
tree67fe8488accac8601159eb8d9bd197ce8a85195e
parent68e31fe815e0cb6276970a1c365f21e187d10ca0 (diff)
downloadnumpy-d7ff9074fcde66225478d6721cf22b2db32dc2fd.tar.gz
BUG: lib: clean up ancient-Python era stuff from IndexExpression (#1196)
-rw-r--r--numpy/lib/index_tricks.py11
-rw-r--r--numpy/lib/tests/test_index_tricks.py15
2 files changed, 15 insertions, 11 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py
index eb1ab22e9..264ebaad0 100644
--- a/numpy/lib/index_tricks.py
+++ b/numpy/lib/index_tricks.py
@@ -700,24 +700,15 @@ class IndexExpression(object):
array([2, 4])
"""
- maxint = sys.maxint
def __init__(self, maketuple):
self.maketuple = maketuple
def __getitem__(self, item):
- if self.maketuple and type(item) != type(()):
+ if self.maketuple and type(item) != tuple:
return (item,)
else:
return item
- def __len__(self):
- return self.maxint
-
- def __getslice__(self, start, stop):
- if stop == self.maxint:
- stop = None
- return self[start:stop:None]
-
index_exp = IndexExpression(maketuple=True)
s_ = IndexExpression(maketuple=False)
diff --git a/numpy/lib/tests/test_index_tricks.py b/numpy/lib/tests/test_index_tricks.py
index 3307cef3e..c17ee5d6a 100644
--- a/numpy/lib/tests/test_index_tricks.py
+++ b/numpy/lib/tests/test_index_tricks.py
@@ -2,7 +2,7 @@ from numpy.testing import *
import numpy as np
from numpy import ( array, ones, r_, mgrid, unravel_index, zeros, where,
ndenumerate, fill_diagonal, diag_indices,
- diag_indices_from )
+ diag_indices_from, s_, index_exp )
class TestUnravelIndex(TestCase):
def test_basic(self):
@@ -77,6 +77,19 @@ class TestNdenumerate(TestCase):
[((0,0), 1), ((0,1), 2), ((1,0), 3), ((1,1), 4)])
+class TestIndexExpression(TestCase):
+ def test_regression_1(self):
+ # ticket #1196
+ a = np.arange(2)
+ assert_equal(a[:-1], a[s_[:-1]])
+ assert_equal(a[:-1], a[index_exp[:-1]])
+
+ def test_simple_1(self):
+ a = np.random.rand(4,5,6)
+
+ assert_equal(a[:,:3,[1,2]], a[index_exp[:,:3,[1,2]]])
+ assert_equal(a[:,:3,[1,2]], a[s_[:,:3,[1,2]]])
+
def test_fill_diagonal():
a = zeros((3, 3),int)
fill_diagonal(a, 5)