summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
authorMarten van Kerkwijk <mhvk@astro.utoronto.ca>2017-02-26 12:48:14 -0500
committerGitHub <noreply@github.com>2017-02-26 12:48:14 -0500
commitad8afe82e7b7643607a348c0e02b45c9131c6a06 (patch)
treee75ca21f4a8bdda9de4d7465ace4b0f0bc9770de /numpy/core
parentc47198b4102cf3974e7fee2aa78d822877e391c2 (diff)
parent3461ea868bc7520d4569f7d61397044cf4e593d0 (diff)
downloadnumpy-ad8afe82e7b7643607a348c0e02b45c9131c6a06.tar.gz
Merge pull request #8663 from eric-wieser/subclass-repr
ENH: Fix alignment of repr for array subclasses
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/numeric.py20
-rw-r--r--numpy/core/tests/test_arrayprint.py22
2 files changed, 32 insertions, 10 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index 896ad7f6a..01dd46c3c 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -1890,35 +1890,35 @@ def array_repr(arr, max_line_width=None, precision=None, suppress_small=None):
'array([ 0.000001, 0. , 2. , 3. ])'
"""
+ if type(arr) is not ndarray:
+ class_name = type(arr).__name__
+ else:
+ class_name = "array"
+
if arr.size > 0 or arr.shape == (0,):
lst = array2string(arr, max_line_width, precision, suppress_small,
- ', ', "array(")
+ ', ', class_name + "(")
else: # show zero-length shape unless it is (0,)
lst = "[], shape=%s" % (repr(arr.shape),)
- if arr.__class__ is not ndarray:
- cName = arr.__class__.__name__
- else:
- cName = "array"
-
skipdtype = (arr.dtype.type in _typelessdata) and arr.size > 0
if skipdtype:
- return "%s(%s)" % (cName, lst)
+ return "%s(%s)" % (class_name, lst)
else:
typename = arr.dtype.name
# Quote typename in the output if it is "complex".
if typename and not (typename[0].isalpha() and typename.isalnum()):
typename = "'%s'" % typename
- lf = ''
+ lf = ' '
if issubclass(arr.dtype.type, flexible):
if arr.dtype.names:
typename = "%s" % str(arr.dtype)
else:
typename = "'%s'" % str(arr.dtype)
- lf = '\n'+' '*len("array(")
- return cName + "(%s, %sdtype=%s)" % (lst, lf, typename)
+ lf = '\n'+' '*len(class_name + "(")
+ return "%s(%s,%sdtype=%s)" % (class_name, lst, lf, typename)
def array_str(a, max_line_width=None, precision=None, suppress_small=None):
diff --git a/numpy/core/tests/test_arrayprint.py b/numpy/core/tests/test_arrayprint.py
index 9aa7b2609..607fa7010 100644
--- a/numpy/core/tests/test_arrayprint.py
+++ b/numpy/core/tests/test_arrayprint.py
@@ -14,6 +14,28 @@ class TestArrayRepr(object):
x = np.array([np.nan, np.inf])
assert_equal(repr(x), 'array([ nan, inf])')
+ def test_subclass(self):
+ class sub(np.ndarray): pass
+
+ # one dimensional
+ x1d = np.array([1, 2]).view(sub)
+ assert_equal(repr(x1d), 'sub([1, 2])')
+
+ # two dimensional
+ x2d = np.array([[1, 2], [3, 4]]).view(sub)
+ assert_equal(repr(x2d),
+ 'sub([[1, 2],\n'
+ ' [3, 4]])')
+
+ # two dimensional with flexible dtype
+ xstruct = np.ones((2,2), dtype=[('a', 'i4')]).view(sub)
+ assert_equal(repr(xstruct),
+ "sub([[(1,), (1,)],\n"
+ " [(1,), (1,)]],\n"
+ " dtype=[('a', '<i4')])"
+ )
+
+
class TestComplexArray(TestCase):
def test_str(self):
rvals = [0, 1, -1, np.inf, -np.inf, np.nan]