[llvm] [EquivClasses] Shorten members_{begin,end} idiom (PR #134373)

via llvm-commits llvm-commits at lists.llvm.org
Fri Apr 4 05:44:51 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-amdgpu

Author: Ramkumar Ramachandra (artagnon)

<details>
<summary>Changes</summary>

Introduce members() iterator-helper to shorten the members_{begin,end} idiom. A previous attempt of this patch was #<!-- -->130319, which had to be reverted due to unit-test failures when attempting to call members() on the end iterator. In this patch, members() accepts either an ECValue or an ElemTy, which is more inututive and doesn't suffer from the same issue.

---
Full diff: https://github.com/llvm/llvm-project/pull/134373.diff


7 Files Affected:

- (modified) llvm/include/llvm/ADT/EquivalenceClasses.h (+9) 
- (modified) llvm/lib/Analysis/LoopAccessAnalysis.cpp (+2-3) 
- (modified) llvm/lib/Analysis/VectorUtils.cpp (+3-3) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPUSplitModule.cpp (+2-3) 
- (modified) llvm/lib/Transforms/IPO/LowerTypeTests.cpp (+6-7) 
- (modified) llvm/lib/Transforms/Scalar/Float2Int.cpp (+4-6) 
- (modified) llvm/unittests/ADT/EquivalenceClassesTest.cpp (+13) 


``````````diff
diff --git a/llvm/include/llvm/ADT/EquivalenceClasses.h b/llvm/include/llvm/ADT/EquivalenceClasses.h
index 906971baf74af..ad1f385cd9414 100644
--- a/llvm/include/llvm/ADT/EquivalenceClasses.h
+++ b/llvm/include/llvm/ADT/EquivalenceClasses.h
@@ -16,6 +16,7 @@
 #define LLVM_ADT_EQUIVALENCECLASSES_H
 
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/iterator_range.h"
 #include <cassert>
 #include <cstddef>
 #include <cstdint>
@@ -184,6 +185,14 @@ class EquivalenceClasses {
     return member_iterator(nullptr);
   }
 
+  iterator_range<member_iterator> members(const ECValue &ECV) const {
+    return make_range(member_begin(ECV), member_end());
+  }
+
+  iterator_range<member_iterator> members(const ElemTy &V) const {
+    return make_range(findLeader(V), member_end());
+  }
+
   /// Returns true if \p V is contained an equivalence class.
   bool contains(const ElemTy &V) const {
     return TheMapping.find(V) != TheMapping.end();
diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
index 47ff31b9a0525..a37ed5c706bdb 100644
--- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp
+++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
@@ -526,9 +526,8 @@ void RuntimePointerChecking::groupChecks(
     // iteration order within an equivalence class member is only dependent on
     // the order in which unions and insertions are performed on the
     // equivalence class, the iteration order is deterministic.
-    for (auto MI = DepCands.findLeader(Access), ME = DepCands.member_end();
-         MI != ME; ++MI) {
-      auto PointerI = PositionMap.find(MI->getPointer());
+    for (auto M : DepCands.members(Access)) {
+      auto PointerI = PositionMap.find(M.getPointer());
       assert(PointerI != PositionMap.end() &&
              "pointer in equivalence class not found in PositionMap");
       for (unsigned Pointer : PointerI->second) {
diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp
index 663b961da848d..46f588f4c6705 100644
--- a/llvm/lib/Analysis/VectorUtils.cpp
+++ b/llvm/lib/Analysis/VectorUtils.cpp
@@ -847,7 +847,7 @@ llvm::computeMinimumValueSizes(ArrayRef<BasicBlock *> Blocks, DemandedBits &DB,
     if (!E->isLeader())
       continue;
     uint64_t LeaderDemandedBits = 0;
-    for (Value *M : make_range(ECs.member_begin(*E), ECs.member_end()))
+    for (Value *M : ECs.members(*E))
       LeaderDemandedBits |= DBits[M];
 
     uint64_t MinBW = llvm::bit_width(LeaderDemandedBits);
@@ -859,7 +859,7 @@ llvm::computeMinimumValueSizes(ArrayRef<BasicBlock *> Blocks, DemandedBits &DB,
     // indvars.
     // If we are required to shrink a PHI, abandon this entire equivalence class.
     bool Abort = false;
-    for (Value *M : make_range(ECs.member_begin(*E), ECs.member_end()))
+    for (Value *M : ECs.members(*E))
       if (isa<PHINode>(M) && MinBW < M->getType()->getScalarSizeInBits()) {
         Abort = true;
         break;
@@ -867,7 +867,7 @@ llvm::computeMinimumValueSizes(ArrayRef<BasicBlock *> Blocks, DemandedBits &DB,
     if (Abort)
       continue;
 
-    for (Value *M : make_range(ECs.member_begin(*E), ECs.member_end())) {
+    for (Value *M : ECs.members(*E)) {
       auto *MI = dyn_cast<Instruction>(M);
       if (!MI)
         continue;
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUSplitModule.cpp b/llvm/lib/Target/AMDGPU/AMDGPUSplitModule.cpp
index 32472201cf9c2..dd3bec774ec67 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUSplitModule.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUSplitModule.cpp
@@ -1021,9 +1021,8 @@ void RecursiveSearchSplitting::setupWorkList() {
       continue;
 
     BitVector Cluster = SG.createNodesBitVector();
-    for (auto MI = NodeEC.member_begin(*Node); MI != NodeEC.member_end();
-         ++MI) {
-      const SplitGraph::Node &N = SG.getNode(*MI);
+    for (unsigned M : NodeEC.members(*Node)) {
+      const SplitGraph::Node &N = SG.getNode(M);
       if (N.isGraphEntryPoint())
         N.getDependencies(Cluster);
     }
diff --git a/llvm/lib/Transforms/IPO/LowerTypeTests.cpp b/llvm/lib/Transforms/IPO/LowerTypeTests.cpp
index fcd8918f1d9d7..7cf7d74acfcfa 100644
--- a/llvm/lib/Transforms/IPO/LowerTypeTests.cpp
+++ b/llvm/lib/Transforms/IPO/LowerTypeTests.cpp
@@ -2349,14 +2349,13 @@ bool LowerTypeTestsModule::lower() {
     std::vector<Metadata *> TypeIds;
     std::vector<GlobalTypeMember *> Globals;
     std::vector<ICallBranchFunnel *> ICallBranchFunnels;
-    for (GlobalClassesTy::member_iterator MI = GlobalClasses.member_begin(*C);
-         MI != GlobalClasses.member_end(); ++MI) {
-      if (isa<Metadata *>(*MI))
-        TypeIds.push_back(cast<Metadata *>(*MI));
-      else if (isa<GlobalTypeMember *>(*MI))
-        Globals.push_back(cast<GlobalTypeMember *>(*MI));
+    for (auto M : GlobalClasses.members(*C)) {
+      if (isa<Metadata *>(M))
+        TypeIds.push_back(cast<Metadata *>(M));
+      else if (isa<GlobalTypeMember *>(M))
+        Globals.push_back(cast<GlobalTypeMember *>(M));
       else
-        ICallBranchFunnels.push_back(cast<ICallBranchFunnel *>(*MI));
+        ICallBranchFunnels.push_back(cast<ICallBranchFunnel *>(M));
     }
 
     // Order type identifiers by unique ID for determinism. This ordering is
diff --git a/llvm/lib/Transforms/Scalar/Float2Int.cpp b/llvm/lib/Transforms/Scalar/Float2Int.cpp
index 927877b3135e5..14686ce8c2ab6 100644
--- a/llvm/lib/Transforms/Scalar/Float2Int.cpp
+++ b/llvm/lib/Transforms/Scalar/Float2Int.cpp
@@ -320,10 +320,8 @@ bool Float2IntPass::validateAndTransform(const DataLayout &DL) {
     Type *ConvertedToTy = nullptr;
 
     // For every member of the partition, union all the ranges together.
-    for (auto MI = ECs.member_begin(*E), ME = ECs.member_end(); MI != ME;
-         ++MI) {
-      Instruction *I = *MI;
-      auto SeenI = SeenInsts.find(I);
+    for (Instruction *I : ECs.members(*E)) {
+      auto *SeenI = SeenInsts.find(I);
       if (SeenI == SeenInsts.end())
         continue;
 
@@ -391,8 +389,8 @@ bool Float2IntPass::validateAndTransform(const DataLayout &DL) {
       }
     }
 
-    for (auto MI = ECs.member_begin(*E), ME = ECs.member_end(); MI != ME; ++MI)
-      convert(*MI, Ty);
+    for (Instruction *I : ECs.members(*E))
+      convert(I, Ty);
     MadeChange = true;
   }
 
diff --git a/llvm/unittests/ADT/EquivalenceClassesTest.cpp b/llvm/unittests/ADT/EquivalenceClassesTest.cpp
index bfb7c8d185fc8..2f9c441cde5c7 100644
--- a/llvm/unittests/ADT/EquivalenceClassesTest.cpp
+++ b/llvm/unittests/ADT/EquivalenceClassesTest.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/ADT/EquivalenceClasses.h"
+#include "gmock/gmock.h"
 #include "gtest/gtest.h"
 
 using namespace llvm;
@@ -75,6 +76,18 @@ TEST(EquivalenceClassesTest, TwoSets) {
         EXPECT_FALSE(EqClasses.isEquivalent(i, j));
 }
 
+TEST(EquivalenceClassesTest, MembersIterator) {
+  EquivalenceClasses<int> EC;
+  EC.unionSets(1, 2);
+  EC.insert(4);
+  EC.insert(5);
+  EC.unionSets(5, 1);
+  EXPECT_EQ(EC.getNumClasses(), 2u);
+
+  EXPECT_THAT(EC.members(4), testing::ElementsAre(4));
+  EXPECT_THAT(EC.members(1), testing::ElementsAre(5, 1, 2));
+}
+
 // Type-parameterized tests: Run the same test cases with different element
 // types.
 template <typename T> class ParameterizedTest : public testing::Test {};

``````````

</details>


https://github.com/llvm/llvm-project/pull/134373


More information about the llvm-commits mailing list