[Mlir-commits] [mlir] cfec3b8 - [MLIR][XeGPU] Fix Multi-Reduction Layout Rule to Preserve sg_data from consumer layout for non-reduced dims (#189795)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Apr 7 10:46:53 PDT 2026


Author: Jianhui Li
Date: 2026-04-07T10:46:48-07:00
New Revision: cfec3b8efb6b04edce719ac3f2150330cf82aaff

URL: https://github.com/llvm/llvm-project/commit/cfec3b8efb6b04edce719ac3f2150330cf82aaff
DIFF: https://github.com/llvm/llvm-project/commit/cfec3b8efb6b04edce719ac3f2150330cf82aaff.diff

LOG: [MLIR][XeGPU] Fix Multi-Reduction Layout Rule to Preserve sg_data from consumer layout for non-reduced dims (#189795)

As title

Added: 
    

Modified: 
    mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
    mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 55cd6ec04970c..c3c40ceb4c6ae 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -471,12 +471,17 @@ xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
       auto srcSgData = computeShapeRatio(srcShape, sgLayoutFromConsumer);
       if (srcSgData)
         for (int dim = 0; dim < srcRank; dim++) {
-          srcLayout = srcLayout.setDimData(dim, srcSgData.value()[dim], -1, -1);
+          if (llvm::is_contained(reductionDims, dim))
+            srcLayout =
+                srcLayout.setDimData(dim, srcSgData.value()[dim], -1, -1);
         }
     } else {
       SmallVector<int64_t> consumerSgLayout =
           consumerLayout ? consumerLayout.getEffectiveSgLayoutAsInt()
                          : SmallVector<int64_t>();
+      SmallVector<int64_t> consumerSgData =
+          consumerLayout ? consumerLayout.getEffectiveSgDataAsInt()
+                         : SmallVector<int64_t>();
       SmallVector<int64_t> consumerOrder =
           consumerLayout ? consumerLayout.getEffectiveOrderAsInt()
                          : SmallVector<int64_t>();
@@ -492,9 +497,7 @@ xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
         if (!llvm::is_contained(reductionDims, i) &&
             consumerIdx < static_cast<int>(consumerSgLayout.size())) {
           sgLayout[i] = consumerSgLayout[consumerIdx];
-          assert((srcShape[i] % sgLayout[i] == 0) &&
-                 "source shape not divisible by consumer sg_layout");
-          sgData[i] = srcShape[i] / sgLayout[i];
+          sgData[i] = consumerSgData[consumerIdx];
           remainingSgCount /= sgLayout[i];
           order[i] = consumerOrder[consumerIdx];
           consumerIdx++;

diff  --git a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
index bb387b4cfb093..831d1e05967f8 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
@@ -209,6 +209,25 @@ gpu.module @test {
   }
 }
 
+// -----
+#data_layout = #xegpu.layout<sg_layout = [1, 16], sg_data = [4, 16]>
+gpu.module @test {
+  // CHECK-LABEL: gpu.func @vector_reduction_preserve_round_robin_layout
+  gpu.func @vector_reduction_preserve_round_robin_layout(%arg0: memref<2048xf32, 1>) kernel {
+    // CHECK: %[[LOAD:.*]] = xegpu.load %arg0{{.*}} <{layout = #xegpu.layout<sg_layout = [1, 16], sg_data = [4, 16]>}> {{.*}} -> vector<8x256xf32>
+    // CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[LOAD]], {{.*}} {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 16], sg_data = [4, 16]>, dims = [1]>} [1] : vector<8x256xf32> to vector<8xf32>
+    %offset = arith.constant dense<0> : vector<8x256xindex>
+    %offset2 = arith.constant dense<1024> : vector<8xindex>
+    %acc = arith.constant dense<0.000000e+00> : vector<8xf32>
+    %mask = arith.constant dense<true> : vector<8x256xi1>
+    %mask2 = arith.constant dense<true> : vector<8xi1>
+    %val = xegpu.load %arg0[%offset], %mask : memref<2048xf32, 1>, vector<8x256xindex>, vector<8x256xi1> -> vector<8x256xf32>
+    %reduce = vector.multi_reduction <add>, %val, %acc [1] : vector<8x256xf32> to vector<8xf32>
+    xegpu.store %reduce, %arg0[%offset2], %mask2 { layout = #xegpu.slice<#data_layout, dims = [1]> } : vector<8xf32>, memref<2048xf32, 1>, vector<8xindex>, vector<8xi1>
+    gpu.return
+  }
+}
+
 // -----
 gpu.module @test {
   // CHECK-LABEL: for_loop_dpas


        


More information about the Mlir-commits mailing list