[llvm-branch-commits] [llvm] [AMDGPU] New RegBankSelect: Add Ptr32/Ptr64/Ptr128 (PR #142602)

Pierre van Houtryve via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue Jun 3 06:12:24 PDT 2025


https://github.com/Pierre-vh created https://github.com/llvm/llvm-project/pull/142602

There's quite a few opcodes that do not care about the exact AS of the pointer, just its size.
Adding generic types for these will help reduce duplication in the rule definitions.

I also moved the usual B types to use the new `isAnyPtr` helper I added to make sure they're supersets of the `Ptr` cases

>From b4da53a420bcb39934f64deff26624a022dccc2b Mon Sep 17 00:00:00 2001
From: pvanhout <pierre.vanhoutryve at amd.com>
Date: Tue, 3 Jun 2025 14:40:38 +0200
Subject: [PATCH] [AMDGPU] New RegBankSelect: Add Ptr32/Ptr64/Ptr128

There's quite a few opcodes that do not care about the exact AS of the pointer, just its size.
Adding generic types for these will help reduce duplication in the rule definitions.

I also moved the usual B types to use the new `isAnyPtr` helper I added to make sure they're supersets of the `Ptr` cases
---
 .../AMDGPU/AMDGPURegBankLegalizeHelper.cpp    | 42 +++++++++++++++----
 .../AMDGPU/AMDGPURegBankLegalizeRules.cpp     | 29 +++++++++++--
 .../AMDGPU/AMDGPURegBankLegalizeRules.h       | 19 +++++++++
 3 files changed, 77 insertions(+), 13 deletions(-)

diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp
index 12af7233ffad6..26aa3cf36c87a 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp
@@ -605,17 +605,23 @@ LLT RegBankLegalizeHelper::getBTyFromID(RegBankLLTMappingApplyID ID, LLT Ty) {
   case VgprB32:
   case UniInVgprB32:
     if (Ty == LLT::scalar(32) || Ty == LLT::fixed_vector(2, 16) ||
-        Ty == LLT::pointer(3, 32) || Ty == LLT::pointer(5, 32) ||
-        Ty == LLT::pointer(6, 32))
+        isAnyPtr(Ty, 32))
       return Ty;
     return LLT();
+  case SgprPtr32:
+  case VgprPtr32:
+    return isAnyPtr(Ty, 32) ? Ty : LLT();
+  case SgprPtr64:
+  case VgprPtr64:
+    return isAnyPtr(Ty, 64) ? Ty : LLT();
+  case SgprPtr128:
+  case VgprPtr128:
+    return isAnyPtr(Ty, 128) ? Ty : LLT();
   case SgprB64:
   case VgprB64:
   case UniInVgprB64:
     if (Ty == LLT::scalar(64) || Ty == LLT::fixed_vector(2, 32) ||
-        Ty == LLT::fixed_vector(4, 16) || Ty == LLT::pointer(0, 64) ||
-        Ty == LLT::pointer(1, 64) || Ty == LLT::pointer(4, 64) ||
-        (Ty.isPointer() && Ty.getAddressSpace() > AMDGPUAS::MAX_AMDGPU_ADDRESS))
+        Ty == LLT::fixed_vector(4, 16) || isAnyPtr(Ty, 64))
       return Ty;
     return LLT();
   case SgprB96:
@@ -629,7 +635,7 @@ LLT RegBankLegalizeHelper::getBTyFromID(RegBankLLTMappingApplyID ID, LLT Ty) {
   case VgprB128:
   case UniInVgprB128:
     if (Ty == LLT::scalar(128) || Ty == LLT::fixed_vector(4, 32) ||
-        Ty == LLT::fixed_vector(2, 64))
+        Ty == LLT::fixed_vector(2, 64) || isAnyPtr(Ty, 128))
       return Ty;
     return LLT();
   case SgprB256:
