[Mlir-commits] [mlir] [mlir][xegpu] Add layout based SIMT distribution support for `vector.extract/insert_strided_slice` (PR #168626)

Charitha Saumya llvmlistbot at llvm.org
Mon Nov 24 14:46:04 PST 2025


https://github.com/charithaintc updated https://github.com/llvm/llvm-project/pull/168626

>From 4bef60fe9e55eba963179888e28c77568aeccf7e Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Mon, 17 Nov 2025 19:31:51 +0000
Subject: [PATCH 01/14] save work

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 34 +++++++++++++++++++
 1 file changed, 34 insertions(+)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index bbd7733e89c29..a125ed18119be 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1472,6 +1472,40 @@ struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
   }
 };
 
+static SmallVector<int64_t> getDistributedDims(VectorType sequentialType,
+                                               VectorType distributedType) {
+  assert(sequentialType.getRank() == distributedType.getRank() &&
+         "sequential and distributed vector types must have the same rank");
+  SmallVector<int64_t> distributedDims;
+  for (int64_t i = 0; i < sequentialType.getRank(); ++i) {
+    if (distributedType.getDimSize(i) != sequentialType.getDimSize(i)) {
+      distributedDims.push_back(i);
+    }
+  }
+  return distributedDims;
+}
+
+struct VectorExtractStridedSliceDistribution
+    : public gpu::WarpDistributionPattern {
+  using gpu::WarpDistributionPattern::WarpDistributionPattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *operand =
+        getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
+    if (!operand)
+      return failure();
+    auto extractOp =
+        cast<vector::ExtractStridedSliceOp>(operand->get().getDefiningOp());
+    unsigned operandIdx = operand->getOperandNumber();
+    auto distributedType =
+        cast<VectorType>(warpOp.getResult(operandIdx).getType());
+    // Find the distributed dimension. There should be exactly one.
+    auto yieldedType = cast<VectorType>(operand->get().getType());
+    auto distributedDims = getDistributedDims(yieldedType, distributedType);
+    return success();
+  }
+};
+
 /// Sink a memref::ExtractAlignedPointerAsIndex op feeding into yield op of an
 /// enclosing `gpu.warp_execute_on_lane_0` region. This will simply move the op
 /// outside of the warp op.

>From a261edcc92360fd7d6679bffe6923cdf1b271a11 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Mon, 17 Nov 2025 23:41:52 +0000
Subject: [PATCH 02/14] save work

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 124 +++++++++++++++---
 1 file changed, 108 insertions(+), 16 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 9865f31e2cbcd..e4f17f0abdc6b 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -35,6 +35,7 @@
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/LogicalResult.h"
 
 namespace mlir {
 namespace xegpu {
@@ -174,6 +175,19 @@ static bool requireTranspose(const xegpu::LayoutAttr layout,
   return laneLayout[0] == uArch->getSubgroupSize() && laneLayout[1] == 1;
 }
 
+static SmallVector<int64_t> getDistributedDims(VectorType sequentialType,
+                                               VectorType distributedType) {
+  assert(sequentialType.getRank() == distributedType.getRank() &&
+         "sequential and distributed vector types must have the same rank");
+  SmallVector<int64_t> distributedDims;
+  for (int64_t i = 0; i < sequentialType.getRank(); ++i) {
+    if (distributedType.getDimSize(i) != sequentialType.getDimSize(i)) {
+      distributedDims.push_back(i);
+    }
+  }
+  return distributedDims;
+}
+
 /// Given a GPUFuncOp, this pattern creates a new GPUFuncOp and moves the body
 /// of the original GPUFuncOp to the new GPUFuncOp such that entire body is
 /// contained within a WarpExecuteOnLane0Op.
@@ -1471,19 +1485,6 @@ struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
   }
 };
 
-static SmallVector<int64_t> getDistributedDims(VectorType sequentialType,
-                                               VectorType distributedType) {
-  assert(sequentialType.getRank() == distributedType.getRank() &&
-         "sequential and distributed vector types must have the same rank");
-  SmallVector<int64_t> distributedDims;
-  for (int64_t i = 0; i < sequentialType.getRank(); ++i) {
-    if (distributedType.getDimSize(i) != sequentialType.getDimSize(i)) {
-      distributedDims.push_back(i);
-    }
-  }
-  return distributedDims;
-}
-
 struct VectorExtractStridedSliceDistribution
     : public gpu::WarpDistributionPattern {
   using gpu::WarpDistributionPattern::WarpDistributionPattern;
@@ -1501,6 +1502,96 @@ struct VectorExtractStridedSliceDistribution
     // Find the distributed dimension. There should be exactly one.
     auto yieldedType = cast<VectorType>(operand->get().getType());
     auto distributedDims = getDistributedDims(yieldedType, distributedType);
+    // Only single dimension distribution is supported.
+    if (distributedDims.size() != 1)
+      return rewriter.notifyMatchFailure(
+          warpOp, "Expecting source to be distributed in a single dimension.");
+    int64_t distributedDim = distributedDims[0];
+    // Check if the distributed dimension is fully extracted. If so, we exit
+    // early becuase this case already handled by vector distribution patterns.
+    // Distributed dimension is fully extracted if:
+    //  1) Distributed dim comes after all the extracted dimensions.
+    //  2) Or, the size extacted along the distributed dimension is equal the
+    //  size of that dim in source vector.
+    auto extractedSizes = extractOp.getSizes();
+    if (distributedDim >= static_cast<int64_t>(extractedSizes.size()))
+      return rewriter.notifyMatchFailure(
+          warpOp, "Distributed dimension is fully extracted, skipping.");
+
+    int distrDimExtractedSize =
+        cast<IntegerAttr>(extractOp.getSizes()[distributedDim]).getInt();
+    if (distrDimExtractedSize ==
+        extractOp.getSourceVectorType().getShape()[distributedDim])
+      return rewriter.notifyMatchFailure(
+          warpOp, "Distributed dimension is fully extracted, skipping.");
+
+    // Check if the size extracted along the distributed dimension is a multiple
+    // of the source dim size and should be distributable to lanes.
+    int64_t sourceDisrDimSize = yieldedType.getShape()[distributedDim];
+    if (sourceDisrDimSize % distrDimExtractedSize != 0)
+      return rewriter.notifyMatchFailure(
+          warpOp,
+          "Extracted size along distributed dimension is not a multiple of "
+          "source dim size.");
+    auto sourceLayout =
+        xegpu::getDistributeLayoutAttr(extractOp->getOpOperand(0));
+    auto sourceLaneLayout = sourceLayout.getEffectiveLaneLayoutAsInt();
+    // Because only single dimension distribution is supported, lane layout size
+    // at the distributed dim must be the subgroup size.
+    int subgroupSize = sourceLaneLayout[distributedDim];
+    // Check if the distributed extracted dim is a multiple of the lane size.
+    if (distrDimExtractedSize % subgroupSize != 0)
+      return rewriter.notifyMatchFailure(
+          warpOp,
+          "Extracted size along distributed dimension is not a multiple of "
+          "lane size in source layout.");
+    auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
+    // We expect lane data to be all ones in this case.
+    if (!llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; }))
+      return rewriter.notifyMatchFailure(
+          warpOp, "Expecting unit lane data in source layout");
+    // The offsets in the distributed dimention must be a multiple of subgroup
+    // size.
+    int64_t distrDimOffset =
+        cast<IntegerAttr>(extractOp.getOffsets()[distributedDim]).getInt();
+    if (distrDimOffset % subgroupSize != 0)
+      return rewriter.notifyMatchFailure(warpOp,
+                                         "Offset along distributed dimension "
+                                         "is not a multiple of subgroup size.");
+    // Do the distribution by yielding the source of the extract op from
+    // the warp op and creating a new extract op outside the warp op.
+    FailureOr<VectorType> sourceDistTypeOrFailure =
+        getDistVecTypeBasedOnLaneLayout(sourceLayout,
+                                        extractOp.getSourceVectorType());
+    if (failed(sourceDistTypeOrFailure))
+      return rewriter.notifyMatchFailure(
+          warpOp, "failed to get distributed vector type for source");
+    VectorType sourceDistType = sourceDistTypeOrFailure.value();
+    // Create a new warp op that yields the source of the extract op.
+    SmallVector<size_t> newRetIndices;
+    auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, {extractOp.getSource()}, {sourceDistType},
+        newRetIndices);
+    rewriter.setInsertionPointAfter(newWarpOp);
+    // Distributed sizes and offsets must be adjusted.
+    SmallVector<Attribute> distributedSizes = llvm::map_to_vector(
+        extractOp.getSizes(), [](Attribute attr) { return attr; });
+    SmallVector<Attribute> distributedOffsets = llvm::map_to_vector(
+        extractOp.getOffsets(), [](Attribute attr) { return attr; });
+    // Update the distributed sizes to match the distributed type.
+    distributedSizes[distributedDim] =
+        rewriter.getI64IntegerAttr(distributedType.getDimSize(distributedDim));
+    // Update the distributed offsets to match round robin distribution.
+    distributedOffsets[distributedDim] = rewriter.getI64IntegerAttr(
+        distrDimOffset / subgroupSize); // because lane data is 1
+    Value source = newWarpOp.getResult(newRetIndices[0]);
+    // Create a new extract op outside the warp op.
+    Value newExtractOp = vector::ExtractStridedSliceOp::create(
+        rewriter, extractOp.getLoc(), distributedType, source,
+        ArrayAttr::get(rewriter.getContext(), distributedOffsets),
+        ArrayAttr::get(rewriter.getContext(), distributedSizes),
+        extractOp.getStrides());
+    rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), newExtractOp);
     return success();
   }
 };
@@ -1662,9 +1753,10 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
                MemrefExtractAlignedPointerAsIndexDistribution>(
       patterns.getContext(),
       /*pattern benefit=*/regularPatternBenefit);
-  patterns.add<VectorShapeCastDistribution>(
-      patterns.getContext(),
-      /*pattern benefit=*/highPatternBenefit);
+  patterns
+      .add<VectorShapeCastDistribution, VectorExtractStridedSliceDistribution>(
+          patterns.getContext(),
+          /*pattern benefit=*/highPatternBenefit);
 }
 
 void xegpu::populateXeGPUMoveFuncBodyToWarpOpPatterns(

>From 59f90b4bedf8684948c2851a199bec95b1aaecd1 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 18 Nov 2025 19:30:51 +0000
Subject: [PATCH 03/14] save work

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 143 ++++++++++++++++--
 1 file changed, 131 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index e4f17f0abdc6b..a3e7f8469cd93 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1535,16 +1535,21 @@ struct VectorExtractStridedSliceDistribution
           "source dim size.");
     auto sourceLayout =
         xegpu::getDistributeLayoutAttr(extractOp->getOpOperand(0));
+    if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt().empty())
+      return rewriter.notifyMatchFailure(
+          warpOp, "the source of extract_strided_slice op lacks distribution "
+                  "layout");
     auto sourceLaneLayout = sourceLayout.getEffectiveLaneLayoutAsInt();
     // Because only single dimension distribution is supported, lane layout size
     // at the distributed dim must be the subgroup size.
     int subgroupSize = sourceLaneLayout[distributedDim];
-    // Check if the distributed extracted dim is a multiple of the lane size.
+    // Check if the distributed extracted dim is a multiple of the subgroup
+    // size.
     if (distrDimExtractedSize % subgroupSize != 0)
       return rewriter.notifyMatchFailure(
           warpOp,
           "Extracted size along distributed dimension is not a multiple of "
-          "lane size in source layout.");
+          "subgroup size in source layout.");
     auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
     // We expect lane data to be all ones in this case.
     if (!llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; }))
@@ -1560,13 +1565,10 @@ struct VectorExtractStridedSliceDistribution
                                          "is not a multiple of subgroup size.");
     // Do the distribution by yielding the source of the extract op from
     // the warp op and creating a new extract op outside the warp op.
-    FailureOr<VectorType> sourceDistTypeOrFailure =
+    VectorType sourceDistType =
         getDistVecTypeBasedOnLaneLayout(sourceLayout,
-                                        extractOp.getSourceVectorType());
-    if (failed(sourceDistTypeOrFailure))
-      return rewriter.notifyMatchFailure(
-          warpOp, "failed to get distributed vector type for source");
-    VectorType sourceDistType = sourceDistTypeOrFailure.value();
+                                        extractOp.getSourceVectorType())
+            .value();
     // Create a new warp op that yields the source of the extract op.
     SmallVector<size_t> newRetIndices;
     auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
@@ -1581,9 +1583,10 @@ struct VectorExtractStridedSliceDistribution
     // Update the distributed sizes to match the distributed type.
     distributedSizes[distributedDim] =
         rewriter.getI64IntegerAttr(distributedType.getDimSize(distributedDim));
-    // Update the distributed offsets to match round robin distribution.
-    distributedOffsets[distributedDim] = rewriter.getI64IntegerAttr(
-        distrDimOffset / subgroupSize); // because lane data is 1
+    // Update the distributed offsets to match round robin distribution (i.e.
+    // each lane owns data at `subgroupSize` stride given unit lane data).
+    distributedOffsets[distributedDim] =
+        rewriter.getI64IntegerAttr(distrDimOffset / subgroupSize);
     Value source = newWarpOp.getResult(newRetIndices[0]);
     // Create a new extract op outside the warp op.
     Value newExtractOp = vector::ExtractStridedSliceOp::create(
@@ -1596,6 +1599,121 @@ struct VectorExtractStridedSliceDistribution
   }
 };
 
