diff options
author | Yuxuan 'fishy' Wang <yuxuan.wang@reddit.com> | 2020-09-21 12:33:26 -0700 |
---|---|---|
committer | Yuxuan 'fishy' Wang <fishywang@gmail.com> | 2020-09-22 16:48:38 -0700 |
commit | a2c44665b416522477cffa6752c2f323768d0507 (patch) | |
tree | 9942f7d24637c0ee3b4e51b8c8e26e6b3f56c640 | |
parent | 03f4729f7c19abb206fae439e27c1dda7250fdcd (diff) | |
download | thrift-a2c44665b416522477cffa6752c2f323768d0507.tar.gz |
THRIFT-5278: Allow set protoID in go THeader transport/protocol
Client: go
In Go library code, allow setting the underlying protoID to a
non-default (TCompactProtocol) one for THeaderTransport/THeaderProtocol.
-rw-r--r-- | lib/go/thrift/header_protocol.go | 60 | ||||
-rw-r--r-- | lib/go/thrift/header_protocol_test.go | 18 | ||||
-rw-r--r-- | lib/go/thrift/header_transport.go | 44 | ||||
-rw-r--r-- | lib/go/thrift/header_transport_test.go | 23 |
4 files changed, 128 insertions, 17 deletions
diff --git a/lib/go/thrift/header_protocol.go b/lib/go/thrift/header_protocol.go index 428b26148..f86d558aa 100644 --- a/lib/go/thrift/header_protocol.go +++ b/lib/go/thrift/header_protocol.go @@ -37,31 +37,73 @@ type THeaderProtocol struct { } // NewTHeaderProtocol creates a new THeaderProtocol from the underlying -// transport. The passed in transport will be wrapped with THeaderTransport. +// transport with default protocol ID. +// +// The passed in transport will be wrapped with THeaderTransport. // // Note that THeaderTransport handles frame and zlib by itself, // so the underlying transport should be a raw socket transports (TSocket or TSSLSocket), // instead of rich transports like TZlibTransport or TFramedTransport. func NewTHeaderProtocol(trans TTransport) *THeaderProtocol { - t := NewTHeaderTransport(trans) - p, _ := THeaderProtocolDefault.GetProtocol(t) + p, err := newTHeaderProtocolWithProtocolID(trans, THeaderProtocolDefault) + if err != nil { + // Since we used THeaderProtocolDefault this should never happen, + // but put a sanity check here just in case. + panic(err) + } + return p +} + +func newTHeaderProtocolWithProtocolID(trans TTransport, protoID THeaderProtocolID) (*THeaderProtocol, error) { + t, err := NewTHeaderTransportWithProtocolID(trans, protoID) + if err != nil { + return nil, err + } + p, err := t.protocolID.GetProtocol(t) + if err != nil { + return nil, err + } return &THeaderProtocol{ transport: t, protocol: p, - } + }, nil } -type tHeaderProtocolFactory struct{} +type tHeaderProtocolFactory struct { + protoID THeaderProtocolID +} -func (tHeaderProtocolFactory) GetProtocol(trans TTransport) TProtocol { - return NewTHeaderProtocol(trans) +func (f tHeaderProtocolFactory) GetProtocol(trans TTransport) TProtocol { + p, err := newTHeaderProtocolWithProtocolID(trans, f.protoID) + if err != nil { + // Currently there's no way for external users to construct a + // valid factory with invalid protoID, so this should never + // happen. But put a sanity check here just in case in the + // future a bug made that possible. + panic(err) + } + return p } -// NewTHeaderProtocolFactory creates a factory for THeader. +// NewTHeaderProtocolFactory creates a factory for THeader with default protocol +// ID. // // It's a wrapper for NewTHeaderProtocol func NewTHeaderProtocolFactory() TProtocolFactory { - return tHeaderProtocolFactory{} + return tHeaderProtocolFactory{ + protoID: THeaderProtocolDefault, + } +} + +// NewTHeaderProtocolFactoryWithProtocolID creates a factory for THeader with +// given protocol ID. +func NewTHeaderProtocolFactoryWithProtocolID(protoID THeaderProtocolID) (TProtocolFactory, error) { + if err := protoID.Validate(); err != nil { + return nil, err + } + return tHeaderProtocolFactory{ + protoID: protoID, + }, nil } // Transport returns the underlying transport. diff --git a/lib/go/thrift/header_protocol_test.go b/lib/go/thrift/header_protocol_test.go index 9b6019bcf..f66ea6463 100644 --- a/lib/go/thrift/header_protocol_test.go +++ b/lib/go/thrift/header_protocol_test.go @@ -24,5 +24,21 @@ import ( ) func TestReadWriteHeaderProtocol(t *testing.T) { - ReadWriteProtocolTest(t, NewTHeaderProtocolFactory()) + t.Run( + "default", + func(t *testing.T) { + ReadWriteProtocolTest(t, NewTHeaderProtocolFactory()) + }, + ) + + t.Run( + "compact", + func(t *testing.T) { + f, err := NewTHeaderProtocolFactoryWithProtocolID(THeaderProtocolCompact) + if err != nil { + t.Fatal(err) + } + ReadWriteProtocolTest(t, f) + }, + ) } diff --git a/lib/go/thrift/header_transport.go b/lib/go/thrift/header_transport.go index e2080344b..562d02fa4 100644 --- a/lib/go/thrift/header_transport.go +++ b/lib/go/thrift/header_transport.go @@ -75,6 +75,15 @@ const ( THeaderProtocolDefault = THeaderProtocolBinary ) +// Declared globally to avoid repetitive allocations, not really used. +var globalMemoryBuffer = NewTMemoryBuffer() + +// Validate checks whether the THeaderProtocolID is a valid/supported one. +func (id THeaderProtocolID) Validate() error { + _, err := id.GetProtocol(globalMemoryBuffer) + return err +} + // GetProtocol gets the corresponding TProtocol from the wrapped protocol id. func (id THeaderProtocolID) GetProtocol(trans TTransport) (TProtocol, error) { switch id { @@ -84,7 +93,7 @@ func (id THeaderProtocolID) GetProtocol(trans TTransport) (TProtocol, error) { fmt.Sprintf("THeader protocol id %d not supported", id), ) case THeaderProtocolBinary: - return NewTBinaryProtocolFactoryDefault().GetProtocol(trans), nil + return NewTBinaryProtocolTransport(trans), nil case THeaderProtocolCompact: return NewTCompactProtocol(trans), nil } @@ -93,11 +102,12 @@ func (id THeaderProtocolID) GetProtocol(trans TTransport) (TProtocol, error) { // THeaderTransformID defines the numeric id of the transform used. type THeaderTransformID int32 -// THeaderTransformID values +// THeaderTransformID values. +// +// Values not defined here are not currently supported, namely HMAC and Snappy. const ( TransformNone THeaderTransformID = iota // 0, no special handling TransformZlib // 1, zlib - // Rest of the values are not currently supported, namely HMAC and Snappy. ) var supportedTransformIDs = map[THeaderTransformID]bool{ @@ -285,6 +295,34 @@ func NewTHeaderTransport(trans TTransport) *THeaderTransport { } } +// NewTHeaderTransportWithProtocolID creates THeaderTransport from the +// underlying transport, with given protocol ID set. +// +// If trans is already a *THeaderTransport, it will be returned as is, +// but with protocol ID overridden by the value passed in. +// +// If the passed in protocol ID is an invalid/unsupported one, +// this function returns error. +// +// The protocol ID overridden is only useful for client transports. +// For servers, +// the protocol ID will be overridden again to the one set by the client, +// to ensure that servers always speak the same dialect as the client. +func NewTHeaderTransportWithProtocolID(trans TTransport, protoID THeaderProtocolID) (*THeaderTransport, error) { + if err := protoID.Validate(); err != nil { + return nil, err + } + if ht, ok := trans.(*THeaderTransport); ok { + return ht, nil + } + return &THeaderTransport{ + transport: trans, + reader: bufio.NewReader(trans), + writeHeaders: make(THeaderMap), + protocolID: protoID, + }, nil +} + // Open calls the underlying transport's Open function. func (t *THeaderTransport) Open() error { return t.transport.Open() diff --git a/lib/go/thrift/header_transport_test.go b/lib/go/thrift/header_transport_test.go index 320fb2a6f..5b47680e8 100644 --- a/lib/go/thrift/header_transport_test.go +++ b/lib/go/thrift/header_transport_test.go @@ -28,10 +28,13 @@ import ( "testing/quick" ) -func TestTHeaderHeadersReadWrite(t *testing.T) { +func testTHeaderHeadersReadWriteProtocolID(t *testing.T, protoID THeaderProtocolID) { trans := NewTMemoryBuffer() reader := NewTHeaderTransport(trans) - writer := NewTHeaderTransport(trans) + writer, err := NewTHeaderTransportWithProtocolID(trans, protoID) + if err != nil { + t.Fatal(err) + } const key1 = "key1" const value1 = "value1" @@ -98,10 +101,10 @@ func TestTHeaderHeadersReadWrite(t *testing.T) { read, ) } - if prot := reader.Protocol(); prot != THeaderProtocolBinary { + if prot := reader.Protocol(); prot != protoID { t.Errorf( "reader.Protocol() expected %d, got %d", - THeaderProtocolBinary, + protoID, prot, ) } @@ -121,6 +124,18 @@ func TestTHeaderHeadersReadWrite(t *testing.T) { } } +func TestTHeaderHeadersReadWrite(t *testing.T) { + for label, id := range map[string]THeaderProtocolID{ + "default": THeaderProtocolDefault, + "binary": THeaderProtocolBinary, + "compact": THeaderProtocolCompact, + } { + t.Run(label, func(t *testing.T) { + testTHeaderHeadersReadWriteProtocolID(t, id) + }) + } +} + func TestTHeaderTransportNoDoubleWrapping(t *testing.T) { trans := NewTMemoryBuffer() orig := NewTHeaderTransport(trans) |