summaryrefslogtreecommitdiff
path: root/libgo/go/crypto/tls/handshake_messages.go
diff options
context:
space:
mode:
Diffstat (limited to 'libgo/go/crypto/tls/handshake_messages.go')
-rw-r--r--libgo/go/crypto/tls/handshake_messages.go98
1 files changed, 96 insertions, 2 deletions
diff --git a/libgo/go/crypto/tls/handshake_messages.go b/libgo/go/crypto/tls/handshake_messages.go
index 7bcaa5eb929..5d14871a348 100644
--- a/libgo/go/crypto/tls/handshake_messages.go
+++ b/libgo/go/crypto/tls/handshake_messages.go
@@ -22,6 +22,7 @@ type clientHelloMsg struct {
sessionTicket []uint8
signatureAndHashes []signatureAndHash
secureRenegotiation bool
+ alpnProtocols []string
}
func (m *clientHelloMsg) equal(i interface{}) bool {
@@ -44,7 +45,8 @@ func (m *clientHelloMsg) equal(i interface{}) bool {
m.ticketSupported == m1.ticketSupported &&
bytes.Equal(m.sessionTicket, m1.sessionTicket) &&
eqSignatureAndHashes(m.signatureAndHashes, m1.signatureAndHashes) &&
- m.secureRenegotiation == m1.secureRenegotiation
+ m.secureRenegotiation == m1.secureRenegotiation &&
+ eqStrings(m.alpnProtocols, m1.alpnProtocols)
}
func (m *clientHelloMsg) marshal() []byte {
@@ -86,6 +88,17 @@ func (m *clientHelloMsg) marshal() []byte {
extensionsLength += 1
numExtensions++
}
+ if len(m.alpnProtocols) > 0 {
+ extensionsLength += 2
+ for _, s := range m.alpnProtocols {
+ if l := len(s); l == 0 || l > 255 {
+ panic("invalid ALPN protocol")
+ }
+ extensionsLength++
+ extensionsLength += len(s)
+ }
+ numExtensions++
+ }
if numExtensions > 0 {
extensionsLength += 4 * numExtensions
length += 2 + extensionsLength
@@ -237,6 +250,27 @@ func (m *clientHelloMsg) marshal() []byte {
z[3] = 1
z = z[5:]
}
+ if len(m.alpnProtocols) > 0 {
+ z[0] = byte(extensionALPN >> 8)
+ z[1] = byte(extensionALPN & 0xff)
+ lengths := z[2:]
+ z = z[6:]
+
+ stringsLength := 0
+ for _, s := range m.alpnProtocols {
+ l := len(s)
+ z[0] = byte(l)
+ copy(z[1:], s)
+ z = z[1+l:]
+ stringsLength += 1 + l
+ }
+
+ lengths[2] = byte(stringsLength >> 8)
+ lengths[3] = byte(stringsLength)
+ stringsLength += 2
+ lengths[0] = byte(stringsLength >> 8)
+ lengths[1] = byte(stringsLength)
+ }
m.raw = x
@@ -291,6 +325,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
m.ticketSupported = false
m.sessionTicket = nil
m.signatureAndHashes = nil
+ m.alpnProtocols = nil
if len(data) == 0 {
// ClientHello is optionally followed by extension data
@@ -400,6 +435,24 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
return false
}
m.secureRenegotiation = true
+ case extensionALPN:
+ if length < 2 {
+ return false
+ }
+ l := int(data[0])<<8 | int(data[1])
+ if l != length-2 {
+ return false
+ }
+ d := data[2:length]
+ for len(d) != 0 {
+ stringLen := int(d[0])
+ d = d[1:]
+ if stringLen == 0 || stringLen > len(d) {
+ return false
+ }
+ m.alpnProtocols = append(m.alpnProtocols, string(d[:stringLen]))
+ d = d[stringLen:]
+ }
}
data = data[length:]
}
@@ -419,6 +472,7 @@ type serverHelloMsg struct {
ocspStapling bool
ticketSupported bool
secureRenegotiation bool
+ alpnProtocol string
}
func (m *serverHelloMsg) equal(i interface{}) bool {
@@ -437,7 +491,8 @@ func (m *serverHelloMsg) equal(i interface{}) bool {
eqStrings(m.nextProtos, m1.nextProtos) &&
m.ocspStapling == m1.ocspStapling &&
m.ticketSupported == m1.ticketSupported &&
- m.secureRenegotiation == m1.secureRenegotiation
+ m.secureRenegotiation == m1.secureRenegotiation &&
+ m.alpnProtocol == m1.alpnProtocol
}
func (m *serverHelloMsg) marshal() []byte {
@@ -468,6 +523,14 @@ func (m *serverHelloMsg) marshal() []byte {
extensionsLength += 1
numExtensions++
}
+ if alpnLen := len(m.alpnProtocol); alpnLen > 0 {
+ if alpnLen >= 256 {
+ panic("invalid ALPN protocol")
+ }
+ extensionsLength += 2 + 1 + alpnLen
+ numExtensions++
+ }
+
if numExtensions > 0 {
extensionsLength += 4 * numExtensions
length += 2 + extensionsLength
@@ -528,6 +591,20 @@ func (m *serverHelloMsg) marshal() []byte {
z[3] = 1
z = z[5:]
}
+ if alpnLen := len(m.alpnProtocol); alpnLen > 0 {
+ z[0] = byte(extensionALPN >> 8)
+ z[1] = byte(extensionALPN & 0xff)
+ l := 2 + 1 + alpnLen
+ z[2] = byte(l >> 8)
+ z[3] = byte(l)
+ l -= 2
+ z[4] = byte(l >> 8)
+ z[5] = byte(l)
+ l -= 1
+ z[6] = byte(l)
+ copy(z[7:], []byte(m.alpnProtocol))
+ z = z[7+alpnLen:]
+ }
m.raw = x
@@ -558,6 +635,7 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool {
m.nextProtos = nil
m.ocspStapling = false
m.ticketSupported = false
+ m.alpnProtocol = ""
if len(data) == 0 {
// ServerHello is optionally followed by extension data
@@ -612,6 +690,22 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool {
return false
}
m.secureRenegotiation = true
+ case extensionALPN:
+ d := data[:length]
+ if len(d) < 3 {
+ return false
+ }
+ l := int(d[0])<<8 | int(d[1])
+ if l != len(d)-2 {
+ return false
+ }
+ d = d[2:]
+ l = int(d[0])
+ if l != len(d)-1 {
+ return false
+ }
+ d = d[1:]
+ m.alpnProtocol = string(d)
}
data = data[length:]
}