[Mlir-commits] [mlir] [mlir][vector] Add foldInsertUseChain folder function to insert op (PR #147045)

lonely eagle llvmlistbot at llvm.org
Mon Jul 7 06:12:23 PDT 2025


https://github.com/linuxlonelyeagle updated https://github.com/llvm/llvm-project/pull/147045

>From 227158dc0ac46ff78a2a56d2c79082640164b4b1 Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Fri, 4 Jul 2025 12:15:10 +0000
Subject: [PATCH 1/3] Add InsertInsertToInsert to insert op canonicalize
 patterns

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 25 +++++++++++++++++++++-
 mlir/test/Dialect/Vector/canonicalize.mlir | 14 ++++++++++++
 2 files changed, 38 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 1fb8c7a928e06..2d090bcce45ab 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3335,6 +3335,28 @@ class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
   }
 };
 
+/// Pattern to rewrite a InsertOp(InsertOp) to InsertOp.
+class InsertInsertToInsert final : public OpRewritePattern<InsertOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(InsertOp op,
+                                PatternRewriter &rewriter) const override {
+    auto destInsert = op.getDest().getDefiningOp<InsertOp>();
+    if (!destInsert)
+      return failure();
+
+    if (!destInsert->hasOneUse())
+      return failure();
+
+    if (op.getMixedPosition() != destInsert.getMixedPosition())
+      return failure();
+
+    rewriter.replaceOpWithNewOp<InsertOp>(
+        op, op.getValueToStore(), destInsert.getDest(), op.getMixedPosition());
+    return success();
+  }
+};
+
 } // namespace
 
 static Attribute
@@ -3389,7 +3411,8 @@ foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr,
 
 void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                            MLIRContext *context) {
-  results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context);
+  results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
+              InsertInsertToInsert>(context);
 }
 
 OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 0282e9cac5e02..73bced6149ff6 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -3446,3 +3446,17 @@ func.func @fold_insert_constant_indices(%arg : vector<4x1xi32>) -> vector<4x1xi3
   %res = vector.insert %1, %arg[%0, %0] : i32 into vector<4x1xi32>
   return %res : vector<4x1xi32>
 }
+
+// -----
+
+// CHECK-LABEL: @insert_insert_to_insert(
+//  CHECK-SAME:   %[[ARG:.*]]: vector<4xf32>,
+//  CHECK-SAME:   %[[VAL:.*]]: f32) -> vector<4xf32> {
+//       CHECK:   %[[RES:.*]] = vector.insert %[[VAL]], %[[ARG]] [0] : f32 into vector<4xf32>
+//       CHECK:    return %[[RES]] : vector<4xf32>
+func.func @insert_insert_to_insert(%v : vector<4xf32>, %value : f32) -> vector<4xf32> {
+  %v_0 = vector.insert %value, %v[0] : f32 into vector<4xf32>
+  %v_1 = vector.insert %value, %v_0[0] : f32 into vector<4xf32>
+  %v_2 = vector.insert %value, %v_1[0] : f32 into vector<4xf32>
+  return %v_2 : vector<4xf32>  
+}

>From ddd3df95de72b3b8ad83ed11ed260db95aa8e768 Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Fri, 4 Jul 2025 15:57:45 +0000
Subject: [PATCH 2/3] Implement pattern as folder function, remove oneuse
 conditional judgement, update tests.

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 44 +++++++----------
 mlir/test/Dialect/Vector/canonicalize.mlir | 55 +++++++++++++++++++++-
 2 files changed, 71 insertions(+), 28 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 2d090bcce45ab..7ce770c55e875 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3334,29 +3334,6 @@ class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
     return success();
   }
 };
