[llvm] [CodeGen] change prototype of RegClassFilterFunc (PR #90907)

Gang Chen via llvm-commits llvm-commits at lists.llvm.org
Thu May 2 14:45:39 PDT 2024


https://github.com/cmc-rep created https://github.com/llvm/llvm-project/pull/90907

change prototype of RegClassFilterFunc so that we can filter not just by RegClass. We need to implement more complicated filter based upon some other info associated with each register.

>From e3aacdff49e331a1cd95376405b3172ec734a6aa Mon Sep 17 00:00:00 2001
From: gangc <gangc at amd.com>
Date: Tue, 30 Apr 2024 13:38:25 -0700
Subject: [PATCH] [CodeGen] change prototype of RegClassFilterFunc

change prototype of RegClassFilterFunc so that we can filter
not just by RegClass. We need to implement more complicated
filter based upon some other info associated with each register.

Signed-off-by: gangc <gangc at amd.com>
---
 llvm/include/llvm/CodeGen/RegAllocCommon.h     | 10 ++++++----
 llvm/lib/CodeGen/RegAllocBase.cpp              |  3 +--
 llvm/lib/CodeGen/RegAllocFast.cpp              |  3 +--
 llvm/lib/CodeGen/RegAllocGreedy.cpp            |  6 +++---
 llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp | 13 ++++++++-----
 llvm/lib/Target/RISCV/RISCVTargetMachine.cpp   |  6 ++++--
 llvm/lib/Target/X86/X86TargetMachine.cpp       |  6 ++++--
 7 files changed, 27 insertions(+), 20 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/RegAllocCommon.h b/llvm/include/llvm/CodeGen/RegAllocCommon.h
index 757ca8e112eec3..943c9dd9e9bf5d 100644
--- a/llvm/include/llvm/CodeGen/RegAllocCommon.h
+++ b/llvm/include/llvm/CodeGen/RegAllocCommon.h
@@ -10,22 +10,24 @@
 #define LLVM_CODEGEN_REGALLOCCOMMON_H
 
 #include <functional>
+#include <llvm/CodeGen/Register.h>
 
 namespace llvm {
 
-class TargetRegisterClass;
 class TargetRegisterInfo;
+class MachineRegisterInfo;
 
 typedef std::function<bool(const TargetRegisterInfo &TRI,
-                           const TargetRegisterClass &RC)> RegClassFilterFunc;
+                           const MachineRegisterInfo &MRI, const Register Reg)>
+    RegClassFilterFunc;
 
 /// Default register class filter function for register allocation. All virtual
 /// registers should be allocated.
 static inline bool allocateAllRegClasses(const TargetRegisterInfo &,
-                                         const TargetRegisterClass &) {
+                                         const MachineRegisterInfo &,
+                                         const Register) {
   return true;
 }
-
 }
 
 #endif // LLVM_CODEGEN_REGALLOCCOMMON_H
