[Mlir-commits] [mlir] [mlir][vector] Add support for `vector.multi_reduction` and `vector.shape_cast` distribution. (PR #154438)

Chao Chen llvmlistbot at llvm.org
Fri Aug 29 07:00:55 PDT 2025


================
@@ -1996,6 +2027,134 @@ struct WarpOpReduction : public WarpDistributionPattern {
   DistributedReductionFn distributedReductionFn;
 };
 
+// This patterns distribute the `vector.multi_reduction` operation across
+// lanes in a warp. Currently only 2D to 1D reductions are supported and assumes
+// that source vector is distributed in column dimension (i.e. Each lane owns
+// complete column(s) of the source vector).
+// TODO: Add support for the case where source rows are distributed across
+// lanes. Requires `DistributionMapFn` to express the data distribution.
+struct WarpOpMultiReduction : public WarpDistributionPattern {
+  using Base::Base;
+  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *yieldOperand =
+        getWarpResult(warpOp, llvm::IsaPred<vector::MultiDimReductionOp>);
+    if (!yieldOperand)
+      return failure();
+    auto reductionOp =
+        cast<vector::MultiDimReductionOp>(yieldOperand->get().getDefiningOp());
+    unsigned operandNumber = yieldOperand->getOperandNumber();
+    VectorType sourceType = reductionOp.getSourceVectorType();
+
+    // Only 2D vectors are supported.
+    if (sourceType.getRank() != 2)
+      return rewriter.notifyMatchFailure(warpOp,
+                                         "Only 2D reductions are supported.");
+    ArrayRef<int64_t> reductionDims = reductionOp.getReductionDims();
+    // Only 1 reduction dimension supported. This also ensures that result is
+    // also vector type.
+    if (reductionDims.size() != 1)
+      return rewriter.notifyMatchFailure(
+          warpOp, "Only 1 reduction dimension is supported.");
+    int64_t reductionDim = reductionDims[0];
+    auto resultType = cast<VectorType>(reductionOp.getType());
+    auto distributedResultType =
+        cast<VectorType>(warpOp.getResult(operandNumber).getType());
+    Type elementType = distributedResultType.getElementType();
+
+    // Currently we make the following assumptions.
+    // 1. The source vector is distributed in the column dimension. Each lane
+    // owns complete column(s) of the source vector.
+    // 2. If the reduction dim == 0, its a lane-local col reduction. In this
+    // case each lane owns its portion of the result (i.e. result is also
+    // distributed).
+    // 3. If reduction dim == 1, its a row reduction that require cross lanes
+    // shuffles. In this case result is not distributed and broadcasted instead.
----------------
chencha3 wrote:

Sorry for the confusion. I meant the comments is "not distributed but broadcasted instead" not "not distributed and broadcasted instead" 

https://github.com/llvm/llvm-project/pull/154438


More information about the Mlir-commits mailing list