[Mlir-commits] [mlir] [mlir][vector] Support complete folding in single pass for vector.insert/vector.extract (PR #142124)

Yang Bai llvmlistbot at llvm.org
Sat May 31 08:34:09 PDT 2025


https://github.com/yangtetris updated https://github.com/llvm/llvm-project/pull/142124

>From 32b9f8cc3b07770ca409f93206485e949d856de8 Mon Sep 17 00:00:00 2001
From: Yang Bai <yangb at nvidia.com>
Date: Fri, 30 May 2025 03:42:09 -0700
Subject: [PATCH 1/3] [mlir][vector] Support complete folding in single pass
 for vector.insert/vector.extract

After successfully converting dynamic indices to static indices, continue folding
instead of returning early, allowing subsequent fold operations to be executed.
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 14 ++++++++------
 1 file changed, 8 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 890a5e9e5c9b4..2e0c917b2139d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2062,6 +2062,7 @@ static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
   if (opChange) {
     op.setStaticPosition(staticPosition);
     op.getOperation()->setOperands(operands);
+    // Return the original result to indicate an in-place folding happened.
     return op.getResult();
   }
   return {};
@@ -2148,8 +2149,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
   // Fold `arith.constant` indices into the `vector.extract` operation. Make
   // sure that patterns requiring constant indices are added after this fold.
   SmallVector<Value> operands = {getVector()};
-  if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
-    return val;
+  auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
+
   if (auto res = foldPoisonIndexInsertExtractOp(
           getContext(), adaptor.getStaticPosition(), kPoisonIndex))
     return res;
@@ -2171,7 +2172,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
     return val;
   if (auto val = foldScalarExtractFromFromElements(*this))
     return val;
-  return OpFoldResult();
+
+  return inplaceFolded;
 }
 
 namespace {
@@ -3150,8 +3152,8 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
   // Fold `arith.constant` indices into the `vector.insert` operation. Make
   // sure that patterns requiring constant indices are added after this fold.
   SmallVector<Value> operands = {getValueToStore(), getDest()};
-  if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
-    return val;
+  auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
+
   if (auto res = foldPoisonIndexInsertExtractOp(
           getContext(), adaptor.getStaticPosition(), kPoisonIndex))
     return res;
@@ -3161,7 +3163,7 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
     return res;
   }
 
-  return {};
+  return inplaceFolded;
 }
 
 //===----------------------------------------------------------------------===//

>From 7586aa1727bcba311ed4a6c0d1092da428c7f163 Mon Sep 17 00:00:00 2001
From: Yang Bai <yangb at nvidia.com>
Date: Sat, 31 May 2025 08:28:43 -0700
Subject: [PATCH 2/3] refine comment

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 10 ++++++----
 1 file changed, 6 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 2e0c917b2139d..efcd75c5c1bb7 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2146,8 +2146,9 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
     return getVector();
   if (auto res = foldPoisonSrcExtractOp(adaptor.getVector()))
     return res;
-  // Fold `arith.constant` indices into the `vector.extract` operation. Make
-  // sure that patterns requiring constant indices are added after this fold.
+  // Fold `arith.constant` indices into the `vector.extract` operation.
+  // Do not stop here as this fold may enable subsequent folds that require
+  // constant indices.
   SmallVector<Value> operands = {getVector()};
   auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
 
@@ -3149,8 +3150,9 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
   // (type mismatch).
   if (getNumIndices() == 0 && getValueToStoreType() == getType())
     return getValueToStore();
-  // Fold `arith.constant` indices into the `vector.insert` operation. Make
-  // sure that patterns requiring constant indices are added after this fold.
+  // Fold `arith.constant` indices into the `vector.insert` operation.
+  // Do not stop here as this fold may enable subsequent folds that require
+  // constant indices.
   SmallVector<Value> operands = {getValueToStore(), getDest()};
   auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
 

>From 22453c9f6309b721bbd747c2194c88b126ee78ae Mon Sep 17 00:00:00 2001
From: Yang Bai <yangb at nvidia.com>
Date: Sat, 31 May 2025 08:33:37 -0700
Subject: [PATCH 3/3] add test

---
 mlir/test/Dialect/Vector/constant-fold.mlir | 26 +++++++++++++++++++++
 1 file changed, 26 insertions(+)

diff --git a/mlir/test/Dialect/Vector/constant-fold.mlir b/mlir/test/Dialect/Vector/constant-fold.mlir
index 66c91d6b2041b..b23e76f590d78 100644
--- a/mlir/test/Dialect/Vector/constant-fold.mlir
+++ b/mlir/test/Dialect/Vector/constant-fold.mlir
@@ -11,3 +11,29 @@ func.func @fold_extract_transpose_negative(%arg0: vector<4x4xf16>) -> vector<4x4
   %2 = vector.extract %1[0] : vector<4x4xf16> from vector<1x4x4xf16>
   return %2 : vector<4x4xf16>
 }
+
+// -----
+
+// CHECK-LABEL: fold_extract_in_single_pass
+// CHECK-SAME: (%{{.*}}: vector<4xf16>, %[[ARG1:.+]]: f16)
+func.func @fold_extract_in_single_pass(%arg0: vector<4xf16>, %arg1: f16) -> f16 {
+  %0 = vector.insert %arg1, %arg0 [1] : f16 into vector<4xf16>
+  %c1 = arith.constant 1 : index
+  // Verify that the fold is finished in a single pass even if the index is dynamic.
+  %1 = vector.extract %0[%c1] : f16 from vector<4xf16>
+  // CHECK: return %[[ARG1]] : f16
+  return %1 : f16
+}
+
+// -----
+
+// CHECK-LABEL: fold_insert_in_single_pass
+func.func @fold_insert_in_single_pass() -> vector<2xf16> {
+  %cst = arith.constant dense<0.000000e+00> : vector<2xf16>
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2.5 : f16
+  // Verify that the fold is finished in a single pass even if the index is dynamic.
+  // CHECK: arith.constant dense<[0.000000e+00, 2.500000e+00]> : vector<2xf16>
+  %0 = vector.insert %c2, %cst [%c1] : f16 into vector<2xf16>
+  return %0 : vector<2xf16>
+}



More information about the Mlir-commits mailing list