[Mlir-commits] [mlir] 9bddf47 - [MLIR][XeGPU] Extend Wg-to-Sg Distribution of Multi-Reduction Op for round-robin layout (#189988)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Apr 6 14:07:54 PDT 2026


Author: Jianhui Li
Date: 2026-04-06T14:07:50-07:00
New Revision: 9bddf47198582459f954e10fd4920f2afac10e15

URL: https://github.com/llvm/llvm-project/commit/9bddf47198582459f954e10fd4920f2afac10e15
DIFF: https://github.com/llvm/llvm-project/commit/9bddf47198582459f954e10fd4920f2afac10e15.diff

LOG: [MLIR][XeGPU] Extend Wg-to-Sg Distribution of Multi-Reduction Op for round-robin layout (#189988)

This PR enhance the multi-reduction op pattern of wg-to-sg distribution
pass:
1. allows each sg have multiple distribution of sg_data tiles.
2. expand the slm buffer size.
3. construct the layout based on the partial reduced vector and use
layout.computeDistributedCoords() to compute coordinates. the layout is
constructed so that the store is cooperative, and load overlapps with
neighbour threads.
4. perform save and load.

Added: 
    

Modified: 
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
    mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
    mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index d04933423ecd0..a5b1df0f93f57 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1255,7 +1255,6 @@ struct WgToSgMultiDimReductionOp
     bool isScalarResult = !dstVecType;
 
     auto originalSrcShape = srcType.getShape();
-    int srcVecRank = originalSrcShape.size();
     Type elemTy = srcType.getElementType();
 
     xegpu::DistributeLayoutAttr layout =
@@ -1268,9 +1267,11 @@ struct WgToSgMultiDimReductionOp
     // Get sg_layout and sg_data from the parent layout
     SmallVector<int64_t> sgLayout;
     SmallVector<int64_t> sgData;
+    xegpu::DistributeLayoutAttr parentLayout;
     if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(layout)) {
-      sgLayout = sliceAttr.getParent().getEffectiveSgLayoutAsInt();
-      sgData = sliceAttr.getParent().getEffectiveSgDataAsInt();
+      parentLayout = sliceAttr.getParent();
+      sgLayout = parentLayout.getEffectiveSgLayoutAsInt();
+      sgData = parentLayout.getEffectiveSgDataAsInt();
     } else
       return rewriter.notifyMatchFailure(
           op, "Reduction should have SliceAttr layout");
@@ -1330,26 +1331,33 @@ struct WgToSgMultiDimReductionOp
       return success();
     }
 
-    // Step 2: cross-subgroup reduction using SLM
+    // Step 2: cross-subgroup reduction using SLM - allocating slm memory
     auto slmStoreDataShape = sgSrcShape;
     for (int64_t dim : reductionDims)
       slmStoreDataShape[dim] = 1;
     VectorType slmStoreDataType = VectorType::get(slmStoreDataShape, elemTy);
-    Value slmStoreData;
-    if (isScalarResult) {
-      // Scalar result: broadcast scalar to vector<1x...x1> for SLM store
-      slmStoreData = vector::BroadcastOp::create(
-          rewriter, loc, slmStoreDataType, localReductions[0]);
-    } else {
-      slmStoreData = vector::ShapeCastOp::create(
-          rewriter, loc, slmStoreDataType, localReductions[0]);
+    SmallVector<Value> slmStoreData;
+    for (auto localResult : localReductions) {
+      if (isScalarResult) {
+        // Scalar result: broadcast scalar to vector<1x...x1> for SLM store
+        slmStoreData.push_back(vector::BroadcastOp::create(
+            rewriter, loc, slmStoreDataType, localResult));
+      } else {
+        slmStoreData.push_back(vector::ShapeCastOp::create(
+            rewriter, loc, slmStoreDataType, localResult));
+      }
     }
-
+    // for reduction dimension, SLM stores partial results from each subgroup
     SmallVector<int64_t> slmShape(originalSrcShape.begin(),
                                   originalSrcShape.end());
-    // for reduction dimension, SLM stores partial results from each subgroup
-    for (int64_t dim : reductionDims)
+    SmallVector<int> slmSgData(sgData.begin(), sgData.end());
+    SmallVector<int> slmSgLayout(sgLayout.begin(), sgLayout.end());
+    for (int dim : reductionDims) {
       slmShape[dim] = sgLayout[dim];
+      slmSgData[dim] = 1;
+    }
+    xegpu::LayoutAttr slmStoreLayout =
+        xegpu::LayoutAttr::get(rewriter.getContext(), slmSgLayout, slmSgData);
 
     // Allocate SLM
     auto bitWidth = elemTy.getIntOrFloatBitWidth();
@@ -1363,82 +1371,61 @@ struct WgToSgMultiDimReductionOp
     auto memDesc =
         xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm);
 
