summaryrefslogtreecommitdiff
path: root/numpy/lib
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib')
-rw-r--r--numpy/lib/recfunctions.py23
-rw-r--r--numpy/lib/tests/test_recfunctions.py13
2 files changed, 25 insertions, 11 deletions
diff --git a/numpy/lib/recfunctions.py b/numpy/lib/recfunctions.py
index 08faeee0e..e42421786 100644
--- a/numpy/lib/recfunctions.py
+++ b/numpy/lib/recfunctions.py
@@ -963,27 +963,28 @@ def join_by(key, r1, r2, jointype='inner', r1postfix='1', r2postfix='2',
ndtype = [list(_) for _ in r1k.dtype.descr]
# Add the other fields
ndtype.extend(list(_) for _ in r1.dtype.descr if _[0] not in key)
- # Find the new list of names (it may be different from r1names)
- names = list(_[0] for _ in ndtype)
+
for desc in r2.dtype.descr:
desc = list(desc)
- name = desc[0]
# Have we seen the current name already ?
- if name in names:
- nameidx = ndtype.index(desc)
+ name = desc[0]
+ names = list(_[0] for _ in ndtype)
+ try:
+ nameidx = names.index(name)
+ except ValueError:
+ #... we haven't: just add the description to the current list
+ ndtype.append(desc)
+ else:
current = ndtype[nameidx]
- # The current field is part of the key: take the largest dtype
if name in key:
+ # The current field is part of the key: take the largest dtype
current[-1] = max(desc[1], current[-1])
- # The current field is not part of the key: add the suffixes
else:
+ # The current field is not part of the key: add the suffixes,
+ # and place the new field adjacent to the old one
current[0] += r1postfix
desc[0] += r2postfix
ndtype.insert(nameidx + 1, desc)
- #... we haven't: just add the description to the current list
- else:
- names.extend(desc[0])
- ndtype.append(desc)
# Revert the elements to tuples
ndtype = [tuple(_) for _ in ndtype]
# Find the largest nb of common fields :
diff --git a/numpy/lib/tests/test_recfunctions.py b/numpy/lib/tests/test_recfunctions.py
index e9cfa4993..a5d15cb24 100644
--- a/numpy/lib/tests/test_recfunctions.py
+++ b/numpy/lib/tests/test_recfunctions.py
@@ -656,6 +656,19 @@ class TestJoinBy(TestCase):
b = np.ones(3, dtype=[('c', 'u1'), ('b', 'f4'), ('a', 'i4')])
assert_raises(ValueError, join_by, ['a', 'b', 'b'], a, b)
+ def test_same_name_different_dtypes(self):
+ # gh-9338
+ a_dtype = np.dtype([('key', 'S10'), ('value', '<f4')])
+ b_dtype = np.dtype([('key', 'S10'), ('value', '<f8')])
+ expected_dtype = np.dtype([
+ ('key', '|S10'), ('value1', '<f4'), ('value2', '<f8')])
+
+ a = np.array([('Sarah', 8.0), ('John', 6.0)], dtype=a_dtype)
+ b = np.array([('Sarah', 10.0), ('John', 7.0)], dtype=b_dtype)
+ res = join_by('key', a, b)
+
+ assert_equal(res.dtype, expected_dtype)
+
class TestJoinBy2(TestCase):
@classmethod