blob: 5e65c3467068bd48a5f9d0a73bd474dc9c23f510 [file] [log] [blame]
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package dbplugin
import (
"fmt"
"reflect"
"strings"
"testing"
"time"
"unicode"
"github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto"
"google.golang.org/protobuf/types/known/structpb"
"google.golang.org/protobuf/types/known/timestamppb"
)
func TestConversionsHaveAllFields(t *testing.T) {
t.Run("initReqToProto", func(t *testing.T) {
req := InitializeRequest{
Config: map[string]interface{}{
"foo": map[string]interface{}{
"bar": "baz",
},
},
VerifyConnection: true,
}
protoReq, err := initReqToProto(req)
if err != nil {
t.Fatalf("Failed to convert request to proto request: %s", err)
}
values := getAllGetterValues(protoReq)
if len(values) == 0 {
// Probably a test failure - the protos used in these tests should have Get functions on them
t.Fatalf("No values found from Get functions!")
}
for _, gtr := range values {
err := assertAllFieldsSet(fmt.Sprintf("InitializeRequest.%s", gtr.name), gtr.value)
if err != nil {
t.Fatalf("%s", err)
}
}
})
t.Run("newUserReqToProto", func(t *testing.T) {
req := NewUserRequest{
UsernameConfig: UsernameMetadata{
DisplayName: "dispName",
RoleName: "roleName",
},
Statements: Statements{
Commands: []string{
"statement",
},
},
RollbackStatements: Statements{
Commands: []string{
"rollback_statement",
},
},
CredentialType: CredentialTypeRSAPrivateKey,
PublicKey: []byte("-----BEGIN PUBLIC KEY-----"),
Password: "password",
Subject: "subject",
Expiration: time.Now(),
}
protoReq, err := newUserReqToProto(req)
if err != nil {
t.Fatalf("Failed to convert request to proto request: %s", err)
}
values := getAllGetterValues(protoReq)
if len(values) == 0 {
// Probably a test failure - the protos used in these tests should have Get functions on them
t.Fatalf("No values found from Get functions!")
}
for _, gtr := range values {
err := assertAllFieldsSet(fmt.Sprintf("NewUserRequest.%s", gtr.name), gtr.value)
if err != nil {
t.Fatalf("%s", err)
}
}
})
t.Run("updateUserReqToProto", func(t *testing.T) {
req := UpdateUserRequest{
Username: "username",
CredentialType: CredentialTypeRSAPrivateKey,
Password: &ChangePassword{
NewPassword: "newpassword",
Statements: Statements{
Commands: []string{
"statement",
},
},
},
PublicKey: &ChangePublicKey{
NewPublicKey: []byte("-----BEGIN PUBLIC KEY-----"),
Statements: Statements{
Commands: []string{
"statement",
},
},
},
Expiration: &ChangeExpiration{
NewExpiration: time.Now(),
Statements: Statements{
Commands: []string{
"statement",
},
},
},
}
protoReq, err := updateUserReqToProto(req)
if err != nil {
t.Fatalf("Failed to convert request to proto request: %s", err)
}
values := getAllGetterValues(protoReq)
if len(values) == 0 {
// Probably a test failure - the protos used in these tests should have Get functions on them
t.Fatalf("No values found from Get functions!")
}
for _, gtr := range values {
err := assertAllFieldsSet(fmt.Sprintf("UpdateUserRequest.%s", gtr.name), gtr.value)
if err != nil {
t.Fatalf("%s", err)
}
}
})
t.Run("deleteUserReqToProto", func(t *testing.T) {
req := DeleteUserRequest{
Username: "username",
Statements: Statements{
Commands: []string{
"statement",
},
},
}
protoReq, err := deleteUserReqToProto(req)
if err != nil {
t.Fatalf("Failed to convert request to proto request: %s", err)
}
values := getAllGetterValues(protoReq)
if len(values) == 0 {
// Probably a test failure - the protos used in these tests should have Get functions on them
t.Fatalf("No values found from Get functions!")
}
for _, gtr := range values {
err := assertAllFieldsSet(fmt.Sprintf("DeleteUserRequest.%s", gtr.name), gtr.value)
if err != nil {
t.Fatalf("%s", err)
}
}
})
t.Run("getUpdateUserRequest", func(t *testing.T) {
req := &proto.UpdateUserRequest{
Username: "username",
CredentialType: int32(CredentialTypeRSAPrivateKey),
Password: &proto.ChangePassword{
NewPassword: "newpass",
Statements: &proto.Statements{
Commands: []string{
"statement",
},
},
},
PublicKey: &proto.ChangePublicKey{
NewPublicKey: []byte("-----BEGIN PUBLIC KEY-----"),
Statements: &proto.Statements{
Commands: []string{
"statement",
},
},
},
Expiration: &proto.ChangeExpiration{
NewExpiration: timestamppb.Now(),
Statements: &proto.Statements{
Commands: []string{
"statement",
},
},
},
}
protoReq, err := getUpdateUserRequest(req)
if err != nil {
t.Fatalf("Failed to convert request to proto request: %s", err)
}
err = assertAllFieldsSet("proto.UpdateUserRequest", protoReq)
if err != nil {
t.Fatalf("%s", err)
}
})
}
type getter struct {
name string
value interface{}
}
func getAllGetterValues(value interface{}) (values []getter) {
typ := reflect.TypeOf(value)
val := reflect.ValueOf(value)
for i := 0; i < typ.NumMethod(); i++ {
method := typ.Method(i)
if !strings.HasPrefix(method.Name, "Get") {
continue
}
valMethod := val.Method(i)
resp := valMethod.Call(nil)
getVal := resp[0].Interface()
gtr := getter{
name: strings.TrimPrefix(method.Name, "Get"),
value: getVal,
}
values = append(values, gtr)
}
return values
}
// Ensures the assertion works properly
func TestAssertAllFieldsSet(t *testing.T) {
type testCase struct {
value interface{}
expectErr bool
}
tests := map[string]testCase{
"zero int": {
value: 0,
expectErr: true,
},
"non-zero int": {
value: 1,
expectErr: false,
},
"zero float64": {
value: 0.0,
expectErr: true,
},
"non-zero float64": {
value: 1.0,
expectErr: false,
},
"empty string": {
value: "",
expectErr: true,
},
"true boolean": {
value: true,
expectErr: false,
},
"false boolean": { // False is an exception to the "is zero" rule
value: false,
expectErr: false,
},
"blank struct": {
value: struct{}{},
expectErr: true,
},
"non-blank but empty struct": {
value: struct {
str string
}{
str: "",
},
expectErr: true,
},
"non-empty string": {
value: "foo",
expectErr: false,
},
"non-empty struct": {
value: struct {
str string
}{
str: "foo",
},
expectErr: false,
},
"empty nested struct": {
value: struct {
Str string
Substruct struct {
Substr string
}
}{
Str: "foo",
Substruct: struct {
Substr string
}{}, // Empty sub-field
},
expectErr: true,
},
"filled nested struct": {
value: struct {
str string
substruct struct {
substr string
}
}{
str: "foo",
substruct: struct {
substr string
}{
substr: "sub-foo",
},
},
expectErr: false,
},
"nil map": {
value: map[string]string(nil),
expectErr: true,
},
"empty map": {
value: map[string]string{},
expectErr: true,
},
"filled map": {
value: map[string]string{
"foo": "bar",
"int": "42",
},
expectErr: false,
},
"map with empty string value": {
value: map[string]string{
"foo": "",
},
expectErr: true,
},
"nested map with empty string value": {
value: map[string]interface{}{
"bar": "baz",
"foo": map[string]interface{}{
"subfoo": "",
},
},
expectErr: true,
},
"nil slice": {
value: []string(nil),
expectErr: true,
},
"empty slice": {
value: []string{},
expectErr: true,
},
"filled slice": {
value: []string{
"foo",
},
expectErr: false,
},
"slice with empty string value": {
value: []string{
"",
},
expectErr: true,
},
"empty structpb": {
value: newStructPb(t, map[string]interface{}{}),
expectErr: true,
},
"filled structpb": {
value: newStructPb(t, map[string]interface{}{
"foo": "bar",
"int": 42,
}),
expectErr: false,
},
"pointer to zero int": {
value: intPtr(0),
expectErr: true,
},
"pointer to non-zero int": {
value: intPtr(1),
expectErr: false,
},
"pointer to zero float64": {
value: float64Ptr(0.0),
expectErr: true,
},
"pointer to non-zero float64": {
value: float64Ptr(1.0),
expectErr: false,
},
"pointer to nil string": {
value: new(string),
expectErr: true,
},
"pointer to non-nil string": {
value: strPtr("foo"),
expectErr: false,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
err := assertAllFieldsSet("", test.value)
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
})
}
}
func assertAllFieldsSet(name string, val interface{}) error {
if val == nil {
return fmt.Errorf("value is nil")
}
rVal := reflect.ValueOf(val)
return assertAllFieldsSetValue(name, rVal)
}
func assertAllFieldsSetValue(name string, rVal reflect.Value) error {
// All booleans are allowed - we don't have a way of differentiating between
// and intentional false and a missing false
if rVal.Kind() == reflect.Bool {
return nil
}
// Primitives fall through here
if rVal.IsZero() {
return fmt.Errorf("%s is zero", name)
}
switch rVal.Kind() {
case reflect.Ptr, reflect.Interface:
return assertAllFieldsSetValue(name, rVal.Elem())
case reflect.Struct:
return assertAllFieldsSetStruct(name, rVal)
case reflect.Map:
if rVal.Len() == 0 {
return fmt.Errorf("%s (map type) is empty", name)
}
iter := rVal.MapRange()
for iter.Next() {
k := iter.Key()
v := iter.Value()
err := assertAllFieldsSetValue(fmt.Sprintf("%s[%s]", name, k), v)
if err != nil {
return err
}
}
case reflect.Slice:
if rVal.Len() == 0 {
return fmt.Errorf("%s (slice type) is empty", name)
}
for i := 0; i < rVal.Len(); i++ {
sliceVal := rVal.Index(i)
err := assertAllFieldsSetValue(fmt.Sprintf("%s[%d]", name, i), sliceVal)
if err != nil {
return err
}
}
}
return nil
}
func assertAllFieldsSetStruct(name string, rVal reflect.Value) error {
switch rVal.Type() {
case reflect.TypeOf(timestamppb.Timestamp{}):
ts := rVal.Interface().(timestamppb.Timestamp)
if ts.AsTime().IsZero() {
return fmt.Errorf("%s is zero", name)
}
return nil
default:
for i := 0; i < rVal.NumField(); i++ {
field := rVal.Field(i)
fieldName := rVal.Type().Field(i)
// Skip fields that aren't exported
if unicode.IsLower([]rune(fieldName.Name)[0]) {
continue
}
err := assertAllFieldsSetValue(fmt.Sprintf("%s.%s", name, fieldName.Name), field)
if err != nil {
return err
}
}
return nil
}
}
func intPtr(i int) *int {
return &i
}
func float64Ptr(f float64) *float64 {
return &f
}
func strPtr(str string) *string {
return &str
}
func newStructPb(t *testing.T, m map[string]interface{}) *structpb.Struct {
t.Helper()
s, err := structpb.NewStruct(m)
if err != nil {
t.Fatalf("Failed to convert map to struct: %s", err)
}
return s
}