blob: 074893fcfb88c3c30f2018e046112f61ab5de582 [file] [log] [blame]
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package server
import (
"context"
"fmt"
"net/url"
"path"
"reflect"
"testing"
"time"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/internalshared/configutil"
"github.com/hashicorp/vault/sdk/helper/docker"
)
func TestTransitWrapper_Lifecycle(t *testing.T) {
cleanup, config := prepareTestContainer(t)
defer cleanup()
wrapperConfig := map[string]string{
"address": config.URL().String(),
"token": config.token,
"mount_path": config.mountPath,
"key_name": config.keyName,
}
kms, _, err := configutil.GetTransitKMSFunc(&configutil.KMS{Config: wrapperConfig})
if err != nil {
t.Fatalf("error setting wrapper config: %v", err)
}
// Test Encrypt and Decrypt calls
input := []byte("foo")
swi, err := kms.Encrypt(context.Background(), input, nil)
if err != nil {
t.Fatalf("err: %s", err.Error())
}
pt, err := kms.Decrypt(context.Background(), swi, nil)
if err != nil {
t.Fatalf("err: %s", err.Error())
}
if !reflect.DeepEqual(input, pt) {
t.Fatalf("expected %s, got %s", input, pt)
}
}
func TestTransitSeal_TokenRenewal(t *testing.T) {
cleanup, config := prepareTestContainer(t)
defer cleanup()
remoteClient, err := api.NewClient(config.apiConfig())
if err != nil {
t.Fatalf("err: %s", err)
}
remoteClient.SetToken(config.token)
req := &api.TokenCreateRequest{
Period: "5s",
}
rsp, err := remoteClient.Auth().Token().Create(req)
if err != nil {
t.Fatalf("err: %s", err)
}
wrapperConfig := map[string]string{
"address": config.URL().String(),
"token": rsp.Auth.ClientToken,
"mount_path": config.mountPath,
"key_name": config.keyName,
}
kms, _, err := configutil.GetTransitKMSFunc(&configutil.KMS{Config: wrapperConfig})
if err != nil {
t.Fatalf("error setting wrapper config: %v", err)
}
time.Sleep(7 * time.Second)
// Test Encrypt and Decrypt calls
input := []byte("foo")
swi, err := kms.Encrypt(context.Background(), input, nil)
if err != nil {
t.Fatalf("err: %s", err.Error())
}
pt, err := kms.Decrypt(context.Background(), swi, nil)
if err != nil {
t.Fatalf("err: %s", err.Error())
}
if !reflect.DeepEqual(input, pt) {
t.Fatalf("expected %s, got %s", input, pt)
}
}
type DockerVaultConfig struct {
docker.ServiceURL
token string
mountPath string
keyName string
tlsConfig *api.TLSConfig
}
func (c *DockerVaultConfig) apiConfig() *api.Config {
vaultConfig := api.DefaultConfig()
vaultConfig.Address = c.URL().String()
if err := vaultConfig.ConfigureTLS(c.tlsConfig); err != nil {
panic("unable to configure TLS")
}
return vaultConfig
}
var _ docker.ServiceConfig = &DockerVaultConfig{}
func prepareTestContainer(t *testing.T) (func(), *DockerVaultConfig) {
rootToken, err := uuid.GenerateUUID()
if err != nil {
t.Fatalf("err: %s", err)
}
testMountPath, err := uuid.GenerateUUID()
if err != nil {
t.Fatalf("err: %s", err)
}
testKeyName, err := uuid.GenerateUUID()
if err != nil {
t.Fatalf("err: %s", err)
}
runner, err := docker.NewServiceRunner(docker.RunOptions{
ContainerName: "vault",
ImageRepo: "docker.mirror.hashicorp.services/hashicorp/vault",
ImageTag: "latest",
Cmd: []string{
"server", "-log-level=trace", "-dev", fmt.Sprintf("-dev-root-token-id=%s", rootToken),
"-dev-listen-address=0.0.0.0:8200",
},
Ports: []string{"8200/tcp"},
})
if err != nil {
t.Fatalf("could not start docker vault: %s", err)
}
svc, err := runner.StartService(context.Background(), func(ctx context.Context, host string, port int) (docker.ServiceConfig, error) {
c := &DockerVaultConfig{
ServiceURL: *docker.NewServiceURL(url.URL{Scheme: "http", Host: fmt.Sprintf("%s:%d", host, port)}),
tlsConfig: &api.TLSConfig{
Insecure: true,
},
token: rootToken,
mountPath: testMountPath,
keyName: testKeyName,
}
vault, err := api.NewClient(c.apiConfig())
if err != nil {
return nil, err
}
vault.SetToken(rootToken)
// Set up transit
if err := vault.Sys().Mount(testMountPath, &api.MountInput{
Type: "transit",
}); err != nil {
return nil, err
}
// Create default aesgcm key
if _, err := vault.Logical().Write(path.Join(testMountPath, "keys", testKeyName), map[string]interface{}{}); err != nil {
return nil, err
}
return c, nil
})
if err != nil {
t.Fatalf("could not start docker vault: %s", err)
}
return svc.Cleanup, svc.Config.(*DockerVaultConfig)
}