| // Copyright (c) 2017-2022 Snowflake Computing Inc. All rights reserved. |
| |
| package ocsp |
| |
| import ( |
| "bytes" |
| "context" |
| "crypto" |
| "crypto/tls" |
| "crypto/x509" |
| "errors" |
| "fmt" |
| "io" |
| "io/ioutil" |
| "net" |
| "net/http" |
| "net/url" |
| "testing" |
| "time" |
| |
| "github.com/hashicorp/go-hclog" |
| "github.com/hashicorp/go-retryablehttp" |
| lru "github.com/hashicorp/golang-lru" |
| "golang.org/x/crypto/ocsp" |
| ) |
| |
| func TestOCSP(t *testing.T) { |
| targetURL := []string{ |
| "https://sfcdev1.blob.core.windows.net/", |
| "https://sfctest0.snowflakecomputing.com/", |
| "https://s3-us-west-2.amazonaws.com/sfc-snowsql-updates/?prefix=1.1/windows_x86_64", |
| } |
| |
| conf := VerifyConfig{ |
| OcspFailureMode: FailOpenFalse, |
| } |
| c := New(testLogFactory, 10) |
| transports := []*http.Transport{ |
| newInsecureOcspTransport(nil), |
| c.NewTransport(&conf), |
| } |
| |
| for _, tgt := range targetURL { |
| c.ocspResponseCache, _ = lru.New2Q(10) |
| for _, tr := range transports { |
| c := &http.Client{ |
| Transport: tr, |
| Timeout: 30 * time.Second, |
| } |
| req, err := http.NewRequest("GET", tgt, bytes.NewReader(nil)) |
| if err != nil { |
| t.Fatalf("fail to create a request. err: %v", err) |
| } |
| res, err := c.Do(req) |
| if err != nil { |
| t.Fatalf("failed to GET contents. err: %v", err) |
| } |
| defer res.Body.Close() |
| _, err = ioutil.ReadAll(res.Body) |
| if err != nil { |
| t.Fatalf("failed to read content body for %v", tgt) |
| } |
| |
| } |
| } |
| } |
| |
| /** |
| // Used for development, requires an active Vault with PKI setup |
| func TestMultiOCSP(t *testing.T) { |
| |
| targetURL := []string{ |
| "https://localhost:8200/v1/pki/ocsp", |
| "https://localhost:8200/v1/pki/ocsp", |
| "https://localhost:8200/v1/pki/ocsp", |
| } |
| |
| b, _ := pem.Decode([]byte(vaultCert)) |
| caCert, _ := x509.ParseCertificate(b.Bytes) |
| conf := VerifyConfig{ |
| OcspFailureMode: FailOpenFalse, |
| QueryAllServers: true, |
| OcspServersOverride: targetURL, |
| ExtraCas: []*x509.Certificate{caCert}, |
| } |
| c := New(testLogFactory, 10) |
| transports := []*http.Transport{ |
| newInsecureOcspTransport(conf.ExtraCas), |
| c.NewTransport(&conf), |
| } |
| |
| tgt := "https://localhost:8200/v1/pki/ca/pem" |
| c.ocspResponseCache, _ = lru.New2Q(10) |
| for _, tr := range transports { |
| c := &http.Client{ |
| Transport: tr, |
| Timeout: 30 * time.Second, |
| } |
| req, err := http.NewRequest("GET", tgt, bytes.NewReader(nil)) |
| if err != nil { |
| t.Fatalf("fail to create a request. err: %v", err) |
| } |
| res, err := c.Do(req) |
| if err != nil { |
| t.Fatalf("failed to GET contents. err: %v", err) |
| } |
| defer res.Body.Close() |
| _, err = ioutil.ReadAll(res.Body) |
| if err != nil { |
| t.Fatalf("failed to read content body for %v", tgt) |
| } |
| } |
| } |
| */ |
| |
| func TestUnitEncodeCertIDGood(t *testing.T) { |
| targetURLs := []string{ |
| "faketestaccount.snowflakecomputing.com:443", |
| "s3-us-west-2.amazonaws.com:443", |
| "sfcdev1.blob.core.windows.net:443", |
| } |
| for _, tt := range targetURLs { |
| chainedCerts := getCert(tt) |
| for i := 0; i < len(chainedCerts)-1; i++ { |
| subject := chainedCerts[i] |
| issuer := chainedCerts[i+1] |
| ocspServers := subject.OCSPServer |
| if len(ocspServers) == 0 { |
| t.Fatalf("no OCSP server is found. cert: %v", subject.Subject) |
| } |
| ocspReq, err := ocsp.CreateRequest(subject, issuer, &ocsp.RequestOptions{}) |
| if err != nil { |
| t.Fatalf("failed to create OCSP request. err: %v", err) |
| } |
| var ost *ocspStatus |
| _, ost = extractCertIDKeyFromRequest(ocspReq) |
| if ost.err != nil { |
| t.Fatalf("failed to extract cert ID from the OCSP request. err: %v", ost.err) |
| } |
| // better hash. Not sure if the actual OCSP server accepts this, though. |
| ocspReq, err = ocsp.CreateRequest(subject, issuer, &ocsp.RequestOptions{Hash: crypto.SHA512}) |
| if err != nil { |
| t.Fatalf("failed to create OCSP request. err: %v", err) |
| } |
| _, ost = extractCertIDKeyFromRequest(ocspReq) |
| if ost.err != nil { |
| t.Fatalf("failed to extract cert ID from the OCSP request. err: %v", ost.err) |
| } |
| // tweaked request binary |
| ocspReq, err = ocsp.CreateRequest(subject, issuer, &ocsp.RequestOptions{Hash: crypto.SHA512}) |
| if err != nil { |
| t.Fatalf("failed to create OCSP request. err: %v", err) |
| } |
| ocspReq[10] = 0 // random change |
| _, ost = extractCertIDKeyFromRequest(ocspReq) |
| if ost.err == nil { |
| t.Fatal("should have failed") |
| } |
| } |
| } |
| } |
| |
| func TestUnitCheckOCSPResponseCache(t *testing.T) { |
| c := New(testLogFactory, 10) |
| dummyKey0 := certIDKey{ |
| NameHash: "dummy0", |
| IssuerKeyHash: "dummy0", |
| SerialNumber: "dummy0", |
| } |
| dummyKey := certIDKey{ |
| NameHash: "dummy1", |
| IssuerKeyHash: "dummy1", |
| SerialNumber: "dummy1", |
| } |
| currentTime := float64(time.Now().UTC().Unix()) |
| c.ocspResponseCache.Add(dummyKey0, &ocspCachedResponse{time: currentTime}) |
| subject := &x509.Certificate{} |
| issuer := &x509.Certificate{} |
| ost, err := c.checkOCSPResponseCache(&dummyKey, subject, issuer) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if ost.code != ocspMissedCache { |
| t.Fatalf("should have failed. expected: %v, got: %v", ocspMissedCache, ost.code) |
| } |
| // old timestamp |
| c.ocspResponseCache.Add(dummyKey, &ocspCachedResponse{time: float64(1395054952)}) |
| ost, err = c.checkOCSPResponseCache(&dummyKey, subject, issuer) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if ost.code != ocspCacheExpired { |
| t.Fatalf("should have failed. expected: %v, got: %v", ocspCacheExpired, ost.code) |
| } |
| |
| // invalid validity |
| c.ocspResponseCache.Add(dummyKey, &ocspCachedResponse{time: float64(currentTime - 1000)}) |
| ost, err = c.checkOCSPResponseCache(&dummyKey, subject, nil) |
| if err == nil && isValidOCSPStatus(ost.code) { |
| t.Fatalf("should have failed.") |
| } |
| } |
| |
| func TestUnitValidateOCSP(t *testing.T) { |
| ocspRes := &ocsp.Response{} |
| ost, err := validateOCSP(ocspRes) |
| if err == nil && isValidOCSPStatus(ost.code) { |
| t.Fatalf("should have failed.") |
| } |
| |
| currentTime := time.Now() |
| ocspRes.ThisUpdate = currentTime.Add(-2 * time.Hour) |
| ocspRes.NextUpdate = currentTime.Add(2 * time.Hour) |
| ocspRes.Status = ocsp.Revoked |
| ost, err = validateOCSP(ocspRes) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| if ost.code != ocspStatusRevoked { |
| t.Fatalf("should have failed. expected: %v, got: %v", ocspStatusRevoked, ost.code) |
| } |
| ocspRes.Status = ocsp.Good |
| ost, err = validateOCSP(ocspRes) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| if ost.code != ocspStatusGood { |
| t.Fatalf("should have success. expected: %v, got: %v", ocspStatusGood, ost.code) |
| } |
| ocspRes.Status = ocsp.Unknown |
| ost, err = validateOCSP(ocspRes) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if ost.code != ocspStatusUnknown { |
| t.Fatalf("should have failed. expected: %v, got: %v", ocspStatusUnknown, ost.code) |
| } |
| ocspRes.Status = ocsp.ServerFailed |
| ost, err = validateOCSP(ocspRes) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if ost.code != ocspStatusOthers { |
| t.Fatalf("should have failed. expected: %v, got: %v", ocspStatusOthers, ost.code) |
| } |
| } |
| |
| func TestUnitEncodeCertID(t *testing.T) { |
| var st *ocspStatus |
| _, st = extractCertIDKeyFromRequest([]byte{0x1, 0x2}) |
| if st.code != ocspFailedDecomposeRequest { |
| t.Fatalf("failed to get OCSP status. expected: %v, got: %v", ocspFailedDecomposeRequest, st.code) |
| } |
| } |
| |
| func getCert(addr string) []*x509.Certificate { |
| tcpConn, err := net.DialTimeout("tcp", addr, 40*time.Second) |
| if err != nil { |
| panic(err) |
| } |
| defer tcpConn.Close() |
| |
| err = tcpConn.SetDeadline(time.Now().Add(10 * time.Second)) |
| if err != nil { |
| panic(err) |
| } |
| config := tls.Config{InsecureSkipVerify: true, ServerName: addr} |
| |
| conn := tls.Client(tcpConn, &config) |
| defer conn.Close() |
| |
| err = conn.Handshake() |
| if err != nil { |
| panic(err) |
| } |
| |
| state := conn.ConnectionState() |
| |
| return state.PeerCertificates |
| } |
| |
| func TestOCSPRetry(t *testing.T) { |
| c := New(testLogFactory, 10) |
| certs := getCert("s3-us-west-2.amazonaws.com:443") |
| dummyOCSPHost := &url.URL{ |
| Scheme: "https", |
| Host: "dummyOCSPHost", |
| } |
| client := &fakeHTTPClient{ |
| cnt: 3, |
| success: true, |
| body: []byte{1, 2, 3}, |
| logger: hclog.New(hclog.DefaultOptions), |
| t: t, |
| } |
| res, b, st, err := c.retryOCSP( |
| context.TODO(), |
| client, fakeRequestFunc, |
| dummyOCSPHost, |
| make(map[string]string), []byte{0}, certs[len(certs)-1]) |
| if err == nil { |
| fmt.Printf("should fail: %v, %v, %v\n", res, b, st) |
| } |
| client = &fakeHTTPClient{ |
| cnt: 30, |
| success: true, |
| body: []byte{1, 2, 3}, |
| logger: hclog.New(hclog.DefaultOptions), |
| t: t, |
| } |
| res, b, st, err = c.retryOCSP( |
| context.TODO(), |
| client, fakeRequestFunc, |
| dummyOCSPHost, |
| make(map[string]string), []byte{0}, certs[len(certs)-1]) |
| if err == nil { |
| fmt.Printf("should fail: %v, %v, %v\n", res, b, st) |
| } |
| } |
| |
| type tcCanEarlyExit struct { |
| results []*ocspStatus |
| resultLen int |
| retFailOpen *ocspStatus |
| retFailClosed *ocspStatus |
| } |
| |
| func TestCanEarlyExitForOCSP(t *testing.T) { |
| testcases := []tcCanEarlyExit{ |
| { // 0 |
| results: []*ocspStatus{ |
| { |
| code: ocspStatusGood, |
| }, |
| { |
| code: ocspStatusGood, |
| }, |
| { |
| code: ocspStatusGood, |
| }, |
| }, |
| retFailOpen: nil, |
| retFailClosed: nil, |
| }, |
| { // 1 |
| results: []*ocspStatus{ |
| { |
| code: ocspStatusRevoked, |
| err: errors.New("revoked"), |
| }, |
| { |
| code: ocspStatusGood, |
| }, |
| { |
| code: ocspStatusGood, |
| }, |
| }, |
| retFailOpen: &ocspStatus{ocspStatusRevoked, errors.New("revoked")}, |
| retFailClosed: &ocspStatus{ocspStatusRevoked, errors.New("revoked")}, |
| }, |
| { // 2 |
| results: []*ocspStatus{ |
| { |
| code: ocspStatusUnknown, |
| err: errors.New("unknown"), |
| }, |
| { |
| code: ocspStatusGood, |
| }, |
| { |
| code: ocspStatusGood, |
| }, |
| }, |
| retFailOpen: nil, |
| retFailClosed: &ocspStatus{ocspStatusUnknown, errors.New("unknown")}, |
| }, |
| { // 3: not taken as revoked if any invalid OCSP response (ocspInvalidValidity) is included. |
| results: []*ocspStatus{ |
| { |
| code: ocspStatusRevoked, |
| err: errors.New("revoked"), |
| }, |
| { |
| code: ocspInvalidValidity, |
| }, |
| { |
| code: ocspStatusGood, |
| }, |
| }, |
| retFailOpen: nil, |
| retFailClosed: &ocspStatus{ocspStatusRevoked, errors.New("revoked")}, |
| }, |
| { // 4: not taken as revoked if the number of results don't match the expected results. |
| results: []*ocspStatus{ |
| { |
| code: ocspStatusRevoked, |
| err: errors.New("revoked"), |
| }, |
| { |
| code: ocspStatusGood, |
| }, |
| }, |
| resultLen: 3, |
| retFailOpen: nil, |
| retFailClosed: &ocspStatus{ocspStatusRevoked, errors.New("revoked")}, |
| }, |
| } |
| c := New(testLogFactory, 10) |
| for idx, tt := range testcases { |
| expectedLen := len(tt.results) |
| if tt.resultLen > 0 { |
| expectedLen = tt.resultLen |
| } |
| r := c.canEarlyExitForOCSP(tt.results, expectedLen, &VerifyConfig{OcspFailureMode: FailOpenTrue}) |
| if !(tt.retFailOpen == nil && r == nil) && !(tt.retFailOpen != nil && r != nil && tt.retFailOpen.code == r.code) { |
| t.Fatalf("%d: failed to match return. expected: %v, got: %v", idx, tt.retFailOpen, r) |
| } |
| r = c.canEarlyExitForOCSP(tt.results, expectedLen, &VerifyConfig{OcspFailureMode: FailOpenFalse}) |
| if !(tt.retFailClosed == nil && r == nil) && !(tt.retFailClosed != nil && r != nil && tt.retFailClosed.code == r.code) { |
| t.Fatalf("%d: failed to match return. expected: %v, got: %v", idx, tt.retFailClosed, r) |
| } |
| } |
| } |
| |
| var testLogger = hclog.New(hclog.DefaultOptions) |
| |
| func testLogFactory() hclog.Logger { |
| return testLogger |
| } |
| |
| type fakeHTTPClient struct { |
| cnt int // number of retry |
| success bool // return success after retry in cnt times |
| timeout bool // timeout |
| body []byte // return body |
| t *testing.T |
| logger hclog.Logger |
| redirected bool |
| } |
| |
| func (c *fakeHTTPClient) Do(_ *retryablehttp.Request) (*http.Response, error) { |
| c.cnt-- |
| if c.cnt < 0 { |
| c.cnt = 0 |
| } |
| c.t.Log("fakeHTTPClient.cnt", c.cnt) |
| |
| var retcode int |
| if !c.redirected { |
| c.redirected = true |
| c.cnt++ |
| retcode = 405 |
| } else if c.success && c.cnt == 1 { |
| retcode = 200 |
| } else { |
| if c.timeout { |
| // simulate timeout |
| time.Sleep(time.Second * 1) |
| return nil, &fakeHTTPError{ |
| err: "Whatever reason (Client.Timeout exceeded while awaiting headers)", |
| timeout: true, |
| } |
| } |
| retcode = 0 |
| } |
| |
| ret := &http.Response{ |
| StatusCode: retcode, |
| Body: &fakeResponseBody{body: c.body}, |
| } |
| return ret, nil |
| } |
| |
| type fakeHTTPError struct { |
| err string |
| timeout bool |
| } |
| |
| func (e *fakeHTTPError) Error() string { return e.err } |
| func (e *fakeHTTPError) Timeout() bool { return e.timeout } |
| func (e *fakeHTTPError) Temporary() bool { return true } |
| |
| type fakeResponseBody struct { |
| body []byte |
| cnt int |
| } |
| |
| func (b *fakeResponseBody) Read(p []byte) (n int, err error) { |
| if b.cnt == 0 { |
| copy(p, b.body) |
| b.cnt = 1 |
| return len(b.body), nil |
| } |
| b.cnt = 0 |
| return 0, io.EOF |
| } |
| |
| func (b *fakeResponseBody) Close() error { |
| return nil |
| } |
| |
| func fakeRequestFunc(_, _ string, _ interface{}) (*retryablehttp.Request, error) { |
| return nil, nil |
| } |
| |
| const vaultCert = `-----BEGIN CERTIFICATE----- |
| MIIDuTCCAqGgAwIBAgIUA6VeVD1IB5rXcCZRAqPO4zr/GAMwDQYJKoZIhvcNAQEL |
| BQAwcjELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAlZBMREwDwYDVQQHDAhTb21lQ2l0 |
| eTESMBAGA1UECgwJTXlDb21wYW55MRMwEQYDVQQLDApNeURpdmlzaW9uMRowGAYD |
| VQQDDBF3d3cuY29uaHVnZWNvLmNvbTAeFw0yMjA5MDcxOTA1MzdaFw0yNDA5MDYx |
| OTA1MzdaMHIxCzAJBgNVBAYTAlVTMQswCQYDVQQIDAJWQTERMA8GA1UEBwwIU29t |
| ZUNpdHkxEjAQBgNVBAoMCU15Q29tcGFueTETMBEGA1UECwwKTXlEaXZpc2lvbjEa |
| MBgGA1UEAwwRd3d3LmNvbmh1Z2Vjby5jb20wggEiMA0GCSqGSIb3DQEBAQUAA4IB |
| DwAwggEKAoIBAQDL9qzEXi4PIafSAqfcwcmjujFvbG1QZbI8swxnD+w8i4ufAQU5 |
| LDmvMrGo3ZbhJ0mCihYmFxpjhRdP2raJQ9TysHlPXHtDRpr9ckWTKBz2oIfqVtJ2 |
| qzteQkWCkDAO7kPqzgCFsMeoMZeONRkeGib0lEzQAbW/Rqnphg8zVVkyQ71DZ7Pc |
| d5WkC2E28kKcSramhWfVFpxG3hSIrLOX2esEXteLRzKxFPf+gi413JZFKYIWrebP |
| u5t0++MLNpuX322geoki4BWMjQsd47XILmxZ4aj33ScZvdrZESCnwP76hKIxg9mO |
| lMxrqSWKVV5jHZrElSEj9LYJgDO1Y6eItn7hAgMBAAGjRzBFMAsGA1UdDwQEAwIE |
| MDATBgNVHSUEDDAKBggrBgEFBQcDATAhBgNVHREEGjAYggtleGFtcGxlLmNvbYIJ |
| bG9jYWxob3N0MA0GCSqGSIb3DQEBCwUAA4IBAQA5dPdf5SdtMwe2uSspO/EuWqbM |
| 497vMQBW1Ey8KRKasJjhvOVYMbe7De5YsnW4bn8u5pl0zQGF4hEtpmifAtVvziH/ |
| K+ritQj9VVNbLLCbFcg+b0kfjt4yrDZ64vWvIeCgPjG1Kme8gdUUWgu9dOud5gdx |
| qg/tIFv4TRS/eIIymMlfd9owOD3Ig6S5fy4NaAJFAwXf8+3Rzuc+e7JSAPgAufjh |
| tOTWinxvoiOLuYwo9CyGgq4qKBFsrY0aE0gdA7oTQkpbEbo2EbqiWUl/PTCl1Y4Z |
| nSZ0n+4q9QC9RLrWwYTwh838d5RVLUst2mBKSA+vn7YkqmBJbdBC6nkd7n7H |
| -----END CERTIFICATE----- |
| ` |