diff options
-rw-r--r-- | buildscripts/idl/idl/ast.py | 1 | ||||
-rw-r--r-- | buildscripts/idl/idl/binder.py | 1 | ||||
-rw-r--r-- | buildscripts/idl/idl/errors.py | 8 | ||||
-rw-r--r-- | buildscripts/idl/idl/generator.py | 43 | ||||
-rw-r--r-- | buildscripts/idl/idl/parser.py | 5 | ||||
-rw-r--r-- | buildscripts/idl/idl/syntax.py | 6 | ||||
-rw-r--r-- | buildscripts/idl/tests/test_parser.py | 20 | ||||
-rw-r--r-- | src/mongo/db/commands.h | 23 | ||||
-rw-r--r-- | src/mongo/idl/unittest.idl | 10 |
9 files changed, 107 insertions, 10 deletions
diff --git a/buildscripts/idl/idl/ast.py b/buildscripts/idl/idl/ast.py index 9f0127224dc..3b5e5e3d122 100644 --- a/buildscripts/idl/idl/ast.py +++ b/buildscripts/idl/idl/ast.py @@ -239,6 +239,7 @@ class Command(Struct): """Construct a command.""" self.namespace = None # type: str self.command_name = None # type: str + self.command_alias = None # type: str self.command_field = None # type: Field self.reply_type = None # type: Field self.api_version = "" # type: str diff --git a/buildscripts/idl/idl/binder.py b/buildscripts/idl/idl/binder.py index a46c5537fd7..3a6ea97531c 100644 --- a/buildscripts/idl/idl/binder.py +++ b/buildscripts/idl/idl/binder.py @@ -562,6 +562,7 @@ def _bind_command(ctxt, parsed_spec, command): ast_command.api_version = command.api_version ast_command.is_deprecated = command.is_deprecated ast_command.command_name = command.command_name + ast_command.command_alias = command.command_alias # Inject special fields used for command parsing _inject_hidden_command_fields(command) diff --git a/buildscripts/idl/idl/errors.py b/buildscripts/idl/idl/errors.py index f7335f2413a..890025789c6 100644 --- a/buildscripts/idl/idl/errors.py +++ b/buildscripts/idl/idl/errors.py @@ -122,6 +122,7 @@ ERROR_ID_VARIANT_NO_DEFAULT = "ID0079" ERROR_ID_VARIANT_DUPLICATE_TYPES = "ID0080" ERROR_ID_VARIANT_STRUCTS = "ID0081" ERROR_ID_NO_VARIANT_ENUM = "ID0082" +ERROR_ID_COMMAND_DUPLICATES_NAME_AND_ALIAS = "ID0083" class IDLError(Exception): @@ -907,6 +908,13 @@ class ParserContext(object): location, ERROR_ID_ILLEGAL_FIELD_ALWAYS_SERIALIZE_NOT_OPTIONAL, ("Field '%s' specifies 'always_serialize' but 'optional' isn't true.") % (field_name)) + def add_duplicate_command_name_and_alias(self, node): + # type: (yaml.nodes.Node) -> None + """Add an error about a command name and command alias having the same name.""" + # pylint: disable=invalid-name + self._add_node_error(node, ERROR_ID_COMMAND_DUPLICATES_NAME_AND_ALIAS, + "Duplicate command_name and command_alias found.") + def _assert_unique_error_messages(): # type: () -> None diff --git a/buildscripts/idl/idl/generator.py b/buildscripts/idl/idl/generator.py index 38094c787eb..87e0264d5ea 100644 --- a/buildscripts/idl/idl/generator.py +++ b/buildscripts/idl/idl/generator.py @@ -725,6 +725,13 @@ class _CppHeaderFileWriter(_CppFileWriterBase): common.template_args('static constexpr auto kCommandName = "${command_name}"_sd;', command_name=struct.command_name)) + # Initialize constexpr for command alias if specified in the IDL spec. + if struct.command_alias: + self._writer.write_line( + common.template_args( + 'static constexpr auto kCommandAlias = "${command_alias}"_sd;', + command_alias=struct.command_alias)) + def gen_enum_functions(self, idl_enum): # type: (ast.Enum) -> None """Generate the declaration for an enum's supporting functions.""" @@ -899,7 +906,7 @@ class _CppHeaderFileWriter(_CppFileWriterBase): # Explicit custom constructor. self._writer.write_line(cls.name + '(StringData name, ServerParameterType spt);') else: - #Inherit base constructor. + # Inherit base constructor. self._writer.write_line('using ServerParameter::ServerParameter;') self.write_empty_line() @@ -935,6 +942,20 @@ class _CppHeaderFileWriter(_CppFileWriterBase): self._writer.write_line( 'using %s = %s;' % (new_type_name, common.title_case(old_type_name))) + def gen_derived_class_constructor(self, command_name, api_version, base_class, + *base_class_args): + # type: (str, str, str, *str) -> None + """Generate a derived class constructor.""" + class_name = common.title_case(command_name) + "CmdVersion" + api_version + "Gen" + args = ", ".join(base_class_args) + self._writer.write_line('%s(): %s(%s) {}' % (class_name, base_class, args)) + + def gen_derived_class_destructor(self, command_name, api_version): + # type: (str, str) -> None + """Generate a derived class destructor.""" + class_name = common.title_case(command_name) + "CmdVersion" + api_version + "Gen" + self._writer.write_line('virtual ~%s() = default;' % (class_name)) + def gen_api_version_fn(self, is_api_versions, api_version): # type: (bool, Union[str, bool]) -> None """Generate an apiVersions or deprecatedApiVersions function for a command's base class.""" @@ -980,6 +1001,20 @@ class _CppHeaderFileWriter(_CppFileWriterBase): self.gen_type_alias_declaration("Request", command.cpp_name) self.gen_type_alias_declaration("Reply", command.reply_type.type.cpp_type) + # Generate a constructor for generated derived class if command alias is specified. + if command.command_alias: + self.write_empty_line() + self.gen_derived_class_constructor(command.command_name, command.api_version, + 'TypedCommand<Derived>', 'Request::kCommandName', + 'Request::kCommandAlias') + + self.write_empty_line() + + # Generate a destructor for generated derived class. + self.gen_derived_class_destructor(command.command_name, command.api_version) + + self.write_empty_line() + # Write apiVersions() and deprecatedApiVersions() functions. self.gen_api_version_fn(True, command.api_version) self.gen_api_version_fn(False, command.is_deprecated) @@ -2158,6 +2193,12 @@ class _CppSourceFileWriter(_CppFileWriterBase): common.template_args('constexpr StringData ${class_name}::kCommandName;', class_name=common.title_case(struct.cpp_name))) + # Declare constexp for commmand alias if specified in the IDL spec. + if struct.command_alias: + self._writer.write_line( + common.template_args('constexpr StringData ${class_name}::kCommandAlias;', + class_name=common.title_case(struct.cpp_name))) + def gen_enum_definition(self, idl_enum): # type: (ast.Enum) -> None """Generate the definitions for an enum's supporting functions.""" diff --git a/buildscripts/idl/idl/parser.py b/buildscripts/idl/idl/parser.py index e3a5a9eb424..ccc07d1b27f 100644 --- a/buildscripts/idl/idl/parser.py +++ b/buildscripts/idl/idl/parser.py @@ -685,6 +685,7 @@ def _parse_enum(ctxt, spec, name, node): def _parse_command(ctxt, spec, name, node): # type: (errors.ParserContext, syntax.IDLSpec, str, Union[yaml.nodes.MappingNode, yaml.nodes.ScalarNode, yaml.nodes.SequenceNode]) -> None """Parse a command section in the IDL file.""" + # pylint: disable=too-many-branches if not ctxt.is_mapping_node(node, "command"): @@ -703,6 +704,7 @@ def _parse_command(ctxt, spec, name, node): "cpp_name": _RuleDesc('scalar'), "type": _RuleDesc('scalar_or_mapping', mapping_parser_func=_parse_field_type), "command_name": _RuleDesc('scalar'), + "command_alias": _RuleDesc('scalar'), "reply_type": _RuleDesc('scalar'), "api_version": _RuleDesc('scalar'), "is_deprecated": _RuleDesc('bool_scalar'), @@ -725,6 +727,9 @@ def _parse_command(ctxt, spec, name, node): if command.api_version is None: ctxt.add_missing_required_field_error(node, "command", "api_version") + if command.command_alias and command.command_alias == command.command_name: + ctxt.add_duplicate_command_name_and_alias(node) + if command.namespace: if command.namespace not in valid_commands: ctxt.add_bad_command_namespace_error(command, command.name, command.namespace, diff --git a/buildscripts/idl/idl/syntax.py b/buildscripts/idl/idl/syntax.py index c721cd10404..bcc854ee796 100644 --- a/buildscripts/idl/idl/syntax.py +++ b/buildscripts/idl/idl/syntax.py @@ -131,7 +131,7 @@ class SymbolTable(object): ctxt.add_duplicate_symbol_error(location, name, duplicate_class_name, entity_type) return True if entity_type == "command": - if item.command_name == name: + if name in [item.command_name, item.command_alias if item.command_alias else '']: ctxt.add_duplicate_symbol_error(location, name, duplicate_class_name, entity_type) return True @@ -159,7 +159,8 @@ class SymbolTable(object): def add_command(self, ctxt, command): # type: (errors.ParserContext, Command) -> None """Add an IDL command to the symbol table and check for duplicates.""" - if not self._is_duplicate(ctxt, command, command.name, "command"): + if (not self._is_duplicate(ctxt, command, command.name, "command") + and not self._is_duplicate(ctxt, command, command.command_alias, "command")): self.commands.append(command) def add_generic_argument_list(self, ctxt, field_list): @@ -529,6 +530,7 @@ class Command(Struct): """Construct a Command.""" self.namespace = None # type: str self.command_name = None # type: str + self.command_alias = None # type: str self.type = None # type: FieldType self.reply_type = None # type: str self.api_version = None # type: str diff --git a/buildscripts/idl/tests/test_parser.py b/buildscripts/idl/tests/test_parser.py index 9486bb467af..77640793029 100644 --- a/buildscripts/idl/tests/test_parser.py +++ b/buildscripts/idl/tests/test_parser.py @@ -1610,6 +1610,26 @@ class TestParser(testcase.IDLTestcase): """Negative generic reply fields list test cases.""" self._test_field_list_negative("generic_reply_field_lists", "forward_from_shards") + def test_command_alias(self): + # type: () -> None + """Test the 'command_alis' field.""" + + # The 'command_name' and 'command_alias' fields cannot have same value. + self.assert_parse_fail( + textwrap.dedent(f""" + commands: + foo: + description: foo + command_name: foo + command_alias: foo + namespace: ignored + api_version: 1 + fields: + foo: + type: bar + reply_type: foo_reply_struct + """), idl.errors.ERROR_ID_COMMAND_DUPLICATES_NAME_AND_ALIAS) + if __name__ == '__main__': diff --git a/src/mongo/db/commands.h b/src/mongo/db/commands.h index 6c9e2cdc668..2d3df2eb106 100644 --- a/src/mongo/db/commands.h +++ b/src/mongo/db/commands.h @@ -1118,9 +1118,7 @@ public: InvocationBaseInternal(OperationContext* opCtx, const Command* command, const OpMsgRequest& opMsgRequest) - : CommandInvocation(command), - - _request{_parseRequest(opCtx, command->getName(), opMsgRequest)} {} + : CommandInvocation(command), _request{_parseRequest(opCtx, command, opMsgRequest)} {} protected: const RequestType& request() const { @@ -1129,11 +1127,22 @@ protected: private: static RequestType _parseRequest(OperationContext* opCtx, - StringData name, + const Command* command, const OpMsgRequest& opMsgRequest) { - return RequestType::parse( - IDLParserErrorContext(name, APIParameters::get(opCtx).getAPIStrict().value_or(false)), - opMsgRequest); + + bool apiStrict = APIParameters::get(opCtx).getAPIStrict().value_or(false); + + // A command with 'apiStrict' cannot be invoked with alias. + if (opMsgRequest.getCommandName() != command->getName() && apiStrict) { + uasserted(ErrorCodes::APIStrictError, + str::stream() << "Command invocation with name '" + << opMsgRequest.getCommandName().toString() + << "' is not allowed in 'apiStrict' mode, use '" + << command->getName() << "' instead"); + } + + return RequestType::parse(IDLParserErrorContext(command->getName(), apiStrict), + opMsgRequest); } RequestType _request; diff --git a/src/mongo/idl/unittest.idl b/src/mongo/idl/unittest.idl index 1f1af1ceb24..b53ea3e28ca 100644 --- a/src/mongo/idl/unittest.idl +++ b/src/mongo/idl/unittest.idl @@ -1042,6 +1042,16 @@ commands: api_version: "1" reply_type: OkReply + # Test whether the C++ code for a command with alias name is currently generated. + APIVersion1CommandWithAlias: + description: A versioned API command with alias + command_name: NewCommandName + command_alias: OldCommandName + namespace: ignored + strict: true + api_version: "1" + reply_type: OkReply + ################################################################################################## # # Test field lists |