[Mlir-commits] [mlir] 91ab4d4 - [mlir][Vector] Fold InsertStridedSliceOp of two splat with the same input to splat.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jun 30 19:56:34 PDT 2022
Author: jacquesguan
Date: 2022-07-01T10:46:47+08:00
New Revision: 91ab4d4231e5b7456d012776c5eeb69fa61ab994
URL: https://github.com/llvm/llvm-project/commit/91ab4d4231e5b7456d012776c5eeb69fa61ab994
DIFF: https://github.com/llvm/llvm-project/commit/91ab4d4231e5b7456d012776c5eeb69fa61ab994.diff
LOG: [mlir][Vector] Fold InsertStridedSliceOp of two splat with the same input to splat.
This patch folds InsertStridedSliceOp(SplatOp(X):src_type, SplatOp(X):dst_type) to SplatOp(X):dst_type.
Reviewed By: Mogball
Differential Revision: https://reviews.llvm.org/D128891
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 57c02c9a35ba3..6f68f83ed05f9 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -886,6 +886,7 @@ def Vector_InsertStridedSliceOp :
let hasFolder = 1;
let hasVerifier = 1;
+ let hasCanonicalizer = 1;
}
def Vector_OuterProductOp :
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index c3b4d1e13de47..3edd23fef6242 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2180,6 +2180,38 @@ LogicalResult InsertStridedSliceOp::verify() {
return success();
}
+namespace {
+/// Pattern to rewrite an InsertStridedSliceOp(SplatOp(X):src_type,
+/// SplatOp(X):dst_type) to SplatOp(X):dst_type.
+class FoldInsertStridedSliceSplat final
+ : public OpRewritePattern<InsertStridedSliceOp> {
+public:
+ using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
+ PatternRewriter &rewriter) const override {
+ auto srcSplatOp =
+ insertStridedSliceOp.getSource().getDefiningOp<vector::SplatOp>();
+ auto destSplatOp =
+ insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>();
+
+ if (!srcSplatOp || !destSplatOp)
+ return failure();
+
+ if (srcSplatOp.getInput() != destSplatOp.getInput())
+ return failure();
+
+ rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
+ return success();
+ }
+};
+} // namespace
+
+void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
+ RewritePatternSet &results, MLIRContext *context) {
+ results.add<FoldInsertStridedSliceSplat>(context);
+}
+
OpFoldResult InsertStridedSliceOp::fold(ArrayRef<Attribute> operands) {
if (getSourceVectorType() == getDestVectorType())
return getSource();
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index d16a6fa2c7e11..7f50d90380452 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1627,3 +1627,17 @@ func.func @bitcast(%a: vector<4x8xf32>) -> vector<4x16xi16> {
%1 = vector.bitcast %0 : vector<4x8xi32> to vector<4x16xi16>
return %1 : vector<4x16xi16>
}
+
+// -----
+
+// CHECK-LABEL: @insert_strided_slice_splat
+// CHECK-SAME: (%[[ARG:.*]]: f32)
+// CHECK-NEXT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<8x16xf32>
+// CHECK-NEXT: return %[[SPLAT]] : vector<8x16xf32>
+func.func @insert_strided_slice_splat(%x: f32) -> (vector<8x16xf32>) {
+ %splat0 = vector.splat %x : vector<4x4xf32>
+ %splat1 = vector.splat %x : vector<8x16xf32>
+ %0 = vector.insert_strided_slice %splat0, %splat1 {offsets = [2, 2], strides = [1, 1]}
+ : vector<4x4xf32> into vector<8x16xf32>
+ return %0 : vector<8x16xf32>
+}
More information about the Mlir-commits
mailing list