| // Copyright (c) HashiCorp, Inc. |
| // SPDX-License-Identifier: MPL-2.0 |
| |
| package keysutil |
| |
| import ( |
| "bytes" |
| "context" |
| "crypto/ecdsa" |
| "crypto/elliptic" |
| "crypto/rand" |
| "crypto/rsa" |
| "crypto/x509" |
| "errors" |
| "fmt" |
| mathrand "math/rand" |
| "reflect" |
| "strconv" |
| "strings" |
| "sync" |
| "testing" |
| "time" |
| |
| "golang.org/x/crypto/ed25519" |
| |
| "github.com/hashicorp/vault/sdk/helper/errutil" |
| "github.com/hashicorp/vault/sdk/helper/jsonutil" |
| "github.com/hashicorp/vault/sdk/logical" |
| "github.com/mitchellh/copystructure" |
| ) |
| |
| func TestPolicy_KeyEntryMapUpgrade(t *testing.T) { |
| now := time.Now() |
| old := map[int]KeyEntry{ |
| 1: { |
| Key: []byte("samplekey"), |
| HMACKey: []byte("samplehmackey"), |
| CreationTime: now, |
| FormattedPublicKey: "sampleformattedpublickey", |
| }, |
| 2: { |
| Key: []byte("samplekey2"), |
| HMACKey: []byte("samplehmackey2"), |
| CreationTime: now.Add(10 * time.Second), |
| FormattedPublicKey: "sampleformattedpublickey2", |
| }, |
| } |
| |
| oldEncoded, err := jsonutil.EncodeJSON(old) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| var new keyEntryMap |
| err = jsonutil.DecodeJSON(oldEncoded, &new) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| newEncoded, err := jsonutil.EncodeJSON(&new) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| if string(oldEncoded) != string(newEncoded) { |
| t.Fatalf("failed to upgrade key entry map;\nold: %q\nnew: %q", string(oldEncoded), string(newEncoded)) |
| } |
| } |
| |
| func Test_KeyUpgrade(t *testing.T) { |
| lockManagerWithCache, _ := NewLockManager(true, 0) |
| lockManagerWithoutCache, _ := NewLockManager(false, 0) |
| testKeyUpgradeCommon(t, lockManagerWithCache) |
| testKeyUpgradeCommon(t, lockManagerWithoutCache) |
| } |
| |
| func testKeyUpgradeCommon(t *testing.T, lm *LockManager) { |
| ctx := context.Background() |
| |
| storage := &logical.InmemStorage{} |
| p, upserted, err := lm.GetPolicy(ctx, PolicyRequest{ |
| Upsert: true, |
| Storage: storage, |
| KeyType: KeyType_AES256_GCM96, |
| Name: "test", |
| }, rand.Reader) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if p == nil { |
| t.Fatal("nil policy") |
| } |
| if !upserted { |
| t.Fatal("expected an upsert") |
| } |
| if !lm.useCache { |
| p.Unlock() |
| } |
| |
| testBytes := make([]byte, len(p.Keys["1"].Key)) |
| copy(testBytes, p.Keys["1"].Key) |
| |
| p.Key = p.Keys["1"].Key |
| p.Keys = nil |
| p.MigrateKeyToKeysMap() |
| if p.Key != nil { |
| t.Fatal("policy.Key is not nil") |
| } |
| if len(p.Keys) != 1 { |
| t.Fatal("policy.Keys is the wrong size") |
| } |
| if !reflect.DeepEqual(testBytes, p.Keys["1"].Key) { |
| t.Fatal("key mismatch") |
| } |
| } |
| |
| func Test_ArchivingUpgrade(t *testing.T) { |
| lockManagerWithCache, _ := NewLockManager(true, 0) |
| lockManagerWithoutCache, _ := NewLockManager(false, 0) |
| testArchivingUpgradeCommon(t, lockManagerWithCache) |
| testArchivingUpgradeCommon(t, lockManagerWithoutCache) |
| } |
| |
| func testArchivingUpgradeCommon(t *testing.T, lm *LockManager) { |
| ctx := context.Background() |
| |
| // First, we generate a policy and rotate it a number of times. Each time |
| // we'll ensure that we have the expected number of keys in the archive and |
| // the main keys object, which without changing the min version should be |
| // zero and latest, respectively |
| |
| storage := &logical.InmemStorage{} |
| p, _, err := lm.GetPolicy(ctx, PolicyRequest{ |
| Upsert: true, |
| Storage: storage, |
| KeyType: KeyType_AES256_GCM96, |
| Name: "test", |
| }, rand.Reader) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if p == nil { |
| t.Fatal("nil policy") |
| } |
| if !lm.useCache { |
| p.Unlock() |
| } |
| |
| // Store the initial key in the archive |
| keysArchive := []KeyEntry{{}, p.Keys["1"]} |
| checkKeys(t, ctx, p, storage, keysArchive, "initial", 1, 1, 1) |
| |
| for i := 2; i <= 10; i++ { |
| err = p.Rotate(ctx, storage, rand.Reader) |
| if err != nil { |
| t.Fatal(err) |
| } |
| keysArchive = append(keysArchive, p.Keys[strconv.Itoa(i)]) |
| checkKeys(t, ctx, p, storage, keysArchive, "rotate", i, i, i) |
| } |
| |
| // Now, wipe the archive and set the archive version to zero |
| err = storage.Delete(ctx, "archive/test") |
| if err != nil { |
| t.Fatal(err) |
| } |
| p.ArchiveVersion = 0 |
| |
| // Store it, but without calling persist, so we don't trigger |
| // handleArchiving() |
| buf, err := p.Serialize() |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| // Write the policy into storage |
| err = storage.Put(ctx, &logical.StorageEntry{ |
| Key: "policy/" + p.Name, |
| Value: buf, |
| }) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| // If we're caching, expire from the cache since we modified it |
| // under-the-hood |
| if lm.useCache { |
| lm.cache.Delete("test") |
| } |
| |
| // Now get the policy again; the upgrade should happen automatically |
| p, _, err = lm.GetPolicy(ctx, PolicyRequest{ |
| Storage: storage, |
| Name: "test", |
| }, rand.Reader) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if p == nil { |
| t.Fatal("nil policy") |
| } |
| if !lm.useCache { |
| p.Unlock() |
| } |
| |
| checkKeys(t, ctx, p, storage, keysArchive, "upgrade", 10, 10, 10) |
| |
| // Let's check some deletion logic while we're at it |
| |
| // The policy should be in there |
| if lm.useCache { |
| _, ok := lm.cache.Load("test") |
| if !ok { |
| t.Fatal("nil policy in cache") |
| } |
| } |
| |
| // First we'll do this wrong, by not setting the deletion flag |
| err = lm.DeletePolicy(ctx, storage, "test") |
| if err == nil { |
| t.Fatal("got nil error, but should not have been able to delete since we didn't set the deletion flag on the policy") |
| } |
| |
| // The policy should still be in there |
| if lm.useCache { |
| _, ok := lm.cache.Load("test") |
| if !ok { |
| t.Fatal("nil policy in cache") |
| } |
| } |
| |
| p, _, err = lm.GetPolicy(ctx, PolicyRequest{ |
| Storage: storage, |
| Name: "test", |
| }, rand.Reader) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if p == nil { |
| t.Fatal("policy nil after bad delete") |
| } |
| if !lm.useCache { |
| p.Unlock() |
| } |
| |
| // Now do it properly |
| p.DeletionAllowed = true |
| err = p.Persist(ctx, storage) |
| if err != nil { |
| t.Fatal(err) |
| } |
| err = lm.DeletePolicy(ctx, storage, "test") |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| // The policy should *not* be in there |
| if lm.useCache { |
| _, ok := lm.cache.Load("test") |
| if ok { |
| t.Fatal("non-nil policy in cache") |
| } |
| } |
| |
| p, _, err = lm.GetPolicy(ctx, PolicyRequest{ |
| Storage: storage, |
| Name: "test", |
| }, rand.Reader) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if p != nil { |
| t.Fatal("policy not nil after delete") |
| } |
| } |
| |
| func Test_Archiving(t *testing.T) { |
| lockManagerWithCache, _ := NewLockManager(true, 0) |
| lockManagerWithoutCache, _ := NewLockManager(false, 0) |
| testArchivingUpgradeCommon(t, lockManagerWithCache) |
| testArchivingUpgradeCommon(t, lockManagerWithoutCache) |
| } |
| |
| func testArchivingCommon(t *testing.T, lm *LockManager) { |
| ctx := context.Background() |
| |
| // First, we generate a policy and rotate it a number of times. Each time |
| // we'll ensure that we have the expected number of keys in the archive and |
| // the main keys object, which without changing the min version should be |
| // zero and latest, respectively |
| |
| storage := &logical.InmemStorage{} |
| p, _, err := lm.GetPolicy(ctx, PolicyRequest{ |
| Upsert: true, |
| Storage: storage, |
| KeyType: KeyType_AES256_GCM96, |
| Name: "test", |
| }, rand.Reader) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if p == nil { |
| t.Fatal("nil policy") |
| } |
| if !lm.useCache { |
| p.Unlock() |
| } |
| |
| // Store the initial key in the archive |
| keysArchive := []KeyEntry{{}, p.Keys["1"]} |
| checkKeys(t, ctx, p, storage, keysArchive, "initial", 1, 1, 1) |
| |
| for i := 2; i <= 10; i++ { |
| err = p.Rotate(ctx, storage, rand.Reader) |
| if err != nil { |
| t.Fatal(err) |
| } |
| keysArchive = append(keysArchive, p.Keys[strconv.Itoa(i)]) |
| checkKeys(t, ctx, p, storage, keysArchive, "rotate", i, i, i) |
| } |
| |
| // Move the min decryption version up |
| for i := 1; i <= 10; i++ { |
| p.MinDecryptionVersion = i |
| |
| err = p.Persist(ctx, storage) |
| if err != nil { |
| t.Fatal(err) |
| } |
| // We expect to find: |
| // * The keys in archive are the same as the latest version |
| // * The latest version is constant |
| // * The number of keys in the policy itself is from the min |
| // decryption version up to the latest version, so for e.g. 7 and |
| // 10, you'd need 7, 8, 9, and 10 -- IOW, latest version - min |
| // decryption version plus 1 (the min decryption version key |
| // itself) |
| checkKeys(t, ctx, p, storage, keysArchive, "minadd", 10, 10, p.LatestVersion-p.MinDecryptionVersion+1) |
| } |
| |
| // Move the min decryption version down |
| for i := 10; i >= 1; i-- { |
| p.MinDecryptionVersion = i |
| |
| err = p.Persist(ctx, storage) |
| if err != nil { |
| t.Fatal(err) |
| } |
| // We expect to find: |
| // * The keys in archive are never removed so same as the latest version |
| // * The latest version is constant |
| // * The number of keys in the policy itself is from the min |
| // decryption version up to the latest version, so for e.g. 7 and |
| // 10, you'd need 7, 8, 9, and 10 -- IOW, latest version - min |
| // decryption version plus 1 (the min decryption version key |
| // itself) |
| checkKeys(t, ctx, p, storage, keysArchive, "minsub", 10, 10, p.LatestVersion-p.MinDecryptionVersion+1) |
| } |
| } |
| |
| func checkKeys(t *testing.T, |
| ctx context.Context, |
| p *Policy, |
| storage logical.Storage, |
| keysArchive []KeyEntry, |
| action string, |
| archiveVer, latestVer, keysSize int, |
| ) { |
| // Sanity check |
| if len(keysArchive) != latestVer+1 { |
| t.Fatalf("latest expected key version is %d, expected test keys archive size is %d, "+ |
| "but keys archive is of size %d", latestVer, latestVer+1, len(keysArchive)) |
| } |
| |
| archive, err := p.LoadArchive(ctx, storage) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| badArchiveVer := false |
| if archiveVer == 0 { |
| if len(archive.Keys) != 0 || p.ArchiveVersion != 0 { |
| badArchiveVer = true |
| } |
| } else { |
| // We need to subtract one because we have the indexes match key |
| // versions, which start at 1. So for an archive version of 1, we |
| // actually have two entries -- a blank 0 entry, and the key at spot 1 |
| if archiveVer != len(archive.Keys)-1 || archiveVer != p.ArchiveVersion { |
| badArchiveVer = true |
| } |
| } |
| if badArchiveVer { |
| t.Fatalf( |
| "expected archive version %d, found length of archive keys %d and policy archive version %d", |
| archiveVer, len(archive.Keys), p.ArchiveVersion, |
| ) |
| } |
| |
| if latestVer != p.LatestVersion { |
| t.Fatalf( |
| "expected latest version %d, found %d", |
| latestVer, p.LatestVersion, |
| ) |
| } |
| |
| if keysSize != len(p.Keys) { |
| t.Fatalf( |
| "expected keys size %d, found %d, action is %s, policy is \n%#v\n", |
| keysSize, len(p.Keys), action, p, |
| ) |
| } |
| |
| for i := p.MinDecryptionVersion; i <= p.LatestVersion; i++ { |
| if _, ok := p.Keys[strconv.Itoa(i)]; !ok { |
| t.Fatalf( |
| "expected key %d, did not find it in policy keys", i, |
| ) |
| } |
| } |
| |
| for i := p.MinDecryptionVersion; i <= p.LatestVersion; i++ { |
| ver := strconv.Itoa(i) |
| if !p.Keys[ver].CreationTime.Equal(keysArchive[i].CreationTime) { |
| t.Fatalf("key %d not equivalent between policy keys and test keys archive; policy keys:\n%#v\ntest keys archive:\n%#v\n", i, p.Keys[ver], keysArchive[i]) |
| } |
| polKey := p.Keys[ver] |
| polKey.CreationTime = keysArchive[i].CreationTime |
| p.Keys[ver] = polKey |
| if !reflect.DeepEqual(p.Keys[ver], keysArchive[i]) { |
| t.Fatalf("key %d not equivalent between policy keys and test keys archive; policy keys:\n%#v\ntest keys archive:\n%#v\n", i, p.Keys[ver], keysArchive[i]) |
| } |
| } |
| |
| for i := 1; i < len(archive.Keys); i++ { |
| if !reflect.DeepEqual(archive.Keys[i].Key, keysArchive[i].Key) { |
| t.Fatalf("key %d not equivalent between policy archive and test keys archive; policy archive:\n%#v\ntest keys archive:\n%#v\n", i, archive.Keys[i].Key, keysArchive[i].Key) |
| } |
| } |
| } |
| |
| func Test_StorageErrorSafety(t *testing.T) { |
| ctx := context.Background() |
| lm, _ := NewLockManager(true, 0) |
| |
| storage := &logical.InmemStorage{} |
| p, _, err := lm.GetPolicy(ctx, PolicyRequest{ |
| Upsert: true, |
| Storage: storage, |
| KeyType: KeyType_AES256_GCM96, |
| Name: "test", |
| }, rand.Reader) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if p == nil { |
| t.Fatal("nil policy") |
| } |
| |
| // Store the initial key in the archive |
| keysArchive := []KeyEntry{{}, p.Keys["1"]} |
| checkKeys(t, ctx, p, storage, keysArchive, "initial", 1, 1, 1) |
| |
| // We use checkKeys here just for sanity; it doesn't really handle cases of |
| // errors below so we do more targeted testing later |
| for i := 2; i <= 5; i++ { |
| err = p.Rotate(ctx, storage, rand.Reader) |
| if err != nil { |
| t.Fatal(err) |
| } |
| keysArchive = append(keysArchive, p.Keys[strconv.Itoa(i)]) |
| checkKeys(t, ctx, p, storage, keysArchive, "rotate", i, i, i) |
| } |
| |
| underlying := storage.Underlying() |
| underlying.FailPut(true) |
| |
| priorLen := len(p.Keys) |
| |
| err = p.Rotate(ctx, storage, rand.Reader) |
| if err == nil { |
| t.Fatal("expected error") |
| } |
| |
| if len(p.Keys) != priorLen { |
| t.Fatal("length of keys should not have changed") |
| } |
| } |
| |
| func Test_BadUpgrade(t *testing.T) { |
| ctx := context.Background() |
| lm, _ := NewLockManager(true, 0) |
| storage := &logical.InmemStorage{} |
| p, _, err := lm.GetPolicy(ctx, PolicyRequest{ |
| Upsert: true, |
| Storage: storage, |
| KeyType: KeyType_AES256_GCM96, |
| Name: "test", |
| }, rand.Reader) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if p == nil { |
| t.Fatal("nil policy") |
| } |
| |
| orig, err := copystructure.Copy(p) |
| if err != nil { |
| t.Fatal(err) |
| } |
| orig.(*Policy).l = p.l |
| |
| p.Key = p.Keys["1"].Key |
| p.Keys = nil |
| p.MinDecryptionVersion = 0 |
| |
| if err := p.Upgrade(ctx, storage, rand.Reader); err != nil { |
| t.Fatal(err) |
| } |
| |
| k := p.Keys["1"] |
| o := orig.(*Policy).Keys["1"] |
| k.CreationTime = o.CreationTime |
| k.HMACKey = o.HMACKey |
| p.Keys["1"] = k |
| p.versionPrefixCache = sync.Map{} |
| |
| if !reflect.DeepEqual(orig, p) { |
| t.Fatalf("not equal:\n%#v\n%#v", orig, p) |
| } |
| |
| // Do it again with a failing storage call |
| underlying := storage.Underlying() |
| underlying.FailPut(true) |
| |
| p.Key = p.Keys["1"].Key |
| p.Keys = nil |
| p.MinDecryptionVersion = 0 |
| |
| if err := p.Upgrade(ctx, storage, rand.Reader); err == nil { |
| t.Fatal("expected error") |
| } |
| |
| if p.MinDecryptionVersion == 1 { |
| t.Fatal("min decryption version was changed") |
| } |
| if p.Keys != nil { |
| t.Fatal("found upgraded keys") |
| } |
| if p.Key == nil { |
| t.Fatal("non-upgraded key not found") |
| } |
| } |
| |
| func Test_BadArchive(t *testing.T) { |
| ctx := context.Background() |
| lm, _ := NewLockManager(true, 0) |
| storage := &logical.InmemStorage{} |
| p, _, err := lm.GetPolicy(ctx, PolicyRequest{ |
| Upsert: true, |
| Storage: storage, |
| KeyType: KeyType_AES256_GCM96, |
| Name: "test", |
| }, rand.Reader) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if p == nil { |
| t.Fatal("nil policy") |
| } |
| |
| for i := 2; i <= 10; i++ { |
| err = p.Rotate(ctx, storage, rand.Reader) |
| if err != nil { |
| t.Fatal(err) |
| } |
| } |
| |
| p.MinDecryptionVersion = 5 |
| if err := p.Persist(ctx, storage); err != nil { |
| t.Fatal(err) |
| } |
| if p.ArchiveVersion != 10 { |
| t.Fatalf("unexpected archive version %d", p.ArchiveVersion) |
| } |
| if len(p.Keys) != 6 { |
| t.Fatalf("unexpected key length %d", len(p.Keys)) |
| } |
| |
| // Set back |
| p.MinDecryptionVersion = 1 |
| if err := p.Persist(ctx, storage); err != nil { |
| t.Fatal(err) |
| } |
| if p.ArchiveVersion != 10 { |
| t.Fatalf("unexpected archive version %d", p.ArchiveVersion) |
| } |
| if len(p.Keys) != 10 { |
| t.Fatalf("unexpected key length %d", len(p.Keys)) |
| } |
| |
| // Run it again but we'll turn off storage along the way |
| p.MinDecryptionVersion = 5 |
| if err := p.Persist(ctx, storage); err != nil { |
| t.Fatal(err) |
| } |
| if p.ArchiveVersion != 10 { |
| t.Fatalf("unexpected archive version %d", p.ArchiveVersion) |
| } |
| if len(p.Keys) != 6 { |
| t.Fatalf("unexpected key length %d", len(p.Keys)) |
| } |
| |
| underlying := storage.Underlying() |
| underlying.FailPut(true) |
| |
| // Set back, which should cause p.Keys to be changed if the persist works, |
| // but it doesn't |
| p.MinDecryptionVersion = 1 |
| if err := p.Persist(ctx, storage); err == nil { |
| t.Fatal("expected error during put") |
| } |
| if p.ArchiveVersion != 10 { |
| t.Fatalf("unexpected archive version %d", p.ArchiveVersion) |
| } |
| // Here's the expected change |
| if len(p.Keys) != 6 { |
| t.Fatalf("unexpected key length %d", len(p.Keys)) |
| } |
| } |
| |
| func Test_Import(t *testing.T) { |
| ctx := context.Background() |
| storage := &logical.InmemStorage{} |
| testKeys, err := generateTestKeys() |
| if err != nil { |
| t.Fatalf("error generating test keys: %s", err) |
| } |
| |
| tests := map[string]struct { |
| policy Policy |
| key []byte |
| shouldError bool |
| }{ |
| "import AES key": { |
| policy: Policy{ |
| Name: "test-aes-key", |
| Type: KeyType_AES256_GCM96, |
| }, |
| key: testKeys[KeyType_AES256_GCM96], |
| shouldError: false, |
| }, |
| "import RSA key": { |
| policy: Policy{ |
| Name: "test-rsa-key", |
| Type: KeyType_RSA2048, |
| }, |
| key: testKeys[KeyType_RSA2048], |
| shouldError: false, |
| }, |
| "import ECDSA key": { |
| policy: Policy{ |
| Name: "test-ecdsa-key", |
| Type: KeyType_ECDSA_P256, |
| }, |
| key: testKeys[KeyType_ECDSA_P256], |
| shouldError: false, |
| }, |
| "import ED25519 key": { |
| policy: Policy{ |
| Name: "test-ed25519-key", |
| Type: KeyType_ED25519, |
| }, |
| key: testKeys[KeyType_ED25519], |
| shouldError: false, |
| }, |
| "import incorrect key type": { |
| policy: Policy{ |
| Name: "test-ed25519-key", |
| Type: KeyType_ED25519, |
| }, |
| key: testKeys[KeyType_AES256_GCM96], |
| shouldError: true, |
| }, |
| } |
| |
| for name, test := range tests { |
| t.Run(name, func(t *testing.T) { |
| if err := test.policy.Import(ctx, storage, test.key, rand.Reader); (err != nil) != test.shouldError { |
| t.Fatalf("error importing key: %s", err) |
| } |
| }) |
| } |
| } |
| |
| func generateTestKeys() (map[KeyType][]byte, error) { |
| keyMap := make(map[KeyType][]byte) |
| |
| rsaKey, err := rsa.GenerateKey(rand.Reader, 2048) |
| if err != nil { |
| return nil, err |
| } |
| rsaKeyBytes, err := x509.MarshalPKCS8PrivateKey(rsaKey) |
| if err != nil { |
| return nil, err |
| } |
| keyMap[KeyType_RSA2048] = rsaKeyBytes |
| |
| rsaKey, err = rsa.GenerateKey(rand.Reader, 3072) |
| if err != nil { |
| return nil, err |
| } |
| rsaKeyBytes, err = x509.MarshalPKCS8PrivateKey(rsaKey) |
| if err != nil { |
| return nil, err |
| } |
| keyMap[KeyType_RSA3072] = rsaKeyBytes |
| |
| rsaKey, err = rsa.GenerateKey(rand.Reader, 4096) |
| if err != nil { |
| return nil, err |
| } |
| rsaKeyBytes, err = x509.MarshalPKCS8PrivateKey(rsaKey) |
| if err != nil { |
| return nil, err |
| } |
| keyMap[KeyType_RSA4096] = rsaKeyBytes |
| |
| ecdsaKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) |
| if err != nil { |
| return nil, err |
| } |
| ecdsaKeyBytes, err := x509.MarshalPKCS8PrivateKey(ecdsaKey) |
| if err != nil { |
| return nil, err |
| } |
| keyMap[KeyType_ECDSA_P256] = ecdsaKeyBytes |
| |
| _, ed25519Key, err := ed25519.GenerateKey(rand.Reader) |
| if err != nil { |
| return nil, err |
| } |
| ed25519KeyBytes, err := x509.MarshalPKCS8PrivateKey(ed25519Key) |
| if err != nil { |
| return nil, err |
| } |
| keyMap[KeyType_ED25519] = ed25519KeyBytes |
| |
| aesKey := make([]byte, 32) |
| _, err = rand.Read(aesKey) |
| if err != nil { |
| return nil, err |
| } |
| keyMap[KeyType_AES256_GCM96] = aesKey |
| |
| return keyMap, nil |
| } |
| |
| func BenchmarkSymmetric(b *testing.B) { |
| ctx := context.Background() |
| lm, _ := NewLockManager(true, 0) |
| storage := &logical.InmemStorage{} |
| p, _, _ := lm.GetPolicy(ctx, PolicyRequest{ |
| Upsert: true, |
| Storage: storage, |
| KeyType: KeyType_AES256_GCM96, |
| Name: "test", |
| }, rand.Reader) |
| key, _ := p.GetKey(nil, 1, 32) |
| pt := make([]byte, 10) |
| ad := make([]byte, 10) |
| for i := 0; i < b.N; i++ { |
| ct, _ := p.SymmetricEncryptRaw(1, key, pt, |
| SymmetricOpts{ |
| AdditionalData: ad, |
| }) |
| pt2, _ := p.SymmetricDecryptRaw(key, ct, SymmetricOpts{ |
| AdditionalData: ad, |
| }) |
| if !bytes.Equal(pt, pt2) { |
| b.Fail() |
| } |
| } |
| } |
| |
| func saltOptions(options SigningOptions, saltLength int) SigningOptions { |
| return SigningOptions{ |
| HashAlgorithm: options.HashAlgorithm, |
| Marshaling: options.Marshaling, |
| SaltLength: saltLength, |
| SigAlgorithm: options.SigAlgorithm, |
| } |
| } |
| |
| func manualVerify(depth int, t *testing.T, p *Policy, input []byte, sig *SigningResult, options SigningOptions) { |
| tabs := strings.Repeat("\t", depth) |
| t.Log(tabs, "Manually verifying signature with options:", options) |
| |
| tabs = strings.Repeat("\t", depth+1) |
| verified, err := p.VerifySignatureWithOptions(nil, input, sig.Signature, &options) |
| if err != nil { |
| t.Fatal(tabs, "❌ Failed to manually verify signature:", err) |
| } |
| if !verified { |
| t.Fatal(tabs, "❌ Failed to manually verify signature") |
| } |
| } |
| |
| func autoVerify(depth int, t *testing.T, p *Policy, input []byte, sig *SigningResult, options SigningOptions) { |
| tabs := strings.Repeat("\t", depth) |
| t.Log(tabs, "Automatically verifying signature with options:", options) |
| |
| tabs = strings.Repeat("\t", depth+1) |
| verified, err := p.VerifySignature(nil, input, options.HashAlgorithm, options.SigAlgorithm, options.Marshaling, sig.Signature) |
| if err != nil { |
| t.Fatal(tabs, "❌ Failed to automatically verify signature:", err) |
| } |
| if !verified { |
| t.Fatal(tabs, "❌ Failed to automatically verify signature") |
| } |
| } |
| |
| func Test_RSA_PSS(t *testing.T) { |
| t.Log("Testing RSA PSS") |
| mathrand.Seed(time.Now().UnixNano()) |
| |
| var userError errutil.UserError |
| ctx := context.Background() |
| storage := &logical.InmemStorage{} |
| // https://crypto.stackexchange.com/a/1222 |
| input := []byte("the ancients say the longer the salt, the more provable the security") |
| sigAlgorithm := "pss" |
| |
| tabs := make(map[int]string) |
| for i := 1; i <= 6; i++ { |
| tabs[i] = strings.Repeat("\t", i) |
| } |
| |
| test_RSA_PSS := func(t *testing.T, p *Policy, rsaKey *rsa.PrivateKey, hashType HashType, |
| marshalingType MarshalingType, |
| ) { |
| unsaltedOptions := SigningOptions{ |
| HashAlgorithm: hashType, |
| Marshaling: marshalingType, |
| SigAlgorithm: sigAlgorithm, |
| } |
| cryptoHash := CryptoHashMap[hashType] |
| minSaltLength := p.minRSAPSSSaltLength() |
| maxSaltLength := p.maxRSAPSSSaltLength(rsaKey.N.BitLen(), cryptoHash) |
| hash := cryptoHash.New() |
| hash.Write(input) |
| input = hash.Sum(nil) |
| |
| // 1. Make an "automatic" signature with the given key size and hash algorithm, |
| // but an automatically chosen salt length. |
| t.Log(tabs[3], "Make an automatic signature") |
| sig, err := p.Sign(0, nil, input, hashType, sigAlgorithm, marshalingType) |
| if err != nil { |
| // A bit of a hack but FIPS go does not support some hash types |
| if isUnsupportedGoHashType(hashType, err) { |
| t.Skip(tabs[4], "skipping test as FIPS Go does not support hash type") |
| return |
| } |
| t.Fatal(tabs[4], "❌ Failed to automatically sign:", err) |
| } |
| |
| // 1.1 Verify this automatic signature using the *inferred* salt length. |
| autoVerify(4, t, p, input, sig, unsaltedOptions) |
| |
| // 1.2. Verify this automatic signature using the *correct, given* salt length. |
| manualVerify(4, t, p, input, sig, saltOptions(unsaltedOptions, maxSaltLength)) |
| |
| // 1.3. Try to verify this automatic signature using *incorrect, given* salt lengths. |
| t.Log(tabs[4], "Test incorrect salt lengths") |
| incorrectSaltLengths := []int{minSaltLength, maxSaltLength - 1} |
| for _, saltLength := range incorrectSaltLengths { |
| t.Log(tabs[5], "Salt length:", saltLength) |
| saltedOptions := saltOptions(unsaltedOptions, saltLength) |
| |
| verified, _ := p.VerifySignatureWithOptions(nil, input, sig.Signature, &saltedOptions) |
| if verified { |
| t.Fatal(tabs[6], "❌ Failed to invalidate", verified, "signature using incorrect salt length:", err) |
| } |
| } |
| |
| // 2. Rule out boundary, invalid salt lengths. |
| t.Log(tabs[3], "Test invalid salt lengths") |
| invalidSaltLengths := []int{minSaltLength - 1, maxSaltLength + 1} |
| for _, saltLength := range invalidSaltLengths { |
| t.Log(tabs[4], "Salt length:", saltLength) |
| saltedOptions := saltOptions(unsaltedOptions, saltLength) |
| |
| // 2.1. Fail to sign. |
| t.Log(tabs[5], "Try to make a manual signature") |
| _, err := p.SignWithOptions(0, nil, input, &saltedOptions) |
| if !errors.As(err, &userError) { |
| t.Fatal(tabs[6], "❌ Failed to reject invalid salt length:", err) |
| } |
| |
| // 2.2. Fail to verify. |
| t.Log(tabs[5], "Try to verify an automatic signature using an invalid salt length") |
| _, err = p.VerifySignatureWithOptions(nil, input, sig.Signature, &saltedOptions) |
| if !errors.As(err, &userError) { |
| t.Fatal(tabs[6], "❌ Failed to reject invalid salt length:", err) |
| } |
| } |
| |
| // 3. For three possible valid salt lengths... |
| t.Log(tabs[3], "Test three possible valid salt lengths") |
| midSaltLength := mathrand.Intn(maxSaltLength-1) + 1 // [1, maxSaltLength) |
| validSaltLengths := []int{minSaltLength, midSaltLength, maxSaltLength} |
| for _, saltLength := range validSaltLengths { |
| t.Log(tabs[4], "Salt length:", saltLength) |
| saltedOptions := saltOptions(unsaltedOptions, saltLength) |
| |
| // 3.1. Make a "manual" signature with the given key size, hash algorithm, and salt length. |
| t.Log(tabs[5], "Make a manual signature") |
| sig, err := p.SignWithOptions(0, nil, input, &saltedOptions) |
| if err != nil { |
| t.Fatal(tabs[6], "❌ Failed to manually sign:", err) |
| } |
| |
| // 3.2. Verify this manual signature using the *correct, given* salt length. |
| manualVerify(6, t, p, input, sig, saltedOptions) |
| |
| // 3.3. Verify this manual signature using the *inferred* salt length. |
| autoVerify(6, t, p, input, sig, unsaltedOptions) |
| } |
| } |
| |
| rsaKeyTypes := []KeyType{KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096} |
| testKeys, err := generateTestKeys() |
| if err != nil { |
| t.Fatalf("error generating test keys: %s", err) |
| } |
| |
| // 1. For each standard RSA key size 2048, 3072, and 4096... |
| for _, rsaKeyType := range rsaKeyTypes { |
| t.Log("Key size: ", rsaKeyType) |
| p := &Policy{ |
| Name: fmt.Sprint(rsaKeyType), // NOTE: crucial to create a new key per key size |
| Type: rsaKeyType, |
| } |
| |
| rsaKeyBytes := testKeys[rsaKeyType] |
| err := p.Import(ctx, storage, rsaKeyBytes, rand.Reader) |
| if err != nil { |
| t.Fatal(tabs[1], "❌ Failed to import key:", err) |
| } |
| rsaKeyAny, err := x509.ParsePKCS8PrivateKey(rsaKeyBytes) |
| if err != nil { |
| t.Fatalf("error parsing test keys: %s", err) |
| } |
| rsaKey := rsaKeyAny.(*rsa.PrivateKey) |
| |
| // 2. For each hash algorithm... |
| for hashAlgorithm, hashType := range HashTypeMap { |
| t.Log(tabs[1], "Hash algorithm:", hashAlgorithm) |
| if hashAlgorithm == "none" { |
| continue |
| } |
| |
| // 3. For each marshaling type... |
| for marshalingName, marshalingType := range MarshalingTypeMap { |
| t.Log(tabs[2], "Marshaling type:", marshalingName) |
| testName := fmt.Sprintf("%s-%s-%s", rsaKeyType, hashAlgorithm, marshalingName) |
| t.Run(testName, func(t *testing.T) { test_RSA_PSS(t, p, rsaKey, hashType, marshalingType) }) |
| } |
| } |
| } |
| } |
| |
| func Test_RSA_PKCS1(t *testing.T) { |
| t.Log("Testing RSA PKCS#1v1.5") |
| |
| ctx := context.Background() |
| storage := &logical.InmemStorage{} |
| // https://crypto.stackexchange.com/a/1222 |
| input := []byte("Sphinx of black quartz, judge my vow") |
| sigAlgorithm := "pkcs1v15" |
| |
| tabs := make(map[int]string) |
| for i := 1; i <= 6; i++ { |
| tabs[i] = strings.Repeat("\t", i) |
| } |
| |
| test_RSA_PKCS1 := func(t *testing.T, p *Policy, rsaKey *rsa.PrivateKey, hashType HashType, |
| marshalingType MarshalingType, |
| ) { |
| unsaltedOptions := SigningOptions{ |
| HashAlgorithm: hashType, |
| Marshaling: marshalingType, |
| SigAlgorithm: sigAlgorithm, |
| } |
| cryptoHash := CryptoHashMap[hashType] |
| |
| // PKCS#1v1.5 NoOID uses a direct input and assumes it is pre-hashed. |
| if hashType != 0 { |
| hash := cryptoHash.New() |
| hash.Write(input) |
| input = hash.Sum(nil) |
| } |
| |
| // 1. Make a signature with the given key size and hash algorithm. |
| t.Log(tabs[3], "Make an automatic signature") |
| sig, err := p.Sign(0, nil, input, hashType, sigAlgorithm, marshalingType) |
| if err != nil { |
| // A bit of a hack but FIPS go does not support some hash types |
| if isUnsupportedGoHashType(hashType, err) { |
| t.Skip(tabs[4], "skipping test as FIPS Go does not support hash type") |
| return |
| } |
| t.Fatal(tabs[4], "❌ Failed to automatically sign:", err) |
| } |
| |
| // 1.1 Verify this signature using the *inferred* salt length. |
| autoVerify(4, t, p, input, sig, unsaltedOptions) |
| } |
| |
| rsaKeyTypes := []KeyType{KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096} |
| testKeys, err := generateTestKeys() |
| if err != nil { |
| t.Fatalf("error generating test keys: %s", err) |
| } |
| |
| // 1. For each standard RSA key size 2048, 3072, and 4096... |
| for _, rsaKeyType := range rsaKeyTypes { |
| t.Log("Key size: ", rsaKeyType) |
| p := &Policy{ |
| Name: fmt.Sprint(rsaKeyType), // NOTE: crucial to create a new key per key size |
| Type: rsaKeyType, |
| } |
| |
| rsaKeyBytes := testKeys[rsaKeyType] |
| err := p.Import(ctx, storage, rsaKeyBytes, rand.Reader) |
| if err != nil { |
| t.Fatal(tabs[1], "❌ Failed to import key:", err) |
| } |
| rsaKeyAny, err := x509.ParsePKCS8PrivateKey(rsaKeyBytes) |
| if err != nil { |
| t.Fatalf("error parsing test keys: %s", err) |
| } |
| rsaKey := rsaKeyAny.(*rsa.PrivateKey) |
| |
| // 2. For each hash algorithm... |
| for hashAlgorithm, hashType := range HashTypeMap { |
| t.Log(tabs[1], "Hash algorithm:", hashAlgorithm) |
| |
| // 3. For each marshaling type... |
| for marshalingName, marshalingType := range MarshalingTypeMap { |
| t.Log(tabs[2], "Marshaling type:", marshalingName) |
| testName := fmt.Sprintf("%s-%s-%s", rsaKeyType, hashAlgorithm, marshalingName) |
| t.Run(testName, func(t *testing.T) { test_RSA_PKCS1(t, p, rsaKey, hashType, marshalingType) }) |
| } |
| } |
| } |
| } |
| |
| // Normal Go builds support all the hash functions for RSA_PSS signatures but the |
| // FIPS Go build does not support at this time the SHA3 hashes as FIPS 140_2 does |
| // not accept them. |
| func isUnsupportedGoHashType(hashType HashType, err error) bool { |
| switch hashType { |
| case HashTypeSHA3224, HashTypeSHA3256, HashTypeSHA3384, HashTypeSHA3512: |
| return strings.Contains(err.Error(), "unsupported hash function") |
| } |
| |
| return false |
| } |