diff options
Diffstat (limited to 'workhorse/internal/redis')
-rw-r--r-- | workhorse/internal/redis/keywatcher.go | 18 | ||||
-rw-r--r-- | workhorse/internal/redis/keywatcher_test.go | 55 | ||||
-rw-r--r-- | workhorse/internal/redis/redis.go | 64 | ||||
-rw-r--r-- | workhorse/internal/redis/redis_test.go | 16 |
4 files changed, 101 insertions, 52 deletions
diff --git a/workhorse/internal/redis/keywatcher.go b/workhorse/internal/redis/keywatcher.go index 8f3e61b5e9f..10d80d13d22 100644 --- a/workhorse/internal/redis/keywatcher.go +++ b/workhorse/internal/redis/keywatcher.go @@ -17,6 +17,7 @@ import ( var ( keyWatcher = make(map[string][]chan string) keyWatcherMutex sync.Mutex + shutdown = make(chan struct{}) redisReconnectTimeout = backoff.Backoff{ //These are the defaults Min: 100 * time.Millisecond, @@ -112,6 +113,20 @@ func Process() { } } +func Shutdown() { + log.Info("keywatcher: shutting down") + + keyWatcherMutex.Lock() + defer keyWatcherMutex.Unlock() + + select { + case <-shutdown: + // already closed + default: + close(shutdown) + } +} + func notifyChanWatchers(key, value string) { keyWatcherMutex.Lock() defer keyWatcherMutex.Unlock() @@ -182,6 +197,9 @@ func WatchKey(key, value string, timeout time.Duration) (WatchKeyStatus, error) } select { + case <-shutdown: + log.WithFields(log.Fields{"key": key}).Info("stopping watch due to shutdown") + return WatchKeyStatusNoChange, nil case currentValue := <-kw.Chan: if currentValue == "" { return WatchKeyStatusNoChange, fmt.Errorf("keywatcher: redis GET failed") diff --git a/workhorse/internal/redis/keywatcher_test.go b/workhorse/internal/redis/keywatcher_test.go index f1ee77e2194..99892bc64b8 100644 --- a/workhorse/internal/redis/keywatcher_test.go +++ b/workhorse/internal/redis/keywatcher_test.go @@ -160,3 +160,58 @@ func TestWatchKeyMassivelyParallel(t *testing.T) { processMessages(runTimes, "somethingelse") wg.Wait() } + +func TestShutdown(t *testing.T) { + conn, td := setupMockPool() + defer td() + defer func() { shutdown = make(chan struct{}) }() + + conn.Command("GET", runnerKey).Expect("something") + + wg := &sync.WaitGroup{} + wg.Add(2) + + go func() { + val, err := WatchKey(runnerKey, "something", 10*time.Second) + + require.NoError(t, err, "Expected no error") + require.Equal(t, WatchKeyStatusNoChange, val, "Expected value not to change") + wg.Done() + }() + + go func() { + for countWatchers(runnerKey) == 0 { + time.Sleep(time.Millisecond) + } + + require.Equal(t, 1, countWatchers(runnerKey)) + + Shutdown() + wg.Done() + }() + + wg.Wait() + + for countWatchers(runnerKey) == 1 { + time.Sleep(time.Millisecond) + } + + require.Equal(t, 0, countWatchers(runnerKey)) + + // Adding a key after the shutdown should result in an immediate response + var val WatchKeyStatus + var err error + done := make(chan struct{}) + go func() { + val, err = WatchKey(runnerKey, "something", 10*time.Second) + close(done) + }() + + select { + case <-done: + require.NoError(t, err, "Expected no error") + require.Equal(t, WatchKeyStatusNoChange, val, "Expected value not to change") + case <-time.After(100 * time.Millisecond): + t.Fatal("timeout waiting for WatchKey") + } +} diff --git a/workhorse/internal/redis/redis.go b/workhorse/internal/redis/redis.go index 0029a2a9e2b..b11a8184bca 100644 --- a/workhorse/internal/redis/redis.go +++ b/workhorse/internal/redis/redis.go @@ -113,21 +113,9 @@ var poolDialFunc func() (redis.Conn, error) var workerDialFunc func() (redis.Conn, error) func timeoutDialOptions(cfg *config.RedisConfig) []redis.DialOption { - readTimeout := defaultReadTimeout - writeTimeout := defaultWriteTimeout - - if cfg != nil { - if cfg.ReadTimeout != nil { - readTimeout = cfg.ReadTimeout.Duration - } - - if cfg.WriteTimeout != nil { - writeTimeout = cfg.WriteTimeout.Duration - } - } return []redis.DialOption{ - redis.DialReadTimeout(readTimeout), - redis.DialWriteTimeout(writeTimeout), + redis.DialReadTimeout(defaultReadTimeout), + redis.DialWriteTimeout(defaultWriteTimeout), } } @@ -148,47 +136,45 @@ func dialOptionsBuilder(cfg *config.RedisConfig, setTimeouts bool) []redis.DialO return dopts } -func keepAliveDialer(timeout time.Duration) func(string, string) (net.Conn, error) { - return func(network, address string) (net.Conn, error) { - addr, err := net.ResolveTCPAddr(network, address) - if err != nil { - return nil, err - } - tc, err := net.DialTCP(network, nil, addr) - if err != nil { - return nil, err - } - if err := tc.SetKeepAlive(true); err != nil { - return nil, err - } - if err := tc.SetKeepAlivePeriod(timeout); err != nil { - return nil, err - } - return tc, nil +func keepAliveDialer(network, address string) (net.Conn, error) { + addr, err := net.ResolveTCPAddr(network, address) + if err != nil { + return nil, err } + tc, err := net.DialTCP(network, nil, addr) + if err != nil { + return nil, err + } + if err := tc.SetKeepAlive(true); err != nil { + return nil, err + } + if err := tc.SetKeepAlivePeriod(defaultKeepAlivePeriod); err != nil { + return nil, err + } + return tc, nil } type redisDialerFunc func() (redis.Conn, error) -func sentinelDialer(dopts []redis.DialOption, keepAlivePeriod time.Duration) redisDialerFunc { +func sentinelDialer(dopts []redis.DialOption) redisDialerFunc { return func() (redis.Conn, error) { address, err := sntnl.MasterAddr() if err != nil { errorCounter.WithLabelValues("master", "sentinel").Inc() return nil, err } - dopts = append(dopts, redis.DialNetDial(keepAliveDialer(keepAlivePeriod))) + dopts = append(dopts, redis.DialNetDial(keepAliveDialer)) return redisDial("tcp", address, dopts...) } } -func defaultDialer(dopts []redis.DialOption, keepAlivePeriod time.Duration, url url.URL) redisDialerFunc { +func defaultDialer(dopts []redis.DialOption, url url.URL) redisDialerFunc { return func() (redis.Conn, error) { if url.Scheme == "unix" { return redisDial(url.Scheme, url.Path, dopts...) } - dopts = append(dopts, redis.DialNetDial(keepAliveDialer(keepAlivePeriod))) + dopts = append(dopts, redis.DialNetDial(keepAliveDialer)) // redis.DialURL only works with redis[s]:// URLs if url.Scheme == "redis" || url.Scheme == "rediss" { @@ -231,15 +217,11 @@ func countDialer(dialer redisDialerFunc) redisDialerFunc { // DefaultDialFunc should always used. Only exception is for unit-tests. func DefaultDialFunc(cfg *config.RedisConfig, setReadTimeout bool) func() (redis.Conn, error) { - keepAlivePeriod := defaultKeepAlivePeriod - if cfg.KeepAlivePeriod != nil { - keepAlivePeriod = cfg.KeepAlivePeriod.Duration - } dopts := dialOptionsBuilder(cfg, setReadTimeout) if sntnl != nil { - return countDialer(sentinelDialer(dopts, keepAlivePeriod)) + return countDialer(sentinelDialer(dopts)) } - return countDialer(defaultDialer(dopts, keepAlivePeriod, cfg.URL.URL)) + return countDialer(defaultDialer(dopts, cfg.URL.URL)) } // Configure redis-connection diff --git a/workhorse/internal/redis/redis_test.go b/workhorse/internal/redis/redis_test.go index f4b4120517d..eee2f99bbbf 100644 --- a/workhorse/internal/redis/redis_test.go +++ b/workhorse/internal/redis/redis_test.go @@ -96,13 +96,11 @@ func TestConfigureMinimalConfig(t *testing.T) { func TestConfigureFullConfig(t *testing.T) { i, a := 4, 10 - r := config.TomlDuration{Duration: 3} cfg := &config.RedisConfig{ - URL: config.TomlURL{}, - Password: "", - MaxIdle: &i, - MaxActive: &a, - ReadTimeout: &r, + URL: config.TomlURL{}, + Password: "", + MaxIdle: &i, + MaxActive: &a, } Configure(cfg, DefaultDialFunc) @@ -219,11 +217,7 @@ func TestDialOptionsBuildersSetTimeouts(t *testing.T) { } func TestDialOptionsBuildersSetTimeoutsConfig(t *testing.T) { - cfg := &config.RedisConfig{ - ReadTimeout: &config.TomlDuration{Duration: time.Second * time.Duration(15)}, - WriteTimeout: &config.TomlDuration{Duration: time.Second * time.Duration(15)}, - } - dopts := dialOptionsBuilder(cfg, true) + dopts := dialOptionsBuilder(nil, true) require.Equal(t, 2, len(dopts)) } |