[Mlir-commits] [mlir] [MLIR][XeGPU] Fix Multi-Reduction Layout Rule to Preserve sg_data from consumer layout for non-reduced dims (PR #189795)

Jianhui Li llvmlistbot at llvm.org
Tue Mar 31 22:07:28 PDT 2026


https://github.com/Jianhui-Li created https://github.com/llvm/llvm-project/pull/189795

As title

>From 28d2de53e3d0284a1f0d5faa041b23bfd101b0bc Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 1 Apr 2026 04:47:13 +0000
Subject: [PATCH 1/2] fix the multireduction rule to reuse the sgdata from the
 consumer

---
 .../XeGPU/Transforms/XeGPULayoutImpl.cpp      |  4 +++-
 .../XeGPU/propagate-layout-subgroup.mlir      | 19 +++++++++++++++++++
 2 files changed, 22 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index ec5751634fdff..5010be3678fd8 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -479,7 +479,9 @@ xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
       auto srcSgData = computeShapeRatio(srcShape, sgLayoutFromConsumer);
       if (srcSgData)
         for (int dim = 0; dim < srcRank; dim++) {
-          srcLayout = srcLayout.setDimData(dim, srcSgData.value()[dim], -1, -1);
+          if (llvm::is_contained(reductionDims, dim))
+            srcLayout =
+                srcLayout.setDimData(dim, srcSgData.value()[dim], -1, -1);
         }
     } else {
 
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
index e4e6d61b92fda..a8ce58aa53a2f 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
@@ -196,6 +196,25 @@ gpu.module @test {
   }
 }
 
+// -----
+#data_layout = #xegpu.layout<sg_layout = [1, 16], sg_data = [4, 16]>
+gpu.module @test {
+  // CHECK-LABEL: gpu.func @vector_reduction_preserve_round_robin_layout
+  gpu.func @vector_reduction_preserve_round_robin_layout(%arg0: memref<2048xf32, 1>) kernel {
+    // CHECK: %[[LOAD:.*]] = xegpu.load %arg0{{.*}} <{layout = #xegpu.layout<sg_layout = [1, 16], sg_data = [4, 16]>}> {{.*}} -> vector<8x256xf32>
+    // CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[LOAD]], {{.*}} {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 16], sg_data = [4, 16]>, dims = [1]>} [1] : vector<8x256xf32> to vector<8xf32>
+    %offset = arith.constant dense<0> : vector<8x256xindex>
+    %offset2 = arith.constant dense<1024> : vector<8xindex>
+    %acc = arith.constant dense<0.000000e+00> : vector<8xf32>
+    %mask = arith.constant dense<true> : vector<8x256xi1>
+    %mask2 = arith.constant dense<true> : vector<8xi1>
+    %val = xegpu.load %arg0[%offset], %mask : memref<2048xf32, 1>, vector<8x256xindex>, vector<8x256xi1> -> vector<8x256xf32>
+    %reduce = vector.multi_reduction <add>, %val, %acc [1] : vector<8x256xf32> to vector<8xf32>
+    xegpu.store %reduce, %arg0[%offset2], %mask2 { layout = #xegpu.slice<#data_layout, dims = [1]> } : vector<8xf32>, memref<2048xf32, 1>, vector<8xindex>, vector<8xi1>
+    gpu.return
+  }
+}
+
 // -----
 gpu.module @test {
   // CHECK-LABEL: for_loop_dpas

>From fab8091e2ef3ef1a2fb45856607bad43e6ba272d Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 1 Apr 2026 05:04:35 +0000
Subject: [PATCH 2/2] fix the multireduction rule to reuse the sgdata from the
 consumer when the consumer is not slice layout

---
 mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 5010be3678fd8..278414c6581c7 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -462,6 +462,8 @@ xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
 
   SmallVector<int64_t> consumerSgLayout =
       consumerLayout.getEffectiveSgLayoutAsInt();
+  SmallVector<int64_t> consumerSgData =
+      consumerLayout.getEffectiveSgDataAsInt();
   SmallVector<int64_t> consumerLaneLayout =
       consumerLayout.getEffectiveLaneLayoutAsInt();
   SmallVector<int64_t> consumerOrder = consumerLayout.getEffectiveOrderAsInt();
@@ -494,9 +496,7 @@ xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
         if (!llvm::is_contained(reductionDims, i) &&
             consumerIdx < static_cast<int>(consumerSgLayout.size())) {
           sgLayout[i] = consumerSgLayout[consumerIdx];
-          assert((srcShape[i] % sgLayout[i] == 0) &&
-                 "source shape not divisible by consumer sg_layout");
-          sgData[i] = srcShape[i] / sgLayout[i];
+          sgData[i] = consumerSgData[consumerIdx];
           remainingSgCount /= sgLayout[i];
           order[i] = consumerOrder[consumerIdx];
           consumerIdx++;



More information about the Mlir-commits mailing list