[Mlir-commits] [mlir] [MLIR][XeGPU] Lower vector.multi_reduction before linearization in XeGPUVectorLinearize (PR #190272)

Charitha Saumya llvmlistbot at llvm.org
Fri Apr 3 15:06:48 PDT 2026


================
@@ -51,6 +51,46 @@ struct XeGPUVectorLinearizePass final
         return signalPassFailure();
     }
 
+    // Lower vector.multi_reduction before linearization. Linearization flattens
+    // nD vectors to 1D, destroying axis information that multi_reduction relies
+    // on to know which elements to group together. By unrolling multi_reduction
+    // into row-wise shuffle + scalar reduction ops first, the IR contains only
+    // shape-agnostic ops by the time linearization runs.
+    //
+    // Two pattern sets are applied in order:
+    //   1. ReorderPatterns (InnerOuterDimReductionConversion): inserts
+    //      vector.transpose to move all reduction dims to either the innermost
+    //      or outermost positions. This normalizes arbitrary reductions into a
+    //      canonical 2-D form that the unrolling patterns can handle.
+    //   2. UnrollingPatterns: with InnerParallel mode, the reduction dims are
+    //      outermost, so the inner (parallel) dims are treated as rows and the
+    //      outer loop is unrolled into a sequence of element-wise arith ops
+    //      (TwoDimMultiReductionToElementWise). Any remaining 1-D
+    //      multi_reduction is converted to vector.reduction
+    //      (OneDimMultiReductionToReduction).
+    // Example: reduce 4x8 matrix along rows (dim 0):
+    //   %0 = vector.multi_reduction <add>, %arg0, %acc [0]
+    //          : vector<4x8xf32> to vector<8xf32>
+    // is unrolled into:
+    //   %flat = vector.shape_cast %arg0 : vector<4x8xf32> to vector<32xf32>
+    //   %s0 = vector.shuffle %flat, %flat [0, 1, 2, 3, 4, 5, 6, 7]
+    //           : vector<32xf32>, vector<32xf32>
+    //   %r0 = arith.addf %s0, %acc : vector<8xf32>            // row 0 + acc
+    //   %s1 = vector.shuffle %flat, %flat [8, 9, 10, 11, 12, 13, 14, 15]
+    //           : vector<32xf32>, vector<32xf32>
+    //   %r1 = arith.addf %s1, %r0 : vector<8xf32>             // row 1 + r0
+    //   ...                                                    // rows 2, 3
+    // These shape-agnostic ops are then safely linearized.
+    //
+    {
+      auto options = vector::VectorMultiReductionLowering::InnerParallel;
+      RewritePatternSet patterns(&getContext());
+      vector::populateVectorMultiReductionReorderPatterns(patterns, options);
+      vector::populateVectorMultiReductionUnrollingPatterns(patterns, options);
----------------
charithaintc wrote:

I think this should be a separate pattern in `VectorLinearize.cpp`. vector linearization should not depend on unrolling for multi reduce because other patterns are in VectorLinearize.cpp. 

https://github.com/llvm/llvm-project/pull/190272


More information about the Mlir-commits mailing list