[Mlir-commits] [mlir] [MLIR][XeGPU] Support round-robin layout for constant and broadcast in wg-to-sg dis… (PR #189798)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Mar 31 22:38:19 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu
Author: Jianhui Li (Jianhui-Li)
<details>
<summary>Changes</summary>
…tribution
---
Full diff: https://github.com/llvm/llvm-project/pull/189798.diff
3 Files Affected:
- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp (+18-8)
- (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir (+1-1)
- (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir (+26)
``````````diff
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 0aead9172858f..d04933423ecd0 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -492,7 +492,9 @@ struct WgToSgVectorBroadcastOp
if (!layout || !layout.isForWorkgroup())
return failure();
- SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+ SmallVector<int64_t> sgShape;
+ int count;
+ std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
VectorType newResultType =
VectorType::get(sgShape, resultType.getElementType());
@@ -500,11 +502,15 @@ struct WgToSgVectorBroadcastOp
return failure();
SmallVector<Value> newBroadcastOps;
- for (auto operand : adaptor.getOperands().front()) {
- auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
- newResultType, operand);
-
- newBroadcastOps.push_back(newBroadcast.getResult());
+ auto distSource = adaptor.getOperands().front();
+ int numDistributions = count / distSource.size();
+ for (int i = 0; i < numDistributions; ++i) {
+ for (auto operand : distSource) {
+ auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
+ newResultType, operand);
+
+ newBroadcastOps.push_back(newBroadcast.getResult());
+ }
}
rewriter.replaceOpWithMultiple(op, {newBroadcastOps});
return success();
@@ -816,8 +822,12 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
// Splat: single value for all subgroups
Attribute singleVal = vecAttr.getSplatValue<Attribute>();
auto sgAttr = DenseElementsAttr::get(newType, singleVal);
- auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr);
- rewriter.replaceOp(op, cstOp);
+ SmallVector<Value> newConstOps;
+ for (int i = 0; i < count; ++i) {
+ auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr);
+ newConstOps.push_back(cstOp);
+ }
+ rewriter.replaceOpWithMultiple(op, {newConstOps});
return success();
} else if (sgShape == wgShape) { // if the entire vector is shared by all
// subgroups, don't distribute
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index e89cb52ee02f5..e4bf3b6c3bf1d 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -116,7 +116,7 @@ gpu.module @test_round_robin_assignment {
%load = xegpu.load_nd %tdesc {layout = #xegpu.layout<sg_layout = [4, 1], sg_data = [16, 1], lane_layout = [8, 1], lane_data = [1, 1]>}
: !xegpu.tensor_desc<128x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [16, 1], lane_layout = [8, 1], lane_data = [1, 1]>>
-> vector<128x1xf32>
- // CHECK-COUNT-2: vector.broadcast {{.*}} : vector<16x1xf32> to vector<16x32xf32>
+ // CHECK-COUNT-4: vector.broadcast {{.*}} : vector<16x1xf32> to vector<16x32xf32>
// CHECK-NOT: vector.broadcast
%broadcast = vector.broadcast %load
{layout_result_0 = #xegpu.layout<sg_layout = [4, 1], sg_data = [16, 32], lane_layout = [8, 1], lane_data = [1, 1]>}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
index 068dd6d865ead..320a2fb1f72ac 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
@@ -165,4 +165,30 @@ gpu.module @test_distribution {
gpu.return
}
+ // CHECK-LABEL: splat_constant
+ gpu.func @splat_constant() {
+ // CHECK-COUNT-2: %[[CST:.*]] = arith.constant dense<0> : vector<4xindex>
+ %cst_2 = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [16, 1], sg_data = [16, 4], order = [0, 1]>, dims = [0]>} dense<0> : vector<8xindex>
+ gpu.return
+ }
+
+ // CHECK-LABEL: gpu.func @step_broadcast
+ gpu.func @step_broadcast() {
+ // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
+ // CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
+ // CHECK: %[[REM:.*]] = arith.remui %[[SGID]], %[[C16]] : index
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+ // CHECK: %[[STEP:.*]] = vector.step : vector<4xindex>
+ // CHECK: %[[BCST0:.*]] = vector.broadcast %[[C0:.*]] : index to vector<4xindex>
+ // CHECK: %[[ADD0:.*]] = arith.addi %[[STEP]], %[[BCST0]] : vector<4xindex>
+ // CHECK: %[[BCST4:.*]] = vector.broadcast %[[C4:.*]] : index to vector<4xindex>
+ // CHECK: %[[ADD4:.*]] = arith.addi %[[STEP]], %[[BCST4]] : vector<4xindex>
+ // CHECK: %[[RES0:.*]] = vector.broadcast %[[ADD0]] : vector<4xindex> to vector<16x4xindex>
+ // CHECK: %[[RES1:.*]] = vector.broadcast %[[ADD4]] : vector<4xindex> to vector<16x4xindex>
+ %2 = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [16, 1], sg_data = [16, 4]>, dims = [0]>} : vector<8xindex>
+ %bcast = vector.broadcast %2 {layout_result_0 = #xegpu.layout<sg_layout = [16, 1], sg_data = [16, 4]>} : vector<8xindex> to vector<256x8xindex>
+ gpu.return
+ }
+
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/189798
More information about the Mlir-commits
mailing list