-    // if localReductions have more than 1 result, not support
-    if (localReductions.size() > 1) {
-      return rewriter.notifyMatchFailure(
-          op,
-          "Multiple local reductions not supported in current implementation.");
-    }
-
-    // Step 4: Store local results to SLM
+    // Step 3: Store local results to SLM
     auto sgId = gpu::SubgroupIdOp::create(rewriter, loc,
                                           rewriter.getIndexType(), nullptr);
 
-    // Convert sgLayout to Values for delinearizeIndex
-    SmallVector<Value> sgLayoutValues;
-    for (int64_t dim : sgLayout)
-      sgLayoutValues.push_back(
-          arith::ConstantIndexOp::create(rewriter, loc, dim));
-
-    auto sgIdsResult = affine::delinearizeIndex(rewriter, loc, sgId.getResult(),
-                                                sgLayoutValues);
-    if (failed(sgIdsResult))
+    auto slmStoreCoords =
+        slmStoreLayout.computeDistributedCoords(rewriter, loc, sgId, slmShape);
+    if (failed(slmStoreCoords))
       return failure();
-    SmallVector<Value> sgIds = *sgIdsResult;
-
-    auto getSlmOffsets = [&](int64_t reductionDimStride) {
-      SmallVector<OpFoldResult> offsets;
-      offsets.reserve(srcVecRank);
-      for (int i = 0; i < srcVecRank; ++i) {
-        Value dimVal = sgIds[i];
-        int64_t sgDataStride = (llvm::is_contained(reductionDims, i))
-                                   ? reductionDimStride
-                                   : sgSrcShape[i];
-        Value strideVal =
-            arith::ConstantIndexOp::create(rewriter, loc, sgDataStride);
-        Value offsetVal =
-            arith::MulIOp::create(rewriter, loc, dimVal, strideVal);
-        offsets.push_back(offsetVal);
-      }
-      return offsets;
-    };
-
-    SmallVector<OpFoldResult> slmStoreOffsets =
-        getSlmOffsets(/*reductionDimStride=*/1);
-
-    xegpu::StoreMatrixOp::create(rewriter, loc, slmStoreData,
-                                 memDesc.getResult(), slmStoreOffsets,
-                                 /*layout=*/nullptr);
+    for (auto [data, coord] : llvm::zip(slmStoreData, *slmStoreCoords)) {
+      SmallVector<OpFoldResult> coordOfr(coord.begin(), coord.end());
+      xegpu::StoreMatrixOp::create(rewriter, loc, data, memDesc.getResult(),
+                                   coordOfr,
+                                   /*layout=*/nullptr);
+    }
 
     gpu::BarrierOp::create(rewriter, loc);
 
-    // Step 5: Load from SLM for final reduction
+    // Step 4: Load from SLM for final reduction
     SmallVector<int64_t> slmLoadDataShape(sgSrcShape.begin(), sgSrcShape.end());
