[llvm] [AMDGPU]: Rewrite mbcnt_lo/mbcnt_hi to work item ID where applicable (PR #160496)

Teja Alaghari via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 29 01:52:09 PDT 2025


================
@@ -2113,6 +2119,172 @@ INITIALIZE_PASS_DEPENDENCY(UniformityInfoWrapperPass)
 INITIALIZE_PASS_END(AMDGPUCodeGenPrepare, DEBUG_TYPE, "AMDGPU IR optimizations",
                     false, false)
 
+/// Optimize mbcnt.lo calls on wave32 architectures for lane ID computation.
+bool AMDGPUCodeGenPrepareImpl::visitMbcntLo(IntrinsicInst &I) {
+  // This optimization only applies to wave32 targets where mbcnt.lo operates on
+  // the full execution mask.
+  if (!ST.isWave32())
+    return false;
+
+  // Only optimize the pattern mbcnt.lo(~0, 0) which counts active lanes with
+  // lower IDs.
+  auto *Arg0C = dyn_cast<ConstantInt>(I.getArgOperand(0));
+  auto *Arg1C = dyn_cast<ConstantInt>(I.getArgOperand(1));
+  if (!Arg0C || !Arg1C || !Arg0C->isAllOnesValue() || !Arg1C->isZero())
+    return false;
+
+  // Abort if wave size is not known at compile time.
+  if (!ST.isWaveSizeKnown())
+    return false;
+
+  unsigned Wave = ST.getWavefrontSize();
+
+  if (auto MaybeX = ST.getReqdWorkGroupSize(F, 0)) {
+    unsigned XLen = *MaybeX;
+
+    // When XLen == wave_size, each work group contains exactly one wave, so
+    // mbcnt.lo(~0, 0) directly equals the workitem ID within the group.
+    if (XLen == Wave) {
+      IRBuilder<> B(&I);
+      CallInst *NewCall =
+          B.CreateIntrinsic(Intrinsic::amdgcn_workitem_id_x, {});
+      NewCall->takeName(&I);
+      ST.makeLIDRangeMetadata(NewCall);
+      I.replaceAllUsesWith(NewCall);
+      I.eraseFromParent();
+      return true;
+    }
+    // When work group evenly splits into waves and wave size is power-of-2,
+    // we can compute lane ID within wave using bit masking:
+    // lane_id = workitem.id.x & (wave_size - 1).
+    if (ST.hasWavefrontsEvenlySplittingXDim(F, /*RequiresUniformYZ=*/true)) {
+      if (isPowerOf2_32(Wave)) {
+        IRBuilder<> B(&I);
+        CallInst *Tid = B.CreateIntrinsic(Intrinsic::amdgcn_workitem_id_x, {});
+        ST.makeLIDRangeMetadata(Tid);
+        IntegerType *ITy = cast<IntegerType>(Tid->getType());
+        Constant *Mask = ConstantInt::get(ITy, Wave - 1);
+        Instruction *AndInst = cast<Instruction>(B.CreateAnd(Tid, Mask));
+        AndInst->takeName(&I);
+        I.replaceAllUsesWith(AndInst);
+        I.eraseFromParent();
+        return true;
+      }
+    }
+  }
+
+  return false;
+}
+
+/// Optimize mbcnt.hi calls for lane ID computation.
+bool AMDGPUCodeGenPrepareImpl::visitMbcntHi(IntrinsicInst &I) {
+  // On wave32, the upper 32 bits of exec are always 0, so mbcnt.hi(mask, val)
+  // always returns val unchanged.
+  if (ST.isWave32()) {
+    // Abort if wave size is not known at compile time.
+    if (!ST.isWaveSizeKnown())
+      return false;
+
+    unsigned Wave = ST.getWavefrontSize();
+
+    if (auto MaybeX = ST.getReqdWorkGroupSize(F, 0)) {
+      unsigned XLen = *MaybeX;
+
+      // Replace mbcnt.hi(mask, val) with val only when work group size matches
+      // wave size (single wave per work group).
+      if (XLen == Wave) {
+        I.replaceAllUsesWith(I.getArgOperand(1));
+        I.eraseFromParent();
+        return true;
+      }
+    }
+  }
+
+  // Optimize the complete lane ID computation pattern:
+  // mbcnt.hi(~0, mbcnt.lo(~0, 0)) which counts all active lanes with lower IDs
+  // across the full execution mask.
+  auto *HiArg1 = dyn_cast<CallInst>(I.getArgOperand(1));
+  if (!HiArg1)
+    return false;
+
+  Function *CalledF = HiArg1->getCalledFunction();
+  if (!CalledF || CalledF->getIntrinsicID() != Intrinsic::amdgcn_mbcnt_lo)
+    return false;
+
+  // mbcnt.hi mask must be all-ones (count from upper 32 bits)
+  auto *HiArg0C = dyn_cast<ConstantInt>(I.getArgOperand(0));
+  if (!HiArg0C || !HiArg0C->isAllOnesValue())
+    return false;
+
+  // mbcnt.lo mask must be all-ones (mask=~0, all lanes) and base must be 0.
+  Value *Lo0 = HiArg1->getArgOperand(0);
+  Value *Lo1 = HiArg1->getArgOperand(1);
+  auto *Lo0C = dyn_cast<ConstantInt>(Lo0);
+  auto *Lo1C = dyn_cast<ConstantInt>(Lo1);
+  if (!Lo0C || !Lo1C || !Lo0C->isAllOnesValue() || !Lo1C->isZero())
+    return false;
----------------
TejaX-Alaghari wrote:

Used match function to do the same now.

https://github.com/llvm/llvm-project/pull/160496


More information about the llvm-commits mailing list