diff options
author | Simon Wang <jellalleonhardt4869@gmail.com> | 2021-09-13 19:50:45 +0800 |
---|---|---|
committer | Yuxuan 'fishy' Wang <yuxuan.wang@reddit.com> | 2022-09-11 08:25:32 -0700 |
commit | d5927a96019154fa590c38f3a7ca70275af11b3c (patch) | |
tree | 4c0a412c9d0bd333e24f21cc6d13bedadc487268 /compiler | |
parent | 944b8e68a099392d80153ebcf26f32ff7f1d893a (diff) | |
download | thrift-d5927a96019154fa590c38f3a7ca70275af11b3c.tar.gz |
THRIFT-5423: IDL parameter validation for Go
Closes https://github.com/apache/thrift/pull/2469.
Diffstat (limited to 'compiler')
24 files changed, 2236 insertions, 361 deletions
diff --git a/compiler/cpp/CMakeLists.txt b/compiler/cpp/CMakeLists.txt index a23004110..b0f123555 100644 --- a/compiler/cpp/CMakeLists.txt +++ b/compiler/cpp/CMakeLists.txt @@ -46,6 +46,8 @@ add_library(parse STATIC ${parse_SOURCES}) set(compiler_core src/thrift/common.cc src/thrift/generate/t_generator.cc + src/thrift/generate/validator_parser.cc + src/thrift/generate/validator_parser.h src/thrift/parse/t_typedef.cc src/thrift/parse/parse.cc src/thrift/version.h @@ -71,36 +73,51 @@ macro(THRIFT_ADD_COMPILER name description initial) endif() endmacro() +# This macro adds an option THRIFT_VALIDATOR_COMPILER_${NAME} +# that allows enabling or disabling certain languages' validator +macro(THRIFT_ADD_VALIDATOR_COMPILER name description initial) + string(TOUPPER "THRIFT_COMPILER_${name}" enabler) + set(src "src/thrift/generate/${name}_validator_generator.cc") + list(APPEND "src/thrift/generate/${name}_validator_generator.h") + option(${enabler} ${description} ${initial}) + if(${enabler}) + list(APPEND thrift-compiler_SOURCES ${src}) + endif() +endmacro() + # The following compiler can be enabled or disabled -THRIFT_ADD_COMPILER(c_glib "Enable compiler for C with Glib" ON) -THRIFT_ADD_COMPILER(cl "Enable compiler for Common LISP" ON) -THRIFT_ADD_COMPILER(cpp "Enable compiler for C++" ON) -THRIFT_ADD_COMPILER(d "Enable compiler for D" ON) -THRIFT_ADD_COMPILER(dart "Enable compiler for Dart" ON) -THRIFT_ADD_COMPILER(delphi "Enable compiler for Delphi" ON) -THRIFT_ADD_COMPILER(erl "Enable compiler for Erlang" ON) -THRIFT_ADD_COMPILER(go "Enable compiler for Go" ON) -THRIFT_ADD_COMPILER(gv "Enable compiler for GraphViz" ON) -THRIFT_ADD_COMPILER(haxe "Enable compiler for Haxe" ON) -THRIFT_ADD_COMPILER(html "Enable compiler for HTML Documentation" ON) -THRIFT_ADD_COMPILER(markdown "Enable compiler for Markdown Documentation" ON) -THRIFT_ADD_COMPILER(java "Enable compiler for Java" ON) -THRIFT_ADD_COMPILER(javame "Enable compiler for Java ME" ON) -THRIFT_ADD_COMPILER(js "Enable compiler for JavaScript" ON) -THRIFT_ADD_COMPILER(json "Enable compiler for JSON" ON) -THRIFT_ADD_COMPILER(kotlin "Enable compiler for Kotlin" ON) -THRIFT_ADD_COMPILER(lua "Enable compiler for Lua" ON) -THRIFT_ADD_COMPILER(netstd "Enable compiler for .NET Standard" ON) -THRIFT_ADD_COMPILER(ocaml "Enable compiler for OCaml" ON) -THRIFT_ADD_COMPILER(perl "Enable compiler for Perl" ON) -THRIFT_ADD_COMPILER(php "Enable compiler for PHP" ON) -THRIFT_ADD_COMPILER(py "Enable compiler for Python 2.0" ON) -THRIFT_ADD_COMPILER(rb "Enable compiler for Ruby" ON) -THRIFT_ADD_COMPILER(rs "Enable compiler for Rust" ON) -THRIFT_ADD_COMPILER(st "Enable compiler for Smalltalk" ON) -THRIFT_ADD_COMPILER(swift "Enable compiler for Cocoa Swift" ON) -THRIFT_ADD_COMPILER(xml "Enable compiler for XML" ON) -THRIFT_ADD_COMPILER(xsd "Enable compiler for XSD" ON) +THRIFT_ADD_COMPILER(c_glib "Enable compiler for C with Glib" ON) +THRIFT_ADD_COMPILER(cl "Enable compiler for Common LISP" ON) +THRIFT_ADD_COMPILER(cpp "Enable compiler for C++" ON) +THRIFT_ADD_COMPILER(d "Enable compiler for D" ON) +THRIFT_ADD_COMPILER(dart "Enable compiler for Dart" ON) +THRIFT_ADD_COMPILER(delphi "Enable compiler for Delphi" ON) +THRIFT_ADD_COMPILER(erl "Enable compiler for Erlang" ON) +THRIFT_ADD_COMPILER(go "Enable compiler for Go" ON) +THRIFT_ADD_COMPILER(gv "Enable compiler for GraphViz" ON) +THRIFT_ADD_COMPILER(haxe "Enable compiler for Haxe" ON) +THRIFT_ADD_COMPILER(html "Enable compiler for HTML Documentation" ON) +THRIFT_ADD_COMPILER(markdown "Enable compiler for Markdown Documentation" ON) +THRIFT_ADD_COMPILER(java "Enable compiler for Java" ON) +THRIFT_ADD_COMPILER(javame "Enable compiler for Java ME" ON) +THRIFT_ADD_COMPILER(js "Enable compiler for JavaScript" ON) +THRIFT_ADD_COMPILER(json "Enable compiler for JSON" ON) +THRIFT_ADD_COMPILER(kotlin "Enable compiler for Kotlin" ON) +THRIFT_ADD_COMPILER(lua "Enable compiler for Lua" ON) +THRIFT_ADD_COMPILER(netstd "Enable compiler for .NET Standard" ON) +THRIFT_ADD_COMPILER(ocaml "Enable compiler for OCaml" ON) +THRIFT_ADD_COMPILER(perl "Enable compiler for Perl" ON) +THRIFT_ADD_COMPILER(php "Enable compiler for PHP" ON) +THRIFT_ADD_COMPILER(py "Enable compiler for Python 2.0" ON) +THRIFT_ADD_COMPILER(rb "Enable compiler for Ruby" ON) +THRIFT_ADD_COMPILER(rs "Enable compiler for Rust" ON) +THRIFT_ADD_COMPILER(st "Enable compiler for Smalltalk" ON) +THRIFT_ADD_COMPILER(swift "Enable compiler for Cocoa Swift" ON) +THRIFT_ADD_COMPILER(xml "Enable compiler for XML" ON) +THRIFT_ADD_COMPILER(xsd "Enable compiler for XSD" ON) + +# The following compiler can be enabled or disabled by enabling or disabling certain languages +THRIFT_ADD_VALIDATOR_COMPILER(go "Enable validator compiler for Go" ON) # Thrift is looking for include files in the src directory # we also add the current binary directory for generated files diff --git a/compiler/cpp/Makefile.am b/compiler/cpp/Makefile.am index 55c82943f..bb29d8c47 100644 --- a/compiler/cpp/Makefile.am +++ b/compiler/cpp/Makefile.am @@ -77,6 +77,7 @@ thrift_SOURCES += src/thrift/generate/t_c_glib_generator.cc \ src/thrift/generate/t_delphi_generator.cc \ src/thrift/generate/t_erl_generator.cc \ src/thrift/generate/t_go_generator.cc \ + src/thrift/generate/t_go_generator.h \ src/thrift/generate/t_gv_generator.cc \ src/thrift/generate/t_haxe_generator.cc \ src/thrift/generate/t_html_generator.cc \ @@ -98,7 +99,11 @@ thrift_SOURCES += src/thrift/generate/t_c_glib_generator.cc \ src/thrift/generate/t_st_generator.cc \ src/thrift/generate/t_swift_generator.cc \ src/thrift/generate/t_xml_generator.cc \ - src/thrift/generate/t_xsd_generator.cc + src/thrift/generate/t_xsd_generator.cc \ + src/thrift/generate/validator_parser.cc \ + src/thrift/generate/validator_parser.h \ + src/thrift/generate/go_validator_generator.cc \ + src/thrift/generate/go_validator_generator.h thrift_CPPFLAGS = -I$(srcdir)/src thrift_CXXFLAGS = -Wall -Wextra -pedantic -Werror diff --git a/compiler/cpp/src/thrift/generate/go_validator_generator.cc b/compiler/cpp/src/thrift/generate/go_validator_generator.cc new file mode 100644 index 000000000..1f5a3ad6e --- /dev/null +++ b/compiler/cpp/src/thrift/generate/go_validator_generator.cc @@ -0,0 +1,906 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * This file is programmatically sanitized for style: + * astyle --style=1tbs -f -p -H -j -U go_validator_generator.cc + * + * The output of astyle should not be taken unquestioningly, but it is a good + * guide for ensuring uniformity and readability. + */ + +#include <fstream> +#include <iostream> +#include <limits> +#include <string> +#include <unordered_map> +#include <vector> + +#include "thrift/generate/go_validator_generator.h" +#include "thrift/generate/validator_parser.h" +#include "thrift/platform.h" +#include "thrift/version.h" +#include <algorithm> +#include <clocale> +#include <sstream> +#include <stdlib.h> +#include <sys/stat.h> +#include <sys/types.h> + +std::string go_validator_generator::get_field_reference_name(t_field* field) { + t_type* type(field->get_type()); + std::string tgt; + t_const_value* def_value; + go_generator->get_publicized_name_and_def_value(field, &tgt, &def_value); + tgt = "p." + tgt; + if (go_generator->is_pointer_field(field) + && (type->is_base_type() || type->is_enum() || type->is_container())) { + tgt = "*" + tgt; + } + return tgt; +} + +void go_validator_generator::generate_struct_validator(std::ostream& out, t_struct* tstruct) { + std::vector<t_field*> members = tstruct->get_members(); + validation_parser parser(tstruct); + for (auto it = members.begin(); it != members.end(); it++) { + t_field* field(*it); + const std::vector<validation_rule*>& rules + = parser.parse_field(field->get_type(), field->annotations_); + if (rules.size() == 0) { + continue; + } + bool opt = field->get_req() == t_field::T_OPTIONAL; + t_type* type = field->get_type(); + std::string tgt = get_field_reference_name(field); + std::string field_symbol = tstruct->get_name() + "." + field->get_name(); + generate_field_validator(out, generator_context{field_symbol, "", tgt, opt, type, rules}); + } +} + +void go_validator_generator::generate_field_validator(std::ostream& out, + const generator_context& context) { + t_type* type = context.type; + if (type->is_typedef()) { + type = type->get_true_type(); + } + if (type->is_enum()) { + if (context.tgt[0] == '*') { + out << indent() << "if " << context.tgt.substr(1) << " != nil {" << endl; + indent_up(); + } + generate_enum_field_validator(out, context); + if (context.tgt[0] == '*') { + indent_down(); + out << indent() << "}" << endl; + } + return; + } else if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + if (context.tgt[0] == '*') { + out << indent() << "if " << context.tgt.substr(1) << " != nil {" << endl; + indent_up(); + } + switch (tbase) { + case t_base_type::TYPE_UUID: + case t_base_type::TYPE_VOID: + break; + case t_base_type::TYPE_I8: + case t_base_type::TYPE_I16: + case t_base_type::TYPE_I32: + case t_base_type::TYPE_I64: + generate_integer_field_validator(out, context); + break; + case t_base_type::TYPE_DOUBLE: + generate_double_field_validator(out, context); + break; + case t_base_type::TYPE_STRING: + generate_string_field_validator(out, context); + break; + case t_base_type::TYPE_BOOL: + generate_bool_field_validator(out, context); + break; + } + if (context.tgt[0] == '*') { + indent_down(); + out << indent() << "}" << endl; + } + return; + } else if (type->is_list()) { + return generate_list_field_validator(out, context); + } else if (type->is_set()) { + return generate_set_field_validator(out, context); + } else if (type->is_map()) { + return generate_map_field_validator(out, context); + } else if (type->is_struct() || type->is_xception()) { + return generate_struct_field_validator(out, context); + } + throw "validator error: unsupported type: " + type->get_name(); +} + +void go_validator_generator::generate_enum_field_validator(std::ostream& out, + const generator_context& context) { + for (auto it = context.rules.begin(); it != context.rules.end(); it++) { + const std::vector<validation_value*>& values = (*it)->get_values(); + if (values.size() == 0) { + continue; + } + std::string key = (*it)->get_name(); + + if (key == "vt.in") { + if (values.size() > 1) { + std::string exist = GenID("_exist"); + out << indent() << "var " << exist << " bool" << endl; + + std::string src = GenID("_src"); + out << indent() << src << " := []int64{"; + for (auto it = values.begin(); it != values.end(); it++) { + if (it != values.begin()) { + out << ", "; + } + out << "int64("; + if ((*it)->is_field_reference()) { + out << get_field_reference_name((*it)->get_field_reference()); + } else { + out << (*it)->get_enum()->get_value(); + } + out << ")"; + } + out << "}" << endl; + + out << indent() << "for _, src := range " << src << " {" << endl; + indent_up(); + out << indent() << "if int64(" << context.tgt << ") == src {" << endl; + indent_up(); + out << indent() << exist << " = true" << endl; + out << indent() << "break" << endl; + indent_down(); + out << indent() << "}" << endl; + indent_down(); + out << indent() << "}" << endl; + out << indent() << "if " << exist << " == false {" << endl; + } else { + out << indent() << "if int64(" << context.tgt << ") != int64("; + if (values[0]->is_field_reference()) { + out << get_field_reference_name(values[0]->get_field_reference()); + } else { + out << values[0]->get_enum()->get_value(); + } + out << ") {" << endl; + } + indent_up(); + out << indent() + << "return thrift.NewValidationException(thrift.VALIDATION_FAILED, \"vt.in\", \"" + << context.field_symbol << "\", \"" << context.field_symbol + << " not valid, rule vt.in check failed\")" << endl; + indent_down(); + out << indent() << "}" << endl; + if (values.size() > 1) { + indent_down(); + out << indent() << "}" << endl; + } + } else if (key == "vt.not_in") { + if (values.size() > 1) { + std::string src = GenID("_src"); + out << indent() << src << " := []int64{"; + for (auto it = values.begin(); it != values.end(); it++) { + if (it != values.begin()) { + out << ", "; + } + out << "int64("; + if ((*it)->is_field_reference()) { + out << get_field_reference_name((*it)->get_field_reference()); + } else { + out << (*it)->get_enum()->get_value(); + } + out << ")"; + } + out << "}" << endl; + + out << indent() << "for _, src := range " << src << " {" << endl; + indent_up(); + out << indent() << "if int64(" << context.tgt << ") == src {" << endl; + } else { + out << indent() << "if int64(" << context.tgt << ") == "; + out << "int64("; + if (values[0]->is_field_reference()) { + out << get_field_reference_name(values[0]->get_field_reference()); + } else { + out << values[0]->get_enum()->get_value(); + } + out << ") {" << endl; + } + indent_up(); + out << indent() + << "return thrift.NewValidationException(thrift.VALIDATION_FAILED, \"vt.not_in\", \"" + << context.field_symbol << "\", \"" << context.field_symbol + << " not valid, rule vt.not_in check failed\")" << endl; + indent_down(); + out << indent() << "}" << endl; + if (values.size() > 1) { + indent_down(); + out << indent() << "}" << endl; + } + } else if (key == "vt.defined_only") { + if (values[0]->get_bool()) { + out << indent() << "if (" << context.tgt << ").String() == \"<UNSET>\" "; + } else { + continue; + } + out << "{" << endl; + indent_up(); + out << indent() + << "return thrift.NewValidationException(thrift.VALIDATION_FAILED, \"" + key + "\", \"" + << context.field_symbol << "\", \"" << context.field_symbol << " not valid, rule " << key + << " check failed\")" << endl; + indent_down(); + out << indent() << "}" << endl; + } + } +} + +void go_validator_generator::generate_bool_field_validator(std::ostream& out, + const generator_context& context) { + for (auto it = context.rules.begin(); it != context.rules.end(); it++) { + const std::vector<validation_value*>& values = (*it)->get_values(); + if (values.size() == 0) { + continue; + } + std::string key = (*it)->get_name(); + + if (key == "vt.const") { + out << indent() << "if " << context.tgt << " != "; + if (values[0]->is_field_reference()) { + out << get_field_reference_name(values[0]->get_field_reference()); + } else { + if (values[0]->get_bool()) { + out << "true"; + } else { + out << "false"; + } + } + } + out << "{" << endl; + indent_up(); + out << indent() + << "return thrift.NewValidationException(thrift.VALIDATION_FAILED, \"" + key + "\", \"" + << context.field_symbol << "\", \"" << context.field_symbol << " not valid, rule " << key + << " check failed\")" << endl; + indent_down(); + out << indent() << "}" << endl; + } +} + +void go_validator_generator::generate_double_field_validator(std::ostream& out, + const generator_context& context) { + for (auto it = context.rules.begin(); it != context.rules.end(); it++) { + const std::vector<validation_value*>& values = (*it)->get_values(); + if (values.size() == 0) { + continue; + } + + std::map<std::string, std::string> signs{{"vt.lt", ">="}, + {"vt.le", ">"}, + {"vt.gt", "<="}, + {"vt.ge", "<"}}; + std::string key = (*it)->get_name(); + auto key_it = signs.find(key); + if (key_it != signs.end()) { + out << indent() << "if " << context.tgt << " " << key_it->second << " "; + if (values[0]->is_field_reference()) { + out << get_field_reference_name(values[0]->get_field_reference()); + } else { + out << values[0]->get_double(); + } + out << "{" << endl; + indent_up(); + out << indent() + << "return thrift.NewValidationException(thrift.VALIDATION_FAILED, \"" + key + "\", \"" + << context.field_symbol << "\", \"" << context.field_symbol << " not valid, rule " << key + << " check failed\")" << endl; + indent_down(); + out << indent() << "}" << endl; + continue; + } else if (key == "vt.in") { + if (values.size() > 1) { + std::string exist = GenID("_exist"); + out << indent() << "var " << exist << " bool" << endl; + + std::string src = GenID("_src"); + out << indent() << src << " := []float64{"; + for (auto it = values.begin(); it != values.end(); it++) { + if (it != values.begin()) { + out << ", "; + } + if ((*it)->is_field_reference()) { + out << get_field_reference_name((*it)->get_field_reference()); + } else { + out << (*it)->get_double(); + } + } + out << "}" << endl; + + out << indent() << "for _, src := range " << src << " {" << endl; + indent_up(); + out << indent() << "if " << context.tgt << " == src {" << endl; + indent_up(); + out << indent() << exist << " = true" << endl; + out << indent() << "break" << endl; + indent_down(); + out << indent() << "}" << endl; + indent_down(); + out << indent() << "}" << endl; + out << indent() << "if " << exist << " == false {" << endl; + } else { + out << indent() << "if " << context.tgt << " != "; + if (values[0]->is_field_reference()) { + out << get_field_reference_name(values[0]->get_field_reference()); + } else { + out << values[0]->get_double(); + } + out << "{" << endl; + } + + indent_up(); + out << indent() + << "return thrift.NewValidationException(thrift.VALIDATION_FAILED, \"vt.in\", \"" + << context.field_symbol << "\", \"" << context.field_symbol + << " not valid, rule vt.in check failed\")" << endl; + indent_down(); + out << indent() << "}" << endl; + } else if (key == "vt.not_in") { + if (values.size() > 1) { + std::string src = GenID("_src"); + out << indent() << src << " := []float64{"; + for (auto it = values.begin(); it != values.end(); it++) { + if (it != values.begin()) { + out << ", "; + } + if ((*it)->is_field_reference()) { + out << get_field_reference_name((*it)->get_field_reference()); + } else { + out << (*it)->get_double(); + } + } + out << "}" << endl; + + out << indent() << "for _, src := range " << src << " {" << endl; + indent_up(); + out << indent() << "if " << context.tgt << " == src {" << endl; + } else { + out << indent() << "if " << context.tgt << " == "; + if (values[0]->is_field_reference()) { + out << get_field_reference_name(values[0]->get_field_reference()); + } else { + out << values[0]->get_double(); + } + out << "{" << endl; + } + indent_up(); + out << indent() + << "return thrift.NewValidationException(thrift.VALIDATION_FAILED, \"vt.not_in\", \"" + << context.field_symbol << "\", \"" << context.field_symbol + << " not valid, rule vt.not_in check failed\")" << endl; + indent_down(); + out << indent() << "}" << endl; + if (values.size() > 1) { + indent_down(); + out << indent() << "}" << endl; + } + } + } +} + +void go_validator_generator::generate_integer_field_validator(std::ostream& out, + const generator_context& context) { + auto generate_current_type = [](std::ostream& out, t_type* type) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_I8: + out << "int8"; + break; + case t_base_type::TYPE_I16: + out << "int16"; + break; + case t_base_type::TYPE_I32: + out << "int32"; + break; + case t_base_type::TYPE_I64: + out << "int64"; + break; + default: + throw "validator error: unsupported integer type: " + type->get_name(); + } + }; + + for (auto it = context.rules.begin(); it != context.rules.end(); it++) { + const std::vector<validation_value*>& values = (*it)->get_values(); + if (values.size() == 0) { + continue; + } + + std::map<std::string, std::string> signs{{"vt.lt", ">="}, + {"vt.le", ">"}, + {"vt.gt", "<="}, + {"vt.ge", "<"}}; + std::string key = (*it)->get_name(); + auto key_it = signs.find(key); + if (key_it != signs.end()) { + out << indent() << "if " << context.tgt << " " << key_it->second << " "; + if (values[0]->is_field_reference()) { + out << get_field_reference_name(values[0]->get_field_reference()); + } else if (values[0]->is_validation_function()) { + generate_current_type(out, context.type); + out << "("; + validation_value::validation_function* func = values[0]->get_function(); + if (func->name == "len") { + out << "len("; + if (func->arguments[0]->is_field_reference()) { + out << get_field_reference_name(func->arguments[0]->get_field_reference()); + } + out << ")"; + } + out << ")"; + } else { + out << values[0]->get_int(); + } + out << "{" << endl; + indent_up(); + out << indent() + << "return thrift.NewValidationException(thrift.VALIDATION_FAILED, \"" + key + "\", \"" + << context.field_symbol << "\", \"" << context.field_symbol << " not valid, rule " << key + << " check failed\")" << endl; + indent_down(); + out << indent() << "}" << endl; + } else if (key == "vt.in") { + if (values.size() > 1) { + std::string exist = GenID("_exist"); + out << indent() << "var " << exist << " bool" << endl; + + std::string src = GenID("_src"); + out << indent() << src << " := []"; + generate_current_type(out, context.type); + out << "{"; + for (auto it = values.begin(); it != values.end(); it++) { + if (it != values.begin()) { + out << ", "; + } + if ((*it)->is_field_reference()) { + out << get_field_reference_name((*it)->get_field_reference()); + } else if ((*it)->is_validation_function()) { + generate_current_type(out, context.type); + out << "("; + validation_value::validation_function* func = (*it)->get_function(); + if (func->name == "len") { + out << "len("; + if (func->arguments[0]->is_field_reference()) { + out << get_field_reference_name(func->arguments[0]->get_field_reference()); + } + out << ")"; + } + out << ")"; + } else { + out << (*it)->get_int(); + } + } + out << "}" << endl; + + out << indent() << "for _, src := range " << src << " {" << endl; + indent_up(); + out << indent() << "if " << context.tgt << " == src {" << endl; + indent_up(); + out << indent() << exist << " = true" << endl; + out << indent() << "break" << endl; + indent_down(); + out << indent() << "}" << endl; + indent_down(); + out << indent() << "}" << endl; + out << indent() << "if " << exist << " == false {" << endl; + } else { + out << indent() << "if " << context.tgt << " != "; + if (values[0]->is_field_reference()) { + out << get_field_reference_name(values[0]->get_field_reference()); + } else if (values[0]->is_validation_function()) { + generate_current_type(out, context.type); + out << "("; + validation_value::validation_function* func = values[0]->get_function(); + if (func->name == "len") { + out << "len("; + if (func->arguments[0]->is_field_reference()) { + out << get_field_reference_name(func->arguments[0]->get_field_reference()); + } + out << ")"; + } + out << ")"; + } else { + out << values[0]->get_int(); + } + out << "{" << endl; + } + indent_up(); + out << indent() + << "return thrift.NewValidationException(thrift.VALIDATION_FAILED, \"vt.in\", \"" + << context.field_symbol << "\", \"" << context.field_symbol + << " not valid, rule vt.in check failed\")" << endl; + indent_down(); + out << indent() << "}" << endl; + } else if (key == "vt.not_in") { + if (values.size() > 1) { + std::string src = GenID("_src"); + out << indent() << src << " := []"; + t_base_type::t_base tbase = ((t_base_type*)context.type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_I8: + out << "int8"; + break; + case t_base_type::TYPE_I16: + out << "int16"; + break; + case t_base_type::TYPE_I32: + out << "int32"; + break; + case t_base_type::TYPE_I64: + out << "int64"; + break; + default: + throw "validator error: unsupported integer type: " + context.type->get_name(); + } + out << "{"; + for (auto it = values.begin(); it != values.end(); it++) { + if (it != values.begin()) { + out << ", "; + } + if ((*it)->is_field_reference()) { + out << get_field_reference_name((*it)->get_field_reference()); + } else if ((*it)->is_validation_function()) { + generate_current_type(out, context.type); + out << "("; + validation_value::validation_function* func = (*it)->get_function(); + if (func->name == "len") { + out << "len("; + if (func->arguments[0]->is_field_reference()) { + out << get_field_reference_name(func->arguments[0]->get_field_reference()); + } + out << ")"; + } + out << ")"; + } else { + out << (*it)->get_int(); + } + } + out << "}" << endl; + + out << indent() << "for _, src := range " << src << " {" << endl; + indent_up(); + out << indent() << "if " << context.tgt << " == src {" << endl; + } else { + out << indent() << "if " << context.tgt << " == "; + if (values[0]->is_field_reference()) { + out << get_field_reference_name(values[0]->get_field_reference()); + } else if (values[0]->is_validation_function()) { + generate_current_type(out, context.type); + out << "("; + validation_value::validation_function* func = values[0]->get_function(); + if (func->name == "len") { + out << "len("; + if (func->arguments[0]->is_field_reference()) { + out << get_field_reference_name(func->arguments[0]->get_field_reference()); + } + out << ")"; + } + out << ")"; + } else { + out << values[0]->get_int(); + } + out << "{" << endl; + } + indent_up(); + out << indent() + << "return thrift.NewValidationException(thrift.VALIDATION_FAILED, \"vt.not_in\", \"" + << context.field_symbol << "\", \"" << context.field_symbol + << " not valid, rule vt.not_in check failed\")" << endl; + indent_down(); + out << indent() << "}" << endl; + if (values.size() > 1) { + indent_down(); + out << indent() << "}" << endl; + } + } + } +} + +void go_validator_generator::generate_string_field_validator(std::ostream& out, + const generator_context& context) { + std::string target = context.tgt; + t_type* type = context.type; + if (type->is_typedef()) { + type = type->get_true_type(); + } + if (type->is_binary()) { + target = GenID("_tgt"); + out << indent() << target << " := " + << "string(" << context.tgt << ")" << endl; + } + for (auto it = context.rules.begin(); it != context.rules.end(); it++) { + const std::vector<validation_value*>& values = (*it)->get_values(); + if (values.size() == 0) { + continue; + } + std::string key = (*it)->get_name(); + + if (key == "vt.const") { + out << indent() << "if " << target << " != "; + if (values[0]->is_field_reference()) { + out << "string("; + out << get_field_reference_name(values[0]->get_field_reference()); + out << ")"; + } else { + out << "\"" << values[0]->get_string() << "\""; + } + } else if (key == "vt.min_size" || key == "vt.max_size") { + out << indent() << "if len(" << target << ") "; + if (key == "vt.min_size") { + out << "<"; + } else { + out << ">"; + } + out << " int("; + if (values[0]->is_field_reference()) { + out << get_field_reference_name(values[0]->get_field_reference()); + } else if (values[0]->is_validation_function()) { + validation_value::validation_function* func = values[0]->get_function(); + if (func->name == "len") { + out << "len("; + if (func->arguments[0]->is_field_reference()) { + out << "string("; + out << get_field_reference_name(values[0]->get_field_reference()); + out << ")"; + } + out << ")"; + } + } else { + out << values[0]->get_int(); + } + out << ")"; + } else if (key == "vt.pattern") { + out << indent() << "if ok, _ := regexp.MatchString(" << target << ","; + if (values[0]->is_field_reference()) { + out << "string("; + out << get_field_reference_name(values[0]->get_field_reference()); + out << ")"; + } else { + out << "\"" << values[0]->get_string() << "\""; + } + out << "); ok "; + } else if (key == "vt.prefix") { + out << indent() << "if !strings.HasPrefix(" << target << ","; + if (values[0]->is_field_reference()) { + out << "string("; + out << get_field_reference_name(values[0]->get_field_reference()); + out << ")"; + } else { + out << "\"" << values[0]->get_string() << "\""; + } + out << ")"; + } else if (key == "vt.suffix") { + out << indent() << "if !strings.HasSuffix(" << target << ","; + if (values[0]->is_field_reference()) { + out << "string("; + out << get_field_reference_name(values[0]->get_field_reference()); + out << ")"; + } else { + out << "\"" << values[0]->get_string() << "\""; + } + out << ")"; + } else if (key == "vt.contains") { + out << indent() << "if !strings.Contains(" << target << ","; + if (values[0]->is_field_reference()) { + out << "string("; + out << get_field_reference_name(values[0]->get_field_reference()); + out << ")"; + } else { + out << "\"" << values[0]->get_string() << "\""; + } + out << ")"; + } else if (key == "vt.not_contains") { + out << indent() << "if strings.Contains(" << target << ","; + if (values[0]->is_field_reference()) { + out << "string("; + out << get_field_reference_name(values[0]->get_field_reference()); + out << ")"; + } else { + out << "\"" << values[0]->get_string() << "\""; + } + out << ")"; + } + out << "{" << endl; + indent_up(); + out << indent() + << "return thrift.NewValidationException(thrift.VALIDATION_FAILED, \"" + key + "\", \"" + << context.field_symbol << "\", \"" << context.field_symbol << " not valid, rule " << key + << " check failed\")" << endl; + indent_down(); + out << indent() << "}" << endl; + } +} + +void go_validator_generator::generate_set_field_validator(std::ostream& out, + const generator_context& context) { + return generate_list_field_validator(out, context); +} + +void go_validator_generator::generate_list_field_validator(std::ostream& out, + const generator_context& context) { + for (auto it = context.rules.begin(); it != context.rules.end(); it++) { + const std::vector<validation_value*>& values = (*it)->get_values(); + std::string key = (*it)->get_name(); + if (key == "vt.min_size" || key == "vt.max_size") { + out << indent() << "if len(" << context.tgt << ")"; + if (key == "vt.min_size") { + out << " < "; + } else { + out << " > "; + } + if (values[0]->is_field_reference()) { + out << "int("; + out << get_field_reference_name(values[0]->get_field_reference()); + out << ")"; + } else { + out << values[0]->get_int(); + } + out << "{" << endl; + indent_up(); + out << indent() + << "return thrift.NewValidationException(thrift.VALIDATION_FAILED, \"" + key + "\", \"" + << context.field_symbol << "\", \"" << context.field_symbol << " not valid, rule " << key + << " check failed\")" << endl; + indent_down(); + out << indent() << "}" << endl; + } else if (key == "vt.elem") { + out << indent() << "for i := 0; i < len(" << context.tgt << ");i++ {" << endl; + indent_up(); + std::string src = GenID("_elem"); + out << indent() << src << " := " << context.tgt << "[i]" << endl; + t_type* elem_type; + if (context.type->is_list()) { + elem_type = ((t_list*)context.type)->get_elem_type(); + } else { + elem_type = ((t_set*)context.type)->get_elem_type(); + } + generator_context ctx{context.field_symbol + ".elem", + "", + src, + false, + elem_type, + std::vector<validation_rule*>{(*it)->get_inner()}}; + generate_field_validator(out, ctx); + indent_down(); + out << indent() << "}" << endl; + } + } +} + +void go_validator_generator::generate_map_field_validator(std::ostream& out, + const generator_context& context) { + for (auto it = context.rules.begin(); it != context.rules.end(); it++) { + const std::vector<validation_value*>& values = (*it)->get_values(); + std::string key = (*it)->get_name(); + if (key == "vt.min_size" || key == "vt.max_size") { + out << indent() << "if len(" << context.tgt << ")"; + if (key == "vt.min_size") { + out << " < "; + } else { + out << " > "; + } + if (values[0]->is_field_reference()) { + out << "int("; + out << get_field_reference_name(values[0]->get_field_reference()); + out << ")"; + } else { + out << values[0]->get_int(); + } + out << "{" << endl; + indent_up(); + out << indent() + << "return thrift.NewValidationException(thrift.VALIDATION_FAILED, \"" + key + "\", \"" + << context.field_symbol << "\", \"" << context.field_symbol << " not valid, rule " << key + << " check failed\")" << endl; + indent_down(); + out << indent() << "}" << endl; + } else if (key == "vt.key") { + std::string src = GenID("_key"); + out << indent() << "for " << src << " := range " << context.tgt << " {" << endl; + indent_up(); + generator_context ctx{context.field_symbol + ".key", + "", + src, + false, + ((t_map*)context.type)->get_key_type(), + std::vector<validation_rule*>{(*it)->get_inner()}}; + generate_field_validator(out, ctx); + indent_down(); + out << indent() << "}" << endl; + } else if (key == "vt.value") { + std::string src = GenID("_value"); + out << indent() << "for _, " << src << " := range " << context.tgt << " {" << endl; + indent_up(); + generator_context ctx{context.field_symbol + ".value", + "", + src, + false, + ((t_map*)context.type)->get_val_type(), + std::vector<validation_rule*>{(*it)->get_inner()}}; + generate_field_validator(out, ctx); + indent_down(); + out << indent() << "}" << endl; + } + } +} + +void go_validator_generator::generate_struct_field_validator(std::ostream& out, + const generator_context& context) { + bool generate_valid = true; + validation_rule* last_valid_rule = nullptr; + for (auto it = context.rules.begin(); it != context.rules.end(); it++) { + const std::vector<validation_value*>& values = (*it)->get_values(); + if (values.size() == 0) { + continue; + } + std::string key = (*it)->get_name(); + + if (key == "vt.skip") { + if (values[0]->is_field_reference() || !values[0]->get_bool()) { + generate_valid = true; + } else if (values[0]->get_bool()) { + generate_valid = false; + } + last_valid_rule = *it; + } + } + if (generate_valid) { + if (last_valid_rule == nullptr) { + out << indent() << "if err := " << context.tgt << ".Validate(); err != nil {" << endl; + indent_up(); + out << indent() << "return err" << endl; + indent_down(); + out << indent() << "}" << endl; + } else { + const std::vector<validation_value*>& values = last_valid_rule->get_values(); + if (!values[0]->get_bool()) { + out << indent() << "if err := " << context.tgt << ".Validate(); err != nil {" << endl; + indent_up(); + out << indent() << "return err" << endl; + indent_down(); + out << indent() << "}" << endl; + } else if (values[0]->is_field_reference()) { + out << indent() << "if !"; + out << get_field_reference_name(values[0]->get_field_reference()); + out << "{" << endl; + indent_up(); + out << indent() << "if err := " << context.tgt << ".Validate(); err != nil {" << endl; + indent_up(); + out << indent() << "return err" << endl; + indent_down(); + out << indent() << "}" << endl; + indent_down(); + out << indent() << "}" << endl; + } + } + } +} diff --git a/compiler/cpp/src/thrift/generate/go_validator_generator.h b/compiler/cpp/src/thrift/generate/go_validator_generator.h new file mode 100644 index 000000000..ca36347cb --- /dev/null +++ b/compiler/cpp/src/thrift/generate/go_validator_generator.h @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef T_GO_VALIDATOR_GENERATOR_H +#define T_GO_VALIDATOR_GENERATOR_H + +#include "thrift/generate/t_generator.h" +#include "thrift/generate/t_go_generator.h" +#include "thrift/generate/validator_parser.h" +#include <fstream> +#include <iostream> +#include <limits> +#include <string> +#include <vector> + +class go_validator_generator { +public: + go_validator_generator(t_go_generator* gg) : go_generator(gg){}; + void generate_struct_validator(std::ostream& out, t_struct* tstruct); + + struct generator_context { + std::string field_symbol; + std::string src; + std::string tgt; + bool opt; + t_type* type; + std::vector<validation_rule*> rules; + }; + +private: + void generate_field_validator(std::ostream& out, const generator_context& context); + void generate_enum_field_validator(std::ostream& out, const generator_context& context); + void generate_bool_field_validator(std::ostream& out, const generator_context& context); + void generate_integer_field_validator(std::ostream& out, const generator_context& context); + void generate_double_field_validator(std::ostream& out, const generator_context& context); + void generate_string_field_validator(std::ostream& out, const generator_context& context); + void generate_list_field_validator(std::ostream& out, const generator_context& context); + void generate_set_field_validator(std::ostream& out, const generator_context& context); + void generate_map_field_validator(std::ostream& out, const generator_context& context); + void generate_struct_field_validator(std::ostream& out, const generator_context& context); + + void indent_up() { go_generator->indent_up(); } + void indent_down() { go_generator->indent_down(); } + std::string indent() { return go_generator->indent(); } + + std::string get_field_name(t_field* field); + std::string get_field_reference_name(t_field* field); + + std::string GenID(std::string id) { return id + std::to_string(tmp_[id]++); }; + + t_go_generator* go_generator; + + std::map<std::string, int> tmp_; +}; + +#endif diff --git a/compiler/cpp/src/thrift/generate/t_cpp_generator.cc b/compiler/cpp/src/thrift/generate/t_cpp_generator.cc index 0420e62e7..dd444bc9c 100644 --- a/compiler/cpp/src/thrift/generate/t_cpp_generator.cc +++ b/compiler/cpp/src/thrift/generate/t_cpp_generator.cc @@ -4398,9 +4398,9 @@ string t_cpp_generator::namespace_close(string ns) { string t_cpp_generator::type_name(t_type* ttype, bool in_typedef, bool arg) { if (ttype->is_base_type()) { string bname = base_type_name(((t_base_type*)ttype)->get_base()); - std::map<string, string>::iterator it = ttype->annotations_.find("cpp.type"); - if (it != ttype->annotations_.end()) { - bname = it->second; + std::map<string, std::vector<string>>::iterator it = ttype->annotations_.find("cpp.type"); + if (it != ttype->annotations_.end() && !it->second.empty()) { + bname = it->second.back(); } if (!arg) { diff --git a/compiler/cpp/src/thrift/generate/t_delphi_generator.cc b/compiler/cpp/src/thrift/generate/t_delphi_generator.cc index f35ffcb99..625179f8f 100644 --- a/compiler/cpp/src/thrift/generate/t_delphi_generator.cc +++ b/compiler/cpp/src/thrift/generate/t_delphi_generator.cc @@ -3329,11 +3329,11 @@ string t_delphi_generator::function_signature(t_function* tfunction, // deprecated method? only at intf decl! if( full_cls == "") { auto iter = tfunction->annotations_.find("deprecated"); - if( tfunction->annotations_.end() != iter) { + if( tfunction->annotations_.end() != iter && !iter->second.empty()) { signature += " deprecated"; // empty annotation values end up with "1" somewhere, ignore these as well - if ((iter->second.length() > 0) && (iter->second != "1")) { - signature += " " + make_pascal_string_literal(iter->second); + if ((iter->second.back().length() > 0) && (iter->second.back() != "1")) { + signature += " " + make_pascal_string_literal(iter->second.back()); } signature += ";"; } diff --git a/compiler/cpp/src/thrift/generate/t_go_generator.cc b/compiler/cpp/src/thrift/generate/t_go_generator.cc index 90f34c8ca..e0ca489f5 100644 --- a/compiler/cpp/src/thrift/generate/t_go_generator.cc +++ b/compiler/cpp/src/thrift/generate/t_go_generator.cc @@ -41,6 +41,8 @@ #include "thrift/platform.h" #include "thrift/version.h" #include "thrift/generate/t_generator.h" +#include "thrift/generate/t_go_generator.h" +#include "thrift/generate/go_validator_generator.h" using std::map; using std::ostream; @@ -49,8 +51,6 @@ using std::string; using std::stringstream; using std::vector; -static const string endl = "\n"; // avoid ostream << std::endl flushes - /** * A helper for automatically formatting the emitted Go code from the Thrift * IDL per the Go style guide. @@ -62,277 +62,6 @@ static const string endl = "\n"; // avoid ostream << std::endl flushes */ bool format_go_output(const string& file_path); -const string DEFAULT_THRIFT_IMPORT = "github.com/apache/thrift/lib/go/thrift"; -static std::string package_flag; - -/** - * Go code generator. - */ -class t_go_generator : public t_generator { -public: - t_go_generator(t_program* program, - const std::map<std::string, std::string>& parsed_options, - const std::string& option_string) - : t_generator(program) { - (void)option_string; - std::map<std::string, std::string>::const_iterator iter; - - - gen_thrift_import_ = DEFAULT_THRIFT_IMPORT; - gen_package_prefix_ = ""; - package_flag = ""; - read_write_private_ = false; - ignore_initialisms_ = false; - skip_remote_ = false; - for( iter = parsed_options.begin(); iter != parsed_options.end(); ++iter) { - if( iter->first.compare("package_prefix") == 0) { - gen_package_prefix_ = (iter->second); - } else if( iter->first.compare("thrift_import") == 0) { - gen_thrift_import_ = (iter->second); - } else if( iter->first.compare("package") == 0) { - package_flag = (iter->second); - } else if( iter->first.compare("read_write_private") == 0) { - read_write_private_ = true; - } else if( iter->first.compare("ignore_initialisms") == 0) { - ignore_initialisms_ = true; - } else if( iter->first.compare("skip_remote") == 0) { - skip_remote_ = true; - } else { - throw "unknown option go:" + iter->first; - } - } - - out_dir_base_ = "gen-go"; - } - - /** - * Init and close methods - */ - - void init_generator() override; - void close_generator() override; - - /** - * Program-level generation functions - */ - - void generate_typedef(t_typedef* ttypedef) override; - void generate_enum(t_enum* tenum) override; - void generate_const(t_const* tconst) override; - void generate_struct(t_struct* tstruct) override; - void generate_xception(t_struct* txception) override; - void generate_service(t_service* tservice) override; - - std::string render_const_value(t_type* type, t_const_value* value, const string& name, bool opt = false); - - /** - * Struct generation code - */ - - void generate_go_struct(t_struct* tstruct, bool is_exception); - void generate_go_struct_definition(std::ostream& out, - t_struct* tstruct, - bool is_xception = false, - bool is_result = false, - bool is_args = false); - void generate_go_struct_initializer(std::ostream& out, - t_struct* tstruct, - bool is_args_or_result = false); - void generate_isset_helpers(std::ostream& out, - t_struct* tstruct, - const string& tstruct_name, - bool is_result = false); - void generate_countsetfields_helper(std::ostream& out, - t_struct* tstruct, - const string& tstruct_name, - bool is_result = false); - void generate_go_struct_reader(std::ostream& out, - t_struct* tstruct, - const string& tstruct_name, - bool is_result = false); - void generate_go_struct_writer(std::ostream& out, - t_struct* tstruct, - const string& tstruct_name, - bool is_result = false, - bool uses_countsetfields = false); - void generate_go_struct_equals(std::ostream& out, t_struct* tstruct, const string& tstruct_name); - void generate_go_function_helpers(t_function* tfunction); - void get_publicized_name_and_def_value(t_field* tfield, - string* OUT_pub_name, - t_const_value** OUT_def_value) const; - - /** - * Service-level generation functions - */ - - void generate_service_helpers(t_service* tservice); - void generate_service_interface(t_service* tservice); - void generate_service_client(t_service* tservice); - void generate_service_remote(t_service* tservice); - void generate_service_server(t_service* tservice); - void generate_process_function(t_service* tservice, t_function* tfunction); - - /** - * Serialization constructs - */ - - void generate_deserialize_field(std::ostream& out, - t_field* tfield, - bool declare, - std::string prefix = "", - bool inclass = false, - bool coerceData = false, - bool inkey = false, - bool in_container = false); - - void generate_deserialize_struct(std::ostream& out, - t_struct* tstruct, - bool is_pointer_field, - bool declare, - std::string prefix = ""); - - void generate_deserialize_container(std::ostream& out, - t_type* ttype, - bool pointer_field, - bool declare, - std::string prefix = ""); - - void generate_deserialize_set_element(std::ostream& out, - t_set* tset, - bool declare, - std::string prefix = ""); - - void generate_deserialize_map_element(std::ostream& out, - t_map* tmap, - bool declare, - std::string prefix = ""); - - void generate_deserialize_list_element(std::ostream& out, - t_list* tlist, - bool declare, - std::string prefix = ""); - - void generate_serialize_field(std::ostream& out, - t_field* tfield, - std::string prefix = "", - bool inkey = false); - - void generate_serialize_struct(std::ostream& out, t_struct* tstruct, std::string prefix = ""); - - void generate_serialize_container(std::ostream& out, - t_type* ttype, - bool pointer_field, - std::string prefix = ""); - - void generate_serialize_map_element(std::ostream& out, - t_map* tmap, - std::string kiter, - std::string viter); - - void generate_serialize_set_element(std::ostream& out, t_set* tmap, std::string iter); - - void generate_serialize_list_element(std::ostream& out, t_list* tlist, std::string iter); - - void generate_go_equals(std::ostream& out, t_type* ttype, string tgt, string src); - - void generate_go_equals_struct(std::ostream& out, t_type* ttype, string tgt, string src); - - void generate_go_equals_container(std::ostream& out, t_type* ttype, string tgt, string src); - - void generate_go_docstring(std::ostream& out, t_struct* tstruct); - - void generate_go_docstring(std::ostream& out, t_function* tfunction); - - void generate_go_docstring(std::ostream& out, - t_doc* tdoc, - t_struct* tstruct, - const char* subheader); - - void generate_go_docstring(std::ostream& out, t_doc* tdoc); - - void parse_go_tags(map<string,string>* tags, const string in); - - /** - * Helper rendering functions - */ - - std::string go_autogen_comment(); - std::string go_package(); - std::string go_imports_begin(bool consts); - std::string go_imports_end(); - std::string render_includes(bool consts); - std::string render_included_programs(string& unused_protection); - std::string render_program_import(const t_program* program, string& unused_protection); - std::string render_system_packages(std::vector<string> &system_packages); - std::string render_import_protection(); - std::string render_fastbinary_includes(); - std::string declare_argument(t_field* tfield); - std::string render_field_initial_value(t_field* tfield, const string& name, bool optional_field); - std::string type_name(t_type* ttype); - std::string module_name(t_type* ttype); - std::string function_signature(t_function* tfunction, std::string prefix = ""); - std::string function_signature_if(t_function* tfunction, - std::string prefix = "", - bool addError = false); - std::string argument_list(t_struct* tstruct); - std::string type_to_enum(t_type* ttype); - std::string type_to_go_type(t_type* ttype); - std::string type_to_go_type_with_opt(t_type* ttype, - bool optional_field); - std::string type_to_go_key_type(t_type* ttype); - std::string type_to_spec_args(t_type* ttype); - - static std::string get_real_go_module(const t_program* program) { - - if (!package_flag.empty()) { - return package_flag; - } - std::string real_module = program->get_namespace("go"); - if (!real_module.empty()) { - return real_module; - } - - return lowercase(program->get_name()); - } - -private: - std::string gen_package_prefix_; - std::string gen_thrift_import_; - bool read_write_private_; - bool ignore_initialisms_; - bool skip_remote_; - - /** - * File streams - */ - - ofstream_with_content_based_conditional_update f_types_; - std::string f_types_name_; - ofstream_with_content_based_conditional_update f_consts_; - std::string f_consts_name_; - std::stringstream f_const_values_; - - std::string package_name_; - std::string package_dir_; - std::unordered_map<std::string, std::string> package_identifiers_; - std::set<std::string> package_identifiers_set_; - std::string read_method_name_; - std::string write_method_name_; - std::string equals_method_name_; - - std::set<std::string> commonInitialisms; - - std::string camelcase(const std::string& value) const; - void fix_common_initialism(std::string& value, int i) const; - std::string publicize(const std::string& value, bool is_args_or_result = false) const; - std::string publicize(const std::string& value, bool is_args_or_result, const std::string& service_name) const; - std::string privatize(const std::string& value) const; - std::string new_prefix(const std::string& value) const; - static std::string variable_name_to_go_name(const std::string& value); - static bool is_pointer_field(t_field* tfield, bool in_container = false); - static bool omit_initialization(t_field* tfield); -}; - // returns true if field initialization can be omitted since it has corresponding go type zero value // or default value is not set bool t_go_generator::omit_initialization(t_field* tfield) { @@ -973,6 +702,10 @@ string t_go_generator::go_imports_begin(bool consts) { system_packages.push_back("time"); // For the thrift import, always do rename import to make sure it's called thrift. system_packages.push_back("thrift \"" + gen_thrift_import_ + "\""); + + // validator import + system_packages.push_back("strings"); + system_packages.push_back("regexp"); return "import (\n" + render_system_packages(system_packages); } @@ -983,7 +716,7 @@ string t_go_generator::go_imports_begin(bool consts) { * This will have to do in lieu of more intelligent import statement construction */ string t_go_generator::go_imports_end() { - return string( + string import_end = string( ")\n\n" "// (needed to ensure safety because of naive import list construction.)\n" "var _ = thrift.ZERO\n" @@ -991,7 +724,11 @@ string t_go_generator::go_imports_end() { "var _ = errors.New\n" "var _ = context.Background\n" "var _ = time.Now\n" - "var _ = bytes.Equal\n\n"); + "var _ = bytes.Equal\n" + "// (needed by validator.)\n" + "var _ = strings.Contains\n" + "var _ = regexp.MatchString\n\n"); + return import_end; } /** @@ -1384,6 +1121,14 @@ void t_go_generator::generate_xception(t_struct* txception) { */ void t_go_generator::generate_go_struct(t_struct* tstruct, bool is_exception) { generate_go_struct_definition(f_types_, tstruct, is_exception); + // generate Validate function + std::string tstruct_name(publicize(tstruct->get_name(), false)); + f_types_ << "func (p *" << tstruct_name << ") Validate() error {" << endl; + indent_up(); + go_validator_generator(this).generate_struct_validator(f_types_, tstruct); + f_types_ << indent() << "return nil" << endl; + indent_down(); + f_types_ << "}" << endl; } void t_go_generator::get_publicized_name_and_def_value(t_field* tfield, @@ -1498,9 +1243,9 @@ void t_go_generator::generate_go_struct_definition(ostream& out, // Check for user defined tags and them if there are any. User defined tags // can override the above db and json tags. - std::map<string, string>::iterator it = (*m_iter)->annotations_.find("go.tag"); + std::map<string, std::vector<string>>::iterator it = (*m_iter)->annotations_.find("go.tag"); if (it != (*m_iter)->annotations_.end()) { - parse_go_tags(&tags, it->second); + parse_go_tags(&tags, it->second.back()); } string gotag; diff --git a/compiler/cpp/src/thrift/generate/t_go_generator.h b/compiler/cpp/src/thrift/generate/t_go_generator.h new file mode 100644 index 000000000..5080e1a0d --- /dev/null +++ b/compiler/cpp/src/thrift/generate/t_go_generator.h @@ -0,0 +1,329 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef T_GO_GENERATOR_H +#define T_GO_GENERATOR_H + +#include <fstream> +#include <iostream> +#include <limits> +#include <string> +#include <unordered_map> +#include <vector> + +#include "thrift/generate/t_generator.h" +#include "thrift/platform.h" +#include "thrift/version.h" +#include <algorithm> +#include <clocale> +#include <sstream> +#include <stdlib.h> +#include <sys/stat.h> +#include <sys/types.h> + +using std::map; +using std::ostream; +using std::ostringstream; +using std::string; +using std::stringstream; +using std::vector; + +static const string endl = "\n"; // avoid ostream << std::endl flushes + +const string DEFAULT_THRIFT_IMPORT = "github.com/apache/thrift/lib/go/thrift"; +static std::string package_flag; + +/** + * Go code generator. + */ +class t_go_generator : public t_generator { +public: + t_go_generator(t_program* program, + const std::map<std::string, std::string>& parsed_options, + const std::string& option_string) + : t_generator(program) { + (void)option_string; + std::map<std::string, std::string>::const_iterator iter; + + gen_thrift_import_ = DEFAULT_THRIFT_IMPORT; + gen_package_prefix_ = ""; + package_flag = ""; + read_write_private_ = false; + ignore_initialisms_ = false; + skip_remote_ = false; + for (iter = parsed_options.begin(); iter != parsed_options.end(); ++iter) { + if (iter->first.compare("package_prefix") == 0) { + gen_package_prefix_ = (iter->second); + } else if (iter->first.compare("thrift_import") == 0) { + gen_thrift_import_ = (iter->second); + } else if (iter->first.compare("package") == 0) { + package_flag = (iter->second); + } else if (iter->first.compare("read_write_private") == 0) { + read_write_private_ = true; + } else if (iter->first.compare("ignore_initialisms") == 0) { + ignore_initialisms_ = true; + } else if( iter->first.compare("skip_remote") == 0) { + skip_remote_ = true; + } else { + throw "unknown option go:" + iter->first; + } + } + + out_dir_base_ = "gen-go"; + } + + /** + * Init and close methods + */ + + void init_generator() override; + void close_generator() override; + + /** + * Program-level generation functions + */ + + void generate_typedef(t_typedef* ttypedef) override; + void generate_enum(t_enum* tenum) override; + void generate_const(t_const* tconst) override; + void generate_struct(t_struct* tstruct) override; + void generate_xception(t_struct* txception) override; + void generate_service(t_service* tservice) override; + + std::string render_const_value(t_type* type, + t_const_value* value, + const string& name, + bool opt = false); + + /** + * Struct generation code + */ + + void generate_go_struct(t_struct* tstruct, bool is_exception); + void generate_go_struct_definition(std::ostream& out, + t_struct* tstruct, + bool is_xception = false, + bool is_result = false, + bool is_args = false); + void generate_go_struct_initializer(std::ostream& out, + t_struct* tstruct, + bool is_args_or_result = false); + void generate_isset_helpers(std::ostream& out, + t_struct* tstruct, + const string& tstruct_name, + bool is_result = false); + void generate_countsetfields_helper(std::ostream& out, + t_struct* tstruct, + const string& tstruct_name, + bool is_result = false); + void generate_go_struct_reader(std::ostream& out, + t_struct* tstruct, + const string& tstruct_name, + bool is_result = false); + void generate_go_struct_writer(std::ostream& out, + t_struct* tstruct, + const string& tstruct_name, + bool is_result = false, + bool uses_countsetfields = false); + void generate_go_struct_equals(std::ostream& out, t_struct* tstruct, const string& tstruct_name); + void generate_go_function_helpers(t_function* tfunction); + void get_publicized_name_and_def_value(t_field* tfield, + string* OUT_pub_name, + t_const_value** OUT_def_value) const; + + /** + * Service-level generation functions + */ + + void generate_service_helpers(t_service* tservice); + void generate_service_interface(t_service* tservice); + void generate_service_client(t_service* tservice); + void generate_service_remote(t_service* tservice); + void generate_service_server(t_service* tservice); + void generate_process_function(t_service* tservice, t_function* tfunction); + + /** + * Serialization constructs + */ + + void generate_deserialize_field(std::ostream& out, + t_field* tfield, + bool declare, + std::string prefix = "", + bool inclass = false, + bool coerceData = false, + bool inkey = false, + bool in_container = false); + + void generate_deserialize_struct(std::ostream& out, + t_struct* tstruct, + bool is_pointer_field, + bool declare, + std::string prefix = ""); + + void generate_deserialize_container(std::ostream& out, + t_type* ttype, + bool pointer_field, + bool declare, + std::string prefix = ""); + + void generate_deserialize_set_element(std::ostream& out, + t_set* tset, + bool declare, + std::string prefix = ""); + + void generate_deserialize_map_element(std::ostream& out, + t_map* tmap, + bool declare, + std::string prefix = ""); + + void generate_deserialize_list_element(std::ostream& out, + t_list* tlist, + bool declare, + std::string prefix = ""); + + void generate_serialize_field(std::ostream& out, + t_field* tfield, + std::string prefix = "", + bool inkey = false); + + void generate_serialize_struct(std::ostream& out, t_struct* tstruct, std::string prefix = ""); + + void generate_serialize_container(std::ostream& out, + t_type* ttype, + bool pointer_field, + std::string prefix = ""); + + void generate_serialize_map_element(std::ostream& out, + t_map* tmap, + std::string kiter, + std::string viter); + + void generate_serialize_set_element(std::ostream& out, t_set* tmap, std::string iter); + + void generate_serialize_list_element(std::ostream& out, t_list* tlist, std::string iter); + + void generate_go_equals(std::ostream& out, t_type* ttype, string tgt, string src); + + void generate_go_equals_struct(std::ostream& out, t_type* ttype, string tgt, string src); + + void generate_go_equals_container(std::ostream& out, t_type* ttype, string tgt, string src); + + void generate_go_docstring(std::ostream& out, t_struct* tstruct); + + void generate_go_docstring(std::ostream& out, t_function* tfunction); + + void generate_go_docstring(std::ostream& out, + t_doc* tdoc, + t_struct* tstruct, + const char* subheader); + + void generate_go_docstring(std::ostream& out, t_doc* tdoc); + + void parse_go_tags(map<string, string>* tags, const string in); + + /** + * Helper rendering functions + */ + + std::string go_autogen_comment(); + std::string go_package(); + std::string go_imports_begin(bool consts); + std::string go_imports_end(); + std::string render_includes(bool consts); + std::string render_included_programs(string& unused_protection); + std::string render_program_import(const t_program* program, string& unused_protection); + std::string render_system_packages(std::vector<string>& system_packages); + std::string render_import_protection(); + std::string render_fastbinary_includes(); + std::string declare_argument(t_field* tfield); + std::string render_field_initial_value(t_field* tfield, const string& name, bool optional_field); + std::string type_name(t_type* ttype); + std::string module_name(t_type* ttype); + std::string function_signature(t_function* tfunction, std::string prefix = ""); + std::string function_signature_if(t_function* tfunction, + std::string prefix = "", + bool addError = false); + std::string argument_list(t_struct* tstruct); + std::string type_to_enum(t_type* ttype); + std::string type_to_go_type(t_type* ttype); + std::string type_to_go_type_with_opt(t_type* ttype, bool optional_field); + std::string type_to_go_key_type(t_type* ttype); + std::string type_to_spec_args(t_type* ttype); + + void indent_up() { t_generator::indent_up(); } + void indent_down() { t_generator::indent_down(); } + std::string indent() { return t_generator::indent(); } + std::ostream& indent(std::ostream& os) { return t_generator::indent(os); } + + static std::string get_real_go_module(const t_program* program) { + + if (!package_flag.empty()) { + return package_flag; + } + std::string real_module = program->get_namespace("go"); + if (!real_module.empty()) { + return real_module; + } + + return lowercase(program->get_name()); + } + + static bool is_pointer_field(t_field* tfield, bool in_container = false); + +private: + std::string gen_package_prefix_; + std::string gen_thrift_import_; + bool read_write_private_; + bool ignore_initialisms_; + bool skip_remote_; + + /** + * File streams + */ + + ofstream_with_content_based_conditional_update f_types_; + std::string f_types_name_; + ofstream_with_content_based_conditional_update f_consts_; + std::string f_consts_name_; + std::stringstream f_const_values_; + + std::string package_name_; + std::string package_dir_; + std::unordered_map<std::string, std::string> package_identifiers_; + std::set<std::string> package_identifiers_set_; + std::string read_method_name_; + std::string write_method_name_; + std::string equals_method_name_; + + std::set<std::string> commonInitialisms; + + std::string camelcase(const std::string& value) const; + void fix_common_initialism(std::string& value, int i) const; + std::string publicize(const std::string& value, bool is_args_or_result = false) const; + std::string publicize(const std::string& value, + bool is_args_or_result, + const std::string& service_name) const; + std::string privatize(const std::string& value) const; + std::string new_prefix(const std::string& value) const; + static std::string variable_name_to_go_name(const std::string& value); + static bool omit_initialization(t_field* tfield); +}; + +#endif diff --git a/compiler/cpp/src/thrift/generate/t_java_generator.cc b/compiler/cpp/src/thrift/generate/t_java_generator.cc index 3ef9028d3..a6041d7c4 100644 --- a/compiler/cpp/src/thrift/generate/t_java_generator.cc +++ b/compiler/cpp/src/thrift/generate/t_java_generator.cc @@ -374,6 +374,10 @@ public: || ttype->is_uuid() || ttype->is_enum(); } + bool is_deprecated(const std::map<std::string, std::vector<std::string>>& annotations) { + return annotations.find("deprecated") != annotations.end(); + } + bool is_deprecated(const std::map<std::string, std::string>& annotations) { return annotations.find("deprecated") != annotations.end(); } @@ -2966,7 +2970,7 @@ void t_java_generator::generate_metadata_for_field_annotations(std::ostream& out indent_up(); for (auto& annotation : field->annotations_) { indent(out) << ".add(new java.util.AbstractMap.SimpleImmutableEntry<>(\"" + annotation.first - + "\", \"" + annotation.second + "\"))" + + "\", \"" + annotation.second.back() + "\"))" << endl; } indent(out) << ".build().collect(java.util.stream.Collectors.toMap(java.util.Map.Entry::getKey, " diff --git a/compiler/cpp/src/thrift/generate/t_json_generator.cc b/compiler/cpp/src/thrift/generate/t_json_generator.cc index e16bc5098..5a854cea6 100644 --- a/compiler/cpp/src/thrift/generate/t_json_generator.cc +++ b/compiler/cpp/src/thrift/generate/t_json_generator.cc @@ -268,7 +268,9 @@ void t_json_generator::write_type_spec(t_type* ttype) { write_key_and("annotations"); start_object(); for (auto & annotation : ttype->annotations_) { - write_key_and_string(annotation.first, annotation.second); + for (auto& annotation_value : annotation.second) { + write_key_and_string(annotation.first, annotation_value); + } } end_object(); } @@ -459,7 +461,9 @@ void t_json_generator::generate_typedef(t_typedef* ttypedef) { write_key_and("annotations"); start_object(); for (auto & annotation : ttypedef->annotations_) { - write_key_and_string(annotation.first, annotation.second); + for (auto& annotation_value : annotation.second) { + write_key_and_string(annotation.first, annotation_value); + } } end_object(); } @@ -566,7 +570,9 @@ void t_json_generator::generate_enum(t_enum* tenum) { write_key_and("annotations"); start_object(); for (auto & annotation : tenum->annotations_) { - write_key_and_string(annotation.first, annotation.second); + for (auto& annotation_value : annotation.second) { + write_key_and_string(annotation.first, annotation_value); + } } end_object(); } @@ -605,7 +611,9 @@ void t_json_generator::generate_struct(t_struct* tstruct) { write_key_and("annotations"); start_object(); for (auto & annotation : tstruct->annotations_) { - write_key_and_string(annotation.first, annotation.second); + for (auto& annotation_value : annotation.second) { + write_key_and_string(annotation.first, annotation_value); + } } end_object(); } @@ -645,7 +653,9 @@ void t_json_generator::generate_service(t_service* tservice) { write_key_and("annotations"); start_object(); for (auto & annotation : tservice->annotations_) { - write_key_and_string(annotation.first, annotation.second); + for (auto& annotation_value : annotation.second) { + write_key_and_string(annotation.first, annotation_value); + } } end_object(); } @@ -682,7 +692,9 @@ void t_json_generator::generate_function(t_function* tfunc) { write_key_and("annotations"); start_object(); for (auto & annotation : tfunc->annotations_) { - write_key_and_string(annotation.first, annotation.second); + for (auto& annotation_value : annotation.second) { + write_key_and_string(annotation.first, annotation_value); + } } end_object(); } @@ -728,7 +740,9 @@ void t_json_generator::generate_field(t_field* field) { write_key_and("annotations"); start_object(); for (auto & annotation : field->annotations_) { - write_key_and_string(annotation.first, annotation.second); + for (auto& annotation_value : annotation.second) { + write_key_and_string(annotation.first, annotation_value); + } } end_object(); } diff --git a/compiler/cpp/src/thrift/generate/t_kotlin_generator.cc b/compiler/cpp/src/thrift/generate/t_kotlin_generator.cc index 1ac9c34fb..29cf00a74 100644 --- a/compiler/cpp/src/thrift/generate/t_kotlin_generator.cc +++ b/compiler/cpp/src/thrift/generate/t_kotlin_generator.cc @@ -583,7 +583,7 @@ void t_kotlin_generator::generate_metadata_for_field_annotations(std::ostream& o out << "mapOf(" << endl; indent_up(); for (auto& annotation : field->annotations_) { - indent(out) << "\"" + annotation.first + "\" to \"" + annotation.second + "\"," << endl; + indent(out) << "\"" + annotation.first + "\" to \"" + annotation.second.back() + "\"," << endl; } indent_down(); indent(out) << ")"; diff --git a/compiler/cpp/src/thrift/generate/t_netstd_generator.cc b/compiler/cpp/src/thrift/generate/t_netstd_generator.cc index 547457130..4cf3db55b 100644 --- a/compiler/cpp/src/thrift/generate/t_netstd_generator.cc +++ b/compiler/cpp/src/thrift/generate/t_netstd_generator.cc @@ -2059,8 +2059,8 @@ void t_netstd_generator::generate_deprecation_attribute(ostream& out, t_function if( func->annotations_.end() != iter) { out << indent() << "[Obsolete"; // empty annotation values end up with "1" somewhere, ignore these as well - if ((iter->second.length() > 0) && (iter->second != "1")) { - out << "(" << make_csharp_string_literal(iter->second) << ")"; + if ((iter->second.back().length() > 0) && (iter->second.back() != "1")) { + out << "(" << make_csharp_string_literal(iter->second.back()) << ")"; } out << "]" << endl; } diff --git a/compiler/cpp/src/thrift/generate/t_py_generator.cc b/compiler/cpp/src/thrift/generate/t_py_generator.cc index 33437fd00..8ab8b9881 100644 --- a/compiler/cpp/src/thrift/generate/t_py_generator.cc +++ b/compiler/cpp/src/thrift/generate/t_py_generator.cc @@ -280,12 +280,12 @@ public: } static bool is_immutable(t_type* ttype) { - std::map<std::string, std::string>::iterator it = ttype->annotations_.find("python.immutable"); + std::map<std::string, std::vector<std::string>>::iterator it = ttype->annotations_.find("python.immutable"); if (it == ttype->annotations_.end()) { // Exceptions are immutable by default. return ttype->is_xception(); - } else if (it->second == "false") { + } else if (!it->second.empty() && it->second.back() == "false") { return false; } else { return true; diff --git a/compiler/cpp/src/thrift/generate/t_xml_generator.cc b/compiler/cpp/src/thrift/generate/t_xml_generator.cc index b6692938f..220d50c68 100644 --- a/compiler/cpp/src/thrift/generate/t_xml_generator.cc +++ b/compiler/cpp/src/thrift/generate/t_xml_generator.cc @@ -94,7 +94,7 @@ public: void generate_service(t_service* tservice) override; void generate_struct(t_struct* tstruct) override; - void generate_annotations(std::map<std::string, std::string> annotations); + void generate_annotations(std::map<std::string, std::vector<std::string>> annotations); private: bool should_merge_includes_; @@ -165,17 +165,23 @@ void t_xml_generator::init_generator() { string t_xml_generator::target_namespace(t_program* program) { std::map<std::string, std::string> map; std::map<std::string, std::string>::iterator iter; - map = program->get_namespace_annotations("xml"); - if ((iter = map.find("targetNamespace")) != map.end()) { - return iter->second; + std::map<std::string, std::vector<std::string>> annotations; + std::map<std::string, std::vector<std::string>>::iterator annotations_iter; + annotations = program->get_namespace_annotations("xml"); + if ((annotations_iter = annotations.find("targetNamespace")) != annotations.end()) { + if (!annotations_iter->second.empty()) { + return annotations_iter->second.back(); + } } map = program->get_namespaces(); if ((iter = map.find("xml")) != map.end()) { return default_ns_prefix + iter->second; } - map = program->get_namespace_annotations("*"); - if ((iter = map.find("xml.targetNamespace")) != map.end()) { - return iter->second; + annotations = program->get_namespace_annotations("*"); + if ((annotations_iter = annotations.find("xml.targetNamespace")) != annotations.end()) { + if (!annotations_iter->second.empty()) { + return annotations_iter->second.back(); + } } map = program->get_namespaces(); if ((iter = map.find("*")) != map.end()) { @@ -432,13 +438,15 @@ void t_xml_generator::write_doc(t_doc* tdoc) { } void t_xml_generator::generate_annotations( - std::map<std::string, std::string> annotations) { - std::map<std::string, std::string>::iterator iter; + std::map<std::string, std::vector<std::string>> annotations) { + std::map<std::string, std::vector<std::string>>::iterator iter; for (iter = annotations.begin(); iter != annotations.end(); ++iter) { - write_element_start("annotation"); - write_attribute("key", iter->first); - write_attribute("value", iter->second); - write_element_end(); + for (auto& annotations_value : iter->second) { + write_element_start("annotation"); + write_attribute("key", iter->first); + write_attribute("value", annotations_value); + write_element_end(); + } } } diff --git a/compiler/cpp/src/thrift/generate/t_xsd_generator.cc b/compiler/cpp/src/thrift/generate/t_xsd_generator.cc index a10f05959..15ede750f 100644 --- a/compiler/cpp/src/thrift/generate/t_xsd_generator.cc +++ b/compiler/cpp/src/thrift/generate/t_xsd_generator.cc @@ -262,10 +262,10 @@ void t_xsd_generator::generate_service(t_service* tservice) { f_xsd_.open(f_xsd_name.c_str()); string ns = program_->get_namespace("xsd"); - const std::map<std::string, std::string> annot = program_->get_namespace_annotations("xsd"); - const std::map<std::string, std::string>::const_iterator uri = annot.find("uri"); - if (uri != annot.end()) { - ns = uri->second; + const std::map<std::string, std::vector<std::string>> annot = program_->get_namespace_annotations("xsd"); + const std::map<std::string, std::vector<std::string>>::const_iterator uri = annot.find("uri"); + if (uri != annot.end() && !uri->second.empty()) { + ns = uri->second.back(); } if (ns.size() > 0) { ns = " targetNamespace=\"" + ns + "\" xmlns=\"" + ns + "\" " diff --git a/compiler/cpp/src/thrift/generate/validator_parser.cc b/compiler/cpp/src/thrift/generate/validator_parser.cc new file mode 100644 index 000000000..84261fe6f --- /dev/null +++ b/compiler/cpp/src/thrift/generate/validator_parser.cc @@ -0,0 +1,550 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * This file is programmatically sanitized for style: + * astyle --style=1tbs -f -p -H -j -U t_validator_parser.cc + * + * The output of astyle should not be taken unquestioningly, but it is a good + * guide for ensuring uniformity and readability. + */ + +#include <fstream> +#include <iostream> +#include <limits> +#include <string> +#include <unordered_map> +#include <vector> + +#include "thrift/generate/t_generator.h" +#include "thrift/generate/validator_parser.h" +#include "thrift/platform.h" +#include "thrift/version.h" +#include <algorithm> +#include <clocale> +#include <sstream> +#include <stdlib.h> +#include <sys/stat.h> +#include <sys/types.h> + +const char* list_delimiter = "[], "; + +std::vector<validation_rule*> validation_parser::parse_field( + t_type* type, + std::map<std::string, std::vector<std::string>>& annotations) { + if (type->is_typedef()) { + type = type->get_true_type(); + } + if (type->is_enum()) { + return parse_enum_field(type, annotations); + } else if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_UUID: + case t_base_type::TYPE_VOID: + break; + case t_base_type::TYPE_I8: + case t_base_type::TYPE_I16: + case t_base_type::TYPE_I32: + case t_base_type::TYPE_I64: + return parse_integer_field(type, annotations); + case t_base_type::TYPE_DOUBLE: + return parse_double_field(type, annotations); + case t_base_type::TYPE_STRING: + return parse_string_field(type, annotations); + case t_base_type::TYPE_BOOL: + return parse_bool_field(type, annotations); + } + } else if (type->is_list()) { + return parse_list_field(type, annotations); + } else if (type->is_set()) { + return parse_set_field(type, annotations); + } else if (type->is_map()) { + return parse_map_field(type, annotations); + } else if (type->is_struct()) { + if (((t_struct*)type)->is_union()) { + return parse_union_field(type, annotations); + } + return parse_struct_field(type, annotations); + } else if (type->is_xception()) { + return parse_xception_field(type, annotations); + } + throw "validator error: unsupported type: " + type->get_name(); +} + +std::vector<validation_rule*> validation_parser::parse_bool_field( + t_type* type, + std::map<std::string, std::vector<std::string>>& annotations) { + (void)type; + std::vector<validation_rule*> rules; + add_bool_rule(rules, "vt.const", annotations); + return rules; +} + +std::vector<validation_rule*> validation_parser::parse_enum_field( + t_type* type, + std::map<std::string, std::vector<std::string>>& annotations) { + std::vector<validation_rule*> rules; + add_bool_rule(rules, "vt.defined_only", annotations); + add_enum_list_rule(rules, (t_enum*)type, "vt.in", annotations); + add_enum_list_rule(rules, (t_enum*)type, "vt.not_in", annotations); + return rules; +} + +std::vector<validation_rule*> validation_parser::parse_double_field( + t_type* type, + std::map<std::string, std::vector<std::string>>& annotations) { + (void)type; + std::vector<validation_rule*> rules; + add_double_rule(rules, "vt.lt", annotations); + add_double_rule(rules, "vt.le", annotations); + add_double_rule(rules, "vt.gt", annotations); + add_double_rule(rules, "vt.ge", annotations); + add_double_list_rule(rules, "vt.in", annotations); + add_double_list_rule(rules, "vt.not_in", annotations); + return rules; +} + +std::vector<validation_rule*> validation_parser::parse_integer_field( + t_type* type, + std::map<std::string, std::vector<std::string>>& annotations) { + (void)type; + std::vector<validation_rule*> rules; + add_integer_rule(rules, "vt.lt", annotations); + add_integer_rule(rules, "vt.le", annotations); + add_integer_rule(rules, "vt.gt", annotations); + add_integer_rule(rules, "vt.ge", annotations); + add_integer_list_rule(rules, "vt.in", annotations); + add_integer_list_rule(rules, "vt.not_in", annotations); + return rules; +} + +std::vector<validation_rule*> validation_parser::parse_string_field( + t_type* type, + std::map<std::string, std::vector<std::string>>& annotations) { + (void)type; + std::vector<validation_rule*> rules; + add_string_rule(rules, "vt.const", annotations); + add_integer_rule(rules, "vt.min_size", annotations); + add_integer_rule(rules, "vt.max_size", annotations); + add_string_rule(rules, "vt.pattern", annotations); + add_string_rule(rules, "vt.prefix", annotations); + add_string_rule(rules, "vt.suffix", annotations); + add_string_rule(rules, "vt.contains", annotations); + add_string_rule(rules, "vt.not_contains", annotations); + return rules; +} + +std::vector<validation_rule*> validation_parser::parse_set_field( + t_type* type, + std::map<std::string, std::vector<std::string>>& annotations) { + (void)type; + return parse_list_field(type, annotations); +} + +std::vector<validation_rule*> validation_parser::parse_list_field( + t_type* type, + std::map<std::string, std::vector<std::string>>& annotations) { + (void)type; + std::vector<validation_rule*> rules; + add_integer_rule(rules, "vt.min_size", annotations); + add_integer_rule(rules, "vt.max_size", annotations); + std::string elem_prefix("vt.elem"); + std::map<std::string, std::vector<std::string>> elem_annotations; + for (auto it = annotations.begin(); it != annotations.end(); it++) { + if (it->first.compare(0, elem_prefix.size(), elem_prefix) == 0) { + std::string elem_key = "vt" + it->first.substr(elem_prefix.size()); + elem_annotations[elem_key] = it->second; + } + } + std::vector<validation_rule*> elem_rules; + if (type->is_list()) { + elem_rules = parse_field(((t_list*)type)->get_elem_type(), elem_annotations); + } else if (type->is_set()) { + elem_rules = parse_field(((t_set*)type)->get_elem_type(), elem_annotations); + } + for (auto it = elem_rules.begin(); it != elem_rules.end(); it++) { + validation_rule* rule = new validation_rule(elem_prefix, *it); + rules.push_back(rule); + } + return rules; +} + +std::vector<validation_rule*> validation_parser::parse_map_field( + t_type* type, + std::map<std::string, std::vector<std::string>>& annotations) { + std::vector<validation_rule*> rules; + add_integer_rule(rules, "vt.min_size", annotations); + add_integer_rule(rules, "vt.max_size", annotations); + std::string key_prefix("vt.key"); + std::map<std::string, std::vector<std::string>> key_annotations; + for (auto it = annotations.begin(); it != annotations.end(); it++) { + if (it->first.compare(0, key_prefix.size(), key_prefix) == 0) { + std::string key_key = "vt" + it->first.substr(key_prefix.size()); + key_annotations[key_key] = it->second; + } + } + std::vector<validation_rule*> key_rules; + key_rules = parse_field(((t_map*)type)->get_key_type(), key_annotations); + for (auto it = key_rules.begin(); it != key_rules.end(); it++) { + validation_rule* rule = new validation_rule(key_prefix, *it); + rules.push_back(rule); + } + + std::string value_prefix("vt.value"); + std::map<std::string, std::vector<std::string>> value_annotations; + for (auto it = annotations.begin(); it != annotations.end(); it++) { + if (it->first.compare(0, value_prefix.size(), value_prefix) == 0) { + std::string value_key = "vt" + it->first.substr(value_prefix.size()); + value_annotations[value_key] = it->second; + } + } + std::vector<validation_rule*> value_rules; + value_rules = parse_field(((t_map*)type)->get_val_type(), value_annotations); + for (auto it = value_rules.begin(); it != value_rules.end(); it++) { + validation_rule* rule = new validation_rule(value_prefix, *it); + rules.push_back(rule); + } + return rules; +} + +std::vector<validation_rule*> validation_parser::parse_struct_field( + t_type* type, + std::map<std::string, std::vector<std::string>>& annotations) { + (void)type; + std::vector<validation_rule*> rules; + add_bool_rule(rules, "vt.skip", annotations); + return rules; +} + +std::vector<validation_rule*> validation_parser::parse_xception_field( + t_type* type, + std::map<std::string, std::vector<std::string>>& annotations) { + return parse_struct_field(type, annotations); +} + +std::vector<validation_rule*> validation_parser::parse_union_field( + t_type* type, + std::map<std::string, std::vector<std::string>>& annotations) { + return parse_struct_field(type, annotations); +} + +bool validation_parser::is_reference_field(std::string value) { + if (value[0] != '$') { + return false; + } + value.erase(value.begin()); + t_field* field = this->reference->get_field_by_name(value); + return field != nullptr; +} + +bool validation_parser::is_validation_function(std::string value) { + if (value[0] != '@') { + return false; + } + return true; +} + +void validation_parser::add_bool_rule( + std::vector<validation_rule*>& rules, + std::string key, + std::map<std::string, std::vector<std::string>>& annotations) { + auto it = annotations.find(key); + if (it != annotations.end() && !it->second.empty()) { + for (auto& annotation_value : it->second) { + validation_rule* rule = new validation_rule(key); + validation_value* value; + if (is_reference_field(annotation_value)) { + t_field* field = get_referenced_field(annotation_value); + value = new validation_value(field); + } else { + bool constant; + std::istringstream(it->second.back()) >> std::boolalpha >> constant; + value = new validation_value(constant); + } + rule->append_value(value); + rules.push_back(rule); + } + } +} + +void validation_parser::add_double_rule( + std::vector<validation_rule*>& rules, + std::string key, + std::map<std::string, std::vector<std::string>>& annotations) { + auto it = annotations.find(key); + if (it != annotations.end() && !it->second.empty()) { + for (auto& annotation_value : it->second) { + if (annotation_value.size() == 0) { + continue; + } + validation_rule* rule = new validation_rule(key); + validation_value* value; + if (is_validation_function(annotation_value)) { + validation_value::validation_function* function = get_validation_function(annotation_value); + value = new validation_value(function); + } else if (is_reference_field(annotation_value)) { + t_field* field = get_referenced_field(annotation_value); + value = new validation_value(field); + } else { + double constant = std::stod(annotation_value); + value = new validation_value(constant); + } + rule->append_value(value); + rules.push_back(rule); + } + } +} + +void validation_parser::add_enum_list_rule( + std::vector<validation_rule*>& rules, + t_enum* enum_, + std::string key, + std::map<std::string, std::vector<std::string>>& annotations) { + auto it = annotations.find(key); + if (it != annotations.end() && !it->second.empty()) { + for (auto& annotation_value : it->second) { + if (annotation_value.size() == 0) { + continue; + } + validation_rule* rule = new validation_rule(key); + if (annotation_value[0] == '[') { + validation_value* value; + char* str = strdup(annotation_value.c_str()); + char* pch = strtok(str, list_delimiter); + std::string val; + while (pch != NULL) { + std::string temp(pch); + if (is_validation_function(temp)) { + validation_value::validation_function* function = get_validation_function(temp); + value = new validation_value(function); + } else if (is_reference_field(temp)) { + t_field* field = get_referenced_field(temp); + value = new validation_value(field); + } else if (std::stringstream(temp) >> val) { + std::string::size_type dot = val.rfind('.'); + if (dot != std::string::npos) { + val = val.substr(dot + 1); + } + t_enum_value* enum_val = enum_->get_constant_by_name(val); + value = new validation_value(enum_val); + } else { + delete rule; + throw "validator error: validation double list parse failed: " + temp; + } + rule->append_value(value); + pch = strtok(NULL, list_delimiter); + } + } else { + validation_value* value; + std::string val = annotation_value; + std::string::size_type dot = val.rfind('.'); + if (dot != std::string::npos) { + val = val.substr(dot + 1); + } + t_enum_value* enum_val = enum_->get_constant_by_name(val); + value = new validation_value(enum_val); + rule->append_value(value); + } + rules.push_back(rule); + } + } +} + +void validation_parser::add_double_list_rule( + std::vector<validation_rule*>& rules, + std::string key, + std::map<std::string, std::vector<std::string>>& annotations) { + std::map<std::string, std::vector<std::string>> double_rules; + auto it = annotations.find(key); + if (it != annotations.end() && !it->second.empty()) { + for (auto& annotation_value : it->second) { + if (annotation_value.size() == 0) { + continue; + } + if (annotation_value[0] == '[') { + validation_rule* rule = new validation_rule(key); + validation_value* value; + char* str = strdup(annotation_value.c_str()); + char* pch = strtok(str, list_delimiter); + double val; + while (pch != NULL) { + std::string temp(pch); + if (is_validation_function(temp)) { + validation_value::validation_function* function = get_validation_function(temp); + value = new validation_value(function); + } else if (is_reference_field(temp)) { + t_field* field = get_referenced_field(temp); + value = new validation_value(field); + } else if (std::stringstream(temp) >> val) { + value = new validation_value(val); + } else { + delete rule; + throw "validator error: validation double list parse failed: " + temp; + } + rule->append_value(value); + pch = strtok(NULL, list_delimiter); + } + rules.push_back(rule); + } else { + double_rules[key].push_back(annotation_value); + } + } + } + if (double_rules[key].size() > 0) { + add_double_rule(rules, key, double_rules); + } +} + +void validation_parser::add_integer_rule( + std::vector<validation_rule*>& rules, + std::string key, + std::map<std::string, std::vector<std::string>>& annotations) { + auto it = annotations.find(key); + if (it != annotations.end() && !it->second.empty()) { + for (auto& annotation_value : it->second) { + if (annotation_value.size() == 0) { + continue; + } + validation_rule* rule = new validation_rule(key); + validation_value* value; + if (is_reference_field(annotation_value)) { + t_field* field = get_referenced_field(annotation_value); + value = new validation_value(field); + } else if (is_validation_function(annotation_value)) { + validation_value::validation_function* function = get_validation_function(annotation_value); + value = new validation_value(function); + } else { + int64_t constant = std::stoll(annotation_value); + value = new validation_value(constant); + } + rule->append_value(value); + rules.push_back(rule); + } + } +} + +void validation_parser::add_integer_list_rule( + std::vector<validation_rule*>& rules, + std::string key, + std::map<std::string, std::vector<std::string>>& annotations) { + std::map<std::string, std::vector<std::string>> integer_rules; + auto it = annotations.find(key); + if (it != annotations.end() && !it->second.empty()) { + for (auto& annotation_value : it->second) { + if (annotation_value.size() == 0) { + continue; + } + if (annotation_value[0] == '[') { + validation_rule* rule = new validation_rule(key); + validation_value* value; + char* str = strdup(annotation_value.c_str()); + char* pch = strtok(str, list_delimiter); + int64_t val; + while (pch != NULL) { + std::string temp(pch); + if (is_validation_function(temp)) { + validation_value::validation_function* function = get_validation_function(temp); + value = new validation_value(function); + } else if (is_reference_field(temp)) { + t_field* field = get_referenced_field(temp); + value = new validation_value(field); + } else if (std::stringstream(temp) >> val) { + value = new validation_value(val); + } else { + delete rule; + throw "validator error: validation integer list parse failed: " + temp; + } + rule->append_value(value); + pch = strtok(NULL, list_delimiter); + } + rules.push_back(rule); + } else { + integer_rules[key].push_back(annotation_value); + } + } + } + if (integer_rules[key].size() > 0) { + add_integer_rule(rules, key, integer_rules); + } +} + +void validation_parser::add_string_rule( + std::vector<validation_rule*>& rules, + std::string key, + std::map<std::string, std::vector<std::string>>& annotations) { + auto it = annotations.find(key); + if (it != annotations.end() && !it->second.empty()) { + for (auto& annotation_value : it->second) { + validation_rule* rule = new validation_rule(key); + validation_value* value; + if (is_reference_field(annotation_value)) { + t_field* field = get_referenced_field(annotation_value); + value = new validation_value(field); + } else { + value = new validation_value(annotation_value); + } + rule->append_value(value); + rules.push_back(rule); + } + } +} + +t_field* validation_parser::get_referenced_field(std::string annotation_value) { + annotation_value.erase(annotation_value.begin()); + return reference->get_field_by_name(annotation_value); +} + +validation_value::validation_function* validation_parser::get_validation_function( + std::string annotation_value) { + std::string value = annotation_value; + value.erase(value.begin()); + validation_value::validation_function* function = new validation_value::validation_function; + + size_t name_end = value.find_first_of('('); + if (name_end >= value.size()) { + delete function; + throw "validator error: validation function parse failed: " + annotation_value; + } + function->name = value.substr(0, name_end); + value.erase(0, name_end + 1); // name( + + if (function->name == "len") { + size_t argument_end = value.find_first_of(')'); + if (argument_end >= value.size()) { + delete function; + throw "validator error: validation function parse failed: " + annotation_value; + } + std::string argument = value.substr(0, argument_end); + if (argument.size() > 0 && argument[0] == '$') { + t_field* field = get_referenced_field(argument); + validation_value* value = new validation_value(field); + function->arguments.push_back(value); + } else { + delete function; + throw "validator error: validation function parse failed, unrecognized argument: " + + annotation_value; + } + } else { + delete function; + throw "validator error: validation function parse failed, function not supported: " + + annotation_value; + } + return function; +}
\ No newline at end of file diff --git a/compiler/cpp/src/thrift/generate/validator_parser.h b/compiler/cpp/src/thrift/generate/validator_parser.h new file mode 100644 index 000000000..076af2ed6 --- /dev/null +++ b/compiler/cpp/src/thrift/generate/validator_parser.h @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef T_VALIDATOR_GENERATOR_H +#define T_VALIDATOR_GENERATOR_H + +#include "thrift/generate/t_generator.h" +#include <fstream> +#include <iostream> +#include <limits> +#include <string> +#include <vector> + +class validation_value { +public: + struct validation_function { + public: + std::string name; + std::vector<validation_value*> arguments; + }; + + enum validation_value_type { + VV_INTEGER, + VV_DOUBLE, + VV_BOOL, + VV_ENUM, + VV_STRING, + VV_FUNCTION, + VV_FIELD_REFERENCE, + VV_UNKNOWN + }; + + validation_value() : val_type(VV_UNKNOWN) {} + validation_value(const int64_t val) : int_val(val), val_type(VV_INTEGER) {} + validation_value(const double val) : double_val(val), val_type(VV_DOUBLE) {} + validation_value(const bool val) : bool_val(val), val_type(VV_BOOL) {} + validation_value(t_enum_value* val) : enum_val(val), val_type(VV_ENUM) {} + validation_value(const std::string val) : string_val(val), val_type(VV_STRING) {} + validation_value(validation_function* val) : function_val(val), val_type(VV_FUNCTION) {} + validation_value(t_field* val) : field_reference_val(val), val_type(VV_FIELD_REFERENCE) {} + + void set_int(const int64_t val) { + int_val = val; + val_type = VV_INTEGER; + } + int64_t get_int() const { return int_val; }; + + void set_double(const double val) { + double_val = val; + val_type = VV_DOUBLE; + } + double get_double() { return double_val; }; + + void set_bool(const bool val) { + bool_val = val; + val_type = VV_BOOL; + } + bool get_bool() const { return bool_val; }; + + void set_enum(t_enum_value* val) { + enum_val = val; + val_type = VV_ENUM; + } + t_enum_value* get_enum() const { return enum_val; }; + + void set_string(const std::string val) { + string_val = val; + val_type = VV_STRING; + } + std::string get_string() const { return string_val; }; + + void set_function(validation_function* val) { + function_val = val; + val_type = VV_FUNCTION; + } + + validation_function* get_function() { return function_val; }; + + void set_field_reference(t_field* val) { + field_reference_val = val; + val_type = VV_FIELD_REFERENCE; + } + t_field* get_field_reference() const { return field_reference_val; }; + + bool is_field_reference() const { return val_type == VV_FIELD_REFERENCE; }; + + bool is_validation_function() const { return val_type == VV_FUNCTION; }; + + validation_value_type get_type() const { return val_type; }; + +private: + int64_t int_val = 0; + double double_val = 0.0; + bool bool_val = false; + t_enum_value* enum_val = nullptr; + std::string string_val; + validation_function* function_val; + t_field* field_reference_val; + + validation_value_type val_type; +}; + +class validation_rule { +public: + validation_rule(){}; + validation_rule(std::string name) : name(name){}; + validation_rule(std::string name, validation_rule* inner) : name(name), inner(inner){}; + + std::string get_name() { return name; }; + void append_value(validation_value* value) { values.push_back(value); } + const std::vector<validation_value*>& get_values() { return values; }; + validation_rule* get_inner() { return inner; }; + +private: + std::string name; + std::vector<validation_value*> values; + validation_rule* inner; +}; + +class validation_parser { +public: + validation_parser() {} + validation_parser(t_struct* reference) : reference(reference) {} + std::vector<validation_rule*> parse_field( + t_type* type, + std::map<std::string, std::vector<std::string>>& annotations); + void set_reference(t_struct* reference) { this->reference = reference; }; + +private: + std::vector<validation_rule*> parse_bool_field( + t_type* type, + std::map<std::string, std::vector<std::string>>& annotations); + std::vector<validation_rule*> parse_enum_field( + t_type* type, + std::map<std::string, std::vector<std::string>>& annotations); + std::vector<validation_rule*> parse_double_field( + t_type* type, + std::map<std::string, std::vector<std::string>>& annotations); + std::vector<validation_rule*> parse_integer_field( + t_type* type, + std::map<std::string, std::vector<std::string>>& annotations); + std::vector<validation_rule*> parse_string_field( + t_type* type, + std::map<std::string, std::vector<std::string>>& annotations); + std::vector<validation_rule*> parse_set_field( + t_type* type, + std::map<std::string, std::vector<std::string>>& annotations); + std::vector<validation_rule*> parse_list_field( + t_type* type, + std::map<std::string, std::vector<std::string>>& annotations); + std::vector<validation_rule*> parse_map_field( + t_type* type, + std::map<std::string, std::vector<std::string>>& annotations); + std::vector<validation_rule*> parse_struct_field( + t_type* type, + std::map<std::string, std::vector<std::string>>& annotations); + std::vector<validation_rule*> parse_xception_field( + t_type* type, + std::map<std::string, std::vector<std::string>>& annotations); + std::vector<validation_rule*> parse_union_field( + t_type* type, + std::map<std::string, std::vector<std::string>>& annotations); + bool is_reference_field(std::string value); + bool is_validation_function(std::string value); + void add_bool_rule(std::vector<validation_rule*>& rules, + std::string key, + std::map<std::string, std::vector<std::string>>& annotations); + void add_double_rule(std::vector<validation_rule*>& rules, + std::string key, + std::map<std::string, std::vector<std::string>>& annotations); + void add_double_list_rule(std::vector<validation_rule*>& rules, + std::string key, + std::map<std::string, std::vector<std::string>>& annotations); + void add_integer_rule(std::vector<validation_rule*>& rules, + std::string key, + std::map<std::string, std::vector<std::string>>& annotations); + void add_integer_list_rule(std::vector<validation_rule*>& rules, + std::string key, + std::map<std::string, std::vector<std::string>>& annotations); + void add_string_rule(std::vector<validation_rule*>& rules, + std::string key, + std::map<std::string, std::vector<std::string>>& annotations); + void add_enum_list_rule(std::vector<validation_rule*>& rules, + t_enum* enum_, + std::string key, + std::map<std::string, std::vector<std::string>>& annotations); + t_field* get_referenced_field(std::string annotation_value); + validation_value::validation_function* get_validation_function(std::string annotation_value); + t_struct* reference; +}; + +#endif diff --git a/compiler/cpp/src/thrift/parse/t_enum_value.h b/compiler/cpp/src/thrift/parse/t_enum_value.h index 70eee8618..c6558d781 100644 --- a/compiler/cpp/src/thrift/parse/t_enum_value.h +++ b/compiler/cpp/src/thrift/parse/t_enum_value.h @@ -20,9 +20,9 @@ #ifndef T_ENUM_VALUE_H #define T_ENUM_VALUE_H +#include "thrift/parse/t_doc.h" #include <map> #include <string> -#include "thrift/parse/t_doc.h" /** * A constant. These are used inside of enum definitions. Constants are just @@ -40,7 +40,7 @@ public: int get_value() const { return value_; } - std::map<std::string, std::string> annotations_; + std::map<std::string, std::vector<std::string>> annotations_; private: std::string name_; diff --git a/compiler/cpp/src/thrift/parse/t_field.h b/compiler/cpp/src/thrift/parse/t_field.h index f0a607de0..928fdcf93 100644 --- a/compiler/cpp/src/thrift/parse/t_field.h +++ b/compiler/cpp/src/thrift/parse/t_field.h @@ -21,8 +21,8 @@ #define T_FIELD_H #include <map> -#include <string> #include <sstream> +#include <string> #include "thrift/parse/t_doc.h" #include "thrift/parse/t_type.h" @@ -106,7 +106,7 @@ public: } }; - std::map<std::string, std::string> annotations_; + std::map<std::string, std::vector<std::string>> annotations_; bool get_reference() const { return reference_; } diff --git a/compiler/cpp/src/thrift/parse/t_function.h b/compiler/cpp/src/thrift/parse/t_function.h index d30c8a46e..bc0ae465b 100644 --- a/compiler/cpp/src/thrift/parse/t_function.h +++ b/compiler/cpp/src/thrift/parse/t_function.h @@ -20,10 +20,10 @@ #ifndef T_FUNCTION_H #define T_FUNCTION_H -#include <string> -#include "thrift/parse/t_type.h" -#include "thrift/parse/t_struct.h" #include "thrift/parse/t_doc.h" +#include "thrift/parse/t_struct.h" +#include "thrift/parse/t_type.h" +#include <string> /** * Representation of a function. Key parts are return type, function name, @@ -79,7 +79,7 @@ public: bool is_oneway() const { return oneway_; } - std::map<std::string, std::string> annotations_; + std::map<std::string, std::vector<std::string>> annotations_; private: t_type* returntype_; diff --git a/compiler/cpp/src/thrift/parse/t_program.h b/compiler/cpp/src/thrift/parse/t_program.h index b6b1332c0..23c6463a2 100644 --- a/compiler/cpp/src/thrift/parse/t_program.h +++ b/compiler/cpp/src/thrift/parse/t_program.h @@ -331,20 +331,20 @@ public: return namespaces_; } - void set_namespace_annotations(std::string language, std::map<std::string, std::string> annotations) { + void set_namespace_annotations(std::string language, std::map<std::string, std::vector<std::string>> annotations) { namespace_annotations_[language] = annotations; } - const std::map<std::string, std::string>& get_namespace_annotations(const std::string& language) const { + const std::map<std::string, std::vector<std::string>>& get_namespace_annotations(const std::string& language) const { auto it = namespace_annotations_.find(language); if (namespace_annotations_.end() != it) { return it->second; } - static const std::map<std::string, std::string> emptyMap; + static const std::map<std::string, std::vector<std::string>> emptyMap; return emptyMap; } - std::map<std::string, std::string>& get_namespace_annotations(const std::string& language) { + std::map<std::string, std::vector<std::string>>& get_namespace_annotations(const std::string& language) { return namespace_annotations_[language]; } @@ -400,7 +400,7 @@ private: std::map<std::string, std::string> namespaces_; // Annotations for dynamic namespaces - std::map<std::string, std::map<std::string, std::string> > namespace_annotations_; + std::map<std::string, std::map<std::string, std::vector<std::string>>> namespace_annotations_; // C++ extra includes std::vector<std::string> cpp_includes_; diff --git a/compiler/cpp/src/thrift/parse/t_type.h b/compiler/cpp/src/thrift/parse/t_type.h index 8dbeb9efc..f4082426c 100644 --- a/compiler/cpp/src/thrift/parse/t_type.h +++ b/compiler/cpp/src/thrift/parse/t_type.h @@ -20,11 +20,11 @@ #ifndef T_TYPE_H #define T_TYPE_H -#include <string> -#include <map> +#include "thrift/parse/t_doc.h" #include <cstring> +#include <map> #include <stdint.h> -#include "thrift/parse/t_doc.h" +#include <string> class t_program; @@ -83,7 +83,7 @@ public: return rv; } - std::map<std::string, std::string> annotations_; + std::map<std::string, std::vector<std::string>> annotations_; protected: t_type() : program_(nullptr) { ; } diff --git a/compiler/cpp/src/thrift/thrifty.yy b/compiler/cpp/src/thrift/thrifty.yy index 82b2be57e..bb2c19e44 100644 --- a/compiler/cpp/src/thrift/thrifty.yy +++ b/compiler/cpp/src/thrift/thrifty.yy @@ -1301,7 +1301,7 @@ TypeAnnotationList: { pdebug("TypeAnnotationList -> TypeAnnotationList , TypeAnnotation"); $$ = $1; - $$->annotations_[$2->key] = $2->val; + $$->annotations_[$2->key].push_back($2->val); delete $2; } | diff --git a/compiler/cpp/tests/CMakeLists.txt b/compiler/cpp/tests/CMakeLists.txt index 0e8254158..b8b2777a9 100644 --- a/compiler/cpp/tests/CMakeLists.txt +++ b/compiler/cpp/tests/CMakeLists.txt @@ -76,6 +76,8 @@ set(thrift_compiler_SOURCES ${THRIFT_COMPILER_SOURCE_DIR}/src/thrift/audit/t_audit.cpp ${THRIFT_COMPILER_SOURCE_DIR}/src/thrift/common.cc ${THRIFT_COMPILER_SOURCE_DIR}/src/thrift/generate/t_generator.cc + ${THRIFT_COMPILER_SOURCE_DIR}/src/thrift/generate/validator_parser.cc + ${THRIFT_COMPILER_SOURCE_DIR}/src/thrift/generate/validator_parser.h ${THRIFT_COMPILER_SOURCE_DIR}/src/thrift/parse/t_typedef.cc ${THRIFT_COMPILER_SOURCE_DIR}/src/thrift/parse/parse.cc ${THRIFT_COMPILER_SOURCE_DIR}/thrift/version.h @@ -96,6 +98,18 @@ macro(THRIFT_ADD_COMPILER name description initial) endif() endmacro() +# This macro adds an option THRIFT_VALIDATOR_COMPILER_${NAME} +# that allows enabling or disabling certain languages +macro(THRIFT_ADD_VALIDATOR_COMPILER name description initial) + string(TOUPPER "THRIFT_VALIDATOR_COMPILER_${name}" enabler) + set(src "${THRIFT_COMPILER_SOURCE_DIR}/src/thrift/generate/${name}_validator_generator.cc") + list(APPEND "${THRIFT_COMPILER_SOURCE_DIR}/src/thrift/generate/${name}_validator_generator.h") + option(${enabler} ${description} ${initial}) + if(${enabler}) + list(APPEND thrift-compiler_SOURCES ${src}) + endif() +endmacro() + # The following compiler with unit tests can be enabled or disabled THRIFT_ADD_COMPILER(c_glib "Enable compiler for C with Glib" OFF) THRIFT_ADD_COMPILER(cl "Enable compiler for Common LISP" OFF) @@ -125,6 +139,9 @@ THRIFT_ADD_COMPILER(swift "Enable compiler for Swift" OFF) THRIFT_ADD_COMPILER(xml "Enable compiler for XML" OFF) THRIFT_ADD_COMPILER(xsd "Enable compiler for XSD" OFF) +# The following compiler can be enabled or disabled by enabling or disabling certain languages +THRIFT_ADD_VALIDATOR_COMPILER(go "Enable validator compiler for Go" ON) + # Thrift is looking for include files in the src directory # we also add the current binary directory for generated files include_directories(${CMAKE_CURRENT_BINARY_DIR} ${THRIFT_COMPILER_SOURCE_DIR}/src ${CMAKE_CURRENT_SOURCE_DIR}/catch) |