[Mlir-commits] [mlir] [mlir][vector] Add support for `vector.multi_reduction` and `vector.shape_cast` distribution. (PR #154438)

Charitha Saumya llvmlistbot at llvm.org
Thu Aug 28 15:13:14 PDT 2025


================
@@ -977,44 +977,75 @@ struct WarpOpBroadcast : public WarpDistributionPattern {
 /// Pattern to move shape cast out of the warp op. shape cast is basically a
 /// no-op for warp distribution; we need to handle the shape though.
 struct WarpOpShapeCast : public WarpDistributionPattern {
-  using Base::Base;
+
+  WarpOpShapeCast(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
+      : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     OpOperand *operand =
         getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
     if (!operand)
       return failure();
-
     auto oldCastOp = operand->get().getDefiningOp<vector::ShapeCastOp>();
 
     unsigned int operandNumber = operand->getOperandNumber();
-    auto castDistributedType =
+    VectorType sourceType = oldCastOp.getSourceVectorType();
+    VectorType distributedResultType =
         cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
-    VectorType castOriginalType = oldCastOp.getSourceVectorType();
-    VectorType castResultType = castDistributedType;
-
-    // We expect the distributed type to have a smaller rank than the original
-    // type. Prepend with size-one dimensions to make them the same.
-    unsigned castDistributedRank = castDistributedType.getRank();
-    unsigned castOriginalRank = castOriginalType.getRank();
-    if (castDistributedRank < castOriginalRank) {
-      SmallVector<int64_t> shape(castOriginalRank - castDistributedRank, 1);
-      llvm::append_range(shape, castDistributedType.getShape());
-      castDistributedType =
-          VectorType::get(shape, castDistributedType.getElementType());
+    VectorType distributedSourceType = sourceType;
+    bool isResultDistributed = distributedResultType.getNumElements() <
+                               oldCastOp.getResultVectorType().getNumElements();
+
+    // If the result is not distributed, source distribted type is the same
+    // as the source type. If the result is distributed, we need to compute the
+    // distributed source type according to following rules:
+    // 1. If the source type is yielded from the warp op, we can use the
+    //    matching warp result type as the distributed source type.
+    // 2. If the source type is not yielded from the warp op, we need
----------------
charithaintc wrote:

I added two examples for the 2 cases as comment.

For both row and col reduction we assume that cols of the source vector is owned by each lane (i.e. in xegpu layouts this will be [1, 16]). Based on that we handle the reduction logic. 
Given this layout,
Col reduction is easy : just reduce your own data.
Row reduction: needs to shuffle data with neighbors and do a tree like reduce (aka butterfly reduction with shuffles)

however source layout can also be [16, 1]. This case is not supported because the vector distribution infra does not allow me to express such layout currently (it always start distributing the vector from innermost dim).  I am working on some proposal to improve this. 

https://github.com/llvm/llvm-project/pull/154438


More information about the Mlir-commits mailing list