diff options
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/__init__.pyi | 2 | ||||
-rw-r--r-- | numpy/core/src/multiarray/dlpack.c | 6 |
2 files changed, 3 insertions, 5 deletions
diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi index 63e723a35..1562ce89e 100644 --- a/numpy/__init__.pyi +++ b/numpy/__init__.pyi @@ -4333,7 +4333,7 @@ class chararray(ndarray[_ShapeType, _CharDType]): # class MachAr: ... class _SupportsDLPack(Protocol[_T_contra]): - def __dlpack__(self, *, stream: Optional[int] = ...) -> _PyCapsule: ... + def __dlpack__(self, *, stream: None | _T_contra = ...) -> _PyCapsule: ... def from_dlpack(__obj: _SupportsDLPack[None]) -> NDArray[Any]: ... diff --git a/numpy/core/src/multiarray/dlpack.c b/numpy/core/src/multiarray/dlpack.c index f1591bb1f..9de304379 100644 --- a/numpy/core/src/multiarray/dlpack.c +++ b/numpy/core/src/multiarray/dlpack.c @@ -211,7 +211,6 @@ array_dlpack(PyArrayObject *self, managed->dl_tensor.device = device; managed->dl_tensor.dtype = managed_dtype; - int64_t *managed_shape_strides = PyMem_Malloc(sizeof(int64_t) * ndim * 2); if (managed_shape_strides == NULL) { PyErr_NoMemory(); @@ -307,7 +306,7 @@ from_dlpack(PyObject *NPY_UNUSED(self), PyObject *obj) { int typenum = -1; const uint8_t bits = managed->dl_tensor.dtype.bits; const npy_intp itemsize = bits / 8; - switch(managed->dl_tensor.dtype.code) { + switch (managed->dl_tensor.dtype.code) { case kDLInt: switch (bits) { @@ -356,8 +355,7 @@ from_dlpack(PyObject *NPY_UNUSED(self), PyObject *obj) { for (int i = 0; i < ndim; ++i) { shape[i] = managed->dl_tensor.shape[i]; // DLPack has elements as stride units, NumPy has bytes. - if (managed->dl_tensor.strides != NULL) - { + if (managed->dl_tensor.strides != NULL) { strides[i] = managed->dl_tensor.strides[i] * itemsize; } } |