[Mlir-commits] [mlir] [MLIR][XeGPU] Enhancing insert_strided_slice layout setup and infer rules (PR #184742)
Jianhui Li
llvmlistbot at llvm.org
Fri Mar 6 06:59:37 PST 2026
https://github.com/Jianhui-Li updated https://github.com/llvm/llvm-project/pull/184742
>From 58264107ef1d5a942d73c5b951a6c70daf06f11e Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 5 Mar 2026 01:42:54 +0000
Subject: [PATCH 1/7] refactor insert_strided_slice layout rules
---
.../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 13 ++
.../XeGPU/Transforms/XeGPULayoutImpl.h | 6 +-
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 103 ++++++++++++--
.../XeGPU/Transforms/XeGPULayoutImpl.cpp | 128 ++++++++----------
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 10 +-
.../XeGPU/propagate-layout-inst-data.mlir | 12 +-
6 files changed, 177 insertions(+), 95 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 6f667f4801673..1f53c8b3b93fa 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -248,6 +248,11 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
"int64_t": $sgData,
"int64_t": $instData,
"int64_t": $laneData)>,
+ InterfaceMethod<[{Derive a new layout by removing dimensions.
+ `dimGroup` specifies a group of dimensions to be removed in the derived layout.}],
+ "xegpu::DistributeLayoutAttr",
+ "dropDims",
+ (ins "SmallVector<int64_t>": $dimGroup)>,
InterfaceMethod<[{Derive a new layout by collapsing dimensions.
`dimGroup` specifies a group of adjacent dimensions that are collapsed into
a single dimension in the derived layout.}],
@@ -562,6 +567,10 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
// preserves its original value.
DistributeLayoutAttr setDimData(int64_t dim, int64_t sgData, int64_t instData, int64_t laneData);
+ // Derive a new layout by removing dimensions.
+ // `dimGroup` specifies a group of dimensions to be removed in the derived layout.
+ DistributeLayoutAttr dropDims(SmallVector<int64_t> dimGroup);
+
// Derive a new layout by collapsing dimensions.
// `dimGroup` specifies a group of adjacent dimensions
// that are collapsed into a single dimension in the derived layout.
@@ -762,6 +771,10 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
// preserves its original value.
DistributeLayoutAttr setDimData(int64_t dim, int64_t sgData, int64_t instData, int64_t laneData);
+ // Derive a new layout by removing dimensions.
+ // `dimGroup` specifies a group of dimensions to be removed in the derived layout.
+ DistributeLayoutAttr dropDims(SmallVector<int64_t> dimGroup);
+
// Derive a new layout by collapsing dimensions.
// `dimGroup` specifies a group of adjacent dimensions
// that are collapsed into a single dimension in the derived layout.
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
index 3482d1b9401bb..5ccb69e0a89dd 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
@@ -132,9 +132,9 @@ DistributeLayoutAttr setupBitCastResultLayout(
/// Sets up the result layout for an insert strided slice operation.
/// Creates a result layout based on the specified layout kind (InstData or
/// Lane).
-DistributeLayoutAttr setupInsertStridedSliceResultLayout(
- LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy,
- DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch);
+// DistributeLayoutAttr setupInsertStridedSliceResultLayout(
+// LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy,
+// DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch);
/// Sets up the anchor layout for a load gather operation.
DistributeLayoutAttr
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index c082600ec27d7..76b7a690288a3 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -501,6 +501,76 @@ DistributeLayoutAttr LayoutAttr::setDimData(int64_t dim, int64_t sgData,
getOrder());
}
+// Derive a new layout by removing dimensions.
+// `dimGroup` specifies a group of dimensions to be removed in the derived
+// layout.
+DistributeLayoutAttr LayoutAttr::dropDims(SmallVector<int64_t> dimGroup) {
+
+ SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
+ SmallVector<int64_t> sgData = getEffectiveSgDataAsInt();
+ SmallVector<int64_t> instData = getEffectiveInstDataAsInt();
+ SmallVector<int64_t> laneLayout = getEffectiveLaneLayoutAsInt();
+ SmallVector<int64_t> laneData = getEffectiveLaneDataAsInt();
+ SmallVector<int64_t> origOrder = getEffectiveOrderAsInt();
+
+ SmallVector<int64_t> sortedDimGroup = dimGroup;
+ llvm::sort(sortedDimGroup);
+
+ if (!sgLayout.empty()) {
+ for (auto dimIdx : llvm::reverse(sortedDimGroup)) {
+ sgLayout.erase(sgLayout.begin() + dimIdx, sgLayout.begin() + dimIdx + 1);
+ sgData.erase(sgData.begin() + dimIdx, sgData.begin() + dimIdx + 1);
+ }
+ }
+
+ if (!instData.empty()) {
+ for (auto dimIdx : llvm::reverse(sortedDimGroup))
+ instData.erase(instData.begin() + dimIdx, instData.begin() + dimIdx + 1);
+ }
+
+ if (!laneLayout.empty()) {
+ for (auto dimIdx : llvm::reverse(sortedDimGroup)) {
+ laneLayout.erase(laneLayout.begin() + dimIdx,
+ laneLayout.begin() + dimIdx + 1);
+ laneData.erase(laneData.begin() + dimIdx, laneData.begin() + dimIdx + 1);
+ }
+ }
+
+ SmallVector<int64_t> newOrder;
+ for (int64_t d : origOrder) {
+ if (llvm::is_contained(dimGroup, d))
+ continue;
+ int64_t offset = llvm::count_if(dimGroup, [&](int64_t s) { return s < d; });
+ newOrder.push_back(d - offset);
+ }
+
+ // Create dropped layout
+ SmallVector<int32_t> sgLayout32(sgLayout.begin(), sgLayout.end());
+ SmallVector<int32_t> sgData32(sgData.begin(), sgData.end());
+ SmallVector<int32_t> instData32(instData.begin(), instData.end());
+ SmallVector<int32_t> laneLayout32(laneLayout.begin(), laneLayout.end());
+ SmallVector<int32_t> laneData32(laneData.begin(), laneData.end());
+ SmallVector<int32_t> newOrder32(newOrder.begin(), newOrder.end());
+
+ DenseI32ArrayAttr orderAttr = getOrder();
+ auto droppedLayout = xegpu::LayoutAttr::get(
+ getContext(),
+ sgLayout32.empty() ? DenseI32ArrayAttr()
+ : DenseI32ArrayAttr::get(getContext(), sgLayout32),
+ sgData32.empty() ? DenseI32ArrayAttr()
+ : DenseI32ArrayAttr::get(getContext(), sgData32),
+ instData32.empty() ? DenseI32ArrayAttr()
+ : DenseI32ArrayAttr::get(getContext(), instData32),
+ laneLayout32.empty() ? DenseI32ArrayAttr()
+ : DenseI32ArrayAttr::get(getContext(), laneLayout32),
+ laneData32.empty() ? DenseI32ArrayAttr()
+ : DenseI32ArrayAttr::get(getContext(), laneData32),
+ (!orderAttr || orderAttr.empty())
+ ? DenseI32ArrayAttr()
+ : DenseI32ArrayAttr::get(getContext(), newOrder32));
+ return droppedLayout;
+}
+
// Derive a new layout by collapsing dimensions.
// `dimGroup` specifies a group of adjacent dimensions
// that are collapsed into a single dimension in the derived layout.
@@ -524,8 +594,8 @@ DistributeLayoutAttr LayoutAttr::collapseDims(SmallVector<int64_t> dimGroup) {
llvm::sort(sortedDimGroup);
int64_t dimBeforeCurrent = -1;
for (auto dimIdx : sortedDimGroup) {
- // when order is present, adjacency dims are on order values like [3, 2, 1,
- // 0] in decreasing order otherwise based on dim indices like [0, 1, 2, 3]
+ // when order is present, adjacency dims are values like [3, 2, 1, 0]
+ // in decreasing order; otherwise based on dim indices like [0, 1, 2, 3]
// in increasing order
if (dimBeforeCurrent >= 0) {
if (!orderVec.empty()) {
@@ -586,10 +656,10 @@ DistributeLayoutAttr LayoutAttr::collapseDims(SmallVector<int64_t> dimGroup) {
laneData.insert(laneData.begin() + firstDim, collapsedLaneData);
}
- // go through the values inside collapsedOrder, and re-map the order values
- // to be in range of [0, N-1] where N is the number of dimensions in
- // collapsed shape for exmaple, collapse dim group {2, 3} of order[1, 2, 3,
- // 4] to new order[1, 3, 4]. the loop below remaps it to [1, 2, 3].
+ // After collapsing the order vector, re-map the order values to be in range
+ // of [0, N-1] where N is the number of dimensions in collapsed shape. For
+ // exmaple, collapse dim group {2, 3} of order[1, 2, 3, 4] to new order[1, 3,
+ // 4]. the loop below remaps it to [1, 2, 3].
SmallVector<int32_t> collapsedOrder;
if (!orderVec.empty()) {
@@ -598,7 +668,6 @@ DistributeLayoutAttr LayoutAttr::collapseDims(SmallVector<int64_t> dimGroup) {
orderVec.erase(orderVec.begin() + dimIdx,
orderVec.begin() + dimIdx + 1);
}
-
// say we have orderVec = {5, 3, 2, 1, 0}
// Create indices [0, 1, 2, 3, 4]
SmallVector<size_t> indices =
@@ -864,6 +933,20 @@ DistributeLayoutAttr SliceAttr::setDimData(int64_t dim, int64_t sgData,
parent.setDimData(adjustDims[0], sgData, instData, laneData), getDims());
}
+// Derive a new layout by removing dimensions.
+// `dimGroup` specifies a group of dimensions to be removed in the derived
+// layout.
+DistributeLayoutAttr SliceAttr::dropDims(SmallVector<int64_t> dimGroup) {
+ // Map the sliced dims from parent space to collapsed space
+ SmallVector<int64_t> sliceDims = llvm::to_vector(getDims().asArrayRef());
+ SmallVector<int64_t> dimsInParentSpace =
+ mapSlicedDimsToParentSpace(dimGroup, sliceDims);
+
+ auto droppedParent = getParent().dropDims(dimsInParentSpace);
+ return SliceAttr::get(getContext(), droppedParent,
+ DenseI64ArrayAttr::get(getContext(), sliceDims));
+}
+
// Derive a new layout by collapsing dimensions.
// `dimGroup` specifies a group of adjacent dimensions
// that are collapsed into a single dimension in the derived layout.
@@ -871,12 +954,14 @@ DistributeLayoutAttr SliceAttr::collapseDims(SmallVector<int64_t> dimGroup) {
// Map the sliced dims from parent space to collapsed space
SmallVector<int64_t> sliceDims = llvm::to_vector(getDims().asArrayRef());
-
+ assert("expect sliceDims not being collapsed" &&
+ llvm::none_of(dimGroup, [&](int64_t dim) {
+ return llvm::is_contained(sliceDims, dim);
+ }));
SmallVector<int64_t> dimsInParentSpace =
mapSlicedDimsToParentSpace(dimGroup, sliceDims);
auto collapsedParent = getParent().collapseDims(dimsInParentSpace);
-
return SliceAttr::get(getContext(), collapsedParent,
DenseI64ArrayAttr::get(getContext(), sliceDims));
}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 7aa186bb22224..097c850e77b91 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -241,31 +241,19 @@ xegpu::DistributeLayoutAttr xegpu::inferInsertStridedSliceSourceLayout(
int resShapeSize = resShape.size();
int dimDiff = resShapeSize - srcShapeSize;
- assert(isa<xegpu::LayoutAttr>(resLayout) &&
- "insertStridedSlice result layout must be plain layout");
- auto context = resLayout.getContext();
- auto resInstData = resLayout.getEffectiveInstDataAsInt();
- auto resLaneLayout = resLayout.getEffectiveLaneLayoutAsInt();
- auto resLaneData = resLayout.getEffectiveLaneDataAsInt();
-
- if (resInstData.size() != 0) {
- SmallVector<int> inferredInstData(srcShapeSize);
- for (int i = 0; i < srcShapeSize; i++)
- inferredInstData[i] = resInstData[i + dimDiff];
- return xegpu::LayoutAttr::get(context, inferredInstData);
- }
-
- if (resLaneLayout.size() != 0) {
- SmallVector<int> inferredLaneLayout(srcShapeSize);
- SmallVector<int> inferredLaneData(srcShapeSize);
- for (int i = 0; i < srcShapeSize; i++) {
- inferredLaneLayout[i] = resLaneLayout[i + dimDiff];
- inferredLaneData[i] = resLaneData[i + dimDiff];
+ if (dimDiff > 0) {
+ // assert that the leading dimensions being sliced off are not distributed
+ // (i.e. sg_layout and lane_layout for those dimensions are all 1)
+ auto resSgLayout = resLayout.getEffectiveSgLayoutAsInt();
+ auto resLaneLayout = resLayout.getEffectiveLaneLayoutAsInt();
+ for (int i = 0; i < dimDiff; i++) {
+ assert((resSgLayout.size() == 0 || resSgLayout[i] == 1) &&
+ (resLaneLayout.size() == 0 || resLaneLayout[i] == 1) &&
+ "Leading dimensions being sliced off must not be distributed");
}
- return xegpu::LayoutAttr::get(context, inferredLaneLayout,
- inferredLaneData);
+ return resLayout.dropDims(llvm::to_vector(llvm::seq<int64_t>(0, dimDiff)));
}
- return nullptr;
+ return resLayout;
}
/// Infers the source layout attribute for a shape cast operation given the
@@ -619,52 +607,54 @@ xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
/// resShape=[4, 64], subgroupSize=16, bitwidth=16, packingFactor=2
/// consumerLayout: laneLayout=[1, 16], laneData=[1, 2]
/// Result: laneLayout=[1, 16], laneData=[1, 2] (adjusted for packed data)
-xegpu::DistributeLayoutAttr xegpu::setupInsertStridedSliceResultLayout(
- xegpu::LayoutKind layoutKind, VectorType srcVectorTy,
- VectorType resVectorTy, xegpu::DistributeLayoutAttr consumerLayout,
- const xegpu::uArch::uArch *uArch) {
-
- xegpu::DistributeLayoutAttr requiredResLayout;
- auto subgroupSize = uArch->getSubgroupSize();
- auto context = resVectorTy.getContext();
- auto resShape = resVectorTy.getShape();
- int resShapeSize = resShape.size();
- auto srcShape = srcVectorTy.getShape();
- SmallVector<int64_t> consumerInstData =
- consumerLayout.getEffectiveInstDataAsInt();
- SmallVector<int64_t> consumerLaneData =
- consumerLayout.getEffectiveLaneDataAsInt();
-
- SmallVector<int> instData(resShapeSize, 1);
- SmallVector<int> laneLayout(resShapeSize, 1);
- SmallVector<int> laneData(resShapeSize, 1);
-
- const unsigned packingSize{uArch->getGeneralPackedFormatBitSize()};
- unsigned bitwidth = resVectorTy.getElementType().getIntOrFloatBitWidth();
- int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
- int packedDataSize = subgroupSize * packingFactor;
-
- if (layoutKind == xegpu::LayoutKind::Subgroup) {
- assert(true &&
- "subgroup layout assignment not supported for insertStridedSlice.");
- } else if (layoutKind == xegpu::LayoutKind::InstData) {
- assert(srcShape.back() >= subgroupSize &&
- "source innermost dim must be >= subgroupSize");
- instData.back() = subgroupSize;
- if (consumerInstData.back() == packedDataSize &&
- srcShape.back() >= packedDataSize)
- instData.back() = packedDataSize;
- requiredResLayout = xegpu::LayoutAttr::get(context, instData);
- } else if (layoutKind == xegpu::LayoutKind::Lane) {
- laneLayout.back() = subgroupSize;
- laneData.back() = 1;
- if (consumerLaneData.back() == packingFactor &&
- srcShape.back() >= packedDataSize)
- laneData.back() = packingFactor;
- requiredResLayout = xegpu::LayoutAttr::get(context, laneLayout, laneData);
- }
- return requiredResLayout;
-}
+// xegpu::DistributeLayoutAttr xegpu::setupInsertStridedSliceResultLayout(
+// xegpu::LayoutKind layoutKind, VectorType srcVectorTy,
+// VectorType resVectorTy, xegpu::DistributeLayoutAttr consumerLayout,
+// const xegpu::uArch::uArch *uArch) {
+
+// xegpu::DistributeLayoutAttr requiredResLayout;
+// auto subgroupSize = uArch->getSubgroupSize();
+// auto context = resVectorTy.getContext();
+// auto resShape = resVectorTy.getShape();
+// int resShapeSize = resShape.size();
+// auto srcShape = srcVectorTy.getShape();
+// SmallVector<int64_t> consumerInstData =
+// consumerLayout.getEffectiveInstDataAsInt();
+// SmallVector<int64_t> consumerLaneData =
+// consumerLayout.getEffectiveLaneDataAsInt();
+
+// SmallVector<int> instData(resShapeSize, 1);
+// SmallVector<int> laneLayout(resShapeSize, 1);
+// SmallVector<int> laneData(resShapeSize, 1);
+
+// const unsigned packingSize{uArch->getGeneralPackedFormatBitSize()};
+// unsigned bitwidth = resVectorTy.getElementType().getIntOrFloatBitWidth();
+// int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
+// int packedDataSize = subgroupSize * packingFactor;
+
+// if (layoutKind == xegpu::LayoutKind::Subgroup) {
+// assert(true &&
+// "subgroup layout assignment not supported for
+// insertStridedSlice.");
+// } else if (layoutKind == xegpu::LayoutKind::InstData) {
+// assert(srcShape.back() >= subgroupSize &&
+// "source innermost dim must be >= subgroupSize");
+// instData.back() = subgroupSize;
+// if (consumerInstData.back() == packedDataSize &&
+// srcShape.back() >= packedDataSize)
+// instData.back() = packedDataSize;
+// requiredResLayout = xegpu::LayoutAttr::get(context, instData);
+// } else if (layoutKind == xegpu::LayoutKind::Lane) {
+// laneLayout.back() = subgroupSize;
+// laneData.back() = 1;
+// if (consumerLaneData.back() == packingFactor &&
+// srcShape.back() >= packedDataSize)
+// laneData.back() = packingFactor;
+// requiredResLayout = xegpu::LayoutAttr::get(context, laneLayout,
+// laneData);
+// }
+// return requiredResLayout;
+// }
/// Sets up the anchor layout for load gather and load matrix operation.
/// load matrix lowers to load gather and 1d block load. All of them share the
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 7f7e8d6ad7734..ea9faf1438033 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -969,18 +969,12 @@ void LayoutInfoPropagation::visitInsertStridedSliceOp(
if (!uArch)
return;
- auto requiredResLayoutAttr = xegpu::setupInsertStridedSliceResultLayout(
- layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch);
-
- xegpu::setTemporaryLayout(insertStridedSlice->getResult(0),
- requiredResLayoutAttr);
-
auto srcLayoutAttr = xegpu::inferInsertStridedSliceSourceLayout(
- requiredResLayoutAttr, resVecType.getShape(), srcVecType.getShape());
+ consumerLayoutAttr, resVecType.getShape(), srcVecType.getShape());
propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
propagateIfChanged(operands[1],
- operands[1]->meet(LayoutInfo(requiredResLayoutAttr)));
+ operands[1]->meet(LayoutInfo(consumerLayoutAttr)));
}
/// Propagate the layout of the result to the tensor descriptor, mask and offset
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
index 5dd05e6cb0001..bfffcf75c7306 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
@@ -238,9 +238,9 @@ func.func @scatter_ops_chunksize_slice(%src: memref<1024xf32>) {
gpu.module @test {
// CHECK-LABEL: func.func @insert_strided_slice_inst_data_no_packing(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x32xf32>) {
-// CHECK: %[[CST_SMALL:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 16]>} dense<1.000000e+00> : vector<4x16xf32>
-// CHECK: %[[CST_LARGE:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 16]>} dense<0.000000e+00> : vector<8x32xf32>
-// CHECK: %[[INSERT:.*]] = vector.insert_strided_slice %[[CST_SMALL]], %[[CST_LARGE]] {layout_result_0 = #xegpu.layout<inst_data = [1, 16]>, offsets = [0, 0], strides = [1, 1]} : vector<4x16xf32> into vector<8x32xf32>
+// CHECK: %[[CST_SMALL:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} dense<1.000000e+00> : vector<4x16xf32>
+// CHECK: %[[CST_LARGE:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} dense<0.000000e+00> : vector<8x32xf32>
+// CHECK: %[[INSERT:.*]] = vector.insert_strided_slice %[[CST_SMALL]], %[[CST_LARGE]] {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>, offsets = [0, 0], strides = [1, 1]} : vector<4x16xf32> into vector<8x32xf32>
// CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}] : memref<8x32xf32> -> !xegpu.tensor_desc<8x32xf32, #xegpu.layout<inst_data = [8, 16]>>
// CHECK: xegpu.store_nd %[[INSERT]], %[[TDESC]] <{layout = #xegpu.layout<inst_data = [8, 16]>}> : vector<8x32xf32>, !xegpu.tensor_desc<8x32xf32, #xegpu.layout<inst_data = [8, 16]>>
func.func @insert_strided_slice_inst_data_no_packing(%arg0: memref<8x32xf32>) {
@@ -258,9 +258,9 @@ func.func @insert_strided_slice_inst_data_no_packing(%arg0: memref<8x32xf32>) {
gpu.module @test {
// CHECK-LABEL: func.func @insert_strided_slice_inst_data_with_packing(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x64xi8>) {
-// CHECK: %[[CST_SMALL:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 64]>} dense<1> : vector<4x64xi8>
-// CHECK: %[[CST_LARGE:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 64]>} dense<0> : vector<8x64xi8>
-// CHECK: %[[INSERT:.*]] = vector.insert_strided_slice %[[CST_SMALL]], %[[CST_LARGE]] {layout_result_0 = #xegpu.layout<inst_data = [1, 64]>, offsets = [0, 0], strides = [1, 1]} : vector<4x64xi8> into vector<8x64xi8>
+// CHECK: %[[CST_SMALL:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8, 64]>} dense<1> : vector<4x64xi8>
+// CHECK: %[[CST_LARGE:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8, 64]>} dense<0> : vector<8x64xi8>
+// CHECK: %[[INSERT:.*]] = vector.insert_strided_slice %[[CST_SMALL]], %[[CST_LARGE]] {layout_result_0 = #xegpu.layout<inst_data = [8, 64]>, offsets = [0, 0], strides = [1, 1]} : vector<4x64xi8> into vector<8x64xi8>
func.func @insert_strided_slice_inst_data_with_packing(%arg0: memref<8x64xi8>) {
%c0 = arith.constant 0 : index
%cst_small = arith.constant dense<1> : vector<4x64xi8>
>From bcb1e63828e230ca7dd36581db4cc25947ca0293 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 5 Mar 2026 02:43:57 +0000
Subject: [PATCH 2/7] add test
---
mlir/test/Dialect/XeGPU/propagate-layout.mlir | 24 +++++++++++++++++++
1 file changed, 24 insertions(+)
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index 4f2349a89b1ed..6ddb922ccd78a 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -792,6 +792,30 @@ func.func @insert_strided_slice_lane_layout_with_packing(%arg0: memref<4x64xf16>
}
}
+// -----
+gpu.module @test {
+// CHECK-LABEL: func.func @insert_strided_slice_with_slice_layout(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
+// CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>} dense<1.000000e+00> : vector<1xf32>
+// CHECK: %[[CST_0:.*]] = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>} dense<1.000000e+00> : vector<16xf32>
+// CHECK: %[[INSERT:.*]] = vector.insert_strided_slice %[[CST]], %[[CST_0]] {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>, offsets = [15], strides = [1]} : vector<1xf32> into vector<16xf32>
+// CHECK: %[[EXTRACT:.*]] = vector.extract_strided_slice %[[INSERT]] {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>, offsets = [0], sizes = [8], strides = [1]} : vector<16xf32> to vector<8xf32>
+// CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[EXTRACT]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} : vector<8xf32> to vector<16x8xf32>
+// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[BROADCAST]], [1, 0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x8xf32> to vector<8x16xf32>
+func.func @insert_strided_slice_with_slice_layout(%arg0: memref<8x16xf32>) {
+ %c0 = arith.constant 0 : index
+ %cst_small = arith.constant dense<1.0> : vector<1xf32>
+ %cst_large = arith.constant dense<1.0> : vector<16xf32>
+ %cst_large_new = vector.insert_strided_slice %cst_small, %cst_large {offsets = [15], strides = [1]} : vector<1xf32> into vector<16xf32>
+ %cst_small8 = vector.extract_strided_slice %cst_large_new {offsets = [0], sizes = [8], strides = [1]} : vector<16xf32> to vector<8xf32>
+ %cst_small16x8 = vector.broadcast %cst_small8 : vector<8xf32> to vector<16x8xf32>
+ %cst_small8x16 = vector.transpose %cst_small16x8, [1, 0] : vector<16x8xf32> to vector<8x16xf32>
+ %tdesc = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ xegpu.store_nd %cst_small8x16, %tdesc <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}>: vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ return
+}
+}
+
// -----
gpu.module @test{
// CHECK-LABEL: load_store_matrix
>From 8288b7e8b04d7a92e6d277d2838c4529d952e71f Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 5 Mar 2026 04:53:36 +0000
Subject: [PATCH 3/7] rewrite setupInsertStridedSliceResultLayout
---
.../XeGPU/Transforms/XeGPULayoutImpl.h | 6 +-
.../XeGPU/Transforms/XeGPULayoutImpl.cpp | 142 +++++++-----------
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 10 +-
.../XeGPU/propagate-layout-inst-data.mlir | 4 +-
4 files changed, 64 insertions(+), 98 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
index 5ccb69e0a89dd..3482d1b9401bb 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
@@ -132,9 +132,9 @@ DistributeLayoutAttr setupBitCastResultLayout(
/// Sets up the result layout for an insert strided slice operation.
/// Creates a result layout based on the specified layout kind (InstData or
/// Lane).
-// DistributeLayoutAttr setupInsertStridedSliceResultLayout(
-// LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy,
-// DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch);
+DistributeLayoutAttr setupInsertStridedSliceResultLayout(
+ LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy,
+ DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch);
/// Sets up the anchor layout for a load gather operation.
DistributeLayoutAttr
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 097c850e77b91..89445a9885b7c 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -532,11 +532,12 @@ xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
SmallVector<int64_t> sgData = consumerLayout.getEffectiveSgDataAsInt();
SmallVector<int64_t> instData = consumerLayout.getEffectiveInstDataAsInt();
SmallVector<int64_t> laneData = consumerLayout.getEffectiveLaneDataAsInt();
+ assert(consumerLayout.getRank() == static_cast<int64_t>(srcShape.size()) &&
+ "laneData must be available for all dimensions");
size_t dim = srcShape.size() - 1;
int64_t sgDataValue = -1;
int64_t instDataValue = -1;
int64_t laneDataValue = -1;
-
const int subgroupSize = uArch->getSubgroupSize();
if (srcElemTyBitWidth > resElemTyBitWidth) {
@@ -546,12 +547,8 @@ xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
int bitWidthRatio = srcElemTyBitWidth / resElemTyBitWidth;
int innermostDimLaneLayout = subgroupSize;
if (layoutKind == xegpu::LayoutKind::Subgroup) {
- assert(sgData.size() == srcShape.size() &&
- "sgData must be available for all dimensions");
sgDataValue = sgData[dim];
} else if (layoutKind == xegpu::LayoutKind::InstData) {
- assert(instData.size() == srcShape.size() &&
- "instData must be available for all dimensions");
instDataValue = instData[dim];
// Adjust instDataValue so it still fits within an instruction after
// dividing by bitWidthRatio
@@ -561,8 +558,6 @@ xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
assert((srcShape[dim] % instDataValue) == 0 &&
"srcShape, instData, and lanelayout for innermost must be 2^n !");
} else if (layoutKind == xegpu::LayoutKind::Lane) {
- assert(laneData.size() == srcShape.size() &&
- "laneData must be available for all dimensions");
laneDataValue = laneData[dim];
while ((laneDataValue <= srcShape[dim]) &&
(laneDataValue % bitWidthRatio != 0))
@@ -580,81 +575,47 @@ xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
/// Sets up the result layout for an insert strided slice operation.
/// Creates a result layout based on the specified layout kind (InstData or
/// Lane).
-/// Subgroup layout is currently not supported for this operation.
-/// InstData layout is first set to be {1, .., subgroupSize}.
-/// Lane layout is first set to be {1, ..., subgroupSize} with lane data {1,
-/// ..., 1}. The instData and laneData is then adjusted to contain packed data,
-/// by checking if the consumerLayout's innermost dimension.
-///
-/// Examples:
-/// 1. InstData layout without packing:
-/// resShape=[8, 32], subgroupSize=16, bitwidth=32
-/// packingFactor=1, packedDataSize=16
-/// consumerLayout: instData=[1, 16]
-/// Result: instData=[1, 16]
-///
-/// 2. InstData layout with packing:
-/// resShape=[8, 64], subgroupSize=16, bitwidth=8, packingFactor=4
-/// consumerLayout: instData=[1, 64]
-/// Result: instData=[1, 64] (adjusted for packed data)
-///
-/// 3. Lane layout without packing:
-/// resShape=[4, 64], subgroupSize=16, bitwidth=32
-/// consumerLayout: laneLayout=[1, 16], laneData=[1, 1]
-/// Result: laneLayout=[1, 16], laneData=[1, 1]
-///
-/// 4. Lane layout with packing:
-/// resShape=[4, 64], subgroupSize=16, bitwidth=16, packingFactor=2
-/// consumerLayout: laneLayout=[1, 16], laneData=[1, 2]
-/// Result: laneLayout=[1, 16], laneData=[1, 2] (adjusted for packed data)
-// xegpu::DistributeLayoutAttr xegpu::setupInsertStridedSliceResultLayout(
-// xegpu::LayoutKind layoutKind, VectorType srcVectorTy,
-// VectorType resVectorTy, xegpu::DistributeLayoutAttr consumerLayout,
-// const xegpu::uArch::uArch *uArch) {
-
-// xegpu::DistributeLayoutAttr requiredResLayout;
-// auto subgroupSize = uArch->getSubgroupSize();
-// auto context = resVectorTy.getContext();
-// auto resShape = resVectorTy.getShape();
-// int resShapeSize = resShape.size();
-// auto srcShape = srcVectorTy.getShape();
-// SmallVector<int64_t> consumerInstData =
-// consumerLayout.getEffectiveInstDataAsInt();
-// SmallVector<int64_t> consumerLaneData =
-// consumerLayout.getEffectiveLaneDataAsInt();
-
-// SmallVector<int> instData(resShapeSize, 1);
-// SmallVector<int> laneLayout(resShapeSize, 1);
-// SmallVector<int> laneData(resShapeSize, 1);
-
-// const unsigned packingSize{uArch->getGeneralPackedFormatBitSize()};
-// unsigned bitwidth = resVectorTy.getElementType().getIntOrFloatBitWidth();
-// int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
-// int packedDataSize = subgroupSize * packingFactor;
-
-// if (layoutKind == xegpu::LayoutKind::Subgroup) {
-// assert(true &&
-// "subgroup layout assignment not supported for
-// insertStridedSlice.");
-// } else if (layoutKind == xegpu::LayoutKind::InstData) {
-// assert(srcShape.back() >= subgroupSize &&
-// "source innermost dim must be >= subgroupSize");
-// instData.back() = subgroupSize;
-// if (consumerInstData.back() == packedDataSize &&
-// srcShape.back() >= packedDataSize)
-// instData.back() = packedDataSize;
-// requiredResLayout = xegpu::LayoutAttr::get(context, instData);
-// } else if (layoutKind == xegpu::LayoutKind::Lane) {
-// laneLayout.back() = subgroupSize;
-// laneData.back() = 1;
-// if (consumerLaneData.back() == packingFactor &&
-// srcShape.back() >= packedDataSize)
-// laneData.back() = packingFactor;
-// requiredResLayout = xegpu::LayoutAttr::get(context, laneLayout,
-// laneData);
-// }
-// return requiredResLayout;
-// }
+xegpu::DistributeLayoutAttr xegpu::setupInsertStridedSliceResultLayout(
+ xegpu::LayoutKind layoutKind, VectorType srcVectorTy,
+ VectorType resVectorTy, xegpu::DistributeLayoutAttr consumerLayout,
+ const xegpu::uArch::uArch *uArch) {
+
+ xegpu::DistributeLayoutAttr requiredResLayout;
+ SmallVector<int64_t> consumerInstData =
+ consumerLayout.getEffectiveInstDataAsInt();
+ SmallVector<int64_t> consumerLaneData =
+ consumerLayout.getEffectiveLaneDataAsInt();
+ SmallVector<int64_t> consumerLaneLayout =
+ consumerLayout.getEffectiveLaneLayoutAsInt();
+ ArrayRef<int64_t> srcShape = srcVectorTy.getShape();
+ int64_t instDataValue = -1;
+ int64_t laneDataValue = -1;
+
+ requiredResLayout = consumerLayout;
+ int srcRank = srcShape.size();
+
+ if (layoutKind == xegpu::LayoutKind::Subgroup) {
+ assert(true &&
+ "subgroup layout assignment not supported for insertStridedSlice.");
+ } else if (layoutKind == xegpu::LayoutKind::InstData) {
+ for (int dim = 0; dim < srcRank; dim++) {
+ instDataValue = std::min(srcShape[dim], consumerInstData[dim]);
+ requiredResLayout =
+ requiredResLayout.setDimData(dim, -1, instDataValue, -1);
+ }
+ } else if (layoutKind == xegpu::LayoutKind::Lane) {
+ for (int dim = 0; dim < srcRank; dim++) {
+ assert(srcShape[dim] % consumerLaneLayout[dim] == 0 &&
+ "srcShape must be divisible by laneLayout for all dimensions");
+ laneDataValue = std::min(srcShape[dim] / consumerLaneLayout[dim],
+ consumerLaneData[dim]);
+
+ requiredResLayout =
+ requiredResLayout.setDimData(dim, -1, -1, laneDataValue);
+ }
+ }
+ return requiredResLayout;
+}
/// Sets up the anchor layout for load gather and load matrix operation.
/// load matrix lowers to load gather and 1d block load. All of them share the
@@ -849,8 +810,8 @@ xegpu::setupStoreMatrixAnchorLayout(xegpu::LayoutKind layoutKind,
}
// This function returns the default lane layout for a given vector type.
-// - `packingSize` means multiple consecutive elements can be accessed together
-// as a single unit.
+// - `packingSize` means multiple consecutive elements can be accessed
+// together as a single unit.
// - `vnni` means data packing is column-wise (i.e., 2x1xf16 with vnni vs.
// 1x2xf16 w/o vnni).
template <typename RankedTy>
@@ -915,7 +876,8 @@ getValidLayouts(ArrayRef<int64_t> wgShape, ArrayRef<int64_t> instData,
}
/// Sets up the anchor layouts for dpas operands (A, B, and C/D).
-/// The numSg and consumerLayout (optional) are only used by sg layout creation.
+/// The numSg and consumerLayout (optional) are only used by sg layout
+/// creation.
std::optional<
std::tuple<xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr,
xegpu::DistributeLayoutAttr>>
@@ -1001,9 +963,9 @@ xegpu::setupDpasLayout(xegpu::LayoutKind layoutKind, VectorType aTy,
break;
}
// Is in (A and B and CD) layoutsB is ordered from most
- // balanced to least. So the first one we see is the most balanced one,
- // remember it and later only update if there is one that matches the
- // consumer.
+ // balanced to least. So the first one we see is the most balanced
+ // one, remember it and later only update if there is one that matches
+ // the consumer.
if (!bestPick)
bestPick = sgLayout;
}
@@ -1142,7 +1104,7 @@ xegpu::DistributeLayoutAttr xegpu::getConsumerLayoutAt(OpOperand &operand) {
return resLayout;
}
// TODO: Handle more cases as needed here.
- // By default, assume no layout conflict and return the current layout of the
- // operand.
+ // By default, assume no layout conflict and return the current layout of
+ // the operand.
return xegpu::getDistributeLayoutAttr(operand.get());
}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index ea9faf1438033..4eac49525b832 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -969,12 +969,16 @@ void LayoutInfoPropagation::visitInsertStridedSliceOp(
if (!uArch)
return;
- auto srcLayoutAttr = xegpu::inferInsertStridedSliceSourceLayout(
- consumerLayoutAttr, resVecType.getShape(), srcVecType.getShape());
+ auto requiredResLayoutAttr = xegpu::setupInsertStridedSliceResultLayout(
+ layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch);
+ xegpu::setTemporaryLayout(insertStridedSlice->getResult(0),
+ requiredResLayoutAttr);
+ auto srcLayoutAttr = xegpu::inferInsertStridedSliceSourceLayout(
+ requiredResLayoutAttr, resVecType.getShape(), srcVecType.getShape());
propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
propagateIfChanged(operands[1],
- operands[1]->meet(LayoutInfo(consumerLayoutAttr)));
+ operands[1]->meet(LayoutInfo(requiredResLayoutAttr)));
}
/// Propagate the layout of the result to the tensor descriptor, mask and offset
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
index bfffcf75c7306..b0523229a4640 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
@@ -238,7 +238,7 @@ func.func @scatter_ops_chunksize_slice(%src: memref<1024xf32>) {
gpu.module @test {
// CHECK-LABEL: func.func @insert_strided_slice_inst_data_no_packing(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x32xf32>) {
-// CHECK: %[[CST_SMALL:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} dense<1.000000e+00> : vector<4x16xf32>
+// CHECK: %[[CST_SMALL:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [4, 16]>} dense<1.000000e+00> : vector<4x16xf32>
// CHECK: %[[CST_LARGE:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} dense<0.000000e+00> : vector<8x32xf32>
// CHECK: %[[INSERT:.*]] = vector.insert_strided_slice %[[CST_SMALL]], %[[CST_LARGE]] {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>, offsets = [0, 0], strides = [1, 1]} : vector<4x16xf32> into vector<8x32xf32>
// CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}] : memref<8x32xf32> -> !xegpu.tensor_desc<8x32xf32, #xegpu.layout<inst_data = [8, 16]>>
@@ -258,7 +258,7 @@ func.func @insert_strided_slice_inst_data_no_packing(%arg0: memref<8x32xf32>) {
gpu.module @test {
// CHECK-LABEL: func.func @insert_strided_slice_inst_data_with_packing(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x64xi8>) {
-// CHECK: %[[CST_SMALL:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8, 64]>} dense<1> : vector<4x64xi8>
+// CHECK: %[[CST_SMALL:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [4, 64]>} dense<1> : vector<4x64xi8>
// CHECK: %[[CST_LARGE:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8, 64]>} dense<0> : vector<8x64xi8>
// CHECK: %[[INSERT:.*]] = vector.insert_strided_slice %[[CST_SMALL]], %[[CST_LARGE]] {layout_result_0 = #xegpu.layout<inst_data = [8, 64]>, offsets = [0, 0], strides = [1, 1]} : vector<4x64xi8> into vector<8x64xi8>
func.func @insert_strided_slice_inst_data_with_packing(%arg0: memref<8x64xi8>) {
>From 36eb877956f3b504cae2dd64b06b8cfd6ec51efb Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 5 Mar 2026 05:39:07 +0000
Subject: [PATCH 4/7] fix test
---
mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
index b0523229a4640..c510a1d5f0fdf 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
@@ -239,8 +239,8 @@ gpu.module @test {
// CHECK-LABEL: func.func @insert_strided_slice_inst_data_no_packing(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x32xf32>) {
// CHECK: %[[CST_SMALL:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [4, 16]>} dense<1.000000e+00> : vector<4x16xf32>
-// CHECK: %[[CST_LARGE:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} dense<0.000000e+00> : vector<8x32xf32>
-// CHECK: %[[INSERT:.*]] = vector.insert_strided_slice %[[CST_SMALL]], %[[CST_LARGE]] {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>, offsets = [0, 0], strides = [1, 1]} : vector<4x16xf32> into vector<8x32xf32>
+// CHECK: %[[CST_LARGE:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [4, 16]>} dense<0.000000e+00> : vector<8x32xf32>
+// CHECK: %[[INSERT:.*]] = vector.insert_strided_slice %[[CST_SMALL]], %[[CST_LARGE]] {layout_result_0 = #xegpu.layout<inst_data = [4, 16]>, offsets = [0, 0], strides = [1, 1]} : vector<4x16xf32> into vector<8x32xf32>
// CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}] : memref<8x32xf32> -> !xegpu.tensor_desc<8x32xf32, #xegpu.layout<inst_data = [8, 16]>>
// CHECK: xegpu.store_nd %[[INSERT]], %[[TDESC]] <{layout = #xegpu.layout<inst_data = [8, 16]>}> : vector<8x32xf32>, !xegpu.tensor_desc<8x32xf32, #xegpu.layout<inst_data = [8, 16]>>
func.func @insert_strided_slice_inst_data_no_packing(%arg0: memref<8x32xf32>) {
@@ -259,8 +259,8 @@ gpu.module @test {
// CHECK-LABEL: func.func @insert_strided_slice_inst_data_with_packing(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x64xi8>) {
// CHECK: %[[CST_SMALL:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [4, 64]>} dense<1> : vector<4x64xi8>
-// CHECK: %[[CST_LARGE:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8, 64]>} dense<0> : vector<8x64xi8>
-// CHECK: %[[INSERT:.*]] = vector.insert_strided_slice %[[CST_SMALL]], %[[CST_LARGE]] {layout_result_0 = #xegpu.layout<inst_data = [8, 64]>, offsets = [0, 0], strides = [1, 1]} : vector<4x64xi8> into vector<8x64xi8>
+// CHECK: %[[CST_LARGE:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [4, 64]>} dense<0> : vector<8x64xi8>
+// CHECK: %[[INSERT:.*]] = vector.insert_strided_slice %[[CST_SMALL]], %[[CST_LARGE]] {layout_result_0 = #xegpu.layout<inst_data = [4, 64]>, offsets = [0, 0], strides = [1, 1]} : vector<4x64xi8> into vector<8x64xi8>
func.func @insert_strided_slice_inst_data_with_packing(%arg0: memref<8x64xi8>) {
%c0 = arith.constant 0 : index
%cst_small = arith.constant dense<1> : vector<4x64xi8>
>From dee54e3c0e8ee7d3b3d67b9873a94e209d67852c Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 6 Mar 2026 03:00:54 +0000
Subject: [PATCH 5/7] address feedback
---
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 127 +++++++--------------
1 file changed, 42 insertions(+), 85 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 76b7a690288a3..4c434a35755a8 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -516,23 +516,16 @@ DistributeLayoutAttr LayoutAttr::dropDims(SmallVector<int64_t> dimGroup) {
SmallVector<int64_t> sortedDimGroup = dimGroup;
llvm::sort(sortedDimGroup);
- if (!sgLayout.empty()) {
- for (auto dimIdx : llvm::reverse(sortedDimGroup)) {
- sgLayout.erase(sgLayout.begin() + dimIdx, sgLayout.begin() + dimIdx + 1);
- sgData.erase(sgData.begin() + dimIdx, sgData.begin() + dimIdx + 1);
+ for (auto dimIdx : llvm::reverse(sortedDimGroup)) {
+ if (!sgLayout.empty()) {
+ sgLayout.erase(sgLayout.begin() + dimIdx);
+ sgData.erase(sgData.begin() + dimIdx);
}
- }
-
- if (!instData.empty()) {
- for (auto dimIdx : llvm::reverse(sortedDimGroup))
- instData.erase(instData.begin() + dimIdx, instData.begin() + dimIdx + 1);
- }
-
- if (!laneLayout.empty()) {
- for (auto dimIdx : llvm::reverse(sortedDimGroup)) {
- laneLayout.erase(laneLayout.begin() + dimIdx,
- laneLayout.begin() + dimIdx + 1);
- laneData.erase(laneData.begin() + dimIdx, laneData.begin() + dimIdx + 1);
+ if (!instData.empty())
+ instData.erase(instData.begin() + dimIdx);
+ if (!laneLayout.empty()) {
+ laneLayout.erase(laneLayout.begin() + dimIdx);
+ laneData.erase(laneData.begin() + dimIdx);
}
}
@@ -543,31 +536,18 @@ DistributeLayoutAttr LayoutAttr::dropDims(SmallVector<int64_t> dimGroup) {
int64_t offset = llvm::count_if(dimGroup, [&](int64_t s) { return s < d; });
newOrder.push_back(d - offset);
}
-
- // Create dropped layout
- SmallVector<int32_t> sgLayout32(sgLayout.begin(), sgLayout.end());
- SmallVector<int32_t> sgData32(sgData.begin(), sgData.end());
- SmallVector<int32_t> instData32(instData.begin(), instData.end());
- SmallVector<int32_t> laneLayout32(laneLayout.begin(), laneLayout.end());
- SmallVector<int32_t> laneData32(laneData.begin(), laneData.end());
- SmallVector<int32_t> newOrder32(newOrder.begin(), newOrder.end());
-
- DenseI32ArrayAttr orderAttr = getOrder();
+ if (sgLayout.empty() && laneLayout.empty())
+ newOrder.clear();
+
+ auto toAttr = [&](ArrayRef<int64_t> v) -> DenseI32ArrayAttr {
+ if (v.empty())
+ return DenseI32ArrayAttr();
+ SmallVector<int32_t> v32(v.begin(), v.end());
+ return DenseI32ArrayAttr::get(getContext(), v32);
+ };
auto droppedLayout = xegpu::LayoutAttr::get(
- getContext(),
- sgLayout32.empty() ? DenseI32ArrayAttr()
- : DenseI32ArrayAttr::get(getContext(), sgLayout32),
- sgData32.empty() ? DenseI32ArrayAttr()
- : DenseI32ArrayAttr::get(getContext(), sgData32),
- instData32.empty() ? DenseI32ArrayAttr()
- : DenseI32ArrayAttr::get(getContext(), instData32),
- laneLayout32.empty() ? DenseI32ArrayAttr()
- : DenseI32ArrayAttr::get(getContext(), laneLayout32),
- laneData32.empty() ? DenseI32ArrayAttr()
- : DenseI32ArrayAttr::get(getContext(), laneData32),
- (!orderAttr || orderAttr.empty())
- ? DenseI32ArrayAttr()
- : DenseI32ArrayAttr::get(getContext(), newOrder32));
+ getContext(), toAttr(sgLayout), toAttr(sgData), toAttr(instData),
+ toAttr(laneLayout), toAttr(laneData), toAttr(newOrder));
return droppedLayout;
}
@@ -581,26 +561,19 @@ DistributeLayoutAttr LayoutAttr::collapseDims(SmallVector<int64_t> dimGroup) {
SmallVector<int64_t> instData = getEffectiveInstDataAsInt();
SmallVector<int64_t> laneLayout = getEffectiveLaneLayoutAsInt();
SmallVector<int64_t> laneData = getEffectiveLaneDataAsInt();
-
- DenseI32ArrayAttr orderAttr = getOrder();
- SmallVector<int32_t> orderVec;
- if (orderAttr && !orderAttr.empty()) {
- orderVec = llvm::to_vector(
- llvm::map_range(orderAttr.asArrayRef(),
- [](int32_t idx) { return static_cast<int32_t>(idx); }));
- }
+ SmallVector<int64_t> origOrder = getEffectiveOrderAsInt();
SmallVector<int64_t> sortedDimGroup = dimGroup;
llvm::sort(sortedDimGroup);
int64_t dimBeforeCurrent = -1;
for (auto dimIdx : sortedDimGroup) {
- // when order is present, adjacency dims are values like [3, 2, 1, 0]
+ // when order attr is present, adjacency dims are values like [3, 2, 1, 0]
// in decreasing order; otherwise based on dim indices like [0, 1, 2, 3]
// in increasing order
if (dimBeforeCurrent >= 0) {
- if (!orderVec.empty()) {
- int64_t orderBefore = orderVec[dimBeforeCurrent];
- int64_t orderCurrent = orderVec[dimIdx];
+ if (getOrder() && !getOrder().empty()) {
+ int64_t orderBefore = origOrder[dimBeforeCurrent];
+ int64_t orderCurrent = origOrder[dimIdx];
if (orderBefore != (orderCurrent - 1))
llvm::report_fatal_error(
"dimensions being collapsed must be adjacent in order");
@@ -656,52 +629,36 @@ DistributeLayoutAttr LayoutAttr::collapseDims(SmallVector<int64_t> dimGroup) {
laneData.insert(laneData.begin() + firstDim, collapsedLaneData);
}
- // After collapsing the order vector, re-map the order values to be in range
- // of [0, N-1] where N is the number of dimensions in collapsed shape. For
- // exmaple, collapse dim group {2, 3} of order[1, 2, 3, 4] to new order[1, 3,
- // 4]. the loop below remaps it to [1, 2, 3].
- SmallVector<int32_t> collapsedOrder;
- if (!orderVec.empty()) {
+ SmallVector<int64_t> newOrder;
+ DenseI32ArrayAttr orderAttr = getOrder();
+ if (orderAttr && !orderAttr.empty()) {
for (auto dimIdx : llvm::reverse(sortedDimGroup)) {
if (dimIdx != firstDim)
- orderVec.erase(orderVec.begin() + dimIdx,
- orderVec.begin() + dimIdx + 1);
+ origOrder.erase(origOrder.begin() + dimIdx);
}
// say we have orderVec = {5, 3, 2, 1, 0}
// Create indices [0, 1, 2, 3, 4]
SmallVector<size_t> indices =
- llvm::to_vector(llvm::seq<size_t>(0, orderVec.size()));
+ llvm::to_vector(llvm::seq<size_t>(0, orderAttr.size()));
// Sort indices based on corresponding values
llvm::sort(indices,
- [&](size_t a, size_t b) { return orderVec[a] < orderVec[b]; });
- collapsedOrder = llvm::to_vector(llvm::map_range(
- indices, [&](size_t i) { return static_cast<int32_t>(i); }));
- }
+ [&](size_t a, size_t b) { return origOrder[a] < origOrder[b]; });
- // Create collapsed layout
- SmallVector<int32_t> sgLayout32(sgLayout.begin(), sgLayout.end());
- SmallVector<int32_t> sgData32(sgData.begin(), sgData.end());
- SmallVector<int32_t> instData32(instData.begin(), instData.end());
- SmallVector<int32_t> laneLayout32(laneLayout.begin(), laneLayout.end());
- SmallVector<int32_t> laneData32(laneData.begin(), laneData.end());
+ newOrder = llvm::to_vector(llvm::map_range(
+ indices, [&](size_t i) { return static_cast<int64_t>(i); }));
+ }
+ auto toAttr = [&](ArrayRef<int64_t> v) -> DenseI32ArrayAttr {
+ if (v.empty())
+ return DenseI32ArrayAttr();
+ SmallVector<int32_t> v32(v.begin(), v.end());
+ return DenseI32ArrayAttr::get(getContext(), v32);
+ };
auto collapsedLayout = xegpu::LayoutAttr::get(
- getContext(),
- sgLayout32.empty() ? DenseI32ArrayAttr()
- : DenseI32ArrayAttr::get(getContext(), sgLayout32),
- sgData32.empty() ? DenseI32ArrayAttr()
- : DenseI32ArrayAttr::get(getContext(), sgData32),
- instData32.empty() ? DenseI32ArrayAttr()
- : DenseI32ArrayAttr::get(getContext(), instData32),
- laneLayout32.empty() ? DenseI32ArrayAttr()
- : DenseI32ArrayAttr::get(getContext(), laneLayout32),
- laneData32.empty() ? DenseI32ArrayAttr()
- : DenseI32ArrayAttr::get(getContext(), laneData32),
- collapsedOrder.empty()
- ? DenseI32ArrayAttr()
- : DenseI32ArrayAttr::get(getContext(), collapsedOrder));
+ getContext(), toAttr(sgLayout), toAttr(sgData), toAttr(instData),
+ toAttr(laneLayout), toAttr(laneData), toAttr(newOrder));
return collapsedLayout;
}
>From d2eaeba6fa48f1ccc70b0895b121724d46941e42 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 6 Mar 2026 03:41:59 +0000
Subject: [PATCH 6/7] fix issue in sliceAttr::dropdims
---
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 29 +++++++++++++++++++---
1 file changed, 25 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 4c434a35755a8..9d99d402637fb 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -890,9 +890,23 @@ DistributeLayoutAttr SliceAttr::setDimData(int64_t dim, int64_t sgData,
parent.setDimData(adjustDims[0], sgData, instData, laneData), getDims());
}
-// Derive a new layout by removing dimensions.
-// `dimGroup` specifies a group of dimensions to be removed in the derived
-// layout.
+// Derive a new layout by removing dimensions. `dimGroup` specifies a group of
+// dimensions to be removed in the derived layout.
+//
+// Example: drop the 2nd dimension from a rank-3 sliced view.
+//
+// Suppose:
+// xegpu.layout = slice<layout<[V0, V1, V2, V3, V4]>, [1, 3]>
+//
+// The slice removes parent dims [1, 3], so the sliced-space dims map to
+// parent dims [V0, V2, V4].
+//
+// If we drop sliced-space dim 1 (the 2nd dim), that corresponds to dropping
+// parent dim 2, result in parent layout [V0, V1, V3, V4] after dropping.
+// After parent dim 2 is removed, sliced dims [1, 3] must be reindexed to [1, 2].
+//
+// Result:
+// xegpu.layout = slice<layout<[0, 1, 3, 4]>, [1, 2]>
DistributeLayoutAttr SliceAttr::dropDims(SmallVector<int64_t> dimGroup) {
// Map the sliced dims from parent space to collapsed space
SmallVector<int64_t> sliceDims = llvm::to_vector(getDims().asArrayRef());
@@ -900,8 +914,15 @@ DistributeLayoutAttr SliceAttr::dropDims(SmallVector<int64_t> dimGroup) {
mapSlicedDimsToParentSpace(dimGroup, sliceDims);
auto droppedParent = getParent().dropDims(dimsInParentSpace);
+
+ SmallVector<int64_t> newSliceDims;
+ for (int64_t d : sliceDims) {
+ int64_t offset = llvm::count_if(dimsInParentSpace, [&](int64_t s) { return s < d; });
+ newSliceDims.push_back(d - offset);
+ }
+
return SliceAttr::get(getContext(), droppedParent,
- DenseI64ArrayAttr::get(getContext(), sliceDims));
+ DenseI64ArrayAttr::get(getContext(), newSliceDims));
}
// Derive a new layout by collapsing dimensions.
>From 4253006d91cdb718376986587c8d38dc99a29794 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 6 Mar 2026 14:59:07 +0000
Subject: [PATCH 7/7] add comments
---
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 9 +++++++--
1 file changed, 7 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 9d99d402637fb..a6b91bb8febc9 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -903,7 +903,8 @@ DistributeLayoutAttr SliceAttr::setDimData(int64_t dim, int64_t sgData,
//
// If we drop sliced-space dim 1 (the 2nd dim), that corresponds to dropping
// parent dim 2, result in parent layout [V0, V1, V3, V4] after dropping.
-// After parent dim 2 is removed, sliced dims [1, 3] must be reindexed to [1, 2].
+// After parent dim 2 is removed, sliced dims [1, 3] must be reindexed to [1,
+// 2].
//
// Result:
// xegpu.layout = slice<layout<[0, 1, 3, 4]>, [1, 2]>
@@ -915,9 +916,13 @@ DistributeLayoutAttr SliceAttr::dropDims(SmallVector<int64_t> dimGroup) {
auto droppedParent = getParent().dropDims(dimsInParentSpace);
+ // Adjust the sliced dims after dropping dims in parent space. For example, if
+ // we drop dim 2 in parent space, the dims after dim 2 will all be shifted by
+ // 1, so sliced dim 3 will be adjusted to 2.
SmallVector<int64_t> newSliceDims;
for (int64_t d : sliceDims) {
- int64_t offset = llvm::count_if(dimsInParentSpace, [&](int64_t s) { return s < d; });
+ int64_t offset =
+ llvm::count_if(dimsInParentSpace, [&](int64_t s) { return s < d; });
newSliceDims.push_back(d - offset);
}
More information about the Mlir-commits
mailing list