[Mlir-commits] [mlir] [mlir][GPU] Fixes subgroup reduce lowering (PR #141825)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed May 28 12:04:22 PDT 2025


https://github.com/Muzammiluddin-Syed-ECE created https://github.com/llvm/llvm-project/pull/141825

Fixes issues with subgroup reduce lowering producing erroneous results.

Aligning more closely with triton implementation : https://github.com/triton-lang/triton/pull/5019

>From 38dca73ce9d7121272cc1386c57bd883b0ca666a Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Wed, 28 May 2025 18:53:24 +0000
Subject: [PATCH] [mlir][GPU] Fixes subgroup reduce lowering

Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>
---
 .../GPU/Transforms/SubgroupReduceLowering.cpp | 29 ++++++++++++-------
 1 file changed, 18 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
index 74face4291353..b7a12763057f0 100644
--- a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
@@ -447,29 +447,36 @@ createSubgroupDPPReduction(PatternRewriter &rewriter, gpu::SubgroupReduceOp op,
   if (ci.clusterSize >= 64) {
     if (chipset.majorVersion <= 9) {
       // Broadcast 31st lane value to rows 2 and 3.
-      // Use row mask to avoid polluting rows 0 and 1.
       dpp = rewriter.create<amdgpu::DPPOp>(
           loc, res.getType(), res, res, amdgpu::DPPPerm::row_bcast_31,
-          rewriter.getUnitAttr(), 0xc, allBanks,
-          /*bound_ctrl*/ false);
+          rewriter.getUnitAttr(), 0xf, allBanks,
+          /*bound_ctrl*/ true);
+      res = vector::makeArithReduction(
+          rewriter, loc, gpu::convertReductionKind(mode), dpp, res);
+      // Obtain reduction from last rows, the previous rows are polluted.
+      Value lane63 = rewriter.create<arith::ConstantOp>(
+          loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(63));
+      res = rewriter.create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane63);
 
     } else if (chipset.majorVersion <= 12) {
       // Assume reduction across 32 lanes has been done.
       // Perform final reduction manually by summing values in lane 0 and
       // lane 32.
-      Value lane0 = rewriter.create<arith::ConstantOp>(
-          loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
-      Value lane32 = rewriter.create<arith::ConstantOp>(
-          loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(32));
-      dpp = rewriter.create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane32);
-      res = rewriter.create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane0);
+      Value lane31 = rewriter.create<arith::ConstantOp>(
+          loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(31));
+      Value lane63 = rewriter.create<arith::ConstantOp>(
+          loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(63));
+      lane0 =
+          rewriter.create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane31);
+      lane32 =
+          rewriter.create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane63);
+      res = vector::makeArithReduction(
+          rewriter, loc, gpu::convertReductionKind(mode), lane31, lane63);
     } else {
       return rewriter.notifyMatchFailure(
           op, "Subgroup reduce lowering to DPP not currently supported for "
               "this device.");
     }
-    res = vector::makeArithReduction(rewriter, loc,
-                                     gpu::convertReductionKind(mode), res, dpp);
   }
   assert(res.getType() == input.getType());
   return res;



More information about the Mlir-commits mailing list