[Mlir-commits] [mlir] [mlir][vector] Add support for `vector.multi_reduction` and `vector.shape_cast` distribution. (PR #154438)
Adam Siemieniuk
llvmlistbot at llvm.org
Tue Sep 2 08:58:33 PDT 2025
================
@@ -1020,44 +1020,75 @@ struct WarpOpBroadcast : public WarpDistributionPattern {
/// Pattern to move shape cast out of the warp op. shape cast is basically a
/// no-op for warp distribution; we need to handle the shape though.
struct WarpOpShapeCast : public WarpDistributionPattern {
- using Base::Base;
+
+ WarpOpShapeCast(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
+ : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *operand =
getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
if (!operand)
return failure();
-
auto oldCastOp = operand->get().getDefiningOp<vector::ShapeCastOp>();
- unsigned int operandNumber = operand->getOperandNumber();
- auto castDistributedType =
+ unsigned operandNumber = operand->getOperandNumber();
+ VectorType sourceType = oldCastOp.getSourceVectorType();
+ VectorType distributedResultType =
cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
- VectorType castOriginalType = oldCastOp.getSourceVectorType();
- VectorType castResultType = castDistributedType;
-
- // We expect the distributed type to have a smaller rank than the original
- // type. Prepend with size-one dimensions to make them the same.
- unsigned castDistributedRank = castDistributedType.getRank();
- unsigned castOriginalRank = castOriginalType.getRank();
- if (castDistributedRank < castOriginalRank) {
- SmallVector<int64_t> shape(castOriginalRank - castDistributedRank, 1);
- llvm::append_range(shape, castDistributedType.getShape());
- castDistributedType =
- VectorType::get(shape, castDistributedType.getElementType());
+ VectorType distributedSourceType = sourceType;
+ bool isResultDistributed = distributedResultType.getNumElements() <
+ oldCastOp.getResultVectorType().getNumElements();
+
+ // If the result is not distributed, source distribted type is the same
----------------
adam-smnk wrote:
```suggestion
// If the result is not distributed, source distributed type is the same
```
nit: typo
https://github.com/llvm/llvm-project/pull/154438
More information about the Mlir-commits
mailing list