diff options
Diffstat (limited to 'src/mongo/gotools/src/github.com/mongodb/mongo-tools/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/connection.go')
-rw-r--r-- | src/mongo/gotools/src/github.com/mongodb/mongo-tools/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/connection.go | 109 |
1 files changed, 62 insertions, 47 deletions
diff --git a/src/mongo/gotools/src/github.com/mongodb/mongo-tools/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/connection.go b/src/mongo/gotools/src/github.com/mongodb/mongo-tools/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/connection.go index fd227fca7d9..db1a2ab514a 100644 --- a/src/mongo/gotools/src/github.com/mongodb/mongo-tools/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/connection.go +++ b/src/mongo/gotools/src/github.com/mongodb/mongo-tools/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/connection.go @@ -37,7 +37,7 @@ type connection struct { nc net.Conn // When nil, the connection is closed. addr address.Address idleTimeout time.Duration - idleDeadline time.Time + idleDeadline atomic.Value // Stores a time.Time lifetimeDeadline time.Time readTimeout time.Duration writeTimeout time.Duration @@ -87,6 +87,7 @@ func newConnection(ctx context.Context, addr address.Address, opts ...Connection // connect handles the I/O for a connection. It will dial, configure TLS, and perform // initialization handshakes. func (c *connection) connect(ctx context.Context) { + if !atomic.CompareAndSwapInt32(&c.connected, initialized, connected) { return } @@ -113,46 +114,56 @@ func (c *connection) connect(ctx context.Context) { c.bumpIdleDeadline() // running isMaster and authentication is handled by a handshaker on the configuration instance. - if c.config.handshaker != nil { - c.desc, err = c.config.handshaker.Handshake(ctx, c.addr, initConnection{c}) - if err != nil { - if c.nc != nil { - _ = c.nc.Close() - } - atomic.StoreInt32(&c.connected, disconnected) - c.connectErr = ConnectionError{Wrapped: err, init: true} - return - } - if c.config.descCallback != nil { - c.config.descCallback(c.desc) + handshaker := c.config.handshaker + if handshaker == nil { + return + } + + handshakeConn := initConnection{c} + c.desc, err = handshaker.GetDescription(ctx, c.addr, handshakeConn) + if err == nil { + err = handshaker.FinishHandshake(ctx, handshakeConn) + } + if err != nil { + if c.nc != nil { + _ = c.nc.Close() } - if len(c.desc.Compression) > 0 { - clientMethodLoop: - for _, method := range c.config.compressors { - for _, serverMethod := range c.desc.Compression { - if method != serverMethod { - continue - } + atomic.StoreInt32(&c.connected, disconnected) + c.connectErr = ConnectionError{Wrapped: err, init: true} + return + } - switch strings.ToLower(method) { - case "snappy": - c.compressor = wiremessage.CompressorSnappy - case "zlib": - c.compressor = wiremessage.CompressorZLib - c.zliblevel = wiremessage.DefaultZlibLevel - if c.config.zlibLevel != nil { - c.zliblevel = *c.config.zlibLevel - } + if c.config.descCallback != nil { + c.config.descCallback(c.desc) + } + if len(c.desc.Compression) > 0 { + clientMethodLoop: + for _, method := range c.config.compressors { + for _, serverMethod := range c.desc.Compression { + if method != serverMethod { + continue + } + + switch strings.ToLower(method) { + case "snappy": + c.compressor = wiremessage.CompressorSnappy + case "zlib": + c.compressor = wiremessage.CompressorZLib + c.zliblevel = wiremessage.DefaultZlibLevel + if c.config.zlibLevel != nil { + c.zliblevel = *c.config.zlibLevel } - break clientMethodLoop } + break clientMethodLoop } } } } -func (c *connection) connectWait() error { - <-c.connectDone +func (c *connection) wait() error { + if c.connectDone != nil { + <-c.connectDone + } return c.connectErr } @@ -259,7 +270,11 @@ func (c *connection) close() error { return nil } if c.pool == nil { - err := c.nc.Close() + var err error + + if c.nc != nil { + err = c.nc.Close() + } atomic.StoreInt32(&c.connected, disconnected) return err } @@ -268,7 +283,8 @@ func (c *connection) close() error { func (c *connection) expired() bool { now := time.Now() - if !c.idleDeadline.IsZero() && now.After(c.idleDeadline) { + idleDeadline, ok := c.idleDeadline.Load().(time.Time) + if ok && now.After(idleDeadline) { return true } @@ -281,7 +297,7 @@ func (c *connection) expired() bool { func (c *connection) bumpIdleDeadline() { if c.idleTimeout > 0 { - c.idleDeadline = time.Now().Add(c.idleTimeout) + c.idleDeadline.Store(time.Now().Add(c.idleTimeout)) } } @@ -292,10 +308,15 @@ type initConnection struct{ *connection } var _ driver.Connection = initConnection{} -func (c initConnection) Description() description.Server { return description.Server{} } -func (c initConnection) Close() error { return nil } -func (c initConnection) ID() string { return c.id } -func (c initConnection) Address() address.Address { return c.addr } +func (c initConnection) Description() description.Server { + if c.connection == nil { + return description.Server{} + } + return c.connection.desc +} +func (c initConnection) Close() error { return nil } +func (c initConnection) ID() string { return c.id } +func (c initConnection) Address() address.Address { return c.addr } func (c initConnection) LocalAddress() address.Address { if c.connection == nil || c.nc == nil { return address.Address("0.0.0.0") @@ -410,11 +431,8 @@ func (c *Connection) Close() error { defer c.s.sem.Release(1) } err := c.pool.put(c.connection) - if err != nil { - return err - } c.connection = nil - return nil + return err } // Expire closes this connection and will closeConnection the underlying socket. @@ -428,11 +446,8 @@ func (c *Connection) Expire() error { c.s.sem.Release(1) } err := c.close() - if err != nil { - return err - } c.connection = nil - return nil + return err } // Alive returns if the connection is still alive. |