[Mlir-commits] [mlir] [MLIR][XeGPU] Improve workgroup to subgroup distribution pattern for mulit-reduction op (PR #182178)
Jianhui Li
llvmlistbot at llvm.org
Wed Feb 18 15:11:28 PST 2026
https://github.com/Jianhui-Li updated https://github.com/llvm/llvm-project/pull/182178
>From 61e382c4d7609c194ed71da159f2a0a0bea80929 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 18 Feb 2026 23:00:41 +0000
Subject: [PATCH 1/3] improve workg group distribution pattern for
cross-subgroup reduction using nd-slm
---
.../Transforms/XeGPUWgToSgDistribute.cpp | 193 +++++++++---------
.../XeGPU/xegpu-wg-to-sg-unify-ops.mlir | 155 +++++++-------
2 files changed, 178 insertions(+), 170 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 66eb7fc97aa1a..8ab0b50da03ad 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1304,10 +1304,10 @@ static Value createAccumulator(ConversionPatternRewriter &rewriter,
/// This gives us a unique linear index for each combination of subgroup
/// positions in the specified dimensions, which is used for SLM row/column
/// addressing.
-static Value linearizeSubgroupIndices(ConversionPatternRewriter &rewriter,
- Location loc, ArrayRef<Value> sgIds,
- ArrayRef<int64_t> dims,
- ArrayRef<int64_t> sgLayout) {
+[[maybe_unused]] static Value
+linearizeSubgroupIndices(ConversionPatternRewriter &rewriter, Location loc,
+ ArrayRef<Value> sgIds, ArrayRef<int64_t> dims,
+ ArrayRef<int64_t> sgLayout) {
Value linearizedOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
int64_t stride = 1;
@@ -1370,6 +1370,9 @@ struct WgToSgMultiDimReductionOp
return failure();
auto originalSrcShape = srcType.getShape();
+ auto originalDstShape = dstType.getShape();
+ int srcVecRank = originalSrcShape.size();
+
xegpu::DistributeLayoutAttr layout =
xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
if (!layout || !layout.isForWorkgroup())
@@ -1391,10 +1394,21 @@ struct WgToSgMultiDimReductionOp
// Step 1: perform local subgroup reductions with ZERO accumulator
SmallVector<Value> localReductions;
- SmallVector<int64_t> sgShape =
- getSgShapeAndCount(originalSrcShape, layout).first;
- VectorType newDstType = VectorType::get(sgShape, elemTy);
- for (auto sgSrc : adaptor.getSource()) {
+ SmallVector<int64_t> sgDstShape =
+ getSgShapeAndCount(originalDstShape, layout).first;
+ auto sgSrcs = adaptor.getSource();
+ auto sgSrcType = dyn_cast<VectorType>(sgSrcs.front().getType());
+ SmallVector<int64_t> sgSrcShape(sgSrcType.getShape().begin(),
+ sgSrcType.getShape().end());
+
+ // debug print sgDstShape
+ llvm::errs() << "[WgToSgMultiDimReductionOp] sgDstShape = [";
+ for (auto [i, v] : llvm::enumerate(sgDstShape))
+ llvm::errs() << (i ? ", " : "") << v;
+ llvm::errs() << "]\n";
+
+ VectorType newDstType = VectorType::get(sgDstShape, elemTy);
+ for (auto sgSrc : sgSrcs) {
// Create ZERO accumulator for local reduction
auto neutralLocalAcc =
createAccumulator(rewriter, loc, newDstType, op.getKind());
@@ -1430,44 +1444,53 @@ struct WgToSgMultiDimReductionOp
}
// Step 2: cross-subgroup reduction using SLM
-
- // Calculate total elements in local result
- int64_t localElements = computeProduct(sgShape);
-
- // Shape cast for SLM storage - store as [1, localElements]
- SmallVector<int64_t> storeShape2D = {1, localElements};
- VectorType storeType2D = VectorType::get(storeShape2D, elemTy);
- auto storeShapeCast = vector::ShapeCastOp::create(
- rewriter, loc, storeType2D, localReductions[0]);
- Value storeData = storeShapeCast.getResult();
-
- // Calculate SLM shape - rows for sg's in reduction dims, cols for total
- // result elements across all subgroups in non-reduction dimensions
- int64_t totalReductionSubgroups = 1;
- for (int64_t dim : crossSgReductionDims) {
- totalReductionSubgroups *= sgLayout[dim];
- }
-
- // Total result elements across all subgroups in non-reduction dimensions
- int64_t totalResultElements =
- localElements * computeProduct(sgLayout) / totalReductionSubgroups;
-
- SmallVector<int64_t> slmShape2D = {totalReductionSubgroups,
- totalResultElements};
+ auto slmStoreDataShape = sgSrcShape;
+ for (int64_t dim : reductionDims)
+ slmStoreDataShape[dim] = 1;
+ VectorType slmStoreDataType = VectorType::get(slmStoreDataShape, elemTy);
+ Value slmStoreData = vector::ShapeCastOp::create(
+ rewriter, loc, slmStoreDataType, localReductions[0]);
+
+ SmallVector<int64_t> slmShape(originalSrcShape.begin(),
+ originalSrcShape.end());
+ // for reduction dimension, SLM stores partial results from each subgroup
+ for (int64_t dim : reductionDims)
+ slmShape[dim] = originalSrcShape[dim] / sgSrcShape[dim];
+
+ // Debug print.
+ llvm::errs() << "[WgToSgMultiDimReductionOp] originalSrcShape = [";
+ for (auto [i, v] : llvm::enumerate(originalSrcShape))
+ llvm::errs() << (i ? ", " : "") << v;
+ llvm::errs() << "], sgSrcShape = [";
+ for (auto [i, v] : llvm::enumerate(sgSrcShape))
+ llvm::errs() << (i ? ", " : "") << v;
+ llvm::errs() << "], slmShape = [";
+ for (auto [i, v] : llvm::enumerate(slmShape))
+ llvm::errs() << (i ? ", " : "") << v;
+ llvm::errs() << "], reductionDims = [";
+ for (auto [i, v] : llvm::enumerate(reductionDims))
+ llvm::errs() << (i ? ", " : "") << v;
+ llvm::errs() << "]\n";
// Allocate SLM
auto bitWidth = elemTy.getIntOrFloatBitWidth();
auto bytesPerElement = bitWidth / 8;
- int64_t slmElements = slmShape2D[0] * slmShape2D[1];
- auto slmSize = slmElements * bytesPerElement;
+ auto slmSize = computeProduct(slmShape) * bytesPerElement;
auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3);
auto slm = memref::AllocaOp::create(rewriter, loc, slmTy);
- auto memDescType = xegpu::MemDescType::get(rewriter.getContext(),
- slmShape2D, elemTy, nullptr);
+ auto memDescType = xegpu::MemDescType::get(rewriter.getContext(), slmShape,
+ elemTy, nullptr);
auto memDesc =
xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm);
+ // if localReductions have more than 1 result, not support
+ if (localReductions.size() > 1) {
+ return rewriter.notifyMatchFailure(
+ op,
+ "Multiple local reductions not supported in current implementation.");
+ }
+
// Step 4: Store local results to SLM
auto sgId = gpu::SubgroupIdOp::create(rewriter, loc,
rewriter.getIndexType(), nullptr);
@@ -1484,78 +1507,66 @@ struct WgToSgMultiDimReductionOp
return failure();
SmallVector<Value> sgIds = *sgIdsResult;
- // Row offset: linearize reduction dimension indices
- Value rowOffsetStore = linearizeSubgroupIndices(
- rewriter, loc, sgIds, crossSgReductionDims, sgLayout);
-
- // Column offset: linearize non-reduction dimension indices
- SmallVector<int64_t> nonReductionDims;
- for (size_t i = 0; i < sgLayout.size(); ++i) {
- if (!llvm::is_contained(reductionDims, static_cast<int64_t>(i))) {
- nonReductionDims.push_back(static_cast<int64_t>(i));
- }
+ SmallVector<OpFoldResult> slmStoreOffsets;
+ for (int i = 0; i < srcVecRank; ++i) {
+ Value dimVal = sgIds[i];
+ int64_t stride =
+ (llvm::is_contained(reductionDims, i)) ? 1 : sgSrcShape[i];
+ Value strideVal = arith::ConstantIndexOp::create(rewriter, loc, stride);
+ Value offsetVal = arith::MulIOp::create(rewriter, loc, dimVal, strideVal);
+ slmStoreOffsets.push_back(offsetVal);
}
-
- Value colOffset = linearizeSubgroupIndices(rewriter, loc, sgIds,
- nonReductionDims, sgLayout);
-
- Value localElementsVal =
- arith::ConstantIndexOp::create(rewriter, loc, localElements);
- colOffset =
- arith::MulIOp::create(rewriter, loc, colOffset, localElementsVal);
-
- SmallVector<OpFoldResult> storeOffsets2D = {rowOffsetStore, colOffset};
-
- xegpu::StoreMatrixOp::create(rewriter, loc, storeData, memDesc.getResult(),
- storeOffsets2D, /*layout=*/nullptr);
+ // debug print
+ llvm::errs() << "[WgToSgMultiDimReductionOp] sgIds = [";
+ for (auto [i, v] : llvm::enumerate(sgIds))
+ llvm::errs() << (i ? ", " : "") << v;
+ llvm::errs() << "], slmStoreOffsets = [";
+ for (auto [i, v] : llvm::enumerate(slmStoreOffsets))
+ llvm::errs() << (i ? ", " : "") << v;
+ llvm::errs() << "]\n";
+ xegpu::StoreMatrixOp::create(rewriter, loc, slmStoreData,
+ memDesc.getResult(), slmStoreOffsets,
+ /*layout=*/nullptr);
gpu::BarrierOp::create(rewriter, loc);
// Step 5: Load from SLM for final reduction
- SmallVector<int64_t> loadShape2D = {totalReductionSubgroups, localElements};
- VectorType loadType2D = VectorType::get(loadShape2D, elemTy);
-
- // Load offsets - each subgroup loads its column based on non-reduction
- // position
- Value rowOffsetLoad = arith::ConstantIndexOp::create(rewriter, loc, 0);
-
- SmallVector<OpFoldResult> loadOffsets2D = {rowOffsetLoad, colOffset};
+ SmallVector<int64_t> slmLoadDataShape(sgSrcShape.begin(), sgSrcShape.end());
+ for (int64_t dim : reductionDims)
+ slmLoadDataShape[dim] = slmShape[dim];
+
+ llvm::errs() << "[WgToSgMultiDimReductionOp] slmLoadDataShape = [";
+ for (auto [i, v] : llvm::enumerate(slmLoadDataShape))
+ llvm::errs() << (i ? ", " : "") << v;
+ llvm::errs() << "]\n";
+
+ SmallVector<OpFoldResult> slmLoadOffsets;
+ for (int i = 0; i < srcVecRank; ++i) {
+ Value dimVal = sgIds[i];
+ int64_t stride =
+ (llvm::is_contained(reductionDims, i)) ? 0 : sgSrcShape[i];
+ Value strideVal = arith::ConstantIndexOp::create(rewriter, loc, stride);
+ Value offsetVal = arith::MulIOp::create(rewriter, loc, dimVal, strideVal);
+ slmLoadOffsets.push_back(offsetVal);
+ }
- auto loadOp = xegpu::LoadMatrixOp::create(
- rewriter, loc, loadType2D, memDesc.getResult(), loadOffsets2D,
+ VectorType slmLoadType = VectorType::get(slmLoadDataShape, elemTy);
+ auto slmLoadOp = xegpu::LoadMatrixOp::create(
+ rewriter, loc, slmLoadType, memDesc.getResult(), slmLoadOffsets,
/*layout=*/nullptr);
// Step 6: Perform final reduction with ZERO accumulator
- SmallVector<int64_t> finalReductionDims = {0};
- SmallVector<int64_t> finalResultShape = {localElements};
- VectorType finalResultType = VectorType::get(finalResultShape, elemTy);
-
auto neutralFinalAcc =
- createAccumulator(rewriter, loc, finalResultType, op.getKind());
+ createAccumulator(rewriter, loc, newDstType, op.getKind());
auto finalReduce = vector::MultiDimReductionOp::create(
- rewriter, loc, finalResultType, op.getKind(), loadOp.getResult(),
- neutralFinalAcc, finalReductionDims);
+ rewriter, loc, newDstType, op.getKind(), slmLoadOp.getResult(),
+ neutralFinalAcc, reductionDims);
// Step 7: Add the original accumulator at the end
Value originalAcc = adaptor.getAcc()[0];
Value accToAdd = originalAcc;
- // Handle shape mismatch by shape casting
- if (originalAcc.getType() != finalReduce.getResult().getType()) {
- auto originalAccType = cast<VectorType>(originalAcc.getType());
- auto finalResultType =
- cast<VectorType>(finalReduce.getResult().getType());
-
- // If they have the same number of elements, just shape cast
- if (originalAccType.getNumElements() ==
- finalResultType.getNumElements()) {
- auto shapeCast = vector::ShapeCastOp::create(
- rewriter, loc, finalResultType, originalAcc);
- accToAdd = shapeCast.getResult();
- }
- }
-
auto finalResult = vector::makeArithReduction(
rewriter, loc, op.getKind(), finalReduce.getResult(), accToAdd);
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index e2e94c5f0300f..9407f7f2357a2 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -646,30 +646,28 @@ gpu.module @test_distribution {
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<1x32xf32>
// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<0> : vector<1x1x32xindex>
// CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<true> : vector<1x1x32xi1>
- // CHECK-DAG: %[[LOAD:.*]] = xegpu.load %{{.*}}[%[[CST_0]]], %[[CST_1]] <{chunk_size = 1 : i64}> : memref<?xf32>, vector<1x1x32xindex>, vector<1x1x32xi1> -> vector<1x1x32xf32>
+ // CHECK-DAG: %[[LOAD:.*]] = xegpu.load %[[ARG0:.*]][%[[CST_0]]], %[[CST_1]] <{chunk_size = 1 : i64}> : memref<?xf32>, vector<1x1x32xindex>, vector<1x1x32xi1> -> vector<1x1x32xf32>
// CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<0.000000e+00> : vector<1x32xf32>
// CHECK-DAG: %[[LOCAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD]], %[[CST_2]] [1] : vector<1x1x32xf32> to vector<1x32xf32>
- // CHECK-DAG: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[LOCAL_REDUCE]] : vector<1x32xf32> to vector<1x32xf32>
+ // CHECK-DAG: %[[CAST:.*]] = vector.shape_cast %[[LOCAL_REDUCE]] : vector<1x32xf32> to vector<1x1x32xf32>
// CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<4096xi8, 3>
- // CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf32>
+ // CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<4096xi8, 3> -> !xegpu.mem_desc<1x32x32xf32>
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
- // CHECK-DAG: %[[AFFINE1:.*]] = affine.apply #map()[%[[SGID]]]
- // CHECK-DAG: %[[AFFINE2:.*]] = affine.apply #map1()[%[[SGID]]]
- // CHECK-DAG: %[[AFFINE3:.*]] = affine.apply #map2()[%[[SGID]]]
- // CHECK-DAG: %[[MUL1:.*]] = arith.muli {{.*}}, %[[C1:.*]] : index
- // CHECK-DAG: %[[ROW_OFFSET:.*]] = arith.addi %[[C0:.*]], %[[MUL1]] : index
- // CHECK-DAG: %[[MUL2:.*]] = arith.muli %[[AFFINE1]], %[[C1:.*]] : index
- // CHECK-DAG: %[[ADD1:.*]] = arith.addi %[[C0:.*]], %[[MUL2]] : index
- // CHECK-DAG: %[[MUL3:.*]] = arith.muli %[[AFFINE3]], %[[C1:.*]] : index
- // CHECK-DAG: %[[ADD2:.*]] = arith.addi %[[ADD1]], %[[MUL3]] : index
- // CHECK-DAG: %[[COL_OFFSET:.*]] = arith.muli %[[ADD2]], %[[C32:.*]] : index
- // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[ROW_OFFSET]], %[[COL_OFFSET]]] : vector<1x32xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+ // CHECK-DAG: %[[AFF0:.*]] = affine.apply #map()[%[[SGID]]]
+ // CHECK-DAG: %[[AFF1:.*]] = affine.apply #map1()[%[[SGID]]]
+ // CHECK-DAG: %[[AFF2:.*]] = affine.apply #map2()[%[[SGID]]]
+ // CHECK-DAG: %[[ROW:.*]] = arith.muli %[[AFF0]], %[[C1A:.*]] : index
+ // CHECK-DAG: %[[COL0:.*]] = arith.muli %[[AFF1:.*]], %[[C1B:.*]] : index
+ // CHECK-DAG: %[[COL1:.*]] = arith.muli %[[AFF2]], %[[C32A:.*]] : index
+ // CHECK-DAG: xegpu.store_matrix %[[CAST]], %[[MEM_DESC]][%[[ROW]], %[[COL0]], %[[COL1]]] : vector<1x1x32xf32>, !xegpu.mem_desc<1x32x32xf32>, index, index, index
// CHECK-DAG: gpu.barrier
- // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[C0:.*]], %[[COL_OFFSET]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<32x32xf32>
- // CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<0.000000e+00> : vector<32xf32>
- // CHECK-DAG: %[[FINAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD_SLM]], %[[CST_3]] [0] : vector<32x32xf32> to vector<32xf32>
- // CHECK-DAG: %[[SHAPE_CAST_FINAL:.*]] = vector.shape_cast %[[CST]] : vector<1x32xf32> to vector<32xf32>
- // CHECK-DAG: %{{.*}} = arith.addf %[[FINAL_REDUCE]], %[[SHAPE_CAST_FINAL]] : vector<32xf32>
+ // CHECK-DAG: %[[ROW_L:.*]] = arith.muli %[[AFF0]], %[[C1C:.*]] : index
+ // CHECK-DAG: %[[COL0_L:.*]] = arith.muli %[[AFF1]], %[[C0:.*]] : index
+ // CHECK-DAG: %[[COL1_L:.*]] = arith.muli %[[AFF2]], %[[C32B:.*]] : index
+ // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[ROW_L]], %[[COL0_L]], %[[COL1_L]]] : !xegpu.mem_desc<1x32x32xf32>, index, index, index -> vector<1x32x32xf32>
+ // CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<0.000000e+00> : vector<1x32xf32>
+ // CHECK-DAG: %[[FINAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD_SLM]], %[[CST_3]] [1] : vector<1x32x32xf32> to vector<1x32xf32>
+ // CHECK-DAG: %[[ADD:.*]] = arith.addf %[[FINAL_REDUCE]], %[[CST]] : vector<1x32xf32>
// CHECK-DAG: gpu.return
%cst_3 = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 32, 1], sg_data = [1, 1, 32]>, dims = [1]>} dense<1.0> : vector<1x32xf32>
%offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [1, 32, 1], sg_data = [1, 1, 32]>} dense<0> : vector<1x32x32xindex>
@@ -688,7 +686,7 @@ gpu.module @test_distribution {
// CHECK-DAG: %[[DIV1:.*]] = arith.divui %[[SGID]], %[[C4:.*]] : index
// CHECK-DAG: %[[REM2:.*]] = arith.remui %[[DIV1]], %[[C8:.*]] : index
// CHECK-DAG: %[[MUL1:.*]] = arith.muli %[[REM2]], %[[C32:.*]] : index
- // CHECK-DAG: %[[MUL2:.*]] = arith.muli %[[REM1]], %[[C32:.*]] : index
+ // CHECK-DAG: %[[MUL2:.*]] = arith.muli %[[REM1]], %[[C32_0:.*]] : index
// CHECK-DAG: %[[REM3:.*]] = arith.remui %[[MUL1]], %[[C256:.*]] : index
// CHECK-DAG: %[[REM4:.*]] = arith.remui %[[MUL2]], %[[C128:.*]] : index
// CHECK-DAG: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%[[REM3]], %[[REM4]]] : memref<256x128xf32> -> !xegpu.tensor_desc<32x32xf32>
@@ -699,19 +697,18 @@ gpu.module @test_distribution {
// CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<4096xi8, 3>
// CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<4096xi8, 3> -> !xegpu.mem_desc<8x128xf32>
// CHECK-DAG: %[[SGID2:.*]] = gpu.subgroup_id : index
- // CHECK-DAG: %[[AFFINE1:.*]] = affine.apply #map3()[%[[SGID2]]]
- // CHECK-DAG: %[[AFFINE2:.*]] = affine.apply #map4()[%[[SGID2]]]
- // CHECK-DAG: %[[MUL3:.*]] = arith.muli {{.*}}, %[[C1:.*]] : index
- // CHECK-DAG: %[[ROW_OFFSET:.*]] = arith.addi %[[C0:.*]], %[[MUL3]] : index
- // CHECK-DAG: %[[MUL4:.*]] = arith.muli {{.*}}, %[[C1:.*]] : index
- // CHECK-DAG: %[[ADD1:.*]] = arith.addi %[[C0:.*]], %[[MUL4]] : index
- // CHECK-DAG: %[[COL_OFFSET:.*]] = arith.muli %[[ADD1]], %[[C32:.*]] : index
+ // CHECK-DAG: %[[AFFINE1:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID2]]]
+ // CHECK-DAG: %[[AFFINE2:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID2]]]
+ // CHECK-DAG: %[[ROW_OFFSET:.*]] = arith.muli %[[AFFINE1]], %[[C1:.*]] : index
+ // CHECK-DAG: %[[COL_OFFSET:.*]] = arith.muli %[[AFFINE2]], %[[C32_1:.*]] : index
// CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[ROW_OFFSET]], %[[COL_OFFSET]]] : vector<1x32xf32>, !xegpu.mem_desc<8x128xf32>, index, index
// CHECK-DAG: gpu.barrier
- // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[C0:.*]], %[[COL_OFFSET]]] : !xegpu.mem_desc<8x128xf32>, index, index -> vector<8x32xf32>
+ // CHECK-DAG: %[[ZERO_ROW:.*]] = arith.muli %[[AFFINE1]], %[[C0:.*]] : index
+ // CHECK-DAG: %[[COL_OFFSET2:.*]] = arith.muli %[[AFFINE2]], %[[C32_2:.*]] : index
+ // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[ZERO_ROW]], %[[COL_OFFSET2]]] : !xegpu.mem_desc<8x128xf32>, index, index -> vector<8x32xf32>
// CHECK-DAG: %[[CST_CROSS_SG_1:.*]] = arith.constant dense<0.000000e+00> : vector<32xf32>
// CHECK-DAG: %[[FINAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD_SLM]], %[[CST_CROSS_SG_1]] [0] : vector<8x32xf32> to vector<32xf32>
- // CHECK-DAG: arith.addf %[[FINAL_REDUCE]], %[[CST:.*]] : vector<32xf32>
+ // CHECK-DAG: arith.addf %[[FINAL_REDUCE]], %[[CST]] : vector<32xf32>
%cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>, dims = [0]>} dense<0.0> : vector<128xf32>
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>>
@@ -732,36 +729,38 @@ gpu.module @test_distribution {
// CHECK-DAG: %[[LOAD:.*]] = xegpu.load %{{.*}}[%[[CST_0]]], %[[CST_1]] <{chunk_size = 1 : i64}> : memref<?xf32>, vector<1x1x32x32xindex>, vector<1x1x32x32xi1> -> vector<1x1x32x32xf32>
// CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<0.000000e+00> : vector<1x1xf32>
// CHECK-DAG: %[[LOCAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD]], %[[CST_2]] [2, 3] : vector<1x1x32x32xf32> to vector<1x1xf32>
- // CHECK-DAG: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[LOCAL_REDUCE]] : vector<1x1xf32> to vector<1x1xf32>
+ // CHECK-DAG: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[LOCAL_REDUCE]] : vector<1x1xf32> to vector<1x1x1x1xf32>
// CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<256xi8, 3>
- // CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<256xi8, 3> -> !xegpu.mem_desc<16x4xf32>
+ // CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<256xi8, 3> -> !xegpu.mem_desc<2x2x4x4xf32>
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
- // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
- // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
- // CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
- // CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
- // CHECK-DAG: %[[AFFINE1:.*]] = affine.apply #map()[%[[SGID]]]
- // CHECK-DAG: %[[AFFINE2:.*]] = affine.apply #map1()[%[[SGID]]]
- // CHECK-DAG: %[[AFFINE3:.*]] = affine.apply #map5()[%[[SGID]]]
- // CHECK-DAG: %[[AFFINE4:.*]] = affine.apply #map6()[%[[SGID]]]
- // CHECK-DAG: %[[AFFINE5:.*]] = affine.apply #map7()[%[[SGID]]]
- // CHECK-DAG: %[[AFFINE6:.*]] = affine.apply #map4()[%[[SGID]]]
- // CHECK-DAG: %[[MUL1:.*]] = arith.muli {{.*}}, %[[C1:.*]] : index
- // CHECK-DAG: %[[ADD1:.*]] = arith.addi %[[C0:.*]], %[[MUL1]] : index
- // CHECK-DAG: %[[MUL2:.*]] = arith.muli {{.*}}, %[[C4:.*]] : index
- // CHECK-DAG: %[[ROW_OFFSET:.*]] = arith.addi %[[ADD1]], %[[MUL2]] : index
- // CHECK-DAG: %[[MUL3:.*]] = arith.muli {{.*}}, %[[C1:.*]] : index
- // CHECK-DAG: %[[ADD2:.*]] = arith.addi %[[C0:.*]], %[[MUL3]] : index
- // CHECK-DAG: %[[MUL4:.*]] = arith.muli {{.*}}, %[[C2:.*]] : index
- // CHECK-DAG: %[[ADD3:.*]] = arith.addi %[[ADD2]], %[[MUL4]] : index
- // CHECK-DAG: %[[COL_OFFSET:.*]] = arith.muli %[[ADD3]], %[[C1:.*]] : index
- // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[ROW_OFFSET]], %[[COL_OFFSET]]] : vector<1x1xf32>, !xegpu.mem_desc<16x4xf32>, index, index
+ // CHECK-DAG: %[[AFFINE0:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
+ // CHECK-DAG: %[[AFFINE1:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
+ // CHECK-DAG: %[[AFFINE2:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
+ // CHECK-DAG: %[[AFFINE3:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
+ // CHECK-DAG: %[[AFFINE4:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
+ // CHECK-DAG: %[[AFFINE5:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK-DAG: %[[I0:.*]] = arith.muli %[[AFFINE0]], %[[C1]] : index
+ // CHECK-DAG: %[[C1_0:.*]] = arith.constant 1 : index
+ // CHECK-DAG: %[[I1:.*]] = arith.muli %[[AFFINE2]], %[[C1_0]] : index
+ // CHECK-DAG: %[[C1_1:.*]] = arith.constant 1 : index
+ // CHECK-DAG: %[[I2:.*]] = arith.muli %[[AFFINE4]], %[[C1_1]] : index
+ // CHECK-DAG: %[[C1_2:.*]] = arith.constant 1 : index
+ // CHECK-DAG: %[[I3:.*]] = arith.muli %[[AFFINE5]], %[[C1_2]] : index
+ // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[I0]], %[[I1]], %[[I2]], %[[I3]]] : vector<1x1x1x1xf32>, !xegpu.mem_desc<2x2x4x4xf32>, index, index, index, index
// CHECK-DAG: gpu.barrier
- // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[C0:.*]], %[[COL_OFFSET]]] : !xegpu.mem_desc<16x4xf32>, index, index -> vector<16x1xf32>
- // CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
- // CHECK-DAG: %[[FINAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD_SLM]], %[[CST_3]] [0] : vector<16x1xf32> to vector<1xf32>
- // CHECK-DAG: %[[SHAPE_CAST_FINAL:.*]] = vector.shape_cast %[[CST]] : vector<1x1xf32> to vector<1xf32>
- // CHECK-DAG: %[[FINAL_ADD:.*]] = arith.addf %[[FINAL_REDUCE]], %[[SHAPE_CAST_FINAL]] : vector<1xf32>
+ // CHECK-DAG: %[[C1_3:.*]] = arith.constant 1 : index
+ // CHECK-DAG: %[[L0:.*]] = arith.muli %[[AFFINE0]], %[[C1_3]] : index
+ // CHECK-DAG: %[[C1_4:.*]] = arith.constant 1 : index
+ // CHECK-DAG: %[[L1:.*]] = arith.muli %[[AFFINE2]], %[[C1_4]] : index
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[L2:.*]] = arith.muli %[[AFFINE4]], %[[C0]] : index
+ // CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[L3:.*]] = arith.muli %[[AFFINE5]], %[[C0_0]] : index
+ // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[L0]], %[[L1]], %[[L2]], %[[L3]]] : !xegpu.mem_desc<2x2x4x4xf32>, index, index, index, index -> vector<1x1x4x4xf32>
+ // CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<0.000000e+00> : vector<1x1xf32>
+ // CHECK-DAG: %[[FINAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD_SLM]], %[[CST_3]] [2, 3] : vector<1x1x4x4xf32> to vector<1x1xf32>
+ // CHECK-DAG: %[[FINAL_ADD:.*]] = arith.addf %[[FINAL_REDUCE]], %[[CST]] : vector<1x1xf32>
// CHECK-DAG: gpu.return
%cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [2, 2, 4, 4], sg_data = [1, 1, 32, 32]>, dims = [2, 3]>} dense<0.0> : vector<2x2xf32>
%offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [2, 2, 4, 4], sg_data = [1, 1, 32, 32]>} dense<0> : vector<2x2x128x128xindex>
@@ -780,32 +779,30 @@ gpu.module @test_distribution {
// CHECK-DAG: %[[LOAD:.*]] = xegpu.load %[[ARG0]][%[[CST_0]]], %[[CST_1]] <{chunk_size = 1 : i64}> : memref<?xf32>, vector<16x16x32x32xindex>, vector<16x16x32x32xi1> -> vector<16x16x32x32xf32>
// CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<0.000000e+00> : vector<16x16xf32>
// CHECK-DAG: %[[LOCAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD]], %[[CST_2]] [2, 3] : vector<16x16x32x32xf32> to vector<16x16xf32>
- // CHECK-DAG: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[LOCAL_REDUCE]] : vector<16x16xf32> to vector<1x256xf32>
+ // CHECK-DAG: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[LOCAL_REDUCE]] : vector<16x16xf32> to vector<16x16x1x1xf32>
// CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<65536xi8, 3>
- // CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<65536xi8, 3> -> !xegpu.mem_desc<16x1024xf32>
+ // CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<65536xi8, 3> -> !xegpu.mem_desc<32x32x4x4xf32>
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
- // CHECK-DAG: %[[AFFINE1:.*]] = affine.apply #map()[%[[SGID]]]
- // CHECK-DAG: %[[AFFINE2:.*]] = affine.apply #map1()[%[[SGID]]]
- // CHECK-DAG: %[[AFFINE3:.*]] = affine.apply #map5()[%[[SGID]]]
- // CHECK-DAG: %[[AFFINE4:.*]] = affine.apply #map6()[%[[SGID]]]
- // CHECK-DAG: %[[AFFINE5:.*]] = affine.apply #map7()[%[[SGID]]]
- // CHECK-DAG: %[[AFFINE6:.*]] = affine.apply #map4()[%[[SGID]]]
- // CHECK-DAG: %[[MUL1:.*]] = arith.muli {{.*}}, %[[C1:.*]] : index
- // CHECK-DAG: %[[ADD1:.*]] = arith.addi %[[C0:.*]], %[[MUL1]] : index
- // CHECK-DAG: %[[MUL2:.*]] = arith.muli {{.*}}, %[[C4:.*]] : index
- // CHECK-DAG: %[[ROW_OFFSET:.*]] = arith.addi %[[ADD1]], %[[MUL2]] : index
- // CHECK-DAG: %[[MUL3:.*]] = arith.muli {{.*}}, %[[C1:.*]] : index
- // CHECK-DAG: %[[ADD2:.*]] = arith.addi %[[C0:.*]], %[[MUL3]] : index
- // CHECK-DAG: %[[MUL4:.*]] = arith.muli {{.*}}, %[[C2:.*]] : index
- // CHECK-DAG: %[[ADD3:.*]] = arith.addi %[[ADD2]], %[[MUL4]] : index
- // CHECK-DAG: %[[COL_OFFSET:.*]] = arith.muli %[[ADD3]], %[[C256:.*]] : index
- // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[ROW_OFFSET]], %[[COL_OFFSET]]] : vector<1x256xf32>, !xegpu.mem_desc<16x1024xf32>, index, index
+ // CHECK-DAG: %[[AFFINE0:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
+ // CHECK-DAG: %[[AFFINE1:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
+ // CHECK-DAG: %[[AFFINE2:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
+ // CHECK-DAG: %[[AFFINE3:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
+ // CHECK-DAG: %[[AFFINE4:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
+ // CHECK-DAG: %[[AFFINE5:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
+ // CHECK-DAG: %[[R0:.*]] = arith.muli %[[AFFINE0]], %[[C16_0:.*]] : index
+ // CHECK-DAG: %[[R1:.*]] = arith.muli %[[AFFINE2]], %[[C16_1:.*]] : index
+ // CHECK-DAG: %[[R2:.*]] = arith.muli %[[AFFINE4]], %[[C1_0:.*]] : index
+ // CHECK-DAG: %[[R3:.*]] = arith.muli %[[AFFINE5]], %[[C1_1:.*]] : index
+ // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[R0]], %[[R1]], %[[R2]], %[[R3]]] : vector<16x16x1x1xf32>, !xegpu.mem_desc<32x32x4x4xf32>, index, index, index, index
// CHECK-DAG: gpu.barrier
- // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[C0:.*]], %[[COL_OFFSET]]] : !xegpu.mem_desc<16x1024xf32>, index, index -> vector<16x256xf32>
- // CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<0.000000e+00> : vector<256xf32>
- // CHECK-DAG: %[[FINAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD_SLM]], %[[CST_3]] [0] : vector<16x256xf32> to vector<256xf32>
- // CHECK-DAG: %[[SHAPE_CAST_FINAL:.*]] = vector.shape_cast %[[CST]] : vector<16x16xf32> to vector<256xf32>
- // CHECK-DAG: %[[FINAL_ADD:.*]] = arith.addf %[[FINAL_REDUCE]], %[[SHAPE_CAST_FINAL]] : vector<256xf32>
+ // CHECK-DAG: %[[L0:.*]] = arith.muli %[[AFFINE0]], %[[C16_2:.*]] : index
+ // CHECK-DAG: %[[L1:.*]] = arith.muli %[[AFFINE2]], %[[C16_3:.*]] : index
+ // CHECK-DAG: %[[L2:.*]] = arith.muli %[[AFFINE4]], %[[C0_0:.*]] : index
+ // CHECK-DAG: %[[L3:.*]] = arith.muli %[[AFFINE5]], %[[C0_1:.*]] : index
+ // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[L0]], %[[L1]], %[[L2]], %[[L3]]] : !xegpu.mem_desc<32x32x4x4xf32>, index, index, index, index -> vector<16x16x4x4xf32>
+ // CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<0.000000e+00> : vector<16x16xf32>
+ // CHECK-DAG: %[[FINAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD_SLM]], %[[CST_3]] [2, 3] : vector<16x16x4x4xf32> to vector<16x16xf32>
+ // CHECK-DAG: %[[FINAL_ADD:.*]] = arith.addf %[[FINAL_REDUCE]], %[[CST]] : vector<16x16xf32>
// CHECK-DAG: gpu.return
%cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [2, 2, 4, 4], sg_data = [16, 16, 32, 32]>, dims = [2, 3]>} dense<0.0> : vector<32x32xf32>
%offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [2, 2, 4, 4], sg_data = [16, 16, 32, 32]>} dense<0> : vector<32x32x128x128xindex>
>From 0a40a853711726999072c842bfee0b0fc8208922 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 18 Feb 2026 23:03:50 +0000
Subject: [PATCH 2/3] remove debug print
---
.../Transforms/XeGPUWgToSgDistribute.cpp | 35 +------------------
1 file changed, 1 insertion(+), 34 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 8ab0b50da03ad..a4172bdc9d136 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1401,12 +1401,6 @@ struct WgToSgMultiDimReductionOp
SmallVector<int64_t> sgSrcShape(sgSrcType.getShape().begin(),
sgSrcType.getShape().end());
- // debug print sgDstShape
- llvm::errs() << "[WgToSgMultiDimReductionOp] sgDstShape = [";
- for (auto [i, v] : llvm::enumerate(sgDstShape))
- llvm::errs() << (i ? ", " : "") << v;
- llvm::errs() << "]\n";
-
VectorType newDstType = VectorType::get(sgDstShape, elemTy);
for (auto sgSrc : sgSrcs) {
// Create ZERO accumulator for local reduction
@@ -1457,21 +1451,6 @@ struct WgToSgMultiDimReductionOp
for (int64_t dim : reductionDims)
slmShape[dim] = originalSrcShape[dim] / sgSrcShape[dim];
- // Debug print.
- llvm::errs() << "[WgToSgMultiDimReductionOp] originalSrcShape = [";
- for (auto [i, v] : llvm::enumerate(originalSrcShape))
- llvm::errs() << (i ? ", " : "") << v;
- llvm::errs() << "], sgSrcShape = [";
- for (auto [i, v] : llvm::enumerate(sgSrcShape))
- llvm::errs() << (i ? ", " : "") << v;
- llvm::errs() << "], slmShape = [";
- for (auto [i, v] : llvm::enumerate(slmShape))
- llvm::errs() << (i ? ", " : "") << v;
- llvm::errs() << "], reductionDims = [";
- for (auto [i, v] : llvm::enumerate(reductionDims))
- llvm::errs() << (i ? ", " : "") << v;
- llvm::errs() << "]\n";
-
// Allocate SLM
auto bitWidth = elemTy.getIntOrFloatBitWidth();
auto bytesPerElement = bitWidth / 8;
@@ -1516,14 +1495,7 @@ struct WgToSgMultiDimReductionOp
Value offsetVal = arith::MulIOp::create(rewriter, loc, dimVal, strideVal);
slmStoreOffsets.push_back(offsetVal);
}
- // debug print
- llvm::errs() << "[WgToSgMultiDimReductionOp] sgIds = [";
- for (auto [i, v] : llvm::enumerate(sgIds))
- llvm::errs() << (i ? ", " : "") << v;
- llvm::errs() << "], slmStoreOffsets = [";
- for (auto [i, v] : llvm::enumerate(slmStoreOffsets))
- llvm::errs() << (i ? ", " : "") << v;
- llvm::errs() << "]\n";
+
xegpu::StoreMatrixOp::create(rewriter, loc, slmStoreData,
memDesc.getResult(), slmStoreOffsets,
/*layout=*/nullptr);
@@ -1535,11 +1507,6 @@ struct WgToSgMultiDimReductionOp
for (int64_t dim : reductionDims)
slmLoadDataShape[dim] = slmShape[dim];
- llvm::errs() << "[WgToSgMultiDimReductionOp] slmLoadDataShape = [";
- for (auto [i, v] : llvm::enumerate(slmLoadDataShape))
- llvm::errs() << (i ? ", " : "") << v;
- llvm::errs() << "]\n";
-
SmallVector<OpFoldResult> slmLoadOffsets;
for (int i = 0; i < srcVecRank; ++i) {
Value dimVal = sgIds[i];
>From c2eb9ff5f4227d0fdc8891b821fcab91652c8732 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 18 Feb 2026 23:11:14 +0000
Subject: [PATCH 3/3] polish
---
.../Transforms/XeGPUWgToSgDistribute.cpp | 46 -------------------
1 file changed, 46 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index a4172bdc9d136..5330217972bf9 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1277,52 +1277,6 @@ static Value createAccumulator(ConversionPatternRewriter &rewriter,
return nullptr;
}
-/// This function converts multi-dimensional subgroup indices into a single
-/// linear offset. It's used to calculate memory offsets in SLM for
-/// cross-subgroup reduction coordination.
-///
-/// Parameters:
-/// - sgIds: Multi-dimensional subgroup indices (e.g., [sgId_x, sgId_y, sgId_z])
-/// - dims: Which dimensions to include in linearization (e.g., [0, 2] for x and
-/// z dims)
-/// - sgLayout: Subgroup layout sizes for each dimension (e.g., [4, 8, 2] means
-/// 4x8x2 subgroups)
-///
-/// It uses row-major linearization formula:
-/// offset = sum(sgIds[dim] * stride[dim])
-/// where stride[dim] = product of all sgLayout sizes in dimensions after
-/// 'dim'
-///
-/// Example:
-/// - sgLayout = [4, 8, 2], dims = [0, 2] (linearize x and z dimensions)
-/// - sgIds = [1, 3, 1] (subgroup at position x=1, y=3, z=1)
-/// - Calculation:
-/// * dim=0: stride=1, term = sgIds[0] * 1 = 1 * 1 = 1
-/// * dim=2: stride=sgLayout[0]=4, term = sgIds[2] * 4 = 1 * 4 = 4
-/// * linearizedOffset = 1 + 4 = 5
-///
-/// This gives us a unique linear index for each combination of subgroup
-/// positions in the specified dimensions, which is used for SLM row/column
-/// addressing.
-[[maybe_unused]] static Value
-linearizeSubgroupIndices(ConversionPatternRewriter &rewriter, Location loc,
- ArrayRef<Value> sgIds, ArrayRef<int64_t> dims,
- ArrayRef<int64_t> sgLayout) {
- Value linearizedOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
- int64_t stride = 1;
-
- for (int64_t dim : dims) {
- Value dimVal = sgIds[dim];
- Value strideVal = arith::ConstantIndexOp::create(rewriter, loc, stride);
- Value term = arith::MulIOp::create(rewriter, loc, dimVal, strideVal);
- linearizedOffset =
- arith::AddIOp::create(rewriter, loc, linearizedOffset, term);
- stride *= sgLayout[dim];
- }
-
- return linearizedOffset;
-}
-
/// This pattern transforms vector.multi_dim_reduction operations from
/// workgroup-level to subgroup-level execution with support for multiple
/// reduction dimensions.
More information about the Mlir-commits
mailing list