[Mlir-commits] [mlir] [MLIR][XeGPU] Add support for cross-subgroup reduction from wg to sg (PR #170936)

Nishant Patel llvmlistbot at llvm.org
Fri Dec 5 17:25:59 PST 2025


https://github.com/nbpatel updated https://github.com/llvm/llvm-project/pull/170936

>From 91de10671a4eb3d2ac8876bd45e347d2d62d7015 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 4 Dec 2025 00:08:26 +0000
Subject: [PATCH] Add support for cross-subgroup reduction from wg to sg

---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 232 +++++++++++++++---
 .../XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir    |   4 +-
 .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir       |  85 +++++++
 3 files changed, 287 insertions(+), 34 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 48bd0662b03ff..173a798fa9021 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1152,11 +1152,8 @@ struct WgToSgVectorShapeCastOp
   }
 };
 
-/// Pattern for lowering vector.multi_reduction op to subgroup level.
-/// Current limitation: the sg_layout in the reduced dimension being 1
-/// so that reduction is local to subgroup & no cross-subgroup communication is
-/// needed.
-/// TODO: Add cases to handle more general situations which require SLM access.
+// This pattern transforms vector.multi_dim_reduction ops to work at subgroup
+// level.
 struct WgToSgMultiDimReductionOp
     : public OpConversionPattern<vector::MultiDimReductionOp> {
   using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
@@ -1164,52 +1161,221 @@ struct WgToSgMultiDimReductionOp
   LogicalResult
   matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+
     VectorType srcType = op.getSourceVectorType();
     VectorType dstType = dyn_cast<VectorType>(op.getResult().getType());
     if (!dstType)
       return failure();
 
-    auto srcShape = srcType.getShape();
+    auto originalSrcShape = srcType.getShape();
     xegpu::DistributeLayoutAttr layout =
         xegpu::getDistributeLayoutAttr(op.getResult());
+
     if (!layout || !layout.isForWorkgroup())
       return failure();
 
     auto reductionDims = llvm::to_vector(op.getReductionDims());
+    if (reductionDims.size() != 1)
+      return rewriter.notifyMatchFailure(
+          op, "Only single dimension reduction is supported");
+
+    // Get sg_layout and sg_data from the parent layout
+    SmallVector<int64_t> sgLayout;
+    SmallVector<int64_t> sgData;
+    if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(layout)) {
+      sgLayout = sliceAttr.getParent().getEffectiveSgLayoutAsInt();
+      sgData = sliceAttr.getParent().getEffectiveSgDataAsInt();
+    } else
+      return rewriter.notifyMatchFailure(
+          op, "Reduction should have SliceAttr layout");
+
+    Type elemTy = dstType.getElementType();
+
+    // Step 1: perform local subgroup reductions with ZERO accumulator
+    SmallVector<Value> localReductions;
+    auto sources = adaptor.getSource();
+    auto accs = adaptor.getAcc();
+
+    SmallVector<Value> expandedAccs;
+    if (accs.size() == 1 && sources.size() > 1) {
+      for (size_t i = 0; i < sources.size(); ++i)
+        expandedAccs.push_back(accs[0]);
+    } else
+      expandedAccs = llvm::to_vector(accs);
+
+    SmallVector<int64_t> sgShape =
+        getSgShapeAndCount(originalSrcShape, layout).first;
+    VectorType newDstType = VectorType::get({sgShape}, elemTy);
+    for (auto [sgSrc, sgAcc] : llvm::zip(sources, expandedAccs)) {
+      // Create ZERO accumulator for local reduction
+      auto zeroLocalAcc = arith::ConstantOp::create(
+          rewriter, loc, newDstType,
+          DenseElementsAttr::get(newDstType, rewriter.getZeroAttr(elemTy)));
+      // Local reduction with ZERO accumulator
+      auto localReduce = vector::MultiDimReductionOp::create(
+          rewriter, loc, newDstType, op.getKind(), sgSrc,
+          zeroLocalAcc.getResult(), reductionDims);
+      localReductions.push_back(localReduce.getResult());
+    }
 
