[Mlir-commits] [mlir] [mlir][vector] Support complete folding in single pass for vector.insert/vector.extract (PR #142124)
Yang Bai
llvmlistbot at llvm.org
Fri May 30 04:14:13 PDT 2025
https://github.com/yangtetris created https://github.com/llvm/llvm-project/pull/142124
### Description
This patch improves the folding efficiency of `vector.insert` and `vector.extract` operations by not returning early after successfully converting dynamic indices to static indices.
### Motivation
Since the `OpBuilder::createOrFold` function only calls `fold` **once**, the current `fold` methods of `vector.insert` and `vector.extract` may leave the op in a state that can be folded further. For example, consider the following un-folded IR:
```
%v1 = vector.insert %e1, %v0 [0] : f32 into vector<128xf32>
%c0 = arith.constant 0 : index
%e2 = vector.extract %v1[%c0] : f32 from vector<128xf32>
```
If we use `createOrFold` to create the `vector.extract` op, then the result will be:
```
%v1 = vector.insert %e1, %v0 [127] : f32 into vector<128xf32>
%e2 = vector.extract %v1[0] : f32 from vector<128xf32>
```
But this is not the optimal result. `createOrFold` should have returned `%e1`.
The reason is that the execution of fold returns immediately after `extractInsertFoldConstantOp`, causing subsequent folding logics to be skipped.
>From 32b9f8cc3b07770ca409f93206485e949d856de8 Mon Sep 17 00:00:00 2001
From: Yang Bai <yangb at nvidia.com>
Date: Fri, 30 May 2025 03:42:09 -0700
Subject: [PATCH] [mlir][vector] Support complete folding in single pass for
vector.insert/vector.extract
After successfully converting dynamic indices to static indices, continue folding
instead of returning early, allowing subsequent fold operations to be executed.
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 14 ++++++++------
1 file changed, 8 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 890a5e9e5c9b4..2e0c917b2139d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2062,6 +2062,7 @@ static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
if (opChange) {
op.setStaticPosition(staticPosition);
op.getOperation()->setOperands(operands);
+ // Return the original result to indicate an in-place folding happened.
return op.getResult();
}
return {};
@@ -2148,8 +2149,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
// Fold `arith.constant` indices into the `vector.extract` operation. Make
// sure that patterns requiring constant indices are added after this fold.
SmallVector<Value> operands = {getVector()};
- if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
- return val;
+ auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
+
if (auto res = foldPoisonIndexInsertExtractOp(
getContext(), adaptor.getStaticPosition(), kPoisonIndex))
return res;
@@ -2171,7 +2172,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
return val;
if (auto val = foldScalarExtractFromFromElements(*this))
return val;
- return OpFoldResult();
+
+ return inplaceFolded;
}
namespace {
@@ -3150,8 +3152,8 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
// Fold `arith.constant` indices into the `vector.insert` operation. Make
// sure that patterns requiring constant indices are added after this fold.
SmallVector<Value> operands = {getValueToStore(), getDest()};
- if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
- return val;
+ auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
+
if (auto res = foldPoisonIndexInsertExtractOp(
getContext(), adaptor.getStaticPosition(), kPoisonIndex))
return res;
@@ -3161,7 +3163,7 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
return res;
}
- return {};
+ return inplaceFolded;
}
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list