diff --git a/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp b/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp index 835b92aff90..322ddfabc4b 100644 --- a/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp +++ b/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp @@ -187,6 +187,13 @@ class ProcessGroupNCCLNoHeartbeatCaught return hasMonitorThreadCaughtError_; } + void forceTryWriteDebugInfo() { + auto thread = tryWriteDebugInfo(); + if (thread) { + thread->join(); + } + } + protected: // Override the heartbeat monitor function to make sure that we capture // the exception in the monitor thread because we cannot try-catch it in @@ -482,6 +489,11 @@ TEST_F(ProcessGroupNCCLWatchdogTimeoutTest, testNCCLTimedoutDebugInfoFinished) { } ProcessGroupNCCLNoHeartbeatCaught pg(store_, 0, 1, options_); + // Write debug info will lead to watchdog thread to wait for 30 seconds. + // And this is hard to override, so we just call it before hand. Otherwise, + // we need to set a long heartbeat timeout which will make the test way + // slower. + pg.forceTryWriteDebugInfo(); watchdogTimeoutTestCommon(pg, 2); // The flag is true shows that the heartbeat monitor thread does not kill diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index a5e736c4ae2..73b2de2648d 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -1208,6 +1208,18 @@ void ProcessGroupNCCL::waitForPendingWorks() { void ProcessGroupNCCL::enableCollectivesTiming() { enableTiming_.store(true); } + +c10::optional ProcessGroupNCCL::tryWriteDebugInfo() { + std::lock_guard lock(writeDebugInfoMutex_); + if (writeDebugInfo_) { + return c10::nullopt; + } + // If we have not dumped the debugInfo return true and set the flag to false + writeDebugInfo_ = true; + return c10::optional( + std::thread(&ProcessGroupNCCL::dumpDebuggingInfo, this)); +} + void abortCommsFromMap( std::unordered_map>>& ncclCommsMap, @@ -1360,8 +1372,9 @@ void ProcessGroupNCCL::heartbeatMonitor() { } } - // Store debug info to storage. (By default to local disk) - std::thread debugInfoStoreThread(&ProcessGroupNCCL::dumpDebuggingInfo, this); + // Store debug info to storage if no other thread does it. (By default to + // local disk) + auto maybeWriteDebugInfo = tryWriteDebugInfo(); // Create a error message reported from MonitorThread, so // we throw exception and make the whole process to be killed. @@ -1382,7 +1395,7 @@ void ProcessGroupNCCL::heartbeatMonitor() { // destructors, we will sleep for some time before calling std::abort() to // kill the whole process. if ((terminateProcessGroup_.load() || collectiveDebugInfoMode_.load() || - debugInfoStoreThread.joinable()) && + (maybeWriteDebugInfo && maybeWriteDebugInfo->joinable())) && !terminateHeartbeatMonitorThread_.load()) { // Leave another two mins for desync report generation or process group // destroy. @@ -1394,8 +1407,8 @@ void ProcessGroupNCCL::heartbeatMonitor() { // thread, so We mark the thread detach and the dump of debug info becomes // "best effort". If the process exit normally, marking it detach also makes // sense because we don't really care about dumping the debug info. - if (debugInfoStoreThread.joinable()) { - debugInfoStoreThread.detach(); + if (maybeWriteDebugInfo && maybeWriteDebugInfo->joinable()) { + maybeWriteDebugInfo->detach(); } if (!terminateHeartbeatMonitorThread_.load()) { @@ -1515,14 +1528,26 @@ void ProcessGroupNCCL::workCleanupLoop() { // rank abort(); } + // Report desync state in case of timeout if (desyncDebug_ && timedOut) { try { - // Set shutdown mode, so the heartbeat monitor thread will not abort - // process immediately. + // Set shutdown mode, so the heartbeat monitor thread will not + // abort process immediately. collectiveDebugInfoMode_.store(true); + // Store debug info to storage. (By default to local disk) + auto dumpingDebugInfo = tryWriteDebugInfo(); auto desyncMsg = getNCCLWatchdogDebugInfo(); LOG(ERROR) << desyncMsg; + if (dumpingDebugInfo && dumpingDebugInfo->joinable()) { + std::this_thread::sleep_for( + std::chrono::milliseconds(kWatchdogThreadSleepMillis * 30)); + // At this point, we either have already waited for + // `kWatchdogThreadSleepMillis * 30` or the thread has finished so + // that we mark the thread detach and the dump of debug info + // becomes "best effort". + dumpingDebugInfo->detach(); + } } catch (const std::exception& e) { LOG(ERROR) << "Failed to retrieve NCCL_DESYNC_DEBUG report. " << " Please file an issue. Error: " << e.what(); diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index c06c7e80c4c..75bf9e4843f 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -660,6 +660,11 @@ class TORCH_API ProcessGroupNCCL : public Backend { // gets terminated. virtual void terminateProcess(std::string errMsg); + // Check the writeDebugInfo_ flag and if it is true, we do nothing. + // If not, we first set the flag to be true and return a thread which will + // get and write the debug info into storage. + c10::optional tryWriteDebugInfo(); + // When watchdog timeout, this function will be called and return debug info // for users. For now we only get information from retrieveDesyncReport. // We are working on enabling more useful debug information for watchdog @@ -766,6 +771,11 @@ class TORCH_API ProcessGroupNCCL : public Backend { // Mutex to Guard monitorWakeUpCV_ std::mutex monitorMutex_; + bool writeDebugInfo_ = false; + + // Mutex to Guard the check of writeDebugInfo_ + std::mutex writeDebugInfoMutex_; + // Condition Variable for watchdog thread sleep std::condition_variable workMetaListCV_;