[Mlir-commits] [mlir] [MLIR][XeGPU] Add support for vector.multi_reduction in wg to sg pass (PR #156731)
    Nishant Patel 
    llvmlistbot at llvm.org
       
    Thu Sep  4 10:34:52 PDT 2025
    
    
  
https://github.com/nbpatel updated https://github.com/llvm/llvm-project/pull/156731
>From c3e598690dc120de36bbe14500936dc4202299e5 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Fri, 29 Aug 2025 20:45:52 +0000
Subject: [PATCH 1/4] Add pattern for reduction
---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 211 +++++++++++++++++-
 .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir       |  13 ++
 2 files changed, 219 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 0b7fe81facfce..54a98970a0fc0 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -726,7 +726,6 @@ struct UnrealizedConversionCastOpPattern
   }
 };
 
-// This pattern distributes arith.constant op into subgroup-level constants
 struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
   using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
 
@@ -756,8 +755,15 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
     auto sgAttr = DenseElementsAttr::get(newType, singleVal);
     auto cstOp =
         arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr);
-    if (auto newLayout = layout.dropSgLayoutAndData())
-      xegpu::setDistributeLayoutAttr(cstOp->getResult(0), newLayout);
+    if (auto sliceAttr = dyn_cast_if_present<xegpu::SliceAttr>(layout)) {
+      if (sliceAttr.isForSubgroup())
+        xegpu::setDistributeLayoutAttr(cstOp->getResult(0),
+                                       sliceAttr.dropSgLayoutAndData());
+    } else if (auto layoutAttr =
+                   dyn_cast_if_present<xegpu::LayoutAttr>(layout)) {
+      if (auto newLayout = layoutAttr.dropSgLayoutAndData())
+        xegpu::setDistributeLayoutAttr(cstOp->getResult(0), newLayout);
+    }
     SmallVector<Value> newConsts(count, cstOp);
 
     rewriter.replaceOpWithMultiple(op, {newConsts});
@@ -815,6 +821,191 @@ struct WgToSgStoreMatrixOp : public OpConversionPattern<xegpu::StoreMatrixOp> {
   }
 };
 
