summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBas van Beek <43369155+BvB93@users.noreply.github.com>2022-09-07 09:22:54 +0200
committerCharles Harris <charlesr.harris@gmail.com>2022-09-07 07:37:54 -0600
commit222cc37acbfe3ef7d26bfe31e70e76642ac1b49b (patch)
tree7b9c03cface81e713229084e6b0a9e25366bf969
parente18dd98ca72441f3eead9c974550ccd75b2247dd (diff)
downloadnumpy-222cc37acbfe3ef7d26bfe31e70e76642ac1b49b.tar.gz
TYP,BUG: Reduce argument validation in C-based `__class_getitem__` (#22212)
Closes #22185 The __class_getitem__ implementations would previously perform basic validation of the passed value, i.e. it would check whether a tuple of the appropriate length was passed (e.g. np.dtype.__class_getitem__ would expect a single item or a length-1 tuple). As noted in aforementioned issue: this approach can cause issues when (a. 2 or more parameters are involved and (b. a subclasses is created one or more parameters are declared constant (e.g. a fixed dtype & variably shaped array). This PR fixes aforementioned issue by relaxing the runtime argument validation, thus mimicking the behavior of the standard library (more closely). While we could alternatively fix this by adding more special casing (e.g. only disable validation when cls is not np.ndarray), I'm not convinced this would be worth the additional complexity, especially since the standard library also has zero runtime validation for all of its Py_GenericAlias-based implementations of __class_getitem__. (Some edits by seberg to the commit message)
-rw-r--r--numpy/core/src/multiarray/methods.c2
-rw-r--r--numpy/core/src/multiarray/scalartypes.c.src2
-rw-r--r--numpy/core/tests/test_arraymethod.py27
-rw-r--r--numpy/core/tests/test_scalar_methods.py10
4 files changed, 25 insertions, 16 deletions
diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c
index f10f68ea5..3e8c78dd0 100644
--- a/numpy/core/src/multiarray/methods.c
+++ b/numpy/core/src/multiarray/methods.c
@@ -2829,7 +2829,7 @@ array_class_getitem(PyObject *cls, PyObject *args)
Py_ssize_t args_len;
args_len = PyTuple_Check(args) ? PyTuple_Size(args) : 1;
- if (args_len != 2) {
+ if ((args_len > 2) || (args_len == 0)) {
return PyErr_Format(PyExc_TypeError,
"Too %s arguments for %s",
args_len > 2 ? "many" : "few",
diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src
index 459e5b222..e1f236001 100644
--- a/numpy/core/src/multiarray/scalartypes.c.src
+++ b/numpy/core/src/multiarray/scalartypes.c.src
@@ -1855,7 +1855,7 @@ numbertype_class_getitem_abc(PyObject *cls, PyObject *args)
}
args_len = PyTuple_Check(args) ? PyTuple_Size(args) : 1;
- if (args_len != args_len_expected) {
+ if ((args_len > args_len_expected) || (args_len == 0)) {
return PyErr_Format(PyExc_TypeError,
"Too %s arguments for %s",
args_len > args_len_expected ? "many" : "few",
diff --git a/numpy/core/tests/test_arraymethod.py b/numpy/core/tests/test_arraymethod.py
index 49aa9f6df..6b75d1921 100644
--- a/numpy/core/tests/test_arraymethod.py
+++ b/numpy/core/tests/test_arraymethod.py
@@ -3,9 +3,11 @@ This file tests the generic aspects of ArrayMethod. At the time of writing
this is private API, but when added, public API may be added here.
"""
+from __future__ import annotations
+
import sys
import types
-from typing import Any, Type
+from typing import Any
import pytest
@@ -63,28 +65,25 @@ class TestSimpleStridedCall:
@pytest.mark.skipif(sys.version_info < (3, 9), reason="Requires python 3.9")
+@pytest.mark.parametrize(
+ "cls", [np.ndarray, np.recarray, np.chararray, np.matrix, np.memmap]
+)
class TestClassGetItem:
- @pytest.mark.parametrize(
- "cls", [np.ndarray, np.recarray, np.chararray, np.matrix, np.memmap]
- )
- def test_class_getitem(self, cls: Type[np.ndarray]) -> None:
+ def test_class_getitem(self, cls: type[np.ndarray]) -> None:
"""Test `ndarray.__class_getitem__`."""
alias = cls[Any, Any]
assert isinstance(alias, types.GenericAlias)
assert alias.__origin__ is cls
@pytest.mark.parametrize("arg_len", range(4))
- def test_subscript_tuple(self, arg_len: int) -> None:
+ def test_subscript_tup(self, cls: type[np.ndarray], arg_len: int) -> None:
arg_tup = (Any,) * arg_len
- if arg_len == 2:
- assert np.ndarray[arg_tup]
+ if arg_len in (1, 2):
+ assert cls[arg_tup]
else:
- with pytest.raises(TypeError):
- np.ndarray[arg_tup]
-
- def test_subscript_scalar(self) -> None:
- with pytest.raises(TypeError):
- np.ndarray[Any]
+ match = f"Too {'few' if arg_len == 0 else 'many'} arguments"
+ with pytest.raises(TypeError, match=match):
+ cls[arg_tup]
@pytest.mark.skipif(sys.version_info >= (3, 9), reason="Requires python 3.8")
diff --git a/numpy/core/tests/test_scalar_methods.py b/numpy/core/tests/test_scalar_methods.py
index eef4c1433..769bfd500 100644
--- a/numpy/core/tests/test_scalar_methods.py
+++ b/numpy/core/tests/test_scalar_methods.py
@@ -153,6 +153,16 @@ class TestClassGetItem:
assert isinstance(alias, types.GenericAlias)
assert alias.__origin__ is np.complexfloating
+ @pytest.mark.parametrize("arg_len", range(4))
+ def test_abc_complexfloating_subscript_tuple(self, arg_len: int) -> None:
+ arg_tup = (Any,) * arg_len
+ if arg_len in (1, 2):
+ assert np.complexfloating[arg_tup]
+ else:
+ match = f"Too {'few' if arg_len == 0 else 'many'} arguments"
+ with pytest.raises(TypeError, match=match):
+ np.complexfloating[arg_tup]
+
@pytest.mark.parametrize("cls", [np.generic, np.flexible, np.character])
def test_abc_non_numeric(self, cls: Type[np.generic]) -> None:
with pytest.raises(TypeError):