diff options
author | samantharitter <samantha.ritter@10gen.com> | 2016-11-04 14:45:32 -0400 |
---|---|---|
committer | samantharitter <samantha.ritter@10gen.com> | 2016-11-05 21:26:59 -0400 |
commit | 0ac04999faae1d2fc0e10972aaf21082a2e48c8f (patch) | |
tree | d9b74efcf36c5381469cc622c3aea4c0f8166398 /src | |
parent | 2d1dd9e07a40f314853e29bffb56b45bf21df940 (diff) | |
download | mongo-0ac04999faae1d2fc0e10972aaf21082a2e48c8f.tar.gz |
SERVER-26674 transport::Session objects should be shared_ptr managed
Diffstat (limited to 'src')
36 files changed, 291 insertions, 292 deletions
diff --git a/src/mongo/client/scoped_db_conn_test.cpp b/src/mongo/client/scoped_db_conn_test.cpp index 0126e0e900b..cebfe48cd16 100644 --- a/src/mongo/client/scoped_db_conn_test.cpp +++ b/src/mongo/client/scoped_db_conn_test.cpp @@ -87,14 +87,14 @@ public: } } - void startSession(transport::Session&& session) override { + void startSession(transport::SessionHandle session) override { _threads.emplace_back(&DummyServiceEntryPoint::run, this, std::move(session)); } private: - void run(transport::Session&& session) { + void run(transport::SessionHandle session) { Message inMessage; - if (!session.sourceMessage(&inMessage).wait().isOK()) { + if (!session->sourceMessage(&inMessage).wait().isOK()) { return; } @@ -117,7 +117,7 @@ private: response.header().setResponseToMsgId(inMessage.header().getId()); - if (!session.sinkMessage(response).wait().isOK()) { + if (!session->sinkMessage(response).wait().isOK()) { return; } } diff --git a/src/mongo/db/auth/authorization_manager_test.cpp b/src/mongo/db/auth/authorization_manager_test.cpp index 9e9312d0694..50f823f014a 100644 --- a/src/mongo/db/auth/authorization_manager_test.cpp +++ b/src/mongo/db/auth/authorization_manager_test.cpp @@ -244,9 +244,9 @@ TEST_F(AuthorizationManagerTest, testAcquireV2User) { TEST_F(AuthorizationManagerTest, testLocalX509Authorization) { ServiceContextNoop serviceContext; transport::TransportLayerMock transportLayer{}; - transport::Session* session = transportLayer.createSession(); + transport::SessionHandle session = transportLayer.createSession(); transportLayer.setX509PeerInfo( - *session, + session, SSLPeerInfo("CN=mongodb.com", {RoleName("read", "test"), RoleName("readWrite", "test")})); ServiceContext::UniqueClient client = serviceContext.makeClient("testClient", session); ServiceContext::UniqueOperationContext txn = client->makeOperationContext(); @@ -278,9 +278,9 @@ TEST_F(AuthorizationManagerTest, testLocalX509Authorization) { TEST_F(AuthorizationManagerTest, testLocalX509AuthorizationInvalidUser) { ServiceContextNoop serviceContext; transport::TransportLayerMock transportLayer{}; - transport::Session* session = transportLayer.createSession(); + transport::SessionHandle session = transportLayer.createSession(); transportLayer.setX509PeerInfo( - *session, + session, SSLPeerInfo("CN=mongodb.com", {RoleName("read", "test"), RoleName("write", "test")})); ServiceContext::UniqueClient client = serviceContext.makeClient("testClient", session); ServiceContext::UniqueOperationContext txn = client->makeOperationContext(); @@ -293,8 +293,8 @@ TEST_F(AuthorizationManagerTest, testLocalX509AuthorizationInvalidUser) { TEST_F(AuthorizationManagerTest, testLocalX509AuthenticationNoAuthorization) { ServiceContextNoop serviceContext; transport::TransportLayerMock transportLayer{}; - transport::Session* session = transportLayer.createSession(); - transportLayer.setX509PeerInfo(*session, {}); + transport::SessionHandle session = transportLayer.createSession(); + transportLayer.setX509PeerInfo(session, {}); ServiceContext::UniqueClient client = serviceContext.makeClient("testClient", session); ServiceContext::UniqueOperationContext txn = client->makeOperationContext(); diff --git a/src/mongo/db/client.cpp b/src/mongo/db/client.cpp index 179b0938d5b..643a25fda05 100644 --- a/src/mongo/db/client.cpp +++ b/src/mongo/db/client.cpp @@ -44,7 +44,6 @@ #include "mongo/db/lasterror.h" #include "mongo/db/service_context.h" #include "mongo/stdx/thread.h" -#include "mongo/transport/session.h" #include "mongo/util/concurrency/thread_name.h" #include "mongo/util/exit.h" #include "mongo/util/mongoutils/str.h" @@ -64,11 +63,13 @@ void Client::initThreadIfNotAlready() { initThreadIfNotAlready(getThreadName().c_str()); } -void Client::initThread(const char* desc, transport::Session* session) { - initThread(desc, getGlobalServiceContext(), session); +void Client::initThread(const char* desc, transport::SessionHandle session) { + initThread(desc, getGlobalServiceContext(), std::move(session)); } -void Client::initThread(const char* desc, ServiceContext* service, transport::Session* session) { +void Client::initThread(const char* desc, + ServiceContext* service, + transport::SessionHandle session) { invariant(currentClient.getMake()->get() == nullptr); std::string fullDesc; @@ -81,7 +82,7 @@ void Client::initThread(const char* desc, ServiceContext* service, transport::Se setThreadName(fullDesc.c_str()); // Create the client obj, attach to thread - *currentClient.get() = service->makeClient(fullDesc, session); + *currentClient.get() = service->makeClient(fullDesc, std::move(session)); } void Client::destroy() { @@ -99,12 +100,12 @@ int64_t generateSeed(const std::string& desc) { } } // namespace -Client::Client(std::string desc, ServiceContext* serviceContext, transport::Session* session) +Client::Client(std::string desc, ServiceContext* serviceContext, transport::SessionHandle session) : _serviceContext(serviceContext), - _session(session), + _session(std::move(session)), _desc(std::move(desc)), _threadId(stdx::this_thread::get_id()), - _connectionId(session ? session->id() : 0), + _connectionId(_session ? _session->id() : 0), _prng(generateSeed(_desc)) {} void Client::reportState(BSONObjBuilder& builder) { diff --git a/src/mongo/db/client.h b/src/mongo/db/client.h index 8e313379b9c..50d4f5de541 100644 --- a/src/mongo/db/client.h +++ b/src/mongo/db/client.h @@ -56,10 +56,6 @@ class AbstractMessagingPort; class Collection; class OperationContext; -namespace transport { -class Session; -} // namespace transport - typedef long long ConnectionId; /** @@ -73,13 +69,12 @@ public: * An unowned pointer to a transport::Session may optionally be provided. If 'session' * is non-null, then it will be used to augment the thread name, and for reporting purposes. * - * If provided, 'session' must outlive the newly-created Client object. Client::destroy() may - * be used to help enforce that the Client does not outlive 'session.' + * If provided, session's ref count will be bumped by this Client. */ - static void initThread(const char* desc, transport::Session* session = nullptr); + static void initThread(const char* desc, transport::SessionHandle session = nullptr); static void initThread(const char* desc, ServiceContext* serviceContext, - transport::Session* session); + transport::SessionHandle session); static Client* getCurrent(); @@ -91,7 +86,7 @@ public: } bool hasRemote() const { - return _session; + return (_session != nullptr); } HostAndPort getRemote() const { @@ -109,10 +104,14 @@ public: /** * Returns the Session to which this client is bound, if any. */ - transport::Session* session() const { + const transport::SessionHandle& session() const& { return _session; } + transport::SessionHandle session() && { + return std::move(_session); + } + /** * Inits a thread if that thread has not already been init'd, setting the thread name to * "desc". @@ -202,10 +201,12 @@ public: private: friend class ServiceContext; - Client(std::string desc, ServiceContext* serviceContext, transport::Session* session = nullptr); + explicit Client(std::string desc, + ServiceContext* serviceContext, + transport::SessionHandle session); ServiceContext* const _serviceContext; - transport::Session* const _session; + const transport::SessionHandle _session; // Description for the client (e.g. conn8) const std::string _desc; diff --git a/src/mongo/db/dbmessage.cpp b/src/mongo/db/dbmessage.cpp index 182d538764b..1b991041094 100644 --- a/src/mongo/db/dbmessage.cpp +++ b/src/mongo/db/dbmessage.cpp @@ -177,7 +177,7 @@ OpQueryReplyBuilder::OpQueryReplyBuilder() : _buffer(32768) { _buffer.skip(sizeof(QueryResult::Value)); } -void OpQueryReplyBuilder::send(transport::Session* session, +void OpQueryReplyBuilder::send(const transport::SessionHandle& session, int queryResultFlags, const Message& requestMsg, int nReturned, @@ -192,7 +192,8 @@ void OpQueryReplyBuilder::send(transport::Session* session, uassertStatusOK(session->sinkMessage(response).wait()); } -void OpQueryReplyBuilder::sendCommandReply(transport::Session* session, const Message& requestMsg) { +void OpQueryReplyBuilder::sendCommandReply(const transport::SessionHandle& session, + const Message& requestMsg) { send(session, /*queryFlags*/ 0, requestMsg, /*nReturned*/ 1); } @@ -209,7 +210,7 @@ void OpQueryReplyBuilder::putInMessage( } void replyToQuery(int queryResultFlags, - transport::Session* session, + const transport::SessionHandle& session, Message& requestMsg, const void* data, int size, @@ -222,7 +223,7 @@ void replyToQuery(int queryResultFlags, } void replyToQuery(int queryResultFlags, - transport::Session* session, + const transport::SessionHandle& session, Message& requestMsg, const BSONObj& responseObj) { replyToQuery(queryResultFlags, diff --git a/src/mongo/db/dbmessage.h b/src/mongo/db/dbmessage.h index 63d3aaea350..405b96c35af 100644 --- a/src/mongo/db/dbmessage.h +++ b/src/mongo/db/dbmessage.h @@ -35,6 +35,7 @@ #include "mongo/client/constants.h" #include "mongo/db/jsobj.h" #include "mongo/db/server_options.h" +#include "mongo/transport/session.h" #include "mongo/util/net/abstract_message_port.h" #include "mongo/util/net/message.h" @@ -42,10 +43,6 @@ namespace mongo { class OperationContext; -namespace transport { -class Session; -} // namespace transport - /* db response format Query or GetMore: // see struct QueryResult @@ -363,7 +360,7 @@ public: /** * Finishes the reply and sends the message out to 'destination'. */ - void send(transport::Session* session, + void send(const transport::SessionHandle& session, int queryResultFlags, const Message& requestMsg, int nReturned, @@ -373,14 +370,14 @@ public: /** * Similar to send() but used for replying to a command. */ - void sendCommandReply(transport::Session* session, const Message& requestMsg); + void sendCommandReply(const transport::SessionHandle& session, const Message& requestMsg); private: BufBuilder _buffer; }; void replyToQuery(int queryResultFlags, - transport::Session* session, + const transport::SessionHandle& session, Message& requestMsg, const void* data, int size, @@ -390,7 +387,7 @@ void replyToQuery(int queryResultFlags, /* object reply helper. */ void replyToQuery(int queryResultFlags, - transport::Session* session, + const transport::SessionHandle& session, Message& requestMsg, const BSONObj& responseObj); diff --git a/src/mongo/db/repl/repl_set_request_votes.cpp b/src/mongo/db/repl/repl_set_request_votes.cpp index c01628466c0..02eb5311cb3 100644 --- a/src/mongo/db/repl/repl_set_request_votes.cpp +++ b/src/mongo/db/repl/repl_set_request_votes.cpp @@ -67,7 +67,7 @@ private: // We want to keep request vote connection open when relinquishing primary. // Tag it here. transport::Session::TagMask originalTag = 0; - transport::Session* session = txn->getClient()->session(); + auto session = txn->getClient()->session(); if (session) { originalTag = session->getTags(); session->replaceTags(originalTag | transport::Session::kKeepOpen); diff --git a/src/mongo/db/repl/replset_commands.cpp b/src/mongo/db/repl/replset_commands.cpp index f52b29d360a..528ce49dbea 100644 --- a/src/mongo/db/repl/replset_commands.cpp +++ b/src/mongo/db/repl/replset_commands.cpp @@ -740,7 +740,7 @@ public: /* we want to keep heartbeat connections open when relinquishing primary. tag them here. */ transport::Session::TagMask originalTag = 0; - transport::Session* session = txn->getClient()->session(); + auto session = txn->getClient()->session(); if (session) { originalTag = session->getTags(); session->replaceTags(originalTag | transport::Session::kKeepOpen); diff --git a/src/mongo/db/service_context.cpp b/src/mongo/db/service_context.cpp index ffaa0f7f128..09af7608dd0 100644 --- a/src/mongo/db/service_context.cpp +++ b/src/mongo/db/service_context.cpp @@ -130,8 +130,8 @@ ServiceContext::~ServiceContext() { } ServiceContext::UniqueClient ServiceContext::makeClient(std::string desc, - transport::Session* session) { - std::unique_ptr<Client> client(new Client(std::move(desc), this, session)); + transport::SessionHandle session) { + std::unique_ptr<Client> client(new Client(std::move(desc), this, std::move(session))); auto observer = _clientObservers.cbegin(); try { for (; observer != _clientObservers.cend(); ++observer) { diff --git a/src/mongo/db/service_context.h b/src/mongo/db/service_context.h index 12b205dba0b..c1895bf4065 100644 --- a/src/mongo/db/service_context.h +++ b/src/mongo/db/service_context.h @@ -37,6 +37,7 @@ #include "mongo/platform/unordered_set.h" #include "mongo/stdx/functional.h" #include "mongo/stdx/mutex.h" +#include "mongo/transport/session.h" #include "mongo/util/clock_source.h" #include "mongo/util/decorable.h" #include "mongo/util/tick_source.h" @@ -50,7 +51,6 @@ class OpObserver; class ServiceEntryPoint; namespace transport { -class Session; class TransportLayer; class TransportLayerManager; } // namespace transport @@ -216,7 +216,7 @@ public: * * If supplied, "session" is the transport::Session used for communicating with the client. */ - UniqueClient makeClient(std::string desc, transport::Session* session = nullptr); + UniqueClient makeClient(std::string desc, transport::SessionHandle session = nullptr); /** * Creates a new OperationContext on "client". diff --git a/src/mongo/db/service_entry_point_mongod.cpp b/src/mongo/db/service_entry_point_mongod.cpp index 310c939c074..ac00e12d9ac 100644 --- a/src/mongo/db/service_entry_point_mongod.cpp +++ b/src/mongo/db/service_entry_point_mongod.cpp @@ -92,16 +92,19 @@ using transport::TransportLayer; ServiceEntryPointMongod::ServiceEntryPointMongod(TransportLayer* tl) : _tl(tl) {} -void ServiceEntryPointMongod::startSession(Session&& session) { - launchWrappedServiceEntryWorkerThread(std::move(session), [this](Session* session) { - _nWorkers.fetchAndAdd(1); - auto guard = MakeGuard([&] { _nWorkers.fetchAndSubtract(1); }); - - _sessionLoop(session); - }); +void ServiceEntryPointMongod::startSession(transport::SessionHandle session) { + // Pass ownership of the transport::SessionHandle into our worker thread. When this + // thread exits, the session will end. + launchWrappedServiceEntryWorkerThread( + std::move(session), [this](const transport::SessionHandle& session) { + _nWorkers.fetchAndAdd(1); + auto guard = MakeGuard([&] { _nWorkers.fetchAndSubtract(1); }); + + _sessionLoop(session); + }); } -void ServiceEntryPointMongod::_sessionLoop(Session* session) { +void ServiceEntryPointMongod::_sessionLoop(const transport::SessionHandle& session) { Message inMessage; bool inExhaust = false; int64_t counter = 0; diff --git a/src/mongo/db/service_entry_point_mongod.h b/src/mongo/db/service_entry_point_mongod.h index 95a24de92fa..4917430f02f 100644 --- a/src/mongo/db/service_entry_point_mongod.h +++ b/src/mongo/db/service_entry_point_mongod.h @@ -53,14 +53,14 @@ public: virtual ~ServiceEntryPointMongod() = default; - void startSession(transport::Session&& session) override; + void startSession(transport::SessionHandle session) override; std::size_t getNumberOfActiveWorkerThreads() const { return _nWorkers.load(); } private: - void _sessionLoop(transport::Session* session); + void _sessionLoop(const transport::SessionHandle& session); transport::TransportLayer* _tl; AtomicWord<std::size_t> _nWorkers; diff --git a/src/mongo/s/commands/SConscript b/src/mongo/s/commands/SConscript index 7722be731f1..2d8b0f7bf34 100644 --- a/src/mongo/s/commands/SConscript +++ b/src/mongo/s/commands/SConscript @@ -84,5 +84,6 @@ env.Library( '$BUILD_DIR/mongo/rpc/client_metadata', '$BUILD_DIR/mongo/s/coreshard', '$BUILD_DIR/mongo/s/mongoscore', + '$BUILD_DIR/mongo/transport/transport_layer_common', ] ) diff --git a/src/mongo/s/commands/request.cpp b/src/mongo/s/commands/request.cpp index 6605db10424..03dc61f8003 100644 --- a/src/mongo/s/commands/request.cpp +++ b/src/mongo/s/commands/request.cpp @@ -123,7 +123,7 @@ void Request::process(OperationContext* txn, int attempt) { << " op: " << op << " attempt: " << attempt; } -transport::Session* Request::session() const { +const transport::SessionHandle& Request::session() const { return _clientInfo->session(); } diff --git a/src/mongo/s/commands/request.h b/src/mongo/s/commands/request.h index ace3e77a720..6b45e0210ec 100644 --- a/src/mongo/s/commands/request.h +++ b/src/mongo/s/commands/request.h @@ -30,6 +30,7 @@ #pragma once #include "mongo/db/dbmessage.h" +#include "mongo/transport/session.h" #include "mongo/util/net/message.h" namespace mongo { @@ -37,10 +38,6 @@ namespace mongo { class Client; class OperationContext; -namespace transport { -class Session; -} // namespace transport - class Request { MONGO_DISALLOW_COPYING(Request); @@ -76,7 +73,7 @@ public: return _d; } - transport::Session* session() const; + const transport::SessionHandle& session() const; void process(OperationContext* txn, int attempt = 0); diff --git a/src/mongo/s/service_entry_point_mongos.cpp b/src/mongo/s/service_entry_point_mongos.cpp index 427152c8adb..33482db8e72 100644 --- a/src/mongo/s/service_entry_point_mongos.cpp +++ b/src/mongo/s/service_entry_point_mongos.cpp @@ -73,12 +73,13 @@ using transport::TransportLayer; ServiceEntryPointMongos::ServiceEntryPointMongos(TransportLayer* tl) : _tl(tl) {} -void ServiceEntryPointMongos::startSession(Session&& session) { - launchWrappedServiceEntryWorkerThread(std::move(session), - [this](Session* session) { _sessionLoop(session); }); +void ServiceEntryPointMongos::startSession(transport::SessionHandle session) { + launchWrappedServiceEntryWorkerThread( + std::move(session), + [this](const transport::SessionHandle& session) { _sessionLoop(session); }); } -void ServiceEntryPointMongos::_sessionLoop(Session* session) { +void ServiceEntryPointMongos::_sessionLoop(const transport::SessionHandle& session) { Message message; int64_t counter = 0; diff --git a/src/mongo/s/service_entry_point_mongos.h b/src/mongo/s/service_entry_point_mongos.h index d96540b2b1e..acf42d89fc9 100644 --- a/src/mongo/s/service_entry_point_mongos.h +++ b/src/mongo/s/service_entry_point_mongos.h @@ -52,10 +52,10 @@ public: virtual ~ServiceEntryPointMongos() = default; - void startSession(transport::Session&& session) override; + void startSession(transport::SessionHandle session) override; private: - void _sessionLoop(transport::Session* session); + void _sessionLoop(const transport::SessionHandle& session); transport::TransportLayer* _tl; }; diff --git a/src/mongo/s/sharding_test_fixture.cpp b/src/mongo/s/sharding_test_fixture.cpp index 41c8936f49c..715e6d9bbfe 100644 --- a/src/mongo/s/sharding_test_fixture.cpp +++ b/src/mongo/s/sharding_test_fixture.cpp @@ -101,9 +101,8 @@ void ShardingTestFixture::setUp() { _transportLayer = tlMock.get(); _service->addAndStartTransportLayer(std::move(tlMock)); CollatorFactoryInterface::set(_service.get(), stdx::make_unique<CollatorFactoryMock>()); - _transportSession = - stdx::make_unique<transport::Session>(HostAndPort{}, HostAndPort{}, _transportLayer); - _client = _service->makeClient("ShardingTestFixture", _transportSession.get()); + _transportSession = transport::Session::create(HostAndPort{}, HostAndPort{}, _transportLayer); + _client = _service->makeClient("ShardingTestFixture", _transportSession); _opCtx = _client->makeOperationContext(); // Set up executor pool used for most operations. @@ -495,7 +494,7 @@ void ShardingTestFixture::expectCount(const HostAndPort& configHost, } void ShardingTestFixture::setRemote(const HostAndPort& remote) { - *_transportSession = transport::Session{remote, HostAndPort{}, _transportLayer}; + _transportSession = transport::Session::create(remote, HostAndPort{}, _transportLayer); } void ShardingTestFixture::checkReadConcern(const BSONObj& cmdObj, diff --git a/src/mongo/s/sharding_test_fixture.h b/src/mongo/s/sharding_test_fixture.h index 6f4cece4dea..a48a56e0df3 100644 --- a/src/mongo/s/sharding_test_fixture.h +++ b/src/mongo/s/sharding_test_fixture.h @@ -32,6 +32,7 @@ #include "mongo/db/service_context.h" #include "mongo/executor/network_test_env.h" +#include "mongo/transport/session.h" #include "mongo/unittest/unittest.h" #include "mongo/util/net/message_port_mock.h" @@ -97,7 +98,7 @@ protected: executor::TaskExecutor* executor() const; - transport::Session* getTransportSession() const; + const transport::SessionHandle& getTransportSession() const; DistLockManagerMock* distLock() const; @@ -208,7 +209,7 @@ private: ServiceContext::UniqueClient _client; ServiceContext::UniqueOperationContext _opCtx; transport::TransportLayerMock* _transportLayer; - std::unique_ptr<transport::Session> _transportSession; + transport::SessionHandle _transportSession; RemoteCommandTargeterFactoryMock* _targeterFactory; RemoteCommandTargeterMock* _configTargeter; diff --git a/src/mongo/transport/service_entry_point.h b/src/mongo/transport/service_entry_point.h index a1c031c5e3e..e0a1ee7a907 100644 --- a/src/mongo/transport/service_entry_point.h +++ b/src/mongo/transport/service_entry_point.h @@ -29,13 +29,10 @@ #pragma once #include "mongo/base/disallow_copying.h" +#include "mongo/transport/session.h" namespace mongo { -namespace transport { -class Session; -} // namespace transport - /** * This is the entrypoint from the transport layer into mongod or mongos. * @@ -52,7 +49,7 @@ public: /** * Begin running a new Session. This method returns immediately. */ - virtual void startSession(transport::Session&& session) = 0; + virtual void startSession(transport::SessionHandle session) = 0; protected: ServiceEntryPoint() = default; diff --git a/src/mongo/transport/service_entry_point_mock.cpp b/src/mongo/transport/service_entry_point_mock.cpp index 8d4530d7f13..667b1d5ae5e 100644 --- a/src/mongo/transport/service_entry_point_mock.cpp +++ b/src/mongo/transport/service_entry_point_mock.cpp @@ -86,11 +86,11 @@ ServiceEntryPointMock::~ServiceEntryPointMock() { } } -void ServiceEntryPointMock::startSession(transport::Session&& session) { +void ServiceEntryPointMock::startSession(transport::SessionHandle session) { _threads.emplace_back(&ServiceEntryPointMock::run, this, std::move(session)); } -void ServiceEntryPointMock::run(transport::Session&& session) { +void ServiceEntryPointMock::run(transport::SessionHandle session) { Message inMessage; while (true) { { @@ -100,12 +100,12 @@ void ServiceEntryPointMock::run(transport::Session&& session) { } // sourceMessage() - if (!session.sourceMessage(&inMessage).wait().isOK()) { + if (!session->sourceMessage(&inMessage).wait().isOK()) { break; } // sinkMessage() - if (!session.sinkMessage(_outMessage).wait().isOK()) { + if (!session->sinkMessage(_outMessage).wait().isOK()) { break; } } diff --git a/src/mongo/transport/service_entry_point_mock.h b/src/mongo/transport/service_entry_point_mock.h index ceeca2a0985..eadc9367fa8 100644 --- a/src/mongo/transport/service_entry_point_mock.h +++ b/src/mongo/transport/service_entry_point_mock.h @@ -34,13 +34,13 @@ #include "mongo/stdx/mutex.h" #include "mongo/stdx/thread.h" #include "mongo/transport/service_entry_point.h" +#include "mongo/transport/session.h" #include "mongo/util/net/message.h" namespace mongo { namespace transport { -class Session; class TransportLayer; } // namespace transport @@ -63,10 +63,10 @@ public: * * ...repeat until wait() returns an error. */ - void startSession(transport::Session&& session) override; + void startSession(transport::SessionHandle session) override; private: - void run(transport::Session&& session); + void run(transport::SessionHandle session); transport::TransportLayer* _tl; diff --git a/src/mongo/transport/service_entry_point_test_suite.cpp b/src/mongo/transport/service_entry_point_test_suite.cpp index eefb9593218..4587ac8ea8b 100644 --- a/src/mongo/transport/service_entry_point_test_suite.cpp +++ b/src/mongo/transport/service_entry_point_test_suite.cpp @@ -90,7 +90,7 @@ void setPingCommand(Message* m) { } // Some default method implementations -const auto kDefaultEnd = [](const Session& session) { return; }; +const auto kDefaultEnd = [](const SessionHandle& session) { return; }; const auto kDefaultDestroyHook = [](Session& session) { return; }; const auto kDefaultAsyncWait = [](Ticket, TicketCallback cb) { cb(Status::OK()); }; const auto kNoopFunction = [] { return; }; @@ -100,13 +100,13 @@ const auto kEndConnectionStatus = Status(ErrorCodes::HostUnreachable, "connectio } // namespace -ServiceEntryPointTestSuite::MockTicket::MockTicket(const Session& session, +ServiceEntryPointTestSuite::MockTicket::MockTicket(const SessionHandle& session, Message* message, Date_t expiration) - : _message(message), _sessionId(session.id()), _expiration(expiration) {} + : _message(message), _sessionId(session->id()), _expiration(expiration) {} -ServiceEntryPointTestSuite::MockTicket::MockTicket(const Session& session, Date_t expiration) - : _sessionId(session.id()), _expiration(expiration) {} +ServiceEntryPointTestSuite::MockTicket::MockTicket(const SessionHandle& session, Date_t expiration) + : _sessionId(session->id()), _expiration(expiration) {} Session::Id ServiceEntryPointTestSuite::MockTicket::sessionId() const { return _sessionId; @@ -129,13 +129,13 @@ ServiceEntryPointTestSuite::MockTLHarness::MockTLHarness() _asyncWait(kDefaultAsyncWait), _end(kDefaultEnd) {} -Ticket ServiceEntryPointTestSuite::MockTLHarness::sourceMessage(Session& session, +Ticket ServiceEntryPointTestSuite::MockTLHarness::sourceMessage(const SessionHandle& session, Message* message, Date_t expiration) { return _sourceMessage(session, message, expiration); } -Ticket ServiceEntryPointTestSuite::MockTLHarness::sinkMessage(Session& session, +Ticket ServiceEntryPointTestSuite::MockTLHarness::sinkMessage(const SessionHandle& session, const Message& message, Date_t expiration) { return _sinkMessage(session, message, expiration); @@ -151,17 +151,17 @@ void ServiceEntryPointTestSuite::MockTLHarness::asyncWait(Ticket&& ticket, } SSLPeerInfo ServiceEntryPointTestSuite::MockTLHarness::getX509PeerInfo( - const Session& session) const { + const ConstSessionHandle& session) const { return SSLPeerInfo("mock", stdx::unordered_set<RoleName>{}); } -void ServiceEntryPointTestSuite::MockTLHarness::registerTags(const Session& session) {} +void ServiceEntryPointTestSuite::MockTLHarness::registerTags(const ConstSessionHandle& session) {} TransportLayer::Stats ServiceEntryPointTestSuite::MockTLHarness::sessionStats() { return Stats(); } -void ServiceEntryPointTestSuite::MockTLHarness::end(Session& session) { +void ServiceEntryPointTestSuite::MockTLHarness::end(const SessionHandle& session) { return _end(session); } @@ -194,17 +194,19 @@ Status ServiceEntryPointTestSuite::MockTLHarness::_waitOnceThenError(transport:: return _defaultWait(std::move(ticket)); } -Ticket ServiceEntryPointTestSuite::MockTLHarness::_defaultSource(Session& s, Message* m, Date_t d) { +Ticket ServiceEntryPointTestSuite::MockTLHarness::_defaultSource(const SessionHandle& s, + Message* m, + Date_t d) { return Ticket(this, stdx::make_unique<ServiceEntryPointTestSuite::MockTicket>(s, m, d)); } -Ticket ServiceEntryPointTestSuite::MockTLHarness::_defaultSink(Session& s, +Ticket ServiceEntryPointTestSuite::MockTLHarness::_defaultSink(const SessionHandle& s, const Message&, Date_t d) { return Ticket(this, stdx::make_unique<ServiceEntryPointTestSuite::MockTicket>(s, d)); } -Ticket ServiceEntryPointTestSuite::MockTLHarness::_sinkThenErrorOnWait(Session& s, +Ticket ServiceEntryPointTestSuite::MockTLHarness::_sinkThenErrorOnWait(const SessionHandle& s, const Message& m, Date_t d) { _wait = stdx::bind(&ServiceEntryPointTestSuite::MockTLHarness::_waitOnceThenError, this, _1); @@ -251,10 +253,10 @@ void ServiceEntryPointTestSuite::noLifeCycleTest() { _tl->_wait = stdx::bind(&ServiceEntryPointTestSuite::MockTLHarness::_waitError, _tl.get(), _1); // Step 3: SEP destroys the session, which calls end() - _tl->_destroy_hook = [&testComplete](const Session&) { testComplete.set_value(); }; + _tl->_destroy_hook = [&testComplete](Session&) { testComplete.set_value(); }; // Kick off the SEP - Session s(HostAndPort(), HostAndPort(), _tl.get()); + auto s = Session::create(HostAndPort(), HostAndPort(), _tl.get()); _sep->startSession(std::move(s)); testFuture.wait(); @@ -270,7 +272,7 @@ void ServiceEntryPointTestSuite::halfLifeCycleTest() { // Step 1: SEP gets a ticket to source a Message // Step 2: SEP calls wait() on the ticket and receives a Message // Step 3: SEP gets a ticket to sink a Message - _tl->_sinkMessage = [this](Session& session, const Message& m, Date_t expiration) { + _tl->_sinkMessage = [this](const SessionHandle& session, const Message& m, Date_t expiration) { // Step 4: SEP calls wait() on the ticket and receives an error _tl->_wait = @@ -280,10 +282,10 @@ void ServiceEntryPointTestSuite::halfLifeCycleTest() { }; // Step 5: SEP destroys the session, which calls end() - _tl->_destroy_hook = [&testComplete](const Session&) { testComplete.set_value(); }; + _tl->_destroy_hook = [&testComplete](Session&) { testComplete.set_value(); }; // Kick off the SEP - Session s(HostAndPort(), HostAndPort(), _tl.get()); + auto s = Session::create(HostAndPort(), HostAndPort(), _tl.get()); _sep->startSession(std::move(s)); testFuture.wait(); @@ -306,20 +308,20 @@ void ServiceEntryPointTestSuite::fullLifeCycleTest() { // Step 5: SEP gets a ticket to source a Message // Step 6: SEP calls wait() on the ticket and receives and error // Step 7: SEP destroys the session, which calls end() - _tl->_destroy_hook = [&testComplete](const Session& session) { testComplete.set_value(); }; + _tl->_destroy_hook = [&testComplete](Session& session) { testComplete.set_value(); }; // Kick off the SEP - Session s(HostAndPort(), HostAndPort(), _tl.get()); + auto s = Session::create(HostAndPort(), HostAndPort(), _tl.get()); _sep->startSession(std::move(s)); testFuture.wait(); } void ServiceEntryPointTestSuite::interruptingSessionTest() { - Session sA(HostAndPort(), HostAndPort(), _tl.get()); - Session sB(HostAndPort(), HostAndPort(), _tl.get()); - auto idA = sA.id(); - auto idB = sB.id(); + auto sA = Session::create(HostAndPort(), HostAndPort(), _tl.get()); + auto sB = Session::create(HostAndPort(), HostAndPort(), _tl.get()); + auto idA = sA->id(); + auto idB = sB->id(); int waitCountB = 0; stdx::promise<void> startB; @@ -366,7 +368,7 @@ void ServiceEntryPointTestSuite::interruptingSessionTest() { // Step 7: SEP calls sourceMessage() for B, gets tB3 // Step 8: SEP calls wait() for tB3, gets an error // Step 9: SEP calls end(B) - _tl->_destroy_hook = [this, idA, idB, &resumeA, &testComplete](const Session& session) { + _tl->_destroy_hook = [this, idA, idB, &resumeA, &testComplete](Session& session) { // When end(B) is called, time to resume session A if (session.id() == idB) { // Resume session A @@ -450,18 +452,18 @@ void ServiceEntryPointTestSuite::burstStressTest(int numSessions, }; // When we end the last session, end the test. - _tl->_destroy_hook = [&allSessionsComplete, numSessions, &ended](const Session& session) { + _tl->_destroy_hook = [&allSessionsComplete, numSessions, &ended](Session& session) { if (ended.fetchAndAdd(1) == (numSessions - 1)) { allSessionsComplete.set_value(); } }; for (int i = 0; i < numSessions; i++) { - Session s(HostAndPort(), HostAndPort(), _tl.get()); + auto s = Session::create(HostAndPort(), HostAndPort(), _tl.get()); { // This operation may cause a re-hash. stdx::lock_guard<stdx::mutex> lock(cyclesLock); - completedCycles.emplace(s.id(), 0); + completedCycles.emplace(s->id(), 0); } _sep->startSession(std::move(s)); } diff --git a/src/mongo/transport/service_entry_point_test_suite.h b/src/mongo/transport/service_entry_point_test_suite.h index ba1d60fd027..d1d3042cc5e 100644 --- a/src/mongo/transport/service_entry_point_test_suite.h +++ b/src/mongo/transport/service_entry_point_test_suite.h @@ -92,10 +92,10 @@ private: class MockTicket : public transport::TicketImpl { public: // Source constructor - MockTicket(const transport::Session& session, Message* message, Date_t expiration); + MockTicket(const transport::SessionHandle& session, Message* message, Date_t expiration); // Sink constructor - MockTicket(const transport::Session& session, Date_t expiration); + MockTicket(const transport::SessionHandle& session, Date_t expiration); MockTicket(MockTicket&&) = default; MockTicket& operator=(MockTicket&&) = default; @@ -121,19 +121,19 @@ private: MockTLHarness(); transport::Ticket sourceMessage( - transport::Session& session, + const transport::SessionHandle& session, Message* message, Date_t expiration = transport::Ticket::kNoExpirationDate) override; transport::Ticket sinkMessage( - transport::Session& session, + const transport::SessionHandle& session, const Message& message, Date_t expiration = transport::Ticket::kNoExpirationDate) override; Status wait(transport::Ticket&& ticket) override; void asyncWait(transport::Ticket&& ticket, TicketCallback callback) override; - SSLPeerInfo getX509PeerInfo(const transport::Session& session) const override; - void registerTags(const transport::Session& session) override; + SSLPeerInfo getX509PeerInfo(const transport::ConstSessionHandle& session) const override; + void registerTags(const transport::ConstSessionHandle& session) override; Stats sessionStats() override; - void end(transport::Session& session) override; + void end(const transport::SessionHandle& session) override; void endAllSessions(transport::Session::TagMask tags) override; Status start() override; void shutdown() override; @@ -141,11 +141,13 @@ private: ServiceEntryPointTestSuite::MockTicket* getMockTicket(const transport::Ticket& ticket); // Mocked method hooks - stdx::function<transport::Ticket(transport::Session&, Message*, Date_t)> _sourceMessage; - stdx::function<transport::Ticket(transport::Session&, const Message&, Date_t)> _sinkMessage; + stdx::function<transport::Ticket(const transport::SessionHandle&, Message*, Date_t)> + _sourceMessage; + stdx::function<transport::Ticket(const transport::SessionHandle&, const Message&, Date_t)> + _sinkMessage; stdx::function<Status(transport::Ticket)> _wait; stdx::function<void(transport::Ticket, TicketCallback)> _asyncWait; - stdx::function<void(const transport::Session&)> _end; + stdx::function<void(const transport::SessionHandle&)> _end; stdx::function<void(transport::Session& session)> _destroy_hook; stdx::function<void(transport::Session::TagMask tags)> _endAllSessions = [](transport::Session::TagMask tags) {}; @@ -153,9 +155,11 @@ private: stdx::function<void(void)> _shutdown = [] {}; // Pre-set hook methods - transport::Ticket _defaultSource(transport::Session& s, Message* m, Date_t d); - transport::Ticket _defaultSink(transport::Session& s, const Message&, Date_t d); - transport::Ticket _sinkThenErrorOnWait(transport::Session& s, const Message& m, Date_t d); + transport::Ticket _defaultSource(const transport::SessionHandle& s, Message* m, Date_t d); + transport::Ticket _defaultSink(const transport::SessionHandle& s, const Message&, Date_t d); + transport::Ticket _sinkThenErrorOnWait(const transport::SessionHandle& s, + const Message& m, + Date_t d); Status _defaultWait(transport::Ticket ticket); Status _waitError(transport::Ticket ticket); diff --git a/src/mongo/transport/service_entry_point_utils.cpp b/src/mongo/transport/service_entry_point_utils.cpp index ac0a6109e03..9d74de81a01 100644 --- a/src/mongo/transport/service_entry_point_utils.cpp +++ b/src/mongo/transport/service_entry_point_utils.cpp @@ -55,23 +55,27 @@ namespace mongo { namespace { +/** + * This object takes ownership of transport::SessionHandle. + */ struct Context { - Context(transport::Session session, stdx::function<void(transport::Session*)> task) + Context(transport::SessionHandle session, + stdx::function<void(const transport::SessionHandle&)> task) : session(std::move(session)), task(std::move(task)) {} - transport::Session session; - stdx::function<void(transport::Session*)> task; + transport::SessionHandle session; + stdx::function<void(const transport::SessionHandle&)> task; }; void* runFunc(void* ptr) { std::unique_ptr<Context> ctx(static_cast<Context*>(ptr)); - auto tl = ctx->session.getTransportLayer(); - Client::initThread("conn", &ctx->session); - setThreadName(std::string(str::stream() << "conn" << ctx->session.id())); + auto tl = ctx->session->getTransportLayer(); + Client::initThread("conn", ctx->session); + setThreadName(std::string(str::stream() << "conn" << ctx->session->id())); try { - ctx->task(&ctx->session); + ctx->task(ctx->session); } catch (const AssertionException& e) { log() << "AssertionException handling request, closing client connection: " << e; } catch (const SocketException& e) { @@ -89,7 +93,7 @@ void* runFunc(void* ptr) { if (!serverGlobalParams.quiet) { auto conns = tl->sessionStats().numOpenSessions; const char* word = (conns == 1 ? " connection" : " connections"); - log() << "end connection " << ctx->session.remote() << " (" << conns << word + log() << "end connection " << ctx->session->remote() << " (" << conns << word << " now open)"; } @@ -99,8 +103,8 @@ void* runFunc(void* ptr) { } } // namespace -void launchWrappedServiceEntryWorkerThread(transport::Session&& session, - stdx::function<void(transport::Session*)> task) { +void launchWrappedServiceEntryWorkerThread( + transport::SessionHandle session, stdx::function<void(const transport::SessionHandle&)> task) { auto ctx = stdx::make_unique<Context>(std::move(session), std::move(task)); try { @@ -147,7 +151,7 @@ void launchWrappedServiceEntryWorkerThread(transport::Session&& session, #endif // __linux__ } catch (...) { - log() << "failed to create service entry worker thread for " << ctx->session.remote(); + log() << "failed to create service entry worker thread for " << ctx->session->remote(); } } diff --git a/src/mongo/transport/service_entry_point_utils.h b/src/mongo/transport/service_entry_point_utils.h index 42be4079f9c..1c1634af6d5 100644 --- a/src/mongo/transport/service_entry_point_utils.h +++ b/src/mongo/transport/service_entry_point_utils.h @@ -29,14 +29,11 @@ #pragma once #include "mongo/stdx/functional.h" +#include "mongo/transport/session.h" namespace mongo { -namespace transport { -class Session; -} // namespace transport - -void launchWrappedServiceEntryWorkerThread(transport::Session&& session, - stdx::function<void(transport::Session*)> task); +void launchWrappedServiceEntryWorkerThread( + transport::SessionHandle session, stdx::function<void(const transport::SessionHandle&)> task); } // namespace mongo diff --git a/src/mongo/transport/session.cpp b/src/mongo/transport/session.cpp index b32c42609b8..4287d62cfd2 100644 --- a/src/mongo/transport/session.cpp +++ b/src/mongo/transport/session.cpp @@ -56,44 +56,26 @@ Session::~Session() { } } -Session::Session(Session&& other) - : _id(other._id), - _remote(std::move(other._remote)), - _local(std::move(other._local)), - _tl(other._tl) { - // We do not want to call tl->destroy() on moved-from Sessions. - other._tl = nullptr; -} - -Session& Session::operator=(Session&& other) { - if (&other == this) { - return *this; - } - - _id = other._id; - _remote = std::move(other._remote); - _local = std::move(other._local); - _tl = other._tl; - other._tl = nullptr; - - return *this; +SessionHandle Session::create(HostAndPort remote, HostAndPort local, TransportLayer* tl) { + std::shared_ptr<Session> handle(new Session(std::move(remote), std::move(local), tl)); + return handle; } void Session::replaceTags(TagMask tags) { _tags = tags; - _tl->registerTags(*this); + _tl->registerTags(shared_from_this()); } Ticket Session::sourceMessage(Message* message, Date_t expiration) { - return _tl->sourceMessage(*this, message, expiration); + return _tl->sourceMessage(shared_from_this(), message, expiration); } Ticket Session::sinkMessage(const Message& message, Date_t expiration) { - return _tl->sinkMessage(*this, message, expiration); + return _tl->sinkMessage(shared_from_this(), message, expiration); } SSLPeerInfo Session::getX509PeerInfo() const { - return _tl->getX509PeerInfo(*this); + return _tl->getX509PeerInfo(shared_from_this()); } } // namespace transport diff --git a/src/mongo/transport/session.h b/src/mongo/transport/session.h index 026a3445bb4..29e54333ceb 100644 --- a/src/mongo/transport/session.h +++ b/src/mongo/transport/session.h @@ -28,6 +28,8 @@ #pragma once +#include <memory> + #include "mongo/base/disallow_copying.h" #include "mongo/transport/message_compressor_manager.h" #include "mongo/transport/session_id.h" @@ -43,12 +45,16 @@ struct SSLPeerInfo; namespace transport { class TransportLayer; +class Session; + +using SessionHandle = std::shared_ptr<Session>; +using ConstSessionHandle = std::shared_ptr<const Session>; /** * This type contains data needed to associate Messages with connections * (on the transport side) and Messages with Client objects (on the database side). */ -class Session { +class Session : public std::enable_shared_from_this<Session> { MONGO_DISALLOW_COPYING(Session); public: @@ -68,20 +74,14 @@ public: static constexpr TagMask kKeepOpen = 1; /** - * Construct a new session. - */ - Session(HostAndPort remote, HostAndPort local, TransportLayer* tl); - - /** * Destroys a session, calling end() for this session in its TransportLayer. */ ~Session(); /** - * Move constructor and assignment operator. + * A factory for sessions. */ - Session(Session&& other); - Session& operator=(Session&& other); + static SessionHandle create(HostAndPort remote, HostAndPort local, TransportLayer* tl); /** * Return the id for this session. @@ -149,6 +149,11 @@ public: } private: + /** + * Construct a new session. + */ + Session(HostAndPort remote, HostAndPort local, TransportLayer* tl); + Id _id; HostAndPort _remote; diff --git a/src/mongo/transport/transport_layer.h b/src/mongo/transport/transport_layer.h index 3313a68108d..b09ff8a8eb9 100644 --- a/src/mongo/transport/transport_layer.h +++ b/src/mongo/transport/transport_layer.h @@ -101,7 +101,7 @@ public: * TransportLayer is unable to source a Message, this will be a failed status, * and the passed-in Message buffer may be left in an invalid state. */ - virtual Ticket sourceMessage(Session& session, + virtual Ticket sourceMessage(const SessionHandle& session, Message* message, Date_t expiration = Ticket::kNoExpirationDate) = 0; @@ -120,7 +120,7 @@ public: * This method does NOT take ownership of the sunk Message, which must be cleaned * up by the caller. */ - virtual Ticket sinkMessage(Session& session, + virtual Ticket sinkMessage(const SessionHandle& session, const Message& message, Date_t expiration = Ticket::kNoExpirationDate) = 0; @@ -154,13 +154,13 @@ public: * * Before calling this method, use Session::replaceTags() to set the desired TagMask. */ - virtual void registerTags(const Session& session) = 0; + virtual void registerTags(const ConstSessionHandle& session) = 0; /** * Return the stored X509 peer information for this session. If the session does not * exist in this TransportLayer, returns a default constructed object. */ - virtual SSLPeerInfo getX509PeerInfo(const Session& session) const = 0; + virtual SSLPeerInfo getX509PeerInfo(const ConstSessionHandle& session) const = 0; /** * Returns the number of sessions currently open in the transport layer. @@ -178,7 +178,7 @@ public: * * This method is idempotent and synchronous. */ - virtual void end(Session& session) = 0; + virtual void end(const SessionHandle& session) = 0; /** * End all active sessions in the TransportLayer. Tickets that have already been started via diff --git a/src/mongo/transport/transport_layer_legacy.cpp b/src/mongo/transport/transport_layer_legacy.cpp index a410ff09fdb..b5e3c96c6e4 100644 --- a/src/mongo/transport/transport_layer_legacy.cpp +++ b/src/mongo/transport/transport_layer_legacy.cpp @@ -64,10 +64,10 @@ TransportLayerLegacy::TransportLayerLegacy(const TransportLayerLegacy::Options& _running(false), _options(opts) {} -TransportLayerLegacy::LegacyTicket::LegacyTicket(const Session& session, +TransportLayerLegacy::LegacyTicket::LegacyTicket(const SessionHandle& session, Date_t expiration, WorkHandle work) - : _sessionId(session.id()), _expiration(expiration), _fill(std::move(work)) {} + : _sessionId(session->id()), _expiration(expiration), _fill(std::move(work)) {} Session::Id TransportLayerLegacy::LegacyTicket::sessionId() const { return _sessionId; @@ -98,8 +98,10 @@ Status TransportLayerLegacy::start() { TransportLayerLegacy::~TransportLayerLegacy() = default; -Ticket TransportLayerLegacy::sourceMessage(Session& session, Message* message, Date_t expiration) { - auto& compressorMgr = session.getCompressorManager(); +Ticket TransportLayerLegacy::sourceMessage(const SessionHandle& session, + Message* message, + Date_t expiration) { + auto& compressorMgr = session->getCompressorManager(); auto sourceCb = [message, &compressorMgr](AbstractMessagingPort* amp) -> Status { if (!amp->recv(*message)) { return {ErrorCodes::HostUnreachable, "Recv failed"}; @@ -119,10 +121,10 @@ Ticket TransportLayerLegacy::sourceMessage(Session& session, Message* message, D return Ticket(this, stdx::make_unique<LegacyTicket>(session, expiration, std::move(sourceCb))); } -SSLPeerInfo TransportLayerLegacy::getX509PeerInfo(const Session& session) const { +SSLPeerInfo TransportLayerLegacy::getX509PeerInfo(const ConstSessionHandle& session) const { { stdx::lock_guard<stdx::mutex> lk(_connectionsMutex); - auto conn = _connections.find(session.id()); + auto conn = _connections.find(session->id()); if (conn == _connections.end()) { // Return empty string if the session is not found return SSLPeerInfo(); @@ -145,10 +147,10 @@ TransportLayer::Stats TransportLayerLegacy::sessionStats() { return stats; } -Ticket TransportLayerLegacy::sinkMessage(Session& session, +Ticket TransportLayerLegacy::sinkMessage(const SessionHandle& session, const Message& message, Date_t expiration) { - auto& compressorMgr = session.getCompressorManager(); + auto& compressorMgr = session->getCompressorManager(); auto sinkCb = [&message, &compressorMgr](AbstractMessagingPort* amp) -> Status { try { networkCounter.hitLogical(0, message.size()); @@ -179,19 +181,19 @@ void TransportLayerLegacy::asyncWait(Ticket&& ticket, TicketCallback callback) { MONGO_UNREACHABLE; } -void TransportLayerLegacy::end(Session& session) { +void TransportLayerLegacy::end(const SessionHandle& session) { stdx::lock_guard<stdx::mutex> lk(_connectionsMutex); - auto conn = _connections.find(session.id()); + auto conn = _connections.find(session->id()); if (conn != _connections.end()) { _endSession_inlock(conn); } } -void TransportLayerLegacy::registerTags(const Session& session) { +void TransportLayerLegacy::registerTags(const ConstSessionHandle& session) { stdx::lock_guard<stdx::mutex> lk(_connectionsMutex); - auto conn = _connections.find(session.id()); + auto conn = _connections.find(session->id()); if (conn != _connections.end()) { - conn->second.tags = session.getTags(); + conn->second.tags = session->getTags(); } } @@ -311,15 +313,16 @@ void TransportLayerLegacy::_handleNewConnection(std::unique_ptr<AbstractMessagin return; } - Session session(amp->remote(), HostAndPort(amp->localAddr().toString(true)), this); + auto session = + Session::create(amp->remote(), HostAndPort(amp->localAddr().toString(true)), this); amp->setLogLevel(logger::LogSeverity::Debug(1)); { stdx::lock_guard<stdx::mutex> lk(_connectionsMutex); _connections.emplace(std::piecewise_construct, - std::forward_as_tuple(session.id()), - std::forward_as_tuple(std::move(amp), false, session.getTags())); + std::forward_as_tuple(session->id()), + std::forward_as_tuple(std::move(amp), false, session->getTags())); } invariant(_sep); diff --git a/src/mongo/transport/transport_layer_legacy.h b/src/mongo/transport/transport_layer_legacy.h index 8135f559070..3d348076ad7 100644 --- a/src/mongo/transport/transport_layer_legacy.h +++ b/src/mongo/transport/transport_layer_legacy.h @@ -66,23 +66,23 @@ public: Status setup(); Status start() override; - Ticket sourceMessage(Session& session, + Ticket sourceMessage(const SessionHandle& session, Message* message, Date_t expiration = Ticket::kNoExpirationDate) override; - Ticket sinkMessage(Session& session, + Ticket sinkMessage(const SessionHandle& session, const Message& message, Date_t expiration = Ticket::kNoExpirationDate) override; Status wait(Ticket&& ticket) override; void asyncWait(Ticket&& ticket, TicketCallback callback) override; - void registerTags(const Session& session) override; - SSLPeerInfo getX509PeerInfo(const Session& session) const override; + void registerTags(const ConstSessionHandle& session) override; + SSLPeerInfo getX509PeerInfo(const ConstSessionHandle& session) const override; Stats sessionStats() override; - void end(Session& session) override; + void end(const SessionHandle& session) override; void endAllSessions(transport::Session::TagMask tags) override; void shutdown() override; @@ -105,7 +105,7 @@ private: MONGO_DISALLOW_COPYING(LegacyTicket); public: - LegacyTicket(const Session& session, Date_t expiration, WorkHandle work); + LegacyTicket(const SessionHandle& session, Date_t expiration, WorkHandle work); SessionId sessionId() const override; Date_t expiration() const override; diff --git a/src/mongo/transport/transport_layer_manager.cpp b/src/mongo/transport/transport_layer_manager.cpp index 2c12fe98c25..d29c01f54f5 100644 --- a/src/mongo/transport/transport_layer_manager.cpp +++ b/src/mongo/transport/transport_layer_manager.cpp @@ -44,14 +44,16 @@ namespace transport { TransportLayerManager::TransportLayerManager() = default; -Ticket TransportLayerManager::sourceMessage(Session& session, Message* message, Date_t expiration) { - return session.getTransportLayer()->sourceMessage(session, message, expiration); +Ticket TransportLayerManager::sourceMessage(const SessionHandle& session, + Message* message, + Date_t expiration) { + return session->getTransportLayer()->sourceMessage(session, message, expiration); } -Ticket TransportLayerManager::sinkMessage(Session& session, +Ticket TransportLayerManager::sinkMessage(const SessionHandle& session, const Message& message, Date_t expiration) { - return session.getTransportLayer()->sinkMessage(session, message, expiration); + return session->getTransportLayer()->sinkMessage(session, message, expiration); } Status TransportLayerManager::wait(Ticket&& ticket) { @@ -62,8 +64,8 @@ void TransportLayerManager::asyncWait(Ticket&& ticket, TicketCallback callback) return getTicketTransportLayer(ticket)->asyncWait(std::move(ticket), std::move(callback)); } -SSLPeerInfo TransportLayerManager::getX509PeerInfo(const Session& session) const { - return session.getX509PeerInfo(); +SSLPeerInfo TransportLayerManager::getX509PeerInfo(const ConstSessionHandle& session) const { + return session->getX509PeerInfo(); } template <typename Callable> @@ -95,12 +97,12 @@ TransportLayer::Stats TransportLayerManager::sessionStats() { return stats; } -void TransportLayerManager::registerTags(const Session& session) { - session.getTransportLayer()->registerTags(session); +void TransportLayerManager::registerTags(const ConstSessionHandle& session) { + session->getTransportLayer()->registerTags(session); } -void TransportLayerManager::end(Session& session) { - session.getTransportLayer()->end(session); +void TransportLayerManager::end(const SessionHandle& session) { + session->getTransportLayer()->end(session); } void TransportLayerManager::endAllSessions(Session::TagMask tags) { diff --git a/src/mongo/transport/transport_layer_manager.h b/src/mongo/transport/transport_layer_manager.h index bfdf7046da1..deba17146cc 100644 --- a/src/mongo/transport/transport_layer_manager.h +++ b/src/mongo/transport/transport_layer_manager.h @@ -54,22 +54,22 @@ class TransportLayerManager final : public TransportLayer { public: TransportLayerManager(); - Ticket sourceMessage(Session& session, + Ticket sourceMessage(const SessionHandle& session, Message* message, Date_t expiration = Ticket::kNoExpirationDate) override; - Ticket sinkMessage(Session& session, + Ticket sinkMessage(const SessionHandle& session, const Message& message, Date_t expiration = Ticket::kNoExpirationDate) override; Status wait(Ticket&& ticket) override; void asyncWait(Ticket&& ticket, TicketCallback callback) override; - SSLPeerInfo getX509PeerInfo(const Session& session) const override; - void registerTags(const Session& session) override; + SSLPeerInfo getX509PeerInfo(const ConstSessionHandle& session) const override; + void registerTags(const ConstSessionHandle& session) override; Stats sessionStats() override; - void end(Session& session) override; + void end(const SessionHandle& session) override; void endAllSessions(Session::TagMask tags) override; Status start() override; diff --git a/src/mongo/transport/transport_layer_mock.cpp b/src/mongo/transport/transport_layer_mock.cpp index 623980a19fc..550c8585f95 100644 --- a/src/mongo/transport/transport_layer_mock.cpp +++ b/src/mongo/transport/transport_layer_mock.cpp @@ -42,12 +42,12 @@ namespace mongo { namespace transport { -TransportLayerMock::TicketMock::TicketMock(const Session* session, +TransportLayerMock::TicketMock::TicketMock(const SessionHandle& session, Message* message, Date_t expiration) : _session(session), _message(message), _expiration(expiration) {} -TransportLayerMock::TicketMock::TicketMock(const Session* session, Date_t expiration) +TransportLayerMock::TicketMock::TicketMock(const SessionHandle& session, Date_t expiration) : _session(session), _expiration(expiration) {} Session::Id TransportLayerMock::TicketMock::sessionId() const { @@ -64,31 +64,33 @@ boost::optional<Message*> TransportLayerMock::TicketMock::msg() const { TransportLayerMock::TransportLayerMock() : _shutdown(false) {} -Ticket TransportLayerMock::sourceMessage(Session& session, Message* message, Date_t expiration) { +Ticket TransportLayerMock::sourceMessage(const SessionHandle& session, + Message* message, + Date_t expiration) { if (inShutdown()) { return Ticket(TransportLayer::ShutdownStatus); - } else if (!owns(session.id())) { + } else if (!owns(session->id())) { return Ticket(TransportLayer::SessionUnknownStatus); - } else if (_sessions[session.id()].ended) { + } else if (_sessions[session->id()].ended) { return Ticket(TransportLayer::TicketSessionClosedStatus); } return Ticket(this, - stdx::make_unique<TransportLayerMock::TicketMock>(&session, message, expiration)); + stdx::make_unique<TransportLayerMock::TicketMock>(session, message, expiration)); } -Ticket TransportLayerMock::sinkMessage(Session& session, +Ticket TransportLayerMock::sinkMessage(const SessionHandle& session, const Message& message, Date_t expiration) { if (inShutdown()) { return Ticket(TransportLayer::ShutdownStatus); - } else if (!owns(session.id())) { + } else if (!owns(session->id())) { return Ticket(TransportLayer::SessionUnknownStatus); - } else if (_sessions[session.id()].ended) { + } else if (_sessions[session->id()].ended) { return Ticket(TransportLayer::TicketSessionClosedStatus); } - return Ticket(this, stdx::make_unique<TransportLayerMock::TicketMock>(&session, expiration)); + return Ticket(this, stdx::make_unique<TransportLayerMock::TicketMock>(session, expiration)); } Status TransportLayerMock::wait(Ticket&& ticket) { @@ -109,52 +111,51 @@ void TransportLayerMock::asyncWait(Ticket&& ticket, TicketCallback callback) { callback(Status::OK()); } -SSLPeerInfo TransportLayerMock::getX509PeerInfo(const Session& session) const { - return _sessions.at(session.id()).peerInfo; +SSLPeerInfo TransportLayerMock::getX509PeerInfo(const ConstSessionHandle& session) const { + return _sessions.at(session->id()).peerInfo; } -void TransportLayerMock::setX509PeerInfo(const Session& session, SSLPeerInfo peerInfo) { - _sessions[session.id()].peerInfo = std::move(peerInfo); +void TransportLayerMock::setX509PeerInfo(const SessionHandle& session, SSLPeerInfo peerInfo) { + _sessions[session->id()].peerInfo = std::move(peerInfo); } TransportLayer::Stats TransportLayerMock::sessionStats() { return Stats(); } -void TransportLayerMock::registerTags(const Session& session) {} +void TransportLayerMock::registerTags(const ConstSessionHandle& session) {} -Session* TransportLayerMock::createSession() { - std::unique_ptr<Session> session = - stdx::make_unique<Session>(HostAndPort(), HostAndPort(), this); +SessionHandle TransportLayerMock::createSession() { + auto session = Session::create(HostAndPort(), HostAndPort(), this); Session::Id sessionId = session->id(); - _sessions[sessionId] = Connection{false, std::move(session), SSLPeerInfo()}; + _sessions[sessionId] = Connection{false, session, SSLPeerInfo()}; - return _sessions[sessionId].session.get(); + return _sessions[sessionId].session; } -Session* TransportLayerMock::get(Session::Id id) { +SessionHandle TransportLayerMock::get(Session::Id id) { if (!owns(id)) return nullptr; - return _sessions[id].session.get(); + return _sessions[id].session; } bool TransportLayerMock::owns(Session::Id id) { return _sessions.count(id) > 0; } -void TransportLayerMock::end(Session& session) { - if (!owns(session.id())) +void TransportLayerMock::end(const SessionHandle& session) { + if (!owns(session->id())) return; - _sessions[session.id()].ended = true; + _sessions[session->id()].ended = true; } void TransportLayerMock::endAllSessions(Session::TagMask tags) { auto it = _sessions.begin(); while (it != _sessions.end()) { - end(*it->second.session.get()); + end(it->second.session); it++; } } diff --git a/src/mongo/transport/transport_layer_mock.h b/src/mongo/transport/transport_layer_mock.h index a8d37ab7a0d..420360a44de 100644 --- a/src/mongo/transport/transport_layer_mock.h +++ b/src/mongo/transport/transport_layer_mock.h @@ -51,12 +51,12 @@ public: class TicketMock : public TicketImpl { public: // Source constructor - TicketMock(const Session* session, + TicketMock(const SessionHandle& session, Message* message, Date_t expiration = Ticket::kNoExpirationDate); // Sink constructor - TicketMock(const Session* session, Date_t expiration = Ticket::kNoExpirationDate); + TicketMock(const SessionHandle& session, Date_t expiration = Ticket::kNoExpirationDate); TicketMock(TicketMock&&) = default; TicketMock& operator=(TicketMock&&) = default; @@ -68,7 +68,7 @@ public: boost::optional<Message*> msg() const; private: - const Session* _session; + const SessionHandle& _session; boost::optional<Message*> _message; Date_t _expiration; }; @@ -76,26 +76,26 @@ public: TransportLayerMock(); ~TransportLayerMock(); - Ticket sourceMessage(Session& session, + Ticket sourceMessage(const SessionHandle& session, Message* message, Date_t expiration = Ticket::kNoExpirationDate) override; - Ticket sinkMessage(Session& session, + Ticket sinkMessage(const SessionHandle& session, const Message& message, Date_t expiration = Ticket::kNoExpirationDate) override; Status wait(Ticket&& ticket) override; void asyncWait(Ticket&& ticket, TicketCallback callback) override; - SSLPeerInfo getX509PeerInfo(const Session& session) const override; - void setX509PeerInfo(const Session& session, SSLPeerInfo peerInfo); - void registerTags(const Session& session) override; + SSLPeerInfo getX509PeerInfo(const ConstSessionHandle& session) const override; + void setX509PeerInfo(const SessionHandle& session, SSLPeerInfo peerInfo); + void registerTags(const ConstSessionHandle& session) override; Stats sessionStats() override; - Session* createSession(); - Session* get(Session::Id id); + SessionHandle createSession(); + SessionHandle get(Session::Id id); bool owns(Session::Id id); - void end(Session& session) override; + void end(const SessionHandle& session) override; void endAllSessions(Session::TagMask tags) override; Status start() override; @@ -107,7 +107,7 @@ private: struct Connection { bool ended; - std::unique_ptr<Session> session; + SessionHandle session; SSLPeerInfo peerInfo; }; stdx::unordered_map<Session::Id, Connection> _sessions; diff --git a/src/mongo/transport/transport_layer_mock_test.cpp b/src/mongo/transport/transport_layer_mock_test.cpp index a86874b4b00..9876db68d8e 100644 --- a/src/mongo/transport/transport_layer_mock_test.cpp +++ b/src/mongo/transport/transport_layer_mock_test.cpp @@ -54,10 +54,10 @@ private: // sinkMessage() generates a valid Ticket TEST_F(TransportLayerMockTest, SinkMessageGeneratesTicket) { Message msg{}; - Session* session = tl()->createSession(); + SessionHandle session = tl()->createSession(); // call sinkMessage() with no expiration - Ticket ticket = tl()->sinkMessage(*session, msg); + Ticket ticket = tl()->sinkMessage(session, msg); ASSERT(ticket.valid()); ASSERT_OK(ticket.status()); ASSERT_EQUALS(ticket.sessionId(), session->id()); @@ -65,7 +65,7 @@ TEST_F(TransportLayerMockTest, SinkMessageGeneratesTicket) { // call sinkMessage() with an expiration Date_t expiration = Date_t::now() + Hours(1); - ticket = tl()->sinkMessage(*session, msg, expiration); + ticket = tl()->sinkMessage(session, msg, expiration); ASSERT(ticket.valid()); ASSERT_OK(ticket.status()); ASSERT_EQUALS(ticket.sessionId(), session->id()); @@ -75,11 +75,11 @@ TEST_F(TransportLayerMockTest, SinkMessageGeneratesTicket) { // sinkMessage() generates an invalid Ticket if the Session is closed TEST_F(TransportLayerMockTest, SinkMessageSessionClosed) { Message msg{}; - Session* session = tl()->createSession(); + SessionHandle session = tl()->createSession(); - tl()->end(*session); + tl()->end(session); - Ticket ticket = tl()->sinkMessage(*session, msg); + Ticket ticket = tl()->sinkMessage(session, msg); ASSERT_FALSE(ticket.valid()); ASSERT_EQUALS(ticket.status().code(), ErrorCodes::TransportSessionClosed); } @@ -89,9 +89,9 @@ TEST_F(TransportLayerMockTest, SinkMessageSessionUnknown) { Message msg{}; std::unique_ptr<TransportLayerMock> anotherTL = stdx::make_unique<TransportLayerMock>(); - Session* session = anotherTL->createSession(); + SessionHandle session = anotherTL->createSession(); - Ticket ticket = tl()->sinkMessage(*session, msg); + Ticket ticket = tl()->sinkMessage(session, msg); ASSERT_FALSE(ticket.valid()); ASSERT_EQUALS(ticket.status().code(), ErrorCodes::TransportSessionUnknown); } @@ -99,11 +99,11 @@ TEST_F(TransportLayerMockTest, SinkMessageSessionUnknown) { // sinkMessage() generates an invalid Ticket if the TransportLayer is in shutdown TEST_F(TransportLayerMockTest, SinkMessageTLShutdown) { Message msg{}; - Session* session = tl()->createSession(); + SessionHandle session = tl()->createSession(); tl()->shutdown(); - Ticket ticket = tl()->sinkMessage(*session, msg); + Ticket ticket = tl()->sinkMessage(session, msg); ASSERT_FALSE(ticket.valid()); ASSERT_EQUALS(ticket.status().code(), ErrorCodes::ShutdownInProgress); } @@ -111,10 +111,10 @@ TEST_F(TransportLayerMockTest, SinkMessageTLShutdown) { // sourceMessage() generates a valid ticket TEST_F(TransportLayerMockTest, SourceMessageGeneratesTicket) { Message msg{}; - Session* session = tl()->createSession(); + SessionHandle session = tl()->createSession(); // call sourceMessage() with no expiration - Ticket ticket = tl()->sourceMessage(*session, &msg); + Ticket ticket = tl()->sourceMessage(session, &msg); ASSERT(ticket.valid()); ASSERT_OK(ticket.status()); ASSERT_EQUALS(ticket.sessionId(), session->id()); @@ -123,7 +123,7 @@ TEST_F(TransportLayerMockTest, SourceMessageGeneratesTicket) { // call sourceMessage() with an expiration Date_t expiration = Date_t::now() + Hours(1); - ticket = tl()->sourceMessage(*session, &msg, expiration); + ticket = tl()->sourceMessage(session, &msg, expiration); ASSERT(ticket.valid()); ASSERT_OK(ticket.status()); ASSERT_EQUALS(ticket.sessionId(), session->id()); @@ -134,11 +134,11 @@ TEST_F(TransportLayerMockTest, SourceMessageGeneratesTicket) { // sourceMessage() generates an invalid ticket if the Session is closed TEST_F(TransportLayerMockTest, SourceMessageSessionClosed) { Message msg{}; - Session* session = tl()->createSession(); + SessionHandle session = tl()->createSession(); - tl()->end(*session); + tl()->end(session); - Ticket ticket = tl()->sourceMessage(*session, &msg); + Ticket ticket = tl()->sourceMessage(session, &msg); ASSERT_FALSE(ticket.valid()); ASSERT_EQUALS(ticket.status().code(), ErrorCodes::TransportSessionClosed); } @@ -148,9 +148,9 @@ TEST_F(TransportLayerMockTest, SourceMessageSessionUnknown) { Message msg{}; std::unique_ptr<TransportLayerMock> anotherTL = stdx::make_unique<TransportLayerMock>(); - Session* session = anotherTL->createSession(); + SessionHandle session = anotherTL->createSession(); - Ticket ticket = tl()->sourceMessage(*session, &msg); + Ticket ticket = tl()->sourceMessage(session, &msg); ASSERT_FALSE(ticket.valid()); ASSERT_EQUALS(ticket.status().code(), ErrorCodes::TransportSessionUnknown); } @@ -158,18 +158,18 @@ TEST_F(TransportLayerMockTest, SourceMessageSessionUnknown) { // sourceMessage() generates an invalid ticket if the TransportLayer is in shutdown TEST_F(TransportLayerMockTest, SourceMessageTLShutdown) { Message msg{}; - Session* session = tl()->createSession(); + SessionHandle session = tl()->createSession(); tl()->shutdown(); - Ticket ticket = tl()->sourceMessage(*session, &msg); + Ticket ticket = tl()->sourceMessage(session, &msg); ASSERT_FALSE(ticket.valid()); ASSERT_EQUALS(ticket.status().code(), ErrorCodes::ShutdownInProgress); } // wait() returns an OK status TEST_F(TransportLayerMockTest, Wait) { - Session* session = tl()->createSession(); + SessionHandle session = tl()->createSession(); Ticket ticket = Ticket(tl(), stdx::make_unique<TransportLayerMock::TicketMock>(session)); Status status = tl()->wait(std::move(ticket)); @@ -178,7 +178,7 @@ TEST_F(TransportLayerMockTest, Wait) { // wait() returns an TicketExpired error status if the Ticket expired TEST_F(TransportLayerMockTest, WaitExpiredTicket) { - Session* session = tl()->createSession(); + SessionHandle session = tl()->createSession(); Ticket expiredTicket = Ticket(tl(), stdx::make_unique<TransportLayerMock::TicketMock>(session, Date_t::now())); @@ -197,10 +197,10 @@ TEST_F(TransportLayerMockTest, WaitInvalidTicket) { // wait() returns a SessionClosed error status if the Ticket's Session is closed TEST_F(TransportLayerMockTest, WaitSessionClosed) { - Session* session = tl()->createSession(); + SessionHandle session = tl()->createSession(); Ticket ticket = Ticket(tl(), stdx::make_unique<TransportLayerMock::TicketMock>(session)); - tl()->end(*session); + tl()->end(session); Status status = tl()->wait(std::move(ticket)); ASSERT_EQUALS(status.code(), ErrorCodes::TransportSessionClosed); @@ -210,7 +210,7 @@ TEST_F(TransportLayerMockTest, WaitSessionClosed) { // Session TEST_F(TransportLayerMockTest, WaitSessionUnknown) { std::unique_ptr<TransportLayerMock> anotherTL = stdx::make_unique<TransportLayerMock>(); - Session* session = anotherTL->createSession(); + SessionHandle session = anotherTL->createSession(); Ticket ticket = Ticket(tl(), stdx::make_unique<TransportLayerMock::TicketMock>(session)); Status status = tl()->wait(std::move(ticket)); @@ -219,7 +219,7 @@ TEST_F(TransportLayerMockTest, WaitSessionUnknown) { // wait() returns a ShutdownInProgress status if the TransportLayer is in shutdown TEST_F(TransportLayerMockTest, WaitTLShutdown) { - Session* session = tl()->createSession(); + SessionHandle session = tl()->createSession(); Ticket ticket = Ticket(tl(), stdx::make_unique<TransportLayerMock::TicketMock>(session)); tl()->shutdown(); @@ -228,18 +228,18 @@ TEST_F(TransportLayerMockTest, WaitTLShutdown) { ASSERT_EQUALS(status.code(), ErrorCodes::ShutdownInProgress); } -std::vector<Session*> createSessions(TransportLayerMock* tl) { +std::vector<SessionHandle> createSessions(TransportLayerMock* tl) { int numSessions = 10; - std::vector<Session*> sessions; + std::vector<SessionHandle> sessions; for (int i = 0; i < numSessions; i++) { - Session* session = tl->createSession(); + SessionHandle session = tl->createSession(); sessions.push_back(session); } return sessions; } void assertEnded(TransportLayer* tl, - std::vector<Session*> sessions, + std::vector<SessionHandle> sessions, ErrorCodes::Error code = ErrorCodes::TransportSessionClosed) { for (auto session : sessions) { Ticket ticket = Ticket(tl, stdx::make_unique<TransportLayerMock::TicketMock>(session)); @@ -250,14 +250,14 @@ void assertEnded(TransportLayer* tl, // endAllSessions() ends all sessions TEST_F(TransportLayerMockTest, EndAllSessions) { - std::vector<Session*> sessions = createSessions(tl()); + std::vector<SessionHandle> sessions = createSessions(tl()); tl()->endAllSessions(Session::kEmptyTagMask); assertEnded(tl(), sessions); } // shutdown() ends all sessions and shuts down TEST_F(TransportLayerMockTest, Shutdown) { - std::vector<Session*> sessions = createSessions(tl()); + std::vector<SessionHandle> sessions = createSessions(tl()); tl()->shutdown(); assertEnded(tl(), sessions, ErrorCodes::ShutdownInProgress); ASSERT(tl()->inShutdown()); |