+// Pattern to distribute vector.multi_dim_reduction op to subgroup level.
+struct WgToSgMultiDimReductionOp
+    : public OpConversionPattern<vector::MultiDimReductionOp> {
+  using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    // Only support reduction with layout and on a single dimension for now.
+    VectorType srcType = dyn_cast<VectorType>(op.getSource().getType());
+    VectorType accType = dyn_cast<VectorType>(op.getAcc().getType());
+    VectorType resType = dyn_cast<VectorType>(op.getResult().getType());
+    Type elemTy = srcType.getElementType();
+    if (!srcType || !accType || !resType)
+      return failure();
+
+    ArrayRef<int64_t> wgShape = resType.getShape();
+    // Handle both LayoutAttr and SliceAttr for the op result.
+    auto layoutName = xegpu::getLayoutName(op->getResult(0));
+    auto sliceAttr = op->getAttrOfType<xegpu::SliceAttr>(layoutName);
+    if (!sliceAttr || sliceAttr.getRank() != 1)
+      return failure();
+
+    SmallVector<int64_t> dims =
+        llvm::to_vector(sliceAttr.getDims().asArrayRef());
+    SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, sliceAttr).first;
+
+    int64_t reduceDim = dims[0];
+
+    // Step 1: Subgroup-level reduction
+    // Each subgroup reduces its local tile.
+    SmallVector<Value> newReductions;
+    VectorType newType = VectorType::get(sgShape, srcType.getElementType());
+    SmallVector<int64_t> shapeCastShape = sgShape;
+    if (reduceDim == 0)
+      shapeCastShape.insert(shapeCastShape.begin(), 1);
+    else
+      shapeCastShape.push_back(1);
+    for (auto [sgSrc, sgAcc] :
+         llvm::zip(adaptor.getSource(), adaptor.getAcc())) {
+      auto sgReduce = rewriter.create<vector::MultiDimReductionOp>(
+          op.getLoc(), newType, op.getKind(), sgSrc, sgAcc,
+          op.getReductionDims());
+      // Compute the shape for the shape cast: set reducedDim to 1, keep other
+      // dims as sgShape
+      auto shapeCastTy =
+          VectorType::get(shapeCastShape, srcType.getElementType());
+      auto shapeCast = rewriter.create<vector::ShapeCastOp>(
+          op.getLoc(), shapeCastTy, sgReduce.getResult());
+      // TODO: Change it to shapeCast
+      newReductions.push_back(shapeCast.getResult());
+    }
+
+    rewriter.setInsertionPoint(op);
+
+    // Get layout of the source tensor
+    SmallVector<int64_t> sgLayoutParent =
+        sliceAttr.getParent().getSgLayoutAsInt();
+
+    // Allocate SLM
+    auto bitWidth = elemTy.getIntOrFloatBitWidth();
+    auto flattenFactor = bitWidth / 8;
+    auto slmSize =
+        resType.getNumElements() * sgLayoutParent[reduceDim] * flattenFactor;
+    auto slmTy = MemRefType::get(slmSize, rewriter.getI8Type(), {}, 3);
+    auto slm = rewriter.create<memref::AllocaOp>(loc, slmTy);
+
+    // Create a view for the SLM buffer using xegpu.create_mem_desc
+    SmallVector<int64_t> viewShape;
+    auto srcVecType = dyn_cast<VectorType>(adaptor.getSource()[0].getType());
+    ArrayRef<int64_t> srcShape =
+        srcVecType ? srcVecType.getShape() : ArrayRef<int64_t>();
+    for (size_t i = 0; i < srcShape.size(); ++i) {
+      if (static_cast<int64_t>(i) == reduceDim) {
+        // For the reduced dimension, use sgLayoutParent[i]
+        viewShape.push_back(sgLayoutParent[i]);
+      } else {
+        // For other dimensions, multiply sgLayoutParent[i] by sgShape[i]
+        viewShape.push_back(sgLayoutParent[i] * srcShape[i]);
+      }
+    }
+
+    auto memDescType = xegpu::MemDescType::get(rewriter.getContext(), viewShape,
+                                               elemTy, nullptr);
+    auto memDesc =
+        rewriter.create<xegpu::CreateMemDescOp>(loc, memDescType, slm);
+
+    // Step 2: Store subgroup results to SLM (shared local memory)
+    // SLM layout: sgLayout same as srcLayout, sgData is shapeCastShape
+    SmallVector<int64_t> slmSgData = shapeCastShape;
+
+    // Get subgroup id and delinearize
+    auto sgId = rewriter.create<gpu::SubgroupIdOp>(loc, rewriter.getIndexType(),
+                                                   nullptr);
+
+    SmallVector<Value> srcSgLayoutDim(sgLayoutParent.size());
+
+    for (size_t i = 0; i < sgLayoutParent.size(); i++) {
+      srcSgLayoutDim[i] =
+          arith::ConstantIndexOp::create(rewriter, loc, sgLayoutParent[i]);
+    }
+
+    auto sgIdVec =
+        affine::delinearizeIndex(rewriter, loc, sgId, srcSgLayoutDim);
+    if (failed(sgIdVec))
+      return failure();
+    SmallVector<Value> sgIds = *sgIdVec;
+
+    // Calculate offsets for store_matrix
+    SmallVector<OpFoldResult> slmStoreOffsets;
+    for (size_t i = 0; i < sgLayoutParent.size(); ++i) {
+      Value offset = rewriter.createOrFold<index::MulOp>(
+          loc, sgIds[i],
+          arith::ConstantIndexOp::create(rewriter, loc, slmSgData[i]));
+      slmStoreOffsets.push_back(offset);
+    }
+
+    // Store subgroup result to SLM
+    rewriter.create<xegpu::StoreMatrixOp>(
+        loc, newReductions[0], memDesc.getResult(),
+        ArrayRef<OpFoldResult>(slmStoreOffsets),
+        /*layout=*/nullptr);
+
+    // Barrier to synchronize subgroups
+    rewriter.create<gpu::BarrierOp>(loc);
+
+    // Step 3: Load from SLM for the second reduction
+    SmallVector<int64_t> slmLoadShape;
+
+    for (size_t i = 0; i < viewShape.size(); ++i) {
+      if (static_cast<int64_t>(i) == reduceDim) {
+        slmLoadShape.push_back(viewShape[i]);
+      } else {
+        int64_t divisor = computeProduct(sgLayoutParent);
+        slmLoadShape.push_back(viewShape[i] / divisor);
+      }
+    }
+
+    // Calculate offsets for create_nd_desc
+    SmallVector<OpFoldResult> slmLoadOffsets;
+    for (size_t i = 0; i < sgLayoutParent.size(); ++i) {
+      Value offset = rewriter.createOrFold<index::MulOp>(
+          loc, sgIds[i],
+          arith::ConstantIndexOp::create(rewriter, loc, slmLoadShape[i]));
+      slmLoadOffsets.push_back(offset);
+    }
+
+    auto load = rewriter.create<xegpu::LoadMatrixOp>(
+        loc, VectorType::get(slmLoadShape, elemTy), memDesc,
+        llvm::ArrayRef<OpFoldResult>({slmLoadOffsets}),
+        /*layout=*/nullptr);
+
+    // Step 4: Create a constant accumulator for the second reduction
+    // with same vallue as adaptor.getAcc()[0] and shape set to
+    // the non reduce dimension of shapeCastLoad
+    auto accShape = load.getType().getShape();
+    SmallVector<int64_t> accShapeWithoutReduceDim;
+    for (size_t i = 0; i < accShape.size(); ++i) {
+      if (static_cast<int64_t>(i) != reduceDim)
+        accShapeWithoutReduceDim.push_back(accShape[i]);
+    }
+    auto accTy = VectorType::get(accShapeWithoutReduceDim, elemTy);
+    auto accConstOp = adaptor.getAcc()[0].getDefiningOp<arith::ConstantOp>();
+    Attribute accSplatValue;
+    if (auto denseAttr = dyn_cast_or_null<DenseElementsAttr>(
+            accConstOp ? accConstOp.getValue() : nullptr)) {
+      accSplatValue =
+          denseAttr.isSplat() ? denseAttr.getSplatValue<Attribute>() : nullptr;
+    }
+    if (!accSplatValue)
+      return failure();
+    auto accValue = rewriter.create<arith::ConstantOp>(
+        loc, accTy, DenseElementsAttr::get(accTy, accSplatValue));
+    // Step 5: Perform the second reduction
+    VectorType secondReduceVecType =
+        VectorType::get(accShapeWithoutReduceDim, srcType.getElementType());
+    auto secondReduce = rewriter.create<vector::MultiDimReductionOp>(
+        loc, secondReduceVecType, op.getKind(), load, accValue,
+        op.getReductionDims());
+    rewriter.replaceOpWithMultiple(op, {secondReduce.getResult()});
+    return success();
+  }
+};
+
 } // namespace
 
 namespace mlir {
@@ -826,8 +1017,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
            WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
            WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
            WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
-           WgToSgArithConstantOp, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp>(
-          patterns.getContext());
+           WgToSgArithConstantOp, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp,
+           WgToSgMultiDimReductionOp>(patterns.getContext());
 }
 } // namespace xegpu
 } // namespace mlir
