[Mlir-commits] [mlir] [mlir] Always update ExtractValue to use last container in insert chain (PR #176588)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Jan 18 16:33:26 PST 2026


https://github.com/neildhar updated https://github.com/llvm/llvm-project/pull/176588

>From 4f2e5ccbe5e6fdbce069e6e9302888001d773b9c Mon Sep 17 00:00:00 2001
From: Neil Dhar <neildhar at meta.com>
Date: Sat, 17 Jan 2026 09:46:25 -0800
Subject: [PATCH 1/2] [NFC][mlir] Clarify bail condition in
 ExtractValueOp::fold

---
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index f9162b35966c1..957f5348c9572 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1962,11 +1962,10 @@ OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
     }
 
     // Case 3: Try to continue the traversal with the container value.
-    unsigned min = std::min(extractPosSize, insertPosSize);
 
-    // If one is fully prefix of the other, stop propagating back as it will
-    // miss dependencies. For instance, %3 should not fold to %f0 in the
-    // following example:
+    // If extract position is a prefix of insert position, stop propagating back
+    // as it will miss dependencies. For instance, %3 should not fold to %f0 in
+    // the following example:
     // ```
     //   %1 = llvm.insertvalue %f0, %0[0, 0] :
     //     !llvm.array<4 x !llvm.array<4 x f32>>
@@ -1974,7 +1973,8 @@ OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
     //     !llvm.array<4 x !llvm.array<4 x f32>>
     //   %3 = llvm.extractvalue %2[0, 0] : !llvm.array<4 x !llvm.array<4 x f32>>
     // ```
-    if (extractPos.take_front(min) == insertPos.take_front(min))
+    if (insertPosSize > extractPosSize &&
+        extractPos == insertPos.take_front(extractPosSize))
       return result;
     // If neither a prefix, nor the exact position, we can extract out of the
     // value being inserted into. Moreover, we can try again if that operand

>From 40a5214f930abd0b21afd0a2236fff3612c20edf Mon Sep 17 00:00:00 2001
From: Neil Dhar <neildhar at meta.com>
Date: Sat, 17 Jan 2026 10:48:02 -0800
Subject: [PATCH 2/2] [mlir] Always update ExtractValue to use last container
 in insert chain

---
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 48 +++++++++++-----------
 mlir/test/Dialect/LLVMIR/canonicalize.mlir | 24 ++++++++---
 2 files changed, 42 insertions(+), 30 deletions(-)

diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 957f5348c9572..4bd840b852fb6 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1928,11 +1928,19 @@ OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
     return getResult();
   }
 
-  Operation *container = getContainer().getDefiningOp();
-  OpFoldResult result = {};
+  Attribute containerAttr;
+  if (matchPattern(getContainer(), m_Constant(&containerAttr))) {
+    for (int64_t pos : getPosition()) {
+      containerAttr = extractElementAt(containerAttr, pos);
+      if (!containerAttr)
+        return nullptr;
+    }
+    return containerAttr;
+  }
+
+  Value container = getContainer();
   ArrayRef<int64_t> extractPos = getPosition();
