[Mlir-commits] [mlir] [mlir] Always update ExtractValue to use last container in insert chain (PR #176588)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Jan 17 11:21:42 PST 2026
https://github.com/neildhar created https://github.com/llvm/llvm-project/pull/176588
Note: This PR is stacked on top of #176583
The current logic only updates the container operand to the last `InsertValueOp` in a chain if we haven't switched to a nested insert chain. Instead, keep track of the new container value and extract position at all times, and always update if we have found a point higher up in the chain to extract from.
This allows us to bypass more insertions (see the updated test) when we are accessing nested struct members. It also allows us to move the constant check back to the top, because we can just do it on a successive call to fold.
Also added a test for a missing case (it is unchanged by this PR).
>From 4f2e5ccbe5e6fdbce069e6e9302888001d773b9c Mon Sep 17 00:00:00 2001
From: Neil Dhar <neildhar at meta.com>
Date: Sat, 17 Jan 2026 09:46:25 -0800
Subject: [PATCH 1/2] [NFC][mlir] Clarify bail condition in
ExtractValueOp::fold
---
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index f9162b35966c1..957f5348c9572 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1962,11 +1962,10 @@ OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
}
// Case 3: Try to continue the traversal with the container value.
- unsigned min = std::min(extractPosSize, insertPosSize);
- // If one is fully prefix of the other, stop propagating back as it will
- // miss dependencies. For instance, %3 should not fold to %f0 in the
- // following example:
+ // If extract position is a prefix of insert position, stop propagating back
+ // as it will miss dependencies. For instance, %3 should not fold to %f0 in
+ // the following example:
// ```
// %1 = llvm.insertvalue %f0, %0[0, 0] :
// !llvm.array<4 x !llvm.array<4 x f32>>
@@ -1974,7 +1973,8 @@ OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
// !llvm.array<4 x !llvm.array<4 x f32>>
// %3 = llvm.extractvalue %2[0, 0] : !llvm.array<4 x !llvm.array<4 x f32>>
// ```
- if (extractPos.take_front(min) == insertPos.take_front(min))
+ if (insertPosSize > extractPosSize &&
+ extractPos == insertPos.take_front(extractPosSize))
return result;
// If neither a prefix, nor the exact position, we can extract out of the
// value being inserted into. Moreover, we can try again if that operand
>From c8b9b1c6a308abbb10f542cfb7cc08b2b247f52f Mon Sep 17 00:00:00 2001
From: Neil Dhar <neildhar at meta.com>
Date: Sat, 17 Jan 2026 10:48:02 -0800
Subject: [PATCH 2/2] [mlir] Always update ExtractValue to use last container
in insert chain
---
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 48 +++++++++++-----------
mlir/test/Dialect/LLVMIR/canonicalize.mlir | 24 ++++++++---
2 files changed, 42 insertions(+), 30 deletions(-)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 957f5348c9572..1df0285c91c42 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1928,11 +1928,20 @@ OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
return getResult();
}
- Operation *container = getContainer().getDefiningOp();
- OpFoldResult result = {};
+ Attribute containerAttr;
+ if (matchPattern(getContainer(), m_Constant(&containerAttr))) {
+ for (int64_t pos : getPosition()) {
+ containerAttr = extractElementAt(containerAttr, pos);
+ if (!containerAttr)
+ return nullptr;
+ }
+ return containerAttr;
+ }
+
+ Value container = getContainer();
ArrayRef<int64_t> extractPos = getPosition();
- bool switchedToInsertedValue = false;
- while (auto insertValueOp = dyn_cast_if_present<InsertValueOp>(container)) {
+ while (auto insertValueOp =
+ dyn_cast_if_present<InsertValueOp>(container.getDefiningOp())) {
ArrayRef<int64_t> insertPos = insertValueOp.getPosition();
auto extractPosSize = extractPos.size();
auto insertPosSize = insertPos.size();
@@ -1955,9 +1964,8 @@ OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
// In the above example, %4 is folded to %arg1.
if (extractPosSize > insertPosSize &&
extractPos.take_front(insertPosSize) == insertPos) {
- container = insertValueOp.getValue().getDefiningOp();
+ container = insertValueOp.getValue();
extractPos = extractPos.drop_front(insertPosSize);
- switchedToInsertedValue = true;
continue;
}
@@ -1975,30 +1983,20 @@ OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
// ```
if (insertPosSize > extractPosSize &&
extractPos == insertPos.take_front(extractPosSize))
- return result;
+ break;
// If neither a prefix, nor the exact position, we can extract out of the
// value being inserted into. Moreover, we can try again if that operand
// is itself an insertvalue expression.
- if (!switchedToInsertedValue) {
- // Do not swap out the container operand if we decided earlier to
- // continue the traversal with the inserted value (Case 2).
- getContainerMutable().assign(insertValueOp.getContainer());
- result = getResult();
- }
- container = insertValueOp.getContainer().getDefiningOp();
+ container = insertValueOp.getContainer();
}
- if (!container)
- return result;
- Attribute containerAttr;
- if (!matchPattern(container, m_Constant(&containerAttr)))
- return nullptr;
- for (int64_t pos : extractPos) {
- containerAttr = extractElementAt(containerAttr, pos);
- if (!containerAttr)
- return nullptr;
- }
- return containerAttr;
+ // If we identified a container higher up in the chain, update the position
+ // and container operands.
+ if (container == getContainer())
+ return {};
+ setPosition(extractPos);
+ getContainerMutable().assign(container);
+ return getResult();
}
LogicalResult ExtractValueOp::verify() {
diff --git a/mlir/test/Dialect/LLVMIR/canonicalize.mlir b/mlir/test/Dialect/LLVMIR/canonicalize.mlir
index 8303afc9eb033..b1c2df87d4867 100644
--- a/mlir/test/Dialect/LLVMIR/canonicalize.mlir
+++ b/mlir/test/Dialect/LLVMIR/canonicalize.mlir
@@ -78,8 +78,7 @@ llvm.func @no_fold_extractvalue(%arr: !llvm.array<4 x f32>) -> f32 {
%f0 = arith.constant 0.0 : f32
%0 = llvm.mlir.undef : !llvm.array<4 x !llvm.array<4 x f32>>
- // CHECK: insertvalue
- // CHECK: insertvalue
+ // CHECK-NOT: insertvalue
// CHECK: extractvalue
%1 = llvm.insertvalue %f0, %0[0, 0] : !llvm.array<4 x !llvm.array<4 x f32>>
%2 = llvm.insertvalue %arr, %1[0] : !llvm.array<4 x !llvm.array<4 x f32>>
@@ -90,6 +89,21 @@ llvm.func @no_fold_extractvalue(%arr: !llvm.array<4 x f32>) -> f32 {
// -----
+// CHECK-LABEL: fold_nested_extractvalue
+// CHECK-SAME: %[[arg1:.*]]: i32, %[[arg2:.*]]: i32)
+// CHECK-NOT: insertvalue
+// CHECK-NOT: extractvalue
+// CHECK: llvm.return %[[arg1]] : i32
+llvm.func @fold_nested_extractvalue(%arg1: i32, %arg2: i32) -> i32 {
+ %0 = llvm.mlir.undef : !llvm.struct<(struct<(i32, i32)>, i32)>
+ %1 = llvm.insertvalue %arg1, %0[0, 0] : !llvm.struct<(struct<(i32, i32)>, i32)>
+ %2 = llvm.insertvalue %arg2, %1[0, 1] : !llvm.struct<(struct<(i32, i32)>, i32)>
+ %3 = llvm.extractvalue %2[0, 0] : !llvm.struct<(struct<(i32, i32)>, i32)>
+ llvm.return %3 : i32
+}
+
+// -----
+
// CHECK-LABEL: fold_unrelated_extractvalue
llvm.func @fold_unrelated_extractvalue(%arr: !llvm.array<4 x f32>) -> f32 {
%f0 = arith.constant 0.0 : f32
@@ -103,10 +117,10 @@ llvm.func @fold_unrelated_extractvalue(%arr: !llvm.array<4 x f32>) -> f32 {
// -----
// CHECK-LABEL: fold_extract_extractvalue
llvm.func @fold_extract_extractvalue(%arr: !llvm.struct<(i64, array<1 x ptr<1>>)>) -> !llvm.ptr<1> {
- // CHECK: llvm.extractvalue %{{.*}}[1, 0]
+ // CHECK: llvm.extractvalue %{{.*}}[1, 0]
// CHECK-NOT: extractvalue
- %a = llvm.extractvalue %arr[1] : !llvm.struct<(i64, array<1 x ptr<1>>)>
- %b = llvm.extractvalue %a[0] : !llvm.array<1 x ptr<1>>
+ %a = llvm.extractvalue %arr[1] : !llvm.struct<(i64, array<1 x ptr<1>>)>
+ %b = llvm.extractvalue %a[0] : !llvm.array<1 x ptr<1>>
llvm.return %b : !llvm.ptr<1>
}
More information about the Mlir-commits
mailing list