[llvm] [RegAlloc] Don't call always-true ShouldAllocClass (PR #96296)

via llvm-commits llvm-commits at lists.llvm.org
Fri Jun 21 03:35:46 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-regalloc

Author: Alexis Engelke (aengelke)

<details>
<summary>Changes</summary>

Previously, there was at least one virtual function call for every allocated register. The only users of this feature are AMDGPU and RISC-V (RVV), other targets don't use this. To easily identify these cases, change the default functor to nullptr and don't call it for every allocated register.

---

Should be compatible with #<!-- -->93525.

http://llvm-compile-time-tracker.com/compare.php?from=90779fdc19dc15099231d6ebc39d9d76991d2d43&to=1054b2da3d05828a670480234560bfca2326a6af&stat=instructions:u

---
Full diff: https://github.com/llvm/llvm-project/pull/96296.diff


10 Files Affected:

- (modified) llvm/include/llvm/CodeGen/RegAllocCommon.h (+4-9) 
- (modified) llvm/include/llvm/CodeGen/RegAllocFast.h (+1-1) 
- (modified) llvm/include/llvm/Passes/PassBuilder.h (+3-1) 
- (modified) llvm/lib/CodeGen/RegAllocBase.cpp (+1-2) 
- (modified) llvm/lib/CodeGen/RegAllocBase.h (+13-2) 
- (modified) llvm/lib/CodeGen/RegAllocBasic.cpp (+1-1) 
- (modified) llvm/lib/CodeGen/RegAllocFast.cpp (+4-3) 
- (modified) llvm/lib/CodeGen/RegAllocGreedy.cpp (+2-2) 
- (modified) llvm/lib/CodeGen/RegAllocGreedy.h (+1-1) 
- (modified) llvm/lib/Passes/PassBuilder.cpp (+7-5) 


``````````diff
diff --git a/llvm/include/llvm/CodeGen/RegAllocCommon.h b/llvm/include/llvm/CodeGen/RegAllocCommon.h
index 757ca8e112eec..ad533eab1861c 100644
--- a/llvm/include/llvm/CodeGen/RegAllocCommon.h
+++ b/llvm/include/llvm/CodeGen/RegAllocCommon.h
@@ -16,16 +16,11 @@ namespace llvm {
 class TargetRegisterClass;
 class TargetRegisterInfo;
 
+/// Filter function for register classes during regalloc. Default register class
+/// filter is nullptr, where all registers should be allocated.
 typedef std::function<bool(const TargetRegisterInfo &TRI,
-                           const TargetRegisterClass &RC)> RegClassFilterFunc;
-
-/// Default register class filter function for register allocation. All virtual
-/// registers should be allocated.
-static inline bool allocateAllRegClasses(const TargetRegisterInfo &,
-                                         const TargetRegisterClass &) {
-  return true;
-}
-
+                           const TargetRegisterClass &RC)>
+    RegClassFilterFunc;
 }
 
 #endif // LLVM_CODEGEN_REGALLOCCOMMON_H
