| // Copyright (c) HashiCorp, Inc. |
| // SPDX-License-Identifier: MPL-2.0 |
| |
| package testing |
| |
| import ( |
| _ "embed" |
| "encoding/json" |
| "fmt" |
| "io/ioutil" |
| "net/http" |
| "net/http/httptest" |
| "os" |
| "path" |
| "strings" |
| "sync" |
| "testing" |
| |
| "go.uber.org/atomic" |
| ) |
| |
| const ( |
| ExpectedNamespace = "default" |
| ExpectedPodName = "shell-demo" |
| ) |
| |
| // Pull real-life-based testing data in from files at compile time. |
| // We decided to embed them in the test binary because of past issues |
| // with reading files that we encountered on CI workers. |
| |
| //go:embed ca.crt |
| var caCrt string |
| |
| //go:embed resp-get-pod.json |
| var getPodResponse string |
| |
| //go:embed resp-not-found.json |
| var notFoundResponse string |
| |
| //go:embed resp-update-pod.json |
| var updatePodTagsResponse string |
| |
| //go:embed token |
| var token string |
| |
| var ( |
| // ReturnGatewayTimeouts toggles whether the test server should return, |
| // well, gateway timeouts... |
| ReturnGatewayTimeouts = atomic.NewBool(false) |
| |
| pathToFiles = func() string { |
| wd, _ := os.Getwd() |
| repoName := "vault-enterprise" |
| if !strings.Contains(wd, repoName) { |
| repoName = "vault" |
| } |
| pathParts := strings.Split(wd, repoName) |
| return pathParts[0] + "vault/serviceregistration/kubernetes/testing/" |
| }() |
| ) |
| |
| // Conf returns the info needed to configure the client to point at |
| // the test server. This must be done by the caller to avoid an import |
| // cycle between the client and the testserver. Example usage: |
| // |
| // client.Scheme = testConf.ClientScheme |
| // client.TokenFile = testConf.PathToTokenFile |
| // client.RootCAFile = testConf.PathToRootCAFile |
| // if err := os.Setenv(client.EnvVarKubernetesServiceHost, testConf.ServiceHost); err != nil { |
| // t.Fatal(err) |
| // } |
| // if err := os.Setenv(client.EnvVarKubernetesServicePort, testConf.ServicePort); err != nil { |
| // t.Fatal(err) |
| // } |
| type Conf struct { |
| ClientScheme, PathToTokenFile, PathToRootCAFile, ServiceHost, ServicePort string |
| } |
| |
| // Server returns an http test server that can be used to test |
| // Kubernetes client code. It also retains the current state, |
| // and a func to close the server and to clean up any temporary |
| // files. |
| func Server(t *testing.T) (testState *State, testConf *Conf, closeFunc func()) { |
| testState = &State{m: &sync.Map{}} |
| testConf = &Conf{ |
| ClientScheme: "http://", |
| } |
| |
| // We're going to have multiple close funcs to call. |
| var closers []func() |
| closeFunc = func() { |
| for _, closer := range closers { |
| closer() |
| } |
| } |
| |
| // Plant our token in a place where it can be read for the config. |
| tmpToken, err := ioutil.TempFile("", "token") |
| if err != nil { |
| t.Fatal(err) |
| } |
| closers = append(closers, func() { |
| os.Remove(tmpToken.Name()) |
| }) |
| if _, err = tmpToken.WriteString(token); err != nil { |
| closeFunc() |
| t.Fatal(err) |
| } |
| if err := tmpToken.Close(); err != nil { |
| closeFunc() |
| t.Fatal(err) |
| } |
| testConf.PathToTokenFile = tmpToken.Name() |
| |
| tmpCACrt, err := ioutil.TempFile("", "ca.crt") |
| if err != nil { |
| closeFunc() |
| t.Fatal(err) |
| } |
| closers = append(closers, func() { |
| os.Remove(tmpCACrt.Name()) |
| }) |
| if _, err = tmpCACrt.WriteString(caCrt); err != nil { |
| closeFunc() |
| t.Fatal(err) |
| } |
| if err := tmpCACrt.Close(); err != nil { |
| closeFunc() |
| t.Fatal(err) |
| } |
| testConf.PathToRootCAFile = tmpCACrt.Name() |
| |
| ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| if ReturnGatewayTimeouts.Load() { |
| w.WriteHeader(504) |
| return |
| } |
| namespace, podName, err := parsePath(r.URL.Path) |
| if err != nil { |
| w.WriteHeader(400) |
| w.Write([]byte(fmt.Sprintf("unable to parse %s: %s", r.URL.Path, err.Error()))) |
| return |
| } |
| |
| switch { |
| case namespace != ExpectedNamespace, podName != ExpectedPodName: |
| w.WriteHeader(404) |
| w.Write([]byte(notFoundResponse)) |
| return |
| case r.Method == http.MethodGet: |
| w.WriteHeader(200) |
| w.Write([]byte(getPodResponse)) |
| return |
| case r.Method == http.MethodPatch: |
| var patches []interface{} |
| if err := json.NewDecoder(r.Body).Decode(&patches); err != nil { |
| w.WriteHeader(400) |
| w.Write([]byte(fmt.Sprintf("unable to decode patches %s: %s", r.URL.Path, err.Error()))) |
| return |
| } |
| for _, patch := range patches { |
| patchMap := patch.(map[string]interface{}) |
| p := patchMap["path"].(string) |
| testState.store(p, patchMap) |
| } |
| w.WriteHeader(200) |
| w.Write([]byte(updatePodTagsResponse)) |
| return |
| default: |
| w.WriteHeader(400) |
| w.Write([]byte(fmt.Sprintf("unexpected request method: %s", r.Method))) |
| } |
| })) |
| closers = append(closers, ts.Close) |
| |
| // ts.URL example: http://127.0.0.1:35681 |
| urlFields := strings.Split(ts.URL, "://") |
| if len(urlFields) != 2 { |
| closeFunc() |
| t.Fatal("received unexpected test url: " + ts.URL) |
| } |
| urlFields = strings.Split(urlFields[1], ":") |
| if len(urlFields) != 2 { |
| closeFunc() |
| t.Fatal("received unexpected test url: " + ts.URL) |
| } |
| testConf.ServiceHost = urlFields[0] |
| testConf.ServicePort = urlFields[1] |
| return testState, testConf, closeFunc |
| } |
| |
| type State struct { |
| m *sync.Map |
| } |
| |
| func (s *State) NumPatches() int { |
| l := 0 |
| f := func(key, value interface{}) bool { |
| l++ |
| return true |
| } |
| s.m.Range(f) |
| return l |
| } |
| |
| func (s *State) Get(key string) map[string]interface{} { |
| v, ok := s.m.Load(key) |
| if !ok { |
| return nil |
| } |
| patch, ok := v.(map[string]interface{}) |
| if !ok { |
| return nil |
| } |
| return patch |
| } |
| |
| func (s *State) store(k string, p map[string]interface{}) { |
| s.m.Store(k, p) |
| } |
| |
| // The path should be formatted like this: |
| // fmt.Sprintf("/api/v1/namespaces/%s/pods/%s", namespace, podName) |
| func parsePath(urlPath string) (namespace, podName string, err error) { |
| original := urlPath |
| podName = path.Base(urlPath) |
| urlPath = strings.TrimSuffix(urlPath, "/pods/"+podName) |
| namespace = path.Base(urlPath) |
| if original != fmt.Sprintf("/api/v1/namespaces/%s/pods/%s", namespace, podName) { |
| return "", "", fmt.Errorf("received unexpected path: %s", original) |
| } |
| return namespace, podName, nil |
| } |
| |
| func readFile(fileName string) (string, error) { |
| b, err := ioutil.ReadFile(pathToFiles + fileName) |
| if err != nil { |
| return "", err |
| } |
| return string(b), nil |
| } |