blob: 0cb8e0af3552effeb6f2ce6fb7d29d62edf4e229 [file] [log] [blame]
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package s3
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"os"
"path"
"sort"
"strconv"
"strings"
"time"
"github.com/armon/go-metrics"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/hashicorp/go-cleanhttp"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-secure-stdlib/awsutil"
"github.com/hashicorp/go-secure-stdlib/parseutil"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/physical"
)
// Verify S3Backend satisfies the correct interfaces
var _ physical.Backend = (*S3Backend)(nil)
// S3Backend is a physical backend that stores data
// within an S3 bucket.
type S3Backend struct {
bucket string
path string
kmsKeyId string
client *s3.S3
logger log.Logger
permitPool *physical.PermitPool
}
// NewS3Backend constructs a S3 backend using a pre-existing
// bucket. Credentials can be provided to the backend, sourced
// from the environment, AWS credential files or by IAM role.
func NewS3Backend(conf map[string]string, logger log.Logger) (physical.Backend, error) {
bucket := os.Getenv("AWS_S3_BUCKET")
if bucket == "" {
bucket = conf["bucket"]
if bucket == "" {
return nil, fmt.Errorf("'bucket' must be set")
}
}
path := conf["path"]
accessKey, ok := conf["access_key"]
if !ok {
accessKey = ""
}
secretKey, ok := conf["secret_key"]
if !ok {
secretKey = ""
}
sessionToken, ok := conf["session_token"]
if !ok {
sessionToken = ""
}
endpoint := os.Getenv("AWS_S3_ENDPOINT")
if endpoint == "" {
endpoint = conf["endpoint"]
}
region := os.Getenv("AWS_REGION")
if region == "" {
region = os.Getenv("AWS_DEFAULT_REGION")
if region == "" {
region = conf["region"]
if region == "" {
region = "us-east-1"
}
}
}
s3ForcePathStyleStr, ok := conf["s3_force_path_style"]
if !ok {
s3ForcePathStyleStr = "false"
}
s3ForcePathStyleBool, err := parseutil.ParseBool(s3ForcePathStyleStr)
if err != nil {
return nil, fmt.Errorf("invalid boolean set for s3_force_path_style: %q", s3ForcePathStyleStr)
}
disableSSLStr, ok := conf["disable_ssl"]
if !ok {
disableSSLStr = "false"
}
disableSSLBool, err := parseutil.ParseBool(disableSSLStr)
if err != nil {
return nil, fmt.Errorf("invalid boolean set for disable_ssl: %q", disableSSLStr)
}
credsConfig := &awsutil.CredentialsConfig{
AccessKey: accessKey,
SecretKey: secretKey,
SessionToken: sessionToken,
Logger: logger,
}
creds, err := credsConfig.GenerateCredentialChain()
if err != nil {
return nil, err
}
pooledTransport := cleanhttp.DefaultPooledTransport()
pooledTransport.MaxIdleConnsPerHost = consts.ExpirationRestoreWorkerCount
sess, err := session.NewSession(&aws.Config{
Credentials: creds,
HTTPClient: &http.Client{
Transport: pooledTransport,
},
Endpoint: aws.String(endpoint),
Region: aws.String(region),
S3ForcePathStyle: aws.Bool(s3ForcePathStyleBool),
DisableSSL: aws.Bool(disableSSLBool),
})
if err != nil {
return nil, err
}
s3conn := s3.New(sess)
_, err = s3conn.ListObjects(&s3.ListObjectsInput{Bucket: &bucket})
if err != nil {
return nil, fmt.Errorf("unable to access bucket %q in region %q: %w", bucket, region, err)
}
maxParStr, ok := conf["max_parallel"]
var maxParInt int
if ok {
maxParInt, err = strconv.Atoi(maxParStr)
if err != nil {
return nil, fmt.Errorf("failed parsing max_parallel parameter: %w", err)
}
if logger.IsDebug() {
logger.Debug("max_parallel set", "max_parallel", maxParInt)
}
}
kmsKeyId, ok := conf["kms_key_id"]
if !ok {
kmsKeyId = ""
}
s := &S3Backend{
client: s3conn,
bucket: bucket,
path: path,
kmsKeyId: kmsKeyId,
logger: logger,
permitPool: physical.NewPermitPool(maxParInt),
}
return s, nil
}
// Put is used to insert or update an entry
func (s *S3Backend) Put(ctx context.Context, entry *physical.Entry) error {
defer metrics.MeasureSince([]string{"s3", "put"}, time.Now())
s.permitPool.Acquire()
defer s.permitPool.Release()
// Setup key
key := path.Join(s.path, entry.Key)
putObjectInput := &s3.PutObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(key),
Body: bytes.NewReader(entry.Value),
}
if s.kmsKeyId != "" {
putObjectInput.ServerSideEncryption = aws.String("aws:kms")
putObjectInput.SSEKMSKeyId = aws.String(s.kmsKeyId)
}
_, err := s.client.PutObject(putObjectInput)
if err != nil {
return err
}
return nil
}
// Get is used to fetch an entry
func (s *S3Backend) Get(ctx context.Context, key string) (*physical.Entry, error) {
defer metrics.MeasureSince([]string{"s3", "get"}, time.Now())
s.permitPool.Acquire()
defer s.permitPool.Release()
// Setup key
key = path.Join(s.path, key)
resp, err := s.client.GetObject(&s3.GetObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(key),
})
if resp != nil && resp.Body != nil {
defer resp.Body.Close()
}
if awsErr, ok := err.(awserr.RequestFailure); ok {
// Return nil on 404s, error on anything else
if awsErr.StatusCode() == 404 {
return nil, nil
}
return nil, err
}
if err != nil {
return nil, err
}
if resp == nil {
return nil, fmt.Errorf("got nil response from S3 but no error")
}
data := bytes.NewBuffer(nil)
if resp.ContentLength != nil {
data = bytes.NewBuffer(make([]byte, 0, *resp.ContentLength))
}
_, err = io.Copy(data, resp.Body)
if err != nil {
return nil, err
}
// Strip path prefix
if s.path != "" {
key = strings.TrimPrefix(key, s.path+"/")
}
ent := &physical.Entry{
Key: key,
Value: data.Bytes(),
}
return ent, nil
}
// Delete is used to permanently delete an entry
func (s *S3Backend) Delete(ctx context.Context, key string) error {
defer metrics.MeasureSince([]string{"s3", "delete"}, time.Now())
s.permitPool.Acquire()
defer s.permitPool.Release()
// Setup key
key = path.Join(s.path, key)
_, err := s.client.DeleteObject(&s3.DeleteObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(key),
})
if err != nil {
return err
}
return nil
}
// List is used to list all the keys under a given
// prefix, up to the next prefix.
func (s *S3Backend) List(ctx context.Context, prefix string) ([]string, error) {
defer metrics.MeasureSince([]string{"s3", "list"}, time.Now())
s.permitPool.Acquire()
defer s.permitPool.Release()
// Setup prefix
prefix = path.Join(s.path, prefix)
// Validate prefix (if present) is ending with a "/"
if prefix != "" && !strings.HasSuffix(prefix, "/") {
prefix += "/"
}
params := &s3.ListObjectsV2Input{
Bucket: aws.String(s.bucket),
Prefix: aws.String(prefix),
Delimiter: aws.String("/"),
}
keys := []string{}
err := s.client.ListObjectsV2Pages(params,
func(page *s3.ListObjectsV2Output, lastPage bool) bool {
if page != nil {
// Add truncated 'folder' paths
for _, commonPrefix := range page.CommonPrefixes {
// Avoid panic
if commonPrefix == nil {
continue
}
commonPrefix := strings.TrimPrefix(*commonPrefix.Prefix, prefix)
keys = append(keys, commonPrefix)
}
// Add objects only from the current 'folder'
for _, key := range page.Contents {
// Avoid panic
if key == nil {
continue
}
key := strings.TrimPrefix(*key.Key, prefix)
keys = append(keys, key)
}
}
return true
})
if err != nil {
return nil, err
}
sort.Strings(keys)
return keys, nil
}