[flang-commits] [flang] [flang] do not rely on existing fir.convert in TargetRewrite (PR #157413)

via flang-commits flang-commits at lists.llvm.org
Mon Sep 8 02:57:10 PDT 2025


https://github.com/jeanPerier created https://github.com/llvm/llvm-project/pull/157413

TargetRewrite is doing a shallow rewrite of function signatures. It is only rewriting function definitions (FuncOp), calls (CallOp) and AddressOfOp. It is not trying to visit each operations that may have an operand with a function type.
It therefore needs function signature casts around the operations it is rewriting.

Currently, these casts were not inserted after AddressOfOp rewrites because lowering tends to always insert function cast after generating AddressOfOp to the void type so the pass relied on implicitly updating this cast operand type to get the required cast. This is brittle because there is no guarantee such convert must be here and canonicalization and passes may remove them.

Insert a cast after on the result of rewritten operations. If it is redundant, it will be canonicalized away later.

>From e2f4cac1907a3c964ab908292f6e6e1ec4ee1a36 Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Mon, 8 Sep 2025 02:35:22 -0700
Subject: [PATCH] [flang] do not rely on existing fir.convert in TargetRewrite

---
 flang/lib/Optimizer/CodeGen/TargetRewrite.cpp | 10 +++++++++-
 flang/test/Fir/struct-return-x86-64.fir       | 20 +++++++++++++++----
 2 files changed, 25 insertions(+), 5 deletions(-)

diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
index fa935542d40f7..ac285b5d403df 100644
--- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
+++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
@@ -1336,7 +1336,15 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
 private:
   // Replace `op` and remove it.
   void replaceOp(mlir::Operation *op, mlir::ValueRange newValues) {
-    op->replaceAllUsesWith(newValues);
+    llvm::SmallVector<mlir::Value> casts;
+    for (auto [oldValue, newValue] : llvm::zip(op->getResults(), newValues)) {
+      if (oldValue.getType() == newValue.getType())
+        casts.push_back(newValue);
+      else
+        casts.push_back(fir::ConvertOp::create(*rewriter, op->getLoc(),
+                                               oldValue.getType(), newValue));
+    }
+    op->replaceAllUsesWith(casts);
     op->dropAllReferences();
     op->erase();
   }
diff --git a/flang/test/Fir/struct-return-x86-64.fir b/flang/test/Fir/struct-return-x86-64.fir
index 5d1e6129d8f69..b45983daa97ba 100644
--- a/flang/test/Fir/struct-return-x86-64.fir
+++ b/flang/test/Fir/struct-return-x86-64.fir
@@ -17,6 +17,10 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data
     %1 = fir.convert %0 : (() -> !fits_in_reg) -> (() -> ())
     return %1 : () -> ()
   }
+  func.func @test_addr_of_inreg_2() -> (() -> !fits_in_reg) {
+    %0 = fir.address_of(@test_inreg) : () -> !fits_in_reg
+    return %0 : () -> !fits_in_reg
+  }
   func.func @test_dispatch_inreg(%arg0: !fir.ref<!fits_in_reg>, %arg1: !fir.class<!fir.type<somet>>) {
     %0 = fir.dispatch "bar"(%arg1 : !fir.class<!fir.type<somet>>) (%arg1 : !fir.class<!fir.type<somet>>) -> !fits_in_reg {pass_arg_pos = 0 : i32}
     fir.store %0 to %arg0 : !fir.ref<!fits_in_reg>
@@ -62,8 +66,15 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data
 
 // CHECK-LABEL:   func.func @test_addr_of_inreg() -> (() -> ()) {
 // CHECK:           %[[VAL_0:.*]] = fir.address_of(@test_inreg) : () -> tuple<i64, f32>
-// CHECK:           %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> tuple<i64, f32>) -> (() -> ())
-// CHECK:           return %[[VAL_1]] : () -> ()
+// CHECK:           %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> tuple<i64, f32>) -> (() -> !fir.type<t1{i:f32,j:i32,k:f32}>)
+// CHECK:           %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<t1{i:f32,j:i32,k:f32}>) -> (() -> ())
+// CHECK:           return %[[VAL_2]] : () -> ()
+// CHECK:         }
+
+// CHECK-LABEL:   func.func @test_addr_of_inreg_2() -> (() -> !fir.type<t1{i:f32,j:i32,k:f32}>) {
+// CHECK:           %[[VAL_0:.*]] = fir.address_of(@test_inreg) : () -> tuple<i64, f32>
+// CHECK:           %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> tuple<i64, f32>) -> (() -> !fir.type<t1{i:f32,j:i32,k:f32}>)
+// CHECK:           return %[[VAL_1]] : () -> !fir.type<t1{i:f32,j:i32,k:f32}>
 // CHECK:         }
 
 // CHECK-LABEL:   func.func @test_dispatch_inreg(
@@ -95,8 +106,9 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data
 
 // CHECK-LABEL:   func.func @test_addr_of_sret() -> (() -> ()) {
 // CHECK:           %[[VAL_0:.*]] = fir.address_of(@test_sret) : (!fir.ref<!fir.type<t2{i:!fir.array<5xf32>}>>) -> ()
-// CHECK:           %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : ((!fir.ref<!fir.type<t2{i:!fir.array<5xf32>}>>) -> ()) -> (() -> ())
-// CHECK:           return %[[VAL_1]] : () -> ()
+// CHECK:           %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : ((!fir.ref<!fir.type<t2{i:!fir.array<5xf32>}>>) -> ()) -> (() -> !fir.type<t2{i:!fir.array<5xf32>}>)
+// CHECK:           %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<t2{i:!fir.array<5xf32>}>) -> (() -> ())
+// CHECK:           return %[[VAL_2]] : () -> ()
 // CHECK:         }
 
 // CHECK-LABEL:   func.func @test_dispatch_sret(



More information about the flang-commits mailing list