| // Copyright (c) HashiCorp, Inc. |
| // SPDX-License-Identifier: MPL-2.0 |
| |
| package transit_test |
| |
| import ( |
| "encoding/hex" |
| "encoding/json" |
| "fmt" |
| "testing" |
| "time" |
| |
| uuid "github.com/hashicorp/go-uuid" |
| "github.com/hashicorp/vault/api" |
| "github.com/hashicorp/vault/audit" |
| "github.com/hashicorp/vault/builtin/audit/file" |
| "github.com/hashicorp/vault/builtin/logical/transit" |
| vaulthttp "github.com/hashicorp/vault/http" |
| "github.com/hashicorp/vault/sdk/logical" |
| "github.com/hashicorp/vault/vault" |
| ) |
| |
| func TestTransit_Issue_2958(t *testing.T) { |
| coreConfig := &vault.CoreConfig{ |
| LogicalBackends: map[string]logical.Factory{ |
| "transit": transit.Factory, |
| }, |
| AuditBackends: map[string]audit.Factory{ |
| "file": file.Factory, |
| }, |
| } |
| |
| cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ |
| HandlerFunc: vaulthttp.Handler, |
| }) |
| cluster.Start() |
| defer cluster.Cleanup() |
| |
| cores := cluster.Cores |
| |
| vault.TestWaitActive(t, cores[0].Core) |
| |
| client := cores[0].Client |
| |
| err := client.Sys().EnableAuditWithOptions("file", &api.EnableAuditOptions{ |
| Type: "file", |
| Options: map[string]string{ |
| "file_path": "/dev/null", |
| }, |
| }) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| err = client.Sys().Mount("transit", &api.MountInput{ |
| Type: "transit", |
| }) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| _, err = client.Logical().Write("transit/keys/foo", map[string]interface{}{ |
| "type": "ecdsa-p256", |
| }) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| _, err = client.Logical().Write("transit/keys/foobar", map[string]interface{}{ |
| "type": "ecdsa-p384", |
| }) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| _, err = client.Logical().Write("transit/keys/bar", map[string]interface{}{ |
| "type": "ed25519", |
| }) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| _, err = client.Logical().Read("transit/keys/foo") |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| _, err = client.Logical().Read("transit/keys/foobar") |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| _, err = client.Logical().Read("transit/keys/bar") |
| if err != nil { |
| t.Fatal(err) |
| } |
| } |
| |
| func TestTransit_CreateKeyWithAutorotation(t *testing.T) { |
| tests := map[string]struct { |
| autoRotatePeriod interface{} |
| shouldError bool |
| expectedValue time.Duration |
| }{ |
| "default (no value)": { |
| shouldError: false, |
| }, |
| "0 (int)": { |
| autoRotatePeriod: 0, |
| shouldError: false, |
| expectedValue: 0, |
| }, |
| "0 (string)": { |
| autoRotatePeriod: "0", |
| shouldError: false, |
| expectedValue: 0, |
| }, |
| "5 seconds": { |
| autoRotatePeriod: "5s", |
| shouldError: true, |
| }, |
| "5 hours": { |
| autoRotatePeriod: "5h", |
| shouldError: false, |
| expectedValue: 5 * time.Hour, |
| }, |
| "negative value": { |
| autoRotatePeriod: "-1800s", |
| shouldError: true, |
| }, |
| "invalid string": { |
| autoRotatePeriod: "this shouldn't work", |
| shouldError: true, |
| }, |
| } |
| |
| coreConfig := &vault.CoreConfig{ |
| LogicalBackends: map[string]logical.Factory{ |
| "transit": transit.Factory, |
| }, |
| } |
| cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ |
| HandlerFunc: vaulthttp.Handler, |
| }) |
| cluster.Start() |
| defer cluster.Cleanup() |
| cores := cluster.Cores |
| vault.TestWaitActive(t, cores[0].Core) |
| client := cores[0].Client |
| err := client.Sys().Mount("transit", &api.MountInput{ |
| Type: "transit", |
| }) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| for name, test := range tests { |
| t.Run(name, func(t *testing.T) { |
| keyNameBytes, err := uuid.GenerateRandomBytes(16) |
| if err != nil { |
| t.Fatal(err) |
| } |
| keyName := hex.EncodeToString(keyNameBytes) |
| |
| _, err = client.Logical().Write(fmt.Sprintf("transit/keys/%s", keyName), map[string]interface{}{ |
| "auto_rotate_period": test.autoRotatePeriod, |
| }) |
| switch { |
| case test.shouldError && err == nil: |
| t.Fatal("expected non-nil error") |
| case !test.shouldError && err != nil: |
| t.Fatal(err) |
| } |
| |
| if !test.shouldError { |
| resp, err := client.Logical().Read(fmt.Sprintf("transit/keys/%s", keyName)) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if resp == nil { |
| t.Fatal("expected non-nil response") |
| } |
| gotRaw, ok := resp.Data["auto_rotate_period"].(json.Number) |
| if !ok { |
| t.Fatal("returned value is of unexpected type") |
| } |
| got, err := gotRaw.Int64() |
| if err != nil { |
| t.Fatal(err) |
| } |
| want := int64(test.expectedValue.Seconds()) |
| if got != want { |
| t.Fatalf("incorrect auto_rotate_period returned, got: %d, want: %d", got, want) |
| } |
| } |
| }) |
| } |
| } |