blob: 335607c3b0e1d5739d1022737de79fe9b4e13c36 [file] [log] [blame]
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package transit
import (
"context"
"encoding/hex"
"encoding/json"
"fmt"
"strconv"
"strings"
"testing"
"time"
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/api"
vaulthttp "github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/vault"
)
func TestTransit_ConfigSettings(t *testing.T) {
b, storage := createBackendWithSysView(t)
doReq := func(req *logical.Request) *logical.Response {
resp, err := b.HandleRequest(context.Background(), req)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("got err:\n%#v\nreq:\n%#v\n", err, *req)
}
return resp
}
doErrReq := func(req *logical.Request) {
resp, err := b.HandleRequest(context.Background(), req)
if err == nil {
if resp == nil || !resp.IsError() {
t.Fatalf("expected error; req:\n%#v\n", *req)
}
}
}
// First create a key
req := &logical.Request{
Storage: storage,
Operation: logical.UpdateOperation,
Path: "keys/aes256",
Data: map[string]interface{}{
"derived": true,
},
}
doReq(req)
req.Path = "keys/aes128"
req.Data["type"] = "aes128-gcm96"
doReq(req)
req.Path = "keys/ed"
req.Data["type"] = "ed25519"
doReq(req)
delete(req.Data, "derived")
req.Path = "keys/p256"
req.Data["type"] = "ecdsa-p256"
doReq(req)
req.Path = "keys/p384"
req.Data["type"] = "ecdsa-p384"
doReq(req)
req.Path = "keys/p521"
req.Data["type"] = "ecdsa-p521"
doReq(req)
delete(req.Data, "type")
req.Path = "keys/aes128/rotate"
doReq(req)
doReq(req)
doReq(req)
doReq(req)
req.Path = "keys/aes256/rotate"
doReq(req)
doReq(req)
doReq(req)
doReq(req)
req.Path = "keys/ed/rotate"
doReq(req)
doReq(req)
doReq(req)
doReq(req)
req.Path = "keys/p256/rotate"
doReq(req)
doReq(req)
doReq(req)
doReq(req)
req.Path = "keys/p384/rotate"
doReq(req)
doReq(req)
doReq(req)
doReq(req)
req.Path = "keys/p521/rotate"
doReq(req)
doReq(req)
doReq(req)
doReq(req)
req.Path = "keys/aes256/config"
// Too high
req.Data["min_decryption_version"] = 7
doErrReq(req)
// Too low
req.Data["min_decryption_version"] = -1
doErrReq(req)
delete(req.Data, "min_decryption_version")
// Too high
req.Data["min_encryption_version"] = 7
doErrReq(req)
// Too low
req.Data["min_encryption_version"] = 7
doErrReq(req)
// Not allowed, cannot decrypt
req.Data["min_decryption_version"] = 3
req.Data["min_encryption_version"] = 2
doErrReq(req)
// Allowed
req.Data["min_decryption_version"] = 2
req.Data["min_encryption_version"] = 3
doReq(req)
req.Path = "keys/aes128/config"
doReq(req)
req.Path = "keys/ed/config"
doReq(req)
req.Path = "keys/p256/config"
doReq(req)
req.Path = "keys/p384/config"
doReq(req)
req.Path = "keys/p521/config"
doReq(req)
req.Data = map[string]interface{}{
"plaintext": "abcd",
"input": "abcd",
"context": "abcd",
}
maxKeyVersion := 5
key := "aes256"
testHMAC := func(ver int, valid bool) {
req.Path = "hmac/" + key
delete(req.Data, "hmac")
if ver == maxKeyVersion {
delete(req.Data, "key_version")
} else {
req.Data["key_version"] = ver
}
if !valid {
doErrReq(req)
return
}
resp := doReq(req)
ct := resp.Data["hmac"].(string)
if strings.Split(ct, ":")[1] != "v"+strconv.Itoa(ver) {
t.Fatal("wrong hmac version")
}
req.Path = "verify/" + key
delete(req.Data, "key_version")
req.Data["hmac"] = resp.Data["hmac"]
doReq(req)
}
testEncryptDecrypt := func(ver int, valid bool) {
req.Path = "encrypt/" + key
delete(req.Data, "ciphertext")
if ver == maxKeyVersion {
delete(req.Data, "key_version")
} else {
req.Data["key_version"] = ver
}
if !valid {
doErrReq(req)
return
}
resp := doReq(req)
ct := resp.Data["ciphertext"].(string)
if strings.Split(ct, ":")[1] != "v"+strconv.Itoa(ver) {
t.Fatal("wrong encryption version")
}
req.Path = "decrypt/" + key
delete(req.Data, "key_version")
req.Data["ciphertext"] = resp.Data["ciphertext"]
doReq(req)
}
testEncryptDecrypt(5, true)
testEncryptDecrypt(4, true)
testEncryptDecrypt(3, true)
testEncryptDecrypt(2, false)
testHMAC(5, true)
testHMAC(4, true)
testHMAC(3, true)
testHMAC(2, false)
key = "aes128"
testEncryptDecrypt(5, true)
testEncryptDecrypt(4, true)
testEncryptDecrypt(3, true)
testEncryptDecrypt(2, false)
testHMAC(5, true)
testHMAC(4, true)
testHMAC(3, true)
testHMAC(2, false)
delete(req.Data, "plaintext")
req.Data["input"] = "abcd"
key = "ed"
testSignVerify := func(ver int, valid bool) {
req.Path = "sign/" + key
delete(req.Data, "signature")
if ver == maxKeyVersion {
delete(req.Data, "key_version")
} else {
req.Data["key_version"] = ver
}
if !valid {
doErrReq(req)
return
}
resp := doReq(req)
ct := resp.Data["signature"].(string)
if strings.Split(ct, ":")[1] != "v"+strconv.Itoa(ver) {
t.Fatal("wrong signature version")
}
req.Path = "verify/" + key
delete(req.Data, "key_version")
req.Data["signature"] = resp.Data["signature"]
doReq(req)
}
testSignVerify(5, true)
testSignVerify(4, true)
testSignVerify(3, true)
testSignVerify(2, false)
testHMAC(5, true)
testHMAC(4, true)
testHMAC(3, true)
testHMAC(2, false)
delete(req.Data, "context")
key = "p256"
testSignVerify(5, true)
testSignVerify(4, true)
testSignVerify(3, true)
testSignVerify(2, false)
testHMAC(5, true)
testHMAC(4, true)
testHMAC(3, true)
testHMAC(2, false)
key = "p384"
testSignVerify(5, true)
testSignVerify(4, true)
testSignVerify(3, true)
testSignVerify(2, false)
testHMAC(5, true)
testHMAC(4, true)
testHMAC(3, true)
testHMAC(2, false)
key = "p521"
testSignVerify(5, true)
testSignVerify(4, true)
testSignVerify(3, true)
testSignVerify(2, false)
testHMAC(5, true)
testHMAC(4, true)
testHMAC(3, true)
testHMAC(2, false)
}
func TestTransit_UpdateKeyConfigWithAutorotation(t *testing.T) {
tests := map[string]struct {
initialAutoRotatePeriod interface{}
newAutoRotatePeriod interface{}
shouldError bool
expectedValue time.Duration
}{
"default (no value)": {
initialAutoRotatePeriod: "5h",
shouldError: false,
expectedValue: 5 * time.Hour,
},
"0 (int)": {
initialAutoRotatePeriod: "5h",
newAutoRotatePeriod: 0,
shouldError: false,
expectedValue: 0,
},
"0 (string)": {
initialAutoRotatePeriod: "5h",
newAutoRotatePeriod: 0,
shouldError: false,
expectedValue: 0,
},
"5 seconds": {
newAutoRotatePeriod: "5s",
shouldError: true,
},
"5 hours": {
newAutoRotatePeriod: "5h",
shouldError: false,
expectedValue: 5 * time.Hour,
},
"negative value": {
newAutoRotatePeriod: "-1800s",
shouldError: true,
},
"invalid string": {
newAutoRotatePeriod: "this shouldn't work",
shouldError: true,
},
}
coreConfig := &vault.CoreConfig{
LogicalBackends: map[string]logical.Factory{
"transit": Factory,
},
}
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
})
cluster.Start()
defer cluster.Cleanup()
cores := cluster.Cores
vault.TestWaitActive(t, cores[0].Core)
client := cores[0].Client
err := client.Sys().Mount("transit", &api.MountInput{
Type: "transit",
})
if err != nil {
t.Fatal(err)
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
keyNameBytes, err := uuid.GenerateRandomBytes(16)
if err != nil {
t.Fatal(err)
}
keyName := hex.EncodeToString(keyNameBytes)
_, err = client.Logical().Write(fmt.Sprintf("transit/keys/%s", keyName), map[string]interface{}{
"auto_rotate_period": test.initialAutoRotatePeriod,
})
if err != nil {
t.Fatal(err)
}
resp, err := client.Logical().Write(fmt.Sprintf("transit/keys/%s/config", keyName), map[string]interface{}{
"auto_rotate_period": test.newAutoRotatePeriod,
})
switch {
case test.shouldError && err == nil:
t.Fatal("expected non-nil error")
case !test.shouldError && err != nil:
t.Fatal(err)
}
if !test.shouldError {
resp, err = client.Logical().Read(fmt.Sprintf("transit/keys/%s", keyName))
if err != nil {
t.Fatal(err)
}
if resp == nil {
t.Fatal("expected non-nil response")
}
gotRaw, ok := resp.Data["auto_rotate_period"].(json.Number)
if !ok {
t.Fatal("returned value is of unexpected type")
}
got, err := gotRaw.Int64()
if err != nil {
t.Fatal(err)
}
want := int64(test.expectedValue.Seconds())
if got != want {
t.Fatalf("incorrect auto_rotate_period returned, got: %d, want: %d", got, want)
}
}
})
}
}