[Mlir-commits] [mlir] [mlir][vector] Support complete folding in single pass for vector.insert/vector.extract (PR #142124)

Yang Bai llvmlistbot at llvm.org
Fri May 30 04:14:13 PDT 2025


https://github.com/yangtetris created https://github.com/llvm/llvm-project/pull/142124

### Description

This patch improves the folding efficiency of `vector.insert` and `vector.extract` operations by not returning early after successfully converting dynamic indices to static indices.

### Motivation

Since the `OpBuilder::createOrFold` function only calls `fold` **once**, the current `fold` methods of `vector.insert` and `vector.extract` may leave the op in a state that can be folded further. For example, consider the following un-folded IR:
```
%v1 = vector.insert %e1, %v0 [0] : f32 into vector<128xf32>
%c0 = arith.constant 0 : index
%e2 = vector.extract %v1[%c0] : f32 from vector<128xf32>
```
If we use `createOrFold` to create the `vector.extract` op, then the result will be:
```
%v1 = vector.insert %e1, %v0 [127] : f32 into vector<128xf32>
%e2 = vector.extract %v1[0] : f32 from vector<128xf32>
```
But this is not the optimal result. `createOrFold` should have returned `%e1`.
The reason is that the execution of fold returns immediately after `extractInsertFoldConstantOp`, causing subsequent folding logics to be skipped.


>From 32b9f8cc3b07770ca409f93206485e949d856de8 Mon Sep 17 00:00:00 2001
From: Yang Bai <yangb at nvidia.com>
Date: Fri, 30 May 2025 03:42:09 -0700
Subject: [PATCH] [mlir][vector] Support complete folding in single pass for
 vector.insert/vector.extract

After successfully converting dynamic indices to static indices, continue folding
instead of returning early, allowing subsequent fold operations to be executed.
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 14 ++++++++------
 1 file changed, 8 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 890a5e9e5c9b4..2e0c917b2139d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2062,6 +2062,7 @@ static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
   if (opChange) {
     op.setStaticPosition(staticPosition);
     op.getOperation()->setOperands(operands);
+    // Return the original result to indicate an in-place folding happened.
     return op.getResult();
   }
   return {};
@@ -2148,8 +2149,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
   // Fold `arith.constant` indices into the `vector.extract` operation. Make
   // sure that patterns requiring constant indices are added after this fold.
   SmallVector<Value> operands = {getVector()};
-  if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
-    return val;
+  auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
+
   if (auto res = foldPoisonIndexInsertExtractOp(
           getContext(), adaptor.getStaticPosition(), kPoisonIndex))
     return res;
@@ -2171,7 +2172,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
     return val;
   if (auto val = foldScalarExtractFromFromElements(*this))
     return val;
-  return OpFoldResult();
+
+  return inplaceFolded;
 }
 
 namespace {
@@ -3150,8 +3152,8 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
   // Fold `arith.constant` indices into the `vector.insert` operation. Make
   // sure that patterns requiring constant indices are added after this fold.
   SmallVector<Value> operands = {getValueToStore(), getDest()};
-  if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
-    return val;
+  auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
+
   if (auto res = foldPoisonIndexInsertExtractOp(
           getContext(), adaptor.getStaticPosition(), kPoisonIndex))
     return res;
@@ -3161,7 +3163,7 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
     return res;
   }
 
-  return {};
+  return inplaceFolded;
 }
 
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list