diff options
Diffstat (limited to 'workhorse/internal/gitaly/gitaly.go')
-rw-r--r-- | workhorse/internal/gitaly/gitaly.go | 88 |
1 files changed, 51 insertions, 37 deletions
diff --git a/workhorse/internal/gitaly/gitaly.go b/workhorse/internal/gitaly/gitaly.go index 362f380dc4d..db1fd3f8abb 100644 --- a/workhorse/internal/gitaly/gitaly.go +++ b/workhorse/internal/gitaly/gitaly.go @@ -19,24 +19,18 @@ import ( gitalyclient "gitlab.com/gitlab-org/gitaly/v14/client" "gitlab.com/gitlab-org/gitaly/v14/proto/go/gitalypb" + "gitlab.com/gitlab-org/gitlab/workhorse/internal/api" + grpccorrelation "gitlab.com/gitlab-org/labkit/correlation/grpc" grpctracing "gitlab.com/gitlab-org/labkit/tracing/grpc" ) -type Server struct { - Address string `json:"address"` - Token string `json:"token"` - Features map[string]string `json:"features"` - Sidechannel bool `json:"sidechannel"` -} - type cacheKey struct { address, token string - sidechannel bool } -func (server Server) cacheKey() cacheKey { - return cacheKey{address: server.Address, token: server.Token, sidechannel: server.Sidechannel} +func getCacheKey(server api.GitalyServer) cacheKey { + return cacheKey{address: server.Address, token: server.Token} } type connectionsCache struct { @@ -73,19 +67,42 @@ func InitializeSidechannelRegistry(logger *logrus.Logger) { } } -func withOutgoingMetadata(ctx context.Context, features map[string]string) context.Context { - md := metadata.New(nil) - for k, v := range features { - if !strings.HasPrefix(k, "gitaly-feature-") { - continue +type MetadataFunc func(metadata.MD) + +func WithUserID(userID string) MetadataFunc { + return func(md metadata.MD) { + md.Append("user_id", userID) + } +} + +func WithUsername(username string) MetadataFunc { + return func(md metadata.MD) { + md.Append("username", username) + } +} + +func WithFeatures(features map[string]string) MetadataFunc { + return func(md metadata.MD) { + for k, v := range features { + if !strings.HasPrefix(k, "gitaly-feature-") { + continue + } + md.Append(k, v) } - md.Append(k, v) + } +} + +func withOutgoingMetadata(ctx context.Context, addMetadataFuncs ...MetadataFunc) context.Context { + md := metadata.New(nil) + + for _, f := range addMetadataFuncs { + f(md) } return metadata.NewOutgoingContext(ctx, md) } -func NewSmartHTTPClient(ctx context.Context, server Server) (context.Context, *SmartHTTPClient, error) { +func NewSmartHTTPClient(ctx context.Context, server api.GitalyServer, metadataFuncs ...MetadataFunc) (context.Context, *SmartHTTPClient, error) { conn, err := getOrCreateConnection(server) if err != nil { return nil, nil, err @@ -94,50 +111,53 @@ func NewSmartHTTPClient(ctx context.Context, server Server) (context.Context, *S smartHTTPClient := &SmartHTTPClient{ SmartHTTPServiceClient: grpcClient, sidechannelRegistry: sidechannelRegistry, - useSidechannel: server.Sidechannel, } - return withOutgoingMetadata(ctx, server.Features), smartHTTPClient, nil + + return withOutgoingMetadata( + ctx, + metadataFuncs..., + ), smartHTTPClient, nil } -func NewBlobClient(ctx context.Context, server Server) (context.Context, *BlobClient, error) { +func NewBlobClient(ctx context.Context, server api.GitalyServer, addMetadataFuncs ...MetadataFunc) (context.Context, *BlobClient, error) { conn, err := getOrCreateConnection(server) if err != nil { return nil, nil, err } grpcClient := gitalypb.NewBlobServiceClient(conn) - return withOutgoingMetadata(ctx, server.Features), &BlobClient{grpcClient}, nil + return withOutgoingMetadata(ctx, addMetadataFuncs...), &BlobClient{grpcClient}, nil } -func NewRepositoryClient(ctx context.Context, server Server) (context.Context, *RepositoryClient, error) { +func NewRepositoryClient(ctx context.Context, server api.GitalyServer, addMetadataFuncs ...MetadataFunc) (context.Context, *RepositoryClient, error) { conn, err := getOrCreateConnection(server) if err != nil { return nil, nil, err } grpcClient := gitalypb.NewRepositoryServiceClient(conn) - return withOutgoingMetadata(ctx, server.Features), &RepositoryClient{grpcClient}, nil + return withOutgoingMetadata(ctx, addMetadataFuncs...), &RepositoryClient{grpcClient}, nil } // NewNamespaceClient is only used by the Gitaly integration tests at present -func NewNamespaceClient(ctx context.Context, server Server) (context.Context, *NamespaceClient, error) { +func NewNamespaceClient(ctx context.Context, server api.GitalyServer, addMetadataFuncs ...MetadataFunc) (context.Context, *NamespaceClient, error) { conn, err := getOrCreateConnection(server) if err != nil { return nil, nil, err } grpcClient := gitalypb.NewNamespaceServiceClient(conn) - return withOutgoingMetadata(ctx, server.Features), &NamespaceClient{grpcClient}, nil + return withOutgoingMetadata(ctx, addMetadataFuncs...), &NamespaceClient{grpcClient}, nil } -func NewDiffClient(ctx context.Context, server Server) (context.Context, *DiffClient, error) { +func NewDiffClient(ctx context.Context, server api.GitalyServer, addMetadataFuncs ...MetadataFunc) (context.Context, *DiffClient, error) { conn, err := getOrCreateConnection(server) if err != nil { return nil, nil, err } grpcClient := gitalypb.NewDiffServiceClient(conn) - return withOutgoingMetadata(ctx, server.Features), &DiffClient{grpcClient}, nil + return withOutgoingMetadata(ctx, addMetadataFuncs...), &DiffClient{grpcClient}, nil } -func getOrCreateConnection(server Server) (*grpc.ClientConn, error) { - key := server.cacheKey() +func getOrCreateConnection(server api.GitalyServer) (*grpc.ClientConn, error) { + key := getCacheKey(server) cache.RLock() conn := cache.connections[key] @@ -173,7 +193,7 @@ func CloseConnections() { } } -func newConnection(server Server) (*grpc.ClientConn, error) { +func newConnection(server api.GitalyServer) (*grpc.ClientConn, error) { connOpts := append(gitalyclient.DefaultDialOpts, grpc.WithPerRPCCredentials(gitalyauth.RPCCredentialsV2(server.Token)), grpc.WithStreamInterceptor( @@ -197,13 +217,7 @@ func newConnection(server Server) (*grpc.ClientConn, error) { ), ) - var conn *grpc.ClientConn - var connErr error - if server.Sidechannel { - conn, connErr = gitalyclient.DialSidechannel(context.Background(), server.Address, sidechannelRegistry, connOpts) // lint:allow context.Background - } else { - conn, connErr = gitalyclient.Dial(server.Address, connOpts) - } + conn, connErr := gitalyclient.DialSidechannel(context.Background(), server.Address, sidechannelRegistry, connOpts) // lint:allow context.Background label := "ok" if connErr != nil { |