[llvm] [ADT] Make set_subtract more efficient when subtrahend is larger (NFC) (PR #98702)

Teresa Johnson via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 15 11:29:59 PDT 2024


https://github.com/teresajohnson updated https://github.com/llvm/llvm-project/pull/98702

>From 0d947a26aedc57deef7935896dadddf081022f40 Mon Sep 17 00:00:00 2001
From: Teresa Johnson <tejohnson at google.com>
Date: Fri, 12 Jul 2024 16:26:54 -0700
Subject: [PATCH 1/2] [ADT] Make set_subtract more efficient when subtrahend is
 larger (NFC)

If the subtrahend is larger, iterate the minuend set instead.

Noticed when subtracting a large set from a number of other smaller
sets for an upcoming MemProf change, this change makes that much faster.

I subsequently found a couple of callsites in one file that were calling
set_subtract with a vector subtrahend, which doesn't have the "count()"
interface. Add a separate helper for subtracting a vector.
---
 llvm/include/llvm/ADT/SetOperations.h | 18 ++++++++++++++++++
 llvm/lib/CodeGen/MachineVerifier.cpp  |  6 ++++--
 2 files changed, 22 insertions(+), 2 deletions(-)

diff --git a/llvm/include/llvm/ADT/SetOperations.h b/llvm/include/llvm/ADT/SetOperations.h
index 1a911b239f4c6..7be7a55ae13c8 100644
--- a/llvm/include/llvm/ADT/SetOperations.h
+++ b/llvm/include/llvm/ADT/SetOperations.h
@@ -92,9 +92,27 @@ S1Ty set_difference(const S1Ty &S1, const S2Ty &S2) {
   return Result;
 }
 
+/// set_subtract_vec(A, B) - Compute A := A - B, where B can be a vector.
+///
+template <class S1Ty, class S2Ty>
+void set_subtract_vec(S1Ty &S1, const S2Ty &S2) {
+  for (typename S2Ty::const_iterator SI = S2.begin(), SE = S2.end(); SI != SE;
+       ++SI)
+    S1.erase(*SI);
+}
+
 /// set_subtract(A, B) - Compute A := A - B
 ///
+/// Selects the set to iterate based on the relative sizes of A and B for better
+/// efficiency.
+///
 template <class S1Ty, class S2Ty> void set_subtract(S1Ty &S1, const S2Ty &S2) {
+  if (S1.size() < S2.size()) {
+    for (typename S1Ty::iterator SI = S1.begin(), SE = S1.end(); SI != SE; ++SI)
+      if (S2.count(*SI))
+        S1.erase(SI);
+    return;
+  }
   for (typename S2Ty::const_iterator SI = S2.begin(), SE = S2.end(); SI != SE;
        ++SI)
     S1.erase(*SI);
diff --git a/llvm/lib/CodeGen/MachineVerifier.cpp b/llvm/lib/CodeGen/MachineVerifier.cpp
index d0d3af0e5e4fc..c80edf974a67e 100644
--- a/llvm/lib/CodeGen/MachineVerifier.cpp
+++ b/llvm/lib/CodeGen/MachineVerifier.cpp
@@ -2909,7 +2909,8 @@ void MachineVerifier::checkLiveness(const MachineOperand *MO, unsigned MONum) {
 void MachineVerifier::visitMachineBundleAfter(const MachineInstr *MI) {
   BBInfo &MInfo = MBBInfoMap[MI->getParent()];
   set_union(MInfo.regsKilled, regsKilled);
-  set_subtract(regsLive, regsKilled); regsKilled.clear();
+  set_subtract_vec(regsLive, regsKilled);
+  regsKilled.clear();
   // Kill any masked registers.
   while (!regMasks.empty()) {
     const uint32_t *Mask = regMasks.pop_back_val();
@@ -2918,7 +2919,8 @@ void MachineVerifier::visitMachineBundleAfter(const MachineInstr *MI) {
           MachineOperand::clobbersPhysReg(Mask, Reg.asMCReg()))
         regsDead.push_back(Reg);
   }
-  set_subtract(regsLive, regsDead);   regsDead.clear();
+  set_subtract_vec(regsLive, regsDead);
+  regsDead.clear();
   set_union(regsLive, regsDefined);   regsDefined.clear();
 }
 

>From bc0335378fa646fa6e386537450cfb3b69bfc9b6 Mon Sep 17 00:00:00 2001
From: Teresa Johnson <tejohnson at google.com>
Date: Mon, 15 Jul 2024 11:23:59 -0700
Subject: [PATCH 2/2] Address comments

---
 llvm/include/llvm/ADT/SetOperations.h | 27 +++++++++++++--------------
 llvm/lib/CodeGen/MachineVerifier.cpp  |  6 ++----
 2 files changed, 15 insertions(+), 18 deletions(-)

diff --git a/llvm/include/llvm/ADT/SetOperations.h b/llvm/include/llvm/ADT/SetOperations.h
index 7be7a55ae13c8..f74b48d290248 100644
--- a/llvm/include/llvm/ADT/SetOperations.h
+++ b/llvm/include/llvm/ADT/SetOperations.h
@@ -92,26 +92,25 @@ S1Ty set_difference(const S1Ty &S1, const S2Ty &S2) {
   return Result;
 }
 
-/// set_subtract_vec(A, B) - Compute A := A - B, where B can be a vector.
-///
-template <class S1Ty, class S2Ty>
-void set_subtract_vec(S1Ty &S1, const S2Ty &S2) {
-  for (typename S2Ty::const_iterator SI = S2.begin(), SE = S2.end(); SI != SE;
-       ++SI)
-    S1.erase(*SI);
-}
-
 /// set_subtract(A, B) - Compute A := A - B
 ///
 /// Selects the set to iterate based on the relative sizes of A and B for better
 /// efficiency.
 ///
 template <class S1Ty, class S2Ty> void set_subtract(S1Ty &S1, const S2Ty &S2) {
-  if (S1.size() < S2.size()) {
-    for (typename S1Ty::iterator SI = S1.begin(), SE = S1.end(); SI != SE; ++SI)
-      if (S2.count(*SI))
-        S1.erase(SI);
-    return;
+  using ElemTy = decltype(*S1.begin());
+  // A couple callers pass a vector for S2, which doesn't support count(), and
+  // wouldn't be efficient if it did. In the absence of a more direct check,
+  // ensure the type supports the contains or find interfaces.
+  if constexpr (detail::HasMemberContains<S2Ty, ElemTy> ||
+                detail::HasMemberFind<S2Ty, ElemTy>) {
+    if (S1.size() < S2.size()) {
+      for (typename S1Ty::iterator SI = S1.begin(), SE = S1.end(); SI != SE;
+           ++SI)
+        if (S2.count(*SI))
+          S1.erase(SI);
+      return;
+    }
   }
   for (typename S2Ty::const_iterator SI = S2.begin(), SE = S2.end(); SI != SE;
        ++SI)
diff --git a/llvm/lib/CodeGen/MachineVerifier.cpp b/llvm/lib/CodeGen/MachineVerifier.cpp
index c80edf974a67e..d0d3af0e5e4fc 100644
--- a/llvm/lib/CodeGen/MachineVerifier.cpp
+++ b/llvm/lib/CodeGen/MachineVerifier.cpp
@@ -2909,8 +2909,7 @@ void MachineVerifier::checkLiveness(const MachineOperand *MO, unsigned MONum) {
 void MachineVerifier::visitMachineBundleAfter(const MachineInstr *MI) {
   BBInfo &MInfo = MBBInfoMap[MI->getParent()];
   set_union(MInfo.regsKilled, regsKilled);
-  set_subtract_vec(regsLive, regsKilled);
-  regsKilled.clear();
+  set_subtract(regsLive, regsKilled); regsKilled.clear();
   // Kill any masked registers.
   while (!regMasks.empty()) {
     const uint32_t *Mask = regMasks.pop_back_val();
@@ -2919,8 +2918,7 @@ void MachineVerifier::visitMachineBundleAfter(const MachineInstr *MI) {
           MachineOperand::clobbersPhysReg(Mask, Reg.asMCReg()))
         regsDead.push_back(Reg);
   }
-  set_subtract_vec(regsLive, regsDead);
-  regsDead.clear();
+  set_subtract(regsLive, regsDead);   regsDead.clear();
   set_union(regsLive, regsDefined);   regsDefined.clear();
 }
 



More information about the llvm-commits mailing list