| // Copyright (c) HashiCorp, Inc. |
| // SPDX-License-Identifier: MPL-2.0 |
| |
| package transit |
| |
| import ( |
| "context" |
| "encoding/base64" |
| "encoding/hex" |
| "fmt" |
| "reflect" |
| "testing" |
| |
| "github.com/hashicorp/vault/helper/random" |
| "github.com/hashicorp/vault/sdk/logical" |
| ) |
| |
| func TestTransit_Random(t *testing.T) { |
| var b *backend |
| sysView := logical.TestSystemView() |
| storage := &logical.InmemStorage{} |
| sysView.CachingDisabledVal = true |
| |
| b, _ = Backend(context.Background(), &logical.BackendConfig{ |
| StorageView: storage, |
| System: sysView, |
| }) |
| |
| req := &logical.Request{ |
| Storage: storage, |
| Operation: logical.UpdateOperation, |
| Path: "random", |
| Data: map[string]interface{}{}, |
| } |
| |
| doRequest := func(req *logical.Request, errExpected bool, format string, numBytes int) { |
| getResponse := func() []byte { |
| resp, err := b.HandleRequest(context.Background(), req) |
| if err != nil && !errExpected { |
| t.Fatal(err) |
| } |
| if resp == nil { |
| t.Fatal("expected non-nil response") |
| } |
| if errExpected { |
| if !resp.IsError() { |
| t.Fatalf("bad: got error response: %#v", *resp) |
| } |
| return nil |
| } |
| if resp.IsError() { |
| t.Fatalf("bad: got error response: %#v", *resp) |
| } |
| if _, ok := resp.Data["random_bytes"]; !ok { |
| t.Fatal("no random_bytes found in response") |
| } |
| |
| outputStr := resp.Data["random_bytes"].(string) |
| var outputBytes []byte |
| switch format { |
| case "base64": |
| outputBytes, err = base64.StdEncoding.DecodeString(outputStr) |
| case "hex": |
| outputBytes, err = hex.DecodeString(outputStr) |
| default: |
| t.Fatal("unknown format") |
| } |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| return outputBytes |
| } |
| |
| rand1 := getResponse() |
| // Expected error |
| if rand1 == nil { |
| return |
| } |
| rand2 := getResponse() |
| if len(rand1) != numBytes || len(rand2) != numBytes { |
| t.Fatal("length of output random bytes not what is expected") |
| } |
| if reflect.DeepEqual(rand1, rand2) { |
| t.Fatal("found identical ouputs") |
| } |
| } |
| |
| for _, source := range []string{"", "platform", "seal", "all"} { |
| req.Data["source"] = source |
| req.Data["bytes"] = 32 |
| req.Data["format"] = "base64" |
| req.Path = "random" |
| // Test defaults |
| doRequest(req, false, "base64", 32) |
| |
| // Test size selection in the path |
| req.Path = "random/24" |
| req.Data["format"] = "hex" |
| doRequest(req, false, "hex", 24) |
| |
| if source != "" { |
| // Test source selection in the path |
| req.Path = fmt.Sprintf("random/%s", source) |
| req.Data["format"] = "hex" |
| doRequest(req, false, "hex", 32) |
| |
| req.Path = fmt.Sprintf("random/%s/24", source) |
| req.Data["format"] = "hex" |
| doRequest(req, false, "hex", 24) |
| } |
| |
| // Test bad input/format |
| req.Path = "random" |
| req.Data["format"] = "base92" |
| doRequest(req, true, "", 0) |
| |
| req.Data["format"] = "hex" |
| req.Data["bytes"] = -1 |
| doRequest(req, true, "", 0) |
| |
| req.Data["format"] = "hex" |
| req.Data["bytes"] = random.APIMaxBytes + 1 |
| |
| doRequest(req, true, "", 0) |
| } |
| } |