| package http |
| |
| //go:generate go run github.com/golang/mock/mockgen -package $GOPACKAGE -source $GOFILE -destination mock_$GOFILE |
| |
| import ( |
| "context" |
| "crypto/tls" |
| "crypto/x509" |
| "encoding/json" |
| "io" |
| "net/http" |
| "net/http/httptest" |
| "os" |
| "os/signal" |
| "path/filepath" |
| "reflect" |
| "strings" |
| "sync" |
| "syscall" |
| "testing" |
| |
| "github.com/golang/mock/gomock" |
| "github.com/hashicorp/terraform/internal/addrs" |
| "github.com/hashicorp/terraform/internal/backend" |
| "github.com/hashicorp/terraform/internal/configs" |
| "github.com/hashicorp/terraform/internal/states" |
| "github.com/zclconf/go-cty/cty" |
| ) |
| |
| const sampleState = ` |
| { |
| "version": 4, |
| "serial": 0, |
| "lineage": "666f9301-7e65-4b19-ae23-71184bb19b03", |
| "remote": { |
| "type": "http", |
| "config": { |
| "path": "local-state.tfstate" |
| } |
| } |
| } |
| ` |
| |
| type ( |
| HttpServerCallback interface { |
| StateGET(req *http.Request) |
| StatePOST(req *http.Request) |
| StateDELETE(req *http.Request) |
| StateLOCK(req *http.Request) |
| StateUNLOCK(req *http.Request) |
| } |
| httpServer struct { |
| r *http.ServeMux |
| data map[string]string |
| locks map[string]string |
| lock sync.RWMutex |
| |
| httpServerCallback HttpServerCallback |
| } |
| httpServerOpt func(*httpServer) |
| ) |
| |
| func withHttpServerCallback(callback HttpServerCallback) httpServerOpt { |
| return func(s *httpServer) { |
| s.httpServerCallback = callback |
| } |
| } |
| |
| func newHttpServer(opts ...httpServerOpt) *httpServer { |
| r := http.NewServeMux() |
| s := &httpServer{ |
| r: r, |
| data: make(map[string]string), |
| locks: make(map[string]string), |
| } |
| for _, opt := range opts { |
| opt(s) |
| } |
| s.data["sample"] = sampleState |
| r.HandleFunc("/state/", s.handleState) |
| return s |
| } |
| |
| func (h *httpServer) getResource(req *http.Request) string { |
| switch pathParts := strings.SplitN(req.URL.Path, string(filepath.Separator), 3); len(pathParts) { |
| case 3: |
| return pathParts[2] |
| default: |
| return "" |
| } |
| } |
| |
| func (h *httpServer) handleState(writer http.ResponseWriter, req *http.Request) { |
| switch req.Method { |
| case "GET": |
| h.handleStateGET(writer, req) |
| case "POST": |
| h.handleStatePOST(writer, req) |
| case "DELETE": |
| h.handleStateDELETE(writer, req) |
| case "LOCK": |
| h.handleStateLOCK(writer, req) |
| case "UNLOCK": |
| h.handleStateUNLOCK(writer, req) |
| } |
| } |
| |
| func (h *httpServer) handleStateGET(writer http.ResponseWriter, req *http.Request) { |
| if h.httpServerCallback != nil { |
| defer h.httpServerCallback.StateGET(req) |
| } |
| resource := h.getResource(req) |
| |
| h.lock.RLock() |
| defer h.lock.RUnlock() |
| |
| if state, ok := h.data[resource]; ok { |
| _, _ = io.WriteString(writer, state) |
| } else { |
| writer.WriteHeader(http.StatusNotFound) |
| } |
| } |
| |
| func (h *httpServer) handleStatePOST(writer http.ResponseWriter, req *http.Request) { |
| if h.httpServerCallback != nil { |
| defer h.httpServerCallback.StatePOST(req) |
| } |
| defer req.Body.Close() |
| resource := h.getResource(req) |
| |
| data, err := io.ReadAll(req.Body) |
| if err != nil { |
| writer.WriteHeader(http.StatusBadRequest) |
| return |
| } |
| |
| h.lock.Lock() |
| defer h.lock.Unlock() |
| |
| h.data[resource] = string(data) |
| writer.WriteHeader(http.StatusOK) |
| } |
| |
| func (h *httpServer) handleStateDELETE(writer http.ResponseWriter, req *http.Request) { |
| if h.httpServerCallback != nil { |
| defer h.httpServerCallback.StateDELETE(req) |
| } |
| resource := h.getResource(req) |
| |
| h.lock.Lock() |
| defer h.lock.Unlock() |
| |
| delete(h.data, resource) |
| writer.WriteHeader(http.StatusOK) |
| } |
| |
| func (h *httpServer) handleStateLOCK(writer http.ResponseWriter, req *http.Request) { |
| if h.httpServerCallback != nil { |
| defer h.httpServerCallback.StateLOCK(req) |
| } |
| defer req.Body.Close() |
| resource := h.getResource(req) |
| |
| data, err := io.ReadAll(req.Body) |
| if err != nil { |
| writer.WriteHeader(http.StatusBadRequest) |
| return |
| } |
| |
| h.lock.Lock() |
| defer h.lock.Unlock() |
| |
| if existingLock, ok := h.locks[resource]; ok { |
| writer.WriteHeader(http.StatusLocked) |
| _, _ = io.WriteString(writer, existingLock) |
| } else { |
| h.locks[resource] = string(data) |
| _, _ = io.WriteString(writer, existingLock) |
| } |
| } |
| |
| func (h *httpServer) handleStateUNLOCK(writer http.ResponseWriter, req *http.Request) { |
| if h.httpServerCallback != nil { |
| defer h.httpServerCallback.StateUNLOCK(req) |
| } |
| defer req.Body.Close() |
| resource := h.getResource(req) |
| |
| data, err := io.ReadAll(req.Body) |
| if err != nil { |
| writer.WriteHeader(http.StatusBadRequest) |
| return |
| } |
| var lockInfo map[string]interface{} |
| if err = json.Unmarshal(data, &lockInfo); err != nil { |
| writer.WriteHeader(http.StatusInternalServerError) |
| return |
| } |
| |
| h.lock.Lock() |
| defer h.lock.Unlock() |
| |
| if existingLock, ok := h.locks[resource]; ok { |
| var existingLockInfo map[string]interface{} |
| if err = json.Unmarshal([]byte(existingLock), &existingLockInfo); err != nil { |
| writer.WriteHeader(http.StatusInternalServerError) |
| return |
| } |
| lockID := lockInfo["ID"].(string) |
| existingID := existingLockInfo["ID"].(string) |
| if lockID != existingID { |
| writer.WriteHeader(http.StatusConflict) |
| _, _ = io.WriteString(writer, existingLock) |
| } else { |
| delete(h.locks, resource) |
| _, _ = io.WriteString(writer, existingLock) |
| } |
| } else { |
| writer.WriteHeader(http.StatusConflict) |
| } |
| } |
| |
| func (h *httpServer) handler() http.Handler { |
| return h.r |
| } |
| |
| func NewHttpTestServer(opts ...httpServerOpt) (*httptest.Server, error) { |
| clientCAData, err := os.ReadFile("testdata/certs/ca.cert.pem") |
| if err != nil { |
| return nil, err |
| } |
| clientCAs := x509.NewCertPool() |
| clientCAs.AppendCertsFromPEM(clientCAData) |
| |
| cert, err := tls.LoadX509KeyPair("testdata/certs/server.crt", "testdata/certs/server.key") |
| if err != nil { |
| return nil, err |
| } |
| |
| h := newHttpServer(opts...) |
| s := httptest.NewUnstartedServer(h.handler()) |
| s.TLS = &tls.Config{ |
| ClientAuth: tls.RequireAndVerifyClientCert, |
| ClientCAs: clientCAs, |
| Certificates: []tls.Certificate{cert}, |
| } |
| |
| s.StartTLS() |
| return s, nil |
| } |
| |
| func TestMTLSServer_NoCertFails(t *testing.T) { |
| // Ensure that no calls are made to the server - everything is blocked by the tls.RequireAndVerifyClientCert |
| ctrl := gomock.NewController(t) |
| defer ctrl.Finish() |
| mockCallback := NewMockHttpServerCallback(ctrl) |
| |
| // Fire up a test server |
| ts, err := NewHttpTestServer(withHttpServerCallback(mockCallback)) |
| if err != nil { |
| t.Fatalf("unexpected error creating test server: %v", err) |
| } |
| defer ts.Close() |
| |
| // Configure the backend to the pre-populated sample state |
| url := ts.URL + "/state/sample" |
| conf := map[string]cty.Value{ |
| "address": cty.StringVal(url), |
| "skip_cert_verification": cty.BoolVal(true), |
| } |
| b := backend.TestBackendConfig(t, New(), configs.SynthBody("synth", conf)).(*Backend) |
| if nil == b { |
| t.Fatal("nil backend") |
| } |
| |
| // Now get a state manager and check that it fails to refresh the state |
| sm, err := b.StateMgr(backend.DefaultStateName) |
| if err != nil { |
| t.Fatalf("unexpected error fetching StateMgr with %s: %v", backend.DefaultStateName, err) |
| } |
| err = sm.RefreshState() |
| if nil == err { |
| t.Error("expected error when refreshing state without a client cert") |
| } else if !strings.Contains(err.Error(), "remote error: tls: bad certificate") { |
| t.Errorf("expected the error to report missing tls credentials: %v", err) |
| } |
| } |
| |
| func TestMTLSServer_WithCertPasses(t *testing.T) { |
| // Ensure that the expected amount of calls is made to the server |
| ctrl := gomock.NewController(t) |
| defer ctrl.Finish() |
| mockCallback := NewMockHttpServerCallback(ctrl) |
| |
| // Two or three (not testing the caching here) calls to GET |
| mockCallback.EXPECT(). |
| StateGET(gomock.Any()). |
| MinTimes(2). |
| MaxTimes(3) |
| // One call to the POST to write the data |
| mockCallback.EXPECT(). |
| StatePOST(gomock.Any()) |
| |
| // Fire up a test server |
| ts, err := NewHttpTestServer(withHttpServerCallback(mockCallback)) |
| if err != nil { |
| t.Fatalf("unexpected error creating test server: %v", err) |
| } |
| defer ts.Close() |
| |
| // Configure the backend to the pre-populated sample state, and with all the test certs lined up |
| url := ts.URL + "/state/sample" |
| caData, err := os.ReadFile("testdata/certs/ca.cert.pem") |
| if err != nil { |
| t.Fatalf("error reading ca certs: %v", err) |
| } |
| clientCertData, err := os.ReadFile("testdata/certs/client.crt") |
| if err != nil { |
| t.Fatalf("error reading client cert: %v", err) |
| } |
| clientKeyData, err := os.ReadFile("testdata/certs/client.key") |
| if err != nil { |
| t.Fatalf("error reading client key: %v", err) |
| } |
| conf := map[string]cty.Value{ |
| "address": cty.StringVal(url), |
| "lock_address": cty.StringVal(url), |
| "unlock_address": cty.StringVal(url), |
| "client_ca_certificate_pem": cty.StringVal(string(caData)), |
| "client_certificate_pem": cty.StringVal(string(clientCertData)), |
| "client_private_key_pem": cty.StringVal(string(clientKeyData)), |
| } |
| b := backend.TestBackendConfig(t, New(), configs.SynthBody("synth", conf)).(*Backend) |
| if nil == b { |
| t.Fatal("nil backend") |
| } |
| |
| // Now get a state manager, fetch the state, and ensure that the "foo" output is not set |
| sm, err := b.StateMgr(backend.DefaultStateName) |
| if err != nil { |
| t.Fatalf("unexpected error fetching StateMgr with %s: %v", backend.DefaultStateName, err) |
| } |
| if err = sm.RefreshState(); err != nil { |
| t.Fatalf("unexpected error calling RefreshState: %v", err) |
| } |
| state := sm.State() |
| if nil == state { |
| t.Fatal("nil state") |
| } |
| stateFoo := state.OutputValue(addrs.OutputValue{Name: "foo"}.Absolute(addrs.RootModuleInstance)) |
| if stateFoo != nil { |
| t.Errorf("expected nil foo from state; got %v", stateFoo) |
| } |
| |
| // Create a new state that has "foo" set to "bar" and ensure that state is as expected |
| state = states.BuildState(func(ss *states.SyncState) { |
| ss.SetOutputValue( |
| addrs.OutputValue{Name: "foo"}.Absolute(addrs.RootModuleInstance), |
| cty.StringVal("bar"), |
| false) |
| }) |
| stateFoo = state.OutputValue(addrs.OutputValue{Name: "foo"}.Absolute(addrs.RootModuleInstance)) |
| if nil == stateFoo { |
| t.Fatal("nil foo after building state with foo populated") |
| } |
| if foo := stateFoo.Value.AsString(); foo != "bar" { |
| t.Errorf("Expected built state foo value to be bar; got %s", foo) |
| } |
| |
| // Ensure the change hasn't altered the current state manager state by checking "foo" and comparing states |
| curState := sm.State() |
| curStateFoo := curState.OutputValue(addrs.OutputValue{Name: "foo"}.Absolute(addrs.RootModuleInstance)) |
| if curStateFoo != nil { |
| t.Errorf("expected session manager state to be unaltered and still nil, but got: %v", curStateFoo) |
| } |
| if reflect.DeepEqual(state, curState) { |
| t.Errorf("expected %v != %v; but they were equal", state, curState) |
| } |
| |
| // Write the new state, persist, and refresh |
| if err = sm.WriteState(state); err != nil { |
| t.Errorf("error writing state: %v", err) |
| } |
| if err = sm.PersistState(nil); err != nil { |
| t.Errorf("error persisting state: %v", err) |
| } |
| if err = sm.RefreshState(); err != nil { |
| t.Errorf("error refreshing state: %v", err) |
| } |
| |
| // Get the state again and verify that is now the same as state and has the "foo" value set to "bar" |
| curState = sm.State() |
| if !reflect.DeepEqual(state, curState) { |
| t.Errorf("expected %v == %v; but they were unequal", state, curState) |
| } |
| curStateFoo = curState.OutputValue(addrs.OutputValue{Name: "foo"}.Absolute(addrs.RootModuleInstance)) |
| if nil == curStateFoo { |
| t.Fatal("nil foo") |
| } |
| if foo := curStateFoo.Value.AsString(); foo != "bar" { |
| t.Errorf("expected foo to be bar, but got: %s", foo) |
| } |
| } |
| |
| // TestRunServer allows running the server for local debugging; it runs until ctl-c is received |
| func TestRunServer(t *testing.T) { |
| if _, ok := os.LookupEnv("TEST_RUN_SERVER"); !ok { |
| t.Skip("TEST_RUN_SERVER not set") |
| } |
| s, err := NewHttpTestServer() |
| if err != nil { |
| t.Fatalf("unexpected error creating test server: %v", err) |
| } |
| defer s.Close() |
| |
| t.Log(s.URL) |
| |
| ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) |
| defer cancel() |
| // wait until signal |
| <-ctx.Done() |
| } |