blob: e655a9084dd6c8983b23db9c3aa58b6213b3f12a [file] [log] [blame]
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package fairshare
import (
"fmt"
"io/ioutil"
"sync"
log "github.com/hashicorp/go-hclog"
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/sdk/helper/logging"
)
// Job is an interface for jobs used with this job manager
type Job interface {
// Execute performs the work.
// It should be synchronous if a cleanupFn is provided.
Execute() error
// OnFailure handles the error resulting from a failed Execute().
// It should be synchronous if a cleanupFn is provided.
OnFailure(err error)
}
type (
initFn func()
cleanupFn func()
)
type wrappedJob struct {
job Job
init initFn
cleanup cleanupFn
}
// worker represents a single worker in a pool
type worker struct {
name string
jobCh <-chan wrappedJob
quit chan struct{}
logger log.Logger
// waitgroup for testing stop functionality
wg *sync.WaitGroup
}
// start starts the worker listening and working until the quit channel is closed
func (w *worker) start() {
w.wg.Add(1)
go func() {
for {
select {
case <-w.quit:
w.wg.Done()
return
case wJob := <-w.jobCh:
if wJob.init != nil {
wJob.init()
}
err := wJob.job.Execute()
if err != nil {
wJob.job.OnFailure(err)
}
if wJob.cleanup != nil {
wJob.cleanup()
}
}
}
}()
}
// dispatcher represents a worker pool
type dispatcher struct {
name string
numWorkers int
workers []worker
jobCh chan wrappedJob
onceStart sync.Once
onceStop sync.Once
quit chan struct{}
logger log.Logger
wg *sync.WaitGroup
}
// newDispatcher generates a new worker dispatcher and populates it with workers
func newDispatcher(name string, numWorkers int, l log.Logger) *dispatcher {
d := createDispatcher(name, numWorkers, l)
d.init()
return d
}
// dispatch dispatches a job to the worker pool, with optional initialization
// and cleanup functions (useful for tracking job progress)
func (d *dispatcher) dispatch(job Job, init initFn, cleanup cleanupFn) {
wJob := wrappedJob{
init: init,
job: job,
cleanup: cleanup,
}
select {
case d.jobCh <- wJob:
case <-d.quit:
d.logger.Info("shutting down during dispatch")
}
}
// start starts all the workers listening on the job channel
// this will only start the workers for this dispatch once
func (d *dispatcher) start() {
d.onceStart.Do(func() {
d.logger.Trace("starting dispatcher")
for _, w := range d.workers {
worker := w
worker.start()
}
})
}
// stop stops the worker pool asynchronously
func (d *dispatcher) stop() {
d.onceStop.Do(func() {
d.logger.Trace("terminating dispatcher")
close(d.quit)
})
}
// createDispatcher generates a new Dispatcher object, but does not initialize the
// worker pool
func createDispatcher(name string, numWorkers int, l log.Logger) *dispatcher {
if l == nil {
l = logging.NewVaultLoggerWithWriter(ioutil.Discard, log.NoLevel)
}
if numWorkers <= 0 {
numWorkers = 1
l.Warn("must have 1 or more workers. setting number of workers to 1")
}
if name == "" {
guid, err := uuid.GenerateUUID()
if err != nil {
l.Warn("uuid generator failed, using 'no-uuid'", "err", err)
guid = "no-uuid"
}
name = fmt.Sprintf("dispatcher-%s", guid)
}
var wg sync.WaitGroup
d := dispatcher{
name: name,
numWorkers: numWorkers,
workers: make([]worker, 0),
jobCh: make(chan wrappedJob),
quit: make(chan struct{}),
logger: l,
wg: &wg,
}
d.logger.Trace("created dispatcher", "name", d.name, "num_workers", d.numWorkers)
return &d
}
func (d *dispatcher) init() {
for len(d.workers) < d.numWorkers {
d.initializeWorker()
}
d.logger.Trace("initialized dispatcher", "num_workers", d.numWorkers)
}
// initializeWorker initializes and adds a new worker, with an optional name
func (d *dispatcher) initializeWorker() {
w := worker{
name: fmt.Sprint("worker-", len(d.workers)),
jobCh: d.jobCh,
quit: d.quit,
logger: d.logger,
wg: d.wg,
}
d.workers = append(d.workers, w)
}