[llvm] [CodeGen] change prototype of regalloc filter function (PR #93525)

Christudasan Devadasan via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 22 02:34:09 PDT 2024


https://github.com/cdevadas updated https://github.com/llvm/llvm-project/pull/93525

>From 1635872d9d8e370355a38d782e8d4a8d0235690c Mon Sep 17 00:00:00 2001
From: Christudasan Devadasan <Christudasan.Devadasan at amd.com>
Date: Mon, 27 May 2024 10:19:36 +0530
Subject: [PATCH] [CodeGen] change prototype of regalloc filter function

change the prototype of the filter function so that we can
filter not just by RegClass. We need to implement more
complicated filter based upon some other info associated
with each register.

Patch provided by: Gang Chen (gangc at amd.com)
---
 llvm/include/llvm/CodeGen/Passes.h             |  6 +++---
 llvm/include/llvm/CodeGen/RegAllocCommon.h     |  6 ++++--
 llvm/include/llvm/CodeGen/RegAllocFast.h       |  2 +-
 llvm/include/llvm/Passes/PassBuilder.h         | 10 +++++-----
 llvm/lib/CodeGen/RegAllocBase.h                |  9 +++++----
 llvm/lib/CodeGen/RegAllocBasic.cpp             | 10 ++++------
 llvm/lib/CodeGen/RegAllocFast.cpp              | 16 ++++++++--------
 llvm/lib/CodeGen/RegAllocGreedy.cpp            | 10 ++++------
 llvm/lib/CodeGen/RegAllocGreedy.h              |  2 +-
 llvm/lib/Passes/PassBuilder.cpp                |  4 ++--
 llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp | 15 +++++++++------
 llvm/lib/Target/RISCV/RISCVTargetMachine.cpp   |  6 ++++--
 llvm/lib/Target/X86/X86TargetMachine.cpp       |  6 ++++--
 13 files changed, 54 insertions(+), 48 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/Passes.h b/llvm/include/llvm/CodeGen/Passes.h
index f850767270a4f..cafb9781698a2 100644
--- a/llvm/include/llvm/CodeGen/Passes.h
+++ b/llvm/include/llvm/CodeGen/Passes.h
@@ -205,20 +205,20 @@ namespace llvm {
   /// possible. It is best suited for debug code where live ranges are short.
   ///
   FunctionPass *createFastRegisterAllocator();
-  FunctionPass *createFastRegisterAllocator(RegClassFilterFunc F,
+  FunctionPass *createFastRegisterAllocator(RegAllocFilterFunc F,
                                             bool ClearVirtRegs);
 
   /// BasicRegisterAllocation Pass - This pass implements a degenerate global
   /// register allocator using the basic regalloc framework.
   ///
   FunctionPass *createBasicRegisterAllocator();
-  FunctionPass *createBasicRegisterAllocator(RegClassFilterFunc F);
+  FunctionPass *createBasicRegisterAllocator(RegAllocFilterFunc F);
 
   /// Greedy register allocation pass - This pass implements a global register
   /// allocator for optimized builds.
   ///
   FunctionPass *createGreedyRegisterAllocator();
-  FunctionPass *createGreedyRegisterAllocator(RegClassFilterFunc F);
+  FunctionPass *createGreedyRegisterAllocator(RegAllocFilterFunc F);
 
   /// PBQPRegisterAllocation Pass - This pass implements the Partitioned Boolean
   /// Quadratic Prograaming (PBQP) based register allocator.
diff --git a/llvm/include/llvm/CodeGen/RegAllocCommon.h b/llvm/include/llvm/CodeGen/RegAllocCommon.h
index ad533eab1861c..6c5cb295f5290 100644
--- a/llvm/include/llvm/CodeGen/RegAllocCommon.h
+++ b/llvm/include/llvm/CodeGen/RegAllocCommon.h
@@ -9,18 +9,20 @@
 #ifndef LLVM_CODEGEN_REGALLOCCOMMON_H
 #define LLVM_CODEGEN_REGALLOCCOMMON_H
 
+#include "llvm/CodeGen/Register.h"
 #include <functional>
 
 namespace llvm {
 
 class TargetRegisterClass;
 class TargetRegisterInfo;
+class MachineRegisterInfo;
 
 /// Filter function for register classes during regalloc. Default register class
 /// filter is nullptr, where all registers should be allocated.
 typedef std::function<bool(const TargetRegisterInfo &TRI,
-                           const TargetRegisterClass &RC)>
-    RegClassFilterFunc;
+                           const MachineRegisterInfo &MRI, const Register Reg)>
+    RegAllocFilterFunc;
 }
 
 #endif // LLVM_CODEGEN_REGALLOCCOMMON_H
diff --git a/llvm/include/llvm/CodeGen/RegAllocFast.h b/llvm/include/llvm/CodeGen/RegAllocFast.h
index c62bd14d0b4cb..c99c715daacf9 100644
--- a/llvm/include/llvm/CodeGen/RegAllocFast.h
+++ b/llvm/include/llvm/CodeGen/RegAllocFast.h
@@ -15,7 +15,7 @@
 namespace llvm {
 
 struct RegAllocFastPassOptions {
-  RegClassFilterFunc Filter = nullptr;
+  RegAllocFilterFunc Filter = nullptr;
   StringRef FilterName = "all";
   bool ClearVRegs = true;
 };
diff --git a/llvm/include/llvm/Passes/PassBuilder.h b/llvm/include/llvm/Passes/PassBuilder.h
index 551d297e0c089..474a19531ff5d 100644
--- a/llvm/include/llvm/Passes/PassBuilder.h
+++ b/llvm/include/llvm/Passes/PassBuilder.h
@@ -390,9 +390,9 @@ class PassBuilder {
   /// returns false.
   Error parseAAPipeline(AAManager &AA, StringRef PipelineText);
 
-  /// Parse RegClassFilterName to get RegClassFilterFunc.
-  std::optional<RegClassFilterFunc>
-  parseRegAllocFilter(StringRef RegClassFilterName);
+  /// Parse RegAllocFilterName to get RegAllocFilterFunc.
+  std::optional<RegAllocFilterFunc>
+  parseRegAllocFilter(StringRef RegAllocFilterName);
 
   /// Print pass names.
   void printPassNames(raw_ostream &OS);
@@ -586,7 +586,7 @@ class PassBuilder {
   /// needs it. E.g. AMDGPU requires regalloc passes can handle sgpr and vgpr
   /// separately.
   void registerRegClassFilterParsingCallback(
-      const std::function<RegClassFilterFunc(StringRef)> &C) {
+      const std::function<RegAllocFilterFunc(StringRef)> &C) {
     RegClassFilterParsingCallbacks.push_back(C);
   }
 
@@ -807,7 +807,7 @@ class PassBuilder {
               2>
       MachineFunctionPipelineParsingCallbacks;
   // Callbacks to parse `filter` parameter in register allocation passes
-  SmallVector<std::function<RegClassFilterFunc(StringRef)>, 2>
+  SmallVector<std::function<RegAllocFilterFunc(StringRef)>, 2>
       RegClassFilterParsingCallbacks;
 };
 
diff --git a/llvm/lib/CodeGen/RegAllocBase.h b/llvm/lib/CodeGen/RegAllocBase.h
index 643094671d682..a1ede08a15356 100644
--- a/llvm/lib/CodeGen/RegAllocBase.h
+++ b/llvm/lib/CodeGen/RegAllocBase.h
@@ -72,7 +72,7 @@ class RegAllocBase {
 
 private:
   /// Private, callees should go through shouldAllocateRegister
-  const RegClassFilterFunc ShouldAllocateClass;
+  const RegAllocFilterFunc shouldAllocateRegisterImpl;
 
 protected:
   /// Inst which is a def of an original reg and whose defs are already all
@@ -81,7 +81,8 @@ class RegAllocBase {
   /// always available for the remat of all the siblings of the original reg.
   SmallPtrSet<MachineInstr *, 32> DeadRemats;
 
-  RegAllocBase(const RegClassFilterFunc F = nullptr) : ShouldAllocateClass(F) {}
+  RegAllocBase(const RegAllocFilterFunc F = nullptr)
+      : shouldAllocateRegisterImpl(F) {}
 
   virtual ~RegAllocBase() = default;
 
@@ -90,9 +91,9 @@ class RegAllocBase {
 
   /// Get whether a given register should be allocated
   bool shouldAllocateRegister(Register Reg) {
-    if (!ShouldAllocateClass)
+    if (!shouldAllocateRegisterImpl)
       return true;
-    return ShouldAllocateClass(*TRI, *MRI->getRegClass(Reg));
+    return shouldAllocateRegisterImpl(*TRI, *MRI, Reg);
   }
 
   // The top-level driver. The output is a VirtRegMap that us updated with
diff --git a/llvm/lib/CodeGen/RegAllocBasic.cpp b/llvm/lib/CodeGen/RegAllocBasic.cpp
index e78a447969fb2..caf9c32a5a349 100644
--- a/llvm/lib/CodeGen/RegAllocBasic.cpp
+++ b/llvm/lib/CodeGen/RegAllocBasic.cpp
@@ -74,7 +74,7 @@ class RABasic : public MachineFunctionPass,
   void LRE_WillShrinkVirtReg(Register) override;
 
 public:
-  RABasic(const RegClassFilterFunc F = nullptr);
+  RABasic(const RegAllocFilterFunc F = nullptr);
 
   /// Return the pass name.
   StringRef getPassName() const override { return "Basic Register Allocator"; }
@@ -168,10 +168,8 @@ void RABasic::LRE_WillShrinkVirtReg(Register VirtReg) {
   enqueue(&LI);
 }
 
-RABasic::RABasic(RegClassFilterFunc F):
-  MachineFunctionPass(ID),
-  RegAllocBase(F) {
-}
+RABasic::RABasic(RegAllocFilterFunc F)
+    : MachineFunctionPass(ID), RegAllocBase(F) {}
 
 void RABasic::getAnalysisUsage(AnalysisUsage &AU) const {
   AU.setPreservesCFG();
@@ -333,6 +331,6 @@ FunctionPass* llvm::createBasicRegisterAllocator() {
   return new RABasic();
 }
 
-FunctionPass* llvm::createBasicRegisterAllocator(RegClassFilterFunc F) {
+FunctionPass *llvm::createBasicRegisterAllocator(RegAllocFilterFunc F) {
   return new RABasic(F);
 }
diff --git a/llvm/lib/CodeGen/RegAllocFast.cpp b/llvm/lib/CodeGen/RegAllocFast.cpp
index b4660d4085ffd..6e5ce72240d27 100644
--- a/llvm/lib/CodeGen/RegAllocFast.cpp
+++ b/llvm/lib/CodeGen/RegAllocFast.cpp
@@ -177,9 +177,9 @@ class InstrPosIndexes {
 
 class RegAllocFastImpl {
 public:
-  RegAllocFastImpl(const RegClassFilterFunc F = nullptr,
+  RegAllocFastImpl(const RegAllocFilterFunc F = nullptr,
                    bool ClearVirtRegs_ = true)
-      : ShouldAllocateClass(F), StackSlotForVirtReg(-1),
+      : ShouldAllocateRegisterImpl(F), StackSlotForVirtReg(-1),
         ClearVirtRegs(ClearVirtRegs_) {}
 
 private:
@@ -188,7 +188,7 @@ class RegAllocFastImpl {
   const TargetRegisterInfo *TRI = nullptr;
   const TargetInstrInfo *TII = nullptr;
   RegisterClassInfo RegClassInfo;
-  const RegClassFilterFunc ShouldAllocateClass;
+  const RegAllocFilterFunc ShouldAllocateRegisterImpl;
 
   /// Basic block currently being allocated.
   MachineBasicBlock *MBB = nullptr;
@@ -397,7 +397,7 @@ class RegAllocFast : public MachineFunctionPass {
 public:
   static char ID;
 
-  RegAllocFast(const RegClassFilterFunc F = nullptr, bool ClearVirtRegs_ = true)
+  RegAllocFast(const RegAllocFilterFunc F = nullptr, bool ClearVirtRegs_ = true)
       : MachineFunctionPass(ID), Impl(F, ClearVirtRegs_) {}
 
   bool runOnMachineFunction(MachineFunction &MF) override {
@@ -440,10 +440,10 @@ INITIALIZE_PASS(RegAllocFast, "regallocfast", "Fast Register Allocator", false,
 
 bool RegAllocFastImpl::shouldAllocateRegister(const Register Reg) const {
   assert(Reg.isVirtual());
-  if (!ShouldAllocateClass)
+  if (!ShouldAllocateRegisterImpl)
     return true;
-  const TargetRegisterClass &RC = *MRI->getRegClass(Reg);
-  return ShouldAllocateClass(*TRI, RC);
+
+  return ShouldAllocateRegisterImpl(*TRI, *MRI, Reg);
 }
 
 void RegAllocFastImpl::setPhysRegState(MCPhysReg PhysReg, unsigned NewState) {
@@ -1841,7 +1841,7 @@ void RegAllocFastPass::printPipeline(
 
 FunctionPass *llvm::createFastRegisterAllocator() { return new RegAllocFast(); }
 
-FunctionPass *llvm::createFastRegisterAllocator(RegClassFilterFunc Ftor,
+FunctionPass *llvm::createFastRegisterAllocator(RegAllocFilterFunc Ftor,
                                                 bool ClearVirtRegs) {
   return new RegAllocFast(Ftor, ClearVirtRegs);
 }
diff --git a/llvm/lib/CodeGen/RegAllocGreedy.cpp b/llvm/lib/CodeGen/RegAllocGreedy.cpp
index 310050d92d744..5001b4fec58f2 100644
--- a/llvm/lib/CodeGen/RegAllocGreedy.cpp
+++ b/llvm/lib/CodeGen/RegAllocGreedy.cpp
@@ -192,14 +192,12 @@ FunctionPass* llvm::createGreedyRegisterAllocator() {
   return new RAGreedy();
 }
 
-FunctionPass *llvm::createGreedyRegisterAllocator(RegClassFilterFunc Ftor) {
+FunctionPass *llvm::createGreedyRegisterAllocator(RegAllocFilterFunc Ftor) {
   return new RAGreedy(Ftor);
 }
 
-RAGreedy::RAGreedy(RegClassFilterFunc F):
-  MachineFunctionPass(ID),
-  RegAllocBase(F) {
-}
+RAGreedy::RAGreedy(RegAllocFilterFunc F)
+    : MachineFunctionPass(ID), RegAllocBase(F) {}
 
 void RAGreedy::getAnalysisUsage(AnalysisUsage &AU) const {
   AU.setPreservesCFG();
@@ -2306,7 +2304,7 @@ void RAGreedy::tryHintRecoloring(const LiveInterval &VirtReg) {
     if (Reg.isPhysical())
       continue;
 
-    // This may be a skipped class
+    // This may be a skipped register.
     if (!VRM->hasPhys(Reg)) {
       assert(!shouldAllocateRegister(Reg) &&
              "We have an unallocated variable which should have been handled");
diff --git a/llvm/lib/CodeGen/RegAllocGreedy.h b/llvm/lib/CodeGen/RegAllocGreedy.h
index ac300c0024f5a..2e7608a53e9ce 100644
--- a/llvm/lib/CodeGen/RegAllocGreedy.h
+++ b/llvm/lib/CodeGen/RegAllocGreedy.h
@@ -281,7 +281,7 @@ class LLVM_LIBRARY_VISIBILITY RAGreedy : public MachineFunctionPass,
   bool ReverseLocalAssignment = false;
 
 public:
-  RAGreedy(const RegClassFilterFunc F = nullptr);
+  RAGreedy(const RegAllocFilterFunc F = nullptr);
 
   /// Return the pass name.
   StringRef getPassName() const override { return "Greedy Register Allocator"; }
diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp
index a9d3f8ec3a4ec..63d2f633784e9 100644
--- a/llvm/lib/Passes/PassBuilder.cpp
+++ b/llvm/lib/Passes/PassBuilder.cpp
@@ -1180,7 +1180,7 @@ parseRegAllocFastPassOptions(PassBuilder &PB, StringRef Params) {
     std::tie(ParamName, Params) = Params.split(';');
 
     if (ParamName.consume_front("filter=")) {
-      std::optional<RegClassFilterFunc> Filter =
+      std::optional<RegAllocFilterFunc> Filter =
           PB.parseRegAllocFilter(ParamName);
       if (!Filter) {
         return make_error<StringError>(
@@ -2190,7 +2190,7 @@ Error PassBuilder::parseAAPipeline(AAManager &AA, StringRef PipelineText) {
   return Error::success();
 }
 
-std::optional<RegClassFilterFunc>
+std::optional<RegAllocFilterFunc>
 PassBuilder::parseRegAllocFilter(StringRef FilterName) {
   if (FilterName == "all")
     return nullptr;
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp b/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp
index 192996f84c5f3..c8fb68d1c0b0c 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp
@@ -85,16 +85,19 @@ class VGPRRegisterRegAlloc : public RegisterRegAllocBase<VGPRRegisterRegAlloc> {
 };
 
 static bool onlyAllocateSGPRs(const TargetRegisterInfo &TRI,
-                              const TargetRegisterClass &RC) {
-  return static_cast<const SIRegisterInfo &>(TRI).isSGPRClass(&RC);
+                              const MachineRegisterInfo &MRI,
+                              const Register Reg) {
+  const TargetRegisterClass *RC = MRI.getRegClass(Reg);
+  return static_cast<const SIRegisterInfo &>(TRI).isSGPRClass(RC);
 }
 
 static bool onlyAllocateVGPRs(const TargetRegisterInfo &TRI,
-                              const TargetRegisterClass &RC) {
-  return !static_cast<const SIRegisterInfo &>(TRI).isSGPRClass(&RC);
+                              const MachineRegisterInfo &MRI,
+                              const Register Reg) {
+  const TargetRegisterClass *RC = MRI.getRegClass(Reg);
+  return !static_cast<const SIRegisterInfo &>(TRI).isSGPRClass(RC);
 }
 
-
 /// -{sgpr|vgpr}-regalloc=... command line option.
 static FunctionPass *useDefaultRegisterAllocator() { return nullptr; }
 
@@ -749,7 +752,7 @@ void AMDGPUTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) {
       });
 
   PB.registerRegClassFilterParsingCallback(
-      [](StringRef FilterName) -> RegClassFilterFunc {
+      [](StringRef FilterName) -> RegAllocFilterFunc {
         if (FilterName == "sgpr")
           return onlyAllocateSGPRs;
         if (FilterName == "vgpr")
diff --git a/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp b/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp
index bc3699db6f91e..21fbf47875e68 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp
@@ -279,8 +279,10 @@ class RVVRegisterRegAlloc : public RegisterRegAllocBase<RVVRegisterRegAlloc> {
 };
 
 static bool onlyAllocateRVVReg(const TargetRegisterInfo &TRI,
-                               const TargetRegisterClass &RC) {
-  return RISCVRegisterInfo::isRVVRegClass(&RC);
+                               const MachineRegisterInfo &MRI,
+                               const Register Reg) {
+  const TargetRegisterClass *RC = MRI.getRegClass(Reg);
+  return RISCVRegisterInfo::isRVVRegClass(RC);
 }
 
 static FunctionPass *useDefaultRegisterAllocator() { return nullptr; }
diff --git a/llvm/lib/Target/X86/X86TargetMachine.cpp b/llvm/lib/Target/X86/X86TargetMachine.cpp
index 4c77f40fd32a3..d2d59ff3b93cf 100644
--- a/llvm/lib/Target/X86/X86TargetMachine.cpp
+++ b/llvm/lib/Target/X86/X86TargetMachine.cpp
@@ -676,8 +676,10 @@ std::unique_ptr<CSEConfigBase> X86PassConfig::getCSEConfig() const {
 }
 
 static bool onlyAllocateTileRegisters(const TargetRegisterInfo &TRI,
-                                      const TargetRegisterClass &RC) {
-  return static_cast<const X86RegisterInfo &>(TRI).isTileRegisterClass(&RC);
+                                      const MachineRegisterInfo &MRI,
+                                      const Register Reg) {
+  const TargetRegisterClass *RC = MRI.getRegClass(Reg);
+  return static_cast<const X86RegisterInfo &>(TRI).isTileRegisterClass(RC);
 }
 
 bool X86PassConfig::addRegAssignAndRewriteOptimized() {



More information about the llvm-commits mailing list