summaryrefslogtreecommitdiff
path: root/workhorse/internal/gitaly/gitaly.go
diff options
context:
space:
mode:
Diffstat (limited to 'workhorse/internal/gitaly/gitaly.go')
-rw-r--r--workhorse/internal/gitaly/gitaly.go88
1 files changed, 51 insertions, 37 deletions
diff --git a/workhorse/internal/gitaly/gitaly.go b/workhorse/internal/gitaly/gitaly.go
index 362f380dc4d..db1fd3f8abb 100644
--- a/workhorse/internal/gitaly/gitaly.go
+++ b/workhorse/internal/gitaly/gitaly.go
@@ -19,24 +19,18 @@ import (
gitalyclient "gitlab.com/gitlab-org/gitaly/v14/client"
"gitlab.com/gitlab-org/gitaly/v14/proto/go/gitalypb"
+ "gitlab.com/gitlab-org/gitlab/workhorse/internal/api"
+
grpccorrelation "gitlab.com/gitlab-org/labkit/correlation/grpc"
grpctracing "gitlab.com/gitlab-org/labkit/tracing/grpc"
)
-type Server struct {
- Address string `json:"address"`
- Token string `json:"token"`
- Features map[string]string `json:"features"`
- Sidechannel bool `json:"sidechannel"`
-}
-
type cacheKey struct {
address, token string
- sidechannel bool
}
-func (server Server) cacheKey() cacheKey {
- return cacheKey{address: server.Address, token: server.Token, sidechannel: server.Sidechannel}
+func getCacheKey(server api.GitalyServer) cacheKey {
+ return cacheKey{address: server.Address, token: server.Token}
}
type connectionsCache struct {
@@ -73,19 +67,42 @@ func InitializeSidechannelRegistry(logger *logrus.Logger) {
}
}
-func withOutgoingMetadata(ctx context.Context, features map[string]string) context.Context {
- md := metadata.New(nil)
- for k, v := range features {
- if !strings.HasPrefix(k, "gitaly-feature-") {
- continue
+type MetadataFunc func(metadata.MD)
+
+func WithUserID(userID string) MetadataFunc {
+ return func(md metadata.MD) {
+ md.Append("user_id", userID)
+ }
+}
+
+func WithUsername(username string) MetadataFunc {
+ return func(md metadata.MD) {
+ md.Append("username", username)
+ }
+}
+
+func WithFeatures(features map[string]string) MetadataFunc {
+ return func(md metadata.MD) {
+ for k, v := range features {
+ if !strings.HasPrefix(k, "gitaly-feature-") {
+ continue
+ }
+ md.Append(k, v)
}
- md.Append(k, v)
+ }
+}
+
+func withOutgoingMetadata(ctx context.Context, addMetadataFuncs ...MetadataFunc) context.Context {
+ md := metadata.New(nil)
+
+ for _, f := range addMetadataFuncs {
+ f(md)
}
return metadata.NewOutgoingContext(ctx, md)
}
-func NewSmartHTTPClient(ctx context.Context, server Server) (context.Context, *SmartHTTPClient, error) {
+func NewSmartHTTPClient(ctx context.Context, server api.GitalyServer, metadataFuncs ...MetadataFunc) (context.Context, *SmartHTTPClient, error) {
conn, err := getOrCreateConnection(server)
if err != nil {
return nil, nil, err
@@ -94,50 +111,53 @@ func NewSmartHTTPClient(ctx context.Context, server Server) (context.Context, *S
smartHTTPClient := &SmartHTTPClient{
SmartHTTPServiceClient: grpcClient,
sidechannelRegistry: sidechannelRegistry,
- useSidechannel: server.Sidechannel,
}
- return withOutgoingMetadata(ctx, server.Features), smartHTTPClient, nil
+
+ return withOutgoingMetadata(
+ ctx,
+ metadataFuncs...,
+ ), smartHTTPClient, nil
}
-func NewBlobClient(ctx context.Context, server Server) (context.Context, *BlobClient, error) {
+func NewBlobClient(ctx context.Context, server api.GitalyServer, addMetadataFuncs ...MetadataFunc) (context.Context, *BlobClient, error) {
conn, err := getOrCreateConnection(server)
if err != nil {
return nil, nil, err
}
grpcClient := gitalypb.NewBlobServiceClient(conn)
- return withOutgoingMetadata(ctx, server.Features), &BlobClient{grpcClient}, nil
+ return withOutgoingMetadata(ctx, addMetadataFuncs...), &BlobClient{grpcClient}, nil
}
-func NewRepositoryClient(ctx context.Context, server Server) (context.Context, *RepositoryClient, error) {
+func NewRepositoryClient(ctx context.Context, server api.GitalyServer, addMetadataFuncs ...MetadataFunc) (context.Context, *RepositoryClient, error) {
conn, err := getOrCreateConnection(server)
if err != nil {
return nil, nil, err
}
grpcClient := gitalypb.NewRepositoryServiceClient(conn)
- return withOutgoingMetadata(ctx, server.Features), &RepositoryClient{grpcClient}, nil
+ return withOutgoingMetadata(ctx, addMetadataFuncs...), &RepositoryClient{grpcClient}, nil
}
// NewNamespaceClient is only used by the Gitaly integration tests at present
-func NewNamespaceClient(ctx context.Context, server Server) (context.Context, *NamespaceClient, error) {
+func NewNamespaceClient(ctx context.Context, server api.GitalyServer, addMetadataFuncs ...MetadataFunc) (context.Context, *NamespaceClient, error) {
conn, err := getOrCreateConnection(server)
if err != nil {
return nil, nil, err
}
grpcClient := gitalypb.NewNamespaceServiceClient(conn)
- return withOutgoingMetadata(ctx, server.Features), &NamespaceClient{grpcClient}, nil
+ return withOutgoingMetadata(ctx, addMetadataFuncs...), &NamespaceClient{grpcClient}, nil
}
-func NewDiffClient(ctx context.Context, server Server) (context.Context, *DiffClient, error) {
+func NewDiffClient(ctx context.Context, server api.GitalyServer, addMetadataFuncs ...MetadataFunc) (context.Context, *DiffClient, error) {
conn, err := getOrCreateConnection(server)
if err != nil {
return nil, nil, err
}
grpcClient := gitalypb.NewDiffServiceClient(conn)
- return withOutgoingMetadata(ctx, server.Features), &DiffClient{grpcClient}, nil
+ return withOutgoingMetadata(ctx, addMetadataFuncs...), &DiffClient{grpcClient}, nil
}
-func getOrCreateConnection(server Server) (*grpc.ClientConn, error) {
- key := server.cacheKey()
+func getOrCreateConnection(server api.GitalyServer) (*grpc.ClientConn, error) {
+ key := getCacheKey(server)
cache.RLock()
conn := cache.connections[key]
@@ -173,7 +193,7 @@ func CloseConnections() {
}
}
-func newConnection(server Server) (*grpc.ClientConn, error) {
+func newConnection(server api.GitalyServer) (*grpc.ClientConn, error) {
connOpts := append(gitalyclient.DefaultDialOpts,
grpc.WithPerRPCCredentials(gitalyauth.RPCCredentialsV2(server.Token)),
grpc.WithStreamInterceptor(
@@ -197,13 +217,7 @@ func newConnection(server Server) (*grpc.ClientConn, error) {
),
)
- var conn *grpc.ClientConn
- var connErr error
- if server.Sidechannel {
- conn, connErr = gitalyclient.DialSidechannel(context.Background(), server.Address, sidechannelRegistry, connOpts) // lint:allow context.Background
- } else {
- conn, connErr = gitalyclient.Dial(server.Address, connOpts)
- }
+ conn, connErr := gitalyclient.DialSidechannel(context.Background(), server.Address, sidechannelRegistry, connOpts) // lint:allow context.Background
label := "ok"
if connErr != nil {