[Mlir-commits] [mlir] 8f45c58 - [mlir][Vector] Fold InsertStridedSliceOp of ExtractStridedSliceOp.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jun 30 20:45:49 PDT 2022
Author: jacquesguan
Date: 2022-07-01T11:43:35+08:00
New Revision: 8f45c5862f82ca7e32c09ef84317daf66d278757
URL: https://github.com/llvm/llvm-project/commit/8f45c5862f82ca7e32c09ef84317daf66d278757
DIFF: https://github.com/llvm/llvm-project/commit/8f45c5862f82ca7e32c09ef84317daf66d278757.diff
LOG: [mlir][Vector] Fold InsertStridedSliceOp of ExtractStridedSliceOp.
This patch supports to fold InsertStridedSliceOp(ExtractStridedSliceOp(dst), dst) to dst.
Differential Revision: https://reviews.llvm.org/D128903
Added:
Modified:
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 3edd23fef624..38f38f886705 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2205,11 +2205,43 @@ class FoldInsertStridedSliceSplat final
return success();
}
};
+
+/// Pattern to rewrite an InsertStridedSliceOp(ExtractStridedSliceOp(dst), dst)
+/// to dst.
+class FoldInsertStridedSliceOfExtract final
+ : public OpRewritePattern<InsertStridedSliceOp> {
+public:
+ using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
+ PatternRewriter &rewriter) const override {
+ auto extractStridedSliceOp =
+ insertStridedSliceOp.getSource()
+ .getDefiningOp<vector::ExtractStridedSliceOp>();
+
+ if (!extractStridedSliceOp)
+ return failure();
+
+ if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
+ return failure();
+
+ // Check if have the same strides and offsets.
+ if (extractStridedSliceOp.getStrides() !=
+ insertStridedSliceOp.getStrides() ||
+ extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
+ return failure();
+
+ rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
+ return success();
+ }
+};
+
} // namespace
void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
- results.add<FoldInsertStridedSliceSplat>(context);
+ results.add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract>(
+ context);
}
OpFoldResult InsertStridedSliceOp::fold(ArrayRef<Attribute> operands) {
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 7f50d9038045..515a2d1726b6 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1641,3 +1641,17 @@ func.func @insert_strided_slice_splat(%x: f32) -> (vector<8x16xf32>) {
: vector<4x4xf32> into vector<8x16xf32>
return %0 : vector<8x16xf32>
}
+
+
+// -----
+
+// CHECK-LABEL: @insert_extract_strided_slice
+// CHECK-SAME: (%[[ARG:.*]]: vector<8x16xf32>)
+// CHECK-NEXT: return %[[ARG]] : vector<8x16xf32>
+func.func @insert_extract_strided_slice(%x: vector<8x16xf32>) -> (vector<8x16xf32>) {
+ %0 = vector.extract_strided_slice %x {offsets = [0, 8], sizes = [2, 4], strides = [1, 1]}
+ : vector<8x16xf32> to vector<2x4xf32>
+ %1 = vector.insert_strided_slice %0, %x {offsets = [0, 8], strides = [1, 1]}
+ : vector<2x4xf32> into vector<8x16xf32>
+ return %1 : vector<8x16xf32>
+}
More information about the Mlir-commits
mailing list