summaryrefslogtreecommitdiff
path: root/buildscripts/idl
diff options
context:
space:
mode:
authorSara Golemon <sara.golemon@mongodb.com>2018-09-26 15:17:56 +0000
committerSara Golemon <sara.golemon@mongodb.com>2018-10-05 23:33:33 +0000
commitf9600749453db5dec58bd1bbfd967e16b1578a24 (patch)
tree22e0f6e8f773483ad12828d92f4bf8d1496fbe70 /buildscripts/idl
parent2db04b524dc5e2121b74829814db7c8e84e5696d (diff)
downloadmongo-f9600749453db5dec58bd1bbfd967e16b1578a24.tar.gz
SERVER-37168 Add validators for IDL fields
Diffstat (limited to 'buildscripts/idl')
-rw-r--r--buildscripts/idl/idl/ast.py27
-rw-r--r--buildscripts/idl/idl/binder.py33
-rw-r--r--buildscripts/idl/idl/errors.py9
-rw-r--r--buildscripts/idl/idl/generator.py109
-rw-r--r--buildscripts/idl/idl/parser.py18
-rw-r--r--buildscripts/idl/idl/syntax.py25
6 files changed, 217 insertions, 4 deletions
diff --git a/buildscripts/idl/idl/ast.py b/buildscripts/idl/idl/ast.py
index e5eda3b8426..7b2c85fda1f 100644
--- a/buildscripts/idl/idl/ast.py
+++ b/buildscripts/idl/idl/ast.py
@@ -92,6 +92,30 @@ class Struct(common.SourceLocation):
super(Struct, self).__init__(file_name, line, column)
+class Validator(common.SourceLocation):
+ """
+ An instance of a validator for a field.
+
+ The validator must include at least one of the defined validation predicates.
+ If more than one is included, they must ALL evaluate to true.
+ """
+
+ # pylint: disable=too-many-instance-attributes
+
+ def __init__(self, file_name, line, column):
+ # type: (unicode, int, int) -> None
+ """Construct a Validator."""
+ # Don't lint gt/lt as bad attribute names.
+ # pylint: disable=C0103
+ self.gt = None # type: Optional[Union[int, float]]
+ self.lt = None # type: Optional[Union[int, float]]
+ self.gte = None # type: Optional[Union[int, float]]
+ self.lte = None # type: Optional[Union[int, float]]
+ self.callback = None # type: Optional[unicode]
+
+ super(Validator, self).__init__(file_name, line, column)
+
+
class Field(common.SourceLocation):
"""
An instance of a field in a struct.
@@ -139,6 +163,9 @@ class Field(common.SourceLocation):
self.serialize_op_msg_request_only = False # type: bool
self.constructed = False # type: bool
+ # Validation rules.
+ self.validator = None # type: Optional[Validator]
+
super(Field, self).__init__(file_name, line, column)
diff --git a/buildscripts/idl/idl/binder.py b/buildscripts/idl/idl/binder.py
index bf8c1881514..02955009e73 100644
--- a/buildscripts/idl/idl/binder.py
+++ b/buildscripts/idl/idl/binder.py
@@ -505,6 +505,34 @@ def _normalize_method_name(cpp_type_name, cpp_method_name):
return cpp_method_name
+def _bind_validator(ctxt, validator):
+ # type: (errors.ParserContext, syntax.Validator) -> ast.Validator
+ """Bind a validator from the idl.syntax tree."""
+
+ ast_validator = ast.Validator(validator.file_name, validator.line, validator.column)
+
+ # Parse syntax value as numeric if possible.
+ for pred in ["gt", "lt", "gte", "lte"]:
+ val = getattr(validator, pred)
+ if val is None:
+ continue
+
+ try:
+ intval = int(val)
+ if (intval < -0x80000000) or (intval > 0x7FFFFFFF):
+ raise ValueError('IDL ints are limited to int32_t')
+ setattr(ast_validator, pred, intval)
+ except ValueError:
+ try:
+ setattr(ast_validator, pred, float(val))
+ except ValueError:
+ ctxt.add_value_not_numeric_error(ast_validator, pred, val)
+ return None
+
+ ast_validator.callback = validator.callback
+ return ast_validator
+
+
def _bind_field(ctxt, parsed_spec, field):
# type: (errors.ParserContext, syntax.IDLSpec, syntax.Field) -> ast.Field
"""
@@ -598,6 +626,11 @@ def _bind_field(ctxt, parsed_spec, field):
# Validation doc_sequence types
_validate_doc_sequence_field(ctxt, ast_field)
+ if field.validator is not None:
+ ast_field.validator = _bind_validator(ctxt, field.validator)
+ if ast_field.validator is None:
+ return None
+
return ast_field
diff --git a/buildscripts/idl/idl/errors.py b/buildscripts/idl/idl/errors.py
index 9fc78547f8c..b9785c2bf66 100644
--- a/buildscripts/idl/idl/errors.py
+++ b/buildscripts/idl/idl/errors.py
@@ -85,6 +85,7 @@ ERROR_ID_IS_NODE_VALID_INT = "ID0049"
ERROR_ID_IS_NODE_VALID_NON_NEGATIVE_INT = "ID0050"
ERROR_ID_IS_DUPLICATE_COMPARISON_ORDER = "ID0051"
ERROR_ID_IS_COMMAND_TYPE_EXTRANEOUS = "ID0052"
+ERROR_ID_VALUE_NOT_NUMERIC = "ID0053"
class IDLError(Exception):
@@ -665,6 +666,14 @@ class ParserContext(object):
("Command '%s' cannot have a 'type' property unless namespace equals 'type'.") %
(command_name))
+ def add_value_not_numeric_error(self, location, attrname, value):
+ # type: (common.SourceLocation, unicode, unicode) -> None
+ """Add an error about non-numeric value where number expected."""
+ # pylint: disable=invalid-name
+ self._add_error(location, ERROR_ID_VALUE_NOT_NUMERIC,
+ ("'%s' requires a numeric value, but %s can not be cast") % (attrname,
+ value))
+
def _assert_unique_error_messages():
# type: () -> None
diff --git a/buildscripts/idl/idl/generator.py b/buildscripts/idl/idl/generator.py
index dd8349dea2e..3da969c756b 100644
--- a/buildscripts/idl/idl/generator.py
+++ b/buildscripts/idl/idl/generator.py
@@ -76,6 +76,16 @@ def _get_field_constant_name(field):
field.cpp_name))
+def _get_field_member_validator_name(field):
+ # type (ast.Field) -> unicode
+ """
+ Get the name of the validator method for this field.
+
+ Fields with no validation rules will have a stub validator which returns Status::OK().
+ """
+ return 'validate%s' % common.title_case(field.cpp_name)
+
+
def _access_member(field):
# type: (ast.Field) -> unicode
"""Get the declaration to access a member for a field."""
@@ -481,6 +491,26 @@ class _CppHeaderFileWriter(_CppFileWriterBase):
self._writer.write_template(
'${const_type}${param_type} ${method_name}() const { ${body} }')
+ def gen_validator(self, field):
+ # type: (ast.Field) -> None
+ """Generate the C++ validator definition for a field."""
+
+ template_params = {
+ 'method_name': _get_field_member_validator_name(field),
+ 'param_type': cpp_types.get_cpp_type(field).get_getter_setter_type()
+ }
+
+ with self._with_template(template_params):
+ if field.validator is None:
+ # Header inline the Status::OK stub for non-validated fields.
+ self._writer.write_template(
+ 'Status ${method_name}(${param_type}) { return Status::OK(); }')
+ else:
+ # Declare method implemented in C++ file.
+ self._writer.write_template('Status ${method_name}(${param_type});')
+
+ self._writer.write_empty_line()
+
def gen_setter(self, field):
# type: (ast.Field) -> None
"""Generate the C++ setter definition for a field."""
@@ -492,17 +522,22 @@ class _CppHeaderFileWriter(_CppFileWriterBase):
if _is_required_serializer_field(field):
post_body = '%s = true;' % (_get_has_field_member_name(field))
+ validator = ''
+ if field.validator is not None:
+ validator = 'uassertStatusOK(%s(value));' % _get_field_member_validator_name(field)
+
template_params = {
'method_name': _get_field_member_setter_name(field),
'member_name': member_name,
'param_type': param_type,
'body': cpp_type_info.get_setter_body(member_name),
'post_body': post_body,
+ 'validator': validator,
}
with self._with_template(template_params):
self._writer.write_template(
- 'void ${method_name}(${param_type} value) & ' + '{ ${body} ${post_body} }')
+ 'void ${method_name}(${param_type} value) & { ${validator} ${body} ${post_body} }')
self._writer.write_empty_line()
@@ -704,6 +739,7 @@ class _CppHeaderFileWriter(_CppFileWriterBase):
self.gen_description_comment(field.description)
self.gen_getter(struct, field)
if not struct.immutable and not field.chained_struct_field:
+ self.gen_validator(field)
self.gen_setter(field)
if struct.generate_comparison_operators:
@@ -868,6 +904,20 @@ class _CppSourceFileWriter(_CppFileWriterBase):
self._gen_array_deserializer(field, bson_element)
return
+ def validate_and_assign_or_uassert(field, expression):
+ # type: (ast.Field, unicode) -> None
+ """Perform field value validation post-assignment."""
+ field_name = _get_field_member_name(field)
+ if field.validator is None:
+ self._writer.write_line('%s = %s;' % (field_name, expression))
+ return
+
+ with self._block('{', '}'):
+ self._writer.write_line('auto value = %s;' % (expression))
+ self._writer.write_line('uassertStatusOK(%s(value));' %
+ (_get_field_member_validator_name(field)))
+ self._writer.write_line('%s = std::move(value);' % (field_name))
+
if field.chained:
# Do not generate a predicate check since we always call these deserializers.
@@ -881,8 +931,8 @@ class _CppSourceFileWriter(_CppFileWriterBase):
expression = "%s(%s)" % (method_name, bson_object)
self._gen_usage_check(field, bson_element, field_usage_check)
+ validate_and_assign_or_uassert(field, expression)
- self._writer.write_line('%s = %s;' % (_get_field_member_name(field), expression))
else:
predicate = _get_bson_type_check(bson_element, 'ctxt', field)
if predicate:
@@ -893,12 +943,12 @@ class _CppSourceFileWriter(_CppFileWriterBase):
object_value = self._gen_field_deserializer_expression(bson_element, field)
if field.chained_struct_field:
+ # No need for explicit validation as setter will throw for us.
self._writer.write_line('%s.%s(%s);' %
(_get_field_member_name(field.chained_struct_field),
_get_field_member_setter_name(field), object_value))
else:
- self._writer.write_line('%s = %s;' % (_get_field_member_name(field),
- object_value))
+ validate_and_assign_or_uassert(field, object_value)
def gen_doc_sequence_deserializer(self, field):
# type: (ast.Field) -> None
@@ -1132,6 +1182,53 @@ class _CppSourceFileWriter(_CppFileWriterBase):
self._writer.write_line(method_info.get_call('object'))
self._writer.write_line('return object;')
+ def gen_field_validators(self, struct):
+ # type: (ast.Struct) -> None
+ """Generate non-trivial field validators."""
+ for field in struct.fields:
+ if field.validator is None:
+ # Fields without validators are implemented in the header.
+ continue
+
+ cpp_type = cpp_types.get_cpp_type(field)
+
+ method_template = {
+ 'class_name': common.title_case(struct.name),
+ 'method_name': _get_field_member_validator_name(field),
+ 'param_type': cpp_type.get_getter_setter_type(),
+ }
+
+ def compare_and_return_status(op, limit):
+ # type: (unicode, Union[int, float]) -> None
+ """Emit a comparison which returns an BadValue Status on failure."""
+ with self._block('if (!(value %s %s)) {' % (op, repr(limit)), '}'):
+ self._writer.write_line(
+ 'return {::mongo::ErrorCodes::BadValue, str::stream() << ' +
+ '"Value must be %s %s, \'" << value << "\' provided"};' % (op, limit))
+
+ validator = field.validator
+ with self._with_template(method_template):
+ self._writer.write_template(
+ 'Status ${class_name}::${method_name}(${param_type} value)')
+ with self._block('{', '}'):
+ if validator.gt is not None:
+ compare_and_return_status('>', validator.gt)
+ if validator.gte is not None:
+ compare_and_return_status('>=', validator.gte)
+ if validator.lt is not None:
+ compare_and_return_status('<', validator.lt)
+ if validator.lte is not None:
+ compare_and_return_status('<=', validator.lte)
+
+ if validator.callback is not None:
+ with self._block('{', '}'):
+ self._writer.write_line('Status status = %s(value);' %
+ (validator.callback))
+ with self._block('if (!status.isOK()) {', '}'):
+ self._writer.write_line('return status;')
+
+ self._writer.write_line('return Status::OK();')
+
def gen_bson_deserializer_methods(self, struct):
# type: (ast.Struct) -> None
"""Generate the C++ deserializer method definitions."""
@@ -1584,6 +1681,10 @@ class _CppSourceFileWriter(_CppFileWriterBase):
self.gen_constructors(struct)
self.write_empty_line()
+ # Write field validators
+ self.gen_field_validators(struct)
+ self.write_empty_line()
+
# Write deserializers
self.gen_bson_deserializer_methods(struct)
self.write_empty_line()
diff --git a/buildscripts/idl/idl/parser.py b/buildscripts/idl/idl/parser.py
index 052f9a21782..bb778a5633f 100644
--- a/buildscripts/idl/idl/parser.py
+++ b/buildscripts/idl/idl/parser.py
@@ -191,6 +191,23 @@ def _parse_type(ctxt, spec, name, node):
spec.symbols.add_type(ctxt, idltype)
+def _parse_validator(ctxt, node):
+ # type: (errors.ParserContext, yaml.nodes.MappingNode) -> syntax.Validator
+ """Parse a validator for a field."""
+ validator = syntax.Validator(ctxt.file_name, node.start_mark.line, node.start_mark.column)
+
+ _generic_parser(
+ ctxt, node, "validator", validator, {
+ "gt": _RuleDesc("scalar"),
+ "lt": _RuleDesc("scalar"),
+ "gte": _RuleDesc("scalar"),
+ "lte": _RuleDesc("scalar"),
+ "callback": _RuleDesc("scalar"),
+ })
+
+ return validator
+
+
def _parse_field(ctxt, name, node):
# type: (errors.ParserContext, str, Union[yaml.nodes.MappingNode, yaml.nodes.ScalarNode, yaml.nodes.SequenceNode]) -> syntax.Field
"""Parse a field in a struct/command in the IDL file."""
@@ -207,6 +224,7 @@ def _parse_field(ctxt, name, node):
"default": _RuleDesc('scalar'),
"supports_doc_sequence": _RuleDesc("bool_scalar"),
"comparison_order": _RuleDesc("int_scalar"),
+ "validator": _RuleDesc('mapping', mapping_parser_func=_parse_validator),
})
return field
diff --git a/buildscripts/idl/idl/syntax.py b/buildscripts/idl/idl/syntax.py
index 049114b5d9f..1f354b9768f 100644
--- a/buildscripts/idl/idl/syntax.py
+++ b/buildscripts/idl/idl/syntax.py
@@ -263,6 +263,30 @@ class Type(common.SourceLocation):
super(Type, self).__init__(file_name, line, column)
+class Validator(common.SourceLocation):
+ """
+ An instance of a validator for a field.
+
+ The validator must include at least one of the defined validation predicates.
+ If more than one is included, they must ALL evaluate to true.
+ """
+
+ # pylint: disable=too-many-instance-attributes
+
+ def __init__(self, file_name, line, column):
+ # type: (unicode, int, int) -> None
+ """Construct a Validator."""
+ # Don't lint gt/lt as bad attibute names.
+ # pylint: disable=C0103
+ self.gt = None # type: unicode
+ self.lt = None # type: unicode
+ self.gte = None # type: unicode
+ self.lte = None # type: unicode
+ self.callback = None # type: unicode
+
+ super(Validator, self).__init__(file_name, line, column)
+
+
class Field(common.SourceLocation):
"""
An instance of a field in a struct.
@@ -286,6 +310,7 @@ class Field(common.SourceLocation):
self.default = None # type: unicode
self.supports_doc_sequence = False # type: bool
self.comparison_order = -1 # type: int
+ self.validator = None # type: Validator
# Internal fields - not generated by parser
self.serialize_op_msg_request_only = False # type: bool