diff options
author | Russ Cox <rsc@golang.org> | 2014-09-08 00:08:51 -0400 |
---|---|---|
committer | Russ Cox <rsc@golang.org> | 2014-09-08 00:08:51 -0400 |
commit | 8528da672cc093d4dd06732819abc1f7b6b5a46e (patch) | |
tree | 334be80d4a4c85b77db6f6fdb67cbf0528cba5f5 /src/crypto/tls/handshake_client_test.go | |
parent | 73bcb69f272cbf34ddcc9daa56427a8683b5a95d (diff) | |
download | go-8528da672cc093d4dd06732819abc1f7b6b5a46e.tar.gz |
build: move package sources from src/pkg to src
Preparation was in CL 134570043.
This CL contains only the effect of 'hg mv src/pkg/* src'.
For more about the move, see golang.org/s/go14nopkg.
Diffstat (limited to 'src/crypto/tls/handshake_client_test.go')
-rw-r--r-- | src/crypto/tls/handshake_client_test.go | 490 |
1 files changed, 490 insertions, 0 deletions
diff --git a/src/crypto/tls/handshake_client_test.go b/src/crypto/tls/handshake_client_test.go new file mode 100644 index 000000000..e5eaa7de2 --- /dev/null +++ b/src/crypto/tls/handshake_client_test.go @@ -0,0 +1,490 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "bytes" + "crypto/ecdsa" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "fmt" + "io" + "net" + "os" + "os/exec" + "path/filepath" + "strconv" + "testing" + "time" +) + +// Note: see comment in handshake_test.go for details of how the reference +// tests work. + +// blockingSource is an io.Reader that blocks a Read call until it's closed. +type blockingSource chan bool + +func (b blockingSource) Read([]byte) (n int, err error) { + <-b + return 0, io.EOF +} + +// clientTest represents a test of the TLS client handshake against a reference +// implementation. +type clientTest struct { + // name is a freeform string identifying the test and the file in which + // the expected results will be stored. + name string + // command, if not empty, contains a series of arguments for the + // command to run for the reference server. + command []string + // config, if not nil, contains a custom Config to use for this test. + config *Config + // cert, if not empty, contains a DER-encoded certificate for the + // reference server. + cert []byte + // key, if not nil, contains either a *rsa.PrivateKey or + // *ecdsa.PrivateKey which is the private key for the reference server. + key interface{} + // validate, if not nil, is a function that will be called with the + // ConnectionState of the resulting connection. It returns a non-nil + // error if the ConnectionState is unacceptable. + validate func(ConnectionState) error +} + +var defaultServerCommand = []string{"openssl", "s_server"} + +// connFromCommand starts the reference server process, connects to it and +// returns a recordingConn for the connection. The stdin return value is a +// blockingSource for the stdin of the child process. It must be closed before +// Waiting for child. +func (test *clientTest) connFromCommand() (conn *recordingConn, child *exec.Cmd, stdin blockingSource, err error) { + cert := testRSACertificate + if len(test.cert) > 0 { + cert = test.cert + } + certPath := tempFile(string(cert)) + defer os.Remove(certPath) + + var key interface{} = testRSAPrivateKey + if test.key != nil { + key = test.key + } + var pemType string + var derBytes []byte + switch key := key.(type) { + case *rsa.PrivateKey: + pemType = "RSA" + derBytes = x509.MarshalPKCS1PrivateKey(key) + case *ecdsa.PrivateKey: + pemType = "EC" + var err error + derBytes, err = x509.MarshalECPrivateKey(key) + if err != nil { + panic(err) + } + default: + panic("unknown key type") + } + + var pemOut bytes.Buffer + pem.Encode(&pemOut, &pem.Block{Type: pemType + " PRIVATE KEY", Bytes: derBytes}) + + keyPath := tempFile(string(pemOut.Bytes())) + defer os.Remove(keyPath) + + var command []string + if len(test.command) > 0 { + command = append(command, test.command...) + } else { + command = append(command, defaultServerCommand...) + } + command = append(command, "-cert", certPath, "-certform", "DER", "-key", keyPath) + // serverPort contains the port that OpenSSL will listen on. OpenSSL + // can't take "0" as an argument here so we have to pick a number and + // hope that it's not in use on the machine. Since this only occurs + // when -update is given and thus when there's a human watching the + // test, this isn't too bad. + const serverPort = 24323 + command = append(command, "-accept", strconv.Itoa(serverPort)) + + cmd := exec.Command(command[0], command[1:]...) + stdin = blockingSource(make(chan bool)) + cmd.Stdin = stdin + var out bytes.Buffer + cmd.Stdout = &out + cmd.Stderr = &out + if err := cmd.Start(); err != nil { + return nil, nil, nil, err + } + + // OpenSSL does print an "ACCEPT" banner, but it does so *before* + // opening the listening socket, so we can't use that to wait until it + // has started listening. Thus we are forced to poll until we get a + // connection. + var tcpConn net.Conn + for i := uint(0); i < 5; i++ { + var err error + tcpConn, err = net.DialTCP("tcp", nil, &net.TCPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: serverPort, + }) + if err == nil { + break + } + time.Sleep((1 << i) * 5 * time.Millisecond) + } + if tcpConn == nil { + close(stdin) + out.WriteTo(os.Stdout) + cmd.Process.Kill() + return nil, nil, nil, cmd.Wait() + } + + record := &recordingConn{ + Conn: tcpConn, + } + + return record, cmd, stdin, nil +} + +func (test *clientTest) dataPath() string { + return filepath.Join("testdata", "Client-"+test.name) +} + +func (test *clientTest) loadData() (flows [][]byte, err error) { + in, err := os.Open(test.dataPath()) + if err != nil { + return nil, err + } + defer in.Close() + return parseTestData(in) +} + +func (test *clientTest) run(t *testing.T, write bool) { + var clientConn, serverConn net.Conn + var recordingConn *recordingConn + var childProcess *exec.Cmd + var stdin blockingSource + + if write { + var err error + recordingConn, childProcess, stdin, err = test.connFromCommand() + if err != nil { + t.Fatalf("Failed to start subcommand: %s", err) + } + clientConn = recordingConn + } else { + clientConn, serverConn = net.Pipe() + } + + config := test.config + if config == nil { + config = testConfig + } + client := Client(clientConn, config) + + doneChan := make(chan bool) + go func() { + if _, err := client.Write([]byte("hello\n")); err != nil { + t.Logf("Client.Write failed: %s", err) + } + if test.validate != nil { + if err := test.validate(client.ConnectionState()); err != nil { + t.Logf("validate callback returned error: %s", err) + } + } + client.Close() + clientConn.Close() + doneChan <- true + }() + + if !write { + flows, err := test.loadData() + if err != nil { + t.Fatalf("%s: failed to load data from %s: %v", test.name, test.dataPath(), err) + } + for i, b := range flows { + if i%2 == 1 { + serverConn.Write(b) + continue + } + bb := make([]byte, len(b)) + _, err := io.ReadFull(serverConn, bb) + if err != nil { + t.Fatalf("%s #%d: %s", test.name, i, err) + } + if !bytes.Equal(b, bb) { + t.Fatalf("%s #%d: mismatch on read: got:%x want:%x", test.name, i, bb, b) + } + } + serverConn.Close() + } + + <-doneChan + + if write { + path := test.dataPath() + out, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) + if err != nil { + t.Fatalf("Failed to create output file: %s", err) + } + defer out.Close() + recordingConn.Close() + close(stdin) + childProcess.Process.Kill() + childProcess.Wait() + if len(recordingConn.flows) < 3 { + childProcess.Stdout.(*bytes.Buffer).WriteTo(os.Stdout) + t.Fatalf("Client connection didn't work") + } + recordingConn.WriteTo(out) + fmt.Printf("Wrote %s\n", path) + } +} + +func runClientTestForVersion(t *testing.T, template *clientTest, prefix, option string) { + test := *template + test.name = prefix + test.name + if len(test.command) == 0 { + test.command = defaultClientCommand + } + test.command = append([]string(nil), test.command...) + test.command = append(test.command, option) + test.run(t, *update) +} + +func runClientTestTLS10(t *testing.T, template *clientTest) { + runClientTestForVersion(t, template, "TLSv10-", "-tls1") +} + +func runClientTestTLS11(t *testing.T, template *clientTest) { + runClientTestForVersion(t, template, "TLSv11-", "-tls1_1") +} + +func runClientTestTLS12(t *testing.T, template *clientTest) { + runClientTestForVersion(t, template, "TLSv12-", "-tls1_2") +} + +func TestHandshakeClientRSARC4(t *testing.T) { + test := &clientTest{ + name: "RSA-RC4", + command: []string{"openssl", "s_server", "-cipher", "RC4-SHA"}, + } + runClientTestTLS10(t, test) + runClientTestTLS11(t, test) + runClientTestTLS12(t, test) +} + +func TestHandshakeClientECDHERSAAES(t *testing.T) { + test := &clientTest{ + name: "ECDHE-RSA-AES", + command: []string{"openssl", "s_server", "-cipher", "ECDHE-RSA-AES128-SHA"}, + } + runClientTestTLS10(t, test) + runClientTestTLS11(t, test) + runClientTestTLS12(t, test) +} + +func TestHandshakeClientECDHEECDSAAES(t *testing.T) { + test := &clientTest{ + name: "ECDHE-ECDSA-AES", + command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-SHA"}, + cert: testECDSACertificate, + key: testECDSAPrivateKey, + } + runClientTestTLS10(t, test) + runClientTestTLS11(t, test) + runClientTestTLS12(t, test) +} + +func TestHandshakeClientECDHEECDSAAESGCM(t *testing.T) { + test := &clientTest{ + name: "ECDHE-ECDSA-AES-GCM", + command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-GCM-SHA256"}, + cert: testECDSACertificate, + key: testECDSAPrivateKey, + } + runClientTestTLS12(t, test) +} + +func TestHandshakeClientCertRSA(t *testing.T) { + config := *testConfig + cert, _ := X509KeyPair([]byte(clientCertificatePEM), []byte(clientKeyPEM)) + config.Certificates = []Certificate{cert} + + test := &clientTest{ + name: "ClientCert-RSA-RSA", + command: []string{"openssl", "s_server", "-cipher", "RC4-SHA", "-verify", "1"}, + config: &config, + } + + runClientTestTLS10(t, test) + runClientTestTLS12(t, test) + + test = &clientTest{ + name: "ClientCert-RSA-ECDSA", + command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-SHA", "-verify", "1"}, + config: &config, + cert: testECDSACertificate, + key: testECDSAPrivateKey, + } + + runClientTestTLS10(t, test) + runClientTestTLS12(t, test) +} + +func TestHandshakeClientCertECDSA(t *testing.T) { + config := *testConfig + cert, _ := X509KeyPair([]byte(clientECDSACertificatePEM), []byte(clientECDSAKeyPEM)) + config.Certificates = []Certificate{cert} + + test := &clientTest{ + name: "ClientCert-ECDSA-RSA", + command: []string{"openssl", "s_server", "-cipher", "RC4-SHA", "-verify", "1"}, + config: &config, + } + + runClientTestTLS10(t, test) + runClientTestTLS12(t, test) + + test = &clientTest{ + name: "ClientCert-ECDSA-ECDSA", + command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-SHA", "-verify", "1"}, + config: &config, + cert: testECDSACertificate, + key: testECDSAPrivateKey, + } + + runClientTestTLS10(t, test) + runClientTestTLS12(t, test) +} + +func TestClientResumption(t *testing.T) { + serverConfig := &Config{ + CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA}, + Certificates: testConfig.Certificates, + } + clientConfig := &Config{ + CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, + InsecureSkipVerify: true, + ClientSessionCache: NewLRUClientSessionCache(32), + } + + testResumeState := func(test string, didResume bool) { + hs, err := testHandshake(clientConfig, serverConfig) + if err != nil { + t.Fatalf("%s: handshake failed: %s", test, err) + } + if hs.DidResume != didResume { + t.Fatalf("%s resumed: %v, expected: %v", test, hs.DidResume, didResume) + } + } + + testResumeState("Handshake", false) + testResumeState("Resume", true) + + if _, err := io.ReadFull(serverConfig.rand(), serverConfig.SessionTicketKey[:]); err != nil { + t.Fatalf("Failed to invalidate SessionTicketKey") + } + testResumeState("InvalidSessionTicketKey", false) + testResumeState("ResumeAfterInvalidSessionTicketKey", true) + + clientConfig.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_RC4_128_SHA} + testResumeState("DifferentCipherSuite", false) + testResumeState("DifferentCipherSuiteRecovers", true) + + clientConfig.ClientSessionCache = nil + testResumeState("WithoutSessionCache", false) +} + +func TestLRUClientSessionCache(t *testing.T) { + // Initialize cache of capacity 4. + cache := NewLRUClientSessionCache(4) + cs := make([]ClientSessionState, 6) + keys := []string{"0", "1", "2", "3", "4", "5", "6"} + + // Add 4 entries to the cache and look them up. + for i := 0; i < 4; i++ { + cache.Put(keys[i], &cs[i]) + } + for i := 0; i < 4; i++ { + if s, ok := cache.Get(keys[i]); !ok || s != &cs[i] { + t.Fatalf("session cache failed lookup for added key: %s", keys[i]) + } + } + + // Add 2 more entries to the cache. First 2 should be evicted. + for i := 4; i < 6; i++ { + cache.Put(keys[i], &cs[i]) + } + for i := 0; i < 2; i++ { + if s, ok := cache.Get(keys[i]); ok || s != nil { + t.Fatalf("session cache should have evicted key: %s", keys[i]) + } + } + + // Touch entry 2. LRU should evict 3 next. + cache.Get(keys[2]) + cache.Put(keys[0], &cs[0]) + if s, ok := cache.Get(keys[3]); ok || s != nil { + t.Fatalf("session cache should have evicted key 3") + } + + // Update entry 0 in place. + cache.Put(keys[0], &cs[3]) + if s, ok := cache.Get(keys[0]); !ok || s != &cs[3] { + t.Fatalf("session cache failed update for key 0") + } + + // Adding a nil entry is valid. + cache.Put(keys[0], nil) + if s, ok := cache.Get(keys[0]); !ok || s != nil { + t.Fatalf("failed to add nil entry to cache") + } +} + +func TestHandshakeClientALPNMatch(t *testing.T) { + config := *testConfig + config.NextProtos = []string{"proto2", "proto1"} + + test := &clientTest{ + name: "ALPN", + // Note that this needs OpenSSL 1.0.2 because that is the first + // version that supports the -alpn flag. + command: []string{"openssl", "s_server", "-alpn", "proto1,proto2"}, + config: &config, + validate: func(state ConnectionState) error { + // The server's preferences should override the client. + if state.NegotiatedProtocol != "proto1" { + return fmt.Errorf("Got protocol %q, wanted proto1", state.NegotiatedProtocol) + } + return nil + }, + } + runClientTestTLS12(t, test) +} + +func TestHandshakeClientALPNNoMatch(t *testing.T) { + config := *testConfig + config.NextProtos = []string{"proto3"} + + test := &clientTest{ + name: "ALPN-NoMatch", + // Note that this needs OpenSSL 1.0.2 because that is the first + // version that supports the -alpn flag. + command: []string{"openssl", "s_server", "-alpn", "proto1,proto2"}, + config: &config, + validate: func(state ConnectionState) error { + // There's no overlap so OpenSSL will not select a protocol. + if state.NegotiatedProtocol != "" { + return fmt.Errorf("Got protocol %q, wanted ''", state.NegotiatedProtocol) + } + return nil + }, + } + runClientTestTLS12(t, test) +} |