summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAsh McKenzie <amckenzie@gitlab.com>2020-04-17 00:33:11 +1000
committerAsh McKenzie <amckenzie@gitlab.com>2020-04-17 16:23:32 +1000
commit1e3d7b6d51c119d34df5a19be4490e4662829128 (patch)
treee245a806b758417dd3c77759bc9490df2f83dc50
parent932e5d46a96bb1daebb9268789e765dda9f6d3ef (diff)
downloadgitlab-shell-1e3d7b6d51c119d34df5a19be4490e4662829128.tar.gz
New pktline package
Package is responsible for parsing git pkt lines. Copied from gitaly, for now.
-rw-r--r--internal/pktline/pktline.go75
-rw-r--r--internal/pktline/pktline_test.go87
2 files changed, 162 insertions, 0 deletions
diff --git a/internal/pktline/pktline.go b/internal/pktline/pktline.go
new file mode 100644
index 0000000..35fceb2
--- /dev/null
+++ b/internal/pktline/pktline.go
@@ -0,0 +1,75 @@
+package pktline
+
+// Utility functions for working with the Git pkt-line format. See
+// https://github.com/git/git/blob/master/Documentation/technical/protocol-common.txt
+
+import (
+ "bufio"
+ "bytes"
+ "fmt"
+ "io"
+ "strconv"
+)
+
+const (
+ maxPktSize = 0xffff
+ pktDelim = "0001"
+)
+
+// NewScanner returns a bufio.Scanner that splits on Git pktline boundaries
+func NewScanner(r io.Reader) *bufio.Scanner {
+ scanner := bufio.NewScanner(r)
+ scanner.Buffer(make([]byte, maxPktSize), maxPktSize)
+ scanner.Split(pktLineSplitter)
+ return scanner
+}
+
+// IsDone detects the special flush packet '0009done\n'
+func IsDone(pkt []byte) bool {
+ return bytes.Equal(pkt, PktDone())
+}
+
+// PktDone returns the bytes for a "done" packet.
+func PktDone() []byte {
+ return []byte("0009done\n")
+}
+
+func pktLineSplitter(data []byte, atEOF bool) (advance int, token []byte, err error) {
+ if len(data) < 4 {
+ if atEOF && len(data) > 0 {
+ return 0, nil, fmt.Errorf("pktLineSplitter: incomplete length prefix on %q", data)
+ }
+ return 0, nil, nil // want more data
+ }
+
+ // We have at least 4 bytes available so we can decode the 4-hex digit
+ // length prefix of the packet line.
+ pktLength64, err := strconv.ParseInt(string(data[:4]), 16, 0)
+ if err != nil {
+ return 0, nil, fmt.Errorf("pktLineSplitter: decode length: %v", err)
+ }
+
+ // Cast is safe because we requested an int-size number from strconv.ParseInt
+ pktLength := int(pktLength64)
+
+ if pktLength < 0 {
+ return 0, nil, fmt.Errorf("pktLineSplitter: invalid length: %d", pktLength)
+ }
+
+ if pktLength < 4 {
+ // Special case: magic empty packet 0000, 0001, 0002 or 0003.
+ return 4, data[:4], nil
+ }
+
+ if len(data) < pktLength {
+ // data contains incomplete packet
+
+ if atEOF {
+ return 0, nil, fmt.Errorf("pktLineSplitter: less than %d bytes in input %q", pktLength, data)
+ }
+
+ return 0, nil, nil // want more data
+ }
+
+ return pktLength, data[:pktLength], nil
+}
diff --git a/internal/pktline/pktline_test.go b/internal/pktline/pktline_test.go
new file mode 100644
index 0000000..cf3f6fd
--- /dev/null
+++ b/internal/pktline/pktline_test.go
@@ -0,0 +1,87 @@
+package pktline
+
+import (
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+var (
+ largestString = strings.Repeat("z", 0xffff-4)
+)
+
+func TestScanner(t *testing.T) {
+ largestPacket := "ffff" + largestString
+ testCases := []struct {
+ desc string
+ in string
+ out []string
+ fail bool
+ }{
+ {
+ desc: "happy path",
+ in: "0010hello world!000000010010hello world!",
+ out: []string{"0010hello world!", "0000", "0001", "0010hello world!"},
+ },
+ {
+ desc: "large input",
+ in: "0010hello world!0000" + largestPacket + "0000",
+ out: []string{"0010hello world!", "0000", largestPacket, "0000"},
+ },
+ {
+ desc: "missing byte middle",
+ in: "0010hello world!00000010010hello world!",
+ out: []string{"0010hello world!", "0000", "0010010hello wor"},
+ fail: true,
+ },
+ {
+ desc: "unfinished prefix",
+ in: "0010hello world!000",
+ out: []string{"0010hello world!"},
+ fail: true,
+ },
+ {
+ desc: "short read in data, only prefix",
+ in: "0010hello world!0005",
+ out: []string{"0010hello world!"},
+ fail: true,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ scanner := NewScanner(strings.NewReader(tc.in))
+ var output []string
+ for scanner.Scan() {
+ output = append(output, scanner.Text())
+ }
+
+ if tc.fail {
+ require.Error(t, scanner.Err())
+ } else {
+ require.NoError(t, scanner.Err())
+ }
+
+ require.Equal(t, tc.out, output)
+ })
+ }
+}
+
+func TestIsDone(t *testing.T) {
+ testCases := []struct {
+ in string
+ done bool
+ }{
+ {in: "0008abcd", done: false},
+ {in: "invalid packet", done: false},
+ {in: "0009done\n", done: true},
+ {in: "0001", done: false},
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.in, func(t *testing.T) {
+ require.Equal(t, tc.done, IsDone([]byte(tc.in)))
+ })
+ }
+}