| // Copyright (c) HashiCorp, Inc. |
| // SPDX-License-Identifier: MPL-2.0 |
| |
| package ssh |
| |
| import ( |
| "context" |
| "crypto/rand" |
| "crypto/rsa" |
| "crypto/x509" |
| "encoding/base64" |
| "encoding/pem" |
| "fmt" |
| "net" |
| "strings" |
| |
| "github.com/hashicorp/go-secure-stdlib/parseutil" |
| "github.com/hashicorp/vault/sdk/logical" |
| "golang.org/x/crypto/ssh" |
| ) |
| |
| // Creates a new RSA key pair with the given key length. The private key will be |
| // of pem format and the public key will be of OpenSSH format. |
| func generateRSAKeys(keyBits int) (publicKeyRsa string, privateKeyRsa string, err error) { |
| privateKey, err := rsa.GenerateKey(rand.Reader, keyBits) |
| if err != nil { |
| return "", "", fmt.Errorf("error generating RSA key-pair: %w", err) |
| } |
| |
| privateKeyRsa = string(pem.EncodeToMemory(&pem.Block{ |
| Type: "RSA PRIVATE KEY", |
| Bytes: x509.MarshalPKCS1PrivateKey(privateKey), |
| })) |
| |
| sshPublicKey, err := ssh.NewPublicKey(privateKey.Public()) |
| if err != nil { |
| return "", "", fmt.Errorf("error generating RSA key-pair: %w", err) |
| } |
| publicKeyRsa = "ssh-rsa " + base64.StdEncoding.EncodeToString(sshPublicKey.Marshal()) |
| return |
| } |
| |
| // Takes an IP address and role name and checks if the IP is part |
| // of CIDR blocks belonging to the role. |
| func roleContainsIP(ctx context.Context, s logical.Storage, roleName string, ip string) (bool, error) { |
| if roleName == "" { |
| return false, fmt.Errorf("missing role name") |
| } |
| |
| if ip == "" { |
| return false, fmt.Errorf("missing ip") |
| } |
| |
| roleEntry, err := s.Get(ctx, fmt.Sprintf("roles/%s", roleName)) |
| if err != nil { |
| return false, fmt.Errorf("error retrieving role %w", err) |
| } |
| if roleEntry == nil { |
| return false, fmt.Errorf("role %q not found", roleName) |
| } |
| |
| var role sshRole |
| if err := roleEntry.DecodeJSON(&role); err != nil { |
| return false, fmt.Errorf("error decoding role %q", roleName) |
| } |
| |
| if matched, err := cidrListContainsIP(ip, role.CIDRList); err != nil { |
| return false, err |
| } else { |
| return matched, nil |
| } |
| } |
| |
| // Returns true if the IP supplied by the user is part of the comma |
| // separated CIDR blocks |
| func cidrListContainsIP(ip, cidrList string) (bool, error) { |
| if len(cidrList) == 0 { |
| return false, fmt.Errorf("IP does not belong to role") |
| } |
| for _, item := range strings.Split(cidrList, ",") { |
| _, cidrIPNet, err := net.ParseCIDR(item) |
| if err != nil { |
| return false, fmt.Errorf("invalid CIDR entry %q", item) |
| } |
| if cidrIPNet.Contains(net.ParseIP(ip)) { |
| return true, nil |
| } |
| } |
| return false, nil |
| } |
| |
| func parsePublicSSHKey(key string) (ssh.PublicKey, error) { |
| keyParts := strings.Split(key, " ") |
| if len(keyParts) > 1 { |
| // Someone has sent the 'full' public key rather than just the base64 encoded part that the ssh library wants |
| key = keyParts[1] |
| } |
| |
| decodedKey, err := base64.StdEncoding.DecodeString(key) |
| if err != nil { |
| return nil, err |
| } |
| |
| return ssh.ParsePublicKey([]byte(decodedKey)) |
| } |
| |
| func convertMapToStringValue(initial map[string]interface{}) map[string]string { |
| result := map[string]string{} |
| for key, value := range initial { |
| result[key] = fmt.Sprintf("%v", value) |
| } |
| return result |
| } |
| |
| func convertMapToIntSlice(initial map[string]interface{}) (map[string][]int, error) { |
| var err error |
| result := map[string][]int{} |
| |
| for key, value := range initial { |
| result[key], err = parseutil.SafeParseIntSlice(value, 0 /* no upper bound on number of keys lengths per key type */) |
| if err != nil { |
| return nil, err |
| } |
| } |
| |
| return result, nil |
| } |
| |
| // Serve a template processor for custom format inputs |
| func substQuery(tpl string, data map[string]string) string { |
| for k, v := range data { |
| tpl = strings.ReplaceAll(tpl, fmt.Sprintf("{{%s}}", k), v) |
| } |
| |
| return tpl |
| } |