blob: 1abb958d34431b4ada39e23671b6c0488b24069c [file] [log] [blame]
package cose
import (
"crypto/rand"
"crypto/rsa"
"reflect"
"testing"
)
func generateTestRSAKey(t *testing.T) *rsa.PrivateKey {
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("rsa.GenerateKey() error = %v", err)
}
return key
}
func Test_rsaSigner(t *testing.T) {
// generate key
alg := AlgorithmPS256
key := generateTestRSAKey(t)
// set up signer
signer, err := NewSigner(alg, key)
if err != nil {
t.Fatalf("NewSigner() error = %v", err)
}
if _, ok := signer.(*rsaSigner); !ok {
t.Fatalf("NewSigner() type = %v, want *rsaSigner", reflect.TypeOf(signer))
}
if got := signer.Algorithm(); got != alg {
t.Fatalf("Algorithm() = %v, want %v", got, alg)
}
// sign / verify round trip
// see also conformance_test.go for strict tests.
digest, err := alg.computeHash([]byte("hello world"))
if err != nil {
t.Fatalf("Algorithm.computeHash() error = %v", err)
}
sig, err := signer.Sign(rand.Reader, digest)
if err != nil {
t.Fatalf("Sign() error = %v", err)
}
verifier, err := NewVerifier(alg, key.Public())
if err != nil {
t.Fatalf("NewVerifier() error = %v", err)
}
if err := verifier.Verify(digest, sig); err != nil {
t.Fatalf("Verifier.Verify() error = %v", err)
}
}
func Test_rsaVerifier_Verify_Success(t *testing.T) {
// generate key
alg := AlgorithmPS256
key := generateTestRSAKey(t)
// generate a valid signature
digest, sig := signTestData(t, alg, key)
// set up verifier
verifier, err := NewVerifier(alg, key.Public())
if err != nil {
t.Fatalf("NewVerifier() error = %v", err)
}
if _, ok := verifier.(*rsaVerifier); !ok {
t.Fatalf("NewVerifier() type = %v, want *rsaVerifier", reflect.TypeOf(verifier))
}
if got := verifier.Algorithm(); got != alg {
t.Fatalf("Algorithm() = %v, want %v", got, alg)
}
// verify round trip
if err := verifier.Verify(digest, sig); err != nil {
t.Fatalf("rsaVerifier.Verify() error = %v", err)
}
}
func Test_rsaVerifier_Verify_AlgorithmMismatch(t *testing.T) {
// generate key
alg := AlgorithmPS256
key := generateTestRSAKey(t)
// generate a valid signature
digest, sig := signTestData(t, alg, key)
// set up verifier with a different algorithm
verifier := &rsaVerifier{
alg: AlgorithmPS512,
key: &key.PublicKey,
}
// verification should fail on algorithm mismatch
if err := verifier.Verify(digest, sig); err != ErrVerification {
t.Fatalf("rsaVerifier.Verify() error = %v, wantErr %v", err, ErrVerification)
}
}
func Test_rsaVerifier_Verify_KeyMismatch(t *testing.T) {
// generate key
alg := AlgorithmPS256
key := generateTestRSAKey(t)
// generate a valid signature
digest, sig := signTestData(t, alg, key)
// set up verifier with a different key / new key
key = generateTestRSAKey(t)
verifier := &rsaVerifier{
alg: alg,
key: &key.PublicKey,
}
// verification should fail on key mismatch
if err := verifier.Verify(digest, sig); err != ErrVerification {
t.Fatalf("rsaVerifier.Verify() error = %v, wantErr %v", err, ErrVerification)
}
}
func Test_rsaVerifier_Verify_InvalidSignature(t *testing.T) {
// generate key
alg := AlgorithmPS256
key := generateTestRSAKey(t)
// generate a valid signature with a tampered one
digest, sig := signTestData(t, alg, key)
tamperedSig := make([]byte, len(sig))
copy(tamperedSig, sig)
tamperedSig[0]++
// set up verifier with a different algorithm
verifier := &rsaVerifier{
alg: alg,
key: &key.PublicKey,
}
// verification should fail on invalid signature
tests := []struct {
name string
signature []byte
}{
{
name: "nil signature",
signature: nil,
},
{
name: "empty signature",
signature: []byte{},
},
{
name: "incomplete signature",
signature: sig[:len(sig)-2],
},
{
name: "tampered signature",
signature: tamperedSig,
},
{
name: "too many signature bytes",
signature: append(sig, 0),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := verifier.Verify(digest, tt.signature); err != ErrVerification {
t.Errorf("rsaVerifier.Verify() error = %v, wantErr %v", err, ErrVerification)
}
})
}
}