[Mlir-commits] [mlir] [MLIR][XeGPU] Add 2D `vector.multi_reduction` optimization (PR #171154)

Jianhui Li llvmlistbot at llvm.org
Mon Dec 8 23:15:02 PST 2025


================
@@ -416,12 +416,131 @@ class VectorExtractOpPattern final
   }
 };
 
+class MultiRed2dOp : public OpConversionPattern<vector::MultiDimReductionOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(vector::MultiDimReductionOp reductionOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (reductionOp.getReductionDims().size() != 2)
+      return rewriter.notifyMatchFailure(reductionOp,
+                                         "Expected 2D multi reduction");
+
+    auto layout = xegpu::getDistributeLayoutAttr(reductionOp.getResult());
+
+    auto dims = llvm::to_vector(reductionOp.getReductionDims());
+    auto [intraLaneDim, crossLaneDim] = getReductionDimOrder(dims, layout);
+    // Order does not matter
+    if (intraLaneDim == -1 || crossLaneDim == -1) {
+      intraLaneDim = dims[0];
+      crossLaneDim = dims[1];
+    }
+    auto loc = reductionOp.getLoc();
+    // XeGPU transforms expect vector types
+    auto sourceVecType = reductionOp.getSourceVectorType();
+    auto acc = reductionOp.getAcc();
+    bool scalarAcc = !isa<VectorType>(acc.getType());
+    if (scalarAcc)
+      acc = vector::FromElementsOp::create(
+          rewriter, loc, VectorType::get({1}, sourceVecType.getElementType()),
+          acc);
+
+    // Preserve layout in the intermediate reduction (apart from the reduced
+    // dim)
+    auto sourceSliceLayoutAttr = cast<xegpu::SliceAttr>(layout);
+    SmallVector<int64_t> sliceDims{
+        sourceSliceLayoutAttr.getDims().asArrayRef()};
+    auto foundIt = std::find(sliceDims.begin(), sliceDims.end(), crossLaneDim);
+    assert(foundIt != sliceDims.end() &&
+           "Expected to find reduction dim in slice dims");
+    sliceDims.erase(foundIt);
+    auto intraLaneLayout = xegpu::SliceAttr::get(
+        reductionOp.getContext(), sourceSliceLayoutAttr.getParent(),
+        DenseI64ArrayAttr::get(getContext(), sliceDims));
+
+    // First we do intra-lane reduction
----------------
Jianhui-Li wrote:

I think that this PR should handle the wg level, so that we can functionally support mutli-dim reduction case first. 
The order is: intra-lane, inter-lane, inter-sg.  So in case there is sg_layout, we should first reduce those sg_layout[x]==1 (regardless of lane layouts). On Xe2/Xe3, the intra-lane is -2 dim, and inter-lane is -1 dim, the inter-sg depends on sg layout but should dims to the left of -2, like (-2, -3, ....) 
For Xe4, the intra-lane is -1, inter-lane is -2, and inter-sg : -2, -3,...
if the reduction is cross multiple sg layout dims, then we may pick the order to let the rightmost dims being reduced first (smaller strides). 

https://github.com/llvm/llvm-project/pull/171154


More information about the Mlir-commits mailing list