-  bool switchedToInsertedValue = false;
-  while (auto insertValueOp = dyn_cast_if_present<InsertValueOp>(container)) {
+  while (auto insertValueOp = container.getDefiningOp<InsertValueOp>()) {
     ArrayRef<int64_t> insertPos = insertValueOp.getPosition();
     auto extractPosSize = extractPos.size();
     auto insertPosSize = insertPos.size();
@@ -1955,9 +1963,8 @@ OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
     // In the above example, %4 is folded to %arg1.
     if (extractPosSize > insertPosSize &&
         extractPos.take_front(insertPosSize) == insertPos) {
-      container = insertValueOp.getValue().getDefiningOp();
+      container = insertValueOp.getValue();
       extractPos = extractPos.drop_front(insertPosSize);
-      switchedToInsertedValue = true;
       continue;
     }
 
@@ -1975,30 +1982,21 @@ OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
     // ```
     if (insertPosSize > extractPosSize &&
         extractPos == insertPos.take_front(extractPosSize))
-      return result;
+      break;
     // If neither a prefix, nor the exact position, we can extract out of the
     // value being inserted into. Moreover, we can try again if that operand
     // is itself an insertvalue expression.
-    if (!switchedToInsertedValue) {
-      // Do not swap out the container operand if we decided earlier to
-      // continue the traversal with the inserted value (Case 2).
-      getContainerMutable().assign(insertValueOp.getContainer());
-      result = getResult();
-    }
-    container = insertValueOp.getContainer().getDefiningOp();
+    container = insertValueOp.getContainer();
   }
-  if (!container)
-    return result;
 
-  Attribute containerAttr;
-  if (!matchPattern(container, m_Constant(&containerAttr)))
-    return nullptr;
-  for (int64_t pos : extractPos) {
-    containerAttr = extractElementAt(containerAttr, pos);
-    if (!containerAttr)
-      return nullptr;
-  }
-  return containerAttr;
+  // We failed to resolve past this container either because it is not an
+  // InsertValueOp, or it is an InsertValueOp that partially overlaps with the
+  // value being extracted. Update to read from this container instead.
+  if (container == getContainer())
+    return {};
+  setPosition(extractPos);
+  getContainerMutable().assign(container);
+  return getResult();
 }
 
 LogicalResult ExtractValueOp::verify() {
diff --git a/mlir/test/Dialect/LLVMIR/canonicalize.mlir b/mlir/test/Dialect/LLVMIR/canonicalize.mlir
index 8303afc9eb033..b1c2df87d4867 100644
--- a/mlir/test/Dialect/LLVMIR/canonicalize.mlir
+++ b/mlir/test/Dialect/LLVMIR/canonicalize.mlir
@@ -78,8 +78,7 @@ llvm.func @no_fold_extractvalue(%arr: !llvm.array<4 x f32>) -> f32 {
   %f0 = arith.constant 0.0 : f32
   %0 = llvm.mlir.undef : !llvm.array<4 x !llvm.array<4 x f32>>
 
-  // CHECK: insertvalue
-  // CHECK: insertvalue
+  // CHECK-NOT: insertvalue
   // CHECK: extractvalue
   %1 = llvm.insertvalue %f0, %0[0, 0] : !llvm.array<4 x !llvm.array<4 x f32>>
   %2 = llvm.insertvalue %arr, %1[0] : !llvm.array<4 x !llvm.array<4 x f32>>
@@ -90,6 +89,21 @@ llvm.func @no_fold_extractvalue(%arr: !llvm.array<4 x f32>) -> f32 {
 
 // -----
 
+// CHECK-LABEL: fold_nested_extractvalue
+// CHECK-SAME:     %[[arg1:.*]]: i32, %[[arg2:.*]]: i32)
+// CHECK-NOT: insertvalue
+// CHECK-NOT: extractvalue
+// CHECK: llvm.return %[[arg1]] : i32
+llvm.func @fold_nested_extractvalue(%arg1: i32, %arg2: i32) -> i32 {
+  %0 = llvm.mlir.undef : !llvm.struct<(struct<(i32, i32)>, i32)>
+  %1 = llvm.insertvalue %arg1, %0[0, 0] : !llvm.struct<(struct<(i32, i32)>, i32)>
+  %2 = llvm.insertvalue %arg2, %1[0, 1] : !llvm.struct<(struct<(i32, i32)>, i32)>
+  %3 = llvm.extractvalue %2[0, 0] : !llvm.struct<(struct<(i32, i32)>, i32)>
+  llvm.return %3 : i32
+}
+
+// -----
+
 // CHECK-LABEL: fold_unrelated_extractvalue
 llvm.func @fold_unrelated_extractvalue(%arr: !llvm.array<4 x f32>) -> f32 {
   %f0 = arith.constant 0.0 : f32
@@ -103,10 +117,10 @@ llvm.func @fold_unrelated_extractvalue(%arr: !llvm.array<4 x f32>) -> f32 {
 // -----
 // CHECK-LABEL: fold_extract_extractvalue
 llvm.func @fold_extract_extractvalue(%arr: !llvm.struct<(i64, array<1 x ptr<1>>)>) -> !llvm.ptr<1> {
-  // CHECK: llvm.extractvalue %{{.*}}[1, 0] 
+  // CHECK: llvm.extractvalue %{{.*}}[1, 0]
   // CHECK-NOT: extractvalue
-  %a = llvm.extractvalue %arr[1] : !llvm.struct<(i64, array<1 x ptr<1>>)> 
-  %b = llvm.extractvalue %a[0] : !llvm.array<1 x ptr<1>> 
+  %a = llvm.extractvalue %arr[1] : !llvm.struct<(i64, array<1 x ptr<1>>)>
+  %b = llvm.extractvalue %a[0] : !llvm.array<1 x ptr<1>>
   llvm.return %b : !llvm.ptr<1>
 }
 



More information about the Mlir-commits mailing list