-
-/// Pattern to rewrite a InsertOp(InsertOp) to InsertOp.
-class InsertInsertToInsert final : public OpRewritePattern<InsertOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-  LogicalResult matchAndRewrite(InsertOp op,
-                                PatternRewriter &rewriter) const override {
-    auto destInsert = op.getDest().getDefiningOp<InsertOp>();
-    if (!destInsert)
-      return failure();
-
-    if (!destInsert->hasOneUse())
-      return failure();
-
-    if (op.getMixedPosition() != destInsert.getMixedPosition())
-      return failure();
-
-    rewriter.replaceOpWithNewOp<InsertOp>(
-        op, op.getValueToStore(), destInsert.getDest(), op.getMixedPosition());
-    return success();
-  }
-};
-
 } // namespace
 
 static Attribute
@@ -3409,13 +3386,26 @@ foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr,
   return newAttr;
 }
 
+/// Folder to replace the `dest` operand of the insert op with the root dest of
+/// the insert op use chain.
+static Value foldInsertUseChain(InsertOp insertOp) {
+  auto destInsert = insertOp.getDest().getDefiningOp<InsertOp>();
+  if (!destInsert)
+    return {};
+
+  if (insertOp.getMixedPosition() != destInsert.getMixedPosition())
+    return {};
+
+  insertOp.setOperand(1, destInsert.getDest());
+  return insertOp.getResult();
+}
+
 void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                            MLIRContext *context) {
-  results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
-              InsertInsertToInsert>(context);
+  results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context);
 }
 
-OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
+OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
   // Do not create constants with more than `vectorSizeFoldThreashold` elements,
   // unless the source vector constant has a single use.
   constexpr int64_t vectorSizeFoldThreshold = 256;
@@ -3430,6 +3420,8 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
   SmallVector<Value> operands = {getValueToStore(), getDest()};
   auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
 
+  if (auto res = foldInsertUseChain(*this))
+    return res;
   if (auto res = foldPoisonIndexInsertExtractOp(
           getContext(), adaptor.getStaticPosition(), kPoisonIndex))
     return res;
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 73bced6149ff6..71aee79f19e46 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -3449,14 +3449,65 @@ func.func @fold_insert_constant_indices(%arg : vector<4x1xi32>) -> vector<4x1xi3
 
 // -----
 
-// CHECK-LABEL: @insert_insert_to_insert(
+// CHECK-LABEL: @fold_insert_use_chain_static_pos(
 //  CHECK-SAME:   %[[ARG:.*]]: vector<4xf32>,
 //  CHECK-SAME:   %[[VAL:.*]]: f32) -> vector<4xf32> {
 //       CHECK:   %[[RES:.*]] = vector.insert %[[VAL]], %[[ARG]] [0] : f32 into vector<4xf32>
 //       CHECK:    return %[[RES]] : vector<4xf32>