@@ -987,6 +1178,16 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
         return isLegal(layout);
       });
 
+  target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
+      [=](vector::MultiDimReductionOp op) -> bool {
+        // Only allow MultiDimReductionOp with a single reduction dimension
+        if (op.getReductionDims().size() != 1)
+          return true;
+
+        // Check if the layout is legal
+        return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
+      });
+
   target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
       [=](UnrealizedConversionCastOp op) {
         return llvm::is_contained(existingCastOps, op.getOperation());
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 32157a7911f62..1b417d752edcc 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -321,4 +321,17 @@ gpu.module @test_distribution {
     xegpu.store_matrix %cst, %mdesc[0, 0] {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [32, 32]>} : vector<64x128xf32>, !xegpu.mem_desc<64x128xf32>
     gpu.return
   }
+
+  //CHECK-LABEL: vector_reduce
+  gpu.func @vector_reduce(%src: memref<256x128xf32>) {
+    %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>, dims = [0]>} dense<1.0> : vector<128xf32>
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+      -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
+    %load =  xegpu.load_nd %tdesc
+      : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
+      -> vector<256x128xf32>
+    %reduce = vector.multi_reduction <add>, %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>, dims = [0]>} [0]
+      : vector<256x128xf32> to vector<128xf32>
+    gpu.return
+  }
 }
