blob: 364b6fe7dd0195080c9f48a0cb69961f4a0c0a1c [file] [log] [blame]
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package transport
import (
"context"
"errors"
"fmt"
"log"
"strings"
"sync"
"testing"
"time"
)
func TestRequestBatcher_batchSingle(t *testing.T) {
testBasicCountBatches(t, "test-single", 1)
}
func TestRequestBatcher_batchMultiple(t *testing.T) {
testBasicCountBatches(t, "test-multiple", 10)
}
func TestRequestBatcher_disableBatching(t *testing.T) {
testBatcher := NewRequestBatcher(
"testBatcher",
context.Background(),
&BatchingConfig{
SendAfter: time.Duration(1) * time.Second,
EnableBatching: false,
})
testCombine := func(currV interface{}, toAddV interface{}) (interface{}, error) {
return currV.(int) + toAddV.(int), nil
}
testSendBatch := func(name string, body interface{}) (interface{}, error) {
return fmt.Sprintf("%s: %d", name, body), nil
}
wg := sync.WaitGroup{}
wg.Add(3)
for i := 0; i < 3; i++ {
go func(idx int) {
defer wg.Done()
req := &BatchRequest{
DebugId: fmt.Sprintf("Test Single Requests #%d", idx),
ResourceName: "testNoBatching",
Body: 1,
CombineF: testCombine,
SendF: testSendBatch,
}
respV, err := testBatcher.SendRequestWithTimeout(
"testDisableBatching", req, time.Duration(1)*time.Second)
if err != nil {
t.Errorf("got unexpected error %s", err)
}
resp, ok := respV.(string)
if !ok {
t.Errorf("test returned an non-string response: %v", resp)
}
if resp != "testNoBatching: 1" {
t.Errorf("expected single request response, got %s", resp)
}
}(i)
}
}
func TestRequestBatcher_errInCombine(t *testing.T) {
testBatcher := NewRequestBatcher(
"testBatcher",
context.Background(),
&BatchingConfig{
SendAfter: time.Duration(5) * time.Second,
EnableBatching: true,
})
combineErrText := "this is an expected error in combine"
testCombine := func(_ interface{}, _ interface{}) (interface{}, error) {
return nil, errors.New(combineErrText)
}
// sendBatchF is no-op
testSendBatch := func(_ string, _ interface{}) (interface{}, error) {
return nil, nil
}
wg := sync.WaitGroup{}
wg.Add(2)
// First call should no-op.
go func() {
defer wg.Done()
req := &BatchRequest{
DebugId: "errInCombine first",
ResourceName: "test-resource",
Body: nil,
CombineF: testCombine,
SendF: testSendBatch,
}
_, err := testBatcher.SendRequestWithTimeout("testCombineErr", req, time.Duration(10)*time.Second)
if err != nil {
t.Errorf("expected no error, got: %s", err)
}
}()
// Second call should fail when being combined with original batch
go func() {
time.Sleep(time.Second)
defer wg.Done()
req := &BatchRequest{
DebugId: "errInCombine second",
ResourceName: "test-resource",
Body: nil,
CombineF: testCombine,
SendF: testSendBatch,
}
_, err := testBatcher.SendRequestWithTimeout("testCombineErr", req, time.Duration(10)*time.Second)
if err == nil {
t.Errorf("expected error, got none")
} else if !strings.Contains(err.Error(), combineErrText) {
t.Errorf("error does not contain expected error %s. Got: %s", combineErrText, err)
}
}()
wg.Wait()
}
func TestRequestBatcher_errInSend(t *testing.T) {
testBatcher := NewRequestBatcher(
"testBatcher",
context.Background(),
&BatchingConfig{
SendAfter: time.Duration(5) * time.Second,
EnableBatching: true,
})
// combineF keeps track of the batched indexes
testCombine := func(body interface{}, toAdd interface{}) (interface{}, error) {
return append(body.([]int), toAdd.([]int)...), nil
}
failIdx := 0
testResource := "RESOURCE-SEND-ERROR"
expectedErrMsg := fmt.Sprintf("Error - batch %q contains idx %d", testResource, failIdx)
testSendBatch := func(resourceName string, body interface{}) (interface{}, error) {
log.Printf("[DEBUG] sendBatch body: %+v", body)
for _, v := range body.([]int) {
if v == failIdx {
return nil, fmt.Errorf(expectedErrMsg)
}
}
return nil, nil
}
numRequests := 3
wg := sync.WaitGroup{}
wg.Add(numRequests)
for i := 0; i < numRequests; i++ {
go func(idx int) {
defer wg.Done()
req := &BatchRequest{
DebugId: fmt.Sprintf("sendError %d", idx),
ResourceName: testResource,
Body: []int{idx},
CombineF: testCombine,
SendF: testSendBatch,
}
_, err := testBatcher.SendRequestWithTimeout("batchSendError", req, time.Duration(10)*time.Second)
// Requests without index 0 should have succeeded
if idx == failIdx {
// We expect an error
if err == nil {
t.Errorf("expected error for request %d, got none", idx)
}
// Check error message
if !strings.Contains(err.Error(), expectedErrMsg) {
t.Errorf("expected error %q to contain %q", err, expectedErrMsg)
}
} else {
// We shouldn't get error for non-failure index
if err != nil {
t.Errorf("expected request %d to succeed, got error: %v", idx, err)
}
}
}(i)
}
wg.Wait()
}
func TestRequestBatcher_errTimeout(t *testing.T) {
testBatcher := NewRequestBatcher(
"testBatcher",
context.Background(),
&BatchingConfig{
SendAfter: time.Duration(5) * time.Second,
EnableBatching: true,
})
testResource := "resource for send error"
// no-op
testCombine := func(v interface{}, _ interface{}) (interface{}, error) {
return v, nil
}
// no-op
testSendBatch := func(resourceName string, cnt interface{}) (interface{}, error) {
return nil, nil
}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
req := &BatchRequest{
DebugId: fmt.Sprintf("timeout test"),
ResourceName: testResource,
Body: 1,
CombineF: testCombine,
SendF: testSendBatch,
}
_, err := testBatcher.SendRequestWithTimeout("batchTimeout", req, time.Duration(1)*time.Second)
if err == nil {
t.Errorf("expected error, got none")
} else if !strings.Contains(err.Error(), "timed out") {
t.Errorf("expected timeout error, got %v", err)
}
}()
wg.Wait()
}
func testBasicCountBatches(t *testing.T, testName string, numBatches int) {
testBatcher := NewRequestBatcher(
"testBatcher",
context.Background(),
&BatchingConfig{
SendAfter: time.Duration(1) * time.Second,
EnableBatching: true,
})
testCombine := func(currV interface{}, toAddV interface{}) (interface{}, error) {
return currV.(int) + toAddV.(int), nil
}
testSendBatch := func(name string, body interface{}) (interface{}, error) {
return fmt.Sprintf("%s: %d", name, body), nil
}
wg := sync.WaitGroup{}
wg.Add(numBatches)
for i := 0; i < numBatches; i++ {
go func(idx int) {
defer wg.Done()
req := &BatchRequest{
DebugId: fmt.Sprintf("Test '%s' Request #%d", testName, idx),
ResourceName: testName,
Body: 1,
CombineF: testCombine,
SendF: testSendBatch,
}
respV, err := testBatcher.SendRequestWithTimeout("testBatching", req, time.Duration(6)*time.Second)
if err != nil {
t.Errorf("got unexpected error %s", err)
}
resp, ok := respV.(string)
if !ok {
t.Errorf("test returned an non-string response: %v", resp)
}
expected := fmt.Sprintf("%s: %d", testName, numBatches)
if resp != expected {
t.Errorf("expected response %s, got %s", expected, resp)
}
}(i)
}
}