@@ -668,6 +674,9 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMappingApplyID ID) {
   case SgprP5:
   case SgprP6:
   case SgprP8:
+  case SgprPtr32:
+  case SgprPtr64:
+  case SgprPtr128:
   case SgprV2S16:
   case SgprV2S32:
   case SgprV4S32:
@@ -705,6 +714,9 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMappingApplyID ID) {
   case VgprP5:
   case VgprP6:
   case VgprP8:
+  case VgprPtr32:
+  case VgprPtr64:
+  case VgprPtr128:
   case VgprV2S16:
   case VgprV2S32:
   case VgprV4S32:
@@ -778,12 +790,18 @@ void RegBankLegalizeHelper::applyMappingDst(
     case SgprB128:
     case SgprB256:
     case SgprB512:
+    case SgprPtr32:
+    case SgprPtr64:
+    case SgprPtr128:
     case VgprB32:
     case VgprB64:
     case VgprB96:
     case VgprB128:
     case VgprB256:
-    case VgprB512: {
+    case VgprB512:
+    case VgprPtr32:
+    case VgprPtr64:
+    case VgprPtr128: {
       assert(Ty == getBTyFromID(MethodIDs[OpIdx], Ty));
       assert(RB == getRegBankFromID(MethodIDs[OpIdx]));
       break;
@@ -892,7 +910,10 @@ void RegBankLegalizeHelper::applyMappingSrc(
     case SgprB96:
     case SgprB128:
     case SgprB256:
-    case SgprB512: {
+    case SgprB512:
+    case SgprPtr32:
+    case SgprPtr64:
+    case SgprPtr128: {
       assert(Ty == getBTyFromID(MethodIDs[i], Ty));
       assert(RB == getRegBankFromID(MethodIDs[i]));
       break;
@@ -926,7 +947,10 @@ void RegBankLegalizeHelper::applyMappingSrc(
     case VgprB96:
     case VgprB128:
     case VgprB256:
-    case VgprB512: {
+    case VgprB512:
+    case VgprPtr32:
+    case VgprPtr64:
+    case VgprPtr128: {
       assert(Ty == getBTyFromID(MethodIDs[i], Ty));
       if (RB != VgprRB) {
         auto CopyToVgpr = B.buildCopy({VgprRB, Ty}, Reg);
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp
index 08a35b9794344..b6260076731ba 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp
@@ -26,6 +26,10 @@
 using namespace llvm;
 using namespace AMDGPU;
 
+bool AMDGPU::isAnyPtr(LLT Ty, unsigned Width) {
+  return Ty.isPointer() && Ty.getSizeInBits() == Width;
+}
+
 RegBankLLTMapping::RegBankLLTMapping(
     std::initializer_list<RegBankLLTMappingApplyID> DstOpMappingList,
     std::initializer_list<RegBankLLTMappingApplyID> SrcOpMappingList,
@@ -68,6 +72,12 @@ bool matchUniformityAndLLT(Register Reg, UniformityLLTOpPredicateID UniID,
     return MRI.getType(Reg) == LLT::pointer(6, 32);
   case P8:
     return MRI.getType(Reg) == LLT::pointer(8, 128);
+  case Ptr32:
+    return isAnyPtr(MRI.getType(Reg), 32);
+  case Ptr64:
+    return isAnyPtr(MRI.getType(Reg), 64);
+  case Ptr128:
+    return isAnyPtr(MRI.getType(Reg), 128);
   case V2S32:
     return MRI.getType(Reg) == LLT::fixed_vector(2, 32);
   case V4S32:
@@ -110,6 +120,12 @@ bool matchUniformityAndLLT(Register Reg, UniformityLLTOpPredicateID UniID,
     return MRI.getType(Reg) == LLT::pointer(6, 32) && MUI.isUniform(Reg);
   case UniP8:
     return MRI.getType(Reg) == LLT::pointer(8, 128) && MUI.isUniform(Reg);
+  case UniPtr32:
+    return isAnyPtr(MRI.getType(Reg), 32) && MUI.isUniform(Reg);
+  case UniPtr64:
+    return isAnyPtr(MRI.getType(Reg), 64) && MUI.isUniform(Reg);
+  case UniPtr128:
+    return isAnyPtr(MRI.getType(Reg), 128) && MUI.isUniform(Reg);
   case UniV2S16:
     return MRI.getType(Reg) == LLT::fixed_vector(2, 16) && MUI.isUniform(Reg);
   case UniB32:
@@ -150,6 +166,12 @@ bool matchUniformityAndLLT(Register Reg, UniformityLLTOpPredicateID UniID,
     return MRI.getType(Reg) == LLT::pointer(6, 32) && MUI.isDivergent(Reg);
   case DivP8:
     return MRI.getType(Reg) == LLT::pointer(8, 128) && MUI.isDivergent(Reg);
+  case DivPtr32:
+    return isAnyPtr(MRI.getType(Reg), 32) && MUI.isDivergent(Reg);
+  case DivPtr64:
+    return isAnyPtr(MRI.getType(Reg), 64) && MUI.isDivergent(Reg);
+  case DivPtr128:
+    return isAnyPtr(MRI.getType(Reg), 128) && MUI.isDivergent(Reg);
   case DivV2S16:
     return MRI.getType(Reg) == LLT::fixed_vector(2, 16) && MUI.isDivergent(Reg);
   case DivB32:
@@ -223,15 +245,14 @@ UniformityLLTOpPredicateID LLTToId(LLT Ty) {
 
 UniformityLLTOpPredicateID LLTToBId(LLT Ty) {
   if (Ty == LLT::scalar(32) || Ty == LLT::fixed_vector(2, 16) ||
-      (Ty.isPointer() && Ty.getSizeInBits() == 32))
+      isAnyPtr(Ty, 32))
     return B32;
   if (Ty == LLT::scalar(64) || Ty == LLT::fixed_vector(2, 32) ||
-      Ty == LLT::fixed_vector(4, 16) ||
-      (Ty.isPointer() && Ty.getSizeInBits() == 64))
+      Ty == LLT::fixed_vector(4, 16) || isAnyPtr(Ty, 64))
     return B64;
   if (Ty == LLT::fixed_vector(3, 32))
     return B96;
-  if (Ty == LLT::fixed_vector(4, 32))
+  if (Ty == LLT::fixed_vector(4, 32) || isAnyPtr(Ty, 128))
     return B128;
   return _;
 }
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h
index 14be873b6ce19..1d429f711fbf6 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h
@@ -15,6 +15,7 @@
 
 namespace llvm {
 
+class LLT;
 class MachineRegisterInfo;
 class MachineInstr;
 class GCNSubtarget;
@@ -26,6 +27,9 @@ using MachineUniformityInfo = GenericUniformityInfo<MachineSSAContext>;
 
 namespace AMDGPU {
 
+/// \returns true if \p Ty is a pointer type with size \p Width.
+bool isAnyPtr(LLT Ty, unsigned Width);
+
 // IDs used to build predicate for RegBankLegalizeRule. Predicate can have one
 // or more IDs and each represents a check for 'uniform or divergent' + LLT or
 // just LLT on register operand.
@@ -62,6 +66,9 @@ enum UniformityLLTOpPredicateID {
   P5,
   P6,
   P8,
+  Ptr32,
+  Ptr64,
+  Ptr128,
 
   UniP0,
   UniP1,
@@ -71,6 +78,9 @@ enum UniformityLLTOpPredicateID {
   UniP5,
   UniP6,
   UniP8,
+  UniPtr32,
+  UniPtr64,
+  UniPtr128,
 
   DivP0,
   DivP1,
@@ -80,6 +90,9 @@ enum UniformityLLTOpPredicateID {
   DivP5,
   DivP6,
   DivP8,
+  DivPtr32,
+  DivPtr64,
+  DivPtr128,
 
   // vectors
   V2S16,
@@ -138,6 +151,9 @@ enum RegBankLLTMappingApplyID {
   SgprP5,
   SgprP6,
   SgprP8,
+  SgprPtr32,
+  SgprPtr64,
+  SgprPtr128,
   SgprV2S16,
   SgprV4S32,
   SgprV2S32,
@@ -161,6 +177,9 @@ enum RegBankLLTMappingApplyID {
   VgprP5,
   VgprP6,
   VgprP8,
+  VgprPtr32,
+  VgprPtr64,
+  VgprPtr128,
   VgprV2S16,
   VgprV2S32,
   VgprB32,



More information about the llvm-branch-commits mailing list