blob: 883cd69926bf957536fa6aa00dcfed128fd56bcc [file] [log] [blame]
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package testcluster
import (
"context"
"encoding/base64"
"encoding/hex"
"fmt"
"sync/atomic"
"time"
"github.com/hashicorp/go-multierror"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/sdk/helper/xor"
)
// Note that OSS standbys will not accept seal requests. And ent perf standbys
// may fail it as well if they haven't yet been able to get "elected" as perf standbys.
func SealNode(ctx context.Context, cluster VaultCluster, nodeIdx int) error {
if nodeIdx >= len(cluster.Nodes()) {
return fmt.Errorf("invalid nodeIdx %d for cluster", nodeIdx)
}
node := cluster.Nodes()[nodeIdx]
client := node.APIClient()
err := client.Sys().SealWithContext(ctx)
if err != nil {
return err
}
return NodeSealed(ctx, cluster, nodeIdx)
}
func SealAllNodes(ctx context.Context, cluster VaultCluster) error {
for i := range cluster.Nodes() {
if err := SealNode(ctx, cluster, i); err != nil {
return err
}
}
return nil
}
func UnsealNode(ctx context.Context, cluster VaultCluster, nodeIdx int) error {
if nodeIdx >= len(cluster.Nodes()) {
return fmt.Errorf("invalid nodeIdx %d for cluster", nodeIdx)
}
node := cluster.Nodes()[nodeIdx]
client := node.APIClient()
for _, key := range cluster.GetBarrierOrRecoveryKeys() {
_, err := client.Sys().UnsealWithContext(ctx, hex.EncodeToString(key))
if err != nil {
return err
}
}
return NodeHealthy(ctx, cluster, nodeIdx)
}
func UnsealAllNodes(ctx context.Context, cluster VaultCluster) error {
for i := range cluster.Nodes() {
if err := UnsealNode(ctx, cluster, i); err != nil {
return err
}
}
return nil
}
func NodeSealed(ctx context.Context, cluster VaultCluster, nodeIdx int) error {
if nodeIdx >= len(cluster.Nodes()) {
return fmt.Errorf("invalid nodeIdx %d for cluster", nodeIdx)
}
node := cluster.Nodes()[nodeIdx]
client := node.APIClient()
var health *api.HealthResponse
var err error
for ctx.Err() == nil {
health, err = client.Sys().HealthWithContext(ctx)
switch {
case err != nil:
case !health.Sealed:
err = fmt.Errorf("unsealed: %#v", health)
default:
return nil
}
time.Sleep(500 * time.Millisecond)
}
return fmt.Errorf("node %d is not sealed: %v", nodeIdx, err)
}
func WaitForNCoresSealed(ctx context.Context, cluster VaultCluster, n int) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
errs := make(chan error)
for i := range cluster.Nodes() {
go func(i int) {
var err error
for ctx.Err() == nil {
err = NodeSealed(ctx, cluster, i)
if err == nil {
errs <- nil
return
}
time.Sleep(100 * time.Millisecond)
}
if err == nil {
err = ctx.Err()
}
errs <- err
}(i)
}
var merr *multierror.Error
var sealed int
for range cluster.Nodes() {
err := <-errs
if err != nil {
merr = multierror.Append(merr, err)
} else {
sealed++
if sealed == n {
return nil
}
}
}
return fmt.Errorf("%d cores were not sealed, errs: %v", n, merr.ErrorOrNil())
}
func NodeHealthy(ctx context.Context, cluster VaultCluster, nodeIdx int) error {
if nodeIdx >= len(cluster.Nodes()) {
return fmt.Errorf("invalid nodeIdx %d for cluster", nodeIdx)
}
node := cluster.Nodes()[nodeIdx]
client := node.APIClient()
var health *api.HealthResponse
var err error
for ctx.Err() == nil {
health, err = client.Sys().HealthWithContext(ctx)
switch {
case err != nil:
case health == nil:
err = fmt.Errorf("nil response to health check")
case health.Sealed:
err = fmt.Errorf("sealed: %#v", health)
default:
return nil
}
time.Sleep(500 * time.Millisecond)
}
return fmt.Errorf("node %d is unhealthy: %v", nodeIdx, err)
}
func LeaderNode(ctx context.Context, cluster VaultCluster) (int, error) {
// Be robust to multiple nodes thinking they are active. This is possible in
// certain network partition situations where the old leader has not
// discovered it's lost leadership yet. In tests this is only likely to come
// up when we are specifically provoking it, but it's possible it could happen
// at any point if leadership flaps of connectivity suffers transient errors
// etc. so be robust against it. The best solution would be to have some sort
// of epoch like the raft term that is guaranteed to be monotonically
// increasing through elections, however we don't have that abstraction for
// all HABackends in general. The best we have is the ActiveTime. In a
// distributed systems text book this would be bad to rely on due to clock
// sync issues etc. but for our tests it's likely fine because even if we are
// running separate Vault containers, they are all using the same hardware
// clock in the system.
leaderActiveTimes := make(map[int]time.Time)
for i, node := range cluster.Nodes() {
client := node.APIClient()
ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
resp, err := client.Sys().LeaderWithContext(ctx)
cancel()
if err != nil || resp == nil || !resp.IsSelf {
continue
}
leaderActiveTimes[i] = resp.ActiveTime
}
if len(leaderActiveTimes) == 0 {
return -1, fmt.Errorf("no leader found")
}
// At least one node thinks it is active. If multiple, pick the one with the
// most recent ActiveTime. Note if there is only one then this just returns
// it.
var newestLeaderIdx int
var newestActiveTime time.Time
for i, at := range leaderActiveTimes {
if at.After(newestActiveTime) {
newestActiveTime = at
newestLeaderIdx = i
}
}
return newestLeaderIdx, nil
}
func WaitForActiveNode(ctx context.Context, cluster VaultCluster) (int, error) {
for ctx.Err() == nil {
if idx, _ := LeaderNode(ctx, cluster); idx != -1 {
return idx, nil
}
time.Sleep(500 * time.Millisecond)
}
return -1, ctx.Err()
}
func WaitForActiveNodeAndPerfStandbys(ctx context.Context, cluster VaultCluster) error {
logger := cluster.NamedLogger("WaitForActiveNodeAndPerfStandbys")
// This WaitForActiveNode was added because after a Raft cluster is sealed
// and then unsealed, when it comes up it may have a different leader than
// Core0, making this helper fail.
// A sleep before calling WaitForActiveNodeAndPerfStandbys seems to sort
// things out, but so apparently does this. We should be able to eliminate
// this call to WaitForActiveNode by reworking the logic in this method.
leaderIdx, err := WaitForActiveNode(ctx, cluster)
if err != nil {
return err
}
if len(cluster.Nodes()) == 1 {
return nil
}
expectedStandbys := len(cluster.Nodes()) - 1
mountPoint, err := uuid.GenerateUUID()
if err != nil {
return err
}
leaderClient := cluster.Nodes()[leaderIdx].APIClient()
for ctx.Err() == nil {
err = leaderClient.Sys().MountWithContext(ctx, mountPoint, &api.MountInput{
Type: "kv",
Local: true,
})
if err == nil {
break
}
time.Sleep(1 * time.Second)
}
if err != nil {
return fmt.Errorf("unable to mount KV engine: %v", err)
}
path := mountPoint + "/waitforactivenodeandperfstandbys"
var standbys, actives int64
errchan := make(chan error, len(cluster.Nodes()))
for i := range cluster.Nodes() {
go func(coreNo int) {
node := cluster.Nodes()[coreNo]
client := node.APIClient()
val := 1
var err error
defer func() {
errchan <- err
}()
var lastWAL uint64
for ctx.Err() == nil {
_, err = leaderClient.Logical().WriteWithContext(ctx, path, map[string]interface{}{
"bar": val,
})
val++
time.Sleep(250 * time.Millisecond)
if err != nil {
continue
}
var leader *api.LeaderResponse
leader, err = client.Sys().LeaderWithContext(ctx)
if err != nil {
logger.Trace("waiting for core", "core", coreNo, "err", err)
continue
}
switch {
case leader.IsSelf:
logger.Trace("waiting for core", "core", coreNo, "isLeader", true)
atomic.AddInt64(&actives, 1)
return
case leader.PerfStandby && leader.PerfStandbyLastRemoteWAL > 0:
switch {
case lastWAL == 0:
lastWAL = leader.PerfStandbyLastRemoteWAL
logger.Trace("waiting for core", "core", coreNo, "lastRemoteWAL", leader.PerfStandbyLastRemoteWAL, "lastWAL", lastWAL)
case lastWAL < leader.PerfStandbyLastRemoteWAL:
logger.Trace("waiting for core", "core", coreNo, "lastRemoteWAL", leader.PerfStandbyLastRemoteWAL, "lastWAL", lastWAL)
atomic.AddInt64(&standbys, 1)
return
}
default:
logger.Trace("waiting for core", "core", coreNo,
"ha_enabled", leader.HAEnabled,
"is_self", leader.IsSelf,
"perf_standby", leader.PerfStandby,
"perf_standby_remote_wal", leader.PerfStandbyLastRemoteWAL)
}
}
}(i)
}
errs := make([]error, 0, len(cluster.Nodes()))
for range cluster.Nodes() {
errs = append(errs, <-errchan)
}
if actives != 1 || int(standbys) != expectedStandbys {
return fmt.Errorf("expected 1 active core and %d standbys, got %d active and %d standbys, errs: %v",
expectedStandbys, actives, standbys, errs)
}
for ctx.Err() == nil {
err = leaderClient.Sys().UnmountWithContext(ctx, mountPoint)
if err == nil {
break
}
time.Sleep(time.Second)
}
if err != nil {
return fmt.Errorf("unable to unmount KV engine on primary")
}
return nil
}
type GenerateRootKind int
const (
GenerateRootRegular GenerateRootKind = iota
GenerateRootDR
GenerateRecovery
)
func GenerateRoot(cluster VaultCluster, kind GenerateRootKind) (string, error) {
// If recovery keys supported, use those to perform root token generation instead
keys := cluster.GetBarrierOrRecoveryKeys()
client := cluster.Nodes()[0].APIClient()
var err error
var status *api.GenerateRootStatusResponse
switch kind {
case GenerateRootRegular:
status, err = client.Sys().GenerateRootInit("", "")
case GenerateRootDR:
status, err = client.Sys().GenerateDROperationTokenInit("", "")
case GenerateRecovery:
status, err = client.Sys().GenerateRecoveryOperationTokenInit("", "")
}
if err != nil {
return "", err
}
if status.Required > len(keys) {
return "", fmt.Errorf("need more keys than have, need %d have %d", status.Required, len(keys))
}
otp := status.OTP
for i, key := range keys {
if i >= status.Required {
break
}
strKey := base64.StdEncoding.EncodeToString(key)
switch kind {
case GenerateRootRegular:
status, err = client.Sys().GenerateRootUpdate(strKey, status.Nonce)
case GenerateRootDR:
status, err = client.Sys().GenerateDROperationTokenUpdate(strKey, status.Nonce)
case GenerateRecovery:
status, err = client.Sys().GenerateRecoveryOperationTokenUpdate(strKey, status.Nonce)
}
if err != nil {
return "", err
}
}
if !status.Complete {
return "", fmt.Errorf("generate root operation did not end successfully")
}
tokenBytes, err := base64.RawStdEncoding.DecodeString(status.EncodedToken)
if err != nil {
return "", err
}
tokenBytes, err = xor.XORBytes(tokenBytes, []byte(otp))
if err != nil {
return "", err
}
return string(tokenBytes), nil
}