[Mlir-commits] [mlir] [mlir][Vector] Move insert/extractelement distribution patterns to insert/extract (PR #116425)
Kunwar Grover
llvmlistbot at llvm.org
Mon Nov 18 02:20:10 PST 2024
https://github.com/Groverkss updated https://github.com/llvm/llvm-project/pull/116425
>From 6f867b11bb5b93474d46d849a91af7c75f288bd1 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Fri, 15 Nov 2024 19:18:51 +0000
Subject: [PATCH 1/2] [mlir][Vector] Move insert/extractelement distribution
patterns to insert/extract
---
.../Vector/Transforms/VectorDistribute.cpp | 241 ++++++++++--------
.../Vector/vector-warp-distribute.mlir | 26 +-
2 files changed, 149 insertions(+), 118 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 682eb82ac58408..bad26b0cc4b383 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1229,28 +1229,9 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
VectorType extractSrcType = extractOp.getSourceVectorType();
Location loc = extractOp.getLoc();
- // "vector.extract %v[] : vector<f32> from vector<f32>" is an invalid op.
- assert(extractSrcType.getRank() > 0 &&
- "vector.extract does not support rank 0 sources");
-
- // "vector.extract %v[] : vector<...xf32> from vector<...xf32>" can be
- // canonicalized to %v.
- if (extractOp.getNumIndices() == 0)
+ // For 1-d or 0-d source cases, we rely on WarpOpExtractScalar pattern.
+ if (extractSrcType.getRank() <= 1) {
return failure();
-
- // Rewrite vector.extract with 1d source to vector.extractelement.
- if (extractSrcType.getRank() == 1) {
- if (extractOp.hasDynamicPosition())
- // TODO: Dinamic position not supported yet.
- return failure();
-
- assert(extractOp.getNumIndices() == 1 && "expected 1 index");
- int64_t pos = extractOp.getStaticPosition()[0];
- rewriter.setInsertionPoint(extractOp);
- rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(
- extractOp, extractOp.getVector(),
- rewriter.create<arith::ConstantIndexOp>(loc, pos));
- return success();
}
// All following cases are 2d or higher dimensional source vectors.
@@ -1313,22 +1294,27 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
}
};
-/// Pattern to move out vector.extractelement of 0-D tensors. Those don't
-/// need to be distributed and can just be propagated outside of the region.
-struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
- WarpOpExtractElement(MLIRContext *ctx, WarpShuffleFromIdxFn fn,
- PatternBenefit b = 1)
+/// Pattern to move out vector.extract with a scalar result.
+/// Only supports 1-D and 0-D sources for now.
+struct WarpOpExtractScalar : public OpRewritePattern<WarpExecuteOnLane0Op> {
+ WarpOpExtractScalar(MLIRContext *ctx, WarpShuffleFromIdxFn fn,
+ PatternBenefit b = 1)
: OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
warpShuffleFromIdxFn(std::move(fn)) {}
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *operand =
- getWarpResult(warpOp, llvm::IsaPred<vector::ExtractElementOp>);
+ getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
if (!operand)
return failure();
unsigned int operandNumber = operand->getOperandNumber();
- auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
+ auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
VectorType extractSrcType = extractOp.getSourceVectorType();
+ // Only supports 1-D or 0-D sources for now.
+ if (extractSrcType.getRank() > 1) {
+ return rewriter.notifyMatchFailure(
+ extractOp, "only 0-D or 1-D source supported for now");
+ }
// TODO: Supported shuffle types should be parameterizable, similar to
// `WarpShuffleFromIdxFn`.
if (!extractSrcType.getElementType().isF32() &&
@@ -1340,7 +1326,7 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
VectorType distributedVecType;
if (!is0dOrVec1Extract) {
assert(extractSrcType.getRank() == 1 &&
- "expected that extractelement src rank is 0 or 1");
+ "expected that extract src rank is 0 or 1");
if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
return failure();
int64_t elementsPerLane =
@@ -1352,10 +1338,11 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
// Yield source vector and position (if present) from warp op.
SmallVector<Value> additionalResults{extractOp.getVector()};
SmallVector<Type> additionalResultTypes{distributedVecType};
- if (static_cast<bool>(extractOp.getPosition())) {
- additionalResults.push_back(extractOp.getPosition());
- additionalResultTypes.push_back(extractOp.getPosition().getType());
- }
+ additionalResults.append(
+ SmallVector<Value>(extractOp.getDynamicPosition()));
+ additionalResultTypes.append(
+ SmallVector<Type>(extractOp.getDynamicPosition().getTypes()));
+
Location loc = extractOp.getLoc();
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
@@ -1369,38 +1356,35 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
if (is0dOrVec1Extract) {
Value newExtract;
if (extractSrcType.getRank() == 1) {
- newExtract = rewriter.create<vector::ExtractElementOp>(
- loc, distributedVec,
- rewriter.create<arith::ConstantIndexOp>(loc, 0));
-
+ newExtract = rewriter.create<vector::ExtractOp>(loc, distributedVec, 0);
} else {
- newExtract =
- rewriter.create<vector::ExtractElementOp>(loc, distributedVec);
+ newExtract = rewriter.create<vector::ExtractOp>(loc, distributedVec,
+ ArrayRef<int64_t>{});
}
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
newExtract);
return success();
}
+ int64_t staticPos = extractOp.getStaticPosition()[0];
+ OpFoldResult pos = ShapedType::isDynamic(staticPos)
+ ? (newWarpOp->getResult(newRetIndices[1]))
+ : OpFoldResult(rewriter.getIndexAttr(staticPos));
// 1d extract: Distribute the source vector. One lane extracts and shuffles
// the value to all other lanes.
int64_t elementsPerLane = distributedVecType.getShape()[0];
AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
// tid of extracting thread: pos / elementsPerLane
- Value broadcastFromTid = rewriter.create<affine::AffineApplyOp>(
- loc, sym0.ceilDiv(elementsPerLane),
- newWarpOp->getResult(newRetIndices[1]));
+ Value broadcastFromTid = affine::makeComposedAffineApply(
+ rewriter, loc, sym0.ceilDiv(elementsPerLane), pos);
// Extract at position: pos % elementsPerLane
- Value pos =
+ Value newPos =
elementsPerLane == 1
? rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult()
- : rewriter
- .create<affine::AffineApplyOp>(
- loc, sym0 % elementsPerLane,
- newWarpOp->getResult(newRetIndices[1]))
- .getResult();
+ : affine::makeComposedAffineApply(rewriter, loc,
+ sym0 % elementsPerLane, pos);
Value extracted =
- rewriter.create<vector::ExtractElementOp>(loc, distributedVec, pos);
+ rewriter.create<vector::ExtractOp>(loc, distributedVec, newPos);
// Shuffle the extracted value to all lanes.
Value shuffled = warpShuffleFromIdxFn(
@@ -1413,31 +1397,60 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
WarpShuffleFromIdxFn warpShuffleFromIdxFn;
};
-struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
+/// Pattern to convert vector.extractelement to vector.extract.
+struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
+ WarpOpExtractElement(MLIRContext *ctx, PatternBenefit b = 1)
+ : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b) {}
+ LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *operand =
+ getWarpResult(warpOp, llvm::IsaPred<vector::ExtractElementOp>);
+ if (!operand)
+ return failure();
+ auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
+ rewriter.setInsertionPoint(extractOp);
+ if (auto pos = extractOp.getPosition()) {
+ rewriter.replaceOpWithNewOp<vector::ExtractOp>(
+ extractOp, extractOp.getVector(), pos);
+ } else {
+ rewriter.replaceOpWithNewOp<vector::ExtractOp>(
+ extractOp, extractOp.getVector(), ArrayRef<int64_t>{});
+ }
+ return success();
+ }
+};
+
+/// Pattern to move out vector.insert with a scalar input.
+/// Only supports 1-D and 0-D destinations for now.
+struct WarpOpInsertScalar : public OpRewritePattern<WarpExecuteOnLane0Op> {
using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- OpOperand *operand =
- getWarpResult(warpOp, llvm::IsaPred<vector::InsertElementOp>);
+ OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
if (!operand)
return failure();
unsigned int operandNumber = operand->getOperandNumber();
- auto insertOp = operand->get().getDefiningOp<vector::InsertElementOp>();
+ auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
VectorType vecType = insertOp.getDestVectorType();
VectorType distrType =
cast<VectorType>(warpOp.getResult(operandNumber).getType());
- bool hasPos = static_cast<bool>(insertOp.getPosition());
+
+ // Only supports 1-D or 0-D destinations for now.
+ if (vecType.getRank() > 1) {
+ return rewriter.notifyMatchFailure(
+ insertOp, "only 0-D or 1-D source supported for now");
+ }
// Yield destination vector, source scalar and position from warp op.
SmallVector<Value> additionalResults{insertOp.getDest(),
insertOp.getSource()};
SmallVector<Type> additionalResultTypes{distrType,
insertOp.getSource().getType()};
- if (hasPos) {
- additionalResults.push_back(insertOp.getPosition());
- additionalResultTypes.push_back(insertOp.getPosition().getType());
- }
+ additionalResults.append(SmallVector<Value>(insertOp.getDynamicPosition()));
+ additionalResultTypes.append(
+ SmallVector<Type>(insertOp.getDynamicPosition().getTypes()));
+
Location loc = insertOp.getLoc();
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
@@ -1446,13 +1459,27 @@ struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
rewriter.setInsertionPointAfter(newWarpOp);
Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
Value newSource = newWarpOp->getResult(newRetIndices[1]);
- Value newPos = hasPos ? newWarpOp->getResult(newRetIndices[2]) : Value();
rewriter.setInsertionPointAfter(newWarpOp);
+ OpFoldResult pos;
+ if (vecType.getRank() != 0) {
+ int64_t staticPos = insertOp.getStaticPosition()[0];
+ pos = ShapedType::isDynamic(staticPos)
+ ? (newWarpOp->getResult(newRetIndices[2]))
+ : OpFoldResult(rewriter.getIndexAttr(staticPos));
+ }
+
+ // This condition is always true for 0-d vectors.
if (vecType == distrType) {
- // Broadcast: Simply move the vector.inserelement op out.
- Value newInsert = rewriter.create<vector::InsertElementOp>(
- loc, newSource, distributedVec, newPos);
+ Value newInsert;
+ if (pos) {
+ newInsert = rewriter.create<vector::InsertOp>(loc, newSource,
+ distributedVec, pos);
+ } else {
+ newInsert = rewriter.create<vector::InsertOp>(
+ loc, newSource, distributedVec, ArrayRef<int64_t>{});
+ }
+ // Broadcast: Simply move the vector.insert op out.
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
newInsert);
return success();
@@ -1462,16 +1489,11 @@ struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
int64_t elementsPerLane = distrType.getShape()[0];
AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
// tid of extracting thread: pos / elementsPerLane
- Value insertingLane = rewriter.create<affine::AffineApplyOp>(
- loc, sym0.ceilDiv(elementsPerLane), newPos);
+ Value insertingLane = affine::makeComposedAffineApply(
+ rewriter, loc, sym0.ceilDiv(elementsPerLane), pos);
// Insert position: pos % elementsPerLane
- Value pos =
- elementsPerLane == 1
- ? rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult()
- : rewriter
- .create<affine::AffineApplyOp>(loc, sym0 % elementsPerLane,
- newPos)
- .getResult();
+ OpFoldResult newPos = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, sym0 % elementsPerLane, pos);
Value isInsertingLane = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
Value newResult =
@@ -1480,8 +1502,8 @@ struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
loc, isInsertingLane,
/*thenBuilder=*/
[&](OpBuilder &builder, Location loc) {
- Value newInsert = builder.create<vector::InsertElementOp>(
- loc, newSource, distributedVec, pos);
+ Value newInsert = builder.create<vector::InsertOp>(
+ loc, newSource, distributedVec, newPos);
builder.create<scf::YieldOp>(loc, newInsert);
},
/*elseBuilder=*/
@@ -1506,25 +1528,13 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
Location loc = insertOp.getLoc();
- // "vector.insert %v, %v[] : ..." can be canonicalized to %v.
- if (insertOp.getNumIndices() == 0)
+ // For 1-d or 0-d destination cases, we rely on WarpOpInsertScalar pattern.
+ if (insertOp.getDestVectorType().getRank() <= 1) {
return failure();
-
- // Rewrite vector.insert with 1d dest to vector.insertelement.
- if (insertOp.getDestVectorType().getRank() == 1) {
- if (insertOp.hasDynamicPosition())
- // TODO: Dinamic position not supported yet.
- return failure();
-
- assert(insertOp.getNumIndices() == 1 && "expected 1 index");
- int64_t pos = insertOp.getStaticPosition()[0];
- rewriter.setInsertionPoint(insertOp);
- rewriter.replaceOpWithNewOp<vector::InsertElementOp>(
- insertOp, insertOp.getSource(), insertOp.getDest(),
- rewriter.create<arith::ConstantIndexOp>(loc, pos));
- return success();
}
+ // All following cases are 2d or higher dimensional source vectors.
+
if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
// There is no distribution, this is a broadcast. Simply move the insert
// out of the warp op.
@@ -1620,9 +1630,32 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
}
};
+struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
+ using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *operand =
+ getWarpResult(warpOp, llvm::IsaPred<vector::InsertElementOp>);
+ if (!operand)
+ return failure();
+ auto insertOp = operand->get().getDefiningOp<vector::InsertElementOp>();
+ rewriter.setInsertionPoint(insertOp);
+ if (auto pos = insertOp.getPosition()) {
+ rewriter.replaceOpWithNewOp<vector::InsertOp>(
+ insertOp, insertOp.getSource(), insertOp.getDest(), pos);
+ } else {
+ rewriter.replaceOpWithNewOp<vector::InsertOp>(
+ insertOp, insertOp.getSource(), insertOp.getDest(),
+ ArrayRef<int64_t>{});
+ }
+ return success();
+ }
+};
+
/// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
-/// the scf.ForOp is the last operation in the region so that it doesn't change
-/// the order of execution. This creates a new scf.for region after the
+/// the scf.ForOp is the last operation in the region so that it doesn't
+/// change the order of execution. This creates a new scf.for region after the
/// WarpExecuteOnLane0Op. The new scf.for region will contain a new
/// WarpExecuteOnLane0Op region. Example:
/// ```
@@ -1668,8 +1701,8 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
if (!forOp)
return failure();
// Collect Values that come from the warp op but are outside the forOp.
- // Those Value needs to be returned by the original warpOp and passed to the
- // new op.
+ // Those Value needs to be returned by the original warpOp and passed to
+ // the new op.
llvm::SmallSetVector<Value, 32> escapingValues;
SmallVector<Type> inputTypes;
SmallVector<Type> distTypes;
@@ -1715,8 +1748,8 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointAfter(newWarpOp);
- // Create a new for op outside the region with a WarpExecuteOnLane0Op region
- // inside.
+ // Create a new for op outside the region with a WarpExecuteOnLane0Op
+ // region inside.
auto newForOp = rewriter.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep(), newOperands);
@@ -1778,8 +1811,8 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
};
/// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op.
-/// The vector is reduced in parallel. Currently limited to vector size matching
-/// the warpOp size. E.g.:
+/// The vector is reduced in parallel. Currently limited to vector size
+/// matching the warpOp size. E.g.:
/// ```
/// %r = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
/// %0 = "some_def"() : () -> (vector<32xf32>)
@@ -1880,13 +1913,13 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
PatternBenefit readBenefit) {
patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
- patterns
- .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
- WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
- WarpOpInsertElement, WarpOpInsert, WarpOpCreateMask>(
- patterns.getContext(), benefit);
- patterns.add<WarpOpExtractElement>(patterns.getContext(),
- warpShuffleFromIdxFn, benefit);
+ patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
+ WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
+ WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement,
+ WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>(
+ patterns.getContext(), benefit);
+ patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
+ benefit);
patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
benefit);
}
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 3acddd6e54639e..b4491812dc26cb 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -783,7 +783,7 @@ func.func @warp_constant(%laneid: index) -> (vector<1xf32>) {
// CHECK-PROP: %[[V:.*]] = "some_def"() : () -> vector<64xf32>
// CHECK-PROP: vector.yield %[[V]] : vector<64xf32>
// CHECK-PROP: }
-// CHECK-PROP: %[[E:.*]] = vector.extractelement %[[R]][%[[C1]] : index] : vector<2xf32>
+// CHECK-PROP: %[[E:.*]] = vector.extract %[[R]][%[[C1]]] : f32 from vector<2xf32>
// CHECK-PROP: %[[SHUFFLED:.*]], %{{.*}} = gpu.shuffle idx %[[E]], %[[C5_I32]]
// CHECK-PROP: return %[[SHUFFLED]] : f32
func.func @vector_extract_1d(%laneid: index) -> (f32) {
@@ -874,7 +874,7 @@ func.func @vector_extract_3d(%laneid: index) -> (vector<4x96xf32>) {
// CHECK-PROP: %[[V:.*]] = "some_def"() : () -> vector<f32>
// CHECK-PROP: vector.yield %[[V]] : vector<f32>
// CHECK-PROP: }
-// CHECK-PROP: %[[E:.*]] = vector.extractelement %[[R]][] : vector<f32>
+// CHECK-PROP: %[[E:.*]] = vector.extract %[[R]][] : f32 from vector<f32>
// CHECK-PROP: return %[[E]] : f32
func.func @vector_extractelement_0d(%laneid: index) -> (f32) {
%r = vector.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
@@ -888,12 +888,11 @@ func.func @vector_extractelement_0d(%laneid: index) -> (f32) {
// -----
// CHECK-PROP-LABEL: func.func @vector_extractelement_1element(
-// CHECK-PROP: %[[C0:.*]] = arith.constant 0 : index
// CHECK-PROP: %[[R:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>) {
// CHECK-PROP: %[[V:.*]] = "some_def"() : () -> vector<1xf32>
// CHECK-PROP: vector.yield %[[V]] : vector<1xf32>
// CHECK-PROP: }
-// CHECK-PROP: %[[E:.*]] = vector.extractelement %[[R]][%[[C0]] : index] : vector<1xf32>
+// CHECK-PROP: %[[E:.*]] = vector.extract %[[R]][0] : f32 from vector<1xf32>
// CHECK-PROP: return %[[E]] : f32
func.func @vector_extractelement_1element(%laneid: index) -> (f32) {
%r = vector.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
@@ -918,7 +917,7 @@ func.func @vector_extractelement_1element(%laneid: index) -> (f32) {
// CHECK-PROP: }
// CHECK-PROP: %[[FROM_LANE:.*]] = affine.apply #[[$map]]()[%[[POS]]]
// CHECK-PROP: %[[DISTR_POS:.*]] = affine.apply #[[$map1]]()[%[[POS]]]
-// CHECK-PROP: %[[EXTRACTED:.*]] = vector.extractelement %[[W]][%[[DISTR_POS]] : index] : vector<3xf32>
+// CHECK-PROP: %[[EXTRACTED:.*]] = vector.extract %[[W]][%[[DISTR_POS]]] : f32 from vector<3xf32>
// CHECK-PROP: %[[FROM_LANE_I32:.*]] = arith.index_cast %[[FROM_LANE]] : index to i32
// CHECK-PROP: %[[SHUFFLED:.*]], %{{.*}} = gpu.shuffle idx %[[EXTRACTED]], %[[FROM_LANE_I32]], %[[C32]] : f32
// CHECK-PROP: return %[[SHUFFLED]]
@@ -938,7 +937,7 @@ func.func @vector_extractelement_1d(%laneid: index, %pos: index) -> (f32) {
// CHECK-PROP-LABEL: func.func @vector_extractelement_1d_index(
// CHECK-PROP: vector.warp_execute_on_lane_0(%{{.*}})[32] -> (index) {
// CHECK-PROP: "some_def"
-// CHECK-PROP: vector.extractelement
+// CHECK-PROP: vector.extract
// CHECK-PROP: vector.yield {{.*}} : index
// CHECK-PROP: }
func.func @vector_extractelement_1d_index(%laneid: index, %pos: index) -> (index) {
@@ -1151,7 +1150,7 @@ func.func @warp_execute_nd_distribute(%laneid: index, %v0: vector<1x64x1xf32>, %
// CHECK-PROP: %[[INSERTING_POS:.*]] = affine.apply #[[$MAP1]]()[%[[POS]]]
// CHECK-PROP: %[[SHOULD_INSERT:.*]] = arith.cmpi eq, %[[LANEID]], %[[INSERTING_LANE]] : index
// CHECK-PROP: %[[R:.*]] = scf.if %[[SHOULD_INSERT]] -> (vector<3xf32>) {
-// CHECK-PROP: %[[INSERT:.*]] = vector.insertelement %[[W]]#1, %[[W]]#0[%[[INSERTING_POS]] : index]
+// CHECK-PROP: %[[INSERT:.*]] = vector.insert %[[W]]#1, %[[W]]#0 [%[[INSERTING_POS]]]
// CHECK-PROP: scf.yield %[[INSERT]]
// CHECK-PROP: } else {
// CHECK-PROP: scf.yield %[[W]]#0
@@ -1175,7 +1174,7 @@ func.func @vector_insertelement_1d(%laneid: index, %pos: index) -> (vector<3xf32
// CHECK-PROP: %[[VEC:.*]] = "some_def"
// CHECK-PROP: %[[VAL:.*]] = "another_def"
// CHECK-PROP: vector.yield %[[VEC]], %[[VAL]]
-// CHECK-PROP: vector.insertelement %[[W]]#1, %[[W]]#0[%[[POS]] : index] : vector<96xf32>
+// CHECK-PROP: vector.insert %[[W]]#1, %[[W]]#0 [%[[POS]]] : f32 into vector<96xf32>
func.func @vector_insertelement_1d_broadcast(%laneid: index, %pos: index) -> (vector<96xf32>) {
%r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<96xf32>) {
%0 = "some_def"() : () -> (vector<96xf32>)
@@ -1193,7 +1192,7 @@ func.func @vector_insertelement_1d_broadcast(%laneid: index, %pos: index) -> (ve
// CHECK-PROP: %[[VEC:.*]] = "some_def"
// CHECK-PROP: %[[VAL:.*]] = "another_def"
// CHECK-PROP: vector.yield %[[VEC]], %[[VAL]]
-// CHECK-PROP: vector.insertelement %[[W]]#1, %[[W]]#0[] : vector<f32>
+// CHECK-PROP: vector.insert %[[W]]#1, %[[W]]#0 [] : f32 into vector<f32>
func.func @vector_insertelement_0d(%laneid: index) -> (vector<f32>) {
%r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<f32>) {
%0 = "some_def"() : () -> (vector<f32>)
@@ -1208,7 +1207,6 @@ func.func @vector_insertelement_0d(%laneid: index) -> (vector<f32>) {
// CHECK-PROP-LABEL: func @vector_insert_1d(
// CHECK-PROP-SAME: %[[LANEID:.*]]: index
-// CHECK-PROP-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-PROP-DAG: %[[C26:.*]] = arith.constant 26 : index
// CHECK-PROP: %[[W:.*]]:2 = vector.warp_execute_on_lane_0{{.*}} -> (vector<3xf32>, f32)
// CHECK-PROP: %[[VEC:.*]] = "some_def"
@@ -1216,7 +1214,7 @@ func.func @vector_insertelement_0d(%laneid: index) -> (vector<f32>) {
// CHECK-PROP: vector.yield %[[VEC]], %[[VAL]]
// CHECK-PROP: %[[SHOULD_INSERT:.*]] = arith.cmpi eq, %[[LANEID]], %[[C26]]
// CHECK-PROP: %[[R:.*]] = scf.if %[[SHOULD_INSERT]] -> (vector<3xf32>) {
-// CHECK-PROP: %[[INSERT:.*]] = vector.insertelement %[[W]]#1, %[[W]]#0[%[[C1]] : index]
+// CHECK-PROP: %[[INSERT:.*]] = vector.insert %[[W]]#1, %[[W]]#0 [1]
// CHECK-PROP: scf.yield %[[INSERT]]
// CHECK-PROP: } else {
// CHECK-PROP: scf.yield %[[W]]#0
@@ -1316,14 +1314,14 @@ func.func @vector_insert_2d_broadcast(%laneid: index) -> (vector<4x96xf32>) {
// CHECK-PROP: %[[GATHER:.*]] = vector.gather %[[AR1]][{{.*}}]
// CHECK-PROP: %[[EXTRACT:.*]] = vector.extract %[[GATHER]][0] : vector<64xi32> from vector<1x64xi32>
// CHECK-PROP: %[[CAST:.*]] = arith.index_cast %[[EXTRACT]] : vector<64xi32> to vector<64xindex>
-// CHECK-PROP: %[[EXTRACTELT:.*]] = vector.extractelement %[[CAST]][{{.*}}: i32] : vector<64xindex>
+// CHECK-PROP: %[[EXTRACTELT:.*]] = vector.extract %[[CAST]][{{.*}}] : index from vector<64xindex>
// CHECK-PROP: vector.yield %[[EXTRACTELT]] : index
// CHECK-PROP: %[[APPLY:.*]] = affine.apply #[[$MAP]]()[%[[THREADID]]]
// CHECK-PROP: %[[TRANSFERREAD:.*]] = vector.transfer_read %[[AR2]][%[[C0]], %[[W]], %[[APPLY]]],
// CHECK-PROP: return %[[TRANSFERREAD]]
func.func @transfer_read_prop_operands(%in2: vector<1x2xindex>, %ar1 : memref<1x4x2xi32>, %ar2 : memref<1x4x1024xf32>)-> vector<2xf32> {
%0 = gpu.thread_id x
- %c0_i32 = arith.constant 0 : i32
+ %c0_i32 = arith.constant 0 : index
%c0 = arith.constant 0 : index
%cst = arith.constant dense<0> : vector<1x64xi32>
%cst_0 = arith.constant dense<true> : vector<1x64xi1>
@@ -1336,7 +1334,7 @@ func.func @transfer_read_prop_operands(%in2: vector<1x2xindex>, %ar1 : memref<1
%28 = vector.gather %ar1[%c0, %c0, %c0] [%arg4], %cst_0, %cst : memref<1x4x2xi32>, vector<1x64xindex>, vector<1x64xi1>, vector<1x64xi32> into vector<1x64xi32>
%29 = vector.extract %28[0] : vector<64xi32> from vector<1x64xi32>
%30 = arith.index_cast %29 : vector<64xi32> to vector<64xindex>
- %36 = vector.extractelement %30[%c0_i32 : i32] : vector<64xindex>
+ %36 = vector.extractelement %30[%c0_i32 : index] : vector<64xindex>
%37 = vector.transfer_read %ar2[%c0, %36, %c0], %cst_6 {in_bounds = [true]} : memref<1x4x1024xf32>, vector<64xf32>
vector.yield %37 : vector<64xf32>
}
>From 823909cab5d5732f9eda7c34981cd6a98df4f1fd Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Mon, 18 Nov 2024 10:19:39 +0000
Subject: [PATCH 2/2] Address comments
---
.../Vector/Transforms/VectorDistribute.cpp | 41 ++++++++-----------
1 file changed, 17 insertions(+), 24 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index bad26b0cc4b383..dc5eb2527f949a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1355,12 +1355,9 @@ struct WarpOpExtractScalar : public OpRewritePattern<WarpExecuteOnLane0Op> {
// All lanes extract the scalar.
if (is0dOrVec1Extract) {
Value newExtract;
- if (extractSrcType.getRank() == 1) {
- newExtract = rewriter.create<vector::ExtractOp>(loc, distributedVec, 0);
- } else {
- newExtract = rewriter.create<vector::ExtractOp>(loc, distributedVec,
- ArrayRef<int64_t>{});
- }
+ SmallVector<int64_t> indices(extractSrcType.getRank(), 0);
+ newExtract =
+ rewriter.create<vector::ExtractOp>(loc, distributedVec, indices);
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
newExtract);
return success();
@@ -1408,14 +1405,13 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
if (!operand)
return failure();
auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
- rewriter.setInsertionPoint(extractOp);
+ SmallVector<OpFoldResult> indices;
if (auto pos = extractOp.getPosition()) {
- rewriter.replaceOpWithNewOp<vector::ExtractOp>(
- extractOp, extractOp.getVector(), pos);
- } else {
- rewriter.replaceOpWithNewOp<vector::ExtractOp>(
- extractOp, extractOp.getVector(), ArrayRef<int64_t>{});
+ indices.push_back(pos);
}
+ rewriter.setInsertionPoint(extractOp);
+ rewriter.replaceOpWithNewOp<vector::ExtractOp>(
+ extractOp, extractOp.getVector(), indices);
return success();
}
};
@@ -1472,13 +1468,12 @@ struct WarpOpInsertScalar : public OpRewritePattern<WarpExecuteOnLane0Op> {
// This condition is always true for 0-d vectors.
if (vecType == distrType) {
Value newInsert;
+ SmallVector<OpFoldResult> indices;
if (pos) {
- newInsert = rewriter.create<vector::InsertOp>(loc, newSource,
- distributedVec, pos);
- } else {
- newInsert = rewriter.create<vector::InsertOp>(
- loc, newSource, distributedVec, ArrayRef<int64_t>{});
+ indices.push_back(pos);
}
+ newInsert = rewriter.create<vector::InsertOp>(loc, newSource,
+ distributedVec, indices);
// Broadcast: Simply move the vector.insert op out.
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
newInsert);
@@ -1640,15 +1635,13 @@ struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
if (!operand)
return failure();
auto insertOp = operand->get().getDefiningOp<vector::InsertElementOp>();
- rewriter.setInsertionPoint(insertOp);
+ SmallVector<OpFoldResult> indices;
if (auto pos = insertOp.getPosition()) {
- rewriter.replaceOpWithNewOp<vector::InsertOp>(
- insertOp, insertOp.getSource(), insertOp.getDest(), pos);
- } else {
- rewriter.replaceOpWithNewOp<vector::InsertOp>(
- insertOp, insertOp.getSource(), insertOp.getDest(),
- ArrayRef<int64_t>{});
+ indices.push_back(pos);
}
+ rewriter.setInsertionPoint(insertOp);
+ rewriter.replaceOpWithNewOp<vector::InsertOp>(
+ insertOp, insertOp.getSource(), insertOp.getDest(), indices);
return success();
}
};
More information about the Mlir-commits
mailing list