+struct VectorInsertStridedSliceDistribution
+    : public gpu::WarpDistributionPattern {
+  using gpu::WarpDistributionPattern::WarpDistributionPattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *operand =
+        getWarpResult(warpOp, llvm::IsaPred<vector::InsertStridedSliceOp>);
+    if (!operand)
+      return failure();
+    unsigned int operandNumber = operand->getOperandNumber();
+    auto insertOp =
+        operand->get().getDefiningOp<vector::InsertStridedSliceOp>();
+    auto distributedType =
+        cast<VectorType>(warpOp.getResult(operandNumber).getType());
+    // Find the distributed dimension of the dest vector. There should be
+    // exactly one.
+    auto yieldedType = cast<VectorType>(operand->get().getType());
+    auto destDistributedDims = getDistributedDims(yieldedType, distributedType);
+    // Only single dimension distribution is supported.
+    if (destDistributedDims.size() != 1)
+      return rewriter.notifyMatchFailure(
+          warpOp, "Expecting source to be distributed in a single dimension.");
+    int64_t destDistributedDim = destDistributedDims[0];
+
+    VectorType srcType = insertOp.getSourceVectorType();
+    VectorType destType = insertOp.getDestVectorType();
+    // Currently we require that both source (kD) and dest (nD) vectors are
+    // distributed. This requires that distributedDim (d) is contained in the
+    // last k dims of the dest vector (d >= n - k).
+    int64_t sourceDistributedDim =
+        destDistributedDim - (destType.getRank() - srcType.getRank());
+    if (sourceDistributedDim < 0)
+      return rewriter.notifyMatchFailure(
+          insertOp, "distributed dimension must be in the last k (i.e. source "
+                    "rank) dims of dest vector");
+    // If the distributed dimension is fully inserted, skip. This case is
+    // already handled by vector distribution patterns.
+    int64_t destDistrDimSize = destType.getDimSize(destDistributedDim);
+    int64_t srcDistrDimSize = srcType.getDimSize(sourceDistributedDim);
+    if (srcDistrDimSize == destDistrDimSize)
+      return rewriter.notifyMatchFailure(
+          insertOp, "distributed dimension is fully inserted. This case "
+                    "is handled by vector distribution.");
+    // Obtain the source and dest layouts.
+    auto destLayout = xegpu::getDistributeLayoutAttr(insertOp->getOpOperand(1));
+    auto sourceLayout =
+        xegpu::getDistributeLayoutAttr(insertOp->getOpOperand(0));
+    if (!destLayout || !sourceLayout ||
+        destLayout.getEffectiveLaneLayoutAsInt().empty() ||
+        sourceLayout.getEffectiveLaneLayoutAsInt().empty())
+      return rewriter.notifyMatchFailure(
+          warpOp, "the source or dest of insert_strided_slice op lacks "
+                  "distribution layout");
+    // Because only single dimension distribution is supported, lane layout
+    // size at the distributed dim must be the subgroup size.
+    int subgroupSize =
+        destLayout.getEffectiveLaneLayoutAsInt()[destDistributedDim];
+    // We require that source and dest lane data are all ones to ensure uniform
+    // round robin distribution.
+    auto destLaneData = destLayout.getEffectiveLaneDataAsInt();
+    auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
+    if (!llvm::all_of(destLaneData, [](int64_t v) { return v == 1; }) ||
+        !llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; }))
+      return rewriter.notifyMatchFailure(
+          warpOp, "Expecting unit lane data in source and dest layouts");
+    // Distributed dim sizes must be multiples of subgroup size.
+    if (destDistrDimSize % subgroupSize != 0 ||
+        srcDistrDimSize % subgroupSize != 0)
+      return rewriter.notifyMatchFailure(
+          warpOp,
+          "Distributed dimension size in source or dest is not a multiple of "
+          "subgroup size.");
+    // Offsets in the distributed dimension must be multiples of subgroup size.
+    int64_t destDistrDimOffset =
+        cast<IntegerAttr>(insertOp.getOffsets()[destDistributedDim]).getInt();
+    if (destDistrDimOffset % subgroupSize != 0)
+      return rewriter.notifyMatchFailure(
+          warpOp,
+          "Offset along distributed dimension in dest is not a multiple of "
+          "subgroup size.");
+    // Do the distribution by yielding the source and dest of the insert op from
+    // the warp op and creating a new insert op outside the warp op.
+    VectorType sourceDistType =
+        getDistVecTypeBasedOnLaneLayout(sourceLayout,
+                                        insertOp.getSourceVectorType())
+            .value();
+    VectorType destDistType = getDistVecTypeBasedOnLaneLayout(
+                                  destLayout, insertOp.getDestVectorType())
+                                  .value();
+    // Create a new warp op that yields the source and dest of the insert op.
+    SmallVector<size_t> newRetIndices;
+    auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
+        {sourceDistType, destDistType}, newRetIndices);
+    rewriter.setInsertionPointAfter(newWarpOp);
+    // Distributed offsets must be adjusted.
+    SmallVector<Attribute> distributedOffsets = llvm::map_to_vector(
+        insertOp.getOffsets(), [](Attribute attr) { return attr; });
+    // Update the distributed offsets to match round robin distribution (i.e.
+    // each lane owns data at `subgroupSize` stride given unit lane data).
+    distributedOffsets[destDistributedDim] =
+        rewriter.getI64IntegerAttr(destDistrDimOffset / subgroupSize);
+    Value valueToStore = newWarpOp.getResult(newRetIndices[0]);
+    Value dest = newWarpOp.getResult(newRetIndices[1]);
+    // Create a new insert op outside the warp op.
+    Value newInsertOp = vector::InsertStridedSliceOp::create(
+        rewriter, insertOp.getLoc(), destDistType, valueToStore, dest,
+        ArrayAttr::get(rewriter.getContext(), distributedOffsets),
+        insertOp.getStrides());
+    rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber),
+                                newInsertOp);
+    return success();
+  }
+};
+
 /// Sink a memref::ExtractAlignedPointerAsIndex op feeding into yield op of an
 /// enclosing `gpu.warp_execute_on_lane_0` region. This will simply move the op
 /// outside of the warp op.
@@ -1754,7 +1872,8 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
       patterns.getContext(),
       /*pattern benefit=*/regularPatternBenefit);
   patterns
-      .add<VectorShapeCastDistribution, VectorExtractStridedSliceDistribution>(
+      .add<VectorShapeCastDistribution, VectorExtractStridedSliceDistribution,
+           VectorInsertStridedSliceDistribution>(
           patterns.getContext(),
           /*pattern benefit=*/highPatternBenefit);
 }

>From f748b80547e00fbaf60688b5696a8f40cd29cbe5 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 18 Nov 2025 22:34:11 +0000
Subject: [PATCH 04/14] save work

---
 .../XeGPU/subgroup-distribute-unit.mlir       | 674 ++++++++++--------
 1 file changed, 386 insertions(+), 288 deletions(-)

diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
index f233dff609f2b..d8ed46646810d 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
@@ -1,6 +1,6 @@
-// RUN: mlir-opt --xevm-attach-target='module=xevm_* chip=pvc' -test-xegpu-sg-distribute -allow-unregistered-dialect \
-// RUN: -canonicalize -cse -split-input-file %s | FileCheck %s
-
+// RUN: mlir-opt --xevm-attach-target='module=xevm_* chip=pvc' -test-xegpu-sg-distribute  \
+// RUN: -allow-unregistered-dialect -canonicalize -cse  %s | FileCheck %s
+gpu.module @xevm_module{
 // CHECK-LABEL: gpu.func @store_nd_1d
 // CHECK:         (%[[ARG0:[0-9a-zA-Z]+]]: index) {
 // CHECK:         %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16]
@@ -11,20 +11,17 @@
 // CHECK-NEXT:    %[[T1:.*]] = builtin.unrealized_conversion_cast %[[W]]#1 : !xegpu.tensor_desc<16xf32,
 // CHECK-SAME:      #xegpu.layout<lane_layout = [16], lane_data = [1]>> to !xegpu.tensor_desc<16xf32> {resolve_simt_type_mismatch}
 // CHECK-NEXT:    xegpu.store_nd %[[W]]#0, %[[T1]][%[[W]]#2]  : vector<1xf32>, !xegpu.tensor_desc<16xf32>
-gpu.module @xevm_module{
-  gpu.func @store_nd_1d(%laneid: index) {
-    %c0 = arith.constant 0 : index
-    gpu.warp_execute_on_lane_0(%laneid)[16] {
-      %0 = "some_op"() : () -> !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
-      %cst = "some_op"() : () -> vector<16xf32>
-      xegpu.store_nd %cst, %0 [%c0] {layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
-        : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
-    }
-    gpu.return
+gpu.func @store_nd_1d(%laneid: index) {
+  %c0 = arith.constant 0 : index
+  gpu.warp_execute_on_lane_0(%laneid)[16] {
+    %0 = "some_op"() : () -> !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+    %cst = "some_op"() : () -> vector<16xf32>
+    xegpu.store_nd %cst, %0 [%c0] {layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+      : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
   }
+  gpu.return
 }
 
-// -----
 // CHECK-LABEL: gpu.func @store_nd_2d
 // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) {
 // CHECK:       %[[W:.*]]:4 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16]
@@ -37,22 +34,18 @@ gpu.module @xevm_module{
 // CHECK-NEXT:  %[[T1:.*]] = builtin.unrealized_conversion_cast %[[W]]#1 : !xegpu.tensor_desc<16x16xf16,
 // CHECK-SAME:    #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> to !xegpu.tensor_desc<16x16xf16> {resolve_simt_type_mismatch}
 // CHECK-NEXT:  xegpu.store_nd %[[CAST]], %[[T1]][%[[W]]#2, %[[W]]#3]  : vector<16xf16>, !xegpu.tensor_desc<16x16xf16>
-gpu.module @xevm_module{
-  gpu.func @store_nd_2d(%laneid : index) {
-    %c0 = arith.constant 0 : index
-    gpu.warp_execute_on_lane_0(%laneid)[16] {
-      %0 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-      %cst = "some_op"() : () -> vector<16x16xf16>
-      xegpu.store_nd %cst, %0 [%c0, %c0] {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
-        : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-    }
-    gpu.return
+gpu.func @store_nd_2d(%laneid : index) {
+  %c0 = arith.constant 0 : index
+  gpu.warp_execute_on_lane_0(%laneid)[16] {
+    %0 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    %cst = "some_op"() : () -> vector<16x16xf16>
+    xegpu.store_nd %cst, %0 [%c0, %c0] {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+      : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
   }
+  gpu.return
 }
 
 
-
-// -----
 // CHECK-LABEL: gpu.func @load_nd_1d
 // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) {
 // CHECK:       %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (vector<1xf32>,
@@ -63,21 +56,19 @@ gpu.module @xevm_module{
 // CHECK-NEXT:  %[[T1:.*]] = builtin.unrealized_conversion_cast %[[W]]#1 : !xegpu.tensor_desc<16xf32,
 // CHECK-SAME:    #xegpu.layout<lane_layout = [16], lane_data = [1]>> to !xegpu.tensor_desc<16xf32> {resolve_simt_type_mismatch}
 // CHECK-NEXT:  xegpu.load_nd %[[T1]][%[[W]]#2]  : !xegpu.tensor_desc<16xf32> -> vector<1xf32>
-gpu.module @xevm_module{
-  gpu.func @load_nd_1d(%laneid: index) {
-    %c0 = arith.constant 0 : index
-    %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1xf32>) {
-      %0 = "some_op"() : () -> !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
-      %1 = xegpu.load_nd %0 [%c0]  {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} :
-        !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>> -> vector<16xf32>
-      gpu.yield %1 : vector<16xf32>
-    }
-    "some_user_op"(%r) : (vector<1xf32>) -> ()
-    gpu.return
+gpu.func @load_nd_1d(%laneid: index) {
+  %c0 = arith.constant 0 : index
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1xf32>) {
+    %0 = "some_op"() : () -> !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+    %1 = xegpu.load_nd %0 [%c0]  {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} :
+      !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>> -> vector<16xf32>
+    gpu.yield %1 : vector<16xf32>
   }
+  "some_user_op"(%r) : (vector<1xf32>) -> ()
+  gpu.return
 }
 
-// -----
+
 // CHECK-LABEL: gpu.func @load_nd_2d
 // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) {
 // CHECK:       %[[W:.*]]:4 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (vector<16x1xf16>, !xegpu.tensor_desc<16x16xf16,
@@ -89,21 +80,19 @@ gpu.module @xevm_module{
 // CHECK-SAME:     #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> to !xegpu.tensor_desc<16x16xf16> {resolve_simt_type_mismatch}
 // CHECK-NEXT:  %[[T2:.*]] = xegpu.load_nd %[[T1]][%[[W]]#2, %[[W]]#3]  : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
 // CHECK:       vector.shape_cast %[[T2]] : vector<16xf16> to vector<16x1xf16>
-gpu.module @xevm_module{
-  gpu.func @load_nd_2d(%laneid: index) {
-    %c0 = arith.constant 0 : index
-    %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<16x1xf16>) {
-      %0 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-      %1 = xegpu.load_nd %0[%c0, %c0]  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
-        : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
-      gpu.yield %1 : vector<16x16xf16>
-    }
-    "some_user_op"(%r) : (vector<16x1xf16>) -> ()
-    gpu.return
+gpu.func @load_nd_2d(%laneid: index) {
+  %c0 = arith.constant 0 : index
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<16x1xf16>) {
+    %0 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    %1 = xegpu.load_nd %0[%c0, %c0]  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+      : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
+    gpu.yield %1 : vector<16x16xf16>
   }
+  "some_user_op"(%r) : (vector<16x1xf16>) -> ()
+  gpu.return
 }
 
-// -----
+
 // CHECK-LABEL: gpu.func @load_nd_array_length
 // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) {
 // CHECK:       %[[W:.*]]:4 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (vector<2x16x1xf16>,
@@ -118,23 +107,21 @@ gpu.module @xevm_module{
 // CHECK-NEXT:  %[[T2:.*]] = xegpu.load_nd %[[T1]][%[[W]]#2, %[[W]]#3]  : !xegpu.tensor_desc<16x16xf16,
 // CHECK-SAME:    #xegpu.block_tdesc_attr<array_length = 2 : i64>> -> vector<32xf16>
 // CHECK-NEXT:  vector.shape_cast %[[T2]] : vector<32xf16> to vector<2x16x1xf16>
-gpu.module @xevm_module{
-  gpu.func @load_nd_array_length(%laneid: index) {
-    %c0 = arith.constant 0 : index
-    %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2x16x1xf16>) {
-      %0 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>,
-        #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-      %1 = xegpu.load_nd %0[%c0, %c0]  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
-        : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>,
-          #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<2x16x16xf16>
-      gpu.yield %1 : vector<2x16x16xf16>
-    }
-    "some_user_op"(%r) : (vector<2x16x1xf16>) -> ()
-    gpu.return
+gpu.func @load_nd_array_length(%laneid: index) {
+  %c0 = arith.constant 0 : index
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2x16x1xf16>) {
+    %0 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>,
+      #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    %1 = xegpu.load_nd %0[%c0, %c0]  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+      : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>,
+        #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<2x16x16xf16>
+    gpu.yield %1 : vector<2x16x16xf16>
   }
+  "some_user_op"(%r) : (vector<2x16x1xf16>) -> ()
+  gpu.return
 }
 
-// -----
+
 // CHECK-LABEL: gpu.func @dpas
 // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) {
 // CHECK:       %[[W:.*]]:4 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] ->
@@ -146,29 +133,27 @@ gpu.module @xevm_module{
 // CHECK-DAG:   %[[T3:.*]] = vector.shape_cast %[[W]]#3 : vector<8x1xf32> to vector<8xf32>
 // CHECK-NEXT:  %[[T4:.*]] = xegpu.dpas %[[T1]], %[[T2]], %[[T3]] : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32>
 // CHECK-NEXT:  vector.shape_cast %[[T4]] : vector<8xf32> to vector<8x1xf32>
-gpu.module @xevm_module{
-  gpu.func @dpas(%laneid: index) {
-    %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<8x1xf32>) {
-      %0 = "some_op"() : () -> vector<8x16xf16>
-      %1 = "some_op"() : () -> vector<16x16xf16>
-      %2 = "some_op"() : () -> vector<8x16xf32>
-      %3 = xegpu.dpas %0, %1, %2
-        {
-          layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
-          layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>,
-          layout_operand_2 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
-          layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
-        }
-        : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
-      gpu.yield %3 : vector<8x16xf32>
-    }
-    "some_user_op"(%r) : (vector<8x1xf32>) -> ()
-    gpu.return
+gpu.func @dpas(%laneid: index) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<8x1xf32>) {
+    %0 = "some_op"() : () -> vector<8x16xf16>
+    %1 = "some_op"() : () -> vector<16x16xf16>
+    %2 = "some_op"() : () -> vector<8x16xf32>
+    %3 = xegpu.dpas %0, %1, %2
+      {
+        layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+        layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>,
+        layout_operand_2 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+        layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+      }
+      : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+    gpu.yield %3 : vector<8x16xf32>
   }
+  "some_user_op"(%r) : (vector<8x1xf32>) -> ()
+  gpu.return
 }
 
 
-// -----
+
 // CHECK-LABEL: gpu.func @create_nd_tdesc_non_memref
 // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: ui64, %[[ARG1:[0-9a-zA-Z]+]]: index) {
 // CHECK:       %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%[[ARG1]])[16] -> (!xegpu.tensor_desc<16x16xf16,
@@ -178,21 +163,19 @@ gpu.module @xevm_module{
 // CHECK-NEXT:  %[[T1:.*]] = xegpu.create_nd_tdesc %[[W]]#1, shape : [64, 128], strides : [128, 1] : ui64 -> !xegpu.tensor_desc<16x16xf16>
 // CHECK-NEXT:  builtin.unrealized_conversion_cast %[[T1]] : !xegpu.tensor_desc<16x16xf16> to !xegpu.tensor_desc<16x16xf16,
 // CHECK-SAME:    #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> {resolve_simt_type_mismatch}
-gpu.module @xevm_module{
-  gpu.func @create_nd_tdesc_non_memref(%arg0: ui64, %laneid: index) {
-    %c0 = arith.constant 0 : index
-    %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (!xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) {
-      %0 = xegpu.create_nd_tdesc %arg0, shape:[64, 128], strides:[128, 1] : ui64 ->
-        !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-      gpu.yield %0 : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-    }
-    "some_user_op"(%r)
-      : (!xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) -> ()
-    gpu.return
+gpu.func @create_nd_tdesc_non_memref(%arg0: ui64, %laneid: index) {
+  %c0 = arith.constant 0 : index
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (!xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) {
+    %0 = xegpu.create_nd_tdesc %arg0, shape:[64, 128], strides:[128, 1] : ui64 ->
+      !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    gpu.yield %0 : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
   }
+  "some_user_op"(%r)
+    : (!xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) -> ()
+  gpu.return
 }
 
-// -----
+
 // CHECK-LABEL: gpu.func @prefetch_2d
 // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) {
 // CHECK:       %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (!xegpu.tensor_desc<16x16xf16,
@@ -204,21 +187,19 @@ gpu.module @xevm_module{
 // CHECK-SAME:    #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> to !xegpu.tensor_desc<16x16xf16> {resolve_simt_type_mismatch}
 // CHECK-NEXT:  xegpu.prefetch_nd %[[T1]][%[[W]]#1, %[[W]]#2]
 // CHECK-SAME:    <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x16xf16>
-gpu.module @xevm_module{
-  gpu.func @prefetch_2d(%laneid: index) {
-    %c0 = arith.constant 0 : index
-    gpu.warp_execute_on_lane_0(%laneid)[16] {
-      %0 = "some_op"() : ()
-        -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-      xegpu.prefetch_nd %0[%c0, %c0]
-        <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
-        : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-    }
-    gpu.return
+gpu.func @prefetch_2d(%laneid: index) {
+  %c0 = arith.constant 0 : index
+  gpu.warp_execute_on_lane_0(%laneid)[16] {
+    %0 = "some_op"() : ()
+      -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    xegpu.prefetch_nd %0[%c0, %c0]
+      <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+      : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
   }
+  gpu.return
 }
 
-// -----
+
 // CHECK-LABEL: gpu.func @prefetch_1d
 // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) {
 // CHECK:       %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (!xegpu.tensor_desc<16xf16,
@@ -229,44 +210,40 @@ gpu.module @xevm_module{
 // CHECK-SAME:    #xegpu.layout<lane_layout = [16], lane_data = [1]>> to !xegpu.tensor_desc<16xf16> {resolve_simt_type_mismatch}
 // CHECK-NEXT:  xegpu.prefetch_nd %[[T1]][%[[W]]#1] <{l1_hint = #xegpu.cache_hint<cached>,
 // CHECK-SAME:    l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16xf16>
-gpu.module @xevm_module{
-  gpu.func @prefetch_1d(%laneid: index) {
-    %c0 = arith.constant 0 : index
-    gpu.warp_execute_on_lane_0(%laneid)[16] {
-      %0 = "some_op"() : ()
-        -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
-      xegpu.prefetch_nd %0[%c0]
-        <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
-        : !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
-    }
-    gpu.return
+gpu.func @prefetch_1d(%laneid: index) {
+  %c0 = arith.constant 0 : index
+  gpu.warp_execute_on_lane_0(%laneid)[16] {
+    %0 = "some_op"() : ()
+      -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+    xegpu.prefetch_nd %0[%c0]
+      <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+      : !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
   }
+  gpu.return
 }
 
-// -----
+
 // CHECK-LABEL: gpu.func @gpu_barrier({{.*}}) {
 // CHECK:  gpu.warp_execute_on_lane_0(%{{.*}})[16] -> ({{.*}}) {
 // CHECK:     gpu.yield %{{.*}}
 // CHECK:  }
 // CHECK:  %{{.*}} = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16xf16> -> vector<1xf16>
 // CHECK:  gpu.barrier
-gpu.module @xevm_module{
-  gpu.func @gpu_barrier(%laneid: index) {
-    %c0 = arith.constant 0 : index
-    %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1xf16>) {
-      %0 = "some_op"() : () -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
-      %1 = xegpu.load_nd %0[%c0]
-        {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
-        : !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>> -> vector<16xf16>
-      gpu.barrier
-      gpu.yield %1 : vector<16xf16>
-    }
-    "some_user_op"(%r) : (vector<1xf16>) -> ()
-    gpu.return
+gpu.func @gpu_barrier(%laneid: index) {
+  %c0 = arith.constant 0 : index
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1xf16>) {
+    %0 = "some_op"() : () -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+    %1 = xegpu.load_nd %0[%c0]
+      {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+      : !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>> -> vector<16xf16>
+    gpu.barrier
+    gpu.yield %1 : vector<16xf16>
   }
+  "some_user_op"(%r) : (vector<1xf16>) -> ()
+  gpu.return
 }
 
-// -----
+
 // CHECK-LABEL: gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction
 // CHECK:       %[[ACC:.*]] = arith.constant {{.*}} dense<0.000000e+00> : vector<32xf32>
 // CHECK:       %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16]
@@ -285,7 +262,6 @@ gpu.module @xevm_module{
 // CHECK:       %[[T7:.*]] = vector.extract %[[W]]#2[1] : f32 from vector<2xf32>
 // CHECK:       %[[T8:.*]] = vector.reduction <add>, %[[T6]], %[[T7]] : vector<16xf32> into f32
 // CHECK:       %[[T9:.*]] = vector.from_elements %[[T4]], %[[T8]] : vector<2xf32>
-gpu.module @xevm_module{
 gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction(%laneid: index) {
   %c0 = arith.constant 0 : index
   %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) {
@@ -307,9 +283,8 @@ gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction(%laneid: index)
   "some_user_op"(%r) : (vector<2xf32>) -> ()
   gpu.return
 }
-}
 
-// -----
+
 // CHECK-LABEL: gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction
 // CHECK:      %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>) {
 // CHECK-NEXT:   %[[SRC:.*]] = "some_def"() {{.*}} : () -> vector<2x16xf32>
@@ -320,7 +295,6 @@ gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction(%laneid: index)
 // CHECK-NEXT:   %[[T6:.*]] = vector.from_elements %[[T3]], %[[T5]] : vector<2xf32>
 // CHECK-NEXT:   gpu.yield %[[T6]] : vector<2xf32>
 // CHECK-NEXT: }
-gpu.module @xevm_module{
 gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction(%laneid: index) {
   %c0 = arith.constant 0 : index
   %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) {
@@ -342,9 +316,8 @@ gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction(%laneid: index)
   "some_user_op"(%r) : (vector<2xf32>) -> ()
   gpu.return
 }
-}
 
-// -----
+
 // CHECK-LABEL:   gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction
 // CHECK:       %[[ACC:.*]] = arith.constant {{.*}} dense<0.000000e+00> : vector<32xf32>
 // CHECK:       %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>, vector<2x16xf32>, vector<2xf32>) {
@@ -358,7 +331,6 @@ gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction(%laneid: index)
 // CHECK:       %[[T5:.*]] = vector.extract %[[W]]#2[1] : f32 from vector<2xf32>
 // CHECK:       %[[T6:.*]] = vector.reduction <add>, %[[T4]], %[[T5]] : vector<16xf32> into f32
 // CHECK:       %[[T7:.*]] = vector.from_elements %[[T3]], %[[T6]] : vector<2xf32>
-gpu.module @xevm_module{
 gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction(%laneid: index) {
   %c0 = arith.constant 0 : index
   %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) {
@@ -380,9 +352,8 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction(%laneid: index)
   "some_user_op"(%r) : (vector<2xf32>) -> ()
   gpu.return
 }
-}
 
-// -----
+
 // CHECK-LABEL: gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction
 // CHECK:     %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>) {
 // CHECK:       %[[SRC:.*]] = "some_def"() {{.*}} : () -> vector<16x2xf32>
@@ -397,7 +368,6 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction(%laneid: index)
 // CHECK:       %[[T7:.*]] = vector.from_elements %[[T3]], %[[T6]] : vector<2xf32>
 // CHECK:       gpu.yield %[[T7]] : vector<2xf32>
 // CHECK:     }
-gpu.module @xevm_module{
 gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction(%laneid: index) {
   %c0 = arith.constant 0 : index
   %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) {
@@ -419,9 +389,8 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction(%laneid: index)
   "some_user_op"(%r) : (vector<2xf32>) -> ()
   gpu.return
 }
-}
 
-// -----
+
 // CHECK-LABEL: gpu.func @scatter_ops_chunksize({{.*}}) {
 // CHECK:       %[[OFFSETS:.*]] = arith.constant {{.*}} dense<12> : vector<16xindex>
 // CHECK:       %[[MASKS:.*]] = arith.constant {{.*}} dense<true> : vector<16xi1>
@@ -434,35 +403,33 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction(%laneid: index)
 // CHECK-SAME:    : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
 // CHECK-NEXT:  xegpu.store %[[T1]], %[[W]]#1[%[[W]]#2], %[[W]]#3 <{chunk_size = 8 : i64}>
 // CHECK-SAME:    : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
-gpu.module @xevm_module{
-  gpu.func @scatter_ops_chunksize(%laneid: index, %src: memref<256xf16>) {
-    gpu.warp_execute_on_lane_0(%laneid)[16] {
-      %1 = arith.constant
-        {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
-        dense<1>: vector<16xi1>
-      %offset = arith.constant
-        {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
-        dense<12> : vector<16xindex>
-      %3 = xegpu.load %src[%offset], %1 <{chunk_size=8}>
-        {
-          layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
-          layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
-          layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
-        }
-        : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
-      xegpu.store %3, %src[%offset], %1 <{chunk_size=8}>
-        {
-          layout_operand_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>,
-          layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
-          layout_operand_3 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
-        }
-        : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
-    }
-    gpu.return
+gpu.func @scatter_ops_chunksize(%laneid: index, %src: memref<256xf16>) {
+  gpu.warp_execute_on_lane_0(%laneid)[16] {
+    %1 = arith.constant
+      {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+      dense<1>: vector<16xi1>
+    %offset = arith.constant
+      {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+      dense<12> : vector<16xindex>
+    %3 = xegpu.load %src[%offset], %1 <{chunk_size=8}>
+      {
+        layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+        layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+        layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
+      }
+      : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
+    xegpu.store %3, %src[%offset], %1 <{chunk_size=8}>
+      {
+        layout_operand_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>,
+        layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+        layout_operand_3 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+      }
+      : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
   }
+  gpu.return
 }
 
-// -----
+
 // CHECK-LABEL: gpu.func @scatter_ops({{.*}}) {
 // CHECK:       %[[OFFSETS:.*]] = arith.constant {{.*}} dense<12> : vector<16xindex>
 // CHECK:       %[[MASKS:.*]] = arith.constant {{.*}} dense<true> : vector<16xi1>
@@ -475,156 +442,144 @@ gpu.module @xevm_module{
 // CHECK-SAME:    : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16>
 // CHECK-NEXT:  xegpu.store %[[T1]], %[[W]]#1[%[[W]]#2], %[[W]]#3
 // CHECK-SAME:    : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
-gpu.module @xevm_module{
-  gpu.func @scatter_ops(%src: memref<256xf16>, %laneid: index) {
-    gpu.warp_execute_on_lane_0(%laneid)[16] {
-      %1 = arith.constant
-        {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
-        dense<1> : vector<16xi1>
-      %offset = arith.constant
-        {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
-        dense<12> : vector<16xindex>
-      %3 = xegpu.load %src[%offset], %1
-      {
-        layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
-        layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
-        layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
-      } : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
-      xegpu.store %3, %src[%offset], %1
-      {
-        layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
-        layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
-        layout_operand_3 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
-      }
-      : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+gpu.func @scatter_ops(%src: memref<256xf16>, %laneid: index) {
+  gpu.warp_execute_on_lane_0(%laneid)[16] {
+    %1 = arith.constant
+      {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+      dense<1> : vector<16xi1>
+    %offset = arith.constant
+      {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+      dense<12> : vector<16xindex>
+    %3 = xegpu.load %src[%offset], %1
+    {
+      layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+      layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+      layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+    } : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
+    xegpu.store %3, %src[%offset], %1
+    {
+      layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+      layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+      layout_operand_3 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
     }
-    gpu.return
+    : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
   }
+  gpu.return
 }
 
-// -----
+
 // CHECK-LABEL: gpu.func @memref_extract_aligned_pointer_as_index(
 // CHECK:       %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (index, memref<256x256xf16>) {
 // CHECK:         gpu.yield %{{.*}}, %{{.*}} : index, memref<256x256xf16>
 // CHECK-NEXT:  }
 // CHECK-NEXT:  %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[W]]#1 : memref<256x256xf16> -> index
 // CHECK-NEXT:  arith.index_cast %[[INTPTR]] : index to i64
-gpu.module @xevm_module{
-  gpu.func @memref_extract_aligned_pointer_as_index(%arg0 : memref<256x256xf16>, %laneid: index) {
-    %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (index) {
-      %ptr = memref.extract_aligned_pointer_as_index %arg0 : memref<256x256xf16> -> index
-      gpu.yield %ptr : index
-    }
-    %ptr_i64 = arith.index_cast %r : index to i64
-    "some_user_op"(%ptr_i64) : (i64) -> ()
-    gpu.return
+gpu.func @memref_extract_aligned_pointer_as_index(%arg0 : memref<256x256xf16>, %laneid: index) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (index) {
+    %ptr = memref.extract_aligned_pointer_as_index %arg0 : memref<256x256xf16> -> index
+    gpu.yield %ptr : index
   }
+  %ptr_i64 = arith.index_cast %r : index to i64
+  "some_user_op"(%ptr_i64) : (i64) -> ()
+  gpu.return
 }
 
 
-// -----
+
 // CHECK-LABEL: gpu.func @vector_transpose(
 // CHECK:       %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2x1xf32>, vector<1x2xf32>) {
 // CHECK:         %[[SRC:.*]] = "some_op"() {{.*}} : () -> vector<16x2xf32>
 // CHECK:         gpu.yield %{{.*}}, %[[SRC]] : vector<2x16xf32>, vector<16x2xf32>
 // CHECK-NEXT:  }
 // CHECK-NEXT:  %[[T1:.*]] = vector.transpose %[[W]]#1, [1, 0] : vector<1x2xf32> to vector<2x1xf32>
-gpu.module @xevm_module{
-  gpu.func @vector_transpose(%laneid: index) {
-    %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2x1xf32>) {
-      %cst = "some_op"()
-        {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
-        : () -> (vector<16x2xf32>)
-      %transpose = vector.transpose %cst, [1, 0]
-        {
-          layout_operand_0 = #xegpu.layout<lane_layout = [16 , 1], lane_data = [1, 1]>,
-          layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
-        }
-        : vector<16x2xf32> to vector<2x16xf32>
-      gpu.yield %transpose : vector<2x16xf32>
-    }
-    "some_user_op"(%r) : (vector<2x1xf32>) -> ()
-    gpu.return
+gpu.func @vector_transpose(%laneid: index) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2x1xf32>) {
+    %cst = "some_op"()
+      {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+      : () -> (vector<16x2xf32>)
+    %transpose = vector.transpose %cst, [1, 0]
+      {
+        layout_operand_0 = #xegpu.layout<lane_layout = [16 , 1], lane_data = [1, 1]>,
+        layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+      }
+      : vector<16x2xf32> to vector<2x16xf32>
+    gpu.yield %transpose : vector<2x16xf32>
   }
+  "some_user_op"(%r) : (vector<2x1xf32>) -> ()
+  gpu.return
 }
 
-// -----
+
 // CHECK-LABEL: gpu.func @vector_bitcast(
 // CHECK:       %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<4x1xi16>, vector<4x2xi8>) {
 // CHECK:         %[[SRC:.*]] = "some_op"() {{.*}} : () -> vector<4x32xi8>
 // CHECK:         gpu.yield %{{.*}}, %[[SRC]] : vector<4x16xi16>, vector<4x32xi8>
 // CHECK:       }
 // CHECK:       vector.bitcast %[[W]]#1 : vector<4x2xi8> to vector<4x1xi16>
-gpu.module @xevm_module{
-  gpu.func @vector_bitcast(%laneid: index) {
-    %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<4x1xi16>) {
-      %cst = "some_op"()
-        {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>}
-        : () -> (vector<4x32xi8>)
-      %bitcast = vector.bitcast %cst
-        {
-          layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>,
-          layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
-        }
-        : vector<4x32xi8> to vector<4x16xi16>
-      gpu.yield %bitcast : vector<4x16xi16>
-    }
-    "some_user_op"(%r) : (vector<4x1xi16>) -> ()
-    gpu.return
+gpu.func @vector_bitcast(%laneid: index) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<4x1xi16>) {
+    %cst = "some_op"()
+      {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>}
+      : () -> (vector<4x32xi8>)
+    %bitcast = vector.bitcast %cst
+      {
+        layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>,
+        layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+      }
+      : vector<4x32xi8> to vector<4x16xi16>
+    gpu.yield %bitcast : vector<4x16xi16>
   }
+  "some_user_op"(%r) : (vector<4x1xi16>) -> ()
+  gpu.return
 }
 
-// -----
+
 // CHECK-LABEL: gpu.func @vector_shapecast_rank_increasing
 // CHECK:         %{{.*}}:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<1x1xf32>, vector<1xf32>) {
 // CHECK:           gpu.yield %{{.*}} : vector<1x16xf32>, vector<16xf32>
 // CHECK:         }
 // CHECK:         %{{.*}} = vector.shape_cast %{{.*}}#1 : vector<1xf32> to vector<1x1xf32>
-gpu.module @xevm_module {
-  gpu.func @vector_shapecast_rank_increasing(%laneid: index) {
-    %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x1xf32>) {
-      %cst = "some_op"()
-        {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>}
-        : () -> (vector<16xf32>)
-      %cast = vector.shape_cast %cst
-        {
-          layout_operand_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>,
-          layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
-        }
-        : vector<16xf32> to vector<1x16xf32>
-      gpu.yield %cast : vector<1x16xf32>
-    }
-    "some_user_op"(%r) : (vector<1x1xf32>) -> ()
-    gpu.return
+gpu.func @vector_shapecast_rank_increasing(%laneid: index) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x1xf32>) {
+    %cst = "some_op"()
+      {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>}
+      : () -> (vector<16xf32>)
+    %cast = vector.shape_cast %cst
+      {
+        layout_operand_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>,
+        layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+      }
+      : vector<16xf32> to vector<1x16xf32>
+    gpu.yield %cast : vector<1x16xf32>
   }
+  "some_user_op"(%r) : (vector<1x1xf32>) -> ()
+  gpu.return
 }
 
-// -----
+
 // CHECK-LABEL: gpu.func @vector_shapecast_rank_reducing(
 // CHECK:         %{{.*}}:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<1xf32>, vector<1x1xf32>) {
 // CHECK:           gpu.yield %{{.*}} : vector<16xf32>, vector<1x16xf32>
 // CHECK:         }
 // CHECK:         %{{.*}} = vector.shape_cast %{{.*}}#1 : vector<1x1xf32> to vector<1xf32>
-gpu.module @xevm_module {
-  gpu.func @vector_shapecast_rank_reducing(%laneid: index) {
-    %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1xf32>) {
-      %cst = "some_op"()
-        {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
-        : () -> (vector<1x16xf32>)
-      %cast = vector.shape_cast %cst
-        {
-          layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
-          layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>
-        }
-        : vector<1x16xf32> to vector<16xf32>
-      gpu.yield %cast : vector<16xf32>
-    }
-    "some_user_op"(%r) : (vector<1xf32>) -> ()
-    gpu.return
+gpu.func @vector_shapecast_rank_reducing(%laneid: index) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1xf32>) {
+    %cst = "some_op"()
+      {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+      : () -> (vector<1x16xf32>)
+    %cast = vector.shape_cast %cst
+      {
+        layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+        layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>
+      }
+      : vector<1x16xf32> to vector<16xf32>
+    gpu.yield %cast : vector<16xf32>
   }
+  "some_user_op"(%r) : (vector<1xf32>) -> ()
+  gpu.return
 }
 
-// -----
+
 // NOTE: Layouts are still valid, but distribution still requires a slice layout for the operand.
 //
 // CHECK-LABEL:  gpu.func @vector_shapecast_unsupported
@@ -634,21 +589,164 @@ gpu.module @xevm_module {
 // CHECK:          }
 // CHECK:          "some_user_op"(%[[W]]) : (vector<1x1xf32>) -> ()
 // CHECK:          gpu.return
-gpu.module @xevm_module {
-  gpu.func @vector_shapecast_unsupported(%laneid: index) {
-    %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x1xf32>) {
-      %cst = "some_op"()
-        {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]> }
-        : () -> (vector<16xf32>)
-      %cast = vector.shape_cast %cst
-        {
-          layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
-          layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
-        }
-        : vector<16xf32> to vector<1x16xf32>
-      gpu.yield %cast : vector<1x16xf32>
+gpu.func @vector_shapecast_unsupported(%laneid: index) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x1xf32>) {
+    %cst = "some_op"()
+      {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]> }
+      : () -> (vector<16xf32>)
+    %cast = vector.shape_cast %cst
+      {
+        layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+        layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+      }
+      : vector<16xf32> to vector<1x16xf32>
+    gpu.yield %cast : vector<1x16xf32>
+  }
+  "some_user_op"(%r) : (vector<1x1xf32>) -> ()
+  gpu.return
+}
+
+
+// CHECK-LABEL:  gpu.func @vector_extract_strided_slice_outer_distributed
+// CHECK:          %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<1x16xf32>, vector<2x16xf32>) {
+// CHECK-NEXT:       %[[S:.*]] = "some_def"() : () -> vector<32x16xf32>
+// CHECK:            gpu.yield %{{.*}}, %[[S]] : vector<16x16xf32>, vector<32x16xf32>
+// CHECK:          }
+// CHECK-NEXT:     %[[T1:.*]] = vector.extract %[[W]]#1[1] : vector<16xf32> from vector<2x16xf32>
+// CHECK-NEXT:     %[[T2:.*]] = vector.shape_cast %[[T1]] : vector<16xf32> to vector<1x16xf32>
+// CHECK-NEXT:     "some_use"(%[[T2]]) : (vector<1x16xf32>) -> ()
+gpu.func @vector_extract_strided_slice_outer_distributed(%laneid: index) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x16xf32>) {
+    %0 = "some_def"() : () -> (vector<32x16xf32>)
+    %1 = vector.extract_strided_slice %0 { offsets = [16], sizes = [16], strides = [1],
+        layout_operand_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>,
+        layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>
+      }
+      : vector<32x16xf32> to vector<16x16xf32>
+    gpu.yield %1 : vector<16x16xf32>
+  }
+  "some_use"(%r) : (vector<1x16xf32>) -> ()
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_extract_strided_slice_inner_distributed
+// CHECK:         %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<8x1xf32>, vector<24x4xf32>) {
+// CHECK-NEXT:      %[[S:.*]] = "some_def"() : () -> vector<24x64xf32>
+// CHECK:           gpu.yield %{{.*}}, %[[S]] : vector<8x16xf32>, vector<24x64xf32>
+// CHECK-NEXT:    }
+// CHECK-NEXT:    %[[T1:.*]] = vector.extract_strided_slice %[[W]]#1
+// CHECK-SAME:      {offsets = [8, 3], sizes = [8, 1], strides = [1, 1]} : vector<24x4xf32> to vector<8x1xf32>
+// CHECK-NEXT:    "some_use"(%[[T1]]) : (vector<8x1xf32>) -> ()
+gpu.func @vector_extract_strided_slice_inner_distributed(%laneid: index) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<8x1xf32>) {
+    %0 = "some_def"() : () -> (vector<24x64xf32>)
+    %1 = vector.extract_strided_slice %0 { offsets = [8, 48], sizes = [8, 16], strides = [1, 1],
+        layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+        layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+      }
+      : vector<24x64xf32> to vector<8x16xf32>
+    gpu.yield %1 : vector<8x16xf32>
+  }
+  "some_use"(%r) : (vector<8x1xf32>) -> ()
+  gpu.return
+}
+
+
+// CHECK-LABEL: gpu.func @vector_extract_strided_slice_1d
+// CHECK:         %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>, vector<4xf32>) {
+// CHECK:           %[[S:.*]] = "some_def"() : () -> vector<64xf32>
+// CHECK:           gpu.yield %{{.*}}, %[[S]] : vector<32xf32>, vector<64xf32>
+// CHECK-NEXT:    }
+// CHECK-NEXT:    %[[T1:.*]] = vector.extract_strided_slice %[[W]]#1
+// CHECK-SAME:      {offsets = [1], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
+// CHECK-NEXT:    "some_use"(%[[T1]]) : (vector<2xf32>) -> ()
+gpu.func @vector_extract_strided_slice_1d(%laneid: index) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) {
+    %0 = "some_def"() : () -> (vector<64xf32>)
+    %1 = vector.extract_strided_slice %0 { offsets = [16], sizes = [32], strides = [1],
+        layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+        layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+      }
+      : vector<64xf32> to vector<32xf32>
+    gpu.yield %1 : vector<32xf32>
+  }
+  "some_use"(%r) : (vector<2xf32>) -> ()
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_insert_strided_slice_inner_distributed
+// CHECK:         %[[W]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<64x2xf32>, vector<16x1xf32>, vector<64x2xf32>) {
+// CHECK-NEXT:      %[[S:.*]] = "some_def"() : () -> vector<16x16xf32>
+// CHECK-NEXT:      %[[D:.*]] = "some_def"() : () -> vector<64x32xf32>
+// CHECK:           gpu.yield %{{.*}}, %[[S]], %[[D]] : vector<64x32xf32>, vector<16x16xf32>, vector<64x32xf32>
+// CHECK-NEXT:    }
+// CHECK-NEXT:    %[[T1:.*]] = vector.insert_strided_slice %[[W]]#1, %[[W]]#2
+// CHECK-SAME:      {offsets = [24, 1], strides = [1, 1]} : vector<16x1xf32> into vector<64x2xf32>
+// CHECK-NEXT:    "some_use"(%[[T1]]) : (vector<64x2xf32>) -> ()
+gpu.func @vector_insert_strided_slice_inner_distributed(%laneid: index) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<64x2xf32>) {
+    %0 = "some_def"() : () -> (vector<16x16xf32>)
+    %1 = "some_def"() : () -> (vector<64x32xf32>)
+    %2 = vector.insert_strided_slice %0, %1 { offsets = [24, 16],  strides = [1, 1],
+      layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+      layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+      layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
     }
-    "some_user_op"(%r) : (vector<1x1xf32>) -> ()
-    gpu.return
+      : vector<16x16xf32> into vector<64x32xf32>
+    gpu.yield %2 : vector<64x32xf32>
   }
+  "some_use"(%r) : (vector<64x2xf32>) -> ()
+  gpu.return
+}
+
+// CHECK-LABEL:   gpu.func @vector_insert_strided_slice_outer_distributed
+// CHECK:           %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<3x32xf32>, vector<1x16xf32>, vector<3x32xf32>) {
+// CHECK-NEXT:        %[[S:.*]] = "some_def"() : () -> vector<16x16xf32>
+// CHECK-NEXT:        %[[D:.*]] = "some_def"() : () -> vector<48x32xf32>
+// CHECK:             gpu.yield %{{.*}}, %[[S]], %[[D]] : vector<48x32xf32>, vector<16x16xf32>, vector<48x32xf32>
+// CHECK-NEXT:      }
+// CHECK-NEXT:      %[[T1:.*]] = vector.insert_strided_slice %[[W]]#1, %[[W]]#2
+// CHECK-SAME:        {offsets = [2, 4], strides = [1, 1]} : vector<1x16xf32> into vector<3x32xf32>
+// CHECK-NEXT:      "some_use"(%[[T1]]) : (vector<3x32xf32>) -> ()
+gpu.func @vector_insert_strided_slice_outer_distributed(%laneid: index) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<3x32xf32>) {
+    %0 = "some_def"() : () -> (vector<16x16xf32>)
+    %1 = "some_def"() : () -> (vector<48x32xf32>)
+    %2 = vector.insert_strided_slice %0, %1 { offsets = [32, 4],  strides = [1, 1],
+      layout_operand_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>,
+      layout_operand_1 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>,
+      layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>
+    }
+      : vector<16x16xf32> into vector<48x32xf32>
+    gpu.yield %2 : vector<48x32xf32>
+  }
+  "some_use"(%r) : (vector<3x32xf32>) -> ()
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_insert_strided_slice_1d
+// CHECK:         %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<3xf32>, vector<1xf32>, vector<3xf32>) {
+// CHECK-NEXT:      %[[S:.*]] = "some_def"() : () -> vector<16xf32>
+// CHECK-NEXT:      %[[D:.*]] = "some_def"() : () -> vector<48xf32>
+// CHECK:           gpu.yield %{{.*}}, %[[S]], %[[D]] : vector<48xf32>, vector<16xf32>, vector<48xf32>
+// CHECK-NEXT:    }
+// CHECK-NEXT:    %[[T1:.*]] = vector.insert_strided_slice %[[W]]#1, %[[W]]#2
+// CHECK-SAME:      {offsets = [1], strides = [1]} : vector<1xf32> into vector<3xf32>
+// CHECK-NEXT:    "some_use"(%[[T1]]) : (vector<3xf32>) -> ()
+gpu.func @vector_insert_strided_slice_1d(%laneid: index) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<3xf32>) {
+    %0 = "some_def"() : () -> (vector<16xf32>)
+    %1 = "some_def"() : () -> (vector<48xf32>)
+    %2 = vector.insert_strided_slice %0, %1 { offsets = [16],  strides = [1],
+      layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+      layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+      layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+    }
+      : vector<16xf32> into vector<48xf32>
+    gpu.yield %2 : vector<48xf32>
+  }
+  "some_use"(%r) : (vector<3xf32>) -> ()
+  gpu.return
+}
+
 }

>From 4905450fc9011677993797c3dbda5550cc7c9e17 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 18 Nov 2025 23:04:34 +0000
Subject: [PATCH 05/14] save work

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 23 +++++-------
 .../XeGPU/subgroup-distribute-unit.mlir       | 36 +++++++++++++++++++
 2 files changed, 44 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index a3e7f8469cd93..71df8d4fcbf7d 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1520,19 +1520,12 @@ struct VectorExtractStridedSliceDistribution
 
     int distrDimExtractedSize =
         cast<IntegerAttr>(extractOp.getSizes()[distributedDim]).getInt();
-    if (distrDimExtractedSize ==
-        extractOp.getSourceVectorType().getShape()[distributedDim])
+    int sourceDistrDimSize =
+        extractOp.getSourceVectorType().getShape()[distributedDim];
+    if (distrDimExtractedSize == sourceDistrDimSize)
       return rewriter.notifyMatchFailure(
           warpOp, "Distributed dimension is fully extracted, skipping.");
 
-    // Check if the size extracted along the distributed dimension is a multiple
-    // of the source dim size and should be distributable to lanes.
-    int64_t sourceDisrDimSize = yieldedType.getShape()[distributedDim];
-    if (sourceDisrDimSize % distrDimExtractedSize != 0)
-      return rewriter.notifyMatchFailure(
-          warpOp,
-          "Extracted size along distributed dimension is not a multiple of "
-          "source dim size.");
     auto sourceLayout =
         xegpu::getDistributeLayoutAttr(extractOp->getOpOperand(0));
     if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt().empty())
@@ -1543,13 +1536,13 @@ struct VectorExtractStridedSliceDistribution
     // Because only single dimension distribution is supported, lane layout size
     // at the distributed dim must be the subgroup size.
     int subgroupSize = sourceLaneLayout[distributedDim];
-    // Check if the distributed extracted dim is a multiple of the subgroup
-    // size.
-    if (distrDimExtractedSize % subgroupSize != 0)
+    // Check if the source size in the distributed dimension is a multiple of
+    // subgroup size.
+    if (sourceDistrDimSize % subgroupSize != 0)
       return rewriter.notifyMatchFailure(
           warpOp,
-          "Extracted size along distributed dimension is not a multiple of "
-          "subgroup size in source layout.");
+          "Source size along distributed dimension is not a multiple of "
+          "subgroup size.");
     auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
     // We expect lane data to be all ones in this case.
     if (!llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; }))
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
index d8ed46646810d..4681b0958958c 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
@@ -749,4 +749,40 @@ gpu.func @vector_insert_strided_slice_1d(%laneid: index) {
   gpu.return
 }
 
+// CHECK-LABEL: gpu.func @vector_extract_strided_slice_unsopported_offset
+// CHECK:         %{{.*}} = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>) {
+// CHECK:         }
+// CHECK-NOT:     %{{.*}} = vector.extract_strided_slice
+gpu.func @vector_extract_strided_slice_unsopported_offset(%laneid: index) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) {
+    %0 = "some_def"() : () -> (vector<64xf32>)
+    %1 = vector.extract_strided_slice %0 { offsets = [3], sizes = [32], strides = [1],
+        layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+        layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+      }
+      : vector<64xf32> to vector<32xf32>
+    gpu.yield %1 : vector<32xf32>
+  }
+  "some_use"(%r) : (vector<2xf32>) -> ()
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_extract_strided_slice_unsopported_source
+// CHECK:         %{{.*}} = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>) {
+// CHECK:         }
+// CHECK-NOT:     %{{.*}} = vector.extract_strided_slice
+gpu.func @vector_extract_strided_slice_unsopported_source(%laneid: index) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) {
+    %0 = "some_def"() : () -> (vector<54xf32>)
+    %1 = vector.extract_strided_slice %0 { offsets = [0], sizes = [32], strides = [1],
+        layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+        layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+      }
+      : vector<54xf32> to vector<32xf32>
+    gpu.yield %1 : vector<32xf32>
+  }
+  "some_use"(%r) : (vector<2xf32>) -> ()
+  gpu.return
+}
+
 }

>From 36b27c44f1b697534a659e1f306d7d7c45c20a6e Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 18 Nov 2025 23:40:11 +0000
Subject: [PATCH 06/14] save work

---
 mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
index 4681b0958958c..4575a981e2986 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
@@ -675,7 +675,7 @@ gpu.func @vector_extract_strided_slice_1d(%laneid: index) {
 }
 
 // CHECK-LABEL: gpu.func @vector_insert_strided_slice_inner_distributed
-// CHECK:         %[[W]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<64x2xf32>, vector<16x1xf32>, vector<64x2xf32>) {
+// CHECK:         %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<64x2xf32>, vector<16x1xf32>, vector<64x2xf32>) {
 // CHECK-NEXT:      %[[S:.*]] = "some_def"() : () -> vector<16x16xf32>
 // CHECK-NEXT:      %[[D:.*]] = "some_def"() : () -> vector<64x32xf32>
 // CHECK:           gpu.yield %{{.*}}, %[[S]], %[[D]] : vector<64x32xf32>, vector<16x16xf32>, vector<64x32xf32>

>From c1e9eb4f0ed1cdb9e940a6e1150fa8f3691f0465 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 18 Nov 2025 23:43:05 +0000
Subject: [PATCH 07/14] save work

---
 .../Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp  | 8 ++++++++
 1 file changed, 8 insertions(+)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 71df8d4fcbf7d..7ecbd9226b43f 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1485,6 +1485,10 @@ struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
   }
 };
 
+// Distribute a `vector.extract_strided_slice` op feeding into yield op of an
+// enclosing `gpu.warp_execute_on_lane_0` region. This pattern only handles
+// advanced cases where the distributed is partially extracted and currently not
+// supported by the generic vector distribution patterns.
 struct VectorExtractStridedSliceDistribution
     : public gpu::WarpDistributionPattern {
   using gpu::WarpDistributionPattern::WarpDistributionPattern;
@@ -1592,6 +1596,10 @@ struct VectorExtractStridedSliceDistribution
   }
 };
 
+/// Distribute a `vector.insert_strided_slice` op feeding into yield op of an
+/// enclosing `gpu.warp_execute_on_lane_0` region. This pattern only handles
+/// advanced cases where the distributed dimension is partially inserted and
+/// currently not supported by the generic vector distribution patterns.
 struct VectorInsertStridedSliceDistribution
     : public gpu::WarpDistributionPattern {
   using gpu::WarpDistributionPattern::WarpDistributionPattern;

>From 8975d6a46ccc47bc017d2a3844e4b6ae43a5da51 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 18 Nov 2025 23:44:24 +0000
Subject: [PATCH 08/14] save work

---
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 7ecbd9226b43f..8e261a8f88d2b 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1872,6 +1872,8 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
                MemrefExtractAlignedPointerAsIndexDistribution>(
       patterns.getContext(),
       /*pattern benefit=*/regularPatternBenefit);
+  // For following patterns, we need to override the regular vector distribution
+  // patterns. Therefore, assign higher benefit.
   patterns
       .add<VectorShapeCastDistribution, VectorExtractStridedSliceDistribution,
            VectorInsertStridedSliceDistribution>(

>From 2324fd3e5f8472ac0e52e02c70bc58ce324a2529 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 19 Nov 2025 18:32:34 +0000
Subject: [PATCH 09/14] add negative cases

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 10 +-
 .../XeGPU/subgroup-distribute-unit.mlir       | 97 +++++++++++++------
 2 files changed, 72 insertions(+), 35 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 8e261a8f88d2b..62904f13d61c8 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1665,13 +1665,11 @@ struct VectorInsertStridedSliceDistribution
         !llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; }))
       return rewriter.notifyMatchFailure(
           warpOp, "Expecting unit lane data in source and dest layouts");
-    // Distributed dim sizes must be multiples of subgroup size.
-    if (destDistrDimSize % subgroupSize != 0 ||
-        srcDistrDimSize % subgroupSize != 0)
+    // Source distributed dim size must be multiples of subgroup size.
+    if (srcDistrDimSize % subgroupSize != 0)
       return rewriter.notifyMatchFailure(
-          warpOp,
-          "Distributed dimension size in source or dest is not a multiple of "
-          "subgroup size.");
+          warpOp, "Distributed dimension size in source is not a multiple of "
+                  "subgroup size.");
     // Offsets in the distributed dimension must be multiples of subgroup size.
     int64_t destDistrDimOffset =
         cast<IntegerAttr>(insertOp.getOffsets()[destDistributedDim]).getInt();
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
index 4575a981e2986..93d9b1ea9904a 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
@@ -651,7 +651,6 @@ gpu.func @vector_extract_strided_slice_inner_distributed(%laneid: index) {
   gpu.return
 }
 
-
 // CHECK-LABEL: gpu.func @vector_extract_strided_slice_1d
 // CHECK:         %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>, vector<4xf32>) {
 // CHECK:           %[[S:.*]] = "some_def"() : () -> vector<64xf32>
@@ -674,6 +673,42 @@ gpu.func @vector_extract_strided_slice_1d(%laneid: index) {
   gpu.return
 }
 
+// CHECK-LABEL: gpu.func @vector_extract_strided_slice_unsopported_offset
+// CHECK:         %{{.*}} = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>) {
+// CHECK:         }
+// CHECK-NOT:     %{{.*}} = vector.extract_strided_slice
+gpu.func @vector_extract_strided_slice_unsopported_offset(%laneid: index) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) {
+    %0 = "some_def"() : () -> (vector<64xf32>)
+    %1 = vector.extract_strided_slice %0 { offsets = [3], sizes = [32], strides = [1],
+        layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+        layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+      }
+      : vector<64xf32> to vector<32xf32>
+    gpu.yield %1 : vector<32xf32>
+  }
+  "some_use"(%r) : (vector<2xf32>) -> ()
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_extract_strided_slice_unsopported_source
+// CHECK:         %{{.*}} = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>) {
+// CHECK:         }
+// CHECK-NOT:     %{{.*}} = vector.extract_strided_slice
+gpu.func @vector_extract_strided_slice_unsopported_source(%laneid: index) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) {
+    %0 = "some_def"() : () -> (vector<54xf32>)
+    %1 = vector.extract_strided_slice %0 { offsets = [0], sizes = [32], strides = [1],
+        layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+        layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+      }
+      : vector<54xf32> to vector<32xf32>
+    gpu.yield %1 : vector<32xf32>
+  }
+  "some_use"(%r) : (vector<2xf32>) -> ()
+  gpu.return
+}
+
 // CHECK-LABEL: gpu.func @vector_insert_strided_slice_inner_distributed
 // CHECK:         %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<64x2xf32>, vector<16x1xf32>, vector<64x2xf32>) {
 // CHECK-NEXT:      %[[S:.*]] = "some_def"() : () -> vector<16x16xf32>
@@ -749,39 +784,43 @@ gpu.func @vector_insert_strided_slice_1d(%laneid: index) {
   gpu.return
 }
 
-// CHECK-LABEL: gpu.func @vector_extract_strided_slice_unsopported_offset
-// CHECK:         %{{.*}} = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>) {
-// CHECK:         }
-// CHECK-NOT:     %{{.*}} = vector.extract_strided_slice
-gpu.func @vector_extract_strided_slice_unsopported_offset(%laneid: index) {
-  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) {
-    %0 = "some_def"() : () -> (vector<64xf32>)
-    %1 = vector.extract_strided_slice %0 { offsets = [3], sizes = [32], strides = [1],
-        layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
-        layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
-      }
-      : vector<64xf32> to vector<32xf32>
-    gpu.yield %1 : vector<32xf32>
+// CHECK-LABEL:  gpu.func @vector_insert_strided_slice_unsupported_source
+// CHECK:          %{{.*}} = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<3xf32>) {
+// CHECK:          }
+// CHECK-NOT:      %{{.*}} = vector.insert_strided_slice
+gpu.func @vector_insert_strided_slice_unsupported_source(%laneid: index) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<3xf32>) {
+    %0 = "some_def"() : () -> (vector<8xf32>)
+    %1 = "some_def"() : () -> (vector<48xf32>)
+    %2 = vector.insert_strided_slice %0, %1 { offsets = [16],  strides = [1],
+      layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+      layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+      layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+    }
+      : vector<8xf32> into vector<48xf32>
+    gpu.yield %2 : vector<48xf32>
   }
-  "some_use"(%r) : (vector<2xf32>) -> ()
+  "some_use"(%r) : (vector<3xf32>) -> ()
   gpu.return
 }
 
-// CHECK-LABEL: gpu.func @vector_extract_strided_slice_unsopported_source
-// CHECK:         %{{.*}} = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>) {
-// CHECK:         }
-// CHECK-NOT:     %{{.*}} = vector.extract_strided_slice
-gpu.func @vector_extract_strided_slice_unsopported_source(%laneid: index) {
-  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) {
-    %0 = "some_def"() : () -> (vector<54xf32>)
-    %1 = vector.extract_strided_slice %0 { offsets = [0], sizes = [32], strides = [1],
-        layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
-        layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
-      }
-      : vector<54xf32> to vector<32xf32>
-    gpu.yield %1 : vector<32xf32>
+// CHECK-LABEL:  gpu.func @vector_insert_strided_slice_unsupported_offset
+// CHECK:          %{{.*}} = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<3xf32>) {
+// CHECK:          }
+// CHECK-NOT:      %{{.*}} = vector.insert_strided_slice
+gpu.func @vector_insert_strided_slice_unsupported_offset(%laneid: index) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<3xf32>) {
+    %0 = "some_def"() : () -> (vector<16xf32>)
+    %1 = "some_def"() : () -> (vector<48xf32>)
+    %2 = vector.insert_strided_slice %0, %1 { offsets = [3],  strides = [1],
+      layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+      layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+      layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+    }
+      : vector<16xf32> into vector<48xf32>
+    gpu.yield %2 : vector<48xf32>
   }
-  "some_use"(%r) : (vector<2xf32>) -> ()
+  "some_use"(%r) : (vector<3xf32>) -> ()
   gpu.return
 }
 

>From 5d33f841cffbd20bd5170a1aa9c352899f680903 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 20 Nov 2025 22:42:07 +0000
Subject: [PATCH 10/14] feedback

---
 .../XeGPU/Transforms/XeGPUSubgroupDistribute.cpp       | 10 ++++++----
 1 file changed, 6 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 62904f13d61c8..b6c3e6b2d43ab 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1504,8 +1504,9 @@ struct VectorExtractStridedSliceDistribution
     auto distributedType =
         cast<VectorType>(warpOp.getResult(operandIdx).getType());
     // Find the distributed dimension. There should be exactly one.
-    auto yieldedType = cast<VectorType>(operand->get().getType());
-    auto distributedDims = getDistributedDims(yieldedType, distributedType);
+    auto extractResultType = cast<VectorType>(operand->get().getType());
+    auto distributedDims =
+        getDistributedDims(extractResultType, distributedType);
     // Only single dimension distribution is supported.
     if (distributedDims.size() != 1)
       return rewriter.notifyMatchFailure(
@@ -1616,8 +1617,9 @@ struct VectorInsertStridedSliceDistribution
         cast<VectorType>(warpOp.getResult(operandNumber).getType());
     // Find the distributed dimension of the dest vector. There should be
     // exactly one.
-    auto yieldedType = cast<VectorType>(operand->get().getType());
-    auto destDistributedDims = getDistributedDims(yieldedType, distributedType);
+    auto insertResultType = cast<VectorType>(operand->get().getType());
+    auto destDistributedDims =
+        getDistributedDims(insertResultType, distributedType);
     // Only single dimension distribution is supported.
     if (destDistributedDims.size() != 1)
       return rewriter.notifyMatchFailure(

>From 85e0c4e432040b01a87b6c944d597f749c6be07a Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Mon, 24 Nov 2025 21:21:10 +0000
Subject: [PATCH 11/14] handle simple cases

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 23 ------
 .../XeGPU/subgroup-distribute-unit.mlir       | 82 +++++++++++++++----
 2 files changed, 65 insertions(+), 40 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 5a91983bf20e3..860ad29ba5dad 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1510,24 +1510,8 @@ struct VectorExtractStridedSliceDistribution
       return rewriter.notifyMatchFailure(
           warpOp, "Expecting source to be distributed in a single dimension.");
     int64_t distributedDim = distributedDims[0];
-    // Check if the distributed dimension is fully extracted. If so, we exit
-    // early becuase this case already handled by vector distribution patterns.
-    // Distributed dimension is fully extracted if:
-    //  1) Distributed dim comes after all the extracted dimensions.
-    //  2) Or, the size extacted along the distributed dimension is equal the
-    //  size of that dim in source vector.
-    auto extractedSizes = extractOp.getSizes();
-    if (distributedDim >= static_cast<int64_t>(extractedSizes.size()))
-      return rewriter.notifyMatchFailure(
-          warpOp, "Distributed dimension is fully extracted, skipping.");
-
-    int distrDimExtractedSize =
-        cast<IntegerAttr>(extractOp.getSizes()[distributedDim]).getInt();
     int sourceDistrDimSize =
         extractOp.getSourceVectorType().getShape()[distributedDim];
-    if (distrDimExtractedSize == sourceDistrDimSize)
-      return rewriter.notifyMatchFailure(
-          warpOp, "Distributed dimension is fully extracted, skipping.");
 
     auto sourceLayout =
         xegpu::getDistributeLayoutAttr(extractOp->getOpOperand(0));
@@ -1635,14 +1619,7 @@ struct VectorInsertStridedSliceDistribution
       return rewriter.notifyMatchFailure(
           insertOp, "distributed dimension must be in the last k (i.e. source "
                     "rank) dims of dest vector");
-    // If the distributed dimension is fully inserted, skip. This case is
-    // already handled by vector distribution patterns.
-    int64_t destDistrDimSize = destType.getDimSize(destDistributedDim);
     int64_t srcDistrDimSize = srcType.getDimSize(sourceDistributedDim);
-    if (srcDistrDimSize == destDistrDimSize)
-      return rewriter.notifyMatchFailure(
-          insertOp, "distributed dimension is fully inserted. This case "
-                    "is handled by vector distribution.");
     // Obtain the source and dest layouts.
     auto destLayout = xegpu::getDistributeLayoutAttr(insertOp->getOpOperand(1));
     auto sourceLayout =
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
index 93d9b1ea9904a..fd969634cd544 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
@@ -607,25 +607,25 @@ gpu.func @vector_shapecast_unsupported(%laneid: index) {
 }
 
 
-// CHECK-LABEL:  gpu.func @vector_extract_strided_slice_outer_distributed
-// CHECK:          %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<1x16xf32>, vector<2x16xf32>) {
-// CHECK-NEXT:       %[[S:.*]] = "some_def"() : () -> vector<32x16xf32>
-// CHECK:            gpu.yield %{{.*}}, %[[S]] : vector<16x16xf32>, vector<32x16xf32>
-// CHECK:          }
-// CHECK-NEXT:     %[[T1:.*]] = vector.extract %[[W]]#1[1] : vector<16xf32> from vector<2x16xf32>
-// CHECK-NEXT:     %[[T2:.*]] = vector.shape_cast %[[T1]] : vector<16xf32> to vector<1x16xf32>
-// CHECK-NEXT:     "some_use"(%[[T2]]) : (vector<1x16xf32>) -> ()
-gpu.func @vector_extract_strided_slice_outer_distributed(%laneid: index) {
-  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x16xf32>) {
-    %0 = "some_def"() : () -> (vector<32x16xf32>)
-    %1 = vector.extract_strided_slice %0 { offsets = [16], sizes = [16], strides = [1],
-        layout_operand_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>,
-        layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>
+// CHECK-LABEL:  gpu.func @vector_extract_strided_slice_distributed_dim_fully_extracted
+// CHECK-NEXT:     %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<8x1xf32>, vector<24x1xf32>) {
+// CHECK-NEXT:       %[[S:.*]] = "some_def"() : () -> vector<24x16xf32>
+// CHECK:            gpu.yield %{{.*}}, %[[S]] : vector<8x16xf32>, vector<24x16xf32>
+// CHECK-NEXT:     }
+// CHECK-NEXT:     %[[T1:.*]] = vector.extract_strided_slice %[[W]]#1
+// CHECK-SAME:        {offsets = [8, 0], sizes = [8, 1], strides = [1, 1]} : vector<24x1xf32> to vector<8x1xf32>
+// CHECK-NEXT:     "some_use"(%[[T1]]) : (vector<8x1xf32>) -> ()
+gpu.func @vector_extract_strided_slice_distributed_dim_fully_extracted(%laneid: index) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<8x1xf32>) {
+    %0 = "some_def"() : () -> (vector<24x16xf32>)
+    %1 = vector.extract_strided_slice %0 { offsets = [8, 0], sizes = [8, 16], strides = [1, 1],
+        layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+        layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
       }
-      : vector<32x16xf32> to vector<16x16xf32>
-    gpu.yield %1 : vector<16x16xf32>
+      : vector<24x16xf32> to vector<8x16xf32>
+    gpu.yield %1 : vector<8x16xf32>
   }
-  "some_use"(%r) : (vector<1x16xf32>) -> ()
+  "some_use"(%r) : (vector<8x1xf32>) -> ()
   gpu.return
 }
 
@@ -651,6 +651,28 @@ gpu.func @vector_extract_strided_slice_inner_distributed(%laneid: index) {
   gpu.return
 }
 
+// CHECK-LABEL:  gpu.func @vector_extract_strided_slice_outer_distributed
+// CHECK:          %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<1x16xf32>, vector<2x16xf32>) {
+// CHECK-NEXT:       %[[S:.*]] = "some_def"() : () -> vector<32x16xf32>
+// CHECK:            gpu.yield %{{.*}}, %[[S]] : vector<16x16xf32>, vector<32x16xf32>
+// CHECK:          }
+// CHECK-NEXT:     %[[T1:.*]] = vector.extract %[[W]]#1[1] : vector<16xf32> from vector<2x16xf32>
+// CHECK-NEXT:     %[[T2:.*]] = vector.shape_cast %[[T1]] : vector<16xf32> to vector<1x16xf32>
+// CHECK-NEXT:     "some_use"(%[[T2]]) : (vector<1x16xf32>) -> ()
+gpu.func @vector_extract_strided_slice_outer_distributed(%laneid: index) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x16xf32>) {
+    %0 = "some_def"() : () -> (vector<32x16xf32>)
+    %1 = vector.extract_strided_slice %0 { offsets = [16], sizes = [16], strides = [1],
+        layout_operand_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>,
+        layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>
+      }
+      : vector<32x16xf32> to vector<16x16xf32>
+    gpu.yield %1 : vector<16x16xf32>
+  }
+  "some_use"(%r) : (vector<1x16xf32>) -> ()
+  gpu.return
+}
+
 // CHECK-LABEL: gpu.func @vector_extract_strided_slice_1d
 // CHECK:         %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>, vector<4xf32>) {
 // CHECK:           %[[S:.*]] = "some_def"() : () -> vector<64xf32>
@@ -709,6 +731,32 @@ gpu.func @vector_extract_strided_slice_unsopported_source(%laneid: index) {
   gpu.return
 }
 
+
+// CHECK-LABEL:  gpu.func @vector_insert_strided_slice_distributed_dim_fully_inserted
+// CHECK-NEXT:      %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<64x1xf32>, vector<16x1xf32>, vector<64x1xf32>) {
+// CHECK-NEXT:        %[[S:.*]] = "some_def"() : () -> vector<16x16xf32>
+// CHECK-NEXT:        %[[D:.*]] = "some_def"() : () -> vector<64x16xf32>
+// CHECK:             gpu.yield %{{.*}}, %[[S]], %[[D]] : vector<64x16xf32>, vector<16x16xf32>, vector<64x16xf32>
+// CHECK-NEXT:      }
+// CHECK-NEXT:      %[[T1:.*]] = vector.insert_strided_slice %[[W]]#1, %[[W]]#2
+// CHECK-SAME:        {offsets = [24, 0], strides = [1, 1]} : vector<16x1xf32> into vector<64x1xf32>
+// CHECK-NEXT:      "some_use"(%[[T1]]) : (vector<64x1xf32>) -> ()
+gpu.func @vector_insert_strided_slice_distributed_dim_fully_inserted(%laneid: index) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<64x1xf32>) {
+    %0 = "some_def"() : () -> (vector<16x16xf32>)
+    %1 = "some_def"() : () -> (vector<64x16xf32>)
+    %2 = vector.insert_strided_slice %0, %1 { offsets = [24, 0],  strides = [1, 1],
+      layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+      layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+      layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+    }
+      : vector<16x16xf32> into vector<64x16xf32>
+    gpu.yield %2 : vector<64x16xf32>
+  }
+  "some_use"(%r) : (vector<64x1xf32>) -> ()
+  gpu.return
+}
+
 // CHECK-LABEL: gpu.func @vector_insert_strided_slice_inner_distributed
 // CHECK:         %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<64x2xf32>, vector<16x1xf32>, vector<64x2xf32>) {
 // CHECK-NEXT:      %[[S:.*]] = "some_def"() : () -> vector<16x16xf32>

>From 6a0db8851796647ba7c09a9545e4bb66d42a0f48 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Mon, 24 Nov 2025 22:39:29 +0000
Subject: [PATCH 12/14] handle simple cases

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 272 +++++++++---------
 .../XeGPU/subgroup-distribute-unit.mlir       |  48 ++++
 2 files changed, 191 insertions(+), 129 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 860ad29ba5dad..1e01e86dba85b 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1505,68 +1505,72 @@ struct VectorExtractStridedSliceDistribution
     auto extractResultType = cast<VectorType>(operand->get().getType());
     auto distributedDims =
         getDistributedDims(extractResultType, distributedType);
-    // Only single dimension distribution is supported.
-    if (distributedDims.size() != 1)
-      return rewriter.notifyMatchFailure(
-          warpOp, "Expecting source to be distributed in a single dimension.");
-    int64_t distributedDim = distributedDims[0];
-    int sourceDistrDimSize =
-        extractOp.getSourceVectorType().getShape()[distributedDim];
-
-    auto sourceLayout =
-        xegpu::getDistributeLayoutAttr(extractOp->getOpOperand(0));
-    if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt().empty())
-      return rewriter.notifyMatchFailure(
-          warpOp, "the source of extract_strided_slice op lacks distribution "
-                  "layout");
-    auto sourceLaneLayout = sourceLayout.getEffectiveLaneLayoutAsInt();
-    // Because only single dimension distribution is supported, lane layout size
-    // at the distributed dim must be the subgroup size.
-    int subgroupSize = sourceLaneLayout[distributedDim];
-    // Check if the source size in the distributed dimension is a multiple of
-    // subgroup size.
-    if (sourceDistrDimSize % subgroupSize != 0)
-      return rewriter.notifyMatchFailure(
-          warpOp,
-          "Source size along distributed dimension is not a multiple of "
-          "subgroup size.");
-    auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
-    // We expect lane data to be all ones in this case.
-    if (!llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; }))
-      return rewriter.notifyMatchFailure(
-          warpOp, "Expecting unit lane data in source layout");
-    // The offsets in the distributed dimention must be a multiple of subgroup
-    // size.
-    int64_t distrDimOffset =
-        cast<IntegerAttr>(extractOp.getOffsets()[distributedDim]).getInt();
-    if (distrDimOffset % subgroupSize != 0)
-      return rewriter.notifyMatchFailure(warpOp,
-                                         "Offset along distributed dimension "
-                                         "is not a multiple of subgroup size.");
-    // Do the distribution by yielding the source of the extract op from
-    // the warp op and creating a new extract op outside the warp op.
-    VectorType sourceDistType =
-        getDistVecTypeBasedOnLaneLayout(sourceLayout,
-                                        extractOp.getSourceVectorType())
-            .value();
+    // Source distributed type must be adjusted for the distributed case.
+    VectorType sourceDistType = extractOp.getSourceVectorType();
+    // Distributed sizes and offsets must be adjusted for distributed case.
+    SmallVector<Attribute> distributedSizes = llvm::map_to_vector(
+        extractOp.getSizes(), [](Attribute attr) { return attr; });
+    SmallVector<Attribute> distributedOffsets = llvm::map_to_vector(
+        extractOp.getOffsets(), [](Attribute attr) { return attr; });
+    // If the result is distributed, it must be distributed in exactly one
+    // dimension. In this case, we adjust the sourceDistType, distributedSizes
+    // and distributedOffsets accordingly.
+    if (distributedDims.size() > 0) {
+      if (distributedDims.size() != 1)
+        return rewriter.notifyMatchFailure(
+            warpOp, "Source can not be distributed in multiple dimensions.");
+      int64_t distributedDim = distributedDims[0];
+      int sourceDistrDimSize =
+          extractOp.getSourceVectorType().getShape()[distributedDim];
+      auto sourceLayout =
+          xegpu::getDistributeLayoutAttr(extractOp->getOpOperand(0));
+      if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt().empty())
+        return rewriter.notifyMatchFailure(
+            warpOp, "the source of extract_strided_slice op lacks distribution "
+                    "layout");
+      auto sourceLaneLayout = sourceLayout.getEffectiveLaneLayoutAsInt();
+      // Because only single dimension distribution is supported, lane layout
+      // size at the distributed dim must be the subgroup size.
+      int subgroupSize = sourceLaneLayout[distributedDim];
+      // Check if the source size in the distributed dimension is a multiple of
+      // subgroup size.
+      if (sourceDistrDimSize % subgroupSize != 0)
+        return rewriter.notifyMatchFailure(
+            warpOp,
+            "Source size along distributed dimension is not a multiple of "
+            "subgroup size.");
+      auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
+      // We expect lane data to be all ones in this case.
+      if (!llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; }))
+        return rewriter.notifyMatchFailure(
+            warpOp, "Expecting unit lane data in source layout");
+      // The offsets in the distributed dimention must be a multiple of subgroup
+      // size.
+      int64_t distrDimOffset =
+          cast<IntegerAttr>(extractOp.getOffsets()[distributedDim]).getInt();
+      if (distrDimOffset % subgroupSize != 0)
+        return rewriter.notifyMatchFailure(
+            warpOp, "Offset along distributed dimension "
+                    "is not a multiple of subgroup size.");
+      // Do the distribution by yielding the source of the extract op from
+      // the warp op and creating a new extract op outside the warp op.
+      sourceDistType = getDistVecTypeBasedOnLaneLayout(
+                           sourceLayout, extractOp.getSourceVectorType())
+                           .value();
+      // Update the distributed sizes to match the distributed type.
+      distributedSizes[distributedDim] = rewriter.getI64IntegerAttr(
+          distributedType.getDimSize(distributedDim));
+      // Update the distributed offsets to match round robin distribution (i.e.
+      // each lane owns data at `subgroupSize` stride given unit lane data).
+      distributedOffsets[distributedDim] =
+          rewriter.getI64IntegerAttr(distrDimOffset / subgroupSize);
+    }
     // Create a new warp op that yields the source of the extract op.
     SmallVector<size_t> newRetIndices;
     auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
         rewriter, warpOp, {extractOp.getSource()}, {sourceDistType},
         newRetIndices);
     rewriter.setInsertionPointAfter(newWarpOp);
-    // Distributed sizes and offsets must be adjusted.
-    SmallVector<Attribute> distributedSizes = llvm::map_to_vector(
-        extractOp.getSizes(), [](Attribute attr) { return attr; });
-    SmallVector<Attribute> distributedOffsets = llvm::map_to_vector(
-        extractOp.getOffsets(), [](Attribute attr) { return attr; });
-    // Update the distributed sizes to match the distributed type.
-    distributedSizes[distributedDim] =
-        rewriter.getI64IntegerAttr(distributedType.getDimSize(distributedDim));
-    // Update the distributed offsets to match round robin distribution (i.e.
-    // each lane owns data at `subgroupSize` stride given unit lane data).
-    distributedOffsets[distributedDim] =
-        rewriter.getI64IntegerAttr(distrDimOffset / subgroupSize);
     Value source = newWarpOp.getResult(newRetIndices[0]);
     // Create a new extract op outside the warp op.
     Value newExtractOp = vector::ExtractStridedSliceOp::create(
@@ -1602,87 +1606,97 @@ struct VectorInsertStridedSliceDistribution
     auto insertResultType = cast<VectorType>(operand->get().getType());
     auto destDistributedDims =
         getDistributedDims(insertResultType, distributedType);
-    // Only single dimension distribution is supported.
-    if (destDistributedDims.size() != 1)
-      return rewriter.notifyMatchFailure(
-          warpOp, "Expecting source to be distributed in a single dimension.");
-    int64_t destDistributedDim = destDistributedDims[0];
-
-    VectorType srcType = insertOp.getSourceVectorType();
-    VectorType destType = insertOp.getDestVectorType();
-    // Currently we require that both source (kD) and dest (nD) vectors are
-    // distributed. This requires that distributedDim (d) is contained in the
-    // last k dims of the dest vector (d >= n - k).
-    int64_t sourceDistributedDim =
-        destDistributedDim - (destType.getRank() - srcType.getRank());
-    if (sourceDistributedDim < 0)
-      return rewriter.notifyMatchFailure(
-          insertOp, "distributed dimension must be in the last k (i.e. source "
-                    "rank) dims of dest vector");
-    int64_t srcDistrDimSize = srcType.getDimSize(sourceDistributedDim);
-    // Obtain the source and dest layouts.
-    auto destLayout = xegpu::getDistributeLayoutAttr(insertOp->getOpOperand(1));
-    auto sourceLayout =
-        xegpu::getDistributeLayoutAttr(insertOp->getOpOperand(0));
-    if (!destLayout || !sourceLayout ||
-        destLayout.getEffectiveLaneLayoutAsInt().empty() ||
-        sourceLayout.getEffectiveLaneLayoutAsInt().empty())
-      return rewriter.notifyMatchFailure(
-          warpOp, "the source or dest of insert_strided_slice op lacks "
-                  "distribution layout");
-    // Because only single dimension distribution is supported, lane layout
-    // size at the distributed dim must be the subgroup size.
-    int subgroupSize =
-        destLayout.getEffectiveLaneLayoutAsInt()[destDistributedDim];
-    // We require that source and dest lane data are all ones to ensure uniform
-    // round robin distribution.
-    auto destLaneData = destLayout.getEffectiveLaneDataAsInt();
-    auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
-    if (!llvm::all_of(destLaneData, [](int64_t v) { return v == 1; }) ||
-        !llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; }))
-      return rewriter.notifyMatchFailure(
-          warpOp, "Expecting unit lane data in source and dest layouts");
-    // Source distributed dim size must be multiples of subgroup size.
-    if (srcDistrDimSize % subgroupSize != 0)
-      return rewriter.notifyMatchFailure(
-          warpOp, "Distributed dimension size in source is not a multiple of "
-                  "subgroup size.");
-    // Offsets in the distributed dimension must be multiples of subgroup size.
-    int64_t destDistrDimOffset =
-        cast<IntegerAttr>(insertOp.getOffsets()[destDistributedDim]).getInt();
-    if (destDistrDimOffset % subgroupSize != 0)
-      return rewriter.notifyMatchFailure(
-          warpOp,
-          "Offset along distributed dimension in dest is not a multiple of "
-          "subgroup size.");
-    // Do the distribution by yielding the source and dest of the insert op from
-    // the warp op and creating a new insert op outside the warp op.
-    VectorType sourceDistType =
-        getDistVecTypeBasedOnLaneLayout(sourceLayout,
-                                        insertOp.getSourceVectorType())
-            .value();
-    VectorType destDistType = getDistVecTypeBasedOnLaneLayout(
-                                  destLayout, insertOp.getDestVectorType())
-                                  .value();
-    // Create a new warp op that yields the source and dest of the insert op.
+    // Collect updated offsets, source type and dest type. They may be updated
+    // later if the data is distributed to lanes (as opposed to being owned by
+    // all lanes uniformly).
+    SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
+        insertOp.getOffsets(), [](Attribute attr) { return attr; });
+    VectorType updatedSourceType = insertOp.getSourceVectorType();
+    VectorType updatedDestType = insertOp.getDestVectorType();
+    if (destDistributedDims.size() > 0) {
+      // Only single dimension distribution is supported.
+      if (destDistributedDims.size() != 1)
+        return rewriter.notifyMatchFailure(
+            warpOp,
+            "Expecting source to be distributed in a single dimension.");
+      int64_t destDistributedDim = destDistributedDims[0];
+
+      VectorType srcType = insertOp.getSourceVectorType();
+      VectorType destType = insertOp.getDestVectorType();
+      // Currently we require that both source (kD) and dest (nD) vectors are
+      // distributed. This requires that distributedDim (d) is contained in the
+      // last k dims of the dest vector (d >= n - k).
+      int64_t sourceDistributedDim =
+          destDistributedDim - (destType.getRank() - srcType.getRank());
+      if (sourceDistributedDim < 0)
+        return rewriter.notifyMatchFailure(
+            insertOp,
+            "distributed dimension must be in the last k (i.e. source "
+            "rank) dims of dest vector");
+      int64_t srcDistrDimSize = srcType.getDimSize(sourceDistributedDim);
+      // Obtain the source and dest layouts.
+      auto destLayout =
+          xegpu::getDistributeLayoutAttr(insertOp->getOpOperand(1));
+      auto sourceLayout =
+          xegpu::getDistributeLayoutAttr(insertOp->getOpOperand(0));
+      if (!destLayout || !sourceLayout ||
+          destLayout.getEffectiveLaneLayoutAsInt().empty() ||
+          sourceLayout.getEffectiveLaneLayoutAsInt().empty())
+        return rewriter.notifyMatchFailure(
+            warpOp, "the source or dest of insert_strided_slice op lacks "
+                    "distribution layout");
+      // Because only single dimension distribution is supported, lane layout
+      // size at the distributed dim must be the subgroup size.
+      int subgroupSize =
+          destLayout.getEffectiveLaneLayoutAsInt()[destDistributedDim];
+      // We require that source and dest lane data are all ones to ensure
+      // uniform round robin distribution.
+      auto destLaneData = destLayout.getEffectiveLaneDataAsInt();
+      auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
+      if (!llvm::all_of(destLaneData, [](int64_t v) { return v == 1; }) ||
+          !llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; }))
+        return rewriter.notifyMatchFailure(
+            warpOp, "Expecting unit lane data in source and dest layouts");
+      // Source distributed dim size must be multiples of subgroup size.
+      if (srcDistrDimSize % subgroupSize != 0)
+        return rewriter.notifyMatchFailure(
+            warpOp, "Distributed dimension size in source is not a multiple of "
+                    "subgroup size.");
+      // Offsets in the distributed dimension must be multiples of subgroup
+      // size.
+      int64_t destDistrDimOffset =
+          cast<IntegerAttr>(insertOp.getOffsets()[destDistributedDim]).getInt();
+      if (destDistrDimOffset % subgroupSize != 0)
+        return rewriter.notifyMatchFailure(
+            warpOp,
+            "Offset along distributed dimension in dest is not a multiple of "
+            "subgroup size.");
+      // Update the source and dest types based on their layouts.
+      updatedSourceType = getDistVecTypeBasedOnLaneLayout(
+                              sourceLayout, insertOp.getSourceVectorType())
+                              .value();
+      updatedDestType = getDistVecTypeBasedOnLaneLayout(
+                            destLayout, insertOp.getDestVectorType())
+                            .value();
+      // Update the distributed offsets to match round robin distribution (i.e.
+      // each lane owns data at `subgroupSize` stride given unit lane data).
+      updatedOffsets[destDistributedDim] =
+          rewriter.getI64IntegerAttr(destDistrDimOffset / subgroupSize);
+    }
+    // Do the distribution by yielding the source and dest of the insert op
+    // from the warp op and creating a new insert op outside the warp op.
     SmallVector<size_t> newRetIndices;
     auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
         rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
