[Mlir-commits] [mlir] [mlir][LLVM] Improve `llvm.extractvalue` folder (PR #136861)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 23 06:19:47 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

Continue the traversal on the SSA chain of the inserted value for additional folding opportunities.


---
Full diff: https://github.com/llvm/llvm-project/pull/136861.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+39-8) 
- (modified) mlir/test/Dialect/LLVMIR/canonicalize.mlir (+16) 


``````````diff
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 0022be84c212e..26c3ef1e8b8bf 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1885,11 +1885,40 @@ OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
 
   auto insertValueOp = getContainer().getDefiningOp<InsertValueOp>();
   OpFoldResult result = {};
+  ArrayRef<int64_t> extractPos = getPosition();
+  bool switchedToInsertedValue = false;
   while (insertValueOp) {
-    if (getPosition() == insertValueOp.getPosition())
+    ArrayRef<int64_t> insertPos = insertValueOp.getPosition();
+    auto extractPosSize = extractPos.size();
+    auto insertPosSize = insertPos.size();
+
+    // Case 1: Exact match of positions.
+    if (extractPos == insertPos)
       return insertValueOp.getValue();
-    unsigned min =
-        std::min(getPosition().size(), insertValueOp.getPosition().size());
+
+    // Case 2: Insert position is a prefix of extract position. Continue
+    // traversal with the inserted value. Example:
+    // ```
+    // %0 = llvm.insertvalue %arg1, %undef[0] : !llvm.struct<(i32, i32, i32)>
+    // %1 = llvm.insertvalue %arg2, %0[1] : !llvm.struct<(i32, i32, i32)>
+    // %2 = llvm.insertvalue %arg3, %1[2] : !llvm.struct<(i32, i32, i32)>
+    // %3 = llvm.insertvalue %2, %foo[0]
+    //     : !llvm.struct<(struct<(i32, i32, i32)>, i64)>
+    // %4 = llvm.extractvalue %3[0, 0]
+    //     : !llvm.struct<(struct<(i32, i32, i32)>, i64)>
+    // ```
+    // In the above example, %4 is folded to %arg1.
+    if (extractPosSize > insertPosSize &&
+        extractPos.take_front(insertPosSize) == insertPos) {
+      insertValueOp = insertValueOp.getValue().getDefiningOp<InsertValueOp>();
+      extractPos = extractPos.drop_front(insertPosSize);
+      switchedToInsertedValue = true;
+      continue;
+    }
+
+    // Case 3: Try to continue the traversal with the container value.
+    unsigned min = std::min(extractPosSize, insertPosSize);
+
     // If one is fully prefix of the other, stop propagating back as it will
     // miss dependencies. For instance, %3 should not fold to %f0 in the
     // following example:
@@ -1900,15 +1929,17 @@ OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
     //     !llvm.array<4 x !llvm.array<4 x f32>>
     //   %3 = llvm.extractvalue %2[0, 0] : !llvm.array<4 x !llvm.array<4 x f32>>
     // ```
-    if (getPosition().take_front(min) ==
-        insertValueOp.getPosition().take_front(min))
+    if (extractPos.take_front(min) == insertPos.take_front(min))
       return result;
-
     // If neither a prefix, nor the exact position, we can extract out of the
     // value being inserted into. Moreover, we can try again if that operand
     // is itself an insertvalue expression.
-    getContainerMutable().assign(insertValueOp.getContainer());
-    result = getResult();
+    if (!switchedToInsertedValue) {
+      // Do not swap out the container operand if we decided earlier to
+      // continue the traversal with the inserted value (Case 2).
+      getContainerMutable().assign(insertValueOp.getContainer());
+      result = getResult();
+    }
     insertValueOp = insertValueOp.getContainer().getDefiningOp<InsertValueOp>();
   }
   return result;
diff --git a/mlir/test/Dialect/LLVMIR/canonicalize.mlir b/mlir/test/Dialect/LLVMIR/canonicalize.mlir
index a793caca064ec..8accf6e263863 100644
--- a/mlir/test/Dialect/LLVMIR/canonicalize.mlir
+++ b/mlir/test/Dialect/LLVMIR/canonicalize.mlir
@@ -57,6 +57,22 @@ llvm.func @fold_extractvalue() -> i32 {
 
 // -----
 
+// CHECK-LABEL: fold_extractvalue(
+//  CHECK-SAME:     %[[arg1:.*]]: i32, %[[arg2:.*]]: i32, %[[arg3:.*]]: i32)
+//  CHECK-NEXT:   llvm.return %[[arg1]] : i32
+llvm.func @fold_extractvalue(%arg1: i32, %arg2: i32, %arg3: i32) -> i32{
+  %3 = llvm.mlir.undef : !llvm.struct<(struct<(i32, i32, i32)>, struct<(i32, i32)>)>
+  %5 = llvm.mlir.undef : !llvm.struct<(i32, i32, i32)>
+  %6 = llvm.insertvalue %arg1, %5[0] : !llvm.struct<(i32, i32, i32)>
+  %7 = llvm.insertvalue %arg1, %6[1] : !llvm.struct<(i32, i32, i32)>
+  %8 = llvm.insertvalue %arg1, %7[2] : !llvm.struct<(i32, i32, i32)>
+  %11 = llvm.insertvalue %8, %3[0] : !llvm.struct<(struct<(i32, i32, i32)>, struct<(i32, i32)>)>
+  %13 = llvm.extractvalue %11[0, 0] : !llvm.struct<(struct<(i32, i32, i32)>, struct<(i32, i32)>)>
+  llvm.return %13 : i32
+}
+
+// -----
+
 // CHECK-LABEL: no_fold_extractvalue
 llvm.func @no_fold_extractvalue(%arr: !llvm.array<4 x f32>) -> f32 {
   %f0 = arith.constant 0.0 : f32

``````````

</details>


https://github.com/llvm/llvm-project/pull/136861


More information about the Mlir-commits mailing list