| // Copyright (c) HashiCorp, Inc. |
| // SPDX-License-Identifier: MPL-2.0 |
| |
| package mysql |
| |
| import ( |
| "context" |
| "crypto/tls" |
| "crypto/x509" |
| "database/sql" |
| "fmt" |
| "net/url" |
| "sync" |
| "time" |
| |
| "github.com/go-sql-driver/mysql" |
| "github.com/hashicorp/go-secure-stdlib/parseutil" |
| "github.com/hashicorp/go-uuid" |
| "github.com/hashicorp/vault/sdk/database/helper/connutil" |
| "github.com/hashicorp/vault/sdk/database/helper/dbutil" |
| "github.com/mitchellh/mapstructure" |
| ) |
| |
| // mySQLConnectionProducer implements ConnectionProducer and provides a generic producer for most sql databases |
| type mySQLConnectionProducer struct { |
| ConnectionURL string `json:"connection_url" mapstructure:"connection_url" structs:"connection_url"` |
| MaxOpenConnections int `json:"max_open_connections" mapstructure:"max_open_connections" structs:"max_open_connections"` |
| MaxIdleConnections int `json:"max_idle_connections" mapstructure:"max_idle_connections" structs:"max_idle_connections"` |
| MaxConnectionLifetimeRaw interface{} `json:"max_connection_lifetime" mapstructure:"max_connection_lifetime" structs:"max_connection_lifetime"` |
| Username string `json:"username" mapstructure:"username" structs:"username"` |
| Password string `json:"password" mapstructure:"password" structs:"password"` |
| |
| TLSCertificateKeyData []byte `json:"tls_certificate_key" mapstructure:"tls_certificate_key" structs:"-"` |
| TLSCAData []byte `json:"tls_ca" mapstructure:"tls_ca" structs:"-"` |
| TLSServerName string `json:"tls_server_name" mapstructure:"tls_server_name" structs:"tls_server_name"` |
| TLSSkipVerify bool `json:"tls_skip_verify" mapstructure:"tls_skip_verify" structs:"tls_skip_verify"` |
| |
| // tlsConfigName is a globally unique name that references the TLS config for this instance in the mysql driver |
| tlsConfigName string |
| |
| RawConfig map[string]interface{} |
| maxConnectionLifetime time.Duration |
| Initialized bool |
| db *sql.DB |
| sync.Mutex |
| } |
| |
| func (c *mySQLConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error { |
| _, err := c.Init(ctx, conf, verifyConnection) |
| return err |
| } |
| |
| func (c *mySQLConnectionProducer) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (map[string]interface{}, error) { |
| c.Lock() |
| defer c.Unlock() |
| |
| c.RawConfig = conf |
| |
| err := mapstructure.WeakDecode(conf, &c) |
| if err != nil { |
| return nil, err |
| } |
| |
| if len(c.ConnectionURL) == 0 { |
| return nil, fmt.Errorf("connection_url cannot be empty") |
| } |
| |
| // Don't escape special characters for MySQL password |
| password := c.Password |
| |
| // QueryHelper doesn't do any SQL escaping, but if it starts to do so |
| // then maybe we won't be able to use it to do URL substitution any more. |
| c.ConnectionURL = dbutil.QueryHelper(c.ConnectionURL, map[string]string{ |
| "username": url.PathEscape(c.Username), |
| "password": password, |
| }) |
| |
| if c.MaxOpenConnections == 0 { |
| c.MaxOpenConnections = 4 |
| } |
| |
| if c.MaxIdleConnections == 0 { |
| c.MaxIdleConnections = c.MaxOpenConnections |
| } |
| if c.MaxIdleConnections > c.MaxOpenConnections { |
| c.MaxIdleConnections = c.MaxOpenConnections |
| } |
| if c.MaxConnectionLifetimeRaw == nil { |
| c.MaxConnectionLifetimeRaw = "0s" |
| } |
| |
| c.maxConnectionLifetime, err = parseutil.ParseDurationSecond(c.MaxConnectionLifetimeRaw) |
| if err != nil { |
| return nil, fmt.Errorf("invalid max_connection_lifetime: %w", err) |
| } |
| |
| tlsConfig, err := c.getTLSAuth() |
| if err != nil { |
| return nil, err |
| } |
| |
| if tlsConfig != nil { |
| if c.tlsConfigName == "" { |
| c.tlsConfigName, err = uuid.GenerateUUID() |
| if err != nil { |
| return nil, fmt.Errorf("unable to generate UUID for TLS configuration: %w", err) |
| } |
| } |
| |
| mysql.RegisterTLSConfig(c.tlsConfigName, tlsConfig) |
| } |
| |
| // Set initialized to true at this point since all fields are set, |
| // and the connection can be established at a later time. |
| c.Initialized = true |
| |
| if verifyConnection { |
| if _, err = c.Connection(ctx); err != nil { |
| return nil, fmt.Errorf("error verifying - connection: %w", err) |
| } |
| |
| if err := c.db.PingContext(ctx); err != nil { |
| return nil, fmt.Errorf("error verifying - ping: %w", err) |
| } |
| } |
| |
| return c.RawConfig, nil |
| } |
| |
| func (c *mySQLConnectionProducer) Connection(ctx context.Context) (interface{}, error) { |
| if !c.Initialized { |
| return nil, connutil.ErrNotInitialized |
| } |
| |
| // If we already have a DB, test it and return |
| if c.db != nil { |
| if err := c.db.PingContext(ctx); err == nil { |
| return c.db, nil |
| } |
| // If the ping was unsuccessful, close it and ignore errors as we'll be |
| // reestablishing anyways |
| c.db.Close() |
| } |
| |
| connURL, err := c.addTLStoDSN() |
| if err != nil { |
| return nil, err |
| } |
| |
| c.db, err = sql.Open("mysql", connURL) |
| if err != nil { |
| return nil, err |
| } |
| |
| // Set some connection pool settings. We don't need much of this, |
| // since the request rate shouldn't be high. |
| c.db.SetMaxOpenConns(c.MaxOpenConnections) |
| c.db.SetMaxIdleConns(c.MaxIdleConnections) |
| c.db.SetConnMaxLifetime(c.maxConnectionLifetime) |
| |
| return c.db, nil |
| } |
| |
| func (c *mySQLConnectionProducer) SecretValues() map[string]string { |
| return map[string]string{ |
| c.Password: "[password]", |
| } |
| } |
| |
| // Close attempts to close the connection |
| func (c *mySQLConnectionProducer) Close() error { |
| // Grab the write lock |
| c.Lock() |
| defer c.Unlock() |
| |
| if c.db != nil { |
| c.db.Close() |
| } |
| |
| c.db = nil |
| |
| return nil |
| } |
| |
| func (c *mySQLConnectionProducer) getTLSAuth() (tlsConfig *tls.Config, err error) { |
| if len(c.TLSCAData) == 0 && |
| len(c.TLSCertificateKeyData) == 0 { |
| return nil, nil |
| } |
| |
| rootCertPool := x509.NewCertPool() |
| if len(c.TLSCAData) > 0 { |
| ok := rootCertPool.AppendCertsFromPEM(c.TLSCAData) |
| if !ok { |
| return nil, fmt.Errorf("failed to append CA to client options") |
| } |
| } |
| |
| clientCert := make([]tls.Certificate, 0, 1) |
| |
| if len(c.TLSCertificateKeyData) > 0 { |
| certificate, err := tls.X509KeyPair(c.TLSCertificateKeyData, c.TLSCertificateKeyData) |
| if err != nil { |
| return nil, fmt.Errorf("unable to load tls_certificate_key_data: %w", err) |
| } |
| |
| clientCert = append(clientCert, certificate) |
| } |
| |
| tlsConfig = &tls.Config{ |
| RootCAs: rootCertPool, |
| Certificates: clientCert, |
| ServerName: c.TLSServerName, |
| InsecureSkipVerify: c.TLSSkipVerify, |
| } |
| |
| return tlsConfig, nil |
| } |
| |
| func (c *mySQLConnectionProducer) addTLStoDSN() (connURL string, err error) { |
| config, err := mysql.ParseDSN(c.ConnectionURL) |
| if err != nil { |
| return "", fmt.Errorf("unable to parse connectionURL: %s", err) |
| } |
| |
| if len(c.tlsConfigName) > 0 { |
| config.TLSConfig = c.tlsConfigName |
| } |
| |
| connURL = config.FormatDSN() |
| return connURL, nil |
| } |