[flang-commits] [flang] [flang] add ABI argument attributes in indirect calls (PR #126896)

via flang-commits flang-commits at lists.llvm.org
Wed Feb 12 07:49:53 PST 2025


https://github.com/jeanPerier updated https://github.com/llvm/llvm-project/pull/126896

>From 2a6e8340bcb110363e4d74fd95e765f082b24c2e Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Wed, 12 Feb 2025 02:01:39 -0800
Subject: [PATCH 1/2] [flang] add ABI argument attributes in indirect calls

---
 flang/lib/Optimizer/CodeGen/CodeGen.cpp       | 32 +++++++++++++-
 flang/lib/Optimizer/CodeGen/TargetRewrite.cpp | 43 +++++++++++++++----
 flang/test/Fir/convert-to-llvm.fir            | 14 ++++++
 .../Fir/target-rewrite-indirect-calls.fir     | 22 ++++++++++
 4 files changed, 100 insertions(+), 11 deletions(-)
 create mode 100644 flang/test/Fir/target-rewrite-indirect-calls.fir

diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index f938d8d377465..c76b7cde55bdd 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -593,8 +593,36 @@ struct CallOpConversion : public fir::FIROpConversion<fir::CallOp> {
         call, resultTys, adaptor.getOperands(),
         addLLVMOpBundleAttrs(rewriter, attrConvert.getAttrs(),
                              adaptor.getOperands().size()));
-    if (mlir::ArrayAttr argAttrs = call.getArgAttrsAttr())
-      llvmCall.setArgAttrsAttr(argAttrs);
+    if (mlir::ArrayAttr argAttrsArray = call.getArgAttrsAttr()) {
+      // sret and byval type needs to be converted.
+      auto convertTypeAttr = [&](const mlir::NamedAttribute &attr) {
+        return mlir::TypeAttr::get(convertType(
+            llvm::cast<mlir::TypeAttr>(attr.getValue()).getValue()));
+      };
+      llvm::SmallVector<mlir::Attribute> newArgAttrsArray;
+      for (auto argAttrs : argAttrsArray) {
+        llvm::SmallVector<mlir::NamedAttribute> convertedAttrs;
+        for (const mlir::NamedAttribute &attr :
+             llvm::cast<mlir::DictionaryAttr>(argAttrs)) {
+          if (attr.getName().getValue() ==
+              mlir::LLVM::LLVMDialect::getByValAttrName()) {
+            convertedAttrs.push_back(rewriter.getNamedAttr(
+                mlir::LLVM::LLVMDialect::getByValAttrName(),
+                convertTypeAttr(attr)));
+          } else if (attr.getName().getValue() ==
+                     mlir::LLVM::LLVMDialect::getStructRetAttrName()) {
+            convertedAttrs.push_back(rewriter.getNamedAttr(
+                mlir::LLVM::LLVMDialect::getStructRetAttrName(),
+                convertTypeAttr(attr)));
+          } else {
+            convertedAttrs.push_back(attr);
+          }
+        }
+        newArgAttrsArray.emplace_back(
+            mlir::DictionaryAttr::get(rewriter.getContext(), convertedAttrs));
+      }
+      llvmCall.setArgAttrsAttr(rewriter.getArrayAttr(newArgAttrsArray));
+    }
     if (mlir::ArrayAttr resAttrs = call.getResAttrsAttr())
       llvmCall.setResAttrsAttr(resAttrs);
     return mlir::success();
diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
index c099a08ffd30a..5c9da0321bcc4 100644
--- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
+++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
@@ -534,19 +534,44 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
     } else if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
       fir::CallOp newCall;
       if (callOp.getCallee()) {
-        newCall =
-            rewriter->create<A>(loc, *callOp.getCallee(), newResTys, newOpers);
+        newCall = rewriter->create<fir::CallOp>(loc, *callOp.getCallee(),
+                                                newResTys, newOpers);
       } else {
-        // TODO: llvm dialect must be updated to propagate argument on
-        // attributes for indirect calls. See:
-        // https://discourse.llvm.org/t/should-llvm-callop-be-able-to-carry-argument-attributes-for-indirect-calls/75431
-        if (hasByValOrSRetArgs(newInTyAndAttrs))
-          TODO(loc,
-               "passing argument or result on the stack in indirect calls");
         newOpers[0].setType(mlir::FunctionType::get(
             callOp.getContext(),
             mlir::TypeRange{newInTypes}.drop_front(dropFront), newResTys));
-        newCall = rewriter->create<A>(loc, newResTys, newOpers);
+        newCall = rewriter->create<fir::CallOp>(loc, newResTys, newOpers);
+        // Set ABI argument attributes on call operation since they are not
+        // accessible via a FuncOp in indirect calls.
+        if (hasByValOrSRetArgs(newInTyAndAttrs)) {
+          llvm::SmallVector<mlir::Attribute> argAttrsArray;
+          for (const auto &arg :
+               llvm::ArrayRef<fir::CodeGenSpecifics::TypeAndAttr>(
+                   newInTyAndAttrs)
+                   .drop_front(dropFront)) {
+            mlir::NamedAttrList argAttrs;
+            const auto &attr = std::get<fir::CodeGenSpecifics::Attributes>(arg);
+            if (attr.isByVal()) {
+              mlir::Type elemType =
+                  fir::dyn_cast_ptrOrBoxEleTy(std::get<mlir::Type>(arg));
+              argAttrs.set(mlir::LLVM::LLVMDialect::getByValAttrName(),
+                           mlir::TypeAttr::get(elemType));
+            } else if (attr.isSRet()) {
+              mlir::Type elemType =
+                  fir::dyn_cast_ptrOrBoxEleTy(std::get<mlir::Type>(arg));
+              argAttrs.set(mlir::LLVM::LLVMDialect::getStructRetAttrName(),
+                           mlir::TypeAttr::get(elemType));
+              if (auto align = attr.getAlignment()) {
+                argAttrs.set(mlir::LLVM::LLVMDialect::getAlignAttrName(),
+                             rewriter->getIntegerAttr(
+                                 rewriter->getIntegerType(32), align));
+              }
+            }
+            argAttrsArray.emplace_back(
+                argAttrs.getDictionary(rewriter->getContext()));
+          }
+          newCall.setArgAttrsAttr(rewriter->getArrayAttr(argAttrsArray));
+        }
       }
       LLVM_DEBUG(llvm::dbgs() << "replacing call with " << newCall << '\n');
       if (wrap)
diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index c11cfd5d5faa1..8727c0ab08e70 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -2871,3 +2871,17 @@ func.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: (i16)-> i16) -> i16 {
   %0 = fir.call %arg1(%arg0) : (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
   return %0 : i16
 }
+
+// CHECK-LABEL: @test_byval
+func.func @test_byval(%arg0: (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, f64) -> (), %arg1: !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, %arg2: f64) {
+  //  llvm.call %{{.*}}(%{{.*}}, %{{.*}}) : !llvm.ptr, (!llvm.ptr {llvm.byval = !llvm.struct<"t", (array<5 x f64>)>}, f64) -> ()
+  fir.call %arg0(%arg1, %arg2) : (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>> {llvm.byval = !fir.type<t{a:!fir.array<5xf64>}>}, f64) -> ()
+  return
+}
+
+// CHECK-LABEL: @test_sret
+func.func @test_sret(%arg0: (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, f64) -> (), %arg1: !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, %arg2: f64) {
+  //  llvm.call %{{.*}}(%{{.*}}, %{{.*}}) : !llvm.ptr, (!llvm.ptr {llvm.sret = !llvm.struct<"t", (array<5 x f64>)>}, f64) -> ()
+  fir.call %arg0(%arg1, %arg2) : (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>> {llvm.sret = !fir.type<t{a:!fir.array<5xf64>}>}, f64) -> ()
+  return
+}
diff --git a/flang/test/Fir/target-rewrite-indirect-calls.fir b/flang/test/Fir/target-rewrite-indirect-calls.fir
new file mode 100644
index 0000000000000..dbb3d0823520c
--- /dev/null
+++ b/flang/test/Fir/target-rewrite-indirect-calls.fir
@@ -0,0 +1,22 @@
+// Test that ABI attributes are set in indirect calls to BIND(C) functions.
+// RUN: fir-opt --target-rewrite="target=x86_64-unknown-linux-gnu" %s | FileCheck %s
+
+func.func @test(%arg0: () -> (), %arg1: !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, %arg2: f64) {
+  %0 = fir.load %arg1 : !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>
+  %1 = fir.convert %arg0 : (() -> ()) -> ((!fir.type<t{a:!fir.array<5xf64>}>, f64) -> ())
+  fir.call %1(%0, %arg2) proc_attrs<bind_c> : (!fir.type<t{a:!fir.array<5xf64>}>, f64) -> ()
+  return
+}
+// CHECK-LABEL:   func.func @test(
+// CHECK-SAME:                    %[[VAL_0:.*]]: () -> (),
+// CHECK-SAME:                    %[[VAL_1:.*]]: !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>,
+// CHECK-SAME:                    %[[VAL_2:.*]]: f64) {
+// CHECK:           %[[VAL_3:.*]] = fir.load %[[VAL_1]] : !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>
+// CHECK:           %[[VAL_4:.*]] = fir.convert %[[VAL_0]] : (() -> ()) -> ((!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, f64) -> ())
+// CHECK:           %[[VAL_5:.*]] = llvm.intr.stacksave : !llvm.ptr
+// CHECK:           %[[VAL_6:.*]] = fir.alloca !fir.type<t{a:!fir.array<5xf64>}>
+// CHECK:           fir.store %[[VAL_3]] to %[[VAL_6]] : !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>
+// CHECK:           fir.call %[[VAL_4]](%[[VAL_6]], %[[VAL_2]]) : (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>> {llvm.byval = !fir.type<t{a:!fir.array<5xf64>}>}, f64) -> ()
+// CHECK:           llvm.intr.stackrestore %[[VAL_5]] : !llvm.ptr
+// CHECK:           return
+// CHECK:         }

>From 11d72af6f3a51a50fbf93b9d684656096a5a48b2 Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Wed, 12 Feb 2025 07:45:40 -0800
Subject: [PATCH 2/2] add integration test

---
 flang/test/Integration/abi-indirect-call.f90 | 15 +++++++++++++++
 1 file changed, 15 insertions(+)
 create mode 100644 flang/test/Integration/abi-indirect-call.f90

diff --git a/flang/test/Integration/abi-indirect-call.f90 b/flang/test/Integration/abi-indirect-call.f90
new file mode 100644
index 0000000000000..54a6adfb2c14a
--- /dev/null
+++ b/flang/test/Integration/abi-indirect-call.f90
@@ -0,0 +1,15 @@
+!REQUIRES: x86-registered-target
+!REQUIRES: flang-supports-f128-math
+!RUN: %flang_fc1 -emit-llvm -triple x86_64-unknown-linux-gnu %s -o - | FileCheck  %s
+
+! Test ABI of indirect calls is properly implemented in the LLVM IR.
+
+subroutine foo(func_ptr, z)
+  interface
+    complex(16) function func_ptr()
+    end function
+  end interface
+  complex(16) :: z
+  ! CHECK: call void %{{.*}}(ptr sret({ fp128, fp128 }) align 16 %{{.*}})
+  z = func_ptr()
+end subroutine



More information about the flang-commits mailing list