diff options
Diffstat (limited to 'workhorse/internal/httprs/httprs_test.go')
-rw-r--r-- | workhorse/internal/httprs/httprs_test.go | 257 |
1 files changed, 257 insertions, 0 deletions
diff --git a/workhorse/internal/httprs/httprs_test.go b/workhorse/internal/httprs/httprs_test.go new file mode 100644 index 00000000000..62279d895c9 --- /dev/null +++ b/workhorse/internal/httprs/httprs_test.go @@ -0,0 +1,257 @@ +package httprs + +import ( + "fmt" + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + "time" + + . "github.com/smartystreets/goconvey/convey" +) + +type fakeResponseWriter struct { + code int + h http.Header + tmp *os.File +} + +func (f *fakeResponseWriter) Header() http.Header { + return f.h +} + +func (f *fakeResponseWriter) Write(b []byte) (int, error) { + return f.tmp.Write(b) +} + +func (f *fakeResponseWriter) Close(b []byte) error { + return f.tmp.Close() +} + +func (f *fakeResponseWriter) WriteHeader(code int) { + f.code = code +} + +func (f *fakeResponseWriter) Response() *http.Response { + f.tmp.Seek(0, io.SeekStart) + return &http.Response{Body: f.tmp, StatusCode: f.code, Header: f.h} +} + +type fakeRoundTripper struct { + src *os.File + downgradeZeroToNoRange bool +} + +func (f *fakeRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) { + fw := &fakeResponseWriter{h: http.Header{}} + var err error + fw.tmp, err = ioutil.TempFile(os.TempDir(), "httprs") + if err != nil { + return nil, err + } + if f.downgradeZeroToNoRange { + // There are implementations that downgrades bytes=0- to a normal un-ranged GET + if r.Header.Get("Range") == "bytes=0-" { + r.Header.Del("Range") + } + } + http.ServeContent(fw, r, "temp.txt", time.Now(), f.src) + + return fw.Response(), nil +} + +const SZ = 4096 + +const ( + downgradeZeroToNoRange = 1 << iota + sendAcceptRanges +) + +type RSFactory func() *HttpReadSeeker + +func newRSFactory(flags int) RSFactory { + return func() *HttpReadSeeker { + tmp, err := ioutil.TempFile(os.TempDir(), "httprs") + if err != nil { + return nil + } + for i := 0; i < SZ; i++ { + tmp.WriteString(fmt.Sprintf("%04d", i)) + } + + req, err := http.NewRequest("GET", "http://www.example.com", nil) + if err != nil { + return nil + } + res := &http.Response{ + Request: req, + ContentLength: SZ * 4, + } + + if flags&sendAcceptRanges > 0 { + res.Header = http.Header{"Accept-Ranges": []string{"bytes"}} + } + + downgradeZeroToNoRange := (flags & downgradeZeroToNoRange) > 0 + return NewHttpReadSeeker(res, &http.Client{Transport: &fakeRoundTripper{src: tmp, downgradeZeroToNoRange: downgradeZeroToNoRange}}) + } +} + +func TestHttpWebServer(t *testing.T) { + Convey("Scenario: testing WebServer", t, func() { + dir, err := ioutil.TempDir("", "webserver") + So(err, ShouldBeNil) + defer os.RemoveAll(dir) + + err = ioutil.WriteFile(filepath.Join(dir, "file"), make([]byte, 10000), 0755) + So(err, ShouldBeNil) + + server := httptest.NewServer(http.FileServer(http.Dir(dir))) + + Convey("When requesting /file", func() { + res, err := http.Get(server.URL + "/file") + So(err, ShouldBeNil) + + stream := NewHttpReadSeeker(res) + So(stream, ShouldNotBeNil) + + Convey("Can read 100 bytes from start of file", func() { + n, err := stream.Read(make([]byte, 100)) + So(err, ShouldBeNil) + So(n, ShouldEqual, 100) + + Convey("When seeking 4KiB forward", func() { + pos, err := stream.Seek(4096, io.SeekCurrent) + So(err, ShouldBeNil) + So(pos, ShouldEqual, 4096+100) + + Convey("Can read 100 bytes", func() { + n, err := stream.Read(make([]byte, 100)) + So(err, ShouldBeNil) + So(n, ShouldEqual, 100) + }) + }) + }) + }) + }) +} + +func TestHttpReaderSeeker(t *testing.T) { + tests := []struct { + name string + newRS func() *HttpReadSeeker + }{ + {name: "with no flags", newRS: newRSFactory(0)}, + {name: "with only Accept-Ranges", newRS: newRSFactory(sendAcceptRanges)}, + {name: "downgrade 0-range to no range", newRS: newRSFactory(downgradeZeroToNoRange)}, + {name: "downgrade 0-range with Accept-Ranges", newRS: newRSFactory(downgradeZeroToNoRange | sendAcceptRanges)}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + testHttpReaderSeeker(t, test.newRS) + }) + } +} + +func testHttpReaderSeeker(t *testing.T, newRS RSFactory) { + Convey("Scenario: testing HttpReaderSeeker", t, func() { + + Convey("Read should start at the beginning", func() { + r := newRS() + So(r, ShouldNotBeNil) + defer r.Close() + buf := make([]byte, 4) + n, err := io.ReadFull(r, buf) + So(n, ShouldEqual, 4) + So(err, ShouldBeNil) + So(string(buf), ShouldEqual, "0000") + }) + + Convey("Seek w SEEK_SET should seek to right offset", func() { + r := newRS() + So(r, ShouldNotBeNil) + defer r.Close() + s, err := r.Seek(4*64, io.SeekStart) + So(s, ShouldEqual, 4*64) + So(err, ShouldBeNil) + buf := make([]byte, 4) + n, err := io.ReadFull(r, buf) + So(n, ShouldEqual, 4) + So(err, ShouldBeNil) + So(string(buf), ShouldEqual, "0064") + }) + + Convey("Read + Seek w SEEK_CUR should seek to right offset", func() { + r := newRS() + So(r, ShouldNotBeNil) + defer r.Close() + buf := make([]byte, 4) + io.ReadFull(r, buf) + s, err := r.Seek(4*64, os.SEEK_CUR) + So(s, ShouldEqual, 4*64+4) + So(err, ShouldBeNil) + n, err := io.ReadFull(r, buf) + So(n, ShouldEqual, 4) + So(err, ShouldBeNil) + So(string(buf), ShouldEqual, "0065") + }) + + Convey("Seek w SEEK_END should seek to right offset", func() { + r := newRS() + So(r, ShouldNotBeNil) + defer r.Close() + buf := make([]byte, 4) + io.ReadFull(r, buf) + s, err := r.Seek(4, os.SEEK_END) + So(s, ShouldEqual, SZ*4-4) + So(err, ShouldBeNil) + n, err := io.ReadFull(r, buf) + So(n, ShouldEqual, 4) + So(err, ShouldBeNil) + So(string(buf), ShouldEqual, fmt.Sprintf("%04d", SZ-1)) + }) + + Convey("Short seek should consume existing request", func() { + r := newRS() + So(r, ShouldNotBeNil) + defer r.Close() + buf := make([]byte, 4) + So(r.Requests, ShouldEqual, 0) + io.ReadFull(r, buf) + So(r.Requests, ShouldEqual, 1) + s, err := r.Seek(shortSeekBytes, os.SEEK_CUR) + So(r.Requests, ShouldEqual, 1) + So(s, ShouldEqual, shortSeekBytes+4) + So(err, ShouldBeNil) + n, err := io.ReadFull(r, buf) + So(n, ShouldEqual, 4) + So(err, ShouldBeNil) + So(string(buf), ShouldEqual, "0257") + So(r.Requests, ShouldEqual, 1) + }) + + Convey("Long seek should do a new request", func() { + r := newRS() + So(r, ShouldNotBeNil) + defer r.Close() + buf := make([]byte, 4) + So(r.Requests, ShouldEqual, 0) + io.ReadFull(r, buf) + So(r.Requests, ShouldEqual, 1) + s, err := r.Seek(shortSeekBytes+1, os.SEEK_CUR) + So(r.Requests, ShouldEqual, 1) + So(s, ShouldEqual, shortSeekBytes+4+1) + So(err, ShouldBeNil) + n, err := io.ReadFull(r, buf) + So(n, ShouldEqual, 4) + So(err, ShouldBeNil) + So(string(buf), ShouldEqual, "2570") + So(r.Requests, ShouldEqual, 2) + }) + }) +} |