[Mlir-commits] [mlir] [MLIR][XeGPU] Extend Wg-to-Sg Distribution of Multi-Reduction Op for round-robin layout (PR #189988)

Jianhui Li llvmlistbot at llvm.org
Wed Apr 1 08:59:37 PDT 2026


https://github.com/Jianhui-Li created https://github.com/llvm/llvm-project/pull/189988

As Title

>From 19566dfb2af650beb64dfaaa2d1cc9e1c414ce18 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 1 Apr 2026 15:57:59 +0000
Subject: [PATCH] extend mutli-reduction wg-to-sg distribtion for round-robin
 layout

---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 138 ++++++++----------
 .../XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir    |  52 +++++++
 2 files changed, 115 insertions(+), 75 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 0aead9172858f..6d2e7514aaff8 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1258,9 +1258,11 @@ struct WgToSgMultiDimReductionOp
     // Get sg_layout and sg_data from the parent layout
     SmallVector<int64_t> sgLayout;
     SmallVector<int64_t> sgData;
+    xegpu::DistributeLayoutAttr parentLayout;
     if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(layout)) {
-      sgLayout = sliceAttr.getParent().getEffectiveSgLayoutAsInt();
-      sgData = sliceAttr.getParent().getEffectiveSgDataAsInt();
+      parentLayout = sliceAttr.getParent();
+      sgLayout = parentLayout.getEffectiveSgLayoutAsInt();
+      sgData = parentLayout.getEffectiveSgDataAsInt();
     } else
       return rewriter.notifyMatchFailure(
           op, "Reduction should have SliceAttr layout");
@@ -1320,26 +1322,33 @@ struct WgToSgMultiDimReductionOp
       return success();
     }
 
-    // Step 2: cross-subgroup reduction using SLM
+    // Step 2: cross-subgroup reduction using SLM - allocating slm memory
     auto slmStoreDataShape = sgSrcShape;
     for (int64_t dim : reductionDims)
       slmStoreDataShape[dim] = 1;
     VectorType slmStoreDataType = VectorType::get(slmStoreDataShape, elemTy);
-    Value slmStoreData;
-    if (isScalarResult) {
-      // Scalar result: broadcast scalar to vector<1x...x1> for SLM store
-      slmStoreData = vector::BroadcastOp::create(
-          rewriter, loc, slmStoreDataType, localReductions[0]);
-    } else {
-      slmStoreData = vector::ShapeCastOp::create(
-          rewriter, loc, slmStoreDataType, localReductions[0]);
+    SmallVector<Value> slmStoreData;
+    for (auto localResult : localReductions) {
+      if (isScalarResult) {
+        // Scalar result: broadcast scalar to vector<1x...x1> for SLM store
+        slmStoreData.push_back(vector::BroadcastOp::create(
+            rewriter, loc, slmStoreDataType, localResult));
+      } else {
+        slmStoreData.push_back(vector::ShapeCastOp::create(
+            rewriter, loc, slmStoreDataType, localResult));
+      }
     }
-
+    // for reduction dimension, SLM stores partial results from each subgroup
     SmallVector<int64_t> slmShape(originalSrcShape.begin(),
                                   originalSrcShape.end());
-    // for reduction dimension, SLM stores partial results from each subgroup
-    for (int64_t dim : reductionDims)
+    SmallVector<int> slmSgData(sgData.begin(), sgData.end());
+    SmallVector<int> slmSgLayout(sgLayout.begin(), sgLayout.end());
+    for (int dim : reductionDims) {
       slmShape[dim] = sgLayout[dim];
+      slmSgData[dim] = sgLayout[dim];
+    }
+    xegpu::LayoutAttr slmStoreLayout =
+        xegpu::LayoutAttr::get(rewriter.getContext(), slmSgLayout, slmSgData);
 
     // Allocate SLM
     auto bitWidth = elemTy.getIntOrFloatBitWidth();
@@ -1353,82 +1362,61 @@ struct WgToSgMultiDimReductionOp
     auto memDesc =
         xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm);
 
-    // if localReductions have more than 1 result, not support
-    if (localReductions.size() > 1) {
-      return rewriter.notifyMatchFailure(
-          op,
-          "Multiple local reductions not supported in current implementation.");
-    }
-
-    // Step 4: Store local results to SLM
+    // Step 3: Store local results to SLM
     auto sgId = gpu::SubgroupIdOp::create(rewriter, loc,
                                           rewriter.getIndexType(), nullptr);
 
-    // Convert sgLayout to Values for delinearizeIndex
-    SmallVector<Value> sgLayoutValues;
-    for (int64_t dim : sgLayout)
-      sgLayoutValues.push_back(
-          arith::ConstantIndexOp::create(rewriter, loc, dim));
-
-    auto sgIdsResult = affine::delinearizeIndex(rewriter, loc, sgId.getResult(),
-                                                sgLayoutValues);
-    if (failed(sgIdsResult))
+    auto slmStoreCoords =
+        slmStoreLayout.computeDistributedCoords(rewriter, loc, sgId, slmShape);
+    if (failed(slmStoreCoords))
       return failure();
