| // Copyright (c) HashiCorp, Inc. |
| // SPDX-License-Identifier: MPL-2.0 |
| |
| package raft |
| |
| import ( |
| "context" |
| "fmt" |
| "io/ioutil" |
| "math/rand" |
| "os" |
| "sort" |
| "testing" |
| |
| "github.com/go-test/deep" |
| "github.com/golang/protobuf/proto" |
| "github.com/hashicorp/go-hclog" |
| "github.com/hashicorp/raft" |
| "github.com/hashicorp/vault/sdk/physical" |
| ) |
| |
| func getFSM(t testing.TB) (*FSM, string) { |
| raftDir, err := ioutil.TempDir("", "vault-raft-") |
| if err != nil { |
| t.Fatal(err) |
| } |
| t.Logf("raft dir: %s", raftDir) |
| |
| logger := hclog.New(&hclog.LoggerOptions{ |
| Name: "raft", |
| Level: hclog.Trace, |
| }) |
| |
| fsm, err := NewFSM(raftDir, "", logger) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| return fsm, raftDir |
| } |
| |
| func TestFSM_Batching(t *testing.T) { |
| fsm, dir := getFSM(t) |
| defer func() { _ = os.RemoveAll(dir) }() |
| |
| var index uint64 |
| var term uint64 = 1 |
| |
| getLog := func(i uint64) (int, *raft.Log) { |
| if rand.Intn(10) >= 8 { |
| term += 1 |
| return 0, &raft.Log{ |
| Index: i, |
| Term: term, |
| Type: raft.LogConfiguration, |
| Data: raft.EncodeConfiguration(raft.Configuration{ |
| Servers: []raft.Server{ |
| { |
| Address: "test", |
| ID: "test", |
| }, |
| }, |
| }), |
| } |
| } |
| |
| command := &LogData{ |
| Operations: make([]*LogOperation, rand.Intn(10)), |
| } |
| |
| for j := range command.Operations { |
| command.Operations[j] = &LogOperation{ |
| OpType: putOp, |
| Key: fmt.Sprintf("key-%d-%d", i, j), |
| Value: []byte(fmt.Sprintf("value-%d-%d", i, j)), |
| } |
| } |
| commandBytes, err := proto.Marshal(command) |
| if err != nil { |
| t.Fatal(err) |
| } |
| return len(command.Operations), &raft.Log{ |
| Index: i, |
| Term: term, |
| Type: raft.LogCommand, |
| Data: commandBytes, |
| } |
| } |
| |
| totalKeys := 0 |
| for i := 0; i < 100; i++ { |
| batchSize := rand.Intn(64) |
| batch := make([]*raft.Log, batchSize) |
| for j := 0; j < batchSize; j++ { |
| var keys int |
| index++ |
| keys, batch[j] = getLog(index) |
| totalKeys += keys |
| } |
| |
| resp := fsm.ApplyBatch(batch) |
| if len(resp) != batchSize { |
| t.Fatalf("incorrect response length: got %d expected %d", len(resp), batchSize) |
| } |
| |
| for _, r := range resp { |
| if _, ok := r.(*FSMApplyResponse); !ok { |
| t.Fatal("bad response type") |
| } |
| } |
| } |
| |
| keys, err := fsm.List(context.Background(), "") |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| if len(keys) != totalKeys { |
| t.Fatalf("incorrect number of keys: got %d expected %d", len(keys), totalKeys) |
| } |
| |
| latestIndex, latestConfig := fsm.LatestState() |
| if latestIndex.Index != index { |
| t.Fatalf("bad latest index: got %d expected %d", latestIndex.Index, index) |
| } |
| if latestIndex.Term != term { |
| t.Fatalf("bad latest term: got %d expected %d", latestIndex.Term, term) |
| } |
| |
| if latestConfig == nil && term > 1 { |
| t.Fatal("config wasn't updated") |
| } |
| } |
| |
| func TestFSM_List(t *testing.T) { |
| fsm, dir := getFSM(t) |
| defer func() { _ = os.RemoveAll(dir) }() |
| |
| ctx := context.Background() |
| count := 100 |
| keys := rand.Perm(count) |
| var sorted []string |
| for _, k := range keys { |
| err := fsm.Put(ctx, &physical.Entry{Key: fmt.Sprintf("foo/%d/bar", k)}) |
| if err != nil { |
| t.Fatal(err) |
| } |
| err = fsm.Put(ctx, &physical.Entry{Key: fmt.Sprintf("foo/%d/baz", k)}) |
| if err != nil { |
| t.Fatal(err) |
| } |
| sorted = append(sorted, fmt.Sprintf("%d/", k)) |
| } |
| sort.Strings(sorted) |
| |
| got, err := fsm.List(ctx, "foo/") |
| if err != nil { |
| t.Fatal(err) |
| } |
| sort.Strings(got) |
| if diff := deep.Equal(sorted, got); len(diff) > 0 { |
| t.Fatal(diff) |
| } |
| } |