| // Copyright (c) HashiCorp, Inc. |
| // SPDX-License-Identifier: MPL-2.0 |
| |
| package transit |
| |
| import ( |
| "context" |
| "encoding/hex" |
| "encoding/json" |
| "fmt" |
| "strconv" |
| "strings" |
| "testing" |
| "time" |
| |
| uuid "github.com/hashicorp/go-uuid" |
| "github.com/hashicorp/vault/api" |
| vaulthttp "github.com/hashicorp/vault/http" |
| "github.com/hashicorp/vault/sdk/logical" |
| "github.com/hashicorp/vault/vault" |
| ) |
| |
| func TestTransit_ConfigSettings(t *testing.T) { |
| b, storage := createBackendWithSysView(t) |
| |
| doReq := func(req *logical.Request) *logical.Response { |
| resp, err := b.HandleRequest(context.Background(), req) |
| if err != nil || (resp != nil && resp.IsError()) { |
| t.Fatalf("got err:\n%#v\nreq:\n%#v\n", err, *req) |
| } |
| return resp |
| } |
| doErrReq := func(req *logical.Request) { |
| resp, err := b.HandleRequest(context.Background(), req) |
| if err == nil { |
| if resp == nil || !resp.IsError() { |
| t.Fatalf("expected error; req:\n%#v\n", *req) |
| } |
| } |
| } |
| |
| // First create a key |
| req := &logical.Request{ |
| Storage: storage, |
| Operation: logical.UpdateOperation, |
| Path: "keys/aes256", |
| Data: map[string]interface{}{ |
| "derived": true, |
| }, |
| } |
| doReq(req) |
| |
| req.Path = "keys/aes128" |
| req.Data["type"] = "aes128-gcm96" |
| doReq(req) |
| |
| req.Path = "keys/ed" |
| req.Data["type"] = "ed25519" |
| doReq(req) |
| |
| delete(req.Data, "derived") |
| |
| req.Path = "keys/p256" |
| req.Data["type"] = "ecdsa-p256" |
| doReq(req) |
| |
| req.Path = "keys/p384" |
| req.Data["type"] = "ecdsa-p384" |
| doReq(req) |
| |
| req.Path = "keys/p521" |
| req.Data["type"] = "ecdsa-p521" |
| doReq(req) |
| |
| delete(req.Data, "type") |
| |
| req.Path = "keys/aes128/rotate" |
| doReq(req) |
| doReq(req) |
| doReq(req) |
| doReq(req) |
| |
| req.Path = "keys/aes256/rotate" |
| doReq(req) |
| doReq(req) |
| doReq(req) |
| doReq(req) |
| |
| req.Path = "keys/ed/rotate" |
| doReq(req) |
| doReq(req) |
| doReq(req) |
| doReq(req) |
| |
| req.Path = "keys/p256/rotate" |
| doReq(req) |
| doReq(req) |
| doReq(req) |
| doReq(req) |
| |
| req.Path = "keys/p384/rotate" |
| doReq(req) |
| doReq(req) |
| doReq(req) |
| doReq(req) |
| |
| req.Path = "keys/p521/rotate" |
| doReq(req) |
| doReq(req) |
| doReq(req) |
| doReq(req) |
| |
| req.Path = "keys/aes256/config" |
| // Too high |
| req.Data["min_decryption_version"] = 7 |
| doErrReq(req) |
| // Too low |
| req.Data["min_decryption_version"] = -1 |
| doErrReq(req) |
| |
| delete(req.Data, "min_decryption_version") |
| // Too high |
| req.Data["min_encryption_version"] = 7 |
| doErrReq(req) |
| // Too low |
| req.Data["min_encryption_version"] = 7 |
| doErrReq(req) |
| |
| // Not allowed, cannot decrypt |
| req.Data["min_decryption_version"] = 3 |
| req.Data["min_encryption_version"] = 2 |
| doErrReq(req) |
| |
| // Allowed |
| req.Data["min_decryption_version"] = 2 |
| req.Data["min_encryption_version"] = 3 |
| doReq(req) |
| req.Path = "keys/aes128/config" |
| doReq(req) |
| req.Path = "keys/ed/config" |
| doReq(req) |
| req.Path = "keys/p256/config" |
| doReq(req) |
| req.Path = "keys/p384/config" |
| doReq(req) |
| |
| req.Path = "keys/p521/config" |
| doReq(req) |
| |
| req.Data = map[string]interface{}{ |
| "plaintext": "abcd", |
| "input": "abcd", |
| "context": "abcd", |
| } |
| |
| maxKeyVersion := 5 |
| key := "aes256" |
| |
| testHMAC := func(ver int, valid bool) { |
| req.Path = "hmac/" + key |
| delete(req.Data, "hmac") |
| if ver == maxKeyVersion { |
| delete(req.Data, "key_version") |
| } else { |
| req.Data["key_version"] = ver |
| } |
| |
| if !valid { |
| doErrReq(req) |
| return |
| } |
| |
| resp := doReq(req) |
| ct := resp.Data["hmac"].(string) |
| if strings.Split(ct, ":")[1] != "v"+strconv.Itoa(ver) { |
| t.Fatal("wrong hmac version") |
| } |
| |
| req.Path = "verify/" + key |
| delete(req.Data, "key_version") |
| req.Data["hmac"] = resp.Data["hmac"] |
| doReq(req) |
| } |
| |
| testEncryptDecrypt := func(ver int, valid bool) { |
| req.Path = "encrypt/" + key |
| delete(req.Data, "ciphertext") |
| if ver == maxKeyVersion { |
| delete(req.Data, "key_version") |
| } else { |
| req.Data["key_version"] = ver |
| } |
| |
| if !valid { |
| doErrReq(req) |
| return |
| } |
| |
| resp := doReq(req) |
| ct := resp.Data["ciphertext"].(string) |
| if strings.Split(ct, ":")[1] != "v"+strconv.Itoa(ver) { |
| t.Fatal("wrong encryption version") |
| } |
| |
| req.Path = "decrypt/" + key |
| delete(req.Data, "key_version") |
| req.Data["ciphertext"] = resp.Data["ciphertext"] |
| doReq(req) |
| } |
| testEncryptDecrypt(5, true) |
| testEncryptDecrypt(4, true) |
| testEncryptDecrypt(3, true) |
| testEncryptDecrypt(2, false) |
| testHMAC(5, true) |
| testHMAC(4, true) |
| testHMAC(3, true) |
| testHMAC(2, false) |
| |
| key = "aes128" |
| testEncryptDecrypt(5, true) |
| testEncryptDecrypt(4, true) |
| testEncryptDecrypt(3, true) |
| testEncryptDecrypt(2, false) |
| testHMAC(5, true) |
| testHMAC(4, true) |
| testHMAC(3, true) |
| testHMAC(2, false) |
| |
| delete(req.Data, "plaintext") |
| req.Data["input"] = "abcd" |
| key = "ed" |
| testSignVerify := func(ver int, valid bool) { |
| req.Path = "sign/" + key |
| delete(req.Data, "signature") |
| if ver == maxKeyVersion { |
| delete(req.Data, "key_version") |
| } else { |
| req.Data["key_version"] = ver |
| } |
| |
| if !valid { |
| doErrReq(req) |
| return |
| } |
| |
| resp := doReq(req) |
| ct := resp.Data["signature"].(string) |
| if strings.Split(ct, ":")[1] != "v"+strconv.Itoa(ver) { |
| t.Fatal("wrong signature version") |
| } |
| |
| req.Path = "verify/" + key |
| delete(req.Data, "key_version") |
| req.Data["signature"] = resp.Data["signature"] |
| doReq(req) |
| } |
| testSignVerify(5, true) |
| testSignVerify(4, true) |
| testSignVerify(3, true) |
| testSignVerify(2, false) |
| testHMAC(5, true) |
| testHMAC(4, true) |
| testHMAC(3, true) |
| testHMAC(2, false) |
| |
| delete(req.Data, "context") |
| key = "p256" |
| testSignVerify(5, true) |
| testSignVerify(4, true) |
| testSignVerify(3, true) |
| testSignVerify(2, false) |
| testHMAC(5, true) |
| testHMAC(4, true) |
| testHMAC(3, true) |
| testHMAC(2, false) |
| |
| key = "p384" |
| testSignVerify(5, true) |
| testSignVerify(4, true) |
| testSignVerify(3, true) |
| testSignVerify(2, false) |
| testHMAC(5, true) |
| testHMAC(4, true) |
| testHMAC(3, true) |
| testHMAC(2, false) |
| |
| key = "p521" |
| testSignVerify(5, true) |
| testSignVerify(4, true) |
| testSignVerify(3, true) |
| testSignVerify(2, false) |
| testHMAC(5, true) |
| testHMAC(4, true) |
| testHMAC(3, true) |
| testHMAC(2, false) |
| } |
| |
| func TestTransit_UpdateKeyConfigWithAutorotation(t *testing.T) { |
| tests := map[string]struct { |
| initialAutoRotatePeriod interface{} |
| newAutoRotatePeriod interface{} |
| shouldError bool |
| expectedValue time.Duration |
| }{ |
| "default (no value)": { |
| initialAutoRotatePeriod: "5h", |
| shouldError: false, |
| expectedValue: 5 * time.Hour, |
| }, |
| "0 (int)": { |
| initialAutoRotatePeriod: "5h", |
| newAutoRotatePeriod: 0, |
| shouldError: false, |
| expectedValue: 0, |
| }, |
| "0 (string)": { |
| initialAutoRotatePeriod: "5h", |
| newAutoRotatePeriod: 0, |
| shouldError: false, |
| expectedValue: 0, |
| }, |
| "5 seconds": { |
| newAutoRotatePeriod: "5s", |
| shouldError: true, |
| }, |
| "5 hours": { |
| newAutoRotatePeriod: "5h", |
| shouldError: false, |
| expectedValue: 5 * time.Hour, |
| }, |
| "negative value": { |
| newAutoRotatePeriod: "-1800s", |
| shouldError: true, |
| }, |
| "invalid string": { |
| newAutoRotatePeriod: "this shouldn't work", |
| shouldError: true, |
| }, |
| } |
| |
| coreConfig := &vault.CoreConfig{ |
| LogicalBackends: map[string]logical.Factory{ |
| "transit": Factory, |
| }, |
| } |
| cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ |
| HandlerFunc: vaulthttp.Handler, |
| }) |
| cluster.Start() |
| defer cluster.Cleanup() |
| cores := cluster.Cores |
| vault.TestWaitActive(t, cores[0].Core) |
| client := cores[0].Client |
| err := client.Sys().Mount("transit", &api.MountInput{ |
| Type: "transit", |
| }) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| for name, test := range tests { |
| t.Run(name, func(t *testing.T) { |
| keyNameBytes, err := uuid.GenerateRandomBytes(16) |
| if err != nil { |
| t.Fatal(err) |
| } |
| keyName := hex.EncodeToString(keyNameBytes) |
| |
| _, err = client.Logical().Write(fmt.Sprintf("transit/keys/%s", keyName), map[string]interface{}{ |
| "auto_rotate_period": test.initialAutoRotatePeriod, |
| }) |
| if err != nil { |
| t.Fatal(err) |
| } |
| resp, err := client.Logical().Write(fmt.Sprintf("transit/keys/%s/config", keyName), map[string]interface{}{ |
| "auto_rotate_period": test.newAutoRotatePeriod, |
| }) |
| switch { |
| case test.shouldError && err == nil: |
| t.Fatal("expected non-nil error") |
| case !test.shouldError && err != nil: |
| t.Fatal(err) |
| } |
| |
| if !test.shouldError { |
| resp, err = client.Logical().Read(fmt.Sprintf("transit/keys/%s", keyName)) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if resp == nil { |
| t.Fatal("expected non-nil response") |
| } |
| gotRaw, ok := resp.Data["auto_rotate_period"].(json.Number) |
| if !ok { |
| t.Fatal("returned value is of unexpected type") |
| } |
| got, err := gotRaw.Int64() |
| if err != nil { |
| t.Fatal(err) |
| } |
| want := int64(test.expectedValue.Seconds()) |
| if got != want { |
| t.Fatalf("incorrect auto_rotate_period returned, got: %d, want: %d", got, want) |
| } |
| } |
| }) |
| } |
| } |