summaryrefslogtreecommitdiff
path: root/mlir/lib/ExecutionEngine
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/ExecutionEngine')
-rw-r--r--mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp16
-rw-r--r--mlir/lib/ExecutionEngine/RocmRuntimeWrappers.cpp16
2 files changed, 32 insertions, 0 deletions
diff --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
index 44ed5b0cd205..4065c6531669 100644
--- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
@@ -192,6 +192,22 @@ mgpuMemHostRegisterMemRef(int64_t rank, StridedMemRefType<char, 1> *descriptor,
mgpuMemHostRegister(ptr, sizeBytes);
}
+// Allows to unregister byte array with the CUDA runtime.
+extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuMemHostUnregister(void *ptr) {
+ ScopedContext scopedContext;
+ CUDA_REPORT_IF_ERROR(cuMemHostUnregister(ptr));
+}
+
+/// Unregisters a memref with the CUDA runtime. `descriptor` is a pointer to a
+/// ranked memref descriptor struct of rank `rank`
+extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
+mgpuMemHostUnregisterMemRef(int64_t rank,
+ StridedMemRefType<char, 1> *descriptor,
+ int64_t elementSizeBytes) {
+ auto *ptr = descriptor->data + descriptor->offset * elementSizeBytes;
+ mgpuMemHostUnregister(ptr);
+}
+
extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSetDefaultDevice(int32_t device) {
defaultDevice = device;
}
diff --git a/mlir/lib/ExecutionEngine/RocmRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/RocmRuntimeWrappers.cpp
index 43a7e3c62089..bd3868a8e196 100644
--- a/mlir/lib/ExecutionEngine/RocmRuntimeWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/RocmRuntimeWrappers.cpp
@@ -152,6 +152,22 @@ mgpuMemHostRegisterMemRef(int64_t rank, StridedMemRefType<char, 1> *descriptor,
mgpuMemHostRegister(ptr, sizeBytes);
}
+// Allows to unregister byte array with the ROCM runtime. Helpful until we have
+// transfer functions implemented.
+extern "C" void mgpuMemHostUnregister(void *ptr) {
+ HIP_REPORT_IF_ERROR(hipHostUnregister(ptr));
+}
+
+// Allows to unregister a MemRef with the ROCm runtime. Helpful until we have
+// transfer functions implemented.
+extern "C" void
+mgpuMemHostUnregisterMemRef(int64_t rank,
+ StridedMemRefType<char, 1> *descriptor,
+ int64_t elementSizeBytes) {
+ auto ptr = descriptor->data + descriptor->offset * elementSizeBytes;
+ mgpuMemHostUnregister(ptr);
+}
+
template <typename T>
void mgpuMemGetDevicePointer(T *hostPtr, T **devicePtr) {
HIP_REPORT_IF_ERROR(hipSetDevice(0));