| // Copyright (c) HashiCorp, Inc. |
| // SPDX-License-Identifier: MPL-2.0 |
| |
| package database |
| |
| import ( |
| "context" |
| "strings" |
| "testing" |
| |
| "github.com/hashicorp/vault/helper/namespace" |
| "github.com/hashicorp/vault/helper/versions" |
| "github.com/hashicorp/vault/sdk/helper/consts" |
| "github.com/hashicorp/vault/sdk/logical" |
| ) |
| |
| func TestWriteConfig_PluginVersionInStorage(t *testing.T) { |
| cluster, sys := getCluster(t) |
| t.Cleanup(cluster.Cleanup) |
| |
| config := logical.TestBackendConfig() |
| config.StorageView = &logical.InmemStorage{} |
| config.System = sys |
| |
| b, err := Factory(context.Background(), config) |
| if err != nil { |
| t.Fatal(err) |
| } |
| defer b.Cleanup(context.Background()) |
| |
| const hdb = "hana-database-plugin" |
| hdbBuiltin := versions.GetBuiltinVersion(consts.PluginTypeDatabase, hdb) |
| |
| // Configure a connection |
| writePluginVersion := func() { |
| t.Helper() |
| req := &logical.Request{ |
| Operation: logical.UpdateOperation, |
| Path: "config/plugin-test", |
| Storage: config.StorageView, |
| Data: map[string]interface{}{ |
| "connection_url": "test", |
| "plugin_name": hdb, |
| "plugin_version": hdbBuiltin, |
| "verify_connection": false, |
| }, |
| } |
| resp, err := b.HandleRequest(namespace.RootContext(nil), req) |
| if err != nil || (resp != nil && resp.IsError()) { |
| t.Fatalf("err:%s resp:%#v\n", err, resp) |
| } |
| } |
| writePluginVersion() |
| |
| getPluginVersionFromAPI := func() string { |
| t.Helper() |
| req := &logical.Request{ |
| Operation: logical.ReadOperation, |
| Path: "config/plugin-test", |
| Storage: config.StorageView, |
| } |
| |
| resp, err := b.HandleRequest(namespace.RootContext(nil), req) |
| if err != nil || (resp != nil && resp.IsError()) { |
| t.Fatalf("err:%s resp:%#v\n", err, resp) |
| } |
| |
| return resp.Data["plugin_version"].(string) |
| } |
| pluginVersion := getPluginVersionFromAPI() |
| if pluginVersion != "" { |
| t.Fatalf("expected plugin_version empty but got %s", pluginVersion) |
| } |
| |
| // Directly store config to get the builtin plugin version into storage, |
| // simulating a write that happened before upgrading to 1.12.2+ |
| err = storeConfig(context.Background(), config.StorageView, "plugin-test", &DatabaseConfig{ |
| PluginName: hdb, |
| PluginVersion: hdbBuiltin, |
| }) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| // Now replay the read request, and we still shouldn't get the builtin version back. |
| pluginVersion = getPluginVersionFromAPI() |
| if pluginVersion != "" { |
| t.Fatalf("expected plugin_version empty but got %s", pluginVersion) |
| } |
| |
| // Check the underlying data, which should still have the version in storage. |
| getPluginVersionFromStorage := func() string { |
| t.Helper() |
| entry, err := config.StorageView.Get(context.Background(), "config/plugin-test") |
| if err != nil { |
| t.Fatal(err) |
| } |
| if entry == nil { |
| t.Fatal() |
| } |
| |
| var config DatabaseConfig |
| if err := entry.DecodeJSON(&config); err != nil { |
| t.Fatal(err) |
| } |
| return config.PluginVersion |
| } |
| |
| storagePluginVersion := getPluginVersionFromStorage() |
| if storagePluginVersion != hdbBuiltin { |
| t.Fatalf("Expected %s, got: %s", hdbBuiltin, storagePluginVersion) |
| } |
| |
| // Trigger a write to storage, which should clean up plugin version in the storage entry. |
| writePluginVersion() |
| |
| storagePluginVersion = getPluginVersionFromStorage() |
| if storagePluginVersion != "" { |
| t.Fatalf("Expected empty, got: %s", storagePluginVersion) |
| } |
| |
| // Finally, confirm API requests still return empty plugin version too |
| pluginVersion = getPluginVersionFromAPI() |
| if pluginVersion != "" { |
| t.Fatalf("expected plugin_version empty but got %s", pluginVersion) |
| } |
| } |
| |
| func TestWriteConfig_HelpfulErrorMessageWhenBuiltinOverridden(t *testing.T) { |
| cluster, sys := getCluster(t) |
| t.Cleanup(cluster.Cleanup) |
| |
| config := logical.TestBackendConfig() |
| config.StorageView = &logical.InmemStorage{} |
| config.System = sys |
| |
| b, err := Factory(context.Background(), config) |
| if err != nil { |
| t.Fatal(err) |
| } |
| defer b.Cleanup(context.Background()) |
| |
| const pg = "postgresql-database-plugin" |
| pgBuiltin := versions.GetBuiltinVersion(consts.PluginTypeDatabase, pg) |
| |
| // Configure a connection |
| data := map[string]interface{}{ |
| "connection_url": "test", |
| "plugin_name": pg, |
| "plugin_version": pgBuiltin, |
| "verify_connection": false, |
| } |
| req := &logical.Request{ |
| Operation: logical.UpdateOperation, |
| Path: "config/plugin-test", |
| Storage: config.StorageView, |
| Data: data, |
| } |
| resp, err := b.HandleRequest(namespace.RootContext(nil), req) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if resp == nil || !resp.IsError() { |
| t.Fatalf("resp:%#v", resp) |
| } |
| if !strings.Contains(resp.Error().Error(), "overridden by an unversioned plugin") { |
| t.Fatalf("expected overridden error but got: %s", resp.Error()) |
| } |
| } |