summaryrefslogtreecommitdiff
path: root/src/mongo/gotools/src/github.com/mongodb/mongo-tools/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/connection.go
diff options
context:
space:
mode:
Diffstat (limited to 'src/mongo/gotools/src/github.com/mongodb/mongo-tools/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/connection.go')
-rw-r--r--src/mongo/gotools/src/github.com/mongodb/mongo-tools/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/connection.go109
1 files changed, 62 insertions, 47 deletions
diff --git a/src/mongo/gotools/src/github.com/mongodb/mongo-tools/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/connection.go b/src/mongo/gotools/src/github.com/mongodb/mongo-tools/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/connection.go
index fd227fca7d9..db1a2ab514a 100644
--- a/src/mongo/gotools/src/github.com/mongodb/mongo-tools/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/connection.go
+++ b/src/mongo/gotools/src/github.com/mongodb/mongo-tools/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/connection.go
@@ -37,7 +37,7 @@ type connection struct {
nc net.Conn // When nil, the connection is closed.
addr address.Address
idleTimeout time.Duration
- idleDeadline time.Time
+ idleDeadline atomic.Value // Stores a time.Time
lifetimeDeadline time.Time
readTimeout time.Duration
writeTimeout time.Duration
@@ -87,6 +87,7 @@ func newConnection(ctx context.Context, addr address.Address, opts ...Connection
// connect handles the I/O for a connection. It will dial, configure TLS, and perform
// initialization handshakes.
func (c *connection) connect(ctx context.Context) {
+
if !atomic.CompareAndSwapInt32(&c.connected, initialized, connected) {
return
}
@@ -113,46 +114,56 @@ func (c *connection) connect(ctx context.Context) {
c.bumpIdleDeadline()
// running isMaster and authentication is handled by a handshaker on the configuration instance.
- if c.config.handshaker != nil {
- c.desc, err = c.config.handshaker.Handshake(ctx, c.addr, initConnection{c})
- if err != nil {
- if c.nc != nil {
- _ = c.nc.Close()
- }
- atomic.StoreInt32(&c.connected, disconnected)
- c.connectErr = ConnectionError{Wrapped: err, init: true}
- return
- }
- if c.config.descCallback != nil {
- c.config.descCallback(c.desc)
+ handshaker := c.config.handshaker
+ if handshaker == nil {
+ return
+ }
+
+ handshakeConn := initConnection{c}
+ c.desc, err = handshaker.GetDescription(ctx, c.addr, handshakeConn)
+ if err == nil {
+ err = handshaker.FinishHandshake(ctx, handshakeConn)
+ }
+ if err != nil {
+ if c.nc != nil {
+ _ = c.nc.Close()
}
- if len(c.desc.Compression) > 0 {
- clientMethodLoop:
- for _, method := range c.config.compressors {
- for _, serverMethod := range c.desc.Compression {
- if method != serverMethod {
- continue
- }
+ atomic.StoreInt32(&c.connected, disconnected)
+ c.connectErr = ConnectionError{Wrapped: err, init: true}
+ return
+ }
- switch strings.ToLower(method) {
- case "snappy":
- c.compressor = wiremessage.CompressorSnappy
- case "zlib":
- c.compressor = wiremessage.CompressorZLib
- c.zliblevel = wiremessage.DefaultZlibLevel
- if c.config.zlibLevel != nil {
- c.zliblevel = *c.config.zlibLevel
- }
+ if c.config.descCallback != nil {
+ c.config.descCallback(c.desc)
+ }
+ if len(c.desc.Compression) > 0 {
+ clientMethodLoop:
+ for _, method := range c.config.compressors {
+ for _, serverMethod := range c.desc.Compression {
+ if method != serverMethod {
+ continue
+ }
+
+ switch strings.ToLower(method) {
+ case "snappy":
+ c.compressor = wiremessage.CompressorSnappy
+ case "zlib":
+ c.compressor = wiremessage.CompressorZLib
+ c.zliblevel = wiremessage.DefaultZlibLevel
+ if c.config.zlibLevel != nil {
+ c.zliblevel = *c.config.zlibLevel
}
- break clientMethodLoop
}
+ break clientMethodLoop
}
}
}
}
-func (c *connection) connectWait() error {
- <-c.connectDone
+func (c *connection) wait() error {
+ if c.connectDone != nil {
+ <-c.connectDone
+ }
return c.connectErr
}
@@ -259,7 +270,11 @@ func (c *connection) close() error {
return nil
}
if c.pool == nil {
- err := c.nc.Close()
+ var err error
+
+ if c.nc != nil {
+ err = c.nc.Close()
+ }
atomic.StoreInt32(&c.connected, disconnected)
return err
}
@@ -268,7 +283,8 @@ func (c *connection) close() error {
func (c *connection) expired() bool {
now := time.Now()
- if !c.idleDeadline.IsZero() && now.After(c.idleDeadline) {
+ idleDeadline, ok := c.idleDeadline.Load().(time.Time)
+ if ok && now.After(idleDeadline) {
return true
}
@@ -281,7 +297,7 @@ func (c *connection) expired() bool {
func (c *connection) bumpIdleDeadline() {
if c.idleTimeout > 0 {
- c.idleDeadline = time.Now().Add(c.idleTimeout)
+ c.idleDeadline.Store(time.Now().Add(c.idleTimeout))
}
}
@@ -292,10 +308,15 @@ type initConnection struct{ *connection }
var _ driver.Connection = initConnection{}
-func (c initConnection) Description() description.Server { return description.Server{} }
-func (c initConnection) Close() error { return nil }
-func (c initConnection) ID() string { return c.id }
-func (c initConnection) Address() address.Address { return c.addr }
+func (c initConnection) Description() description.Server {
+ if c.connection == nil {
+ return description.Server{}
+ }
+ return c.connection.desc
+}
+func (c initConnection) Close() error { return nil }
+func (c initConnection) ID() string { return c.id }
+func (c initConnection) Address() address.Address { return c.addr }
func (c initConnection) LocalAddress() address.Address {
if c.connection == nil || c.nc == nil {
return address.Address("0.0.0.0")
@@ -410,11 +431,8 @@ func (c *Connection) Close() error {
defer c.s.sem.Release(1)
}
err := c.pool.put(c.connection)
- if err != nil {
- return err
- }
c.connection = nil
- return nil
+ return err
}
// Expire closes this connection and will closeConnection the underlying socket.
@@ -428,11 +446,8 @@ func (c *Connection) Expire() error {
c.s.sem.Release(1)
}
err := c.close()
- if err != nil {
- return err
- }
c.connection = nil
- return nil
+ return err
}
// Alive returns if the connection is still alive.