[Mlir-commits] [mlir] [MLIR][XeGPU] Add 2D `vector.multi_reduction` optimization (PR #171154)
Artem Kroviakov
llvmlistbot at llvm.org
Thu Dec 18 02:53:23 PST 2025
================
@@ -416,12 +416,102 @@ class VectorExtractOpPattern final
}
};
+class MultiRed2dOp : public OpConversionPattern<vector::MultiDimReductionOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(vector::MultiDimReductionOp reductionOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto sourceVecType = reductionOp.getSourceVectorType();
+ if (reductionOp.getReductionDims().size() != 2 ||
+ sourceVecType.getRank() != 2)
+ return rewriter.notifyMatchFailure(
+ reductionOp, "Expected 2D multi reduction of a 2D source");
+ auto resLayout = xegpu::getDistributeLayoutAttr(reductionOp.getResult());
+ // Retrieve and order dims for 1D decomposition (prefer intra-lane first).
+ auto dims = llvm::to_vector(reductionOp.getReductionDims());
+ auto [intraLaneDim, crossLaneDim] = getReductionDimOrder(dims, resLayout);
+ // Order does not matter
+ if (intraLaneDim == -1 || crossLaneDim == -1) {
+ intraLaneDim = dims[0];
+ crossLaneDim = dims[1];
+ }
+ auto loc = reductionOp.getLoc();
+ auto acc = reductionOp.getAcc();
+
+ // The first reduction's dist attribute does not have the cross lane dim.
+ auto resSliceLayoutAttr = cast<xegpu::SliceAttr>(resLayout);
+ SmallVector<int64_t> sliceDims{resSliceLayoutAttr.getDims().asArrayRef()};
+ auto foundIt = std::find(sliceDims.begin(), sliceDims.end(), crossLaneDim);
+ assert(foundIt != sliceDims.end() &&
+ "Expected to find reduction dim in slice dims");
+ sliceDims.erase(foundIt);
+ auto intraLaneRedResLayout = xegpu::SliceAttr::get(
+ reductionOp.getContext(), resSliceLayoutAttr.getParent(),
+ DenseI64ArrayAttr::get(getContext(), sliceDims));
+
+ SmallVector<int64_t> accShape(sourceVecType.getShape());
+ accShape.erase(accShape.begin() + intraLaneDim);
+ if (acc) {
+ acc = vector::BroadcastOp::create(
+ rewriter, loc,
+ VectorType::get(accShape, sourceVecType.getElementType()), acc);
+ xegpu::setDistributeLayoutAttr(
+ llvm::dyn_cast<OpResult>(acc),
+ cast<xegpu::DistributeLayoutAttr>(intraLaneRedResLayout));
+ }
+ Value intraLaneReduced = vector::MultiDimReductionOp::create(
+ rewriter, loc, reductionOp.getKind(), reductionOp.getSource(), acc,
+ ArrayRef<int64_t>(intraLaneDim));
+ xegpu::setDistributeLayoutAttr(
+ llvm::dyn_cast<OpResult>(intraLaneReduced),
+ cast<xegpu::DistributeLayoutAttr>(intraLaneRedResLayout));
+
+ Value crossLaneReduced = vector::ReductionOp::create(
+ rewriter, loc, reductionOp.getKind(), intraLaneReduced, nullptr);
+ xegpu::setDistributeLayoutAttr(
+ llvm::dyn_cast<OpResult>(crossLaneReduced),
+ cast<xegpu::DistributeLayoutAttr>(resLayout));
+ assert(crossLaneReduced.getType() == reductionOp.getResult().getType() &&
+ "Type mismatch");
+ rewriter.replaceOp(reductionOp, crossLaneReduced);
+ return success();
+ }
+
+private:
+ std::pair<int64_t, int64_t>
+ getReductionDimOrder(ArrayRef<int64_t> reductionDims,
+ xegpu::DistributeLayoutAttr layout) const {
+ assert(layout.isForSubgroup() && "Must know the lane layout");
+ assert(reductionDims.size() == 2 && "Expected 2D reduction");
+ int64_t intra, cross = -1;
+ xegpu::LayoutAttr layoutAttr = dyn_cast<xegpu::LayoutAttr>(layout);
+ if (auto layoutSliceAttr = dyn_cast<xegpu::SliceAttr>(layout)) {
+ while (dyn_cast<xegpu::SliceAttr>(layoutSliceAttr.getParent()))
+ layoutSliceAttr =
+ dyn_cast<xegpu::SliceAttr>(layoutSliceAttr.getParent());
+ layoutAttr = dyn_cast<xegpu::LayoutAttr>(layoutSliceAttr.getParent());
+ }
+ assert(layoutAttr);
+ SmallVector<int64_t> laneLayout = layoutAttr.getEffectiveLaneLayoutAsInt();
+
+ assert(laneLayout.size() && "Expected a non-empty layout");
+ // try to pick a dim that does not communicate
+ for (auto dim : reductionDims) {
+ if (laneLayout[dim] == 1)
+ intra = dim;
+ else
+ cross = dim;
+ }
+ return {intra, cross};
+ }
+};
+
} // namespace
void xegpu::populateXeGPUOptimizeBlockLoadsPatterns(
RewritePatternSet &patterns) {
patterns.add<XeGPUCreateNdDescOpPattern, XeGPULoadNdDescOpPattern,
- VectorExtractOpPattern>(patterns.getContext());
+ VectorExtractOpPattern, MultiRed2dOp>(patterns.getContext());
----------------
akroviakov wrote:
Renamed
https://github.com/llvm/llvm-project/pull/171154
More information about the Mlir-commits
mailing list