| // Copyright (c) HashiCorp, Inc. |
| // SPDX-License-Identifier: MPL-2.0 |
| |
| package testcluster |
| |
| import ( |
| "context" |
| "encoding/base64" |
| "encoding/hex" |
| "fmt" |
| "sync/atomic" |
| "time" |
| |
| "github.com/hashicorp/go-multierror" |
| "github.com/hashicorp/go-uuid" |
| "github.com/hashicorp/vault/api" |
| "github.com/hashicorp/vault/sdk/helper/xor" |
| ) |
| |
| // Note that OSS standbys will not accept seal requests. And ent perf standbys |
| // may fail it as well if they haven't yet been able to get "elected" as perf standbys. |
| func SealNode(ctx context.Context, cluster VaultCluster, nodeIdx int) error { |
| if nodeIdx >= len(cluster.Nodes()) { |
| return fmt.Errorf("invalid nodeIdx %d for cluster", nodeIdx) |
| } |
| node := cluster.Nodes()[nodeIdx] |
| client := node.APIClient() |
| |
| err := client.Sys().SealWithContext(ctx) |
| if err != nil { |
| return err |
| } |
| |
| return NodeSealed(ctx, cluster, nodeIdx) |
| } |
| |
| func SealAllNodes(ctx context.Context, cluster VaultCluster) error { |
| for i := range cluster.Nodes() { |
| if err := SealNode(ctx, cluster, i); err != nil { |
| return err |
| } |
| } |
| return nil |
| } |
| |
| func UnsealNode(ctx context.Context, cluster VaultCluster, nodeIdx int) error { |
| if nodeIdx >= len(cluster.Nodes()) { |
| return fmt.Errorf("invalid nodeIdx %d for cluster", nodeIdx) |
| } |
| node := cluster.Nodes()[nodeIdx] |
| client := node.APIClient() |
| |
| for _, key := range cluster.GetBarrierOrRecoveryKeys() { |
| _, err := client.Sys().UnsealWithContext(ctx, hex.EncodeToString(key)) |
| if err != nil { |
| return err |
| } |
| } |
| |
| return NodeHealthy(ctx, cluster, nodeIdx) |
| } |
| |
| func UnsealAllNodes(ctx context.Context, cluster VaultCluster) error { |
| for i := range cluster.Nodes() { |
| if err := UnsealNode(ctx, cluster, i); err != nil { |
| return err |
| } |
| } |
| return nil |
| } |
| |
| func NodeSealed(ctx context.Context, cluster VaultCluster, nodeIdx int) error { |
| if nodeIdx >= len(cluster.Nodes()) { |
| return fmt.Errorf("invalid nodeIdx %d for cluster", nodeIdx) |
| } |
| node := cluster.Nodes()[nodeIdx] |
| client := node.APIClient() |
| |
| var health *api.HealthResponse |
| var err error |
| for ctx.Err() == nil { |
| health, err = client.Sys().HealthWithContext(ctx) |
| switch { |
| case err != nil: |
| case !health.Sealed: |
| err = fmt.Errorf("unsealed: %#v", health) |
| default: |
| return nil |
| } |
| time.Sleep(500 * time.Millisecond) |
| } |
| return fmt.Errorf("node %d is not sealed: %v", nodeIdx, err) |
| } |
| |
| func WaitForNCoresSealed(ctx context.Context, cluster VaultCluster, n int) error { |
| ctx, cancel := context.WithCancel(ctx) |
| defer cancel() |
| |
| errs := make(chan error) |
| for i := range cluster.Nodes() { |
| go func(i int) { |
| var err error |
| for ctx.Err() == nil { |
| err = NodeSealed(ctx, cluster, i) |
| if err == nil { |
| errs <- nil |
| return |
| } |
| time.Sleep(100 * time.Millisecond) |
| } |
| if err == nil { |
| err = ctx.Err() |
| } |
| errs <- err |
| }(i) |
| } |
| |
| var merr *multierror.Error |
| var sealed int |
| for range cluster.Nodes() { |
| err := <-errs |
| if err != nil { |
| merr = multierror.Append(merr, err) |
| } else { |
| sealed++ |
| if sealed == n { |
| return nil |
| } |
| } |
| } |
| |
| return fmt.Errorf("%d cores were not sealed, errs: %v", n, merr.ErrorOrNil()) |
| } |
| |
| func NodeHealthy(ctx context.Context, cluster VaultCluster, nodeIdx int) error { |
| if nodeIdx >= len(cluster.Nodes()) { |
| return fmt.Errorf("invalid nodeIdx %d for cluster", nodeIdx) |
| } |
| node := cluster.Nodes()[nodeIdx] |
| client := node.APIClient() |
| |
| var health *api.HealthResponse |
| var err error |
| for ctx.Err() == nil { |
| health, err = client.Sys().HealthWithContext(ctx) |
| switch { |
| case err != nil: |
| case health == nil: |
| err = fmt.Errorf("nil response to health check") |
| case health.Sealed: |
| err = fmt.Errorf("sealed: %#v", health) |
| default: |
| return nil |
| } |
| time.Sleep(500 * time.Millisecond) |
| } |
| return fmt.Errorf("node %d is unhealthy: %v", nodeIdx, err) |
| } |
| |
| func LeaderNode(ctx context.Context, cluster VaultCluster) (int, error) { |
| // Be robust to multiple nodes thinking they are active. This is possible in |
| // certain network partition situations where the old leader has not |
| // discovered it's lost leadership yet. In tests this is only likely to come |
| // up when we are specifically provoking it, but it's possible it could happen |
| // at any point if leadership flaps of connectivity suffers transient errors |
| // etc. so be robust against it. The best solution would be to have some sort |
| // of epoch like the raft term that is guaranteed to be monotonically |
| // increasing through elections, however we don't have that abstraction for |
| // all HABackends in general. The best we have is the ActiveTime. In a |
| // distributed systems text book this would be bad to rely on due to clock |
| // sync issues etc. but for our tests it's likely fine because even if we are |
| // running separate Vault containers, they are all using the same hardware |
| // clock in the system. |
| leaderActiveTimes := make(map[int]time.Time) |
| for i, node := range cluster.Nodes() { |
| client := node.APIClient() |
| ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) |
| resp, err := client.Sys().LeaderWithContext(ctx) |
| cancel() |
| if err != nil || resp == nil || !resp.IsSelf { |
| continue |
| } |
| leaderActiveTimes[i] = resp.ActiveTime |
| } |
| if len(leaderActiveTimes) == 0 { |
| return -1, fmt.Errorf("no leader found") |
| } |
| // At least one node thinks it is active. If multiple, pick the one with the |
| // most recent ActiveTime. Note if there is only one then this just returns |
| // it. |
| var newestLeaderIdx int |
| var newestActiveTime time.Time |
| for i, at := range leaderActiveTimes { |
| if at.After(newestActiveTime) { |
| newestActiveTime = at |
| newestLeaderIdx = i |
| } |
| } |
| return newestLeaderIdx, nil |
| } |
| |
| func WaitForActiveNode(ctx context.Context, cluster VaultCluster) (int, error) { |
| for ctx.Err() == nil { |
| if idx, _ := LeaderNode(ctx, cluster); idx != -1 { |
| return idx, nil |
| } |
| time.Sleep(500 * time.Millisecond) |
| } |
| return -1, ctx.Err() |
| } |
| |
| func WaitForActiveNodeAndPerfStandbys(ctx context.Context, cluster VaultCluster) error { |
| logger := cluster.NamedLogger("WaitForActiveNodeAndPerfStandbys") |
| // This WaitForActiveNode was added because after a Raft cluster is sealed |
| // and then unsealed, when it comes up it may have a different leader than |
| // Core0, making this helper fail. |
| // A sleep before calling WaitForActiveNodeAndPerfStandbys seems to sort |
| // things out, but so apparently does this. We should be able to eliminate |
| // this call to WaitForActiveNode by reworking the logic in this method. |
| leaderIdx, err := WaitForActiveNode(ctx, cluster) |
| if err != nil { |
| return err |
| } |
| |
| if len(cluster.Nodes()) == 1 { |
| return nil |
| } |
| |
| expectedStandbys := len(cluster.Nodes()) - 1 |
| |
| mountPoint, err := uuid.GenerateUUID() |
| if err != nil { |
| return err |
| } |
| leaderClient := cluster.Nodes()[leaderIdx].APIClient() |
| |
| for ctx.Err() == nil { |
| err = leaderClient.Sys().MountWithContext(ctx, mountPoint, &api.MountInput{ |
| Type: "kv", |
| Local: true, |
| }) |
| if err == nil { |
| break |
| } |
| time.Sleep(1 * time.Second) |
| } |
| if err != nil { |
| return fmt.Errorf("unable to mount KV engine: %v", err) |
| } |
| path := mountPoint + "/waitforactivenodeandperfstandbys" |
| var standbys, actives int64 |
| errchan := make(chan error, len(cluster.Nodes())) |
| for i := range cluster.Nodes() { |
| go func(coreNo int) { |
| node := cluster.Nodes()[coreNo] |
| client := node.APIClient() |
| val := 1 |
| var err error |
| defer func() { |
| errchan <- err |
| }() |
| |
| var lastWAL uint64 |
| for ctx.Err() == nil { |
| _, err = leaderClient.Logical().WriteWithContext(ctx, path, map[string]interface{}{ |
| "bar": val, |
| }) |
| val++ |
| time.Sleep(250 * time.Millisecond) |
| if err != nil { |
| continue |
| } |
| var leader *api.LeaderResponse |
| leader, err = client.Sys().LeaderWithContext(ctx) |
| if err != nil { |
| logger.Trace("waiting for core", "core", coreNo, "err", err) |
| continue |
| } |
| switch { |
| case leader.IsSelf: |
| logger.Trace("waiting for core", "core", coreNo, "isLeader", true) |
| atomic.AddInt64(&actives, 1) |
| return |
| case leader.PerfStandby && leader.PerfStandbyLastRemoteWAL > 0: |
| switch { |
| case lastWAL == 0: |
| lastWAL = leader.PerfStandbyLastRemoteWAL |
| logger.Trace("waiting for core", "core", coreNo, "lastRemoteWAL", leader.PerfStandbyLastRemoteWAL, "lastWAL", lastWAL) |
| case lastWAL < leader.PerfStandbyLastRemoteWAL: |
| logger.Trace("waiting for core", "core", coreNo, "lastRemoteWAL", leader.PerfStandbyLastRemoteWAL, "lastWAL", lastWAL) |
| atomic.AddInt64(&standbys, 1) |
| return |
| } |
| default: |
| logger.Trace("waiting for core", "core", coreNo, |
| "ha_enabled", leader.HAEnabled, |
| "is_self", leader.IsSelf, |
| "perf_standby", leader.PerfStandby, |
| "perf_standby_remote_wal", leader.PerfStandbyLastRemoteWAL) |
| } |
| } |
| }(i) |
| } |
| |
| errs := make([]error, 0, len(cluster.Nodes())) |
| for range cluster.Nodes() { |
| errs = append(errs, <-errchan) |
| } |
| if actives != 1 || int(standbys) != expectedStandbys { |
| return fmt.Errorf("expected 1 active core and %d standbys, got %d active and %d standbys, errs: %v", |
| expectedStandbys, actives, standbys, errs) |
| } |
| |
| for ctx.Err() == nil { |
| err = leaderClient.Sys().UnmountWithContext(ctx, mountPoint) |
| if err == nil { |
| break |
| } |
| time.Sleep(time.Second) |
| } |
| if err != nil { |
| return fmt.Errorf("unable to unmount KV engine on primary") |
| } |
| return nil |
| } |
| |
| type GenerateRootKind int |
| |
| const ( |
| GenerateRootRegular GenerateRootKind = iota |
| GenerateRootDR |
| GenerateRecovery |
| ) |
| |
| func GenerateRoot(cluster VaultCluster, kind GenerateRootKind) (string, error) { |
| // If recovery keys supported, use those to perform root token generation instead |
| keys := cluster.GetBarrierOrRecoveryKeys() |
| |
| client := cluster.Nodes()[0].APIClient() |
| |
| var err error |
| var status *api.GenerateRootStatusResponse |
| switch kind { |
| case GenerateRootRegular: |
| status, err = client.Sys().GenerateRootInit("", "") |
| case GenerateRootDR: |
| status, err = client.Sys().GenerateDROperationTokenInit("", "") |
| case GenerateRecovery: |
| status, err = client.Sys().GenerateRecoveryOperationTokenInit("", "") |
| } |
| if err != nil { |
| return "", err |
| } |
| |
| if status.Required > len(keys) { |
| return "", fmt.Errorf("need more keys than have, need %d have %d", status.Required, len(keys)) |
| } |
| |
| otp := status.OTP |
| |
| for i, key := range keys { |
| if i >= status.Required { |
| break |
| } |
| |
| strKey := base64.StdEncoding.EncodeToString(key) |
| switch kind { |
| case GenerateRootRegular: |
| status, err = client.Sys().GenerateRootUpdate(strKey, status.Nonce) |
| case GenerateRootDR: |
| status, err = client.Sys().GenerateDROperationTokenUpdate(strKey, status.Nonce) |
| case GenerateRecovery: |
| status, err = client.Sys().GenerateRecoveryOperationTokenUpdate(strKey, status.Nonce) |
| } |
| if err != nil { |
| return "", err |
| } |
| } |
| if !status.Complete { |
| return "", fmt.Errorf("generate root operation did not end successfully") |
| } |
| |
| tokenBytes, err := base64.RawStdEncoding.DecodeString(status.EncodedToken) |
| if err != nil { |
| return "", err |
| } |
| tokenBytes, err = xor.XORBytes(tokenBytes, []byte(otp)) |
| if err != nil { |
| return "", err |
| } |
| return string(tokenBytes), nil |
| } |