-    SmallVector<int64_t> sgLayout = llvm::cast<xegpu::SliceAttr>(layout)
-                                        .getParent()
-                                        .getEffectiveSgLayoutAsInt();
-    SmallVector<int64_t> sgData = llvm::cast<xegpu::SliceAttr>(layout)
-                                      .getParent()
-                                      .getEffectiveSgDataAsInt();
-
-    // Check that the sgLayout in the reduced dimension is 1 and
-    // each sg gets the entire slice to reduce.
-    for (int64_t dim : reductionDims) {
-      if (sgLayout[dim] != 1 || sgData[dim] != srcShape[dim])
-        return rewriter.notifyMatchFailure(
-            op,
-            "sgLayout in each reduced dimension must be 1 and sgData in the "
-            "reduced dim must match srcShape in that dim");
+    // Check if cross-subgroup reduction is needed
+    int64_t reductionDim = reductionDims[0];
+    bool needsCrossSubgroupReduction = (sgLayout[reductionDim] > 1);
+
+    // If no cross-subgroup reduction needed, add accumulator and return
+    if (!needsCrossSubgroupReduction) {
+      SmallVector<Value> results;
+      for (auto localResult : localReductions) {
+        auto finalResult = arith::AddFOp::create(rewriter, loc, localResult,
+                                                 adaptor.getAcc()[0]);
+        if (auto defOp = finalResult.getResult().getDefiningOp())
+          xegpu::setDistributeLayoutAttr(defOp->getResult(0),
+                                         layout.dropSgLayoutAndData());
+        results.push_back(finalResult.getResult());
+      }
+      rewriter.replaceOpWithMultiple(op, {results});
+      return success();
     }
 
-    SmallVector<int64_t> sgShape = getSgShapeAndCount(srcShape, layout).first;
+    // Step 2: Cross-subgroup reduction using SLM
 
-    VectorType newDstType =
-        VectorType::get({sgShape}, dstType.getElementType());
+    // Calculate total elements in local result
+    int64_t localElements = computeProduct(sgShape);
 
-    SmallVector<Value> newReductions;
-    for (auto sgSrc : adaptor.getSource()) {
-      auto newOp = vector::MultiDimReductionOp::create(
-          rewriter, op.getLoc(), newDstType, op.getKind(), sgSrc,
-          adaptor.getAcc()[0], op.getReductionDims());
-      xegpu::setDistributeLayoutAttr(newOp->getResult(0),
-                                     layout.dropSgLayoutAndData());
-      newReductions.push_back(newOp.getResult());
+    // Shape cast for SLM storage - store as [1, localElements]
+    SmallVector<int64_t> storeShape2D = {1, localElements};
+    VectorType storeType2D = VectorType::get(storeShape2D, elemTy);
+    auto storeShapeCast = vector::ShapeCastOp::create(
+        rewriter, loc, storeType2D, localReductions[0]);
+    Value storeData = storeShapeCast.getResult();
+
+    // Calculate SLM shape
+    int64_t totalReductionSubgroups =
+        sgLayout[static_cast<size_t>(reductionDims[0])];
+
+    // Total result elements across all subgroups in non-reduction dimensions
+    int64_t totalResultElements = localElements;
+    for (size_t i = 0; i < sgLayout.size(); ++i) {
+      if (!llvm::is_contained(reductionDims, static_cast<int64_t>(i)))
+        totalResultElements *= sgLayout[i];
+    }
+
+    SmallVector<int64_t> slmShape2D = {totalReductionSubgroups,
+                                       totalResultElements};
+
+    // Allocate SLM
+    auto bitWidth = elemTy.getIntOrFloatBitWidth();
+    auto bytesPerElement = bitWidth / 8;
+    int64_t slmElements = slmShape2D[0] * slmShape2D[1];
+    auto slmSize = slmElements * bytesPerElement;
+    auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3);
+    auto slm = memref::AllocaOp::create(rewriter, loc, slmTy);
+
+    auto memDescType = xegpu::MemDescType::get(rewriter.getContext(),
+                                               slmShape2D, elemTy, nullptr);
+    auto memDesc =
+        xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm);
+
+    // Step 4: Store local results to SLM
+    auto sgId = gpu::SubgroupIdOp::create(rewriter, loc,
+                                          rewriter.getIndexType(), nullptr);
+
+    // Convert sgLayout to Values for delinearizeIndex
+    SmallVector<Value> sgLayoutValues;
+    for (int64_t dim : sgLayout)
+      sgLayoutValues.push_back(
+          arith::ConstantIndexOp::create(rewriter, loc, dim));
+
+    auto sgIdsResult = affine::delinearizeIndex(rewriter, loc, sgId.getResult(),
+                                                sgLayoutValues);
+    if (failed(sgIdsResult))
+      return failure();
+    SmallVector<Value> sgIds = *sgIdsResult;
+
+    // Row offset is simply the subgroup ID along the reduction dimension
+    Value rowOffset = sgIds[reductionDim];
+
+    // Column offset: linearize all non-reduction dimensions and multiply by
+    // localElements
+    Value colOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
+    int64_t currentStride = 1;
+    for (size_t i = 0; i < sgLayout.size(); ++i) {
+      if (static_cast<int64_t>(i) != reductionDim) { // Skip reduction dimension
+        Value dimVal = sgIds[i];
+        Value strideVal =
+            arith::ConstantIndexOp::create(rewriter, loc, currentStride);
+        Value term = arith::MulIOp::create(rewriter, loc, dimVal, strideVal);
+        colOffset = arith::AddIOp::create(rewriter, loc, colOffset, term);
+        currentStride *= sgLayout[i];
+      }
+    }
+    Value localElementsVal =
+        arith::ConstantIndexOp::create(rewriter, loc, localElements);
+    colOffset =
+        arith::MulIOp::create(rewriter, loc, colOffset, localElementsVal);
+
+    SmallVector<OpFoldResult> storeOffsets2D = {rowOffset, colOffset};
+
+    xegpu::StoreMatrixOp::create(rewriter, loc, storeData, memDesc.getResult(),
+                                 storeOffsets2D, /*layout=*/nullptr);
+
+    gpu::BarrierOp::create(rewriter, loc);
+
+    // Step 5: Load from SLM for final reduction
+    SmallVector<int64_t> loadShape2D = {totalReductionSubgroups, localElements};
+    VectorType loadType2D = VectorType::get(loadShape2D, elemTy);
+
+    // Load offsets - each subgroup loads its column based on non-reduction
+    // position
+    Value loadOffsetY = arith::ConstantIndexOp::create(rewriter, loc, 0);
+    Value loadOffsetX = colOffset;
+
+    SmallVector<OpFoldResult> loadOffsets2D = {loadOffsetY, loadOffsetX};
+
+    auto loadOp = xegpu::LoadMatrixOp::create(
+        rewriter, loc, loadType2D, memDesc.getResult(), loadOffsets2D,
+        /*layout=*/nullptr);
+
+    // Step 6: Perform final reduction with ZERO accumulator
+    SmallVector<int64_t> finalReductionDims = {0};
+    SmallVector<int64_t> finalResultShape = {localElements};
+    VectorType finalResultType = VectorType::get(finalResultShape, elemTy);
+
+    // Create ZERO accumulator for final reduction
+    auto zeroFinalAcc = arith::ConstantOp::create(
+        rewriter, loc, finalResultType,
+        DenseElementsAttr::get(finalResultType, rewriter.getZeroAttr(elemTy)));
+
+    auto finalReduce = vector::MultiDimReductionOp::create(
+        rewriter, loc, finalResultType, op.getKind(), loadOp.getResult(),
+        zeroFinalAcc.getResult(), finalReductionDims);
+
+    // Step 7: Add the original accumulator at the end
+    Value originalAcc = adaptor.getAcc()[0];
+    Value accToAdd = originalAcc;
+
+    // Handle shape mismatch by shape casting
+    if (originalAcc.getType() != finalReduce.getResult().getType()) {
+      auto originalAccType = cast<VectorType>(originalAcc.getType());
+      auto finalResultType =
+          cast<VectorType>(finalReduce.getResult().getType());
+
+      // If they have the same number of elements, just shape cast
+      if (originalAccType.getNumElements() == finalResultType.getNumElements())
+        auto shapeCast = vector::ShapeCastOp::create(
+            rewriter, loc, finalResultType, originalAcc);
+      accToAdd = shapeCast.getResult();
     }
 
