blob: f634be2f39b84e1af2ae8eb4ca07f0d611c3b276 [file] [log] [blame]
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package testing
import (
"context"
"crypto/tls"
"fmt"
"os"
"reflect"
"sort"
"testing"
"github.com/hashicorp/errwrap"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/helper/testhelpers/corehelpers"
"github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/sdk/helper/logging"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/sdk/physical/inmem"
"github.com/hashicorp/vault/vault"
)
// TestEnvVar must be set to a non-empty value for acceptance tests to run.
const TestEnvVar = "VAULT_ACC"
// TestCase is a single set of tests to run for a backend. A TestCase
// should generally map 1:1 to each test method for your acceptance
// tests.
type TestCase struct {
// Precheck, if non-nil, will be called once before the test case
// runs at all. This can be used for some validation prior to the
// test running.
PreCheck func()
// LogicalBackend is the backend that will be mounted.
LogicalBackend logical.Backend
// LogicalFactory can be used instead of LogicalBackend if the
// backend requires more construction
LogicalFactory logical.Factory
// CredentialBackend is the backend that will be mounted.
CredentialBackend logical.Backend
// CredentialFactory can be used instead of CredentialBackend if the
// backend requires more construction
CredentialFactory logical.Factory
// Steps are the set of operations that are run for this test case.
Steps []TestStep
// Teardown will be called before the test case is over regardless
// of if the test succeeded or failed. This should return an error
// in the case that the test can't guarantee all resources were
// properly cleaned up.
Teardown TestTeardownFunc
// AcceptanceTest, if set, the test case will be run only if
// the environment variable VAULT_ACC is set. If not this test case
// will be run as a unit test.
AcceptanceTest bool
}
// TestStep is a single step within a TestCase.
type TestStep struct {
// Operation is the operation to execute
Operation logical.Operation
// Path is the request path. The mount prefix will be automatically added.
Path string
// Arguments to pass in
Data map[string]interface{}
// Check is called after this step is executed in order to test that
// the step executed successfully. If this is not set, then the next
// step will be called
Check TestCheckFunc
// PreFlight is called directly before execution of the request, allowing
// modification of the request parameters (e.g. Path) with dynamic values.
PreFlight PreFlightFunc
// ErrorOk, if true, will let erroneous responses through to the check
ErrorOk bool
// Unauthenticated, if true, will make the request unauthenticated.
Unauthenticated bool
// RemoteAddr, if set, will set the remote addr on the request.
RemoteAddr string
// ConnState, if set, will set the tls connection state
ConnState *tls.ConnectionState
}
// TestCheckFunc is the callback used for Check in TestStep.
type TestCheckFunc func(*logical.Response) error
// PreFlightFunc is used to modify request parameters directly before execution
// in each TestStep.
type PreFlightFunc func(*logical.Request) error
// TestTeardownFunc is the callback used for Teardown in TestCase.
type TestTeardownFunc func() error
// Test performs an acceptance test on a backend with the given test case.
//
// Tests are not run unless an environmental variable "VAULT_ACC" is
// set to some non-empty value. This is to avoid test cases surprising
// a user by creating real resources.
//
// Tests will fail unless the verbose flag (`go test -v`, or explicitly
// the "-test.v" flag) is set. Because some acceptance tests take quite
// long, we require the verbose flag so users are able to see progress
// output.
func Test(tt TestT, c TestCase) {
// We only run acceptance tests if an env var is set because they're
// slow and generally require some outside configuration.
if c.AcceptanceTest && os.Getenv(TestEnvVar) == "" {
tt.Skip(fmt.Sprintf(
"Acceptance tests skipped unless env %q set",
TestEnvVar))
return
}
// We require verbose mode so that the user knows what is going on.
if c.AcceptanceTest && !testTesting && !testing.Verbose() {
tt.Fatal("Acceptance tests must be run with the -v flag on tests")
return
}
// Run the PreCheck if we have it
if c.PreCheck != nil {
c.PreCheck()
}
// Defer on the teardown, regardless of pass/fail at this point
if c.Teardown != nil {
defer c.Teardown()
}
// Check that something is provided
if c.LogicalBackend == nil && c.LogicalFactory == nil {
if c.CredentialBackend == nil && c.CredentialFactory == nil {
tt.Fatal("Must provide either Backend or Factory")
return
}
}
// We currently only support doing one logical OR one credential test at a time.
if (c.LogicalFactory != nil || c.LogicalBackend != nil) && (c.CredentialFactory != nil || c.CredentialBackend != nil) {
tt.Fatal("Must provide only one backend or factory")
return
}
// Create an in-memory Vault core
logger := logging.NewVaultLogger(log.Trace)
phys, err := inmem.NewInmem(nil, logger)
if err != nil {
tt.Fatal(err)
return
}
config := &vault.CoreConfig{
Physical: phys,
DisableMlock: true,
BuiltinRegistry: corehelpers.NewMockBuiltinRegistry(),
}
if c.LogicalBackend != nil || c.LogicalFactory != nil {
config.LogicalBackends = map[string]logical.Factory{
"test": func(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
if c.LogicalBackend != nil {
return c.LogicalBackend, nil
}
return c.LogicalFactory(ctx, conf)
},
}
}
if c.CredentialBackend != nil || c.CredentialFactory != nil {
config.CredentialBackends = map[string]logical.Factory{
"test": func(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
if c.CredentialBackend != nil {
return c.CredentialBackend, nil
}
return c.CredentialFactory(ctx, conf)
},
}
}
core, err := vault.NewCore(config)
if err != nil {
tt.Fatal("error initializing core: ", err)
return
}
// Initialize the core
init, err := core.Initialize(context.Background(), &vault.InitParams{
BarrierConfig: &vault.SealConfig{
SecretShares: 1,
SecretThreshold: 1,
},
RecoveryConfig: nil,
})
if err != nil {
tt.Fatal("error initializing core: ", err)
return
}
// Unseal the core
if unsealed, err := core.Unseal(init.SecretShares[0]); err != nil {
tt.Fatal("error unsealing core: ", err)
return
} else if !unsealed {
tt.Fatal("vault shouldn't be sealed")
return
}
// Create an HTTP API server and client
ln, addr := http.TestServer(nil, core)
defer ln.Close()
clientConfig := api.DefaultConfig()
clientConfig.Address = addr
client, err := api.NewClient(clientConfig)
if err != nil {
tt.Fatal("error initializing HTTP client: ", err)
return
}
// Set the token so we're authenticated
client.SetToken(init.RootToken)
prefix := "mnt"
if c.LogicalBackend != nil || c.LogicalFactory != nil {
// Mount the backend
mountInfo := &api.MountInput{
Type: "test",
Description: "acceptance test",
}
if err := client.Sys().Mount(prefix, mountInfo); err != nil {
tt.Fatal("error mounting backend: ", err)
return
}
}
isAuthBackend := false
if c.CredentialBackend != nil || c.CredentialFactory != nil {
isAuthBackend = true
// Enable the test auth method
opts := &api.EnableAuthOptions{
Type: "test",
}
if err := client.Sys().EnableAuthWithOptions(prefix, opts); err != nil {
tt.Fatal("error enabling backend: ", err)
return
}
}
tokenInfo, err := client.Auth().Token().LookupSelf()
if err != nil {
tt.Fatal("error looking up token: ", err)
return
}
var tokenPolicies []string
if tokenPoliciesRaw, ok := tokenInfo.Data["policies"]; ok {
if tokenPoliciesSliceRaw, ok := tokenPoliciesRaw.([]interface{}); ok {
for _, p := range tokenPoliciesSliceRaw {
tokenPolicies = append(tokenPolicies, p.(string))
}
}
}
// Make requests
var revoke []*logical.Request
for i, s := range c.Steps {
if logger.IsWarn() {
logger.Warn("Executing test step", "step_number", i+1)
}
// Create the request
req := &logical.Request{
Operation: s.Operation,
Path: s.Path,
Data: s.Data,
}
if !s.Unauthenticated {
req.ClientToken = client.Token()
req.SetTokenEntry(&logical.TokenEntry{
ID: req.ClientToken,
NamespaceID: namespace.RootNamespaceID,
Policies: tokenPolicies,
DisplayName: tokenInfo.Data["display_name"].(string),
})
}
req.Connection = &logical.Connection{RemoteAddr: s.RemoteAddr}
if s.ConnState != nil {
req.Connection.ConnState = s.ConnState
}
if s.PreFlight != nil {
ct := req.ClientToken
req.ClientToken = ""
if err := s.PreFlight(req); err != nil {
tt.Error(fmt.Sprintf("Failed preflight for step %d: %s", i+1, err))
break
}
req.ClientToken = ct
}
// Make sure to prefix the path with where we mounted the thing
req.Path = fmt.Sprintf("%s/%s", prefix, req.Path)
if isAuthBackend {
// Prepend the path with "auth"
req.Path = "auth/" + req.Path
}
// Make the request
resp, err := core.HandleRequest(namespace.RootContext(nil), req)
if resp != nil && resp.Secret != nil {
// Revoke this secret later
revoke = append(revoke, &logical.Request{
Operation: logical.UpdateOperation,
Path: "sys/revoke/" + resp.Secret.LeaseID,
})
}
// Test step returned an error.
if err != nil {
// But if an error is expected, do not fail the test step,
// regardless of whether the error is a 'logical.ErrorResponse'
// or not. Set the err to nil. If the error is a logical.ErrorResponse,
// it will be handled later.
if s.ErrorOk {
err = nil
} else {
// If the error is not expected, fail right away.
tt.Error(fmt.Sprintf("Failed step %d: %s", i+1, err))
break
}
}
// If the error is a 'logical.ErrorResponse' and if error was not expected,
// set the error so that this can be caught below.
if resp.IsError() && !s.ErrorOk {
err = fmt.Errorf("erroneous response:\n\n%#v", resp)
}
// Either the 'err' was nil or if an error was expected, it was set to nil.
// Call the 'Check' function if there is one.
//
// TODO: This works perfectly for now, but it would be better if 'Check'
// function takes in both the response object and the error, and decide on
// the action on its own.
if err == nil && s.Check != nil {
// Call the test method
err = s.Check(resp)
}
if err != nil {
tt.Error(fmt.Sprintf("Failed step %d: %s", i+1, err))
break
}
}
// Revoke any secrets we might have.
var failedRevokes []*logical.Secret
for _, req := range revoke {
if logger.IsWarn() {
logger.Warn("Revoking secret", "secret", fmt.Sprintf("%#v", req))
}
req.ClientToken = client.Token()
resp, err := core.HandleRequest(namespace.RootContext(nil), req)
if err == nil && resp.IsError() {
err = fmt.Errorf("erroneous response:\n\n%#v", resp)
}
if err != nil {
failedRevokes = append(failedRevokes, req.Secret)
tt.Error(fmt.Sprintf("Revoke error: %s", err))
}
}
// Perform any rollbacks. This should no-op if there aren't any.
// We set the "immediate" flag here that any backend can pick up on
// to do all rollbacks immediately even if the WAL entries are new.
logger.Warn("Requesting RollbackOperation")
rollbackPath := prefix + "/"
if c.CredentialFactory != nil || c.CredentialBackend != nil {
rollbackPath = "auth/" + rollbackPath
}
req := logical.RollbackRequest(rollbackPath)
req.Data["immediate"] = true
req.ClientToken = client.Token()
resp, err := core.HandleRequest(namespace.RootContext(nil), req)
if err == nil && resp.IsError() {
err = fmt.Errorf("erroneous response:\n\n%#v", resp)
}
if err != nil {
if !errwrap.Contains(err, logical.ErrUnsupportedOperation.Error()) {
tt.Error(fmt.Sprintf("[ERR] Rollback error: %s", err))
}
}
// If we have any failed revokes, log it.
if len(failedRevokes) > 0 {
for _, s := range failedRevokes {
tt.Error(fmt.Sprintf(
"WARNING: Revoking the following secret failed. It may\n"+
"still exist. Please verify:\n\n%#v",
s))
}
}
}
// TestCheckMulti is a helper to have multiple checks.
func TestCheckMulti(fs ...TestCheckFunc) TestCheckFunc {
return func(resp *logical.Response) error {
for _, f := range fs {
if err := f(resp); err != nil {
return err
}
}
return nil
}
}
// TestCheckAuth is a helper to check that a request generated an
// auth token with the proper policies.
func TestCheckAuth(policies []string) TestCheckFunc {
return func(resp *logical.Response) error {
if resp == nil || resp.Auth == nil {
return fmt.Errorf("no auth in response")
}
expected := make([]string, len(policies))
copy(expected, policies)
sort.Strings(expected)
ret := make([]string, len(resp.Auth.Policies))
copy(ret, resp.Auth.Policies)
sort.Strings(ret)
if !reflect.DeepEqual(ret, expected) {
return fmt.Errorf("invalid policies: expected %#v, got %#v", expected, ret)
}
return nil
}
}
// TestCheckAuthEntityId is a helper to check that a request generated an
// auth token with the expected entity_id.
func TestCheckAuthEntityId(entity_id *string) TestCheckFunc {
return func(resp *logical.Response) error {
if resp == nil || resp.Auth == nil {
return fmt.Errorf("no auth in response")
}
if *entity_id == "" {
// If we don't know what the entity_id should be, just save it
*entity_id = resp.Auth.EntityID
} else if resp.Auth.EntityID != *entity_id {
return fmt.Errorf("entity_id %s does not match the expected value of %s", resp.Auth.EntityID, *entity_id)
}
return nil
}
}
// TestCheckAuthEntityAliasMetadataName is a helper to check that a request generated an
// auth token with the expected alias metadata.
func TestCheckAuthEntityAliasMetadataName(key string, value string) TestCheckFunc {
return func(resp *logical.Response) error {
if resp == nil || resp.Auth == nil {
return fmt.Errorf("no auth in response")
}
if key == "" || value == "" {
return fmt.Errorf("alias metadata key and value required")
}
name, ok := resp.Auth.Alias.Metadata[key]
if !ok {
return fmt.Errorf("metadata key %s does not exist, it should", key)
}
if name != value {
return fmt.Errorf("expected map value %s, got %s", value, name)
}
return nil
}
}
// TestCheckAuthDisplayName is a helper to check that a request generated a
// valid display name.
func TestCheckAuthDisplayName(n string) TestCheckFunc {
return func(resp *logical.Response) error {
if resp.Auth == nil {
return fmt.Errorf("no auth in response")
}
if n != "" && resp.Auth.DisplayName != "mnt-"+n {
return fmt.Errorf("invalid display name: %#v", resp.Auth.DisplayName)
}
return nil
}
}
// TestCheckError is a helper to check that a response is an error.
func TestCheckError() TestCheckFunc {
return func(resp *logical.Response) error {
if !resp.IsError() {
return fmt.Errorf("response should be error")
}
return nil
}
}
// TestT is the interface used to handle the test lifecycle of a test.
//
// Users should just use a *testing.T object, which implements this.
type TestT interface {
Error(args ...interface{})
Fatal(args ...interface{})
Skip(args ...interface{})
}
var testTesting = false