[Mlir-commits] [mlir] [MLIR][XeGPU] Add handling for unit-dim expansion in ShapeCast workgroup-to-subgroup distribution (PR #171758)
Jianhui Li
llvmlistbot at llvm.org
Tue Dec 16 09:59:07 PST 2025
https://github.com/Jianhui-Li updated https://github.com/llvm/llvm-project/pull/171758
>From a8f6c51683164884bd744d6a6b978572d821a235 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 11 Dec 2025 03:13:43 +0000
Subject: [PATCH 1/2] adjust the layout for expandedUnitDims and wg-to-sg
distribution shapecast op
---
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 4 -
.../Transforms/XeGPUSubgroupDistribute.cpp | 2 +-
.../Transforms/XeGPUWgToSgDistribute.cpp | 75 ++++++++++++-------
mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir | 17 +++++
4 files changed, 64 insertions(+), 34 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index dc9eb96c169b4..12d1c494a0b72 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -609,10 +609,6 @@ void LayoutInfoPropagation::visitVectorBroadCastOp(
propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout)));
return;
}
-
- SetVector<int64_t> broadcastUnitDims = broadcast.computeBroadcastedUnitDims();
- resultLayout = cast<xegpu::DistributeLayoutAttr>(resultLayout.get())
- .setUnitDimData(broadcastUnitDims);
propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index ca81c3cd7be42..27273ee245cf2 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1530,12 +1530,12 @@ struct VectorBroadcastDistribution : public gpu::WarpDistributionPattern {
if (rankDiff == 0) {
SetVector<int64_t> broadcastUnitDims =
broadcastOp.computeBroadcastedUnitDims();
- resultLayout = resultLayout.setUnitDimData(broadcastUnitDims);
bool isEqualTo = sourceLayout.isEqualTo(resultLayout);
if (!isEqualTo)
return rewriter.notifyMatchFailure(
warpOp, "For same-rank broadcast, source must be identical to "
"adjusted result layouts with unit dims.");
+ resultLayout = resultLayout.setUnitDimData(broadcastUnitDims);
sourceLayout = sourceLayout.setUnitDimLayout(broadcastUnitDims);
}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index be82cda574f1e..ef5da57c5f3b4 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1111,41 +1111,58 @@ struct WgToSgVectorShapeCastOp
if (!layout || !layout.isForWorkgroup())
return failure();
- SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
- VectorType newResultType =
- VectorType::get(sgShape, resultType.getElementType());
-
- // TODO: Add check for compatible layouts in layout attr.
- auto srcType = dyn_cast<VectorType>(adaptor.getSource()[0].getType());
+ // Check that srcShape and destShape, if they differ, only differ by
+ // expand of unit dimensions.
+ auto srcType = dyn_cast<VectorType>(op.getSource().getType());
if (!srcType)
return failure();
- // Check that shape_cast only adds/removes unit dimensions,
- auto onlyUnitDims = [](ArrayRef<int64_t> src, ArrayRef<int64_t> dst) {
- // Remove all 1s from both shapes and compare the rest.
- SmallVector<int64_t> srcNonUnit, dstNonUnit;
- for (int64_t d : src)
- if (d != 1)
- srcNonUnit.push_back(d);
- for (int64_t d : dst)
- if (d != 1)
- dstNonUnit.push_back(d);
- return srcNonUnit == dstNonUnit;
+ ArrayRef<int64_t> srcShape = srcType.getShape();
+ llvm::SetVector<int64_t> expandedUnitDims;
+
+ // Check if shapes only differ by expanding unit dimensions (like
+ // expand_dims)
+ auto checkOnlyExpandUnitDims = [&](ArrayRef<int64_t> src,
+ ArrayRef<int64_t> dst) -> bool {
+ // All unit dimensions in dst that don't appear in src are the expanded
+ // unit dimensions
+ size_t srcIdx = 0;
+ for (size_t dstIdx = 0; dstIdx < dst.size(); ++dstIdx)
+ if (srcIdx < src.size() && src[srcIdx] == dst[dstIdx])
+ srcIdx++;
+ else if (dst[dstIdx] == 1)
+ expandedUnitDims.insert(dstIdx);
+ else
+ return false;
+ return srcIdx == src.size();
};
- if (!onlyUnitDims(srcType.getShape(), sgShape))
- return failure();
+ if (checkOnlyExpandUnitDims(srcShape, wgShape)) {
+ xegpu::DistributeLayoutAttr sourceLayout =
+ xegpu::getDistributeLayoutAttr(op.getSource());
- // For rank reducing or increasing shape_cast ops, the lower rank layout
- // must be a slice of higher rank layout.
- int64_t sourceRank = srcType.getRank();
- int64_t resultRank = sgShape.size();
- xegpu::DistributeLayoutAttr sourceLayout =
- xegpu::getDistributeLayoutAttr(op.getSource());
- if (sourceRank < resultRank && !sourceLayout.isSliceOf(layout))
- return failure();
- if (sourceRank > resultRank && !layout.isSliceOf(sourceLayout))
- return failure();
+ auto usedByBroadcastOp = [](vector::ShapeCastOp op) {
+ return llvm::all_of(op.getResult().getUsers(), [](Operation *user) {
+ return isa<vector::BroadcastOp>(user);
+ });
+ };
+
+ if (!usedByBroadcastOp(op)) {
+ return rewriter.notifyMatchFailure(
+ op, "ShapeCast ops that expand unit dimensions and are used by "
+ "non-broadcast operations are not supported.");
+ }
+ if (!sourceLayout.isSliceOf(layout))
+ return rewriter.notifyMatchFailure(
+ op, "The ShapeCast op only expands dimensions, the result layout "
+ "must be a slice of the input layout, or vice versa.");
+ layout = layout.setUnitDimData(expandedUnitDims);
+ layout = layout.setUnitDimLayout(expandedUnitDims);
+ }
+
+ SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+ VectorType newResultType =
+ VectorType::get(sgShape, resultType.getElementType());
SmallVector<Value> newShapeCastOps;
for (auto src : adaptor.getSource()) {
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index a8015cced7eb4..7f651ef5fdc14 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -380,4 +380,21 @@ gpu.module @test_1_1_assignment {
%cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} dense<1.0> : vector<256x128xf32>
gpu.return
}
+
+ // CHECK-LABEL: distribute_shapecast_expandunitdims_broadcast
+ // CHECK: %[[CAST:.*]] = vector.shape_cast %[[REDUCE:.*]] {layout_result_0 = #xegpu.layout<inst_data = [8, 1]>} : vector<8xf32> to vector<8x1xf32>
+ // CHECK: %[[BCAST:.*]] = vector.broadcast %[[CAST]] {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} : vector<8x1xf32> to vector<8x128xf32>
+ gpu.func @distribute_shapecast_expandunitdims_broadcast(%arg0: memref<4096x128xf32>, %arg1: memref<4096x128xf32>) {
+ %cst_0 = arith.constant {layout_result_0=#xegpu.slice<#xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>, dims = [1]>} dense<0xFF800000> : vector<256xf32>
+ %block_id_x = gpu.block_id x
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x128xf32> -> !xegpu.tensor_desc<256x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>>
+ %1 = xegpu.load_nd %0[%block_id_x, 0] : !xegpu.tensor_desc<256x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>> -> vector<256x128xf32>
+ %2 = vector.multi_reduction <maximumf>, %1, %cst_0 {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>, dims = [1]>} [1] : vector<256x128xf32> to vector<256xf32>
+ %3 = vector.shape_cast %2 {layout_result_0 = #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>} : vector<256xf32> to vector<256x1xf32>
+ %4 = vector.broadcast %3 {layout_result_0 = #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>} : vector<256x1xf32>to vector<256x128xf32>
+ %9 = xegpu.create_nd_tdesc %arg0 : memref<4096x128xf32> -> !xegpu.tensor_desc<256x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>>
+ xegpu.store_nd %4, %9[%block_id_x, 0] : vector<256x128xf32>, !xegpu.tensor_desc<256x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>>
+ gpu.return
+ }
}
+
>From 49d66980f904638c83aa47b3637900a851d8d3dc Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 16 Dec 2025 17:58:07 +0000
Subject: [PATCH 2/2] address feedback
---
.../XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 14 +++++++++-----
.../XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir | 17 +++++++++++++++++
mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir | 17 -----------------
3 files changed, 26 insertions(+), 22 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index ef5da57c5f3b4..95e27e46d90ab 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1137,6 +1137,8 @@ struct WgToSgVectorShapeCastOp
return srcIdx == src.size();
};
+ xegpu::DistributeLayoutAttr layoutToDistribute = layout;
+
if (checkOnlyExpandUnitDims(srcShape, wgShape)) {
xegpu::DistributeLayoutAttr sourceLayout =
xegpu::getDistributeLayoutAttr(op.getSource());
@@ -1147,20 +1149,22 @@ struct WgToSgVectorShapeCastOp
});
};
- if (!usedByBroadcastOp(op)) {
+ if (!usedByBroadcastOp(op))
return rewriter.notifyMatchFailure(
op, "ShapeCast ops that expand unit dimensions and are used by "
"non-broadcast operations are not supported.");
- }
+
if (!sourceLayout.isSliceOf(layout))
return rewriter.notifyMatchFailure(
op, "The ShapeCast op only expands dimensions, the result layout "
"must be a slice of the input layout, or vice versa.");
- layout = layout.setUnitDimData(expandedUnitDims);
- layout = layout.setUnitDimLayout(expandedUnitDims);
+ layoutToDistribute = layoutToDistribute.setUnitDimData(expandedUnitDims);
+ layoutToDistribute =
+ layoutToDistribute.setUnitDimLayout(expandedUnitDims);
}
- SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+ SmallVector<int64_t> sgShape =
+ getSgShapeAndCount(wgShape, layoutToDistribute).first;
VectorType newResultType =
VectorType::get(sgShape, resultType.getElementType());
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
index c95c64084f3f8..37a76f316e75a 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
@@ -143,4 +143,21 @@ gpu.module @test_distribution {
%constant_mask = vector.create_mask %cst16, %cst16 {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16]>} : vector<256x128xi1>
gpu.return
}
+
+ // CHECK-LABEL: distribute_shapecast_expandunitdims_broadcast
+ // CHECK: %[[CAST:.*]] = vector.shape_cast %[[REDUCE:.*]] {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} : vector<8xf32> to vector<8x1xf32>
+ // CHECK: %[[BCAST:.*]] = vector.broadcast %[[CAST]] {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} : vector<8x1xf32> to vector<8x128xf32>
+ gpu.func @distribute_shapecast_expandunitdims_broadcast(%arg0: memref<4096x128xf32>, %arg1: memref<4096x128xf32>) {
+ %cst_0 = arith.constant {layout_result_0=#xegpu.slice<#xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>, dims = [1]>} dense<0xFF800000> : vector<256xf32>
+ %block_id_x = gpu.block_id x
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x128xf32> -> !xegpu.tensor_desc<256x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>>
+ %1 = xegpu.load_nd %0[%block_id_x, 0] : !xegpu.tensor_desc<256x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>> -> vector<256x128xf32>
+ %2 = vector.multi_reduction <maximumf>, %1, %cst_0 {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>, dims = [1]>} [1] : vector<256x128xf32> to vector<256xf32>
+ %3 = vector.shape_cast %2 {layout_result_0 = #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>} : vector<256xf32> to vector<256x1xf32>
+ %4 = vector.broadcast %3 {layout_result_0 = #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>} : vector<256x1xf32>to vector<256x128xf32>
+ %9 = xegpu.create_nd_tdesc %arg0 : memref<4096x128xf32> -> !xegpu.tensor_desc<256x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>>
+ xegpu.store_nd %4, %9[%block_id_x, 0] : vector<256x128xf32>, !xegpu.tensor_desc<256x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>>
+ gpu.return
+ }
+
}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index 7f651ef5fdc14..a8015cced7eb4 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -380,21 +380,4 @@ gpu.module @test_1_1_assignment {
%cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} dense<1.0> : vector<256x128xf32>
gpu.return
}
-
- // CHECK-LABEL: distribute_shapecast_expandunitdims_broadcast
- // CHECK: %[[CAST:.*]] = vector.shape_cast %[[REDUCE:.*]] {layout_result_0 = #xegpu.layout<inst_data = [8, 1]>} : vector<8xf32> to vector<8x1xf32>
- // CHECK: %[[BCAST:.*]] = vector.broadcast %[[CAST]] {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} : vector<8x1xf32> to vector<8x128xf32>
- gpu.func @distribute_shapecast_expandunitdims_broadcast(%arg0: memref<4096x128xf32>, %arg1: memref<4096x128xf32>) {
- %cst_0 = arith.constant {layout_result_0=#xegpu.slice<#xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>, dims = [1]>} dense<0xFF800000> : vector<256xf32>
- %block_id_x = gpu.block_id x
- %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x128xf32> -> !xegpu.tensor_desc<256x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>>
- %1 = xegpu.load_nd %0[%block_id_x, 0] : !xegpu.tensor_desc<256x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>> -> vector<256x128xf32>
- %2 = vector.multi_reduction <maximumf>, %1, %cst_0 {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>, dims = [1]>} [1] : vector<256x128xf32> to vector<256xf32>
- %3 = vector.shape_cast %2 {layout_result_0 = #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>} : vector<256xf32> to vector<256x1xf32>
- %4 = vector.broadcast %3 {layout_result_0 = #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>} : vector<256x1xf32>to vector<256x128xf32>
- %9 = xegpu.create_nd_tdesc %arg0 : memref<4096x128xf32> -> !xegpu.tensor_desc<256x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>>
- xegpu.store_nd %4, %9[%block_id_x, 0] : vector<256x128xf32>, !xegpu.tensor_desc<256x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>>
- gpu.return
- }
}
-
More information about the Mlir-commits
mailing list