-        {sourceDistType, destDistType}, newRetIndices);
+        {updatedSourceType, updatedDestType}, newRetIndices);
     rewriter.setInsertionPointAfter(newWarpOp);
-    // Distributed offsets must be adjusted.
-    SmallVector<Attribute> distributedOffsets = llvm::map_to_vector(
-        insertOp.getOffsets(), [](Attribute attr) { return attr; });
-    // Update the distributed offsets to match round robin distribution (i.e.
-    // each lane owns data at `subgroupSize` stride given unit lane data).
-    distributedOffsets[destDistributedDim] =
-        rewriter.getI64IntegerAttr(destDistrDimOffset / subgroupSize);
+
     Value valueToStore = newWarpOp.getResult(newRetIndices[0]);
     Value dest = newWarpOp.getResult(newRetIndices[1]);
     // Create a new insert op outside the warp op.
     Value newInsertOp = vector::InsertStridedSliceOp::create(
-        rewriter, insertOp.getLoc(), destDistType, valueToStore, dest,
-        ArrayAttr::get(rewriter.getContext(), distributedOffsets),
+        rewriter, insertOp.getLoc(), updatedDestType, valueToStore, dest,
+        ArrayAttr::get(rewriter.getContext(), updatedOffsets),
         insertOp.getStrides());
     rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber),
                                 newInsertOp);
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
index fd969634cd544..44ec21359593f 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
@@ -629,6 +629,28 @@ gpu.func @vector_extract_strided_slice_distributed_dim_fully_extracted(%laneid:
   gpu.return
 }
 
