[Mlir-commits] [mlir] 7f8557c - [mlir] Always update ExtractValue to use last container in insert chain (#176588)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jan 19 11:22:54 PST 2026
Author: neildhar
Date: 2026-01-19T20:22:50+01:00
New Revision: 7f8557cb2ef3065394d844cfd2c3592fcbaf4b90
URL: https://github.com/llvm/llvm-project/commit/7f8557cb2ef3065394d844cfd2c3592fcbaf4b90
DIFF: https://github.com/llvm/llvm-project/commit/7f8557cb2ef3065394d844cfd2c3592fcbaf4b90.diff
LOG: [mlir] Always update ExtractValue to use last container in insert chain (#176588)
The current logic only updates the container operand to the last
`InsertValueOp` in a chain if we haven't switched to a nested insert
chain. Instead, keep track of the new container value and extract
position at all times, and always update if we have found a point higher
up in the chain to extract from.
This allows us to bypass more insertions (see the updated test) when we
are accessing nested struct members. It also allows us to move the
constant check back to the top, because we can just do it on a
successive call to fold.
Also added a test for a missing case (it is unchanged by this PR).
Added:
Modified:
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/test/Dialect/LLVMIR/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index e6addfc4359a9..91fbc53c5eb32 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1928,11 +1928,19 @@ OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
return getResult();
}
- Operation *container = getContainer().getDefiningOp();
- OpFoldResult result = {};
+ Attribute containerAttr;
+ if (matchPattern(getContainer(), m_Constant(&containerAttr))) {
+ for (int64_t pos : getPosition()) {
+ containerAttr = extractElementAt(containerAttr, pos);
+ if (!containerAttr)
+ return nullptr;
+ }
+ return containerAttr;
+ }
+
+ Value container = getContainer();
ArrayRef<int64_t> extractPos = getPosition();
- bool switchedToInsertedValue = false;
- while (auto insertValueOp = dyn_cast_if_present<InsertValueOp>(container)) {
+ while (auto insertValueOp = container.getDefiningOp<InsertValueOp>()) {
ArrayRef<int64_t> insertPos = insertValueOp.getPosition();
auto extractPosSize = extractPos.size();
auto insertPosSize = insertPos.size();
@@ -1955,9 +1963,8 @@ OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
// In the above example, %4 is folded to %arg1.
if (extractPosSize > insertPosSize &&
extractPos.take_front(insertPosSize) == insertPos) {
- container = insertValueOp.getValue().getDefiningOp();
+ container = insertValueOp.getValue();
extractPos = extractPos.drop_front(insertPosSize);
- switchedToInsertedValue = true;
continue;
}
@@ -1975,30 +1982,21 @@ OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
// ```
if (insertPosSize > extractPosSize &&
extractPos == insertPos.take_front(extractPosSize))
- return result;
+ break;
// If neither a prefix, nor the exact position, we can extract out of the
// value being inserted into. Moreover, we can try again if that operand
// is itself an insertvalue expression.
- if (!switchedToInsertedValue) {
- // Do not swap out the container operand if we decided earlier to
- // continue the traversal with the inserted value (Case 2).
- getContainerMutable().assign(insertValueOp.getContainer());
- result = getResult();
- }
- container = insertValueOp.getContainer().getDefiningOp();
+ container = insertValueOp.getContainer();
}
- if (!container)
- return result;
- Attribute containerAttr;
- if (!matchPattern(container, m_Constant(&containerAttr)))
- return nullptr;
- for (int64_t pos : extractPos) {
- containerAttr = extractElementAt(containerAttr, pos);
- if (!containerAttr)
- return nullptr;
- }
- return containerAttr;
+ // We failed to resolve past this container either because it is not an
+ // InsertValueOp, or it is an InsertValueOp that partially overlaps with the
+ // value being extracted. Update to read from this container instead.
+ if (container == getContainer())
+ return {};
+ setPosition(extractPos);
+ getContainerMutable().assign(container);
+ return getResult();
}
LogicalResult ExtractValueOp::verify() {
diff --git a/mlir/test/Dialect/LLVMIR/canonicalize.mlir b/mlir/test/Dialect/LLVMIR/canonicalize.mlir
index 8303afc9eb033..b1c2df87d4867 100644
--- a/mlir/test/Dialect/LLVMIR/canonicalize.mlir
+++ b/mlir/test/Dialect/LLVMIR/canonicalize.mlir
@@ -78,8 +78,7 @@ llvm.func @no_fold_extractvalue(%arr: !llvm.array<4 x f32>) -> f32 {
%f0 = arith.constant 0.0 : f32
%0 = llvm.mlir.undef : !llvm.array<4 x !llvm.array<4 x f32>>
- // CHECK: insertvalue
- // CHECK: insertvalue
+ // CHECK-NOT: insertvalue
// CHECK: extractvalue
%1 = llvm.insertvalue %f0, %0[0, 0] : !llvm.array<4 x !llvm.array<4 x f32>>
%2 = llvm.insertvalue %arr, %1[0] : !llvm.array<4 x !llvm.array<4 x f32>>
@@ -90,6 +89,21 @@ llvm.func @no_fold_extractvalue(%arr: !llvm.array<4 x f32>) -> f32 {
// -----
+// CHECK-LABEL: fold_nested_extractvalue
+// CHECK-SAME: %[[arg1:.*]]: i32, %[[arg2:.*]]: i32)
+// CHECK-NOT: insertvalue
+// CHECK-NOT: extractvalue
+// CHECK: llvm.return %[[arg1]] : i32
+llvm.func @fold_nested_extractvalue(%arg1: i32, %arg2: i32) -> i32 {
+ %0 = llvm.mlir.undef : !llvm.struct<(struct<(i32, i32)>, i32)>
+ %1 = llvm.insertvalue %arg1, %0[0, 0] : !llvm.struct<(struct<(i32, i32)>, i32)>
+ %2 = llvm.insertvalue %arg2, %1[0, 1] : !llvm.struct<(struct<(i32, i32)>, i32)>
+ %3 = llvm.extractvalue %2[0, 0] : !llvm.struct<(struct<(i32, i32)>, i32)>
+ llvm.return %3 : i32
+}
+
+// -----
+
// CHECK-LABEL: fold_unrelated_extractvalue
llvm.func @fold_unrelated_extractvalue(%arr: !llvm.array<4 x f32>) -> f32 {
%f0 = arith.constant 0.0 : f32
@@ -103,10 +117,10 @@ llvm.func @fold_unrelated_extractvalue(%arr: !llvm.array<4 x f32>) -> f32 {
// -----
// CHECK-LABEL: fold_extract_extractvalue
llvm.func @fold_extract_extractvalue(%arr: !llvm.struct<(i64, array<1 x ptr<1>>)>) -> !llvm.ptr<1> {
- // CHECK: llvm.extractvalue %{{.*}}[1, 0]
+ // CHECK: llvm.extractvalue %{{.*}}[1, 0]
// CHECK-NOT: extractvalue
- %a = llvm.extractvalue %arr[1] : !llvm.struct<(i64, array<1 x ptr<1>>)>
- %b = llvm.extractvalue %a[0] : !llvm.array<1 x ptr<1>>
+ %a = llvm.extractvalue %arr[1] : !llvm.struct<(i64, array<1 x ptr<1>>)>
+ %b = llvm.extractvalue %a[0] : !llvm.array<1 x ptr<1>>
llvm.return %b : !llvm.ptr<1>
}
More information about the Mlir-commits
mailing list