[Mlir-commits] [mlir] 73ce971 - [mlir][vector] Distribute vector.insertelement op
Matthias Springer
llvmlistbot at llvm.org
Mon Jan 9 07:44:56 PST 2023
Author: Matthias Springer
Date: 2023-01-09T16:41:08+01:00
New Revision: 73ce971c630ff87e16788091097b3232847acc48
URL: https://github.com/llvm/llvm-project/commit/73ce971c630ff87e16788091097b3232847acc48
DIFF: https://github.com/llvm/llvm-project/commit/73ce971c630ff87e16788091097b3232847acc48.diff
LOG: [mlir][vector] Distribute vector.insertelement op
In case of a distribution, only one lane inserts the scalar value. In case of a broadcast, every lane inserts the scalar.
Differential Revision: https://reviews.llvm.org/D137929
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
mlir/test/Dialect/Vector/vector-warp-distribute.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 60ca036228dc..df7b240a4d5c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1033,8 +1033,13 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
Value broadcastFromTid = rewriter.create<AffineApplyOp>(
loc, sym0.ceilDiv(elementsPerLane), extractOp.getPosition());
// Extract at position: pos % elementsPerLane
- Value pos = rewriter.create<AffineApplyOp>(loc, sym0 % elementsPerLane,
- extractOp.getPosition());
+ Value pos =
+ elementsPerLane == 1
+ ? rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult()
+ : rewriter
+ .create<AffineApplyOp>(loc, sym0 % elementsPerLane,
+ extractOp.getPosition())
+ .getResult();
Value extracted =
rewriter.create<vector::ExtractElementOp>(loc, distributedVec, pos);
@@ -1049,6 +1054,85 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
WarpShuffleFromIdxFn warpShuffleFromIdxFn;
};
+struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
+ using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *operand = getWarpResult(
+ warpOp, [](Operation *op) { return isa<vector::InsertElementOp>(op); });
+ if (!operand)
+ return failure();
+ unsigned int operandNumber = operand->getOperandNumber();
+ auto insertOp = operand->get().getDefiningOp<vector::InsertElementOp>();
+ VectorType vecType = insertOp.getDestVectorType();
+ VectorType distrType =
+ warpOp.getResult(operandNumber).getType().cast<VectorType>();
+ bool hasPos = static_cast<bool>(insertOp.getPosition());
+
+ // Yield destination vector, source scalar and position from warp op.
+ SmallVector<Value> additionalResults{insertOp.getDest(),
+ insertOp.getSource()};
+ SmallVector<Type> additionalResultTypes{distrType,
+ insertOp.getSource().getType()};
+ if (hasPos) {
+ additionalResults.push_back(insertOp.getPosition());
+ additionalResultTypes.push_back(insertOp.getPosition().getType());
+ }
+ Location loc = insertOp.getLoc();
+ SmallVector<size_t> newRetIndices;
+ WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, additionalResults, additionalResultTypes,
+ newRetIndices);
+ rewriter.setInsertionPointAfter(newWarpOp);
+ Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
+ Value newSource = newWarpOp->getResult(newRetIndices[1]);
+ Value newPos = hasPos ? newWarpOp->getResult(newRetIndices[2]) : Value();
+ rewriter.setInsertionPointAfter(newWarpOp);
+
+ if (vecType == distrType) {
+ // Broadcast: Simply move the vector.inserelement op out.
+ Value newInsert = rewriter.create<vector::InsertElementOp>(
+ loc, newSource, distributedVec, newPos);
+ newWarpOp->getResult(operandNumber).replaceAllUsesWith(newInsert);
+ return success();
+ }
+
+ // This is a distribution. Only one lane should insert.
+ int64_t elementsPerLane = distrType.getShape()[0];
+ AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
+ // tid of extracting thread: pos / elementsPerLane
+ Value insertingLane = rewriter.create<AffineApplyOp>(
+ loc, sym0.ceilDiv(elementsPerLane), newPos);
+ // Insert position: pos % elementsPerLane
+ Value pos =
+ elementsPerLane == 1
+ ? rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult()
+ : rewriter
+ .create<AffineApplyOp>(loc, sym0 % elementsPerLane, newPos)
+ .getResult();
+ Value isInsertingLane = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
+ Value newResult =
+ rewriter
+ .create<scf::IfOp>(
+ loc, distrType, isInsertingLane,
+ /*thenBuilder=*/
+ [&](OpBuilder &builder, Location loc) {
+ Value newInsert = builder.create<vector::InsertElementOp>(
+ loc, newSource, distributedVec, pos);
+ builder.create<scf::YieldOp>(loc, newInsert);
+ },
+ /*elseBuilder=*/
+ [&](OpBuilder &builder, Location loc) {
+ builder.create<scf::YieldOp>(loc, distributedVec);
+ })
+ .getResult(0);
+ newWarpOp->getResult(operandNumber).replaceAllUsesWith(newResult);
+ return success();
+ }
+};
+
/// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
/// the scf.ForOp is the last operation in the region so that it doesn't change
/// the order of execution. This creates a new scf.for region after the
@@ -1303,7 +1387,8 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit) {
patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult,
WarpOpBroadcast, WarpOpExtract, WarpOpForwardOperand,
- WarpOpConstant>(patterns.getContext(), benefit);
+ WarpOpConstant, WarpOpInsertElement>(patterns.getContext(),
+ benefit);
patterns.add<WarpOpExtractElement>(patterns.getContext(),
warpShuffleFromIdxFn, benefit);
patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 5a238c57c933..b19c3cd96a26 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -930,3 +930,67 @@ func.func @warp_execute_nd_distribute(%laneid: index, %v0: vector<1x64x1xf32>, %
// CHECK-SCF-IF: return %[[R0]], %[[R1]] : vector<1x64x1xf32>, vector<1x2x128xf32>
return %r#0, %r#1 : vector<1x64x1xf32>, vector<1x2x128xf32>
}
+
+// -----
+
+// CHECK-PROP: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 ceildiv 3)>
+// CHECK-PROP: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 3)>
+// CHECK-PROP-LABEL: func @vector_insertelement_1d(
+// CHECK-PROP-SAME: %[[LANEID:.*]]: index, %[[POS:.*]]: index
+// CHECK-PROP: %[[W:.*]]:2 = vector.warp_execute_on_lane_0{{.*}} -> (vector<3xf32>, f32)
+// CHECK-PROP: %[[INSERTING_LANE:.*]] = affine.apply #[[$MAP]]()[%[[POS]]]
+// CHECK-PROP: %[[INSERTING_POS:.*]] = affine.apply #[[$MAP1]]()[%[[POS]]]
+// CHECK-PROP: %[[SHOULD_INSERT:.*]] = arith.cmpi eq, %[[LANEID]], %[[INSERTING_LANE]] : index
+// CHECK-PROP: %[[R:.*]] = scf.if %[[SHOULD_INSERT]] -> (vector<3xf32>) {
+// CHECK-PROP: %[[INSERT:.*]] = vector.insertelement %[[W]]#1, %[[W]]#0[%[[INSERTING_POS]] : index]
+// CHECK-PROP: scf.yield %[[INSERT]]
+// CHECK-PROP: } else {
+// CHECK-PROP: scf.yield %[[W]]#0
+// CHECK-PROP: }
+// CHECK-PROP: return %[[R]]
+func.func @vector_insertelement_1d(%laneid: index, %pos: index) -> (vector<3xf32>) {
+ %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<3xf32>) {
+ %0 = "some_def"() : () -> (vector<96xf32>)
+ %f = "another_def"() : () -> (f32)
+ %1 = vector.insertelement %f, %0[%pos : index] : vector<96xf32>
+ vector.yield %1 : vector<96xf32>
+ }
+ return %r : vector<3xf32>
+}
+
+// -----
+
+// CHECK-PROP-LABEL: func @vector_insertelement_1d_broadcast(
+// CHECK-PROP-SAME: %[[LANEID:.*]]: index, %[[POS:.*]]: index
+// CHECK-PROP: %[[W:.*]]:2 = vector.warp_execute_on_lane_0{{.*}} -> (vector<96xf32>, f32)
+// CHECK-PROP: %[[VEC:.*]] = "some_def"
+// CHECK-PROP: %[[VAL:.*]] = "another_def"
+// CHECK-PROP: vector.yield %[[VEC]], %[[VAL]]
+// CHECK-PROP: vector.insertelement %[[W]]#1, %[[W]]#0[%[[POS]] : index] : vector<96xf32>
+func.func @vector_insertelement_1d_broadcast(%laneid: index, %pos: index) -> (vector<96xf32>) {
+ %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<96xf32>) {
+ %0 = "some_def"() : () -> (vector<96xf32>)
+ %f = "another_def"() : () -> (f32)
+ %1 = vector.insertelement %f, %0[%pos : index] : vector<96xf32>
+ vector.yield %1 : vector<96xf32>
+ }
+ return %r : vector<96xf32>
+}
+
+// -----
+
+// CHECK-PROP-LABEL: func @vector_insertelement_0d(
+// CHECK-PROP: %[[W:.*]]:2 = vector.warp_execute_on_lane_0{{.*}} -> (vector<f32>, f32)
+// CHECK-PROP: %[[VEC:.*]] = "some_def"
+// CHECK-PROP: %[[VAL:.*]] = "another_def"
+// CHECK-PROP: vector.yield %[[VEC]], %[[VAL]]
+// CHECK-PROP: vector.insertelement %[[W]]#1, %[[W]]#0[] : vector<f32>
+func.func @vector_insertelement_0d(%laneid: index) -> (vector<f32>) {
+ %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<f32>) {
+ %0 = "some_def"() : () -> (vector<f32>)
+ %f = "another_def"() : () -> (f32)
+ %1 = vector.insertelement %f, %0[] : vector<f32>
+ vector.yield %1 : vector<f32>
+ }
+ return %r : vector<f32>
+}
More information about the Mlir-commits
mailing list