blob: a77596313eafa61abcdd079601cf890bb73a9b89 [file] [log] [blame] [edit]
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package promising
import (
"context"
"sync/atomic"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
)
// task represents one of a set of collaborating tasks that are communicating
// in terms of promises.
type task struct {
awaiting atomic.Pointer[promise]
responsible promiseSet
}
// MainTask runs the given function as a "main task", which is a task
// that blocks execution of the caller until it is complete and can create
// the promises and other async tasks required to produce its result.
func MainTask[T any](ctx context.Context, impl func(ctx context.Context) (T, error)) (T, error) {
mainT := &task{
responsible: make(promiseSet),
}
ctx = contextWithTask(ctx, mainT)
v, err := impl(ctx)
// The implementation function must have either resolved all of its
// promises or transferred responsibility for them to another task
// before it returns.
var unresolvedErr ErrUnresolved
for unresolved := range mainT.responsible {
oneErr := ErrUnresolved{unresolved.promiseID()}
resolvePromise(unresolved, nil, oneErr)
unresolvedErr = append(unresolvedErr, unresolved.promiseID())
}
if err == nil && len(unresolvedErr) != 0 {
// If the task wasn't already returning its own error then we'll
// make it return our ErrUnresolved so the caller will know that the
// task behaved incorrectly.
err = unresolvedErr
}
return v, err
}
// AsyncTask runs the given function as a new task, passing responsibility
// for the promises in the given [PromiseContainer] to the new task.
//
// The new task runs concurrently with the caller as a new goroutine. It must
// either resolve all of the given promises or delegate responsibilty for
// them to another task before returning.
//
// The context passed to the implementation function carries the identity of
// the new task, and so the task must use that context for any calls to
// [PromiseGet] functions and for resolving any promises.
//
// A task should typically be a single thread of execution and not spawn
// any new goroutines unless doing so indirectly through another call to
// [AsyncTask]. If a particular task _does_ spawn additional goroutines then
// it's the task implementer's responsibility to prevent concurrent calls to
// any promise getters or resolvers from multiple goroutines. In particular,
// each task is allowed to await only one promise at a time and violating
// this invariant will cause undefined behavior.
func AsyncTask[P PromiseContainer](ctx context.Context, promises P, impl func(ctx context.Context, promises P)) {
callerT := mustTaskFromContext(ctx)
newT := &task{
responsible: make(promiseSet),
}
// We treat async tasks as disconnected from their caller when tracing,
// because each task has an independent lifetime, but we do still track
// the causal relationship between the two using span links.
callerSpanContext := trace.SpanFromContext(ctx).SpanContext()
childCtx, childSpan := tracer.Start(
ctx, "async task",
trace.WithNewRoot(),
trace.WithLinks(
trace.Link{
SpanContext: callerSpanContext,
},
),
)
promises.AnnounceContainedPromises(func(apr AnyPromiseResolver) {
p := apr.promise()
if p.responsible.Load() != callerT {
// TODO: a better error message that gives some information
// about what mismatched?
panic("promise responsibility mismatch")
}
newT.responsible.Add(p)
callerT.responsible.Remove(p)
p.responsible.Store(newT)
p.traceSpan.AddEvent("delegated to new task", trace.WithAttributes(
attribute.String("promising.delegated_from", callerSpanContext.SpanID().String()),
attribute.String("promising.delegated_to", childSpan.SpanContext().SpanID().String()),
))
childSpan.AddEvent("inherited promise responsibility", trace.WithAttributes(
attribute.String("promising.responsible_for", p.traceSpan.SpanContext().SpanID().String()),
))
})
go func() {
ctx := childCtx
defer childSpan.End()
ctx = contextWithTask(ctx, newT)
impl(ctx, promises)
// The implementation function must have either resolved all of its
// promises or transferred responsibility for them to another task
// before it returns.
for unresolved := range newT.responsible {
err := ErrUnresolved{unresolved.promiseID()}
resolvePromise(unresolved, nil, err)
}
}()
}
// AssertContextInTask panics if the given context does not belong to a
// promising task, or does nothing at all if it does.
//
// This is here just as a helper for clearly marking functions that only
// make sense to call when in a task context, typically because they are
// going to rely on promises.
func AssertContextInTask(ctx context.Context) {
_, ok := ctx.Value(taskContextKey).(*task)
if !ok {
panic("function requires an active task, but the given context does not belong to one")
}
}
func mustTaskFromContext(ctx context.Context) *task {
ret, ok := ctx.Value(taskContextKey).(*task)
if !ok {
panic("cannot interact with promises or tasks from non-task context")
}
return ret
}
func contextWithTask(ctx context.Context, t *task) context.Context {
return context.WithValue(ctx, taskContextKey, t)
}
type taskContextKeyType int
const taskContextKey taskContextKeyType = 0