summaryrefslogtreecommitdiff
path: root/workhorse/internal/redis
diff options
context:
space:
mode:
Diffstat (limited to 'workhorse/internal/redis')
-rw-r--r--workhorse/internal/redis/keywatcher.go18
-rw-r--r--workhorse/internal/redis/keywatcher_test.go55
-rw-r--r--workhorse/internal/redis/redis.go64
-rw-r--r--workhorse/internal/redis/redis_test.go16
4 files changed, 101 insertions, 52 deletions
diff --git a/workhorse/internal/redis/keywatcher.go b/workhorse/internal/redis/keywatcher.go
index 8f3e61b5e9f..10d80d13d22 100644
--- a/workhorse/internal/redis/keywatcher.go
+++ b/workhorse/internal/redis/keywatcher.go
@@ -17,6 +17,7 @@ import (
var (
keyWatcher = make(map[string][]chan string)
keyWatcherMutex sync.Mutex
+ shutdown = make(chan struct{})
redisReconnectTimeout = backoff.Backoff{
//These are the defaults
Min: 100 * time.Millisecond,
@@ -112,6 +113,20 @@ func Process() {
}
}
+func Shutdown() {
+ log.Info("keywatcher: shutting down")
+
+ keyWatcherMutex.Lock()
+ defer keyWatcherMutex.Unlock()
+
+ select {
+ case <-shutdown:
+ // already closed
+ default:
+ close(shutdown)
+ }
+}
+
func notifyChanWatchers(key, value string) {
keyWatcherMutex.Lock()
defer keyWatcherMutex.Unlock()
@@ -182,6 +197,9 @@ func WatchKey(key, value string, timeout time.Duration) (WatchKeyStatus, error)
}
select {
+ case <-shutdown:
+ log.WithFields(log.Fields{"key": key}).Info("stopping watch due to shutdown")
+ return WatchKeyStatusNoChange, nil
case currentValue := <-kw.Chan:
if currentValue == "" {
return WatchKeyStatusNoChange, fmt.Errorf("keywatcher: redis GET failed")
diff --git a/workhorse/internal/redis/keywatcher_test.go b/workhorse/internal/redis/keywatcher_test.go
index f1ee77e2194..99892bc64b8 100644
--- a/workhorse/internal/redis/keywatcher_test.go
+++ b/workhorse/internal/redis/keywatcher_test.go
@@ -160,3 +160,58 @@ func TestWatchKeyMassivelyParallel(t *testing.T) {
processMessages(runTimes, "somethingelse")
wg.Wait()
}
+
+func TestShutdown(t *testing.T) {
+ conn, td := setupMockPool()
+ defer td()
+ defer func() { shutdown = make(chan struct{}) }()
+
+ conn.Command("GET", runnerKey).Expect("something")
+
+ wg := &sync.WaitGroup{}
+ wg.Add(2)
+
+ go func() {
+ val, err := WatchKey(runnerKey, "something", 10*time.Second)
+
+ require.NoError(t, err, "Expected no error")
+ require.Equal(t, WatchKeyStatusNoChange, val, "Expected value not to change")
+ wg.Done()
+ }()
+
+ go func() {
+ for countWatchers(runnerKey) == 0 {
+ time.Sleep(time.Millisecond)
+ }
+
+ require.Equal(t, 1, countWatchers(runnerKey))
+
+ Shutdown()
+ wg.Done()
+ }()
+
+ wg.Wait()
+
+ for countWatchers(runnerKey) == 1 {
+ time.Sleep(time.Millisecond)
+ }
+
+ require.Equal(t, 0, countWatchers(runnerKey))
+
+ // Adding a key after the shutdown should result in an immediate response
+ var val WatchKeyStatus
+ var err error
+ done := make(chan struct{})
+ go func() {
+ val, err = WatchKey(runnerKey, "something", 10*time.Second)
+ close(done)
+ }()
+
+ select {
+ case <-done:
+ require.NoError(t, err, "Expected no error")
+ require.Equal(t, WatchKeyStatusNoChange, val, "Expected value not to change")
+ case <-time.After(100 * time.Millisecond):
+ t.Fatal("timeout waiting for WatchKey")
+ }
+}
diff --git a/workhorse/internal/redis/redis.go b/workhorse/internal/redis/redis.go
index 0029a2a9e2b..b11a8184bca 100644
--- a/workhorse/internal/redis/redis.go
+++ b/workhorse/internal/redis/redis.go
@@ -113,21 +113,9 @@ var poolDialFunc func() (redis.Conn, error)
var workerDialFunc func() (redis.Conn, error)
func timeoutDialOptions(cfg *config.RedisConfig) []redis.DialOption {
- readTimeout := defaultReadTimeout
- writeTimeout := defaultWriteTimeout
-
- if cfg != nil {
- if cfg.ReadTimeout != nil {
- readTimeout = cfg.ReadTimeout.Duration
- }
-
- if cfg.WriteTimeout != nil {
- writeTimeout = cfg.WriteTimeout.Duration
- }
- }
return []redis.DialOption{
- redis.DialReadTimeout(readTimeout),
- redis.DialWriteTimeout(writeTimeout),
+ redis.DialReadTimeout(defaultReadTimeout),
+ redis.DialWriteTimeout(defaultWriteTimeout),
}
}
@@ -148,47 +136,45 @@ func dialOptionsBuilder(cfg *config.RedisConfig, setTimeouts bool) []redis.DialO
return dopts
}
-func keepAliveDialer(timeout time.Duration) func(string, string) (net.Conn, error) {
- return func(network, address string) (net.Conn, error) {
- addr, err := net.ResolveTCPAddr(network, address)
- if err != nil {
- return nil, err
- }
- tc, err := net.DialTCP(network, nil, addr)
- if err != nil {
- return nil, err
- }
- if err := tc.SetKeepAlive(true); err != nil {
- return nil, err
- }
- if err := tc.SetKeepAlivePeriod(timeout); err != nil {
- return nil, err
- }
- return tc, nil
+func keepAliveDialer(network, address string) (net.Conn, error) {
+ addr, err := net.ResolveTCPAddr(network, address)
+ if err != nil {
+ return nil, err
}
+ tc, err := net.DialTCP(network, nil, addr)
+ if err != nil {
+ return nil, err
+ }
+ if err := tc.SetKeepAlive(true); err != nil {
+ return nil, err
+ }
+ if err := tc.SetKeepAlivePeriod(defaultKeepAlivePeriod); err != nil {
+ return nil, err
+ }
+ return tc, nil
}
type redisDialerFunc func() (redis.Conn, error)
-func sentinelDialer(dopts []redis.DialOption, keepAlivePeriod time.Duration) redisDialerFunc {
+func sentinelDialer(dopts []redis.DialOption) redisDialerFunc {
return func() (redis.Conn, error) {
address, err := sntnl.MasterAddr()
if err != nil {
errorCounter.WithLabelValues("master", "sentinel").Inc()
return nil, err
}
- dopts = append(dopts, redis.DialNetDial(keepAliveDialer(keepAlivePeriod)))
+ dopts = append(dopts, redis.DialNetDial(keepAliveDialer))
return redisDial("tcp", address, dopts...)
}
}
-func defaultDialer(dopts []redis.DialOption, keepAlivePeriod time.Duration, url url.URL) redisDialerFunc {
+func defaultDialer(dopts []redis.DialOption, url url.URL) redisDialerFunc {
return func() (redis.Conn, error) {
if url.Scheme == "unix" {
return redisDial(url.Scheme, url.Path, dopts...)
}
- dopts = append(dopts, redis.DialNetDial(keepAliveDialer(keepAlivePeriod)))
+ dopts = append(dopts, redis.DialNetDial(keepAliveDialer))
// redis.DialURL only works with redis[s]:// URLs
if url.Scheme == "redis" || url.Scheme == "rediss" {
@@ -231,15 +217,11 @@ func countDialer(dialer redisDialerFunc) redisDialerFunc {
// DefaultDialFunc should always used. Only exception is for unit-tests.
func DefaultDialFunc(cfg *config.RedisConfig, setReadTimeout bool) func() (redis.Conn, error) {
- keepAlivePeriod := defaultKeepAlivePeriod
- if cfg.KeepAlivePeriod != nil {
- keepAlivePeriod = cfg.KeepAlivePeriod.Duration
- }
dopts := dialOptionsBuilder(cfg, setReadTimeout)
if sntnl != nil {
- return countDialer(sentinelDialer(dopts, keepAlivePeriod))
+ return countDialer(sentinelDialer(dopts))
}
- return countDialer(defaultDialer(dopts, keepAlivePeriod, cfg.URL.URL))
+ return countDialer(defaultDialer(dopts, cfg.URL.URL))
}
// Configure redis-connection
diff --git a/workhorse/internal/redis/redis_test.go b/workhorse/internal/redis/redis_test.go
index f4b4120517d..eee2f99bbbf 100644
--- a/workhorse/internal/redis/redis_test.go
+++ b/workhorse/internal/redis/redis_test.go
@@ -96,13 +96,11 @@ func TestConfigureMinimalConfig(t *testing.T) {
func TestConfigureFullConfig(t *testing.T) {
i, a := 4, 10
- r := config.TomlDuration{Duration: 3}
cfg := &config.RedisConfig{
- URL: config.TomlURL{},
- Password: "",
- MaxIdle: &i,
- MaxActive: &a,
- ReadTimeout: &r,
+ URL: config.TomlURL{},
+ Password: "",
+ MaxIdle: &i,
+ MaxActive: &a,
}
Configure(cfg, DefaultDialFunc)
@@ -219,11 +217,7 @@ func TestDialOptionsBuildersSetTimeouts(t *testing.T) {
}
func TestDialOptionsBuildersSetTimeoutsConfig(t *testing.T) {
- cfg := &config.RedisConfig{
- ReadTimeout: &config.TomlDuration{Duration: time.Second * time.Duration(15)},
- WriteTimeout: &config.TomlDuration{Duration: time.Second * time.Duration(15)},
- }
- dopts := dialOptionsBuilder(cfg, true)
+ dopts := dialOptionsBuilder(nil, true)
require.Equal(t, 2, len(dopts))
}