diff options
Diffstat (limited to 'buildscripts')
-rw-r--r-- | buildscripts/idl/idl/generator.py | 149 | ||||
-rw-r--r-- | buildscripts/idl/tests/test_binder.py | 14 |
2 files changed, 141 insertions, 22 deletions
diff --git a/buildscripts/idl/idl/generator.py b/buildscripts/idl/idl/generator.py index c2d198e50cf..38094c787eb 100644 --- a/buildscripts/idl/idl/generator.py +++ b/buildscripts/idl/idl/generator.py @@ -1275,6 +1275,90 @@ class _CppSourceFileWriter(_CppFileWriterBase): else: self._writer.write_line('%s = std::move(values);' % (_get_field_member_name(field))) + def _gen_variant_deserializer(self, field, bson_element): + # type: (ast.Field, str) -> None + # pylint: disable=too-many-statements + """Generate the C++ deserializer piece for a variant field.""" + self._writer.write_empty_line() + self._writer.write_line('const BSONType variantType = %s.type();' % (bson_element, )) + + array_types = [v for v in field.type.variant_types if v.is_array] + scalar_types = [v for v in field.type.variant_types if not v.is_array] + + self._writer.write_line('switch (variantType) {') + if array_types: + self._writer.write_line('case Array:') + self._writer.indent() + with self._predicate('%s.Obj().isEmpty()' % (bson_element, )): + # Can't determine element type of an empty array, use the first array type. + self._gen_array_deserializer(field, bson_element, array_types[0]) + + with self._block('else {', '}'): + self._writer.write_line( + 'const BSONType elemType = %s.Obj().firstElement().type();' % (bson_element, )) + + # Start inner switch statement, for each type the first element could be. + self._writer.write_line('switch (elemType) {') + for array_type in array_types: + for bson_type in array_type.bson_serialization_type: + self._writer.write_line('case %s:' % (bson.cpp_bson_type_name(bson_type), )) + # Each copy of the array deserialization code gets an anonymous block. + with self._block('{', '}'): + self._gen_array_deserializer(field, bson_element, array_type) + self._writer.write_line('break;') + + self._writer.write_line('default:') + self._writer.indent() + expected_types = ', '.join( + 'BSONType::%s' % bson.cpp_bson_type_name(t.bson_serialization_type[0]) + for t in array_types) + self._writer.write_line( + 'ctxt.throwBadType(%s, {%s});' % (bson_element, expected_types)) + self._writer.write_line('break;') + self._writer.unindent() + # End of inner switch. + self._writer.write_line('}') + + # End of "case Array:". + self._writer.write_line('break;') + self._writer.unindent() + + for scalar_type in scalar_types: + for bson_type in scalar_type.bson_serialization_type: + self._writer.write_line('case %s:' % (bson.cpp_bson_type_name(bson_type), )) + self._writer.indent() + self.gen_field_deserializer(field, scalar_type, "bsonObject", bson_element, None, + check_type=False) + self._writer.write_line('break;') + self._writer.unindent() + + if field.type.variant_struct_type: + self._writer.write_line('case Object:') + self._writer.indent() + object_value = '%s::parse(ctxt, %s.Obj())' % (field.type.variant_struct_type.cpp_type, + bson_element) + + if field.chained_struct_field: + self._writer.write_line( + '%s.%s(%s);' % (_get_field_member_name(field.chained_struct_field), + _get_field_member_setter_name(field), object_value)) + else: + self._writer.write_line('%s = %s;' % (_get_field_member_name(field), object_value)) + self._writer.write_line('break;') + self._writer.unindent() + + self._writer.write_line('default:') + self._writer.indent() + expected_types = ', '.join( + 'BSONType::%s' % bson.cpp_bson_type_name(t.bson_serialization_type[0]) + for t in scalar_types) + self._writer.write_line('ctxt.throwBadType(%s, {%s});' % (bson_element, expected_types)) + self._writer.write_line('break;') + self._writer.unindent() + + # End of outer switch statement. + self._writer.write_line('}') + def _gen_usage_check(self, field, bson_element, field_usage_check): # type: (ast.Field, str, _FieldUsageCheckerBase) -> None """Generate the field usage check and insert the required field check.""" @@ -1284,22 +1368,25 @@ class _CppSourceFileWriter(_CppFileWriterBase): if _is_required_serializer_field(field): self._writer.write_line('%s = true;' % (_get_has_field_member_name(field))) - def gen_field_deserializer(self, field, bson_object, bson_element, field_usage_check, - is_command_field=False): - # type: (ast.Field, str, str, _FieldUsageCheckerBase, bool) -> None - """Generate the C++ deserializer piece for a field.""" + def gen_field_deserializer(self, field, field_type, bson_object, bson_element, + field_usage_check, is_command_field=False, check_type=True): + # type: (ast.Field, ast.Type, str, str, _FieldUsageCheckerBase, bool, bool) -> None + """Generate the C++ deserializer piece for a field. + + If field_type is scalar and check_type is True (the default), generate type-checking code. + Array elements are always type-checked. + """ # pylint: disable=too-many-arguments - if field.type.is_array: + if field_type.is_array: predicate = "MONGO_likely(ctxt.checkAndAssertType(%s, Array))" % (bson_element) with self._predicate(predicate): self._gen_usage_check(field, bson_element, field_usage_check) - - self._gen_array_deserializer(field, bson_element, field.type) + self._gen_array_deserializer(field, bson_element, field_type) return - elif field.type.is_variant: + elif field_type.is_variant: self._gen_usage_check(field, bson_element, field_usage_check) - # TODO (SERVER-51369): implement _gen_variant_deserializer. + self._gen_variant_deserializer(field, bson_element) return def validate_and_assign_or_uassert(field, expression): @@ -1318,28 +1405,31 @@ class _CppSourceFileWriter(_CppFileWriterBase): if field.chained: # Do not generate a predicate check since we always call these deserializers. - if field.type.is_struct: + if field_type.is_struct: # Do not generate a new parser context, reuse the current one since we are not # entering a nested document. - expression = '%s::parse(ctxt, %s)' % (field.type.cpp_type, bson_object) + expression = '%s::parse(ctxt, %s)' % (field_type.cpp_type, bson_object) else: method_name = writer.get_method_name_from_qualified_method_name( - field.type.deserializer) + field_type.deserializer) expression = "%s(%s)" % (method_name, bson_object) self._gen_usage_check(field, bson_element, field_usage_check) validate_and_assign_or_uassert(field, expression) else: - predicate = _get_bson_type_check(bson_element, 'ctxt', field.type) - if predicate: - predicate = "MONGO_likely(%s)" % (predicate) + predicate = None + if check_type: + predicate = _get_bson_type_check(bson_element, 'ctxt', field_type) + if predicate: + predicate = "MONGO_likely(%s)" % (predicate) + with self._predicate(predicate): self._gen_usage_check(field, bson_element, field_usage_check) object_value = self._gen_field_deserializer_expression( - bson_element, field, field.type) + bson_element, field, field_type) if field.chained_struct_field: if field.optional: # We must invoke the boost::optional constructor when setting optional view @@ -1497,8 +1587,9 @@ class _CppSourceFileWriter(_CppFileWriterBase): if isinstance(struct, ast.Command) and struct.command_field: with self._block('{', '}'): - self.gen_field_deserializer(struct.command_field, bson_object, "commandElement", - None, is_command_field=True) + self.gen_field_deserializer(struct.command_field, struct.command_field.type, + bson_object, "commandElement", None, + is_command_field=True) else: struct_type_info = struct_types.get_struct_info(struct) @@ -1549,7 +1640,7 @@ class _CppSourceFileWriter(_CppFileWriterBase): self._writer.write_line('// ignore field') else: - self.gen_field_deserializer(field, bson_object, "element", + self.gen_field_deserializer(field, field.type, bson_object, "element", field_usage_check) if first_field: @@ -1576,7 +1667,7 @@ class _CppSourceFileWriter(_CppFileWriterBase): continue # Simply generate deserializers since these are all 'any' types - self.gen_field_deserializer(field, bson_object, "element", None) + self.gen_field_deserializer(field, field.type, bson_object, "element", None) self._writer.write_empty_line() self._writer.write_empty_line() @@ -1854,6 +1945,21 @@ class _CppSourceFileWriter(_CppFileWriterBase): 'BSONObjBuilder subObjBuilder(builder->subobjStart(${field_name}));') self._writer.write_template('${access_member}.serialize(&subObjBuilder);') + def _gen_serializer_method_variant(self, field): + # type: (ast.Field) -> None + """Generate the serialize method definition for a variant type.""" + template_params = { + 'field_name': _get_field_constant_name(field), + 'access_member': _access_member(field), + } + + with self._with_template(template_params): + # See https://en.cppreference.com/w/cpp/utility/variant/visit + # This lambda is a template instantiated for each alternate type. Use "if constexpr" + # to compile the appropriate serialization code for each. + with self._block('stdx::visit([builder](auto&& arg) {', '}, ${access_member});'): + self._writer.write_template('idlSerialize(builder, ${field_name}, arg);') + def _gen_serializer_method_common(self, field): # type: (ast.Field) -> None # pylint: disable=too-many-locals @@ -1879,8 +1985,7 @@ class _CppSourceFileWriter(_CppFileWriterBase): if needs_custom_serializer: self._gen_serializer_method_custom(field) elif field.type.is_variant: - # TODO (SERVER-51369): implement deserializer. - return + self._gen_serializer_method_variant(field) else: # Generate default serialization using BSONObjBuilder::append # Note: BSONObjBuilder::append has overrides for std::vector also diff --git a/buildscripts/idl/tests/test_binder.py b/buildscripts/idl/tests/test_binder.py index 5e20a27a346..86bd63db773 100644 --- a/buildscripts/idl/tests/test_binder.py +++ b/buildscripts/idl/tests/test_binder.py @@ -756,6 +756,20 @@ class TestBinder(testcase.IDLTestcase): - array<struct1> """), idl.errors.ERROR_ID_VARIANT_DUPLICATE_TYPES) + # At most one array can have BSON serialization type NumberInt. + self.assert_bind_fail( + test_preamble + textwrap.dedent(""" + structs: + foo: + description: foo + fields: + my_variant_field: + type: + variant: + - array<int> + - array<safeInt> + """), idl.errors.ERROR_ID_VARIANT_DUPLICATE_TYPES) + self.assert_bind_fail( test_preamble + textwrap.dedent(""" structs: |