[Mlir-commits] [mlir] [mlir][vector] Add support for `vector.multi_reduction` and `vector.shape_cast` distribution. (PR #154438)
Charitha Saumya
llvmlistbot at llvm.org
Thu Aug 28 15:13:14 PDT 2025
================
@@ -977,44 +977,75 @@ struct WarpOpBroadcast : public WarpDistributionPattern {
/// Pattern to move shape cast out of the warp op. shape cast is basically a
/// no-op for warp distribution; we need to handle the shape though.
struct WarpOpShapeCast : public WarpDistributionPattern {
- using Base::Base;
+
+ WarpOpShapeCast(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
+ : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *operand =
getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
if (!operand)
return failure();
-
auto oldCastOp = operand->get().getDefiningOp<vector::ShapeCastOp>();
unsigned int operandNumber = operand->getOperandNumber();
- auto castDistributedType =
+ VectorType sourceType = oldCastOp.getSourceVectorType();
+ VectorType distributedResultType =
cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
- VectorType castOriginalType = oldCastOp.getSourceVectorType();
- VectorType castResultType = castDistributedType;
-
- // We expect the distributed type to have a smaller rank than the original
- // type. Prepend with size-one dimensions to make them the same.
- unsigned castDistributedRank = castDistributedType.getRank();
- unsigned castOriginalRank = castOriginalType.getRank();
- if (castDistributedRank < castOriginalRank) {
- SmallVector<int64_t> shape(castOriginalRank - castDistributedRank, 1);
- llvm::append_range(shape, castDistributedType.getShape());
- castDistributedType =
- VectorType::get(shape, castDistributedType.getElementType());
+ VectorType distributedSourceType = sourceType;
+ bool isResultDistributed = distributedResultType.getNumElements() <
+ oldCastOp.getResultVectorType().getNumElements();
+
+ // If the result is not distributed, source distribted type is the same
+ // as the source type. If the result is distributed, we need to compute the
+ // distributed source type according to following rules:
+ // 1. If the source type is yielded from the warp op, we can use the
+ // matching warp result type as the distributed source type.
+ // 2. If the source type is not yielded from the warp op, we need
----------------
charithaintc wrote:
I added two examples for the 2 cases as comment.
For both row and col reduction we assume that cols of the source vector is owned by each lane (i.e. in xegpu layouts this will be [1, 16]). Based on that we handle the reduction logic.
Given this layout,
Col reduction is easy : just reduce your own data.
Row reduction: needs to shuffle data with neighbors and do a tree like reduce (aka butterfly reduction with shuffles)
however source layout can also be [16, 1]. This case is not supported because the vector distribution infra does not allow me to express such layout currently (it always start distributing the vector from innermost dim). I am working on some proposal to improve this.
https://github.com/llvm/llvm-project/pull/154438
More information about the Mlir-commits
mailing list