From 84475222a82ded97409472d83b16afb1c08a9af8 Mon Sep 17 00:00:00 2001 From: Patrick Freed Date: Wed, 17 May 2023 13:23:13 +0000 Subject: SERVER-74015 Introduce mocked gRPC stub and client stream --- src/mongo/transport/grpc/SConscript | 1 + src/mongo/transport/grpc/bidirectional_pipe.h | 31 ++- src/mongo/transport/grpc/client_context.h | 92 +++++++++ src/mongo/transport/grpc/client_stream.h | 78 ++++++++ src/mongo/transport/grpc/grpc_server_context.h | 8 +- src/mongo/transport/grpc/grpc_session.h | 60 ++++-- src/mongo/transport/grpc/grpc_session_test.cpp | 16 +- src/mongo/transport/grpc/mock_client_context.h | 43 +++- src/mongo/transport/grpc/mock_client_stream.cpp | 50 ++++- src/mongo/transport/grpc/mock_client_stream.h | 43 ++-- src/mongo/transport/grpc/mock_server_context.cpp | 4 +- src/mongo/transport/grpc/mock_server_context.h | 7 +- src/mongo/transport/grpc/mock_server_stream.cpp | 52 +++-- src/mongo/transport/grpc/mock_server_stream.h | 58 +++++- .../transport/grpc/mock_server_stream_test.cpp | 66 +++++-- src/mongo/transport/grpc/mock_stub.h | 217 +++++++++++++++++++++ src/mongo/transport/grpc/mock_stub_test.cpp | 204 +++++++++++++++++++ src/mongo/transport/grpc/mock_util.h | 49 +++++ src/mongo/transport/grpc/server_context.h | 11 +- src/mongo/transport/grpc/service_test.cpp | 38 ++++ src/mongo/transport/grpc/test_fixtures.h | 72 +++++-- src/mongo/util/producer_consumer_queue.h | 5 +- 22 files changed, 1078 insertions(+), 127 deletions(-) create mode 100644 src/mongo/transport/grpc/client_context.h create mode 100644 src/mongo/transport/grpc/client_stream.h create mode 100644 src/mongo/transport/grpc/mock_stub.h create mode 100644 src/mongo/transport/grpc/mock_stub_test.cpp diff --git a/src/mongo/transport/grpc/SConscript b/src/mongo/transport/grpc/SConscript index 17ef132dff3..b3317cbaff8 100644 --- a/src/mongo/transport/grpc/SConscript +++ b/src/mongo/transport/grpc/SConscript @@ -48,6 +48,7 @@ env.CppUnitTest( 'grpc_session_test.cpp', 'grpc_transport_layer_test.cpp', 'mock_server_stream_test.cpp', + 'mock_stub_test.cpp', 'server_test.cpp', 'service_test.cpp', ], diff --git a/src/mongo/transport/grpc/bidirectional_pipe.h b/src/mongo/transport/grpc/bidirectional_pipe.h index f1c29471eb0..b53d61b519c 100644 --- a/src/mongo/transport/grpc/bidirectional_pipe.h +++ b/src/mongo/transport/grpc/bidirectional_pipe.h @@ -89,20 +89,36 @@ public: } /** - * Close both ends of the pipe. In progress reads and writes on either end will be - * interrupted. + * Close both the read and write halves of this end of the pipe. In-progress reads and + * writes on this end and writes on the other end will be interrupted. + * + * Messages that have already been transmitted through this end of the pipe can still be + * read by the other end. */ void close() { _sendHalf.close(); _recvHalf.close(); } + /** + * Returns true when at least one of the following conditions is met: + * - This end of the pipe is closed. + * - The other end of the pipe is closed and there are no more messages to be read. + */ + bool isConsumed() const { + auto stats = _recvHalfCtrl.getStats(); + return stats.consumerEndClosed || (stats.queueDepth == 0 && stats.producerEndClosed); + } + private: friend BidirectionalPipe; explicit End(SingleProducerSingleConsumerQueue::Producer send, - SingleProducerSingleConsumerQueue::Consumer recv) - : _sendHalf{std::move(send)}, _recvHalf{std::move(recv)} {} + SingleProducerSingleConsumerQueue::Consumer recv, + SingleProducerSingleConsumerQueue::Controller recvCtrl) + : _sendHalf{std::move(send)}, + _recvHalf{std::move(recv)}, + _recvHalfCtrl(std::move(recvCtrl)) {} bool _isPipeClosedError(const DBException& e) const { return e.code() == ErrorCodes::ProducerConsumerQueueEndClosed || @@ -111,14 +127,17 @@ public: SingleProducerSingleConsumerQueue::Producer _sendHalf; SingleProducerSingleConsumerQueue::Consumer _recvHalf; + SingleProducerSingleConsumerQueue::Controller _recvHalfCtrl; }; BidirectionalPipe() { SingleProducerSingleConsumerQueue::Pipe pipe1; SingleProducerSingleConsumerQueue::Pipe pipe2; - left = std::unique_ptr(new End(std::move(pipe1.producer), std::move(pipe2.consumer))); - right = std::unique_ptr(new End(std::move(pipe2.producer), std::move(pipe1.consumer))); + left = std::unique_ptr(new End( + std::move(pipe1.producer), std::move(pipe2.consumer), std::move(pipe2.controller))); + right = std::unique_ptr(new End( + std::move(pipe2.producer), std::move(pipe1.consumer), std::move(pipe1.controller))); } std::unique_ptr left; diff --git a/src/mongo/transport/grpc/client_context.h b/src/mongo/transport/grpc/client_context.h new file mode 100644 index 00000000000..b4b8238cce6 --- /dev/null +++ b/src/mongo/transport/grpc/client_context.h @@ -0,0 +1,92 @@ +/** + * Copyright (C) 2023-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * . + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include +#include + +#include "mongo/transport/grpc/metadata.h" +#include "mongo/util/net/hostandport.h" +#include "mongo/util/time_support.h" + +namespace mongo::transport::grpc { + +/** + * Base class modeling a gRPC ClientContext. + * See: https://grpc.github.io/grpc/cpp/classgrpc_1_1_client_context.html + */ +class ClientContext { +public: + virtual ~ClientContext() = default; + + /** + * Add an entry to the metadata associated with the RPC. + * + * This must only be called before invoking the RPC. + */ + virtual void addMetadataEntry(const std::string& key, const std::string& value) = 0; + + /** + * Retrieve the server's initial metadata. + * + * This must only be called after the first message has been received on the ClientStream + * created from the RPC that this context is associated with. + */ + virtual boost::optional getServerInitialMetadata() const = 0; + + /** + * Set the deadline for the RPC to be executed using this context. + * + * This must only be called before invoking the RPC. + */ + virtual void setDeadline(Date_t deadline) = 0; + + virtual Date_t getDeadline() const = 0; + + virtual HostAndPort getRemote() const = 0; + + /** + * Send a best-effort out-of-band cancel on the call associated with this ClientContext. There + * is no guarantee the call will be cancelled (e.g. if the call has already finished by the time + * the cancellation is received). + * + * Note that tryCancel() will not impede the execution of any already scheduled work (e.g. + * messages already queued to be sent on a stream will still be sent), though the reported + * sucess or failure of such work may reflect the cancellation. + * + * This method is thread-safe, and can be called multiple times from any thread. It should not + * be called before this ClientContext has been used to invoke an RPC. + * + * See: + * https://grpc.github.io/grpc/cpp/classgrpc_1_1_client_context.html#abd0f6715c30287b75288015eee628984 + */ + virtual void tryCancel() = 0; +}; +} // namespace mongo::transport::grpc diff --git a/src/mongo/transport/grpc/client_stream.h b/src/mongo/transport/grpc/client_stream.h new file mode 100644 index 00000000000..e1dcbe80827 --- /dev/null +++ b/src/mongo/transport/grpc/client_stream.h @@ -0,0 +1,78 @@ +/** + * Copyright (C) 2023-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * . + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include +#include + +#include "mongo/util/shared_buffer.h" + +namespace mongo::transport::grpc { + +/** + * Base class modeling a synchronous client side of a gRPC stream. + * See: https://grpc.github.io/grpc/cpp/classgrpc_1_1_client_reader_writer.html + * + * ClientStream::read() is thread safe with respect to ClientStream::write(), but neither method + * should be called concurrently with another invocation of itself on the same stream. + * + * ClientStream::finish() is thread safe with respect to ClientStream::read(). + */ +class ClientStream { +public: + virtual ~ClientStream() = default; + + /** + * Block to read a message from the stream. + * + * Returns boost::none if the stream is closed, either cleanly or due to an underlying + * connection failure. + */ + virtual boost::optional read() = 0; + + /** + * Block to write a message to the stream. + * + * Returns true if the write was successful or false if it failed due to the stream being + * closed, either explicitly or due to an underlying connection failure. + */ + virtual bool write(ConstSharedBuffer msg) = 0; + + /** + * Block waiting until all received messages have been read and the stream has been closed. + * + * Returns the final status of the RPC associated with this stream. + * + * This method should only be called once. + */ + virtual ::grpc::Status finish() = 0; +}; + +} // namespace mongo::transport::grpc diff --git a/src/mongo/transport/grpc/grpc_server_context.h b/src/mongo/transport/grpc/grpc_server_context.h index 400719eafa8..672074ddf07 100644 --- a/src/mongo/transport/grpc/grpc_server_context.h +++ b/src/mongo/transport/grpc/grpc_server_context.h @@ -53,7 +53,7 @@ public: } explicit GRPCServerContext(::grpc::ServerContext* ctx) - : _ctx{ctx}, _hostAndPort{parseURI(_ctx->peer())} { + : _ctx{ctx}, _remote{parseURI(_ctx->peer())} { for (auto& kvp : _ctx->client_metadata()) { _clientMetadata.insert({StringData{kvp.first.data(), kvp.first.length()}, StringData{kvp.second.data(), kvp.second.length()}}); @@ -74,8 +74,8 @@ public: return Date_t{_ctx->deadline()}; } - HostAndPort getHostAndPort() const override { - return _hostAndPort; + HostAndPort getRemote() const override { + return _remote; } bool isCancelled() const override { @@ -89,7 +89,7 @@ public: private: ::grpc::ServerContext* _ctx; MetadataView _clientMetadata; - HostAndPort _hostAndPort; + HostAndPort _remote; }; } // namespace mongo::transport::grpc diff --git a/src/mongo/transport/grpc/grpc_session.h b/src/mongo/transport/grpc/grpc_session.h index 43d4d255442..7dbb7ffc7a0 100644 --- a/src/mongo/transport/grpc/grpc_session.h +++ b/src/mongo/transport/grpc/grpc_session.h @@ -66,28 +66,41 @@ public: virtual boost::optional clientId() const = 0; /** - * Terminates the underlying gRPC stream. + * Cancels the underlying gRPC stream and updates the termination status of the session. + * If this session is already terminated, this has no effect. + * + * It is an error to terminate a session with an OK status. Instead, provide a status that + * explains the reason for the cancellation or use end() if the intention is to mark the session + * as successfully terminated without cancelling the underlying stream. */ - virtual void terminate(Status status) { - auto ts = _terminationStatus.synchronize(); - if (MONGO_unlikely(ts->has_value())) - return; - ts->emplace(std::move(status)); + void terminate(Status status) { + tassert(7401590, + "gRPC sessions should only be manually terminated with non-OK statuses", + !status.isOK()); + + // Need to update terminationStatus before cancelling so that when the RPC caller/handler is + // interrupted, the it will be guaranteed to have access to the reason for cancellation. + if (_setTerminationStatus(std::move(status))) { + _tryCancel(); + } } /** * Returns the termination status (always set at termination). Remains unset until termination. */ boost::optional terminationStatus() const { - return **_terminationStatus; + return *_terminationStatus; } TransportLayer* getTransportLayer() const final { return _tl; } + /** + * Marks the session as having terminated successfully. + */ void end() final { - terminate(Status::OK()); + _setTerminationStatus(Status::OK()); } /** @@ -145,6 +158,20 @@ public: } private: + /** + * Sets the termination status if it hasn't been set already. + * Returns whether the termination status was updated or not. + */ + bool _setTerminationStatus(Status status) { + auto ts = _terminationStatus.synchronize(); + if (MONGO_unlikely(ts->has_value())) + return false; + ts->emplace(std::move(status)); + return true; + } + + virtual void _tryCancel() = 0; + // TODO SERVER-74020: replace this with `GRPCTransportLayer`. TransportLayer* const _tl; @@ -168,7 +195,7 @@ public: _ctx(ctx), _stream(stream), _clientId(std::move(clientId)), - _remote(ctx->getHostAndPort()) { + _remote(ctx->getRemote()) { LOGV2_DEBUG(7401101, 2, "Constructed a new gRPC ingress session", @@ -219,17 +246,6 @@ public: return Status(ErrorCodes::StreamTerminated, "Unable to write to ingress session"); } - void terminate(Status status) override { - auto shouldCancel = !status.isOK() && !terminationStatus(); - // Need to invoke GRPCSession::terminate before cancelling the server context so that when - // an RPC handler is interrupted, it will be guaranteed to have access to its termination - // status. - GRPCSession::terminate(std::move(status)); - if (MONGO_unlikely(shouldCancel)) { - _ctx->tryCancel(); - } - } - const HostAndPort& remote() const override { return _remote; } @@ -239,6 +255,10 @@ public: } private: + void _tryCancel() override { + _ctx->tryCancel(); + } + // _ctx and _stream are only valid while the RPC handler is still running. They should not be // accessed after the stream has been terminated. ServerContext* const _ctx; diff --git a/src/mongo/transport/grpc/grpc_session_test.cpp b/src/mongo/transport/grpc/grpc_session_test.cpp index 24baf977f1e..de4f7e2897e 100644 --- a/src/mongo/transport/grpc/grpc_session_test.cpp +++ b/src/mongo/transport/grpc/grpc_session_test.cpp @@ -61,8 +61,11 @@ public: void setUp() override { SessionTest::setUp(); - _fixture = std::make_unique( - HostAndPort{kRemote}, kStreamTimeout, _clientMetadata); + // The MockStreamTestFixtures created here doesn't contain any references to the channel or + // server, so it's okay to let stubFixture go out of scope. + MockStubTestFixtures stubFixture; + _fixture = stubFixture.makeStreamTestFixtures( + getServiceContext()->getFastClockSource()->now() + kStreamTimeout, _clientMetadata); } void tearDown() override { @@ -109,7 +112,7 @@ TEST_F(IngressSessionTest, GetClientId) { TEST_F(IngressSessionTest, GetRemote) { auto session = makeSession(); - ASSERT_EQ(session->remote().toString(), kRemote); + ASSERT_EQ(session->remote().toString(), MockStubTestFixtures::kClientAddress); } TEST_F(IngressSessionTest, IsConnected) { @@ -119,9 +122,9 @@ TEST_F(IngressSessionTest, IsConnected) { ASSERT_FALSE(session->isConnected()); } -TEST_F(IngressSessionTest, Terminate) { +TEST_F(IngressSessionTest, End) { auto session = makeSession(); - session->terminate(Status::OK()); + session->end(); ASSERT_FALSE(session->isConnected()); ASSERT_TRUE(session->terminationStatus()); ASSERT_OK(*session->terminationStatus()); @@ -134,13 +137,14 @@ TEST_F(IngressSessionTest, TerminateWithError) { ASSERT_FALSE(session->isConnected()); ASSERT_TRUE(session->terminationStatus()); ASSERT_EQ(*session->terminationStatus(), error); + ASSERT_TRUE(fixture()->serverCtx->isCancelled()); } TEST_F(IngressSessionTest, TerminateRetainsStatus) { const Status error(ErrorCodes::InternalError, "Some Error"); auto session = makeSession(); session->terminate(error); - session->terminate(Status::OK()); + session->end(); ASSERT_EQ(*session->terminationStatus(), error); } diff --git a/src/mongo/transport/grpc/mock_client_context.h b/src/mongo/transport/grpc/mock_client_context.h index 04311414e35..5dad9078403 100644 --- a/src/mongo/transport/grpc/mock_client_context.h +++ b/src/mongo/transport/grpc/mock_client_context.h @@ -29,29 +29,54 @@ #pragma once -#include -#include -#include +#include "mongo/transport/grpc/client_context.h" #include "mongo/transport/grpc/mock_client_stream.h" namespace mongo::transport::grpc { -// TODO: SERVER-74015 introduce a ClientContext interface that covers the whole API surface of -// gRPC's ClientContext type, and implement that interface here. -class MockClientContext { +class MockClientContext : public ClientContext { public: - explicit MockClientContext(MockClientStream* stream) : _stream{stream} {} - ~MockClientContext() = default; + MockClientContext() : _deadline{Date_t::max()}, _stream{nullptr} {} - boost::optional getServerInitialMetadata() const { + void addMetadataEntry(const std::string& key, const std::string& value) override { + invariant(!_stream); + _metadata.insert({key, value}); + }; + + boost::optional getServerInitialMetadata() const override { + invariant(_stream); if (!_stream->_serverInitialMetadata.isReady()) { return boost::none; } return _stream->_serverInitialMetadata.get(); } + Date_t getDeadline() const override { + return _deadline; + } + + void setDeadline(Date_t deadline) override { + invariant(!_stream); + _deadline = deadline; + } + + HostAndPort getRemote() const override { + invariant(_stream); + return _stream->_remote; + } + + void tryCancel() override { + invariant(_stream); + _stream->_cancel(); + } + private: + friend class MockStub; + friend struct MockStreamTestFixtures; + + Date_t _deadline; + MetadataContainer _metadata; MockClientStream* _stream; }; diff --git a/src/mongo/transport/grpc/mock_client_stream.cpp b/src/mongo/transport/grpc/mock_client_stream.cpp index 0952c8be49c..dfc00f8c51a 100644 --- a/src/mongo/transport/grpc/mock_client_stream.cpp +++ b/src/mongo/transport/grpc/mock_client_stream.cpp @@ -29,26 +29,64 @@ #include "mongo/transport/grpc/mock_client_stream.h" -#include "mongo/db/service_context.h" #include "mongo/transport/grpc/mock_util.h" +#include "mongo/util/interruptible.h" namespace mongo::transport::grpc { -MockClientStream::MockClientStream(HostAndPort hostAndPort, - Milliseconds timeout, +MockClientStream::MockClientStream(HostAndPort remote, Future&& initialMetadataFuture, + Future<::grpc::Status>&& rpcReturnStatus, + std::shared_ptr rpcCancellationState, BidirectionalPipe::End&& pipe) - : _deadline{getGlobalServiceContext()->getFastClockSource()->now() + timeout}, + : _remote{std::move(remote)}, _serverInitialMetadata{std::move(initialMetadataFuture)}, + _rpcReturnStatus{std::move(rpcReturnStatus)}, + _rpcCancellationState(std::move(rpcCancellationState)), _pipe{std::move(pipe)} {} boost::optional MockClientStream::read() { + // Even if the server side handler of this stream has set a final status for the RPC (i.e. + // _rpcReturnStatus is ready), there may still be unread messages in the queue that the server + // sent before setting that status, so only return early here if the RPC was cancelled. + // Otherwise, try to read whatever messages are in the queue. + if (_rpcCancellationState->isCancelled()) { + return boost::none; + } + return runWithDeadline>( - _deadline, [&](Interruptible* i) { return _pipe.read(i); }); + _rpcCancellationState->getDeadline(), [&](Interruptible* i) { return _pipe.read(i); }); } bool MockClientStream::write(ConstSharedBuffer msg) { - return runWithDeadline(_deadline, [&](Interruptible* i) { return _pipe.write(msg, i); }); + if (_rpcCancellationState->isCancelled() || _rpcReturnStatus.isReady()) { + return false; + } + + return runWithDeadline(_rpcCancellationState->getDeadline(), + [&](Interruptible* i) { return _pipe.write(msg, i); }); +} + +::grpc::Status MockClientStream::finish() { + // We use a busy wait here because there is no easy way to wait until all the messages in the + // pipe have been read. + while (!_pipe.isConsumed() && !_rpcCancellationState->isDeadlineExceeded()) { + sleepFor(Milliseconds(1)); + } + + invariant(_rpcReturnStatus.isReady() || _rpcCancellationState->isCancelled()); + + if (auto cancellationStatus = _rpcCancellationState->getCancellationStatus(); + cancellationStatus.has_value()) { + return *cancellationStatus; + } + + return _rpcReturnStatus.get(); +} + +void MockClientStream::_cancel() { + _rpcCancellationState->cancel(::grpc::Status::CANCELLED); + _pipe.close(); } } // namespace mongo::transport::grpc diff --git a/src/mongo/transport/grpc/mock_client_stream.h b/src/mongo/transport/grpc/mock_client_stream.h index 694581c393d..f07af95f8d1 100644 --- a/src/mongo/transport/grpc/mock_client_stream.h +++ b/src/mongo/transport/grpc/mock_client_stream.h @@ -29,37 +29,52 @@ #pragma once -#include -#include +#include "mongo/transport/grpc/client_stream.h" #include "mongo/transport/grpc/bidirectional_pipe.h" #include "mongo/transport/grpc/metadata.h" -#include "mongo/transport/grpc/server_stream.h" +#include "mongo/transport/grpc/mock_util.h" #include "mongo/util/future.h" #include "mongo/util/net/hostandport.h" +#include "mongo/util/synchronized_value.h" #include "mongo/util/time_support.h" namespace mongo::transport::grpc { -// TODO: SERVER-74015 introduce a ClientStream interface that covers the whole API surface of -// gRPC's ClientReaderWriter type, and implement that interface here. -class MockClientStream { +class MockClientStream : public ClientStream { public: - ~MockClientStream() = default; + MockClientStream(HostAndPort remote, + Future&& serverInitialMetadata, + Future<::grpc::Status>&& rpcReturnStatus, + std::shared_ptr rpcCancellationState, + BidirectionalPipe::End&& pipe); - boost::optional read(); - bool write(ConstSharedBuffer msg); + boost::optional read() override; - explicit MockClientStream(HostAndPort hostAndPort, - Milliseconds timeout, - Future&& serverInitialMetadata, - BidirectionalPipe::End&& pipe); + bool write(ConstSharedBuffer msg) override; + + ::grpc::Status finish() override; private: friend class MockClientContext; - Date_t _deadline; + void _cancel(); + + HostAndPort _remote; + MetadataContainer _clientMetadata; Future _serverInitialMetadata; + + /** + * The mocked equivalent of a status returned from a server-side RPC handler. + */ + Future<::grpc::Status> _rpcReturnStatus; + + /** + * State used to mock RPC cancellation, including explicit cancellation (client or server side) + * or network errors. + */ + std::shared_ptr _rpcCancellationState; + BidirectionalPipe::End _pipe; }; } // namespace mongo::transport::grpc diff --git a/src/mongo/transport/grpc/mock_server_context.cpp b/src/mongo/transport/grpc/mock_server_context.cpp index a862fe17e49..a4b91ce5b01 100644 --- a/src/mongo/transport/grpc/mock_server_context.cpp +++ b/src/mongo/transport/grpc/mock_server_context.cpp @@ -40,11 +40,11 @@ const MetadataView& MockServerContext::getClientMetadata() const { } Date_t MockServerContext::getDeadline() const { - return _stream->_deadline; + return _stream->_rpcCancellationState->getDeadline(); } void MockServerContext::tryCancel() { - _stream->close(); + _stream->cancel(::grpc::Status::CANCELLED); } bool MockServerContext::isCancelled() const { diff --git a/src/mongo/transport/grpc/mock_server_context.h b/src/mongo/transport/grpc/mock_server_context.h index bd1fa05d47b..0c867f43e18 100644 --- a/src/mongo/transport/grpc/mock_server_context.h +++ b/src/mongo/transport/grpc/mock_server_context.h @@ -43,11 +43,8 @@ public: const MetadataView& getClientMetadata() const override; Date_t getDeadline() const override; bool isCancelled() const override; - HostAndPort getHostAndPort() const override { - return _stream->_hostAndPort; - } - CancellationToken getCancellationToken() { - return _stream->_cancellationSource.token(); + HostAndPort getRemote() const override { + return _stream->_remote; } void tryCancel() override; diff --git a/src/mongo/transport/grpc/mock_server_stream.cpp b/src/mongo/transport/grpc/mock_server_stream.cpp index 8c8ca1df559..8fa89bdba90 100644 --- a/src/mongo/transport/grpc/mock_server_stream.cpp +++ b/src/mongo/transport/grpc/mock_server_stream.cpp @@ -31,42 +31,68 @@ #include "mongo/db/service_context.h" #include "mongo/transport/grpc/mock_util.h" +#include "mongo/util/assert_util.h" #include "mongo/util/interruptible.h" namespace mongo::transport::grpc { -MockServerStream::MockServerStream(HostAndPort hostAndPort, - Milliseconds timeout, +MockServerStream::MockServerStream(HostAndPort remote, Promise&& initialMetadataPromise, + Promise<::grpc::Status>&& rpcTerminationStatusPromise, + std::shared_ptr rpcCancellationState, BidirectionalPipe::End&& serverPipeEnd, MetadataView clientMetadata) - : _deadline{getGlobalServiceContext()->getFastClockSource()->now() + timeout}, + : _remote(std::move(remote)), _initialMetadata(std::move(initialMetadataPromise)), + _rpcReturnStatus(std::move(rpcTerminationStatusPromise)), + _finalStatusReturned(false), + _rpcCancellationState(std::move(rpcCancellationState)), _pipe{std::move(serverPipeEnd)}, - _clientMetadata{std::move(clientMetadata)}, - _hostAndPort(std::move(hostAndPort)) {} + _clientMetadata{std::move(clientMetadata)} {} boost::optional MockServerStream::read() { + invariant(!*_finalStatusReturned); + return runWithDeadline>( - _deadline, [&](Interruptible* i) { return _pipe.read(i); }); + _rpcCancellationState->getDeadline(), [&](Interruptible* i) { return _pipe.read(i); }); } bool MockServerStream::isCancelled() const { - return _cancellationSource.token().isCanceled() || - getGlobalServiceContext()->getFastClockSource()->now() > _deadline; + return _rpcCancellationState->isCancelled(); } bool MockServerStream::write(ConstSharedBuffer msg) { - if (_cancellationSource.token().isCanceled() || - getGlobalServiceContext()->getFastClockSource()->now() > _deadline) { + invariant(!*_finalStatusReturned); + if (isCancelled()) { return false; } + _initialMetadata.trySend(); - return runWithDeadline(_deadline, [&](Interruptible* i) { return _pipe.write(msg, i); }); + return runWithDeadline(_rpcCancellationState->getDeadline(), + [&](Interruptible* i) { return _pipe.write(msg, i); }); } -void MockServerStream::close() { - _cancellationSource.cancel(); +void MockServerStream::sendReturnStatus(::grpc::Status status) { + { + auto finalStatusReturned = _finalStatusReturned.synchronize(); + invariant(!*finalStatusReturned); + *finalStatusReturned = true; + // Client side ignores the mocked return value in the event of a cancellation, so don't need + // to check if stream has been cancelled before sending the status. + } + _rpcReturnStatus.emplaceValue(std::move(status)); + _pipe.close(); +} + +void MockServerStream::cancel(::grpc::Status status) { + // Only mark the RPC as cancelled if a status hasn't already been returned to client. + auto finalStatusReturned = _finalStatusReturned.synchronize(); + if (*finalStatusReturned) { + return; + } + // Need to update the cancellation state before closing the pipe so that when a stream + // read/write is interrupted, the cancellation state will already be up to date. + _rpcCancellationState->cancel(std::move(status)); _pipe.close(); } diff --git a/src/mongo/transport/grpc/mock_server_stream.h b/src/mongo/transport/grpc/mock_server_stream.h index eb7177f46cc..88fbe51d5eb 100644 --- a/src/mongo/transport/grpc/mock_server_stream.h +++ b/src/mongo/transport/grpc/mock_server_stream.h @@ -32,13 +32,14 @@ #include #include +#include + #include "mongo/transport/grpc/bidirectional_pipe.h" #include "mongo/transport/grpc/metadata.h" +#include "mongo/transport/grpc/mock_util.h" #include "mongo/transport/grpc/server_stream.h" -#include "mongo/util/cancellation.h" #include "mongo/util/future.h" #include "mongo/util/net/hostandport.h" -#include "mongo/util/time_support.h" namespace mongo::transport::grpc { @@ -47,16 +48,19 @@ public: ~MockServerStream() = default; boost::optional read() override; + bool write(ConstSharedBuffer msg) override; - explicit MockServerStream(HostAndPort hostAndPort, - Milliseconds timeout, + explicit MockServerStream(HostAndPort remote, Promise&& initialMetadataPromise, + Promise<::grpc::Status>&& rpcTerminationStatusPromise, + std::shared_ptr rpcCancellationState, BidirectionalPipe::End&& serverPipeEnd, MetadataView clientMetadata); private: friend class MockServerContext; + friend class MockRPC; class InitialMetadata { public: @@ -86,13 +90,51 @@ private: }; bool isCancelled() const; - void close(); - CancellationSource _cancellationSource; - Date_t _deadline; + /** + * Cancel the RPC associated with this stream. This is used for mocking situations in + * which an RPC handler was never able to return a final status to the client (e.g. manual + * cancellation or a network interruption). + * + * This method has no effect if the stream is already terminated. + */ + void cancel(::grpc::Status status); + + /** + * Closes the stream and sends the final return status of the RPC to the client. This is the + * mocked equivalent of an RPC handler returning a status. + * + * This does not mark the stream as cancelled. + * + * This method must only be called once, and this stream must not be used after this method has + * been called. + */ + void sendReturnStatus(::grpc::Status status); + + HostAndPort _remote; InitialMetadata _initialMetadata; + + /** + * _rpcReturnStatus is set in sendReturnStatus(), and it is used to mock returning a status from + * an RPC handler. sendReturnStatus itself is called via MockRPC::sendReturnStatus(). + */ + Promise<::grpc::Status> _rpcReturnStatus; + + /** + * _finalStatusReturned is also set in sendReturnStatus(), and it is used to denote that a + * status has been returned and the stream should no longer be used. + */ + synchronized_value _finalStatusReturned; + + /** + * _rpcCancellationState is set via cancel(), which is called by either + * MockServerContext::tryCancel() or MockRPC::cancel(). It is used to mock situations in which a + * server RPC handler is unable to return a status to the client (e.g. explicit cancellation or + * a network interruption). + */ + std::shared_ptr _rpcCancellationState; + BidirectionalPipe::End _pipe; MetadataView _clientMetadata; - HostAndPort _hostAndPort; }; } // namespace mongo::transport::grpc diff --git a/src/mongo/transport/grpc/mock_server_stream_test.cpp b/src/mongo/transport/grpc/mock_server_stream_test.cpp index 1452f9f3f4a..1f01a7b7ff0 100644 --- a/src/mongo/transport/grpc/mock_server_stream_test.cpp +++ b/src/mongo/transport/grpc/mock_server_stream_test.cpp @@ -33,6 +33,8 @@ #include #include +#include + #include "mongo/db/concurrency/locker_noop_service_context_test_fixture.h" #include "mongo/platform/mutex.h" #include "mongo/rpc/message.h" @@ -40,6 +42,7 @@ #include "mongo/transport/grpc/metadata.h" #include "mongo/transport/grpc/mock_server_context.h" #include "mongo/transport/grpc/mock_server_stream.h" +#include "mongo/transport/grpc/mock_stub.h" #include "mongo/transport/grpc/test_fixtures.h" #include "mongo/unittest/assert.h" #include "mongo/unittest/death_test.h" @@ -47,6 +50,7 @@ #include "mongo/unittest/unittest.h" #include "mongo/util/concurrency/notification.h" #include "mongo/util/duration.h" +#include "mongo/util/future.h" #include "mongo/util/net/hostandport.h" #include "mongo/util/scopeguard.h" #include "mongo/util/system_clock_source.h" @@ -58,32 +62,29 @@ template class MockServerStreamBase : public Base { public: static constexpr Milliseconds kTimeout = Milliseconds(100); - static constexpr const char* kRemote = "abc:123"; virtual void setUp() override { Base::setUp(); - _fixtures = std::make_unique( - HostAndPort{kRemote}, kTimeout, _clientMetadata); - } - MockStreamTestFixtures& getFixtures() { - return *_fixtures; + MockStubTestFixtures fixtures; + _fixtures = fixtures.makeStreamTestFixtures( + Base::getServiceContext()->getFastClockSource()->now() + kTimeout, _clientMetadata); } MockServerStream& getServerStream() { - return *getFixtures().serverStream; + return *_fixtures->serverStream; } MockServerContext& getServerContext() { - return *getFixtures().serverCtx; + return *_fixtures->serverCtx; } - MockClientStream& getClientStream() { - return *getFixtures().clientStream; + ClientStream& getClientStream() { + return *_fixtures->clientStream; } - MockClientContext& getClientContext() { - return *getFixtures().clientCtx; + ClientContext& getClientContext() { + return *_fixtures->clientCtx; } const Message& getClientFirstMessage() const { @@ -220,6 +221,7 @@ TEST_F(MockServerStreamTestWithMockedClockSource, DeadlineIsEnforced) { ASSERT_FALSE(getServerStream().write(makeUniqueMessage().sharedBuffer())); ASSERT_FALSE(getClientContext().getServerInitialMetadata()); ASSERT_FALSE(getClientStream().read()); + ASSERT_EQ(getClientStream().finish().error_code(), ::grpc::StatusCode::DEADLINE_EXCEEDED); } TEST_F(MockServerStreamTest, TryCancelSubsequentServerRead) { @@ -265,4 +267,44 @@ TEST_F(MockServerStreamTest, ClientMetadataIsAccessible) { ASSERT_EQ(getServerContext().getClientMetadata(), getClientMetadata()); } +TEST_F(MockServerStreamTest, ClientSideCancellation) { + ASSERT_TRUE(getClientStream().write(makeUniqueMessage().sharedBuffer())); + ASSERT_TRUE(getServerStream().read()); + + getClientContext().tryCancel(); + + ASSERT_FALSE(getClientStream().read()); + ASSERT_FALSE(getClientStream().write(makeUniqueMessage().sharedBuffer())); + ASSERT_FALSE(getServerStream().read()); + ASSERT_FALSE(getServerStream().write(makeUniqueMessage().sharedBuffer())); + ASSERT_TRUE(getServerContext().isCancelled()); + + ASSERT_EQ(getClientStream().finish().error_code(), ::grpc::StatusCode::CANCELLED); +} + +TEST_F(MockServerStreamTest, CancellationInterruptsFinish) { + auto pf = makePromiseFuture<::grpc::Status>(); + auto finishThread = + stdx::thread([&] { pf.promise.setWith([&] { return getClientStream().finish(); }); }); + ON_BLOCK_EXIT([&] { finishThread.join(); }); + + // finish() won't return until server end hangs up too. + ASSERT_FALSE(pf.future.isReady()); + + getServerContext().tryCancel(); + ASSERT_EQ(pf.future.get().error_code(), ::grpc::StatusCode::CANCELLED); +} + +TEST_F(MockServerStreamTestWithMockedClockSource, DeadlineExceededInterruptsFinish) { + auto pf = makePromiseFuture<::grpc::Status>(); + auto finishThread = + stdx::thread([&] { pf.promise.setWith([&] { return getClientStream().finish(); }); }); + ON_BLOCK_EXIT([&] { finishThread.join(); }); + + // finish() won't return until server end hangs up too. + ASSERT_FALSE(pf.future.isReady()); + + clockSource().advance(kTimeout * 2); + ASSERT_EQ(pf.future.get().error_code(), ::grpc::StatusCode::DEADLINE_EXCEEDED); +} } // namespace mongo::transport::grpc diff --git a/src/mongo/transport/grpc/mock_stub.h b/src/mongo/transport/grpc/mock_stub.h new file mode 100644 index 00000000000..5af39aeda93 --- /dev/null +++ b/src/mongo/transport/grpc/mock_stub.h @@ -0,0 +1,217 @@ +/** + * Copyright (C) 2023-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * . + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include "mongo/base/string_data.h" +#include "mongo/transport/grpc/mock_client_context.h" +#include "mongo/transport/grpc/mock_server_context.h" +#include "mongo/transport/grpc/mock_server_stream.h" +#include "mongo/transport/grpc/mock_util.h" +#include "mongo/transport/grpc/service.h" +#include "mongo/unittest/thread_assertion_monitor.h" +#include "mongo/util/assert_util.h" +#include "mongo/util/concurrency/notification.h" +#include "mongo/util/producer_consumer_queue.h" + +namespace mongo::transport::grpc { + +class MockRPC { +public: + /** + * Close the stream and send the final return status of the RPC to the client. This is the + * mocked equivalent of returning a status from an RPC handler. + * + * The RPC's stream must not be used after calling this method. + * This method must only be called once. + */ + void sendReturnStatus(::grpc::Status status) { + serverStream->sendReturnStatus(std::move(status)); + } + + /** + * Cancel the RPC with the provided status. This is used for mocking situations in which an RPC + * handler was never able to return a final status to the client (e.g. network interruption). + * + * For mocking an explicit server-side cancellation, use serverCtx->tryCancel(). + * This method has no effect if the RPC has already been terminated, either by returning a + * status or prior cancellation. + */ + void cancel(::grpc::Status status) { + serverStream->cancel(std::move(status)); + } + + StringData methodName; + std::unique_ptr serverStream; + std::unique_ptr serverCtx; +}; + +using MockRPCQueue = MultiProducerMultiConsumerQueue, MockRPC>>; + +class MockServer { +public: + explicit MockServer(MockRPCQueue::Consumer queue) : _queue(std::move(queue)) {} + + boost::optional acceptRPC() { + try { + auto entry = _queue.pop(); + entry.first.emplaceValue(); + return std::move(entry.second); + } catch (const DBException& e) { + if (e.code() == ErrorCodes::ProducerConsumerQueueEndClosed || + e.code() == ErrorCodes::ProducerConsumerQueueConsumed) { + return boost::none; + } + throw; + } + } + + /** + * Starts up a thread that listens for incoming RPCs and then returns immediately. + * The listener thread will spawn a new thread for each RPC it receives and pass it to the + * provided handler. + * + * The provided handler is expected to throw assertion exceptions, hence the use of + * ThreadAssertionMonitor to spawn threads here. + */ + void start(unittest::ThreadAssertionMonitor& monitor, + std::function<::grpc::Status(MockRPC&)> handler) { + _listenerThread = monitor.spawn([&, handler = std::move(handler)] { + std::vector rpcHandlers; + while (auto rpc = acceptRPC()) { + rpcHandlers.push_back(monitor.spawn([rpc = std::move(*rpc), handler]() mutable { + try { + auto status = handler(rpc); + rpc.sendReturnStatus(std::move(status)); + } catch (DBException& e) { + rpc.sendReturnStatus( + ::grpc::Status(::grpc::StatusCode::UNKNOWN, e.toString())); + } + })); + } + + for (auto& thread : rpcHandlers) { + thread.join(); + } + }); + } + + /** + * Close the mocked channel and then block until all RPC handler threads (if any) have exited. + */ + void shutdown() { + _queue.close(); + if (_listenerThread) { + _listenerThread->join(); + } + } + +private: + boost::optional _listenerThread; + MockRPCQueue::Consumer _queue; +}; + +class MockChannel { +public: + explicit MockChannel(HostAndPort local, HostAndPort remote, MockRPCQueue::Producer queue) + : _local(std::move(local)), _remote{std::move(remote)}, _rpcQueue{std::move(queue)} {}; + + void sendRPC(MockRPC&& rpc) { + auto pf = makePromiseFuture(); + _rpcQueue.push({std::move(pf.promise), std::move(rpc)}); + pf.future.get(); + } + + const HostAndPort& getLocal() const { + return _local; + } + + const HostAndPort& getRemote() const { + return _remote; + } + +private: + HostAndPort _local; + HostAndPort _remote; + MockRPCQueue::Producer _rpcQueue; +}; + +class MockStub { +public: + explicit MockStub(std::shared_ptr channel) : _channel{std::move(channel)} {} + + ~MockStub() {} + + std::shared_ptr unauthenticatedCommandStream(MockClientContext* ctx) { + return _makeStream(CommandService::kUnauthenticatedCommandStreamMethodName, ctx); + } + + std::shared_ptr authenticatedCommandStream(MockClientContext* ctx) { + return _makeStream(CommandService::kAuthenticatedCommandStreamMethodName, ctx); + } + +private: + std::shared_ptr _makeStream(const StringData methodName, + MockClientContext* ctx) { + MetadataView clientMetadata; + for (auto& kvp : ctx->_metadata) { + clientMetadata.insert(kvp); + } + + BidirectionalPipe pipe; + auto metadataPF = makePromiseFuture(); + auto terminationStatusPF = makePromiseFuture<::grpc::Status>(); + auto cancellationState = std::make_shared(ctx->getDeadline()); + + MockRPC rpc; + rpc.methodName = methodName; + rpc.serverStream = + std::make_unique(_channel->getLocal(), + std::move(metadataPF.promise), + std::move(terminationStatusPF.promise), + cancellationState, + std::move(*pipe.left), + clientMetadata); + rpc.serverCtx = std::make_unique(rpc.serverStream.get()); + auto clientStream = + std::make_shared(_channel->getRemote(), + std::move(metadataPF.future), + std::move(terminationStatusPF.future), + cancellationState, + std::move(*pipe.right)); + + ctx->_stream = clientStream.get(); + _channel->sendRPC(std::move(rpc)); + return clientStream; + } + + std::shared_ptr _channel; +}; + +} // namespace mongo::transport::grpc diff --git a/src/mongo/transport/grpc/mock_stub_test.cpp b/src/mongo/transport/grpc/mock_stub_test.cpp new file mode 100644 index 00000000000..1a24f8e09a1 --- /dev/null +++ b/src/mongo/transport/grpc/mock_stub_test.cpp @@ -0,0 +1,204 @@ +/** + * Copyright (C) 2023-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * . + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include +#include + +#include "mongo/db/concurrency/locker_noop_service_context_test_fixture.h" +#include "mongo/rpc/message.h" +#include "mongo/stdx/thread.h" +#include "mongo/transport/grpc/mock_client_context.h" +#include "mongo/transport/grpc/mock_stub.h" +#include "mongo/transport/grpc/test_fixtures.h" +#include "mongo/unittest/assert.h" +#include "mongo/unittest/thread_assertion_monitor.h" +#include "mongo/unittest/unittest.h" +#include "mongo/util/net/hostandport.h" +#include "mongo/util/producer_consumer_queue.h" +#include "mongo/util/scopeguard.h" + +namespace mongo::transport::grpc { + +class MockStubTest : public LockerNoopServiceContextTest { +public: + void setUp() override { + _fixtures = std::make_unique(); + } + + MockStub makeStub() { + return _fixtures->makeStub(); + } + + MockServer& getServer() { + return _fixtures->getServer(); + } + + void runEchoTest(std::function(MockClientContext&)> makeStream) { + unittest::threadAssertionMonitoredTest([&](unittest::ThreadAssertionMonitor& monitor) { + getServer().start(monitor, [](auto& rpc) { + ASSERT_EQ(rpc.serverCtx->getRemote().toString(), + MockStubTestFixtures::kClientAddress); + auto msg = rpc.serverStream->read(); + ASSERT_TRUE(msg); + ASSERT_TRUE(rpc.serverStream->write(*msg)); + return ::grpc::Status::OK; + }); + + std::vector clientThreads; + for (int i = 0; i < 10; i++) { + clientThreads.push_back(monitor.spawn([&]() { + auto clientMessage = makeUniqueMessage(); + MockClientContext ctx; + auto stream = makeStream(ctx); + ASSERT_TRUE(stream->write(clientMessage.sharedBuffer())); + + auto serverResponse = stream->read(); + ASSERT_TRUE(serverResponse); + ASSERT_EQ_MSG(Message{*serverResponse}, clientMessage); + })); + } + + for (auto& thread : clientThreads) { + thread.join(); + } + + getServer().shutdown(); + }); + } + + std::pair, MockRPC> makeRPC(MockClientContext& ctx) { + auto clientStreamPf = makePromiseFuture>(); + auto th = stdx::thread([&, promise = std::move(clientStreamPf.promise)]() mutable { + promise.setWith([&] { + auto stub = makeStub(); + return stub.unauthenticatedCommandStream(&ctx); + }); + }); + ON_BLOCK_EXIT([&th] { th.join(); }); + auto rpc = getServer().acceptRPC(); + ASSERT_TRUE(rpc); + return {clientStreamPf.future.get(), std::move(*rpc)}; + } + +private: + std::unique_ptr _fixtures; +}; + +TEST_F(MockStubTest, ConcurrentStreamsAuth) { + auto stub = makeStub(); + runEchoTest([&](auto& ctx) { return stub.authenticatedCommandStream(&ctx); }); +} + +TEST_F(MockStubTest, ConcurrentStreamsNoAuth) { + auto stub = makeStub(); + runEchoTest([&](auto& ctx) { return stub.unauthenticatedCommandStream(&ctx); }); +} + +TEST_F(MockStubTest, ConcurrentStubsAuth) { + runEchoTest([&](auto& ctx) { return makeStub().authenticatedCommandStream(&ctx); }); +} + +TEST_F(MockStubTest, ConcurrentStubsNoAuth) { + runEchoTest([&](auto& ctx) { return makeStub().unauthenticatedCommandStream(&ctx); }); +} + +TEST_F(MockStubTest, RPCReturn) { + const ::grpc::Status kFinalStatus = + ::grpc::Status{::grpc::StatusCode::FAILED_PRECONDITION, "test"}; + const int kMessageCount = 5; + + MockClientContext ctx; + auto [clientStream, rpc] = makeRPC(ctx); + + for (auto i = 0; i < kMessageCount; i++) { + ASSERT_TRUE(rpc.serverStream->write(makeUniqueMessage().sharedBuffer())); + } + + rpc.sendReturnStatus(kFinalStatus); + ASSERT_FALSE(rpc.serverCtx->isCancelled()) + << "returning a status should not mark stream as cancelled"; + ASSERT_FALSE(clientStream->write(makeUniqueMessage().sharedBuffer())); + + auto finishPf = makePromiseFuture<::grpc::Status>(); + auto finishThread = stdx::thread( + [&clientStream = *clientStream, promise = std::move(finishPf.promise)]() mutable { + promise.setWith([&] { return clientStream.finish(); }); + }); + ON_BLOCK_EXIT([&finishThread] { finishThread.join(); }); + // finish() should not return until all messages have been read. + ASSERT_FALSE(finishPf.future.isReady()); + + // Ensure messages sent before the RPC was finished can still be read. + for (auto i = 0; i < kMessageCount; i++) { + ASSERT_TRUE(clientStream->read()); + } + ASSERT_FALSE(clientStream->read()); + + // Ensure that finish() returns now that all the messages have been read. + auto status = finishPf.future.get(); + ASSERT_EQ(status.error_code(), kFinalStatus.error_code()); + + // Cancelling a finished RPC should have no effect. + rpc.serverCtx->tryCancel(); + ASSERT_FALSE(rpc.serverCtx->isCancelled()); +} + +TEST_F(MockStubTest, RPCCancellation) { + const ::grpc::Status kCancellationStatus = + ::grpc::Status{::grpc::StatusCode::UNAVAILABLE, "mock network error"}; + + MockClientContext ctx; + auto [clientStream, rpc] = makeRPC(ctx); + + ASSERT_TRUE(clientStream->write(makeUniqueMessage().sharedBuffer())); + + rpc.cancel(kCancellationStatus); + + ASSERT_TRUE(rpc.serverCtx->isCancelled()); + ASSERT_FALSE(clientStream->write(makeUniqueMessage().sharedBuffer())); + ASSERT_FALSE(clientStream->read()); + ASSERT_EQ(clientStream->finish().error_code(), kCancellationStatus.error_code()); +} + +TEST_F(MockStubTest, CannotReturnStatusForCancelledRPC) { + MockClientContext ctx; + auto [clientStream, rpc] = makeRPC(ctx); + + ASSERT_TRUE(clientStream->write(makeUniqueMessage().sharedBuffer())); + + rpc.cancel(::grpc::Status::CANCELLED); + rpc.sendReturnStatus(::grpc::Status::OK); + + ASSERT_TRUE(rpc.serverCtx->isCancelled()); + ASSERT_FALSE(clientStream->write(makeUniqueMessage().sharedBuffer())); + ASSERT_FALSE(clientStream->read()); + ASSERT_EQ(clientStream->finish().error_code(), ::grpc::StatusCode::CANCELLED); +} + +} // namespace mongo::transport::grpc diff --git a/src/mongo/transport/grpc/mock_util.h b/src/mongo/transport/grpc/mock_util.h index 456437cfe1f..2e0fc9c02ed 100644 --- a/src/mongo/transport/grpc/mock_util.h +++ b/src/mongo/transport/grpc/mock_util.h @@ -32,13 +32,62 @@ #include #include +#include +#include +#include + #include "mongo/db/operation_context.h" #include "mongo/db/service_context.h" #include "mongo/util/interruptible.h" +#include "mongo/util/synchronized_value.h" #include "mongo/util/time_support.h" namespace mongo::transport::grpc { +/** + * Class containing the shared cancellation state between a MockServerStream and its corresponding + * MockClientStream. This mocks cases in which an RPC is terminated before the server's RPC handler + * is able to return a status (e.g. explicit client/server cancellation or a network error). + */ +class MockCancellationState { +public: + explicit MockCancellationState(Date_t deadline) : _deadline(deadline) {} + + Date_t getDeadline() const { + return _deadline; + } + + bool isCancelled() const { + return _cancellationStatus->has_value() || isDeadlineExceeded(); + } + + bool isDeadlineExceeded() const { + return getGlobalServiceContext()->getFastClockSource()->now() > _deadline; + } + + boost::optional<::grpc::Status> getCancellationStatus() { + if (auto status = _cancellationStatus.synchronize(); status->has_value()) { + return *status; + } else if (isDeadlineExceeded()) { + return ::grpc::Status(::grpc::StatusCode::DEADLINE_EXCEEDED, "Deadline exceeded"); + } else { + return boost::none; + } + } + + void cancel(::grpc::Status status) { + auto statusGuard = _cancellationStatus.synchronize(); + if (statusGuard->has_value()) { + return; + } + *statusGuard = std::move(status); + } + +private: + Date_t _deadline; + synchronized_value> _cancellationStatus; +}; + /** * Performs the provided lambda and returns its return value. If the lambda's execution is * interrupted due to the deadline being exceeded, this returns a default-constructed T instead. diff --git a/src/mongo/transport/grpc/server_context.h b/src/mongo/transport/grpc/server_context.h index a957bc26582..97e467c5260 100644 --- a/src/mongo/transport/grpc/server_context.h +++ b/src/mongo/transport/grpc/server_context.h @@ -56,7 +56,7 @@ public: virtual void addInitialMetadataEntry(const std::string& key, const std::string& value) = 0; virtual const MetadataView& getClientMetadata() const = 0; virtual Date_t getDeadline() const = 0; - virtual HostAndPort getHostAndPort() const = 0; + virtual HostAndPort getRemote() const = 0; /** * Attempt to cancel the RPC this context is associated with. This may not have an effect if the @@ -65,6 +65,15 @@ public: * This is thread-safe. */ virtual void tryCancel() = 0; + + /** + * Return true if the RPC associated with this ServerContext failed before the RPC handler could + * return its final status back to the client (e.g. due to explicit cancellation or a network + * issue). + * + * If the handler was able to return a status successfully, even if that status was + * Status::CANCELLED, then this method will return false. + */ virtual bool isCancelled() const = 0; }; diff --git a/src/mongo/transport/grpc/service_test.cpp b/src/mongo/transport/grpc/service_test.cpp index ef306ee6b46..aa637f84695 100644 --- a/src/mongo/transport/grpc/service_test.cpp +++ b/src/mongo/transport/grpc/service_test.cpp @@ -56,6 +56,7 @@ #include "mongo/unittest/thread_assertion_monitor.h" #include "mongo/unittest/unittest.h" #include "mongo/util/concurrency/notification.h" +#include "mongo/util/scopeguard.h" #define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kTest @@ -194,6 +195,43 @@ TEST_F(CommandServiceTest, Echo) { }); } +TEST_F(CommandServiceTest, SessionTerminate) { + const int kMessageCount = 5; + auto termination = std::make_unique>(); + + auto handler = [&termination](IngressSession& session) { + for (int i = 0; i < kMessageCount; i++) { + auto status = session.sinkMessage(makeUniqueMessage()); + ASSERT_OK(status); + } + session.terminate(Status(ErrorCodes::StreamTerminated, "dummy error")); + termination->set(); + ASSERT_NOT_OK(session.sinkMessage(makeUniqueMessage())); + return ::grpc::Status::CANCELLED; + }; + + runTestWithBothMethods(handler, [&](auto&, auto& monitor, auto methodCallback) { + ::grpc::ClientContext ctx; + CommandServiceTestFixtures::addAllClientMetadata(ctx); + auto stream = methodCallback(ctx); + + // Messages sent before the RPC was cancelled should be able to be read. + termination->get(); + ON_BLOCK_EXIT([&] { + // Reset the termination notification for the next test run. + termination = std::make_unique>(); + }); + for (int i = 0; i < kMessageCount; i++) { + SharedBuffer b; + ASSERT_TRUE(stream->Read(&b)); + } + + SharedBuffer b; + ASSERT_FALSE(stream->Read(&b)); + ASSERT_EQ(stream->Finish().error_code(), ::grpc::StatusCode::CANCELLED); + }); +} + TEST_F(CommandServiceTest, NewClientsAreLogged) { runMetadataLogTest( [clientId = UUID::gen().toString()](::grpc::ClientContext& ctx) { diff --git a/src/mongo/transport/grpc/test_fixtures.h b/src/mongo/transport/grpc/test_fixtures.h index 56735b37997..eaa53fb0733 100644 --- a/src/mongo/transport/grpc/test_fixtures.h +++ b/src/mongo/transport/grpc/test_fixtures.h @@ -49,6 +49,7 @@ #include "mongo/transport/grpc/mock_client_stream.h" #include "mongo/transport/grpc/mock_server_context.h" #include "mongo/transport/grpc/mock_server_stream.h" +#include "mongo/transport/grpc/mock_stub.h" #include "mongo/transport/grpc/server.h" #include "mongo/transport/grpc/service.h" #include "mongo/transport/grpc/util.h" @@ -71,28 +72,59 @@ inline Message makeUniqueMessage() { } struct MockStreamTestFixtures { - MockStreamTestFixtures(HostAndPort hostAndPort, - Milliseconds timeout, - MetadataView clientMetadata) { - BidirectionalPipe pipe; - auto promiseAndFuture = makePromiseFuture(); - - serverStream = std::make_unique(hostAndPort, - timeout, - std::move(promiseAndFuture.promise), - std::move(*pipe.left), - clientMetadata); - serverCtx = std::make_unique(serverStream.get()); - - clientStream = std::make_unique( - hostAndPort, timeout, std::move(promiseAndFuture.future), std::move(*pipe.right)); - clientCtx = std::make_unique(clientStream.get()); + std::shared_ptr clientStream; + std::shared_ptr clientCtx; + std::shared_ptr serverStream; + std::shared_ptr serverCtx; +}; + +class MockStubTestFixtures { +public: + static constexpr auto kBindAddress = "localhost:1234"; + static constexpr auto kClientAddress = "abc:5678"; + + MockStubTestFixtures() { + MockRPCQueue::Pipe pipe; + + _channel = std::make_shared( + HostAndPort(kClientAddress), HostAndPort(kBindAddress), std::move(pipe.producer)); + _server = std::make_unique(std::move(pipe.consumer)); } - std::unique_ptr clientStream; - std::unique_ptr clientCtx; - std::unique_ptr serverStream; - std::unique_ptr serverCtx; + MockStub makeStub() { + return MockStub(_channel); + } + + std::unique_ptr makeStreamTestFixtures( + Date_t deadline, const MetadataView& clientMetadata) { + MockStreamTestFixtures fixtures{ + nullptr, std::make_shared(), nullptr, nullptr}; + + auto clientThread = stdx::thread([&] { + fixtures.clientCtx->setDeadline(deadline); + for (auto& kvp : clientMetadata) { + fixtures.clientCtx->addMetadataEntry(kvp.first.toString(), kvp.second.toString()); + } + fixtures.clientStream = + makeStub().unauthenticatedCommandStream(fixtures.clientCtx.get()); + }); + + auto rpc = getServer().acceptRPC(); + ASSERT_TRUE(rpc); + fixtures.serverCtx = std::move(rpc->serverCtx); + fixtures.serverStream = std::move(rpc->serverStream); + clientThread.join(); + + return std::make_unique(std::move(fixtures)); + } + + MockServer& getServer() { + return *_server; + } + +private: + std::unique_ptr _server; + std::shared_ptr _channel; }; class ServiceContextWithClockSourceMockTest : public LockerNoopServiceContextTest { diff --git a/src/mongo/util/producer_consumer_queue.h b/src/mongo/util/producer_consumer_queue.h index 0c5118e52d6..bef44b6f29f 100644 --- a/src/mongo/util/producer_consumer_queue.h +++ b/src/mongo/util/producer_consumer_queue.h @@ -439,10 +439,11 @@ public: size_t waitingConsumers; size_t waitingProducers; size_t producerQueueDepth; + bool producerEndClosed; + bool consumerEndClosed; // TODO more stats // // totalTimeBlocked on either side - // closed ends // count of producers and consumers (blocked, or existing if we're a pipe) }; @@ -629,6 +630,8 @@ public: stats.waitingConsumers = _consumers; stats.waitingProducers = _producers; stats.producerQueueDepth = _producers.queueDepth(); + stats.producerEndClosed = _producerEndClosed; + stats.consumerEndClosed = _consumerEndClosed; return stats; } -- cgit v1.2.1