[Mlir-commits] [mlir] [mlir] Always update ExtractValue to use last container in insert chain (PR #176588)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Jan 17 11:22:10 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: None (neildhar)

<details>
<summary>Changes</summary>

Note: This PR is stacked on top of #<!-- -->176583

The current logic only updates the container operand to the last `InsertValueOp` in a chain if we haven't switched to a nested insert chain. Instead, keep track of the new container value and extract position at all times, and always update if we have found a point higher up in the chain to extract from.

This allows us to bypass more insertions (see the updated test) when we are accessing nested struct members. It also allows us to move the constant check back to the top, because we can just do it on a successive call to fold.

Also added a test for a missing case (it is unchanged by this PR).

---
Full diff: https://github.com/llvm/llvm-project/pull/176588.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+28-30) 
- (modified) mlir/test/Dialect/LLVMIR/canonicalize.mlir (+19-5) 


``````````diff
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index f9162b35966c1..1df0285c91c42 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1928,11 +1928,20 @@ OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
     return getResult();
   }
 
-  Operation *container = getContainer().getDefiningOp();
-  OpFoldResult result = {};
+  Attribute containerAttr;
+  if (matchPattern(getContainer(), m_Constant(&containerAttr))) {
+    for (int64_t pos : getPosition()) {
+      containerAttr = extractElementAt(containerAttr, pos);
+      if (!containerAttr)
+        return nullptr;
+    }
+    return containerAttr;
+  }
+
+  Value container = getContainer();
   ArrayRef<int64_t> extractPos = getPosition();
