blob: a9ee8079557be3852a050e8559d6010b42372347 [file] [log] [blame]
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package rpcapi
import (
"context"
"strings"
"sync"
"testing"
"github.com/hashicorp/terraform/internal/rpcapi/terraform1/setup"
)
func TestSetupServer_Handshake(t *testing.T) {
called := 0
server := newSetupServer(func(ctx context.Context, req *setup.Handshake_Request, stopper *stopper) (*setup.ServerCapabilities, error) {
called++
if got, want := req.Config.Credentials["localterraform.com"].Token, "boop"; got != want {
t.Fatalf("incorrect token. got %q, want %q", got, want)
}
return &setup.ServerCapabilities{}, nil
})
req := &setup.Handshake_Request{
Capabilities: &setup.ClientCapabilities{},
Config: &setup.Config{
Credentials: map[string]*setup.HostCredential{
"localterraform.com": {
Token: "boop",
},
},
},
}
_, err := server.Handshake(context.Background(), req)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if called != 1 {
t.Errorf("unexpected initOthers call count %d, want 1", called)
}
_, err = server.Handshake(context.Background(), req)
if err == nil || !strings.Contains(err.Error(), "handshake already completed") {
t.Fatalf("unexpected error: %s", err)
}
if called != 1 {
t.Errorf("unexpected initOthers call count %d, want 1", called)
}
}
func TestSetupServer_Stop(t *testing.T) {
var s *stopper
server := newSetupServer(func(ctx context.Context, req *setup.Handshake_Request, stopper *stopper) (*setup.ServerCapabilities, error) {
s = stopper
return &setup.ServerCapabilities{}, nil
})
_, err := server.Handshake(context.Background(), nil)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if s == nil {
t.Fatal("stopper not passed to initOthers")
}
var wg sync.WaitGroup
var stops []stopChan
for range 2 {
stops = append(stops, s.add())
wg.Add(1)
}
for _, stop := range stops {
stop := stop
go func() {
<-stop
wg.Done()
}()
}
server.Stop(context.Background(), nil)
wg.Wait()
}