diff options
author | Yuxuan 'fishy' Wang <yuxuan.wang@reddit.com> | 2020-12-22 09:53:58 -0800 |
---|---|---|
committer | Yuxuan 'fishy' Wang <fishywang@gmail.com> | 2021-01-17 15:19:44 -0800 |
commit | d831230929bb332189c9509d07102e4be9e7f681 (patch) | |
tree | c1989efa6dc2dd66f0ba7afdf9682b9e764145c4 | |
parent | c4d1c0d80067986dbee124887bcb402ee1c6538e (diff) | |
download | thrift-d831230929bb332189c9509d07102e4be9e7f681.tar.gz |
THRIFT-5326: Expand TException interface in go library
Client: go
Add TExceptionType enum type, and add
TExceptionType() TExceptionType
function to TException definition.
Also make TProtocolException unwrap-able.
-rw-r--r-- | CHANGES.md | 1 | ||||
-rw-r--r-- | compiler/cpp/src/thrift/generate/t_go_generator.cc | 33 | ||||
-rw-r--r-- | lib/go/test/tests/thrifttest_handler.go | 2 | ||||
-rw-r--r-- | lib/go/thrift/application_exception.go | 6 | ||||
-rw-r--r-- | lib/go/thrift/compact_protocol.go | 2 | ||||
-rw-r--r-- | lib/go/thrift/exception.go | 81 | ||||
-rw-r--r-- | lib/go/thrift/multiplexed_protocol.go | 14 | ||||
-rw-r--r-- | lib/go/thrift/protocol_exception.go | 33 | ||||
-rw-r--r-- | lib/go/thrift/simple_server.go | 18 | ||||
-rw-r--r-- | lib/go/thrift/transport_exception.go | 6 | ||||
-rw-r--r-- | lib/go/thrift/transport_exception_test.go | 4 |
11 files changed, 162 insertions, 38 deletions
diff --git a/CHANGES.md b/CHANGES.md index 663c4c18c..8e4d08edd 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -17,6 +17,7 @@ - [THRIFT-5233](https://issues.apache.org/jira/browse/THRIFT-5233) - go: Now all Read*, Write* and Skip functions in TProtocol accept context arg - [THRIFT-5152](https://issues.apache.org/jira/browse/THRIFT-5152) - go: TSocket and TSSLSocket now have separated connect timeout and socket timeout - c++: dropped support for Windows XP +- [THRIFT-5326](https://issues.apache.org/jira/browse/THRIFT-5326) - go: TException interface now has a new function: TExceptionType ### Java diff --git a/compiler/cpp/src/thrift/generate/t_go_generator.cc b/compiler/cpp/src/thrift/generate/t_go_generator.cc index 3bb2a5cf2..49d8bc119 100644 --- a/compiler/cpp/src/thrift/generate/t_go_generator.cc +++ b/compiler/cpp/src/thrift/generate/t_go_generator.cc @@ -1493,8 +1493,15 @@ void t_go_generator::generate_go_struct_definition(ostream& out, if (is_exception) { out << indent() << "func (p *" << tstruct_name << ") Error() string {" << endl; - out << indent() << " return p.String()" << endl; + out << indent() << indent() << "return p.String()" << endl; out << indent() << "}" << endl << endl; + + out << indent() << "func (" << tstruct_name << ") TExceptionType() thrift.TExceptionType {" << endl; + out << indent() << indent() << "return thrift.TExceptionTypeCompiled" << endl; + out << indent() << "}" << endl << endl; + + out << indent() << "var _ thrift.TException = (*" << tstruct_name << ")(nil)" + << endl << endl; } } @@ -2700,8 +2707,8 @@ void t_go_generator::generate_service_server(t_service* tservice) { f_types_ << indent() << "func (p *" << serviceName << "Processor) Process(ctx context.Context, iprot, oprot thrift.TProtocol) (success bool, err " "thrift.TException) {" << endl; - f_types_ << indent() << " name, _, seqId, err := iprot.ReadMessageBegin(ctx)" << endl; - f_types_ << indent() << " if err != nil { return false, err }" << endl; + f_types_ << indent() << " name, _, seqId, err2 := iprot.ReadMessageBegin(ctx)" << endl; + f_types_ << indent() << " if err2 != nil { return false, thrift.WrapTException(err2) }" << endl; f_types_ << indent() << " if processor, ok := p.GetProcessorFunction(name); ok {" << endl; f_types_ << indent() << " return processor.Process(ctx, seqId, iprot, oprot)" << endl; f_types_ << indent() << " }" << endl; @@ -2767,11 +2774,12 @@ void t_go_generator::generate_process_function(t_service* tservice, t_function* "thrift.TException) {" << endl; indent_up(); f_types_ << indent() << "args := " << argsname << "{}" << endl; - f_types_ << indent() << "if err = args." << read_method_name_ << "(ctx, iprot); err != nil {" << endl; + f_types_ << indent() << "var err2 error" << endl; + f_types_ << indent() << "if err2 = args." << read_method_name_ << "(ctx, iprot); err2 != nil {" << endl; f_types_ << indent() << " iprot.ReadMessageEnd(ctx)" << endl; if (!tfunction->is_oneway()) { f_types_ << indent() - << " x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err.Error())" + << " x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err2.Error())" << endl; f_types_ << indent() << " oprot.WriteMessageBegin(ctx, \"" << escape_string(tfunction->get_name()) << "\", thrift.EXCEPTION, seqId)" << endl; @@ -2779,7 +2787,7 @@ void t_go_generator::generate_process_function(t_service* tservice, t_function* f_types_ << indent() << " oprot.WriteMessageEnd(ctx)" << endl; f_types_ << indent() << " oprot.Flush(ctx)" << endl; } - f_types_ << indent() << " return false, err" << endl; + f_types_ << indent() << " return false, thrift.WrapTException(err2)" << endl; f_types_ << indent() << "}" << endl; f_types_ << indent() << "iprot.ReadMessageEnd(ctx)" << endl << endl; @@ -2842,7 +2850,6 @@ void t_go_generator::generate_process_function(t_service* tservice, t_function* f_types_ << indent() << "var retval " << type_to_go_type(tfunction->get_returntype()) << endl; } - f_types_ << indent() << "var err2 error" << endl; f_types_ << indent() << "if "; if (!tfunction->is_oneway()) { @@ -2892,7 +2899,7 @@ void t_go_generator::generate_process_function(t_service* tservice, t_function* if (!tfunction->is_oneway()) { // Avoid writing the error to the wire if it's ErrAbandonRequest f_types_ << indent() << " if err2 == thrift.ErrAbandonRequest {" << endl; - f_types_ << indent() << " return false, err2" << endl; + f_types_ << indent() << " return false, thrift.WrapTException(err2)" << endl; f_types_ << indent() << " }" << endl; f_types_ << indent() << " x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, " @@ -2905,7 +2912,7 @@ void t_go_generator::generate_process_function(t_service* tservice, t_function* f_types_ << indent() << " oprot.Flush(ctx)" << endl; } - f_types_ << indent() << " return true, err2" << endl; + f_types_ << indent() << " return true, thrift.WrapTException(err2)" << endl; if (!x_fields.empty()) { f_types_ << indent() << "}" << endl; @@ -2931,17 +2938,17 @@ void t_go_generator::generate_process_function(t_service* tservice, t_function* f_types_ << indent() << "if err2 = oprot.WriteMessageBegin(ctx, \"" << escape_string(tfunction->get_name()) << "\", thrift.REPLY, seqId); err2 != nil {" << endl; - f_types_ << indent() << " err = err2" << endl; + f_types_ << indent() << " err = thrift.WrapTException(err2)" << endl; f_types_ << indent() << "}" << endl; f_types_ << indent() << "if err2 = result." << write_method_name_ << "(ctx, oprot); err == nil && err2 != nil {" << endl; - f_types_ << indent() << " err = err2" << endl; + f_types_ << indent() << " err = thrift.WrapTException(err2)" << endl; f_types_ << indent() << "}" << endl; f_types_ << indent() << "if err2 = oprot.WriteMessageEnd(ctx); err == nil && err2 != nil {" << endl; - f_types_ << indent() << " err = err2" << endl; + f_types_ << indent() << " err = thrift.WrapTException(err2)" << endl; f_types_ << indent() << "}" << endl; f_types_ << indent() << "if err2 = oprot.Flush(ctx); err == nil && err2 != nil {" << endl; - f_types_ << indent() << " err = err2" << endl; + f_types_ << indent() << " err = thrift.WrapTException(err2)" << endl; f_types_ << indent() << "}" << endl; f_types_ << indent() << "if err != nil {" << endl; f_types_ << indent() << " return" << endl; diff --git a/lib/go/test/tests/thrifttest_handler.go b/lib/go/test/tests/thrifttest_handler.go index 31b9ee23e..7b115ec40 100644 --- a/lib/go/test/tests/thrifttest_handler.go +++ b/lib/go/test/tests/thrifttest_handler.go @@ -179,7 +179,7 @@ func (p *ThriftTestHandler) TestException(ctx context.Context, arg string) (err x.Message = arg return x } else if arg == "TException" { - return thrift.TException(errors.New(arg)) + return thrift.WrapTException(errors.New(arg)) } else { return nil } diff --git a/lib/go/thrift/application_exception.go b/lib/go/thrift/application_exception.go index 6de37ee73..32d5b0147 100644 --- a/lib/go/thrift/application_exception.go +++ b/lib/go/thrift/application_exception.go @@ -64,6 +64,12 @@ type tApplicationException struct { type_ int32 } +var _ TApplicationException = (*tApplicationException)(nil) + +func (tApplicationException) TExceptionType() TExceptionType { + return TExceptionTypeApplication +} + func (e tApplicationException) Error() string { if e.message != "" { return e.message diff --git a/lib/go/thrift/compact_protocol.go b/lib/go/thrift/compact_protocol.go index 25e6d0ccd..a49225dab 100644 --- a/lib/go/thrift/compact_protocol.go +++ b/lib/go/thrift/compact_protocol.go @@ -845,7 +845,7 @@ func (p *TCompactProtocol) getTType(t tCompactType) (TType, error) { case COMPACT_STRUCT: return STRUCT, nil } - return STOP, TException(fmt.Errorf("don't know what type: %v", t&0x0f)) + return STOP, NewTProtocolException(fmt.Errorf("don't know what type: %v", t&0x0f)) } // Given a TType value, find the appropriate TCompactProtocol.Types constant. diff --git a/lib/go/thrift/exception.go b/lib/go/thrift/exception.go index ea8d6f661..b6885face 100644 --- a/lib/go/thrift/exception.go +++ b/lib/go/thrift/exception.go @@ -26,19 +26,86 @@ import ( // Generic Thrift exception type TException interface { error + + TExceptionType() TExceptionType } // Prepends additional information to an error without losing the Thrift exception interface func PrependError(prepend string, err error) error { - if t, ok := err.(TTransportException); ok { - return NewTTransportException(t.TypeId(), prepend+t.Error()) + msg := prepend + err.Error() + + if te, ok := err.(TException); ok { + switch te.TExceptionType() { + case TExceptionTypeTransport: + if t, ok := err.(TTransportException); ok { + return NewTTransportException(t.TypeId(), msg) + } + case TExceptionTypeProtocol: + if t, ok := err.(TProtocolException); ok { + return NewTProtocolExceptionWithType(t.TypeId(), errors.New(msg)) + } + case TExceptionTypeApplication: + if t, ok := err.(TApplicationException); ok { + return NewTApplicationException(t.TypeId(), msg) + } + } + + return wrappedTException{ + err: errors.New(msg), + tExceptionType: te.TExceptionType(), + } + } + + return errors.New(msg) +} + +// TExceptionType is an enum type to categorize different "subclasses" of TExceptions. +type TExceptionType byte + +// TExceptionType values +const ( + TExceptionTypeUnknown TExceptionType = iota + TExceptionTypeCompiled // TExceptions defined in thrift files and generated by thrift compiler + TExceptionTypeApplication // TApplicationExceptions + TExceptionTypeProtocol // TProtocolExceptions + TExceptionTypeTransport // TTransportExceptions +) + +// WrapTException wraps an error into TException. +// +// If err is nil or already TException, it's returned as-is. +// Otherwise it will be wraped into TException with TExceptionType() returning +// TExceptionTypeUnknown, and Unwrap() returning the original error. +func WrapTException(err error) TException { + if err == nil { + return nil } - if t, ok := err.(TProtocolException); ok { - return NewTProtocolExceptionWithType(t.TypeId(), errors.New(prepend+err.Error())) + + if te, ok := err.(TException); ok { + return te } - if t, ok := err.(TApplicationException); ok { - return NewTApplicationException(t.TypeId(), prepend+t.Error()) + + return wrappedTException{ + err: err, + tExceptionType: TExceptionTypeUnknown, } +} + +type wrappedTException struct { + err error + tExceptionType TExceptionType +} - return errors.New(prepend + err.Error()) +func (w wrappedTException) Error() string { + return w.err.Error() } + +func (w wrappedTException) TExceptionType() TExceptionType { + return w.tExceptionType +} + +func (w wrappedTException) Unwrap() error { + return w.err +} + +var _ TException = wrappedTException{} diff --git a/lib/go/thrift/multiplexed_protocol.go b/lib/go/thrift/multiplexed_protocol.go index 2f7997e77..cacbf6bef 100644 --- a/lib/go/thrift/multiplexed_protocol.go +++ b/lib/go/thrift/multiplexed_protocol.go @@ -192,10 +192,10 @@ func (t *TMultiplexedProcessor) RegisterProcessor(name string, processor TProces func (t *TMultiplexedProcessor) Process(ctx context.Context, in, out TProtocol) (bool, TException) { name, typeId, seqid, err := in.ReadMessageBegin(ctx) if err != nil { - return false, err + return false, NewTProtocolException(err) } if typeId != CALL && typeId != ONEWAY { - return false, fmt.Errorf("Unexpected message type %v", typeId) + return false, NewTProtocolException(fmt.Errorf("Unexpected message type %v", typeId)) } //extract the service name v := strings.SplitN(name, MULTIPLEXED_SEPARATOR, 2) @@ -204,11 +204,17 @@ func (t *TMultiplexedProcessor) Process(ctx context.Context, in, out TProtocol) smb := NewStoredMessageProtocol(in, name, typeId, seqid) return t.DefaultProcessor.Process(ctx, smb, out) } - return false, fmt.Errorf("Service name not found in message name: %s. Did you forget to use a TMultiplexProtocol in your client?", name) + return false, NewTProtocolException(fmt.Errorf( + "Service name not found in message name: %s. Did you forget to use a TMultiplexProtocol in your client?", + name, + )) } actualProcessor, ok := t.serviceProcessorMap[v[0]] if !ok { - return false, fmt.Errorf("Service name not found: %s. Did you forget to call registerProcessor()?", v[0]) + return false, NewTProtocolException(fmt.Errorf( + "Service name not found: %s. Did you forget to call registerProcessor()?", + v[0], + )) } smb := NewStoredMessageProtocol(in, v[1], typeId, seqid) return actualProcessor.Process(ctx, smb, out) diff --git a/lib/go/thrift/protocol_exception.go b/lib/go/thrift/protocol_exception.go index 29ab75d92..b088caf13 100644 --- a/lib/go/thrift/protocol_exception.go +++ b/lib/go/thrift/protocol_exception.go @@ -40,8 +40,14 @@ const ( ) type tProtocolException struct { - typeId int - message string + typeId int + err error +} + +var _ TProtocolException = (*tProtocolException)(nil) + +func (tProtocolException) TExceptionType() TExceptionType { + return TExceptionTypeProtocol } func (p *tProtocolException) TypeId() int { @@ -49,11 +55,15 @@ func (p *tProtocolException) TypeId() int { } func (p *tProtocolException) String() string { - return p.message + return p.err.Error() } func (p *tProtocolException) Error() string { - return p.message + return p.err.Error() +} + +func (p *tProtocolException) Unwrap() error { + return p.err } func NewTProtocolException(err error) TProtocolException { @@ -64,14 +74,23 @@ func NewTProtocolException(err error) TProtocolException { return e } if _, ok := err.(base64.CorruptInputError); ok { - return &tProtocolException{INVALID_DATA, err.Error()} + return &tProtocolException{ + typeId: INVALID_DATA, + err: err, + } + } + return &tProtocolException{ + typeId: UNKNOWN_PROTOCOL_EXCEPTION, + err: err, } - return &tProtocolException{UNKNOWN_PROTOCOL_EXCEPTION, err.Error()} } func NewTProtocolExceptionWithType(errType int, err error) TProtocolException { if err == nil { return nil } - return &tProtocolException{errType, err.Error()} + return &tProtocolException{ + typeId: errType, + err: err, + } } diff --git a/lib/go/thrift/simple_server.go b/lib/go/thrift/simple_server.go index e9fea86d2..ca0e61d0e 100644 --- a/lib/go/thrift/simple_server.go +++ b/lib/go/thrift/simple_server.go @@ -315,7 +315,9 @@ func (p *TSimpleServer) processRequests(client TTransport) (err error) { } ok, err := processor.Process(ctx, inputProtocol, outputProtocol) - if err == ErrAbandonRequest { + // Once we dropped support for pre-go1.13 this can be replaced by: + // errors.Is(err, ErrAbandonRequest) + if unwrapError(err) == ErrAbandonRequest { return client.Close() } if _, ok := err.(TTransportException); ok && err != nil { @@ -330,3 +332,17 @@ func (p *TSimpleServer) processRequests(client TTransport) (err error) { } return nil } + +type unwrapper interface { + Unwrap() error +} + +func unwrapError(err error) error { + for { + if u, ok := err.(unwrapper); ok { + err = u.Unwrap() + } else { + return err + } + } +} diff --git a/lib/go/thrift/transport_exception.go b/lib/go/thrift/transport_exception.go index 16193ee86..cf2cc0059 100644 --- a/lib/go/thrift/transport_exception.go +++ b/lib/go/thrift/transport_exception.go @@ -48,6 +48,12 @@ type tTransportException struct { err error } +var _ TTransportException = (*tTransportException)(nil) + +func (tTransportException) TExceptionType() TExceptionType { + return TExceptionTypeTransport +} + func (p *tTransportException) TypeId() int { return p.typeId } diff --git a/lib/go/thrift/transport_exception_test.go b/lib/go/thrift/transport_exception_test.go index fb1dc2602..57386cb28 100644 --- a/lib/go/thrift/transport_exception_test.go +++ b/lib/go/thrift/transport_exception_test.go @@ -36,10 +36,6 @@ func (t *timeout) Error() string { return fmt.Sprintf("Timeout: %v", t.timedout) } -type unwrapper interface { - Unwrap() error -} - func TestTExceptionTimeout(t *testing.T) { timeout := &timeout{true} exception := NewTTransportExceptionFromError(timeout) |