[llvm-branch-commits] [llvm] AMDGPU/GlobalISel: RBLegalize rules for load (PR #112882)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Oct 18 04:12:35 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-amdgpu

Author: Petar Avramovic (petar-avramovic)

<details>
<summary>Changes</summary>

Add IDs for bit width that cover multiple LLTs: B32 B64 etc.
"Predicate" wrapper class for bool predicate functions used to
write pretty rules. Predicates can be combined using &&, || and !.
Lowering for splitting and widening loads.
Write rules for loads to not change existing mir tests from old
regbankselect.

---

Patch is 81.23 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/112882.diff


6 Files Affected:

- (modified) llvm/lib/Target/AMDGPU/AMDGPURBLegalizeHelper.cpp (+297-5) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPURBLegalizeHelper.h (+4-3) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPURBLegalizeRules.cpp (+300-7) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPURBLegalizeRules.h (+63-2) 
- (modified) llvm/test/CodeGen/AMDGPU/GlobalISel/regbankselect-load.mir (+271-49) 
- (modified) llvm/test/CodeGen/AMDGPU/GlobalISel/regbankselect-zextload.mir (+7-2) 


``````````diff
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeHelper.cpp b/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeHelper.cpp
index a0f6ecedab7a83..f58f0a315096d2 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeHelper.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeHelper.cpp
@@ -37,6 +37,97 @@ bool RegBankLegalizeHelper::findRuleAndApplyMapping(MachineInstr &MI) {
   return true;
 }
 
+void RegBankLegalizeHelper::splitLoad(MachineInstr &MI,
+                                      ArrayRef<LLT> LLTBreakdown, LLT MergeTy) {
+  MachineFunction &MF = B.getMF();
+  assert(MI.getNumMemOperands() == 1);
+  MachineMemOperand &BaseMMO = **MI.memoperands_begin();
+  Register Dst = MI.getOperand(0).getReg();
+  const RegisterBank *DstRB = MRI.getRegBankOrNull(Dst);
+  Register BasePtrReg = MI.getOperand(1).getReg();
+  LLT PtrTy = MRI.getType(BasePtrReg);
+  const RegisterBank *PtrRB = MRI.getRegBankOrNull(BasePtrReg);
+  LLT OffsetTy = LLT::scalar(PtrTy.getSizeInBits());
+  SmallVector<Register, 4> LoadPartRegs;
+
+  unsigned ByteOffset = 0;
+  for (LLT PartTy : LLTBreakdown) {
+    Register BasePtrPlusOffsetReg;
+    if (ByteOffset == 0) {
+      BasePtrPlusOffsetReg = BasePtrReg;
+    } else {
+      BasePtrPlusOffsetReg = MRI.createVirtualRegister({PtrRB, PtrTy});
+      Register OffsetReg = MRI.createVirtualRegister({PtrRB, OffsetTy});
+      B.buildConstant(OffsetReg, ByteOffset);
+      B.buildPtrAdd(BasePtrPlusOffsetReg, BasePtrReg, OffsetReg);
+    }
+    MachineMemOperand *BasePtrPlusOffsetMMO =
+        MF.getMachineMemOperand(&BaseMMO, ByteOffset, PartTy);
+    Register PartLoad = MRI.createVirtualRegister({DstRB, PartTy});
+    B.buildLoad(PartLoad, BasePtrPlusOffsetReg, *BasePtrPlusOffsetMMO);
+    LoadPartRegs.push_back(PartLoad);
+    ByteOffset += PartTy.getSizeInBytes();
+  }
+
+  if (!MergeTy.isValid()) {
+    // Loads are of same size, concat or merge them together.
+    B.buildMergeLikeInstr(Dst, LoadPartRegs);
+  } else {
+    // Load(s) are not all of same size, need to unmerge them to smaller pieces
+    // of MergeTy type, then merge them all together in Dst.
+    SmallVector<Register, 4> MergeTyParts;
+    for (Register Reg : LoadPartRegs) {
+      if (MRI.getType(Reg) == MergeTy) {
+        MergeTyParts.push_back(Reg);
+      } else {
+        auto Unmerge = B.buildUnmerge(MergeTy, Reg);
+        for (unsigned i = 0; i < Unmerge->getNumOperands() - 1; ++i) {
+          Register UnmergeReg = Unmerge->getOperand(i).getReg();
+          MRI.setRegBank(UnmergeReg, *DstRB);
+          MergeTyParts.push_back(UnmergeReg);
+        }
+      }
+    }
+    B.buildMergeLikeInstr(Dst, MergeTyParts);
+  }
+  MI.eraseFromParent();
+}
+
+void RegBankLegalizeHelper::widenLoad(MachineInstr &MI, LLT WideTy,
+                                      LLT MergeTy) {
+  MachineFunction &MF = B.getMF();
+  assert(MI.getNumMemOperands() == 1);
+  MachineMemOperand &BaseMMO = **MI.memoperands_begin();
+  Register Dst = MI.getOperand(0).getReg();
+  const RegisterBank *DstRB = MRI.getRegBankOrNull(Dst);
+  Register BasePtrReg = MI.getOperand(1).getReg();
+
+  Register BasePtrPlusOffsetReg;
+  BasePtrPlusOffsetReg = BasePtrReg;
+
+  MachineMemOperand *BasePtrPlusOffsetMMO =
+      MF.getMachineMemOperand(&BaseMMO, 0, WideTy);
+  Register WideLoad = MRI.createVirtualRegister({DstRB, WideTy});
+  B.buildLoad(WideLoad, BasePtrPlusOffsetReg, *BasePtrPlusOffsetMMO);
+
+  if (WideTy.isScalar()) {
+    B.buildTrunc(Dst, WideLoad);
+  } else {
+    SmallVector<Register, 4> MergeTyParts;
+    unsigned NumEltsMerge =
+        MRI.getType(Dst).getSizeInBits() / MergeTy.getSizeInBits();
+    auto Unmerge = B.buildUnmerge(MergeTy, WideLoad);
+    for (unsigned i = 0; i < Unmerge->getNumOperands() - 1; ++i) {
+      Register UnmergeReg = Unmerge->getOperand(i).getReg();
+      MRI.setRegBank(UnmergeReg, *DstRB);
+      if (i < NumEltsMerge)
+        MergeTyParts.push_back(UnmergeReg);
+    }
+    B.buildMergeLikeInstr(Dst, MergeTyParts);
+  }
+  MI.eraseFromParent();
+}
+
 void RegBankLegalizeHelper::lower(MachineInstr &MI,
                                   const RegBankLLTMapping &Mapping,
                                   SmallSet<Register, 4> &WaterfallSGPRs) {
@@ -119,6 +210,53 @@ void RegBankLegalizeHelper::lower(MachineInstr &MI,
     MI.eraseFromParent();
     break;
   }
+  case SplitLoad: {
+    LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
+    LLT V8S16 = LLT::fixed_vector(8, S16);
+    LLT V4S32 = LLT::fixed_vector(4, S32);
+    LLT V2S64 = LLT::fixed_vector(2, S64);
+
+    if (DstTy == LLT::fixed_vector(8, S32))
+      splitLoad(MI, {V4S32, V4S32});
+    else if (DstTy == LLT::fixed_vector(16, S32))
+      splitLoad(MI, {V4S32, V4S32, V4S32, V4S32});
+    else if (DstTy == LLT::fixed_vector(4, S64))
+      splitLoad(MI, {V2S64, V2S64});
+    else if (DstTy == LLT::fixed_vector(8, S64))
+      splitLoad(MI, {V2S64, V2S64, V2S64, V2S64});
+    else if (DstTy == LLT::fixed_vector(16, S16))
+      splitLoad(MI, {V8S16, V8S16});
+    else if (DstTy == LLT::fixed_vector(32, S16))
+      splitLoad(MI, {V8S16, V8S16, V8S16, V8S16});
+    else if (DstTy == LLT::scalar(256))
+      splitLoad(MI, {LLT::scalar(128), LLT::scalar(128)});
+    else if (DstTy == LLT::scalar(96))
+      splitLoad(MI, {S64, S32}, S32);
+    else if (DstTy == LLT::fixed_vector(3, S32))
+      splitLoad(MI, {LLT::fixed_vector(2, S32), S32}, S32);
+    else if (DstTy == LLT::fixed_vector(6, S16))
+      splitLoad(MI, {LLT::fixed_vector(4, S16), LLT::fixed_vector(2, S16)},
+                LLT::fixed_vector(2, S16));
+    else {
+      MI.dump();
+      llvm_unreachable("SplitLoad type not supported\n");
+    }
+    break;
+  }
+  case WidenLoad: {
+    LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
+    if (DstTy == LLT::scalar(96))
+      widenLoad(MI, LLT::scalar(128));
+    else if (DstTy == LLT::fixed_vector(3, S32))
+      widenLoad(MI, LLT::fixed_vector(4, S32), S32);
+    else if (DstTy == LLT::fixed_vector(6, S16))
+      widenLoad(MI, LLT::fixed_vector(8, S16), LLT::fixed_vector(2, S16));
+    else {
+      MI.dump();
+      llvm_unreachable("WidenLoad type not supported\n");
+    }
+    break;
+  }
   }
 
   // TODO: executeInWaterfallLoop(... WaterfallSGPRs)
@@ -142,13 +280,73 @@ LLT RegBankLegalizeHelper::getTyFromID(RegBankLLTMapingApplyID ID) {
   case Sgpr64:
   case Vgpr64:
     return LLT::scalar(64);
-
+  case SgprP1:
+  case VgprP1:
+    return LLT::pointer(1, 64);
+  case SgprP3:
+  case VgprP3:
+    return LLT::pointer(3, 32);
+  case SgprP4:
+  case VgprP4:
+    return LLT::pointer(4, 64);
+  case SgprP5:
+  case VgprP5:
+    return LLT::pointer(5, 32);
   case SgprV4S32:
   case VgprV4S32:
   case UniInVgprV4S32:
     return LLT::fixed_vector(4, 32);
-  case VgprP1:
-    return LLT::pointer(1, 64);
+  default:
+    return LLT();
+  }
+}
+
+LLT RegBankLegalizeHelper::getBTyFromID(RegBankLLTMapingApplyID ID, LLT Ty) {
+  switch (ID) {
+  case SgprB32:
+  case VgprB32:
+  case UniInVgprB32:
+    if (Ty == LLT::scalar(32) || Ty == LLT::fixed_vector(2, 16) ||
+        Ty == LLT::pointer(3, 32) || Ty == LLT::pointer(5, 32) ||
+        Ty == LLT::pointer(6, 32))
+      return Ty;
+    return LLT();
+  case SgprB64:
+  case VgprB64:
+  case UniInVgprB64:
+    if (Ty == LLT::scalar(64) || Ty == LLT::fixed_vector(2, 32) ||
+        Ty == LLT::fixed_vector(4, 16) || Ty == LLT::pointer(0, 64) ||
+        Ty == LLT::pointer(1, 64) || Ty == LLT::pointer(4, 64))
+      return Ty;
+    return LLT();
+  case SgprB96:
+  case VgprB96:
+  case UniInVgprB96:
+    if (Ty == LLT::scalar(96) || Ty == LLT::fixed_vector(3, 32) ||
+        Ty == LLT::fixed_vector(6, 16))
+      return Ty;
+    return LLT();
+  case SgprB128:
+  case VgprB128:
+  case UniInVgprB128:
+    if (Ty == LLT::scalar(128) || Ty == LLT::fixed_vector(4, 32) ||
+        Ty == LLT::fixed_vector(2, 64))
+      return Ty;
+    return LLT();
+  case SgprB256:
+  case VgprB256:
+  case UniInVgprB256:
+    if (Ty == LLT::scalar(256) || Ty == LLT::fixed_vector(8, 32) ||
+        Ty == LLT::fixed_vector(4, 64) || Ty == LLT::fixed_vector(16, 16))
+      return Ty;
+    return LLT();
+  case SgprB512:
+  case VgprB512:
+  case UniInVgprB512:
+    if (Ty == LLT::scalar(512) || Ty == LLT::fixed_vector(16, 32) ||
+        Ty == LLT::fixed_vector(8, 64))
+      return Ty;
+    return LLT();
   default:
     return LLT();
   }
@@ -163,10 +361,26 @@ RegBankLegalizeHelper::getRBFromID(RegBankLLTMapingApplyID ID) {
   case Sgpr16:
   case Sgpr32:
   case Sgpr64:
+  case SgprP1:
+  case SgprP3:
+  case SgprP4:
+  case SgprP5:
   case SgprV4S32:
+  case SgprB32:
+  case SgprB64:
+  case SgprB96:
+  case SgprB128:
+  case SgprB256:
+  case SgprB512:
   case UniInVcc:
   case UniInVgprS32:
   case UniInVgprV4S32:
+  case UniInVgprB32:
+  case UniInVgprB64:
+  case UniInVgprB96:
+  case UniInVgprB128:
+  case UniInVgprB256:
+  case UniInVgprB512:
   case Sgpr32Trunc:
   case Sgpr32AExt:
   case Sgpr32AExtBoolInReg:
@@ -176,7 +390,16 @@ RegBankLegalizeHelper::getRBFromID(RegBankLLTMapingApplyID ID) {
   case Vgpr32:
   case Vgpr64:
   case VgprP1:
+  case VgprP3:
+  case VgprP4:
+  case VgprP5:
   case VgprV4S32:
+  case VgprB32:
+  case VgprB64:
+  case VgprB96:
+  case VgprB128:
+  case VgprB256:
+  case VgprB512:
     return VgprRB;
 
   default:
@@ -202,17 +425,42 @@ void RegBankLegalizeHelper::applyMappingDst(
     case Sgpr16:
     case Sgpr32:
     case Sgpr64:
+    case SgprP1:
+    case SgprP3:
+    case SgprP4:
+    case SgprP5:
     case SgprV4S32:
     case Vgpr32:
     case Vgpr64:
     case VgprP1:
+    case VgprP3:
+    case VgprP4:
+    case VgprP5:
     case VgprV4S32: {
       assert(Ty == getTyFromID(MethodIDs[OpIdx]));
       assert(RB == getRBFromID(MethodIDs[OpIdx]));
       break;
     }
 
-    // uniform in vcc/vgpr: scalars and vectors
+    // sgpr and vgpr B-types
+    case SgprB32:
+    case SgprB64:
+    case SgprB96:
+    case SgprB128:
+    case SgprB256:
+    case SgprB512:
+    case VgprB32:
+    case VgprB64:
+    case VgprB96:
+    case VgprB128:
+    case VgprB256:
+    case VgprB512: {
+      assert(Ty == getBTyFromID(MethodIDs[OpIdx], Ty));
+      assert(RB == getRBFromID(MethodIDs[OpIdx]));
+      break;
+    }
+
+    // uniform in vcc/vgpr: scalars, vectors and B-types
     case UniInVcc: {
       assert(Ty == S1);
       assert(RB == SgprRB);
@@ -229,6 +477,17 @@ void RegBankLegalizeHelper::applyMappingDst(
       AMDGPU::buildReadAnyLaneDst(B, MI, RBI);
       break;
     }
+    case UniInVgprB32:
+    case UniInVgprB64:
+    case UniInVgprB96:
+    case UniInVgprB128:
+    case UniInVgprB256:
+    case UniInVgprB512: {
+      assert(Ty == getBTyFromID(MethodIDs[OpIdx], Ty));
+      assert(RB == SgprRB);
+      AMDGPU::buildReadAnyLaneDst(B, MI, RBI);
+      break;
+    }
 
     // sgpr trunc
     case Sgpr32Trunc: {
@@ -279,16 +538,34 @@ void RegBankLegalizeHelper::applyMappingSrc(
     case Sgpr16:
     case Sgpr32:
     case Sgpr64:
+    case SgprP1:
+    case SgprP3:
+    case SgprP4:
+    case SgprP5:
     case SgprV4S32: {
       assert(Ty == getTyFromID(MethodIDs[i]));
       assert(RB == getRBFromID(MethodIDs[i]));
       break;
     }
+    // sgpr B-types
+    case SgprB32:
+    case SgprB64:
+    case SgprB96:
+    case SgprB128:
+    case SgprB256:
+    case SgprB512: {
+      assert(Ty == getBTyFromID(MethodIDs[i], Ty));
+      assert(RB == getRBFromID(MethodIDs[i]));
+      break;
+    }
 
     // vgpr scalars, pointers and vectors
     case Vgpr32:
     case Vgpr64:
     case VgprP1:
+    case VgprP3:
+    case VgprP4:
+    case VgprP5:
     case VgprV4S32: {
       assert(Ty == getTyFromID(MethodIDs[i]));
       if (RB != VgprRB) {
@@ -298,6 +575,21 @@ void RegBankLegalizeHelper::applyMappingSrc(
       }
       break;
     }
+    // vgpr B-types
+    case VgprB32:
+    case VgprB64:
+    case VgprB96:
+    case VgprB128:
+    case VgprB256:
+    case VgprB512: {
+      assert(Ty == getBTyFromID(MethodIDs[i], Ty));
+      if (RB != VgprRB) {
+        auto CopyToVgpr =
+            B.buildCopy(createVgpr(getBTyFromID(MethodIDs[i], Ty)), Reg);
+        Op.setReg(CopyToVgpr.getReg(0));
+      }
+      break;
+    }
 
     // sgpr and vgpr scalars with extend
     case Sgpr32AExt: {
@@ -372,7 +664,7 @@ void RegBankLegalizeHelper::applyMappingPHI(MachineInstr &MI) {
   // We accept all types that can fit in some register class.
   // Uniform G_PHIs have all sgpr registers.
   // Divergent G_PHIs have vgpr dst but inputs can be sgpr or vgpr.
-  if (Ty == LLT::scalar(32)) {
+  if (Ty == LLT::scalar(32) || Ty == LLT::pointer(4, 64)) {
     return;
   }
 
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeHelper.h b/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeHelper.h
index e23dfcebe3fe3f..c409df54519c5c 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeHelper.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeHelper.h
@@ -92,6 +92,7 @@ class RegBankLegalizeHelper {
                               SmallSet<Register, 4> &SGPROperandRegs);
 
   LLT getTyFromID(RegBankLLTMapingApplyID ID);
+  LLT getBTyFromID(RegBankLLTMapingApplyID ID, LLT Ty);
 
   const RegisterBank *getRBFromID(RegBankLLTMapingApplyID ID);
 
@@ -104,9 +105,9 @@ class RegBankLegalizeHelper {
                   const SmallVectorImpl<RegBankLLTMapingApplyID> &MethodIDs,
                   SmallSet<Register, 4> &SGPRWaterfallOperandRegs);
 
-  unsigned setBufferOffsets(MachineIRBuilder &B, Register CombinedOffset,
-                            Register &VOffsetReg, Register &SOffsetReg,
-                            int64_t &InstOffsetVal, Align Alignment);
+  void splitLoad(MachineInstr &MI, ArrayRef<LLT> LLTBreakdown,
+                 LLT MergeTy = LLT());
+  void widenLoad(MachineInstr &MI, LLT WideTy, LLT MergeTy = LLT());
 
   void lower(MachineInstr &MI, const RegBankLLTMapping &Mapping,
              SmallSet<Register, 4> &SGPRWaterfallOperandRegs);
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeRules.cpp b/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeRules.cpp
index 1266f99c79c395..895a596cf84f40 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeRules.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeRules.cpp
@@ -14,9 +14,11 @@
 //===----------------------------------------------------------------------===//
 
 #include "AMDGPURBLegalizeRules.h"
+#include "AMDGPUInstrInfo.h"
 #include "GCNSubtarget.h"
 #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
 #include "llvm/IR/IntrinsicsAMDGPU.h"
+#include "llvm/Support/AMDGPUAddrSpace.h"
 
 using namespace llvm;
 using namespace AMDGPU;
@@ -47,6 +49,24 @@ bool matchUniformityAndLLT(Register Reg, UniformityLLTOpPredicateID UniID,
     return MRI.getType(Reg) == LLT::scalar(64);
   case P1:
     return MRI.getType(Reg) == LLT::pointer(1, 64);
+  case P3:
+    return MRI.getType(Reg) == LLT::pointer(3, 32);
+  case P4:
+    return MRI.getType(Reg) == LLT::pointer(4, 64);
+  case P5:
+    return MRI.getType(Reg) == LLT::pointer(5, 32);
+  case B32:
+    return MRI.getType(Reg).getSizeInBits() == 32;
+  case B64:
+    return MRI.getType(Reg).getSizeInBits() == 64;
+  case B96:
+    return MRI.getType(Reg).getSizeInBits() == 96;
+  case B128:
+    return MRI.getType(Reg).getSizeInBits() == 128;
+  case B256:
+    return MRI.getType(Reg).getSizeInBits() == 256;
+  case B512:
+    return MRI.getType(Reg).getSizeInBits() == 512;
 
   case UniS1:
     return MRI.getType(Reg) == LLT::scalar(1) && MUI.isUniform(Reg);
@@ -56,6 +76,26 @@ bool matchUniformityAndLLT(Register Reg, UniformityLLTOpPredicateID UniID,
     return MRI.getType(Reg) == LLT::scalar(32) && MUI.isUniform(Reg);
   case UniS64:
     return MRI.getType(Reg) == LLT::scalar(64) && MUI.isUniform(Reg);
+  case UniP1:
+    return MRI.getType(Reg) == LLT::pointer(1, 64) && MUI.isUniform(Reg);
+  case UniP3:
+    return MRI.getType(Reg) == LLT::pointer(3, 32) && MUI.isUniform(Reg);
+  case UniP4:
+    return MRI.getType(Reg) == LLT::pointer(4, 64) && MUI.isUniform(Reg);
+  case UniP5:
+    return MRI.getType(Reg) == LLT::pointer(5, 32) && MUI.isUniform(Reg);
+  case UniB32:
+    return MRI.getType(Reg).getSizeInBits() == 32 && MUI.isUniform(Reg);
+  case UniB64:
+    return MRI.getType(Reg).getSizeInBits() == 64 && MUI.isUniform(Reg);
+  case UniB96:
+    return MRI.getType(Reg).getSizeInBits() == 96 && MUI.isUniform(Reg);
+  case UniB128:
+    return MRI.getType(Reg).getSizeInBits() == 128 && MUI.isUniform(Reg);
+  case UniB256:
+    return MRI.getType(Reg).getSizeInBits() == 256 && MUI.isUniform(Reg);
+  case UniB512:
+    return MRI.getType(Reg).getSizeInBits() == 512 && MUI.isUniform(Reg);
 
   case DivS1:
     return MRI.getType(Reg) == LLT::scalar(1) && MUI.isDivergent(Reg);
@@ -65,6 +105,24 @@ bool matchUniformityAndLLT(Register Reg, UniformityLLTOpPredicateID UniID,
     return MRI.getType(Reg) == LLT::scalar(64) && MUI.isDivergent(Reg);
   case DivP1:
     return MRI.getType(Reg) == LLT::pointer(1, 64) && MUI.isDivergent(Reg);
+  case DivP3:
+    return MRI.getType(Reg) == LLT::pointer(3, 32) && MUI.isDivergent(Reg);
+  case DivP4:
+    return MRI.getType(Reg) == LLT::pointer(4, 64) && MUI.isDivergent(Reg);
+  case DivP5:
+    return MRI.getType(Reg) == LLT::pointer(5, 32) && MUI.isDivergent(Reg);
+  case DivB32:
+    return MRI.getType(Reg).getSizeInBits() == 32 && MUI.isDivergent(Reg);
+  case DivB64:
+    return MRI.getType(Reg).getSizeInBits() == 64 && MUI.isDivergent(Reg);
+  case DivB96:
+    return MRI.getType(Reg).getSizeInBits() == 96 && MUI.isDivergent(Reg);
+  case DivB128:
+    return MRI.getType(Reg).getSizeInBits() == 128 && MUI.isDivergent(Reg);
+  case DivB256:
+    return MRI.getType(Reg).getSizeInBits() == 256 && MUI.isDivergent(Reg);
+  case DivB512:
+    return MRI.getType(Reg).getSizeInBits() == 512 && MUI.isDivergent(Reg);
 
   case _:
     return true;
@@ -124,6 +182,22 @@ UniformityLLTOpPredicateID LLTToId(LLT Ty) {
   return _;
 }
 
+UniformityLLTOpPredicateID LLTToBId(LLT Ty) {
+  if (Ty == LLT::scalar(32) || Ty == LLT::fixed_vector(2, 16) ||
+      Ty == LLT::pointer(3, 32) || Ty == LLT::pointer(5, 32) ||
+      Ty == LLT::pointer(6, 32))
+    return B32;
+  if (Ty == LLT::scalar(64) || Ty == LLT::fixed_vector(2, 32) ||
+      Ty == LLT::fixed_vector(4, 16) || Ty == LLT::pointer(1, 64) ||
+      Ty == LLT::pointer(4, 64))
+    return B64;
+  if (Ty == LLT::fixed_vector(3, 32))
+    return B96;
+  if (Ty == LLT::fixed_vector(4, 32))
+    return B128;
+  return _;
+}
+
 const RegBankLLTMapping &
 SetOfRulesForOpcode::findMappingForMI(const MachineInstr &MI,
                                       const MachineRegisterInfo &MRI,
@@ -134,7 +208,12 @@ SetOfRulesForOpcode::findMappingForMI(const MachineInstr &MI,
   // returned which results in failure, does not search "Slow Rules".
   if (FastTypes != No) {
     Register Reg = MI.getOperand(0).getReg();
-    int Slot = getFastPredicateSlot(LLTToId(MRI.getType(Reg)));
+    int Slot;
+    if (FastTypes == StandardB)
+      Slot = getFastPredicateSlot(LLTToBId(MRI.getType(Reg)));
+    else
+      Slot = getFastPredicateSlot(LLTToId(MRI.getType(Reg)));
+
     if (Slot != -1) {
       if (MUI.isUniform(Reg))
         return Uni[Slot];
@@ -184,6 +263,19 @@ int SetOfRulesForOpcode::getFastPredicateSlot(
     default:
       return -1;
     }
+  case StandardB:
+    switch (Ty) {
+    case B32:
+      return 0;
+    case B64:
+      return 1;
+    case B96:
+      return 2;
+    case B128:
+      return 3;
+    default:
+      return -1;
+    }
   case Vector:
     switch (Ty) {
     case S32:
@@ -236,6 +328,127 @@ RegBankLegalizeRules::getRulesForOpc(MachineInstr &MI) const {
   return GRules.at(GRulesAlias.at(Opc));
 }
 
+// Syntactic sugar wrapper for predicate lambda that enables '&&', '||' and '!'.
+class Predicate {
+public:
+  struct Elt {
+    // Save formula composed of Pred, '&&', '||' and '!' as a jump table.
+    // Sink ! to Pred. For example !((A && !B) || C) -> (!A || B) && !C
+    // Sequences of && and || will be represented by jumps, for example:
+    // (A && B && ... X) or (A && B && ... X) || Y
+    //   A == true jump to B
+    //   A == false jump to end or Y, result is A(false) or Y
+    // (A || B || ... X) or (A || B || ... X) && Y
+    //   A == true jump to end or Y, result is B(true) or Y
+    //   A == false jump B
+    // Notice that when negating expression, we apply simply flip Neg on each
+    // Pred and swap TJumpOffset and FJumpOffset (&& becomes ||, || becomes &&)....
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/112882


More information about the llvm-branch-commits mailing list