| // Copyright (c) HashiCorp, Inc. |
| // SPDX-License-Identifier: MPL-2.0 |
| |
| package dynamodb |
| |
| import ( |
| "context" |
| "fmt" |
| "math/rand" |
| "net/http" |
| "net/url" |
| "os" |
| "testing" |
| "time" |
| |
| "github.com/go-test/deep" |
| log "github.com/hashicorp/go-hclog" |
| "github.com/hashicorp/vault/sdk/helper/docker" |
| "github.com/hashicorp/vault/sdk/helper/logging" |
| "github.com/hashicorp/vault/sdk/physical" |
| |
| "github.com/aws/aws-sdk-go/aws" |
| "github.com/aws/aws-sdk-go/aws/credentials" |
| "github.com/aws/aws-sdk-go/aws/session" |
| "github.com/aws/aws-sdk-go/service/dynamodb" |
| "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute" |
| ) |
| |
| func TestDynamoDBBackend(t *testing.T) { |
| cleanup, svccfg := prepareDynamoDBTestContainer(t) |
| defer cleanup() |
| |
| creds, err := svccfg.Credentials.Get() |
| if err != nil { |
| t.Fatalf("err: %v", err) |
| } |
| |
| region := os.Getenv("AWS_DEFAULT_REGION") |
| if region == "" { |
| region = "us-east-1" |
| } |
| |
| awsSession, err := session.NewSession(&aws.Config{ |
| Credentials: svccfg.Credentials, |
| Endpoint: aws.String(svccfg.URL().String()), |
| Region: aws.String(region), |
| }) |
| if err != nil { |
| t.Fatalf("err: %v", err) |
| } |
| |
| conn := dynamodb.New(awsSession) |
| |
| randInt := rand.New(rand.NewSource(time.Now().UnixNano())).Int() |
| table := fmt.Sprintf("vault-dynamodb-testacc-%d", randInt) |
| |
| defer func() { |
| conn.DeleteTable(&dynamodb.DeleteTableInput{ |
| TableName: aws.String(table), |
| }) |
| }() |
| |
| logger := logging.NewVaultLogger(log.Debug) |
| |
| b, err := NewDynamoDBBackend(map[string]string{ |
| "access_key": creds.AccessKeyID, |
| "secret_key": creds.SecretAccessKey, |
| "session_token": creds.SessionToken, |
| "table": table, |
| "region": region, |
| "endpoint": svccfg.URL().String(), |
| }, logger) |
| if err != nil { |
| t.Fatalf("err: %s", err) |
| } |
| |
| physical.ExerciseBackend(t, b) |
| physical.ExerciseBackend_ListPrefix(t, b) |
| |
| t.Run("Marshalling upgrade", func(t *testing.T) { |
| path := "test_key" |
| |
| // Manually write to DynamoDB using the old ConvertTo function |
| // for marshalling data |
| inputEntry := &physical.Entry{ |
| Key: path, |
| Value: []byte{0x0f, 0xcf, 0x4a, 0x0f, 0xba, 0x2b, 0x15, 0xf0, 0xaa, 0x75, 0x09}, |
| } |
| |
| record := DynamoDBRecord{ |
| Path: recordPathForVaultKey(inputEntry.Key), |
| Key: recordKeyForVaultKey(inputEntry.Key), |
| Value: inputEntry.Value, |
| } |
| |
| item, err := dynamodbattribute.ConvertToMap(record) |
| if err != nil { |
| t.Fatalf("err: %s", err) |
| } |
| |
| request := &dynamodb.PutItemInput{ |
| Item: item, |
| TableName: &table, |
| } |
| conn.PutItem(request) |
| |
| // Read back the data using the normal interface which should |
| // handle the old marshalling format gracefully |
| entry, err := b.Get(context.Background(), path) |
| if err != nil { |
| t.Fatalf("err: %s", err) |
| } |
| if diff := deep.Equal(inputEntry, entry); diff != nil { |
| t.Fatal(diff) |
| } |
| }) |
| } |
| |
| func TestDynamoDBHABackend(t *testing.T) { |
| cleanup, svccfg := prepareDynamoDBTestContainer(t) |
| defer cleanup() |
| |
| creds, err := svccfg.Credentials.Get() |
| if err != nil { |
| t.Fatalf("err: %v", err) |
| } |
| |
| region := os.Getenv("AWS_DEFAULT_REGION") |
| if region == "" { |
| region = "us-east-1" |
| } |
| |
| awsSession, err := session.NewSession(&aws.Config{ |
| Credentials: svccfg.Credentials, |
| Endpoint: aws.String(svccfg.URL().String()), |
| Region: aws.String(region), |
| }) |
| if err != nil { |
| t.Fatalf("err: %v", err) |
| } |
| |
| conn := dynamodb.New(awsSession) |
| |
| randInt := rand.New(rand.NewSource(time.Now().UnixNano())).Int() |
| table := fmt.Sprintf("vault-dynamodb-testacc-%d", randInt) |
| |
| defer func() { |
| conn.DeleteTable(&dynamodb.DeleteTableInput{ |
| TableName: aws.String(table), |
| }) |
| }() |
| |
| logger := logging.NewVaultLogger(log.Debug) |
| config := map[string]string{ |
| "access_key": creds.AccessKeyID, |
| "secret_key": creds.SecretAccessKey, |
| "session_token": creds.SessionToken, |
| "table": table, |
| "region": region, |
| "endpoint": svccfg.URL().String(), |
| } |
| |
| b, err := NewDynamoDBBackend(config, logger) |
| if err != nil { |
| t.Fatalf("err: %s", err) |
| } |
| |
| b2, err := NewDynamoDBBackend(config, logger) |
| if err != nil { |
| t.Fatalf("err: %s", err) |
| } |
| |
| physical.ExerciseHABackend(t, b.(physical.HABackend), b2.(physical.HABackend)) |
| testDynamoDBLockTTL(t, b.(physical.HABackend)) |
| testDynamoDBLockRenewal(t, b.(physical.HABackend)) |
| } |
| |
| // Similar to testHABackend, but using internal implementation details to |
| // trigger the lock failure scenario by setting the lock renew period for one |
| // of the locks to a higher value than the lock TTL. |
| func testDynamoDBLockTTL(t *testing.T, ha physical.HABackend) { |
| // Set much smaller lock times to speed up the test. |
| lockTTL := time.Second * 3 |
| renewInterval := time.Second * 1 |
| watchInterval := time.Second * 1 |
| |
| // Get the lock |
| origLock, err := ha.LockWith("dynamodbttl", "bar") |
| if err != nil { |
| t.Fatalf("err: %v", err) |
| } |
| // set the first lock renew period to double the expected TTL. |
| lock := origLock.(*DynamoDBLock) |
| lock.renewInterval = lockTTL * 2 |
| lock.ttl = lockTTL |
| lock.watchRetryInterval = watchInterval |
| |
| // Attempt to lock |
| leaderCh, err := lock.Lock(nil) |
| if err != nil { |
| t.Fatalf("err: %v", err) |
| } |
| if leaderCh == nil { |
| t.Fatalf("failed to get leader ch") |
| } |
| |
| // Check the value |
| held, val, err := lock.Value() |
| if err != nil { |
| t.Fatalf("err: %v", err) |
| } |
| if !held { |
| t.Fatalf("should be held") |
| } |
| if val != "bar" { |
| t.Fatalf("bad value: %v", err) |
| } |
| |
| // Second acquisition should succeed because the first lock should |
| // not renew within the 3 sec TTL. |
| origLock2, err := ha.LockWith("dynamodbttl", "baz") |
| if err != nil { |
| t.Fatalf("err: %v", err) |
| } |
| |
| lock2 := origLock2.(*DynamoDBLock) |
| lock2.renewInterval = renewInterval |
| lock2.ttl = lockTTL |
| lock2.watchRetryInterval = watchInterval |
| |
| // Cancel attempt eventually so as not to block unit tests forever |
| stopCh := make(chan struct{}) |
| time.AfterFunc(lockTTL*10, func() { |
| close(stopCh) |
| }) |
| |
| // Attempt to lock should work |
| leaderCh2, err := lock2.Lock(stopCh) |
| if err != nil { |
| t.Fatalf("err: %v", err) |
| } |
| if leaderCh2 == nil { |
| t.Fatalf("should get leader ch") |
| } |
| |
| // Check the value |
| held, val, err = lock2.Value() |
| if err != nil { |
| t.Fatalf("err: %v", err) |
| } |
| if !held { |
| t.Fatalf("should be held") |
| } |
| if val != "baz" { |
| t.Fatalf("bad value: %v", err) |
| } |
| |
| // The first lock should have lost the leader channel |
| leaderChClosed := false |
| blocking := make(chan struct{}) |
| // Attempt to read from the leader or the blocking channel, which ever one |
| // happens first. |
| go func() { |
| select { |
| case <-time.After(watchInterval * 3): |
| return |
| case <-leaderCh: |
| leaderChClosed = true |
| close(blocking) |
| case <-blocking: |
| return |
| } |
| }() |
| |
| <-blocking |
| if !leaderChClosed { |
| t.Fatalf("original lock did not have its leader channel closed.") |
| } |
| |
| // Cleanup |
| lock2.Unlock() |
| } |
| |
| // Similar to testHABackend, but using internal implementation details to |
| // trigger a renewal before a "watch" check, which has been a source of |
| // race conditions. |
| func testDynamoDBLockRenewal(t *testing.T, ha physical.HABackend) { |
| renewInterval := time.Second * 1 |
| watchInterval := time.Second * 5 |
| |
| // Get the lock |
| origLock, err := ha.LockWith("dynamodbrenewal", "bar") |
| if err != nil { |
| t.Fatalf("err: %v", err) |
| } |
| |
| // customize the renewal and watch intervals |
| lock := origLock.(*DynamoDBLock) |
| lock.renewInterval = renewInterval |
| lock.watchRetryInterval = watchInterval |
| |
| // Attempt to lock |
| leaderCh, err := lock.Lock(nil) |
| if err != nil { |
| t.Fatalf("err: %v", err) |
| } |
| if leaderCh == nil { |
| t.Fatalf("failed to get leader ch") |
| } |
| |
| // Check the value |
| held, val, err := lock.Value() |
| if err != nil { |
| t.Fatalf("err: %v", err) |
| } |
| if !held { |
| t.Fatalf("should be held") |
| } |
| if val != "bar" { |
| t.Fatalf("bad value: %v", err) |
| } |
| |
| // Release the lock, which will delete the stored item |
| if err := lock.Unlock(); err != nil { |
| t.Fatalf("err: %v", err) |
| } |
| |
| // Wait longer than the renewal time, but less than the watch time |
| time.Sleep(1500 * time.Millisecond) |
| |
| // Attempt to lock with new lock |
| newLock, err := ha.LockWith("dynamodbrenewal", "baz") |
| if err != nil { |
| t.Fatalf("err: %v", err) |
| } |
| |
| // Cancel attempt in 6 sec so as not to block unit tests forever |
| stopCh := make(chan struct{}) |
| time.AfterFunc(6*time.Second, func() { |
| close(stopCh) |
| }) |
| |
| // Attempt to lock should work |
| leaderCh2, err := newLock.Lock(stopCh) |
| if err != nil { |
| t.Fatalf("err: %v", err) |
| } |
| if leaderCh2 == nil { |
| t.Fatalf("should get leader ch") |
| } |
| |
| // Check the value |
| held, val, err = newLock.Value() |
| if err != nil { |
| t.Fatalf("err: %v", err) |
| } |
| if !held { |
| t.Fatalf("should be held") |
| } |
| if val != "baz" { |
| t.Fatalf("bad value: %v", err) |
| } |
| |
| // Cleanup |
| newLock.Unlock() |
| } |
| |
| type Config struct { |
| docker.ServiceURL |
| Credentials *credentials.Credentials |
| } |
| |
| var _ docker.ServiceConfig = &Config{} |
| |
| func prepareDynamoDBTestContainer(t *testing.T) (func(), *Config) { |
| // If environment variable is set, assume caller wants to target a real |
| // DynamoDB. |
| if endpoint := os.Getenv("AWS_DYNAMODB_ENDPOINT"); endpoint != "" { |
| s, err := docker.NewServiceURLParse(endpoint) |
| if err != nil { |
| t.Fatal(err) |
| } |
| return func() {}, &Config{*s, credentials.NewEnvCredentials()} |
| } |
| |
| runner, err := docker.NewServiceRunner(docker.RunOptions{ |
| ImageRepo: "docker.mirror.hashicorp.services/cnadiminti/dynamodb-local", |
| ImageTag: "latest", |
| ContainerName: "dynamodb", |
| Ports: []string{"8000/tcp"}, |
| }) |
| if err != nil { |
| t.Fatalf("Could not start local DynamoDB: %s", err) |
| } |
| |
| svc, err := runner.StartService(context.Background(), connectDynamoDB) |
| if err != nil { |
| t.Fatalf("Could not start local DynamoDB: %s", err) |
| } |
| |
| return svc.Cleanup, svc.Config.(*Config) |
| } |
| |
| func connectDynamoDB(ctx context.Context, host string, port int) (docker.ServiceConfig, error) { |
| u := url.URL{ |
| Scheme: "http", |
| Host: fmt.Sprintf("%s:%d", host, port), |
| } |
| resp, err := http.Get(u.String()) |
| if err != nil { |
| return nil, err |
| } |
| if resp.StatusCode != 400 { |
| return nil, err |
| } |
| |
| return &Config{ |
| ServiceURL: *docker.NewServiceURL(u), |
| Credentials: credentials.NewStaticCredentials("fake", "fake", ""), |
| }, nil |
| } |