diff --git a/llvm/lib/CodeGen/RegAllocBase.cpp b/llvm/lib/CodeGen/RegAllocBase.cpp
index d0dec372f68961..a4645ed93029d0 100644
--- a/llvm/lib/CodeGen/RegAllocBase.cpp
+++ b/llvm/lib/CodeGen/RegAllocBase.cpp
@@ -181,8 +181,7 @@ void RegAllocBase::enqueue(const LiveInterval *LI) {
   if (VRM->hasPhys(Reg))
     return;
 
-  const TargetRegisterClass &RC = *MRI->getRegClass(Reg);
-  if (ShouldAllocateClass(*TRI, RC)) {
+  if (ShouldAllocateClass(*TRI, *MRI, Reg)) {
     LLVM_DEBUG(dbgs() << "Enqueuing " << printReg(Reg, TRI) << '\n');
     enqueueImpl(LI);
   } else {
diff --git a/llvm/lib/CodeGen/RegAllocFast.cpp b/llvm/lib/CodeGen/RegAllocFast.cpp
index 6740e1f0edb4f4..f6419daba6a2d8 100644
--- a/llvm/lib/CodeGen/RegAllocFast.cpp
+++ b/llvm/lib/CodeGen/RegAllocFast.cpp
@@ -417,8 +417,7 @@ INITIALIZE_PASS(RegAllocFast, "regallocfast", "Fast Register Allocator", false,
 
 bool RegAllocFast::shouldAllocateRegister(const Register Reg) const {
   assert(Reg.isVirtual());
-  const TargetRegisterClass &RC = *MRI->getRegClass(Reg);
-  return ShouldAllocateClass(*TRI, RC);
+  return ShouldAllocateClass(*TRI, *MRI, Reg);
 }
 
 void RegAllocFast::setPhysRegState(MCPhysReg PhysReg, unsigned NewState) {
diff --git a/llvm/lib/CodeGen/RegAllocGreedy.cpp b/llvm/lib/CodeGen/RegAllocGreedy.cpp
index 348277224c7aee..c3d5984b46f51a 100644
--- a/llvm/lib/CodeGen/RegAllocGreedy.cpp
+++ b/llvm/lib/CodeGen/RegAllocGreedy.cpp
@@ -2306,9 +2306,9 @@ void RAGreedy::tryHintRecoloring(const LiveInterval &VirtReg) {
     if (Reg.isPhysical())
       continue;
 
-    // This may be a skipped class
+    // This may be a skipped register
     if (!VRM->hasPhys(Reg)) {
-      assert(!ShouldAllocateClass(*TRI, *MRI->getRegClass(Reg)) &&
+      assert(!ShouldAllocateClass(*TRI, *MRI, Reg) &&
              "We have an unallocated variable which should have been handled");
       continue;
     }
@@ -2698,7 +2698,7 @@ bool RAGreedy::hasVirtRegAlloc() {
     const TargetRegisterClass *RC = MRI->getRegClass(Reg);
     if (!RC)
       continue;
-    if (ShouldAllocateClass(*TRI, *RC))
+    if (ShouldAllocateClass(*TRI, *MRI, Reg))
       return true;
   }
 
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp b/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp
index 305a6c8c3b9262..3d6965fa9876ca 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp
@@ -81,16 +81,19 @@ class VGPRRegisterRegAlloc : public RegisterRegAllocBase<VGPRRegisterRegAlloc> {
 };
 
 static bool onlyAllocateSGPRs(const TargetRegisterInfo &TRI,
-                              const TargetRegisterClass &RC) {
-  return static_cast<const SIRegisterInfo &>(TRI).isSGPRClass(&RC);
+                              const MachineRegisterInfo &MRI,
+                              const Register Reg) {
+  const TargetRegisterClass *RC = MRI.getRegClass(Reg);
+  return static_cast<const SIRegisterInfo &>(TRI).isSGPRClass(RC);
 }
 
 static bool onlyAllocateVGPRs(const TargetRegisterInfo &TRI,
-                              const TargetRegisterClass &RC) {
-  return !static_cast<const SIRegisterInfo &>(TRI).isSGPRClass(&RC);
+                              const MachineRegisterInfo &MRI,
+                              const Register Reg) {
+  const TargetRegisterClass *RC = MRI.getRegClass(Reg);
+  return !static_cast<const SIRegisterInfo &>(TRI).isSGPRClass(RC);
 }
 
-
 /// -{sgpr|vgpr}-regalloc=... command line option.
 static FunctionPass *useDefaultRegisterAllocator() { return nullptr; }
 
diff --git a/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp b/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp
index 0876f46728a10c..44a26c48c63e02 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp
@@ -273,8 +273,10 @@ class RVVRegisterRegAlloc : public RegisterRegAllocBase<RVVRegisterRegAlloc> {
 };
 
 static bool onlyAllocateRVVReg(const TargetRegisterInfo &TRI,
-                               const TargetRegisterClass &RC) {
-  return RISCVRegisterInfo::isRVVRegClass(&RC);
+                               const MachineRegisterInfo &MRI,
+                               const Register Reg) {
+  const TargetRegisterClass *RC = MRI.getRegClass(Reg);
+  return RISCVRegisterInfo::isRVVRegClass(RC);
 }
 
 static FunctionPass *useDefaultRegisterAllocator() { return nullptr; }
diff --git a/llvm/lib/Target/X86/X86TargetMachine.cpp b/llvm/lib/Target/X86/X86TargetMachine.cpp
index 86b456019c4e56..eab537e8a5f8b9 100644
--- a/llvm/lib/Target/X86/X86TargetMachine.cpp
+++ b/llvm/lib/Target/X86/X86TargetMachine.cpp
@@ -652,8 +652,10 @@ std::unique_ptr<CSEConfigBase> X86PassConfig::getCSEConfig() const {
 }
 
 static bool onlyAllocateTileRegisters(const TargetRegisterInfo &TRI,
-                                      const TargetRegisterClass &RC) {
-  return static_cast<const X86RegisterInfo &>(TRI).isTileRegisterClass(&RC);
+                                      const MachineRegisterInfo &MRI,
+                                      const Register Reg) {
+  const TargetRegisterClass *RC = MRI.getRegClass(Reg);
+  return static_cast<const X86RegisterInfo &>(TRI).isTileRegisterClass(RC);
 }
 
 bool X86PassConfig::addRegAssignAndRewriteOptimized() {



More information about the llvm-commits mailing list