diff options
author | Ash McKenzie <amckenzie@gitlab.com> | 2020-04-17 00:33:11 +1000 |
---|---|---|
committer | Ash McKenzie <amckenzie@gitlab.com> | 2020-04-17 16:23:32 +1000 |
commit | 1e3d7b6d51c119d34df5a19be4490e4662829128 (patch) | |
tree | e245a806b758417dd3c77759bc9490df2f83dc50 | |
parent | 932e5d46a96bb1daebb9268789e765dda9f6d3ef (diff) | |
download | gitlab-shell-1e3d7b6d51c119d34df5a19be4490e4662829128.tar.gz |
New pktline package
Package is responsible for parsing git
pkt lines.
Copied from gitaly, for now.
-rw-r--r-- | internal/pktline/pktline.go | 75 | ||||
-rw-r--r-- | internal/pktline/pktline_test.go | 87 |
2 files changed, 162 insertions, 0 deletions
diff --git a/internal/pktline/pktline.go b/internal/pktline/pktline.go new file mode 100644 index 0000000..35fceb2 --- /dev/null +++ b/internal/pktline/pktline.go @@ -0,0 +1,75 @@ +package pktline + +// Utility functions for working with the Git pkt-line format. See +// https://github.com/git/git/blob/master/Documentation/technical/protocol-common.txt + +import ( + "bufio" + "bytes" + "fmt" + "io" + "strconv" +) + +const ( + maxPktSize = 0xffff + pktDelim = "0001" +) + +// NewScanner returns a bufio.Scanner that splits on Git pktline boundaries +func NewScanner(r io.Reader) *bufio.Scanner { + scanner := bufio.NewScanner(r) + scanner.Buffer(make([]byte, maxPktSize), maxPktSize) + scanner.Split(pktLineSplitter) + return scanner +} + +// IsDone detects the special flush packet '0009done\n' +func IsDone(pkt []byte) bool { + return bytes.Equal(pkt, PktDone()) +} + +// PktDone returns the bytes for a "done" packet. +func PktDone() []byte { + return []byte("0009done\n") +} + +func pktLineSplitter(data []byte, atEOF bool) (advance int, token []byte, err error) { + if len(data) < 4 { + if atEOF && len(data) > 0 { + return 0, nil, fmt.Errorf("pktLineSplitter: incomplete length prefix on %q", data) + } + return 0, nil, nil // want more data + } + + // We have at least 4 bytes available so we can decode the 4-hex digit + // length prefix of the packet line. + pktLength64, err := strconv.ParseInt(string(data[:4]), 16, 0) + if err != nil { + return 0, nil, fmt.Errorf("pktLineSplitter: decode length: %v", err) + } + + // Cast is safe because we requested an int-size number from strconv.ParseInt + pktLength := int(pktLength64) + + if pktLength < 0 { + return 0, nil, fmt.Errorf("pktLineSplitter: invalid length: %d", pktLength) + } + + if pktLength < 4 { + // Special case: magic empty packet 0000, 0001, 0002 or 0003. + return 4, data[:4], nil + } + + if len(data) < pktLength { + // data contains incomplete packet + + if atEOF { + return 0, nil, fmt.Errorf("pktLineSplitter: less than %d bytes in input %q", pktLength, data) + } + + return 0, nil, nil // want more data + } + + return pktLength, data[:pktLength], nil +} diff --git a/internal/pktline/pktline_test.go b/internal/pktline/pktline_test.go new file mode 100644 index 0000000..cf3f6fd --- /dev/null +++ b/internal/pktline/pktline_test.go @@ -0,0 +1,87 @@ +package pktline + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +var ( + largestString = strings.Repeat("z", 0xffff-4) +) + +func TestScanner(t *testing.T) { + largestPacket := "ffff" + largestString + testCases := []struct { + desc string + in string + out []string + fail bool + }{ + { + desc: "happy path", + in: "0010hello world!000000010010hello world!", + out: []string{"0010hello world!", "0000", "0001", "0010hello world!"}, + }, + { + desc: "large input", + in: "0010hello world!0000" + largestPacket + "0000", + out: []string{"0010hello world!", "0000", largestPacket, "0000"}, + }, + { + desc: "missing byte middle", + in: "0010hello world!00000010010hello world!", + out: []string{"0010hello world!", "0000", "0010010hello wor"}, + fail: true, + }, + { + desc: "unfinished prefix", + in: "0010hello world!000", + out: []string{"0010hello world!"}, + fail: true, + }, + { + desc: "short read in data, only prefix", + in: "0010hello world!0005", + out: []string{"0010hello world!"}, + fail: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + scanner := NewScanner(strings.NewReader(tc.in)) + var output []string + for scanner.Scan() { + output = append(output, scanner.Text()) + } + + if tc.fail { + require.Error(t, scanner.Err()) + } else { + require.NoError(t, scanner.Err()) + } + + require.Equal(t, tc.out, output) + }) + } +} + +func TestIsDone(t *testing.T) { + testCases := []struct { + in string + done bool + }{ + {in: "0008abcd", done: false}, + {in: "invalid packet", done: false}, + {in: "0009done\n", done: true}, + {in: "0001", done: false}, + } + + for _, tc := range testCases { + t.Run(tc.in, func(t *testing.T) { + require.Equal(t, tc.done, IsDone([]byte(tc.in))) + }) + } +} |