| // Copyright (c) HashiCorp, Inc. |
| // SPDX-License-Identifier: MPL-2.0 |
| |
| package configutil |
| |
| import ( |
| "fmt" |
| "net/textproto" |
| "strconv" |
| "strings" |
| |
| "github.com/hashicorp/go-secure-stdlib/strutil" |
| ) |
| |
| var ValidCustomStatusCodeCollection = []string{ |
| "default", |
| "1xx", |
| "2xx", |
| "3xx", |
| "4xx", |
| "5xx", |
| } |
| |
| const StrictTransportSecurity = "max-age=31536000; includeSubDomains" |
| |
| // ParseCustomResponseHeaders takes a raw config values for the |
| // "custom_response_headers". It makes sure the config entry is passed in |
| // as a map of status code to a map of header name and header values. It |
| // verifies the validity of the status codes, and header values. It also |
| // adds the default headers values. |
| func ParseCustomResponseHeaders(responseHeaders interface{}) (map[string]map[string]string, error) { |
| h := make(map[string]map[string]string) |
| // if r is nil, we still should set the default custom headers |
| if responseHeaders == nil { |
| h["default"] = map[string]string{"Strict-Transport-Security": StrictTransportSecurity} |
| return h, nil |
| } |
| |
| customResponseHeader, ok := responseHeaders.([]map[string]interface{}) |
| if !ok { |
| return nil, fmt.Errorf("response headers were not configured correctly. please make sure they're in a slice of maps") |
| } |
| |
| for _, crh := range customResponseHeader { |
| for statusCode, responseHeader := range crh { |
| headerValList, ok := responseHeader.([]map[string]interface{}) |
| if !ok { |
| return nil, fmt.Errorf("response headers were not configured correctly. please make sure they're in a slice of maps") |
| } |
| |
| if !IsValidStatusCode(statusCode) { |
| return nil, fmt.Errorf("invalid status code found in the server configuration: %v", statusCode) |
| } |
| |
| if len(headerValList) != 1 { |
| return nil, fmt.Errorf("invalid number of response headers exist") |
| } |
| headerValMap := headerValList[0] |
| headerVal, err := parseHeaders(headerValMap) |
| if err != nil { |
| return nil, err |
| } |
| |
| h[statusCode] = headerVal |
| } |
| } |
| |
| // setting Strict-Transport-Security as a default header |
| if h["default"] == nil { |
| h["default"] = make(map[string]string) |
| } |
| if _, ok := h["default"]["Strict-Transport-Security"]; !ok { |
| h["default"]["Strict-Transport-Security"] = StrictTransportSecurity |
| } |
| |
| return h, nil |
| } |
| |
| // IsValidStatusCode checking for status codes outside the boundary |
| func IsValidStatusCode(sc string) bool { |
| if strutil.StrListContains(ValidCustomStatusCodeCollection, sc) { |
| return true |
| } |
| |
| i, err := strconv.Atoi(sc) |
| if err != nil { |
| return false |
| } |
| |
| if i >= 600 || i < 100 { |
| return false |
| } |
| |
| return true |
| } |
| |
| func parseHeaders(in map[string]interface{}) (map[string]string, error) { |
| hvMap := make(map[string]string) |
| for k, v := range in { |
| // parsing header name |
| headerName := textproto.CanonicalMIMEHeaderKey(k) |
| // parsing header values |
| s, err := parseHeaderValues(v) |
| if err != nil { |
| return nil, err |
| } |
| hvMap[headerName] = s |
| } |
| return hvMap, nil |
| } |
| |
| func parseHeaderValues(header interface{}) (string, error) { |
| var sl []string |
| if _, ok := header.([]interface{}); !ok { |
| return "", fmt.Errorf("headers must be given in a list of strings") |
| } |
| headerValList := header.([]interface{}) |
| for _, vh := range headerValList { |
| if _, ok := vh.(string); !ok { |
| return "", fmt.Errorf("found a non-string header value: %v", vh) |
| } |
| headerVal := strings.TrimSpace(vh.(string)) |
| if headerVal == "" { |
| continue |
| } |
| sl = append(sl, headerVal) |
| |
| } |
| s := strings.Join(sl, "; ") |
| |
| return s, nil |
| } |