summaryrefslogtreecommitdiff
path: root/src/mongo/gotools/common/db/openssl/openssl.go
diff options
context:
space:
mode:
Diffstat (limited to 'src/mongo/gotools/common/db/openssl/openssl.go')
-rw-r--r--src/mongo/gotools/common/db/openssl/openssl.go38
1 files changed, 25 insertions, 13 deletions
diff --git a/src/mongo/gotools/common/db/openssl/openssl.go b/src/mongo/gotools/common/db/openssl/openssl.go
index 1cc4c2ccd1f..db9d71c85aa 100644
--- a/src/mongo/gotools/common/db/openssl/openssl.go
+++ b/src/mongo/gotools/common/db/openssl/openssl.go
@@ -7,6 +7,7 @@ import (
"time"
"github.com/mongodb/mongo-tools/common/db/kerberos"
+ "github.com/mongodb/mongo-tools/common/log"
"github.com/mongodb/mongo-tools/common/options"
"github.com/mongodb/mongo-tools/common/util"
"github.com/spacemonkeygo/openssl"
@@ -15,17 +16,14 @@ import (
// For connecting to the database over ssl
type SSLDBConnector struct {
- dialInfo *mgo.DialInfo
- dialError error
- ctx *openssl.Ctx
+ dialInfo *mgo.DialInfo
+ ctx *openssl.Ctx
}
// Configure the connector to connect to the server over ssl. Parses the
// connection string, and sets up the correct function to dial the server
// based on the ssl options passed in.
func (self *SSLDBConnector) Configure(opts options.ToolOptions) error {
- // create the addresses to be used to connect
- connectionAddrs := util.CreateConnectionAddrs(opts.Host, opts.Port)
var err error
self.ctx, err = setupCtx(opts)
@@ -41,15 +39,26 @@ func (self *SSLDBConnector) Configure(opts options.ToolOptions) error {
// create the dialer func that will be used to connect
dialer := func(addr *mgo.ServerAddr) (net.Conn, error) {
conn, err := openssl.Dial("tcp", addr.String(), self.ctx, flags)
- self.dialError = err
- return conn, err
+ if err != nil {
+ // mgo discards dialer errors so log it now
+ log.Logvf(log.Always, "error dialing %v: %v", addr.String(), err)
+ return nil, err
+ }
+ // enable TCP keepalive
+ err = util.EnableTCPKeepAlive(conn.UnderlyingConn(), time.Duration(opts.TCPKeepAliveSeconds)*time.Second)
+ if err != nil {
+ // mgo discards dialer errors so log it now
+ log.Logvf(log.Always, "error enabling TCP keepalive on connection to %v: %v", addr.String(), err)
+ conn.Close()
+ return nil, err
+ }
+ return conn, nil
}
timeout := time.Duration(opts.Timeout) * time.Second
// set up the dial info
self.dialInfo = &mgo.DialInfo{
- Addrs: connectionAddrs,
Timeout: timeout,
Direct: opts.Direct,
ReplicaSetName: opts.ReplicaSetName,
@@ -59,6 +68,13 @@ func (self *SSLDBConnector) Configure(opts options.ToolOptions) error {
Source: opts.GetAuthenticationDatabase(),
Mechanism: opts.Auth.Mechanism,
}
+
+ // create or fetch the addresses to be used to connect
+ if opts.URI != nil && opts.URI.ConnectionString != "" {
+ self.dialInfo.Addrs = opts.URI.GetConnectionAddrs()
+ } else {
+ self.dialInfo.Addrs = util.CreateConnectionAddrs(opts.Host, opts.Port)
+ }
kerberos.AddKerberosOpts(opts, self.dialInfo)
return nil
@@ -66,11 +82,7 @@ func (self *SSLDBConnector) Configure(opts options.ToolOptions) error {
// Dial the server.
func (self *SSLDBConnector) GetNewSession() (*mgo.Session, error) {
- session, err := mgo.DialWithInfo(self.dialInfo)
- if err != nil && self.dialError != nil {
- return nil, fmt.Errorf("%v, openssl error: %v", err, self.dialError)
- }
- return session, err
+ return mgo.DialWithInfo(self.dialInfo)
}
// To be handed to mgo.DialInfo for connecting to the server.