[Mlir-commits] [mlir] [mlir][GPU] Fixes subgroup reduce lowering (PR #141825)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed May 28 12:04:57 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-gpu

Author: Muzammil (Muzammiluddin-Syed-ECE)

<details>
<summary>Changes</summary>

Fixes issues with subgroup reduce lowering producing erroneous results.

Aligning more closely with triton implementation : https://github.com/triton-lang/triton/pull/5019

---
Full diff: https://github.com/llvm/llvm-project/pull/141825.diff


1 Files Affected:

- (modified) mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp (+18-11) 


``````````diff
diff --git a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
index 74face4291353..b7a12763057f0 100644
--- a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
@@ -447,29 +447,36 @@ createSubgroupDPPReduction(PatternRewriter &rewriter, gpu::SubgroupReduceOp op,
   if (ci.clusterSize >= 64) {
     if (chipset.majorVersion <= 9) {
       // Broadcast 31st lane value to rows 2 and 3.
-      // Use row mask to avoid polluting rows 0 and 1.
       dpp = rewriter.create<amdgpu::DPPOp>(
           loc, res.getType(), res, res, amdgpu::DPPPerm::row_bcast_31,
-          rewriter.getUnitAttr(), 0xc, allBanks,
-          /*bound_ctrl*/ false);
+          rewriter.getUnitAttr(), 0xf, allBanks,
+          /*bound_ctrl*/ true);
+      res = vector::makeArithReduction(
+          rewriter, loc, gpu::convertReductionKind(mode), dpp, res);
+      // Obtain reduction from last rows, the previous rows are polluted.
+      Value lane63 = rewriter.create<arith::ConstantOp>(
+          loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(63));
+      res = rewriter.create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane63);
 
     } else if (chipset.majorVersion <= 12) {
       // Assume reduction across 32 lanes has been done.
       // Perform final reduction manually by summing values in lane 0 and
       // lane 32.
-      Value lane0 = rewriter.create<arith::ConstantOp>(
-          loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
-      Value lane32 = rewriter.create<arith::ConstantOp>(
-          loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(32));
-      dpp = rewriter.create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane32);
-      res = rewriter.create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane0);
+      Value lane31 = rewriter.create<arith::ConstantOp>(
+          loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(31));
+      Value lane63 = rewriter.create<arith::ConstantOp>(
+          loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(63));
+      lane0 =
+          rewriter.create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane31);
+      lane32 =
+          rewriter.create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane63);
+      res = vector::makeArithReduction(
+          rewriter, loc, gpu::convertReductionKind(mode), lane31, lane63);
     } else {
       return rewriter.notifyMatchFailure(
           op, "Subgroup reduce lowering to DPP not currently supported for "
               "this device.");
     }
-    res = vector::makeArithReduction(rewriter, loc,
-                                     gpu::convertReductionKind(mode), res, dpp);
   }
   assert(res.getType() == input.getType());
   return res;

``````````

</details>


https://github.com/llvm/llvm-project/pull/141825


More information about the Mlir-commits mailing list