summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--libc/src/__support/RPC/rpc.h21
-rw-r--r--libc/test/integration/startup/gpu/CMakeLists.txt7
-rw-r--r--libc/test/integration/startup/gpu/rpc_interface_test.cpp43
-rw-r--r--libc/utils/gpu/loader/Server.h35
4 files changed, 103 insertions, 3 deletions
diff --git a/libc/src/__support/RPC/rpc.h b/libc/src/__support/RPC/rpc.h
index d448e190214f..1285f2b7cd50 100644
--- a/libc/src/__support/RPC/rpc.h
+++ b/libc/src/__support/RPC/rpc.h
@@ -36,6 +36,7 @@ enum Opcode : uint16_t {
PRINT_TO_STDERR = 1,
EXIT = 2,
TEST_INCREMENT = 3,
+ TEST_INTERFACE = 4,
};
/// A fixed size channel used to communicate between the RPC client and server.
@@ -252,7 +253,8 @@ template <bool InvertInbox> struct Process {
template <bool T> struct Port {
LIBC_INLINE Port(Process<T> &process, uint64_t lane_mask, uint64_t index,
uint32_t out)
- : process(process), lane_mask(lane_mask), index(index), out(out) {}
+ : process(process), lane_mask(lane_mask), index(index), out(out),
+ receive(false) {}
LIBC_INLINE ~Port() = default;
private:
@@ -278,13 +280,20 @@ public:
return process.get_packet(index).header.opcode;
}
- LIBC_INLINE void close() { process.unlock(lane_mask, index); }
+ LIBC_INLINE void close() {
+ // If the server last did a receive it needs to exchange ownership before
+ // closing the port.
+ if (receive && T)
+ out = process.invert_outbox(index, out);
+ process.unlock(lane_mask, index);
+ }
private:
Process<T> &process;
uint64_t lane_mask;
uint64_t index;
uint32_t out;
+ bool receive;
};
/// The RPC client used to make requests to the server.
@@ -325,10 +334,16 @@ template <bool T> template <typename F> LIBC_INLINE void Port<T>::send(F fill) {
process.invoke_rpc(fill, process.get_packet(index));
atomic_thread_fence(cpp::MemoryOrder::RELEASE);
out = process.invert_outbox(index, out);
+ receive = false;
}
/// Applies \p use to the shared buffer and acknowledges the send.
template <bool T> template <typename U> LIBC_INLINE void Port<T>::recv(U use) {
+ // We only exchange ownership of the buffer during a receive if we are waiting
+ // for a previous receive to finish.
+ if (receive)
+ out = process.invert_outbox(index, out);
+
uint32_t in = process.load_inbox(index);
// We need to wait until we own the buffer before receiving.
@@ -340,7 +355,7 @@ template <bool T> template <typename U> LIBC_INLINE void Port<T>::recv(U use) {
// Apply the \p use function to read the memory out of the buffer.
process.invoke_rpc(use, process.get_packet(index));
- out = process.invert_outbox(index, out);
+ receive = true;
}
/// Combines a send and receive into a single function.
diff --git a/libc/test/integration/startup/gpu/CMakeLists.txt b/libc/test/integration/startup/gpu/CMakeLists.txt
index d2028cc941f0..12ff49c46dbe 100644
--- a/libc/test/integration/startup/gpu/CMakeLists.txt
+++ b/libc/test/integration/startup/gpu/CMakeLists.txt
@@ -36,3 +36,10 @@ add_integration_test(
SRCS
init_fini_array_test.cpp
)
+
+add_integration_test(
+ startup_rpc_interface_test
+ SUITE libc-startup-tests
+ SRCS
+ rpc_interface_test.cpp
+)
diff --git a/libc/test/integration/startup/gpu/rpc_interface_test.cpp b/libc/test/integration/startup/gpu/rpc_interface_test.cpp
new file mode 100644
index 000000000000..b4b03ead31c1
--- /dev/null
+++ b/libc/test/integration/startup/gpu/rpc_interface_test.cpp
@@ -0,0 +1,43 @@
+//===-- Loader test to check the RPC interface with the loader ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "src/__support/GPU/utils.h"
+#include "src/__support/RPC/rpc_client.h"
+#include "test/IntegrationTest/test.h"
+
+using namespace __llvm_libc;
+
+// Test to ensure that we can use aribtrary combinations of sends and recieves
+// as long as they are mirrored.
+static void test_interface(bool end_with_send) {
+ uint64_t cnt = 0;
+ rpc::Client::Port port = rpc::client.open<rpc::TEST_INTERFACE>();
+ port.send([&](rpc::Buffer *buffer) { buffer->data[0] = end_with_send; });
+ port.send([&](rpc::Buffer *buffer) { buffer->data[0] = cnt = cnt + 1; });
+ port.recv([&](rpc::Buffer *buffer) { cnt = buffer->data[0]; });
+ port.send([&](rpc::Buffer *buffer) { buffer->data[0] = cnt = cnt + 1; });
+ port.recv([&](rpc::Buffer *buffer) { cnt = buffer->data[0]; });
+ port.send([&](rpc::Buffer *buffer) { buffer->data[0] = cnt = cnt + 1; });
+ port.send([&](rpc::Buffer *buffer) { buffer->data[0] = cnt = cnt + 1; });
+ port.recv([&](rpc::Buffer *buffer) { cnt = buffer->data[0]; });
+ port.recv([&](rpc::Buffer *buffer) { cnt = buffer->data[0]; });
+ if (end_with_send)
+ port.send([&](rpc::Buffer *buffer) { buffer->data[0] = cnt = cnt + 1; });
+ else
+ port.recv([&](rpc::Buffer *buffer) { cnt = buffer->data[0]; });
+ port.close();
+
+ ASSERT_TRUE(cnt == 9 && "Invalid number of increments");
+}
+
+TEST_MAIN(int argc, char **argv, char **envp) {
+ test_interface(true);
+ test_interface(false);
+
+ return 0;
+}
diff --git a/libc/utils/gpu/loader/Server.h b/libc/utils/gpu/loader/Server.h
index 89ef712e8596..c79277dd8fdc 100644
--- a/libc/utils/gpu/loader/Server.h
+++ b/libc/utils/gpu/loader/Server.h
@@ -57,6 +57,41 @@ void handle_server() {
});
break;
}
+ case __llvm_libc::rpc::Opcode::TEST_INTERFACE: {
+ uint64_t cnt = 0;
+ bool end_with_recv;
+ port->recv([&](__llvm_libc::rpc::Buffer *buffer) {
+ end_with_recv = buffer->data[0];
+ });
+ port->recv(
+ [&](__llvm_libc::rpc::Buffer *buffer) { cnt = buffer->data[0]; });
+ port->send([&](__llvm_libc::rpc::Buffer *buffer) {
+ buffer->data[0] = cnt = cnt + 1;
+ });
+ port->recv(
+ [&](__llvm_libc::rpc::Buffer *buffer) { cnt = buffer->data[0]; });
+ port->send([&](__llvm_libc::rpc::Buffer *buffer) {
+ buffer->data[0] = cnt = cnt + 1;
+ });
+ port->recv(
+ [&](__llvm_libc::rpc::Buffer *buffer) { cnt = buffer->data[0]; });
+ port->recv(
+ [&](__llvm_libc::rpc::Buffer *buffer) { cnt = buffer->data[0]; });
+ port->send([&](__llvm_libc::rpc::Buffer *buffer) {
+ buffer->data[0] = cnt = cnt + 1;
+ });
+ port->send([&](__llvm_libc::rpc::Buffer *buffer) {
+ buffer->data[0] = cnt = cnt + 1;
+ });
+ if (end_with_recv)
+ port->recv(
+ [&](__llvm_libc::rpc::Buffer *buffer) { cnt = buffer->data[0]; });
+ else
+ port->send([&](__llvm_libc::rpc::Buffer *buffer) {
+ buffer->data[0] = cnt = cnt + 1;
+ });
+ break;
+ }
default:
port->recv([](__llvm_libc::rpc::Buffer *buffer) {});
}