-    SmallVector<Value> sgIds = *sgIdsResult;
-
-    auto getSlmOffsets = [&](int64_t reductionDimStride) {
-      SmallVector<OpFoldResult> offsets;
-      offsets.reserve(srcVecRank);
-      for (int i = 0; i < srcVecRank; ++i) {
-        Value dimVal = sgIds[i];
-        int64_t sgDataStride = (llvm::is_contained(reductionDims, i))
-                                   ? reductionDimStride
-                                   : sgSrcShape[i];
-        Value strideVal =
-            arith::ConstantIndexOp::create(rewriter, loc, sgDataStride);
-        Value offsetVal =
-            arith::MulIOp::create(rewriter, loc, dimVal, strideVal);
-        offsets.push_back(offsetVal);
-      }
-      return offsets;
-    };
-
-    SmallVector<OpFoldResult> slmStoreOffsets =
-        getSlmOffsets(/*reductionDimStride=*/1);
-
-    xegpu::StoreMatrixOp::create(rewriter, loc, slmStoreData,
-                                 memDesc.getResult(), slmStoreOffsets,
-                                 /*layout=*/nullptr);
+    for (auto [data, coord] : llvm::zip(slmStoreData, *slmStoreCoords)) {
+      SmallVector<OpFoldResult> coordOfr(coord.begin(), coord.end());
+      xegpu::StoreMatrixOp::create(rewriter, loc, data, memDesc.getResult(),
+                                   coordOfr,
+                                   /*layout=*/nullptr);
+    }
 
     gpu::BarrierOp::create(rewriter, loc);
 
-    // Step 5: Load from SLM for final reduction
+    // Step 4: Load from SLM for final reduction
     SmallVector<int64_t> slmLoadDataShape(sgSrcShape.begin(), sgSrcShape.end());
-    for (int64_t dim : reductionDims)
+    for (int64_t dim : reductionDims) {
       slmLoadDataShape[dim] = slmShape[dim];
-
-    SmallVector<OpFoldResult> slmLoadOffsets =
-        getSlmOffsets(/*reductionDimStride=*/0);
+      slmSgData[dim] = slmShape[dim];
+    }
+    xegpu::LayoutAttr slmLoadLayout =
+        xegpu::LayoutAttr::get(rewriter.getContext(), slmSgLayout, slmSgData);
+    auto slmLoadCoords =
+        slmLoadLayout.computeDistributedCoords(rewriter, loc, sgId, slmShape);
+    if (failed(slmLoadCoords))
+      return failure();
 
     VectorType slmLoadType = VectorType::get(slmLoadDataShape, elemTy);
-    auto slmLoadOp = xegpu::LoadMatrixOp::create(
-        rewriter, loc, slmLoadType, memDesc.getResult(), slmLoadOffsets,
-        /*layout=*/nullptr);
+    SmallVector<Value> slmLoadData;
+    for (auto coord : *slmLoadCoords) {
+      SmallVector<OpFoldResult> coordOfr(coord.begin(), coord.end());
+      slmLoadData.push_back(xegpu::LoadMatrixOp::create(
+          rewriter, loc, slmLoadType, memDesc.getResult(), coordOfr,
+          /*layout=*/nullptr));
+    }
 
-    // Step 6: Perform final reduction with neutral accumulator
+    // Step 5: Perform final reduction with neutral accumulator and add the
+    // original accumulator at the end
     Value neutralFinalAcc = xegpu::createReductionNeutralValue(
         rewriter, loc, sgDstType, op.getKind());
 
-    auto finalReduce = vector::MultiDimReductionOp::create(
-        rewriter, loc, sgDstType, op.getKind(), slmLoadOp.getResult(),
-        neutralFinalAcc, reductionDims);
-
-    // Step 7: Add the original accumulator at the end
-    auto finalResult = vector::makeArithReduction(rewriter, loc, op.getKind(),
-                                                  finalReduce.getResult(),
-                                                  adaptor.getAcc()[0]);
-
-    rewriter.replaceOp(op, finalResult);
+    SmallVector<Value> finalResults;
+    for (size_t i = 0; i < slmLoadData.size(); ++i) {
+      auto loaded = slmLoadData[i];
+      auto finalReduce = vector::MultiDimReductionOp::create(
+          rewriter, loc, sgDstType, op.getKind(), loaded, neutralFinalAcc,
+          reductionDims);
+      finalResults.push_back(vector::makeArithReduction(
+          rewriter, loc, op.getKind(), finalReduce.getResult(),
+          adaptor.getAcc()[i]));
+    }
+    rewriter.replaceOpWithMultiple(op, {finalResults});
     return success();
   }
 };
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
index 068dd6d865ead..4a74b66afcd6a 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
@@ -165,4 +165,56 @@ gpu.module @test_distribution {
     gpu.return
   }
 
