summaryrefslogtreecommitdiff
path: root/workhorse/internal/upload/destination/objectstore/multipart.go
blob: 900ca040dad45e792baa346d7f2d5b106bf1e84c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
package objectstore

import (
	"bytes"
	"context"
	"encoding/xml"
	"errors"
	"fmt"
	"io"
	"net/http"
	"os"

	"gitlab.com/gitlab-org/labkit/mask"

	"gitlab.com/gitlab-org/gitlab/workhorse/internal/upload/destination/objectstore/s3api"
)

// ErrNotEnoughParts will be used when writing more than size * len(partURLs)
var ErrNotEnoughParts = errors.New("not enough Parts")

// Multipart represents a MultipartUpload on a S3 compatible Object Store service.
// It can be used as io.WriteCloser for uploading an object
type Multipart struct {
	PartURLs []string
	// CompleteURL is a presigned URL for CompleteMultipartUpload
	CompleteURL string
	// AbortURL is a presigned URL for AbortMultipartUpload
	AbortURL string
	// DeleteURL is a presigned URL for RemoveObject
	DeleteURL  string
	PutHeaders map[string]string
	partSize   int64
	etag       string

	*uploader
}

// NewMultipart provides Multipart pointer that can be used for uploading. Data written will be split buffered on disk up to size bytes
// then uploaded with S3 Upload Part. Once Multipart is Closed a final call to CompleteMultipartUpload will be sent.
// In case of any error a call to AbortMultipartUpload will be made to cleanup all the resources
func NewMultipart(partURLs []string, completeURL, abortURL, deleteURL string, putHeaders map[string]string, partSize int64) (*Multipart, error) {
	m := &Multipart{
		PartURLs:    partURLs,
		CompleteURL: completeURL,
		AbortURL:    abortURL,
		DeleteURL:   deleteURL,
		PutHeaders:  putHeaders,
		partSize:    partSize,
	}

	m.uploader = newUploader(m)
	return m, nil
}

func (m *Multipart) Upload(ctx context.Context, r io.Reader) error {
	cmu := &s3api.CompleteMultipartUpload{}
	for i, partURL := range m.PartURLs {
		src := io.LimitReader(r, m.partSize)
		part, err := m.readAndUploadOnePart(ctx, partURL, m.PutHeaders, src, i+1)
		if err != nil {
			return err
		}
		if part == nil {
			break
		} else {
			cmu.Part = append(cmu.Part, part)
		}
	}

	n, err := io.Copy(io.Discard, r)
	if err != nil {
		return fmt.Errorf("drain pipe: %v", err)
	}
	if n > 0 {
		return ErrNotEnoughParts
	}

	if err := m.complete(ctx, cmu); err != nil {
		return err
	}

	return nil
}

func (m *Multipart) ETag() string {
	return m.etag
}
func (m *Multipart) Abort() {
	deleteURL(m.AbortURL)
}

func (m *Multipart) Delete() {
	deleteURL(m.DeleteURL)
}

func (m *Multipart) readAndUploadOnePart(ctx context.Context, partURL string, putHeaders map[string]string, src io.Reader, partNumber int) (*s3api.CompleteMultipartUploadPart, error) {
	file, err := os.CreateTemp("", "part-buffer")
	if err != nil {
		return nil, fmt.Errorf("create temporary buffer file: %v", err)
	}
	defer file.Close()

	if err := os.Remove(file.Name()); err != nil {
		return nil, fmt.Errorf("remove temporary buffer file: %v", err)
	}

	n, err := io.Copy(file, src)
	if err != nil {
		return nil, fmt.Errorf("copy to temporary buffer file: %v", err)
	}
	if n == 0 {
		return nil, nil
	}

	if _, err = file.Seek(0, io.SeekStart); err != nil {
		return nil, fmt.Errorf("rewind part %d temporary dump : %v", partNumber, err)
	}

	etag, err := m.uploadPart(ctx, partURL, putHeaders, file, n)
	if err != nil {
		return nil, fmt.Errorf("upload part %d: %v", partNumber, err)
	}
	return &s3api.CompleteMultipartUploadPart{PartNumber: partNumber, ETag: etag}, nil
}

func (m *Multipart) uploadPart(ctx context.Context, url string, headers map[string]string, body io.Reader, size int64) (string, error) {
	deadline, ok := ctx.Deadline()
	if !ok {
		return "", fmt.Errorf("missing deadline")
	}

	part, err := newObject(url, "", headers, size, false)
	if err != nil {
		return "", err
	}

	if n, err := part.Consume(ctx, io.LimitReader(body, size), deadline); err != nil || n < size {
		if err == nil {
			err = io.ErrUnexpectedEOF
		}
		return "", err
	}

	return part.ETag(), nil
}

func (m *Multipart) complete(ctx context.Context, cmu *s3api.CompleteMultipartUpload) error {
	body, err := xml.Marshal(cmu)
	if err != nil {
		return fmt.Errorf("marshal CompleteMultipartUpload request: %v", err)
	}

	req, err := http.NewRequest("POST", m.CompleteURL, bytes.NewReader(body))
	if err != nil {
		return fmt.Errorf("create CompleteMultipartUpload request: %v", err)
	}
	req.ContentLength = int64(len(body))
	req.Header.Set("Content-Type", "application/xml")
	req = req.WithContext(ctx)

	resp, err := httpClient.Do(req)
	if err != nil {
		return fmt.Errorf("CompleteMultipartUpload request %q: %v", mask.URL(m.CompleteURL), err)
	}
	defer resp.Body.Close()

	if resp.StatusCode != http.StatusOK {
		return fmt.Errorf("CompleteMultipartUpload request %v returned: %s", mask.URL(m.CompleteURL), resp.Status)
	}

	result := &compoundCompleteMultipartUploadResult{}
	decoder := xml.NewDecoder(resp.Body)
	if err := decoder.Decode(&result); err != nil {
		return fmt.Errorf("decode CompleteMultipartUpload answer: %v", err)
	}

	if result.isError() {
		return result
	}

	if result.CompleteMultipartUploadResult == nil {
		return fmt.Errorf("empty CompleteMultipartUploadResult")
	}

	m.etag = extractETag(result.ETag)

	return nil
}