diff options
author | Eric Wieser <wieser.eric@gmail.com> | 2017-03-22 22:22:28 +0000 |
---|---|---|
committer | Eric Wieser <wieser.eric@gmail.com> | 2017-05-05 21:42:17 +0100 |
commit | 37d756c46424b2da04a87e7df45a6c64f9b50117 (patch) | |
tree | 9520af9aabfba8e759dbd35e681e6be0e302e899 /numpy | |
parent | e6b8e75547af0cc4d38af458eff5e5d6c14102b8 (diff) | |
download | numpy-37d756c46424b2da04a87e7df45a6c64f9b50117.tar.gz |
BUG: Remove mutable state from AxisConcatenator
Fixes #8815
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/lib/index_tricks.py | 55 | ||||
-rw-r--r-- | numpy/lib/tests/test_index_tricks.py | 11 |
2 files changed, 37 insertions, 29 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py index 58d3e0dcf..719698bc8 100644 --- a/numpy/lib/index_tricks.py +++ b/numpy/lib/index_tricks.py @@ -10,12 +10,12 @@ from numpy.core.numeric import ( from numpy.core.numerictypes import find_common_type, issubdtype from . import function_base -import numpy.matrixlib as matrix +import numpy.matrixlib as matrixlib from .function_base import diff from numpy.core.multiarray import ravel_multi_index, unravel_index from numpy.lib.stride_tricks import as_strided -makemat = matrix.matrix +makemat = matrixlib.matrix __all__ = [ @@ -235,44 +235,36 @@ class AxisConcatenator(object): Translates slice objects to concatenation along an axis. For detailed documentation on usage, see `r_`. - """ # allow ma.mr_ to override this concatenate = staticmethod(_nx.concatenate) - - def _retval(self, res): - if self.matrix: - oldndim = res.ndim - res = makemat(res) - if oldndim == 1 and self.col: - res = res.T - self.axis = self._axis - self.matrix = self._matrix - self.col = 0 - return res - def __init__(self, axis=0, matrix=False, ndmin=1, trans1d=-1): - self._axis = axis - self._matrix = matrix self.axis = axis self.matrix = matrix - self.col = 0 self.trans1d = trans1d self.ndmin = ndmin def __getitem__(self, key): - trans1d = self.trans1d - ndmin = self.ndmin + # handle matrix builder syntax if isinstance(key, str): frame = sys._getframe().f_back - mymat = matrix.bmat(key, frame.f_globals, frame.f_locals) + mymat = matrixlib.bmat(key, frame.f_globals, frame.f_locals) return mymat + if not isinstance(key, tuple): key = (key,) + + # copy attributes, since they can be overriden in the first argument + trans1d = self.trans1d + ndmin = self.ndmin + matrix = self.matrix + axis = self.axis + objs = [] scalars = [] arraytypes = [] scalartypes = [] + for k in range(len(key)): scalar = False if isinstance(key[k], slice): @@ -298,21 +290,20 @@ class AxisConcatenator(object): "first entry.") key0 = key[0] if key0 in 'rc': - self.matrix = True - self.col = (key0 == 'c') + matrix = True + col = (key0 == 'c') continue if ',' in key0: vec = key0.split(',') try: - self.axis, ndmin = \ - [int(x) for x in vec[:2]] + axis, ndmin = [int(x) for x in vec[:2]] if len(vec) == 3: trans1d = int(vec[2]) continue except: raise ValueError("unknown special directive") try: - self.axis = int(key[k]) + axis = int(key[k]) continue except (ValueError, TypeError): raise ValueError("unknown special directive") @@ -341,14 +332,20 @@ class AxisConcatenator(object): if not scalar and isinstance(newobj, _nx.ndarray): arraytypes.append(newobj.dtype) - # Esure that scalars won't up-cast unless warranted + # Ensure that scalars won't up-cast unless warranted final_dtype = find_common_type(arraytypes, scalartypes) if final_dtype is not None: for k in scalars: objs[k] = objs[k].astype(final_dtype) - res = self.concatenate(tuple(objs), axis=self.axis) - return self._retval(res) + res = self.concatenate(tuple(objs), axis=axis) + + if matrix: + oldndim = res.ndim + res = makemat(res) + if oldndim == 1 and col: + res = res.T + return res def __len__(self): return 0 diff --git a/numpy/lib/tests/test_index_tricks.py b/numpy/lib/tests/test_index_tricks.py index d9fa1f43e..434dd14b4 100644 --- a/numpy/lib/tests/test_index_tricks.py +++ b/numpy/lib/tests/test_index_tricks.py @@ -174,6 +174,17 @@ class TestConcatenator(TestCase): assert_array_equal(d[:5, :], b) assert_array_equal(d[5:, :], c) + def test_matrix_builder(self): + a = np.array([1]) + b = np.array([2]) + c = np.array([3]) + d = np.array([4]) + actual = np.r_['a, b; c, d'] + expected = np.bmat([[a, b], [c, d]]) + + assert_equal(actual, expected) + assert_equal(type(actual), type(expected)) + class TestNdenumerate(TestCase): def test_basic(self): |