+  // CHECK-LABEL: gpu.func @reduction_cross_sg_rr
+  gpu.func @reduction_cross_sg_rr(%arg0: memref<2048xf32, 1>) kernel {
+    // CHECK: %[[CST_OFFSETS0:.*]] = arith.constant dense<0> : vector<4x16xindex>
+    // CHECK: %[[CST_OFFSETS1:.*]] = arith.constant dense<0> : vector<4x16xindex>
+    // CHECK: %[[CST_ACC0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+    // CHECK: %[[CST_ACC1:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+    // CHECK: %[[CST_MASK0:.*]] = arith.constant dense<true> : vector<4x16xi1>
+    // CHECK: %[[CST_MASK1:.*]] = arith.constant dense<true> : vector<4x16xi1>
+    //
+    // CHECK: %[[LOAD0:.*]] = xegpu.load %arg0[%[[CST_OFFSETS0]]], %[[CST_MASK0]]
+    // CHECK-SAME: -> vector<4x16xf32>
+    // CHECK: %[[LOAD1:.*]] = xegpu.load %arg0[%[[CST_OFFSETS1]]], %[[CST_MASK1]]
+    // CHECK-SAME: -> vector<4x16xf32>
+    //
+    // Local reductions
+    // CHECK: %[[NEUTRAL0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+    // CHECK: %[[LOCAL_RED0:.*]] = vector.multi_reduction <add>, %[[LOAD0]], %[[NEUTRAL0]] [1] : vector<4x16xf32> to vector<4xf32>
+    // CHECK: %[[NEUTRAL1:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+    // CHECK: %[[LOCAL_RED1:.*]] = vector.multi_reduction <add>, %[[LOAD1]], %[[NEUTRAL1]] [1] : vector<4x16xf32> to vector<4xf32>
+    //
+    // Shape cast for SLM store
+    // CHECK: %[[SC0:.*]] = vector.shape_cast %[[LOCAL_RED0]] : vector<4xf32> to vector<4x1xf32>
+    // CHECK: %[[SC1:.*]] = vector.shape_cast %[[LOCAL_RED1]] : vector<4xf32> to vector<4x1xf32>
+    //
+    // SLM allocation and mem_desc
+    // CHECK: %[[SLM:.*]] = memref.alloca() : memref<512xi8, 3>
+    // CHECK: %[[MEMDESC:.*]] = xegpu.create_mem_desc %[[SLM]] : memref<512xi8, 3> -> !xegpu.mem_desc<8x16xf32>
+    //
+    // Store to SLM
+    // CHECK: xegpu.store_matrix %[[SC0]], %[[MEMDESC]]{{.*}} : vector<4x1xf32>, !xegpu.mem_desc<8x16xf32>
+    // CHECK: xegpu.store_matrix %[[SC1]], %[[MEMDESC]]{{.*}} : vector<4x1xf32>, !xegpu.mem_desc<8x16xf32>
+    // CHECK: gpu.barrier
+    //
+    // Load from SLM
+    // CHECK: %[[SLM_LOAD0:.*]] = xegpu.load_matrix %[[MEMDESC]]{{.*}} -> vector<4x16xf32>
+    // CHECK: %[[SLM_LOAD1:.*]] = xegpu.load_matrix %[[MEMDESC]]{{.*}} -> vector<4x16xf32>
+    //
+    // Final reduction
+    // CHECK: %[[FINAL_NEUTRAL:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+    // CHECK: %[[FINAL_RED0:.*]] = vector.multi_reduction <add>, %[[SLM_LOAD0]], %[[FINAL_NEUTRAL]] [1] : vector<4x16xf32> to vector<4xf32>
+    // CHECK: %[[RES0:.*]] = arith.addf %[[FINAL_RED0]], %[[CST_ACC0]] : vector<4xf32>
+    // CHECK: %[[FINAL_RED1:.*]] = vector.multi_reduction <add>, %[[SLM_LOAD1]], %[[FINAL_NEUTRAL]] [1] : vector<4x16xf32> to vector<4xf32>
+    // CHECK: %[[RES1:.*]] = arith.addf %[[FINAL_RED1]], %[[CST_ACC1]] : vector<4xf32>
+    
+    %offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [1, 16], sg_data = [4, 16]>} dense<0> : vector<8x256xindex>
+    %acc = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 16], sg_data = [4, 16]>, dims = [1]>} dense<0.000000e+00> : vector<8xf32>
+    %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [1, 16], sg_data = [4, 16]>} dense<true> : vector<8x256xi1>
+    %val = xegpu.load %arg0[%offset], %mask <{layout = #xegpu.layout<sg_layout = [1, 16], sg_data = [4, 16]>}> : memref<2048xf32, 1>, vector<8x256xindex>, vector<8x256xi1> -> vector<8x256xf32>
+    %reduce = vector.multi_reduction <add>, %val, %acc {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 16], sg_data = [4, 16]>, dims = [1]>} [1] : vector<8x256xf32> to vector<8xf32>
+    gpu.return
+  }
+
 }



More information about the Mlir-commits mailing list