[Mlir-commits] [mlir] cf74b7e - [mlir][Vector] Fold InsertOp(SplatOp(X), SplatOp(X)) to SplatOp(X).

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jul 5 20:36:19 PDT 2022


Author: jacquesguan
Date: 2022-07-06T11:27:23+08:00
New Revision: cf74b7ec80a89720f8e24394718d34c4436016cf

URL: https://github.com/llvm/llvm-project/commit/cf74b7ec80a89720f8e24394718d34c4436016cf
DIFF: https://github.com/llvm/llvm-project/commit/cf74b7ec80a89720f8e24394718d34c4436016cf.diff

LOG: [mlir][Vector] Fold InsertOp(SplatOp(X), SplatOp(X)) to SplatOp(X).

This patch folds InsertOp(SplatOp(X), SplatOp(X)) to SplatOp(X).

Differential Revision: https://reviews.llvm.org/D129058

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a3874795571d5..00db4650f1206 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2031,11 +2031,32 @@ class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
   }
 };
 
+/// Pattern to rewrite a InsertOp(SplatOp, SplatOp) to SplatOp.
+class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
+public:
+  using OpRewritePattern<InsertOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(InsertOp op,
+                                PatternRewriter &rewriter) const override {
+    auto srcSplat = op.getSource().getDefiningOp<SplatOp>();
+    auto dstSplat = op.getDest().getDefiningOp<SplatOp>();
+
+    if (!srcSplat || !dstSplat)
+      return failure();
+
+    if (srcSplat.getInput() != dstSplat.getInput())
+      return failure();
+
+    rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), srcSplat.getInput());
+    return success();
+  }
+};
+
 } // namespace
 
 void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                            MLIRContext *context) {
-  results.add<InsertToBroadcast, BroadcastFolder>(context);
+  results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context);
 }
 
 // Eliminates insert operations that produce values identical to their source

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index e7747c736867f..84b5a45f19e65 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1669,3 +1669,16 @@ func.func @shuffle_splat(%x : i32) -> vector<4xi32> {
   return %shuffle : vector<4xi32>
 }
 
+
+// -----
+
+// CHECK-LABEL: func @insert_splat
+//  CHECK-SAME:   (%[[ARG:.*]]: i32)
+//  CHECK-NEXT:   %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<2x4x3xi32>
+//  CHECK-NEXT:   return %[[SPLAT]] : vector<2x4x3xi32>
+func.func @insert_splat(%x : i32) -> vector<2x4x3xi32> {
+  %v0 = vector.splat %x : vector<4x3xi32>
+  %v1 = vector.splat %x : vector<2x4x3xi32>
+  %insert = vector.insert %v0, %v1[0] : vector<4x3xi32> into vector<2x4x3xi32>
+  return %insert : vector<2x4x3xi32>
+}


        


More information about the Mlir-commits mailing list