[Mlir-commits] [mlir] ce288f4 - [MLIR][XeGPU] Add distribution patterns for vector insert & extract ops in sg to wi pass (#184665)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Mar 22 15:05:28 PDT 2026


Author: Nishant Patel
Date: 2026-03-22T15:05:23-07:00
New Revision: ce288f444102a0d8ca1c1223b4f85bf4d1b6300e

URL: https://github.com/llvm/llvm-project/commit/ce288f444102a0d8ca1c1223b4f85bf4d1b6300e
DIFF: https://github.com/llvm/llvm-project/commit/ce288f444102a0d8ca1c1223b4f85bf4d1b6300e.diff

LOG: [MLIR][XeGPU] Add distribution patterns for vector insert & extract ops in sg to wi pass (#184665)

This PR adds patterns for following vector ops in the new sg-to-wi pass
1. ExtractOp
2. ExtractStridedSliceOp
3. InsertStridedSliceOp
4. InsertOp

Added: 
    

Modified: 
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
    mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
    mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
index 82e8bbd107abe..0961ddfb92040 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
@@ -122,6 +122,20 @@ static bool isReductionLaneLocal(vector::MultiDimReductionOp op) {
   return resTy != resDistTypeOrFailure.value();
 }
 
+/// Given a vector type and its distributed vector type, return the list of
+/// dimensions that are distributed.
+static SmallVector<int64_t> getDistributedDims(VectorType originalType,
+                                               VectorType distributedType) {
+  assert(originalType.getRank() == distributedType.getRank() &&
+         "original and distributed vector types must have the same rank");
+  SmallVector<int64_t> distributedDims;
+  for (int64_t i = 0; i < originalType.getRank(); ++i) {
+    if (distributedType.getDimSize(i) != originalType.getDimSize(i))
+      distributedDims.push_back(i);
+  }
+  return distributedDims;
+}
+
 /// Distributes a subgroup-level CreateNdDesc op to workitem-level CreateNdDesc
 /// op. This simply drops the layout attribute from the tensor descriptor type.
 struct SgToWiCreateNdDesc : public OpConversionPattern<xegpu::CreateNdDescOp> {
@@ -828,6 +842,262 @@ struct SgToWiStoreScatter : public OpConversionPattern<xegpu::StoreScatterOp> {
   }
 };
 
+/// Distributes a subgroup-level vector.extract op to workitem-level. Only
+/// handles sub-vector extraction (result is VectorType, not scalar).
+struct SgToWiVectorExtract : public OpConversionPattern<vector::ExtractOp> {
+  using OpConversionPattern<vector::ExtractOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::ExtractOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Only handle vector results (not scalar extraction).
+    auto resultType = dyn_cast<VectorType>(op.getType());
+    if (!resultType)
+      return rewriter.notifyMatchFailure(op, "scalar extract not supported");
+
+    xegpu::DistributeLayoutAttr layout =
+        xegpu::getTemporaryLayout(op->getOpResult(0));
+    if (!layout || !layout.isForSubgroup())
+      return failure();
+
+    // This implementation assumes distribution only happens on the innermost
+    // dimension. Verify that lane_layout[0...n-2] are all unit.
+    auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
+    if (llvm::any_of(ArrayRef<int64_t>(laneLayout).drop_back(1),
+                     [](int64_t v) { return v != 1; }))
+      return rewriter.notifyMatchFailure(
+          op, "only innermost dimension distribution is supported for "
+              "vector.extract");
+
+    auto newOp = vector::ExtractOp::create(
+        rewriter, op.getLoc(), adaptor.getSource(), op.getMixedPosition());
+    rewriter.replaceOp(op, newOp.getResult());
+    return success();
+  }
+};
+
+/// Distributes a subgroup-level vector.extract_strided_slice op to
+/// workitem-level. If the result is distributed, the offsets and sizes are
+/// adjusted to match the distributed types.
+struct SgToWiVectorExtractStridedSlice
+    : public OpConversionPattern<vector::ExtractStridedSliceOp> {
+  using OpConversionPattern<vector::ExtractStridedSliceOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::ExtractStridedSliceOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    xegpu::DistributeLayoutAttr resultLayout =
+        xegpu::getTemporaryLayout(op->getOpResult(0));
+    if (!resultLayout || !resultLayout.isForSubgroup())
+      return failure();
+
+    VectorType resultType = op.getType();
+    auto distResultTyOrFailure =
+        xegpu::getDistVecTypeBasedOnLaneLayout(resultLayout, resultType);
+    if (failed(distResultTyOrFailure))
+      return rewriter.notifyMatchFailure(
+          op, "unable to compute distributed vector type from lane layout");
+    VectorType distResultTy = *distResultTyOrFailure;
+
+    SmallVector<int64_t> distributedDims =
+        getDistributedDims(resultType, distResultTy);
+
+    // Collect updated sizes, offsets, strides. Pad to full source rank.
+    int64_t sourceRank = op.getSourceVectorType().getRank();
+    SmallVector<Attribute> updatedSizes =
+        llvm::map_to_vector(op.getSizes(), [](Attribute attr) { return attr; });
+    SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
+        op.getOffsets(), [](Attribute attr) { return attr; });
+    SmallVector<Attribute> updatedStrides = llvm::map_to_vector(
+        op.getStrides(), [](Attribute attr) { return attr; });
+    for (int64_t i = op.getSizes().size(); i < sourceRank; ++i) {
+      updatedSizes.push_back(
+          rewriter.getI64IntegerAttr(op.getSourceVectorType().getDimSize(i)));
+      updatedOffsets.push_back(rewriter.getI64IntegerAttr(0));
+      updatedStrides.push_back(rewriter.getI64IntegerAttr(1));
+    }
+
+    // If the result is distributed, adjust offsets and sizes in the
+    // distributed dimension.
+    if (!distributedDims.empty()) {
+      if (distributedDims.size() != 1)
+        return rewriter.notifyMatchFailure(
+            op, "only single dimension distribution is supported");
+      int64_t distDim = distributedDims[0];
+      const uArch *uArch = getUArch(xegpu::getChipStr(op).value_or(""));
+      if (!uArch)
+        return rewriter.notifyMatchFailure(
+            op, "target attribute required to determine subgroup size");
+      int subgroupSize = uArch->getSubgroupSize();
+      auto sourceLayout = xegpu::getTemporaryLayout(op->getOpOperand(0));
+      if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt().empty())
+        return rewriter.notifyMatchFailure(
+            op, "source of extract_strided_slice lacks distribution layout");
+      int sourceDistrDimSize = op.getSourceVectorType().getShape()[distDim];
+      if (sourceDistrDimSize % subgroupSize != 0)
+        return rewriter.notifyMatchFailure(
+            op, "source size along distributed dim is not a multiple of "
+                "subgroup size");
+      auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
+      // Only check lane_data for the distributed dimension. Non-distributed
+      // dimensions may have non-unit lane_data (e.g., packed layouts).
+      if (distDim < static_cast<int64_t>(sourceLaneData.size()) &&
+          sourceLaneData[distDim] != 1)
+        return rewriter.notifyMatchFailure(
+            op, "expecting unit lane data along the distributed dimension");
+      int64_t distrDimOffset =
+          cast<IntegerAttr>(updatedOffsets[distDim]).getInt();
+      if (distrDimOffset % subgroupSize != 0)
+        return rewriter.notifyMatchFailure(
+            op, "offset along distributed dim is not a multiple of "
+                "subgroup size");
+      // Adjust sizes and offsets for the distributed dimension.
+      updatedSizes[distDim] =
+          rewriter.getI64IntegerAttr(distResultTy.getDimSize(distDim));
+      updatedOffsets[distDim] =
+          rewriter.getI64IntegerAttr(distrDimOffset / subgroupSize);
+    }
+
+    auto newOp = vector::ExtractStridedSliceOp::create(
+        rewriter, op.getLoc(), distResultTy, adaptor.getSource(),
+        ArrayAttr::get(rewriter.getContext(), updatedOffsets),
+        ArrayAttr::get(rewriter.getContext(), updatedSizes),
+        ArrayAttr::get(rewriter.getContext(), updatedStrides));
+    rewriter.replaceOp(op, newOp.getResult());
+    return success();
+  }
+};
+
+/// Distributes a subgroup-level vector.insert_strided_slice op to
+/// workitem-level. If the dest is distributed, the offsets are adjusted to
+/// match the distributed types.
+struct SgToWiVectorInsertStridedSlice
+    : public OpConversionPattern<vector::InsertStridedSliceOp> {
+  using OpConversionPattern<vector::InsertStridedSliceOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::InsertStridedSliceOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    xegpu::DistributeLayoutAttr resultLayout =
+        xegpu::getTemporaryLayout(op->getOpResult(0));
+    if (!resultLayout || !resultLayout.isForSubgroup())
+      return failure();
+
+    VectorType destType = op.getDestVectorType();
+    auto distDestTyOrFailure =
+        xegpu::getDistVecTypeBasedOnLaneLayout(resultLayout, destType);
+    if (failed(distDestTyOrFailure))
+      return rewriter.notifyMatchFailure(
+          op, "unable to compute distributed vector type from lane layout");
+    VectorType distDestTy = *distDestTyOrFailure;
+
+    SmallVector<int64_t> destDistributedDims =
+        getDistributedDims(destType, distDestTy);
+
+    SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
+        op.getOffsets(), [](Attribute attr) { return attr; });
+
+    if (!destDistributedDims.empty()) {
+      if (destDistributedDims.size() != 1)
+        return rewriter.notifyMatchFailure(
+            op, "only single dimension distribution is supported");
+      int64_t destDistDim = destDistributedDims[0];
+
+      const uArch *uArch = getUArch(xegpu::getChipStr(op).value_or(""));
+      if (!uArch)
+        return rewriter.notifyMatchFailure(
+            op, "target attribute required to determine subgroup size");
+      int subgroupSize = uArch->getSubgroupSize();
+
+      VectorType srcType = op.getSourceVectorType();
+      // The distributed dim must be in the last k (source rank) dims of dest.
+      int64_t sourceDistDim =
+          destDistDim - (destType.getRank() - srcType.getRank());
+      if (sourceDistDim < 0)
+        return rewriter.notifyMatchFailure(
+            op, "distributed dimension must be in the last k dims of dest");
+
+      auto destLayout = xegpu::getTemporaryLayout(op->getOpOperand(1));
+      auto sourceLayout = xegpu::getTemporaryLayout(op->getOpOperand(0));
+      if (!destLayout || !sourceLayout ||
+          destLayout.getEffectiveLaneLayoutAsInt().empty() ||
+          sourceLayout.getEffectiveLaneLayoutAsInt().empty())
+        return rewriter.notifyMatchFailure(
+            op, "source or dest of insert_strided_slice lacks distribution "
+                "layout");
+
+      auto destLaneData = destLayout.getEffectiveLaneDataAsInt();
+      auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
+      // Only check lane_data for the distributed dimension. Non-distributed
+      // dimensions may have non-unit lane_data (e.g., packed layouts).
+      if ((destDistDim < static_cast<int64_t>(destLaneData.size()) &&
+           destLaneData[destDistDim] != 1) ||
+          (sourceDistDim < static_cast<int64_t>(sourceLaneData.size()) &&
+           sourceLaneData[sourceDistDim] != 1))
+        return rewriter.notifyMatchFailure(
+            op, "expecting unit lane data along the distributed dimension");
+
+      int64_t srcDistrDimSize = srcType.getDimSize(sourceDistDim);
+      if (srcDistrDimSize % subgroupSize != 0)
+        return rewriter.notifyMatchFailure(
+            op, "source distributed dim size is not a multiple of "
+                "subgroup size");
+
+      int64_t destDistrDimOffset =
+          cast<IntegerAttr>(op.getOffsets()[destDistDim]).getInt();
+      if (destDistrDimOffset % subgroupSize != 0)
+        return rewriter.notifyMatchFailure(
+            op, "offset along distributed dim is not a multiple of "
+                "subgroup size");
+      // Adjust offset for the distributed dimension.
+      updatedOffsets[destDistDim] =
+          rewriter.getI64IntegerAttr(destDistrDimOffset / subgroupSize);
+    }
+
+    auto newOp = vector::InsertStridedSliceOp::create(
+        rewriter, op.getLoc(), distDestTy, adaptor.getValueToStore(),
+        adaptor.getDest(),
+        ArrayAttr::get(rewriter.getContext(), updatedOffsets), op.getStrides());
+    rewriter.replaceOp(op, newOp.getResult());
+    return success();
+  }
+};
+
+/// Distributes a subgroup-level vector.insert op to workitem-level. Only
+/// handles sub-vector insertion (value to store is VectorType, not scalar).
+struct SgToWiVectorInsert : public OpConversionPattern<vector::InsertOp> {
+  using OpConversionPattern<vector::InsertOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::InsertOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Only handle vector value-to-store (not scalar insertion).
+    auto valueType = dyn_cast<VectorType>(op.getValueToStoreType());
+    if (!valueType)
+      return rewriter.notifyMatchFailure(op, "scalar insert not supported");
+
+    xegpu::DistributeLayoutAttr layout =
+        xegpu::getTemporaryLayout(op->getOpResult(0));
+    if (!layout || !layout.isForSubgroup())
+      return failure();
+
+    // verify that the outer k dimensions (for offsets)
+    // don't have non-unit lane_layout.
+    auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
+    if (llvm::any_of(ArrayRef<int64_t>(laneLayout).drop_back(1),
+                     [](int64_t v) { return v != 1; }))
+      return rewriter.notifyMatchFailure(
+          op, "only innermost dimension distribution is supported for "
+              "vector.insert");
+
+    auto newOp = vector::InsertOp::create(
+        rewriter, op.getLoc(), adaptor.getValueToStore(), adaptor.getDest(),
+        op.getMixedPosition());
+    rewriter.replaceOp(op, newOp.getResult());
+    return success();
+  }
+};
+
 /// Folds a subgroup-level ConvertLayout op with compatible lane layouts.
 struct SgToWiConvertLayout
     : public OpConversionPattern<xegpu::ConvertLayoutOp> {
@@ -1052,10 +1322,30 @@ void xegpu::populateXeGPUSgToWiDistributeTypeConversionAndLegality(
       [=](vector::MultiDimReductionOp op) -> bool {
         return !isValidSubgroupMultiReductionOp(op);
       });
+  target.addDynamicallyLegalOp<vector::ExtractOp>(
+      [=](vector::ExtractOp op) -> bool {
+        if (!isa<VectorType>(op.getType()))
+          return true;
+        return !xegpu::getTemporaryLayout(op->getOpResult(0));
+      });
+  target.addDynamicallyLegalOp<vector::InsertOp>(
+      [=](vector::InsertOp op) -> bool {
+        return !xegpu::getTemporaryLayout(op->getOpResult(0));
+      });
+  target.addDynamicallyLegalOp<vector::ExtractStridedSliceOp>(
+      [=](vector::ExtractStridedSliceOp op) -> bool {
+        return !xegpu::getTemporaryLayout(op->getOpResult(0));
+      });
+  target.addDynamicallyLegalOp<vector::InsertStridedSliceOp>(
+      [=](vector::InsertStridedSliceOp op) -> bool {
+        return !xegpu::getTemporaryLayout(op->getOpResult(0));
+      });
   target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
   patterns.add<SgToWiCreateNdDesc, SgToWiLoadNd, SgToWiStoreNd, SgToWiDpas,
                SgToWiElementWise, SgToWiArithConstant, SgToWiPrefetchNd,
                SgToWiLoadGather, SgToWiStoreScatter, SgToWiVectorReduction,
-               SgToWiMultiDimReduction, SgToWiLoadMatrix, SgToWiStoreMatrix,
-               SgToWiConvertLayout>(typeConverter, patterns.getContext());
+               SgToWiMultiDimReduction, SgToWiVectorExtract, SgToWiVectorInsert,
+               SgToWiVectorExtractStridedSlice, SgToWiVectorInsertStridedSlice,
+               SgToWiLoadMatrix, SgToWiStoreMatrix, SgToWiConvertLayout>(
+      typeConverter, patterns.getContext());
 }

