[Mlir-commits] [mlir] [mlir] Always update ExtractValue to use last container in insert chain (PR #176588)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Jan 18 16:33:26 PST 2026
https://github.com/neildhar updated https://github.com/llvm/llvm-project/pull/176588
>From 4f2e5ccbe5e6fdbce069e6e9302888001d773b9c Mon Sep 17 00:00:00 2001
From: Neil Dhar <neildhar at meta.com>
Date: Sat, 17 Jan 2026 09:46:25 -0800
Subject: [PATCH 1/2] [NFC][mlir] Clarify bail condition in
ExtractValueOp::fold
---
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index f9162b35966c1..957f5348c9572 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1962,11 +1962,10 @@ OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
}
// Case 3: Try to continue the traversal with the container value.
- unsigned min = std::min(extractPosSize, insertPosSize);
- // If one is fully prefix of the other, stop propagating back as it will
- // miss dependencies. For instance, %3 should not fold to %f0 in the
- // following example:
+ // If extract position is a prefix of insert position, stop propagating back
+ // as it will miss dependencies. For instance, %3 should not fold to %f0 in
+ // the following example:
// ```
// %1 = llvm.insertvalue %f0, %0[0, 0] :
// !llvm.array<4 x !llvm.array<4 x f32>>
@@ -1974,7 +1973,8 @@ OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
// !llvm.array<4 x !llvm.array<4 x f32>>
// %3 = llvm.extractvalue %2[0, 0] : !llvm.array<4 x !llvm.array<4 x f32>>
// ```
- if (extractPos.take_front(min) == insertPos.take_front(min))
+ if (insertPosSize > extractPosSize &&
+ extractPos == insertPos.take_front(extractPosSize))
return result;
// If neither a prefix, nor the exact position, we can extract out of the
// value being inserted into. Moreover, we can try again if that operand
>From 40a5214f930abd0b21afd0a2236fff3612c20edf Mon Sep 17 00:00:00 2001
From: Neil Dhar <neildhar at meta.com>
Date: Sat, 17 Jan 2026 10:48:02 -0800
Subject: [PATCH 2/2] [mlir] Always update ExtractValue to use last container
in insert chain
---
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 48 +++++++++++-----------
mlir/test/Dialect/LLVMIR/canonicalize.mlir | 24 ++++++++---
2 files changed, 42 insertions(+), 30 deletions(-)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 957f5348c9572..4bd840b852fb6 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1928,11 +1928,19 @@ OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
return getResult();
}
- Operation *container = getContainer().getDefiningOp();
- OpFoldResult result = {};
+ Attribute containerAttr;
+ if (matchPattern(getContainer(), m_Constant(&containerAttr))) {
+ for (int64_t pos : getPosition()) {
+ containerAttr = extractElementAt(containerAttr, pos);
+ if (!containerAttr)
+ return nullptr;
+ }
+ return containerAttr;
+ }
+
+ Value container = getContainer();
ArrayRef<int64_t> extractPos = getPosition();
- bool switchedToInsertedValue = false;
- while (auto insertValueOp = dyn_cast_if_present<InsertValueOp>(container)) {
+ while (auto insertValueOp = container.getDefiningOp<InsertValueOp>()) {
ArrayRef<int64_t> insertPos = insertValueOp.getPosition();
auto extractPosSize = extractPos.size();
auto insertPosSize = insertPos.size();
@@ -1955,9 +1963,8 @@ OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
// In the above example, %4 is folded to %arg1.
if (extractPosSize > insertPosSize &&
extractPos.take_front(insertPosSize) == insertPos) {
- container = insertValueOp.getValue().getDefiningOp();
+ container = insertValueOp.getValue();
extractPos = extractPos.drop_front(insertPosSize);
- switchedToInsertedValue = true;
continue;
}
@@ -1975,30 +1982,21 @@ OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
// ```
if (insertPosSize > extractPosSize &&
extractPos == insertPos.take_front(extractPosSize))
- return result;
+ break;
// If neither a prefix, nor the exact position, we can extract out of the
// value being inserted into. Moreover, we can try again if that operand
// is itself an insertvalue expression.
- if (!switchedToInsertedValue) {
- // Do not swap out the container operand if we decided earlier to
- // continue the traversal with the inserted value (Case 2).
- getContainerMutable().assign(insertValueOp.getContainer());
- result = getResult();
- }
- container = insertValueOp.getContainer().getDefiningOp();
+ container = insertValueOp.getContainer();
}
- if (!container)
- return result;
- Attribute containerAttr;
- if (!matchPattern(container, m_Constant(&containerAttr)))
- return nullptr;
- for (int64_t pos : extractPos) {
- containerAttr = extractElementAt(containerAttr, pos);
- if (!containerAttr)
- return nullptr;
- }
- return containerAttr;
+ // We failed to resolve past this container either because it is not an
+ // InsertValueOp, or it is an InsertValueOp that partially overlaps with the
+ // value being extracted. Update to read from this container instead.
+ if (container == getContainer())
+ return {};
+ setPosition(extractPos);
+ getContainerMutable().assign(container);
+ return getResult();
}
LogicalResult ExtractValueOp::verify() {
diff --git a/mlir/test/Dialect/LLVMIR/canonicalize.mlir b/mlir/test/Dialect/LLVMIR/canonicalize.mlir
index 8303afc9eb033..b1c2df87d4867 100644
--- a/mlir/test/Dialect/LLVMIR/canonicalize.mlir
+++ b/mlir/test/Dialect/LLVMIR/canonicalize.mlir
@@ -78,8 +78,7 @@ llvm.func @no_fold_extractvalue(%arr: !llvm.array<4 x f32>) -> f32 {
%f0 = arith.constant 0.0 : f32
%0 = llvm.mlir.undef : !llvm.array<4 x !llvm.array<4 x f32>>
- // CHECK: insertvalue
- // CHECK: insertvalue
+ // CHECK-NOT: insertvalue
// CHECK: extractvalue
%1 = llvm.insertvalue %f0, %0[0, 0] : !llvm.array<4 x !llvm.array<4 x f32>>
%2 = llvm.insertvalue %arr, %1[0] : !llvm.array<4 x !llvm.array<4 x f32>>
@@ -90,6 +89,21 @@ llvm.func @no_fold_extractvalue(%arr: !llvm.array<4 x f32>) -> f32 {
// -----
+// CHECK-LABEL: fold_nested_extractvalue
+// CHECK-SAME: %[[arg1:.*]]: i32, %[[arg2:.*]]: i32)
+// CHECK-NOT: insertvalue
+// CHECK-NOT: extractvalue
+// CHECK: llvm.return %[[arg1]] : i32
+llvm.func @fold_nested_extractvalue(%arg1: i32, %arg2: i32) -> i32 {
+ %0 = llvm.mlir.undef : !llvm.struct<(struct<(i32, i32)>, i32)>
+ %1 = llvm.insertvalue %arg1, %0[0, 0] : !llvm.struct<(struct<(i32, i32)>, i32)>
+ %2 = llvm.insertvalue %arg2, %1[0, 1] : !llvm.struct<(struct<(i32, i32)>, i32)>
+ %3 = llvm.extractvalue %2[0, 0] : !llvm.struct<(struct<(i32, i32)>, i32)>
+ llvm.return %3 : i32
+}
+
+// -----
+
// CHECK-LABEL: fold_unrelated_extractvalue
llvm.func @fold_unrelated_extractvalue(%arr: !llvm.array<4 x f32>) -> f32 {
%f0 = arith.constant 0.0 : f32
@@ -103,10 +117,10 @@ llvm.func @fold_unrelated_extractvalue(%arr: !llvm.array<4 x f32>) -> f32 {
// -----
// CHECK-LABEL: fold_extract_extractvalue
llvm.func @fold_extract_extractvalue(%arr: !llvm.struct<(i64, array<1 x ptr<1>>)>) -> !llvm.ptr<1> {
- // CHECK: llvm.extractvalue %{{.*}}[1, 0]
+ // CHECK: llvm.extractvalue %{{.*}}[1, 0]
// CHECK-NOT: extractvalue
- %a = llvm.extractvalue %arr[1] : !llvm.struct<(i64, array<1 x ptr<1>>)>
- %b = llvm.extractvalue %a[0] : !llvm.array<1 x ptr<1>>
+ %a = llvm.extractvalue %arr[1] : !llvm.struct<(i64, array<1 x ptr<1>>)>
+ %b = llvm.extractvalue %a[0] : !llvm.array<1 x ptr<1>>
llvm.return %b : !llvm.ptr<1>
}
More information about the Mlir-commits
mailing list