diff options
Diffstat (limited to 'libgo/go/crypto/tls/conn.go')
-rw-r--r-- | libgo/go/crypto/tls/conn.go | 530 |
1 files changed, 413 insertions, 117 deletions
diff --git a/libgo/go/crypto/tls/conn.go b/libgo/go/crypto/tls/conn.go index 03775685fb..03895a723f 100644 --- a/libgo/go/crypto/tls/conn.go +++ b/libgo/go/crypto/tls/conn.go @@ -28,34 +28,71 @@ type Conn struct { isClient bool // constant after handshake; protected by handshakeMutex - handshakeMutex sync.Mutex // handshakeMutex < in.Mutex, out.Mutex, errMutex - handshakeErr error // error resulting from handshake - vers uint16 // TLS version - haveVers bool // version has been negotiated - config *Config // configuration passed to constructor + handshakeMutex sync.Mutex // handshakeMutex < in.Mutex, out.Mutex, errMutex + // handshakeCond, if not nil, indicates that a goroutine is committed + // to running the handshake for this Conn. Other goroutines that need + // to wait for the handshake can wait on this, under handshakeMutex. + handshakeCond *sync.Cond + handshakeErr error // error resulting from handshake + vers uint16 // TLS version + haveVers bool // version has been negotiated + config *Config // configuration passed to constructor + // handshakeComplete is true if the connection is currently transferring + // application data (i.e. is not currently processing a handshake). handshakeComplete bool - didResume bool // whether this connection was a session resumption - cipherSuite uint16 - ocspResponse []byte // stapled OCSP response - scts [][]byte // signed certificate timestamps from server - peerCertificates []*x509.Certificate + // handshakes counts the number of handshakes performed on the + // connection so far. If renegotiation is disabled then this is either + // zero or one. + handshakes int + didResume bool // whether this connection was a session resumption + cipherSuite uint16 + ocspResponse []byte // stapled OCSP response + scts [][]byte // signed certificate timestamps from server + peerCertificates []*x509.Certificate // verifiedChains contains the certificate chains that we built, as // opposed to the ones presented by the server. verifiedChains [][]*x509.Certificate // serverName contains the server name indicated by the client, if any. serverName string - // firstFinished contains the first Finished hash sent during the - // handshake. This is the "tls-unique" channel binding value. - firstFinished [12]byte + // secureRenegotiation is true if the server echoed the secure + // renegotiation extension. (This is meaningless as a server because + // renegotiation is not supported in that case.) + secureRenegotiation bool + + // clientFinishedIsFirst is true if the client sent the first Finished + // message during the most recent handshake. This is recorded because + // the first transmitted Finished message is the tls-unique + // channel-binding value. + clientFinishedIsFirst bool + + // closeNotifyErr is any error from sending the alertCloseNotify record. + closeNotifyErr error + // closeNotifySent is true if the Conn attempted to send an + // alertCloseNotify record. + closeNotifySent bool + + // clientFinished and serverFinished contain the Finished message sent + // by the client or server in the most recent handshake. This is + // retained to support the renegotiation extension and tls-unique + // channel-binding. + clientFinished [12]byte + serverFinished [12]byte clientProtocol string clientProtocolFallback bool // input/output - in, out halfConn // in.Mutex < out.Mutex - rawInput *block // raw input, right off the wire - input *block // application data waiting to be read - hand bytes.Buffer // handshake data waiting to be read + in, out halfConn // in.Mutex < out.Mutex + rawInput *block // raw input, right off the wire + input *block // application data waiting to be read + hand bytes.Buffer // handshake data waiting to be read + buffering bool // whether records are buffered in sendBuf + sendBuf []byte // a buffer of records waiting to be sent + + // bytesSent counts the bytes of application data sent. + // packetsSent counts packets. + bytesSent int64 + packetsSent int64 // activeCall is an atomic int32; the low bit is whether Close has // been called. the rest of the bits are the number of goroutines @@ -124,13 +161,6 @@ func (hc *halfConn) setErrorLocked(err error) error { return err } -func (hc *halfConn) error() error { - hc.Lock() - err := hc.err - hc.Unlock() - return err -} - // prepareCipherSpec sets the encryption and MAC states // that a subsequent changeCipherSpec will use. func (hc *halfConn) prepareCipherSpec(version uint16, cipher interface{}, mac macFunction) { @@ -170,25 +200,18 @@ func (hc *halfConn) incSeq() { panic("TLS: sequence number wraparound") } -// resetSeq resets the sequence number to zero. -func (hc *halfConn) resetSeq() { - for i := range hc.seq { - hc.seq[i] = 0 - } -} - -// removePadding returns an unpadded slice, in constant time, which is a prefix -// of the input. It also returns a byte which is equal to 255 if the padding -// was valid and 0 otherwise. See RFC 2246, section 6.2.3.2 -func removePadding(payload []byte) ([]byte, byte) { +// extractPadding returns, in constant time, the length of the padding to remove +// from the end of payload. It also returns a byte which is equal to 255 if the +// padding was valid and 0 otherwise. See RFC 2246, section 6.2.3.2 +func extractPadding(payload []byte) (toRemove int, good byte) { if len(payload) < 1 { - return payload, 0 + return 0, 0 } paddingLen := payload[len(payload)-1] t := uint(len(payload)-1) - uint(paddingLen) // if len(payload) >= (paddingLen - 1) then the MSB of t is zero - good := byte(int32(^t) >> 31) + good = byte(int32(^t) >> 31) toCheck := 255 // the maximum possible padding length // The length of the padded data is public, so we can use an if here @@ -211,24 +234,24 @@ func removePadding(payload []byte) ([]byte, byte) { good &= good << 1 good = uint8(int8(good) >> 7) - toRemove := good&paddingLen + 1 - return payload[:len(payload)-int(toRemove)], good + toRemove = int(paddingLen) + 1 + return } -// removePaddingSSL30 is a replacement for removePadding in the case that the +// extractPaddingSSL30 is a replacement for extractPadding in the case that the // protocol version is SSLv3. In this version, the contents of the padding // are random and cannot be checked. -func removePaddingSSL30(payload []byte) ([]byte, byte) { +func extractPaddingSSL30(payload []byte) (toRemove int, good byte) { if len(payload) < 1 { - return payload, 0 + return 0, 0 } paddingLen := int(payload[len(payload)-1]) + 1 if paddingLen > len(payload) { - return payload, 0 + return 0, 0 } - return payload[:len(payload)-paddingLen], 255 + return paddingLen, 255 } func roundUp(a, b int) int { @@ -254,6 +277,7 @@ func (hc *halfConn) decrypt(b *block) (ok bool, prefixLen int, alertValue alert) } paddingGood := byte(255) + paddingLen := 0 explicitIVLen := 0 // decrypt @@ -261,13 +285,17 @@ func (hc *halfConn) decrypt(b *block) (ok bool, prefixLen int, alertValue alert) switch c := hc.cipher.(type) { case cipher.Stream: c.XORKeyStream(payload, payload) - case cipher.AEAD: - explicitIVLen = 8 + case aead: + explicitIVLen = c.explicitNonceLen() if len(payload) < explicitIVLen { return false, 0, alertBadRecordMAC } - nonce := payload[:8] - payload = payload[8:] + nonce := payload[:explicitIVLen] + payload = payload[explicitIVLen:] + + if len(nonce) == 0 { + nonce = hc.seq[:] + } copy(hc.additionalData[:], hc.seq[:]) copy(hc.additionalData[8:], b.data[:3]) @@ -296,22 +324,17 @@ func (hc *halfConn) decrypt(b *block) (ok bool, prefixLen int, alertValue alert) } c.CryptBlocks(payload, payload) if hc.version == VersionSSL30 { - payload, paddingGood = removePaddingSSL30(payload) + paddingLen, paddingGood = extractPaddingSSL30(payload) } else { - payload, paddingGood = removePadding(payload) + paddingLen, paddingGood = extractPadding(payload) + + // To protect against CBC padding oracles like Lucky13, the data + // past paddingLen (which is secret) is passed to the MAC + // function as extra data, to be fed into the HMAC after + // computing the digest. This makes the MAC constant time as + // long as the digest computation is constant time and does not + // affect the subsequent write. } - b.resize(recordHeaderLen + explicitIVLen + len(payload)) - - // note that we still have a timing side-channel in the - // MAC check, below. An attacker can align the record - // so that a correct padding will cause one less hash - // block to be calculated. Then they can iteratively - // decrypt a record by breaking each byte. See - // "Password Interception in a SSL/TLS Channel", Brice - // Canvel et al. - // - // However, our behavior matches OpenSSL, so we leak - // only as much as they do. default: panic("unknown cipher type") } @@ -324,17 +347,19 @@ func (hc *halfConn) decrypt(b *block) (ok bool, prefixLen int, alertValue alert) } // strip mac off payload, b.data - n := len(payload) - macSize + n := len(payload) - macSize - paddingLen + n = subtle.ConstantTimeSelect(int(uint32(n)>>31), 0, n) // if n < 0 { n = 0 } b.data[3] = byte(n >> 8) b.data[4] = byte(n) - b.resize(recordHeaderLen + explicitIVLen + n) - remoteMAC := payload[n:] - localMAC := hc.mac.MAC(hc.inDigestBuf, hc.seq[0:], b.data[:recordHeaderLen], payload[:n]) + remoteMAC := payload[n : n+macSize] + localMAC := hc.mac.MAC(hc.inDigestBuf, hc.seq[0:], b.data[:recordHeaderLen], payload[:n], payload[n+macSize:]) if subtle.ConstantTimeCompare(localMAC, remoteMAC) != 1 || paddingGood != 255 { return false, 0, alertBadRecordMAC } hc.inDigestBuf = localMAC + + b.resize(recordHeaderLen + explicitIVLen + n) } hc.incSeq() @@ -362,7 +387,7 @@ func padToBlockSize(payload []byte, blockSize int) (prefix, finalBlock []byte) { func (hc *halfConn) encrypt(b *block, explicitIVLen int) (bool, alert) { // mac if hc.mac != nil { - mac := hc.mac.MAC(hc.outDigestBuf, hc.seq[0:], b.data[:recordHeaderLen], b.data[recordHeaderLen+explicitIVLen:]) + mac := hc.mac.MAC(hc.outDigestBuf, hc.seq[0:], b.data[:recordHeaderLen], b.data[recordHeaderLen+explicitIVLen:], nil) n := len(b.data) b.resize(n + len(mac)) @@ -377,10 +402,13 @@ func (hc *halfConn) encrypt(b *block, explicitIVLen int) (bool, alert) { switch c := hc.cipher.(type) { case cipher.Stream: c.XORKeyStream(payload, payload) - case cipher.AEAD: + case aead: payloadLen := len(b.data) - recordHeaderLen - explicitIVLen b.resize(len(b.data) + c.Overhead()) nonce := b.data[recordHeaderLen : recordHeaderLen+explicitIVLen] + if len(nonce) == 0 { + nonce = hc.seq[:] + } payload := b.data[recordHeaderLen+explicitIVLen:] payload = payload[:payloadLen] @@ -535,7 +563,7 @@ func (c *Conn) newRecordHeaderError(msg string) (err RecordHeaderError) { func (c *Conn) readRecord(want recordType) error { // Caller must be in sync with connection: // handshake data if handshake not yet completed, - // else application data. (We don't support renegotiation.) + // else application data. switch want { default: c.sendAlert(alertInternalError) @@ -543,12 +571,12 @@ func (c *Conn) readRecord(want recordType) error { case recordTypeHandshake, recordTypeChangeCipherSpec: if c.handshakeComplete { c.sendAlert(alertInternalError) - return c.in.setErrorLocked(errors.New("tls: handshake or ChangeCipherSpec requested after handshake complete")) + return c.in.setErrorLocked(errors.New("tls: handshake or ChangeCipherSpec requested while not in handshake")) } case recordTypeApplicationData: if !c.handshakeComplete { c.sendAlert(alertInternalError) - return c.in.setErrorLocked(errors.New("tls: application data record requested before handshake complete")) + return c.in.setErrorLocked(errors.New("tls: application data record requested while in handshake")) } } @@ -616,9 +644,10 @@ Again: // Process message. b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n) - ok, off, err := c.in.decrypt(b) + ok, off, alertValue := c.in.decrypt(b) if !ok { - c.in.setErrorLocked(c.sendAlert(err)) + c.in.freeBlock(b) + return c.in.setErrorLocked(c.sendAlert(alertValue)) } b.off = off data := b.data[b.off:] @@ -672,7 +701,7 @@ Again: case recordTypeHandshake: // TODO(rsc): Should at least pick off connection close. - if typ != want { + if typ != want && !(c.isClient && c.config.Renegotiation != RenegotiateNever) { return c.in.setErrorLocked(c.sendAlert(alertNoRenegotiation)) } c.hand.Write(data) @@ -694,12 +723,14 @@ func (c *Conn) sendAlertLocked(err alert) error { c.tmp[0] = alertLevelError } c.tmp[1] = byte(err) - c.writeRecord(recordTypeAlert, c.tmp[0:2]) - // closeNotify is a special case in that it isn't an error: - if err != alertCloseNotify { - return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err}) + + _, writeErr := c.writeRecordLocked(recordTypeAlert, c.tmp[0:2]) + if err == alertCloseNotify { + // closeNotify is a special case in that it isn't an error. + return writeErr } - return nil + + return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err}) } // sendAlert sends a TLS alert message. @@ -710,16 +741,120 @@ func (c *Conn) sendAlert(err alert) error { return c.sendAlertLocked(err) } -// writeRecord writes a TLS record with the given type and payload -// to the connection and updates the record layer state. +const ( + // tcpMSSEstimate is a conservative estimate of the TCP maximum segment + // size (MSS). A constant is used, rather than querying the kernel for + // the actual MSS, to avoid complexity. The value here is the IPv6 + // minimum MTU (1280 bytes) minus the overhead of an IPv6 header (40 + // bytes) and a TCP header with timestamps (32 bytes). + tcpMSSEstimate = 1208 + + // recordSizeBoostThreshold is the number of bytes of application data + // sent after which the TLS record size will be increased to the + // maximum. + recordSizeBoostThreshold = 128 * 1024 +) + +// maxPayloadSizeForWrite returns the maximum TLS payload size to use for the +// next application data record. There is the following trade-off: +// +// - For latency-sensitive applications, such as web browsing, each TLS +// record should fit in one TCP segment. +// - For throughput-sensitive applications, such as large file transfers, +// larger TLS records better amortize framing and encryption overheads. +// +// A simple heuristic that works well in practice is to use small records for +// the first 1MB of data, then use larger records for subsequent data, and +// reset back to smaller records after the connection becomes idle. See "High +// Performance Web Networking", Chapter 4, or: +// https://www.igvita.com/2013/10/24/optimizing-tls-record-size-and-buffering-latency/ +// +// In the interests of simplicity and determinism, this code does not attempt +// to reset the record size once the connection is idle, however. +// // c.out.Mutex <= L. -func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err error) { +func (c *Conn) maxPayloadSizeForWrite(typ recordType, explicitIVLen int) int { + if c.config.DynamicRecordSizingDisabled || typ != recordTypeApplicationData { + return maxPlaintext + } + + if c.bytesSent >= recordSizeBoostThreshold { + return maxPlaintext + } + + // Subtract TLS overheads to get the maximum payload size. + macSize := 0 + if c.out.mac != nil { + macSize = c.out.mac.Size() + } + + payloadBytes := tcpMSSEstimate - recordHeaderLen - explicitIVLen + if c.out.cipher != nil { + switch ciph := c.out.cipher.(type) { + case cipher.Stream: + payloadBytes -= macSize + case cipher.AEAD: + payloadBytes -= ciph.Overhead() + case cbcMode: + blockSize := ciph.BlockSize() + // The payload must fit in a multiple of blockSize, with + // room for at least one padding byte. + payloadBytes = (payloadBytes & ^(blockSize - 1)) - 1 + // The MAC is appended before padding so affects the + // payload size directly. + payloadBytes -= macSize + default: + panic("unknown cipher type") + } + } + + // Allow packet growth in arithmetic progression up to max. + pkt := c.packetsSent + c.packetsSent++ + if pkt > 1000 { + return maxPlaintext // avoid overflow in multiply below + } + + n := payloadBytes * int(pkt+1) + if n > maxPlaintext { + n = maxPlaintext + } + return n +} + +// c.out.Mutex <= L. +func (c *Conn) write(data []byte) (int, error) { + if c.buffering { + c.sendBuf = append(c.sendBuf, data...) + return len(data), nil + } + + n, err := c.conn.Write(data) + c.bytesSent += int64(n) + return n, err +} + +func (c *Conn) flush() (int, error) { + if len(c.sendBuf) == 0 { + return 0, nil + } + + n, err := c.conn.Write(c.sendBuf) + c.bytesSent += int64(n) + c.sendBuf = nil + c.buffering = false + return n, err +} + +// writeRecordLocked writes a TLS record with the given type and payload to the +// connection and updates the record layer state. +// c.out.Mutex <= L. +func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) { b := c.out.newBlock() + defer c.out.freeBlock(b) + + var n int for len(data) > 0 { - m := len(data) - if m > maxPlaintext { - m = maxPlaintext - } explicitIVLen := 0 explicitIVIsSeq := false @@ -731,17 +866,22 @@ func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err error) { } } if explicitIVLen == 0 { - if _, ok := c.out.cipher.(cipher.AEAD); ok { - explicitIVLen = 8 + if c, ok := c.out.cipher.(aead); ok { + explicitIVLen = c.explicitNonceLen() + // The AES-GCM construction in TLS has an // explicit nonce so that the nonce can be // random. However, the nonce is only 8 bytes // which is too small for a secure, random // nonce. Therefore we use the sequence number // as the nonce. - explicitIVIsSeq = true + explicitIVIsSeq = explicitIVLen > 0 } } + m := len(data) + if maxPayload := c.maxPayloadSizeForWrite(typ, explicitIVLen); m > maxPayload { + m = maxPayload + } b.resize(recordHeaderLen + explicitIVLen + m) b.data[0] = byte(typ) vers := c.vers @@ -759,34 +899,37 @@ func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err error) { if explicitIVIsSeq { copy(explicitIV, c.out.seq[:]) } else { - if _, err = io.ReadFull(c.config.rand(), explicitIV); err != nil { - break + if _, err := io.ReadFull(c.config.rand(), explicitIV); err != nil { + return n, err } } } copy(b.data[recordHeaderLen+explicitIVLen:], data) c.out.encrypt(b, explicitIVLen) - _, err = c.conn.Write(b.data) - if err != nil { - break + if _, err := c.write(b.data); err != nil { + return n, err } n += m data = data[m:] } - c.out.freeBlock(b) if typ == recordTypeChangeCipherSpec { - err = c.out.changeCipherSpec() - if err != nil { - // Cannot call sendAlert directly, - // because we already hold c.out.Mutex. - c.tmp[0] = alertLevelError - c.tmp[1] = byte(err.(alert)) - c.writeRecord(recordTypeAlert, c.tmp[0:2]) - return n, c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err}) + if err := c.out.changeCipherSpec(); err != nil { + return n, c.sendAlertLocked(err.(alert)) } } - return + + return n, nil +} + +// writeRecord writes a TLS record with the given type and payload to the +// connection and updates the record layer state. +// L < c.out.Mutex. +func (c *Conn) writeRecord(typ recordType, data []byte) (int, error) { + c.out.Lock() + defer c.out.Unlock() + + return c.writeRecordLocked(typ, data) } // readHandshake reads the next handshake message from @@ -805,7 +948,8 @@ func (c *Conn) readHandshake() (interface{}, error) { data := c.hand.Bytes() n := int(data[1])<<16 | int(data[2])<<8 | int(data[3]) if n > maxHandshake { - return nil, c.in.setErrorLocked(c.sendAlert(alertInternalError)) + c.sendAlertLocked(alertInternalError) + return nil, c.in.setErrorLocked(fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake)) } for c.hand.Len() < 4+n { if err := c.in.err; err != nil { @@ -818,6 +962,8 @@ func (c *Conn) readHandshake() (interface{}, error) { data = c.hand.Next(4 + n) var m handshakeMessage switch data[0] { + case typeHelloRequest: + m = new(helloRequestMsg) case typeClientHello: m = new(clientHelloMsg) case typeServerHello: @@ -850,7 +996,7 @@ func (c *Conn) readHandshake() (interface{}, error) { return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) } - // The handshake message unmarshallers + // The handshake message unmarshalers // expect to be able to keep references to data, // so pass in a fresh copy that won't be overwritten. data = append([]byte(nil), data...) @@ -861,7 +1007,10 @@ func (c *Conn) readHandshake() (interface{}, error) { return m, nil } -var errClosed = errors.New("crypto/tls: use of closed connection") +var ( + errClosed = errors.New("tls: use of closed connection") + errShutdown = errors.New("tls: protocol is shutdown") +) // Write writes data to the connection. func (c *Conn) Write(b []byte) (int, error) { @@ -892,6 +1041,10 @@ func (c *Conn) Write(b []byte) (int, error) { return 0, alertInternalError } + if c.closeNotifySent { + return 0, errShutdown + } + // SSL 3.0 and TLS 1.0 are susceptible to a chosen-plaintext // attack when using block mode ciphers due to predictable IVs. // This can be prevented by splitting each Application Data @@ -904,7 +1057,7 @@ func (c *Conn) Write(b []byte) (int, error) { var m int if len(b) > 1 && c.vers <= VersionTLS10 { if _, ok := c.out.cipher.(cipher.BlockMode); ok { - n, err := c.writeRecord(recordTypeApplicationData, b[:1]) + n, err := c.writeRecordLocked(recordTypeApplicationData, b[:1]) if err != nil { return n, c.out.setErrorLocked(err) } @@ -912,10 +1065,52 @@ func (c *Conn) Write(b []byte) (int, error) { } } - n, err := c.writeRecord(recordTypeApplicationData, b) + n, err := c.writeRecordLocked(recordTypeApplicationData, b) return n + m, c.out.setErrorLocked(err) } +// handleRenegotiation processes a HelloRequest handshake message. +// c.in.Mutex <= L +func (c *Conn) handleRenegotiation() error { + msg, err := c.readHandshake() + if err != nil { + return err + } + + _, ok := msg.(*helloRequestMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return alertUnexpectedMessage + } + + if !c.isClient { + return c.sendAlert(alertNoRenegotiation) + } + + switch c.config.Renegotiation { + case RenegotiateNever: + return c.sendAlert(alertNoRenegotiation) + case RenegotiateOnceAsClient: + if c.handshakes > 1 { + return c.sendAlert(alertNoRenegotiation) + } + case RenegotiateFreelyAsClient: + // Ok. + default: + c.sendAlert(alertInternalError) + return errors.New("tls: unknown Renegotiation value") + } + + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + + c.handshakeComplete = false + if c.handshakeErr = c.clientHandshake(); c.handshakeErr == nil { + c.handshakes++ + } + return c.handshakeErr +} + // Read can be made to time out and return a net.Error with Timeout() == true // after a fixed time limit; see SetDeadline and SetReadDeadline. func (c *Conn) Read(b []byte) (n int, err error) { @@ -940,6 +1135,13 @@ func (c *Conn) Read(b []byte) (n int, err error) { // Soft error, like EAGAIN return 0, err } + if c.hand.Len() > 0 { + // We received handshake bytes, indicating the + // start of a renegotiation. + if err := c.handleRenegotiation(); err != nil { + return 0, err + } + } } if err := c.in.err; err != nil { return 0, err @@ -1006,7 +1208,7 @@ func (c *Conn) Close() error { c.handshakeMutex.Lock() defer c.handshakeMutex.Unlock() if c.handshakeComplete { - alertErr = c.sendAlert(alertCloseNotify) + alertErr = c.closeNotify() } if err := c.conn.Close(); err != nil { @@ -1015,18 +1217,90 @@ func (c *Conn) Close() error { return alertErr } +var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake complete") + +// CloseWrite shuts down the writing side of the connection. It should only be +// called once the handshake has completed and does not call CloseWrite on the +// underlying connection. Most callers should just use Close. +func (c *Conn) CloseWrite() error { + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + if !c.handshakeComplete { + return errEarlyCloseWrite + } + + return c.closeNotify() +} + +func (c *Conn) closeNotify() error { + c.out.Lock() + defer c.out.Unlock() + + if !c.closeNotifySent { + c.closeNotifyErr = c.sendAlertLocked(alertCloseNotify) + c.closeNotifySent = true + } + return c.closeNotifyErr +} + // Handshake runs the client or server handshake // protocol if it has not yet been run. // Most uses of this package need not call Handshake // explicitly: the first Read or Write will call it automatically. func (c *Conn) Handshake() error { + // c.handshakeErr and c.handshakeComplete are protected by + // c.handshakeMutex. In order to perform a handshake, we need to lock + // c.in also and c.handshakeMutex must be locked after c.in. + // + // However, if a Read() operation is hanging then it'll be holding the + // lock on c.in and so taking it here would cause all operations that + // need to check whether a handshake is pending (such as Write) to + // block. + // + // Thus we first take c.handshakeMutex to check whether a handshake is + // needed. + // + // If so then, previously, this code would unlock handshakeMutex and + // then lock c.in and handshakeMutex in the correct order to run the + // handshake. The problem was that it was possible for a Read to + // complete the handshake once handshakeMutex was unlocked and then + // keep c.in while waiting for network data. Thus a concurrent + // operation could be blocked on c.in. + // + // Thus handshakeCond is used to signal that a goroutine is committed + // to running the handshake and other goroutines can wait on it if they + // need. handshakeCond is protected by handshakeMutex. c.handshakeMutex.Lock() defer c.handshakeMutex.Unlock() - if err := c.handshakeErr; err != nil { - return err + + for { + if err := c.handshakeErr; err != nil { + return err + } + if c.handshakeComplete { + return nil + } + if c.handshakeCond == nil { + break + } + + c.handshakeCond.Wait() } - if c.handshakeComplete { - return nil + + // Set handshakeCond to indicate that this goroutine is committing to + // running the handshake. + c.handshakeCond = sync.NewCond(&c.handshakeMutex) + c.handshakeMutex.Unlock() + + c.in.Lock() + defer c.in.Unlock() + + c.handshakeMutex.Lock() + + // The handshake cannot have completed when handshakeMutex was unlocked + // because this goroutine set handshakeCond. + if c.handshakeErr != nil || c.handshakeComplete { + panic("handshake should not have been able to complete after handshakeCond was set") } if c.isClient { @@ -1034,6 +1308,23 @@ func (c *Conn) Handshake() error { } else { c.handshakeErr = c.serverHandshake() } + if c.handshakeErr == nil { + c.handshakes++ + } else { + // If an error occurred during the hadshake try to flush the + // alert that might be left in the buffer. + c.flush() + } + + if c.handshakeErr == nil && !c.handshakeComplete { + panic("handshake should have had a result.") + } + + // Wake any other goroutines that are waiting for this handshake to + // complete. + c.handshakeCond.Broadcast() + c.handshakeCond = nil + return c.handshakeErr } @@ -1044,6 +1335,8 @@ func (c *Conn) ConnectionState() ConnectionState { var state ConnectionState state.HandshakeComplete = c.handshakeComplete + state.ServerName = c.serverName + if c.handshakeComplete { state.Version = c.vers state.NegotiatedProtocol = c.clientProtocol @@ -1052,11 +1345,14 @@ func (c *Conn) ConnectionState() ConnectionState { state.CipherSuite = c.cipherSuite state.PeerCertificates = c.peerCertificates state.VerifiedChains = c.verifiedChains - state.ServerName = c.serverName state.SignedCertificateTimestamps = c.scts state.OCSPResponse = c.ocspResponse if !c.didResume { - state.TLSUnique = c.firstFinished[:] + if c.clientFinishedIsFirst { + state.TLSUnique = c.clientFinished[:] + } else { + state.TLSUnique = c.serverFinished[:] + } } } @@ -1073,7 +1369,7 @@ func (c *Conn) OCSPResponse() []byte { } // VerifyHostname checks that the peer certificate chain is valid for -// connecting to host. If so, it returns nil; if not, it returns an error +// connecting to host. If so, it returns nil; if not, it returns an error // describing the problem. func (c *Conn) VerifyHostname(host string) error { c.handshakeMutex.Lock() |