-    rewriter.replaceOpWithMultiple(op, {newReductions});
+    auto finalResult =
+        arith::AddFOp::create(rewriter, loc, finalReduce.getResult(), accToAdd);
+
+    if (auto defOp = finalResult.getResult().getDefiningOp())
+      xegpu::setDistributeLayoutAttr(defOp->getResult(0),
+                                     layout.dropSgLayoutAndData());
+
+    rewriter.replaceOp(op, finalResult.getResult());
     return success();
   }
 };
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
index 1cddccb5fbbd1..ff792d809a090 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
@@ -83,8 +83,10 @@ gpu.module @test_distribution {
     %load =  xegpu.load_nd %tdesc[0, 0]
       : !xegpu.tensor_desc<256x64xf32, #xegpu.layout<sg_layout = [8, 1], sg_data = [16, 64]>>
       -> vector<256x64xf32>
-    // CHECK-COUNT-2: vector.multi_reduction <add>, {{.*}}, %[[CST]] [1] : vector<16x64xf32> to vector<16xf32>
+    // CHECK-COUNT-2: vector.multi_reduction <add>, {{.*}}, %[[C0:.*]] [1] : vector<16x64xf32> to vector<16xf32>
     // CHECK-NOT: vector.multi_reduction
+    // CHECK-COUNT-2: arith.addf {{.*}}, {{.*}} : vector<16xf32>
+    // CHECK-NOT: arith.addf
     %reduce = vector.multi_reduction <add>, %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1], sg_data = [16, 64]>, dims = [1]>} [1]
       : vector<256x64xf32> to vector<256xf32>
     gpu.return
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 574b365443a0a..bbacc527984e7 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -1,5 +1,10 @@
 // RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
 
