diff --git a/c10/test/util/irange_test.cpp b/c10/test/util/irange_test.cpp index 8bf2f38aeb4..66fa4d4d786 100644 --- a/c10/test/util/irange_test.cpp +++ b/c10/test/util/irange_test.cpp @@ -4,6 +4,8 @@ #include +#include + using namespace ::testing; TEST(irangeTest, range_test) { @@ -56,3 +58,31 @@ TEST(irange, empty_reverse_range_one_input) { const std::vector correct = {}; ASSERT_EQ(test_vec, correct); } + +constexpr std::array toy_iota() { + std::array result = {0}; + for (const auto i : c10::irange(3)) { + result[i] = i; + } + return result; +} + +constexpr std::array toy_iota_with_start(int start) { + std::array result = {0}; + for (const auto i : c10::irange(start, start + 3)) { + result[i - start] = i; + } + return result; +} + +TEST(irange, constexpr_ok) { + constexpr auto arr = toy_iota(); + static_assert(arr[0] == 0); + static_assert(arr[1] == 1); + static_assert(arr[2] == 2); + + constexpr auto arr2 = toy_iota_with_start(4); + static_assert(arr2[0] == 4); + static_assert(arr2[1] == 5); + static_assert(arr2[2] == 6); +} diff --git a/c10/util/irange.h b/c10/util/irange.h index 2719a82075c..f5310510099 100644 --- a/c10/util/irange.h +++ b/c10/util/irange.h @@ -24,28 +24,28 @@ struct integer_iterator { using pointer = I*; using reference = I&; - explicit integer_iterator(I value) : value(value) {} + explicit constexpr integer_iterator(I value) : value(value) {} - I operator*() const { + constexpr I operator*() const { return value; } - I const* operator->() const { + constexpr I const* operator->() const { return &value; } - integer_iterator& operator++() { + constexpr integer_iterator& operator++() { ++value; return *this; } - integer_iterator operator++(int) { + constexpr integer_iterator operator++(int) { const auto copy = *this; ++*this; return copy; } - bool operator==(const integer_iterator& other) const { + constexpr bool operator==(const integer_iterator& other) const { if constexpr (one_sided) { // Range-for loops' end test is `begin != end`, not `begin < // end`. To handle `c10::irange(n)` where n < 0 (which should be @@ -64,7 +64,7 @@ struct integer_iterator { return false; // Horrible hack } - bool operator!=(const integer_iterator& other) const { + constexpr bool operator!=(const integer_iterator& other) const { return !(*this == other); } @@ -80,12 +80,12 @@ template < std::enable_if_t, bool> = true> struct integer_range { public: - integer_range(I begin, I end) : begin_(begin), end_(end) {} + constexpr integer_range(I begin, I end) : begin_(begin), end_(end) {} using iterator = detail::integer_iterator; - iterator begin() const { + constexpr iterator begin() const { return begin_; } - iterator end() const { + constexpr iterator end() const { return end_; } @@ -103,7 +103,7 @@ template < typename Integer2, std::enable_if_t, bool> = true, std::enable_if_t, bool> = true> -integer_range irange(Integer1 begin, Integer2 end) { +constexpr integer_range irange(Integer1 begin, Integer2 end) { // If end<=begin then the range is empty; we can achieve this effect by // choosing the larger of {begin, end} as the loop terminator return { @@ -116,7 +116,7 @@ integer_range irange(Integer1 begin, Integer2 end) { template < typename Integer, std::enable_if_t, bool> = true> -integer_range irange(Integer end) { +constexpr integer_range irange(Integer end) { return {Integer(), end}; }