[Mlir-commits] [mlir] [MLIR][XeGPU] Lower vector.multi_reduction before linearization in XeGPUVectorLinearize (PR #190272)

Nishant Patel llvmlistbot at llvm.org
Tue Apr 7 09:10:08 PDT 2026


https://github.com/nbpatel updated https://github.com/llvm/llvm-project/pull/190272

>From a310208876a5abd5adcefa009c91a9bfd19cad0e Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 2 Apr 2026 22:03:32 +0000
Subject: [PATCH 1/2] Linearize vector.multi_reduction

---
 .../XeGPU/Transforms/XeGPUVectorLinearize.cpp | 40 +++++++++++++++++++
 .../Dialect/XeGPU/xegpu-vector-linearize.mlir | 18 +++++++++
 2 files changed, 58 insertions(+)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp
index e31c37a2459ad..0b18f2ef49736 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp
@@ -51,6 +51,46 @@ struct XeGPUVectorLinearizePass final
         return signalPassFailure();
     }
 
+    // Lower vector.multi_reduction before linearization. Linearization flattens
+    // nD vectors to 1D, destroying axis information that multi_reduction relies
+    // on to know which elements to group together. By unrolling multi_reduction
+    // into row-wise shuffle + scalar reduction ops first, the IR contains only
+    // shape-agnostic ops by the time linearization runs.
+    //
+    // Two pattern sets are applied in order:
+    //   1. ReorderPatterns (InnerOuterDimReductionConversion): inserts
+    //      vector.transpose to move all reduction dims to either the innermost
+    //      or outermost positions. This normalizes arbitrary reductions into a
+    //      canonical 2-D form that the unrolling patterns can handle.
+    //   2. UnrollingPatterns: with InnerParallel mode, the reduction dims are
+    //      outermost, so the inner (parallel) dims are treated as rows and the
+    //      outer loop is unrolled into a sequence of element-wise arith ops
+    //      (TwoDimMultiReductionToElementWise). Any remaining 1-D
+    //      multi_reduction is converted to vector.reduction
+    //      (OneDimMultiReductionToReduction).
+    // Example: reduce 4x8 matrix along rows (dim 0):
+    //   %0 = vector.multi_reduction <add>, %arg0, %acc [0]
+    //          : vector<4x8xf32> to vector<8xf32>
+    // is unrolled into:
+    //   %flat = vector.shape_cast %arg0 : vector<4x8xf32> to vector<32xf32>
+    //   %s0 = vector.shuffle %flat, %flat [0, 1, 2, 3, 4, 5, 6, 7]
+    //           : vector<32xf32>, vector<32xf32>
+    //   %r0 = arith.addf %s0, %acc : vector<8xf32>            // row 0 + acc
+    //   %s1 = vector.shuffle %flat, %flat [8, 9, 10, 11, 12, 13, 14, 15]
+    //           : vector<32xf32>, vector<32xf32>
+    //   %r1 = arith.addf %s1, %r0 : vector<8xf32>             // row 1 + r0
+    //   ...                                                    // rows 2, 3
+    // These shape-agnostic ops are then safely linearized.
+    //
+    {
+      auto options = vector::VectorMultiReductionLowering::InnerParallel;
+      RewritePatternSet patterns(&getContext());
+      vector::populateVectorMultiReductionReorderPatterns(patterns, options);
+      vector::populateVectorMultiReductionUnrollingPatterns(patterns, options);
+      if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+        return signalPassFailure();
+    }
+
     // Unroll load/store from <d1xd2x...xdk> to (d1*d2*...*d(k-1)) slices of
     // <1x1x...x1xdk>.
     {
diff --git a/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir b/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir
index 94205a6c26ba2..20b555d768337 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir
@@ -265,4 +265,22 @@ gpu.module @test_kernel {
   }
 }
 
+// -----
+// CHECK-LABEL: func.func @test_vector_multi_reduction_add
+// CHECK-SAME: (%[[ARG0:.*]]: vector<4x8xf32>, %[[ARG1:.*]]: vector<8xf32>) -> vector<8xf32>
+// CHECK: %[[FLAT:.*]] = vector.shape_cast %[[ARG0]] : vector<4x8xf32> to vector<32xf32>
+// CHECK: %[[S0:.*]] = vector.shuffle %[[FLAT]], %[[FLAT]] [0, 1, 2, 3, 4, 5, 6, 7]
+// CHECK: %[[R0:.*]] = arith.addf %[[S0]], %[[ARG1]] : vector<8xf32>
+// CHECK: %[[S1:.*]] = vector.shuffle %[[FLAT]], %[[FLAT]] [8, 9, 10, 11, 12, 13, 14, 15]
+// CHECK: %[[R1:.*]] = arith.addf %[[S1]], %[[R0]] : vector<8xf32>
+// CHECK: %[[S2:.*]] = vector.shuffle %[[FLAT]], %[[FLAT]] [16, 17, 18, 19, 20, 21, 22, 23]
+// CHECK: %[[R2:.*]] = arith.addf %[[S2]], %[[R1]] : vector<8xf32>
+// CHECK: %[[S3:.*]] = vector.shuffle %[[FLAT]], %[[FLAT]] [24, 25, 26, 27, 28, 29, 30, 31]
+// CHECK: %[[R3:.*]] = arith.addf %[[S3]], %[[R2]] : vector<8xf32>
+// CHECK: return %[[R3]] : vector<8xf32>
+func.func @test_vector_multi_reduction_add(%arg0: vector<4x8xf32>, %arg1: vector<8xf32>) -> vector<8xf32> {
+  %0 = vector.multi_reduction <add>, %arg0, %arg1 [0] : vector<4x8xf32> to vector<8xf32>
+  return %0 : vector<8xf32>
+}
+
 

