No public description

PiperOrigin-RevId: 654364452
Change-Id: I7189202a6135d5504068bbacd24be42adb689f7d
diff --git a/Eigen/src/ThreadPool/NonBlockingThreadPool.h b/Eigen/src/ThreadPool/NonBlockingThreadPool.h
index 0939dde..b3575c7 100644
--- a/Eigen/src/ThreadPool/NonBlockingThreadPool.h
+++ b/Eigen/src/ThreadPool/NonBlockingThreadPool.h
@@ -31,8 +31,8 @@
         all_coprimes_(num_threads),
         waiters_(num_threads),
         global_steal_partition_(EncodePartition(0, num_threads_)),
+        spinning_state_(0),
         blocked_(0),
-        spinning_(0),
         done_(false),
         cancelled_(false),
         ec_(waiters_) {
@@ -125,7 +125,9 @@
     // this. We expect that such scenario is prevented by program, that is,
     // this is kept alive while any threads can potentially be in Schedule.
     if (!t.f) {
-      ec_.Notify(false);
+      if (IsNotifyParkedThreadRequired()) {
+        ec_.Notify(false);
+      }
     } else {
       env_.ExecuteTask(t);  // Push failed, execute directly.
     }
@@ -165,8 +167,8 @@
   // Exposed publicly as static functions so that external callers can reuse
   // this encode/decode logic for maintaining their own thread-safe copies of
   // scheduling and steal domain(s).
-  static const int kMaxPartitionBits = 16;
-  static const int kMaxThreads = 1 << kMaxPartitionBits;
+  static constexpr int kMaxPartitionBits = 16;
+  static constexpr int kMaxThreads = 1 << kMaxPartitionBits;
 
   inline unsigned EncodePartition(unsigned start, unsigned limit) { return (start << kMaxPartitionBits) | limit; }
 
@@ -220,6 +222,59 @@
     Queue queue;
   };
 
+  // Maximum number of threads that can spin in steal loop.
+  static constexpr int kMaxSpinningThreads = 1;
+
+  // The number of steal loop spin iterations before parking (this number is
+  // divided by the number of threads, to get spin count for each thread).
+  static constexpr int kSpinCount = 5000;
+
+  // If there are enough active threads with empty pending-task queues, a thread
+  // that runs out of work can just be parked without spinning, because these
+  // active threads will go into a steal loop after finishing their current
+  // tasks.
+  //
+  // In the worst case when all active threads are executing long/expensive
+  // tasks, the next Schedule() will have to wait until one of the parked
+  // threads will be unparked, however this should be very rare in practice.
+  static constexpr int kMinActiveThreadsToStartSpinning = 4;
+
+  struct SpinningState {
+    // Spinning state layout:
+    //
+    // - Low 32 bits encode the number of threads that are spinning in steal
+    //   loop.
+    //
+    // - High 32 bits encode the number of tasks that were submitted to the pool
+    //   without a call to `ec_.Notify()`. This number can't be larger than
+    //   the number of spinning threads. Each spinning thread, when it exits the
+    //   spin loop must check if this number is greater than zero, and maybe
+    //   make another attempt to steal a task and decrement it by one.
+    static constexpr uint64_t kNumSpinningMask = 0x00000000FFFFFFFF;
+    static constexpr uint64_t kNumNoNotifyMask = 0xFFFFFFFF00000000;
+    static constexpr uint64_t kNumNoNotifyShift = 32;
+
+    uint64_t num_spinning;         // number of spinning threads
+    uint64_t num_no_notification;  // number of tasks submitted without
+                                   // notifying waiting threads
+
+    // Decodes `spinning_state_` value.
+    static SpinningState Decode(uint64_t state) {
+      uint64_t num_spinning = (state & kNumSpinningMask);
+      uint64_t num_no_notification =
+          (state & kNumNoNotifyMask) >> kNumNoNotifyShift;
+
+      assert(num_no_notification <= num_spinning);
+      return {num_spinning, num_no_notification};
+    }
+
+    // Encodes as `spinning_state_` value.
+    uint64_t Encode() const {
+      assert(num_no_notification <= num_spinning);
+      return (num_no_notification << kNumNoNotifyShift) | num_spinning;
+    }
+  };
+
   Environment env_;
   const int num_threads_;
   const bool allow_spinning_;
