summaryrefslogtreecommitdiff
path: root/src/crypto/tls/handshake_messages.go
diff options
context:
space:
mode:
Diffstat (limited to 'src/crypto/tls/handshake_messages.go')
-rw-r--r--src/crypto/tls/handshake_messages.go1438
1 files changed, 1438 insertions, 0 deletions
diff --git a/src/crypto/tls/handshake_messages.go b/src/crypto/tls/handshake_messages.go
new file mode 100644
index 000000000..5d14871a3
--- /dev/null
+++ b/src/crypto/tls/handshake_messages.go
@@ -0,0 +1,1438 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package tls
+
+import "bytes"
+
+type clientHelloMsg struct {
+ raw []byte
+ vers uint16
+ random []byte
+ sessionId []byte
+ cipherSuites []uint16
+ compressionMethods []uint8
+ nextProtoNeg bool
+ serverName string
+ ocspStapling bool
+ supportedCurves []CurveID
+ supportedPoints []uint8
+ ticketSupported bool
+ sessionTicket []uint8
+ signatureAndHashes []signatureAndHash
+ secureRenegotiation bool
+ alpnProtocols []string
+}
+
+func (m *clientHelloMsg) equal(i interface{}) bool {
+ m1, ok := i.(*clientHelloMsg)
+ if !ok {
+ return false
+ }
+
+ return bytes.Equal(m.raw, m1.raw) &&
+ m.vers == m1.vers &&
+ bytes.Equal(m.random, m1.random) &&
+ bytes.Equal(m.sessionId, m1.sessionId) &&
+ eqUint16s(m.cipherSuites, m1.cipherSuites) &&
+ bytes.Equal(m.compressionMethods, m1.compressionMethods) &&
+ m.nextProtoNeg == m1.nextProtoNeg &&
+ m.serverName == m1.serverName &&
+ m.ocspStapling == m1.ocspStapling &&
+ eqCurveIDs(m.supportedCurves, m1.supportedCurves) &&
+ bytes.Equal(m.supportedPoints, m1.supportedPoints) &&
+ m.ticketSupported == m1.ticketSupported &&
+ bytes.Equal(m.sessionTicket, m1.sessionTicket) &&
+ eqSignatureAndHashes(m.signatureAndHashes, m1.signatureAndHashes) &&
+ m.secureRenegotiation == m1.secureRenegotiation &&
+ eqStrings(m.alpnProtocols, m1.alpnProtocols)
+}
+
+func (m *clientHelloMsg) marshal() []byte {
+ if m.raw != nil {
+ return m.raw
+ }
+
+ length := 2 + 32 + 1 + len(m.sessionId) + 2 + len(m.cipherSuites)*2 + 1 + len(m.compressionMethods)
+ numExtensions := 0
+ extensionsLength := 0
+ if m.nextProtoNeg {
+ numExtensions++
+ }
+ if m.ocspStapling {
+ extensionsLength += 1 + 2 + 2
+ numExtensions++
+ }
+ if len(m.serverName) > 0 {
+ extensionsLength += 5 + len(m.serverName)
+ numExtensions++
+ }
+ if len(m.supportedCurves) > 0 {
+ extensionsLength += 2 + 2*len(m.supportedCurves)
+ numExtensions++
+ }
+ if len(m.supportedPoints) > 0 {
+ extensionsLength += 1 + len(m.supportedPoints)
+ numExtensions++
+ }
+ if m.ticketSupported {
+ extensionsLength += len(m.sessionTicket)
+ numExtensions++
+ }
+ if len(m.signatureAndHashes) > 0 {
+ extensionsLength += 2 + 2*len(m.signatureAndHashes)
+ numExtensions++
+ }
+ if m.secureRenegotiation {
+ extensionsLength += 1
+ numExtensions++
+ }
+ if len(m.alpnProtocols) > 0 {
+ extensionsLength += 2
+ for _, s := range m.alpnProtocols {
+ if l := len(s); l == 0 || l > 255 {
+ panic("invalid ALPN protocol")
+ }
+ extensionsLength++
+ extensionsLength += len(s)
+ }
+ numExtensions++
+ }
+ if numExtensions > 0 {
+ extensionsLength += 4 * numExtensions
+ length += 2 + extensionsLength
+ }
+
+ x := make([]byte, 4+length)
+ x[0] = typeClientHello
+ x[1] = uint8(length >> 16)
+ x[2] = uint8(length >> 8)
+ x[3] = uint8(length)
+ x[4] = uint8(m.vers >> 8)
+ x[5] = uint8(m.vers)
+ copy(x[6:38], m.random)
+ x[38] = uint8(len(m.sessionId))
+ copy(x[39:39+len(m.sessionId)], m.sessionId)
+ y := x[39+len(m.sessionId):]
+ y[0] = uint8(len(m.cipherSuites) >> 7)
+ y[1] = uint8(len(m.cipherSuites) << 1)
+ for i, suite := range m.cipherSuites {
+ y[2+i*2] = uint8(suite >> 8)
+ y[3+i*2] = uint8(suite)
+ }
+ z := y[2+len(m.cipherSuites)*2:]
+ z[0] = uint8(len(m.compressionMethods))
+ copy(z[1:], m.compressionMethods)
+
+ z = z[1+len(m.compressionMethods):]
+ if numExtensions > 0 {
+ z[0] = byte(extensionsLength >> 8)
+ z[1] = byte(extensionsLength)
+ z = z[2:]
+ }
+ if m.nextProtoNeg {
+ z[0] = byte(extensionNextProtoNeg >> 8)
+ z[1] = byte(extensionNextProtoNeg & 0xff)
+ // The length is always 0
+ z = z[4:]
+ }
+ if len(m.serverName) > 0 {
+ z[0] = byte(extensionServerName >> 8)
+ z[1] = byte(extensionServerName & 0xff)
+ l := len(m.serverName) + 5
+ z[2] = byte(l >> 8)
+ z[3] = byte(l)
+ z = z[4:]
+
+ // RFC 3546, section 3.1
+ //
+ // struct {
+ // NameType name_type;
+ // select (name_type) {
+ // case host_name: HostName;
+ // } name;
+ // } ServerName;
+ //
+ // enum {
+ // host_name(0), (255)
+ // } NameType;
+ //
+ // opaque HostName<1..2^16-1>;
+ //
+ // struct {
+ // ServerName server_name_list<1..2^16-1>
+ // } ServerNameList;
+
+ z[0] = byte((len(m.serverName) + 3) >> 8)
+ z[1] = byte(len(m.serverName) + 3)
+ z[3] = byte(len(m.serverName) >> 8)
+ z[4] = byte(len(m.serverName))
+ copy(z[5:], []byte(m.serverName))
+ z = z[l:]
+ }
+ if m.ocspStapling {
+ // RFC 4366, section 3.6
+ z[0] = byte(extensionStatusRequest >> 8)
+ z[1] = byte(extensionStatusRequest)
+ z[2] = 0
+ z[3] = 5
+ z[4] = 1 // OCSP type
+ // Two zero valued uint16s for the two lengths.
+ z = z[9:]
+ }
+ if len(m.supportedCurves) > 0 {
+ // http://tools.ietf.org/html/rfc4492#section-5.5.1
+ z[0] = byte(extensionSupportedCurves >> 8)
+ z[1] = byte(extensionSupportedCurves)
+ l := 2 + 2*len(m.supportedCurves)
+ z[2] = byte(l >> 8)
+ z[3] = byte(l)
+ l -= 2
+ z[4] = byte(l >> 8)
+ z[5] = byte(l)
+ z = z[6:]
+ for _, curve := range m.supportedCurves {
+ z[0] = byte(curve >> 8)
+ z[1] = byte(curve)
+ z = z[2:]
+ }
+ }
+ if len(m.supportedPoints) > 0 {
+ // http://tools.ietf.org/html/rfc4492#section-5.5.2
+ z[0] = byte(extensionSupportedPoints >> 8)
+ z[1] = byte(extensionSupportedPoints)
+ l := 1 + len(m.supportedPoints)
+ z[2] = byte(l >> 8)
+ z[3] = byte(l)
+ l--
+ z[4] = byte(l)
+ z = z[5:]
+ for _, pointFormat := range m.supportedPoints {
+ z[0] = byte(pointFormat)
+ z = z[1:]
+ }
+ }
+ if m.ticketSupported {
+ // http://tools.ietf.org/html/rfc5077#section-3.2
+ z[0] = byte(extensionSessionTicket >> 8)
+ z[1] = byte(extensionSessionTicket)
+ l := len(m.sessionTicket)
+ z[2] = byte(l >> 8)
+ z[3] = byte(l)
+ z = z[4:]
+ copy(z, m.sessionTicket)
+ z = z[len(m.sessionTicket):]
+ }
+ if len(m.signatureAndHashes) > 0 {
+ // https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1
+ z[0] = byte(extensionSignatureAlgorithms >> 8)
+ z[1] = byte(extensionSignatureAlgorithms)
+ l := 2 + 2*len(m.signatureAndHashes)
+ z[2] = byte(l >> 8)
+ z[3] = byte(l)
+ z = z[4:]
+
+ l -= 2
+ z[0] = byte(l >> 8)
+ z[1] = byte(l)
+ z = z[2:]
+ for _, sigAndHash := range m.signatureAndHashes {
+ z[0] = sigAndHash.hash
+ z[1] = sigAndHash.signature
+ z = z[2:]
+ }
+ }
+ if m.secureRenegotiation {
+ z[0] = byte(extensionRenegotiationInfo >> 8)
+ z[1] = byte(extensionRenegotiationInfo & 0xff)
+ z[2] = 0
+ z[3] = 1
+ z = z[5:]
+ }
+ if len(m.alpnProtocols) > 0 {
+ z[0] = byte(extensionALPN >> 8)
+ z[1] = byte(extensionALPN & 0xff)
+ lengths := z[2:]
+ z = z[6:]
+
+ stringsLength := 0
+ for _, s := range m.alpnProtocols {
+ l := len(s)
+ z[0] = byte(l)
+ copy(z[1:], s)
+ z = z[1+l:]
+ stringsLength += 1 + l
+ }
+
+ lengths[2] = byte(stringsLength >> 8)
+ lengths[3] = byte(stringsLength)
+ stringsLength += 2
+ lengths[0] = byte(stringsLength >> 8)
+ lengths[1] = byte(stringsLength)
+ }
+
+ m.raw = x
+
+ return x
+}
+
+func (m *clientHelloMsg) unmarshal(data []byte) bool {
+ if len(data) < 42 {
+ return false
+ }
+ m.raw = data
+ m.vers = uint16(data[4])<<8 | uint16(data[5])
+ m.random = data[6:38]
+ sessionIdLen := int(data[38])
+ if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
+ return false
+ }
+ m.sessionId = data[39 : 39+sessionIdLen]
+ data = data[39+sessionIdLen:]
+ if len(data) < 2 {
+ return false
+ }
+ // cipherSuiteLen is the number of bytes of cipher suite numbers. Since
+ // they are uint16s, the number must be even.
+ cipherSuiteLen := int(data[0])<<8 | int(data[1])
+ if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen {
+ return false
+ }
+ numCipherSuites := cipherSuiteLen / 2
+ m.cipherSuites = make([]uint16, numCipherSuites)
+ for i := 0; i < numCipherSuites; i++ {
+ m.cipherSuites[i] = uint16(data[2+2*i])<<8 | uint16(data[3+2*i])
+ if m.cipherSuites[i] == scsvRenegotiation {
+ m.secureRenegotiation = true
+ }
+ }
+ data = data[2+cipherSuiteLen:]
+ if len(data) < 1 {
+ return false
+ }
+ compressionMethodsLen := int(data[0])
+ if len(data) < 1+compressionMethodsLen {
+ return false
+ }
+ m.compressionMethods = data[1 : 1+compressionMethodsLen]
+
+ data = data[1+compressionMethodsLen:]
+
+ m.nextProtoNeg = false
+ m.serverName = ""
+ m.ocspStapling = false
+ m.ticketSupported = false
+ m.sessionTicket = nil
+ m.signatureAndHashes = nil
+ m.alpnProtocols = nil
+
+ if len(data) == 0 {
+ // ClientHello is optionally followed by extension data
+ return true
+ }
+ if len(data) < 2 {
+ return false
+ }
+
+ extensionsLength := int(data[0])<<8 | int(data[1])
+ data = data[2:]
+ if extensionsLength != len(data) {
+ return false
+ }
+
+ for len(data) != 0 {
+ if len(data) < 4 {
+ return false
+ }
+ extension := uint16(data[0])<<8 | uint16(data[1])
+ length := int(data[2])<<8 | int(data[3])
+ data = data[4:]
+ if len(data) < length {
+ return false
+ }
+
+ switch extension {
+ case extensionServerName:
+ if length < 2 {
+ return false
+ }
+ numNames := int(data[0])<<8 | int(data[1])
+ d := data[2:]
+ for i := 0; i < numNames; i++ {
+ if len(d) < 3 {
+ return false
+ }
+ nameType := d[0]
+ nameLen := int(d[1])<<8 | int(d[2])
+ d = d[3:]
+ if len(d) < nameLen {
+ return false
+ }
+ if nameType == 0 {
+ m.serverName = string(d[0:nameLen])
+ break
+ }
+ d = d[nameLen:]
+ }
+ case extensionNextProtoNeg:
+ if length > 0 {
+ return false
+ }
+ m.nextProtoNeg = true
+ case extensionStatusRequest:
+ m.ocspStapling = length > 0 && data[0] == statusTypeOCSP
+ case extensionSupportedCurves:
+ // http://tools.ietf.org/html/rfc4492#section-5.5.1
+ if length < 2 {
+ return false
+ }
+ l := int(data[0])<<8 | int(data[1])
+ if l%2 == 1 || length != l+2 {
+ return false
+ }
+ numCurves := l / 2
+ m.supportedCurves = make([]CurveID, numCurves)
+ d := data[2:]
+ for i := 0; i < numCurves; i++ {
+ m.supportedCurves[i] = CurveID(d[0])<<8 | CurveID(d[1])
+ d = d[2:]
+ }
+ case extensionSupportedPoints:
+ // http://tools.ietf.org/html/rfc4492#section-5.5.2
+ if length < 1 {
+ return false
+ }
+ l := int(data[0])
+ if length != l+1 {
+ return false
+ }
+ m.supportedPoints = make([]uint8, l)
+ copy(m.supportedPoints, data[1:])
+ case extensionSessionTicket:
+ // http://tools.ietf.org/html/rfc5077#section-3.2
+ m.ticketSupported = true
+ m.sessionTicket = data[:length]
+ case extensionSignatureAlgorithms:
+ // https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1
+ if length < 2 || length&1 != 0 {
+ return false
+ }
+ l := int(data[0])<<8 | int(data[1])
+ if l != length-2 {
+ return false
+ }
+ n := l / 2
+ d := data[2:]
+ m.signatureAndHashes = make([]signatureAndHash, n)
+ for i := range m.signatureAndHashes {
+ m.signatureAndHashes[i].hash = d[0]
+ m.signatureAndHashes[i].signature = d[1]
+ d = d[2:]
+ }
+ case extensionRenegotiationInfo + 1:
+ if length != 1 || data[0] != 0 {
+ return false
+ }
+ m.secureRenegotiation = true
+ case extensionALPN:
+ if length < 2 {
+ return false
+ }
+ l := int(data[0])<<8 | int(data[1])
+ if l != length-2 {
+ return false
+ }
+ d := data[2:length]
+ for len(d) != 0 {
+ stringLen := int(d[0])
+ d = d[1:]
+ if stringLen == 0 || stringLen > len(d) {
+ return false
+ }
+ m.alpnProtocols = append(m.alpnProtocols, string(d[:stringLen]))
+ d = d[stringLen:]
+ }
+ }
+ data = data[length:]
+ }
+
+ return true
+}
+
+type serverHelloMsg struct {
+ raw []byte
+ vers uint16
+ random []byte
+ sessionId []byte
+ cipherSuite uint16
+ compressionMethod uint8
+ nextProtoNeg bool
+ nextProtos []string
+ ocspStapling bool
+ ticketSupported bool
+ secureRenegotiation bool
+ alpnProtocol string
+}
+
+func (m *serverHelloMsg) equal(i interface{}) bool {
+ m1, ok := i.(*serverHelloMsg)
+ if !ok {
+ return false
+ }
+
+ return bytes.Equal(m.raw, m1.raw) &&
+ m.vers == m1.vers &&
+ bytes.Equal(m.random, m1.random) &&
+ bytes.Equal(m.sessionId, m1.sessionId) &&
+ m.cipherSuite == m1.cipherSuite &&
+ m.compressionMethod == m1.compressionMethod &&
+ m.nextProtoNeg == m1.nextProtoNeg &&
+ eqStrings(m.nextProtos, m1.nextProtos) &&
+ m.ocspStapling == m1.ocspStapling &&
+ m.ticketSupported == m1.ticketSupported &&
+ m.secureRenegotiation == m1.secureRenegotiation &&
+ m.alpnProtocol == m1.alpnProtocol
+}
+
+func (m *serverHelloMsg) marshal() []byte {
+ if m.raw != nil {
+ return m.raw
+ }
+
+ length := 38 + len(m.sessionId)
+ numExtensions := 0
+ extensionsLength := 0
+
+ nextProtoLen := 0
+ if m.nextProtoNeg {
+ numExtensions++
+ for _, v := range m.nextProtos {
+ nextProtoLen += len(v)
+ }
+ nextProtoLen += len(m.nextProtos)
+ extensionsLength += nextProtoLen
+ }
+ if m.ocspStapling {
+ numExtensions++
+ }
+ if m.ticketSupported {
+ numExtensions++
+ }
+ if m.secureRenegotiation {
+ extensionsLength += 1
+ numExtensions++
+ }
+ if alpnLen := len(m.alpnProtocol); alpnLen > 0 {
+ if alpnLen >= 256 {
+ panic("invalid ALPN protocol")
+ }
+ extensionsLength += 2 + 1 + alpnLen
+ numExtensions++
+ }
+
+ if numExtensions > 0 {
+ extensionsLength += 4 * numExtensions
+ length += 2 + extensionsLength
+ }
+
+ x := make([]byte, 4+length)
+ x[0] = typeServerHello
+ x[1] = uint8(length >> 16)
+ x[2] = uint8(length >> 8)
+ x[3] = uint8(length)
+ x[4] = uint8(m.vers >> 8)
+ x[5] = uint8(m.vers)
+ copy(x[6:38], m.random)
+ x[38] = uint8(len(m.sessionId))
+ copy(x[39:39+len(m.sessionId)], m.sessionId)
+ z := x[39+len(m.sessionId):]
+ z[0] = uint8(m.cipherSuite >> 8)
+ z[1] = uint8(m.cipherSuite)
+ z[2] = uint8(m.compressionMethod)
+
+ z = z[3:]
+ if numExtensions > 0 {
+ z[0] = byte(extensionsLength >> 8)
+ z[1] = byte(extensionsLength)
+ z = z[2:]
+ }
+ if m.nextProtoNeg {
+ z[0] = byte(extensionNextProtoNeg >> 8)
+ z[1] = byte(extensionNextProtoNeg & 0xff)
+ z[2] = byte(nextProtoLen >> 8)
+ z[3] = byte(nextProtoLen)
+ z = z[4:]
+
+ for _, v := range m.nextProtos {
+ l := len(v)
+ if l > 255 {
+ l = 255
+ }
+ z[0] = byte(l)
+ copy(z[1:], []byte(v[0:l]))
+ z = z[1+l:]
+ }
+ }
+ if m.ocspStapling {
+ z[0] = byte(extensionStatusRequest >> 8)
+ z[1] = byte(extensionStatusRequest)
+ z = z[4:]
+ }
+ if m.ticketSupported {
+ z[0] = byte(extensionSessionTicket >> 8)
+ z[1] = byte(extensionSessionTicket)
+ z = z[4:]
+ }
+ if m.secureRenegotiation {
+ z[0] = byte(extensionRenegotiationInfo >> 8)
+ z[1] = byte(extensionRenegotiationInfo & 0xff)
+ z[2] = 0
+ z[3] = 1
+ z = z[5:]
+ }
+ if alpnLen := len(m.alpnProtocol); alpnLen > 0 {
+ z[0] = byte(extensionALPN >> 8)
+ z[1] = byte(extensionALPN & 0xff)
+ l := 2 + 1 + alpnLen
+ z[2] = byte(l >> 8)
+ z[3] = byte(l)
+ l -= 2
+ z[4] = byte(l >> 8)
+ z[5] = byte(l)
+ l -= 1
+ z[6] = byte(l)
+ copy(z[7:], []byte(m.alpnProtocol))
+ z = z[7+alpnLen:]
+ }
+
+ m.raw = x
+
+ return x
+}
+
+func (m *serverHelloMsg) unmarshal(data []byte) bool {
+ if len(data) < 42 {
+ return false
+ }
+ m.raw = data
+ m.vers = uint16(data[4])<<8 | uint16(data[5])
+ m.random = data[6:38]
+ sessionIdLen := int(data[38])
+ if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
+ return false
+ }
+ m.sessionId = data[39 : 39+sessionIdLen]
+ data = data[39+sessionIdLen:]
+ if len(data) < 3 {
+ return false
+ }
+ m.cipherSuite = uint16(data[0])<<8 | uint16(data[1])
+ m.compressionMethod = data[2]
+ data = data[3:]
+
+ m.nextProtoNeg = false
+ m.nextProtos = nil
+ m.ocspStapling = false
+ m.ticketSupported = false
+ m.alpnProtocol = ""
+
+ if len(data) == 0 {
+ // ServerHello is optionally followed by extension data
+ return true
+ }
+ if len(data) < 2 {
+ return false
+ }
+
+ extensionsLength := int(data[0])<<8 | int(data[1])
+ data = data[2:]
+ if len(data) != extensionsLength {
+ return false
+ }
+
+ for len(data) != 0 {
+ if len(data) < 4 {
+ return false
+ }
+ extension := uint16(data[0])<<8 | uint16(data[1])
+ length := int(data[2])<<8 | int(data[3])
+ data = data[4:]
+ if len(data) < length {
+ return false
+ }
+
+ switch extension {
+ case extensionNextProtoNeg:
+ m.nextProtoNeg = true
+ d := data[:length]
+ for len(d) > 0 {
+ l := int(d[0])
+ d = d[1:]
+ if l == 0 || l > len(d) {
+ return false
+ }
+ m.nextProtos = append(m.nextProtos, string(d[:l]))
+ d = d[l:]
+ }
+ case extensionStatusRequest:
+ if length > 0 {
+ return false
+ }
+ m.ocspStapling = true
+ case extensionSessionTicket:
+ if length > 0 {
+ return false
+ }
+ m.ticketSupported = true
+ case extensionRenegotiationInfo:
+ if length != 1 || data[0] != 0 {
+ return false
+ }
+ m.secureRenegotiation = true
+ case extensionALPN:
+ d := data[:length]
+ if len(d) < 3 {
+ return false
+ }
+ l := int(d[0])<<8 | int(d[1])
+ if l != len(d)-2 {
+ return false
+ }
+ d = d[2:]
+ l = int(d[0])
+ if l != len(d)-1 {
+ return false
+ }
+ d = d[1:]
+ m.alpnProtocol = string(d)
+ }
+ data = data[length:]
+ }
+
+ return true
+}
+
+type certificateMsg struct {
+ raw []byte
+ certificates [][]byte
+}
+
+func (m *certificateMsg) equal(i interface{}) bool {
+ m1, ok := i.(*certificateMsg)
+ if !ok {
+ return false
+ }
+
+ return bytes.Equal(m.raw, m1.raw) &&
+ eqByteSlices(m.certificates, m1.certificates)
+}
+
+func (m *certificateMsg) marshal() (x []byte) {
+ if m.raw != nil {
+ return m.raw
+ }
+
+ var i int
+ for _, slice := range m.certificates {
+ i += len(slice)
+ }
+
+ length := 3 + 3*len(m.certificates) + i
+ x = make([]byte, 4+length)
+ x[0] = typeCertificate
+ x[1] = uint8(length >> 16)
+ x[2] = uint8(length >> 8)
+ x[3] = uint8(length)
+
+ certificateOctets := length - 3
+ x[4] = uint8(certificateOctets >> 16)
+ x[5] = uint8(certificateOctets >> 8)
+ x[6] = uint8(certificateOctets)
+
+ y := x[7:]
+ for _, slice := range m.certificates {
+ y[0] = uint8(len(slice) >> 16)
+ y[1] = uint8(len(slice) >> 8)
+ y[2] = uint8(len(slice))
+ copy(y[3:], slice)
+ y = y[3+len(slice):]
+ }
+
+ m.raw = x
+ return
+}
+
+func (m *certificateMsg) unmarshal(data []byte) bool {
+ if len(data) < 7 {
+ return false
+ }
+
+ m.raw = data
+ certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6])
+ if uint32(len(data)) != certsLen+7 {
+ return false
+ }
+
+ numCerts := 0
+ d := data[7:]
+ for certsLen > 0 {
+ if len(d) < 4 {
+ return false
+ }
+ certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
+ if uint32(len(d)) < 3+certLen {
+ return false
+ }
+ d = d[3+certLen:]
+ certsLen -= 3 + certLen
+ numCerts++
+ }
+
+ m.certificates = make([][]byte, numCerts)
+ d = data[7:]
+ for i := 0; i < numCerts; i++ {
+ certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
+ m.certificates[i] = d[3 : 3+certLen]
+ d = d[3+certLen:]
+ }
+
+ return true
+}
+
+type serverKeyExchangeMsg struct {
+ raw []byte
+ key []byte
+}
+
+func (m *serverKeyExchangeMsg) equal(i interface{}) bool {
+ m1, ok := i.(*serverKeyExchangeMsg)
+ if !ok {
+ return false
+ }
+
+ return bytes.Equal(m.raw, m1.raw) &&
+ bytes.Equal(m.key, m1.key)
+}
+
+func (m *serverKeyExchangeMsg) marshal() []byte {
+ if m.raw != nil {
+ return m.raw
+ }
+ length := len(m.key)
+ x := make([]byte, length+4)
+ x[0] = typeServerKeyExchange
+ x[1] = uint8(length >> 16)
+ x[2] = uint8(length >> 8)
+ x[3] = uint8(length)
+ copy(x[4:], m.key)
+
+ m.raw = x
+ return x
+}
+
+func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool {
+ m.raw = data
+ if len(data) < 4 {
+ return false
+ }
+ m.key = data[4:]
+ return true
+}
+
+type certificateStatusMsg struct {
+ raw []byte
+ statusType uint8
+ response []byte
+}
+
+func (m *certificateStatusMsg) equal(i interface{}) bool {
+ m1, ok := i.(*certificateStatusMsg)
+ if !ok {
+ return false
+ }
+
+ return bytes.Equal(m.raw, m1.raw) &&
+ m.statusType == m1.statusType &&
+ bytes.Equal(m.response, m1.response)
+}
+
+func (m *certificateStatusMsg) marshal() []byte {
+ if m.raw != nil {
+ return m.raw
+ }
+
+ var x []byte
+ if m.statusType == statusTypeOCSP {
+ x = make([]byte, 4+4+len(m.response))
+ x[0] = typeCertificateStatus
+ l := len(m.response) + 4
+ x[1] = byte(l >> 16)
+ x[2] = byte(l >> 8)
+ x[3] = byte(l)
+ x[4] = statusTypeOCSP
+
+ l -= 4
+ x[5] = byte(l >> 16)
+ x[6] = byte(l >> 8)
+ x[7] = byte(l)
+ copy(x[8:], m.response)
+ } else {
+ x = []byte{typeCertificateStatus, 0, 0, 1, m.statusType}
+ }
+
+ m.raw = x
+ return x
+}
+
+func (m *certificateStatusMsg) unmarshal(data []byte) bool {
+ m.raw = data
+ if len(data) < 5 {
+ return false
+ }
+ m.statusType = data[4]
+
+ m.response = nil
+ if m.statusType == statusTypeOCSP {
+ if len(data) < 8 {
+ return false
+ }
+ respLen := uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7])
+ if uint32(len(data)) != 4+4+respLen {
+ return false
+ }
+ m.response = data[8:]
+ }
+ return true
+}
+
+type serverHelloDoneMsg struct{}
+
+func (m *serverHelloDoneMsg) equal(i interface{}) bool {
+ _, ok := i.(*serverHelloDoneMsg)
+ return ok
+}
+
+func (m *serverHelloDoneMsg) marshal() []byte {
+ x := make([]byte, 4)
+ x[0] = typeServerHelloDone
+ return x
+}
+
+func (m *serverHelloDoneMsg) unmarshal(data []byte) bool {
+ return len(data) == 4
+}
+
+type clientKeyExchangeMsg struct {
+ raw []byte
+ ciphertext []byte
+}
+
+func (m *clientKeyExchangeMsg) equal(i interface{}) bool {
+ m1, ok := i.(*clientKeyExchangeMsg)
+ if !ok {
+ return false
+ }
+
+ return bytes.Equal(m.raw, m1.raw) &&
+ bytes.Equal(m.ciphertext, m1.ciphertext)
+}
+
+func (m *clientKeyExchangeMsg) marshal() []byte {
+ if m.raw != nil {
+ return m.raw
+ }
+ length := len(m.ciphertext)
+ x := make([]byte, length+4)
+ x[0] = typeClientKeyExchange
+ x[1] = uint8(length >> 16)
+ x[2] = uint8(length >> 8)
+ x[3] = uint8(length)
+ copy(x[4:], m.ciphertext)
+
+ m.raw = x
+ return x
+}
+
+func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool {
+ m.raw = data
+ if len(data) < 4 {
+ return false
+ }
+ l := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
+ if l != len(data)-4 {
+ return false
+ }
+ m.ciphertext = data[4:]
+ return true
+}
+
+type finishedMsg struct {
+ raw []byte
+ verifyData []byte
+}
+
+func (m *finishedMsg) equal(i interface{}) bool {
+ m1, ok := i.(*finishedMsg)
+ if !ok {
+ return false
+ }
+
+ return bytes.Equal(m.raw, m1.raw) &&
+ bytes.Equal(m.verifyData, m1.verifyData)
+}
+
+func (m *finishedMsg) marshal() (x []byte) {
+ if m.raw != nil {
+ return m.raw
+ }
+
+ x = make([]byte, 4+len(m.verifyData))
+ x[0] = typeFinished
+ x[3] = byte(len(m.verifyData))
+ copy(x[4:], m.verifyData)
+ m.raw = x
+ return
+}
+
+func (m *finishedMsg) unmarshal(data []byte) bool {
+ m.raw = data
+ if len(data) < 4 {
+ return false
+ }
+ m.verifyData = data[4:]
+ return true
+}
+
+type nextProtoMsg struct {
+ raw []byte
+ proto string
+}
+
+func (m *nextProtoMsg) equal(i interface{}) bool {
+ m1, ok := i.(*nextProtoMsg)
+ if !ok {
+ return false
+ }
+
+ return bytes.Equal(m.raw, m1.raw) &&
+ m.proto == m1.proto
+}
+
+func (m *nextProtoMsg) marshal() []byte {
+ if m.raw != nil {
+ return m.raw
+ }
+ l := len(m.proto)
+ if l > 255 {
+ l = 255
+ }
+
+ padding := 32 - (l+2)%32
+ length := l + padding + 2
+ x := make([]byte, length+4)
+ x[0] = typeNextProtocol
+ x[1] = uint8(length >> 16)
+ x[2] = uint8(length >> 8)
+ x[3] = uint8(length)
+
+ y := x[4:]
+ y[0] = byte(l)
+ copy(y[1:], []byte(m.proto[0:l]))
+ y = y[1+l:]
+ y[0] = byte(padding)
+
+ m.raw = x
+
+ return x
+}
+
+func (m *nextProtoMsg) unmarshal(data []byte) bool {
+ m.raw = data
+
+ if len(data) < 5 {
+ return false
+ }
+ data = data[4:]
+ protoLen := int(data[0])
+ data = data[1:]
+ if len(data) < protoLen {
+ return false
+ }
+ m.proto = string(data[0:protoLen])
+ data = data[protoLen:]
+
+ if len(data) < 1 {
+ return false
+ }
+ paddingLen := int(data[0])
+ data = data[1:]
+ if len(data) != paddingLen {
+ return false
+ }
+
+ return true
+}
+
+type certificateRequestMsg struct {
+ raw []byte
+ // hasSignatureAndHash indicates whether this message includes a list
+ // of signature and hash functions. This change was introduced with TLS
+ // 1.2.
+ hasSignatureAndHash bool
+
+ certificateTypes []byte
+ signatureAndHashes []signatureAndHash
+ certificateAuthorities [][]byte
+}
+
+func (m *certificateRequestMsg) equal(i interface{}) bool {
+ m1, ok := i.(*certificateRequestMsg)
+ if !ok {
+ return false
+ }
+
+ return bytes.Equal(m.raw, m1.raw) &&
+ bytes.Equal(m.certificateTypes, m1.certificateTypes) &&
+ eqByteSlices(m.certificateAuthorities, m1.certificateAuthorities) &&
+ eqSignatureAndHashes(m.signatureAndHashes, m1.signatureAndHashes)
+}
+
+func (m *certificateRequestMsg) marshal() (x []byte) {
+ if m.raw != nil {
+ return m.raw
+ }
+
+ // See http://tools.ietf.org/html/rfc4346#section-7.4.4
+ length := 1 + len(m.certificateTypes) + 2
+ casLength := 0
+ for _, ca := range m.certificateAuthorities {
+ casLength += 2 + len(ca)
+ }
+ length += casLength
+
+ if m.hasSignatureAndHash {
+ length += 2 + 2*len(m.signatureAndHashes)
+ }
+
+ x = make([]byte, 4+length)
+ x[0] = typeCertificateRequest
+ x[1] = uint8(length >> 16)
+ x[2] = uint8(length >> 8)
+ x[3] = uint8(length)
+
+ x[4] = uint8(len(m.certificateTypes))
+
+ copy(x[5:], m.certificateTypes)
+ y := x[5+len(m.certificateTypes):]
+
+ if m.hasSignatureAndHash {
+ n := len(m.signatureAndHashes) * 2
+ y[0] = uint8(n >> 8)
+ y[1] = uint8(n)
+ y = y[2:]
+ for _, sigAndHash := range m.signatureAndHashes {
+ y[0] = sigAndHash.hash
+ y[1] = sigAndHash.signature
+ y = y[2:]
+ }
+ }
+
+ y[0] = uint8(casLength >> 8)
+ y[1] = uint8(casLength)
+ y = y[2:]
+ for _, ca := range m.certificateAuthorities {
+ y[0] = uint8(len(ca) >> 8)
+ y[1] = uint8(len(ca))
+ y = y[2:]
+ copy(y, ca)
+ y = y[len(ca):]
+ }
+
+ m.raw = x
+ return
+}
+
+func (m *certificateRequestMsg) unmarshal(data []byte) bool {
+ m.raw = data
+
+ if len(data) < 5 {
+ return false
+ }
+
+ length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
+ if uint32(len(data))-4 != length {
+ return false
+ }
+
+ numCertTypes := int(data[4])
+ data = data[5:]
+ if numCertTypes == 0 || len(data) <= numCertTypes {
+ return false
+ }
+
+ m.certificateTypes = make([]byte, numCertTypes)
+ if copy(m.certificateTypes, data) != numCertTypes {
+ return false
+ }
+
+ data = data[numCertTypes:]
+
+ if m.hasSignatureAndHash {
+ if len(data) < 2 {
+ return false
+ }
+ sigAndHashLen := uint16(data[0])<<8 | uint16(data[1])
+ data = data[2:]
+ if sigAndHashLen&1 != 0 {
+ return false
+ }
+ if len(data) < int(sigAndHashLen) {
+ return false
+ }
+ numSigAndHash := sigAndHashLen / 2
+ m.signatureAndHashes = make([]signatureAndHash, numSigAndHash)
+ for i := range m.signatureAndHashes {
+ m.signatureAndHashes[i].hash = data[0]
+ m.signatureAndHashes[i].signature = data[1]
+ data = data[2:]
+ }
+ }
+
+ if len(data) < 2 {
+ return false
+ }
+ casLength := uint16(data[0])<<8 | uint16(data[1])
+ data = data[2:]
+ if len(data) < int(casLength) {
+ return false
+ }
+ cas := make([]byte, casLength)
+ copy(cas, data)
+ data = data[casLength:]
+
+ m.certificateAuthorities = nil
+ for len(cas) > 0 {
+ if len(cas) < 2 {
+ return false
+ }
+ caLen := uint16(cas[0])<<8 | uint16(cas[1])
+ cas = cas[2:]
+
+ if len(cas) < int(caLen) {
+ return false
+ }
+
+ m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen])
+ cas = cas[caLen:]
+ }
+ if len(data) > 0 {
+ return false
+ }
+
+ return true
+}
+
+type certificateVerifyMsg struct {
+ raw []byte
+ hasSignatureAndHash bool
+ signatureAndHash signatureAndHash
+ signature []byte
+}
+
+func (m *certificateVerifyMsg) equal(i interface{}) bool {
+ m1, ok := i.(*certificateVerifyMsg)
+ if !ok {
+ return false
+ }
+
+ return bytes.Equal(m.raw, m1.raw) &&
+ m.hasSignatureAndHash == m1.hasSignatureAndHash &&
+ m.signatureAndHash.hash == m1.signatureAndHash.hash &&
+ m.signatureAndHash.signature == m1.signatureAndHash.signature &&
+ bytes.Equal(m.signature, m1.signature)
+}
+
+func (m *certificateVerifyMsg) marshal() (x []byte) {
+ if m.raw != nil {
+ return m.raw
+ }
+
+ // See http://tools.ietf.org/html/rfc4346#section-7.4.8
+ siglength := len(m.signature)
+ length := 2 + siglength
+ if m.hasSignatureAndHash {
+ length += 2
+ }
+ x = make([]byte, 4+length)
+ x[0] = typeCertificateVerify
+ x[1] = uint8(length >> 16)
+ x[2] = uint8(length >> 8)
+ x[3] = uint8(length)
+ y := x[4:]
+ if m.hasSignatureAndHash {
+ y[0] = m.signatureAndHash.hash
+ y[1] = m.signatureAndHash.signature
+ y = y[2:]
+ }
+ y[0] = uint8(siglength >> 8)
+ y[1] = uint8(siglength)
+ copy(y[2:], m.signature)
+
+ m.raw = x
+
+ return
+}
+
+func (m *certificateVerifyMsg) unmarshal(data []byte) bool {
+ m.raw = data
+
+ if len(data) < 6 {
+ return false
+ }
+
+ length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
+ if uint32(len(data))-4 != length {
+ return false
+ }
+
+ data = data[4:]
+ if m.hasSignatureAndHash {
+ m.signatureAndHash.hash = data[0]
+ m.signatureAndHash.signature = data[1]
+ data = data[2:]
+ }
+
+ if len(data) < 2 {
+ return false
+ }
+ siglength := int(data[0])<<8 + int(data[1])
+ data = data[2:]
+ if len(data) != siglength {
+ return false
+ }
+
+ m.signature = data
+
+ return true
+}
+
+type newSessionTicketMsg struct {
+ raw []byte
+ ticket []byte
+}
+
+func (m *newSessionTicketMsg) equal(i interface{}) bool {
+ m1, ok := i.(*newSessionTicketMsg)
+ if !ok {
+ return false
+ }
+
+ return bytes.Equal(m.raw, m1.raw) &&
+ bytes.Equal(m.ticket, m1.ticket)
+}
+
+func (m *newSessionTicketMsg) marshal() (x []byte) {
+ if m.raw != nil {
+ return m.raw
+ }
+
+ // See http://tools.ietf.org/html/rfc5077#section-3.3
+ ticketLen := len(m.ticket)
+ length := 2 + 4 + ticketLen
+ x = make([]byte, 4+length)
+ x[0] = typeNewSessionTicket
+ x[1] = uint8(length >> 16)
+ x[2] = uint8(length >> 8)
+ x[3] = uint8(length)
+ x[8] = uint8(ticketLen >> 8)
+ x[9] = uint8(ticketLen)
+ copy(x[10:], m.ticket)
+
+ m.raw = x
+
+ return
+}
+
+func (m *newSessionTicketMsg) unmarshal(data []byte) bool {
+ m.raw = data
+
+ if len(data) < 10 {
+ return false
+ }
+
+ length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
+ if uint32(len(data))-4 != length {
+ return false
+ }
+
+ ticketLen := int(data[8])<<8 + int(data[9])
+ if len(data)-10 != ticketLen {
+ return false
+ }
+
+ m.ticket = data[10:]
+
+ return true
+}
+
+func eqUint16s(x, y []uint16) bool {
+ if len(x) != len(y) {
+ return false
+ }
+ for i, v := range x {
+ if y[i] != v {
+ return false
+ }
+ }
+ return true
+}
+
+func eqCurveIDs(x, y []CurveID) bool {
+ if len(x) != len(y) {
+ return false
+ }
+ for i, v := range x {
+ if y[i] != v {
+ return false
+ }
+ }
+ return true
+}
+
+func eqStrings(x, y []string) bool {
+ if len(x) != len(y) {
+ return false
+ }
+ for i, v := range x {
+ if y[i] != v {
+ return false
+ }
+ }
+ return true
+}
+
+func eqByteSlices(x, y [][]byte) bool {
+ if len(x) != len(y) {
+ return false
+ }
+ for i, v := range x {
+ if !bytes.Equal(v, y[i]) {
+ return false
+ }
+ }
+ return true
+}
+
+func eqSignatureAndHashes(x, y []signatureAndHash) bool {
+ if len(x) != len(y) {
+ return false
+ }
+ for i, v := range x {
+ v2 := y[i]
+ if v.hash != v2.hash || v.signature != v2.signature {
+ return false
+ }
+ }
+ return true
+}