summaryrefslogtreecommitdiff
path: root/numpy/_array_api
diff options
context:
space:
mode:
authorAaron Meurer <asmeurer@gmail.com>2021-04-15 15:31:33 -0600
committerAaron Meurer <asmeurer@gmail.com>2021-04-15 15:31:33 -0600
commit9af1cc60edd4fdbb7e9c18d124e639a44ce420c7 (patch)
tree458fbf2cbc465f7d5109f4fee9dc1109802ba0d2 /numpy/_array_api
parent844fcd39692da676a7204c6cc7feea428ba49609 (diff)
downloadnumpy-9af1cc60edd4fdbb7e9c18d124e639a44ce420c7.tar.gz
Use dtype objects instead of classes in the array API
The array API does not require any methods or behaviors on dtype objects, other than that they be literals that can be compared for equality and passed to dtype keywords in functions. Since dtype objects are already used by the dtype attribute of ndarray, this makes it consistent, so that func(dtype=<dtype>).dtype will give exactly <dtype> back, which will be the same thing as numpy._array_api.<dtype>. This also fixes an issue in the array API test suite due to the fact that dtype classes and objects are not equal as dictionary keys.
Diffstat (limited to 'numpy/_array_api')
-rw-r--r--numpy/_array_api/_dtypes.py17
1 files changed, 15 insertions, 2 deletions
diff --git a/numpy/_array_api/_dtypes.py b/numpy/_array_api/_dtypes.py
index d33ae1fce..c874763dd 100644
--- a/numpy/_array_api/_dtypes.py
+++ b/numpy/_array_api/_dtypes.py
@@ -1,6 +1,19 @@
-from .. import int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64
+import numpy as np
+
+# Note: we use dtype objects instead of dtype classes. The spec does not
+# require any behavior on dtypes other than equality.
+int8 = np.dtype('int8')
+int16 = np.dtype('int16')
+int32 = np.dtype('int32')
+int64 = np.dtype('int64')
+uint8 = np.dtype('uint8')
+uint16 = np.dtype('uint16')
+uint32 = np.dtype('uint32')
+uint64 = np.dtype('uint64')
+float32 = np.dtype('float32')
+float64 = np.dtype('float64')
# Note: This name is changed
-from .. import bool_ as bool
+bool = np.dtype('bool')
_all_dtypes = [int8, int16, int32, int64, uint8, uint16, uint32, uint64,
float32, float64, bool]