| // Copyright (c) HashiCorp, Inc. |
| // SPDX-License-Identifier: MPL-2.0 |
| |
| package mysql |
| |
| import ( |
| "context" |
| "crypto/tls" |
| "crypto/x509" |
| "database/sql" |
| "errors" |
| "fmt" |
| "io/ioutil" |
| "math" |
| "net/url" |
| "sort" |
| "strconv" |
| "strings" |
| "sync" |
| "time" |
| "unicode" |
| |
| log "github.com/hashicorp/go-hclog" |
| "github.com/hashicorp/go-multierror" |
| |
| metrics "github.com/armon/go-metrics" |
| mysql "github.com/go-sql-driver/mysql" |
| "github.com/hashicorp/go-secure-stdlib/strutil" |
| "github.com/hashicorp/vault/sdk/physical" |
| ) |
| |
| // Verify MySQLBackend satisfies the correct interfaces |
| var ( |
| _ physical.Backend = (*MySQLBackend)(nil) |
| _ physical.HABackend = (*MySQLBackend)(nil) |
| _ physical.Lock = (*MySQLHALock)(nil) |
| ) |
| |
| // Unreserved tls key |
| // Reserved values are "true", "false", "skip-verify" |
| const mysqlTLSKey = "default" |
| |
| // MySQLBackend is a physical backend that stores data |
| // within MySQL database. |
| type MySQLBackend struct { |
| dbTable string |
| dbLockTable string |
| client *sql.DB |
| statements map[string]*sql.Stmt |
| logger log.Logger |
| permitPool *physical.PermitPool |
| conf map[string]string |
| redirectHost string |
| redirectPort int64 |
| haEnabled bool |
| } |
| |
| // NewMySQLBackend constructs a MySQL backend using the given API client and |
| // server address and credential for accessing mysql database. |
| func NewMySQLBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) { |
| var err error |
| |
| db, err := NewMySQLClient(conf, logger) |
| if err != nil { |
| return nil, err |
| } |
| |
| database := conf["database"] |
| if database == "" { |
| database = "vault" |
| } |
| table := conf["table"] |
| if table == "" { |
| table = "vault" |
| } |
| |
| err = validateDBTable(database, table) |
| if err != nil { |
| return nil, err |
| } |
| |
| dbTable := fmt.Sprintf("`%s`.`%s`", database, table) |
| |
| maxParStr, ok := conf["max_parallel"] |
| var maxParInt int |
| if ok { |
| maxParInt, err = strconv.Atoi(maxParStr) |
| if err != nil { |
| return nil, fmt.Errorf("failed parsing max_parallel parameter: %w", err) |
| } |
| if logger.IsDebug() { |
| logger.Debug("max_parallel set", "max_parallel", maxParInt) |
| } |
| } else { |
| maxParInt = physical.DefaultParallelOperations |
| } |
| |
| // Check schema exists |
| var schemaExist bool |
| schemaRows, err := db.Query("SELECT SCHEMA_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = ?", database) |
| if err != nil { |
| return nil, fmt.Errorf("failed to check mysql schema exist: %w", err) |
| } |
| defer schemaRows.Close() |
| schemaExist = schemaRows.Next() |
| |
| // Check table exists |
| var tableExist bool |
| tableRows, err := db.Query("SELECT TABLE_NAME FROM information_schema.TABLES WHERE TABLE_NAME = ? AND TABLE_SCHEMA = ?", table, database) |
| if err != nil { |
| return nil, fmt.Errorf("failed to check mysql table exist: %w", err) |
| } |
| defer tableRows.Close() |
| tableExist = tableRows.Next() |
| |
| // Create the required database if it doesn't exists. |
| if !schemaExist { |
| if _, err := db.Exec("CREATE DATABASE IF NOT EXISTS `" + database + "`"); err != nil { |
| return nil, fmt.Errorf("failed to create mysql database: %w", err) |
| } |
| } |
| |
| // Create the required table if it doesn't exists. |
| if !tableExist { |
| create_query := "CREATE TABLE IF NOT EXISTS " + dbTable + |
| " (vault_key varbinary(3072), vault_value mediumblob, PRIMARY KEY (vault_key))" |
| if _, err := db.Exec(create_query); err != nil { |
| return nil, fmt.Errorf("failed to create mysql table: %w", err) |
| } |
| } |
| |
| // Default value for ha_enabled |
| haEnabledStr, ok := conf["ha_enabled"] |
| if !ok { |
| haEnabledStr = "false" |
| } |
| haEnabled, err := strconv.ParseBool(haEnabledStr) |
| if err != nil { |
| return nil, fmt.Errorf("value [%v] of 'ha_enabled' could not be understood", haEnabledStr) |
| } |
| |
| locktable, ok := conf["lock_table"] |
| if !ok { |
| locktable = table + "_lock" |
| } |
| |
| dbLockTable := "`" + database + "`.`" + locktable + "`" |
| |
| // Only create lock table if ha_enabled is true |
| if haEnabled { |
| // Check table exists |
| var lockTableExist bool |
| lockTableRows, err := db.Query("SELECT TABLE_NAME FROM information_schema.TABLES WHERE TABLE_NAME = ? AND TABLE_SCHEMA = ?", locktable, database) |
| if err != nil { |
| return nil, fmt.Errorf("failed to check mysql table exist: %w", err) |
| } |
| defer lockTableRows.Close() |
| lockTableExist = lockTableRows.Next() |
| |
| // Create the required table if it doesn't exists. |
| if !lockTableExist { |
| create_query := "CREATE TABLE IF NOT EXISTS " + dbLockTable + |
| " (node_job varbinary(512), current_leader varbinary(512), PRIMARY KEY (node_job))" |
| if _, err := db.Exec(create_query); err != nil { |
| return nil, fmt.Errorf("failed to create mysql table: %w", err) |
| } |
| } |
| } |
| |
| // Setup the backend. |
| m := &MySQLBackend{ |
| dbTable: dbTable, |
| dbLockTable: dbLockTable, |
| client: db, |
| statements: make(map[string]*sql.Stmt), |
| logger: logger, |
| permitPool: physical.NewPermitPool(maxParInt), |
| conf: conf, |
| haEnabled: haEnabled, |
| } |
| |
| // Prepare all the statements required |
| statements := map[string]string{ |
| "put": "INSERT INTO " + dbTable + |
| " VALUES( ?, ? ) ON DUPLICATE KEY UPDATE vault_value=VALUES(vault_value)", |
| "get": "SELECT vault_value FROM " + dbTable + " WHERE vault_key = ?", |
| "delete": "DELETE FROM " + dbTable + " WHERE vault_key = ?", |
| "list": "SELECT vault_key FROM " + dbTable + " WHERE vault_key LIKE ?", |
| } |
| |
| // Only prepare ha-related statements if we need them |
| if haEnabled { |
| statements["get_lock"] = "SELECT current_leader FROM " + dbLockTable + " WHERE node_job = ?" |
| statements["used_lock"] = "SELECT IS_USED_LOCK(?)" |
| } |
| |
| for name, query := range statements { |
| if err := m.prepare(name, query); err != nil { |
| return nil, err |
| } |
| } |
| |
| return m, nil |
| } |
| |
| // validateDBTable to prevent SQL injection attacks. This ensures that the database and table names only have valid |
| // characters in them. MySQL allows for more characters that this will allow, but there isn't an easy way of |
| // representing the full Unicode Basic Multilingual Plane to check against. |
| // https://dev.mysql.com/doc/refman/5.7/en/identifiers.html |
| func validateDBTable(db, table string) (err error) { |
| merr := &multierror.Error{} |
| merr = multierror.Append(merr, wrapErr("invalid database: %w", validate(db))) |
| merr = multierror.Append(merr, wrapErr("invalid table: %w", validate(table))) |
| return merr.ErrorOrNil() |
| } |
| |
| func validate(name string) (err error) { |
| if name == "" { |
| return fmt.Errorf("missing name") |
| } |
| // From: https://dev.mysql.com/doc/refman/5.7/en/identifiers.html |
| // - Permitted characters in quoted identifiers include the full Unicode Basic Multilingual Plane (BMP), except U+0000: |
| // ASCII: U+0001 .. U+007F |
| // Extended: U+0080 .. U+FFFF |
| // - ASCII NUL (U+0000) and supplementary characters (U+10000 and higher) are not permitted in quoted or unquoted identifiers. |
| // - Identifiers may begin with a digit but unless quoted may not consist solely of digits. |
| // - Database, table, and column names cannot end with space characters. |
| // |
| // We are explicitly excluding all space characters (it's easier to deal with) |
| // The name will be quoted, so the all-digit requirement doesn't apply |
| runes := []rune(name) |
| validationErr := fmt.Errorf("invalid character found: can only include printable, non-space characters between [0x0001-0xFFFF]") |
| for _, r := range runes { |
| // U+0000 Explicitly disallowed |
| if r == 0x0000 { |
| return fmt.Errorf("invalid character: cannot include 0x0000") |
| } |
| // Cannot be above 0xFFFF |
| if r > 0xFFFF { |
| return fmt.Errorf("invalid character: cannot include any characters above 0xFFFF") |
| } |
| if r == '`' { |
| return fmt.Errorf("invalid character: cannot include '`' character") |
| } |
| if r == '\'' || r == '"' { |
| return fmt.Errorf("invalid character: cannot include quotes") |
| } |
| // We are excluding non-printable characters (not mentioned in the docs) |
| if !unicode.IsPrint(r) { |
| return validationErr |
| } |
| // We are excluding space characters (not mentioned in the docs) |
| if unicode.IsSpace(r) { |
| return validationErr |
| } |
| } |
| return nil |
| } |
| |
| func wrapErr(message string, err error) error { |
| if err == nil { |
| return nil |
| } |
| return fmt.Errorf(message, err) |
| } |
| |
| func NewMySQLClient(conf map[string]string, logger log.Logger) (*sql.DB, error) { |
| var err error |
| |
| // Get the MySQL credentials to perform read/write operations. |
| username, ok := conf["username"] |
| if !ok || username == "" { |
| return nil, fmt.Errorf("missing username") |
| } |
| password, ok := conf["password"] |
| if !ok || password == "" { |
| return nil, fmt.Errorf("missing password") |
| } |
| |
| // Get or set MySQL server address. Defaults to localhost and default port(3306) |
| address, ok := conf["address"] |
| if !ok { |
| address = "127.0.0.1:3306" |
| } |
| |
| maxIdleConnStr, ok := conf["max_idle_connections"] |
| var maxIdleConnInt int |
| if ok { |
| maxIdleConnInt, err = strconv.Atoi(maxIdleConnStr) |
| if err != nil { |
| return nil, fmt.Errorf("failed parsing max_idle_connections parameter: %w", err) |
| } |
| if logger.IsDebug() { |
| logger.Debug("max_idle_connections set", "max_idle_connections", maxIdleConnInt) |
| } |
| } |
| |
| maxConnLifeStr, ok := conf["max_connection_lifetime"] |
| var maxConnLifeInt int |
| if ok { |
| maxConnLifeInt, err = strconv.Atoi(maxConnLifeStr) |
| if err != nil { |
| return nil, fmt.Errorf("failed parsing max_connection_lifetime parameter: %w", err) |
| } |
| if logger.IsDebug() { |
| logger.Debug("max_connection_lifetime set", "max_connection_lifetime", maxConnLifeInt) |
| } |
| } |
| |
| maxParStr, ok := conf["max_parallel"] |
| var maxParInt int |
| if ok { |
| maxParInt, err = strconv.Atoi(maxParStr) |
| if err != nil { |
| return nil, fmt.Errorf("failed parsing max_parallel parameter: %w", err) |
| } |
| if logger.IsDebug() { |
| logger.Debug("max_parallel set", "max_parallel", maxParInt) |
| } |
| } else { |
| maxParInt = physical.DefaultParallelOperations |
| } |
| |
| dsnParams := url.Values{} |
| tlsCaFile, tlsOk := conf["tls_ca_file"] |
| if tlsOk { |
| if err := setupMySQLTLSConfig(tlsCaFile); err != nil { |
| return nil, fmt.Errorf("failed register TLS config: %w", err) |
| } |
| |
| dsnParams.Add("tls", mysqlTLSKey) |
| } |
| ptAllowed, ptOk := conf["plaintext_connection_allowed"] |
| if !(ptOk && strings.ToLower(ptAllowed) == "true") && !tlsOk { |
| logger.Warn("No TLS specified, credentials will be sent in plaintext. To mute this warning add 'plaintext_connection_allowed' with a true value to your MySQL configuration in your config file.") |
| } |
| |
| // Create MySQL handle for the database. |
| dsn := username + ":" + password + "@tcp(" + address + ")/?" + dsnParams.Encode() |
| db, err := sql.Open("mysql", dsn) |
| if err != nil { |
| return nil, fmt.Errorf("failed to connect to mysql: %w", err) |
| } |
| db.SetMaxOpenConns(maxParInt) |
| if maxIdleConnInt != 0 { |
| db.SetMaxIdleConns(maxIdleConnInt) |
| } |
| if maxConnLifeInt != 0 { |
| db.SetConnMaxLifetime(time.Duration(maxConnLifeInt) * time.Second) |
| } |
| |
| return db, err |
| } |
| |
| // prepare is a helper to prepare a query for future execution |
| func (m *MySQLBackend) prepare(name, query string) error { |
| stmt, err := m.client.Prepare(query) |
| if err != nil { |
| return fmt.Errorf("failed to prepare %q: %w", name, err) |
| } |
| m.statements[name] = stmt |
| return nil |
| } |
| |
| // Put is used to insert or update an entry. |
| func (m *MySQLBackend) Put(ctx context.Context, entry *physical.Entry) error { |
| defer metrics.MeasureSince([]string{"mysql", "put"}, time.Now()) |
| |
| m.permitPool.Acquire() |
| defer m.permitPool.Release() |
| |
| _, err := m.statements["put"].Exec(entry.Key, entry.Value) |
| if err != nil { |
| return err |
| } |
| return nil |
| } |
| |
| // Get is used to fetch an entry. |
| func (m *MySQLBackend) Get(ctx context.Context, key string) (*physical.Entry, error) { |
| defer metrics.MeasureSince([]string{"mysql", "get"}, time.Now()) |
| |
| m.permitPool.Acquire() |
| defer m.permitPool.Release() |
| |
| var result []byte |
| err := m.statements["get"].QueryRow(key).Scan(&result) |
| if err == sql.ErrNoRows { |
| return nil, nil |
| } |
| if err != nil { |
| return nil, err |
| } |
| |
| ent := &physical.Entry{ |
| Key: key, |
| Value: result, |
| } |
| return ent, nil |
| } |
| |
| // Delete is used to permanently delete an entry |
| func (m *MySQLBackend) Delete(ctx context.Context, key string) error { |
| defer metrics.MeasureSince([]string{"mysql", "delete"}, time.Now()) |
| |
| m.permitPool.Acquire() |
| defer m.permitPool.Release() |
| |
| _, err := m.statements["delete"].Exec(key) |
| if err != nil { |
| return err |
| } |
| return nil |
| } |
| |
| // List is used to list all the keys under a given |
| // prefix, up to the next prefix. |
| func (m *MySQLBackend) List(ctx context.Context, prefix string) ([]string, error) { |
| defer metrics.MeasureSince([]string{"mysql", "list"}, time.Now()) |
| |
| m.permitPool.Acquire() |
| defer m.permitPool.Release() |
| |
| // Add the % wildcard to the prefix to do the prefix search |
| likePrefix := prefix + "%" |
| rows, err := m.statements["list"].Query(likePrefix) |
| if err != nil { |
| return nil, fmt.Errorf("failed to execute statement: %w", err) |
| } |
| |
| var keys []string |
| for rows.Next() { |
| var key string |
| err = rows.Scan(&key) |
| if err != nil { |
| return nil, fmt.Errorf("failed to scan rows: %w", err) |
| } |
| |
| key = strings.TrimPrefix(key, prefix) |
| if i := strings.Index(key, "/"); i == -1 { |
| // Add objects only from the current 'folder' |
| keys = append(keys, key) |
| } else if i != -1 { |
| // Add truncated 'folder' paths |
| keys = strutil.AppendIfMissing(keys, string(key[:i+1])) |
| } |
| } |
| |
| sort.Strings(keys) |
| return keys, nil |
| } |
| |
| // LockWith is used for mutual exclusion based on the given key. |
| func (m *MySQLBackend) LockWith(key, value string) (physical.Lock, error) { |
| l := &MySQLHALock{ |
| in: m, |
| key: key, |
| value: value, |
| logger: m.logger, |
| } |
| return l, nil |
| } |
| |
| func (m *MySQLBackend) HAEnabled() bool { |
| return m.haEnabled |
| } |
| |
| // MySQLHALock is a MySQL Lock implementation for the HABackend |
| type MySQLHALock struct { |
| in *MySQLBackend |
| key string |
| value string |
| logger log.Logger |
| |
| held bool |
| localLock sync.Mutex |
| leaderCh chan struct{} |
| stopCh <-chan struct{} |
| lock *MySQLLock |
| } |
| |
| func (i *MySQLHALock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) { |
| i.localLock.Lock() |
| defer i.localLock.Unlock() |
| if i.held { |
| return nil, fmt.Errorf("lock already held") |
| } |
| |
| // Attempt an async acquisition |
| didLock := make(chan struct{}) |
| failLock := make(chan error, 1) |
| releaseCh := make(chan bool, 1) |
| go i.attemptLock(i.key, i.value, didLock, failLock, releaseCh) |
| |
| // Wait for lock acquisition, failure, or shutdown |
| select { |
| case <-didLock: |
| releaseCh <- false |
| case err := <-failLock: |
| return nil, err |
| case <-stopCh: |
| releaseCh <- true |
| return nil, nil |
| } |
| |
| // Create the leader channel |
| i.held = true |
| i.leaderCh = make(chan struct{}) |
| |
| go i.monitorLock(i.leaderCh) |
| |
| i.stopCh = stopCh |
| |
| return i.leaderCh, nil |
| } |
| |
| func (i *MySQLHALock) attemptLock(key, value string, didLock chan struct{}, failLock chan error, releaseCh chan bool) { |
| lock, err := NewMySQLLock(i.in, i.logger, key, value) |
| if err != nil { |
| failLock <- err |
| return |
| } |
| |
| // Set node value |
| i.lock = lock |
| |
| err = lock.Lock() |
| if err != nil { |
| failLock <- err |
| return |
| } |
| |
| // Signal that lock is held |
| close(didLock) |
| |
| // Handle an early abort |
| release := <-releaseCh |
| if release { |
| lock.Unlock() |
| } |
| } |
| |
| func (i *MySQLHALock) monitorLock(leaderCh chan struct{}) { |
| for { |
| // The only way to lose this lock is if someone is |
| // logging into the DB and altering system tables or you lose a connection in |
| // which case you will lose the lock anyway. |
| err := i.hasLock(i.key) |
| if err != nil { |
| // Somehow we lost the lock.... likely because the connection holding |
| // the lock was closed or someone was playing around with the locks in the DB. |
| close(leaderCh) |
| return |
| } |
| |
| time.Sleep(5 * time.Second) |
| } |
| } |
| |
| func (i *MySQLHALock) Unlock() error { |
| i.localLock.Lock() |
| defer i.localLock.Unlock() |
| if !i.held { |
| return nil |
| } |
| |
| err := i.lock.Unlock() |
| |
| if err == nil { |
| i.held = false |
| return nil |
| } |
| |
| return err |
| } |
| |
| // hasLock will check if a lock is held by checking the current lock id against our known ID. |
| func (i *MySQLHALock) hasLock(key string) error { |
| var result sql.NullInt64 |
| err := i.in.statements["used_lock"].QueryRow(key).Scan(&result) |
| if err == sql.ErrNoRows || !result.Valid { |
| // This is not an error to us since it just means the lock isn't held |
| return nil |
| } |
| |
| if err != nil { |
| return err |
| } |
| |
| // IS_USED_LOCK will return the ID of the connection that created the lock. |
| if result.Int64 != GlobalLockID { |
| return ErrLockHeld |
| } |
| |
| return nil |
| } |
| |
| func (i *MySQLHALock) GetLeader() (string, error) { |
| defer metrics.MeasureSince([]string{"mysql", "lock_get"}, time.Now()) |
| var result string |
| err := i.in.statements["get_lock"].QueryRow("leader").Scan(&result) |
| if err == sql.ErrNoRows { |
| return "", err |
| } |
| |
| return result, nil |
| } |
| |
| func (i *MySQLHALock) Value() (bool, string, error) { |
| leaderkey, err := i.GetLeader() |
| if err != nil { |
| return false, "", err |
| } |
| |
| return true, leaderkey, err |
| } |
| |
| // MySQLLock provides an easy way to grab and release mysql |
| // locks using the built in GET_LOCK function. Note that these |
| // locks are released when you lose connection to the server. |
| type MySQLLock struct { |
| parentConn *MySQLBackend |
| in *sql.DB |
| logger log.Logger |
| statements map[string]*sql.Stmt |
| key string |
| value string |
| } |
| |
| // Errors specific to trying to grab a lock in MySQL |
| var ( |
| // This is the GlobalLockID for checking if the lock we got is still the current lock |
| GlobalLockID int64 |
| // ErrLockHeld is returned when another vault instance already has a lock held for the given key. |
| ErrLockHeld = errors.New("mysql: lock already held") |
| // ErrUnlockFailed |
| ErrUnlockFailed = errors.New("mysql: unable to release lock, already released or not held by this session") |
| // You were unable to update that you are the new leader in the DB |
| ErrClaimFailed = errors.New("mysql: unable to update DB with new leader information") |
| // Error to throw if between getting the lock and checking the ID of it we lost it. |
| ErrSettingGlobalID = errors.New("mysql: getting global lock id failed") |
| ) |
| |
| // NewMySQLLock helper function |
| func NewMySQLLock(in *MySQLBackend, l log.Logger, key, value string) (*MySQLLock, error) { |
| // Create a new MySQL connection so we can close this and have no effect on |
| // the rest of the MySQL backend and any cleanup that might need to be done. |
| conn, _ := NewMySQLClient(in.conf, in.logger) |
| |
| m := &MySQLLock{ |
| parentConn: in, |
| in: conn, |
| logger: l, |
| statements: make(map[string]*sql.Stmt), |
| key: key, |
| value: value, |
| } |
| |
| statements := map[string]string{ |
| "put": "INSERT INTO " + in.dbLockTable + |
| " VALUES( ?, ? ) ON DUPLICATE KEY UPDATE current_leader=VALUES(current_leader)", |
| } |
| |
| for name, query := range statements { |
| if err := m.prepare(name, query); err != nil { |
| return nil, err |
| } |
| } |
| |
| return m, nil |
| } |
| |
| // prepare is a helper to prepare a query for future execution |
| func (m *MySQLLock) prepare(name, query string) error { |
| stmt, err := m.in.Prepare(query) |
| if err != nil { |
| return fmt.Errorf("failed to prepare %q: %w", name, err) |
| } |
| m.statements[name] = stmt |
| return nil |
| } |
| |
| // update the current cluster leader in the DB. This is used so |
| // we can tell the servers in standby who the active leader is. |
| func (i *MySQLLock) becomeLeader() error { |
| _, err := i.statements["put"].Exec("leader", i.value) |
| if err != nil { |
| return err |
| } |
| |
| return nil |
| } |
| |
| // Lock will try to get a lock for an indefinite amount of time |
| // based on the given key that has been requested. |
| func (i *MySQLLock) Lock() error { |
| defer metrics.MeasureSince([]string{"mysql", "get_lock"}, time.Now()) |
| |
| // Lock timeout math.MaxInt32 instead of -1 solves compatibility issues with |
| // different MySQL flavours i.e. MariaDB |
| rows, err := i.in.Query("SELECT GET_LOCK(?, ?), IS_USED_LOCK(?)", i.key, math.MaxInt32, i.key) |
| if err != nil { |
| return err |
| } |
| |
| defer rows.Close() |
| rows.Next() |
| var lock sql.NullInt64 |
| var connectionID sql.NullInt64 |
| rows.Scan(&lock, &connectionID) |
| |
| if rows.Err() != nil { |
| return rows.Err() |
| } |
| |
| // 1 is returned from GET_LOCK if it was able to get the lock |
| // 0 if it failed and NULL if some strange error happened. |
| // https://dev.mysql.com/doc/refman/8.0/en/miscellaneous-functions.html#function_get-lock |
| if !lock.Valid || lock.Int64 != 1 { |
| return ErrLockHeld |
| } |
| |
| // Since we have the lock alert the rest of the cluster |
| // that we are now the active leader. |
| err = i.becomeLeader() |
| if err != nil { |
| return ErrLockHeld |
| } |
| |
| // This will return the connection ID of NULL if an error happens |
| // https://dev.mysql.com/doc/refman/8.0/en/miscellaneous-functions.html#function_is-used-lock |
| if !connectionID.Valid { |
| return ErrSettingGlobalID |
| } |
| |
| GlobalLockID = connectionID.Int64 |
| |
| return nil |
| } |
| |
| // Unlock just closes the connection. This is because closing the MySQL connection |
| // is a 100% reliable way to close the lock. If you just release the lock you must |
| // do it from the same mysql connection_id that you originally created it from. This |
| // is a huge hastle and I actually couldn't find a clean way to do this although one |
| // likely does exist. Closing the connection however ensures we don't ever get into a |
| // state where we try to release the lock and it hangs it is also much less code. |
| func (i *MySQLLock) Unlock() error { |
| err := i.in.Close() |
| if err != nil { |
| return ErrUnlockFailed |
| } |
| |
| return nil |
| } |
| |
| // Establish a TLS connection with a given CA certificate |
| // Register a tsl.Config associated with the same key as the dns param from sql.Open |
| // foo:bar@tcp(127.0.0.1:3306)/dbname?tls=default |
| func setupMySQLTLSConfig(tlsCaFile string) error { |
| rootCertPool := x509.NewCertPool() |
| |
| pem, err := ioutil.ReadFile(tlsCaFile) |
| if err != nil { |
| return err |
| } |
| |
| if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { |
| return err |
| } |
| |
| err = mysql.RegisterTLSConfig(mysqlTLSKey, &tls.Config{ |
| RootCAs: rootCertPool, |
| }) |
| if err != nil { |
| return err |
| } |
| |
| return nil |
| } |