[Mlir-commits] [mlir] 1757164 - [mlir][vector] Add distribution for extract from 0d vector
Thomas Raoux
llvmlistbot at llvm.org
Fri Oct 14 16:17:26 PDT 2022
Author: Thomas Raoux
Date: 2022-10-14T23:06:42Z
New Revision: 1757164eed244b221c6c078baa7c836e4809e133
URL: https://github.com/llvm/llvm-project/commit/1757164eed244b221c6c078baa7c836e4809e133
DIFF: https://github.com/llvm/llvm-project/commit/1757164eed244b221c6c078baa7c836e4809e133.diff
LOG: [mlir][vector] Add distribution for extract from 0d vector
Differential Revision: https://reviews.llvm.org/D135994
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
mlir/test/Dialect/Vector/vector-warp-distribute.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 3c4f20fd5f9c2..f730044abcf85 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -895,6 +895,34 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
}
};
+/// Pattern to move out vector.extractelement of 0-D tensors. Those don't
+/// need to be distributed and can just be propagated outside of the region.
+struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
+ using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
+ LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
+ return isa<vector::ExtractElementOp>(op);
+ });
+ if (!operand)
+ return failure();
+ unsigned int operandNumber = operand->getOperandNumber();
+ auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
+ if (extractOp.getVectorType().getRank() != 0)
+ return failure();
+ Location loc = extractOp.getLoc();
+ SmallVector<size_t> newRetIndices;
+ WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, {extractOp.getVector()}, {extractOp.getVectorType()},
+ newRetIndices);
+ rewriter.setInsertionPointAfter(newWarpOp);
+ Value newExtract = rewriter.create<vector::ExtractElementOp>(
+ loc, newWarpOp->getResult(newRetIndices[0]));
+ newWarpOp->getResult(operandNumber).replaceAllUsesWith(newExtract);
+ return success();
+ }
+};
+
/// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
/// the scf.ForOp is the last operation in the region so that it doesn't change
/// the order of execution. This creates a new scf.for region after the
@@ -1093,8 +1121,9 @@ void mlir::vector::populateDistributeTransferWriteOpPatterns(
void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult,
- WarpOpBroadcast, WarpOpExtract, WarpOpForwardOperand,
- WarpOpScfForOp, WarpOpConstant>(patterns.getContext(), benefit);
+ WarpOpBroadcast, WarpOpExtract, WarpOpExtractElement,
+ WarpOpForwardOperand, WarpOpScfForOp, WarpOpConstant>(
+ patterns.getContext(), benefit);
}
void mlir::vector::populateDistributeReduction(
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 3978d94f377b4..49c36fe18c90d 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -632,6 +632,24 @@ func.func @vector_extract_simple(%laneid: index) -> (f32) {
// -----
+// CHECK-PROP-LABEL: func.func @vector_extractelement_simple(
+// CHECK-PROP: %[[R:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<f32>) {
+// CHECK-PROP: %[[V:.*]] = "some_def"() : () -> vector<f32>
+// CHECK-PROP: vector.yield %[[V]] : vector<f32>
+// CHECK-PROP: }
+// CHECK-PROP: %[[E:.*]] = vector.extractelement %[[R]][] : vector<f32>
+// CHECK-PROP: return %[[E]] : f32
+func.func @vector_extractelement_simple(%laneid: index) -> (f32) {
+ %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
+ %0 = "some_def"() : () -> (vector<f32>)
+ %1 = vector.extractelement %0[] : vector<f32>
+ vector.yield %1 : f32
+ }
+ return %r : f32
+}
+
+// -----
+
// CHECK-PROP: func @lane_dependent_warp_propagate_read
// CHECK-PROP-SAME: %[[ID:.*]]: index
func.func @lane_dependent_warp_propagate_read(
More information about the Mlir-commits
mailing list