[Mlir-commits] [mlir] fe3933d - [mlir][vector] Support complete folding in single pass for vector.insert/vector.extract (#142124)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jun 18 09:26:07 PDT 2025


Author: Yang Bai
Date: 2025-06-18T09:26:04-07:00
New Revision: fe3933da15b5bc635bce156f1f8d11a784316a07

URL: https://github.com/llvm/llvm-project/commit/fe3933da15b5bc635bce156f1f8d11a784316a07
DIFF: https://github.com/llvm/llvm-project/commit/fe3933da15b5bc635bce156f1f8d11a784316a07.diff

LOG: [mlir][vector] Support complete folding in single pass for vector.insert/vector.extract (#142124)

### Description

This patch improves the folding efficiency of `vector.insert` and
`vector.extract` operations by not returning early after successfully
converting dynamic indices to static indices.

This PR also renames the test pass `TestConstantFold` to
`TestSingleFold` and adds comprehensive documentation explaining the
single-pass folding behavior.

### Motivation

Since the `OpBuilder::createOrFold` function only calls `fold` **once**,
the current `fold` methods of `vector.insert` and `vector.extract` may
leave the op in a state that can be folded further. For example,
consider the following un-folded IR:
```
%v1 = vector.insert %e1, %v0 [0] : f32 into vector<128xf32>
%c0 = arith.constant 0 : index
%e2 = vector.extract %v1[%c0] : f32 from vector<128xf32>
```
If we use `createOrFold` to create the `vector.extract` op, then the
result will be:
```
%v1 = vector.insert %e1, %v0 [127] : f32 into vector<128xf32>
%e2 = vector.extract %v1[0] : f32 from vector<128xf32>
```
But this is not the optimal result. `createOrFold` should have returned
`%e1`.
The reason is that the execution of fold returns immediately after
`extractInsertFoldConstantOp`, causing subsequent folding logics to be
skipped.

---------

Co-authored-by: Yang Bai <yangb at nvidia.com>

Added: 
    mlir/test/Dialect/Vector/single-fold.mlir
    mlir/test/lib/Transforms/TestSingleFold.cpp

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Affine/constant-fold.mlir
    mlir/test/Dialect/Linalg/mesh-spmdization.mlir
    mlir/test/Dialect/Mesh/spmdization.mlir
    mlir/test/Dialect/Tensor/mesh-spmdization.mlir
    mlir/test/Dialect/Tosa/constant_folding.mlir
    mlir/test/Dialect/Vector/constant-fold.mlir
    mlir/test/Transforms/constant-fold-debuginfo.mlir
    mlir/test/Transforms/constant-fold.mlir
    mlir/test/lib/Transforms/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    mlir/test/lib/Transforms/TestConstantFold.cpp


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 2a2357319bd23..e576eeac23656 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2063,6 +2063,7 @@ static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
   if (opChange) {
     op.setStaticPosition(staticPosition);
     op.getOperation()->setOperands(operands);
+    // Return the original result to indicate an in-place folding happened.
     return op.getResult();
   }
   return {};
@@ -2146,11 +2147,12 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
     return getVector();
   if (auto res = foldPoisonSrcExtractOp(adaptor.getVector()))
     return res;
-  // Fold `arith.constant` indices into the `vector.extract` operation. Make
-  // sure that patterns requiring constant indices are added after this fold.
+  // Fold `arith.constant` indices into the `vector.extract` operation.
+  // Do not stop here as this fold may enable subsequent folds that require
+  // constant indices.
   SmallVector<Value> operands = {getVector()};
-  if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
-    return val;
+  auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
+
   if (auto res = foldPoisonIndexInsertExtractOp(
           getContext(), adaptor.getStaticPosition(), kPoisonIndex))
     return res;
@@ -2172,7 +2174,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
     return val;
   if (auto val = foldScalarExtractFromFromElements(*this))
     return val;
-  return OpFoldResult();
+
+  return inplaceFolded;
 }
 
 namespace {
@@ -3272,11 +3275,12 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
   // (type mismatch).
   if (getNumIndices() == 0 && getValueToStoreType() == getType())
     return getValueToStore();
-  // Fold `arith.constant` indices into the `vector.insert` operation. Make
-  // sure that patterns requiring constant indices are added after this fold.
+  // Fold `arith.constant` indices into the `vector.insert` operation.
+  // Do not stop here as this fold may enable subsequent folds that require
+  // constant indices.
   SmallVector<Value> operands = {getValueToStore(), getDest()};
-  if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
-    return val;
+  auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
+
   if (auto res = foldPoisonIndexInsertExtractOp(
           getContext(), adaptor.getStaticPosition(), kPoisonIndex))
     return res;
@@ -3286,7 +3290,7 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
     return res;
   }
 
-  return {};
+  return inplaceFolded;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Affine/constant-fold.mlir b/mlir/test/Dialect/Affine/constant-fold.mlir
index ffc3946db08df..8bddacc024751 100644
--- a/mlir/test/Dialect/Affine/constant-fold.mlir
+++ b/mlir/test/Dialect/Affine/constant-fold.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-constant-fold -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -test-single-fold -split-input-file %s | FileCheck %s
 
 // CHECK-LABEL: func @affine_apply
 func.func @affine_apply(%variable : index) -> (index, index, index) {

diff  --git a/mlir/test/Dialect/Linalg/mesh-spmdization.mlir b/mlir/test/Dialect/Linalg/mesh-spmdization.mlir
index 487cec00de16a..9805ee4ea5525 100644
--- a/mlir/test/Dialect/Linalg/mesh-spmdization.mlir
+++ b/mlir/test/Dialect/Linalg/mesh-spmdization.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt \
-// RUN:  --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-constant-fold))" \
+// RUN:  --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-single-fold))" \
 // RUN:  --split-input-file \
 // RUN:  %s | FileCheck %s
 

diff  --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir
index 5c9fd29444f04..af4ab58ea50a3 100644
--- a/mlir/test/Dialect/Mesh/spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/spmdization.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt \
-// RUN:   --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-constant-fold))" \
+// RUN:   --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-single-fold))" \
 // RUN:   %s | FileCheck %s
 
 mesh.mesh @mesh_1d(shape = 2)

diff  --git a/mlir/test/Dialect/Tensor/mesh-spmdization.mlir b/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
index 3fb8424745501..8598d81ff6cfa 100644
--- a/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
+++ b/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt \
-// RUN:   --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-constant-fold))" \
+// RUN:   --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-single-fold))" \
 // RUN:   %s | FileCheck %s
 
 mesh.mesh @mesh_1d_4(shape = 4)

diff  --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir
index 9b6ccdb54c107..d477a2479e913 100644
--- a/mlir/test/Dialect/Tosa/constant_folding.mlir
+++ b/mlir/test/Dialect/Tosa/constant_folding.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --test-constant-fold %s | FileCheck %s
+// RUN: mlir-opt --test-single-fold %s | FileCheck %s
 
 // CHECK-LABEL: func @test_const
 func.func @test_const(%arg0 : index) -> tensor<4xi32> {

diff  --git a/mlir/test/Dialect/Vector/constant-fold.mlir b/mlir/test/Dialect/Vector/constant-fold.mlir
index 66c91d6b2041b..cbb159fd59ffc 100644
--- a/mlir/test/Dialect/Vector/constant-fold.mlir
+++ b/mlir/test/Dialect/Vector/constant-fold.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -split-input-file -test-constant-fold | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -test-single-fold | FileCheck %s
 
 // CHECK-LABEL: fold_extract_transpose_negative
 func.func @fold_extract_transpose_negative(%arg0: vector<4x4xf16>) -> vector<4x4xf16> {
@@ -11,3 +11,5 @@ func.func @fold_extract_transpose_negative(%arg0: vector<4x4xf16>) -> vector<4x4
   %2 = vector.extract %1[0] : vector<4x4xf16> from vector<1x4x4xf16>
   return %2 : vector<4x4xf16>
 }
+
+

diff  --git a/mlir/test/Dialect/Vector/single-fold.mlir b/mlir/test/Dialect/Vector/single-fold.mlir
new file mode 100644
index 0000000000000..baccdc3f51c05
--- /dev/null
+++ b/mlir/test/Dialect/Vector/single-fold.mlir
@@ -0,0 +1,38 @@
+// RUN: mlir-opt %s -split-input-file -test-single-fold | FileCheck %s
+
+// The tests in this file verify that fold() methods can handle complex
+// optimization scenarios without requiring multiple folding iterations.
+// This is important because:
+//
+// 1. OpBuilder::createOrFold() only calls fold() once, so operations must
+//    be fully optimized in that single call
+// 2. Multiple rounds of folding would incur higher performance costs,
+//    so it's more efficient to complete all optimizations in one pass
+//
+// These tests ensure that folding implementations are robust and complete,
+// avoiding situations where operations are left in intermediate states
+// that could be further optimized.
+
+// CHECK-LABEL: fold_extract_in_single_pass
+// CHECK-SAME: (%{{.*}}: vector<4xf16>, %[[ARG1:.+]]: f16)
+func.func @fold_extract_in_single_pass(%arg0: vector<4xf16>, %arg1: f16) -> f16 {
+  %0 = vector.insert %arg1, %arg0 [1] : f16 into vector<4xf16>
+  %c1 = arith.constant 1 : index
+  // Verify that the fold is finished in a single pass even if the index is dynamic.
+  %1 = vector.extract %0[%c1] : f16 from vector<4xf16>
+  // CHECK: return %[[ARG1]] : f16
+  return %1 : f16
+}
+
+// -----
+
+// CHECK-LABEL: fold_insert_in_single_pass
+func.func @fold_insert_in_single_pass() -> vector<2xf16> {
+  %cst = arith.constant dense<0.000000e+00> : vector<2xf16>
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2.5 : f16
+  // Verify that the fold is finished in a single pass even if the index is dynamic.
+  // CHECK: arith.constant dense<[0.000000e+00, 2.500000e+00]> : vector<2xf16>
+  %0 = vector.insert %c2, %cst [%c1] : f16 into vector<2xf16>
+  return %0 : vector<2xf16>
+} 
\ No newline at end of file

diff  --git a/mlir/test/Transforms/constant-fold-debuginfo.mlir b/mlir/test/Transforms/constant-fold-debuginfo.mlir
index c308bc477bee4..4fa7fb6698a2b 100644
--- a/mlir/test/Transforms/constant-fold-debuginfo.mlir
+++ b/mlir/test/Transforms/constant-fold-debuginfo.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -split-input-file -test-constant-fold -mlir-print-debuginfo | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -test-single-fold -mlir-print-debuginfo | FileCheck %s
 
 // CHECK-LABEL: func @fold_and_merge
 func.func @fold_and_merge() -> (i32, i32) {

diff  --git a/mlir/test/Transforms/constant-fold.mlir b/mlir/test/Transforms/constant-fold.mlir
index 981757aed9b1d..0b393bf0556b9 100644
--- a/mlir/test/Transforms/constant-fold.mlir
+++ b/mlir/test/Transforms/constant-fold.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -test-constant-fold | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -test-single-fold | FileCheck %s
 
 // -----
 

diff  --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index 76041cd6cd791..ddc0a779e8f69 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -26,11 +26,11 @@ endif()
 add_mlir_library(MLIRTestTransforms
   TestCommutativityUtils.cpp
   TestCompositePass.cpp
-  TestConstantFold.cpp
   TestControlFlowSink.cpp
   TestInlining.cpp
   TestInliningCallback.cpp
   TestMakeIsolatedFromAbove.cpp
+  TestSingleFold.cpp
   TestTransformsOps.cpp
   ${MLIRTestTransformsPDLSrc}
 

diff  --git a/mlir/test/lib/Transforms/TestConstantFold.cpp b/mlir/test/lib/Transforms/TestSingleFold.cpp
similarity index 62%
rename from mlir/test/lib/Transforms/TestConstantFold.cpp
rename to mlir/test/lib/Transforms/TestSingleFold.cpp
index c97ab9091cb66..5bd9dd2a1f075 100644
--- a/mlir/test/lib/Transforms/TestConstantFold.cpp
+++ b/mlir/test/lib/Transforms/TestSingleFold.cpp
@@ -1,4 +1,4 @@
-//===- TestConstantFold.cpp - Pass to test constant folding ---------------===//
+//===- TestSingleFold.cpp - Pass to test single-pass folding --------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -12,14 +12,23 @@
 using namespace mlir;
 
 namespace {
-/// Simple constant folding pass.
-struct TestConstantFold : public PassWrapper<TestConstantFold, OperationPass<>>,
-                          public RewriterBase::Listener {
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestConstantFold)
+/// Test pass for single-pass constant folding.
+///
+/// This pass tests the behavior of operations when folded exactly once. Unlike
+/// canonicalization passes that may apply multiple rounds of folding, this pass
+/// ensures that each operation is folded at most once, which is useful for
+/// testing scenarios where the fold implementation should handle complex cases
+/// without requiring multiple iterations.
+///
+/// The pass also removes dead constants after folding to clean up unused
+/// intermediate results.
+struct TestSingleFold : public PassWrapper<TestSingleFold, OperationPass<>>,
+                        public RewriterBase::Listener {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSingleFold)
 
-  StringRef getArgument() const final { return "test-constant-fold"; }
+  StringRef getArgument() const final { return "test-single-fold"; }
   StringRef getDescription() const final {
-    return "Test operation constant folding";
+    return "Test single-pass operation folding and dead constant elimination";
   }
   // All constants in the operation post folding.
   SmallVector<Operation *> existingConstants;
@@ -39,18 +48,19 @@ struct TestConstantFold : public PassWrapper<TestConstantFold, OperationPass<>>,
 };
 } // namespace
 
-void TestConstantFold::foldOperation(Operation *op, OperationFolder &helper) {
+void TestSingleFold::foldOperation(Operation *op, OperationFolder &helper) {
   // Attempt to fold the specified operation, including handling unused or
   // duplicated constants.
   (void)helper.tryToFold(op);
 }
 
-void TestConstantFold::runOnOperation() {
+void TestSingleFold::runOnOperation() {
   existingConstants.clear();
 
   // Collect and fold the operations within the operation.
   SmallVector<Operation *, 8> ops;
-  getOperation()->walk<mlir::WalkOrder::PreOrder>([&](Operation *op) { ops.push_back(op); });
+  getOperation()->walk<mlir::WalkOrder::PreOrder>(
+      [&](Operation *op) { ops.push_back(op); });
 
   // Fold the constants in reverse so that the last generated constants from
   // folding are at the beginning. This creates somewhat of a linear ordering to
@@ -70,6 +80,6 @@ void TestConstantFold::runOnOperation() {
 
 namespace mlir {
 namespace test {
-void registerTestConstantFold() { PassRegistration<TestConstantFold>(); }
+void registerTestSingleFold() { PassRegistration<TestSingleFold>(); }
 } // namespace test
 } // namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 6ef9ff8e84545..143a5e8e8f8dd 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -87,7 +87,6 @@ void registerTestCfAssertPass();
 void registerTestCFGLoopInfoPass();
 void registerTestComposeSubView();
 void registerTestCompositePass();
-void registerTestConstantFold();
 void registerTestControlFlowSink();
 void registerTestConvertToSPIRVPass();
 void registerTestDataLayoutPropagation();
@@ -145,6 +144,7 @@ void registerTestSCFUtilsPass();
 void registerTestSCFWhileOpBuilderPass();
 void registerTestSCFWrapInZeroTripCheckPasses();
 void registerTestShapeMappingPass();
+void registerTestSingleFold();
 void registerTestSliceAnalysisPass();
 void registerTestSPIRVCPURunnerPipeline();
 void registerTestSPIRVFuncSignatureConversion();
@@ -233,7 +233,6 @@ void registerTestPasses() {
   mlir::test::registerTestCFGLoopInfoPass();
   mlir::test::registerTestComposeSubView();
   mlir::test::registerTestCompositePass();
-  mlir::test::registerTestConstantFold();
   mlir::test::registerTestControlFlowSink();
   mlir::test::registerTestConvertToSPIRVPass();
   mlir::test::registerTestDataLayoutPropagation();
@@ -291,6 +290,7 @@ void registerTestPasses() {
   mlir::test::registerTestSCFWhileOpBuilderPass();
   mlir::test::registerTestSCFWrapInZeroTripCheckPasses();
   mlir::test::registerTestShapeMappingPass();
+  mlir::test::registerTestSingleFold();
   mlir::test::registerTestSliceAnalysisPass();
   mlir::test::registerTestSPIRVCPURunnerPipeline();
   mlir::test::registerTestSPIRVFuncSignatureConversion();


        


More information about the Mlir-commits mailing list