summaryrefslogtreecommitdiff
path: root/buildscripts/idl
diff options
context:
space:
mode:
authorMark Benvenuto <mark.benvenuto@mongodb.com>2017-06-27 11:39:11 -0400
committerMark Benvenuto <mark.benvenuto@mongodb.com>2017-06-27 15:39:30 -0400
commit68537d7033973a4ae499c030c087e344f8b83b1f (patch)
treed3c14872ddd1aa1d029377bf3e61db1da677a70a /buildscripts/idl
parentbeea8e5d090d269ca0a0390785bd417fcf4cfcf2 (diff)
downloadmongo-68537d7033973a4ae499c030c087e344f8b83b1f.tar.gz
SERVER-29849 IDL generated code should ensure fields are set on serialization
Diffstat (limited to 'buildscripts/idl')
-rw-r--r--buildscripts/idl/idl/generator.py85
-rw-r--r--buildscripts/idl/idl/struct_types.py92
2 files changed, 128 insertions, 49 deletions
diff --git a/buildscripts/idl/idl/generator.py b/buildscripts/idl/idl/generator.py
index 851c7a4da41..91f09e70799 100644
--- a/buildscripts/idl/idl/generator.py
+++ b/buildscripts/idl/idl/generator.py
@@ -40,6 +40,23 @@ def _get_field_member_name(field):
return '_%s' % (common.camel_case(field.cpp_name))
+def _get_has_field_member_name(field):
+ # type: (ast.Field) -> unicode
+ """Get the C++ class member name for bool 'has' member field."""
+ return '_has%s' % (common.title_case(field.cpp_name))
+
+
+def _is_required_serializer_field(field):
+ # type: (ast.Field) -> bool
+ """
+ Get whether we require this field to have a value set before serialization.
+
+ Fields that must be set before serialization are fields without default values, that are not
+ optional, and are not chained.
+ """
+ return not field.ignore and not field.optional and not field.default and not field.chained
+
+
def _get_field_constant_name(field):
# type: (ast.Field) -> unicode
"""Get the C++ string constant name for a field."""
@@ -337,8 +354,7 @@ class _CppHeaderFileWriter(_CppFileWriterBase):
"""Generate the declarations for the class constructors."""
struct_type_info = struct_types.get_struct_info(struct)
- if struct_type_info.get_constructor_method():
- self._writer.write_line(struct_type_info.get_constructor_method().get_declaration())
+ self._writer.write_line(struct_type_info.get_constructor_method().get_declaration())
def gen_serializer_methods(self, struct):
# type: (ast.Struct) -> None
@@ -399,15 +415,21 @@ class _CppHeaderFileWriter(_CppFileWriterBase):
param_type = cpp_type_info.get_getter_setter_type()
member_name = _get_field_member_name(field)
+ post_body = ''
+ if _is_required_serializer_field(field):
+ post_body = '%s = true;' % (_get_has_field_member_name(field))
+
template_params = {
'method_name': common.title_case(field.cpp_name),
'member_name': member_name,
'param_type': param_type,
- 'body': cpp_type_info.get_setter_body(member_name)
+ 'body': cpp_type_info.get_setter_body(member_name),
+ 'post_body': post_body,
}
with self._with_template(template_params):
- self._writer.write_template('void set${method_name}(${param_type} value) & { ${body} }')
+ self._writer.write_template('void set${method_name}(${param_type} value) & ' +
+ '{ ${body} ${post_body} }')
self._writer.write_empty_line()
@@ -423,6 +445,14 @@ class _CppHeaderFileWriter(_CppFileWriterBase):
else:
self._writer.write_line('%s %s;' % (member_type, member_name))
+ def gen_serializer_member(self, field):
+ # type: (ast.Field) -> None
+ """Generate the C++ class bool has_<field> member definition for a field."""
+ has_member_name = _get_has_field_member_name(field)
+
+ # Use a bitfield to save space
+ self._writer.write_line('bool %s : 1;' % (has_member_name))
+
def gen_string_constants_declarations(self, struct):
# type: (ast.Struct) -> None
# pylint: disable=invalid-name
@@ -484,6 +514,7 @@ class _CppHeaderFileWriter(_CppFileWriterBase):
def generate(self, spec):
# type: (ast.IDLAST) -> None
"""Generate the C++ header to a stream."""
+ # pylint: disable=too-many-branches
self.gen_file_header()
self._writer.write_unindented_line('#pragma once')
@@ -577,6 +608,13 @@ class _CppHeaderFileWriter(_CppFileWriterBase):
if not field.ignore:
self.gen_member(field)
+ # Write serializer member variables
+ # Note: we write these out second to ensure the bit fields can be packed by
+ # the compiler.
+ for field in struct.fields:
+ if _is_required_serializer_field(field):
+ self.gen_serializer_member(field)
+
self.write_empty_line()
@@ -716,24 +754,37 @@ class _CppSourceFileWriter(_CppFileWriterBase):
def gen_constructors(self, struct):
# type: (ast.Struct) -> None
- """Generate the C++ constructor definitions."""
+ """Generate the C++ constructor definition."""
struct_type_info = struct_types.get_struct_info(struct)
- if struct_type_info.get_constructor_method():
- with self._block('%s : _nss(nss) {' %
- (struct_type_info.get_constructor_method().get_definition()), '}'):
- self._writer.write_line('// Used for initialization only')
+ constructor = struct_type_info.get_constructor_method()
+
+ initializers = ['_%s(%s)' % (arg.name, arg.name) for arg in constructor.args]
+
+ initializers += [
+ '%s(false)' % _get_has_field_member_name(field) for field in struct.fields
+ if _is_required_serializer_field(field)
+ ]
+
+ initializers_str = ''
+ if initializers:
+ initializers_str = ': ' + ', '.join(initializers)
+
+ with self._block('%s %s {' % (constructor.get_definition(), initializers_str), '}'):
+ self._writer.write_line('// Used for initialization only')
def gen_deserializer_methods(self, struct):
# type: (ast.Struct) -> None
"""Generate the C++ deserializer method definitions."""
+ # pylint: disable=too-many-branches
# Commands that have concatentate_with_db namespaces require db name as a parameter
struct_type_info = struct_types.get_struct_info(struct)
with self._block('%s {' %
(struct_type_info.get_deserializer_static_method().get_definition()), '}'):
- if isinstance(struct, ast.Command) and struct_type_info.get_constructor_method():
+ if isinstance(struct,
+ ast.Command) and struct.namespace != common.COMMAND_NAMESPACE_IGNORED:
self._writer.write_line('%s object(%s);' % (
common.title_case(struct.name),
'ctxt.parseNSCollectionRequired(dbName, bsonObject.firstElement())'))
@@ -781,6 +832,10 @@ class _CppSourceFileWriter(_CppFileWriterBase):
if field.ignore:
self._writer.write_line('// ignore field')
else:
+ if _is_required_serializer_field(field):
+ self._writer.write_line('%s = true;' %
+ (_get_has_field_member_name(field)))
+
self.gen_field_deserializer(field)
if first_field:
@@ -907,6 +962,16 @@ class _CppSourceFileWriter(_CppFileWriterBase):
struct_type_info = struct_types.get_struct_info(struct)
with self._block('%s {' % (struct_type_info.get_serializer_method().get_definition()), '}'):
+ # Check all required fields have been specified
+ required_fields = [
+ _get_has_field_member_name(field) for field in struct.fields
+ if _is_required_serializer_field(field)
+ ]
+
+ if required_fields:
+ assert_fields_set = ' && '.join(required_fields)
+ self._writer.write_line('invariant(%s);' % assert_fields_set)
+ self._writer.write_empty_line()
# Serialize the namespace as the first field
if isinstance(struct, ast.Command):
diff --git a/buildscripts/idl/idl/struct_types.py b/buildscripts/idl/idl/struct_types.py
index 82ec208d65c..402cc8ae002 100644
--- a/buildscripts/idl/idl/struct_types.py
+++ b/buildscripts/idl/idl/struct_types.py
@@ -24,6 +24,22 @@ from . import common
from . import writer
+class ArgumentInfo(object):
+ """Class that encapsulates information about an argument to a method."""
+
+ def __init__(self, arg):
+ # type: (unicode) -> None
+ """Create a instance of the ArgumentInfo class by parsing the argument string."""
+ parts = arg.split(' ')
+ self.type = ' '.join(parts[0:-1])
+ self.name = parts[-1]
+
+ def __str__(self):
+ # type: () -> str
+ """Return a formatted argument string."""
+ return "%s %s" % (self.type, self.name) # type: ignore
+
+
class MethodInfo(object):
"""Class that encapslates information about a method and how to declare, define, and call it."""
@@ -31,12 +47,12 @@ class MethodInfo(object):
# type: (unicode, unicode, List[unicode], unicode, bool, bool) -> None
# pylint: disable=too-many-arguments
"""Create a MethodInfo instance."""
- self._class_name = class_name
- self._method_name = method_name
- self._args = args
- self._return_type = return_type
- self._static = static
- self._const = const
+ self.class_name = class_name
+ self.method_name = method_name
+ self.args = [ArgumentInfo(arg) for arg in args]
+ self.return_type = return_type
+ self.static = static
+ self.const = const
def get_declaration(self):
# type: () -> unicode
@@ -45,21 +61,21 @@ class MethodInfo(object):
post_modifiers = ''
return_type_str = ''
- if self._static:
+ if self.static:
pre_modifiers = 'static '
- if self._const:
+ if self.const:
post_modifiers = ' const'
- if self._return_type:
- return_type_str = self._return_type + ' '
+ if self.return_type:
+ return_type_str = self.return_type + ' '
return common.template_args(
"${pre_modifiers}${return_type}${method_name}(${args})${post_modifiers};",
pre_modifiers=pre_modifiers,
return_type=return_type_str,
- method_name=self._method_name,
- args=', '.join(self._args),
+ method_name=self.method_name,
+ args=', '.join([str(arg) for arg in self.args]),
post_modifiers=post_modifiers)
def get_definition(self):
@@ -69,36 +85,33 @@ class MethodInfo(object):
post_modifiers = ''
return_type_str = ''
- if self._const:
+ if self.const:
post_modifiers = ' const'
- if self._return_type:
- return_type_str = self._return_type + ' '
+ if self.return_type:
+ return_type_str = self.return_type + ' '
return common.template_args(
"${pre_modifiers}${return_type}${class_name}::${method_name}(${args})${post_modifiers}",
pre_modifiers=pre_modifiers,
return_type=return_type_str,
- class_name=self._class_name,
- method_name=self._method_name,
- args=', '.join(self._args),
+ class_name=self.class_name,
+ method_name=self.method_name,
+ args=', '.join([str(arg) for arg in self.args]),
post_modifiers=post_modifiers)
def get_call(self, obj):
# type: (Optional[unicode]) -> unicode
"""Generate a simply call to the method using the defined args list."""
- args = ', '.join([a.split(' ')[-1] for a in self._args])
+ args = ', '.join([arg.name for arg in self.args])
if obj:
return common.template_args(
- "${obj}.${method_name}(${args});",
- obj=obj,
- method_name=self._method_name,
- args=args)
+ "${obj}.${method_name}(${args});", obj=obj, method_name=self.method_name, args=args)
return common.template_args(
- "${method_name}(${args});", method_name=self._method_name, args=args)
+ "${method_name}(${args});", method_name=self.method_name, args=args)
class StructTypeInfoBase(object):
@@ -161,27 +174,28 @@ class _StructTypeInfo(StructTypeInfoBase):
def __init__(self, struct):
# type: (ast.Struct) -> None
"""Create a _StructTypeInfo instance."""
- self._struct = struct
+ self.struct = struct
def get_constructor_method(self):
# type: () -> MethodInfo
- pass
+ class_name = common.title_case(self.struct.name)
+ return MethodInfo(class_name, class_name, [])
def get_serializer_method(self):
# type: () -> MethodInfo
return MethodInfo(
- common.title_case(self._struct.name),
+ common.title_case(self.struct.name),
'serialize', ['BSONObjBuilder* builder'],
'void',
const=True)
def get_to_bson_method(self):
# type: () -> MethodInfo
- return MethodInfo(common.title_case(self._struct.name), 'toBSON', [], 'BSONObj', const=True)
+ return MethodInfo(common.title_case(self.struct.name), 'toBSON', [], 'BSONObj', const=True)
def get_deserializer_static_method(self):
# type: () -> MethodInfo
- class_name = common.title_case(self._struct.name)
+ class_name = common.title_case(self.struct.name)
return MethodInfo(
class_name,
'parse', ['const IDLParserErrorContext& ctxt', 'const BSONObj& bsonObject'],
@@ -191,7 +205,7 @@ class _StructTypeInfo(StructTypeInfoBase):
def get_deserializer_method(self):
# type: () -> MethodInfo
return MethodInfo(
- common.title_case(self._struct.name), 'parseProtected',
+ common.title_case(self.struct.name), 'parseProtected',
['const IDLParserErrorContext& ctxt', 'const BSONObj& bsonObject'], 'void')
def gen_getter_method(self, indented_writer):
@@ -213,7 +227,7 @@ class _IgnoredCommandTypeInfo(_StructTypeInfo):
def __init__(self, command):
# type: (ast.Command) -> None
"""Create a _IgnoredCommandTypeInfo instance."""
- self._command = command
+ self.command = command
super(_IgnoredCommandTypeInfo, self).__init__(command)
@@ -243,7 +257,7 @@ class _IgnoredCommandTypeInfo(_StructTypeInfo):
def gen_serializer(self, indented_writer):
# type: (writer.IndentedTextWriter) -> None
- indented_writer.write_line('builder->append("%s", 1);' % (self._command.name))
+ indented_writer.write_line('builder->append("%s", 1);' % (self.command.name))
class _CommandWithNamespaceTypeInfo(_StructTypeInfo):
@@ -252,20 +266,20 @@ class _CommandWithNamespaceTypeInfo(_StructTypeInfo):
def __init__(self, command):
# type: (ast.Command) -> None
"""Create a _CommandWithNamespaceTypeInfo instance."""
- self._command = command
+ self.command = command
super(_CommandWithNamespaceTypeInfo, self).__init__(command)
def get_constructor_method(self):
# type: () -> MethodInfo
- class_name = common.title_case(self._struct.name)
+ class_name = common.title_case(self.struct.name)
return MethodInfo(class_name, class_name, ['const NamespaceString& nss'])
def get_serializer_method(self):
# type: () -> MethodInfo
# Commands that require namespaces require it as a parameter to serialize()
return MethodInfo(
- common.title_case(self._struct.name),
+ common.title_case(self.struct.name),
'serialize', ['BSONObjBuilder* builder'],
'void',
const=True)
@@ -273,12 +287,12 @@ class _CommandWithNamespaceTypeInfo(_StructTypeInfo):
def get_to_bson_method(self):
# type: () -> MethodInfo
# Commands that require namespaces require it as a parameter to serialize()
- return MethodInfo(common.title_case(self._struct.name), 'toBSON', [], 'BSONObj', const=True)
+ return MethodInfo(common.title_case(self.struct.name), 'toBSON', [], 'BSONObj', const=True)
def get_deserializer_static_method(self):
# type: () -> MethodInfo
# Commands that have concatentate_with_db namespaces require db name as a parameter
- class_name = common.title_case(self._struct.name)
+ class_name = common.title_case(self.struct.name)
return MethodInfo(
class_name,
'parse',
@@ -290,7 +304,7 @@ class _CommandWithNamespaceTypeInfo(_StructTypeInfo):
# type: () -> MethodInfo
# Commands that have concatentate_with_db namespaces require db name as a parameter
return MethodInfo(
- common.title_case(self._struct.name), 'parseProtected',
+ common.title_case(self.struct.name), 'parseProtected',
['const IDLParserErrorContext& ctxt', 'StringData dbName',
'const BSONObj& bsonObject'], 'void')
@@ -304,7 +318,7 @@ class _CommandWithNamespaceTypeInfo(_StructTypeInfo):
def gen_serializer(self, indented_writer):
# type: (writer.IndentedTextWriter) -> None
- indented_writer.write_line('builder->append("%s", _nss.coll());' % (self._command.name))
+ indented_writer.write_line('builder->append("%s", _nss.coll());' % (self.command.name))
indented_writer.write_empty_line()