No public description
PiperOrigin-RevId: 654364452
Change-Id: I7189202a6135d5504068bbacd24be42adb689f7d
diff --git a/Eigen/src/ThreadPool/NonBlockingThreadPool.h b/Eigen/src/ThreadPool/NonBlockingThreadPool.h
index 0939dde..b3575c7 100644
--- a/Eigen/src/ThreadPool/NonBlockingThreadPool.h
+++ b/Eigen/src/ThreadPool/NonBlockingThreadPool.h
@@ -31,8 +31,8 @@
all_coprimes_(num_threads),
waiters_(num_threads),
global_steal_partition_(EncodePartition(0, num_threads_)),
+ spinning_state_(0),
blocked_(0),
- spinning_(0),
done_(false),
cancelled_(false),
ec_(waiters_) {
@@ -125,7 +125,9 @@
// this. We expect that such scenario is prevented by program, that is,
// this is kept alive while any threads can potentially be in Schedule.
if (!t.f) {
- ec_.Notify(false);
+ if (IsNotifyParkedThreadRequired()) {
+ ec_.Notify(false);
+ }
} else {
env_.ExecuteTask(t); // Push failed, execute directly.
}
@@ -165,8 +167,8 @@
// Exposed publicly as static functions so that external callers can reuse
// this encode/decode logic for maintaining their own thread-safe copies of
// scheduling and steal domain(s).
- static const int kMaxPartitionBits = 16;
- static const int kMaxThreads = 1 << kMaxPartitionBits;
+ static constexpr int kMaxPartitionBits = 16;
+ static constexpr int kMaxThreads = 1 << kMaxPartitionBits;
inline unsigned EncodePartition(unsigned start, unsigned limit) { return (start << kMaxPartitionBits) | limit; }
@@ -220,6 +222,59 @@
Queue queue;
};
+ // Maximum number of threads that can spin in steal loop.
+ static constexpr int kMaxSpinningThreads = 1;
+
+ // The number of steal loop spin iterations before parking (this number is
+ // divided by the number of threads, to get spin count for each thread).
+ static constexpr int kSpinCount = 5000;
+
+ // If there are enough active threads with empty pending-task queues, a thread
+ // that runs out of work can just be parked without spinning, because these
+ // active threads will go into a steal loop after finishing their current
+ // tasks.
+ //
+ // In the worst case when all active threads are executing long/expensive
+ // tasks, the next Schedule() will have to wait until one of the parked
+ // threads will be unparked, however this should be very rare in practice.
+ static constexpr int kMinActiveThreadsToStartSpinning = 4;
+
+ struct SpinningState {
+ // Spinning state layout:
+ //
+ // - Low 32 bits encode the number of threads that are spinning in steal
+ // loop.
+ //
+ // - High 32 bits encode the number of tasks that were submitted to the pool
+ // without a call to `ec_.Notify()`. This number can't be larger than
+ // the number of spinning threads. Each spinning thread, when it exits the
+ // spin loop must check if this number is greater than zero, and maybe
+ // make another attempt to steal a task and decrement it by one.
+ static constexpr uint64_t kNumSpinningMask = 0x00000000FFFFFFFF;
+ static constexpr uint64_t kNumNoNotifyMask = 0xFFFFFFFF00000000;
+ static constexpr uint64_t kNumNoNotifyShift = 32;
+
+ uint64_t num_spinning; // number of spinning threads
+ uint64_t num_no_notification; // number of tasks submitted without
+ // notifying waiting threads
+
+ // Decodes `spinning_state_` value.
+ static SpinningState Decode(uint64_t state) {
+ uint64_t num_spinning = (state & kNumSpinningMask);
+ uint64_t num_no_notification =
+ (state & kNumNoNotifyMask) >> kNumNoNotifyShift;
+
+ assert(num_no_notification <= num_spinning);
+ return {num_spinning, num_no_notification};
+ }
+
+ // Encodes as `spinning_state_` value.
+ uint64_t Encode() const {
+ assert(num_no_notification <= num_spinning);
+ return (num_no_notification << kNumNoNotifyShift) | num_spinning;
+ }
+ };
+
Environment env_;
const int num_threads_;
const bool allow_spinning_;
@@ -227,8 +282,8 @@
MaxSizeVector<MaxSizeVector<unsigned>> all_coprimes_;
MaxSizeVector<EventCount::Waiter> waiters_;
unsigned global_steal_partition_;
+ std::atomic<uint64_t> spinning_state_;
std::atomic<unsigned> blocked_;
- std::atomic<bool> spinning_;
std::atomic<bool> done_;
std::atomic<bool> cancelled_;
EventCount ec_;
@@ -238,6 +293,9 @@
std::unordered_map<uint64_t, std::unique_ptr<PerThread>> per_thread_map_;
#endif
+ unsigned NumBlockedThreads() const { return blocked_.load(); }
+ unsigned NumActiveThreads() const { return num_threads_ - blocked_.load(); }
+
// Main worker thread loop.
void WorkerLoop(int thread_id) {
#ifndef EIGEN_THREAD_LOCAL
@@ -258,9 +316,12 @@
EventCount::Waiter* waiter = &waiters_[thread_id];
// TODO(dvyukov,rmlarsen): The time spent in NonEmptyQueueIndex() is
// proportional to num_threads_ and we assume that new work is scheduled at
- // a constant rate, so we set spin_count to 5000 / num_threads_. The
- // constant was picked based on a fair dice roll, tune it.
- const int spin_count = allow_spinning_ && num_threads_ > 0 ? 5000 / num_threads_ : 0;
+ // a constant rate, so we divide `kSpintCount` by number of threads and
+ // number of spinning threads. The constant was picked based on a fair dice
+ // roll, tune it.
+ const int spin_count = allow_spinning_ && num_threads_ > 0
+ ? kSpinCount / kMaxSpinningThreads / num_threads_
+ : 0;
if (num_threads_ == 1) {
// For num_threads_ == 1 there is no point in going through the expensive
// steal loop. Moreover, since NonEmptyQueueIndex() calls PopBack() on the
@@ -268,50 +329,70 @@
// compared to the order in which they are scheduled, which tends to be
// counter-productive for the types of I/O workloads the single thread
// pools tend to be used for.
- while (!cancelled_) {
+ while (!cancelled_.load(std::memory_order_relaxed)) {
Task t = q.PopFront();
- for (int i = 0; i < spin_count && !t.f; i++) {
- if (!cancelled_.load(std::memory_order_relaxed)) {
- t = q.PopFront();
- }
+
+ for (int i = 0; i < spin_count && !t.f; ++i) {
+ t = q.PopFront();
}
- if (!t.f) {
+
+ if (EIGEN_PREDICT_FALSE(!t.f)) {
if (!WaitForWork(waiter, &t)) {
return;
}
}
- if (t.f) {
+
+ if (EIGEN_PREDICT_TRUE(t.f)) {
env_.ExecuteTask(t);
}
}
+
} else {
- while (!cancelled_) {
+ while (!cancelled_.load(std::memory_order_relaxed)) {
Task t = q.PopFront();
- if (!t.f) {
+
+ // Do one round of steal loop from local thread partition.
+ if (EIGEN_PREDICT_FALSE(!t.f)) {
t = LocalSteal();
- if (!t.f) {
- t = GlobalSteal();
- if (!t.f) {
- // Leave one thread spinning. This reduces latency.
- if (allow_spinning_ && !spinning_ && !spinning_.exchange(true)) {
- for (int i = 0; i < spin_count && !t.f; i++) {
- if (!cancelled_.load(std::memory_order_relaxed)) {
- t = GlobalSteal();
- } else {
- return;
- }
- }
- spinning_ = false;
- }
- if (!t.f) {
- if (!WaitForWork(waiter, &t)) {
- return;
- }
- }
+ }
+
+ // Do one round of steal loop from global thread partition.
+ if (EIGEN_PREDICT_FALSE(!t.f)) {
+ t = GlobalSteal();
+ }
+
+ // Maybe leave a thread spinning. This improves latency.
+ if (EIGEN_PREDICT_FALSE(!t.f)) {
+ if (allow_spinning_ && StartSpinning()) {
+ for (int i = 0; i < spin_count && !t.f; ++i) {
+ t = GlobalSteal();
+ }
+
+ // Notify `spinning_state_` that we are no longer spinning.
+ bool has_no_notify_task = StopSpinning();
+
+ // If a task was submitted to the queue without a call to
+ // `ec_.Notify()` (if `IsNotifyParkedThreadRequired()` returned
+ // false), and we didn't steal anything above, we must try to
+ // steal one more time, to make sure that this task will be
+ // executed. We will not necessarily find it, because it might
+ // have been already stolen by some other thread.
+ if (has_no_notify_task && !t.f) {
+ t = GlobalSteal();
}
}
}
- if (t.f) {
+
+ // If we still don't have a task, wait for one. Return if thread pool is
+ // in cancelled state.
+ if (EIGEN_PREDICT_FALSE(!t.f)) {
+ if (!WaitForWork(waiter, &t)) {
+ return;
+ }
+ }
+
+ // Execute task if we found one.
+ if (EIGEN_PREDICT_TRUE(t.f)) {
env_.ExecuteTask(t);
}
}
@@ -433,6 +514,80 @@
return -1;
}
+ // StartSpinning() checks if the number of threads in the spin loop is less
+ // than the allowed maximum. If so, increments the number of spinning threads
+ // by one and returns true (caller must enter the spin loop). Otherwise
+ // returns false, and the caller must not enter the spin loop.
+ bool StartSpinning() {
+ if (NumActiveThreads() > kMinActiveThreadsToStartSpinning) return false;
+
+ uint64_t spinning = spinning_state_.load(std::memory_order_relaxed);
+ for (;;) {
+ SpinningState state = SpinningState::Decode(spinning);
+
+ if ((state.num_spinning - state.num_no_notification) >=
+ kMaxSpinningThreads) {
+ return false;
+ }
+
+ // Increment the number of spinning threads.
+ ++state.num_spinning;
+
+ if (spinning_state_.compare_exchange_weak(spinning, state.Encode(),
+ std::memory_order_relaxed)) {
+ return true;
+ }
+ }
+ }
+
+ // StopSpinning() decrements the number of spinning threads by one. It also
+ // checks if there were any tasks submitted into the pool without notifying
+ // parked threads, and decrements the count by one. Returns true if the number
+ // of tasks submitted without notification was decremented. In this case,
+ // caller thread might have to call Steal() one more time.
+ bool StopSpinning() {
+ uint64_t spinning = spinning_state_.load(std::memory_order_relaxed);
+ for (;;) {
+ SpinningState state = SpinningState::Decode(spinning);
+
+ // Decrement the number of spinning threads.
+ --state.num_spinning;
+
+ // Maybe decrement the number of tasks submitted without notification.
+ bool has_no_notify_task = state.num_no_notification > 0;
+ if (has_no_notify_task) --state.num_no_notification;
+
+ if (spinning_state_.compare_exchange_weak(spinning, state.Encode(),
+ std::memory_order_relaxed)) {
+ return has_no_notify_task;
+ }
+ }
+ }
+
+ // IsNotifyParkedThreadRequired() returns true if parked thread must be
+ // notified about new added task. If there are threads spinning in the steal
+ // loop, there is no need to unpark any of the waiting threads, the task will
+ // be picked up by one of the spinning threads.
+ bool IsNotifyParkedThreadRequired() {
+ uint64_t spinning = spinning_state_.load(std::memory_order_relaxed);
+ for (;;) {
+ SpinningState state = SpinningState::Decode(spinning);
+
+ // If the number of tasks submitted without notifying parked threads is
+ // equal to the number of spinning threads, we must wake up one of the
+ // parked threads.
+ if (state.num_no_notification == state.num_spinning) return true;
+
+ // Increment the number of tasks submitted without notification.
+ ++state.num_no_notification;
+
+ if (spinning_state_.compare_exchange_weak(spinning, state.Encode(),
+ std::memory_order_relaxed)) {
+ return false;
+ }
+ }
+ }
+
static EIGEN_STRONG_INLINE uint64_t GlobalThreadIdHash() {
return std::hash<std::thread::id>()(std::this_thread::get_id());
}