[flang-commits] [flang] [flang] do not rely on existing fir.convert in TargetRewrite (PR #157413)
via flang-commits
flang-commits at lists.llvm.org
Mon Sep 8 02:57:40 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir
Author: None (jeanPerier)
<details>
<summary>Changes</summary>
TargetRewrite is doing a shallow rewrite of function signatures. It is only rewriting function definitions (FuncOp), calls (CallOp) and AddressOfOp. It is not trying to visit each operations that may have an operand with a function type.
It therefore needs function signature casts around the operations it is rewriting.
Currently, these casts were not inserted after AddressOfOp rewrites because lowering tends to always insert function cast after generating AddressOfOp to the void type so the pass relied on implicitly updating this cast operand type to get the required cast. This is brittle because there is no guarantee such convert must be here and canonicalization and passes may remove them.
Insert a cast after on the result of rewritten operations. If it is redundant, it will be canonicalized away later.
---
Full diff: https://github.com/llvm/llvm-project/pull/157413.diff
2 Files Affected:
- (modified) flang/lib/Optimizer/CodeGen/TargetRewrite.cpp (+9-1)
- (modified) flang/test/Fir/struct-return-x86-64.fir (+16-4)
``````````diff
diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
index fa935542d40f7..ac285b5d403df 100644
--- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
+++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
@@ -1336,7 +1336,15 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
private:
// Replace `op` and remove it.
void replaceOp(mlir::Operation *op, mlir::ValueRange newValues) {
- op->replaceAllUsesWith(newValues);
+ llvm::SmallVector<mlir::Value> casts;
+ for (auto [oldValue, newValue] : llvm::zip(op->getResults(), newValues)) {
+ if (oldValue.getType() == newValue.getType())
+ casts.push_back(newValue);
+ else
+ casts.push_back(fir::ConvertOp::create(*rewriter, op->getLoc(),
+ oldValue.getType(), newValue));
+ }
+ op->replaceAllUsesWith(casts);
op->dropAllReferences();
op->erase();
}
diff --git a/flang/test/Fir/struct-return-x86-64.fir b/flang/test/Fir/struct-return-x86-64.fir
index 5d1e6129d8f69..b45983daa97ba 100644
--- a/flang/test/Fir/struct-return-x86-64.fir
+++ b/flang/test/Fir/struct-return-x86-64.fir
@@ -17,6 +17,10 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data
%1 = fir.convert %0 : (() -> !fits_in_reg) -> (() -> ())
return %1 : () -> ()
}
+ func.func @test_addr_of_inreg_2() -> (() -> !fits_in_reg) {
+ %0 = fir.address_of(@test_inreg) : () -> !fits_in_reg
+ return %0 : () -> !fits_in_reg
+ }
func.func @test_dispatch_inreg(%arg0: !fir.ref<!fits_in_reg>, %arg1: !fir.class<!fir.type<somet>>) {
%0 = fir.dispatch "bar"(%arg1 : !fir.class<!fir.type<somet>>) (%arg1 : !fir.class<!fir.type<somet>>) -> !fits_in_reg {pass_arg_pos = 0 : i32}
fir.store %0 to %arg0 : !fir.ref<!fits_in_reg>
@@ -62,8 +66,15 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data
// CHECK-LABEL: func.func @test_addr_of_inreg() -> (() -> ()) {
// CHECK: %[[VAL_0:.*]] = fir.address_of(@test_inreg) : () -> tuple<i64, f32>
-// CHECK: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> tuple<i64, f32>) -> (() -> ())
-// CHECK: return %[[VAL_1]] : () -> ()
+// CHECK: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> tuple<i64, f32>) -> (() -> !fir.type<t1{i:f32,j:i32,k:f32}>)
+// CHECK: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<t1{i:f32,j:i32,k:f32}>) -> (() -> ())
+// CHECK: return %[[VAL_2]] : () -> ()
+// CHECK: }
+
+// CHECK-LABEL: func.func @test_addr_of_inreg_2() -> (() -> !fir.type<t1{i:f32,j:i32,k:f32}>) {
+// CHECK: %[[VAL_0:.*]] = fir.address_of(@test_inreg) : () -> tuple<i64, f32>
+// CHECK: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> tuple<i64, f32>) -> (() -> !fir.type<t1{i:f32,j:i32,k:f32}>)
+// CHECK: return %[[VAL_1]] : () -> !fir.type<t1{i:f32,j:i32,k:f32}>
// CHECK: }
// CHECK-LABEL: func.func @test_dispatch_inreg(
@@ -95,8 +106,9 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data
// CHECK-LABEL: func.func @test_addr_of_sret() -> (() -> ()) {
// CHECK: %[[VAL_0:.*]] = fir.address_of(@test_sret) : (!fir.ref<!fir.type<t2{i:!fir.array<5xf32>}>>) -> ()
-// CHECK: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : ((!fir.ref<!fir.type<t2{i:!fir.array<5xf32>}>>) -> ()) -> (() -> ())
-// CHECK: return %[[VAL_1]] : () -> ()
+// CHECK: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : ((!fir.ref<!fir.type<t2{i:!fir.array<5xf32>}>>) -> ()) -> (() -> !fir.type<t2{i:!fir.array<5xf32>}>)
+// CHECK: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<t2{i:!fir.array<5xf32>}>) -> (() -> ())
+// CHECK: return %[[VAL_2]] : () -> ()
// CHECK: }
// CHECK-LABEL: func.func @test_dispatch_sret(
``````````
</details>
https://github.com/llvm/llvm-project/pull/157413
More information about the flang-commits
mailing list