summaryrefslogtreecommitdiff
path: root/src/mongo/gotools/common/db/openssl/openssl.go
blob: d938cf5d5327eeb5f252c65479480e005cef243a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
// Package openssl implements connection to MongoDB over ssl.
package openssl

import (
	"fmt"
	"net"
	"time"

	"github.com/mongodb/mongo-tools/common/db/kerberos"
	"github.com/mongodb/mongo-tools/common/options"
	"github.com/mongodb/mongo-tools/common/util"
	"github.com/spacemonkeygo/openssl"
	"gopkg.in/mgo.v2"
)

// For connecting to the database over ssl
type SSLDBConnector struct {
	dialInfo  *mgo.DialInfo
	dialError error
	ctx       *openssl.Ctx
}

// Configure the connector to connect to the server over ssl. Parses the
// connection string, and sets up the correct function to dial the server
// based on the ssl options passed in.
func (self *SSLDBConnector) Configure(opts options.ToolOptions) error {

	var err error
	self.ctx, err = setupCtx(opts)
	if err != nil {
		return fmt.Errorf("openssl configuration: %v", err)
	}

	var flags openssl.DialFlags
	flags = 0
	if opts.SSLAllowInvalidCert || opts.SSLAllowInvalidHost {
		flags = openssl.InsecureSkipHostVerification
	}
	// create the dialer func that will be used to connect
	dialer := func(addr *mgo.ServerAddr) (net.Conn, error) {
		conn, err := openssl.Dial("tcp", addr.String(), self.ctx, flags)
		self.dialError = err
		return conn, err
	}

	timeout := time.Duration(opts.Timeout) * time.Second

	// set up the dial info
	self.dialInfo = &mgo.DialInfo{
		Timeout:        timeout,
		Direct:         opts.Direct,
		ReplicaSetName: opts.ReplicaSetName,
		DialServer:     dialer,
		Username:       opts.Auth.Username,
		Password:       opts.Auth.Password,
		Source:         opts.GetAuthenticationDatabase(),
		Mechanism:      opts.Auth.Mechanism,
	}

	// create or fetch the addresses to be used to connect
	if opts.URI != nil && opts.URI.ConnectionString != "" {
		self.dialInfo.Addrs = opts.URI.GetConnectionAddrs()
	} else {
		self.dialInfo.Addrs = util.CreateConnectionAddrs(opts.Host, opts.Port)
	}
	kerberos.AddKerberosOpts(opts, self.dialInfo)
	return nil

}

// Dial the server.
func (self *SSLDBConnector) GetNewSession() (*mgo.Session, error) {
	session, err := mgo.DialWithInfo(self.dialInfo)
	if err != nil && self.dialError != nil {
		return nil, fmt.Errorf("%v, openssl error: %v", err, self.dialError)
	}
	return session, err
}

// To be handed to mgo.DialInfo for connecting to the server.
type dialerFunc func(addr *mgo.ServerAddr) (net.Conn, error)

// Handle optionally compiled SSL initialization functions (fips mode set)
type sslInitializationFunction func(options.ToolOptions) error

var sslInitializationFunctions []sslInitializationFunction

// Creates and configures an openssl.Ctx
func setupCtx(opts options.ToolOptions) (*openssl.Ctx, error) {
	var ctx *openssl.Ctx
	var err error

	for _, sslInitFunc := range sslInitializationFunctions {
		sslInitFunc(opts)
	}

	if ctx, err = openssl.NewCtxWithVersion(openssl.AnyVersion); err != nil {
		return nil, fmt.Errorf("failure creating new openssl context with "+
			"NewCtxWithVersion(AnyVersion): %v", err)
	}

	// OpAll - Activate all bug workaround options, to support buggy client SSL's.
	// NoSSLv2 - Disable SSL v2 support
	ctx.SetOptions(openssl.OpAll | openssl.NoSSLv2)

	// HIGH - Enable strong ciphers
	// !EXPORT - Disable export ciphers (40/56 bit)
	// !aNULL - Disable anonymous auth ciphers
	// @STRENGTH - Sort ciphers based on strength
	ctx.SetCipherList("HIGH:!EXPORT:!aNULL@STRENGTH")

	// add the PEM key file with the cert and private key, if specified
	if opts.SSLPEMKeyFile != "" {
		if err = ctx.UseCertificateChainFile(opts.SSLPEMKeyFile); err != nil {
			return nil, fmt.Errorf("UseCertificateChainFile: %v", err)
		}
		if opts.SSLPEMKeyPassword != "" {
			if err = ctx.UsePrivateKeyFileWithPassword(
				opts.SSLPEMKeyFile, openssl.FiletypePEM, opts.SSLPEMKeyPassword); err != nil {
				return nil, fmt.Errorf("UsePrivateKeyFile: %v", err)
			}
		} else {
			if err = ctx.UsePrivateKeyFile(opts.SSLPEMKeyFile, openssl.FiletypePEM); err != nil {
				return nil, fmt.Errorf("UsePrivateKeyFile: %v", err)
			}
		}
		// Verify that the certificate and the key go together.
		if err = ctx.CheckPrivateKey(); err != nil {
			return nil, fmt.Errorf("CheckPrivateKey: %v", err)
		}
	}

	// If renegotiation is needed, don't return from recv() or send() until it's successful.
	// Note: this is for blocking sockets only.
	ctx.SetMode(openssl.AutoRetry)

	// Disable session caching (see SERVER-10261)
	ctx.SetSessionCacheMode(openssl.SessionCacheOff)

	if opts.SSLCAFile != "" {
		calist, err := openssl.LoadClientCAFile(opts.SSLCAFile)
		if err != nil {
			return nil, fmt.Errorf("LoadClientCAFile: %v", err)
		}
		ctx.SetClientCAList(calist)
		if err = ctx.LoadVerifyLocations(opts.SSLCAFile, ""); err != nil {
			return nil, fmt.Errorf("LoadVerifyLocations: %v", err)
		}
	} else {
		err = ctx.SetupSystemCA()
		if err != nil {
			return nil, fmt.Errorf("Error setting up system certificate authority: %v", err)
		}
	}

	var verifyOption openssl.VerifyOptions
	if opts.SSLAllowInvalidCert {
		verifyOption = openssl.VerifyNone
	} else {
		verifyOption = openssl.VerifyPeer
	}
	ctx.SetVerify(verifyOption, nil)

	if opts.SSLCRLFile != "" {
		store := ctx.GetCertificateStore()
		store.SetFlags(openssl.CRLCheck)
		lookup, err := store.AddLookup(openssl.X509LookupFile())
		if err != nil {
			return nil, fmt.Errorf("AddLookup(X509LookupFile()): %v", err)
		}
		lookup.LoadCRLFile(opts.SSLCRLFile)
	}

	return ctx, nil
}