| // Copyright (c) HashiCorp, Inc. |
| // SPDX-License-Identifier: BUSL-1.1 |
| |
| package promising |
| |
| import ( |
| "context" |
| "fmt" |
| "sync" |
| "sync/atomic" |
| |
| "go.opentelemetry.io/otel/attribute" |
| "go.opentelemetry.io/otel/trace" |
| ) |
| |
| // promise represents a result that will become available at some point |
| // in the future, delivered by an asynchronous [Task]. |
| type promise struct { |
| name string |
| |
| responsible atomic.Pointer[task] |
| result atomic.Pointer[promiseResult] |
| traceSpan trace.Span |
| |
| waiting []chan<- struct{} |
| waitingMu sync.Mutex |
| } |
| |
| func (p *promise) promiseID() PromiseID { |
| return PromiseID{p} |
| } |
| |
| type promiseResult struct { |
| val any |
| err error |
| |
| // forced is set when this result was generated by the promise machinery |
| // itself, as opposed to from calling tasks. We use this to behave more |
| // gracefully when the responsible task resolution races with the internal |
| // error, so that we can treat that differently to when the responsible |
| // task itself tries to resolve a promise multiple times. |
| forced bool |
| } |
| |
| func getResolvedPromiseResult[T any](result *promiseResult) (T, error) { |
| // v might fail this type assertion if it's been set to nil |
| // due to its responsible task exiting without resolving it, |
| // in which case we'll just return the zero value of T along |
| // with the error. |
| v, _ := result.val.(T) |
| err := result.err |
| return v, err |
| } |
| |
| // PromiseID is an opaque, comparable unique identifier for a promise, which |
| // can therefore be used by callers to produce a lookup table of metadata for |
| // each active promise they are interested in. |
| // |
| // The identifier for a promise follows it as the responsibility to resolve it |
| // transfers beween tasks. |
| // |
| // For example, this can be useful for retaining contextual information that |
| // can help explain which work was implicated in a dependency cycle between |
| // tasks. |
| type PromiseID struct { |
| promise *promise |
| } |
| |
| func (id PromiseID) FriendlyName() string { |
| return id.promise.name |
| } |
| |
| // NoPromise is the zero value of [PromiseID] and used to represent the absense |
| // of a promise. |
| var NoPromise PromiseID |
| |
| // NewPromise creates a new promise that the calling task is initially |
| // responsible for and returns both its resolver and its getter. |
| // |
| // The given context must be a task context or this function will panic. |
| // |
| // The caller should retain the resolver for its own use and pass the getter |
| // to any other tasks that will consume the result of the promise. |
| func NewPromise[T any](ctx context.Context, name string) (PromiseResolver[T], PromiseGet[T]) { |
| callerSpan := trace.SpanFromContext(ctx) |
| initialResponsible := mustTaskFromContext(ctx) |
| p := &promise{name: name} |
| p.responsible.Store(initialResponsible) |
| initialResponsible.responsible[p] = struct{}{} |
| |
| ctx, span := tracer.Start( |
| ctx, fmt.Sprintf("promise(%s)", name), |
| trace.WithNewRoot(), |
| trace.WithLinks(trace.Link{ |
| SpanContext: trace.SpanContextFromContext(ctx), |
| }), |
| ) |
| _ = ctx // prevent staticcheck from complaining until we have something actually using this |
| p.traceSpan = span |
| promiseSpanContext := span.SpanContext() |
| |
| callerSpan.AddEvent("new promise", trace.WithAttributes( |
| attribute.String("promising.responsible_for", promiseSpanContext.SpanID().String()), |
| )) |
| |
| resolver := PromiseResolver[T]{p} |
| getter := PromiseGet[T](func(ctx context.Context) (T, error) { |
| reqT := mustTaskFromContext(ctx) |
| |
| waiterSpan := trace.SpanFromContext(ctx) |
| |
| ok := reqT.awaiting.CompareAndSwap(nil, p) |
| if !ok { |
| // If we get here then the task seems to have forked into two |
| // goroutines that are trying to await promises concurrently, |
| // which is illegal per the contract for tasks. |
| panic("racing promise get") |
| } |
| defer func() { |
| ok := reqT.awaiting.CompareAndSwap(p, nil) |
| if !ok { |
| panic("racing promise get") |
| } |
| }() |
| |
| // We'll first test whether waiting for this promise is possible |
| // without creating a deadlock, by following the awaiting->responsible |
| // chain. |
| checkP := p |
| checkT := p.responsible.Load() |
| steps := 1 |
| for checkT != reqT { |
| steps++ |
| if checkT == nil { |
| break |
| } |
| nextCheckP := checkT.awaiting.Load() |
| if nextCheckP == nil { |
| break |
| } |
| if checkP.responsible.Load() != checkT { |
| break |
| } |
| checkP = nextCheckP |
| checkT = checkP.responsible.Load() |
| } |
| if checkT == reqT { |
| // We've found a self-dependency, but to report it in a useful |
| // way we need to collect up all of the promises, so we'll |
| // repeat the above and collect up all of the promises we find |
| // along the way this time, instead of just counting them. |
| err := make(ErrSelfDependent, 0, steps) |
| var affectedPromises []*promise |
| checkP := p |
| checkT := p.responsible.Load() |
| err = append(err, checkP.promiseID()) |
| affectedPromises = append(affectedPromises, checkP) |
| for checkT != reqT { |
| if checkT == nil { |
| break |
| } |
| nextCheckP := checkT.awaiting.Load() |
| if nextCheckP == nil { |
| break |
| } |
| if checkP.responsible.Load() != checkT { |
| break |
| } |
| checkP = nextCheckP |
| checkT = checkP.responsible.Load() |
| err = append(err, checkP.promiseID()) |
| affectedPromises = append(affectedPromises, checkP) |
| } |
| waiterSpan.AddEvent( |
| "task is self-dependent", |
| trace.WithAttributes( |
| attribute.String("promise.waiting_for_id", promiseSpanContext.SpanID().String()), |
| ), |
| ) |
| |
| // All waiters for this promise need to see this error, because |
| // otherwise the other waiters might stall forever waiting for |
| // a result that will never come. |
| for _, affected := range affectedPromises { |
| resolvePromiseInternalFailure(affected, err) |
| } |
| // The current promise is one of the "affected promises" that |
| // were resolved above, so we can now fall through to the check |
| // below for whether the promise is already resolved and have |
| // it return the error. |
| } |
| |
| // If we get here then it's safe to actually await. |
| p.waitingMu.Lock() |
| if result := p.result.Load(); result != nil { |
| // No need to wait because the result is already available. |
| p.waitingMu.Unlock() |
| waiterSpan.AddEvent( |
| "promise is already resolved", |
| trace.WithAttributes( |
| attribute.String("promise.waiting_for_id", promiseSpanContext.SpanID().String()), |
| ), |
| ) |
| return getResolvedPromiseResult[T](result) |
| } |
| |
| ch := make(chan struct{}) |
| p.waiting = append(p.waiting, ch) |
| waiterCount := len(p.waiting) |
| p.waitingMu.Unlock() |
| |
| waiterSpan.AddEvent( |
| "waiting for promise result", |
| trace.WithAttributes( |
| attribute.String("promise.waiting_for_id", promiseSpanContext.SpanID().String()), |
| attribute.Int("promise.waiter_count", waiterCount), |
| ), |
| ) |
| p.traceSpan.AddEvent( |
| "new task waiting", |
| trace.WithAttributes( |
| attribute.String("promise.waiter_id", waiterSpan.SpanContext().SpanID().String()), |
| attribute.Int("promise.waiter_count", waiterCount), |
| ), |
| ) |
| <-ch // channel will be closed once promise is resolved |
| waiterSpan.AddEvent( |
| "promise resolved", |
| trace.WithAttributes( |
| attribute.String("promise.waiting_for_id", promiseSpanContext.SpanID().String()), |
| ), |
| ) |
| if result := p.result.Load(); result != nil { |
| return getResolvedPromiseResult[T](result) |
| } else { |
| // If we get here then there's a bug in resolvePromise below |
| panic("promise signaled resolved but has no result") |
| } |
| }) |
| |
| return resolver, getter |
| } |
| |
| func resolvePromise(p *promise, v any, err error) { |
| p.waitingMu.Lock() |
| defer p.waitingMu.Unlock() |
| |
| respT := p.responsible.Load() |
| p.responsible.Store(nil) |
| respT.responsible.Remove(p) |
| |
| ok := p.result.CompareAndSwap(nil, &promiseResult{ |
| val: v, |
| err: err, |
| }) |
| if !ok { |
| // The result that's now present might be a "forced error" generated |
| // through promiseInternalFailure, in which case we just quietly |
| // ignore the attempt to actually resolve it since all of the |
| // waiters will already have received the error. |
| r := p.result.Load() |
| if r != nil && r.forced { |
| return |
| } |
| // Any other conflict indicates a bug in the calling task. |
| panic("promise resolved more than once") |
| } |
| |
| for _, waitingCh := range p.waiting { |
| close(waitingCh) |
| } |
| p.waiting = nil |
| } |
| |
| // resolvePromiseInternalFailure is a variant of resolvePromise that we use for |
| // internal errors that aren't produced by the task responsible for the |
| // promise, such as when tasks become self-dependent and so we need to |
| // immediately fail all of the promises in the chain to prevent any of |
| // the waiters from potentially stuck forever waiting for completion that |
| // might never come, or might see an incorrect result while the failures |
| // propagate through a different return path. |
| func resolvePromiseInternalFailure(p *promise, err error) { |
| p.waitingMu.Lock() |
| defer p.waitingMu.Unlock() |
| |
| p.traceSpan.AddEvent("internal promise failure", trace.WithAttributes( |
| attribute.String("error", err.Error()), |
| )) |
| |
| // For internal failures we leave the responsibility data in place so |
| // that the responsible task can still try to resolve the promise and |
| // have it be a no-op, since the task that's responsible for resolving |
| // will not typically also call the promise getter, and so it won't |
| // know about the failure. |
| |
| ok := p.result.CompareAndSwap(nil, &promiseResult{ |
| err: err, |
| forced: true, |
| }) |
| if !ok { |
| // This suggests either that the responsible task beat us to the punch |
| // and resolved first, or that this promise was involved in two |
| // different self-dependence situations simultaneously and a different |
| // one got recorded already. |
| // |
| // Both situations are no big deal -- the promise got resolved one |
| // way or another -- but we'll record a tracing event for it just |
| // in case it's helpful while debugging something. |
| p.traceSpan.AddEvent("internal promise failure conflict") |
| } |
| |
| for _, waitingCh := range p.waiting { |
| close(waitingCh) |
| } |
| p.waiting = nil |
| } |
| |
| // PromiseGet is the signature of a promise "getter" function, which blocks |
| // until a promise is resolved and then returns its result values. |
| // |
| // A PromiseGet function may be called only within a task, using a context |
| // value that descends from that task's context. |
| // |
| // If the given context is cancelled or reaches its deadline then the function |
| // will return the relevant context-related error to describe that situation. |
| type PromiseGet[T any] func(ctx context.Context) (T, error) |