Skip to content

Commit

Permalink
fixed test after refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
nbys committed Feb 8, 2022
1 parent 180de74 commit 6a7f778
Show file tree
Hide file tree
Showing 9 changed files with 588 additions and 729 deletions.
129 changes: 56 additions & 73 deletions app/acme/acme.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,103 +2,86 @@ package acme

import (
"context"
"log"
"time"

"github.com/go-pkgz/repeater"
log "github.com/go-pkgz/lgr"
)

// var acmeOpTimeout = 5 * time.Minute
var (
attemptInterval = time.Minute * 1
maxAttemps = 5
)

// Solver is an interface for solving ACME DNS challenge
type Solver interface {
// PreSolve is called before solving the challenge. ACME Order will be created and DNS record will be added.
PreSolve(ctx context.Context) error
// Solve is called to present TXT record and accept challenge.
Solve(ctx context.Context) error
// PostSolve is called after obtaining the certificate.
PostSolve(ctx context.Context) error
// GetCertificateExpiration returns certificate expiration date
GetCertificateExpiration(certPath string) (time.Time, error)
}
PreSolve() error

// fqdns []string, provider string, nameservers []string
// Solve is called to accept the challenge and pull the certificate.
Solve() error

// ScheduleCertificateRenewal schedules certificate renewal
func ScheduleCertificateRenewal(solver Solver, timeout time.Duration) {
certPath := getEnvOptionalString("SSL_CERT", "./var/acme/cert.pem")
// ObtainCertificate is called to obtain the certificate.
// Certificate will be saved to the file path specified by flag (env: SSL_CERT). //TODO add proper descr
ObtainCertificate() error
}

