[Mlir-commits] [mlir] [MLIR][XeGPU] Preserve leading unit dimension during blocking (PR #180884)

Jianhui Li llvmlistbot at llvm.org
Tue Feb 10 20:59:53 PST 2026


https://github.com/Jianhui-Li created https://github.com/llvm/llvm-project/pull/180884

This PR preserve leading dimension during blocking. This ensures the blocking process avoid generating unnecessary insert/extract_strided_slice, which under certain condition becomes difficult to be canceled, and creates extra burden in lane layout propagation and subgroup distribution. 
This PR also extended subgroup distribution so load and store can support payload/mask/offsets with leading unit dimension. The distributed load/store works on 1d only, but shapecast is inserted to remove and add the leading dimension for the input/output vectors. Comparing to the insert/extract inserted at subgroup level, the shapecast inserted at lane level handling leading unit dimension is essentially a nop and can be processed lightly. 

>From 17d9b0dd3a88f52f40d0041451a84c7c1509ed12 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 10 Feb 2026 23:03:23 +0000
Subject: [PATCH 1/4] allow leading dim for scatter load in blocking and sg
 distribution

---
 .../XeGPU/Transforms/XeGPUBlocking.cpp        |  8 +-
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 88 ++++++++++++++-----
 .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp  |  7 ++
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp   | 15 ++++
 4 files changed, 96 insertions(+), 22 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 6faa25cf49df9..8b8647fbc8757 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -162,13 +162,19 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {
           ownerOp &&
           (isa<xegpu::CreateNdDescOp, xegpu::DpasOp, xegpu::ConvertLayoutOp,
                xegpu::LoadMatrixOp, xegpu::StoreMatrixOp, xegpu::AtomicRMWOp,
-               xegpu::LoadNdOp, xegpu::StoreNdOp, xegpu::PrefetchNdOp,
+               xegpu::LoadGatherOp, xegpu::StoreScatterOp, xegpu::LoadNdOp, xegpu::StoreNdOp, xegpu::PrefetchNdOp,
                vector::TransposeOp, vector::ShapeCastOp,
                vector::MultiDimReductionOp, vector::BroadcastOp>(ownerOp));
       if (!skipLeadingUnitDimRemoval) {
         auto it = llvm::find_if(instData, [](auto val) { return val != 1; });
         instData.erase(instData.begin(), it);
       }
+      
+      //print instData for debugging
+      llvm::errs() << "Inst data for value " << value << " in op " << *ownerOp << ": ";
+      for (auto dim : instData)        llvm::errs() << dim << " ";
+      llvm::errs() << "\n"; 
+
       return instData;
     }
 
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index aa1dfaa9e0fda..25462a93e5393 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1058,6 +1058,15 @@ struct StoreMatrixDistribution final : public gpu::WarpDistributionPattern {
 /// To
 ///    %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
 ///     memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
+///
+/// Note that the load distribution pattern also handles leading unit dimensions
+/// in the payload vector to support cases where the load is distributed by
+/// the warp op with unit dimensions added to the front of the vector. In this
+/// case the load distribution will only change the dimensions corresponding to
+/// the SG distribution and keep the leading unit dimensions unchanged. For
+/// example, a load with result type vector<1x16xf16> distributed by the warp op
+/// with layout expecting vector<1x16xf16> will be transformed to have result
+/// type vector<1x1xf16>.
 struct LoadDistribution final : public gpu::WarpDistributionPattern {
   using gpu::WarpDistributionPattern::WarpDistributionPattern;
   LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
@@ -1082,19 +1091,23 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
           "Load op must have a vector arguments for offsets and mask");
     VectorType offsetsTy = cast<VectorType>(offsets.getType());
     VectorType maskTy = cast<VectorType>(loadGatherOp.getMask().getType());
-    if (offsetsTy.getRank() != 1 || maskTy.getRank() != 1)
-      return rewriter.notifyMatchFailure(loadGatherOp,
-                                         "Expected 1D offsets and mask vector");
-    // Assume offset and mask producers will be distributed as well.
-    std::string layoutOffsetsName =
-        xegpu::getTemporaryLayoutName(loadGatherOp->getOpOperand(1));
-    std::string layoutMaskName =
-        xegpu::getTemporaryLayoutName(loadGatherOp->getOpOperand(2));
+    VectorType resultVecTy =
+        cast<VectorType>(loadGatherOp.getResult().getType());
+
+    // add handling leading unit dimensions support
+    int chunkSize = loadGatherOp.getChunkSize().value_or(1);
+    int effectiveVecRank = chunkSize > 1 ? 1 : 2;
+    for (int i = 0; i < resultVecTy.getRank() - effectiveVecRank; i++) {
+      if (resultVecTy.getShape()[i] != 1) {
+        return rewriter.notifyMatchFailure(
+            loadGatherOp, "Only unit dimensions allowed for the leading "
+                          "dimensions of the load vector!");
+      }
+    }
 
-    xegpu::LayoutAttr layoutOffsets =
-        loadGatherOp->getAttrOfType<xegpu::LayoutAttr>(layoutOffsetsName);
-    xegpu::LayoutAttr layoutMask =
-        loadGatherOp->getAttrOfType<xegpu::LayoutAttr>(layoutMaskName);
+    auto layoutOffsets =
+        xegpu::getTemporaryLayout(loadGatherOp->getOpOperand(1));
+    auto layoutMask = xegpu::getTemporaryLayout(loadGatherOp->getOpOperand(2));
 
     FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
         getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
@@ -1109,26 +1122,59 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
 
     SmallVector<size_t> newRetIndices;
     SmallVector<Value> operands = loadGatherOp->getOperands();
-    SmallVector<Type> operandTypesToYield = {
-        operands[0].getType(), distOffsetsByWarpOpOrFailure.value(),
-        distMaskByWarpOpOrFailure.value()};
 
     const unsigned operandIdx = producedByLastLoad->getOperandNumber();
     VectorType distResultTy =
         cast<VectorType>(warpOp.getResult(operandIdx).getType());
-    // Distributed load op will always be 1D.
-    VectorType loadVecTy = VectorType::get({distResultTy.getNumElements()},
-                                           distResultTy.getElementType());
+    VectorType distOffsetsTy = distOffsetsByWarpOpOrFailure.value();
+    VectorType distMaskTy = distMaskByWarpOpOrFailure.value();
+
+    SmallVector<Type> operandTypesToYield = {operands[0].getType(),
+                                             distOffsetsTy, distMaskTy};
+
+    // Debug print
+    llvm::errs() << "LoadDistribution: operands.size() = " << operands.size()
+                 << "\n";
+    llvm::errs() << "LoadDistribution: operandTypesToYield.size() = "
+                 << operandTypesToYield.size() << "\n";
+    for (size_t i = 0; i < operands.size(); ++i) {
+      llvm::errs() << "  operand[" << i << "] type: " << operands[i].getType()
+                   << "\n";
+    }
+    for (size_t i = 0; i < operandTypesToYield.size(); ++i) {
+      llvm::errs() << "  operandTypesToYield[" << i
+                   << "]: " << operandTypesToYield[i] << "\n";
+    }
 
     gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
         rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
 
-    SmallVector<Value> newLoadGatherOperands = llvm::map_to_vector(
-        newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
+    // SmallVector<Value> newLoadGatherOperands = llvm::map_to_vector(
+    //     newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
 
     rewriter.setInsertionPointAfter(newWarpOp);
+
+    // Distributed load op will always be 1D.
+    VectorType loadVecTy1D = VectorType::get({distResultTy.getNumElements()},
+                                             distResultTy.getElementType());
+
+    VectorType distOffsetsTy1D =
+        VectorType::get({distOffsetsByWarpOpOrFailure.value().getNumElements()},
+                        distOffsetsByWarpOpOrFailure.value().getElementType());
+    VectorType distMaskTy1D =
+        VectorType::get({distMaskByWarpOpOrFailure.value().getNumElements()},
+                        distMaskByWarpOpOrFailure.value().getElementType());
+
+    Value distOffsetVal = resolveDistributedTy(
+        newWarpOp.getResult(newRetIndices[1]), distOffsetsTy1D, rewriter);
+    Value distmaskVal = resolveDistributedTy(
+        newWarpOp.getResult(newRetIndices[2]), distMaskTy1D, rewriter);
+
+    SmallVector<Value> newLoadGatherOperands = {
+        newWarpOp.getResult(newRetIndices[0]), distOffsetVal, distmaskVal};
+
     xegpu::LoadGatherOp newOp = xegpu::LoadGatherOp::create(
-        rewriter, newWarpOp.getLoc(), loadVecTy, newLoadGatherOperands,
+        rewriter, newWarpOp.getLoc(), loadVecTy1D, newLoadGatherOperands,
         loadGatherOp->getAttrs());
     xegpu::removeLayoutAttrs(newOp);
     Value distributedVal = newWarpOp.getResult(operandIdx);
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 2b1bd4d73a576..792fe12b88397 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -728,6 +728,13 @@ struct UnrollStoreScatterOpWithOffsets
       return failure();
 
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+    //print target shape for debugging  
+    llvm::errs() << "Target shape: ";
+    if (targetShape) {
+      for (auto dim : *targetShape)
+        llvm::errs() << dim << " ";
+    }
+    llvm::errs() << "\n";
     if (!targetShape)
       return failure();
 
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index c47fd92fe46d7..fb5dff8981c9d 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -390,11 +390,26 @@ xegpu::extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc,
   for (SmallVector<int64_t> offsets :
        StaticTileOffsetRange(srcShape, adjustedTargetShape)) {
     SmallVector<int64_t> staticStrides(offsets.size(), 1);
+    
+    // Debug print
+    llvm::errs() << "Extracting slice with offsets: [";
+    for (size_t i = 0; i < offsets.size(); ++i) {
+      llvm::errs() << offsets[i];
+      if (i + 1 < offsets.size()) llvm::errs() << ", ";
+    }
+    llvm::errs() << "], shape: [";
+    for (size_t i = 0; i < adjustedTargetShape.size(); ++i) {
+      llvm::errs() << adjustedTargetShape[i];
+      if (i + 1 < adjustedTargetShape.size()) llvm::errs() << ", ";
+    }
+    llvm::errs() << "]\n";
+    
     Value slice = vector::ExtractStridedSliceOp::create(
         builder, loc, value, offsets, adjustedTargetShape, staticStrides);
 
     // Reshape to remove leading unit dims if needed
     if (srcShapeRank > targetShapeRank) {
+      llvm::errs() << "Reshaping from rank " << srcShapeRank << " to rank " << targetShapeRank << "\n";
       auto targetTy = VectorType::get(shape, vecTy.getElementType());
       slice = vector::ShapeCastOp::create(builder, loc, targetTy, slice);
     }

>From 21517cf3bab285af7bf50f940fe5e16453f75e54 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 10 Feb 2026 23:29:22 +0000
Subject: [PATCH 2/4] minor fix of chunk_size

---
 .../Transforms/XeGPUSubgroupDistribute.cpp      | 17 +----------------
 1 file changed, 1 insertion(+), 16 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 25462a93e5393..4ec01d4517a65 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1093,10 +1093,9 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
     VectorType maskTy = cast<VectorType>(loadGatherOp.getMask().getType());
     VectorType resultVecTy =
         cast<VectorType>(loadGatherOp.getResult().getType());
-
     // add handling leading unit dimensions support
     int chunkSize = loadGatherOp.getChunkSize().value_or(1);
-    int effectiveVecRank = chunkSize > 1 ? 1 : 2;
+    int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
     for (int i = 0; i < resultVecTy.getRank() - effectiveVecRank; i++) {
       if (resultVecTy.getShape()[i] != 1) {
         return rewriter.notifyMatchFailure(
@@ -1132,20 +1131,6 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
     SmallVector<Type> operandTypesToYield = {operands[0].getType(),
                                              distOffsetsTy, distMaskTy};
 
-    // Debug print
-    llvm::errs() << "LoadDistribution: operands.size() = " << operands.size()
-                 << "\n";
-    llvm::errs() << "LoadDistribution: operandTypesToYield.size() = "
-                 << operandTypesToYield.size() << "\n";
-    for (size_t i = 0; i < operands.size(); ++i) {
-      llvm::errs() << "  operand[" << i << "] type: " << operands[i].getType()
-                   << "\n";
-    }
-    for (size_t i = 0; i < operandTypesToYield.size(); ++i) {
-      llvm::errs() << "  operandTypesToYield[" << i
-                   << "]: " << operandTypesToYield[i] << "\n";
-    }
-
     gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
         rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
 

>From bbc3d41cd679079061de22ffe80a0d5d9edf870e Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 11 Feb 2026 04:42:23 +0000
Subject: [PATCH 3/4] preserve leading dimension during blocking

---
 .../XeGPU/Transforms/XeGPUBlocking.cpp        |  31 -----
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 109 ++++++++++--------
 mlir/test/Dialect/XeGPU/xegpu-blocking.mlir   |  68 +++++------
 3 files changed, 95 insertions(+), 113 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 8b8647fbc8757..4eb6ad51ee9bf 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -152,32 +152,8 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {
   if (layout && layout.isForSubgroup()) {
     if (!layout.getEffectiveInstDataAsInt().empty()) {
       SmallVector<int64_t> instData = layout.getEffectiveInstDataAsInt();
-      // Remove leading unit dimensions from inst_data for non-rank-sensitive
-      // ops. For example, if the inst_data is [1, 1, 32] it will pass [32] as
-      // the unroll/blocking size.
-      // Skip it for rank-sensitive ops, whose semantics depend on the tensor
-      // rank (and consequently its shape), and therefore must not alter the
-      // input tile rank or shape, such as by dropping leading dimensions.
-      bool skipLeadingUnitDimRemoval =
-          ownerOp &&
-          (isa<xegpu::CreateNdDescOp, xegpu::DpasOp, xegpu::ConvertLayoutOp,
-               xegpu::LoadMatrixOp, xegpu::StoreMatrixOp, xegpu::AtomicRMWOp,
-               xegpu::LoadGatherOp, xegpu::StoreScatterOp, xegpu::LoadNdOp, xegpu::StoreNdOp, xegpu::PrefetchNdOp,
-               vector::TransposeOp, vector::ShapeCastOp,
-               vector::MultiDimReductionOp, vector::BroadcastOp>(ownerOp));
-      if (!skipLeadingUnitDimRemoval) {
-        auto it = llvm::find_if(instData, [](auto val) { return val != 1; });
-        instData.erase(instData.begin(), it);
-      }
-      
-      //print instData for debugging
-      llvm::errs() << "Inst data for value " << value << " in op " << *ownerOp << ": ";
-      for (auto dim : instData)        llvm::errs() << dim << " ";
-      llvm::errs() << "\n"; 
-
       return instData;
     }
-
     if (auto type = dyn_cast<ShapedType>(value.getType()))
       return llvm::to_vector(type.getShape());
   }
@@ -356,13 +332,6 @@ void XeGPUBlockingPass::runOnOperation() {
 
   xegpu::doSCFStructuralTypeConversionWithTensorType(op, converter);
 
-  // Remove leading unit dimensions from vector ops and then
-  // do the unrolling.
-  {
-    RewritePatternSet patterns(ctx);
-    vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
-    (void)applyPatternsGreedily(op, std::move(patterns));
-  }
   xegpu::UnrollOptions options;
   options.setFilterConstraint(
       [&](Operation *op) -> LogicalResult { return success(needsUnroll(op)); });
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 4ec01d4517a65..ce5f4f887e910 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -778,6 +778,15 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
 /// To
 ///    xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
 ///     vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+///
+/// Note that the store distribution pattern also handles leading unit
+/// dimensions in the payload, mask and offsets vectors. In this case the store
+/// distribution will only change the dimensions corresponding to the SG
+/// distribution and keep the leading unit dimensions unchanged.
+/// For example, a store with payload vector<1x16xf16> with lane layout [1, 16 ]
+/// will be distributed as vector<1x1xf16>. Shapecast ops are inserted for the
+/// offset/mask/payload when necessary so that the distributed store is workign
+/// on 1D shape vector to match the HW capability.
 struct StoreDistribution final : public gpu::WarpDistributionPattern {
   using gpu::WarpDistributionPattern::WarpDistributionPattern;
   LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
@@ -792,30 +801,27 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
           storeScatterOp, "Store op must have a vector of offsets argument");
     VectorType offsetsTy = cast<VectorType>(offsets.getType());
     VectorType maskTy = cast<VectorType>(storeScatterOp.getMask().getType());
-    if (offsetsTy.getRank() != 1 || maskTy.getRank() != 1)
-      return rewriter.notifyMatchFailure(storeScatterOp,
-                                         "Expected 1D offsets and mask vector");
     VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType());
-    if (storeVecTy.getRank() > 2)
-      return rewriter.notifyMatchFailure(
-          storeScatterOp, "Expected at most 2D result at SG level");
-
-    std::string layoutPayloadName =
-        xegpu::getTemporaryLayoutName(storeScatterOp->getOpOperand(0));
-    std::string layoutOffsetsName =
-        xegpu::getTemporaryLayoutName(storeScatterOp->getOpOperand(2));
-    std::string layoutMaskName =
-        xegpu::getTemporaryLayoutName(storeScatterOp->getOpOperand(3));
-
-    xegpu::DistributeLayoutAttr layoutPayload =
-        storeScatterOp->getAttrOfType<xegpu::DistributeLayoutAttr>(
-            layoutPayloadName);
-    xegpu::DistributeLayoutAttr layoutOffsets =
-        storeScatterOp->getAttrOfType<xegpu::DistributeLayoutAttr>(
-            layoutOffsetsName);
-    xegpu::DistributeLayoutAttr layoutMask =
-        storeScatterOp->getAttrOfType<xegpu::DistributeLayoutAttr>(
-            layoutMaskName);
+
+    // Add handling for leading unit dimensions support
+    int chunkSize = storeScatterOp.getChunkSize().value_or(1);
+    int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
+
+    // Check that all leading dimensions are unit dimensions
+    for (int i = 0; i < storeVecTy.getRank() - effectiveVecRank; i++) {
+      if (storeVecTy.getShape()[i] != 1) {
+        return rewriter.notifyMatchFailure(
+            storeScatterOp, "Only unit dimensions allowed for the leading "
+                            "dimensions of the store vector!");
+      }
+    }
+
+    auto layoutPayload =
+        xegpu::getTemporaryLayout(storeScatterOp->getOpOperand(0));
+    auto layoutOffsets =
+        xegpu::getTemporaryLayout(storeScatterOp->getOpOperand(2));
+    auto layoutMask =
+        xegpu::getTemporaryLayout(storeScatterOp->getOpOperand(3));
 
     FailureOr<VectorType> distStoreVecByWarpOpOrFailure =
         getDistVecTypeBasedOnLaneLayout(layoutPayload, storeVecTy);
@@ -830,29 +836,42 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
           storeScatterOp,
           "Some vector operands have no layouts, using defaults instead.");
     }
-    // Distributed store payload type according to the lane layout.
-    VectorType distPayloadTyByWarpOp = distStoreVecByWarpOpOrFailure.value();
-    // Expected distributed payload type is always 1D.
-    VectorType expectedPayloadTy =
-        VectorType::get({distPayloadTyByWarpOp.getNumElements()},
-                        distPayloadTyByWarpOp.getElementType());
+
+    VectorType distPayloadTy = distStoreVecByWarpOpOrFailure.value();
+    VectorType distOffsetsTy = distOffsetsByWarpOpOrFailure.value();
+    VectorType distMaskTy = distMaskByWarpOpOrFailure.value();
 
     SmallVector<size_t> newRetIndices;
     SmallVector<Value> operands = storeScatterOp->getOperands();
     SmallVector<Type> operandTypesToYield = {
-        distPayloadTyByWarpOp, operands[1].getType(),
-        distOffsetsByWarpOpOrFailure.value(),
-        distMaskByWarpOpOrFailure.value()};
+        distPayloadTy, operands[1].getType(), distOffsetsTy, distMaskTy};
 
     gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
         rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
-    SmallVector<Value> newStoreScatterOpOperands = llvm::map_to_vector(
-        newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
-    // The payload operand may need type adjustment due to mismatch between warp
-    // distributed type and expected SIMT type.
+
     rewriter.setInsertionPointAfter(newWarpOp);
-    newStoreScatterOpOperands[0] = resolveDistributedTy(
-        newStoreScatterOpOperands[0], expectedPayloadTy, rewriter);
+
+    // Distributed store payload type is always 1D without leading unit dims
+    VectorType payloadTy1D = VectorType::get({distPayloadTy.getNumElements()},
+                                             distPayloadTy.getElementType());
+
+    VectorType distOffsetsTy1D = VectorType::get(
+        {distOffsetsTy.getNumElements()}, distOffsetsTy.getElementType());
+    VectorType distMaskTy1D = VectorType::get({distMaskTy.getNumElements()},
+                                              distMaskTy.getElementType());
+
+    // Resolve distributed types to 1D for SIMT execution
+    Value distPayloadVal = resolveDistributedTy(
+        newWarpOp.getResult(newRetIndices[0]), payloadTy1D, rewriter);
+    Value distOffsetVal = resolveDistributedTy(
+        newWarpOp.getResult(newRetIndices[2]), distOffsetsTy1D, rewriter);
+    Value distMaskVal = resolveDistributedTy(
+        newWarpOp.getResult(newRetIndices[3]), distMaskTy1D, rewriter);
+
+    SmallVector<Value> newStoreScatterOpOperands = {
+        distPayloadVal, newWarpOp.getResult(newRetIndices[1]), distOffsetVal,
+        distMaskVal};
+
     xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create(
         rewriter, newWarpOp.getLoc(), TypeRange{}, newStoreScatterOpOperands,
         storeScatterOp->getAttrs());
@@ -1060,13 +1079,13 @@ struct StoreMatrixDistribution final : public gpu::WarpDistributionPattern {
 ///     memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
 ///
 /// Note that the load distribution pattern also handles leading unit dimensions
-/// in the payload vector to support cases where the load is distributed by
-/// the warp op with unit dimensions added to the front of the vector. In this
-/// case the load distribution will only change the dimensions corresponding to
-/// the SG distribution and keep the leading unit dimensions unchanged. For
-/// example, a load with result type vector<1x16xf16> distributed by the warp op
-/// with layout expecting vector<1x16xf16> will be transformed to have result
-/// type vector<1x1xf16>.
+/// in the payload, mask, and offsets vector.The load distribution will only
+/// change the dimensions corresponding to the SG distribution and keep the
+/// leading unit dimensions unchanged. For example, a load with result type
+/// vector<1x16xf16> with lane layout [1, 16 ] will be distributed
+/// as result type vector<1x1xf16>. Shapecast ops are inserted for the
+/// offset/mask/payload when necessary so that the distributed load is workign
+/// on 1D shape vector to match the HW capability.
 struct LoadDistribution final : public gpu::WarpDistributionPattern {
   using gpu::WarpDistributionPattern::WarpDistributionPattern;
   LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
diff --git a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
index 68f6e8e1ec955..e80a9144b9674 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
@@ -740,17 +740,17 @@ gpu.module @test_kernel {
 
 // -----
 gpu.module @test_kernel {
-  // CHECK-LABEL: remove_unit_dim_inst_data
+  // CHECK-LABEL: preserve_unit_dim_of_load_inst_data
   // CHECK-SAME: [[arg0:%.+]]: ui64
   // CHECK: [[cst:%.+]] = arith.constant dense<0.000000e+00> : vector<1x1x32xf32>
-  // CHECK: [[cst_0:%.+]] = arith.constant dense<true> : vector<16xi1>
-  // CHECK: [[cst_1:%.+]] = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]> : vector<16xindex>
-  // CHECK: [[cst_2:%.+]] = arith.constant dense<[128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248]> : vector<16xindex>
-  // CHECK: [[ld_0:%.+]] = xegpu.load [[arg0]][[[cst_1]]], [[cst_0]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : ui64, vector<16xindex>, vector<16xi1> -> vector<16xf32>
-  // CHECK: [[ld_1:%.+]] = xegpu.load [[arg0]][[[cst_2]]], [[cst_0]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : ui64, vector<16xindex>, vector<16xi1> -> vector<16xf32>
-  // CHECK: [[ins_0:%.+]] = vector.insert_strided_slice [[ld_0]], [[cst]] {offsets = [0, 0, 0], strides = [1]} : vector<16xf32> into vector<1x1x32xf32>
-  // CHECK: [[ins_1:%.+]] = vector.insert_strided_slice [[ld_1]], [[ins_0]] {offsets = [0, 0, 16], strides = [1]} : vector<16xf32> into vector<1x1x32xf32>
-  gpu.func @remove_unit_dim_inst_data(%src: ui64) -> vector<1x1x32xf32> {
+  // CHECK: [[cst_0:%.+]] = arith.constant dense<true> : vector<1x1x16xi1>
+  // CHECK: [[cst_1:%.+]] = arith.constant dense<{{.*}}> : vector<1x1x16xindex>
+  // CHECK: [[cst_2:%.+]] = arith.constant dense<{{.*}}> : vector<1x1x16xindex>
+  // CHECK: [[ld_0:%.+]] = xegpu.load [[arg0]][[[cst_1]]], [[cst_0]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : ui64, vector<1x1x16xindex>, vector<1x1x16xi1> -> vector<1x1x16xf32>
+  // CHECK: [[ld_1:%.+]] = xegpu.load [[arg0]][[[cst_2]]], [[cst_0]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : ui64, vector<1x1x16xindex>, vector<1x1x16xi1> -> vector<1x1x16xf32>
+  // CHECK: [[ins_0:%.+]] = vector.insert_strided_slice [[ld_0]], [[cst]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<1x1x16xf32> into vector<1x1x32xf32>
+  // CHECK: [[ins_1:%.+]] = vector.insert_strided_slice [[ld_1]], [[ins_0]] {offsets = [0, 0, 16], strides = [1, 1, 1]} : vector<1x1x16xf32> into vector<1x1x32xf32>
+  gpu.func @preserve_unit_dim_of_load_inst_data(%src: ui64) -> vector<1x1x32xf32> {
       %cst = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 1, 16]>} dense<[[
       [0,   8,  16,  24,  32,  40,  48,  56,
       64,  72,  80,  88,  96, 104, 112, 120,
@@ -770,8 +770,6 @@ gpu.module @test_kernel {
 gpu.module @test_kernel {
   // CHECK-LABEL: load_store_nd_with_offsets
   // CHECK-SAME: [[arg0:%.+]]: memref<1024x1024xf32>, [[arg1:%.+]]: memref<1024x1024xf32>, [[arg2:%.+]]: memref<1024x1024xf32>
-  // CHECK-DAG: [[cst:%.+]] = arith.constant dense<0.000000e+00> : vector<32xf32>
-  // CHECK-DAG: [[cst_0:%.+]] = arith.constant dense<0.000000e+00> : vector<1x32xf32>
   // CHECK-DAG: [[c16:%.+]] = arith.constant 16 : index
   // CHECK-DAG: [[c0:%.+]] = arith.constant 0 : index
   // CHECK: [[tdesc_a:%.+]] = xegpu.create_nd_tdesc [[arg0]] : memref<1024x1024xf32> -> !xegpu.tensor_desc<1x16xf32>
@@ -779,27 +777,12 @@ gpu.module @test_kernel {
   // CHECK: [[tdesc_c:%.+]] = xegpu.create_nd_tdesc [[arg2]] : memref<1024x1024xf32> -> !xegpu.tensor_desc<1x16xf32>
   // CHECK: [[ld_a0:%.+]] = xegpu.load_nd [[tdesc_a]][[[c0]], [[c0]]]  : !xegpu.tensor_desc<1x16xf32> -> vector<1x16xf32>
   // CHECK: [[ld_a1:%.+]] = xegpu.load_nd [[tdesc_a]][[[c0]], [[c16]]]  : !xegpu.tensor_desc<1x16xf32> -> vector<1x16xf32>
-  // CHECK: [[ins_a0:%.+]] = vector.insert_strided_slice [[ld_a0]], [[cst_0]] {offsets = [0, 0], strides = [1, 1]} : vector<1x16xf32> into vector<1x32xf32>
-  // CHECK: [[ins_a1:%.+]] = vector.insert_strided_slice [[ld_a1]], [[ins_a0]] {offsets = [0, 16], strides = [1, 1]} : vector<1x16xf32> into vector<1x32xf32>
   // CHECK: [[ld_b0:%.+]] = xegpu.load_nd [[tdesc_b]][[[c0]], [[c0]]]  : !xegpu.tensor_desc<1x16xf32> -> vector<1x16xf32>
   // CHECK: [[ld_b1:%.+]] = xegpu.load_nd [[tdesc_b]][[[c0]], [[c16]]]  : !xegpu.tensor_desc<1x16xf32> -> vector<1x16xf32>
-  // CHECK: [[ins_b0:%.+]] = vector.insert_strided_slice [[ld_b0]], [[cst_0]] {offsets = [0, 0], strides = [1, 1]} : vector<1x16xf32> into vector<1x32xf32>
-  // CHECK: [[ins_b1:%.+]] = vector.insert_strided_slice [[ld_b1]], [[ins_b0]] {offsets = [0, 16], strides = [1, 1]} : vector<1x16xf32> into vector<1x32xf32>
-  // CHECK: [[ext_a:%.+]] = vector.extract [[ins_a1]][0] : vector<32xf32> from vector<1x32xf32>
-  // CHECK: [[ext_b:%.+]] = vector.extract [[ins_b1]][0] : vector<32xf32> from vector<1x32xf32>
-  // CHECK: [[slice_a0:%.+]] = vector.extract_strided_slice [[ext_a]] {offsets = [0], sizes = [16], strides = [1]} : vector<32xf32> to vector<16xf32>
-  // CHECK: [[slice_b0:%.+]] = vector.extract_strided_slice [[ext_b]] {offsets = [0], sizes = [16], strides = [1]} : vector<32xf32> to vector<16xf32>
-  // CHECK: [[add0:%.+]] = arith.addf [[slice_a0]], [[slice_b0]] : vector<16xf32>
-  // CHECK: [[ins_add0:%.+]] = vector.insert_strided_slice [[add0]], [[cst]] {offsets = [0], strides = [1]} : vector<16xf32> into vector<32xf32>
-  // CHECK: [[slice_a1:%.+]] = vector.extract_strided_slice [[ext_a]] {offsets = [16], sizes = [16], strides = [1]} : vector<32xf32> to vector<16xf32>
-  // CHECK: [[slice_b1:%.+]] = vector.extract_strided_slice [[ext_b]] {offsets = [16], sizes = [16], strides = [1]} : vector<32xf32> to vector<16xf32>
-  // CHECK: [[add1:%.+]] = arith.addf [[slice_a1]], [[slice_b1]] : vector<16xf32>
-  // CHECK: [[ins_add1:%.+]] = vector.insert_strided_slice [[add1]], [[ins_add0]] {offsets = [16], strides = [1]} : vector<16xf32> into vector<32xf32>
-  // CHECK: [[broadcast:%.+]] = vector.broadcast [[ins_add1]] : vector<32xf32> to vector<1x32xf32>
-  // CHECK: [[ext_result0:%.+]] = vector.extract_strided_slice [[broadcast]] {offsets = [0, 0], sizes = [1, 16], strides = [1, 1]} : vector<1x32xf32> to vector<1x16xf32>
-  // CHECK: [[ext_result1:%.+]] = vector.extract_strided_slice [[broadcast]] {offsets = [0, 16], sizes = [1, 16], strides = [1, 1]} : vector<1x32xf32> to vector<1x16xf32>
-  // CHECK: xegpu.store_nd [[ext_result0]], [[tdesc_c]][[[c0]], [[c0]]]  : vector<1x16xf32>, !xegpu.tensor_desc<1x16xf32>
-  // CHECK: xegpu.store_nd [[ext_result1]], [[tdesc_c]][[[c0]], [[c16]]]  : vector<1x16xf32>, !xegpu.tensor_desc<1x16xf32>
+  // CHECK: [[add0:%.+]] = arith.addf [[ld_a0]], [[ld_b0]] : vector<1x16xf32>
+  // CHECK: [[add1:%.+]] = arith.addf [[ld_a1]], [[ld_b1]] : vector<1x16xf32>
+  // CHECK: xegpu.store_nd [[add0]], [[tdesc_c]][[[c0]], [[c0]]]  : vector<1x16xf32>, !xegpu.tensor_desc<1x16xf32>
+  // CHECK: xegpu.store_nd [[add1]], [[tdesc_c]][[[c0]], [[c16]]]  : vector<1x16xf32>, !xegpu.tensor_desc<1x16xf32>
   gpu.func @load_store_nd_with_offsets(%A: memref<1024x1024xf32>, %B: memref<1024x1024xf32>, %C: memref<1024x1024xf32>) {
     %c0 = arith.constant 0 : index
 
@@ -817,16 +800,27 @@ gpu.module @test_kernel {
 }
 
 // -----
-#inst_data = #xegpu.layout<inst_data = [1, 1, 32]>
+#inst_data = #xegpu.layout<inst_data = [1, 1, 16]>
 gpu.module @test_kernel {
   // CHECK-LABEL: load_add_store_leading_unit_dims
   // CHECK-SAME: [[arg0:%.+]]: ui64, [[arg1:%.+]]: ui64, [[arg2:%.+]]: ui64
-  // CHECK: [[mask:%.+]] = arith.constant dense<true> : vector<32xi1>
-  // CHECK: [[offsets:%.+]] = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248]> : vector<32xindex>
-  // CHECK: [[a:%.+]] = xegpu.load [[arg0]][[[offsets]]], [[mask]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : ui64, vector<32xindex>, vector<32xi1> -> vector<32xf32>
-  // CHECK: [[b:%.+]] = xegpu.load [[arg1]][[[offsets]]], [[mask]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : ui64, vector<32xindex>, vector<32xi1> -> vector<32xf32>
-  // CHECK: [[add:%.+]] = arith.addf [[a]], [[b]] : vector<32xf32>
-  // CHECK: xegpu.store [[add]], [[arg2]][[[offsets]]], [[mask]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<32xf32>, ui64, vector<32xindex>, vector<32xi1>
+    // CHECK: [[c0:%.+]] = arith.constant dense<true> : vector<1x1x16xi1>
+    // CHECK: [[c1:%.+]] = arith.constant dense<[{{\[\[}}0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]]]> : vector<1x1x16xindex>
+    // CHECK: [[c2:%.+]] = arith.constant dense<[{{\[\[}}128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248]]]> : vector<1x1x16xindex>
+    // CHECK: [[v0:%.+]] = xegpu.load [[arg0]]{{\[}}[[c1]]], [[c0]]
+    // CHECK-SAME: ui64, vector<1x1x16xindex>, vector<1x1x16xi1> -> vector<1x1x16xf32>
+    // CHECK: [[v1:%.+]] = xegpu.load [[arg0]]{{\[}}[[c2]]], [[c0]]
+    // CHECK-SAME: ui64, vector<1x1x16xindex>, vector<1x1x16xi1> -> vector<1x1x16xf32>
+    // CHECK: [[v2:%.+]] = xegpu.load [[arg1]]{{\[}}[[c1]]], [[c0]]
+    // CHECK-SAME: ui64, vector<1x1x16xindex>, vector<1x1x16xi1> -> vector<1x1x16xf32>
+    // CHECK: [[v3:%.+]] = xegpu.load [[arg1]]{{\[}}[[c2]]], [[c0]]
+    // CHECK-SAME: ui64, vector<1x1x16xindex>, vector<1x1x16xi1> -> vector<1x1x16xf32>
+    // CHECK: [[v4:%.+]] = arith.addf [[v0]], [[v2]] : vector<1x1x16xf32>
+    // CHECK: [[v5:%.+]] = arith.addf [[v1]], [[v3]] : vector<1x1x16xf32>
+    // CHECK: xegpu.store [[v4]], [[arg2]]{{\[}}[[c1]]], [[c0]]
+    // CHECK-SAME: vector<1x1x16xf32>, ui64, vector<1x1x16xindex>, vector<1x1x16xi1>
+    // CHECK: xegpu.store [[v5]], [[arg2]]{{\[}}[[c2]]], [[c0]]
+    // CHECK-SAME: vector<1x1x16xf32>, ui64, vector<1x1x16xindex>, vector<1x1x16xi1>
   gpu.func @load_add_store_leading_unit_dims(%A: ui64, %B: ui64, %C: ui64) {
     %cst = arith.constant {layout_result_0 = #inst_data} dense<[
       [[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120,

>From bd81a4bfc1da3ca412ee59ed0f6c6c1b8988c9ec Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 11 Feb 2026 04:45:10 +0000
Subject: [PATCH 4/4] clean up

---
 .../lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp |  7 -------
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp      | 16 +---------------
 2 files changed, 1 insertion(+), 22 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 792fe12b88397..2b1bd4d73a576 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -728,13 +728,6 @@ struct UnrollStoreScatterOpWithOffsets
       return failure();
 
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
-    //print target shape for debugging  
-    llvm::errs() << "Target shape: ";
-    if (targetShape) {
-      for (auto dim : *targetShape)
-        llvm::errs() << dim << " ";
-    }
-    llvm::errs() << "\n";
     if (!targetShape)
       return failure();
 
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index fb5dff8981c9d..fa5810ad7f828 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -390,26 +390,12 @@ xegpu::extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc,
   for (SmallVector<int64_t> offsets :
        StaticTileOffsetRange(srcShape, adjustedTargetShape)) {
     SmallVector<int64_t> staticStrides(offsets.size(), 1);
-    
-    // Debug print
-    llvm::errs() << "Extracting slice with offsets: [";
-    for (size_t i = 0; i < offsets.size(); ++i) {
-      llvm::errs() << offsets[i];
-      if (i + 1 < offsets.size()) llvm::errs() << ", ";
-    }
-    llvm::errs() << "], shape: [";
-    for (size_t i = 0; i < adjustedTargetShape.size(); ++i) {
-      llvm::errs() << adjustedTargetShape[i];
-      if (i + 1 < adjustedTargetShape.size()) llvm::errs() << ", ";
-    }
-    llvm::errs() << "]\n";
-    
+
     Value slice = vector::ExtractStridedSliceOp::create(
         builder, loc, value, offsets, adjustedTargetShape, staticStrides);
 
     // Reshape to remove leading unit dims if needed
     if (srcShapeRank > targetShapeRank) {
-      llvm::errs() << "Reshaping from rank " << srcShapeRank << " to rank " << targetShapeRank << "\n";
       auto targetTy = VectorType::get(shape, vecTy.getElementType());
       slice = vector::ShapeCastOp::create(builder, loc, targetTy, slice);
     }



More information about the Mlir-commits mailing list