blob: 192646a15f4d2d0620ca9359d85d0b71e9e2b7b6 [file] [log] [blame]
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package forwarding
import (
"bufio"
"bytes"
"net/http"
"os"
"reflect"
"testing"
)
func Test_ForwardedRequest_GenerateParse(t *testing.T) {
testForwardedRequestGenerateParse(t)
}
func Benchmark_ForwardedRequest_GenerateParse_JSON(b *testing.B) {
os.Setenv("VAULT_MESSAGE_TYPE", "json")
var totalSize int64
var numRuns int64
for i := 0; i < b.N; i++ {
totalSize += testForwardedRequestGenerateParse(b)
numRuns++
}
b.Logf("message size per op: %d", totalSize/numRuns)
}
func Benchmark_ForwardedRequest_GenerateParse_JSON_Compressed(b *testing.B) {
os.Setenv("VAULT_MESSAGE_TYPE", "json_compress")
var totalSize int64
var numRuns int64
for i := 0; i < b.N; i++ {
totalSize += testForwardedRequestGenerateParse(b)
numRuns++
}
b.Logf("message size per op: %d", totalSize/numRuns)
}
func Benchmark_ForwardedRequest_GenerateParse_Proto3(b *testing.B) {
os.Setenv("VAULT_MESSAGE_TYPE", "proto3")
var totalSize int64
var numRuns int64
for i := 0; i < b.N; i++ {
totalSize += testForwardedRequestGenerateParse(b)
numRuns++
}
b.Logf("message size per op: %d", totalSize/numRuns)
}
func testForwardedRequestGenerateParse(t testing.TB) int64 {
bodBuf := bytes.NewReader([]byte(`{ "foo": "bar", "zip": { "argle": "bargle", neet: 0 } }`))
req, err := http.NewRequest("FOOBAR", "https://pushit.real.good:9281/snicketysnack?furbleburble=bloopetybloop", bodBuf)
if err != nil {
t.Fatal(err)
}
// We want to get the fields we would expect from an incoming request, so
// we write it out and then read it again
buf1 := bytes.NewBuffer(nil)
err = req.Write(buf1)
if err != nil {
t.Fatal(err)
}
// Read it back in, parsing like a server
bufr1 := bufio.NewReader(buf1)
initialReq, err := http.ReadRequest(bufr1)
if err != nil {
t.Fatal(err)
}
// Generate the request with the forwarded request in the body
req, err = GenerateForwardedHTTPRequest(initialReq, "https://bloopety.bloop:8201")
if err != nil {
t.Fatal(err)
}
// Perform another "round trip"
buf2 := bytes.NewBuffer(nil)
err = req.Write(buf2)
if err != nil {
t.Fatal(err)
}
size := int64(buf2.Len())
bufr2 := bufio.NewReader(buf2)
intreq, err := http.ReadRequest(bufr2)
if err != nil {
t.Fatal(err)
}
// Now extract the forwarded request to generate a final request for processing
finalReq, err := ParseForwardedHTTPRequest(intreq)
if err != nil {
t.Fatal(err)
}
switch {
case initialReq.Method != finalReq.Method:
t.Fatalf("bad method:\ninitialReq:\n%#v\nfinalReq:\n%#v\n", *initialReq, *finalReq)
case initialReq.RemoteAddr != finalReq.RemoteAddr:
t.Fatalf("bad remoteaddr:\ninitialReq:\n%#v\nfinalReq:\n%#v\n", *initialReq, *finalReq)
case initialReq.Host != finalReq.Host:
t.Fatalf("bad host:\ninitialReq:\n%#v\nfinalReq:\n%#v\n", *initialReq, *finalReq)
case !reflect.DeepEqual(initialReq.URL, finalReq.URL):
t.Fatalf("bad url:\ninitialReq:\n%#v\nfinalReq:\n%#v\n", *initialReq.URL, *finalReq.URL)
case !reflect.DeepEqual(initialReq.Header, finalReq.Header):
t.Fatalf("bad header:\ninitialReq:\n%#v\nfinalReq:\n%#v\n", *initialReq, *finalReq)
default:
// Compare bodies
bodBuf.Seek(0, 0)
initBuf := bytes.NewBuffer(nil)
_, err = initBuf.ReadFrom(bodBuf)
if err != nil {
t.Fatal(err)
}
finBuf := bytes.NewBuffer(nil)
_, err = finBuf.ReadFrom(finalReq.Body)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(initBuf.Bytes(), finBuf.Bytes()) {
t.Fatalf("badbody :\ninitialReq:\n%#v\nfinalReq:\n%#v\n", initBuf.Bytes(), finBuf.Bytes())
}
}
return size
}