mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
Fix DataLoaderTest.EnforcesOrderingAmongThreadsWhenConfigured (#14038)
Summary: I think this will be it. So for one, the previous test was bullshit because it was returning the thread id instead of the sample index (which is the thing whose ordering is enforced). Just turning up the number of threads to 10 from 4 made this very obvious. I also think there is a race condition, which may or may not have surfaced, in that there was nothing stopping one worker to get multiple batches, which would screw with the whole ordering logic. I've added a barrier struct such that workers wait for all workers to be in the `get_batch` function before actually doing something. Fixes https://github.com/pytorch/pytorch/issues/14002 ezyang Pull Request resolved: https://github.com/pytorch/pytorch/pull/14038 Differential Revision: D13088132 Pulled By: goldsborough fbshipit-source-id: 4bded63756c6a49502ee07ef8709a03073e7e05f
This commit is contained in:
committed by
Facebook Github Bot
parent
f930c4307c
commit
8f4dc192b6
@@ -13,6 +13,7 @@
|
||||
#include <chrono>
|
||||
#include <future>
|
||||
#include <iostream>
|
||||
#include <iterator>
|
||||
#include <limits>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
@@ -833,57 +834,128 @@ TEST(DataLoaderTest, RespectsTimeout) {
|
||||
ASSERT_LT(duration.count(), 1);
|
||||
}
|
||||
|
||||
// https://stackoverflow.com/questions/24465533/implementing-boostbarrier-in-c11
|
||||
struct Barrier {
|
||||
explicit Barrier(size_t target) : counter_(target) {}
|
||||
void wait() {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
if (--counter_ == 0) {
|
||||
cv_.notify_all();
|
||||
} else {
|
||||
cv_.wait(lock, [this] { return this->counter_ == 0; });
|
||||
}
|
||||
}
|
||||
|
||||
size_t counter_;
|
||||
std::condition_variable cv_;
|
||||
std::mutex mutex_;
|
||||
};
|
||||
|
||||
// On the OrderingTest: This test is intended to verify that the
|
||||
// `enforce_ordering` option of the dataloader works correctly. The reason this
|
||||
// flag exists is because when the dataloader has multiple workers (threads)
|
||||
// enabled and this flag is not set, the order in which worker threads finish
|
||||
// loading their respective batch and push it back to the dataloader's main
|
||||
// thread (for outside consumption) is not deterministic. Imagine the sampler is
|
||||
// a SequentialSampler with indices 0, 1, 2, 3. With batch size 1, each index
|
||||
// will be a single "job". Inside the dataloader, worker threads block until a
|
||||
// job is available. It is not deterministic which worker thread wakes up first
|
||||
// to dequeue a particular batch. Further, some worker threads may take longer
|
||||
// than others to read the data for their index. As such, it could be that
|
||||
// worker thread 2 finishes before all other threads and returns its batch to
|
||||
// the main thread. In that case, the dataloader iterator would return the datum
|
||||
// at index 2 first, and afterwards the datum from whatever thread finishes
|
||||
// next. As such, the user may see data from indices 2, 0, 3, 1. On another run
|
||||
// of the same dataloader on the same data, threads may be scheduled differently
|
||||
// and return in order 0, 2, 3, 1. To force this ordering to deterministically
|
||||
// be 0, 1, 2, 3, the `enforce_ordering` flag can be set to true. In that case,
|
||||
// the dataloader will use a *sequencer* internally which keeps track of which
|
||||
// datum is expected next, and buffers any other results until that next
|
||||
// expected value arrives. For example, workers 1, 2, 3 may finish before worker
|
||||
// 0. If `enforce_ordering` is true, the sequencer will internally buffer the
|
||||
// results from 1, 2, 3 until worker 0 finishes. Only then does the dataloader
|
||||
// return the datum from worker 0 to the user (and then datum 1 the next time,
|
||||
// then 2 and so on).
|
||||
//
|
||||
// The way the test works is that we start
|
||||
// `kNumberOfWorkers` workers in the dataloader, which each get an index from a
|
||||
// `SequentialSampler` in the range `0...kNumberOfWorkers-1`. Each worker thread
|
||||
// has a copy of the dataset, and thus `get_batch()` is called on the
|
||||
// thread-local copy in each worker. We want to simulate out-of-order completion
|
||||
// of these threads. For this, we first set a barrier in the `get_batch()`
|
||||
// method to make sure every worker has some index to fetch assigned. Further,
|
||||
// each worker thread has a unique ID in `0...kNumberOfWorkers-1`.
|
||||
// There is a hard-coded ordering, `kOrderInWhichWorkersReturnTheirBatch`, in
|
||||
// which we want the worker threads to return. For this, an iterator into this
|
||||
// order is maintained. When the derferenced iterator (the current order index)
|
||||
// matches the thread ID of a worker, it knows it can now return its index as
|
||||
// well as progress the iterator. Inside the dataloader, the sequencer should
|
||||
// buffer these indices such that they are ultimately returned in order.
|
||||
|
||||
namespace ordering_test {
|
||||
namespace {
|
||||
std::atomic<size_t> ordering_test_counter{0};
|
||||
std::condition_variable ordering_test_cv;
|
||||
std::mutex ordering_test_mutex;
|
||||
const std::array<size_t, 4> ordering_test_order = {3, 1, 0, 2};
|
||||
std::atomic<size_t> ordering_test_index{0};
|
||||
const size_t kNumberOfWorkers = 10;
|
||||
const std::vector<size_t> kOrderInWhichWorkersReturnTheirBatch =
|
||||
{3, 7, 0, 5, 4, 8, 2, 1, 9, 6};
|
||||
} // namespace
|
||||
|
||||
struct OrderingTestDataset : datasets::BatchDataset<DummyDataset, int> {
|
||||
OrderingTestDataset() = default;
|
||||
struct Dataset : datasets::BatchDataset<Dataset, size_t> {
|
||||
Dataset() = default;
|
||||
|
||||
// This copy constructor will be called when we copy the dataset into a
|
||||
// particular thread.
|
||||
OrderingTestDataset(const OrderingTestDataset& other)
|
||||
: id(ordering_test_counter++) {}
|
||||
Dataset(const Dataset& other) {
|
||||
static std::atomic<size_t> counter{0};
|
||||
thread_id_ = counter.fetch_add(1);
|
||||
}
|
||||
|
||||
OrderingTestDataset(OrderingTestDataset&& other) noexcept = default;
|
||||
OrderingTestDataset& operator=(const OrderingTestDataset& other) = delete;
|
||||
OrderingTestDataset& operator=(OrderingTestDataset&& other) noexcept = delete;
|
||||
Dataset(Dataset&& other) noexcept = default;
|
||||
Dataset& operator=(const Dataset& other) = delete;
|
||||
Dataset& operator=(Dataset&& other) noexcept = delete;
|
||||
|
||||
int get_batch(torch::ArrayRef<size_t> indices) override {
|
||||
std::unique_lock<std::mutex> lock(ordering_test_mutex);
|
||||
// block until order.at(index) == my_thread_id (until it's this thread's
|
||||
// turn)
|
||||
ordering_test_cv.wait(lock, [this] {
|
||||
return ordering_test_order.at(ordering_test_index.load()) == this->id;
|
||||
});
|
||||
// Make one step in the order.
|
||||
++ordering_test_index;
|
||||
size_t get_batch(torch::ArrayRef<size_t> indices) override {
|
||||
static Barrier barrier(kNumberOfWorkers);
|
||||
static auto order_iterator = kOrderInWhichWorkersReturnTheirBatch.begin();
|
||||
static std::condition_variable cv;
|
||||
static std::mutex mutex;
|
||||
|
||||
// Wait for all threads to get an index batch and arrive here.
|
||||
barrier.wait();
|
||||
|
||||
std::unique_lock<std::mutex> lock(mutex);
|
||||
cv.wait(lock, [this] { return *order_iterator == this->thread_id_; });
|
||||
++order_iterator;
|
||||
lock.unlock();
|
||||
// Wake up the other threads to check if it's their turn to return.
|
||||
ordering_test_cv.notify_all();
|
||||
cv.notify_all();
|
||||
|
||||
return id;
|
||||
return indices.front();
|
||||
}
|
||||
|
||||
torch::optional<size_t> size() const {
|
||||
return 4;
|
||||
torch::optional<size_t> size() const override {
|
||||
return kNumberOfWorkers;
|
||||
}
|
||||
|
||||
size_t id = 0;
|
||||
size_t thread_id_ = 0;
|
||||
};
|
||||
|
||||
} // namespace ordering_test
|
||||
|
||||
TEST(DataLoaderTest, EnforcesOrderingAmongThreadsWhenConfigured) {
|
||||
auto data_loader = torch::data::make_data_loader(
|
||||
OrderingTestDataset{},
|
||||
DataLoaderOptions().batch_size(1).workers(4).enforce_ordering(true));
|
||||
size_t index = 0;
|
||||
for (int value : *data_loader) {
|
||||
ASSERT_EQ(value, index++);
|
||||
ordering_test::Dataset{},
|
||||
DataLoaderOptions()
|
||||
.batch_size(1)
|
||||
.workers(ordering_test::kNumberOfWorkers)
|
||||
.enforce_ordering(true),
|
||||
torch::data::samplers::SequentialSampler(
|
||||
ordering_test::kNumberOfWorkers));
|
||||
std::vector<size_t> output;
|
||||
for (size_t value : *data_loader) {
|
||||
output.push_back(value);
|
||||
}
|
||||
std::vector<size_t> expected(ordering_test::kNumberOfWorkers);
|
||||
std::iota(expected.begin(), expected.end(), size_t(0));
|
||||
ASSERT_EQ(expected, output);
|
||||
}
|
||||
|
||||
TEST(DataLoaderTest, Reset) {
|
||||
|
||||
Reference in New Issue
Block a user