diff --git a/llvm/include/llvm/CodeGen/RegAllocFast.h b/llvm/include/llvm/CodeGen/RegAllocFast.h
index c50deccabd995..c62bd14d0b4cb 100644
--- a/llvm/include/llvm/CodeGen/RegAllocFast.h
+++ b/llvm/include/llvm/CodeGen/RegAllocFast.h
@@ -15,7 +15,7 @@
 namespace llvm {
 
 struct RegAllocFastPassOptions {
-  RegClassFilterFunc Filter = allocateAllRegClasses;
+  RegClassFilterFunc Filter = nullptr;
   StringRef FilterName = "all";
   bool ClearVRegs = true;
 };
diff --git a/llvm/include/llvm/Passes/PassBuilder.h b/llvm/include/llvm/Passes/PassBuilder.h
index ed817127c3db1..551d297e0c089 100644
--- a/llvm/include/llvm/Passes/PassBuilder.h
+++ b/llvm/include/llvm/Passes/PassBuilder.h
@@ -27,6 +27,7 @@
 #include "llvm/Transforms/IPO/ModuleInliner.h"
 #include "llvm/Transforms/Instrumentation.h"
 #include "llvm/Transforms/Scalar/LoopPassManager.h"
+#include <optional>
 #include <vector>
 
 namespace llvm {
@@ -390,7 +391,8 @@ class PassBuilder {
   Error parseAAPipeline(AAManager &AA, StringRef PipelineText);
 
   /// Parse RegClassFilterName to get RegClassFilterFunc.
-  RegClassFilterFunc parseRegAllocFilter(StringRef RegClassFilterName);
+  std::optional<RegClassFilterFunc>
+  parseRegAllocFilter(StringRef RegClassFilterName);
 
   /// Print pass names.
   void printPassNames(raw_ostream &OS);
diff --git a/llvm/lib/CodeGen/RegAllocBase.cpp b/llvm/lib/CodeGen/RegAllocBase.cpp
index d0dec372f6896..71288469b8f0f 100644
--- a/llvm/lib/CodeGen/RegAllocBase.cpp
+++ b/llvm/lib/CodeGen/RegAllocBase.cpp
@@ -181,8 +181,7 @@ void RegAllocBase::enqueue(const LiveInterval *LI) {
   if (VRM->hasPhys(Reg))
     return;
 
-  const TargetRegisterClass &RC = *MRI->getRegClass(Reg);
-  if (ShouldAllocateClass(*TRI, RC)) {
+  if (shouldAllocateRegister(Reg)) {
     LLVM_DEBUG(dbgs() << "Enqueuing " << printReg(Reg, TRI) << '\n');
     enqueueImpl(LI);
   } else {
diff --git a/llvm/lib/CodeGen/RegAllocBase.h b/llvm/lib/CodeGen/RegAllocBase.h
index a8bf305a50c98..643094671d682 100644
--- a/llvm/lib/CodeGen/RegAllocBase.h
+++ b/llvm/lib/CodeGen/RegAllocBase.h
@@ -37,6 +37,7 @@
 #define LLVM_LIB_CODEGEN_REGALLOCBASE_H
 
 #include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/CodeGen/MachineRegisterInfo.h"
 #include "llvm/CodeGen/RegAllocCommon.h"
 #include "llvm/CodeGen/RegisterClassInfo.h"
 
@@ -68,22 +69,32 @@ class RegAllocBase {
   LiveIntervals *LIS = nullptr;
   LiveRegMatrix *Matrix = nullptr;
   RegisterClassInfo RegClassInfo;
+
+private:
+  /// Private, callees should go through shouldAllocateRegister
   const RegClassFilterFunc ShouldAllocateClass;
 
+protected:
   /// Inst which is a def of an original reg and whose defs are already all
   /// dead after remat is saved in DeadRemats. The deletion of such inst is
   /// postponed till all the allocations are done, so its remat expr is
   /// always available for the remat of all the siblings of the original reg.
   SmallPtrSet<MachineInstr *, 32> DeadRemats;
 
-  RegAllocBase(const RegClassFilterFunc F = allocateAllRegClasses) :
-    ShouldAllocateClass(F) {}
+  RegAllocBase(const RegClassFilterFunc F = nullptr) : ShouldAllocateClass(F) {}
 
   virtual ~RegAllocBase() = default;
 
   // A RegAlloc pass should call this before allocatePhysRegs.
   void init(VirtRegMap &vrm, LiveIntervals &lis, LiveRegMatrix &mat);
 
+  /// Get whether a given register should be allocated
+  bool shouldAllocateRegister(Register Reg) {
+    if (!ShouldAllocateClass)
+      return true;
+    return ShouldAllocateClass(*TRI, *MRI->getRegClass(Reg));
+  }
+
   // The top-level driver. The output is a VirtRegMap that us updated with
   // physical register assignments.
   void allocatePhysRegs();
diff --git a/llvm/lib/CodeGen/RegAllocBasic.cpp b/llvm/lib/CodeGen/RegAllocBasic.cpp
index 181337ca4d60f..5d84e1e39e27c 100644
--- a/llvm/lib/CodeGen/RegAllocBasic.cpp
+++ b/llvm/lib/CodeGen/RegAllocBasic.cpp
@@ -74,7 +74,7 @@ class RABasic : public MachineFunctionPass,
   void LRE_WillShrinkVirtReg(Register) override;
 
 public:
-  RABasic(const RegClassFilterFunc F = allocateAllRegClasses);
+  RABasic(const RegClassFilterFunc F = nullptr);
 
   /// Return the pass name.
   StringRef getPassName() const override { return "Basic Register Allocator"; }
diff --git a/llvm/lib/CodeGen/RegAllocFast.cpp b/llvm/lib/CodeGen/RegAllocFast.cpp
index 09ce8c42a3850..dddc004be9293 100644
--- a/llvm/lib/CodeGen/RegAllocFast.cpp
+++ b/llvm/lib/CodeGen/RegAllocFast.cpp
@@ -177,7 +177,7 @@ class InstrPosIndexes {
 
 class RegAllocFastImpl {
 public:
-  RegAllocFastImpl(const RegClassFilterFunc F = allocateAllRegClasses,
+  RegAllocFastImpl(const RegClassFilterFunc F = nullptr,
                    bool ClearVirtRegs_ = true)
       : ShouldAllocateClass(F), StackSlotForVirtReg(-1),
         ClearVirtRegs(ClearVirtRegs_) {}
@@ -387,8 +387,7 @@ class RegAllocFast : public MachineFunctionPass {
 public:
   static char ID;
 
-  RegAllocFast(const RegClassFilterFunc F = allocateAllRegClasses,
-               bool ClearVirtRegs_ = true)
+  RegAllocFast(const RegClassFilterFunc F = nullptr, bool ClearVirtRegs_ = true)
       : MachineFunctionPass(ID), Impl(F, ClearVirtRegs_) {}
 
   bool runOnMachineFunction(MachineFunction &MF) override {
@@ -431,6 +430,8 @@ INITIALIZE_PASS(RegAllocFast, "regallocfast", "Fast Register Allocator", false,
 
 bool RegAllocFastImpl::shouldAllocateRegister(const Register Reg) const {
   assert(Reg.isVirtual());
+  if (!ShouldAllocateClass)
+    return true;
   const TargetRegisterClass &RC = *MRI->getRegClass(Reg);
   return ShouldAllocateClass(*TRI, RC);
 }
diff --git a/llvm/lib/CodeGen/RegAllocGreedy.cpp b/llvm/lib/CodeGen/RegAllocGreedy.cpp
index 500ceb3d8b700..19c1ee23af858 100644
--- a/llvm/lib/CodeGen/RegAllocGreedy.cpp
+++ b/llvm/lib/CodeGen/RegAllocGreedy.cpp
@@ -2308,7 +2308,7 @@ void RAGreedy::tryHintRecoloring(const LiveInterval &VirtReg) {
 
     // This may be a skipped class
     if (!VRM->hasPhys(Reg)) {
-      assert(!ShouldAllocateClass(*TRI, *MRI->getRegClass(Reg)) &&
+      assert(!shouldAllocateRegister(Reg) &&
              "We have an unallocated variable which should have been handled");
       continue;
     }
@@ -2698,7 +2698,7 @@ bool RAGreedy::hasVirtRegAlloc() {
     const TargetRegisterClass *RC = MRI->getRegClass(Reg);
     if (!RC)
       continue;
-    if (ShouldAllocateClass(*TRI, *RC))
+    if (shouldAllocateRegister(Reg))
       return true;
   }
 
diff --git a/llvm/lib/CodeGen/RegAllocGreedy.h b/llvm/lib/CodeGen/RegAllocGreedy.h
index 06cf0828ea79b..ac300c0024f5a 100644
--- a/llvm/lib/CodeGen/RegAllocGreedy.h
+++ b/llvm/lib/CodeGen/RegAllocGreedy.h
@@ -281,7 +281,7 @@ class LLVM_LIBRARY_VISIBILITY RAGreedy : public MachineFunctionPass,
   bool ReverseLocalAssignment = false;
 
 public:
-  RAGreedy(const RegClassFilterFunc F = allocateAllRegClasses);
+  RAGreedy(const RegClassFilterFunc F = nullptr);
 
   /// Return the pass name.
   StringRef getPassName() const override { return "Greedy Register Allocator"; }
diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp
index 19e8a8ab68a73..b1488f9b86886 100644
--- a/llvm/lib/Passes/PassBuilder.cpp
+++ b/llvm/lib/Passes/PassBuilder.cpp
@@ -1173,14 +1173,15 @@ parseRegAllocFastPassOptions(PassBuilder &PB, StringRef Params) {
     std::tie(ParamName, Params) = Params.split(';');
 
     if (ParamName.consume_front("filter=")) {
-      RegClassFilterFunc Filter = PB.parseRegAllocFilter(ParamName);
+      std::optional<RegClassFilterFunc> Filter =
+          PB.parseRegAllocFilter(ParamName);
       if (!Filter) {
         return make_error<StringError>(
             formatv("invalid regallocfast register filter '{0}' ", ParamName)
                 .str(),
             inconvertibleErrorCode());
       }
-      Opts.Filter = Filter;
+      Opts.Filter = *Filter;
       Opts.FilterName = ParamName;
       continue;
     }
@@ -2220,13 +2221,14 @@ Error PassBuilder::parseAAPipeline(AAManager &AA, StringRef PipelineText) {
   return Error::success();
 }
 
-RegClassFilterFunc PassBuilder::parseRegAllocFilter(StringRef FilterName) {
+std::optional<RegClassFilterFunc>
+PassBuilder::parseRegAllocFilter(StringRef FilterName) {
   if (FilterName == "all")
-    return allocateAllRegClasses;
+    return nullptr;
   for (auto &C : RegClassFilterParsingCallbacks)
     if (auto F = C(FilterName))
       return F;
-  return nullptr;
+  return std::nullopt;
 }
 
 static void printPassName(StringRef PassName, raw_ostream &OS) {

``````````

</details>


https://github.com/llvm/llvm-project/pull/96296


More information about the llvm-commits mailing list