diff options
Diffstat (limited to 'libgo/go/crypto/tls/handshake_messages.go')
-rw-r--r-- | libgo/go/crypto/tls/handshake_messages.go | 98 |
1 files changed, 96 insertions, 2 deletions
diff --git a/libgo/go/crypto/tls/handshake_messages.go b/libgo/go/crypto/tls/handshake_messages.go index 7bcaa5eb929..5d14871a348 100644 --- a/libgo/go/crypto/tls/handshake_messages.go +++ b/libgo/go/crypto/tls/handshake_messages.go @@ -22,6 +22,7 @@ type clientHelloMsg struct { sessionTicket []uint8 signatureAndHashes []signatureAndHash secureRenegotiation bool + alpnProtocols []string } func (m *clientHelloMsg) equal(i interface{}) bool { @@ -44,7 +45,8 @@ func (m *clientHelloMsg) equal(i interface{}) bool { m.ticketSupported == m1.ticketSupported && bytes.Equal(m.sessionTicket, m1.sessionTicket) && eqSignatureAndHashes(m.signatureAndHashes, m1.signatureAndHashes) && - m.secureRenegotiation == m1.secureRenegotiation + m.secureRenegotiation == m1.secureRenegotiation && + eqStrings(m.alpnProtocols, m1.alpnProtocols) } func (m *clientHelloMsg) marshal() []byte { @@ -86,6 +88,17 @@ func (m *clientHelloMsg) marshal() []byte { extensionsLength += 1 numExtensions++ } + if len(m.alpnProtocols) > 0 { + extensionsLength += 2 + for _, s := range m.alpnProtocols { + if l := len(s); l == 0 || l > 255 { + panic("invalid ALPN protocol") + } + extensionsLength++ + extensionsLength += len(s) + } + numExtensions++ + } if numExtensions > 0 { extensionsLength += 4 * numExtensions length += 2 + extensionsLength @@ -237,6 +250,27 @@ func (m *clientHelloMsg) marshal() []byte { z[3] = 1 z = z[5:] } + if len(m.alpnProtocols) > 0 { + z[0] = byte(extensionALPN >> 8) + z[1] = byte(extensionALPN & 0xff) + lengths := z[2:] + z = z[6:] + + stringsLength := 0 + for _, s := range m.alpnProtocols { + l := len(s) + z[0] = byte(l) + copy(z[1:], s) + z = z[1+l:] + stringsLength += 1 + l + } + + lengths[2] = byte(stringsLength >> 8) + lengths[3] = byte(stringsLength) + stringsLength += 2 + lengths[0] = byte(stringsLength >> 8) + lengths[1] = byte(stringsLength) + } m.raw = x @@ -291,6 +325,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { m.ticketSupported = false m.sessionTicket = nil m.signatureAndHashes = nil + m.alpnProtocols = nil if len(data) == 0 { // ClientHello is optionally followed by extension data @@ -400,6 +435,24 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { return false } m.secureRenegotiation = true + case extensionALPN: + if length < 2 { + return false + } + l := int(data[0])<<8 | int(data[1]) + if l != length-2 { + return false + } + d := data[2:length] + for len(d) != 0 { + stringLen := int(d[0]) + d = d[1:] + if stringLen == 0 || stringLen > len(d) { + return false + } + m.alpnProtocols = append(m.alpnProtocols, string(d[:stringLen])) + d = d[stringLen:] + } } data = data[length:] } @@ -419,6 +472,7 @@ type serverHelloMsg struct { ocspStapling bool ticketSupported bool secureRenegotiation bool + alpnProtocol string } func (m *serverHelloMsg) equal(i interface{}) bool { @@ -437,7 +491,8 @@ func (m *serverHelloMsg) equal(i interface{}) bool { eqStrings(m.nextProtos, m1.nextProtos) && m.ocspStapling == m1.ocspStapling && m.ticketSupported == m1.ticketSupported && - m.secureRenegotiation == m1.secureRenegotiation + m.secureRenegotiation == m1.secureRenegotiation && + m.alpnProtocol == m1.alpnProtocol } func (m *serverHelloMsg) marshal() []byte { @@ -468,6 +523,14 @@ func (m *serverHelloMsg) marshal() []byte { extensionsLength += 1 numExtensions++ } + if alpnLen := len(m.alpnProtocol); alpnLen > 0 { + if alpnLen >= 256 { + panic("invalid ALPN protocol") + } + extensionsLength += 2 + 1 + alpnLen + numExtensions++ + } + if numExtensions > 0 { extensionsLength += 4 * numExtensions length += 2 + extensionsLength @@ -528,6 +591,20 @@ func (m *serverHelloMsg) marshal() []byte { z[3] = 1 z = z[5:] } + if alpnLen := len(m.alpnProtocol); alpnLen > 0 { + z[0] = byte(extensionALPN >> 8) + z[1] = byte(extensionALPN & 0xff) + l := 2 + 1 + alpnLen + z[2] = byte(l >> 8) + z[3] = byte(l) + l -= 2 + z[4] = byte(l >> 8) + z[5] = byte(l) + l -= 1 + z[6] = byte(l) + copy(z[7:], []byte(m.alpnProtocol)) + z = z[7+alpnLen:] + } m.raw = x @@ -558,6 +635,7 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool { m.nextProtos = nil m.ocspStapling = false m.ticketSupported = false + m.alpnProtocol = "" if len(data) == 0 { // ServerHello is optionally followed by extension data @@ -612,6 +690,22 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool { return false } m.secureRenegotiation = true + case extensionALPN: + d := data[:length] + if len(d) < 3 { + return false + } + l := int(d[0])<<8 | int(d[1]) + if l != len(d)-2 { + return false + } + d = d[2:] + l = int(d[0]) + if l != len(d)-1 { + return false + } + d = d[1:] + m.alpnProtocol = string(d) } data = data[length:] } |