blob: 4b3303988eed515e5b064824fe92c84287c098cf [file] [log] [blame]
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package transit_test
import (
"encoding/hex"
"encoding/json"
"fmt"
"testing"
"time"
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/audit"
"github.com/hashicorp/vault/builtin/audit/file"
"github.com/hashicorp/vault/builtin/logical/transit"
vaulthttp "github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/vault"
)
func TestTransit_Issue_2958(t *testing.T) {
coreConfig := &vault.CoreConfig{
LogicalBackends: map[string]logical.Factory{
"transit": transit.Factory,
},
AuditBackends: map[string]audit.Factory{
"file": file.Factory,
},
}
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
})
cluster.Start()
defer cluster.Cleanup()
cores := cluster.Cores
vault.TestWaitActive(t, cores[0].Core)
client := cores[0].Client
err := client.Sys().EnableAuditWithOptions("file", &api.EnableAuditOptions{
Type: "file",
Options: map[string]string{
"file_path": "/dev/null",
},
})
if err != nil {
t.Fatal(err)
}
err = client.Sys().Mount("transit", &api.MountInput{
Type: "transit",
})
if err != nil {
t.Fatal(err)
}
_, err = client.Logical().Write("transit/keys/foo", map[string]interface{}{
"type": "ecdsa-p256",
})
if err != nil {
t.Fatal(err)
}
_, err = client.Logical().Write("transit/keys/foobar", map[string]interface{}{
"type": "ecdsa-p384",
})
if err != nil {
t.Fatal(err)
}
_, err = client.Logical().Write("transit/keys/bar", map[string]interface{}{
"type": "ed25519",
})
if err != nil {
t.Fatal(err)
}
_, err = client.Logical().Read("transit/keys/foo")
if err != nil {
t.Fatal(err)
}
_, err = client.Logical().Read("transit/keys/foobar")
if err != nil {
t.Fatal(err)
}
_, err = client.Logical().Read("transit/keys/bar")
if err != nil {
t.Fatal(err)
}
}
func TestTransit_CreateKeyWithAutorotation(t *testing.T) {
tests := map[string]struct {
autoRotatePeriod interface{}
shouldError bool
expectedValue time.Duration
}{
"default (no value)": {
shouldError: false,
},
"0 (int)": {
autoRotatePeriod: 0,
shouldError: false,
expectedValue: 0,
},
"0 (string)": {
autoRotatePeriod: "0",
shouldError: false,
expectedValue: 0,
},
"5 seconds": {
autoRotatePeriod: "5s",
shouldError: true,
},
"5 hours": {
autoRotatePeriod: "5h",
shouldError: false,
expectedValue: 5 * time.Hour,
},
"negative value": {
autoRotatePeriod: "-1800s",
shouldError: true,
},
"invalid string": {
autoRotatePeriod: "this shouldn't work",
shouldError: true,
},
}
coreConfig := &vault.CoreConfig{
LogicalBackends: map[string]logical.Factory{
"transit": transit.Factory,
},
}
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
})
cluster.Start()
defer cluster.Cleanup()
cores := cluster.Cores
vault.TestWaitActive(t, cores[0].Core)
client := cores[0].Client
err := client.Sys().Mount("transit", &api.MountInput{
Type: "transit",
})
if err != nil {
t.Fatal(err)
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
keyNameBytes, err := uuid.GenerateRandomBytes(16)
if err != nil {
t.Fatal(err)
}
keyName := hex.EncodeToString(keyNameBytes)
_, err = client.Logical().Write(fmt.Sprintf("transit/keys/%s", keyName), map[string]interface{}{
"auto_rotate_period": test.autoRotatePeriod,
})
switch {
case test.shouldError && err == nil:
t.Fatal("expected non-nil error")
case !test.shouldError && err != nil:
t.Fatal(err)
}
if !test.shouldError {
resp, err := client.Logical().Read(fmt.Sprintf("transit/keys/%s", keyName))
if err != nil {
t.Fatal(err)
}
if resp == nil {
t.Fatal("expected non-nil response")
}
gotRaw, ok := resp.Data["auto_rotate_period"].(json.Number)
if !ok {
t.Fatal("returned value is of unexpected type")
}
got, err := gotRaw.Int64()
if err != nil {
t.Fatal(err)
}
want := int64(test.expectedValue.Seconds())
if got != want {
t.Fatalf("incorrect auto_rotate_period returned, got: %d, want: %d", got, want)
}
}
})
}
}