diff  --git a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
index 57e47b6794f07..016b393e3d8bc 100644
--- a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
@@ -461,6 +461,205 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction(%laneid: index)
   gpu.return
 }
 
+// CHECK-LABEL: gpu.func @vector_extract_from_2d
+// CHECK: %[[EXT:.*]] = vector.extract %{{.*}}[0] : vector<1xf32> from vector<4x1xf32>
+gpu.func @vector_extract_from_2d() {
+  %src = "some_op"()
+    {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+    : () -> vector<4x16xf32>
+  %0 = vector.extract %src[0]
+    {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+    : vector<16xf32> from vector<4x16xf32>
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_extract_from_2d_offset2
+// CHECK: %[[EXT:.*]] = vector.extract %{{.*}}[2] : vector<1xf32> from vector<8x1xf32>
+gpu.func @vector_extract_from_2d_offset2() {
+  %src = "some_op"()
+    {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+    : () -> vector<8x16xf32>
+  %0 = vector.extract %src[2]
+    {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+    : vector<16xf32> from vector<8x16xf32>
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_insert_into_2d
+// CHECK: %[[INS:.*]] = vector.insert %{{.*}}, %{{.*}}[0] : vector<1xf32> into vector<4x1xf32>
+gpu.func @vector_insert_into_2d() {
+  %val = "some_op"()
+    {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+    : () -> vector<16xf32>
+  %dst = "some_op"()
+    {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+    : () -> vector<4x16xf32>
+  %0 = vector.insert %val, %dst[0]
+    {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+    : vector<16xf32> into vector<4x16xf32>
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_insert_into_2d_offset2
+// CHECK: %[[INS:.*]] = vector.insert %{{.*}}, %{{.*}}[2] : vector<1xf32> into vector<8x1xf32>
+gpu.func @vector_insert_into_2d_offset2() {
+  %val = "some_op"()
+    {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+    : () -> vector<16xf32>
+  %dst = "some_op"()
+    {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+    : () -> vector<8x16xf32>
+  %0 = vector.insert %val, %dst[2]
+    {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+    : vector<16xf32> into vector<8x16xf32>
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_extract_strided_slice_distributed_dim_fully_extracted
+// CHECK: %[[ESS:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [8, 0], sizes = [8, 1], strides = [1, 1]} : vector<24x1xf32> to vector<8x1xf32>
+gpu.func @vector_extract_strided_slice_distributed_dim_fully_extracted() {
+  %0 = "some_op"()
+    {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+    : () -> vector<24x16xf32>
+  %1 = vector.extract_strided_slice %0 { offsets = [8, 0], sizes = [8, 16], strides = [1, 1],
+      layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+    }
+    : vector<24x16xf32> to vector<8x16xf32>
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_extract_strided_slice_inner_distributed
+// CHECK: %[[ESS:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [8, 3], sizes = [8, 1], strides = [1, 1]} : vector<24x4xf32> to vector<8x1xf32>
+gpu.func @vector_extract_strided_slice_inner_distributed() {
+  %0 = "some_op"()
+    {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+    : () -> vector<24x64xf32>
+  %1 = vector.extract_strided_slice %0 { offsets = [8, 48], sizes = [8, 16], strides = [1, 1],
+      layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+    }
+    : vector<24x64xf32> to vector<8x16xf32>
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_extract_strided_slice_outer_distributed
+// CHECK: %[[ESS:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0], sizes = [1, 16], strides = [1, 1]} : vector<2x16xf32> to vector<1x16xf32>
+gpu.func @vector_extract_strided_slice_outer_distributed() {
+  %0 = "some_op"()
+    {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+    : () -> vector<32x16xf32>
+  %1 = vector.extract_strided_slice %0 { offsets = [16], sizes = [16], strides = [1],
+      layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>
+    }
+    : vector<32x16xf32> to vector<16x16xf32>
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_extract_strided_slice_1d
+// CHECK: %[[ESS:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
+gpu.func @vector_extract_strided_slice_1d() {
+  %0 = "some_op"()
+    {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+    : () -> vector<64xf32>
+  %1 = vector.extract_strided_slice %0 { offsets = [16], sizes = [32], strides = [1],
+      layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+    }
+    : vector<64xf32> to vector<32xf32>
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_extract_strided_slice_partial_offsets
+// CHECK: %[[ESS:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [8, 0], sizes = [8, 1], strides = [1, 1]} : vector<24x1xf32> to vector<8x1xf32>
+gpu.func @vector_extract_strided_slice_partial_offsets() {
+  %0 = "some_op"()
+    {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+    : () -> vector<24x16xf32>
+  %1 = vector.extract_strided_slice %0 { offsets = [8], sizes = [8], strides = [1],
+      layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+    }
+    : vector<24x16xf32> to vector<8x16xf32>
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_insert_strided_slice_distributed_dim_fully_inserted
+// CHECK: %[[ISS:.*]] = vector.insert_strided_slice %{{.*}}, %{{.*}} {offsets = [24, 0], strides = [1, 1]} : vector<16x1xf32> into vector<64x1xf32>
+gpu.func @vector_insert_strided_slice_distributed_dim_fully_inserted() {
+  %0 = "some_op"()
+    {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+    : () -> vector<16x16xf32>
+  %1 = "some_op"()
+    {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+    : () -> vector<64x16xf32>
+  %2 = vector.insert_strided_slice %0, %1 { offsets = [24, 0], strides = [1, 1],
+      layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+    }
+    : vector<16x16xf32> into vector<64x16xf32>
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_insert_strided_slice_inner_distributed
+// CHECK: %[[ISS:.*]] = vector.insert_strided_slice %{{.*}}, %{{.*}} {offsets = [24, 1], strides = [1, 1]} : vector<16x1xf32> into vector<64x2xf32>
+gpu.func @vector_insert_strided_slice_inner_distributed() {
+  %0 = "some_op"()
+    {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+    : () -> vector<16x16xf32>
+  %1 = "some_op"()
+    {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+    : () -> vector<64x32xf32>
+  %2 = vector.insert_strided_slice %0, %1 { offsets = [24, 16], strides = [1, 1],
+      layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+    }
+    : vector<16x16xf32> into vector<64x32xf32>
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_insert_strided_slice_outer_distributed
+// CHECK: %[[ISS:.*]] = vector.insert_strided_slice %{{.*}}, %{{.*}} {offsets = [2, 4], strides = [1, 1]} : vector<1x16xf32> into vector<3x32xf32>
+gpu.func @vector_insert_strided_slice_outer_distributed() {
+  %0 = "some_op"()
+    {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+    : () -> vector<16x16xf32>
+  %1 = "some_op"()
+    {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+    : () -> vector<48x32xf32>
+  %2 = vector.insert_strided_slice %0, %1 { offsets = [32, 4], strides = [1, 1],
+      layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>
+    }
+    : vector<16x16xf32> into vector<48x32xf32>
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_insert_strided_slice_1d
+// CHECK: %[[ISS:.*]] = vector.insert_strided_slice %{{.*}}, %{{.*}} {offsets = [1], strides = [1]} : vector<1xf32> into vector<3xf32>
+gpu.func @vector_insert_strided_slice_1d() {
+  %0 = "some_op"()
+    {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+    : () -> vector<16xf32>
+  %1 = "some_op"()
+    {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+    : () -> vector<48xf32>
+  %2 = vector.insert_strided_slice %0, %1 { offsets = [16], strides = [1],
+      layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+    }
+    : vector<16xf32> into vector<48xf32>
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_insert_strided_slice_
diff erent_ranks
+// CHECK: %[[ISS:.*]] = vector.insert_strided_slice %{{.*}}, %{{.*}} {offsets = [13, 0], strides = [1]} : vector<1xf32> into vector<64x1xf32>
+gpu.func @vector_insert_strided_slice_
diff erent_ranks() {
+  %0 = "some_op"()
+    {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+    : () -> vector<16xf32>
+  %1 = "some_op"()
+    {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+    : () -> vector<64x16xf32>
+  %2 = vector.insert_strided_slice %0, %1 { offsets = [13, 0], strides = [1],
+      layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+    }
+    : vector<16xf32> into vector<64x16xf32>
+  gpu.return
+}
+
 // CHECK-LABEL: gpu.func @convert_layout_removed_when_compatible
 // CHECK-NOT: xegpu.convert_layout
 gpu.func @convert_layout_removed_when_compatible() {

diff  --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index ffd091b5154b3..0d10ab7c74da6 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -288,6 +288,12 @@ struct TestXeGPUSgToWiDistributeExperimental
       : PassWrapper(pass) {}
 
   void runOnOperation() override {
+    Operation *op = getOperation();
+    if (!xegpu::recoverTemporaryLayouts(op)) {
+      signalPassFailure();
+      return;
+    }
+
     MLIRContext *ctx = &getContext();
     TypeConverter typeConverter;
     // Define type materializations using UnrealizedConversionCastOp.
@@ -304,7 +310,7 @@ struct TestXeGPUSgToWiDistributeExperimental
     RewritePatternSet patterns(ctx);
     xegpu::populateXeGPUSgToWiDistributeTypeConversionAndLegality(
         typeConverter, patterns, target);
-    (void)applyPartialConversion(getOperation(), target, std::move(patterns));
+    (void)applyPartialConversion(op, target, std::move(patterns));
   }
 };
 


        


More information about the Mlir-commits mailing list