-  bool switchedToInsertedValue = false;
-  while (auto insertValueOp = dyn_cast_if_present<InsertValueOp>(container)) {
+  while (auto insertValueOp =
+             dyn_cast_if_present<InsertValueOp>(container.getDefiningOp())) {
     ArrayRef<int64_t> insertPos = insertValueOp.getPosition();
     auto extractPosSize = extractPos.size();
     auto insertPosSize = insertPos.size();
@@ -1955,18 +1964,16 @@ OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
     // In the above example, %4 is folded to %arg1.
     if (extractPosSize > insertPosSize &&
         extractPos.take_front(insertPosSize) == insertPos) {
-      container = insertValueOp.getValue().getDefiningOp();
+      container = insertValueOp.getValue();
       extractPos = extractPos.drop_front(insertPosSize);
-      switchedToInsertedValue = true;
       continue;
     }
 
     // Case 3: Try to continue the traversal with the container value.
-    unsigned min = std::min(extractPosSize, insertPosSize);
 
-    // If one is fully prefix of the other, stop propagating back as it will
-    // miss dependencies. For instance, %3 should not fold to %f0 in the
-    // following example:
+    // If extract position is a prefix of insert position, stop propagating back
+    // as it will miss dependencies. For instance, %3 should not fold to %f0 in
+    // the following example:
     // ```
     //   %1 = llvm.insertvalue %f0, %0[0, 0] :
     //     !llvm.array<4 x !llvm.array<4 x f32>>
@@ -1974,31 +1981,22 @@ OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
     //     !llvm.array<4 x !llvm.array<4 x f32>>
     //   %3 = llvm.extractvalue %2[0, 0] : !llvm.array<4 x !llvm.array<4 x f32>>
     // ```
-    if (extractPos.take_front(min) == insertPos.take_front(min))
-      return result;
+    if (insertPosSize > extractPosSize &&
+        extractPos == insertPos.take_front(extractPosSize))
+      break;
     // If neither a prefix, nor the exact position, we can extract out of the
     // value being inserted into. Moreover, we can try again if that operand
     // is itself an insertvalue expression.
-    if (!switchedToInsertedValue) {
-      // Do not swap out the container operand if we decided earlier to
-      // continue the traversal with the inserted value (Case 2).
-      getContainerMutable().assign(insertValueOp.getContainer());
-      result = getResult();
-    }
-    container = insertValueOp.getContainer().getDefiningOp();
+    container = insertValueOp.getContainer();
   }
-  if (!container)
-    return result;
 
-  Attribute containerAttr;
-  if (!matchPattern(container, m_Constant(&containerAttr)))
-    return nullptr;
-  for (int64_t pos : extractPos) {
-    containerAttr = extractElementAt(containerAttr, pos);
-    if (!containerAttr)
-      return nullptr;
-  }
-  return containerAttr;
+  // If we identified a container higher up in the chain, update the position
+  // and container operands.
+  if (container == getContainer())
+    return {};
+  setPosition(extractPos);
+  getContainerMutable().assign(container);
+  return getResult();
 }
 
 LogicalResult ExtractValueOp::verify() {
diff --git a/mlir/test/Dialect/LLVMIR/canonicalize.mlir b/mlir/test/Dialect/LLVMIR/canonicalize.mlir
index 8303afc9eb033..b1c2df87d4867 100644
--- a/mlir/test/Dialect/LLVMIR/canonicalize.mlir
+++ b/mlir/test/Dialect/LLVMIR/canonicalize.mlir
@@ -78,8 +78,7 @@ llvm.func @no_fold_extractvalue(%arr: !llvm.array<4 x f32>) -> f32 {
   %f0 = arith.constant 0.0 : f32
   %0 = llvm.mlir.undef : !llvm.array<4 x !llvm.array<4 x f32>>
 
-  // CHECK: insertvalue
-  // CHECK: insertvalue
+  // CHECK-NOT: insertvalue
   // CHECK: extractvalue
   %1 = llvm.insertvalue %f0, %0[0, 0] : !llvm.array<4 x !llvm.array<4 x f32>>
   %2 = llvm.insertvalue %arr, %1[0] : !llvm.array<4 x !llvm.array<4 x f32>>
@@ -90,6 +89,21 @@ llvm.func @no_fold_extractvalue(%arr: !llvm.array<4 x f32>) -> f32 {
 
 // -----
 
+// CHECK-LABEL: fold_nested_extractvalue
+// CHECK-SAME:     %[[arg1:.*]]: i32, %[[arg2:.*]]: i32)
+// CHECK-NOT: insertvalue
+// CHECK-NOT: extractvalue
+// CHECK: llvm.return %[[arg1]] : i32
+llvm.func @fold_nested_extractvalue(%arg1: i32, %arg2: i32) -> i32 {
+  %0 = llvm.mlir.undef : !llvm.struct<(struct<(i32, i32)>, i32)>
+  %1 = llvm.insertvalue %arg1, %0[0, 0] : !llvm.struct<(struct<(i32, i32)>, i32)>
+  %2 = llvm.insertvalue %arg2, %1[0, 1] : !llvm.struct<(struct<(i32, i32)>, i32)>
+  %3 = llvm.extractvalue %2[0, 0] : !llvm.struct<(struct<(i32, i32)>, i32)>
+  llvm.return %3 : i32
+}
+
+// -----
+
 // CHECK-LABEL: fold_unrelated_extractvalue
 llvm.func @fold_unrelated_extractvalue(%arr: !llvm.array<4 x f32>) -> f32 {
   %f0 = arith.constant 0.0 : f32
@@ -103,10 +117,10 @@ llvm.func @fold_unrelated_extractvalue(%arr: !llvm.array<4 x f32>) -> f32 {
 // -----
 // CHECK-LABEL: fold_extract_extractvalue
 llvm.func @fold_extract_extractvalue(%arr: !llvm.struct<(i64, array<1 x ptr<1>>)>) -> !llvm.ptr<1> {
-  // CHECK: llvm.extractvalue %{{.*}}[1, 0] 
+  // CHECK: llvm.extractvalue %{{.*}}[1, 0]
   // CHECK-NOT: extractvalue
-  %a = llvm.extractvalue %arr[1] : !llvm.struct<(i64, array<1 x ptr<1>>)> 
-  %b = llvm.extractvalue %a[0] : !llvm.array<1 x ptr<1>> 
+  %a = llvm.extractvalue %arr[1] : !llvm.struct<(i64, array<1 x ptr<1>>)>
+  %b = llvm.extractvalue %a[0] : !llvm.array<1 x ptr<1>>
   llvm.return %b : !llvm.ptr<1>
 }
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/176588


More information about the Mlir-commits mailing list