blob: 35782ec3eadc8b2c11cf2e85acfb6dd62968a574 [file] [log] [blame] [edit]
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package transit
import (
"context"
"encoding/base64"
"encoding/hex"
"fmt"
"reflect"
"testing"
"github.com/hashicorp/vault/helper/random"
"github.com/hashicorp/vault/sdk/logical"
)
func TestTransit_Random(t *testing.T) {
var b *backend
sysView := logical.TestSystemView()
storage := &logical.InmemStorage{}
sysView.CachingDisabledVal = true
b, _ = Backend(context.Background(), &logical.BackendConfig{
StorageView: storage,
System: sysView,
})
req := &logical.Request{
Storage: storage,
Operation: logical.UpdateOperation,
Path: "random",
Data: map[string]interface{}{},
}
doRequest := func(req *logical.Request, errExpected bool, format string, numBytes int) {
getResponse := func() []byte {
resp, err := b.HandleRequest(context.Background(), req)
if err != nil && !errExpected {
t.Fatal(err)
}
if resp == nil {
t.Fatal("expected non-nil response")
}
if errExpected {
if !resp.IsError() {
t.Fatalf("bad: got error response: %#v", *resp)
}
return nil
}
if resp.IsError() {
t.Fatalf("bad: got error response: %#v", *resp)
}
if _, ok := resp.Data["random_bytes"]; !ok {
t.Fatal("no random_bytes found in response")
}
outputStr := resp.Data["random_bytes"].(string)
var outputBytes []byte
switch format {
case "base64":
outputBytes, err = base64.StdEncoding.DecodeString(outputStr)
case "hex":
outputBytes, err = hex.DecodeString(outputStr)
default:
t.Fatal("unknown format")
}
if err != nil {
t.Fatal(err)
}
return outputBytes
}
rand1 := getResponse()
// Expected error
if rand1 == nil {
return
}
rand2 := getResponse()
if len(rand1) != numBytes || len(rand2) != numBytes {
t.Fatal("length of output random bytes not what is expected")
}
if reflect.DeepEqual(rand1, rand2) {
t.Fatal("found identical ouputs")
}
}
for _, source := range []string{"", "platform", "seal", "all"} {
req.Data["source"] = source
req.Data["bytes"] = 32
req.Data["format"] = "base64"
req.Path = "random"
// Test defaults
doRequest(req, false, "base64", 32)
// Test size selection in the path
req.Path = "random/24"
req.Data["format"] = "hex"
doRequest(req, false, "hex", 24)
if source != "" {
// Test source selection in the path
req.Path = fmt.Sprintf("random/%s", source)
req.Data["format"] = "hex"
doRequest(req, false, "hex", 32)
req.Path = fmt.Sprintf("random/%s/24", source)
req.Data["format"] = "hex"
doRequest(req, false, "hex", 24)
}
// Test bad input/format
req.Path = "random"
req.Data["format"] = "base92"
doRequest(req, true, "", 0)
req.Data["format"] = "hex"
req.Data["bytes"] = -1
doRequest(req, true, "", 0)
req.Data["format"] = "hex"
req.Data["bytes"] = random.APIMaxBytes + 1
doRequest(req, true, "", 0)
}
}