[llvm] 6608908 - [ADT] Extend EnumeratedArray

Jannik Silvanus via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 18 08:09:30 PDT 2022


Author: Jannik Silvanus
Date: 2022-10-18T17:08:38+02:00
New Revision: 6608908b1b7fd9146f632b040c0b48c2fb661966

URL: https://github.com/llvm/llvm-project/commit/6608908b1b7fd9146f632b040c0b48c2fb661966
DIFF: https://github.com/llvm/llvm-project/commit/6608908b1b7fd9146f632b040c0b48c2fb661966.diff

LOG: [ADT] Extend EnumeratedArray

EnumeratedArray is essentially a wrapper around a fixed-size
array that uses enum values instead of integers as indices.

 * Add iterator support (begin/end/rbegin/rend), which enables
   the use of iterator/range based algorithms on EnumeratedArrays.
 * Add common container typedefs (value_type etc.), allowing
   drop-in replacements of other containers in cases relying on these.
 * Add a constructor that takes an std::initializer_list<T>.
 * Make the size() function const.
 * Add empty().

Iterator support slightly lowers the protection non-type-safe accesses,
because iterator arithmetic is not enum-based, and one can now use
*(begin() + IntIndex). However, it is and was also always possible to
just cast arbitrary indices to the enum type.

Differential Revision: https://reviews.llvm.org/D135594

Added: 
    

Modified: 
    llvm/include/llvm/ADT/EnumeratedArray.h
    llvm/unittests/ADT/EnumeratedArrayTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/ADT/EnumeratedArray.h b/llvm/include/llvm/ADT/EnumeratedArray.h
index cece71d1be245..fd0700c8e408a 100644
--- a/llvm/include/llvm/ADT/EnumeratedArray.h
+++ b/llvm/include/llvm/ADT/EnumeratedArray.h
@@ -16,6 +16,7 @@
 #define LLVM_ADT_ENUMERATEDARRAY_H
 
 #include <cassert>