>From 100341dec307fbab0612abc49fe38d27239d177a Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Wed, 3 Sep 2025 17:23:28 +0000
Subject: [PATCH 2/4] Add CHECKS
---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 34 ++++++++-----------
 .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir       | 26 ++++++++++++++
 2 files changed, 41 insertions(+), 19 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 54a98970a0fc0..fe5026203ad34 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -830,7 +830,6 @@ struct WgToSgMultiDimReductionOp
   matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
-    // Only support reduction with layout and on a single dimension for now.
     VectorType srcType = dyn_cast<VectorType>(op.getSource().getType());
     VectorType accType = dyn_cast<VectorType>(op.getAcc().getType());
     VectorType resType = dyn_cast<VectorType>(op.getResult().getType());
@@ -838,8 +837,10 @@ struct WgToSgMultiDimReductionOp
     if (!srcType || !accType || !resType)
       return failure();
 
+    // Support only 2D vectors
+    if (srcType.getShape().size() != 2 && resType.getShape().size() != 1)
+      return failure();
     ArrayRef<int64_t> wgShape = resType.getShape();
-    // Handle both LayoutAttr and SliceAttr for the op result.
     auto layoutName = xegpu::getLayoutName(op->getResult(0));
     auto sliceAttr = op->getAttrOfType<xegpu::SliceAttr>(layoutName);
     if (!sliceAttr || sliceAttr.getRank() != 1)
@@ -871,7 +872,6 @@ struct WgToSgMultiDimReductionOp
           VectorType::get(shapeCastShape, srcType.getElementType());
       auto shapeCast = rewriter.create<vector::ShapeCastOp>(
           op.getLoc(), shapeCastTy, sgReduce.getResult());
-      // TODO: Change it to shapeCast
       newReductions.push_back(shapeCast.getResult());
     }
 
@@ -889,23 +889,23 @@ struct WgToSgMultiDimReductionOp
     auto slmTy = MemRefType::get(slmSize, rewriter.getI8Type(), {}, 3);
     auto slm = rewriter.create<memref::AllocaOp>(loc, slmTy);
 
-    // Create a view for the SLM buffer using xegpu.create_mem_desc
-    SmallVector<int64_t> viewShape;
+    // Create a SLM buffer using xegpu.create_mem_desc
+    SmallVector<int64_t> memDescShape;
     auto srcVecType = dyn_cast<VectorType>(adaptor.getSource()[0].getType());
     ArrayRef<int64_t> srcShape =
         srcVecType ? srcVecType.getShape() : ArrayRef<int64_t>();
     for (size_t i = 0; i < srcShape.size(); ++i) {
       if (static_cast<int64_t>(i) == reduceDim) {
         // For the reduced dimension, use sgLayoutParent[i]
-        viewShape.push_back(sgLayoutParent[i]);
+        memDescShape.push_back(sgLayoutParent[i]);
       } else {
         // For other dimensions, multiply sgLayoutParent[i] by sgShape[i]
-        viewShape.push_back(sgLayoutParent[i] * srcShape[i]);
+        memDescShape.push_back(sgLayoutParent[i] * srcShape[i]);
       }
     }
 
