| // Copyright (c) HashiCorp, Inc. |
| // SPDX-License-Identifier: MPL-2.0 |
| |
| package cassandra |
| |
| import ( |
| "context" |
| "fmt" |
| "strings" |
| |
| "github.com/hashicorp/vault/sdk/helper/template" |
| |
| "github.com/gocql/gocql" |
| multierror "github.com/hashicorp/go-multierror" |
| "github.com/hashicorp/go-secure-stdlib/strutil" |
| dbplugin "github.com/hashicorp/vault/sdk/database/dbplugin/v5" |
| "github.com/hashicorp/vault/sdk/database/helper/dbutil" |
| ) |
| |
| const ( |
| defaultUserCreationCQL = `CREATE USER '{{username}}' WITH PASSWORD '{{password}}' NOSUPERUSER;` |
| defaultUserDeletionCQL = `DROP USER '{{username}}';` |
| defaultChangePasswordCQL = `ALTER USER '{{username}}' WITH PASSWORD '{{password}}';` |
| cassandraTypeName = "cassandra" |
| |
| defaultUserNameTemplate = `{{ printf "v_%s_%s_%s_%s" (.DisplayName | truncate 15) (.RoleName | truncate 15) (random 20) (unix_time) | truncate 100 | replace "-" "_" | lowercase }}` |
| ) |
| |
| var _ dbplugin.Database = &Cassandra{} |
| |
| // Cassandra is an implementation of Database interface |
| type Cassandra struct { |
| *cassandraConnectionProducer |
| |
| usernameProducer template.StringTemplate |
| } |
| |
| // New returns a new Cassandra instance |
| func New() (interface{}, error) { |
| db := new() |
| dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.secretValues) |
| |
| return dbType, nil |
| } |
| |
| func new() *Cassandra { |
| connProducer := &cassandraConnectionProducer{} |
| connProducer.Type = cassandraTypeName |
| |
| return &Cassandra{ |
| cassandraConnectionProducer: connProducer, |
| } |
| } |
| |
| // Type returns the TypeName for this backend |
| func (c *Cassandra) Type() (string, error) { |
| return cassandraTypeName, nil |
| } |
| |
| func (c *Cassandra) getConnection(ctx context.Context) (*gocql.Session, error) { |
| session, err := c.Connection(ctx) |
| if err != nil { |
| return nil, err |
| } |
| |
| return session.(*gocql.Session), nil |
| } |
| |
| func (c *Cassandra) Initialize(ctx context.Context, req dbplugin.InitializeRequest) (dbplugin.InitializeResponse, error) { |
| usernameTemplate, err := strutil.GetString(req.Config, "username_template") |
| if err != nil { |
| return dbplugin.InitializeResponse{}, fmt.Errorf("failed to retrieve username_template: %w", err) |
| } |
| if usernameTemplate == "" { |
| usernameTemplate = defaultUserNameTemplate |
| } |
| |
| up, err := template.NewTemplate(template.Template(usernameTemplate)) |
| if err != nil { |
| return dbplugin.InitializeResponse{}, fmt.Errorf("unable to initialize username template: %w", err) |
| } |
| c.usernameProducer = up |
| |
| _, err = c.usernameProducer.Generate(dbplugin.UsernameMetadata{}) |
| if err != nil { |
| return dbplugin.InitializeResponse{}, fmt.Errorf("invalid username template: %w", err) |
| } |
| |
| err = c.cassandraConnectionProducer.Initialize(ctx, req) |
| if err != nil { |
| return dbplugin.InitializeResponse{}, fmt.Errorf("failed to initialize: %w", err) |
| } |
| |
| resp := dbplugin.InitializeResponse{ |
| Config: req.Config, |
| } |
| return resp, nil |
| } |
| |
| // NewUser generates the username/password on the underlying Cassandra secret backend as instructed by |
| // the statements provided. |
| func (c *Cassandra) NewUser(ctx context.Context, req dbplugin.NewUserRequest) (dbplugin.NewUserResponse, error) { |
| c.Lock() |
| defer c.Unlock() |
| |
| session, err := c.getConnection(ctx) |
| if err != nil { |
| return dbplugin.NewUserResponse{}, err |
| } |
| |
| creationCQL := req.Statements.Commands |
| if len(creationCQL) == 0 { |
| creationCQL = []string{defaultUserCreationCQL} |
| } |
| |
| rollbackCQL := req.RollbackStatements.Commands |
| if len(rollbackCQL) == 0 { |
| rollbackCQL = []string{defaultUserDeletionCQL} |
| } |
| |
| username, err := c.usernameProducer.Generate(req.UsernameConfig) |
| if err != nil { |
| return dbplugin.NewUserResponse{}, err |
| } |
| |
| for _, stmt := range creationCQL { |
| for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { |
| query = strings.TrimSpace(query) |
| if len(query) == 0 { |
| continue |
| } |
| |
| m := map[string]string{ |
| "username": username, |
| "password": req.Password, |
| } |
| err = session. |
| Query(dbutil.QueryHelper(query, m)). |
| WithContext(ctx). |
| Exec() |
| if err != nil { |
| rollbackErr := rollbackUser(ctx, session, username, rollbackCQL) |
| if rollbackErr != nil { |
| err = multierror.Append(err, rollbackErr) |
| } |
| return dbplugin.NewUserResponse{}, err |
| } |
| } |
| } |
| |
| resp := dbplugin.NewUserResponse{ |
| Username: username, |
| } |
| return resp, nil |
| } |
| |
| func rollbackUser(ctx context.Context, session *gocql.Session, username string, rollbackCQL []string) error { |
| for _, stmt := range rollbackCQL { |
| for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { |
| query = strings.TrimSpace(query) |
| if len(query) == 0 { |
| continue |
| } |
| |
| m := map[string]string{ |
| "username": username, |
| } |
| err := session. |
| Query(dbutil.QueryHelper(query, m)). |
| WithContext(ctx). |
| Exec() |
| if err != nil { |
| return fmt.Errorf("failed to roll back user %s: %w", username, err) |
| } |
| } |
| } |
| return nil |
| } |
| |
| func (c *Cassandra) UpdateUser(ctx context.Context, req dbplugin.UpdateUserRequest) (dbplugin.UpdateUserResponse, error) { |
| if req.Password == nil && req.Expiration == nil { |
| return dbplugin.UpdateUserResponse{}, fmt.Errorf("no changes requested") |
| } |
| |
| if req.Password != nil { |
| err := c.changeUserPassword(ctx, req.Username, req.Password) |
| return dbplugin.UpdateUserResponse{}, err |
| } |
| // Expiration is no-op |
| return dbplugin.UpdateUserResponse{}, nil |
| } |
| |
| func (c *Cassandra) changeUserPassword(ctx context.Context, username string, changePass *dbplugin.ChangePassword) error { |
| session, err := c.getConnection(ctx) |
| if err != nil { |
| return err |
| } |
| |
| rotateCQL := changePass.Statements.Commands |
| if len(rotateCQL) == 0 { |
| rotateCQL = []string{defaultChangePasswordCQL} |
| } |
| |
| var result *multierror.Error |
| for _, stmt := range rotateCQL { |
| for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { |
| query = strings.TrimSpace(query) |
| if len(query) == 0 { |
| continue |
| } |
| |
| m := map[string]string{ |
| "username": username, |
| "password": changePass.NewPassword, |
| } |
| err := session. |
| Query(dbutil.QueryHelper(query, m)). |
| WithContext(ctx). |
| Exec() |
| result = multierror.Append(result, err) |
| } |
| } |
| |
| return result.ErrorOrNil() |
| } |
| |
| // DeleteUser attempts to drop the specified user. |
| func (c *Cassandra) DeleteUser(ctx context.Context, req dbplugin.DeleteUserRequest) (dbplugin.DeleteUserResponse, error) { |
| c.Lock() |
| defer c.Unlock() |
| |
| session, err := c.getConnection(ctx) |
| if err != nil { |
| return dbplugin.DeleteUserResponse{}, err |
| } |
| |
| revocationCQL := req.Statements.Commands |
| if len(revocationCQL) == 0 { |
| revocationCQL = []string{defaultUserDeletionCQL} |
| } |
| |
| var result *multierror.Error |
| for _, stmt := range revocationCQL { |
| for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { |
| query = strings.TrimSpace(query) |
| if len(query) == 0 { |
| continue |
| } |
| |
| m := map[string]string{ |
| "username": req.Username, |
| } |
| err := session. |
| Query(dbutil.QueryHelper(query, m)). |
| WithContext(ctx). |
| Exec() |
| |
| result = multierror.Append(result, err) |
| } |
| } |
| |
| return dbplugin.DeleteUserResponse{}, result.ErrorOrNil() |
| } |