summaryrefslogtreecommitdiff
path: root/workhorse/internal/httprs/httprs_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'workhorse/internal/httprs/httprs_test.go')
-rw-r--r--workhorse/internal/httprs/httprs_test.go257
1 files changed, 257 insertions, 0 deletions
diff --git a/workhorse/internal/httprs/httprs_test.go b/workhorse/internal/httprs/httprs_test.go
new file mode 100644
index 00000000000..62279d895c9
--- /dev/null
+++ b/workhorse/internal/httprs/httprs_test.go
@@ -0,0 +1,257 @@
+package httprs
+
+import (
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "path/filepath"
+ "testing"
+ "time"
+
+ . "github.com/smartystreets/goconvey/convey"
+)
+
+type fakeResponseWriter struct {
+ code int
+ h http.Header
+ tmp *os.File
+}
+
+func (f *fakeResponseWriter) Header() http.Header {
+ return f.h
+}
+
+func (f *fakeResponseWriter) Write(b []byte) (int, error) {
+ return f.tmp.Write(b)
+}
+
+func (f *fakeResponseWriter) Close(b []byte) error {
+ return f.tmp.Close()
+}
+
+func (f *fakeResponseWriter) WriteHeader(code int) {
+ f.code = code
+}
+
+func (f *fakeResponseWriter) Response() *http.Response {
+ f.tmp.Seek(0, io.SeekStart)
+ return &http.Response{Body: f.tmp, StatusCode: f.code, Header: f.h}
+}
+
+type fakeRoundTripper struct {
+ src *os.File
+ downgradeZeroToNoRange bool
+}
+
+func (f *fakeRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) {
+ fw := &fakeResponseWriter{h: http.Header{}}
+ var err error
+ fw.tmp, err = ioutil.TempFile(os.TempDir(), "httprs")
+ if err != nil {
+ return nil, err
+ }
+ if f.downgradeZeroToNoRange {
+ // There are implementations that downgrades bytes=0- to a normal un-ranged GET
+ if r.Header.Get("Range") == "bytes=0-" {
+ r.Header.Del("Range")
+ }
+ }
+ http.ServeContent(fw, r, "temp.txt", time.Now(), f.src)
+
+ return fw.Response(), nil
+}
+
+const SZ = 4096
+
+const (
+ downgradeZeroToNoRange = 1 << iota
+ sendAcceptRanges
+)
+
+type RSFactory func() *HttpReadSeeker
+
+func newRSFactory(flags int) RSFactory {
+ return func() *HttpReadSeeker {
+ tmp, err := ioutil.TempFile(os.TempDir(), "httprs")
+ if err != nil {
+ return nil
+ }
+ for i := 0; i < SZ; i++ {
+ tmp.WriteString(fmt.Sprintf("%04d", i))
+ }
+
+ req, err := http.NewRequest("GET", "http://www.example.com", nil)
+ if err != nil {
+ return nil
+ }
+ res := &http.Response{
+ Request: req,
+ ContentLength: SZ * 4,
+ }
+
+ if flags&sendAcceptRanges > 0 {
+ res.Header = http.Header{"Accept-Ranges": []string{"bytes"}}
+ }
+
+ downgradeZeroToNoRange := (flags & downgradeZeroToNoRange) > 0
+ return NewHttpReadSeeker(res, &http.Client{Transport: &fakeRoundTripper{src: tmp, downgradeZeroToNoRange: downgradeZeroToNoRange}})
+ }
+}
+
+func TestHttpWebServer(t *testing.T) {
+ Convey("Scenario: testing WebServer", t, func() {
+ dir, err := ioutil.TempDir("", "webserver")
+ So(err, ShouldBeNil)
+ defer os.RemoveAll(dir)
+
+ err = ioutil.WriteFile(filepath.Join(dir, "file"), make([]byte, 10000), 0755)
+ So(err, ShouldBeNil)
+
+ server := httptest.NewServer(http.FileServer(http.Dir(dir)))
+
+ Convey("When requesting /file", func() {
+ res, err := http.Get(server.URL + "/file")
+ So(err, ShouldBeNil)
+
+ stream := NewHttpReadSeeker(res)
+ So(stream, ShouldNotBeNil)
+
+ Convey("Can read 100 bytes from start of file", func() {
+ n, err := stream.Read(make([]byte, 100))
+ So(err, ShouldBeNil)
+ So(n, ShouldEqual, 100)
+
+ Convey("When seeking 4KiB forward", func() {
+ pos, err := stream.Seek(4096, io.SeekCurrent)
+ So(err, ShouldBeNil)
+ So(pos, ShouldEqual, 4096+100)
+
+ Convey("Can read 100 bytes", func() {
+ n, err := stream.Read(make([]byte, 100))
+ So(err, ShouldBeNil)
+ So(n, ShouldEqual, 100)
+ })
+ })
+ })
+ })
+ })
+}
+
+func TestHttpReaderSeeker(t *testing.T) {
+ tests := []struct {
+ name string
+ newRS func() *HttpReadSeeker
+ }{
+ {name: "with no flags", newRS: newRSFactory(0)},
+ {name: "with only Accept-Ranges", newRS: newRSFactory(sendAcceptRanges)},
+ {name: "downgrade 0-range to no range", newRS: newRSFactory(downgradeZeroToNoRange)},
+ {name: "downgrade 0-range with Accept-Ranges", newRS: newRSFactory(downgradeZeroToNoRange | sendAcceptRanges)},
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ testHttpReaderSeeker(t, test.newRS)
+ })
+ }
+}
+
+func testHttpReaderSeeker(t *testing.T, newRS RSFactory) {
+ Convey("Scenario: testing HttpReaderSeeker", t, func() {
+
+ Convey("Read should start at the beginning", func() {
+ r := newRS()
+ So(r, ShouldNotBeNil)
+ defer r.Close()
+ buf := make([]byte, 4)
+ n, err := io.ReadFull(r, buf)
+ So(n, ShouldEqual, 4)
+ So(err, ShouldBeNil)
+ So(string(buf), ShouldEqual, "0000")
+ })
+
+ Convey("Seek w SEEK_SET should seek to right offset", func() {
+ r := newRS()
+ So(r, ShouldNotBeNil)
+ defer r.Close()
+ s, err := r.Seek(4*64, io.SeekStart)
+ So(s, ShouldEqual, 4*64)
+ So(err, ShouldBeNil)
+ buf := make([]byte, 4)
+ n, err := io.ReadFull(r, buf)
+ So(n, ShouldEqual, 4)
+ So(err, ShouldBeNil)
+ So(string(buf), ShouldEqual, "0064")
+ })
+
+ Convey("Read + Seek w SEEK_CUR should seek to right offset", func() {
+ r := newRS()
+ So(r, ShouldNotBeNil)
+ defer r.Close()
+ buf := make([]byte, 4)
+ io.ReadFull(r, buf)
+ s, err := r.Seek(4*64, os.SEEK_CUR)
+ So(s, ShouldEqual, 4*64+4)
+ So(err, ShouldBeNil)
+ n, err := io.ReadFull(r, buf)
+ So(n, ShouldEqual, 4)
+ So(err, ShouldBeNil)
+ So(string(buf), ShouldEqual, "0065")
+ })
+
+ Convey("Seek w SEEK_END should seek to right offset", func() {
+ r := newRS()
+ So(r, ShouldNotBeNil)
+ defer r.Close()
+ buf := make([]byte, 4)
+ io.ReadFull(r, buf)
+ s, err := r.Seek(4, os.SEEK_END)
+ So(s, ShouldEqual, SZ*4-4)
+ So(err, ShouldBeNil)
+ n, err := io.ReadFull(r, buf)
+ So(n, ShouldEqual, 4)
+ So(err, ShouldBeNil)
+ So(string(buf), ShouldEqual, fmt.Sprintf("%04d", SZ-1))
+ })
+
+ Convey("Short seek should consume existing request", func() {
+ r := newRS()
+ So(r, ShouldNotBeNil)
+ defer r.Close()
+ buf := make([]byte, 4)
+ So(r.Requests, ShouldEqual, 0)
+ io.ReadFull(r, buf)
+ So(r.Requests, ShouldEqual, 1)
+ s, err := r.Seek(shortSeekBytes, os.SEEK_CUR)
+ So(r.Requests, ShouldEqual, 1)
+ So(s, ShouldEqual, shortSeekBytes+4)
+ So(err, ShouldBeNil)
+ n, err := io.ReadFull(r, buf)
+ So(n, ShouldEqual, 4)
+ So(err, ShouldBeNil)
+ So(string(buf), ShouldEqual, "0257")
+ So(r.Requests, ShouldEqual, 1)
+ })
+
+ Convey("Long seek should do a new request", func() {
+ r := newRS()
+ So(r, ShouldNotBeNil)
+ defer r.Close()
+ buf := make([]byte, 4)
+ So(r.Requests, ShouldEqual, 0)
+ io.ReadFull(r, buf)
+ So(r.Requests, ShouldEqual, 1)
+ s, err := r.Seek(shortSeekBytes+1, os.SEEK_CUR)
+ So(r.Requests, ShouldEqual, 1)
+ So(s, ShouldEqual, shortSeekBytes+4+1)
+ So(err, ShouldBeNil)
+ n, err := io.ReadFull(r, buf)
+ So(n, ShouldEqual, 4)
+ So(err, ShouldBeNil)
+ So(string(buf), ShouldEqual, "2570")
+ So(r.Requests, ShouldEqual, 2)
+ })
+ })
+}