[flang-commits] [flang] 5836d91 - [flang] add ABI argument attributes in indirect calls (#126896)

via flang-commits flang-commits at lists.llvm.org
Wed Feb 12 08:31:38 PST 2025


Author: jeanPerier
Date: 2025-02-12T17:31:34+01:00
New Revision: 5836d918450b07886556c519a81776db9ac91eea

URL: https://github.com/llvm/llvm-project/commit/5836d918450b07886556c519a81776db9ac91eea
DIFF: https://github.com/llvm/llvm-project/commit/5836d918450b07886556c519a81776db9ac91eea.diff

LOG: [flang] add ABI argument attributes in indirect calls (#126896)

Last piece that implements the TODO for sret and byval setting on
indirect calls.

This includes a fix to the codegen last patch. I thought types in in
type attributes were automatically converted in dialect conversion
passes, but that is not the case. The sret and byval type needs to be
converted to llvm types in codegen (mlir FuncOp conversion is doing a
similar conversion).

Added: 
    flang/test/Fir/target-rewrite-indirect-calls.fir
    flang/test/Integration/abi-indirect-call.f90

Modified: 
    flang/lib/Optimizer/CodeGen/CodeGen.cpp
    flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
    flang/test/Fir/convert-to-llvm.fir

Removed: 
    


################################################################################
diff  --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index f938d8d377465..c76b7cde55bdd 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -593,8 +593,36 @@ struct CallOpConversion : public fir::FIROpConversion<fir::CallOp> {
         call, resultTys, adaptor.getOperands(),
         addLLVMOpBundleAttrs(rewriter, attrConvert.getAttrs(),
                              adaptor.getOperands().size()));
-    if (mlir::ArrayAttr argAttrs = call.getArgAttrsAttr())
-      llvmCall.setArgAttrsAttr(argAttrs);
+    if (mlir::ArrayAttr argAttrsArray = call.getArgAttrsAttr()) {
+      // sret and byval type needs to be converted.
+      auto convertTypeAttr = [&](const mlir::NamedAttribute &attr) {
+        return mlir::TypeAttr::get(convertType(
+            llvm::cast<mlir::TypeAttr>(attr.getValue()).getValue()));
+      };
+      llvm::SmallVector<mlir::Attribute> newArgAttrsArray;
+      for (auto argAttrs : argAttrsArray) {
+        llvm::SmallVector<mlir::NamedAttribute> convertedAttrs;
+        for (const mlir::NamedAttribute &attr :
+             llvm::cast<mlir::DictionaryAttr>(argAttrs)) {
+          if (attr.getName().getValue() ==
+              mlir::LLVM::LLVMDialect::getByValAttrName()) {
+            convertedAttrs.push_back(rewriter.getNamedAttr(
+                mlir::LLVM::LLVMDialect::getByValAttrName(),
+                convertTypeAttr(attr)));
+          } else if (attr.getName().getValue() ==
+                     mlir::LLVM::LLVMDialect::getStructRetAttrName()) {
+            convertedAttrs.push_back(rewriter.getNamedAttr(
+                mlir::LLVM::LLVMDialect::getStructRetAttrName(),
+                convertTypeAttr(attr)));
+          } else {
+            convertedAttrs.push_back(attr);
+          }
+        }
+        newArgAttrsArray.emplace_back(
+            mlir::DictionaryAttr::get(rewriter.getContext(), convertedAttrs));
+      }
+      llvmCall.setArgAttrsAttr(rewriter.getArrayAttr(newArgAttrsArray));
+    }
     if (mlir::ArrayAttr resAttrs = call.getResAttrsAttr())
       llvmCall.setResAttrsAttr(resAttrs);
     return mlir::success();

diff  --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
index c099a08ffd30a..5c9da0321bcc4 100644
--- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
+++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
@@ -534,19 +534,44 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
     } else if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
       fir::CallOp newCall;
       if (callOp.getCallee()) {
-        newCall =
-            rewriter->create<A>(loc, *callOp.getCallee(), newResTys, newOpers);
+        newCall = rewriter->create<fir::CallOp>(loc, *callOp.getCallee(),
+                                                newResTys, newOpers);
       } else {
-        // TODO: llvm dialect must be updated to propagate argument on
-        // attributes for indirect calls. See:
-        // https://discourse.llvm.org/t/should-llvm-callop-be-able-to-carry-argument-attributes-for-indirect-calls/75431
-        if (hasByValOrSRetArgs(newInTyAndAttrs))
-          TODO(loc,
-               "passing argument or result on the stack in indirect calls");
         newOpers[0].setType(mlir::FunctionType::get(
             callOp.getContext(),
             mlir::TypeRange{newInTypes}.drop_front(dropFront), newResTys));
-        newCall = rewriter->create<A>(loc, newResTys, newOpers);
+        newCall = rewriter->create<fir::CallOp>(loc, newResTys, newOpers);
+        // Set ABI argument attributes on call operation since they are not
+        // accessible via a FuncOp in indirect calls.
+        if (hasByValOrSRetArgs(newInTyAndAttrs)) {
+          llvm::SmallVector<mlir::Attribute> argAttrsArray;
+          for (const auto &arg :
+               llvm::ArrayRef<fir::CodeGenSpecifics::TypeAndAttr>(
+                   newInTyAndAttrs)
+                   .drop_front(dropFront)) {
+            mlir::NamedAttrList argAttrs;
+            const auto &attr = std::get<fir::CodeGenSpecifics::Attributes>(arg);
+            if (attr.isByVal()) {
+              mlir::Type elemType =
+                  fir::dyn_cast_ptrOrBoxEleTy(std::get<mlir::Type>(arg));
+              argAttrs.set(mlir::LLVM::LLVMDialect::getByValAttrName(),
+                           mlir::TypeAttr::get(elemType));
+            } else if (attr.isSRet()) {
+              mlir::Type elemType =
+                  fir::dyn_cast_ptrOrBoxEleTy(std::get<mlir::Type>(arg));
+              argAttrs.set(mlir::LLVM::LLVMDialect::getStructRetAttrName(),
+                           mlir::TypeAttr::get(elemType));
+              if (auto align = attr.getAlignment()) {
+                argAttrs.set(mlir::LLVM::LLVMDialect::getAlignAttrName(),
+                             rewriter->getIntegerAttr(
+                                 rewriter->getIntegerType(32), align));
+              }
+            }
+            argAttrsArray.emplace_back(
+                argAttrs.getDictionary(rewriter->getContext()));
+          }
+          newCall.setArgAttrsAttr(rewriter->getArrayAttr(argAttrsArray));
+        }
       }
       LLVM_DEBUG(llvm::dbgs() << "replacing call with " << newCall << '\n');
       if (wrap)

diff  --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index c11cfd5d5faa1..8727c0ab08e70 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -2871,3 +2871,17 @@ func.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: (i16)-> i16) -> i16 {
   %0 = fir.call %arg1(%arg0) : (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
   return %0 : i16
 }
+
+// CHECK-LABEL: @test_byval
+func.func @test_byval(%arg0: (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, f64) -> (), %arg1: !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, %arg2: f64) {
+  //  llvm.call %{{.*}}(%{{.*}}, %{{.*}}) : !llvm.ptr, (!llvm.ptr {llvm.byval = !llvm.struct<"t", (array<5 x f64>)>}, f64) -> ()
+  fir.call %arg0(%arg1, %arg2) : (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>> {llvm.byval = !fir.type<t{a:!fir.array<5xf64>}>}, f64) -> ()
+  return
+}
+
+// CHECK-LABEL: @test_sret
+func.func @test_sret(%arg0: (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, f64) -> (), %arg1: !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, %arg2: f64) {
+  //  llvm.call %{{.*}}(%{{.*}}, %{{.*}}) : !llvm.ptr, (!llvm.ptr {llvm.sret = !llvm.struct<"t", (array<5 x f64>)>}, f64) -> ()
+  fir.call %arg0(%arg1, %arg2) : (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>> {llvm.sret = !fir.type<t{a:!fir.array<5xf64>}>}, f64) -> ()
+  return
+}

diff  --git a/flang/test/Fir/target-rewrite-indirect-calls.fir b/flang/test/Fir/target-rewrite-indirect-calls.fir
new file mode 100644
index 0000000000000..dbb3d0823520c
--- /dev/null
+++ b/flang/test/Fir/target-rewrite-indirect-calls.fir
@@ -0,0 +1,22 @@
+// Test that ABI attributes are set in indirect calls to BIND(C) functions.
+// RUN: fir-opt --target-rewrite="target=x86_64-unknown-linux-gnu" %s | FileCheck %s
+
+func.func @test(%arg0: () -> (), %arg1: !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, %arg2: f64) {
+  %0 = fir.load %arg1 : !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>
+  %1 = fir.convert %arg0 : (() -> ()) -> ((!fir.type<t{a:!fir.array<5xf64>}>, f64) -> ())
+  fir.call %1(%0, %arg2) proc_attrs<bind_c> : (!fir.type<t{a:!fir.array<5xf64>}>, f64) -> ()
+  return
+}
+// CHECK-LABEL:   func.func @test(
+// CHECK-SAME:                    %[[VAL_0:.*]]: () -> (),
+// CHECK-SAME:                    %[[VAL_1:.*]]: !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>,
+// CHECK-SAME:                    %[[VAL_2:.*]]: f64) {
+// CHECK:           %[[VAL_3:.*]] = fir.load %[[VAL_1]] : !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>
+// CHECK:           %[[VAL_4:.*]] = fir.convert %[[VAL_0]] : (() -> ()) -> ((!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, f64) -> ())
+// CHECK:           %[[VAL_5:.*]] = llvm.intr.stacksave : !llvm.ptr
+// CHECK:           %[[VAL_6:.*]] = fir.alloca !fir.type<t{a:!fir.array<5xf64>}>
+// CHECK:           fir.store %[[VAL_3]] to %[[VAL_6]] : !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>
+// CHECK:           fir.call %[[VAL_4]](%[[VAL_6]], %[[VAL_2]]) : (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>> {llvm.byval = !fir.type<t{a:!fir.array<5xf64>}>}, f64) -> ()
+// CHECK:           llvm.intr.stackrestore %[[VAL_5]] : !llvm.ptr
+// CHECK:           return
+// CHECK:         }

diff  --git a/flang/test/Integration/abi-indirect-call.f90 b/flang/test/Integration/abi-indirect-call.f90
new file mode 100644
index 0000000000000..54a6adfb2c14a
--- /dev/null
+++ b/flang/test/Integration/abi-indirect-call.f90
@@ -0,0 +1,15 @@
+!REQUIRES: x86-registered-target
+!REQUIRES: flang-supports-f128-math
+!RUN: %flang_fc1 -emit-llvm -triple x86_64-unknown-linux-gnu %s -o - | FileCheck  %s
+
+! Test ABI of indirect calls is properly implemented in the LLVM IR.
+
+subroutine foo(func_ptr, z)
+  interface
+    complex(16) function func_ptr()
+    end function
+  end interface
+  complex(16) :: z
+  ! CHECK: call void %{{.*}}(ptr sret({ fp128, fp128 }) align 16 %{{.*}})
+  z = func_ptr()
+end subroutine


        


More information about the flang-commits mailing list