-func.func @insert_insert_to_insert(%v : vector<4xf32>, %value : f32) -> vector<4xf32> {
+func.func @fold_insert_use_chain_static_pos(%v : vector<4xf32>, %value : f32) -> vector<4xf32> {
   %v_0 = vector.insert %value, %v[0] : f32 into vector<4xf32>
   %v_1 = vector.insert %value, %v_0[0] : f32 into vector<4xf32>
   %v_2 = vector.insert %value, %v_1[0] : f32 into vector<4xf32>
   return %v_2 : vector<4xf32>  
 }
+
+// -----
+
+// CHECK-LABEL: @fold_insert_use_chain_dynamic_pos(
+//  CHECK-SAME:   %[[ARG:.*]]: vector<4x4xf32>,
+//  CHECK-SAME:   %[[VAL:.*]]: f32,
+//  CHECK-SAME:   %[[POS:.*]]: index) -> vector<4x4xf32> {
+//       CHECK:   %[[RES:.*]] = vector.insert %[[VAL]], %[[ARG]] {{\[}}%[[POS]], 0] : f32 into vector<4x4xf32>
+//       CHECK:   return %[[RES]] : vector<4x4xf32>
+func.func @fold_insert_use_chain_dynamic_pos(%arg : vector<4x4xf32>, %value : f32, %pos: index) -> vector<4x4xf32> {
+  %v_0 = vector.insert %value, %arg[%pos, 0] : f32 into vector<4x4xf32>
+  %v_1 = vector.insert %value, %v_0[%pos, 0] : f32 into vector<4x4xf32>
+  %v_2 = vector.insert %value, %v_1[%pos, 0] : f32 into vector<4x4xf32>
+  return %v_2 : vector<4x4xf32>  
+}
+
+// -----
+
+// CHECK-LABEL: @fold_insert_use_chain_add_float(
+//  CHECK-SAME:   %[[VEC_0:.*]]: vector<4xf32>,
+//  CHECK-SAME:   %[[VAL:.*]]: f32) -> vector<4xf32> {
+//       CHECK:   %[[VEC_1:.*]] = vector.insert %[[VAL]], %[[VEC_0]] [0] : f32 into vector<4xf32>
+//       CHECK:   %[[VEC_2:.*]] = arith.addf %[[VEC_1]], %[[VEC_1]] : vector<4xf32>
+//       CHECK:   %[[VEC_3:.*]] = vector.insert %[[VAL]], %[[VEC_0]] [0] : f32 into vector<4xf32>
+//       CHECK:   %[[VEC_4:.*]] = arith.addf %[[VEC_2]], %[[VEC_3]] : vector<4xf32>
+//       CHECK:   return %[[VEC_4]] : vector<4xf32>
+func.func @fold_insert_use_chain_add_float(%v : vector<4xf32>, %value : f32) -> vector<4xf32> {
+  %v_0 = vector.insert %value, %v[0] : f32 into vector<4xf32>
+  %v_1 = arith.addf %v_0, %v_0 : vector<4xf32>
+  %v_2 = vector.insert %value, %v_0[0] : f32 into vector<4xf32>
+  %v_3 = arith.addf %v_1, %v_2 : vector<4xf32>
+  return %v_3 : vector<4xf32>  
+}
+
+// -----
+
+// CHECK-LABEL: @fold_insert_use_chain_add_float_pos_mismatch(
+//  CHECK-SAME:   %[[VEC_0:.*]]: vector<4xf32>,
+//  CHECK-SAME:   %[[VAL:.*]]: f32) -> vector<4xf32> {
+//       CHECK:   %[[VEC_1:.*]] = vector.insert %[[VAL]], %[[VEC_0]] [0] : f32 into vector<4xf32>
+//       CHECK:   %[[VEC_2:.*]] = arith.addf %[[VEC_1]], %[[VEC_1]] : vector<4xf32>
+//       CHECK:   %[[VEC_3:.*]] = vector.insert %[[VAL]], %[[VEC_1]] [1] : f32 into vector<4xf32>
+//       CHECK:   %[[VEC_4:.*]] = arith.addf %[[VEC_2]], %[[VEC_3]] : vector<4xf32>
+//       CHECK:   return %[[VEC_4]] : vector<4xf32>
+func.func @fold_insert_use_chain_add_float_pos_mismatch(%v : vector<4xf32>, %value : f32) -> vector<4xf32> {
+  %v_0 = vector.insert %value, %v[0] : f32 into vector<4xf32>
+  %v_1 = arith.addf %v_0, %v_0 : vector<4xf32>
+  %v_2 = vector.insert %value, %v_0[1] : f32 into vector<4xf32>
+  %v_3 = arith.addf %v_1, %v_2 : vector<4xf32>
+  return %v_3 : vector<4xf32>  
+}

>From 4c81344f1847c2214652ff7002aa2c7745ecd6aa Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Mon, 7 Jul 2025 13:12:07 +0000
Subject: [PATCH 3/3] update test.

---
 mlir/test/Dialect/Vector/canonicalize.mlir | 64 ++++++++--------------
 1 file changed, 23 insertions(+), 41 deletions(-)

diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 71aee79f19e46..129d21acd42df 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -3449,27 +3449,13 @@ func.func @fold_insert_constant_indices(%arg : vector<4x1xi32>) -> vector<4x1xi3
 
 // -----
 
-// CHECK-LABEL: @fold_insert_use_chain_static_pos(
-//  CHECK-SAME:   %[[ARG:.*]]: vector<4xf32>,
-//  CHECK-SAME:   %[[VAL:.*]]: f32) -> vector<4xf32> {
-//       CHECK:   %[[RES:.*]] = vector.insert %[[VAL]], %[[ARG]] [0] : f32 into vector<4xf32>
-//       CHECK:    return %[[RES]] : vector<4xf32>
-func.func @fold_insert_use_chain_static_pos(%v : vector<4xf32>, %value : f32) -> vector<4xf32> {
-  %v_0 = vector.insert %value, %v[0] : f32 into vector<4xf32>
-  %v_1 = vector.insert %value, %v_0[0] : f32 into vector<4xf32>
-  %v_2 = vector.insert %value, %v_1[0] : f32 into vector<4xf32>
-  return %v_2 : vector<4xf32>  
-}
-
-// -----
-
-// CHECK-LABEL: @fold_insert_use_chain_dynamic_pos(
-//  CHECK-SAME:   %[[ARG:.*]]: vector<4x4xf32>,
+// CHECK-LABEL: @fold_insert_use_chain(
+//  CHECK-SAME:   %[[DEST:.*]]: vector<4x4xf32>,
 //  CHECK-SAME:   %[[VAL:.*]]: f32,
 //  CHECK-SAME:   %[[POS:.*]]: index) -> vector<4x4xf32> {
-//       CHECK:   %[[RES:.*]] = vector.insert %[[VAL]], %[[ARG]] {{\[}}%[[POS]], 0] : f32 into vector<4x4xf32>
-//       CHECK:   return %[[RES]] : vector<4x4xf32>
-func.func @fold_insert_use_chain_dynamic_pos(%arg : vector<4x4xf32>, %value : f32, %pos: index) -> vector<4x4xf32> {
+//  CHECK-NEXT:   %[[RES:.*]] = vector.insert %[[VAL]], %[[DEST]] {{\[}}%[[POS]], 0] : f32 into vector<4x4xf32>
+//  CHECK-NEXT:   return %[[RES]] : vector<4x4xf32>
+func.func @fold_insert_use_chain(%arg : vector<4x4xf32>, %value : f32, %pos: index) -> vector<4x4xf32> {
   %v_0 = vector.insert %value, %arg[%pos, 0] : f32 into vector<4x4xf32>
   %v_1 = vector.insert %value, %v_0[%pos, 0] : f32 into vector<4x4xf32>
   %v_2 = vector.insert %value, %v_1[%pos, 0] : f32 into vector<4x4xf32>
@@ -3478,36 +3464,32 @@ func.func @fold_insert_use_chain_dynamic_pos(%arg : vector<4x4xf32>, %value : f3
 
 // -----
 
-// CHECK-LABEL: @fold_insert_use_chain_add_float(
-//  CHECK-SAME:   %[[VEC_0:.*]]: vector<4xf32>,
+// CHECK-LABEL: @no_fold_insert_use_chain(
+//  CHECK-SAME:   %[[DEST_0:.*]]: vector<4xf32>,
 //  CHECK-SAME:   %[[VAL:.*]]: f32) -> vector<4xf32> {
-//       CHECK:   %[[VEC_1:.*]] = vector.insert %[[VAL]], %[[VEC_0]] [0] : f32 into vector<4xf32>
-//       CHECK:   %[[VEC_2:.*]] = arith.addf %[[VEC_1]], %[[VEC_1]] : vector<4xf32>
-//       CHECK:   %[[VEC_3:.*]] = vector.insert %[[VAL]], %[[VEC_0]] [0] : f32 into vector<4xf32>
-//       CHECK:   %[[VEC_4:.*]] = arith.addf %[[VEC_2]], %[[VEC_3]] : vector<4xf32>
-//       CHECK:   return %[[VEC_4]] : vector<4xf32>
-func.func @fold_insert_use_chain_add_float(%v : vector<4xf32>, %value : f32) -> vector<4xf32> {
+//       CHECK:   %[[DEST_1:.*]] = vector.insert %[[VAL]], %[[DEST_0]] [0] : f32 into vector<4xf32>
+//       CHECK:   %[[RES:.*]] = vector.insert %[[VAL]], %[[DEST_1]] [1] : f32 into vector<4xf32>
+//       CHECK:   return %[[RES]] : vector<4xf32>
+func.func @no_fold_insert_use_chain(%v : vector<4xf32>, %value : f32) -> vector<4xf32> {
   %v_0 = vector.insert %value, %v[0] : f32 into vector<4xf32>
-  %v_1 = arith.addf %v_0, %v_0 : vector<4xf32>
-  %v_2 = vector.insert %value, %v_0[0] : f32 into vector<4xf32>
-  %v_3 = arith.addf %v_1, %v_2 : vector<4xf32>
-  return %v_3 : vector<4xf32>  
+  %v_2 = vector.insert %value, %v_0[1] : f32 into vector<4xf32>
+  return %v_2 : vector<4xf32>  
 }
 
 // -----
 
-// CHECK-LABEL: @fold_insert_use_chain_add_float_pos_mismatch(
-//  CHECK-SAME:   %[[VEC_0:.*]]: vector<4xf32>,
+// CHECK-LABEL: @fold_insert_use_chain_add_float(
+//  CHECK-SAME:   %[[DEST:.*]]: vector<4xf32>,
 //  CHECK-SAME:   %[[VAL:.*]]: f32) -> vector<4xf32> {
-//       CHECK:   %[[VEC_1:.*]] = vector.insert %[[VAL]], %[[VEC_0]] [0] : f32 into vector<4xf32>
-//       CHECK:   %[[VEC_2:.*]] = arith.addf %[[VEC_1]], %[[VEC_1]] : vector<4xf32>
-//       CHECK:   %[[VEC_3:.*]] = vector.insert %[[VAL]], %[[VEC_1]] [1] : f32 into vector<4xf32>
-//       CHECK:   %[[VEC_4:.*]] = arith.addf %[[VEC_2]], %[[VEC_3]] : vector<4xf32>
-//       CHECK:   return %[[VEC_4]] : vector<4xf32>
-func.func @fold_insert_use_chain_add_float_pos_mismatch(%v : vector<4xf32>, %value : f32) -> vector<4xf32> {
+//       CHECK:   %{{.*}} = vector.insert %[[VAL]], %[[DEST]] [0] : f32 into vector<4xf32>
+//       CHECK:   %[[LHS:.*]] = arith.addf %{{.*}}, %{{.*}} : vector<4xf32>
+//       CHECK:   %[[RHS:.*]] = vector.insert %[[VAL]], %[[DEST]] [0] : f32 into vector<4xf32>
+//       CHECK:   %[[RES:.*]] = arith.addf %[[LHS]], %[[RHS]] : vector<4xf32>
+//       CHECK:   return %[[RES]] : vector<4xf32>
+func.func @fold_insert_use_chain_add_float(%v : vector<4xf32>, %value : f32) -> vector<4xf32> {
   %v_0 = vector.insert %value, %v[0] : f32 into vector<4xf32>
   %v_1 = arith.addf %v_0, %v_0 : vector<4xf32>
-  %v_2 = vector.insert %value, %v_0[1] : f32 into vector<4xf32>
+  %v_2 = vector.insert %value, %v_0[0] : f32 into vector<4xf32>
   %v_3 = arith.addf %v_1, %v_2 : vector<4xf32>
-  return %v_3 : vector<4xf32>  
+  return %v_3 : vector<4xf32>
 }



More information about the Mlir-commits mailing list