| // Copyright (c) HashiCorp, Inc. |
| // SPDX-License-Identifier: MPL-2.0 |
| |
| package docker |
| |
| import ( |
| "bufio" |
| "bytes" |
| "context" |
| "crypto/ecdsa" |
| "crypto/elliptic" |
| "crypto/rand" |
| "crypto/tls" |
| "crypto/x509" |
| "crypto/x509/pkix" |
| "encoding/hex" |
| "encoding/json" |
| "encoding/pem" |
| "fmt" |
| "io" |
| "io/ioutil" |
| "math/big" |
| mathrand "math/rand" |
| "net" |
| "net/http" |
| "os" |
| "path/filepath" |
| "strings" |
| "sync" |
| "testing" |
| "time" |
| |
| "github.com/docker/docker/api/types" |
| "github.com/docker/docker/api/types/volume" |
| docker "github.com/docker/docker/client" |
| "github.com/hashicorp/go-cleanhttp" |
| log "github.com/hashicorp/go-hclog" |
| "github.com/hashicorp/go-multierror" |
| "github.com/hashicorp/vault/api" |
| dockhelper "github.com/hashicorp/vault/sdk/helper/docker" |
| "github.com/hashicorp/vault/sdk/helper/logging" |
| "github.com/hashicorp/vault/sdk/helper/testcluster" |
| uberAtomic "go.uber.org/atomic" |
| "golang.org/x/net/http2" |
| ) |
| |
| var ( |
| _ testcluster.VaultCluster = &DockerCluster{} |
| _ testcluster.VaultClusterNode = &DockerClusterNode{} |
| ) |
| |
| const MaxClusterNameLength = 52 |
| |
| // DockerCluster is used to managing the lifecycle of the test Vault cluster |
| type DockerCluster struct { |
| ClusterName string |
| |
| ClusterNodes []*DockerClusterNode |
| |
| // Certificate fields |
| *testcluster.CA |
| RootCAs *x509.CertPool |
| |
| barrierKeys [][]byte |
| recoveryKeys [][]byte |
| tmpDir string |
| |
| // rootToken is the initial root token created when the Vault cluster is |
| // created. |
| rootToken string |
| DockerAPI *docker.Client |
| ID string |
| Logger log.Logger |
| builtTags map[string]struct{} |
| |
| storage testcluster.ClusterStorage |
| } |
| |
| func (dc *DockerCluster) NamedLogger(s string) log.Logger { |
| return dc.Logger.Named(s) |
| } |
| |
| func (dc *DockerCluster) ClusterID() string { |
| return dc.ID |
| } |
| |
| func (dc *DockerCluster) Nodes() []testcluster.VaultClusterNode { |
| ret := make([]testcluster.VaultClusterNode, len(dc.ClusterNodes)) |
| for i := range dc.ClusterNodes { |
| ret[i] = dc.ClusterNodes[i] |
| } |
| return ret |
| } |
| |
| func (dc *DockerCluster) GetBarrierKeys() [][]byte { |
| return dc.barrierKeys |
| } |
| |
| func testKeyCopy(key []byte) []byte { |
| result := make([]byte, len(key)) |
| copy(result, key) |
| return result |
| } |
| |
| func (dc *DockerCluster) GetRecoveryKeys() [][]byte { |
| ret := make([][]byte, len(dc.recoveryKeys)) |
| for i, k := range dc.recoveryKeys { |
| ret[i] = testKeyCopy(k) |
| } |
| return ret |
| } |
| |
| func (dc *DockerCluster) GetBarrierOrRecoveryKeys() [][]byte { |
| return dc.GetBarrierKeys() |
| } |
| |
| func (dc *DockerCluster) SetBarrierKeys(keys [][]byte) { |
| dc.barrierKeys = make([][]byte, len(keys)) |
| for i, k := range keys { |
| dc.barrierKeys[i] = testKeyCopy(k) |
| } |
| } |
| |
| func (dc *DockerCluster) SetRecoveryKeys(keys [][]byte) { |
| dc.recoveryKeys = make([][]byte, len(keys)) |
| for i, k := range keys { |
| dc.recoveryKeys[i] = testKeyCopy(k) |
| } |
| } |
| |
| func (dc *DockerCluster) GetCACertPEMFile() string { |
| return dc.CACertPEMFile |
| } |
| |
| func (dc *DockerCluster) Cleanup() { |
| dc.cleanup() |
| } |
| |
| func (dc *DockerCluster) cleanup() error { |
| var result *multierror.Error |
| for _, node := range dc.ClusterNodes { |
| if err := node.cleanup(); err != nil { |
| result = multierror.Append(result, err) |
| } |
| } |
| |
| return result.ErrorOrNil() |
| } |
| |
| // GetRootToken returns the root token of the cluster, if set |
| func (dc *DockerCluster) GetRootToken() string { |
| return dc.rootToken |
| } |
| |
| func (dc *DockerCluster) SetRootToken(s string) { |
| dc.Logger.Trace("cluster root token changed", "helpful_env", fmt.Sprintf("VAULT_TOKEN=%s VAULT_CACERT=/vault/config/ca.pem", s)) |
| dc.rootToken = s |
| } |
| |
| func (n *DockerClusterNode) Name() string { |
| return n.Cluster.ClusterName + "-" + n.NodeID |
| } |
| |
| func (dc *DockerCluster) setupNode0(ctx context.Context) error { |
| client := dc.ClusterNodes[0].client |
| |
| var resp *api.InitResponse |
| var err error |
| for ctx.Err() == nil { |
| resp, err = client.Sys().Init(&api.InitRequest{ |
| SecretShares: 3, |
| SecretThreshold: 3, |
| }) |
| if err == nil && resp != nil { |
| break |
| } |
| time.Sleep(500 * time.Millisecond) |
| } |
| if err != nil { |
| return err |
| } |
| if resp == nil { |
| return fmt.Errorf("nil response to init request") |
| } |
| |
| for _, k := range resp.Keys { |
| raw, err := hex.DecodeString(k) |
| if err != nil { |
| return err |
| } |
| dc.barrierKeys = append(dc.barrierKeys, raw) |
| } |
| |
| for _, k := range resp.RecoveryKeys { |
| raw, err := hex.DecodeString(k) |
| if err != nil { |
| return err |
| } |
| dc.recoveryKeys = append(dc.recoveryKeys, raw) |
| } |
| |
| dc.rootToken = resp.RootToken |
| client.SetToken(dc.rootToken) |
| dc.ClusterNodes[0].client = client |
| |
| err = testcluster.UnsealNode(ctx, dc, 0) |
| if err != nil { |
| return err |
| } |
| |
| err = ensureLeaderMatches(ctx, client, func(leader *api.LeaderResponse) error { |
| if !leader.IsSelf { |
| return fmt.Errorf("node %d leader=%v, expected=%v", 0, leader.IsSelf, true) |
| } |
| |
| return nil |
| }) |
| |
| status, err := client.Sys().SealStatusWithContext(ctx) |
| if err != nil { |
| return err |
| } |
| dc.ID = status.ClusterID |
| return err |
| } |
| |
| func (dc *DockerCluster) clusterReady(ctx context.Context) error { |
| for i, node := range dc.ClusterNodes { |
| expectLeader := i == 0 |
| err := ensureLeaderMatches(ctx, node.client, func(leader *api.LeaderResponse) error { |
| if expectLeader != leader.IsSelf { |
| return fmt.Errorf("node %d leader=%v, expected=%v", i, leader.IsSelf, expectLeader) |
| } |
| |
| return nil |
| }) |
| if err != nil { |
| return err |
| } |
| } |
| |
| return nil |
| } |
| |
| func (dc *DockerCluster) setupCA(opts *DockerClusterOptions) error { |
| var err error |
| var ca testcluster.CA |
| |
| if opts != nil && opts.CAKey != nil { |
| ca.CAKey = opts.CAKey |
| } else { |
| ca.CAKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) |
| if err != nil { |
| return err |
| } |
| } |
| |
| var caBytes []byte |
| if opts != nil && len(opts.CACert) > 0 { |
| caBytes = opts.CACert |
| } else { |
| serialNumber := mathrand.New(mathrand.NewSource(time.Now().UnixNano())).Int63() |
| CACertTemplate := &x509.Certificate{ |
| Subject: pkix.Name{ |
| CommonName: "localhost", |
| }, |
| KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, |
| SerialNumber: big.NewInt(serialNumber), |
| NotBefore: time.Now().Add(-30 * time.Second), |
| NotAfter: time.Now().Add(262980 * time.Hour), |
| BasicConstraintsValid: true, |
| IsCA: true, |
| } |
| caBytes, err = x509.CreateCertificate(rand.Reader, CACertTemplate, CACertTemplate, ca.CAKey.Public(), ca.CAKey) |
| if err != nil { |
| return err |
| } |
| } |
| CACert, err := x509.ParseCertificate(caBytes) |
| if err != nil { |
| return err |
| } |
| ca.CACert = CACert |
| ca.CACertBytes = caBytes |
| |
| CACertPEMBlock := &pem.Block{ |
| Type: "CERTIFICATE", |
| Bytes: caBytes, |
| } |
| ca.CACertPEM = pem.EncodeToMemory(CACertPEMBlock) |
| |
| ca.CACertPEMFile = filepath.Join(dc.tmpDir, "ca", "ca.pem") |
| err = os.WriteFile(ca.CACertPEMFile, ca.CACertPEM, 0o755) |
| if err != nil { |
| return err |
| } |
| |
| marshaledCAKey, err := x509.MarshalECPrivateKey(ca.CAKey) |
| if err != nil { |
| return err |
| } |
| CAKeyPEMBlock := &pem.Block{ |
| Type: "EC PRIVATE KEY", |
| Bytes: marshaledCAKey, |
| } |
| ca.CAKeyPEM = pem.EncodeToMemory(CAKeyPEMBlock) |
| |
| dc.CA = &ca |
| |
| return nil |
| } |
| |
| func (n *DockerClusterNode) setupCert(ip string) error { |
| var err error |
| |
| n.ServerKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) |
| if err != nil { |
| return err |
| } |
| |
| serialNumber := mathrand.New(mathrand.NewSource(time.Now().UnixNano())).Int63() |
| certTemplate := &x509.Certificate{ |
| Subject: pkix.Name{ |
| CommonName: n.Name(), |
| }, |
| DNSNames: []string{"localhost", n.Name()}, |
| IPAddresses: []net.IP{net.IPv6loopback, net.ParseIP("127.0.0.1"), net.ParseIP(ip)}, |
| ExtKeyUsage: []x509.ExtKeyUsage{ |
| x509.ExtKeyUsageServerAuth, |
| x509.ExtKeyUsageClientAuth, |
| }, |
| KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement, |
| SerialNumber: big.NewInt(serialNumber), |
| NotBefore: time.Now().Add(-30 * time.Second), |
| NotAfter: time.Now().Add(262980 * time.Hour), |
| } |
| n.ServerCertBytes, err = x509.CreateCertificate(rand.Reader, certTemplate, n.Cluster.CACert, n.ServerKey.Public(), n.Cluster.CAKey) |
| if err != nil { |
| return err |
| } |
| n.ServerCert, err = x509.ParseCertificate(n.ServerCertBytes) |
| if err != nil { |
| return err |
| } |
| n.ServerCertPEM = pem.EncodeToMemory(&pem.Block{ |
| Type: "CERTIFICATE", |
| Bytes: n.ServerCertBytes, |
| }) |
| |
| marshaledKey, err := x509.MarshalECPrivateKey(n.ServerKey) |
| if err != nil { |
| return err |
| } |
| n.ServerKeyPEM = pem.EncodeToMemory(&pem.Block{ |
| Type: "EC PRIVATE KEY", |
| Bytes: marshaledKey, |
| }) |
| |
| n.ServerCertPEMFile = filepath.Join(n.WorkDir, "cert.pem") |
| err = os.WriteFile(n.ServerCertPEMFile, n.ServerCertPEM, 0o755) |
| if err != nil { |
| return err |
| } |
| |
| n.ServerKeyPEMFile = filepath.Join(n.WorkDir, "key.pem") |
| err = os.WriteFile(n.ServerKeyPEMFile, n.ServerKeyPEM, 0o755) |
| if err != nil { |
| return err |
| } |
| |
| tlsCert, err := tls.X509KeyPair(n.ServerCertPEM, n.ServerKeyPEM) |
| if err != nil { |
| return err |
| } |
| |
| certGetter := NewCertificateGetter(n.ServerCertPEMFile, n.ServerKeyPEMFile, "") |
| if err := certGetter.Reload(); err != nil { |
| return err |
| } |
| tlsConfig := &tls.Config{ |
| Certificates: []tls.Certificate{tlsCert}, |
| RootCAs: n.Cluster.RootCAs, |
| ClientCAs: n.Cluster.RootCAs, |
| ClientAuth: tls.RequestClientCert, |
| NextProtos: []string{"h2", "http/1.1"}, |
| GetCertificate: certGetter.GetCertificate, |
| } |
| |
| n.tlsConfig = tlsConfig |
| |
| err = os.WriteFile(filepath.Join(n.WorkDir, "ca.pem"), n.Cluster.CACertPEM, 0o755) |
| if err != nil { |
| return err |
| } |
| return nil |
| } |
| |
| func NewTestDockerCluster(t *testing.T, opts *DockerClusterOptions) *DockerCluster { |
| if opts == nil { |
| opts = &DockerClusterOptions{} |
| } |
| if opts.ClusterName == "" { |
| opts.ClusterName = strings.ReplaceAll(t.Name(), "/", "-") |
| } |
| if opts.Logger == nil { |
| opts.Logger = logging.NewVaultLogger(log.Trace).Named(t.Name()) |
| } |
| if opts.NetworkName == "" { |
| opts.NetworkName = os.Getenv("TEST_DOCKER_NETWORK_NAME") |
| } |
| |
| ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) |
| t.Cleanup(cancel) |
| |
| dc, err := NewDockerCluster(ctx, opts) |
| if err != nil { |
| t.Fatal(err) |
| } |
| dc.Logger.Trace("cluster started", "helpful_env", fmt.Sprintf("VAULT_TOKEN=%s VAULT_CACERT=/vault/config/ca.pem", dc.GetRootToken())) |
| return dc |
| } |
| |
| func NewDockerCluster(ctx context.Context, opts *DockerClusterOptions) (*DockerCluster, error) { |
| api, err := dockhelper.NewDockerAPI() |
| if err != nil { |
| return nil, err |
| } |
| |
| if opts == nil { |
| opts = &DockerClusterOptions{} |
| } |
| if opts.Logger == nil { |
| opts.Logger = log.NewNullLogger() |
| } |
| if opts.VaultLicense == "" { |
| opts.VaultLicense = os.Getenv(testcluster.EnvVaultLicenseCI) |
| } |
| |
| dc := &DockerCluster{ |
| DockerAPI: api, |
| ClusterName: opts.ClusterName, |
| Logger: opts.Logger, |
| builtTags: map[string]struct{}{}, |
| CA: opts.CA, |
| storage: opts.Storage, |
| } |
| |
| if err := dc.setupDockerCluster(ctx, opts); err != nil { |
| dc.Cleanup() |
| return nil, err |
| } |
| |
| return dc, nil |
| } |
| |
| // DockerClusterNode represents a single instance of Vault in a cluster |
| type DockerClusterNode struct { |
| NodeID string |
| HostPort string |
| client *api.Client |
| ServerCert *x509.Certificate |
| ServerCertBytes []byte |
| ServerCertPEM []byte |
| ServerCertPEMFile string |
| ServerKey *ecdsa.PrivateKey |
| ServerKeyPEM []byte |
| ServerKeyPEMFile string |
| tlsConfig *tls.Config |
| WorkDir string |
| Cluster *DockerCluster |
| Container *types.ContainerJSON |
| DockerAPI *docker.Client |
| runner *dockhelper.Runner |
| Logger log.Logger |
| cleanupContainer func() |
| RealAPIAddr string |
| ContainerNetworkName string |
| ContainerIPAddress string |
| ImageRepo string |
| ImageTag string |
| DataVolumeName string |
| cleanupVolume func() |
| } |
| |
| func (n *DockerClusterNode) TLSConfig() *tls.Config { |
| return n.tlsConfig.Clone() |
| } |
| |
| func (n *DockerClusterNode) APIClient() *api.Client { |
| // We clone to ensure that whenever this method is called, the caller gets |
| // back a pristine client, without e.g. any namespace or token changes that |
| // might pollute a shared client. We clone the config instead of the |
| // client because (1) Client.clone propagates the replicationStateStore and |
| // the httpClient pointers, (2) it doesn't copy the tlsConfig at all, and |
| // (3) if clone returns an error, it doesn't feel as appropriate to panic |
| // below. Who knows why clone might return an error? |
| cfg := n.client.CloneConfig() |
| client, err := api.NewClient(cfg) |
| if err != nil { |
| // It seems fine to panic here, since this should be the same input |
| // we provided to NewClient when we were setup, and we didn't panic then. |
| // Better not to completely ignore the error though, suppose there's a |
| // bug in CloneConfig? |
| panic(fmt.Sprintf("NewClient error on cloned config: %v", err)) |
| } |
| client.SetToken(n.Cluster.rootToken) |
| return client |
| } |
| |
| // NewAPIClient creates and configures a Vault API client to communicate with |
| // the running Vault Cluster for this DockerClusterNode |
| func (n *DockerClusterNode) apiConfig() (*api.Config, error) { |
| transport := cleanhttp.DefaultPooledTransport() |
| transport.TLSClientConfig = n.TLSConfig() |
| if err := http2.ConfigureTransport(transport); err != nil { |
| return nil, err |
| } |
| client := &http.Client{ |
| Transport: transport, |
| CheckRedirect: func(*http.Request, []*http.Request) error { |
| // This can of course be overridden per-test by using its own client |
| return fmt.Errorf("redirects not allowed in these tests") |
| }, |
| } |
| config := api.DefaultConfig() |
| if config.Error != nil { |
| return nil, config.Error |
| } |
| config.Address = fmt.Sprintf("https://%s", n.HostPort) |
| config.HttpClient = client |
| config.MaxRetries = 0 |
| return config, nil |
| } |
| |
| func (n *DockerClusterNode) newAPIClient() (*api.Client, error) { |
| config, err := n.apiConfig() |
| if err != nil { |
| return nil, err |
| } |
| client, err := api.NewClient(config) |
| if err != nil { |
| return nil, err |
| } |
| client.SetToken(n.Cluster.GetRootToken()) |
| return client, nil |
| } |
| |
| // Cleanup kills the container of the node and deletes its data volume |
| func (n *DockerClusterNode) Cleanup() { |
| n.cleanup() |
| } |
| |
| // Stop kills the container of the node |
| func (n *DockerClusterNode) Stop() { |
| n.cleanupContainer() |
| } |
| |
| func (n *DockerClusterNode) cleanup() error { |
| if n.Container == nil || n.Container.ID == "" { |
| return nil |
| } |
| n.cleanupContainer() |
| n.cleanupVolume() |
| return nil |
| } |
| |
| func (n *DockerClusterNode) Start(ctx context.Context, opts *DockerClusterOptions) error { |
| if n.DataVolumeName == "" { |
| vol, err := n.DockerAPI.VolumeCreate(ctx, volume.CreateOptions{}) |
| if err != nil { |
| return err |
| } |
| n.DataVolumeName = vol.Name |
| n.cleanupVolume = func() { |
| _ = n.DockerAPI.VolumeRemove(ctx, vol.Name, false) |
| } |
| } |
| vaultCfg := map[string]interface{}{} |
| vaultCfg["listener"] = map[string]interface{}{ |
| "tcp": map[string]interface{}{ |
| "address": fmt.Sprintf("%s:%d", "0.0.0.0", 8200), |
| "tls_cert_file": "/vault/config/cert.pem", |
| "tls_key_file": "/vault/config/key.pem", |
| "telemetry": map[string]interface{}{ |
| "unauthenticated_metrics_access": true, |
| }, |
| }, |
| } |
| vaultCfg["telemetry"] = map[string]interface{}{ |
| "disable_hostname": true, |
| } |
| |
| // Setup storage. Default is raft. |
| storageType := "raft" |
| storageOpts := map[string]interface{}{ |
| // TODO add options from vnc |
| "path": "/vault/file", |
| "node_id": n.NodeID, |
| } |
| |
| if opts.Storage != nil { |
| storageType = opts.Storage.Type() |
| storageOpts = opts.Storage.Opts() |
| } |
| |
| if opts != nil && opts.VaultNodeConfig != nil { |
| for k, v := range opts.VaultNodeConfig.StorageOptions { |
| if _, ok := storageOpts[k].(string); !ok { |
| storageOpts[k] = v |
| } |
| } |
| } |
| vaultCfg["storage"] = map[string]interface{}{ |
| storageType: storageOpts, |
| } |
| |
| //// disable_mlock is required for working in the Docker environment with |
| //// custom plugins |
| vaultCfg["disable_mlock"] = true |
| vaultCfg["api_addr"] = `https://{{- GetAllInterfaces | exclude "flags" "loopback" | attr "address" -}}:8200` |
| vaultCfg["cluster_addr"] = `https://{{- GetAllInterfaces | exclude "flags" "loopback" | attr "address" -}}:8201` |
| |
| vaultCfg["administrative_namespace_path"] = opts.AdministrativeNamespacePath |
| |
| systemJSON, err := json.Marshal(vaultCfg) |
| if err != nil { |
| return err |
| } |
| err = os.WriteFile(filepath.Join(n.WorkDir, "system.json"), systemJSON, 0o644) |
| if err != nil { |
| return err |
| } |
| |
| if opts.VaultNodeConfig != nil { |
| localCfg := *opts.VaultNodeConfig |
| if opts.VaultNodeConfig.LicensePath != "" { |
| b, err := os.ReadFile(opts.VaultNodeConfig.LicensePath) |
| if err != nil || len(b) == 0 { |
| return fmt.Errorf("unable to read LicensePath at %q: %w", opts.VaultNodeConfig.LicensePath, err) |
| } |
| localCfg.LicensePath = "/vault/config/license" |
| dest := filepath.Join(n.WorkDir, "license") |
| err = os.WriteFile(dest, b, 0o644) |
| if err != nil { |
| return fmt.Errorf("error writing license to %q: %w", dest, err) |
| } |
| |
| } |
| userJSON, err := json.Marshal(localCfg) |
| if err != nil { |
| return err |
| } |
| err = os.WriteFile(filepath.Join(n.WorkDir, "user.json"), userJSON, 0o644) |
| if err != nil { |
| return err |
| } |
| } |
| |
| // Create a temporary cert so vault will start up |
| err = n.setupCert("127.0.0.1") |
| if err != nil { |
| return err |
| } |
| |
| caDir := filepath.Join(n.Cluster.tmpDir, "ca") |
| |
| // setup plugin bin copy if needed |
| copyFromTo := map[string]string{ |
| n.WorkDir: "/vault/config", |
| caDir: "/usr/local/share/ca-certificates/", |
| } |
| |
| var wg sync.WaitGroup |
| wg.Add(1) |
| var seenLogs uberAtomic.Bool |
| logConsumer := func(s string) { |
| if seenLogs.CAS(false, true) { |
| wg.Done() |
| } |
| n.Logger.Trace(s) |
| } |
| logStdout := &LogConsumerWriter{logConsumer} |
| logStderr := &LogConsumerWriter{func(s string) { |
| if seenLogs.CAS(false, true) { |
| wg.Done() |
| } |
| testcluster.JSONLogNoTimestamp(n.Logger, s) |
| }} |
| r, err := dockhelper.NewServiceRunner(dockhelper.RunOptions{ |
| ImageRepo: n.ImageRepo, |
| ImageTag: n.ImageTag, |
| // We don't need to run update-ca-certificates in the container, because |
| // we're providing the CA in the raft join call, and otherwise Vault |
| // servers don't talk to one another on the API port. |
| Cmd: append([]string{"server"}, opts.Args...), |
| Env: []string{ |
| // For now we're using disable_mlock, because this is for testing |
| // anyway, and because it prevents us using external plugins. |
| "SKIP_SETCAP=true", |
| "VAULT_LOG_FORMAT=json", |
| "VAULT_LICENSE=" + opts.VaultLicense, |
| }, |
| Ports: []string{"8200/tcp", "8201/tcp"}, |
| ContainerName: n.Name(), |
| NetworkName: opts.NetworkName, |
| CopyFromTo: copyFromTo, |
| LogConsumer: logConsumer, |
| LogStdout: logStdout, |
| LogStderr: logStderr, |
| PreDelete: true, |
| DoNotAutoRemove: true, |
| PostStart: func(containerID string, realIP string) error { |
| err := n.setupCert(realIP) |
| if err != nil { |
| return err |
| } |
| |
| // If we signal Vault before it installs its sighup handler, it'll die. |
| wg.Wait() |
| n.Logger.Trace("running poststart", "containerID", containerID, "IP", realIP) |
| return n.runner.RefreshFiles(ctx, containerID) |
| }, |
| Capabilities: []string{"NET_ADMIN"}, |
| OmitLogTimestamps: true, |
| VolumeNameToMountPoint: map[string]string{ |
| n.DataVolumeName: "/vault/file", |
| }, |
| }) |
| if err != nil { |
| return err |
| } |
| n.runner = r |
| |
| probe := opts.StartProbe |
| if probe == nil { |
| probe = func(c *api.Client) error { |
| _, err = c.Sys().SealStatus() |
| return err |
| } |
| } |
| svc, _, err := r.StartNewService(ctx, false, false, func(ctx context.Context, host string, port int) (dockhelper.ServiceConfig, error) { |
| config, err := n.apiConfig() |
| if err != nil { |
| return nil, err |
| } |
| config.Address = fmt.Sprintf("https://%s:%d", host, port) |
| client, err := api.NewClient(config) |
| if err != nil { |
| return nil, err |
| } |
| err = probe(client) |
| if err != nil { |
| return nil, err |
| } |
| |
| return dockhelper.NewServiceHostPort(host, port), nil |
| }) |
| if err != nil { |
| return err |
| } |
| |
| n.HostPort = svc.Config.Address() |
| n.Container = svc.Container |
| netName := opts.NetworkName |
| if netName == "" { |
| if len(svc.Container.NetworkSettings.Networks) > 1 { |
| return fmt.Errorf("Set d.RunOptions.NetworkName instead for container with multiple networks: %v", svc.Container.NetworkSettings.Networks) |
| } |
| for netName = range svc.Container.NetworkSettings.Networks { |
| // Networks above is a map; we just need to find the first and |
| // only key of this map (network name). The range handles this |
| // for us, but we need a loop construction in order to use range. |
| } |
| } |
| n.ContainerNetworkName = netName |
| n.ContainerIPAddress = svc.Container.NetworkSettings.Networks[netName].IPAddress |
| n.RealAPIAddr = "https://" + n.ContainerIPAddress + ":8200" |
| n.cleanupContainer = svc.Cleanup |
| |
| client, err := n.newAPIClient() |
| if err != nil { |
| return err |
| } |
| client.SetToken(n.Cluster.rootToken) |
| n.client = client |
| return nil |
| } |
| |
| func (n *DockerClusterNode) Pause(ctx context.Context) error { |
| return n.DockerAPI.ContainerPause(ctx, n.Container.ID) |
| } |
| |
| func (n *DockerClusterNode) AddNetworkDelay(ctx context.Context, delay time.Duration, targetIP string) error { |
| ip := net.ParseIP(targetIP) |
| if ip == nil { |
| return fmt.Errorf("targetIP %q is not an IP address", targetIP) |
| } |
| // Let's attempt to get a unique handle for the filter rule; we'll assume that |
| // every targetIP has a unique last octet, which is true currently for how |
| // we're doing docker networking. |
| lastOctet := ip.To4()[3] |
| |
| stdout, stderr, exitCode, err := n.runner.RunCmdWithOutput(ctx, n.Container.ID, []string{ |
| "/bin/sh", |
| "-xec", strings.Join([]string{ |
| fmt.Sprintf("echo isolating node %s", targetIP), |
| "apk add iproute2", |
| // If we're running this script a second time on the same node, |
| // the add dev will fail; since we only want to run the netem |
| // command once, we'll do so in the case where the add dev doesn't fail. |
| "tc qdisc add dev eth0 root handle 1: prio && " + |
| fmt.Sprintf("tc qdisc add dev eth0 parent 1:1 handle 2: netem delay %dms", delay/time.Millisecond), |
| // Here we create a u32 filter as per https://man7.org/linux/man-pages/man8/tc-u32.8.html |
| // Its parent is 1:0 (which I guess is the root?) |
| // Its handle must be unique, so we base it on targetIP |
| fmt.Sprintf("tc filter add dev eth0 parent 1:0 protocol ip pref 55 handle ::%x u32 match ip dst %s flowid 2:1", lastOctet, targetIP), |
| }, "; "), |
| }) |
| if err != nil { |
| return err |
| } |
| |
| n.Logger.Trace(string(stdout)) |
| n.Logger.Trace(string(stderr)) |
| if exitCode != 0 { |
| return fmt.Errorf("got nonzero exit code from iptables: %d", exitCode) |
| } |
| return nil |
| } |
| |
| // PartitionFromCluster will cause the node to be disconnected at the network |
| // level from the rest of the docker cluster. It does so in a way that the node |
| // will not see TCP RSTs and all packets it sends will be "black holed". It |
| // attempts to keep packets to and from the host intact which allows docker |
| // daemon to continue streaming logs and any test code to continue making |
| // requests from the host to the partitioned node. |
| func (n *DockerClusterNode) PartitionFromCluster(ctx context.Context) error { |
| stdout, stderr, exitCode, err := n.runner.RunCmdWithOutput(ctx, n.Container.ID, []string{ |
| "/bin/sh", |
| "-xec", strings.Join([]string{ |
| fmt.Sprintf("echo partitioning container from network"), |
| "apk add iproute2", |
| // Get the gateway address for the bridge so we can allow host to |
| // container traffic still. |
| "GW=$(ip r | grep default | grep eth0 | cut -f 3 -d' ')", |
| // First delete the rules in case this is called twice otherwise we'll add |
| // multiple copies and only remove one in Unpartition (yay iptables). |
| // Ignore the error if it didn't exist. |
| "iptables -D INPUT -i eth0 ! -s \"$GW\" -j DROP | true", |
| "iptables -D OUTPUT -o eth0 ! -d \"$GW\" -j DROP | true", |
| // Add rules to drop all packets in and out of the docker network |
| // connection. |
| "iptables -I INPUT -i eth0 ! -s \"$GW\" -j DROP", |
| "iptables -I OUTPUT -o eth0 ! -d \"$GW\" -j DROP", |
| }, "; "), |
| }) |
| if err != nil { |
| return err |
| } |
| |
| n.Logger.Trace(string(stdout)) |
| n.Logger.Trace(string(stderr)) |
| if exitCode != 0 { |
| return fmt.Errorf("got nonzero exit code from iptables: %d", exitCode) |
| } |
| return nil |
| } |
| |
| // UnpartitionFromCluster reverses a previous call to PartitionFromCluster and |
| // restores full connectivity. Currently assumes the default "bridge" network. |
| func (n *DockerClusterNode) UnpartitionFromCluster(ctx context.Context) error { |
| stdout, stderr, exitCode, err := n.runner.RunCmdWithOutput(ctx, n.Container.ID, []string{ |
| "/bin/sh", |
| "-xec", strings.Join([]string{ |
| fmt.Sprintf("echo un-partitioning container from network"), |
| // Get the gateway address for the bridge so we can allow host to |
| // container traffic still. |
| "GW=$(ip r | grep default | grep eth0 | cut -f 3 -d' ')", |
| // Remove the rules, ignore if they are not present or iptables wasn't |
| // installed yet (i.e. no-one called PartitionFromCluster yet). |
| "iptables -D INPUT -i eth0 ! -s \"$GW\" -j DROP | true", |
| "iptables -D OUTPUT -o eth0 ! -d \"$GW\" -j DROP | true", |
| }, "; "), |
| }) |
| if err != nil { |
| return err |
| } |
| |
| n.Logger.Trace(string(stdout)) |
| n.Logger.Trace(string(stderr)) |
| if exitCode != 0 { |
| return fmt.Errorf("got nonzero exit code from iptables: %d", exitCode) |
| } |
| return nil |
| } |
| |
| type LogConsumerWriter struct { |
| consumer func(string) |
| } |
| |
| func (l LogConsumerWriter) Write(p []byte) (n int, err error) { |
| // TODO this assumes that we're never passed partial log lines, which |
| // seems a safe assumption for now based on how docker looks to implement |
| // logging, but might change in the future. |
| scanner := bufio.NewScanner(bytes.NewReader(p)) |
| scanner.Buffer(make([]byte, 64*1024), bufio.MaxScanTokenSize) |
| for scanner.Scan() { |
| l.consumer(scanner.Text()) |
| } |
| return len(p), nil |
| } |
| |
| // DockerClusterOptions has options for setting up the docker cluster |
| type DockerClusterOptions struct { |
| testcluster.ClusterOptions |
| CAKey *ecdsa.PrivateKey |
| NetworkName string |
| ImageRepo string |
| ImageTag string |
| CA *testcluster.CA |
| VaultBinary string |
| Args []string |
| StartProbe func(*api.Client) error |
| Storage testcluster.ClusterStorage |
| } |
| |
| func ensureLeaderMatches(ctx context.Context, client *api.Client, ready func(response *api.LeaderResponse) error) error { |
| var leader *api.LeaderResponse |
| var err error |
| for ctx.Err() == nil { |
| leader, err = client.Sys().Leader() |
| switch { |
| case err != nil: |
| case leader == nil: |
| err = fmt.Errorf("nil response to leader check") |
| default: |
| err = ready(leader) |
| if err == nil { |
| return nil |
| } |
| } |
| time.Sleep(500 * time.Millisecond) |
| } |
| return fmt.Errorf("error checking leader: %v", err) |
| } |
| |
| const DefaultNumCores = 3 |
| |
| // creates a managed docker container running Vault |
| func (dc *DockerCluster) setupDockerCluster(ctx context.Context, opts *DockerClusterOptions) error { |
| if opts.TmpDir != "" { |
| if _, err := os.Stat(opts.TmpDir); os.IsNotExist(err) { |
| if err := os.MkdirAll(opts.TmpDir, 0o700); err != nil { |
| return err |
| } |
| } |
| dc.tmpDir = opts.TmpDir |
| } else { |
| tempDir, err := ioutil.TempDir("", "vault-test-cluster-") |
| if err != nil { |
| return err |
| } |
| dc.tmpDir = tempDir |
| } |
| caDir := filepath.Join(dc.tmpDir, "ca") |
| if err := os.MkdirAll(caDir, 0o755); err != nil { |
| return err |
| } |
| |
| var numCores int |
| if opts.NumCores == 0 { |
| numCores = DefaultNumCores |
| } else { |
| numCores = opts.NumCores |
| } |
| |
| if dc.CA == nil { |
| if err := dc.setupCA(opts); err != nil { |
| return err |
| } |
| } |
| dc.RootCAs = x509.NewCertPool() |
| dc.RootCAs.AddCert(dc.CA.CACert) |
| |
| if dc.storage != nil { |
| if err := dc.storage.Start(ctx, &opts.ClusterOptions); err != nil { |
| return err |
| } |
| } |
| |
| for i := 0; i < numCores; i++ { |
| if err := dc.addNode(ctx, opts); err != nil { |
| return err |
| } |
| if opts.SkipInit { |
| continue |
| } |
| if i == 0 { |
| if err := dc.setupNode0(ctx); err != nil { |
| return nil |
| } |
| } else { |
| if err := dc.joinNode(ctx, i, 0); err != nil { |
| return err |
| } |
| } |
| } |
| |
| return nil |
| } |
| |
| func (dc *DockerCluster) AddNode(ctx context.Context, opts *DockerClusterOptions) error { |
| leaderIdx, err := testcluster.LeaderNode(ctx, dc) |
| if err != nil { |
| return err |
| } |
| if err := dc.addNode(ctx, opts); err != nil { |
| return err |
| } |
| |
| return dc.joinNode(ctx, len(dc.ClusterNodes)-1, leaderIdx) |
| } |
| |
| func (dc *DockerCluster) addNode(ctx context.Context, opts *DockerClusterOptions) error { |
| tag, err := dc.setupImage(ctx, opts) |
| if err != nil { |
| return err |
| } |
| i := len(dc.ClusterNodes) |
| nodeID := fmt.Sprintf("core-%d", i) |
| node := &DockerClusterNode{ |
| DockerAPI: dc.DockerAPI, |
| NodeID: nodeID, |
| Cluster: dc, |
| WorkDir: filepath.Join(dc.tmpDir, nodeID), |
| Logger: dc.Logger.Named(nodeID), |
| ImageRepo: opts.ImageRepo, |
| ImageTag: tag, |
| } |
| dc.ClusterNodes = append(dc.ClusterNodes, node) |
| if err := os.MkdirAll(node.WorkDir, 0o755); err != nil { |
| return err |
| } |
| if err := node.Start(ctx, opts); err != nil { |
| return err |
| } |
| return nil |
| } |
| |
| func (dc *DockerCluster) joinNode(ctx context.Context, nodeIdx int, leaderIdx int) error { |
| if dc.storage != nil && dc.storage.Type() != "raft" { |
| // Storage is not raft so nothing to do but unseal. |
| return testcluster.UnsealNode(ctx, dc, nodeIdx) |
| } |
| |
| leader := dc.ClusterNodes[leaderIdx] |
| |
| if nodeIdx >= len(dc.ClusterNodes) { |
| return fmt.Errorf("invalid node %d", nodeIdx) |
| } |
| node := dc.ClusterNodes[nodeIdx] |
| client := node.APIClient() |
| |
| var resp *api.RaftJoinResponse |
| resp, err := client.Sys().RaftJoinWithContext(ctx, &api.RaftJoinRequest{ |
| // When running locally on a bridge network, the containers must use their |
| // actual (private) IP to talk to one another. Our code must instead use |
| // the portmapped address since we're not on their network in that case. |
| LeaderAPIAddr: leader.RealAPIAddr, |
| LeaderCACert: string(dc.CACertPEM), |
| LeaderClientCert: string(node.ServerCertPEM), |
| LeaderClientKey: string(node.ServerKeyPEM), |
| }) |
| if resp == nil || !resp.Joined { |
| return fmt.Errorf("nil or negative response from raft join request: %v", resp) |
| } |
| if err != nil { |
| return fmt.Errorf("failed to join cluster: %w", err) |
| } |
| |
| return testcluster.UnsealNode(ctx, dc, nodeIdx) |
| } |
| |
| func (dc *DockerCluster) setupImage(ctx context.Context, opts *DockerClusterOptions) (string, error) { |
| if opts == nil { |
| opts = &DockerClusterOptions{} |
| } |
| sourceTag := opts.ImageTag |
| if sourceTag == "" { |
| sourceTag = "latest" |
| } |
| |
| if opts.VaultBinary == "" { |
| return sourceTag, nil |
| } |
| |
| suffix := "testing" |
| if sha := os.Getenv("COMMIT_SHA"); sha != "" { |
| suffix = sha |
| } |
| tag := sourceTag + "-" + suffix |
| if _, ok := dc.builtTags[tag]; ok { |
| return tag, nil |
| } |
| |
| f, err := os.Open(opts.VaultBinary) |
| if err != nil { |
| return "", err |
| } |
| data, err := io.ReadAll(f) |
| if err != nil { |
| return "", err |
| } |
| bCtx := dockhelper.NewBuildContext() |
| bCtx["vault"] = &dockhelper.FileContents{ |
| Data: data, |
| Mode: 0o755, |
| } |
| |
| containerFile := fmt.Sprintf(` |
| FROM %s:%s |
| COPY vault /bin/vault |
| `, opts.ImageRepo, sourceTag) |
| |
| _, err = dockhelper.BuildImage(ctx, dc.DockerAPI, containerFile, bCtx, |
| dockhelper.BuildRemove(true), dockhelper.BuildForceRemove(true), |
| dockhelper.BuildPullParent(true), |
| dockhelper.BuildTags([]string{opts.ImageRepo + ":" + tag})) |
| if err != nil { |
| return "", err |
| } |
| dc.builtTags[tag] = struct{}{} |
| return tag, nil |
| } |
| |
| /* Notes on testing the non-bridge network case: |
| - you need the test itself to be running in a container so that it can use |
| the network; create the network using |
| docker network create testvault |
| - this means that you need to mount the docker socket in that test container, |
| but on macos there's stuff that prevents that from working; to hack that, |
| on the host run |
| sudo ln -s "$HOME/Library/Containers/com.docker.docker/Data/docker.raw.sock" /var/run/docker.sock.raw |
| - run the test container like |
| docker run --rm -it --network testvault \ |
| -v /var/run/docker.sock.raw:/var/run/docker.sock \ |
| -v $(pwd):/home/circleci/go/src/github.com/hashicorp/vault/ \ |
| -w /home/circleci/go/src/github.com/hashicorp/vault/ \ |
| "docker.mirror.hashicorp.services/cimg/go:1.19.2" /bin/bash |
| - in the container you may need to chown/chmod /var/run/docker.sock; use `docker ps` |
| to test if it's working |
| |
| */ |