[Mlir-commits] [mlir] [MLIR][XeGPU] Fix Multi-Reduction Layout Rule to Preserve sg_data from consumer layout for non-reduced dims (PR #189795)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Mar 31 22:07:58 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu
Author: Jianhui Li (Jianhui-Li)
<details>
<summary>Changes</summary>
As title
---
Full diff: https://github.com/llvm/llvm-project/pull/189795.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp (+6-4)
- (modified) mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir (+19)
``````````diff
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index ec5751634fdff..278414c6581c7 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -462,6 +462,8 @@ xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
SmallVector<int64_t> consumerSgLayout =
consumerLayout.getEffectiveSgLayoutAsInt();
+ SmallVector<int64_t> consumerSgData =
+ consumerLayout.getEffectiveSgDataAsInt();
SmallVector<int64_t> consumerLaneLayout =
consumerLayout.getEffectiveLaneLayoutAsInt();
SmallVector<int64_t> consumerOrder = consumerLayout.getEffectiveOrderAsInt();
@@ -479,7 +481,9 @@ xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
auto srcSgData = computeShapeRatio(srcShape, sgLayoutFromConsumer);
if (srcSgData)
for (int dim = 0; dim < srcRank; dim++) {
- srcLayout = srcLayout.setDimData(dim, srcSgData.value()[dim], -1, -1);
+ if (llvm::is_contained(reductionDims, dim))
+ srcLayout =
+ srcLayout.setDimData(dim, srcSgData.value()[dim], -1, -1);
}
} else {
@@ -492,9 +496,7 @@ xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
if (!llvm::is_contained(reductionDims, i) &&
consumerIdx < static_cast<int>(consumerSgLayout.size())) {
sgLayout[i] = consumerSgLayout[consumerIdx];
- assert((srcShape[i] % sgLayout[i] == 0) &&
- "source shape not divisible by consumer sg_layout");
- sgData[i] = srcShape[i] / sgLayout[i];
+ sgData[i] = consumerSgData[consumerIdx];
remainingSgCount /= sgLayout[i];
order[i] = consumerOrder[consumerIdx];
consumerIdx++;
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
index e4e6d61b92fda..a8ce58aa53a2f 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
@@ -196,6 +196,25 @@ gpu.module @test {
}
}
+// -----
+#data_layout = #xegpu.layout<sg_layout = [1, 16], sg_data = [4, 16]>
+gpu.module @test {
+ // CHECK-LABEL: gpu.func @vector_reduction_preserve_round_robin_layout
+ gpu.func @vector_reduction_preserve_round_robin_layout(%arg0: memref<2048xf32, 1>) kernel {
+ // CHECK: %[[LOAD:.*]] = xegpu.load %arg0{{.*}} <{layout = #xegpu.layout<sg_layout = [1, 16], sg_data = [4, 16]>}> {{.*}} -> vector<8x256xf32>
+ // CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[LOAD]], {{.*}} {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 16], sg_data = [4, 16]>, dims = [1]>} [1] : vector<8x256xf32> to vector<8xf32>
+ %offset = arith.constant dense<0> : vector<8x256xindex>
+ %offset2 = arith.constant dense<1024> : vector<8xindex>
+ %acc = arith.constant dense<0.000000e+00> : vector<8xf32>
+ %mask = arith.constant dense<true> : vector<8x256xi1>
+ %mask2 = arith.constant dense<true> : vector<8xi1>
+ %val = xegpu.load %arg0[%offset], %mask : memref<2048xf32, 1>, vector<8x256xindex>, vector<8x256xi1> -> vector<8x256xf32>
+ %reduce = vector.multi_reduction <add>, %val, %acc [1] : vector<8x256xf32> to vector<8xf32>
+ xegpu.store %reduce, %arg0[%offset2], %mask2 { layout = #xegpu.slice<#data_layout, dims = [1]> } : vector<8xf32>, memref<2048xf32, 1>, vector<8xindex>, vector<8xi1>
+ gpu.return
+ }
+}
+
// -----
gpu.module @test {
// CHECK-LABEL: for_loop_dpas
``````````
</details>
https://github.com/llvm/llvm-project/pull/189795
More information about the Mlir-commits
mailing list