| // Copyright (c) HashiCorp, Inc. |
| // SPDX-License-Identifier: MPL-2.0 |
| |
| package shamir |
| |
| import ( |
| "bytes" |
| "testing" |
| ) |
| |
| func TestSplit_invalid(t *testing.T) { |
| secret := []byte("test") |
| |
| if _, err := Split(secret, 0, 0); err == nil { |
| t.Fatalf("expect error") |
| } |
| |
| if _, err := Split(secret, 2, 3); err == nil { |
| t.Fatalf("expect error") |
| } |
| |
| if _, err := Split(secret, 1000, 3); err == nil { |
| t.Fatalf("expect error") |
| } |
| |
| if _, err := Split(secret, 10, 1); err == nil { |
| t.Fatalf("expect error") |
| } |
| |
| if _, err := Split(nil, 3, 2); err == nil { |
| t.Fatalf("expect error") |
| } |
| } |
| |
| func TestSplit(t *testing.T) { |
| secret := []byte("test") |
| |
| out, err := Split(secret, 5, 3) |
| if err != nil { |
| t.Fatalf("err: %v", err) |
| } |
| |
| if len(out) != 5 { |
| t.Fatalf("bad: %v", out) |
| } |
| |
| for _, share := range out { |
| if len(share) != len(secret)+1 { |
| t.Fatalf("bad: %v", out) |
| } |
| } |
| } |
| |
| func TestCombine_invalid(t *testing.T) { |
| // Not enough parts |
| if _, err := Combine(nil); err == nil { |
| t.Fatalf("should err") |
| } |
| |
| // Mis-match in length |
| parts := [][]byte{ |
| []byte("foo"), |
| []byte("ba"), |
| } |
| if _, err := Combine(parts); err == nil { |
| t.Fatalf("should err") |
| } |
| |
| // Too short |
| parts = [][]byte{ |
| []byte("f"), |
| []byte("b"), |
| } |
| if _, err := Combine(parts); err == nil { |
| t.Fatalf("should err") |
| } |
| |
| parts = [][]byte{ |
| []byte("foo"), |
| []byte("foo"), |
| } |
| if _, err := Combine(parts); err == nil { |
| t.Fatalf("should err") |
| } |
| } |
| |
| func TestCombine(t *testing.T) { |
| secret := []byte("test") |
| |
| out, err := Split(secret, 5, 3) |
| if err != nil { |
| t.Fatalf("err: %v", err) |
| } |
| |
| // There is 5*4*3 possible choices, |
| // we will just brute force try them all |
| for i := 0; i < 5; i++ { |
| for j := 0; j < 5; j++ { |
| if j == i { |
| continue |
| } |
| for k := 0; k < 5; k++ { |
| if k == i || k == j { |
| continue |
| } |
| parts := [][]byte{out[i], out[j], out[k]} |
| recomb, err := Combine(parts) |
| if err != nil { |
| t.Fatalf("err: %v", err) |
| } |
| |
| if !bytes.Equal(recomb, secret) { |
| t.Errorf("parts: (i:%d, j:%d, k:%d) %v", i, j, k, parts) |
| t.Fatalf("bad: %v %v", recomb, secret) |
| } |
| } |
| } |
| } |
| } |
| |
| func TestField_Add(t *testing.T) { |
| if out := add(16, 16); out != 0 { |
| t.Fatalf("Bad: %v 16", out) |
| } |
| |
| if out := add(3, 4); out != 7 { |
| t.Fatalf("Bad: %v 7", out) |
| } |
| } |
| |
| func TestField_Mult(t *testing.T) { |
| if out := mult(3, 7); out != 9 { |
| t.Fatalf("Bad: %v 9", out) |
| } |
| |
| if out := mult(3, 0); out != 0 { |
| t.Fatalf("Bad: %v 0", out) |
| } |
| |
| if out := mult(0, 3); out != 0 { |
| t.Fatalf("Bad: %v 0", out) |
| } |
| } |
| |
| func TestField_Divide(t *testing.T) { |
| if out := div(0, 7); out != 0 { |
| t.Fatalf("Bad: %v 0", out) |
| } |
| |
| if out := div(3, 3); out != 1 { |
| t.Fatalf("Bad: %v 1", out) |
| } |
| |
| if out := div(6, 3); out != 2 { |
| t.Fatalf("Bad: %v 2", out) |
| } |
| } |
| |
| func TestPolynomial_Random(t *testing.T) { |
| p, err := makePolynomial(42, 2) |
| if err != nil { |
| t.Fatalf("err: %v", err) |
| } |
| |
| if p.coefficients[0] != 42 { |
| t.Fatalf("bad: %v", p.coefficients) |
| } |
| } |
| |
| func TestPolynomial_Eval(t *testing.T) { |
| p, err := makePolynomial(42, 1) |
| if err != nil { |
| t.Fatalf("err: %v", err) |
| } |
| |
| if out := p.evaluate(0); out != 42 { |
| t.Fatalf("bad: %v", out) |
| } |
| |
| out := p.evaluate(1) |
| exp := add(42, mult(1, p.coefficients[1])) |
| if out != exp { |
| t.Fatalf("bad: %v %v %v", out, exp, p.coefficients) |
| } |
| } |
| |
| func TestInterpolate_Rand(t *testing.T) { |
| for i := 0; i < 256; i++ { |
| p, err := makePolynomial(uint8(i), 2) |
| if err != nil { |
| t.Fatalf("err: %v", err) |
| } |
| |
| x_vals := []uint8{1, 2, 3} |
| y_vals := []uint8{p.evaluate(1), p.evaluate(2), p.evaluate(3)} |
| out := interpolatePolynomial(x_vals, y_vals, 0) |
| if out != uint8(i) { |
| t.Fatalf("Bad: %v %d", out, i) |
| } |
| } |
| } |