summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPatrick Freed <patrick.freed@mongodb.com>2023-05-17 13:23:13 +0000
committerEvergreen Agent <no-reply@evergreen.mongodb.com>2023-05-17 14:28:11 +0000
commit84475222a82ded97409472d83b16afb1c08a9af8 (patch)
tree07c94f2cc924cf99801d67658391a6376ad262e6
parenta4cffa9d6a2c902101f19b3b54af2714b0f4af4a (diff)
downloadmongo-84475222a82ded97409472d83b16afb1c08a9af8.tar.gz
SERVER-74015 Introduce mocked gRPC stub and client stream
-rw-r--r--src/mongo/transport/grpc/SConscript1
-rw-r--r--src/mongo/transport/grpc/bidirectional_pipe.h31
-rw-r--r--src/mongo/transport/grpc/client_context.h92
-rw-r--r--src/mongo/transport/grpc/client_stream.h78
-rw-r--r--src/mongo/transport/grpc/grpc_server_context.h8
-rw-r--r--src/mongo/transport/grpc/grpc_session.h60
-rw-r--r--src/mongo/transport/grpc/grpc_session_test.cpp16
-rw-r--r--src/mongo/transport/grpc/mock_client_context.h43
-rw-r--r--src/mongo/transport/grpc/mock_client_stream.cpp50
-rw-r--r--src/mongo/transport/grpc/mock_client_stream.h43
-rw-r--r--src/mongo/transport/grpc/mock_server_context.cpp4
-rw-r--r--src/mongo/transport/grpc/mock_server_context.h7
-rw-r--r--src/mongo/transport/grpc/mock_server_stream.cpp52
-rw-r--r--src/mongo/transport/grpc/mock_server_stream.h58
-rw-r--r--src/mongo/transport/grpc/mock_server_stream_test.cpp66
-rw-r--r--src/mongo/transport/grpc/mock_stub.h217
-rw-r--r--src/mongo/transport/grpc/mock_stub_test.cpp204
-rw-r--r--src/mongo/transport/grpc/mock_util.h49
-rw-r--r--src/mongo/transport/grpc/server_context.h11
-rw-r--r--src/mongo/transport/grpc/service_test.cpp38
-rw-r--r--src/mongo/transport/grpc/test_fixtures.h72
-rw-r--r--src/mongo/util/producer_consumer_queue.h5
22 files changed, 1078 insertions, 127 deletions
diff --git a/src/mongo/transport/grpc/SConscript b/src/mongo/transport/grpc/SConscript
index 17ef132dff3..b3317cbaff8 100644
--- a/src/mongo/transport/grpc/SConscript
+++ b/src/mongo/transport/grpc/SConscript
@@ -48,6 +48,7 @@ env.CppUnitTest(
'grpc_session_test.cpp',
'grpc_transport_layer_test.cpp',
'mock_server_stream_test.cpp',
+ 'mock_stub_test.cpp',
'server_test.cpp',
'service_test.cpp',
],
diff --git a/src/mongo/transport/grpc/bidirectional_pipe.h b/src/mongo/transport/grpc/bidirectional_pipe.h
index f1c29471eb0..b53d61b519c 100644
--- a/src/mongo/transport/grpc/bidirectional_pipe.h
+++ b/src/mongo/transport/grpc/bidirectional_pipe.h
@@ -89,20 +89,36 @@ public:
}
/**
- * Close both ends of the pipe. In progress reads and writes on either end will be
- * interrupted.
+ * Close both the read and write halves of this end of the pipe. In-progress reads and
+ * writes on this end and writes on the other end will be interrupted.
+ *
+ * Messages that have already been transmitted through this end of the pipe can still be
+ * read by the other end.
*/
void close() {
_sendHalf.close();
_recvHalf.close();
}
+ /**
+ * Returns true when at least one of the following conditions is met:
+ * - This end of the pipe is closed.
+ * - The other end of the pipe is closed and there are no more messages to be read.
+ */
+ bool isConsumed() const {
+ auto stats = _recvHalfCtrl.getStats();
+ return stats.consumerEndClosed || (stats.queueDepth == 0 && stats.producerEndClosed);
+ }
+
private:
friend BidirectionalPipe;
explicit End(SingleProducerSingleConsumerQueue<SharedBuffer>::Producer send,
- SingleProducerSingleConsumerQueue<SharedBuffer>::Consumer recv)
- : _sendHalf{std::move(send)}, _recvHalf{std::move(recv)} {}
+ SingleProducerSingleConsumerQueue<SharedBuffer>::Consumer recv,
+ SingleProducerSingleConsumerQueue<SharedBuffer>::Controller recvCtrl)
+ : _sendHalf{std::move(send)},
+ _recvHalf{std::move(recv)},
+ _recvHalfCtrl(std::move(recvCtrl)) {}
bool _isPipeClosedError(const DBException& e) const {
return e.code() == ErrorCodes::ProducerConsumerQueueEndClosed ||
@@ -111,14 +127,17 @@ public:
SingleProducerSingleConsumerQueue<SharedBuffer>::Producer _sendHalf;
SingleProducerSingleConsumerQueue<SharedBuffer>::Consumer _recvHalf;
+ SingleProducerSingleConsumerQueue<SharedBuffer>::Controller _recvHalfCtrl;
};
BidirectionalPipe() {
SingleProducerSingleConsumerQueue<SharedBuffer>::Pipe pipe1;
SingleProducerSingleConsumerQueue<SharedBuffer>::Pipe pipe2;
- left = std::unique_ptr<End>(new End(std::move(pipe1.producer), std::move(pipe2.consumer)));
- right = std::unique_ptr<End>(new End(std::move(pipe2.producer), std::move(pipe1.consumer)));
+ left = std::unique_ptr<End>(new End(
+ std::move(pipe1.producer), std::move(pipe2.consumer), std::move(pipe2.controller)));
+ right = std::unique_ptr<End>(new End(
+ std::move(pipe2.producer), std::move(pipe1.consumer), std::move(pipe1.controller)));
}
std::unique_ptr<End> left;
diff --git a/src/mongo/transport/grpc/client_context.h b/src/mongo/transport/grpc/client_context.h
new file mode 100644
index 00000000000..b4b8238cce6
--- /dev/null
+++ b/src/mongo/transport/grpc/client_context.h
@@ -0,0 +1,92 @@
+/**
+ * Copyright (C) 2023-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include <map>
+#include <string>
+
+#include "mongo/transport/grpc/metadata.h"
+#include "mongo/util/net/hostandport.h"
+#include "mongo/util/time_support.h"
+
+namespace mongo::transport::grpc {
+
+/**
+ * Base class modeling a gRPC ClientContext.
+ * See: https://grpc.github.io/grpc/cpp/classgrpc_1_1_client_context.html
+ */
+class ClientContext {
+public:
+ virtual ~ClientContext() = default;
+
+ /**
+ * Add an entry to the metadata associated with the RPC.
+ *
+ * This must only be called before invoking the RPC.
+ */
+ virtual void addMetadataEntry(const std::string& key, const std::string& value) = 0;
+
+ /**
+ * Retrieve the server's initial metadata.
+ *
+ * This must only be called after the first message has been received on the ClientStream
+ * created from the RPC that this context is associated with.
+ */
+ virtual boost::optional<const MetadataContainer&> getServerInitialMetadata() const = 0;
+
+ /**
+ * Set the deadline for the RPC to be executed using this context.
+ *
+ * This must only be called before invoking the RPC.
+ */
+ virtual void setDeadline(Date_t deadline) = 0;
+
+ virtual Date_t getDeadline() const = 0;
+
+ virtual HostAndPort getRemote() const = 0;
+
+ /**
+ * Send a best-effort out-of-band cancel on the call associated with this ClientContext. There
+ * is no guarantee the call will be cancelled (e.g. if the call has already finished by the time
+ * the cancellation is received).
+ *
+ * Note that tryCancel() will not impede the execution of any already scheduled work (e.g.
+ * messages already queued to be sent on a stream will still be sent), though the reported
+ * sucess or failure of such work may reflect the cancellation.
+ *
+ * This method is thread-safe, and can be called multiple times from any thread. It should not
+ * be called before this ClientContext has been used to invoke an RPC.
+ *
+ * See:
+ * https://grpc.github.io/grpc/cpp/classgrpc_1_1_client_context.html#abd0f6715c30287b75288015eee628984
+ */
+ virtual void tryCancel() = 0;
+};
+} // namespace mongo::transport::grpc
diff --git a/src/mongo/transport/grpc/client_stream.h b/src/mongo/transport/grpc/client_stream.h
new file mode 100644
index 00000000000..e1dcbe80827
--- /dev/null
+++ b/src/mongo/transport/grpc/client_stream.h
@@ -0,0 +1,78 @@
+/**
+ * Copyright (C) 2023-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include <boost/optional.hpp>
+#include <grpcpp/grpcpp.h>
+
+#include "mongo/util/shared_buffer.h"
+
+namespace mongo::transport::grpc {
+
+/**
+ * Base class modeling a synchronous client side of a gRPC stream.
+ * See: https://grpc.github.io/grpc/cpp/classgrpc_1_1_client_reader_writer.html
+ *
+ * ClientStream::read() is thread safe with respect to ClientStream::write(), but neither method
+ * should be called concurrently with another invocation of itself on the same stream.
+ *
+ * ClientStream::finish() is thread safe with respect to ClientStream::read().
+ */
+class ClientStream {
+public:
+ virtual ~ClientStream() = default;
+
+ /**
+ * Block to read a message from the stream.
+ *
+ * Returns boost::none if the stream is closed, either cleanly or due to an underlying
+ * connection failure.
+ */
+ virtual boost::optional<SharedBuffer> read() = 0;
+
+ /**
+ * Block to write a message to the stream.
+ *
+ * Returns true if the write was successful or false if it failed due to the stream being
+ * closed, either explicitly or due to an underlying connection failure.
+ */
+ virtual bool write(ConstSharedBuffer msg) = 0;
+
+ /**
+ * Block waiting until all received messages have been read and the stream has been closed.
+ *
+ * Returns the final status of the RPC associated with this stream.
+ *
+ * This method should only be called once.
+ */
+ virtual ::grpc::Status finish() = 0;
+};
+
+} // namespace mongo::transport::grpc
diff --git a/src/mongo/transport/grpc/grpc_server_context.h b/src/mongo/transport/grpc/grpc_server_context.h
index 400719eafa8..672074ddf07 100644
--- a/src/mongo/transport/grpc/grpc_server_context.h
+++ b/src/mongo/transport/grpc/grpc_server_context.h
@@ -53,7 +53,7 @@ public:
}
explicit GRPCServerContext(::grpc::ServerContext* ctx)
- : _ctx{ctx}, _hostAndPort{parseURI(_ctx->peer())} {
+ : _ctx{ctx}, _remote{parseURI(_ctx->peer())} {
for (auto& kvp : _ctx->client_metadata()) {
_clientMetadata.insert({StringData{kvp.first.data(), kvp.first.length()},
StringData{kvp.second.data(), kvp.second.length()}});
@@ -74,8 +74,8 @@ public:
return Date_t{_ctx->deadline()};
}
- HostAndPort getHostAndPort() const override {
- return _hostAndPort;
+ HostAndPort getRemote() const override {
+ return _remote;
}
bool isCancelled() const override {
@@ -89,7 +89,7 @@ public:
private:
::grpc::ServerContext* _ctx;
MetadataView _clientMetadata;
- HostAndPort _hostAndPort;
+ HostAndPort _remote;
};
} // namespace mongo::transport::grpc
diff --git a/src/mongo/transport/grpc/grpc_session.h b/src/mongo/transport/grpc/grpc_session.h
index 43d4d255442..7dbb7ffc7a0 100644
--- a/src/mongo/transport/grpc/grpc_session.h
+++ b/src/mongo/transport/grpc/grpc_session.h
@@ -66,28 +66,41 @@ public:
virtual boost::optional<UUID> clientId() const = 0;
/**
- * Terminates the underlying gRPC stream.
+ * Cancels the underlying gRPC stream and updates the termination status of the session.
+ * If this session is already terminated, this has no effect.
+ *
+ * It is an error to terminate a session with an OK status. Instead, provide a status that
+ * explains the reason for the cancellation or use end() if the intention is to mark the session
+ * as successfully terminated without cancelling the underlying stream.
*/
- virtual void terminate(Status status) {
- auto ts = _terminationStatus.synchronize();
- if (MONGO_unlikely(ts->has_value()))
- return;
- ts->emplace(std::move(status));
+ void terminate(Status status) {
+ tassert(7401590,
+ "gRPC sessions should only be manually terminated with non-OK statuses",
+ !status.isOK());
+
+ // Need to update terminationStatus before cancelling so that when the RPC caller/handler is
+ // interrupted, the it will be guaranteed to have access to the reason for cancellation.
+ if (_setTerminationStatus(std::move(status))) {
+ _tryCancel();
+ }
}
/**
* Returns the termination status (always set at termination). Remains unset until termination.
*/
boost::optional<Status> terminationStatus() const {
- return **_terminationStatus;
+ return *_terminationStatus;
}
TransportLayer* getTransportLayer() const final {
return _tl;
}
+ /**
+ * Marks the session as having terminated successfully.
+ */
void end() final {
- terminate(Status::OK());
+ _setTerminationStatus(Status::OK());
}
/**
@@ -145,6 +158,20 @@ public:
}
private:
+ /**
+ * Sets the termination status if it hasn't been set already.
+ * Returns whether the termination status was updated or not.
+ */
+ bool _setTerminationStatus(Status status) {
+ auto ts = _terminationStatus.synchronize();
+ if (MONGO_unlikely(ts->has_value()))
+ return false;
+ ts->emplace(std::move(status));
+ return true;
+ }
+
+ virtual void _tryCancel() = 0;
+
// TODO SERVER-74020: replace this with `GRPCTransportLayer`.
TransportLayer* const _tl;
@@ -168,7 +195,7 @@ public:
_ctx(ctx),
_stream(stream),
_clientId(std::move(clientId)),
- _remote(ctx->getHostAndPort()) {
+ _remote(ctx->getRemote()) {
LOGV2_DEBUG(7401101,
2,
"Constructed a new gRPC ingress session",
@@ -219,17 +246,6 @@ public:
return Status(ErrorCodes::StreamTerminated, "Unable to write to ingress session");
}
- void terminate(Status status) override {
- auto shouldCancel = !status.isOK() && !terminationStatus();
- // Need to invoke GRPCSession::terminate before cancelling the server context so that when
- // an RPC handler is interrupted, it will be guaranteed to have access to its termination
- // status.
- GRPCSession::terminate(std::move(status));
- if (MONGO_unlikely(shouldCancel)) {
- _ctx->tryCancel();
- }
- }
-
const HostAndPort& remote() const override {
return _remote;
}
@@ -239,6 +255,10 @@ public:
}
private:
+ void _tryCancel() override {
+ _ctx->tryCancel();
+ }
+
// _ctx and _stream are only valid while the RPC handler is still running. They should not be
// accessed after the stream has been terminated.
ServerContext* const _ctx;
diff --git a/src/mongo/transport/grpc/grpc_session_test.cpp b/src/mongo/transport/grpc/grpc_session_test.cpp
index 24baf977f1e..de4f7e2897e 100644
--- a/src/mongo/transport/grpc/grpc_session_test.cpp
+++ b/src/mongo/transport/grpc/grpc_session_test.cpp
@@ -61,8 +61,11 @@ public:
void setUp() override {
SessionTest::setUp();
- _fixture = std::make_unique<MockStreamTestFixtures>(
- HostAndPort{kRemote}, kStreamTimeout, _clientMetadata);
+ // The MockStreamTestFixtures created here doesn't contain any references to the channel or
+ // server, so it's okay to let stubFixture go out of scope.
+ MockStubTestFixtures stubFixture;
+ _fixture = stubFixture.makeStreamTestFixtures(
+ getServiceContext()->getFastClockSource()->now() + kStreamTimeout, _clientMetadata);
}
void tearDown() override {
@@ -109,7 +112,7 @@ TEST_F(IngressSessionTest, GetClientId) {
TEST_F(IngressSessionTest, GetRemote) {
auto session = makeSession();
- ASSERT_EQ(session->remote().toString(), kRemote);
+ ASSERT_EQ(session->remote().toString(), MockStubTestFixtures::kClientAddress);
}
TEST_F(IngressSessionTest, IsConnected) {
@@ -119,9 +122,9 @@ TEST_F(IngressSessionTest, IsConnected) {
ASSERT_FALSE(session->isConnected());
}
-TEST_F(IngressSessionTest, Terminate) {
+TEST_F(IngressSessionTest, End) {
auto session = makeSession();
- session->terminate(Status::OK());
+ session->end();
ASSERT_FALSE(session->isConnected());
ASSERT_TRUE(session->terminationStatus());
ASSERT_OK(*session->terminationStatus());
@@ -134,13 +137,14 @@ TEST_F(IngressSessionTest, TerminateWithError) {
ASSERT_FALSE(session->isConnected());
ASSERT_TRUE(session->terminationStatus());
ASSERT_EQ(*session->terminationStatus(), error);
+ ASSERT_TRUE(fixture()->serverCtx->isCancelled());
}
TEST_F(IngressSessionTest, TerminateRetainsStatus) {
const Status error(ErrorCodes::InternalError, "Some Error");
auto session = makeSession();
session->terminate(error);
- session->terminate(Status::OK());
+ session->end();
ASSERT_EQ(*session->terminationStatus(), error);
}
diff --git a/src/mongo/transport/grpc/mock_client_context.h b/src/mongo/transport/grpc/mock_client_context.h
index 04311414e35..5dad9078403 100644
--- a/src/mongo/transport/grpc/mock_client_context.h
+++ b/src/mongo/transport/grpc/mock_client_context.h
@@ -29,29 +29,54 @@
#pragma once
-#include <boost/optional.hpp>
-#include <map>
-#include <string>
+#include "mongo/transport/grpc/client_context.h"
#include "mongo/transport/grpc/mock_client_stream.h"
namespace mongo::transport::grpc {
-// TODO: SERVER-74015 introduce a ClientContext interface that covers the whole API surface of
-// gRPC's ClientContext type, and implement that interface here.
-class MockClientContext {
+class MockClientContext : public ClientContext {
public:
- explicit MockClientContext(MockClientStream* stream) : _stream{stream} {}
- ~MockClientContext() = default;
+ MockClientContext() : _deadline{Date_t::max()}, _stream{nullptr} {}
- boost::optional<const MetadataContainer&> getServerInitialMetadata() const {
+ void addMetadataEntry(const std::string& key, const std::string& value) override {
+ invariant(!_stream);
+ _metadata.insert({key, value});
+ };
+
+ boost::optional<const MetadataContainer&> getServerInitialMetadata() const override {
+ invariant(_stream);
if (!_stream->_serverInitialMetadata.isReady()) {
return boost::none;
}
return _stream->_serverInitialMetadata.get();
}
+ Date_t getDeadline() const override {
+ return _deadline;
+ }
+
+ void setDeadline(Date_t deadline) override {
+ invariant(!_stream);
+ _deadline = deadline;
+ }
+
+ HostAndPort getRemote() const override {
+ invariant(_stream);
+ return _stream->_remote;
+ }
+
+ void tryCancel() override {
+ invariant(_stream);
+ _stream->_cancel();
+ }
+
private:
+ friend class MockStub;
+ friend struct MockStreamTestFixtures;
+
+ Date_t _deadline;
+ MetadataContainer _metadata;
MockClientStream* _stream;
};
diff --git a/src/mongo/transport/grpc/mock_client_stream.cpp b/src/mongo/transport/grpc/mock_client_stream.cpp
index 0952c8be49c..dfc00f8c51a 100644
--- a/src/mongo/transport/grpc/mock_client_stream.cpp
+++ b/src/mongo/transport/grpc/mock_client_stream.cpp
@@ -29,26 +29,64 @@
#include "mongo/transport/grpc/mock_client_stream.h"
-#include "mongo/db/service_context.h"
#include "mongo/transport/grpc/mock_util.h"
+#include "mongo/util/interruptible.h"
namespace mongo::transport::grpc {
-MockClientStream::MockClientStream(HostAndPort hostAndPort,
- Milliseconds timeout,
+MockClientStream::MockClientStream(HostAndPort remote,
Future<MetadataContainer>&& initialMetadataFuture,
+ Future<::grpc::Status>&& rpcReturnStatus,
+ std::shared_ptr<MockCancellationState> rpcCancellationState,
BidirectionalPipe::End&& pipe)
- : _deadline{getGlobalServiceContext()->getFastClockSource()->now() + timeout},
+ : _remote{std::move(remote)},
_serverInitialMetadata{std::move(initialMetadataFuture)},
+ _rpcReturnStatus{std::move(rpcReturnStatus)},
+ _rpcCancellationState(std::move(rpcCancellationState)),
_pipe{std::move(pipe)} {}
boost::optional<SharedBuffer> MockClientStream::read() {
+ // Even if the server side handler of this stream has set a final status for the RPC (i.e.
+ // _rpcReturnStatus is ready), there may still be unread messages in the queue that the server
+ // sent before setting that status, so only return early here if the RPC was cancelled.
+ // Otherwise, try to read whatever messages are in the queue.
+ if (_rpcCancellationState->isCancelled()) {
+ return boost::none;
+ }
+
return runWithDeadline<boost::optional<SharedBuffer>>(
- _deadline, [&](Interruptible* i) { return _pipe.read(i); });
+ _rpcCancellationState->getDeadline(), [&](Interruptible* i) { return _pipe.read(i); });
}
bool MockClientStream::write(ConstSharedBuffer msg) {
- return runWithDeadline<bool>(_deadline, [&](Interruptible* i) { return _pipe.write(msg, i); });
+ if (_rpcCancellationState->isCancelled() || _rpcReturnStatus.isReady()) {
+ return false;
+ }
+
+ return runWithDeadline<bool>(_rpcCancellationState->getDeadline(),
+ [&](Interruptible* i) { return _pipe.write(msg, i); });
+}
+
+::grpc::Status MockClientStream::finish() {
+ // We use a busy wait here because there is no easy way to wait until all the messages in the
+ // pipe have been read.
+ while (!_pipe.isConsumed() && !_rpcCancellationState->isDeadlineExceeded()) {
+ sleepFor(Milliseconds(1));
+ }
+
+ invariant(_rpcReturnStatus.isReady() || _rpcCancellationState->isCancelled());
+
+ if (auto cancellationStatus = _rpcCancellationState->getCancellationStatus();
+ cancellationStatus.has_value()) {
+ return *cancellationStatus;
+ }
+
+ return _rpcReturnStatus.get();
+}
+
+void MockClientStream::_cancel() {
+ _rpcCancellationState->cancel(::grpc::Status::CANCELLED);
+ _pipe.close();
}
} // namespace mongo::transport::grpc
diff --git a/src/mongo/transport/grpc/mock_client_stream.h b/src/mongo/transport/grpc/mock_client_stream.h
index 694581c393d..f07af95f8d1 100644
--- a/src/mongo/transport/grpc/mock_client_stream.h
+++ b/src/mongo/transport/grpc/mock_client_stream.h
@@ -29,37 +29,52 @@
#pragma once
-#include <map>
-#include <string>
+#include "mongo/transport/grpc/client_stream.h"
#include "mongo/transport/grpc/bidirectional_pipe.h"
#include "mongo/transport/grpc/metadata.h"
-#include "mongo/transport/grpc/server_stream.h"
+#include "mongo/transport/grpc/mock_util.h"
#include "mongo/util/future.h"
#include "mongo/util/net/hostandport.h"
+#include "mongo/util/synchronized_value.h"
#include "mongo/util/time_support.h"
namespace mongo::transport::grpc {
-// TODO: SERVER-74015 introduce a ClientStream interface that covers the whole API surface of
-// gRPC's ClientReaderWriter type, and implement that interface here.
-class MockClientStream {
+class MockClientStream : public ClientStream {
public:
- ~MockClientStream() = default;
+ MockClientStream(HostAndPort remote,
+ Future<MetadataContainer>&& serverInitialMetadata,
+ Future<::grpc::Status>&& rpcReturnStatus,
+ std::shared_ptr<MockCancellationState> rpcCancellationState,
+ BidirectionalPipe::End&& pipe);
- boost::optional<SharedBuffer> read();
- bool write(ConstSharedBuffer msg);
+ boost::optional<SharedBuffer> read() override;
- explicit MockClientStream(HostAndPort hostAndPort,
- Milliseconds timeout,
- Future<MetadataContainer>&& serverInitialMetadata,
- BidirectionalPipe::End&& pipe);
+ bool write(ConstSharedBuffer msg) override;
+
+ ::grpc::Status finish() override;
private:
friend class MockClientContext;
- Date_t _deadline;
+ void _cancel();
+
+ HostAndPort _remote;
+ MetadataContainer _clientMetadata;
Future<MetadataContainer> _serverInitialMetadata;
+
+ /**
+ * The mocked equivalent of a status returned from a server-side RPC handler.
+ */
+ Future<::grpc::Status> _rpcReturnStatus;
+
+ /**
+ * State used to mock RPC cancellation, including explicit cancellation (client or server side)
+ * or network errors.
+ */
+ std::shared_ptr<MockCancellationState> _rpcCancellationState;
+
BidirectionalPipe::End _pipe;
};
} // namespace mongo::transport::grpc
diff --git a/src/mongo/transport/grpc/mock_server_context.cpp b/src/mongo/transport/grpc/mock_server_context.cpp
index a862fe17e49..a4b91ce5b01 100644
--- a/src/mongo/transport/grpc/mock_server_context.cpp
+++ b/src/mongo/transport/grpc/mock_server_context.cpp
@@ -40,11 +40,11 @@ const MetadataView& MockServerContext::getClientMetadata() const {
}
Date_t MockServerContext::getDeadline() const {
- return _stream->_deadline;
+ return _stream->_rpcCancellationState->getDeadline();
}
void MockServerContext::tryCancel() {
- _stream->close();
+ _stream->cancel(::grpc::Status::CANCELLED);
}
bool MockServerContext::isCancelled() const {
diff --git a/src/mongo/transport/grpc/mock_server_context.h b/src/mongo/transport/grpc/mock_server_context.h
index bd1fa05d47b..0c867f43e18 100644
--- a/src/mongo/transport/grpc/mock_server_context.h
+++ b/src/mongo/transport/grpc/mock_server_context.h
@@ -43,11 +43,8 @@ public:
const MetadataView& getClientMetadata() const override;
Date_t getDeadline() const override;
bool isCancelled() const override;
- HostAndPort getHostAndPort() const override {
- return _stream->_hostAndPort;
- }
- CancellationToken getCancellationToken() {
- return _stream->_cancellationSource.token();
+ HostAndPort getRemote() const override {
+ return _stream->_remote;
}
void tryCancel() override;
diff --git a/src/mongo/transport/grpc/mock_server_stream.cpp b/src/mongo/transport/grpc/mock_server_stream.cpp
index 8c8ca1df559..8fa89bdba90 100644
--- a/src/mongo/transport/grpc/mock_server_stream.cpp
+++ b/src/mongo/transport/grpc/mock_server_stream.cpp
@@ -31,42 +31,68 @@
#include "mongo/db/service_context.h"
#include "mongo/transport/grpc/mock_util.h"
+#include "mongo/util/assert_util.h"
#include "mongo/util/interruptible.h"
namespace mongo::transport::grpc {
-MockServerStream::MockServerStream(HostAndPort hostAndPort,
- Milliseconds timeout,
+MockServerStream::MockServerStream(HostAndPort remote,
Promise<MetadataContainer>&& initialMetadataPromise,
+ Promise<::grpc::Status>&& rpcTerminationStatusPromise,
+ std::shared_ptr<MockCancellationState> rpcCancellationState,
BidirectionalPipe::End&& serverPipeEnd,
MetadataView clientMetadata)
- : _deadline{getGlobalServiceContext()->getFastClockSource()->now() + timeout},
+ : _remote(std::move(remote)),
_initialMetadata(std::move(initialMetadataPromise)),
+ _rpcReturnStatus(std::move(rpcTerminationStatusPromise)),
+ _finalStatusReturned(false),
+ _rpcCancellationState(std::move(rpcCancellationState)),
_pipe{std::move(serverPipeEnd)},
- _clientMetadata{std::move(clientMetadata)},
- _hostAndPort(std::move(hostAndPort)) {}
+ _clientMetadata{std::move(clientMetadata)} {}
boost::optional<SharedBuffer> MockServerStream::read() {
+ invariant(!*_finalStatusReturned);
+
return runWithDeadline<boost::optional<SharedBuffer>>(
- _deadline, [&](Interruptible* i) { return _pipe.read(i); });
+ _rpcCancellationState->getDeadline(), [&](Interruptible* i) { return _pipe.read(i); });
}
bool MockServerStream::isCancelled() const {
- return _cancellationSource.token().isCanceled() ||
- getGlobalServiceContext()->getFastClockSource()->now() > _deadline;
+ return _rpcCancellationState->isCancelled();
}
bool MockServerStream::write(ConstSharedBuffer msg) {
- if (_cancellationSource.token().isCanceled() ||
- getGlobalServiceContext()->getFastClockSource()->now() > _deadline) {
+ invariant(!*_finalStatusReturned);
+ if (isCancelled()) {
return false;
}
+
_initialMetadata.trySend();
- return runWithDeadline<bool>(_deadline, [&](Interruptible* i) { return _pipe.write(msg, i); });
+ return runWithDeadline<bool>(_rpcCancellationState->getDeadline(),
+ [&](Interruptible* i) { return _pipe.write(msg, i); });
}
-void MockServerStream::close() {
- _cancellationSource.cancel();
+void MockServerStream::sendReturnStatus(::grpc::Status status) {
+ {
+ auto finalStatusReturned = _finalStatusReturned.synchronize();
+ invariant(!*finalStatusReturned);
+ *finalStatusReturned = true;
+ // Client side ignores the mocked return value in the event of a cancellation, so don't need
+ // to check if stream has been cancelled before sending the status.
+ }
+ _rpcReturnStatus.emplaceValue(std::move(status));
+ _pipe.close();
+}
+
+void MockServerStream::cancel(::grpc::Status status) {
+ // Only mark the RPC as cancelled if a status hasn't already been returned to client.
+ auto finalStatusReturned = _finalStatusReturned.synchronize();
+ if (*finalStatusReturned) {
+ return;
+ }
+ // Need to update the cancellation state before closing the pipe so that when a stream
+ // read/write is interrupted, the cancellation state will already be up to date.
+ _rpcCancellationState->cancel(std::move(status));
_pipe.close();
}
diff --git a/src/mongo/transport/grpc/mock_server_stream.h b/src/mongo/transport/grpc/mock_server_stream.h
index eb7177f46cc..88fbe51d5eb 100644
--- a/src/mongo/transport/grpc/mock_server_stream.h
+++ b/src/mongo/transport/grpc/mock_server_stream.h
@@ -32,13 +32,14 @@
#include <map>
#include <string>
+#include <grpcpp/grpcpp.h>
+
#include "mongo/transport/grpc/bidirectional_pipe.h"
#include "mongo/transport/grpc/metadata.h"
+#include "mongo/transport/grpc/mock_util.h"
#include "mongo/transport/grpc/server_stream.h"
-#include "mongo/util/cancellation.h"
#include "mongo/util/future.h"
#include "mongo/util/net/hostandport.h"
-#include "mongo/util/time_support.h"
namespace mongo::transport::grpc {
@@ -47,16 +48,19 @@ public:
~MockServerStream() = default;
boost::optional<SharedBuffer> read() override;
+
bool write(ConstSharedBuffer msg) override;
- explicit MockServerStream(HostAndPort hostAndPort,
- Milliseconds timeout,
+ explicit MockServerStream(HostAndPort remote,
Promise<MetadataContainer>&& initialMetadataPromise,
+ Promise<::grpc::Status>&& rpcTerminationStatusPromise,
+ std::shared_ptr<MockCancellationState> rpcCancellationState,
BidirectionalPipe::End&& serverPipeEnd,
MetadataView clientMetadata);
private:
friend class MockServerContext;
+ friend class MockRPC;
class InitialMetadata {
public:
@@ -86,13 +90,51 @@ private:
};
bool isCancelled() const;
- void close();
- CancellationSource _cancellationSource;
- Date_t _deadline;
+ /**
+ * Cancel the RPC associated with this stream. This is used for mocking situations in
+ * which an RPC handler was never able to return a final status to the client (e.g. manual
+ * cancellation or a network interruption).
+ *
+ * This method has no effect if the stream is already terminated.
+ */
+ void cancel(::grpc::Status status);
+
+ /**
+ * Closes the stream and sends the final return status of the RPC to the client. This is the
+ * mocked equivalent of an RPC handler returning a status.
+ *
+ * This does not mark the stream as cancelled.
+ *
+ * This method must only be called once, and this stream must not be used after this method has
+ * been called.
+ */
+ void sendReturnStatus(::grpc::Status status);
+
+ HostAndPort _remote;
InitialMetadata _initialMetadata;
+
+ /**
+ * _rpcReturnStatus is set in sendReturnStatus(), and it is used to mock returning a status from
+ * an RPC handler. sendReturnStatus itself is called via MockRPC::sendReturnStatus().
+ */
+ Promise<::grpc::Status> _rpcReturnStatus;
+
+ /**
+ * _finalStatusReturned is also set in sendReturnStatus(), and it is used to denote that a
+ * status has been returned and the stream should no longer be used.
+ */
+ synchronized_value<bool> _finalStatusReturned;
+
+ /**
+ * _rpcCancellationState is set via cancel(), which is called by either
+ * MockServerContext::tryCancel() or MockRPC::cancel(). It is used to mock situations in which a
+ * server RPC handler is unable to return a status to the client (e.g. explicit cancellation or
+ * a network interruption).
+ */
+ std::shared_ptr<MockCancellationState> _rpcCancellationState;
+
BidirectionalPipe::End _pipe;
MetadataView _clientMetadata;
- HostAndPort _hostAndPort;
};
} // namespace mongo::transport::grpc
diff --git a/src/mongo/transport/grpc/mock_server_stream_test.cpp b/src/mongo/transport/grpc/mock_server_stream_test.cpp
index 1452f9f3f4a..1f01a7b7ff0 100644
--- a/src/mongo/transport/grpc/mock_server_stream_test.cpp
+++ b/src/mongo/transport/grpc/mock_server_stream_test.cpp
@@ -33,6 +33,8 @@
#include <utility>
#include <vector>
+#include <grpcpp/support/status.h>
+
#include "mongo/db/concurrency/locker_noop_service_context_test_fixture.h"
#include "mongo/platform/mutex.h"
#include "mongo/rpc/message.h"
@@ -40,6 +42,7 @@
#include "mongo/transport/grpc/metadata.h"
#include "mongo/transport/grpc/mock_server_context.h"
#include "mongo/transport/grpc/mock_server_stream.h"
+#include "mongo/transport/grpc/mock_stub.h"
#include "mongo/transport/grpc/test_fixtures.h"
#include "mongo/unittest/assert.h"
#include "mongo/unittest/death_test.h"
@@ -47,6 +50,7 @@
#include "mongo/unittest/unittest.h"
#include "mongo/util/concurrency/notification.h"
#include "mongo/util/duration.h"
+#include "mongo/util/future.h"
#include "mongo/util/net/hostandport.h"
#include "mongo/util/scopeguard.h"
#include "mongo/util/system_clock_source.h"
@@ -58,32 +62,29 @@ template <class Base>
class MockServerStreamBase : public Base {
public:
static constexpr Milliseconds kTimeout = Milliseconds(100);
- static constexpr const char* kRemote = "abc:123";
virtual void setUp() override {
Base::setUp();
- _fixtures = std::make_unique<MockStreamTestFixtures>(
- HostAndPort{kRemote}, kTimeout, _clientMetadata);
- }
- MockStreamTestFixtures& getFixtures() {
- return *_fixtures;
+ MockStubTestFixtures fixtures;
+ _fixtures = fixtures.makeStreamTestFixtures(
+ Base::getServiceContext()->getFastClockSource()->now() + kTimeout, _clientMetadata);
}
MockServerStream& getServerStream() {
- return *getFixtures().serverStream;
+ return *_fixtures->serverStream;
}
MockServerContext& getServerContext() {
- return *getFixtures().serverCtx;
+ return *_fixtures->serverCtx;
}
- MockClientStream& getClientStream() {
- return *getFixtures().clientStream;
+ ClientStream& getClientStream() {
+ return *_fixtures->clientStream;
}
- MockClientContext& getClientContext() {
- return *getFixtures().clientCtx;
+ ClientContext& getClientContext() {
+ return *_fixtures->clientCtx;
}
const Message& getClientFirstMessage() const {
@@ -220,6 +221,7 @@ TEST_F(MockServerStreamTestWithMockedClockSource, DeadlineIsEnforced) {
ASSERT_FALSE(getServerStream().write(makeUniqueMessage().sharedBuffer()));
ASSERT_FALSE(getClientContext().getServerInitialMetadata());
ASSERT_FALSE(getClientStream().read());
+ ASSERT_EQ(getClientStream().finish().error_code(), ::grpc::StatusCode::DEADLINE_EXCEEDED);
}
TEST_F(MockServerStreamTest, TryCancelSubsequentServerRead) {
@@ -265,4 +267,44 @@ TEST_F(MockServerStreamTest, ClientMetadataIsAccessible) {
ASSERT_EQ(getServerContext().getClientMetadata(), getClientMetadata());
}
+TEST_F(MockServerStreamTest, ClientSideCancellation) {
+ ASSERT_TRUE(getClientStream().write(makeUniqueMessage().sharedBuffer()));
+ ASSERT_TRUE(getServerStream().read());
+
+ getClientContext().tryCancel();
+
+ ASSERT_FALSE(getClientStream().read());
+ ASSERT_FALSE(getClientStream().write(makeUniqueMessage().sharedBuffer()));
+ ASSERT_FALSE(getServerStream().read());
+ ASSERT_FALSE(getServerStream().write(makeUniqueMessage().sharedBuffer()));
+ ASSERT_TRUE(getServerContext().isCancelled());
+
+ ASSERT_EQ(getClientStream().finish().error_code(), ::grpc::StatusCode::CANCELLED);
+}
+
+TEST_F(MockServerStreamTest, CancellationInterruptsFinish) {
+ auto pf = makePromiseFuture<::grpc::Status>();
+ auto finishThread =
+ stdx::thread([&] { pf.promise.setWith([&] { return getClientStream().finish(); }); });
+ ON_BLOCK_EXIT([&] { finishThread.join(); });
+
+ // finish() won't return until server end hangs up too.
+ ASSERT_FALSE(pf.future.isReady());
+
+ getServerContext().tryCancel();
+ ASSERT_EQ(pf.future.get().error_code(), ::grpc::StatusCode::CANCELLED);
+}
+
+TEST_F(MockServerStreamTestWithMockedClockSource, DeadlineExceededInterruptsFinish) {
+ auto pf = makePromiseFuture<::grpc::Status>();
+ auto finishThread =
+ stdx::thread([&] { pf.promise.setWith([&] { return getClientStream().finish(); }); });
+ ON_BLOCK_EXIT([&] { finishThread.join(); });
+
+ // finish() won't return until server end hangs up too.
+ ASSERT_FALSE(pf.future.isReady());
+
+ clockSource().advance(kTimeout * 2);
+ ASSERT_EQ(pf.future.get().error_code(), ::grpc::StatusCode::DEADLINE_EXCEEDED);
+}
} // namespace mongo::transport::grpc
diff --git a/src/mongo/transport/grpc/mock_stub.h b/src/mongo/transport/grpc/mock_stub.h
new file mode 100644
index 00000000000..5af39aeda93
--- /dev/null
+++ b/src/mongo/transport/grpc/mock_stub.h
@@ -0,0 +1,217 @@
+/**
+ * Copyright (C) 2023-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include "mongo/base/string_data.h"
+#include "mongo/transport/grpc/mock_client_context.h"
+#include "mongo/transport/grpc/mock_server_context.h"
+#include "mongo/transport/grpc/mock_server_stream.h"
+#include "mongo/transport/grpc/mock_util.h"
+#include "mongo/transport/grpc/service.h"
+#include "mongo/unittest/thread_assertion_monitor.h"
+#include "mongo/util/assert_util.h"
+#include "mongo/util/concurrency/notification.h"
+#include "mongo/util/producer_consumer_queue.h"
+
+namespace mongo::transport::grpc {
+
+class MockRPC {
+public:
+ /**
+ * Close the stream and send the final return status of the RPC to the client. This is the
+ * mocked equivalent of returning a status from an RPC handler.
+ *
+ * The RPC's stream must not be used after calling this method.
+ * This method must only be called once.
+ */
+ void sendReturnStatus(::grpc::Status status) {
+ serverStream->sendReturnStatus(std::move(status));
+ }
+
+ /**
+ * Cancel the RPC with the provided status. This is used for mocking situations in which an RPC
+ * handler was never able to return a final status to the client (e.g. network interruption).
+ *
+ * For mocking an explicit server-side cancellation, use serverCtx->tryCancel().
+ * This method has no effect if the RPC has already been terminated, either by returning a
+ * status or prior cancellation.
+ */
+ void cancel(::grpc::Status status) {
+ serverStream->cancel(std::move(status));
+ }
+
+ StringData methodName;
+ std::unique_ptr<MockServerStream> serverStream;
+ std::unique_ptr<MockServerContext> serverCtx;
+};
+
+using MockRPCQueue = MultiProducerMultiConsumerQueue<std::pair<Promise<void>, MockRPC>>;
+
+class MockServer {
+public:
+ explicit MockServer(MockRPCQueue::Consumer queue) : _queue(std::move(queue)) {}
+
+ boost::optional<MockRPC> acceptRPC() {
+ try {
+ auto entry = _queue.pop();
+ entry.first.emplaceValue();
+ return std::move(entry.second);
+ } catch (const DBException& e) {
+ if (e.code() == ErrorCodes::ProducerConsumerQueueEndClosed ||
+ e.code() == ErrorCodes::ProducerConsumerQueueConsumed) {
+ return boost::none;
+ }
+ throw;
+ }
+ }
+
+ /**
+ * Starts up a thread that listens for incoming RPCs and then returns immediately.
+ * The listener thread will spawn a new thread for each RPC it receives and pass it to the
+ * provided handler.
+ *
+ * The provided handler is expected to throw assertion exceptions, hence the use of
+ * ThreadAssertionMonitor to spawn threads here.
+ */
+ void start(unittest::ThreadAssertionMonitor& monitor,
+ std::function<::grpc::Status(MockRPC&)> handler) {
+ _listenerThread = monitor.spawn([&, handler = std::move(handler)] {
+ std::vector<stdx::thread> rpcHandlers;
+ while (auto rpc = acceptRPC()) {
+ rpcHandlers.push_back(monitor.spawn([rpc = std::move(*rpc), handler]() mutable {
+ try {
+ auto status = handler(rpc);
+ rpc.sendReturnStatus(std::move(status));
+ } catch (DBException& e) {
+ rpc.sendReturnStatus(
+ ::grpc::Status(::grpc::StatusCode::UNKNOWN, e.toString()));
+ }
+ }));
+ }
+
+ for (auto& thread : rpcHandlers) {
+ thread.join();
+ }
+ });
+ }
+
+ /**
+ * Close the mocked channel and then block until all RPC handler threads (if any) have exited.
+ */
+ void shutdown() {
+ _queue.close();
+ if (_listenerThread) {
+ _listenerThread->join();
+ }
+ }
+
+private:
+ boost::optional<stdx::thread> _listenerThread;
+ MockRPCQueue::Consumer _queue;
+};
+
+class MockChannel {
+public:
+ explicit MockChannel(HostAndPort local, HostAndPort remote, MockRPCQueue::Producer queue)
+ : _local(std::move(local)), _remote{std::move(remote)}, _rpcQueue{std::move(queue)} {};
+
+ void sendRPC(MockRPC&& rpc) {
+ auto pf = makePromiseFuture<void>();
+ _rpcQueue.push({std::move(pf.promise), std::move(rpc)});
+ pf.future.get();
+ }
+
+ const HostAndPort& getLocal() const {
+ return _local;
+ }
+
+ const HostAndPort& getRemote() const {
+ return _remote;
+ }
+
+private:
+ HostAndPort _local;
+ HostAndPort _remote;
+ MockRPCQueue::Producer _rpcQueue;
+};
+
+class MockStub {
+public:
+ explicit MockStub(std::shared_ptr<MockChannel> channel) : _channel{std::move(channel)} {}
+
+ ~MockStub() {}
+
+ std::shared_ptr<MockClientStream> unauthenticatedCommandStream(MockClientContext* ctx) {
+ return _makeStream(CommandService::kUnauthenticatedCommandStreamMethodName, ctx);
+ }
+
+ std::shared_ptr<MockClientStream> authenticatedCommandStream(MockClientContext* ctx) {
+ return _makeStream(CommandService::kAuthenticatedCommandStreamMethodName, ctx);
+ }
+
+private:
+ std::shared_ptr<MockClientStream> _makeStream(const StringData methodName,
+ MockClientContext* ctx) {
+ MetadataView clientMetadata;
+ for (auto& kvp : ctx->_metadata) {
+ clientMetadata.insert(kvp);
+ }
+
+ BidirectionalPipe pipe;
+ auto metadataPF = makePromiseFuture<MetadataContainer>();
+ auto terminationStatusPF = makePromiseFuture<::grpc::Status>();
+ auto cancellationState = std::make_shared<MockCancellationState>(ctx->getDeadline());
+
+ MockRPC rpc;
+ rpc.methodName = methodName;
+ rpc.serverStream =
+ std::make_unique<MockServerStream>(_channel->getLocal(),
+ std::move(metadataPF.promise),
+ std::move(terminationStatusPF.promise),
+ cancellationState,
+ std::move(*pipe.left),
+ clientMetadata);
+ rpc.serverCtx = std::make_unique<MockServerContext>(rpc.serverStream.get());
+ auto clientStream =
+ std::make_shared<MockClientStream>(_channel->getRemote(),
+ std::move(metadataPF.future),
+ std::move(terminationStatusPF.future),
+ cancellationState,
+ std::move(*pipe.right));
+
+ ctx->_stream = clientStream.get();
+ _channel->sendRPC(std::move(rpc));
+ return clientStream;
+ }
+
+ std::shared_ptr<MockChannel> _channel;
+};
+
+} // namespace mongo::transport::grpc
diff --git a/src/mongo/transport/grpc/mock_stub_test.cpp b/src/mongo/transport/grpc/mock_stub_test.cpp
new file mode 100644
index 00000000000..1a24f8e09a1
--- /dev/null
+++ b/src/mongo/transport/grpc/mock_stub_test.cpp
@@ -0,0 +1,204 @@
+/**
+ * Copyright (C) 2023-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include <memory>
+#include <vector>
+
+#include "mongo/db/concurrency/locker_noop_service_context_test_fixture.h"
+#include "mongo/rpc/message.h"
+#include "mongo/stdx/thread.h"
+#include "mongo/transport/grpc/mock_client_context.h"
+#include "mongo/transport/grpc/mock_stub.h"
+#include "mongo/transport/grpc/test_fixtures.h"
+#include "mongo/unittest/assert.h"
+#include "mongo/unittest/thread_assertion_monitor.h"
+#include "mongo/unittest/unittest.h"
+#include "mongo/util/net/hostandport.h"
+#include "mongo/util/producer_consumer_queue.h"
+#include "mongo/util/scopeguard.h"
+
+namespace mongo::transport::grpc {
+
+class MockStubTest : public LockerNoopServiceContextTest {
+public:
+ void setUp() override {
+ _fixtures = std::make_unique<MockStubTestFixtures>();
+ }
+
+ MockStub makeStub() {
+ return _fixtures->makeStub();
+ }
+
+ MockServer& getServer() {
+ return _fixtures->getServer();
+ }
+
+ void runEchoTest(std::function<std::shared_ptr<ClientStream>(MockClientContext&)> makeStream) {
+ unittest::threadAssertionMonitoredTest([&](unittest::ThreadAssertionMonitor& monitor) {
+ getServer().start(monitor, [](auto& rpc) {
+ ASSERT_EQ(rpc.serverCtx->getRemote().toString(),
+ MockStubTestFixtures::kClientAddress);
+ auto msg = rpc.serverStream->read();
+ ASSERT_TRUE(msg);
+ ASSERT_TRUE(rpc.serverStream->write(*msg));
+ return ::grpc::Status::OK;
+ });
+
+ std::vector<stdx::thread> clientThreads;
+ for (int i = 0; i < 10; i++) {
+ clientThreads.push_back(monitor.spawn([&]() {
+ auto clientMessage = makeUniqueMessage();
+ MockClientContext ctx;
+ auto stream = makeStream(ctx);
+ ASSERT_TRUE(stream->write(clientMessage.sharedBuffer()));
+
+ auto serverResponse = stream->read();
+ ASSERT_TRUE(serverResponse);
+ ASSERT_EQ_MSG(Message{*serverResponse}, clientMessage);
+ }));
+ }
+
+ for (auto& thread : clientThreads) {
+ thread.join();
+ }
+
+ getServer().shutdown();
+ });
+ }
+
+ std::pair<std::shared_ptr<MockClientStream>, MockRPC> makeRPC(MockClientContext& ctx) {
+ auto clientStreamPf = makePromiseFuture<std::shared_ptr<MockClientStream>>();
+ auto th = stdx::thread([&, promise = std::move(clientStreamPf.promise)]() mutable {
+ promise.setWith([&] {
+ auto stub = makeStub();
+ return stub.unauthenticatedCommandStream(&ctx);
+ });
+ });
+ ON_BLOCK_EXIT([&th] { th.join(); });
+ auto rpc = getServer().acceptRPC();
+ ASSERT_TRUE(rpc);
+ return {clientStreamPf.future.get(), std::move(*rpc)};
+ }
+
+private:
+ std::unique_ptr<MockStubTestFixtures> _fixtures;
+};
+
+TEST_F(MockStubTest, ConcurrentStreamsAuth) {
+ auto stub = makeStub();
+ runEchoTest([&](auto& ctx) { return stub.authenticatedCommandStream(&ctx); });
+}
+
+TEST_F(MockStubTest, ConcurrentStreamsNoAuth) {
+ auto stub = makeStub();
+ runEchoTest([&](auto& ctx) { return stub.unauthenticatedCommandStream(&ctx); });
+}
+
+TEST_F(MockStubTest, ConcurrentStubsAuth) {
+ runEchoTest([&](auto& ctx) { return makeStub().authenticatedCommandStream(&ctx); });
+}
+
+TEST_F(MockStubTest, ConcurrentStubsNoAuth) {
+ runEchoTest([&](auto& ctx) { return makeStub().unauthenticatedCommandStream(&ctx); });
+}
+
+TEST_F(MockStubTest, RPCReturn) {
+ const ::grpc::Status kFinalStatus =
+ ::grpc::Status{::grpc::StatusCode::FAILED_PRECONDITION, "test"};
+ const int kMessageCount = 5;
+
+ MockClientContext ctx;
+ auto [clientStream, rpc] = makeRPC(ctx);
+
+ for (auto i = 0; i < kMessageCount; i++) {
+ ASSERT_TRUE(rpc.serverStream->write(makeUniqueMessage().sharedBuffer()));
+ }
+
+ rpc.sendReturnStatus(kFinalStatus);
+ ASSERT_FALSE(rpc.serverCtx->isCancelled())
+ << "returning a status should not mark stream as cancelled";
+ ASSERT_FALSE(clientStream->write(makeUniqueMessage().sharedBuffer()));
+
+ auto finishPf = makePromiseFuture<::grpc::Status>();
+ auto finishThread = stdx::thread(
+ [&clientStream = *clientStream, promise = std::move(finishPf.promise)]() mutable {
+ promise.setWith([&] { return clientStream.finish(); });
+ });
+ ON_BLOCK_EXIT([&finishThread] { finishThread.join(); });
+ // finish() should not return until all messages have been read.
+ ASSERT_FALSE(finishPf.future.isReady());
+
+ // Ensure messages sent before the RPC was finished can still be read.
+ for (auto i = 0; i < kMessageCount; i++) {
+ ASSERT_TRUE(clientStream->read());
+ }
+ ASSERT_FALSE(clientStream->read());
+
+ // Ensure that finish() returns now that all the messages have been read.
+ auto status = finishPf.future.get();
+ ASSERT_EQ(status.error_code(), kFinalStatus.error_code());
+
+ // Cancelling a finished RPC should have no effect.
+ rpc.serverCtx->tryCancel();
+ ASSERT_FALSE(rpc.serverCtx->isCancelled());
+}
+
+TEST_F(MockStubTest, RPCCancellation) {
+ const ::grpc::Status kCancellationStatus =
+ ::grpc::Status{::grpc::StatusCode::UNAVAILABLE, "mock network error"};
+
+ MockClientContext ctx;
+ auto [clientStream, rpc] = makeRPC(ctx);
+
+ ASSERT_TRUE(clientStream->write(makeUniqueMessage().sharedBuffer()));
+
+ rpc.cancel(kCancellationStatus);
+
+ ASSERT_TRUE(rpc.serverCtx->isCancelled());
+ ASSERT_FALSE(clientStream->write(makeUniqueMessage().sharedBuffer()));
+ ASSERT_FALSE(clientStream->read());
+ ASSERT_EQ(clientStream->finish().error_code(), kCancellationStatus.error_code());
+}
+
+TEST_F(MockStubTest, CannotReturnStatusForCancelledRPC) {
+ MockClientContext ctx;
+ auto [clientStream, rpc] = makeRPC(ctx);
+
+ ASSERT_TRUE(clientStream->write(makeUniqueMessage().sharedBuffer()));
+
+ rpc.cancel(::grpc::Status::CANCELLED);
+ rpc.sendReturnStatus(::grpc::Status::OK);
+
+ ASSERT_TRUE(rpc.serverCtx->isCancelled());
+ ASSERT_FALSE(clientStream->write(makeUniqueMessage().sharedBuffer()));
+ ASSERT_FALSE(clientStream->read());
+ ASSERT_EQ(clientStream->finish().error_code(), ::grpc::StatusCode::CANCELLED);
+}
+
+} // namespace mongo::transport::grpc
diff --git a/src/mongo/transport/grpc/mock_util.h b/src/mongo/transport/grpc/mock_util.h
index 456437cfe1f..2e0fc9c02ed 100644
--- a/src/mongo/transport/grpc/mock_util.h
+++ b/src/mongo/transport/grpc/mock_util.h
@@ -32,14 +32,63 @@
#include <map>
#include <string>
+#include <boost/optional.hpp>
+#include <grpcpp/grpcpp.h>
+#include <grpcpp/support/status.h>
+
#include "mongo/db/operation_context.h"
#include "mongo/db/service_context.h"
#include "mongo/util/interruptible.h"
+#include "mongo/util/synchronized_value.h"
#include "mongo/util/time_support.h"
namespace mongo::transport::grpc {
/**
+ * Class containing the shared cancellation state between a MockServerStream and its corresponding
+ * MockClientStream. This mocks cases in which an RPC is terminated before the server's RPC handler
+ * is able to return a status (e.g. explicit client/server cancellation or a network error).
+ */
+class MockCancellationState {
+public:
+ explicit MockCancellationState(Date_t deadline) : _deadline(deadline) {}
+
+ Date_t getDeadline() const {
+ return _deadline;
+ }
+
+ bool isCancelled() const {
+ return _cancellationStatus->has_value() || isDeadlineExceeded();
+ }
+
+ bool isDeadlineExceeded() const {
+ return getGlobalServiceContext()->getFastClockSource()->now() > _deadline;
+ }
+
+ boost::optional<::grpc::Status> getCancellationStatus() {
+ if (auto status = _cancellationStatus.synchronize(); status->has_value()) {
+ return *status;
+ } else if (isDeadlineExceeded()) {
+ return ::grpc::Status(::grpc::StatusCode::DEADLINE_EXCEEDED, "Deadline exceeded");
+ } else {
+ return boost::none;
+ }
+ }
+
+ void cancel(::grpc::Status status) {
+ auto statusGuard = _cancellationStatus.synchronize();
+ if (statusGuard->has_value()) {
+ return;
+ }
+ *statusGuard = std::move(status);
+ }
+
+private:
+ Date_t _deadline;
+ synchronized_value<boost::optional<::grpc::Status>> _cancellationStatus;
+};
+
+/**
* Performs the provided lambda and returns its return value. If the lambda's execution is
* interrupted due to the deadline being exceeded, this returns a default-constructed T instead.
*/
diff --git a/src/mongo/transport/grpc/server_context.h b/src/mongo/transport/grpc/server_context.h
index a957bc26582..97e467c5260 100644
--- a/src/mongo/transport/grpc/server_context.h
+++ b/src/mongo/transport/grpc/server_context.h
@@ -56,7 +56,7 @@ public:
virtual void addInitialMetadataEntry(const std::string& key, const std::string& value) = 0;
virtual const MetadataView& getClientMetadata() const = 0;
virtual Date_t getDeadline() const = 0;
- virtual HostAndPort getHostAndPort() const = 0;
+ virtual HostAndPort getRemote() const = 0;
/**
* Attempt to cancel the RPC this context is associated with. This may not have an effect if the
@@ -65,6 +65,15 @@ public:
* This is thread-safe.
*/
virtual void tryCancel() = 0;
+
+ /**
+ * Return true if the RPC associated with this ServerContext failed before the RPC handler could
+ * return its final status back to the client (e.g. due to explicit cancellation or a network
+ * issue).
+ *
+ * If the handler was able to return a status successfully, even if that status was
+ * Status::CANCELLED, then this method will return false.
+ */
virtual bool isCancelled() const = 0;
};
diff --git a/src/mongo/transport/grpc/service_test.cpp b/src/mongo/transport/grpc/service_test.cpp
index ef306ee6b46..aa637f84695 100644
--- a/src/mongo/transport/grpc/service_test.cpp
+++ b/src/mongo/transport/grpc/service_test.cpp
@@ -56,6 +56,7 @@
#include "mongo/unittest/thread_assertion_monitor.h"
#include "mongo/unittest/unittest.h"
#include "mongo/util/concurrency/notification.h"
+#include "mongo/util/scopeguard.h"
#define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kTest
@@ -194,6 +195,43 @@ TEST_F(CommandServiceTest, Echo) {
});
}
+TEST_F(CommandServiceTest, SessionTerminate) {
+ const int kMessageCount = 5;
+ auto termination = std::make_unique<Notification<void>>();
+
+ auto handler = [&termination](IngressSession& session) {
+ for (int i = 0; i < kMessageCount; i++) {
+ auto status = session.sinkMessage(makeUniqueMessage());
+ ASSERT_OK(status);
+ }
+ session.terminate(Status(ErrorCodes::StreamTerminated, "dummy error"));
+ termination->set();
+ ASSERT_NOT_OK(session.sinkMessage(makeUniqueMessage()));
+ return ::grpc::Status::CANCELLED;
+ };
+
+ runTestWithBothMethods(handler, [&](auto&, auto& monitor, auto methodCallback) {
+ ::grpc::ClientContext ctx;
+ CommandServiceTestFixtures::addAllClientMetadata(ctx);
+ auto stream = methodCallback(ctx);
+
+ // Messages sent before the RPC was cancelled should be able to be read.
+ termination->get();
+ ON_BLOCK_EXIT([&] {
+ // Reset the termination notification for the next test run.
+ termination = std::make_unique<Notification<void>>();
+ });
+ for (int i = 0; i < kMessageCount; i++) {
+ SharedBuffer b;
+ ASSERT_TRUE(stream->Read(&b));
+ }
+
+ SharedBuffer b;
+ ASSERT_FALSE(stream->Read(&b));
+ ASSERT_EQ(stream->Finish().error_code(), ::grpc::StatusCode::CANCELLED);
+ });
+}
+
TEST_F(CommandServiceTest, NewClientsAreLogged) {
runMetadataLogTest(
[clientId = UUID::gen().toString()](::grpc::ClientContext& ctx) {
diff --git a/src/mongo/transport/grpc/test_fixtures.h b/src/mongo/transport/grpc/test_fixtures.h
index 56735b37997..eaa53fb0733 100644
--- a/src/mongo/transport/grpc/test_fixtures.h
+++ b/src/mongo/transport/grpc/test_fixtures.h
@@ -49,6 +49,7 @@
#include "mongo/transport/grpc/mock_client_stream.h"
#include "mongo/transport/grpc/mock_server_context.h"
#include "mongo/transport/grpc/mock_server_stream.h"
+#include "mongo/transport/grpc/mock_stub.h"
#include "mongo/transport/grpc/server.h"
#include "mongo/transport/grpc/service.h"
#include "mongo/transport/grpc/util.h"
@@ -71,28 +72,59 @@ inline Message makeUniqueMessage() {
}
struct MockStreamTestFixtures {
- MockStreamTestFixtures(HostAndPort hostAndPort,
- Milliseconds timeout,
- MetadataView clientMetadata) {
- BidirectionalPipe pipe;
- auto promiseAndFuture = makePromiseFuture<MetadataContainer>();
-
- serverStream = std::make_unique<MockServerStream>(hostAndPort,
- timeout,
- std::move(promiseAndFuture.promise),
- std::move(*pipe.left),
- clientMetadata);
- serverCtx = std::make_unique<MockServerContext>(serverStream.get());
-
- clientStream = std::make_unique<MockClientStream>(
- hostAndPort, timeout, std::move(promiseAndFuture.future), std::move(*pipe.right));
- clientCtx = std::make_unique<MockClientContext>(clientStream.get());
+ std::shared_ptr<MockClientStream> clientStream;
+ std::shared_ptr<MockClientContext> clientCtx;
+ std::shared_ptr<MockServerStream> serverStream;
+ std::shared_ptr<MockServerContext> serverCtx;
+};
+
+class MockStubTestFixtures {
+public:
+ static constexpr auto kBindAddress = "localhost:1234";
+ static constexpr auto kClientAddress = "abc:5678";
+
+ MockStubTestFixtures() {
+ MockRPCQueue::Pipe pipe;
+
+ _channel = std::make_shared<MockChannel>(
+ HostAndPort(kClientAddress), HostAndPort(kBindAddress), std::move(pipe.producer));
+ _server = std::make_unique<MockServer>(std::move(pipe.consumer));
}
- std::unique_ptr<MockClientStream> clientStream;
- std::unique_ptr<MockClientContext> clientCtx;
- std::unique_ptr<MockServerStream> serverStream;
- std::unique_ptr<MockServerContext> serverCtx;
+ MockStub makeStub() {
+ return MockStub(_channel);
+ }
+
+ std::unique_ptr<MockStreamTestFixtures> makeStreamTestFixtures(
+ Date_t deadline, const MetadataView& clientMetadata) {
+ MockStreamTestFixtures fixtures{
+ nullptr, std::make_shared<MockClientContext>(), nullptr, nullptr};
+
+ auto clientThread = stdx::thread([&] {
+ fixtures.clientCtx->setDeadline(deadline);
+ for (auto& kvp : clientMetadata) {
+ fixtures.clientCtx->addMetadataEntry(kvp.first.toString(), kvp.second.toString());
+ }
+ fixtures.clientStream =
+ makeStub().unauthenticatedCommandStream(fixtures.clientCtx.get());
+ });
+
+ auto rpc = getServer().acceptRPC();
+ ASSERT_TRUE(rpc);
+ fixtures.serverCtx = std::move(rpc->serverCtx);
+ fixtures.serverStream = std::move(rpc->serverStream);
+ clientThread.join();
+
+ return std::make_unique<MockStreamTestFixtures>(std::move(fixtures));
+ }
+
+ MockServer& getServer() {
+ return *_server;
+ }
+
+private:
+ std::unique_ptr<MockServer> _server;
+ std::shared_ptr<MockChannel> _channel;
};
class ServiceContextWithClockSourceMockTest : public LockerNoopServiceContextTest {
diff --git a/src/mongo/util/producer_consumer_queue.h b/src/mongo/util/producer_consumer_queue.h
index 0c5118e52d6..bef44b6f29f 100644
--- a/src/mongo/util/producer_consumer_queue.h
+++ b/src/mongo/util/producer_consumer_queue.h
@@ -439,10 +439,11 @@ public:
size_t waitingConsumers;
size_t waitingProducers;
size_t producerQueueDepth;
+ bool producerEndClosed;
+ bool consumerEndClosed;
// TODO more stats
//
// totalTimeBlocked on either side
- // closed ends
// count of producers and consumers (blocked, or existing if we're a pipe)
};
@@ -629,6 +630,8 @@ public:
stats.waitingConsumers = _consumers;
stats.waitingProducers = _producers;
stats.producerQueueDepth = _producers.queueDepth();
+ stats.producerEndClosed = _producerEndClosed;
+ stats.consumerEndClosed = _consumerEndClosed;
return stats;
}