[llvm] [AMDGPU]: Rewrite mbcnt_lo/mbcnt_hi to work item ID where applicable (PR #160496)
Krzysztof Drewniak via llvm-commits
llvm-commits at lists.llvm.org
Wed Sep 24 11:39:53 PDT 2025
================
@@ -1312,9 +1317,122 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
break;
}
case Intrinsic::amdgcn_mbcnt_hi: {
- // exec_hi is all 0, so this is just a copy.
- if (ST->isWave32())
+ // exec_hi is all 0, so this is just a copy on wave32.
+ if (ST && ST->isWave32())
return IC.replaceInstUsesWith(II, II.getArgOperand(1));
+
+ // Pattern: mbcnt.hi(~0, mbcnt.lo(~0, 0))
+ if (auto *HiArg1 = dyn_cast<CallInst>(II.getArgOperand(1))) {
+ Function *CalledF = HiArg1->getCalledFunction();
+ bool IsMbcntLo = false;
+ if (CalledF) {
+ // Fast-path: if this is a declared intrinsic, check the intrinsic ID.
+ if (CalledF->getIntrinsicID() == Intrinsic::amdgcn_mbcnt_lo) {
+ IsMbcntLo = true;
+ } else {
+ // Fallback: accept a declared function with the canonical name, but
+ // verify its signature to be safe: i32(i32,i32). Use the name
+ // comparison only when there's no intrinsic ID match.
+ if (CalledF->getName() == "llvm.amdgcn.mbcnt.lo") {
+ if (FunctionType *FT = CalledF->getFunctionType()) {
+ if (FT->getNumParams() == 2 &&
+ FT->getReturnType()->isIntegerTy(32) &&
+ FT->getParamType(0)->isIntegerTy(32) &&
+ FT->getParamType(1)->isIntegerTy(32))
+ IsMbcntLo = true;
+ }
+ }
+ }
+ }
+
+ if (!IsMbcntLo)
+ break;
+
+ // hi arg0 must be all-ones
+ if (auto *HiArg0C = dyn_cast<ConstantInt>(II.getArgOperand(0))) {
+ if (!HiArg0C->isAllOnesValue())
+ break;
+ } else
+ break;
+
+ // lo args: arg0 == ~0, arg1 == 0
+ Value *Lo0 = HiArg1->getArgOperand(0);
+ Value *Lo1 = HiArg1->getArgOperand(1);
+ auto *Lo0C = dyn_cast<ConstantInt>(Lo0);
+ auto *Lo1C = dyn_cast<ConstantInt>(Lo1);
+ if (!Lo0C || !Lo1C)
+ break;
+ if (!Lo0C->isAllOnesValue() || !Lo1C->isZero())
+ break;
+
+ // Query reqd_work_group_size via subtarget helper and compare X to wave
+ // size conservatively.
+ if (Function *F = II.getFunction()) {
+ unsigned Wave = 0;
+ if (ST && ST->isWaveSizeKnown())
+ Wave = ST->getWavefrontSize();
+
+ if (ST) {
+ if (auto MaybeX = ST->getReqdWorkGroupSize(*F, 0)) {
+ unsigned XLen = *MaybeX;
+ if (Wave == 0 && (XLen == WavefrontSize32 ||
----------------
krzysz00 wrote:
How can `getWavefrontSize()` fail?
https://github.com/llvm/llvm-project/pull/160496
More information about the llvm-commits
mailing list