[Mlir-commits] [mlir] [MLIR][XeGPU] Enhancing insert_strided_slice layout setup and infer rules (PR #184742)

Jianhui Li llvmlistbot at llvm.org
Fri Mar 6 06:59:37 PST 2026


https://github.com/Jianhui-Li updated https://github.com/llvm/llvm-project/pull/184742

>From 58264107ef1d5a942d73c5b951a6c70daf06f11e Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 5 Mar 2026 01:42:54 +0000
Subject: [PATCH 1/7] refactor insert_strided_slice layout rules

---
 .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td       |  13 ++
 .../XeGPU/Transforms/XeGPULayoutImpl.h        |   6 +-
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    | 103 ++++++++++++--
 .../XeGPU/Transforms/XeGPULayoutImpl.cpp      | 128 ++++++++----------
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp |  10 +-
 .../XeGPU/propagate-layout-inst-data.mlir     |  12 +-
 6 files changed, 177 insertions(+), 95 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 6f667f4801673..1f53c8b3b93fa 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -248,6 +248,11 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
                           "int64_t": $sgData,
                           "int64_t": $instData,
                           "int64_t": $laneData)>,
+    InterfaceMethod<[{Derive a new layout by removing dimensions.
+                      `dimGroup` specifies a group of dimensions to be removed in the derived layout.}],
+                    "xegpu::DistributeLayoutAttr",
+                    "dropDims",
+                    (ins "SmallVector<int64_t>": $dimGroup)>,
     InterfaceMethod<[{Derive a new layout by collapsing dimensions.
                       `dimGroup` specifies a group of adjacent dimensions that are collapsed into
                        a single dimension in the derived layout.}],
@@ -562,6 +567,10 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
     // preserves its original value.
     DistributeLayoutAttr setDimData(int64_t dim, int64_t sgData, int64_t instData, int64_t laneData);
 
+    // Derive a new layout by removing dimensions.
+    // `dimGroup` specifies a group of dimensions to be removed in the derived layout.
+    DistributeLayoutAttr dropDims(SmallVector<int64_t> dimGroup);
+
     // Derive a new layout by collapsing dimensions.
     // `dimGroup` specifies a group of adjacent dimensions
     // that are collapsed into a single dimension in the derived layout.
@@ -762,6 +771,10 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
     // preserves its original value.
     DistributeLayoutAttr setDimData(int64_t dim, int64_t sgData, int64_t instData, int64_t laneData);
 
+    // Derive a new layout by removing dimensions.
+    // `dimGroup` specifies a group of dimensions to be removed in the derived layout.
+    DistributeLayoutAttr dropDims(SmallVector<int64_t> dimGroup);
+
     // Derive a new layout by collapsing dimensions.
     // `dimGroup` specifies a group of adjacent dimensions
     // that are collapsed into a single dimension in the derived layout.
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
index 3482d1b9401bb..5ccb69e0a89dd 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
@@ -132,9 +132,9 @@ DistributeLayoutAttr setupBitCastResultLayout(
 /// Sets up the result layout for an insert strided slice operation.
 /// Creates a result layout based on the specified layout kind (InstData or
 /// Lane).
-DistributeLayoutAttr setupInsertStridedSliceResultLayout(
-    LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy,
-    DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch);
+// DistributeLayoutAttr setupInsertStridedSliceResultLayout(
+//     LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy,
+//     DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch);
 
 /// Sets up the anchor layout for a load gather operation.
 DistributeLayoutAttr
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index c082600ec27d7..76b7a690288a3 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -501,6 +501,76 @@ DistributeLayoutAttr LayoutAttr::setDimData(int64_t dim, int64_t sgData,
       getOrder());
 }
 
+// Derive a new layout by removing dimensions.
+// `dimGroup` specifies a group of dimensions to be removed in the derived
+// layout.
+DistributeLayoutAttr LayoutAttr::dropDims(SmallVector<int64_t> dimGroup) {
+
+  SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
+  SmallVector<int64_t> sgData = getEffectiveSgDataAsInt();
+  SmallVector<int64_t> instData = getEffectiveInstDataAsInt();
+  SmallVector<int64_t> laneLayout = getEffectiveLaneLayoutAsInt();
+  SmallVector<int64_t> laneData = getEffectiveLaneDataAsInt();
+  SmallVector<int64_t> origOrder = getEffectiveOrderAsInt();
+
+  SmallVector<int64_t> sortedDimGroup = dimGroup;
+  llvm::sort(sortedDimGroup);
+
+  if (!sgLayout.empty()) {
+    for (auto dimIdx : llvm::reverse(sortedDimGroup)) {
+      sgLayout.erase(sgLayout.begin() + dimIdx, sgLayout.begin() + dimIdx + 1);
+      sgData.erase(sgData.begin() + dimIdx, sgData.begin() + dimIdx + 1);
+    }
+  }
+
+  if (!instData.empty()) {
+    for (auto dimIdx : llvm::reverse(sortedDimGroup))
+      instData.erase(instData.begin() + dimIdx, instData.begin() + dimIdx + 1);
+  }
+
+  if (!laneLayout.empty()) {
+    for (auto dimIdx : llvm::reverse(sortedDimGroup)) {
+      laneLayout.erase(laneLayout.begin() + dimIdx,
+                       laneLayout.begin() + dimIdx + 1);
+      laneData.erase(laneData.begin() + dimIdx, laneData.begin() + dimIdx + 1);
+    }
+  }
+
+  SmallVector<int64_t> newOrder;
+  for (int64_t d : origOrder) {
+    if (llvm::is_contained(dimGroup, d))
+      continue;
+    int64_t offset = llvm::count_if(dimGroup, [&](int64_t s) { return s < d; });
+    newOrder.push_back(d - offset);
+  }
+
+  // Create dropped layout
+  SmallVector<int32_t> sgLayout32(sgLayout.begin(), sgLayout.end());
+  SmallVector<int32_t> sgData32(sgData.begin(), sgData.end());
+  SmallVector<int32_t> instData32(instData.begin(), instData.end());
+  SmallVector<int32_t> laneLayout32(laneLayout.begin(), laneLayout.end());
+  SmallVector<int32_t> laneData32(laneData.begin(), laneData.end());
+  SmallVector<int32_t> newOrder32(newOrder.begin(), newOrder.end());
+
+  DenseI32ArrayAttr orderAttr = getOrder();
+  auto droppedLayout = xegpu::LayoutAttr::get(
+      getContext(),
+      sgLayout32.empty() ? DenseI32ArrayAttr()
+                         : DenseI32ArrayAttr::get(getContext(), sgLayout32),
+      sgData32.empty() ? DenseI32ArrayAttr()
+                       : DenseI32ArrayAttr::get(getContext(), sgData32),
+      instData32.empty() ? DenseI32ArrayAttr()
+                         : DenseI32ArrayAttr::get(getContext(), instData32),
+      laneLayout32.empty() ? DenseI32ArrayAttr()
+                           : DenseI32ArrayAttr::get(getContext(), laneLayout32),
+      laneData32.empty() ? DenseI32ArrayAttr()
+                         : DenseI32ArrayAttr::get(getContext(), laneData32),
+      (!orderAttr || orderAttr.empty())
+          ? DenseI32ArrayAttr()
+          : DenseI32ArrayAttr::get(getContext(), newOrder32));
+  return droppedLayout;
+}
+
 // Derive a new layout by collapsing dimensions.
 // `dimGroup` specifies a group of adjacent dimensions
 // that are collapsed into a single dimension in the derived layout.
