[Mlir-commits] [mlir] [mlir][vector] Add support for `vector.multi_reduction` and `vector.shape_cast` distribution. (PR #154438)

Adam Siemieniuk llvmlistbot at llvm.org
Tue Sep 2 08:58:33 PDT 2025


================
@@ -1020,44 +1020,75 @@ struct WarpOpBroadcast : public WarpDistributionPattern {
 /// Pattern to move shape cast out of the warp op. shape cast is basically a
 /// no-op for warp distribution; we need to handle the shape though.
 struct WarpOpShapeCast : public WarpDistributionPattern {
-  using Base::Base;
+
+  WarpOpShapeCast(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
+      : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     OpOperand *operand =
         getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
     if (!operand)
       return failure();
-
     auto oldCastOp = operand->get().getDefiningOp<vector::ShapeCastOp>();
 
-    unsigned int operandNumber = operand->getOperandNumber();
-    auto castDistributedType =
+    unsigned operandNumber = operand->getOperandNumber();
+    VectorType sourceType = oldCastOp.getSourceVectorType();
+    VectorType distributedResultType =
         cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
-    VectorType castOriginalType = oldCastOp.getSourceVectorType();
-    VectorType castResultType = castDistributedType;
-
-    // We expect the distributed type to have a smaller rank than the original
-    // type. Prepend with size-one dimensions to make them the same.
-    unsigned castDistributedRank = castDistributedType.getRank();
-    unsigned castOriginalRank = castOriginalType.getRank();
-    if (castDistributedRank < castOriginalRank) {
-      SmallVector<int64_t> shape(castOriginalRank - castDistributedRank, 1);
-      llvm::append_range(shape, castDistributedType.getShape());
-      castDistributedType =
-          VectorType::get(shape, castDistributedType.getElementType());
+    VectorType distributedSourceType = sourceType;
+    bool isResultDistributed = distributedResultType.getNumElements() <
+                               oldCastOp.getResultVectorType().getNumElements();
+
+    // If the result is not distributed, source distribted type is the same
----------------
adam-smnk wrote:

```suggestion
    // If the result is not distributed, source distributed type is the same
```
nit: typo

https://github.com/llvm/llvm-project/pull/154438


More information about the Mlir-commits mailing list