[Mlir-commits] [mlir] [MLIR][XeGPU] Add 2D `vector.multi_reduction` optimization (PR #171154)

Jianhui Li llvmlistbot at llvm.org
Wed Dec 17 11:59:18 PST 2025


================
@@ -416,12 +416,102 @@ class VectorExtractOpPattern final
   }
 };
 
+class MultiRed2dOp : public OpConversionPattern<vector::MultiDimReductionOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(vector::MultiDimReductionOp reductionOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto sourceVecType = reductionOp.getSourceVectorType();
+    if (reductionOp.getReductionDims().size() != 2 ||
+        sourceVecType.getRank() != 2)
+      return rewriter.notifyMatchFailure(
+          reductionOp, "Expected 2D multi reduction of a 2D source");
+    auto resLayout = xegpu::getDistributeLayoutAttr(reductionOp.getResult());
+    // Retrieve and order dims for 1D decomposition (prefer intra-lane first).
+    auto dims = llvm::to_vector(reductionOp.getReductionDims());
+    auto [intraLaneDim, crossLaneDim] = getReductionDimOrder(dims, resLayout);
+    // Order does not matter
+    if (intraLaneDim == -1 || crossLaneDim == -1) {
+      intraLaneDim = dims[0];
+      crossLaneDim = dims[1];
+    }
+    auto loc = reductionOp.getLoc();
+    auto acc = reductionOp.getAcc();
+
+    // The first reduction's dist attribute does not have the cross lane dim.
+    auto resSliceLayoutAttr = cast<xegpu::SliceAttr>(resLayout);
+    SmallVector<int64_t> sliceDims{resSliceLayoutAttr.getDims().asArrayRef()};
+    auto foundIt = std::find(sliceDims.begin(), sliceDims.end(), crossLaneDim);
+    assert(foundIt != sliceDims.end() &&
+           "Expected to find reduction dim in slice dims");
+    sliceDims.erase(foundIt);
+    auto intraLaneRedResLayout = xegpu::SliceAttr::get(
+        reductionOp.getContext(), resSliceLayoutAttr.getParent(),
+        DenseI64ArrayAttr::get(getContext(), sliceDims));
+
+    SmallVector<int64_t> accShape(sourceVecType.getShape());
+    accShape.erase(accShape.begin() + intraLaneDim);
+    if (acc) {
+      acc = vector::BroadcastOp::create(
+          rewriter, loc,
+          VectorType::get(accShape, sourceVecType.getElementType()), acc);
+      xegpu::setDistributeLayoutAttr(
+          llvm::dyn_cast<OpResult>(acc),
+          cast<xegpu::DistributeLayoutAttr>(intraLaneRedResLayout));
+    }
+    Value intraLaneReduced = vector::MultiDimReductionOp::create(
+        rewriter, loc, reductionOp.getKind(), reductionOp.getSource(), acc,
+        ArrayRef<int64_t>(intraLaneDim));
+    xegpu::setDistributeLayoutAttr(
+        llvm::dyn_cast<OpResult>(intraLaneReduced),
+        cast<xegpu::DistributeLayoutAttr>(intraLaneRedResLayout));
+
+    Value crossLaneReduced = vector::ReductionOp::create(
+        rewriter, loc, reductionOp.getKind(), intraLaneReduced, nullptr);
+    xegpu::setDistributeLayoutAttr(
+        llvm::dyn_cast<OpResult>(crossLaneReduced),
+        cast<xegpu::DistributeLayoutAttr>(resLayout));
+    assert(crossLaneReduced.getType() == reductionOp.getResult().getType() &&
+           "Type mismatch");
+    rewriter.replaceOp(reductionOp, crossLaneReduced);
+    return success();
+  }
+
+private:
+  std::pair<int64_t, int64_t>
+  getReductionDimOrder(ArrayRef<int64_t> reductionDims,
+                       xegpu::DistributeLayoutAttr layout) const {
+    assert(layout.isForSubgroup() && "Must know the lane layout");
+    assert(reductionDims.size() == 2 && "Expected 2D reduction");
+    int64_t intra, cross = -1;
+    xegpu::LayoutAttr layoutAttr = dyn_cast<xegpu::LayoutAttr>(layout);
+    if (auto layoutSliceAttr = dyn_cast<xegpu::SliceAttr>(layout)) {
+      while (dyn_cast<xegpu::SliceAttr>(layoutSliceAttr.getParent()))
+        layoutSliceAttr =
+            dyn_cast<xegpu::SliceAttr>(layoutSliceAttr.getParent());
+      layoutAttr = dyn_cast<xegpu::LayoutAttr>(layoutSliceAttr.getParent());
+    }
+    assert(layoutAttr);
+    SmallVector<int64_t> laneLayout = layoutAttr.getEffectiveLaneLayoutAsInt();
+
+    assert(laneLayout.size() && "Expected a non-empty layout");
+    // try to pick a dim that does not communicate
+    for (auto dim : reductionDims) {
+      if (laneLayout[dim] == 1)
+        intra = dim;
+      else
+        cross = dim;
+    }
+    return {intra, cross};
+  }
+};
+
 } // namespace
 
 void xegpu::populateXeGPUOptimizeBlockLoadsPatterns(
     RewritePatternSet &patterns) {
   patterns.add<XeGPUCreateNdDescOpPattern, XeGPULoadNdDescOpPattern,
-               VectorExtractOpPattern>(patterns.getContext());
+               VectorExtractOpPattern, MultiRed2dOp>(patterns.getContext());
----------------
Jianhui-Li wrote:

nit: the name doesn't follow the *OpPattern style

https://github.com/llvm/llvm-project/pull/171154


More information about the Mlir-commits mailing list