[llvm] DenseMap: support enum class keys (PR #95972)

Ramkumar Ramachandra via llvm-commits llvm-commits at lists.llvm.org
Tue Jun 18 12:10:03 PDT 2024


https://github.com/artagnon created https://github.com/llvm/llvm-project/pull/95972

Implemented using std::underlying_type.

>From e579faa93faa5f38ebc7d016791a7c368ffd585f Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Tue, 18 Jun 2024 19:27:45 +0100
Subject: [PATCH] DenseMap: support enum class keys

Implemented using std::underlying_type.
---
 llvm/include/llvm/ADT/DenseMapInfo.h | 21 +++++++++++++++++++++
 llvm/unittests/ADT/MapVectorTest.cpp | 24 ++++++++++++++++++++++++
 2 files changed, 45 insertions(+)

diff --git a/llvm/include/llvm/ADT/DenseMapInfo.h b/llvm/include/llvm/ADT/DenseMapInfo.h
index 5b7dce7b53c62..14440d123289e 100644
--- a/llvm/include/llvm/ADT/DenseMapInfo.h
+++ b/llvm/include/llvm/ADT/DenseMapInfo.h
@@ -297,6 +297,27 @@ template <typename... Ts> struct DenseMapInfo<std::tuple<Ts...>> {
   }
 };
 
+// Provide DenseMapInfo for enum classes.
+template <typename Enum>
+struct DenseMapInfo<Enum, std::enable_if_t<std::is_enum_v<Enum>>> {
+  using UnderlyingType = std::underlying_type_t<Enum>;
+  using Info = DenseMapInfo<UnderlyingType>;
+
+  static Enum getEmptyKey() { return static_cast<Enum>(Info::getEmptyKey()); }
+
+  static Enum getTombstoneKey() {
+    return static_cast<Enum>(Info::getTombstoneKey());
+  }
+
+  static unsigned getHashValue(const Enum &Val) {
+    return Info::getHashValue(static_cast<UnderlyingType>(Val));
+  }
+
+  static bool isEqual(const Enum &LHS, const Enum &RHS) {
+    return Info::isEqual(static_cast<UnderlyingType>(LHS),
+                         static_cast<UnderlyingType>(RHS));
+  }
+};
 } // end namespace llvm
 
 #endif // LLVM_ADT_DENSEMAPINFO_H
diff --git a/llvm/unittests/ADT/MapVectorTest.cpp b/llvm/unittests/ADT/MapVectorTest.cpp
index e0f11b60a0223..2c61acb29ce20 100644
--- a/llvm/unittests/ADT/MapVectorTest.cpp
+++ b/llvm/unittests/ADT/MapVectorTest.cpp
@@ -267,6 +267,30 @@ TEST(MapVectorTest, NonCopyable) {
   ASSERT_EQ(*MV.find(2)->second, 2);
 }
 
+TEST(MapVectorTest, EnumClassKey) {
+  enum class EC1 { ValA, ValB };
+  enum class EC2 { ValA, ValB };
+  MapVector<EC1, int> MV1;
+  MapVector<EC2, int> MV2;
+
+  ASSERT_TRUE(MV1.empty());
+  ASSERT_TRUE(MV2.empty());
+  MV1.insert({EC1::ValA, 13});
+  MV1.insert({EC1::ValB, 7});
+  MV2.insert({EC2::ValA, 42});
+  MV2.insert({EC2::ValB, 43});
+
+  ASSERT_EQ(MV1.count(EC1::ValA), 1u);
+  ASSERT_EQ(MV2.count(EC2::ValA), 1u);
+  ASSERT_EQ(MV2[EC2::ValB], 43);
+  ASSERT_NE(DenseMapInfo<EC1>::getHashValue(EC1::ValA),
+            DenseMapInfo<EC1>::getHashValue(EC1::ValB));
+  ASSERT_NE(EC2::ValA, DenseMapInfo<EC2>::getTombstoneKey());
+  ASSERT_NE(EC2::ValB, DenseMapInfo<EC2>::getTombstoneKey());
+  ASSERT_NE(EC2::ValA, DenseMapInfo<EC2>::getEmptyKey());
+  ASSERT_NE(EC2::ValB, DenseMapInfo<EC2>::getEmptyKey());
+}
+
 template <class IntType> struct MapVectorMappedTypeTest : ::testing::Test {
   using int_type = IntType;
 };



More information about the llvm-commits mailing list