summaryrefslogtreecommitdiff
path: root/workhorse/internal/helper/context_reader_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'workhorse/internal/helper/context_reader_test.go')
-rw-r--r--workhorse/internal/helper/context_reader_test.go83
1 files changed, 83 insertions, 0 deletions
diff --git a/workhorse/internal/helper/context_reader_test.go b/workhorse/internal/helper/context_reader_test.go
new file mode 100644
index 00000000000..257ec4e35f2
--- /dev/null
+++ b/workhorse/internal/helper/context_reader_test.go
@@ -0,0 +1,83 @@
+package helper
+
+import (
+ "context"
+ "io"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+type fakeReader struct {
+ n int
+ err error
+}
+
+func (f *fakeReader) Read(b []byte) (int, error) {
+ return f.n, f.err
+}
+
+type fakeContextWithTimeout struct {
+ n int
+ threshold int
+}
+
+func (*fakeContextWithTimeout) Deadline() (deadline time.Time, ok bool) {
+ return
+}
+
+func (*fakeContextWithTimeout) Done() <-chan struct{} {
+ return nil
+}
+
+func (*fakeContextWithTimeout) Value(key interface{}) interface{} {
+ return nil
+}
+
+func (f *fakeContextWithTimeout) Err() error {
+ f.n++
+ if f.n > f.threshold {
+ return context.DeadlineExceeded
+ }
+
+ return nil
+}
+
+func TestContextReaderRead(t *testing.T) {
+ underlyingReader := &fakeReader{n: 1, err: io.EOF}
+
+ for _, tc := range []struct {
+ desc string
+ ctx *fakeContextWithTimeout
+ expectedN int
+ expectedErr error
+ }{
+ {
+ desc: "Before and after read deadline checks are fine",
+ ctx: &fakeContextWithTimeout{n: 0, threshold: 2},
+ expectedN: underlyingReader.n,
+ expectedErr: underlyingReader.err,
+ },
+ {
+ desc: "Before read deadline check fails",
+ ctx: &fakeContextWithTimeout{n: 0, threshold: 0},
+ expectedN: 0,
+ expectedErr: context.DeadlineExceeded,
+ },
+ {
+ desc: "After read deadline check fails",
+ ctx: &fakeContextWithTimeout{n: 0, threshold: 1},
+ expectedN: underlyingReader.n,
+ expectedErr: context.DeadlineExceeded,
+ },
+ } {
+ t.Run(tc.desc, func(t *testing.T) {
+ cr := NewContextReader(tc.ctx, underlyingReader)
+
+ n, err := cr.Read(nil)
+ require.Equal(t, tc.expectedN, n)
+ require.Equal(t, tc.expectedErr, err)
+ })
+ }
+}