[Mlir-commits] [mlir] [MLIR][XeGPU] Lower vector.multi_reduction before linearization in XeGPUVectorLinearize (PR #190272)
Nishant Patel
llvmlistbot at llvm.org
Tue Apr 7 09:10:08 PDT 2026
https://github.com/nbpatel updated https://github.com/llvm/llvm-project/pull/190272
>From a310208876a5abd5adcefa009c91a9bfd19cad0e Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 2 Apr 2026 22:03:32 +0000
Subject: [PATCH 1/2] Linearize vector.multi_reduction
---
.../XeGPU/Transforms/XeGPUVectorLinearize.cpp | 40 +++++++++++++++++++
.../Dialect/XeGPU/xegpu-vector-linearize.mlir | 18 +++++++++
2 files changed, 58 insertions(+)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp
index e31c37a2459ad..0b18f2ef49736 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp
@@ -51,6 +51,46 @@ struct XeGPUVectorLinearizePass final
return signalPassFailure();
}
+ // Lower vector.multi_reduction before linearization. Linearization flattens
+ // nD vectors to 1D, destroying axis information that multi_reduction relies
+ // on to know which elements to group together. By unrolling multi_reduction
+ // into row-wise shuffle + scalar reduction ops first, the IR contains only
+ // shape-agnostic ops by the time linearization runs.
+ //
+ // Two pattern sets are applied in order:
+ // 1. ReorderPatterns (InnerOuterDimReductionConversion): inserts
+ // vector.transpose to move all reduction dims to either the innermost
+ // or outermost positions. This normalizes arbitrary reductions into a
+ // canonical 2-D form that the unrolling patterns can handle.
+ // 2. UnrollingPatterns: with InnerParallel mode, the reduction dims are
+ // outermost, so the inner (parallel) dims are treated as rows and the
+ // outer loop is unrolled into a sequence of element-wise arith ops
+ // (TwoDimMultiReductionToElementWise). Any remaining 1-D
+ // multi_reduction is converted to vector.reduction
+ // (OneDimMultiReductionToReduction).
+ // Example: reduce 4x8 matrix along rows (dim 0):
+ // %0 = vector.multi_reduction <add>, %arg0, %acc [0]
+ // : vector<4x8xf32> to vector<8xf32>
+ // is unrolled into:
+ // %flat = vector.shape_cast %arg0 : vector<4x8xf32> to vector<32xf32>
+ // %s0 = vector.shuffle %flat, %flat [0, 1, 2, 3, 4, 5, 6, 7]
+ // : vector<32xf32>, vector<32xf32>
+ // %r0 = arith.addf %s0, %acc : vector<8xf32> // row 0 + acc
+ // %s1 = vector.shuffle %flat, %flat [8, 9, 10, 11, 12, 13, 14, 15]
+ // : vector<32xf32>, vector<32xf32>
+ // %r1 = arith.addf %s1, %r0 : vector<8xf32> // row 1 + r0
+ // ... // rows 2, 3
+ // These shape-agnostic ops are then safely linearized.
+ //
+ {
+ auto options = vector::VectorMultiReductionLowering::InnerParallel;
+ RewritePatternSet patterns(&getContext());
+ vector::populateVectorMultiReductionReorderPatterns(patterns, options);
+ vector::populateVectorMultiReductionUnrollingPatterns(patterns, options);
+ if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+ return signalPassFailure();
+ }
+
// Unroll load/store from <d1xd2x...xdk> to (d1*d2*...*d(k-1)) slices of
// <1x1x...x1xdk>.
{
diff --git a/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir b/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir
index 94205a6c26ba2..20b555d768337 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir
@@ -265,4 +265,22 @@ gpu.module @test_kernel {
}
}
+// -----
+// CHECK-LABEL: func.func @test_vector_multi_reduction_add
+// CHECK-SAME: (%[[ARG0:.*]]: vector<4x8xf32>, %[[ARG1:.*]]: vector<8xf32>) -> vector<8xf32>
+// CHECK: %[[FLAT:.*]] = vector.shape_cast %[[ARG0]] : vector<4x8xf32> to vector<32xf32>
+// CHECK: %[[S0:.*]] = vector.shuffle %[[FLAT]], %[[FLAT]] [0, 1, 2, 3, 4, 5, 6, 7]
+// CHECK: %[[R0:.*]] = arith.addf %[[S0]], %[[ARG1]] : vector<8xf32>
+// CHECK: %[[S1:.*]] = vector.shuffle %[[FLAT]], %[[FLAT]] [8, 9, 10, 11, 12, 13, 14, 15]
+// CHECK: %[[R1:.*]] = arith.addf %[[S1]], %[[R0]] : vector<8xf32>
+// CHECK: %[[S2:.*]] = vector.shuffle %[[FLAT]], %[[FLAT]] [16, 17, 18, 19, 20, 21, 22, 23]
+// CHECK: %[[R2:.*]] = arith.addf %[[S2]], %[[R1]] : vector<8xf32>
+// CHECK: %[[S3:.*]] = vector.shuffle %[[FLAT]], %[[FLAT]] [24, 25, 26, 27, 28, 29, 30, 31]
+// CHECK: %[[R3:.*]] = arith.addf %[[S3]], %[[R2]] : vector<8xf32>
+// CHECK: return %[[R3]] : vector<8xf32>
+func.func @test_vector_multi_reduction_add(%arg0: vector<4x8xf32>, %arg1: vector<8xf32>) -> vector<8xf32> {
+ %0 = vector.multi_reduction <add>, %arg0, %arg1 [0] : vector<4x8xf32> to vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
>From d8889f9163e2dee93f874a7261af970065a1f8a3 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 7 Apr 2026 16:09:41 +0000
Subject: [PATCH 2/2] Update test
---
.../Dialect/XeGPU/xegpu-vector-linearize.mlir | 27 +++++++++----------
1 file changed, 13 insertions(+), 14 deletions(-)
diff --git a/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir b/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir
index 20b555d768337..bac1e838b038e 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir
@@ -267,20 +267,19 @@ gpu.module @test_kernel {
// -----
// CHECK-LABEL: func.func @test_vector_multi_reduction_add
-// CHECK-SAME: (%[[ARG0:.*]]: vector<4x8xf32>, %[[ARG1:.*]]: vector<8xf32>) -> vector<8xf32>
-// CHECK: %[[FLAT:.*]] = vector.shape_cast %[[ARG0]] : vector<4x8xf32> to vector<32xf32>
-// CHECK: %[[S0:.*]] = vector.shuffle %[[FLAT]], %[[FLAT]] [0, 1, 2, 3, 4, 5, 6, 7]
-// CHECK: %[[R0:.*]] = arith.addf %[[S0]], %[[ARG1]] : vector<8xf32>
-// CHECK: %[[S1:.*]] = vector.shuffle %[[FLAT]], %[[FLAT]] [8, 9, 10, 11, 12, 13, 14, 15]
-// CHECK: %[[R1:.*]] = arith.addf %[[S1]], %[[R0]] : vector<8xf32>
-// CHECK: %[[S2:.*]] = vector.shuffle %[[FLAT]], %[[FLAT]] [16, 17, 18, 19, 20, 21, 22, 23]
-// CHECK: %[[R2:.*]] = arith.addf %[[S2]], %[[R1]] : vector<8xf32>
-// CHECK: %[[S3:.*]] = vector.shuffle %[[FLAT]], %[[FLAT]] [24, 25, 26, 27, 28, 29, 30, 31]
-// CHECK: %[[R3:.*]] = arith.addf %[[S3]], %[[R2]] : vector<8xf32>
-// CHECK: return %[[R3]] : vector<8xf32>
-func.func @test_vector_multi_reduction_add(%arg0: vector<4x8xf32>, %arg1: vector<8xf32>) -> vector<8xf32> {
- %0 = vector.multi_reduction <add>, %arg0, %arg1 [0] : vector<4x8xf32> to vector<8xf32>
- return %0 : vector<8xf32>
+// CHECK-SAME: (%[[ARG0:.*]]: vector<16x1xf16>, %[[ARG1:.*]]: vector<1xf16>) -> vector<1xf16>
+// CHECK: %[[FLAT:.*]] = vector.shape_cast %[[ARG0]] : vector<16x1xf16> to vector<16xf16>
+// CHECK: vector.shuffle %[[FLAT]], %[[FLAT]] [0] : vector<16xf16>, vector<16xf16>
+// CHECK: arith.addf {{.*}}, %[[ARG1]] : vector<1xf16>
+// 14 more shuffle+addf pairs for indices 1..14
+// CHECK-COUNT-14: vector.shuffle %[[FLAT]], %[[FLAT]] {{.*}} : vector<16xf16>, vector<16xf16>
+// Final shuffle (index 15) + addf + return
+// CHECK: vector.shuffle %[[FLAT]], %[[FLAT]] [15] : vector<16xf16>, vector<16xf16>
+// CHECK: %[[LAST:.*]] = arith.addf
+// CHECK: return %[[LAST]] : vector<1xf16>
+func.func @test_vector_multi_reduction_add(%arg0: vector<16x1xf16>, %arg1: vector<1xf16>) -> vector<1xf16> {
+ %0 = vector.multi_reduction <add>, %arg0, %arg1 [0] : vector<16x1xf16> to vector<1xf16>
+ return %0 : vector<1xf16>
}
More information about the Mlir-commits
mailing list