[llvm] 6608908 - [ADT] Extend EnumeratedArray
Jannik Silvanus via llvm-commits
llvm-commits at lists.llvm.org
Tue Oct 18 08:09:30 PDT 2022
Author: Jannik Silvanus
Date: 2022-10-18T17:08:38+02:00
New Revision: 6608908b1b7fd9146f632b040c0b48c2fb661966
URL: https://github.com/llvm/llvm-project/commit/6608908b1b7fd9146f632b040c0b48c2fb661966
DIFF: https://github.com/llvm/llvm-project/commit/6608908b1b7fd9146f632b040c0b48c2fb661966.diff
LOG: [ADT] Extend EnumeratedArray
EnumeratedArray is essentially a wrapper around a fixed-size
array that uses enum values instead of integers as indices.
* Add iterator support (begin/end/rbegin/rend), which enables
the use of iterator/range based algorithms on EnumeratedArrays.
* Add common container typedefs (value_type etc.), allowing
drop-in replacements of other containers in cases relying on these.
* Add a constructor that takes an std::initializer_list<T>.
* Make the size() function const.
* Add empty().
Iterator support slightly lowers the protection non-type-safe accesses,
because iterator arithmetic is not enum-based, and one can now use
*(begin() + IntIndex). However, it is and was also always possible to
just cast arbitrary indices to the enum type.
Differential Revision: https://reviews.llvm.org/D135594
Added:
Modified:
llvm/include/llvm/ADT/EnumeratedArray.h
llvm/unittests/ADT/EnumeratedArrayTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/ADT/EnumeratedArray.h b/llvm/include/llvm/ADT/EnumeratedArray.h
index cece71d1be245..fd0700c8e408a 100644
--- a/llvm/include/llvm/ADT/EnumeratedArray.h
+++ b/llvm/include/llvm/ADT/EnumeratedArray.h
@@ -16,6 +16,7 @@
#define LLVM_ADT_ENUMERATEDARRAY_H
#include <cassert>
+#include <iterator>
namespace llvm {
@@ -24,14 +25,33 @@ template <typename ValueType, typename Enumeration,
IndexType Size = 1 + static_cast<IndexType>(LargestEnum)>
class EnumeratedArray {
public:
+ using iterator = ValueType *;
+ using const_iterator = const ValueType *;
+
+ using const_reverse_iterator = std::reverse_iterator<const_iterator>;
+ using reverse_iterator = std::reverse_iterator<iterator>;
+
+ using value_type = ValueType;
+ using reference = ValueType &;
+ using const_reference = const ValueType &;
+ using pointer = ValueType *;
+ using const_pointer = const ValueType *;
+
EnumeratedArray() = default;
EnumeratedArray(ValueType V) {
for (IndexType IX = 0; IX < Size; ++IX) {
Underlying[IX] = V;
}
}
+ EnumeratedArray(std::initializer_list<ValueType> Init) {
+ assert(Init.size() == Size && "Incorrect initializer size");
+ for (IndexType IX = 0; IX < Size; ++IX) {
+ Underlying[IX] = *(Init.begin() + IX);
+ }
+ }
+
const ValueType &operator[](Enumeration Index) const {
- auto IX = static_cast<const IndexType>(Index);
+ auto IX = static_cast<IndexType>(Index);
assert(IX >= 0 && IX < Size && "Index is out of bounds.");
return Underlying[IX];
}
@@ -40,7 +60,23 @@ class EnumeratedArray {
static_cast<const EnumeratedArray<ValueType, Enumeration, LargestEnum,
IndexType, Size> &>(*this)[Index]);
}
- IndexType size() { return Size; }
+ IndexType size() const { return Size; }
+ bool empty() const { return size() == 0; }
+
+ iterator begin() { return Underlying; }
+ const_iterator begin() const { return Underlying; }
+
+ iterator end() { return begin() + size(); }
+ const_iterator end() const { return begin() + size(); }
+
+ reverse_iterator rbegin() { return reverse_iterator(end()); }
+ const_reverse_iterator rbegin() const {
+ return const_reverse_iterator(end());
+ }
+ reverse_iterator rend() { return reverse_iterator(begin()); }
+ const_reverse_iterator rend() const {
+ return const_reverse_iterator(begin());
+ }
private:
ValueType Underlying[Size];
diff --git a/llvm/unittests/ADT/EnumeratedArrayTest.cpp b/llvm/unittests/ADT/EnumeratedArrayTest.cpp
index 29107a7b9a7fa..9975428047141 100644
--- a/llvm/unittests/ADT/EnumeratedArrayTest.cpp
+++ b/llvm/unittests/ADT/EnumeratedArrayTest.cpp
@@ -11,7 +11,10 @@
//===----------------------------------------------------------------------===//
#include "llvm/ADT/EnumeratedArray.h"
+#include "llvm/ADT/iterator_range.h"
+#include "gmock/gmock.h"
#include "gtest/gtest.h"
+#include <type_traits>
namespace llvm {
@@ -46,6 +49,73 @@ TEST(EnumeratedArray, InitAndIndex) {
EXPECT_TRUE(Array2[Colors::Red]);
EXPECT_FALSE(Array2[Colors::Blue]);
EXPECT_TRUE(Array2[Colors::Green]);
+
+ EnumeratedArray<float, Colors, Colors::Last, size_t> Array3 = {10.0, 11.0,
+ 12.0};
+ EXPECT_EQ(Array3[Colors::Red], 10.0);
+ EXPECT_EQ(Array3[Colors::Blue], 11.0);
+ EXPECT_EQ(Array3[Colors::Green], 12.0);
+}
+
+//===--------------------------------------------------------------------===//
+// Test size and empty function
+//===--------------------------------------------------------------------===//
+
+TEST(EnumeratedArray, Size) {
+
+ enum class Colors { Red, Blue, Green, Last = Green };
+
+ EnumeratedArray<float, Colors, Colors::Last, size_t> Array;
+ const auto &ConstArray = Array;
+
+ EXPECT_EQ(ConstArray.size(), 3u);
+ EXPECT_EQ(ConstArray.empty(), false);
+}
+
+//===--------------------------------------------------------------------===//
+// Test iterators
+//===--------------------------------------------------------------------===//
+
+TEST(EnumeratedArray, Iterators) {
+
+ enum class Colors { Red, Blue, Green, Last = Green };
+
+ EnumeratedArray<float, Colors, Colors::Last, size_t> Array;
+ const auto &ConstArray = Array;
+
+ Array[Colors::Red] = 1.0;
+ Array[Colors::Blue] = 2.0;
+ Array[Colors::Green] = 3.0;
+
+ EXPECT_THAT(Array, testing::ElementsAre(1.0, 2.0, 3.0));
+ EXPECT_THAT(ConstArray, testing::ElementsAre(1.0, 2.0, 3.0));
+
+ EXPECT_THAT(make_range(Array.rbegin(), Array.rend()),
+ testing::ElementsAre(3.0, 2.0, 1.0));
+ EXPECT_THAT(make_range(ConstArray.rbegin(), ConstArray.rend()),
+ testing::ElementsAre(3.0, 2.0, 1.0));
}
+//===--------------------------------------------------------------------===//
+// Test typedefs
+//===--------------------------------------------------------------------===//
+
+namespace {
+
+enum class Colors { Red, Blue, Green, Last = Green };
+
+using Array = EnumeratedArray<float, Colors, Colors::Last, size_t>;
+
+static_assert(std::is_same<Array::value_type, float>::value,
+ "Incorrect value_type type");
+static_assert(std::is_same<Array::reference, float &>::value,
+ "Incorrect reference type!");
+static_assert(std::is_same<Array::pointer, float *>::value,
+ "Incorrect pointer type!");
+static_assert(std::is_same<Array::const_reference, const float &>::value,
+ "Incorrect const_reference type!");
+static_assert(std::is_same<Array::const_pointer, const float *>::value,
+ "Incorrect const_pointer type!");
+} // namespace
+
} // namespace llvm
More information about the llvm-commits
mailing list