[llvm-branch-commits] [llvm] [AMDGPU] New RegBankSelect: Add Ptr32/Ptr64/Ptr128 (PR #142602)

Pierre van Houtryve via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Jun 4 01:09:32 PDT 2025


https://github.com/Pierre-vh updated https://github.com/llvm/llvm-project/pull/142602

>From c69258d78459b8dcc89bec38a8a795763cd3dc80 Mon Sep 17 00:00:00 2001
From: pvanhout <pierre.vanhoutryve at amd.com>
Date: Tue, 3 Jun 2025 14:40:38 +0200
Subject: [PATCH] [AMDGPU] New RegBankSelect: Add Ptr32/Ptr64/Ptr128

There's quite a few opcodes that do not care about the exact AS of the pointer, just its size.
Adding generic types for these will help reduce duplication in the rule definitions.

I also moved the usual B types to use the new `isAnyPtr` helper I added to make sure they're supersets of the `Ptr` cases
---
 .../AMDGPU/AMDGPURegBankLegalizeHelper.cpp    | 42 +++++++++++++++----
 .../AMDGPU/AMDGPURegBankLegalizeRules.cpp     | 29 +++++++++++--
 .../AMDGPU/AMDGPURegBankLegalizeRules.h       | 19 +++++++++
 3 files changed, 77 insertions(+), 13 deletions(-)

diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp
index 89af982636590..b2ddc6e88966b 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp
@@ -595,17 +595,23 @@ LLT RegBankLegalizeHelper::getBTyFromID(RegBankLLTMappingApplyID ID, LLT Ty) {
   case VgprB32:
   case UniInVgprB32:
     if (Ty == LLT::scalar(32) || Ty == LLT::fixed_vector(2, 16) ||
-        Ty == LLT::pointer(3, 32) || Ty == LLT::pointer(5, 32) ||
-        Ty == LLT::pointer(6, 32))
+        isAnyPtr(Ty, 32))
       return Ty;
     return LLT();
+  case SgprPtr32:
+  case VgprPtr32:
+    return isAnyPtr(Ty, 32) ? Ty : LLT();
+  case SgprPtr64:
+  case VgprPtr64:
+    return isAnyPtr(Ty, 64) ? Ty : LLT();
+  case SgprPtr128:
+  case VgprPtr128:
+    return isAnyPtr(Ty, 128) ? Ty : LLT();
   case SgprB64:
   case VgprB64:
   case UniInVgprB64:
     if (Ty == LLT::scalar(64) || Ty == LLT::fixed_vector(2, 32) ||
-        Ty == LLT::fixed_vector(4, 16) || Ty == LLT::pointer(0, 64) ||
-        Ty == LLT::pointer(1, 64) || Ty == LLT::pointer(4, 64) ||
-        (Ty.isPointer() && Ty.getAddressSpace() > AMDGPUAS::MAX_AMDGPU_ADDRESS))
+        Ty == LLT::fixed_vector(4, 16) || isAnyPtr(Ty, 64))
       return Ty;
     return LLT();
   case SgprB96:
@@ -619,7 +625,7 @@ LLT RegBankLegalizeHelper::getBTyFromID(RegBankLLTMappingApplyID ID, LLT Ty) {
   case VgprB128:
   case UniInVgprB128:
     if (Ty == LLT::scalar(128) || Ty == LLT::fixed_vector(4, 32) ||
-        Ty == LLT::fixed_vector(2, 64))
+        Ty == LLT::fixed_vector(2, 64) || isAnyPtr(Ty, 128))
       return Ty;
     return LLT();
   case SgprB256:
