diff options
Diffstat (limited to 'src/mongo/gotools/common')
22 files changed, 348 insertions, 233 deletions
diff --git a/src/mongo/gotools/common/db/buffered_bulk_test.go b/src/mongo/gotools/common/db/buffered_bulk_test.go index d4664dadd20..304bd2cc1e1 100644 --- a/src/mongo/gotools/common/db/buffered_bulk_test.go +++ b/src/mongo/gotools/common/db/buffered_bulk_test.go @@ -22,6 +22,8 @@ func TestBufferedBulkInserterInserts(t *testing.T) { Auth: &options.Auth{}, } provider, err := NewSessionProvider(opts) + So(provider, ShouldNotBeNil) + So(err, ShouldBeNil) session, err := provider.GetSession() So(session, ShouldNotBeNil) So(err, ShouldBeNil) @@ -102,6 +104,8 @@ func TestBufferedBulkInserterInserts(t *testing.T) { Reset(func() { session.DB("tools-test").DropDatabase() + session.Close() + provider.Close() }) }) diff --git a/src/mongo/gotools/common/db/connector.go b/src/mongo/gotools/common/db/connector.go index 85e8d6e653c..c3e50ddfa5d 100644 --- a/src/mongo/gotools/common/db/connector.go +++ b/src/mongo/gotools/common/db/connector.go @@ -3,6 +3,7 @@ package db import ( "time" + "github.com/mongodb/mongo-tools/common/db/kerberos" "github.com/mongodb/mongo-tools/common/options" "github.com/mongodb/mongo-tools/common/util" "gopkg.in/mgo.v2" @@ -42,6 +43,7 @@ func (self *VanillaDBConnector) Configure(opts options.ToolOptions) error { Source: opts.GetAuthenticationDatabase(), Mechanism: opts.Auth.Mechanism, } + kerberos.AddKerberosOpts(opts, self.dialInfo) return nil } diff --git a/src/mongo/gotools/common/db/connector_sasl_test.go b/src/mongo/gotools/common/db/connector_sasl_test.go index c585c92f842..35f7090ce6b 100644 --- a/src/mongo/gotools/common/db/connector_sasl_test.go +++ b/src/mongo/gotools/common/db/connector_sasl_test.go @@ -6,14 +6,14 @@ package db import ( "fmt" - "github.com/mongodb/mongo-tools/common/db/kerberos" + "os" + "runtime" + "testing" + "github.com/mongodb/mongo-tools/common/options" "github.com/mongodb/mongo-tools/common/testutil" . "github.com/smartystreets/goconvey/convey" "gopkg.in/mgo.v2/bson" - "os" - "runtime" - "testing" ) var ( @@ -21,9 +21,9 @@ var ( KERBEROS_USER = "drivers@LDAPTEST.10GEN.CC" ) -func TestKerberosDBConnector(t *testing.T) { +func TestKerberosAuthMechanism(t *testing.T) { Convey("should be able to successfully connect", t, func() { - connector := &kerberos.KerberosDBConnector{} + connector := &VanillaDBConnector{} opts := options.ToolOptions{ Connection: &options.Connection{ @@ -31,7 +31,8 @@ func TestKerberosDBConnector(t *testing.T) { Port: "27017", }, Auth: &options.Auth{ - Username: KERBEROS_USER, + Username: KERBEROS_USER, + Mechanism: "GSSAPI", }, Kerberos: &options.Kerberos{ Service: "mongodb", diff --git a/src/mongo/gotools/common/db/db.go b/src/mongo/gotools/common/db/db.go index a3207c4b467..5372cb6981f 100644 --- a/src/mongo/gotools/common/db/db.go +++ b/src/mongo/gotools/common/db/db.go @@ -116,6 +116,13 @@ func (self *SessionProvider) GetSession() (*mgo.Session, error) { return self.masterSession.Copy(), nil } +// Close closes the master session in the connection pool +func (self *SessionProvider) Close() { + self.masterSessionLock.Lock() + defer self.masterSessionLock.Unlock() + self.masterSession.Close() +} + // refresh is a helper for modifying the session based on the // session provider flags passed in with SetFlags. // This helper assumes a lock is already taken. diff --git a/src/mongo/gotools/common/db/db_gssapi.go b/src/mongo/gotools/common/db/db_gssapi.go deleted file mode 100644 index 656e81987a9..00000000000 --- a/src/mongo/gotools/common/db/db_gssapi.go +++ /dev/null @@ -1,20 +0,0 @@ -// +build sasl - -package db - -import ( - "github.com/mongodb/mongo-tools/common/db/kerberos" - "github.com/mongodb/mongo-tools/common/options" -) - -func init() { - GetConnectorFuncs = append(GetConnectorFuncs, getGSSAPIConnector) -} - -// return the Kerberos DB connector if using SSL, otherwise return nil. -func getGSSAPIConnector(opts options.ToolOptions) DBConnector { - if opts.Auth.Mechanism == "GSSAPI" { - return &kerberos.KerberosDBConnector{} - } - return nil -} diff --git a/src/mongo/gotools/common/db/db_test.go b/src/mongo/gotools/common/db/db_test.go index 59d1c53b929..b266d8619d8 100644 --- a/src/mongo/gotools/common/db/db_test.go +++ b/src/mongo/gotools/common/db/db_test.go @@ -45,7 +45,14 @@ func TestNewSessionProvider(t *testing.T) { session, err := provider.GetSession() So(err, ShouldBeNil) So(session, ShouldNotBeNil) + session.Close() So(provider.masterSession, ShouldNotBeNil) + err = provider.masterSession.Ping() + So(err, ShouldBeNil) + provider.Close() + So(func() { + provider.masterSession.Ping() + }, ShouldPanic) }) diff --git a/src/mongo/gotools/common/db/kerberos/gssapi.go b/src/mongo/gotools/common/db/kerberos/gssapi.go index e9827b04109..c2b93ef6fc9 100644 --- a/src/mongo/gotools/common/db/kerberos/gssapi.go +++ b/src/mongo/gotools/common/db/kerberos/gssapi.go @@ -1,4 +1,4 @@ -// Package kerberos implements connection to MongoDB using kerberos. +// Package kerberos implements authentication to MongoDB using kerberos package kerberos // #cgo windows CFLAGS: -Ic:/sasl/include @@ -6,53 +6,24 @@ package kerberos import ( "github.com/mongodb/mongo-tools/common/options" - "github.com/mongodb/mongo-tools/common/util" "gopkg.in/mgo.v2" - "time" ) -const ( - KERBEROS_AUTHENTICATION_MECHANISM = "GSSAPI" -) - -type KerberosDBConnector struct { - dialInfo *mgo.DialInfo -} - -// Configure the db connector. Parses the connection string and sets up -// the dial info with the default dial timeout. -func (self *KerberosDBConnector) Configure(opts options.ToolOptions) error { - - // create the addresses to be used to connect - connectionAddrs := util.CreateConnectionAddrs(opts.Host, opts.Port) - - timeout := time.Duration(opts.Timeout) * time.Second +const authMechanism = "GSSAPI" - // set up the dial info - self.dialInfo = &mgo.DialInfo{ - Addrs: connectionAddrs, - Timeout: timeout, - Direct: opts.Direct, - ReplicaSetName: opts.ReplicaSetName, - - // Kerberos principal - Username: opts.Auth.Username, - // Note: Password is only used on Windows. SASL doesn't allow you to specify - // a password, so this field is ignored on Linux and OSX. Run the kinit - // command to get a ticket first. - Password: opts.Auth.Password, - // This should always be '$external', but legacy tools still allow you to - // specify a source DB - Source: opts.Auth.Source, - Service: opts.Kerberos.Service, - ServiceHost: opts.Kerberos.ServiceHost, - Mechanism: KERBEROS_AUTHENTICATION_MECHANISM, +func AddKerberosOpts(opts options.ToolOptions, dialInfo *mgo.DialInfo) { + if dialInfo == nil { + return } - - return nil -} - -// Dial the database. -func (self *KerberosDBConnector) GetNewSession() (*mgo.Session, error) { - return mgo.DialWithInfo(self.dialInfo) + if opts.Kerberos == nil || opts.Kerberos.Service == "" || + opts.Kerberos.ServiceHost == "" { + return + } + if opts.Auth == nil || (opts.Auth.Mechanism != authMechanism && + dialInfo.Mechanism != authMechanism) { + return + } + dialInfo.Service = opts.Kerberos.Service + dialInfo.ServiceHost = opts.Kerberos.ServiceHost + dialInfo.Mechanism = authMechanism } diff --git a/src/mongo/gotools/common/db/openssl/openssl.go b/src/mongo/gotools/common/db/openssl/openssl.go index 9b3c50c0e90..1cc4c2ccd1f 100644 --- a/src/mongo/gotools/common/db/openssl/openssl.go +++ b/src/mongo/gotools/common/db/openssl/openssl.go @@ -6,11 +6,11 @@ import ( "net" "time" - "gopkg.in/mgo.v2" - + "github.com/mongodb/mongo-tools/common/db/kerberos" "github.com/mongodb/mongo-tools/common/options" "github.com/mongodb/mongo-tools/common/util" "github.com/spacemonkeygo/openssl" + "gopkg.in/mgo.v2" ) // For connecting to the database over ssl @@ -24,7 +24,6 @@ type SSLDBConnector struct { // connection string, and sets up the correct function to dial the server // based on the ssl options passed in. func (self *SSLDBConnector) Configure(opts options.ToolOptions) error { - // create the addresses to be used to connect connectionAddrs := util.CreateConnectionAddrs(opts.Host, opts.Port) @@ -36,7 +35,7 @@ func (self *SSLDBConnector) Configure(opts options.ToolOptions) error { var flags openssl.DialFlags flags = 0 - if opts.SSLAllowInvalidCert || opts.SSLAllowInvalidHost || opts.SSLCAFile == "" { + if opts.SSLAllowInvalidCert || opts.SSLAllowInvalidHost { flags = openssl.InsecureSkipHostVerification } // create the dialer func that will be used to connect @@ -60,7 +59,7 @@ func (self *SSLDBConnector) Configure(opts options.ToolOptions) error { Source: opts.GetAuthenticationDatabase(), Mechanism: opts.Auth.Mechanism, } - + kerberos.AddKerberosOpts(opts, self.dialInfo) return nil } @@ -140,20 +139,24 @@ func setupCtx(opts options.ToolOptions) (*openssl.Ctx, error) { return nil, fmt.Errorf("LoadClientCAFile: %v", err) } ctx.SetClientCAList(calist) - if err = ctx.LoadVerifyLocations(opts.SSLCAFile, ""); err != nil { return nil, fmt.Errorf("LoadVerifyLocations: %v", err) } - - var verifyOption openssl.VerifyOptions - if opts.SSLAllowInvalidCert { - verifyOption = openssl.VerifyNone - } else { - verifyOption = openssl.VerifyPeer + } else { + err = ctx.SetupSystemCA() + if err != nil { + return nil, fmt.Errorf("Error setting up system certificate authority: %v", err) } - ctx.SetVerify(verifyOption, nil) } + var verifyOption openssl.VerifyOptions + if opts.SSLAllowInvalidCert { + verifyOption = openssl.VerifyNone + } else { + verifyOption = openssl.VerifyPeer + } + ctx.SetVerify(verifyOption, nil) + if opts.SSLCRLFile != "" { store := ctx.GetCertificateStore() store.SetFlags(openssl.CRLCheck) diff --git a/src/mongo/gotools/common/failpoint/failpoint.go b/src/mongo/gotools/common/failpoint/failpoint.go new file mode 100644 index 00000000000..ae8a99b7395 --- /dev/null +++ b/src/mongo/gotools/common/failpoint/failpoint.go @@ -0,0 +1,41 @@ +// +build failpoints + +// Package failpoint implements triggers for custom debugging behavoir +package failpoint + +import ( + "strings" +) + +var values map[string]string + +func init() { + values = make(map[string]string) +} + +// ParseFailpoints registers a comma-separated list of failpoint=value pairs +func ParseFailpoints(arg string) { + args := strings.Split(arg, ",") + for _, fp := range args { + if sep := strings.Index(fp, "="); sep != -1 { + key := fp[:sep] + val := fp[sep+1:] + values[key] = val + continue + } + values[fp] = "" + } +} + +// Get returns the value of the given failpoint and true, if it exists, and +// false otherwise +func Get(fp string) (string, bool) { + val, ok := values[fp] + return val, ok +} + +// Enabled returns true iff the given failpoint has been turned on +func Enabled(fp string) bool { + _, ok := Get(fp) + return ok +} diff --git a/src/mongo/gotools/common/failpoint/failpoint_disabled.go b/src/mongo/gotools/common/failpoint/failpoint_disabled.go new file mode 100644 index 00000000000..7fb3aa2a27d --- /dev/null +++ b/src/mongo/gotools/common/failpoint/failpoint_disabled.go @@ -0,0 +1,14 @@ +// +build !failpoints + +package failpoint + +func ParseFailpoints(_ string) { +} + +func Get(fp string) (string, bool) { + return "", false +} + +func Enabled(fp string) bool { + return false +} diff --git a/src/mongo/gotools/common/failpoint/failpoint_test.go b/src/mongo/gotools/common/failpoint/failpoint_test.go new file mode 100644 index 00000000000..8154e348dde --- /dev/null +++ b/src/mongo/gotools/common/failpoint/failpoint_test.go @@ -0,0 +1,42 @@ +// +build failpoints + +package failpoint + +import ( + "testing" + + "github.com/mongodb/mongo-tools/common/testutil" + . "github.com/smartystreets/goconvey/convey" +) + +func TestFailpointParsing(t *testing.T) { + testutil.VerifyTestType(t, testutil.UnitTestType) + + Convey("With test args", t, func() { + args := "foo=bar,baz,biz=,=a" + ParseFailpoints(args) + + So(Enabled("foo"), ShouldBeTrue) + So(Enabled("baz"), ShouldBeTrue) + So(Enabled("biz"), ShouldBeTrue) + So(Enabled(""), ShouldBeTrue) + So(Enabled("bar"), ShouldBeFalse) + + var val string + var ok bool + val, ok = Get("foo") + So(val, ShouldEqual, "bar") + So(ok, ShouldBeTrue) + val, ok = Get("baz") + So(val, ShouldEqual, "") + So(ok, ShouldBeTrue) + val, ok = Get("biz") + So(val, ShouldEqual, "") + So(ok, ShouldBeTrue) + val, ok = Get("") + So(val, ShouldEqual, "a") + So(ok, ShouldBeTrue) + val, ok = Get("bar") + So(ok, ShouldBeFalse) + }) +} diff --git a/src/mongo/gotools/common/failpoint/failpoints.go b/src/mongo/gotools/common/failpoint/failpoints.go new file mode 100644 index 00000000000..dbf78c776cd --- /dev/null +++ b/src/mongo/gotools/common/failpoint/failpoints.go @@ -0,0 +1,6 @@ +package failpoint + +// Supported failpoint names +const ( + PauseBeforeDumping = "PauseBeforeDumping" +) diff --git a/src/mongo/gotools/common/intents/intent.go b/src/mongo/gotools/common/intents/intent.go index 42999806744..8f317a3716f 100644 --- a/src/mongo/gotools/common/intents/intent.go +++ b/src/mongo/gotools/common/intents/intent.go @@ -109,6 +109,14 @@ func (intent *Intent) IsSpecialCollection() bool { return intent.IsSystemIndexes() || intent.IsUsers() || intent.IsRoles() || intent.IsAuthVersion() } +func (it *Intent) IsView() bool { + if it.Options == nil { + return false + } + _, isView := it.Options.Map()["viewOn"] + return isView +} + func (existing *Intent) MergeIntent(intent *Intent) { // merge new intent into old intent if existing.BSONFile == nil { diff --git a/src/mongo/gotools/common/options/options.go b/src/mongo/gotools/common/options/options.go index 8962be7f8d7..5e52c733a7f 100644 --- a/src/mongo/gotools/common/options/options.go +++ b/src/mongo/gotools/common/options/options.go @@ -4,6 +4,7 @@ package options import ( "github.com/jessevdk/go-flags" + "github.com/mongodb/mongo-tools/common/failpoint" "github.com/mongodb/mongo-tools/common/log" "fmt" @@ -64,7 +65,8 @@ type General struct { Help bool `long:"help" description:"print usage"` Version bool `long:"version" description:"print the tool version and exit"` - MaxProcs int `long:"numThreads" default:"0" hidden:"true"` + MaxProcs int `long:"numThreads" default:"0" hidden:"true"` + Failpoints string `long:"failpoints" hidden:"true"` } // Struct holding verbosity-related options @@ -177,6 +179,9 @@ func New(appName, usageStr string, enabled EnabledOptions) *ToolOptions { panic(fmt.Errorf("couldn't register verbosity options: %v", err)) } + // this call disables failpoints if compiled without failpoint support + EnableFailpoints(opts) + if enabled.Connection { if _, err := opts.parser.AddGroup("connection options", "", opts.Connection); err != nil { panic(fmt.Errorf("couldn't register connection options: %v", err)) @@ -298,7 +303,9 @@ func (o *ToolOptions) AddOptions(opts ExtraOptions) { // Parse the command line args. Returns any extra args not accounted for by // parsing, as well as an error if the parsing returns an error. func (o *ToolOptions) Parse() ([]string, error) { - return o.parser.Parse() + args, err := o.parser.Parse() + failpoint.ParseFailpoints(o.Failpoints) + return args, err } func (opts *ToolOptions) handleUnknownOption(option string, arg flags.SplitArgument, args []string) ([]string, error) { diff --git a/src/mongo/gotools/common/options/options_fp.go b/src/mongo/gotools/common/options/options_fp.go new file mode 100644 index 00000000000..f0254646708 --- /dev/null +++ b/src/mongo/gotools/common/options/options_fp.go @@ -0,0 +1,7 @@ +// +build failpoints + +package options + +// EnableFailpoints does nothing if we've compiled with failpoints enabled +func EnableFailpoints(opts *ToolOptions) { +} diff --git a/src/mongo/gotools/common/options/options_fp_disabled.go b/src/mongo/gotools/common/options/options_fp_disabled.go new file mode 100644 index 00000000000..3411cd8ad33 --- /dev/null +++ b/src/mongo/gotools/common/options/options_fp_disabled.go @@ -0,0 +1,9 @@ +// +build !failpoints + +package options + +// EnableFailpoints removes the failpoints options +func EnableFailpoints(opts *ToolOptions) { + opt := opts.FindOptionByLongName("failpoints") + opt.LongName = "" +} diff --git a/src/mongo/gotools/common/progress/manager.go b/src/mongo/gotools/common/progress/manager.go index e1c5f0db5f7..7322c0d45a3 100644 --- a/src/mongo/gotools/common/progress/manager.go +++ b/src/mongo/gotools/common/progress/manager.go @@ -2,70 +2,92 @@ package progress import ( "fmt" - "github.com/mongodb/mongo-tools/common/text" "io" "sync" "time" + + "github.com/mongodb/mongo-tools/common/text" ) +// Manager is an interface which tools can use to registers progressors which +// track the progress of any arbitrary operation. +type Manager interface { + // Attach registers the progressor with the manager under the given name. + // Any call to Attach must have a matching call to Detach. + Attach(name string, progressor Progressor) + + // Detach removes the progressor with the given name from the manager + Detach(name string) +} + const GridPadding = 2 -// Manager handles thread-safe synchronized progress bar writing, so that all -// given progress bars are written in a group at a given interval. -// The current implementation maintains insert order when printing, -// such that new bars appear at the bottom of the group. -type Manager struct { - waitTime time.Duration - writer io.Writer - bars []*Bar - barsLock *sync.Mutex - stopChan chan struct{} +// BarWriter implements Manager. It periodically prints the status of all of its +// progressors in the form of pretty progress bars. It handles thread-safe +// synchronized progress bar writing, so that its progressors are written in a +// group at a given interval. It maintains insertion order when printing, such +// that new bars appear at the bottom of the group. +type BarWriter struct { + sync.Mutex + + waitTime time.Duration + writer io.Writer + bars []*Bar + stopChan chan struct{} + barLength int + isBytes bool } -// NewProgressBarManager returns an initialized Manager with the given -// time.Duration to wait between writes -func NewProgressBarManager(w io.Writer, waitTime time.Duration) *Manager { - return &Manager{ - waitTime: waitTime, - writer: w, - barsLock: &sync.Mutex{}, - stopChan: make(chan struct{}), +// NewBarWriter returns an initialized BarWriter with the given bar length and +// byte-formatting toggle, waiting the given duration between writes +func NewBarWriter(w io.Writer, waitTime time.Duration, barLength int, isBytes bool) *BarWriter { + return &BarWriter{ + waitTime: waitTime, + writer: w, + stopChan: make(chan struct{}), + barLength: barLength, + isBytes: isBytes, } } -// Attach registers the given progress bar with the manager. Should be used as -// myManager.Attach(myBar) -// defer myManager.Detach(myBar) -func (manager *Manager) Attach(pb *Bar) { - // first some quick error checks - if pb.Name == "" { - panic("cannot attach a nameless bar to a progress bar manager") +// Attach registers the given progressor with the manager +func (manager *BarWriter) Attach(name string, progressor Progressor) { + pb := &Bar{ + Name: name, + Watching: progressor, + BarLength: manager.barLength, + IsBytes: manager.isBytes, } pb.validate() - manager.barsLock.Lock() - defer manager.barsLock.Unlock() + manager.Lock() + defer manager.Unlock() // make sure we are not adding the same bar again for _, bar := range manager.bars { - if bar.Name == pb.Name { - panic(fmt.Sprintf("progress bar with name '%v' already exists in manager", pb.Name)) + if bar.Name == name { + panic(fmt.Sprintf("progress bar with name '%s' already exists in manager", name)) } } manager.bars = append(manager.bars, pb) } -// Detach removes the given progress bar from the manager. -// Insert order is maintained for consistent ordering of the printed bars. -// Note: the manager removes progress bars by "Name" not by memory location -func (manager *Manager) Detach(pb *Bar) { - if pb.Name == "" { - panic("cannot detach a nameless bar from a progress bar manager") +// Detach removes the progressor with the given name from the manager. Insert +// order is maintained for consistent ordering of the printed bars. +func (manager *BarWriter) Detach(name string) { + manager.Lock() + defer manager.Unlock() + var pb *Bar + for _, bar := range manager.bars { + if bar.Name == name { + pb = bar + break + } + } + if pb == nil { + panic("could not find progressor") } - - manager.barsLock.Lock() - defer manager.barsLock.Unlock() grid := &text.GridWriter{ ColumnPadding: GridPadding, @@ -88,9 +110,9 @@ func (manager *Manager) Detach(pb *Bar) { } // helper to render all bars in order -func (manager *Manager) renderAllBars() { - manager.barsLock.Lock() - defer manager.barsLock.Unlock() +func (manager *BarWriter) renderAllBars() { + manager.Lock() + defer manager.Unlock() grid := &text.GridWriter{ ColumnPadding: GridPadding, } @@ -107,14 +129,14 @@ func (manager *Manager) renderAllBars() { } // Start kicks of the timed batch writing of progress bars. -func (manager *Manager) Start() { +func (manager *BarWriter) Start() { if manager.writer == nil { - panic("Cannot use a progress.Manager with an unset Writer") + panic("Cannot use a progress.BarWriter with an unset Writer") } go manager.start() } -func (manager *Manager) start() { +func (manager *BarWriter) start() { if manager.waitTime <= 0 { manager.waitTime = DefaultWaitTime } @@ -133,6 +155,6 @@ func (manager *Manager) start() { // Stop ends the main manager goroutine, stopping the manager's bars // from being rendered. -func (manager *Manager) Stop() { +func (manager *BarWriter) Stop() { manager.stopChan <- struct{}{} } diff --git a/src/mongo/gotools/common/progress/manager_test.go b/src/mongo/gotools/common/progress/manager_test.go index e997404e886..b881f039169 100644 --- a/src/mongo/gotools/common/progress/manager_test.go +++ b/src/mongo/gotools/common/progress/manager_test.go @@ -35,33 +35,18 @@ func (b *safeBuffer) Reset() { func TestManagerAttachAndDetach(t *testing.T) { writeBuffer := new(safeBuffer) - var manager *Manager + var manager *BarWriter - Convey("With an empty progress.Manager", t, func() { - manager = NewProgressBarManager(writeBuffer, time.Second) + Convey("With an empty progress.BarWriter", t, func() { + manager = NewBarWriter(writeBuffer, time.Second, 10, false) So(manager, ShouldNotBeNil) Convey("adding 3 bars", func() { - watching := NewCounter(10) - watching.Inc(5) - pbar1 := &Bar{ - Name: "\nTEST1", - Watching: watching, - BarLength: 10, - } - manager.Attach(pbar1) - pbar2 := &Bar{ - Name: "\nTEST2", - Watching: watching, - BarLength: 10, - } - manager.Attach(pbar2) - pbar3 := &Bar{ - Name: "\nTEST3", - Watching: watching, - BarLength: 10, - } - manager.Attach(pbar3) + progressor := NewCounter(10) + progressor.Inc(5) + manager.Attach("TEST1", progressor) + manager.Attach("TEST2", progressor) + manager.Attach("TEST3", progressor) So(len(manager.bars), ShouldEqual, 3) @@ -74,7 +59,7 @@ func TestManagerAttachAndDetach(t *testing.T) { }) Convey("detaching the second bar", func() { - manager.Detach(pbar2) + manager.Detach("TEST2") So(len(manager.bars), ShouldEqual, 2) Convey("should print 1,3", func() { @@ -91,13 +76,7 @@ func TestManagerAttachAndDetach(t *testing.T) { }) Convey("but adding a new bar should print 1,2,4", func() { - watching := NewCounter(10) - pbar4 := &Bar{ - Name: "\nTEST4", - Watching: watching, - BarLength: 10, - } - manager.Attach(pbar4) + manager.Attach("TEST4", progressor) So(len(manager.bars), ShouldEqual, 3) manager.renderAllBars() @@ -127,19 +106,14 @@ func TestManagerAttachAndDetach(t *testing.T) { func TestManagerStartAndStop(t *testing.T) { writeBuffer := new(safeBuffer) - var manager *Manager + var manager *BarWriter - Convey("With a progress.Manager with a waitTime of 10 ms and one bar", t, func() { - manager = NewProgressBarManager(writeBuffer, time.Millisecond*10) + Convey("With a progress.BarWriter with a waitTime of 10 ms and one bar", t, func() { + manager = NewBarWriter(writeBuffer, time.Millisecond*10, 10, false) So(manager, ShouldNotBeNil) watching := NewCounter(10) watching.Inc(5) - pbar := &Bar{ - Name: "\nTEST", - Watching: watching, - BarLength: 10, - } - manager.Attach(pbar) + manager.Attach("TEST", watching) So(manager.waitTime, ShouldEqual, time.Millisecond*10) So(len(manager.bars), ShouldEqual, 1) @@ -164,13 +138,13 @@ func TestManagerStartAndStop(t *testing.T) { func TestNumberOfWrites(t *testing.T) { var cw *CountWriter - var manager *Manager + var manager *BarWriter Convey("With a test manager and counting writer", t, func() { cw = new(CountWriter) - manager = NewProgressBarManager(cw, time.Millisecond*10) + manager = NewBarWriter(cw, time.Millisecond*10, 10, false) So(manager, ShouldNotBeNil) - manager.Attach(&Bar{Name: "1", Watching: NewCounter(10), BarLength: 10}) + manager.Attach("1", NewCounter(10)) Convey("with one attached bar", func() { So(len(manager.bars), ShouldEqual, 1) @@ -182,7 +156,7 @@ func TestNumberOfWrites(t *testing.T) { }) Convey("with two bars attached", func() { - manager.Attach(&Bar{Name: "2", Watching: NewCounter(10), BarLength: 10}) + manager.Attach("2", NewCounter(10)) So(len(manager.bars), ShouldEqual, 2) Convey("three writes should be made per render, since an empty write is added", func() { @@ -193,7 +167,7 @@ func TestNumberOfWrites(t *testing.T) { Convey("with 57 bars attached", func() { for i := 2; i <= 57; i++ { - manager.Attach(&Bar{Name: strconv.Itoa(i), Watching: NewCounter(10), BarLength: 10}) + manager.Attach(strconv.Itoa(i), NewCounter(10)) } So(len(manager.bars), ShouldEqual, 57) diff --git a/src/mongo/gotools/common/progress/progress_bar.go b/src/mongo/gotools/common/progress/progress_bar.go index 3b196b7f471..7727db4ea3e 100644 --- a/src/mongo/gotools/common/progress/progress_bar.go +++ b/src/mongo/gotools/common/progress/progress_bar.go @@ -4,10 +4,10 @@ package progress import ( "bytes" "fmt" - "github.com/mongodb/mongo-tools/common/text" "io" - "sync/atomic" "time" + + "github.com/mongodb/mongo-tools/common/text" ) const ( @@ -18,49 +18,6 @@ const ( BarRight = "]" ) -// countProgressor is an implementation of Progressor that uses -type countProgressor struct { - max int64 - current int64 -} - -func (c *countProgressor) Progress() (int64, int64) { - current := atomic.LoadInt64(&c.current) - return c.max, current -} - -func (c *countProgressor) Inc(amount int64) { - atomic.AddInt64(&c.current, amount) -} - -func (c *countProgressor) Set(amount int64) { - atomic.StoreInt64(&c.current, amount) -} - -func NewCounter(max int64) *countProgressor { - return &countProgressor{max, 0} -} - -// Progressor can be implemented to allow an object to hook up to a progress.Bar. -type Progressor interface { - // Progress returns a pair of integers: the total amount to reach 100%, and - // the amount completed. This method is called by progress.Bar to - // determine what percentage to display. - Progress() (int64, int64) -} - -// Updateable is a Progressor which also exposes the ability for the progressing -// value to be incremented, or reset. -type Updateable interface { - // Inc increments the current progress counter by the given amount. - Inc(amount int64) - - // Set resets the progress counter to the given amount. - Set(amount int64) - - Progressor -} - // Bar is a tool for concurrently monitoring the progress // of a task with a simple linear ASCII visualization type Bar struct { @@ -130,7 +87,7 @@ func (pb *Bar) Stop() { } func (pb *Bar) formatCounts() (string, string) { - maxCount, currentCount := pb.Watching.Progress() + currentCount, maxCount := pb.Watching.Progress() if pb.IsBytes { return text.FormatByteAmount(maxCount), text.FormatByteAmount(currentCount) } @@ -140,7 +97,7 @@ func (pb *Bar) formatCounts() (string, string) { // computes all necessary values renders to the bar's Writer func (pb *Bar) renderToWriter() { pb.hasRendered = true - maxCount, currentCount := pb.Watching.Progress() + currentCount, maxCount := pb.Watching.Progress() maxStr, currentStr := pb.formatCounts() if maxCount == 0 { // if we have no max amount, just print a count @@ -160,7 +117,7 @@ func (pb *Bar) renderToWriter() { func (pb *Bar) renderToGridRow(grid *text.GridWriter) { pb.hasRendered = true - maxCount, currentCount := pb.Watching.Progress() + currentCount, maxCount := pb.Watching.Progress() maxStr, currentStr := pb.formatCounts() if maxCount == 0 { // if we have no max amount, just print a count diff --git a/src/mongo/gotools/common/progress/progressor.go b/src/mongo/gotools/common/progress/progressor.go new file mode 100644 index 00000000000..72a8c85b8be --- /dev/null +++ b/src/mongo/gotools/common/progress/progressor.go @@ -0,0 +1,48 @@ +// Package progress exposes utilities to asynchronously monitor and display processing progress. +package progress + +import ( + "sync/atomic" +) + +// Progressor can be implemented to allow an object to hook up to a progress.Bar. +type Progressor interface { + // Progress returns a pair of integers: the amount completed and the total + // amount to reach 100%. This method is called by progress.Bar to determine + // what percentage to display. + Progress() (current, max int64) +} + +// Updateable is a Progressor which also exposes the ability for the progressing +// value to be updated. +type Updateable interface { + Progressor + + // Inc increments the current progress counter by the given amount. + Inc(amount int64) + + // Set sets the progress counter to the given amount. + Set(amount int64) +} + +// countProgressor is an implementation of Progressor that uses +type countProgressor struct { + max, current int64 +} + +func (c *countProgressor) Progress() (int64, int64) { + current := atomic.LoadInt64(&c.current) + return current, c.max +} + +func (c *countProgressor) Inc(amount int64) { + atomic.AddInt64(&c.current, amount) +} + +func (c *countProgressor) Set(amount int64) { + atomic.StoreInt64(&c.current, amount) +} + +func NewCounter(max int64) *countProgressor { + return &countProgressor{max, 0} +} diff --git a/src/mongo/gotools/common/signals/signals.go b/src/mongo/gotools/common/signals/signals.go index b7ce20b1bea..51ea1be9655 100644 --- a/src/mongo/gotools/common/signals/signals.go +++ b/src/mongo/gotools/common/signals/signals.go @@ -15,10 +15,11 @@ func Handle() chan struct{} { return HandleWithInterrupt(nil) } -// HandleWithInterrupt starts a goroutine which listens for SIGTERM, SIGINT, -// SIGKILL, and SIGPIPE. It calls the finalizer function when the first signal -// is received and forcibly terminates the program after the second. If a nil -// function is provided, the program will exit after the first signal. +// HandleWithInterrupt starts a goroutine which listens for SIGTERM, SIGINT, and +// SIGKILL and explicitly ignores SIGPIPE. It calls the finalizer function when +// the first signal is received and forcibly terminates the program after the +// second. If a nil function is provided, the program will exit after the first +// signal. func HandleWithInterrupt(finalizer func()) chan struct{} { finishedChan := make(chan struct{}) go handleSignals(finalizer, finishedChan) @@ -26,9 +27,13 @@ func HandleWithInterrupt(finalizer func()) chan struct{} { } func handleSignals(finalizer func(), finishedChan chan struct{}) { - log.Logv(log.DebugLow, "will listen for SIGTERM, SIGINT, SIGKILL, and SIGPIPE") + // explicitly ignore SIGPIPE; the tools should deal with write errors + noopChan := make(chan os.Signal) + signal.Notify(noopChan, syscall.SIGPIPE) + + log.Logv(log.DebugLow, "will listen for SIGTERM, SIGINT, and SIGKILL") sigChan := make(chan os.Signal, 2) - signal.Notify(sigChan, syscall.SIGTERM, syscall.SIGINT, syscall.SIGKILL, syscall.SIGPIPE) + signal.Notify(sigChan, syscall.SIGTERM, syscall.SIGINT, syscall.SIGKILL) defer signal.Stop(sigChan) if finalizer != nil { select { diff --git a/src/mongo/gotools/common/util/mongo.go b/src/mongo/gotools/common/util/mongo.go index 628bbc1d7f4..81904b4b37a 100644 --- a/src/mongo/gotools/common/util/mongo.go +++ b/src/mongo/gotools/common/util/mongo.go @@ -145,7 +145,7 @@ func ValidateDBName(database string) error { // check for illegal characters for _, illegalRune := range InvalidDBChars { if strings.ContainsRune(database, illegalRune) { - return fmt.Errorf("illegal character '%c' found in '%v'", illegalRune, database) + return fmt.Errorf("illegal character '%c' found in db name '%v'", illegalRune, database) } } |