summaryrefslogtreecommitdiff
path: root/src/pkg
diff options
context:
space:
mode:
authorPercy Wegmann <ox.to.a.cart@gmail.com>2014-08-06 11:22:00 -0700
committerPercy Wegmann <ox.to.a.cart@gmail.com>2014-08-06 11:22:00 -0700
commitfa1b7e0aa2db2ac548824595fe7b88f455d2bcfa (patch)
tree30d8f05ccf529932d40120f055a1dc54fb34f3d1 /src/pkg
parent05c70485e187e3668b7ddc2fd59860e22dd96201 (diff)
downloadgo-fa1b7e0aa2db2ac548824595fe7b88f455d2bcfa.tar.gz
crypto/tls: Added dynamic alternative to NameToCertificate map for SNI
Revised version of https://codereview.appspot.com/81260045/ LGTM=agl R=golang-codereviews, gobot, agl, ox CC=golang-codereviews https://codereview.appspot.com/107400043 Committer: Adam Langley <agl@golang.org>
Diffstat (limited to 'src/pkg')
-rw-r--r--src/pkg/crypto/tls/common.go57
-rw-r--r--src/pkg/crypto/tls/conn_test.go22
-rw-r--r--src/pkg/crypto/tls/handshake_server.go11
-rw-r--r--src/pkg/crypto/tls/handshake_server_test.go78
4 files changed, 147 insertions, 21 deletions
diff --git a/src/pkg/crypto/tls/common.go b/src/pkg/crypto/tls/common.go
index ce3038591..2b59136e6 100644
--- a/src/pkg/crypto/tls/common.go
+++ b/src/pkg/crypto/tls/common.go
@@ -202,6 +202,32 @@ type ClientSessionCache interface {
Put(sessionKey string, cs *ClientSessionState)
}
+// ClientHelloInfo contains information from a ClientHello message in order to
+// guide certificate selection in the GetCertificate callback.
+type ClientHelloInfo struct {
+ // CipherSuites lists the CipherSuites supported by the client (e.g.
+ // TLS_RSA_WITH_RC4_128_SHA).
+ CipherSuites []uint16
+
+ // ServerName indicates the name of the server requested by the client
+ // in order to support virtual hosting. ServerName is only set if the
+ // client is using SNI (see
+ // http://tools.ietf.org/html/rfc4366#section-3.1).
+ ServerName string
+
+ // SupportedCurves lists the elliptic curves supported by the client.
+ // SupportedCurves is set only if the Supported Elliptic Curves
+ // Extension is being used (see
+ // http://tools.ietf.org/html/rfc4492#section-5.1.1).
+ SupportedCurves []CurveID
+
+ // SupportedPoints lists the point formats supported by the client.
+ // SupportedPoints is set only if the Supported Point Formats Extension
+ // is being used (see
+ // http://tools.ietf.org/html/rfc4492#section-5.1.2).
+ SupportedPoints []uint8
+}
+
// A Config structure is used to configure a TLS client or server.
// After one has been passed to a TLS function it must not be
// modified. A Config may be reused; the tls package will also not
@@ -230,6 +256,13 @@ type Config struct {
// for all connections.
NameToCertificate map[string]*Certificate
+ // GetCertificate returns a Certificate based on the given
+ // ClientHelloInfo. If GetCertificate is nil or returns nil, then the
+ // certificate is retrieved from NameToCertificate. If
+ // NameToCertificate is nil, the first element of Certificates will be
+ // used.
+ GetCertificate func(clientHello *ClientHelloInfo) (*Certificate, error)
+
// RootCAs defines the set of root certificate authorities
// that clients use when verifying server certificates.
// If RootCAs is nil, TLS uses the host's root CA set.
@@ -384,22 +417,28 @@ func (c *Config) mutualVersion(vers uint16) (uint16, bool) {
return vers, true
}
-// getCertificateForName returns the best certificate for the given name,
-// defaulting to the first element of c.Certificates if there are no good
-// options.
-func (c *Config) getCertificateForName(name string) *Certificate {
+// getCertificate returns the best certificate for the given ClientHelloInfo,
+// defaulting to the first element of c.Certificates.
+func (c *Config) getCertificate(clientHello *ClientHelloInfo) (*Certificate, error) {
+ if c.GetCertificate != nil {
+ cert, err := c.GetCertificate(clientHello)
+ if cert != nil || err != nil {
+ return cert, err
+ }
+ }
+
if len(c.Certificates) == 1 || c.NameToCertificate == nil {
// There's only one choice, so no point doing any work.
- return &c.Certificates[0]
+ return &c.Certificates[0], nil
}
- name = strings.ToLower(name)
+ name := strings.ToLower(clientHello.ServerName)
for len(name) > 0 && name[len(name)-1] == '.' {
name = name[:len(name)-1]
}
if cert, ok := c.NameToCertificate[name]; ok {
- return cert
+ return cert, nil
}
// try replacing labels in the name with wildcards until we get a
@@ -409,12 +448,12 @@ func (c *Config) getCertificateForName(name string) *Certificate {
labels[i] = "*"
candidate := strings.Join(labels, ".")
if cert, ok := c.NameToCertificate[candidate]; ok {
- return cert
+ return cert, nil
}
}
// If nothing matches, return the first certificate.
- return &c.Certificates[0]
+ return &c.Certificates[0], nil
}
// BuildNameToCertificate parses c.Certificates and builds c.NameToCertificate
diff --git a/src/pkg/crypto/tls/conn_test.go b/src/pkg/crypto/tls/conn_test.go
index 5c555147c..ec802cad7 100644
--- a/src/pkg/crypto/tls/conn_test.go
+++ b/src/pkg/crypto/tls/conn_test.go
@@ -88,19 +88,31 @@ func TestCertificateSelection(t *testing.T) {
return -1
}
- if n := pointerToIndex(config.getCertificateForName("example.com")); n != 0 {
+ certificateForName := func(name string) *Certificate {
+ clientHello := &ClientHelloInfo{
+ ServerName: name,
+ }
+ if cert, err := config.getCertificate(clientHello); err != nil {
+ t.Errorf("unable to get certificate for name '%s': %s", name, err)
+ return nil
+ } else {
+ return cert
+ }
+ }
+
+ if n := pointerToIndex(certificateForName("example.com")); n != 0 {
t.Errorf("example.com returned certificate %d, not 0", n)
}
- if n := pointerToIndex(config.getCertificateForName("bar.example.com")); n != 1 {
+ if n := pointerToIndex(certificateForName("bar.example.com")); n != 1 {
t.Errorf("bar.example.com returned certificate %d, not 1", n)
}
- if n := pointerToIndex(config.getCertificateForName("foo.example.com")); n != 2 {
+ if n := pointerToIndex(certificateForName("foo.example.com")); n != 2 {
t.Errorf("foo.example.com returned certificate %d, not 2", n)
}
- if n := pointerToIndex(config.getCertificateForName("foo.bar.example.com")); n != 3 {
+ if n := pointerToIndex(certificateForName("foo.bar.example.com")); n != 3 {
t.Errorf("foo.bar.example.com returned certificate %d, not 3", n)
}
- if n := pointerToIndex(config.getCertificateForName("foo.bar.baz.example.com")); n != 0 {
+ if n := pointerToIndex(certificateForName("foo.bar.baz.example.com")); n != 0 {
t.Errorf("foo.bar.baz.example.com returned certificate %d, not 0", n)
}
}
diff --git a/src/pkg/crypto/tls/handshake_server.go b/src/pkg/crypto/tls/handshake_server.go
index 910f3d6ef..39eeb363c 100644
--- a/src/pkg/crypto/tls/handshake_server.go
+++ b/src/pkg/crypto/tls/handshake_server.go
@@ -186,7 +186,16 @@ Curves:
}
hs.cert = &config.Certificates[0]
if len(hs.clientHello.serverName) > 0 {
- hs.cert = config.getCertificateForName(hs.clientHello.serverName)
+ chi := &ClientHelloInfo{
+ CipherSuites: hs.clientHello.cipherSuites,
+ ServerName: hs.clientHello.serverName,
+ SupportedCurves: hs.clientHello.supportedCurves,
+ SupportedPoints: hs.clientHello.supportedPoints,
+ }
+ if hs.cert, err = config.getCertificate(chi); err != nil {
+ c.sendAlert(alertInternalError)
+ return false, err
+ }
}
_, hs.ecdsaOk = hs.cert.PrivateKey.(*ecdsa.PrivateKey)
diff --git a/src/pkg/crypto/tls/handshake_server_test.go b/src/pkg/crypto/tls/handshake_server_test.go
index e94dc9f99..36c79a9b0 100644
--- a/src/pkg/crypto/tls/handshake_server_test.go
+++ b/src/pkg/crypto/tls/handshake_server_test.go
@@ -257,6 +257,9 @@ type serverTest struct {
expectedPeerCerts []string
// config, if not nil, contains a custom Config to use for this test.
config *Config
+ // expectAlert, if true, indicates that a fatal alert should be returned
+ // when handshaking with the server.
+ expectAlert bool
// validate, if not nil, is a function that will be called with the
// ConnectionState of the resulting connection. It returns false if the
// ConnectionState is unacceptable.
@@ -370,7 +373,9 @@ func (test *serverTest) run(t *testing.T, write bool) {
if !write {
flows, err := test.loadData()
if err != nil {
- t.Fatalf("%s: failed to load data from %s", test.name, test.dataPath())
+ if !test.expectAlert {
+ t.Fatalf("%s: failed to load data from %s", test.name, test.dataPath())
+ }
}
for i, b := range flows {
if i%2 == 0 {
@@ -379,11 +384,17 @@ func (test *serverTest) run(t *testing.T, write bool) {
}
bb := make([]byte, len(b))
n, err := io.ReadFull(clientConn, bb)
- if err != nil {
- t.Fatalf("%s #%d: %s\nRead %d, wanted %d, got %x, wanted %x\n", test.name, i+1, err, n, len(bb), bb[:n], b)
- }
- if !bytes.Equal(b, bb) {
- t.Fatalf("%s #%d: mismatch on read: got:%x want:%x", test.name, i+1, bb, b)
+ if test.expectAlert {
+ if err == nil {
+ t.Fatal("Expected read failure but read succeeded")
+ }
+ } else {
+ if err != nil {
+ t.Fatalf("%s #%d: %s\nRead %d, wanted %d, got %x, wanted %x\n", test.name, i+1, err, n, len(bb), bb[:n], b)
+ }
+ if !bytes.Equal(b, bb) {
+ t.Fatalf("%s #%d: mismatch on read: got:%x want:%x", test.name, i+1, bb, b)
+ }
}
}
clientConn.Close()
@@ -562,6 +573,61 @@ func TestHandshakeServerSNI(t *testing.T) {
runServerTestTLS12(t, test)
}
+// TestHandshakeServerSNICertForName is similar to TestHandshakeServerSNI, but
+// tests the dynamic GetCertificate method
+func TestHandshakeServerSNIGetCertificate(t *testing.T) {
+ config := *testConfig
+
+ // Replace the NameToCertificate map with a GetCertificate function
+ nameToCert := config.NameToCertificate
+ config.NameToCertificate = nil
+ config.GetCertificate = func(clientHello *ClientHelloInfo) (*Certificate, error) {
+ cert, _ := nameToCert[clientHello.ServerName]
+ return cert, nil
+ }
+ test := &serverTest{
+ name: "SNI",
+ command: []string{"openssl", "s_client", "-no_ticket", "-cipher", "AES128-SHA", "-servername", "snitest.com"},
+ config: &config,
+ }
+ runServerTestTLS12(t, test)
+}
+
+// TestHandshakeServerSNICertForNameNotFound is similar to
+// TestHandshakeServerSNICertForName, but tests to make sure that when the
+// GetCertificate method doesn't return a cert, we fall back to what's in
+// the NameToCertificate map.
+func TestHandshakeServerSNIGetCertificateNotFound(t *testing.T) {
+ config := *testConfig
+
+ config.GetCertificate = func(clientHello *ClientHelloInfo) (*Certificate, error) {
+ return nil, nil
+ }
+ test := &serverTest{
+ name: "SNI",
+ command: []string{"openssl", "s_client", "-no_ticket", "-cipher", "AES128-SHA", "-servername", "snitest.com"},
+ config: &config,
+ }
+ runServerTestTLS12(t, test)
+}
+
+// TestHandshakeServerSNICertForNameError tests to make sure that errors in
+// GetCertificate result in a tls alert.
+func TestHandshakeServerSNIGetCertificateError(t *testing.T) {
+ config := *testConfig
+
+ config.GetCertificate = func(clientHello *ClientHelloInfo) (*Certificate, error) {
+ return nil, fmt.Errorf("Test error in GetCertificate")
+ }
+ test := &serverTest{
+ name: "SNI",
+ command: []string{"openssl", "s_client", "-no_ticket", "-cipher", "AES128-SHA", "-servername", "snitest.com"},
+ config: &config,
+ expectAlert: true,
+ }
+ runServerTestTLS12(t, test)
+}
+
// TestCipherSuiteCertPreferance ensures that we select an RSA ciphersuite with
// an RSA certificate and an ECDSA ciphersuite with an ECDSA certificate.
func TestCipherSuiteCertPreferenceECDSA(t *testing.T) {