@@ -227,8 +282,8 @@
   MaxSizeVector<MaxSizeVector<unsigned>> all_coprimes_;
   MaxSizeVector<EventCount::Waiter> waiters_;
   unsigned global_steal_partition_;
+  std::atomic<uint64_t> spinning_state_;
   std::atomic<unsigned> blocked_;
-  std::atomic<bool> spinning_;
   std::atomic<bool> done_;
   std::atomic<bool> cancelled_;
   EventCount ec_;
@@ -238,6 +293,9 @@
   std::unordered_map<uint64_t, std::unique_ptr<PerThread>> per_thread_map_;
 #endif
 
+  unsigned NumBlockedThreads() const { return blocked_.load(); }
+  unsigned NumActiveThreads() const { return num_threads_ - blocked_.load(); }
+
   // Main worker thread loop.
   void WorkerLoop(int thread_id) {
 #ifndef EIGEN_THREAD_LOCAL
@@ -258,9 +316,12 @@
     EventCount::Waiter* waiter = &waiters_[thread_id];
     // TODO(dvyukov,rmlarsen): The time spent in NonEmptyQueueIndex() is
     // proportional to num_threads_ and we assume that new work is scheduled at
-    // a constant rate, so we set spin_count to 5000 / num_threads_. The
-    // constant was picked based on a fair dice roll, tune it.
-    const int spin_count = allow_spinning_ && num_threads_ > 0 ? 5000 / num_threads_ : 0;
+    // a constant rate, so we divide `kSpintCount` by number of threads and
+    // number of spinning threads. The constant was picked based on a fair dice
+    // roll, tune it.
+    const int spin_count = allow_spinning_ && num_threads_ > 0
+                               ? kSpinCount / kMaxSpinningThreads / num_threads_
+                               : 0;
     if (num_threads_ == 1) {
       // For num_threads_ == 1 there is no point in going through the expensive
       // steal loop. Moreover, since NonEmptyQueueIndex() calls PopBack() on the
@@ -268,50 +329,70 @@
       // compared to the order in which they are scheduled, which tends to be
       // counter-productive for the types of I/O workloads the single thread
       // pools tend to be used for.
-      while (!cancelled_) {
+      while (!cancelled_.load(std::memory_order_relaxed)) {
         Task t = q.PopFront();
-        for (int i = 0; i < spin_count && !t.f; i++) {
-          if (!cancelled_.load(std::memory_order_relaxed)) {
-            t = q.PopFront();
-          }
+
+        for (int i = 0; i < spin_count && !t.f; ++i) {
+          t = q.PopFront();
         }
-        if (!t.f) {
+
+        if (EIGEN_PREDICT_FALSE(!t.f)) {
           if (!WaitForWork(waiter, &t)) {
             return;
           }
         }
-        if (t.f) {
+
+        if (EIGEN_PREDICT_TRUE(t.f)) {
           env_.ExecuteTask(t);
         }
       }
+
     } else {
-      while (!cancelled_) {
+      while (!cancelled_.load(std::memory_order_relaxed)) {
         Task t = q.PopFront();
-        if (!t.f) {
+
+        // Do one round of steal loop from local thread partition.
+        if (EIGEN_PREDICT_FALSE(!t.f)) {
           t = LocalSteal();
-          if (!t.f) {
-            t = GlobalSteal();
-            if (!t.f) {
-              // Leave one thread spinning. This reduces latency.
-              if (allow_spinning_ && !spinning_ && !spinning_.exchange(true)) {
-                for (int i = 0; i < spin_count && !t.f; i++) {
-                  if (!cancelled_.load(std::memory_order_relaxed)) {
-                    t = GlobalSteal();
-                  } else {
-                    return;
-                  }
-                }
-                spinning_ = false;
-              }
-              if (!t.f) {
-                if (!WaitForWork(waiter, &t)) {
-                  return;
-                }
-              }
+        }
+
+        // Do one round of steal loop from global thread partition.
+        if (EIGEN_PREDICT_FALSE(!t.f)) {
+          t = GlobalSteal();
+        }
+
+        // Maybe leave a thread spinning. This improves latency.
+        if (EIGEN_PREDICT_FALSE(!t.f)) {
+          if (allow_spinning_ && StartSpinning()) {
+            for (int i = 0; i < spin_count && !t.f; ++i) {
+              t = GlobalSteal();
+            }
+
+            // Notify `spinning_state_` that we are no longer spinning.
+            bool has_no_notify_task = StopSpinning();
+
+            // If a task was submitted to the queue without a call to
+            // `ec_.Notify()` (if `IsNotifyParkedThreadRequired()` returned
+            // false), and we didn't steal anything above, we must try to
+            // steal one more time, to make sure that this task will be
+            // executed. We will not necessarily find it, because it might
+            // have been already stolen by some other thread.
+            if (has_no_notify_task && !t.f) {
+              t = GlobalSteal();
             }
           }
         }
-        if (t.f) {
+
+        // If we still don't have a task, wait for one. Return if thread pool is
+        // in cancelled state.
+        if (EIGEN_PREDICT_FALSE(!t.f)) {
+          if (!WaitForWork(waiter, &t)) {
+            return;
+          }
+        }
+
+        // Execute task if we found one.
+        if (EIGEN_PREDICT_TRUE(t.f)) {
           env_.ExecuteTask(t);
         }
       }
@@ -433,6 +514,80 @@
     return -1;
   }
 
+  // StartSpinning() checks if the number of threads in the spin loop is less
+  // than the allowed maximum. If so, increments the number of spinning threads
+  // by one and returns true (caller must enter the spin loop). Otherwise
+  // returns false, and the caller must not enter the spin loop.
+  bool StartSpinning() {
+    if (NumActiveThreads() > kMinActiveThreadsToStartSpinning) return false;
+
+    uint64_t spinning = spinning_state_.load(std::memory_order_relaxed);
+    for (;;) {
+      SpinningState state = SpinningState::Decode(spinning);
+
+      if ((state.num_spinning - state.num_no_notification) >=
+          kMaxSpinningThreads) {
+        return false;
+      }
+
+      // Increment the number of spinning threads.
+      ++state.num_spinning;
+
+      if (spinning_state_.compare_exchange_weak(spinning, state.Encode(),
+                                                std::memory_order_relaxed)) {
+        return true;
+      }
+    }
+  }
+
+  // StopSpinning() decrements the number of spinning threads by one. It also
+  // checks if there were any tasks submitted into the pool without notifying
+  // parked threads, and decrements the count by one. Returns true if the number
+  // of tasks submitted without notification was decremented. In this case,
+  // caller thread might have to call Steal() one more time.
+  bool StopSpinning() {
+    uint64_t spinning = spinning_state_.load(std::memory_order_relaxed);
+    for (;;) {
+      SpinningState state = SpinningState::Decode(spinning);
+
+      // Decrement the number of spinning threads.
+      --state.num_spinning;
+
+      // Maybe decrement the number of tasks submitted without notification.
+      bool has_no_notify_task = state.num_no_notification > 0;
+      if (has_no_notify_task) --state.num_no_notification;
+
+      if (spinning_state_.compare_exchange_weak(spinning, state.Encode(),
+                                                std::memory_order_relaxed)) {
+        return has_no_notify_task;
+      }
+    }
+  }
+
+  // IsNotifyParkedThreadRequired() returns true if parked thread must be
+  // notified about new added task. If there are threads spinning in the steal
+  // loop, there is no need to unpark any of the waiting threads, the task will
+  // be picked up by one of the spinning threads.
+  bool IsNotifyParkedThreadRequired() {
+    uint64_t spinning = spinning_state_.load(std::memory_order_relaxed);
+    for (;;) {
+      SpinningState state = SpinningState::Decode(spinning);
+
+      // If the number of tasks submitted without notifying parked threads is
+      // equal to the number of spinning threads, we must wake up one of the
+      // parked threads.
+      if (state.num_no_notification == state.num_spinning) return true;
+
+      // Increment the number of tasks submitted without notification.
+      ++state.num_no_notification;
+
+      if (spinning_state_.compare_exchange_weak(spinning, state.Encode(),
+                                                std::memory_order_relaxed)) {
+        return false;
+      }
+    }
+  }
+
   static EIGEN_STRONG_INLINE uint64_t GlobalThreadIdHash() {
     return std::hash<std::thread::id>()(std::this_thread::get_id());
   }