-    for (int64_t dim : reductionDims)
+    for (int64_t dim : reductionDims) {
       slmLoadDataShape[dim] = slmShape[dim];
-
-    SmallVector<OpFoldResult> slmLoadOffsets =
-        getSlmOffsets(/*reductionDimStride=*/0);
+      slmSgData[dim] = slmShape[dim];
+    }
+    xegpu::LayoutAttr slmLoadLayout =
+        xegpu::LayoutAttr::get(rewriter.getContext(), slmSgLayout, slmSgData);
+    auto slmLoadCoords =
+        slmLoadLayout.computeDistributedCoords(rewriter, loc, sgId, slmShape);
+    if (failed(slmLoadCoords))
+      return failure();
 
     VectorType slmLoadType = VectorType::get(slmLoadDataShape, elemTy);
-    auto slmLoadOp = xegpu::LoadMatrixOp::create(
-        rewriter, loc, slmLoadType, memDesc.getResult(), slmLoadOffsets,
-        /*layout=*/nullptr);
+    SmallVector<Value> slmLoadData;
+    for (auto coord : *slmLoadCoords) {
+      SmallVector<OpFoldResult> coordOfr(coord.begin(), coord.end());
+      slmLoadData.push_back(xegpu::LoadMatrixOp::create(
+          rewriter, loc, slmLoadType, memDesc.getResult(), coordOfr,
+          /*layout=*/nullptr));
+    }
 
-    // Step 6: Perform final reduction with neutral accumulator
+    // Step 5: Perform final reduction with neutral accumulator and add the
+    // original accumulator at the end
     Value neutralFinalAcc = xegpu::createReductionNeutralValue(
         rewriter, loc, sgDstType, op.getKind());
 
-    auto finalReduce = vector::MultiDimReductionOp::create(
-        rewriter, loc, sgDstType, op.getKind(), slmLoadOp.getResult(),
-        neutralFinalAcc, reductionDims);
-
-    // Step 7: Add the original accumulator at the end
-    auto finalResult = vector::makeArithReduction(rewriter, loc, op.getKind(),
-                                                  finalReduce.getResult(),
-                                                  adaptor.getAcc()[0]);
-
-    rewriter.replaceOp(op, finalResult);
+    SmallVector<Value> finalResults;
+    for (size_t i = 0; i < slmLoadData.size(); ++i) {
+      auto loaded = slmLoadData[i];
+      auto finalReduce = vector::MultiDimReductionOp::create(
+          rewriter, loc, sgDstType, op.getKind(), loaded, neutralFinalAcc,
+          reductionDims);
+      finalResults.push_back(vector::makeArithReduction(
+          rewriter, loc, op.getKind(), finalReduce.getResult(),
+          adaptor.getAcc()[i]));
+    }
+    rewriter.replaceOpWithMultiple(op, {finalResults});
     return success();
   }
 };

diff  --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
index 320a2fb1f72ac..897eab12329e2 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
@@ -165,6 +165,58 @@ gpu.module @test_distribution {
     gpu.return
   }
 