// ScheduleCertificateRenewal schedules certificate renewal
func ScheduleCertificateRenewal(ctx context.Context, solver Solver, certPath string) {
go func(certPath string) {
var (
expiredAt time.Time
err error
)
var nextAttemptAfter time.Duration

expiredAt, err = solver.GetCertificateExpiration(certPath)
if err != nil {
expiredAt = time.Now()
log.Printf("[INFO] failed to get certificate expiration date, probably not obtained yet: %v", err)
if expiredAt, err := getCertificateExpiration(certPath); err == nil {
nextAttemptAfter = time.Until(expiredAt.Add(time.Hour * 24 * -5))
log.Printf("[INFO] certificate will expire in %v, next attempt in %v", expiredAt, nextAttemptAfter)
}

attempted := 0
for {
<-time.After(time.Until(expiredAt.Add(time.Hour * 24 * -5)))
select {
case <-ctx.Done():
return
case <-time.After(nextAttemptAfter):
}
attempted++

// add DNS record and wait for propagation
{
ctx, cancel := context.WithTimeout(context.Background(), timeout)
err = repeater.NewDefault(10, timeout>>12).Do(ctx, func() error {
if errc := solver.PreSolve(ctx); errc != nil {
log.Printf("[INFO] error in ACME DNS Challenge Presolve: %v", errc)
return errc
}
return nil
})
cancel()
if err != nil {
log.Printf("[ERROR] ACME DNS Challenge Presolve failed. Last error %v", err)
return
}
if attempted > maxAttemps {
log.Printf("[ERROR] Certificate renewal failed after %d attempts", attempted-1)
return
}
log.Printf("[INFO] renewing certificate attempt %d", attempted)

// present TXT record and accept challenge
{
ctx, cancel := context.WithTimeout(context.Background(), timeout)
err = repeater.NewDefault(10, timeout>>12).Do(ctx, func() error {
if errc := solver.Solve(ctx); errc != nil {
log.Printf("[INFO] error in ACME DNS Challenge Solve: %v", errc)
return errc
}
return nil
})
cancel()
if err != nil {
log.Printf("[ERROR] retry limit reached ACME DNS Challenge Solve failed. Last error: %v", err)
return
}
// create ACME order and add TXT record for the challenge
if err := solver.PreSolve(); err != nil {
nextAttemptAfter = time.Duration(attempted) * attemptInterval
log.Printf("[WARN] error during preparing ACME order: %v, next attempt in %v", err, nextAttemptAfter)
continue
}

// pull the certificate
{
ctx, cancel := context.WithTimeout(context.Background(), timeout)
err = repeater.NewDefault(10, timeout>>12).Do(ctx, func() error {
if errc := solver.PostSolve(ctx); errc != nil {
log.Printf("[INFO] error in ACME DNS Challenge PostSolve: %v", errc)
return errc
}
return nil
})
cancel()
if err != nil {
log.Printf("[ERROR] retry limit reached, ACME DNS Challenge PostSolve failed. Last error: %v", err)
return
}
// solve the challenge
if err := solver.Solve(); err != nil {
nextAttemptAfter = time.Duration(attempted) * attemptInterval
log.Printf("[WARN] error during solving ACME DNS Challenge: %v, next attempt in %v", err, nextAttemptAfter)
continue
}

expiredAt, err = solver.GetCertificateExpiration(certPath)
if err != nil {
log.Printf("[ERROR] failed to get certificate expiration date: %v", err)
return
// obtain certificate
if err := solver.ObtainCertificate(); err != nil {
nextAttemptAfter = time.Duration(attempted) * attemptInterval
log.Printf("[WARN] error during certificate obtaining: %v, next attempt in %v", err, nextAttemptAfter)
continue
}

expiredAt, err := getCertificateExpiration(certPath)
if err == nil {
// 5 days earlier than the certificate expiration
nextAttemptAfter = time.Until(expiredAt.Add(time.Hour * 24 * -5))
log.Printf("[INFO] certificate will expire in %v, next attempt in %v", expiredAt, nextAttemptAfter)
attempted = 0
continue
}

log.Printf("[WARN] failed to get certificate expiration date, probably not obtained yet: %v", err)
nextAttemptAfter = time.Duration(attempted) * attemptInterval
}
}(certPath)
}
127 changes: 86 additions & 41 deletions app/acme/acme_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,31 @@ package acme

import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"os"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

const certPath = "./TestScheduleCertificateRenewal.pem"

type mockSolver struct {
domain string
expires time.Time
preSolvedCalled int
solveCalled int
postSolvedCalled int
obtainCertCalled int
}

func (s *mockSolver) PreSolve(ctx context.Context) error {
func (s *mockSolver) PreSolve() error {
s.preSolvedCalled++
switch s.domain {
case "mycompany1.com":
Expand All @@ -26,39 +35,31 @@ func (s *mockSolver) PreSolve(ctx context.Context) error {
return nil
}

func (s *mockSolver) Solve(ctx context.Context) error {
func (s *mockSolver) Solve() error {
s.solveCalled++
switch s.domain {
case "mycompany2.com":
return fmt.Errorf("solve failed")
return fmt.Errorf("postSolved failed")
}
return nil
}

func (s *mockSolver) PostSolve(ctx context.Context) error {
s.postSolvedCalled++
func (s *mockSolver) ObtainCertificate() error {
s.obtainCertCalled++
switch s.domain {
case "mycompany3.com":
return fmt.Errorf("postSolved failed")
return fmt.Errorf("obtainCertificate failed")
case "mycompany5.com":
return nil
default:
return createCert(time.Now().Add(time.Hour*24*365), s.domain)
}
return nil
}

func (s *mockSolver) GetCertificateExpiration(certPath string) (time.Time, error) {
// check called before loop starts
if s.preSolvedCalled == 0 {
switch s.domain {
case "mycompany4.com":
return time.Now().Add(time.Hour * 24 * 670), nil
default:
return time.Time{}, fmt.Errorf("certificate does not exist")
}
}
return time.Now().Add(time.Hour * 24 * 365), nil
}

func TestScheduleCertificateRenewal(t *testing.T) {
timeout := 15 * time.Second
testMaxAttemps := 10
maxAttemps = testMaxAttemps

attemptInterval = time.Microsecond * 10

type args struct {
domain string
Expand All @@ -69,42 +70,86 @@ func TestScheduleCertificateRenewal(t *testing.T) {
type expected struct {
preSolvedCalled int
solveCalled int
postSolvedCalled int
obtainCertCalled int
}

tests := []struct {
name string
args args
expected expected
}{
// {"certificate not existed before",
// args{"example.com", false, time.Now().Add(time.Hour * 100 * 24)},
// expected{1, 1, 1}},
// {"presolve failed",
// args{"mycompany1.com", false, time.Time{}},
// expected{10, 0, 0}},
// {"solve failed",
// args{"mycompany2.com", false, time.Time{}},
// expected{1, 10, 0}},
{"postsolve failed",
{"certificate not existed before",
args{"example.com", false, time.Time{}},
expected{1, 1, 1}},
{"presolve always fails",
args{"mycompany1.com", false, time.Time{}},
expected{testMaxAttemps, 0, 0}},
{"solve always fails",
args{"mycompany2.com", false, time.Time{}},
expected{testMaxAttemps, testMaxAttemps, 0}},
{"obtain cert failed",
args{"mycompany3.com", false, time.Time{}},
expected{1, 1, 10}},
// {"certificate valid for a long time",
// args{"mycompany4.com", false, time.Time{}},
// expected{0, 0, 0}},
expected{maxAttemps, maxAttemps, maxAttemps}},
{"certificate valid for a long time",
args{"mycompany4.com", true, time.Now().Add(time.Hour * 100 * 24)},
expected{0, 0, 0}},
{"obtain cert success, but file not created",
args{"mycompany5.com", false, time.Time{}},
expected{maxAttemps, maxAttemps, maxAttemps}},
}

for _, tt := range tests {
if tt.args.certExistedBefore {
if err := createCert(tt.args.expiryTime, tt.args.domain); err != nil {
t.Fatal(err)
}
}

s := &mockSolver{
domain: tt.args.domain,
expires: tt.args.expiryTime,
}

ScheduleCertificateRenewal(s, timeout)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*15)
ScheduleCertificateRenewal(ctx, s, certPath)
time.Sleep(time.Second * 2)

time.Sleep(timeout)
assert.Equal(t, tt.expected.preSolvedCalled, s.preSolvedCalled, fmt.Sprintf("[case %s] preSolvedCalled not match", tt.name))
assert.Equal(t, tt.expected.solveCalled, s.solveCalled, fmt.Sprintf("[case %s] solveCalled not match", tt.name))
assert.Equal(t, tt.expected.postSolvedCalled, s.postSolvedCalled, fmt.Sprintf("[case %s] postSolvedCalled not match", tt.name))
assert.Equal(t, tt.expected.obtainCertCalled, s.obtainCertCalled, fmt.Sprintf("[case %s] postSolvedCalled not match", tt.name))

os.Remove(certPath)
cancel()
}
}

func createCert(expireAt time.Time, domain string) error {
priv, _ := rsa.GenerateKey(rand.Reader, 2048)
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Acme Co"},
},
NotBefore: time.Now(),
NotAfter: expireAt,

KeyUsage: x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
DNSNames: []string{domain},
}
// write cert to file
certBytes, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv)
if err != nil {
return err
}
certFile, err := os.Create(certPath)
if err != nil {
return err
}

if _, err := certFile.Write(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certBytes})); err != nil {
return err
}
return certFile.Close()
}
Loading

0 comments on commit 6a7f778

Please sign in to comment.