From 5bc18e5edf382187c421ae23a5707a3207ef3204 Mon Sep 17 00:00:00 2001 From: Mark Benvenuto Date: Mon, 1 Mar 2021 15:28:37 -0500 Subject: =?UTF-8?q?SERVER-54520=20Extend=20IDL=20for=20new=20access=5Fchec?= =?UTF-8?q?k=20field=20and=20none=20value=20and=20generate=20code=20when?= =?UTF-8?q?=20api=5Fversion=20!=3D=20=E2=80=9C=E2=80=9D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- buildscripts/idl/idl/ast.py | 11 +++++++++++ buildscripts/idl/idl/binder.py | 18 +++++++++++++++++- buildscripts/idl/idl/generator.py | 10 +++++++--- buildscripts/idl/idl/parser.py | 22 +++++++++++++++++++++ buildscripts/idl/idl/syntax.py | 13 +++++++++++++ buildscripts/idl/tests/test_binder.py | 36 ++++++++++++++++++++++++++++++++++- buildscripts/idl/tests/test_parser.py | 29 +++++++++++++++++++++++----- src/mongo/idl/unittest.idl | 12 +++++++++++- 8 files changed, 140 insertions(+), 11 deletions(-) diff --git a/buildscripts/idl/idl/ast.py b/buildscripts/idl/idl/ast.py index 3b5e5e3d122..aa7e6a88ae1 100644 --- a/buildscripts/idl/idl/ast.py +++ b/buildscripts/idl/idl/ast.py @@ -225,6 +225,16 @@ class Field(common.SourceLocation): super(Field, self).__init__(file_name, line, column) +class AccessCheck(common.SourceLocation): + """IDL commmand access check information.""" + + def __init__(self, file_name, line, column): + # type: (str, int, int) -> None + """Construct an AccessCheck.""" + self.placeholder = None # type: str + super(AccessCheck, self).__init__(file_name, line, column) + + class Command(Struct): """ IDL commmand information. @@ -244,6 +254,7 @@ class Command(Struct): self.reply_type = None # type: Field self.api_version = "" # type: str self.is_deprecated = False # type: bool + self.access_checks = None # type: List[AccessCheck] super(Command, self).__init__(file_name, line, column) diff --git a/buildscripts/idl/idl/binder.py b/buildscripts/idl/idl/binder.py index 582c63128cb..3ae92799d1c 100644 --- a/buildscripts/idl/idl/binder.py +++ b/buildscripts/idl/idl/binder.py @@ -31,7 +31,7 @@ import collections import re import typing -from typing import Type, TypeVar, cast, List, Set, Union +from typing import Type, TypeVar, cast, List, Set, Union, Optional from . import ast from . import bson @@ -549,6 +549,20 @@ def _bind_command_reply_type(ctxt, parsed_spec, command): return ast_field +def _bind_access_check(command): + # type: (syntax.Command) -> Optional[List[ast.AccessCheck]] + """Bind the access_check field in a command.""" + if not command.access_check: + return None + + access_check = command.access_check + + if access_check.none: + return [] + + return None + + def _bind_command(ctxt, parsed_spec, command): # type: (errors.ParserContext, syntax.IDLSpec, syntax.Command) -> ast.Command """ @@ -569,6 +583,8 @@ def _bind_command(ctxt, parsed_spec, command): _bind_struct_common(ctxt, parsed_spec, command, ast_command) + ast_command.access_checks = _bind_access_check(command) + ast_command.namespace = command.namespace if command.type: diff --git a/buildscripts/idl/idl/generator.py b/buildscripts/idl/idl/generator.py index b3b2d08fdf0..08aa024ee43 100644 --- a/buildscripts/idl/idl/generator.py +++ b/buildscripts/idl/idl/generator.py @@ -971,8 +971,8 @@ class _CppHeaderFileWriter(_CppFileWriterBase): with self._block('%s {' % (fn_def), '}'): self._writer.write_line('return %s;' % value) - def gen_invocation_base_class_declaration(self): - # type: () -> None + def gen_invocation_base_class_declaration(self, command): + # type: (ast.Command) -> None """Generate the InvocationBaseGen class for a command's base class.""" class_declaration = 'class InvocationBaseGen : public _TypedCommandInvocationBase {' with writer.IndentedScopedBlock(self._writer, class_declaration, '};'): @@ -987,6 +987,10 @@ class _CppHeaderFileWriter(_CppFileWriterBase): self._writer.write_line('virtual Reply typedRun(OperationContext* opCtx) = 0;') + if command.access_checks == []: + self._writer.write_line( + 'void doCheckAuthorization(OperationContext* opCtx) const final {}') + def generate_versioned_command_base_class(self, command): # type: (ast.Command) -> None """Generate a command's C++ base class to a stream.""" @@ -1028,7 +1032,7 @@ class _CppHeaderFileWriter(_CppFileWriterBase): self.gen_api_version_fn(False, command.is_deprecated) # Write InvocationBaseGen class. - self.gen_invocation_base_class_declaration() + self.gen_invocation_base_class_declaration(command) def generate(self, spec): # type: (ast.IDLAST) -> None diff --git a/buildscripts/idl/idl/parser.py b/buildscripts/idl/idl/parser.py index ccc07d1b27f..eb09b454450 100644 --- a/buildscripts/idl/idl/parser.py +++ b/buildscripts/idl/idl/parser.py @@ -682,6 +682,27 @@ def _parse_enum(ctxt, spec, name, node): spec.symbols.add_enum(ctxt, idl_enum) +def _parse_access_checks(ctxt, node): + # type: (errors.ParserContext, yaml.nodes.MappingNode) -> syntax.AccessChecks + """Parse an access check section in a struct in the IDL file.""" + + access_checks = None + + if not ctxt.is_mapping_node(node, "access_check"): + return None + + access_checks = syntax.AccessChecks(ctxt.file_name, node.start_mark.line, + node.start_mark.column) + + _generic_parser(ctxt, node, "access_check", access_checks, { + "none": _RuleDesc('bool_scalar'), + }) + + # TODO(SERVER-54522) TODO(SERVER-54523) - validate only one of none, simple or complex is set + + return access_checks + + def _parse_command(ctxt, spec, name, node): # type: (errors.ParserContext, syntax.IDLSpec, str, Union[yaml.nodes.MappingNode, yaml.nodes.ScalarNode, yaml.nodes.SequenceNode]) -> None """Parse a command section in the IDL file.""" @@ -714,6 +735,7 @@ def _parse_command(ctxt, spec, name, node): "generate_comparison_operators": _RuleDesc("bool_scalar"), "allow_global_collection_name": _RuleDesc('bool_scalar'), "non_const_getter": _RuleDesc('bool_scalar'), + "access_check": _RuleDesc('mapping', mapping_parser_func=_parse_access_checks), }) valid_commands = [ diff --git a/buildscripts/idl/idl/syntax.py b/buildscripts/idl/idl/syntax.py index 092bbd7228a..f6d973a442d 100644 --- a/buildscripts/idl/idl/syntax.py +++ b/buildscripts/idl/idl/syntax.py @@ -549,6 +549,18 @@ class Struct(common.SourceLocation): super(Struct, self).__init__(file_name, line, column) +class AccessChecks(common.SourceLocation): + """IDL access checks information.""" + + def __init__(self, file_name, line, column): + # type: (str, int, int) -> None + """Construct an AccessChecks.""" + + self.none = None # type: bool + + super(AccessChecks, self).__init__(file_name, line, column) + + class Command(Struct): """ IDL command information, a subtype of Struct. @@ -568,6 +580,7 @@ class Command(Struct): self.reply_type = None # type: str self.api_version = None # type: str self.is_deprecated = False # type: bool + self.access_check = None # type: AccessChecks super(Command, self).__init__(file_name, line, column) diff --git a/buildscripts/idl/tests/test_binder.py b/buildscripts/idl/tests/test_binder.py index 86bd63db773..8bba05db7d9 100644 --- a/buildscripts/idl/tests/test_binder.py +++ b/buildscripts/idl/tests/test_binder.py @@ -1649,7 +1649,7 @@ class TestBinder(testcase.IDLTestcase): serializer: foo deserializer: foo default: foo - + structs: reply: description: foo @@ -2363,6 +2363,40 @@ class TestBinder(testcase.IDLTestcase): version: 123 """), idl.errors.ERROR_ID_FEATURE_FLAG_DEFAULT_FALSE_HAS_VERSION) + def test_access_check(self): + # type: () -> None + """Test access check.""" + + test_preamble = textwrap.dedent(""" + types: + string: + description: foo + cpp_type: foo + bson_serialization_type: string + serializer: foo + deserializer: foo + + structs: + reply: + description: foo + fields: + foo: string + """) + + self.assert_bind(test_preamble + textwrap.dedent(""" + commands: + test1: + description: foo + command_name: foo + api_version: "" + namespace: ignored + access_check: + none: true + fields: + foo: string + reply_type: reply + """)) + if __name__ == '__main__': diff --git a/buildscripts/idl/tests/test_parser.py b/buildscripts/idl/tests/test_parser.py index 794b79342cf..87b8b23461d 100644 --- a/buildscripts/idl/tests/test_parser.py +++ b/buildscripts/idl/tests/test_parser.py @@ -885,7 +885,7 @@ class TestParser(testcase.IDLTestcase): deserializer: foo default: foo - structs: + structs: foo: description: foo strict: true @@ -1235,7 +1235,7 @@ class TestParser(testcase.IDLTestcase): # Commands and structs with same name self.assert_parse_fail( test_preamble + textwrap.dedent(""" - commands: + commands: foo: description: foo command_name: foo @@ -1243,7 +1243,7 @@ class TestParser(testcase.IDLTestcase): api_version: "" fields: foo: string - + structs: foo: description: foo @@ -1254,7 +1254,7 @@ class TestParser(testcase.IDLTestcase): # Commands and types with same name self.assert_parse_fail( test_preamble + textwrap.dedent(""" - commands: + commands: string: description: foo command_name: foo @@ -1322,7 +1322,7 @@ class TestParser(testcase.IDLTestcase): fields: foo: type: bar - supports_doc_sequence: false + supports_doc_sequence: false """)) # supports_doc_sequence can be true @@ -1706,6 +1706,25 @@ class TestParser(testcase.IDLTestcase): reply_type: foo_reply_struct """), idl.errors.ERROR_ID_COMMAND_DUPLICATES_NAME_AND_ALIAS) + def test_access_checks_positive(self): + # type: () -> None + """Positive access_check test cases.""" + + self.assert_parse( + textwrap.dedent(""" + commands: + foo: + description: foo + command_name: foo + api_version: 1 + namespace: ignored + access_check: + none: true + fields: + foo: bar + reply_type: foo_reply_struct + """)) + if __name__ == '__main__': diff --git a/src/mongo/idl/unittest.idl b/src/mongo/idl/unittest.idl index ba4fbe95d3a..a53748090e0 100644 --- a/src/mongo/idl/unittest.idl +++ b/src/mongo/idl/unittest.idl @@ -1020,7 +1020,7 @@ commands: reply_type: OkReply fields: anyTypeField: IDLAnyType - + CommandWithAnyTypeOwnedMember: description: "A mock command to test IDLAnyTypeOwned" command_name: CommandWithAnyTypeOwnedMember @@ -1030,6 +1030,16 @@ commands: fields: anyTypeField: IDLAnyTypeOwned + AccessCheckNone: + description: A versioned API command with access_check + command_name: AccessCheckNoneCommandName + namespace: ignored + strict: true + api_version: "1" + access_check: + none: true + reply_type: OkReply + # Test that we correctly generate C++ base classes for versioned API commands with different # key names, command names, and C++ names. APIVersion1CommandIDLName: -- cgit v1.2.1