[llvm-branch-commits] [llvm] AMDGPU/GlobalISel: RBLegalize rules for load (PR #112882)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Fri Oct 18 04:12:35 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-amdgpu
Author: Petar Avramovic (petar-avramovic)
<details>
<summary>Changes</summary>
Add IDs for bit width that cover multiple LLTs: B32 B64 etc.
"Predicate" wrapper class for bool predicate functions used to
write pretty rules. Predicates can be combined using &&, || and !.
Lowering for splitting and widening loads.
Write rules for loads to not change existing mir tests from old
regbankselect.
---
Patch is 81.23 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/112882.diff
6 Files Affected:
- (modified) llvm/lib/Target/AMDGPU/AMDGPURBLegalizeHelper.cpp (+297-5)
- (modified) llvm/lib/Target/AMDGPU/AMDGPURBLegalizeHelper.h (+4-3)
- (modified) llvm/lib/Target/AMDGPU/AMDGPURBLegalizeRules.cpp (+300-7)
- (modified) llvm/lib/Target/AMDGPU/AMDGPURBLegalizeRules.h (+63-2)
- (modified) llvm/test/CodeGen/AMDGPU/GlobalISel/regbankselect-load.mir (+271-49)
- (modified) llvm/test/CodeGen/AMDGPU/GlobalISel/regbankselect-zextload.mir (+7-2)
``````````diff
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeHelper.cpp b/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeHelper.cpp
index a0f6ecedab7a83..f58f0a315096d2 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeHelper.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeHelper.cpp
@@ -37,6 +37,97 @@ bool RegBankLegalizeHelper::findRuleAndApplyMapping(MachineInstr &MI) {
return true;
}
+void RegBankLegalizeHelper::splitLoad(MachineInstr &MI,
+ ArrayRef<LLT> LLTBreakdown, LLT MergeTy) {
+ MachineFunction &MF = B.getMF();
+ assert(MI.getNumMemOperands() == 1);
+ MachineMemOperand &BaseMMO = **MI.memoperands_begin();
+ Register Dst = MI.getOperand(0).getReg();
+ const RegisterBank *DstRB = MRI.getRegBankOrNull(Dst);
+ Register BasePtrReg = MI.getOperand(1).getReg();
+ LLT PtrTy = MRI.getType(BasePtrReg);
+ const RegisterBank *PtrRB = MRI.getRegBankOrNull(BasePtrReg);
+ LLT OffsetTy = LLT::scalar(PtrTy.getSizeInBits());
+ SmallVector<Register, 4> LoadPartRegs;
+
+ unsigned ByteOffset = 0;
+ for (LLT PartTy : LLTBreakdown) {
+ Register BasePtrPlusOffsetReg;
+ if (ByteOffset == 0) {
+ BasePtrPlusOffsetReg = BasePtrReg;
+ } else {
+ BasePtrPlusOffsetReg = MRI.createVirtualRegister({PtrRB, PtrTy});
+ Register OffsetReg = MRI.createVirtualRegister({PtrRB, OffsetTy});
+ B.buildConstant(OffsetReg, ByteOffset);
+ B.buildPtrAdd(BasePtrPlusOffsetReg, BasePtrReg, OffsetReg);
+ }
+ MachineMemOperand *BasePtrPlusOffsetMMO =
+ MF.getMachineMemOperand(&BaseMMO, ByteOffset, PartTy);
+ Register PartLoad = MRI.createVirtualRegister({DstRB, PartTy});
+ B.buildLoad(PartLoad, BasePtrPlusOffsetReg, *BasePtrPlusOffsetMMO);
+ LoadPartRegs.push_back(PartLoad);
+ ByteOffset += PartTy.getSizeInBytes();
+ }
+
+ if (!MergeTy.isValid()) {
+ // Loads are of same size, concat or merge them together.
+ B.buildMergeLikeInstr(Dst, LoadPartRegs);
+ } else {
+ // Load(s) are not all of same size, need to unmerge them to smaller pieces
+ // of MergeTy type, then merge them all together in Dst.
+ SmallVector<Register, 4> MergeTyParts;
+ for (Register Reg : LoadPartRegs) {
+ if (MRI.getType(Reg) == MergeTy) {
+ MergeTyParts.push_back(Reg);
+ } else {
+ auto Unmerge = B.buildUnmerge(MergeTy, Reg);
+ for (unsigned i = 0; i < Unmerge->getNumOperands() - 1; ++i) {
+ Register UnmergeReg = Unmerge->getOperand(i).getReg();
+ MRI.setRegBank(UnmergeReg, *DstRB);
+ MergeTyParts.push_back(UnmergeReg);
+ }
+ }
+ }
+ B.buildMergeLikeInstr(Dst, MergeTyParts);
+ }
+ MI.eraseFromParent();
+}
+
+void RegBankLegalizeHelper::widenLoad(MachineInstr &MI, LLT WideTy,
+ LLT MergeTy) {
+ MachineFunction &MF = B.getMF();
+ assert(MI.getNumMemOperands() == 1);
+ MachineMemOperand &BaseMMO = **MI.memoperands_begin();
+ Register Dst = MI.getOperand(0).getReg();
+ const RegisterBank *DstRB = MRI.getRegBankOrNull(Dst);
+ Register BasePtrReg = MI.getOperand(1).getReg();
+
+ Register BasePtrPlusOffsetReg;
+ BasePtrPlusOffsetReg = BasePtrReg;
+
+ MachineMemOperand *BasePtrPlusOffsetMMO =
+ MF.getMachineMemOperand(&BaseMMO, 0, WideTy);
+ Register WideLoad = MRI.createVirtualRegister({DstRB, WideTy});
+ B.buildLoad(WideLoad, BasePtrPlusOffsetReg, *BasePtrPlusOffsetMMO);
+
+ if (WideTy.isScalar()) {
+ B.buildTrunc(Dst, WideLoad);
+ } else {
+ SmallVector<Register, 4> MergeTyParts;
+ unsigned NumEltsMerge =
+ MRI.getType(Dst).getSizeInBits() / MergeTy.getSizeInBits();
+ auto Unmerge = B.buildUnmerge(MergeTy, WideLoad);
+ for (unsigned i = 0; i < Unmerge->getNumOperands() - 1; ++i) {
+ Register UnmergeReg = Unmerge->getOperand(i).getReg();
+ MRI.setRegBank(UnmergeReg, *DstRB);
+ if (i < NumEltsMerge)
+ MergeTyParts.push_back(UnmergeReg);
+ }
+ B.buildMergeLikeInstr(Dst, MergeTyParts);
+ }
+ MI.eraseFromParent();
+}
+
void RegBankLegalizeHelper::lower(MachineInstr &MI,
const RegBankLLTMapping &Mapping,
SmallSet<Register, 4> &WaterfallSGPRs) {
@@ -119,6 +210,53 @@ void RegBankLegalizeHelper::lower(MachineInstr &MI,
MI.eraseFromParent();
break;
}
+ case SplitLoad: {
+ LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
+ LLT V8S16 = LLT::fixed_vector(8, S16);
+ LLT V4S32 = LLT::fixed_vector(4, S32);
+ LLT V2S64 = LLT::fixed_vector(2, S64);
+
+ if (DstTy == LLT::fixed_vector(8, S32))
+ splitLoad(MI, {V4S32, V4S32});
+ else if (DstTy == LLT::fixed_vector(16, S32))
+ splitLoad(MI, {V4S32, V4S32, V4S32, V4S32});
+ else if (DstTy == LLT::fixed_vector(4, S64))
+ splitLoad(MI, {V2S64, V2S64});
+ else if (DstTy == LLT::fixed_vector(8, S64))
+ splitLoad(MI, {V2S64, V2S64, V2S64, V2S64});
+ else if (DstTy == LLT::fixed_vector(16, S16))
+ splitLoad(MI, {V8S16, V8S16});
+ else if (DstTy == LLT::fixed_vector(32, S16))
+ splitLoad(MI, {V8S16, V8S16, V8S16, V8S16});
+ else if (DstTy == LLT::scalar(256))
+ splitLoad(MI, {LLT::scalar(128), LLT::scalar(128)});
+ else if (DstTy == LLT::scalar(96))
+ splitLoad(MI, {S64, S32}, S32);
+ else if (DstTy == LLT::fixed_vector(3, S32))
+ splitLoad(MI, {LLT::fixed_vector(2, S32), S32}, S32);
+ else if (DstTy == LLT::fixed_vector(6, S16))
+ splitLoad(MI, {LLT::fixed_vector(4, S16), LLT::fixed_vector(2, S16)},
+ LLT::fixed_vector(2, S16));
+ else {
+ MI.dump();
+ llvm_unreachable("SplitLoad type not supported\n");
+ }
+ break;
+ }
+ case WidenLoad: {
+ LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
+ if (DstTy == LLT::scalar(96))
+ widenLoad(MI, LLT::scalar(128));
+ else if (DstTy == LLT::fixed_vector(3, S32))
+ widenLoad(MI, LLT::fixed_vector(4, S32), S32);
+ else if (DstTy == LLT::fixed_vector(6, S16))
+ widenLoad(MI, LLT::fixed_vector(8, S16), LLT::fixed_vector(2, S16));
+ else {
+ MI.dump();
+ llvm_unreachable("WidenLoad type not supported\n");
+ }
+ break;
+ }
}
// TODO: executeInWaterfallLoop(... WaterfallSGPRs)
@@ -142,13 +280,73 @@ LLT RegBankLegalizeHelper::getTyFromID(RegBankLLTMapingApplyID ID) {
case Sgpr64:
case Vgpr64:
return LLT::scalar(64);
-
+ case SgprP1:
+ case VgprP1:
+ return LLT::pointer(1, 64);
+ case SgprP3:
+ case VgprP3:
+ return LLT::pointer(3, 32);
+ case SgprP4:
+ case VgprP4:
+ return LLT::pointer(4, 64);
+ case SgprP5:
+ case VgprP5:
+ return LLT::pointer(5, 32);
case SgprV4S32:
case VgprV4S32:
case UniInVgprV4S32:
return LLT::fixed_vector(4, 32);
- case VgprP1:
- return LLT::pointer(1, 64);
+ default:
+ return LLT();
+ }
+}
+
+LLT RegBankLegalizeHelper::getBTyFromID(RegBankLLTMapingApplyID ID, LLT Ty) {
+ switch (ID) {
+ case SgprB32:
+ case VgprB32:
+ case UniInVgprB32:
+ if (Ty == LLT::scalar(32) || Ty == LLT::fixed_vector(2, 16) ||
+ Ty == LLT::pointer(3, 32) || Ty == LLT::pointer(5, 32) ||
+ Ty == LLT::pointer(6, 32))
+ return Ty;
+ return LLT();
+ case SgprB64:
+ case VgprB64:
+ case UniInVgprB64:
+ if (Ty == LLT::scalar(64) || Ty == LLT::fixed_vector(2, 32) ||
+ Ty == LLT::fixed_vector(4, 16) || Ty == LLT::pointer(0, 64) ||
+ Ty == LLT::pointer(1, 64) || Ty == LLT::pointer(4, 64))
+ return Ty;
+ return LLT();
+ case SgprB96:
+ case VgprB96:
+ case UniInVgprB96:
+ if (Ty == LLT::scalar(96) || Ty == LLT::fixed_vector(3, 32) ||
+ Ty == LLT::fixed_vector(6, 16))
+ return Ty;
+ return LLT();
+ case SgprB128:
+ case VgprB128:
+ case UniInVgprB128:
+ if (Ty == LLT::scalar(128) || Ty == LLT::fixed_vector(4, 32) ||
+ Ty == LLT::fixed_vector(2, 64))
+ return Ty;
+ return LLT();
+ case SgprB256:
+ case VgprB256:
+ case UniInVgprB256:
+ if (Ty == LLT::scalar(256) || Ty == LLT::fixed_vector(8, 32) ||
+ Ty == LLT::fixed_vector(4, 64) || Ty == LLT::fixed_vector(16, 16))
+ return Ty;
+ return LLT();
+ case SgprB512:
+ case VgprB512:
+ case UniInVgprB512:
+ if (Ty == LLT::scalar(512) || Ty == LLT::fixed_vector(16, 32) ||
+ Ty == LLT::fixed_vector(8, 64))
+ return Ty;
+ return LLT();
default:
return LLT();
}
@@ -163,10 +361,26 @@ RegBankLegalizeHelper::getRBFromID(RegBankLLTMapingApplyID ID) {
case Sgpr16:
case Sgpr32:
case Sgpr64:
+ case SgprP1:
+ case SgprP3:
+ case SgprP4:
+ case SgprP5:
case SgprV4S32:
+ case SgprB32:
+ case SgprB64:
+ case SgprB96:
+ case SgprB128:
+ case SgprB256:
+ case SgprB512:
case UniInVcc:
case UniInVgprS32:
case UniInVgprV4S32:
+ case UniInVgprB32:
+ case UniInVgprB64:
+ case UniInVgprB96:
+ case UniInVgprB128:
+ case UniInVgprB256:
+ case UniInVgprB512:
case Sgpr32Trunc:
case Sgpr32AExt:
case Sgpr32AExtBoolInReg:
@@ -176,7 +390,16 @@ RegBankLegalizeHelper::getRBFromID(RegBankLLTMapingApplyID ID) {
case Vgpr32:
case Vgpr64:
case VgprP1:
+ case VgprP3:
+ case VgprP4:
+ case VgprP5:
case VgprV4S32:
+ case VgprB32:
+ case VgprB64:
+ case VgprB96:
+ case VgprB128:
+ case VgprB256:
+ case VgprB512:
return VgprRB;
default:
@@ -202,17 +425,42 @@ void RegBankLegalizeHelper::applyMappingDst(
case Sgpr16:
case Sgpr32:
case Sgpr64:
+ case SgprP1:
+ case SgprP3:
+ case SgprP4:
+ case SgprP5:
case SgprV4S32:
case Vgpr32:
case Vgpr64:
case VgprP1:
+ case VgprP3:
+ case VgprP4:
+ case VgprP5:
case VgprV4S32: {
assert(Ty == getTyFromID(MethodIDs[OpIdx]));
assert(RB == getRBFromID(MethodIDs[OpIdx]));
break;
}
- // uniform in vcc/vgpr: scalars and vectors
+ // sgpr and vgpr B-types
+ case SgprB32:
+ case SgprB64:
+ case SgprB96:
+ case SgprB128:
+ case SgprB256:
+ case SgprB512:
+ case VgprB32:
+ case VgprB64:
+ case VgprB96:
+ case VgprB128:
+ case VgprB256:
+ case VgprB512: {
+ assert(Ty == getBTyFromID(MethodIDs[OpIdx], Ty));
+ assert(RB == getRBFromID(MethodIDs[OpIdx]));
+ break;
+ }
+
+ // uniform in vcc/vgpr: scalars, vectors and B-types
case UniInVcc: {
assert(Ty == S1);
assert(RB == SgprRB);
@@ -229,6 +477,17 @@ void RegBankLegalizeHelper::applyMappingDst(
AMDGPU::buildReadAnyLaneDst(B, MI, RBI);
break;
}
+ case UniInVgprB32:
+ case UniInVgprB64:
+ case UniInVgprB96:
+ case UniInVgprB128:
+ case UniInVgprB256:
+ case UniInVgprB512: {
+ assert(Ty == getBTyFromID(MethodIDs[OpIdx], Ty));
+ assert(RB == SgprRB);
+ AMDGPU::buildReadAnyLaneDst(B, MI, RBI);
+ break;
+ }
// sgpr trunc
case Sgpr32Trunc: {
@@ -279,16 +538,34 @@ void RegBankLegalizeHelper::applyMappingSrc(
case Sgpr16:
case Sgpr32:
case Sgpr64:
+ case SgprP1:
+ case SgprP3:
+ case SgprP4:
+ case SgprP5:
case SgprV4S32: {
assert(Ty == getTyFromID(MethodIDs[i]));
assert(RB == getRBFromID(MethodIDs[i]));
break;
}
+ // sgpr B-types
+ case SgprB32:
+ case SgprB64:
+ case SgprB96:
+ case SgprB128:
+ case SgprB256:
+ case SgprB512: {
+ assert(Ty == getBTyFromID(MethodIDs[i], Ty));
+ assert(RB == getRBFromID(MethodIDs[i]));
+ break;
+ }
// vgpr scalars, pointers and vectors
case Vgpr32:
case Vgpr64:
case VgprP1:
+ case VgprP3:
+ case VgprP4:
+ case VgprP5:
case VgprV4S32: {
assert(Ty == getTyFromID(MethodIDs[i]));
if (RB != VgprRB) {
@@ -298,6 +575,21 @@ void RegBankLegalizeHelper::applyMappingSrc(
}
break;
}
+ // vgpr B-types
+ case VgprB32:
+ case VgprB64:
+ case VgprB96:
+ case VgprB128:
+ case VgprB256:
+ case VgprB512: {
+ assert(Ty == getBTyFromID(MethodIDs[i], Ty));
+ if (RB != VgprRB) {
+ auto CopyToVgpr =
+ B.buildCopy(createVgpr(getBTyFromID(MethodIDs[i], Ty)), Reg);
+ Op.setReg(CopyToVgpr.getReg(0));
+ }
+ break;
+ }
// sgpr and vgpr scalars with extend
case Sgpr32AExt: {
@@ -372,7 +664,7 @@ void RegBankLegalizeHelper::applyMappingPHI(MachineInstr &MI) {
// We accept all types that can fit in some register class.
// Uniform G_PHIs have all sgpr registers.
// Divergent G_PHIs have vgpr dst but inputs can be sgpr or vgpr.
- if (Ty == LLT::scalar(32)) {
+ if (Ty == LLT::scalar(32) || Ty == LLT::pointer(4, 64)) {
return;
}
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeHelper.h b/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeHelper.h
index e23dfcebe3fe3f..c409df54519c5c 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeHelper.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeHelper.h
@@ -92,6 +92,7 @@ class RegBankLegalizeHelper {
SmallSet<Register, 4> &SGPROperandRegs);
LLT getTyFromID(RegBankLLTMapingApplyID ID);
+ LLT getBTyFromID(RegBankLLTMapingApplyID ID, LLT Ty);
const RegisterBank *getRBFromID(RegBankLLTMapingApplyID ID);
@@ -104,9 +105,9 @@ class RegBankLegalizeHelper {
const SmallVectorImpl<RegBankLLTMapingApplyID> &MethodIDs,
SmallSet<Register, 4> &SGPRWaterfallOperandRegs);
- unsigned setBufferOffsets(MachineIRBuilder &B, Register CombinedOffset,
- Register &VOffsetReg, Register &SOffsetReg,
- int64_t &InstOffsetVal, Align Alignment);
+ void splitLoad(MachineInstr &MI, ArrayRef<LLT> LLTBreakdown,
+ LLT MergeTy = LLT());
+ void widenLoad(MachineInstr &MI, LLT WideTy, LLT MergeTy = LLT());
void lower(MachineInstr &MI, const RegBankLLTMapping &Mapping,
SmallSet<Register, 4> &SGPRWaterfallOperandRegs);
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeRules.cpp b/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeRules.cpp
index 1266f99c79c395..895a596cf84f40 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeRules.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeRules.cpp
@@ -14,9 +14,11 @@
//===----------------------------------------------------------------------===//
#include "AMDGPURBLegalizeRules.h"
+#include "AMDGPUInstrInfo.h"
#include "GCNSubtarget.h"
#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
#include "llvm/IR/IntrinsicsAMDGPU.h"
+#include "llvm/Support/AMDGPUAddrSpace.h"
using namespace llvm;
using namespace AMDGPU;
@@ -47,6 +49,24 @@ bool matchUniformityAndLLT(Register Reg, UniformityLLTOpPredicateID UniID,
return MRI.getType(Reg) == LLT::scalar(64);
case P1:
return MRI.getType(Reg) == LLT::pointer(1, 64);
+ case P3:
+ return MRI.getType(Reg) == LLT::pointer(3, 32);
+ case P4:
+ return MRI.getType(Reg) == LLT::pointer(4, 64);
+ case P5:
+ return MRI.getType(Reg) == LLT::pointer(5, 32);
+ case B32:
+ return MRI.getType(Reg).getSizeInBits() == 32;
+ case B64:
+ return MRI.getType(Reg).getSizeInBits() == 64;
+ case B96:
+ return MRI.getType(Reg).getSizeInBits() == 96;
+ case B128:
+ return MRI.getType(Reg).getSizeInBits() == 128;
+ case B256:
+ return MRI.getType(Reg).getSizeInBits() == 256;
+ case B512:
+ return MRI.getType(Reg).getSizeInBits() == 512;
case UniS1:
return MRI.getType(Reg) == LLT::scalar(1) && MUI.isUniform(Reg);
@@ -56,6 +76,26 @@ bool matchUniformityAndLLT(Register Reg, UniformityLLTOpPredicateID UniID,
return MRI.getType(Reg) == LLT::scalar(32) && MUI.isUniform(Reg);
case UniS64:
return MRI.getType(Reg) == LLT::scalar(64) && MUI.isUniform(Reg);
+ case UniP1:
+ return MRI.getType(Reg) == LLT::pointer(1, 64) && MUI.isUniform(Reg);
+ case UniP3:
+ return MRI.getType(Reg) == LLT::pointer(3, 32) && MUI.isUniform(Reg);
+ case UniP4:
+ return MRI.getType(Reg) == LLT::pointer(4, 64) && MUI.isUniform(Reg);
+ case UniP5:
+ return MRI.getType(Reg) == LLT::pointer(5, 32) && MUI.isUniform(Reg);
+ case UniB32:
+ return MRI.getType(Reg).getSizeInBits() == 32 && MUI.isUniform(Reg);
+ case UniB64:
+ return MRI.getType(Reg).getSizeInBits() == 64 && MUI.isUniform(Reg);
+ case UniB96:
+ return MRI.getType(Reg).getSizeInBits() == 96 && MUI.isUniform(Reg);
+ case UniB128:
+ return MRI.getType(Reg).getSizeInBits() == 128 && MUI.isUniform(Reg);
+ case UniB256:
+ return MRI.getType(Reg).getSizeInBits() == 256 && MUI.isUniform(Reg);
+ case UniB512:
+ return MRI.getType(Reg).getSizeInBits() == 512 && MUI.isUniform(Reg);
case DivS1:
return MRI.getType(Reg) == LLT::scalar(1) && MUI.isDivergent(Reg);
@@ -65,6 +105,24 @@ bool matchUniformityAndLLT(Register Reg, UniformityLLTOpPredicateID UniID,
return MRI.getType(Reg) == LLT::scalar(64) && MUI.isDivergent(Reg);
case DivP1:
return MRI.getType(Reg) == LLT::pointer(1, 64) && MUI.isDivergent(Reg);
+ case DivP3:
+ return MRI.getType(Reg) == LLT::pointer(3, 32) && MUI.isDivergent(Reg);
+ case DivP4:
+ return MRI.getType(Reg) == LLT::pointer(4, 64) && MUI.isDivergent(Reg);
+ case DivP5:
+ return MRI.getType(Reg) == LLT::pointer(5, 32) && MUI.isDivergent(Reg);
+ case DivB32:
+ return MRI.getType(Reg).getSizeInBits() == 32 && MUI.isDivergent(Reg);
+ case DivB64:
+ return MRI.getType(Reg).getSizeInBits() == 64 && MUI.isDivergent(Reg);
+ case DivB96:
+ return MRI.getType(Reg).getSizeInBits() == 96 && MUI.isDivergent(Reg);
+ case DivB128:
+ return MRI.getType(Reg).getSizeInBits() == 128 && MUI.isDivergent(Reg);
+ case DivB256:
+ return MRI.getType(Reg).getSizeInBits() == 256 && MUI.isDivergent(Reg);
+ case DivB512:
+ return MRI.getType(Reg).getSizeInBits() == 512 && MUI.isDivergent(Reg);
case _:
return true;
@@ -124,6 +182,22 @@ UniformityLLTOpPredicateID LLTToId(LLT Ty) {
return _;
}
+UniformityLLTOpPredicateID LLTToBId(LLT Ty) {
+ if (Ty == LLT::scalar(32) || Ty == LLT::fixed_vector(2, 16) ||
+ Ty == LLT::pointer(3, 32) || Ty == LLT::pointer(5, 32) ||
+ Ty == LLT::pointer(6, 32))
+ return B32;
+ if (Ty == LLT::scalar(64) || Ty == LLT::fixed_vector(2, 32) ||
+ Ty == LLT::fixed_vector(4, 16) || Ty == LLT::pointer(1, 64) ||
+ Ty == LLT::pointer(4, 64))
+ return B64;
+ if (Ty == LLT::fixed_vector(3, 32))
+ return B96;
+ if (Ty == LLT::fixed_vector(4, 32))
+ return B128;
+ return _;
+}
+
const RegBankLLTMapping &
SetOfRulesForOpcode::findMappingForMI(const MachineInstr &MI,
const MachineRegisterInfo &MRI,
@@ -134,7 +208,12 @@ SetOfRulesForOpcode::findMappingForMI(const MachineInstr &MI,
// returned which results in failure, does not search "Slow Rules".
if (FastTypes != No) {
Register Reg = MI.getOperand(0).getReg();
- int Slot = getFastPredicateSlot(LLTToId(MRI.getType(Reg)));
+ int Slot;
+ if (FastTypes == StandardB)
+ Slot = getFastPredicateSlot(LLTToBId(MRI.getType(Reg)));
+ else
+ Slot = getFastPredicateSlot(LLTToId(MRI.getType(Reg)));
+
if (Slot != -1) {
if (MUI.isUniform(Reg))
return Uni[Slot];
@@ -184,6 +263,19 @@ int SetOfRulesForOpcode::getFastPredicateSlot(
default:
return -1;
}
+ case StandardB:
+ switch (Ty) {
+ case B32:
+ return 0;
+ case B64:
+ return 1;
+ case B96:
+ return 2;
+ case B128:
+ return 3;
+ default:
+ return -1;
+ }
case Vector:
switch (Ty) {
case S32:
@@ -236,6 +328,127 @@ RegBankLegalizeRules::getRulesForOpc(MachineInstr &MI) const {
return GRules.at(GRulesAlias.at(Opc));
}
+// Syntactic sugar wrapper for predicate lambda that enables '&&', '||' and '!'.
+class Predicate {
+public:
+ struct Elt {
+ // Save formula composed of Pred, '&&', '||' and '!' as a jump table.
+ // Sink ! to Pred. For example !((A && !B) || C) -> (!A || B) && !C
+ // Sequences of && and || will be represented by jumps, for example:
+ // (A && B && ... X) or (A && B && ... X) || Y
+ // A == true jump to B
+ // A == false jump to end or Y, result is A(false) or Y
+ // (A || B || ... X) or (A || B || ... X) && Y
+ // A == true jump to end or Y, result is B(true) or Y
+ // A == false jump B
+ // Notice that when negating expression, we apply simply flip Neg on each
+ // Pred and swap TJumpOffset and FJumpOffset (&& becomes ||, || becomes &&)....
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/112882
More information about the llvm-branch-commits
mailing list