| // Copyright (c) HashiCorp, Inc. |
| // SPDX-License-Identifier: MPL-2.0 |
| |
| package etcd |
| |
| import ( |
| "context" |
| "errors" |
| "fmt" |
| "os" |
| "path" |
| "strconv" |
| "strings" |
| "sync" |
| "time" |
| |
| "github.com/armon/go-metrics" |
| log "github.com/hashicorp/go-hclog" |
| "github.com/hashicorp/go-secure-stdlib/parseutil" |
| "github.com/hashicorp/go-secure-stdlib/strutil" |
| "github.com/hashicorp/vault/sdk/physical" |
| "go.etcd.io/etcd/client/pkg/v3/transport" |
| clientv3 "go.etcd.io/etcd/client/v3" |
| "go.etcd.io/etcd/client/v3/concurrency" |
| ) |
| |
| // EtcdBackend is a physical backend that stores data at specific |
| // prefix within etcd. It is used for most production situations as |
| // it allows Vault to run on multiple machines in a highly-available manner. |
| type EtcdBackend struct { |
| logger log.Logger |
| path string |
| haEnabled bool |
| lockTimeout time.Duration |
| requestTimeout time.Duration |
| |
| permitPool *physical.PermitPool |
| |
| etcd *clientv3.Client |
| } |
| |
| // Verify EtcdBackend satisfies the correct interfaces |
| var ( |
| _ physical.Backend = (*EtcdBackend)(nil) |
| _ physical.HABackend = (*EtcdBackend)(nil) |
| _ physical.Lock = (*EtcdLock)(nil) |
| ) |
| |
| // newEtcd3Backend constructs a etcd3 backend. |
| func newEtcd3Backend(conf map[string]string, logger log.Logger) (physical.Backend, error) { |
| // Get the etcd path form the configuration. |
| path, ok := conf["path"] |
| if !ok { |
| path = "/vault" |
| } |
| |
| // Ensure path is prefixed. |
| if !strings.HasPrefix(path, "/") { |
| path = "/" + path |
| } |
| |
| endpoints, err := getEtcdEndpoints(conf) |
| if err != nil { |
| return nil, err |
| } |
| |
| cfg := clientv3.Config{ |
| Endpoints: endpoints, |
| } |
| |
| haEnabled := os.Getenv("ETCD_HA_ENABLED") |
| if haEnabled == "" { |
| haEnabled = conf["ha_enabled"] |
| } |
| if haEnabled == "" { |
| haEnabled = "false" |
| } |
| haEnabledBool, err := strconv.ParseBool(haEnabled) |
| if err != nil { |
| return nil, fmt.Errorf("value [%v] of 'ha_enabled' could not be understood", haEnabled) |
| } |
| |
| cert, hasCert := conf["tls_cert_file"] |
| key, hasKey := conf["tls_key_file"] |
| ca, hasCa := conf["tls_ca_file"] |
| if (hasCert && hasKey) || hasCa { |
| tls := transport.TLSInfo{ |
| TrustedCAFile: ca, |
| CertFile: cert, |
| KeyFile: key, |
| } |
| |
| tlscfg, err := tls.ClientConfig() |
| if err != nil { |
| return nil, err |
| } |
| cfg.TLS = tlscfg |
| } |
| |
| // Set credentials. |
| username := os.Getenv("ETCD_USERNAME") |
| if username == "" { |
| username, _ = conf["username"] |
| } |
| |
| password := os.Getenv("ETCD_PASSWORD") |
| if password == "" { |
| password, _ = conf["password"] |
| } |
| |
| if username != "" && password != "" { |
| cfg.Username = username |
| cfg.Password = password |
| } |
| |
| if maxReceive, ok := conf["max_receive_size"]; ok { |
| // grpc converts this to uint32 internally, so parse as that to avoid passing invalid values |
| val, err := strconv.ParseUint(maxReceive, 10, 32) |
| if err != nil { |
| return nil, fmt.Errorf("value of 'max_receive_size' (%v) could not be understood: %w", maxReceive, err) |
| } |
| cfg.MaxCallRecvMsgSize = int(val) |
| } |
| |
| etcd, err := clientv3.New(cfg) |
| if err != nil { |
| return nil, err |
| } |
| |
| sReqTimeout := conf["request_timeout"] |
| if sReqTimeout == "" { |
| // etcd3 default request timeout is set to 5s. It should be long enough |
| // for most cases, even with internal retry. |
| sReqTimeout = "5s" |
| } |
| reqTimeout, err := parseutil.ParseDurationSecond(sReqTimeout) |
| if err != nil { |
| return nil, fmt.Errorf("value [%v] of 'request_timeout' could not be understood: %w", sReqTimeout, err) |
| } |
| |
| ssync, ok := conf["sync"] |
| if !ok { |
| ssync = "true" |
| } |
| sync, err := strconv.ParseBool(ssync) |
| if err != nil { |
| return nil, fmt.Errorf("value of 'sync' (%v) could not be understood: %w", ssync, err) |
| } |
| |
| if sync { |
| ctx, cancel := context.WithTimeout(context.Background(), reqTimeout) |
| err := etcd.Sync(ctx) |
| cancel() |
| if err != nil { |
| return nil, err |
| } |
| } |
| |
| sLock := conf["lock_timeout"] |
| if sLock == "" { |
| // etcd3 default lease duration is 60s. set to 15s for faster recovery. |
| sLock = "15s" |
| } |
| lock, err := parseutil.ParseDurationSecond(sLock) |
| if err != nil { |
| return nil, fmt.Errorf("value [%v] of 'lock_timeout' could not be understood: %w", sLock, err) |
| } |
| |
| return &EtcdBackend{ |
| path: path, |
| etcd: etcd, |
| permitPool: physical.NewPermitPool(physical.DefaultParallelOperations), |
| logger: logger, |
| haEnabled: haEnabledBool, |
| lockTimeout: lock, |
| requestTimeout: reqTimeout, |
| }, nil |
| } |
| |
| func (c *EtcdBackend) Put(ctx context.Context, entry *physical.Entry) error { |
| defer metrics.MeasureSince([]string{"etcd", "put"}, time.Now()) |
| |
| c.permitPool.Acquire() |
| defer c.permitPool.Release() |
| |
| ctx, cancel := context.WithTimeout(context.Background(), c.requestTimeout) |
| defer cancel() |
| _, err := c.etcd.Put(ctx, path.Join(c.path, entry.Key), string(entry.Value)) |
| return err |
| } |
| |
| func (c *EtcdBackend) Get(ctx context.Context, key string) (*physical.Entry, error) { |
| defer metrics.MeasureSince([]string{"etcd", "get"}, time.Now()) |
| |
| c.permitPool.Acquire() |
| defer c.permitPool.Release() |
| |
| ctx, cancel := context.WithTimeout(context.Background(), c.requestTimeout) |
| defer cancel() |
| resp, err := c.etcd.Get(ctx, path.Join(c.path, key)) |
| if err != nil { |
| return nil, err |
| } |
| |
| if len(resp.Kvs) == 0 { |
| return nil, nil |
| } |
| if len(resp.Kvs) > 1 { |
| return nil, errors.New("unexpected number of keys from a get request") |
| } |
| return &physical.Entry{ |
| Key: key, |
| Value: resp.Kvs[0].Value, |
| }, nil |
| } |
| |
| func (c *EtcdBackend) Delete(ctx context.Context, key string) error { |
| defer metrics.MeasureSince([]string{"etcd", "delete"}, time.Now()) |
| |
| c.permitPool.Acquire() |
| defer c.permitPool.Release() |
| |
| ctx, cancel := context.WithTimeout(context.Background(), c.requestTimeout) |
| defer cancel() |
| _, err := c.etcd.Delete(ctx, path.Join(c.path, key)) |
| if err != nil { |
| return err |
| } |
| return nil |
| } |
| |
| func (c *EtcdBackend) List(ctx context.Context, prefix string) ([]string, error) { |
| defer metrics.MeasureSince([]string{"etcd", "list"}, time.Now()) |
| |
| c.permitPool.Acquire() |
| defer c.permitPool.Release() |
| |
| ctx, cancel := context.WithTimeout(context.Background(), c.requestTimeout) |
| defer cancel() |
| prefix = path.Join(c.path, prefix) + "/" |
| resp, err := c.etcd.Get(ctx, prefix, clientv3.WithPrefix(), clientv3.WithKeysOnly()) |
| if err != nil { |
| return nil, err |
| } |
| |
| keys := []string{} |
| for _, kv := range resp.Kvs { |
| key := strings.TrimPrefix(string(kv.Key), prefix) |
| key = strings.TrimPrefix(key, "/") |
| |
| if len(key) == 0 { |
| continue |
| } |
| |
| if i := strings.Index(key, "/"); i == -1 { |
| keys = append(keys, key) |
| } else if i != -1 { |
| keys = strutil.AppendIfMissing(keys, key[:i+1]) |
| } |
| } |
| return keys, nil |
| } |
| |
| func (e *EtcdBackend) HAEnabled() bool { |
| return e.haEnabled |
| } |
| |
| // EtcdLock implements a lock using and etcd backend. |
| type EtcdLock struct { |
| lock sync.Mutex |
| held bool |
| timeout time.Duration |
| requestTimeout time.Duration |
| |
| etcdSession *concurrency.Session |
| etcdMu *concurrency.Mutex |
| |
| prefix string |
| value string |
| |
| etcd *clientv3.Client |
| } |
| |
| // Lock is used for mutual exclusion based on the given key. |
| func (c *EtcdBackend) LockWith(key, value string) (physical.Lock, error) { |
| p := path.Join(c.path, key) |
| return &EtcdLock{ |
| prefix: p, |
| value: value, |
| etcd: c.etcd, |
| timeout: c.lockTimeout, |
| requestTimeout: c.requestTimeout, |
| }, nil |
| } |
| |
| func (c *EtcdLock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) { |
| c.lock.Lock() |
| defer c.lock.Unlock() |
| |
| if c.etcdMu == nil { |
| if err := c.initMu(); err != nil { |
| return nil, err |
| } |
| } |
| |
| if c.held { |
| return nil, EtcdLockHeldError |
| } |
| |
| select { |
| case _, ok := <-c.etcdSession.Done(): |
| if !ok { |
| // The session's done channel is closed, so the session is over, |
| // and we need a new lock with a new session. |
| if err := c.initMu(); err != nil { |
| return nil, err |
| } |
| } |
| default: |
| } |
| |
| ctx, cancel := context.WithCancel(context.Background()) |
| go func() { |
| <-stopCh |
| cancel() |
| }() |
| if err := c.etcdMu.Lock(ctx); err != nil { |
| if err == context.Canceled { |
| return nil, nil |
| } |
| return nil, err |
| } |
| |
| pctx, cancel := context.WithTimeout(context.Background(), c.requestTimeout) |
| defer cancel() |
| if _, err := c.etcd.Put(pctx, c.etcdMu.Key(), c.value, clientv3.WithLease(c.etcdSession.Lease())); err != nil { |
| return nil, err |
| } |
| |
| c.held = true |
| |
| return c.etcdSession.Done(), nil |
| } |
| |
| func (c *EtcdLock) Unlock() error { |
| c.lock.Lock() |
| defer c.lock.Unlock() |
| |
| if !c.held { |
| return EtcdLockNotHeldError |
| } |
| |
| ctx, cancel := context.WithTimeout(context.Background(), c.requestTimeout) |
| defer cancel() |
| return c.etcdMu.Unlock(ctx) |
| } |
| |
| func (c *EtcdLock) Value() (bool, string, error) { |
| ctx, cancel := context.WithTimeout(context.Background(), c.requestTimeout) |
| defer cancel() |
| |
| resp, err := c.etcd.Get(ctx, |
| c.prefix, clientv3.WithPrefix(), |
| clientv3.WithSort(clientv3.SortByCreateRevision, clientv3.SortAscend)) |
| if err != nil { |
| return false, "", err |
| } |
| if len(resp.Kvs) == 0 { |
| return false, "", nil |
| } |
| |
| return true, string(resp.Kvs[0].Value), nil |
| } |
| |
| func (c *EtcdLock) initMu() error { |
| session, err := concurrency.NewSession(c.etcd, concurrency.WithTTL(int(c.timeout.Seconds()))) |
| if err != nil { |
| return err |
| } |
| c.etcdSession = session |
| c.etcdMu = concurrency.NewMutex(session, c.prefix) |
| return nil |
| } |