blob: 1771489d63a7c093289320ad136f2c711b4715a5 [file] [log] [blame] [edit]
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
// Package corehelpers contains testhelpers that don't depend on package vault,
// and thus can be used within vault (as well as elsewhere.)
package corehelpers
import (
var externalPlugins = []string{"transform", "kmip", "keymgmt"}
// RetryUntil runs f until it returns a nil result or the timeout is reached.
// If a nil result hasn't been obtained by timeout, calls t.Fatal.
func RetryUntil(t testing.T, timeout time.Duration, f func() error) {
deadline := time.Now().Add(timeout)
var err error
for time.Now().Before(deadline) {
if err = f(); err == nil {
time.Sleep(100 * time.Millisecond)
t.Fatalf("did not complete before deadline, err: %v", err)
// MakeTestPluginDir creates a temporary directory suitable for holding plugins.
// This helper also resolves symlinks to make tests happy on OS X.
func MakeTestPluginDir(t testing.T) (string, func(t testing.T)) {
if t != nil {
dir, err := os.MkdirTemp("", "")
if err != nil {
if t == nil {
// OSX tempdir are /var, but actually symlinked to /private/var
dir, err = filepath.EvalSymlinks(dir)
if err != nil {
if t == nil {
return dir, func(t testing.T) {
if err := os.RemoveAll(dir); err != nil {
if t == nil {
func NewMockBuiltinRegistry() *mockBuiltinRegistry {
return &mockBuiltinRegistry{
forTesting: map[string]mockBackend{
"mysql-database-plugin": {PluginType: consts.PluginTypeDatabase},
"postgresql-database-plugin": {PluginType: consts.PluginTypeDatabase},
"approle": {PluginType: consts.PluginTypeCredential},
"pending-removal-test-plugin": {
PluginType: consts.PluginTypeCredential,
DeprecationStatus: consts.PendingRemoval,
"aws": {PluginType: consts.PluginTypeCredential},
"consul": {PluginType: consts.PluginTypeSecrets},
type mockBackend struct {
type mockBuiltinRegistry struct {
forTesting map[string]mockBackend
func toFunc(f logical.Factory) func() (interface{}, error) {
return func() (interface{}, error) {
return f, nil
func (m *mockBuiltinRegistry) Get(name string, pluginType consts.PluginType) (func() (interface{}, error), bool) {
testBackend, ok := m.forTesting[name]
if !ok {
return nil, false
testPluginType := testBackend.PluginType
if pluginType != testPluginType {
return nil, false
switch name {
case "approle", "pending-removal-test-plugin":
return toFunc(approle.Factory), true
case "aws":
return toFunc(func(ctx context.Context, config *logical.BackendConfig) (logical.Backend, error) {
b := new(framework.Backend)
b.Setup(ctx, config)
b.BackendType = logical.TypeCredential
return b, nil
}), true
case "postgresql-database-plugin":
return toFunc(func(ctx context.Context, config *logical.BackendConfig) (logical.Backend, error) {
b := new(framework.Backend)
b.Setup(ctx, config)
b.BackendType = logical.TypeLogical
return b, nil
}), true
case "mysql-database-plugin":
return mysql.New(mysql.DefaultUserNameTemplate), true
case "consul":
return toFunc(func(ctx context.Context, config *logical.BackendConfig) (logical.Backend, error) {
b := new(framework.Backend)
b.Setup(ctx, config)
b.BackendType = logical.TypeLogical
return b, nil
}), true
return nil, false
// Keys only supports getting a realistic list of the keys for database plugins,
// and approle
func (m *mockBuiltinRegistry) Keys(pluginType consts.PluginType) []string {
switch pluginType {
case consts.PluginTypeDatabase:
// This is a hard-coded reproduction of the db plugin keys in
// helper/builtinplugins/registry.go. The registry isn't directly used
// because it causes import cycles.
return []string{
case consts.PluginTypeCredential:
return []string{
case consts.PluginTypeSecrets:
return append(externalPlugins, "kv")
return []string{}
func (r *mockBuiltinRegistry) IsBuiltinEntPlugin(name string, pluginType consts.PluginType) bool {
for _, i := range externalPlugins {
if i == name {
return true
return false
func (m *mockBuiltinRegistry) Contains(name string, pluginType consts.PluginType) bool {
for _, key := range m.Keys(pluginType) {
if key == name {
return true
return false
func (m *mockBuiltinRegistry) DeprecationStatus(name string, pluginType consts.PluginType) (consts.DeprecationStatus, bool) {
if m.Contains(name, pluginType) {
return m.forTesting[name].DeprecationStatus, true
return consts.Unknown, false
func TestNoopAudit(t testing.T, config map[string]string) *NoopAudit {
n, err := NewNoopAudit(config)
if err != nil {
return n
func NewNoopAudit(config map[string]string) (*NoopAudit, error) {
view := &logical.InmemStorage{}
err := view.Put(context.Background(), &logical.StorageEntry{
Key: "salt",
Value: []byte("foo"),
if err != nil {
return nil, err
n := &NoopAudit{
Config: &audit.BackendConfig{
SaltView: view,
SaltConfig: &salt.Config{
HMAC: sha256.New,
HMACType: "hmac-sha256",
Config: config,
n.formatter.AuditFormatWriter = &audit.JSONFormatWriter{
SaltFunc: n.Salt,
return n, nil
func NoopAuditFactory(records **[][]byte) audit.Factory {
return func(_ context.Context, config *audit.BackendConfig) (audit.Backend, error) {
n, err := NewNoopAudit(config.Config)
if err != nil {
return nil, err
if records != nil {
*records = &n.records
return n, nil
type NoopAudit struct {
Config *audit.BackendConfig
ReqErr error
ReqAuth []*logical.Auth
Req []*logical.Request
ReqHeaders []map[string][]string
ReqNonHMACKeys []string
ReqErrs []error
RespErr error
RespAuth []*logical.Auth
RespReq []*logical.Request
Resp []*logical.Response
RespNonHMACKeys [][]string
RespReqNonHMACKeys [][]string
RespErrs []error
formatter audit.AuditFormatter
records [][]byte
l sync.RWMutex
salt *salt.Salt
saltMutex sync.RWMutex
func (n *NoopAudit) LogRequest(ctx context.Context, in *logical.LogInput) error {
defer n.l.Unlock()
if n.formatter.AuditFormatWriter != nil {
var w bytes.Buffer
err := n.formatter.FormatRequest(ctx, &w, audit.FormatterConfig{}, in)
if err != nil {
return err
n.records = append(n.records, w.Bytes())
n.ReqAuth = append(n.ReqAuth, in.Auth)
n.Req = append(n.Req, in.Request)
n.ReqHeaders = append(n.ReqHeaders, in.Request.Headers)
n.ReqNonHMACKeys = in.NonHMACReqDataKeys
n.ReqErrs = append(n.ReqErrs, in.OuterErr)
return n.ReqErr
func (n *NoopAudit) LogResponse(ctx context.Context, in *logical.LogInput) error {
defer n.l.Unlock()
if n.formatter.AuditFormatWriter != nil {
var w bytes.Buffer
err := n.formatter.FormatResponse(ctx, &w, audit.FormatterConfig{}, in)
if err != nil {
return err
n.records = append(n.records, w.Bytes())
n.RespAuth = append(n.RespAuth, in.Auth)
n.RespReq = append(n.RespReq, in.Request)
n.Resp = append(n.Resp, in.Response)
n.RespErrs = append(n.RespErrs, in.OuterErr)
if in.Response != nil {
n.RespNonHMACKeys = append(n.RespNonHMACKeys, in.NonHMACRespDataKeys)
n.RespReqNonHMACKeys = append(n.RespReqNonHMACKeys, in.NonHMACReqDataKeys)
return n.RespErr
func (n *NoopAudit) LogTestMessage(ctx context.Context, in *logical.LogInput, config map[string]string) error {
defer n.l.Unlock()
var w bytes.Buffer
tempFormatter := audit.NewTemporaryFormatter(config["format"], config["prefix"])
err := tempFormatter.FormatResponse(ctx, &w, audit.FormatterConfig{}, in)
if err != nil {
return err
n.records = append(n.records, w.Bytes())
return nil
func (n *NoopAudit) Salt(ctx context.Context) (*salt.Salt, error) {
if n.salt != nil {
defer n.saltMutex.RUnlock()
return n.salt, nil
defer n.saltMutex.Unlock()
if n.salt != nil {
return n.salt, nil
salt, err := salt.NewSalt(ctx, n.Config.SaltView, n.Config.SaltConfig)
if err != nil {
return nil, err
n.salt = salt
return salt, nil
func (n *NoopAudit) GetHash(ctx context.Context, data string) (string, error) {
salt, err := n.Salt(ctx)
if err != nil {
return "", err
return salt.GetIdentifiedHMAC(data), nil
func (n *NoopAudit) Reload(ctx context.Context) error {
return nil
func (n *NoopAudit) Invalidate(ctx context.Context) {
defer n.saltMutex.Unlock()
n.salt = nil
type TestLogger struct {
Path string
File *os.File
sink hclog.SinkAdapter
func NewTestLogger(t testing.T) *TestLogger {
var logFile *os.File
var logPath string
output := os.Stderr
logDir := os.Getenv("VAULT_TEST_LOG_DIR")
if logDir != "" {
logPath = filepath.Join(logDir, t.Name()+".log")
// t.Name may include slashes.
dir, _ := filepath.Split(logPath)
err := os.MkdirAll(dir, 0o755)
if err != nil {
logFile, err = os.Create(logPath)
if err != nil {
output = logFile
// We send nothing on the regular logger, that way we can later deregister
// the sink to stop logging during cluster cleanup.
logger := hclog.NewInterceptLogger(&hclog.LoggerOptions{
Output: io.Discard,
IndependentLevels: true,
Name: t.Name(),
sink := hclog.NewSinkAdapter(&hclog.LoggerOptions{
Output: output,
Level: hclog.Trace,
IndependentLevels: true,
return &TestLogger{
Path: logPath,
File: logFile,
InterceptLogger: logger,
sink: sink,
func (tl *TestLogger) StopLogging() {