blob: 6c8bfc2e3b1434abe0e6061c251930b3242e54c0 [file] [log] [blame]
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package promising
import (
"context"
"fmt"
"sync"
"sync/atomic"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
)
// promise represents a result that will become available at some point
// in the future, delivered by an asynchronous [Task].
type promise struct {
name string
responsible atomic.Pointer[task]
result atomic.Pointer[promiseResult]
traceSpan trace.Span
waiting []chan<- struct{}
waitingMu sync.Mutex
}
func (p *promise) promiseID() PromiseID {
return PromiseID{p}
}
type promiseResult struct {
val any
err error
// forced is set when this result was generated by the promise machinery
// itself, as opposed to from calling tasks. We use this to behave more
// gracefully when the responsible task resolution races with the internal
// error, so that we can treat that differently to when the responsible
// task itself tries to resolve a promise multiple times.
forced bool
}
func getResolvedPromiseResult[T any](result *promiseResult) (T, error) {
// v might fail this type assertion if it's been set to nil
// due to its responsible task exiting without resolving it,
// in which case we'll just return the zero value of T along
// with the error.
v, _ := result.val.(T)
err := result.err
return v, err
}
// PromiseID is an opaque, comparable unique identifier for a promise, which
// can therefore be used by callers to produce a lookup table of metadata for
// each active promise they are interested in.
//
// The identifier for a promise follows it as the responsibility to resolve it
// transfers beween tasks.
//
// For example, this can be useful for retaining contextual information that
// can help explain which work was implicated in a dependency cycle between
// tasks.
type PromiseID struct {
promise *promise
}
func (id PromiseID) FriendlyName() string {
return id.promise.name
}
// NoPromise is the zero value of [PromiseID] and used to represent the absense
// of a promise.
var NoPromise PromiseID
// NewPromise creates a new promise that the calling task is initially
// responsible for and returns both its resolver and its getter.
//
// The given context must be a task context or this function will panic.
//
// The caller should retain the resolver for its own use and pass the getter
// to any other tasks that will consume the result of the promise.
func NewPromise[T any](ctx context.Context, name string) (PromiseResolver[T], PromiseGet[T]) {
callerSpan := trace.SpanFromContext(ctx)
initialResponsible := mustTaskFromContext(ctx)
p := &promise{name: name}
p.responsible.Store(initialResponsible)
initialResponsible.responsible[p] = struct{}{}
ctx, span := tracer.Start(
ctx, fmt.Sprintf("promise(%s)", name),
trace.WithNewRoot(),
trace.WithLinks(trace.Link{
SpanContext: trace.SpanContextFromContext(ctx),
}),
)
_ = ctx // prevent staticcheck from complaining until we have something actually using this
p.traceSpan = span
promiseSpanContext := span.SpanContext()
callerSpan.AddEvent("new promise", trace.WithAttributes(
attribute.String("promising.responsible_for", promiseSpanContext.SpanID().String()),
))
resolver := PromiseResolver[T]{p}
getter := PromiseGet[T](func(ctx context.Context) (T, error) {
reqT := mustTaskFromContext(ctx)
waiterSpan := trace.SpanFromContext(ctx)
ok := reqT.awaiting.CompareAndSwap(nil, p)
if !ok {
// If we get here then the task seems to have forked into two
// goroutines that are trying to await promises concurrently,
// which is illegal per the contract for tasks.
panic("racing promise get")
}
defer func() {
ok := reqT.awaiting.CompareAndSwap(p, nil)
if !ok {
panic("racing promise get")
}
}()
// We'll first test whether waiting for this promise is possible
// without creating a deadlock, by following the awaiting->responsible
// chain.
checkP := p
checkT := p.responsible.Load()
steps := 1
for checkT != reqT {
steps++
if checkT == nil {
break
}
nextCheckP := checkT.awaiting.Load()
if nextCheckP == nil {
break
}
if checkP.responsible.Load() != checkT {
break
}
checkP = nextCheckP
checkT = checkP.responsible.Load()
}
if checkT == reqT {
// We've found a self-dependency, but to report it in a useful
// way we need to collect up all of the promises, so we'll
// repeat the above and collect up all of the promises we find
// along the way this time, instead of just counting them.
err := make(ErrSelfDependent, 0, steps)
var affectedPromises []*promise
checkP := p
checkT := p.responsible.Load()
err = append(err, checkP.promiseID())
affectedPromises = append(affectedPromises, checkP)
for checkT != reqT {
if checkT == nil {
break
}
nextCheckP := checkT.awaiting.Load()
if nextCheckP == nil {
break
}
if checkP.responsible.Load() != checkT {
break
}
checkP = nextCheckP
checkT = checkP.responsible.Load()
err = append(err, checkP.promiseID())
affectedPromises = append(affectedPromises, checkP)
}
waiterSpan.AddEvent(
"task is self-dependent",
trace.WithAttributes(
attribute.String("promise.waiting_for_id", promiseSpanContext.SpanID().String()),
),
)
// All waiters for this promise need to see this error, because
// otherwise the other waiters might stall forever waiting for
// a result that will never come.
for _, affected := range affectedPromises {
resolvePromiseInternalFailure(affected, err)
}
// The current promise is one of the "affected promises" that
// were resolved above, so we can now fall through to the check
// below for whether the promise is already resolved and have
// it return the error.
}
// If we get here then it's safe to actually await.
p.waitingMu.Lock()
if result := p.result.Load(); result != nil {
// No need to wait because the result is already available.
p.waitingMu.Unlock()
waiterSpan.AddEvent(
"promise is already resolved",
trace.WithAttributes(
attribute.String("promise.waiting_for_id", promiseSpanContext.SpanID().String()),
),
)
return getResolvedPromiseResult[T](result)
}
ch := make(chan struct{})
p.waiting = append(p.waiting, ch)
waiterCount := len(p.waiting)
p.waitingMu.Unlock()
waiterSpan.AddEvent(
"waiting for promise result",
trace.WithAttributes(
attribute.String("promise.waiting_for_id", promiseSpanContext.SpanID().String()),
attribute.Int("promise.waiter_count", waiterCount),
),
)
p.traceSpan.AddEvent(
"new task waiting",
trace.WithAttributes(
attribute.String("promise.waiter_id", waiterSpan.SpanContext().SpanID().String()),
attribute.Int("promise.waiter_count", waiterCount),
),
)
<-ch // channel will be closed once promise is resolved
waiterSpan.AddEvent(
"promise resolved",
trace.WithAttributes(
attribute.String("promise.waiting_for_id", promiseSpanContext.SpanID().String()),
),
)
if result := p.result.Load(); result != nil {
return getResolvedPromiseResult[T](result)
} else {
// If we get here then there's a bug in resolvePromise below
panic("promise signaled resolved but has no result")
}
})
return resolver, getter
}
func resolvePromise(p *promise, v any, err error) {
p.waitingMu.Lock()
defer p.waitingMu.Unlock()
respT := p.responsible.Load()
p.responsible.Store(nil)
respT.responsible.Remove(p)
ok := p.result.CompareAndSwap(nil, &promiseResult{
val: v,
err: err,
})
if !ok {
// The result that's now present might be a "forced error" generated
// through promiseInternalFailure, in which case we just quietly
// ignore the attempt to actually resolve it since all of the
// waiters will already have received the error.
r := p.result.Load()
if r != nil && r.forced {
return
}
// Any other conflict indicates a bug in the calling task.
panic("promise resolved more than once")
}
for _, waitingCh := range p.waiting {
close(waitingCh)
}
p.waiting = nil
}
// resolvePromiseInternalFailure is a variant of resolvePromise that we use for
// internal errors that aren't produced by the task responsible for the
// promise, such as when tasks become self-dependent and so we need to
// immediately fail all of the promises in the chain to prevent any of
// the waiters from potentially stuck forever waiting for completion that
// might never come, or might see an incorrect result while the failures
// propagate through a different return path.
func resolvePromiseInternalFailure(p *promise, err error) {
p.waitingMu.Lock()
defer p.waitingMu.Unlock()
p.traceSpan.AddEvent("internal promise failure", trace.WithAttributes(
attribute.String("error", err.Error()),
))
// For internal failures we leave the responsibility data in place so
// that the responsible task can still try to resolve the promise and
// have it be a no-op, since the task that's responsible for resolving
// will not typically also call the promise getter, and so it won't
// know about the failure.
ok := p.result.CompareAndSwap(nil, &promiseResult{
err: err,
forced: true,
})
if !ok {
// This suggests either that the responsible task beat us to the punch
// and resolved first, or that this promise was involved in two
// different self-dependence situations simultaneously and a different
// one got recorded already.
//
// Both situations are no big deal -- the promise got resolved one
// way or another -- but we'll record a tracing event for it just
// in case it's helpful while debugging something.
p.traceSpan.AddEvent("internal promise failure conflict")
}
for _, waitingCh := range p.waiting {
close(waitingCh)
}
p.waiting = nil
}
// PromiseGet is the signature of a promise "getter" function, which blocks
// until a promise is resolved and then returns its result values.
//
// A PromiseGet function may be called only within a task, using a context
// value that descends from that task's context.
//
// If the given context is cancelled or reaches its deadline then the function
// will return the relevant context-related error to describe that situation.
type PromiseGet[T any] func(ctx context.Context) (T, error)