[Mlir-commits] [mlir] [MLIR][XeGPU] Improve workgroup to subgroup distribution pattern for mulit-reduction op (PR #182178)

Jianhui Li llvmlistbot at llvm.org
Wed Feb 18 15:11:28 PST 2026


https://github.com/Jianhui-Li updated https://github.com/llvm/llvm-project/pull/182178

>From 61e382c4d7609c194ed71da159f2a0a0bea80929 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 18 Feb 2026 23:00:41 +0000
Subject: [PATCH 1/3] improve workg group distribution pattern for
 cross-subgroup reduction using nd-slm

---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 193 +++++++++---------
 .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir       | 155 +++++++-------
 2 files changed, 178 insertions(+), 170 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 66eb7fc97aa1a..8ab0b50da03ad 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1304,10 +1304,10 @@ static Value createAccumulator(ConversionPatternRewriter &rewriter,
 /// This gives us a unique linear index for each combination of subgroup
 /// positions in the specified dimensions, which is used for SLM row/column
 /// addressing.
-static Value linearizeSubgroupIndices(ConversionPatternRewriter &rewriter,
-                                      Location loc, ArrayRef<Value> sgIds,
-                                      ArrayRef<int64_t> dims,
-                                      ArrayRef<int64_t> sgLayout) {
+[[maybe_unused]] static Value
+linearizeSubgroupIndices(ConversionPatternRewriter &rewriter, Location loc,
+                         ArrayRef<Value> sgIds, ArrayRef<int64_t> dims,
+                         ArrayRef<int64_t> sgLayout) {
   Value linearizedOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
   int64_t stride = 1;
 
@@ -1370,6 +1370,9 @@ struct WgToSgMultiDimReductionOp
       return failure();
 
     auto originalSrcShape = srcType.getShape();
+    auto originalDstShape = dstType.getShape();
+    int srcVecRank = originalSrcShape.size();
+
     xegpu::DistributeLayoutAttr layout =
         xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
     if (!layout || !layout.isForWorkgroup())
@@ -1391,10 +1394,21 @@ struct WgToSgMultiDimReductionOp
 
     // Step 1: perform local subgroup reductions with ZERO accumulator
     SmallVector<Value> localReductions;
-    SmallVector<int64_t> sgShape =
-        getSgShapeAndCount(originalSrcShape, layout).first;
-    VectorType newDstType = VectorType::get(sgShape, elemTy);
-    for (auto sgSrc : adaptor.getSource()) {
+    SmallVector<int64_t> sgDstShape =
+        getSgShapeAndCount(originalDstShape, layout).first;
+    auto sgSrcs = adaptor.getSource();
+    auto sgSrcType = dyn_cast<VectorType>(sgSrcs.front().getType());
+    SmallVector<int64_t> sgSrcShape(sgSrcType.getShape().begin(),
+                                    sgSrcType.getShape().end());
+
+    // debug print sgDstShape
+    llvm::errs() << "[WgToSgMultiDimReductionOp] sgDstShape = [";
+    for (auto [i, v] : llvm::enumerate(sgDstShape))
+      llvm::errs() << (i ? ", " : "") << v;
+    llvm::errs() << "]\n";
+
+    VectorType newDstType = VectorType::get(sgDstShape, elemTy);
+    for (auto sgSrc : sgSrcs) {
       // Create ZERO accumulator for local reduction
       auto neutralLocalAcc =
           createAccumulator(rewriter, loc, newDstType, op.getKind());
@@ -1430,44 +1444,53 @@ struct WgToSgMultiDimReductionOp
     }
 
     // Step 2: cross-subgroup reduction using SLM
-
-    // Calculate total elements in local result
-    int64_t localElements = computeProduct(sgShape);
-
-    // Shape cast for SLM storage - store as [1, localElements]
-    SmallVector<int64_t> storeShape2D = {1, localElements};
-    VectorType storeType2D = VectorType::get(storeShape2D, elemTy);
-    auto storeShapeCast = vector::ShapeCastOp::create(
-        rewriter, loc, storeType2D, localReductions[0]);
-    Value storeData = storeShapeCast.getResult();
-
-    // Calculate SLM shape - rows for sg's in reduction dims, cols for total
-    // result elements across all subgroups in non-reduction dimensions
-    int64_t totalReductionSubgroups = 1;
-    for (int64_t dim : crossSgReductionDims) {
-      totalReductionSubgroups *= sgLayout[dim];
-    }
-
-    // Total result elements across all subgroups in non-reduction dimensions
-    int64_t totalResultElements =
-        localElements * computeProduct(sgLayout) / totalReductionSubgroups;
-
-    SmallVector<int64_t> slmShape2D = {totalReductionSubgroups,
-                                       totalResultElements};
+    auto slmStoreDataShape = sgSrcShape;
+    for (int64_t dim : reductionDims)
+      slmStoreDataShape[dim] = 1;
+    VectorType slmStoreDataType = VectorType::get(slmStoreDataShape, elemTy);
+    Value slmStoreData = vector::ShapeCastOp::create(
+        rewriter, loc, slmStoreDataType, localReductions[0]);
+
+    SmallVector<int64_t> slmShape(originalSrcShape.begin(),
+                                  originalSrcShape.end());
+    // for reduction dimension, SLM stores partial results from each subgroup
+    for (int64_t dim : reductionDims)
+      slmShape[dim] = originalSrcShape[dim] / sgSrcShape[dim];
+
+    // Debug print.
+    llvm::errs() << "[WgToSgMultiDimReductionOp] originalSrcShape = [";
+    for (auto [i, v] : llvm::enumerate(originalSrcShape))
+      llvm::errs() << (i ? ", " : "") << v;
+    llvm::errs() << "], sgSrcShape = [";
+    for (auto [i, v] : llvm::enumerate(sgSrcShape))
+      llvm::errs() << (i ? ", " : "") << v;
+    llvm::errs() << "], slmShape = [";
+    for (auto [i, v] : llvm::enumerate(slmShape))
+      llvm::errs() << (i ? ", " : "") << v;
+    llvm::errs() << "], reductionDims = [";
+    for (auto [i, v] : llvm::enumerate(reductionDims))
+      llvm::errs() << (i ? ", " : "") << v;
+    llvm::errs() << "]\n";
 
     // Allocate SLM
     auto bitWidth = elemTy.getIntOrFloatBitWidth();
     auto bytesPerElement = bitWidth / 8;
-    int64_t slmElements = slmShape2D[0] * slmShape2D[1];
-    auto slmSize = slmElements * bytesPerElement;
+    auto slmSize = computeProduct(slmShape) * bytesPerElement;
     auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3);
     auto slm = memref::AllocaOp::create(rewriter, loc, slmTy);
 
-    auto memDescType = xegpu::MemDescType::get(rewriter.getContext(),
-                                               slmShape2D, elemTy, nullptr);
+    auto memDescType = xegpu::MemDescType::get(rewriter.getContext(), slmShape,
+                                               elemTy, nullptr);
     auto memDesc =
         xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm);
 
+    // if localReductions have more than 1 result, not support
+    if (localReductions.size() > 1) {
+      return rewriter.notifyMatchFailure(
+          op,
+          "Multiple local reductions not supported in current implementation.");
+    }
+
     // Step 4: Store local results to SLM
     auto sgId = gpu::SubgroupIdOp::create(rewriter, loc,
                                           rewriter.getIndexType(), nullptr);
@@ -1484,78 +1507,66 @@ struct WgToSgMultiDimReductionOp
       return failure();
     SmallVector<Value> sgIds = *sgIdsResult;
 
-    // Row offset: linearize reduction dimension indices
-    Value rowOffsetStore = linearizeSubgroupIndices(
-        rewriter, loc, sgIds, crossSgReductionDims, sgLayout);
-
-    // Column offset: linearize non-reduction dimension indices
-    SmallVector<int64_t> nonReductionDims;
-    for (size_t i = 0; i < sgLayout.size(); ++i) {
-      if (!llvm::is_contained(reductionDims, static_cast<int64_t>(i))) {
-        nonReductionDims.push_back(static_cast<int64_t>(i));
-      }
+    SmallVector<OpFoldResult> slmStoreOffsets;
+    for (int i = 0; i < srcVecRank; ++i) {
+      Value dimVal = sgIds[i];
+      int64_t stride =
+          (llvm::is_contained(reductionDims, i)) ? 1 : sgSrcShape[i];
+      Value strideVal = arith::ConstantIndexOp::create(rewriter, loc, stride);
+      Value offsetVal = arith::MulIOp::create(rewriter, loc, dimVal, strideVal);
+      slmStoreOffsets.push_back(offsetVal);
     }
-
-    Value colOffset = linearizeSubgroupIndices(rewriter, loc, sgIds,
-                                               nonReductionDims, sgLayout);
-
-    Value localElementsVal =
-        arith::ConstantIndexOp::create(rewriter, loc, localElements);
-    colOffset =
-        arith::MulIOp::create(rewriter, loc, colOffset, localElementsVal);
-
-    SmallVector<OpFoldResult> storeOffsets2D = {rowOffsetStore, colOffset};
-
-    xegpu::StoreMatrixOp::create(rewriter, loc, storeData, memDesc.getResult(),
-                                 storeOffsets2D, /*layout=*/nullptr);
+    // debug print
+    llvm::errs() << "[WgToSgMultiDimReductionOp] sgIds = [";
+    for (auto [i, v] : llvm::enumerate(sgIds))
+      llvm::errs() << (i ? ", " : "") << v;
+    llvm::errs() << "], slmStoreOffsets = [";
+    for (auto [i, v] : llvm::enumerate(slmStoreOffsets))
+      llvm::errs() << (i ? ", " : "") << v;
+    llvm::errs() << "]\n";
+    xegpu::StoreMatrixOp::create(rewriter, loc, slmStoreData,
+                                 memDesc.getResult(), slmStoreOffsets,
+                                 /*layout=*/nullptr);
 
     gpu::BarrierOp::create(rewriter, loc);
 
     // Step 5: Load from SLM for final reduction
-    SmallVector<int64_t> loadShape2D = {totalReductionSubgroups, localElements};
-    VectorType loadType2D = VectorType::get(loadShape2D, elemTy);
-
-    // Load offsets - each subgroup loads its column based on non-reduction
-    // position
-    Value rowOffsetLoad = arith::ConstantIndexOp::create(rewriter, loc, 0);
-
-    SmallVector<OpFoldResult> loadOffsets2D = {rowOffsetLoad, colOffset};
+    SmallVector<int64_t> slmLoadDataShape(sgSrcShape.begin(), sgSrcShape.end());
+    for (int64_t dim : reductionDims)
+      slmLoadDataShape[dim] = slmShape[dim];
+
+    llvm::errs() << "[WgToSgMultiDimReductionOp] slmLoadDataShape = [";
+    for (auto [i, v] : llvm::enumerate(slmLoadDataShape))
+      llvm::errs() << (i ? ", " : "") << v;
+    llvm::errs() << "]\n";
+
+    SmallVector<OpFoldResult> slmLoadOffsets;
+    for (int i = 0; i < srcVecRank; ++i) {
+      Value dimVal = sgIds[i];
+      int64_t stride =
+          (llvm::is_contained(reductionDims, i)) ? 0 : sgSrcShape[i];
+      Value strideVal = arith::ConstantIndexOp::create(rewriter, loc, stride);
+      Value offsetVal = arith::MulIOp::create(rewriter, loc, dimVal, strideVal);
+      slmLoadOffsets.push_back(offsetVal);
+    }
 
-    auto loadOp = xegpu::LoadMatrixOp::create(
-        rewriter, loc, loadType2D, memDesc.getResult(), loadOffsets2D,
+    VectorType slmLoadType = VectorType::get(slmLoadDataShape, elemTy);
+    auto slmLoadOp = xegpu::LoadMatrixOp::create(
+        rewriter, loc, slmLoadType, memDesc.getResult(), slmLoadOffsets,
         /*layout=*/nullptr);
 
     // Step 6: Perform final reduction with ZERO accumulator
-    SmallVector<int64_t> finalReductionDims = {0};
-    SmallVector<int64_t> finalResultShape = {localElements};
-    VectorType finalResultType = VectorType::get(finalResultShape, elemTy);
-
     auto neutralFinalAcc =
-        createAccumulator(rewriter, loc, finalResultType, op.getKind());
+        createAccumulator(rewriter, loc, newDstType, op.getKind());
 
     auto finalReduce = vector::MultiDimReductionOp::create(
-        rewriter, loc, finalResultType, op.getKind(), loadOp.getResult(),
-        neutralFinalAcc, finalReductionDims);
+        rewriter, loc, newDstType, op.getKind(), slmLoadOp.getResult(),
+        neutralFinalAcc, reductionDims);
 
     // Step 7: Add the original accumulator at the end
     Value originalAcc = adaptor.getAcc()[0];
     Value accToAdd = originalAcc;
 
-    // Handle shape mismatch by shape casting
-    if (originalAcc.getType() != finalReduce.getResult().getType()) {
-      auto originalAccType = cast<VectorType>(originalAcc.getType());
-      auto finalResultType =
-          cast<VectorType>(finalReduce.getResult().getType());
-
-      // If they have the same number of elements, just shape cast
-      if (originalAccType.getNumElements() ==
-          finalResultType.getNumElements()) {
-        auto shapeCast = vector::ShapeCastOp::create(
-            rewriter, loc, finalResultType, originalAcc);
-        accToAdd = shapeCast.getResult();
-      }
-    }
-
     auto finalResult = vector::makeArithReduction(
         rewriter, loc, op.getKind(), finalReduce.getResult(), accToAdd);
 
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index e2e94c5f0300f..9407f7f2357a2 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -646,30 +646,28 @@ gpu.module @test_distribution {
     // CHECK-DAG: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<1x32xf32>
     // CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<0> : vector<1x1x32xindex>
     // CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<true> : vector<1x1x32xi1>
-    // CHECK-DAG: %[[LOAD:.*]] = xegpu.load %{{.*}}[%[[CST_0]]], %[[CST_1]] <{chunk_size = 1 : i64}> : memref<?xf32>, vector<1x1x32xindex>, vector<1x1x32xi1> -> vector<1x1x32xf32>
+    // CHECK-DAG: %[[LOAD:.*]] = xegpu.load %[[ARG0:.*]][%[[CST_0]]], %[[CST_1]] <{chunk_size = 1 : i64}> : memref<?xf32>, vector<1x1x32xindex>, vector<1x1x32xi1> -> vector<1x1x32xf32>
     // CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<0.000000e+00> : vector<1x32xf32>
     // CHECK-DAG: %[[LOCAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD]], %[[CST_2]] [1] : vector<1x1x32xf32> to vector<1x32xf32>
-    // CHECK-DAG: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[LOCAL_REDUCE]] : vector<1x32xf32> to vector<1x32xf32>
+    // CHECK-DAG: %[[CAST:.*]] = vector.shape_cast %[[LOCAL_REDUCE]] : vector<1x32xf32> to vector<1x1x32xf32>
     // CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<4096xi8, 3>
-    // CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf32>
+    // CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<4096xi8, 3> -> !xegpu.mem_desc<1x32x32xf32>
     // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
-    // CHECK-DAG: %[[AFFINE1:.*]] = affine.apply #map()[%[[SGID]]]
-    // CHECK-DAG: %[[AFFINE2:.*]] = affine.apply #map1()[%[[SGID]]]
-    // CHECK-DAG: %[[AFFINE3:.*]] = affine.apply #map2()[%[[SGID]]]
-    // CHECK-DAG: %[[MUL1:.*]] = arith.muli {{.*}}, %[[C1:.*]] : index
-    // CHECK-DAG: %[[ROW_OFFSET:.*]] = arith.addi %[[C0:.*]], %[[MUL1]] : index
-    // CHECK-DAG: %[[MUL2:.*]] = arith.muli %[[AFFINE1]], %[[C1:.*]] : index
-    // CHECK-DAG: %[[ADD1:.*]] = arith.addi %[[C0:.*]], %[[MUL2]] : index
-    // CHECK-DAG: %[[MUL3:.*]] = arith.muli %[[AFFINE3]], %[[C1:.*]] : index
-    // CHECK-DAG: %[[ADD2:.*]] = arith.addi %[[ADD1]], %[[MUL3]] : index
-    // CHECK-DAG: %[[COL_OFFSET:.*]] = arith.muli %[[ADD2]], %[[C32:.*]] : index
-    // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[ROW_OFFSET]], %[[COL_OFFSET]]] : vector<1x32xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+    // CHECK-DAG: %[[AFF0:.*]] = affine.apply #map()[%[[SGID]]]
+    // CHECK-DAG: %[[AFF1:.*]] = affine.apply #map1()[%[[SGID]]]
+    // CHECK-DAG: %[[AFF2:.*]] = affine.apply #map2()[%[[SGID]]]
+    // CHECK-DAG: %[[ROW:.*]] = arith.muli %[[AFF0]], %[[C1A:.*]] : index
+    // CHECK-DAG: %[[COL0:.*]] = arith.muli %[[AFF1:.*]], %[[C1B:.*]] : index
+    // CHECK-DAG: %[[COL1:.*]] = arith.muli %[[AFF2]], %[[C32A:.*]] : index
+    // CHECK-DAG: xegpu.store_matrix %[[CAST]], %[[MEM_DESC]][%[[ROW]], %[[COL0]], %[[COL1]]] : vector<1x1x32xf32>, !xegpu.mem_desc<1x32x32xf32>, index, index, index
     // CHECK-DAG: gpu.barrier
-    // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[C0:.*]], %[[COL_OFFSET]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<32x32xf32>
-    // CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<0.000000e+00> : vector<32xf32>
-    // CHECK-DAG: %[[FINAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD_SLM]], %[[CST_3]] [0] : vector<32x32xf32> to vector<32xf32>
-    // CHECK-DAG: %[[SHAPE_CAST_FINAL:.*]] = vector.shape_cast %[[CST]] : vector<1x32xf32> to vector<32xf32>
-    // CHECK-DAG: %{{.*}} = arith.addf %[[FINAL_REDUCE]], %[[SHAPE_CAST_FINAL]] : vector<32xf32>
+    // CHECK-DAG: %[[ROW_L:.*]] = arith.muli %[[AFF0]], %[[C1C:.*]] : index
+    // CHECK-DAG: %[[COL0_L:.*]] = arith.muli %[[AFF1]], %[[C0:.*]] : index
+    // CHECK-DAG: %[[COL1_L:.*]] = arith.muli %[[AFF2]], %[[C32B:.*]] : index
+    // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[ROW_L]], %[[COL0_L]], %[[COL1_L]]] : !xegpu.mem_desc<1x32x32xf32>, index, index, index -> vector<1x32x32xf32>
+    // CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<0.000000e+00> : vector<1x32xf32>
+    // CHECK-DAG: %[[FINAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD_SLM]], %[[CST_3]] [1] : vector<1x32x32xf32> to vector<1x32xf32>
+    // CHECK-DAG: %[[ADD:.*]] = arith.addf %[[FINAL_REDUCE]], %[[CST]] : vector<1x32xf32>
     // CHECK-DAG: gpu.return
     %cst_3 = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 32, 1], sg_data = [1, 1, 32]>, dims = [1]>} dense<1.0> : vector<1x32xf32>
     %offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [1, 32, 1], sg_data = [1, 1, 32]>} dense<0> : vector<1x32x32xindex>
@@ -688,7 +686,7 @@ gpu.module @test_distribution {
     // CHECK-DAG: %[[DIV1:.*]] = arith.divui %[[SGID]], %[[C4:.*]] : index
     // CHECK-DAG: %[[REM2:.*]] = arith.remui %[[DIV1]], %[[C8:.*]] : index
     // CHECK-DAG: %[[MUL1:.*]] = arith.muli %[[REM2]], %[[C32:.*]] : index
-    // CHECK-DAG: %[[MUL2:.*]] = arith.muli %[[REM1]], %[[C32:.*]] : index
+    // CHECK-DAG: %[[MUL2:.*]] = arith.muli %[[REM1]], %[[C32_0:.*]] : index
     // CHECK-DAG: %[[REM3:.*]] = arith.remui %[[MUL1]], %[[C256:.*]] : index
     // CHECK-DAG: %[[REM4:.*]] = arith.remui %[[MUL2]], %[[C128:.*]] : index
     // CHECK-DAG: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%[[REM3]], %[[REM4]]] : memref<256x128xf32> -> !xegpu.tensor_desc<32x32xf32>
@@ -699,19 +697,18 @@ gpu.module @test_distribution {
     // CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<4096xi8, 3>
     // CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<4096xi8, 3> -> !xegpu.mem_desc<8x128xf32>
     // CHECK-DAG: %[[SGID2:.*]] = gpu.subgroup_id : index
-    // CHECK-DAG: %[[AFFINE1:.*]] = affine.apply #map3()[%[[SGID2]]]
-    // CHECK-DAG: %[[AFFINE2:.*]] = affine.apply #map4()[%[[SGID2]]]
-    // CHECK-DAG: %[[MUL3:.*]] = arith.muli {{.*}}, %[[C1:.*]] : index
-    // CHECK-DAG: %[[ROW_OFFSET:.*]] = arith.addi %[[C0:.*]], %[[MUL3]] : index
-    // CHECK-DAG: %[[MUL4:.*]] = arith.muli {{.*}}, %[[C1:.*]] : index
-    // CHECK-DAG: %[[ADD1:.*]] = arith.addi %[[C0:.*]], %[[MUL4]] : index
-    // CHECK-DAG: %[[COL_OFFSET:.*]] = arith.muli %[[ADD1]], %[[C32:.*]] : index
+    // CHECK-DAG: %[[AFFINE1:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID2]]]
+    // CHECK-DAG: %[[AFFINE2:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID2]]]
+    // CHECK-DAG: %[[ROW_OFFSET:.*]] = arith.muli %[[AFFINE1]], %[[C1:.*]] : index
+    // CHECK-DAG: %[[COL_OFFSET:.*]] = arith.muli %[[AFFINE2]], %[[C32_1:.*]] : index
     // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[ROW_OFFSET]], %[[COL_OFFSET]]] : vector<1x32xf32>, !xegpu.mem_desc<8x128xf32>, index, index
     // CHECK-DAG: gpu.barrier
-    // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[C0:.*]], %[[COL_OFFSET]]] : !xegpu.mem_desc<8x128xf32>, index, index -> vector<8x32xf32>
+    // CHECK-DAG: %[[ZERO_ROW:.*]] = arith.muli %[[AFFINE1]], %[[C0:.*]] : index
+    // CHECK-DAG: %[[COL_OFFSET2:.*]] = arith.muli %[[AFFINE2]], %[[C32_2:.*]] : index
+    // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[ZERO_ROW]], %[[COL_OFFSET2]]] : !xegpu.mem_desc<8x128xf32>, index, index -> vector<8x32xf32>
     // CHECK-DAG: %[[CST_CROSS_SG_1:.*]] = arith.constant dense<0.000000e+00> : vector<32xf32>
     // CHECK-DAG: %[[FINAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD_SLM]], %[[CST_CROSS_SG_1]] [0] : vector<8x32xf32> to vector<32xf32>
-    // CHECK-DAG: arith.addf %[[FINAL_REDUCE]], %[[CST:.*]] : vector<32xf32>
+    // CHECK-DAG: arith.addf %[[FINAL_REDUCE]], %[[CST]] : vector<32xf32>
     %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>, dims = [0]>} dense<0.0> : vector<128xf32>
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
       -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>>
@@ -732,36 +729,38 @@ gpu.module @test_distribution {
     // CHECK-DAG: %[[LOAD:.*]] = xegpu.load %{{.*}}[%[[CST_0]]], %[[CST_1]] <{chunk_size = 1 : i64}> : memref<?xf32>, vector<1x1x32x32xindex>, vector<1x1x32x32xi1> -> vector<1x1x32x32xf32>
     // CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<0.000000e+00> : vector<1x1xf32>
     // CHECK-DAG: %[[LOCAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD]], %[[CST_2]] [2, 3] : vector<1x1x32x32xf32> to vector<1x1xf32>
-    // CHECK-DAG: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[LOCAL_REDUCE]] : vector<1x1xf32> to vector<1x1xf32>
+    // CHECK-DAG: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[LOCAL_REDUCE]] : vector<1x1xf32> to vector<1x1x1x1xf32>
     // CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<256xi8, 3>
-    // CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<256xi8, 3> -> !xegpu.mem_desc<16x4xf32>
+    // CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<256xi8, 3> -> !xegpu.mem_desc<2x2x4x4xf32>
     // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
-    // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-    // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
-    // CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
-    // CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
-    // CHECK-DAG: %[[AFFINE1:.*]] = affine.apply #map()[%[[SGID]]]
-    // CHECK-DAG: %[[AFFINE2:.*]] = affine.apply #map1()[%[[SGID]]]
-    // CHECK-DAG: %[[AFFINE3:.*]] = affine.apply #map5()[%[[SGID]]]
-    // CHECK-DAG: %[[AFFINE4:.*]] = affine.apply #map6()[%[[SGID]]]
-    // CHECK-DAG: %[[AFFINE5:.*]] = affine.apply #map7()[%[[SGID]]]
-    // CHECK-DAG: %[[AFFINE6:.*]] = affine.apply #map4()[%[[SGID]]]
-    // CHECK-DAG: %[[MUL1:.*]] = arith.muli {{.*}}, %[[C1:.*]] : index
-    // CHECK-DAG: %[[ADD1:.*]] = arith.addi %[[C0:.*]], %[[MUL1]] : index
-    // CHECK-DAG: %[[MUL2:.*]] = arith.muli {{.*}}, %[[C4:.*]] : index
-    // CHECK-DAG: %[[ROW_OFFSET:.*]] = arith.addi %[[ADD1]], %[[MUL2]] : index
-    // CHECK-DAG: %[[MUL3:.*]] = arith.muli {{.*}}, %[[C1:.*]] : index
-    // CHECK-DAG: %[[ADD2:.*]] = arith.addi %[[C0:.*]], %[[MUL3]] : index
-    // CHECK-DAG: %[[MUL4:.*]] = arith.muli {{.*}}, %[[C2:.*]] : index
-    // CHECK-DAG: %[[ADD3:.*]] = arith.addi %[[ADD2]], %[[MUL4]] : index
-    // CHECK-DAG: %[[COL_OFFSET:.*]] = arith.muli %[[ADD3]], %[[C1:.*]] : index
-    // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[ROW_OFFSET]], %[[COL_OFFSET]]] : vector<1x1xf32>, !xegpu.mem_desc<16x4xf32>, index, index
+    // CHECK-DAG: %[[AFFINE0:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
+    // CHECK-DAG: %[[AFFINE1:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
+    // CHECK-DAG: %[[AFFINE2:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
+    // CHECK-DAG: %[[AFFINE3:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
+    // CHECK-DAG: %[[AFFINE4:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
+    // CHECK-DAG: %[[AFFINE5:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
+    // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+    // CHECK-DAG: %[[I0:.*]] = arith.muli %[[AFFINE0]], %[[C1]] : index
+    // CHECK-DAG: %[[C1_0:.*]] = arith.constant 1 : index
+    // CHECK-DAG: %[[I1:.*]] = arith.muli %[[AFFINE2]], %[[C1_0]] : index
+    // CHECK-DAG: %[[C1_1:.*]] = arith.constant 1 : index
+    // CHECK-DAG: %[[I2:.*]] = arith.muli %[[AFFINE4]], %[[C1_1]] : index
+    // CHECK-DAG: %[[C1_2:.*]] = arith.constant 1 : index
+    // CHECK-DAG: %[[I3:.*]] = arith.muli %[[AFFINE5]], %[[C1_2]] : index
+    // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[I0]], %[[I1]], %[[I2]], %[[I3]]] : vector<1x1x1x1xf32>, !xegpu.mem_desc<2x2x4x4xf32>, index, index, index, index
     // CHECK-DAG: gpu.barrier
-    // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[C0:.*]], %[[COL_OFFSET]]] : !xegpu.mem_desc<16x4xf32>, index, index -> vector<16x1xf32>
-    // CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
-    // CHECK-DAG: %[[FINAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD_SLM]], %[[CST_3]] [0] : vector<16x1xf32> to vector<1xf32>
-    // CHECK-DAG: %[[SHAPE_CAST_FINAL:.*]] = vector.shape_cast %[[CST]] : vector<1x1xf32> to vector<1xf32>
-    // CHECK-DAG: %[[FINAL_ADD:.*]] = arith.addf %[[FINAL_REDUCE]], %[[SHAPE_CAST_FINAL]] : vector<1xf32>
+    // CHECK-DAG: %[[C1_3:.*]] = arith.constant 1 : index
+    // CHECK-DAG: %[[L0:.*]] = arith.muli %[[AFFINE0]], %[[C1_3]] : index
+    // CHECK-DAG: %[[C1_4:.*]] = arith.constant 1 : index
+    // CHECK-DAG: %[[L1:.*]] = arith.muli %[[AFFINE2]], %[[C1_4]] : index
+    // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+    // CHECK-DAG: %[[L2:.*]] = arith.muli %[[AFFINE4]], %[[C0]] : index
+    // CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index
+    // CHECK-DAG: %[[L3:.*]] = arith.muli %[[AFFINE5]], %[[C0_0]] : index
+    // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[L0]], %[[L1]], %[[L2]], %[[L3]]] : !xegpu.mem_desc<2x2x4x4xf32>, index, index, index, index -> vector<1x1x4x4xf32>
+    // CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<0.000000e+00> : vector<1x1xf32>
+    // CHECK-DAG: %[[FINAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD_SLM]], %[[CST_3]] [2, 3] : vector<1x1x4x4xf32> to vector<1x1xf32>
+    // CHECK-DAG: %[[FINAL_ADD:.*]] = arith.addf %[[FINAL_REDUCE]], %[[CST]] : vector<1x1xf32>
     // CHECK-DAG: gpu.return
     %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [2, 2, 4, 4], sg_data = [1, 1, 32, 32]>, dims = [2, 3]>} dense<0.0> : vector<2x2xf32>
     %offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [2, 2, 4, 4], sg_data = [1, 1, 32, 32]>} dense<0> : vector<2x2x128x128xindex>
@@ -780,32 +779,30 @@ gpu.module @test_distribution {
     // CHECK-DAG: %[[LOAD:.*]] = xegpu.load %[[ARG0]][%[[CST_0]]], %[[CST_1]] <{chunk_size = 1 : i64}> : memref<?xf32>, vector<16x16x32x32xindex>, vector<16x16x32x32xi1> -> vector<16x16x32x32xf32>
     // CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<0.000000e+00> : vector<16x16xf32>
     // CHECK-DAG: %[[LOCAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD]], %[[CST_2]] [2, 3] : vector<16x16x32x32xf32> to vector<16x16xf32>
-    // CHECK-DAG: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[LOCAL_REDUCE]] : vector<16x16xf32> to vector<1x256xf32>
+    // CHECK-DAG: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[LOCAL_REDUCE]] : vector<16x16xf32> to vector<16x16x1x1xf32>
     // CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<65536xi8, 3>
-    // CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<65536xi8, 3> -> !xegpu.mem_desc<16x1024xf32>
+    // CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<65536xi8, 3> -> !xegpu.mem_desc<32x32x4x4xf32>
     // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
-    // CHECK-DAG: %[[AFFINE1:.*]] = affine.apply #map()[%[[SGID]]]
-    // CHECK-DAG: %[[AFFINE2:.*]] = affine.apply #map1()[%[[SGID]]]
-    // CHECK-DAG: %[[AFFINE3:.*]] = affine.apply #map5()[%[[SGID]]]
-    // CHECK-DAG: %[[AFFINE4:.*]] = affine.apply #map6()[%[[SGID]]]
-    // CHECK-DAG: %[[AFFINE5:.*]] = affine.apply #map7()[%[[SGID]]]
-    // CHECK-DAG: %[[AFFINE6:.*]] = affine.apply #map4()[%[[SGID]]]
-    // CHECK-DAG: %[[MUL1:.*]] = arith.muli {{.*}}, %[[C1:.*]] : index
-    // CHECK-DAG: %[[ADD1:.*]] = arith.addi %[[C0:.*]], %[[MUL1]] : index
-    // CHECK-DAG: %[[MUL2:.*]] = arith.muli {{.*}}, %[[C4:.*]] : index
-    // CHECK-DAG: %[[ROW_OFFSET:.*]] = arith.addi %[[ADD1]], %[[MUL2]] : index
-    // CHECK-DAG: %[[MUL3:.*]] = arith.muli {{.*}}, %[[C1:.*]] : index
-    // CHECK-DAG: %[[ADD2:.*]] = arith.addi %[[C0:.*]], %[[MUL3]] : index
-    // CHECK-DAG: %[[MUL4:.*]] = arith.muli {{.*}}, %[[C2:.*]] : index
-    // CHECK-DAG: %[[ADD3:.*]] = arith.addi %[[ADD2]], %[[MUL4]] : index
-    // CHECK-DAG: %[[COL_OFFSET:.*]] = arith.muli %[[ADD3]], %[[C256:.*]] : index
-    // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[ROW_OFFSET]], %[[COL_OFFSET]]] : vector<1x256xf32>, !xegpu.mem_desc<16x1024xf32>, index, index
+    // CHECK-DAG: %[[AFFINE0:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
+    // CHECK-DAG: %[[AFFINE1:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
+    // CHECK-DAG: %[[AFFINE2:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
+    // CHECK-DAG: %[[AFFINE3:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
+    // CHECK-DAG: %[[AFFINE4:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
+    // CHECK-DAG: %[[AFFINE5:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
+    // CHECK-DAG: %[[R0:.*]] = arith.muli %[[AFFINE0]], %[[C16_0:.*]] : index
+    // CHECK-DAG: %[[R1:.*]] = arith.muli %[[AFFINE2]], %[[C16_1:.*]] : index
+    // CHECK-DAG: %[[R2:.*]] = arith.muli %[[AFFINE4]], %[[C1_0:.*]] : index
+    // CHECK-DAG: %[[R3:.*]] = arith.muli %[[AFFINE5]], %[[C1_1:.*]] : index
+    // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[R0]], %[[R1]], %[[R2]], %[[R3]]] : vector<16x16x1x1xf32>, !xegpu.mem_desc<32x32x4x4xf32>, index, index, index, index
     // CHECK-DAG: gpu.barrier
-    // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[C0:.*]], %[[COL_OFFSET]]] : !xegpu.mem_desc<16x1024xf32>, index, index -> vector<16x256xf32>
-    // CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<0.000000e+00> : vector<256xf32>
-    // CHECK-DAG: %[[FINAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD_SLM]], %[[CST_3]] [0] : vector<16x256xf32> to vector<256xf32>
-    // CHECK-DAG: %[[SHAPE_CAST_FINAL:.*]] = vector.shape_cast %[[CST]] : vector<16x16xf32> to vector<256xf32>
-    // CHECK-DAG: %[[FINAL_ADD:.*]] = arith.addf %[[FINAL_REDUCE]], %[[SHAPE_CAST_FINAL]] : vector<256xf32>
+    // CHECK-DAG: %[[L0:.*]] = arith.muli %[[AFFINE0]], %[[C16_2:.*]] : index
+    // CHECK-DAG: %[[L1:.*]] = arith.muli %[[AFFINE2]], %[[C16_3:.*]] : index
+    // CHECK-DAG: %[[L2:.*]] = arith.muli %[[AFFINE4]], %[[C0_0:.*]] : index
+    // CHECK-DAG: %[[L3:.*]] = arith.muli %[[AFFINE5]], %[[C0_1:.*]] : index
+    // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[L0]], %[[L1]], %[[L2]], %[[L3]]] : !xegpu.mem_desc<32x32x4x4xf32>, index, index, index, index -> vector<16x16x4x4xf32>
+    // CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<0.000000e+00> : vector<16x16xf32>
+    // CHECK-DAG: %[[FINAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD_SLM]], %[[CST_3]] [2, 3] : vector<16x16x4x4xf32> to vector<16x16xf32>
+    // CHECK-DAG: %[[FINAL_ADD:.*]] = arith.addf %[[FINAL_REDUCE]], %[[CST]] : vector<16x16xf32>
     // CHECK-DAG: gpu.return
     %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [2, 2, 4, 4], sg_data = [16, 16, 32, 32]>, dims = [2, 3]>} dense<0.0> : vector<32x32xf32>
     %offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [2, 2, 4, 4], sg_data = [16, 16, 32, 32]>} dense<0> : vector<32x32x128x128xindex>

>From 0a40a853711726999072c842bfee0b0fc8208922 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 18 Feb 2026 23:03:50 +0000
Subject: [PATCH 2/3] remove debug print

---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 35 +------------------
 1 file changed, 1 insertion(+), 34 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 8ab0b50da03ad..a4172bdc9d136 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1401,12 +1401,6 @@ struct WgToSgMultiDimReductionOp
     SmallVector<int64_t> sgSrcShape(sgSrcType.getShape().begin(),
                                     sgSrcType.getShape().end());
 
-    // debug print sgDstShape
-    llvm::errs() << "[WgToSgMultiDimReductionOp] sgDstShape = [";
-    for (auto [i, v] : llvm::enumerate(sgDstShape))
-      llvm::errs() << (i ? ", " : "") << v;
-    llvm::errs() << "]\n";
-
     VectorType newDstType = VectorType::get(sgDstShape, elemTy);
     for (auto sgSrc : sgSrcs) {
       // Create ZERO accumulator for local reduction
@@ -1457,21 +1451,6 @@ struct WgToSgMultiDimReductionOp
     for (int64_t dim : reductionDims)
       slmShape[dim] = originalSrcShape[dim] / sgSrcShape[dim];
 
-    // Debug print.
-    llvm::errs() << "[WgToSgMultiDimReductionOp] originalSrcShape = [";
-    for (auto [i, v] : llvm::enumerate(originalSrcShape))
-      llvm::errs() << (i ? ", " : "") << v;
-    llvm::errs() << "], sgSrcShape = [";
-    for (auto [i, v] : llvm::enumerate(sgSrcShape))
-      llvm::errs() << (i ? ", " : "") << v;
-    llvm::errs() << "], slmShape = [";
-    for (auto [i, v] : llvm::enumerate(slmShape))
-      llvm::errs() << (i ? ", " : "") << v;
-    llvm::errs() << "], reductionDims = [";
-    for (auto [i, v] : llvm::enumerate(reductionDims))
-      llvm::errs() << (i ? ", " : "") << v;
-    llvm::errs() << "]\n";
-
     // Allocate SLM
     auto bitWidth = elemTy.getIntOrFloatBitWidth();
     auto bytesPerElement = bitWidth / 8;
@@ -1516,14 +1495,7 @@ struct WgToSgMultiDimReductionOp
       Value offsetVal = arith::MulIOp::create(rewriter, loc, dimVal, strideVal);
       slmStoreOffsets.push_back(offsetVal);
     }
-    // debug print
-    llvm::errs() << "[WgToSgMultiDimReductionOp] sgIds = [";
-    for (auto [i, v] : llvm::enumerate(sgIds))
-      llvm::errs() << (i ? ", " : "") << v;
-    llvm::errs() << "], slmStoreOffsets = [";
-    for (auto [i, v] : llvm::enumerate(slmStoreOffsets))
-      llvm::errs() << (i ? ", " : "") << v;
-    llvm::errs() << "]\n";
+
     xegpu::StoreMatrixOp::create(rewriter, loc, slmStoreData,
                                  memDesc.getResult(), slmStoreOffsets,
                                  /*layout=*/nullptr);
@@ -1535,11 +1507,6 @@ struct WgToSgMultiDimReductionOp
     for (int64_t dim : reductionDims)
       slmLoadDataShape[dim] = slmShape[dim];
 
-    llvm::errs() << "[WgToSgMultiDimReductionOp] slmLoadDataShape = [";
-    for (auto [i, v] : llvm::enumerate(slmLoadDataShape))
-      llvm::errs() << (i ? ", " : "") << v;
-    llvm::errs() << "]\n";
-
     SmallVector<OpFoldResult> slmLoadOffsets;
     for (int i = 0; i < srcVecRank; ++i) {
       Value dimVal = sgIds[i];

>From c2eb9ff5f4227d0fdc8891b821fcab91652c8732 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 18 Feb 2026 23:11:14 +0000
Subject: [PATCH 3/3] polish

---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 46 -------------------
 1 file changed, 46 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index a4172bdc9d136..5330217972bf9 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1277,52 +1277,6 @@ static Value createAccumulator(ConversionPatternRewriter &rewriter,
   return nullptr;
 }
 
-/// This function converts multi-dimensional subgroup indices into a single
-/// linear offset. It's used to calculate memory offsets in SLM for
-/// cross-subgroup reduction coordination.
-///
-/// Parameters:
-/// - sgIds: Multi-dimensional subgroup indices (e.g., [sgId_x, sgId_y, sgId_z])
-/// - dims: Which dimensions to include in linearization (e.g., [0, 2] for x and
-///   z dims)
-/// - sgLayout: Subgroup layout sizes for each dimension (e.g., [4, 8, 2] means
-///   4x8x2 subgroups)
-///
-/// It uses row-major linearization formula:
-///    offset = sum(sgIds[dim] * stride[dim])
-///    where stride[dim] = product of all sgLayout sizes in dimensions after
-///    'dim'
-///
-/// Example:
-/// - sgLayout = [4, 8, 2], dims = [0, 2] (linearize x and z dimensions)
-/// - sgIds = [1, 3, 1] (subgroup at position x=1, y=3, z=1)
-/// - Calculation:
-///   * dim=0: stride=1, term = sgIds[0] * 1 = 1 * 1 = 1
-///   * dim=2: stride=sgLayout[0]=4, term = sgIds[2] * 4 = 1 * 4 = 4
-///   * linearizedOffset = 1 + 4 = 5
-///
-/// This gives us a unique linear index for each combination of subgroup
-/// positions in the specified dimensions, which is used for SLM row/column
-/// addressing.
-[[maybe_unused]] static Value
-linearizeSubgroupIndices(ConversionPatternRewriter &rewriter, Location loc,
-                         ArrayRef<Value> sgIds, ArrayRef<int64_t> dims,
-                         ArrayRef<int64_t> sgLayout) {
-  Value linearizedOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
-  int64_t stride = 1;
-
-  for (int64_t dim : dims) {
-    Value dimVal = sgIds[dim];
-    Value strideVal = arith::ConstantIndexOp::create(rewriter, loc, stride);
-    Value term = arith::MulIOp::create(rewriter, loc, dimVal, strideVal);
-    linearizedOffset =
-        arith::AddIOp::create(rewriter, loc, linearizedOffset, term);
-    stride *= sgLayout[dim];
-  }
-
-  return linearizedOffset;
-}
-
 /// This pattern transforms vector.multi_dim_reduction operations from
 /// workgroup-level to subgroup-level execution with support for multiple
 /// reduction dimensions.



More information about the Mlir-commits mailing list