1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
|
package gitaly
import (
"context"
"fmt"
"sync"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus"
"google.golang.org/grpc"
gitalyauth "gitlab.com/gitlab-org/gitaly/v15/auth"
"gitlab.com/gitlab-org/gitaly/v15/client"
gitalyclient "gitlab.com/gitlab-org/gitaly/v15/client"
"gitlab.com/gitlab-org/labkit/correlation"
grpccorrelation "gitlab.com/gitlab-org/labkit/correlation/grpc"
"gitlab.com/gitlab-org/labkit/log"
grpctracing "gitlab.com/gitlab-org/labkit/tracing/grpc"
"gitlab.com/gitlab-org/gitlab-shell/v14/internal/metrics"
)
type Command struct {
ServiceName string
Address string
Token string
}
type connectionsCache struct {
sync.RWMutex
connections map[Command]*grpc.ClientConn
}
type Client struct {
SidechannelRegistry *gitalyclient.SidechannelRegistry
cache connectionsCache
}
func (c *Client) InitSidechannelRegistry(ctx context.Context) {
c.SidechannelRegistry = gitalyclient.NewSidechannelRegistry(log.ContextLogger(ctx))
}
func (c *Client) GetConnection(ctx context.Context, cmd Command) (*grpc.ClientConn, error) {
c.cache.RLock()
conn := c.cache.connections[cmd]
c.cache.RUnlock()
if conn != nil {
return conn, nil
}
c.cache.Lock()
defer c.cache.Unlock()
if conn := c.cache.connections[cmd]; conn != nil {
return conn, nil
}
conn, err := c.newConnection(ctx, cmd)
if err != nil {
return nil, err
}
if c.cache.connections == nil {
c.cache.connections = make(map[Command]*grpc.ClientConn)
}
c.cache.connections[cmd] = conn
return conn, nil
}
func (c *Client) newConnection(ctx context.Context, cmd Command) (conn *grpc.ClientConn, err error) {
defer func() {
label := "ok"
if err != nil {
label = "fail"
}
metrics.GitalyConnectionsTotal.WithLabelValues(label).Inc()
}()
if cmd.Address == "" {
return nil, fmt.Errorf("no gitaly_address given")
}
serviceName := correlation.ExtractClientNameFromContext(ctx)
if serviceName == "" {
serviceName = "gitlab-shell-unknown"
log.WithContextFields(ctx, log.Fields{"service_name": serviceName}).Warn("No gRPC service name specified, defaulting to gitlab-shell-unknown")
}
serviceName = fmt.Sprintf("%s-%s", serviceName, cmd.ServiceName)
connOpts := client.DefaultDialOpts
connOpts = append(
connOpts,
grpc.WithStreamInterceptor(
grpc_middleware.ChainStreamClient(
grpctracing.StreamClientTracingInterceptor(),
grpc_prometheus.StreamClientInterceptor,
grpccorrelation.StreamClientCorrelationInterceptor(
grpccorrelation.WithClientName(serviceName),
),
),
),
grpc.WithUnaryInterceptor(
grpc_middleware.ChainUnaryClient(
grpctracing.UnaryClientTracingInterceptor(),
grpc_prometheus.UnaryClientInterceptor,
grpccorrelation.UnaryClientCorrelationInterceptor(
grpccorrelation.WithClientName(serviceName),
),
),
),
)
if cmd.Token != "" {
connOpts = append(connOpts,
grpc.WithPerRPCCredentials(gitalyauth.RPCCredentialsV2(cmd.Token)),
)
}
return client.DialSidechannel(ctx, cmd.Address, c.SidechannelRegistry, connOpts)
}
|