| // Copyright (c) HashiCorp, Inc. |
| // SPDX-License-Identifier: MPL-2.0 |
| |
| package command |
| |
| import ( |
| "bytes" |
| "crypto/rand" |
| "crypto/rsa" |
| "crypto/x509" |
| "encoding/base64" |
| "testing" |
| |
| "github.com/hashicorp/vault/api" |
| |
| "github.com/stretchr/testify/require" |
| ) |
| |
| // Validate the `vault transit import` command works. |
| func TestTransitImport(t *testing.T) { |
| t.Parallel() |
| |
| client, closer := testVaultServer(t) |
| defer closer() |
| |
| if err := client.Sys().Mount("transit", &api.MountInput{ |
| Type: "transit", |
| }); err != nil { |
| t.Fatalf("transit mount error: %#v", err) |
| } |
| |
| rsa1, rsa2, aes128, aes256 := generateKeys(t) |
| |
| type testCase struct { |
| variant string |
| path string |
| key []byte |
| args []string |
| shouldFail bool |
| } |
| tests := []testCase{ |
| { |
| "import", |
| "transit/keys/rsa1", |
| rsa1, |
| []string{"type=rsa-2048"}, |
| false, /* first import */ |
| }, |
| { |
| "import", |
| "transit/keys/rsa1", |
| rsa2, |
| []string{"type=rsa-2048"}, |
| true, /* already exists */ |
| }, |
| { |
| "import-version", |
| "transit/keys/rsa1", |
| rsa2, |
| []string{"type=rsa-2048"}, |
| false, /* new version */ |
| }, |
| { |
| "import", |
| "transit/keys/rsa2", |
| rsa2, |
| []string{"type=rsa-4096"}, |
| true, /* wrong type */ |
| }, |
| { |
| "import", |
| "transit/keys/rsa2", |
| rsa2, |
| []string{"type=rsa-2048"}, |
| false, /* new name */ |
| }, |
| { |
| "import", |
| "transit/keys/aes1", |
| aes128, |
| []string{"type=aes128-gcm96"}, |
| false, /* first import */ |
| }, |
| { |
| "import", |
| "transit/keys/aes1", |
| aes256, |
| []string{"type=aes256-gcm96"}, |
| true, /* already exists */ |
| }, |
| { |
| "import-version", |
| "transit/keys/aes1", |
| aes256, |
| []string{"type=aes256-gcm96"}, |
| true, /* new version, different type */ |
| }, |
| { |
| "import-version", |
| "transit/keys/aes1", |
| aes128, |
| []string{"type=aes128-gcm96"}, |
| false, /* new version */ |
| }, |
| { |
| "import", |
| "transit/keys/aes2", |
| aes256, |
| []string{"type=aes128-gcm96"}, |
| true, /* wrong type */ |
| }, |
| { |
| "import", |
| "transit/keys/aes2", |
| aes256, |
| []string{"type=aes256-gcm96"}, |
| false, /* new name */ |
| }, |
| } |
| |
| for index, tc := range tests { |
| t.Logf("Running test case %d: %v", index, tc) |
| execTransitImport(t, client, tc.variant, tc.path, tc.key, tc.args, tc.shouldFail) |
| } |
| } |
| |
| func execTransitImport(t *testing.T, client *api.Client, method string, path string, key []byte, data []string, expectFailure bool) { |
| t.Helper() |
| |
| keyBase64 := base64.StdEncoding.EncodeToString(key) |
| |
| var args []string |
| args = append(args, "transit") |
| args = append(args, method) |
| args = append(args, path) |
| args = append(args, keyBase64) |
| args = append(args, data...) |
| |
| stdout := bytes.NewBuffer(nil) |
| stderr := bytes.NewBuffer(nil) |
| runOpts := &RunOptions{ |
| Stdout: stdout, |
| Stderr: stderr, |
| Client: client, |
| } |
| |
| code := RunCustom(args, runOpts) |
| combined := stdout.String() + stderr.String() |
| |
| if code != 0 { |
| if !expectFailure { |
| t.Fatalf("Got unexpected failure from test (ret %d): %v", code, combined) |
| } |
| } else { |
| if expectFailure { |
| t.Fatalf("Expected failure, got success from test (ret %d): %v", code, combined) |
| } |
| } |
| } |
| |
| func generateKeys(t *testing.T) (rsa1 []byte, rsa2 []byte, aes128 []byte, aes256 []byte) { |
| t.Helper() |
| |
| priv1, err := rsa.GenerateKey(rand.Reader, 2048) |
| require.NotNil(t, priv1, "failed generating RSA 1 key") |
| require.NoError(t, err, "failed generating RSA 1 key") |
| |
| rsa1, err = x509.MarshalPKCS8PrivateKey(priv1) |
| require.NotNil(t, rsa1, "failed marshaling RSA 1 key") |
| require.NoError(t, err, "failed marshaling RSA 1 key") |
| |
| priv2, err := rsa.GenerateKey(rand.Reader, 2048) |
| require.NotNil(t, priv2, "failed generating RSA 2 key") |
| require.NoError(t, err, "failed generating RSA 2 key") |
| |
| rsa2, err = x509.MarshalPKCS8PrivateKey(priv2) |
| require.NotNil(t, rsa2, "failed marshaling RSA 2 key") |
| require.NoError(t, err, "failed marshaling RSA 2 key") |
| |
| aes128 = make([]byte, 128/8) |
| _, err = rand.Read(aes128) |
| require.NoError(t, err, "failed generating AES 128 key") |
| |
| aes256 = make([]byte, 256/8) |
| _, err = rand.Read(aes256) |
| require.NoError(t, err, "failed generating AES 256 key") |
| |
| return |
| } |