| // Copyright (c) HashiCorp, Inc. |
| // SPDX-License-Identifier: MPL-2.0 |
| |
| package raft |
| |
| import ( |
| "bytes" |
| "context" |
| "crypto/ecdsa" |
| "crypto/elliptic" |
| "crypto/rand" |
| "crypto/tls" |
| "crypto/x509" |
| "crypto/x509/pkix" |
| "errors" |
| fmt "fmt" |
| "io" |
| "math/big" |
| mathrand "math/rand" |
| "net" |
| "net/url" |
| "sync" |
| "time" |
| |
| log "github.com/hashicorp/go-hclog" |
| uuid "github.com/hashicorp/go-uuid" |
| "github.com/hashicorp/raft" |
| "github.com/hashicorp/vault/sdk/helper/certutil" |
| "github.com/hashicorp/vault/sdk/helper/consts" |
| "github.com/hashicorp/vault/vault/cluster" |
| ) |
| |
| // TLSKey is a single TLS keypair in the Keyring |
| type TLSKey struct { |
| // ID is a unique identifier for this Key |
| ID string `json:"id"` |
| |
| // KeyType defines the algorighm used to generate the private keys |
| KeyType string `json:"key_type"` |
| |
| // AppliedIndex is the earliest known raft index that safely contains this |
| // key. |
| AppliedIndex uint64 `json:"applied_index"` |
| |
| // CertBytes is the marshaled certificate. |
| CertBytes []byte `json:"cluster_cert"` |
| |
| // KeyParams is the marshaled private key. |
| KeyParams *certutil.ClusterKeyParams `json:"cluster_key_params"` |
| |
| // CreatedTime is the time this key was generated. This value is useful in |
| // determining when the next rotation should be. |
| CreatedTime time.Time `json:"created_time"` |
| |
| parsedCert *x509.Certificate |
| parsedKey *ecdsa.PrivateKey |
| } |
| |
| // TLSKeyring is the set of keys that raft uses for network communication. |
| // Only one key is used to dial at a time but both keys will be used to accept |
| // connections. |
| type TLSKeyring struct { |
| // Keys is the set of available key pairs |
| Keys []*TLSKey `json:"keys"` |
| |
| // AppliedIndex is the earliest known raft index that safely contains the |
| // latest key in the keyring. |
| AppliedIndex uint64 `json:"applied_index"` |
| |
| // Term is an incrementing identifier value used to quickly determine if two |
| // states of the keyring are different. |
| Term uint64 `json:"term"` |
| |
| // ActiveKeyID is the key ID to track the active key in the keyring. Only |
| // the active key is used for dialing. |
| ActiveKeyID string `json:"active_key_id"` |
| } |
| |
| // GetActive returns the active key. |
| func (k *TLSKeyring) GetActive() *TLSKey { |
| if k.ActiveKeyID == "" { |
| return nil |
| } |
| |
| for _, key := range k.Keys { |
| if key.ID == k.ActiveKeyID { |
| return key |
| } |
| } |
| return nil |
| } |
| |
| func GenerateTLSKey(reader io.Reader) (*TLSKey, error) { |
| key, err := ecdsa.GenerateKey(elliptic.P521(), reader) |
| if err != nil { |
| return nil, err |
| } |
| |
| host, err := uuid.GenerateUUID() |
| if err != nil { |
| return nil, err |
| } |
| host = fmt.Sprintf("raft-%s", host) |
| template := &x509.Certificate{ |
| Subject: pkix.Name{ |
| CommonName: host, |
| }, |
| DNSNames: []string{host}, |
| ExtKeyUsage: []x509.ExtKeyUsage{ |
| x509.ExtKeyUsageServerAuth, |
| x509.ExtKeyUsageClientAuth, |
| }, |
| KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement | x509.KeyUsageCertSign, |
| SerialNumber: big.NewInt(mathrand.Int63()), |
| NotBefore: time.Now().Add(-30 * time.Second), |
| // 30 years ought to be enough for anybody |
| NotAfter: time.Now().Add(262980 * time.Hour), |
| BasicConstraintsValid: true, |
| IsCA: true, |
| } |
| |
| certBytes, err := x509.CreateCertificate(rand.Reader, template, template, key.Public(), key) |
| if err != nil { |
| return nil, fmt.Errorf("unable to generate local cluster certificate: %w", err) |
| } |
| |
| return &TLSKey{ |
| ID: host, |
| KeyType: certutil.PrivateKeyTypeP521, |
| CertBytes: certBytes, |
| KeyParams: &certutil.ClusterKeyParams{ |
| Type: certutil.PrivateKeyTypeP521, |
| X: key.PublicKey.X, |
| Y: key.PublicKey.Y, |
| D: key.D, |
| }, |
| CreatedTime: time.Now(), |
| }, nil |
| } |
| |
| var ( |
| // Make sure raftLayer satisfies the raft.StreamLayer interface |
| _ raft.StreamLayer = (*raftLayer)(nil) |
| |
| // Make sure raftLayer satisfies the cluster.Handler and cluster.Client |
| // interfaces |
| _ cluster.Handler = (*raftLayer)(nil) |
| _ cluster.Client = (*raftLayer)(nil) |
| ) |
| |
| // RaftLayer implements the raft.StreamLayer interface, |
| // so that we can use a single RPC layer for Raft and Vault |
| type raftLayer struct { |
| // Addr is the listener address to return |
| addr net.Addr |
| |
| // connCh is used to accept connections |
| connCh chan net.Conn |
| |
| // Tracks if we are closed |
| closed bool |
| closeCh chan struct{} |
| closeLock sync.Mutex |
| |
| logger log.Logger |
| |
| dialerFunc func(string, time.Duration) (net.Conn, error) |
| |
| // TLS config |
| keyring *TLSKeyring |
| clusterListener cluster.ClusterHook |
| } |
| |
| // NewRaftLayer creates a new raftLayer object. It parses the TLS information |
| // from the network config. |
| func NewRaftLayer(logger log.Logger, raftTLSKeyring *TLSKeyring, clusterListener cluster.ClusterHook) (*raftLayer, error) { |
| clusterAddr := clusterListener.Addr() |
| if clusterAddr == nil { |
| return nil, errors.New("no raft addr found") |
| } |
| |
| { |
| // Test the advertised address to make sure it's not an unspecified IP |
| u := url.URL{ |
| Host: clusterAddr.String(), |
| } |
| ip := net.ParseIP(u.Hostname()) |
| if ip != nil && ip.IsUnspecified() { |
| return nil, fmt.Errorf("cannot use unspecified IP with raft storage: %s", clusterAddr.String()) |
| } |
| } |
| |
| layer := &raftLayer{ |
| addr: clusterAddr, |
| connCh: make(chan net.Conn), |
| closeCh: make(chan struct{}), |
| logger: logger, |
| clusterListener: clusterListener, |
| } |
| |
| if err := layer.setTLSKeyring(raftTLSKeyring); err != nil { |
| return nil, err |
| } |
| |
| return layer, nil |
| } |
| |
| func (l *raftLayer) setTLSKeyring(keyring *TLSKeyring) error { |
| // Fast path a noop update |
| if l.keyring != nil && l.keyring.Term == keyring.Term { |
| return nil |
| } |
| |
| for _, key := range keyring.Keys { |
| switch { |
| case key.KeyParams == nil: |
| return errors.New("no raft cluster key params found") |
| |
| case key.KeyParams.X == nil, key.KeyParams.Y == nil, key.KeyParams.D == nil: |
| return errors.New("failed to parse raft cluster key") |
| |
| case key.KeyParams.Type != certutil.PrivateKeyTypeP521: |
| return errors.New("failed to find valid raft cluster key type") |
| |
| case len(key.CertBytes) == 0: |
| return errors.New("no cluster cert found") |
| } |
| |
| parsedCert, err := x509.ParseCertificate(key.CertBytes) |
| if err != nil { |
| return fmt.Errorf("error parsing raft cluster certificate: %w", err) |
| } |
| |
| key.parsedCert = parsedCert |
| key.parsedKey = &ecdsa.PrivateKey{ |
| PublicKey: ecdsa.PublicKey{ |
| Curve: elliptic.P521(), |
| X: key.KeyParams.X, |
| Y: key.KeyParams.Y, |
| }, |
| D: key.KeyParams.D, |
| } |
| } |
| |
| if keyring.GetActive() == nil { |
| return errors.New("expected one active key to be present in the keyring") |
| } |
| |
| l.keyring = keyring |
| |
| return nil |
| } |
| |
| func (l *raftLayer) ServerName() string { |
| key := l.keyring.GetActive() |
| if key == nil { |
| return "" |
| } |
| |
| return key.parsedCert.Subject.CommonName |
| } |
| |
| func (l *raftLayer) CACert(ctx context.Context) *x509.Certificate { |
| key := l.keyring.GetActive() |
| if key == nil { |
| return nil |
| } |
| |
| return key.parsedCert |
| } |
| |
| func (l *raftLayer) ClientLookup(ctx context.Context, requestInfo *tls.CertificateRequestInfo) (*tls.Certificate, error) { |
| for _, subj := range requestInfo.AcceptableCAs { |
| for _, key := range l.keyring.Keys { |
| if bytes.Equal(subj, key.parsedCert.RawIssuer) { |
| localCert := make([]byte, len(key.CertBytes)) |
| copy(localCert, key.CertBytes) |
| |
| return &tls.Certificate{ |
| Certificate: [][]byte{localCert}, |
| PrivateKey: key.parsedKey, |
| Leaf: key.parsedCert, |
| }, nil |
| } |
| } |
| } |
| |
| return nil, nil |
| } |
| |
| func (l *raftLayer) ServerLookup(ctx context.Context, clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { |
| if l.keyring == nil { |
| return nil, errors.New("got raft connection but no local cert") |
| } |
| |
| for _, key := range l.keyring.Keys { |
| if clientHello.ServerName == key.ID { |
| localCert := make([]byte, len(key.CertBytes)) |
| copy(localCert, key.CertBytes) |
| |
| return &tls.Certificate{ |
| Certificate: [][]byte{localCert}, |
| PrivateKey: key.parsedKey, |
| Leaf: key.parsedCert, |
| }, nil |
| } |
| } |
| |
| return nil, nil |
| } |
| |
| // CALookup returns the CA to use when validating this connection. |
| func (l *raftLayer) CALookup(context.Context) ([]*x509.Certificate, error) { |
| ret := make([]*x509.Certificate, len(l.keyring.Keys)) |
| for i, key := range l.keyring.Keys { |
| ret[i] = key.parsedCert |
| } |
| return ret, nil |
| } |
| |
| // Stop shuts down the raft layer. |
| func (l *raftLayer) Stop() error { |
| l.Close() |
| return nil |
| } |
| |
| // Handoff is used to hand off a connection to the |
| // RaftLayer. This allows it to be Accept()'ed |
| func (l *raftLayer) Handoff(ctx context.Context, wg *sync.WaitGroup, quit chan struct{}, conn *tls.Conn) error { |
| l.closeLock.Lock() |
| closed := l.closed |
| l.closeLock.Unlock() |
| |
| if closed { |
| return errors.New("raft is shutdown") |
| } |
| |
| wg.Add(1) |
| go func() { |
| defer wg.Done() |
| select { |
| case l.connCh <- conn: |
| case <-l.closeCh: |
| case <-ctx.Done(): |
| case <-quit: |
| } |
| }() |
| |
| return nil |
| } |
| |
| // Accept is used to return connection which are |
| // dialed to be used with the Raft layer |
| func (l *raftLayer) Accept() (net.Conn, error) { |
| select { |
| case conn := <-l.connCh: |
| return conn, nil |
| case <-l.closeCh: |
| return nil, fmt.Errorf("Raft RPC layer closed") |
| } |
| } |
| |
| // Close is used to stop listening for Raft connections |
| func (l *raftLayer) Close() error { |
| l.closeLock.Lock() |
| defer l.closeLock.Unlock() |
| |
| if !l.closed { |
| l.closed = true |
| close(l.closeCh) |
| } |
| return nil |
| } |
| |
| // Addr is used to return the address of the listener |
| func (l *raftLayer) Addr() net.Addr { |
| return l.addr |
| } |
| |
| // Dial is used to create a new outgoing connection |
| func (l *raftLayer) Dial(address raft.ServerAddress, timeout time.Duration) (net.Conn, error) { |
| dialFunc := l.clusterListener.GetDialerFunc(context.Background(), consts.RaftStorageALPN) |
| return dialFunc(string(address), timeout) |
| } |