+// CHECK-LABEL: gpu.func @vector_extract_strided_slice_non_distributed
+// CHECK-NEXT:    %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<8x1xf32>, vector<24x1xf32>) {
+// CHECK-NEXT:      %[[S:.*]] = "some_def"() : () -> vector<24x1xf32>
+// CHECK:           gpu.yield %{{.*}}, %[[S]] : vector<8x1xf32>, vector<24x1xf32>
+// CHECK-NEXT:    }
+// CHECK-NEXT:    %[[T1:.*]] = vector.extract_strided_slice %[[W]]#1
+// CHECK-SAME:      {offsets = [8, 0], sizes = [8, 1], strides = [1, 1]} : vector<24x1xf32> to vector<8x1xf32>
+// CHECK-NEXT:    "some_use"(%[[T1]]) : (vector<8x1xf32>) -> ()
+gpu.func @vector_extract_strided_slice_non_distributed(%laneid: index) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<8x1xf32>) {
+    %0 = "some_def"() : () -> (vector<24x1xf32>)
+    %1 = vector.extract_strided_slice %0 { offsets = [8, 0], sizes = [8, 1], strides = [1, 1],
+        layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+        layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+      }
+      : vector<24x1xf32> to vector<8x1xf32>
+    gpu.yield %1 : vector<8x1xf32>
+  }
+  "some_use"(%r) : (vector<8x1xf32>) -> ()
+  gpu.return
+}
+
 // CHECK-LABEL: gpu.func @vector_extract_strided_slice_inner_distributed
 // CHECK:         %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<8x1xf32>, vector<24x4xf32>) {
 // CHECK-NEXT:      %[[S:.*]] = "some_def"() : () -> vector<24x64xf32>
