From d55cc0bf9134958ec5651daef9b04e9eee4c3dbb Mon Sep 17 00:00:00 2001 From: Hugh Tong Date: Wed, 3 May 2023 22:26:55 +0000 Subject: SERVER-76635 Override SerializationContext when parsing agg requests --- src/mongo/db/commands/count_cmd.cpp | 3 ++- src/mongo/db/commands/distinct.cpp | 6 +++++- src/mongo/db/commands/find_cmd.cpp | 25 ++++++++++++++++------ src/mongo/db/commands/pipeline_command.cpp | 7 +++++- .../db/pipeline/aggregation_request_helper.cpp | 23 +++++++++++++------- src/mongo/db/pipeline/aggregation_request_helper.h | 24 ++++++++++++--------- ...source_change_stream_handle_topology_change.cpp | 7 +++++- 7 files changed, 67 insertions(+), 28 deletions(-) diff --git a/src/mongo/db/commands/count_cmd.cpp b/src/mongo/db/commands/count_cmd.cpp index 821ad1bc6de..5285347a715 100644 --- a/src/mongo/db/commands/count_cmd.cpp +++ b/src/mongo/db/commands/count_cmd.cpp @@ -189,7 +189,8 @@ public: nss, viewAggCmd, verbosity, - APIParameters::get(opCtx).getAPIStrict().value_or(false)); + APIParameters::get(opCtx).getAPIStrict().value_or(false), + request.getSerializationContext()); // An empty PrivilegeVector is acceptable because these privileges are only checked on // getMore and explain will not open a cursor. diff --git a/src/mongo/db/commands/distinct.cpp b/src/mongo/db/commands/distinct.cpp index fc0c1b0b98f..ba8226ca8a4 100644 --- a/src/mongo/db/commands/distinct.cpp +++ b/src/mongo/db/commands/distinct.cpp @@ -180,12 +180,16 @@ public: OpMsgRequestBuilder::createWithValidatedTenancyScope( nss.dbName(), request.validatedTenancyScope, viewAggregation.getValue()) .body; + // TODO SERVER-75930: expose serializatonContext from when ParseDistinct calls + // ParseDistinctRequest, and pass it onto parseFromBSON to override + // parse(IDLParseContext&, BSONObj&) call auto viewAggRequest = aggregation_request_helper::parseFromBSON( opCtx, nss, viewAggCmd, verbosity, - APIParameters::get(opCtx).getAPIStrict().value_or(false)); + APIParameters::get(opCtx).getAPIStrict().value_or(false), + SerializationContext::stateCommandRequest()); // An empty PrivilegeVector is acceptable because these privileges are only checked on // getMore and explain will not open a cursor. diff --git a/src/mongo/db/commands/find_cmd.cpp b/src/mongo/db/commands/find_cmd.cpp index 983661fbd15..19233f3eee5 100644 --- a/src/mongo/db/commands/find_cmd.cpp +++ b/src/mongo/db/commands/find_cmd.cpp @@ -268,6 +268,12 @@ public: // Parse the command BSON to a FindCommandRequest. auto findCommand = _parseCmdObjectToFindCommandRequest(opCtx, nss, _request.body); + // check validated tenantId and correct the serialization context object on the request + auto reqSerializationContext = findCommand->getSerializationContext(); + reqSerializationContext.setTenantIdSource(_request.getValidatedTenantId() != + boost::none); + findCommand->setSerializationContext(reqSerializationContext); + // Finish the parsing step by using the FindCommandRequest to create a CanonicalQuery. const ExtensionsCallbackReal extensionsCallback(opCtx, &nss); @@ -312,7 +318,8 @@ public: nss, viewAggCmd, verbosity, - APIParameters::get(opCtx).getAPIStrict().value_or(false)); + APIParameters::get(opCtx).getAPIStrict().value_or(false), + reqSerializationContext); try { // An empty PrivilegeVector is acceptable because these privileges are only @@ -368,6 +375,16 @@ public: const bool isExplain = false; const bool isOplogNss = (_ns == NamespaceString::kRsOplogNamespace); auto findCommand = _parseCmdObjectToFindCommandRequest(opCtx, _ns, cmdObj); + + // check validated tenantId and correct the serialization context object on the request + auto reqSerializationContext = findCommand->getSerializationContext(); + reqSerializationContext.setTenantIdSource(_request.getValidatedTenantId() != + boost::none); + findCommand->setSerializationContext(reqSerializationContext); + + auto respSerializationContext = + SerializationContext::stateCommandReply(reqSerializationContext); + CurOp::get(opCtx)->beginQueryPlanningTimer(); // Only allow speculative majority for internal commands that specify the correct flag. @@ -572,10 +589,6 @@ public: } } - // We need to copy the serialization context from the request to the reply object before - // the request object goes out of scope - const auto serializationContext = cq->getFindCommandRequest().getSerializationContext(); - // Get the execution plan for the query. bool permitYield = true; auto exec = @@ -732,7 +745,7 @@ public: auto& metricsCollector = ResourceConsumption::MetricsCollector::get(opCtx); metricsCollector.incrementDocUnitsReturned(nss.ns(), docUnitsReturned); query_request_helper::validateCursorResponse( - result->getBodyBuilder().asTempObj(), nss.tenantId(), serializationContext); + result->getBodyBuilder().asTempObj(), nss.tenantId(), respSerializationContext); } void appendMirrorableRequest(BSONObjBuilder* bob) const override { diff --git a/src/mongo/db/commands/pipeline_command.cpp b/src/mongo/db/commands/pipeline_command.cpp index 20eab15b603..48f6d7f7e4c 100644 --- a/src/mongo/db/commands/pipeline_command.cpp +++ b/src/mongo/db/commands/pipeline_command.cpp @@ -81,13 +81,18 @@ public: OperationContext* opCtx, const OpMsgRequest& opMsgRequest, boost::optional explainVerbosity) override { + + SerializationContext serializationCtx = SerializationContext::stateCommandRequest(); + serializationCtx.setTenantIdSource(opMsgRequest.getValidatedTenantId() != boost::none); + const auto aggregationRequest = aggregation_request_helper::parseFromBSON( opCtx, DatabaseNameUtil::deserialize(opMsgRequest.getValidatedTenantId(), opMsgRequest.getDatabase()), opMsgRequest.body, explainVerbosity, - APIParameters::get(opCtx).getAPIStrict().value_or(false)); + APIParameters::get(opCtx).getAPIStrict().value_or(false), + serializationCtx); auto privileges = uassertStatusOK( auth::getPrivilegesForAggregate(AuthorizationSession::get(opCtx->getClient()), diff --git a/src/mongo/db/pipeline/aggregation_request_helper.cpp b/src/mongo/db/pipeline/aggregation_request_helper.cpp index 589353fcdc7..c66e501833c 100644 --- a/src/mongo/db/pipeline/aggregation_request_helper.cpp +++ b/src/mongo/db/pipeline/aggregation_request_helper.cpp @@ -58,8 +58,10 @@ AggregateCommandRequest parseFromBSON(OperationContext* opCtx, const DatabaseName& dbName, const BSONObj& cmdObj, boost::optional explainVerbosity, - bool apiStrict) { - return parseFromBSON(opCtx, parseNs(dbName, cmdObj), cmdObj, explainVerbosity, apiStrict); + bool apiStrict, + const SerializationContext& serializationContext) { + return parseFromBSON( + opCtx, parseNs(dbName, cmdObj), cmdObj, explainVerbosity, apiStrict, serializationContext); } StatusWith parseFromBSONForTests( @@ -68,7 +70,8 @@ StatusWith parseFromBSONForTests( boost::optional explainVerbosity, bool apiStrict) { try { - return parseFromBSON(/*opCtx=*/nullptr, nss, cmdObj, explainVerbosity, apiStrict); + return parseFromBSON( + /*opCtx=*/nullptr, nss, cmdObj, explainVerbosity, apiStrict, SerializationContext()); } catch (const AssertionException&) { return exceptionToStatus(); } @@ -80,8 +83,9 @@ StatusWith parseFromBSONForTests( boost::optional explainVerbosity, bool apiStrict) { try { + // TODO SERVER-75930: pass serializationContext in return parseFromBSON( - /*opCtx=*/nullptr, dbName, cmdObj, explainVerbosity, apiStrict); + /*opCtx=*/nullptr, dbName, cmdObj, explainVerbosity, apiStrict, SerializationContext()); } catch (const AssertionException&) { return exceptionToStatus(); } @@ -91,7 +95,8 @@ AggregateCommandRequest parseFromBSON(OperationContext* opCtx, NamespaceString nss, const BSONObj& cmdObj, boost::optional explainVerbosity, - bool apiStrict) { + bool apiStrict, + const SerializationContext& serializationContext) { // if the command object lacks field 'aggregate' or '$db', we will use the namespace in 'nss'. bool cmdObjChanged = false; @@ -104,9 +109,11 @@ AggregateCommandRequest parseFromBSON(OperationContext* opCtx, } AggregateCommandRequest request(nss); - request = - AggregateCommandRequest::parse(IDLParserContext("aggregate", apiStrict, nss.tenantId()), - cmdObjChanged ? cmdObjBob.obj() : cmdObj); + // TODO SERVER-75930: tenantId in VTS isn't properly detected by call to parse(IDLParseContext&, + // BSONObj&) + request = AggregateCommandRequest::parse( + IDLParserContext("aggregate", apiStrict, nss.tenantId(), serializationContext), + cmdObjChanged ? cmdObjBob.obj() : cmdObj); if (explainVerbosity) { uassert(ErrorCodes::FailedToParse, diff --git a/src/mongo/db/pipeline/aggregation_request_helper.h b/src/mongo/db/pipeline/aggregation_request_helper.h index 63d77287db4..c8adf37b2fe 100644 --- a/src/mongo/db/pipeline/aggregation_request_helper.h +++ b/src/mongo/db/pipeline/aggregation_request_helper.h @@ -67,11 +67,13 @@ static constexpr long long kDefaultBatchSize = 101; * then 'explainVerbosity' contains this information. In this case, 'cmdObj' may not itself * contain the explain specifier. Otherwise, 'explainVerbosity' should be boost::none. */ -AggregateCommandRequest parseFromBSON(OperationContext* opCtx, - NamespaceString nss, - const BSONObj& cmdObj, - boost::optional explainVerbosity, - bool apiStrict); +AggregateCommandRequest parseFromBSON( + OperationContext* opCtx, + NamespaceString nss, + const BSONObj& cmdObj, + boost::optional explainVerbosity, + bool apiStrict, + const SerializationContext& serializationContext = SerializationContext()); StatusWith parseFromBSONForTests( NamespaceString nss, @@ -83,11 +85,13 @@ StatusWith parseFromBSONForTests( * Convenience overload which constructs the request's NamespaceString from the given database * name and command object. */ -AggregateCommandRequest parseFromBSON(OperationContext* opCtx, - const DatabaseName& dbName, - const BSONObj& cmdObj, - boost::optional explainVerbosity, - bool apiStrict); +AggregateCommandRequest parseFromBSON( + OperationContext* opCtx, + const DatabaseName& dbName, + const BSONObj& cmdObj, + boost::optional explainVerbosity, + bool apiStrict, + const SerializationContext& serializationContext = SerializationContext()); StatusWith parseFromBSONForTests( const DatabaseName& dbName, diff --git a/src/mongo/db/pipeline/document_source_change_stream_handle_topology_change.cpp b/src/mongo/db/pipeline/document_source_change_stream_handle_topology_change.cpp index 28b695c89be..95f59ee5794 100644 --- a/src/mongo/db/pipeline/document_source_change_stream_handle_topology_change.cpp +++ b/src/mongo/db/pipeline/document_source_change_stream_handle_topology_change.cpp @@ -210,9 +210,14 @@ BSONObj DocumentSourceChangeStreamHandleTopologyChange::createUpdatedCommandForN auto* opCtx = pExpCtx->opCtx; bool apiStrict = APIParameters::get(opCtx).getAPIStrict().value_or(false); + tassert(7663502, + str::stream() << "SerializationContext on the expCtx should not be empty, with ns: " + << pExpCtx->ns.ns(), + pExpCtx->serializationCtxt != SerializationContext::stateDefault()); + // Create the 'AggregateCommandRequest' object which will help in creating the parsed pipeline. auto aggCmdRequest = aggregation_request_helper::parseFromBSON( - opCtx, pExpCtx->ns, shardCommand, boost::none, apiStrict); + opCtx, pExpCtx->ns, shardCommand, boost::none, apiStrict, pExpCtx->serializationCtxt); // Parse and optimize the pipeline. auto pipeline = Pipeline::parse(aggCmdRequest.getPipeline(), pExpCtx); -- cgit v1.2.1