blob: fcda59fde66a315d9e7ac293b78e3abb411916de [file] [log] [blame] [edit]
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package oci
import (
"bytes"
"context"
"crypto/md5"
"encoding/base64"
"fmt"
"io"
"sync"
"github.com/hashicorp/go-hclog"
"github.com/oracle/oci-go-sdk/v65/objectstorage"
"github.com/oracle/oci-go-sdk/v65/common"
)
var DefaultFilePartSize int64 = 128 * 1024 * 1024 // 128MB
const MaxFilePartSize int64 = 50 * 1024 * 1024 * 1024 // 50GB
const defaultNumberOfGoroutines = 10
const MaxCount int64 = 10000
type MultipartUploadData struct {
client *RemoteClient
Data []byte
RequestMetadata common.RequestMetadata
}
type objectStorageUploadPartResponse struct {
response objectstorage.UploadPartResponse
partNumber *int
error error
}
type objectStorageMultiPartUploadContext struct {
client *RemoteClient
sourceBlocks chan objectStorageSourceBlock
osUploadPartResponses chan objectStorageUploadPartResponse
wg *sync.WaitGroup
errChan chan error
multipartUploadResponse objectstorage.CreateMultipartUploadResponse
multipartUploadRequest objectstorage.CreateMultipartUploadRequest
logger hclog.Logger
}
type objectStorageSourceBlock struct {
section *io.SectionReader
blockNumber *int
}
func (multipartUploadData MultipartUploadData) multiPartUploadImpl(ctx context.Context) error {
logger := ctx.Value("logger").(hclog.Logger).Named("multiPartUpload")
sourceBlocks, err := multipartUploadData.objectMultiPartSplit()
if err != nil {
return fmt.Errorf("error splitting source data: %s", err)
}
multipartUploadRequest := &objectstorage.CreateMultipartUploadRequest{
NamespaceName: common.String(multipartUploadData.client.namespace),
BucketName: common.String(multipartUploadData.client.bucketName),
RequestMetadata: multipartUploadData.RequestMetadata,
CreateMultipartUploadDetails: objectstorage.CreateMultipartUploadDetails{
Object: common.String(multipartUploadData.client.path),
},
}
if multipartUploadData.client.kmsKeyID != "" {
multipartUploadRequest.OpcSseKmsKeyId = common.String(multipartUploadData.client.kmsKeyID)
} else if multipartUploadData.client.SSECustomerKey != "" && multipartUploadData.client.SSECustomerKeySHA256 != "" {
multipartUploadRequest.OpcSseCustomerKey = common.String(multipartUploadData.client.SSECustomerKey)
multipartUploadRequest.OpcSseCustomerKeySha256 = common.String(multipartUploadData.client.SSECustomerKeySHA256)
multipartUploadRequest.OpcSseCustomerAlgorithm = common.String(multipartUploadData.client.SSECustomerAlgorithm)
}
multipartUploadResponse, err := multipartUploadData.client.objectStorageClient.CreateMultipartUpload(context.Background(), *multipartUploadRequest)
if err != nil {
return fmt.Errorf("error creating multipart upload: %s", err)
}
workerCount := defaultNumberOfGoroutines
osUploadPartResponses := make(chan objectStorageUploadPartResponse, len(sourceBlocks))
sourceBlocksChan := make(chan objectStorageSourceBlock, len(sourceBlocks))
wg := &sync.WaitGroup{}
wg.Add(len(sourceBlocks))
// Push all source blocks into the channel
for _, sourceBlock := range sourceBlocks {
sourceBlocksChan <- sourceBlock
}
close(sourceBlocksChan)
errChan := make(chan error, workerCount)
// Start workers
for i := 0; i < workerCount; i++ {
go func() {
ctx := &objectStorageMultiPartUploadContext{
client: multipartUploadData.client,
wg: wg,
errChan: errChan,
multipartUploadResponse: multipartUploadResponse,
multipartUploadRequest: *multipartUploadRequest,
sourceBlocks: sourceBlocksChan,
osUploadPartResponses: osUploadPartResponses,
logger: logger,
}
ctx.uploadPartsWorker()
}()
}
wg.Wait()
close(osUploadPartResponses)
close(errChan)
// Collect errors from workers
for workerErr := range errChan {
if workerErr != nil {
return workerErr
}
}
commitMultipartUploadPartDetails := make([]objectstorage.CommitMultipartUploadPartDetails, len(sourceBlocks))
i := 0
for response := range osUploadPartResponses {
if response.error != nil || response.partNumber == nil || response.response.ETag == nil {
return fmt.Errorf("failed to upload part: %s", response.error)
}
partNumber, etag := *response.partNumber, *response.response.ETag
commitMultipartUploadPartDetails[i] = objectstorage.CommitMultipartUploadPartDetails{
PartNum: common.Int(partNumber),
Etag: common.String(etag),
}
i++
}
if len(commitMultipartUploadPartDetails) != len(sourceBlocks) {
abortReq := objectstorage.AbortMultipartUploadRequest{
UploadId: multipartUploadResponse.MultipartUpload.UploadId,
NamespaceName: multipartUploadResponse.Namespace,
BucketName: multipartUploadResponse.Bucket,
ObjectName: multipartUploadResponse.Object,
}
_, abortErr := multipartUploadData.client.objectStorageClient.AbortMultipartUpload(context.Background(), abortReq)
if abortErr != nil {
logger.Error(fmt.Sprintf("Failed to abort multipart upload: %s", abortErr))
}
return fmt.Errorf("not all parts uploaded successfully, multipart upload aborted")
}
commitMultipartUploadRequest := objectstorage.CommitMultipartUploadRequest{
UploadId: multipartUploadResponse.MultipartUpload.UploadId,
NamespaceName: multipartUploadResponse.Namespace,
BucketName: multipartUploadResponse.Bucket,
ObjectName: multipartUploadResponse.Object,
OpcClientRequestId: multipartUploadResponse.OpcClientRequestId,
RequestMetadata: multipartUploadRequest.RequestMetadata,
CommitMultipartUploadDetails: objectstorage.CommitMultipartUploadDetails{
PartsToCommit: commitMultipartUploadPartDetails,
},
}
_, err = multipartUploadData.client.objectStorageClient.CommitMultipartUpload(context.Background(), commitMultipartUploadRequest)
if err != nil {
return fmt.Errorf("failed to commit multipart upload: %s", err)
}
return nil
}
func (m MultipartUploadData) objectMultiPartSplit() ([]objectStorageSourceBlock, error) {
dataSize := int64(len(m.Data))
offsets, partSize, err := SplitSizeToOffsetsAndLimits(dataSize)
if err != nil {
return nil, fmt.Errorf("error splitting data into parts: %s", err)
}
sourceBlocks := make([]objectStorageSourceBlock, len(offsets))
for i := range offsets {
start := offsets[i]
end := start + partSize
if end > dataSize {
end = dataSize
}
sourceBlocks[i] = objectStorageSourceBlock{
section: io.NewSectionReader(bytes.NewReader(m.Data), start, end-start),
blockNumber: common.Int(i + 1),
}
}
return sourceBlocks, nil
}
/*
SplitSizeToOffsetsAndLimits splits a file size into chunks based on DefaultFilePartSize.
Returns the byte offsets and byte limits for each chunk.
Returns an error if the size exceeds MaxCount parts.
*/
func SplitSizeToOffsetsAndLimits(size int64) ([]int64, int64, error) {
partSize := DefaultFilePartSize
totalParts := (size + partSize - 1) / partSize
if totalParts > MaxCount {
return nil, 0, fmt.Errorf("file exceeds maximum part count")
}
offsets := make([]int64, totalParts)
for i := range offsets {
offsets[i] = int64(i) * partSize
}
return offsets, partSize, nil
}
func (ctx *objectStorageMultiPartUploadContext) uploadPartsWorker() {
for block := range ctx.sourceBlocks {
buffer := make([]byte, block.section.Size())
_, err := block.section.Read(buffer)
if err != nil {
ctx.errChan <- fmt.Errorf("error reading source block %d: %w", block.blockNumber, err)
return
}
tmpLength := int64(len(buffer))
sum := md5.Sum(buffer)
uploadPartRequest := &objectstorage.UploadPartRequest{
UploadId: ctx.multipartUploadResponse.UploadId,
ObjectName: ctx.multipartUploadResponse.Object,
NamespaceName: ctx.multipartUploadResponse.Namespace,
BucketName: ctx.multipartUploadResponse.Bucket,
ContentLength: &tmpLength,
UploadPartBody: io.NopCloser(bytes.NewReader(buffer)),
UploadPartNum: block.blockNumber,
ContentMD5: common.String(base64.StdEncoding.EncodeToString(sum[:])),
RequestMetadata: common.RequestMetadata{
RetryPolicy: getDefaultRetryPolicy(),
},
}
if ctx.client.kmsKeyID != "" {
uploadPartRequest.OpcSseKmsKeyId = common.String(ctx.client.kmsKeyID)
} else if ctx.client.SSECustomerKey != "" && ctx.client.SSECustomerKeySHA256 != "" {
uploadPartRequest.OpcSseCustomerKey = common.String(ctx.client.SSECustomerKey)
uploadPartRequest.OpcSseCustomerKeySha256 = common.String(ctx.client.SSECustomerKeySHA256)
uploadPartRequest.OpcSseCustomerAlgorithm = common.String(ctx.client.SSECustomerAlgorithm)
}
response, err := ctx.client.objectStorageClient.UploadPart(context.Background(), *uploadPartRequest)
if err != nil {
ctx.errChan <- fmt.Errorf("failed to upload part %d: %w", *block.blockNumber, err)
return
}
ctx.osUploadPartResponses <- objectStorageUploadPartResponse{
response: response,
error: nil,
partNumber: block.blockNumber,
}
ctx.wg.Done()
}
}