+  // CHECK-LABEL: gpu.func @reduction_cross_sg_rr
+  gpu.func @reduction_cross_sg_rr(%arg0: memref<2048xf32, 1>) kernel {
+    // CHECK: %[[CST_OFFSETS0:.*]] = arith.constant dense<0> : vector<4x16xindex>
+    // CHECK: %[[CST_OFFSETS1:.*]] = arith.constant dense<0> : vector<4x16xindex>
+    // CHECK: %[[CST_ACC0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+    // CHECK: %[[CST_ACC1:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+    // CHECK: %[[CST_MASK0:.*]] = arith.constant dense<true> : vector<4x16xi1>
+    // CHECK: %[[CST_MASK1:.*]] = arith.constant dense<true> : vector<4x16xi1>
+    //
+    // CHECK: %[[LOAD0:.*]] = xegpu.load %arg0[%[[CST_OFFSETS0]]], %[[CST_MASK0]]
+    // CHECK-SAME: -> vector<4x16xf32>
+    // CHECK: %[[LOAD1:.*]] = xegpu.load %arg0[%[[CST_OFFSETS1]]], %[[CST_MASK1]]
+    // CHECK-SAME: -> vector<4x16xf32>
+    //
+    // Local reductions
+    // CHECK: %[[NEUTRAL0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+    // CHECK: %[[LOCAL_RED0:.*]] = vector.multi_reduction <add>, %[[LOAD0]], %[[NEUTRAL0]] [1] : vector<4x16xf32> to vector<4xf32>
+    // CHECK: %[[NEUTRAL1:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+    // CHECK: %[[LOCAL_RED1:.*]] = vector.multi_reduction <add>, %[[LOAD1]], %[[NEUTRAL1]] [1] : vector<4x16xf32> to vector<4xf32>
+    //
+    // Shape cast for SLM store
+    // CHECK: %[[SC0:.*]] = vector.shape_cast %[[LOCAL_RED0]] : vector<4xf32> to vector<4x1xf32>
+    // CHECK: %[[SC1:.*]] = vector.shape_cast %[[LOCAL_RED1]] : vector<4xf32> to vector<4x1xf32>
+    //
+    // SLM allocation and mem_desc
+    // CHECK: %[[SLM:.*]] = memref.alloca() : memref<512xi8, 3>
+    // CHECK: %[[MEMDESC:.*]] = xegpu.create_mem_desc %[[SLM]] : memref<512xi8, 3> -> !xegpu.mem_desc<8x16xf32>
+    //
+    // Store to SLM
+    // CHECK: xegpu.store_matrix %[[SC0]], %[[MEMDESC]]{{.*}} : vector<4x1xf32>, !xegpu.mem_desc<8x16xf32>
+    // CHECK: xegpu.store_matrix %[[SC1]], %[[MEMDESC]]{{.*}} : vector<4x1xf32>, !xegpu.mem_desc<8x16xf32>
+    // CHECK: gpu.barrier
+    //
+    // Load from SLM
+    // CHECK: %[[SLM_LOAD0:.*]] = xegpu.load_matrix %[[MEMDESC]]{{.*}} -> vector<4x16xf32>
+    // CHECK: %[[SLM_LOAD1:.*]] = xegpu.load_matrix %[[MEMDESC]]{{.*}} -> vector<4x16xf32>
+    //
+    // Final reduction
+    // CHECK: %[[FINAL_NEUTRAL:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+    // CHECK: %[[FINAL_RED0:.*]] = vector.multi_reduction <add>, %[[SLM_LOAD0]], %[[FINAL_NEUTRAL]] [1] : vector<4x16xf32> to vector<4xf32>
+    // CHECK: %[[RES0:.*]] = arith.addf %[[FINAL_RED0]], %[[CST_ACC0]] : vector<4xf32>
+    // CHECK: %[[FINAL_RED1:.*]] = vector.multi_reduction <add>, %[[SLM_LOAD1]], %[[FINAL_NEUTRAL]] [1] : vector<4x16xf32> to vector<4xf32>
+    // CHECK: %[[RES1:.*]] = arith.addf %[[FINAL_RED1]], %[[CST_ACC1]] : vector<4xf32>
+
+    %offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [1, 16], sg_data = [4, 16]>} dense<0> : vector<8x256xindex>
+    %acc = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 16], sg_data = [4, 16]>, dims = [1]>} dense<0.000000e+00> : vector<8xf32>
+    %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [1, 16], sg_data = [4, 16]>} dense<true> : vector<8x256xi1>
+    %val = xegpu.load %arg0[%offset], %mask <{layout = #xegpu.layout<sg_layout = [1, 16], sg_data = [4, 16]>}> : memref<2048xf32, 1>, vector<8x256xindex>, vector<8x256xi1> -> vector<8x256xf32>
+    %reduce = vector.multi_reduction <add>, %val, %acc {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 16], sg_data = [4, 16]>, dims = [1]>} [1] : vector<8x256xf32> to vector<8xf32>
+    gpu.return
+  }
+
   // CHECK-LABEL: splat_constant
   gpu.func @splat_constant() {
     // CHECK-COUNT-2: %[[CST:.*]] = arith.constant dense<0> : vector<4xindex>

diff  --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 90c6a73497630..bbdffa0986962 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -1,13 +1,4 @@
 // RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
-
-// CHECK-DAG: #map = affine_map<()[s0] -> (s0 floordiv 4)>
-// CHECK-DAG: #map1 = affine_map<()[s0] -> (s0 mod 4)>
-// CHECK-DAG: #map2 = affine_map<()[s0] -> (s0 floordiv 32)>
-// CHECK-DAG: #map3 = affine_map<()[s0] -> (s0 mod 32)>
-// CHECK-DAG: #map4 = affine_map<()[s0] -> (0)>
-// CHECK-DAG: #map5 = affine_map<()[s0] -> ((s0 mod 32) floordiv 16)>
-// CHECK-DAG: #map6 = affine_map<()[s0] -> (s0 mod 16)>
-// CHECK-DAG: #map7 = affine_map<()[s0] -> ((s0 mod 16) floordiv 4)>
 gpu.module @test_distribution {
   // CHECK-LABEL: create_nd_tdesc_no_offset
   // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
@@ -681,18 +672,9 @@ gpu.module @test_distribution {
     // CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<4096xi8, 3>
     // CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<4096xi8, 3> -> !xegpu.mem_desc<1x32x32xf32>
     // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
-    // CHECK-DAG: %[[AFF0:.*]] = affine.apply #map2()[%[[SGID]]]
-    // CHECK-DAG: %[[AFF1:.*]] = affine.apply #map3()[%[[SGID]]]
-    // CHECK-DAG: %[[AFF2:.*]] = affine.apply #map4()[%[[SGID]]]
-    // CHECK-DAG: %[[ROW:.*]] = arith.muli %[[AFF0]], %[[C1A:.*]] : index
-    // CHECK-DAG: %[[COL0:.*]] = arith.muli %[[AFF1:.*]], %[[C1B:.*]] : index
-    // CHECK-DAG: %[[COL1:.*]] = arith.muli %[[AFF2]], %[[C32A:.*]] : index
-    // CHECK-DAG: xegpu.store_matrix %[[CAST]], %[[MEM_DESC]][%[[ROW]], %[[COL0]], %[[COL1]]] : vector<1x1x32xf32>, !xegpu.mem_desc<1x32x32xf32>, index, index, index
+    // CHECK-DAG: xegpu.store_matrix %[[CAST]], %[[MEM_DESC]]{{.*}} : vector<1x1x32xf32>, !xegpu.mem_desc<1x32x32xf32>, index, index, index
     // CHECK-DAG: gpu.barrier
-    // CHECK-DAG: %[[ROW_L:.*]] = arith.muli %[[AFF0]], %[[C1C:.*]] : index
-    // CHECK-DAG: %[[COL0_L:.*]] = arith.muli %[[AFF1]], %[[C0:.*]] : index
-    // CHECK-DAG: %[[COL1_L:.*]] = arith.muli %[[AFF2]], %[[C32B:.*]] : index
-    // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[ROW_L]], %[[COL0_L]], %[[COL1_L]]] : !xegpu.mem_desc<1x32x32xf32>, index, index, index -> vector<1x32x32xf32>
+    // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]]{{.*}} : !xegpu.mem_desc<1x32x32xf32>, index, index, index -> vector<1x32x32xf32>
     // CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<0.000000e+00> : vector<1x32xf32>
     // CHECK-DAG: %[[FINAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD_SLM]], %[[CST_3]] [1] : vector<1x32x32xf32> to vector<1x32xf32>
     // CHECK-DAG: %[[ADD:.*]] = arith.addf %[[FINAL_REDUCE]], %[[CST]] : vector<1x32xf32>
@@ -725,15 +707,9 @@ gpu.module @test_distribution {
     // CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<4096xi8, 3>
     // CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<4096xi8, 3> -> !xegpu.mem_desc<8x128xf32>
     // CHECK-DAG: %[[SGID2:.*]] = gpu.subgroup_id : index
-    // CHECK-DAG: %[[AFFINE1:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID2]]]
-    // CHECK-DAG: %[[AFFINE2:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID2]]]
-    // CHECK-DAG: %[[ROW_OFFSET:.*]] = arith.muli %[[AFFINE1]], %[[C1:.*]] : index
-    // CHECK-DAG: %[[COL_OFFSET:.*]] = arith.muli %[[AFFINE2]], %[[C32_1:.*]] : index
-    // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[ROW_OFFSET]], %[[COL_OFFSET]]] : vector<1x32xf32>, !xegpu.mem_desc<8x128xf32>, index, index
+    // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]]{{.*}} : vector<1x32xf32>, !xegpu.mem_desc<8x128xf32>, index, index
     // CHECK-DAG: gpu.barrier
-    // CHECK-DAG: %[[ZERO_ROW:.*]] = arith.muli %[[AFFINE1]], %[[C0:.*]] : index
-    // CHECK-DAG: %[[COL_OFFSET2:.*]] = arith.muli %[[AFFINE2]], %[[C32_2:.*]] : index
-    // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[ZERO_ROW]], %[[COL_OFFSET2]]] : !xegpu.mem_desc<8x128xf32>, index, index -> vector<8x32xf32>
+    // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]]{{.*}} : !xegpu.mem_desc<8x128xf32>, index, index -> vector<8x32xf32>
     // CHECK-DAG: %[[CST_CROSS_SG_1:.*]] = arith.constant dense<0.000000e+00> : vector<32xf32>
     // CHECK-DAG: %[[FINAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD_SLM]], %[[CST_CROSS_SG_1]] [0] : vector<8x32xf32> to vector<32xf32>
     // CHECK-DAG: arith.addf %[[FINAL_REDUCE]], %[[CST]] : vector<32xf32>
@@ -761,31 +737,9 @@ gpu.module @test_distribution {
     // CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<256xi8, 3>
     // CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<256xi8, 3> -> !xegpu.mem_desc<2x2x4x4xf32>
     // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
-    // CHECK-DAG: %[[AFFINE0:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
-    // CHECK-DAG: %[[AFFINE1:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
-    // CHECK-DAG: %[[AFFINE2:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
-    // CHECK-DAG: %[[AFFINE3:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
-    // CHECK-DAG: %[[AFFINE4:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
-    // CHECK-DAG: %[[AFFINE5:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
-    // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-    // CHECK-DAG: %[[I0:.*]] = arith.muli %[[AFFINE0]], %[[C1]] : index
-    // CHECK-DAG: %[[C1_0:.*]] = arith.constant 1 : index
-    // CHECK-DAG: %[[I1:.*]] = arith.muli %[[AFFINE2]], %[[C1_0]] : index
-    // CHECK-DAG: %[[C1_1:.*]] = arith.constant 1 : index
-    // CHECK-DAG: %[[I2:.*]] = arith.muli %[[AFFINE4]], %[[C1_1]] : index
-    // CHECK-DAG: %[[C1_2:.*]] = arith.constant 1 : index
-    // CHECK-DAG: %[[I3:.*]] = arith.muli %[[AFFINE5]], %[[C1_2]] : index
-    // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[I0]], %[[I1]], %[[I2]], %[[I3]]] : vector<1x1x1x1xf32>, !xegpu.mem_desc<2x2x4x4xf32>, index, index, index, index
+    // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]]{{.*}} : vector<1x1x1x1xf32>, !xegpu.mem_desc<2x2x4x4xf32>, index, index, index, index
     // CHECK-DAG: gpu.barrier
-    // CHECK-DAG: %[[C1_3:.*]] = arith.constant 1 : index
-    // CHECK-DAG: %[[L0:.*]] = arith.muli %[[AFFINE0]], %[[C1_3]] : index
-    // CHECK-DAG: %[[C1_4:.*]] = arith.constant 1 : index
-    // CHECK-DAG: %[[L1:.*]] = arith.muli %[[AFFINE2]], %[[C1_4]] : index
-    // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-    // CHECK-DAG: %[[L2:.*]] = arith.muli %[[AFFINE4]], %[[C0]] : index
-    // CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index
-    // CHECK-DAG: %[[L3:.*]] = arith.muli %[[AFFINE5]], %[[C0_0]] : index
-    // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[L0]], %[[L1]], %[[L2]], %[[L3]]] : !xegpu.mem_desc<2x2x4x4xf32>, index, index, index, index -> vector<1x1x4x4xf32>
+    // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]]{{.*}} : !xegpu.mem_desc<2x2x4x4xf32>, index, index, index, index -> vector<1x1x4x4xf32>
     // CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<0.000000e+00> : vector<1x1xf32>
     // CHECK-DAG: %[[FINAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD_SLM]], %[[CST_3]] [2, 3] : vector<1x1x4x4xf32> to vector<1x1xf32>
     // CHECK-DAG: %[[FINAL_ADD:.*]] = arith.addf %[[FINAL_REDUCE]], %[[CST]] : vector<1x1xf32>
@@ -811,23 +765,9 @@ gpu.module @test_distribution {
     // CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<65536xi8, 3>
     // CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<65536xi8, 3> -> !xegpu.mem_desc<32x32x4x4xf32>
     // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
-    // CHECK-DAG: %[[AFFINE0:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
-    // CHECK-DAG: %[[AFFINE1:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
-    // CHECK-DAG: %[[AFFINE2:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
-    // CHECK-DAG: %[[AFFINE3:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
-    // CHECK-DAG: %[[AFFINE4:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
-    // CHECK-DAG: %[[AFFINE5:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
-    // CHECK-DAG: %[[R0:.*]] = arith.muli %[[AFFINE0]], %[[C16_0:.*]] : index
-    // CHECK-DAG: %[[R1:.*]] = arith.muli %[[AFFINE2]], %[[C16_1:.*]] : index
-    // CHECK-DAG: %[[R2:.*]] = arith.muli %[[AFFINE4]], %[[C1_0:.*]] : index
-    // CHECK-DAG: %[[R3:.*]] = arith.muli %[[AFFINE5]], %[[C1_1:.*]] : index
-    // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[R0]], %[[R1]], %[[R2]], %[[R3]]] : vector<16x16x1x1xf32>, !xegpu.mem_desc<32x32x4x4xf32>, index, index, index, index
+    // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]]{{.*}} : vector<16x16x1x1xf32>, !xegpu.mem_desc<32x32x4x4xf32>, index, index, index, index
     // CHECK-DAG: gpu.barrier
-    // CHECK-DAG: %[[L0:.*]] = arith.muli %[[AFFINE0]], %[[C16_2:.*]] : index
-    // CHECK-DAG: %[[L1:.*]] = arith.muli %[[AFFINE2]], %[[C16_3:.*]] : index
-    // CHECK-DAG: %[[L2:.*]] = arith.muli %[[AFFINE4]], %[[C0_0:.*]] : index
-    // CHECK-DAG: %[[L3:.*]] = arith.muli %[[AFFINE5]], %[[C0_1:.*]] : index
-    // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[L0]], %[[L1]], %[[L2]], %[[L3]]] : !xegpu.mem_desc<32x32x4x4xf32>, index, index, index, index -> vector<16x16x4x4xf32>
+    // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]]{{.*}} : !xegpu.mem_desc<32x32x4x4xf32>, index, index, index, index -> vector<16x16x4x4xf32>
     // CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<0.000000e+00> : vector<16x16xf32>
     // CHECK-DAG: %[[FINAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD_SLM]], %[[CST_3]] [2, 3] : vector<16x16x4x4xf32> to vector<16x16xf32>
     // CHECK-DAG: %[[FINAL_ADD:.*]] = arith.addf %[[FINAL_REDUCE]], %[[CST]] : vector<16x16xf32>


        


More information about the Mlir-commits mailing list