summaryrefslogtreecommitdiff
path: root/src/mongo/gotools/common/db
diff options
context:
space:
mode:
authorRamon Fernandez <ramon@mongodb.com>2016-08-25 16:34:34 -0400
committerRamon Fernandez <ramon@mongodb.com>2016-08-25 16:54:18 -0400
commitc330c9991ab45e7d0685d53e699ef26dba065660 (patch)
tree3dc5cd06b5f6c7eaaa4cb20cbe763504c14a772b /src/mongo/gotools/common/db
parenteb62b862d5ebf179a1bcd9f394070e69c30188ab (diff)
downloadmongo-c330c9991ab45e7d0685d53e699ef26dba065660.tar.gz
Import tools: 5b883d86fdb4df55036d5dba2ca6f9dfa0750b44 from branch v3.3
ref: 1ac1389bda..5b883d86fd for: 3.3.12 SERVER-25814 Initial vendor import: tools
Diffstat (limited to 'src/mongo/gotools/common/db')
-rw-r--r--src/mongo/gotools/common/db/bson_stream.go140
-rw-r--r--src/mongo/gotools/common/db/bson_stream_test.go41
-rw-r--r--src/mongo/gotools/common/db/buffered_bulk.go79
-rw-r--r--src/mongo/gotools/common/db/buffered_bulk_test.go108
-rw-r--r--src/mongo/gotools/common/db/command.go210
-rw-r--r--src/mongo/gotools/common/db/connector.go52
-rw-r--r--src/mongo/gotools/common/db/connector_sasl_test.go60
-rw-r--r--src/mongo/gotools/common/db/connector_test.go134
-rw-r--r--src/mongo/gotools/common/db/db.go243
-rw-r--r--src/mongo/gotools/common/db/db_gssapi.go20
-rw-r--r--src/mongo/gotools/common/db/db_ssl.go20
-rw-r--r--src/mongo/gotools/common/db/db_test.go63
-rw-r--r--src/mongo/gotools/common/db/kerberos/gssapi.go58
-rw-r--r--src/mongo/gotools/common/db/namespaces.go159
-rw-r--r--src/mongo/gotools/common/db/openssl/openssl.go168
-rw-r--r--src/mongo/gotools/common/db/openssl/openssl_fips.go15
-rw-r--r--src/mongo/gotools/common/db/openssl/testdata/ca.pem34
-rw-r--r--src/mongo/gotools/common/db/openssl/testdata/server.pem32
-rw-r--r--src/mongo/gotools/common/db/read_preferences.go51
-rw-r--r--src/mongo/gotools/common/db/testdata/testdata.bsonbin0 -> 1800 bytes
-rw-r--r--src/mongo/gotools/common/db/write_concern.go123
-rw-r--r--src/mongo/gotools/common/db/write_concern_test.go166
22 files changed, 1976 insertions, 0 deletions
diff --git a/src/mongo/gotools/common/db/bson_stream.go b/src/mongo/gotools/common/db/bson_stream.go
new file mode 100644
index 00000000000..780d76cca19
--- /dev/null
+++ b/src/mongo/gotools/common/db/bson_stream.go
@@ -0,0 +1,140 @@
+package db
+
+import (
+ "fmt"
+ "gopkg.in/mgo.v2/bson"
+ "io"
+)
+
+// BSONSource reads documents from the underlying io.ReadCloser, Stream which
+// wraps a stream of BSON documents.
+type BSONSource struct {
+ reusableBuf []byte
+ Stream io.ReadCloser
+ err error
+}
+
+// DecodedBSONSource reads documents from the underlying io.ReadCloser, Stream which
+// wraps a stream of BSON documents.
+type DecodedBSONSource struct {
+ RawDocSource
+ err error
+}
+
+// RawDocSource wraps basic functions for reading a BSON source file.
+type RawDocSource interface {
+ LoadNext() []byte
+ Close() error
+ Err() error
+}
+
+// NewBSONSource creates a BSONSource with a reusable I/O buffer
+func NewBSONSource(in io.ReadCloser) *BSONSource {
+ return &BSONSource{make([]byte, MaxBSONSize), in, nil}
+}
+
+// NewBufferlessBSONSource creates a BSONSource without a reusable I/O buffer
+func NewBufferlessBSONSource(in io.ReadCloser) *BSONSource {
+ return &BSONSource{nil, in, nil}
+}
+
+// Close closes the BSONSource, rendering it unusable for I/O.
+// It returns an error, if any.
+func (bs *BSONSource) Close() error {
+ return bs.Stream.Close()
+}
+
+func NewDecodedBSONSource(ds RawDocSource) *DecodedBSONSource {
+ return &DecodedBSONSource{ds, nil}
+}
+
+// Err returns any error in the DecodedBSONSource or its RawDocSource.
+func (dbs *DecodedBSONSource) Err() error {
+ if dbs.err != nil {
+ return dbs.err
+ }
+ return dbs.RawDocSource.Err()
+}
+
+// Next unmarshals the next BSON document into result. Returns true if no errors
+// are encountered and false otherwise.
+func (dbs *DecodedBSONSource) Next(result interface{}) bool {
+ doc := dbs.LoadNext()
+ if doc == nil {
+ return false
+ }
+ if err := bson.Unmarshal(doc, result); err != nil {
+ dbs.err = err
+ return false
+ }
+ dbs.err = nil
+ return true
+}
+
+// LoadNext reads and returns the next BSON document in the stream. If the
+// BSONSource was created with NewBSONSource then each returned []byte will be
+// a slice of a single reused I/O buffer. If the BSONSource was created with
+// NewBufferlessBSONSource then each returend []byte will be individually
+// allocated
+func (bs *BSONSource) LoadNext() []byte {
+ var into []byte
+ if bs.reusableBuf == nil {
+ into = make([]byte, 4)
+ } else {
+ into = bs.reusableBuf
+ }
+ // read the bson object size (a 4 byte integer)
+ _, err := io.ReadAtLeast(bs.Stream, into[0:4], 4)
+ if err != nil {
+ if err != io.EOF {
+ bs.err = err
+ return nil
+ }
+ // we hit EOF right away, so we're at the end of the stream.
+ bs.err = nil
+ return nil
+ }
+
+ bsonSize := int32(
+ (uint32(into[0]) << 0) |
+ (uint32(into[1]) << 8) |
+ (uint32(into[2]) << 16) |
+ (uint32(into[3]) << 24),
+ )
+
+ // Verify that the size of the BSON object we are about to read can
+ // actually fit into the buffer that was provided. If not, either the BSON is
+ // invalid, or the buffer passed in is too small.
+ // Verify that we do not have an invalid BSON document with size < 5.
+ if bsonSize > MaxBSONSize || bsonSize < 5 {
+ bs.err = fmt.Errorf("invalid BSONSize: %v bytes", bsonSize)
+ return nil
+ }
+ if int(bsonSize) > cap(into) {
+ bigInto := make([]byte, bsonSize)
+ copy(bigInto, into)
+ into = bigInto
+ if bs.reusableBuf != nil {
+ bs.reusableBuf = bigInto
+ }
+ }
+ into = into[:int(bsonSize)]
+ _, err = io.ReadAtLeast(bs.Stream, into[4:], int(bsonSize-4))
+ if err != nil {
+ if err != io.EOF {
+ bs.err = err
+ return nil
+ }
+ // this case means we hit EOF but read a partial document,
+ // so there's a broken doc in the stream. Treat this as error.
+ bs.err = fmt.Errorf("invalid bson: %v", err)
+ return nil
+ }
+
+ bs.err = nil
+ return into
+}
+
+func (bs *BSONSource) Err() error {
+ return bs.err
+}
diff --git a/src/mongo/gotools/common/db/bson_stream_test.go b/src/mongo/gotools/common/db/bson_stream_test.go
new file mode 100644
index 00000000000..657c038e0ea
--- /dev/null
+++ b/src/mongo/gotools/common/db/bson_stream_test.go
@@ -0,0 +1,41 @@
+package db
+
+import (
+ "bytes"
+ . "github.com/smartystreets/goconvey/convey"
+ "gopkg.in/mgo.v2/bson"
+ "io/ioutil"
+ "testing"
+)
+
+func TestBufferlessBSONSource(t *testing.T) {
+ var testValues = []bson.M{
+ {"_": bson.Binary{Kind: 0x80, Data: []byte("apples")}},
+ {"_": bson.Binary{Kind: 0x80, Data: []byte("bananas")}},
+ {"_": bson.Binary{Kind: 0x80, Data: []byte("cherries")}},
+ }
+ Convey("with a buffer containing several bson documents with binary fields", t, func() {
+ writeBuf := bytes.NewBuffer(make([]byte, 0, 1024))
+ for _, tv := range testValues {
+ data, err := bson.Marshal(&tv)
+ So(err, ShouldBeNil)
+ _, err = writeBuf.Write(data)
+ So(err, ShouldBeNil)
+ }
+ Convey("that we parse correctly with a BufferlessBSONSource", func() {
+ bsonSource := NewDecodedBSONSource(
+ NewBufferlessBSONSource(ioutil.NopCloser(writeBuf)))
+ docs := []bson.M{}
+ count := 0
+ doc := &bson.M{}
+ for bsonSource.Next(doc) {
+ count++
+ docs = append(docs, *doc)
+ doc = &bson.M{}
+ }
+ So(bsonSource.Err(), ShouldBeNil)
+ So(count, ShouldEqual, len(testValues))
+ So(docs, ShouldResemble, testValues)
+ })
+ })
+}
diff --git a/src/mongo/gotools/common/db/buffered_bulk.go b/src/mongo/gotools/common/db/buffered_bulk.go
new file mode 100644
index 00000000000..be2673b5876
--- /dev/null
+++ b/src/mongo/gotools/common/db/buffered_bulk.go
@@ -0,0 +1,79 @@
+package db
+
+import (
+ "fmt"
+ "gopkg.in/mgo.v2"
+ "gopkg.in/mgo.v2/bson"
+)
+
+// BufferedBulkInserter implements a bufio.Writer-like design for queuing up
+// documents and inserting them in bulk when the given doc limit (or max
+// message size) is reached. Must be flushed at the end to ensure that all
+// documents are written.
+type BufferedBulkInserter struct {
+ bulk *mgo.Bulk
+ collection *mgo.Collection
+ continueOnError bool
+ docLimit int
+ byteCount int
+ docCount int
+ unordered bool
+}
+
+// NewBufferedBulkInserter returns an initialized BufferedBulkInserter
+// for writing.
+func NewBufferedBulkInserter(collection *mgo.Collection, docLimit int,
+ continueOnError bool) *BufferedBulkInserter {
+ bb := &BufferedBulkInserter{
+ collection: collection,
+ continueOnError: continueOnError,
+ docLimit: docLimit,
+ }
+ bb.resetBulk()
+ return bb
+}
+
+func (bb *BufferedBulkInserter) Unordered() {
+ bb.unordered = true
+ bb.bulk.Unordered()
+}
+
+// throw away the old bulk and init a new one
+func (bb *BufferedBulkInserter) resetBulk() {
+ bb.bulk = bb.collection.Bulk()
+ if bb.continueOnError || bb.unordered {
+ bb.bulk.Unordered()
+ }
+ bb.byteCount = 0
+ bb.docCount = 0
+}
+
+// Insert adds a document to the buffer for bulk insertion. If the buffer is
+// full, the bulk insert is made, returning any error that occurs.
+func (bb *BufferedBulkInserter) Insert(doc interface{}) error {
+ rawBytes, err := bson.Marshal(doc)
+ if err != nil {
+ return fmt.Errorf("bson encoding error: %v", err)
+ }
+ // flush if we are full
+ if bb.docCount >= bb.docLimit || bb.byteCount+len(rawBytes) > MaxBSONSize {
+ err = bb.Flush()
+ }
+ // buffer the document
+ bb.docCount++
+ bb.byteCount += len(rawBytes)
+ bb.bulk.Insert(bson.Raw{Data: rawBytes})
+ return err
+}
+
+// Flush writes all buffered documents in one bulk insert then resets the buffer.
+func (bb *BufferedBulkInserter) Flush() error {
+ if bb.docCount == 0 {
+ return nil
+ }
+ defer bb.resetBulk()
+ if _, err := bb.bulk.Run(); err != nil {
+ return err
+ }
+ return nil
+}
diff --git a/src/mongo/gotools/common/db/buffered_bulk_test.go b/src/mongo/gotools/common/db/buffered_bulk_test.go
new file mode 100644
index 00000000000..d4664dadd20
--- /dev/null
+++ b/src/mongo/gotools/common/db/buffered_bulk_test.go
@@ -0,0 +1,108 @@
+package db
+
+import (
+ "github.com/mongodb/mongo-tools/common/options"
+ "github.com/mongodb/mongo-tools/common/testutil"
+ . "github.com/smartystreets/goconvey/convey"
+ "gopkg.in/mgo.v2/bson"
+ "testing"
+)
+
+func TestBufferedBulkInserterInserts(t *testing.T) {
+ var bufBulk *BufferedBulkInserter
+
+ testutil.VerifyTestType(t, "db")
+
+ Convey("With a valid session", t, func() {
+ opts := options.ToolOptions{
+ Connection: &options.Connection{
+ Port: DefaultTestPort,
+ },
+ SSL: &options.SSL{},
+ Auth: &options.Auth{},
+ }
+ provider, err := NewSessionProvider(opts)
+ session, err := provider.GetSession()
+ So(session, ShouldNotBeNil)
+ So(err, ShouldBeNil)
+
+ Convey("using a test collection and a doc limit of 3", func() {
+ testCol := session.DB("tools-test").C("bulk1")
+ bufBulk = NewBufferedBulkInserter(testCol, 3, false)
+ So(bufBulk, ShouldNotBeNil)
+
+ Convey("inserting 10 documents into the BufferedBulkInserter", func() {
+ flushCount := 0
+ for i := 0; i < 10; i++ {
+ So(bufBulk.Insert(bson.D{}), ShouldBeNil)
+ if bufBulk.docCount%3 == 0 {
+ flushCount++
+ }
+ }
+
+ Convey("should have flushed 3 times with one doc still buffered", func() {
+ So(flushCount, ShouldEqual, 3)
+ So(bufBulk.byteCount, ShouldBeGreaterThan, 0)
+ So(bufBulk.docCount, ShouldEqual, 1)
+ })
+ })
+ })
+
+ Convey("using a test collection and a doc limit of 1", func() {
+ testCol := session.DB("tools-test").C("bulk2")
+ bufBulk = NewBufferedBulkInserter(testCol, 1, false)
+ So(bufBulk, ShouldNotBeNil)
+
+ Convey("inserting 10 documents into the BufferedBulkInserter and flushing", func() {
+ for i := 0; i < 10; i++ {
+ So(bufBulk.Insert(bson.D{}), ShouldBeNil)
+ }
+ So(bufBulk.Flush(), ShouldBeNil)
+
+ Convey("should have no docs buffered", func() {
+ So(bufBulk.docCount, ShouldEqual, 0)
+ So(bufBulk.byteCount, ShouldEqual, 0)
+ })
+ })
+ })
+
+ Convey("using a test collection and a doc limit of 1000", func() {
+ testCol := session.DB("tools-test").C("bulk3")
+ bufBulk = NewBufferedBulkInserter(testCol, 100, false)
+ So(bufBulk, ShouldNotBeNil)
+
+ Convey("inserting 1,000,000 documents into the BufferedBulkInserter and flushing", func() {
+ session.SetSocketTimeout(0)
+
+ for i := 0; i < 1000000; i++ {
+ bufBulk.Insert(bson.M{"_id": i})
+ }
+ So(bufBulk.Flush(), ShouldBeNil)
+
+ Convey("should have inserted all of the documents", func() {
+ count, err := testCol.Count()
+ So(err, ShouldBeNil)
+ So(count, ShouldEqual, 1000000)
+
+ // test values
+ testDoc := bson.M{}
+ err = testCol.Find(bson.M{"_id": 477232}).One(&testDoc)
+ So(err, ShouldBeNil)
+ So(testDoc["_id"], ShouldEqual, 477232)
+ err = testCol.Find(bson.M{"_id": 999999}).One(&testDoc)
+ So(err, ShouldBeNil)
+ So(testDoc["_id"], ShouldEqual, 999999)
+ err = testCol.Find(bson.M{"_id": 1}).One(&testDoc)
+ So(err, ShouldBeNil)
+ So(testDoc["_id"], ShouldEqual, 1)
+
+ })
+ })
+ })
+
+ Reset(func() {
+ session.DB("tools-test").DropDatabase()
+ })
+ })
+
+}
diff --git a/src/mongo/gotools/common/db/command.go b/src/mongo/gotools/common/db/command.go
new file mode 100644
index 00000000000..6016d25f5fb
--- /dev/null
+++ b/src/mongo/gotools/common/db/command.go
@@ -0,0 +1,210 @@
+package db
+
+import (
+ "fmt"
+ "gopkg.in/mgo.v2"
+ "gopkg.in/mgo.v2/bson"
+ "strings"
+)
+
+// Query flags
+const (
+ Snapshot = 1 << iota
+ LogReplay
+ Prefetch
+)
+
+type NodeType string
+
+const (
+ Mongos NodeType = "mongos"
+ Standalone = "standalone"
+ ReplSet = "replset"
+ Unknown = "unknown"
+)
+
+// CommandRunner exposes functions that can be run against a server
+type CommandRunner interface {
+ Run(command interface{}, out interface{}, database string) error
+ FindOne(db, collection string, skip int, query interface{}, sort []string, into interface{}, opts int) error
+ Remove(db, collection string, query interface{}) error
+ DatabaseNames() ([]string, error)
+ CollectionNames(db string) ([]string, error)
+}
+
+// Remove removes all documents matched by query q in the db database and c collection.
+func (sp *SessionProvider) Remove(db, c string, q interface{}) error {
+ session, err := sp.GetSession()
+ if err != nil {
+ return err
+ }
+ defer session.Close()
+ _, err = session.DB(db).C(c).RemoveAll(q)
+ return err
+}
+
+// Run issues the provided command on the db database and unmarshals its result
+// into out.
+func (sp *SessionProvider) Run(command interface{}, out interface{}, db string) error {
+ session, err := sp.GetSession()
+ if err != nil {
+ return err
+ }
+ defer session.Close()
+ return session.DB(db).Run(command, out)
+}
+
+// DatabaseNames returns a slice containing the names of all the databases on the
+// connected server.
+func (sp *SessionProvider) DatabaseNames() ([]string, error) {
+ session, err := sp.GetSession()
+ if err != nil {
+ return nil, err
+ }
+ session.SetSocketTimeout(0)
+ defer session.Close()
+ return session.DatabaseNames()
+}
+
+// CollectionNames returns the names of all the collections in the dbName database.
+func (sp *SessionProvider) CollectionNames(dbName string) ([]string, error) {
+ session, err := sp.GetSession()
+ if err != nil {
+ return nil, err
+ }
+ defer session.Close()
+ session.SetSocketTimeout(0)
+ return session.DB(dbName).CollectionNames()
+}
+
+// GetNodeType checks if the connected SessionProvider is a mongos, standalone, or replset,
+// by looking at the result of calling isMaster.
+func (sp *SessionProvider) GetNodeType() (NodeType, error) {
+ session, err := sp.GetSession()
+ if err != nil {
+ return Unknown, err
+ }
+ session.SetSocketTimeout(0)
+ defer session.Close()
+ masterDoc := struct {
+ SetName interface{} `bson:"setName"`
+ Hosts interface{} `bson:"hosts"`
+ Msg string `bson:"msg"`
+ }{}
+ err = session.Run("isMaster", &masterDoc)
+ if err != nil {
+ return Unknown, err
+ }
+
+ if masterDoc.SetName != nil || masterDoc.Hosts != nil {
+ return ReplSet, nil
+ } else if masterDoc.Msg == "isdbgrid" {
+ // isdbgrid is always the msg value when calling isMaster on a mongos
+ // see http://docs.mongodb.org/manual/core/sharded-cluster-query-router/
+ return Mongos, nil
+ }
+ return Standalone, nil
+}
+
+// IsReplicaSet returns a boolean which is true if the connected server is part
+// of a replica set.
+func (sp *SessionProvider) IsReplicaSet() (bool, error) {
+ nodeType, err := sp.GetNodeType()
+ if err != nil {
+ return false, err
+ }
+ return nodeType == ReplSet, nil
+}
+
+// IsMongos returns true if the connected server is a mongos.
+func (sp *SessionProvider) IsMongos() (bool, error) {
+ nodeType, err := sp.GetNodeType()
+ if err != nil {
+ return false, err
+ }
+ return nodeType == Mongos, nil
+}
+
+// SupportsRepairCursor takes in an example db and collection name and
+// returns true if the connected server supports the repairCursor command.
+// It returns false and the error that occurred if it is not supported.
+func (sp *SessionProvider) SupportsRepairCursor(db, collection string) (bool, error) {
+ session, err := sp.GetSession()
+ if err != nil {
+ return false, err
+ }
+ session.SetSocketTimeout(0)
+ defer session.Close()
+
+ // This check is slightly hacky, but necessary to allow users to run repair without
+ // permissions to all collections. There are multiple reasons a repair command could fail,
+ // but we are only interested in the ones that imply that the repair command is not
+ // usable by the connected server. If we do not get one of these specific error messages,
+ // we will let the error happen again later.
+ repairIter := session.DB(db).C(collection).Repair()
+ repairIter.Next(bson.D{})
+ err = repairIter.Err()
+ if err == nil {
+ return true, nil
+ }
+ if strings.Index(err.Error(), "no such cmd: repairCursor") > -1 {
+ // return a helpful error message for early server versions
+ return false, fmt.Errorf("--repair flag cannot be used on mongodb versions before 2.7.8")
+ }
+ if strings.Index(err.Error(), "repair iterator not supported") > -1 {
+ // helpful error message if the storage engine does not support repair (WiredTiger)
+ return false, fmt.Errorf("--repair is not supported by the connected storage engine")
+ }
+
+ return true, nil
+}
+
+// SupportsWriteCommands returns true if the connected server supports write
+// commands, returns false otherwise.
+func (sp *SessionProvider) SupportsWriteCommands() (bool, error) {
+ session, err := sp.GetSession()
+ if err != nil {
+ return false, err
+ }
+ session.SetSocketTimeout(0)
+ defer session.Close()
+ masterDoc := struct {
+ Ok int `bson:"ok"`
+ MaxWire int `bson:"maxWireVersion"`
+ }{}
+ err = session.Run("isMaster", &masterDoc)
+ if err != nil {
+ return false, err
+ }
+ // the connected server supports write commands if
+ // the maxWriteVersion field is present
+ return (masterDoc.Ok == 1 && masterDoc.MaxWire >= 2), nil
+}
+
+// FindOne retuns the first document in the collection and database that matches
+// the query after skip, sort and query flags are applied.
+func (sp *SessionProvider) FindOne(db, collection string, skip int, query interface{}, sort []string, into interface{}, flags int) error {
+ session, err := sp.GetSession()
+ if err != nil {
+ return err
+ }
+ defer session.Close()
+
+ q := session.DB(db).C(collection).Find(query).Sort(sort...).Skip(skip)
+ q = ApplyFlags(q, session, flags)
+ return q.One(into)
+}
+
+// ApplyFlags applies flags to the given query session.
+func ApplyFlags(q *mgo.Query, session *mgo.Session, flags int) *mgo.Query {
+ if flags&Snapshot > 0 {
+ q = q.Snapshot()
+ }
+ if flags&LogReplay > 0 {
+ q = q.LogReplay()
+ }
+ if flags&Prefetch > 0 {
+ session.SetPrefetch(1.0)
+ }
+ return q
+}
diff --git a/src/mongo/gotools/common/db/connector.go b/src/mongo/gotools/common/db/connector.go
new file mode 100644
index 00000000000..85e8d6e653c
--- /dev/null
+++ b/src/mongo/gotools/common/db/connector.go
@@ -0,0 +1,52 @@
+package db
+
+import (
+ "time"
+
+ "github.com/mongodb/mongo-tools/common/options"
+ "github.com/mongodb/mongo-tools/common/util"
+ "gopkg.in/mgo.v2"
+)
+
+// Interface type for connecting to the database.
+type DBConnector interface {
+ // configure, based on the options passed in
+ Configure(options.ToolOptions) error
+
+ // dial the database and get a fresh new session
+ GetNewSession() (*mgo.Session, error)
+}
+
+// Basic connector for dialing the database, with no authentication.
+type VanillaDBConnector struct {
+ dialInfo *mgo.DialInfo
+}
+
+// Configure sets up the db connector using the options in opts. It parses the
+// connection string and then sets up the dial information using the default
+// dial timeout.
+func (self *VanillaDBConnector) Configure(opts options.ToolOptions) error {
+ // create the addresses to be used to connect
+ connectionAddrs := util.CreateConnectionAddrs(opts.Host, opts.Port)
+
+ timeout := time.Duration(opts.Timeout) * time.Second
+
+ // set up the dial info
+ self.dialInfo = &mgo.DialInfo{
+ Addrs: connectionAddrs,
+ Timeout: timeout,
+ Direct: opts.Direct,
+ ReplicaSetName: opts.ReplicaSetName,
+ Username: opts.Auth.Username,
+ Password: opts.Auth.Password,
+ Source: opts.GetAuthenticationDatabase(),
+ Mechanism: opts.Auth.Mechanism,
+ }
+ return nil
+}
+
+// GetNewSession connects to the server and returns the established session and any
+// error encountered.
+func (self *VanillaDBConnector) GetNewSession() (*mgo.Session, error) {
+ return mgo.DialWithInfo(self.dialInfo)
+}
diff --git a/src/mongo/gotools/common/db/connector_sasl_test.go b/src/mongo/gotools/common/db/connector_sasl_test.go
new file mode 100644
index 00000000000..c585c92f842
--- /dev/null
+++ b/src/mongo/gotools/common/db/connector_sasl_test.go
@@ -0,0 +1,60 @@
+// +build sasl
+
+package db
+
+// This file runs Kerberos tests if build with sasl is enabled
+
+import (
+ "fmt"
+ "github.com/mongodb/mongo-tools/common/db/kerberos"
+ "github.com/mongodb/mongo-tools/common/options"
+ "github.com/mongodb/mongo-tools/common/testutil"
+ . "github.com/smartystreets/goconvey/convey"
+ "gopkg.in/mgo.v2/bson"
+ "os"
+ "runtime"
+ "testing"
+)
+
+var (
+ KERBEROS_HOST = "ldaptest.10gen.cc"
+ KERBEROS_USER = "drivers@LDAPTEST.10GEN.CC"
+)
+
+func TestKerberosDBConnector(t *testing.T) {
+ Convey("should be able to successfully connect", t, func() {
+ connector := &kerberos.KerberosDBConnector{}
+
+ opts := options.ToolOptions{
+ Connection: &options.Connection{
+ Host: KERBEROS_HOST,
+ Port: "27017",
+ },
+ Auth: &options.Auth{
+ Username: KERBEROS_USER,
+ },
+ Kerberos: &options.Kerberos{
+ Service: "mongodb",
+ ServiceHost: KERBEROS_HOST,
+ },
+ }
+
+ if runtime.GOOS == "windows" {
+ opts.Auth.Password = os.Getenv(testutil.WinKerberosPwdEnv)
+ if opts.Auth.Password == "" {
+ panic(fmt.Sprintf("Need to set %v environment variable to run kerberos tests on windows",
+ testutil.WinKerberosPwdEnv))
+ }
+ }
+
+ So(connector.Configure(opts), ShouldBeNil)
+
+ session, err := connector.GetNewSession()
+ So(err, ShouldBeNil)
+ So(session, ShouldNotBeNil)
+
+ n, err := session.DB("kerberos").C("test").Find(bson.M{}).Count()
+ So(err, ShouldBeNil)
+ So(n, ShouldEqual, 1)
+ })
+}
diff --git a/src/mongo/gotools/common/db/connector_test.go b/src/mongo/gotools/common/db/connector_test.go
new file mode 100644
index 00000000000..a04cd9cfb4e
--- /dev/null
+++ b/src/mongo/gotools/common/db/connector_test.go
@@ -0,0 +1,134 @@
+package db
+
+import (
+ "github.com/mongodb/mongo-tools/common/options"
+ "github.com/mongodb/mongo-tools/common/testutil"
+ . "github.com/smartystreets/goconvey/convey"
+ "gopkg.in/mgo.v2"
+ "testing"
+ "time"
+)
+
+func TestVanillaDBConnector(t *testing.T) {
+
+ testutil.VerifyTestType(t, "db")
+
+ Convey("With a vanilla db connector", t, func() {
+
+ var connector *VanillaDBConnector
+
+ Convey("calling Configure should populate the addrs and dial timeout"+
+ " appropriately with no error", func() {
+
+ connector = &VanillaDBConnector{}
+
+ opts := options.ToolOptions{
+ Connection: &options.Connection{
+ Host: "host1,host2",
+ Port: "20000",
+ },
+ Auth: &options.Auth{},
+ }
+ So(connector.Configure(opts), ShouldBeNil)
+ So(connector.dialInfo.Addrs, ShouldResemble,
+ []string{"host1:20000", "host2:20000"})
+ So(connector.dialInfo.Timeout, ShouldResemble, time.Duration(opts.Timeout)*time.Second)
+
+ })
+
+ Convey("calling GetNewSession with a running mongod should connect"+
+ " successfully", func() {
+
+ connector = &VanillaDBConnector{}
+
+ opts := options.ToolOptions{
+ Connection: &options.Connection{
+ Host: "localhost",
+ Port: DefaultTestPort,
+ },
+ Auth: &options.Auth{},
+ }
+ So(connector.Configure(opts), ShouldBeNil)
+
+ session, err := connector.GetNewSession()
+ So(err, ShouldBeNil)
+ So(session, ShouldNotBeNil)
+ session.Close()
+
+ })
+
+ })
+
+}
+
+func TestVanillaDBConnectorWithAuth(t *testing.T) {
+ testutil.VerifyTestType(t, "auth")
+ session, err := mgo.Dial("localhost:33333")
+ if err != nil {
+ t.Fatalf("error dialing server: %v", err)
+ }
+
+ err = testutil.CreateUserAdmin(session)
+ So(err, ShouldBeNil)
+ err = testutil.CreateUserWithRole(session, "cAdmin", "password",
+ mgo.RoleClusterAdmin, true)
+ So(err, ShouldBeNil)
+ session.Close()
+
+ Convey("With a vanilla db connector and a mongod running with"+
+ " auth", t, func() {
+
+ var connector *VanillaDBConnector
+
+ Convey("connecting without authentication should not be able"+
+ " to run commands", func() {
+
+ connector = &VanillaDBConnector{}
+
+ opts := options.ToolOptions{
+ Connection: &options.Connection{
+ Host: "localhost",
+ Port: DefaultTestPort,
+ },
+ Auth: &options.Auth{},
+ }
+ So(connector.Configure(opts), ShouldBeNil)
+
+ session, err := connector.GetNewSession()
+ So(err, ShouldBeNil)
+ So(session, ShouldNotBeNil)
+
+ So(session.DB("admin").Run("top", &struct{}{}), ShouldNotBeNil)
+ session.Close()
+
+ })
+
+ Convey("connecting with authentication should succeed and"+
+ " authenticate properly", func() {
+
+ connector = &VanillaDBConnector{}
+
+ opts := options.ToolOptions{
+ Connection: &options.Connection{
+ Host: "localhost",
+ Port: DefaultTestPort,
+ },
+ Auth: &options.Auth{
+ Username: "cAdmin",
+ Password: "password",
+ },
+ }
+ So(connector.Configure(opts), ShouldBeNil)
+
+ session, err := connector.GetNewSession()
+ So(err, ShouldBeNil)
+ So(session, ShouldNotBeNil)
+
+ So(session.DB("admin").Run("top", &struct{}{}), ShouldBeNil)
+ session.Close()
+
+ })
+
+ })
+
+}
diff --git a/src/mongo/gotools/common/db/db.go b/src/mongo/gotools/common/db/db.go
new file mode 100644
index 00000000000..a3207c4b467
--- /dev/null
+++ b/src/mongo/gotools/common/db/db.go
@@ -0,0 +1,243 @@
+// Package db implements generic connection to MongoDB, and contains
+// subpackages for specific methods of connection.
+package db
+
+import (
+ "github.com/mongodb/mongo-tools/common/options"
+ "github.com/mongodb/mongo-tools/common/password"
+ "gopkg.in/mgo.v2"
+ "gopkg.in/mgo.v2/bson"
+
+ "fmt"
+ "io"
+ "strings"
+ "sync"
+)
+
+type (
+ sessionFlag uint32
+ // Used to get appropriate the DBConnector(s) based on opts
+ GetConnectorFunc func(opts options.ToolOptions) DBConnector
+)
+
+// Session flags.
+const (
+ None sessionFlag = 0
+ Monotonic sessionFlag = 1 << iota
+ DisableSocketTimeout
+)
+
+// MongoDB enforced limits.
+const (
+ MaxBSONSize = 16 * 1024 * 1024 // 16MB - maximum BSON document size
+)
+
+// Default port for integration tests
+const (
+ DefaultTestPort = "33333"
+)
+
+const (
+ ErrLostConnection = "lost connection to server"
+ ErrNoReachableServers = "no reachable servers"
+ ErrNsNotFound = "ns not found"
+ // replication errors list the replset name if we are talking to a mongos,
+ // so we can only check for this universal prefix
+ ErrReplTimeoutPrefix = "waiting for replication timed out"
+ ErrCouldNotContactPrimaryPrefix = "could not contact primary for replica set"
+ ErrCouldNotFindPrimaryPrefix = `could not find host matching read preference { mode: "primary"`
+ ErrUnableToTargetPrefix = "unable to target"
+ ErrNotMaster = "not master"
+ ErrConnectionRefusedSuffix = "Connection refused"
+)
+
+var (
+ GetConnectorFuncs = []GetConnectorFunc{}
+)
+
+// Used to manage database sessions
+type SessionProvider struct {
+
+ // For connecting to the database
+ connector DBConnector
+
+ // used to avoid a race condition around creating the master session
+ masterSessionLock sync.Mutex
+
+ // the master session to use for connection pooling
+ masterSession *mgo.Session
+
+ // flags for generating the master session
+ bypassDocumentValidation bool
+ flags sessionFlag
+ readPreference mgo.Mode
+ tags bson.D
+}
+
+// ApplyOpsResponse represents the response from an 'applyOps' command.
+type ApplyOpsResponse struct {
+ Ok bool `bson:"ok"`
+ ErrMsg string `bson:"errmsg"`
+}
+
+// Oplog represents a MongoDB oplog document.
+type Oplog struct {
+ Timestamp bson.MongoTimestamp `bson:"ts"`
+ HistoryID int64 `bson:"h"`
+ Version int `bson:"v"`
+ Operation string `bson:"op"`
+ Namespace string `bson:"ns"`
+ Object bson.D `bson:"o"`
+ Query bson.D `bson:"o2"`
+}
+
+// Returns a session connected to the database server for which the
+// session provider is configured.
+func (self *SessionProvider) GetSession() (*mgo.Session, error) {
+ self.masterSessionLock.Lock()
+ defer self.masterSessionLock.Unlock()
+
+ // The master session is initialized
+ if self.masterSession != nil {
+ return self.masterSession.Copy(), nil
+ }
+
+ // initialize the provider's master session
+ var err error
+ self.masterSession, err = self.connector.GetNewSession()
+ if err != nil {
+ return nil, fmt.Errorf("error connecting to db server: %v", err)
+ }
+
+ // update masterSession based on flags
+ self.refresh()
+
+ // copy the provider's master session, for connection pooling
+ return self.masterSession.Copy(), nil
+}
+
+// refresh is a helper for modifying the session based on the
+// session provider flags passed in with SetFlags.
+// This helper assumes a lock is already taken.
+func (self *SessionProvider) refresh() {
+ // handle bypassDocumentValidation
+ self.masterSession.SetBypassValidation(self.bypassDocumentValidation)
+
+ // handle readPreference
+ self.masterSession.SetMode(self.readPreference, true)
+
+ // disable timeouts
+ if (self.flags & DisableSocketTimeout) > 0 {
+ self.masterSession.SetSocketTimeout(0)
+ }
+ if self.tags != nil {
+ self.masterSession.SelectServers(self.tags)
+ }
+}
+
+// SetFlags allows certain modifications to the masterSession after initial creation.
+func (self *SessionProvider) SetFlags(flagBits sessionFlag) {
+ self.masterSessionLock.Lock()
+ defer self.masterSessionLock.Unlock()
+
+ self.flags = flagBits
+
+ // make sure we update the master session if one already exists
+ if self.masterSession != nil {
+ self.refresh()
+ }
+}
+
+// SetReadPreference sets the read preference mode in the SessionProvider
+// and eventually in the masterSession
+func (self *SessionProvider) SetReadPreference(pref mgo.Mode) {
+ self.masterSessionLock.Lock()
+ defer self.masterSessionLock.Unlock()
+
+ self.readPreference = pref
+
+ if self.masterSession != nil {
+ self.refresh()
+ }
+}
+
+// SetBypassDocumentValidation sets whether to bypass document validation in the SessionProvider
+// and eventually in the masterSession
+func (self *SessionProvider) SetBypassDocumentValidation(bypassDocumentValidation bool) {
+ self.masterSessionLock.Lock()
+ defer self.masterSessionLock.Unlock()
+
+ self.bypassDocumentValidation = bypassDocumentValidation
+
+ if self.masterSession != nil {
+ self.refresh()
+ }
+}
+
+// SetTags sets the server selection tags in the SessionProvider
+// and eventually in the masterSession
+func (self *SessionProvider) SetTags(tags bson.D) {
+ self.masterSessionLock.Lock()
+ defer self.masterSessionLock.Unlock()
+
+ self.tags = tags
+
+ if self.masterSession != nil {
+ self.refresh()
+ }
+}
+
+// NewSessionProvider constructs a session provider but does not attempt to
+// create the initial session.
+func NewSessionProvider(opts options.ToolOptions) (*SessionProvider, error) {
+ // create the provider
+ provider := &SessionProvider{
+ readPreference: mgo.Primary,
+ bypassDocumentValidation: false,
+ }
+
+ // finalize auth options, filling in missing passwords
+ if opts.Auth.ShouldAskForPassword() {
+ opts.Auth.Password = password.Prompt()
+ }
+
+ // create the connector for dialing the database
+ provider.connector = getConnector(opts)
+
+ // configure the connector
+ err := provider.connector.Configure(opts)
+ if err != nil {
+ return nil, fmt.Errorf("error configuring the connector: %v", err)
+ }
+ return provider, nil
+}
+
+// IsConnectionError returns a boolean indicating if a given error is due to
+// an error in an underlying DB connection (as opposed to some other write
+// failure such as a duplicate key error)
+func IsConnectionError(err error) bool {
+ if err == nil {
+ return false
+ }
+ if err.Error() == ErrNoReachableServers ||
+ err.Error() == io.EOF.Error() ||
+ strings.HasPrefix(err.Error(), ErrReplTimeoutPrefix) ||
+ strings.HasPrefix(err.Error(), ErrCouldNotContactPrimaryPrefix) ||
+ strings.HasPrefix(err.Error(), ErrCouldNotFindPrimaryPrefix) ||
+ strings.HasPrefix(err.Error(), ErrUnableToTargetPrefix) ||
+ err.Error() == ErrNotMaster ||
+ strings.HasSuffix(err.Error(), ErrConnectionRefusedSuffix) {
+ return true
+ }
+ return false
+}
+
+// Get the right type of connector, based on the options
+func getConnector(opts options.ToolOptions) DBConnector {
+ for _, getConnectorFunc := range GetConnectorFuncs {
+ if connector := getConnectorFunc(opts); connector != nil {
+ return connector
+ }
+ }
+ return &VanillaDBConnector{}
+}
diff --git a/src/mongo/gotools/common/db/db_gssapi.go b/src/mongo/gotools/common/db/db_gssapi.go
new file mode 100644
index 00000000000..656e81987a9
--- /dev/null
+++ b/src/mongo/gotools/common/db/db_gssapi.go
@@ -0,0 +1,20 @@
+// +build sasl
+
+package db
+
+import (
+ "github.com/mongodb/mongo-tools/common/db/kerberos"
+ "github.com/mongodb/mongo-tools/common/options"
+)
+
+func init() {
+ GetConnectorFuncs = append(GetConnectorFuncs, getGSSAPIConnector)
+}
+
+// return the Kerberos DB connector if using SSL, otherwise return nil.
+func getGSSAPIConnector(opts options.ToolOptions) DBConnector {
+ if opts.Auth.Mechanism == "GSSAPI" {
+ return &kerberos.KerberosDBConnector{}
+ }
+ return nil
+}
diff --git a/src/mongo/gotools/common/db/db_ssl.go b/src/mongo/gotools/common/db/db_ssl.go
new file mode 100644
index 00000000000..68d3850b525
--- /dev/null
+++ b/src/mongo/gotools/common/db/db_ssl.go
@@ -0,0 +1,20 @@
+// +build ssl
+
+package db
+
+import (
+ "github.com/mongodb/mongo-tools/common/db/openssl"
+ "github.com/mongodb/mongo-tools/common/options"
+)
+
+func init() {
+ GetConnectorFuncs = append(GetConnectorFuncs, getSSLConnector)
+}
+
+// return the SSL DB connector if using SSL, otherwise, return nil.
+func getSSLConnector(opts options.ToolOptions) DBConnector {
+ if opts.SSL.UseSSL {
+ return &openssl.SSLDBConnector{}
+ }
+ return nil
+}
diff --git a/src/mongo/gotools/common/db/db_test.go b/src/mongo/gotools/common/db/db_test.go
new file mode 100644
index 00000000000..59d1c53b929
--- /dev/null
+++ b/src/mongo/gotools/common/db/db_test.go
@@ -0,0 +1,63 @@
+package db
+
+import (
+ "github.com/mongodb/mongo-tools/common/options"
+ "github.com/mongodb/mongo-tools/common/testutil"
+ . "github.com/smartystreets/goconvey/convey"
+ "reflect"
+ "testing"
+)
+
+func TestNewSessionProvider(t *testing.T) {
+
+ testutil.VerifyTestType(t, "db")
+
+ Convey("When initializing a session provider", t, func() {
+
+ Convey("with the standard options, a provider with a standard"+
+ " connector should be returned", func() {
+ opts := options.ToolOptions{
+ Connection: &options.Connection{
+ Port: DefaultTestPort,
+ },
+ SSL: &options.SSL{},
+ Auth: &options.Auth{},
+ }
+ provider, err := NewSessionProvider(opts)
+ So(err, ShouldBeNil)
+ So(reflect.TypeOf(provider.connector), ShouldEqual,
+ reflect.TypeOf(&VanillaDBConnector{}))
+
+ })
+
+ Convey("the master session should be successfully "+
+ " initialized", func() {
+ opts := options.ToolOptions{
+ Connection: &options.Connection{
+ Port: DefaultTestPort,
+ },
+ SSL: &options.SSL{},
+ Auth: &options.Auth{},
+ }
+ provider, err := NewSessionProvider(opts)
+ So(err, ShouldBeNil)
+ So(provider.masterSession, ShouldBeNil)
+ session, err := provider.GetSession()
+ So(err, ShouldBeNil)
+ So(session, ShouldNotBeNil)
+ So(provider.masterSession, ShouldNotBeNil)
+
+ })
+
+ })
+
+}
+
+type listDatabasesCommand struct {
+ Databases []map[string]interface{} `json:"databases"`
+ Ok bool `json:"ok"`
+}
+
+func (self *listDatabasesCommand) AsRunnable() interface{} {
+ return "listDatabases"
+}
diff --git a/src/mongo/gotools/common/db/kerberos/gssapi.go b/src/mongo/gotools/common/db/kerberos/gssapi.go
new file mode 100644
index 00000000000..e9827b04109
--- /dev/null
+++ b/src/mongo/gotools/common/db/kerberos/gssapi.go
@@ -0,0 +1,58 @@
+// Package kerberos implements connection to MongoDB using kerberos.
+package kerberos
+
+// #cgo windows CFLAGS: -Ic:/sasl/include
+// #cgo windows LDFLAGS: -Lc:/sasl/lib
+
+import (
+ "github.com/mongodb/mongo-tools/common/options"
+ "github.com/mongodb/mongo-tools/common/util"
+ "gopkg.in/mgo.v2"
+ "time"
+)
+
+const (
+ KERBEROS_AUTHENTICATION_MECHANISM = "GSSAPI"
+)
+
+type KerberosDBConnector struct {
+ dialInfo *mgo.DialInfo
+}
+
+// Configure the db connector. Parses the connection string and sets up
+// the dial info with the default dial timeout.
+func (self *KerberosDBConnector) Configure(opts options.ToolOptions) error {
+
+ // create the addresses to be used to connect
+ connectionAddrs := util.CreateConnectionAddrs(opts.Host, opts.Port)
+
+ timeout := time.Duration(opts.Timeout) * time.Second
+
+ // set up the dial info
+ self.dialInfo = &mgo.DialInfo{
+ Addrs: connectionAddrs,
+ Timeout: timeout,
+ Direct: opts.Direct,
+ ReplicaSetName: opts.ReplicaSetName,
+
+ // Kerberos principal
+ Username: opts.Auth.Username,
+ // Note: Password is only used on Windows. SASL doesn't allow you to specify
+ // a password, so this field is ignored on Linux and OSX. Run the kinit
+ // command to get a ticket first.
+ Password: opts.Auth.Password,
+ // This should always be '$external', but legacy tools still allow you to
+ // specify a source DB
+ Source: opts.Auth.Source,
+ Service: opts.Kerberos.Service,
+ ServiceHost: opts.Kerberos.ServiceHost,
+ Mechanism: KERBEROS_AUTHENTICATION_MECHANISM,
+ }
+
+ return nil
+}
+
+// Dial the database.
+func (self *KerberosDBConnector) GetNewSession() (*mgo.Session, error) {
+ return mgo.DialWithInfo(self.dialInfo)
+}
diff --git a/src/mongo/gotools/common/db/namespaces.go b/src/mongo/gotools/common/db/namespaces.go
new file mode 100644
index 00000000000..149400543ef
--- /dev/null
+++ b/src/mongo/gotools/common/db/namespaces.go
@@ -0,0 +1,159 @@
+package db
+
+import (
+ "fmt"
+ "github.com/mongodb/mongo-tools/common/bsonutil"
+ "github.com/mongodb/mongo-tools/common/log"
+ "gopkg.in/mgo.v2"
+ "gopkg.in/mgo.v2/bson"
+ "strings"
+)
+
+// IsNoCmd reeturns true if err indicates a query command is not supported,
+// otherwise, returns false.
+func IsNoCmd(err error) bool {
+ e, ok := err.(*mgo.QueryError)
+ return ok && strings.HasPrefix(e.Message, "no such cmd:")
+}
+
+// IsNoCollection returns true if err indicates a query resulted in a "no collection" error
+// otherwise, returns false.
+func IsNoCollection(err error) bool {
+ e, ok := err.(*mgo.QueryError)
+ return ok && e.Message == "no collection"
+}
+
+// buildBsonArray takes a cursor iterator and returns an array of
+// all of its documents as bson.D objects.
+func buildBsonArray(iter *mgo.Iter) ([]bson.D, error) {
+ ret := make([]bson.D, 0, 0)
+ index := new(bson.D)
+ for iter.Next(index) {
+ ret = append(ret, *index)
+ index = new(bson.D)
+ }
+
+ if iter.Err() != nil {
+ return nil, iter.Err()
+ }
+ return ret, nil
+
+}
+
+// GetIndexes returns an iterator to thethe raw index info for a collection by
+// using the listIndexes command if available, or by falling back to querying
+// against system.indexes (pre-3.0 systems). nil is returned if the collection
+// does not exist.
+func GetIndexes(coll *mgo.Collection) (*mgo.Iter, error) {
+ var cmdResult struct {
+ Cursor struct {
+ FirstBatch []bson.Raw `bson:"firstBatch"`
+ NS string
+ Id int64
+ }
+ }
+
+ err := coll.Database.Run(bson.D{{"listIndexes", coll.Name}, {"cursor", bson.M{}}}, &cmdResult)
+ switch {
+ case err == nil:
+ ns := strings.SplitN(cmdResult.Cursor.NS, ".", 2)
+ if len(ns) < 2 {
+ return nil, fmt.Errorf("server returned invalid cursor.ns `%v` on listIndexes for `%v`: %v",
+ cmdResult.Cursor.NS, coll.FullName, err)
+ }
+
+ ses := coll.Database.Session
+ return ses.DB(ns[0]).C(ns[1]).NewIter(ses, cmdResult.Cursor.FirstBatch, cmdResult.Cursor.Id, nil), nil
+ case IsNoCmd(err):
+ log.Logvf(log.DebugLow, "No support for listIndexes command, falling back to querying system.indexes")
+ return getIndexesPre28(coll)
+ case IsNoCollection(err):
+ return nil, nil
+ default:
+ return nil, fmt.Errorf("error running `listIndexes`. Collection: `%v` Err: %v", coll.FullName, err)
+ }
+}
+
+func getIndexesPre28(coll *mgo.Collection) (*mgo.Iter, error) {
+ indexColl := coll.Database.C("system.indexes")
+ iter := indexColl.Find(&bson.M{"ns": coll.FullName}).Iter()
+ return iter, nil
+}
+
+func GetCollections(database *mgo.Database, name string) (*mgo.Iter, bool, error) {
+ var cmdResult struct {
+ Cursor struct {
+ FirstBatch []bson.Raw `bson:"firstBatch"`
+ NS string
+ Id int64
+ }
+ }
+
+ command := bson.D{{"listCollections", 1}, {"cursor", bson.M{}}}
+ if len(name) > 0 {
+ command = bson.D{{"listCollections", 1}, {"filter", bson.M{"name": name}}, {"cursor", bson.M{}}}
+ }
+
+ err := database.Run(command, &cmdResult)
+ switch {
+ case err == nil:
+ ns := strings.SplitN(cmdResult.Cursor.NS, ".", 2)
+ if len(ns) < 2 {
+ return nil, false, fmt.Errorf("server returned invalid cursor.ns `%v` on listCollections for `%v`: %v",
+ cmdResult.Cursor.NS, database.Name, err)
+ }
+
+ return database.Session.DB(ns[0]).C(ns[1]).NewIter(database.Session, cmdResult.Cursor.FirstBatch, cmdResult.Cursor.Id, nil), false, nil
+ case IsNoCmd(err):
+ log.Logvf(log.DebugLow, "No support for listCollections command, falling back to querying system.namespaces")
+ iter, err := getCollectionsPre28(database, name)
+ return iter, true, err
+ default:
+ return nil, false, fmt.Errorf("error running `listCollections`. Database: `%v` Err: %v",
+ database.Name, err)
+ }
+}
+
+func getCollectionsPre28(database *mgo.Database, name string) (*mgo.Iter, error) {
+ indexColl := database.C("system.namespaces")
+ selector := bson.M{}
+ if len(name) > 0 {
+ selector["name"] = database.Name + "." + name
+ }
+ iter := indexColl.Find(selector).Iter()
+ return iter, nil
+}
+
+func GetCollectionOptions(coll *mgo.Collection) (*bson.D, error) {
+ iter, useFullName, err := GetCollections(coll.Database, coll.Name)
+ if err != nil {
+ return nil, err
+ }
+ comparisonName := coll.Name
+ if useFullName {
+ comparisonName = coll.FullName
+ }
+ collInfo := &bson.D{}
+ for iter.Next(collInfo) {
+ name, err := bsonutil.FindValueByKey("name", collInfo)
+ if err != nil {
+ collInfo = nil
+ continue
+ }
+ if nameStr, ok := name.(string); ok {
+ if nameStr == comparisonName {
+ // we've found the collection we're looking for
+ return collInfo, nil
+ }
+ } else {
+ collInfo = nil
+ continue
+ }
+ }
+ err = iter.Err()
+ if err != nil {
+ return nil, err
+ }
+ // The given collection was not found, but no error encountered.
+ return nil, nil
+}
diff --git a/src/mongo/gotools/common/db/openssl/openssl.go b/src/mongo/gotools/common/db/openssl/openssl.go
new file mode 100644
index 00000000000..9b3c50c0e90
--- /dev/null
+++ b/src/mongo/gotools/common/db/openssl/openssl.go
@@ -0,0 +1,168 @@
+// Package openssl implements connection to MongoDB over ssl.
+package openssl
+
+import (
+ "fmt"
+ "net"
+ "time"
+
+ "gopkg.in/mgo.v2"
+
+ "github.com/mongodb/mongo-tools/common/options"
+ "github.com/mongodb/mongo-tools/common/util"
+ "github.com/spacemonkeygo/openssl"
+)
+
+// For connecting to the database over ssl
+type SSLDBConnector struct {
+ dialInfo *mgo.DialInfo
+ dialError error
+ ctx *openssl.Ctx
+}
+
+// Configure the connector to connect to the server over ssl. Parses the
+// connection string, and sets up the correct function to dial the server
+// based on the ssl options passed in.
+func (self *SSLDBConnector) Configure(opts options.ToolOptions) error {
+
+ // create the addresses to be used to connect
+ connectionAddrs := util.CreateConnectionAddrs(opts.Host, opts.Port)
+
+ var err error
+ self.ctx, err = setupCtx(opts)
+ if err != nil {
+ return fmt.Errorf("openssl configuration: %v", err)
+ }
+
+ var flags openssl.DialFlags
+ flags = 0
+ if opts.SSLAllowInvalidCert || opts.SSLAllowInvalidHost || opts.SSLCAFile == "" {
+ flags = openssl.InsecureSkipHostVerification
+ }
+ // create the dialer func that will be used to connect
+ dialer := func(addr *mgo.ServerAddr) (net.Conn, error) {
+ conn, err := openssl.Dial("tcp", addr.String(), self.ctx, flags)
+ self.dialError = err
+ return conn, err
+ }
+
+ timeout := time.Duration(opts.Timeout) * time.Second
+
+ // set up the dial info
+ self.dialInfo = &mgo.DialInfo{
+ Addrs: connectionAddrs,
+ Timeout: timeout,
+ Direct: opts.Direct,
+ ReplicaSetName: opts.ReplicaSetName,
+ DialServer: dialer,
+ Username: opts.Auth.Username,
+ Password: opts.Auth.Password,
+ Source: opts.GetAuthenticationDatabase(),
+ Mechanism: opts.Auth.Mechanism,
+ }
+
+ return nil
+
+}
+
+// Dial the server.
+func (self *SSLDBConnector) GetNewSession() (*mgo.Session, error) {
+ session, err := mgo.DialWithInfo(self.dialInfo)
+ if err != nil && self.dialError != nil {
+ return nil, fmt.Errorf("%v, openssl error: %v", err, self.dialError)
+ }
+ return session, err
+}
+
+// To be handed to mgo.DialInfo for connecting to the server.
+type dialerFunc func(addr *mgo.ServerAddr) (net.Conn, error)
+
+// Handle optionally compiled SSL initialization functions (fips mode set)
+type sslInitializationFunction func(options.ToolOptions) error
+
+var sslInitializationFunctions []sslInitializationFunction
+
+// Creates and configures an openssl.Ctx
+func setupCtx(opts options.ToolOptions) (*openssl.Ctx, error) {
+ var ctx *openssl.Ctx
+ var err error
+
+ for _, sslInitFunc := range sslInitializationFunctions {
+ sslInitFunc(opts)
+ }
+
+ if ctx, err = openssl.NewCtxWithVersion(openssl.AnyVersion); err != nil {
+ return nil, fmt.Errorf("failure creating new openssl context with "+
+ "NewCtxWithVersion(AnyVersion): %v", err)
+ }
+
+ // OpAll - Activate all bug workaround options, to support buggy client SSL's.
+ // NoSSLv2 - Disable SSL v2 support
+ ctx.SetOptions(openssl.OpAll | openssl.NoSSLv2)
+
+ // HIGH - Enable strong ciphers
+ // !EXPORT - Disable export ciphers (40/56 bit)
+ // !aNULL - Disable anonymous auth ciphers
+ // @STRENGTH - Sort ciphers based on strength
+ ctx.SetCipherList("HIGH:!EXPORT:!aNULL@STRENGTH")
+
+ // add the PEM key file with the cert and private key, if specified
+ if opts.SSLPEMKeyFile != "" {
+ if err = ctx.UseCertificateChainFile(opts.SSLPEMKeyFile); err != nil {
+ return nil, fmt.Errorf("UseCertificateChainFile: %v", err)
+ }
+ if opts.SSLPEMKeyPassword != "" {
+ if err = ctx.UsePrivateKeyFileWithPassword(
+ opts.SSLPEMKeyFile, openssl.FiletypePEM, opts.SSLPEMKeyPassword); err != nil {
+ return nil, fmt.Errorf("UsePrivateKeyFile: %v", err)
+ }
+ } else {
+ if err = ctx.UsePrivateKeyFile(opts.SSLPEMKeyFile, openssl.FiletypePEM); err != nil {
+ return nil, fmt.Errorf("UsePrivateKeyFile: %v", err)
+ }
+ }
+ // Verify that the certificate and the key go together.
+ if err = ctx.CheckPrivateKey(); err != nil {
+ return nil, fmt.Errorf("CheckPrivateKey: %v", err)
+ }
+ }
+
+ // If renegotiation is needed, don't return from recv() or send() until it's successful.
+ // Note: this is for blocking sockets only.
+ ctx.SetMode(openssl.AutoRetry)
+
+ // Disable session caching (see SERVER-10261)
+ ctx.SetSessionCacheMode(openssl.SessionCacheOff)
+
+ if opts.SSLCAFile != "" {
+ calist, err := openssl.LoadClientCAFile(opts.SSLCAFile)
+ if err != nil {
+ return nil, fmt.Errorf("LoadClientCAFile: %v", err)
+ }
+ ctx.SetClientCAList(calist)
+
+ if err = ctx.LoadVerifyLocations(opts.SSLCAFile, ""); err != nil {
+ return nil, fmt.Errorf("LoadVerifyLocations: %v", err)
+ }
+
+ var verifyOption openssl.VerifyOptions
+ if opts.SSLAllowInvalidCert {
+ verifyOption = openssl.VerifyNone
+ } else {
+ verifyOption = openssl.VerifyPeer
+ }
+ ctx.SetVerify(verifyOption, nil)
+ }
+
+ if opts.SSLCRLFile != "" {
+ store := ctx.GetCertificateStore()
+ store.SetFlags(openssl.CRLCheck)
+ lookup, err := store.AddLookup(openssl.X509LookupFile())
+ if err != nil {
+ return nil, fmt.Errorf("AddLookup(X509LookupFile()): %v", err)
+ }
+ lookup.LoadCRLFile(opts.SSLCRLFile)
+ }
+
+ return ctx, nil
+}
diff --git a/src/mongo/gotools/common/db/openssl/openssl_fips.go b/src/mongo/gotools/common/db/openssl/openssl_fips.go
new file mode 100644
index 00000000000..2c4705e23ff
--- /dev/null
+++ b/src/mongo/gotools/common/db/openssl/openssl_fips.go
@@ -0,0 +1,15 @@
+// +build ssl
+// +build -darwin
+
+package openssl
+
+import "github.com/spacemonkeygo/openssl"
+
+func init() { sslInitializationFunctions = append(sslInitializationFunctions, SetUpFIPSMode) }
+
+func SetUpFIPSMode(opts *ToolOptions) error {
+ if err := openssl.FIPSModeSet(opts.SSLFipsMode); err != nil {
+ return fmt.Errorf("couldn't set FIPS mode to %v: %v", opts.SSLFipsMode, err)
+ }
+ return nil
+}
diff --git a/src/mongo/gotools/common/db/openssl/testdata/ca.pem b/src/mongo/gotools/common/db/openssl/testdata/ca.pem
new file mode 100644
index 00000000000..b1b6f2628da
--- /dev/null
+++ b/src/mongo/gotools/common/db/openssl/testdata/ca.pem
@@ -0,0 +1,34 @@
+-----BEGIN PRIVATE KEY-----
+MIICdgIBADANBgkqhkiG9w0BAQEFAASCAmAwggJcAgEAAoGBAMbN8D5Au+xWdY+s
+GpUuSFSbHGzYfHmw0yajA9J8PiwDePRMl71OMMsByNsykjzXEr0BBOn4PNO6KW7K
+HdDicRavuC/iFucVpILUiJoLOUCPKb/EyAHUk0r2fdr3Ypd2ZXkD1EXmM9WTQnyW
+PEWqr1T7MmM9PhsD0r8ZbQVu8R49AgMBAAECgYBbC+mguQjXfektOUabV6zsgnUM
+LEElgiPRqAqSFTBr+9MjHwjHO84Ayvpv2MM8dcsxIAxeEr/Yv4NGJ+5rwajESir6
+/7UzqzhXmj6ylqTfbMRJCRsqnwvSfNwpsxtMSYieCxtdYqTLaJLAItBjuZPAYL8W
+9Tf/NMc4AjLLHx7PyQJBAOyOcIS/i23td6ZX+QtppXL1fF/JMiKooE9m/npAT5K/
+hQEaAatdLyQ669id181KY9F0JR1TEbzb0A1yo73soRsCQQDXJSG4ID8lfR9SXnEE
+y/RqYv0eKneER+V7e1Cy7bYHvJxZK0sWXYzIZhTl8PABh3PCoLdxjY0IM7UNWlwU
+dAuHAkAOUaTv9CQ9eDVY5VRW44M3TTLFHYmiXXCuvb5Dqibm7B7h7TASrmZPHB3w
+k8VfUNRv9kbU2pVlSCz0026j7XHnAkEAk/qZP8EGTe3K3mfRCsCSA57EhLwm6phd
+ElrWPcvc2WN0kqyBgAembqwwEZxwKE0XZTYQFw2KhKq0DFQrY3IR/wJAIAnLtabL
+aF819WI/VYlMmwb3GAO2w5KQilGhYl7tv1BghH+Qmg7HZEcIRmSwPKEQveT3YpCH
+nCu38jgPXhhqdg==
+-----END PRIVATE KEY-----
+-----BEGIN CERTIFICATE-----
+MIIC3DCCAkWgAwIBAgIJAKwksc/otf2iMA0GCSqGSIb3DQEBCwUAMIGGMQswCQYD
+VQQGEwJVUzERMA8GA1UECAwITmV3IFlvcmsxFjAUBgNVBAcMDU5ldyBZb3JrIENp
+dHkxHTAbBgNVBAoMFE1vbmdvREIgS2VybmVsIFRvb2xzMRkwFwYDVQQLDBBUb29s
+cyBUZXN0aW5nIENBMRIwEAYDVQQDDAlsb2NhbGhvc3QwHhcNMTUwNjA1MTU1MTQ1
+WhcNMzUwNjA0MTU1MTQ1WjCBhjELMAkGA1UEBhMCVVMxETAPBgNVBAgMCE5ldyBZ
+b3JrMRYwFAYDVQQHDA1OZXcgWW9yayBDaXR5MR0wGwYDVQQKDBRNb25nb0RCIEtl
+cm5lbCBUb29sczEZMBcGA1UECwwQVG9vbHMgVGVzdGluZyBDQTESMBAGA1UEAwwJ
+bG9jYWxob3N0MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDGzfA+QLvsVnWP
+rBqVLkhUmxxs2Hx5sNMmowPSfD4sA3j0TJe9TjDLAcjbMpI81xK9AQTp+DzTuilu
+yh3Q4nEWr7gv4hbnFaSC1IiaCzlAjym/xMgB1JNK9n3a92KXdmV5A9RF5jPVk0J8
+ljxFqq9U+zJjPT4bA9K/GW0FbvEePQIDAQABo1AwTjAdBgNVHQ4EFgQU+QOiCHTF
+8At8aMOBvHF6wWZpcZUwHwYDVR0jBBgwFoAU+QOiCHTF8At8aMOBvHF6wWZpcZUw
+DAYDVR0TBAUwAwEB/zANBgkqhkiG9w0BAQsFAAOBgQCbbIAjdV+M8RR3ZF1WMBYD
+8aMr55kgtnCWn4mTCDdombCYgtbaPq5sy8Hb/2wLQ9Zl4UuFL5wKWcx3kOLo3cw/
+boj8jnUDnwrsBd2nN7sYdjF+M7FLp6U1AxrE5ejijtg2KCl+p4b7jJgJBSFIQD45
+7CAJVjIrajY4LlJj3x+caQ==
+-----END CERTIFICATE-----
diff --git a/src/mongo/gotools/common/db/openssl/testdata/server.pem b/src/mongo/gotools/common/db/openssl/testdata/server.pem
new file mode 100644
index 00000000000..d2aaa930ff5
--- /dev/null
+++ b/src/mongo/gotools/common/db/openssl/testdata/server.pem
@@ -0,0 +1,32 @@
+-----BEGIN PRIVATE KEY-----
+MIICeAIBADANBgkqhkiG9w0BAQEFAASCAmIwggJeAgEAAoGBALOkdwU9Qx4FRn+z
+coBkeYYpVRg0pknPMDo4Q50TqZPfVhroTynx2Or+cjl5csd5hMKxWQpdzGq8JzH9
+2BCLcDz/51vG3tPrpLIB50ABqa0wRGGDOO+XN0h+VkdqJvKReWOsNRoMT3s0Lh78
+BqvRUomYXnbc1RBaxwWa+UoLCFgnAgMBAAECgYBd9XmjLeW6//tds5gB+4tsVpYB
+cRhAprOM3/zNXYlmpHu+2x78y1gvoSJRWWplVvPPeT8fIuxWL0844JJwJN5wyCwN
+nnrA28l6+Tcde+NlzCxwED+QDjAH20BRxCs0BLvnx3WAXRDmUbWAjOl/qPn9H6m1
+nmUQ7H/f6dxZ0vVMQQJBAOl3xeVLyZZ828P/p3PvYkaeIxxVK1QDGOWi/3vC0DrY
+WK8xAoopjj0RHHZ1fL5bG31G3OR9Vc/rfk4a5XPIlRECQQDE+teCTiwV5Wwzdpg3
+r440qOLCmpMXwJr/Jlh+C4c8ebnIQ9P5sSe4wQNHyeEZ2t7SGvPfjr7glpPhAkXy
+JTm3AkEAvNPgvVoUy6Bk5xuJRl2hMNiKMUo5ZxOyOVkiJeklHdMJt3h+Q1zk7ENA
+sBbKM/PgQezkj/FHTIl9eJKMbp8W4QJBAL4aXHyslw12wisUrKkpa7PUviwT5BvL
+TYsrZcIXvCeYTr1BAMX8vBopZNIWuoEqY1sgmfZKnFrB1+wTNpAQbxcCQQCHbtvQ
+1U2p5Pz5XYyaoK2OEZhPMuLnOBMpzjSxRLxKyhb4k+ssIA0IeAiT4RIECtHJ8DJX
+4aZK/qg9WmBH+zbO
+-----END PRIVATE KEY-----
+-----BEGIN CERTIFICATE-----
+MIICbzCCAdgCAQEwDQYJKoZIhvcNAQEFBQAwgYYxCzAJBgNVBAYTAlVTMREwDwYD
+VQQIDAhOZXcgWW9yazEWMBQGA1UEBwwNTmV3IFlvcmsgQ2l0eTEdMBsGA1UECgwU
+TW9uZ29EQiBLZXJuZWwgVG9vbHMxGTAXBgNVBAsMEFRvb2xzIFRlc3RpbmcgQ0Ex
+EjAQBgNVBAMMCWxvY2FsaG9zdDAeFw0xNTA2MDUxNTUxNDVaFw0zNTA2MDQxNTUx
+NDVaMHkxCzAJBgNVBAYTAlVTMREwDwYDVQQIDAhOZXcgWW9yazEWMBQGA1UEBwwN
+TmV3IFlvcmsgQ2l0eTEUMBIGA1UECgwLTW9uZ29EQiBJbmMxFTATBgNVBAsMDEtl
+cm5lbCBUb29sczESMBAGA1UEAwwJbG9jYWxob3N0MIGfMA0GCSqGSIb3DQEBAQUA
+A4GNADCBiQKBgQCzpHcFPUMeBUZ/s3KAZHmGKVUYNKZJzzA6OEOdE6mT31Ya6E8p
+8djq/nI5eXLHeYTCsVkKXcxqvCcx/dgQi3A8/+dbxt7T66SyAedAAamtMERhgzjv
+lzdIflZHaibykXljrDUaDE97NC4e/Aar0VKJmF523NUQWscFmvlKCwhYJwIDAQAB
+MA0GCSqGSIb3DQEBBQUAA4GBACJiTnC3nksZsmMyD88+DuV8IA1DHSby4X/qtDYT
+eSuNbxRKnihXkm2KE+MGn7YeKg4a7FaYiH3ejk0ZBlY3TZXK3I1uh/zIhC9aMnSL
+z0z4OLcqp46F8PpYF7ARtXXWQuOEWe6k+VKy5XP1NX60sEJ0KwGBQjUw3Ys41JE8
+iigw
+-----END CERTIFICATE-----
diff --git a/src/mongo/gotools/common/db/read_preferences.go b/src/mongo/gotools/common/db/read_preferences.go
new file mode 100644
index 00000000000..9dec319ca48
--- /dev/null
+++ b/src/mongo/gotools/common/db/read_preferences.go
@@ -0,0 +1,51 @@
+package db
+
+import (
+ "fmt"
+
+ "github.com/mongodb/mongo-tools/common/json"
+ "gopkg.in/mgo.v2"
+ "gopkg.in/mgo.v2/bson"
+)
+
+type readPrefDoc struct {
+ Mode string
+ Tags bson.D
+}
+
+const (
+ WarningNonPrimaryMongosConnection = "Warning: using a non-primary readPreference with a " +
+ "connection to mongos may produce inconsistent duplicates or miss some documents."
+)
+
+func ParseReadPreference(rp string) (mgo.Mode, bson.D, error) {
+ var mode string
+ var tags bson.D
+ if rp == "" {
+ return mgo.Nearest, nil, nil
+ }
+ if rp[0] != '{' {
+ mode = rp
+ } else {
+ var doc readPrefDoc
+ err := json.Unmarshal([]byte(rp), &doc)
+ if err != nil {
+ return 0, nil, fmt.Errorf("invalid --ReadPreferences json object: %v", err)
+ }
+ tags = doc.Tags
+ mode = doc.Mode
+ }
+ switch mode {
+ case "primary":
+ return mgo.Primary, tags, nil
+ case "primaryPreferred":
+ return mgo.PrimaryPreferred, tags, nil
+ case "secondary":
+ return mgo.Secondary, tags, nil
+ case "secondaryPreferred":
+ return mgo.SecondaryPreferred, tags, nil
+ case "nearest":
+ return mgo.Nearest, tags, nil
+ }
+ return 0, nil, fmt.Errorf("invalid readPreference mode '%v'", mode)
+}
diff --git a/src/mongo/gotools/common/db/testdata/testdata.bson b/src/mongo/gotools/common/db/testdata/testdata.bson
new file mode 100644
index 00000000000..5157dc1158f
--- /dev/null
+++ b/src/mongo/gotools/common/db/testdata/testdata.bson
Binary files differ
diff --git a/src/mongo/gotools/common/db/write_concern.go b/src/mongo/gotools/common/db/write_concern.go
new file mode 100644
index 00000000000..0a9a16214c8
--- /dev/null
+++ b/src/mongo/gotools/common/db/write_concern.go
@@ -0,0 +1,123 @@
+package db
+
+import (
+ "fmt"
+ "github.com/mongodb/mongo-tools/common/json"
+ "github.com/mongodb/mongo-tools/common/log"
+ "github.com/mongodb/mongo-tools/common/util"
+ "gopkg.in/mgo.v2"
+ "strconv"
+)
+
+// write concern fields
+const (
+ j = "j"
+ w = "w"
+ fSync = "fsync"
+ wTimeout = "wtimeout"
+)
+
+// constructWCObject takes in a write concern and attempts to construct an
+// mgo.Safe object from it. It returns an error if it is unable to parse the
+// string or if a parsed write concern field value is invalid.
+func constructWCObject(writeConcern string) (sessionSafety *mgo.Safe, err error) {
+ sessionSafety = &mgo.Safe{}
+ defer func() {
+ // If the user passes a w value of 0, we set the session to use the
+ // unacknowledged write concern but only if journal commit acknowledgment,
+ // is not required. If commit acknowledgment is required, it prevails,
+ // and the server will require that mongod acknowledge the write operation
+ if sessionSafety.WMode == "" && sessionSafety.W == 0 && !sessionSafety.J {
+ sessionSafety = nil
+ }
+ }()
+ jsonWriteConcern := map[string]interface{}{}
+
+ if err = json.Unmarshal([]byte(writeConcern), &jsonWriteConcern); err != nil {
+ // if the writeConcern string can not be unmarshaled into JSON, this
+ // allows a default to the old behavior wherein the entire argument
+ // passed in is assigned to the 'w' field - thus allowing users pass
+ // a write concern that looks like: "majority", 0, "4", etc.
+ wValue, err := strconv.Atoi(writeConcern)
+ if err != nil {
+ sessionSafety.WMode = writeConcern
+ } else {
+ sessionSafety.W = wValue
+ if wValue < 0 {
+ return sessionSafety, fmt.Errorf("invalid '%v' argument: %v", w, wValue)
+ }
+ }
+ return sessionSafety, nil
+ }
+
+ if jVal, ok := jsonWriteConcern[j]; ok && util.IsTruthy(jVal) {
+ sessionSafety.J = true
+ }
+
+ if fsyncVal, ok := jsonWriteConcern[fSync]; ok && util.IsTruthy(fsyncVal) {
+ sessionSafety.FSync = true
+ }
+
+ if wtimeout, ok := jsonWriteConcern[wTimeout]; ok {
+ wtimeoutValue, err := util.ToInt(wtimeout)
+ if err != nil {
+ return sessionSafety, fmt.Errorf("invalid '%v' argument: %v", wTimeout, wtimeout)
+ }
+ sessionSafety.WTimeout = wtimeoutValue
+ }
+
+ if wInterface, ok := jsonWriteConcern[w]; ok {
+ wValue, err := util.ToInt(wInterface)
+ if err != nil {
+ // if the argument is neither a string nor int, error out
+ wStrVal, ok := wInterface.(string)
+ if !ok {
+ return sessionSafety, fmt.Errorf("invalid '%v' argument: %v", w, wInterface)
+ }
+ sessionSafety.WMode = wStrVal
+ } else {
+ sessionSafety.W = wValue
+ if wValue < 0 {
+ return sessionSafety, fmt.Errorf("invalid '%v' argument: %v", w, wValue)
+ }
+ }
+ }
+ return sessionSafety, nil
+}
+
+// BuildWriteConcern takes a string and a NodeType indicating the type of node the write concern
+// is intended to be used against, and converts the write concern string argument into an
+// mgo.Safe object that's usable on sessions for that node type.
+func BuildWriteConcern(writeConcern string, nodeType NodeType) (*mgo.Safe, error) {
+ sessionSafety, err := constructWCObject(writeConcern)
+ if err != nil {
+ return nil, err
+ }
+
+ if sessionSafety == nil {
+ log.Logvf(log.DebugLow, "using unacknowledged write concern")
+ return nil, nil
+ }
+
+ // for standalone mongods, set the default write concern to 1
+ if nodeType == Standalone {
+ log.Logvf(log.DebugLow, "standalone server: setting write concern %v to 1", w)
+ sessionSafety.W = 1
+ sessionSafety.WMode = ""
+ }
+
+ var writeConcernStr interface{}
+
+ if sessionSafety.WMode != "" {
+ writeConcernStr = sessionSafety.WMode
+ } else {
+ writeConcernStr = sessionSafety.W
+ }
+ log.Logvf(log.Info, "using write concern: %v='%v', %v=%v, %v=%v, %v=%v",
+ w, writeConcernStr,
+ j, sessionSafety.J,
+ fSync, sessionSafety.FSync,
+ wTimeout, sessionSafety.WTimeout,
+ )
+ return sessionSafety, nil
+}
diff --git a/src/mongo/gotools/common/db/write_concern_test.go b/src/mongo/gotools/common/db/write_concern_test.go
new file mode 100644
index 00000000000..96bd8e0ed89
--- /dev/null
+++ b/src/mongo/gotools/common/db/write_concern_test.go
@@ -0,0 +1,166 @@
+package db
+
+import (
+ . "github.com/smartystreets/goconvey/convey"
+ "testing"
+)
+
+func TestBuildWriteConcern(t *testing.T) {
+ Convey("Given a write concern string value, and a boolean indicating if the "+
+ "write concern is to be used on a replica set, on calling BuildWriteConcern...", t, func() {
+ Convey("no error should be returned if the write concern is valid", func() {
+ writeConcern, err := BuildWriteConcern(`{w:34}`, ReplSet)
+ So(err, ShouldBeNil)
+ So(writeConcern.W, ShouldEqual, 34)
+ writeConcern, err = BuildWriteConcern(`{w:"majority"}`, ReplSet)
+ So(err, ShouldBeNil)
+ So(writeConcern.WMode, ShouldEqual, "majority")
+ writeConcern, err = BuildWriteConcern(`majority`, ReplSet)
+ So(err, ShouldBeNil)
+ So(writeConcern.WMode, ShouldEqual, "majority")
+ writeConcern, err = BuildWriteConcern(`tagset`, ReplSet)
+ So(err, ShouldBeNil)
+ So(writeConcern.WMode, ShouldEqual, "tagset")
+ })
+ Convey("on replica sets, only a write concern of 1 or 0 should be returned", func() {
+ writeConcern, err := BuildWriteConcern(`{w:34}`, Standalone)
+ So(err, ShouldBeNil)
+ So(writeConcern.W, ShouldEqual, 1)
+ writeConcern, err = BuildWriteConcern(`{w:"majority"}`, Standalone)
+ So(err, ShouldBeNil)
+ So(writeConcern.W, ShouldEqual, 1)
+ writeConcern, err = BuildWriteConcern(`tagset`, Standalone)
+ So(err, ShouldBeNil)
+ So(writeConcern.W, ShouldEqual, 1)
+ })
+ Convey("with a w value of 0, without j set, a nil write concern should be returned", func() {
+ writeConcern, err := BuildWriteConcern(`{w:0}`, Standalone)
+ So(err, ShouldBeNil)
+ So(writeConcern, ShouldBeNil)
+ })
+ Convey("with a negative w value, an error should be returned", func() {
+ _, err := BuildWriteConcern(`{w:-1}`, ReplSet)
+ So(err, ShouldNotBeNil)
+ _, err = BuildWriteConcern(`{w:-2}`, ReplSet)
+ So(err, ShouldNotBeNil)
+ })
+ Convey("with a w value of 0, with j set, a non-nil write concern should be returned", func() {
+ writeConcern, err := BuildWriteConcern(`{w:0, j:true}`, Standalone)
+ So(err, ShouldBeNil)
+ So(writeConcern.J, ShouldBeTrue)
+ })
+ })
+}
+
+func TestConstructWCObject(t *testing.T) {
+ Convey("Given a write concern string value, on calling constructWCObject...", t, func() {
+
+ Convey("non-JSON string values should be assigned to the 'WMode' "+
+ "field in their entirety", func() {
+ writeConcernString := "majority"
+ writeConcern, err := constructWCObject(writeConcernString)
+ So(err, ShouldBeNil)
+ So(writeConcern.WMode, ShouldEqual, writeConcernString)
+ })
+
+ Convey("non-JSON int values should be assigned to the 'w' field "+
+ "in their entirety", func() {
+ writeConcernString := `{w: 4}`
+ writeConcern, err := constructWCObject(writeConcernString)
+ So(err, ShouldBeNil)
+ So(writeConcern.W, ShouldEqual, 4)
+ })
+
+ Convey("JSON strings with valid j, wtimeout, fsync and w, should be "+
+ "assigned accordingly", func() {
+ writeConcernString := `{w: 3, j: true, fsync: false, wtimeout: 43}`
+ expectedW := 3
+ expectedWTimeout := 43
+ writeConcern, err := constructWCObject(writeConcernString)
+ So(err, ShouldBeNil)
+ So(writeConcern.W, ShouldEqual, expectedW)
+ So(writeConcern.J, ShouldBeTrue)
+ So(writeConcern.FSync, ShouldBeFalse)
+ So(writeConcern.WTimeout, ShouldEqual, expectedWTimeout)
+ })
+
+ Convey("JSON strings with an argument for j that is not false should set j true", func() {
+ writeConcernString := `{w: 3, j: "rue"}`
+ writeConcern, err := constructWCObject(writeConcernString)
+ So(err, ShouldBeNil)
+ So(writeConcern.W, ShouldEqual, 3)
+ So(writeConcern.J, ShouldBeTrue)
+ })
+
+ Convey("JSON strings with an argument for fsync that is not false should set fsync true", func() {
+ writeConcernString := `{w: 3, fsync: "rue"}`
+ writeConcern, err := constructWCObject(writeConcernString)
+ So(err, ShouldBeNil)
+ So(writeConcern.W, ShouldEqual, 3)
+ So(writeConcern.FSync, ShouldBeTrue)
+ })
+
+ Convey("JSON strings with an invalid wtimeout argument should error out", func() {
+ writeConcernString := `{w: 3, wtimeout: "rue"}`
+ _, err := constructWCObject(writeConcernString)
+ So(err, ShouldNotBeNil)
+ writeConcernString = `{w: 3, wtimeout: "43"}`
+ _, err = constructWCObject(writeConcernString)
+ So(err, ShouldNotBeNil)
+ })
+
+ Convey("JSON strings with any non-false j argument should not error out", func() {
+ writeConcernString := `{w: 3, j: "t"}`
+ writeConcern, err := constructWCObject(writeConcernString)
+ So(err, ShouldBeNil)
+ So(writeConcern.J, ShouldBeTrue)
+ writeConcernString = `{w: 3, j: "f"}`
+ writeConcern, err = constructWCObject(writeConcernString)
+ So(err, ShouldBeNil)
+ So(writeConcern.J, ShouldBeTrue)
+ writeConcernString = `{w: 3, j: false}`
+ writeConcern, err = constructWCObject(writeConcernString)
+ So(err, ShouldBeNil)
+ So(writeConcern.J, ShouldBeFalse)
+ writeConcernString = `{w: 3, j: 0}`
+ writeConcern, err = constructWCObject(writeConcernString)
+ So(err, ShouldBeNil)
+ So(writeConcern.J, ShouldBeFalse)
+ })
+
+ Convey("JSON strings with a shorthand fsync argument should not error out", func() {
+ writeConcernString := `{w: 3, fsync: "t"}`
+ writeConcern, err := constructWCObject(writeConcernString)
+ So(err, ShouldBeNil)
+ So(writeConcern.FSync, ShouldBeTrue)
+ writeConcernString = `{w: "3", fsync: "f"}`
+ writeConcern, err = constructWCObject(writeConcernString)
+ So(err, ShouldBeNil)
+ So(writeConcern.FSync, ShouldBeTrue)
+ writeConcernString = `{w: "3", fsync: false}`
+ writeConcern, err = constructWCObject(writeConcernString)
+ So(err, ShouldBeNil)
+ So(writeConcern.FSync, ShouldBeFalse)
+ writeConcernString = `{w: "3", fsync: 0}`
+ writeConcern, err = constructWCObject(writeConcernString)
+ So(err, ShouldBeNil)
+ So(writeConcern.FSync, ShouldBeFalse)
+ })
+
+ Convey("Unacknowledge write concern strings should return a nil object "+
+ "if journaling is not required", func() {
+ writeConcernString := `{w: 0}`
+ writeConcern, err := constructWCObject(writeConcernString)
+ So(err, ShouldBeNil)
+ So(writeConcern, ShouldBeNil)
+ writeConcernString = `{w: 0}`
+ writeConcern, err = constructWCObject(writeConcernString)
+ So(err, ShouldBeNil)
+ So(writeConcern, ShouldBeNil)
+ writeConcernString = `0`
+ writeConcern, err = constructWCObject(writeConcernString)
+ So(err, ShouldBeNil)
+ So(writeConcern, ShouldBeNil)
+ })
+ })
+}