[Mlir-commits] [mlir] cfec3b8 - [MLIR][XeGPU] Fix Multi-Reduction Layout Rule to Preserve sg_data from consumer layout for non-reduced dims (#189795)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Apr 7 10:46:53 PDT 2026
Author: Jianhui Li
Date: 2026-04-07T10:46:48-07:00
New Revision: cfec3b8efb6b04edce719ac3f2150330cf82aaff
URL: https://github.com/llvm/llvm-project/commit/cfec3b8efb6b04edce719ac3f2150330cf82aaff
DIFF: https://github.com/llvm/llvm-project/commit/cfec3b8efb6b04edce719ac3f2150330cf82aaff.diff
LOG: [MLIR][XeGPU] Fix Multi-Reduction Layout Rule to Preserve sg_data from consumer layout for non-reduced dims (#189795)
As title
Added:
Modified:
mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 55cd6ec04970c..c3c40ceb4c6ae 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -471,12 +471,17 @@ xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
auto srcSgData = computeShapeRatio(srcShape, sgLayoutFromConsumer);
if (srcSgData)
for (int dim = 0; dim < srcRank; dim++) {
- srcLayout = srcLayout.setDimData(dim, srcSgData.value()[dim], -1, -1);
+ if (llvm::is_contained(reductionDims, dim))
+ srcLayout =
+ srcLayout.setDimData(dim, srcSgData.value()[dim], -1, -1);
}
} else {
SmallVector<int64_t> consumerSgLayout =
consumerLayout ? consumerLayout.getEffectiveSgLayoutAsInt()
: SmallVector<int64_t>();
+ SmallVector<int64_t> consumerSgData =
+ consumerLayout ? consumerLayout.getEffectiveSgDataAsInt()
+ : SmallVector<int64_t>();
SmallVector<int64_t> consumerOrder =
consumerLayout ? consumerLayout.getEffectiveOrderAsInt()
: SmallVector<int64_t>();
@@ -492,9 +497,7 @@ xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
if (!llvm::is_contained(reductionDims, i) &&
consumerIdx < static_cast<int>(consumerSgLayout.size())) {
sgLayout[i] = consumerSgLayout[consumerIdx];
- assert((srcShape[i] % sgLayout[i] == 0) &&
- "source shape not divisible by consumer sg_layout");
- sgData[i] = srcShape[i] / sgLayout[i];
+ sgData[i] = consumerSgData[consumerIdx];
remainingSgCount /= sgLayout[i];
order[i] = consumerOrder[consumerIdx];
consumerIdx++;
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
index bb387b4cfb093..831d1e05967f8 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
@@ -209,6 +209,25 @@ gpu.module @test {
}
}
+// -----
+#data_layout = #xegpu.layout<sg_layout = [1, 16], sg_data = [4, 16]>
+gpu.module @test {
+ // CHECK-LABEL: gpu.func @vector_reduction_preserve_round_robin_layout
+ gpu.func @vector_reduction_preserve_round_robin_layout(%arg0: memref<2048xf32, 1>) kernel {
+ // CHECK: %[[LOAD:.*]] = xegpu.load %arg0{{.*}} <{layout = #xegpu.layout<sg_layout = [1, 16], sg_data = [4, 16]>}> {{.*}} -> vector<8x256xf32>
+ // CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[LOAD]], {{.*}} {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 16], sg_data = [4, 16]>, dims = [1]>} [1] : vector<8x256xf32> to vector<8xf32>
+ %offset = arith.constant dense<0> : vector<8x256xindex>
+ %offset2 = arith.constant dense<1024> : vector<8xindex>
+ %acc = arith.constant dense<0.000000e+00> : vector<8xf32>
+ %mask = arith.constant dense<true> : vector<8x256xi1>
+ %mask2 = arith.constant dense<true> : vector<8xi1>
+ %val = xegpu.load %arg0[%offset], %mask : memref<2048xf32, 1>, vector<8x256xindex>, vector<8x256xi1> -> vector<8x256xf32>
+ %reduce = vector.multi_reduction <add>, %val, %acc [1] : vector<8x256xf32> to vector<8xf32>
+ xegpu.store %reduce, %arg0[%offset2], %mask2 { layout = #xegpu.slice<#data_layout, dims = [1]> } : vector<8xf32>, memref<2048xf32, 1>, vector<8xindex>, vector<8xi1>
+ gpu.return
+ }
+}
+
// -----
gpu.module @test {
// CHECK-LABEL: for_loop_dpas
More information about the Mlir-commits
mailing list