| // Copyright (c) HashiCorp, Inc. |
| // SPDX-License-Identifier: MPL-2.0 |
| |
| package rabbitmq |
| |
| import ( |
| "context" |
| "fmt" |
| "log" |
| "os" |
| "testing" |
| |
| "github.com/hashicorp/go-secure-stdlib/base62" |
| logicaltest "github.com/hashicorp/vault/helper/testhelpers/logical" |
| "github.com/hashicorp/vault/sdk/helper/docker" |
| "github.com/hashicorp/vault/sdk/helper/jsonutil" |
| "github.com/hashicorp/vault/sdk/logical" |
| rabbithole "github.com/michaelklishin/rabbit-hole/v2" |
| "github.com/mitchellh/mapstructure" |
| ) |
| |
| const ( |
| envRabbitMQConnectionURI = "RABBITMQ_CONNECTION_URI" |
| envRabbitMQUsername = "RABBITMQ_USERNAME" |
| envRabbitMQPassword = "RABBITMQ_PASSWORD" |
| ) |
| |
| const ( |
| testTags = "administrator" |
| testVHosts = `{"/": {"configure": ".*", "write": ".*", "read": ".*"}}` |
| testVHostTopics = `{"/": {"amq.topic": {"write": ".*", "read": ".*"}}}` |
| |
| roleName = "web" |
| ) |
| |
| func prepareRabbitMQTestContainer(t *testing.T) (func(), string) { |
| if os.Getenv(envRabbitMQConnectionURI) != "" { |
| return func() {}, os.Getenv(envRabbitMQConnectionURI) |
| } |
| |
| runner, err := docker.NewServiceRunner(docker.RunOptions{ |
| ImageRepo: "docker.mirror.hashicorp.services/library/rabbitmq", |
| ImageTag: "3-management", |
| ContainerName: "rabbitmq", |
| Ports: []string{"15672/tcp"}, |
| }) |
| if err != nil { |
| t.Fatalf("could not start docker rabbitmq: %s", err) |
| } |
| |
| svc, err := runner.StartService(context.Background(), func(ctx context.Context, host string, port int) (docker.ServiceConfig, error) { |
| connURL := fmt.Sprintf("http://%s:%d", host, port) |
| rmqc, err := rabbithole.NewClient(connURL, "guest", "guest") |
| if err != nil { |
| return nil, err |
| } |
| |
| _, err = rmqc.Overview() |
| if err != nil { |
| return nil, err |
| } |
| |
| return docker.NewServiceURLParse(connURL) |
| }) |
| if err != nil { |
| t.Fatalf("could not start docker rabbitmq: %s", err) |
| } |
| return svc.Cleanup, svc.Config.URL().String() |
| } |
| |
| func TestBackend_basic(t *testing.T) { |
| b, _ := Factory(context.Background(), logical.TestBackendConfig()) |
| |
| cleanup, uri := prepareRabbitMQTestContainer(t) |
| defer cleanup() |
| |
| logicaltest.Test(t, logicaltest.TestCase{ |
| PreCheck: testAccPreCheckFunc(t, uri), |
| LogicalBackend: b, |
| Steps: []logicaltest.TestStep{ |
| testAccStepConfig(t, uri, ""), |
| testAccStepRole(t), |
| testAccStepReadCreds(t, b, uri, roleName), |
| }, |
| }) |
| } |
| |
| func TestBackend_returnsErrs(t *testing.T) { |
| b, _ := Factory(context.Background(), logical.TestBackendConfig()) |
| |
| cleanup, uri := prepareRabbitMQTestContainer(t) |
| defer cleanup() |
| |
| logicaltest.Test(t, logicaltest.TestCase{ |
| PreCheck: testAccPreCheckFunc(t, uri), |
| LogicalBackend: b, |
| Steps: []logicaltest.TestStep{ |
| testAccStepConfig(t, uri, ""), |
| { |
| Operation: logical.CreateOperation, |
| Path: fmt.Sprintf("roles/%s", roleName), |
| Data: map[string]interface{}{ |
| "tags": testTags, |
| "vhosts": `{"invalid":{"write": ".*", "read": ".*"}}`, |
| "vhost_topics": testVHostTopics, |
| }, |
| }, |
| { |
| Operation: logical.ReadOperation, |
| Path: fmt.Sprintf("creds/%s", roleName), |
| ErrorOk: true, |
| }, |
| }, |
| }) |
| } |
| |
| func TestBackend_roleCrud(t *testing.T) { |
| b, _ := Factory(context.Background(), logical.TestBackendConfig()) |
| |
| cleanup, uri := prepareRabbitMQTestContainer(t) |
| defer cleanup() |
| |
| logicaltest.Test(t, logicaltest.TestCase{ |
| PreCheck: testAccPreCheckFunc(t, uri), |
| LogicalBackend: b, |
| Steps: []logicaltest.TestStep{ |
| testAccStepConfig(t, uri, ""), |
| testAccStepRole(t), |
| testAccStepReadRole(t, roleName, testTags, testVHosts, testVHostTopics), |
| testAccStepDeleteRole(t, roleName), |
| testAccStepReadRole(t, roleName, "", "", ""), |
| }, |
| }) |
| } |
| |
| func TestBackend_roleWithPasswordPolicy(t *testing.T) { |
| if os.Getenv(logicaltest.TestEnvVar) == "" { |
| t.Skip(fmt.Sprintf("Acceptance tests skipped unless env %q set", logicaltest.TestEnvVar)) |
| return |
| } |
| |
| backendConfig := logical.TestBackendConfig() |
| passGen := func() (password string, err error) { |
| return base62.Random(30) |
| } |
| backendConfig.System.(*logical.StaticSystemView).SetPasswordPolicy("testpolicy", passGen) |
| b, _ := Factory(context.Background(), backendConfig) |
| |
| cleanup, uri := prepareRabbitMQTestContainer(t) |
| defer cleanup() |
| |
| logicaltest.Test(t, logicaltest.TestCase{ |
| PreCheck: testAccPreCheckFunc(t, uri), |
| LogicalBackend: b, |
| Steps: []logicaltest.TestStep{ |
| testAccStepConfig(t, uri, "testpolicy"), |
| testAccStepRole(t), |
| testAccStepReadCreds(t, b, uri, roleName), |
| }, |
| }) |
| } |
| |
| func testAccPreCheckFunc(t *testing.T, uri string) func() { |
| return func() { |
| if uri == "" { |
| t.Fatal("RabbitMQ URI must be set for acceptance tests") |
| } |
| } |
| } |
| |
| func testAccStepConfig(t *testing.T, uri string, passwordPolicy string) logicaltest.TestStep { |
| username := os.Getenv(envRabbitMQUsername) |
| if len(username) == 0 { |
| username = "guest" |
| } |
| password := os.Getenv(envRabbitMQPassword) |
| if len(password) == 0 { |
| password = "guest" |
| } |
| |
| return logicaltest.TestStep{ |
| Operation: logical.UpdateOperation, |
| Path: "config/connection", |
| Data: map[string]interface{}{ |
| "connection_uri": uri, |
| "username": username, |
| "password": password, |
| "password_policy": passwordPolicy, |
| }, |
| } |
| } |
| |
| func testAccStepRole(t *testing.T) logicaltest.TestStep { |
| return logicaltest.TestStep{ |
| Operation: logical.UpdateOperation, |
| Path: fmt.Sprintf("roles/%s", roleName), |
| Data: map[string]interface{}{ |
| "tags": testTags, |
| "vhosts": testVHosts, |
| "vhost_topics": testVHostTopics, |
| }, |
| } |
| } |
| |
| func testAccStepDeleteRole(t *testing.T, n string) logicaltest.TestStep { |
| return logicaltest.TestStep{ |
| Operation: logical.DeleteOperation, |
| Path: "roles/" + n, |
| } |
| } |
| |
| func testAccStepReadCreds(t *testing.T, b logical.Backend, uri, name string) logicaltest.TestStep { |
| return logicaltest.TestStep{ |
| Operation: logical.ReadOperation, |
| Path: "creds/" + name, |
| Check: func(resp *logical.Response) error { |
| var d struct { |
| Username string `mapstructure:"username"` |
| Password string `mapstructure:"password"` |
| } |
| if err := mapstructure.Decode(resp.Data, &d); err != nil { |
| return err |
| } |
| log.Printf("[WARN] Generated credentials: %v", d) |
| |
| client, err := rabbithole.NewClient(uri, d.Username, d.Password) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| _, err = client.ListVhosts() |
| if err != nil { |
| t.Fatalf("unable to list vhosts with generated credentials: %s", err) |
| } |
| |
| resp, err = b.HandleRequest(context.Background(), &logical.Request{ |
| Operation: logical.RevokeOperation, |
| Secret: &logical.Secret{ |
| InternalData: map[string]interface{}{ |
| "secret_type": "creds", |
| "username": d.Username, |
| }, |
| }, |
| }) |
| if err != nil { |
| return err |
| } |
| if resp != nil { |
| if resp.IsError() { |
| return fmt.Errorf("error on resp: %#v", *resp) |
| } |
| } |
| |
| client, err = rabbithole.NewClient(uri, d.Username, d.Password) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| _, err = client.ListVhosts() |
| if err == nil { |
| t.Fatalf("expected to fail listing vhosts: %s", err) |
| } |
| |
| return nil |
| }, |
| } |
| } |
| |
| func testAccStepReadRole(t *testing.T, name, tags, rawVHosts string, rawVHostTopics string) logicaltest.TestStep { |
| return logicaltest.TestStep{ |
| Operation: logical.ReadOperation, |
| Path: "roles/" + name, |
| Check: func(resp *logical.Response) error { |
| if resp == nil { |
| if tags == "" && rawVHosts == "" && rawVHostTopics == "" { |
| return nil |
| } |
| |
| return fmt.Errorf("bad: %#v", resp) |
| } |
| |
| var d struct { |
| Tags string `mapstructure:"tags"` |
| VHosts map[string]vhostPermission `mapstructure:"vhosts"` |
| VHostTopics map[string]map[string]vhostTopicPermission `mapstructure:"vhost_topics"` |
| } |
| if err := mapstructure.Decode(resp.Data, &d); err != nil { |
| return err |
| } |
| |
| if d.Tags != tags { |
| return fmt.Errorf("bad: %#v", resp) |
| } |
| |
| var vhosts map[string]vhostPermission |
| if err := jsonutil.DecodeJSON([]byte(rawVHosts), &vhosts); err != nil { |
| return fmt.Errorf("bad expected vhosts %#v: %s", vhosts, err) |
| } |
| |
| for host, permission := range vhosts { |
| actualPermission, ok := d.VHosts[host] |
| if !ok { |
| return fmt.Errorf("expected vhost: %s", host) |
| } |
| |
| if actualPermission.Configure != permission.Configure { |
| return fmt.Errorf("expected permission %s to be %s, got %s", "configure", permission.Configure, actualPermission.Configure) |
| } |
| |
| if actualPermission.Write != permission.Write { |
| return fmt.Errorf("expected permission %s to be %s, got %s", "write", permission.Write, actualPermission.Write) |
| } |
| |
| if actualPermission.Read != permission.Read { |
| return fmt.Errorf("expected permission %s to be %s, got %s", "read", permission.Read, actualPermission.Read) |
| } |
| } |
| |
| var vhostTopics map[string]map[string]vhostTopicPermission |
| if err := jsonutil.DecodeJSON([]byte(rawVHostTopics), &vhostTopics); err != nil { |
| return fmt.Errorf("bad expected vhostTopics %#v: %s", vhostTopics, err) |
| } |
| |
| for host, permissions := range vhostTopics { |
| for exchange, permission := range permissions { |
| actualPermissions, ok := d.VHostTopics[host] |
| if !ok { |
| return fmt.Errorf("expected vhost topics: %s", host) |
| } |
| |
| actualPermission, ok := actualPermissions[exchange] |
| if !ok { |
| return fmt.Errorf("expected vhost topic exchange: %s", exchange) |
| } |
| |
| if actualPermission.Write != permission.Write { |
| return fmt.Errorf("expected permission %s to be %s, got %s", "write", permission.Write, actualPermission.Write) |
| } |
| |
| if actualPermission.Read != permission.Read { |
| return fmt.Errorf("expected permission %s to be %s, got %s", "read", permission.Read, actualPermission.Read) |
| } |
| } |
| } |
| |
| return nil |
| }, |
| } |
| } |