diff options
author | D. Can Celasun <can@dcc.im> | 2017-09-21 15:21:00 +0200 |
---|---|---|
committer | James E. King, III <jking@apache.org> | 2017-11-03 18:21:40 -0700 |
commit | 4f77ab8e296d64c57e6ea1c6e3f0f152bc7d6a3a (patch) | |
tree | c3acd180d98bcfdb76c40dc5b6177e16bfc13719 | |
parent | 847ecf3c1de8b297d6a29305b9f7871fcf609c36 (diff) | |
download | thrift-4f77ab8e296d64c57e6ea1c6e3f0f152bc7d6a3a.tar.gz |
THRIFT-4285 Move TX/RX methods from gen. code to library
This change removes a lot of duplication from generated code and allows
the caller to customize how they can read from / write to the
transport. Backwards compatible adapters make the change compatible
with existing code in use by consuming applications.
Client: Go
This closes #1382
-rw-r--r-- | compiler/cpp/src/thrift/generate/t_go_generator.cc | 252 | ||||
-rw-r--r-- | lib/go/test/tests/client_error_test.go | 194 | ||||
-rw-r--r-- | lib/go/test/tests/multiplexed_protocol_test.go | 150 | ||||
-rw-r--r-- | lib/go/test/tests/one_way_test.go | 2 | ||||
-rw-r--r-- | lib/go/test/tests/protocol_mock.go | 1 | ||||
-rw-r--r-- | lib/go/test/tests/protocols_test.go | 2 | ||||
-rw-r--r-- | lib/go/thrift/application_exception.go | 30 | ||||
-rw-r--r-- | lib/go/thrift/client.go | 78 | ||||
-rw-r--r-- | lib/go/thrift/client_go17.go | 13 | ||||
-rw-r--r-- | lib/go/thrift/client_pre_go17.go | 13 | ||||
-rw-r--r-- | test/go/Makefile.am | 2 | ||||
-rw-r--r-- | test/go/src/bin/testclient/main.go | 32 | ||||
-rw-r--r-- | test/go/src/common/client.go | 27 | ||||
-rw-r--r-- | test/go/src/common/clientserver_test.go | 33 | ||||
-rw-r--r-- | tutorial/go/src/client.go | 7 |
15 files changed, 545 insertions, 291 deletions
diff --git a/compiler/cpp/src/thrift/generate/t_go_generator.cc b/compiler/cpp/src/thrift/generate/t_go_generator.cc index bac1c57a2..e869b00bf 100644 --- a/compiler/cpp/src/thrift/generate/t_go_generator.cc +++ b/compiler/cpp/src/thrift/generate/t_go_generator.cc @@ -1878,20 +1878,16 @@ void t_go_generator::generate_service_client(t_service* tservice) { f_types_ << indent() << "type " << serviceName << "Client struct {" << endl; indent_up(); + f_types_ << indent() << "c thrift.TClient" << endl; if (!extends_client.empty()) { f_types_ << indent() << "*" << extends_client << endl; - } else { - f_types_ << indent() << "Transport thrift.TTransport" << endl; - f_types_ << indent() << "ProtocolFactory thrift.TProtocolFactory" << endl; - f_types_ << indent() << "InputProtocol thrift.TProtocol" << endl; - f_types_ << indent() << "OutputProtocol thrift.TProtocol" << endl; - f_types_ << indent() << "SeqId int32" << endl; - /*f_types_ << indent() << "reqs map[int32]Deferred" << endl*/; } indent_down(); f_types_ << indent() << "}" << endl << endl; - // Constructor function + + // Legacy constructor function + f_types_ << indent() << "// Deprecated: Use New" << serviceName << " instead" << endl; f_types_ << indent() << "func New" << serviceName << "ClientFactory(t thrift.TTransport, f thrift.TProtocolFactory) *" << serviceName << "Client {" << endl; @@ -1902,19 +1898,16 @@ void t_go_generator::generate_service_client(t_service* tservice) { f_types_ << "{" << extends_field << ": " << extends_client_new << "Factory(t, f)}"; } else { indent_up(); - f_types_ << "{Transport: t," << endl; - f_types_ << indent() << "ProtocolFactory: f," << endl; - f_types_ << indent() << "InputProtocol: f.GetProtocol(t)," << endl; - f_types_ << indent() << "OutputProtocol: f.GetProtocol(t)," << endl; - f_types_ << indent() << "SeqId: 0," << endl; - /*f_types_ << indent() << "Reqs: make(map[int32]Deferred)" << endl*/; + f_types_ << "{" << endl; + f_types_ << indent() << "c: thrift.NewTStandardClient(f.GetProtocol(t), f.GetProtocol(t))," << endl; indent_down(); f_types_ << indent() << "}" << endl; } indent_down(); f_types_ << indent() << "}" << endl << endl; - // Constructor function + // Legacy constructor function with custom input & output protocols + f_types_ << indent() << "// Deprecated: Use New" << serviceName << " instead" << endl; f_types_ << indent() << "func New" << serviceName << "ClientProtocol(t thrift.TTransport, iprot thrift.TProtocol, oprot thrift.TProtocol) *" @@ -1927,18 +1920,32 @@ void t_go_generator::generate_service_client(t_service* tservice) { << endl; } else { indent_up(); - f_types_ << "{Transport: t," << endl; - f_types_ << indent() << "ProtocolFactory: nil," << endl; - f_types_ << indent() << "InputProtocol: iprot," << endl; - f_types_ << indent() << "OutputProtocol: oprot," << endl; - f_types_ << indent() << "SeqId: 0," << endl; - /*f_types_ << indent() << "Reqs: make(map[int32]interface{})" << endl*/; + f_types_ << "{" << endl; + f_types_ << indent() << "c: thrift.NewTStandardClient(iprot, oprot)," << endl; indent_down(); f_types_ << indent() << "}" << endl; } indent_down(); f_types_ << indent() << "}" << endl << endl; + + // Constructor function + f_types_ << indent() << "func New" << serviceName + << "Client(c thrift.TClient) *" << serviceName << "Client {" << endl; + indent_up(); + f_types_ << indent() << "return &" << serviceName << "Client{" << endl; + + indent_up(); + f_types_ << indent() << "c: c," << endl; + if (!extends.empty()) { + f_types_ << indent() << extends_field << ": " << extends_client_new << "(c)," << endl; + } + indent_down(); + f_types_ << indent() << "}" << endl; + + indent_down(); + f_types_ << indent() << "}" << endl << endl; + // Generate client method implementations vector<t_function*> functions = tservice->get_functions(); vector<t_function*>::const_iterator f_iter; @@ -1953,177 +1960,75 @@ void t_go_generator::generate_service_client(t_service* tservice) { f_types_ << indent() << "func (p *" << serviceName << "Client) " << function_signature_if(*f_iter, "", true) << " {" << endl; indent_up(); - /* - f_types_ << - indent() << "p.SeqId += 1" << endl; - if (!(*f_iter)->is_oneway()) { - f_types_ << - indent() << "d := defer.Deferred()" << endl << - indent() << "p.Reqs[p.SeqId] = d" << endl; - } - */ - f_types_ << indent() << "if err = p.send" << funname << "("; - bool first = true; - - for (fld_iter = fields.begin(); fld_iter != fields.end(); ++fld_iter) { - if (first) { - first = false; - } else { - f_types_ << ", "; - } - f_types_ << variable_name_to_go_name((*fld_iter)->get_name()); - } - - f_types_ << "); err != nil { return }" << endl; - - if (!(*f_iter)->is_oneway()) { - f_types_ << indent() << "return p.recv" << funname << "()" << endl; - } else { - f_types_ << indent() << "return" << endl; - } - - indent_down(); - f_types_ << indent() << "}" << endl << endl; - f_types_ << indent() << "func (p *" << serviceName << "Client) send" - << function_signature(*f_iter) << "(err error) {" << endl; - indent_up(); - std::string argsname = publicize((*f_iter)->get_name() + "_args", true); - // Serialize the request header - f_types_ << indent() << "oprot := p.OutputProtocol" << endl; - f_types_ << indent() << "if oprot == nil {" << endl; - f_types_ << indent() << " oprot = p.ProtocolFactory.GetProtocol(p.Transport)" << endl; - f_types_ << indent() << " p.OutputProtocol = oprot" << endl; - f_types_ << indent() << "}" << endl; - f_types_ << indent() << "p.SeqId++" << endl; - f_types_ << indent() << "if err = oprot.WriteMessageBegin(\"" << (*f_iter)->get_name() - << "\", " << ((*f_iter)->is_oneway() ? "thrift.ONEWAY" : "thrift.CALL") - << ", p.SeqId); err != nil {" << endl; - indent_up(); - f_types_ << indent() << " return" << endl; - indent_down(); - f_types_ << indent() << "}" << endl; - f_types_ << indent() << "args := " << argsname << "{" << endl; + std::string method = (*f_iter)->get_name(); + std::string argsType = publicize(method + "_args", true); + std::string argsName = tmp("_args"); + f_types_ << indent() << "var " << argsName << " " << argsType << endl; for (fld_iter = fields.begin(); fld_iter != fields.end(); ++fld_iter) { - f_types_ << indent() << publicize((*fld_iter)->get_name()) << " : " - << variable_name_to_go_name((*fld_iter)->get_name()) << "," << endl; + f_types_ << indent() << argsName << "." << publicize((*fld_iter)->get_name()) + << " = " << variable_name_to_go_name((*fld_iter)->get_name()) << endl; } - f_types_ << indent() << "}" << endl; - - // Write to the stream - f_types_ << indent() << "if err = args." << write_method_name_ << "(oprot); err != nil {" << endl; - indent_up(); - f_types_ << indent() << " return" << endl; - indent_down(); - f_types_ << indent() << "}" << endl; - f_types_ << indent() << "if err = oprot.WriteMessageEnd(); err != nil {" << endl; - indent_up(); - f_types_ << indent() << " return" << endl; - indent_down(); - f_types_ << indent() << "}" << endl; - f_types_ << indent() << "return oprot.Flush()" << endl; - indent_down(); - f_types_ << indent() << "}" << endl << endl; if (!(*f_iter)->is_oneway()) { - std::string resultname = publicize((*f_iter)->get_name() + "_result", true); - // Open function - f_types_ << endl << indent() << "func (p *" << serviceName << "Client) recv" - << publicize((*f_iter)->get_name()) << "() ("; + std::string resultName = tmp("_result"); + std::string resultType = publicize(method + "_result", true); + f_types_ << indent() << "var " << resultName << " " << resultType << endl; + f_types_ << indent() << "if err = p.c.Call(ctx, \"" + << method << "\", &" << argsName << ", &" << resultName << "); err != nil {" << endl; - if (!(*f_iter)->get_returntype()->is_void()) { - f_types_ << "value " << type_to_go_type((*f_iter)->get_returntype()) << ", "; - } - - f_types_ << "err error) {" << endl; indent_up(); - // TODO(mcslee): Validate message reply here, seq ids etc. - string error(tmp("error")); - string error2(tmp("error")); - f_types_ << indent() << "iprot := p.InputProtocol" << endl; - f_types_ << indent() << "if iprot == nil {" << endl; - f_types_ << indent() << " iprot = p.ProtocolFactory.GetProtocol(p.Transport)" << endl; - f_types_ << indent() << " p.InputProtocol = iprot" << endl; - f_types_ << indent() << "}" << endl; - f_types_ << indent() << "method, mTypeId, seqId, err := iprot.ReadMessageBegin()" << endl; - f_types_ << indent() << "if err != nil {" << endl; - f_types_ << indent() << " return" << endl; - f_types_ << indent() << "}" << endl; - f_types_ << indent() << "if method != \"" << (*f_iter)->get_name() << "\" {" << endl; - f_types_ << indent() << " err = thrift.NewTApplicationException(" - << "thrift.WRONG_METHOD_NAME, \"" << (*f_iter)->get_name() - << " failed: wrong method name\")" << endl; - f_types_ << indent() << " return" << endl; - f_types_ << indent() << "}" << endl; - f_types_ << indent() << "if p.SeqId != seqId {" << endl; - f_types_ << indent() << " err = thrift.NewTApplicationException(" - << "thrift.BAD_SEQUENCE_ID, \"" << (*f_iter)->get_name() - << " failed: out of sequence response\")" << endl; - f_types_ << indent() << " return" << endl; - f_types_ << indent() << "}" << endl; - f_types_ << indent() << "if mTypeId == thrift.EXCEPTION {" << endl; - f_types_ << indent() << " " << error - << " := thrift.NewTApplicationException(thrift.UNKNOWN_APPLICATION_EXCEPTION, " - "\"Unknown Exception\")" << endl; - f_types_ << indent() << " var " << error2 << " error" << endl; - f_types_ << indent() << " " << error2 << ", err = " << error << ".Read(iprot)" << endl; - f_types_ << indent() << " if err != nil {" << endl; - f_types_ << indent() << " return" << endl; - f_types_ << indent() << " }" << endl; - f_types_ << indent() << " if err = iprot.ReadMessageEnd(); err != nil {" << endl; - f_types_ << indent() << " return" << endl; - f_types_ << indent() << " }" << endl; - f_types_ << indent() << " err = " << error2 << endl; - f_types_ << indent() << " return" << endl; - f_types_ << indent() << "}" << endl; - f_types_ << indent() << "if mTypeId != thrift.REPLY {" << endl; - f_types_ << indent() << " err = thrift.NewTApplicationException(" - << "thrift.INVALID_MESSAGE_TYPE_EXCEPTION, \"" << (*f_iter)->get_name() - << " failed: invalid message type\")" << endl; - f_types_ << indent() << " return" << endl; - f_types_ << indent() << "}" << endl; - f_types_ << indent() << "result := " << resultname << "{}" << endl; - f_types_ << indent() << "if err = result." << read_method_name_ << "(iprot); err != nil {" << endl; - f_types_ << indent() << " return" << endl; - f_types_ << indent() << "}" << endl; - f_types_ << indent() << "if err = iprot.ReadMessageEnd(); err != nil {" << endl; - f_types_ << indent() << " return" << endl; + f_types_ << indent() << "return" << endl; + indent_down(); f_types_ << indent() << "}" << endl; t_struct* xs = (*f_iter)->get_xceptions(); const std::vector<t_field*>& xceptions = xs->get_members(); vector<t_field*>::const_iterator x_iter; - for (x_iter = xceptions.begin(); x_iter != xceptions.end(); ++x_iter) { - const std::string pubname = publicize((*x_iter)->get_name()); + if (!xceptions.empty()) { + f_types_ << indent() << "switch {" << endl; - f_types_ << indent() << "if result." << pubname << " != nil {" << endl; - f_types_ << indent() << " err = result." << pubname << endl; - f_types_ << indent() << " return " << endl; - f_types_ << indent() << "}"; + for (x_iter = xceptions.begin(); x_iter != xceptions.end(); ++x_iter) { + const std::string pubname = publicize((*x_iter)->get_name()); + const std::string field = resultName + "." + pubname; - if ((x_iter + 1) != xceptions.end()) { - f_types_ << " else "; - } else { - f_types_ << endl; + f_types_ << indent() << "case " << field << "!= nil:" << endl; + indent_up(); + + if (!(*f_iter)->get_returntype()->is_void()) { + f_types_ << indent() << "return r, " << field << endl; + } else { + f_types_ << indent() << "return "<< field << endl; + } + + indent_down(); } + + f_types_ << indent() << "}" << endl << endl; } - // Careful, only return _result if not a void function if (!(*f_iter)->get_returntype()->is_void()) { - f_types_ << indent() << "value = result.GetSuccess()" << endl; + f_types_ << indent() << "return " << resultName << ".GetSuccess(), nil" << endl; + } else { + f_types_ << indent() << "return nil" << endl; } + } else { + // TODO: would be nice to not to duplicate the call generation + f_types_ << indent() << "if err := p.c.Call(ctx, \"" + << method << "\", &"<< argsName << ", nil); err != nil {" << endl; - f_types_ << indent() << "return" << endl; - // Close function + indent_up(); + f_types_ << indent() << "return err" << endl; indent_down(); - f_types_ << indent() << "}" << endl << endl; + f_types_ << indent() << "}" << endl; + f_types_ << indent() << "return nil" << endl; } - } - // indent_down(); - f_types_ << endl; + indent_down(); + f_types_ << "}" << endl << endl; + } } /** @@ -2308,8 +2213,10 @@ void t_go_generator::generate_service_remote(t_service* tservice) { f_remote << indent() << " Usage()" << endl; f_remote << indent() << " os.Exit(1)" << endl; f_remote << indent() << "}" << endl; + f_remote << indent() << "iprot := protocolFactory.GetProtocol(trans)" << endl; + f_remote << indent() << "oprot := protocolFactory.GetProtocol(trans)" << endl; f_remote << indent() << "client := " << package_name_ << ".New" << publicize(service_name_) - << "ClientFactory(trans, protocolFactory)" << endl; + << "Client(thrift.NewTStandardClient(iprot, oprot))" << endl; f_remote << indent() << "if err := trans.Open(); err != nil {" << endl; f_remote << indent() << " fmt.Fprintln(os.Stderr, \"Error opening socket to \", " "host, \":\", port, \" \", err)" << endl; @@ -3444,10 +3351,13 @@ string t_go_generator::function_signature(t_function* tfunction, string prefix) * @return String of rendered function definition */ string t_go_generator::function_signature_if(t_function* tfunction, string prefix, bool addError) { - // TODO(mcslee): Nitpicky, no ',' if argument_list is empty string signature = publicize(prefix + tfunction->get_name()) + "("; - signature += "ctx context.Context, "; - signature += argument_list(tfunction->get_arglist()) + ") ("; + signature += "ctx context.Context"; + if (!tfunction->get_arglist()->get_members().empty()) { + signature += ", " + argument_list(tfunction->get_arglist()); + } + signature += ") ("; + t_type* ret = tfunction->get_returntype(); t_struct* exceptions = tfunction->get_xceptions(); string errs = argument_list(exceptions); diff --git a/lib/go/test/tests/client_error_test.go b/lib/go/test/tests/client_error_test.go index ad43447da..4a8ef1371 100644 --- a/lib/go/test/tests/client_error_test.go +++ b/lib/go/test/tests/client_error_test.go @@ -20,11 +20,12 @@ package tests import ( - "github.com/golang/mock/gomock" "errors" "errortest" "testing" "thrift" + + "github.com/golang/mock/gomock" ) // TestCase: Comprehensive call and reply workflow in the client. @@ -397,7 +398,6 @@ func prepareClientCallReply(protocol *MockTProtocol, failAt int, failWith error) // Expecting TTransportError on fail. func TestClientReportTTransportErrors(t *testing.T) { mockCtrl := gomock.NewController(t) - transport := thrift.NewTMemoryBuffer() thing := errortest.NewTestStruct() thing.M = make(map[string]string) @@ -411,6 +411,38 @@ func TestClientReportTTransportErrors(t *testing.T) { if !prepareClientCallReply(protocol, i, err) { return } + client := errortest.NewErrorTestClient(thrift.NewTStandardClient(protocol, protocol)) + _, retErr := client.TestStruct(defaultCtx, thing) + mockCtrl.Finish() + err2, ok := retErr.(thrift.TTransportException) + if !ok { + t.Fatal("Expected a TTrasportException") + } + + if err2.TypeId() != thrift.TIMED_OUT { + t.Fatal("Expected TIMED_OUT error") + } + } +} + +// TestCase: Comprehensive call and reply workflow in the client. +// Expecting TTransportError on fail. +// Similar to TestClientReportTTransportErrors, but using legacy client constructor. +func TestClientReportTTransportErrorsLegacy(t *testing.T) { + mockCtrl := gomock.NewController(t) + transport := thrift.NewTMemoryBuffer() + thing := errortest.NewTestStruct() + thing.M = make(map[string]string) + thing.L = make([]string, 0) + thing.S = make([]string, 0) + thing.I = 3 + + err := thrift.NewTTransportException(thrift.TIMED_OUT, "test") + for i := 0; ; i++ { + protocol := NewMockTProtocol(mockCtrl) + if !prepareClientCallReply(protocol, i, err) { + return + } client := errortest.NewErrorTestClientProtocol(transport, protocol, protocol) _, retErr := client.TestStruct(defaultCtx, thing) mockCtrl.Finish() @@ -429,7 +461,6 @@ func TestClientReportTTransportErrors(t *testing.T) { // Expecting TTProtocolErrors on fail. func TestClientReportTProtocolErrors(t *testing.T) { mockCtrl := gomock.NewController(t) - transport := thrift.NewTMemoryBuffer() thing := errortest.NewTestStruct() thing.M = make(map[string]string) @@ -443,6 +474,37 @@ func TestClientReportTProtocolErrors(t *testing.T) { if !prepareClientCallReply(protocol, i, err) { return } + client := errortest.NewErrorTestClient(thrift.NewTStandardClient(protocol, protocol)) + _, retErr := client.TestStruct(defaultCtx, thing) + mockCtrl.Finish() + err2, ok := retErr.(thrift.TProtocolException) + if !ok { + t.Fatal("Expected a TProtocolException") + } + if err2.TypeId() != thrift.INVALID_DATA { + t.Fatal("Expected INVALID_DATA error") + } + } +} + +// TestCase: Comprehensive call and reply workflow in the client. +// Expecting TTProtocolErrors on fail. +// Similar to TestClientReportTProtocolErrors, but using legacy client constructor. +func TestClientReportTProtocolErrorsLegacy(t *testing.T) { + mockCtrl := gomock.NewController(t) + transport := thrift.NewTMemoryBuffer() + thing := errortest.NewTestStruct() + thing.M = make(map[string]string) + thing.L = make([]string, 0) + thing.S = make([]string, 0) + thing.I = 3 + + err := thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, errors.New("test")) + for i := 0; ; i++ { + protocol := NewMockTProtocol(mockCtrl) + if !prepareClientCallReply(protocol, i, err) { + return + } client := errortest.NewErrorTestClientProtocol(transport, protocol, protocol) _, retErr := client.TestStruct(defaultCtx, thing) mockCtrl.Finish() @@ -557,13 +619,47 @@ func prepareClientCallException(protocol *MockTProtocol, failAt int, failWith er // TestCase: call and reply with exception workflow in the client. func TestClientCallException(t *testing.T) { mockCtrl := gomock.NewController(t) - transport := thrift.NewTMemoryBuffer() err := thrift.NewTTransportException(thrift.TIMED_OUT, "test") for i := 0; ; i++ { protocol := NewMockTProtocol(mockCtrl) willComplete := !prepareClientCallException(protocol, i, err) + client := errortest.NewErrorTestClient(thrift.NewTStandardClient(protocol, protocol)) + _, retErr := client.TestString(defaultCtx, "test") + mockCtrl.Finish() + + if !willComplete { + err2, ok := retErr.(thrift.TTransportException) + if !ok { + t.Fatal("Expected a TTransportException") + } + if err2.TypeId() != thrift.TIMED_OUT { + t.Fatal("Expected TIMED_OUT error") + } + } else { + err2, ok := retErr.(thrift.TApplicationException) + if !ok { + t.Fatal("Expected a TApplicationException") + } + if err2.TypeId() != thrift.PROTOCOL_ERROR { + t.Fatal("Expected PROTOCOL_ERROR error") + } + break + } + } +} + +// TestCase: call and reply with exception workflow in the client. +// Similar to TestClientCallException, but using legacy client constructor. +func TestClientCallExceptionLegacy(t *testing.T) { + mockCtrl := gomock.NewController(t) + transport := thrift.NewTMemoryBuffer() + err := thrift.NewTTransportException(thrift.TIMED_OUT, "test") + for i := 0; ; i++ { + protocol := NewMockTProtocol(mockCtrl) + willComplete := !prepareClientCallException(protocol, i, err) + client := errortest.NewErrorTestClientProtocol(transport, protocol, protocol) _, retErr := client.TestString(defaultCtx, "test") mockCtrl.Finish() @@ -592,6 +688,36 @@ func TestClientCallException(t *testing.T) { // TestCase: Mismatching sequence id has been received in the client. func TestClientSeqIdMismatch(t *testing.T) { mockCtrl := gomock.NewController(t) + protocol := NewMockTProtocol(mockCtrl) + gomock.InOrder( + protocol.EXPECT().WriteMessageBegin("testString", thrift.CALL, int32(1)), + protocol.EXPECT().WriteStructBegin("testString_args"), + protocol.EXPECT().WriteFieldBegin("s", thrift.TType(thrift.STRING), int16(1)), + protocol.EXPECT().WriteString("test"), + protocol.EXPECT().WriteFieldEnd(), + protocol.EXPECT().WriteFieldStop(), + protocol.EXPECT().WriteStructEnd(), + protocol.EXPECT().WriteMessageEnd(), + protocol.EXPECT().Flush(), + protocol.EXPECT().ReadMessageBegin().Return("testString", thrift.REPLY, int32(2), nil), + ) + + client := errortest.NewErrorTestClient(thrift.NewTStandardClient(protocol, protocol)) + _, err := client.TestString(defaultCtx, "test") + mockCtrl.Finish() + appErr, ok := err.(thrift.TApplicationException) + if !ok { + t.Fatal("Expected TApplicationException") + } + if appErr.TypeId() != thrift.BAD_SEQUENCE_ID { + t.Fatal("Expected BAD_SEQUENCE_ID error") + } +} + +// TestCase: Mismatching sequence id has been received in the client. +// Similar to TestClientSeqIdMismatch, but using legacy client constructor. +func TestClientSeqIdMismatchLegeacy(t *testing.T) { + mockCtrl := gomock.NewController(t) transport := thrift.NewTMemoryBuffer() protocol := NewMockTProtocol(mockCtrl) gomock.InOrder( @@ -622,6 +748,36 @@ func TestClientSeqIdMismatch(t *testing.T) { // TestCase: Wrong method name has been received in the client. func TestClientWrongMethodName(t *testing.T) { mockCtrl := gomock.NewController(t) + protocol := NewMockTProtocol(mockCtrl) + gomock.InOrder( + protocol.EXPECT().WriteMessageBegin("testString", thrift.CALL, int32(1)), + protocol.EXPECT().WriteStructBegin("testString_args"), + protocol.EXPECT().WriteFieldBegin("s", thrift.TType(thrift.STRING), int16(1)), + protocol.EXPECT().WriteString("test"), + protocol.EXPECT().WriteFieldEnd(), + protocol.EXPECT().WriteFieldStop(), + protocol.EXPECT().WriteStructEnd(), + protocol.EXPECT().WriteMessageEnd(), + protocol.EXPECT().Flush(), + protocol.EXPECT().ReadMessageBegin().Return("unknown", thrift.REPLY, int32(1), nil), + ) + + client := errortest.NewErrorTestClient(thrift.NewTStandardClient(protocol, protocol)) + _, err := client.TestString(defaultCtx, "test") + mockCtrl.Finish() + appErr, ok := err.(thrift.TApplicationException) + if !ok { + t.Fatal("Expected TApplicationException") + } + if appErr.TypeId() != thrift.WRONG_METHOD_NAME { + t.Fatal("Expected WRONG_METHOD_NAME error") + } +} + +// TestCase: Wrong method name has been received in the client. +// Similar to TestClientWrongMethodName, but using legacy client constructor. +func TestClientWrongMethodNameLegacy(t *testing.T) { + mockCtrl := gomock.NewController(t) transport := thrift.NewTMemoryBuffer() protocol := NewMockTProtocol(mockCtrl) gomock.InOrder( @@ -652,6 +808,36 @@ func TestClientWrongMethodName(t *testing.T) { // TestCase: Wrong message type has been received in the client. func TestClientWrongMessageType(t *testing.T) { mockCtrl := gomock.NewController(t) + protocol := NewMockTProtocol(mockCtrl) + gomock.InOrder( + protocol.EXPECT().WriteMessageBegin("testString", thrift.CALL, int32(1)), + protocol.EXPECT().WriteStructBegin("testString_args"), + protocol.EXPECT().WriteFieldBegin("s", thrift.TType(thrift.STRING), int16(1)), + protocol.EXPECT().WriteString("test"), + protocol.EXPECT().WriteFieldEnd(), + protocol.EXPECT().WriteFieldStop(), + protocol.EXPECT().WriteStructEnd(), + protocol.EXPECT().WriteMessageEnd(), + protocol.EXPECT().Flush(), + protocol.EXPECT().ReadMessageBegin().Return("testString", thrift.INVALID_TMESSAGE_TYPE, int32(1), nil), + ) + + client := errortest.NewErrorTestClient(thrift.NewTStandardClient(protocol, protocol)) + _, err := client.TestString(defaultCtx, "test") + mockCtrl.Finish() + appErr, ok := err.(thrift.TApplicationException) + if !ok { + t.Fatal("Expected TApplicationException") + } + if appErr.TypeId() != thrift.INVALID_MESSAGE_TYPE_EXCEPTION { + t.Fatal("Expected INVALID_MESSAGE_TYPE_EXCEPTION error") + } +} + +// TestCase: Wrong message type has been received in the client. +// Similar to TestClientWrongMessageType, but using legacy client constructor. +func TestClientWrongMessageTypeLegacy(t *testing.T) { + mockCtrl := gomock.NewController(t) transport := thrift.NewTMemoryBuffer() protocol := NewMockTProtocol(mockCtrl) gomock.InOrder( diff --git a/lib/go/test/tests/multiplexed_protocol_test.go b/lib/go/test/tests/multiplexed_protocol_test.go index 27802e5a3..0b5896b60 100644 --- a/lib/go/test/tests/multiplexed_protocol_test.go +++ b/lib/go/test/tests/multiplexed_protocol_test.go @@ -36,15 +36,22 @@ func FindAvailableTCPServerPort() net.Addr { } } +func createTransport(addr net.Addr) (thrift.TTransport, error) { + socket := thrift.NewTSocketFromAddrTimeout(addr, TIMEOUT) + transport := thrift.NewTFramedTransport(socket) + err := transport.Open() + if err != nil { + return nil, err + } + return transport, nil +} -var processor = thrift.NewTMultiplexedProcessor() - -func TestInitTwoServers(t *testing.T) { - var err error +func TestMultiplexedProtocolFirst(t *testing.T) { + processor := thrift.NewTMultiplexedProcessor() protocolFactory := thrift.NewTBinaryProtocolFactoryDefault() transportFactory := thrift.NewTTransportFactory() transportFactory = thrift.NewTFramedTransportFactory(transportFactory) - addr = FindAvailableTCPServerPort() + addr := FindAvailableTCPServerPort() serverTransport, err := thrift.NewTServerSocketTimeout(addr.String(), TIMEOUT) if err != nil { t.Fatal("Unable to create server socket", err) @@ -57,82 +64,117 @@ func TestInitTwoServers(t *testing.T) { secondProcessor := multiplexedprotocoltest.NewSecondProcessor(&SecondImpl{}) processor.RegisterProcessor("SecondService", secondProcessor) + defer server.Stop() go server.Serve() time.Sleep(10 * time.Millisecond) -} - -var firstClient *multiplexedprotocoltest.FirstClient -func TestInitClient1(t *testing.T) { - socket := thrift.NewTSocketFromAddrTimeout(addr, TIMEOUT) - transport := thrift.NewTFramedTransport(socket) - var protocol thrift.TProtocol = thrift.NewTBinaryProtocolTransport(transport) - protocol = thrift.NewTMultiplexedProtocol(protocol, "FirstService") - firstClient = multiplexedprotocoltest.NewFirstClientProtocol(transport, protocol, protocol) - err := transport.Open() + transport, err := createTransport(addr) if err != nil { - t.Fatal("Unable to open client socket", err) + t.Fatal(err) } -} + defer transport.Close() + protocol := thrift.NewTMultiplexedProtocol(thrift.NewTBinaryProtocolTransport(transport), "FirstService") -var secondClient *multiplexedprotocoltest.SecondClient + client := multiplexedprotocoltest.NewFirstClient(thrift.NewTStandardClient(protocol, protocol)) -func TestInitClient2(t *testing.T) { - socket := thrift.NewTSocketFromAddrTimeout(addr, TIMEOUT) - transport := thrift.NewTFramedTransport(socket) - var protocol thrift.TProtocol = thrift.NewTBinaryProtocolTransport(transport) - protocol = thrift.NewTMultiplexedProtocol(protocol, "SecondService") - secondClient = multiplexedprotocoltest.NewSecondClientProtocol(transport, protocol, protocol) - err := transport.Open() + ret, err := client.ReturnOne(defaultCtx) if err != nil { - t.Fatal("Unable to open client socket", err) + t.Fatal("Unable to call first server:", err) + } else if ret != 1 { + t.Fatal("Unexpected result from server: ", ret) } } -//create client without service prefix -func createLegacyClient(t *testing.T) *multiplexedprotocoltest.SecondClient { - socket := thrift.NewTSocketFromAddrTimeout(addr, TIMEOUT) - transport := thrift.NewTFramedTransport(socket) - var protocol thrift.TProtocol = thrift.NewTBinaryProtocolTransport(transport) - legacyClient := multiplexedprotocoltest.NewSecondClientProtocol(transport, protocol, protocol) - err := transport.Open() +func TestMultiplexedProtocolSecond(t *testing.T) { + processor := thrift.NewTMultiplexedProcessor() + protocolFactory := thrift.NewTBinaryProtocolFactoryDefault() + transportFactory := thrift.NewTTransportFactory() + transportFactory = thrift.NewTFramedTransportFactory(transportFactory) + addr := FindAvailableTCPServerPort() + serverTransport, err := thrift.NewTServerSocketTimeout(addr.String(), TIMEOUT) if err != nil { - t.Fatal("Unable to open client socket", err) + t.Fatal("Unable to create server socket", err) } - return legacyClient -} + server = thrift.NewTSimpleServer4(processor, serverTransport, transportFactory, protocolFactory) -func TestCallFirst(t *testing.T) { - ret, err := firstClient.ReturnOne(defaultCtx) + firstProcessor := multiplexedprotocoltest.NewFirstProcessor(&FirstImpl{}) + processor.RegisterProcessor("FirstService", firstProcessor) + + secondProcessor := multiplexedprotocoltest.NewSecondProcessor(&SecondImpl{}) + processor.RegisterProcessor("SecondService", secondProcessor) + + defer server.Stop() + go server.Serve() + time.Sleep(10 * time.Millisecond) + + transport, err := createTransport(addr) if err != nil { - t.Fatal("Unable to call first server:", err) + t.Fatal(err) } - if ret != 1 { + defer transport.Close() + protocol := thrift.NewTMultiplexedProtocol(thrift.NewTBinaryProtocolTransport(transport), "SecondService") + + client := multiplexedprotocoltest.NewSecondClient(thrift.NewTStandardClient(protocol, protocol)) + + ret, err := client.ReturnTwo(defaultCtx) + if err != nil { + t.Fatal("Unable to call second server:", err) + } else if ret != 2 { t.Fatal("Unexpected result from server: ", ret) } } -func TestCallSecond(t *testing.T) { - ret, err := secondClient.ReturnTwo(defaultCtx) +func TestMultiplexedProtocolLegacy(t *testing.T) { + processor := thrift.NewTMultiplexedProcessor() + protocolFactory := thrift.NewTBinaryProtocolFactoryDefault() + transportFactory := thrift.NewTTransportFactory() + transportFactory = thrift.NewTFramedTransportFactory(transportFactory) + addr := FindAvailableTCPServerPort() + serverTransport, err := thrift.NewTServerSocketTimeout(addr.String(), TIMEOUT) if err != nil { - t.Fatal("Unable to call second server:", err) + t.Fatal("Unable to create server socket", err) } - if ret != 2 { - t.Fatal("Unexpected result from server: ", ret) + server = thrift.NewTSimpleServer4(processor, serverTransport, transportFactory, protocolFactory) + + firstProcessor := multiplexedprotocoltest.NewFirstProcessor(&FirstImpl{}) + processor.RegisterProcessor("FirstService", firstProcessor) + + secondProcessor := multiplexedprotocoltest.NewSecondProcessor(&SecondImpl{}) + processor.RegisterProcessor("SecondService", secondProcessor) + + defer server.Stop() + go server.Serve() + time.Sleep(10 * time.Millisecond) + + transport, err := createTransport(addr) + if err != nil { + t.Error(err) + return } -} + defer transport.Close() -func TestCallLegacy(t *testing.T) { - legacyClient := createLegacyClient(t) - ret, err := legacyClient.ReturnTwo(defaultCtx) + protocol := thrift.NewTBinaryProtocolTransport(transport) + client := multiplexedprotocoltest.NewSecondClient(thrift.NewTStandardClient(protocol, protocol)) + + ret, err := client.ReturnTwo(defaultCtx) //expect error since default processor is not registered if err == nil { t.Fatal("Expecting error") } + //register default processor and call again processor.RegisterDefault(multiplexedprotocoltest.NewSecondProcessor(&SecondImpl{})) - legacyClient = createLegacyClient(t) - ret, err = legacyClient.ReturnTwo(defaultCtx) + transport, err = createTransport(addr) + if err != nil { + t.Error(err) + return + } + defer transport.Close() + + protocol = thrift.NewTBinaryProtocolTransport(transport) + client = multiplexedprotocoltest.NewSecondClient(thrift.NewTStandardClient(protocol, protocol)) + + ret, err = client.ReturnTwo(defaultCtx) if err != nil { t.Fatal("Unable to call legacy server:", err) } @@ -140,9 +182,3 @@ func TestCallLegacy(t *testing.T) { t.Fatal("Unexpected result from server: ", ret) } } - -func TestShutdownServerAndClients(t *testing.T) { - firstClient.Transport.Close() - secondClient.Transport.Close() - server.Stop() -} diff --git a/lib/go/test/tests/one_way_test.go b/lib/go/test/tests/one_way_test.go index 32881e2a1..8abd671e6 100644 --- a/lib/go/test/tests/one_way_test.go +++ b/lib/go/test/tests/one_way_test.go @@ -59,7 +59,7 @@ func TestInitOneway(t *testing.T) { func TestInitOnewayClient(t *testing.T) { transport := thrift.NewTSocketFromAddrTimeout(addr, TIMEOUT) protocol := thrift.NewTBinaryProtocolTransport(transport) - client = onewaytest.NewOneWayClientProtocol(transport, protocol, protocol) + client = onewaytest.NewOneWayClient(thrift.NewTStandardClient(protocol, protocol)) err := transport.Open() if err != nil { t.Fatal("Unable to open client socket", err) diff --git a/lib/go/test/tests/protocol_mock.go b/lib/go/test/tests/protocol_mock.go index 9197fedab..8476c8661 100644 --- a/lib/go/test/tests/protocol_mock.go +++ b/lib/go/test/tests/protocol_mock.go @@ -24,6 +24,7 @@ package tests import ( thrift "thrift" + gomock "github.com/golang/mock/gomock" ) diff --git a/lib/go/test/tests/protocols_test.go b/lib/go/test/tests/protocols_test.go index 1580678eb..cffd9c3f7 100644 --- a/lib/go/test/tests/protocols_test.go +++ b/lib/go/test/tests/protocols_test.go @@ -47,7 +47,7 @@ func RunSocketTestSuite(t *testing.T, protocolFactory thrift.TProtocolFactory, t.Fatal(err) } var protocol thrift.TProtocol = protocolFactory.GetProtocol(transport) - thriftTestClient := thrifttest.NewThriftTestClientProtocol(transport, protocol, protocol) + thriftTestClient := thrifttest.NewThriftTestClient(thrift.NewTStandardClient(protocol, protocol)) err = transport.Open() if err != nil { t.Fatal("Unable to open client socket", err) diff --git a/lib/go/thrift/application_exception.go b/lib/go/thrift/application_exception.go index 525bce22b..b9d7eedcd 100644 --- a/lib/go/thrift/application_exception.go +++ b/lib/go/thrift/application_exception.go @@ -45,7 +45,7 @@ var defaultApplicationExceptionMessage = map[int32]string{ type TApplicationException interface { TException TypeId() int32 - Read(iprot TProtocol) (TApplicationException, error) + Read(iprot TProtocol) error Write(oprot TProtocol) error } @@ -69,10 +69,11 @@ func (p *tApplicationException) TypeId() int32 { return p.type_ } -func (p *tApplicationException) Read(iprot TProtocol) (TApplicationException, error) { +func (p *tApplicationException) Read(iprot TProtocol) error { + // TODO: this should really be generated by the compiler _, err := iprot.ReadStructBegin() if err != nil { - return nil, err + return err } message := "" @@ -81,7 +82,7 @@ func (p *tApplicationException) Read(iprot TProtocol) (TApplicationException, er for { _, ttype, id, err := iprot.ReadFieldBegin() if err != nil { - return nil, err + return err } if ttype == STOP { break @@ -90,33 +91,40 @@ func (p *tApplicationException) Read(iprot TProtocol) (TApplicationException, er case 1: if ttype == STRING { if message, err = iprot.ReadString(); err != nil { - return nil, err + return err } } else { if err = SkipDefaultDepth(iprot, ttype); err != nil { - return nil, err + return err } } case 2: if ttype == I32 { if type_, err = iprot.ReadI32(); err != nil { - return nil, err + return err } } else { if err = SkipDefaultDepth(iprot, ttype); err != nil { - return nil, err + return err } } default: if err = SkipDefaultDepth(iprot, ttype); err != nil { - return nil, err + return err } } if err = iprot.ReadFieldEnd(); err != nil { - return nil, err + return err } } - return NewTApplicationException(type_, message), iprot.ReadStructEnd() + if err := iprot.ReadStructEnd(); err != nil { + return err + } + + p.message = message + p.type_ = type_ + + return nil } func (p *tApplicationException) Write(oprot TProtocol) (err error) { diff --git a/lib/go/thrift/client.go b/lib/go/thrift/client.go new file mode 100644 index 000000000..8bdb53d8d --- /dev/null +++ b/lib/go/thrift/client.go @@ -0,0 +1,78 @@ +package thrift + +import "fmt" + +type TStandardClient struct { + seqId int32 + iprot, oprot TProtocol +} + +// TStandardClient implements TClient, and uses the standard message format for Thrift. +// It is not safe for concurrent use. +func NewTStandardClient(inputProtocol, outputProtocol TProtocol) *TStandardClient { + return &TStandardClient{ + iprot: inputProtocol, + oprot: outputProtocol, + } +} + +func (p *TStandardClient) Send(oprot TProtocol, seqId int32, method string, args TStruct) error { + if err := oprot.WriteMessageBegin(method, CALL, seqId); err != nil { + return err + } + if err := args.Write(oprot); err != nil { + return err + } + if err := oprot.WriteMessageEnd(); err != nil { + return err + } + return oprot.Flush() +} + +func (p *TStandardClient) Recv(iprot TProtocol, seqId int32, method string, result TStruct) error { + rMethod, rTypeId, rSeqId, err := iprot.ReadMessageBegin() + if err != nil { + return err + } + + if method != rMethod { + return NewTApplicationException(WRONG_METHOD_NAME, fmt.Sprintf("%s: wrong method name", method)) + } else if seqId != rSeqId { + return NewTApplicationException(BAD_SEQUENCE_ID, fmt.Sprintf("%s: out of order sequence response", method)) + } else if rTypeId == EXCEPTION { + var exception tApplicationException + if err := exception.Read(iprot); err != nil { + return err + } + + if err := iprot.ReadMessageEnd(); err != nil { + return err + } + + return &exception + } else if rTypeId != REPLY { + return NewTApplicationException(INVALID_MESSAGE_TYPE_EXCEPTION, fmt.Sprintf("%s: invalid message type", method)) + } + + if err := result.Read(iprot); err != nil { + return err + } + + return iprot.ReadMessageEnd() +} + +func (p *TStandardClient) call(method string, args, result TStruct) error { + p.seqId++ + seqId := p.seqId + + if err := p.Send(p.oprot, seqId, method, args); err != nil { + return err + } + + // method is oneway + if result == nil { + return nil + } + + return p.Recv(p.iprot, seqId, method, result) +} diff --git a/lib/go/thrift/client_go17.go b/lib/go/thrift/client_go17.go new file mode 100644 index 000000000..15c1c52ca --- /dev/null +++ b/lib/go/thrift/client_go17.go @@ -0,0 +1,13 @@ +// +build go1.7 + +package thrift + +import "context" + +type TClient interface { + Call(ctx context.Context, method string, args, result TStruct) error +} + +func (p *TStandardClient) Call(ctx context.Context, method string, args, result TStruct) error { + return p.call(method, args, result) +} diff --git a/lib/go/thrift/client_pre_go17.go b/lib/go/thrift/client_pre_go17.go new file mode 100644 index 000000000..d2e99ef2a --- /dev/null +++ b/lib/go/thrift/client_pre_go17.go @@ -0,0 +1,13 @@ +// +build !go1.7 + +package thrift + +import "golang.org/x/net/context" + +type TClient interface { + Call(ctx context.Context, method string, args, result TStruct) error +} + +func (p *TStandardClient) Call(ctx context.Context, method string, args, result TStruct) error { + return p.call(method, args, result) +} diff --git a/test/go/Makefile.am b/test/go/Makefile.am index db2725875..6bc97f582 100644 --- a/test/go/Makefile.am +++ b/test/go/Makefile.am @@ -30,6 +30,8 @@ precross: bin/testclient bin/testserver ThriftTest.thrift: $(THRIFTTEST) grep -v list.*map.*list.*map $(THRIFTTEST) > ThriftTest.thrift +.PHONY: gopath + # Thrift for GO has problems with complex map keys: THRIFT-2063 gopath: $(THRIFT) ThriftTest.thrift mkdir -p src/gen diff --git a/test/go/src/bin/testclient/main.go b/test/go/src/bin/testclient/main.go index b34c53927..ab24cbfc7 100644 --- a/test/go/src/bin/testclient/main.go +++ b/test/go/src/bin/testclient/main.go @@ -38,7 +38,7 @@ var testloops = flag.Int("testloops", 1, "Number of Tests") func main() { flag.Parse() - client, err := common.StartClient(*host, *port, *domain_socket, *transport, *protocol, *ssl) + client, _, err := common.StartClient(*host, *port, *domain_socket, *transport, *protocol, *ssl) if err != nil { t.Fatalf("Unable to start client: ", err) } @@ -128,7 +128,7 @@ func callEverything(client *thrifttest.ThriftTestClient) { } bin, err := client.TestBinary(defaultCtx, binout) for i := 0; i < 256; i++ { - if (binout[i] != bin[i]) { + if binout[i] != bin[i] { t.Fatalf("Unexpected TestBinary() result expected %d, got %d ", binout[i], bin[i]) } } @@ -224,21 +224,21 @@ func callEverything(client *thrifttest.ThriftTestClient) { } crazy := thrifttest.NewInsanity() - crazy.UserMap = map[thrifttest.Numberz]thrifttest.UserId { - thrifttest.Numberz_FIVE: 5, + crazy.UserMap = map[thrifttest.Numberz]thrifttest.UserId{ + thrifttest.Numberz_FIVE: 5, thrifttest.Numberz_EIGHT: 8, } truck1 := thrifttest.NewXtruct() truck1.StringThing = "Goodbye4" - truck1.ByteThing = 4; - truck1.I32Thing = 4; - truck1.I64Thing = 4; + truck1.ByteThing = 4 + truck1.I32Thing = 4 + truck1.I64Thing = 4 truck2 := thrifttest.NewXtruct() truck2.StringThing = "Hello2" - truck2.ByteThing = 2; - truck2.I32Thing = 2; - truck2.I64Thing = 2; - crazy.Xtructs = []*thrifttest.Xtruct { + truck2.ByteThing = 2 + truck2.I32Thing = 2 + truck2.I64Thing = 2 + crazy.Xtructs = []*thrifttest.Xtruct{ truck1, truck2, } @@ -248,17 +248,17 @@ func callEverything(client *thrifttest.ThriftTestClient) { } if !reflect.DeepEqual(crazy, insanity[1][2]) { t.Fatalf("Unexpected TestInsanity() first result expected %#v, got %#v ", - crazy, - insanity[1][2]) + crazy, + insanity[1][2]) } if !reflect.DeepEqual(crazy, insanity[1][3]) { t.Fatalf("Unexpected TestInsanity() second result expected %#v, got %#v ", - crazy, - insanity[1][3]) + crazy, + insanity[1][3]) } if len(insanity[2][6].UserMap) > 0 || len(insanity[2][6].Xtructs) > 0 { t.Fatalf("Unexpected TestInsanity() non-empty result got %#v ", - insanity[2][6]) + insanity[2][6]) } xxsret, err := client.TestMulti(defaultCtx, 42, 4242, 424242, map[int16]string{1: "blah", 2: "thing"}, thrifttest.Numberz_EIGHT, thrifttest.UserId(24)) diff --git a/test/go/src/common/client.go b/test/go/src/common/client.go index 4251d910d..236ce43ea 100644 --- a/test/go/src/common/client.go +++ b/test/go/src/common/client.go @@ -41,7 +41,7 @@ func StartClient( domain_socket string, transport string, protocol string, - ssl bool) (client *thrifttest.ThriftTestClient, err error) { + ssl bool) (client *thrifttest.ThriftTestClient, trans thrift.TTransport, err error) { hostPort := fmt.Sprintf("%s:%d", host, port) @@ -56,12 +56,11 @@ func StartClient( case "binary": protocolFactory = thrift.NewTBinaryProtocolFactoryDefault() default: - return nil, fmt.Errorf("Invalid protocol specified %s", protocol) + return nil, nil, fmt.Errorf("Invalid protocol specified %s", protocol) } if debugClientProtocol { protocolFactory = thrift.NewTDebugProtocolFactory(protocolFactory, "client:") } - var trans thrift.TTransport if ssl { trans, err = thrift.NewTSSLSocket(hostPort, &tls.Config{InsecureSkipVerify: true}) } else { @@ -72,7 +71,7 @@ func StartClient( } } if err != nil { - return nil, err + return nil, nil, err } switch transport { case "http": @@ -86,29 +85,25 @@ func StartClient( } else { trans, err = thrift.NewTHttpPostClient(fmt.Sprintf("http://%s/", hostPort)) } - - if err != nil { - return nil, err - } - case "framed": trans = thrift.NewTFramedTransport(trans) case "buffered": trans = thrift.NewTBufferedTransport(trans, 8192) case "zlib": trans, err = thrift.NewTZlibTransport(trans, zlib.BestCompression) - if err != nil { - return nil, err - } case "": trans = trans default: - return nil, fmt.Errorf("Invalid transport specified %s", transport) + return nil, nil, fmt.Errorf("Invalid transport specified %s", transport) + } + if err != nil { + return nil, nil, err } - if err = trans.Open(); err != nil { - return nil, err + return nil, nil, err } - client = thrifttest.NewThriftTestClientFactory(trans, protocolFactory) + iprot := protocolFactory.GetProtocol(trans) + oprot := protocolFactory.GetProtocol(trans) + client = thrifttest.NewThriftTestClient(thrift.NewTStandardClient(iprot, oprot)) return } diff --git a/test/go/src/common/clientserver_test.go b/test/go/src/common/clientserver_test.go index ecd021f3b..c4cfd44f3 100644 --- a/test/go/src/common/clientserver_test.go +++ b/test/go/src/common/clientserver_test.go @@ -23,6 +23,7 @@ import ( "errors" "gen/thrifttest" "reflect" + "sync" "testing" "thrift" @@ -47,10 +48,15 @@ var units = []test_unit{ func TestAllConnection(t *testing.T) { certPath = "../../../keys" + wg := &sync.WaitGroup{} + wg.Add(len(units)) for _, unit := range units { - t.Logf("%#v", unit) - doUnit(t, &unit) + go func(u test_unit) { + defer wg.Done() + doUnit(t, &u) + }(unit) } + wg.Wait() } func doUnit(t *testing.T, unit *test_unit) { @@ -62,17 +68,17 @@ func doUnit(t *testing.T, unit *test_unit) { server := thrift.NewTSimpleServer4(processor, serverTransport, transportFactory, protocolFactory) if err = server.Listen(); err != nil { - t.Errorf("Unable to start server", err) - t.FailNow() + t.Errorf("Unable to start server: %v", err) + return } go server.AcceptLoop() defer server.Stop() - client, err := StartClient(unit.host, unit.port, unit.domain_socket, unit.transport, unit.protocol, unit.ssl) + client, trans, err := StartClient(unit.host, unit.port, unit.domain_socket, unit.transport, unit.protocol, unit.ssl) if err != nil { - t.Errorf("Unable to start client", err) - t.FailNow() + t.Errorf("Unable to start client: %v", err) + return } - defer client.Transport.Close() + defer trans.Close() callEverythingWithMock(t, client, handler) } @@ -273,7 +279,7 @@ func callEverythingWithMock(t *testing.T, client *thrifttest.ThriftTestClient, h xxsret, err := client.TestMulti(defaultCtx, 42, 4242, 424242, map[int16]string{1: "blah", 2: "thing"}, thrifttest.Numberz_EIGHT, thrifttest.UserId(24)) if err != nil { - t.Errorf("Unexpected error in TestMulti() call: ", err) + t.Errorf("Unexpected error in TestMulti() call: %v", err) } if !reflect.DeepEqual(xxs, xxsret) { t.Errorf("Unexpected TestMulti() result expected %#v, got %#v ", xxs, xxsret) @@ -289,9 +295,12 @@ func callEverythingWithMock(t *testing.T, client *thrifttest.ThriftTestClient, h // TODO: connection is being closed on this err = client.TestException(defaultCtx, "TException") - tex, ok := err.(thrift.TApplicationException) - if err == nil || !ok || tex.TypeId() != thrift.INTERNAL_ERROR { - t.Errorf("Unexpected TestException() result expected ApplicationError, got %#v ", err) + if err == nil { + t.Error("expected exception got nil") + } else if tex, ok := err.(thrift.TApplicationException); !ok { + t.Errorf("Unexpected TestException() result expected ApplicationError, got %T ", err) + } else if tex.TypeId() != thrift.INTERNAL_ERROR { + t.Errorf("expected internal_error got %v", tex.TypeId()) } ign, err := client.TestMultiException(defaultCtx, "Xception", "ignoreme") diff --git a/tutorial/go/src/client.go b/tutorial/go/src/client.go index 65027ea05..25616bf4e 100644 --- a/tutorial/go/src/client.go +++ b/tutorial/go/src/client.go @@ -22,8 +22,9 @@ package main import ( "crypto/tls" "fmt" - "git.apache.org/thrift.git/lib/go/thrift" "tutorial" + + "git.apache.org/thrift.git/lib/go/thrift" ) func handleClient(client *tutorial.CalculatorClient) (err error) { @@ -98,5 +99,7 @@ func runClient(transportFactory thrift.TTransportFactory, protocolFactory thrift if err := transport.Open(); err != nil { return err } - return handleClient(tutorial.NewCalculatorClientFactory(transport, protocolFactory)) + iprot := protocolFactory.GetProtocol(transport) + oprot := protocolFactory.GetProtocol(transport) + return handleClient(tutorial.NewCalculatorClient(thrift.NewTStandardClient(iprot, oprot))) } |