summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
authorahaldane <ealloc@gmail.com>2017-05-10 21:03:20 -0400
committerGitHub <noreply@github.com>2017-05-10 21:03:20 -0400
commit81728dc0bd862b8cd0b2dfec05afdea9dca419a5 (patch)
treec62be4313d446c817de1fa08fe22d655ca6eafb4 /numpy/core
parent22ccc378253c3c2beefc603a2e0bccc5a3a9e1d1 (diff)
parent3cdcbe0f8402e795262341837afe39a279d9b888 (diff)
downloadnumpy-81728dc0bd862b8cd0b2dfec05afdea9dca419a5.tar.gz
Merge pull request #9083 from eric-wieser/fix-duplicate-order-error
MAINT: Improve error message from sorting with duplicate key
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/_internal.py14
-rw-r--r--numpy/core/tests/test_multiarray.py5
2 files changed, 14 insertions, 5 deletions
diff --git a/numpy/core/_internal.py b/numpy/core/_internal.py
index 9c46b3297..163145cdd 100644
--- a/numpy/core/_internal.py
+++ b/numpy/core/_internal.py
@@ -281,20 +281,26 @@ class _ctypes(object):
_as_parameter_ = property(get_as_parameter, None, doc="_as parameter_")
-# Given a datatype and an order object
-# return a new names tuple
-# with the order indicated
def _newnames(datatype, order):
+ """
+ Given a datatype and an order object, return a new names tuple, with the
+ order indicated
+ """
oldnames = datatype.names
nameslist = list(oldnames)
if isinstance(order, str):
order = [order]
+ seen = set()
if isinstance(order, (list, tuple)):
for name in order:
try:
nameslist.remove(name)
except ValueError:
- raise ValueError("unknown field name: %s" % (name,))
+ if name in seen:
+ raise ValueError("duplicate field name: %s" % (name,))
+ else:
+ raise ValueError("unknown field name: %s" % (name,))
+ seen.add(name)
return tuple(list(order) + nameslist)
raise ValueError("unsupported order value: %s" % (order,))
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 835d03528..e30cd09e7 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -30,7 +30,7 @@ from numpy.core.multiarray_tests import (
)
from numpy.testing import (
TestCase, run_module_suite, assert_, assert_raises, assert_warns,
- assert_equal, assert_almost_equal, assert_array_equal,
+ assert_equal, assert_almost_equal, assert_array_equal, assert_raises_regex,
assert_array_almost_equal, assert_allclose, IS_PYPY, HAS_REFCOUNT,
assert_array_less, runstring, dec, SkipTest, temppath, suppress_warnings
)
@@ -1538,6 +1538,9 @@ class TestMethods(TestCase):
assert_equal(r.word, np.array(['my', 'first', 'name']))
assert_equal(r.number, np.array([3.1, 4.5, 6.2]))
+ assert_raises_regex(ValueError, 'duplicate',
+ lambda: r.sort(order=['id', 'id']))
+
if sys.byteorder == 'little':
strtype = '>i2'
else: