[flang-commits] [flang] 3beec2f - [flang] do not rely on existing fir.convert in TargetRewrite (#157413)
via flang-commits
flang-commits at lists.llvm.org
Mon Sep 8 08:22:29 PDT 2025
Author: jeanPerier
Date: 2025-09-08T17:22:25+02:00
New Revision: 3beec2f6875a9a41c4010db7d3ace5acdad48e5d
URL: https://github.com/llvm/llvm-project/commit/3beec2f6875a9a41c4010db7d3ace5acdad48e5d
DIFF: https://github.com/llvm/llvm-project/commit/3beec2f6875a9a41c4010db7d3ace5acdad48e5d.diff
LOG: [flang] do not rely on existing fir.convert in TargetRewrite (#157413)
TargetRewrite is doing a shallow rewrite of function signatures. It is
only rewriting function definitions (FuncOp), calls (CallOp) and
AddressOfOp. It is not trying to visit each operations that may have an
operand with a function type.
It therefore needs function signature casts around the operations it is
rewriting.
Currently, these casts were not inserted after AddressOfOp rewrites
because lowering tends to always insert function cast after generating
AddressOfOp to the void type so the pass relied on implicitly updating
this cast operand type to get the required cast. This is brittle because
there is no guarantee such convert must be here and canonicalization and
passes may remove them.
Insert a cast after on the result of rewritten operations. If it is
redundant, it will be canonicalized away later.
Added:
Modified:
flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
flang/test/Fir/struct-return-x86-64.fir
Removed:
################################################################################
diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
index fa935542d40f7..ac285b5d403df 100644
--- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
+++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
@@ -1336,7 +1336,15 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
private:
// Replace `op` and remove it.
void replaceOp(mlir::Operation *op, mlir::ValueRange newValues) {
- op->replaceAllUsesWith(newValues);
+ llvm::SmallVector<mlir::Value> casts;
+ for (auto [oldValue, newValue] : llvm::zip(op->getResults(), newValues)) {
+ if (oldValue.getType() == newValue.getType())
+ casts.push_back(newValue);
+ else
+ casts.push_back(fir::ConvertOp::create(*rewriter, op->getLoc(),
+ oldValue.getType(), newValue));
+ }
+ op->replaceAllUsesWith(casts);
op->dropAllReferences();
op->erase();
}
diff --git a/flang/test/Fir/struct-return-x86-64.fir b/flang/test/Fir/struct-return-x86-64.fir
index 5d1e6129d8f69..b45983daa97ba 100644
--- a/flang/test/Fir/struct-return-x86-64.fir
+++ b/flang/test/Fir/struct-return-x86-64.fir
@@ -17,6 +17,10 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data
%1 = fir.convert %0 : (() -> !fits_in_reg) -> (() -> ())
return %1 : () -> ()
}
+ func.func @test_addr_of_inreg_2() -> (() -> !fits_in_reg) {
+ %0 = fir.address_of(@test_inreg) : () -> !fits_in_reg
+ return %0 : () -> !fits_in_reg
+ }
func.func @test_dispatch_inreg(%arg0: !fir.ref<!fits_in_reg>, %arg1: !fir.class<!fir.type<somet>>) {
%0 = fir.dispatch "bar"(%arg1 : !fir.class<!fir.type<somet>>) (%arg1 : !fir.class<!fir.type<somet>>) -> !fits_in_reg {pass_arg_pos = 0 : i32}
fir.store %0 to %arg0 : !fir.ref<!fits_in_reg>
@@ -62,8 +66,15 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data
// CHECK-LABEL: func.func @test_addr_of_inreg() -> (() -> ()) {
// CHECK: %[[VAL_0:.*]] = fir.address_of(@test_inreg) : () -> tuple<i64, f32>
-// CHECK: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> tuple<i64, f32>) -> (() -> ())
-// CHECK: return %[[VAL_1]] : () -> ()
+// CHECK: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> tuple<i64, f32>) -> (() -> !fir.type<t1{i:f32,j:i32,k:f32}>)
+// CHECK: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<t1{i:f32,j:i32,k:f32}>) -> (() -> ())
+// CHECK: return %[[VAL_2]] : () -> ()
+// CHECK: }
+
+// CHECK-LABEL: func.func @test_addr_of_inreg_2() -> (() -> !fir.type<t1{i:f32,j:i32,k:f32}>) {
+// CHECK: %[[VAL_0:.*]] = fir.address_of(@test_inreg) : () -> tuple<i64, f32>
+// CHECK: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> tuple<i64, f32>) -> (() -> !fir.type<t1{i:f32,j:i32,k:f32}>)
+// CHECK: return %[[VAL_1]] : () -> !fir.type<t1{i:f32,j:i32,k:f32}>
// CHECK: }
// CHECK-LABEL: func.func @test_dispatch_inreg(
@@ -95,8 +106,9 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data
// CHECK-LABEL: func.func @test_addr_of_sret() -> (() -> ()) {
// CHECK: %[[VAL_0:.*]] = fir.address_of(@test_sret) : (!fir.ref<!fir.type<t2{i:!fir.array<5xf32>}>>) -> ()
-// CHECK: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : ((!fir.ref<!fir.type<t2{i:!fir.array<5xf32>}>>) -> ()) -> (() -> ())
-// CHECK: return %[[VAL_1]] : () -> ()
+// CHECK: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : ((!fir.ref<!fir.type<t2{i:!fir.array<5xf32>}>>) -> ()) -> (() -> !fir.type<t2{i:!fir.array<5xf32>}>)
+// CHECK: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<t2{i:!fir.array<5xf32>}>) -> (() -> ())
+// CHECK: return %[[VAL_2]] : () -> ()
// CHECK: }
// CHECK-LABEL: func.func @test_dispatch_sret(
More information about the flang-commits
mailing list