[Mlir-commits] [mlir] [mlir][vector] Add support for `vector.multi_reduction` and `vector.shape_cast` distribution. (PR #154438)
Charitha Saumya
llvmlistbot at llvm.org
Thu Aug 28 14:35:15 PDT 2025
https://github.com/charithaintc updated https://github.com/llvm/llvm-project/pull/154438
>From eaaca7f54a9333b1841283b4483cb9c8f91f9f6b Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 19 Aug 2025 16:42:52 +0000
Subject: [PATCH 1/8] save
---
.../Vector/Transforms/VectorDistribute.cpp | 242 +++++++++++++++---
1 file changed, 213 insertions(+), 29 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index be0d28a91cba7..2d9fcaee37282 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -15,13 +15,19 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/raw_ostream.h"
+#include <cstddef>
#include <utility>
using namespace mlir;
@@ -939,8 +945,40 @@ struct WarpOpForwardOperand : public WarpDistributionPattern {
}
};
+static VectorType
+tryFindDistributedType(TypedValue<VectorType> source,
+ WarpExecuteOnLane0Op warpOp,
+ const DistributionMapFn &distributionMapFn) {
+ VectorType distributedType = source.getType();
+ // Check if the source is yielded from the warp op.
+ gpu::YieldOp yieldOp = cast<gpu::YieldOp>(
+ warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ auto *it = llvm::find_if(yieldOp->getOpOperands(), [&](OpOperand &operand) {
+ return operand.get() == source;
+ });
+
+ if (it != yieldOp->getOpOperands().end()) {
+ // If the source is yielded from the warp op, we can use the matching
+ // warp result type as the distributed source type.
+ distributedType =
+ cast<VectorType>(warpOp->getResultTypes()[it->getOperandNumber()]);
+ } else {
+ // If the source is not yielded from the warp op, we need to compute
+ // the distributed source type based on the distribution map and the
+ // warp size.
+ AffineMap map = distributionMapFn(source);
+ VectorType computed =
+ getDistributedType(source.getType(), map, warpOp.getWarpSize());
+ if (!computed)
+ return source.getType();
+ distributedType = computed;
+ }
+ return distributedType;
+}
+
struct WarpOpBroadcast : public WarpDistributionPattern {
- using Base::Base;
+ WarpOpBroadcast(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
+ : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *operand =
@@ -953,18 +991,23 @@ struct WarpOpBroadcast : public WarpDistributionPattern {
auto destVecType =
cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
Value broadcastSrc = broadcastOp.getSource();
- Type broadcastSrcType = broadcastSrc.getType();
+ Type srcDistributedType = broadcastSrc.getType();
+
+ if (isa<VectorType>(srcDistributedType))
+ srcDistributedType =
+ tryFindDistributedType(cast<TypedValue<VectorType>>(broadcastSrc),
+ warpOp, distributionMapFn);
// Check that the broadcast actually spans a set of values uniformly across
// all threads. In other words, check that each thread can reconstruct
// their own broadcast.
// For that we simply check that the broadcast we want to build makes sense.
- if (vector::isBroadcastableTo(broadcastSrcType, destVecType) !=
+ if (vector::isBroadcastableTo(srcDistributedType, destVecType) !=
vector::BroadcastableToResult::Success)
return failure();
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
- rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
+ rewriter, warpOp, {broadcastSrc}, {srcDistributedType}, newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
Value broadcasted = vector::BroadcastOp::create(
rewriter, loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
@@ -972,49 +1015,83 @@ struct WarpOpBroadcast : public WarpDistributionPattern {
broadcasted);
return success();
}
+
+private:
+ DistributionMapFn distributionMapFn;
};
/// Pattern to move shape cast out of the warp op. shape cast is basically a
/// no-op for warp distribution; we need to handle the shape though.
struct WarpOpShapeCast : public WarpDistributionPattern {
- using Base::Base;
+
+ WarpOpShapeCast(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
+ : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *operand =
getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
if (!operand)
return failure();
-
auto oldCastOp = operand->get().getDefiningOp<vector::ShapeCastOp>();
unsigned int operandNumber = operand->getOperandNumber();
- auto castDistributedType =
+ VectorType sourceType = oldCastOp.getSourceVectorType();
+ VectorType distributedResultType =
cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
- VectorType castOriginalType = oldCastOp.getSourceVectorType();
- VectorType castResultType = castDistributedType;
-
- // We expect the distributed type to have a smaller rank than the original
- // type. Prepend with size-one dimensions to make them the same.
- unsigned castDistributedRank = castDistributedType.getRank();
- unsigned castOriginalRank = castOriginalType.getRank();
- if (castDistributedRank < castOriginalRank) {
- SmallVector<int64_t> shape(castOriginalRank - castDistributedRank, 1);
- llvm::append_range(shape, castDistributedType.getShape());
- castDistributedType =
- VectorType::get(shape, castDistributedType.getElementType());
+ VectorType distributedSourceType = sourceType;
+ bool isResultDistributed = distributedResultType.getNumElements() <
+ oldCastOp.getResultVectorType().getNumElements();
+
+ // If the result is not distributed, source distribted type is the same
+ // as the source type. If the result is distributed, we need to compute the
+ // distributed source type according to following rules:
+ // 1. If the source type is yielded from the warp op, we can use the
+ // matching warp result type as the distributed source type.
+ // 2. If the source type is not yielded from the warp op, we need
+ // to compute the distributed source type based on the distribution map
+ // and the warp size.
+ if (isResultDistributed) {
+ // Check if the source is yielded from the warp op.
+ gpu::YieldOp yieldOp = cast<gpu::YieldOp>(
+ warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ auto *it =
+ llvm::find_if(yieldOp->getOpOperands(), [&](OpOperand &operand) {
+ return operand.get() == oldCastOp.getSource();
+ });
+
+ if (it != yieldOp->getOpOperands().end()) {
+ // If the source is yielded from the warp op, we can use the matching
+ // warp result type as the distributed source type.
+ distributedSourceType =
+ cast<VectorType>(warpOp->getResultTypes()[it->getOperandNumber()]);
+ } else {
+ // If the source is not yielded from the warp op, we need to compute
+ // the distributed source type based on the distribution map and the
+ // warp size.
+ AffineMap map = distributionMapFn(oldCastOp.getSource());
+ distributedSourceType =
+ getDistributedType(sourceType, map, warpOp.getWarpSize());
+ if (!distributedSourceType)
+ return rewriter.notifyMatchFailure(
+ oldCastOp,
+ "cannot compute distributed source type for shape cast");
+ }
}
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
- rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType},
+ rewriter, warpOp, {oldCastOp.getSource()}, {distributedSourceType},
newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
Value newCast = vector::ShapeCastOp::create(
- rewriter, oldCastOp.getLoc(), castResultType,
+ rewriter, oldCastOp.getLoc(), distributedResultType,
newWarpOp->getResult(newRetIndices[0]));
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast);
return success();
}
+
+private:
+ DistributionMapFn distributionMapFn;
};
/// Sink out vector.create_mask op feeding into a warp op yield.
@@ -1996,6 +2073,114 @@ struct WarpOpReduction : public WarpDistributionPattern {
DistributedReductionFn distributedReductionFn;
};
+struct VectorMultiDimReductionDistribution : public WarpDistributionPattern {
+ VectorMultiDimReductionDistribution(MLIRContext *context,
+ PatternBenefit benefit = 1)
+ : WarpDistributionPattern(context, benefit) {}
+ LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *yieldOperand =
+ getWarpResult(warpOp, llvm::IsaPred<vector::MultiDimReductionOp>);
+ if (!yieldOperand)
+ return failure();
+ auto reductionOp =
+ cast<vector::MultiDimReductionOp>(yieldOperand->get().getDefiningOp());
+ unsigned operandNumber = yieldOperand->getOperandNumber();
+ VectorType sourceType = reductionOp.getSourceVectorType();
+ VectorType distributedResultType =
+ cast<VectorType>(warpOp.getResult(operandNumber).getType());
+ Type elementType = distributedResultType.getElementType();
+ // Only 2D vectors are supported.
+ if (sourceType.getRank() != 2)
+ return rewriter.notifyMatchFailure(warpOp,
+ "Only 2D reductions are supported.");
+ ArrayRef<int64_t> reductionDims = reductionOp.getReductionDims();
+ // Only 1 reduction dimension supported.
+ if (reductionDims.size() != 1)
+ return rewriter.notifyMatchFailure(
+ warpOp, "Only 1 reduction dimension is supported.");
+
+ // Col reduction.
+ if (reductionDims[0] == 0) {
+ // Yield the source vector and the accumulator.
+ if (sourceType.getShape()[1] % warpOp.getWarpSize() != 0)
+ return rewriter.notifyMatchFailure(
+ warpOp, "Source vector dimension must be divisible by warp size.");
+ SmallVector<int64_t> shape(sourceType.getShape());
+ shape[1] = shape[1] / warpOp.getWarpSize();
+ auto sourceDistributedType = VectorType::get(shape, elementType);
+ SmallVector<size_t> newRetIndices;
+ auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()},
+ {sourceDistributedType, distributedResultType}, newRetIndices);
+ rewriter.setInsertionPointAfter(newWarpOp);
+ // Create new reduction op.
+ // auto newOp = vector::MultiDimReductionOp::create(
+ // rewriter, reductionOp.getLoc(), distributedResultType,
+ // reductionOp.getKind(),
+ // /** source = **/ newWarpOp.getResult(newRetIndices[0]),
+ // /** accumulator = **/ newWarpOp.getResult(newRetIndices[1]),
+ // reductionDims);
+ // Create a constant zero value for storing the reduction result.
+ // rewriter.setInsertionPointAfter(reductionOp);
+ auto zeroAttr =
+ rewriter.getZeroAttr(distributedResultType.getElementType());
+ Value result = arith::ConstantOp::create(
+ rewriter, reductionOp->getLoc(), distributedResultType,
+ DenseElementsAttr::get(distributedResultType, zeroAttr));
+ int nCols = sourceDistributedType.getShape()[1];
+ Value source = newWarpOp.getResult(newRetIndices[0]);
+ Value acc = newWarpOp.getResult(newRetIndices[1]);
+ for (int i = 0; i < nCols; ++i) {
+ Value col = vector::ExtractStridedSliceOp::create(
+ rewriter, reductionOp.getLoc(), source, {0, i},
+ {sourceDistributedType.getShape()[0], 1}, {1, 1});
+ col = vector::ShapeCastOp::create(
+ rewriter, reductionOp.getLoc(),
+ VectorType::get({sourceDistributedType.getShape()[0]}, elementType),
+ col);
+ Value accCol =
+ vector::ExtractOp::create(rewriter, reductionOp.getLoc(), acc, i);
+ Value colReduce = vector::ReductionOp::create(
+ rewriter, reductionOp.getLoc(), reductionOp.getKind(), col, accCol);
+ // Insert the reduced column into the result.
+ result = vector::InsertOp::create(rewriter, reductionOp.getLoc(),
+ colReduce, result, i);
+ }
+ // Replace the warp op result with the new reduction op.
+ rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber), result);
+ return success();
+ }
+ // Row reduction.
+ // Create a constant zero value for storing the reduction result.
+ rewriter.setInsertionPointAfter(reductionOp);
+ auto zeroAttr =
+ rewriter.getZeroAttr(distributedResultType.getElementType());
+ Value result = arith::ConstantOp::create(
+ rewriter, reductionOp->getLoc(), distributedResultType,
+ DenseElementsAttr::get(distributedResultType, zeroAttr));
+ // Value result = arith::ConstantOp::create(
+ // rewriter, reductionOp.getLoc(),
+ // rewriter.getIntegerAttr(reductionOp.getType(), 0));
+ int nRows = sourceType.getShape()[0];
+ // For each row, do a vector reduction.
+ for (int i = 0; i < nRows; ++i) {
+ Value source = vector::ExtractOp::create(rewriter, reductionOp.getLoc(),
+ reductionOp.getSource(), i);
+ Value acc = vector::ExtractOp::create(rewriter, reductionOp.getLoc(),
+ reductionOp.getAcc(), i);
+ Value rowReduce = vector::ReductionOp::create(
+ rewriter, reductionOp.getLoc(), reductionOp.getKind(), source, acc);
+ result = vector::InsertOp::create(rewriter, reductionOp.getLoc(),
+ rowReduce, result, i);
+ }
+ // Replace the warp op result with the final result.
+ rewriter.replaceAllUsesWith(reductionOp.getResult(), result);
+
+ return success();
+ }
+};
+
} // namespace
void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern(
@@ -2016,16 +2201,15 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
PatternBenefit readBenefit) {
patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
- patterns
- .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
- WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
- WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask,
- WarpOpExtractStridedSlice, WarpOpInsertStridedSlice>(
- patterns.getContext(), benefit);
+ patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpExtract,
+ WarpOpForwardOperand, WarpOpConstant, WarpOpInsertScalar,
+ WarpOpInsert, WarpOpCreateMask, WarpOpExtractStridedSlice,
+ WarpOpInsertStridedSlice, VectorMultiDimReductionDistribution>(
+ patterns.getContext(), benefit);
patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
benefit);
- patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
- benefit);
+ patterns.add<WarpOpScfForOp, WarpOpShapeCast, WarpOpBroadcast>(
+ patterns.getContext(), distributionMapFn, benefit);
}
void mlir::vector::populateDistributeReduction(
>From 56c3441e9443660788e51064f8206c5e4ac9fbaf Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 19 Aug 2025 23:07:10 +0000
Subject: [PATCH 2/8] save
---
.../Vector/Transforms/VectorDistribute.cpp | 110 +++++------------
.../Vector/vector-warp-distribute.mlir | 111 ++++++++++++++++++
2 files changed, 143 insertions(+), 78 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 2d9fcaee37282..6410a895fc9ae 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -945,40 +945,8 @@ struct WarpOpForwardOperand : public WarpDistributionPattern {
}
};
-static VectorType
-tryFindDistributedType(TypedValue<VectorType> source,
- WarpExecuteOnLane0Op warpOp,
- const DistributionMapFn &distributionMapFn) {
- VectorType distributedType = source.getType();
- // Check if the source is yielded from the warp op.
- gpu::YieldOp yieldOp = cast<gpu::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
- auto *it = llvm::find_if(yieldOp->getOpOperands(), [&](OpOperand &operand) {
- return operand.get() == source;
- });
-
- if (it != yieldOp->getOpOperands().end()) {
- // If the source is yielded from the warp op, we can use the matching
- // warp result type as the distributed source type.
- distributedType =
- cast<VectorType>(warpOp->getResultTypes()[it->getOperandNumber()]);
- } else {
- // If the source is not yielded from the warp op, we need to compute
- // the distributed source type based on the distribution map and the
- // warp size.
- AffineMap map = distributionMapFn(source);
- VectorType computed =
- getDistributedType(source.getType(), map, warpOp.getWarpSize());
- if (!computed)
- return source.getType();
- distributedType = computed;
- }
- return distributedType;
-}
-
struct WarpOpBroadcast : public WarpDistributionPattern {
- WarpOpBroadcast(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
- : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
+ using Base::Base;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *operand =
@@ -991,23 +959,18 @@ struct WarpOpBroadcast : public WarpDistributionPattern {
auto destVecType =
cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
Value broadcastSrc = broadcastOp.getSource();
- Type srcDistributedType = broadcastSrc.getType();
-
- if (isa<VectorType>(srcDistributedType))
- srcDistributedType =
- tryFindDistributedType(cast<TypedValue<VectorType>>(broadcastSrc),
- warpOp, distributionMapFn);
+ Type broadcastSrcType = broadcastSrc.getType();
// Check that the broadcast actually spans a set of values uniformly across
// all threads. In other words, check that each thread can reconstruct
// their own broadcast.
// For that we simply check that the broadcast we want to build makes sense.
- if (vector::isBroadcastableTo(srcDistributedType, destVecType) !=
+ if (vector::isBroadcastableTo(broadcastSrcType, destVecType) !=
vector::BroadcastableToResult::Success)
return failure();
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
- rewriter, warpOp, {broadcastSrc}, {srcDistributedType}, newRetIndices);
+ rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
Value broadcasted = vector::BroadcastOp::create(
rewriter, loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
@@ -1015,9 +978,6 @@ struct WarpOpBroadcast : public WarpDistributionPattern {
broadcasted);
return success();
}
-
-private:
- DistributionMapFn distributionMapFn;
};
/// Pattern to move shape cast out of the warp op. shape cast is basically a
@@ -2100,37 +2060,37 @@ struct VectorMultiDimReductionDistribution : public WarpDistributionPattern {
return rewriter.notifyMatchFailure(
warpOp, "Only 1 reduction dimension is supported.");
+ // Create a constant vector to store the result of the reduction per lane.
+ TypedAttr zeroAttr =
+ rewriter.getZeroAttr(distributedResultType.getElementType());
+ Value result = arith::ConstantOp::create(
+ rewriter, reductionOp->getLoc(), distributedResultType,
+ DenseElementsAttr::get(distributedResultType, zeroAttr));
+
// Col reduction.
if (reductionDims[0] == 0) {
- // Yield the source vector and the accumulator.
+ // Source vector must be distributable to lanes in the col dimension.
if (sourceType.getShape()[1] % warpOp.getWarpSize() != 0)
return rewriter.notifyMatchFailure(
warpOp, "Source vector dimension must be divisible by warp size.");
+ // Compute source distributed type.
SmallVector<int64_t> shape(sourceType.getShape());
shape[1] = shape[1] / warpOp.getWarpSize();
auto sourceDistributedType = VectorType::get(shape, elementType);
+
+ // Yield the source and acc vectors from the WarpOp.
SmallVector<size_t> newRetIndices;
auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()},
{sourceDistributedType, distributedResultType}, newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
- // Create new reduction op.
- // auto newOp = vector::MultiDimReductionOp::create(
- // rewriter, reductionOp.getLoc(), distributedResultType,
- // reductionOp.getKind(),
- // /** source = **/ newWarpOp.getResult(newRetIndices[0]),
- // /** accumulator = **/ newWarpOp.getResult(newRetIndices[1]),
- // reductionDims);
- // Create a constant zero value for storing the reduction result.
- // rewriter.setInsertionPointAfter(reductionOp);
- auto zeroAttr =
- rewriter.getZeroAttr(distributedResultType.getElementType());
- Value result = arith::ConstantOp::create(
- rewriter, reductionOp->getLoc(), distributedResultType,
- DenseElementsAttr::get(distributedResultType, zeroAttr));
+
int nCols = sourceDistributedType.getShape()[1];
Value source = newWarpOp.getResult(newRetIndices[0]);
Value acc = newWarpOp.getResult(newRetIndices[1]);
+ // For each column owned by a lane, extract the column (of size nRows x
+ // 1), shape cast to 1D (nRows), do a vector.reduction and, insert the
+ // result back to the result vector.
for (int i = 0; i < nCols; ++i) {
Value col = vector::ExtractStridedSliceOp::create(
rewriter, reductionOp.getLoc(), source, {0, i},
@@ -2143,7 +2103,6 @@ struct VectorMultiDimReductionDistribution : public WarpDistributionPattern {
vector::ExtractOp::create(rewriter, reductionOp.getLoc(), acc, i);
Value colReduce = vector::ReductionOp::create(
rewriter, reductionOp.getLoc(), reductionOp.getKind(), col, accCol);
- // Insert the reduced column into the result.
result = vector::InsertOp::create(rewriter, reductionOp.getLoc(),
colReduce, result, i);
}
@@ -2151,19 +2110,13 @@ struct VectorMultiDimReductionDistribution : public WarpDistributionPattern {
rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber), result);
return success();
}
- // Row reduction.
- // Create a constant zero value for storing the reduction result.
+ // For row reductions, we simply rewrite the MultiReductionOp in terms of
+ // multiple ReductionOps. Actual distribution is done by the WarpOpReduction
+ // pattern.
rewriter.setInsertionPointAfter(reductionOp);
- auto zeroAttr =
- rewriter.getZeroAttr(distributedResultType.getElementType());
- Value result = arith::ConstantOp::create(
- rewriter, reductionOp->getLoc(), distributedResultType,
- DenseElementsAttr::get(distributedResultType, zeroAttr));
- // Value result = arith::ConstantOp::create(
- // rewriter, reductionOp.getLoc(),
- // rewriter.getIntegerAttr(reductionOp.getType(), 0));
int nRows = sourceType.getShape()[0];
- // For each row, do a vector reduction.
+ // For each row of the source, extract the row vector, do a reduction and,
+ // insert the result back to the result.
for (int i = 0; i < nRows; ++i) {
Value source = vector::ExtractOp::create(rewriter, reductionOp.getLoc(),
reductionOp.getSource(), i);
@@ -2201,15 +2154,16 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
PatternBenefit readBenefit) {
patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
- patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpExtract,
- WarpOpForwardOperand, WarpOpConstant, WarpOpInsertScalar,
- WarpOpInsert, WarpOpCreateMask, WarpOpExtractStridedSlice,
- WarpOpInsertStridedSlice, VectorMultiDimReductionDistribution>(
- patterns.getContext(), benefit);
+ patterns
+ .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast, WarpOpExtract,
+ WarpOpForwardOperand, WarpOpConstant, WarpOpInsertScalar,
+ WarpOpInsert, WarpOpCreateMask, WarpOpExtractStridedSlice,
+ WarpOpInsertStridedSlice, VectorMultiDimReductionDistribution>(
+ patterns.getContext(), benefit);
patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
benefit);
- patterns.add<WarpOpScfForOp, WarpOpShapeCast, WarpOpBroadcast>(
- patterns.getContext(), distributionMapFn, benefit);
+ patterns.add<WarpOpScfForOp, WarpOpShapeCast>(patterns.getContext(),
+ distributionMapFn, benefit);
}
void mlir::vector::populateDistributeReduction(
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 4d2c964a6df3c..bf70fbbd27244 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -850,6 +850,83 @@ func.func @vector_reduction_acc(%laneid: index) -> (f32) {
return %r : f32
}
+// -----
+// CHECK-PROP-LABEL: func.func @vector_multi_reduction_col_reduce
+// CHECK-PROP: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0({{.*}})[32] -> (vector<32x2xf32>, vector<2xf32>) {
+// CHECK-PROP: %[[SOURCE:.*]] = "some_def"() : () -> vector<32x64xf32>
+// CHECK-PROP: %[[ACC:.*]] = "some_def"() : () -> vector<64xf32>
+// CHECK-PROP: gpu.yield %[[SOURCE]], %[[ACC]] : vector<32x64xf32>, vector<64xf32>
+// CHECK-PROP: }
+// CHECK-PROP: %[[COL0:.*]] = vector.extract_strided_slice %[[W]]#0 {offsets = [0, 0], sizes = [32, 1], strides = [1, 1]} : vector<32x2xf32> to vector<32x1xf32>
+// CHECK-PROP: %[[COL0CAST:.*]] = vector.shape_cast %[[COL0]] : vector<32x1xf32> to vector<32xf32>
+// CHECK-PROP: %[[ACC0:.*]] = vector.extract %[[W]]#1[0] : f32 from vector<2xf32>
+// CHECK-PROP: %[[REDUCE0:.*]] = vector.reduction <add>, %[[COL0CAST]], %[[ACC0]] : vector<32xf32> into f32
+// CHECK-PROP: %[[COL1:.*]] = vector.extract_strided_slice %[[W]]#0 {offsets = [0, 1], sizes = [32, 1], strides = [1, 1]} : vector<32x2xf32> to vector<32x1xf32>
+// CHECK-PROP: %[[COL1CAST:.*]] = vector.shape_cast %[[COL1]] : vector<32x1xf32> to vector<32xf32>
+// CHECK-PROP: %[[ACC1:.*]] = vector.extract %[[W]]#1[1] : f32 from vector<2xf32>
+// CHECK-PROP: %[[REDUCE1:.*]] = vector.reduction <add>, %[[COL1CAST]], %[[ACC1]] : vector<32xf32> into f32
+// CHECK-PROP: %[[R:.*]] = vector.from_elements %[[REDUCE0]], %[[REDUCE1]] : vector<2xf32>
+// CHECK-PROP: return %[[R]] : vector<2xf32>
+func.func @vector_multi_reduction_col_reduce(%laneid: index) -> vector<2xf32> {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
+ %0 = "some_def"() : () -> (vector<32x64xf32>)
+ %acc = "some_def"() : () -> (vector<64xf32>)
+ %1 = vector.multi_reduction <add>, %0, %acc [0] : vector<32x64xf32> to vector<64xf32>
+ gpu.yield %1 : vector<64xf32>
+ }
+ return %r : vector<2xf32>
+}
+
+// -----
+// CHECK-PROP-LABEL: func.func @vector_multi_reduction_row_reduce
+// CHECK-PROP: %[[C16:.*]] = arith.constant 16 : i32
+// CHECK-PROP: %[[C8:.*]] = arith.constant 8 : i32
+// CHECK-PROP: %[[C4:.*]] = arith.constant 4 : i32
+// CHECK-PROP: %[[C2:.*]] = arith.constant 2 : i32
+// CHECK-PROP: %[[C1:.*]] = arith.constant 1 : i32
+// CHECK-PROP: %[[C32:.*]] = arith.constant 32 : i32
+// CHECK-PROP: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-PROP: %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<2x1xf32>) {
+// CHECK-PROP: %[[SRC:.*]] = "some_def"() : () -> vector<2x32xf32>
+// CHECK-PROP: gpu.yield %[[SRC]] : vector<2x32xf32>
+// CHECK-PROP: }
+// CHECK-PROP: %[[T1:.*]] = vector.extract %[[W]][0, 0] : f32 from vector<2x1xf32>
+// CHECK-PROP: %[[SR:.*]], %{{.*}} = gpu.shuffle xor %[[T1]], %[[C1]], %[[C32]] : f32
+// CHECK-PROP: %[[T2:.*]] = arith.addf %[[T1]], %[[SR]] : f32
+// CHECK-PROP: %[[SR0:.*]], %{{.*}} = gpu.shuffle xor %[[T2]], %[[C2]], %[[C32]] : f32
+// CHECK-PROP: %[[T3:.*]] = arith.addf %[[T2]], %[[SR0]] : f32
+// CHECK-PROP: %[[SR2:.*]], %{{.*}} = gpu.shuffle xor %[[T3]], %[[C4]], %[[C32]] : f32
+// CHECK-PROP: %[[T4:.*]] = arith.addf %[[T3]], %[[SR2]] : f32
+// CHECK-PROP: %[[SR4:.*]], %{{.*}} = gpu.shuffle xor %[[T4]], %[[C8]], %[[C32]] : f32
+// CHECK-PROP: %[[T5:.*]] = arith.addf %[[T4]], %[[SR4]] : f32
+// CHECK-PROP: %[[SR6:.*]], %{{.*}} = gpu.shuffle xor %[[T5]], %[[C16]], %[[C32]] : f32
+// CHECK-PROP: %[[T6:.*]] = arith.addf %[[T5]], %[[SR6]] : f32
+// CHECK-PROP: %[[R0:.*]] = arith.addf %[[T6]], %[[CST]] : f32
+//
+// CHECK-PROP: %[[T8:.*]] = vector.extract %[[W]][1, 0] : f32 from vector<2x1xf32>
+// CHECK-PROP: %[[SR8:.*]], %{{.*}} = gpu.shuffle xor %[[T8]], %[[C1]], %[[C32]] : f32
+// CHECK-PROP: %[[T9:.*]] = arith.addf %[[T8]], %[[SR8]] : f32
+// CHECK-PROP: %[[SR10:.*]], %{{.*}} = gpu.shuffle xor %[[T9]], %[[C2]], %[[C32]] : f32
+// CHECK-PROP: %[[T10:.*]] = arith.addf %[[T9]], %[[SR10]] : f32
+// CHECK-PROP: %[[SR12:.*]], %{{.*}} = gpu.shuffle xor %[[T10]], %[[C4]], %[[C32]] : f32
+// CHECK-PROP: %[[T11:.*]] = arith.addf %[[T10]], %[[SR12]] : f32
+// CHECK-PROP: %[[SR14:.*]], %{{.*}} = gpu.shuffle xor %[[T11]], %[[C8]], %[[C32]] : f32
+// CHECK-PROP: %[[T12:.*]] = arith.addf %[[T11]], %[[SR14]] : f32
+// CHECK-PROP: %[[SR16:.*]], %{{.*}} = gpu.shuffle xor %[[T12]], %[[C16]], %[[C32]] : f32
+// CHECK-PROP: %[[T13:.*]] = arith.addf %[[T12]], %[[SR16]] : f32
+// CHECK-PROP: %[[R1:.*]] = arith.addf %[[T13]], %[[CST]] : f32
+// CHECK-PROP: %[[R:.*]] = vector.from_elements %[[R0]], %[[R1]] : vector<2xf32>
+// CHECK-PROP: return %[[R]] : vector<2xf32>
+func.func @vector_multi_reduction_row_reduce(%laneid: index) -> vector<2xf32> {
+ %zero = arith.constant dense<0.0> : vector<2xf32>
+ %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
+ %0 = "some_def"() : () -> (vector<2x32xf32>)
+ %1 = vector.multi_reduction <add>, %0, %zero [1] : vector<2x32xf32> to vector<2xf32>
+ gpu.yield %1 : vector<2xf32>
+ }
+ return %r : vector<2xf32>
+}
+
// -----
// CHECK-PROP-LABEL: func @warp_duplicate_yield(
@@ -1567,6 +1644,40 @@ func.func @warp_propagate_shape_cast(%laneid: index, %src: memref<32x4x32xf32>)
// CHECK-PROP: %[[CAST:.+]] = vector.shape_cast %[[READ]] : vector<1x1x4xf32> to vector<4xf32>
// CHECK-PROP: return %[[CAST]] : vector<4xf32>
+// -----
+func.func @warp_propagate_shape_cast_2d_to_2d(%laneid: index, %src: memref<64x32xf32>) -> vector<32x2xf32> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<32x2xf32>) {
+ %2 = vector.transfer_read %src[%c0, %c0], %cst : memref<64x32xf32>, vector<64x32xf32>
+ %3 = vector.shape_cast %2 : vector<64x32xf32> to vector<32x64xf32>
+ gpu.yield %3 : vector<32x64xf32>
+ }
+ return %r : vector<32x2xf32>
+}
+
+// CHECK-PROP-LABEL: func.func @warp_propagate_shape_cast_2d_to_2d
+// CHECK-PROP: %[[READ:.*]] = vector.transfer_read {{.*}} {in_bounds = [false, true]} : memref<64x32xf32>, vector<2x32xf32>
+// CHECK-PROP: %[[CAST:.*]] = vector.shape_cast %[[READ]] : vector<2x32xf32> to vector<32x2xf32>
+// CHECK-PROP: return %[[CAST]] : vector<32x2xf32>
+
+// -----
+func.func @warp_propagate_shape_cast_non_distributed_result(%laneid: index, %src: memref<64xf32>) -> vector<8x4x2xf32> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x4x2xf32>) {
+ %2 = vector.transfer_read %src[%c0], %cst : memref<64xf32>, vector<64xf32>
+ %3 = vector.shape_cast %2 : vector<64xf32> to vector<8x4x2xf32>
+ gpu.yield %3 : vector<8x4x2xf32>
+ }
+ return %r : vector<8x4x2xf32>
+}
+
+// CHECK-PROP-LABEL: func.func @warp_propagate_shape_cast_non_distributed_result
+// CHECK-PROP: %[[READ:.*]] = vector.transfer_read {{.*}} {in_bounds = [true]} : memref<64xf32>, vector<64xf32>
+// CHECK-PROP: %[[CAST:.*]] = vector.shape_cast %[[READ]] : vector<64xf32> to vector<8x4x2xf32>
+// CHECK-PROP: return %[[CAST]] : vector<8x4x2xf32>
+
// -----
func.func @warp_propagate_uniform_transfer_read(%laneid: index, %src: memref<4096xf32>, %index: index) -> vector<1xf32> {
>From 01880b561e94c6cb752e6eddb16957e00dbdc97f Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 19 Aug 2025 23:26:49 +0000
Subject: [PATCH 3/8] save
---
mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp | 6 ++++++
1 file changed, 6 insertions(+)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 6410a895fc9ae..8dc1418e09006 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -2033,6 +2033,12 @@ struct WarpOpReduction : public WarpDistributionPattern {
DistributedReductionFn distributedReductionFn;
};
+// This patterns distribute the `vector.multi_reduction` operation across
+// lanes in a warp. Currently only 2D to 1D reductions are supported and assumes
+// that source vector is distributed in column dimension (i.e. Each lane owns
+// complete column(s) of the source vector.
+// TODO: Add support for the case where source rows are distributed accross
+// lanes. Requires DistributionMapFn to express the data distribution.
struct VectorMultiDimReductionDistribution : public WarpDistributionPattern {
VectorMultiDimReductionDistribution(MLIRContext *context,
PatternBenefit benefit = 1)
>From 53da9928117634d6eb929f81cbfa59ed4c06d884 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 19 Aug 2025 23:28:13 +0000
Subject: [PATCH 4/8] save
---
mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 8dc1418e09006..c88c001f34843 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -2036,9 +2036,9 @@ struct WarpOpReduction : public WarpDistributionPattern {
// This patterns distribute the `vector.multi_reduction` operation across
// lanes in a warp. Currently only 2D to 1D reductions are supported and assumes
// that source vector is distributed in column dimension (i.e. Each lane owns
-// complete column(s) of the source vector.
-// TODO: Add support for the case where source rows are distributed accross
-// lanes. Requires DistributionMapFn to express the data distribution.
+// complete column(s) of the source vector).
+// TODO: Add support for the case where source rows are distributed across
+// lanes. Requires `DistributionMapFn` to express the data distribution.
struct VectorMultiDimReductionDistribution : public WarpDistributionPattern {
VectorMultiDimReductionDistribution(MLIRContext *context,
PatternBenefit benefit = 1)
>From affd4aadb2e0f3f7cd19b0805b34067c1fa65371 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 19 Aug 2025 23:48:08 +0000
Subject: [PATCH 5/8] save
---
.../Vector/vector-warp-distribute.mlir | 74 +++++++++----------
1 file changed, 37 insertions(+), 37 deletions(-)
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index bf70fbbd27244..bf0191655d654 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -879,44 +879,44 @@ func.func @vector_multi_reduction_col_reduce(%laneid: index) -> vector<2xf32> {
// -----
// CHECK-PROP-LABEL: func.func @vector_multi_reduction_row_reduce
-// CHECK-PROP: %[[C16:.*]] = arith.constant 16 : i32
-// CHECK-PROP: %[[C8:.*]] = arith.constant 8 : i32
-// CHECK-PROP: %[[C4:.*]] = arith.constant 4 : i32
-// CHECK-PROP: %[[C2:.*]] = arith.constant 2 : i32
-// CHECK-PROP: %[[C1:.*]] = arith.constant 1 : i32
-// CHECK-PROP: %[[C32:.*]] = arith.constant 32 : i32
-// CHECK-PROP: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK-PROP: %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<2x1xf32>) {
-// CHECK-PROP: %[[SRC:.*]] = "some_def"() : () -> vector<2x32xf32>
-// CHECK-PROP: gpu.yield %[[SRC]] : vector<2x32xf32>
-// CHECK-PROP: }
-// CHECK-PROP: %[[T1:.*]] = vector.extract %[[W]][0, 0] : f32 from vector<2x1xf32>
-// CHECK-PROP: %[[SR:.*]], %{{.*}} = gpu.shuffle xor %[[T1]], %[[C1]], %[[C32]] : f32
-// CHECK-PROP: %[[T2:.*]] = arith.addf %[[T1]], %[[SR]] : f32
-// CHECK-PROP: %[[SR0:.*]], %{{.*}} = gpu.shuffle xor %[[T2]], %[[C2]], %[[C32]] : f32
-// CHECK-PROP: %[[T3:.*]] = arith.addf %[[T2]], %[[SR0]] : f32
-// CHECK-PROP: %[[SR2:.*]], %{{.*}} = gpu.shuffle xor %[[T3]], %[[C4]], %[[C32]] : f32
-// CHECK-PROP: %[[T4:.*]] = arith.addf %[[T3]], %[[SR2]] : f32
-// CHECK-PROP: %[[SR4:.*]], %{{.*}} = gpu.shuffle xor %[[T4]], %[[C8]], %[[C32]] : f32
-// CHECK-PROP: %[[T5:.*]] = arith.addf %[[T4]], %[[SR4]] : f32
-// CHECK-PROP: %[[SR6:.*]], %{{.*}} = gpu.shuffle xor %[[T5]], %[[C16]], %[[C32]] : f32
-// CHECK-PROP: %[[T6:.*]] = arith.addf %[[T5]], %[[SR6]] : f32
-// CHECK-PROP: %[[R0:.*]] = arith.addf %[[T6]], %[[CST]] : f32
+// CHECK-PROP-DAG: %[[C16:.*]] = arith.constant 16 : i32
+// CHECK-PROP-DAG: %[[C8:.*]] = arith.constant 8 : i32
+// CHECK-PROP-DAG: %[[C4:.*]] = arith.constant 4 : i32
+// CHECK-PROP-DAG: %[[C2:.*]] = arith.constant 2 : i32
+// CHECK-PROP-DAG: %[[C1:.*]] = arith.constant 1 : i32
+// CHECK-PROP-DAG: %[[C32:.*]] = arith.constant 32 : i32
+// CHECK-PROP-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-PROP: %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<2x1xf32>) {
+// CHECK-PROP: %[[SRC:.*]] = "some_def"() : () -> vector<2x32xf32>
+// CHECK-PROP: gpu.yield %[[SRC]] : vector<2x32xf32>
+// CHECK-PROP: }
+// CHECK-PROP: %[[T1:.*]] = vector.extract %[[W]][0, 0] : f32 from vector<2x1xf32>
+// CHECK-PROP: %[[SR:.*]], %{{.*}} = gpu.shuffle xor %[[T1]], %[[C1]], %[[C32]] : f32
+// CHECK-PROP: %[[T2:.*]] = arith.addf %[[T1]], %[[SR]] : f32
+// CHECK-PROP: %[[SR0:.*]], %{{.*}} = gpu.shuffle xor %[[T2]], %[[C2]], %[[C32]] : f32
+// CHECK-PROP: %[[T3:.*]] = arith.addf %[[T2]], %[[SR0]] : f32
+// CHECK-PROP: %[[SR2:.*]], %{{.*}} = gpu.shuffle xor %[[T3]], %[[C4]], %[[C32]] : f32
+// CHECK-PROP: %[[T4:.*]] = arith.addf %[[T3]], %[[SR2]] : f32
+// CHECK-PROP: %[[SR4:.*]], %{{.*}} = gpu.shuffle xor %[[T4]], %[[C8]], %[[C32]] : f32
+// CHECK-PROP: %[[T5:.*]] = arith.addf %[[T4]], %[[SR4]] : f32
+// CHECK-PROP: %[[SR6:.*]], %{{.*}} = gpu.shuffle xor %[[T5]], %[[C16]], %[[C32]] : f32
+// CHECK-PROP: %[[T6:.*]] = arith.addf %[[T5]], %[[SR6]] : f32
+// CHECK-PROP: %[[R0:.*]] = arith.addf %[[T6]], %[[CST]] : f32
//
-// CHECK-PROP: %[[T8:.*]] = vector.extract %[[W]][1, 0] : f32 from vector<2x1xf32>
-// CHECK-PROP: %[[SR8:.*]], %{{.*}} = gpu.shuffle xor %[[T8]], %[[C1]], %[[C32]] : f32
-// CHECK-PROP: %[[T9:.*]] = arith.addf %[[T8]], %[[SR8]] : f32
-// CHECK-PROP: %[[SR10:.*]], %{{.*}} = gpu.shuffle xor %[[T9]], %[[C2]], %[[C32]] : f32
-// CHECK-PROP: %[[T10:.*]] = arith.addf %[[T9]], %[[SR10]] : f32
-// CHECK-PROP: %[[SR12:.*]], %{{.*}} = gpu.shuffle xor %[[T10]], %[[C4]], %[[C32]] : f32
-// CHECK-PROP: %[[T11:.*]] = arith.addf %[[T10]], %[[SR12]] : f32
-// CHECK-PROP: %[[SR14:.*]], %{{.*}} = gpu.shuffle xor %[[T11]], %[[C8]], %[[C32]] : f32
-// CHECK-PROP: %[[T12:.*]] = arith.addf %[[T11]], %[[SR14]] : f32
-// CHECK-PROP: %[[SR16:.*]], %{{.*}} = gpu.shuffle xor %[[T12]], %[[C16]], %[[C32]] : f32
-// CHECK-PROP: %[[T13:.*]] = arith.addf %[[T12]], %[[SR16]] : f32
-// CHECK-PROP: %[[R1:.*]] = arith.addf %[[T13]], %[[CST]] : f32
-// CHECK-PROP: %[[R:.*]] = vector.from_elements %[[R0]], %[[R1]] : vector<2xf32>
-// CHECK-PROP: return %[[R]] : vector<2xf32>
+// CHECK-PROP: %[[T8:.*]] = vector.extract %[[W]][1, 0] : f32 from vector<2x1xf32>
+// CHECK-PROP: %[[SR8:.*]], %{{.*}} = gpu.shuffle xor %[[T8]], %[[C1]], %[[C32]] : f32
+// CHECK-PROP: %[[T9:.*]] = arith.addf %[[T8]], %[[SR8]] : f32
+// CHECK-PROP: %[[SR10:.*]], %{{.*}} = gpu.shuffle xor %[[T9]], %[[C2]], %[[C32]] : f32
+// CHECK-PROP: %[[T10:.*]] = arith.addf %[[T9]], %[[SR10]] : f32
+// CHECK-PROP: %[[SR12:.*]], %{{.*}} = gpu.shuffle xor %[[T10]], %[[C4]], %[[C32]] : f32
+// CHECK-PROP: %[[T11:.*]] = arith.addf %[[T10]], %[[SR12]] : f32
+// CHECK-PROP: %[[SR14:.*]], %{{.*}} = gpu.shuffle xor %[[T11]], %[[C8]], %[[C32]] : f32
+// CHECK-PROP: %[[T12:.*]] = arith.addf %[[T11]], %[[SR14]] : f32
+// CHECK-PROP: %[[SR16:.*]], %{{.*}} = gpu.shuffle xor %[[T12]], %[[C16]], %[[C32]] : f32
+// CHECK-PROP: %[[T13:.*]] = arith.addf %[[T12]], %[[SR16]] : f32
+// CHECK-PROP: %[[R1:.*]] = arith.addf %[[T13]], %[[CST]] : f32
+// CHECK-PROP: %[[R:.*]] = vector.from_elements %[[R0]], %[[R1]] : vector<2xf32>
+// CHECK-PROP: return %[[R]] : vector<2xf32>
func.func @vector_multi_reduction_row_reduce(%laneid: index) -> vector<2xf32> {
%zero = arith.constant dense<0.0> : vector<2xf32>
%r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
>From df59c20f5d8020ab9ba78f1c360334c738a60404 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 20 Aug 2025 00:01:32 +0000
Subject: [PATCH 6/8] save
---
mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp | 6 ------
1 file changed, 6 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index c88c001f34843..b0b52919c69ce 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -15,19 +15,13 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
#include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/Value.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Transforms/RegionUtils.h"
-#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/Support/FormatVariadic.h"
-#include "llvm/Support/raw_ostream.h"
-#include <cstddef>
#include <utility>
using namespace mlir;
>From 55797318492b6a38801aa27bf9ec97d26523322e Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 20 Aug 2025 23:31:35 +0000
Subject: [PATCH 7/8] save
---
.../Vector/Transforms/VectorDistribute.cpp | 52 +++++++++++++------
1 file changed, 37 insertions(+), 15 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index b0b52919c69ce..ab0f1b55d04da 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -2033,9 +2033,8 @@ struct WarpOpReduction : public WarpDistributionPattern {
// complete column(s) of the source vector).
// TODO: Add support for the case where source rows are distributed across
// lanes. Requires `DistributionMapFn` to express the data distribution.
-struct VectorMultiDimReductionDistribution : public WarpDistributionPattern {
- VectorMultiDimReductionDistribution(MLIRContext *context,
- PatternBenefit benefit = 1)
+struct WarpOpMultiReduction : public WarpDistributionPattern {
+ WarpOpMultiReduction(MLIRContext *context, PatternBenefit benefit = 1)
: WarpDistributionPattern(context, benefit) {}
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
@@ -2047,18 +2046,46 @@ struct VectorMultiDimReductionDistribution : public WarpDistributionPattern {
cast<vector::MultiDimReductionOp>(yieldOperand->get().getDefiningOp());
unsigned operandNumber = yieldOperand->getOperandNumber();
VectorType sourceType = reductionOp.getSourceVectorType();
- VectorType distributedResultType =
- cast<VectorType>(warpOp.getResult(operandNumber).getType());
- Type elementType = distributedResultType.getElementType();
+
// Only 2D vectors are supported.
if (sourceType.getRank() != 2)
return rewriter.notifyMatchFailure(warpOp,
"Only 2D reductions are supported.");
ArrayRef<int64_t> reductionDims = reductionOp.getReductionDims();
- // Only 1 reduction dimension supported.
+ // Only 1 reduction dimension supported. This also ensures that result is
+ // also vector type.
if (reductionDims.size() != 1)
return rewriter.notifyMatchFailure(
warpOp, "Only 1 reduction dimension is supported.");
+ int64_t reductionDim = reductionDims[0];
+ auto resultType = cast<VectorType>(reductionOp.getType());
+ auto distributedResultType =
+ cast<VectorType>(warpOp.getResult(operandNumber).getType());
+ Type elementType = distributedResultType.getElementType();
+
+ // Currently we make the following assumptions.
+ // 1. The source vector is distributed in the column dimension. Each lane
+ // owns complete column(s) of the source vector.
+ // 2. If the reduction dim == 0, its a lane-local col reduction. In this
+ // case each lane owns its portion of the result (i.e. result is also
+ // distributed).
+ // 3. If reduction dim == 1, its a row reduction that require cross lanes
+ // shuffles. In this case result is not distributed and broadcasted instead.
+ // TODO: These assumptions are fairly restrictive. For example, source
+ // vector can have row distributed layout. Improve support for such cases.
+ if (sourceType.getShape()[1] % warpOp.getWarpSize() != 0)
+ return rewriter.notifyMatchFailure(
+ warpOp, "Source vector dimension must be divisible by warp size.");
+ bool isResultDistributed =
+ distributedResultType.getNumElements() < resultType.getNumElements();
+ if (reductionDim == 0 && !isResultDistributed)
+ return rewriter.notifyMatchFailure(
+ warpOp,
+ "Expecting result vector to be distributed in a col reduction.");
+ if (reductionDim == 1 && isResultDistributed)
+ return rewriter.notifyMatchFailure(
+ warpOp,
+ "Expecting result vector to be broadcasted in a row reduction.");
// Create a constant vector to store the result of the reduction per lane.
TypedAttr zeroAttr =
@@ -2066,14 +2093,9 @@ struct VectorMultiDimReductionDistribution : public WarpDistributionPattern {
Value result = arith::ConstantOp::create(
rewriter, reductionOp->getLoc(), distributedResultType,
DenseElementsAttr::get(distributedResultType, zeroAttr));
-
// Col reduction.
- if (reductionDims[0] == 0) {
- // Source vector must be distributable to lanes in the col dimension.
- if (sourceType.getShape()[1] % warpOp.getWarpSize() != 0)
- return rewriter.notifyMatchFailure(
- warpOp, "Source vector dimension must be divisible by warp size.");
- // Compute source distributed type.
+ if (reductionDim == 0) {
+ // Compute source distributed type assuming each lane owns cols.
SmallVector<int64_t> shape(sourceType.getShape());
shape[1] = shape[1] / warpOp.getWarpSize();
auto sourceDistributedType = VectorType::get(shape, elementType);
@@ -2158,7 +2180,7 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast, WarpOpExtract,
WarpOpForwardOperand, WarpOpConstant, WarpOpInsertScalar,
WarpOpInsert, WarpOpCreateMask, WarpOpExtractStridedSlice,
- WarpOpInsertStridedSlice, VectorMultiDimReductionDistribution>(
+ WarpOpInsertStridedSlice, WarpOpMultiReduction>(
patterns.getContext(), benefit);
patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
benefit);
>From 07c0364d64109faf740023107ab68ec0f242d9ca Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 20 Aug 2025 23:36:26 +0000
Subject: [PATCH 8/8] save
---
.../Vector/Transforms/VectorDistribute.cpp | 3 +-
.../Vector/vector-warp-distribute.mlir | 32 ++++++++++---------
2 files changed, 18 insertions(+), 17 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index ab0f1b55d04da..aecb6a11a7b36 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -2034,8 +2034,7 @@ struct WarpOpReduction : public WarpDistributionPattern {
// TODO: Add support for the case where source rows are distributed across
// lanes. Requires `DistributionMapFn` to express the data distribution.
struct WarpOpMultiReduction : public WarpDistributionPattern {
- WarpOpMultiReduction(MLIRContext *context, PatternBenefit benefit = 1)
- : WarpDistributionPattern(context, benefit) {}
+ using Base::Base;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *yieldOperand =
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index bf0191655d654..95b8a48404f20 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -852,21 +852,23 @@ func.func @vector_reduction_acc(%laneid: index) -> (f32) {
// -----
// CHECK-PROP-LABEL: func.func @vector_multi_reduction_col_reduce
-// CHECK-PROP: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0({{.*}})[32] -> (vector<32x2xf32>, vector<2xf32>) {
-// CHECK-PROP: %[[SOURCE:.*]] = "some_def"() : () -> vector<32x64xf32>
-// CHECK-PROP: %[[ACC:.*]] = "some_def"() : () -> vector<64xf32>
-// CHECK-PROP: gpu.yield %[[SOURCE]], %[[ACC]] : vector<32x64xf32>, vector<64xf32>
-// CHECK-PROP: }
-// CHECK-PROP: %[[COL0:.*]] = vector.extract_strided_slice %[[W]]#0 {offsets = [0, 0], sizes = [32, 1], strides = [1, 1]} : vector<32x2xf32> to vector<32x1xf32>
-// CHECK-PROP: %[[COL0CAST:.*]] = vector.shape_cast %[[COL0]] : vector<32x1xf32> to vector<32xf32>
-// CHECK-PROP: %[[ACC0:.*]] = vector.extract %[[W]]#1[0] : f32 from vector<2xf32>
-// CHECK-PROP: %[[REDUCE0:.*]] = vector.reduction <add>, %[[COL0CAST]], %[[ACC0]] : vector<32xf32> into f32
-// CHECK-PROP: %[[COL1:.*]] = vector.extract_strided_slice %[[W]]#0 {offsets = [0, 1], sizes = [32, 1], strides = [1, 1]} : vector<32x2xf32> to vector<32x1xf32>
-// CHECK-PROP: %[[COL1CAST:.*]] = vector.shape_cast %[[COL1]] : vector<32x1xf32> to vector<32xf32>
-// CHECK-PROP: %[[ACC1:.*]] = vector.extract %[[W]]#1[1] : f32 from vector<2xf32>
-// CHECK-PROP: %[[REDUCE1:.*]] = vector.reduction <add>, %[[COL1CAST]], %[[ACC1]] : vector<32xf32> into f32
-// CHECK-PROP: %[[R:.*]] = vector.from_elements %[[REDUCE0]], %[[REDUCE1]] : vector<2xf32>
-// CHECK-PROP: return %[[R]] : vector<2xf32>
+// CHECK-PROP : %[[W:.*]]:2 = gpu.warp_execute_on_lane_0({{.*}})[32] -> (vector<32x2xf32>, vector<2xf32>) {
+// CHECK-PROP : %[[SOURCE:.*]] = "some_def"() : () -> vector<32x64xf32>
+// CHECK-PROP : %[[ACC:.*]] = "some_def"() : () -> vector<64xf32>
+// CHECK-PROP : gpu.yield %[[SOURCE]], %[[ACC]] : vector<32x64xf32>, vector<64xf32>
+// CHECK-PROP : }
+// CHECK-PROP : %[[COL0:.*]] = vector.extract_strided_slice %[[W]]#0
+// CHECK-PROP-SAME : {offsets = [0, 0], sizes = [32, 1], strides = [1, 1]} : vector<32x2xf32> to vector<32x1xf32>
+// CHECK-PROP : %[[COL0CAST:.*]] = vector.shape_cast %[[COL0]] : vector<32x1xf32> to vector<32xf32>
+// CHECK-PROP : %[[ACC0:.*]] = vector.extract %[[W]]#1[0] : f32 from vector<2xf32>
+// CHECK-PROP : %[[REDUCE0:.*]] = vector.reduction <add>, %[[COL0CAST]], %[[ACC0]] : vector<32xf32> into f32
+// CHECK-PROP : %[[COL1:.*]] = vector.extract_strided_slice %[[W]]#0
+// CHECK-PROP-SAME : {offsets = [0, 1], sizes = [32, 1], strides = [1, 1]} : vector<32x2xf32> to vector<32x1xf32>
+// CHECK-PROP : %[[COL1CAST:.*]] = vector.shape_cast %[[COL1]] : vector<32x1xf32> to vector<32xf32>
+// CHECK-PROP : %[[ACC1:.*]] = vector.extract %[[W]]#1[1] : f32 from vector<2xf32>
+// CHECK-PROP : %[[REDUCE1:.*]] = vector.reduction <add>, %[[COL1CAST]], %[[ACC1]] : vector<32xf32> into f32
+// CHECK-PROP : %[[R:.*]] = vector.from_elements %[[REDUCE0]], %[[REDUCE1]] : vector<2xf32>
+// CHECK-PROP : return %[[R]] : vector<2xf32>
func.func @vector_multi_reduction_col_reduce(%laneid: index) -> vector<2xf32> {
%r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
%0 = "some_def"() : () -> (vector<32x64xf32>)
More information about the Mlir-commits
mailing list