[Mlir-commits] [mlir] cf74b7e - [mlir][Vector] Fold InsertOp(SplatOp(X), SplatOp(X)) to SplatOp(X).
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jul 5 20:36:19 PDT 2022
Author: jacquesguan
Date: 2022-07-06T11:27:23+08:00
New Revision: cf74b7ec80a89720f8e24394718d34c4436016cf
URL: https://github.com/llvm/llvm-project/commit/cf74b7ec80a89720f8e24394718d34c4436016cf
DIFF: https://github.com/llvm/llvm-project/commit/cf74b7ec80a89720f8e24394718d34c4436016cf.diff
LOG: [mlir][Vector] Fold InsertOp(SplatOp(X), SplatOp(X)) to SplatOp(X).
This patch folds InsertOp(SplatOp(X), SplatOp(X)) to SplatOp(X).
Differential Revision: https://reviews.llvm.org/D129058
Added:
Modified:
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a3874795571d5..00db4650f1206 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2031,11 +2031,32 @@ class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
}
};
+/// Pattern to rewrite a InsertOp(SplatOp, SplatOp) to SplatOp.
+class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
+public:
+ using OpRewritePattern<InsertOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(InsertOp op,
+ PatternRewriter &rewriter) const override {
+ auto srcSplat = op.getSource().getDefiningOp<SplatOp>();
+ auto dstSplat = op.getDest().getDefiningOp<SplatOp>();
+
+ if (!srcSplat || !dstSplat)
+ return failure();
+
+ if (srcSplat.getInput() != dstSplat.getInput())
+ return failure();
+
+ rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), srcSplat.getInput());
+ return success();
+ }
+};
+
} // namespace
void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<InsertToBroadcast, BroadcastFolder>(context);
+ results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context);
}
// Eliminates insert operations that produce values identical to their source
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index e7747c736867f..84b5a45f19e65 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1669,3 +1669,16 @@ func.func @shuffle_splat(%x : i32) -> vector<4xi32> {
return %shuffle : vector<4xi32>
}
+
+// -----
+
+// CHECK-LABEL: func @insert_splat
+// CHECK-SAME: (%[[ARG:.*]]: i32)
+// CHECK-NEXT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<2x4x3xi32>
+// CHECK-NEXT: return %[[SPLAT]] : vector<2x4x3xi32>
+func.func @insert_splat(%x : i32) -> vector<2x4x3xi32> {
+ %v0 = vector.splat %x : vector<4x3xi32>
+ %v1 = vector.splat %x : vector<2x4x3xi32>
+ %insert = vector.insert %v0, %v1[0] : vector<4x3xi32> into vector<2x4x3xi32>
+ return %insert : vector<2x4x3xi32>
+}
More information about the Mlir-commits
mailing list