[Mlir-commits] [mlir] 91ab4d4 - [mlir][Vector] Fold InsertStridedSliceOp of two splat with the same input to splat.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jun 30 19:56:34 PDT 2022


Author: jacquesguan
Date: 2022-07-01T10:46:47+08:00
New Revision: 91ab4d4231e5b7456d012776c5eeb69fa61ab994

URL: https://github.com/llvm/llvm-project/commit/91ab4d4231e5b7456d012776c5eeb69fa61ab994
DIFF: https://github.com/llvm/llvm-project/commit/91ab4d4231e5b7456d012776c5eeb69fa61ab994.diff

LOG: [mlir][Vector] Fold InsertStridedSliceOp of two splat with the same input to splat.

This patch folds InsertStridedSliceOp(SplatOp(X):src_type, SplatOp(X):dst_type) to SplatOp(X):dst_type.

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D128891

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 57c02c9a35ba3..6f68f83ed05f9 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -886,6 +886,7 @@ def Vector_InsertStridedSliceOp :
 
   let hasFolder = 1;
   let hasVerifier = 1;
+  let hasCanonicalizer = 1;
 }
 
 def Vector_OuterProductOp :

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index c3b4d1e13de47..3edd23fef6242 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2180,6 +2180,38 @@ LogicalResult InsertStridedSliceOp::verify() {
   return success();
 }
 
+namespace {
+/// Pattern to rewrite an InsertStridedSliceOp(SplatOp(X):src_type,
+/// SplatOp(X):dst_type) to SplatOp(X):dst_type.
+class FoldInsertStridedSliceSplat final
+    : public OpRewritePattern<InsertStridedSliceOp> {
+public:
+  using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
+                                PatternRewriter &rewriter) const override {
+    auto srcSplatOp =
+        insertStridedSliceOp.getSource().getDefiningOp<vector::SplatOp>();
+    auto destSplatOp =
+        insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>();
+
+    if (!srcSplatOp || !destSplatOp)
+      return failure();
+
+    if (srcSplatOp.getInput() != destSplatOp.getInput())
+      return failure();
+
+    rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
+    return success();
+  }
+};
+} // namespace
+
+void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
+    RewritePatternSet &results, MLIRContext *context) {
+  results.add<FoldInsertStridedSliceSplat>(context);
+}
+
 OpFoldResult InsertStridedSliceOp::fold(ArrayRef<Attribute> operands) {
   if (getSourceVectorType() == getDestVectorType())
     return getSource();

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index d16a6fa2c7e11..7f50d90380452 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1627,3 +1627,17 @@ func.func @bitcast(%a: vector<4x8xf32>) -> vector<4x16xi16> {
   %1 = vector.bitcast %0 : vector<4x8xi32> to vector<4x16xi16>
   return %1 : vector<4x16xi16>
 }
+
+// -----
+
+// CHECK-LABEL: @insert_strided_slice_splat
+//  CHECK-SAME: (%[[ARG:.*]]: f32)
+//  CHECK-NEXT:   %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<8x16xf32>
+//  CHECK-NEXT:   return %[[SPLAT]] : vector<8x16xf32>
+func.func @insert_strided_slice_splat(%x: f32) -> (vector<8x16xf32>) {
+  %splat0 = vector.splat %x : vector<4x4xf32>
+  %splat1 = vector.splat %x : vector<8x16xf32>
+  %0 = vector.insert_strided_slice %splat0, %splat1 {offsets = [2, 2], strides = [1, 1]}
+    : vector<4x4xf32> into vector<8x16xf32>
+  return %0 : vector<8x16xf32>
+}


        


More information about the Mlir-commits mailing list