[Mlir-commits] [mlir] [mlir][vector] Add support for `vector.multi_reduction` and `vector.shape_cast` distribution. (PR #154438)

Charitha Saumya llvmlistbot at llvm.org
Thu Aug 28 15:21:21 PDT 2025


================
@@ -1996,6 +2027,134 @@ struct WarpOpReduction : public WarpDistributionPattern {
   DistributedReductionFn distributedReductionFn;
 };
 
+// This patterns distribute the `vector.multi_reduction` operation across
+// lanes in a warp. Currently only 2D to 1D reductions are supported and assumes
+// that source vector is distributed in column dimension (i.e. Each lane owns
+// complete column(s) of the source vector).
+// TODO: Add support for the case where source rows are distributed across
+// lanes. Requires `DistributionMapFn` to express the data distribution.
+struct WarpOpMultiReduction : public WarpDistributionPattern {
+  using Base::Base;
+  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *yieldOperand =
+        getWarpResult(warpOp, llvm::IsaPred<vector::MultiDimReductionOp>);
+    if (!yieldOperand)
+      return failure();
+    auto reductionOp =
+        cast<vector::MultiDimReductionOp>(yieldOperand->get().getDefiningOp());
+    unsigned operandNumber = yieldOperand->getOperandNumber();
+    VectorType sourceType = reductionOp.getSourceVectorType();
+
+    // Only 2D vectors are supported.
+    if (sourceType.getRank() != 2)
+      return rewriter.notifyMatchFailure(warpOp,
+                                         "Only 2D reductions are supported.");
+    ArrayRef<int64_t> reductionDims = reductionOp.getReductionDims();
+    // Only 1 reduction dimension supported. This also ensures that result is
+    // also vector type.
+    if (reductionDims.size() != 1)
+      return rewriter.notifyMatchFailure(
+          warpOp, "Only 1 reduction dimension is supported.");
+    int64_t reductionDim = reductionDims[0];
+    auto resultType = cast<VectorType>(reductionOp.getType());
+    auto distributedResultType =
+        cast<VectorType>(warpOp.getResult(operandNumber).getType());
+    Type elementType = distributedResultType.getElementType();
+
+    // Currently we make the following assumptions.
+    // 1. The source vector is distributed in the column dimension. Each lane
+    // owns complete column(s) of the source vector.
+    // 2. If the reduction dim == 0, its a lane-local col reduction. In this
+    // case each lane owns its portion of the result (i.e. result is also
+    // distributed).
+    // 3. If reduction dim == 1, its a row reduction that require cross lanes
+    // shuffles. In this case result is not distributed and broadcasted instead.
+    // TODO: These assumptions are fairly restrictive. For example, source
+    // vector can have row distributed layout. Improve support for such cases.
+    if (sourceType.getShape()[1] % warpOp.getWarpSize() != 0)
+      return rewriter.notifyMatchFailure(
+          warpOp, "Source vector dimension must be divisible by warp size.");
+    bool isResultDistributed =
+        distributedResultType.getNumElements() < resultType.getNumElements();
+    if (reductionDim == 0 && !isResultDistributed)
+      return rewriter.notifyMatchFailure(
+          warpOp,
+          "Expecting result vector to be distributed in a col reduction.");
+    if (reductionDim == 1 && isResultDistributed)
+      return rewriter.notifyMatchFailure(
+          warpOp,
+          "Expecting result vector to be broadcasted in a row reduction.");
+
+    // Create a constant vector to store the result of the reduction per lane.
+    TypedAttr zeroAttr =
+        rewriter.getZeroAttr(distributedResultType.getElementType());
+    Value result = arith::ConstantOp::create(
+        rewriter, reductionOp->getLoc(), distributedResultType,
+        DenseElementsAttr::get(distributedResultType, zeroAttr));
+    // Col reduction.
+    if (reductionDim == 0) {
+      // Compute source distributed type assuming each lane owns cols.
+      SmallVector<int64_t> shape(sourceType.getShape());
+      shape[1] = shape[1] / warpOp.getWarpSize();
+      auto sourceDistributedType = VectorType::get(shape, elementType);
+
+      // Yield the source and acc vectors from the WarpOp.
+      SmallVector<size_t> newRetIndices;
+      auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+          rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()},
+          {sourceDistributedType, distributedResultType}, newRetIndices);
+      rewriter.setInsertionPointAfter(newWarpOp);
+
+      int nCols = sourceDistributedType.getShape()[1];
+      Value source = newWarpOp.getResult(newRetIndices[0]);
+      Value acc = newWarpOp.getResult(newRetIndices[1]);
+      // For each column owned by a lane, extract the column (of size nRows x
+      // 1), shape cast to 1D (nRows), do a vector.reduction and, insert the
+      // result back to the result vector.
+      for (int i = 0; i < nCols; ++i) {
+        Value col = vector::ExtractStridedSliceOp::create(
+            rewriter, reductionOp.getLoc(), source, {0, i},
+            {sourceDistributedType.getShape()[0], 1}, {1, 1});
+        col = vector::ShapeCastOp::create(
+            rewriter, reductionOp.getLoc(),
+            VectorType::get({sourceDistributedType.getShape()[0]}, elementType),
+            col);
+        Value accCol =
+            vector::ExtractOp::create(rewriter, reductionOp.getLoc(), acc, i);
+        Value colReduce = vector::ReductionOp::create(
+            rewriter, reductionOp.getLoc(), reductionOp.getKind(), col, accCol);
+        result = vector::InsertOp::create(rewriter, reductionOp.getLoc(),
+                                          colReduce, result, i);
+      }
+      // Replace the warp op result with the new reduction op.
+      rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber), result);
+      return success();
+    }
+    // For row reductions, we simply rewrite the MultiReductionOp in terms of
+    // multiple ReductionOps. Actual distribution is done by the WarpOpReduction
+    // pattern.
+    rewriter.setInsertionPointAfter(reductionOp);
+    int nRows = sourceType.getShape()[0];
+    // For each row of the source, extract the row vector, do a reduction and,
+    // insert the result back to the result.
+    for (int i = 0; i < nRows; ++i) {
+      Value source = vector::ExtractOp::create(rewriter, reductionOp.getLoc(),
+                                               reductionOp.getSource(), i);
+      Value acc = vector::ExtractOp::create(rewriter, reductionOp.getLoc(),
+                                            reductionOp.getAcc(), i);
+      Value rowReduce = vector::ReductionOp::create(
----------------
charithaintc wrote:

This is lowered progressively. Here we lower it to bunch of `vector.reduction` ops. Then `WarpOpReduction` pattern kicks in and do the actual distribution to shuffle ops. 

`WarpOpReduction` is free to use any reduction strategy (specified by `distributedReductionFn`). Currently it by default use the one defined here.
https://github.com/llvm/llvm-project/blob/main/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp#L566


https://github.com/llvm/llvm-project/pull/154438


More information about the Mlir-commits mailing list