-    auto memDescType = xegpu::MemDescType::get(rewriter.getContext(), viewShape,
-                                               elemTy, nullptr);
+    auto memDescType = xegpu::MemDescType::get(rewriter.getContext(),
+                                               memDescShape, elemTy, nullptr);
     auto memDesc =
         rewriter.create<xegpu::CreateMemDescOp>(loc, memDescType, slm);
 
@@ -951,16 +951,16 @@ struct WgToSgMultiDimReductionOp
     // Step 3: Load from SLM for the second reduction
     SmallVector<int64_t> slmLoadShape;
 
-    for (size_t i = 0; i < viewShape.size(); ++i) {
+    for (size_t i = 0; i < memDescShape.size(); ++i) {
       if (static_cast<int64_t>(i) == reduceDim) {
-        slmLoadShape.push_back(viewShape[i]);
+        slmLoadShape.push_back(memDescShape[i]);
       } else {
         int64_t divisor = computeProduct(sgLayoutParent);
-        slmLoadShape.push_back(viewShape[i] / divisor);
+        slmLoadShape.push_back(memDescShape[i] / divisor);
       }
     }
 
-    // Calculate offsets for create_nd_desc
+    // Calculate offsets for load_matrix op
     SmallVector<OpFoldResult> slmLoadOffsets;
     for (size_t i = 0; i < sgLayoutParent.size(); ++i) {
       Value offset = rewriter.createOrFold<index::MulOp>(
@@ -975,8 +975,8 @@ struct WgToSgMultiDimReductionOp
         /*layout=*/nullptr);
 
     // Step 4: Create a constant accumulator for the second reduction
-    // with same vallue as adaptor.getAcc()[0] and shape set to
-    // the non reduce dimension of shapeCastLoad
+    // with same value as adaptor.getAcc()[0] and shape set to
+    // the non reduce dimension of load
     auto accShape = load.getType().getShape();
     SmallVector<int64_t> accShapeWithoutReduceDim;
     for (size_t i = 0; i < accShape.size(); ++i) {
@@ -1180,10 +1180,6 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
 
   target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
       [=](vector::MultiDimReductionOp op) -> bool {
-        // Only allow MultiDimReductionOp with a single reduction dimension
-        if (op.getReductionDims().size() != 1)
-          return true;
-
         // Check if the layout is legal
         return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
       });
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 1b417d752edcc..50bcc4341291e 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -323,7 +323,33 @@ gpu.module @test_distribution {
   }
 
   //CHECK-LABEL: vector_reduce
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
   gpu.func @vector_reduce(%src: memref<256x128xf32>) {
+    // CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<32xf32>
+    // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [8, 4], lane_data = [1, 1]>>
+    // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]] : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [8, 4], lane_data = [1, 1]>> -> vector<32x32xf32>
+    // CHECK: %[[REDUCE:.*]] = vector.multi_reduction <add>, {{%.*}}, %[[CST]] [0] : vector<32x32xf32> to vector<32xf32>
+    // CHECK: %[[SHAPECAST:.*]] = vector.shape_cast %[[REDUCE]] : vector<32xf32> to vector<1x32xf32>
+    // CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref<4096xi8, 3>
+    // CHECK: %[[MDESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<4096xi8, 3> -> !xegpu.mem_desc<8x128xf32>
+    // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
+    // CHECK: %[[C8:.*]] = arith.constant 8 : index
+    // CHECK: %[[C4:.*]] = arith.constant 4 : index
+    // CHECK: %[[C4_1:.*]] = arith.constant 4 : index
+    // CHECK: %[[ID_Y:.*]] = affine.apply #map()[%[[SGID]]]
+    // CHECK: %[[ID_X:.*]] = affine.apply #map1()[%[[SGID]]]
+    // CHECK: %[[C1:.*]] = arith.constant 1 : index
+    // CHECK: %[[C32:.*]] = arith.constant 32 : index
+    // CHECK: %[[L_OFF_X:.*]] = index.mul %[[ID_X]], %[[C32]]
+    // CHECK: xegpu.store_matrix {{.*}}, %[[MDESC]][%[[ID_Y]], %[[L_OFF_X]]] : vector<1x32xf32>, !xegpu.mem_desc<8x128xf32>, index, index
+    // CHECK: gpu.barrier
+    // CHECK: %[[C8_1:.*]] = arith.constant 8 : index
+    // CHECK: %[[OFF_Y:.*]] = index.mul %[[ID_Y]], %[[C8_1]]
+    // CHECK: %[[C4_2:.*]] = arith.constant 4 : index
+    // CHECK: %[[OFF_X:.*]] = index.mul %[[ID_X]], %[[C4_2]]
+    // CHECK: %[[LOAD:.*]] = xegpu.load_matrix %[[MDESC]][%[[OFF_Y]], %[[OFF_X]]] : !xegpu.mem_desc<8x128xf32>, index, index -> vector<8x4xf32>
+    // CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<4xf32>
+    // CHECK: %[[REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD]], %[[CST]] [0] : vector<8x4xf32> to vector<4xf32>
     %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>, dims = [0]>} dense<1.0> : vector<128xf32>
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
       -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
>From b1e202693d5c665800e67d7ff6cd8b0fe2146d82 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Wed, 3 Sep 2025 21:09:11 +0000
Subject: [PATCH 3/4] Clean up
---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 26 +++++++++----------
 1 file changed, 12 insertions(+), 14 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 81dcb5d931473..685a9da92e54c 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -981,15 +981,13 @@ struct WgToSgMultiDimReductionOp
 
     rewriter.setInsertionPoint(op);
 
-    // Get layout of the source tensor
-    SmallVector<int64_t> sgLayoutParent =
-        sliceAttr.getParent().getSgLayoutAsInt();
+    SmallVector<int64_t> sgLayout = sliceAttr.getParent().getSgLayoutAsInt();
 
     // Allocate SLM
     auto bitWidth = elemTy.getIntOrFloatBitWidth();
     auto flattenFactor = bitWidth / 8;
     auto slmSize =
-        resType.getNumElements() * sgLayoutParent[reduceDim] * flattenFactor;
+        resType.getNumElements() * sgLayout[reduceDim] * flattenFactor;
     auto slmTy = MemRefType::get(slmSize, rewriter.getI8Type(), {}, 3);
     auto slm = rewriter.create<memref::AllocaOp>(loc, slmTy);
 
@@ -1000,11 +998,11 @@ struct WgToSgMultiDimReductionOp
         srcVecType ? srcVecType.getShape() : ArrayRef<int64_t>();
     for (size_t i = 0; i < srcShape.size(); ++i) {
       if (static_cast<int64_t>(i) == reduceDim) {
-        // For the reduced dimension, use sgLayoutParent[i]
-        memDescShape.push_back(sgLayoutParent[i]);
+        // For the reduced dimension, use sgLayout[i]
+        memDescShape.push_back(sgLayout[i]);
       } else {
-        // For other dimensions, multiply sgLayoutParent[i] by sgShape[i]
-        memDescShape.push_back(sgLayoutParent[i] * srcShape[i]);
+        // For other dimensions, multiply sgLayout[i] by sgShape[i]
+        memDescShape.push_back(sgLayout[i] * srcShape[i]);
       }
     }
 
@@ -1021,11 +1019,11 @@ struct WgToSgMultiDimReductionOp
     auto sgId = rewriter.create<gpu::SubgroupIdOp>(loc, rewriter.getIndexType(),
                                                    nullptr);
 
-    SmallVector<Value> srcSgLayoutDim(sgLayoutParent.size());
+    SmallVector<Value> srcSgLayoutDim(sgLayout.size());
 
-    for (size_t i = 0; i < sgLayoutParent.size(); i++) {
+    for (size_t i = 0; i < sgLayout.size(); i++) {
       srcSgLayoutDim[i] =
-          arith::ConstantIndexOp::create(rewriter, loc, sgLayoutParent[i]);
+          arith::ConstantIndexOp::create(rewriter, loc, sgLayout[i]);
     }
 
     auto sgIdVec =
@@ -1036,7 +1034,7 @@ struct WgToSgMultiDimReductionOp
 
     // Calculate offsets for store_matrix
     SmallVector<OpFoldResult> slmStoreOffsets;
-    for (size_t i = 0; i < sgLayoutParent.size(); ++i) {
+    for (size_t i = 0; i < sgLayout.size(); ++i) {
       Value offset = rewriter.createOrFold<index::MulOp>(
           loc, sgIds[i],
           arith::ConstantIndexOp::create(rewriter, loc, slmSgData[i]));
@@ -1059,14 +1057,14 @@ struct WgToSgMultiDimReductionOp
       if (static_cast<int64_t>(i) == reduceDim) {
         slmLoadShape.push_back(memDescShape[i]);
       } else {
-        int64_t divisor = computeProduct(sgLayoutParent);
+        int64_t divisor = computeProduct(sgLayout);
         slmLoadShape.push_back(memDescShape[i] / divisor);
       }
     }
 
     // Calculate offsets for load_matrix op
     SmallVector<OpFoldResult> slmLoadOffsets;
-    for (size_t i = 0; i < sgLayoutParent.size(); ++i) {
+    for (size_t i = 0; i < sgLayout.size(); ++i) {
       Value offset = rewriter.createOrFold<index::MulOp>(
           loc, sgIds[i],
           arith::ConstantIndexOp::create(rewriter, loc, slmLoadShape[i]));
>From ff3baed020510463408539c6e17a942f4a0f2353 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 4 Sep 2025 17:34:29 +0000
Subject: [PATCH 4/4] clean up
---
 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 8934865c49cd3..fb1eff1ae8c07 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -370,8 +370,8 @@ gpu.module @test_distribution {
   // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
   gpu.func @vector_reduce(%src: memref<256x128xf32>) {
     // CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<32xf32>
-    // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [8, 4], lane_data = [1, 1]>>
-    // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]] : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [8, 4], lane_data = [1, 1]>> -> vector<32x32xf32>
+    // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32> -> !xegpu.tensor_desc<32x32xf32>
+    // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]] : !xegpu.tensor_desc<32x32xf32>
     // CHECK: %[[REDUCE:.*]] = vector.multi_reduction <add>, {{%.*}}, %[[CST]] [0] : vector<32x32xf32> to vector<32xf32>
     // CHECK: %[[SHAPECAST:.*]] = vector.shape_cast %[[REDUCE]] : vector<32xf32> to vector<1x32xf32>
     // CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref<4096xi8, 3>
@@ -396,9 +396,9 @@ gpu.module @test_distribution {
     // CHECK: %[[REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD]], %[[CST]] [0] : vector<8x4xf32> to vector<4xf32>
     %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>, dims = [0]>} dense<1.0> : vector<128xf32>
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
-      -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
+      -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>>
     %load =  xegpu.load_nd %tdesc
-      : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
+      : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>>
       -> vector<256x128xf32>
     %reduce = vector.multi_reduction <add>, %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>, dims = [0]>} [0]
       : vector<256x128xf32> to vector<128xf32>
    
    
More information about the Mlir-commits
mailing list