[Mlir-commits] [mlir] 517cda1 - [mlir][vector] Add foldInsertUseChain folder function to insert op (#147045)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jul 8 07:39:22 PDT 2025
Author: lonely eagle
Date: 2025-07-08T22:39:18+08:00
New Revision: 517cda12e5091216645903ec85087b0b2f8239c4
URL: https://github.com/llvm/llvm-project/commit/517cda12e5091216645903ec85087b0b2f8239c4
DIFF: https://github.com/llvm/llvm-project/commit/517cda12e5091216645903ec85087b0b2f8239c4.diff
LOG: [mlir][vector] Add foldInsertUseChain folder function to insert op (#147045)
When the result of an insert op is used by an insert op, and the
subsequent insert op is inserted at the same location as the previous
insert op, replaces the dest of the subsequent insert op with the dest
of the previous insert op.This is because the previous insert op does
not affect subsequent insert ops.
---------
Co-authored-by: Mehdi Amini <joker.eph at gmail.com>
Co-authored-by: Andrzej WarzyĆski <andrzej.warzynski at gmail.com>
Added:
Modified:
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 14e626a6b23e3..214d2ba7e1b8e 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3334,7 +3334,6 @@ class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
return success();
}
};
-
} // namespace
static Attribute
@@ -3387,12 +3386,26 @@ foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr,
return newAttr;
}
+/// Folder to replace the `dest` operand of the insert op with the root dest of
+/// the insert op use chain.
+static Value foldInsertUseChain(InsertOp insertOp) {
+ auto destInsert = insertOp.getDest().getDefiningOp<InsertOp>();
+ if (!destInsert)
+ return {};
+
+ if (insertOp.getMixedPosition() != destInsert.getMixedPosition())
+ return {};
+
+ insertOp.setOperand(1, destInsert.getDest());
+ return insertOp.getResult();
+}
+
void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context);
}
-OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
+OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
// Do not create constants with more than `vectorSizeFoldThreashold` elements,
// unless the source vector constant has a single use.
constexpr int64_t vectorSizeFoldThreshold = 256;
@@ -3407,6 +3420,8 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
SmallVector<Value> operands = {getValueToStore(), getDest()};
auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
+ if (auto res = foldInsertUseChain(*this))
+ return res;
if (auto res = foldPoisonIndexInsertExtractOp(
getContext(), adaptor.getStaticPosition(), kPoisonIndex))
return res;
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 8cda8d47cb908..8a9e27378df61 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -3470,3 +3470,32 @@ func.func @fold_insert_constant_indices(%arg : vector<4x1xi32>) -> vector<4x1xi3
%res = vector.insert %1, %arg[%0, %0] : i32 into vector<4x1xi32>
return %res : vector<4x1xi32>
}
+
+// -----
+
+// CHECK-LABEL: @fold_insert_use_chain(
+// CHECK-SAME: %[[ARG:.*]]: vector<4x4xf32>,
+// CHECK-SAME: %[[VAL:.*]]: f32,
+// CHECK-SAME: %[[POS:.*]]: index) -> vector<4x4xf32> {
+// CHECK-NEXT: %[[RES:.*]] = vector.insert %[[VAL]], %[[ARG]] {{\[}}%[[POS]], 0] : f32 into vector<4x4xf32>
+// CHECK-NEXT: return %[[RES]] : vector<4x4xf32>
+func.func @fold_insert_use_chain(%arg : vector<4x4xf32>, %val : f32, %pos: index) -> vector<4x4xf32> {
+ %v_0 = vector.insert %val, %arg[%pos, 0] : f32 into vector<4x4xf32>
+ %v_1 = vector.insert %val, %v_0[%pos, 0] : f32 into vector<4x4xf32>
+ %v_2 = vector.insert %val, %v_1[%pos, 0] : f32 into vector<4x4xf32>
+ return %v_2 : vector<4x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @no_fold_insert_use_chain_mismatch_static_position(
+// CHECK-SAME: %[[ARG:.*]]: vector<4xf32>,
+// CHECK-SAME: %[[VAL:.*]]: f32) -> vector<4xf32> {
+// CHECK: %[[V_0:.*]] = vector.insert %[[VAL]], %[[ARG]] [0] : f32 into vector<4xf32>
+// CHECK: %[[V_1:.*]] = vector.insert %[[VAL]], %[[V_0]] [1] : f32 into vector<4xf32>
+// CHECK: return %[[V_1]] : vector<4xf32>
+func.func @no_fold_insert_use_chain_mismatch_static_position(%arg : vector<4xf32>, %val : f32) -> vector<4xf32> {
+ %v_0 = vector.insert %val, %arg[0] : f32 into vector<4xf32>
+ %v_1 = vector.insert %val, %v_0[1] : f32 into vector<4xf32>
+ return %v_1 : vector<4xf32>
+}
More information about the Mlir-commits
mailing list