| // Copyright (c) HashiCorp, Inc. |
| // SPDX-License-Identifier: MPL-2.0 |
| |
| package spanner |
| |
| import ( |
| "context" |
| "fmt" |
| "sync" |
| "time" |
| |
| "cloud.google.com/go/spanner" |
| metrics "github.com/armon/go-metrics" |
| uuid "github.com/hashicorp/go-uuid" |
| "github.com/hashicorp/vault/sdk/physical" |
| "github.com/pkg/errors" |
| "google.golang.org/grpc/codes" |
| ) |
| |
| // Verify Backend satisfies the correct interfaces |
| var ( |
| _ physical.HABackend = (*Backend)(nil) |
| _ physical.Lock = (*Lock)(nil) |
| ) |
| |
| const ( |
| // LockRenewInterval is the time to wait between lock renewals. |
| LockRenewInterval = 5 * time.Second |
| |
| // LockRetryInterval is the amount of time to wait if the lock fails before |
| // trying again. |
| LockRetryInterval = 5 * time.Second |
| |
| // LockTTL is the default lock TTL. |
| LockTTL = 15 * time.Second |
| |
| // LockWatchRetryInterval is the amount of time to wait if a watch fails |
| // before trying again. |
| LockWatchRetryInterval = 5 * time.Second |
| |
| // LockWatchRetryMax is the number of times to retry a failed watch before |
| // signaling that leadership is lost. |
| LockWatchRetryMax = 5 |
| ) |
| |
| var ( |
| // metricLockUnlock is the metric to register for a lock delete. |
| metricLockUnlock = []string{"spanner", "lock", "unlock"} |
| |
| // metricLockGet is the metric to register for a lock get. |
| metricLockLock = []string{"spanner", "lock", "lock"} |
| |
| // metricLockValue is the metric to register for a lock create/update. |
| metricLockValue = []string{"spanner", "lock", "value"} |
| ) |
| |
| // Lock is the HA lock. |
| type Lock struct { |
| // backend is the underlying physical backend. |
| backend *Backend |
| |
| // key is the name of the key. value is the value of the key. |
| key, value string |
| |
| // held is a boolean indicating if the lock is currently held. |
| held bool |
| |
| // identity is the internal identity of this key (unique to this server |
| // instance). |
| identity string |
| |
| // lock is an internal lock |
| lock sync.Mutex |
| |
| // stopCh is the channel that stops all operations. It may be closed in the |
| // event of a leader loss or graceful shutdown. stopped is a boolean |
| // indicating if we are stopped - it exists to prevent double closing the |
| // channel. stopLock is a mutex around the locks. |
| stopCh chan struct{} |
| stopped bool |
| stopLock sync.Mutex |
| |
| // Allow modifying the Lock durations for ease of unit testing. |
| renewInterval time.Duration |
| retryInterval time.Duration |
| ttl time.Duration |
| watchRetryInterval time.Duration |
| watchRetryMax int |
| } |
| |
| // LockRecord is the struct that corresponds to a lock. |
| type LockRecord struct { |
| Key string |
| Value string |
| Identity string |
| Timestamp time.Time |
| } |
| |
| // HAEnabled implements HABackend and indicates that this backend supports high |
| // availability. |
| func (b *Backend) HAEnabled() bool { |
| return b.haEnabled |
| } |
| |
| // LockWith acquires a mutual exclusion based on the given key. |
| func (b *Backend) LockWith(key, value string) (physical.Lock, error) { |
| identity, err := uuid.GenerateUUID() |
| if err != nil { |
| return nil, fmt.Errorf("lock with: %w", err) |
| } |
| return &Lock{ |
| backend: b, |
| key: key, |
| value: value, |
| identity: identity, |
| stopped: true, |
| |
| renewInterval: LockRenewInterval, |
| retryInterval: LockRetryInterval, |
| ttl: LockTTL, |
| watchRetryInterval: LockWatchRetryInterval, |
| watchRetryMax: LockWatchRetryMax, |
| }, nil |
| } |
| |
| // Lock acquires the given lock. The stopCh is optional. If closed, it |
| // interrupts the lock acquisition attempt. The returned channel should be |
| // closed when leadership is lost. |
| func (l *Lock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) { |
| defer metrics.MeasureSince(metricLockLock, time.Now()) |
| |
| l.lock.Lock() |
| defer l.lock.Unlock() |
| if l.held { |
| return nil, errors.New("lock already held") |
| } |
| |
| // Attempt to lock - this function blocks until a lock is acquired or an error |
| // occurs. |
| acquired, err := l.attemptLock(stopCh) |
| if err != nil { |
| return nil, fmt.Errorf("lock: %w", err) |
| } |
| if !acquired { |
| return nil, nil |
| } |
| |
| // We have the lock now |
| l.held = true |
| |
| // Build the locks |
| l.stopLock.Lock() |
| l.stopCh = make(chan struct{}) |
| l.stopped = false |
| l.stopLock.Unlock() |
| |
| // Periodically renew and watch the lock |
| go l.renewLock() |
| go l.watchLock() |
| |
| return l.stopCh, nil |
| } |
| |
| // Unlock releases the lock. |
| func (l *Lock) Unlock() error { |
| defer metrics.MeasureSince(metricLockUnlock, time.Now()) |
| |
| l.lock.Lock() |
| defer l.lock.Unlock() |
| if !l.held { |
| return nil |
| } |
| |
| // Stop any existing locking or renewal attempts |
| l.stopLock.Lock() |
| if !l.stopped { |
| l.stopped = true |
| close(l.stopCh) |
| } |
| l.stopLock.Unlock() |
| |
| // Delete |
| ctx := context.Background() |
| if _, err := l.backend.haClient.ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { |
| row, err := txn.ReadRow(ctx, l.backend.haTable, spanner.Key{l.key}, []string{"Identity"}) |
| if err != nil { |
| if spanner.ErrCode(err) != codes.NotFound { |
| return nil |
| } |
| return err |
| } |
| |
| var r LockRecord |
| if derr := row.ToStruct(&r); derr != nil { |
| return fmt.Errorf("failed to decode to struct: %w", derr) |
| } |
| |
| // If the identity is different, that means that between the time that after |
| // we stopped acquisition, the TTL expired and someone else grabbed the |
| // lock. We do not want to delete a lock that is not our own. |
| if r.Identity != l.identity { |
| return nil |
| } |
| |
| return txn.BufferWrite([]*spanner.Mutation{ |
| spanner.Delete(l.backend.haTable, spanner.Key{l.key}), |
| }) |
| }); err != nil { |
| return fmt.Errorf("unlock: %w", err) |
| } |
| |
| // We are no longer holding the lock |
| l.held = false |
| |
| return nil |
| } |
| |
| // Value returns the value of the lock and if it is held. |
| func (l *Lock) Value() (bool, string, error) { |
| defer metrics.MeasureSince(metricLockValue, time.Now()) |
| |
| r, err := l.get(context.Background()) |
| if err != nil { |
| return false, "", err |
| } |
| if r == nil { |
| return false, "", err |
| } |
| return true, string(r.Value), nil |
| } |
| |
| // attemptLock attempts to acquire a lock. If the given channel is closed, the |
| // acquisition attempt stops. This function returns when a lock is acquired or |
| // an error occurs. |
| func (l *Lock) attemptLock(stopCh <-chan struct{}) (bool, error) { |
| ticker := time.NewTicker(l.retryInterval) |
| defer ticker.Stop() |
| |
| for { |
| select { |
| case <-ticker.C: |
| acquired, err := l.writeLock() |
| if err != nil { |
| return false, fmt.Errorf("attempt lock: %w", err) |
| } |
| if !acquired { |
| continue |
| } |
| |
| return true, nil |
| case <-stopCh: |
| return false, nil |
| } |
| } |
| } |
| |
| // renewLock renews the given lock until the channel is closed. |
| func (l *Lock) renewLock() { |
| ticker := time.NewTicker(l.renewInterval) |
| defer ticker.Stop() |
| |
| for { |
| select { |
| case <-ticker.C: |
| l.writeLock() |
| case <-l.stopCh: |
| return |
| } |
| } |
| } |
| |
| // watchLock checks whether the lock has changed in the table and closes the |
| // leader channel accordingly. If an error occurs during the check, watchLock |
| // will retry the operation and then close the leader channel if it can't |
| // succeed after retries. |
| func (l *Lock) watchLock() { |
| retries := 0 |
| ticker := time.NewTicker(l.watchRetryInterval) |
| |
| OUTER: |
| for { |
| // Check if the channel is already closed |
| select { |
| case <-l.stopCh: |
| break OUTER |
| default: |
| } |
| |
| // Check if we've exceeded retries |
| if retries >= l.watchRetryMax-1 { |
| break OUTER |
| } |
| |
| // Wait for the timer |
| select { |
| case <-ticker.C: |
| case <-l.stopCh: |
| break OUTER |
| } |
| |
| // Attempt to read the key |
| r, err := l.get(context.Background()) |
| if err != nil { |
| retries++ |
| continue |
| } |
| |
| // Verify the identity is the same |
| if r == nil || r.Identity != l.identity { |
| break OUTER |
| } |
| } |
| |
| l.stopLock.Lock() |
| defer l.stopLock.Unlock() |
| if !l.stopped { |
| l.stopped = true |
| close(l.stopCh) |
| } |
| } |
| |
| // writeLock writes the given lock using the following algorithm: |
| // |
| // - lock does not exist |
| // - write the lock |
| // |
| // - lock exists |
| // - if key is empty or identity is the same or timestamp exceeds TTL |
| // - update the lock to self |
| func (l *Lock) writeLock() (bool, error) { |
| // Keep track of whether the lock was written |
| lockWritten := false |
| |
| // Create a transaction to read and the update (maybe) |
| ctx, cancel := context.WithCancel(context.Background()) |
| defer cancel() |
| |
| // The transaction will be retried, and it could sit in a queue behind, say, |
| // the delete operation. To stop the transaction, we close the context when |
| // the associated stopCh is received. |
| go func() { |
| select { |
| case <-l.stopCh: |
| cancel() |
| case <-ctx.Done(): |
| } |
| }() |
| |
| _, err := l.backend.haClient.ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { |
| row, err := txn.ReadRow(ctx, l.backend.haTable, spanner.Key{l.key}, []string{"Key", "Identity", "Timestamp"}) |
| if err != nil && spanner.ErrCode(err) != codes.NotFound { |
| return err |
| } |
| |
| // If there was a record, verify that the record is still trustable. |
| if row != nil { |
| var r LockRecord |
| if derr := row.ToStruct(&r); derr != nil { |
| return fmt.Errorf("failed to decode to struct: %w", derr) |
| } |
| |
| // If the key is empty or the identity is ours or the ttl expired, we can |
| // write. Otherwise, return now because we cannot. |
| if r.Key != "" && r.Identity != l.identity && time.Now().UTC().Sub(r.Timestamp) < l.ttl { |
| return nil |
| } |
| } |
| |
| m, err := spanner.InsertOrUpdateStruct(l.backend.haTable, &LockRecord{ |
| Key: l.key, |
| Value: l.value, |
| Identity: l.identity, |
| Timestamp: time.Now().UTC(), |
| }) |
| if err != nil { |
| return fmt.Errorf("failed to generate struct: %w", err) |
| } |
| if err := txn.BufferWrite([]*spanner.Mutation{m}); err != nil { |
| return fmt.Errorf("failed to write: %w", err) |
| } |
| |
| // Mark that the lock was acquired |
| lockWritten = true |
| |
| return nil |
| }) |
| if err != nil { |
| return false, fmt.Errorf("write lock: %w", err) |
| } |
| |
| return lockWritten, nil |
| } |
| |
| // get retrieves the value for the lock. |
| func (l *Lock) get(ctx context.Context) (*LockRecord, error) { |
| // Read |
| row, err := l.backend.haClient.Single().ReadRow(ctx, l.backend.haTable, spanner.Key{l.key}, []string{"Key", "Value", "Timestamp", "Identity"}) |
| if spanner.ErrCode(err) == codes.NotFound { |
| return nil, nil |
| } |
| if err != nil { |
| return nil, fmt.Errorf("failed to read value for %q: %w", l.key, err) |
| } |
| |
| var r LockRecord |
| if err := row.ToStruct(&r); err != nil { |
| return nil, fmt.Errorf("failed to decode lock: %w", err) |
| } |
| return &r, nil |
| } |