blob: 10f02b7bec259a79707432fa237b53ada52ff196 [file] [log] [blame]
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package dbplugin
import (
"context"
"errors"
"reflect"
"testing"
"time"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
"github.com/hashicorp/vault/sdk/helper/wrapping"
"github.com/hashicorp/vault/sdk/logical"
"github.com/stretchr/testify/mock"
"google.golang.org/grpc"
)
func TestNewPluginClient(t *testing.T) {
type testCase struct {
config pluginutil.PluginClientConfig
pluginClient pluginutil.PluginClient
expectedResp *DatabasePluginClient
expectedErr error
}
tests := map[string]testCase{
"happy path": {
config: testPluginClientConfig(),
pluginClient: &fakePluginClient{
connResp: nil,
dispenseResp: gRPCClient{client: fakeClient{}},
dispenseErr: nil,
},
expectedResp: &DatabasePluginClient{
client: &fakePluginClient{
connResp: nil,
dispenseResp: gRPCClient{client: fakeClient{}},
dispenseErr: nil,
},
Database: gRPCClient{client: proto.NewDatabaseClient(nil), versionClient: logical.NewPluginVersionClient(nil), doneCtx: context.Context(nil)},
},
expectedErr: nil,
},
"dispense error": {
config: testPluginClientConfig(),
pluginClient: &fakePluginClient{
connResp: nil,
dispenseResp: gRPCClient{},
dispenseErr: errors.New("dispense error"),
},
expectedResp: nil,
expectedErr: errors.New("dispense error"),
},
"error unsupported client type": {
config: testPluginClientConfig(),
pluginClient: &fakePluginClient{
connResp: nil,
dispenseResp: nil,
dispenseErr: nil,
},
expectedResp: nil,
expectedErr: errors.New("unsupported client type"),
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
ctx := context.Background()
mockWrapper := new(mockRunnerUtil)
mockWrapper.On("NewPluginClient", ctx, mock.Anything).
Return(test.pluginClient, nil)
defer mockWrapper.AssertNumberOfCalls(t, "NewPluginClient", 1)
resp, err := NewPluginClient(ctx, mockWrapper, test.config)
if test.expectedErr != nil && err == nil {
t.Fatalf("err expected, got nil")
}
if test.expectedErr == nil && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
if test.expectedErr == nil && !reflect.DeepEqual(resp, test.expectedResp) {
t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp)
}
})
}
}
func testPluginClientConfig() pluginutil.PluginClientConfig {
return pluginutil.PluginClientConfig{
Name: "test-plugin",
PluginSets: PluginSets,
PluginType: consts.PluginTypeDatabase,
HandshakeConfig: HandshakeConfig,
Logger: log.NewNullLogger(),
IsMetadataMode: true,
AutoMTLS: true,
}
}
var _ pluginutil.PluginClient = &fakePluginClient{}
type fakePluginClient struct {
connResp grpc.ClientConnInterface
dispenseResp interface{}
dispenseErr error
}
func (f *fakePluginClient) Conn() grpc.ClientConnInterface {
return nil
}
func (f *fakePluginClient) Reload() error {
return nil
}
func (f *fakePluginClient) Dispense(name string) (interface{}, error) {
return f.dispenseResp, f.dispenseErr
}
func (f *fakePluginClient) Ping() error {
return nil
}
func (f *fakePluginClient) Close() error {
return nil
}
var _ pluginutil.RunnerUtil = &mockRunnerUtil{}
type mockRunnerUtil struct {
mock.Mock
}
func (m *mockRunnerUtil) VaultVersion(ctx context.Context) (string, error) {
return "dummyversion", nil
}
func (m *mockRunnerUtil) NewPluginClient(ctx context.Context, config pluginutil.PluginClientConfig) (pluginutil.PluginClient, error) {
args := m.Called(ctx, config)
return args.Get(0).(pluginutil.PluginClient), args.Error(1)
}
func (m *mockRunnerUtil) ResponseWrapData(ctx context.Context, data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error) {
args := m.Called(ctx, data, ttl, jwt)
return args.Get(0).(*wrapping.ResponseWrapInfo), args.Error(1)
}
func (m *mockRunnerUtil) MlockEnabled() bool {
args := m.Called()
return args.Bool(0)
}