[Mlir-commits] [mlir] [mlir][LLVM] Improve `llvm.extractvalue` folder (PR #136861)
Matthias Springer
llvmlistbot at llvm.org
Wed Apr 23 06:19:11 PDT 2025
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/136861
Continue the traversal on the SSA chain of the inserted value for additional folding opportunities.
>From 1ab843addd1adc567006e86ed38702c38cfe5b7c Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Wed, 23 Apr 2025 14:56:37 +0200
Subject: [PATCH] [mlir][LLVM] Improve `llvm.extractvalue` folder
---
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 47 ++++++++++++++++++----
mlir/test/Dialect/LLVMIR/canonicalize.mlir | 16 ++++++++
2 files changed, 55 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 0022be84c212e..26c3ef1e8b8bf 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1885,11 +1885,40 @@ OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
auto insertValueOp = getContainer().getDefiningOp<InsertValueOp>();
OpFoldResult result = {};
+ ArrayRef<int64_t> extractPos = getPosition();
+ bool switchedToInsertedValue = false;
while (insertValueOp) {
- if (getPosition() == insertValueOp.getPosition())
+ ArrayRef<int64_t> insertPos = insertValueOp.getPosition();
+ auto extractPosSize = extractPos.size();
+ auto insertPosSize = insertPos.size();
+
+ // Case 1: Exact match of positions.
+ if (extractPos == insertPos)
return insertValueOp.getValue();
- unsigned min =
- std::min(getPosition().size(), insertValueOp.getPosition().size());
+
+ // Case 2: Insert position is a prefix of extract position. Continue
+ // traversal with the inserted value. Example:
+ // ```
+ // %0 = llvm.insertvalue %arg1, %undef[0] : !llvm.struct<(i32, i32, i32)>
+ // %1 = llvm.insertvalue %arg2, %0[1] : !llvm.struct<(i32, i32, i32)>
+ // %2 = llvm.insertvalue %arg3, %1[2] : !llvm.struct<(i32, i32, i32)>
+ // %3 = llvm.insertvalue %2, %foo[0]
+ // : !llvm.struct<(struct<(i32, i32, i32)>, i64)>
+ // %4 = llvm.extractvalue %3[0, 0]
+ // : !llvm.struct<(struct<(i32, i32, i32)>, i64)>
+ // ```
+ // In the above example, %4 is folded to %arg1.
+ if (extractPosSize > insertPosSize &&
+ extractPos.take_front(insertPosSize) == insertPos) {
+ insertValueOp = insertValueOp.getValue().getDefiningOp<InsertValueOp>();
+ extractPos = extractPos.drop_front(insertPosSize);
+ switchedToInsertedValue = true;
+ continue;
+ }
+
+ // Case 3: Try to continue the traversal with the container value.
+ unsigned min = std::min(extractPosSize, insertPosSize);
+
// If one is fully prefix of the other, stop propagating back as it will
// miss dependencies. For instance, %3 should not fold to %f0 in the
// following example:
@@ -1900,15 +1929,17 @@ OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
// !llvm.array<4 x !llvm.array<4 x f32>>
// %3 = llvm.extractvalue %2[0, 0] : !llvm.array<4 x !llvm.array<4 x f32>>
// ```
- if (getPosition().take_front(min) ==
- insertValueOp.getPosition().take_front(min))
+ if (extractPos.take_front(min) == insertPos.take_front(min))
return result;
-
// If neither a prefix, nor the exact position, we can extract out of the
// value being inserted into. Moreover, we can try again if that operand
// is itself an insertvalue expression.
- getContainerMutable().assign(insertValueOp.getContainer());
- result = getResult();
+ if (!switchedToInsertedValue) {
+ // Do not swap out the container operand if we decided earlier to
+ // continue the traversal with the inserted value (Case 2).
+ getContainerMutable().assign(insertValueOp.getContainer());
+ result = getResult();
+ }
insertValueOp = insertValueOp.getContainer().getDefiningOp<InsertValueOp>();
}
return result;
diff --git a/mlir/test/Dialect/LLVMIR/canonicalize.mlir b/mlir/test/Dialect/LLVMIR/canonicalize.mlir
index a793caca064ec..8accf6e263863 100644
--- a/mlir/test/Dialect/LLVMIR/canonicalize.mlir
+++ b/mlir/test/Dialect/LLVMIR/canonicalize.mlir
@@ -57,6 +57,22 @@ llvm.func @fold_extractvalue() -> i32 {
// -----
+// CHECK-LABEL: fold_extractvalue(
+// CHECK-SAME: %[[arg1:.*]]: i32, %[[arg2:.*]]: i32, %[[arg3:.*]]: i32)
+// CHECK-NEXT: llvm.return %[[arg1]] : i32
+llvm.func @fold_extractvalue(%arg1: i32, %arg2: i32, %arg3: i32) -> i32{
+ %3 = llvm.mlir.undef : !llvm.struct<(struct<(i32, i32, i32)>, struct<(i32, i32)>)>
+ %5 = llvm.mlir.undef : !llvm.struct<(i32, i32, i32)>
+ %6 = llvm.insertvalue %arg1, %5[0] : !llvm.struct<(i32, i32, i32)>
+ %7 = llvm.insertvalue %arg1, %6[1] : !llvm.struct<(i32, i32, i32)>
+ %8 = llvm.insertvalue %arg1, %7[2] : !llvm.struct<(i32, i32, i32)>
+ %11 = llvm.insertvalue %8, %3[0] : !llvm.struct<(struct<(i32, i32, i32)>, struct<(i32, i32)>)>
+ %13 = llvm.extractvalue %11[0, 0] : !llvm.struct<(struct<(i32, i32, i32)>, struct<(i32, i32)>)>
+ llvm.return %13 : i32
+}
+
+// -----
+
// CHECK-LABEL: no_fold_extractvalue
llvm.func @no_fold_extractvalue(%arr: !llvm.array<4 x f32>) -> f32 {
%f0 = arith.constant 0.0 : f32
More information about the Mlir-commits
mailing list