[Mlir-commits] [mlir] [mlir][Vector] Move insert/extractelement distribution patterns to insert/extract (PR #116425)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Nov 15 22:06:14 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector
Author: Kunwar Grover (Groverkss)
<details>
<summary>Changes</summary>
This is a NFC-ish change that moves vector.extractelement/vector.insertelement vector distribution patterns to vector.insert/vector.extract.
Before:
0-d/1-d vector.extract -> vector.extractelement -> distributed vector.extractelement
2-d+ vector.extract -> distributed vector.extract
After:
scalar input vector.extract -> distributed vector.extract
vector.extractelement -> distributed vector.extract
2d+ vector.extract -> distributed vector.extract
The same changes are done for insertelement/insert. The change allows us to remove reliance on vector.extractelement/vector.insertelement, which are soon to be depreciated: https://discourse.llvm.org/t/rfc-psa-remove-vector-extractelement-and-vector-insertelement-ops-in-favor-of-vector-extract-and-vector-insert-ops/71116/8
No extra tests are included because this patch doesn't introduce / remove any functionality. It only changes the chain of lowerings. This change can be completly NFC if we make the distributed operation vector.extractelement/vector.insertelement, but that is slightly weird, because you are going from extractelement -> extract -> extractelement.
---
Patch is 27.90 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116425.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+137-104)
- (modified) mlir/test/Dialect/Vector/vector-warp-distribute.mlir (+12-14)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 682eb82ac58408..bad26b0cc4b383 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1229,28 +1229,9 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
VectorType extractSrcType = extractOp.getSourceVectorType();
Location loc = extractOp.getLoc();
- // "vector.extract %v[] : vector<f32> from vector<f32>" is an invalid op.
- assert(extractSrcType.getRank() > 0 &&
- "vector.extract does not support rank 0 sources");
-
- // "vector.extract %v[] : vector<...xf32> from vector<...xf32>" can be
- // canonicalized to %v.
- if (extractOp.getNumIndices() == 0)
+ // For 1-d or 0-d source cases, we rely on WarpOpExtractScalar pattern.
+ if (extractSrcType.getRank() <= 1) {
return failure();
-
- // Rewrite vector.extract with 1d source to vector.extractelement.
- if (extractSrcType.getRank() == 1) {
- if (extractOp.hasDynamicPosition())
- // TODO: Dinamic position not supported yet.
- return failure();
-
- assert(extractOp.getNumIndices() == 1 && "expected 1 index");
- int64_t pos = extractOp.getStaticPosition()[0];
- rewriter.setInsertionPoint(extractOp);
- rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(
- extractOp, extractOp.getVector(),
- rewriter.create<arith::ConstantIndexOp>(loc, pos));
- return success();
}
// All following cases are 2d or higher dimensional source vectors.
@@ -1313,22 +1294,27 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
}
};
-/// Pattern to move out vector.extractelement of 0-D tensors. Those don't
-/// need to be distributed and can just be propagated outside of the region.
-struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
- WarpOpExtractElement(MLIRContext *ctx, WarpShuffleFromIdxFn fn,
- PatternBenefit b = 1)
+/// Pattern to move out vector.extract with a scalar result.
+/// Only supports 1-D and 0-D sources for now.
+struct WarpOpExtractScalar : public OpRewritePattern<WarpExecuteOnLane0Op> {
+ WarpOpExtractScalar(MLIRContext *ctx, WarpShuffleFromIdxFn fn,
+ PatternBenefit b = 1)
: OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
warpShuffleFromIdxFn(std::move(fn)) {}
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *operand =
- getWarpResult(warpOp, llvm::IsaPred<vector::ExtractElementOp>);
+ getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
if (!operand)
return failure();
unsigned int operandNumber = operand->getOperandNumber();
- auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
+ auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
VectorType extractSrcType = extractOp.getSourceVectorType();
+ // Only supports 1-D or 0-D sources for now.
+ if (extractSrcType.getRank() > 1) {
+ return rewriter.notifyMatchFailure(
+ extractOp, "only 0-D or 1-D source supported for now");
+ }
// TODO: Supported shuffle types should be parameterizable, similar to
// `WarpShuffleFromIdxFn`.
if (!extractSrcType.getElementType().isF32() &&
@@ -1340,7 +1326,7 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
VectorType distributedVecType;
if (!is0dOrVec1Extract) {
assert(extractSrcType.getRank() == 1 &&
- "expected that extractelement src rank is 0 or 1");
+ "expected that extract src rank is 0 or 1");
if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
return failure();
int64_t elementsPerLane =
@@ -1352,10 +1338,11 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
// Yield source vector and position (if present) from warp op.
SmallVector<Value> additionalResults{extractOp.getVector()};
SmallVector<Type> additionalResultTypes{distributedVecType};
- if (static_cast<bool>(extractOp.getPosition())) {
- additionalResults.push_back(extractOp.getPosition());
- additionalResultTypes.push_back(extractOp.getPosition().getType());
- }
+ additionalResults.append(
+ SmallVector<Value>(extractOp.getDynamicPosition()));
+ additionalResultTypes.append(
+ SmallVector<Type>(extractOp.getDynamicPosition().getTypes()));
+
Location loc = extractOp.getLoc();
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
@@ -1369,38 +1356,35 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
if (is0dOrVec1Extract) {
Value newExtract;
if (extractSrcType.getRank() == 1) {
- newExtract = rewriter.create<vector::ExtractElementOp>(
- loc, distributedVec,
- rewriter.create<arith::ConstantIndexOp>(loc, 0));
-
+ newExtract = rewriter.create<vector::ExtractOp>(loc, distributedVec, 0);
} else {
- newExtract =
- rewriter.create<vector::ExtractElementOp>(loc, distributedVec);
+ newExtract = rewriter.create<vector::ExtractOp>(loc, distributedVec,
+ ArrayRef<int64_t>{});
}
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
newExtract);
return success();
}
+ int64_t staticPos = extractOp.getStaticPosition()[0];
+ OpFoldResult pos = ShapedType::isDynamic(staticPos)
+ ? (newWarpOp->getResult(newRetIndices[1]))
+ : OpFoldResult(rewriter.getIndexAttr(staticPos));
// 1d extract: Distribute the source vector. One lane extracts and shuffles
// the value to all other lanes.
int64_t elementsPerLane = distributedVecType.getShape()[0];
AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
// tid of extracting thread: pos / elementsPerLane
- Value broadcastFromTid = rewriter.create<affine::AffineApplyOp>(
- loc, sym0.ceilDiv(elementsPerLane),
- newWarpOp->getResult(newRetIndices[1]));
+ Value broadcastFromTid = affine::makeComposedAffineApply(
+ rewriter, loc, sym0.ceilDiv(elementsPerLane), pos);
// Extract at position: pos % elementsPerLane
- Value pos =
+ Value newPos =
elementsPerLane == 1
? rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult()
- : rewriter
- .create<affine::AffineApplyOp>(
- loc, sym0 % elementsPerLane,
- newWarpOp->getResult(newRetIndices[1]))
- .getResult();
+ : affine::makeComposedAffineApply(rewriter, loc,
+ sym0 % elementsPerLane, pos);
Value extracted =
- rewriter.create<vector::ExtractElementOp>(loc, distributedVec, pos);
+ rewriter.create<vector::ExtractOp>(loc, distributedVec, newPos);
// Shuffle the extracted value to all lanes.
Value shuffled = warpShuffleFromIdxFn(
@@ -1413,31 +1397,60 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
WarpShuffleFromIdxFn warpShuffleFromIdxFn;
};
-struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
+/// Pattern to convert vector.extractelement to vector.extract.
+struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
+ WarpOpExtractElement(MLIRContext *ctx, PatternBenefit b = 1)
+ : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b) {}
+ LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *operand =
+ getWarpResult(warpOp, llvm::IsaPred<vector::ExtractElementOp>);
+ if (!operand)
+ return failure();
+ auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
+ rewriter.setInsertionPoint(extractOp);
+ if (auto pos = extractOp.getPosition()) {
+ rewriter.replaceOpWithNewOp<vector::ExtractOp>(
+ extractOp, extractOp.getVector(), pos);
+ } else {
+ rewriter.replaceOpWithNewOp<vector::ExtractOp>(
+ extractOp, extractOp.getVector(), ArrayRef<int64_t>{});
+ }
+ return success();
+ }
+};
+
+/// Pattern to move out vector.insert with a scalar input.
+/// Only supports 1-D and 0-D destinations for now.
+struct WarpOpInsertScalar : public OpRewritePattern<WarpExecuteOnLane0Op> {
using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- OpOperand *operand =
- getWarpResult(warpOp, llvm::IsaPred<vector::InsertElementOp>);
+ OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
if (!operand)
return failure();
unsigned int operandNumber = operand->getOperandNumber();
- auto insertOp = operand->get().getDefiningOp<vector::InsertElementOp>();
+ auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
VectorType vecType = insertOp.getDestVectorType();
VectorType distrType =
cast<VectorType>(warpOp.getResult(operandNumber).getType());
- bool hasPos = static_cast<bool>(insertOp.getPosition());
+
+ // Only supports 1-D or 0-D destinations for now.
+ if (vecType.getRank() > 1) {
+ return rewriter.notifyMatchFailure(
+ insertOp, "only 0-D or 1-D source supported for now");
+ }
// Yield destination vector, source scalar and position from warp op.
SmallVector<Value> additionalResults{insertOp.getDest(),
insertOp.getSource()};
SmallVector<Type> additionalResultTypes{distrType,
insertOp.getSource().getType()};
- if (hasPos) {
- additionalResults.push_back(insertOp.getPosition());
- additionalResultTypes.push_back(insertOp.getPosition().getType());
- }
+ additionalResults.append(SmallVector<Value>(insertOp.getDynamicPosition()));
+ additionalResultTypes.append(
+ SmallVector<Type>(insertOp.getDynamicPosition().getTypes()));
+
Location loc = insertOp.getLoc();
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
@@ -1446,13 +1459,27 @@ struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
rewriter.setInsertionPointAfter(newWarpOp);
Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
Value newSource = newWarpOp->getResult(newRetIndices[1]);
- Value newPos = hasPos ? newWarpOp->getResult(newRetIndices[2]) : Value();
rewriter.setInsertionPointAfter(newWarpOp);
+ OpFoldResult pos;
+ if (vecType.getRank() != 0) {
+ int64_t staticPos = insertOp.getStaticPosition()[0];
+ pos = ShapedType::isDynamic(staticPos)
+ ? (newWarpOp->getResult(newRetIndices[2]))
+ : OpFoldResult(rewriter.getIndexAttr(staticPos));
+ }
+
+ // This condition is always true for 0-d vectors.
if (vecType == distrType) {
- // Broadcast: Simply move the vector.inserelement op out.
- Value newInsert = rewriter.create<vector::InsertElementOp>(
- loc, newSource, distributedVec, newPos);
+ Value newInsert;
+ if (pos) {
+ newInsert = rewriter.create<vector::InsertOp>(loc, newSource,
+ distributedVec, pos);
+ } else {
+ newInsert = rewriter.create<vector::InsertOp>(
+ loc, newSource, distributedVec, ArrayRef<int64_t>{});
+ }
+ // Broadcast: Simply move the vector.insert op out.
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
newInsert);
return success();
@@ -1462,16 +1489,11 @@ struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
int64_t elementsPerLane = distrType.getShape()[0];
AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
// tid of extracting thread: pos / elementsPerLane
- Value insertingLane = rewriter.create<affine::AffineApplyOp>(
- loc, sym0.ceilDiv(elementsPerLane), newPos);
+ Value insertingLane = affine::makeComposedAffineApply(
+ rewriter, loc, sym0.ceilDiv(elementsPerLane), pos);
// Insert position: pos % elementsPerLane
- Value pos =
- elementsPerLane == 1
- ? rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult()
- : rewriter
- .create<affine::AffineApplyOp>(loc, sym0 % elementsPerLane,
- newPos)
- .getResult();
+ OpFoldResult newPos = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, sym0 % elementsPerLane, pos);
Value isInsertingLane = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
Value newResult =
@@ -1480,8 +1502,8 @@ struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
loc, isInsertingLane,
/*thenBuilder=*/
[&](OpBuilder &builder, Location loc) {
- Value newInsert = builder.create<vector::InsertElementOp>(
- loc, newSource, distributedVec, pos);
+ Value newInsert = builder.create<vector::InsertOp>(
+ loc, newSource, distributedVec, newPos);
builder.create<scf::YieldOp>(loc, newInsert);
},
/*elseBuilder=*/
@@ -1506,25 +1528,13 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
Location loc = insertOp.getLoc();
- // "vector.insert %v, %v[] : ..." can be canonicalized to %v.
- if (insertOp.getNumIndices() == 0)
+ // For 1-d or 0-d destination cases, we rely on WarpOpInsertScalar pattern.
+ if (insertOp.getDestVectorType().getRank() <= 1) {
return failure();
-
- // Rewrite vector.insert with 1d dest to vector.insertelement.
- if (insertOp.getDestVectorType().getRank() == 1) {
- if (insertOp.hasDynamicPosition())
- // TODO: Dinamic position not supported yet.
- return failure();
-
- assert(insertOp.getNumIndices() == 1 && "expected 1 index");
- int64_t pos = insertOp.getStaticPosition()[0];
- rewriter.setInsertionPoint(insertOp);
- rewriter.replaceOpWithNewOp<vector::InsertElementOp>(
- insertOp, insertOp.getSource(), insertOp.getDest(),
- rewriter.create<arith::ConstantIndexOp>(loc, pos));
- return success();
}
+ // All following cases are 2d or higher dimensional source vectors.
+
if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
// There is no distribution, this is a broadcast. Simply move the insert
// out of the warp op.
@@ -1620,9 +1630,32 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
}
};
+struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
+ using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *operand =
+ getWarpResult(warpOp, llvm::IsaPred<vector::InsertElementOp>);
+ if (!operand)
+ return failure();
+ auto insertOp = operand->get().getDefiningOp<vector::InsertElementOp>();
+ rewriter.setInsertionPoint(insertOp);
+ if (auto pos = insertOp.getPosition()) {
+ rewriter.replaceOpWithNewOp<vector::InsertOp>(
+ insertOp, insertOp.getSource(), insertOp.getDest(), pos);
+ } else {
+ rewriter.replaceOpWithNewOp<vector::InsertOp>(
+ insertOp, insertOp.getSource(), insertOp.getDest(),
+ ArrayRef<int64_t>{});
+ }
+ return success();
+ }
+};
+
/// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
-/// the scf.ForOp is the last operation in the region so that it doesn't change
-/// the order of execution. This creates a new scf.for region after the
+/// the scf.ForOp is the last operation in the region so that it doesn't
+/// change the order of execution. This creates a new scf.for region after the
/// WarpExecuteOnLane0Op. The new scf.for region will contain a new
/// WarpExecuteOnLane0Op region. Example:
/// ```
@@ -1668,8 +1701,8 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
if (!forOp)
return failure();
// Collect Values that come from the warp op but are outside the forOp.
- // Those Value needs to be returned by the original warpOp and passed to the
- // new op.
+ // Those Value needs to be returned by the original warpOp and passed to
+ // the new op.
llvm::SmallSetVector<Value, 32> escapingValues;
SmallVector<Type> inputTypes;
SmallVector<Type> distTypes;
@@ -1715,8 +1748,8 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointAfter(newWarpOp);
- // Create a new for op outside the region with a WarpExecuteOnLane0Op region
- // inside.
+ // Create a new for op outside the region with a WarpExecuteOnLane0Op
+ // region inside.
auto newForOp = rewriter.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep(), newOperands);
@@ -1778,8 +1811,8 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
};
/// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op.
-/// The vector is reduced in parallel. Currently limited to vector size matching
-/// the warpOp size. E.g.:
+/// The vector is reduced in parallel. Currently limited to vector size
+/// matching the warpOp size. E.g.:
/// ```
/// %r = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
/// %0 = "some_def"() : () -> (vector<32xf32>)
@@ -1880,13 +1913,13 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
PatternBenefit readBenefit) {
patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
- patterns
- .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
- WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
- WarpOpInsertElement, WarpOpInsert, WarpOpCreateMask>(
- patterns.getContext(), benefit);
- patterns.add<WarpOpExtractElement>(patterns.getContext(),
- warpShuffleFromIdxFn, benefit);
+ patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
+ WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
+ WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement,
+ WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>(
+ patterns.getContext(), benefit);
+ patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
+ benefit);
patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
benefit);
}
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 3acddd6e54639e..b4491812dc26cb 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -783,7 +783,7 @@ func.func @warp_constant(%laneid: index) -> (vector<1xf32>) {
// CHECK-PROP: %[[V:.*]] = "some_def"() : () -> vector<64xf32>
// CHECK-PROP: vector.yield %[[V]] : vector<64xf32>
// CHECK-PROP: }
-// CHECK-PROP: %[[E:.*]] = vector.extractelement %[[R]][%[[C1]] : index] : vector<2xf32>
+// CHE...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/116425
More information about the Mlir-commits
mailing list