No public description
PiperOrigin-RevId: 744780506
Change-Id: I45d84a786b66e78e4de0e342cf951b5efb94b7cd
diff --git a/Eigen/src/ThreadPool/ForkJoin.h b/Eigen/src/ThreadPool/ForkJoin.h
index f67abd3..44eefaf 100644
--- a/Eigen/src/ThreadPool/ForkJoin.h
+++ b/Eigen/src/ThreadPool/ForkJoin.h
@@ -61,9 +61,9 @@
// Runs `do_func` asynchronously for the range [start, end) with a specified
// granularity. `do_func` should be of type `std::function<void(Index,
// Index)`. `done()` is called exactly once after all tasks have been executed.
- template <typename DoFnType, typename DoneFnType>
+ template <typename DoFnType, typename DoneFnType, typename ThreadPoolEnv>
static void ParallelForAsync(Index start, Index end, Index granularity, DoFnType&& do_func, DoneFnType&& done,
- ThreadPool* thread_pool) {
+ ThreadPoolTempl<ThreadPoolEnv>* thread_pool) {
if (start >= end) {
done();
return;
@@ -76,8 +76,11 @@
}
// Synchronous variant of ParallelForAsync.
- template <typename DoFnType>
- static void ParallelFor(Index start, Index end, Index granularity, DoFnType&& do_func, ThreadPool* thread_pool) {
+ // WARNING: Making nested calls to `ParallelFor`, e.g., calling `ParallelFor` inside a task passed into another
+ // `ParallelFor` call, may lead to deadlocks due to how task stealing is implemented.
+ template <typename DoFnType, typename ThreadPoolEnv>
+ static void ParallelFor(Index start, Index end, Index granularity, DoFnType&& do_func,
+ ThreadPoolTempl<ThreadPoolEnv>* thread_pool) {
if (start >= end) return;
Barrier barrier(1);
auto done = [&barrier]() { barrier.Notify(); };
@@ -87,8 +90,8 @@
private:
// Schedules `right_thunk`, runs `left_thunk`, and runs other tasks until `right_thunk` has finished.
- template <typename LeftType, typename RightType>
- static void ForkJoin(LeftType&& left_thunk, RightType&& right_thunk, ThreadPool* thread_pool) {
+ template <typename LeftType, typename RightType, typename ThreadPoolEnv>
+ static void ForkJoin(LeftType&& left_thunk, RightType&& right_thunk, ThreadPoolTempl<ThreadPoolEnv>* thread_pool) {
std::atomic<bool> right_done(false);
auto execute_right = [&right_thunk, &right_done]() {
std::forward<RightType>(right_thunk)();
@@ -114,16 +117,16 @@
return start + offset;
}
- template <typename DoFnType>
- static void RunParallelFor(Index start, Index end, Index granularity, DoFnType&& do_func, ThreadPool* thread_pool) {
+ template <typename DoFnType, typename ThreadPoolEnv>
+ static void RunParallelFor(Index start, Index end, Index granularity, DoFnType&& do_func,
+ ThreadPoolTempl<ThreadPoolEnv>* thread_pool) {
Index mid = ComputeMidpoint(start, end, granularity);
if ((end - start) < granularity || mid == start || mid == end) {
do_func(start, end);
return;
}
- ForkJoin([start, mid, granularity, &do_func, thread_pool]() {
- RunParallelFor(start, mid, granularity, do_func, thread_pool);
- },
+ ForkJoin([start, mid, granularity, &do_func,
+ thread_pool]() { RunParallelFor(start, mid, granularity, do_func, thread_pool); },
[mid, end, granularity, &do_func, thread_pool]() {
RunParallelFor(mid, end, granularity, do_func, thread_pool);
},