[Mlir-commits] [mlir] [MLIR][XeGPU] Extend Wg-to-Sg Distribution of Multi-Reduction Op for round-robin layout (PR #189988)
Jianhui Li
llvmlistbot at llvm.org
Wed Apr 1 08:59:37 PDT 2026
https://github.com/Jianhui-Li created https://github.com/llvm/llvm-project/pull/189988
As Title
>From 19566dfb2af650beb64dfaaa2d1cc9e1c414ce18 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 1 Apr 2026 15:57:59 +0000
Subject: [PATCH] extend mutli-reduction wg-to-sg distribtion for round-robin
layout
---
.../Transforms/XeGPUWgToSgDistribute.cpp | 138 ++++++++----------
.../XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir | 52 +++++++
2 files changed, 115 insertions(+), 75 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 0aead9172858f..6d2e7514aaff8 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1258,9 +1258,11 @@ struct WgToSgMultiDimReductionOp
// Get sg_layout and sg_data from the parent layout
SmallVector<int64_t> sgLayout;
SmallVector<int64_t> sgData;
+ xegpu::DistributeLayoutAttr parentLayout;
if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(layout)) {
- sgLayout = sliceAttr.getParent().getEffectiveSgLayoutAsInt();
- sgData = sliceAttr.getParent().getEffectiveSgDataAsInt();
+ parentLayout = sliceAttr.getParent();
+ sgLayout = parentLayout.getEffectiveSgLayoutAsInt();
+ sgData = parentLayout.getEffectiveSgDataAsInt();
} else
return rewriter.notifyMatchFailure(
op, "Reduction should have SliceAttr layout");
@@ -1320,26 +1322,33 @@ struct WgToSgMultiDimReductionOp
return success();
}
- // Step 2: cross-subgroup reduction using SLM
+ // Step 2: cross-subgroup reduction using SLM - allocating slm memory
auto slmStoreDataShape = sgSrcShape;
for (int64_t dim : reductionDims)
slmStoreDataShape[dim] = 1;
VectorType slmStoreDataType = VectorType::get(slmStoreDataShape, elemTy);
- Value slmStoreData;
- if (isScalarResult) {
- // Scalar result: broadcast scalar to vector<1x...x1> for SLM store
- slmStoreData = vector::BroadcastOp::create(
- rewriter, loc, slmStoreDataType, localReductions[0]);
- } else {
- slmStoreData = vector::ShapeCastOp::create(
- rewriter, loc, slmStoreDataType, localReductions[0]);
+ SmallVector<Value> slmStoreData;
+ for (auto localResult : localReductions) {
+ if (isScalarResult) {
+ // Scalar result: broadcast scalar to vector<1x...x1> for SLM store
+ slmStoreData.push_back(vector::BroadcastOp::create(
+ rewriter, loc, slmStoreDataType, localResult));
+ } else {
+ slmStoreData.push_back(vector::ShapeCastOp::create(
+ rewriter, loc, slmStoreDataType, localResult));
+ }
}
-
+ // for reduction dimension, SLM stores partial results from each subgroup
SmallVector<int64_t> slmShape(originalSrcShape.begin(),
originalSrcShape.end());
- // for reduction dimension, SLM stores partial results from each subgroup
- for (int64_t dim : reductionDims)
+ SmallVector<int> slmSgData(sgData.begin(), sgData.end());
+ SmallVector<int> slmSgLayout(sgLayout.begin(), sgLayout.end());
+ for (int dim : reductionDims) {
slmShape[dim] = sgLayout[dim];
+ slmSgData[dim] = sgLayout[dim];
+ }
+ xegpu::LayoutAttr slmStoreLayout =
+ xegpu::LayoutAttr::get(rewriter.getContext(), slmSgLayout, slmSgData);
// Allocate SLM
auto bitWidth = elemTy.getIntOrFloatBitWidth();
@@ -1353,82 +1362,61 @@ struct WgToSgMultiDimReductionOp
auto memDesc =
xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm);
- // if localReductions have more than 1 result, not support
- if (localReductions.size() > 1) {
- return rewriter.notifyMatchFailure(
- op,
- "Multiple local reductions not supported in current implementation.");
- }
-
- // Step 4: Store local results to SLM
+ // Step 3: Store local results to SLM
auto sgId = gpu::SubgroupIdOp::create(rewriter, loc,
rewriter.getIndexType(), nullptr);
- // Convert sgLayout to Values for delinearizeIndex
- SmallVector<Value> sgLayoutValues;
- for (int64_t dim : sgLayout)
- sgLayoutValues.push_back(
- arith::ConstantIndexOp::create(rewriter, loc, dim));
-
- auto sgIdsResult = affine::delinearizeIndex(rewriter, loc, sgId.getResult(),
- sgLayoutValues);
- if (failed(sgIdsResult))
+ auto slmStoreCoords =
+ slmStoreLayout.computeDistributedCoords(rewriter, loc, sgId, slmShape);
+ if (failed(slmStoreCoords))
return failure();
- SmallVector<Value> sgIds = *sgIdsResult;
-
- auto getSlmOffsets = [&](int64_t reductionDimStride) {
- SmallVector<OpFoldResult> offsets;
- offsets.reserve(srcVecRank);
- for (int i = 0; i < srcVecRank; ++i) {
- Value dimVal = sgIds[i];
- int64_t sgDataStride = (llvm::is_contained(reductionDims, i))
- ? reductionDimStride
- : sgSrcShape[i];
- Value strideVal =
- arith::ConstantIndexOp::create(rewriter, loc, sgDataStride);
- Value offsetVal =
- arith::MulIOp::create(rewriter, loc, dimVal, strideVal);
- offsets.push_back(offsetVal);
- }
- return offsets;
- };
-
- SmallVector<OpFoldResult> slmStoreOffsets =
- getSlmOffsets(/*reductionDimStride=*/1);
-
- xegpu::StoreMatrixOp::create(rewriter, loc, slmStoreData,
- memDesc.getResult(), slmStoreOffsets,
- /*layout=*/nullptr);
+ for (auto [data, coord] : llvm::zip(slmStoreData, *slmStoreCoords)) {
+ SmallVector<OpFoldResult> coordOfr(coord.begin(), coord.end());
+ xegpu::StoreMatrixOp::create(rewriter, loc, data, memDesc.getResult(),
+ coordOfr,
+ /*layout=*/nullptr);
+ }
gpu::BarrierOp::create(rewriter, loc);
- // Step 5: Load from SLM for final reduction
+ // Step 4: Load from SLM for final reduction
SmallVector<int64_t> slmLoadDataShape(sgSrcShape.begin(), sgSrcShape.end());
- for (int64_t dim : reductionDims)
+ for (int64_t dim : reductionDims) {
slmLoadDataShape[dim] = slmShape[dim];
-
- SmallVector<OpFoldResult> slmLoadOffsets =
- getSlmOffsets(/*reductionDimStride=*/0);
+ slmSgData[dim] = slmShape[dim];
+ }
+ xegpu::LayoutAttr slmLoadLayout =
+ xegpu::LayoutAttr::get(rewriter.getContext(), slmSgLayout, slmSgData);
+ auto slmLoadCoords =
+ slmLoadLayout.computeDistributedCoords(rewriter, loc, sgId, slmShape);
+ if (failed(slmLoadCoords))
+ return failure();
VectorType slmLoadType = VectorType::get(slmLoadDataShape, elemTy);
- auto slmLoadOp = xegpu::LoadMatrixOp::create(
- rewriter, loc, slmLoadType, memDesc.getResult(), slmLoadOffsets,
- /*layout=*/nullptr);
+ SmallVector<Value> slmLoadData;
+ for (auto coord : *slmLoadCoords) {
+ SmallVector<OpFoldResult> coordOfr(coord.begin(), coord.end());
+ slmLoadData.push_back(xegpu::LoadMatrixOp::create(
+ rewriter, loc, slmLoadType, memDesc.getResult(), coordOfr,
+ /*layout=*/nullptr));
+ }
- // Step 6: Perform final reduction with neutral accumulator
+ // Step 5: Perform final reduction with neutral accumulator and add the
+ // original accumulator at the end
Value neutralFinalAcc = xegpu::createReductionNeutralValue(
rewriter, loc, sgDstType, op.getKind());
- auto finalReduce = vector::MultiDimReductionOp::create(
- rewriter, loc, sgDstType, op.getKind(), slmLoadOp.getResult(),
- neutralFinalAcc, reductionDims);
-
- // Step 7: Add the original accumulator at the end
- auto finalResult = vector::makeArithReduction(rewriter, loc, op.getKind(),
- finalReduce.getResult(),
- adaptor.getAcc()[0]);
-
- rewriter.replaceOp(op, finalResult);
+ SmallVector<Value> finalResults;
+ for (size_t i = 0; i < slmLoadData.size(); ++i) {
+ auto loaded = slmLoadData[i];
+ auto finalReduce = vector::MultiDimReductionOp::create(
+ rewriter, loc, sgDstType, op.getKind(), loaded, neutralFinalAcc,
+ reductionDims);
+ finalResults.push_back(vector::makeArithReduction(
+ rewriter, loc, op.getKind(), finalReduce.getResult(),
+ adaptor.getAcc()[i]));
+ }
+ rewriter.replaceOpWithMultiple(op, {finalResults});
return success();
}
};
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
index 068dd6d865ead..4a74b66afcd6a 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
@@ -165,4 +165,56 @@ gpu.module @test_distribution {
gpu.return
}
+ // CHECK-LABEL: gpu.func @reduction_cross_sg_rr
+ gpu.func @reduction_cross_sg_rr(%arg0: memref<2048xf32, 1>) kernel {
+ // CHECK: %[[CST_OFFSETS0:.*]] = arith.constant dense<0> : vector<4x16xindex>
+ // CHECK: %[[CST_OFFSETS1:.*]] = arith.constant dense<0> : vector<4x16xindex>
+ // CHECK: %[[CST_ACC0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+ // CHECK: %[[CST_ACC1:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+ // CHECK: %[[CST_MASK0:.*]] = arith.constant dense<true> : vector<4x16xi1>
+ // CHECK: %[[CST_MASK1:.*]] = arith.constant dense<true> : vector<4x16xi1>
+ //
+ // CHECK: %[[LOAD0:.*]] = xegpu.load %arg0[%[[CST_OFFSETS0]]], %[[CST_MASK0]]
+ // CHECK-SAME: -> vector<4x16xf32>
+ // CHECK: %[[LOAD1:.*]] = xegpu.load %arg0[%[[CST_OFFSETS1]]], %[[CST_MASK1]]
+ // CHECK-SAME: -> vector<4x16xf32>
+ //
+ // Local reductions
+ // CHECK: %[[NEUTRAL0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+ // CHECK: %[[LOCAL_RED0:.*]] = vector.multi_reduction <add>, %[[LOAD0]], %[[NEUTRAL0]] [1] : vector<4x16xf32> to vector<4xf32>
+ // CHECK: %[[NEUTRAL1:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+ // CHECK: %[[LOCAL_RED1:.*]] = vector.multi_reduction <add>, %[[LOAD1]], %[[NEUTRAL1]] [1] : vector<4x16xf32> to vector<4xf32>
+ //
+ // Shape cast for SLM store
+ // CHECK: %[[SC0:.*]] = vector.shape_cast %[[LOCAL_RED0]] : vector<4xf32> to vector<4x1xf32>
+ // CHECK: %[[SC1:.*]] = vector.shape_cast %[[LOCAL_RED1]] : vector<4xf32> to vector<4x1xf32>
+ //
+ // SLM allocation and mem_desc
+ // CHECK: %[[SLM:.*]] = memref.alloca() : memref<512xi8, 3>
+ // CHECK: %[[MEMDESC:.*]] = xegpu.create_mem_desc %[[SLM]] : memref<512xi8, 3> -> !xegpu.mem_desc<8x16xf32>
+ //
+ // Store to SLM
+ // CHECK: xegpu.store_matrix %[[SC0]], %[[MEMDESC]]{{.*}} : vector<4x1xf32>, !xegpu.mem_desc<8x16xf32>
+ // CHECK: xegpu.store_matrix %[[SC1]], %[[MEMDESC]]{{.*}} : vector<4x1xf32>, !xegpu.mem_desc<8x16xf32>
+ // CHECK: gpu.barrier
+ //
+ // Load from SLM
+ // CHECK: %[[SLM_LOAD0:.*]] = xegpu.load_matrix %[[MEMDESC]]{{.*}} -> vector<4x16xf32>
+ // CHECK: %[[SLM_LOAD1:.*]] = xegpu.load_matrix %[[MEMDESC]]{{.*}} -> vector<4x16xf32>
+ //
+ // Final reduction
+ // CHECK: %[[FINAL_NEUTRAL:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+ // CHECK: %[[FINAL_RED0:.*]] = vector.multi_reduction <add>, %[[SLM_LOAD0]], %[[FINAL_NEUTRAL]] [1] : vector<4x16xf32> to vector<4xf32>
+ // CHECK: %[[RES0:.*]] = arith.addf %[[FINAL_RED0]], %[[CST_ACC0]] : vector<4xf32>
+ // CHECK: %[[FINAL_RED1:.*]] = vector.multi_reduction <add>, %[[SLM_LOAD1]], %[[FINAL_NEUTRAL]] [1] : vector<4x16xf32> to vector<4xf32>
+ // CHECK: %[[RES1:.*]] = arith.addf %[[FINAL_RED1]], %[[CST_ACC1]] : vector<4xf32>
+
+ %offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [1, 16], sg_data = [4, 16]>} dense<0> : vector<8x256xindex>
+ %acc = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 16], sg_data = [4, 16]>, dims = [1]>} dense<0.000000e+00> : vector<8xf32>
+ %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [1, 16], sg_data = [4, 16]>} dense<true> : vector<8x256xi1>
+ %val = xegpu.load %arg0[%offset], %mask <{layout = #xegpu.layout<sg_layout = [1, 16], sg_data = [4, 16]>}> : memref<2048xf32, 1>, vector<8x256xindex>, vector<8x256xi1> -> vector<8x256xf32>
+ %reduce = vector.multi_reduction <add>, %val, %acc {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 16], sg_data = [4, 16]>, dims = [1]>} [1] : vector<8x256xf32> to vector<8xf32>
+ gpu.return
+ }
+
}
More information about the Mlir-commits
mailing list