@@ -524,8 +594,8 @@ DistributeLayoutAttr LayoutAttr::collapseDims(SmallVector<int64_t> dimGroup) {
   llvm::sort(sortedDimGroup);
   int64_t dimBeforeCurrent = -1;
   for (auto dimIdx : sortedDimGroup) {
-    // when order is present, adjacency dims are on order values like [3, 2, 1,
-    // 0] in decreasing order otherwise based on dim indices like [0, 1, 2, 3]
+    // when order is present, adjacency dims are values like [3, 2, 1, 0]
+    // in decreasing order; otherwise based on dim indices like [0, 1, 2, 3]
     // in increasing order
     if (dimBeforeCurrent >= 0) {
       if (!orderVec.empty()) {
@@ -586,10 +656,10 @@ DistributeLayoutAttr LayoutAttr::collapseDims(SmallVector<int64_t> dimGroup) {
     laneData.insert(laneData.begin() + firstDim, collapsedLaneData);
   }
 
-  // go through the values inside collapsedOrder, and re-map the order values
-  // to be in range of [0, N-1] where N is the number of dimensions in
-  // collapsed shape for exmaple, collapse dim group {2, 3} of order[1, 2, 3,
-  // 4] to new order[1, 3, 4]. the loop below remaps it to [1, 2, 3].
+  // After collapsing the order vector, re-map the order values to be in range
+  // of [0, N-1] where N is the number of dimensions in collapsed shape. For
+  // exmaple, collapse dim group {2, 3} of order[1, 2, 3, 4] to new order[1, 3,
+  // 4]. the loop below remaps it to [1, 2, 3].
   SmallVector<int32_t> collapsedOrder;
   if (!orderVec.empty()) {
 
@@ -598,7 +668,6 @@ DistributeLayoutAttr LayoutAttr::collapseDims(SmallVector<int64_t> dimGroup) {
         orderVec.erase(orderVec.begin() + dimIdx,
                        orderVec.begin() + dimIdx + 1);
     }
-
     // say we have orderVec = {5, 3, 2, 1, 0}
     // Create indices [0, 1, 2, 3, 4]
     SmallVector<size_t> indices =
@@ -864,6 +933,20 @@ DistributeLayoutAttr SliceAttr::setDimData(int64_t dim, int64_t sgData,
       parent.setDimData(adjustDims[0], sgData, instData, laneData), getDims());
 }
 
+// Derive a new layout by removing dimensions.
+// `dimGroup` specifies a group of dimensions to be removed in the derived
+// layout.
+DistributeLayoutAttr SliceAttr::dropDims(SmallVector<int64_t> dimGroup) {
+  // Map the sliced dims from parent space to collapsed space
+  SmallVector<int64_t> sliceDims = llvm::to_vector(getDims().asArrayRef());
+  SmallVector<int64_t> dimsInParentSpace =
+      mapSlicedDimsToParentSpace(dimGroup, sliceDims);
+
+  auto droppedParent = getParent().dropDims(dimsInParentSpace);
+  return SliceAttr::get(getContext(), droppedParent,
+                        DenseI64ArrayAttr::get(getContext(), sliceDims));
+}
+
 // Derive a new layout by collapsing dimensions.
 // `dimGroup` specifies a group of adjacent dimensions
 // that are collapsed into a single dimension in the derived layout.
@@ -871,12 +954,14 @@ DistributeLayoutAttr SliceAttr::collapseDims(SmallVector<int64_t> dimGroup) {
 
   // Map the sliced dims from parent space to collapsed space
   SmallVector<int64_t> sliceDims = llvm::to_vector(getDims().asArrayRef());
-
+  assert("expect sliceDims not being collapsed" &&
+         llvm::none_of(dimGroup, [&](int64_t dim) {
+           return llvm::is_contained(sliceDims, dim);
+         }));
   SmallVector<int64_t> dimsInParentSpace =
       mapSlicedDimsToParentSpace(dimGroup, sliceDims);
 
   auto collapsedParent = getParent().collapseDims(dimsInParentSpace);
-
   return SliceAttr::get(getContext(), collapsedParent,
                         DenseI64ArrayAttr::get(getContext(), sliceDims));
 }
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 7aa186bb22224..097c850e77b91 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -241,31 +241,19 @@ xegpu::DistributeLayoutAttr xegpu::inferInsertStridedSliceSourceLayout(
   int resShapeSize = resShape.size();
   int dimDiff = resShapeSize - srcShapeSize;
 
-  assert(isa<xegpu::LayoutAttr>(resLayout) &&
-         "insertStridedSlice result layout must be plain layout");
-  auto context = resLayout.getContext();
-  auto resInstData = resLayout.getEffectiveInstDataAsInt();
-  auto resLaneLayout = resLayout.getEffectiveLaneLayoutAsInt();
-  auto resLaneData = resLayout.getEffectiveLaneDataAsInt();
-
-  if (resInstData.size() != 0) {
-    SmallVector<int> inferredInstData(srcShapeSize);
-    for (int i = 0; i < srcShapeSize; i++)
-      inferredInstData[i] = resInstData[i + dimDiff];
-    return xegpu::LayoutAttr::get(context, inferredInstData);
-  }
-
-  if (resLaneLayout.size() != 0) {
-    SmallVector<int> inferredLaneLayout(srcShapeSize);
-    SmallVector<int> inferredLaneData(srcShapeSize);
-    for (int i = 0; i < srcShapeSize; i++) {
-      inferredLaneLayout[i] = resLaneLayout[i + dimDiff];
-      inferredLaneData[i] = resLaneData[i + dimDiff];
+  if (dimDiff > 0) {
+    // assert that the leading dimensions being sliced off are not distributed
+    // (i.e. sg_layout and lane_layout for those dimensions are all 1)
+    auto resSgLayout = resLayout.getEffectiveSgLayoutAsInt();
+    auto resLaneLayout = resLayout.getEffectiveLaneLayoutAsInt();
+    for (int i = 0; i < dimDiff; i++) {
+      assert((resSgLayout.size() == 0 || resSgLayout[i] == 1) &&
+             (resLaneLayout.size() == 0 || resLaneLayout[i] == 1) &&
+             "Leading dimensions being sliced off must not be distributed");
     }
-    return xegpu::LayoutAttr::get(context, inferredLaneLayout,
-                                  inferredLaneData);
+    return resLayout.dropDims(llvm::to_vector(llvm::seq<int64_t>(0, dimDiff)));
   }
-  return nullptr;
+  return resLayout;
 }
 
 /// Infers the source layout attribute for a shape cast operation given the
@@ -619,52 +607,54 @@ xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
 ///      resShape=[4, 64], subgroupSize=16, bitwidth=16, packingFactor=2
 ///      consumerLayout: laneLayout=[1, 16], laneData=[1, 2]
 ///      Result: laneLayout=[1, 16], laneData=[1, 2] (adjusted for packed data)
-xegpu::DistributeLayoutAttr xegpu::setupInsertStridedSliceResultLayout(
-    xegpu::LayoutKind layoutKind, VectorType srcVectorTy,
-    VectorType resVectorTy, xegpu::DistributeLayoutAttr consumerLayout,
-    const xegpu::uArch::uArch *uArch) {
-
-  xegpu::DistributeLayoutAttr requiredResLayout;
-  auto subgroupSize = uArch->getSubgroupSize();
-  auto context = resVectorTy.getContext();
-  auto resShape = resVectorTy.getShape();
-  int resShapeSize = resShape.size();
-  auto srcShape = srcVectorTy.getShape();
-  SmallVector<int64_t> consumerInstData =
-      consumerLayout.getEffectiveInstDataAsInt();
-  SmallVector<int64_t> consumerLaneData =
-      consumerLayout.getEffectiveLaneDataAsInt();
-
-  SmallVector<int> instData(resShapeSize, 1);
-  SmallVector<int> laneLayout(resShapeSize, 1);
-  SmallVector<int> laneData(resShapeSize, 1);
-
-  const unsigned packingSize{uArch->getGeneralPackedFormatBitSize()};
-  unsigned bitwidth = resVectorTy.getElementType().getIntOrFloatBitWidth();
-  int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
-  int packedDataSize = subgroupSize * packingFactor;
-
-  if (layoutKind == xegpu::LayoutKind::Subgroup) {
-    assert(true &&
-           "subgroup layout assignment not supported for insertStridedSlice.");
-  } else if (layoutKind == xegpu::LayoutKind::InstData) {
-    assert(srcShape.back() >= subgroupSize &&
-           "source innermost dim must be >= subgroupSize");
-    instData.back() = subgroupSize;
-    if (consumerInstData.back() == packedDataSize &&
-        srcShape.back() >= packedDataSize)
-      instData.back() = packedDataSize;
-    requiredResLayout = xegpu::LayoutAttr::get(context, instData);
-  } else if (layoutKind == xegpu::LayoutKind::Lane) {
-    laneLayout.back() = subgroupSize;
-    laneData.back() = 1;
-    if (consumerLaneData.back() == packingFactor &&
-        srcShape.back() >= packedDataSize)
-      laneData.back() = packingFactor;
-    requiredResLayout = xegpu::LayoutAttr::get(context, laneLayout, laneData);
-  }
-  return requiredResLayout;
-}
+// xegpu::DistributeLayoutAttr xegpu::setupInsertStridedSliceResultLayout(
+//     xegpu::LayoutKind layoutKind, VectorType srcVectorTy,
+//     VectorType resVectorTy, xegpu::DistributeLayoutAttr consumerLayout,
+//     const xegpu::uArch::uArch *uArch) {
+
+//   xegpu::DistributeLayoutAttr requiredResLayout;
+//   auto subgroupSize = uArch->getSubgroupSize();
+//   auto context = resVectorTy.getContext();
+//   auto resShape = resVectorTy.getShape();
+//   int resShapeSize = resShape.size();
+//   auto srcShape = srcVectorTy.getShape();
+//   SmallVector<int64_t> consumerInstData =
+//       consumerLayout.getEffectiveInstDataAsInt();
+//   SmallVector<int64_t> consumerLaneData =
+//       consumerLayout.getEffectiveLaneDataAsInt();
+
+//   SmallVector<int> instData(resShapeSize, 1);
+//   SmallVector<int> laneLayout(resShapeSize, 1);
+//   SmallVector<int> laneData(resShapeSize, 1);
+
+//   const unsigned packingSize{uArch->getGeneralPackedFormatBitSize()};
+//   unsigned bitwidth = resVectorTy.getElementType().getIntOrFloatBitWidth();
+//   int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
+//   int packedDataSize = subgroupSize * packingFactor;
+
+//   if (layoutKind == xegpu::LayoutKind::Subgroup) {
+//     assert(true &&
+//            "subgroup layout assignment not supported for
+//            insertStridedSlice.");
+//   } else if (layoutKind == xegpu::LayoutKind::InstData) {
+//     assert(srcShape.back() >= subgroupSize &&
+//            "source innermost dim must be >= subgroupSize");
+//     instData.back() = subgroupSize;
+//     if (consumerInstData.back() == packedDataSize &&
+//         srcShape.back() >= packedDataSize)
+//       instData.back() = packedDataSize;
+//     requiredResLayout = xegpu::LayoutAttr::get(context, instData);
+//   } else if (layoutKind == xegpu::LayoutKind::Lane) {
+//     laneLayout.back() = subgroupSize;
+//     laneData.back() = 1;
+//     if (consumerLaneData.back() == packingFactor &&
+//         srcShape.back() >= packedDataSize)
+//       laneData.back() = packingFactor;
+//     requiredResLayout = xegpu::LayoutAttr::get(context, laneLayout,
+//     laneData);
+//   }
+//   return requiredResLayout;
+// }
 
 /// Sets up the anchor layout for load gather and load matrix operation.
 /// load matrix lowers to load gather and 1d block load. All of them share the
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 7f7e8d6ad7734..ea9faf1438033 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -969,18 +969,12 @@ void LayoutInfoPropagation::visitInsertStridedSliceOp(
   if (!uArch)
     return;
 
-  auto requiredResLayoutAttr = xegpu::setupInsertStridedSliceResultLayout(
-      layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch);
-
-  xegpu::setTemporaryLayout(insertStridedSlice->getResult(0),
-                            requiredResLayoutAttr);
-
   auto srcLayoutAttr = xegpu::inferInsertStridedSliceSourceLayout(
-      requiredResLayoutAttr, resVecType.getShape(), srcVecType.getShape());
+      consumerLayoutAttr, resVecType.getShape(), srcVecType.getShape());
 
   propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
   propagateIfChanged(operands[1],
-                     operands[1]->meet(LayoutInfo(requiredResLayoutAttr)));
+                     operands[1]->meet(LayoutInfo(consumerLayoutAttr)));
 }
 
 /// Propagate the layout of the result to the tensor descriptor, mask and offset
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
index 5dd05e6cb0001..bfffcf75c7306 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
@@ -238,9 +238,9 @@ func.func @scatter_ops_chunksize_slice(%src: memref<1024xf32>) {
 gpu.module @test {
 // CHECK-LABEL: func.func @insert_strided_slice_inst_data_no_packing(
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x32xf32>) {
-// CHECK: %[[CST_SMALL:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 16]>} dense<1.000000e+00> : vector<4x16xf32>
-// CHECK: %[[CST_LARGE:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 16]>} dense<0.000000e+00> : vector<8x32xf32>
-// CHECK: %[[INSERT:.*]] = vector.insert_strided_slice %[[CST_SMALL]], %[[CST_LARGE]] {layout_result_0 = #xegpu.layout<inst_data = [1, 16]>, offsets = [0, 0], strides = [1, 1]} : vector<4x16xf32> into vector<8x32xf32>
+// CHECK: %[[CST_SMALL:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} dense<1.000000e+00> : vector<4x16xf32>
+// CHECK: %[[CST_LARGE:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} dense<0.000000e+00> : vector<8x32xf32>
+// CHECK: %[[INSERT:.*]] = vector.insert_strided_slice %[[CST_SMALL]], %[[CST_LARGE]] {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>, offsets = [0, 0], strides = [1, 1]} : vector<4x16xf32> into vector<8x32xf32>
 // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}] : memref<8x32xf32> -> !xegpu.tensor_desc<8x32xf32, #xegpu.layout<inst_data = [8, 16]>>
 // CHECK: xegpu.store_nd %[[INSERT]], %[[TDESC]] <{layout = #xegpu.layout<inst_data = [8, 16]>}> : vector<8x32xf32>, !xegpu.tensor_desc<8x32xf32, #xegpu.layout<inst_data = [8, 16]>>
 func.func @insert_strided_slice_inst_data_no_packing(%arg0: memref<8x32xf32>) {
@@ -258,9 +258,9 @@ func.func @insert_strided_slice_inst_data_no_packing(%arg0: memref<8x32xf32>) {
 gpu.module @test {
 // CHECK-LABEL: func.func @insert_strided_slice_inst_data_with_packing(
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x64xi8>) {
-// CHECK: %[[CST_SMALL:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 64]>} dense<1> : vector<4x64xi8>
-// CHECK: %[[CST_LARGE:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 64]>} dense<0> : vector<8x64xi8>
-// CHECK: %[[INSERT:.*]] = vector.insert_strided_slice %[[CST_SMALL]], %[[CST_LARGE]] {layout_result_0 = #xegpu.layout<inst_data = [1, 64]>, offsets = [0, 0], strides = [1, 1]} : vector<4x64xi8> into vector<8x64xi8>
+// CHECK: %[[CST_SMALL:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8, 64]>} dense<1> : vector<4x64xi8>
+// CHECK: %[[CST_LARGE:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8, 64]>} dense<0> : vector<8x64xi8>
+// CHECK: %[[INSERT:.*]] = vector.insert_strided_slice %[[CST_SMALL]], %[[CST_LARGE]] {layout_result_0 = #xegpu.layout<inst_data = [8, 64]>, offsets = [0, 0], strides = [1, 1]} : vector<4x64xi8> into vector<8x64xi8>
 func.func @insert_strided_slice_inst_data_with_packing(%arg0: memref<8x64xi8>) {
   %c0 = arith.constant 0 : index
   %cst_small = arith.constant dense<1> : vector<4x64xi8>

>From bcb1e63828e230ca7dd36581db4cc25947ca0293 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 5 Mar 2026 02:43:57 +0000
Subject: [PATCH 2/7] add test

---
 mlir/test/Dialect/XeGPU/propagate-layout.mlir | 24 +++++++++++++++++++
 1 file changed, 24 insertions(+)

diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index 4f2349a89b1ed..6ddb922ccd78a 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -792,6 +792,30 @@ func.func @insert_strided_slice_lane_layout_with_packing(%arg0: memref<4x64xf16>
 }
 }
 
+// -----
+gpu.module @test {
+// CHECK-LABEL: func.func @insert_strided_slice_with_slice_layout(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
+// CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>} dense<1.000000e+00> : vector<1xf32>
+// CHECK: %[[CST_0:.*]] = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>} dense<1.000000e+00> : vector<16xf32>
+// CHECK: %[[INSERT:.*]] = vector.insert_strided_slice %[[CST]], %[[CST_0]] {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>, offsets = [15], strides = [1]} : vector<1xf32> into vector<16xf32>
+// CHECK: %[[EXTRACT:.*]] = vector.extract_strided_slice %[[INSERT]] {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>, offsets = [0], sizes = [8], strides = [1]} : vector<16xf32> to vector<8xf32>
+// CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[EXTRACT]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} : vector<8xf32> to vector<16x8xf32>
+// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[BROADCAST]], [1, 0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x8xf32> to vector<8x16xf32>
+func.func @insert_strided_slice_with_slice_layout(%arg0: memref<8x16xf32>) {
+  %c0 = arith.constant 0 : index
+  %cst_small = arith.constant dense<1.0> : vector<1xf32>
+  %cst_large = arith.constant dense<1.0> : vector<16xf32>
+  %cst_large_new = vector.insert_strided_slice %cst_small, %cst_large {offsets = [15], strides = [1]} : vector<1xf32> into vector<16xf32>
+  %cst_small8 = vector.extract_strided_slice %cst_large_new {offsets = [0], sizes = [8], strides = [1]} : vector<16xf32> to vector<8xf32>
+  %cst_small16x8 = vector.broadcast %cst_small8 : vector<8xf32> to vector<16x8xf32>
+  %cst_small8x16 = vector.transpose %cst_small16x8, [1, 0] : vector<16x8xf32> to vector<8x16xf32>
+  %tdesc = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+  xegpu.store_nd %cst_small8x16, %tdesc <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}>: vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+  return
+}
+}
+
 // -----
 gpu.module @test{
   // CHECK-LABEL: load_store_matrix

>From 8288b7e8b04d7a92e6d277d2838c4529d952e71f Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 5 Mar 2026 04:53:36 +0000
Subject: [PATCH 3/7] rewrite setupInsertStridedSliceResultLayout

---
 .../XeGPU/Transforms/XeGPULayoutImpl.h        |   6 +-
 .../XeGPU/Transforms/XeGPULayoutImpl.cpp      | 142 +++++++-----------
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp |  10 +-
 .../XeGPU/propagate-layout-inst-data.mlir     |   4 +-
 4 files changed, 64 insertions(+), 98 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
index 5ccb69e0a89dd..3482d1b9401bb 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
@@ -132,9 +132,9 @@ DistributeLayoutAttr setupBitCastResultLayout(
 /// Sets up the result layout for an insert strided slice operation.
 /// Creates a result layout based on the specified layout kind (InstData or
 /// Lane).
-// DistributeLayoutAttr setupInsertStridedSliceResultLayout(
-//     LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy,
-//     DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch);
+DistributeLayoutAttr setupInsertStridedSliceResultLayout(
+    LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy,
+    DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch);
 
 /// Sets up the anchor layout for a load gather operation.
 DistributeLayoutAttr
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 097c850e77b91..89445a9885b7c 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -532,11 +532,12 @@ xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
   SmallVector<int64_t> sgData = consumerLayout.getEffectiveSgDataAsInt();
   SmallVector<int64_t> instData = consumerLayout.getEffectiveInstDataAsInt();
   SmallVector<int64_t> laneData = consumerLayout.getEffectiveLaneDataAsInt();
+  assert(consumerLayout.getRank() == static_cast<int64_t>(srcShape.size()) &&
+         "laneData must be available for all dimensions");
   size_t dim = srcShape.size() - 1;
   int64_t sgDataValue = -1;
   int64_t instDataValue = -1;
   int64_t laneDataValue = -1;
-
   const int subgroupSize = uArch->getSubgroupSize();
 
   if (srcElemTyBitWidth > resElemTyBitWidth) {
@@ -546,12 +547,8 @@ xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
     int bitWidthRatio = srcElemTyBitWidth / resElemTyBitWidth;
     int innermostDimLaneLayout = subgroupSize;
     if (layoutKind == xegpu::LayoutKind::Subgroup) {
-      assert(sgData.size() == srcShape.size() &&
-             "sgData must be available for all dimensions");
       sgDataValue = sgData[dim];
     } else if (layoutKind == xegpu::LayoutKind::InstData) {
-      assert(instData.size() == srcShape.size() &&
-             "instData must be available for all dimensions");
       instDataValue = instData[dim];
       // Adjust instDataValue so it still fits within an instruction after
       // dividing by bitWidthRatio
@@ -561,8 +558,6 @@ xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
       assert((srcShape[dim] % instDataValue) == 0 &&
              "srcShape, instData, and lanelayout for innermost must be 2^n !");
     } else if (layoutKind == xegpu::LayoutKind::Lane) {
-      assert(laneData.size() == srcShape.size() &&
-             "laneData must be available for all dimensions");
       laneDataValue = laneData[dim];
       while ((laneDataValue <= srcShape[dim]) &&
              (laneDataValue % bitWidthRatio != 0))
@@ -580,81 +575,47 @@ xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
 /// Sets up the result layout for an insert strided slice operation.
 /// Creates a result layout based on the specified layout kind (InstData or
 /// Lane).
-/// Subgroup layout is currently not supported for this operation.
-/// InstData layout is first set to be {1, .., subgroupSize}.
-/// Lane layout is first set to be {1, ..., subgroupSize} with lane data {1,
-/// ..., 1}. The instData and laneData is then adjusted to contain packed data,
-/// by checking if the consumerLayout's innermost dimension.
-///
-/// Examples:
-///   1. InstData layout without packing:
-///      resShape=[8, 32], subgroupSize=16, bitwidth=32
-///      packingFactor=1, packedDataSize=16
-///      consumerLayout: instData=[1, 16]
-///      Result: instData=[1, 16]
-///
-///   2. InstData layout with packing:
-///      resShape=[8, 64], subgroupSize=16, bitwidth=8, packingFactor=4
-///      consumerLayout: instData=[1, 64]
-///      Result: instData=[1, 64] (adjusted for packed data)
-///
-///   3. Lane layout without packing:
-///      resShape=[4, 64], subgroupSize=16, bitwidth=32
-///      consumerLayout: laneLayout=[1, 16], laneData=[1, 1]
-///      Result: laneLayout=[1, 16], laneData=[1, 1]
-///
-///   4. Lane layout with packing:
-///      resShape=[4, 64], subgroupSize=16, bitwidth=16, packingFactor=2
-///      consumerLayout: laneLayout=[1, 16], laneData=[1, 2]
-///      Result: laneLayout=[1, 16], laneData=[1, 2] (adjusted for packed data)
-// xegpu::DistributeLayoutAttr xegpu::setupInsertStridedSliceResultLayout(
-//     xegpu::LayoutKind layoutKind, VectorType srcVectorTy,
-//     VectorType resVectorTy, xegpu::DistributeLayoutAttr consumerLayout,
-//     const xegpu::uArch::uArch *uArch) {
-
-//   xegpu::DistributeLayoutAttr requiredResLayout;
-//   auto subgroupSize = uArch->getSubgroupSize();
-//   auto context = resVectorTy.getContext();
-//   auto resShape = resVectorTy.getShape();
-//   int resShapeSize = resShape.size();
-//   auto srcShape = srcVectorTy.getShape();
-//   SmallVector<int64_t> consumerInstData =
-//       consumerLayout.getEffectiveInstDataAsInt();
-//   SmallVector<int64_t> consumerLaneData =
-//       consumerLayout.getEffectiveLaneDataAsInt();
-
-//   SmallVector<int> instData(resShapeSize, 1);
-//   SmallVector<int> laneLayout(resShapeSize, 1);
-//   SmallVector<int> laneData(resShapeSize, 1);
-
-//   const unsigned packingSize{uArch->getGeneralPackedFormatBitSize()};
-//   unsigned bitwidth = resVectorTy.getElementType().getIntOrFloatBitWidth();
-//   int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
-//   int packedDataSize = subgroupSize * packingFactor;
-
-//   if (layoutKind == xegpu::LayoutKind::Subgroup) {
-//     assert(true &&
-//            "subgroup layout assignment not supported for
-//            insertStridedSlice.");
-//   } else if (layoutKind == xegpu::LayoutKind::InstData) {
-//     assert(srcShape.back() >= subgroupSize &&
-//            "source innermost dim must be >= subgroupSize");
-//     instData.back() = subgroupSize;
-//     if (consumerInstData.back() == packedDataSize &&
-//         srcShape.back() >= packedDataSize)
-//       instData.back() = packedDataSize;
-//     requiredResLayout = xegpu::LayoutAttr::get(context, instData);
-//   } else if (layoutKind == xegpu::LayoutKind::Lane) {
-//     laneLayout.back() = subgroupSize;
-//     laneData.back() = 1;
-//     if (consumerLaneData.back() == packingFactor &&
-//         srcShape.back() >= packedDataSize)
-//       laneData.back() = packingFactor;
-//     requiredResLayout = xegpu::LayoutAttr::get(context, laneLayout,
-//     laneData);
-//   }
-//   return requiredResLayout;
-// }
+xegpu::DistributeLayoutAttr xegpu::setupInsertStridedSliceResultLayout(
+    xegpu::LayoutKind layoutKind, VectorType srcVectorTy,
+    VectorType resVectorTy, xegpu::DistributeLayoutAttr consumerLayout,
+    const xegpu::uArch::uArch *uArch) {
+
+  xegpu::DistributeLayoutAttr requiredResLayout;
+  SmallVector<int64_t> consumerInstData =
+      consumerLayout.getEffectiveInstDataAsInt();
+  SmallVector<int64_t> consumerLaneData =
+      consumerLayout.getEffectiveLaneDataAsInt();
+  SmallVector<int64_t> consumerLaneLayout =
+      consumerLayout.getEffectiveLaneLayoutAsInt();
+  ArrayRef<int64_t> srcShape = srcVectorTy.getShape();
+  int64_t instDataValue = -1;
+  int64_t laneDataValue = -1;
+
+  requiredResLayout = consumerLayout;
+  int srcRank = srcShape.size();
+
+  if (layoutKind == xegpu::LayoutKind::Subgroup) {
+    assert(true &&
+           "subgroup layout assignment not supported for insertStridedSlice.");
+  } else if (layoutKind == xegpu::LayoutKind::InstData) {
+    for (int dim = 0; dim < srcRank; dim++) {
+      instDataValue = std::min(srcShape[dim], consumerInstData[dim]);
+      requiredResLayout =
+          requiredResLayout.setDimData(dim, -1, instDataValue, -1);
+    }
+  } else if (layoutKind == xegpu::LayoutKind::Lane) {
+    for (int dim = 0; dim < srcRank; dim++) {
+      assert(srcShape[dim] % consumerLaneLayout[dim] == 0 &&
+             "srcShape must be divisible by laneLayout for all dimensions");
+      laneDataValue = std::min(srcShape[dim] / consumerLaneLayout[dim],
+                               consumerLaneData[dim]);
+
+      requiredResLayout =
+          requiredResLayout.setDimData(dim, -1, -1, laneDataValue);
+    }
+  }
+  return requiredResLayout;
+}
 
 /// Sets up the anchor layout for load gather and load matrix operation.
 /// load matrix lowers to load gather and 1d block load. All of them share the
@@ -849,8 +810,8 @@ xegpu::setupStoreMatrixAnchorLayout(xegpu::LayoutKind layoutKind,
 }
 
 // This function returns the default lane layout for a given vector type.
-// - `packingSize` means multiple consecutive elements can be accessed together
-// as a single unit.
+// - `packingSize` means multiple consecutive elements can be accessed
+// together as a single unit.
 // - `vnni` means data packing is column-wise (i.e., 2x1xf16 with vnni vs.
 // 1x2xf16 w/o vnni).
 template <typename RankedTy>
@@ -915,7 +876,8 @@ getValidLayouts(ArrayRef<int64_t> wgShape, ArrayRef<int64_t> instData,
 }
 
 /// Sets up the anchor layouts for dpas operands (A, B, and C/D).
-/// The numSg and consumerLayout (optional) are only used by sg layout creation.
+/// The numSg and consumerLayout (optional) are only used by sg layout
+/// creation.
 std::optional<
     std::tuple<xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr,
                xegpu::DistributeLayoutAttr>>
@@ -1001,9 +963,9 @@ xegpu::setupDpasLayout(xegpu::LayoutKind layoutKind, VectorType aTy,
           break;
         }
         // Is in (A and B and CD) layoutsB is ordered from most
-        // balanced to least. So the first one we see is the most balanced one,
-        // remember it and later only update if there is one that matches the
-        // consumer.
+        // balanced to least. So the first one we see is the most balanced
+        // one, remember it and later only update if there is one that matches
+        // the consumer.
         if (!bestPick)
           bestPick = sgLayout;
       }
@@ -1142,7 +1104,7 @@ xegpu::DistributeLayoutAttr xegpu::getConsumerLayoutAt(OpOperand &operand) {
     return resLayout;
   }
   // TODO: Handle more cases as needed here.
-  // By default, assume no layout conflict and return the current layout of the
-  // operand.
+  // By default, assume no layout conflict and return the current layout of
+  // the operand.
   return xegpu::getDistributeLayoutAttr(operand.get());
 }
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index ea9faf1438033..4eac49525b832 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -969,12 +969,16 @@ void LayoutInfoPropagation::visitInsertStridedSliceOp(
   if (!uArch)
     return;
 
-  auto srcLayoutAttr = xegpu::inferInsertStridedSliceSourceLayout(
-      consumerLayoutAttr, resVecType.getShape(), srcVecType.getShape());
+  auto requiredResLayoutAttr = xegpu::setupInsertStridedSliceResultLayout(
+      layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch);
+  xegpu::setTemporaryLayout(insertStridedSlice->getResult(0),
+                            requiredResLayoutAttr);
 
+  auto srcLayoutAttr = xegpu::inferInsertStridedSliceSourceLayout(
+      requiredResLayoutAttr, resVecType.getShape(), srcVecType.getShape());
   propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
   propagateIfChanged(operands[1],
-                     operands[1]->meet(LayoutInfo(consumerLayoutAttr)));
+                     operands[1]->meet(LayoutInfo(requiredResLayoutAttr)));
 }
 
 /// Propagate the layout of the result to the tensor descriptor, mask and offset
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
index bfffcf75c7306..b0523229a4640 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
@@ -238,7 +238,7 @@ func.func @scatter_ops_chunksize_slice(%src: memref<1024xf32>) {
 gpu.module @test {
 // CHECK-LABEL: func.func @insert_strided_slice_inst_data_no_packing(
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x32xf32>) {
-// CHECK: %[[CST_SMALL:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} dense<1.000000e+00> : vector<4x16xf32>
+// CHECK: %[[CST_SMALL:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [4, 16]>} dense<1.000000e+00> : vector<4x16xf32>
 // CHECK: %[[CST_LARGE:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} dense<0.000000e+00> : vector<8x32xf32>
 // CHECK: %[[INSERT:.*]] = vector.insert_strided_slice %[[CST_SMALL]], %[[CST_LARGE]] {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>, offsets = [0, 0], strides = [1, 1]} : vector<4x16xf32> into vector<8x32xf32>
 // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}] : memref<8x32xf32> -> !xegpu.tensor_desc<8x32xf32, #xegpu.layout<inst_data = [8, 16]>>
@@ -258,7 +258,7 @@ func.func @insert_strided_slice_inst_data_no_packing(%arg0: memref<8x32xf32>) {
 gpu.module @test {
 // CHECK-LABEL: func.func @insert_strided_slice_inst_data_with_packing(
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x64xi8>) {
-// CHECK: %[[CST_SMALL:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8, 64]>} dense<1> : vector<4x64xi8>
+// CHECK: %[[CST_SMALL:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [4, 64]>} dense<1> : vector<4x64xi8>
 // CHECK: %[[CST_LARGE:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8, 64]>} dense<0> : vector<8x64xi8>
 // CHECK: %[[INSERT:.*]] = vector.insert_strided_slice %[[CST_SMALL]], %[[CST_LARGE]] {layout_result_0 = #xegpu.layout<inst_data = [8, 64]>, offsets = [0, 0], strides = [1, 1]} : vector<4x64xi8> into vector<8x64xi8>
 func.func @insert_strided_slice_inst_data_with_packing(%arg0: memref<8x64xi8>) {

>From 36eb877956f3b504cae2dd64b06b8cfd6ec51efb Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 5 Mar 2026 05:39:07 +0000
Subject: [PATCH 4/7] fix test

---
 mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
index b0523229a4640..c510a1d5f0fdf 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
@@ -239,8 +239,8 @@ gpu.module @test {
 // CHECK-LABEL: func.func @insert_strided_slice_inst_data_no_packing(
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x32xf32>) {
 // CHECK: %[[CST_SMALL:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [4, 16]>} dense<1.000000e+00> : vector<4x16xf32>
-// CHECK: %[[CST_LARGE:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} dense<0.000000e+00> : vector<8x32xf32>
-// CHECK: %[[INSERT:.*]] = vector.insert_strided_slice %[[CST_SMALL]], %[[CST_LARGE]] {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>, offsets = [0, 0], strides = [1, 1]} : vector<4x16xf32> into vector<8x32xf32>
+// CHECK: %[[CST_LARGE:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [4, 16]>} dense<0.000000e+00> : vector<8x32xf32>
+// CHECK: %[[INSERT:.*]] = vector.insert_strided_slice %[[CST_SMALL]], %[[CST_LARGE]] {layout_result_0 = #xegpu.layout<inst_data = [4, 16]>, offsets = [0, 0], strides = [1, 1]} : vector<4x16xf32> into vector<8x32xf32>
 // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}] : memref<8x32xf32> -> !xegpu.tensor_desc<8x32xf32, #xegpu.layout<inst_data = [8, 16]>>
 // CHECK: xegpu.store_nd %[[INSERT]], %[[TDESC]] <{layout = #xegpu.layout<inst_data = [8, 16]>}> : vector<8x32xf32>, !xegpu.tensor_desc<8x32xf32, #xegpu.layout<inst_data = [8, 16]>>
 func.func @insert_strided_slice_inst_data_no_packing(%arg0: memref<8x32xf32>) {
@@ -259,8 +259,8 @@ gpu.module @test {
 // CHECK-LABEL: func.func @insert_strided_slice_inst_data_with_packing(
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x64xi8>) {
 // CHECK: %[[CST_SMALL:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [4, 64]>} dense<1> : vector<4x64xi8>
-// CHECK: %[[CST_LARGE:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8, 64]>} dense<0> : vector<8x64xi8>
-// CHECK: %[[INSERT:.*]] = vector.insert_strided_slice %[[CST_SMALL]], %[[CST_LARGE]] {layout_result_0 = #xegpu.layout<inst_data = [8, 64]>, offsets = [0, 0], strides = [1, 1]} : vector<4x64xi8> into vector<8x64xi8>
+// CHECK: %[[CST_LARGE:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [4, 64]>} dense<0> : vector<8x64xi8>
+// CHECK: %[[INSERT:.*]] = vector.insert_strided_slice %[[CST_SMALL]], %[[CST_LARGE]] {layout_result_0 = #xegpu.layout<inst_data = [4, 64]>, offsets = [0, 0], strides = [1, 1]} : vector<4x64xi8> into vector<8x64xi8>
 func.func @insert_strided_slice_inst_data_with_packing(%arg0: memref<8x64xi8>) {
   %c0 = arith.constant 0 : index
   %cst_small = arith.constant dense<1> : vector<4x64xi8>

>From dee54e3c0e8ee7d3b3d67b9873a94e209d67852c Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 6 Mar 2026 03:00:54 +0000
Subject: [PATCH 5/7] address feedback

---
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 127 +++++++--------------
 1 file changed, 42 insertions(+), 85 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 76b7a690288a3..4c434a35755a8 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -516,23 +516,16 @@ DistributeLayoutAttr LayoutAttr::dropDims(SmallVector<int64_t> dimGroup) {
   SmallVector<int64_t> sortedDimGroup = dimGroup;
   llvm::sort(sortedDimGroup);
 
-  if (!sgLayout.empty()) {
-    for (auto dimIdx : llvm::reverse(sortedDimGroup)) {
-      sgLayout.erase(sgLayout.begin() + dimIdx, sgLayout.begin() + dimIdx + 1);
-      sgData.erase(sgData.begin() + dimIdx, sgData.begin() + dimIdx + 1);
+  for (auto dimIdx : llvm::reverse(sortedDimGroup)) {
+    if (!sgLayout.empty()) {
+      sgLayout.erase(sgLayout.begin() + dimIdx);
+      sgData.erase(sgData.begin() + dimIdx);
     }
-  }
-
-  if (!instData.empty()) {
-    for (auto dimIdx : llvm::reverse(sortedDimGroup))
-      instData.erase(instData.begin() + dimIdx, instData.begin() + dimIdx + 1);
-  }
-
-  if (!laneLayout.empty()) {
-    for (auto dimIdx : llvm::reverse(sortedDimGroup)) {
-      laneLayout.erase(laneLayout.begin() + dimIdx,
-                       laneLayout.begin() + dimIdx + 1);
-      laneData.erase(laneData.begin() + dimIdx, laneData.begin() + dimIdx + 1);
+    if (!instData.empty())
+      instData.erase(instData.begin() + dimIdx);
+    if (!laneLayout.empty()) {
+      laneLayout.erase(laneLayout.begin() + dimIdx);
+      laneData.erase(laneData.begin() + dimIdx);
     }
   }
 
@@ -543,31 +536,18 @@ DistributeLayoutAttr LayoutAttr::dropDims(SmallVector<int64_t> dimGroup) {
     int64_t offset = llvm::count_if(dimGroup, [&](int64_t s) { return s < d; });
     newOrder.push_back(d - offset);
   }
-
-  // Create dropped layout
-  SmallVector<int32_t> sgLayout32(sgLayout.begin(), sgLayout.end());
-  SmallVector<int32_t> sgData32(sgData.begin(), sgData.end());
-  SmallVector<int32_t> instData32(instData.begin(), instData.end());
-  SmallVector<int32_t> laneLayout32(laneLayout.begin(), laneLayout.end());
-  SmallVector<int32_t> laneData32(laneData.begin(), laneData.end());
-  SmallVector<int32_t> newOrder32(newOrder.begin(), newOrder.end());
-
-  DenseI32ArrayAttr orderAttr = getOrder();
+  if (sgLayout.empty() && laneLayout.empty())
+    newOrder.clear();
+
+  auto toAttr = [&](ArrayRef<int64_t> v) -> DenseI32ArrayAttr {
+    if (v.empty())
+      return DenseI32ArrayAttr();
+    SmallVector<int32_t> v32(v.begin(), v.end());
+    return DenseI32ArrayAttr::get(getContext(), v32);
+  };
   auto droppedLayout = xegpu::LayoutAttr::get(
-      getContext(),
-      sgLayout32.empty() ? DenseI32ArrayAttr()
-                         : DenseI32ArrayAttr::get(getContext(), sgLayout32),
-      sgData32.empty() ? DenseI32ArrayAttr()
-                       : DenseI32ArrayAttr::get(getContext(), sgData32),
-      instData32.empty() ? DenseI32ArrayAttr()
-                         : DenseI32ArrayAttr::get(getContext(), instData32),
-      laneLayout32.empty() ? DenseI32ArrayAttr()
-                           : DenseI32ArrayAttr::get(getContext(), laneLayout32),
-      laneData32.empty() ? DenseI32ArrayAttr()
-                         : DenseI32ArrayAttr::get(getContext(), laneData32),
-      (!orderAttr || orderAttr.empty())
-          ? DenseI32ArrayAttr()
-          : DenseI32ArrayAttr::get(getContext(), newOrder32));
+      getContext(), toAttr(sgLayout), toAttr(sgData), toAttr(instData),
+      toAttr(laneLayout), toAttr(laneData), toAttr(newOrder));
   return droppedLayout;
 }
 
@@ -581,26 +561,19 @@ DistributeLayoutAttr LayoutAttr::collapseDims(SmallVector<int64_t> dimGroup) {
   SmallVector<int64_t> instData = getEffectiveInstDataAsInt();
   SmallVector<int64_t> laneLayout = getEffectiveLaneLayoutAsInt();
   SmallVector<int64_t> laneData = getEffectiveLaneDataAsInt();
-
-  DenseI32ArrayAttr orderAttr = getOrder();
-  SmallVector<int32_t> orderVec;
-  if (orderAttr && !orderAttr.empty()) {
-    orderVec = llvm::to_vector(
-        llvm::map_range(orderAttr.asArrayRef(),
-                        [](int32_t idx) { return static_cast<int32_t>(idx); }));
-  }
+  SmallVector<int64_t> origOrder = getEffectiveOrderAsInt();
 
   SmallVector<int64_t> sortedDimGroup = dimGroup;
   llvm::sort(sortedDimGroup);
   int64_t dimBeforeCurrent = -1;
   for (auto dimIdx : sortedDimGroup) {
-    // when order is present, adjacency dims are values like [3, 2, 1, 0]
+    // when order attr is present, adjacency dims are values like [3, 2, 1, 0]
     // in decreasing order; otherwise based on dim indices like [0, 1, 2, 3]
     // in increasing order
     if (dimBeforeCurrent >= 0) {
-      if (!orderVec.empty()) {
-        int64_t orderBefore = orderVec[dimBeforeCurrent];
-        int64_t orderCurrent = orderVec[dimIdx];
+      if (getOrder() && !getOrder().empty()) {
+        int64_t orderBefore = origOrder[dimBeforeCurrent];
+        int64_t orderCurrent = origOrder[dimIdx];
         if (orderBefore != (orderCurrent - 1))
           llvm::report_fatal_error(
               "dimensions being collapsed must be adjacent in order");
@@ -656,52 +629,36 @@ DistributeLayoutAttr LayoutAttr::collapseDims(SmallVector<int64_t> dimGroup) {
     laneData.insert(laneData.begin() + firstDim, collapsedLaneData);
   }
 
-  // After collapsing the order vector, re-map the order values to be in range
-  // of [0, N-1] where N is the number of dimensions in collapsed shape. For
-  // exmaple, collapse dim group {2, 3} of order[1, 2, 3, 4] to new order[1, 3,
-  // 4]. the loop below remaps it to [1, 2, 3].
-  SmallVector<int32_t> collapsedOrder;
-  if (!orderVec.empty()) {
+  SmallVector<int64_t> newOrder;
+  DenseI32ArrayAttr orderAttr = getOrder();
+  if (orderAttr && !orderAttr.empty()) {
 
     for (auto dimIdx : llvm::reverse(sortedDimGroup)) {
       if (dimIdx != firstDim)
-        orderVec.erase(orderVec.begin() + dimIdx,
-                       orderVec.begin() + dimIdx + 1);
+        origOrder.erase(origOrder.begin() + dimIdx);
     }
     // say we have orderVec = {5, 3, 2, 1, 0}
     // Create indices [0, 1, 2, 3, 4]
     SmallVector<size_t> indices =
-        llvm::to_vector(llvm::seq<size_t>(0, orderVec.size()));
+        llvm::to_vector(llvm::seq<size_t>(0, orderAttr.size()));
 
     // Sort indices based on corresponding values
     llvm::sort(indices,
-               [&](size_t a, size_t b) { return orderVec[a] < orderVec[b]; });
-    collapsedOrder = llvm::to_vector(llvm::map_range(
-        indices, [&](size_t i) { return static_cast<int32_t>(i); }));
-  }
+               [&](size_t a, size_t b) { return origOrder[a] < origOrder[b]; });
 
-  // Create collapsed layout
-  SmallVector<int32_t> sgLayout32(sgLayout.begin(), sgLayout.end());
-  SmallVector<int32_t> sgData32(sgData.begin(), sgData.end());
-  SmallVector<int32_t> instData32(instData.begin(), instData.end());
-  SmallVector<int32_t> laneLayout32(laneLayout.begin(), laneLayout.end());
-  SmallVector<int32_t> laneData32(laneData.begin(), laneData.end());
+    newOrder = llvm::to_vector(llvm::map_range(
+        indices, [&](size_t i) { return static_cast<int64_t>(i); }));
+  }
 
+  auto toAttr = [&](ArrayRef<int64_t> v) -> DenseI32ArrayAttr {
+    if (v.empty())
+      return DenseI32ArrayAttr();
+    SmallVector<int32_t> v32(v.begin(), v.end());
+    return DenseI32ArrayAttr::get(getContext(), v32);
+  };
   auto collapsedLayout = xegpu::LayoutAttr::get(
-      getContext(),
-      sgLayout32.empty() ? DenseI32ArrayAttr()
-                         : DenseI32ArrayAttr::get(getContext(), sgLayout32),
-      sgData32.empty() ? DenseI32ArrayAttr()
-                       : DenseI32ArrayAttr::get(getContext(), sgData32),
-      instData32.empty() ? DenseI32ArrayAttr()
-                         : DenseI32ArrayAttr::get(getContext(), instData32),
-      laneLayout32.empty() ? DenseI32ArrayAttr()
-                           : DenseI32ArrayAttr::get(getContext(), laneLayout32),
-      laneData32.empty() ? DenseI32ArrayAttr()
-                         : DenseI32ArrayAttr::get(getContext(), laneData32),
-      collapsedOrder.empty()
-          ? DenseI32ArrayAttr()
-          : DenseI32ArrayAttr::get(getContext(), collapsedOrder));
+      getContext(), toAttr(sgLayout), toAttr(sgData), toAttr(instData),
+      toAttr(laneLayout), toAttr(laneData), toAttr(newOrder));
   return collapsedLayout;
 }
 

>From d2eaeba6fa48f1ccc70b0895b121724d46941e42 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 6 Mar 2026 03:41:59 +0000
Subject: [PATCH 6/7] fix issue in sliceAttr::dropdims

---
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 29 +++++++++++++++++++---
 1 file changed, 25 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 4c434a35755a8..9d99d402637fb 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -890,9 +890,23 @@ DistributeLayoutAttr SliceAttr::setDimData(int64_t dim, int64_t sgData,
       parent.setDimData(adjustDims[0], sgData, instData, laneData), getDims());
 }
 
-// Derive a new layout by removing dimensions.
-// `dimGroup` specifies a group of dimensions to be removed in the derived
-// layout.
+// Derive a new layout by removing dimensions. `dimGroup` specifies a group of
+// dimensions to be removed in the derived layout.
+//
+// Example: drop the 2nd dimension from a rank-3 sliced view.
+//
+// Suppose:
+//   xegpu.layout = slice<layout<[V0, V1, V2, V3, V4]>, [1, 3]>
+//
+// The slice removes parent dims [1, 3], so the sliced-space dims map to
+// parent dims [V0, V2, V4].
+//
+// If we drop sliced-space dim 1 (the 2nd dim), that corresponds to dropping
+// parent dim 2, result in parent layout [V0, V1, V3, V4] after dropping.
+// After parent dim 2 is removed, sliced dims [1, 3] must be reindexed to [1, 2].
+//
+// Result:
+//   xegpu.layout = slice<layout<[0, 1, 3, 4]>, [1, 2]>
 DistributeLayoutAttr SliceAttr::dropDims(SmallVector<int64_t> dimGroup) {
   // Map the sliced dims from parent space to collapsed space
   SmallVector<int64_t> sliceDims = llvm::to_vector(getDims().asArrayRef());
@@ -900,8 +914,15 @@ DistributeLayoutAttr SliceAttr::dropDims(SmallVector<int64_t> dimGroup) {
       mapSlicedDimsToParentSpace(dimGroup, sliceDims);
 
   auto droppedParent = getParent().dropDims(dimsInParentSpace);
+
+  SmallVector<int64_t> newSliceDims;
+  for (int64_t d : sliceDims) {
+    int64_t offset = llvm::count_if(dimsInParentSpace, [&](int64_t s) { return s < d; });
+    newSliceDims.push_back(d - offset);
+  }
+
   return SliceAttr::get(getContext(), droppedParent,
-                        DenseI64ArrayAttr::get(getContext(), sliceDims));
+                        DenseI64ArrayAttr::get(getContext(), newSliceDims));
 }
 
 // Derive a new layout by collapsing dimensions.

>From 4253006d91cdb718376986587c8d38dc99a29794 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 6 Mar 2026 14:59:07 +0000
Subject: [PATCH 7/7] add comments

---
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 9 +++++++--
 1 file changed, 7 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 9d99d402637fb..a6b91bb8febc9 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -903,7 +903,8 @@ DistributeLayoutAttr SliceAttr::setDimData(int64_t dim, int64_t sgData,
 //
 // If we drop sliced-space dim 1 (the 2nd dim), that corresponds to dropping
 // parent dim 2, result in parent layout [V0, V1, V3, V4] after dropping.
-// After parent dim 2 is removed, sliced dims [1, 3] must be reindexed to [1, 2].
+// After parent dim 2 is removed, sliced dims [1, 3] must be reindexed to [1,
+// 2].
 //
 // Result:
 //   xegpu.layout = slice<layout<[0, 1, 3, 4]>, [1, 2]>
@@ -915,9 +916,13 @@ DistributeLayoutAttr SliceAttr::dropDims(SmallVector<int64_t> dimGroup) {
 
   auto droppedParent = getParent().dropDims(dimsInParentSpace);
 
+  // Adjust the sliced dims after dropping dims in parent space. For example, if
+  // we drop dim 2 in parent space, the dims after dim 2 will all be shifted by
+  // 1, so sliced dim 3 will be adjusted to 2.
   SmallVector<int64_t> newSliceDims;
   for (int64_t d : sliceDims) {
-    int64_t offset = llvm::count_if(dimsInParentSpace, [&](int64_t s) { return s < d; });
+    int64_t offset =
+        llvm::count_if(dimsInParentSpace, [&](int64_t s) { return s < d; });
     newSliceDims.push_back(d - offset);
   }
 



More information about the Mlir-commits mailing list