blob: 8cf06062890f5cf2b306a09d35d9c2ed8ff3cfc7 [file] [log] [blame]
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package database
import (
"context"
"strings"
"testing"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/helper/versions"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/logical"
)
func TestWriteConfig_PluginVersionInStorage(t *testing.T) {
cluster, sys := getCluster(t)
t.Cleanup(cluster.Cleanup)
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
config.System = sys
b, err := Factory(context.Background(), config)
if err != nil {
t.Fatal(err)
}
defer b.Cleanup(context.Background())
const hdb = "hana-database-plugin"
hdbBuiltin := versions.GetBuiltinVersion(consts.PluginTypeDatabase, hdb)
// Configure a connection
writePluginVersion := func() {
t.Helper()
req := &logical.Request{
Operation: logical.UpdateOperation,
Path: "config/plugin-test",
Storage: config.StorageView,
Data: map[string]interface{}{
"connection_url": "test",
"plugin_name": hdb,
"plugin_version": hdbBuiltin,
"verify_connection": false,
},
}
resp, err := b.HandleRequest(namespace.RootContext(nil), req)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
}
writePluginVersion()
getPluginVersionFromAPI := func() string {
t.Helper()
req := &logical.Request{
Operation: logical.ReadOperation,
Path: "config/plugin-test",
Storage: config.StorageView,
}
resp, err := b.HandleRequest(namespace.RootContext(nil), req)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
return resp.Data["plugin_version"].(string)
}
pluginVersion := getPluginVersionFromAPI()
if pluginVersion != "" {
t.Fatalf("expected plugin_version empty but got %s", pluginVersion)
}
// Directly store config to get the builtin plugin version into storage,
// simulating a write that happened before upgrading to 1.12.2+
err = storeConfig(context.Background(), config.StorageView, "plugin-test", &DatabaseConfig{
PluginName: hdb,
PluginVersion: hdbBuiltin,
})
if err != nil {
t.Fatal(err)
}
// Now replay the read request, and we still shouldn't get the builtin version back.
pluginVersion = getPluginVersionFromAPI()
if pluginVersion != "" {
t.Fatalf("expected plugin_version empty but got %s", pluginVersion)
}
// Check the underlying data, which should still have the version in storage.
getPluginVersionFromStorage := func() string {
t.Helper()
entry, err := config.StorageView.Get(context.Background(), "config/plugin-test")
if err != nil {
t.Fatal(err)
}
if entry == nil {
t.Fatal()
}
var config DatabaseConfig
if err := entry.DecodeJSON(&config); err != nil {
t.Fatal(err)
}
return config.PluginVersion
}
storagePluginVersion := getPluginVersionFromStorage()
if storagePluginVersion != hdbBuiltin {
t.Fatalf("Expected %s, got: %s", hdbBuiltin, storagePluginVersion)
}
// Trigger a write to storage, which should clean up plugin version in the storage entry.
writePluginVersion()
storagePluginVersion = getPluginVersionFromStorage()
if storagePluginVersion != "" {
t.Fatalf("Expected empty, got: %s", storagePluginVersion)
}
// Finally, confirm API requests still return empty plugin version too
pluginVersion = getPluginVersionFromAPI()
if pluginVersion != "" {
t.Fatalf("expected plugin_version empty but got %s", pluginVersion)
}
}
func TestWriteConfig_HelpfulErrorMessageWhenBuiltinOverridden(t *testing.T) {
cluster, sys := getCluster(t)
t.Cleanup(cluster.Cleanup)
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
config.System = sys
b, err := Factory(context.Background(), config)
if err != nil {
t.Fatal(err)
}
defer b.Cleanup(context.Background())
const pg = "postgresql-database-plugin"
pgBuiltin := versions.GetBuiltinVersion(consts.PluginTypeDatabase, pg)
// Configure a connection
data := map[string]interface{}{
"connection_url": "test",
"plugin_name": pg,
"plugin_version": pgBuiltin,
"verify_connection": false,
}
req := &logical.Request{
Operation: logical.UpdateOperation,
Path: "config/plugin-test",
Storage: config.StorageView,
Data: data,
}
resp, err := b.HandleRequest(namespace.RootContext(nil), req)
if err != nil {
t.Fatal(err)
}
if resp == nil || !resp.IsError() {
t.Fatalf("resp:%#v", resp)
}
if !strings.Contains(resp.Error().Error(), "overridden by an unversioned plugin") {
t.Fatalf("expected overridden error but got: %s", resp.Error())
}
}