@@ -757,6 +779,32 @@ gpu.func @vector_insert_strided_slice_distributed_dim_fully_inserted(%laneid: in
   gpu.return
 }
 
+
+// CHECK-LABEL: gpu.func @vector_insert_strided_slice_non_distributed
+// CHECK-NEXT:    %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<64x1xf32>, vector<16x1xf32>, vector<64x1xf32>) {
+// CHECK-NEXT:      %[[S:.*]] = "some_def"() : () -> vector<16x1xf32>
+// CHECK-NEXT:      %[[D:.*]] = "some_def"() : () -> vector<64x1xf32>
+// CHECK:           gpu.yield %{{.*}}, %[[S]], %[[D]] : vector<64x1xf32>, vector<16x1xf32>, vector<64x1xf32>
+// CHECK-NEXT:    }
+// CHECK-NEXT:    %[[T1:.*]] = vector.insert_strided_slice %[[W]]#1, %[[W]]#2
+// CHECK-SAME:      {offsets = [24, 0], strides = [1, 1]} : vector<16x1xf32> into vector<64x1xf32>
+// CHECK-NEXT:    "some_use"(%[[T1]]) : (vector<64x1xf32>) -> ()
+gpu.func @vector_insert_strided_slice_non_distributed(%laneid: index) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<64x1xf32>) {
+    %0 = "some_def"() : () -> (vector<16x1xf32>)
+    %1 = "some_def"() : () -> (vector<64x1xf32>)
+    %2 = vector.insert_strided_slice %0, %1 { offsets = [24, 0],  strides = [1, 1],
+      layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+      layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+      layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+    }
+      : vector<16x1xf32> into vector<64x1xf32>
+    gpu.yield %2 : vector<64x1xf32>
+  }
+  "some_use"(%r) : (vector<64x1xf32>) -> ()
+  gpu.return
+}
+
 // CHECK-LABEL: gpu.func @vector_insert_strided_slice_inner_distributed
 // CHECK:         %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<64x2xf32>, vector<16x1xf32>, vector<64x2xf32>) {
 // CHECK-NEXT:      %[[S:.*]] = "some_def"() : () -> vector<16x16xf32>

