diff options
author | Mark Benvenuto <mark.benvenuto@mongodb.com> | 2017-06-27 11:39:11 -0400 |
---|---|---|
committer | Mark Benvenuto <mark.benvenuto@mongodb.com> | 2017-06-27 15:39:30 -0400 |
commit | 68537d7033973a4ae499c030c087e344f8b83b1f (patch) | |
tree | d3c14872ddd1aa1d029377bf3e61db1da677a70a /buildscripts/idl | |
parent | beea8e5d090d269ca0a0390785bd417fcf4cfcf2 (diff) | |
download | mongo-68537d7033973a4ae499c030c087e344f8b83b1f.tar.gz |
SERVER-29849 IDL generated code should ensure fields are set on serialization
Diffstat (limited to 'buildscripts/idl')
-rw-r--r-- | buildscripts/idl/idl/generator.py | 85 | ||||
-rw-r--r-- | buildscripts/idl/idl/struct_types.py | 92 |
2 files changed, 128 insertions, 49 deletions
diff --git a/buildscripts/idl/idl/generator.py b/buildscripts/idl/idl/generator.py index 851c7a4da41..91f09e70799 100644 --- a/buildscripts/idl/idl/generator.py +++ b/buildscripts/idl/idl/generator.py @@ -40,6 +40,23 @@ def _get_field_member_name(field): return '_%s' % (common.camel_case(field.cpp_name)) +def _get_has_field_member_name(field): + # type: (ast.Field) -> unicode + """Get the C++ class member name for bool 'has' member field.""" + return '_has%s' % (common.title_case(field.cpp_name)) + + +def _is_required_serializer_field(field): + # type: (ast.Field) -> bool + """ + Get whether we require this field to have a value set before serialization. + + Fields that must be set before serialization are fields without default values, that are not + optional, and are not chained. + """ + return not field.ignore and not field.optional and not field.default and not field.chained + + def _get_field_constant_name(field): # type: (ast.Field) -> unicode """Get the C++ string constant name for a field.""" @@ -337,8 +354,7 @@ class _CppHeaderFileWriter(_CppFileWriterBase): """Generate the declarations for the class constructors.""" struct_type_info = struct_types.get_struct_info(struct) - if struct_type_info.get_constructor_method(): - self._writer.write_line(struct_type_info.get_constructor_method().get_declaration()) + self._writer.write_line(struct_type_info.get_constructor_method().get_declaration()) def gen_serializer_methods(self, struct): # type: (ast.Struct) -> None @@ -399,15 +415,21 @@ class _CppHeaderFileWriter(_CppFileWriterBase): param_type = cpp_type_info.get_getter_setter_type() member_name = _get_field_member_name(field) + post_body = '' + if _is_required_serializer_field(field): + post_body = '%s = true;' % (_get_has_field_member_name(field)) + template_params = { 'method_name': common.title_case(field.cpp_name), 'member_name': member_name, 'param_type': param_type, - 'body': cpp_type_info.get_setter_body(member_name) + 'body': cpp_type_info.get_setter_body(member_name), + 'post_body': post_body, } with self._with_template(template_params): - self._writer.write_template('void set${method_name}(${param_type} value) & { ${body} }') + self._writer.write_template('void set${method_name}(${param_type} value) & ' + + '{ ${body} ${post_body} }') self._writer.write_empty_line() @@ -423,6 +445,14 @@ class _CppHeaderFileWriter(_CppFileWriterBase): else: self._writer.write_line('%s %s;' % (member_type, member_name)) + def gen_serializer_member(self, field): + # type: (ast.Field) -> None + """Generate the C++ class bool has_<field> member definition for a field.""" + has_member_name = _get_has_field_member_name(field) + + # Use a bitfield to save space + self._writer.write_line('bool %s : 1;' % (has_member_name)) + def gen_string_constants_declarations(self, struct): # type: (ast.Struct) -> None # pylint: disable=invalid-name @@ -484,6 +514,7 @@ class _CppHeaderFileWriter(_CppFileWriterBase): def generate(self, spec): # type: (ast.IDLAST) -> None """Generate the C++ header to a stream.""" + # pylint: disable=too-many-branches self.gen_file_header() self._writer.write_unindented_line('#pragma once') @@ -577,6 +608,13 @@ class _CppHeaderFileWriter(_CppFileWriterBase): if not field.ignore: self.gen_member(field) + # Write serializer member variables + # Note: we write these out second to ensure the bit fields can be packed by + # the compiler. + for field in struct.fields: + if _is_required_serializer_field(field): + self.gen_serializer_member(field) + self.write_empty_line() @@ -716,24 +754,37 @@ class _CppSourceFileWriter(_CppFileWriterBase): def gen_constructors(self, struct): # type: (ast.Struct) -> None - """Generate the C++ constructor definitions.""" + """Generate the C++ constructor definition.""" struct_type_info = struct_types.get_struct_info(struct) - if struct_type_info.get_constructor_method(): - with self._block('%s : _nss(nss) {' % - (struct_type_info.get_constructor_method().get_definition()), '}'): - self._writer.write_line('// Used for initialization only') + constructor = struct_type_info.get_constructor_method() + + initializers = ['_%s(%s)' % (arg.name, arg.name) for arg in constructor.args] + + initializers += [ + '%s(false)' % _get_has_field_member_name(field) for field in struct.fields + if _is_required_serializer_field(field) + ] + + initializers_str = '' + if initializers: + initializers_str = ': ' + ', '.join(initializers) + + with self._block('%s %s {' % (constructor.get_definition(), initializers_str), '}'): + self._writer.write_line('// Used for initialization only') def gen_deserializer_methods(self, struct): # type: (ast.Struct) -> None """Generate the C++ deserializer method definitions.""" + # pylint: disable=too-many-branches # Commands that have concatentate_with_db namespaces require db name as a parameter struct_type_info = struct_types.get_struct_info(struct) with self._block('%s {' % (struct_type_info.get_deserializer_static_method().get_definition()), '}'): - if isinstance(struct, ast.Command) and struct_type_info.get_constructor_method(): + if isinstance(struct, + ast.Command) and struct.namespace != common.COMMAND_NAMESPACE_IGNORED: self._writer.write_line('%s object(%s);' % ( common.title_case(struct.name), 'ctxt.parseNSCollectionRequired(dbName, bsonObject.firstElement())')) @@ -781,6 +832,10 @@ class _CppSourceFileWriter(_CppFileWriterBase): if field.ignore: self._writer.write_line('// ignore field') else: + if _is_required_serializer_field(field): + self._writer.write_line('%s = true;' % + (_get_has_field_member_name(field))) + self.gen_field_deserializer(field) if first_field: @@ -907,6 +962,16 @@ class _CppSourceFileWriter(_CppFileWriterBase): struct_type_info = struct_types.get_struct_info(struct) with self._block('%s {' % (struct_type_info.get_serializer_method().get_definition()), '}'): + # Check all required fields have been specified + required_fields = [ + _get_has_field_member_name(field) for field in struct.fields + if _is_required_serializer_field(field) + ] + + if required_fields: + assert_fields_set = ' && '.join(required_fields) + self._writer.write_line('invariant(%s);' % assert_fields_set) + self._writer.write_empty_line() # Serialize the namespace as the first field if isinstance(struct, ast.Command): diff --git a/buildscripts/idl/idl/struct_types.py b/buildscripts/idl/idl/struct_types.py index 82ec208d65c..402cc8ae002 100644 --- a/buildscripts/idl/idl/struct_types.py +++ b/buildscripts/idl/idl/struct_types.py @@ -24,6 +24,22 @@ from . import common from . import writer +class ArgumentInfo(object): + """Class that encapsulates information about an argument to a method.""" + + def __init__(self, arg): + # type: (unicode) -> None + """Create a instance of the ArgumentInfo class by parsing the argument string.""" + parts = arg.split(' ') + self.type = ' '.join(parts[0:-1]) + self.name = parts[-1] + + def __str__(self): + # type: () -> str + """Return a formatted argument string.""" + return "%s %s" % (self.type, self.name) # type: ignore + + class MethodInfo(object): """Class that encapslates information about a method and how to declare, define, and call it.""" @@ -31,12 +47,12 @@ class MethodInfo(object): # type: (unicode, unicode, List[unicode], unicode, bool, bool) -> None # pylint: disable=too-many-arguments """Create a MethodInfo instance.""" - self._class_name = class_name - self._method_name = method_name - self._args = args - self._return_type = return_type - self._static = static - self._const = const + self.class_name = class_name + self.method_name = method_name + self.args = [ArgumentInfo(arg) for arg in args] + self.return_type = return_type + self.static = static + self.const = const def get_declaration(self): # type: () -> unicode @@ -45,21 +61,21 @@ class MethodInfo(object): post_modifiers = '' return_type_str = '' - if self._static: + if self.static: pre_modifiers = 'static ' - if self._const: + if self.const: post_modifiers = ' const' - if self._return_type: - return_type_str = self._return_type + ' ' + if self.return_type: + return_type_str = self.return_type + ' ' return common.template_args( "${pre_modifiers}${return_type}${method_name}(${args})${post_modifiers};", pre_modifiers=pre_modifiers, return_type=return_type_str, - method_name=self._method_name, - args=', '.join(self._args), + method_name=self.method_name, + args=', '.join([str(arg) for arg in self.args]), post_modifiers=post_modifiers) def get_definition(self): @@ -69,36 +85,33 @@ class MethodInfo(object): post_modifiers = '' return_type_str = '' - if self._const: + if self.const: post_modifiers = ' const' - if self._return_type: - return_type_str = self._return_type + ' ' + if self.return_type: + return_type_str = self.return_type + ' ' return common.template_args( "${pre_modifiers}${return_type}${class_name}::${method_name}(${args})${post_modifiers}", pre_modifiers=pre_modifiers, return_type=return_type_str, - class_name=self._class_name, - method_name=self._method_name, - args=', '.join(self._args), + class_name=self.class_name, + method_name=self.method_name, + args=', '.join([str(arg) for arg in self.args]), post_modifiers=post_modifiers) def get_call(self, obj): # type: (Optional[unicode]) -> unicode """Generate a simply call to the method using the defined args list.""" - args = ', '.join([a.split(' ')[-1] for a in self._args]) + args = ', '.join([arg.name for arg in self.args]) if obj: return common.template_args( - "${obj}.${method_name}(${args});", - obj=obj, - method_name=self._method_name, - args=args) + "${obj}.${method_name}(${args});", obj=obj, method_name=self.method_name, args=args) return common.template_args( - "${method_name}(${args});", method_name=self._method_name, args=args) + "${method_name}(${args});", method_name=self.method_name, args=args) class StructTypeInfoBase(object): @@ -161,27 +174,28 @@ class _StructTypeInfo(StructTypeInfoBase): def __init__(self, struct): # type: (ast.Struct) -> None """Create a _StructTypeInfo instance.""" - self._struct = struct + self.struct = struct def get_constructor_method(self): # type: () -> MethodInfo - pass + class_name = common.title_case(self.struct.name) + return MethodInfo(class_name, class_name, []) def get_serializer_method(self): # type: () -> MethodInfo return MethodInfo( - common.title_case(self._struct.name), + common.title_case(self.struct.name), 'serialize', ['BSONObjBuilder* builder'], 'void', const=True) def get_to_bson_method(self): # type: () -> MethodInfo - return MethodInfo(common.title_case(self._struct.name), 'toBSON', [], 'BSONObj', const=True) + return MethodInfo(common.title_case(self.struct.name), 'toBSON', [], 'BSONObj', const=True) def get_deserializer_static_method(self): # type: () -> MethodInfo - class_name = common.title_case(self._struct.name) + class_name = common.title_case(self.struct.name) return MethodInfo( class_name, 'parse', ['const IDLParserErrorContext& ctxt', 'const BSONObj& bsonObject'], @@ -191,7 +205,7 @@ class _StructTypeInfo(StructTypeInfoBase): def get_deserializer_method(self): # type: () -> MethodInfo return MethodInfo( - common.title_case(self._struct.name), 'parseProtected', + common.title_case(self.struct.name), 'parseProtected', ['const IDLParserErrorContext& ctxt', 'const BSONObj& bsonObject'], 'void') def gen_getter_method(self, indented_writer): @@ -213,7 +227,7 @@ class _IgnoredCommandTypeInfo(_StructTypeInfo): def __init__(self, command): # type: (ast.Command) -> None """Create a _IgnoredCommandTypeInfo instance.""" - self._command = command + self.command = command super(_IgnoredCommandTypeInfo, self).__init__(command) @@ -243,7 +257,7 @@ class _IgnoredCommandTypeInfo(_StructTypeInfo): def gen_serializer(self, indented_writer): # type: (writer.IndentedTextWriter) -> None - indented_writer.write_line('builder->append("%s", 1);' % (self._command.name)) + indented_writer.write_line('builder->append("%s", 1);' % (self.command.name)) class _CommandWithNamespaceTypeInfo(_StructTypeInfo): @@ -252,20 +266,20 @@ class _CommandWithNamespaceTypeInfo(_StructTypeInfo): def __init__(self, command): # type: (ast.Command) -> None """Create a _CommandWithNamespaceTypeInfo instance.""" - self._command = command + self.command = command super(_CommandWithNamespaceTypeInfo, self).__init__(command) def get_constructor_method(self): # type: () -> MethodInfo - class_name = common.title_case(self._struct.name) + class_name = common.title_case(self.struct.name) return MethodInfo(class_name, class_name, ['const NamespaceString& nss']) def get_serializer_method(self): # type: () -> MethodInfo # Commands that require namespaces require it as a parameter to serialize() return MethodInfo( - common.title_case(self._struct.name), + common.title_case(self.struct.name), 'serialize', ['BSONObjBuilder* builder'], 'void', const=True) @@ -273,12 +287,12 @@ class _CommandWithNamespaceTypeInfo(_StructTypeInfo): def get_to_bson_method(self): # type: () -> MethodInfo # Commands that require namespaces require it as a parameter to serialize() - return MethodInfo(common.title_case(self._struct.name), 'toBSON', [], 'BSONObj', const=True) + return MethodInfo(common.title_case(self.struct.name), 'toBSON', [], 'BSONObj', const=True) def get_deserializer_static_method(self): # type: () -> MethodInfo # Commands that have concatentate_with_db namespaces require db name as a parameter - class_name = common.title_case(self._struct.name) + class_name = common.title_case(self.struct.name) return MethodInfo( class_name, 'parse', @@ -290,7 +304,7 @@ class _CommandWithNamespaceTypeInfo(_StructTypeInfo): # type: () -> MethodInfo # Commands that have concatentate_with_db namespaces require db name as a parameter return MethodInfo( - common.title_case(self._struct.name), 'parseProtected', + common.title_case(self.struct.name), 'parseProtected', ['const IDLParserErrorContext& ctxt', 'StringData dbName', 'const BSONObj& bsonObject'], 'void') @@ -304,7 +318,7 @@ class _CommandWithNamespaceTypeInfo(_StructTypeInfo): def gen_serializer(self, indented_writer): # type: (writer.IndentedTextWriter) -> None - indented_writer.write_line('builder->append("%s", _nss.coll());' % (self._command.name)) + indented_writer.write_line('builder->append("%s", _nss.coll());' % (self.command.name)) indented_writer.write_empty_line() |