diff options
Diffstat (limited to 'libgo/go/crypto/tls/conn.go')
-rw-r--r-- | libgo/go/crypto/tls/conn.go | 62 |
1 files changed, 58 insertions, 4 deletions
diff --git a/libgo/go/crypto/tls/conn.go b/libgo/go/crypto/tls/conn.go index 72ad52c194b..969f357834c 100644 --- a/libgo/go/crypto/tls/conn.go +++ b/libgo/go/crypto/tls/conn.go @@ -8,6 +8,7 @@ package tls import ( "bytes" + "context" "crypto/cipher" "crypto/subtle" "crypto/x509" @@ -27,7 +28,7 @@ type Conn struct { // constant conn net.Conn isClient bool - handshakeFn func() error // (*Conn).clientHandshake or serverHandshake + handshakeFn func(context.Context) error // (*Conn).clientHandshake or serverHandshake // handshakeStatus is 1 if the connection is currently transferring // application data (i.e. is not currently processing a handshake). @@ -1190,7 +1191,7 @@ func (c *Conn) handleRenegotiation() error { defer c.handshakeMutex.Unlock() atomic.StoreUint32(&c.handshakeStatus, 0) - if c.handshakeErr = c.clientHandshake(); c.handshakeErr == nil { + if c.handshakeErr = c.clientHandshake(context.Background()); c.handshakeErr == nil { c.handshakes++ } return c.handshakeErr @@ -1373,8 +1374,61 @@ func (c *Conn) closeNotify() error { // first Read or Write will call it automatically. // // For control over canceling or setting a timeout on a handshake, use -// the Dialer's DialContext method. +// HandshakeContext or the Dialer's DialContext method instead. func (c *Conn) Handshake() error { + return c.HandshakeContext(context.Background()) +} + +// HandshakeContext runs the client or server handshake +// protocol if it has not yet been run. +// +// The provided Context must be non-nil. If the context is canceled before +// the handshake is complete, the handshake is interrupted and an error is returned. +// Once the handshake has completed, cancellation of the context will not affect the +// connection. +// +// Most uses of this package need not call HandshakeContext explicitly: the +// first Read or Write will call it automatically. +func (c *Conn) HandshakeContext(ctx context.Context) error { + // Delegate to unexported method for named return + // without confusing documented signature. + return c.handshakeContext(ctx) +} + +func (c *Conn) handshakeContext(ctx context.Context) (ret error) { + handshakeCtx, cancel := context.WithCancel(ctx) + // Note: defer this before starting the "interrupter" goroutine + // so that we can tell the difference between the input being canceled and + // this cancellation. In the former case, we need to close the connection. + defer cancel() + + // Start the "interrupter" goroutine, if this context might be canceled. + // (The background context cannot). + // + // The interrupter goroutine waits for the input context to be done and + // closes the connection if this happens before the function returns. + if ctx.Done() != nil { + done := make(chan struct{}) + interruptRes := make(chan error, 1) + defer func() { + close(done) + if ctxErr := <-interruptRes; ctxErr != nil { + // Return context error to user. + ret = ctxErr + } + }() + go func() { + select { + case <-handshakeCtx.Done(): + // Close the connection, discarding the error + _ = c.conn.Close() + interruptRes <- handshakeCtx.Err() + case <-done: + interruptRes <- nil + } + }() + } + c.handshakeMutex.Lock() defer c.handshakeMutex.Unlock() @@ -1388,7 +1442,7 @@ func (c *Conn) Handshake() error { c.in.Lock() defer c.in.Unlock() - c.handshakeErr = c.handshakeFn() + c.handshakeErr = c.handshakeFn(handshakeCtx) if c.handshakeErr == nil { c.handshakes++ } else { |