blob: e01c03fa00bb6431d053b8ea336a082f839c205a [file] [log] [blame]
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package command
import (
"bytes"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/base64"
"testing"
"github.com/hashicorp/vault/api"
"github.com/stretchr/testify/require"
)
// Validate the `vault transit import` command works.
func TestTransitImport(t *testing.T) {
t.Parallel()
client, closer := testVaultServer(t)
defer closer()
if err := client.Sys().Mount("transit", &api.MountInput{
Type: "transit",
}); err != nil {
t.Fatalf("transit mount error: %#v", err)
}
rsa1, rsa2, aes128, aes256 := generateKeys(t)
type testCase struct {
variant string
path string
key []byte
args []string
shouldFail bool
}
tests := []testCase{
{
"import",
"transit/keys/rsa1",
rsa1,
[]string{"type=rsa-2048"},
false, /* first import */
},
{
"import",
"transit/keys/rsa1",
rsa2,
[]string{"type=rsa-2048"},
true, /* already exists */
},
{
"import-version",
"transit/keys/rsa1",
rsa2,
[]string{"type=rsa-2048"},
false, /* new version */
},
{
"import",
"transit/keys/rsa2",
rsa2,
[]string{"type=rsa-4096"},
true, /* wrong type */
},
{
"import",
"transit/keys/rsa2",
rsa2,
[]string{"type=rsa-2048"},
false, /* new name */
},
{
"import",
"transit/keys/aes1",
aes128,
[]string{"type=aes128-gcm96"},
false, /* first import */
},
{
"import",
"transit/keys/aes1",
aes256,
[]string{"type=aes256-gcm96"},
true, /* already exists */
},
{
"import-version",
"transit/keys/aes1",
aes256,
[]string{"type=aes256-gcm96"},
true, /* new version, different type */
},
{
"import-version",
"transit/keys/aes1",
aes128,
[]string{"type=aes128-gcm96"},
false, /* new version */
},
{
"import",
"transit/keys/aes2",
aes256,
[]string{"type=aes128-gcm96"},
true, /* wrong type */
},
{
"import",
"transit/keys/aes2",
aes256,
[]string{"type=aes256-gcm96"},
false, /* new name */
},
}
for index, tc := range tests {
t.Logf("Running test case %d: %v", index, tc)
execTransitImport(t, client, tc.variant, tc.path, tc.key, tc.args, tc.shouldFail)
}
}
func execTransitImport(t *testing.T, client *api.Client, method string, path string, key []byte, data []string, expectFailure bool) {
t.Helper()
keyBase64 := base64.StdEncoding.EncodeToString(key)
var args []string
args = append(args, "transit")
args = append(args, method)
args = append(args, path)
args = append(args, keyBase64)
args = append(args, data...)
stdout := bytes.NewBuffer(nil)
stderr := bytes.NewBuffer(nil)
runOpts := &RunOptions{
Stdout: stdout,
Stderr: stderr,
Client: client,
}
code := RunCustom(args, runOpts)
combined := stdout.String() + stderr.String()
if code != 0 {
if !expectFailure {
t.Fatalf("Got unexpected failure from test (ret %d): %v", code, combined)
}
} else {
if expectFailure {
t.Fatalf("Expected failure, got success from test (ret %d): %v", code, combined)
}
}
}
func generateKeys(t *testing.T) (rsa1 []byte, rsa2 []byte, aes128 []byte, aes256 []byte) {
t.Helper()
priv1, err := rsa.GenerateKey(rand.Reader, 2048)
require.NotNil(t, priv1, "failed generating RSA 1 key")
require.NoError(t, err, "failed generating RSA 1 key")
rsa1, err = x509.MarshalPKCS8PrivateKey(priv1)
require.NotNil(t, rsa1, "failed marshaling RSA 1 key")
require.NoError(t, err, "failed marshaling RSA 1 key")
priv2, err := rsa.GenerateKey(rand.Reader, 2048)
require.NotNil(t, priv2, "failed generating RSA 2 key")
require.NoError(t, err, "failed generating RSA 2 key")
rsa2, err = x509.MarshalPKCS8PrivateKey(priv2)
require.NotNil(t, rsa2, "failed marshaling RSA 2 key")
require.NoError(t, err, "failed marshaling RSA 2 key")
aes128 = make([]byte, 128/8)
_, err = rand.Read(aes128)
require.NoError(t, err, "failed generating AES 128 key")
aes256 = make([]byte, 256/8)
_, err = rand.Read(aes256)
require.NoError(t, err, "failed generating AES 256 key")
return
}