diff options
Diffstat (limited to 'mlir/lib/ExecutionEngine')
-rw-r--r-- | mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp | 16 | ||||
-rw-r--r-- | mlir/lib/ExecutionEngine/RocmRuntimeWrappers.cpp | 16 |
2 files changed, 32 insertions, 0 deletions
diff --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp index 44ed5b0cd205..4065c6531669 100644 --- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp +++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp @@ -192,6 +192,22 @@ mgpuMemHostRegisterMemRef(int64_t rank, StridedMemRefType<char, 1> *descriptor, mgpuMemHostRegister(ptr, sizeBytes); } +// Allows to unregister byte array with the CUDA runtime. +extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuMemHostUnregister(void *ptr) { + ScopedContext scopedContext; + CUDA_REPORT_IF_ERROR(cuMemHostUnregister(ptr)); +} + +/// Unregisters a memref with the CUDA runtime. `descriptor` is a pointer to a +/// ranked memref descriptor struct of rank `rank` +extern "C" MLIR_CUDA_WRAPPERS_EXPORT void +mgpuMemHostUnregisterMemRef(int64_t rank, + StridedMemRefType<char, 1> *descriptor, + int64_t elementSizeBytes) { + auto *ptr = descriptor->data + descriptor->offset * elementSizeBytes; + mgpuMemHostUnregister(ptr); +} + extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSetDefaultDevice(int32_t device) { defaultDevice = device; } diff --git a/mlir/lib/ExecutionEngine/RocmRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/RocmRuntimeWrappers.cpp index 43a7e3c62089..bd3868a8e196 100644 --- a/mlir/lib/ExecutionEngine/RocmRuntimeWrappers.cpp +++ b/mlir/lib/ExecutionEngine/RocmRuntimeWrappers.cpp @@ -152,6 +152,22 @@ mgpuMemHostRegisterMemRef(int64_t rank, StridedMemRefType<char, 1> *descriptor, mgpuMemHostRegister(ptr, sizeBytes); } +// Allows to unregister byte array with the ROCM runtime. Helpful until we have +// transfer functions implemented. +extern "C" void mgpuMemHostUnregister(void *ptr) { + HIP_REPORT_IF_ERROR(hipHostUnregister(ptr)); +} + +// Allows to unregister a MemRef with the ROCm runtime. Helpful until we have +// transfer functions implemented. +extern "C" void +mgpuMemHostUnregisterMemRef(int64_t rank, + StridedMemRefType<char, 1> *descriptor, + int64_t elementSizeBytes) { + auto ptr = descriptor->data + descriptor->offset * elementSizeBytes; + mgpuMemHostUnregister(ptr); +} + template <typename T> void mgpuMemGetDevicePointer(T *hostPtr, T **devicePtr) { HIP_REPORT_IF_ERROR(hipSetDevice(0)); |