| // Copyright (c) HashiCorp, Inc. |
| // SPDX-License-Identifier: MPL-2.0 |
| |
| package testhelpers |
| |
| import ( |
| "context" |
| "encoding/base64" |
| "encoding/json" |
| "errors" |
| "fmt" |
| "io/ioutil" |
| "math/rand" |
| "net/url" |
| "os" |
| "strings" |
| "sync/atomic" |
| "time" |
| |
| "github.com/armon/go-metrics" |
| raftlib "github.com/hashicorp/raft" |
| "github.com/hashicorp/vault/api" |
| "github.com/hashicorp/vault/helper/metricsutil" |
| "github.com/hashicorp/vault/helper/namespace" |
| "github.com/hashicorp/vault/physical/raft" |
| "github.com/hashicorp/vault/sdk/helper/xor" |
| "github.com/hashicorp/vault/vault" |
| "github.com/mitchellh/go-testing-interface" |
| ) |
| |
| type GenerateRootKind int |
| |
| const ( |
| GenerateRootRegular GenerateRootKind = iota |
| GenerateRootDR |
| GenerateRecovery |
| ) |
| |
| // GenerateRoot generates a root token on the target cluster. |
| func GenerateRoot(t testing.T, cluster *vault.TestCluster, kind GenerateRootKind) string { |
| t.Helper() |
| token, err := GenerateRootWithError(t, cluster, kind) |
| if err != nil { |
| t.Fatal(err) |
| } |
| return token |
| } |
| |
| func GenerateRootWithError(t testing.T, cluster *vault.TestCluster, kind GenerateRootKind) (string, error) { |
| t.Helper() |
| // If recovery keys supported, use those to perform root token generation instead |
| var keys [][]byte |
| if cluster.Cores[0].SealAccess().RecoveryKeySupported() { |
| keys = cluster.RecoveryKeys |
| } else { |
| keys = cluster.BarrierKeys |
| } |
| client := cluster.Cores[0].Client |
| oldNS := client.Namespace() |
| defer client.SetNamespace(oldNS) |
| client.ClearNamespace() |
| |
| var err error |
| var status *api.GenerateRootStatusResponse |
| switch kind { |
| case GenerateRootRegular: |
| status, err = client.Sys().GenerateRootInit("", "") |
| case GenerateRootDR: |
| status, err = client.Sys().GenerateDROperationTokenInit("", "") |
| case GenerateRecovery: |
| status, err = client.Sys().GenerateRecoveryOperationTokenInit("", "") |
| } |
| if err != nil { |
| return "", err |
| } |
| |
| if status.Required > len(keys) { |
| return "", fmt.Errorf("need more keys than have, need %d have %d", status.Required, len(keys)) |
| } |
| |
| otp := status.OTP |
| |
| for i, key := range keys { |
| if i >= status.Required { |
| break |
| } |
| |
| strKey := base64.StdEncoding.EncodeToString(key) |
| switch kind { |
| case GenerateRootRegular: |
| status, err = client.Sys().GenerateRootUpdate(strKey, status.Nonce) |
| case GenerateRootDR: |
| status, err = client.Sys().GenerateDROperationTokenUpdate(strKey, status.Nonce) |
| case GenerateRecovery: |
| status, err = client.Sys().GenerateRecoveryOperationTokenUpdate(strKey, status.Nonce) |
| } |
| if err != nil { |
| return "", err |
| } |
| } |
| if !status.Complete { |
| return "", errors.New("generate root operation did not end successfully") |
| } |
| |
| tokenBytes, err := base64.RawStdEncoding.DecodeString(status.EncodedToken) |
| if err != nil { |
| return "", err |
| } |
| tokenBytes, err = xor.XORBytes(tokenBytes, []byte(otp)) |
| if err != nil { |
| return "", err |
| } |
| return string(tokenBytes), nil |
| } |
| |
| // RandomWithPrefix is used to generate a unique name with a prefix, for |
| // randomizing names in acceptance tests |
| func RandomWithPrefix(name string) string { |
| return fmt.Sprintf("%s-%d", name, rand.New(rand.NewSource(time.Now().UnixNano())).Int()) |
| } |
| |
| func EnsureCoresSealed(t testing.T, c *vault.TestCluster) { |
| t.Helper() |
| for _, core := range c.Cores { |
| EnsureCoreSealed(t, core) |
| } |
| } |
| |
| func EnsureCoreSealed(t testing.T, core *vault.TestClusterCore) { |
| t.Helper() |
| core.Seal(t) |
| timeout := time.Now().Add(60 * time.Second) |
| for { |
| if time.Now().After(timeout) { |
| t.Fatal("timeout waiting for core to seal") |
| } |
| if core.Core.Sealed() { |
| break |
| } |
| time.Sleep(250 * time.Millisecond) |
| } |
| } |
| |
| func EnsureCoresUnsealed(t testing.T, c *vault.TestCluster) { |
| t.Helper() |
| for i, core := range c.Cores { |
| err := AttemptUnsealCore(c, core) |
| if err != nil { |
| t.Fatalf("failed to unseal core %d: %v", i, err) |
| } |
| } |
| } |
| |
| func EnsureCoreUnsealed(t testing.T, c *vault.TestCluster, core *vault.TestClusterCore) { |
| t.Helper() |
| err := AttemptUnsealCore(c, core) |
| if err != nil { |
| t.Fatalf("failed to unseal core: %v", err) |
| } |
| } |
| |
| func AttemptUnsealCores(c *vault.TestCluster) error { |
| for i, core := range c.Cores { |
| err := AttemptUnsealCore(c, core) |
| if err != nil { |
| return fmt.Errorf("failed to unseal core %d: %v", i, err) |
| } |
| } |
| return nil |
| } |
| |
| func AttemptUnsealCore(c *vault.TestCluster, core *vault.TestClusterCore) error { |
| if !core.Sealed() { |
| return nil |
| } |
| |
| core.SealAccess().ClearCaches(context.Background()) |
| if err := core.UnsealWithStoredKeys(context.Background()); err != nil { |
| return err |
| } |
| |
| client := core.Client |
| oldNS := client.Namespace() |
| defer client.SetNamespace(oldNS) |
| client.ClearNamespace() |
| |
| client.Sys().ResetUnsealProcess() |
| for j := 0; j < len(c.BarrierKeys); j++ { |
| statusResp, err := client.Sys().Unseal(base64.StdEncoding.EncodeToString(c.BarrierKeys[j])) |
| if err != nil { |
| // Sometimes when we get here it's already unsealed on its own |
| // and then this fails for DR secondaries so check again |
| if core.Sealed() { |
| return err |
| } else { |
| return nil |
| } |
| } |
| if statusResp == nil { |
| return fmt.Errorf("nil status response during unseal") |
| } |
| if !statusResp.Sealed { |
| break |
| } |
| } |
| if core.Sealed() { |
| return fmt.Errorf("core is still sealed") |
| } |
| return nil |
| } |
| |
| func EnsureStableActiveNode(t testing.T, cluster *vault.TestCluster) { |
| t.Helper() |
| deriveStableActiveCore(t, cluster) |
| } |
| |
| func DeriveStableActiveCore(t testing.T, cluster *vault.TestCluster) *vault.TestClusterCore { |
| t.Helper() |
| return deriveStableActiveCore(t, cluster) |
| } |
| |
| func deriveStableActiveCore(t testing.T, cluster *vault.TestCluster) *vault.TestClusterCore { |
| t.Helper() |
| activeCore := DeriveActiveCore(t, cluster) |
| minDuration := time.NewTimer(3 * time.Second) |
| |
| for i := 0; i < 60; i++ { |
| leaderResp, err := activeCore.Client.Sys().Leader() |
| if err != nil { |
| t.Fatal(err) |
| } |
| if !leaderResp.IsSelf { |
| minDuration.Reset(3 * time.Second) |
| } |
| time.Sleep(200 * time.Millisecond) |
| } |
| |
| select { |
| case <-minDuration.C: |
| default: |
| if stopped := minDuration.Stop(); stopped { |
| t.Fatal("unstable active node") |
| } |
| // Drain the value |
| <-minDuration.C |
| } |
| |
| return activeCore |
| } |
| |
| func DeriveActiveCore(t testing.T, cluster *vault.TestCluster) *vault.TestClusterCore { |
| t.Helper() |
| for i := 0; i < 60; i++ { |
| for _, core := range cluster.Cores { |
| oldNS := core.Client.Namespace() |
| core.Client.ClearNamespace() |
| leaderResp, err := core.Client.Sys().Leader() |
| core.Client.SetNamespace(oldNS) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if leaderResp.IsSelf { |
| return core |
| } |
| } |
| time.Sleep(1 * time.Second) |
| } |
| t.Fatal("could not derive the active core") |
| return nil |
| } |
| |
| func DeriveStandbyCores(t testing.T, cluster *vault.TestCluster) []*vault.TestClusterCore { |
| t.Helper() |
| cores := make([]*vault.TestClusterCore, 0, 2) |
| for _, core := range cluster.Cores { |
| oldNS := core.Client.Namespace() |
| core.Client.ClearNamespace() |
| leaderResp, err := core.Client.Sys().Leader() |
| core.Client.SetNamespace(oldNS) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if !leaderResp.IsSelf { |
| cores = append(cores, core) |
| } |
| } |
| |
| return cores |
| } |
| |
| func WaitForNCoresUnsealed(t testing.T, cluster *vault.TestCluster, n int) { |
| t.Helper() |
| for i := 0; i < 30; i++ { |
| unsealed := 0 |
| for _, core := range cluster.Cores { |
| if !core.Core.Sealed() { |
| unsealed++ |
| } |
| } |
| |
| if unsealed >= n { |
| return |
| } |
| time.Sleep(time.Second) |
| } |
| |
| t.Fatalf("%d cores were not unsealed", n) |
| } |
| |
| func SealCores(t testing.T, cluster *vault.TestCluster) { |
| t.Helper() |
| for _, core := range cluster.Cores { |
| if err := core.Shutdown(); err != nil { |
| t.Fatal(err) |
| } |
| timeout := time.Now().Add(3 * time.Second) |
| for { |
| if time.Now().After(timeout) { |
| t.Fatal("timeout waiting for core to seal") |
| } |
| if core.Sealed() { |
| break |
| } |
| time.Sleep(100 * time.Millisecond) |
| } |
| } |
| } |
| |
| func WaitForNCoresSealed(t testing.T, cluster *vault.TestCluster, n int) { |
| t.Helper() |
| for i := 0; i < 60; i++ { |
| sealed := 0 |
| for _, core := range cluster.Cores { |
| if core.Core.Sealed() { |
| sealed++ |
| } |
| } |
| |
| if sealed >= n { |
| return |
| } |
| time.Sleep(time.Second) |
| } |
| |
| t.Fatalf("%d cores were not sealed", n) |
| } |
| |
| func WaitForActiveNode(t testing.T, cluster *vault.TestCluster) *vault.TestClusterCore { |
| t.Helper() |
| for i := 0; i < 60; i++ { |
| for _, core := range cluster.Cores { |
| if standby, _ := core.Core.Standby(); !standby { |
| return core |
| } |
| } |
| |
| time.Sleep(time.Second) |
| } |
| |
| t.Fatalf("node did not become active") |
| return nil |
| } |
| |
| func WaitForStandbyNode(t testing.T, core *vault.TestClusterCore) { |
| t.Helper() |
| for i := 0; i < 30; i++ { |
| if isLeader, _, clusterAddr, _ := core.Core.Leader(); isLeader != true && clusterAddr != "" { |
| return |
| } |
| if core.Core.ActiveNodeReplicationState() == 0 { |
| return |
| } |
| |
| time.Sleep(time.Second) |
| } |
| |
| t.Fatalf("node did not become standby") |
| } |
| |
| func RekeyCluster(t testing.T, cluster *vault.TestCluster, recovery bool) [][]byte { |
| t.Helper() |
| cluster.Logger.Info("rekeying cluster", "recovery", recovery) |
| client := cluster.Cores[0].Client |
| |
| initFunc := client.Sys().RekeyInit |
| if recovery { |
| initFunc = client.Sys().RekeyRecoveryKeyInit |
| } |
| init, err := initFunc(&api.RekeyInitRequest{ |
| SecretShares: 5, |
| SecretThreshold: 3, |
| }) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| var statusResp *api.RekeyUpdateResponse |
| keys := cluster.BarrierKeys |
| if cluster.Cores[0].Core.SealAccess().RecoveryKeySupported() { |
| keys = cluster.RecoveryKeys |
| } |
| |
| updateFunc := client.Sys().RekeyUpdate |
| if recovery { |
| updateFunc = client.Sys().RekeyRecoveryKeyUpdate |
| } |
| for j := 0; j < len(keys); j++ { |
| statusResp, err = updateFunc(base64.StdEncoding.EncodeToString(keys[j]), init.Nonce) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if statusResp == nil { |
| t.Fatal("nil status response during unseal") |
| } |
| if statusResp.Complete { |
| break |
| } |
| } |
| cluster.Logger.Info("cluster rekeyed", "recovery", recovery) |
| |
| if cluster.Cores[0].Core.SealAccess().RecoveryKeySupported() && !recovery { |
| return nil |
| } |
| if len(statusResp.KeysB64) != 5 { |
| t.Fatal("wrong number of keys") |
| } |
| |
| newKeys := make([][]byte, 5) |
| for i, key := range statusResp.KeysB64 { |
| newKeys[i], err = base64.StdEncoding.DecodeString(key) |
| if err != nil { |
| t.Fatal(err) |
| } |
| } |
| return newKeys |
| } |
| |
| // TestRaftServerAddressProvider is a ServerAddressProvider that uses the |
| // ClusterAddr() of each node to provide raft addresses. |
| // |
| // Note that TestRaftServerAddressProvider should only be used in cases where |
| // cores that are part of a raft configuration have already had |
| // startClusterListener() called (via either unsealing or raft joining). |
| type TestRaftServerAddressProvider struct { |
| Cluster *vault.TestCluster |
| } |
| |
| func (p *TestRaftServerAddressProvider) ServerAddr(id raftlib.ServerID) (raftlib.ServerAddress, error) { |
| for _, core := range p.Cluster.Cores { |
| if core.NodeID == string(id) { |
| parsed, err := url.Parse(core.ClusterAddr()) |
| if err != nil { |
| return "", err |
| } |
| |
| return raftlib.ServerAddress(parsed.Host), nil |
| } |
| } |
| |
| return "", errors.New("could not find cluster addr") |
| } |
| |
| func RaftClusterJoinNodes(t testing.T, cluster *vault.TestCluster) { |
| addressProvider := &TestRaftServerAddressProvider{Cluster: cluster} |
| |
| atomic.StoreUint32(&vault.TestingUpdateClusterAddr, 1) |
| |
| leader := cluster.Cores[0] |
| |
| // Seal the leader so we can install an address provider |
| { |
| EnsureCoreSealed(t, leader) |
| leader.UnderlyingRawStorage.(*raft.RaftBackend).SetServerAddressProvider(addressProvider) |
| cluster.UnsealCore(t, leader) |
| vault.TestWaitActive(t, leader.Core) |
| } |
| |
| leaderInfos := []*raft.LeaderJoinInfo{ |
| { |
| LeaderAPIAddr: leader.Client.Address(), |
| TLSConfig: leader.TLSConfig(), |
| }, |
| } |
| |
| // Join followers |
| for i := 1; i < len(cluster.Cores); i++ { |
| core := cluster.Cores[i] |
| core.UnderlyingRawStorage.(*raft.RaftBackend).SetServerAddressProvider(addressProvider) |
| _, err := core.JoinRaftCluster(namespace.RootContext(context.Background()), leaderInfos, false) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| cluster.UnsealCore(t, core) |
| } |
| |
| WaitForNCoresUnsealed(t, cluster, len(cluster.Cores)) |
| } |
| |
| // HardcodedServerAddressProvider is a ServerAddressProvider that uses |
| // a hardcoded map of raft node addresses. |
| // |
| // It is useful in cases where the raft configuration is known ahead of time, |
| // but some of the cores have not yet had startClusterListener() called (via |
| // either unsealing or raft joining), and thus do not yet have a ClusterAddr() |
| // assigned. |
| type HardcodedServerAddressProvider struct { |
| Entries map[raftlib.ServerID]raftlib.ServerAddress |
| } |
| |
| func (p *HardcodedServerAddressProvider) ServerAddr(id raftlib.ServerID) (raftlib.ServerAddress, error) { |
| if addr, ok := p.Entries[id]; ok { |
| return addr, nil |
| } |
| return "", errors.New("could not find cluster addr") |
| } |
| |
| // NewHardcodedServerAddressProvider is a convenience function that makes a |
| // ServerAddressProvider from a given cluster address base port. |
| func NewHardcodedServerAddressProvider(numCores, baseClusterPort int) raftlib.ServerAddressProvider { |
| entries := make(map[raftlib.ServerID]raftlib.ServerAddress) |
| |
| for i := 0; i < numCores; i++ { |
| id := fmt.Sprintf("core-%d", i) |
| addr := fmt.Sprintf("127.0.0.1:%d", baseClusterPort+i) |
| entries[raftlib.ServerID(id)] = raftlib.ServerAddress(addr) |
| } |
| |
| return &HardcodedServerAddressProvider{ |
| entries, |
| } |
| } |
| |
| // VerifyRaftConfiguration checks that we have a valid raft configuration, i.e. |
| // the correct number of servers, having the correct NodeIDs, and exactly one |
| // leader. |
| func VerifyRaftConfiguration(core *vault.TestClusterCore, numCores int) error { |
| backend := core.UnderlyingRawStorage.(*raft.RaftBackend) |
| ctx := namespace.RootContext(context.Background()) |
| config, err := backend.GetConfiguration(ctx) |
| if err != nil { |
| return err |
| } |
| |
| servers := config.Servers |
| if len(servers) != numCores { |
| return fmt.Errorf("Found %d servers, not %d", len(servers), numCores) |
| } |
| |
| leaders := 0 |
| for i, s := range servers { |
| if s.NodeID != fmt.Sprintf("core-%d", i) { |
| return fmt.Errorf("Found unexpected node ID %q", s.NodeID) |
| } |
| if s.Leader { |
| leaders++ |
| } |
| } |
| |
| if leaders != 1 { |
| return fmt.Errorf("Found %d leaders", leaders) |
| } |
| |
| return nil |
| } |
| |
| func RaftAppliedIndex(core *vault.TestClusterCore) uint64 { |
| return core.UnderlyingRawStorage.(*raft.RaftBackend).AppliedIndex() |
| } |
| |
| func WaitForRaftApply(t testing.T, core *vault.TestClusterCore, index uint64) { |
| t.Helper() |
| |
| backend := core.UnderlyingRawStorage.(*raft.RaftBackend) |
| for i := 0; i < 30; i++ { |
| if backend.AppliedIndex() >= index { |
| return |
| } |
| |
| time.Sleep(time.Second) |
| } |
| |
| t.Fatalf("node did not apply index") |
| } |
| |
| // AwaitLeader waits for one of the cluster's nodes to become leader. |
| func AwaitLeader(t testing.T, cluster *vault.TestCluster) (int, error) { |
| timeout := time.Now().Add(60 * time.Second) |
| for { |
| if time.Now().After(timeout) { |
| break |
| } |
| |
| for i, core := range cluster.Cores { |
| if core.Core.Sealed() { |
| continue |
| } |
| |
| isLeader, _, _, _ := core.Leader() |
| if isLeader { |
| return i, nil |
| } |
| } |
| |
| time.Sleep(time.Second) |
| } |
| |
| return 0, fmt.Errorf("timeout waiting leader") |
| } |
| |
| func GenerateDebugLogs(t testing.T, client *api.Client) chan struct{} { |
| t.Helper() |
| |
| stopCh := make(chan struct{}) |
| |
| go func() { |
| ticker := time.NewTicker(time.Second) |
| defer ticker.Stop() |
| for { |
| select { |
| case <-stopCh: |
| return |
| case <-ticker.C: |
| err := client.Sys().Mount("foo", &api.MountInput{ |
| Type: "kv", |
| Options: map[string]string{ |
| "version": "1", |
| }, |
| }) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| err = client.Sys().Unmount("foo") |
| if err != nil { |
| t.Fatal(err) |
| } |
| } |
| } |
| }() |
| |
| return stopCh |
| } |
| |
| // VerifyRaftPeers verifies that the raft configuration contains a given set of peers. |
| // The `expected` contains a map of expected peers. Existing entries are deleted |
| // from the map by removing entries whose keys are in the raft configuration. |
| // Remaining entries result in an error return so that the caller can poll for |
| // an expected configuration. |
| func VerifyRaftPeers(t testing.T, client *api.Client, expected map[string]bool) error { |
| t.Helper() |
| |
| resp, err := client.Logical().Read("sys/storage/raft/configuration") |
| if err != nil { |
| t.Fatalf("error reading raft config: %v", err) |
| } |
| |
| if resp == nil || resp.Data == nil { |
| t.Fatal("missing response data") |
| } |
| |
| config, ok := resp.Data["config"].(map[string]interface{}) |
| if !ok { |
| t.Fatal("missing config in response data") |
| } |
| |
| servers, ok := config["servers"].([]interface{}) |
| if !ok { |
| t.Fatal("missing servers in response data config") |
| } |
| |
| // Iterate through the servers and remove the node found in the response |
| // from the expected collection |
| for _, s := range servers { |
| server := s.(map[string]interface{}) |
| delete(expected, server["node_id"].(string)) |
| } |
| |
| // If the collection is non-empty, it means that the peer was not found in |
| // the response. |
| if len(expected) != 0 { |
| return fmt.Errorf("failed to read configuration successfully, expected peers not found in configuration list: %v", expected) |
| } |
| |
| return nil |
| } |
| |
| func TestMetricSinkProvider(gaugeInterval time.Duration) func(string) (*metricsutil.ClusterMetricSink, *metricsutil.MetricsHelper) { |
| return func(clusterName string) (*metricsutil.ClusterMetricSink, *metricsutil.MetricsHelper) { |
| inm := metrics.NewInmemSink(1000000*time.Hour, 2000000*time.Hour) |
| clusterSink := metricsutil.NewClusterMetricSink(clusterName, inm) |
| clusterSink.GaugeInterval = gaugeInterval |
| return clusterSink, metricsutil.NewMetricsHelper(inm, false) |
| } |
| } |
| |
| func SysMetricsReq(client *api.Client, cluster *vault.TestCluster, unauth bool) (*SysMetricsJSON, error) { |
| r := client.NewRequest("GET", "/v1/sys/metrics") |
| if !unauth { |
| r.Headers.Set("X-Vault-Token", cluster.RootToken) |
| } |
| var data SysMetricsJSON |
| resp, err := client.RawRequestWithContext(context.Background(), r) |
| if err != nil { |
| return nil, err |
| } |
| bodyBytes, err := ioutil.ReadAll(resp.Response.Body) |
| if err != nil { |
| return nil, err |
| } |
| defer resp.Body.Close() |
| if err := json.Unmarshal(bodyBytes, &data); err != nil { |
| return nil, errors.New("failed to unmarshal:" + err.Error()) |
| } |
| return &data, nil |
| } |
| |
| type SysMetricsJSON struct { |
| Gauges []gaugeJSON `json:"Gauges"` |
| Counters []counterJSON `json:"Counters"` |
| |
| // note: this is referred to as a "Summary" type in our telemetry docs, but |
| // the field name in the JSON is "Samples" |
| Summaries []summaryJSON `json:"Samples"` |
| } |
| |
| type baseInfoJSON struct { |
| Name string `json:"Name"` |
| Labels map[string]interface{} `json:"Labels"` |
| } |
| |
| type gaugeJSON struct { |
| baseInfoJSON |
| Value int `json:"Value"` |
| } |
| |
| type counterJSON struct { |
| baseInfoJSON |
| Count int `json:"Count"` |
| Rate float64 `json:"Rate"` |
| Sum int `json:"Sum"` |
| Min int `json:"Min"` |
| Max int `json:"Max"` |
| Mean float64 `json:"Mean"` |
| Stddev float64 `json:"Stddev"` |
| } |
| |
| type summaryJSON struct { |
| baseInfoJSON |
| Count int `json:"Count"` |
| Rate float64 `json:"Rate"` |
| Sum float64 `json:"Sum"` |
| Min float64 `json:"Min"` |
| Max float64 `json:"Max"` |
| Mean float64 `json:"Mean"` |
| Stddev float64 `json:"Stddev"` |
| } |
| |
| // SetNonRootToken sets a token on :client: with a fairly generic policy. |
| // This is useful if a test needs to examine differing behavior based on if a |
| // root token is passed with the request. |
| func SetNonRootToken(client *api.Client) error { |
| policy := `path "*" { capabilities = ["create", "update", "read"] }` |
| if err := client.Sys().PutPolicy("policy", policy); err != nil { |
| return fmt.Errorf("error putting policy: %v", err) |
| } |
| |
| secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{ |
| Policies: []string{"policy"}, |
| TTL: "30m", |
| }) |
| if err != nil { |
| return fmt.Errorf("error creating token secret: %v", err) |
| } |
| |
| if secret == nil || secret.Auth == nil || secret.Auth.ClientToken == "" { |
| return fmt.Errorf("missing token auth data") |
| } |
| |
| client.SetToken(secret.Auth.ClientToken) |
| return nil |
| } |
| |
| // RetryUntilAtCadence runs f until it returns a nil result or the timeout is reached. |
| // If a nil result hasn't been obtained by timeout, calls t.Fatal. |
| func RetryUntilAtCadence(t testing.T, timeout, sleepTime time.Duration, f func() error) { |
| t.Helper() |
| deadline := time.Now().Add(timeout) |
| var err error |
| for time.Now().Before(deadline) { |
| if err = f(); err == nil { |
| return |
| } |
| time.Sleep(sleepTime) |
| } |
| t.Fatalf("did not complete before deadline, err: %v", err) |
| } |
| |
| // RetryUntil runs f until it returns a nil result or the timeout is reached. |
| // If a nil result hasn't been obtained by timeout, calls t.Fatal. |
| func RetryUntil(t testing.T, timeout time.Duration, f func() error) { |
| t.Helper() |
| deadline := time.Now().Add(timeout) |
| var err error |
| for time.Now().Before(deadline) { |
| if err = f(); err == nil { |
| return |
| } |
| time.Sleep(100 * time.Millisecond) |
| } |
| t.Fatalf("did not complete before deadline, err: %v", err) |
| } |
| |
| // CreateEntityAndAlias clones an existing client and creates an entity/alias. |
| // It returns the cloned client, entityID, and aliasID. |
| func CreateEntityAndAlias(t testing.T, client *api.Client, mountAccessor, entityName, aliasName string) (*api.Client, string, string) { |
| t.Helper() |
| userClient, err := client.Clone() |
| if err != nil { |
| t.Fatalf("failed to clone the client:%v", err) |
| } |
| userClient.SetToken(client.Token()) |
| |
| resp, err := client.Logical().WriteWithContext(context.Background(), "identity/entity", map[string]interface{}{ |
| "name": entityName, |
| }) |
| if err != nil { |
| t.Fatalf("failed to create an entity:%v", err) |
| } |
| entityID := resp.Data["id"].(string) |
| |
| aliasResp, err := client.Logical().WriteWithContext(context.Background(), "identity/entity-alias", map[string]interface{}{ |
| "name": aliasName, |
| "canonical_id": entityID, |
| "mount_accessor": mountAccessor, |
| }) |
| if err != nil { |
| t.Fatalf("failed to create an entity alias:%v", err) |
| } |
| aliasID := aliasResp.Data["id"].(string) |
| if aliasID == "" { |
| t.Fatal("Alias ID not present in response") |
| } |
| _, err = client.Logical().WriteWithContext(context.Background(), fmt.Sprintf("auth/userpass/users/%s", aliasName), map[string]interface{}{ |
| "password": "testpassword", |
| }) |
| if err != nil { |
| t.Fatalf("failed to configure userpass backend: %v", err) |
| } |
| |
| return userClient, entityID, aliasID |
| } |
| |
| // SetupTOTPMount enables the totp secrets engine by mounting it. This requires |
| // that the test cluster has a totp backend available. |
| func SetupTOTPMount(t testing.T, client *api.Client) { |
| t.Helper() |
| // Mount the TOTP backend |
| mountInfo := &api.MountInput{ |
| Type: "totp", |
| } |
| if err := client.Sys().Mount("totp", mountInfo); err != nil { |
| t.Fatalf("failed to mount totp backend: %v", err) |
| } |
| } |
| |
| // SetupTOTPMethod configures the TOTP secrets engine with a provided config map. |
| func SetupTOTPMethod(t testing.T, client *api.Client, config map[string]interface{}) string { |
| t.Helper() |
| |
| resp1, err := client.Logical().Write("identity/mfa/method/totp", config) |
| |
| if err != nil || (resp1 == nil) { |
| t.Fatalf("bad: resp: %#v\n err: %v", resp1, err) |
| } |
| |
| methodID := resp1.Data["method_id"].(string) |
| if methodID == "" { |
| t.Fatalf("method ID is empty") |
| } |
| |
| return methodID |
| } |
| |
| // SetupMFALoginEnforcement configures a single enforcement method using the |
| // provided config map. "name" field is required in the config map. |
| func SetupMFALoginEnforcement(t testing.T, client *api.Client, config map[string]interface{}) { |
| t.Helper() |
| enfName, ok := config["name"] |
| if !ok { |
| t.Fatalf("couldn't find name in login-enforcement config") |
| } |
| _, err := client.Logical().WriteWithContext(context.Background(), fmt.Sprintf("identity/mfa/login-enforcement/%s", enfName), config) |
| if err != nil { |
| t.Fatalf("failed to configure MFAEnforcementConfig: %v", err) |
| } |
| } |
| |
| // SetupUserpassMountAccessor sets up userpass auth and returns its mount |
| // accessor. This requires that the test cluster has a "userpass" auth method |
| // available. |
| func SetupUserpassMountAccessor(t testing.T, client *api.Client) string { |
| t.Helper() |
| // Enable Userpass authentication |
| err := client.Sys().EnableAuthWithOptions("userpass", &api.EnableAuthOptions{ |
| Type: "userpass", |
| }) |
| if err != nil { |
| t.Fatalf("failed to enable userpass auth: %v", err) |
| } |
| |
| auths, err := client.Sys().ListAuthWithContext(context.Background()) |
| if err != nil { |
| t.Fatalf("failed to list auth methods: %v", err) |
| } |
| if auths == nil || auths["userpass/"] == nil { |
| t.Fatalf("failed to get userpass mount accessor") |
| } |
| |
| return auths["userpass/"].Accessor |
| } |
| |
| // RegisterEntityInTOTPEngine registers an entity with a methodID and returns |
| // the generated name. |
| func RegisterEntityInTOTPEngine(t testing.T, client *api.Client, entityID, methodID string) string { |
| t.Helper() |
| totpGenName := fmt.Sprintf("%s-%s", entityID, methodID) |
| secret, err := client.Logical().WriteWithContext(context.Background(), "identity/mfa/method/totp/admin-generate", map[string]interface{}{ |
| "entity_id": entityID, |
| "method_id": methodID, |
| }) |
| if err != nil { |
| t.Fatalf("failed to generate a TOTP secret on an entity: %v", err) |
| } |
| totpURL := secret.Data["url"].(string) |
| if totpURL == "" { |
| t.Fatalf("failed to get TOTP url in secret response: %+v", secret) |
| } |
| _, err = client.Logical().WriteWithContext(context.Background(), fmt.Sprintf("totp/keys/%s", totpGenName), map[string]interface{}{ |
| "url": totpURL, |
| }) |
| if err != nil { |
| t.Fatalf("failed to register a TOTP URL: %v", err) |
| } |
| enfPath := fmt.Sprintf("identity/mfa/login-enforcement/%s", methodID[0:4]) |
| _, err = client.Logical().WriteWithContext(context.Background(), enfPath, map[string]interface{}{ |
| "name": methodID[0:4], |
| "identity_entity_ids": []string{entityID}, |
| "mfa_method_ids": []string{methodID}, |
| }) |
| if err != nil { |
| t.Fatalf("failed to create login enforcement") |
| } |
| |
| return totpGenName |
| } |
| |
| // GetTOTPCodeFromEngine requests a TOTP code from the specified enginePath. |
| func GetTOTPCodeFromEngine(t testing.T, client *api.Client, enginePath string) string { |
| t.Helper() |
| totpPath := fmt.Sprintf("totp/code/%s", enginePath) |
| secret, err := client.Logical().ReadWithContext(context.Background(), totpPath) |
| if err != nil { |
| t.Fatalf("failed to create totp passcode: %v", err) |
| } |
| if secret == nil || secret.Data == nil { |
| t.Fatalf("bad secret returned from %s", totpPath) |
| } |
| return secret.Data["code"].(string) |
| } |
| |
| // SetupLoginMFATOTP setups up a TOTP MFA using some basic configuration and |
| // returns all relevant information to the client. |
| func SetupLoginMFATOTP(t testing.T, client *api.Client, methodName string, waitPeriod int) (*api.Client, string, string) { |
| t.Helper() |
| // Mount the totp secrets engine |
| SetupTOTPMount(t, client) |
| |
| // Create a mount accessor to associate with an entity |
| mountAccessor := SetupUserpassMountAccessor(t, client) |
| |
| // Create a test entity and alias |
| entityClient, entityID, _ := CreateEntityAndAlias(t, client, mountAccessor, "entity1", "testuser1") |
| |
| // Configure a default TOTP method |
| totpConfig := map[string]interface{}{ |
| "issuer": "yCorp", |
| "period": waitPeriod, |
| "algorithm": "SHA256", |
| "digits": 6, |
| "skew": 1, |
| "key_size": 20, |
| "qr_size": 200, |
| "max_validation_attempts": 5, |
| "method_name": methodName, |
| } |
| methodID := SetupTOTPMethod(t, client, totpConfig) |
| |
| // Configure a default login enforcement |
| enforcementConfig := map[string]interface{}{ |
| "auth_method_types": []string{"userpass"}, |
| "name": methodID[0:4], |
| "mfa_method_ids": []string{methodID}, |
| } |
| |
| SetupMFALoginEnforcement(t, client, enforcementConfig) |
| return entityClient, entityID, methodID |
| } |
| |
| func SkipUnlessEnvVarsSet(t testing.T, envVars []string) { |
| t.Helper() |
| |
| for _, i := range envVars { |
| if os.Getenv(i) == "" { |
| t.Skipf("%s must be set for this test to run", strings.Join(envVars, " ")) |
| } |
| } |
| } |
| |
| // WaitForNodesExcludingSelectedStandbys is variation on WaitForActiveNodeAndStandbys. |
| // It waits for the active node before waiting for standby nodes, however |
| // it will not wait for cores with indexes that match those specified as arguments. |
| // Whilst you could specify index 0 which is likely to be the leader node, the function |
| // checks for the leader first regardless of the indexes to skip, so it would be redundant to do so. |
| // The intention/use case for this function is to allow a cluster to start and become active with one |
| // or more nodes not joined, so that we can test scenarios where a node joins later. |
| // e.g. 4 nodes in the cluster, only 3 nodes in cluster 'active', 1 node can be joined later in tests. |
| func WaitForNodesExcludingSelectedStandbys(t testing.T, cluster *vault.TestCluster, indexesToSkip ...int) { |
| WaitForActiveNode(t, cluster) |
| |
| contains := func(elems []int, e int) bool { |
| for _, v := range elems { |
| if v == e { |
| return true |
| } |
| } |
| |
| return false |
| } |
| for i, core := range cluster.Cores { |
| if contains(indexesToSkip, i) { |
| continue |
| } |
| |
| if standby, _ := core.Core.Standby(); standby { |
| WaitForStandbyNode(t, core) |
| } |
| } |
| } |
| |
| // IsLocalOrRegressionTests returns true when the tests are running locally (not in CI), or when |
| // the regression test env var (VAULT_REGRESSION_TESTS) is provided. |
| func IsLocalOrRegressionTests() bool { |
| return os.Getenv("CI") == "" || os.Getenv("VAULT_REGRESSION_TESTS") == "true" |
| } |