[Mlir-commits] [mlir] c333f7d - [mlir][xegpu] Add layout based SIMT distribution support for `vector.extract/insert_strided_slice` (#168626)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 26 08:10:47 PST 2025


Author: Charitha Saumya
Date: 2025-11-26T10:10:36-06:00
New Revision: c333f7dab9f89734777f7d19bc7b68c86f393216

URL: https://github.com/llvm/llvm-project/commit/c333f7dab9f89734777f7d19bc7b68c86f393216
DIFF: https://github.com/llvm/llvm-project/commit/c333f7dab9f89734777f7d19bc7b68c86f393216.diff

LOG: [mlir][xegpu] Add layout based SIMT distribution support for `vector.extract/insert_strided_slice` (#168626)

This PR adds general SIMT distribution support for
`vector.extract/insert_strided_slice`. Currently vector distribution
already have support for these operations but have restrictions to avoid
requiring layouts during distribution logic. For example,
`extract_stride_slice` require that distributed dimension is fully
extracted. However, more complex cases may require extracting partially
from distributed dimension (eg. 8x16xf16 extraction from 8x32xf16).
These types of cases need the layouts to reason about how the data is
spread across SIMT lanes.

Currently, we don't have layout access in vector distribution so these
new patterns are place in XeGPU side. They have higher pattern benefit
so that they will be tried first before trying regular vector
distribution based patterns.

Added: 
    

Modified: 
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
    mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index b64eb5b29ccb0..0d1c5eeeff711 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -174,6 +174,21 @@ static bool requireTranspose(const xegpu::LayoutAttr layout,
   return laneLayout[0] == uArch->getSubgroupSize() && laneLayout[1] == 1;
 }
 
+/// Given a vector type and its distributed vector type, return the list of
+/// dimensions that are distributed.
+static SmallVector<int64_t> getDistributedDims(VectorType originalType,
+                                               VectorType distributedType) {
+  assert(originalType.getRank() == distributedType.getRank() &&
+         "sequential and distributed vector types must have the same rank");
+  SmallVector<int64_t> distributedDims;
+  for (int64_t i = 0; i < originalType.getRank(); ++i) {
+    if (distributedType.getDimSize(i) != originalType.getDimSize(i)) {
+      distributedDims.push_back(i);
+    }
+  }
+  return distributedDims;
+}
+
 /// Given a GPUFuncOp, this pattern creates a new GPUFuncOp and moves the body
 /// of the original GPUFuncOp to the new GPUFuncOp such that entire body is
 /// contained within a WarpExecuteOnLane0Op.
@@ -1469,6 +1484,226 @@ struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
   }
 };
 
+// Distribute a `vector.extract_strided_slice` op feeding into yield op of an
+// enclosing `gpu.warp_execute_on_lane_0` region. This pattern covers
+// advanced cases where the distributed dimension is partially extracted and
+// currently not supported by the generic vector distribution patterns.
+struct VectorExtractStridedSliceDistribution
+    : public gpu::WarpDistributionPattern {
+  using gpu::WarpDistributionPattern::WarpDistributionPattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *operand =
+        getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
+    if (!operand)
+      return failure();
+    auto extractOp =
+        cast<vector::ExtractStridedSliceOp>(operand->get().getDefiningOp());
+    unsigned operandIdx = operand->getOperandNumber();
+    auto distributedType =
+        cast<VectorType>(warpOp.getResult(operandIdx).getType());
+    // Find the distributed dimensions.
+    auto extractResultType = cast<VectorType>(operand->get().getType());
+    auto distributedDims =
+        getDistributedDims(extractResultType, distributedType);
+    // Collect updated source type, sizes and offsets. They may be adjusted
+    // later if the data is distributed to lanes (as opposed to being owned by
+    // all lanes uniformly).
+    VectorType updatedSourceType = extractOp.getSourceVectorType();
+    SmallVector<Attribute> updatedSizes = llvm::map_to_vector(
+        extractOp.getSizes(), [](Attribute attr) { return attr; });
+    SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
+        extractOp.getOffsets(), [](Attribute attr) { return attr; });
+    // If the result is distributed, it must be distributed in exactly one
+    // dimension. In this case, we adjust the sourceDistType, distributedSizes
+    // and distributedOffsets accordingly.
+    if (distributedDims.size() > 0) {
+      if (distributedDims.size() != 1)
+        return rewriter.notifyMatchFailure(
+            warpOp, "Source can not be distributed in multiple dimensions.");
+      int64_t distributedDim = distributedDims[0];
+      int sourceDistrDimSize =
+          extractOp.getSourceVectorType().getShape()[distributedDim];
+      auto sourceLayout =
+          xegpu::getDistributeLayoutAttr(extractOp->getOpOperand(0));
+      if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt().empty())
+        return rewriter.notifyMatchFailure(
+            warpOp, "the source of extract_strided_slice op lacks distribution "
+                    "layout");
+      auto sourceLaneLayout = sourceLayout.getEffectiveLaneLayoutAsInt();
+      // Because only single dimension distribution is supported, lane layout
+      // size at the distributed dim must be the subgroup size.
+      int subgroupSize = sourceLaneLayout[distributedDim];
+      // Check if the source size in the distributed dimension is a multiple of
+      // subgroup size.
+      if (sourceDistrDimSize % subgroupSize != 0)
+        return rewriter.notifyMatchFailure(
+            warpOp,
+            "Source size along distributed dimension is not a multiple of "
+            "subgroup size.");
+      auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
+      // We expect lane data to be all ones in this case.
+      if (!llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; }))
+        return rewriter.notifyMatchFailure(
+            warpOp, "Expecting unit lane data in source layout");
+      // The offsets in the distributed dimention must be a multiple of subgroup
+      // size.
+      int64_t distrDimOffset =
+          cast<IntegerAttr>(extractOp.getOffsets()[distributedDim]).getInt();
+      if (distrDimOffset % subgroupSize != 0)
+        return rewriter.notifyMatchFailure(
+            warpOp, "Offset along distributed dimension "
+                    "is not a multiple of subgroup size.");
+      updatedSourceType = getDistVecTypeBasedOnLaneLayout(
+                              sourceLayout, extractOp.getSourceVectorType())
+                              .value();
+      // Update the distributed sizes to match the distributed type.
+      updatedSizes[distributedDim] = rewriter.getI64IntegerAttr(
+          distributedType.getDimSize(distributedDim));
+      // Update the distributed offsets to match round robin distribution (i.e.
+      // each lane owns data at `subgroupSize` stride given unit lane data).
+      updatedOffsets[distributedDim] =
+          rewriter.getI64IntegerAttr(distrDimOffset / subgroupSize);
+    }
+    // Do the distribution by yielding the source of the extract op from
+    // the warp op and creating a new extract op outside the warp op.
+    SmallVector<size_t> newRetIndices;
+    auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, {extractOp.getSource()}, {updatedSourceType},
+        newRetIndices);
+    rewriter.setInsertionPointAfter(newWarpOp);
+    Value source = newWarpOp.getResult(newRetIndices[0]);
+    // Create a new extract op outside the warp op.
+    Value newExtractOp = vector::ExtractStridedSliceOp::create(
+        rewriter, extractOp.getLoc(), distributedType, source,
+        ArrayAttr::get(rewriter.getContext(), updatedOffsets),
+        ArrayAttr::get(rewriter.getContext(), updatedSizes),
+        extractOp.getStrides());
+    rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), newExtractOp);
+    return success();
+  }
+};
+
+/// Distribute a `vector.insert_strided_slice` op feeding into yield op of an
+/// enclosing `gpu.warp_execute_on_lane_0` region. This pattern covers
+/// advanced cases where the distributed dimension is partially inserted and
+/// currently not supported by the generic vector distribution patterns.
+struct VectorInsertStridedSliceDistribution
+    : public gpu::WarpDistributionPattern {
+  using gpu::WarpDistributionPattern::WarpDistributionPattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *operand =
+        getWarpResult(warpOp, llvm::IsaPred<vector::InsertStridedSliceOp>);
+    if (!operand)
+      return failure();
+    unsigned int operandNumber = operand->getOperandNumber();
+    auto insertOp =
+        operand->get().getDefiningOp<vector::InsertStridedSliceOp>();
+    auto distributedType =
+        cast<VectorType>(warpOp.getResult(operandNumber).getType());
+    // Find the distributed dimensions of the dest vector.
+    auto insertResultType = cast<VectorType>(operand->get().getType());
+    auto destDistributedDims =
+        getDistributedDims(insertResultType, distributedType);
+    // Collect updated offsets, source type and dest type. They may be adjusted
+    // later if the data is distributed to lanes (as opposed to being owned by
+    // all lanes uniformly).
+    SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
+        insertOp.getOffsets(), [](Attribute attr) { return attr; });
+    VectorType updatedSourceType = insertOp.getSourceVectorType();
+    VectorType updatedDestType = insertOp.getDestVectorType();
+    if (destDistributedDims.size() > 0) {
+      // Only single dimension distribution is supported.
+      if (destDistributedDims.size() != 1)
+        return rewriter.notifyMatchFailure(
+            warpOp,
+            "Expecting source to be distributed in a single dimension.");
+      int64_t destDistributedDim = destDistributedDims[0];
+
+      VectorType srcType = insertOp.getSourceVectorType();
+      VectorType destType = insertOp.getDestVectorType();
+      // Currently we require that both source (kD) and dest (nD) vectors are
+      // distributed. This requires that distributedDim (d) is contained in the
+      // last k dims of the dest vector (d >= n - k).
+      int64_t sourceDistributedDim =
+          destDistributedDim - (destType.getRank() - srcType.getRank());
+      if (sourceDistributedDim < 0)
+        return rewriter.notifyMatchFailure(
+            insertOp,
+            "distributed dimension must be in the last k (i.e. source "
+            "rank) dims of dest vector");
+      int64_t srcDistrDimSize = srcType.getDimSize(sourceDistributedDim);
+      // Obtain the source and dest layouts.
+      auto destLayout =
+          xegpu::getDistributeLayoutAttr(insertOp->getOpOperand(1));
+      auto sourceLayout =
+          xegpu::getDistributeLayoutAttr(insertOp->getOpOperand(0));
+      if (!destLayout || !sourceLayout ||
+          destLayout.getEffectiveLaneLayoutAsInt().empty() ||
+          sourceLayout.getEffectiveLaneLayoutAsInt().empty())
+        return rewriter.notifyMatchFailure(
+            warpOp, "the source or dest of insert_strided_slice op lacks "
+                    "distribution layout");
+      // Because only single dimension distribution is supported, lane layout
+      // size at the distributed dim must be the subgroup size.
+      int subgroupSize =
+          destLayout.getEffectiveLaneLayoutAsInt()[destDistributedDim];
+      // We require that source and dest lane data are all ones to ensure
+      // uniform round robin distribution.
+      auto destLaneData = destLayout.getEffectiveLaneDataAsInt();
+      auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
+      if (!llvm::all_of(destLaneData, [](int64_t v) { return v == 1; }) ||
+          !llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; }))
+        return rewriter.notifyMatchFailure(
+            warpOp, "Expecting unit lane data in source and dest layouts");
+      // Source distributed dim size must be multiples of subgroup size.
+      if (srcDistrDimSize % subgroupSize != 0)
+        return rewriter.notifyMatchFailure(
+            warpOp, "Distributed dimension size in source is not a multiple of "
+                    "subgroup size.");
+      // Offsets in the distributed dimension must be multiples of subgroup
+      // size.
+      int64_t destDistrDimOffset =
+          cast<IntegerAttr>(insertOp.getOffsets()[destDistributedDim]).getInt();
+      if (destDistrDimOffset % subgroupSize != 0)
+        return rewriter.notifyMatchFailure(
+            warpOp,
+            "Offset along distributed dimension in dest is not a multiple of "
+            "subgroup size.");
+      // Update the source and dest types based on their layouts.
+      updatedSourceType = getDistVecTypeBasedOnLaneLayout(
+                              sourceLayout, insertOp.getSourceVectorType())
+                              .value();
+      updatedDestType = getDistVecTypeBasedOnLaneLayout(
+                            destLayout, insertOp.getDestVectorType())
+                            .value();
+      // Update the distributed offsets to match round robin distribution (i.e.
+      // each lane owns data at `subgroupSize` stride given unit lane data).
+      updatedOffsets[destDistributedDim] =
+          rewriter.getI64IntegerAttr(destDistrDimOffset / subgroupSize);
+    }
+    // Do the distribution by yielding the source and dest of the insert op
+    // from the warp op and creating a new insert op outside the warp op.
+    SmallVector<size_t> newRetIndices;
+    auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
+        {updatedSourceType, updatedDestType}, newRetIndices);
+    rewriter.setInsertionPointAfter(newWarpOp);
+
+    Value valueToStore = newWarpOp.getResult(newRetIndices[0]);
+    Value dest = newWarpOp.getResult(newRetIndices[1]);
+    // Create a new insert op outside the warp op.
+    Value newInsertOp = vector::InsertStridedSliceOp::create(
+        rewriter, insertOp.getLoc(), updatedDestType, valueToStore, dest,
+        ArrayAttr::get(rewriter.getContext(), updatedOffsets),
+        insertOp.getStrides());
+    rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber),
+                                newInsertOp);
+    return success();
+  }
+};
+
 /// Sink a memref::ExtractAlignedPointerAsIndex op feeding into yield op of an
 /// enclosing `gpu.warp_execute_on_lane_0` region. This will simply move the op
 /// outside of the warp op.