+#include <iterator>
 
 namespace llvm {
 
@@ -24,14 +25,33 @@ template <typename ValueType, typename Enumeration,
           IndexType Size = 1 + static_cast<IndexType>(LargestEnum)>
 class EnumeratedArray {
 public:
+  using iterator = ValueType *;
+  using const_iterator = const ValueType *;
+
+  using const_reverse_iterator = std::reverse_iterator<const_iterator>;
+  using reverse_iterator = std::reverse_iterator<iterator>;
+
+  using value_type = ValueType;
+  using reference = ValueType &;
+  using const_reference = const ValueType &;
+  using pointer = ValueType *;
+  using const_pointer = const ValueType *;
+
   EnumeratedArray() = default;
   EnumeratedArray(ValueType V) {
     for (IndexType IX = 0; IX < Size; ++IX) {
       Underlying[IX] = V;
     }
   }
+  EnumeratedArray(std::initializer_list<ValueType> Init) {
+    assert(Init.size() == Size && "Incorrect initializer size");
+    for (IndexType IX = 0; IX < Size; ++IX) {
+      Underlying[IX] = *(Init.begin() + IX);
+    }
+  }
+
   const ValueType &operator[](Enumeration Index) const {
-    auto IX = static_cast<const IndexType>(Index);
+    auto IX = static_cast<IndexType>(Index);
     assert(IX >= 0 && IX < Size && "Index is out of bounds.");
     return Underlying[IX];
   }
@@ -40,7 +60,23 @@ class EnumeratedArray {
         static_cast<const EnumeratedArray<ValueType, Enumeration, LargestEnum,
                                           IndexType, Size> &>(*this)[Index]);
   }
-  IndexType size() { return Size; }
+  IndexType size() const { return Size; }
+  bool empty() const { return size() == 0; }
+
+  iterator begin() { return Underlying; }
+  const_iterator begin() const { return Underlying; }
+
+  iterator end() { return begin() + size(); }
+  const_iterator end() const { return begin() + size(); }
+
+  reverse_iterator rbegin() { return reverse_iterator(end()); }
+  const_reverse_iterator rbegin() const {
+    return const_reverse_iterator(end());
+  }
+  reverse_iterator rend() { return reverse_iterator(begin()); }
+  const_reverse_iterator rend() const {
+    return const_reverse_iterator(begin());
+  }
 
 private:
   ValueType Underlying[Size];

diff  --git a/llvm/unittests/ADT/EnumeratedArrayTest.cpp b/llvm/unittests/ADT/EnumeratedArrayTest.cpp
index 29107a7b9a7fa..9975428047141 100644
--- a/llvm/unittests/ADT/EnumeratedArrayTest.cpp
+++ b/llvm/unittests/ADT/EnumeratedArrayTest.cpp
@@ -11,7 +11,10 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/ADT/EnumeratedArray.h"
+#include "llvm/ADT/iterator_range.h"
+#include "gmock/gmock.h"
 #include "gtest/gtest.h"
+#include <type_traits>
 
 namespace llvm {
 
@@ -46,6 +49,73 @@ TEST(EnumeratedArray, InitAndIndex) {
   EXPECT_TRUE(Array2[Colors::Red]);
   EXPECT_FALSE(Array2[Colors::Blue]);
   EXPECT_TRUE(Array2[Colors::Green]);
+
+  EnumeratedArray<float, Colors, Colors::Last, size_t> Array3 = {10.0, 11.0,
+                                                                 12.0};
+  EXPECT_EQ(Array3[Colors::Red], 10.0);
+  EXPECT_EQ(Array3[Colors::Blue], 11.0);
+  EXPECT_EQ(Array3[Colors::Green], 12.0);
+}
+
+//===--------------------------------------------------------------------===//
+// Test size and empty function
+//===--------------------------------------------------------------------===//
+
+TEST(EnumeratedArray, Size) {
+
+  enum class Colors { Red, Blue, Green, Last = Green };
+
+  EnumeratedArray<float, Colors, Colors::Last, size_t> Array;
+  const auto &ConstArray = Array;
+
+  EXPECT_EQ(ConstArray.size(), 3u);
+  EXPECT_EQ(ConstArray.empty(), false);
+}
+
+//===--------------------------------------------------------------------===//
+// Test iterators
+//===--------------------------------------------------------------------===//
+
+TEST(EnumeratedArray, Iterators) {
+
+  enum class Colors { Red, Blue, Green, Last = Green };
+
+  EnumeratedArray<float, Colors, Colors::Last, size_t> Array;
+  const auto &ConstArray = Array;
+
+  Array[Colors::Red] = 1.0;
+  Array[Colors::Blue] = 2.0;
+  Array[Colors::Green] = 3.0;
+
+  EXPECT_THAT(Array, testing::ElementsAre(1.0, 2.0, 3.0));
+  EXPECT_THAT(ConstArray, testing::ElementsAre(1.0, 2.0, 3.0));
+
+  EXPECT_THAT(make_range(Array.rbegin(), Array.rend()),
+              testing::ElementsAre(3.0, 2.0, 1.0));
+  EXPECT_THAT(make_range(ConstArray.rbegin(), ConstArray.rend()),
+              testing::ElementsAre(3.0, 2.0, 1.0));
 }
 
+//===--------------------------------------------------------------------===//
+// Test typedefs
+//===--------------------------------------------------------------------===//
+
+namespace {
+
+enum class Colors { Red, Blue, Green, Last = Green };
+
+using Array = EnumeratedArray<float, Colors, Colors::Last, size_t>;
+
+static_assert(std::is_same<Array::value_type, float>::value,
+              "Incorrect value_type type");
+static_assert(std::is_same<Array::reference, float &>::value,
+              "Incorrect reference type!");
+static_assert(std::is_same<Array::pointer, float *>::value,
+              "Incorrect pointer type!");
+static_assert(std::is_same<Array::const_reference, const float &>::value,
+              "Incorrect const_reference type!");
+static_assert(std::is_same<Array::const_pointer, const float *>::value,
+              "Incorrect const_pointer type!");
+} // namespace
+
 } // namespace llvm


        


More information about the llvm-commits mailing list