summaryrefslogtreecommitdiff
path: root/libgo/go/crypto/tls/conn.go
diff options
context:
space:
mode:
Diffstat (limited to 'libgo/go/crypto/tls/conn.go')
-rw-r--r--libgo/go/crypto/tls/conn.go62
1 files changed, 58 insertions, 4 deletions
diff --git a/libgo/go/crypto/tls/conn.go b/libgo/go/crypto/tls/conn.go
index 72ad52c194b..969f357834c 100644
--- a/libgo/go/crypto/tls/conn.go
+++ b/libgo/go/crypto/tls/conn.go
@@ -8,6 +8,7 @@ package tls
import (
"bytes"
+ "context"
"crypto/cipher"
"crypto/subtle"
"crypto/x509"
@@ -27,7 +28,7 @@ type Conn struct {
// constant
conn net.Conn
isClient bool
- handshakeFn func() error // (*Conn).clientHandshake or serverHandshake
+ handshakeFn func(context.Context) error // (*Conn).clientHandshake or serverHandshake
// handshakeStatus is 1 if the connection is currently transferring
// application data (i.e. is not currently processing a handshake).
@@ -1190,7 +1191,7 @@ func (c *Conn) handleRenegotiation() error {
defer c.handshakeMutex.Unlock()
atomic.StoreUint32(&c.handshakeStatus, 0)
- if c.handshakeErr = c.clientHandshake(); c.handshakeErr == nil {
+ if c.handshakeErr = c.clientHandshake(context.Background()); c.handshakeErr == nil {
c.handshakes++
}
return c.handshakeErr
@@ -1373,8 +1374,61 @@ func (c *Conn) closeNotify() error {
// first Read or Write will call it automatically.
//
// For control over canceling or setting a timeout on a handshake, use
-// the Dialer's DialContext method.
+// HandshakeContext or the Dialer's DialContext method instead.
func (c *Conn) Handshake() error {
+ return c.HandshakeContext(context.Background())
+}
+
+// HandshakeContext runs the client or server handshake
+// protocol if it has not yet been run.
+//
+// The provided Context must be non-nil. If the context is canceled before
+// the handshake is complete, the handshake is interrupted and an error is returned.
+// Once the handshake has completed, cancellation of the context will not affect the
+// connection.
+//
+// Most uses of this package need not call HandshakeContext explicitly: the
+// first Read or Write will call it automatically.
+func (c *Conn) HandshakeContext(ctx context.Context) error {
+ // Delegate to unexported method for named return
+ // without confusing documented signature.
+ return c.handshakeContext(ctx)
+}
+
+func (c *Conn) handshakeContext(ctx context.Context) (ret error) {
+ handshakeCtx, cancel := context.WithCancel(ctx)
+ // Note: defer this before starting the "interrupter" goroutine
+ // so that we can tell the difference between the input being canceled and
+ // this cancellation. In the former case, we need to close the connection.
+ defer cancel()
+
+ // Start the "interrupter" goroutine, if this context might be canceled.
+ // (The background context cannot).
+ //
+ // The interrupter goroutine waits for the input context to be done and
+ // closes the connection if this happens before the function returns.
+ if ctx.Done() != nil {
+ done := make(chan struct{})
+ interruptRes := make(chan error, 1)
+ defer func() {
+ close(done)
+ if ctxErr := <-interruptRes; ctxErr != nil {
+ // Return context error to user.
+ ret = ctxErr
+ }
+ }()
+ go func() {
+ select {
+ case <-handshakeCtx.Done():
+ // Close the connection, discarding the error
+ _ = c.conn.Close()
+ interruptRes <- handshakeCtx.Err()
+ case <-done:
+ interruptRes <- nil
+ }
+ }()
+ }
+
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
@@ -1388,7 +1442,7 @@ func (c *Conn) Handshake() error {
c.in.Lock()
defer c.in.Unlock()
- c.handshakeErr = c.handshakeFn()
+ c.handshakeErr = c.handshakeFn(handshakeCtx)
if c.handshakeErr == nil {
c.handshakes++
} else {