| // Copyright (c) HashiCorp, Inc. |
| // SPDX-License-Identifier: MPL-2.0 |
| |
| package kubernetes |
| |
| import ( |
| "fmt" |
| "os" |
| "strconv" |
| "sync" |
| |
| "github.com/hashicorp/go-hclog" |
| sr "github.com/hashicorp/vault/serviceregistration" |
| "github.com/hashicorp/vault/serviceregistration/kubernetes/client" |
| ) |
| |
| const ( |
| // Labels are placed in a pod's metadata. |
| labelVaultVersion = "vault-version" |
| labelActive = "vault-active" |
| labelSealed = "vault-sealed" |
| labelPerfStandby = "vault-perf-standby" |
| labelInitialized = "vault-initialized" |
| |
| // This is the path to where these labels are applied. |
| pathToLabels = "/metadata/labels/" |
| ) |
| |
| func NewServiceRegistration(config map[string]string, logger hclog.Logger, state sr.State) (sr.ServiceRegistration, error) { |
| namespace, err := getRequiredField(logger, config, client.EnvVarKubernetesNamespace, "namespace") |
| if err != nil { |
| return nil, err |
| } |
| podName, err := getRequiredField(logger, config, client.EnvVarKubernetesPodName, "pod_name") |
| if err != nil { |
| return nil, err |
| } |
| |
| c, err := client.New(logger) |
| if err != nil { |
| return nil, err |
| } |
| |
| // The Vault version must be sanitized because it can contain special |
| // characters like "+" which aren't acceptable by the Kube API. |
| state.VaultVersion = client.Sanitize(state.VaultVersion) |
| return &serviceRegistration{ |
| logger: logger, |
| namespace: namespace, |
| podName: podName, |
| retryHandler: &retryHandler{ |
| logger: logger, |
| namespace: namespace, |
| podName: podName, |
| initialState: state, |
| patchesToRetry: make(map[string]*client.Patch), |
| client: c, |
| }, |
| }, nil |
| } |
| |
| type serviceRegistration struct { |
| logger hclog.Logger |
| namespace, podName string |
| retryHandler *retryHandler |
| } |
| |
| func (r *serviceRegistration) Run(shutdownCh <-chan struct{}, wait *sync.WaitGroup, _ string) error { |
| r.retryHandler.Run(shutdownCh, wait) |
| return nil |
| } |
| |
| func (r *serviceRegistration) NotifyActiveStateChange(isActive bool) error { |
| r.retryHandler.Notify(&client.Patch{ |
| Operation: client.Replace, |
| Path: pathToLabels + labelActive, |
| Value: strconv.FormatBool(isActive), |
| }) |
| return nil |
| } |
| |
| func (r *serviceRegistration) NotifySealedStateChange(isSealed bool) error { |
| r.retryHandler.Notify(&client.Patch{ |
| Operation: client.Replace, |
| Path: pathToLabels + labelSealed, |
| Value: strconv.FormatBool(isSealed), |
| }) |
| return nil |
| } |
| |
| func (r *serviceRegistration) NotifyPerformanceStandbyStateChange(isStandby bool) error { |
| r.retryHandler.Notify(&client.Patch{ |
| Operation: client.Replace, |
| Path: pathToLabels + labelPerfStandby, |
| Value: strconv.FormatBool(isStandby), |
| }) |
| return nil |
| } |
| |
| func (r *serviceRegistration) NotifyInitializedStateChange(isInitialized bool) error { |
| r.retryHandler.Notify(&client.Patch{ |
| Operation: client.Replace, |
| Path: pathToLabels + labelInitialized, |
| Value: strconv.FormatBool(isInitialized), |
| }) |
| return nil |
| } |
| |
| func getRequiredField(logger hclog.Logger, config map[string]string, envVar, configParam string) (string, error) { |
| value := "" |
| switch { |
| case os.Getenv(envVar) != "": |
| value = os.Getenv(envVar) |
| case config[configParam] != "": |
| value = config[configParam] |
| default: |
| return "", fmt.Errorf(`%s must be provided via %q or the %q config parameter`, configParam, envVar, configParam) |
| } |
| if logger.IsDebug() { |
| logger.Debug(fmt.Sprintf("%q: %q", configParam, value)) |
| } |
| return value, nil |
| } |