| // Copyright (c) HashiCorp, Inc. |
| // SPDX-License-Identifier: MPL-2.0 |
| |
| package dbplugin |
| |
| import ( |
| "context" |
| "encoding/json" |
| "errors" |
| "reflect" |
| "testing" |
| "time" |
| |
| "github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto" |
| "google.golang.org/grpc" |
| ) |
| |
| func TestGRPCClient_Initialize(t *testing.T) { |
| type testCase struct { |
| client proto.DatabaseClient |
| req InitializeRequest |
| expectedResp InitializeResponse |
| assertErr errorAssertion |
| } |
| |
| tests := map[string]testCase{ |
| "bad config": { |
| client: fakeClient{}, |
| req: InitializeRequest{ |
| Config: map[string]interface{}{ |
| "foo": badJSONValue{}, |
| }, |
| }, |
| assertErr: assertErrNotNil, |
| }, |
| "database error": { |
| client: fakeClient{ |
| initErr: errors.New("initialize error"), |
| }, |
| req: InitializeRequest{ |
| Config: map[string]interface{}{ |
| "foo": "bar", |
| }, |
| }, |
| assertErr: assertErrNotNil, |
| }, |
| "happy path": { |
| client: fakeClient{ |
| initResp: &proto.InitializeResponse{ |
| ConfigData: marshal(t, map[string]interface{}{ |
| "foo": "bar", |
| "baz": "biz", |
| }), |
| }, |
| }, |
| req: InitializeRequest{ |
| Config: map[string]interface{}{ |
| "foo": "bar", |
| }, |
| }, |
| expectedResp: InitializeResponse{ |
| Config: map[string]interface{}{ |
| "foo": "bar", |
| "baz": "biz", |
| }, |
| }, |
| assertErr: assertErrNil, |
| }, |
| "JSON number type in initialize request": { |
| client: fakeClient{ |
| initResp: &proto.InitializeResponse{ |
| ConfigData: marshal(t, map[string]interface{}{ |
| "foo": "bar", |
| "max": "10", |
| }), |
| }, |
| }, |
| req: InitializeRequest{ |
| Config: map[string]interface{}{ |
| "foo": "bar", |
| "max": json.Number("10"), |
| }, |
| }, |
| expectedResp: InitializeResponse{ |
| Config: map[string]interface{}{ |
| "foo": "bar", |
| "max": "10", |
| }, |
| }, |
| assertErr: assertErrNil, |
| }, |
| } |
| |
| for name, test := range tests { |
| t.Run(name, func(t *testing.T) { |
| c := gRPCClient{ |
| client: test.client, |
| doneCtx: nil, |
| } |
| |
| // Context doesn't need to timeout since this is just passed through |
| ctx := context.Background() |
| |
| resp, err := c.Initialize(ctx, test.req) |
| test.assertErr(t, err) |
| |
| if !reflect.DeepEqual(resp, test.expectedResp) { |
| t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp) |
| } |
| }) |
| } |
| } |
| |
| func TestGRPCClient_NewUser(t *testing.T) { |
| runningCtx := context.Background() |
| cancelledCtx, cancel := context.WithCancel(context.Background()) |
| cancel() |
| |
| type testCase struct { |
| client proto.DatabaseClient |
| req NewUserRequest |
| doneCtx context.Context |
| expectedResp NewUserResponse |
| assertErr errorAssertion |
| } |
| |
| tests := map[string]testCase{ |
| "missing password": { |
| client: fakeClient{}, |
| req: NewUserRequest{ |
| Password: "", |
| Expiration: time.Now(), |
| }, |
| doneCtx: runningCtx, |
| assertErr: assertErrNotNil, |
| }, |
| "bad expiration": { |
| client: fakeClient{}, |
| req: NewUserRequest{ |
| Password: "njkvcb8y934u90grsnkjl", |
| Expiration: invalidExpiration, |
| }, |
| doneCtx: runningCtx, |
| assertErr: assertErrNotNil, |
| }, |
| "database error": { |
| client: fakeClient{ |
| newUserErr: errors.New("new user error"), |
| }, |
| req: NewUserRequest{ |
| Password: "njkvcb8y934u90grsnkjl", |
| Expiration: time.Now(), |
| }, |
| doneCtx: runningCtx, |
| assertErr: assertErrNotNil, |
| }, |
| "plugin shut down": { |
| client: fakeClient{ |
| newUserErr: errors.New("new user error"), |
| }, |
| req: NewUserRequest{ |
| Password: "njkvcb8y934u90grsnkjl", |
| Expiration: time.Now(), |
| }, |
| doneCtx: cancelledCtx, |
| assertErr: assertErrEquals(ErrPluginShutdown), |
| }, |
| "happy path": { |
| client: fakeClient{ |
| newUserResp: &proto.NewUserResponse{ |
| Username: "new_user", |
| }, |
| }, |
| req: NewUserRequest{ |
| Password: "njkvcb8y934u90grsnkjl", |
| Expiration: time.Now(), |
| }, |
| doneCtx: runningCtx, |
| expectedResp: NewUserResponse{ |
| Username: "new_user", |
| }, |
| assertErr: assertErrNil, |
| }, |
| } |
| |
| for name, test := range tests { |
| t.Run(name, func(t *testing.T) { |
| c := gRPCClient{ |
| client: test.client, |
| doneCtx: test.doneCtx, |
| } |
| |
| ctx := context.Background() |
| |
| resp, err := c.NewUser(ctx, test.req) |
| test.assertErr(t, err) |
| |
| if !reflect.DeepEqual(resp, test.expectedResp) { |
| t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp) |
| } |
| }) |
| } |
| } |
| |
| func TestGRPCClient_UpdateUser(t *testing.T) { |
| runningCtx := context.Background() |
| cancelledCtx, cancel := context.WithCancel(context.Background()) |
| cancel() |
| |
| type testCase struct { |
| client proto.DatabaseClient |
| req UpdateUserRequest |
| doneCtx context.Context |
| assertErr errorAssertion |
| } |
| |
| tests := map[string]testCase{ |
| "missing username": { |
| client: fakeClient{}, |
| req: UpdateUserRequest{}, |
| doneCtx: runningCtx, |
| assertErr: assertErrNotNil, |
| }, |
| "missing changes": { |
| client: fakeClient{}, |
| req: UpdateUserRequest{ |
| Username: "user", |
| }, |
| doneCtx: runningCtx, |
| assertErr: assertErrNotNil, |
| }, |
| "empty password": { |
| client: fakeClient{}, |
| req: UpdateUserRequest{ |
| Username: "user", |
| Password: &ChangePassword{ |
| NewPassword: "", |
| }, |
| }, |
| doneCtx: runningCtx, |
| assertErr: assertErrNotNil, |
| }, |
| "zero expiration": { |
| client: fakeClient{}, |
| req: UpdateUserRequest{ |
| Username: "user", |
| Expiration: &ChangeExpiration{ |
| NewExpiration: time.Time{}, |
| }, |
| }, |
| doneCtx: runningCtx, |
| assertErr: assertErrNotNil, |
| }, |
| "bad expiration": { |
| client: fakeClient{}, |
| req: UpdateUserRequest{ |
| Username: "user", |
| Expiration: &ChangeExpiration{ |
| NewExpiration: invalidExpiration, |
| }, |
| }, |
| doneCtx: runningCtx, |
| assertErr: assertErrNotNil, |
| }, |
| "database error": { |
| client: fakeClient{ |
| updateUserErr: errors.New("update user error"), |
| }, |
| req: UpdateUserRequest{ |
| Username: "user", |
| Password: &ChangePassword{ |
| NewPassword: "asdf", |
| }, |
| }, |
| doneCtx: runningCtx, |
| assertErr: assertErrNotNil, |
| }, |
| "plugin shut down": { |
| client: fakeClient{ |
| updateUserErr: errors.New("update user error"), |
| }, |
| req: UpdateUserRequest{ |
| Username: "user", |
| Password: &ChangePassword{ |
| NewPassword: "asdf", |
| }, |
| }, |
| doneCtx: cancelledCtx, |
| assertErr: assertErrEquals(ErrPluginShutdown), |
| }, |
| "happy path - change password": { |
| client: fakeClient{}, |
| req: UpdateUserRequest{ |
| Username: "user", |
| Password: &ChangePassword{ |
| NewPassword: "asdf", |
| }, |
| }, |
| doneCtx: runningCtx, |
| assertErr: assertErrNil, |
| }, |
| "happy path - change expiration": { |
| client: fakeClient{}, |
| req: UpdateUserRequest{ |
| Username: "user", |
| Expiration: &ChangeExpiration{ |
| NewExpiration: time.Now(), |
| }, |
| }, |
| doneCtx: runningCtx, |
| assertErr: assertErrNil, |
| }, |
| } |
| |
| for name, test := range tests { |
| t.Run(name, func(t *testing.T) { |
| c := gRPCClient{ |
| client: test.client, |
| doneCtx: test.doneCtx, |
| } |
| |
| ctx := context.Background() |
| |
| _, err := c.UpdateUser(ctx, test.req) |
| test.assertErr(t, err) |
| }) |
| } |
| } |
| |
| func TestGRPCClient_DeleteUser(t *testing.T) { |
| runningCtx := context.Background() |
| cancelledCtx, cancel := context.WithCancel(context.Background()) |
| cancel() |
| |
| type testCase struct { |
| client proto.DatabaseClient |
| req DeleteUserRequest |
| doneCtx context.Context |
| assertErr errorAssertion |
| } |
| |
| tests := map[string]testCase{ |
| "missing username": { |
| client: fakeClient{}, |
| req: DeleteUserRequest{}, |
| doneCtx: runningCtx, |
| assertErr: assertErrNotNil, |
| }, |
| "database error": { |
| client: fakeClient{ |
| deleteUserErr: errors.New("delete user error'"), |
| }, |
| req: DeleteUserRequest{ |
| Username: "user", |
| }, |
| doneCtx: runningCtx, |
| assertErr: assertErrNotNil, |
| }, |
| "plugin shut down": { |
| client: fakeClient{ |
| deleteUserErr: errors.New("delete user error'"), |
| }, |
| req: DeleteUserRequest{ |
| Username: "user", |
| }, |
| doneCtx: cancelledCtx, |
| assertErr: assertErrEquals(ErrPluginShutdown), |
| }, |
| "happy path": { |
| client: fakeClient{}, |
| req: DeleteUserRequest{ |
| Username: "user", |
| }, |
| doneCtx: runningCtx, |
| assertErr: assertErrNil, |
| }, |
| } |
| |
| for name, test := range tests { |
| t.Run(name, func(t *testing.T) { |
| c := gRPCClient{ |
| client: test.client, |
| doneCtx: test.doneCtx, |
| } |
| |
| ctx := context.Background() |
| |
| _, err := c.DeleteUser(ctx, test.req) |
| test.assertErr(t, err) |
| }) |
| } |
| } |
| |
| func TestGRPCClient_Type(t *testing.T) { |
| runningCtx := context.Background() |
| cancelledCtx, cancel := context.WithCancel(context.Background()) |
| cancel() |
| |
| type testCase struct { |
| client proto.DatabaseClient |
| doneCtx context.Context |
| expectedType string |
| assertErr errorAssertion |
| } |
| |
| tests := map[string]testCase{ |
| "database error": { |
| client: fakeClient{ |
| typeErr: errors.New("type error"), |
| }, |
| doneCtx: runningCtx, |
| assertErr: assertErrNotNil, |
| }, |
| "plugin shut down": { |
| client: fakeClient{ |
| typeErr: errors.New("type error"), |
| }, |
| doneCtx: cancelledCtx, |
| assertErr: assertErrEquals(ErrPluginShutdown), |
| }, |
| "happy path": { |
| client: fakeClient{ |
| typeResp: &proto.TypeResponse{ |
| Type: "test type", |
| }, |
| }, |
| doneCtx: runningCtx, |
| expectedType: "test type", |
| assertErr: assertErrNil, |
| }, |
| } |
| |
| for name, test := range tests { |
| t.Run(name, func(t *testing.T) { |
| c := gRPCClient{ |
| client: test.client, |
| doneCtx: test.doneCtx, |
| } |
| |
| dbType, err := c.Type() |
| test.assertErr(t, err) |
| |
| if dbType != test.expectedType { |
| t.Fatalf("Actual type: %s Expected type: %s", dbType, test.expectedType) |
| } |
| }) |
| } |
| } |
| |
| func TestGRPCClient_Close(t *testing.T) { |
| runningCtx := context.Background() |
| cancelledCtx, cancel := context.WithCancel(context.Background()) |
| cancel() |
| |
| type testCase struct { |
| client proto.DatabaseClient |
| doneCtx context.Context |
| assertErr errorAssertion |
| } |
| |
| tests := map[string]testCase{ |
| "database error": { |
| client: fakeClient{ |
| typeErr: errors.New("type error"), |
| }, |
| doneCtx: runningCtx, |
| assertErr: assertErrNotNil, |
| }, |
| "plugin shut down": { |
| client: fakeClient{ |
| typeErr: errors.New("type error"), |
| }, |
| doneCtx: cancelledCtx, |
| assertErr: assertErrEquals(ErrPluginShutdown), |
| }, |
| "happy path": { |
| client: fakeClient{}, |
| doneCtx: runningCtx, |
| assertErr: assertErrNil, |
| }, |
| } |
| |
| for name, test := range tests { |
| t.Run(name, func(t *testing.T) { |
| c := gRPCClient{ |
| client: test.client, |
| doneCtx: test.doneCtx, |
| } |
| |
| err := c.Close() |
| test.assertErr(t, err) |
| }) |
| } |
| } |
| |
| type errorAssertion func(*testing.T, error) |
| |
| func assertErrNotNil(t *testing.T, err error) { |
| t.Helper() |
| if err == nil { |
| t.Fatalf("err expected, got nil") |
| } |
| } |
| |
| func assertErrNil(t *testing.T, err error) { |
| t.Helper() |
| if err != nil { |
| t.Fatalf("no error expected, got: %s", err) |
| } |
| } |
| |
| func assertErrEquals(expectedErr error) errorAssertion { |
| return func(t *testing.T, err error) { |
| t.Helper() |
| if err != expectedErr { |
| t.Fatalf("Actual err: %#v Expected err: %#v", err, expectedErr) |
| } |
| } |
| } |
| |
| var _ proto.DatabaseClient = fakeClient{} |
| |
| type fakeClient struct { |
| initResp *proto.InitializeResponse |
| initErr error |
| |
| newUserResp *proto.NewUserResponse |
| newUserErr error |
| |
| updateUserResp *proto.UpdateUserResponse |
| updateUserErr error |
| |
| deleteUserResp *proto.DeleteUserResponse |
| deleteUserErr error |
| |
| typeResp *proto.TypeResponse |
| typeErr error |
| |
| closeErr error |
| } |
| |
| func (f fakeClient) Initialize(context.Context, *proto.InitializeRequest, ...grpc.CallOption) (*proto.InitializeResponse, error) { |
| return f.initResp, f.initErr |
| } |
| |
| func (f fakeClient) NewUser(context.Context, *proto.NewUserRequest, ...grpc.CallOption) (*proto.NewUserResponse, error) { |
| return f.newUserResp, f.newUserErr |
| } |
| |
| func (f fakeClient) UpdateUser(context.Context, *proto.UpdateUserRequest, ...grpc.CallOption) (*proto.UpdateUserResponse, error) { |
| return f.updateUserResp, f.updateUserErr |
| } |
| |
| func (f fakeClient) DeleteUser(context.Context, *proto.DeleteUserRequest, ...grpc.CallOption) (*proto.DeleteUserResponse, error) { |
| return f.deleteUserResp, f.deleteUserErr |
| } |
| |
| func (f fakeClient) Type(context.Context, *proto.Empty, ...grpc.CallOption) (*proto.TypeResponse, error) { |
| return f.typeResp, f.typeErr |
| } |
| |
| func (f fakeClient) Close(context.Context, *proto.Empty, ...grpc.CallOption) (*proto.Empty, error) { |
| return &proto.Empty{}, f.typeErr |
| } |