[Mlir-commits] [mlir] [MLIR][XeGPU] Fix Multi-Reduction Layout Rule to Preserve sg_data from consumer layout for non-reduced dims (PR #189795)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 31 22:07:58 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-gpu

Author: Jianhui Li (Jianhui-Li)

<details>
<summary>Changes</summary>

As title

---
Full diff: https://github.com/llvm/llvm-project/pull/189795.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp (+6-4) 
- (modified) mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir (+19) 


``````````diff
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index ec5751634fdff..278414c6581c7 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -462,6 +462,8 @@ xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
 
   SmallVector<int64_t> consumerSgLayout =
       consumerLayout.getEffectiveSgLayoutAsInt();
+  SmallVector<int64_t> consumerSgData =
+      consumerLayout.getEffectiveSgDataAsInt();
   SmallVector<int64_t> consumerLaneLayout =
       consumerLayout.getEffectiveLaneLayoutAsInt();
   SmallVector<int64_t> consumerOrder = consumerLayout.getEffectiveOrderAsInt();
@@ -479,7 +481,9 @@ xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
       auto srcSgData = computeShapeRatio(srcShape, sgLayoutFromConsumer);
       if (srcSgData)
         for (int dim = 0; dim < srcRank; dim++) {
-          srcLayout = srcLayout.setDimData(dim, srcSgData.value()[dim], -1, -1);
+          if (llvm::is_contained(reductionDims, dim))
+            srcLayout =
+                srcLayout.setDimData(dim, srcSgData.value()[dim], -1, -1);
         }
     } else {
 
@@ -492,9 +496,7 @@ xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
         if (!llvm::is_contained(reductionDims, i) &&
             consumerIdx < static_cast<int>(consumerSgLayout.size())) {
           sgLayout[i] = consumerSgLayout[consumerIdx];
-          assert((srcShape[i] % sgLayout[i] == 0) &&
-                 "source shape not divisible by consumer sg_layout");
-          sgData[i] = srcShape[i] / sgLayout[i];
+          sgData[i] = consumerSgData[consumerIdx];
           remainingSgCount /= sgLayout[i];
           order[i] = consumerOrder[consumerIdx];
           consumerIdx++;
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
index e4e6d61b92fda..a8ce58aa53a2f 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
@@ -196,6 +196,25 @@ gpu.module @test {
   }
 }
 
+// -----
+#data_layout = #xegpu.layout<sg_layout = [1, 16], sg_data = [4, 16]>
+gpu.module @test {
+  // CHECK-LABEL: gpu.func @vector_reduction_preserve_round_robin_layout
+  gpu.func @vector_reduction_preserve_round_robin_layout(%arg0: memref<2048xf32, 1>) kernel {
+    // CHECK: %[[LOAD:.*]] = xegpu.load %arg0{{.*}} <{layout = #xegpu.layout<sg_layout = [1, 16], sg_data = [4, 16]>}> {{.*}} -> vector<8x256xf32>
+    // CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[LOAD]], {{.*}} {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 16], sg_data = [4, 16]>, dims = [1]>} [1] : vector<8x256xf32> to vector<8xf32>
+    %offset = arith.constant dense<0> : vector<8x256xindex>
+    %offset2 = arith.constant dense<1024> : vector<8xindex>
+    %acc = arith.constant dense<0.000000e+00> : vector<8xf32>
+    %mask = arith.constant dense<true> : vector<8x256xi1>
+    %mask2 = arith.constant dense<true> : vector<8xi1>
+    %val = xegpu.load %arg0[%offset], %mask : memref<2048xf32, 1>, vector<8x256xindex>, vector<8x256xi1> -> vector<8x256xf32>
+    %reduce = vector.multi_reduction <add>, %val, %acc [1] : vector<8x256xf32> to vector<8xf32>
+    xegpu.store %reduce, %arg0[%offset2], %mask2 { layout = #xegpu.slice<#data_layout, dims = [1]> } : vector<8xf32>, memref<2048xf32, 1>, vector<8xindex>, vector<8xi1>
+    gpu.return
+  }
+}
+
 // -----
 gpu.module @test {
   // CHECK-LABEL: for_loop_dpas

``````````

</details>


https://github.com/llvm/llvm-project/pull/189795


More information about the Mlir-commits mailing list