| // Copyright (c) HashiCorp, Inc. |
| // SPDX-License-Identifier: MPL-2.0 |
| |
| package s3 |
| |
| import ( |
| "bytes" |
| "context" |
| "fmt" |
| "io" |
| "net/http" |
| "os" |
| "path" |
| "sort" |
| "strconv" |
| "strings" |
| "time" |
| |
| "github.com/armon/go-metrics" |
| "github.com/aws/aws-sdk-go/aws" |
| "github.com/aws/aws-sdk-go/aws/awserr" |
| "github.com/aws/aws-sdk-go/aws/session" |
| "github.com/aws/aws-sdk-go/service/s3" |
| "github.com/hashicorp/go-cleanhttp" |
| log "github.com/hashicorp/go-hclog" |
| "github.com/hashicorp/go-secure-stdlib/awsutil" |
| "github.com/hashicorp/go-secure-stdlib/parseutil" |
| "github.com/hashicorp/vault/sdk/helper/consts" |
| "github.com/hashicorp/vault/sdk/physical" |
| ) |
| |
| // Verify S3Backend satisfies the correct interfaces |
| var _ physical.Backend = (*S3Backend)(nil) |
| |
| // S3Backend is a physical backend that stores data |
| // within an S3 bucket. |
| type S3Backend struct { |
| bucket string |
| path string |
| kmsKeyId string |
| client *s3.S3 |
| logger log.Logger |
| permitPool *physical.PermitPool |
| } |
| |
| // NewS3Backend constructs a S3 backend using a pre-existing |
| // bucket. Credentials can be provided to the backend, sourced |
| // from the environment, AWS credential files or by IAM role. |
| func NewS3Backend(conf map[string]string, logger log.Logger) (physical.Backend, error) { |
| bucket := os.Getenv("AWS_S3_BUCKET") |
| if bucket == "" { |
| bucket = conf["bucket"] |
| if bucket == "" { |
| return nil, fmt.Errorf("'bucket' must be set") |
| } |
| } |
| |
| path := conf["path"] |
| |
| accessKey, ok := conf["access_key"] |
| if !ok { |
| accessKey = "" |
| } |
| secretKey, ok := conf["secret_key"] |
| if !ok { |
| secretKey = "" |
| } |
| sessionToken, ok := conf["session_token"] |
| if !ok { |
| sessionToken = "" |
| } |
| endpoint := os.Getenv("AWS_S3_ENDPOINT") |
| if endpoint == "" { |
| endpoint = conf["endpoint"] |
| } |
| region := os.Getenv("AWS_REGION") |
| if region == "" { |
| region = os.Getenv("AWS_DEFAULT_REGION") |
| if region == "" { |
| region = conf["region"] |
| if region == "" { |
| region = "us-east-1" |
| } |
| } |
| } |
| s3ForcePathStyleStr, ok := conf["s3_force_path_style"] |
| if !ok { |
| s3ForcePathStyleStr = "false" |
| } |
| s3ForcePathStyleBool, err := parseutil.ParseBool(s3ForcePathStyleStr) |
| if err != nil { |
| return nil, fmt.Errorf("invalid boolean set for s3_force_path_style: %q", s3ForcePathStyleStr) |
| } |
| disableSSLStr, ok := conf["disable_ssl"] |
| if !ok { |
| disableSSLStr = "false" |
| } |
| disableSSLBool, err := parseutil.ParseBool(disableSSLStr) |
| if err != nil { |
| return nil, fmt.Errorf("invalid boolean set for disable_ssl: %q", disableSSLStr) |
| } |
| |
| credsConfig := &awsutil.CredentialsConfig{ |
| AccessKey: accessKey, |
| SecretKey: secretKey, |
| SessionToken: sessionToken, |
| Logger: logger, |
| } |
| creds, err := credsConfig.GenerateCredentialChain() |
| if err != nil { |
| return nil, err |
| } |
| |
| pooledTransport := cleanhttp.DefaultPooledTransport() |
| pooledTransport.MaxIdleConnsPerHost = consts.ExpirationRestoreWorkerCount |
| |
| sess, err := session.NewSession(&aws.Config{ |
| Credentials: creds, |
| HTTPClient: &http.Client{ |
| Transport: pooledTransport, |
| }, |
| Endpoint: aws.String(endpoint), |
| Region: aws.String(region), |
| S3ForcePathStyle: aws.Bool(s3ForcePathStyleBool), |
| DisableSSL: aws.Bool(disableSSLBool), |
| }) |
| if err != nil { |
| return nil, err |
| } |
| s3conn := s3.New(sess) |
| |
| _, err = s3conn.ListObjects(&s3.ListObjectsInput{Bucket: &bucket}) |
| if err != nil { |
| return nil, fmt.Errorf("unable to access bucket %q in region %q: %w", bucket, region, err) |
| } |
| |
| maxParStr, ok := conf["max_parallel"] |
| var maxParInt int |
| if ok { |
| maxParInt, err = strconv.Atoi(maxParStr) |
| if err != nil { |
| return nil, fmt.Errorf("failed parsing max_parallel parameter: %w", err) |
| } |
| if logger.IsDebug() { |
| logger.Debug("max_parallel set", "max_parallel", maxParInt) |
| } |
| } |
| |
| kmsKeyId, ok := conf["kms_key_id"] |
| if !ok { |
| kmsKeyId = "" |
| } |
| |
| s := &S3Backend{ |
| client: s3conn, |
| bucket: bucket, |
| path: path, |
| kmsKeyId: kmsKeyId, |
| logger: logger, |
| permitPool: physical.NewPermitPool(maxParInt), |
| } |
| return s, nil |
| } |
| |
| // Put is used to insert or update an entry |
| func (s *S3Backend) Put(ctx context.Context, entry *physical.Entry) error { |
| defer metrics.MeasureSince([]string{"s3", "put"}, time.Now()) |
| |
| s.permitPool.Acquire() |
| defer s.permitPool.Release() |
| |
| // Setup key |
| key := path.Join(s.path, entry.Key) |
| |
| putObjectInput := &s3.PutObjectInput{ |
| Bucket: aws.String(s.bucket), |
| Key: aws.String(key), |
| Body: bytes.NewReader(entry.Value), |
| } |
| |
| if s.kmsKeyId != "" { |
| putObjectInput.ServerSideEncryption = aws.String("aws:kms") |
| putObjectInput.SSEKMSKeyId = aws.String(s.kmsKeyId) |
| } |
| |
| _, err := s.client.PutObject(putObjectInput) |
| if err != nil { |
| return err |
| } |
| |
| return nil |
| } |
| |
| // Get is used to fetch an entry |
| func (s *S3Backend) Get(ctx context.Context, key string) (*physical.Entry, error) { |
| defer metrics.MeasureSince([]string{"s3", "get"}, time.Now()) |
| |
| s.permitPool.Acquire() |
| defer s.permitPool.Release() |
| |
| // Setup key |
| key = path.Join(s.path, key) |
| |
| resp, err := s.client.GetObject(&s3.GetObjectInput{ |
| Bucket: aws.String(s.bucket), |
| Key: aws.String(key), |
| }) |
| if resp != nil && resp.Body != nil { |
| defer resp.Body.Close() |
| } |
| if awsErr, ok := err.(awserr.RequestFailure); ok { |
| // Return nil on 404s, error on anything else |
| if awsErr.StatusCode() == 404 { |
| return nil, nil |
| } |
| return nil, err |
| } |
| if err != nil { |
| return nil, err |
| } |
| if resp == nil { |
| return nil, fmt.Errorf("got nil response from S3 but no error") |
| } |
| |
| data := bytes.NewBuffer(nil) |
| if resp.ContentLength != nil { |
| data = bytes.NewBuffer(make([]byte, 0, *resp.ContentLength)) |
| } |
| _, err = io.Copy(data, resp.Body) |
| if err != nil { |
| return nil, err |
| } |
| |
| // Strip path prefix |
| if s.path != "" { |
| key = strings.TrimPrefix(key, s.path+"/") |
| } |
| |
| ent := &physical.Entry{ |
| Key: key, |
| Value: data.Bytes(), |
| } |
| |
| return ent, nil |
| } |
| |
| // Delete is used to permanently delete an entry |
| func (s *S3Backend) Delete(ctx context.Context, key string) error { |
| defer metrics.MeasureSince([]string{"s3", "delete"}, time.Now()) |
| |
| s.permitPool.Acquire() |
| defer s.permitPool.Release() |
| |
| // Setup key |
| key = path.Join(s.path, key) |
| |
| _, err := s.client.DeleteObject(&s3.DeleteObjectInput{ |
| Bucket: aws.String(s.bucket), |
| Key: aws.String(key), |
| }) |
| if err != nil { |
| return err |
| } |
| |
| return nil |
| } |
| |
| // List is used to list all the keys under a given |
| // prefix, up to the next prefix. |
| func (s *S3Backend) List(ctx context.Context, prefix string) ([]string, error) { |
| defer metrics.MeasureSince([]string{"s3", "list"}, time.Now()) |
| |
| s.permitPool.Acquire() |
| defer s.permitPool.Release() |
| |
| // Setup prefix |
| prefix = path.Join(s.path, prefix) |
| |
| // Validate prefix (if present) is ending with a "/" |
| if prefix != "" && !strings.HasSuffix(prefix, "/") { |
| prefix += "/" |
| } |
| |
| params := &s3.ListObjectsV2Input{ |
| Bucket: aws.String(s.bucket), |
| Prefix: aws.String(prefix), |
| Delimiter: aws.String("/"), |
| } |
| |
| keys := []string{} |
| |
| err := s.client.ListObjectsV2Pages(params, |
| func(page *s3.ListObjectsV2Output, lastPage bool) bool { |
| if page != nil { |
| // Add truncated 'folder' paths |
| for _, commonPrefix := range page.CommonPrefixes { |
| // Avoid panic |
| if commonPrefix == nil { |
| continue |
| } |
| |
| commonPrefix := strings.TrimPrefix(*commonPrefix.Prefix, prefix) |
| keys = append(keys, commonPrefix) |
| } |
| // Add objects only from the current 'folder' |
| for _, key := range page.Contents { |
| // Avoid panic |
| if key == nil { |
| continue |
| } |
| |
| key := strings.TrimPrefix(*key.Key, prefix) |
| keys = append(keys, key) |
| } |
| } |
| return true |
| }) |
| if err != nil { |
| return nil, err |
| } |
| |
| sort.Strings(keys) |
| |
| return keys, nil |
| } |