blob: 8d468c9775a548217a669b6d58f30c1c6dacaa80 [file] [log] [blame] [edit]
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package retryablehttp
import (
"context"
"errors"
"io"
"net"
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"sync/atomic"
"testing"
)
func TestRoundTripper_implements(t *testing.T) {
// Compile-time proof of interface satisfaction.
var _ http.RoundTripper = &RoundTripper{}
}
func TestRoundTripper_init(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
}))
defer ts.Close()
// Start with a new empty RoundTripper.
rt := &RoundTripper{}
// RoundTrip once.
req, _ := http.NewRequest("GET", ts.URL, nil)
if _, err := rt.RoundTrip(req); err != nil {
t.Fatal(err)
}
// Check that the Client was initialized.
if rt.Client == nil {
t.Fatal("expected rt.Client to be initialized")
}
// Save the Client for later comparison.
initialClient := rt.Client
// RoundTrip again.
req, _ = http.NewRequest("GET", ts.URL, nil)
if _, err := rt.RoundTrip(req); err != nil {
t.Fatal(err)
}
// Check that the underlying Client is unchanged.
if rt.Client != initialClient {
t.Fatalf("expected %v, got %v", initialClient, rt.Client)
}
}
func TestRoundTripper_RoundTrip(t *testing.T) {
var reqCount int32 = 0
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reqNo := atomic.AddInt32(&reqCount, 1)
if reqNo < 3 {
w.WriteHeader(404)
} else {
w.WriteHeader(200)
if _, err := w.Write([]byte("success!")); err != nil {
t.Fatalf("failed to write: %v", err)
}
}
}))
defer ts.Close()
// Make a client with some custom settings to verify they are used.
retryClient := NewClient()
retryClient.CheckRetry = func(_ context.Context, resp *http.Response, _ error) (bool, error) {
return resp.StatusCode == 404, nil
}
// Get the standard client and execute the request.
client := retryClient.StandardClient()
resp, err := client.Get(ts.URL)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
// Check the response to ensure the client behaved as expected.
if resp.StatusCode != 200 {
t.Fatalf("expected 200, got %d", resp.StatusCode)
}
if v, err := io.ReadAll(resp.Body); err != nil {
t.Fatal(err)
} else if string(v) != "success!" {
t.Fatalf("expected %q, got %q", "success!", v)
}
}
func TestRoundTripper_TransportFailureErrorHandling(t *testing.T) {
// Make a client with some custom settings to verify they are used.
retryClient := NewClient()
retryClient.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) {
if err != nil {
return true, err
}
return false, nil
}
retryClient.ErrorHandler = PassthroughErrorHandler
expectedError := &url.Error{
Op: "Get",
URL: "http://999.999.999.999:999/",
Err: &net.OpError{
Op: "dial",
Net: "tcp",
Err: &net.DNSError{
Name: "999.999.999.999",
Err: "no such host",
IsNotFound: true,
},
},
}
// Get the standard client and execute the request.
client := retryClient.StandardClient()
_, err := client.Get("http://999.999.999.999:999/")
// assert expectations
if !reflect.DeepEqual(expectedError, normalizeError(err)) {
t.Fatalf("expected %q, got %q", expectedError, err)
}
}
func normalizeError(err error) error {
var dnsError *net.DNSError
if errors.As(err, &dnsError) {
// this field is populated with the DNS server on on CI, but not locally
dnsError.Server = ""
}
return err
}