| // Copyright (c) HashiCorp, Inc. |
| // SPDX-License-Identifier: MPL-2.0 |
| |
| package forwarding |
| |
| import ( |
| "bufio" |
| "bytes" |
| "net/http" |
| "os" |
| "reflect" |
| "testing" |
| ) |
| |
| func Test_ForwardedRequest_GenerateParse(t *testing.T) { |
| testForwardedRequestGenerateParse(t) |
| } |
| |
| func Benchmark_ForwardedRequest_GenerateParse_JSON(b *testing.B) { |
| os.Setenv("VAULT_MESSAGE_TYPE", "json") |
| var totalSize int64 |
| var numRuns int64 |
| for i := 0; i < b.N; i++ { |
| totalSize += testForwardedRequestGenerateParse(b) |
| numRuns++ |
| } |
| b.Logf("message size per op: %d", totalSize/numRuns) |
| } |
| |
| func Benchmark_ForwardedRequest_GenerateParse_JSON_Compressed(b *testing.B) { |
| os.Setenv("VAULT_MESSAGE_TYPE", "json_compress") |
| var totalSize int64 |
| var numRuns int64 |
| for i := 0; i < b.N; i++ { |
| totalSize += testForwardedRequestGenerateParse(b) |
| numRuns++ |
| } |
| b.Logf("message size per op: %d", totalSize/numRuns) |
| } |
| |
| func Benchmark_ForwardedRequest_GenerateParse_Proto3(b *testing.B) { |
| os.Setenv("VAULT_MESSAGE_TYPE", "proto3") |
| var totalSize int64 |
| var numRuns int64 |
| for i := 0; i < b.N; i++ { |
| totalSize += testForwardedRequestGenerateParse(b) |
| numRuns++ |
| } |
| b.Logf("message size per op: %d", totalSize/numRuns) |
| } |
| |
| func testForwardedRequestGenerateParse(t testing.TB) int64 { |
| bodBuf := bytes.NewReader([]byte(`{ "foo": "bar", "zip": { "argle": "bargle", neet: 0 } }`)) |
| req, err := http.NewRequest("FOOBAR", "https://pushit.real.good:9281/snicketysnack?furbleburble=bloopetybloop", bodBuf) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| // We want to get the fields we would expect from an incoming request, so |
| // we write it out and then read it again |
| buf1 := bytes.NewBuffer(nil) |
| err = req.Write(buf1) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| // Read it back in, parsing like a server |
| bufr1 := bufio.NewReader(buf1) |
| initialReq, err := http.ReadRequest(bufr1) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| // Generate the request with the forwarded request in the body |
| req, err = GenerateForwardedHTTPRequest(initialReq, "https://bloopety.bloop:8201") |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| // Perform another "round trip" |
| buf2 := bytes.NewBuffer(nil) |
| err = req.Write(buf2) |
| if err != nil { |
| t.Fatal(err) |
| } |
| size := int64(buf2.Len()) |
| bufr2 := bufio.NewReader(buf2) |
| intreq, err := http.ReadRequest(bufr2) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| // Now extract the forwarded request to generate a final request for processing |
| finalReq, err := ParseForwardedHTTPRequest(intreq) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| switch { |
| case initialReq.Method != finalReq.Method: |
| t.Fatalf("bad method:\ninitialReq:\n%#v\nfinalReq:\n%#v\n", *initialReq, *finalReq) |
| case initialReq.RemoteAddr != finalReq.RemoteAddr: |
| t.Fatalf("bad remoteaddr:\ninitialReq:\n%#v\nfinalReq:\n%#v\n", *initialReq, *finalReq) |
| case initialReq.Host != finalReq.Host: |
| t.Fatalf("bad host:\ninitialReq:\n%#v\nfinalReq:\n%#v\n", *initialReq, *finalReq) |
| case !reflect.DeepEqual(initialReq.URL, finalReq.URL): |
| t.Fatalf("bad url:\ninitialReq:\n%#v\nfinalReq:\n%#v\n", *initialReq.URL, *finalReq.URL) |
| case !reflect.DeepEqual(initialReq.Header, finalReq.Header): |
| t.Fatalf("bad header:\ninitialReq:\n%#v\nfinalReq:\n%#v\n", *initialReq, *finalReq) |
| default: |
| // Compare bodies |
| bodBuf.Seek(0, 0) |
| initBuf := bytes.NewBuffer(nil) |
| _, err = initBuf.ReadFrom(bodBuf) |
| if err != nil { |
| t.Fatal(err) |
| } |
| finBuf := bytes.NewBuffer(nil) |
| _, err = finBuf.ReadFrom(finalReq.Body) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if !reflect.DeepEqual(initBuf.Bytes(), finBuf.Bytes()) { |
| t.Fatalf("badbody :\ninitialReq:\n%#v\nfinalReq:\n%#v\n", initBuf.Bytes(), finBuf.Bytes()) |
| } |
| } |
| |
| return size |
| } |