summaryrefslogtreecommitdiff
path: root/src/mongo/gotools/common/db/tlsgo/tlsgo.go
diff options
context:
space:
mode:
Diffstat (limited to 'src/mongo/gotools/common/db/tlsgo/tlsgo.go')
-rw-r--r--src/mongo/gotools/common/db/tlsgo/tlsgo.go135
1 files changed, 135 insertions, 0 deletions
diff --git a/src/mongo/gotools/common/db/tlsgo/tlsgo.go b/src/mongo/gotools/common/db/tlsgo/tlsgo.go
new file mode 100644
index 00000000000..c26b7e2dc4f
--- /dev/null
+++ b/src/mongo/gotools/common/db/tlsgo/tlsgo.go
@@ -0,0 +1,135 @@
+// Copyright (C) MongoDB, Inc. 2014-present.
+//
+// Licensed under the Apache License, Version 2.0 (the "License"); you may
+// not use this file except in compliance with the License. You may obtain
+// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
+
+// Package tlsgo implements connection to MongoDB with Go native TLS.
+package tlsgo
+
+import (
+ "crypto/tls"
+ "fmt"
+ "net"
+ "strings"
+ "time"
+
+ "github.com/mongodb/mongo-tools/common/db/kerberos"
+ "github.com/mongodb/mongo-tools/common/log"
+ "github.com/mongodb/mongo-tools/common/options"
+ "github.com/mongodb/mongo-tools/common/util"
+ "gopkg.in/mgo.v2"
+)
+
+// TLSDBConnector makes a connection to the database with Go native TLS.
+type TLSDBConnector struct {
+ dialInfo *mgo.DialInfo
+ config *TLSConfig
+}
+
+// Configure the connector to connect to the server over ssl. Sets up the
+// correct function to dial the server based on the ssl options passed in.
+func (c *TLSDBConnector) Configure(opts options.ToolOptions) error {
+ if opts.SSLFipsMode {
+ return fmt.Errorf("FIPS mode not supported")
+ }
+
+ if opts.SSLCRLFile != "" {
+ return fmt.Errorf("CRL files are not supported on this platform")
+ }
+
+ c.config = NewTLSConfig()
+
+ if opts.SSLAllowInvalidCert || opts.SSLAllowInvalidHost {
+ c.config.SetInsecure(true)
+ }
+
+ if opts.SSLPEMKeyFile != "" {
+ subject, err := c.config.AddClientCertFromFile(opts.SSLPEMKeyFile, opts.SSLPEMKeyPassword)
+ if err != nil {
+ return err
+ }
+ if opts.Auth.Mechanism == "MONGODB-X509" && opts.Auth.Username == "" {
+ opts.Auth.Username = subject
+ }
+ }
+
+ if opts.SSLCAFile != "" {
+ c.config.AddCaCertFromFile(opts.SSLCAFile)
+ }
+
+ // set up the dial info
+ c.dialInfo = &mgo.DialInfo{
+ Timeout: time.Duration(opts.Timeout) * time.Second,
+ Direct: opts.Direct,
+ ReplicaSetName: opts.ReplicaSetName,
+ DialServer: c.makeDialer(opts),
+ Username: opts.Auth.Username,
+ Password: opts.Auth.Password,
+ Source: opts.GetAuthenticationDatabase(),
+ Mechanism: opts.Auth.Mechanism,
+ }
+
+ // create or fetch the addresses to be used to connect
+ if opts.URI != nil && opts.URI.ConnectionString != "" {
+ c.dialInfo.Addrs = opts.URI.GetConnectionAddrs()
+ } else {
+ c.dialInfo.Addrs = util.CreateConnectionAddrs(opts.Host, opts.Port)
+ }
+ kerberos.AddKerberosOpts(opts, c.dialInfo)
+ return nil
+}
+
+// GetNewSession dials the server.
+func (c *TLSDBConnector) GetNewSession() (*mgo.Session, error) {
+ return mgo.DialWithInfo(c.dialInfo)
+}
+
+// To be handed to mgo.DialInfo for connecting to the server.
+type dialerFunc func(addr *mgo.ServerAddr) (net.Conn, error)
+
+func (c *TLSDBConnector) makeDialer(opts options.ToolOptions) dialerFunc {
+ return func(addr *mgo.ServerAddr) (net.Conn, error) {
+ address := addr.String()
+ conn, err := net.Dial("tcp", address)
+ if err != nil {
+ // mgo discards dialer errors so log it now
+ log.Logvf(log.Always, "error dialing %v: %v", address, err)
+ return nil, err
+ }
+ // enable TCP keepalive
+ err = util.EnableTCPKeepAlive(conn, time.Duration(opts.TCPKeepAliveSeconds)*time.Second)
+ if err != nil {
+ // mgo discards dialer errors so log it now
+ log.Logvf(log.Always, "error enabling TCP keepalive on connection to %v: %v", address, err)
+ conn.Close()
+ return nil, err
+ }
+
+ tlsConfig, err := c.config.MakeConfig()
+ if err != nil {
+ return nil, err
+ }
+
+ if !tlsConfig.InsecureSkipVerify {
+ colonPos := strings.LastIndex(address, ":")
+ if colonPos == -1 {
+ colonPos = len(address)
+ }
+
+ hostname := address[:colonPos]
+ tlsConfig.ServerName = hostname
+ }
+
+ client := tls.Client(conn, tlsConfig)
+ err = client.Handshake()
+ if err != nil {
+ // mgo discards dialer errors so log it now
+ log.Logvf(log.Always, "error doing TLS handshake with %v: %v", address, err)
+ client.Close()
+ return nil, err
+ }
+
+ return client, nil
+ }
+}