From ed0dffba9cb908bb22c21d3d79f6434e25fa8947 Mon Sep 17 00:00:00 2001 From: George Wangensteen Date: Mon, 26 Sep 2022 15:58:35 +0000 Subject: SERVER-67826 Allow IDL types to own or preserve the lifetime of underlying data --- buildscripts/idl/idl/generator.py | 119 +++++++++++++++++---- buildscripts/idl/idl/struct_types.py | 56 +++++++++- src/mongo/db/process_health/fault_manager_config.h | 2 +- src/mongo/db/query/ce/scalar_histogram.cpp | 2 +- src/mongo/db/query/query_request_test.cpp | 20 ++++ src/mongo/db/repl/oplog_entry_test.cpp | 1 - src/mongo/idl/idl_test.cpp | 34 ++++++ 7 files changed, 204 insertions(+), 30 deletions(-) diff --git a/buildscripts/idl/idl/generator.py b/buildscripts/idl/idl/generator.py index 490f819990f..3cb145c09eb 100644 --- a/buildscripts/idl/idl/generator.py +++ b/buildscripts/idl/idl/generator.py @@ -35,12 +35,21 @@ import re import sys import textwrap from abc import ABCMeta, abstractmethod +from enum import Enum from typing import Dict, Iterable, List, Mapping, Tuple, Union, cast from . import (ast, bson, common, cpp_types, enum_types, generic_field_list_types, struct_types, writer) +class _StructDataOwnership(Enum): + """Enumerates the ways a struct may participate in ownership of it's data.""" + + VIEW = 1 # Doesn't participate in ownership + SHARED = 2 # Participates in non-exclusive ownership + OWNER = 3 # Takes ownership of underlying data + + def _get_field_member_name(field): # type: (ast.Field) -> str """Get the C++ class member name for a field.""" @@ -467,26 +476,64 @@ class _CppHeaderFileWriter(_CppFileWriterBase): def gen_serializer_methods(self, struct): # type: (ast.Struct) -> None - """Generate a serializer method declarations.""" - + """Generate serializer method declarations.""" struct_type_info = struct_types.get_struct_info(struct) + self._writer.write_line(struct_type_info.get_serializer_method().get_declaration()) - parse_method = struct_type_info.get_deserializer_static_method() - if parse_method: - self._writer.write_line(parse_method.get_declaration()) + maybe_op_msg_serializer = struct_type_info.get_op_msg_request_serializer_method() + if maybe_op_msg_serializer: + self._writer.write_line(maybe_op_msg_serializer.get_declaration()) - parse_method = struct_type_info.get_op_msg_request_deserializer_static_method() - if parse_method: - self._writer.write_line(parse_method.get_declaration()) + self._writer.write_line(struct_type_info.get_to_bson_method().get_declaration()) - self._writer.write_line(struct_type_info.get_serializer_method().get_declaration()) + self._writer.write_empty_line() - parse_method = struct_type_info.get_op_msg_request_serializer_method() - if parse_method: - self._writer.write_line(parse_method.get_declaration()) + def gen_deserializer_methods(self, struct): + # type: (ast.Struct) -> None + """Generate deserializer method declarations.""" + struct_type_info = struct_types.get_struct_info(struct) + possible_deserializer_methods = [ + struct_type_info.get_deserializer_static_method(), + struct_type_info.get_owned_deserializer_static_method(), + struct_type_info.get_sharing_deserializer_static_method(), + struct_type_info.get_op_msg_request_deserializer_static_method() + ] + for maybe_parse_method in possible_deserializer_methods: + if maybe_parse_method: + comment = maybe_parse_method.get_desc_for_comment() + if comment: + self.gen_description_comment(comment) + self._writer.write_line(maybe_parse_method.get_declaration()) - self._writer.write_line(struct_type_info.get_to_bson_method().get_declaration()) + self._writer.write_empty_line() + def gen_ownership_getter(self): + # type: () -> None + """Generate a getter that returns true if this IDL object owns its underlying data.""" + self.gen_description_comment( + textwrap.dedent("""\ + An IDL struct can either provide a view onto some underlying BSON data, or it can + participate in owning that data. This function returns true if the struct participates in + owning the underlying data. + + Note that the underlying data is not synchronized with the IDL struct over its lifetime; to + generate a BSON representation of an IDL struct, use its `serialize` member functions. + Participating in ownership of the underlying data merely allows the struct to ensure that + struct members that are pointers-into-BSON (i.e. BSONElement and BSONObject) are valid for + the lifetime of the struct itself.""")) + self._writer.write_line("bool isOwned() const { return _anchorObj.isOwned(); }") + + def gen_private_ownership_setters(self): + # type: () -> None + """Generate a setter that can be used to allow this IDL struct to particpate in the ownership of some BSON that the struct's members refer to.""" + with self._block('void setAnchor(const BSONObj& obj) {', '}'): + self._writer.write_line("invariant(obj.isOwned());") + self._writer.write_line("_anchorObj = obj;") + self._writer.write_empty_line() + + with self._block('void setAnchor(BSONObj&& obj) {', '}'): + self._writer.write_line("invariant(obj.isOwned());") + self._writer.write_line("_anchorObj = std::move(obj);") self._writer.write_empty_line() def gen_protected_serializer_methods(self, struct): @@ -1075,9 +1122,14 @@ class _CppHeaderFileWriter(_CppFileWriterBase): # Write serialization self.gen_serializer_methods(struct) + # Write deserialization + self.gen_deserializer_methods(struct) + if isinstance(struct, ast.Command): self.gen_op_msg_request_methods(struct) + self.gen_ownership_getter() + # Write getters & setters for field in struct.fields: if not field.ignore: @@ -1103,6 +1155,8 @@ class _CppHeaderFileWriter(_CppFileWriterBase): self.write_unindented_line('private:') + self.gen_private_ownership_setters() + if struct.generate_comparison_operators: self.gen_comparison_operators_declarations(struct) @@ -1124,7 +1178,7 @@ class _CppHeaderFileWriter(_CppFileWriterBase): for field in struct.fields: if _is_required_serializer_field(field): self.gen_serializer_member(field) - + self._writer.write_line("BSONObj _anchorObj;") # Write constexpr struct data self.gen_constexpr_members(struct) @@ -1696,8 +1750,9 @@ class _CppSourceFileWriter(_CppFileWriterBase): return field_usage_check - def get_bson_deserializer_static_common(self, struct, static_method_info, method_info): - # type: (ast.Struct, struct_types.MethodInfo, struct_types.MethodInfo) -> None + def get_bson_deserializer_static_common(self, struct, static_method_info, method_info, + ownership): + # type: (ast.Struct, struct_types.MethodInfo, struct_types.MethodInfo, _StructDataOwnership) -> None """Generate the C++ deserializer static method.""" func_def = static_method_info.get_definition() @@ -1730,8 +1785,17 @@ class _CppSourceFileWriter(_CppFileWriterBase): common.title_case(struct.cpp_name)) self._writer.write_line(method_info.get_call('object')) + + if ownership == _StructDataOwnership.OWNER: + self._writer.write_line('object.setAnchor(std::move(bsonObject));') + + elif ownership == _StructDataOwnership.SHARED: + self._writer.write_line('object.setAnchor(bsonObject);') + self._writer.write_line('return object;') + self.write_empty_line() + def _compare_and_return_status(self, op, limit, field, optional_param): # type: (str, ast.Expression, ast.Field, str) -> None """Throw an error on comparison failure.""" @@ -1791,20 +1855,30 @@ class _CppSourceFileWriter(_CppFileWriterBase): # type: (ast.Struct) -> None """Generate the C++ deserializer method definitions.""" struct_type_info = struct_types.get_struct_info(struct) + method = struct_type_info.get_deserializer_method() self.get_bson_deserializer_static_common(struct, struct_type_info.get_deserializer_static_method(), - struct_type_info.get_deserializer_method()) + method, _StructDataOwnership.VIEW) + self.get_bson_deserializer_static_common( + struct, struct_type_info.get_sharing_deserializer_static_method(), method, + _StructDataOwnership.SHARED) + self.get_bson_deserializer_static_common( + struct, struct_type_info.get_owned_deserializer_static_method(), method, + _StructDataOwnership.OWNER) - func_def = struct_type_info.get_deserializer_method().get_definition() - with self._block('%s {' % (func_def), '}'): + func_def = method.get_definition() + # Name of the variable that we are deserialzing from + variable_name = "bsonObject" + + with self._block('%s {' % (func_def), '}'): # If the struct contains no fields, there's nothing to deserialize, so we write an empty function stub. if not struct.fields: return # Deserialize all the fields - field_usage_check = self._gen_fields_deserializer_common(struct, "bsonObject", + field_usage_check = self._gen_fields_deserializer_common(struct, variable_name, "ctxt.getTenantId()") # Check for required fields @@ -1814,7 +1888,7 @@ class _CppSourceFileWriter(_CppFileWriterBase): if struct.cpp_validator_func is not None: self._writer.write_line(struct.cpp_validator_func + "(this);") - self._gen_command_deserializer(struct, "bsonObject") + self._gen_command_deserializer(struct, variable_name) def gen_op_msg_request_deserializer_methods(self, struct): # type: (ast.Struct) -> None @@ -1828,7 +1902,7 @@ class _CppSourceFileWriter(_CppFileWriterBase): self.get_bson_deserializer_static_common( struct, struct_type_info.get_op_msg_request_deserializer_static_method(), - struct_type_info.get_op_msg_request_deserializer_method()) + struct_type_info.get_op_msg_request_deserializer_method(), _StructDataOwnership.VIEW) func_def = struct_type_info.get_op_msg_request_deserializer_method().get_definition() with self._block('%s {' % (func_def), '}'): @@ -2695,7 +2769,6 @@ class _CppSourceFileWriter(_CppFileWriterBase): self.gen_field_validators(struct) self.write_empty_line() - # Write deserializers self.gen_bson_deserializer_methods(struct) self.write_empty_line() diff --git a/buildscripts/idl/idl/struct_types.py b/buildscripts/idl/idl/struct_types.py index 1559d8ecd52..3d44458a436 100644 --- a/buildscripts/idl/idl/struct_types.py +++ b/buildscripts/idl/idl/struct_types.py @@ -27,6 +27,7 @@ # """Provide code generation information for structs and commands in a polymorphic way.""" +import textwrap from abc import ABCMeta, abstractmethod from typing import Optional, List @@ -80,9 +81,12 @@ class ArgumentInfo(object): class MethodInfo(object): """Class that encapslates information about a method and how to declare, define, and call it.""" + # pylint: disable=too-many-instance-attributes + def __init__(self, class_name, method_name, args, return_type=None, static=False, const=False, - explicit=False): - # type: (str, str, List[str], str, bool, bool, bool) -> None + explicit=False, desc_for_comment=None): + # type: (str, str, List[str], str, bool, bool, bool, Optional[str]) -> None + # pylint: disable=too-many-arguments """Create a MethodInfo instance.""" self.class_name = class_name self.method_name = method_name @@ -91,6 +95,7 @@ class MethodInfo(object): self.static = static self.const = const self.explicit = explicit + self.desc_for_comment = desc_for_comment def get_declaration(self): # type: () -> str @@ -137,7 +142,7 @@ class MethodInfo(object): def get_call(self, obj): # type: (Optional[str]) -> str - """Generate a simply call to the method using the defined args list.""" + """Generate a simple call to the method using the defined args list.""" args = ', '.join([arg.name for arg in self.args]) @@ -148,6 +153,11 @@ class MethodInfo(object): return common.template_args("${method_name}(${args});", method_name=self.method_name, args=args) + def get_desc_for_comment(self): + # type: () -> Optional[str] + """Get the description of this method suitable for commenting it.""" + return self.desc_for_comment + class StructTypeInfoBase(object, metaclass=ABCMeta): """Base class for struct and command code generation.""" @@ -182,6 +192,18 @@ class StructTypeInfoBase(object, metaclass=ABCMeta): """Get the public static deserializer method for a struct.""" pass + @abstractmethod + def get_sharing_deserializer_static_method(self): + # type: () -> MethodInfo + """Get the public static deserializer method for a struct that participates in shared ownership of underlying data we are deserializing from.""" + pass + + @abstractmethod + def get_owned_deserializer_static_method(self): + # type: () -> MethodInfo + """Get the public static deserializer method for a struct that takes exclusive ownership of underlying data we are deserializing from.""" + pass + @abstractmethod def get_deserializer_method(self): # type: () -> MethodInfo @@ -249,12 +271,38 @@ class _StructTypeInfo(StructTypeInfoBase): class_name = common.title_case(self._struct.cpp_name) return MethodInfo(class_name, class_name, _get_required_parameters(self._struct)) + def get_sharing_deserializer_static_method(self): + # type: () -> MethodInfo + class_name = common.title_case(self._struct.cpp_name) + comment = textwrap.dedent(f"""\ + Factory function that parses a {class_name} from a BSONObj. A {class_name} parsed + this way participates in ownership of the data underlying the BSONObj.""") + return MethodInfo(class_name, 'parseSharingOwnership', + ['const IDLParserContext& ctxt', 'const BSONObj& bsonObject'], class_name, + static=True, desc_for_comment=comment) + + def get_owned_deserializer_static_method(self): + # type: () -> MethodInfo + class_name = common.title_case(self._struct.cpp_name) + comment = textwrap.dedent(f"""\ + Factory function that parses a {class_name} from a BSONObj. A {class_name} parsed + this way takes ownership of the data underlying the BSONObj.""") + return MethodInfo(class_name, 'parseOwned', + ['const IDLParserContext& ctxt', 'BSONObj&& bsonObject'], class_name, + static=True, desc_for_comment=comment) + def get_deserializer_static_method(self): # type: () -> MethodInfo class_name = common.title_case(self._struct.cpp_name) + comment = textwrap.dedent(f"""\ + Factory function that parses a {class_name} from a BSONObj. A {class_name} parsed + this way is strictly a view onto that BSONObj; the BSONObj must be kept valid to + ensure the validity any members of this struct that point-into the BSONObj (i.e. + unowned + objects).""") return MethodInfo(class_name, 'parse', ['const IDLParserContext& ctxt', 'const BSONObj& bsonObject'], class_name, - static=True) + static=True, desc_for_comment=comment) def get_deserializer_method(self): # type: () -> MethodInfo diff --git a/src/mongo/db/process_health/fault_manager_config.h b/src/mongo/db/process_health/fault_manager_config.h index 5918c677e41..f4d63721053 100644 --- a/src/mongo/db/process_health/fault_manager_config.h +++ b/src/mongo/db/process_health/fault_manager_config.h @@ -117,7 +117,7 @@ public: auto x = intensities->_data->getValues(); if (x) { - for (auto setting : *x) { + for (const auto& setting : *x) { if (setting.getType() == observerType) { return setting.getIntensity(); } diff --git a/src/mongo/db/query/ce/scalar_histogram.cpp b/src/mongo/db/query/ce/scalar_histogram.cpp index 400c9b03e60..e95215890f6 100644 --- a/src/mongo/db/query/ce/scalar_histogram.cpp +++ b/src/mongo/db/query/ce/scalar_histogram.cpp @@ -61,7 +61,7 @@ ScalarHistogram::ScalarHistogram() : ScalarHistogram({}, {}) {} ScalarHistogram::ScalarHistogram(std::vector buckets) { - for (auto bucket : buckets) { + for (const auto& bucket : buckets) { Bucket b(bucket.getBoundaryCount(), bucket.getRangeCount(), bucket.getCumulativeCount(), diff --git a/src/mongo/db/query/query_request_test.cpp b/src/mongo/db/query/query_request_test.cpp index 78c6c02d3ed..e9dc4dda22d 100644 --- a/src/mongo/db/query/query_request_test.cpp +++ b/src/mongo/db/query/query_request_test.cpp @@ -1536,6 +1536,26 @@ TEST(QueryRequestHelperTest, ValidateResponseWrongDataType) { ErrorCodes::TypeMismatch); } +TEST(QueryRequestHelperTest, ParsedCursorRemainsValidAfterBSONDestroyed) { + std::vector batch = {BSON("_id" << 1), BSON("_id" << 2)}; + CursorInitialReply cir; + { + BSONObj cursorObj = + BSON("cursor" << BSON("id" << CursorId(123) << "ns" + << "testdb.testcoll" + << "firstBatch" + << BSON_ARRAY(BSON("_id" << 1) << BSON("_id" << 2)))); + cir = CursorInitialReply::parseOwned( + IDLParserContext("QueryRequestHelperTest::ParsedCursorRemainsValidAFterBSONDestroyed"), + std::move(cursorObj)); + cursorObj = BSONObj(); + } + ASSERT_EQ(cir.getCursor()->getFirstBatch().size(), batch.size()); + for (std::vector::size_type i = 0; i < batch.size(); ++i) { + ASSERT_BSONOBJ_EQ(batch[i], cir.getCursor()->getFirstBatch()[i]); + } +} + class QueryRequestTest : public ServiceContextTest {}; TEST_F(QueryRequestTest, ParseFromUUID) { diff --git a/src/mongo/db/repl/oplog_entry_test.cpp b/src/mongo/db/repl/oplog_entry_test.cpp index b7a6052582f..6d4021bcc33 100644 --- a/src/mongo/db/repl/oplog_entry_test.cpp +++ b/src/mongo/db/repl/oplog_entry_test.cpp @@ -228,7 +228,6 @@ TEST(OplogEntryTest, ParseReplOperationIncludesTidField) { ASSERT_EQ(replOp.getTid(), tid); ASSERT_EQ(replOp.getNss(), nssWithTid); } - } // namespace } // namespace repl } // namespace mongo diff --git a/src/mongo/idl/idl_test.cpp b/src/mongo/idl/idl_test.cpp index 6b285d93e59..4e825e6db1a 100644 --- a/src/mongo/idl/idl_test.cpp +++ b/src/mongo/idl/idl_test.cpp @@ -4089,5 +4089,39 @@ TEST(IDLFieldTests, TenantOverrideFieldWithInvalidValue) { } } +TEST(IDLOwnershipTests, ParseOwnAssumesOwnership) { + IDLParserContext ctxt("root"); + One_plain_object idlStruct; + { + auto tmp = BSON("value" << BSON("x" << 42)); + idlStruct = One_plain_object::parseOwned(ctxt, std::move(tmp)); + } + // Now that tmp is out of scope, if idlStruct didn't retain ownership, it would be accessing + // free'd memory which should error on ASAN and debug builds. + auto obj = idlStruct.getValue(); + ASSERT_BSONOBJ_EQ(obj, BSON("x" << 42)); +} + +TEST(IDLOwnershipTests, ParseSharingOwnershipTmpBSON) { + IDLParserContext ctxt("root"); + One_plain_object idlStruct; + { + auto tmp = BSON("value" << BSON("x" << 42)); + idlStruct = One_plain_object::parseSharingOwnership(ctxt, tmp); + } + // Now that tmp is out of scope, if idlStruct didn't particpate in ownership, it would be + // accessing free'd memory which should error on ASAN and debug builds. + auto obj = idlStruct.getValue(); + ASSERT_BSONOBJ_EQ(obj, BSON("x" << 42)); +} + +TEST(IDLOwnershipTests, ParseSharingOwnershipTmpIDLStruct) { + IDLParserContext ctxt("root"); + auto bson = BSON("value" << BSON("x" << 42)); + { auto idlStruct = One_plain_object::parseSharingOwnership(ctxt, bson); } + // Now that idlStruct is out of scope, if bson didn't particpate in ownership, it would be + // accessing free'd memory which should error on ASAN and debug builds. + ASSERT_BSONOBJ_EQ(bson["value"].Obj(), BSON("x" << 42)); +} } // namespace } // namespace mongo -- cgit v1.2.1