summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSimon Wang <jellalleonhardt4869@gmail.com>2021-09-13 19:50:45 +0800
committerYuxuan 'fishy' Wang <yuxuan.wang@reddit.com>2022-09-11 08:25:32 -0700
commitd5927a96019154fa590c38f3a7ca70275af11b3c (patch)
tree4c0a412c9d0bd333e24f21cc6d13bedadc487268
parent944b8e68a099392d80153ebcf26f32ff7f1d893a (diff)
downloadthrift-d5927a96019154fa590c38f3a7ca70275af11b3c.tar.gz
THRIFT-5423: IDL parameter validation for Go
Closes https://github.com/apache/thrift/pull/2469.
-rw-r--r--compiler/cpp/CMakeLists.txt75
-rw-r--r--compiler/cpp/Makefile.am7
-rw-r--r--compiler/cpp/src/thrift/generate/go_validator_generator.cc906
-rw-r--r--compiler/cpp/src/thrift/generate/go_validator_generator.h72
-rw-r--r--compiler/cpp/src/thrift/generate/t_cpp_generator.cc6
-rw-r--r--compiler/cpp/src/thrift/generate/t_delphi_generator.cc6
-rw-r--r--compiler/cpp/src/thrift/generate/t_go_generator.cc299
-rw-r--r--compiler/cpp/src/thrift/generate/t_go_generator.h329
-rw-r--r--compiler/cpp/src/thrift/generate/t_java_generator.cc6
-rw-r--r--compiler/cpp/src/thrift/generate/t_json_generator.cc28
-rw-r--r--compiler/cpp/src/thrift/generate/t_kotlin_generator.cc2
-rw-r--r--compiler/cpp/src/thrift/generate/t_netstd_generator.cc4
-rw-r--r--compiler/cpp/src/thrift/generate/t_py_generator.cc4
-rw-r--r--compiler/cpp/src/thrift/generate/t_xml_generator.cc34
-rw-r--r--compiler/cpp/src/thrift/generate/t_xsd_generator.cc8
-rw-r--r--compiler/cpp/src/thrift/generate/validator_parser.cc550
-rw-r--r--compiler/cpp/src/thrift/generate/validator_parser.h208
-rw-r--r--compiler/cpp/src/thrift/parse/t_enum_value.h4
-rw-r--r--compiler/cpp/src/thrift/parse/t_field.h4
-rw-r--r--compiler/cpp/src/thrift/parse/t_function.h8
-rw-r--r--compiler/cpp/src/thrift/parse/t_program.h10
-rw-r--r--compiler/cpp/src/thrift/parse/t_type.h8
-rw-r--r--compiler/cpp/src/thrift/thrifty.yy2
-rw-r--r--compiler/cpp/tests/CMakeLists.txt17
-rw-r--r--doc/specs/thrift-parameter-validation-proposal.md195
-rw-r--r--lib/go/test/Makefile.am10
-rw-r--r--lib/go/test/ValidateTest.thrift104
-rw-r--r--lib/go/test/tests/validate_test.go494
-rw-r--r--lib/go/thrift/application_exception.go47
29 files changed, 3082 insertions, 365 deletions
diff --git a/compiler/cpp/CMakeLists.txt b/compiler/cpp/CMakeLists.txt
index a23004110..b0f123555 100644
--- a/compiler/cpp/CMakeLists.txt
+++ b/compiler/cpp/CMakeLists.txt
@@ -46,6 +46,8 @@ add_library(parse STATIC ${parse_SOURCES})
set(compiler_core
src/thrift/common.cc
src/thrift/generate/t_generator.cc
+ src/thrift/generate/validator_parser.cc
+ src/thrift/generate/validator_parser.h
src/thrift/parse/t_typedef.cc
src/thrift/parse/parse.cc
src/thrift/version.h
@@ -71,36 +73,51 @@ macro(THRIFT_ADD_COMPILER name description initial)
endif()
endmacro()
+# This macro adds an option THRIFT_VALIDATOR_COMPILER_${NAME}
+# that allows enabling or disabling certain languages' validator
+macro(THRIFT_ADD_VALIDATOR_COMPILER name description initial)
+ string(TOUPPER "THRIFT_COMPILER_${name}" enabler)
+ set(src "src/thrift/generate/${name}_validator_generator.cc")
+ list(APPEND "src/thrift/generate/${name}_validator_generator.h")
+ option(${enabler} ${description} ${initial})
+ if(${enabler})
+ list(APPEND thrift-compiler_SOURCES ${src})
+ endif()
+endmacro()
+
# The following compiler can be enabled or disabled
-THRIFT_ADD_COMPILER(c_glib "Enable compiler for C with Glib" ON)
-THRIFT_ADD_COMPILER(cl "Enable compiler for Common LISP" ON)
-THRIFT_ADD_COMPILER(cpp "Enable compiler for C++" ON)
-THRIFT_ADD_COMPILER(d "Enable compiler for D" ON)
-THRIFT_ADD_COMPILER(dart "Enable compiler for Dart" ON)
-THRIFT_ADD_COMPILER(delphi "Enable compiler for Delphi" ON)
-THRIFT_ADD_COMPILER(erl "Enable compiler for Erlang" ON)
-THRIFT_ADD_COMPILER(go "Enable compiler for Go" ON)
-THRIFT_ADD_COMPILER(gv "Enable compiler for GraphViz" ON)
-THRIFT_ADD_COMPILER(haxe "Enable compiler for Haxe" ON)
-THRIFT_ADD_COMPILER(html "Enable compiler for HTML Documentation" ON)
-THRIFT_ADD_COMPILER(markdown "Enable compiler for Markdown Documentation" ON)
-THRIFT_ADD_COMPILER(java "Enable compiler for Java" ON)
-THRIFT_ADD_COMPILER(javame "Enable compiler for Java ME" ON)
-THRIFT_ADD_COMPILER(js "Enable compiler for JavaScript" ON)
-THRIFT_ADD_COMPILER(json "Enable compiler for JSON" ON)
-THRIFT_ADD_COMPILER(kotlin "Enable compiler for Kotlin" ON)
-THRIFT_ADD_COMPILER(lua "Enable compiler for Lua" ON)
-THRIFT_ADD_COMPILER(netstd "Enable compiler for .NET Standard" ON)
-THRIFT_ADD_COMPILER(ocaml "Enable compiler for OCaml" ON)
-THRIFT_ADD_COMPILER(perl "Enable compiler for Perl" ON)
-THRIFT_ADD_COMPILER(php "Enable compiler for PHP" ON)
-THRIFT_ADD_COMPILER(py "Enable compiler for Python 2.0" ON)
-THRIFT_ADD_COMPILER(rb "Enable compiler for Ruby" ON)
-THRIFT_ADD_COMPILER(rs "Enable compiler for Rust" ON)
-THRIFT_ADD_COMPILER(st "Enable compiler for Smalltalk" ON)
-THRIFT_ADD_COMPILER(swift "Enable compiler for Cocoa Swift" ON)
-THRIFT_ADD_COMPILER(xml "Enable compiler for XML" ON)
-THRIFT_ADD_COMPILER(xsd "Enable compiler for XSD" ON)
+THRIFT_ADD_COMPILER(c_glib "Enable compiler for C with Glib" ON)
+THRIFT_ADD_COMPILER(cl "Enable compiler for Common LISP" ON)
+THRIFT_ADD_COMPILER(cpp "Enable compiler for C++" ON)
+THRIFT_ADD_COMPILER(d "Enable compiler for D" ON)
+THRIFT_ADD_COMPILER(dart "Enable compiler for Dart" ON)
+THRIFT_ADD_COMPILER(delphi "Enable compiler for Delphi" ON)
+THRIFT_ADD_COMPILER(erl "Enable compiler for Erlang" ON)
+THRIFT_ADD_COMPILER(go "Enable compiler for Go" ON)
+THRIFT_ADD_COMPILER(gv "Enable compiler for GraphViz" ON)
+THRIFT_ADD_COMPILER(haxe "Enable compiler for Haxe" ON)
+THRIFT_ADD_COMPILER(html "Enable compiler for HTML Documentation" ON)
+THRIFT_ADD_COMPILER(markdown "Enable compiler for Markdown Documentation" ON)
+THRIFT_ADD_COMPILER(java "Enable compiler for Java" ON)
+THRIFT_ADD_COMPILER(javame "Enable compiler for Java ME" ON)
+THRIFT_ADD_COMPILER(js "Enable compiler for JavaScript" ON)
+THRIFT_ADD_COMPILER(json "Enable compiler for JSON" ON)
+THRIFT_ADD_COMPILER(kotlin "Enable compiler for Kotlin" ON)
+THRIFT_ADD_COMPILER(lua "Enable compiler for Lua" ON)
+THRIFT_ADD_COMPILER(netstd "Enable compiler for .NET Standard" ON)
+THRIFT_ADD_COMPILER(ocaml "Enable compiler for OCaml" ON)
+THRIFT_ADD_COMPILER(perl "Enable compiler for Perl" ON)
+THRIFT_ADD_COMPILER(php "Enable compiler for PHP" ON)
+THRIFT_ADD_COMPILER(py "Enable compiler for Python 2.0" ON)
+THRIFT_ADD_COMPILER(rb "Enable compiler for Ruby" ON)
+THRIFT_ADD_COMPILER(rs "Enable compiler for Rust" ON)
+THRIFT_ADD_COMPILER(st "Enable compiler for Smalltalk" ON)
+THRIFT_ADD_COMPILER(swift "Enable compiler for Cocoa Swift" ON)
+THRIFT_ADD_COMPILER(xml "Enable compiler for XML" ON)
+THRIFT_ADD_COMPILER(xsd "Enable compiler for XSD" ON)
+
+# The following compiler can be enabled or disabled by enabling or disabling certain languages
+THRIFT_ADD_VALIDATOR_COMPILER(go "Enable validator compiler for Go" ON)
# Thrift is looking for include files in the src directory
# we also add the current binary directory for generated files
diff --git a/compiler/cpp/Makefile.am b/compiler/cpp/Makefile.am
index 55c82943f..bb29d8c47 100644
--- a/compiler/cpp/Makefile.am
+++ b/compiler/cpp/Makefile.am
@@ -77,6 +77,7 @@ thrift_SOURCES += src/thrift/generate/t_c_glib_generator.cc \
src/thrift/generate/t_delphi_generator.cc \
src/thrift/generate/t_erl_generator.cc \
src/thrift/generate/t_go_generator.cc \
+ src/thrift/generate/t_go_generator.h \
src/thrift/generate/t_gv_generator.cc \
src/thrift/generate/t_haxe_generator.cc \
src/thrift/generate/t_html_generator.cc \
@@ -98,7 +99,11 @@ thrift_SOURCES += src/thrift/generate/t_c_glib_generator.cc \
src/thrift/generate/t_st_generator.cc \
src/thrift/generate/t_swift_generator.cc \
src/thrift/generate/t_xml_generator.cc \
- src/thrift/generate/t_xsd_generator.cc
+ src/thrift/generate/t_xsd_generator.cc \
+ src/thrift/generate/validator_parser.cc \
+ src/thrift/generate/validator_parser.h \
+ src/thrift/generate/go_validator_generator.cc \
+ src/thrift/generate/go_validator_generator.h
thrift_CPPFLAGS = -I$(srcdir)/src
thrift_CXXFLAGS = -Wall -Wextra -pedantic -Werror
diff --git a/compiler/cpp/src/thrift/generate/go_validator_generator.cc b/compiler/cpp/src/thrift/generate/go_validator_generator.cc
new file mode 100644
index 000000000..1f5a3ad6e
--- /dev/null
+++ b/compiler/cpp/src/thrift/generate/go_validator_generator.cc
@@ -0,0 +1,906 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*
+ * This file is programmatically sanitized for style:
+ * astyle --style=1tbs -f -p -H -j -U go_validator_generator.cc
+ *
+ * The output of astyle should not be taken unquestioningly, but it is a good
+ * guide for ensuring uniformity and readability.
+ */
+
+#include <fstream>
+#include <iostream>
+#include <limits>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "thrift/generate/go_validator_generator.h"
+#include "thrift/generate/validator_parser.h"
+#include "thrift/platform.h"
+#include "thrift/version.h"
+#include <algorithm>
+#include <clocale>
+#include <sstream>
+#include <stdlib.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+
+std::string go_validator_generator::get_field_reference_name(t_field* field) {
+ t_type* type(field->get_type());
+ std::string tgt;
+ t_const_value* def_value;
+ go_generator->get_publicized_name_and_def_value(field, &tgt, &def_value);
+ tgt = "p." + tgt;
+ if (go_generator->is_pointer_field(field)
+ && (type->is_base_type() || type->is_enum() || type->is_container())) {
+ tgt = "*" + tgt;
+ }
+ return tgt;
+}
+
+void go_validator_generator::generate_struct_validator(std::ostream& out, t_struct* tstruct) {
+ std::vector<t_field*> members = tstruct->get_members();
+ validation_parser parser(tstruct);
+ for (auto it = members.begin(); it != members.end(); it++) {
+ t_field* field(*it);
+ const std::vector<validation_rule*>& rules
+ = parser.parse_field(field->get_type(), field->annotations_);
+ if (rules.size() == 0) {
+ continue;
+ }
+ bool opt = field->get_req() == t_field::T_OPTIONAL;
+ t_type* type = field->get_type();
+ std::string tgt = get_field_reference_name(field);
+ std::string field_symbol = tstruct->get_name() + "." + field->get_name();
+ generate_field_validator(out, generator_context{field_symbol, "", tgt, opt, type, rules});
+ }
+}
+
+void go_validator_generator::generate_field_validator(std::ostream& out,
+ const generator_context& context) {
+ t_type* type = context.type;
+ if (type->is_typedef()) {
+ type = type->get_true_type();
+ }
+ if (type->is_enum()) {
+ if (context.tgt[0] == '*') {
+ out << indent() << "if " << context.tgt.substr(1) << " != nil {" << endl;
+ indent_up();
+ }
+ generate_enum_field_validator(out, context);
+ if (context.tgt[0] == '*') {
+ indent_down();
+ out << indent() << "}" << endl;
+ }
+ return;
+ } else if (type->is_base_type()) {
+ t_base_type::t_base tbase = ((t_base_type*)type)->get_base();
+ if (context.tgt[0] == '*') {
+ out << indent() << "if " << context.tgt.substr(1) << " != nil {" << endl;
+ indent_up();
+ }
+ switch (tbase) {
+ case t_base_type::TYPE_UUID:
+ case t_base_type::TYPE_VOID:
+ break;
+ case t_base_type::TYPE_I8:
+ case t_base_type::TYPE_I16:
+ case t_base_type::TYPE_I32:
+ case t_base_type::TYPE_I64:
+ generate_integer_field_validator(out, context);
+ break;
+ case t_base_type::TYPE_DOUBLE:
+ generate_double_field_validator(out, context);
+ break;
+ case t_base_type::TYPE_STRING:
+ generate_string_field_validator(out, context);
+ break;
+ case t_base_type::TYPE_BOOL:
+ generate_bool_field_validator(out, context);
+ break;
+ }
+ if (context.tgt[0] == '*') {
+ indent_down();
+ out << indent() << "}" << endl;
+ }
+ return;
+ } else if (type->is_list()) {
+ return generate_list_field_validator(out, context);
+ } else if (type->is_set()) {
+ return generate_set_field_validator(out, context);
+ } else if (type->is_map()) {
+ return generate_map_field_validator(out, context);
+ } else if (type->is_struct() || type->is_xception()) {
+ return generate_struct_field_validator(out, context);
+ }
+ throw "validator error: unsupported type: " + type->get_name();
+}
+
+void go_validator_generator::generate_enum_field_validator(std::ostream& out,
+ const generator_context& context) {
+ for (auto it = context.rules.begin(); it != context.rules.end(); it++) {
+ const std::vector<validation_value*>& values = (*it)->get_values();
+ if (values.size() == 0) {
+ continue;
+ }
+ std::string key = (*it)->get_name();
+
+ if (key == "vt.in") {
+ if (values.size() > 1) {
+ std::string exist = GenID("_exist");
+ out << indent() << "var " << exist << " bool" << endl;
+
+ std::string src = GenID("_src");
+ out << indent() << src << " := []int64{";
+ for (auto it = values.begin(); it != values.end(); it++) {
+ if (it != values.begin()) {
+ out << ", ";
+ }
+ out << "int64(";
+ if ((*it)->is_field_reference()) {
+ out << get_field_reference_name((*it)->get_field_reference());
+ } else {
+ out << (*it)->get_enum()->get_value();
+ }
+ out << ")";
+ }
+ out << "}" << endl;
+
+ out << indent() << "for _, src := range " << src << " {" << endl;
+ indent_up();
+ out << indent() << "if int64(" << context.tgt << ") == src {" << endl;
+ indent_up();
+ out << indent() << exist << " = true" << endl;
+ out << indent() << "break" << endl;
+ indent_down();
+ out << indent() << "}" << endl;
+ indent_down();
+ out << indent() << "}" << endl;
+ out << indent() << "if " << exist << " == false {" << endl;
+ } else {
+ out << indent() << "if int64(" << context.tgt << ") != int64(";
+ if (values[0]->is_field_reference()) {
+ out << get_field_reference_name(values[0]->get_field_reference());
+ } else {
+ out << values[0]->get_enum()->get_value();
+ }
+ out << ") {" << endl;
+ }
+ indent_up();
+ out << indent()
+ << "return thrift.NewValidationException(thrift.VALIDATION_FAILED, \"vt.in\", \""
+ << context.field_symbol << "\", \"" << context.field_symbol
+ << " not valid, rule vt.in check failed\")" << endl;
+ indent_down();
+ out << indent() << "}" << endl;
+ if (values.size() > 1) {
+ indent_down();
+ out << indent() << "}" << endl;
+ }
+ } else if (key == "vt.not_in") {
+ if (values.size() > 1) {
+ std::string src = GenID("_src");
+ out << indent() << src << " := []int64{";
+ for (auto it = values.begin(); it != values.end(); it++) {
+ if (it != values.begin()) {
+ out << ", ";
+ }
+ out << "int64(";
+ if ((*it)->is_field_reference()) {
+ out << get_field_reference_name((*it)->get_field_reference());
+ } else {
+ out << (*it)->get_enum()->get_value();
+ }
+ out << ")";
+ }
+ out << "}" << endl;
+
+ out << indent() << "for _, src := range " << src << " {" << endl;
+ indent_up();
+ out << indent() << "if int64(" << context.tgt << ") == src {" << endl;
+ } else {
+ out << indent() << "if int64(" << context.tgt << ") == ";
+ out << "int64(";
+ if (values[0]->is_field_reference()) {
+ out << get_field_reference_name(values[0]->get_field_reference());
+ } else {
+ out << values[0]->get_enum()->get_value();
+ }
+ out << ") {" << endl;
+ }
+ indent_up();
+ out << indent()
+ << "return thrift.NewValidationException(thrift.VALIDATION_FAILED, \"vt.not_in\", \""
+ << context.field_symbol << "\", \"" << context.field_symbol
+ << " not valid, rule vt.not_in check failed\")" << endl;
+ indent_down();
+ out << indent() << "}" << endl;
+ if (values.size() > 1) {
+ indent_down();
+ out << indent() << "}" << endl;
+ }
+ } else if (key == "vt.defined_only") {
+ if (values[0]->get_bool()) {
+ out << indent() << "if (" << context.tgt << ").String() == \"<UNSET>\" ";
+ } else {
+ continue;
+ }
+ out << "{" << endl;
+ indent_up();
+ out << indent()
+ << "return thrift.NewValidationException(thrift.VALIDATION_FAILED, \"" + key + "\", \""
+ << context.field_symbol << "\", \"" << context.field_symbol << " not valid, rule " << key
+ << " check failed\")" << endl;
+ indent_down();
+ out << indent() << "}" << endl;
+ }
+ }
+}
+
+void go_validator_generator::generate_bool_field_validator(std::ostream& out,
+ const generator_context& context) {
+ for (auto it = context.rules.begin(); it != context.rules.end(); it++) {
+ const std::vector<validation_value*>& values = (*it)->get_values();
+ if (values.size() == 0) {
+ continue;
+ }
+ std::string key = (*it)->get_name();
+
+ if (key == "vt.const") {
+ out << indent() << "if " << context.tgt << " != ";
+ if (values[0]->is_field_reference()) {
+ out << get_field_reference_name(values[0]->get_field_reference());
+ } else {
+ if (values[0]->get_bool()) {
+ out << "true";
+ } else {
+ out << "false";
+ }
+ }
+ }
+ out << "{" << endl;
+ indent_up();
+ out << indent()
+ << "return thrift.NewValidationException(thrift.VALIDATION_FAILED, \"" + key + "\", \""
+ << context.field_symbol << "\", \"" << context.field_symbol << " not valid, rule " << key
+ << " check failed\")" << endl;
+ indent_down();
+ out << indent() << "}" << endl;
+ }
+}
+
+void go_validator_generator::generate_double_field_validator(std::ostream& out,
+ const generator_context& context) {
+ for (auto it = context.rules.begin(); it != context.rules.end(); it++) {
+ const std::vector<validation_value*>& values = (*it)->get_values();
+ if (values.size() == 0) {
+ continue;
+ }
+
+ std::map<std::string, std::string> signs{{"vt.lt", ">="},
+ {"vt.le", ">"},
+ {"vt.gt", "<="},
+ {"vt.ge", "<"}};
+ std::string key = (*it)->get_name();
+ auto key_it = signs.find(key);
+ if (key_it != signs.end()) {
+ out << indent() << "if " << context.tgt << " " << key_it->second << " ";
+ if (values[0]->is_field_reference()) {
+ out << get_field_reference_name(values[0]->get_field_reference());
+ } else {
+ out << values[0]->get_double();
+ }
+ out << "{" << endl;
+ indent_up();
+ out << indent()
+ << "return thrift.NewValidationException(thrift.VALIDATION_FAILED, \"" + key + "\", \""
+ << context.field_symbol << "\", \"" << context.field_symbol << " not valid, rule " << key
+ << " check failed\")" << endl;
+ indent_down();
+ out << indent() << "}" << endl;
+ continue;
+ } else if (key == "vt.in") {
+ if (values.size() > 1) {
+ std::string exist = GenID("_exist");
+ out << indent() << "var " << exist << " bool" << endl;
+
+ std::string src = GenID("_src");
+ out << indent() << src << " := []float64{";
+ for (auto it = values.begin(); it != values.end(); it++) {
+ if (it != values.begin()) {
+ out << ", ";
+ }
+ if ((*it)->is_field_reference()) {
+ out << get_field_reference_name((*it)->get_field_reference());
+ } else {
+ out << (*it)->get_double();
+ }
+ }
+ out << "}" << endl;
+
+ out << indent() << "for _, src := range " << src << " {" << endl;
+ indent_up();
+ out << indent() << "if " << context.tgt << " == src {" << endl;
+ indent_up();
+ out << indent() << exist << " = true" << endl;
+ out << indent() << "break" << endl;
+ indent_down();
+ out << indent() << "}" << endl;
+ indent_down();
+ out << indent() << "}" << endl;
+ out << indent() << "if " << exist << " == false {" << endl;
+ } else {
+ out << indent() << "if " << context.tgt << " != ";
+ if (values[0]->is_field_reference()) {
+ out << get_field_reference_name(values[0]->get_field_reference());
+ } else {
+ out << values[0]->get_double();
+ }
+ out << "{" << endl;
+ }
+
+ indent_up();
+ out << indent()
+ << "return thrift.NewValidationException(thrift.VALIDATION_FAILED, \"vt.in\", \""
+ << context.field_symbol << "\", \"" << context.field_symbol
+ << " not valid, rule vt.in check failed\")" << endl;
+ indent_down();
+ out << indent() << "}" << endl;
+ } else if (key == "vt.not_in") {
+ if (values.size() > 1) {
+ std::string src = GenID("_src");
+ out << indent() << src << " := []float64{";
+ for (auto it = values.begin(); it != values.end(); it++) {
+ if (it != values.begin()) {
+ out << ", ";
+ }
+ if ((*it)->is_field_reference()) {
+ out << get_field_reference_name((*it)->get_field_reference());
+ } else {
+ out << (*it)->get_double();
+ }
+ }
+ out << "}" << endl;
+
+ out << indent() << "for _, src := range " << src << " {" << endl;
+ indent_up();
+ out << indent() << "if " << context.tgt << " == src {" << endl;
+ } else {
+ out << indent() << "if " << context.tgt << " == ";
+ if (values[0]->is_field_reference()) {
+ out << get_field_reference_name(values[0]->get_field_reference());
+ } else {
+ out << values[0]->get_double();
+ }
+ out << "{" << endl;
+ }
+ indent_up();
+ out << indent()
+ << "return thrift.NewValidationException(thrift.VALIDATION_FAILED, \"vt.not_in\", \""
+ << context.field_symbol << "\", \"" << context.field_symbol
+ << " not valid, rule vt.not_in check failed\")" << endl;
+ indent_down();
+ out << indent() << "}" << endl;
+ if (values.size() > 1) {
+ indent_down();
+ out << indent() << "}" << endl;
+ }
+ }
+ }
+}
+
+void go_validator_generator::generate_integer_field_validator(std::ostream& out,
+ const generator_context& context) {
+ auto generate_current_type = [](std::ostream& out, t_type* type) {
+ t_base_type::t_base tbase = ((t_base_type*)type)->get_base();
+ switch (tbase) {
+ case t_base_type::TYPE_I8:
+ out << "int8";
+ break;
+ case t_base_type::TYPE_I16:
+ out << "int16";
+ break;
+ case t_base_type::TYPE_I32:
+ out << "int32";
+ break;
+ case t_base_type::TYPE_I64:
+ out << "int64";
+ break;
+ default:
+ throw "validator error: unsupported integer type: " + type->get_name();
+ }
+ };
+
+ for (auto it = context.rules.begin(); it != context.rules.end(); it++) {
+ const std::vector<validation_value*>& values = (*it)->get_values();
+ if (values.size() == 0) {
+ continue;
+ }
+
+ std::map<std::string, std::string> signs{{"vt.lt", ">="},
+ {"vt.le", ">"},
+ {"vt.gt", "<="},
+ {"vt.ge", "<"}};
+ std::string key = (*it)->get_name();
+ auto key_it = signs.find(key);
+ if (key_it != signs.end()) {
+ out << indent() << "if " << context.tgt << " " << key_it->second << " ";
+ if (values[0]->is_field_reference()) {
+ out << get_field_reference_name(values[0]->get_field_reference());
+ } else if (values[0]->is_validation_function()) {
+ generate_current_type(out, context.type);
+ out << "(";
+ validation_value::validation_function* func = values[0]->get_function();
+ if (func->name == "len") {
+ out << "len(";
+ if (func->arguments[0]->is_field_reference()) {
+ out << get_field_reference_name(func->arguments[0]->get_field_reference());
+ }
+ out << ")";
+ }
+ out << ")";
+ } else {
+ out << values[0]->get_int();
+ }
+ out << "{" << endl;
+ indent_up();
+ out << indent()
+ << "return thrift.NewValidationException(thrift.VALIDATION_FAILED, \"" + key + "\", \""
+ << context.field_symbol << "\", \"" << context.field_symbol << " not valid, rule " << key
+ << " check failed\")" << endl;
+ indent_down();
+ out << indent() << "}" << endl;
+ } else if (key == "vt.in") {
+ if (values.size() > 1) {
+ std::string exist = GenID("_exist");
+ out << indent() << "var " << exist << " bool" << endl;
+
+ std::string src = GenID("_src");
+ out << indent() << src << " := []";
+ generate_current_type(out, context.type);
+ out << "{";
+ for (auto it = values.begin(); it != values.end(); it++) {
+ if (it != values.begin()) {
+ out << ", ";
+ }
+ if ((*it)->is_field_reference()) {
+ out << get_field_reference_name((*it)->get_field_reference());
+ } else if ((*it)->is_validation_function()) {
+ generate_current_type(out, context.type);
+ out << "(";
+ validation_value::validation_function* func = (*it)->get_function();
+ if (func->name == "len") {
+ out << "len(";
+ if (func->arguments[0]->is_field_reference()) {
+ out << get_field_reference_name(func->arguments[0]->get_field_reference());
+ }
+ out << ")";
+ }
+ out << ")";
+ } else {
+ out << (*it)->get_int();
+ }
+ }
+ out << "}" << endl;
+
+ out << indent() << "for _, src := range " << src << " {" << endl;
+ indent_up();
+ out << indent() << "if " << context.tgt << " == src {" << endl;
+ indent_up();
+ out << indent() << exist << " = true" << endl;
+ out << indent() << "break" << endl;
+ indent_down();
+ out << indent() << "}" << endl;
+ indent_down();
+ out << indent() << "}" << endl;
+ out << indent() << "if " << exist << " == false {" << endl;
+ } else {
+ out << indent() << "if " << context.tgt << " != ";
+ if (values[0]->is_field_reference()) {
+ out << get_field_reference_name(values[0]->get_field_reference());
+ } else if (values[0]->is_validation_function()) {
+ generate_current_type(out, context.type);
+ out << "(";
+ validation_value::validation_function* func = values[0]->get_function();
+ if (func->name == "len") {
+ out << "len(";
+ if (func->arguments[0]->is_field_reference()) {
+ out << get_field_reference_name(func->arguments[0]->get_field_reference());
+ }
+ out << ")";
+ }
+ out << ")";
+ } else {
+ out << values[0]->get_int();
+ }
+ out << "{" << endl;
+ }
+ indent_up();
+ out << indent()
+ << "return thrift.NewValidationException(thrift.VALIDATION_FAILED, \"vt.in\", \""
+ << context.field_symbol << "\", \"" << context.field_symbol
+ << " not valid, rule vt.in check failed\")" << endl;
+ indent_down();
+ out << indent() << "}" << endl;
+ } else if (key == "vt.not_in") {
+ if (values.size() > 1) {
+ std::string src = GenID("_src");
+ out << indent() << src << " := []";
+ t_base_type::t_base tbase = ((t_base_type*)context.type)->get_base();
+ switch (tbase) {
+ case t_base_type::TYPE_I8:
+ out << "int8";
+ break;
+ case t_base_type::TYPE_I16:
+ out << "int16";
+ break;
+ case t_base_type::TYPE_I32:
+ out << "int32";
+ break;
+ case t_base_type::TYPE_I64:
+ out << "int64";
+ break;
+ default:
+ throw "validator error: unsupported integer type: " + context.type->get_name();
+ }
+ out << "{";
+ for (auto it = values.begin(); it != values.end(); it++) {
+ if (it != values.begin()) {
+ out << ", ";
+ }
+ if ((*it)->is_field_reference()) {
+ out << get_field_reference_name((*it)->get_field_reference());
+ } else if ((*it)->is_validation_function()) {
+ generate_current_type(out, context.type);
+ out << "(";
+ validation_value::validation_function* func = (*it)->get_function();
+ if (func->name == "len") {
+ out << "len(";
+ if (func->arguments[0]->is_field_reference()) {
+ out << get_field_reference_name(func->arguments[0]->get_field_reference());
+ }
+ out << ")";
+ }
+ out << ")";
+ } else {
+ out << (*it)->get_int();
+ }
+ }
+ out << "}" << endl;
+
+ out << indent() << "for _, src := range " << src << " {" << endl;
+ indent_up();
+ out << indent() << "if " << context.tgt << " == src {" << endl;
+ } else {
+ out << indent() << "if " << context.tgt << " == ";
+ if (values[0]->is_field_reference()) {
+ out << get_field_reference_name(values[0]->get_field_reference());
+ } else if (values[0]->is_validation_function()) {
+ generate_current_type(out, context.type);
+ out << "(";
+ validation_value::validation_function* func = values[0]->get_function();
+ if (func->name == "len") {
+ out << "len(";
+ if (func->arguments[0]->is_field_reference()) {
+ out << get_field_reference_name(func->arguments[0]->get_field_reference());
+ }
+ out << ")";
+ }
+ out << ")";
+ } else {
+ out << values[0]->get_int();
+ }
+ out << "{" << endl;
+ }
+ indent_up();
+ out << indent()
+ << "return thrift.NewValidationException(thrift.VALIDATION_FAILED, \"vt.not_in\", \""
+ << context.field_symbol << "\", \"" << context.field_symbol
+ << " not valid, rule vt.not_in check failed\")" << endl;
+ indent_down();
+ out << indent() << "}" << endl;
+ if (values.size() > 1) {
+ indent_down();
+ out << indent() << "}" << endl;
+ }
+ }
+ }
+}
+
+void go_validator_generator::generate_string_field_validator(std::ostream& out,
+ const generator_context& context) {
+ std::string target = context.tgt;
+ t_type* type = context.type;
+ if (type->is_typedef()) {
+ type = type->get_true_type();
+ }
+ if (type->is_binary()) {
+ target = GenID("_tgt");
+ out << indent() << target << " := "
+ << "string(" << context.tgt << ")" << endl;
+ }
+ for (auto it = context.rules.begin(); it != context.rules.end(); it++) {
+ const std::vector<validation_value*>& values = (*it)->get_values();
+ if (values.size() == 0) {
+ continue;
+ }
+ std::string key = (*it)->get_name();
+
+ if (key == "vt.const") {
+ out << indent() << "if " << target << " != ";
+ if (values[0]->is_field_reference()) {
+ out << "string(";
+ out << get_field_reference_name(values[0]->get_field_reference());
+ out << ")";
+ } else {
+ out << "\"" << values[0]->get_string() << "\"";
+ }
+ } else if (key == "vt.min_size" || key == "vt.max_size") {
+ out << indent() << "if len(" << target << ") ";
+ if (key == "vt.min_size") {
+ out << "<";
+ } else {
+ out << ">";
+ }
+ out << " int(";
+ if (values[0]->is_field_reference()) {
+ out << get_field_reference_name(values[0]->get_field_reference());
+ } else if (values[0]->is_validation_function()) {
+ validation_value::validation_function* func = values[0]->get_function();
+ if (func->name == "len") {
+ out << "len(";
+ if (func->arguments[0]->is_field_reference()) {
+ out << "string(";
+ out << get_field_reference_name(values[0]->get_field_reference());
+ out << ")";
+ }
+ out << ")";
+ }
+ } else {
+ out << values[0]->get_int();
+ }
+ out << ")";
+ } else if (key == "vt.pattern") {
+ out << indent() << "if ok, _ := regexp.MatchString(" << target << ",";
+ if (values[0]->is_field_reference()) {
+ out << "string(";
+ out << get_field_reference_name(values[0]->get_field_reference());
+ out << ")";
+ } else {
+ out << "\"" << values[0]->get_string() << "\"";
+ }
+ out << "); ok ";
+ } else if (key == "vt.prefix") {
+ out << indent() << "if !strings.HasPrefix(" << target << ",";
+ if (values[0]->is_field_reference()) {
+ out << "string(";
+ out << get_field_reference_name(values[0]->get_field_reference());
+ out << ")";
+ } else {
+ out << "\"" << values[0]->get_string() << "\"";
+ }
+ out << ")";
+ } else if (key == "vt.suffix") {
+ out << indent() << "if !strings.HasSuffix(" << target << ",";
+ if (values[0]->is_field_reference()) {
+ out << "string(";
+ out << get_field_reference_name(values[0]->get_field_reference());
+ out << ")";
+ } else {
+ out << "\"" << values[0]->get_string() << "\"";
+ }
+ out << ")";
+ } else if (key == "vt.contains") {
+ out << indent() << "if !strings.Contains(" << target << ",";
+ if (values[0]->is_field_reference()) {
+ out << "string(";
+ out << get_field_reference_name(values[0]->get_field_reference());
+ out << ")";
+ } else {
+ out << "\"" << values[0]->get_string() << "\"";
+ }
+ out << ")";
+ } else if (key == "vt.not_contains") {
+ out << indent() << "if strings.Contains(" << target << ",";
+ if (values[0]->is_field_reference()) {
+ out << "string(";
+ out << get_field_reference_name(values[0]->get_field_reference());
+ out << ")";
+ } else {
+ out << "\"" << values[0]->get_string() << "\"";
+ }
+ out << ")";
+ }
+ out << "{" << endl;
+ indent_up();
+ out << indent()
+ << "return thrift.NewValidationException(thrift.VALIDATION_FAILED, \"" + key + "\", \""
+ << context.field_symbol << "\", \"" << context.field_symbol << " not valid, rule " << key
+ << " check failed\")" << endl;
+ indent_down();
+ out << indent() << "}" << endl;
+ }
+}
+
+void go_validator_generator::generate_set_field_validator(std::ostream& out,
+ const generator_context& context) {
+ return generate_list_field_validator(out, context);
+}
+
+void go_validator_generator::generate_list_field_validator(std::ostream& out,
+ const generator_context& context) {
+ for (auto it = context.rules.begin(); it != context.rules.end(); it++) {
+ const std::vector<validation_value*>& values = (*it)->get_values();
+ std::string key = (*it)->get_name();
+ if (key == "vt.min_size" || key == "vt.max_size") {
+ out << indent() << "if len(" << context.tgt << ")";
+ if (key == "vt.min_size") {
+ out << " < ";
+ } else {
+ out << " > ";
+ }
+ if (values[0]->is_field_reference()) {
+ out << "int(";
+ out << get_field_reference_name(values[0]->get_field_reference());
+ out << ")";
+ } else {
+ out << values[0]->get_int();
+ }
+ out << "{" << endl;
+ indent_up();
+ out << indent()
+ << "return thrift.NewValidationException(thrift.VALIDATION_FAILED, \"" + key + "\", \""
+ << context.field_symbol << "\", \"" << context.field_symbol << " not valid, rule " << key
+ << " check failed\")" << endl;
+ indent_down();
+ out << indent() << "}" << endl;
+ } else if (key == "vt.elem") {
+ out << indent() << "for i := 0; i < len(" << context.tgt << ");i++ {" << endl;
+ indent_up();
+ std::string src = GenID("_elem");
+ out << indent() << src << " := " << context.tgt << "[i]" << endl;
+ t_type* elem_type;
+ if (context.type->is_list()) {
+ elem_type = ((t_list*)context.type)->get_elem_type();
+ } else {
+ elem_type = ((t_set*)context.type)->get_elem_type();
+ }
+ generator_context ctx{context.field_symbol + ".elem",
+ "",
+ src,
+ false,
+ elem_type,
+ std::vector<validation_rule*>{(*it)->get_inner()}};
+ generate_field_validator(out, ctx);
+ indent_down();
+ out << indent() << "}" << endl;
+ }
+ }
+}
+
+void go_validator_generator::generate_map_field_validator(std::ostream& out,
+ const generator_context& context) {
+ for (auto it = context.rules.begin(); it != context.rules.end(); it++) {
+ const std::vector<validation_value*>& values = (*it)->get_values();
+ std::string key = (*it)->get_name();
+ if (key == "vt.min_size" || key == "vt.max_size") {
+ out << indent() << "if len(" << context.tgt << ")";
+ if (key == "vt.min_size") {
+ out << " < ";
+ } else {
+ out << " > ";
+ }
+ if (values[0]->is_field_reference()) {
+ out << "int(";
+ out << get_field_reference_name(values[0]->get_field_reference());
+ out << ")";
+ } else {
+ out << values[0]->get_int();
+ }
+ out << "{" << endl;
+ indent_up();
+ out << indent()
+ << "return thrift.NewValidationException(thrift.VALIDATION_FAILED, \"" + key + "\", \""
+ << context.field_symbol << "\", \"" << context.field_symbol << " not valid, rule " << key
+ << " check failed\")" << endl;
+ indent_down();
+ out << indent() << "}" << endl;
+ } else if (key == "vt.key") {
+ std::string src = GenID("_key");
+ out << indent() << "for " << src << " := range " << context.tgt << " {" << endl;
+ indent_up();
+ generator_context ctx{context.field_symbol + ".key",
+ "",
+ src,
+ false,
+ ((t_map*)context.type)->get_key_type(),
+ std::vector<validation_rule*>{(*it)->get_inner()}};
+ generate_field_validator(out, ctx);
+ indent_down();
+ out << indent() << "}" << endl;
+ } else if (key == "vt.value") {
+ std::string src = GenID("_value");
+ out << indent() << "for _, " << src << " := range " << context.tgt << " {" << endl;
+ indent_up();
+ generator_context ctx{context.field_symbol + ".value",
+ "",
+ src,
+ false,
+ ((t_map*)context.type)->get_val_type(),
+ std::vector<validation_rule*>{(*it)->get_inner()}};
+ generate_field_validator(out, ctx);
+ indent_down();
+ out << indent() << "}" << endl;
+ }
+ }
+}
+
+void go_validator_generator::generate_struct_field_validator(std::ostream& out,
+ const generator_context& context) {
+ bool generate_valid = true;
+ validation_rule* last_valid_rule = nullptr;
+ for (auto it = context.rules.begin(); it != context.rules.end(); it++) {
+ const std::vector<validation_value*>& values = (*it)->get_values();
+ if (values.size() == 0) {
+ continue;
+ }
+ std::string key = (*it)->get_name();
+
+ if (key == "vt.skip") {
+ if (values[0]->is_field_reference() || !values[0]->get_bool()) {
+ generate_valid = true;
+ } else if (values[0]->get_bool()) {
+ generate_valid = false;
+ }
+ last_valid_rule = *it;
+ }
+ }
+ if (generate_valid) {
+ if (last_valid_rule == nullptr) {
+ out << indent() << "if err := " << context.tgt << ".Validate(); err != nil {" << endl;
+ indent_up();
+ out << indent() << "return err" << endl;
+ indent_down();
+ out << indent() << "}" << endl;
+ } else {
+ const std::vector<validation_value*>& values = last_valid_rule->get_values();
+ if (!values[0]->get_bool()) {
+ out << indent() << "if err := " << context.tgt << ".Validate(); err != nil {" << endl;
+ indent_up();
+ out << indent() << "return err" << endl;
+ indent_down();
+ out << indent() << "}" << endl;
+ } else if (values[0]->is_field_reference()) {
+ out << indent() << "if !";
+ out << get_field_reference_name(values[0]->get_field_reference());
+ out << "{" << endl;
+ indent_up();
+ out << indent() << "if err := " << context.tgt << ".Validate(); err != nil {" << endl;
+ indent_up();
+ out << indent() << "return err" << endl;
+ indent_down();
+ out << indent() << "}" << endl;
+ indent_down();
+ out << indent() << "}" << endl;
+ }
+ }
+ }
+}
diff --git a/compiler/cpp/src/thrift/generate/go_validator_generator.h b/compiler/cpp/src/thrift/generate/go_validator_generator.h
new file mode 100644
index 000000000..ca36347cb
--- /dev/null
+++ b/compiler/cpp/src/thrift/generate/go_validator_generator.h
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef T_GO_VALIDATOR_GENERATOR_H
+#define T_GO_VALIDATOR_GENERATOR_H
+
+#include "thrift/generate/t_generator.h"
+#include "thrift/generate/t_go_generator.h"
+#include "thrift/generate/validator_parser.h"
+#include <fstream>
+#include <iostream>
+#include <limits>
+#include <string>
+#include <vector>
+
+class go_validator_generator {
+public:
+ go_validator_generator(t_go_generator* gg) : go_generator(gg){};
+ void generate_struct_validator(std::ostream& out, t_struct* tstruct);
+
+ struct generator_context {
+ std::string field_symbol;
+ std::string src;
+ std::string tgt;
+ bool opt;
+ t_type* type;
+ std::vector<validation_rule*> rules;
+ };
+
+private:
+ void generate_field_validator(std::ostream& out, const generator_context& context);
+ void generate_enum_field_validator(std::ostream& out, const generator_context& context);
+ void generate_bool_field_validator(std::ostream& out, const generator_context& context);
+ void generate_integer_field_validator(std::ostream& out, const generator_context& context);
+ void generate_double_field_validator(std::ostream& out, const generator_context& context);
+ void generate_string_field_validator(std::ostream& out, const generator_context& context);
+ void generate_list_field_validator(std::ostream& out, const generator_context& context);
+ void generate_set_field_validator(std::ostream& out, const generator_context& context);
+ void generate_map_field_validator(std::ostream& out, const generator_context& context);
+ void generate_struct_field_validator(std::ostream& out, const generator_context& context);
+
+ void indent_up() { go_generator->indent_up(); }
+ void indent_down() { go_generator->indent_down(); }
+ std::string indent() { return go_generator->indent(); }
+
+ std::string get_field_name(t_field* field);
+ std::string get_field_reference_name(t_field* field);
+
+ std::string GenID(std::string id) { return id + std::to_string(tmp_[id]++); };
+
+ t_go_generator* go_generator;
+
+ std::map<std::string, int> tmp_;
+};
+
+#endif
diff --git a/compiler/cpp/src/thrift/generate/t_cpp_generator.cc b/compiler/cpp/src/thrift/generate/t_cpp_generator.cc
index 0420e62e7..dd444bc9c 100644
--- a/compiler/cpp/src/thrift/generate/t_cpp_generator.cc
+++ b/compiler/cpp/src/thrift/generate/t_cpp_generator.cc
@@ -4398,9 +4398,9 @@ string t_cpp_generator::namespace_close(string ns) {
string t_cpp_generator::type_name(t_type* ttype, bool in_typedef, bool arg) {
if (ttype->is_base_type()) {
string bname = base_type_name(((t_base_type*)ttype)->get_base());
- std::map<string, string>::iterator it = ttype->annotations_.find("cpp.type");
- if (it != ttype->annotations_.end()) {
- bname = it->second;
+ std::map<string, std::vector<string>>::iterator it = ttype->annotations_.find("cpp.type");
+ if (it != ttype->annotations_.end() && !it->second.empty()) {
+ bname = it->second.back();
}
if (!arg) {
diff --git a/compiler/cpp/src/thrift/generate/t_delphi_generator.cc b/compiler/cpp/src/thrift/generate/t_delphi_generator.cc
index f35ffcb99..625179f8f 100644
--- a/compiler/cpp/src/thrift/generate/t_delphi_generator.cc
+++ b/compiler/cpp/src/thrift/generate/t_delphi_generator.cc
@@ -3329,11 +3329,11 @@ string t_delphi_generator::function_signature(t_function* tfunction,
// deprecated method? only at intf decl!
if( full_cls == "") {
auto iter = tfunction->annotations_.find("deprecated");
- if( tfunction->annotations_.end() != iter) {
+ if( tfunction->annotations_.end() != iter && !iter->second.empty()) {
signature += " deprecated";
// empty annotation values end up with "1" somewhere, ignore these as well
- if ((iter->second.length() > 0) && (iter->second != "1")) {
- signature += " " + make_pascal_string_literal(iter->second);
+ if ((iter->second.back().length() > 0) && (iter->second.back() != "1")) {
+ signature += " " + make_pascal_string_literal(iter->second.back());
}
signature += ";";
}
diff --git a/compiler/cpp/src/thrift/generate/t_go_generator.cc b/compiler/cpp/src/thrift/generate/t_go_generator.cc
index 90f34c8ca..e0ca489f5 100644
--- a/compiler/cpp/src/thrift/generate/t_go_generator.cc
+++ b/compiler/cpp/src/thrift/generate/t_go_generator.cc
@@ -41,6 +41,8 @@
#include "thrift/platform.h"
#include "thrift/version.h"
#include "thrift/generate/t_generator.h"
+#include "thrift/generate/t_go_generator.h"
+#include "thrift/generate/go_validator_generator.h"
using std::map;
using std::ostream;
@@ -49,8 +51,6 @@ using std::string;
using std::stringstream;
using std::vector;
-static const string endl = "\n"; // avoid ostream << std::endl flushes
-
/**
* A helper for automatically formatting the emitted Go code from the Thrift
* IDL per the Go style guide.
@@ -62,277 +62,6 @@ static const string endl = "\n"; // avoid ostream << std::endl flushes
*/
bool format_go_output(const string& file_path);
-const string DEFAULT_THRIFT_IMPORT = "github.com/apache/thrift/lib/go/thrift";
-static std::string package_flag;
-
-/**
- * Go code generator.
- */
-class t_go_generator : public t_generator {
-public:
- t_go_generator(t_program* program,
- const std::map<std::string, std::string>& parsed_options,
- const std::string& option_string)
- : t_generator(program) {
- (void)option_string;
- std::map<std::string, std::string>::const_iterator iter;
-
-
- gen_thrift_import_ = DEFAULT_THRIFT_IMPORT;
- gen_package_prefix_ = "";
- package_flag = "";
- read_write_private_ = false;
- ignore_initialisms_ = false;
- skip_remote_ = false;
- for( iter = parsed_options.begin(); iter != parsed_options.end(); ++iter) {
- if( iter->first.compare("package_prefix") == 0) {
- gen_package_prefix_ = (iter->second);
- } else if( iter->first.compare("thrift_import") == 0) {
- gen_thrift_import_ = (iter->second);
- } else if( iter->first.compare("package") == 0) {
- package_flag = (iter->second);
- } else if( iter->first.compare("read_write_private") == 0) {
- read_write_private_ = true;
- } else if( iter->first.compare("ignore_initialisms") == 0) {
- ignore_initialisms_ = true;
- } else if( iter->first.compare("skip_remote") == 0) {
- skip_remote_ = true;
- } else {
- throw "unknown option go:" + iter->first;
- }
- }
-
- out_dir_base_ = "gen-go";
- }
-
- /**
- * Init and close methods
- */
-
- void init_generator() override;
- void close_generator() override;
-
- /**
- * Program-level generation functions
- */
-
- void generate_typedef(t_typedef* ttypedef) override;
- void generate_enum(t_enum* tenum) override;
- void generate_const(t_const* tconst) override;
- void generate_struct(t_struct* tstruct) override;
- void generate_xception(t_struct* txception) override;
- void generate_service(t_service* tservice) override;
-
- std::string render_const_value(t_type* type, t_const_value* value, const string& name, bool opt = false);
-
- /**
- * Struct generation code
- */
-
- void generate_go_struct(t_struct* tstruct, bool is_exception);
- void generate_go_struct_definition(std::ostream& out,
- t_struct* tstruct,
- bool is_xception = false,
- bool is_result = false,
- bool is_args = false);
- void generate_go_struct_initializer(std::ostream& out,
- t_struct* tstruct,
- bool is_args_or_result = false);
- void generate_isset_helpers(std::ostream& out,
- t_struct* tstruct,
- const string& tstruct_name,
- bool is_result = false);
- void generate_countsetfields_helper(std::ostream& out,
- t_struct* tstruct,
- const string& tstruct_name,
- bool is_result = false);
- void generate_go_struct_reader(std::ostream& out,
- t_struct* tstruct,
- const string& tstruct_name,
- bool is_result = false);
- void generate_go_struct_writer(std::ostream& out,
- t_struct* tstruct,
- const string& tstruct_name,
- bool is_result = false,
- bool uses_countsetfields = false);
- void generate_go_struct_equals(std::ostream& out, t_struct* tstruct, const string& tstruct_name);
- void generate_go_function_helpers(t_function* tfunction);
- void get_publicized_name_and_def_value(t_field* tfield,
- string* OUT_pub_name,
- t_const_value** OUT_def_value) const;
-
- /**
- * Service-level generation functions
- */
-
- void generate_service_helpers(t_service* tservice);
- void generate_service_interface(t_service* tservice);
- void generate_service_client(t_service* tservice);
- void generate_service_remote(t_service* tservice);
- void generate_service_server(t_service* tservice);
- void generate_process_function(t_service* tservice, t_function* tfunction);
-
- /**
- * Serialization constructs
- */
-
- void generate_deserialize_field(std::ostream& out,
- t_field* tfield,
- bool declare,
- std::string prefix = "",
- bool inclass = false,
- bool coerceData = false,
- bool inkey = false,
- bool in_container = false);
-
- void generate_deserialize_struct(std::ostream& out,
- t_struct* tstruct,
- bool is_pointer_field,
- bool declare,
- std::string prefix = "");
-
- void generate_deserialize_container(std::ostream& out,
- t_type* ttype,
- bool pointer_field,
- bool declare,
- std::string prefix = "");
-
- void generate_deserialize_set_element(std::ostream& out,
- t_set* tset,
- bool declare,
- std::string prefix = "");
-
- void generate_deserialize_map_element(std::ostream& out,
- t_map* tmap,
- bool declare,
- std::string prefix = "");
-
- void generate_deserialize_list_element(std::ostream& out,
- t_list* tlist,
- bool declare,
- std::string prefix = "");
-
- void generate_serialize_field(std::ostream& out,
- t_field* tfield,
- std::string prefix = "",
- bool inkey = false);
-
- void generate_serialize_struct(std::ostream& out, t_struct* tstruct, std::string prefix = "");
-
- void generate_serialize_container(std::ostream& out,
- t_type* ttype,
- bool pointer_field,
- std::string prefix = "");
-
- void generate_serialize_map_element(std::ostream& out,
- t_map* tmap,
- std::string kiter,
- std::string viter);
-
- void generate_serialize_set_element(std::ostream& out, t_set* tmap, std::string iter);
-
- void generate_serialize_list_element(std::ostream& out, t_list* tlist, std::string iter);
-
- void generate_go_equals(std::ostream& out, t_type* ttype, string tgt, string src);
-
- void generate_go_equals_struct(std::ostream& out, t_type* ttype, string tgt, string src);
-
- void generate_go_equals_container(std::ostream& out, t_type* ttype, string tgt, string src);
-
- void generate_go_docstring(std::ostream& out, t_struct* tstruct);
-
- void generate_go_docstring(std::ostream& out, t_function* tfunction);
-
- void generate_go_docstring(std::ostream& out,
- t_doc* tdoc,
- t_struct* tstruct,
- const char* subheader);
-
- void generate_go_docstring(std::ostream& out, t_doc* tdoc);
-
- void parse_go_tags(map<string,string>* tags, const string in);
-
- /**
- * Helper rendering functions
- */
-
- std::string go_autogen_comment();
- std::string go_package();
- std::string go_imports_begin(bool consts);
- std::string go_imports_end();
- std::string render_includes(bool consts);
- std::string render_included_programs(string& unused_protection);
- std::string render_program_import(const t_program* program, string& unused_protection);
- std::string render_system_packages(std::vector<string> &system_packages);
- std::string render_import_protection();
- std::string render_fastbinary_includes();
- std::string declare_argument(t_field* tfield);
- std::string render_field_initial_value(t_field* tfield, const string& name, bool optional_field);
- std::string type_name(t_type* ttype);
- std::string module_name(t_type* ttype);
- std::string function_signature(t_function* tfunction, std::string prefix = "");
- std::string function_signature_if(t_function* tfunction,
- std::string prefix = "",
- bool addError = false);
- std::string argument_list(t_struct* tstruct);
- std::string type_to_enum(t_type* ttype);
- std::string type_to_go_type(t_type* ttype);
- std::string type_to_go_type_with_opt(t_type* ttype,
- bool optional_field);
- std::string type_to_go_key_type(t_type* ttype);
- std::string type_to_spec_args(t_type* ttype);
-
- static std::string get_real_go_module(const t_program* program) {
-
- if (!package_flag.empty()) {
- return package_flag;
- }
- std::string real_module = program->get_namespace("go");
- if (!real_module.empty()) {
- return real_module;
- }
-
- return lowercase(program->get_name());
- }
-
-private:
- std::string gen_package_prefix_;
- std::string gen_thrift_import_;
- bool read_write_private_;
- bool ignore_initialisms_;
- bool skip_remote_;
-
- /**
- * File streams
- */
-
- ofstream_with_content_based_conditional_update f_types_;
- std::string f_types_name_;
- ofstream_with_content_based_conditional_update f_consts_;
- std::string f_consts_name_;
- std::stringstream f_const_values_;
-
- std::string package_name_;
- std::string package_dir_;
- std::unordered_map<std::string, std::string> package_identifiers_;
- std::set<std::string> package_identifiers_set_;
- std::string read_method_name_;
- std::string write_method_name_;
- std::string equals_method_name_;
-
- std::set<std::string> commonInitialisms;
-
- std::string camelcase(const std::string& value) const;
- void fix_common_initialism(std::string& value, int i) const;
- std::string publicize(const std::string& value, bool is_args_or_result = false) const;
- std::string publicize(const std::string& value, bool is_args_or_result, const std::string& service_name) const;
- std::string privatize(const std::string& value) const;
- std::string new_prefix(const std::string& value) const;
- static std::string variable_name_to_go_name(const std::string& value);
- static bool is_pointer_field(t_field* tfield, bool in_container = false);
- static bool omit_initialization(t_field* tfield);
-};
-
// returns true if field initialization can be omitted since it has corresponding go type zero value
// or default value is not set
bool t_go_generator::omit_initialization(t_field* tfield) {
@@ -973,6 +702,10 @@ string t_go_generator::go_imports_begin(bool consts) {
system_packages.push_back("time");
// For the thrift import, always do rename import to make sure it's called thrift.
system_packages.push_back("thrift \"" + gen_thrift_import_ + "\"");
+
+ // validator import
+ system_packages.push_back("strings");
+ system_packages.push_back("regexp");
return "import (\n" + render_system_packages(system_packages);
}
@@ -983,7 +716,7 @@ string t_go_generator::go_imports_begin(bool consts) {
* This will have to do in lieu of more intelligent import statement construction
*/
string t_go_generator::go_imports_end() {
- return string(
+ string import_end = string(
")\n\n"
"// (needed to ensure safety because of naive import list construction.)\n"
"var _ = thrift.ZERO\n"
@@ -991,7 +724,11 @@ string t_go_generator::go_imports_end() {
"var _ = errors.New\n"
"var _ = context.Background\n"
"var _ = time.Now\n"
- "var _ = bytes.Equal\n\n");
+ "var _ = bytes.Equal\n"
+ "// (needed by validator.)\n"
+ "var _ = strings.Contains\n"
+ "var _ = regexp.MatchString\n\n");
+ return import_end;
}
/**
@@ -1384,6 +1121,14 @@ void t_go_generator::generate_xception(t_struct* txception) {
*/
void t_go_generator::generate_go_struct(t_struct* tstruct, bool is_exception) {
generate_go_struct_definition(f_types_, tstruct, is_exception);
+ // generate Validate function
+ std::string tstruct_name(publicize(tstruct->get_name(), false));
+ f_types_ << "func (p *" << tstruct_name << ") Validate() error {" << endl;
+ indent_up();
+ go_validator_generator(this).generate_struct_validator(f_types_, tstruct);
+ f_types_ << indent() << "return nil" << endl;
+ indent_down();
+ f_types_ << "}" << endl;
}
void t_go_generator::get_publicized_name_and_def_value(t_field* tfield,
@@ -1498,9 +1243,9 @@ void t_go_generator::generate_go_struct_definition(ostream& out,
// Check for user defined tags and them if there are any. User defined tags
// can override the above db and json tags.
- std::map<string, string>::iterator it = (*m_iter)->annotations_.find("go.tag");
+ std::map<string, std::vector<string>>::iterator it = (*m_iter)->annotations_.find("go.tag");
if (it != (*m_iter)->annotations_.end()) {
- parse_go_tags(&tags, it->second);
+ parse_go_tags(&tags, it->second.back());
}
string gotag;
diff --git a/compiler/cpp/src/thrift/generate/t_go_generator.h b/compiler/cpp/src/thrift/generate/t_go_generator.h
new file mode 100644
index 000000000..5080e1a0d
--- /dev/null
+++ b/compiler/cpp/src/thrift/generate/t_go_generator.h
@@ -0,0 +1,329 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef T_GO_GENERATOR_H
+#define T_GO_GENERATOR_H
+
+#include <fstream>
+#include <iostream>
+#include <limits>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "thrift/generate/t_generator.h"
+#include "thrift/platform.h"
+#include "thrift/version.h"
+#include <algorithm>
+#include <clocale>
+#include <sstream>
+#include <stdlib.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+
+using std::map;
+using std::ostream;
+using std::ostringstream;
+using std::string;
+using std::stringstream;
+using std::vector;
+
+static const string endl = "\n"; // avoid ostream << std::endl flushes
+
+const string DEFAULT_THRIFT_IMPORT = "github.com/apache/thrift/lib/go/thrift";
+static std::string package_flag;
+
+/**
+ * Go code generator.
+ */
+class t_go_generator : public t_generator {
+public:
+ t_go_generator(t_program* program,
+ const std::map<std::string, std::string>& parsed_options,
+ const std::string& option_string)
+ : t_generator(program) {
+ (void)option_string;
+ std::map<std::string, std::string>::const_iterator iter;
+
+ gen_thrift_import_ = DEFAULT_THRIFT_IMPORT;
+ gen_package_prefix_ = "";
+ package_flag = "";
+ read_write_private_ = false;
+ ignore_initialisms_ = false;
+ skip_remote_ = false;
+ for (iter = parsed_options.begin(); iter != parsed_options.end(); ++iter) {
+ if (iter->first.compare("package_prefix") == 0) {
+ gen_package_prefix_ = (iter->second);
+ } else if (iter->first.compare("thrift_import") == 0) {
+ gen_thrift_import_ = (iter->second);
+ } else if (iter->first.compare("package") == 0) {
+ package_flag = (iter->second);
+ } else if (iter->first.compare("read_write_private") == 0) {
+ read_write_private_ = true;
+ } else if (iter->first.compare("ignore_initialisms") == 0) {
+ ignore_initialisms_ = true;
+ } else if( iter->first.compare("skip_remote") == 0) {
+ skip_remote_ = true;
+ } else {
+ throw "unknown option go:" + iter->first;
+ }
+ }
+
+ out_dir_base_ = "gen-go";
+ }
+
+ /**
+ * Init and close methods
+ */
+
+ void init_generator() override;
+ void close_generator() override;
+
+ /**
+ * Program-level generation functions
+ */
+
+ void generate_typedef(t_typedef* ttypedef) override;
+ void generate_enum(t_enum* tenum) override;
+ void generate_const(t_const* tconst) override;
+ void generate_struct(t_struct* tstruct) override;
+ void generate_xception(t_struct* txception) override;
+ void generate_service(t_service* tservice) override;
+
+ std::string render_const_value(t_type* type,
+ t_const_value* value,
+ const string& name,
+ bool opt = false);
+
+ /**
+ * Struct generation code
+ */
+
+ void generate_go_struct(t_struct* tstruct, bool is_exception);
+ void generate_go_struct_definition(std::ostream& out,
+ t_struct* tstruct,
+ bool is_xception = false,
+ bool is_result = false,
+ bool is_args = false);
+ void generate_go_struct_initializer(std::ostream& out,
+ t_struct* tstruct,
+ bool is_args_or_result = false);
+ void generate_isset_helpers(std::ostream& out,
+ t_struct* tstruct,
+ const string& tstruct_name,
+ bool is_result = false);
+ void generate_countsetfields_helper(std::ostream& out,
+ t_struct* tstruct,
+ const string& tstruct_name,
+ bool is_result = false);
+ void generate_go_struct_reader(std::ostream& out,
+ t_struct* tstruct,
+ const string& tstruct_name,
+ bool is_result = false);
+ void generate_go_struct_writer(std::ostream& out,
+ t_struct* tstruct,
+ const string& tstruct_name,
+ bool is_result = false,
+ bool uses_countsetfields = false);
+ void generate_go_struct_equals(std::ostream& out, t_struct* tstruct, const string& tstruct_name);
+ void generate_go_function_helpers(t_function* tfunction);
+ void get_publicized_name_and_def_value(t_field* tfield,
+ string* OUT_pub_name,
+ t_const_value** OUT_def_value) const;
+
+ /**
+ * Service-level generation functions
+ */
+
+ void generate_service_helpers(t_service* tservice);
+ void generate_service_interface(t_service* tservice);
+ void generate_service_client(t_service* tservice);
+ void generate_service_remote(t_service* tservice);
+ void generate_service_server(t_service* tservice);
+ void generate_process_function(t_service* tservice, t_function* tfunction);
+
+ /**
+ * Serialization constructs
+ */
+
+ void generate_deserialize_field(std::ostream& out,
+ t_field* tfield,
+ bool declare,
+ std::string prefix = "",
+ bool inclass = false,
+ bool coerceData = false,
+ bool inkey = false,
+ bool in_container = false);
+
+ void generate_deserialize_struct(std::ostream& out,
+ t_struct* tstruct,
+ bool is_pointer_field,
+ bool declare,
+ std::string prefix = "");
+
+ void generate_deserialize_container(std::ostream& out,
+ t_type* ttype,
+ bool pointer_field,
+ bool declare,
+ std::string prefix = "");
+
+ void generate_deserialize_set_element(std::ostream& out,
+ t_set* tset,
+ bool declare,
+ std::string prefix = "");
+
+ void generate_deserialize_map_element(std::ostream& out,
+ t_map* tmap,
+ bool declare,
+ std::string prefix = "");
+
+ void generate_deserialize_list_element(std::ostream& out,
+ t_list* tlist,
+ bool declare,
+ std::string prefix = "");
+
+ void generate_serialize_field(std::ostream& out,
+ t_field* tfield,
+ std::string prefix = "",
+ bool inkey = false);
+
+ void generate_serialize_struct(std::ostream& out, t_struct* tstruct, std::string prefix = "");
+
+ void generate_serialize_container(std::ostream& out,
+ t_type* ttype,
+ bool pointer_field,
+ std::string prefix = "");
+
+ void generate_serialize_map_element(std::ostream& out,
+ t_map* tmap,
+ std::string kiter,
+ std::string viter);
+
+ void generate_serialize_set_element(std::ostream& out, t_set* tmap, std::string iter);
+
+ void generate_serialize_list_element(std::ostream& out, t_list* tlist, std::string iter);
+
+ void generate_go_equals(std::ostream& out, t_type* ttype, string tgt, string src);
+
+ void generate_go_equals_struct(std::ostream& out, t_type* ttype, string tgt, string src);
+
+ void generate_go_equals_container(std::ostream& out, t_type* ttype, string tgt, string src);
+
+ void generate_go_docstring(std::ostream& out, t_struct* tstruct);
+
+ void generate_go_docstring(std::ostream& out, t_function* tfunction);
+
+ void generate_go_docstring(std::ostream& out,
+ t_doc* tdoc,
+ t_struct* tstruct,
+ const char* subheader);
+
+ void generate_go_docstring(std::ostream& out, t_doc* tdoc);
+
+ void parse_go_tags(map<string, string>* tags, const string in);
+
+ /**
+ * Helper rendering functions
+ */
+
+ std::string go_autogen_comment();
+ std::string go_package();
+ std::string go_imports_begin(bool consts);
+ std::string go_imports_end();
+ std::string render_includes(bool consts);
+ std::string render_included_programs(string& unused_protection);
+ std::string render_program_import(const t_program* program, string& unused_protection);
+ std::string render_system_packages(std::vector<string>& system_packages);
+ std::string render_import_protection();
+ std::string render_fastbinary_includes();
+ std::string declare_argument(t_field* tfield);
+ std::string render_field_initial_value(t_field* tfield, const string& name, bool optional_field);
+ std::string type_name(t_type* ttype);
+ std::string module_name(t_type* ttype);
+ std::string function_signature(t_function* tfunction, std::string prefix = "");
+ std::string function_signature_if(t_function* tfunction,
+ std::string prefix = "",
+ bool addError = false);
+ std::string argument_list(t_struct* tstruct);
+ std::string type_to_enum(t_type* ttype);
+ std::string type_to_go_type(t_type* ttype);
+ std::string type_to_go_type_with_opt(t_type* ttype, bool optional_field);
+ std::string type_to_go_key_type(t_type* ttype);
+ std::string type_to_spec_args(t_type* ttype);
+
+ void indent_up() { t_generator::indent_up(); }
+ void indent_down() { t_generator::indent_down(); }
+ std::string indent() { return t_generator::indent(); }
+ std::ostream& indent(std::ostream& os) { return t_generator::indent(os); }
+
+ static std::string get_real_go_module(const t_program* program) {
+
+ if (!package_flag.empty()) {
+ return package_flag;
+ }
+ std::string real_module = program->get_namespace("go");
+ if (!real_module.empty()) {
+ return real_module;
+ }
+
+ return lowercase(program->get_name());
+ }
+
+ static bool is_pointer_field(t_field* tfield, bool in_container = false);
+
+private:
+ std::string gen_package_prefix_;
+ std::string gen_thrift_import_;
+ bool read_write_private_;
+ bool ignore_initialisms_;
+ bool skip_remote_;
+
+ /**
+ * File streams
+ */
+
+ ofstream_with_content_based_conditional_update f_types_;
+ std::string f_types_name_;
+ ofstream_with_content_based_conditional_update f_consts_;
+ std::string f_consts_name_;
+ std::stringstream f_const_values_;
+
+ std::string package_name_;
+ std::string package_dir_;
+ std::unordered_map<std::string, std::string> package_identifiers_;
+ std::set<std::string> package_identifiers_set_;
+ std::string read_method_name_;
+ std::string write_method_name_;
+ std::string equals_method_name_;
+
+ std::set<std::string> commonInitialisms;
+
+ std::string camelcase(const std::string& value) const;
+ void fix_common_initialism(std::string& value, int i) const;
+ std::string publicize(const std::string& value, bool is_args_or_result = false) const;
+ std::string publicize(const std::string& value,
+ bool is_args_or_result,
+ const std::string& service_name) const;
+ std::string privatize(const std::string& value) const;
+ std::string new_prefix(const std::string& value) const;
+ static std::string variable_name_to_go_name(const std::string& value);
+ static bool omit_initialization(t_field* tfield);
+};
+
+#endif
diff --git a/compiler/cpp/src/thrift/generate/t_java_generator.cc b/compiler/cpp/src/thrift/generate/t_java_generator.cc
index 3ef9028d3..a6041d7c4 100644
--- a/compiler/cpp/src/thrift/generate/t_java_generator.cc
+++ b/compiler/cpp/src/thrift/generate/t_java_generator.cc
@@ -374,6 +374,10 @@ public:
|| ttype->is_uuid() || ttype->is_enum();
}
+ bool is_deprecated(const std::map<std::string, std::vector<std::string>>& annotations) {
+ return annotations.find("deprecated") != annotations.end();
+ }
+
bool is_deprecated(const std::map<std::string, std::string>& annotations) {
return annotations.find("deprecated") != annotations.end();
}
@@ -2966,7 +2970,7 @@ void t_java_generator::generate_metadata_for_field_annotations(std::ostream& out
indent_up();
for (auto& annotation : field->annotations_) {
indent(out) << ".add(new java.util.AbstractMap.SimpleImmutableEntry<>(\"" + annotation.first
- + "\", \"" + annotation.second + "\"))"
+ + "\", \"" + annotation.second.back() + "\"))"
<< endl;
}
indent(out) << ".build().collect(java.util.stream.Collectors.toMap(java.util.Map.Entry::getKey, "
diff --git a/compiler/cpp/src/thrift/generate/t_json_generator.cc b/compiler/cpp/src/thrift/generate/t_json_generator.cc
index e16bc5098..5a854cea6 100644
--- a/compiler/cpp/src/thrift/generate/t_json_generator.cc
+++ b/compiler/cpp/src/thrift/generate/t_json_generator.cc
@@ -268,7 +268,9 @@ void t_json_generator::write_type_spec(t_type* ttype) {
write_key_and("annotations");
start_object();
for (auto & annotation : ttype->annotations_) {
- write_key_and_string(annotation.first, annotation.second);
+ for (auto& annotation_value : annotation.second) {
+ write_key_and_string(annotation.first, annotation_value);
+ }
}
end_object();
}
@@ -459,7 +461,9 @@ void t_json_generator::generate_typedef(t_typedef* ttypedef) {
write_key_and("annotations");
start_object();
for (auto & annotation : ttypedef->annotations_) {
- write_key_and_string(annotation.first, annotation.second);
+ for (auto& annotation_value : annotation.second) {
+ write_key_and_string(annotation.first, annotation_value);
+ }
}
end_object();
}
@@ -566,7 +570,9 @@ void t_json_generator::generate_enum(t_enum* tenum) {
write_key_and("annotations");
start_object();
for (auto & annotation : tenum->annotations_) {
- write_key_and_string(annotation.first, annotation.second);
+ for (auto& annotation_value : annotation.second) {
+ write_key_and_string(annotation.first, annotation_value);
+ }
}
end_object();
}
@@ -605,7 +611,9 @@ void t_json_generator::generate_struct(t_struct* tstruct) {
write_key_and("annotations");
start_object();
for (auto & annotation : tstruct->annotations_) {
- write_key_and_string(annotation.first, annotation.second);
+ for (auto& annotation_value : annotation.second) {
+ write_key_and_string(annotation.first, annotation_value);
+ }
}
end_object();
}
@@ -645,7 +653,9 @@ void t_json_generator::generate_service(t_service* tservice) {
write_key_and("annotations");
start_object();
for (auto & annotation : tservice->annotations_) {
- write_key_and_string(annotation.first, annotation.second);
+ for (auto& annotation_value : annotation.second) {
+ write_key_and_string(annotation.first, annotation_value);
+ }
}
end_object();
}
@@ -682,7 +692,9 @@ void t_json_generator::generate_function(t_function* tfunc) {
write_key_and("annotations");
start_object();
for (auto & annotation : tfunc->annotations_) {
- write_key_and_string(annotation.first, annotation.second);
+ for (auto& annotation_value : annotation.second) {
+ write_key_and_string(annotation.first, annotation_value);
+ }
}
end_object();
}
@@ -728,7 +740,9 @@ void t_json_generator::generate_field(t_field* field) {
write_key_and("annotations");
start_object();
for (auto & annotation : field->annotations_) {
- write_key_and_string(annotation.first, annotation.second);
+ for (auto& annotation_value : annotation.second) {
+ write_key_and_string(annotation.first, annotation_value);
+ }
}
end_object();
}
diff --git a/compiler/cpp/src/thrift/generate/t_kotlin_generator.cc b/compiler/cpp/src/thrift/generate/t_kotlin_generator.cc
index 1ac9c34fb..29cf00a74 100644
--- a/compiler/cpp/src/thrift/generate/t_kotlin_generator.cc
+++ b/compiler/cpp/src/thrift/generate/t_kotlin_generator.cc
@@ -583,7 +583,7 @@ void t_kotlin_generator::generate_metadata_for_field_annotations(std::ostream& o
out << "mapOf(" << endl;
indent_up();
for (auto& annotation : field->annotations_) {
- indent(out) << "\"" + annotation.first + "\" to \"" + annotation.second + "\"," << endl;
+ indent(out) << "\"" + annotation.first + "\" to \"" + annotation.second.back() + "\"," << endl;
}
indent_down();
indent(out) << ")";
diff --git a/compiler/cpp/src/thrift/generate/t_netstd_generator.cc b/compiler/cpp/src/thrift/generate/t_netstd_generator.cc
index 547457130..4cf3db55b 100644
--- a/compiler/cpp/src/thrift/generate/t_netstd_generator.cc
+++ b/compiler/cpp/src/thrift/generate/t_netstd_generator.cc
@@ -2059,8 +2059,8 @@ void t_netstd_generator::generate_deprecation_attribute(ostream& out, t_function
if( func->annotations_.end() != iter) {
out << indent() << "[Obsolete";
// empty annotation values end up with "1" somewhere, ignore these as well
- if ((iter->second.length() > 0) && (iter->second != "1")) {
- out << "(" << make_csharp_string_literal(iter->second) << ")";
+ if ((iter->second.back().length() > 0) && (iter->second.back() != "1")) {
+ out << "(" << make_csharp_string_literal(iter->second.back()) << ")";
}
out << "]" << endl;
}
diff --git a/compiler/cpp/src/thrift/generate/t_py_generator.cc b/compiler/cpp/src/thrift/generate/t_py_generator.cc
index 33437fd00..8ab8b9881 100644
--- a/compiler/cpp/src/thrift/generate/t_py_generator.cc
+++ b/compiler/cpp/src/thrift/generate/t_py_generator.cc
@@ -280,12 +280,12 @@ public:
}
static bool is_immutable(t_type* ttype) {
- std::map<std::string, std::string>::iterator it = ttype->annotations_.find("python.immutable");
+ std::map<std::string, std::vector<std::string>>::iterator it = ttype->annotations_.find("python.immutable");
if (it == ttype->annotations_.end()) {
// Exceptions are immutable by default.
return ttype->is_xception();
- } else if (it->second == "false") {
+ } else if (!it->second.empty() && it->second.back() == "false") {
return false;
} else {
return true;
diff --git a/compiler/cpp/src/thrift/generate/t_xml_generator.cc b/compiler/cpp/src/thrift/generate/t_xml_generator.cc
index b6692938f..220d50c68 100644
--- a/compiler/cpp/src/thrift/generate/t_xml_generator.cc
+++ b/compiler/cpp/src/thrift/generate/t_xml_generator.cc
@@ -94,7 +94,7 @@ public:
void generate_service(t_service* tservice) override;
void generate_struct(t_struct* tstruct) override;
- void generate_annotations(std::map<std::string, std::string> annotations);
+ void generate_annotations(std::map<std::string, std::vector<std::string>> annotations);
private:
bool should_merge_includes_;
@@ -165,17 +165,23 @@ void t_xml_generator::init_generator() {
string t_xml_generator::target_namespace(t_program* program) {
std::map<std::string, std::string> map;
std::map<std::string, std::string>::iterator iter;
- map = program->get_namespace_annotations("xml");
- if ((iter = map.find("targetNamespace")) != map.end()) {
- return iter->second;
+ std::map<std::string, std::vector<std::string>> annotations;
+ std::map<std::string, std::vector<std::string>>::iterator annotations_iter;
+ annotations = program->get_namespace_annotations("xml");
+ if ((annotations_iter = annotations.find("targetNamespace")) != annotations.end()) {
+ if (!annotations_iter->second.empty()) {
+ return annotations_iter->second.back();
+ }
}
map = program->get_namespaces();
if ((iter = map.find("xml")) != map.end()) {
return default_ns_prefix + iter->second;
}
- map = program->get_namespace_annotations("*");
- if ((iter = map.find("xml.targetNamespace")) != map.end()) {
- return iter->second;
+ annotations = program->get_namespace_annotations("*");
+ if ((annotations_iter = annotations.find("xml.targetNamespace")) != annotations.end()) {
+ if (!annotations_iter->second.empty()) {
+ return annotations_iter->second.back();
+ }
}
map = program->get_namespaces();
if ((iter = map.find("*")) != map.end()) {
@@ -432,13 +438,15 @@ void t_xml_generator::write_doc(t_doc* tdoc) {
}
void t_xml_generator::generate_annotations(
- std::map<std::string, std::string> annotations) {
- std::map<std::string, std::string>::iterator iter;
+ std::map<std::string, std::vector<std::string>> annotations) {
+ std::map<std::string, std::vector<std::string>>::iterator iter;
for (iter = annotations.begin(); iter != annotations.end(); ++iter) {
- write_element_start("annotation");
- write_attribute("key", iter->first);
- write_attribute("value", iter->second);
- write_element_end();
+ for (auto& annotations_value : iter->second) {
+ write_element_start("annotation");
+ write_attribute("key", iter->first);
+ write_attribute("value", annotations_value);
+ write_element_end();
+ }
}
}
diff --git a/compiler/cpp/src/thrift/generate/t_xsd_generator.cc b/compiler/cpp/src/thrift/generate/t_xsd_generator.cc
index a10f05959..15ede750f 100644
--- a/compiler/cpp/src/thrift/generate/t_xsd_generator.cc
+++ b/compiler/cpp/src/thrift/generate/t_xsd_generator.cc
@@ -262,10 +262,10 @@ void t_xsd_generator::generate_service(t_service* tservice) {
f_xsd_.open(f_xsd_name.c_str());
string ns = program_->get_namespace("xsd");
- const std::map<std::string, std::string> annot = program_->get_namespace_annotations("xsd");
- const std::map<std::string, std::string>::const_iterator uri = annot.find("uri");
- if (uri != annot.end()) {
- ns = uri->second;
+ const std::map<std::string, std::vector<std::string>> annot = program_->get_namespace_annotations("xsd");
+ const std::map<std::string, std::vector<std::string>>::const_iterator uri = annot.find("uri");
+ if (uri != annot.end() && !uri->second.empty()) {
+ ns = uri->second.back();
}
if (ns.size() > 0) {
ns = " targetNamespace=\"" + ns + "\" xmlns=\"" + ns + "\" "
diff --git a/compiler/cpp/src/thrift/generate/validator_parser.cc b/compiler/cpp/src/thrift/generate/validator_parser.cc
new file mode 100644
index 000000000..84261fe6f
--- /dev/null
+++ b/compiler/cpp/src/thrift/generate/validator_parser.cc
@@ -0,0 +1,550 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*
+ * This file is programmatically sanitized for style:
+ * astyle --style=1tbs -f -p -H -j -U t_validator_parser.cc
+ *
+ * The output of astyle should not be taken unquestioningly, but it is a good
+ * guide for ensuring uniformity and readability.
+ */
+
+#include <fstream>
+#include <iostream>
+#include <limits>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "thrift/generate/t_generator.h"
+#include "thrift/generate/validator_parser.h"
+#include "thrift/platform.h"
+#include "thrift/version.h"
+#include <algorithm>
+#include <clocale>
+#include <sstream>
+#include <stdlib.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+
+const char* list_delimiter = "[], ";
+
+std::vector<validation_rule*> validation_parser::parse_field(
+ t_type* type,
+ std::map<std::string, std::vector<std::string>>& annotations) {
+ if (type->is_typedef()) {
+ type = type->get_true_type();
+ }
+ if (type->is_enum()) {
+ return parse_enum_field(type, annotations);
+ } else if (type->is_base_type()) {
+ t_base_type::t_base tbase = ((t_base_type*)type)->get_base();
+ switch (tbase) {
+ case t_base_type::TYPE_UUID:
+ case t_base_type::TYPE_VOID:
+ break;
+ case t_base_type::TYPE_I8:
+ case t_base_type::TYPE_I16:
+ case t_base_type::TYPE_I32:
+ case t_base_type::TYPE_I64:
+ return parse_integer_field(type, annotations);
+ case t_base_type::TYPE_DOUBLE:
+ return parse_double_field(type, annotations);
+ case t_base_type::TYPE_STRING:
+ return parse_string_field(type, annotations);
+ case t_base_type::TYPE_BOOL:
+ return parse_bool_field(type, annotations);
+ }
+ } else if (type->is_list()) {
+ return parse_list_field(type, annotations);
+ } else if (type->is_set()) {
+ return parse_set_field(type, annotations);
+ } else if (type->is_map()) {
+ return parse_map_field(type, annotations);
+ } else if (type->is_struct()) {
+ if (((t_struct*)type)->is_union()) {
+ return parse_union_field(type, annotations);
+ }
+ return parse_struct_field(type, annotations);
+ } else if (type->is_xception()) {
+ return parse_xception_field(type, annotations);
+ }
+ throw "validator error: unsupported type: " + type->get_name();
+}
+
+std::vector<validation_rule*> validation_parser::parse_bool_field(
+ t_type* type,
+ std::map<std::string, std::vector<std::string>>& annotations) {
+ (void)type;
+ std::vector<validation_rule*> rules;
+ add_bool_rule(rules, "vt.const", annotations);
+ return rules;
+}
+
+std::vector<validation_rule*> validation_parser::parse_enum_field(
+ t_type* type,
+ std::map<std::string, std::vector<std::string>>& annotations) {
+ std::vector<validation_rule*> rules;
+ add_bool_rule(rules, "vt.defined_only", annotations);
+ add_enum_list_rule(rules, (t_enum*)type, "vt.in", annotations);
+ add_enum_list_rule(rules, (t_enum*)type, "vt.not_in", annotations);
+ return rules;
+}
+
+std::vector<validation_rule*> validation_parser::parse_double_field(
+ t_type* type,
+ std::map<std::string, std::vector<std::string>>& annotations) {
+ (void)type;
+ std::vector<validation_rule*> rules;
+ add_double_rule(rules, "vt.lt", annotations);
+ add_double_rule(rules, "vt.le", annotations);
+ add_double_rule(rules, "vt.gt", annotations);
+ add_double_rule(rules, "vt.ge", annotations);
+ add_double_list_rule(rules, "vt.in", annotations);
+ add_double_list_rule(rules, "vt.not_in", annotations);
+ return rules;
+}
+
+std::vector<validation_rule*> validation_parser::parse_integer_field(
+ t_type* type,
+ std::map<std::string, std::vector<std::string>>& annotations) {
+ (void)type;
+ std::vector<validation_rule*> rules;
+ add_integer_rule(rules, "vt.lt", annotations);
+ add_integer_rule(rules, "vt.le", annotations);
+ add_integer_rule(rules, "vt.gt", annotations);
+ add_integer_rule(rules, "vt.ge", annotations);
+ add_integer_list_rule(rules, "vt.in", annotations);
+ add_integer_list_rule(rules, "vt.not_in", annotations);
+ return rules;
+}
+
+std::vector<validation_rule*> validation_parser::parse_string_field(
+ t_type* type,
+ std::map<std::string, std::vector<std::string>>& annotations) {
+ (void)type;
+ std::vector<validation_rule*> rules;
+ add_string_rule(rules, "vt.const", annotations);
+ add_integer_rule(rules, "vt.min_size", annotations);
+ add_integer_rule(rules, "vt.max_size", annotations);
+ add_string_rule(rules, "vt.pattern", annotations);
+ add_string_rule(rules, "vt.prefix", annotations);
+ add_string_rule(rules, "vt.suffix", annotations);
+ add_string_rule(rules, "vt.contains", annotations);
+ add_string_rule(rules, "vt.not_contains", annotations);
+ return rules;
+}
+
+std::vector<validation_rule*> validation_parser::parse_set_field(
+ t_type* type,
+ std::map<std::string, std::vector<std::string>>& annotations) {
+ (void)type;
+ return parse_list_field(type, annotations);
+}
+
+std::vector<validation_rule*> validation_parser::parse_list_field(
+ t_type* type,
+ std::map<std::string, std::vector<std::string>>& annotations) {
+ (void)type;
+ std::vector<validation_rule*> rules;
+ add_integer_rule(rules, "vt.min_size", annotations);
+ add_integer_rule(rules, "vt.max_size", annotations);
+ std::string elem_prefix("vt.elem");
+ std::map<std::string, std::vector<std::string>> elem_annotations;
+ for (auto it = annotations.begin(); it != annotations.end(); it++) {
+ if (it->first.compare(0, elem_prefix.size(), elem_prefix) == 0) {
+ std::string elem_key = "vt" + it->first.substr(elem_prefix.size());
+ elem_annotations[elem_key] = it->second;
+ }
+ }
+ std::vector<validation_rule*> elem_rules;
+ if (type->is_list()) {
+ elem_rules = parse_field(((t_list*)type)->get_elem_type(), elem_annotations);
+ } else if (type->is_set()) {
+ elem_rules = parse_field(((t_set*)type)->get_elem_type(), elem_annotations);
+ }
+ for (auto it = elem_rules.begin(); it != elem_rules.end(); it++) {
+ validation_rule* rule = new validation_rule(elem_prefix, *it);
+ rules.push_back(rule);
+ }
+ return rules;
+}
+
+std::vector<validation_rule*> validation_parser::parse_map_field(
+ t_type* type,
+ std::map<std::string, std::vector<std::string>>& annotations) {
+ std::vector<validation_rule*> rules;
+ add_integer_rule(rules, "vt.min_size", annotations);
+ add_integer_rule(rules, "vt.max_size", annotations);
+ std::string key_prefix("vt.key");
+ std::map<std::string, std::vector<std::string>> key_annotations;
+ for (auto it = annotations.begin(); it != annotations.end(); it++) {
+ if (it->first.compare(0, key_prefix.size(), key_prefix) == 0) {
+ std::string key_key = "vt" + it->first.substr(key_prefix.size());
+ key_annotations[key_key] = it->second;
+ }
+ }
+ std::vector<validation_rule*> key_rules;
+ key_rules = parse_field(((t_map*)type)->get_key_type(), key_annotations);
+ for (auto it = key_rules.begin(); it != key_rules.end(); it++) {
+ validation_rule* rule = new validation_rule(key_prefix, *it);
+ rules.push_back(rule);
+ }
+
+ std::string value_prefix("vt.value");
+ std::map<std::string, std::vector<std::string>> value_annotations;
+ for (auto it = annotations.begin(); it != annotations.end(); it++) {
+ if (it->first.compare(0, value_prefix.size(), value_prefix) == 0) {
+ std::string value_key = "vt" + it->first.substr(value_prefix.size());
+ value_annotations[value_key] = it->second;
+ }
+ }
+ std::vector<validation_rule*> value_rules;
+ value_rules = parse_field(((t_map*)type)->get_val_type(), value_annotations);
+ for (auto it = value_rules.begin(); it != value_rules.end(); it++) {
+ validation_rule* rule = new validation_rule(value_prefix, *it);
+ rules.push_back(rule);
+ }
+ return rules;
+}
+
+std::vector<validation_rule*> validation_parser::parse_struct_field(
+ t_type* type,
+ std::map<std::string, std::vector<std::string>>& annotations) {
+ (void)type;
+ std::vector<validation_rule*> rules;
+ add_bool_rule(rules, "vt.skip", annotations);
+ return rules;
+}
+
+std::vector<validation_rule*> validation_parser::parse_xception_field(
+ t_type* type,
+ std::map<std::string, std::vector<std::string>>& annotations) {
+ return parse_struct_field(type, annotations);
+}
+
+std::vector<validation_rule*> validation_parser::parse_union_field(
+ t_type* type,
+ std::map<std::string, std::vector<std::string>>& annotations) {
+ return parse_struct_field(type, annotations);
+}
+
+bool validation_parser::is_reference_field(std::string value) {
+ if (value[0] != '$') {
+ return false;
+ }
+ value.erase(value.begin());
+ t_field* field = this->reference->get_field_by_name(value);
+ return field != nullptr;
+}
+
+bool validation_parser::is_validation_function(std::string value) {
+ if (value[0] != '@') {
+ return false;
+ }
+ return true;
+}
+
+void validation_parser::add_bool_rule(
+ std::vector<validation_rule*>& rules,
+ std::string key,
+ std::map<std::string, std::vector<std::string>>& annotations) {
+ auto it = annotations.find(key);
+ if (it != annotations.end() && !it->second.empty()) {
+ for (auto& annotation_value : it->second) {
+ validation_rule* rule = new validation_rule(key);
+ validation_value* value;
+ if (is_reference_field(annotation_value)) {
+ t_field* field = get_referenced_field(annotation_value);
+ value = new validation_value(field);
+ } else {
+ bool constant;
+ std::istringstream(it->second.back()) >> std::boolalpha >> constant;
+ value = new validation_value(constant);
+ }
+ rule->append_value(value);
+ rules.push_back(rule);
+ }
+ }
+}
+
+void validation_parser::add_double_rule(
+ std::vector<validation_rule*>& rules,
+ std::string key,
+ std::map<std::string, std::vector<std::string>>& annotations) {
+ auto it = annotations.find(key);
+ if (it != annotations.end() && !it->second.empty()) {
+ for (auto& annotation_value : it->second) {
+ if (annotation_value.size() == 0) {
+ continue;
+ }
+ validation_rule* rule = new validation_rule(key);
+ validation_value* value;
+ if (is_validation_function(annotation_value)) {
+ validation_value::validation_function* function = get_validation_function(annotation_value);
+ value = new validation_value(function);
+ } else if (is_reference_field(annotation_value)) {
+ t_field* field = get_referenced_field(annotation_value);
+ value = new validation_value(field);
+ } else {
+ double constant = std::stod(annotation_value);
+ value = new validation_value(constant);
+ }
+ rule->append_value(value);
+ rules.push_back(rule);
+ }
+ }
+}
+
+void validation_parser::add_enum_list_rule(
+ std::vector<validation_rule*>& rules,
+ t_enum* enum_,
+ std::string key,
+ std::map<std::string, std::vector<std::string>>& annotations) {
+ auto it = annotations.find(key);
+ if (it != annotations.end() && !it->second.empty()) {
+ for (auto& annotation_value : it->second) {
+ if (annotation_value.size() == 0) {
+ continue;
+ }
+ validation_rule* rule = new validation_rule(key);
+ if (annotation_value[0] == '[') {
+ validation_value* value;
+ char* str = strdup(annotation_value.c_str());
+ char* pch = strtok(str, list_delimiter);
+ std::string val;
+ while (pch != NULL) {
+ std::string temp(pch);
+ if (is_validation_function(temp)) {
+ validation_value::validation_function* function = get_validation_function(temp);
+ value = new validation_value(function);
+ } else if (is_reference_field(temp)) {
+ t_field* field = get_referenced_field(temp);
+ value = new validation_value(field);
+ } else if (std::stringstream(temp) >> val) {
+ std::string::size_type dot = val.rfind('.');
+ if (dot != std::string::npos) {
+ val = val.substr(dot + 1);
+ }
+ t_enum_value* enum_val = enum_->get_constant_by_name(val);
+ value = new validation_value(enum_val);
+ } else {
+ delete rule;
+ throw "validator error: validation double list parse failed: " + temp;
+ }
+ rule->append_value(value);
+ pch = strtok(NULL, list_delimiter);
+ }
+ } else {
+ validation_value* value;
+ std::string val = annotation_value;
+ std::string::size_type dot = val.rfind('.');
+ if (dot != std::string::npos) {
+ val = val.substr(dot + 1);
+ }
+ t_enum_value* enum_val = enum_->get_constant_by_name(val);
+ value = new validation_value(enum_val);
+ rule->append_value(value);
+ }
+ rules.push_back(rule);
+ }
+ }
+}
+
+void validation_parser::add_double_list_rule(
+ std::vector<validation_rule*>& rules,
+ std::string key,
+ std::map<std::string, std::vector<std::string>>& annotations) {
+ std::map<std::string, std::vector<std::string>> double_rules;
+ auto it = annotations.find(key);
+ if (it != annotations.end() && !it->second.empty()) {
+ for (auto& annotation_value : it->second) {
+ if (annotation_value.size() == 0) {
+ continue;
+ }
+ if (annotation_value[0] == '[') {
+ validation_rule* rule = new validation_rule(key);
+ validation_value* value;
+ char* str = strdup(annotation_value.c_str());
+ char* pch = strtok(str, list_delimiter);
+ double val;
+ while (pch != NULL) {
+ std::string temp(pch);
+ if (is_validation_function(temp)) {
+ validation_value::validation_function* function = get_validation_function(temp);
+ value = new validation_value(function);
+ } else if (is_reference_field(temp)) {
+ t_field* field = get_referenced_field(temp);
+ value = new validation_value(field);
+ } else if (std::stringstream(temp) >> val) {
+ value = new validation_value(val);
+ } else {
+ delete rule;
+ throw "validator error: validation double list parse failed: " + temp;
+ }
+ rule->append_value(value);
+ pch = strtok(NULL, list_delimiter);
+ }
+ rules.push_back(rule);
+ } else {
+ double_rules[key].push_back(annotation_value);
+ }
+ }
+ }
+ if (double_rules[key].size() > 0) {
+ add_double_rule(rules, key, double_rules);
+ }
+}
+
+void validation_parser::add_integer_rule(
+ std::vector<validation_rule*>& rules,
+ std::string key,
+ std::map<std::string, std::vector<std::string>>& annotations) {
+ auto it = annotations.find(key);
+ if (it != annotations.end() && !it->second.empty()) {
+ for (auto& annotation_value : it->second) {
+ if (annotation_value.size() == 0) {
+ continue;
+ }
+ validation_rule* rule = new validation_rule(key);
+ validation_value* value;
+ if (is_reference_field(annotation_value)) {
+ t_field* field = get_referenced_field(annotation_value);
+ value = new validation_value(field);
+ } else if (is_validation_function(annotation_value)) {
+ validation_value::validation_function* function = get_validation_function(annotation_value);
+ value = new validation_value(function);
+ } else {
+ int64_t constant = std::stoll(annotation_value);
+ value = new validation_value(constant);
+ }
+ rule->append_value(value);
+ rules.push_back(rule);
+ }
+ }
+}
+
+void validation_parser::add_integer_list_rule(
+ std::vector<validation_rule*>& rules,
+ std::string key,
+ std::map<std::string, std::vector<std::string>>& annotations) {
+ std::map<std::string, std::vector<std::string>> integer_rules;
+ auto it = annotations.find(key);
+ if (it != annotations.end() && !it->second.empty()) {
+ for (auto& annotation_value : it->second) {
+ if (annotation_value.size() == 0) {
+ continue;
+ }
+ if (annotation_value[0] == '[') {
+ validation_rule* rule = new validation_rule(key);
+ validation_value* value;
+ char* str = strdup(annotation_value.c_str());
+ char* pch = strtok(str, list_delimiter);
+ int64_t val;
+ while (pch != NULL) {
+ std::string temp(pch);
+ if (is_validation_function(temp)) {
+ validation_value::validation_function* function = get_validation_function(temp);
+ value = new validation_value(function);
+ } else if (is_reference_field(temp)) {
+ t_field* field = get_referenced_field(temp);
+ value = new validation_value(field);
+ } else if (std::stringstream(temp) >> val) {
+ value = new validation_value(val);
+ } else {
+ delete rule;
+ throw "validator error: validation integer list parse failed: " + temp;
+ }
+ rule->append_value(value);
+ pch = strtok(NULL, list_delimiter);
+ }
+ rules.push_back(rule);
+ } else {
+ integer_rules[key].push_back(annotation_value);
+ }
+ }
+ }
+ if (integer_rules[key].size() > 0) {
+ add_integer_rule(rules, key, integer_rules);
+ }
+}
+
+void validation_parser::add_string_rule(
+ std::vector<validation_rule*>& rules,
+ std::string key,
+ std::map<std::string, std::vector<std::string>>& annotations) {
+ auto it = annotations.find(key);
+ if (it != annotations.end() && !it->second.empty()) {
+ for (auto& annotation_value : it->second) {
+ validation_rule* rule = new validation_rule(key);
+ validation_value* value;
+ if (is_reference_field(annotation_value)) {
+ t_field* field = get_referenced_field(annotation_value);
+ value = new validation_value(field);
+ } else {
+ value = new validation_value(annotation_value);
+ }
+ rule->append_value(value);
+ rules.push_back(rule);
+ }
+ }
+}
+
+t_field* validation_parser::get_referenced_field(std::string annotation_value) {
+ annotation_value.erase(annotation_value.begin());
+ return reference->get_field_by_name(annotation_value);
+}
+
+validation_value::validation_function* validation_parser::get_validation_function(
+ std::string annotation_value) {
+ std::string value = annotation_value;
+ value.erase(value.begin());
+ validation_value::validation_function* function = new validation_value::validation_function;
+
+ size_t name_end = value.find_first_of('(');
+ if (name_end >= value.size()) {
+ delete function;
+ throw "validator error: validation function parse failed: " + annotation_value;
+ }
+ function->name = value.substr(0, name_end);
+ value.erase(0, name_end + 1); // name(
+
+ if (function->name == "len") {
+ size_t argument_end = value.find_first_of(')');
+ if (argument_end >= value.size()) {
+ delete function;
+ throw "validator error: validation function parse failed: " + annotation_value;
+ }
+ std::string argument = value.substr(0, argument_end);
+ if (argument.size() > 0 && argument[0] == '$') {
+ t_field* field = get_referenced_field(argument);
+ validation_value* value = new validation_value(field);
+ function->arguments.push_back(value);
+ } else {
+ delete function;
+ throw "validator error: validation function parse failed, unrecognized argument: "
+ + annotation_value;
+ }
+ } else {
+ delete function;
+ throw "validator error: validation function parse failed, function not supported: "
+ + annotation_value;
+ }
+ return function;
+} \ No newline at end of file
diff --git a/compiler/cpp/src/thrift/generate/validator_parser.h b/compiler/cpp/src/thrift/generate/validator_parser.h
new file mode 100644
index 000000000..076af2ed6
--- /dev/null
+++ b/compiler/cpp/src/thrift/generate/validator_parser.h
@@ -0,0 +1,208 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef T_VALIDATOR_GENERATOR_H
+#define T_VALIDATOR_GENERATOR_H
+
+#include "thrift/generate/t_generator.h"
+#include <fstream>
+#include <iostream>
+#include <limits>
+#include <string>
+#include <vector>
+
+class validation_value {
+public:
+ struct validation_function {
+ public:
+ std::string name;
+ std::vector<validation_value*> arguments;
+ };
+
+ enum validation_value_type {
+ VV_INTEGER,
+ VV_DOUBLE,
+ VV_BOOL,
+ VV_ENUM,
+ VV_STRING,
+ VV_FUNCTION,
+ VV_FIELD_REFERENCE,
+ VV_UNKNOWN
+ };
+
+ validation_value() : val_type(VV_UNKNOWN) {}
+ validation_value(const int64_t val) : int_val(val), val_type(VV_INTEGER) {}
+ validation_value(const double val) : double_val(val), val_type(VV_DOUBLE) {}
+ validation_value(const bool val) : bool_val(val), val_type(VV_BOOL) {}
+ validation_value(t_enum_value* val) : enum_val(val), val_type(VV_ENUM) {}
+ validation_value(const std::string val) : string_val(val), val_type(VV_STRING) {}
+ validation_value(validation_function* val) : function_val(val), val_type(VV_FUNCTION) {}
+ validation_value(t_field* val) : field_reference_val(val), val_type(VV_FIELD_REFERENCE) {}
+
+ void set_int(const int64_t val) {
+ int_val = val;
+ val_type = VV_INTEGER;
+ }
+ int64_t get_int() const { return int_val; };
+
+ void set_double(const double val) {
+ double_val = val;
+ val_type = VV_DOUBLE;
+ }
+ double get_double() { return double_val; };
+
+ void set_bool(const bool val) {
+ bool_val = val;
+ val_type = VV_BOOL;
+ }
+ bool get_bool() const { return bool_val; };
+
+ void set_enum(t_enum_value* val) {
+ enum_val = val;
+ val_type = VV_ENUM;
+ }
+ t_enum_value* get_enum() const { return enum_val; };
+
+ void set_string(const std::string val) {
+ string_val = val;
+ val_type = VV_STRING;
+ }
+ std::string get_string() const { return string_val; };
+
+ void set_function(validation_function* val) {
+ function_val = val;
+ val_type = VV_FUNCTION;
+ }
+
+ validation_function* get_function() { return function_val; };
+
+ void set_field_reference(t_field* val) {
+ field_reference_val = val;
+ val_type = VV_FIELD_REFERENCE;
+ }
+ t_field* get_field_reference() const { return field_reference_val; };
+
+ bool is_field_reference() const { return val_type == VV_FIELD_REFERENCE; };
+
+ bool is_validation_function() const { return val_type == VV_FUNCTION; };
+
+ validation_value_type get_type() const { return val_type; };
+
+private:
+ int64_t int_val = 0;
+ double double_val = 0.0;
+ bool bool_val = false;
+ t_enum_value* enum_val = nullptr;
+ std::string string_val;
+ validation_function* function_val;
+ t_field* field_reference_val;
+
+ validation_value_type val_type;
+};
+
+class validation_rule {
+public:
+ validation_rule(){};
+ validation_rule(std::string name) : name(name){};
+ validation_rule(std::string name, validation_rule* inner) : name(name), inner(inner){};
+
+ std::string get_name() { return name; };
+ void append_value(validation_value* value) { values.push_back(value); }
+ const std::vector<validation_value*>& get_values() { return values; };
+ validation_rule* get_inner() { return inner; };
+
+private:
+ std::string name;
+ std::vector<validation_value*> values;
+ validation_rule* inner;
+};
+
+class validation_parser {
+public:
+ validation_parser() {}
+ validation_parser(t_struct* reference) : reference(reference) {}
+ std::vector<validation_rule*> parse_field(
+ t_type* type,
+ std::map<std::string, std::vector<std::string>>& annotations);
+ void set_reference(t_struct* reference) { this->reference = reference; };
+
+private:
+ std::vector<validation_rule*> parse_bool_field(
+ t_type* type,
+ std::map<std::string, std::vector<std::string>>& annotations);
+ std::vector<validation_rule*> parse_enum_field(
+ t_type* type,
+ std::map<std::string, std::vector<std::string>>& annotations);
+ std::vector<validation_rule*> parse_double_field(
+ t_type* type,
+ std::map<std::string, std::vector<std::string>>& annotations);
+ std::vector<validation_rule*> parse_integer_field(
+ t_type* type,
+ std::map<std::string, std::vector<std::string>>& annotations);
+ std::vector<validation_rule*> parse_string_field(
+ t_type* type,
+ std::map<std::string, std::vector<std::string>>& annotations);
+ std::vector<validation_rule*> parse_set_field(
+ t_type* type,
+ std::map<std::string, std::vector<std::string>>& annotations);
+ std::vector<validation_rule*> parse_list_field(
+ t_type* type,
+ std::map<std::string, std::vector<std::string>>& annotations);
+ std::vector<validation_rule*> parse_map_field(
+ t_type* type,
+ std::map<std::string, std::vector<std::string>>& annotations);
+ std::vector<validation_rule*> parse_struct_field(
+ t_type* type,
+ std::map<std::string, std::vector<std::string>>& annotations);
+ std::vector<validation_rule*> parse_xception_field(
+ t_type* type,
+ std::map<std::string, std::vector<std::string>>& annotations);
+ std::vector<validation_rule*> parse_union_field(
+ t_type* type,
+ std::map<std::string, std::vector<std::string>>& annotations);
+ bool is_reference_field(std::string value);
+ bool is_validation_function(std::string value);
+ void add_bool_rule(std::vector<validation_rule*>& rules,
+ std::string key,
+ std::map<std::string, std::vector<std::string>>& annotations);
+ void add_double_rule(std::vector<validation_rule*>& rules,
+ std::string key,
+ std::map<std::string, std::vector<std::string>>& annotations);
+ void add_double_list_rule(std::vector<validation_rule*>& rules,
+ std::string key,
+ std::map<std::string, std::vector<std::string>>& annotations);
+ void add_integer_rule(std::vector<validation_rule*>& rules,
+ std::string key,
+ std::map<std::string, std::vector<std::string>>& annotations);
+ void add_integer_list_rule(std::vector<validation_rule*>& rules,
+ std::string key,
+ std::map<std::string, std::vector<std::string>>& annotations);
+ void add_string_rule(std::vector<validation_rule*>& rules,
+ std::string key,
+ std::map<std::string, std::vector<std::string>>& annotations);
+ void add_enum_list_rule(std::vector<validation_rule*>& rules,
+ t_enum* enum_,
+ std::string key,
+ std::map<std::string, std::vector<std::string>>& annotations);
+ t_field* get_referenced_field(std::string annotation_value);
+ validation_value::validation_function* get_validation_function(std::string annotation_value);
+ t_struct* reference;
+};
+
+#endif
diff --git a/compiler/cpp/src/thrift/parse/t_enum_value.h b/compiler/cpp/src/thrift/parse/t_enum_value.h
index 70eee8618..c6558d781 100644
--- a/compiler/cpp/src/thrift/parse/t_enum_value.h
+++ b/compiler/cpp/src/thrift/parse/t_enum_value.h
@@ -20,9 +20,9 @@
#ifndef T_ENUM_VALUE_H
#define T_ENUM_VALUE_H
+#include "thrift/parse/t_doc.h"
#include <map>
#include <string>
-#include "thrift/parse/t_doc.h"
/**
* A constant. These are used inside of enum definitions. Constants are just
@@ -40,7 +40,7 @@ public:
int get_value() const { return value_; }
- std::map<std::string, std::string> annotations_;
+ std::map<std::string, std::vector<std::string>> annotations_;
private:
std::string name_;
diff --git a/compiler/cpp/src/thrift/parse/t_field.h b/compiler/cpp/src/thrift/parse/t_field.h
index f0a607de0..928fdcf93 100644
--- a/compiler/cpp/src/thrift/parse/t_field.h
+++ b/compiler/cpp/src/thrift/parse/t_field.h
@@ -21,8 +21,8 @@
#define T_FIELD_H
#include <map>
-#include <string>
#include <sstream>
+#include <string>
#include "thrift/parse/t_doc.h"
#include "thrift/parse/t_type.h"
@@ -106,7 +106,7 @@ public:
}
};
- std::map<std::string, std::string> annotations_;
+ std::map<std::string, std::vector<std::string>> annotations_;
bool get_reference() const { return reference_; }
diff --git a/compiler/cpp/src/thrift/parse/t_function.h b/compiler/cpp/src/thrift/parse/t_function.h
index d30c8a46e..bc0ae465b 100644
--- a/compiler/cpp/src/thrift/parse/t_function.h
+++ b/compiler/cpp/src/thrift/parse/t_function.h
@@ -20,10 +20,10 @@
#ifndef T_FUNCTION_H
#define T_FUNCTION_H
-#include <string>
-#include "thrift/parse/t_type.h"
-#include "thrift/parse/t_struct.h"
#include "thrift/parse/t_doc.h"
+#include "thrift/parse/t_struct.h"
+#include "thrift/parse/t_type.h"
+#include <string>
/**
* Representation of a function. Key parts are return type, function name,
@@ -79,7 +79,7 @@ public:
bool is_oneway() const { return oneway_; }
- std::map<std::string, std::string> annotations_;
+ std::map<std::string, std::vector<std::string>> annotations_;
private:
t_type* returntype_;
diff --git a/compiler/cpp/src/thrift/parse/t_program.h b/compiler/cpp/src/thrift/parse/t_program.h
index b6b1332c0..23c6463a2 100644
--- a/compiler/cpp/src/thrift/parse/t_program.h
+++ b/compiler/cpp/src/thrift/parse/t_program.h
@@ -331,20 +331,20 @@ public:
return namespaces_;
}
- void set_namespace_annotations(std::string language, std::map<std::string, std::string> annotations) {
+ void set_namespace_annotations(std::string language, std::map<std::string, std::vector<std::string>> annotations) {
namespace_annotations_[language] = annotations;
}
- const std::map<std::string, std::string>& get_namespace_annotations(const std::string& language) const {
+ const std::map<std::string, std::vector<std::string>>& get_namespace_annotations(const std::string& language) const {
auto it = namespace_annotations_.find(language);
if (namespace_annotations_.end() != it) {
return it->second;
}
- static const std::map<std::string, std::string> emptyMap;
+ static const std::map<std::string, std::vector<std::string>> emptyMap;
return emptyMap;
}
- std::map<std::string, std::string>& get_namespace_annotations(const std::string& language) {
+ std::map<std::string, std::vector<std::string>>& get_namespace_annotations(const std::string& language) {
return namespace_annotations_[language];
}
@@ -400,7 +400,7 @@ private:
std::map<std::string, std::string> namespaces_;
// Annotations for dynamic namespaces
- std::map<std::string, std::map<std::string, std::string> > namespace_annotations_;
+ std::map<std::string, std::map<std::string, std::vector<std::string>>> namespace_annotations_;
// C++ extra includes
std::vector<std::string> cpp_includes_;
diff --git a/compiler/cpp/src/thrift/parse/t_type.h b/compiler/cpp/src/thrift/parse/t_type.h
index 8dbeb9efc..f4082426c 100644
--- a/compiler/cpp/src/thrift/parse/t_type.h
+++ b/compiler/cpp/src/thrift/parse/t_type.h
@@ -20,11 +20,11 @@
#ifndef T_TYPE_H
#define T_TYPE_H
-#include <string>
-#include <map>
+#include "thrift/parse/t_doc.h"
#include <cstring>
+#include <map>
#include <stdint.h>
-#include "thrift/parse/t_doc.h"
+#include <string>
class t_program;
@@ -83,7 +83,7 @@ public:
return rv;
}
- std::map<std::string, std::string> annotations_;
+ std::map<std::string, std::vector<std::string>> annotations_;
protected:
t_type() : program_(nullptr) { ; }
diff --git a/compiler/cpp/src/thrift/thrifty.yy b/compiler/cpp/src/thrift/thrifty.yy
index 82b2be57e..bb2c19e44 100644
--- a/compiler/cpp/src/thrift/thrifty.yy
+++ b/compiler/cpp/src/thrift/thrifty.yy
@@ -1301,7 +1301,7 @@ TypeAnnotationList:
{
pdebug("TypeAnnotationList -> TypeAnnotationList , TypeAnnotation");
$$ = $1;
- $$->annotations_[$2->key] = $2->val;
+ $$->annotations_[$2->key].push_back($2->val);
delete $2;
}
|
diff --git a/compiler/cpp/tests/CMakeLists.txt b/compiler/cpp/tests/CMakeLists.txt
index 0e8254158..b8b2777a9 100644
--- a/compiler/cpp/tests/CMakeLists.txt
+++ b/compiler/cpp/tests/CMakeLists.txt
@@ -76,6 +76,8 @@ set(thrift_compiler_SOURCES
${THRIFT_COMPILER_SOURCE_DIR}/src/thrift/audit/t_audit.cpp
${THRIFT_COMPILER_SOURCE_DIR}/src/thrift/common.cc
${THRIFT_COMPILER_SOURCE_DIR}/src/thrift/generate/t_generator.cc
+ ${THRIFT_COMPILER_SOURCE_DIR}/src/thrift/generate/validator_parser.cc
+ ${THRIFT_COMPILER_SOURCE_DIR}/src/thrift/generate/validator_parser.h
${THRIFT_COMPILER_SOURCE_DIR}/src/thrift/parse/t_typedef.cc
${THRIFT_COMPILER_SOURCE_DIR}/src/thrift/parse/parse.cc
${THRIFT_COMPILER_SOURCE_DIR}/thrift/version.h
@@ -96,6 +98,18 @@ macro(THRIFT_ADD_COMPILER name description initial)
endif()
endmacro()
+# This macro adds an option THRIFT_VALIDATOR_COMPILER_${NAME}
+# that allows enabling or disabling certain languages
+macro(THRIFT_ADD_VALIDATOR_COMPILER name description initial)
+ string(TOUPPER "THRIFT_VALIDATOR_COMPILER_${name}" enabler)
+ set(src "${THRIFT_COMPILER_SOURCE_DIR}/src/thrift/generate/${name}_validator_generator.cc")
+ list(APPEND "${THRIFT_COMPILER_SOURCE_DIR}/src/thrift/generate/${name}_validator_generator.h")
+ option(${enabler} ${description} ${initial})
+ if(${enabler})
+ list(APPEND thrift-compiler_SOURCES ${src})
+ endif()
+endmacro()
+
# The following compiler with unit tests can be enabled or disabled
THRIFT_ADD_COMPILER(c_glib "Enable compiler for C with Glib" OFF)
THRIFT_ADD_COMPILER(cl "Enable compiler for Common LISP" OFF)
@@ -125,6 +139,9 @@ THRIFT_ADD_COMPILER(swift "Enable compiler for Swift" OFF)
THRIFT_ADD_COMPILER(xml "Enable compiler for XML" OFF)
THRIFT_ADD_COMPILER(xsd "Enable compiler for XSD" OFF)
+# The following compiler can be enabled or disabled by enabling or disabling certain languages
+THRIFT_ADD_VALIDATOR_COMPILER(go "Enable validator compiler for Go" ON)
+
# Thrift is looking for include files in the src directory
# we also add the current binary directory for generated files
include_directories(${CMAKE_CURRENT_BINARY_DIR} ${THRIFT_COMPILER_SOURCE_DIR}/src ${CMAKE_CURRENT_SOURCE_DIR}/catch)
diff --git a/doc/specs/thrift-parameter-validation-proposal.md b/doc/specs/thrift-parameter-validation-proposal.md
new file mode 100644
index 000000000..42f308404
--- /dev/null
+++ b/doc/specs/thrift-parameter-validation-proposal.md
@@ -0,0 +1,195 @@
+# Thrift Parameter Validation Proposal
+
+> Version 1.1
+>
+> Dec 15, 2021
+>
+> duanyi.aster@bytedance.com, wangtieju@bytedance.com
+
+### 1. Abstract
+***
+This document presents a proposed set of annotations to the Thrift IDL. The new annotations will supports parameter validation using build-in or third-party validators. The goal of this proposal is to define semantics and behavior of validation annotations, rather than to discuss their implementation.
+
+### 2. Background
+***
+Parameter validation is a common need for web service. In the past, we usually write our validating logics after a RPC message deserialized by thrift. This ways works flexibly enough but restrict poorly: It is dangerous that service A and service B using the same IDL have two different validating rule, which often misdirects developers. Even if we extract our validating codes to a single module, simple and repeated work (ex. `if xx.Field1 > 1 then ...`) is really disturbing. If we can use build tool to generating codes for simple and unchangeable restraint, the web service will be more robust and developers will benefits from lighter work.
+Compared to other IDL, the parameter validation gradually gets strong community supports like PGV ([protoc-gen-validate](https://github.com/envoyproxy/protoc-gen-validate)), benefiting from pb's strong plugin mechanism (lacking official plugin mechanism is one reason for we submit this proposal). Take a long-term view, auto-generated parameter validation may be a step towards code-less web service.
+
+### 3. Proposal
+***
+This proposal includes three part: Validate Annotation Semantics, Validate Rule and Validate Feedback. The first declare how to write a validate annotation, the middle explain how every annotation should behave, the last introduces a mechanism of validating feedback.
+
+#### 3.1 Validate Annotation Semantics
+This semantics uses same rule of [Thrift IDL](https://thrift.apache.org/docs/idl). The validate option only works on struct fields, thus we must start from Field semantics part.
+- Field
+```peg
+Field <- FieldID? FieldReq? FieldType Identifier ('=' ConstValue)? ValidateAnnotations? ListSeparator?
+```
+- ValidateAnnotations
+```peg
+ValidateAnnotations <- '(' ValidateRule+ ListSeparator? ')'
+```
+- ValidateRule
+```peg
+ValidateRule <- ('validate' | 'vt') Validator+ = '"' ValidatingValue? '"'
+```
+- Validator
+
+ Build-in validating logics. See [Supported Validator](#321-supported-validator) part.
+```peg
+Validator <- '.' Identifier
+```
+- ValidatingValue
+```peg
+ValidatingValue <- (ToolFunction '(' )? Arguments ')'?
+```
+- ToolFunction
+
+ Build-in or user-defined tool functions. See [Tool Function](#325-tool-function) part.
+```peg
+ToolFunction <- '@' Identifier
+```
+- Arguments
+```peg
+Arguments <- (DynamicValue ListSeparator?)*
+```
+- DynamicValue
+```peg
+DynamicValue <- ConstValue | FieldReference
+```
+- FieldReference
+
+ See [Field Reference](#324-field-reference) part.
+```apache
+FieldReference <- '$' ReferPath
+ReferPath <- FieldName? ( ('['IntConstant']') | ('.'Identifier) )?
+```
+- All other semantics keep same with [standard definition](https://thrift.apache.org/docs/idl)
+
+### 3.2 Validate Rule
+The validate rule is works as a Boolean Expression, and Validator is core logic for one validate rule. Every Validator works like an Operator, calculating the Validating Value and Field Value, and then compare. For example, `gt` (greater than) will compare the right Validating Value with value of the field it belongs to, and return `true` if field value is greater than value or `false` if field value is not. We appoint that: Only if the validate rule returns true, the validated parameter is valid. If there are several validate rules defined in annotations of a field, Validator will take the logical relation as "and". Simply put, commas in annotations can be treated as "and".
+
+
+#### 3.2.1 Supported Validator
+Below lists the support validators. Value type means the type of validating value, field type means type of validated field.
+
+| validator | behavior | value type | field type | secondary validator |
+| ------------ | ------------------------------------- | ------------------------------------ | ---------------------- | ------------------- |
+| const | must be constant | string, bool | same with value | - |
+| defined_only | must be defined value | enum | enum | - |
+| not_nil | must not be empty | "true" | any | - |
+| skip | skip validate | "true" | any | - |
+| eq | equals to (`==`) | i8, i16, i32, i64, f64, string, bool | same with value | - |
+| ne | not equals to (`!=`) | i8, i16, i32, i64, f64, string, bool | same with value | - |
+| lt | less than (`<`) | i8, i16, i32, i64, f64 | same with value | - |
+| le | less equal (`<=`) | i8, i16, i32, i64, f64 | same with value | - |
+| gt | greater than (`>`) | i8, i16, i32, i64, f64 | same with value | - |
+| ge | greater equal (`>=`) | i8, i16, i32, i64, f64 | same with value | - |
+| in | within given container | i8, i16, i32, i64, f64, enum | same with value | - |
+| not_in | not within given container | i8, i16, i32, i64, f64, enum | same with value | - |
+| elem | field's element constraint | any | list, set | support |
+| key | field's element key constraint | any | map | support |
+| value | field's element value constraint | any | map | support |
+| min_size | minimal length | i8, i16, i32, i64 | string, list, set, map | - |
+| max_size | maximal length | i8, i16, i32, i64 | string, list, set, map | - |
+| prefix | field prefix must be (case-sensitive) | string | string | - |
+| suffix | suffix must be (case-sensitive) | string | string | - |
+| contains | must contain (case-sensitive) | string | string | - |
+| not_contains | must not contain (case-sensitive) | string | string | - |
+| pattern | basic regular expression | string | string | - |
+
+- Basic Regular Expression (BRE), the syntax of BRE can be found in [manual](https://www.gnu.org/software/sed/manual/html_node/BRE-syntax.html) of GNU sed.
+- Secondary validator (`elem`, `key` and `value`) is a successive validator, usually used at container-type field. See below Set/List/Map examples.
+- Add suffix "_escape" to validators to prevent value of rule conflicting with tool function. For example, you can use `"vt.eq_escape" = "@len(A)"` to match literal `@len(A)`.
+
+#### 3.2.2 IDL example
+- Number
+```
+struct NumericDemo{
+ 1: double Value (validator.ge = "1000.1", validator.le = "10000.1")
+ 2: i8 Type (validator.in = "[1, 2, 4]")
+}
+```
+- String/Binary
+```
+struct StringDemo{
+ 1: string Uninitialized (vt.const = "abc")
+ 2: string Name (vt.min_size = "6", vt.max_size = "12")
+ 3: string SomeStuffs (vt.pattern = "[0-9A-Za-z]+")
+ 4: string DebugInfo (vt.prefix = "[Debug]")
+ 5: string ErrorMessage (vt.contains = "Error")
+}
+```
+- Bool
+```
+struct BoolDemo {
+ 1: bool AMD (vt.const = "true")
+}
+```
+- Enum
+```
+enum Type {
+ Bool
+ I8
+ I16
+ I32
+ I64
+ String
+ Struct
+ List
+ Set
+ Map
+}
+
+struct EnumDemo {
+ 1: Type AddressType (vt.in = "[String]")
+ 2: Type ValueType (vt.defined_only = "true")
+}
+```
+- Set/List
+```
+struct SetListDemo {
+ 1: list<string> Persons (vt.min_size = "5", vt.max_size = "10")
+ 2: set<double> HealthPoints (vt.elem.gt = "0")
+}
+```
+- Map
+```
+struct MapDemo {
+ 1: map<i32, string> IdName (vt.min_size = "5", vt.max_size = "10")
+ 2: map<i32, double> Some (vt.key.gt = "0", vt.value.lt = "1000")
+}
+```
+
+#### 3.2.3 Arguments
+Arguments can by static literals or dynamic variables. If one literal expression contains any Field Reference or Tool Function, it becomes dynamic variables. Every dynamic variables finally get calculated and finally become a Thrift Constant Value.
+
+#### 3.2.4 Field Reference
+Field Reference is used to refer to another field's value in Validating Value, therefore user can compare more than one field. The referenced field must be within same struct. Identifier must be the field name referred.
+- Field Reference Rule
+1. `$x` represents a variable named x, and its scope is within current struct
+2. `$` indicates the current field in which the validator is located
+3. `$x['k']` indicates a reference to the key k of variable x (which must be map)
+4. `$x[i]` indicates a reference to the i + 1 element of variable x (which must be list)
+- Example
+```
+struct FieldReferenceExample {
+ 1: string A (vt.eq = "$B") //field A must equal to field B
+ 2: list<string> C
+}
+```
+
+#### 3.2.5 Tool Function
+Tool Function is use to enhance the operating of Validating Value. For example, if we want to ensure one field is larger than the length of string field A, we can use `len()` function: `vt.gt = "@len($A)"`. The arguments can be either literals or variables, and no size limit. However, we won't suggest any build-in function here, because the category is too big and always language-related. Instead, we only propose one mechanism for thrift developers to extends their implementation according to used language.
+
+Supported functions:
+| function | behavior | arguments | results | supported language |
+| -------- | ----------------------- | ----------------------------------- | -------- | ------------------ |
+| len | the length of the field | 1: string, binary, list, set or map | 1. int64 | go |
+
+### 3.3 Feedback
+The generated validating codes should be included in struct's `Validate() TApplicationException` method. If all validate rule declared by one struct get passed, the struct's `Validate() TApplicationException` method returns nil (or just returns without exception, depending on specific language implementation); Otherwise it returns `TApplicationException` and report feedback message indicating failure reason. Due to language function implementations are different, we won't constrain the interface of feedback messages. However, by practice we suggest developers to give below three detail information:
+
+- The position where first validating failure happens.
+- The validator who reports the failure.
+- The red-handed field value and validating value when the failure happens
diff --git a/lib/go/test/Makefile.am b/lib/go/test/Makefile.am
index 992a84357..b9c00d969 100644
--- a/lib/go/test/Makefile.am
+++ b/lib/go/test/Makefile.am
@@ -64,7 +64,8 @@ gopath: $(THRIFT) $(THRIFTTEST) \
ConstOptionalFieldImport.thrift \
ConstOptionalField.thrift \
ProcessorMiddlewareTest.thrift \
- ClientMiddlewareExceptionTest.thrift
+ ClientMiddlewareExceptionTest.thrift \
+ ValidateTest.thrift
mkdir -p gopath/src
grep -v list.*map.*list.*map $(THRIFTTEST) | grep -v 'set<Insanity>' > ThriftTest.thrift
$(THRIFT) $(THRIFTARGS) -r IncludesTest.thrift
@@ -98,6 +99,7 @@ gopath: $(THRIFT) $(THRIFTTEST) \
$(THRIFT) $(THRIFTARGS) -r ConstOptionalField.thrift
$(THRIFT) $(THRIFTARGS_SKIP_REMOTE) ProcessorMiddlewareTest.thrift
$(THRIFT) $(THRIFTARGS) ClientMiddlewareExceptionTest.thrift
+ $(THRIFT) $(THRIFTARGS) ValidateTest.thrift
ln -nfs ../../tests gopath/src/tests
cp -r ./dontexportrwtest gopath/src
touch gopath
@@ -122,7 +124,8 @@ check: gopath
./gopath/src/equalstest \
./gopath/src/conflictargnamestest \
./gopath/src/processormiddlewaretest \
- ./gopath/src/clientmiddlewareexceptiontest
+ ./gopath/src/clientmiddlewareexceptiontest \
+ ./gopath/src/validatetest
$(GO) test -mod=mod github.com/apache/thrift/lib/go/thrift
$(GO) test -mod=mod ./gopath/src/tests ./gopath/src/dontexportrwtest
@@ -168,4 +171,5 @@ EXTRA_DIST = \
ServicesTest.thrift \
TypedefFieldTest.thrift \
UnionBinaryTest.thrift \
- UnionDefaultValueTest.thrift
+ UnionDefaultValueTest.thrift \
+ ValidateTest.thrift
diff --git a/lib/go/test/ValidateTest.thrift b/lib/go/test/ValidateTest.thrift
new file mode 100644
index 000000000..c02bfa8cb
--- /dev/null
+++ b/lib/go/test/ValidateTest.thrift
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+namespace go validatetest
+
+enum EnumFoo {
+ e1
+ e2
+}
+
+struct Foo {
+ 1: bool Bool
+}
+
+struct BasicTest {
+ 1: bool Bool0 = true (vt.const = "true")
+ 2: optional bool Bool1 (vt.const = "true")
+ 3: i8 Byte0 = 1 (vt.lt = "2", vt.le = "2", vt.gt = "0", vt.ge = "0", vt.in = "[0, 1, 2]", vt.not_in = "[3, 4, 5]")
+ 4: optional i8 Byte1 (vt.lt = "1", vt.le = "1", vt.gt = "-1", vt.ge = "-1", vt.in = "[-1, 0, 1]", vt.not_in = "[1, 2, 3]")
+ 5: double Double0 = 1.0 (vt.lt = "2.0", vt.le = "2.0", vt.gt = "0", vt.ge = "0", vt.in = "[0, 1.0, 2.0]", vt.not_in = "[3.0, 4.0, 5.0]")
+ 6: optional double Double1 (vt.lt = "2.0", vt.le = "2.0", vt.gt = "0", vt.ge = "0", vt.in = "[0, 1.0, 2.0]", vt.not_in = "[3.0, 4.0, 5.0]")
+ 7: string String0 = "my const string" (vt.const = "my const string", vt.min_size = "0", vt.max_size = "100", vt.pattern = ".*", vt.prefix = "my", vt.suffix = "string", vt.contains = "const", vt.not_contains = "oh")
+ 8: optional string String1 (vt.const = "my const string", vt.min_size = "0", vt.max_size = "100", vt.pattern = ".*", vt.prefix = "my", vt.suffix = "string", vt.contains = "const", vt.not_contains = "oh")
+ 9: binary Binary0 = "my const string" (vt.const = "my const string", vt.min_size = "0", vt.max_size = "100", vt.pattern = ".*", vt.prefix = "my", vt.suffix = "string", vt.contains = "const", vt.not_contains = "oh")
+ 10: optional binary Binary1 = "my const string" (vt.const = "my const string", vt.min_size = "0", vt.max_size = "100", vt.pattern = ".*", vt.prefix = "my", vt.suffix = "string", vt.contains = "const", vt.not_contains = "oh")
+ 11: map<string, string> Map0 (vt.min_size = "0", vt.max_size = "10", vt.key.min_size = "0", vt.key.max_size = "10", vt.value.min_size = "0", vt.value.max_size = "10")
+ 12: optional map<string, string> Map1 (vt.min_size = "0", vt.max_size = "10", vt.key.min_size = "0", vt.key.max_size = "10", vt.value.min_size = "0", vt.value.max_size = "10")
+ 13: set<string> Set0 (vt.min_size = "0", vt.max_size = "10", vt.elem.min_size = "5")
+ 14: optional set<string> Set1 (vt.min_size = "0", vt.max_size = "10", vt.elem.min_size = "5")
+ 15: EnumFoo Enum0 = EnumFoo.e2 (vt.in = "[EnumFoo.e2]", vt.defined_only = "true")
+ 16: optional EnumFoo Enum1 (vt.in = "[EnumFoo.e1]", vt.defined_only = "true")
+ 17: Foo Struct0 (vt.skip = "true")
+ 18: optional Foo Struct1 (vt.skip = "true")
+ 19: i8 Byte2 = 1 (vt.in = "1", vt.not_in = "2")
+ 20: double Double2 = 3.0 (vt.in = "3.0", vt.not_in = "4.0")
+ 21: EnumFoo Enum2 = EnumFoo.e2 (vt.in = "EnumFoo.e2", vt.not_in = "EnumFoo.e1")
+}
+
+struct FieldReferenceTest {
+ 1: bool Bool0 (vt.const = "$Bool2")
+ 2: optional bool Bool1 (vt.const = "$Bool2")
+ 3: i8 Byte0 = 10 (vt.lt = "$Byte4", vt.le = "$Byte4", vt.gt = "$Byte2", vt.ge = "$Byte2", vt.in = "[$Byte2, $Byte3, $Byte4]", vt.not_in = "[$Byte2, $Byte4]")
+ 4: optional i8 Byte1 (vt.lt = "$Byte4", vt.le = "$Byte4", vt.gt = "$Byte2", vt.ge = "$Byte2", vt.in = "[$Byte2, $Byte3, $Byte4]", vt.not_in = "[$Byte2, $Byte4]")
+ 5: double Double0 = 10.0 (vt.lt = "$Double4", vt.le = "$Double4", vt.gt = "$Double2", vt.ge = "$Double2", vt.in = "[$Double2, $Double3, $Double4]", vt.not_in = "[$Double2, $Double4]")
+ 6: optional double Double1 (vt.lt = "$Double4", vt.le = "$Double4", vt.gt = "$Double2", vt.ge = "$Double2", vt.in = "[$Double2, $Double3, $Double4]", vt.not_in = "[$Double2, $Double4]")
+ 7: string String0 = "my string" (vt.const = "$String2", vt.min_size = "$Byte2", vt.max_size = "$Byte3", vt.pattern = "$String4", vt.prefix = "$String2", vt.suffix = "$String2", vt.contains = "$String2", vt.not_contains = "$String3")
+ 8: optional string String1 (vt.const = "$String2", vt.min_size = "$Byte2", vt.max_size = "$Byte3", vt.pattern = "$String4", vt.prefix = "$String2", vt.suffix = "$String2", vt.contains = "$String2", vt.not_contains = "$String3")
+ 9: binary Binary0 = "my binary" (vt.const = "$Binary2", vt.min_size = "$Byte2", vt.max_size = "$Byte3", vt.pattern = "$Binary4", vt.prefix = "$Binary2", vt.suffix = "$Binary2", vt.contains = "$Binary2", vt.not_contains = "$Binary3")
+ 10: optional binary Binary1 = "my binary" (vt.const = "$Binary2", vt.min_size = "$Byte2", vt.max_size = "$Byte3", vt.pattern = "$Binary4", vt.prefix = "$Binary2", vt.suffix = "$Binary2", vt.contains = "$Binary2", vt.not_contains = "$Binary3")
+ 11: map<string, string> Map0 (vt.min_size = "$Byte2", vt.max_size = "$MaxSize", vt.key.min_size = "$Byte2", vt.key.max_size = "$MaxSize", vt.value.min_size = "$Byte2", vt.value.max_size = "$MaxSize")
+ 12: optional map<string, string> Map1 (vt.min_size = "$Byte2", vt.max_size = "$MaxSize", vt.key.min_size = "$Byte2", vt.key.max_size = "$MaxSize", vt.value.min_size = "$Byte2", vt.value.max_size = "$MaxSize")
+ 13: list<string> List0 (vt.min_size = "$Byte2", vt.max_size = "$MaxSize", vt.elem.min_size = "$Byte2", vt.elem.max_size = "$MaxSize")
+ 14: optional list<string> List1 (vt.min_size = "$Byte2", vt.max_size = "$MaxSize", vt.elem.min_size = "$Byte2", vt.elem.max_size = "$MaxSize")
+ 15: set<string> Set0 (vt.min_size = "$Byte2", vt.max_size = "$MaxSize", vt.elem.min_size = "$Byte2", vt.elem.max_size = "$MaxSize")
+ 16: optional set<string> Set1 (vt.min_size = "$Byte2", vt.max_size = "$MaxSize", vt.elem.min_size = "$Byte2", vt.elem.max_size = "$MaxSize")
+ 17: bool Bool2 = false
+ 18: i8 Byte2 = 0
+ 19: i8 Byte3 = 10
+ 20: i8 Byte4 = 20
+ 21: double Double2 = 0
+ 22: double Double3 = 10.0
+ 23: double Double4 = 20.0
+ 24: string String2 = "my string"
+ 25: string String3 = "other string"
+ 26: string String4 = ".*"
+ 27: binary Binary2 = "my binary"
+ 28: binary Binary3 = "other binary"
+ 29: binary Binary4 = ".*"
+ 30: i64 MaxSize = 10
+}
+
+struct ValidationFunctionTest {
+ 1: string StringFoo
+ 2: i64 StringLength (vt.in = "[@len($StringFoo)]")
+}
+
+struct AnnotationCompatibleTest {
+ 1: bool Bool0 = true (vt.const = "true", go.tag = 'json:"bool1"')
+ 2: i8 Byte0 = 1 (vt.lt = "2", go.tag = 'json:"byte1"')
+ 3: double Double0 = 1.0 (vt.lt = "2.0", go.tag = 'json:"double1"')
+ 4: string String0 = "my const string" (vt.const = "my const string", go.tag = 'json:"string1"')
+ 5: binary Binary0 = "my const string" (vt.const = "my const string", go.tag = 'json:"binary1"')
+ 6: map<string, string> Map0 (vt.max_size = "2", go.tag = 'json:"map1"')
+ 7: set<string> Set0 (vt.max_size = "2", go.tag = 'json:"set1"')
+ 8: list<string> List0 (vt.max_size = "2", go.tag = 'json:"list1"')
+ 9: EnumFoo Enum0 = EnumFoo.e2 (vt.in = "[EnumFoo.e2]", go.tag = 'json:"enum1"')
+ 10: Foo Struct0 (vt.skip = "true", go.tag = 'json:"struct1"')
+}
diff --git a/lib/go/test/tests/validate_test.go b/lib/go/test/tests/validate_test.go
new file mode 100644
index 000000000..957a8df03
--- /dev/null
+++ b/lib/go/test/tests/validate_test.go
@@ -0,0 +1,494 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package tests
+
+import (
+ "encoding/json"
+ "errors"
+ "strconv"
+ "testing"
+
+ "github.com/apache/thrift/lib/go/test/gopath/src/validatetest"
+ thrift "github.com/apache/thrift/lib/go/thrift"
+)
+
+func TestBasicValidator(t *testing.T) {
+ bt := validatetest.NewBasicTest()
+ if err := bt.Validate(); err != nil {
+ t.Error(err)
+ }
+ var ve *thrift.ValidationError
+ bt = validatetest.NewBasicTest()
+ bt.Bool1 = thrift.BoolPtr(false)
+ if err := bt.Validate(); err == nil {
+ t.Error("Expected vt.const error for Bool1")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.const" {
+ t.Errorf("Expected vt.const check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "Bool1" {
+ t.Errorf("Expected error for Bool1, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ bt = validatetest.NewBasicTest()
+ bt.Byte1 = thrift.Int8Ptr(3)
+ if err := bt.Validate(); err == nil {
+ t.Errorf("Expected vt.lt error for Byte1")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.lt" {
+ t.Errorf("Expected vt.lt check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "Byte1" {
+ t.Errorf("Expected error for Byte1, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ bt = validatetest.NewBasicTest()
+ bt.Double1 = thrift.Float64Ptr(3.0)
+ if err := bt.Validate(); err == nil {
+ t.Errorf("Expected vt.lt error for Double1")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.lt" {
+ t.Errorf("Expected vt.lt check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "Double1" {
+ t.Errorf("Expected error for Double1, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ bt = validatetest.NewBasicTest()
+ bt.String1 = thrift.StringPtr("other string")
+ if err := bt.Validate(); err == nil {
+ t.Errorf("Expected vt.const error for String1")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.const" {
+ t.Errorf("Expected vt.const check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "String1" {
+ t.Errorf("Expected error for String1, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ bt = validatetest.NewBasicTest()
+ bt.Binary1 = []byte("other binary")
+ if err := bt.Validate(); err == nil {
+ t.Errorf("Expected vt.const error for Binary1")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.const" {
+ t.Errorf("Expected vt.const check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "Binary1" {
+ t.Errorf("Expected error for Binary1, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ bt = validatetest.NewBasicTest()
+ bt.Map1 = make(map[string]string)
+ for i := 0; i < 11; i++ {
+ bt.Map1[strconv.Itoa(i)] = strconv.Itoa(i)
+ }
+ if err := bt.Validate(); err == nil {
+ t.Errorf("Expected vt.max_size error for Map1")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.max_size" {
+ t.Errorf("Expected vt.max_size check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "Map1" {
+ t.Errorf("Expected error for Map1, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ bt.Map1 = map[string]string{"012345678910": "0"}
+ if err := bt.Validate(); err == nil {
+ t.Errorf("Expected vt.max_size error for Map1")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.max_size" {
+ t.Errorf("Expected vt.max_size check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "Map1" {
+ t.Errorf("Expected error for Map1, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ bt.Map1 = map[string]string{"0": "012345678910"}
+ if err := bt.Validate(); err == nil {
+ t.Errorf("Expected vt.max_size error for Map1")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.max_size" {
+ t.Errorf("Expected vt.max_size check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "Map1" {
+ t.Errorf("Expected error for Map1, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ bt = validatetest.NewBasicTest()
+ for i := 0; i < 11; i++ {
+ bt.Set1 = append(bt.Set1, "0")
+ }
+ if err := bt.Validate(); err == nil {
+ t.Errorf("Expected vt.max_size error for Set1")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.max_size" {
+ t.Errorf("Expected vt.max_size check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "Set1" {
+ t.Errorf("Expected error for Set1, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ bt.Set1 = []string{"0"}
+ if err := bt.Validate(); err == nil {
+ t.Errorf("Expected vt.min_size error for Set1")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.min_size" {
+ t.Errorf("Expected vt.min_size check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "Set1" {
+ t.Errorf("Expected error for Set1, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ bt = validatetest.NewBasicTest()
+ bt.Enum1 = (*validatetest.EnumFoo)(thrift.Int64Ptr(int64(validatetest.EnumFoo_e2)))
+ if err := bt.Validate(); err == nil {
+ t.Errorf("Expected vt.in error for Enum1")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.in" {
+ t.Errorf("Expected vt.in check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "Enum1" {
+ t.Errorf("Expected error for Enum1, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+}
+
+func TestFieldReference(t *testing.T) {
+ frt := validatetest.NewFieldReferenceTest()
+ if err := frt.Validate(); err != nil {
+ t.Error(err)
+ }
+ var ve *thrift.ValidationError
+ frt = validatetest.NewFieldReferenceTest()
+ frt.Bool2 = true
+ if err := frt.Validate(); err == nil {
+ t.Errorf("Expected vt.const error for Bool0")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.const" {
+ t.Errorf("Expected vt.const check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "Bool0" {
+ t.Errorf("Expected error for Bool0, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ frt = validatetest.NewFieldReferenceTest()
+ frt.Byte4 = 9
+ if err := frt.Validate(); err == nil {
+ t.Errorf("Expected vt.lt error for Byte0")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.lt" {
+ t.Errorf("Expected vt.lt check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "Byte0" {
+ t.Errorf("Expected error for Byte0, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ frt = validatetest.NewFieldReferenceTest()
+ frt.Double4 = 9
+ if err := frt.Validate(); err == nil {
+ t.Errorf("Expected vt.lt error for Double0")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.lt" {
+ t.Errorf("Expected vt.lt check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "Double0" {
+ t.Errorf("Expected error for Double0, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ frt = validatetest.NewFieldReferenceTest()
+ frt.String2 = "other string"
+ if err := frt.Validate(); err == nil {
+ t.Errorf("Expected vt.const error for String0")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.const" {
+ t.Errorf("Expected vt.const check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "String0" {
+ t.Errorf("Expected error for String0, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ frt = validatetest.NewFieldReferenceTest()
+ frt.Binary2 = []byte("other string")
+ if err := frt.Validate(); err == nil {
+ t.Errorf("Expected vt.const error for Binary0")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.const" {
+ t.Errorf("Expected vt.const check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "Binary0" {
+ t.Errorf("Expected error for Binary0, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ frt = validatetest.NewFieldReferenceTest()
+ frt.MaxSize = 8
+ frt.Map0 = make(map[string]string)
+ for i := 0; i < 9; i++ {
+ frt.Map0[strconv.Itoa(i)] = strconv.Itoa(i)
+ }
+ if err := frt.Validate(); err == nil {
+ t.Errorf("Expected vt.max_size error for Map0")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.max_size" {
+ t.Errorf("Expected vt.max_size check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "Map0" {
+ t.Errorf("Expected error for Map0, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ frt = validatetest.NewFieldReferenceTest()
+ frt.MaxSize = 8
+ for i := 0; i < 9; i++ {
+ frt.List0 = append(frt.List0, "0")
+ }
+ if err := frt.Validate(); err == nil {
+ t.Errorf("Expected vt.max_size error for List0")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.max_size" {
+ t.Errorf("Expected vt.max_size check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "List0" {
+ t.Errorf("Expected error for List0, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ frt = validatetest.NewFieldReferenceTest()
+ frt.MaxSize = 8
+ for i := 0; i < 9; i++ {
+ frt.Set0 = append(frt.Set0, "0")
+ }
+ if err := frt.Validate(); err == nil {
+ t.Errorf("Expected vt.max_size error for Set0")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.max_size" {
+ t.Errorf("Expected vt.max_size check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "Set0" {
+ t.Errorf("Expected error for Set0, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+}
+
+func TestValidationFunction(t *testing.T) {
+ vft := validatetest.NewValidationFunctionTest()
+ if err := vft.Validate(); err != nil {
+ t.Error(err)
+ }
+ var ve *thrift.ValidationError
+ vft = validatetest.NewValidationFunctionTest()
+ vft.StringFoo = "some string"
+ if err := vft.Validate(); err == nil {
+ t.Errorf("Expected vt.in error for StringLength")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.in" {
+ t.Errorf("Expected vt.in check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "StringLength" {
+ t.Errorf("Expected error for StringLength, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+}
+
+func TestAnnotationCompatibleTest(t *testing.T) {
+ act := validatetest.NewAnnotationCompatibleTest()
+ if err := act.Validate(); err != nil {
+ t.Error(err)
+ }
+ var ve *thrift.ValidationError
+ act = validatetest.NewAnnotationCompatibleTest()
+ act.Bool0 = false
+ if err := act.Validate(); err == nil {
+ t.Errorf("Expected vt.const error for Bool0")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.const" {
+ t.Errorf("Expected vt.const check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "Bool0" {
+ t.Errorf("Expected error for Bool0, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ act = validatetest.NewAnnotationCompatibleTest()
+ act.Byte0 = 3
+ if err := act.Validate(); err == nil {
+ t.Errorf("Expected vt.lt error for Byte0")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.lt" {
+ t.Errorf("Expected vt.lt check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "Byte0" {
+ t.Errorf("Expected error for Byte0, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ act = validatetest.NewAnnotationCompatibleTest()
+ act.Double0 = 3
+ if err := act.Validate(); err == nil {
+ t.Errorf("Expected vt.lt error for Double0")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.lt" {
+ t.Errorf("Expected vt.lt check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "Double0" {
+ t.Errorf("Expected error for Double0, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ act = validatetest.NewAnnotationCompatibleTest()
+ act.String0 = "other string"
+ if err := act.Validate(); err == nil {
+ t.Errorf("Expected vt.const error for String0")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.const" {
+ t.Errorf("Expected vt.const check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "String0" {
+ t.Errorf("Expected error for String0, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ act = validatetest.NewAnnotationCompatibleTest()
+ act.Binary0 = []byte("other string")
+ if err := act.Validate(); err == nil {
+ t.Errorf("Expected vt.const error for Binary0")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.const" {
+ t.Errorf("Expected vt.const check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "Binary0" {
+ t.Errorf("Expected error for Binary0, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ act = validatetest.NewAnnotationCompatibleTest()
+ act.Map0 = map[string]string{"0": "0", "1": "1", "2": "2"}
+ if err := act.Validate(); err == nil {
+ t.Errorf("Expected vt.max_size error for Map0")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.max_size" {
+ t.Errorf("Expected vt.max_size check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "Map0" {
+ t.Errorf("Expected error for Map0, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ act = validatetest.NewAnnotationCompatibleTest()
+ act.Set0 = []string{"0", "1", "2"}
+ if err := act.Validate(); err == nil {
+ t.Errorf("Expected vt.max_size error for Set0")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.max_size" {
+ t.Errorf("Expected vt.max_size check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "Set0" {
+ t.Errorf("Expected error for Set0, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ act = validatetest.NewAnnotationCompatibleTest()
+ act.List0 = []string{"0", "1", "2"}
+ if err := act.Validate(); err == nil {
+ t.Errorf("Expected vt.max_size error for List0")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.max_size" {
+ t.Errorf("Expected vt.max_size check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "List0" {
+ t.Errorf("Expected error for List0, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ act = validatetest.NewAnnotationCompatibleTest()
+ act.Enum0 = validatetest.EnumFoo_e1
+ if err := act.Validate(); err == nil {
+ t.Errorf("Expected vt.in error for Enum0")
+ } else if errors.As(err, &ve) {
+ if ve.Check() != "vt.in" {
+ t.Errorf("Expected vt.in check error, but got %v", ve.Check())
+ }
+ if ve.Field() != "Enum0" {
+ t.Errorf("Expected error for Enum0, but got %v", ve.Field())
+ }
+ } else {
+ t.Errorf("Error cannot be unwrapped into *ValidationError: %v", err)
+ }
+ fields := []string{"bool1", "byte1", "double1", "string1", "binary1", "enum1", "struct1", "list1", "set1", "map1"}
+ b, err := json.Marshal(act)
+ if err != nil {
+ t.Error(err)
+ }
+ jsonMap := make(map[string]interface{})
+ if err = json.Unmarshal(b, &jsonMap); err != nil {
+ t.Error(err)
+ }
+ for _, field := range fields {
+ if _, ok := jsonMap[field]; !ok {
+ t.Errorf("Expected field %s in JSON, but not found", field)
+ }
+ }
+}
diff --git a/lib/go/thrift/application_exception.go b/lib/go/thrift/application_exception.go
index ed85a645c..8b8137ae8 100644
--- a/lib/go/thrift/application_exception.go
+++ b/lib/go/thrift/application_exception.go
@@ -21,6 +21,7 @@ package thrift
import (
"context"
+ "strings"
)
const (
@@ -35,6 +36,7 @@ const (
INVALID_TRANSFORM = 8
INVALID_PROTOCOL = 9
UNSUPPORTED_CLIENT_TYPE = 10
+ VALIDATION_FAILED = 11
)
var defaultApplicationExceptionMessage = map[int32]string{
@@ -49,6 +51,7 @@ var defaultApplicationExceptionMessage = map[int32]string{
INVALID_TRANSFORM: "Invalid transform",
INVALID_PROTOCOL: "Invalid protocol",
UNSUPPORTED_CLIENT_TYPE: "Unsupported client type",
+ VALIDATION_FAILED: "validation failed",
}
// Application level Thrift exception
@@ -59,9 +62,39 @@ type TApplicationException interface {
Write(ctx context.Context, oprot TProtocol) error
}
+type ValidationError struct {
+ message string
+ check string
+ fieldSymbol string
+}
+
+func (e *ValidationError) Check() string {
+ return e.check
+}
+
+func (e *ValidationError) TypeName() string {
+ return strings.Split(e.fieldSymbol, ".")[0]
+}
+
+func (e *ValidationError) Field() string {
+ if fs := strings.Split(e.fieldSymbol, "."); len(fs) > 1 {
+ return fs[1]
+ }
+ return e.fieldSymbol
+}
+
+func (e *ValidationError) FieldSymbol() string {
+ return e.fieldSymbol
+}
+
+func (e ValidationError) Error() string {
+ return e.message
+}
+
type tApplicationException struct {
message string
type_ int32
+ err error
}
var _ TApplicationException = (*tApplicationException)(nil)
@@ -77,8 +110,20 @@ func (e tApplicationException) Error() string {
return defaultApplicationExceptionMessage[e.type_]
}
+func (e tApplicationException) Unwrap() error {
+ return e.err
+}
+
func NewTApplicationException(type_ int32, message string) TApplicationException {
- return &tApplicationException{message, type_}
+ return &tApplicationException{message, type_, nil}
+}
+
+func NewValidationException(type_ int32, check string, field string, message string) TApplicationException {
+ return &tApplicationException{
+ type_: type_,
+ message: message,
+ err: &ValidationError{message: message, check: check, fieldSymbol: field},
+ }
}
func (p *tApplicationException) TypeId() int32 {