>From 4dfbe7c2cabc5673c925eebe471ec5dea5cb446f Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Mon, 24 Nov 2025 22:41:16 +0000
Subject: [PATCH 13/14] add comment

---
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 1e01e86dba85b..448fb48d23879 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -175,6 +175,8 @@ static bool requireTranspose(const xegpu::LayoutAttr layout,
   return laneLayout[0] == uArch->getSubgroupSize() && laneLayout[1] == 1;
 }
 
+/// Given a sequential and distributed vector type, return the list of
+/// dimensions that are distributed.
 static SmallVector<int64_t> getDistributedDims(VectorType sequentialType,
                                                VectorType distributedType) {
   assert(sequentialType.getRank() == distributedType.getRank() &&

>From b96538cc97100092aa927f100248a6a005851d13 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Mon, 24 Nov 2025 22:45:42 +0000
Subject: [PATCH 14/14] add comment

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 34 +++++++++----------
 1 file changed, 17 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 448fb48d23879..fe7f10908d986 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1507,12 +1507,13 @@ struct VectorExtractStridedSliceDistribution
     auto extractResultType = cast<VectorType>(operand->get().getType());
     auto distributedDims =
         getDistributedDims(extractResultType, distributedType);
-    // Source distributed type must be adjusted for the distributed case.
-    VectorType sourceDistType = extractOp.getSourceVectorType();
-    // Distributed sizes and offsets must be adjusted for distributed case.
-    SmallVector<Attribute> distributedSizes = llvm::map_to_vector(
+    // Collect updated source type, sizes and offsets. They may be adjusted
+    // later if the data is distributed to lanes (as opposed to being owned by
+    // all lanes uniformly).
+    VectorType updatedSourceType = extractOp.getSourceVectorType();
+    SmallVector<Attribute> updatedSizes = llvm::map_to_vector(
         extractOp.getSizes(), [](Attribute attr) { return attr; });
-    SmallVector<Attribute> distributedOffsets = llvm::map_to_vector(
+    SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
         extractOp.getOffsets(), [](Attribute attr) { return attr; });
     // If the result is distributed, it must be distributed in exactly one
     // dimension. In this case, we adjust the sourceDistType, distributedSizes
@@ -1554,31 +1555,30 @@ struct VectorExtractStridedSliceDistribution
         return rewriter.notifyMatchFailure(
             warpOp, "Offset along distributed dimension "
                     "is not a multiple of subgroup size.");
-      // Do the distribution by yielding the source of the extract op from
-      // the warp op and creating a new extract op outside the warp op.
-      sourceDistType = getDistVecTypeBasedOnLaneLayout(
-                           sourceLayout, extractOp.getSourceVectorType())
-                           .value();
+      updatedSourceType = getDistVecTypeBasedOnLaneLayout(
+                              sourceLayout, extractOp.getSourceVectorType())
+                              .value();
       // Update the distributed sizes to match the distributed type.
-      distributedSizes[distributedDim] = rewriter.getI64IntegerAttr(
+      updatedSizes[distributedDim] = rewriter.getI64IntegerAttr(
           distributedType.getDimSize(distributedDim));
       // Update the distributed offsets to match round robin distribution (i.e.
       // each lane owns data at `subgroupSize` stride given unit lane data).
-      distributedOffsets[distributedDim] =
+      updatedOffsets[distributedDim] =
           rewriter.getI64IntegerAttr(distrDimOffset / subgroupSize);
     }
-    // Create a new warp op that yields the source of the extract op.
+    // Do the distribution by yielding the source of the extract op from
+    // the warp op and creating a new extract op outside the warp op.
     SmallVector<size_t> newRetIndices;
     auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-        rewriter, warpOp, {extractOp.getSource()}, {sourceDistType},
+        rewriter, warpOp, {extractOp.getSource()}, {updatedSourceType},
         newRetIndices);
     rewriter.setInsertionPointAfter(newWarpOp);
     Value source = newWarpOp.getResult(newRetIndices[0]);
     // Create a new extract op outside the warp op.
     Value newExtractOp = vector::ExtractStridedSliceOp::create(
         rewriter, extractOp.getLoc(), distributedType, source,
-        ArrayAttr::get(rewriter.getContext(), distributedOffsets),
-        ArrayAttr::get(rewriter.getContext(), distributedSizes),
+        ArrayAttr::get(rewriter.getContext(), updatedOffsets),
+        ArrayAttr::get(rewriter.getContext(), updatedSizes),
         extractOp.getStrides());
     rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), newExtractOp);
     return success();
@@ -1608,7 +1608,7 @@ struct VectorInsertStridedSliceDistribution
     auto insertResultType = cast<VectorType>(operand->get().getType());
     auto destDistributedDims =
         getDistributedDims(insertResultType, distributedType);
-    // Collect updated offsets, source type and dest type. They may be updated
+    // Collect updated offsets, source type and dest type. They may be adjusted
     // later if the data is distributed to lanes (as opposed to being owned by
     // all lanes uniformly).
     SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(



More information about the Mlir-commits mailing list