>From d8889f9163e2dee93f874a7261af970065a1f8a3 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 7 Apr 2026 16:09:41 +0000
Subject: [PATCH 2/2] Update test

---
 .../Dialect/XeGPU/xegpu-vector-linearize.mlir | 27 +++++++++----------
 1 file changed, 13 insertions(+), 14 deletions(-)

diff --git a/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir b/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir
index 20b555d768337..bac1e838b038e 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir
@@ -267,20 +267,19 @@ gpu.module @test_kernel {
 
 // -----
 // CHECK-LABEL: func.func @test_vector_multi_reduction_add
-// CHECK-SAME: (%[[ARG0:.*]]: vector<4x8xf32>, %[[ARG1:.*]]: vector<8xf32>) -> vector<8xf32>
-// CHECK: %[[FLAT:.*]] = vector.shape_cast %[[ARG0]] : vector<4x8xf32> to vector<32xf32>
-// CHECK: %[[S0:.*]] = vector.shuffle %[[FLAT]], %[[FLAT]] [0, 1, 2, 3, 4, 5, 6, 7]
-// CHECK: %[[R0:.*]] = arith.addf %[[S0]], %[[ARG1]] : vector<8xf32>
-// CHECK: %[[S1:.*]] = vector.shuffle %[[FLAT]], %[[FLAT]] [8, 9, 10, 11, 12, 13, 14, 15]
-// CHECK: %[[R1:.*]] = arith.addf %[[S1]], %[[R0]] : vector<8xf32>
-// CHECK: %[[S2:.*]] = vector.shuffle %[[FLAT]], %[[FLAT]] [16, 17, 18, 19, 20, 21, 22, 23]
-// CHECK: %[[R2:.*]] = arith.addf %[[S2]], %[[R1]] : vector<8xf32>
-// CHECK: %[[S3:.*]] = vector.shuffle %[[FLAT]], %[[FLAT]] [24, 25, 26, 27, 28, 29, 30, 31]
-// CHECK: %[[R3:.*]] = arith.addf %[[S3]], %[[R2]] : vector<8xf32>
-// CHECK: return %[[R3]] : vector<8xf32>
-func.func @test_vector_multi_reduction_add(%arg0: vector<4x8xf32>, %arg1: vector<8xf32>) -> vector<8xf32> {
-  %0 = vector.multi_reduction <add>, %arg0, %arg1 [0] : vector<4x8xf32> to vector<8xf32>
-  return %0 : vector<8xf32>
+// CHECK-SAME: (%[[ARG0:.*]]: vector<16x1xf16>, %[[ARG1:.*]]: vector<1xf16>) -> vector<1xf16>
+// CHECK:      %[[FLAT:.*]] = vector.shape_cast %[[ARG0]] : vector<16x1xf16> to vector<16xf16>
+// CHECK:      vector.shuffle %[[FLAT]], %[[FLAT]] [0] : vector<16xf16>, vector<16xf16>
+// CHECK:      arith.addf {{.*}}, %[[ARG1]] : vector<1xf16>
+// 14 more shuffle+addf pairs for indices 1..14
+// CHECK-COUNT-14: vector.shuffle %[[FLAT]], %[[FLAT]] {{.*}} : vector<16xf16>, vector<16xf16>
+// Final shuffle (index 15) + addf + return
+// CHECK:      vector.shuffle %[[FLAT]], %[[FLAT]] [15] : vector<16xf16>, vector<16xf16>
+// CHECK:      %[[LAST:.*]] = arith.addf
+// CHECK:      return %[[LAST]] : vector<1xf16>
+func.func @test_vector_multi_reduction_add(%arg0: vector<16x1xf16>, %arg1: vector<1xf16>) -> vector<1xf16> {
+  %0 = vector.multi_reduction <add>, %arg0, %arg1 [0] : vector<16x1xf16> to vector<1xf16>
+  return %0 : vector<1xf16>
 }
 
 



More information about the Mlir-commits mailing list