diff options
Diffstat (limited to 'src/mongo/gotools/common/db/openssl/openssl.go')
-rw-r--r-- | src/mongo/gotools/common/db/openssl/openssl.go | 38 |
1 files changed, 25 insertions, 13 deletions
diff --git a/src/mongo/gotools/common/db/openssl/openssl.go b/src/mongo/gotools/common/db/openssl/openssl.go index 1cc4c2ccd1f..db9d71c85aa 100644 --- a/src/mongo/gotools/common/db/openssl/openssl.go +++ b/src/mongo/gotools/common/db/openssl/openssl.go @@ -7,6 +7,7 @@ import ( "time" "github.com/mongodb/mongo-tools/common/db/kerberos" + "github.com/mongodb/mongo-tools/common/log" "github.com/mongodb/mongo-tools/common/options" "github.com/mongodb/mongo-tools/common/util" "github.com/spacemonkeygo/openssl" @@ -15,17 +16,14 @@ import ( // For connecting to the database over ssl type SSLDBConnector struct { - dialInfo *mgo.DialInfo - dialError error - ctx *openssl.Ctx + dialInfo *mgo.DialInfo + ctx *openssl.Ctx } // Configure the connector to connect to the server over ssl. Parses the // connection string, and sets up the correct function to dial the server // based on the ssl options passed in. func (self *SSLDBConnector) Configure(opts options.ToolOptions) error { - // create the addresses to be used to connect - connectionAddrs := util.CreateConnectionAddrs(opts.Host, opts.Port) var err error self.ctx, err = setupCtx(opts) @@ -41,15 +39,26 @@ func (self *SSLDBConnector) Configure(opts options.ToolOptions) error { // create the dialer func that will be used to connect dialer := func(addr *mgo.ServerAddr) (net.Conn, error) { conn, err := openssl.Dial("tcp", addr.String(), self.ctx, flags) - self.dialError = err - return conn, err + if err != nil { + // mgo discards dialer errors so log it now + log.Logvf(log.Always, "error dialing %v: %v", addr.String(), err) + return nil, err + } + // enable TCP keepalive + err = util.EnableTCPKeepAlive(conn.UnderlyingConn(), time.Duration(opts.TCPKeepAliveSeconds)*time.Second) + if err != nil { + // mgo discards dialer errors so log it now + log.Logvf(log.Always, "error enabling TCP keepalive on connection to %v: %v", addr.String(), err) + conn.Close() + return nil, err + } + return conn, nil } timeout := time.Duration(opts.Timeout) * time.Second // set up the dial info self.dialInfo = &mgo.DialInfo{ - Addrs: connectionAddrs, Timeout: timeout, Direct: opts.Direct, ReplicaSetName: opts.ReplicaSetName, @@ -59,6 +68,13 @@ func (self *SSLDBConnector) Configure(opts options.ToolOptions) error { Source: opts.GetAuthenticationDatabase(), Mechanism: opts.Auth.Mechanism, } + + // create or fetch the addresses to be used to connect + if opts.URI != nil && opts.URI.ConnectionString != "" { + self.dialInfo.Addrs = opts.URI.GetConnectionAddrs() + } else { + self.dialInfo.Addrs = util.CreateConnectionAddrs(opts.Host, opts.Port) + } kerberos.AddKerberosOpts(opts, self.dialInfo) return nil @@ -66,11 +82,7 @@ func (self *SSLDBConnector) Configure(opts options.ToolOptions) error { // Dial the server. func (self *SSLDBConnector) GetNewSession() (*mgo.Session, error) { - session, err := mgo.DialWithInfo(self.dialInfo) - if err != nil && self.dialError != nil { - return nil, fmt.Errorf("%v, openssl error: %v", err, self.dialError) - } - return session, err + return mgo.DialWithInfo(self.dialInfo) } // To be handed to mgo.DialInfo for connecting to the server. |