@@ -1626,9 +1861,13 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
                MemrefExtractAlignedPointerAsIndexDistribution>(
       patterns.getContext(),
       /*pattern benefit=*/regularPatternBenefit);
-  patterns.add<VectorShapeCastDistribution>(
-      patterns.getContext(),
-      /*pattern benefit=*/highPatternBenefit);
+  // For following patterns, we need to override the regular vector distribution
+  // patterns. Therefore, assign higher benefit.
+  patterns
+      .add<VectorShapeCastDistribution, VectorExtractStridedSliceDistribution,
+           VectorInsertStridedSliceDistribution>(
+          patterns.getContext(),
+          /*pattern benefit=*/highPatternBenefit);
 }
 
 void xegpu::populateXeGPUMoveFuncBodyToWarpOpPatterns(

diff  --git a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
index f233dff609f2b..44ec21359593f 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
@@ -1,6 +1,6 @@
-// RUN: mlir-opt --xevm-attach-target='module=xevm_* chip=pvc' -test-xegpu-sg-distribute -allow-unregistered-dialect \
-// RUN: -canonicalize -cse -split-input-file %s | FileCheck %s
-
+// RUN: mlir-opt --xevm-attach-target='module=xevm_* chip=pvc' -test-xegpu-sg-distribute  \
+// RUN: -allow-unregistered-dialect -canonicalize -cse  %s | FileCheck %s
+gpu.module @xevm_module{
 // CHECK-LABEL: gpu.func @store_nd_1d
 // CHECK:         (%[[ARG0:[0-9a-zA-Z]+]]: index) {
 // CHECK:         %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16]
@@ -11,20 +11,17 @@
 // CHECK-NEXT:    %[[T1:.*]] = builtin.unrealized_conversion_cast %[[W]]#1 : !xegpu.tensor_desc<16xf32,
 // CHECK-SAME:      #xegpu.layout<lane_layout = [16], lane_data = [1]>> to !xegpu.tensor_desc<16xf32> {resolve_simt_type_mismatch}
 // CHECK-NEXT:    xegpu.store_nd %[[W]]#0, %[[T1]][%[[W]]#2]  : vector<1xf32>, !xegpu.tensor_desc<16xf32>
-gpu.module @xevm_module{
-  gpu.func @store_nd_1d(%laneid: index) {
-    %c0 = arith.constant 0 : index
-    gpu.warp_execute_on_lane_0(%laneid)[16] {
-      %0 = "some_op"() : () -> !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
-      %cst = "some_op"() : () -> vector<16xf32>
-      xegpu.store_nd %cst, %0 [%c0] {layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
-        : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
-    }
-    gpu.return
+gpu.func @store_nd_1d(%laneid: index) {
+  %c0 = arith.constant 0 : index
+  gpu.warp_execute_on_lane_0(%laneid)[16] {
+    %0 = "some_op"() : () -> !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+    %cst = "some_op"() : () -> vector<16xf32>
+    xegpu.store_nd %cst, %0 [%c0] {layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+      : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
   }
+  gpu.return
 }
 
-// -----
 // CHECK-LABEL: gpu.func @store_nd_2d
 // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) {
 // CHECK:       %[[W:.*]]:4 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16]
@@ -37,22 +34,18 @@ gpu.module @xevm_module{
 // CHECK-NEXT:  %[[T1:.*]] = builtin.unrealized_conversion_cast %[[W]]#1 : !xegpu.tensor_desc<16x16xf16,
 // CHECK-SAME:    #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> to !xegpu.tensor_desc<16x16xf16> {resolve_simt_type_mismatch}
 // CHECK-NEXT:  xegpu.store_nd %[[CAST]], %[[T1]][%[[W]]#2, %[[W]]#3]  : vector<16xf16>, !xegpu.tensor_desc<16x16xf16>
-gpu.module @xevm_module{
-  gpu.func @store_nd_2d(%laneid : index) {
-    %c0 = arith.constant 0 : index
-    gpu.warp_execute_on_lane_0(%laneid)[16] {
-      %0 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-      %cst = "some_op"() : () -> vector<16x16xf16>
-      xegpu.store_nd %cst, %0 [%c0, %c0] {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
-        : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-    }
-    gpu.return
+gpu.func @store_nd_2d(%laneid : index) {
+  %c0 = arith.constant 0 : index
+  gpu.warp_execute_on_lane_0(%laneid)[16] {
+    %0 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    %cst = "some_op"() : () -> vector<16x16xf16>
+    xegpu.store_nd %cst, %0 [%c0, %c0] {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+      : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
   }
+  gpu.return
 }
 
 
-
-// -----
 // CHECK-LABEL: gpu.func @load_nd_1d
 // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) {
 // CHECK:       %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (vector<1xf32>,
@@ -63,21 +56,19 @@ gpu.module @xevm_module{
 // CHECK-NEXT:  %[[T1:.*]] = builtin.unrealized_conversion_cast %[[W]]#1 : !xegpu.tensor_desc<16xf32,
 // CHECK-SAME:    #xegpu.layout<lane_layout = [16], lane_data = [1]>> to !xegpu.tensor_desc<16xf32> {resolve_simt_type_mismatch}
 // CHECK-NEXT:  xegpu.load_nd %[[T1]][%[[W]]#2]  : !xegpu.tensor_desc<16xf32> -> vector<1xf32>
-gpu.module @xevm_module{
-  gpu.func @load_nd_1d(%laneid: index) {
-    %c0 = arith.constant 0 : index
-    %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1xf32>) {
-      %0 = "some_op"() : () -> !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
-      %1 = xegpu.load_nd %0 [%c0]  {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} :
-        !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>> -> vector<16xf32>
-      gpu.yield %1 : vector<16xf32>
-    }
-    "some_user_op"(%r) : (vector<1xf32>) -> ()
-    gpu.return
+gpu.func @load_nd_1d(%laneid: index) {
+  %c0 = arith.constant 0 : index
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1xf32>) {
+    %0 = "some_op"() : () -> !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+    %1 = xegpu.load_nd %0 [%c0]  {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} :
+      !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>> -> vector<16xf32>
+    gpu.yield %1 : vector<16xf32>
   }
+  "some_user_op"(%r) : (vector<1xf32>) -> ()
+  gpu.return
 }
 
-// -----
+
 // CHECK-LABEL: gpu.func @load_nd_2d
 // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) {
 // CHECK:       %[[W:.*]]:4 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (vector<16x1xf16>, !xegpu.tensor_desc<16x16xf16,
@@ -89,21 +80,19 @@ gpu.module @xevm_module{
 // CHECK-SAME:     #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> to !xegpu.tensor_desc<16x16xf16> {resolve_simt_type_mismatch}
 // CHECK-NEXT:  %[[T2:.*]] = xegpu.load_nd %[[T1]][%[[W]]#2, %[[W]]#3]  : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
 // CHECK:       vector.shape_cast %[[T2]] : vector<16xf16> to vector<16x1xf16>
-gpu.module @xevm_module{
-  gpu.func @load_nd_2d(%laneid: index) {
-    %c0 = arith.constant 0 : index
-    %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<16x1xf16>) {
-      %0 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-      %1 = xegpu.load_nd %0[%c0, %c0]  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
-        : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
-      gpu.yield %1 : vector<16x16xf16>
-    }
-    "some_user_op"(%r) : (vector<16x1xf16>) -> ()
-    gpu.return
+gpu.func @load_nd_2d(%laneid: index) {
+  %c0 = arith.constant 0 : index
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<16x1xf16>) {
+    %0 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    %1 = xegpu.load_nd %0[%c0, %c0]  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+      : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
+    gpu.yield %1 : vector<16x16xf16>
   }
+  "some_user_op"(%r) : (vector<16x1xf16>) -> ()
+  gpu.return
 }
 
-// -----
+
 // CHECK-LABEL: gpu.func @load_nd_array_length
 // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) {
 // CHECK:       %[[W:.*]]:4 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (vector<2x16x1xf16>,
@@ -118,23 +107,21 @@ gpu.module @xevm_module{
 // CHECK-NEXT:  %[[T2:.*]] = xegpu.load_nd %[[T1]][%[[W]]#2, %[[W]]#3]  : !xegpu.tensor_desc<16x16xf16,
 // CHECK-SAME:    #xegpu.block_tdesc_attr<array_length = 2 : i64>> -> vector<32xf16>
 // CHECK-NEXT:  vector.shape_cast %[[T2]] : vector<32xf16> to vector<2x16x1xf16>
-gpu.module @xevm_module{
-  gpu.func @load_nd_array_length(%laneid: index) {
-    %c0 = arith.constant 0 : index
-    %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2x16x1xf16>) {
-      %0 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>,
-        #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-      %1 = xegpu.load_nd %0[%c0, %c0]  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
-        : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>,
-          #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<2x16x16xf16>
-      gpu.yield %1 : vector<2x16x16xf16>
-    }
-    "some_user_op"(%r) : (vector<2x16x1xf16>) -> ()
-    gpu.return
+gpu.func @load_nd_array_length(%laneid: index) {
+  %c0 = arith.constant 0 : index
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2x16x1xf16>) {
+    %0 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>,
+      #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    %1 = xegpu.load_nd %0[%c0, %c0]  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+      : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>,
+        #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<2x16x16xf16>
+    gpu.yield %1 : vector<2x16x16xf16>
   }
+  "some_user_op"(%r) : (vector<2x16x1xf16>) -> ()
+  gpu.return
 }
 
-// -----
+
 // CHECK-LABEL: gpu.func @dpas
 // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) {
 // CHECK:       %[[W:.*]]:4 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] ->
@@ -146,29 +133,27 @@ gpu.module @xevm_module{
 // CHECK-DAG:   %[[T3:.*]] = vector.shape_cast %[[W]]#3 : vector<8x1xf32> to vector<8xf32>
 // CHECK-NEXT:  %[[T4:.*]] = xegpu.dpas %[[T1]], %[[T2]], %[[T3]] : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32>
 // CHECK-NEXT:  vector.shape_cast %[[T4]] : vector<8xf32> to vector<8x1xf32>
-gpu.module @xevm_module{
-  gpu.func @dpas(%laneid: index) {
-    %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<8x1xf32>) {
-      %0 = "some_op"() : () -> vector<8x16xf16>
-      %1 = "some_op"() : () -> vector<16x16xf16>
-      %2 = "some_op"() : () -> vector<8x16xf32>
-      %3 = xegpu.dpas %0, %1, %2
-        {
-          layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
-          layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>,
-          layout_operand_2 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
-          layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
-        }
-        : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
-      gpu.yield %3 : vector<8x16xf32>
-    }
-    "some_user_op"(%r) : (vector<8x1xf32>) -> ()
-    gpu.return
+gpu.func @dpas(%laneid: index) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<8x1xf32>) {
+    %0 = "some_op"() : () -> vector<8x16xf16>
+    %1 = "some_op"() : () -> vector<16x16xf16>
+    %2 = "some_op"() : () -> vector<8x16xf32>
+    %3 = xegpu.dpas %0, %1, %2
+      {
+        layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+        layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>,
+        layout_operand_2 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+        layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+      }
+      : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+    gpu.yield %3 : vector<8x16xf32>
   }
+  "some_user_op"(%r) : (vector<8x1xf32>) -> ()
+  gpu.return
 }
 
 
-// -----
+
 // CHECK-LABEL: gpu.func @create_nd_tdesc_non_memref
 // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: ui64, %[[ARG1:[0-9a-zA-Z]+]]: index) {
 // CHECK:       %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%[[ARG1]])[16] -> (!xegpu.tensor_desc<16x16xf16,
@@ -178,21 +163,19 @@ gpu.module @xevm_module{
 // CHECK-NEXT:  %[[T1:.*]] = xegpu.create_nd_tdesc %[[W]]#1, shape : [64, 128], strides : [128, 1] : ui64 -> !xegpu.tensor_desc<16x16xf16>
 // CHECK-NEXT:  builtin.unrealized_conversion_cast %[[T1]] : !xegpu.tensor_desc<16x16xf16> to !xegpu.tensor_desc<16x16xf16,
 // CHECK-SAME:    #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> {resolve_simt_type_mismatch}
-gpu.module @xevm_module{
-  gpu.func @create_nd_tdesc_non_memref(%arg0: ui64, %laneid: index) {
-    %c0 = arith.constant 0 : index
-    %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (!xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) {
-      %0 = xegpu.create_nd_tdesc %arg0, shape:[64, 128], strides:[128, 1] : ui64 ->
-        !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-      gpu.yield %0 : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-    }
-    "some_user_op"(%r)
-      : (!xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) -> ()
-    gpu.return
+gpu.func @create_nd_tdesc_non_memref(%arg0: ui64, %laneid: index) {
+  %c0 = arith.constant 0 : index
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (!xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) {
+    %0 = xegpu.create_nd_tdesc %arg0, shape:[64, 128], strides:[128, 1] : ui64 ->
+      !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    gpu.yield %0 : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
   }
+  "some_user_op"(%r)
+    : (!xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) -> ()
+  gpu.return
 }
 
-// -----
+
 // CHECK-LABEL: gpu.func @prefetch_2d
 // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) {
 // CHECK:       %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (!xegpu.tensor_desc<16x16xf16,
@@ -204,21 +187,19 @@ gpu.module @xevm_module{
 // CHECK-SAME:    #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> to !xegpu.tensor_desc<16x16xf16> {resolve_simt_type_mismatch}
 // CHECK-NEXT:  xegpu.prefetch_nd %[[T1]][%[[W]]#1, %[[W]]#2]
 // CHECK-SAME:    <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x16xf16>
-gpu.module @xevm_module{
-  gpu.func @prefetch_2d(%laneid: index) {
-    %c0 = arith.constant 0 : index
-    gpu.warp_execute_on_lane_0(%laneid)[16] {
-      %0 = "some_op"() : ()
-        -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-      xegpu.prefetch_nd %0[%c0, %c0]
-        <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
-        : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-    }
-    gpu.return
+gpu.func @prefetch_2d(%laneid: index) {
+  %c0 = arith.constant 0 : index
+  gpu.warp_execute_on_lane_0(%laneid)[16] {
+    %0 = "some_op"() : ()
+      -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    xegpu.prefetch_nd %0[%c0, %c0]
+      <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+      : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
   }
+  gpu.return
 }
 
-// -----
+
 // CHECK-LABEL: gpu.func @prefetch_1d
 // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) {
 // CHECK:       %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (!xegpu.tensor_desc<16xf16,
@@ -229,44 +210,40 @@ gpu.module @xevm_module{
 // CHECK-SAME:    #xegpu.layout<lane_layout = [16], lane_data = [1]>> to !xegpu.tensor_desc<16xf16> {resolve_simt_type_mismatch}
 // CHECK-NEXT:  xegpu.prefetch_nd %[[T1]][%[[W]]#1] <{l1_hint = #xegpu.cache_hint<cached>,
 // CHECK-SAME:    l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16xf16>
-gpu.module @xevm_module{
-  gpu.func @prefetch_1d(%laneid: index) {
-    %c0 = arith.constant 0 : index
-    gpu.warp_execute_on_lane_0(%laneid)[16] {
-      %0 = "some_op"() : ()
-        -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
-      xegpu.prefetch_nd %0[%c0]
-        <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
-        : !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
-    }
-    gpu.return
+gpu.func @prefetch_1d(%laneid: index) {
+  %c0 = arith.constant 0 : index
+  gpu.warp_execute_on_lane_0(%laneid)[16] {
+    %0 = "some_op"() : ()
+      -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+    xegpu.prefetch_nd %0[%c0]
+      <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+      : !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
   }
+  gpu.return
 }
 
-// -----
+
 // CHECK-LABEL: gpu.func @gpu_barrier({{.*}}) {
 // CHECK:  gpu.warp_execute_on_lane_0(%{{.*}})[16] -> ({{.*}}) {
 // CHECK:     gpu.yield %{{.*}}
 // CHECK:  }
 // CHECK:  %{{.*}} = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16xf16> -> vector<1xf16>
 // CHECK:  gpu.barrier
-gpu.module @xevm_module{
-  gpu.func @gpu_barrier(%laneid: index) {
-    %c0 = arith.constant 0 : index
-    %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1xf16>) {
-      %0 = "some_op"() : () -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
-      %1 = xegpu.load_nd %0[%c0]
-        {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
-        : !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>> -> vector<16xf16>
-      gpu.barrier
-      gpu.yield %1 : vector<16xf16>
-    }
-    "some_user_op"(%r) : (vector<1xf16>) -> ()
-    gpu.return
+gpu.func @gpu_barrier(%laneid: index) {
+  %c0 = arith.constant 0 : index
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1xf16>) {
+    %0 = "some_op"() : () -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+    %1 = xegpu.load_nd %0[%c0]
+      {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+      : !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>> -> vector<16xf16>
+    gpu.barrier
+    gpu.yield %1 : vector<16xf16>
   }
+  "some_user_op"(%r) : (vector<1xf16>) -> ()
+  gpu.return
 }
 
-// -----
+
 // CHECK-LABEL: gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction
 // CHECK:       %[[ACC:.*]] = arith.constant {{.*}} dense<0.000000e+00> : vector<32xf32>
 // CHECK:       %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16]
@@ -285,7 +262,6 @@ gpu.module @xevm_module{
 // CHECK:       %[[T7:.*]] = vector.extract %[[W]]#2[1] : f32 from vector<2xf32>
 // CHECK:       %[[T8:.*]] = vector.reduction <add>, %[[T6]], %[[T7]] : vector<16xf32> into f32
 // CHECK:       %[[T9:.*]] = vector.from_elements %[[T4]], %[[T8]] : vector<2xf32>
-gpu.module @xevm_module{
 gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction(%laneid: index) {
   %c0 = arith.constant 0 : index
   %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) {
@@ -307,9 +283,8 @@ gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction(%laneid: index)
   "some_user_op"(%r) : (vector<2xf32>) -> ()
   gpu.return
 }
-}
 
-// -----
+
 // CHECK-LABEL: gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction
 // CHECK:      %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>) {
 // CHECK-NEXT:   %[[SRC:.*]] = "some_def"() {{.*}} : () -> vector<2x16xf32>
@@ -320,7 +295,6 @@ gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction(%laneid: index)
 // CHECK-NEXT:   %[[T6:.*]] = vector.from_elements %[[T3]], %[[T5]] : vector<2xf32>
 // CHECK-NEXT:   gpu.yield %[[T6]] : vector<2xf32>
 // CHECK-NEXT: }
-gpu.module @xevm_module{
 gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction(%laneid: index) {
   %c0 = arith.constant 0 : index
   %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) {
@@ -342,9 +316,8 @@ gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction(%laneid: index)
   "some_user_op"(%r) : (vector<2xf32>) -> ()
   gpu.return
 }
-}
 
-// -----
+
 // CHECK-LABEL:   gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction
 // CHECK:       %[[ACC:.*]] = arith.constant {{.*}} dense<0.000000e+00> : vector<32xf32>
 // CHECK:       %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>, vector<2x16xf32>, vector<2xf32>) {
@@ -358,7 +331,6 @@ gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction(%laneid: index)
 // CHECK:       %[[T5:.*]] = vector.extract %[[W]]#2[1] : f32 from vector<2xf32>
 // CHECK:       %[[T6:.*]] = vector.reduction <add>, %[[T4]], %[[T5]] : vector<16xf32> into f32
 // CHECK:       %[[T7:.*]] = vector.from_elements %[[T3]], %[[T6]] : vector<2xf32>
-gpu.module @xevm_module{
 gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction(%laneid: index) {
   %c0 = arith.constant 0 : index
   %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) {
@@ -380,9 +352,8 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction(%laneid: index)
   "some_user_op"(%r) : (vector<2xf32>) -> ()
   gpu.return
 }
-}
 
-// -----
+
 // CHECK-LABEL: gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction
 // CHECK:     %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>) {
 // CHECK:       %[[SRC:.*]] = "some_def"() {{.*}} : () -> vector<16x2xf32>
@@ -397,7 +368,6 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction(%laneid: index)
 // CHECK:       %[[T7:.*]] = vector.from_elements %[[T3]], %[[T6]] : vector<2xf32>
 // CHECK:       gpu.yield %[[T7]] : vector<2xf32>
 // CHECK:     }
-gpu.module @xevm_module{
 gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction(%laneid: index) {
   %c0 = arith.constant 0 : index
   %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) {
@@ -419,9 +389,8 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction(%laneid: index)
   "some_user_op"(%r) : (vector<2xf32>) -> ()
   gpu.return
 }
-}
 
-// -----
+
 // CHECK-LABEL: gpu.func @scatter_ops_chunksize({{.*}}) {
 // CHECK:       %[[OFFSETS:.*]] = arith.constant {{.*}} dense<12> : vector<16xindex>
 // CHECK:       %[[MASKS:.*]] = arith.constant {{.*}} dense<true> : vector<16xi1>
@@ -434,35 +403,33 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction(%laneid: index)
 // CHECK-SAME:    : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
 // CHECK-NEXT:  xegpu.store %[[T1]], %[[W]]#1[%[[W]]#2], %[[W]]#3 <{chunk_size = 8 : i64}>
 // CHECK-SAME:    : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
-gpu.module @xevm_module{
-  gpu.func @scatter_ops_chunksize(%laneid: index, %src: memref<256xf16>) {
-    gpu.warp_execute_on_lane_0(%laneid)[16] {
-      %1 = arith.constant
-        {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
-        dense<1>: vector<16xi1>
-      %offset = arith.constant
-        {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
-        dense<12> : vector<16xindex>
-      %3 = xegpu.load %src[%offset], %1 <{chunk_size=8}>
-        {
-          layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
-          layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
-          layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
-        }
-        : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
-      xegpu.store %3, %src[%offset], %1 <{chunk_size=8}>
-        {
-          layout_operand_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>,
-          layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
-          layout_operand_3 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
-        }
-        : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
-    }
-    gpu.return
+gpu.func @scatter_ops_chunksize(%laneid: index, %src: memref<256xf16>) {
+  gpu.warp_execute_on_lane_0(%laneid)[16] {
+    %1 = arith.constant
+      {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+      dense<1>: vector<16xi1>
+    %offset = arith.constant
+      {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+      dense<12> : vector<16xindex>
+    %3 = xegpu.load %src[%offset], %1 <{chunk_size=8}>
+      {
+        layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+        layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+        layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
+      }
+      : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
+    xegpu.store %3, %src[%offset], %1 <{chunk_size=8}>
+      {
+        layout_operand_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>,
+        layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+        layout_operand_3 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+      }
+      : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
   }
+  gpu.return
 }
 
-// -----
+
 // CHECK-LABEL: gpu.func @scatter_ops({{.*}}) {
 // CHECK:       %[[OFFSETS:.*]] = arith.constant {{.*}} dense<12> : vector<16xindex>
 // CHECK:       %[[MASKS:.*]] = arith.constant {{.*}} dense<true> : vector<16xi1>
@@ -475,156 +442,144 @@ gpu.module @xevm_module{
 // CHECK-SAME:    : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16>
 // CHECK-NEXT:  xegpu.store %[[T1]], %[[W]]#1[%[[W]]#2], %[[W]]#3
 // CHECK-SAME:    : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
-gpu.module @xevm_module{
-  gpu.func @scatter_ops(%src: memref<256xf16>, %laneid: index) {
-    gpu.warp_execute_on_lane_0(%laneid)[16] {
-      %1 = arith.constant
-        {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
-        dense<1> : vector<16xi1>
-      %offset = arith.constant
-        {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
-        dense<12> : vector<16xindex>
-      %3 = xegpu.load %src[%offset], %1
-      {
-        layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
-        layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
-        layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
-      } : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
-      xegpu.store %3, %src[%offset], %1
-      {
-        layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
-        layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
-        layout_operand_3 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
-      }
-      : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+gpu.func @scatter_ops(%src: memref<256xf16>, %laneid: index) {
+  gpu.warp_execute_on_lane_0(%laneid)[16] {
+    %1 = arith.constant
+      {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+      dense<1> : vector<16xi1>
+    %offset = arith.constant
+      {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+      dense<12> : vector<16xindex>
+    %3 = xegpu.load %src[%offset], %1
+    {
+      layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+      layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+      layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+    } : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
+    xegpu.store %3, %src[%offset], %1
+    {
+      layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+      layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+      layout_operand_3 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
     }
-    gpu.return
+    : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
   }
+  gpu.return
 }
 
-// -----
+
 // CHECK-LABEL: gpu.func @memref_extract_aligned_pointer_as_index(
 // CHECK:       %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (index, memref<256x256xf16>) {
 // CHECK:         gpu.yield %{{.*}}, %{{.*}} : index, memref<256x256xf16>
 // CHECK-NEXT:  }
 // CHECK-NEXT:  %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[W]]#1 : memref<256x256xf16> -> index
 // CHECK-NEXT:  arith.index_cast %[[INTPTR]] : index to i64
-gpu.module @xevm_module{
-  gpu.func @memref_extract_aligned_pointer_as_index(%arg0 : memref<256x256xf16>, %laneid: index) {
-    %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (index) {
-      %ptr = memref.extract_aligned_pointer_as_index %arg0 : memref<256x256xf16> -> index
-      gpu.yield %ptr : index
-    }
-    %ptr_i64 = arith.index_cast %r : index to i64
-    "some_user_op"(%ptr_i64) : (i64) -> ()
-    gpu.return
+gpu.func @memref_extract_aligned_pointer_as_index(%arg0 : memref<256x256xf16>, %laneid: index) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (index) {
+    %ptr = memref.extract_aligned_pointer_as_index %arg0 : memref<256x256xf16> -> index
+    gpu.yield %ptr : index
   }
+  %ptr_i64 = arith.index_cast %r : index to i64
+  "some_user_op"(%ptr_i64) : (i64) -> ()
+  gpu.return
 }
 
 
-// -----
+
 // CHECK-LABEL: gpu.func @vector_transpose(
 // CHECK:       %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2x1xf32>, vector<1x2xf32>) {
 // CHECK:         %[[SRC:.*]] = "some_op"() {{.*}} : () -> vector<16x2xf32>
 // CHECK:         gpu.yield %{{.*}}, %[[SRC]] : vector<2x16xf32>, vector<16x2xf32>
 // CHECK-NEXT:  }
 // CHECK-NEXT:  %[[T1:.*]] = vector.transpose %[[W]]#1, [1, 0] : vector<1x2xf32> to vector<2x1xf32>
-gpu.module @xevm_module{
-  gpu.func @vector_transpose(%laneid: index) {
-    %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2x1xf32>) {
-      %cst = "some_op"()
-        {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
-        : () -> (vector<16x2xf32>)
-      %transpose = vector.transpose %cst, [1, 0]
-        {
-          layout_operand_0 = #xegpu.layout<lane_layout = [16 , 1], lane_data = [1, 1]>,
-          layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
-        }
-        : vector<16x2xf32> to vector<2x16xf32>
-      gpu.yield %transpose : vector<2x16xf32>
-    }
-    "some_user_op"(%r) : (vector<2x1xf32>) -> ()
-    gpu.return
+gpu.func @vector_transpose(%laneid: index) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2x1xf32>) {
+    %cst = "some_op"()
+      {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+      : () -> (vector<16x2xf32>)
+    %transpose = vector.transpose %cst, [1, 0]
+      {
+        layout_operand_0 = #xegpu.layout<lane_layout = [16 , 1], lane_data = [1, 1]>,
+        layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+      }
+      : vector<16x2xf32> to vector<2x16xf32>
+    gpu.yield %transpose : vector<2x16xf32>
   }
+  "some_user_op"(%r) : (vector<2x1xf32>) -> ()
+  gpu.return
 }
 
-// -----
+
 // CHECK-LABEL: gpu.func @vector_bitcast(
 // CHECK:       %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<4x1xi16>, vector<4x2xi8>) {
 // CHECK:         %[[SRC:.*]] = "some_op"() {{.*}} : () -> vector<4x32xi8>
 // CHECK:         gpu.yield %{{.*}}, %[[SRC]] : vector<4x16xi16>, vector<4x32xi8>
 // CHECK:       }
 // CHECK:       vector.bitcast %[[W]]#1 : vector<4x2xi8> to vector<4x1xi16>
-gpu.module @xevm_module{
-  gpu.func @vector_bitcast(%laneid: index) {
-    %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<4x1xi16>) {
-      %cst = "some_op"()
-        {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>}
-        : () -> (vector<4x32xi8>)
-      %bitcast = vector.bitcast %cst
-        {
-          layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>,
-          layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
-        }
-        : vector<4x32xi8> to vector<4x16xi16>
-      gpu.yield %bitcast : vector<4x16xi16>
-    }
-    "some_user_op"(%r) : (vector<4x1xi16>) -> ()
-    gpu.return
+gpu.func @vector_bitcast(%laneid: index) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<4x1xi16>) {
+    %cst = "some_op"()
+      {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>}
+      : () -> (vector<4x32xi8>)
+    %bitcast = vector.bitcast %cst
+      {
+        layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>,
+        layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+      }
+      : vector<4x32xi8> to vector<4x16xi16>
+    gpu.yield %bitcast : vector<4x16xi16>
   }
+  "some_user_op"(%r) : (vector<4x1xi16>) -> ()
+  gpu.return
 }
 
-// -----
+
 // CHECK-LABEL: gpu.func @vector_shapecast_rank_increasing
 // CHECK:         %{{.*}}:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<1x1xf32>, vector<1xf32>) {
 // CHECK:           gpu.yield %{{.*}} : vector<1x16xf32>, vector<16xf32>
 // CHECK:         }
 // CHECK:         %{{.*}} = vector.shape_cast %{{.*}}#1 : vector<1xf32> to vector<1x1xf32>
-gpu.module @xevm_module {
-  gpu.func @vector_shapecast_rank_increasing(%laneid: index) {
-    %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x1xf32>) {
-      %cst = "some_op"()
-        {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>}
-        : () -> (vector<16xf32>)
-      %cast = vector.shape_cast %cst
-        {
-          layout_operand_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>,
-          layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
-        }
-        : vector<16xf32> to vector<1x16xf32>
-      gpu.yield %cast : vector<1x16xf32>
-    }
-    "some_user_op"(%r) : (vector<1x1xf32>) -> ()
-    gpu.return
+gpu.func @vector_shapecast_rank_increasing(%laneid: index) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x1xf32>) {
+    %cst = "some_op"()
+      {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>}
+      : () -> (vector<16xf32>)
+    %cast = vector.shape_cast %cst
+      {
+        layout_operand_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>,
+        layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+      }
+      : vector<16xf32> to vector<1x16xf32>
+    gpu.yield %cast : vector<1x16xf32>
   }
+  "some_user_op"(%r) : (vector<1x1xf32>) -> ()
+  gpu.return
 }
 
-// -----
+
 // CHECK-LABEL: gpu.func @vector_shapecast_rank_reducing(
 // CHECK:         %{{.*}}:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<1xf32>, vector<1x1xf32>) {
 // CHECK:           gpu.yield %{{.*}} : vector<16xf32>, vector<1x16xf32>
 // CHECK:         }
 // CHECK:         %{{.*}} = vector.shape_cast %{{.*}}#1 : vector<1x1xf32> to vector<1xf32>
-gpu.module @xevm_module {
-  gpu.func @vector_shapecast_rank_reducing(%laneid: index) {
-    %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1xf32>) {
-      %cst = "some_op"()
-        {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
-        : () -> (vector<1x16xf32>)
-      %cast = vector.shape_cast %cst
-        {
-          layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
-          layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>
-        }
-        : vector<1x16xf32> to vector<16xf32>
-      gpu.yield %cast : vector<16xf32>
-    }
-    "some_user_op"(%r) : (vector<1xf32>) -> ()
-    gpu.return
+gpu.func @vector_shapecast_rank_reducing(%laneid: index) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1xf32>) {
+    %cst = "some_op"()
+      {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+      : () -> (vector<1x16xf32>)
+    %cast = vector.shape_cast %cst
+      {
+        layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+        layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>
+      }
+      : vector<1x16xf32> to vector<16xf32>
+    gpu.yield %cast : vector<16xf32>
   }
+  "some_user_op"(%r) : (vector<1xf32>) -> ()
+  gpu.return
 }
 
-// -----
+
 // NOTE: Layouts are still valid, but distribution still requires a slice layout for the operand.
 //
 // CHECK-LABEL:  gpu.func @vector_shapecast_unsupported
@@ -634,21 +589,335 @@ gpu.module @xevm_module {
 // CHECK:          }
 // CHECK:          "some_user_op"(%[[W]]) : (vector<1x1xf32>) -> ()
 // CHECK:          gpu.return
-gpu.module @xevm_module {
-  gpu.func @vector_shapecast_unsupported(%laneid: index) {
-    %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x1xf32>) {
-      %cst = "some_op"()
-        {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]> }
-        : () -> (vector<16xf32>)
-      %cast = vector.shape_cast %cst
-        {
-          layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
-          layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
-        }
-        : vector<16xf32> to vector<1x16xf32>
-      gpu.yield %cast : vector<1x16xf32>
+gpu.func @vector_shapecast_unsupported(%laneid: index) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x1xf32>) {
+    %cst = "some_op"()
+      {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]> }
+      : () -> (vector<16xf32>)
+    %cast = vector.shape_cast %cst
+      {
+        layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+        layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+      }
+      : vector<16xf32> to vector<1x16xf32>
+    gpu.yield %cast : vector<1x16xf32>
+  }
+  "some_user_op"(%r) : (vector<1x1xf32>) -> ()
+  gpu.return
+}
+
+
+// CHECK-LABEL:  gpu.func @vector_extract_strided_slice_distributed_dim_fully_extracted
+// CHECK-NEXT:     %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<8x1xf32>, vector<24x1xf32>) {
+// CHECK-NEXT:       %[[S:.*]] = "some_def"() : () -> vector<24x16xf32>
+// CHECK:            gpu.yield %{{.*}}, %[[S]] : vector<8x16xf32>, vector<24x16xf32>
+// CHECK-NEXT:     }
+// CHECK-NEXT:     %[[T1:.*]] = vector.extract_strided_slice %[[W]]#1
+// CHECK-SAME:        {offsets = [8, 0], sizes = [8, 1], strides = [1, 1]} : vector<24x1xf32> to vector<8x1xf32>
+// CHECK-NEXT:     "some_use"(%[[T1]]) : (vector<8x1xf32>) -> ()
+gpu.func @vector_extract_strided_slice_distributed_dim_fully_extracted(%laneid: index) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<8x1xf32>) {
+    %0 = "some_def"() : () -> (vector<24x16xf32>)
+    %1 = vector.extract_strided_slice %0 { offsets = [8, 0], sizes = [8, 16], strides = [1, 1],
+        layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+        layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+      }
+      : vector<24x16xf32> to vector<8x16xf32>
+    gpu.yield %1 : vector<8x16xf32>
+  }
+  "some_use"(%r) : (vector<8x1xf32>) -> ()
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_extract_strided_slice_non_distributed
+// CHECK-NEXT:    %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<8x1xf32>, vector<24x1xf32>) {
+// CHECK-NEXT:      %[[S:.*]] = "some_def"() : () -> vector<24x1xf32>
+// CHECK:           gpu.yield %{{.*}}, %[[S]] : vector<8x1xf32>, vector<24x1xf32>
+// CHECK-NEXT:    }
+// CHECK-NEXT:    %[[T1:.*]] = vector.extract_strided_slice %[[W]]#1
+// CHECK-SAME:      {offsets = [8, 0], sizes = [8, 1], strides = [1, 1]} : vector<24x1xf32> to vector<8x1xf32>
+// CHECK-NEXT:    "some_use"(%[[T1]]) : (vector<8x1xf32>) -> ()
+gpu.func @vector_extract_strided_slice_non_distributed(%laneid: index) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<8x1xf32>) {
+    %0 = "some_def"() : () -> (vector<24x1xf32>)
+    %1 = vector.extract_strided_slice %0 { offsets = [8, 0], sizes = [8, 1], strides = [1, 1],
+        layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+        layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+      }
+      : vector<24x1xf32> to vector<8x1xf32>
+    gpu.yield %1 : vector<8x1xf32>
+  }
+  "some_use"(%r) : (vector<8x1xf32>) -> ()
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_extract_strided_slice_inner_distributed
+// CHECK:         %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<8x1xf32>, vector<24x4xf32>) {
+// CHECK-NEXT:      %[[S:.*]] = "some_def"() : () -> vector<24x64xf32>
+// CHECK:           gpu.yield %{{.*}}, %[[S]] : vector<8x16xf32>, vector<24x64xf32>
+// CHECK-NEXT:    }
+// CHECK-NEXT:    %[[T1:.*]] = vector.extract_strided_slice %[[W]]#1
+// CHECK-SAME:      {offsets = [8, 3], sizes = [8, 1], strides = [1, 1]} : vector<24x4xf32> to vector<8x1xf32>
+// CHECK-NEXT:    "some_use"(%[[T1]]) : (vector<8x1xf32>) -> ()
+gpu.func @vector_extract_strided_slice_inner_distributed(%laneid: index) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<8x1xf32>) {
+    %0 = "some_def"() : () -> (vector<24x64xf32>)
+    %1 = vector.extract_strided_slice %0 { offsets = [8, 48], sizes = [8, 16], strides = [1, 1],
+        layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+        layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+      }
+      : vector<24x64xf32> to vector<8x16xf32>
+    gpu.yield %1 : vector<8x16xf32>
+  }
+  "some_use"(%r) : (vector<8x1xf32>) -> ()
+  gpu.return
+}
+
+// CHECK-LABEL:  gpu.func @vector_extract_strided_slice_outer_distributed
+// CHECK:          %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<1x16xf32>, vector<2x16xf32>) {
+// CHECK-NEXT:       %[[S:.*]] = "some_def"() : () -> vector<32x16xf32>
+// CHECK:            gpu.yield %{{.*}}, %[[S]] : vector<16x16xf32>, vector<32x16xf32>
+// CHECK:          }
+// CHECK-NEXT:     %[[T1:.*]] = vector.extract %[[W]]#1[1] : vector<16xf32> from vector<2x16xf32>
+// CHECK-NEXT:     %[[T2:.*]] = vector.shape_cast %[[T1]] : vector<16xf32> to vector<1x16xf32>
+// CHECK-NEXT:     "some_use"(%[[T2]]) : (vector<1x16xf32>) -> ()
+gpu.func @vector_extract_strided_slice_outer_distributed(%laneid: index) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x16xf32>) {
+    %0 = "some_def"() : () -> (vector<32x16xf32>)
+    %1 = vector.extract_strided_slice %0 { offsets = [16], sizes = [16], strides = [1],
+        layout_operand_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>,
+        layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>
+      }
+      : vector<32x16xf32> to vector<16x16xf32>
+    gpu.yield %1 : vector<16x16xf32>
+  }
+  "some_use"(%r) : (vector<1x16xf32>) -> ()
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_extract_strided_slice_1d
+// CHECK:         %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>, vector<4xf32>) {
+// CHECK:           %[[S:.*]] = "some_def"() : () -> vector<64xf32>
+// CHECK:           gpu.yield %{{.*}}, %[[S]] : vector<32xf32>, vector<64xf32>
+// CHECK-NEXT:    }
+// CHECK-NEXT:    %[[T1:.*]] = vector.extract_strided_slice %[[W]]#1
+// CHECK-SAME:      {offsets = [1], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
+// CHECK-NEXT:    "some_use"(%[[T1]]) : (vector<2xf32>) -> ()
+gpu.func @vector_extract_strided_slice_1d(%laneid: index) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) {
+    %0 = "some_def"() : () -> (vector<64xf32>)
+    %1 = vector.extract_strided_slice %0 { offsets = [16], sizes = [32], strides = [1],
+        layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+        layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+      }
+      : vector<64xf32> to vector<32xf32>
+    gpu.yield %1 : vector<32xf32>
+  }
+  "some_use"(%r) : (vector<2xf32>) -> ()
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_extract_strided_slice_unsopported_offset
+// CHECK:         %{{.*}} = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>) {
+// CHECK:         }
+// CHECK-NOT:     %{{.*}} = vector.extract_strided_slice
+gpu.func @vector_extract_strided_slice_unsopported_offset(%laneid: index) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) {
+    %0 = "some_def"() : () -> (vector<64xf32>)
+    %1 = vector.extract_strided_slice %0 { offsets = [3], sizes = [32], strides = [1],
+        layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+        layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+      }
+      : vector<64xf32> to vector<32xf32>
+    gpu.yield %1 : vector<32xf32>
+  }
+  "some_use"(%r) : (vector<2xf32>) -> ()
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_extract_strided_slice_unsopported_source
+// CHECK:         %{{.*}} = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>) {
+// CHECK:         }
+// CHECK-NOT:     %{{.*}} = vector.extract_strided_slice
+gpu.func @vector_extract_strided_slice_unsopported_source(%laneid: index) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) {
+    %0 = "some_def"() : () -> (vector<54xf32>)
+    %1 = vector.extract_strided_slice %0 { offsets = [0], sizes = [32], strides = [1],
+        layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+        layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+      }
+      : vector<54xf32> to vector<32xf32>
+    gpu.yield %1 : vector<32xf32>
+  }
+  "some_use"(%r) : (vector<2xf32>) -> ()
+  gpu.return
+}
+
+
+// CHECK-LABEL:  gpu.func @vector_insert_strided_slice_distributed_dim_fully_inserted
+// CHECK-NEXT:      %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<64x1xf32>, vector<16x1xf32>, vector<64x1xf32>) {
+// CHECK-NEXT:        %[[S:.*]] = "some_def"() : () -> vector<16x16xf32>
+// CHECK-NEXT:        %[[D:.*]] = "some_def"() : () -> vector<64x16xf32>
+// CHECK:             gpu.yield %{{.*}}, %[[S]], %[[D]] : vector<64x16xf32>, vector<16x16xf32>, vector<64x16xf32>
+// CHECK-NEXT:      }
+// CHECK-NEXT:      %[[T1:.*]] = vector.insert_strided_slice %[[W]]#1, %[[W]]#2
+// CHECK-SAME:        {offsets = [24, 0], strides = [1, 1]} : vector<16x1xf32> into vector<64x1xf32>
+// CHECK-NEXT:      "some_use"(%[[T1]]) : (vector<64x1xf32>) -> ()
+gpu.func @vector_insert_strided_slice_distributed_dim_fully_inserted(%laneid: index) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<64x1xf32>) {
+    %0 = "some_def"() : () -> (vector<16x16xf32>)
+    %1 = "some_def"() : () -> (vector<64x16xf32>)
+    %2 = vector.insert_strided_slice %0, %1 { offsets = [24, 0],  strides = [1, 1],
+      layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+      layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+      layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+    }
+      : vector<16x16xf32> into vector<64x16xf32>
+    gpu.yield %2 : vector<64x16xf32>
+  }
+  "some_use"(%r) : (vector<64x1xf32>) -> ()
+  gpu.return
+}
+
+
+// CHECK-LABEL: gpu.func @vector_insert_strided_slice_non_distributed
+// CHECK-NEXT:    %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<64x1xf32>, vector<16x1xf32>, vector<64x1xf32>) {
+// CHECK-NEXT:      %[[S:.*]] = "some_def"() : () -> vector<16x1xf32>
+// CHECK-NEXT:      %[[D:.*]] = "some_def"() : () -> vector<64x1xf32>
+// CHECK:           gpu.yield %{{.*}}, %[[S]], %[[D]] : vector<64x1xf32>, vector<16x1xf32>, vector<64x1xf32>
+// CHECK-NEXT:    }
+// CHECK-NEXT:    %[[T1:.*]] = vector.insert_strided_slice %[[W]]#1, %[[W]]#2
+// CHECK-SAME:      {offsets = [24, 0], strides = [1, 1]} : vector<16x1xf32> into vector<64x1xf32>
+// CHECK-NEXT:    "some_use"(%[[T1]]) : (vector<64x1xf32>) -> ()
+gpu.func @vector_insert_strided_slice_non_distributed(%laneid: index) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<64x1xf32>) {
+    %0 = "some_def"() : () -> (vector<16x1xf32>)
+    %1 = "some_def"() : () -> (vector<64x1xf32>)
+    %2 = vector.insert_strided_slice %0, %1 { offsets = [24, 0],  strides = [1, 1],
+      layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+      layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+      layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
     }
-    "some_user_op"(%r) : (vector<1x1xf32>) -> ()
-    gpu.return
+      : vector<16x1xf32> into vector<64x1xf32>
+    gpu.yield %2 : vector<64x1xf32>
   }
+  "some_use"(%r) : (vector<64x1xf32>) -> ()
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_insert_strided_slice_inner_distributed
+// CHECK:         %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<64x2xf32>, vector<16x1xf32>, vector<64x2xf32>) {
+// CHECK-NEXT:      %[[S:.*]] = "some_def"() : () -> vector<16x16xf32>
+// CHECK-NEXT:      %[[D:.*]] = "some_def"() : () -> vector<64x32xf32>
+// CHECK:           gpu.yield %{{.*}}, %[[S]], %[[D]] : vector<64x32xf32>, vector<16x16xf32>, vector<64x32xf32>
+// CHECK-NEXT:    }
+// CHECK-NEXT:    %[[T1:.*]] = vector.insert_strided_slice %[[W]]#1, %[[W]]#2
+// CHECK-SAME:      {offsets = [24, 1], strides = [1, 1]} : vector<16x1xf32> into vector<64x2xf32>
+// CHECK-NEXT:    "some_use"(%[[T1]]) : (vector<64x2xf32>) -> ()
+gpu.func @vector_insert_strided_slice_inner_distributed(%laneid: index) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<64x2xf32>) {
+    %0 = "some_def"() : () -> (vector<16x16xf32>)
+    %1 = "some_def"() : () -> (vector<64x32xf32>)
+    %2 = vector.insert_strided_slice %0, %1 { offsets = [24, 16],  strides = [1, 1],
+      layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+      layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+      layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+    }
+      : vector<16x16xf32> into vector<64x32xf32>
+    gpu.yield %2 : vector<64x32xf32>
+  }
+  "some_use"(%r) : (vector<64x2xf32>) -> ()
+  gpu.return
+}
+
+// CHECK-LABEL:   gpu.func @vector_insert_strided_slice_outer_distributed
+// CHECK:           %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<3x32xf32>, vector<1x16xf32>, vector<3x32xf32>) {
+// CHECK-NEXT:        %[[S:.*]] = "some_def"() : () -> vector<16x16xf32>
+// CHECK-NEXT:        %[[D:.*]] = "some_def"() : () -> vector<48x32xf32>
+// CHECK:             gpu.yield %{{.*}}, %[[S]], %[[D]] : vector<48x32xf32>, vector<16x16xf32>, vector<48x32xf32>
+// CHECK-NEXT:      }
+// CHECK-NEXT:      %[[T1:.*]] = vector.insert_strided_slice %[[W]]#1, %[[W]]#2
+// CHECK-SAME:        {offsets = [2, 4], strides = [1, 1]} : vector<1x16xf32> into vector<3x32xf32>
+// CHECK-NEXT:      "some_use"(%[[T1]]) : (vector<3x32xf32>) -> ()
+gpu.func @vector_insert_strided_slice_outer_distributed(%laneid: index) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<3x32xf32>) {
+    %0 = "some_def"() : () -> (vector<16x16xf32>)
+    %1 = "some_def"() : () -> (vector<48x32xf32>)
+    %2 = vector.insert_strided_slice %0, %1 { offsets = [32, 4],  strides = [1, 1],
+      layout_operand_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>,
+      layout_operand_1 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>,
+      layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>
+    }
+      : vector<16x16xf32> into vector<48x32xf32>
+    gpu.yield %2 : vector<48x32xf32>
+  }
+  "some_use"(%r) : (vector<3x32xf32>) -> ()
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_insert_strided_slice_1d
+// CHECK:         %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<3xf32>, vector<1xf32>, vector<3xf32>) {
+// CHECK-NEXT:      %[[S:.*]] = "some_def"() : () -> vector<16xf32>
+// CHECK-NEXT:      %[[D:.*]] = "some_def"() : () -> vector<48xf32>
+// CHECK:           gpu.yield %{{.*}}, %[[S]], %[[D]] : vector<48xf32>, vector<16xf32>, vector<48xf32>
+// CHECK-NEXT:    }
+// CHECK-NEXT:    %[[T1:.*]] = vector.insert_strided_slice %[[W]]#1, %[[W]]#2
+// CHECK-SAME:      {offsets = [1], strides = [1]} : vector<1xf32> into vector<3xf32>
+// CHECK-NEXT:    "some_use"(%[[T1]]) : (vector<3xf32>) -> ()
+gpu.func @vector_insert_strided_slice_1d(%laneid: index) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<3xf32>) {
+    %0 = "some_def"() : () -> (vector<16xf32>)
+    %1 = "some_def"() : () -> (vector<48xf32>)
+    %2 = vector.insert_strided_slice %0, %1 { offsets = [16],  strides = [1],
+      layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+      layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+      layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+    }
+      : vector<16xf32> into vector<48xf32>
+    gpu.yield %2 : vector<48xf32>
+  }
+  "some_use"(%r) : (vector<3xf32>) -> ()
+  gpu.return
+}
+
+// CHECK-LABEL:  gpu.func @vector_insert_strided_slice_unsupported_source
+// CHECK:          %{{.*}} = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<3xf32>) {
+// CHECK:          }
+// CHECK-NOT:      %{{.*}} = vector.insert_strided_slice
+gpu.func @vector_insert_strided_slice_unsupported_source(%laneid: index) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<3xf32>) {
+    %0 = "some_def"() : () -> (vector<8xf32>)
+    %1 = "some_def"() : () -> (vector<48xf32>)
+    %2 = vector.insert_strided_slice %0, %1 { offsets = [16],  strides = [1],
+      layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+      layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+      layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+    }
+      : vector<8xf32> into vector<48xf32>
+    gpu.yield %2 : vector<48xf32>
+  }
+  "some_use"(%r) : (vector<3xf32>) -> ()
+  gpu.return
+}
+
+// CHECK-LABEL:  gpu.func @vector_insert_strided_slice_unsupported_offset
+// CHECK:          %{{.*}} = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<3xf32>) {
+// CHECK:          }
+// CHECK-NOT:      %{{.*}} = vector.insert_strided_slice
+gpu.func @vector_insert_strided_slice_unsupported_offset(%laneid: index) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<3xf32>) {
+    %0 = "some_def"() : () -> (vector<16xf32>)
+    %1 = "some_def"() : () -> (vector<48xf32>)
+    %2 = vector.insert_strided_slice %0, %1 { offsets = [3],  strides = [1],
+      layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+      layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+      layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+    }
+      : vector<16xf32> into vector<48xf32>
+    gpu.yield %2 : vector<48xf32>
+  }
+  "some_use"(%r) : (vector<3xf32>) -> ()
+  gpu.return
+}
+
 }


        


More information about the Mlir-commits mailing list