[Mlir-commits] [mlir] [MLIR][XeGPU] setUnitDim bug fix and add documentation (PR #173521)
Jianhui Li
llvmlistbot at llvm.org
Thu Jan 15 16:55:16 PST 2026
https://github.com/Jianhui-Li updated https://github.com/llvm/llvm-project/pull/173521
>From daf352dd17e61c34ff69c5fa220c93cd6458f3e7 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Mon, 22 Dec 2025 03:22:02 +0000
Subject: [PATCH 1/2] fix-setUnitDim
---
.../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 16 ++---
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 59 ++++++++++---------
.../Transforms/XeGPUWgToSgDistribute.cpp | 17 +++---
mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp | 3 +-
.../XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir | 2 +-
.../XeGPU/xegpu-wg-to-sg-unify-ops.mlir | 38 +++++++++++-
6 files changed, 87 insertions(+), 48 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 446f64fffa468..d4d5c7d58b37e 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -454,7 +454,7 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
return 0;
}
- LayoutAttr dropSgLayoutAndData() {
+ LayoutAttr dropSgLayoutAndData() const{
// avoid every field of the attribute is nullptr, which may lead to segment fault
if (!getInstData() && !getLaneLayout())
return nullptr;
@@ -462,7 +462,7 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
getLaneLayout(), getLaneData(), getOrder());
}
- LayoutAttr dropInstData() {
+ LayoutAttr dropInstData() const{
// avoid every field of the attribute is nullptr, which may lead to segment fault
if (!getSgLayout() && !getLaneLayout())
return nullptr;
@@ -501,10 +501,10 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
}
//set the layout for the sepcified unit dims: sg_data, inst_data and lane_data to 1
- DistributeLayoutAttr setUnitDimData(SetVector<int64_t> unitDims);
+ DistributeLayoutAttr setUnitDimData(SetVector<int64_t> unitDims) const;
//set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
- DistributeLayoutAttr setUnitDimLayout(SetVector<int64_t> unitDims);
+ DistributeLayoutAttr setUnitDimLayout(SetVector<int64_t> unitDims) const;
/// Delinearizes a linear ID into its multidimensional indices
/// based on the effective level of the layout.
@@ -653,7 +653,7 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
return {};
}
- SliceAttr dropSgLayoutAndData() {
+ SliceAttr dropSgLayoutAndData() const{
SliceAttr attr = flatten();
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
parent = parent.dropSgLayoutAndData();
@@ -662,7 +662,7 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
return SliceAttr::get(getContext(), parent, attr.getDims());
}
- SliceAttr dropInstData() {
+ SliceAttr dropInstData() const{
SliceAttr attr = flatten();
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
parent = parent.dropInstData();
@@ -672,10 +672,10 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
}
//set the layout for the sepcified unit dims: sg_data, inst_data and lane_data to 1
- DistributeLayoutAttr setUnitDimData(SetVector<int64_t> unitDims);
+ DistributeLayoutAttr setUnitDimData(SetVector<int64_t> unitDims) const;
//set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
- DistributeLayoutAttr setUnitDimLayout(SetVector<int64_t> unitDims);
+ DistributeLayoutAttr setUnitDimLayout(SetVector<int64_t> unitDims) const;
/// flatten a nested SliceAttr, e.g., for 2-level nested SliceAttr
/// #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [4, 8, 12]>, dims = [0]>, dims = [0]>
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index ccf17da26c942..13bf475920075 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -400,7 +400,7 @@ bool LayoutAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
}
// set the layout for unit dims: sg_data, inst_data and lane_data to 1
-DistributeLayoutAttr LayoutAttr::setUnitDimData(SetVector<int64_t> unitDims) {
+DistributeLayoutAttr LayoutAttr::setUnitDimData(SetVector<int64_t> unitDims) const{
auto sgDataOpt = getSgData();
auto instDataOpt = getInstData();
auto laneDataOpt = getLaneData();
@@ -441,7 +441,7 @@ DistributeLayoutAttr LayoutAttr::setUnitDimData(SetVector<int64_t> unitDims) {
}
// set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
-DistributeLayoutAttr LayoutAttr::setUnitDimLayout(SetVector<int64_t> unitDims) {
+DistributeLayoutAttr LayoutAttr::setUnitDimLayout(SetVector<int64_t> unitDims) const{
auto sgLayoutOpt = getSgLayout();
auto laneLayoutOpt = getLaneLayout();
@@ -587,7 +587,7 @@ bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) {
// dims.
llvm::SmallDenseSet<int64_t> thisDims(
flattenedThis.getDims().asArrayRef().begin(),
- flattenedThis.getDims().asArrayRef().end());
+ flattenedThis.getDims().asArrayRef().end());
return llvm::all_of(flattenedOther.getDims().asArrayRef(),
[&](int64_t dim) { return thisDims.contains(dim); });
}
@@ -604,55 +604,60 @@ bool SliceAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
}
// Helper function to adjust unit dimensions from sliced space to parent space
+// say we have a parent shape of rank 4, and slice dims [1,3], so the sliced
+// shape is of rank 2, if we want to set unit dim [0] in sliced space, it maps to
+// dim [0] in parent space; if we want to set unit dim [1] in sliced space, it maps to
+// dim [2] in parent space.
static SetVector<int64_t>
adjustUnitDimsWithSliceDims(const SetVector<int64_t> &unitDims,
ArrayRef<int64_t> sliceDims) {
- // Reconstruct parent's non-sliced dimensions
+ // get max number from sliceDims and unitDims to determine parent space rank
+ int64_t maxDim = -1;
+ maxDim = std::max(maxDim, *std::max_element(sliceDims.begin(), sliceDims.end()));
+ maxDim = std::max(maxDim, *std::max_element(unitDims.begin(), unitDims.end()));
+ int64_t parentSpaceRank = maxDim + sliceDims.size() + 1;
- int64_t parentRank = sliceDims.size() + unitDims.size();
+ // get remaining dims in parent space after applying slicing with parent's slice Dims
llvm::SmallDenseSet<int64_t> slicedDimsSet(sliceDims.begin(),
sliceDims.end());
- SmallVector<int64_t> nonSlicedDims;
- for (int64_t i = 0; i < parentRank; ++i) {
+ SmallVector<int64_t> remainingDims;
+ for (int64_t i = 0; i < parentSpaceRank; ++i) {
if (!slicedDimsSet.contains(i))
- nonSlicedDims.push_back(i);
+ remainingDims.push_back(i);
}
// Map unit dims from sliced space to parent space
SetVector<int64_t> adjustUnitDims;
for (auto dim : unitDims) {
- if (dim < static_cast<int64_t>(nonSlicedDims.size())) {
- adjustUnitDims.insert(nonSlicedDims[dim]);
- }
+ int64_t mappedDim = remainingDims[dim];
+ adjustUnitDims.insert(mappedDim);
}
return adjustUnitDims;
}
// set the layout for unit dims: sg_data, inst_data and lane_data to 1
-DistributeLayoutAttr SliceAttr::setUnitDimData(SetVector<int64_t> unitDims) {
- SliceAttr attr = flatten();
- ArrayRef<int64_t> sliceDims = attr.getDims().asArrayRef();
- auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+DistributeLayoutAttr SliceAttr::setUnitDimData(SetVector<int64_t> unitDims) const{
+ DistributeLayoutAttr parentLayout = getParent();
+
+ ArrayRef<int64_t> sliceDims = getDims().asArrayRef();
- SetVector<int64_t> adjustUnitDims =
- adjustUnitDimsWithSliceDims(unitDims, sliceDims);
+ SetVector<int64_t> adjustUnitDims = adjustUnitDimsWithSliceDims(unitDims, sliceDims);
- return SliceAttr::get(getContext(), parent.setUnitDimData(adjustUnitDims),
- attr.getDims());
+ return SliceAttr::get(getContext(), parentLayout.setUnitDimData(adjustUnitDims),
+ getDims());
}
// set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
-DistributeLayoutAttr SliceAttr::setUnitDimLayout(SetVector<int64_t> unitDims) {
- SliceAttr attr = flatten();
- ArrayRef<int64_t> sliceDims = attr.getDims().asArrayRef();
- auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+DistributeLayoutAttr SliceAttr::setUnitDimLayout(SetVector<int64_t> unitDims) const{
+ DistributeLayoutAttr parentLayout = getParent();
+
+ ArrayRef<int64_t> sliceDims = getDims().asArrayRef();
- SetVector<int64_t> adjustUnitDims =
- adjustUnitDimsWithSliceDims(unitDims, sliceDims);
+ SetVector<int64_t> adjustUnitDims = adjustUnitDimsWithSliceDims(unitDims, sliceDims);
- return SliceAttr::get(getContext(), parent.setUnitDimLayout(adjustUnitDims),
- attr.getDims());
+ return SliceAttr::get(getContext(), parentLayout.setUnitDimLayout(adjustUnitDims),
+ getDims());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 07572a4950760..0dcabf0fccf6a 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -614,20 +614,20 @@ struct WgToSgConvertLayoutOp
LogicalResult
matchAndRewrite(xegpu::ConvertLayoutOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- // TODO: currently, we only support LayoutAttr
- auto input = dyn_cast<xegpu::LayoutAttr>(op.getInputLayout());
- auto target = dyn_cast<xegpu::LayoutAttr>(op.getTargetLayout());
+
+ auto input = op.getInputLayout();
+ auto target = op.getTargetLayout();
if (!input || !target || !input.isForWorkgroup() ||
!target.isForWorkgroup())
return rewriter.notifyMatchFailure(
op, "Input and target layouts must have subgroup layout");
- DenseI32ArrayAttr inputSgLayout = input.getSgLayout();
- DenseI32ArrayAttr inputSgData = input.getSgData();
+ SmallVector<int64_t> inputSgLayout = input.getEffectiveSgLayoutAsInt();
+ SmallVector<int64_t> inputSgData = input.getEffectiveSgDataAsInt();
DenseI32ArrayAttr inputOrder = input.getOrder();
- DenseI32ArrayAttr targetSgLayout = target.getSgLayout();
- DenseI32ArrayAttr targetSgData = target.getSgData();
+ SmallVector<int64_t> targetSgLayout = target.getEffectiveSgLayoutAsInt();
+ SmallVector<int64_t> targetSgData = target.getEffectiveSgDataAsInt();
DenseI32ArrayAttr targetOrder = target.getOrder();
// TODO: currently we only support for optimal case, where input and
@@ -1138,12 +1138,11 @@ struct WgToSgVectorShapeCastOp
return false;
return srcIdx == src.size();
};
-
xegpu::DistributeLayoutAttr layoutToDistribute = layout;
if (checkOnlyExpandUnitDims(srcShape, wgShape)) {
xegpu::DistributeLayoutAttr sourceLayout =
- xegpu::getDistributeLayoutAttr(op.getSource());
+ xegpu::getTemporaryLayout(op->getOpOperand(0));
auto usedByBroadcastOp = [](vector::ShapeCastOp op) {
return llvm::all_of(op.getResult().getUsers(), [](Operation *user) {
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index d3906e37ffbf1..10fe06ddb756f 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -186,8 +186,7 @@ xegpu::getDistributeLayoutAttr(const OpOperand &opr) {
return layout;
}
- auto layout = getDistributeLayoutAttr(opr.get());
- return layout;
+ return nullptr;
}
// Returns the permanent layout attribute for the given result if it's
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
index ad346307437e4..5508e6e938f67 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
@@ -155,7 +155,7 @@ gpu.module @test_distribution {
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x128xf32> -> !xegpu.tensor_desc<256x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>>
%1 = xegpu.load_nd %0[%block_id_x, 0] : !xegpu.tensor_desc<256x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>> -> vector<256x128xf32>
%2 = vector.multi_reduction <maximumf>, %1, %cst_0 {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>, dims = [1]>} [1] : vector<256x128xf32> to vector<256xf32>
- %3 = vector.shape_cast %2 {layout_result_0 = #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>} : vector<256xf32> to vector<256x1xf32>
+ %3 = vector.shape_cast %2 {layout_result_0 = #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>, layout_operand_0 = #xegpu.slice<#xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>, dims = [1]>} : vector<256xf32> to vector<256x1xf32>
%4 = vector.broadcast %3 {layout_result_0 = #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>} : vector<256x1xf32>to vector<256x128xf32>
%9 = xegpu.create_nd_tdesc %arg0 : memref<4096x128xf32> -> !xegpu.tensor_desc<256x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>>
xegpu.store_nd %4, %9[%block_id_x, 0] : vector<256x128xf32>, !xegpu.tensor_desc<256x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>>
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index da6ad976d3730..d7145ab91ac2f 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -462,7 +462,7 @@ gpu.module @test_distribution {
%step = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 4], sg_data = [1, 1, 1, 32]>, dims = [0, 1, 2]>} : vector<128xindex>
%muli = arith.muli %cst, %step {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 4], sg_data = [1, 1, 1, 32]>, dims = [0, 1, 2]>} : vector<128xindex>
//CHECK: vector.shape_cast {{.*}} : vector<32xindex> to vector<1x1x1x32xindex>
- %shape_cast = vector.shape_cast %muli {layout_result_0 = #xegpu.layout<sg_layout = [8, 1, 1, 4], sg_data = [1, 1, 1, 32]>} : vector<128xindex> to vector<1x1x1x128xindex>
+ %shape_cast = vector.shape_cast %muli {layout_result_0 = #xegpu.layout<sg_layout = [8, 1, 1, 4], sg_data = [1, 1, 1, 32]>, layout_operand_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 4], sg_data = [1, 1, 1, 32]>, dims = [0, 1, 2]>} : vector<128xindex> to vector<1x1x1x128xindex>
gpu.return
}
@@ -655,4 +655,40 @@ gpu.module @test_distribution {
-> vector<256x128xf32>
gpu.return
}
+
+ // CHECK-LABEL: distribute_nested_slice
+ // CHECK: %[[V0:.*]] = vector.shape_cast %{{.*}} : vector<32x32xf32> to vector<32x1x32x1xf32>
+ // CHECK: %[[V1:.*]] = vector.broadcast %[[V0]] : vector<32x1x32x1xf32> to vector<32x16x32x16xf32>
+ // CHECK: %[[V2:.*]] = vector.shape_cast %[[V1]] : vector<32x16x32x16xf32> to vector<32x16x32x16x1xf32>
+ // CHECK: %[[V3:.*]] = vector.broadcast %[[V2]] : vector<32x16x32x16x1xf32> to vector<32x16x32x16x16xf32>
+ // CHECK: %[[V4:.*]] = vector.shape_cast %[[V3]] : vector<32x16x32x16x16xf32> to vector<32x16x1x32x16x16xf32>
+ // CHECK: %[[V5:.*]] = vector.broadcast %[[V4]] : vector<32x16x1x32x16x16xf32> to vector<32x16x16x32x16x16xf32>
+ gpu.func @distribute_nested_slice(%src: memref<256x256xf32>) {
+
+ %tdesc = xegpu.create_nd_tdesc %src : memref<256x256xf32>
+ -> !xegpu.tensor_desc<256x256xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32]>>
+
+ %load = xegpu.load_nd %tdesc[0, 0] {layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32]>}
+ : !xegpu.tensor_desc<256x256xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32]>>
+ -> vector<256x256xf32>
+
+ %load2 = xegpu.convert_layout %load <{input_layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32]>, target_layout = #xegpu.slice<#xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>, dims=[2]>, dims=[4]>, dims=[1, 3]>}> : vector<256x256xf32>
+
+ %scast = vector.shape_cast %load2 {layout_result_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>, dims=[2]>, dims=[4]>, layout_operand_0 = #xegpu.slice<#xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>, dims=[2]>, dims=[4]>, dims=[1, 3]>} : vector<256x256xf32> to vector<256x1x256x1xf32>
+
+ %bcast = vector.broadcast %scast {layout_result_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>, dims=[2]>, dims=[4]>, layout_operand_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>, dims=[2]>, dims=[4]>} : vector<256x1x256x1xf32> to vector<256x16x256x16xf32>
+
+ %scast1 = vector.shape_cast %bcast {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>, dims=[2]>, layout_operand_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>, dims=[2]>, dims=[4]>} : vector<256x16x256x16xf32> to vector<256x16x256x16x1xf32>
+
+ %bcast1 = vector.broadcast %scast1 {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>, dims=[2]>, layout_operand_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>, dims=[2]>} : vector<256x16x256x16x1xf32> to vector<256x16x256x16x16xf32>
+
+ %scast2 = vector.shape_cast %bcast1 {layout_result_0 =
+ #xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>, layout_operand_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>, dims=[2]>} : vector<256x16x256x16x16xf32> to vector<256x16x1x256x16x16xf32>
+
+ %bcast2 = vector.broadcast %scast2 {layout_result_0 =
+ #xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>, layout_operand_0 =
+ #xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>} : vector<256x16x1x256x16x16xf32> to vector<256x16x16x256x16x16xf32>
+ gpu.return
+ }
+
}
>From 1fd70d9648c446be3a552b81ed5dacec20b78b7f Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 16 Jan 2026 00:53:32 +0000
Subject: [PATCH 2/2] address feedback
---
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 32 +++++++++++++---------
1 file changed, 19 insertions(+), 13 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 13bf475920075..abcfadf9b6fc6 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -587,7 +587,7 @@ bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) {
// dims.
llvm::SmallDenseSet<int64_t> thisDims(
flattenedThis.getDims().asArrayRef().begin(),
- flattenedThis.getDims().asArrayRef().end());
+ flattenedThis.getDims().asArrayRef().end());
return llvm::all_of(flattenedOther.getDims().asArrayRef(),
[&](int64_t dim) { return thisDims.contains(dim); });
}
@@ -603,18 +603,22 @@ bool SliceAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
(flattenedThis.getDims() == flattenedOther.getDims()));
}
-// Helper function to adjust unit dimensions from sliced space to parent space
+// Helper function to adjust dimensions from sliced space to parent space
// say we have a parent shape of rank 4, and slice dims [1,3], so the sliced
-// shape is of rank 2, if we want to set unit dim [0] in sliced space, it maps to
-// dim [0] in parent space; if we want to set unit dim [1] in sliced space, it maps to
-// dim [2] in parent space.
+// shape is of rank 2, if we want to set unit dim [0] in sliced space, it maps
+// to dim [0] in parent space; if we want to set unit dim [1] in sliced space,
+// it maps to dim [2] in parent space.
static SetVector<int64_t>
-adjustUnitDimsWithSliceDims(const SetVector<int64_t> &unitDims,
- ArrayRef<int64_t> sliceDims) {
+mapSlicedDimsToParentSpace(const SetVector<int64_t> &dimsToMap,
+ ArrayRef<int64_t> sliceDims) {
// get max number from sliceDims and unitDims to determine parent space rank
+ // the recovered parent space from sliceDims/unitDims is not necessary the
+ // actual parent rank. As long as the parent space rank covers both maximum
+ // number of sliceDims and unitDims, the algorithm works.
int64_t maxDim = -1;
maxDim = std::max(maxDim, *std::max_element(sliceDims.begin(), sliceDims.end()));
- maxDim = std::max(maxDim, *std::max_element(unitDims.begin(), unitDims.end()));
+ maxDim =
+ std::max(maxDim, *std::max_element(dimsToMap.begin(), dimsToMap.end()));
int64_t parentSpaceRank = maxDim + sliceDims.size() + 1;
// get remaining dims in parent space after applying slicing with parent's slice Dims
@@ -628,9 +632,9 @@ adjustUnitDimsWithSliceDims(const SetVector<int64_t> &unitDims,
// Map unit dims from sliced space to parent space
SetVector<int64_t> adjustUnitDims;
- for (auto dim : unitDims) {
- int64_t mappedDim = remainingDims[dim];
- adjustUnitDims.insert(mappedDim);
+ for (auto dim : dimsToMap) {
+ int64_t mappedDim = remainingDims[dim];
+ adjustUnitDims.insert(mappedDim);
}
return adjustUnitDims;
@@ -642,7 +646,8 @@ DistributeLayoutAttr SliceAttr::setUnitDimData(SetVector<int64_t> unitDims) cons
ArrayRef<int64_t> sliceDims = getDims().asArrayRef();
- SetVector<int64_t> adjustUnitDims = adjustUnitDimsWithSliceDims(unitDims, sliceDims);
+ SetVector<int64_t> adjustUnitDims =
+ mapSlicedDimsToParentSpace(unitDims, sliceDims);
return SliceAttr::get(getContext(), parentLayout.setUnitDimData(adjustUnitDims),
getDims());
@@ -654,7 +659,8 @@ DistributeLayoutAttr SliceAttr::setUnitDimLayout(SetVector<int64_t> unitDims) co
ArrayRef<int64_t> sliceDims = getDims().asArrayRef();
- SetVector<int64_t> adjustUnitDims = adjustUnitDimsWithSliceDims(unitDims, sliceDims);
+ SetVector<int64_t> adjustUnitDims =
+ mapSlicedDimsToParentSpace(unitDims, sliceDims);
return SliceAttr::get(getContext(), parentLayout.setUnitDimLayout(adjustUnitDims),
getDims());
More information about the Mlir-commits
mailing list