| // Copyright (c) HashiCorp, Inc. |
| // SPDX-License-Identifier: MPL-2.0 |
| |
| //go:build !race && !hsm && !fips_140_3 |
| |
| // NOTE: we can't use this with HSM. We can't set testing mode on and it's not |
| // safe to use env vars since that provides an attack vector in the real world. |
| // |
| // The server tests have a go-metrics/exp manager race condition :(. |
| |
| package command |
| |
| import ( |
| "crypto/tls" |
| "crypto/x509" |
| "fmt" |
| "io/ioutil" |
| "os" |
| "strings" |
| "sync" |
| "testing" |
| "time" |
| |
| "github.com/hashicorp/vault/sdk/physical" |
| physInmem "github.com/hashicorp/vault/sdk/physical/inmem" |
| "github.com/mitchellh/cli" |
| "github.com/stretchr/testify/require" |
| ) |
| |
| func init() { |
| if signed := os.Getenv("VAULT_LICENSE_CI"); signed != "" { |
| os.Setenv(EnvVaultLicense, signed) |
| } |
| } |
| |
| func testBaseHCL(tb testing.TB, listenerExtras string) string { |
| tb.Helper() |
| |
| return strings.TrimSpace(fmt.Sprintf(` |
| disable_mlock = true |
| listener "tcp" { |
| address = "127.0.0.1:%d" |
| tls_disable = "true" |
| %s |
| } |
| `, 0, listenerExtras)) |
| } |
| |
| const ( |
| goodListenerTimeouts = `http_read_header_timeout = 12 |
| http_read_timeout = "34s" |
| http_write_timeout = "56m" |
| http_idle_timeout = "78h"` |
| |
| badListenerReadHeaderTimeout = `http_read_header_timeout = "12km"` |
| badListenerReadTimeout = `http_read_timeout = "34日"` |
| badListenerWriteTimeout = `http_write_timeout = "56lbs"` |
| badListenerIdleTimeout = `http_idle_timeout = "78gophers"` |
| |
| inmemHCL = ` |
| backend "inmem_ha" { |
| advertise_addr = "http://127.0.0.1:8200" |
| } |
| ` |
| haInmemHCL = ` |
| ha_backend "inmem_ha" { |
| redirect_addr = "http://127.0.0.1:8200" |
| } |
| ` |
| |
| badHAInmemHCL = ` |
| ha_backend "inmem" {} |
| ` |
| |
| reloadHCL = ` |
| backend "inmem" {} |
| disable_mlock = true |
| listener "tcp" { |
| address = "127.0.0.1:8203" |
| tls_cert_file = "TMPDIR/reload_cert.pem" |
| tls_key_file = "TMPDIR/reload_key.pem" |
| } |
| ` |
| cloudHCL = ` |
| cloud { |
| resource_id = "organization/bc58b3d0-2eab-4ab8-abf4-f61d3c9975ff/project/1c78e888-2142-4000-8918-f933bbbc7690/hashicorp.example.resource/example" |
| client_id = "J2TtcSYOyPUkPV2z0mSyDtvitxLVjJmu" |
| client_secret = "N9JtHZyOnHrIvJZs82pqa54vd4jnkyU3xCcqhFXuQKJZZuxqxxbP1xCfBZVB82vY" |
| } |
| ` |
| ) |
| |
| func testServerCommand(tb testing.TB) (*cli.MockUi, *ServerCommand) { |
| tb.Helper() |
| |
| ui := cli.NewMockUi() |
| return ui, &ServerCommand{ |
| BaseCommand: &BaseCommand{ |
| UI: ui, |
| }, |
| ShutdownCh: MakeShutdownCh(), |
| SighupCh: MakeSighupCh(), |
| SigUSR2Ch: MakeSigUSR2Ch(), |
| PhysicalBackends: map[string]physical.Factory{ |
| "inmem": physInmem.NewInmem, |
| "inmem_ha": physInmem.NewInmemHA, |
| }, |
| |
| // These prevent us from random sleep guessing... |
| startedCh: make(chan struct{}, 5), |
| reloadedCh: make(chan struct{}, 5), |
| licenseReloadedCh: make(chan error), |
| } |
| } |
| |
| func TestServer_ReloadListener(t *testing.T) { |
| t.Parallel() |
| |
| wd, _ := os.Getwd() |
| wd += "/server/test-fixtures/reload/" |
| |
| td, err := ioutil.TempDir("", "vault-test-") |
| if err != nil { |
| t.Fatal(err) |
| } |
| defer os.RemoveAll(td) |
| |
| wg := &sync.WaitGroup{} |
| // Setup initial certs |
| inBytes, _ := ioutil.ReadFile(wd + "reload_foo.pem") |
| ioutil.WriteFile(td+"/reload_cert.pem", inBytes, 0o777) |
| inBytes, _ = ioutil.ReadFile(wd + "reload_foo.key") |
| ioutil.WriteFile(td+"/reload_key.pem", inBytes, 0o777) |
| |
| relhcl := strings.ReplaceAll(reloadHCL, "TMPDIR", td) |
| ioutil.WriteFile(td+"/reload.hcl", []byte(relhcl), 0o777) |
| |
| inBytes, _ = ioutil.ReadFile(wd + "reload_ca.pem") |
| certPool := x509.NewCertPool() |
| ok := certPool.AppendCertsFromPEM(inBytes) |
| if !ok { |
| t.Fatal("not ok when appending CA cert") |
| } |
| |
| ui, cmd := testServerCommand(t) |
| _ = ui |
| |
| wg.Add(1) |
| args := []string{"-config", td + "/reload.hcl"} |
| go func() { |
| if code := cmd.Run(args); code != 0 { |
| output := ui.ErrorWriter.String() + ui.OutputWriter.String() |
| t.Errorf("got a non-zero exit status: %s", output) |
| } |
| wg.Done() |
| }() |
| |
| testCertificateName := func(cn string) error { |
| conn, err := tls.Dial("tcp", "127.0.0.1:8203", &tls.Config{ |
| RootCAs: certPool, |
| }) |
| if err != nil { |
| return err |
| } |
| defer conn.Close() |
| if err = conn.Handshake(); err != nil { |
| return err |
| } |
| servName := conn.ConnectionState().PeerCertificates[0].Subject.CommonName |
| if servName != cn { |
| return fmt.Errorf("expected %s, got %s", cn, servName) |
| } |
| return nil |
| } |
| |
| select { |
| case <-cmd.startedCh: |
| case <-time.After(5 * time.Second): |
| t.Fatalf("timeout") |
| } |
| |
| if err := testCertificateName("foo.example.com"); err != nil { |
| t.Fatalf("certificate name didn't check out: %s", err) |
| } |
| |
| relhcl = strings.ReplaceAll(reloadHCL, "TMPDIR", td) |
| inBytes, _ = ioutil.ReadFile(wd + "reload_bar.pem") |
| ioutil.WriteFile(td+"/reload_cert.pem", inBytes, 0o777) |
| inBytes, _ = ioutil.ReadFile(wd + "reload_bar.key") |
| ioutil.WriteFile(td+"/reload_key.pem", inBytes, 0o777) |
| ioutil.WriteFile(td+"/reload.hcl", []byte(relhcl), 0o777) |
| |
| cmd.SighupCh <- struct{}{} |
| select { |
| case <-cmd.reloadedCh: |
| case <-time.After(5 * time.Second): |
| t.Fatalf("timeout") |
| } |
| |
| if err := testCertificateName("bar.example.com"); err != nil { |
| t.Fatalf("certificate name didn't check out: %s", err) |
| } |
| |
| cmd.ShutdownCh <- struct{}{} |
| |
| wg.Wait() |
| } |
| |
| func TestServer(t *testing.T) { |
| t.Parallel() |
| |
| cases := []struct { |
| name string |
| contents string |
| exp string |
| code int |
| args []string |
| }{ |
| { |
| "common_ha", |
| testBaseHCL(t, "") + inmemHCL, |
| "(HA available)", |
| 0, |
| []string{"-test-verify-only"}, |
| }, |
| { |
| "separate_ha", |
| testBaseHCL(t, "") + inmemHCL + haInmemHCL, |
| "HA Storage:", |
| 0, |
| []string{"-test-verify-only"}, |
| }, |
| { |
| "bad_separate_ha", |
| testBaseHCL(t, "") + inmemHCL + badHAInmemHCL, |
| "Specified HA storage does not support HA", |
| 1, |
| []string{"-test-verify-only"}, |
| }, |
| { |
| "good_listener_timeout_config", |
| testBaseHCL(t, goodListenerTimeouts) + inmemHCL, |
| "", |
| 0, |
| []string{"-test-server-config"}, |
| }, |
| { |
| "bad_listener_read_header_timeout_config", |
| testBaseHCL(t, badListenerReadHeaderTimeout) + inmemHCL, |
| "unknown unit \"km\" in duration \"12km\"", |
| 1, |
| []string{"-test-server-config"}, |
| }, |
| { |
| "bad_listener_read_timeout_config", |
| testBaseHCL(t, badListenerReadTimeout) + inmemHCL, |
| "unknown unit \"\\xe6\\x97\\xa5\" in duration", |
| 1, |
| []string{"-test-server-config"}, |
| }, |
| { |
| "bad_listener_write_timeout_config", |
| testBaseHCL(t, badListenerWriteTimeout) + inmemHCL, |
| "unknown unit \"lbs\" in duration \"56lbs\"", |
| 1, |
| []string{"-test-server-config"}, |
| }, |
| { |
| "bad_listener_idle_timeout_config", |
| testBaseHCL(t, badListenerIdleTimeout) + inmemHCL, |
| "unknown unit \"gophers\" in duration \"78gophers\"", |
| 1, |
| []string{"-test-server-config"}, |
| }, |
| { |
| "environment_variables_logged", |
| testBaseHCL(t, "") + inmemHCL, |
| "Environment Variables", |
| 0, |
| []string{"-test-verify-only"}, |
| }, |
| { |
| "cloud_config", |
| testBaseHCL(t, "") + inmemHCL + cloudHCL, |
| "HCP Organization: bc58b3d0-2eab-4ab8-abf4-f61d3c9975ff", |
| 0, |
| []string{"-test-verify-only"}, |
| }, |
| { |
| "recovery_mode", |
| testBaseHCL(t, "") + inmemHCL, |
| "", |
| 0, |
| []string{"-test-verify-only", "-recovery"}, |
| }, |
| } |
| |
| for _, tc := range cases { |
| tc := tc |
| |
| t.Run(tc.name, func(t *testing.T) { |
| t.Parallel() |
| |
| ui, cmd := testServerCommand(t) |
| |
| f, err := os.CreateTemp(t.TempDir(), "") |
| require.NoErrorf(t, err, "error creating temp dir: %v", err) |
| |
| _, err = f.WriteString(tc.contents) |
| require.NoErrorf(t, err, "cannot write temp file contents") |
| |
| err = f.Close() |
| require.NoErrorf(t, err, "unable to close temp file") |
| |
| args := append(tc.args, "-config", f.Name()) |
| code := cmd.Run(args) |
| output := ui.ErrorWriter.String() + ui.OutputWriter.String() |
| require.Equal(t, tc.code, code, "expected %d to be %d: %s", code, tc.code, output) |
| require.Contains(t, output, tc.exp, "expected %q to contain %q", output, tc.exp) |
| }) |
| } |
| } |
| |
| // TestServer_DevTLS verifies that a vault server starts up correctly with the -dev-tls flag |
| func TestServer_DevTLS(t *testing.T) { |
| ui, cmd := testServerCommand(t) |
| args := []string{"-dev-tls", "-dev-listen-address=127.0.0.1:0", "-test-server-config"} |
| retCode := cmd.Run(args) |
| output := ui.ErrorWriter.String() + ui.OutputWriter.String() |
| require.Equal(t, 0, retCode, output) |
| require.Contains(t, output, `tls: "enabled"`) |
| } |
| |
| // TestConfigureDevTLS verifies the various logic paths that flow through the |
| // configureDevTLS function. |
| func TestConfigureDevTLS(t *testing.T) { |
| testcases := []struct { |
| ServerCommand *ServerCommand |
| DeferFuncNotNil bool |
| ConfigNotNil bool |
| TLSDisable bool |
| CertPathEmpty bool |
| ErrNotNil bool |
| TestDescription string |
| }{ |
| { |
| ServerCommand: &ServerCommand{ |
| flagDevTLS: false, |
| }, |
| ConfigNotNil: true, |
| TLSDisable: true, |
| CertPathEmpty: true, |
| ErrNotNil: false, |
| TestDescription: "flagDev is false, nothing will be configured", |
| }, |
| { |
| ServerCommand: &ServerCommand{ |
| flagDevTLS: true, |
| flagDevTLSCertDir: "", |
| }, |
| DeferFuncNotNil: true, |
| ConfigNotNil: true, |
| ErrNotNil: false, |
| TestDescription: "flagDevTLSCertDir is empty", |
| }, |
| { |
| ServerCommand: &ServerCommand{ |
| flagDevTLS: true, |
| flagDevTLSCertDir: "@/#", |
| }, |
| CertPathEmpty: true, |
| ErrNotNil: true, |
| TestDescription: "flagDevTLSCertDir is set to something invalid", |
| }, |
| } |
| |
| for _, testcase := range testcases { |
| fun, cfg, certPath, err := configureDevTLS(testcase.ServerCommand) |
| if fun != nil { |
| // If a function is returned, call it right away to clean up |
| // files created in the temporary directory before anything else has |
| // a chance to fail this test. |
| fun() |
| } |
| |
| require.Equal(t, testcase.DeferFuncNotNil, (fun != nil), "test description %s", testcase.TestDescription) |
| require.Equal(t, testcase.ConfigNotNil, cfg != nil, "test description %s", testcase.TestDescription) |
| if testcase.ConfigNotNil { |
| require.True(t, len(cfg.Listeners) > 0, "test description %s", testcase.TestDescription) |
| require.Equal(t, testcase.TLSDisable, cfg.Listeners[0].TLSDisable, "test description %s", testcase.TestDescription) |
| } |
| require.Equal(t, testcase.CertPathEmpty, len(certPath) == 0, "test description %s", testcase.TestDescription) |
| require.Equal(t, testcase.ErrNotNil, (err != nil), "test description %s", testcase.TestDescription) |
| } |
| } |