| // Copyright (c) HashiCorp, Inc. |
| // SPDX-License-Identifier: MPL-2.0 |
| |
| package framework |
| |
| import ( |
| "encoding/base64" |
| "encoding/json" |
| "errors" |
| "fmt" |
| "net/http" |
| "regexp" |
| "strings" |
| |
| "github.com/hashicorp/errwrap" |
| "github.com/hashicorp/go-secure-stdlib/parseutil" |
| "github.com/hashicorp/go-secure-stdlib/strutil" |
| "github.com/hashicorp/vault/sdk/helper/jsonutil" |
| "github.com/mitchellh/mapstructure" |
| ) |
| |
| // FieldData is the structure passed to the callback to handle a path |
| // containing the populated parameters for fields. This should be used |
| // instead of the raw (*vault.Request).Data to access data in a type-safe |
| // way. |
| type FieldData struct { |
| Raw map[string]interface{} |
| Schema map[string]*FieldSchema |
| } |
| |
| // Validate cycles through raw data and validates conversions in |
| // the schema, so we don't get an error/panic later when |
| // trying to get data out. Data not in the schema is not |
| // an error at this point, so we don't worry about it. |
| func (d *FieldData) Validate() error { |
| for field, value := range d.Raw { |
| |
| schema, ok := d.Schema[field] |
| if !ok { |
| continue |
| } |
| |
| switch schema.Type { |
| case TypeBool, TypeInt, TypeInt64, TypeMap, TypeDurationSecond, TypeSignedDurationSecond, TypeString, |
| TypeLowerCaseString, TypeNameString, TypeSlice, TypeStringSlice, TypeCommaStringSlice, |
| TypeKVPairs, TypeCommaIntSlice, TypeHeader, TypeFloat, TypeTime: |
| _, _, err := d.getPrimitive(field, schema) |
| if err != nil { |
| return errwrap.Wrapf(fmt.Sprintf("error converting input %v for field %q: {{err}}", value, field), err) |
| } |
| default: |
| return fmt.Errorf("unknown field type %q for field %q", schema.Type, field) |
| } |
| } |
| |
| return nil |
| } |
| |
| // ValidateStrict cycles through raw data and validates conversions in the |
| // schema. In addition to the checks done by Validate, this function ensures |
| // that the raw data has all of the schema's required fields and does not |
| // have any fields outside of the schema. It will return a non-nil error if: |
| // |
| // 1. a conversion (parsing of the field's value) fails |
| // 2. a raw field does not exist in the schema (unless the schema is nil) |
| // 3. a required schema field is missing from the raw data |
| // |
| // This function is currently used for validating response schemas in tests. |
| func (d *FieldData) ValidateStrict() error { |
| // the schema is nil, nothing to validate |
| if d.Schema == nil { |
| return nil |
| } |
| |
| for field := range d.Raw { |
| if _, _, err := d.GetOkErr(field); err != nil { |
| return fmt.Errorf("field %q: %w", field, err) |
| } |
| } |
| |
| for field, schema := range d.Schema { |
| if !schema.Required { |
| continue |
| } |
| if _, ok := d.Raw[field]; !ok { |
| return fmt.Errorf("missing required field %q", field) |
| } |
| } |
| |
| return nil |
| } |
| |
| // Get gets the value for the given field. If the key is an invalid field, |
| // FieldData will panic. If you want a safer version of this method, use |
| // GetOk. If the field k is not set, the default value (if set) will be |
| // returned, otherwise the zero value will be returned. |
| func (d *FieldData) Get(k string) interface{} { |
| schema, ok := d.Schema[k] |
| if !ok { |
| panic(fmt.Sprintf("field %s not in the schema", k)) |
| } |
| |
| // If the value can't be decoded, use the zero or default value for the field |
| // type |
| value, ok := d.GetOk(k) |
| if !ok || value == nil { |
| value = schema.DefaultOrZero() |
| } |
| |
| return value |
| } |
| |
| // GetDefaultOrZero gets the default value set on the schema for the given |
| // field. If there is no default value set, the zero value of the type |
| // will be returned. |
| func (d *FieldData) GetDefaultOrZero(k string) interface{} { |
| schema, ok := d.Schema[k] |
| if !ok { |
| panic(fmt.Sprintf("field %s not in the schema", k)) |
| } |
| |
| return schema.DefaultOrZero() |
| } |
| |
| // GetFirst gets the value for the given field names, in order from first |
| // to last. This can be useful for fields with a current name, and one or |
| // more deprecated names. The second return value will be false if the keys |
| // are invalid or the keys are not set at all. |
| func (d *FieldData) GetFirst(k ...string) (interface{}, bool) { |
| for _, v := range k { |
| if result, ok := d.GetOk(v); ok { |
| return result, ok |
| } |
| } |
| return nil, false |
| } |
| |
| // GetOk gets the value for the given field. The second return value will be |
| // false if the key is invalid or the key is not set at all. If the field k is |
| // set and the decoded value is nil, the default or zero value |
| // will be returned instead. |
| func (d *FieldData) GetOk(k string) (interface{}, bool) { |
| schema, ok := d.Schema[k] |
| if !ok { |
| return nil, false |
| } |
| |
| result, ok, err := d.GetOkErr(k) |
| if err != nil { |
| panic(fmt.Sprintf("error reading %s: %s", k, err)) |
| } |
| |
| if ok && result == nil { |
| result = schema.DefaultOrZero() |
| } |
| |
| return result, ok |
| } |
| |
| // GetOkErr is the most conservative of all the Get methods. It returns |
| // whether key is set or not, but also an error value. The error value is |
| // non-nil if the field doesn't exist or there was an error parsing the |
| // field value. |
| func (d *FieldData) GetOkErr(k string) (interface{}, bool, error) { |
| schema, ok := d.Schema[k] |
| if !ok { |
| return nil, false, fmt.Errorf("unknown field: %q", k) |
| } |
| |
| switch schema.Type { |
| case TypeBool, TypeInt, TypeInt64, TypeMap, TypeDurationSecond, TypeSignedDurationSecond, TypeString, |
| TypeLowerCaseString, TypeNameString, TypeSlice, TypeStringSlice, TypeCommaStringSlice, |
| TypeKVPairs, TypeCommaIntSlice, TypeHeader, TypeFloat, TypeTime: |
| return d.getPrimitive(k, schema) |
| default: |
| return nil, false, |
| fmt.Errorf("unknown field type %q for field %q", schema.Type, k) |
| } |
| } |
| |
| func (d *FieldData) getPrimitive(k string, schema *FieldSchema) (interface{}, bool, error) { |
| raw, ok := d.Raw[k] |
| if !ok { |
| return nil, false, nil |
| } |
| |
| switch t := schema.Type; t { |
| case TypeBool: |
| var result bool |
| if err := mapstructure.WeakDecode(raw, &result); err != nil { |
| return nil, false, err |
| } |
| return result, true, nil |
| |
| case TypeInt: |
| var result int |
| if err := mapstructure.WeakDecode(raw, &result); err != nil { |
| return nil, false, err |
| } |
| return result, true, nil |
| |
| case TypeInt64: |
| var result int64 |
| if err := mapstructure.WeakDecode(raw, &result); err != nil { |
| return nil, false, err |
| } |
| return result, true, nil |
| |
| case TypeFloat: |
| var result float64 |
| if err := mapstructure.WeakDecode(raw, &result); err != nil { |
| return nil, false, err |
| } |
| return result, true, nil |
| |
| case TypeString: |
| var result string |
| if err := mapstructure.WeakDecode(raw, &result); err != nil { |
| return nil, false, err |
| } |
| return result, true, nil |
| |
| case TypeLowerCaseString: |
| var result string |
| if err := mapstructure.WeakDecode(raw, &result); err != nil { |
| return nil, false, err |
| } |
| return strings.ToLower(result), true, nil |
| |
| case TypeNameString: |
| var result string |
| if err := mapstructure.WeakDecode(raw, &result); err != nil { |
| return nil, false, err |
| } |
| matched, err := regexp.MatchString("^\\w(([\\w-.]+)?\\w)?$", result) |
| if err != nil { |
| return nil, false, err |
| } |
| if !matched { |
| return nil, false, errors.New("field does not match the formatting rules") |
| } |
| return result, true, nil |
| |
| case TypeMap: |
| var result map[string]interface{} |
| if err := mapstructure.WeakDecode(raw, &result); err != nil { |
| return nil, false, err |
| } |
| return result, true, nil |
| |
| case TypeDurationSecond, TypeSignedDurationSecond: |
| var result int |
| switch inp := raw.(type) { |
| case nil: |
| return nil, false, nil |
| default: |
| dur, err := parseutil.ParseDurationSecond(inp) |
| if err != nil { |
| return nil, false, err |
| } |
| result = int(dur.Seconds()) |
| } |
| if t == TypeDurationSecond && result < 0 { |
| return nil, false, fmt.Errorf("cannot provide negative value '%d'", result) |
| } |
| return result, true, nil |
| |
| case TypeTime: |
| switch inp := raw.(type) { |
| case nil: |
| // Handle nil interface{} as a non-error case |
| return nil, false, nil |
| default: |
| time, err := parseutil.ParseAbsoluteTime(inp) |
| if err != nil { |
| return nil, false, err |
| } |
| return time.UTC(), true, nil |
| } |
| |
| case TypeCommaIntSlice: |
| var result []int |
| |
| jsonIn, ok := raw.(json.Number) |
| if ok { |
| raw = jsonIn.String() |
| } |
| |
| config := &mapstructure.DecoderConfig{ |
| Result: &result, |
| WeaklyTypedInput: true, |
| DecodeHook: mapstructure.StringToSliceHookFunc(","), |
| } |
| decoder, err := mapstructure.NewDecoder(config) |
| if err != nil { |
| return nil, false, err |
| } |
| if err := decoder.Decode(raw); err != nil { |
| return nil, false, err |
| } |
| if len(result) == 0 { |
| return make([]int, 0), true, nil |
| } |
| return result, true, nil |
| |
| case TypeSlice: |
| var result []interface{} |
| if err := mapstructure.WeakDecode(raw, &result); err != nil { |
| return nil, false, err |
| } |
| if len(result) == 0 { |
| return make([]interface{}, 0), true, nil |
| } |
| return result, true, nil |
| |
| case TypeStringSlice: |
| rawString, ok := raw.(string) |
| if ok && rawString == "" { |
| return []string{}, true, nil |
| } |
| |
| var result []string |
| if err := mapstructure.WeakDecode(raw, &result); err != nil { |
| return nil, false, err |
| } |
| if len(result) == 0 { |
| return make([]string, 0), true, nil |
| } |
| return strutil.TrimStrings(result), true, nil |
| |
| case TypeCommaStringSlice: |
| res, err := parseutil.ParseCommaStringSlice(raw) |
| if err != nil { |
| return nil, false, err |
| } |
| return res, true, nil |
| |
| case TypeKVPairs: |
| // First try to parse this as a map |
| var mapResult map[string]string |
| if err := mapstructure.WeakDecode(raw, &mapResult); err == nil { |
| return mapResult, true, nil |
| } |
| |
| // If map parse fails, parse as a string list of = delimited pairs |
| var listResult []string |
| if err := mapstructure.WeakDecode(raw, &listResult); err != nil { |
| return nil, false, err |
| } |
| |
| result := make(map[string]string, len(listResult)) |
| for _, keyPair := range listResult { |
| keyPairSlice := strings.SplitN(keyPair, "=", 2) |
| if len(keyPairSlice) != 2 || keyPairSlice[0] == "" { |
| return nil, false, fmt.Errorf("invalid key pair %q", keyPair) |
| } |
| result[keyPairSlice[0]] = keyPairSlice[1] |
| } |
| return result, true, nil |
| |
| case TypeHeader: |
| /* |
| |
| There are multiple ways a header could be provided: |
| |
| 1. As a map[string]interface{} that resolves to a map[string]string or map[string][]string, or a mix of both |
| because that's permitted for headers. |
| This mainly comes from the API. |
| |
| 2. As a string... |
| a. That contains JSON that originally was JSON, but then was base64 encoded. |
| b. That contains JSON, ex. `{"content-type":"text/json","accept":["encoding/json"]}`. |
| This mainly comes from the API and is used to save space while sending in the header. |
| |
| 3. As an array of strings that contains comma-delimited key-value pairs associated via a colon, |
| ex: `content-type:text/json`,`accept:encoding/json`. |
| This mainly comes from the CLI. |
| |
| We go through these sequentially below. |
| |
| */ |
| result := http.Header{} |
| |
| toHeader := func(resultMap map[string]interface{}) (http.Header, error) { |
| header := http.Header{} |
| for headerKey, headerValGroup := range resultMap { |
| switch typedHeader := headerValGroup.(type) { |
| case string: |
| header.Add(headerKey, typedHeader) |
| case []string: |
| for _, headerVal := range typedHeader { |
| header.Add(headerKey, headerVal) |
| } |
| case json.Number: |
| header.Add(headerKey, typedHeader.String()) |
| case []interface{}: |
| for _, headerVal := range typedHeader { |
| switch typedHeader := headerVal.(type) { |
| case string: |
| header.Add(headerKey, typedHeader) |
| case json.Number: |
| header.Add(headerKey, typedHeader.String()) |
| default: |
| // All header values should already be strings when they're being sent in. |
| // Even numbers and booleans will be treated as strings. |
| return nil, fmt.Errorf("received non-string value for header key:%s, val:%s", headerKey, headerValGroup) |
| } |
| } |
| default: |
| return nil, fmt.Errorf("unrecognized type for %s", headerValGroup) |
| } |
| } |
| return header, nil |
| } |
| |
| resultMap := make(map[string]interface{}) |
| |
| // 1. Are we getting a map from the API? |
| if err := mapstructure.WeakDecode(raw, &resultMap); err == nil { |
| result, err = toHeader(resultMap) |
| if err != nil { |
| return nil, false, err |
| } |
| return result, true, nil |
| } |
| |
| // 2. Are we getting a JSON string? |
| if headerStr, ok := raw.(string); ok { |
| // a. Is it base64 encoded? |
| headerBytes, err := base64.StdEncoding.DecodeString(headerStr) |
| if err != nil { |
| // b. It's not base64 encoded, it's a straight-out JSON string. |
| headerBytes = []byte(headerStr) |
| } |
| if err := jsonutil.DecodeJSON(headerBytes, &resultMap); err != nil { |
| return nil, false, err |
| } |
| result, err = toHeader(resultMap) |
| if err != nil { |
| return nil, false, err |
| } |
| return result, true, nil |
| } |
| |
| // 3. Are we getting an array of fields like "content-type:encoding/json" from the CLI? |
| var keyPairs []interface{} |
| if err := mapstructure.WeakDecode(raw, &keyPairs); err == nil { |
| for _, keyPairIfc := range keyPairs { |
| keyPair, ok := keyPairIfc.(string) |
| if !ok { |
| return nil, false, fmt.Errorf("invalid key pair %q", keyPair) |
| } |
| keyPairSlice := strings.SplitN(keyPair, ":", 2) |
| if len(keyPairSlice) != 2 || keyPairSlice[0] == "" { |
| return nil, false, fmt.Errorf("invalid key pair %q", keyPair) |
| } |
| result.Add(keyPairSlice[0], keyPairSlice[1]) |
| } |
| return result, true, nil |
| } |
| return nil, false, fmt.Errorf("%s not provided an expected format", raw) |
| |
| default: |
| panic(fmt.Sprintf("Unknown type: %s", schema.Type)) |
| } |
| } |