[flang-commits] [flang] [flang] add ABI argument attributes in indirect calls (PR #126896)
via flang-commits
flang-commits at lists.llvm.org
Wed Feb 12 07:49:53 PST 2025
https://github.com/jeanPerier updated https://github.com/llvm/llvm-project/pull/126896
>From 2a6e8340bcb110363e4d74fd95e765f082b24c2e Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Wed, 12 Feb 2025 02:01:39 -0800
Subject: [PATCH 1/2] [flang] add ABI argument attributes in indirect calls
---
flang/lib/Optimizer/CodeGen/CodeGen.cpp | 32 +++++++++++++-
flang/lib/Optimizer/CodeGen/TargetRewrite.cpp | 43 +++++++++++++++----
flang/test/Fir/convert-to-llvm.fir | 14 ++++++
.../Fir/target-rewrite-indirect-calls.fir | 22 ++++++++++
4 files changed, 100 insertions(+), 11 deletions(-)
create mode 100644 flang/test/Fir/target-rewrite-indirect-calls.fir
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index f938d8d377465..c76b7cde55bdd 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -593,8 +593,36 @@ struct CallOpConversion : public fir::FIROpConversion<fir::CallOp> {
call, resultTys, adaptor.getOperands(),
addLLVMOpBundleAttrs(rewriter, attrConvert.getAttrs(),
adaptor.getOperands().size()));
- if (mlir::ArrayAttr argAttrs = call.getArgAttrsAttr())
- llvmCall.setArgAttrsAttr(argAttrs);
+ if (mlir::ArrayAttr argAttrsArray = call.getArgAttrsAttr()) {
+ // sret and byval type needs to be converted.
+ auto convertTypeAttr = [&](const mlir::NamedAttribute &attr) {
+ return mlir::TypeAttr::get(convertType(
+ llvm::cast<mlir::TypeAttr>(attr.getValue()).getValue()));
+ };
+ llvm::SmallVector<mlir::Attribute> newArgAttrsArray;
+ for (auto argAttrs : argAttrsArray) {
+ llvm::SmallVector<mlir::NamedAttribute> convertedAttrs;
+ for (const mlir::NamedAttribute &attr :
+ llvm::cast<mlir::DictionaryAttr>(argAttrs)) {
+ if (attr.getName().getValue() ==
+ mlir::LLVM::LLVMDialect::getByValAttrName()) {
+ convertedAttrs.push_back(rewriter.getNamedAttr(
+ mlir::LLVM::LLVMDialect::getByValAttrName(),
+ convertTypeAttr(attr)));
+ } else if (attr.getName().getValue() ==
+ mlir::LLVM::LLVMDialect::getStructRetAttrName()) {
+ convertedAttrs.push_back(rewriter.getNamedAttr(
+ mlir::LLVM::LLVMDialect::getStructRetAttrName(),
+ convertTypeAttr(attr)));
+ } else {
+ convertedAttrs.push_back(attr);
+ }
+ }
+ newArgAttrsArray.emplace_back(
+ mlir::DictionaryAttr::get(rewriter.getContext(), convertedAttrs));
+ }
+ llvmCall.setArgAttrsAttr(rewriter.getArrayAttr(newArgAttrsArray));
+ }
if (mlir::ArrayAttr resAttrs = call.getResAttrsAttr())
llvmCall.setResAttrsAttr(resAttrs);
return mlir::success();
diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
index c099a08ffd30a..5c9da0321bcc4 100644
--- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
+++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
@@ -534,19 +534,44 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
} else if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
fir::CallOp newCall;
if (callOp.getCallee()) {
- newCall =
- rewriter->create<A>(loc, *callOp.getCallee(), newResTys, newOpers);
+ newCall = rewriter->create<fir::CallOp>(loc, *callOp.getCallee(),
+ newResTys, newOpers);
} else {
- // TODO: llvm dialect must be updated to propagate argument on
- // attributes for indirect calls. See:
- // https://discourse.llvm.org/t/should-llvm-callop-be-able-to-carry-argument-attributes-for-indirect-calls/75431
- if (hasByValOrSRetArgs(newInTyAndAttrs))
- TODO(loc,
- "passing argument or result on the stack in indirect calls");
newOpers[0].setType(mlir::FunctionType::get(
callOp.getContext(),
mlir::TypeRange{newInTypes}.drop_front(dropFront), newResTys));
- newCall = rewriter->create<A>(loc, newResTys, newOpers);
+ newCall = rewriter->create<fir::CallOp>(loc, newResTys, newOpers);
+ // Set ABI argument attributes on call operation since they are not
+ // accessible via a FuncOp in indirect calls.
+ if (hasByValOrSRetArgs(newInTyAndAttrs)) {
+ llvm::SmallVector<mlir::Attribute> argAttrsArray;
+ for (const auto &arg :
+ llvm::ArrayRef<fir::CodeGenSpecifics::TypeAndAttr>(
+ newInTyAndAttrs)
+ .drop_front(dropFront)) {
+ mlir::NamedAttrList argAttrs;
+ const auto &attr = std::get<fir::CodeGenSpecifics::Attributes>(arg);
+ if (attr.isByVal()) {
+ mlir::Type elemType =
+ fir::dyn_cast_ptrOrBoxEleTy(std::get<mlir::Type>(arg));
+ argAttrs.set(mlir::LLVM::LLVMDialect::getByValAttrName(),
+ mlir::TypeAttr::get(elemType));
+ } else if (attr.isSRet()) {
+ mlir::Type elemType =
+ fir::dyn_cast_ptrOrBoxEleTy(std::get<mlir::Type>(arg));
+ argAttrs.set(mlir::LLVM::LLVMDialect::getStructRetAttrName(),
+ mlir::TypeAttr::get(elemType));
+ if (auto align = attr.getAlignment()) {
+ argAttrs.set(mlir::LLVM::LLVMDialect::getAlignAttrName(),
+ rewriter->getIntegerAttr(
+ rewriter->getIntegerType(32), align));
+ }
+ }
+ argAttrsArray.emplace_back(
+ argAttrs.getDictionary(rewriter->getContext()));
+ }
+ newCall.setArgAttrsAttr(rewriter->getArrayAttr(argAttrsArray));
+ }
}
LLVM_DEBUG(llvm::dbgs() << "replacing call with " << newCall << '\n');
if (wrap)
diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index c11cfd5d5faa1..8727c0ab08e70 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -2871,3 +2871,17 @@ func.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: (i16)-> i16) -> i16 {
%0 = fir.call %arg1(%arg0) : (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
return %0 : i16
}
+
+// CHECK-LABEL: @test_byval
+func.func @test_byval(%arg0: (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, f64) -> (), %arg1: !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, %arg2: f64) {
+ // llvm.call %{{.*}}(%{{.*}}, %{{.*}}) : !llvm.ptr, (!llvm.ptr {llvm.byval = !llvm.struct<"t", (array<5 x f64>)>}, f64) -> ()
+ fir.call %arg0(%arg1, %arg2) : (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>> {llvm.byval = !fir.type<t{a:!fir.array<5xf64>}>}, f64) -> ()
+ return
+}
+
+// CHECK-LABEL: @test_sret
+func.func @test_sret(%arg0: (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, f64) -> (), %arg1: !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, %arg2: f64) {
+ // llvm.call %{{.*}}(%{{.*}}, %{{.*}}) : !llvm.ptr, (!llvm.ptr {llvm.sret = !llvm.struct<"t", (array<5 x f64>)>}, f64) -> ()
+ fir.call %arg0(%arg1, %arg2) : (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>> {llvm.sret = !fir.type<t{a:!fir.array<5xf64>}>}, f64) -> ()
+ return
+}
diff --git a/flang/test/Fir/target-rewrite-indirect-calls.fir b/flang/test/Fir/target-rewrite-indirect-calls.fir
new file mode 100644
index 0000000000000..dbb3d0823520c
--- /dev/null
+++ b/flang/test/Fir/target-rewrite-indirect-calls.fir
@@ -0,0 +1,22 @@
+// Test that ABI attributes are set in indirect calls to BIND(C) functions.
+// RUN: fir-opt --target-rewrite="target=x86_64-unknown-linux-gnu" %s | FileCheck %s
+
+func.func @test(%arg0: () -> (), %arg1: !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, %arg2: f64) {
+ %0 = fir.load %arg1 : !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>
+ %1 = fir.convert %arg0 : (() -> ()) -> ((!fir.type<t{a:!fir.array<5xf64>}>, f64) -> ())
+ fir.call %1(%0, %arg2) proc_attrs<bind_c> : (!fir.type<t{a:!fir.array<5xf64>}>, f64) -> ()
+ return
+}
+// CHECK-LABEL: func.func @test(
+// CHECK-SAME: %[[VAL_0:.*]]: () -> (),
+// CHECK-SAME: %[[VAL_1:.*]]: !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>,
+// CHECK-SAME: %[[VAL_2:.*]]: f64) {
+// CHECK: %[[VAL_3:.*]] = fir.load %[[VAL_1]] : !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>
+// CHECK: %[[VAL_4:.*]] = fir.convert %[[VAL_0]] : (() -> ()) -> ((!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, f64) -> ())
+// CHECK: %[[VAL_5:.*]] = llvm.intr.stacksave : !llvm.ptr
+// CHECK: %[[VAL_6:.*]] = fir.alloca !fir.type<t{a:!fir.array<5xf64>}>
+// CHECK: fir.store %[[VAL_3]] to %[[VAL_6]] : !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>
+// CHECK: fir.call %[[VAL_4]](%[[VAL_6]], %[[VAL_2]]) : (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>> {llvm.byval = !fir.type<t{a:!fir.array<5xf64>}>}, f64) -> ()
+// CHECK: llvm.intr.stackrestore %[[VAL_5]] : !llvm.ptr
+// CHECK: return
+// CHECK: }
>From 11d72af6f3a51a50fbf93b9d684656096a5a48b2 Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Wed, 12 Feb 2025 07:45:40 -0800
Subject: [PATCH 2/2] add integration test
---
flang/test/Integration/abi-indirect-call.f90 | 15 +++++++++++++++
1 file changed, 15 insertions(+)
create mode 100644 flang/test/Integration/abi-indirect-call.f90
diff --git a/flang/test/Integration/abi-indirect-call.f90 b/flang/test/Integration/abi-indirect-call.f90
new file mode 100644
index 0000000000000..54a6adfb2c14a
--- /dev/null
+++ b/flang/test/Integration/abi-indirect-call.f90
@@ -0,0 +1,15 @@
+!REQUIRES: x86-registered-target
+!REQUIRES: flang-supports-f128-math
+!RUN: %flang_fc1 -emit-llvm -triple x86_64-unknown-linux-gnu %s -o - | FileCheck %s
+
+! Test ABI of indirect calls is properly implemented in the LLVM IR.
+
+subroutine foo(func_ptr, z)
+ interface
+ complex(16) function func_ptr()
+ end function
+ end interface
+ complex(16) :: z
+ ! CHECK: call void %{{.*}}(ptr sret({ fp128, fp128 }) align 16 %{{.*}})
+ z = func_ptr()
+end subroutine
More information about the flang-commits
mailing list