+// CHECK-DAG: #map = affine_map<()[s0] -> (s0 floordiv 32)>
+// CHECK-DAG: #map1 = affine_map<()[s0] -> (s0 mod 32)>
+// CHECK-DAG: #map2 = affine_map<()[s0] -> (0)>
+// CHECK-DAG: #map3 = affine_map<()[s0] -> (s0 floordiv 4)>
+// CHECK-DAG: #map4 = affine_map<()[s0] -> (s0 mod 4)>
 gpu.module @test_distribution {
   // CHECK-LABEL: create_nd_tdesc_no_offset
   // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
@@ -599,4 +604,84 @@ gpu.module @test_distribution {
         #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>} : vector<256xf32> to vector<256x256xf32>
     gpu.return
   }
+
+  // CHECK-LABEL: gpu.func @vector_reduce_cross_sg_dim_1
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>)
+gpu.func @vector_reduce_cross_sg_dim_1(%src: memref<?xf32>) {
+  // CHECK-DAG: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<1x32xf32>
+  // CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<0> : vector<1x1x32xindex>
+  // CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<true> : vector<1x1x32xi1>
+  // CHECK-DAG: %[[LOAD:.*]] = xegpu.load %[[ARG0]][%[[CST_0]]], %[[CST_1]] <{chunk_size = 1 : i64}> : memref<?xf32>, vector<1x1x32xindex>, vector<1x1x32xi1> -> vector<1x1x32xf32>
+  // CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<0.000000e+00> : vector<1x32xf32>
+  // CHECK-DAG: %[[LOCAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD]], %[[CST_2]] [1] : vector<1x1x32xf32> to vector<1x32xf32>
+  // CHECK-DAG: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[LOCAL_REDUCE]] : vector<1x32xf32> to vector<1x32xf32>
+  // CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<4096xi8, 3>
+  // CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf32>
+  // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+  // CHECK-DAG: %[[AFFINE1:.*]] = affine.apply #map()[%[[SGID]]]
+  // CHECK-DAG: %[[AFFINE2:.*]] = affine.apply #map1()[%[[SGID]]]
+  // CHECK-DAG: %[[AFFINE3:.*]] = affine.apply #map2()[%[[SGID]]]
+  // CHECK-DAG: %[[MUL1:.*]] = arith.muli {{.*}}, %[[C1:.*]] : index
+  // CHECK-DAG: %[[ADD1:.*]] = arith.addi %[[C0:.*]], %[[MUL1]] : index
+  // CHECK-DAG: %[[MUL2:.*]] = arith.muli %[[AFFINE3]], %[[C1:.*]] : index
+  // CHECK-DAG: %[[ADD2:.*]] = arith.addi %[[ADD1]], %[[MUL2]] : index
+  // CHECK-DAG: %[[COL_OFFSET:.*]] = arith.muli %[[ADD2]], %[[C32:.*]] : index
+  // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][{{.*}}, %[[COL_OFFSET]]] : vector<1x32xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+  // CHECK-DAG: gpu.barrier
+  // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[C0:.*]], %[[COL_OFFSET]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<32x32xf32>
+  // CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<0.000000e+00> : vector<32xf32>
+  // CHECK-DAG: %[[FINAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD_SLM]], %[[CST_3]] [0] : vector<32x32xf32> to vector<32xf32>
+  // CHECK-DAG: %[[SHAPE_CAST_FINAL:.*]] = vector.shape_cast %[[CST]] : vector<1x32xf32> to vector<32xf32>
+  // CHECK-DAG: arith.addf %[[FINAL_REDUCE]], %[[SHAPE_CAST_FINAL]] : vector<32xf32>
+  %cst_3 = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 32, 1], sg_data = [1, 1, 32]>, dims = [1]>} dense<1.0> : vector<1x32xf32>
+  %offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [1, 32, 1], sg_data = [1, 1, 32]>} dense<0> : vector<1x32x32xindex>
+  %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [1, 32, 1], sg_data = [1, 1, 32]>} dense<true> : vector<1x32x32xi1>
+  %14 = xegpu.load %src[%offset], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout<sg_layout = [1, 32, 1], sg_data = [1, 1, 32]>} : memref<?xf32>, vector<1x32x32xindex>, vector<1x32x32xi1> -> vector<1x32x32xf32>
+  %15 = vector.multi_reduction <add>, %14, %cst_3 {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 32, 1], sg_data = [1, 1, 32]>, dims = [1]>} [1] : vector<1x32x32xf32> to vector<1x32xf32>
+  // CHECK-DAG: gpu.return
+  gpu.return
+}
+
+  // CHECK-LABEL: gpu.func @vector_reduce_cross_sg_dim_0
+  // CHECK-SAME: (%[[ARG0:.*]]: memref<256x128xf32>)
+  gpu.func @vector_reduce_cross_sg_dim_0(%src: memref<256x128xf32>) {
+    // CHECK-DAG: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<32xf32>
+    // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+    // CHECK-DAG: %[[REM4:.*]] = index.remu %[[SGID]], %[[C4:.*]]
+    // CHECK-DAG: %[[DIV4:.*]] = index.divu %[[SGID]], %[[C4:.*]]
+    // CHECK-DAG: %[[REM8:.*]] = index.remu %[[DIV4]], %[[C8:.*]]
+    // CHECK-DAG: %[[MUL1:.*]] = index.mul %[[REM8]], %[[C32:.*]]
+    // CHECK-DAG: %[[MUL2:.*]] = index.mul %[[REM4]], %[[C32:.*]]
+    // CHECK-DAG: %[[REM256:.*]] = index.remu %[[MUL1]], %[[C256:.*]]
+    // CHECK-DAG: %[[REM128:.*]] = index.remu %[[MUL2]], %[[C128:.*]]
+    // CHECK-DAG: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%[[REM256]], %[[REM128]]] : memref<256x128xf32> -> !xegpu.tensor_desc<32x32xf32>
+    // CHECK-DAG: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]] : !xegpu.tensor_desc<32x32xf32> -> vector<32x32xf32>
+    // CHECK-DAG: %[[CST_LOCAL:.*]] = arith.constant dense<0.000000e+00> : vector<32xf32>
+    // CHECK-DAG: %[[LOCAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD]], %[[CST_LOCAL]] [0] : vector<32x32xf32> to vector<32xf32>
+    // CHECK-DAG: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[LOCAL_REDUCE]] : vector<32xf32> to vector<1x32xf32>
+    // CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<4096xi8, 3>
+    // CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<4096xi8, 3> -> !xegpu.mem_desc<8x128xf32>
+    // CHECK-DAG: %[[SGID2:.*]] = gpu.subgroup_id : index
+    // CHECK-DAG: %[[AFFINE1:.*]] = affine.apply #map3()[%[[SGID2]]]
+    // CHECK-DAG: %[[AFFINE2:.*]] = affine.apply #map4()[%[[SGID2]]]
+    // CHECK-DAG: %[[MUL_AFFINE:.*]] = arith.muli {{.*}}, %[[C1:.*]] : index
+    // CHECK-DAG: %[[ADD_OFFSET:.*]] = arith.addi %[[C0:.*]], %[[MUL_AFFINE]] : index
+    // CHECK-DAG: %[[COL_OFFSET:.*]] = arith.muli %[[ADD_OFFSET]], %[[C32:.*]] : index
+    // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][{{.*}}, %[[COL_OFFSET]]] : vector<1x32xf32>, !xegpu.mem_desc<8x128xf32>, index, index
+    // CHECK-DAG: gpu.barrier
+    // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[C0:.*]], %[[COL_OFFSET]]] : !xegpu.mem_desc<8x128xf32>, index, index -> vector<8x32xf32>
+    // CHECK-DAG: %[[CST_CROSS_SG_1:.*]] = arith.constant dense<0.000000e+00> : vector<32xf32>
+    // CHECK-DAG: %[[FINAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD_SLM]], %[[CST_CROSS_SG_1]] [0] : vector<8x32xf32> to vector<32xf32>
+    // CHECK-DAG: arith.addf %[[FINAL_REDUCE]], %[[CST:.*]] : vector<32xf32>
+    %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>, dims = [0]>} dense<0.0> : vector<128xf32>
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+      -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>>
+    %load =  xegpu.load_nd %tdesc
+      : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>>
+      -> vector<256x128xf32>
+    %reduce = vector.multi_reduction <add>, %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>, dims = [0]>} [0]
+      : vector<256x128xf32> to vector<128xf32>
+    // CHECK-DAG: gpu.return
+    gpu.return
+  }
 }



More information about the Mlir-commits mailing list