[llvm] [mlir] [EquivalenceClasses] Use DenseMap instead of std::set. (NFC) (PR #134264)

via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 3 08:58:33 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Florian Hahn (fhahn)

<details>
<summary>Changes</summary>

Replace the std::set with DenseMap, which removes the requirement for an ordering predicate. This also requires to allocate the ECValue objects separately. This patch uses a BumpPtrAllocator.

Follow-up to https://github.com/llvm/llvm-project/pull/134075.

Not super sure there's a big benefit, compile-time impact is mostly neutral or slightly positive:
https://llvm-compile-time-tracker.com/compare.php?from=ee4e8197fa67dd1ed6e9470e00708e7feeaacd97&to=242e6a8e42889eebfc0bb5d433a4de7dd9e224a7&stat=instructions:u

---
Full diff: https://github.com/llvm/llvm-project/pull/134264.diff


4 Files Affected:

- (modified) llvm/include/llvm/ADT/EquivalenceClasses.h (+24-42) 
- (modified) llvm/lib/Transforms/Vectorize/VectorCombine.cpp (+1) 
- (modified) llvm/unittests/ADT/EquivalenceClassesTest.cpp (-26) 
- (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h (+4-13) 


``````````diff
diff --git a/llvm/include/llvm/ADT/EquivalenceClasses.h b/llvm/include/llvm/ADT/EquivalenceClasses.h
index 906971baf74af..22322ecb98cb9 100644
--- a/llvm/include/llvm/ADT/EquivalenceClasses.h
+++ b/llvm/include/llvm/ADT/EquivalenceClasses.h
@@ -15,12 +15,13 @@
 #ifndef LLVM_ADT_EQUIVALENCECLASSES_H
 #define LLVM_ADT_EQUIVALENCECLASSES_H
 
+#include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Allocator.h"
 #include <cassert>
 #include <cstddef>
 #include <cstdint>
 #include <iterator>
-#include <set>
 
 namespace llvm {
 
@@ -32,8 +33,7 @@ namespace llvm {
 ///
 /// This implementation is an efficient implementation that only stores one copy
 /// of the element being indexed per entry in the set, and allows any arbitrary
-/// type to be indexed (as long as it can be ordered with operator< or a
-/// comparator is provided).
+/// type to be indexed (as long as it can be implements DenseMapInfo).
 ///
 /// Here is a simple example using integers:
 ///
@@ -57,18 +57,17 @@ namespace llvm {
 ///   4
 ///   5 1 2
 ///
-template <class ElemTy, class Compare = std::less<ElemTy>>
-class EquivalenceClasses {
+template <class ElemTy> class EquivalenceClasses {
   /// ECValue - The EquivalenceClasses data structure is just a set of these.
   /// Each of these represents a relation for a value.  First it stores the
-  /// value itself, which provides the ordering that the set queries.  Next, it
-  /// provides a "next pointer", which is used to enumerate all of the elements
-  /// in the unioned set.  Finally, it defines either a "end of list pointer" or
-  /// "leader pointer" depending on whether the value itself is a leader.  A
-  /// "leader pointer" points to the node that is the leader for this element,
-  /// if the node is not a leader.  A "end of list pointer" points to the last
-  /// node in the list of members of this list.  Whether or not a node is a
-  /// leader is determined by a bit stolen from one of the pointers.
+  /// value itself. Next, it provides a "next pointer", which is used to
+  /// enumerate all of the elements in the unioned set.  Finally, it defines
+  /// either a "end of list pointer" or "leader pointer" depending on whether
+  /// the value itself is a leader. A "leader pointer" points to the node that
+  /// is the leader for this element, if the node is not a leader.  A "end of
+  /// list pointer" points to the last node in the list of members of this list.
+  /// Whether or not a node is a leader is determined by a bit stolen from one
+  /// of the pointers.
   class ECValue {
     friend class EquivalenceClasses;
 
@@ -112,36 +111,15 @@ class EquivalenceClasses {
     }
   };
 
-  /// A wrapper of the comparator, to be passed to the set.
-  struct ECValueComparator {
-    using is_transparent = void;
-
-    ECValueComparator() : compare(Compare()) {}
-
-    bool operator()(const ECValue &lhs, const ECValue &rhs) const {
-      return compare(lhs.Data, rhs.Data);
-    }
-
-    template <typename T>
-    bool operator()(const T &lhs, const ECValue &rhs) const {
-      return compare(lhs, rhs.Data);
-    }
-
-    template <typename T>
-    bool operator()(const ECValue &lhs, const T &rhs) const {
-      return compare(lhs.Data, rhs);
-    }
-
-    const Compare compare;
-  };
-
   /// TheMapping - This implicitly provides a mapping from ElemTy values to the
   /// ECValues, it just keeps the key as part of the value.
-  std::set<ECValue, ECValueComparator> TheMapping;
+  DenseMap<ElemTy, ECValue *> TheMapping;
 
   /// List of all members, used to provide a determinstic iteration order.
   SmallVector<const ECValue *> Members;
 
+  mutable BumpPtrAllocator ECValueAllocator;
+
 public:
   EquivalenceClasses() = default;
   EquivalenceClasses(const EquivalenceClasses &RHS) {
@@ -223,10 +201,14 @@ class EquivalenceClasses {
   /// insert - Insert a new value into the union/find set, ignoring the request
   /// if the value already exists.
   const ECValue &insert(const ElemTy &Data) {
-    auto I = TheMapping.insert(ECValue(Data));
-    if (I.second)
-      Members.push_back(&*I.first);
-    return *I.first;
+    auto I = TheMapping.insert({Data, nullptr});
+    if (!I.second)
+      return *I.first->second;
+
+    auto *ECV = new (ECValueAllocator) ECValue(Data);
+    I.first->second = ECV;
+    Members.push_back(ECV);
+    return *ECV;
   }
 
   /// findLeader - Given a value in the set, return a member iterator for the
@@ -237,7 +219,7 @@ class EquivalenceClasses {
     auto I = TheMapping.find(V);
     if (I == TheMapping.end())
       return member_iterator(nullptr);
-    return findLeader(*I);
+    return findLeader(*I->second);
   }
   member_iterator findLeader(const ECValue &ECV) const {
     return member_iterator(ECV.getLeader());
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 4bfe41a5ed00d..ffb82bd5baf4e 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -33,6 +33,7 @@
 #include "llvm/Transforms/Utils/LoopUtils.h"
 #include <numeric>
 #include <queue>
+#include <set>
 
 #define DEBUG_TYPE "vector-combine"
 #include "llvm/Transforms/Utils/InstructionWorklist.h"
diff --git a/llvm/unittests/ADT/EquivalenceClassesTest.cpp b/llvm/unittests/ADT/EquivalenceClassesTest.cpp
index bfb7c8d185fc8..4391351743551 100644
--- a/llvm/unittests/ADT/EquivalenceClassesTest.cpp
+++ b/llvm/unittests/ADT/EquivalenceClassesTest.cpp
@@ -96,30 +96,4 @@ TYPED_TEST_P(ParameterizedTest, MultipleSets) {
         EXPECT_FALSE(EqClasses.isEquivalent(i, j));
 }
 
-namespace {
-// A dummy struct for testing EquivalenceClasses with a comparator.
-struct TestStruct {
-  TestStruct(int value) : value(value) {}
-
-  bool operator==(const TestStruct &other) const {
-    return value == other.value;
-  }
-
-  int value;
-};
-// Comparator to be used in test case.
-struct TestStructComparator {
-  bool operator()(const TestStruct &lhs, const TestStruct &rhs) const {
-    return lhs.value < rhs.value;
-  }
-};
-} // namespace
-
-REGISTER_TYPED_TEST_SUITE_P(ParameterizedTest, MultipleSets);
-using ParamTypes =
-    testing::Types<EquivalenceClasses<int>,
-                   EquivalenceClasses<TestStruct, TestStructComparator>>;
-INSTANTIATE_TYPED_TEST_SUITE_P(EquivalenceClassesTest, ParameterizedTest,
-                               ParamTypes, );
-
 } // llvm
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
index bd23a19f74728..673027f76190d 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
@@ -224,17 +224,8 @@ class OneShotAnalysisState : public AnalysisState {
   }
 
 private:
-  /// llvm::EquivalenceClasses wants comparable elements. This comparator uses
-  /// pointer comparison on the defining op. This is a poor man's comparison
-  /// but it's not like UnionFind needs ordering anyway.
-  struct ValueComparator {
-    bool operator()(const Value &lhs, const Value &rhs) const {
-      return lhs.getImpl() < rhs.getImpl();
-    }
-  };
-
-  using EquivalenceClassRangeType = llvm::iterator_range<
-      llvm::EquivalenceClasses<Value, ValueComparator>::member_iterator>;
+  using EquivalenceClassRangeType =
+      llvm::iterator_range<llvm::EquivalenceClasses<Value>::member_iterator>;
   /// Check that aliasInfo for `v` exists and return a reference to it.
   EquivalenceClassRangeType getAliases(Value v) const;
 
@@ -249,7 +240,7 @@ class OneShotAnalysisState : public AnalysisState {
   /// value may alias with one of multiple other values. The concrete aliasing
   /// value may not even be known at compile time. All such values are
   /// considered to be aliases.
-  llvm::EquivalenceClasses<Value, ValueComparator> aliasInfo;
+  llvm::EquivalenceClasses<Value> aliasInfo;
 
   /// Auxiliary structure to store all the equivalent buffer classes. Equivalent
   /// buffer information is "must be" conservative: Only if two values are
@@ -257,7 +248,7 @@ class OneShotAnalysisState : public AnalysisState {
   /// possible that, in the presence of branches, it cannot be determined
   /// statically if two values are equivalent. In that case, the values are
   /// considered to be not equivalent.
-  llvm::EquivalenceClasses<Value, ValueComparator> equivalentInfo;
+  llvm::EquivalenceClasses<Value> equivalentInfo;
 
   // Bufferization statistics.
   int64_t statNumTensorOutOfPlace = 0;

``````````

</details>


https://github.com/llvm/llvm-project/pull/134264


More information about the llvm-commits mailing list