[Mlir-commits] [mlir] f48ce52 - [mlir][vector] Pattern to clean up vector.extract during distribution
Thomas Raoux
llvmlistbot at llvm.org
Thu Jul 14 10:07:48 PDT 2022
Author: Thomas Raoux
Date: 2022-07-14T17:07:32Z
New Revision: f48ce52c4c2de8dc80e1bdd5caebbdb9f9db00ce
URL: https://github.com/llvm/llvm-project/commit/f48ce52c4c2de8dc80e1bdd5caebbdb9f9db00ce
DIFF: https://github.com/llvm/llvm-project/commit/f48ce52c4c2de8dc80e1bdd5caebbdb9f9db00ce.diff
LOG: [mlir][vector] Pattern to clean up vector.extract during distribution
This prevents blocking propagation when converting between scalar and
vector<1>
Differential Revision: https://reviews.llvm.org/D129782
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
mlir/test/Dialect/Vector/vector-warp-distribute.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 3fcd42e981dbd..8ecb8986fd9fc 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -719,6 +719,33 @@ struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
}
};
+/// Pattern to move out vector.extract of single element vector. Those don't
+/// need to be distributed and can just be propagated outside of the region.
+struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
+ using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
+ LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *operand = getWarpResult(
+ warpOp, [](Operation *op) { return isa<vector::ExtractOp>(op); });
+ if (!operand)
+ return failure();
+ unsigned int operandNumber = operand->getOperandNumber();
+ auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
+ if (extractOp.getVectorType().getNumElements() != 1)
+ return failure();
+ Location loc = extractOp.getLoc();
+ SmallVector<size_t> newRetIndices;
+ WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, {extractOp.getVector()}, {extractOp.getVectorType()},
+ newRetIndices);
+ rewriter.setInsertionPointAfter(newWarpOp);
+ Value newExtract = rewriter.create<vector::ExtractOp>(
+ loc, newWarpOp->getResult(newRetIndices[0]), extractOp.getPosition());
+ newWarpOp->getResult(operandNumber).replaceAllUsesWith(newExtract);
+ return success();
+ }
+};
+
/// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
/// the scf.ForOp is the last operation in the region so that it doesn't change
/// the order of execution. This creates a new scf.for region after the
@@ -915,8 +942,8 @@ void mlir::vector::populateDistributeTransferWriteOpPatterns(
void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
RewritePatternSet &patterns) {
patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult,
- WarpOpBroadcast, WarpOpForwardOperand, WarpOpScfForOp,
- WarpOpConstant>(patterns.getContext());
+ WarpOpBroadcast, WarpOpExtract, WarpOpForwardOperand,
+ WarpOpScfForOp, WarpOpConstant>(patterns.getContext());
}
void mlir::vector::populateDistributeReduction(
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 6d377920eeca5..1b87094c39819 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -611,3 +611,21 @@ func.func @warp_constant(%laneid: index) -> (vector<1xf32>) {
}
return %r : vector<1xf32>
}
+
+// -----
+
+// CHECK-PROP-LABEL: func.func @vector_extract_simple(
+// CHECK-PROP: %[[R:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>) {
+// CHECK-PROP: %[[V:.*]] = "some_def"() : () -> vector<1xf32>
+// CHECK-PROP: vector.yield %[[V]] : vector<1xf32>
+// CHECK-PROP: }
+// CHECK-PROP: %[[E:.*]] = vector.extract %[[R]][0] : vector<1xf32>
+// CHECK-PROP: return %[[E]] : f32
+func.func @vector_extract_simple(%laneid: index) -> (f32) {
+ %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
+ %0 = "some_def"() : () -> (vector<1xf32>)
+ %1 = vector.extract %0[0] : vector<1xf32>
+ vector.yield %1 : f32
+ }
+ return %r : f32
+}
More information about the Mlir-commits
mailing list