summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYuxuan 'fishy' Wang <yuxuan.wang@reddit.com>2020-09-21 12:33:26 -0700
committerYuxuan 'fishy' Wang <fishywang@gmail.com>2020-09-22 16:48:38 -0700
commita2c44665b416522477cffa6752c2f323768d0507 (patch)
tree9942f7d24637c0ee3b4e51b8c8e26e6b3f56c640
parent03f4729f7c19abb206fae439e27c1dda7250fdcd (diff)
downloadthrift-a2c44665b416522477cffa6752c2f323768d0507.tar.gz
THRIFT-5278: Allow set protoID in go THeader transport/protocol
Client: go In Go library code, allow setting the underlying protoID to a non-default (TCompactProtocol) one for THeaderTransport/THeaderProtocol.
-rw-r--r--lib/go/thrift/header_protocol.go60
-rw-r--r--lib/go/thrift/header_protocol_test.go18
-rw-r--r--lib/go/thrift/header_transport.go44
-rw-r--r--lib/go/thrift/header_transport_test.go23
4 files changed, 128 insertions, 17 deletions
diff --git a/lib/go/thrift/header_protocol.go b/lib/go/thrift/header_protocol.go
index 428b26148..f86d558aa 100644
--- a/lib/go/thrift/header_protocol.go
+++ b/lib/go/thrift/header_protocol.go
@@ -37,31 +37,73 @@ type THeaderProtocol struct {
}
// NewTHeaderProtocol creates a new THeaderProtocol from the underlying
-// transport. The passed in transport will be wrapped with THeaderTransport.
+// transport with default protocol ID.
+//
+// The passed in transport will be wrapped with THeaderTransport.
//
// Note that THeaderTransport handles frame and zlib by itself,
// so the underlying transport should be a raw socket transports (TSocket or TSSLSocket),
// instead of rich transports like TZlibTransport or TFramedTransport.
func NewTHeaderProtocol(trans TTransport) *THeaderProtocol {
- t := NewTHeaderTransport(trans)
- p, _ := THeaderProtocolDefault.GetProtocol(t)
+ p, err := newTHeaderProtocolWithProtocolID(trans, THeaderProtocolDefault)
+ if err != nil {
+ // Since we used THeaderProtocolDefault this should never happen,
+ // but put a sanity check here just in case.
+ panic(err)
+ }
+ return p
+}
+
+func newTHeaderProtocolWithProtocolID(trans TTransport, protoID THeaderProtocolID) (*THeaderProtocol, error) {
+ t, err := NewTHeaderTransportWithProtocolID(trans, protoID)
+ if err != nil {
+ return nil, err
+ }
+ p, err := t.protocolID.GetProtocol(t)
+ if err != nil {
+ return nil, err
+ }
return &THeaderProtocol{
transport: t,
protocol: p,
- }
+ }, nil
}
-type tHeaderProtocolFactory struct{}
+type tHeaderProtocolFactory struct {
+ protoID THeaderProtocolID
+}
-func (tHeaderProtocolFactory) GetProtocol(trans TTransport) TProtocol {
- return NewTHeaderProtocol(trans)
+func (f tHeaderProtocolFactory) GetProtocol(trans TTransport) TProtocol {
+ p, err := newTHeaderProtocolWithProtocolID(trans, f.protoID)
+ if err != nil {
+ // Currently there's no way for external users to construct a
+ // valid factory with invalid protoID, so this should never
+ // happen. But put a sanity check here just in case in the
+ // future a bug made that possible.
+ panic(err)
+ }
+ return p
}
-// NewTHeaderProtocolFactory creates a factory for THeader.
+// NewTHeaderProtocolFactory creates a factory for THeader with default protocol
+// ID.
//
// It's a wrapper for NewTHeaderProtocol
func NewTHeaderProtocolFactory() TProtocolFactory {
- return tHeaderProtocolFactory{}
+ return tHeaderProtocolFactory{
+ protoID: THeaderProtocolDefault,
+ }
+}
+
+// NewTHeaderProtocolFactoryWithProtocolID creates a factory for THeader with
+// given protocol ID.
+func NewTHeaderProtocolFactoryWithProtocolID(protoID THeaderProtocolID) (TProtocolFactory, error) {
+ if err := protoID.Validate(); err != nil {
+ return nil, err
+ }
+ return tHeaderProtocolFactory{
+ protoID: protoID,
+ }, nil
}
// Transport returns the underlying transport.
diff --git a/lib/go/thrift/header_protocol_test.go b/lib/go/thrift/header_protocol_test.go
index 9b6019bcf..f66ea6463 100644
--- a/lib/go/thrift/header_protocol_test.go
+++ b/lib/go/thrift/header_protocol_test.go
@@ -24,5 +24,21 @@ import (
)
func TestReadWriteHeaderProtocol(t *testing.T) {
- ReadWriteProtocolTest(t, NewTHeaderProtocolFactory())
+ t.Run(
+ "default",
+ func(t *testing.T) {
+ ReadWriteProtocolTest(t, NewTHeaderProtocolFactory())
+ },
+ )
+
+ t.Run(
+ "compact",
+ func(t *testing.T) {
+ f, err := NewTHeaderProtocolFactoryWithProtocolID(THeaderProtocolCompact)
+ if err != nil {
+ t.Fatal(err)
+ }
+ ReadWriteProtocolTest(t, f)
+ },
+ )
}
diff --git a/lib/go/thrift/header_transport.go b/lib/go/thrift/header_transport.go
index e2080344b..562d02fa4 100644
--- a/lib/go/thrift/header_transport.go
+++ b/lib/go/thrift/header_transport.go
@@ -75,6 +75,15 @@ const (
THeaderProtocolDefault = THeaderProtocolBinary
)
+// Declared globally to avoid repetitive allocations, not really used.
+var globalMemoryBuffer = NewTMemoryBuffer()
+
+// Validate checks whether the THeaderProtocolID is a valid/supported one.
+func (id THeaderProtocolID) Validate() error {
+ _, err := id.GetProtocol(globalMemoryBuffer)
+ return err
+}
+
// GetProtocol gets the corresponding TProtocol from the wrapped protocol id.
func (id THeaderProtocolID) GetProtocol(trans TTransport) (TProtocol, error) {
switch id {
@@ -84,7 +93,7 @@ func (id THeaderProtocolID) GetProtocol(trans TTransport) (TProtocol, error) {
fmt.Sprintf("THeader protocol id %d not supported", id),
)
case THeaderProtocolBinary:
- return NewTBinaryProtocolFactoryDefault().GetProtocol(trans), nil
+ return NewTBinaryProtocolTransport(trans), nil
case THeaderProtocolCompact:
return NewTCompactProtocol(trans), nil
}
@@ -93,11 +102,12 @@ func (id THeaderProtocolID) GetProtocol(trans TTransport) (TProtocol, error) {
// THeaderTransformID defines the numeric id of the transform used.
type THeaderTransformID int32
-// THeaderTransformID values
+// THeaderTransformID values.
+//
+// Values not defined here are not currently supported, namely HMAC and Snappy.
const (
TransformNone THeaderTransformID = iota // 0, no special handling
TransformZlib // 1, zlib
- // Rest of the values are not currently supported, namely HMAC and Snappy.
)
var supportedTransformIDs = map[THeaderTransformID]bool{
@@ -285,6 +295,34 @@ func NewTHeaderTransport(trans TTransport) *THeaderTransport {
}
}
+// NewTHeaderTransportWithProtocolID creates THeaderTransport from the
+// underlying transport, with given protocol ID set.
+//
+// If trans is already a *THeaderTransport, it will be returned as is,
+// but with protocol ID overridden by the value passed in.
+//
+// If the passed in protocol ID is an invalid/unsupported one,
+// this function returns error.
+//
+// The protocol ID overridden is only useful for client transports.
+// For servers,
+// the protocol ID will be overridden again to the one set by the client,
+// to ensure that servers always speak the same dialect as the client.
+func NewTHeaderTransportWithProtocolID(trans TTransport, protoID THeaderProtocolID) (*THeaderTransport, error) {
+ if err := protoID.Validate(); err != nil {
+ return nil, err
+ }
+ if ht, ok := trans.(*THeaderTransport); ok {
+ return ht, nil
+ }
+ return &THeaderTransport{
+ transport: trans,
+ reader: bufio.NewReader(trans),
+ writeHeaders: make(THeaderMap),
+ protocolID: protoID,
+ }, nil
+}
+
// Open calls the underlying transport's Open function.
func (t *THeaderTransport) Open() error {
return t.transport.Open()
diff --git a/lib/go/thrift/header_transport_test.go b/lib/go/thrift/header_transport_test.go
index 320fb2a6f..5b47680e8 100644
--- a/lib/go/thrift/header_transport_test.go
+++ b/lib/go/thrift/header_transport_test.go
@@ -28,10 +28,13 @@ import (
"testing/quick"
)
-func TestTHeaderHeadersReadWrite(t *testing.T) {
+func testTHeaderHeadersReadWriteProtocolID(t *testing.T, protoID THeaderProtocolID) {
trans := NewTMemoryBuffer()
reader := NewTHeaderTransport(trans)
- writer := NewTHeaderTransport(trans)
+ writer, err := NewTHeaderTransportWithProtocolID(trans, protoID)
+ if err != nil {
+ t.Fatal(err)
+ }
const key1 = "key1"
const value1 = "value1"
@@ -98,10 +101,10 @@ func TestTHeaderHeadersReadWrite(t *testing.T) {
read,
)
}
- if prot := reader.Protocol(); prot != THeaderProtocolBinary {
+ if prot := reader.Protocol(); prot != protoID {
t.Errorf(
"reader.Protocol() expected %d, got %d",
- THeaderProtocolBinary,
+ protoID,
prot,
)
}
@@ -121,6 +124,18 @@ func TestTHeaderHeadersReadWrite(t *testing.T) {
}
}
+func TestTHeaderHeadersReadWrite(t *testing.T) {
+ for label, id := range map[string]THeaderProtocolID{
+ "default": THeaderProtocolDefault,
+ "binary": THeaderProtocolBinary,
+ "compact": THeaderProtocolCompact,
+ } {
+ t.Run(label, func(t *testing.T) {
+ testTHeaderHeadersReadWriteProtocolID(t, id)
+ })
+ }
+}
+
func TestTHeaderTransportNoDoubleWrapping(t *testing.T) {
trans := NewTMemoryBuffer()
orig := NewTHeaderTransport(trans)