[Mlir-commits] [mlir] [MLIR][XeGPU] Enhance Layout Propagation for broadcasting both leading dimensions and inner unit dimensions (PR #185583)
Jianhui Li
llvmlistbot at llvm.org
Tue Mar 10 17:47:44 PDT 2026
https://github.com/Jianhui-Li updated https://github.com/llvm/llvm-project/pull/185583
>From ae1c313b5fe22cabe08de33e9e0775103e4e7185 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 6 Mar 2026 14:00:56 +0000
Subject: [PATCH 1/4] broadcast layout inference enhancing
---
.../XeGPU/Transforms/XeGPULayoutImpl.cpp | 20 +++++++++++--------
.../XeGPU/resolve-layout-conflicts.mlir | 2 +-
2 files changed, 13 insertions(+), 9 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 7aa186bb22224..a278b87070bf2 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -143,20 +143,24 @@ xegpu::inferBroadcastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
ArrayRef<int64_t> srcShape) {
SmallVector<int64_t> bcastDims;
- auto returnLayout = resLayout;
// Handling broadcast from low-rank to high-rank (e.g., 1D to 2D) case.
- int dimDiff = resShape.size() - srcShape.size();
+ size_t dimDiff = resShape.size() - srcShape.size();
+ for (size_t i = 0; i < dimDiff; i++)
+ bcastDims.push_back(i);
- if (dimDiff > 0) {
- // Adding the missing leading dims
- for (int i = 0; i < dimDiff; i++)
+ for (size_t i = 0; i < resShape.size(); i++)
+ if ((i < dimDiff) || ((srcShape[i - dimDiff] == 1) && (resShape[i] != 1)))
bcastDims.push_back(i);
- // Create a slice layout for the source
+ auto returnLayout = resLayout.setUnitDimData(bcastDims);
+ if (dimDiff > 0) {
+ SmallVector<int64_t> sliceDims;
+ for (size_t i = 0; i < dimDiff; i++)
+ sliceDims.push_back(i);
returnLayout = xegpu::SliceAttr::get(
- resLayout.getContext(), resLayout,
- DenseI64ArrayAttr::get(resLayout.getContext(), bcastDims));
+ resLayout.getContext(), returnLayout,
+ DenseI64ArrayAttr::get(resLayout.getContext(), sliceDims));
}
return returnLayout;
}
diff --git a/mlir/test/Dialect/XeGPU/resolve-layout-conflicts.mlir b/mlir/test/Dialect/XeGPU/resolve-layout-conflicts.mlir
index c73a9990542de..033ad188e5347 100644
--- a/mlir/test/Dialect/XeGPU/resolve-layout-conflicts.mlir
+++ b/mlir/test/Dialect/XeGPU/resolve-layout-conflicts.mlir
@@ -115,7 +115,7 @@ func.func @elementwise_conflict_uniform() -> vector<2xf16> {
// CHECK-LABEL: func.func @broadcast_source_conflict
// CHECK: %[[V0:.*]] = "some_op"() {layout_result_0 = #xegpu.layout<inst_data = [16]>} : () -> vector<16xf16>
// CHECK: %[[CVT:.*]] = xegpu.convert_layout %[[V0]]
-// CHECK-SAME: <{input_layout = #xegpu.layout<inst_data = [16]>, target_layout = #xegpu.slice<#xegpu.layout<inst_data = [16, 16]>, dims = [0]>}>
+// CHECK-SAME: <{input_layout = #xegpu.layout<inst_data = [16]>, target_layout = #xegpu.slice<#xegpu.layout<inst_data = [1, 16]>, dims = [0]>}>
// CHECK-SAME: : vector<16xf16>
// CHECK: %[[BC:.*]] = vector.broadcast %[[CVT]]
// CHECK-SAME: {layout_result_0 = #xegpu.layout<inst_data = [16, 16]>} : vector<16xf16> to vector<16x16xf16>
>From 2c52353c52768dc1128f195f812f376d8ec35c71 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Sat, 7 Mar 2026 04:53:31 +0000
Subject: [PATCH 2/4] improve broadcast
---
.../Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp | 14 +++++++++-----
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 12 +++++++-----
.../Dialect/XeGPU/resolve-layout-conflicts.mlir | 2 +-
3 files changed, 17 insertions(+), 11 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index a278b87070bf2..c244a27b988b6 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -146,14 +146,18 @@ xegpu::inferBroadcastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
// Handling broadcast from low-rank to high-rank (e.g., 1D to 2D) case.
size_t dimDiff = resShape.size() - srcShape.size();
- for (size_t i = 0; i < dimDiff; i++)
- bcastDims.push_back(i);
+ // for (size_t i = 0; i < dimDiff; i++)
+ // bcastDims.push_back(i);
- for (size_t i = 0; i < resShape.size(); i++)
- if ((i < dimDiff) || ((srcShape[i - dimDiff] == 1) && (resShape[i] != 1)))
+ auto returnLayout = resLayout;
+ for (size_t i = dimDiff; i < resShape.size(); i++) {
+ if ((srcShape[i - dimDiff] == 1) && (resShape[i] != 1))
bcastDims.push_back(i);
+ }
+
+ if (!bcastDims.empty())
+ returnLayout = returnLayout.setUnitDimData(bcastDims);
- auto returnLayout = resLayout.setUnitDimData(bcastDims);
if (dimDiff > 0) {
SmallVector<int64_t> sliceDims;
for (size_t i = 0; i < dimDiff; i++)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 7f7e8d6ad7734..5abc40b3aea48 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -676,11 +676,13 @@ void LayoutInfoPropagation::visitVectorBroadCastOp(
auto srcShape = sourceTy.getShape();
auto resShape = resultTy.getShape();
- size_t dimDiff = resultTy.getRank() - sourceTy.getRank();
- for (size_t i = 0; i < srcShape.size(); i++)
- if ((srcShape[i] == 1) && (resShape[i + dimDiff] != 1))
- broadcast.emitWarning("broadcast must either from low-rank or same-rank "
- "with unit-dim, mixed scenario is not supported!");
+ // size_t dimDiff = resultTy.getRank() - sourceTy.getRank();
+ // for (size_t i = 0; i < srcShape.size(); i++)
+ // if ((srcShape[i] == 1) && (resShape[i + dimDiff] != 1))
+ // broadcast.emitWarning("broadcast must either from low-rank or same-rank
+ // "
+ // "with unit-dim, mixed scenario is not
+ // supported!");
auto resultLayoutAttr =
dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
diff --git a/mlir/test/Dialect/XeGPU/resolve-layout-conflicts.mlir b/mlir/test/Dialect/XeGPU/resolve-layout-conflicts.mlir
index 033ad188e5347..c73a9990542de 100644
--- a/mlir/test/Dialect/XeGPU/resolve-layout-conflicts.mlir
+++ b/mlir/test/Dialect/XeGPU/resolve-layout-conflicts.mlir
@@ -115,7 +115,7 @@ func.func @elementwise_conflict_uniform() -> vector<2xf16> {
// CHECK-LABEL: func.func @broadcast_source_conflict
// CHECK: %[[V0:.*]] = "some_op"() {layout_result_0 = #xegpu.layout<inst_data = [16]>} : () -> vector<16xf16>
// CHECK: %[[CVT:.*]] = xegpu.convert_layout %[[V0]]
-// CHECK-SAME: <{input_layout = #xegpu.layout<inst_data = [16]>, target_layout = #xegpu.slice<#xegpu.layout<inst_data = [1, 16]>, dims = [0]>}>
+// CHECK-SAME: <{input_layout = #xegpu.layout<inst_data = [16]>, target_layout = #xegpu.slice<#xegpu.layout<inst_data = [16, 16]>, dims = [0]>}>
// CHECK-SAME: : vector<16xf16>
// CHECK: %[[BC:.*]] = vector.broadcast %[[CVT]]
// CHECK-SAME: {layout_result_0 = #xegpu.layout<inst_data = [16, 16]>} : vector<16xf16> to vector<16x16xf16>
>From a00bd8a30e4f80824123f09ccbabda261ebfe35a Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Mon, 9 Mar 2026 23:21:16 +0000
Subject: [PATCH 3/4] polish the implement: change warining to assert
---
.../XeGPU/Transforms/XeGPULayoutImpl.cpp | 2 ++
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 17 ++++++++++-------
.../Transforms/XeGPUSubgroupDistribute.cpp | 9 +++------
.../XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 11 ++++++-----
.../XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir | 4 ++--
.../Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir | 14 +++++++-------
6 files changed, 30 insertions(+), 27 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 2f232bcba3496..b94c4b37fce0a 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -155,6 +155,8 @@ xegpu::inferBroadcastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
bcastDims.push_back(i);
}
+ // the sg_layout and lane_layout for unit dimensions are preserved so it can
+ // be propagate to producer op so potentially used by the multi-reduction op.
if (!bcastDims.empty())
returnLayout = returnLayout.setUnitDimData(bcastDims);
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 8141e9b87ac95..28da5226f1413 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -676,13 +676,16 @@ void LayoutInfoPropagation::visitVectorBroadCastOp(
auto srcShape = sourceTy.getShape();
auto resShape = resultTy.getShape();
- // size_t dimDiff = resultTy.getRank() - sourceTy.getRank();
- // for (size_t i = 0; i < srcShape.size(); i++)
- // if ((srcShape[i] == 1) && (resShape[i + dimDiff] != 1))
- // broadcast.emitWarning("broadcast must either from low-rank or same-rank
- // "
- // "with unit-dim, mixed scenario is not
- // supported!");
+ size_t dimDiff = resultTy.getRank() - sourceTy.getRank();
+ if (dimDiff == 0) {
+ bool hasUnitDim =
+ llvm::any_of(srcShape, [](int64_t dim) { return dim == 1; });
+ Operation *srcOp = broadcast.getSource().getDefiningOp();
+ bool produceByShapeCast = srcOp || isa<vector::ShapeCastOp>(srcOp);
+ assert(
+ hasUnitDim && produceByShapeCast &&
+ "When broadcasting from unit-dim, the producer op must be shape_cast!");
+ }
auto resultLayoutAttr =
dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 38bc95d39c2c6..d7db33a0913b8 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1520,12 +1520,9 @@ struct VectorBroadcastDistribution : public gpu::WarpDistributionPattern {
auto broadcastUnitDimsSet = broadcastOp.computeBroadcastedUnitDims();
SmallVector<int64_t> broadcastUnitDims(broadcastUnitDimsSet.begin(),
broadcastUnitDimsSet.end());
- bool isEqualTo = sourceLayout.isEqualTo(resultLayout);
- if (!isEqualTo)
- return rewriter.notifyMatchFailure(
- warpOp, "For same-rank broadcast, source must be identical to "
- "adjusted result layouts with unit dims.");
- resultLayout = resultLayout.setUnitDimData(broadcastUnitDims);
+ assert(sourceLayout.isEqualTo(
+ sourceLayout.setUnitDimData(broadcastUnitDims)) &&
+ "The sg_data for unit dimensions should be set as 1");
sourceLayout = sourceLayout.setUnitDimLayout(broadcastUnitDims);
}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 139a30e76854f..5a380685d9e7b 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1173,11 +1173,12 @@ struct WgToSgVectorShapeCastOp
if (!sourceLayout.isSliceOf(layout))
return rewriter.notifyMatchFailure(
- op, "The ShapeCast op only expands dimensions, the result layout "
- "must be a slice of the input layout, or vice versa.");
- layoutToDistribute = layoutToDistribute.setUnitDimData(expandedUnitDims);
- layoutToDistribute =
- layoutToDistribute.setUnitDimLayout(expandedUnitDims);
+ op, "The ShapeCast op only expands dimensions, the input layout "
+ "must be a slice of the result layout.");
+
+ assert(layoutToDistribute.isEqualTo(
+ layoutToDistribute.setUnitDimData(expandedUnitDims)) &&
+ "The sg_data for unit dimensions should be set as 1");
}
SmallVector<int64_t> sgShape =
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
index 28865c2ec50bb..068dd6d865ead 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
@@ -157,8 +157,8 @@ gpu.module @test_distribution {
%block_id_x = gpu.block_id x
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x128xf32> -> !xegpu.tensor_desc<256x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>>
%1 = xegpu.load_nd %0[%block_id_x, 0] {layout = #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>} : !xegpu.tensor_desc<256x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>> -> vector<256x128xf32>
- %2 = vector.multi_reduction <maximumf>, %1, %cst_0 {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>, dims = [1]>} [1] : vector<256x128xf32> to vector<256xf32>
- %3 = vector.shape_cast %2 {layout_result_0 = #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>, layout_operand_0 = #xegpu.slice<#xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>, dims = [1]>} : vector<256xf32> to vector<256x1xf32>
+ %2 = vector.multi_reduction <maximumf>, %1, %cst_0 {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [32, 1], sg_data = [8, 1], inst_data = [8, 1]>, dims = [1]>} [1] : vector<256x128xf32> to vector<256xf32>
+ %3 = vector.shape_cast %2 {layout_result_0 = #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 1], inst_data = [8, 1]>, layout_operand_0 = #xegpu.slice<#xegpu.layout<sg_layout = [32, 1], sg_data = [8, 1], inst_data = [8, 1]>, dims = [1]>} : vector<256xf32> to vector<256x1xf32>
%4 = vector.broadcast %3 {layout_result_0 = #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>} : vector<256x1xf32>to vector<256x128xf32>
%9 = xegpu.create_nd_tdesc %arg0 : memref<4096x128xf32> -> !xegpu.tensor_desc<256x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>>
xegpu.store_nd %4, %9[%block_id_x, 0] : vector<256x128xf32>, !xegpu.tensor_desc<256x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>>
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 9a7ea4d1c2f43..9e0ae881c8a7e 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -948,22 +948,22 @@ gpu.module @test_distribution {
: !xegpu.tensor_desc<256x256xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32]>>
-> vector<256x256xf32>
- %load2 = xegpu.convert_layout %load <{input_layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32]>, target_layout = #xegpu.slice<#xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>, dims=[2]>, dims=[4]>, dims=[1, 3]>}> : vector<256x256xf32>
+ %load2 = xegpu.convert_layout %load <{input_layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32]>, target_layout = #xegpu.slice<#xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 1, 16, 32, 1, 1]>, dims=[2]>, dims=[4]>, dims=[1, 3]>}> : vector<256x256xf32>
- %scast = vector.shape_cast %load2 {layout_result_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>, dims=[2]>, dims=[4]>, layout_operand_0 = #xegpu.slice<#xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>, dims=[2]>, dims=[4]>, dims=[1, 3]>} : vector<256x256xf32> to vector<256x1x256x1xf32>
+ %scast = vector.shape_cast %load2 {layout_result_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 1, 16, 32, 1, 1]>, dims=[2]>, dims=[4]>, layout_operand_0 = #xegpu.slice<#xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 1, 16, 32, 1, 1]>, dims=[2]>, dims=[4]>, dims=[1, 3]>} : vector<256x256xf32> to vector<256x1x256x1xf32>
- %bcast = vector.broadcast %scast {layout_result_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>, dims=[2]>, dims=[4]>, layout_operand_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>, dims=[2]>, dims=[4]>} : vector<256x1x256x1xf32> to vector<256x16x256x16xf32>
+ %bcast = vector.broadcast %scast {layout_result_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 1]>, dims=[2]>, dims=[4]>, layout_operand_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 1, 16, 32, 1, 1]>, dims=[2]>, dims=[4]>} : vector<256x1x256x1xf32> to vector<256x16x256x16xf32>
- %scast1 = vector.shape_cast %bcast {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>, dims=[2]>, layout_operand_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>, dims=[2]>, dims=[4]>} : vector<256x16x256x16xf32> to vector<256x16x256x16x1xf32>
+ %scast1 = vector.shape_cast %bcast {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 1]>, dims=[2]>, layout_operand_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 1]>, dims=[2]>, dims=[4]>} : vector<256x16x256x16xf32> to vector<256x16x256x16x1xf32>
- %bcast1 = vector.broadcast %scast1 {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>, dims=[2]>, layout_operand_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>, dims=[2]>} : vector<256x16x256x16x1xf32> to vector<256x16x256x16x16xf32>
+ %bcast1 = vector.broadcast %scast1 {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>, dims=[2]>, layout_operand_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 1]>, dims=[2]>} : vector<256x16x256x16x1xf32> to vector<256x16x256x16x16xf32>
%scast2 = vector.shape_cast %bcast1 {layout_result_0 =
- #xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>, layout_operand_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>, dims=[2]>} : vector<256x16x256x16x16xf32> to vector<256x16x1x256x16x16xf32>
+ #xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 1, 32, 16, 16]>, layout_operand_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 1, 32, 16, 16]>, dims=[2]>} : vector<256x16x256x16x16xf32> to vector<256x16x1x256x16x16xf32>
%bcast2 = vector.broadcast %scast2 {layout_result_0 =
#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>, layout_operand_0 =
- #xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>} : vector<256x16x1x256x16x16xf32> to vector<256x16x16x256x16x16xf32>
+ #xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 1, 32, 16, 16]>} : vector<256x16x1x256x16x16xf32> to vector<256x16x16x256x16x16xf32>
gpu.return
}
>From 55daece69d3519cb095d8db2605a38ef3c61e389 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 10 Mar 2026 06:07:59 +0000
Subject: [PATCH 4/4] add tests
---
.../XeGPU/propagate-layout-inst-data.mlir | 21 +++++++++++++++
.../XeGPU/propagate-layout-subgroup.mlir | 21 +++++++++++++++
mlir/test/Dialect/XeGPU/propagate-layout.mlir | 27 +++++++++++++++++++
3 files changed, 69 insertions(+)
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
index c510a1d5f0fdf..503fb25deb151 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
@@ -334,3 +334,24 @@ gpu.module @test{
return
}
}
+
+// -----
+gpu.module @test{
+ // CHECK-LABEL: broadcast_both_leadingdims_innerdims
+ // CHECK: arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 1, 1, 16]>} dense<true> : vector<2x2x6x32xi1>
+ // CHECK: arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 1, 1, 16]>} dense<1.000000e+00> : vector<2x2x6x32xf32>
+ // CHECK: vector.step {layout_result_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<inst_data = [1, 1, 1, 1]>, dims = [0, 1]>, dims = [1]>} : vector<6xindex>
+ // CHECK: vector.shape_cast {{.*}} {layout_result_0 = #xegpu.slice<#xegpu.layout<inst_data = [1, 1, 1, 1]>, dims = [0, 1]>} : vector<6xindex> to vector<6x1xindex>
+ // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<inst_data = [1, 1, 1, 16]>} : vector<6x1xindex> to vector<2x2x6x32xindex>
+ gpu.func @broadcast_both_leadingdims_innerdims(%arg0: memref<32x2x192xf32>, %arg1: memref<32x2x192xf32>, %arg2: memref<32x2x192xf32>) kernel attributes {known_block_size = array<i32: 768, 1, 1>, known_grid_size = array<i32: 16, 1, 1>} {
+ %cst = arith.constant dense<true> : vector<2x2x6x32xi1>
+ %cst_0 = arith.constant dense<1.000000e+00> : vector<2x2x6x32xf32>
+ %intptr = memref.extract_aligned_pointer_as_index %arg2 : memref<32x2x192xf32> -> index
+ %0 = arith.index_cast %intptr : index to i64
+ %1 = vector.step : vector<6xindex>
+ %2 = vector.shape_cast %1 : vector<6xindex> to vector<6x1xindex>
+ %3 = vector.broadcast %2 : vector<6x1xindex> to vector<2x2x6x32xindex>
+ xegpu.store %cst_0, %0[%3], %cst <{layout = #xegpu.layout<inst_data = [1, 1, 1, 16]>}> : vector<2x2x6x32xf32>, i64, vector<2x2x6x32xindex>, vector<2x2x6x32xi1>
+ gpu.return
+ }
+}
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
index 39fd815b1b380..9ee3de4490727 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
@@ -165,6 +165,27 @@ gpu.module @test {
}
}
+// -----
+gpu.module @test {
+ // CHECK-LABEL: broadcast_both_leadingdims_innerdims
+ gpu.func @broadcast_both_leadingdims_innerdims(%arg0: memref<32x2x192xf32>, %arg1: memref<32x2x192xf32>, %arg2: memref<32x2x192xf32>) kernel attributes {known_block_size = array<i32: 768, 1, 1>, known_grid_size = array<i32: 16, 1, 1>} {
+ // CHECK: arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [2, 2, 6, 1], sg_data = [1, 1, 1, 32]>} dense<true> : vector<2x2x6x32xi1>
+ %cst = arith.constant dense<true> : vector<2x2x6x32xi1>
+ // CHECK: arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [2, 2, 6, 1], sg_data = [1, 1, 1, 32]>} dense<1.000000e+00> : vector<2x2x6x32xf32>
+ %cst_0 = arith.constant dense<1.000000e+00> : vector<2x2x6x32xf32>
+ %intptr = memref.extract_aligned_pointer_as_index %arg2 : memref<32x2x192xf32> -> index
+ %0 = arith.index_cast %intptr : index to i64
+ // CHECK: vector.step {layout_result_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [2, 2, 6, 1], sg_data = [1, 1, 1, 1]>, dims = [0, 1]>, dims = [1]>} : vector<6xindex>
+ %1 = vector.step : vector<6xindex>
+ // CHECK: vector.shape_cast {{.*}} {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [2, 2, 6, 1], sg_data = [1, 1, 1, 1]>, dims = [0, 1]>} : vector<6xindex> to vector<6x1xindex>
+ %2 = vector.shape_cast %1 : vector<6xindex> to vector<6x1xindex>
+ // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<sg_layout = [2, 2, 6, 1], sg_data = [1, 1, 1, 32]>} : vector<6x1xindex> to vector<2x2x6x32xindex>
+ %3 = vector.broadcast %2 : vector<6x1xindex> to vector<2x2x6x32xindex>
+ xegpu.store %cst_0, %0[%3], %cst <{layout = #xegpu.layout<sg_layout = [2, 2, 6, 1], sg_data = [1, 1, 1, 32]>}> : vector<2x2x6x32xf32>, i64, vector<2x2x6x32xindex>, vector<2x2x6x32xi1>
+ gpu.return
+ }
+}
+
// -----
gpu.module @test {
// CHECK-LABEL: for_loop_dpas
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index 43c01445876cb..ddd2d22108d1f 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -638,6 +638,33 @@ func.func @vector_shape_cast_1d_to_2d_dim0_broadcasted(%arg0: !xegpu.tensor_desc
}
}
// -----
+gpu.module @test{
+ // CHECK-LABEL: broadcast_both_leadingdims_innerdims
+ // CHECK: arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 1, 1, 16], lane_data = [1, 1, 1, 1]>} dense<1.000000e+00> : vector<1x1x1x16xf32>
+ // CHECK: arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 1, 1, 16], lane_data = [1, 1, 1, 1]>} dense<true> : vector<1x1x1x16xi1>
+ // CHECK: vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 1, 16], lane_data = [1, 1, 1, 1]>, dims = [0, 1, 2]>} : vector<1xindex>
+ // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 1, 16], lane_data = [1, 1, 1, 1]>, dims = [0, 1, 2]>} : index to vector<1xindex>
+ // CHECK: arith.addi {{.*}} {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 1, 16], lane_data = [1, 1, 1, 1]>, dims = [0, 1, 2]>} : vector<1xindex>
+ // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 1, 1, 16], lane_data = [1, 1, 1, 1]>} : vector<1xindex> to vector<1x1x1x16xindex>
+ gpu.func @broadcast_both_leadingdims_innerdims(%arg0: memref<32x2x192xf32>, %arg1: memref<32x2x192xf32>, %arg2: memref<32x2x192xf32>) kernel attributes {known_block_size = array<i32: 768, 1, 1>, known_grid_size = array<i32: 16, 1, 1>} {
+ %cst = arith.constant dense<1.000000e+00> : vector<1x1x1x16xf32>
+ %cst_0 = arith.constant dense<true> : vector<1x1x1x16xi1>
+ %c6 = arith.constant 6 : index
+ %intptr = memref.extract_aligned_pointer_as_index %arg2 : memref<32x2x192xf32> -> index
+ %0 = arith.index_cast %intptr : index to i64
+ %1 = gpu.subgroup_id : index
+ %2 = arith.remui %1, %c6 : index
+ %3 = arith.remui %2, %c6 : index
+ %4 = vector.step : vector<1xindex>
+ %5 = vector.broadcast %3 : index to vector<1xindex>
+ %6 = arith.addi %4, %5 : vector<1xindex>
+ %7 = vector.broadcast %6 : vector<1xindex> to vector<1x1x1x16xindex>
+ xegpu.store %cst, %0[%7], %cst_0 <{chunk_size = 1 : i64}> : vector<1x1x1x16xf32>, i64, vector<1x1x1x16xindex>, vector<1x1x1x16xi1>
+ xegpu.store %cst, %0[%7], %cst_0 <{chunk_size = 1 : i64}> : vector<1x1x1x16xf32>, i64, vector<1x1x1x16xindex>, vector<1x1x1x16xi1>
+ gpu.return
+ }
+}
+// -----
gpu.module @test {
// CHECK-LABEL: func.func @vector_shape_cast_expand_non_unit_dims(
// CHECK: %[[LOAD:.*]] = xegpu.load %arg0[%[[STEP:.*]]], %[[CST:.*]] <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}> : memref<1024xf16>, vector<1024xindex>, vector<1024xi1> -> vector<1024xf16>
More information about the Mlir-commits
mailing list