summaryrefslogtreecommitdiff
path: root/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp')
-rw-r--r--mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp16
1 files changed, 16 insertions, 0 deletions
diff --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
index 44ed5b0cd205..4065c6531669 100644
--- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
@@ -192,6 +192,22 @@ mgpuMemHostRegisterMemRef(int64_t rank, StridedMemRefType<char, 1> *descriptor,
mgpuMemHostRegister(ptr, sizeBytes);
}
+// Allows to unregister byte array with the CUDA runtime.
+extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuMemHostUnregister(void *ptr) {
+ ScopedContext scopedContext;
+ CUDA_REPORT_IF_ERROR(cuMemHostUnregister(ptr));
+}
+
+/// Unregisters a memref with the CUDA runtime. `descriptor` is a pointer to a
+/// ranked memref descriptor struct of rank `rank`
+extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
+mgpuMemHostUnregisterMemRef(int64_t rank,
+ StridedMemRefType<char, 1> *descriptor,
+ int64_t elementSizeBytes) {
+ auto *ptr = descriptor->data + descriptor->offset * elementSizeBytes;
+ mgpuMemHostUnregister(ptr);
+}
+
extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSetDefaultDevice(int32_t device) {
defaultDevice = device;
}