[Mlir-commits] [mlir] [mlir][vector] Add support for `vector.multi_reduction` and `vector.shape_cast` distribution. (PR #154438)
Charitha Saumya
llvmlistbot at llvm.org
Thu Aug 28 15:21:21 PDT 2025
================
@@ -1996,6 +2027,134 @@ struct WarpOpReduction : public WarpDistributionPattern {
DistributedReductionFn distributedReductionFn;
};
+// This patterns distribute the `vector.multi_reduction` operation across
+// lanes in a warp. Currently only 2D to 1D reductions are supported and assumes
+// that source vector is distributed in column dimension (i.e. Each lane owns
+// complete column(s) of the source vector).
+// TODO: Add support for the case where source rows are distributed across
+// lanes. Requires `DistributionMapFn` to express the data distribution.
+struct WarpOpMultiReduction : public WarpDistributionPattern {
+ using Base::Base;
+ LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *yieldOperand =
+ getWarpResult(warpOp, llvm::IsaPred<vector::MultiDimReductionOp>);
+ if (!yieldOperand)
+ return failure();
+ auto reductionOp =
+ cast<vector::MultiDimReductionOp>(yieldOperand->get().getDefiningOp());
+ unsigned operandNumber = yieldOperand->getOperandNumber();
+ VectorType sourceType = reductionOp.getSourceVectorType();
+
+ // Only 2D vectors are supported.
+ if (sourceType.getRank() != 2)
+ return rewriter.notifyMatchFailure(warpOp,
+ "Only 2D reductions are supported.");
+ ArrayRef<int64_t> reductionDims = reductionOp.getReductionDims();
+ // Only 1 reduction dimension supported. This also ensures that result is
+ // also vector type.
+ if (reductionDims.size() != 1)
+ return rewriter.notifyMatchFailure(
+ warpOp, "Only 1 reduction dimension is supported.");
+ int64_t reductionDim = reductionDims[0];
+ auto resultType = cast<VectorType>(reductionOp.getType());
+ auto distributedResultType =
+ cast<VectorType>(warpOp.getResult(operandNumber).getType());
+ Type elementType = distributedResultType.getElementType();
+
+ // Currently we make the following assumptions.
+ // 1. The source vector is distributed in the column dimension. Each lane
+ // owns complete column(s) of the source vector.
+ // 2. If the reduction dim == 0, its a lane-local col reduction. In this
+ // case each lane owns its portion of the result (i.e. result is also
+ // distributed).
+ // 3. If reduction dim == 1, its a row reduction that require cross lanes
+ // shuffles. In this case result is not distributed and broadcasted instead.
+ // TODO: These assumptions are fairly restrictive. For example, source
+ // vector can have row distributed layout. Improve support for such cases.
+ if (sourceType.getShape()[1] % warpOp.getWarpSize() != 0)
+ return rewriter.notifyMatchFailure(
+ warpOp, "Source vector dimension must be divisible by warp size.");
+ bool isResultDistributed =
+ distributedResultType.getNumElements() < resultType.getNumElements();
+ if (reductionDim == 0 && !isResultDistributed)
+ return rewriter.notifyMatchFailure(
+ warpOp,
+ "Expecting result vector to be distributed in a col reduction.");
+ if (reductionDim == 1 && isResultDistributed)
+ return rewriter.notifyMatchFailure(
+ warpOp,
+ "Expecting result vector to be broadcasted in a row reduction.");
+
+ // Create a constant vector to store the result of the reduction per lane.
+ TypedAttr zeroAttr =
+ rewriter.getZeroAttr(distributedResultType.getElementType());
+ Value result = arith::ConstantOp::create(
+ rewriter, reductionOp->getLoc(), distributedResultType,
+ DenseElementsAttr::get(distributedResultType, zeroAttr));
+ // Col reduction.
+ if (reductionDim == 0) {
+ // Compute source distributed type assuming each lane owns cols.
+ SmallVector<int64_t> shape(sourceType.getShape());
+ shape[1] = shape[1] / warpOp.getWarpSize();
+ auto sourceDistributedType = VectorType::get(shape, elementType);
+
+ // Yield the source and acc vectors from the WarpOp.
+ SmallVector<size_t> newRetIndices;
+ auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()},
+ {sourceDistributedType, distributedResultType}, newRetIndices);
+ rewriter.setInsertionPointAfter(newWarpOp);
+
+ int nCols = sourceDistributedType.getShape()[1];
+ Value source = newWarpOp.getResult(newRetIndices[0]);
+ Value acc = newWarpOp.getResult(newRetIndices[1]);
+ // For each column owned by a lane, extract the column (of size nRows x
+ // 1), shape cast to 1D (nRows), do a vector.reduction and, insert the
+ // result back to the result vector.
+ for (int i = 0; i < nCols; ++i) {
+ Value col = vector::ExtractStridedSliceOp::create(
+ rewriter, reductionOp.getLoc(), source, {0, i},
+ {sourceDistributedType.getShape()[0], 1}, {1, 1});
+ col = vector::ShapeCastOp::create(
+ rewriter, reductionOp.getLoc(),
+ VectorType::get({sourceDistributedType.getShape()[0]}, elementType),
+ col);
+ Value accCol =
+ vector::ExtractOp::create(rewriter, reductionOp.getLoc(), acc, i);
+ Value colReduce = vector::ReductionOp::create(
+ rewriter, reductionOp.getLoc(), reductionOp.getKind(), col, accCol);
+ result = vector::InsertOp::create(rewriter, reductionOp.getLoc(),
+ colReduce, result, i);
+ }
+ // Replace the warp op result with the new reduction op.
+ rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber), result);
+ return success();
+ }
+ // For row reductions, we simply rewrite the MultiReductionOp in terms of
+ // multiple ReductionOps. Actual distribution is done by the WarpOpReduction
+ // pattern.
+ rewriter.setInsertionPointAfter(reductionOp);
+ int nRows = sourceType.getShape()[0];
+ // For each row of the source, extract the row vector, do a reduction and,
+ // insert the result back to the result.
+ for (int i = 0; i < nRows; ++i) {
+ Value source = vector::ExtractOp::create(rewriter, reductionOp.getLoc(),
+ reductionOp.getSource(), i);
+ Value acc = vector::ExtractOp::create(rewriter, reductionOp.getLoc(),
+ reductionOp.getAcc(), i);
+ Value rowReduce = vector::ReductionOp::create(
----------------
charithaintc wrote:
This is lowered progressively. Here we lower it to bunch of `vector.reduction` ops. Then `WarpOpReduction` pattern kicks in and do the actual distribution to shuffle ops.
`WarpOpReduction` is free to use any reduction strategy (specified by `distributedReductionFn`). Currently it by default use the one defined here.
https://github.com/llvm/llvm-project/blob/main/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp#L566
https://github.com/llvm/llvm-project/pull/154438
More information about the Mlir-commits
mailing list