blob: de306595e5a3f06aafa5608b55751345052d61b1 [file] [log] [blame]
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package command
import (
"context"
"fmt"
"os"
"strings"
"testing"
"github.com/hashicorp/vault/physical/raft"
"github.com/hashicorp/vault/sdk/physical"
"github.com/mitchellh/cli"
)
func testOperatorRaftSnapshotInspectCommand(tb testing.TB) (*cli.MockUi, *OperatorRaftSnapshotInspectCommand) {
tb.Helper()
ui := cli.NewMockUi()
return ui, &OperatorRaftSnapshotInspectCommand{
BaseCommand: &BaseCommand{
UI: ui,
},
}
}
func createSnapshot(tb testing.TB) (*os.File, func(), error) {
// Create new raft backend
r, raftDir := raft.GetRaft(tb, true, false)
defer os.RemoveAll(raftDir)
// Write some data
for i := 0; i < 100; i++ {
err := r.Put(context.Background(), &physical.Entry{
Key: fmt.Sprintf("key-%d", i),
Value: []byte(fmt.Sprintf("value-%d", i)),
})
if err != nil {
return nil, nil, fmt.Errorf("Error adding data to snapshot %s", err)
}
}
// Create temporary file to save snapshot to
snap, err := os.CreateTemp("", "temp_snapshot.snap")
if err != nil {
return nil, nil, fmt.Errorf("Error creating temporary file %s", err)
}
cleanup := func() {
err := os.RemoveAll(snap.Name())
if err != nil {
tb.Errorf("Error deleting temporary snapshot %s", err)
}
}
// Save snapshot
err = r.Snapshot(snap, nil)
if err != nil {
return nil, nil, fmt.Errorf("Error saving raft snapshot %s", err)
}
return snap, cleanup, nil
}
func TestOperatorRaftSnapshotInspectCommand_Run(t *testing.T) {
t.Parallel()
file1, cleanup1, err := createSnapshot(t)
if err != nil {
t.Fatalf("Error creating snapshot %s", err)
}
file2, cleanup2, err := createSnapshot(t)
if err != nil {
t.Fatalf("Error creating snapshot %s", err)
}
cases := []struct {
name string
args []string
out string
code int
cleanup func()
}{
{
"too_many_args",
[]string{"test.snap", "test"},
"Too many arguments",
1,
nil,
},
{
"default",
[]string{file1.Name()},
"ID bolt-snapshot",
0,
cleanup1,
},
{
"all_flags",
[]string{"-details", "-depth", "10", "-filter", "key", file2.Name()},
"Key Name",
0,
cleanup2,
},
}
t.Run("validations", func(t *testing.T) {
t.Parallel()
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
client, closer := testVaultServer(t)
defer closer()
ui, cmd := testOperatorRaftSnapshotInspectCommand(t)
cmd.client = client
code := cmd.Run(tc.args)
if code != tc.code {
t.Errorf("expected %d to be %d", code, tc.code)
}
combined := ui.OutputWriter.String() + ui.ErrorWriter.String()
if !strings.Contains(combined, tc.out) {
t.Errorf("expected %q to contain %q", combined, tc.out)
}
if tc.cleanup != nil {
tc.cleanup()
}
})
}
})
}