@@ -654,6 +660,9 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMappingApplyID ID) {
   case SgprP3:
   case SgprP4:
   case SgprP5:
+  case SgprPtr32:
+  case SgprPtr64:
+  case SgprPtr128:
   case SgprV2S16:
   case SgprV2S32:
   case SgprV4S32:
@@ -688,6 +697,9 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMappingApplyID ID) {
   case VgprP3:
   case VgprP4:
   case VgprP5:
+  case VgprPtr32:
+  case VgprPtr64:
+  case VgprPtr128:
   case VgprV2S16:
   case VgprV2S32:
   case VgprV4S32:
@@ -754,12 +766,18 @@ void RegBankLegalizeHelper::applyMappingDst(
     case SgprB128:
     case SgprB256:
     case SgprB512:
+    case SgprPtr32:
+    case SgprPtr64:
+    case SgprPtr128:
     case VgprB32:
     case VgprB64:
     case VgprB96:
     case VgprB128:
     case VgprB256:
-    case VgprB512: {
+    case VgprB512:
+    case VgprPtr32:
+    case VgprPtr64:
+    case VgprPtr128: {
       assert(Ty == getBTyFromID(MethodIDs[OpIdx], Ty));
       assert(RB == getRegBankFromID(MethodIDs[OpIdx]));
       break;
@@ -864,7 +882,10 @@ void RegBankLegalizeHelper::applyMappingSrc(
     case SgprB96:
     case SgprB128:
     case SgprB256:
-    case SgprB512: {
+    case SgprB512:
+    case SgprPtr32:
+    case SgprPtr64:
+    case SgprPtr128: {
       assert(Ty == getBTyFromID(MethodIDs[i], Ty));
       assert(RB == getRegBankFromID(MethodIDs[i]));
       break;
@@ -895,7 +916,10 @@ void RegBankLegalizeHelper::applyMappingSrc(
     case VgprB96:
     case VgprB128:
     case VgprB256:
-    case VgprB512: {
+    case VgprB512:
+    case VgprPtr32:
+    case VgprPtr64:
+    case VgprPtr128: {
       assert(Ty == getBTyFromID(MethodIDs[i], Ty));
       if (RB != VgprRB) {
         auto CopyToVgpr = B.buildCopy({VgprRB, Ty}, Reg);
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp
index 672fc5b79abc2..5402129e41887 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp
@@ -26,6 +26,10 @@
 using namespace llvm;
 using namespace AMDGPU;
 
+bool AMDGPU::isAnyPtr(LLT Ty, unsigned Width) {
+  return Ty.isPointer() && Ty.getSizeInBits() == Width;
+}
+
 RegBankLLTMapping::RegBankLLTMapping(
     std::initializer_list<RegBankLLTMappingApplyID> DstOpMappingList,
     std::initializer_list<RegBankLLTMappingApplyID> SrcOpMappingList,
@@ -62,6 +66,12 @@ bool matchUniformityAndLLT(Register Reg, UniformityLLTOpPredicateID UniID,
     return MRI.getType(Reg) == LLT::pointer(4, 64);
   case P5:
     return MRI.getType(Reg) == LLT::pointer(5, 32);
+  case Ptr32:
+    return isAnyPtr(MRI.getType(Reg), 32);
+  case Ptr64:
+    return isAnyPtr(MRI.getType(Reg), 64);
+  case Ptr128:
+    return isAnyPtr(MRI.getType(Reg), 128);
   case V2S32:
     return MRI.getType(Reg) == LLT::fixed_vector(2, 32);
   case V4S32:
@@ -98,6 +108,12 @@ bool matchUniformityAndLLT(Register Reg, UniformityLLTOpPredicateID UniID,
     return MRI.getType(Reg) == LLT::pointer(4, 64) && MUI.isUniform(Reg);
   case UniP5:
     return MRI.getType(Reg) == LLT::pointer(5, 32) && MUI.isUniform(Reg);
+  case UniPtr32:
+    return isAnyPtr(MRI.getType(Reg), 32) && MUI.isUniform(Reg);
+  case UniPtr64:
+    return isAnyPtr(MRI.getType(Reg), 64) && MUI.isUniform(Reg);
+  case UniPtr128:
+    return isAnyPtr(MRI.getType(Reg), 128) && MUI.isUniform(Reg);
   case UniV2S16:
     return MRI.getType(Reg) == LLT::fixed_vector(2, 16) && MUI.isUniform(Reg);
   case UniB32:
@@ -132,6 +148,12 @@ bool matchUniformityAndLLT(Register Reg, UniformityLLTOpPredicateID UniID,
     return MRI.getType(Reg) == LLT::pointer(4, 64) && MUI.isDivergent(Reg);
   case DivP5:
     return MRI.getType(Reg) == LLT::pointer(5, 32) && MUI.isDivergent(Reg);
+  case DivPtr32:
+    return isAnyPtr(MRI.getType(Reg), 32) && MUI.isDivergent(Reg);
+  case DivPtr64:
+    return isAnyPtr(MRI.getType(Reg), 64) && MUI.isDivergent(Reg);
+  case DivPtr128:
+    return isAnyPtr(MRI.getType(Reg), 128) && MUI.isDivergent(Reg);
   case DivV2S16:
     return MRI.getType(Reg) == LLT::fixed_vector(2, 16) && MUI.isDivergent(Reg);
   case DivB32:
@@ -205,15 +227,14 @@ UniformityLLTOpPredicateID LLTToId(LLT Ty) {
 
 UniformityLLTOpPredicateID LLTToBId(LLT Ty) {
   if (Ty == LLT::scalar(32) || Ty == LLT::fixed_vector(2, 16) ||
-      (Ty.isPointer() && Ty.getSizeInBits() == 32))
+      isAnyPtr(Ty, 32))
     return B32;
   if (Ty == LLT::scalar(64) || Ty == LLT::fixed_vector(2, 32) ||
-      Ty == LLT::fixed_vector(4, 16) ||
-      (Ty.isPointer() && Ty.getSizeInBits() == 64))
+      Ty == LLT::fixed_vector(4, 16) || isAnyPtr(Ty, 64))
     return B64;
   if (Ty == LLT::fixed_vector(3, 32))
     return B96;
-  if (Ty == LLT::fixed_vector(4, 32))
+  if (Ty == LLT::fixed_vector(4, 32) || isAnyPtr(Ty, 128))
     return B128;
   return _;
 }
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h
index 30b900d871f3c..7243d75aa830c 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h
@@ -15,6 +15,7 @@
 
 namespace llvm {
 
+class LLT;
 class MachineRegisterInfo;
 class MachineInstr;
 class GCNSubtarget;
@@ -26,6 +27,9 @@ using MachineUniformityInfo = GenericUniformityInfo<MachineSSAContext>;
 
 namespace AMDGPU {
 
+/// \returns true if \p Ty is a pointer type with size \p Width.
+bool isAnyPtr(LLT Ty, unsigned Width);
+
 // IDs used to build predicate for RegBankLegalizeRule. Predicate can have one
 // or more IDs and each represents a check for 'uniform or divergent' + LLT or
 // just LLT on register operand.
@@ -59,18 +63,27 @@ enum UniformityLLTOpPredicateID {
   P3,
   P4,
   P5,
+  Ptr32,
+  Ptr64,
+  Ptr128,
 
   UniP0,
   UniP1,
   UniP3,
   UniP4,
   UniP5,
+  UniPtr32,
+  UniPtr64,
+  UniPtr128,
 
   DivP0,
   DivP1,
   DivP3,
   DivP4,
   DivP5,
+  DivPtr32,
+  DivPtr64,
+  DivPtr128,
 
   // vectors
   V2S16,
@@ -125,6 +138,9 @@ enum RegBankLLTMappingApplyID {
   SgprP3,
   SgprP4,
   SgprP5,
+  SgprPtr32,
+  SgprPtr64,
+  SgprPtr128,
   SgprV2S16,
   SgprV4S32,
   SgprV2S32,
@@ -145,6 +161,9 @@ enum RegBankLLTMappingApplyID {
   VgprP3,
   VgprP4,
   VgprP5,
+  VgprPtr32,
+  VgprPtr64,
+  VgprPtr128,
   VgprV2S16,
   VgprV2S32,
   VgprB32,



More information about the llvm-branch-commits mailing list