[Mlir-commits] [mlir] [MLIR][XeGPU] Support round-robin layout for constant and broadcast in wg-to-sg dis… (PR #189798)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 31 22:38:19 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-gpu

Author: Jianhui Li (Jianhui-Li)

<details>
<summary>Changes</summary>

…tribution

---
Full diff: https://github.com/llvm/llvm-project/pull/189798.diff


3 Files Affected:

- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp (+18-8) 
- (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir (+1-1) 
- (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir (+26) 


``````````diff
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 0aead9172858f..d04933423ecd0 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -492,7 +492,9 @@ struct WgToSgVectorBroadcastOp
     if (!layout || !layout.isForWorkgroup())
       return failure();
 
-    SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+    SmallVector<int64_t> sgShape;
+    int count;
+    std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
     VectorType newResultType =
         VectorType::get(sgShape, resultType.getElementType());
 
@@ -500,11 +502,15 @@ struct WgToSgVectorBroadcastOp
       return failure();
 
     SmallVector<Value> newBroadcastOps;
-    for (auto operand : adaptor.getOperands().front()) {
-      auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
-                                                      newResultType, operand);
-
-      newBroadcastOps.push_back(newBroadcast.getResult());
+    auto distSource = adaptor.getOperands().front();
+    int numDistributions = count / distSource.size();
+    for (int i = 0; i < numDistributions; ++i) {
+      for (auto operand : distSource) {
+        auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
+                                                        newResultType, operand);
+
+        newBroadcastOps.push_back(newBroadcast.getResult());
+      }
     }
     rewriter.replaceOpWithMultiple(op, {newBroadcastOps});
     return success();
@@ -816,8 +822,12 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
       // Splat: single value for all subgroups
       Attribute singleVal = vecAttr.getSplatValue<Attribute>();
       auto sgAttr = DenseElementsAttr::get(newType, singleVal);
-      auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr);
-      rewriter.replaceOp(op, cstOp);
+      SmallVector<Value> newConstOps;
+      for (int i = 0; i < count; ++i) {
+        auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr);
+        newConstOps.push_back(cstOp);
+      }
+      rewriter.replaceOpWithMultiple(op, {newConstOps});
       return success();
     } else if (sgShape == wgShape) { // if the entire vector is shared by all
                                      // subgroups, don't distribute
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index e89cb52ee02f5..e4bf3b6c3bf1d 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -116,7 +116,7 @@ gpu.module @test_round_robin_assignment {
     %load =  xegpu.load_nd %tdesc {layout =  #xegpu.layout<sg_layout = [4, 1], sg_data = [16, 1], lane_layout = [8, 1], lane_data = [1, 1]>}
       : !xegpu.tensor_desc<128x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [16, 1], lane_layout = [8, 1], lane_data = [1, 1]>>
       -> vector<128x1xf32>
-    // CHECK-COUNT-2: vector.broadcast {{.*}} : vector<16x1xf32> to vector<16x32xf32>
+    // CHECK-COUNT-4: vector.broadcast {{.*}} : vector<16x1xf32> to vector<16x32xf32>
     // CHECK-NOT: vector.broadcast
     %broadcast = vector.broadcast %load
       {layout_result_0 = #xegpu.layout<sg_layout = [4, 1], sg_data = [16, 32], lane_layout = [8, 1], lane_data = [1, 1]>}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
index 068dd6d865ead..320a2fb1f72ac 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
@@ -165,4 +165,30 @@ gpu.module @test_distribution {
     gpu.return
   }
 
+  // CHECK-LABEL: splat_constant
+  gpu.func @splat_constant() {
+    // CHECK-COUNT-2: %[[CST:.*]] = arith.constant dense<0> : vector<4xindex>
+    %cst_2 = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [16, 1], sg_data = [16, 4], order = [0, 1]>, dims = [0]>}  dense<0> : vector<8xindex>
+    gpu.return
+  }
+
+  // CHECK-LABEL: gpu.func @step_broadcast
+  gpu.func @step_broadcast() {
+    // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
+    // CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
+    // CHECK: %[[REM:.*]] = arith.remui %[[SGID]], %[[C16]] : index
+    // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+    // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+    // CHECK: %[[STEP:.*]] = vector.step : vector<4xindex>
+    // CHECK: %[[BCST0:.*]] = vector.broadcast %[[C0:.*]] : index to vector<4xindex>
+    // CHECK: %[[ADD0:.*]] = arith.addi %[[STEP]], %[[BCST0]] : vector<4xindex>
+    // CHECK: %[[BCST4:.*]] = vector.broadcast %[[C4:.*]] : index to vector<4xindex>
+    // CHECK: %[[ADD4:.*]] = arith.addi %[[STEP]], %[[BCST4]] : vector<4xindex>
+    // CHECK: %[[RES0:.*]] = vector.broadcast %[[ADD0]] : vector<4xindex> to vector<16x4xindex>
+    // CHECK: %[[RES1:.*]] = vector.broadcast %[[ADD4]] : vector<4xindex> to vector<16x4xindex>
+    %2 = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [16, 1], sg_data = [16, 4]>, dims = [0]>} : vector<8xindex>
+    %bcast = vector.broadcast %2 {layout_result_0 = #xegpu.layout<sg_layout = [16, 1], sg_data = [16, 4]>} : vector<8xindex> to vector<256x8xindex>
+    gpu.return
+  }
+
 }

``````````

</details>


https://github.com/llvm/llvm-project/pull/189798


More information about the Mlir-commits mailing list