blob: f4d555c5c2ca27943c4405aac753819f77f160cf [file] [log] [blame] [edit]
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package server
import (
"bytes"
"crypto/tls"
"io"
"net"
"testing"
)
type testListenerConnFn func(net.Listener) (net.Conn, error)
func testListenerImpl(t *testing.T, ln net.Listener, connFn testListenerConnFn, certName string, expectedVersion uint16, expectedAddr string, expectError bool) {
serverCh := make(chan net.Conn, 1)
go func() {
server, err := ln.Accept()
if err != nil {
if !expectError {
t.Errorf("err: %s", err)
}
close(serverCh)
return
}
if certName != "" {
tlsConn := server.(*tls.Conn)
tlsConn.Handshake()
}
serverCh <- server
if expectedAddr == "" {
return
}
addr, _, err := net.SplitHostPort(server.RemoteAddr().String())
if err != nil {
t.Error(err)
}
if addr != expectedAddr {
t.Errorf("expected: %s, got: %s", expectedAddr, addr)
}
}()
client, err := connFn(ln)
if err != nil {
t.Fatalf("err: %s", err)
}
if certName != "" {
tlsConn := client.(*tls.Conn)
if expectedVersion != 0 && tlsConn.ConnectionState().Version != expectedVersion {
t.Fatalf("expected version %d, got %d", expectedVersion, tlsConn.ConnectionState().Version)
}
if len(tlsConn.ConnectionState().PeerCertificates) != 1 {
t.Fatalf("err: number of certs too long")
}
peerName := tlsConn.ConnectionState().PeerCertificates[0].Subject.CommonName
if peerName != certName {
t.Fatalf("err: bad cert name %s, expected %s", peerName, certName)
}
}
server := <-serverCh
if server == nil {
if !expectError {
// Something failed already so we abort the test early
t.Fatal("aborting test because the server did not accept the connection")
}
return
}
defer client.Close()
defer server.Close()
var buf bytes.Buffer
copyCh := make(chan struct{})
go func() {
io.Copy(&buf, server)
close(copyCh)
}()
if _, err := client.Write([]byte("foo")); err != nil {
t.Fatalf("err: %s", err)
}
client.Close()
<-copyCh
if (buf.String() != "foo" && !expectError) || (buf.String() == "foo" && expectError) {
t.Fatalf("bad: %q, expectError: %t", buf.String(), expectError)
}
}
func TestProfilingUnauthenticatedInFlightAccess(t *testing.T) {
config, err := LoadConfigFile("./test-fixtures/unauth_in_flight_access.hcl")
if err != nil {
t.Fatalf("Error encountered when loading config %+v", err)
}
if !config.Listeners[0].InFlightRequestLogging.UnauthenticatedInFlightAccess {
t.Fatalf("failed to read UnauthenticatedInFlightAccess")
}
}