summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorEric Wieser <wieser.eric@gmail.com>2017-03-22 22:22:28 +0000
committerEric Wieser <wieser.eric@gmail.com>2017-05-05 21:42:17 +0100
commit37d756c46424b2da04a87e7df45a6c64f9b50117 (patch)
tree9520af9aabfba8e759dbd35e681e6be0e302e899 /numpy
parente6b8e75547af0cc4d38af458eff5e5d6c14102b8 (diff)
downloadnumpy-37d756c46424b2da04a87e7df45a6c64f9b50117.tar.gz
BUG: Remove mutable state from AxisConcatenator
Fixes #8815
Diffstat (limited to 'numpy')
-rw-r--r--numpy/lib/index_tricks.py55
-rw-r--r--numpy/lib/tests/test_index_tricks.py11
2 files changed, 37 insertions, 29 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py
index 58d3e0dcf..719698bc8 100644
--- a/numpy/lib/index_tricks.py
+++ b/numpy/lib/index_tricks.py
@@ -10,12 +10,12 @@ from numpy.core.numeric import (
from numpy.core.numerictypes import find_common_type, issubdtype
from . import function_base
-import numpy.matrixlib as matrix
+import numpy.matrixlib as matrixlib
from .function_base import diff
from numpy.core.multiarray import ravel_multi_index, unravel_index
from numpy.lib.stride_tricks import as_strided
-makemat = matrix.matrix
+makemat = matrixlib.matrix
__all__ = [
@@ -235,44 +235,36 @@ class AxisConcatenator(object):
Translates slice objects to concatenation along an axis.
For detailed documentation on usage, see `r_`.
-
"""
# allow ma.mr_ to override this
concatenate = staticmethod(_nx.concatenate)
-
- def _retval(self, res):
- if self.matrix:
- oldndim = res.ndim
- res = makemat(res)
- if oldndim == 1 and self.col:
- res = res.T
- self.axis = self._axis
- self.matrix = self._matrix
- self.col = 0
- return res
-
def __init__(self, axis=0, matrix=False, ndmin=1, trans1d=-1):
- self._axis = axis
- self._matrix = matrix
self.axis = axis
self.matrix = matrix
- self.col = 0
self.trans1d = trans1d
self.ndmin = ndmin
def __getitem__(self, key):
- trans1d = self.trans1d
- ndmin = self.ndmin
+ # handle matrix builder syntax
if isinstance(key, str):
frame = sys._getframe().f_back
- mymat = matrix.bmat(key, frame.f_globals, frame.f_locals)
+ mymat = matrixlib.bmat(key, frame.f_globals, frame.f_locals)
return mymat
+
if not isinstance(key, tuple):
key = (key,)
+
+ # copy attributes, since they can be overriden in the first argument
+ trans1d = self.trans1d
+ ndmin = self.ndmin
+ matrix = self.matrix
+ axis = self.axis
+
objs = []
scalars = []
arraytypes = []
scalartypes = []
+
for k in range(len(key)):
scalar = False
if isinstance(key[k], slice):
@@ -298,21 +290,20 @@ class AxisConcatenator(object):
"first entry.")
key0 = key[0]
if key0 in 'rc':
- self.matrix = True
- self.col = (key0 == 'c')
+ matrix = True
+ col = (key0 == 'c')
continue
if ',' in key0:
vec = key0.split(',')
try:
- self.axis, ndmin = \
- [int(x) for x in vec[:2]]
+ axis, ndmin = [int(x) for x in vec[:2]]
if len(vec) == 3:
trans1d = int(vec[2])
continue
except:
raise ValueError("unknown special directive")
try:
- self.axis = int(key[k])
+ axis = int(key[k])
continue
except (ValueError, TypeError):
raise ValueError("unknown special directive")
@@ -341,14 +332,20 @@ class AxisConcatenator(object):
if not scalar and isinstance(newobj, _nx.ndarray):
arraytypes.append(newobj.dtype)
- # Esure that scalars won't up-cast unless warranted
+ # Ensure that scalars won't up-cast unless warranted
final_dtype = find_common_type(arraytypes, scalartypes)
if final_dtype is not None:
for k in scalars:
objs[k] = objs[k].astype(final_dtype)
- res = self.concatenate(tuple(objs), axis=self.axis)
- return self._retval(res)
+ res = self.concatenate(tuple(objs), axis=axis)
+
+ if matrix:
+ oldndim = res.ndim
+ res = makemat(res)
+ if oldndim == 1 and col:
+ res = res.T
+ return res
def __len__(self):
return 0
diff --git a/numpy/lib/tests/test_index_tricks.py b/numpy/lib/tests/test_index_tricks.py
index d9fa1f43e..434dd14b4 100644
--- a/numpy/lib/tests/test_index_tricks.py
+++ b/numpy/lib/tests/test_index_tricks.py
@@ -174,6 +174,17 @@ class TestConcatenator(TestCase):
assert_array_equal(d[:5, :], b)
assert_array_equal(d[5:, :], c)
+ def test_matrix_builder(self):
+ a = np.array([1])
+ b = np.array([2])
+ c = np.array([3])
+ d = np.array([4])
+ actual = np.r_['a, b; c, d']
+ expected = np.bmat([[a, b], [c, d]])
+
+ assert_equal(actual, expected)
+ assert_equal(type(actual), type(expected))
+
class TestNdenumerate(TestCase):
def test_basic(self):