summaryrefslogtreecommitdiff
path: root/workhorse/internal/secret
diff options
context:
space:
mode:
authorGitLab Bot <gitlab-bot@gitlab.com>2020-12-17 11:59:07 +0000
committerGitLab Bot <gitlab-bot@gitlab.com>2020-12-17 11:59:07 +0000
commit8b573c94895dc0ac0e1d9d59cf3e8745e8b539ca (patch)
tree544930fb309b30317ae9797a9683768705d664c4 /workhorse/internal/secret
parent4b1de649d0168371549608993deac953eb692019 (diff)
downloadgitlab-ce-8b573c94895dc0ac0e1d9d59cf3e8745e8b539ca.tar.gz
Add latest changes from gitlab-org/gitlab@13-7-stable-eev13.7.0-rc42
Diffstat (limited to 'workhorse/internal/secret')
-rw-r--r--workhorse/internal/secret/jwt.go25
-rw-r--r--workhorse/internal/secret/roundtripper.go35
-rw-r--r--workhorse/internal/secret/secret.go77
3 files changed, 137 insertions, 0 deletions
diff --git a/workhorse/internal/secret/jwt.go b/workhorse/internal/secret/jwt.go
new file mode 100644
index 00000000000..04335e58f76
--- /dev/null
+++ b/workhorse/internal/secret/jwt.go
@@ -0,0 +1,25 @@
+package secret
+
+import (
+ "fmt"
+
+ "github.com/dgrijalva/jwt-go"
+)
+
+var (
+ DefaultClaims = jwt.StandardClaims{Issuer: "gitlab-workhorse"}
+)
+
+func JWTTokenString(claims jwt.Claims) (string, error) {
+ secretBytes, err := Bytes()
+ if err != nil {
+ return "", fmt.Errorf("secret.JWTTokenString: %v", err)
+ }
+
+ tokenString, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString(secretBytes)
+ if err != nil {
+ return "", fmt.Errorf("secret.JWTTokenString: sign JWT: %v", err)
+ }
+
+ return tokenString, nil
+}
diff --git a/workhorse/internal/secret/roundtripper.go b/workhorse/internal/secret/roundtripper.go
new file mode 100644
index 00000000000..50bf7fff5b8
--- /dev/null
+++ b/workhorse/internal/secret/roundtripper.go
@@ -0,0 +1,35 @@
+package secret
+
+import (
+ "net/http"
+)
+
+const (
+ // This header carries the JWT token for gitlab-rails
+ RequestHeader = "Gitlab-Workhorse-Api-Request"
+)
+
+type roundTripper struct {
+ next http.RoundTripper
+ version string
+}
+
+// NewRoundTripper creates a RoundTripper that adds the JWT token header to a
+// request. This is used to verify that a request came from workhorse
+func NewRoundTripper(next http.RoundTripper, version string) http.RoundTripper {
+ return &roundTripper{next: next, version: version}
+}
+
+func (r *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
+ tokenString, err := JWTTokenString(DefaultClaims)
+ if err != nil {
+ return nil, err
+ }
+
+ // Set a custom header for the request. This can be used in some
+ // configurations (Passenger) to solve auth request routing problems.
+ req.Header.Set("Gitlab-Workhorse", r.version)
+ req.Header.Set(RequestHeader, tokenString)
+
+ return r.next.RoundTrip(req)
+}
diff --git a/workhorse/internal/secret/secret.go b/workhorse/internal/secret/secret.go
new file mode 100644
index 00000000000..e8c7c25393c
--- /dev/null
+++ b/workhorse/internal/secret/secret.go
@@ -0,0 +1,77 @@
+package secret
+
+import (
+ "encoding/base64"
+ "fmt"
+ "io/ioutil"
+ "sync"
+)
+
+const numSecretBytes = 32
+
+type sec struct {
+ path string
+ bytes []byte
+ sync.RWMutex
+}
+
+var (
+ theSecret = &sec{}
+)
+
+func SetPath(path string) {
+ theSecret.Lock()
+ defer theSecret.Unlock()
+ theSecret.path = path
+ theSecret.bytes = nil
+}
+
+// Lazy access to the HMAC secret key. We must be lazy because if the key
+// is not already there, it will be generated by gitlab-rails, and
+// gitlab-rails is slow.
+func Bytes() ([]byte, error) {
+ if bytes := getBytes(); bytes != nil {
+ return copyBytes(bytes), nil
+ }
+
+ return setBytes()
+}
+
+func getBytes() []byte {
+ theSecret.RLock()
+ defer theSecret.RUnlock()
+ return theSecret.bytes
+}
+
+func copyBytes(bytes []byte) []byte {
+ out := make([]byte, len(bytes))
+ copy(out, bytes)
+ return out
+}
+
+func setBytes() ([]byte, error) {
+ theSecret.Lock()
+ defer theSecret.Unlock()
+
+ if theSecret.bytes != nil {
+ return theSecret.bytes, nil
+ }
+
+ base64Bytes, err := ioutil.ReadFile(theSecret.path)
+ if err != nil {
+ return nil, fmt.Errorf("secret.setBytes: read %q: %v", theSecret.path, err)
+ }
+
+ secretBytes := make([]byte, base64.StdEncoding.DecodedLen(len(base64Bytes)))
+ n, err := base64.StdEncoding.Decode(secretBytes, base64Bytes)
+ if err != nil {
+ return nil, fmt.Errorf("secret.setBytes: decode secret: %v", err)
+ }
+
+ if n != numSecretBytes {
+ return nil, fmt.Errorf("secret.setBytes: expected %d secretBytes in %s, found %d", numSecretBytes, theSecret.path, n)
+ }
+
+ theSecret.bytes = secretBytes
+ return copyBytes(theSecret.bytes), nil
+}