summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CHANGES.md2
-rw-r--r--compiler/cpp/src/thrift/generate/t_go_generator.cc140
-rw-r--r--lib/go/test/tests/client_error_test.go246
-rw-r--r--lib/go/test/tests/optional_fields_test.go79
-rw-r--r--lib/go/test/tests/protocol_mock.go336
-rw-r--r--lib/go/test/tests/required_fields_test.go46
-rw-r--r--lib/go/thrift/application_exception.go48
-rw-r--r--lib/go/thrift/binary_protocol.go185
-rw-r--r--lib/go/thrift/client.go20
-rw-r--r--lib/go/thrift/compact_protocol.go117
-rw-r--r--lib/go/thrift/debug_protocol.go168
-rw-r--r--lib/go/thrift/deserializer.go17
-rw-r--r--lib/go/thrift/header_protocol.go184
-rw-r--r--lib/go/thrift/header_transport.go39
-rw-r--r--lib/go/thrift/header_transport_test.go6
-rw-r--r--lib/go/thrift/json_protocol.go158
-rw-r--r--lib/go/thrift/json_protocol_test.go62
-rw-r--r--lib/go/thrift/multiplexed_protocol.go10
-rw-r--r--lib/go/thrift/protocol.go134
-rw-r--r--lib/go/thrift/protocol_test.go82
-rw-r--r--lib/go/thrift/serializer.go8
-rw-r--r--lib/go/thrift/serializer_test.go10
-rw-r--r--lib/go/thrift/serializer_types_test.go241
-rw-r--r--lib/go/thrift/simple_json_protocol.go147
-rw-r--r--lib/go/thrift/simple_json_protocol_test.go66
-rw-r--r--lib/go/thrift/simple_server.go2
-rw-r--r--lib/go/thrift/transport_exception.go14
-rw-r--r--lib/go/thrift/transport_exception_test.go31
28 files changed, 1342 insertions, 1256 deletions
diff --git a/CHANGES.md b/CHANGES.md
index fbaf35dff..ceb8f8b6f 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -14,6 +14,7 @@
- [THRIFT-5138](https://issues.apache.org/jira/browse/THRIFT-5138) - Swift generator does not escape keywords properly
- [THRIFT-5164](https://issues.apache.org/jira/browse/THRIFT-5164) - In Go library TProcessor interface now includes ProcessorMap and AddToProcessorMap functions.
- [THRIFT-5186](https://issues.apache.org/jira/browse/THRIFT-5186) - cpp: use all getaddrinfo() results when retrying failed bind() in T{Nonblocking,}ServerSocket
+- [THRIFT-5233](https://issues.apache.org/jira/browse/THRIFT-5233) - go: Now all Read*, Write* and Skip functions in TProtocol accept context arg
### Java
@@ -24,6 +25,7 @@
- [THRIFT-5069](https://issues.apache.org/jira/browse/THRIFT-5069) - Add TSerializerPool and TDeserializerPool, which are thread-safe versions of TSerializer and TDeserializer.
- [THRIFT-5164](https://issues.apache.org/jira/browse/THRIFT-5164) - Add ClientMiddleware function type and WrapClient function to support wrapping a TClient with middleware functions.
- [THRIFT-5164](https://issues.apache.org/jira/browse/THRIFT-5164) - Add ProcessorMiddleware function type and WrapProcessor function to support wrapping a TProcessor with middleware functions.
+- [THRIFT-5233](https://issues.apache.org/jira/browse/THRIFT-5233) - Add context deadline check to ReadMessageBegin in TBinaryProtocol, TCompactProtocol, and THeaderProtocol.
## 0.13.0
diff --git a/compiler/cpp/src/thrift/generate/t_go_generator.cc b/compiler/cpp/src/thrift/generate/t_go_generator.cc
index 4f8715159..b89052b5e 100644
--- a/compiler/cpp/src/thrift/generate/t_go_generator.cc
+++ b/compiler/cpp/src/thrift/generate/t_go_generator.cc
@@ -888,7 +888,7 @@ string t_go_generator::render_fastbinary_includes() {
}
/**
- * Autogen'd comment. The different text is necessary due to
+ * Autogen'd comment. The different text is necessary due to
* https://github.com/golang/go/issues/13560#issuecomment-288457920
*/
string t_go_generator::go_autogen_comment() {
@@ -1585,10 +1585,10 @@ void t_go_generator::generate_go_struct_reader(ostream& out,
const vector<t_field*>& fields = tstruct->get_members();
vector<t_field*>::const_iterator f_iter;
string escaped_tstruct_name(escape_string(tstruct->get_name()));
- out << indent() << "func (p *" << tstruct_name << ") " << read_method_name_ << "(iprot thrift.TProtocol) error {"
+ out << indent() << "func (p *" << tstruct_name << ") " << read_method_name_ << "(ctx context.Context, iprot thrift.TProtocol) error {"
<< endl;
indent_up();
- out << indent() << "if _, err := iprot.ReadStructBegin(); err != nil {" << endl;
+ out << indent() << "if _, err := iprot.ReadStructBegin(ctx); err != nil {" << endl;
out << indent() << " return thrift.PrependError(fmt.Sprintf(\"%T read error: \", p), err)"
<< endl;
out << indent() << "}" << endl << endl;
@@ -1606,7 +1606,7 @@ void t_go_generator::generate_go_struct_reader(ostream& out,
indent(out) << "for {" << endl;
indent_up();
// Read beginning field marker
- out << indent() << "_, fieldTypeId, fieldId, err := iprot.ReadFieldBegin()" << endl;
+ out << indent() << "_, fieldTypeId, fieldId, err := iprot.ReadFieldBegin(ctx)" << endl;
out << indent() << "if err != nil {" << endl;
out << indent() << " return thrift.PrependError(fmt.Sprintf("
"\"%T field %d read error: \", p, fieldId), err)" << endl;
@@ -1646,7 +1646,7 @@ void t_go_generator::generate_go_struct_reader(ostream& out,
}
out << indent() << "if fieldTypeId == " << thriftFieldTypeId << " {" << endl;
- out << indent() << " if err := p." << field_method_prefix << field_method_suffix << "(iprot); err != nil {"
+ out << indent() << " if err := p." << field_method_prefix << field_method_suffix << "(ctx, iprot); err != nil {"
<< endl;
out << indent() << " return err" << endl;
out << indent() << " }" << endl;
@@ -1658,7 +1658,7 @@ void t_go_generator::generate_go_struct_reader(ostream& out,
}
out << indent() << "} else {" << endl;
- out << indent() << " if err := iprot.Skip(fieldTypeId); err != nil {" << endl;
+ out << indent() << " if err := iprot.Skip(ctx, fieldTypeId); err != nil {" << endl;
out << indent() << " return err" << endl;
out << indent() << " }" << endl;
out << indent() << "}" << endl;
@@ -1674,7 +1674,7 @@ void t_go_generator::generate_go_struct_reader(ostream& out,
}
// Skip unknown fields in either case
- out << indent() << "if err := iprot.Skip(fieldTypeId); err != nil {" << endl;
+ out << indent() << "if err := iprot.Skip(ctx, fieldTypeId); err != nil {" << endl;
out << indent() << " return err" << endl;
out << indent() << "}" << endl;
@@ -1685,12 +1685,12 @@ void t_go_generator::generate_go_struct_reader(ostream& out,
}
// Read field end marker
- out << indent() << "if err := iprot.ReadFieldEnd(); err != nil {" << endl;
+ out << indent() << "if err := iprot.ReadFieldEnd(ctx); err != nil {" << endl;
out << indent() << " return err" << endl;
out << indent() << "}" << endl;
indent_down();
out << indent() << "}" << endl;
- out << indent() << "if err := iprot.ReadStructEnd(); err != nil {" << endl;
+ out << indent() << "if err := iprot.ReadStructEnd(ctx); err != nil {" << endl;
out << indent() << " return thrift.PrependError(fmt.Sprintf("
"\"%T read struct end error: \", p), err)" << endl;
out << indent() << "}" << endl;
@@ -1723,7 +1723,7 @@ void t_go_generator::generate_go_struct_reader(ostream& out,
}
out << indent() << "func (p *" << tstruct_name << ") " << field_method_prefix << field_method_suffix
- << "(iprot thrift.TProtocol) error {" << endl;
+ << "(ctx context.Context, iprot thrift.TProtocol) error {" << endl;
indent_up();
generate_deserialize_field(out, *f_iter, false, "p.");
indent_down();
@@ -1741,7 +1741,7 @@ void t_go_generator::generate_go_struct_writer(ostream& out,
string name(tstruct->get_name());
const vector<t_field*>& fields = tstruct->get_sorted_members();
vector<t_field*>::const_iterator f_iter;
- indent(out) << "func (p *" << tstruct_name << ") " << write_method_name_ << "(oprot thrift.TProtocol) error {" << endl;
+ indent(out) << "func (p *" << tstruct_name << ") " << write_method_name_ << "(ctx context.Context, oprot thrift.TProtocol) error {" << endl;
indent_up();
if (tstruct->is_union() && uses_countsetfields) {
std::string tstruct_name(publicize(tstruct->get_name()));
@@ -1750,7 +1750,7 @@ void t_go_generator::generate_go_struct_writer(ostream& out,
<< " return fmt.Errorf(\"%T write union: exactly one field must be set (%d set).\", p, c)"
<< endl << indent() << "}" << endl;
}
- out << indent() << "if err := oprot.WriteStructBegin(\"" << name << "\"); err != nil {" << endl;
+ out << indent() << "if err := oprot.WriteStructBegin(ctx, \"" << name << "\"); err != nil {" << endl;
out << indent() << " return thrift.PrependError(fmt.Sprintf("
"\"%T write struct begin error: \", p), err) }" << endl;
@@ -1776,16 +1776,16 @@ void t_go_generator::generate_go_struct_writer(ostream& out,
}
out << indent() << "if err := p." << field_method_prefix << field_method_suffix
- << "(oprot); err != nil { return err }" << endl;
+ << "(ctx, oprot); err != nil { return err }" << endl;
}
indent_down();
out << indent() << "}" << endl;
// Write the struct map
- out << indent() << "if err := oprot.WriteFieldStop(); err != nil {" << endl;
+ out << indent() << "if err := oprot.WriteFieldStop(ctx); err != nil {" << endl;
out << indent() << " return thrift.PrependError(\"write field stop error: \", err) }" << endl;
- out << indent() << "if err := oprot.WriteStructEnd(); err != nil {" << endl;
+ out << indent() << "if err := oprot.WriteStructEnd(ctx); err != nil {" << endl;
out << indent() << " return thrift.PrependError(\"write struct stop error: \", err) }" << endl;
out << indent() << "return nil" << endl;
indent_down();
@@ -1806,7 +1806,7 @@ void t_go_generator::generate_go_struct_writer(ostream& out,
}
out << indent() << "func (p *" << tstruct_name << ") " << field_method_prefix << field_method_suffix
- << "(oprot thrift.TProtocol) (err error) {" << endl;
+ << "(ctx context.Context, oprot thrift.TProtocol) (err error) {" << endl;
indent_up();
if (field_required == t_field::T_OPTIONAL) {
@@ -1814,7 +1814,7 @@ void t_go_generator::generate_go_struct_writer(ostream& out,
indent_up();
}
- out << indent() << "if err := oprot.WriteFieldBegin(\"" << escape_field_name << "\", "
+ out << indent() << "if err := oprot.WriteFieldBegin(ctx, \"" << escape_field_name << "\", "
<< type_to_enum((*f_iter)->get_type()) << ", " << field_id << "); err != nil {" << endl;
out << indent() << " return thrift.PrependError(fmt.Sprintf(\"%T write field begin error "
<< field_id << ":" << escape_field_name << ": \", p), err) }" << endl;
@@ -1823,7 +1823,7 @@ void t_go_generator::generate_go_struct_writer(ostream& out,
generate_serialize_field(out, *f_iter, "p.");
// Write field closer
- out << indent() << "if err := oprot.WriteFieldEnd(); err != nil {" << endl;
+ out << indent() << "if err := oprot.WriteFieldEnd(ctx); err != nil {" << endl;
out << indent() << " return thrift.PrependError(fmt.Sprintf(\"%T write field end error "
<< field_id << ":" << escape_field_name << ": \", p), err) }" << endl;
@@ -2503,7 +2503,7 @@ void t_go_generator::generate_service_remote(t_service* tservice) {
<< endl;
f_remote << indent() << "argvalue" << i << " := " << tstruct_module << ".New" << tstruct_name
<< "()" << endl;
- f_remote << indent() << err2 << " := argvalue" << i << "." << read_method_name_ << "(" << jsProt << ")" << endl;
+ f_remote << indent() << err2 << " := argvalue" << i << "." << read_method_name_ << "(context.Background(), " << jsProt << ")" << endl;
f_remote << indent() << "if " << err2 << " != nil {" << endl;
f_remote << indent() << " Usage()" << endl;
f_remote << indent() << " return" << endl;
@@ -2531,7 +2531,7 @@ void t_go_generator::generate_service_remote(t_service* tservice) {
<< endl;
f_remote << indent() << "containerStruct" << i << " := " << package_name_aliased << ".New"
<< argumentsName << "()" << endl;
- f_remote << indent() << err2 << " := containerStruct" << i << ".ReadField" << (i + 1) << "("
+ f_remote << indent() << err2 << " := containerStruct" << i << ".ReadField" << (i + 1) << "(context.Background(), "
<< jsProt << ")" << endl;
f_remote << indent() << "if " << err2 << " != nil {" << endl;
f_remote << indent() << " Usage()" << endl;
@@ -2698,19 +2698,19 @@ void t_go_generator::generate_service_server(t_service* tservice) {
f_types_ << indent() << "func (p *" << serviceName
<< "Processor) Process(ctx context.Context, iprot, oprot thrift.TProtocol) (success bool, err "
"thrift.TException) {" << endl;
- f_types_ << indent() << " name, _, seqId, err := iprot.ReadMessageBegin()" << endl;
+ f_types_ << indent() << " name, _, seqId, err := iprot.ReadMessageBegin(ctx)" << endl;
f_types_ << indent() << " if err != nil { return false, err }" << endl;
f_types_ << indent() << " if processor, ok := p.GetProcessorFunction(name); ok {" << endl;
f_types_ << indent() << " return processor.Process(ctx, seqId, iprot, oprot)" << endl;
f_types_ << indent() << " }" << endl;
- f_types_ << indent() << " iprot.Skip(thrift.STRUCT)" << endl;
- f_types_ << indent() << " iprot.ReadMessageEnd()" << endl;
+ f_types_ << indent() << " iprot.Skip(ctx, thrift.STRUCT)" << endl;
+ f_types_ << indent() << " iprot.ReadMessageEnd(ctx)" << endl;
f_types_ << indent() << " " << x
<< " := thrift.NewTApplicationException(thrift.UNKNOWN_METHOD, \"Unknown function "
"\" + name)" << endl;
- f_types_ << indent() << " oprot.WriteMessageBegin(name, thrift.EXCEPTION, seqId)" << endl;
- f_types_ << indent() << " " << x << ".Write(oprot)" << endl;
- f_types_ << indent() << " oprot.WriteMessageEnd()" << endl;
+ f_types_ << indent() << " oprot.WriteMessageBegin(ctx, name, thrift.EXCEPTION, seqId)" << endl;
+ f_types_ << indent() << " " << x << ".Write(ctx, oprot)" << endl;
+ f_types_ << indent() << " oprot.WriteMessageEnd(ctx)" << endl;
f_types_ << indent() << " oprot.Flush(ctx)" << endl;
f_types_ << indent() << " return false, " << x << endl;
f_types_ << indent() << "" << endl;
@@ -2765,21 +2765,21 @@ void t_go_generator::generate_process_function(t_service* tservice, t_function*
"thrift.TException) {" << endl;
indent_up();
f_types_ << indent() << "args := " << argsname << "{}" << endl;
- f_types_ << indent() << "if err = args." << read_method_name_ << "(iprot); err != nil {" << endl;
- f_types_ << indent() << " iprot.ReadMessageEnd()" << endl;
+ f_types_ << indent() << "if err = args." << read_method_name_ << "(ctx, iprot); err != nil {" << endl;
+ f_types_ << indent() << " iprot.ReadMessageEnd(ctx)" << endl;
if (!tfunction->is_oneway()) {
f_types_ << indent()
<< " x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err.Error())"
<< endl;
- f_types_ << indent() << " oprot.WriteMessageBegin(\"" << escape_string(tfunction->get_name())
+ f_types_ << indent() << " oprot.WriteMessageBegin(ctx, \"" << escape_string(tfunction->get_name())
<< "\", thrift.EXCEPTION, seqId)" << endl;
- f_types_ << indent() << " x.Write(oprot)" << endl;
- f_types_ << indent() << " oprot.WriteMessageEnd()" << endl;
+ f_types_ << indent() << " x.Write(ctx, oprot)" << endl;
+ f_types_ << indent() << " oprot.WriteMessageEnd(ctx)" << endl;
f_types_ << indent() << " oprot.Flush(ctx)" << endl;
}
f_types_ << indent() << " return false, err" << endl;
f_types_ << indent() << "}" << endl << endl;
- f_types_ << indent() << "iprot.ReadMessageEnd()" << endl;
+ f_types_ << indent() << "iprot.ReadMessageEnd(ctx)" << endl;
if (!tfunction->is_oneway()) {
f_types_ << indent() << "result := " << resultname << "{}" << endl;
@@ -2839,10 +2839,10 @@ void t_go_generator::generate_process_function(t_service* tservice, t_function*
f_types_ << indent() << " x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "
"\"Internal error processing " << escape_string(tfunction->get_name())
<< ": \" + err2.Error())" << endl;
- f_types_ << indent() << " oprot.WriteMessageBegin(\"" << escape_string(tfunction->get_name())
+ f_types_ << indent() << " oprot.WriteMessageBegin(ctx, \"" << escape_string(tfunction->get_name())
<< "\", thrift.EXCEPTION, seqId)" << endl;
- f_types_ << indent() << " x.Write(oprot)" << endl;
- f_types_ << indent() << " oprot.WriteMessageEnd()" << endl;
+ f_types_ << indent() << " x.Write(ctx, oprot)" << endl;
+ f_types_ << indent() << " oprot.WriteMessageEnd(ctx)" << endl;
f_types_ << indent() << " oprot.Flush(ctx)" << endl;
}
@@ -2868,15 +2868,15 @@ void t_go_generator::generate_process_function(t_service* tservice, t_function*
} else {
f_types_ << endl;
}
- f_types_ << indent() << "if err2 = oprot.WriteMessageBegin(\""
+ f_types_ << indent() << "if err2 = oprot.WriteMessageBegin(ctx, \""
<< escape_string(tfunction->get_name()) << "\", thrift.REPLY, seqId); err2 != nil {"
<< endl;
f_types_ << indent() << " err = err2" << endl;
f_types_ << indent() << "}" << endl;
- f_types_ << indent() << "if err2 = result." << write_method_name_ << "(oprot); err == nil && err2 != nil {" << endl;
+ f_types_ << indent() << "if err2 = result." << write_method_name_ << "(ctx, oprot); err == nil && err2 != nil {" << endl;
f_types_ << indent() << " err = err2" << endl;
f_types_ << indent() << "}" << endl;
- f_types_ << indent() << "if err2 = oprot.WriteMessageEnd(); err == nil && err2 != nil {"
+ f_types_ << indent() << "if err2 = oprot.WriteMessageEnd(ctx); err == nil && err2 != nil {"
<< endl;
f_types_ << indent() << " err = err2" << endl;
f_types_ << indent() << "}" << endl;
@@ -2945,42 +2945,42 @@ void t_go_generator::generate_deserialize_field(ostream& out,
case t_base_type::TYPE_STRING:
if (type->is_binary() && !inkey) {
- out << "ReadBinary()";
+ out << "ReadBinary(ctx)";
} else {
- out << "ReadString()";
+ out << "ReadString(ctx)";
}
break;
case t_base_type::TYPE_BOOL:
- out << "ReadBool()";
+ out << "ReadBool(ctx)";
break;
case t_base_type::TYPE_I8:
- out << "ReadByte()";
+ out << "ReadByte(ctx)";
break;
case t_base_type::TYPE_I16:
- out << "ReadI16()";
+ out << "ReadI16(ctx)";
break;
case t_base_type::TYPE_I32:
- out << "ReadI32()";
+ out << "ReadI32(ctx)";
break;
case t_base_type::TYPE_I64:
- out << "ReadI64()";
+ out << "ReadI64(ctx)";
break;
case t_base_type::TYPE_DOUBLE:
- out << "ReadDouble()";
+ out << "ReadDouble(ctx)";
break;
default:
throw "compiler error: no Go name for base type " + t_base_type::t_base_name(tbase);
}
} else if (type->is_enum()) {
- out << "ReadI32()";
+ out << "ReadI32(ctx)";
}
out << "; err != nil {" << endl;
@@ -3023,7 +3023,7 @@ void t_go_generator::generate_deserialize_struct(ostream& out,
out << indent() << prefix << eq << (pointer_field ? "&" : "");
generate_go_struct_initializer(out, tstruct);
- out << indent() << "if err := " << prefix << "." << read_method_name_ << "(iprot); err != nil {" << endl;
+ out << indent() << "if err := " << prefix << "." << read_method_name_ << "(ctx, iprot); err != nil {" << endl;
out << indent() << " return thrift.PrependError(fmt.Sprintf(\"%T error reading struct: \", "
<< prefix << "), err)" << endl;
out << indent() << "}" << endl;
@@ -3047,21 +3047,21 @@ void t_go_generator::generate_deserialize_container(ostream& out,
// Declare variables, read header
if (ttype->is_map()) {
- out << indent() << "_, _, size, err := iprot.ReadMapBegin()" << endl;
+ out << indent() << "_, _, size, err := iprot.ReadMapBegin(ctx)" << endl;
out << indent() << "if err != nil {" << endl;
out << indent() << " return thrift.PrependError(\"error reading map begin: \", err)" << endl;
out << indent() << "}" << endl;
out << indent() << "tMap := make(" << type_to_go_type(orig_type) << ", size)" << endl;
out << indent() << prefix << eq << " " << (pointer_field ? "&" : "") << "tMap" << endl;
} else if (ttype->is_set()) {
- out << indent() << "_, size, err := iprot.ReadSetBegin()" << endl;
+ out << indent() << "_, size, err := iprot.ReadSetBegin(ctx)" << endl;
out << indent() << "if err != nil {" << endl;
out << indent() << " return thrift.PrependError(\"error reading set begin: \", err)" << endl;
out << indent() << "}" << endl;
out << indent() << "tSet := make(" << type_to_go_type(orig_type) << ", 0, size)" << endl;
out << indent() << prefix << eq << " " << (pointer_field ? "&" : "") << "tSet" << endl;
} else if (ttype->is_list()) {
- out << indent() << "_, size, err := iprot.ReadListBegin()" << endl;
+ out << indent() << "_, size, err := iprot.ReadListBegin(ctx)" << endl;
out << indent() << "if err != nil {" << endl;
out << indent() << " return thrift.PrependError(\"error reading list begin: \", err)" << endl;
out << indent() << "}" << endl;
@@ -3092,15 +3092,15 @@ void t_go_generator::generate_deserialize_container(ostream& out,
// Read container end
if (ttype->is_map()) {
- out << indent() << "if err := iprot.ReadMapEnd(); err != nil {" << endl;
+ out << indent() << "if err := iprot.ReadMapEnd(ctx); err != nil {" << endl;
out << indent() << " return thrift.PrependError(\"error reading map end: \", err)" << endl;
out << indent() << "}" << endl;
} else if (ttype->is_set()) {
- out << indent() << "if err := iprot.ReadSetEnd(); err != nil {" << endl;
+ out << indent() << "if err := iprot.ReadSetEnd(ctx); err != nil {" << endl;
out << indent() << " return thrift.PrependError(\"error reading set end: \", err)" << endl;
out << indent() << "}" << endl;
} else if (ttype->is_list()) {
- out << indent() << "if err := iprot.ReadListEnd(); err != nil {" << endl;
+ out << indent() << "if err := iprot.ReadListEnd(ctx); err != nil {" << endl;
out << indent() << " return thrift.PrependError(\"error reading list end: \", err)" << endl;
out << indent() << "}" << endl;
}
@@ -3194,42 +3194,42 @@ void t_go_generator::generate_serialize_field(ostream& out,
case t_base_type::TYPE_STRING:
if (type->is_binary() && !inkey) {
- out << "WriteBinary(" << name << ")";
+ out << "WriteBinary(ctx, " << name << ")";
} else {
- out << "WriteString(string(" << name << "))";
+ out << "WriteString(ctx, string(" << name << "))";
}
break;
case t_base_type::TYPE_BOOL:
- out << "WriteBool(bool(" << name << "))";
+ out << "WriteBool(ctx, bool(" << name << "))";
break;
case t_base_type::TYPE_I8:
- out << "WriteByte(int8(" << name << "))";
+ out << "WriteByte(ctx, int8(" << name << "))";
break;
case t_base_type::TYPE_I16:
- out << "WriteI16(int16(" << name << "))";
+ out << "WriteI16(ctx, int16(" << name << "))";
break;
case t_base_type::TYPE_I32:
- out << "WriteI32(int32(" << name << "))";
+ out << "WriteI32(ctx, int32(" << name << "))";
break;
case t_base_type::TYPE_I64:
- out << "WriteI64(int64(" << name << "))";
+ out << "WriteI64(ctx, int64(" << name << "))";
break;
case t_base_type::TYPE_DOUBLE:
- out << "WriteDouble(float64(" << name << "))";
+ out << "WriteDouble(ctx, float64(" << name << "))";
break;
default:
throw "compiler error: no Go name for base type " + t_base_type::t_base_name(tbase);
}
} else if (type->is_enum()) {
- out << "WriteI32(int32(" << name << "))";
+ out << "WriteI32(ctx, int32(" << name << "))";
}
out << "; err != nil {" << endl;
@@ -3250,7 +3250,7 @@ void t_go_generator::generate_serialize_field(ostream& out,
*/
void t_go_generator::generate_serialize_struct(ostream& out, t_struct* tstruct, string prefix) {
(void)tstruct;
- out << indent() << "if err := " << prefix << "." << write_method_name_ << "(oprot); err != nil {" << endl;
+ out << indent() << "if err := " << prefix << "." << write_method_name_ << "(ctx, oprot); err != nil {" << endl;
out << indent() << " return thrift.PrependError(fmt.Sprintf(\"%T error writing struct: \", "
<< prefix << "), err)" << endl;
out << indent() << "}" << endl;
@@ -3264,20 +3264,20 @@ void t_go_generator::generate_serialize_container(ostream& out,
prefix = "*" + prefix;
}
if (ttype->is_map()) {
- out << indent() << "if err := oprot.WriteMapBegin("
+ out << indent() << "if err := oprot.WriteMapBegin(ctx, "
<< type_to_enum(((t_map*)ttype)->get_key_type()) << ", "
<< type_to_enum(((t_map*)ttype)->get_val_type()) << ", "
<< "len(" << prefix << ")); err != nil {" << endl;
out << indent() << " return thrift.PrependError(\"error writing map begin: \", err)" << endl;
out << indent() << "}" << endl;
} else if (ttype->is_set()) {
- out << indent() << "if err := oprot.WriteSetBegin("
+ out << indent() << "if err := oprot.WriteSetBegin(ctx, "
<< type_to_enum(((t_set*)ttype)->get_elem_type()) << ", "
<< "len(" << prefix << ")); err != nil {" << endl;
out << indent() << " return thrift.PrependError(\"error writing set begin: \", err)" << endl;
out << indent() << "}" << endl;
} else if (ttype->is_list()) {
- out << indent() << "if err := oprot.WriteListBegin("
+ out << indent() << "if err := oprot.WriteListBegin(ctx, "
<< type_to_enum(((t_list*)ttype)->get_elem_type()) << ", "
<< "len(" << prefix << ")); err != nil {" << endl;
out << indent() << " return thrift.PrependError(\"error writing list begin: \", err)" << endl;
@@ -3323,15 +3323,15 @@ void t_go_generator::generate_serialize_container(ostream& out,
}
if (ttype->is_map()) {
- out << indent() << "if err := oprot.WriteMapEnd(); err != nil {" << endl;
+ out << indent() << "if err := oprot.WriteMapEnd(ctx); err != nil {" << endl;
out << indent() << " return thrift.PrependError(\"error writing map end: \", err)" << endl;
out << indent() << "}" << endl;
} else if (ttype->is_set()) {
- out << indent() << "if err := oprot.WriteSetEnd(); err != nil {" << endl;
+ out << indent() << "if err := oprot.WriteSetEnd(ctx); err != nil {" << endl;
out << indent() << " return thrift.PrependError(\"error writing set end: \", err)" << endl;
out << indent() << "}" << endl;
} else if (ttype->is_list()) {
- out << indent() << "if err := oprot.WriteListEnd(); err != nil {" << endl;
+ out << indent() << "if err := oprot.WriteListEnd(ctx); err != nil {" << endl;
out << indent() << " return thrift.PrependError(\"error writing list end: \", err)" << endl;
out << indent() << "}" << endl;
}
diff --git a/lib/go/test/tests/client_error_test.go b/lib/go/test/tests/client_error_test.go
index fdec4ea57..a06416365 100644
--- a/lib/go/test/tests/client_error_test.go
+++ b/lib/go/test/tests/client_error_test.go
@@ -37,84 +37,84 @@ func prepareClientCallReply(protocol *MockTProtocol, failAt int, failWith error)
if failAt == 0 {
err = failWith
}
- last := protocol.EXPECT().WriteMessageBegin("testStruct", thrift.CALL, int32(1)).Return(err)
+ last := protocol.EXPECT().WriteMessageBegin(context.Background(), "testStruct", thrift.CALL, int32(1)).Return(err)
if failAt == 0 {
return true
}
if failAt == 1 {
err = failWith
}
- last = protocol.EXPECT().WriteStructBegin("testStruct_args").Return(err).After(last)
+ last = protocol.EXPECT().WriteStructBegin(context.Background(), "testStruct_args").Return(err).After(last)
if failAt == 1 {
return true
}
if failAt == 2 {
err = failWith
}
- last = protocol.EXPECT().WriteFieldBegin("thing", thrift.TType(thrift.STRUCT), int16(1)).Return(err).After(last)
+ last = protocol.EXPECT().WriteFieldBegin(context.Background(), "thing", thrift.TType(thrift.STRUCT), int16(1)).Return(err).After(last)
if failAt == 2 {
return true
}
if failAt == 3 {
err = failWith
}
- last = protocol.EXPECT().WriteStructBegin("TestStruct").Return(err).After(last)
+ last = protocol.EXPECT().WriteStructBegin(context.Background(), "TestStruct").Return(err).After(last)
if failAt == 3 {
return true
}
if failAt == 4 {
err = failWith
}
- last = protocol.EXPECT().WriteFieldBegin("m", thrift.TType(thrift.MAP), int16(1)).Return(err).After(last)
+ last = protocol.EXPECT().WriteFieldBegin(context.Background(), "m", thrift.TType(thrift.MAP), int16(1)).Return(err).After(last)
if failAt == 4 {
return true
}
if failAt == 5 {
err = failWith
}
- last = protocol.EXPECT().WriteMapBegin(thrift.TType(thrift.STRING), thrift.TType(thrift.STRING), 0).Return(err).After(last)
+ last = protocol.EXPECT().WriteMapBegin(context.Background(), thrift.TType(thrift.STRING), thrift.TType(thrift.STRING), 0).Return(err).After(last)
if failAt == 5 {
return true
}
if failAt == 6 {
err = failWith
}
- last = protocol.EXPECT().WriteMapEnd().Return(err).After(last)
+ last = protocol.EXPECT().WriteMapEnd(context.Background()).Return(err).After(last)
if failAt == 6 {
return true
}
if failAt == 7 {
err = failWith
}
- last = protocol.EXPECT().WriteFieldEnd().Return(err).After(last)
+ last = protocol.EXPECT().WriteFieldEnd(context.Background()).Return(err).After(last)
if failAt == 7 {
return true
}
if failAt == 8 {
err = failWith
}
- last = protocol.EXPECT().WriteFieldBegin("l", thrift.TType(thrift.LIST), int16(2)).Return(err).After(last)
+ last = protocol.EXPECT().WriteFieldBegin(context.Background(), "l", thrift.TType(thrift.LIST), int16(2)).Return(err).After(last)
if failAt == 8 {
return true
}
if failAt == 9 {
err = failWith
}
- last = protocol.EXPECT().WriteListBegin(thrift.TType(thrift.STRING), 0).Return(err).After(last)
+ last = protocol.EXPECT().WriteListBegin(context.Background(), thrift.TType(thrift.STRING), 0).Return(err).After(last)
if failAt == 9 {
return true
}
if failAt == 10 {
err = failWith
}
- last = protocol.EXPECT().WriteListEnd().Return(err).After(last)
+ last = protocol.EXPECT().WriteListEnd(context.Background()).Return(err).After(last)
if failAt == 10 {
return true
}
if failAt == 11 {
err = failWith
}
- last = protocol.EXPECT().WriteFieldEnd().Return(err).After(last)
+ last = protocol.EXPECT().WriteFieldEnd(context.Background()).Return(err).After(last)
if failAt == 11 {
return true
}
@@ -122,91 +122,91 @@ func prepareClientCallReply(protocol *MockTProtocol, failAt int, failWith error)
err = failWith
}
- last = protocol.EXPECT().WriteFieldBegin("s", thrift.TType(thrift.SET), int16(3)).Return(err).After(last)
+ last = protocol.EXPECT().WriteFieldBegin(context.Background(), "s", thrift.TType(thrift.SET), int16(3)).Return(err).After(last)
if failAt == 12 {
return true
}
if failAt == 13 {
err = failWith
}
- last = protocol.EXPECT().WriteSetBegin(thrift.TType(thrift.STRING), 0).Return(err).After(last)
+ last = protocol.EXPECT().WriteSetBegin(context.Background(), thrift.TType(thrift.STRING), 0).Return(err).After(last)
if failAt == 13 {
return true
}
if failAt == 14 {
err = failWith
}
- last = protocol.EXPECT().WriteSetEnd().Return(err).After(last)
+ last = protocol.EXPECT().WriteSetEnd(context.Background()).Return(err).After(last)
if failAt == 14 {
return true
}
if failAt == 15 {
err = failWith
}
- last = protocol.EXPECT().WriteFieldEnd().Return(err).After(last)
+ last = protocol.EXPECT().WriteFieldEnd(context.Background()).Return(err).After(last)
if failAt == 15 {
return true
}
if failAt == 16 {
err = failWith
}
- last = protocol.EXPECT().WriteFieldBegin("i", thrift.TType(thrift.I32), int16(4)).Return(err).After(last)
+ last = protocol.EXPECT().WriteFieldBegin(context.Background(), "i", thrift.TType(thrift.I32), int16(4)).Return(err).After(last)
if failAt == 16 {
return true
}
if failAt == 17 {
err = failWith
}
- last = protocol.EXPECT().WriteI32(int32(3)).Return(err).After(last)
+ last = protocol.EXPECT().WriteI32(context.Background(), int32(3)).Return(err).After(last)
if failAt == 17 {
return true
}
if failAt == 18 {
err = failWith
}
- last = protocol.EXPECT().WriteFieldEnd().Return(err).After(last)
+ last = protocol.EXPECT().WriteFieldEnd(context.Background()).Return(err).After(last)
if failAt == 18 {
return true
}
if failAt == 19 {
err = failWith
}
- last = protocol.EXPECT().WriteFieldStop().Return(err).After(last)
+ last = protocol.EXPECT().WriteFieldStop(context.Background()).Return(err).After(last)
if failAt == 19 {
return true
}
if failAt == 20 {
err = failWith
}
- last = protocol.EXPECT().WriteStructEnd().Return(err).After(last)
+ last = protocol.EXPECT().WriteStructEnd(context.Background()).Return(err).After(last)
if failAt == 20 {
return true
}
if failAt == 21 {
err = failWith
}
- last = protocol.EXPECT().WriteFieldEnd().Return(err).After(last)
+ last = protocol.EXPECT().WriteFieldEnd(context.Background()).Return(err).After(last)
if failAt == 21 {
return true
}
if failAt == 22 {
err = failWith
}
- last = protocol.EXPECT().WriteFieldStop().Return(err).After(last)
+ last = protocol.EXPECT().WriteFieldStop(context.Background()).Return(err).After(last)
if failAt == 22 {
return true
}
if failAt == 23 {
err = failWith
}
- last = protocol.EXPECT().WriteStructEnd().Return(err).After(last)
+ last = protocol.EXPECT().WriteStructEnd(context.Background()).Return(err).After(last)
if failAt == 23 {
return true
}
if failAt == 24 {
err = failWith
}
- last = protocol.EXPECT().WriteMessageEnd().Return(err).After(last)
+ last = protocol.EXPECT().WriteMessageEnd(context.Background()).Return(err).After(last)
if failAt == 24 {
return true
}
@@ -220,175 +220,175 @@ func prepareClientCallReply(protocol *MockTProtocol, failAt int, failWith error)
if failAt == 26 {
err = failWith
}
- last = protocol.EXPECT().ReadMessageBegin().Return("testStruct", thrift.REPLY, int32(1), err).After(last)
+ last = protocol.EXPECT().ReadMessageBegin(context.Background()).Return("testStruct", thrift.REPLY, int32(1), err).After(last)
if failAt == 26 {
return true
}
if failAt == 27 {
err = failWith
}
- last = protocol.EXPECT().ReadStructBegin().Return("testStruct_args", err).After(last)
+ last = protocol.EXPECT().ReadStructBegin(context.Background()).Return("testStruct_args", err).After(last)
if failAt == 27 {
return true
}
if failAt == 28 {
err = failWith
}
- last = protocol.EXPECT().ReadFieldBegin().Return("_", thrift.TType(thrift.STRUCT), int16(0), err).After(last)
+ last = protocol.EXPECT().ReadFieldBegin(context.Background()).Return("_", thrift.TType(thrift.STRUCT), int16(0), err).After(last)
if failAt == 28 {
return true
}
if failAt == 29 {
err = failWith
}
- last = protocol.EXPECT().ReadStructBegin().Return("TestStruct", err).After(last)
+ last = protocol.EXPECT().ReadStructBegin(context.Background()).Return("TestStruct", err).After(last)
if failAt == 29 {
return true
}
if failAt == 30 {
err = failWith
}
- last = protocol.EXPECT().ReadFieldBegin().Return("m", thrift.TType(thrift.MAP), int16(1), err).After(last)
+ last = protocol.EXPECT().ReadFieldBegin(context.Background()).Return("m", thrift.TType(thrift.MAP), int16(1), err).After(last)
if failAt == 30 {
return true
}
if failAt == 31 {
err = failWith
}
- last = protocol.EXPECT().ReadMapBegin().Return(thrift.TType(thrift.STRING), thrift.TType(thrift.STRING), 0, err).After(last)
+ last = protocol.EXPECT().ReadMapBegin(context.Background()).Return(thrift.TType(thrift.STRING), thrift.TType(thrift.STRING), 0, err).After(last)
if failAt == 31 {
return true
}
if failAt == 32 {
err = failWith
}
- last = protocol.EXPECT().ReadMapEnd().Return(err).After(last)
+ last = protocol.EXPECT().ReadMapEnd(context.Background()).Return(err).After(last)
if failAt == 32 {
return true
}
if failAt == 33 {
err = failWith
}
- last = protocol.EXPECT().ReadFieldEnd().Return(err).After(last)
+ last = protocol.EXPECT().ReadFieldEnd(context.Background()).Return(err).After(last)
if failAt == 33 {
return true
}
if failAt == 34 {
err = failWith
}
- last = protocol.EXPECT().ReadFieldBegin().Return("l", thrift.TType(thrift.LIST), int16(2), err).After(last)
+ last = protocol.EXPECT().ReadFieldBegin(context.Background()).Return("l", thrift.TType(thrift.LIST), int16(2), err).After(last)
if failAt == 34 {
return true
}
if failAt == 35 {
err = failWith
}
- last = protocol.EXPECT().ReadListBegin().Return(thrift.TType(thrift.STRING), 0, err).After(last)
+ last = protocol.EXPECT().ReadListBegin(context.Background()).Return(thrift.TType(thrift.STRING), 0, err).After(last)
if failAt == 35 {
return true
}
if failAt == 36 {
err = failWith
}
- last = protocol.EXPECT().ReadListEnd().Return(err).After(last)
+ last = protocol.EXPECT().ReadListEnd(context.Background()).Return(err).After(last)
if failAt == 36 {
return true
}
if failAt == 37 {
err = failWith
}
- last = protocol.EXPECT().ReadFieldEnd().Return(err).After(last)
+ last = protocol.EXPECT().ReadFieldEnd(context.Background()).Return(err).After(last)
if failAt == 37 {
return true
}
if failAt == 38 {
err = failWith
}
- last = protocol.EXPECT().ReadFieldBegin().Return("s", thrift.TType(thrift.SET), int16(3), err).After(last)
+ last = protocol.EXPECT().ReadFieldBegin(context.Background()).Return("s", thrift.TType(thrift.SET), int16(3), err).After(last)
if failAt == 38 {
return true
}
if failAt == 39 {
err = failWith
}
- last = protocol.EXPECT().ReadSetBegin().Return(thrift.TType(thrift.STRING), 0, err).After(last)
+ last = protocol.EXPECT().ReadSetBegin(context.Background()).Return(thrift.TType(thrift.STRING), 0, err).After(last)
if failAt == 39 {
return true
}
if failAt == 40 {
err = failWith
}
- last = protocol.EXPECT().ReadSetEnd().Return(err).After(last)
+ last = protocol.EXPECT().ReadSetEnd(context.Background()).Return(err).After(last)
if failAt == 40 {
return true
}
if failAt == 41 {
err = failWith
}
- last = protocol.EXPECT().ReadFieldEnd().Return(err).After(last)
+ last = protocol.EXPECT().ReadFieldEnd(context.Background()).Return(err).After(last)
if failAt == 41 {
return true
}
if failAt == 42 {
err = failWith
}
- last = protocol.EXPECT().ReadFieldBegin().Return("i", thrift.TType(thrift.I32), int16(4), err).After(last)
+ last = protocol.EXPECT().ReadFieldBegin(context.Background()).Return("i", thrift.TType(thrift.I32), int16(4), err).After(last)
if failAt == 42 {
return true
}
if failAt == 43 {
err = failWith
}
- last = protocol.EXPECT().ReadI32().Return(int32(3), err).After(last)
+ last = protocol.EXPECT().ReadI32(context.Background()).Return(int32(3), err).After(last)
if failAt == 43 {
return true
}
if failAt == 44 {
err = failWith
}
- last = protocol.EXPECT().ReadFieldEnd().Return(err).After(last)
+ last = protocol.EXPECT().ReadFieldEnd(context.Background()).Return(err).After(last)
if failAt == 44 {
return true
}
if failAt == 45 {
err = failWith
}
- last = protocol.EXPECT().ReadFieldBegin().Return("_", thrift.TType(thrift.STOP), int16(5), err).After(last)
+ last = protocol.EXPECT().ReadFieldBegin(context.Background()).Return("_", thrift.TType(thrift.STOP), int16(5), err).After(last)
if failAt == 45 {
return true
}
if failAt == 46 {
err = failWith
}
- last = protocol.EXPECT().ReadStructEnd().Return(err).After(last)
+ last = protocol.EXPECT().ReadStructEnd(context.Background()).Return(err).After(last)
if failAt == 46 {
return true
}
if failAt == 47 {
err = failWith
}
- last = protocol.EXPECT().ReadFieldEnd().Return(err).After(last)
+ last = protocol.EXPECT().ReadFieldEnd(context.Background()).Return(err).After(last)
if failAt == 47 {
return true
}
if failAt == 48 {
err = failWith
}
- last = protocol.EXPECT().ReadFieldBegin().Return("_", thrift.TType(thrift.STOP), int16(1), err).After(last)
+ last = protocol.EXPECT().ReadFieldBegin(context.Background()).Return("_", thrift.TType(thrift.STOP), int16(1), err).After(last)
if failAt == 48 {
return true
}
if failAt == 49 {
err = failWith
}
- last = protocol.EXPECT().ReadStructEnd().Return(err).After(last)
+ last = protocol.EXPECT().ReadStructEnd(context.Background()).Return(err).After(last)
if failAt == 49 {
return true
}
if failAt == 50 {
err = failWith
}
- last = protocol.EXPECT().ReadMessageEnd().Return(err).After(last)
+ last = protocol.EXPECT().ReadMessageEnd(context.Background()).Return(err).After(last)
if failAt == 50 {
return true
}
@@ -529,91 +529,91 @@ func prepareClientCallException(protocol *MockTProtocol, failAt int, failWith er
var err error = nil
// No need to test failure in this block, because it is covered in other test cases
- last := protocol.EXPECT().WriteMessageBegin("testString", thrift.CALL, int32(1))
- last = protocol.EXPECT().WriteStructBegin("testString_args").After(last)
- last = protocol.EXPECT().WriteFieldBegin("s", thrift.TType(thrift.STRING), int16(1)).After(last)
- last = protocol.EXPECT().WriteString("test").After(last)
- last = protocol.EXPECT().WriteFieldEnd().After(last)
- last = protocol.EXPECT().WriteFieldStop().After(last)
- last = protocol.EXPECT().WriteStructEnd().After(last)
- last = protocol.EXPECT().WriteMessageEnd().After(last)
+ last := protocol.EXPECT().WriteMessageBegin(context.Background(), "testString", thrift.CALL, int32(1))
+ last = protocol.EXPECT().WriteStructBegin(context.Background(), "testString_args").After(last)
+ last = protocol.EXPECT().WriteFieldBegin(context.Background(), "s", thrift.TType(thrift.STRING), int16(1)).After(last)
+ last = protocol.EXPECT().WriteString(context.Background(), "test").After(last)
+ last = protocol.EXPECT().WriteFieldEnd(context.Background()).After(last)
+ last = protocol.EXPECT().WriteFieldStop(context.Background()).After(last)
+ last = protocol.EXPECT().WriteStructEnd(context.Background()).After(last)
+ last = protocol.EXPECT().WriteMessageEnd(context.Background()).After(last)
last = protocol.EXPECT().Flush(context.Background()).After(last)
// Reading the exception, might fail.
if failAt == 0 {
err = failWith
}
- last = protocol.EXPECT().ReadMessageBegin().Return("testString", thrift.EXCEPTION, int32(1), err).After(last)
+ last = protocol.EXPECT().ReadMessageBegin(context.Background()).Return("testString", thrift.EXCEPTION, int32(1), err).After(last)
if failAt == 0 {
return true
}
if failAt == 1 {
err = failWith
}
- last = protocol.EXPECT().ReadStructBegin().Return("TApplicationException", err).After(last)
+ last = protocol.EXPECT().ReadStructBegin(context.Background()).Return("TApplicationException", err).After(last)
if failAt == 1 {
return true
}
if failAt == 2 {
err = failWith
}
- last = protocol.EXPECT().ReadFieldBegin().Return("message", thrift.TType(thrift.STRING), int16(1), err).After(last)
+ last = protocol.EXPECT().ReadFieldBegin(context.Background()).Return("message", thrift.TType(thrift.STRING), int16(1), err).After(last)
if failAt == 2 {
return true
}
if failAt == 3 {
err = failWith
}
- last = protocol.EXPECT().ReadString().Return("test", err).After(last)
+ last = protocol.EXPECT().ReadString(context.Background()).Return("test", err).After(last)
if failAt == 3 {
return true
}
if failAt == 4 {
err = failWith
}
- last = protocol.EXPECT().ReadFieldEnd().Return(err).After(last)
+ last = protocol.EXPECT().ReadFieldEnd(context.Background()).Return(err).After(last)
if failAt == 4 {
return true
}
if failAt == 5 {
err = failWith
}
- last = protocol.EXPECT().ReadFieldBegin().Return("type", thrift.TType(thrift.I32), int16(2), err).After(last)
+ last = protocol.EXPECT().ReadFieldBegin(context.Background()).Return("type", thrift.TType(thrift.I32), int16(2), err).After(last)
if failAt == 5 {
return true
}
if failAt == 6 {
err = failWith
}
- last = protocol.EXPECT().ReadI32().Return(int32(thrift.PROTOCOL_ERROR), err).After(last)
+ last = protocol.EXPECT().ReadI32(context.Background()).Return(int32(thrift.PROTOCOL_ERROR), err).After(last)
if failAt == 6 {
return true
}
if failAt == 7 {
err = failWith
}
- last = protocol.EXPECT().ReadFieldEnd().Return(err).After(last)
+ last = protocol.EXPECT().ReadFieldEnd(context.Background()).Return(err).After(last)
if failAt == 7 {
return true
}
if failAt == 8 {
err = failWith
}
- last = protocol.EXPECT().ReadFieldBegin().Return("_", thrift.TType(thrift.STOP), int16(2), err).After(last)
+ last = protocol.EXPECT().ReadFieldBegin(context.Background()).Return("_", thrift.TType(thrift.STOP), int16(2), err).After(last)
if failAt == 8 {
return true
}
if failAt == 9 {
err = failWith
}
- last = protocol.EXPECT().ReadStructEnd().Return(err).After(last)
+ last = protocol.EXPECT().ReadStructEnd(context.Background()).Return(err).After(last)
if failAt == 9 {
return true
}
if failAt == 10 {
err = failWith
}
- last = protocol.EXPECT().ReadMessageEnd().Return(err).After(last)
+ last = protocol.EXPECT().ReadMessageEnd(context.Background()).Return(err).After(last)
if failAt == 10 {
return true
}
@@ -697,16 +697,16 @@ func TestClientSeqIdMismatch(t *testing.T) {
mockCtrl := gomock.NewController(t)
protocol := NewMockTProtocol(mockCtrl)
gomock.InOrder(
- protocol.EXPECT().WriteMessageBegin("testString", thrift.CALL, int32(1)),
- protocol.EXPECT().WriteStructBegin("testString_args"),
- protocol.EXPECT().WriteFieldBegin("s", thrift.TType(thrift.STRING), int16(1)),
- protocol.EXPECT().WriteString("test"),
- protocol.EXPECT().WriteFieldEnd(),
- protocol.EXPECT().WriteFieldStop(),
- protocol.EXPECT().WriteStructEnd(),
- protocol.EXPECT().WriteMessageEnd(),
+ protocol.EXPECT().WriteMessageBegin(context.Background(), "testString", thrift.CALL, int32(1)),
+ protocol.EXPECT().WriteStructBegin(context.Background(), "testString_args"),
+ protocol.EXPECT().WriteFieldBegin(context.Background(), "s", thrift.TType(thrift.STRING), int16(1)),
+ protocol.EXPECT().WriteString(context.Background(), "test"),
+ protocol.EXPECT().WriteFieldEnd(context.Background()),
+ protocol.EXPECT().WriteFieldStop(context.Background()),
+ protocol.EXPECT().WriteStructEnd(context.Background()),
+ protocol.EXPECT().WriteMessageEnd(context.Background()),
protocol.EXPECT().Flush(context.Background()),
- protocol.EXPECT().ReadMessageBegin().Return("testString", thrift.REPLY, int32(2), nil),
+ protocol.EXPECT().ReadMessageBegin(context.Background()).Return("testString", thrift.REPLY, int32(2), nil),
)
client := errortest.NewErrorTestClient(thrift.NewTStandardClient(protocol, protocol))
@@ -728,16 +728,16 @@ func TestClientSeqIdMismatchLegeacy(t *testing.T) {
transport := thrift.NewTMemoryBuffer()
protocol := NewMockTProtocol(mockCtrl)
gomock.InOrder(
- protocol.EXPECT().WriteMessageBegin("testString", thrift.CALL, int32(1)),
- protocol.EXPECT().WriteStructBegin("testString_args"),
- protocol.EXPECT().WriteFieldBegin("s", thrift.TType(thrift.STRING), int16(1)),
- protocol.EXPECT().WriteString("test"),
- protocol.EXPECT().WriteFieldEnd(),
- protocol.EXPECT().WriteFieldStop(),
- protocol.EXPECT().WriteStructEnd(),
- protocol.EXPECT().WriteMessageEnd(),
+ protocol.EXPECT().WriteMessageBegin(context.Background(), "testString", thrift.CALL, int32(1)),
+ protocol.EXPECT().WriteStructBegin(context.Background(), "testString_args"),
+ protocol.EXPECT().WriteFieldBegin(context.Background(), "s", thrift.TType(thrift.STRING), int16(1)),
+ protocol.EXPECT().WriteString(context.Background(), "test"),
+ protocol.EXPECT().WriteFieldEnd(context.Background()),
+ protocol.EXPECT().WriteFieldStop(context.Background()),
+ protocol.EXPECT().WriteStructEnd(context.Background()),
+ protocol.EXPECT().WriteMessageEnd(context.Background()),
protocol.EXPECT().Flush(context.Background()),
- protocol.EXPECT().ReadMessageBegin().Return("testString", thrift.REPLY, int32(2), nil),
+ protocol.EXPECT().ReadMessageBegin(context.Background()).Return("testString", thrift.REPLY, int32(2), nil),
)
client := errortest.NewErrorTestClientProtocol(transport, protocol, protocol)
@@ -757,16 +757,16 @@ func TestClientWrongMethodName(t *testing.T) {
mockCtrl := gomock.NewController(t)
protocol := NewMockTProtocol(mockCtrl)
gomock.InOrder(
- protocol.EXPECT().WriteMessageBegin("testString", thrift.CALL, int32(1)),
- protocol.EXPECT().WriteStructBegin("testString_args"),
- protocol.EXPECT().WriteFieldBegin("s", thrift.TType(thrift.STRING), int16(1)),
- protocol.EXPECT().WriteString("test"),
- protocol.EXPECT().WriteFieldEnd(),
- protocol.EXPECT().WriteFieldStop(),
- protocol.EXPECT().WriteStructEnd(),
- protocol.EXPECT().WriteMessageEnd(),
+ protocol.EXPECT().WriteMessageBegin(context.Background(), "testString", thrift.CALL, int32(1)),
+ protocol.EXPECT().WriteStructBegin(context.Background(), "testString_args"),
+ protocol.EXPECT().WriteFieldBegin(context.Background(), "s", thrift.TType(thrift.STRING), int16(1)),
+ protocol.EXPECT().WriteString(context.Background(), "test"),
+ protocol.EXPECT().WriteFieldEnd(context.Background()),
+ protocol.EXPECT().WriteFieldStop(context.Background()),
+ protocol.EXPECT().WriteStructEnd(context.Background()),
+ protocol.EXPECT().WriteMessageEnd(context.Background()),
protocol.EXPECT().Flush(context.Background()),
- protocol.EXPECT().ReadMessageBegin().Return("unknown", thrift.REPLY, int32(1), nil),
+ protocol.EXPECT().ReadMessageBegin(context.Background()).Return("unknown", thrift.REPLY, int32(1), nil),
)
client := errortest.NewErrorTestClient(thrift.NewTStandardClient(protocol, protocol))
@@ -788,16 +788,16 @@ func TestClientWrongMethodNameLegacy(t *testing.T) {
transport := thrift.NewTMemoryBuffer()
protocol := NewMockTProtocol(mockCtrl)
gomock.InOrder(
- protocol.EXPECT().WriteMessageBegin("testString", thrift.CALL, int32(1)),
- protocol.EXPECT().WriteStructBegin("testString_args"),
- protocol.EXPECT().WriteFieldBegin("s", thrift.TType(thrift.STRING), int16(1)),
- protocol.EXPECT().WriteString("test"),
- protocol.EXPECT().WriteFieldEnd(),
- protocol.EXPECT().WriteFieldStop(),
- protocol.EXPECT().WriteStructEnd(),
- protocol.EXPECT().WriteMessageEnd(),
+ protocol.EXPECT().WriteMessageBegin(context.Background(), "testString", thrift.CALL, int32(1)),
+ protocol.EXPECT().WriteStructBegin(context.Background(), "testString_args"),
+ protocol.EXPECT().WriteFieldBegin(context.Background(), "s", thrift.TType(thrift.STRING), int16(1)),
+ protocol.EXPECT().WriteString(context.Background(), "test"),
+ protocol.EXPECT().WriteFieldEnd(context.Background()),
+ protocol.EXPECT().WriteFieldStop(context.Background()),
+ protocol.EXPECT().WriteStructEnd(context.Background()),
+ protocol.EXPECT().WriteMessageEnd(context.Background()),
protocol.EXPECT().Flush(context.Background()),
- protocol.EXPECT().ReadMessageBegin().Return("unknown", thrift.REPLY, int32(1), nil),
+ protocol.EXPECT().ReadMessageBegin(context.Background()).Return("unknown", thrift.REPLY, int32(1), nil),
)
client := errortest.NewErrorTestClientProtocol(transport, protocol, protocol)
@@ -817,16 +817,16 @@ func TestClientWrongMessageType(t *testing.T) {
mockCtrl := gomock.NewController(t)
protocol := NewMockTProtocol(mockCtrl)
gomock.InOrder(
- protocol.EXPECT().WriteMessageBegin("testString", thrift.CALL, int32(1)),
- protocol.EXPECT().WriteStructBegin("testString_args"),
- protocol.EXPECT().WriteFieldBegin("s", thrift.TType(thrift.STRING), int16(1)),
- protocol.EXPECT().WriteString("test"),
- protocol.EXPECT().WriteFieldEnd(),
- protocol.EXPECT().WriteFieldStop(),
- protocol.EXPECT().WriteStructEnd(),
- protocol.EXPECT().WriteMessageEnd(),
+ protocol.EXPECT().WriteMessageBegin(context.Background(), "testString", thrift.CALL, int32(1)),
+ protocol.EXPECT().WriteStructBegin(context.Background(), "testString_args"),
+ protocol.EXPECT().WriteFieldBegin(context.Background(), "s", thrift.TType(thrift.STRING), int16(1)),
+ protocol.EXPECT().WriteString(context.Background(), "test"),
+ protocol.EXPECT().WriteFieldEnd(context.Background()),
+ protocol.EXPECT().WriteFieldStop(context.Background()),
+ protocol.EXPECT().WriteStructEnd(context.Background()),
+ protocol.EXPECT().WriteMessageEnd(context.Background()),
protocol.EXPECT().Flush(context.Background()),
- protocol.EXPECT().ReadMessageBegin().Return("testString", thrift.INVALID_TMESSAGE_TYPE, int32(1), nil),
+ protocol.EXPECT().ReadMessageBegin(context.Background()).Return("testString", thrift.INVALID_TMESSAGE_TYPE, int32(1), nil),
)
client := errortest.NewErrorTestClient(thrift.NewTStandardClient(protocol, protocol))
@@ -848,16 +848,16 @@ func TestClientWrongMessageTypeLegacy(t *testing.T) {
transport := thrift.NewTMemoryBuffer()
protocol := NewMockTProtocol(mockCtrl)
gomock.InOrder(
- protocol.EXPECT().WriteMessageBegin("testString", thrift.CALL, int32(1)),
- protocol.EXPECT().WriteStructBegin("testString_args"),
- protocol.EXPECT().WriteFieldBegin("s", thrift.TType(thrift.STRING), int16(1)),
- protocol.EXPECT().WriteString("test"),
- protocol.EXPECT().WriteFieldEnd(),
- protocol.EXPECT().WriteFieldStop(),
- protocol.EXPECT().WriteStructEnd(),
- protocol.EXPECT().WriteMessageEnd(),
+ protocol.EXPECT().WriteMessageBegin(context.Background(), "testString", thrift.CALL, int32(1)),
+ protocol.EXPECT().WriteStructBegin(context.Background(), "testString_args"),
+ protocol.EXPECT().WriteFieldBegin(context.Background(), "s", thrift.TType(thrift.STRING), int16(1)),
+ protocol.EXPECT().WriteString(context.Background(), "test"),
+ protocol.EXPECT().WriteFieldEnd(context.Background()),
+ protocol.EXPECT().WriteFieldStop(context.Background()),
+ protocol.EXPECT().WriteStructEnd(context.Background()),
+ protocol.EXPECT().WriteMessageEnd(context.Background()),
protocol.EXPECT().Flush(context.Background()),
- protocol.EXPECT().ReadMessageBegin().Return("testString", thrift.INVALID_TMESSAGE_TYPE, int32(1), nil),
+ protocol.EXPECT().ReadMessageBegin(context.Background()).Return("testString", thrift.INVALID_TMESSAGE_TYPE, int32(1), nil),
)
client := errortest.NewErrorTestClientProtocol(transport, protocol, protocol)
diff --git a/lib/go/test/tests/optional_fields_test.go b/lib/go/test/tests/optional_fields_test.go
index 34ad6605a..7e240e6e2 100644
--- a/lib/go/test/tests/optional_fields_test.go
+++ b/lib/go/test/tests/optional_fields_test.go
@@ -21,6 +21,7 @@ package tests
import (
"bytes"
+ "context"
gomock "github.com/golang/mock/gomock"
"optionalfieldstest"
"testing"
@@ -185,12 +186,12 @@ func TestNoOptionalUnsetFieldsOnWire(t *testing.T) {
defer mockCtrl.Finish()
proto := NewMockTProtocol(mockCtrl)
gomock.InOrder(
- proto.EXPECT().WriteStructBegin("all_optional").Return(nil),
- proto.EXPECT().WriteFieldStop().Return(nil),
- proto.EXPECT().WriteStructEnd().Return(nil),
+ proto.EXPECT().WriteStructBegin(context.Background(), "all_optional").Return(nil),
+ proto.EXPECT().WriteFieldStop(context.Background()).Return(nil),
+ proto.EXPECT().WriteStructEnd(context.Background()).Return(nil),
)
ao := optionalfieldstest.NewAllOptional()
- ao.Write(proto)
+ ao.Write(context.Background(), proto)
}
func TestNoSetToDefaultFieldsOnWire(t *testing.T) {
@@ -198,13 +199,13 @@ func TestNoSetToDefaultFieldsOnWire(t *testing.T) {
defer mockCtrl.Finish()
proto := NewMockTProtocol(mockCtrl)
gomock.InOrder(
- proto.EXPECT().WriteStructBegin("all_optional").Return(nil),
- proto.EXPECT().WriteFieldStop().Return(nil),
- proto.EXPECT().WriteStructEnd().Return(nil),
+ proto.EXPECT().WriteStructBegin(context.Background(), "all_optional").Return(nil),
+ proto.EXPECT().WriteFieldStop(context.Background()).Return(nil),
+ proto.EXPECT().WriteStructEnd(context.Background()).Return(nil),
)
ao := optionalfieldstest.NewAllOptional()
ao.I = 42
- ao.Write(proto)
+ ao.Write(context.Background(), proto)
}
//Make sure that only one field is being serialized when set to non-default
@@ -213,16 +214,16 @@ func TestOneISetFieldOnWire(t *testing.T) {
defer mockCtrl.Finish()
proto := NewMockTProtocol(mockCtrl)
gomock.InOrder(
- proto.EXPECT().WriteStructBegin("all_optional").Return(nil),
- proto.EXPECT().WriteFieldBegin("i", thrift.TType(thrift.I64), int16(2)).Return(nil),
- proto.EXPECT().WriteI64(int64(123)).Return(nil),
- proto.EXPECT().WriteFieldEnd().Return(nil),
- proto.EXPECT().WriteFieldStop().Return(nil),
- proto.EXPECT().WriteStructEnd().Return(nil),
+ proto.EXPECT().WriteStructBegin(context.Background(), "all_optional").Return(nil),
+ proto.EXPECT().WriteFieldBegin(context.Background(), "i", thrift.TType(thrift.I64), int16(2)).Return(nil),
+ proto.EXPECT().WriteI64(context.Background(), int64(123)).Return(nil),
+ proto.EXPECT().WriteFieldEnd(context.Background()).Return(nil),
+ proto.EXPECT().WriteFieldStop(context.Background()).Return(nil),
+ proto.EXPECT().WriteStructEnd(context.Background()).Return(nil),
)
ao := optionalfieldstest.NewAllOptional()
ao.I = 123
- ao.Write(proto)
+ ao.Write(context.Background(), proto)
}
func TestOneLSetFieldOnWire(t *testing.T) {
@@ -230,19 +231,19 @@ func TestOneLSetFieldOnWire(t *testing.T) {
defer mockCtrl.Finish()
proto := NewMockTProtocol(mockCtrl)
gomock.InOrder(
- proto.EXPECT().WriteStructBegin("all_optional").Return(nil),
- proto.EXPECT().WriteFieldBegin("l", thrift.TType(thrift.LIST), int16(9)).Return(nil),
- proto.EXPECT().WriteListBegin(thrift.TType(thrift.I64), 2).Return(nil),
- proto.EXPECT().WriteI64(int64(1)).Return(nil),
- proto.EXPECT().WriteI64(int64(2)).Return(nil),
- proto.EXPECT().WriteListEnd().Return(nil),
- proto.EXPECT().WriteFieldEnd().Return(nil),
- proto.EXPECT().WriteFieldStop().Return(nil),
- proto.EXPECT().WriteStructEnd().Return(nil),
+ proto.EXPECT().WriteStructBegin(context.Background(), "all_optional").Return(nil),
+ proto.EXPECT().WriteFieldBegin(context.Background(), "l", thrift.TType(thrift.LIST), int16(9)).Return(nil),
+ proto.EXPECT().WriteListBegin(context.Background(), thrift.TType(thrift.I64), 2).Return(nil),
+ proto.EXPECT().WriteI64(context.Background(), int64(1)).Return(nil),
+ proto.EXPECT().WriteI64(context.Background(), int64(2)).Return(nil),
+ proto.EXPECT().WriteListEnd(context.Background()).Return(nil),
+ proto.EXPECT().WriteFieldEnd(context.Background()).Return(nil),
+ proto.EXPECT().WriteFieldStop(context.Background()).Return(nil),
+ proto.EXPECT().WriteStructEnd(context.Background()).Return(nil),
)
ao := optionalfieldstest.NewAllOptional()
ao.L = []int64{1, 2}
- ao.Write(proto)
+ ao.Write(context.Background(), proto)
}
func TestOneBinSetFieldOnWire(t *testing.T) {
@@ -250,16 +251,16 @@ func TestOneBinSetFieldOnWire(t *testing.T) {
defer mockCtrl.Finish()
proto := NewMockTProtocol(mockCtrl)
gomock.InOrder(
- proto.EXPECT().WriteStructBegin("all_optional").Return(nil),
- proto.EXPECT().WriteFieldBegin("bin", thrift.TType(thrift.STRING), int16(13)).Return(nil),
- proto.EXPECT().WriteBinary([]byte("somebytestring")).Return(nil),
- proto.EXPECT().WriteFieldEnd().Return(nil),
- proto.EXPECT().WriteFieldStop().Return(nil),
- proto.EXPECT().WriteStructEnd().Return(nil),
+ proto.EXPECT().WriteStructBegin(context.Background(), "all_optional").Return(nil),
+ proto.EXPECT().WriteFieldBegin(context.Background(), "bin", thrift.TType(thrift.STRING), int16(13)).Return(nil),
+ proto.EXPECT().WriteBinary(context.Background(), []byte("somebytestring")).Return(nil),
+ proto.EXPECT().WriteFieldEnd(context.Background()).Return(nil),
+ proto.EXPECT().WriteFieldStop(context.Background()).Return(nil),
+ proto.EXPECT().WriteStructEnd(context.Background()).Return(nil),
)
ao := optionalfieldstest.NewAllOptional()
ao.Bin = []byte("somebytestring")
- ao.Write(proto)
+ ao.Write(context.Background(), proto)
}
func TestOneEmptyBinSetFieldOnWire(t *testing.T) {
@@ -267,14 +268,14 @@ func TestOneEmptyBinSetFieldOnWire(t *testing.T) {
defer mockCtrl.Finish()
proto := NewMockTProtocol(mockCtrl)
gomock.InOrder(
- proto.EXPECT().WriteStructBegin("all_optional").Return(nil),
- proto.EXPECT().WriteFieldBegin("bin", thrift.TType(thrift.STRING), int16(13)).Return(nil),
- proto.EXPECT().WriteBinary([]byte{}).Return(nil),
- proto.EXPECT().WriteFieldEnd().Return(nil),
- proto.EXPECT().WriteFieldStop().Return(nil),
- proto.EXPECT().WriteStructEnd().Return(nil),
+ proto.EXPECT().WriteStructBegin(context.Background(), "all_optional").Return(nil),
+ proto.EXPECT().WriteFieldBegin(context.Background(), "bin", thrift.TType(thrift.STRING), int16(13)).Return(nil),
+ proto.EXPECT().WriteBinary(context.Background(), []byte{}).Return(nil),
+ proto.EXPECT().WriteFieldEnd(context.Background()).Return(nil),
+ proto.EXPECT().WriteFieldStop(context.Background()).Return(nil),
+ proto.EXPECT().WriteStructEnd(context.Background()).Return(nil),
)
ao := optionalfieldstest.NewAllOptional()
ao.Bin = []byte{}
- ao.Write(proto)
+ ao.Write(context.Background(), proto)
}
diff --git a/lib/go/test/tests/protocol_mock.go b/lib/go/test/tests/protocol_mock.go
index 51d7a02ff..793e4e1c0 100644
--- a/lib/go/test/tests/protocol_mock.go
+++ b/lib/go/test/tests/protocol_mock.go
@@ -60,52 +60,52 @@ func (_mr *_MockTProtocolRecorder) Flush(ctx context.Context) *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "Flush")
}
-func (_m *MockTProtocol) ReadBinary() ([]byte, error) {
- ret := _m.ctrl.Call(_m, "ReadBinary")
+func (_m *MockTProtocol) ReadBinary(ctx context.Context) ([]byte, error) {
+ ret := _m.ctrl.Call(_m, "ReadBinary", ctx)
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error)
return ret0, ret1
}
-func (_mr *_MockTProtocolRecorder) ReadBinary() *gomock.Call {
- return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadBinary")
+func (_mr *_MockTProtocolRecorder) ReadBinary(ctx context.Context) *gomock.Call {
+ return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadBinary", ctx)
}
-func (_m *MockTProtocol) ReadBool() (bool, error) {
- ret := _m.ctrl.Call(_m, "ReadBool")
+func (_m *MockTProtocol) ReadBool(ctx context.Context) (bool, error) {
+ ret := _m.ctrl.Call(_m, "ReadBool", ctx)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
-func (_mr *_MockTProtocolRecorder) ReadBool() *gomock.Call {
- return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadBool")
+func (_mr *_MockTProtocolRecorder) ReadBool(ctx context.Context) *gomock.Call {
+ return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadBool", ctx)
}
-func (_m *MockTProtocol) ReadByte() (int8, error) {
- ret := _m.ctrl.Call(_m, "ReadByte")
+func (_m *MockTProtocol) ReadByte(ctx context.Context) (int8, error) {
+ ret := _m.ctrl.Call(_m, "ReadByte", ctx)
ret0, _ := ret[0].(int8)
ret1, _ := ret[1].(error)
return ret0, ret1
}
-func (_mr *_MockTProtocolRecorder) ReadByte() *gomock.Call {
- return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadByte")
+func (_mr *_MockTProtocolRecorder) ReadByte(ctx context.Context) *gomock.Call {
+ return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadByte", ctx)
}
-func (_m *MockTProtocol) ReadDouble() (float64, error) {
- ret := _m.ctrl.Call(_m, "ReadDouble")
+func (_m *MockTProtocol) ReadDouble(ctx context.Context) (float64, error) {
+ ret := _m.ctrl.Call(_m, "ReadDouble", ctx)
ret0, _ := ret[0].(float64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
-func (_mr *_MockTProtocolRecorder) ReadDouble() *gomock.Call {
- return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadDouble")
+func (_mr *_MockTProtocolRecorder) ReadDouble(ctx context.Context) *gomock.Call {
+ return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadDouble", ctx)
}
-func (_m *MockTProtocol) ReadFieldBegin() (string, thrift.TType, int16, error) {
- ret := _m.ctrl.Call(_m, "ReadFieldBegin")
+func (_m *MockTProtocol) ReadFieldBegin(ctx context.Context) (string, thrift.TType, int16, error) {
+ ret := _m.ctrl.Call(_m, "ReadFieldBegin", ctx)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(thrift.TType)
ret2, _ := ret[2].(int16)
@@ -113,77 +113,77 @@ func (_m *MockTProtocol) ReadFieldBegin() (string, thrift.TType, int16, error) {
return ret0, ret1, ret2, ret3
}
-func (_mr *_MockTProtocolRecorder) ReadFieldBegin() *gomock.Call {
- return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadFieldBegin")
+func (_mr *_MockTProtocolRecorder) ReadFieldBegin(ctx context.Context) *gomock.Call {
+ return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadFieldBegin", ctx)
}
-func (_m *MockTProtocol) ReadFieldEnd() error {
- ret := _m.ctrl.Call(_m, "ReadFieldEnd")
+func (_m *MockTProtocol) ReadFieldEnd(ctx context.Context) error {
+ ret := _m.ctrl.Call(_m, "ReadFieldEnd", ctx)
ret0, _ := ret[0].(error)
return ret0
}
-func (_mr *_MockTProtocolRecorder) ReadFieldEnd() *gomock.Call {
- return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadFieldEnd")
+func (_mr *_MockTProtocolRecorder) ReadFieldEnd(ctx context.Context) *gomock.Call {
+ return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadFieldEnd", ctx)
}
-func (_m *MockTProtocol) ReadI16() (int16, error) {
- ret := _m.ctrl.Call(_m, "ReadI16")
+func (_m *MockTProtocol) ReadI16(ctx context.Context) (int16, error) {
+ ret := _m.ctrl.Call(_m, "ReadI16", ctx)
ret0, _ := ret[0].(int16)
ret1, _ := ret[1].(error)
return ret0, ret1
}
-func (_mr *_MockTProtocolRecorder) ReadI16() *gomock.Call {
- return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadI16")
+func (_mr *_MockTProtocolRecorder) ReadI16(ctx context.Context) *gomock.Call {
+ return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadI16", ctx)
}
-func (_m *MockTProtocol) ReadI32() (int32, error) {
- ret := _m.ctrl.Call(_m, "ReadI32")
+func (_m *MockTProtocol) ReadI32(ctx context.Context) (int32, error) {
+ ret := _m.ctrl.Call(_m, "ReadI32", ctx)
ret0, _ := ret[0].(int32)
ret1, _ := ret[1].(error)
return ret0, ret1
}
-func (_mr *_MockTProtocolRecorder) ReadI32() *gomock.Call {
- return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadI32")
+func (_mr *_MockTProtocolRecorder) ReadI32(ctx context.Context) *gomock.Call {
+ return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadI32", ctx)
}
-func (_m *MockTProtocol) ReadI64() (int64, error) {
- ret := _m.ctrl.Call(_m, "ReadI64")
+func (_m *MockTProtocol) ReadI64(ctx context.Context) (int64, error) {
+ ret := _m.ctrl.Call(_m, "ReadI64", ctx)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
-func (_mr *_MockTProtocolRecorder) ReadI64() *gomock.Call {
- return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadI64")
+func (_mr *_MockTProtocolRecorder) ReadI64(ctx context.Context) *gomock.Call {
+ return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadI64", ctx)
}
-func (_m *MockTProtocol) ReadListBegin() (thrift.TType, int, error) {
- ret := _m.ctrl.Call(_m, "ReadListBegin")
+func (_m *MockTProtocol) ReadListBegin(ctx context.Context) (thrift.TType, int, error) {
+ ret := _m.ctrl.Call(_m, "ReadListBegin", ctx)
ret0, _ := ret[0].(thrift.TType)
ret1, _ := ret[1].(int)
ret2, _ := ret[2].(error)
return ret0, ret1, ret2
}
-func (_mr *_MockTProtocolRecorder) ReadListBegin() *gomock.Call {
- return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadListBegin")
+func (_mr *_MockTProtocolRecorder) ReadListBegin(ctx context.Context) *gomock.Call {
+ return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadListBegin", ctx)
}
-func (_m *MockTProtocol) ReadListEnd() error {
- ret := _m.ctrl.Call(_m, "ReadListEnd")
+func (_m *MockTProtocol) ReadListEnd(ctx context.Context) error {
+ ret := _m.ctrl.Call(_m, "ReadListEnd", ctx)
ret0, _ := ret[0].(error)
return ret0
}
-func (_mr *_MockTProtocolRecorder) ReadListEnd() *gomock.Call {
- return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadListEnd")
+func (_mr *_MockTProtocolRecorder) ReadListEnd(ctx context.Context) *gomock.Call {
+ return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadListEnd", ctx)
}
-func (_m *MockTProtocol) ReadMapBegin() (thrift.TType, thrift.TType, int, error) {
- ret := _m.ctrl.Call(_m, "ReadMapBegin")
+func (_m *MockTProtocol) ReadMapBegin(ctx context.Context) (thrift.TType, thrift.TType, int, error) {
+ ret := _m.ctrl.Call(_m, "ReadMapBegin", ctx)
ret0, _ := ret[0].(thrift.TType)
ret1, _ := ret[1].(thrift.TType)
ret2, _ := ret[2].(int)
@@ -191,22 +191,22 @@ func (_m *MockTProtocol) ReadMapBegin() (thrift.TType, thrift.TType, int, error)
return ret0, ret1, ret2, ret3
}
-func (_mr *_MockTProtocolRecorder) ReadMapBegin() *gomock.Call {
- return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadMapBegin")
+func (_mr *_MockTProtocolRecorder) ReadMapBegin(ctx context.Context) *gomock.Call {
+ return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadMapBegin", ctx)
}
-func (_m *MockTProtocol) ReadMapEnd() error {
- ret := _m.ctrl.Call(_m, "ReadMapEnd")
+func (_m *MockTProtocol) ReadMapEnd(ctx context.Context) error {
+ ret := _m.ctrl.Call(_m, "ReadMapEnd", ctx)
ret0, _ := ret[0].(error)
return ret0
}
-func (_mr *_MockTProtocolRecorder) ReadMapEnd() *gomock.Call {
- return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadMapEnd")
+func (_mr *_MockTProtocolRecorder) ReadMapEnd(ctx context.Context) *gomock.Call {
+ return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadMapEnd", ctx)
}
-func (_m *MockTProtocol) ReadMessageBegin() (string, thrift.TMessageType, int32, error) {
- ret := _m.ctrl.Call(_m, "ReadMessageBegin")
+func (_m *MockTProtocol) ReadMessageBegin(ctx context.Context) (string, thrift.TMessageType, int32, error) {
+ ret := _m.ctrl.Call(_m, "ReadMessageBegin", ctx)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(thrift.TMessageType)
ret2, _ := ret[2].(int32)
@@ -214,82 +214,82 @@ func (_m *MockTProtocol) ReadMessageBegin() (string, thrift.TMessageType, int32,
return ret0, ret1, ret2, ret3
}
-func (_mr *_MockTProtocolRecorder) ReadMessageBegin() *gomock.Call {
- return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadMessageBegin")
+func (_mr *_MockTProtocolRecorder) ReadMessageBegin(ctx context.Context) *gomock.Call {
+ return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadMessageBegin", ctx)
}
-func (_m *MockTProtocol) ReadMessageEnd() error {
- ret := _m.ctrl.Call(_m, "ReadMessageEnd")
+func (_m *MockTProtocol) ReadMessageEnd(ctx context.Context) error {
+ ret := _m.ctrl.Call(_m, "ReadMessageEnd", ctx)
ret0, _ := ret[0].(error)
return ret0
}
-func (_mr *_MockTProtocolRecorder) ReadMessageEnd() *gomock.Call {
- return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadMessageEnd")
+func (_mr *_MockTProtocolRecorder) ReadMessageEnd(ctx context.Context) *gomock.Call {
+ return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadMessageEnd", ctx)
}
-func (_m *MockTProtocol) ReadSetBegin() (thrift.TType, int, error) {
- ret := _m.ctrl.Call(_m, "ReadSetBegin")
+func (_m *MockTProtocol) ReadSetBegin(ctx context.Context) (thrift.TType, int, error) {
+ ret := _m.ctrl.Call(_m, "ReadSetBegin", ctx)
ret0, _ := ret[0].(thrift.TType)
ret1, _ := ret[1].(int)
ret2, _ := ret[2].(error)
return ret0, ret1, ret2
}
-func (_mr *_MockTProtocolRecorder) ReadSetBegin() *gomock.Call {
- return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadSetBegin")
+func (_mr *_MockTProtocolRecorder) ReadSetBegin(ctx context.Context) *gomock.Call {
+ return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadSetBegin", ctx)
}
-func (_m *MockTProtocol) ReadSetEnd() error {
- ret := _m.ctrl.Call(_m, "ReadSetEnd")
+func (_m *MockTProtocol) ReadSetEnd(ctx context.Context) error {
+ ret := _m.ctrl.Call(_m, "ReadSetEnd", ctx)
ret0, _ := ret[0].(error)
return ret0
}
-func (_mr *_MockTProtocolRecorder) ReadSetEnd() *gomock.Call {
- return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadSetEnd")
+func (_mr *_MockTProtocolRecorder) ReadSetEnd(ctx context.Context) *gomock.Call {
+ return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadSetEnd", ctx)
}
-func (_m *MockTProtocol) ReadString() (string, error) {
- ret := _m.ctrl.Call(_m, "ReadString")
+func (_m *MockTProtocol) ReadString(ctx context.Context) (string, error) {
+ ret := _m.ctrl.Call(_m, "ReadString", ctx)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
-func (_mr *_MockTProtocolRecorder) ReadString() *gomock.Call {
- return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadString")
+func (_mr *_MockTProtocolRecorder) ReadString(ctx context.Context) *gomock.Call {
+ return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadString", ctx)
}
-func (_m *MockTProtocol) ReadStructBegin() (string, error) {
- ret := _m.ctrl.Call(_m, "ReadStructBegin")
+func (_m *MockTProtocol) ReadStructBegin(ctx context.Context) (string, error) {
+ ret := _m.ctrl.Call(_m, "ReadStructBegin", ctx)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
-func (_mr *_MockTProtocolRecorder) ReadStructBegin() *gomock.Call {
- return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadStructBegin")
+func (_mr *_MockTProtocolRecorder) ReadStructBegin(ctx context.Context) *gomock.Call {
+ return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadStructBegin", ctx)
}
-func (_m *MockTProtocol) ReadStructEnd() error {
- ret := _m.ctrl.Call(_m, "ReadStructEnd")
+func (_m *MockTProtocol) ReadStructEnd(ctx context.Context) error {
+ ret := _m.ctrl.Call(_m, "ReadStructEnd", ctx)
ret0, _ := ret[0].(error)
return ret0
}
-func (_mr *_MockTProtocolRecorder) ReadStructEnd() *gomock.Call {
- return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadStructEnd")
+func (_mr *_MockTProtocolRecorder) ReadStructEnd(ctx context.Context) *gomock.Call {
+ return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadStructEnd", ctx)
}
-func (_m *MockTProtocol) Skip(_param0 thrift.TType) error {
- ret := _m.ctrl.Call(_m, "Skip", _param0)
+func (_m *MockTProtocol) Skip(ctx context.Context, _param0 thrift.TType) error {
+ ret := _m.ctrl.Call(_m, "Skip", ctx, _param0)
ret0, _ := ret[0].(error)
return ret0
}
-func (_mr *_MockTProtocolRecorder) Skip(arg0 interface{}) *gomock.Call {
- return _mr.mock.ctrl.RecordCall(_mr.mock, "Skip", arg0)
+func (_mr *_MockTProtocolRecorder) Skip(ctx context.Context, arg0 interface{}) *gomock.Call {
+ return _mr.mock.ctrl.RecordCall(_mr.mock, "Skip", ctx, arg0)
}
func (_m *MockTProtocol) Transport() thrift.TTransport {
@@ -302,212 +302,212 @@ func (_mr *_MockTProtocolRecorder) Transport() *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "Transport")
}
-func (_m *MockTProtocol) WriteBinary(_param0 []byte) error {
- ret := _m.ctrl.Call(_m, "WriteBinary", _param0)
+func (_m *MockTProtocol) WriteBinary(ctx context.Context, _param0 []byte) error {
+ ret := _m.ctrl.Call(_m, "WriteBinary", ctx, _param0)
ret0, _ := ret[0].(error)
return ret0
}
-func (_mr *_MockTProtocolRecorder) WriteBinary(arg0 interface{}) *gomock.Call {
- return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteBinary", arg0)
+func (_mr *_MockTProtocolRecorder) WriteBinary(ctx context.Context, arg0 interface{}) *gomock.Call {
+ return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteBinary", ctx, arg0)
}
-func (_m *MockTProtocol) WriteBool(_param0 bool) error {
- ret := _m.ctrl.Call(_m, "WriteBool", _param0)
+func (_m *MockTProtocol) WriteBool(ctx context.Context, _param0 bool) error {
+ ret := _m.ctrl.Call(_m, "WriteBool", ctx, _param0)
ret0, _ := ret[0].(error)
return ret0
}
-func (_mr *_MockTProtocolRecorder) WriteBool(arg0 interface{}) *gomock.Call {
- return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteBool", arg0)
+func (_mr *_MockTProtocolRecorder) WriteBool(ctx context.Context, arg0 interface{}) *gomock.Call {
+ return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteBool", ctx, arg0)
}
-func (_m *MockTProtocol) WriteByte(_param0 int8) error {
- ret := _m.ctrl.Call(_m, "WriteByte", _param0)
+func (_m *MockTProtocol) WriteByte(ctx context.Context, _param0 int8) error {
+ ret := _m.ctrl.Call(_m, "WriteByte", ctx, _param0)
ret0, _ := ret[0].(error)
return ret0
}
-func (_mr *_MockTProtocolRecorder) WriteByte(arg0 interface{}) *gomock.Call {
- return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteByte", arg0)
+func (_mr *_MockTProtocolRecorder) WriteByte(ctx context.Context, arg0 interface{}) *gomock.Call {
+ return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteByte", ctx, arg0)
}
-func (_m *MockTProtocol) WriteDouble(_param0 float64) error {
- ret := _m.ctrl.Call(_m, "WriteDouble", _param0)
+func (_m *MockTProtocol) WriteDouble(ctx context.Context, _param0 float64) error {
+ ret := _m.ctrl.Call(_m, "WriteDouble", ctx, _param0)
ret0, _ := ret[0].(error)
return ret0
}
-func (_mr *_MockTProtocolRecorder) WriteDouble(arg0 interface{}) *gomock.Call {
- return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteDouble", arg0)
+func (_mr *_MockTProtocolRecorder) WriteDouble(ctx context.Context, arg0 interface{}) *gomock.Call {
+ return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteDouble", ctx, arg0)
}
-func (_m *MockTProtocol) WriteFieldBegin(_param0 string, _param1 thrift.TType, _param2 int16) error {
- ret := _m.ctrl.Call(_m, "WriteFieldBegin", _param0, _param1, _param2)
+func (_m *MockTProtocol) WriteFieldBegin(ctx context.Context, _param0 string, _param1 thrift.TType, _param2 int16) error {
+ ret := _m.ctrl.Call(_m, "WriteFieldBegin", ctx, _param0, _param1, _param2)
ret0, _ := ret[0].(error)
return ret0
}
-func (_mr *_MockTProtocolRecorder) WriteFieldBegin(arg0, arg1, arg2 interface{}) *gomock.Call {
- return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteFieldBegin", arg0, arg1, arg2)
+func (_mr *_MockTProtocolRecorder) WriteFieldBegin(ctx context.Context, arg0, arg1, arg2 interface{}) *gomock.Call {
+ return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteFieldBegin", ctx, arg0, arg1, arg2)
}
-func (_m *MockTProtocol) WriteFieldEnd() error {
- ret := _m.ctrl.Call(_m, "WriteFieldEnd")
+func (_m *MockTProtocol) WriteFieldEnd(ctx context.Context) error {
+ ret := _m.ctrl.Call(_m, "WriteFieldEnd", ctx)
ret0, _ := ret[0].(error)
return ret0
}
-func (_mr *_MockTProtocolRecorder) WriteFieldEnd() *gomock.Call {
- return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteFieldEnd")
+func (_mr *_MockTProtocolRecorder) WriteFieldEnd(ctx context.Context) *gomock.Call {
+ return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteFieldEnd", ctx)
}
-func (_m *MockTProtocol) WriteFieldStop() error {
- ret := _m.ctrl.Call(_m, "WriteFieldStop")
+func (_m *MockTProtocol) WriteFieldStop(ctx context.Context) error {
+ ret := _m.ctrl.Call(_m, "WriteFieldStop", ctx)
ret0, _ := ret[0].(error)
return ret0
}
-func (_mr *_MockTProtocolRecorder) WriteFieldStop() *gomock.Call {
- return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteFieldStop")
+func (_mr *_MockTProtocolRecorder) WriteFieldStop(ctx context.Context) *gomock.Call {
+ return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteFieldStop", ctx)
}
-func (_m *MockTProtocol) WriteI16(_param0 int16) error {
- ret := _m.ctrl.Call(_m, "WriteI16", _param0)
+func (_m *MockTProtocol) WriteI16(ctx context.Context, _param0 int16) error {
+ ret := _m.ctrl.Call(_m, "WriteI16", ctx, _param0)
ret0, _ := ret[0].(error)
return ret0
}
-func (_mr *_MockTProtocolRecorder) WriteI16(arg0 interface{}) *gomock.Call {
- return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteI16", arg0)
+func (_mr *_MockTProtocolRecorder) WriteI16(ctx context.Context, arg0 interface{}) *gomock.Call {
+ return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteI16", ctx, arg0)
}
-func (_m *MockTProtocol) WriteI32(_param0 int32) error {
- ret := _m.ctrl.Call(_m, "WriteI32", _param0)
+func (_m *MockTProtocol) WriteI32(ctx context.Context, _param0 int32) error {
+ ret := _m.ctrl.Call(_m, "WriteI32", ctx, _param0)
ret0, _ := ret[0].(error)
return ret0
}
-func (_mr *_MockTProtocolRecorder) WriteI32(arg0 interface{}) *gomock.Call {
- return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteI32", arg0)
+func (_mr *_MockTProtocolRecorder) WriteI32(ctx context.Context, arg0 interface{}) *gomock.Call {
+ return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteI32", ctx, arg0)
}
-func (_m *MockTProtocol) WriteI64(_param0 int64) error {
- ret := _m.ctrl.Call(_m, "WriteI64", _param0)
+func (_m *MockTProtocol) WriteI64(ctx context.Context, _param0 int64) error {
+ ret := _m.ctrl.Call(_m, "WriteI64", ctx, _param0)
ret0, _ := ret[0].(error)
return ret0
}
-func (_mr *_MockTProtocolRecorder) WriteI64(arg0 interface{}) *gomock.Call {
- return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteI64", arg0)
+func (_mr *_MockTProtocolRecorder) WriteI64(ctx context.Context, arg0 interface{}) *gomock.Call {
+ return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteI64", ctx, arg0)
}
-func (_m *MockTProtocol) WriteListBegin(_param0 thrift.TType, _param1 int) error {
- ret := _m.ctrl.Call(_m, "WriteListBegin", _param0, _param1)
+func (_m *MockTProtocol) WriteListBegin(ctx context.Context, _param0 thrift.TType, _param1 int) error {
+ ret := _m.ctrl.Call(_m, "WriteListBegin", ctx, _param0, _param1)
ret0, _ := ret[0].(error)
return ret0
}
-func (_mr *_MockTProtocolRecorder) WriteListBegin(arg0, arg1 interface{}) *gomock.Call {
- return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteListBegin", arg0, arg1)
+func (_mr *_MockTProtocolRecorder) WriteListBegin(ctx context.Context, arg0, arg1 interface{}) *gomock.Call {
+ return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteListBegin", ctx, arg0, arg1)
}
-func (_m *MockTProtocol) WriteListEnd() error {
- ret := _m.ctrl.Call(_m, "WriteListEnd")
+func (_m *MockTProtocol) WriteListEnd(ctx context.Context) error {
+ ret := _m.ctrl.Call(_m, "WriteListEnd", ctx)
ret0, _ := ret[0].(error)
return ret0
}
-func (_mr *_MockTProtocolRecorder) WriteListEnd() *gomock.Call {
- return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteListEnd")
+func (_mr *_MockTProtocolRecorder) WriteListEnd(ctx context.Context) *gomock.Call {
+ return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteListEnd", ctx)
}
-func (_m *MockTProtocol) WriteMapBegin(_param0 thrift.TType, _param1 thrift.TType, _param2 int) error {
- ret := _m.ctrl.Call(_m, "WriteMapBegin", _param0, _param1, _param2)
+func (_m *MockTProtocol) WriteMapBegin(ctx context.Context, _param0 thrift.TType, _param1 thrift.TType, _param2 int) error {
+ ret := _m.ctrl.Call(_m, "WriteMapBegin", ctx, _param0, _param1, _param2)
ret0, _ := ret[0].(error)
return ret0
}
-func (_mr *_MockTProtocolRecorder) WriteMapBegin(arg0, arg1, arg2 interface{}) *gomock.Call {
- return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteMapBegin", arg0, arg1, arg2)
+func (_mr *_MockTProtocolRecorder) WriteMapBegin(ctx context.Context, arg0, arg1, arg2 interface{}) *gomock.Call {
+ return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteMapBegin", ctx, arg0, arg1, arg2)
}
-func (_m *MockTProtocol) WriteMapEnd() error {
- ret := _m.ctrl.Call(_m, "WriteMapEnd")
+func (_m *MockTProtocol) WriteMapEnd(ctx context.Context) error {
+ ret := _m.ctrl.Call(_m, "WriteMapEnd", ctx)
ret0, _ := ret[0].(error)
return ret0
}
-func (_mr *_MockTProtocolRecorder) WriteMapEnd() *gomock.Call {
- return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteMapEnd")
+func (_mr *_MockTProtocolRecorder) WriteMapEnd(ctx context.Context) *gomock.Call {
+ return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteMapEnd", ctx)
}
-func (_m *MockTProtocol) WriteMessageBegin(_param0 string, _param1 thrift.TMessageType, _param2 int32) error {
- ret := _m.ctrl.Call(_m, "WriteMessageBegin", _param0, _param1, _param2)
+func (_m *MockTProtocol) WriteMessageBegin(ctx context.Context, _param0 string, _param1 thrift.TMessageType, _param2 int32) error {
+ ret := _m.ctrl.Call(_m, "WriteMessageBegin", ctx, _param0, _param1, _param2)
ret0, _ := ret[0].(error)
return ret0
}
-func (_mr *_MockTProtocolRecorder) WriteMessageBegin(arg0, arg1, arg2 interface{}) *gomock.Call {
- return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteMessageBegin", arg0, arg1, arg2)
+func (_mr *_MockTProtocolRecorder) WriteMessageBegin(ctx context.Context, arg0, arg1, arg2 interface{}) *gomock.Call {
+ return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteMessageBegin", ctx, arg0, arg1, arg2)
}
-func (_m *MockTProtocol) WriteMessageEnd() error {
- ret := _m.ctrl.Call(_m, "WriteMessageEnd")
+func (_m *MockTProtocol) WriteMessageEnd(ctx context.Context) error {
+ ret := _m.ctrl.Call(_m, "WriteMessageEnd", ctx)
ret0, _ := ret[0].(error)
return ret0
}
-func (_mr *_MockTProtocolRecorder) WriteMessageEnd() *gomock.Call {
- return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteMessageEnd")
+func (_mr *_MockTProtocolRecorder) WriteMessageEnd(ctx context.Context) *gomock.Call {
+ return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteMessageEnd", ctx)
}
-func (_m *MockTProtocol) WriteSetBegin(_param0 thrift.TType, _param1 int) error {
- ret := _m.ctrl.Call(_m, "WriteSetBegin", _param0, _param1)
+func (_m *MockTProtocol) WriteSetBegin(ctx context.Context, _param0 thrift.TType, _param1 int) error {
+ ret := _m.ctrl.Call(_m, "WriteSetBegin", ctx, _param0, _param1)
ret0, _ := ret[0].(error)
return ret0
}
-func (_mr *_MockTProtocolRecorder) WriteSetBegin(arg0, arg1 interface{}) *gomock.Call {
- return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteSetBegin", arg0, arg1)
+func (_mr *_MockTProtocolRecorder) WriteSetBegin(ctx context.Context, arg0, arg1 interface{}) *gomock.Call {
+ return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteSetBegin", ctx, arg0, arg1)
}
-func (_m *MockTProtocol) WriteSetEnd() error {
- ret := _m.ctrl.Call(_m, "WriteSetEnd")
+func (_m *MockTProtocol) WriteSetEnd(ctx context.Context) error {
+ ret := _m.ctrl.Call(_m, "WriteSetEnd", ctx)
ret0, _ := ret[0].(error)
return ret0
}
-func (_mr *_MockTProtocolRecorder) WriteSetEnd() *gomock.Call {
- return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteSetEnd")
+func (_mr *_MockTProtocolRecorder) WriteSetEnd(ctx context.Context) *gomock.Call {
+ return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteSetEnd", ctx)
}
-func (_m *MockTProtocol) WriteString(_param0 string) error {
- ret := _m.ctrl.Call(_m, "WriteString", _param0)
+func (_m *MockTProtocol) WriteString(ctx context.Context, _param0 string) error {
+ ret := _m.ctrl.Call(_m, "WriteString", ctx, _param0)
ret0, _ := ret[0].(error)
return ret0
}
-func (_mr *_MockTProtocolRecorder) WriteString(arg0 interface{}) *gomock.Call {
- return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteString", arg0)
+func (_mr *_MockTProtocolRecorder) WriteString(ctx context.Context, arg0 interface{}) *gomock.Call {
+ return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteString", ctx, arg0)
}
-func (_m *MockTProtocol) WriteStructBegin(_param0 string) error {
- ret := _m.ctrl.Call(_m, "WriteStructBegin", _param0)
+func (_m *MockTProtocol) WriteStructBegin(ctx context.Context, _param0 string) error {
+ ret := _m.ctrl.Call(_m, "WriteStructBegin", ctx, _param0)
ret0, _ := ret[0].(error)
return ret0
}
-func (_mr *_MockTProtocolRecorder) WriteStructBegin(arg0 interface{}) *gomock.Call {
- return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteStructBegin", arg0)
+func (_mr *_MockTProtocolRecorder) WriteStructBegin(ctx context.Context, arg0 interface{}) *gomock.Call {
+ return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteStructBegin", ctx, arg0)
}
-func (_m *MockTProtocol) WriteStructEnd() error {
- ret := _m.ctrl.Call(_m, "WriteStructEnd")
+func (_m *MockTProtocol) WriteStructEnd(ctx context.Context) error {
+ ret := _m.ctrl.Call(_m, "WriteStructEnd", ctx)
ret0, _ := ret[0].(error)
return ret0
}
-func (_mr *_MockTProtocolRecorder) WriteStructEnd() *gomock.Call {
- return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteStructEnd")
+func (_mr *_MockTProtocolRecorder) WriteStructEnd(ctx context.Context) *gomock.Call {
+ return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteStructEnd", ctx)
}
diff --git a/lib/go/test/tests/required_fields_test.go b/lib/go/test/tests/required_fields_test.go
index 3fa414ad8..06e8560e5 100644
--- a/lib/go/test/tests/required_fields_test.go
+++ b/lib/go/test/tests/required_fields_test.go
@@ -37,7 +37,7 @@ func TestRequiredField_SucecssWhenSet(t *testing.T) {
}
d := thrift.NewTDeserializer()
- err = d.Read(&requiredfieldtest.RequiredField{}, sourceData)
+ err = d.Read(context.Background(), &requiredfieldtest.RequiredField{}, sourceData)
if err != nil {
t.Fatalf("Did not expect an error when trying to deserialize the requiredfieldtest.RequiredField: %v", err)
}
@@ -53,7 +53,7 @@ func TestRequiredField_ErrorWhenMissing(t *testing.T) {
// attempt to deserialize into a different type (which should fail)
d := thrift.NewTDeserializer()
- err = d.Read(&requiredfieldtest.RequiredField{}, sourceData)
+ err = d.Read(context.Background(), &requiredfieldtest.RequiredField{}, sourceData)
if err == nil {
t.Fatal("Expected an error when trying to deserialize an object which is missing a required field")
}
@@ -66,12 +66,12 @@ func TestStructReadRequiredFields(t *testing.T) {
// None of required fields are set
gomock.InOrder(
- protocol.EXPECT().ReadStructBegin().Return("StructC", nil),
- protocol.EXPECT().ReadFieldBegin().Return("_", thrift.TType(thrift.STOP), int16(1), nil),
- protocol.EXPECT().ReadStructEnd().Return(nil),
+ protocol.EXPECT().ReadStructBegin(context.Background()).Return("StructC", nil),
+ protocol.EXPECT().ReadFieldBegin(context.Background()).Return("_", thrift.TType(thrift.STOP), int16(1), nil),
+ protocol.EXPECT().ReadStructEnd(context.Background()).Return(nil),
)
- err := testStruct.Read(protocol)
+ err := testStruct.Read(context.Background(), protocol)
mockCtrl.Finish()
mockCtrl = gomock.NewController(t)
if err == nil {
@@ -87,15 +87,15 @@ func TestStructReadRequiredFields(t *testing.T) {
// One of the required fields is set
gomock.InOrder(
- protocol.EXPECT().ReadStructBegin().Return("StructC", nil),
- protocol.EXPECT().ReadFieldBegin().Return("I", thrift.TType(thrift.I32), int16(2), nil),
- protocol.EXPECT().ReadI32().Return(int32(1), nil),
- protocol.EXPECT().ReadFieldEnd().Return(nil),
- protocol.EXPECT().ReadFieldBegin().Return("_", thrift.TType(thrift.STOP), int16(1), nil),
- protocol.EXPECT().ReadStructEnd().Return(nil),
+ protocol.EXPECT().ReadStructBegin(context.Background()).Return("StructC", nil),
+ protocol.EXPECT().ReadFieldBegin(context.Background()).Return("I", thrift.TType(thrift.I32), int16(2), nil),
+ protocol.EXPECT().ReadI32(context.Background()).Return(int32(1), nil),
+ protocol.EXPECT().ReadFieldEnd(context.Background()).Return(nil),
+ protocol.EXPECT().ReadFieldBegin(context.Background()).Return("_", thrift.TType(thrift.STOP), int16(1), nil),
+ protocol.EXPECT().ReadStructEnd(context.Background()).Return(nil),
)
- err = testStruct.Read(protocol)
+ err = testStruct.Read(context.Background(), protocol)
mockCtrl.Finish()
mockCtrl = gomock.NewController(t)
if err == nil {
@@ -111,18 +111,18 @@ func TestStructReadRequiredFields(t *testing.T) {
// Both of the required fields are set
gomock.InOrder(
- protocol.EXPECT().ReadStructBegin().Return("StructC", nil),
- protocol.EXPECT().ReadFieldBegin().Return("i", thrift.TType(thrift.I32), int16(2), nil),
- protocol.EXPECT().ReadI32().Return(int32(1), nil),
- protocol.EXPECT().ReadFieldEnd().Return(nil),
- protocol.EXPECT().ReadFieldBegin().Return("s2", thrift.TType(thrift.STRING), int16(4), nil),
- protocol.EXPECT().ReadString().Return("test", nil),
- protocol.EXPECT().ReadFieldEnd().Return(nil),
- protocol.EXPECT().ReadFieldBegin().Return("_", thrift.TType(thrift.STOP), int16(1), nil),
- protocol.EXPECT().ReadStructEnd().Return(nil),
+ protocol.EXPECT().ReadStructBegin(context.Background()).Return("StructC", nil),
+ protocol.EXPECT().ReadFieldBegin(context.Background()).Return("i", thrift.TType(thrift.I32), int16(2), nil),
+ protocol.EXPECT().ReadI32(context.Background()).Return(int32(1), nil),
+ protocol.EXPECT().ReadFieldEnd(context.Background()).Return(nil),
+ protocol.EXPECT().ReadFieldBegin(context.Background()).Return("s2", thrift.TType(thrift.STRING), int16(4), nil),
+ protocol.EXPECT().ReadString(context.Background()).Return("test", nil),
+ protocol.EXPECT().ReadFieldEnd(context.Background()).Return(nil),
+ protocol.EXPECT().ReadFieldBegin(context.Background()).Return("_", thrift.TType(thrift.STOP), int16(1), nil),
+ protocol.EXPECT().ReadStructEnd(context.Background()).Return(nil),
)
- err = testStruct.Read(protocol)
+ err = testStruct.Read(context.Background(), protocol)
mockCtrl.Finish()
if err != nil {
t.Fatal("Expected read to succeed")
diff --git a/lib/go/thrift/application_exception.go b/lib/go/thrift/application_exception.go
index 0023c57cf..6de37ee73 100644
--- a/lib/go/thrift/application_exception.go
+++ b/lib/go/thrift/application_exception.go
@@ -19,6 +19,10 @@
package thrift
+import (
+ "context"
+)
+
const (
UNKNOWN_APPLICATION_EXCEPTION = 0
UNKNOWN_METHOD = 1
@@ -51,8 +55,8 @@ var defaultApplicationExceptionMessage = map[int32]string{
type TApplicationException interface {
TException
TypeId() int32
- Read(iprot TProtocol) error
- Write(oprot TProtocol) error
+ Read(ctx context.Context, iprot TProtocol) error
+ Write(ctx context.Context, oprot TProtocol) error
}
type tApplicationException struct {
@@ -75,9 +79,9 @@ func (p *tApplicationException) TypeId() int32 {
return p.type_
}
-func (p *tApplicationException) Read(iprot TProtocol) error {
+func (p *tApplicationException) Read(ctx context.Context, iprot TProtocol) error {
// TODO: this should really be generated by the compiler
- _, err := iprot.ReadStructBegin()
+ _, err := iprot.ReadStructBegin(ctx)
if err != nil {
return err
}
@@ -86,7 +90,7 @@ func (p *tApplicationException) Read(iprot TProtocol) error {
type_ := int32(UNKNOWN_APPLICATION_EXCEPTION)
for {
- _, ttype, id, err := iprot.ReadFieldBegin()
+ _, ttype, id, err := iprot.ReadFieldBegin(ctx)
if err != nil {
return err
}
@@ -96,34 +100,34 @@ func (p *tApplicationException) Read(iprot TProtocol) error {
switch id {
case 1:
if ttype == STRING {
- if message, err = iprot.ReadString(); err != nil {
+ if message, err = iprot.ReadString(ctx); err != nil {
return err
}
} else {
- if err = SkipDefaultDepth(iprot, ttype); err != nil {
+ if err = SkipDefaultDepth(ctx, iprot, ttype); err != nil {
return err
}
}
case 2:
if ttype == I32 {
- if type_, err = iprot.ReadI32(); err != nil {
+ if type_, err = iprot.ReadI32(ctx); err != nil {
return err
}
} else {
- if err = SkipDefaultDepth(iprot, ttype); err != nil {
+ if err = SkipDefaultDepth(ctx, iprot, ttype); err != nil {
return err
}
}
default:
- if err = SkipDefaultDepth(iprot, ttype); err != nil {
+ if err = SkipDefaultDepth(ctx, iprot, ttype); err != nil {
return err
}
}
- if err = iprot.ReadFieldEnd(); err != nil {
+ if err = iprot.ReadFieldEnd(ctx); err != nil {
return err
}
}
- if err := iprot.ReadStructEnd(); err != nil {
+ if err := iprot.ReadStructEnd(ctx); err != nil {
return err
}
@@ -133,38 +137,38 @@ func (p *tApplicationException) Read(iprot TProtocol) error {
return nil
}
-func (p *tApplicationException) Write(oprot TProtocol) (err error) {
- err = oprot.WriteStructBegin("TApplicationException")
+func (p *tApplicationException) Write(ctx context.Context, oprot TProtocol) (err error) {
+ err = oprot.WriteStructBegin(ctx, "TApplicationException")
if len(p.Error()) > 0 {
- err = oprot.WriteFieldBegin("message", STRING, 1)
+ err = oprot.WriteFieldBegin(ctx, "message", STRING, 1)
if err != nil {
return
}
- err = oprot.WriteString(p.Error())
+ err = oprot.WriteString(ctx, p.Error())
if err != nil {
return
}
- err = oprot.WriteFieldEnd()
+ err = oprot.WriteFieldEnd(ctx)
if err != nil {
return
}
}
- err = oprot.WriteFieldBegin("type", I32, 2)
+ err = oprot.WriteFieldBegin(ctx, "type", I32, 2)
if err != nil {
return
}
- err = oprot.WriteI32(p.type_)
+ err = oprot.WriteI32(ctx, p.type_)
if err != nil {
return
}
- err = oprot.WriteFieldEnd()
+ err = oprot.WriteFieldEnd(ctx)
if err != nil {
return
}
- err = oprot.WriteFieldStop()
+ err = oprot.WriteFieldStop(ctx)
if err != nil {
return
}
- err = oprot.WriteStructEnd()
+ err = oprot.WriteStructEnd(ctx)
return
}
diff --git a/lib/go/thrift/binary_protocol.go b/lib/go/thrift/binary_protocol.go
index 93ae898cf..c87d23a1b 100644
--- a/lib/go/thrift/binary_protocol.go
+++ b/lib/go/thrift/binary_protocol.go
@@ -72,146 +72,146 @@ func (p *TBinaryProtocolFactory) GetProtocol(t TTransport) TProtocol {
* Writing Methods
*/
-func (p *TBinaryProtocol) WriteMessageBegin(name string, typeId TMessageType, seqId int32) error {
+func (p *TBinaryProtocol) WriteMessageBegin(ctx context.Context, name string, typeId TMessageType, seqId int32) error {
if p.strictWrite {
version := uint32(VERSION_1) | uint32(typeId)
- e := p.WriteI32(int32(version))
+ e := p.WriteI32(ctx, int32(version))
if e != nil {
return e
}
- e = p.WriteString(name)
+ e = p.WriteString(ctx, name)
if e != nil {
return e
}
- e = p.WriteI32(seqId)
+ e = p.WriteI32(ctx, seqId)
return e
} else {
- e := p.WriteString(name)
+ e := p.WriteString(ctx, name)
if e != nil {
return e
}
- e = p.WriteByte(int8(typeId))
+ e = p.WriteByte(ctx, int8(typeId))
if e != nil {
return e
}
- e = p.WriteI32(seqId)
+ e = p.WriteI32(ctx, seqId)
return e
}
return nil
}
-func (p *TBinaryProtocol) WriteMessageEnd() error {
+func (p *TBinaryProtocol) WriteMessageEnd(ctx context.Context) error {
return nil
}
-func (p *TBinaryProtocol) WriteStructBegin(name string) error {
+func (p *TBinaryProtocol) WriteStructBegin(ctx context.Context, name string) error {
return nil
}
-func (p *TBinaryProtocol) WriteStructEnd() error {
+func (p *TBinaryProtocol) WriteStructEnd(ctx context.Context) error {
return nil
}
-func (p *TBinaryProtocol) WriteFieldBegin(name string, typeId TType, id int16) error {
- e := p.WriteByte(int8(typeId))
+func (p *TBinaryProtocol) WriteFieldBegin(ctx context.Context, name string, typeId TType, id int16) error {
+ e := p.WriteByte(ctx, int8(typeId))
if e != nil {
return e
}
- e = p.WriteI16(id)
+ e = p.WriteI16(ctx, id)
return e
}
-func (p *TBinaryProtocol) WriteFieldEnd() error {
+func (p *TBinaryProtocol) WriteFieldEnd(ctx context.Context) error {
return nil
}
-func (p *TBinaryProtocol) WriteFieldStop() error {
- e := p.WriteByte(STOP)
+func (p *TBinaryProtocol) WriteFieldStop(ctx context.Context) error {
+ e := p.WriteByte(ctx, STOP)
return e
}
-func (p *TBinaryProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error {
- e := p.WriteByte(int8(keyType))
+func (p *TBinaryProtocol) WriteMapBegin(ctx context.Context, keyType TType, valueType TType, size int) error {
+ e := p.WriteByte(ctx, int8(keyType))
if e != nil {
return e
}
- e = p.WriteByte(int8(valueType))
+ e = p.WriteByte(ctx, int8(valueType))
if e != nil {
return e
}
- e = p.WriteI32(int32(size))
+ e = p.WriteI32(ctx, int32(size))
return e
}
-func (p *TBinaryProtocol) WriteMapEnd() error {
+func (p *TBinaryProtocol) WriteMapEnd(ctx context.Context) error {
return nil
}
-func (p *TBinaryProtocol) WriteListBegin(elemType TType, size int) error {
- e := p.WriteByte(int8(elemType))
+func (p *TBinaryProtocol) WriteListBegin(ctx context.Context, elemType TType, size int) error {
+ e := p.WriteByte(ctx, int8(elemType))
if e != nil {
return e
}
- e = p.WriteI32(int32(size))
+ e = p.WriteI32(ctx, int32(size))
return e
}
-func (p *TBinaryProtocol) WriteListEnd() error {
+func (p *TBinaryProtocol) WriteListEnd(ctx context.Context) error {
return nil
}
-func (p *TBinaryProtocol) WriteSetBegin(elemType TType, size int) error {
- e := p.WriteByte(int8(elemType))
+func (p *TBinaryProtocol) WriteSetBegin(ctx context.Context, elemType TType, size int) error {
+ e := p.WriteByte(ctx, int8(elemType))
if e != nil {
return e
}
- e = p.WriteI32(int32(size))
+ e = p.WriteI32(ctx, int32(size))
return e
}
-func (p *TBinaryProtocol) WriteSetEnd() error {
+func (p *TBinaryProtocol) WriteSetEnd(ctx context.Context) error {
return nil
}
-func (p *TBinaryProtocol) WriteBool(value bool) error {
+func (p *TBinaryProtocol) WriteBool(ctx context.Context, value bool) error {
if value {
- return p.WriteByte(1)
+ return p.WriteByte(ctx, 1)
}
- return p.WriteByte(0)
+ return p.WriteByte(ctx, 0)
}
-func (p *TBinaryProtocol) WriteByte(value int8) error {
+func (p *TBinaryProtocol) WriteByte(ctx context.Context, value int8) error {
e := p.trans.WriteByte(byte(value))
return NewTProtocolException(e)
}
-func (p *TBinaryProtocol) WriteI16(value int16) error {
+func (p *TBinaryProtocol) WriteI16(ctx context.Context, value int16) error {
v := p.buffer[0:2]
binary.BigEndian.PutUint16(v, uint16(value))
_, e := p.trans.Write(v)
return NewTProtocolException(e)
}
-func (p *TBinaryProtocol) WriteI32(value int32) error {
+func (p *TBinaryProtocol) WriteI32(ctx context.Context, value int32) error {
v := p.buffer[0:4]
binary.BigEndian.PutUint32(v, uint32(value))
_, e := p.trans.Write(v)
return NewTProtocolException(e)
}
-func (p *TBinaryProtocol) WriteI64(value int64) error {
+func (p *TBinaryProtocol) WriteI64(ctx context.Context, value int64) error {
v := p.buffer[0:8]
binary.BigEndian.PutUint64(v, uint64(value))
_, err := p.trans.Write(v)
return NewTProtocolException(err)
}
-func (p *TBinaryProtocol) WriteDouble(value float64) error {
- return p.WriteI64(int64(math.Float64bits(value)))
+func (p *TBinaryProtocol) WriteDouble(ctx context.Context, value float64) error {
+ return p.WriteI64(ctx, int64(math.Float64bits(value)))
}
-func (p *TBinaryProtocol) WriteString(value string) error {
- e := p.WriteI32(int32(len(value)))
+func (p *TBinaryProtocol) WriteString(ctx context.Context, value string) error {
+ e := p.WriteI32(ctx, int32(len(value)))
if e != nil {
return e
}
@@ -219,8 +219,8 @@ func (p *TBinaryProtocol) WriteString(value string) error {
return NewTProtocolException(err)
}
-func (p *TBinaryProtocol) WriteBinary(value []byte) error {
- e := p.WriteI32(int32(len(value)))
+func (p *TBinaryProtocol) WriteBinary(ctx context.Context, value []byte) error {
+ e := p.WriteI32(ctx, int32(len(value)))
if e != nil {
return e
}
@@ -232,8 +232,8 @@ func (p *TBinaryProtocol) WriteBinary(value []byte) error {
* Reading methods
*/
-func (p *TBinaryProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqId int32, err error) {
- size, e := p.ReadI32()
+func (p *TBinaryProtocol) ReadMessageBegin(ctx context.Context) (name string, typeId TMessageType, seqId int32, err error) {
+ size, e := p.ReadI32(ctx)
if e != nil {
return "", typeId, 0, NewTProtocolException(e)
}
@@ -243,11 +243,11 @@ func (p *TBinaryProtocol) ReadMessageBegin() (name string, typeId TMessageType,
if version != VERSION_1 {
return name, typeId, seqId, NewTProtocolExceptionWithType(BAD_VERSION, fmt.Errorf("Bad version in ReadMessageBegin"))
}
- name, e = p.ReadString()
+ name, e = p.ReadString(ctx)
if e != nil {
return name, typeId, seqId, NewTProtocolException(e)
}
- seqId, e = p.ReadI32()
+ seqId, e = p.ReadI32(ctx)
if e != nil {
return name, typeId, seqId, NewTProtocolException(e)
}
@@ -260,62 +260,62 @@ func (p *TBinaryProtocol) ReadMessageBegin() (name string, typeId TMessageType,
if e2 != nil {
return name, typeId, seqId, e2
}
- b, e3 := p.ReadByte()
+ b, e3 := p.ReadByte(ctx)
if e3 != nil {
return name, typeId, seqId, e3
}
typeId = TMessageType(b)
- seqId, e4 := p.ReadI32()
+ seqId, e4 := p.ReadI32(ctx)
if e4 != nil {
return name, typeId, seqId, e4
}
return name, typeId, seqId, nil
}
-func (p *TBinaryProtocol) ReadMessageEnd() error {
+func (p *TBinaryProtocol) ReadMessageEnd(ctx context.Context) error {
return nil
}
-func (p *TBinaryProtocol) ReadStructBegin() (name string, err error) {
+func (p *TBinaryProtocol) ReadStructBegin(ctx context.Context) (name string, err error) {
return
}
-func (p *TBinaryProtocol) ReadStructEnd() error {
+func (p *TBinaryProtocol) ReadStructEnd(ctx context.Context) error {
return nil
}
-func (p *TBinaryProtocol) ReadFieldBegin() (name string, typeId TType, seqId int16, err error) {
- t, err := p.ReadByte()
+func (p *TBinaryProtocol) ReadFieldBegin(ctx context.Context) (name string, typeId TType, seqId int16, err error) {
+ t, err := p.ReadByte(ctx)
typeId = TType(t)
if err != nil {
return name, typeId, seqId, err
}
if t != STOP {
- seqId, err = p.ReadI16()
+ seqId, err = p.ReadI16(ctx)
}
return name, typeId, seqId, err
}
-func (p *TBinaryProtocol) ReadFieldEnd() error {
+func (p *TBinaryProtocol) ReadFieldEnd(ctx context.Context) error {
return nil
}
var invalidDataLength = NewTProtocolExceptionWithType(INVALID_DATA, errors.New("Invalid data length"))
-func (p *TBinaryProtocol) ReadMapBegin() (kType, vType TType, size int, err error) {
- k, e := p.ReadByte()
+func (p *TBinaryProtocol) ReadMapBegin(ctx context.Context) (kType, vType TType, size int, err error) {
+ k, e := p.ReadByte(ctx)
if e != nil {
err = NewTProtocolException(e)
return
}
kType = TType(k)
- v, e := p.ReadByte()
+ v, e := p.ReadByte(ctx)
if e != nil {
err = NewTProtocolException(e)
return
}
vType = TType(v)
- size32, e := p.ReadI32()
+ size32, e := p.ReadI32(ctx)
if e != nil {
err = NewTProtocolException(e)
return
@@ -328,18 +328,18 @@ func (p *TBinaryProtocol) ReadMapBegin() (kType, vType TType, size int, err erro
return kType, vType, size, nil
}
-func (p *TBinaryProtocol) ReadMapEnd() error {
+func (p *TBinaryProtocol) ReadMapEnd(ctx context.Context) error {
return nil
}
-func (p *TBinaryProtocol) ReadListBegin() (elemType TType, size int, err error) {
- b, e := p.ReadByte()
+func (p *TBinaryProtocol) ReadListBegin(ctx context.Context) (elemType TType, size int, err error) {
+ b, e := p.ReadByte(ctx)
if e != nil {
err = NewTProtocolException(e)
return
}
elemType = TType(b)
- size32, e := p.ReadI32()
+ size32, e := p.ReadI32(ctx)
if e != nil {
err = NewTProtocolException(e)
return
@@ -353,18 +353,18 @@ func (p *TBinaryProtocol) ReadListBegin() (elemType TType, size int, err error)
return
}
-func (p *TBinaryProtocol) ReadListEnd() error {
+func (p *TBinaryProtocol) ReadListEnd(ctx context.Context) error {
return nil
}
-func (p *TBinaryProtocol) ReadSetBegin() (elemType TType, size int, err error) {
- b, e := p.ReadByte()
+func (p *TBinaryProtocol) ReadSetBegin(ctx context.Context) (elemType TType, size int, err error) {
+ b, e := p.ReadByte(ctx)
if e != nil {
err = NewTProtocolException(e)
return
}
elemType = TType(b)
- size32, e := p.ReadI32()
+ size32, e := p.ReadI32(ctx)
if e != nil {
err = NewTProtocolException(e)
return
@@ -377,12 +377,12 @@ func (p *TBinaryProtocol) ReadSetBegin() (elemType TType, size int, err error) {
return elemType, size, nil
}
-func (p *TBinaryProtocol) ReadSetEnd() error {
+func (p *TBinaryProtocol) ReadSetEnd(ctx context.Context) error {
return nil
}
-func (p *TBinaryProtocol) ReadBool() (bool, error) {
- b, e := p.ReadByte()
+func (p *TBinaryProtocol) ReadBool(ctx context.Context) (bool, error) {
+ b, e := p.ReadByte(ctx)
v := true
if b != 1 {
v = false
@@ -390,41 +390,41 @@ func (p *TBinaryProtocol) ReadBool() (bool, error) {
return v, e
}
-func (p *TBinaryProtocol) ReadByte() (int8, error) {
+func (p *TBinaryProtocol) ReadByte(ctx context.Context) (int8, error) {
v, err := p.trans.ReadByte()
return int8(v), err
}
-func (p *TBinaryProtocol) ReadI16() (value int16, err error) {
+func (p *TBinaryProtocol) ReadI16(ctx context.Context) (value int16, err error) {
buf := p.buffer[0:2]
- err = p.readAll(buf)
+ err = p.readAll(ctx, buf)
value = int16(binary.BigEndian.Uint16(buf))
return value, err
}
-func (p *TBinaryProtocol) ReadI32() (value int32, err error) {
+func (p *TBinaryProtocol) ReadI32(ctx context.Context) (value int32, err error) {
buf := p.buffer[0:4]
- err = p.readAll(buf)
+ err = p.readAll(ctx, buf)
value = int32(binary.BigEndian.Uint32(buf))
return value, err
}
-func (p *TBinaryProtocol) ReadI64() (value int64, err error) {
+func (p *TBinaryProtocol) ReadI64(ctx context.Context) (value int64, err error) {
buf := p.buffer[0:8]
- err = p.readAll(buf)
+ err = p.readAll(ctx, buf)
value = int64(binary.BigEndian.Uint64(buf))
return value, err
}
-func (p *TBinaryProtocol) ReadDouble() (value float64, err error) {
+func (p *TBinaryProtocol) ReadDouble(ctx context.Context) (value float64, err error) {
buf := p.buffer[0:8]
- err = p.readAll(buf)
+ err = p.readAll(ctx, buf)
value = math.Float64frombits(binary.BigEndian.Uint64(buf))
return value, err
}
-func (p *TBinaryProtocol) ReadString() (value string, err error) {
- size, e := p.ReadI32()
+func (p *TBinaryProtocol) ReadString(ctx context.Context) (value string, err error) {
+ size, e := p.ReadI32(ctx)
if e != nil {
return "", e
}
@@ -436,8 +436,8 @@ func (p *TBinaryProtocol) ReadString() (value string, err error) {
return p.readStringBody(size)
}
-func (p *TBinaryProtocol) ReadBinary() ([]byte, error) {
- size, e := p.ReadI32()
+func (p *TBinaryProtocol) ReadBinary(ctx context.Context) ([]byte, error) {
+ size, e := p.ReadI32(ctx)
if e != nil {
return nil, e
}
@@ -455,16 +455,27 @@ func (p *TBinaryProtocol) Flush(ctx context.Context) (err error) {
return NewTProtocolException(p.trans.Flush(ctx))
}
-func (p *TBinaryProtocol) Skip(fieldType TType) (err error) {
- return SkipDefaultDepth(p, fieldType)
+func (p *TBinaryProtocol) Skip(ctx context.Context, fieldType TType) (err error) {
+ return SkipDefaultDepth(ctx, p, fieldType)
}
func (p *TBinaryProtocol) Transport() TTransport {
return p.origTransport
}
-func (p *TBinaryProtocol) readAll(buf []byte) error {
- _, err := io.ReadFull(p.trans, buf)
+func (p *TBinaryProtocol) readAll(ctx context.Context, buf []byte) (err error) {
+ var read int
+ _, deadlineSet := ctx.Deadline()
+ for {
+ read, err = io.ReadFull(p.trans, buf)
+ if deadlineSet && read == 0 && isTimeoutError(err) && ctx.Err() == nil {
+ // This is I/O timeout without anything read,
+ // and we still have time left, keep retrying.
+ continue
+ }
+ // For anything else, don't retry
+ break
+ }
return NewTProtocolException(err)
}
diff --git a/lib/go/thrift/client.go b/lib/go/thrift/client.go
index b073a952d..1c5705d55 100644
--- a/lib/go/thrift/client.go
+++ b/lib/go/thrift/client.go
@@ -34,20 +34,20 @@ func (p *TStandardClient) Send(ctx context.Context, oprot TProtocol, seqId int32
}
}
- if err := oprot.WriteMessageBegin(method, CALL, seqId); err != nil {
+ if err := oprot.WriteMessageBegin(ctx, method, CALL, seqId); err != nil {
return err
}
- if err := args.Write(oprot); err != nil {
+ if err := args.Write(ctx, oprot); err != nil {
return err
}
- if err := oprot.WriteMessageEnd(); err != nil {
+ if err := oprot.WriteMessageEnd(ctx); err != nil {
return err
}
return oprot.Flush(ctx)
}
-func (p *TStandardClient) Recv(iprot TProtocol, seqId int32, method string, result TStruct) error {
- rMethod, rTypeId, rSeqId, err := iprot.ReadMessageBegin()
+func (p *TStandardClient) Recv(ctx context.Context, iprot TProtocol, seqId int32, method string, result TStruct) error {
+ rMethod, rTypeId, rSeqId, err := iprot.ReadMessageBegin(ctx)
if err != nil {
return err
}
@@ -58,11 +58,11 @@ func (p *TStandardClient) Recv(iprot TProtocol, seqId int32, method string, resu
return NewTApplicationException(BAD_SEQUENCE_ID, fmt.Sprintf("%s: out of order sequence response", method))
} else if rTypeId == EXCEPTION {
var exception tApplicationException
- if err := exception.Read(iprot); err != nil {
+ if err := exception.Read(ctx, iprot); err != nil {
return err
}
- if err := iprot.ReadMessageEnd(); err != nil {
+ if err := iprot.ReadMessageEnd(ctx); err != nil {
return err
}
@@ -71,11 +71,11 @@ func (p *TStandardClient) Recv(iprot TProtocol, seqId int32, method string, resu
return NewTApplicationException(INVALID_MESSAGE_TYPE_EXCEPTION, fmt.Sprintf("%s: invalid message type", method))
}
- if err := result.Read(iprot); err != nil {
+ if err := result.Read(ctx, iprot); err != nil {
return err
}
- return iprot.ReadMessageEnd()
+ return iprot.ReadMessageEnd(ctx)
}
func (p *TStandardClient) Call(ctx context.Context, method string, args, result TStruct) error {
@@ -91,5 +91,5 @@ func (p *TStandardClient) Call(ctx context.Context, method string, args, result
return nil
}
- return p.Recv(p.iprot, seqId, method, result)
+ return p.Recv(ctx, p.iprot, seqId, method, result)
}
diff --git a/lib/go/thrift/compact_protocol.go b/lib/go/thrift/compact_protocol.go
index 1900d50c3..468935781 100644
--- a/lib/go/thrift/compact_protocol.go
+++ b/lib/go/thrift/compact_protocol.go
@@ -125,7 +125,7 @@ func NewTCompactProtocol(trans TTransport) *TCompactProtocol {
// Write a message header to the wire. Compact Protocol messages contain the
// protocol version so we can migrate forwards in the future if need be.
-func (p *TCompactProtocol) WriteMessageBegin(name string, typeId TMessageType, seqid int32) error {
+func (p *TCompactProtocol) WriteMessageBegin(ctx context.Context, name string, typeId TMessageType, seqid int32) error {
err := p.writeByteDirect(COMPACT_PROTOCOL_ID)
if err != nil {
return NewTProtocolException(err)
@@ -138,17 +138,17 @@ func (p *TCompactProtocol) WriteMessageBegin(name string, typeId TMessageType, s
if err != nil {
return NewTProtocolException(err)
}
- e := p.WriteString(name)
+ e := p.WriteString(ctx, name)
return e
}
-func (p *TCompactProtocol) WriteMessageEnd() error { return nil }
+func (p *TCompactProtocol) WriteMessageEnd(ctx context.Context) error { return nil }
// Write a struct begin. This doesn't actually put anything on the wire. We
// use it as an opportunity to put special placeholder markers on the field
// stack so we can get the field id deltas correct.
-func (p *TCompactProtocol) WriteStructBegin(name string) error {
+func (p *TCompactProtocol) WriteStructBegin(ctx context.Context, name string) error {
p.lastField = append(p.lastField, p.lastFieldId)
p.lastFieldId = 0
return nil
@@ -157,26 +157,26 @@ func (p *TCompactProtocol) WriteStructBegin(name string) error {
// Write a struct end. This doesn't actually put anything on the wire. We use
// this as an opportunity to pop the last field from the current struct off
// of the field stack.
-func (p *TCompactProtocol) WriteStructEnd() error {
+func (p *TCompactProtocol) WriteStructEnd(ctx context.Context) error {
p.lastFieldId = p.lastField[len(p.lastField)-1]
p.lastField = p.lastField[:len(p.lastField)-1]
return nil
}
-func (p *TCompactProtocol) WriteFieldBegin(name string, typeId TType, id int16) error {
+func (p *TCompactProtocol) WriteFieldBegin(ctx context.Context, name string, typeId TType, id int16) error {
if typeId == BOOL {
// we want to possibly include the value, so we'll wait.
p.booleanFieldName, p.booleanFieldId, p.booleanFieldPending = name, id, true
return nil
}
- _, err := p.writeFieldBeginInternal(name, typeId, id, 0xFF)
+ _, err := p.writeFieldBeginInternal(ctx, name, typeId, id, 0xFF)
return NewTProtocolException(err)
}
// The workhorse of writeFieldBegin. It has the option of doing a
// 'type override' of the type header. This is used specifically in the
// boolean field case.
-func (p *TCompactProtocol) writeFieldBeginInternal(name string, typeId TType, id int16, typeOverride byte) (int, error) {
+func (p *TCompactProtocol) writeFieldBeginInternal(ctx context.Context, name string, typeId TType, id int16, typeOverride byte) (int, error) {
// short lastField = lastField_.pop();
// if there's a type override, use that.
@@ -201,7 +201,7 @@ func (p *TCompactProtocol) writeFieldBeginInternal(name string, typeId TType, id
if err != nil {
return 0, err
}
- err = p.WriteI16(id)
+ err = p.WriteI16(ctx, id)
written = 1 + 2
if err != nil {
return 0, err
@@ -213,14 +213,14 @@ func (p *TCompactProtocol) writeFieldBeginInternal(name string, typeId TType, id
return written, nil
}
-func (p *TCompactProtocol) WriteFieldEnd() error { return nil }
+func (p *TCompactProtocol) WriteFieldEnd(ctx context.Context) error { return nil }
-func (p *TCompactProtocol) WriteFieldStop() error {
+func (p *TCompactProtocol) WriteFieldStop(ctx context.Context) error {
err := p.writeByteDirect(STOP)
return NewTProtocolException(err)
}
-func (p *TCompactProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error {
+func (p *TCompactProtocol) WriteMapBegin(ctx context.Context, keyType TType, valueType TType, size int) error {
if size == 0 {
err := p.writeByteDirect(0)
return NewTProtocolException(err)
@@ -233,32 +233,32 @@ func (p *TCompactProtocol) WriteMapBegin(keyType TType, valueType TType, size in
return NewTProtocolException(err)
}
-func (p *TCompactProtocol) WriteMapEnd() error { return nil }
+func (p *TCompactProtocol) WriteMapEnd(ctx context.Context) error { return nil }
// Write a list header.
-func (p *TCompactProtocol) WriteListBegin(elemType TType, size int) error {
+func (p *TCompactProtocol) WriteListBegin(ctx context.Context, elemType TType, size int) error {
_, err := p.writeCollectionBegin(elemType, size)
return NewTProtocolException(err)
}
-func (p *TCompactProtocol) WriteListEnd() error { return nil }
+func (p *TCompactProtocol) WriteListEnd(ctx context.Context) error { return nil }
// Write a set header.
-func (p *TCompactProtocol) WriteSetBegin(elemType TType, size int) error {
+func (p *TCompactProtocol) WriteSetBegin(ctx context.Context, elemType TType, size int) error {
_, err := p.writeCollectionBegin(elemType, size)
return NewTProtocolException(err)
}
-func (p *TCompactProtocol) WriteSetEnd() error { return nil }
+func (p *TCompactProtocol) WriteSetEnd(ctx context.Context) error { return nil }
-func (p *TCompactProtocol) WriteBool(value bool) error {
+func (p *TCompactProtocol) WriteBool(ctx context.Context, value bool) error {
v := byte(COMPACT_BOOLEAN_FALSE)
if value {
v = byte(COMPACT_BOOLEAN_TRUE)
}
if p.booleanFieldPending {
// we haven't written the field header yet
- _, err := p.writeFieldBeginInternal(p.booleanFieldName, BOOL, p.booleanFieldId, v)
+ _, err := p.writeFieldBeginInternal(ctx, p.booleanFieldName, BOOL, p.booleanFieldId, v)
p.booleanFieldPending = false
return NewTProtocolException(err)
}
@@ -268,31 +268,31 @@ func (p *TCompactProtocol) WriteBool(value bool) error {
}
// Write a byte. Nothing to see here!
-func (p *TCompactProtocol) WriteByte(value int8) error {
+func (p *TCompactProtocol) WriteByte(ctx context.Context, value int8) error {
err := p.writeByteDirect(byte(value))
return NewTProtocolException(err)
}
// Write an I16 as a zigzag varint.
-func (p *TCompactProtocol) WriteI16(value int16) error {
+func (p *TCompactProtocol) WriteI16(ctx context.Context, value int16) error {
_, err := p.writeVarint32(p.int32ToZigzag(int32(value)))
return NewTProtocolException(err)
}
// Write an i32 as a zigzag varint.
-func (p *TCompactProtocol) WriteI32(value int32) error {
+func (p *TCompactProtocol) WriteI32(ctx context.Context, value int32) error {
_, err := p.writeVarint32(p.int32ToZigzag(value))
return NewTProtocolException(err)
}
// Write an i64 as a zigzag varint.
-func (p *TCompactProtocol) WriteI64(value int64) error {
+func (p *TCompactProtocol) WriteI64(ctx context.Context, value int64) error {
_, err := p.writeVarint64(p.int64ToZigzag(value))
return NewTProtocolException(err)
}
// Write a double to the wire as 8 bytes.
-func (p *TCompactProtocol) WriteDouble(value float64) error {
+func (p *TCompactProtocol) WriteDouble(ctx context.Context, value float64) error {
buf := p.buffer[0:8]
binary.LittleEndian.PutUint64(buf, math.Float64bits(value))
_, err := p.trans.Write(buf)
@@ -300,7 +300,7 @@ func (p *TCompactProtocol) WriteDouble(value float64) error {
}
// Write a string to the wire with a varint size preceding.
-func (p *TCompactProtocol) WriteString(value string) error {
+func (p *TCompactProtocol) WriteString(ctx context.Context, value string) error {
_, e := p.writeVarint32(int32(len(value)))
if e != nil {
return NewTProtocolException(e)
@@ -312,7 +312,7 @@ func (p *TCompactProtocol) WriteString(value string) error {
}
// Write a byte array, using a varint for the size.
-func (p *TCompactProtocol) WriteBinary(bin []byte) error {
+func (p *TCompactProtocol) WriteBinary(ctx context.Context, bin []byte) error {
_, e := p.writeVarint32(int32(len(bin)))
if e != nil {
return NewTProtocolException(e)
@@ -329,9 +329,20 @@ func (p *TCompactProtocol) WriteBinary(bin []byte) error {
//
// Read a message header.
-func (p *TCompactProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqId int32, err error) {
+func (p *TCompactProtocol) ReadMessageBegin(ctx context.Context) (name string, typeId TMessageType, seqId int32, err error) {
+ var protocolId byte
- protocolId, err := p.readByteDirect()
+ _, deadlineSet := ctx.Deadline()
+ for {
+ protocolId, err = p.readByteDirect()
+ if deadlineSet && isTimeoutError(err) && ctx.Err() == nil {
+ // keep retrying I/O timeout errors since we still have
+ // time left
+ continue
+ }
+ // For anything else, don't retry
+ break
+ }
if err != nil {
return
}
@@ -358,15 +369,15 @@ func (p *TCompactProtocol) ReadMessageBegin() (name string, typeId TMessageType,
err = NewTProtocolException(e)
return
}
- name, err = p.ReadString()
+ name, err = p.ReadString(ctx)
return
}
-func (p *TCompactProtocol) ReadMessageEnd() error { return nil }
+func (p *TCompactProtocol) ReadMessageEnd(ctx context.Context) error { return nil }
// Read a struct begin. There's nothing on the wire for this, but it is our
// opportunity to push a new struct begin marker onto the field stack.
-func (p *TCompactProtocol) ReadStructBegin() (name string, err error) {
+func (p *TCompactProtocol) ReadStructBegin(ctx context.Context) (name string, err error) {
p.lastField = append(p.lastField, p.lastFieldId)
p.lastFieldId = 0
return
@@ -374,7 +385,7 @@ func (p *TCompactProtocol) ReadStructBegin() (name string, err error) {
// Doesn't actually consume any wire data, just removes the last field for
// this struct from the field stack.
-func (p *TCompactProtocol) ReadStructEnd() error {
+func (p *TCompactProtocol) ReadStructEnd(ctx context.Context) error {
// consume the last field we read off the wire.
p.lastFieldId = p.lastField[len(p.lastField)-1]
p.lastField = p.lastField[:len(p.lastField)-1]
@@ -382,7 +393,7 @@ func (p *TCompactProtocol) ReadStructEnd() error {
}
// Read a field header off the wire.
-func (p *TCompactProtocol) ReadFieldBegin() (name string, typeId TType, id int16, err error) {
+func (p *TCompactProtocol) ReadFieldBegin(ctx context.Context) (name string, typeId TType, id int16, err error) {
t, err := p.readByteDirect()
if err != nil {
return
@@ -397,7 +408,7 @@ func (p *TCompactProtocol) ReadFieldBegin() (name string, typeId TType, id int16
modifier := int16((t & 0xf0) >> 4)
if modifier == 0 {
// not a delta. look ahead for the zigzag varint field id.
- id, err = p.ReadI16()
+ id, err = p.ReadI16(ctx)
if err != nil {
return
}
@@ -423,12 +434,12 @@ func (p *TCompactProtocol) ReadFieldBegin() (name string, typeId TType, id int16
return
}
-func (p *TCompactProtocol) ReadFieldEnd() error { return nil }
+func (p *TCompactProtocol) ReadFieldEnd(ctx context.Context) error { return nil }
// Read a map header off the wire. If the size is zero, skip reading the key
// and value type. This means that 0-length maps will yield TMaps without the
// "correct" types.
-func (p *TCompactProtocol) ReadMapBegin() (keyType TType, valueType TType, size int, err error) {
+func (p *TCompactProtocol) ReadMapBegin(ctx context.Context) (keyType TType, valueType TType, size int, err error) {
size32, e := p.readVarint32()
if e != nil {
err = NewTProtocolException(e)
@@ -452,13 +463,13 @@ func (p *TCompactProtocol) ReadMapBegin() (keyType TType, valueType TType, size
return
}
-func (p *TCompactProtocol) ReadMapEnd() error { return nil }
+func (p *TCompactProtocol) ReadMapEnd(ctx context.Context) error { return nil }
// Read a list header off the wire. If the list size is 0-14, the size will
// be packed into the element type header. If it's a longer list, the 4 MSB
// of the element type header will be 0xF, and a varint will follow with the
// true size.
-func (p *TCompactProtocol) ReadListBegin() (elemType TType, size int, err error) {
+func (p *TCompactProtocol) ReadListBegin(ctx context.Context) (elemType TType, size int, err error) {
size_and_type, err := p.readByteDirect()
if err != nil {
return
@@ -484,22 +495,22 @@ func (p *TCompactProtocol) ReadListBegin() (elemType TType, size int, err error)
return
}
-func (p *TCompactProtocol) ReadListEnd() error { return nil }
+func (p *TCompactProtocol) ReadListEnd(ctx context.Context) error { return nil }
// Read a set header off the wire. If the set size is 0-14, the size will
// be packed into the element type header. If it's a longer set, the 4 MSB
// of the element type header will be 0xF, and a varint will follow with the
// true size.
-func (p *TCompactProtocol) ReadSetBegin() (elemType TType, size int, err error) {
- return p.ReadListBegin()
+func (p *TCompactProtocol) ReadSetBegin(ctx context.Context) (elemType TType, size int, err error) {
+ return p.ReadListBegin(ctx)
}
-func (p *TCompactProtocol) ReadSetEnd() error { return nil }
+func (p *TCompactProtocol) ReadSetEnd(ctx context.Context) error { return nil }
// Read a boolean off the wire. If this is a boolean field, the value should
// already have been read during readFieldBegin, so we'll just consume the
// pre-stored value. Otherwise, read a byte.
-func (p *TCompactProtocol) ReadBool() (value bool, err error) {
+func (p *TCompactProtocol) ReadBool(ctx context.Context) (value bool, err error) {
if p.boolValueIsNotNull {
p.boolValueIsNotNull = false
return p.boolValue, nil
@@ -509,7 +520,7 @@ func (p *TCompactProtocol) ReadBool() (value bool, err error) {
}
// Read a single byte off the wire. Nothing interesting here.
-func (p *TCompactProtocol) ReadByte() (int8, error) {
+func (p *TCompactProtocol) ReadByte(ctx context.Context) (int8, error) {
v, err := p.readByteDirect()
if err != nil {
return 0, NewTProtocolException(err)
@@ -518,13 +529,13 @@ func (p *TCompactProtocol) ReadByte() (int8, error) {
}
// Read an i16 from the wire as a zigzag varint.
-func (p *TCompactProtocol) ReadI16() (value int16, err error) {
- v, err := p.ReadI32()
+func (p *TCompactProtocol) ReadI16(ctx context.Context) (value int16, err error) {
+ v, err := p.ReadI32(ctx)
return int16(v), err
}
// Read an i32 from the wire as a zigzag varint.
-func (p *TCompactProtocol) ReadI32() (value int32, err error) {
+func (p *TCompactProtocol) ReadI32(ctx context.Context) (value int32, err error) {
v, e := p.readVarint32()
if e != nil {
return 0, NewTProtocolException(e)
@@ -534,7 +545,7 @@ func (p *TCompactProtocol) ReadI32() (value int32, err error) {
}
// Read an i64 from the wire as a zigzag varint.
-func (p *TCompactProtocol) ReadI64() (value int64, err error) {
+func (p *TCompactProtocol) ReadI64(ctx context.Context) (value int64, err error) {
v, e := p.readVarint64()
if e != nil {
return 0, NewTProtocolException(e)
@@ -544,7 +555,7 @@ func (p *TCompactProtocol) ReadI64() (value int64, err error) {
}
// No magic here - just read a double off the wire.
-func (p *TCompactProtocol) ReadDouble() (value float64, err error) {
+func (p *TCompactProtocol) ReadDouble(ctx context.Context) (value float64, err error) {
longBits := p.buffer[0:8]
_, e := io.ReadFull(p.trans, longBits)
if e != nil {
@@ -554,7 +565,7 @@ func (p *TCompactProtocol) ReadDouble() (value float64, err error) {
}
// Reads a []byte (via readBinary), and then UTF-8 decodes it.
-func (p *TCompactProtocol) ReadString() (value string, err error) {
+func (p *TCompactProtocol) ReadString(ctx context.Context) (value string, err error) {
length, e := p.readVarint32()
if e != nil {
return "", NewTProtocolException(e)
@@ -577,7 +588,7 @@ func (p *TCompactProtocol) ReadString() (value string, err error) {
}
// Read a []byte from the wire.
-func (p *TCompactProtocol) ReadBinary() (value []byte, err error) {
+func (p *TCompactProtocol) ReadBinary(ctx context.Context) (value []byte, err error) {
length, e := p.readVarint32()
if e != nil {
return nil, NewTProtocolException(e)
@@ -598,8 +609,8 @@ func (p *TCompactProtocol) Flush(ctx context.Context) (err error) {
return NewTProtocolException(p.trans.Flush(ctx))
}
-func (p *TCompactProtocol) Skip(fieldType TType) (err error) {
- return SkipDefaultDepth(p, fieldType)
+func (p *TCompactProtocol) Skip(ctx context.Context, fieldType TType) (err error) {
+ return SkipDefaultDepth(ctx, p, fieldType)
}
func (p *TCompactProtocol) Transport() TTransport {
diff --git a/lib/go/thrift/debug_protocol.go b/lib/go/thrift/debug_protocol.go
index c33fba879..a0920031a 100644
--- a/lib/go/thrift/debug_protocol.go
+++ b/lib/go/thrift/debug_protocol.go
@@ -70,214 +70,214 @@ func (tdp *TDebugProtocol) logf(format string, v ...interface{}) {
fallbackLogger(tdp.Logger)(fmt.Sprintf(format, v...))
}
-func (tdp *TDebugProtocol) WriteMessageBegin(name string, typeId TMessageType, seqid int32) error {
- err := tdp.Delegate.WriteMessageBegin(name, typeId, seqid)
+func (tdp *TDebugProtocol) WriteMessageBegin(ctx context.Context, name string, typeId TMessageType, seqid int32) error {
+ err := tdp.Delegate.WriteMessageBegin(ctx, name, typeId, seqid)
tdp.logf("%sWriteMessageBegin(name=%#v, typeId=%#v, seqid=%#v) => %#v", tdp.LogPrefix, name, typeId, seqid, err)
return err
}
-func (tdp *TDebugProtocol) WriteMessageEnd() error {
- err := tdp.Delegate.WriteMessageEnd()
+func (tdp *TDebugProtocol) WriteMessageEnd(ctx context.Context) error {
+ err := tdp.Delegate.WriteMessageEnd(ctx)
tdp.logf("%sWriteMessageEnd() => %#v", tdp.LogPrefix, err)
return err
}
-func (tdp *TDebugProtocol) WriteStructBegin(name string) error {
- err := tdp.Delegate.WriteStructBegin(name)
+func (tdp *TDebugProtocol) WriteStructBegin(ctx context.Context, name string) error {
+ err := tdp.Delegate.WriteStructBegin(ctx, name)
tdp.logf("%sWriteStructBegin(name=%#v) => %#v", tdp.LogPrefix, name, err)
return err
}
-func (tdp *TDebugProtocol) WriteStructEnd() error {
- err := tdp.Delegate.WriteStructEnd()
+func (tdp *TDebugProtocol) WriteStructEnd(ctx context.Context) error {
+ err := tdp.Delegate.WriteStructEnd(ctx)
tdp.logf("%sWriteStructEnd() => %#v", tdp.LogPrefix, err)
return err
}
-func (tdp *TDebugProtocol) WriteFieldBegin(name string, typeId TType, id int16) error {
- err := tdp.Delegate.WriteFieldBegin(name, typeId, id)
+func (tdp *TDebugProtocol) WriteFieldBegin(ctx context.Context, name string, typeId TType, id int16) error {
+ err := tdp.Delegate.WriteFieldBegin(ctx, name, typeId, id)
tdp.logf("%sWriteFieldBegin(name=%#v, typeId=%#v, id%#v) => %#v", tdp.LogPrefix, name, typeId, id, err)
return err
}
-func (tdp *TDebugProtocol) WriteFieldEnd() error {
- err := tdp.Delegate.WriteFieldEnd()
+func (tdp *TDebugProtocol) WriteFieldEnd(ctx context.Context) error {
+ err := tdp.Delegate.WriteFieldEnd(ctx)
tdp.logf("%sWriteFieldEnd() => %#v", tdp.LogPrefix, err)
return err
}
-func (tdp *TDebugProtocol) WriteFieldStop() error {
- err := tdp.Delegate.WriteFieldStop()
+func (tdp *TDebugProtocol) WriteFieldStop(ctx context.Context) error {
+ err := tdp.Delegate.WriteFieldStop(ctx)
tdp.logf("%sWriteFieldStop() => %#v", tdp.LogPrefix, err)
return err
}
-func (tdp *TDebugProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error {
- err := tdp.Delegate.WriteMapBegin(keyType, valueType, size)
+func (tdp *TDebugProtocol) WriteMapBegin(ctx context.Context, keyType TType, valueType TType, size int) error {
+ err := tdp.Delegate.WriteMapBegin(ctx, keyType, valueType, size)
tdp.logf("%sWriteMapBegin(keyType=%#v, valueType=%#v, size=%#v) => %#v", tdp.LogPrefix, keyType, valueType, size, err)
return err
}
-func (tdp *TDebugProtocol) WriteMapEnd() error {
- err := tdp.Delegate.WriteMapEnd()
+func (tdp *TDebugProtocol) WriteMapEnd(ctx context.Context) error {
+ err := tdp.Delegate.WriteMapEnd(ctx)
tdp.logf("%sWriteMapEnd() => %#v", tdp.LogPrefix, err)
return err
}
-func (tdp *TDebugProtocol) WriteListBegin(elemType TType, size int) error {
- err := tdp.Delegate.WriteListBegin(elemType, size)
+func (tdp *TDebugProtocol) WriteListBegin(ctx context.Context, elemType TType, size int) error {
+ err := tdp.Delegate.WriteListBegin(ctx, elemType, size)
tdp.logf("%sWriteListBegin(elemType=%#v, size=%#v) => %#v", tdp.LogPrefix, elemType, size, err)
return err
}
-func (tdp *TDebugProtocol) WriteListEnd() error {
- err := tdp.Delegate.WriteListEnd()
+func (tdp *TDebugProtocol) WriteListEnd(ctx context.Context) error {
+ err := tdp.Delegate.WriteListEnd(ctx)
tdp.logf("%sWriteListEnd() => %#v", tdp.LogPrefix, err)
return err
}
-func (tdp *TDebugProtocol) WriteSetBegin(elemType TType, size int) error {
- err := tdp.Delegate.WriteSetBegin(elemType, size)
+func (tdp *TDebugProtocol) WriteSetBegin(ctx context.Context, elemType TType, size int) error {
+ err := tdp.Delegate.WriteSetBegin(ctx, elemType, size)
tdp.logf("%sWriteSetBegin(elemType=%#v, size=%#v) => %#v", tdp.LogPrefix, elemType, size, err)
return err
}
-func (tdp *TDebugProtocol) WriteSetEnd() error {
- err := tdp.Delegate.WriteSetEnd()
+func (tdp *TDebugProtocol) WriteSetEnd(ctx context.Context) error {
+ err := tdp.Delegate.WriteSetEnd(ctx)
tdp.logf("%sWriteSetEnd() => %#v", tdp.LogPrefix, err)
return err
}
-func (tdp *TDebugProtocol) WriteBool(value bool) error {
- err := tdp.Delegate.WriteBool(value)
+func (tdp *TDebugProtocol) WriteBool(ctx context.Context, value bool) error {
+ err := tdp.Delegate.WriteBool(ctx, value)
tdp.logf("%sWriteBool(value=%#v) => %#v", tdp.LogPrefix, value, err)
return err
}
-func (tdp *TDebugProtocol) WriteByte(value int8) error {
- err := tdp.Delegate.WriteByte(value)
+func (tdp *TDebugProtocol) WriteByte(ctx context.Context, value int8) error {
+ err := tdp.Delegate.WriteByte(ctx, value)
tdp.logf("%sWriteByte(value=%#v) => %#v", tdp.LogPrefix, value, err)
return err
}
-func (tdp *TDebugProtocol) WriteI16(value int16) error {
- err := tdp.Delegate.WriteI16(value)
+func (tdp *TDebugProtocol) WriteI16(ctx context.Context, value int16) error {
+ err := tdp.Delegate.WriteI16(ctx, value)
tdp.logf("%sWriteI16(value=%#v) => %#v", tdp.LogPrefix, value, err)
return err
}
-func (tdp *TDebugProtocol) WriteI32(value int32) error {
- err := tdp.Delegate.WriteI32(value)
+func (tdp *TDebugProtocol) WriteI32(ctx context.Context, value int32) error {
+ err := tdp.Delegate.WriteI32(ctx, value)
tdp.logf("%sWriteI32(value=%#v) => %#v", tdp.LogPrefix, value, err)
return err
}
-func (tdp *TDebugProtocol) WriteI64(value int64) error {
- err := tdp.Delegate.WriteI64(value)
+func (tdp *TDebugProtocol) WriteI64(ctx context.Context, value int64) error {
+ err := tdp.Delegate.WriteI64(ctx, value)
tdp.logf("%sWriteI64(value=%#v) => %#v", tdp.LogPrefix, value, err)
return err
}
-func (tdp *TDebugProtocol) WriteDouble(value float64) error {
- err := tdp.Delegate.WriteDouble(value)
+func (tdp *TDebugProtocol) WriteDouble(ctx context.Context, value float64) error {
+ err := tdp.Delegate.WriteDouble(ctx, value)
tdp.logf("%sWriteDouble(value=%#v) => %#v", tdp.LogPrefix, value, err)
return err
}
-func (tdp *TDebugProtocol) WriteString(value string) error {
- err := tdp.Delegate.WriteString(value)
+func (tdp *TDebugProtocol) WriteString(ctx context.Context, value string) error {
+ err := tdp.Delegate.WriteString(ctx, value)
tdp.logf("%sWriteString(value=%#v) => %#v", tdp.LogPrefix, value, err)
return err
}
-func (tdp *TDebugProtocol) WriteBinary(value []byte) error {
- err := tdp.Delegate.WriteBinary(value)
+func (tdp *TDebugProtocol) WriteBinary(ctx context.Context, value []byte) error {
+ err := tdp.Delegate.WriteBinary(ctx, value)
tdp.logf("%sWriteBinary(value=%#v) => %#v", tdp.LogPrefix, value, err)
return err
}
-func (tdp *TDebugProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqid int32, err error) {
- name, typeId, seqid, err = tdp.Delegate.ReadMessageBegin()
+func (tdp *TDebugProtocol) ReadMessageBegin(ctx context.Context) (name string, typeId TMessageType, seqid int32, err error) {
+ name, typeId, seqid, err = tdp.Delegate.ReadMessageBegin(ctx)
tdp.logf("%sReadMessageBegin() (name=%#v, typeId=%#v, seqid=%#v, err=%#v)", tdp.LogPrefix, name, typeId, seqid, err)
return
}
-func (tdp *TDebugProtocol) ReadMessageEnd() (err error) {
- err = tdp.Delegate.ReadMessageEnd()
+func (tdp *TDebugProtocol) ReadMessageEnd(ctx context.Context) (err error) {
+ err = tdp.Delegate.ReadMessageEnd(ctx)
tdp.logf("%sReadMessageEnd() err=%#v", tdp.LogPrefix, err)
return
}
-func (tdp *TDebugProtocol) ReadStructBegin() (name string, err error) {
- name, err = tdp.Delegate.ReadStructBegin()
+func (tdp *TDebugProtocol) ReadStructBegin(ctx context.Context) (name string, err error) {
+ name, err = tdp.Delegate.ReadStructBegin(ctx)
tdp.logf("%sReadStructBegin() (name%#v, err=%#v)", tdp.LogPrefix, name, err)
return
}
-func (tdp *TDebugProtocol) ReadStructEnd() (err error) {
- err = tdp.Delegate.ReadStructEnd()
+func (tdp *TDebugProtocol) ReadStructEnd(ctx context.Context) (err error) {
+ err = tdp.Delegate.ReadStructEnd(ctx)
tdp.logf("%sReadStructEnd() err=%#v", tdp.LogPrefix, err)
return
}
-func (tdp *TDebugProtocol) ReadFieldBegin() (name string, typeId TType, id int16, err error) {
- name, typeId, id, err = tdp.Delegate.ReadFieldBegin()
+func (tdp *TDebugProtocol) ReadFieldBegin(ctx context.Context) (name string, typeId TType, id int16, err error) {
+ name, typeId, id, err = tdp.Delegate.ReadFieldBegin(ctx)
tdp.logf("%sReadFieldBegin() (name=%#v, typeId=%#v, id=%#v, err=%#v)", tdp.LogPrefix, name, typeId, id, err)
return
}
-func (tdp *TDebugProtocol) ReadFieldEnd() (err error) {
- err = tdp.Delegate.ReadFieldEnd()
+func (tdp *TDebugProtocol) ReadFieldEnd(ctx context.Context) (err error) {
+ err = tdp.Delegate.ReadFieldEnd(ctx)
tdp.logf("%sReadFieldEnd() err=%#v", tdp.LogPrefix, err)
return
}
-func (tdp *TDebugProtocol) ReadMapBegin() (keyType TType, valueType TType, size int, err error) {
- keyType, valueType, size, err = tdp.Delegate.ReadMapBegin()
+func (tdp *TDebugProtocol) ReadMapBegin(ctx context.Context) (keyType TType, valueType TType, size int, err error) {
+ keyType, valueType, size, err = tdp.Delegate.ReadMapBegin(ctx)
tdp.logf("%sReadMapBegin() (keyType=%#v, valueType=%#v, size=%#v, err=%#v)", tdp.LogPrefix, keyType, valueType, size, err)
return
}
-func (tdp *TDebugProtocol) ReadMapEnd() (err error) {
- err = tdp.Delegate.ReadMapEnd()
+func (tdp *TDebugProtocol) ReadMapEnd(ctx context.Context) (err error) {
+ err = tdp.Delegate.ReadMapEnd(ctx)
tdp.logf("%sReadMapEnd() err=%#v", tdp.LogPrefix, err)
return
}
-func (tdp *TDebugProtocol) ReadListBegin() (elemType TType, size int, err error) {
- elemType, size, err = tdp.Delegate.ReadListBegin()
+func (tdp *TDebugProtocol) ReadListBegin(ctx context.Context) (elemType TType, size int, err error) {
+ elemType, size, err = tdp.Delegate.ReadListBegin(ctx)
tdp.logf("%sReadListBegin() (elemType=%#v, size=%#v, err=%#v)", tdp.LogPrefix, elemType, size, err)
return
}
-func (tdp *TDebugProtocol) ReadListEnd() (err error) {
- err = tdp.Delegate.ReadListEnd()
+func (tdp *TDebugProtocol) ReadListEnd(ctx context.Context) (err error) {
+ err = tdp.Delegate.ReadListEnd(ctx)
tdp.logf("%sReadListEnd() err=%#v", tdp.LogPrefix, err)
return
}
-func (tdp *TDebugProtocol) ReadSetBegin() (elemType TType, size int, err error) {
- elemType, size, err = tdp.Delegate.ReadSetBegin()
+func (tdp *TDebugProtocol) ReadSetBegin(ctx context.Context) (elemType TType, size int, err error) {
+ elemType, size, err = tdp.Delegate.ReadSetBegin(ctx)
tdp.logf("%sReadSetBegin() (elemType=%#v, size=%#v, err=%#v)", tdp.LogPrefix, elemType, size, err)
return
}
-func (tdp *TDebugProtocol) ReadSetEnd() (err error) {
- err = tdp.Delegate.ReadSetEnd()
+func (tdp *TDebugProtocol) ReadSetEnd(ctx context.Context) (err error) {
+ err = tdp.Delegate.ReadSetEnd(ctx)
tdp.logf("%sReadSetEnd() err=%#v", tdp.LogPrefix, err)
return
}
-func (tdp *TDebugProtocol) ReadBool() (value bool, err error) {
- value, err = tdp.Delegate.ReadBool()
+func (tdp *TDebugProtocol) ReadBool(ctx context.Context) (value bool, err error) {
+ value, err = tdp.Delegate.ReadBool(ctx)
tdp.logf("%sReadBool() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
return
}
-func (tdp *TDebugProtocol) ReadByte() (value int8, err error) {
- value, err = tdp.Delegate.ReadByte()
+func (tdp *TDebugProtocol) ReadByte(ctx context.Context) (value int8, err error) {
+ value, err = tdp.Delegate.ReadByte(ctx)
tdp.logf("%sReadByte() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
return
}
-func (tdp *TDebugProtocol) ReadI16() (value int16, err error) {
- value, err = tdp.Delegate.ReadI16()
+func (tdp *TDebugProtocol) ReadI16(ctx context.Context) (value int16, err error) {
+ value, err = tdp.Delegate.ReadI16(ctx)
tdp.logf("%sReadI16() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
return
}
-func (tdp *TDebugProtocol) ReadI32() (value int32, err error) {
- value, err = tdp.Delegate.ReadI32()
+func (tdp *TDebugProtocol) ReadI32(ctx context.Context) (value int32, err error) {
+ value, err = tdp.Delegate.ReadI32(ctx)
tdp.logf("%sReadI32() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
return
}
-func (tdp *TDebugProtocol) ReadI64() (value int64, err error) {
- value, err = tdp.Delegate.ReadI64()
+func (tdp *TDebugProtocol) ReadI64(ctx context.Context) (value int64, err error) {
+ value, err = tdp.Delegate.ReadI64(ctx)
tdp.logf("%sReadI64() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
return
}
-func (tdp *TDebugProtocol) ReadDouble() (value float64, err error) {
- value, err = tdp.Delegate.ReadDouble()
+func (tdp *TDebugProtocol) ReadDouble(ctx context.Context) (value float64, err error) {
+ value, err = tdp.Delegate.ReadDouble(ctx)
tdp.logf("%sReadDouble() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
return
}
-func (tdp *TDebugProtocol) ReadString() (value string, err error) {
- value, err = tdp.Delegate.ReadString()
+func (tdp *TDebugProtocol) ReadString(ctx context.Context) (value string, err error) {
+ value, err = tdp.Delegate.ReadString(ctx)
tdp.logf("%sReadString() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
return
}
-func (tdp *TDebugProtocol) ReadBinary() (value []byte, err error) {
- value, err = tdp.Delegate.ReadBinary()
+func (tdp *TDebugProtocol) ReadBinary(ctx context.Context) (value []byte, err error) {
+ value, err = tdp.Delegate.ReadBinary(ctx)
tdp.logf("%sReadBinary() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
return
}
-func (tdp *TDebugProtocol) Skip(fieldType TType) (err error) {
- err = tdp.Delegate.Skip(fieldType)
+func (tdp *TDebugProtocol) Skip(ctx context.Context, fieldType TType) (err error) {
+ err = tdp.Delegate.Skip(ctx, fieldType)
tdp.logf("%sSkip(fieldType=%#v) (err=%#v)", tdp.LogPrefix, fieldType, err)
return
}
diff --git a/lib/go/thrift/deserializer.go b/lib/go/thrift/deserializer.go
index 2ab82145e..e1203a868 100644
--- a/lib/go/thrift/deserializer.go
+++ b/lib/go/thrift/deserializer.go
@@ -20,6 +20,7 @@
package thrift
import (
+ "context"
"sync"
)
@@ -38,27 +39,27 @@ func NewTDeserializer() *TDeserializer {
protocol}
}
-func (t *TDeserializer) ReadString(msg TStruct, s string) (err error) {
+func (t *TDeserializer) ReadString(ctx context.Context, msg TStruct, s string) (err error) {
t.Transport.Reset()
err = nil
if _, err = t.Transport.Write([]byte(s)); err != nil {
return
}
- if err = msg.Read(t.Protocol); err != nil {
+ if err = msg.Read(ctx, t.Protocol); err != nil {
return
}
return
}
-func (t *TDeserializer) Read(msg TStruct, b []byte) (err error) {
+func (t *TDeserializer) Read(ctx context.Context, msg TStruct, b []byte) (err error) {
t.Transport.Reset()
err = nil
if _, err = t.Transport.Write(b); err != nil {
return
}
- if err = msg.Read(t.Protocol); err != nil {
+ if err = msg.Read(ctx, t.Protocol); err != nil {
return
}
return
@@ -85,14 +86,14 @@ func NewTDeserializerPool(f func() *TDeserializer) *TDeserializerPool {
}
}
-func (t *TDeserializerPool) ReadString(msg TStruct, s string) error {
+func (t *TDeserializerPool) ReadString(ctx context.Context, msg TStruct, s string) error {
d := t.pool.Get().(*TDeserializer)
defer t.pool.Put(d)
- return d.ReadString(msg, s)
+ return d.ReadString(ctx, msg, s)
}
-func (t *TDeserializerPool) Read(msg TStruct, b []byte) error {
+func (t *TDeserializerPool) Read(ctx context.Context, msg TStruct, b []byte) error {
d := t.pool.Get().(*TDeserializer)
defer t.pool.Put(d)
- return d.Read(msg, b)
+ return d.Read(ctx, msg, b)
}
diff --git a/lib/go/thrift/header_protocol.go b/lib/go/thrift/header_protocol.go
index 99deaf7dc..428b26148 100644
--- a/lib/go/thrift/header_protocol.go
+++ b/lib/go/thrift/header_protocol.go
@@ -95,106 +95,106 @@ func (p *THeaderProtocol) Flush(ctx context.Context) error {
return p.transport.Flush(ctx)
}
-func (p *THeaderProtocol) WriteMessageBegin(name string, typeID TMessageType, seqID int32) error {
+func (p *THeaderProtocol) WriteMessageBegin(ctx context.Context, name string, typeID TMessageType, seqID int32) error {
newProto, err := p.transport.Protocol().GetProtocol(p.transport)
if err != nil {
return err
}
p.protocol = newProto
p.transport.SequenceID = seqID
- return p.protocol.WriteMessageBegin(name, typeID, seqID)
+ return p.protocol.WriteMessageBegin(ctx, name, typeID, seqID)
}
-func (p *THeaderProtocol) WriteMessageEnd() error {
- if err := p.protocol.WriteMessageEnd(); err != nil {
+func (p *THeaderProtocol) WriteMessageEnd(ctx context.Context) error {
+ if err := p.protocol.WriteMessageEnd(ctx); err != nil {
return err
}
- return p.transport.Flush(context.Background())
+ return p.transport.Flush(ctx)
}
-func (p *THeaderProtocol) WriteStructBegin(name string) error {
- return p.protocol.WriteStructBegin(name)
+func (p *THeaderProtocol) WriteStructBegin(ctx context.Context, name string) error {
+ return p.protocol.WriteStructBegin(ctx, name)
}
-func (p *THeaderProtocol) WriteStructEnd() error {
- return p.protocol.WriteStructEnd()
+func (p *THeaderProtocol) WriteStructEnd(ctx context.Context) error {
+ return p.protocol.WriteStructEnd(ctx)
}
-func (p *THeaderProtocol) WriteFieldBegin(name string, typeID TType, id int16) error {
- return p.protocol.WriteFieldBegin(name, typeID, id)
+func (p *THeaderProtocol) WriteFieldBegin(ctx context.Context, name string, typeID TType, id int16) error {
+ return p.protocol.WriteFieldBegin(ctx, name, typeID, id)
}
-func (p *THeaderProtocol) WriteFieldEnd() error {
- return p.protocol.WriteFieldEnd()
+func (p *THeaderProtocol) WriteFieldEnd(ctx context.Context) error {
+ return p.protocol.WriteFieldEnd(ctx)
}
-func (p *THeaderProtocol) WriteFieldStop() error {
- return p.protocol.WriteFieldStop()
+func (p *THeaderProtocol) WriteFieldStop(ctx context.Context) error {
+ return p.protocol.WriteFieldStop(ctx)
}
-func (p *THeaderProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error {
- return p.protocol.WriteMapBegin(keyType, valueType, size)
+func (p *THeaderProtocol) WriteMapBegin(ctx context.Context, keyType TType, valueType TType, size int) error {
+ return p.protocol.WriteMapBegin(ctx, keyType, valueType, size)
}
-func (p *THeaderProtocol) WriteMapEnd() error {
- return p.protocol.WriteMapEnd()
+func (p *THeaderProtocol) WriteMapEnd(ctx context.Context) error {
+ return p.protocol.WriteMapEnd(ctx)
}
-func (p *THeaderProtocol) WriteListBegin(elemType TType, size int) error {
- return p.protocol.WriteListBegin(elemType, size)
+func (p *THeaderProtocol) WriteListBegin(ctx context.Context, elemType TType, size int) error {
+ return p.protocol.WriteListBegin(ctx, elemType, size)
}
-func (p *THeaderProtocol) WriteListEnd() error {
- return p.protocol.WriteListEnd()
+func (p *THeaderProtocol) WriteListEnd(ctx context.Context) error {
+ return p.protocol.WriteListEnd(ctx)
}
-func (p *THeaderProtocol) WriteSetBegin(elemType TType, size int) error {
- return p.protocol.WriteSetBegin(elemType, size)
+func (p *THeaderProtocol) WriteSetBegin(ctx context.Context, elemType TType, size int) error {
+ return p.protocol.WriteSetBegin(ctx, elemType, size)
}
-func (p *THeaderProtocol) WriteSetEnd() error {
- return p.protocol.WriteSetEnd()
+func (p *THeaderProtocol) WriteSetEnd(ctx context.Context) error {
+ return p.protocol.WriteSetEnd(ctx)
}
-func (p *THeaderProtocol) WriteBool(value bool) error {
- return p.protocol.WriteBool(value)
+func (p *THeaderProtocol) WriteBool(ctx context.Context, value bool) error {
+ return p.protocol.WriteBool(ctx, value)
}
-func (p *THeaderProtocol) WriteByte(value int8) error {
- return p.protocol.WriteByte(value)
+func (p *THeaderProtocol) WriteByte(ctx context.Context, value int8) error {
+ return p.protocol.WriteByte(ctx, value)
}
-func (p *THeaderProtocol) WriteI16(value int16) error {
- return p.protocol.WriteI16(value)
+func (p *THeaderProtocol) WriteI16(ctx context.Context, value int16) error {
+ return p.protocol.WriteI16(ctx, value)
}
-func (p *THeaderProtocol) WriteI32(value int32) error {
- return p.protocol.WriteI32(value)
+func (p *THeaderProtocol) WriteI32(ctx context.Context, value int32) error {
+ return p.protocol.WriteI32(ctx, value)
}
-func (p *THeaderProtocol) WriteI64(value int64) error {
- return p.protocol.WriteI64(value)
+func (p *THeaderProtocol) WriteI64(ctx context.Context, value int64) error {
+ return p.protocol.WriteI64(ctx, value)
}
-func (p *THeaderProtocol) WriteDouble(value float64) error {
- return p.protocol.WriteDouble(value)
+func (p *THeaderProtocol) WriteDouble(ctx context.Context, value float64) error {
+ return p.protocol.WriteDouble(ctx, value)
}
-func (p *THeaderProtocol) WriteString(value string) error {
- return p.protocol.WriteString(value)
+func (p *THeaderProtocol) WriteString(ctx context.Context, value string) error {
+ return p.protocol.WriteString(ctx, value)
}
-func (p *THeaderProtocol) WriteBinary(value []byte) error {
- return p.protocol.WriteBinary(value)
+func (p *THeaderProtocol) WriteBinary(ctx context.Context, value []byte) error {
+ return p.protocol.WriteBinary(ctx, value)
}
// ReadFrame calls underlying THeaderTransport's ReadFrame function.
-func (p *THeaderProtocol) ReadFrame() error {
- return p.transport.ReadFrame()
+func (p *THeaderProtocol) ReadFrame(ctx context.Context) error {
+ return p.transport.ReadFrame(ctx)
}
-func (p *THeaderProtocol) ReadMessageBegin() (name string, typeID TMessageType, seqID int32, err error) {
- if err = p.transport.ReadFrame(); err != nil {
+func (p *THeaderProtocol) ReadMessageBegin(ctx context.Context) (name string, typeID TMessageType, seqID int32, err error) {
+ if err = p.transport.ReadFrame(ctx); err != nil {
return
}
@@ -205,103 +205,103 @@ func (p *THeaderProtocol) ReadMessageBegin() (name string, typeID TMessageType,
if !ok {
return
}
- if e := p.protocol.WriteMessageBegin("", EXCEPTION, seqID); e != nil {
+ if e := p.protocol.WriteMessageBegin(ctx, "", EXCEPTION, seqID); e != nil {
return
}
- if e := tAppExc.Write(p.protocol); e != nil {
+ if e := tAppExc.Write(ctx, p.protocol); e != nil {
return
}
- if e := p.protocol.WriteMessageEnd(); e != nil {
+ if e := p.protocol.WriteMessageEnd(ctx); e != nil {
return
}
- if e := p.transport.Flush(context.Background()); e != nil {
+ if e := p.transport.Flush(ctx); e != nil {
return
}
return
}
p.protocol = newProto
- return p.protocol.ReadMessageBegin()
+ return p.protocol.ReadMessageBegin(ctx)
}
-func (p *THeaderProtocol) ReadMessageEnd() error {
- return p.protocol.ReadMessageEnd()
+func (p *THeaderProtocol) ReadMessageEnd(ctx context.Context) error {
+ return p.protocol.ReadMessageEnd(ctx)
}
-func (p *THeaderProtocol) ReadStructBegin() (name string, err error) {
- return p.protocol.ReadStructBegin()
+func (p *THeaderProtocol) ReadStructBegin(ctx context.Context) (name string, err error) {
+ return p.protocol.ReadStructBegin(ctx)
}
-func (p *THeaderProtocol) ReadStructEnd() error {
- return p.protocol.ReadStructEnd()
+func (p *THeaderProtocol) ReadStructEnd(ctx context.Context) error {
+ return p.protocol.ReadStructEnd(ctx)
}
-func (p *THeaderProtocol) ReadFieldBegin() (name string, typeID TType, id int16, err error) {
- return p.protocol.ReadFieldBegin()
+func (p *THeaderProtocol) ReadFieldBegin(ctx context.Context) (name string, typeID TType, id int16, err error) {
+ return p.protocol.ReadFieldBegin(ctx)
}
-func (p *THeaderProtocol) ReadFieldEnd() error {
- return p.protocol.ReadFieldEnd()
+func (p *THeaderProtocol) ReadFieldEnd(ctx context.Context) error {
+ return p.protocol.ReadFieldEnd(ctx)
}
-func (p *THeaderProtocol) ReadMapBegin() (keyType TType, valueType TType, size int, err error) {
- return p.protocol.ReadMapBegin()
+func (p *THeaderProtocol) ReadMapBegin(ctx context.Context) (keyType TType, valueType TType, size int, err error) {
+ return p.protocol.ReadMapBegin(ctx)
}
-func (p *THeaderProtocol) ReadMapEnd() error {
- return p.protocol.ReadMapEnd()
+func (p *THeaderProtocol) ReadMapEnd(ctx context.Context) error {
+ return p.protocol.ReadMapEnd(ctx)
}
-func (p *THeaderProtocol) ReadListBegin() (elemType TType, size int, err error) {
- return p.protocol.ReadListBegin()
+func (p *THeaderProtocol) ReadListBegin(ctx context.Context) (elemType TType, size int, err error) {
+ return p.protocol.ReadListBegin(ctx)
}
-func (p *THeaderProtocol) ReadListEnd() error {
- return p.protocol.ReadListEnd()
+func (p *THeaderProtocol) ReadListEnd(ctx context.Context) error {
+ return p.protocol.ReadListEnd(ctx)
}
-func (p *THeaderProtocol) ReadSetBegin() (elemType TType, size int, err error) {
- return p.protocol.ReadSetBegin()
+func (p *THeaderProtocol) ReadSetBegin(ctx context.Context) (elemType TType, size int, err error) {
+ return p.protocol.ReadSetBegin(ctx)
}
-func (p *THeaderProtocol) ReadSetEnd() error {
- return p.protocol.ReadSetEnd()
+func (p *THeaderProtocol) ReadSetEnd(ctx context.Context) error {
+ return p.protocol.ReadSetEnd(ctx)
}
-func (p *THeaderProtocol) ReadBool() (value bool, err error) {
- return p.protocol.ReadBool()
+func (p *THeaderProtocol) ReadBool(ctx context.Context) (value bool, err error) {
+ return p.protocol.ReadBool(ctx)
}
-func (p *THeaderProtocol) ReadByte() (value int8, err error) {
- return p.protocol.ReadByte()
+func (p *THeaderProtocol) ReadByte(ctx context.Context) (value int8, err error) {
+ return p.protocol.ReadByte(ctx)
}
-func (p *THeaderProtocol) ReadI16() (value int16, err error) {
- return p.protocol.ReadI16()
+func (p *THeaderProtocol) ReadI16(ctx context.Context) (value int16, err error) {
+ return p.protocol.ReadI16(ctx)
}
-func (p *THeaderProtocol) ReadI32() (value int32, err error) {
- return p.protocol.ReadI32()
+func (p *THeaderProtocol) ReadI32(ctx context.Context) (value int32, err error) {
+ return p.protocol.ReadI32(ctx)
}
-func (p *THeaderProtocol) ReadI64() (value int64, err error) {
- return p.protocol.ReadI64()
+func (p *THeaderProtocol) ReadI64(ctx context.Context) (value int64, err error) {
+ return p.protocol.ReadI64(ctx)
}
-func (p *THeaderProtocol) ReadDouble() (value float64, err error) {
- return p.protocol.ReadDouble()
+func (p *THeaderProtocol) ReadDouble(ctx context.Context) (value float64, err error) {
+ return p.protocol.ReadDouble(ctx)
}
-func (p *THeaderProtocol) ReadString() (value string, err error) {
- return p.protocol.ReadString()
+func (p *THeaderProtocol) ReadString(ctx context.Context) (value string, err error) {
+ return p.protocol.ReadString(ctx)
}
-func (p *THeaderProtocol) ReadBinary() (value []byte, err error) {
- return p.protocol.ReadBinary()
+func (p *THeaderProtocol) ReadBinary(ctx context.Context) (value []byte, err error) {
+ return p.protocol.ReadBinary(ctx)
}
-func (p *THeaderProtocol) Skip(fieldType TType) error {
- return p.protocol.Skip(fieldType)
+func (p *THeaderProtocol) Skip(ctx context.Context, fieldType TType) error {
+ return p.protocol.Skip(ctx, fieldType)
}
// GetResponseHeadersFromClient is a helper function to get the read THeaderMap
diff --git a/lib/go/thrift/header_transport.go b/lib/go/thrift/header_transport.go
index 85d296d6b..c622c0e4f 100644
--- a/lib/go/thrift/header_transport.go
+++ b/lib/go/thrift/header_transport.go
@@ -297,18 +297,34 @@ func (t *THeaderTransport) IsOpen() bool {
// ReadFrame tries to read the frame header, guess the client type, and handle
// unframed clients.
-func (t *THeaderTransport) ReadFrame() error {
+func (t *THeaderTransport) ReadFrame(ctx context.Context) error {
if !t.needReadFrame() {
// No need to read frame, skipping.
return nil
}
+
// Peek and handle the first 32 bits.
// They could either be the length field of a framed message,
// or the first bytes of an unframed message.
- buf, err := t.reader.Peek(size32)
+ var buf []byte
+ var err error
+ // This is also usually the first read from a connection,
+ // so handle retries around socket timeouts.
+ _, deadlineSet := ctx.Deadline()
+ for {
+ buf, err = t.reader.Peek(size32)
+ if deadlineSet && isTimeoutError(err) && ctx.Err() == nil {
+ // This is I/O timeout and we still have time,
+ // continue trying
+ continue
+ }
+ // For anything else, do not retry
+ break
+ }
if err != nil {
return err
}
+
frameSize := binary.BigEndian.Uint32(buf)
if frameSize&VERSION_MASK == VERSION_1 {
t.clientType = clientUnframedBinary
@@ -341,7 +357,7 @@ func (t *THeaderTransport) ReadFrame() error {
version := binary.BigEndian.Uint32(buf)
if version&THeaderHeaderMask == THeaderHeaderMagic {
t.clientType = clientHeaders
- return t.parseHeaders(frameSize)
+ return t.parseHeaders(ctx, frameSize)
}
if version&VERSION_MASK == VERSION_1 {
t.clientType = clientFramedBinary
@@ -371,7 +387,7 @@ func (t *THeaderTransport) endOfFrame() error {
return t.frameReader.Close()
}
-func (t *THeaderTransport) parseHeaders(frameSize uint32) error {
+func (t *THeaderTransport) parseHeaders(ctx context.Context, frameSize uint32) error {
if t.clientType != clientHeaders {
return nil
}
@@ -451,11 +467,11 @@ func (t *THeaderTransport) parseHeaders(frameSize uint32) error {
return err
}
for i := 0; i < int(count); i++ {
- key, err := hp.ReadString()
+ key, err := hp.ReadString(ctx)
if err != nil {
return err
}
- value, err := hp.ReadString()
+ value, err := hp.ReadString(ctx)
if err != nil {
return err
}
@@ -485,7 +501,12 @@ func (t *THeaderTransport) needReadFrame() bool {
}
func (t *THeaderTransport) Read(p []byte) (read int, err error) {
- err = t.ReadFrame()
+ // Here using context.Background instead of a context passed in is safe.
+ // First is that there's no way to pass context into this function.
+ // Then, 99% of the case when calling this Read frame is already read
+ // into frameReader. ReadFrame here is more of preventing bugs that
+ // didn't call ReadFrame before calling Read.
+ err = t.ReadFrame(context.Background())
if err != nil {
return
}
@@ -557,10 +578,10 @@ func (t *THeaderTransport) Flush(ctx context.Context) error {
return NewTTransportExceptionFromError(err)
}
for key, value := range t.writeHeaders {
- if err := hp.WriteString(key); err != nil {
+ if err := hp.WriteString(ctx, key); err != nil {
return NewTTransportExceptionFromError(err)
}
- if err := hp.WriteString(value); err != nil {
+ if err := hp.WriteString(ctx, value); err != nil {
return NewTTransportExceptionFromError(err)
}
}
diff --git a/lib/go/thrift/header_transport_test.go b/lib/go/thrift/header_transport_test.go
index e3ae41b02..cee1cadc3 100644
--- a/lib/go/thrift/header_transport_test.go
+++ b/lib/go/thrift/header_transport_test.go
@@ -78,17 +78,17 @@ func TestTHeaderHeadersReadWrite(t *testing.T) {
// Read
// Make sure multiple calls to ReadFrame is fine.
- if err := reader.ReadFrame(); err != nil {
+ if err := reader.ReadFrame(context.Background()); err != nil {
t.Errorf("reader.ReadFrame returned error: %v", err)
}
- if err := reader.ReadFrame(); err != nil {
+ if err := reader.ReadFrame(context.Background()); err != nil {
t.Errorf("reader.ReadFrame returned error: %v", err)
}
read, err := ioutil.ReadAll(reader)
if err != nil {
t.Errorf("Read returned error: %v", err)
}
- if err := reader.ReadFrame(); err != nil && err != io.EOF {
+ if err := reader.ReadFrame(context.Background()); err != nil && err != io.EOF {
t.Errorf("reader.ReadFrame returned error: %v", err)
}
if string(read) != payload1+payload2 {
diff --git a/lib/go/thrift/json_protocol.go b/lib/go/thrift/json_protocol.go
index 800ac22c7..9a9328dc7 100644
--- a/lib/go/thrift/json_protocol.go
+++ b/lib/go/thrift/json_protocol.go
@@ -57,43 +57,43 @@ func NewTJSONProtocolFactory() *TJSONProtocolFactory {
return &TJSONProtocolFactory{}
}
-func (p *TJSONProtocol) WriteMessageBegin(name string, typeId TMessageType, seqId int32) error {
+func (p *TJSONProtocol) WriteMessageBegin(ctx context.Context, name string, typeId TMessageType, seqId int32) error {
p.resetContextStack() // THRIFT-3735
if e := p.OutputListBegin(); e != nil {
return e
}
- if e := p.WriteI32(THRIFT_JSON_PROTOCOL_VERSION); e != nil {
+ if e := p.WriteI32(ctx, THRIFT_JSON_PROTOCOL_VERSION); e != nil {
return e
}
- if e := p.WriteString(name); e != nil {
+ if e := p.WriteString(ctx, name); e != nil {
return e
}
- if e := p.WriteByte(int8(typeId)); e != nil {
+ if e := p.WriteByte(ctx, int8(typeId)); e != nil {
return e
}
- if e := p.WriteI32(seqId); e != nil {
+ if e := p.WriteI32(ctx, seqId); e != nil {
return e
}
return nil
}
-func (p *TJSONProtocol) WriteMessageEnd() error {
+func (p *TJSONProtocol) WriteMessageEnd(ctx context.Context) error {
return p.OutputListEnd()
}
-func (p *TJSONProtocol) WriteStructBegin(name string) error {
+func (p *TJSONProtocol) WriteStructBegin(ctx context.Context, name string) error {
if e := p.OutputObjectBegin(); e != nil {
return e
}
return nil
}
-func (p *TJSONProtocol) WriteStructEnd() error {
+func (p *TJSONProtocol) WriteStructEnd(ctx context.Context) error {
return p.OutputObjectEnd()
}
-func (p *TJSONProtocol) WriteFieldBegin(name string, typeId TType, id int16) error {
- if e := p.WriteI16(id); e != nil {
+func (p *TJSONProtocol) WriteFieldBegin(ctx context.Context, name string, typeId TType, id int16) error {
+ if e := p.WriteI16(ctx, id); e != nil {
return e
}
if e := p.OutputObjectBegin(); e != nil {
@@ -103,19 +103,19 @@ func (p *TJSONProtocol) WriteFieldBegin(name string, typeId TType, id int16) err
if e1 != nil {
return e1
}
- if e := p.WriteString(s); e != nil {
+ if e := p.WriteString(ctx, s); e != nil {
return e
}
return nil
}
-func (p *TJSONProtocol) WriteFieldEnd() error {
+func (p *TJSONProtocol) WriteFieldEnd(ctx context.Context) error {
return p.OutputObjectEnd()
}
-func (p *TJSONProtocol) WriteFieldStop() error { return nil }
+func (p *TJSONProtocol) WriteFieldStop(ctx context.Context) error { return nil }
-func (p *TJSONProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error {
+func (p *TJSONProtocol) WriteMapBegin(ctx context.Context, keyType TType, valueType TType, size int) error {
if e := p.OutputListBegin(); e != nil {
return e
}
@@ -123,77 +123,77 @@ func (p *TJSONProtocol) WriteMapBegin(keyType TType, valueType TType, size int)
if e1 != nil {
return e1
}
- if e := p.WriteString(s); e != nil {
+ if e := p.WriteString(ctx, s); e != nil {
return e
}
s, e1 = p.TypeIdToString(valueType)
if e1 != nil {
return e1
}
- if e := p.WriteString(s); e != nil {
+ if e := p.WriteString(ctx, s); e != nil {
return e
}
- if e := p.WriteI64(int64(size)); e != nil {
+ if e := p.WriteI64(ctx, int64(size)); e != nil {
return e
}
return p.OutputObjectBegin()
}
-func (p *TJSONProtocol) WriteMapEnd() error {
+func (p *TJSONProtocol) WriteMapEnd(ctx context.Context) error {
if e := p.OutputObjectEnd(); e != nil {
return e
}
return p.OutputListEnd()
}
-func (p *TJSONProtocol) WriteListBegin(elemType TType, size int) error {
+func (p *TJSONProtocol) WriteListBegin(ctx context.Context, elemType TType, size int) error {
return p.OutputElemListBegin(elemType, size)
}
-func (p *TJSONProtocol) WriteListEnd() error {
+func (p *TJSONProtocol) WriteListEnd(ctx context.Context) error {
return p.OutputListEnd()
}
-func (p *TJSONProtocol) WriteSetBegin(elemType TType, size int) error {
+func (p *TJSONProtocol) WriteSetBegin(ctx context.Context, elemType TType, size int) error {
return p.OutputElemListBegin(elemType, size)
}
-func (p *TJSONProtocol) WriteSetEnd() error {
+func (p *TJSONProtocol) WriteSetEnd(ctx context.Context) error {
return p.OutputListEnd()
}
-func (p *TJSONProtocol) WriteBool(b bool) error {
+func (p *TJSONProtocol) WriteBool(ctx context.Context, b bool) error {
if b {
- return p.WriteI32(1)
+ return p.WriteI32(ctx, 1)
}
- return p.WriteI32(0)
+ return p.WriteI32(ctx, 0)
}
-func (p *TJSONProtocol) WriteByte(b int8) error {
- return p.WriteI32(int32(b))
+func (p *TJSONProtocol) WriteByte(ctx context.Context, b int8) error {
+ return p.WriteI32(ctx, int32(b))
}
-func (p *TJSONProtocol) WriteI16(v int16) error {
- return p.WriteI32(int32(v))
+func (p *TJSONProtocol) WriteI16(ctx context.Context, v int16) error {
+ return p.WriteI32(ctx, int32(v))
}
-func (p *TJSONProtocol) WriteI32(v int32) error {
+func (p *TJSONProtocol) WriteI32(ctx context.Context, v int32) error {
return p.OutputI64(int64(v))
}
-func (p *TJSONProtocol) WriteI64(v int64) error {
+func (p *TJSONProtocol) WriteI64(ctx context.Context, v int64) error {
return p.OutputI64(int64(v))
}
-func (p *TJSONProtocol) WriteDouble(v float64) error {
+func (p *TJSONProtocol) WriteDouble(ctx context.Context, v float64) error {
return p.OutputF64(v)
}
-func (p *TJSONProtocol) WriteString(v string) error {
+func (p *TJSONProtocol) WriteString(ctx context.Context, v string) error {
return p.OutputString(v)
}
-func (p *TJSONProtocol) WriteBinary(v []byte) error {
+func (p *TJSONProtocol) WriteBinary(ctx context.Context, v []byte) error {
// JSON library only takes in a string,
// not an arbitrary byte array, to ensure bytes are transmitted
// efficiently we must convert this into a valid JSON string
@@ -219,12 +219,12 @@ func (p *TJSONProtocol) WriteBinary(v []byte) error {
}
// Reading methods.
-func (p *TJSONProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqId int32, err error) {
+func (p *TJSONProtocol) ReadMessageBegin(ctx context.Context) (name string, typeId TMessageType, seqId int32, err error) {
p.resetContextStack() // THRIFT-3735
if isNull, err := p.ParseListBegin(); isNull || err != nil {
return name, typeId, seqId, err
}
- version, err := p.ReadI32()
+ version, err := p.ReadI32(ctx)
if err != nil {
return name, typeId, seqId, err
}
@@ -233,47 +233,47 @@ func (p *TJSONProtocol) ReadMessageBegin() (name string, typeId TMessageType, se
return name, typeId, seqId, NewTProtocolExceptionWithType(INVALID_DATA, e)
}
- if name, err = p.ReadString(); err != nil {
+ if name, err = p.ReadString(ctx); err != nil {
return name, typeId, seqId, err
}
- bTypeId, err := p.ReadByte()
+ bTypeId, err := p.ReadByte(ctx)
typeId = TMessageType(bTypeId)
if err != nil {
return name, typeId, seqId, err
}
- if seqId, err = p.ReadI32(); err != nil {
+ if seqId, err = p.ReadI32(ctx); err != nil {
return name, typeId, seqId, err
}
return name, typeId, seqId, nil
}
-func (p *TJSONProtocol) ReadMessageEnd() error {
+func (p *TJSONProtocol) ReadMessageEnd(ctx context.Context) error {
err := p.ParseListEnd()
return err
}
-func (p *TJSONProtocol) ReadStructBegin() (name string, err error) {
+func (p *TJSONProtocol) ReadStructBegin(ctx context.Context) (name string, err error) {
_, err = p.ParseObjectStart()
return "", err
}
-func (p *TJSONProtocol) ReadStructEnd() error {
+func (p *TJSONProtocol) ReadStructEnd(ctx context.Context) error {
return p.ParseObjectEnd()
}
-func (p *TJSONProtocol) ReadFieldBegin() (string, TType, int16, error) {
+func (p *TJSONProtocol) ReadFieldBegin(ctx context.Context) (string, TType, int16, error) {
b, _ := p.reader.Peek(1)
if len(b) < 1 || b[0] == JSON_RBRACE[0] || b[0] == JSON_RBRACKET[0] {
return "", STOP, -1, nil
}
- fieldId, err := p.ReadI16()
+ fieldId, err := p.ReadI16(ctx)
if err != nil {
return "", STOP, fieldId, err
}
if _, err = p.ParseObjectStart(); err != nil {
return "", STOP, fieldId, err
}
- sType, err := p.ReadString()
+ sType, err := p.ReadString(ctx)
if err != nil {
return "", STOP, fieldId, err
}
@@ -281,17 +281,17 @@ func (p *TJSONProtocol) ReadFieldBegin() (string, TType, int16, error) {
return "", fType, fieldId, err
}
-func (p *TJSONProtocol) ReadFieldEnd() error {
+func (p *TJSONProtocol) ReadFieldEnd(ctx context.Context) error {
return p.ParseObjectEnd()
}
-func (p *TJSONProtocol) ReadMapBegin() (keyType TType, valueType TType, size int, e error) {
+func (p *TJSONProtocol) ReadMapBegin(ctx context.Context) (keyType TType, valueType TType, size int, e error) {
if isNull, e := p.ParseListBegin(); isNull || e != nil {
return VOID, VOID, 0, e
}
// read keyType
- sKeyType, e := p.ReadString()
+ sKeyType, e := p.ReadString(ctx)
if e != nil {
return keyType, valueType, size, e
}
@@ -301,7 +301,7 @@ func (p *TJSONProtocol) ReadMapBegin() (keyType TType, valueType TType, size int
}
// read valueType
- sValueType, e := p.ReadString()
+ sValueType, e := p.ReadString(ctx)
if e != nil {
return keyType, valueType, size, e
}
@@ -311,7 +311,7 @@ func (p *TJSONProtocol) ReadMapBegin() (keyType TType, valueType TType, size int
}
// read size
- iSize, e := p.ReadI64()
+ iSize, e := p.ReadI64(ctx)
if e != nil {
return keyType, valueType, size, e
}
@@ -321,7 +321,7 @@ func (p *TJSONProtocol) ReadMapBegin() (keyType TType, valueType TType, size int
return keyType, valueType, size, e
}
-func (p *TJSONProtocol) ReadMapEnd() error {
+func (p *TJSONProtocol) ReadMapEnd(ctx context.Context) error {
e := p.ParseObjectEnd()
if e != nil {
return e
@@ -329,53 +329,53 @@ func (p *TJSONProtocol) ReadMapEnd() error {
return p.ParseListEnd()
}
-func (p *TJSONProtocol) ReadListBegin() (elemType TType, size int, e error) {
+func (p *TJSONProtocol) ReadListBegin(ctx context.Context) (elemType TType, size int, e error) {
return p.ParseElemListBegin()
}
-func (p *TJSONProtocol) ReadListEnd() error {
+func (p *TJSONProtocol) ReadListEnd(ctx context.Context) error {
return p.ParseListEnd()
}
-func (p *TJSONProtocol) ReadSetBegin() (elemType TType, size int, e error) {
+func (p *TJSONProtocol) ReadSetBegin(ctx context.Context) (elemType TType, size int, e error) {
return p.ParseElemListBegin()
}
-func (p *TJSONProtocol) ReadSetEnd() error {
+func (p *TJSONProtocol) ReadSetEnd(ctx context.Context) error {
return p.ParseListEnd()
}
-func (p *TJSONProtocol) ReadBool() (bool, error) {
- value, err := p.ReadI32()
+func (p *TJSONProtocol) ReadBool(ctx context.Context) (bool, error) {
+ value, err := p.ReadI32(ctx)
return (value != 0), err
}
-func (p *TJSONProtocol) ReadByte() (int8, error) {
- v, err := p.ReadI64()
+func (p *TJSONProtocol) ReadByte(ctx context.Context) (int8, error) {
+ v, err := p.ReadI64(ctx)
return int8(v), err
}
-func (p *TJSONProtocol) ReadI16() (int16, error) {
- v, err := p.ReadI64()
+func (p *TJSONProtocol) ReadI16(ctx context.Context) (int16, error) {
+ v, err := p.ReadI64(ctx)
return int16(v), err
}
-func (p *TJSONProtocol) ReadI32() (int32, error) {
- v, err := p.ReadI64()
+func (p *TJSONProtocol) ReadI32(ctx context.Context) (int32, error) {
+ v, err := p.ReadI64(ctx)
return int32(v), err
}
-func (p *TJSONProtocol) ReadI64() (int64, error) {
+func (p *TJSONProtocol) ReadI64(ctx context.Context) (int64, error) {
v, _, err := p.ParseI64()
return v, err
}
-func (p *TJSONProtocol) ReadDouble() (float64, error) {
+func (p *TJSONProtocol) ReadDouble(ctx context.Context) (float64, error) {
v, _, err := p.ParseF64()
return v, err
}
-func (p *TJSONProtocol) ReadString() (string, error) {
+func (p *TJSONProtocol) ReadString(ctx context.Context) (string, error) {
var v string
if err := p.ParsePreValue(); err != nil {
return v, err
@@ -405,7 +405,7 @@ func (p *TJSONProtocol) ReadString() (string, error) {
return v, p.ParsePostValue()
}
-func (p *TJSONProtocol) ReadBinary() ([]byte, error) {
+func (p *TJSONProtocol) ReadBinary(ctx context.Context) ([]byte, error) {
var v []byte
if err := p.ParsePreValue(); err != nil {
return nil, err
@@ -444,8 +444,8 @@ func (p *TJSONProtocol) Flush(ctx context.Context) (err error) {
return NewTProtocolException(err)
}
-func (p *TJSONProtocol) Skip(fieldType TType) (err error) {
- return SkipDefaultDepth(p, fieldType)
+func (p *TJSONProtocol) Skip(ctx context.Context, fieldType TType) (err error) {
+ return SkipDefaultDepth(ctx, p, fieldType)
}
func (p *TJSONProtocol) Transport() TTransport {
@@ -460,10 +460,10 @@ func (p *TJSONProtocol) OutputElemListBegin(elemType TType, size int) error {
if e1 != nil {
return e1
}
- if e := p.WriteString(s); e != nil {
+ if e := p.OutputString(s); e != nil {
return e
}
- if e := p.WriteI64(int64(size)); e != nil {
+ if e := p.OutputI64(int64(size)); e != nil {
return e
}
return nil
@@ -473,7 +473,11 @@ func (p *TJSONProtocol) ParseElemListBegin() (elemType TType, size int, e error)
if isNull, e := p.ParseListBegin(); isNull || e != nil {
return VOID, 0, e
}
- sElemType, err := p.ReadString()
+ // We don't really use the ctx in ReadString implementation,
+ // so this is safe for now.
+ // We might want to add context to ParseElemListBegin if we start to use
+ // ctx in ReadString implementation in the future.
+ sElemType, err := p.ReadString(context.Background())
if err != nil {
return VOID, size, err
}
@@ -481,7 +485,7 @@ func (p *TJSONProtocol) ParseElemListBegin() (elemType TType, size int, e error)
if err != nil {
return elemType, size, err
}
- nSize, err2 := p.ReadI64()
+ nSize, _, err2 := p.ParseI64()
size = int(nSize)
return elemType, size, err2
}
@@ -490,7 +494,11 @@ func (p *TJSONProtocol) readElemListBegin() (elemType TType, size int, e error)
if isNull, e := p.ParseListBegin(); isNull || e != nil {
return VOID, 0, e
}
- sElemType, err := p.ReadString()
+ // We don't really use the ctx in ReadString implementation,
+ // so this is safe for now.
+ // We might want to add context to ParseElemListBegin if we start to use
+ // ctx in ReadString implementation in the future.
+ sElemType, err := p.ReadString(context.Background())
if err != nil {
return VOID, size, err
}
@@ -498,7 +506,7 @@ func (p *TJSONProtocol) readElemListBegin() (elemType TType, size int, e error)
if err != nil {
return elemType, size, err
}
- nSize, err2 := p.ReadI64()
+ nSize, _, err2 := p.ParseI64()
size = int(nSize)
return elemType, size, err2
}
diff --git a/lib/go/thrift/json_protocol_test.go b/lib/go/thrift/json_protocol_test.go
index 59c4d64a2..07afa9683 100644
--- a/lib/go/thrift/json_protocol_test.go
+++ b/lib/go/thrift/json_protocol_test.go
@@ -34,7 +34,7 @@ func TestWriteJSONProtocolBool(t *testing.T) {
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
for _, value := range BOOL_VALUES {
- if e := p.WriteBool(value); e != nil {
+ if e := p.WriteBool(context.Background(), value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
if e := p.Flush(context.Background()); e != nil {
@@ -71,7 +71,7 @@ func TestReadJSONProtocolBool(t *testing.T) {
}
trans.Flush(context.Background())
s := trans.String()
- v, e := p.ReadBool()
+ v, e := p.ReadBool(context.Background())
if e != nil {
t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error())
}
@@ -92,7 +92,7 @@ func TestWriteJSONProtocolByte(t *testing.T) {
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
for _, value := range BYTE_VALUES {
- if e := p.WriteByte(value); e != nil {
+ if e := p.WriteByte(context.Background(), value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
if e := p.Flush(context.Background()); e != nil {
@@ -119,7 +119,7 @@ func TestReadJSONProtocolByte(t *testing.T) {
trans.WriteString(strconv.Itoa(int(value)))
trans.Flush(context.Background())
s := trans.String()
- v, e := p.ReadByte()
+ v, e := p.ReadByte(context.Background())
if e != nil {
t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error())
}
@@ -139,7 +139,7 @@ func TestWriteJSONProtocolI16(t *testing.T) {
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
for _, value := range INT16_VALUES {
- if e := p.WriteI16(value); e != nil {
+ if e := p.WriteI16(context.Background(), value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
if e := p.Flush(context.Background()); e != nil {
@@ -166,7 +166,7 @@ func TestReadJSONProtocolI16(t *testing.T) {
trans.WriteString(strconv.Itoa(int(value)))
trans.Flush(context.Background())
s := trans.String()
- v, e := p.ReadI16()
+ v, e := p.ReadI16(context.Background())
if e != nil {
t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error())
}
@@ -186,7 +186,7 @@ func TestWriteJSONProtocolI32(t *testing.T) {
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
for _, value := range INT32_VALUES {
- if e := p.WriteI32(value); e != nil {
+ if e := p.WriteI32(context.Background(), value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
if e := p.Flush(context.Background()); e != nil {
@@ -213,7 +213,7 @@ func TestReadJSONProtocolI32(t *testing.T) {
trans.WriteString(strconv.Itoa(int(value)))
trans.Flush(context.Background())
s := trans.String()
- v, e := p.ReadI32()
+ v, e := p.ReadI32(context.Background())
if e != nil {
t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error())
}
@@ -233,7 +233,7 @@ func TestWriteJSONProtocolI64(t *testing.T) {
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
for _, value := range INT64_VALUES {
- if e := p.WriteI64(value); e != nil {
+ if e := p.WriteI64(context.Background(), value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
if e := p.Flush(context.Background()); e != nil {
@@ -260,7 +260,7 @@ func TestReadJSONProtocolI64(t *testing.T) {
trans.WriteString(strconv.FormatInt(value, 10))
trans.Flush(context.Background())
s := trans.String()
- v, e := p.ReadI64()
+ v, e := p.ReadI64(context.Background())
if e != nil {
t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error())
}
@@ -280,7 +280,7 @@ func TestWriteJSONProtocolDouble(t *testing.T) {
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
for _, value := range DOUBLE_VALUES {
- if e := p.WriteDouble(value); e != nil {
+ if e := p.WriteDouble(context.Background(), value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
if e := p.Flush(context.Background()); e != nil {
@@ -322,7 +322,7 @@ func TestReadJSONProtocolDouble(t *testing.T) {
trans.WriteString(n.String())
trans.Flush(context.Background())
s := trans.String()
- v, e := p.ReadDouble()
+ v, e := p.ReadDouble(context.Background())
if e != nil {
t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error())
}
@@ -356,7 +356,7 @@ func TestWriteJSONProtocolString(t *testing.T) {
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
for _, value := range STRING_VALUES {
- if e := p.WriteString(value); e != nil {
+ if e := p.WriteString(context.Background(), value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
if e := p.Flush(context.Background()); e != nil {
@@ -383,7 +383,7 @@ func TestReadJSONProtocolString(t *testing.T) {
trans.WriteString(jsonQuote(value))
trans.Flush(context.Background())
s := trans.String()
- v, e := p.ReadString()
+ v, e := p.ReadString(context.Background())
if e != nil {
t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error())
}
@@ -407,7 +407,7 @@ func TestWriteJSONProtocolBinary(t *testing.T) {
b64String := string(b64value)
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
- if e := p.WriteBinary(value); e != nil {
+ if e := p.WriteBinary(context.Background(), value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
if e := p.Flush(context.Background()); e != nil {
@@ -418,7 +418,7 @@ func TestWriteJSONProtocolBinary(t *testing.T) {
if s != expectedString {
t.Fatalf("Bad value for %s %v\n wrote: \"%v\"\nexpected: \"%v\"", thetype, value, s, expectedString)
}
- v1, err := p.ReadBinary()
+ v1, err := p.ReadBinary(context.Background())
if err != nil {
t.Fatalf("Unable to read binary: %s", err.Error())
}
@@ -444,7 +444,7 @@ func TestReadJSONProtocolBinary(t *testing.T) {
trans.WriteString(jsonQuote(b64String))
trans.Flush(context.Background())
s := trans.String()
- v, e := p.ReadBinary()
+ v, e := p.ReadBinary(context.Background())
if e != nil {
t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error())
}
@@ -468,13 +468,13 @@ func TestWriteJSONProtocolList(t *testing.T) {
thetype := "list"
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
- p.WriteListBegin(TType(DOUBLE), len(DOUBLE_VALUES))
+ p.WriteListBegin(context.Background(), TType(DOUBLE), len(DOUBLE_VALUES))
for _, value := range DOUBLE_VALUES {
- if e := p.WriteDouble(value); e != nil {
+ if e := p.WriteDouble(context.Background(), value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
}
- p.WriteListEnd()
+ p.WriteListEnd(context.Background())
if e := p.Flush(context.Background()); e != nil {
t.Fatalf("Unable to write %s due to error flushing: %s", thetype, e.Error())
}
@@ -522,13 +522,13 @@ func TestWriteJSONProtocolSet(t *testing.T) {
thetype := "set"
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
- p.WriteSetBegin(TType(DOUBLE), len(DOUBLE_VALUES))
+ p.WriteSetBegin(context.Background(), TType(DOUBLE), len(DOUBLE_VALUES))
for _, value := range DOUBLE_VALUES {
- if e := p.WriteDouble(value); e != nil {
+ if e := p.WriteDouble(context.Background(), value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
}
- p.WriteSetEnd()
+ p.WriteSetEnd(context.Background())
if e := p.Flush(context.Background()); e != nil {
t.Fatalf("Unable to write %s due to error flushing: %s", thetype, e.Error())
}
@@ -576,16 +576,16 @@ func TestWriteJSONProtocolMap(t *testing.T) {
thetype := "map"
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
- p.WriteMapBegin(TType(I32), TType(DOUBLE), len(DOUBLE_VALUES))
+ p.WriteMapBegin(context.Background(), TType(I32), TType(DOUBLE), len(DOUBLE_VALUES))
for k, value := range DOUBLE_VALUES {
- if e := p.WriteI32(int32(k)); e != nil {
+ if e := p.WriteI32(context.Background(), int32(k)); e != nil {
t.Fatalf("Unable to write %s key int32 value %v due to error: %s", thetype, k, e.Error())
}
- if e := p.WriteDouble(value); e != nil {
+ if e := p.WriteDouble(context.Background(), value); e != nil {
t.Fatalf("Unable to write %s value float64 value %v due to error: %s", thetype, value, e.Error())
}
}
- p.WriteMapEnd()
+ p.WriteMapEnd(context.Background())
if e := p.Flush(context.Background()); e != nil {
t.Fatalf("Unable to write %s due to error flushing: %s", thetype, e.Error())
}
@@ -593,7 +593,7 @@ func TestWriteJSONProtocolMap(t *testing.T) {
if str[0] != '[' || str[len(str)-1] != ']' {
t.Fatalf("Bad value for %s, wrote: %v, in go: %v", thetype, str, DOUBLE_VALUES)
}
- expectedKeyType, expectedValueType, expectedSize, err := p.ReadMapBegin()
+ expectedKeyType, expectedValueType, expectedSize, err := p.ReadMapBegin(context.Background())
if err != nil {
t.Fatalf("Error while reading map begin: %s", err.Error())
}
@@ -607,14 +607,14 @@ func TestWriteJSONProtocolMap(t *testing.T) {
t.Fatal("Expected map size of ", len(DOUBLE_VALUES), ", but was ", expectedSize)
}
for k, value := range DOUBLE_VALUES {
- ik, err := p.ReadI32()
+ ik, err := p.ReadI32(context.Background())
if err != nil {
t.Fatalf("Bad key for %s index %v, wrote: %v, expected: %v, error: %s", thetype, k, ik, string(k), err.Error())
}
if int(ik) != k {
t.Fatalf("Bad key for %s index %v, wrote: %v, expected: %v", thetype, k, ik, k)
}
- dv, err := p.ReadDouble()
+ dv, err := p.ReadDouble(context.Background())
if err != nil {
t.Fatalf("Bad value for %s index %v, wrote: %v, expected: %v, error: %s", thetype, k, dv, value, err.Error())
}
@@ -642,7 +642,7 @@ func TestWriteJSONProtocolMap(t *testing.T) {
}
}
}
- err = p.ReadMapEnd()
+ err = p.ReadMapEnd(context.Background())
if err != nil {
t.Fatalf("Error while reading map end: %s", err.Error())
}
diff --git a/lib/go/thrift/multiplexed_protocol.go b/lib/go/thrift/multiplexed_protocol.go
index 9db59c4c9..2f7997e77 100644
--- a/lib/go/thrift/multiplexed_protocol.go
+++ b/lib/go/thrift/multiplexed_protocol.go
@@ -68,11 +68,11 @@ func NewTMultiplexedProtocol(protocol TProtocol, serviceName string) *TMultiplex
}
}
-func (t *TMultiplexedProtocol) WriteMessageBegin(name string, typeId TMessageType, seqid int32) error {
+func (t *TMultiplexedProtocol) WriteMessageBegin(ctx context.Context, name string, typeId TMessageType, seqid int32) error {
if typeId == CALL || typeId == ONEWAY {
- return t.TProtocol.WriteMessageBegin(t.serviceName+MULTIPLEXED_SEPARATOR+name, typeId, seqid)
+ return t.TProtocol.WriteMessageBegin(ctx, t.serviceName+MULTIPLEXED_SEPARATOR+name, typeId, seqid)
} else {
- return t.TProtocol.WriteMessageBegin(name, typeId, seqid)
+ return t.TProtocol.WriteMessageBegin(ctx, name, typeId, seqid)
}
}
@@ -190,7 +190,7 @@ func (t *TMultiplexedProcessor) RegisterProcessor(name string, processor TProces
}
func (t *TMultiplexedProcessor) Process(ctx context.Context, in, out TProtocol) (bool, TException) {
- name, typeId, seqid, err := in.ReadMessageBegin()
+ name, typeId, seqid, err := in.ReadMessageBegin(ctx)
if err != nil {
return false, err
}
@@ -226,6 +226,6 @@ func NewStoredMessageProtocol(protocol TProtocol, name string, typeId TMessageTy
return &storedMessageProtocol{protocol, name, typeId, seqid}
}
-func (s *storedMessageProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqid int32, err error) {
+func (s *storedMessageProtocol) ReadMessageBegin(ctx context.Context) (name string, typeId TMessageType, seqid int32, err error) {
return s.name, s.typeId, s.seqid, nil
}
diff --git a/lib/go/thrift/protocol.go b/lib/go/thrift/protocol.go
index 2e6bc4b16..0a69bd416 100644
--- a/lib/go/thrift/protocol.go
+++ b/lib/go/thrift/protocol.go
@@ -31,50 +31,50 @@ const (
)
type TProtocol interface {
- WriteMessageBegin(name string, typeId TMessageType, seqid int32) error
- WriteMessageEnd() error
- WriteStructBegin(name string) error
- WriteStructEnd() error
- WriteFieldBegin(name string, typeId TType, id int16) error
- WriteFieldEnd() error
- WriteFieldStop() error
- WriteMapBegin(keyType TType, valueType TType, size int) error
- WriteMapEnd() error
- WriteListBegin(elemType TType, size int) error
- WriteListEnd() error
- WriteSetBegin(elemType TType, size int) error
- WriteSetEnd() error
- WriteBool(value bool) error
- WriteByte(value int8) error
- WriteI16(value int16) error
- WriteI32(value int32) error
- WriteI64(value int64) error
- WriteDouble(value float64) error
- WriteString(value string) error
- WriteBinary(value []byte) error
+ WriteMessageBegin(ctx context.Context, name string, typeId TMessageType, seqid int32) error
+ WriteMessageEnd(ctx context.Context) error
+ WriteStructBegin(ctx context.Context, name string) error
+ WriteStructEnd(ctx context.Context) error
+ WriteFieldBegin(ctx context.Context, name string, typeId TType, id int16) error
+ WriteFieldEnd(ctx context.Context) error
+ WriteFieldStop(ctx context.Context) error
+ WriteMapBegin(ctx context.Context, keyType TType, valueType TType, size int) error
+ WriteMapEnd(ctx context.Context) error
+ WriteListBegin(ctx context.Context, elemType TType, size int) error
+ WriteListEnd(ctx context.Context) error
+ WriteSetBegin(ctx context.Context, elemType TType, size int) error
+ WriteSetEnd(ctx context.Context) error
+ WriteBool(ctx context.Context, value bool) error
+ WriteByte(ctx context.Context, value int8) error
+ WriteI16(ctx context.Context, value int16) error
+ WriteI32(ctx context.Context, value int32) error
+ WriteI64(ctx context.Context, value int64) error
+ WriteDouble(ctx context.Context, value float64) error
+ WriteString(ctx context.Context, value string) error
+ WriteBinary(ctx context.Context, value []byte) error
- ReadMessageBegin() (name string, typeId TMessageType, seqid int32, err error)
- ReadMessageEnd() error
- ReadStructBegin() (name string, err error)
- ReadStructEnd() error
- ReadFieldBegin() (name string, typeId TType, id int16, err error)
- ReadFieldEnd() error
- ReadMapBegin() (keyType TType, valueType TType, size int, err error)
- ReadMapEnd() error
- ReadListBegin() (elemType TType, size int, err error)
- ReadListEnd() error
- ReadSetBegin() (elemType TType, size int, err error)
- ReadSetEnd() error
- ReadBool() (value bool, err error)
- ReadByte() (value int8, err error)
- ReadI16() (value int16, err error)
- ReadI32() (value int32, err error)
- ReadI64() (value int64, err error)
- ReadDouble() (value float64, err error)
- ReadString() (value string, err error)
- ReadBinary() (value []byte, err error)
+ ReadMessageBegin(ctx context.Context) (name string, typeId TMessageType, seqid int32, err error)
+ ReadMessageEnd(ctx context.Context) error
+ ReadStructBegin(ctx context.Context) (name string, err error)
+ ReadStructEnd(ctx context.Context) error
+ ReadFieldBegin(ctx context.Context) (name string, typeId TType, id int16, err error)
+ ReadFieldEnd(ctx context.Context) error
+ ReadMapBegin(ctx context.Context) (keyType TType, valueType TType, size int, err error)
+ ReadMapEnd(ctx context.Context) error
+ ReadListBegin(ctx context.Context) (elemType TType, size int, err error)
+ ReadListEnd(ctx context.Context) error
+ ReadSetBegin(ctx context.Context) (elemType TType, size int, err error)
+ ReadSetEnd(ctx context.Context) error
+ ReadBool(ctx context.Context) (value bool, err error)
+ ReadByte(ctx context.Context) (value int8, err error)
+ ReadI16(ctx context.Context) (value int16, err error)
+ ReadI32(ctx context.Context) (value int32, err error)
+ ReadI64(ctx context.Context) (value int64, err error)
+ ReadDouble(ctx context.Context) (value float64, err error)
+ ReadString(ctx context.Context) (value string, err error)
+ ReadBinary(ctx context.Context) (value []byte, err error)
- Skip(fieldType TType) (err error)
+ Skip(ctx context.Context, fieldType TType) (err error)
Flush(ctx context.Context) (err error)
Transport() TTransport
@@ -84,12 +84,12 @@ type TProtocol interface {
const DEFAULT_RECURSION_DEPTH = 64
// Skips over the next data element from the provided input TProtocol object.
-func SkipDefaultDepth(prot TProtocol, typeId TType) (err error) {
- return Skip(prot, typeId, DEFAULT_RECURSION_DEPTH)
+func SkipDefaultDepth(ctx context.Context, prot TProtocol, typeId TType) (err error) {
+ return Skip(ctx, prot, typeId, DEFAULT_RECURSION_DEPTH)
}
// Skips over the next data element from the provided input TProtocol object.
-func Skip(self TProtocol, fieldType TType, maxDepth int) (err error) {
+func Skip(ctx context.Context, self TProtocol, fieldType TType, maxDepth int) (err error) {
if maxDepth <= 0 {
return NewTProtocolExceptionWithType(DEPTH_LIMIT, errors.New("Depth limit exceeded"))
@@ -97,79 +97,79 @@ func Skip(self TProtocol, fieldType TType, maxDepth int) (err error) {
switch fieldType {
case BOOL:
- _, err = self.ReadBool()
+ _, err = self.ReadBool(ctx)
return
case BYTE:
- _, err = self.ReadByte()
+ _, err = self.ReadByte(ctx)
return
case I16:
- _, err = self.ReadI16()
+ _, err = self.ReadI16(ctx)
return
case I32:
- _, err = self.ReadI32()
+ _, err = self.ReadI32(ctx)
return
case I64:
- _, err = self.ReadI64()
+ _, err = self.ReadI64(ctx)
return
case DOUBLE:
- _, err = self.ReadDouble()
+ _, err = self.ReadDouble(ctx)
return
case STRING:
- _, err = self.ReadString()
+ _, err = self.ReadString(ctx)
return
case STRUCT:
- if _, err = self.ReadStructBegin(); err != nil {
+ if _, err = self.ReadStructBegin(ctx); err != nil {
return err
}
for {
- _, typeId, _, _ := self.ReadFieldBegin()
+ _, typeId, _, _ := self.ReadFieldBegin(ctx)
if typeId == STOP {
break
}
- err := Skip(self, typeId, maxDepth-1)
+ err := Skip(ctx, self, typeId, maxDepth-1)
if err != nil {
return err
}
- self.ReadFieldEnd()
+ self.ReadFieldEnd(ctx)
}
- return self.ReadStructEnd()
+ return self.ReadStructEnd(ctx)
case MAP:
- keyType, valueType, size, err := self.ReadMapBegin()
+ keyType, valueType, size, err := self.ReadMapBegin(ctx)
if err != nil {
return err
}
for i := 0; i < size; i++ {
- err := Skip(self, keyType, maxDepth-1)
+ err := Skip(ctx, self, keyType, maxDepth-1)
if err != nil {
return err
}
- self.Skip(valueType)
+ self.Skip(ctx, valueType)
}
- return self.ReadMapEnd()
+ return self.ReadMapEnd(ctx)
case SET:
- elemType, size, err := self.ReadSetBegin()
+ elemType, size, err := self.ReadSetBegin(ctx)
if err != nil {
return err
}
for i := 0; i < size; i++ {
- err := Skip(self, elemType, maxDepth-1)
+ err := Skip(ctx, self, elemType, maxDepth-1)
if err != nil {
return err
}
}
- return self.ReadSetEnd()
+ return self.ReadSetEnd(ctx)
case LIST:
- elemType, size, err := self.ReadListBegin()
+ elemType, size, err := self.ReadListBegin(ctx)
if err != nil {
return err
}
for i := 0; i < size; i++ {
- err := Skip(self, elemType, maxDepth-1)
+ err := Skip(ctx, self, elemType, maxDepth-1)
if err != nil {
return err
}
}
- return self.ReadListEnd()
+ return self.ReadListEnd(ctx)
default:
return NewTProtocolExceptionWithType(INVALID_DATA, errors.New(fmt.Sprintf("Unknown data type %d", fieldType)))
}
diff --git a/lib/go/thrift/protocol_test.go b/lib/go/thrift/protocol_test.go
index 944055c0b..c1c67e8ca 100644
--- a/lib/go/thrift/protocol_test.go
+++ b/lib/go/thrift/protocol_test.go
@@ -222,22 +222,22 @@ func ReadWriteProtocolTest(t *testing.T, protocolFactory TProtocolFactory) {
func ReadWriteBool(t testing.TB, p TProtocol, trans TTransport) {
thetype := TType(BOOL)
thelen := len(BOOL_VALUES)
- err := p.WriteListBegin(thetype, thelen)
+ err := p.WriteListBegin(context.Background(), thetype, thelen)
if err != nil {
t.Errorf("%s: %T %T %q Error writing list begin: %q", "ReadWriteBool", p, trans, err, thetype)
}
for k, v := range BOOL_VALUES {
- err = p.WriteBool(v)
+ err = p.WriteBool(context.Background(), v)
if err != nil {
t.Errorf("%s: %T %T %v Error writing bool in list at index %v: %v", "ReadWriteBool", p, trans, err, k, v)
}
}
- p.WriteListEnd()
+ p.WriteListEnd(context.Background())
if err != nil {
t.Errorf("%s: %T %T %v Error writing list end: %v", "ReadWriteBool", p, trans, err, BOOL_VALUES)
}
p.Flush(context.Background())
- thetype2, thelen2, err := p.ReadListBegin()
+ thetype2, thelen2, err := p.ReadListBegin(context.Background())
if err != nil {
t.Errorf("%s: %T %T %v Error reading list: %v", "ReadWriteBool", p, trans, err, BOOL_VALUES)
}
@@ -251,7 +251,7 @@ func ReadWriteBool(t testing.TB, p TProtocol, trans TTransport) {
}
}
for k, v := range BOOL_VALUES {
- value, err := p.ReadBool()
+ value, err := p.ReadBool(context.Background())
if err != nil {
t.Errorf("%s: %T %T %v Error reading bool at index %v: %v", "ReadWriteBool", p, trans, err, k, v)
}
@@ -259,7 +259,7 @@ func ReadWriteBool(t testing.TB, p TProtocol, trans TTransport) {
t.Errorf("%s: index %v %v %v %v != %v", "ReadWriteBool", k, p, trans, v, value)
}
}
- err = p.ReadListEnd()
+ err = p.ReadListEnd(context.Background())
if err != nil {
t.Errorf("%s: %T %T Unable to read list end: %q", "ReadWriteBool", p, trans, err)
}
@@ -268,17 +268,17 @@ func ReadWriteBool(t testing.TB, p TProtocol, trans TTransport) {
func ReadWriteByte(t testing.TB, p TProtocol, trans TTransport) {
thetype := TType(BYTE)
thelen := len(BYTE_VALUES)
- err := p.WriteListBegin(thetype, thelen)
+ err := p.WriteListBegin(context.Background(), thetype, thelen)
if err != nil {
t.Errorf("%s: %T %T %q Error writing list begin: %q", "ReadWriteByte", p, trans, err, thetype)
}
for k, v := range BYTE_VALUES {
- err = p.WriteByte(v)
+ err = p.WriteByte(context.Background(), v)
if err != nil {
t.Errorf("%s: %T %T %q Error writing byte in list at index %d: %q", "ReadWriteByte", p, trans, err, k, v)
}
}
- err = p.WriteListEnd()
+ err = p.WriteListEnd(context.Background())
if err != nil {
t.Errorf("%s: %T %T %q Error writing list end: %q", "ReadWriteByte", p, trans, err, BYTE_VALUES)
}
@@ -286,7 +286,7 @@ func ReadWriteByte(t testing.TB, p TProtocol, trans TTransport) {
if err != nil {
t.Errorf("%s: %T %T %q Error flushing list of bytes: %q", "ReadWriteByte", p, trans, err, BYTE_VALUES)
}
- thetype2, thelen2, err := p.ReadListBegin()
+ thetype2, thelen2, err := p.ReadListBegin(context.Background())
if err != nil {
t.Errorf("%s: %T %T %q Error reading list: %q", "ReadWriteByte", p, trans, err, BYTE_VALUES)
}
@@ -300,7 +300,7 @@ func ReadWriteByte(t testing.TB, p TProtocol, trans TTransport) {
}
}
for k, v := range BYTE_VALUES {
- value, err := p.ReadByte()
+ value, err := p.ReadByte(context.Background())
if err != nil {
t.Errorf("%s: %T %T %q Error reading byte at index %d: %q", "ReadWriteByte", p, trans, err, k, v)
}
@@ -308,7 +308,7 @@ func ReadWriteByte(t testing.TB, p TProtocol, trans TTransport) {
t.Errorf("%s: %T %T %d != %d", "ReadWriteByte", p, trans, v, value)
}
}
- err = p.ReadListEnd()
+ err = p.ReadListEnd(context.Background())
if err != nil {
t.Errorf("%s: %T %T Unable to read list end: %q", "ReadWriteByte", p, trans, err)
}
@@ -317,13 +317,13 @@ func ReadWriteByte(t testing.TB, p TProtocol, trans TTransport) {
func ReadWriteI16(t testing.TB, p TProtocol, trans TTransport) {
thetype := TType(I16)
thelen := len(INT16_VALUES)
- p.WriteListBegin(thetype, thelen)
+ p.WriteListBegin(context.Background(), thetype, thelen)
for _, v := range INT16_VALUES {
- p.WriteI16(v)
+ p.WriteI16(context.Background(), v)
}
- p.WriteListEnd()
+ p.WriteListEnd(context.Background())
p.Flush(context.Background())
- thetype2, thelen2, err := p.ReadListBegin()
+ thetype2, thelen2, err := p.ReadListBegin(context.Background())
if err != nil {
t.Errorf("%s: %T %T %q Error reading list: %q", "ReadWriteI16", p, trans, err, INT16_VALUES)
}
@@ -337,7 +337,7 @@ func ReadWriteI16(t testing.TB, p TProtocol, trans TTransport) {
}
}
for k, v := range INT16_VALUES {
- value, err := p.ReadI16()
+ value, err := p.ReadI16(context.Background())
if err != nil {
t.Errorf("%s: %T %T %q Error reading int16 at index %d: %q", "ReadWriteI16", p, trans, err, k, v)
}
@@ -345,7 +345,7 @@ func ReadWriteI16(t testing.TB, p TProtocol, trans TTransport) {
t.Errorf("%s: %T %T %d != %d", "ReadWriteI16", p, trans, v, value)
}
}
- err = p.ReadListEnd()
+ err = p.ReadListEnd(context.Background())
if err != nil {
t.Errorf("%s: %T %T Unable to read list end: %q", "ReadWriteI16", p, trans, err)
}
@@ -354,13 +354,13 @@ func ReadWriteI16(t testing.TB, p TProtocol, trans TTransport) {
func ReadWriteI32(t testing.TB, p TProtocol, trans TTransport) {
thetype := TType(I32)
thelen := len(INT32_VALUES)
- p.WriteListBegin(thetype, thelen)
+ p.WriteListBegin(context.Background(), thetype, thelen)
for _, v := range INT32_VALUES {
- p.WriteI32(v)
+ p.WriteI32(context.Background(), v)
}
- p.WriteListEnd()
+ p.WriteListEnd(context.Background())
p.Flush(context.Background())
- thetype2, thelen2, err := p.ReadListBegin()
+ thetype2, thelen2, err := p.ReadListBegin(context.Background())
if err != nil {
t.Errorf("%s: %T %T %q Error reading list: %q", "ReadWriteI32", p, trans, err, INT32_VALUES)
}
@@ -374,7 +374,7 @@ func ReadWriteI32(t testing.TB, p TProtocol, trans TTransport) {
}
}
for k, v := range INT32_VALUES {
- value, err := p.ReadI32()
+ value, err := p.ReadI32(context.Background())
if err != nil {
t.Errorf("%s: %T %T %q Error reading int32 at index %d: %q", "ReadWriteI32", p, trans, err, k, v)
}
@@ -390,13 +390,13 @@ func ReadWriteI32(t testing.TB, p TProtocol, trans TTransport) {
func ReadWriteI64(t testing.TB, p TProtocol, trans TTransport) {
thetype := TType(I64)
thelen := len(INT64_VALUES)
- p.WriteListBegin(thetype, thelen)
+ p.WriteListBegin(context.Background(), thetype, thelen)
for _, v := range INT64_VALUES {
- p.WriteI64(v)
+ p.WriteI64(context.Background(), v)
}
- p.WriteListEnd()
+ p.WriteListEnd(context.Background())
p.Flush(context.Background())
- thetype2, thelen2, err := p.ReadListBegin()
+ thetype2, thelen2, err := p.ReadListBegin(context.Background())
if err != nil {
t.Errorf("%s: %T %T %q Error reading list: %q", "ReadWriteI64", p, trans, err, INT64_VALUES)
}
@@ -410,7 +410,7 @@ func ReadWriteI64(t testing.TB, p TProtocol, trans TTransport) {
}
}
for k, v := range INT64_VALUES {
- value, err := p.ReadI64()
+ value, err := p.ReadI64(context.Background())
if err != nil {
t.Errorf("%s: %T %T %q Error reading int64 at index %d: %q", "ReadWriteI64", p, trans, err, k, v)
}
@@ -426,13 +426,13 @@ func ReadWriteI64(t testing.TB, p TProtocol, trans TTransport) {
func ReadWriteDouble(t testing.TB, p TProtocol, trans TTransport) {
thetype := TType(DOUBLE)
thelen := len(DOUBLE_VALUES)
- p.WriteListBegin(thetype, thelen)
+ p.WriteListBegin(context.Background(), thetype, thelen)
for _, v := range DOUBLE_VALUES {
- p.WriteDouble(v)
+ p.WriteDouble(context.Background(), v)
}
- p.WriteListEnd()
+ p.WriteListEnd(context.Background())
p.Flush(context.Background())
- thetype2, thelen2, err := p.ReadListBegin()
+ thetype2, thelen2, err := p.ReadListBegin(context.Background())
if err != nil {
t.Errorf("%s: %T %T %v Error reading list: %v", "ReadWriteDouble", p, trans, err, DOUBLE_VALUES)
}
@@ -443,7 +443,7 @@ func ReadWriteDouble(t testing.TB, p TProtocol, trans TTransport) {
t.Errorf("%s: %T %T len %v != len %v", "ReadWriteDouble", p, trans, thelen, thelen2)
}
for k, v := range DOUBLE_VALUES {
- value, err := p.ReadDouble()
+ value, err := p.ReadDouble(context.Background())
if err != nil {
t.Errorf("%s: %T %T %q Error reading double at index %d: %v", "ReadWriteDouble", p, trans, err, k, v)
}
@@ -455,7 +455,7 @@ func ReadWriteDouble(t testing.TB, p TProtocol, trans TTransport) {
t.Errorf("%s: %T %T %v != %v", "ReadWriteDouble", p, trans, v, value)
}
}
- err = p.ReadListEnd()
+ err = p.ReadListEnd(context.Background())
if err != nil {
t.Errorf("%s: %T %T Unable to read list end: %q", "ReadWriteDouble", p, trans, err)
}
@@ -464,13 +464,13 @@ func ReadWriteDouble(t testing.TB, p TProtocol, trans TTransport) {
func ReadWriteString(t testing.TB, p TProtocol, trans TTransport) {
thetype := TType(STRING)
thelen := len(STRING_VALUES)
- p.WriteListBegin(thetype, thelen)
+ p.WriteListBegin(context.Background(), thetype, thelen)
for _, v := range STRING_VALUES {
- p.WriteString(v)
+ p.WriteString(context.Background(), v)
}
- p.WriteListEnd()
+ p.WriteListEnd(context.Background())
p.Flush(context.Background())
- thetype2, thelen2, err := p.ReadListBegin()
+ thetype2, thelen2, err := p.ReadListBegin(context.Background())
if err != nil {
t.Errorf("%s: %T %T %q Error reading list: %q", "ReadWriteString", p, trans, err, STRING_VALUES)
}
@@ -484,7 +484,7 @@ func ReadWriteString(t testing.TB, p TProtocol, trans TTransport) {
}
}
for k, v := range STRING_VALUES {
- value, err := p.ReadString()
+ value, err := p.ReadString(context.Background())
if err != nil {
t.Errorf("%s: %T %T %q Error reading string at index %d: %q", "ReadWriteString", p, trans, err, k, v)
}
@@ -499,9 +499,9 @@ func ReadWriteString(t testing.TB, p TProtocol, trans TTransport) {
func ReadWriteBinary(t testing.TB, p TProtocol, trans TTransport) {
v := protocol_bdata
- p.WriteBinary(v)
+ p.WriteBinary(context.Background(), v)
p.Flush(context.Background())
- value, err := p.ReadBinary()
+ value, err := p.ReadBinary(context.Background())
if err != nil {
t.Errorf("%s: %T %T Unable to read binary: %s", "ReadWriteBinary", p, trans, err.Error())
}
diff --git a/lib/go/thrift/serializer.go b/lib/go/thrift/serializer.go
index d85d2049d..b1b8061d6 100644
--- a/lib/go/thrift/serializer.go
+++ b/lib/go/thrift/serializer.go
@@ -30,8 +30,8 @@ type TSerializer struct {
}
type TStruct interface {
- Write(p TProtocol) error
- Read(p TProtocol) error
+ Write(ctx context.Context, p TProtocol) error
+ Read(ctx context.Context, p TProtocol) error
}
func NewTSerializer() *TSerializer {
@@ -46,7 +46,7 @@ func NewTSerializer() *TSerializer {
func (t *TSerializer) WriteString(ctx context.Context, msg TStruct) (s string, err error) {
t.Transport.Reset()
- if err = msg.Write(t.Protocol); err != nil {
+ if err = msg.Write(ctx, t.Protocol); err != nil {
return
}
@@ -63,7 +63,7 @@ func (t *TSerializer) WriteString(ctx context.Context, msg TStruct) (s string, e
func (t *TSerializer) Write(ctx context.Context, msg TStruct) (b []byte, err error) {
t.Transport.Reset()
- if err = msg.Write(t.Protocol); err != nil {
+ if err = msg.Write(ctx, t.Protocol); err != nil {
return
}
diff --git a/lib/go/thrift/serializer_test.go b/lib/go/thrift/serializer_test.go
index 52ebdca89..2e37ea2b1 100644
--- a/lib/go/thrift/serializer_test.go
+++ b/lib/go/thrift/serializer_test.go
@@ -80,7 +80,7 @@ type serializer interface {
}
type deserializer interface {
- ReadString(TStruct, string) error
+ ReadString(context.Context, TStruct, string) error
}
func plainSerializer(pf ProtocolFactory) serializer {
@@ -158,7 +158,7 @@ func ProtocolTest1(t *testing.T, pf ProtocolFactory) {
t1 := impl.Deserializer(pf)
var m1 MyTestStruct
- if err = t1.ReadString(&m1, s); err != nil {
+ if err = t1.ReadString(context.Background(), &m1, s); err != nil {
test.Fatalf("Unable to Deserialize struct: %v", err)
}
@@ -199,7 +199,7 @@ func ProtocolTest2(t *testing.T, pf ProtocolFactory) {
t1 := impl.Deserializer(pf)
var m1 MyTestStruct
- if err = t1.ReadString(&m1, s); err != nil {
+ if err = t1.ReadString(context.Background(), &m1, s); err != nil {
test.Fatalf("Unable to Deserialize struct: %v", err)
}
@@ -264,7 +264,7 @@ func TestSerializerPoolAsync(t *testing.T) {
t.Fatal("serialize:", err)
}
var m1 MyTestStruct
- if err = d.ReadString(&m1, str); err != nil {
+ if err = d.ReadString(context.Background(), &m1, str); err != nil {
t.Fatal("deserialize:", err)
}
@@ -335,7 +335,7 @@ func BenchmarkSerializer(b *testing.B) {
str, _ := s.WriteString(context.Background(), &m)
var m1 MyTestStruct
d := c.Deserializer()
- d.ReadString(&m1, str)
+ d.ReadString(context.Background(), &m1, str)
}
},
)
diff --git a/lib/go/thrift/serializer_types_test.go b/lib/go/thrift/serializer_types_test.go
index e5472bbff..4d1e992ae 100644
--- a/lib/go/thrift/serializer_types_test.go
+++ b/lib/go/thrift/serializer_types_test.go
@@ -48,6 +48,7 @@ struct MyTestStruct {
*/
import (
+ "context"
"fmt"
)
@@ -162,12 +163,12 @@ func (p *MyTestStruct) GetStringSet() map[string]struct{} {
func (p *MyTestStruct) GetE() MyTestEnum {
return p.E
}
-func (p *MyTestStruct) Read(iprot TProtocol) error {
- if _, err := iprot.ReadStructBegin(); err != nil {
+func (p *MyTestStruct) Read(ctx context.Context, iprot TProtocol) error {
+ if _, err := iprot.ReadStructBegin(ctx); err != nil {
return PrependError(fmt.Sprintf("%T read error: ", p), err)
}
for {
- _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin()
+ _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin(ctx)
if err != nil {
return PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err)
}
@@ -176,70 +177,70 @@ func (p *MyTestStruct) Read(iprot TProtocol) error {
}
switch fieldId {
case 1:
- if err := p.readField1(iprot); err != nil {
+ if err := p.readField1(ctx, iprot); err != nil {
return err
}
case 2:
- if err := p.readField2(iprot); err != nil {
+ if err := p.readField2(ctx, iprot); err != nil {
return err
}
case 3:
- if err := p.readField3(iprot); err != nil {
+ if err := p.readField3(ctx, iprot); err != nil {
return err
}
case 4:
- if err := p.readField4(iprot); err != nil {
+ if err := p.readField4(ctx, iprot); err != nil {
return err
}
case 5:
- if err := p.readField5(iprot); err != nil {
+ if err := p.readField5(ctx, iprot); err != nil {
return err
}
case 6:
- if err := p.readField6(iprot); err != nil {
+ if err := p.readField6(ctx, iprot); err != nil {
return err
}
case 7:
- if err := p.readField7(iprot); err != nil {
+ if err := p.readField7(ctx, iprot); err != nil {
return err
}
case 8:
- if err := p.readField8(iprot); err != nil {
+ if err := p.readField8(ctx, iprot); err != nil {
return err
}
case 9:
- if err := p.readField9(iprot); err != nil {
+ if err := p.readField9(ctx, iprot); err != nil {
return err
}
case 10:
- if err := p.readField10(iprot); err != nil {
+ if err := p.readField10(ctx, iprot); err != nil {
return err
}
case 11:
- if err := p.readField11(iprot); err != nil {
+ if err := p.readField11(ctx, iprot); err != nil {
return err
}
case 12:
- if err := p.readField12(iprot); err != nil {
+ if err := p.readField12(ctx, iprot); err != nil {
return err
}
default:
- if err := iprot.Skip(fieldTypeId); err != nil {
+ if err := iprot.Skip(ctx, fieldTypeId); err != nil {
return err
}
}
- if err := iprot.ReadFieldEnd(); err != nil {
+ if err := iprot.ReadFieldEnd(ctx); err != nil {
return err
}
}
- if err := iprot.ReadStructEnd(); err != nil {
+ if err := iprot.ReadStructEnd(ctx); err != nil {
return PrependError(fmt.Sprintf("%T read struct end error: ", p), err)
}
return nil
}
-func (p *MyTestStruct) readField1(iprot TProtocol) error {
- if v, err := iprot.ReadBool(); err != nil {
+func (p *MyTestStruct) readField1(ctx context.Context, iprot TProtocol) error {
+ if v, err := iprot.ReadBool(ctx); err != nil {
return PrependError("error reading field 1: ", err)
} else {
p.On = v
@@ -247,8 +248,8 @@ func (p *MyTestStruct) readField1(iprot TProtocol) error {
return nil
}
-func (p *MyTestStruct) readField2(iprot TProtocol) error {
- if v, err := iprot.ReadByte(); err != nil {
+func (p *MyTestStruct) readField2(ctx context.Context, iprot TProtocol) error {
+ if v, err := iprot.ReadByte(ctx); err != nil {
return PrependError("error reading field 2: ", err)
} else {
temp := int8(v)
@@ -257,8 +258,8 @@ func (p *MyTestStruct) readField2(iprot TProtocol) error {
return nil
}
-func (p *MyTestStruct) readField3(iprot TProtocol) error {
- if v, err := iprot.ReadI16(); err != nil {
+func (p *MyTestStruct) readField3(ctx context.Context, iprot TProtocol) error {
+ if v, err := iprot.ReadI16(ctx); err != nil {
return PrependError("error reading field 3: ", err)
} else {
p.Int16 = v
@@ -266,8 +267,8 @@ func (p *MyTestStruct) readField3(iprot TProtocol) error {
return nil
}
-func (p *MyTestStruct) readField4(iprot TProtocol) error {
- if v, err := iprot.ReadI32(); err != nil {
+func (p *MyTestStruct) readField4(ctx context.Context, iprot TProtocol) error {
+ if v, err := iprot.ReadI32(ctx); err != nil {
return PrependError("error reading field 4: ", err)
} else {
p.Int32 = v
@@ -275,8 +276,8 @@ func (p *MyTestStruct) readField4(iprot TProtocol) error {
return nil
}
-func (p *MyTestStruct) readField5(iprot TProtocol) error {
- if v, err := iprot.ReadI64(); err != nil {
+func (p *MyTestStruct) readField5(ctx context.Context, iprot TProtocol) error {
+ if v, err := iprot.ReadI64(ctx); err != nil {
return PrependError("error reading field 5: ", err)
} else {
p.Int64 = v
@@ -284,8 +285,8 @@ func (p *MyTestStruct) readField5(iprot TProtocol) error {
return nil
}
-func (p *MyTestStruct) readField6(iprot TProtocol) error {
- if v, err := iprot.ReadDouble(); err != nil {
+func (p *MyTestStruct) readField6(ctx context.Context, iprot TProtocol) error {
+ if v, err := iprot.ReadDouble(ctx); err != nil {
return PrependError("error reading field 6: ", err)
} else {
p.D = v
@@ -293,8 +294,8 @@ func (p *MyTestStruct) readField6(iprot TProtocol) error {
return nil
}
-func (p *MyTestStruct) readField7(iprot TProtocol) error {
- if v, err := iprot.ReadString(); err != nil {
+func (p *MyTestStruct) readField7(ctx context.Context, iprot TProtocol) error {
+ if v, err := iprot.ReadString(ctx); err != nil {
return PrependError("error reading field 7: ", err)
} else {
p.St = v
@@ -302,8 +303,8 @@ func (p *MyTestStruct) readField7(iprot TProtocol) error {
return nil
}
-func (p *MyTestStruct) readField8(iprot TProtocol) error {
- if v, err := iprot.ReadBinary(); err != nil {
+func (p *MyTestStruct) readField8(ctx context.Context, iprot TProtocol) error {
+ if v, err := iprot.ReadBinary(ctx); err != nil {
return PrependError("error reading field 8: ", err)
} else {
p.Bin = v
@@ -311,8 +312,8 @@ func (p *MyTestStruct) readField8(iprot TProtocol) error {
return nil
}
-func (p *MyTestStruct) readField9(iprot TProtocol) error {
- _, _, size, err := iprot.ReadMapBegin()
+func (p *MyTestStruct) readField9(ctx context.Context, iprot TProtocol) error {
+ _, _, size, err := iprot.ReadMapBegin(ctx)
if err != nil {
return PrependError("error reading map begin: ", err)
}
@@ -320,27 +321,27 @@ func (p *MyTestStruct) readField9(iprot TProtocol) error {
p.StringMap = tMap
for i := 0; i < size; i++ {
var _key0 string
- if v, err := iprot.ReadString(); err != nil {
+ if v, err := iprot.ReadString(ctx); err != nil {
return PrependError("error reading field 0: ", err)
} else {
_key0 = v
}
var _val1 string
- if v, err := iprot.ReadString(); err != nil {
+ if v, err := iprot.ReadString(ctx); err != nil {
return PrependError("error reading field 0: ", err)
} else {
_val1 = v
}
p.StringMap[_key0] = _val1
}
- if err := iprot.ReadMapEnd(); err != nil {
+ if err := iprot.ReadMapEnd(ctx); err != nil {
return PrependError("error reading map end: ", err)
}
return nil
}
-func (p *MyTestStruct) readField10(iprot TProtocol) error {
- _, size, err := iprot.ReadListBegin()
+func (p *MyTestStruct) readField10(ctx context.Context, iprot TProtocol) error {
+ _, size, err := iprot.ReadListBegin(ctx)
if err != nil {
return PrependError("error reading list begin: ", err)
}
@@ -348,21 +349,21 @@ func (p *MyTestStruct) readField10(iprot TProtocol) error {
p.StringList = tSlice
for i := 0; i < size; i++ {
var _elem2 string
- if v, err := iprot.ReadString(); err != nil {
+ if v, err := iprot.ReadString(ctx); err != nil {
return PrependError("error reading field 0: ", err)
} else {
_elem2 = v
}
p.StringList = append(p.StringList, _elem2)
}
- if err := iprot.ReadListEnd(); err != nil {
+ if err := iprot.ReadListEnd(ctx); err != nil {
return PrependError("error reading list end: ", err)
}
return nil
}
-func (p *MyTestStruct) readField11(iprot TProtocol) error {
- _, size, err := iprot.ReadSetBegin()
+func (p *MyTestStruct) readField11(ctx context.Context, iprot TProtocol) error {
+ _, size, err := iprot.ReadSetBegin(ctx)
if err != nil {
return PrependError("error reading set begin: ", err)
}
@@ -370,21 +371,21 @@ func (p *MyTestStruct) readField11(iprot TProtocol) error {
p.StringSet = tSet
for i := 0; i < size; i++ {
var _elem3 string
- if v, err := iprot.ReadString(); err != nil {
+ if v, err := iprot.ReadString(ctx); err != nil {
return PrependError("error reading field 0: ", err)
} else {
_elem3 = v
}
p.StringSet[_elem3] = struct{}{}
}
- if err := iprot.ReadSetEnd(); err != nil {
+ if err := iprot.ReadSetEnd(ctx); err != nil {
return PrependError("error reading set end: ", err)
}
return nil
}
-func (p *MyTestStruct) readField12(iprot TProtocol) error {
- if v, err := iprot.ReadI32(); err != nil {
+func (p *MyTestStruct) readField12(ctx context.Context, iprot TProtocol) error {
+ if v, err := iprot.ReadI32(ctx); err != nil {
return PrependError("error reading field 12: ", err)
} else {
temp := MyTestEnum(v)
@@ -393,233 +394,233 @@ func (p *MyTestStruct) readField12(iprot TProtocol) error {
return nil
}
-func (p *MyTestStruct) Write(oprot TProtocol) error {
- if err := oprot.WriteStructBegin("MyTestStruct"); err != nil {
+func (p *MyTestStruct) Write(ctx context.Context, oprot TProtocol) error {
+ if err := oprot.WriteStructBegin(ctx, "MyTestStruct"); err != nil {
return PrependError(fmt.Sprintf("%T write struct begin error: ", p), err)
}
- if err := p.writeField1(oprot); err != nil {
+ if err := p.writeField1(ctx, oprot); err != nil {
return err
}
- if err := p.writeField2(oprot); err != nil {
+ if err := p.writeField2(ctx, oprot); err != nil {
return err
}
- if err := p.writeField3(oprot); err != nil {
+ if err := p.writeField3(ctx, oprot); err != nil {
return err
}
- if err := p.writeField4(oprot); err != nil {
+ if err := p.writeField4(ctx, oprot); err != nil {
return err
}
- if err := p.writeField5(oprot); err != nil {
+ if err := p.writeField5(ctx, oprot); err != nil {
return err
}
- if err := p.writeField6(oprot); err != nil {
+ if err := p.writeField6(ctx, oprot); err != nil {
return err
}
- if err := p.writeField7(oprot); err != nil {
+ if err := p.writeField7(ctx, oprot); err != nil {
return err
}
- if err := p.writeField8(oprot); err != nil {
+ if err := p.writeField8(ctx, oprot); err != nil {
return err
}
- if err := p.writeField9(oprot); err != nil {
+ if err := p.writeField9(ctx, oprot); err != nil {
return err
}
- if err := p.writeField10(oprot); err != nil {
+ if err := p.writeField10(ctx, oprot); err != nil {
return err
}
- if err := p.writeField11(oprot); err != nil {
+ if err := p.writeField11(ctx, oprot); err != nil {
return err
}
- if err := p.writeField12(oprot); err != nil {
+ if err := p.writeField12(ctx, oprot); err != nil {
return err
}
- if err := oprot.WriteFieldStop(); err != nil {
+ if err := oprot.WriteFieldStop(ctx); err != nil {
return PrependError("write field stop error: ", err)
}
- if err := oprot.WriteStructEnd(); err != nil {
+ if err := oprot.WriteStructEnd(ctx); err != nil {
return PrependError("write struct stop error: ", err)
}
return nil
}
-func (p *MyTestStruct) writeField1(oprot TProtocol) (err error) {
- if err := oprot.WriteFieldBegin("on", BOOL, 1); err != nil {
+func (p *MyTestStruct) writeField1(ctx context.Context, oprot TProtocol) (err error) {
+ if err := oprot.WriteFieldBegin(ctx, "on", BOOL, 1); err != nil {
return PrependError(fmt.Sprintf("%T write field begin error 1:on: ", p), err)
}
- if err := oprot.WriteBool(bool(p.On)); err != nil {
+ if err := oprot.WriteBool(ctx, bool(p.On)); err != nil {
return PrependError(fmt.Sprintf("%T.on (1) field write error: ", p), err)
}
- if err := oprot.WriteFieldEnd(); err != nil {
+ if err := oprot.WriteFieldEnd(ctx); err != nil {
return PrependError(fmt.Sprintf("%T write field end error 1:on: ", p), err)
}
return err
}
-func (p *MyTestStruct) writeField2(oprot TProtocol) (err error) {
- if err := oprot.WriteFieldBegin("b", BYTE, 2); err != nil {
+func (p *MyTestStruct) writeField2(ctx context.Context, oprot TProtocol) (err error) {
+ if err := oprot.WriteFieldBegin(ctx, "b", BYTE, 2); err != nil {
return PrependError(fmt.Sprintf("%T write field begin error 2:b: ", p), err)
}
- if err := oprot.WriteByte(int8(p.B)); err != nil {
+ if err := oprot.WriteByte(ctx, int8(p.B)); err != nil {
return PrependError(fmt.Sprintf("%T.b (2) field write error: ", p), err)
}
- if err := oprot.WriteFieldEnd(); err != nil {
+ if err := oprot.WriteFieldEnd(ctx); err != nil {
return PrependError(fmt.Sprintf("%T write field end error 2:b: ", p), err)
}
return err
}
-func (p *MyTestStruct) writeField3(oprot TProtocol) (err error) {
- if err := oprot.WriteFieldBegin("int16", I16, 3); err != nil {
+func (p *MyTestStruct) writeField3(ctx context.Context, oprot TProtocol) (err error) {
+ if err := oprot.WriteFieldBegin(ctx, "int16", I16, 3); err != nil {
return PrependError(fmt.Sprintf("%T write field begin error 3:int16: ", p), err)
}
- if err := oprot.WriteI16(int16(p.Int16)); err != nil {
+ if err := oprot.WriteI16(ctx, int16(p.Int16)); err != nil {
return PrependError(fmt.Sprintf("%T.int16 (3) field write error: ", p), err)
}
- if err := oprot.WriteFieldEnd(); err != nil {
+ if err := oprot.WriteFieldEnd(ctx); err != nil {
return PrependError(fmt.Sprintf("%T write field end error 3:int16: ", p), err)
}
return err
}
-func (p *MyTestStruct) writeField4(oprot TProtocol) (err error) {
- if err := oprot.WriteFieldBegin("int32", I32, 4); err != nil {
+func (p *MyTestStruct) writeField4(ctx context.Context, oprot TProtocol) (err error) {
+ if err := oprot.WriteFieldBegin(ctx, "int32", I32, 4); err != nil {
return PrependError(fmt.Sprintf("%T write field begin error 4:int32: ", p), err)
}
- if err := oprot.WriteI32(int32(p.Int32)); err != nil {
+ if err := oprot.WriteI32(ctx, int32(p.Int32)); err != nil {
return PrependError(fmt.Sprintf("%T.int32 (4) field write error: ", p), err)
}
- if err := oprot.WriteFieldEnd(); err != nil {
+ if err := oprot.WriteFieldEnd(ctx); err != nil {
return PrependError(fmt.Sprintf("%T write field end error 4:int32: ", p), err)
}
return err
}
-func (p *MyTestStruct) writeField5(oprot TProtocol) (err error) {
- if err := oprot.WriteFieldBegin("int64", I64, 5); err != nil {
+func (p *MyTestStruct) writeField5(ctx context.Context, oprot TProtocol) (err error) {
+ if err := oprot.WriteFieldBegin(ctx, "int64", I64, 5); err != nil {
return PrependError(fmt.Sprintf("%T write field begin error 5:int64: ", p), err)
}
- if err := oprot.WriteI64(int64(p.Int64)); err != nil {
+ if err := oprot.WriteI64(ctx, int64(p.Int64)); err != nil {
return PrependError(fmt.Sprintf("%T.int64 (5) field write error: ", p), err)
}
- if err := oprot.WriteFieldEnd(); err != nil {
+ if err := oprot.WriteFieldEnd(ctx); err != nil {
return PrependError(fmt.Sprintf("%T write field end error 5:int64: ", p), err)
}
return err
}
-func (p *MyTestStruct) writeField6(oprot TProtocol) (err error) {
- if err := oprot.WriteFieldBegin("d", DOUBLE, 6); err != nil {
+func (p *MyTestStruct) writeField6(ctx context.Context, oprot TProtocol) (err error) {
+ if err := oprot.WriteFieldBegin(ctx, "d", DOUBLE, 6); err != nil {
return PrependError(fmt.Sprintf("%T write field begin error 6:d: ", p), err)
}
- if err := oprot.WriteDouble(float64(p.D)); err != nil {
+ if err := oprot.WriteDouble(ctx, float64(p.D)); err != nil {
return PrependError(fmt.Sprintf("%T.d (6) field write error: ", p), err)
}
- if err := oprot.WriteFieldEnd(); err != nil {
+ if err := oprot.WriteFieldEnd(ctx); err != nil {
return PrependError(fmt.Sprintf("%T write field end error 6:d: ", p), err)
}
return err
}
-func (p *MyTestStruct) writeField7(oprot TProtocol) (err error) {
- if err := oprot.WriteFieldBegin("st", STRING, 7); err != nil {
+func (p *MyTestStruct) writeField7(ctx context.Context, oprot TProtocol) (err error) {
+ if err := oprot.WriteFieldBegin(ctx, "st", STRING, 7); err != nil {
return PrependError(fmt.Sprintf("%T write field begin error 7:st: ", p), err)
}
- if err := oprot.WriteString(string(p.St)); err != nil {
+ if err := oprot.WriteString(ctx, string(p.St)); err != nil {
return PrependError(fmt.Sprintf("%T.st (7) field write error: ", p), err)
}
- if err := oprot.WriteFieldEnd(); err != nil {
+ if err := oprot.WriteFieldEnd(ctx); err != nil {
return PrependError(fmt.Sprintf("%T write field end error 7:st: ", p), err)
}
return err
}
-func (p *MyTestStruct) writeField8(oprot TProtocol) (err error) {
- if err := oprot.WriteFieldBegin("bin", STRING, 8); err != nil {
+func (p *MyTestStruct) writeField8(ctx context.Context, oprot TProtocol) (err error) {
+ if err := oprot.WriteFieldBegin(ctx, "bin", STRING, 8); err != nil {
return PrependError(fmt.Sprintf("%T write field begin error 8:bin: ", p), err)
}
- if err := oprot.WriteBinary(p.Bin); err != nil {
+ if err := oprot.WriteBinary(ctx, p.Bin); err != nil {
return PrependError(fmt.Sprintf("%T.bin (8) field write error: ", p), err)
}
- if err := oprot.WriteFieldEnd(); err != nil {
+ if err := oprot.WriteFieldEnd(ctx); err != nil {
return PrependError(fmt.Sprintf("%T write field end error 8:bin: ", p), err)
}
return err
}
-func (p *MyTestStruct) writeField9(oprot TProtocol) (err error) {
- if err := oprot.WriteFieldBegin("stringMap", MAP, 9); err != nil {
+func (p *MyTestStruct) writeField9(ctx context.Context, oprot TProtocol) (err error) {
+ if err := oprot.WriteFieldBegin(ctx, "stringMap", MAP, 9); err != nil {
return PrependError(fmt.Sprintf("%T write field begin error 9:stringMap: ", p), err)
}
- if err := oprot.WriteMapBegin(STRING, STRING, len(p.StringMap)); err != nil {
+ if err := oprot.WriteMapBegin(ctx, STRING, STRING, len(p.StringMap)); err != nil {
return PrependError("error writing map begin: ", err)
}
for k, v := range p.StringMap {
- if err := oprot.WriteString(string(k)); err != nil {
+ if err := oprot.WriteString(ctx, string(k)); err != nil {
return PrependError(fmt.Sprintf("%T. (0) field write error: ", p), err)
}
- if err := oprot.WriteString(string(v)); err != nil {
+ if err := oprot.WriteString(ctx, string(v)); err != nil {
return PrependError(fmt.Sprintf("%T. (0) field write error: ", p), err)
}
}
- if err := oprot.WriteMapEnd(); err != nil {
+ if err := oprot.WriteMapEnd(ctx); err != nil {
return PrependError("error writing map end: ", err)
}
- if err := oprot.WriteFieldEnd(); err != nil {
+ if err := oprot.WriteFieldEnd(ctx); err != nil {
return PrependError(fmt.Sprintf("%T write field end error 9:stringMap: ", p), err)
}
return err
}
-func (p *MyTestStruct) writeField10(oprot TProtocol) (err error) {
- if err := oprot.WriteFieldBegin("stringList", LIST, 10); err != nil {
+func (p *MyTestStruct) writeField10(ctx context.Context, oprot TProtocol) (err error) {
+ if err := oprot.WriteFieldBegin(ctx, "stringList", LIST, 10); err != nil {
return PrependError(fmt.Sprintf("%T write field begin error 10:stringList: ", p), err)
}
- if err := oprot.WriteListBegin(STRING, len(p.StringList)); err != nil {
+ if err := oprot.WriteListBegin(ctx, STRING, len(p.StringList)); err != nil {
return PrependError("error writing list begin: ", err)
}
for _, v := range p.StringList {
- if err := oprot.WriteString(string(v)); err != nil {
+ if err := oprot.WriteString(ctx, string(v)); err != nil {
return PrependError(fmt.Sprintf("%T. (0) field write error: ", p), err)
}
}
- if err := oprot.WriteListEnd(); err != nil {
+ if err := oprot.WriteListEnd(ctx); err != nil {
return PrependError("error writing list end: ", err)
}
- if err := oprot.WriteFieldEnd(); err != nil {
+ if err := oprot.WriteFieldEnd(ctx); err != nil {
return PrependError(fmt.Sprintf("%T write field end error 10:stringList: ", p), err)
}
return err
}
-func (p *MyTestStruct) writeField11(oprot TProtocol) (err error) {
- if err := oprot.WriteFieldBegin("stringSet", SET, 11); err != nil {
+func (p *MyTestStruct) writeField11(ctx context.Context, oprot TProtocol) (err error) {
+ if err := oprot.WriteFieldBegin(ctx, "stringSet", SET, 11); err != nil {
return PrependError(fmt.Sprintf("%T write field begin error 11:stringSet: ", p), err)
}
- if err := oprot.WriteSetBegin(STRING, len(p.StringSet)); err != nil {
+ if err := oprot.WriteSetBegin(ctx, STRING, len(p.StringSet)); err != nil {
return PrependError("error writing set begin: ", err)
}
for v := range p.StringSet {
- if err := oprot.WriteString(string(v)); err != nil {
+ if err := oprot.WriteString(ctx, string(v)); err != nil {
return PrependError(fmt.Sprintf("%T. (0) field write error: ", p), err)
}
}
- if err := oprot.WriteSetEnd(); err != nil {
+ if err := oprot.WriteSetEnd(ctx); err != nil {
return PrependError("error writing set end: ", err)
}
- if err := oprot.WriteFieldEnd(); err != nil {
+ if err := oprot.WriteFieldEnd(ctx); err != nil {
return PrependError(fmt.Sprintf("%T write field end error 11:stringSet: ", p), err)
}
return err
}
-func (p *MyTestStruct) writeField12(oprot TProtocol) (err error) {
- if err := oprot.WriteFieldBegin("e", I32, 12); err != nil {
+func (p *MyTestStruct) writeField12(ctx context.Context, oprot TProtocol) (err error) {
+ if err := oprot.WriteFieldBegin(ctx, "e", I32, 12); err != nil {
return PrependError(fmt.Sprintf("%T write field begin error 12:e: ", p), err)
}
- if err := oprot.WriteI32(int32(p.E)); err != nil {
+ if err := oprot.WriteI32(ctx, int32(p.E)); err != nil {
return PrependError(fmt.Sprintf("%T.e (12) field write error: ", p), err)
}
- if err := oprot.WriteFieldEnd(); err != nil {
+ if err := oprot.WriteFieldEnd(ctx); err != nil {
return PrependError(fmt.Sprintf("%T write field end error 12:e: ", p), err)
}
return err
diff --git a/lib/go/thrift/simple_json_protocol.go b/lib/go/thrift/simple_json_protocol.go
index f5e0c05d1..d101b993c 100644
--- a/lib/go/thrift/simple_json_protocol.go
+++ b/lib/go/thrift/simple_json_protocol.go
@@ -156,114 +156,113 @@ func mismatch(expected, actual string) error {
return fmt.Errorf("Expected '%s' but found '%s' while parsing JSON.", expected, actual)
}
-func (p *TSimpleJSONProtocol) WriteMessageBegin(name string, typeId TMessageType, seqId int32) error {
+func (p *TSimpleJSONProtocol) WriteMessageBegin(ctx context.Context, name string, typeId TMessageType, seqId int32) error {
p.resetContextStack() // THRIFT-3735
if e := p.OutputListBegin(); e != nil {
return e
}
- if e := p.WriteString(name); e != nil {
+ if e := p.WriteString(ctx, name); e != nil {
return e
}
- if e := p.WriteByte(int8(typeId)); e != nil {
+ if e := p.WriteByte(ctx, int8(typeId)); e != nil {
return e
}
- if e := p.WriteI32(seqId); e != nil {
+ if e := p.WriteI32(ctx, seqId); e != nil {
return e
}
return nil
}
-func (p *TSimpleJSONProtocol) WriteMessageEnd() error {
+func (p *TSimpleJSONProtocol) WriteMessageEnd(ctx context.Context) error {
return p.OutputListEnd()
}
-func (p *TSimpleJSONProtocol) WriteStructBegin(name string) error {
+func (p *TSimpleJSONProtocol) WriteStructBegin(ctx context.Context, name string) error {
if e := p.OutputObjectBegin(); e != nil {
return e
}
return nil
}
-func (p *TSimpleJSONProtocol) WriteStructEnd() error {
+func (p *TSimpleJSONProtocol) WriteStructEnd(ctx context.Context) error {
return p.OutputObjectEnd()
}
-func (p *TSimpleJSONProtocol) WriteFieldBegin(name string, typeId TType, id int16) error {
- if e := p.WriteString(name); e != nil {
+func (p *TSimpleJSONProtocol) WriteFieldBegin(ctx context.Context, name string, typeId TType, id int16) error {
+ if e := p.WriteString(ctx, name); e != nil {
return e
}
return nil
}
-func (p *TSimpleJSONProtocol) WriteFieldEnd() error {
- //return p.OutputListEnd()
+func (p *TSimpleJSONProtocol) WriteFieldEnd(ctx context.Context) error {
return nil
}
-func (p *TSimpleJSONProtocol) WriteFieldStop() error { return nil }
+func (p *TSimpleJSONProtocol) WriteFieldStop(ctx context.Context) error { return nil }
-func (p *TSimpleJSONProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error {
+func (p *TSimpleJSONProtocol) WriteMapBegin(ctx context.Context, keyType TType, valueType TType, size int) error {
if e := p.OutputListBegin(); e != nil {
return e
}
- if e := p.WriteByte(int8(keyType)); e != nil {
+ if e := p.WriteByte(ctx, int8(keyType)); e != nil {
return e
}
- if e := p.WriteByte(int8(valueType)); e != nil {
+ if e := p.WriteByte(ctx, int8(valueType)); e != nil {
return e
}
- return p.WriteI32(int32(size))
+ return p.WriteI32(ctx, int32(size))
}
-func (p *TSimpleJSONProtocol) WriteMapEnd() error {
+func (p *TSimpleJSONProtocol) WriteMapEnd(ctx context.Context) error {
return p.OutputListEnd()
}
-func (p *TSimpleJSONProtocol) WriteListBegin(elemType TType, size int) error {
+func (p *TSimpleJSONProtocol) WriteListBegin(ctx context.Context, elemType TType, size int) error {
return p.OutputElemListBegin(elemType, size)
}
-func (p *TSimpleJSONProtocol) WriteListEnd() error {
+func (p *TSimpleJSONProtocol) WriteListEnd(ctx context.Context) error {
return p.OutputListEnd()
}
-func (p *TSimpleJSONProtocol) WriteSetBegin(elemType TType, size int) error {
+func (p *TSimpleJSONProtocol) WriteSetBegin(ctx context.Context, elemType TType, size int) error {
return p.OutputElemListBegin(elemType, size)
}
-func (p *TSimpleJSONProtocol) WriteSetEnd() error {
+func (p *TSimpleJSONProtocol) WriteSetEnd(ctx context.Context) error {
return p.OutputListEnd()
}
-func (p *TSimpleJSONProtocol) WriteBool(b bool) error {
+func (p *TSimpleJSONProtocol) WriteBool(ctx context.Context, b bool) error {
return p.OutputBool(b)
}
-func (p *TSimpleJSONProtocol) WriteByte(b int8) error {
- return p.WriteI32(int32(b))
+func (p *TSimpleJSONProtocol) WriteByte(ctx context.Context, b int8) error {
+ return p.WriteI32(ctx, int32(b))
}
-func (p *TSimpleJSONProtocol) WriteI16(v int16) error {
- return p.WriteI32(int32(v))
+func (p *TSimpleJSONProtocol) WriteI16(ctx context.Context, v int16) error {
+ return p.WriteI32(ctx, int32(v))
}
-func (p *TSimpleJSONProtocol) WriteI32(v int32) error {
+func (p *TSimpleJSONProtocol) WriteI32(ctx context.Context, v int32) error {
return p.OutputI64(int64(v))
}
-func (p *TSimpleJSONProtocol) WriteI64(v int64) error {
+func (p *TSimpleJSONProtocol) WriteI64(ctx context.Context, v int64) error {
return p.OutputI64(int64(v))
}
-func (p *TSimpleJSONProtocol) WriteDouble(v float64) error {
+func (p *TSimpleJSONProtocol) WriteDouble(ctx context.Context, v float64) error {
return p.OutputF64(v)
}
-func (p *TSimpleJSONProtocol) WriteString(v string) error {
+func (p *TSimpleJSONProtocol) WriteString(ctx context.Context, v string) error {
return p.OutputString(v)
}
-func (p *TSimpleJSONProtocol) WriteBinary(v []byte) error {
+func (p *TSimpleJSONProtocol) WriteBinary(ctx context.Context, v []byte) error {
// JSON library only takes in a string,
// not an arbitrary byte array, to ensure bytes are transmitted
// efficiently we must convert this into a valid JSON string
@@ -289,39 +288,39 @@ func (p *TSimpleJSONProtocol) WriteBinary(v []byte) error {
}
// Reading methods.
-func (p *TSimpleJSONProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqId int32, err error) {
+func (p *TSimpleJSONProtocol) ReadMessageBegin(ctx context.Context) (name string, typeId TMessageType, seqId int32, err error) {
p.resetContextStack() // THRIFT-3735
if isNull, err := p.ParseListBegin(); isNull || err != nil {
return name, typeId, seqId, err
}
- if name, err = p.ReadString(); err != nil {
+ if name, err = p.ReadString(ctx); err != nil {
return name, typeId, seqId, err
}
- bTypeId, err := p.ReadByte()
+ bTypeId, err := p.ReadByte(ctx)
typeId = TMessageType(bTypeId)
if err != nil {
return name, typeId, seqId, err
}
- if seqId, err = p.ReadI32(); err != nil {
+ if seqId, err = p.ReadI32(ctx); err != nil {
return name, typeId, seqId, err
}
return name, typeId, seqId, nil
}
-func (p *TSimpleJSONProtocol) ReadMessageEnd() error {
+func (p *TSimpleJSONProtocol) ReadMessageEnd(ctx context.Context) error {
return p.ParseListEnd()
}
-func (p *TSimpleJSONProtocol) ReadStructBegin() (name string, err error) {
+func (p *TSimpleJSONProtocol) ReadStructBegin(ctx context.Context) (name string, err error) {
_, err = p.ParseObjectStart()
return "", err
}
-func (p *TSimpleJSONProtocol) ReadStructEnd() error {
+func (p *TSimpleJSONProtocol) ReadStructEnd(ctx context.Context) error {
return p.ParseObjectEnd()
}
-func (p *TSimpleJSONProtocol) ReadFieldBegin() (string, TType, int16, error) {
+func (p *TSimpleJSONProtocol) ReadFieldBegin(ctx context.Context) (string, TType, int16, error) {
if err := p.ParsePreValue(); err != nil {
return "", STOP, 0, err
}
@@ -340,21 +339,6 @@ func (p *TSimpleJSONProtocol) ReadFieldBegin() (string, TType, int16, error) {
return name, STOP, 0, err
}
return name, STOP, -1, p.ParsePostValue()
- /*
- if err = p.ParsePostValue(); err != nil {
- return name, STOP, 0, err
- }
- if isNull, err := p.ParseListBegin(); isNull || err != nil {
- return name, STOP, 0, err
- }
- bType, err := p.ReadByte()
- thetype := TType(bType)
- if err != nil {
- return name, thetype, 0, err
- }
- id, err := p.ReadI16()
- return name, thetype, id, err
- */
}
e := fmt.Errorf("Expected \"}\" or '\"', but found: '%s'", string(b))
return "", STOP, 0, NewTProtocolExceptionWithType(INVALID_DATA, e)
@@ -362,57 +346,56 @@ func (p *TSimpleJSONProtocol) ReadFieldBegin() (string, TType, int16, error) {
return "", STOP, 0, NewTProtocolException(io.EOF)
}
-func (p *TSimpleJSONProtocol) ReadFieldEnd() error {
+func (p *TSimpleJSONProtocol) ReadFieldEnd(ctx context.Context) error {
return nil
- //return p.ParseListEnd()
}
-func (p *TSimpleJSONProtocol) ReadMapBegin() (keyType TType, valueType TType, size int, e error) {
+func (p *TSimpleJSONProtocol) ReadMapBegin(ctx context.Context) (keyType TType, valueType TType, size int, e error) {
if isNull, e := p.ParseListBegin(); isNull || e != nil {
return VOID, VOID, 0, e
}
// read keyType
- bKeyType, e := p.ReadByte()
+ bKeyType, e := p.ReadByte(ctx)
keyType = TType(bKeyType)
if e != nil {
return keyType, valueType, size, e
}
// read valueType
- bValueType, e := p.ReadByte()
+ bValueType, e := p.ReadByte(ctx)
valueType = TType(bValueType)
if e != nil {
return keyType, valueType, size, e
}
// read size
- iSize, err := p.ReadI64()
+ iSize, err := p.ReadI64(ctx)
size = int(iSize)
return keyType, valueType, size, err
}
-func (p *TSimpleJSONProtocol) ReadMapEnd() error {
+func (p *TSimpleJSONProtocol) ReadMapEnd(ctx context.Context) error {
return p.ParseListEnd()
}
-func (p *TSimpleJSONProtocol) ReadListBegin() (elemType TType, size int, e error) {
+func (p *TSimpleJSONProtocol) ReadListBegin(ctx context.Context) (elemType TType, size int, e error) {
return p.ParseElemListBegin()
}
-func (p *TSimpleJSONProtocol) ReadListEnd() error {
+func (p *TSimpleJSONProtocol) ReadListEnd(ctx context.Context) error {
return p.ParseListEnd()
}
-func (p *TSimpleJSONProtocol) ReadSetBegin() (elemType TType, size int, e error) {
+func (p *TSimpleJSONProtocol) ReadSetBegin(ctx context.Context) (elemType TType, size int, e error) {
return p.ParseElemListBegin()
}
-func (p *TSimpleJSONProtocol) ReadSetEnd() error {
+func (p *TSimpleJSONProtocol) ReadSetEnd(ctx context.Context) error {
return p.ParseListEnd()
}
-func (p *TSimpleJSONProtocol) ReadBool() (bool, error) {
+func (p *TSimpleJSONProtocol) ReadBool(ctx context.Context) (bool, error) {
var value bool
if err := p.ParsePreValue(); err != nil {
@@ -467,32 +450,32 @@ func (p *TSimpleJSONProtocol) ReadBool() (bool, error) {
return value, p.ParsePostValue()
}
-func (p *TSimpleJSONProtocol) ReadByte() (int8, error) {
- v, err := p.ReadI64()
+func (p *TSimpleJSONProtocol) ReadByte(ctx context.Context) (int8, error) {
+ v, err := p.ReadI64(ctx)
return int8(v), err
}
-func (p *TSimpleJSONProtocol) ReadI16() (int16, error) {
- v, err := p.ReadI64()
+func (p *TSimpleJSONProtocol) ReadI16(ctx context.Context) (int16, error) {
+ v, err := p.ReadI64(ctx)
return int16(v), err
}
-func (p *TSimpleJSONProtocol) ReadI32() (int32, error) {
- v, err := p.ReadI64()
+func (p *TSimpleJSONProtocol) ReadI32(ctx context.Context) (int32, error) {
+ v, err := p.ReadI64(ctx)
return int32(v), err
}
-func (p *TSimpleJSONProtocol) ReadI64() (int64, error) {
+func (p *TSimpleJSONProtocol) ReadI64(ctx context.Context) (int64, error) {
v, _, err := p.ParseI64()
return v, err
}
-func (p *TSimpleJSONProtocol) ReadDouble() (float64, error) {
+func (p *TSimpleJSONProtocol) ReadDouble(ctx context.Context) (float64, error) {
v, _, err := p.ParseF64()
return v, err
}
-func (p *TSimpleJSONProtocol) ReadString() (string, error) {
+func (p *TSimpleJSONProtocol) ReadString(ctx context.Context) (string, error) {
var v string
if err := p.ParsePreValue(); err != nil {
return v, err
@@ -522,7 +505,7 @@ func (p *TSimpleJSONProtocol) ReadString() (string, error) {
return v, p.ParsePostValue()
}
-func (p *TSimpleJSONProtocol) ReadBinary() ([]byte, error) {
+func (p *TSimpleJSONProtocol) ReadBinary(ctx context.Context) ([]byte, error) {
var v []byte
if err := p.ParsePreValue(); err != nil {
return nil, err
@@ -557,8 +540,8 @@ func (p *TSimpleJSONProtocol) Flush(ctx context.Context) (err error) {
return NewTProtocolException(p.writer.Flush())
}
-func (p *TSimpleJSONProtocol) Skip(fieldType TType) (err error) {
- return SkipDefaultDepth(p, fieldType)
+func (p *TSimpleJSONProtocol) Skip(ctx context.Context, fieldType TType) (err error) {
+ return SkipDefaultDepth(ctx, p, fieldType)
}
func (p *TSimpleJSONProtocol) Transport() TTransport {
@@ -740,10 +723,10 @@ func (p *TSimpleJSONProtocol) OutputElemListBegin(elemType TType, size int) erro
if e := p.OutputListBegin(); e != nil {
return e
}
- if e := p.WriteByte(int8(elemType)); e != nil {
+ if e := p.OutputI64(int64(elemType)); e != nil {
return e
}
- if e := p.WriteI64(int64(size)); e != nil {
+ if e := p.OutputI64(int64(size)); e != nil {
return e
}
return nil
@@ -1039,12 +1022,12 @@ func (p *TSimpleJSONProtocol) ParseElemListBegin() (elemType TType, size int, e
if isNull, e := p.ParseListBegin(); isNull || e != nil {
return VOID, 0, e
}
- bElemType, err := p.ReadByte()
+ bElemType, _, err := p.ParseI64()
elemType = TType(bElemType)
if err != nil {
return elemType, size, err
}
- nSize, err2 := p.ReadI64()
+ nSize, _, err2 := p.ParseI64()
size = int(nSize)
return elemType, size, err2
}
diff --git a/lib/go/thrift/simple_json_protocol_test.go b/lib/go/thrift/simple_json_protocol_test.go
index 0126da0a8..951389a60 100644
--- a/lib/go/thrift/simple_json_protocol_test.go
+++ b/lib/go/thrift/simple_json_protocol_test.go
@@ -35,7 +35,7 @@ func TestWriteSimpleJSONProtocolBool(t *testing.T) {
trans := NewTMemoryBuffer()
p := NewTSimpleJSONProtocol(trans)
for _, value := range BOOL_VALUES {
- if e := p.WriteBool(value); e != nil {
+ if e := p.WriteBool(context.Background(), value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
if e := p.Flush(context.Background()); e != nil {
@@ -66,7 +66,7 @@ func TestReadSimpleJSONProtocolBool(t *testing.T) {
}
trans.Flush(context.Background())
s := trans.String()
- v, e := p.ReadBool()
+ v, e := p.ReadBool(context.Background())
if e != nil {
t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error())
}
@@ -86,7 +86,7 @@ func TestWriteSimpleJSONProtocolByte(t *testing.T) {
trans := NewTMemoryBuffer()
p := NewTSimpleJSONProtocol(trans)
for _, value := range BYTE_VALUES {
- if e := p.WriteByte(value); e != nil {
+ if e := p.WriteByte(context.Background(), value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
if e := p.Flush(context.Background()); e != nil {
@@ -113,7 +113,7 @@ func TestReadSimpleJSONProtocolByte(t *testing.T) {
trans.WriteString(strconv.Itoa(int(value)))
trans.Flush(context.Background())
s := trans.String()
- v, e := p.ReadByte()
+ v, e := p.ReadByte(context.Background())
if e != nil {
t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error())
}
@@ -133,7 +133,7 @@ func TestWriteSimpleJSONProtocolI16(t *testing.T) {
trans := NewTMemoryBuffer()
p := NewTSimpleJSONProtocol(trans)
for _, value := range INT16_VALUES {
- if e := p.WriteI16(value); e != nil {
+ if e := p.WriteI16(context.Background(), value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
if e := p.Flush(context.Background()); e != nil {
@@ -160,7 +160,7 @@ func TestReadSimpleJSONProtocolI16(t *testing.T) {
trans.WriteString(strconv.Itoa(int(value)))
trans.Flush(context.Background())
s := trans.String()
- v, e := p.ReadI16()
+ v, e := p.ReadI16(context.Background())
if e != nil {
t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error())
}
@@ -180,7 +180,7 @@ func TestWriteSimpleJSONProtocolI32(t *testing.T) {
trans := NewTMemoryBuffer()
p := NewTSimpleJSONProtocol(trans)
for _, value := range INT32_VALUES {
- if e := p.WriteI32(value); e != nil {
+ if e := p.WriteI32(context.Background(), value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
if e := p.Flush(context.Background()); e != nil {
@@ -207,7 +207,7 @@ func TestReadSimpleJSONProtocolI32(t *testing.T) {
trans.WriteString(strconv.Itoa(int(value)))
trans.Flush(context.Background())
s := trans.String()
- v, e := p.ReadI32()
+ v, e := p.ReadI32(context.Background())
if e != nil {
t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error())
}
@@ -231,7 +231,7 @@ func TestReadSimpleJSONProtocolI32Null(t *testing.T) {
trans.WriteString(value)
trans.Flush(context.Background())
s := trans.String()
- v, e := p.ReadI32()
+ v, e := p.ReadI32(context.Background())
if e != nil {
t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error())
@@ -248,7 +248,7 @@ func TestWriteSimpleJSONProtocolI64(t *testing.T) {
trans := NewTMemoryBuffer()
p := NewTSimpleJSONProtocol(trans)
for _, value := range INT64_VALUES {
- if e := p.WriteI64(value); e != nil {
+ if e := p.WriteI64(context.Background(), value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
if e := p.Flush(context.Background()); e != nil {
@@ -275,7 +275,7 @@ func TestReadSimpleJSONProtocolI64(t *testing.T) {
trans.WriteString(strconv.FormatInt(value, 10))
trans.Flush(context.Background())
s := trans.String()
- v, e := p.ReadI64()
+ v, e := p.ReadI64(context.Background())
if e != nil {
t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error())
}
@@ -299,7 +299,7 @@ func TestReadSimpleJSONProtocolI64Null(t *testing.T) {
trans.WriteString(value)
trans.Flush(context.Background())
s := trans.String()
- v, e := p.ReadI64()
+ v, e := p.ReadI64(context.Background())
if e != nil {
t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error())
@@ -316,7 +316,7 @@ func TestWriteSimpleJSONProtocolDouble(t *testing.T) {
trans := NewTMemoryBuffer()
p := NewTSimpleJSONProtocol(trans)
for _, value := range DOUBLE_VALUES {
- if e := p.WriteDouble(value); e != nil {
+ if e := p.WriteDouble(context.Background(), value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
if e := p.Flush(context.Background()); e != nil {
@@ -358,7 +358,7 @@ func TestReadSimpleJSONProtocolDouble(t *testing.T) {
trans.WriteString(n.String())
trans.Flush(context.Background())
s := trans.String()
- v, e := p.ReadDouble()
+ v, e := p.ReadDouble(context.Background())
if e != nil {
t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error())
}
@@ -392,7 +392,7 @@ func TestWriteSimpleJSONProtocolString(t *testing.T) {
trans := NewTMemoryBuffer()
p := NewTSimpleJSONProtocol(trans)
for _, value := range STRING_VALUES {
- if e := p.WriteString(value); e != nil {
+ if e := p.WriteString(context.Background(), value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
if e := p.Flush(context.Background()); e != nil {
@@ -419,7 +419,7 @@ func TestReadSimpleJSONProtocolString(t *testing.T) {
trans.WriteString(jsonQuote(value))
trans.Flush(context.Background())
s := trans.String()
- v, e := p.ReadString()
+ v, e := p.ReadString(context.Background())
if e != nil {
t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error())
}
@@ -443,7 +443,7 @@ func TestReadSimpleJSONProtocolStringNull(t *testing.T) {
trans.WriteString(value)
trans.Flush(context.Background())
s := trans.String()
- v, e := p.ReadString()
+ v, e := p.ReadString(context.Background())
if e != nil {
t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error())
}
@@ -462,7 +462,7 @@ func TestWriteSimpleJSONProtocolBinary(t *testing.T) {
b64String := string(b64value)
trans := NewTMemoryBuffer()
p := NewTSimpleJSONProtocol(trans)
- if e := p.WriteBinary(value); e != nil {
+ if e := p.WriteBinary(context.Background(), value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
if e := p.Flush(context.Background()); e != nil {
@@ -490,7 +490,7 @@ func TestReadSimpleJSONProtocolBinary(t *testing.T) {
trans.WriteString(jsonQuote(b64String))
trans.Flush(context.Background())
s := trans.String()
- v, e := p.ReadBinary()
+ v, e := p.ReadBinary(context.Background())
if e != nil {
t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error())
}
@@ -519,7 +519,7 @@ func TestReadSimpleJSONProtocolBinaryNull(t *testing.T) {
trans.WriteString(value)
trans.Flush(context.Background())
s := trans.String()
- b, e := p.ReadBinary()
+ b, e := p.ReadBinary(context.Background())
v := string(b)
if e != nil {
@@ -536,13 +536,13 @@ func TestWriteSimpleJSONProtocolList(t *testing.T) {
thetype := "list"
trans := NewTMemoryBuffer()
p := NewTSimpleJSONProtocol(trans)
- p.WriteListBegin(TType(DOUBLE), len(DOUBLE_VALUES))
+ p.WriteListBegin(context.Background(), TType(DOUBLE), len(DOUBLE_VALUES))
for _, value := range DOUBLE_VALUES {
- if e := p.WriteDouble(value); e != nil {
+ if e := p.WriteDouble(context.Background(), value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
}
- p.WriteListEnd()
+ p.WriteListEnd(context.Background())
if e := p.Flush(context.Background()); e != nil {
t.Fatalf("Unable to write %s due to error flushing: %s", thetype, e.Error())
}
@@ -590,13 +590,13 @@ func TestWriteSimpleJSONProtocolSet(t *testing.T) {
thetype := "set"
trans := NewTMemoryBuffer()
p := NewTSimpleJSONProtocol(trans)
- p.WriteSetBegin(TType(DOUBLE), len(DOUBLE_VALUES))
+ p.WriteSetBegin(context.Background(), TType(DOUBLE), len(DOUBLE_VALUES))
for _, value := range DOUBLE_VALUES {
- if e := p.WriteDouble(value); e != nil {
+ if e := p.WriteDouble(context.Background(), value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
}
- p.WriteSetEnd()
+ p.WriteSetEnd(context.Background())
if e := p.Flush(context.Background()); e != nil {
t.Fatalf("Unable to write %s due to error flushing: %s", thetype, e.Error())
}
@@ -644,16 +644,16 @@ func TestWriteSimpleJSONProtocolMap(t *testing.T) {
thetype := "map"
trans := NewTMemoryBuffer()
p := NewTSimpleJSONProtocol(trans)
- p.WriteMapBegin(TType(I32), TType(DOUBLE), len(DOUBLE_VALUES))
+ p.WriteMapBegin(context.Background(), TType(I32), TType(DOUBLE), len(DOUBLE_VALUES))
for k, value := range DOUBLE_VALUES {
- if e := p.WriteI32(int32(k)); e != nil {
+ if e := p.WriteI32(context.Background(), int32(k)); e != nil {
t.Fatalf("Unable to write %s key int32 value %v due to error: %s", thetype, k, e.Error())
}
- if e := p.WriteDouble(value); e != nil {
+ if e := p.WriteDouble(context.Background(), value); e != nil {
t.Fatalf("Unable to write %s value float64 value %v due to error: %s", thetype, value, e.Error())
}
}
- p.WriteMapEnd()
+ p.WriteMapEnd(context.Background())
if e := p.Flush(context.Background()); e != nil {
t.Fatalf("Unable to write %s due to error flushing: %s", thetype, e.Error())
}
@@ -720,17 +720,17 @@ func TestWriteSimpleJSONProtocolSafePeek(t *testing.T) {
p := NewTSimpleJSONProtocol(trans)
trans.Write([]byte{'a', 'b'})
trans.Flush(context.Background())
-
+
test1 := p.safePeekContains([]byte{'a', 'b'})
if !test1 {
t.Fatalf("Should match at test 1")
}
-
+
test2 := p.safePeekContains([]byte{'a', 'b', 'c', 'd'})
if test2 {
t.Fatalf("Should not match at test 2")
}
-
+
test3 := p.safePeekContains([]byte{'x', 'y'})
if test3 {
t.Fatalf("Should not match at test 3")
diff --git a/lib/go/thrift/simple_server.go b/lib/go/thrift/simple_server.go
index 5a9c9c9e1..85baa4ee1 100644
--- a/lib/go/thrift/simple_server.go
+++ b/lib/go/thrift/simple_server.go
@@ -285,7 +285,7 @@ func (p *TSimpleServer) processRequests(client TTransport) (err error) {
// ReadFrame is safe to be called multiple times so it
// won't break when it's called again later when we
// actually start to read the message.
- if err := headerProtocol.ReadFrame(); err != nil {
+ if err := headerProtocol.ReadFrame(ctx); err != nil {
return err
}
ctx = AddReadTHeaderToContext(ctx, headerProtocol.GetReadHeaders())
diff --git a/lib/go/thrift/transport_exception.go b/lib/go/thrift/transport_exception.go
index d2283ea2e..16193ee86 100644
--- a/lib/go/thrift/transport_exception.go
+++ b/lib/go/thrift/transport_exception.go
@@ -64,6 +64,10 @@ func (p *tTransportException) Unwrap() error {
return p.err
}
+func (p *tTransportException) Timeout() bool {
+ return p.typeId == TIMED_OUT
+}
+
func NewTTransportException(t int, e string) TTransportException {
return &tTransportException{typeId: t, err: errors.New(e)}
}
@@ -92,3 +96,13 @@ func NewTTransportExceptionFromError(e error) TTransportException {
return &tTransportException{typeId: UNKNOWN_TRANSPORT_EXCEPTION, err: e}
}
+
+// isTimeoutError returns true when err is a timeout error.
+//
+// Note that this also includes TTransportException wrapped timeout errors.
+func isTimeoutError(err error) bool {
+ if t, ok := err.(timeoutable); ok {
+ return t.Timeout()
+ }
+ return false
+}
diff --git a/lib/go/thrift/transport_exception_test.go b/lib/go/thrift/transport_exception_test.go
index cf26258d9..fb1dc2602 100644
--- a/lib/go/thrift/transport_exception_test.go
+++ b/lib/go/thrift/transport_exception_test.go
@@ -20,9 +20,9 @@
package thrift
import (
+ "errors"
"fmt"
"io"
-
"testing"
)
@@ -78,3 +78,32 @@ func TestTExceptionEOF(t *testing.T) {
t.Errorf("Unwrapped exception did not match: expected %v, got %v", io.EOF, e.Unwrap())
}
}
+
+func TestIsTimeoutError(t *testing.T) {
+ te := &timeout{true}
+ if !isTimeoutError(te) {
+ t.Error("isTimeoutError expected true, got false")
+ }
+ e := NewTTransportExceptionFromError(te)
+ if !isTimeoutError(e) {
+ t.Error("isTimeoutError on wrapped TTransportException expected true, got false")
+ }
+
+ te = &timeout{false}
+ if isTimeoutError(te) {
+ t.Error("isTimeoutError expected false, got true")
+ }
+ e = NewTTransportExceptionFromError(te)
+ if isTimeoutError(e) {
+ t.Error("isTimeoutError on wrapped TTransportException expected false, got true")
+ }
+
+ err := errors.New("foo")
+ if isTimeoutError(err) {
+ t.Error("isTimeoutError expected false, got true")
+ }
+ e = NewTTransportExceptionFromError(err)
+ if isTimeoutError(e) {
+ t.Error("isTimeoutError on wrapped TTransportException expected false, got true")
+ }
+}