summaryrefslogtreecommitdiff
path: root/buildscripts/idl
diff options
context:
space:
mode:
Diffstat (limited to 'buildscripts/idl')
-rw-r--r--buildscripts/idl/idl/ast.py2
-rw-r--r--buildscripts/idl/idl/binder.py25
-rw-r--r--buildscripts/idl/idl/errors.py42
-rw-r--r--buildscripts/idl/idl/generator.py53
-rw-r--r--buildscripts/idl/idl/parser.py15
-rw-r--r--buildscripts/idl/idl/syntax.py2
-rw-r--r--buildscripts/idl/tests/test_binder.py16
-rw-r--r--buildscripts/idl/tests/test_parser.py73
8 files changed, 225 insertions, 3 deletions
diff --git a/buildscripts/idl/idl/ast.py b/buildscripts/idl/idl/ast.py
index e0aa71d0b34..c1bbba4f266 100644
--- a/buildscripts/idl/idl/ast.py
+++ b/buildscripts/idl/idl/ast.py
@@ -84,6 +84,7 @@ class Struct(common.SourceLocation):
self.strict = True # type: bool
self.immutable = False # type: bool
self.inline_chained_structs = False # type: bool
+ self.generate_comparison_operators = False # type: bool
self.fields = [] # type: List[Field]
super(Struct, self).__init__(file_name, line, column)
@@ -108,6 +109,7 @@ class Field(common.SourceLocation):
self.optional = False # type: bool
self.ignore = False # type: bool
self.chained = False # type: bool
+ self.comparison_order = -1 # type: int
# Properties specific to fields which are types.
self.cpp_type = None # type: unicode
diff --git a/buildscripts/idl/idl/binder.py b/buildscripts/idl/idl/binder.py
index 9addc6b6104..d0888f7efc0 100644
--- a/buildscripts/idl/idl/binder.py
+++ b/buildscripts/idl/idl/binder.py
@@ -237,12 +237,14 @@ def _is_duplicate_field(ctxt, field_container, fields, ast_field):
def _bind_struct_common(ctxt, parsed_spec, struct, ast_struct):
# type: (errors.ParserContext, syntax.IDLSpec, syntax.Struct, ast.Struct) -> None
+ # pylint: disable=too-many-branches
ast_struct.name = struct.name
ast_struct.description = struct.description
ast_struct.strict = struct.strict
ast_struct.immutable = struct.immutable
ast_struct.inline_chained_structs = struct.inline_chained_structs
+ ast_struct.generate_comparison_operators = struct.generate_comparison_operators
# Validate naming restrictions
if ast_struct.name.startswith("array<"):
@@ -275,6 +277,28 @@ def _bind_struct_common(ctxt, parsed_spec, struct, ast_struct):
if not _is_duplicate_field(ctxt, ast_struct.name, ast_struct.fields, ast_field):
ast_struct.fields.append(ast_field)
+ # Fill out the field comparison_order property as needed
+ if ast_struct.generate_comparison_operators and ast_struct.fields:
+ # If the user did not specify an ordering of fields, then number all fields in
+ # declared field.
+ use_default_order = True
+ comparison_orders = set() # type: Set[int]
+
+ for ast_field in ast_struct.fields:
+ if not ast_field.comparison_order == -1:
+ use_default_order = False
+ if ast_field.comparison_order in comparison_orders:
+ ctxt.add_duplicate_comparison_order_field_error(ast_struct, ast_struct.name,
+ ast_field.comparison_order)
+
+ comparison_orders.add(ast_field.comparison_order)
+
+ if use_default_order:
+ pos = 0
+ for ast_field in ast_struct.fields:
+ ast_field.comparison_order = pos
+ pos += 1
+
def _bind_struct(ctxt, parsed_spec, struct):
# type: (errors.ParserContext, syntax.IDLSpec, syntax.Struct) -> ast.Struct
@@ -433,6 +457,7 @@ def _bind_field(ctxt, parsed_spec, field):
ast_field.supports_doc_sequence = field.supports_doc_sequence
ast_field.serialize_op_msg_request_only = field.serialize_op_msg_request_only
ast_field.constructed = field.constructed
+ ast_field.comparison_order = field.comparison_order
ast_field.cpp_name = field.name
if field.cpp_name:
diff --git a/buildscripts/idl/idl/errors.py b/buildscripts/idl/idl/errors.py
index 71054a15652..a883d09d40a 100644
--- a/buildscripts/idl/idl/errors.py
+++ b/buildscripts/idl/idl/errors.py
@@ -82,6 +82,9 @@ ERROR_ID_STRUCT_NO_DOC_SEQUENCE = "ID0045"
ERROR_ID_NO_DOC_SEQUENCE_FOR_NON_ARRAY = "ID0046"
ERROR_ID_NO_DOC_SEQUENCE_FOR_NON_OBJECT = "ID0047"
ERROR_ID_COMMAND_DUPLICATES_FIELD = "ID0048"
+ERROR_ID_IS_NODE_VALID_INT = "ID0049"
+ERROR_ID_IS_NODE_VALID_NON_NEGATIVE_INT = "ID0050"
+ERROR_ID_IS_DUPLICATE_COMPARISON_ORDER = "ID0051"
class IDLError(Exception):
@@ -616,6 +619,45 @@ class ParserContext(object):
self._add_error(location, ERROR_ID_COMMAND_DUPLICATES_FIELD,
("Command '%s' cannot have the same name as a field.") % (command_name))
+ def is_scalar_non_negative_int_node(self, node, node_name):
+ # type: (Union[yaml.nodes.MappingNode, yaml.nodes.ScalarNode, yaml.nodes.SequenceNode], unicode) -> bool
+ """Return True if this YAML node is a Scalar and a valid non-negative int."""
+ if not self._is_node_type(node, node_name, "scalar"):
+ return False
+
+ try:
+ value = int(node.value)
+ if value < 0:
+ self._add_node_error(
+ node, ERROR_ID_IS_NODE_VALID_NON_NEGATIVE_INT,
+ "Illegal negative integer value for '%s', expected 0 or positive integer." %
+ (node_name))
+ return False
+
+ except ValueError as value_error:
+ self._add_node_error(node, ERROR_ID_IS_NODE_VALID_INT,
+ "Illegal integer value for '%s', message '%s'." %
+ (node_name, value_error))
+ return False
+
+ return True
+
+ def get_non_negative_int(self, node):
+ # type: (Union[yaml.nodes.MappingNode, yaml.nodes.ScalarNode, yaml.nodes.SequenceNode]) -> int
+ """Convert a scalar to an int."""
+ assert self.is_scalar_non_negative_int_node(node, "unknown")
+
+ return int(node.value)
+
+ def add_duplicate_comparison_order_field_error(self, location, struct_name, comparison_order):
+ # type: (common.SourceLocation, unicode, int) -> None
+ """Add an error about fields having duplicate comparison_orders."""
+ # pylint: disable=invalid-name
+ self._add_error(
+ location, ERROR_ID_IS_DUPLICATE_COMPARISON_ORDER,
+ ("Struct '%s' cannot have two fields with the same comparison_order value '%d'.") %
+ (struct_name, comparison_order))
+
def _assert_unique_error_messages():
# type: () -> None
diff --git a/buildscripts/idl/idl/generator.py b/buildscripts/idl/idl/generator.py
index 9f591eccc4f..35e250998b7 100644
--- a/buildscripts/idl/idl/generator.py
+++ b/buildscripts/idl/idl/generator.py
@@ -549,10 +549,55 @@ class _CppHeaderFileWriter(_CppFileWriterBase):
self._writer.write_line("static const std::vector<StringData> _knownFields;")
self.write_empty_line()
+ def gen_comparison_operators_declarations(self, struct):
+ # type: (ast.Struct) -> None
+ """Generate comparison operators declarations for the type."""
+ # pylint: disable=invalid-name
+
+ template_params = {'class_name': common.title_case(struct.name)}
+
+ with self._with_template(template_params):
+ self._writer.write_template(
+ 'friend bool operator==(const ${class_name}& left, const ${class_name}& right);')
+ self._writer.write_template(
+ 'friend bool operator!=(const ${class_name}& left, const ${class_name}& right);')
+ self._writer.write_template(
+ 'friend bool operator<(const ${class_name}& left, const ${class_name}& right);')
+
+ self.write_empty_line()
+
+ def gen_comparison_operators_definitions(self, struct):
+ # type: (ast.Struct) -> None
+ """Generate comparison operators definitions for the type."""
+ # pylint: disable=invalid-name
+
+ sorted_fields = sorted(
+ [
+ field for field in struct.fields
+ if (not field.ignore) and field.comparison_order != -1
+ ],
+ key=lambda f: f.comparison_order)
+ fields = [_get_field_member_name(field) for field in sorted_fields]
+
+ for rel_op in ['==', '!=', '<']:
+ decl = common.template_args(
+ "inline bool operator${rel_op}(const ${class_name}& left, const ${class_name}& right) {",
+ rel_op=rel_op,
+ class_name=common.title_case(struct.name))
+
+ with self._block(decl, "}"):
+ self._writer.write_line('return std::tie(%s) %s std::tie(%s);' % (','.join(
+ ["left.%s" % (field) for field in fields]), rel_op, ','.join(
+ ["right.%s" % (field) for field in fields])))
+
+ self.write_empty_line()
+
+ self.write_empty_line()
+
def generate(self, spec):
# type: (ast.IDLAST) -> None
"""Generate the C++ header to a stream."""
- # pylint: disable=too-many-branches
+ # pylint: disable=too-many-branches,too-many-statements
self.gen_file_header()
self._writer.write_unindented_line('#pragma once')
@@ -634,6 +679,9 @@ class _CppHeaderFileWriter(_CppFileWriterBase):
if not struct.immutable:
self.gen_setter(field)
+ if struct.generate_comparison_operators:
+ self.gen_comparison_operators_declarations(struct)
+
self.write_unindented_line('protected:')
self.gen_protected_serializer_methods(struct)
@@ -660,6 +708,9 @@ class _CppHeaderFileWriter(_CppFileWriterBase):
self.write_empty_line()
+ if struct.generate_comparison_operators:
+ self.gen_comparison_operators_definitions(struct)
+
class _CppSourceFileWriter(_CppFileWriterBase):
"""C++ .cpp File writer."""
diff --git a/buildscripts/idl/idl/parser.py b/buildscripts/idl/idl/parser.py
index 0f45fe93255..49821707c94 100644
--- a/buildscripts/idl/idl/parser.py
+++ b/buildscripts/idl/idl/parser.py
@@ -36,9 +36,12 @@ class _RuleDesc(object):
"""
Describe a simple parser rule for the generic YAML node parser.
- node_type is either (scalar, scalar_bool, scalar_or_sequence, or mapping)
- - scalar_bool - means a scalar node which is a valid bool, populates a bool
+ node_type is either (scalar, bool_scalar, int_scalar, scalar_or_sequence, sequence, or mapping)
+ - bool_scalar - means a scalar node which is a valid bool, populates a bool
+ - int_scalar - means a scalar node which is a valid non-negative int, populates a int
- scalar_or_sequence - means a scalar or sequence node, populates a list
+ - sequence - a sequence node, populates a list
+ - mapping - a mapping node, calls another parser
mapping_parser_func is only called when parsing a mapping yaml node
"""
@@ -83,6 +86,9 @@ def _generic_parser(
elif rule_desc.node_type == "bool_scalar":
if ctxt.is_scalar_bool_node(second_node, first_name):
syntax_node.__dict__[first_name] = ctxt.get_bool(second_node)
+ elif rule_desc.node_type == "int_scalar":
+ if ctxt.is_scalar_non_negative_int_node(second_node, first_name):
+ syntax_node.__dict__[first_name] = ctxt.get_non_negative_int(second_node)
elif rule_desc.node_type == "scalar_or_sequence":
if ctxt.is_scalar_sequence_or_scalar_node(second_node, first_name):
syntax_node.__dict__[first_name] = ctxt.get_list(second_node)
@@ -198,6 +204,7 @@ def _parse_field(ctxt, name, node):
"optional": _RuleDesc("bool_scalar"),
"default": _RuleDesc('scalar'),
"supports_doc_sequence": _RuleDesc("bool_scalar"),
+ "comparison_order": _RuleDesc("int_scalar"),
})
return field
@@ -337,6 +344,7 @@ def _parse_struct(ctxt, spec, name, node):
"strict": _RuleDesc("bool_scalar"),
"inline_chained_structs": _RuleDesc("bool_scalar"),
"immutable": _RuleDesc('bool_scalar'),
+ "generate_comparison_operators": _RuleDesc("bool_scalar"),
})
# TODO: SHOULD WE ALLOW STRUCTS ONLY WITH CHAINED STUFF and no fields???
@@ -412,6 +420,9 @@ def _parse_command(ctxt, spec, name, node):
"fields": _RuleDesc('mapping', mapping_parser_func=_parse_fields),
"namespace": _RuleDesc('scalar', _RuleDesc.REQUIRED),
"strict": _RuleDesc("bool_scalar"),
+ "inline_chained_structs": _RuleDesc("bool_scalar"),
+ "immutable": _RuleDesc('bool_scalar'),
+ "generate_comparison_operators": _RuleDesc("bool_scalar"),
})
# TODO: support the first argument as UUID depending on outcome of Catalog Versioning changes.
diff --git a/buildscripts/idl/idl/syntax.py b/buildscripts/idl/idl/syntax.py
index 6ea14171f74..df355084555 100644
--- a/buildscripts/idl/idl/syntax.py
+++ b/buildscripts/idl/idl/syntax.py
@@ -285,6 +285,7 @@ class Field(common.SourceLocation):
self.optional = False # type: bool
self.default = None # type: unicode
self.supports_doc_sequence = False # type: bool
+ self.comparison_order = -1 # type: int
# Internal fields - not generated by parser
self.serialize_op_msg_request_only = False # type: bool
@@ -342,6 +343,7 @@ class Struct(common.SourceLocation):
self.strict = True # type: bool
self.immutable = False # type: bool
self.inline_chained_structs = False # type: bool
+ self.generate_comparison_operators = False # type: bool
self.chained_types = None # type: List[ChainedType]
self.chained_structs = None # type: List[ChainedStruct]
self.fields = None # type: List[Field]
diff --git a/buildscripts/idl/tests/test_binder.py b/buildscripts/idl/tests/test_binder.py
index 5502b69d366..5a086885b5b 100644
--- a/buildscripts/idl/tests/test_binder.py
+++ b/buildscripts/idl/tests/test_binder.py
@@ -751,6 +751,22 @@ class TestBinder(testcase.IDLTestcase):
optional: true
"""), idl.errors.ERROR_ID_ILLEGAL_FIELD_DEFAULT_AND_OPTIONAL)
+ # Test duplicate comparison order
+ self.assert_bind_fail(test_preamble + textwrap.dedent("""
+ structs:
+ foo:
+ description: foo
+ strict: false
+ generate_comparison_operators: true
+ fields:
+ foo:
+ type: string
+ comparison_order: 1
+ bar:
+ type: string
+ comparison_order: 1
+ """), idl.errors.ERROR_ID_IS_DUPLICATE_COMPARISON_ORDER)
+
def test_ignored_field_negative(self):
# type: () -> None
"""Test that if a field is marked as ignored, no other properties are set."""
diff --git a/buildscripts/idl/tests/test_parser.py b/buildscripts/idl/tests/test_parser.py
index becae57acdb..32b0fa4b627 100644
--- a/buildscripts/idl/tests/test_parser.py
+++ b/buildscripts/idl/tests/test_parser.py
@@ -14,6 +14,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
"""Test cases for IDL parser."""
+# pylint: disable=too-many-lines
from __future__ import absolute_import, print_function, unicode_literals
@@ -278,6 +279,7 @@ class TestParser(testcase.IDLTestcase):
strict: true
immutable: true
inline_chained_structs: true
+ generate_comparison_operators: true
fields:
foo: bar
"""))
@@ -291,6 +293,7 @@ class TestParser(testcase.IDLTestcase):
strict: false
immutable: false
inline_chained_structs: false
+ generate_comparison_operators: false
fields:
foo: bar
"""))
@@ -348,6 +351,28 @@ class TestParser(testcase.IDLTestcase):
foo: bar
"""), idl.errors.ERROR_ID_IS_NODE_VALID_BOOL)
+ # inline_chained_structs is a bool
+ self.assert_parse_fail(
+ textwrap.dedent("""
+ structs:
+ foo:
+ description: foo
+ inline_chained_structs: bar
+ fields:
+ foo: bar
+ """), idl.errors.ERROR_ID_IS_NODE_VALID_BOOL)
+
+ # generate_comparison_operators is a bool
+ self.assert_parse_fail(
+ textwrap.dedent("""
+ structs:
+ foo:
+ description: foo
+ generate_comparison_operators: bar
+ fields:
+ foo: bar
+ """), idl.errors.ERROR_ID_IS_NODE_VALID_BOOL)
+
def test_field_positive(self):
# type: () -> None
"""Positive field test cases."""
@@ -375,6 +400,7 @@ class TestParser(testcase.IDLTestcase):
optional: true
ignore: true
cpp_name: bar
+ comparison_order: 3
"""))
# Test false bools
@@ -433,6 +459,47 @@ class TestParser(testcase.IDLTestcase):
ignore: bar
"""), idl.errors.ERROR_ID_IS_NODE_VALID_BOOL)
+ # Test bad int scalar
+ self.assert_parse_fail(
+ textwrap.dedent("""
+ structs:
+ foo:
+ description: foo
+ strict: false
+ fields:
+ foo:
+ type: string
+ comparison_order:
+ - a
+ - b
+ """), idl.errors.ERROR_ID_IS_NODE_TYPE)
+
+ # Test bad int
+ self.assert_parse_fail(
+ textwrap.dedent("""
+ structs:
+ foo:
+ description: foo
+ strict: false
+ fields:
+ foo:
+ type: string
+ comparison_order: 3.14159
+ """), idl.errors.ERROR_ID_IS_NODE_VALID_INT)
+
+ # Test bad negative int
+ self.assert_parse_fail(
+ textwrap.dedent("""
+ structs:
+ foo:
+ description: foo
+ strict: false
+ fields:
+ foo:
+ type: string
+ comparison_order: -1
+ """), idl.errors.ERROR_ID_IS_NODE_VALID_NON_NEGATIVE_INT)
+
def test_name_collisions_negative(self):
# type: () -> None
"""Negative tests for type collisions."""
@@ -758,6 +825,9 @@ class TestParser(testcase.IDLTestcase):
description: foo
strict: true
namespace: ignored
+ immutable: true
+ inline_chained_structs: true
+ generate_comparison_operators: true
fields:
foo: bar
"""))
@@ -770,6 +840,9 @@ class TestParser(testcase.IDLTestcase):
description: foo
strict: false
namespace: ignored
+ immutable: false
+ inline_chained_structs: false
+ generate_comparison_operators: false
fields:
foo: bar
"""))