| // Copyright (c) HashiCorp, Inc. |
| // SPDX-License-Identifier: BUSL-1.1 |
| |
| package command |
| |
| import ( |
| "context" |
| "fmt" |
| "os" |
| "strings" |
| "testing" |
| |
| "github.com/hashicorp/vault/physical/raft" |
| "github.com/hashicorp/vault/sdk/physical" |
| "github.com/mitchellh/cli" |
| ) |
| |
| func testOperatorRaftSnapshotInspectCommand(tb testing.TB) (*cli.MockUi, *OperatorRaftSnapshotInspectCommand) { |
| tb.Helper() |
| |
| ui := cli.NewMockUi() |
| return ui, &OperatorRaftSnapshotInspectCommand{ |
| BaseCommand: &BaseCommand{ |
| UI: ui, |
| }, |
| } |
| } |
| |
| func createSnapshot(tb testing.TB) (*os.File, func(), error) { |
| // Create new raft backend |
| r, raftDir := raft.GetRaft(tb, true, false) |
| defer os.RemoveAll(raftDir) |
| |
| // Write some data |
| for i := 0; i < 100; i++ { |
| err := r.Put(context.Background(), &physical.Entry{ |
| Key: fmt.Sprintf("key-%d", i), |
| Value: []byte(fmt.Sprintf("value-%d", i)), |
| }) |
| if err != nil { |
| return nil, nil, fmt.Errorf("Error adding data to snapshot %s", err) |
| } |
| } |
| |
| // Create temporary file to save snapshot to |
| snap, err := os.CreateTemp("", "temp_snapshot.snap") |
| if err != nil { |
| return nil, nil, fmt.Errorf("Error creating temporary file %s", err) |
| } |
| |
| cleanup := func() { |
| err := os.RemoveAll(snap.Name()) |
| if err != nil { |
| tb.Errorf("Error deleting temporary snapshot %s", err) |
| } |
| } |
| |
| // Save snapshot |
| err = r.Snapshot(snap, nil) |
| if err != nil { |
| return nil, nil, fmt.Errorf("Error saving raft snapshot %s", err) |
| } |
| |
| return snap, cleanup, nil |
| } |
| |
| func TestOperatorRaftSnapshotInspectCommand_Run(t *testing.T) { |
| t.Parallel() |
| |
| file1, cleanup1, err := createSnapshot(t) |
| if err != nil { |
| t.Fatalf("Error creating snapshot %s", err) |
| } |
| |
| file2, cleanup2, err := createSnapshot(t) |
| if err != nil { |
| t.Fatalf("Error creating snapshot %s", err) |
| } |
| |
| cases := []struct { |
| name string |
| args []string |
| out string |
| code int |
| cleanup func() |
| }{ |
| { |
| "too_many_args", |
| []string{"test.snap", "test"}, |
| "Too many arguments", |
| 1, |
| nil, |
| }, |
| { |
| "default", |
| []string{file1.Name()}, |
| "ID bolt-snapshot", |
| 0, |
| cleanup1, |
| }, |
| { |
| "all_flags", |
| []string{"-details", "-depth", "10", "-filter", "key", file2.Name()}, |
| "Key Name", |
| 0, |
| cleanup2, |
| }, |
| } |
| |
| t.Run("validations", func(t *testing.T) { |
| t.Parallel() |
| |
| for _, tc := range cases { |
| tc := tc |
| |
| t.Run(tc.name, func(t *testing.T) { |
| t.Parallel() |
| |
| client, closer := testVaultServer(t) |
| defer closer() |
| |
| ui, cmd := testOperatorRaftSnapshotInspectCommand(t) |
| |
| cmd.client = client |
| code := cmd.Run(tc.args) |
| if code != tc.code { |
| t.Errorf("expected %d to be %d", code, tc.code) |
| } |
| |
| combined := ui.OutputWriter.String() + ui.ErrorWriter.String() |
| if !strings.Contains(combined, tc.out) { |
| t.Errorf("expected %q to contain %q", combined, tc.out) |
| } |
| |
| if tc.cleanup != nil { |
| tc.cleanup() |
| } |
| }) |
| } |
| }) |
| } |