[flang-commits] [flang] [flang] AArch64 ABI for BIND(C) VALUE parameters (PR #118305)

via flang-commits flang-commits at lists.llvm.org
Mon Dec 2 06:52:48 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-fir-hlfir

Author: David Truby (DavidTruby)

<details>
<summary>Changes</summary>

This patch adds handling for derived type VALUE parameters in BIND(C)
functions for AArch64.


---
Full diff: https://github.com/llvm/llvm-project/pull/118305.diff


2 Files Affected:

- (modified) flang/lib/Optimizer/CodeGen/Target.cpp (+116-24) 
- (added) flang/test/Fir/struct-passing-aarch64-byval.fir (+73) 


``````````diff
diff --git a/flang/lib/Optimizer/CodeGen/Target.cpp b/flang/lib/Optimizer/CodeGen/Target.cpp
index f7bffbf53c190e..0d865ee09535a3 100644
--- a/flang/lib/Optimizer/CodeGen/Target.cpp
+++ b/flang/lib/Optimizer/CodeGen/Target.cpp
@@ -826,7 +826,7 @@ struct TargetAArch64 : public GenericTarget<TargetAArch64> {
     return marshal;
   }
 
-  // Flatten a RecordType::TypeList containing more record types or array types
+  // Flatten a RecordType::TypeList containing more record types or array type
   static std::optional<std::vector<mlir::Type>>
   flattenTypeList(const RecordType::TypeList &types) {
     std::vector<mlir::Type> flatTypes;
@@ -870,51 +870,143 @@ struct TargetAArch64 : public GenericTarget<TargetAArch64> {
 
   // Determine if the type is a Homogenous Floating-point Aggregate (HFA). An
   // HFA is a record type with up to 4 floating-point members of the same type.
-  static bool isHFA(fir::RecordType ty) {
+  static std::optional<int> usedRegsForHFA(fir::RecordType ty) {
     RecordType::TypeList types = ty.getTypeList();
     if (types.empty() || types.size() > 4)
-      return false;
+      return std::nullopt;
 
     std::optional<std::vector<mlir::Type>> flatTypes = flattenTypeList(types);
     if (!flatTypes || flatTypes->size() > 4) {
-      return false;
+      return std::nullopt;
     }
 
     if (!isa_real(flatTypes->front())) {
-      return false;
+      return std::nullopt;
+    }
+
+    return llvm::all_equal(*flatTypes) ? std::optional<int>{flatTypes->size()}
+                                       : std::nullopt;
+  }
+
+  struct NRegs {
+    int n{0};
+    bool isSimd{false};
+  };
+
+  NRegs usedRegsForRecordType(mlir::Location loc, fir::RecordType type) const {
+    if (std::optional<int> size = usedRegsForHFA(type))
+      return {*size, true};
+
+    auto [size, align] = fir::getTypeSizeAndAlignmentOrCrash(
+        loc, type, getDataLayout(), kindMap);
+
+    if (size <= 16)
+      return {static_cast<int>((size + 7) / 8), false};
+
+    // Pass on the stack, i.e. no registers used
+    return {};
+  }
+
+  NRegs usedRegsForType(mlir::Location loc, mlir::Type type) const {
+    return llvm::TypeSwitch<mlir::Type, NRegs>(type)
+        .Case<mlir::IntegerType>([&](auto intTy) {
+          return intTy.getWidth() == 128 ? NRegs{2, false} : NRegs{1, false};
+        })
+        .Case<mlir::FloatType>([&](auto) { return NRegs{1, true}; })
+        .Case<mlir::ComplexType>([&](auto) { return NRegs{2, true}; })
+        .Case<fir::LogicalType>([&](auto) { return NRegs{1, false}; })
+        .Case<fir::CharacterType>([&](auto) { return NRegs{1, false}; })
+        .Case<fir::SequenceType>([&](auto ty) {
+          NRegs nregs = usedRegsForType(loc, ty.getEleTy());
+          nregs.n *= ty.getShape()[0];
+          return nregs;
+        })
+        .Case<fir::RecordType>(
+            [&](auto ty) { return usedRegsForRecordType(loc, ty); })
+        .Case<fir::VectorType>([&](auto) {
+          TODO(loc, "passing vector argument to C by value is not supported");
+          return NRegs{};
+        });
+  }
+
+  bool hasEnoughRegisters(mlir::Location loc, fir::RecordType type,
+                          const Marshalling &previousArguments) const {
+    int availIntRegisters = 8;
+    int availSIMDRegisters = 8;
+
+    // Check previous arguments to see how many registers are used already
+    for (auto [type, attr] : previousArguments) {
+      if (availIntRegisters <= 0 || availSIMDRegisters <= 0)
+        break;
+
+      if (attr.isByVal())
+        continue; // Previous argument passed on the stack
+
+      NRegs nregs = usedRegsForType(loc, type);
+      if (nregs.isSimd)
+        availSIMDRegisters -= nregs.n;
+      else
+        availIntRegisters -= nregs.n;
     }
 
-    return llvm::all_equal(*flatTypes);
+    NRegs nregs = usedRegsForRecordType(loc, type);
+
+    if (nregs.isSimd)
+      return nregs.n <= availSIMDRegisters;
+
+    return nregs.n <= availIntRegisters;
+  }
+
+  CodeGenSpecifics::Marshalling
+  passOnTheStack(mlir::Location loc, mlir::Type ty, bool isResult) const {
+    CodeGenSpecifics::Marshalling marshal;
+    auto sizeAndAlign =
+        fir::getTypeSizeAndAlignmentOrCrash(loc, ty, getDataLayout(), kindMap);
+    // The stack is always 8 byte aligned
+    unsigned short align =
+        std::max(sizeAndAlign.second, static_cast<unsigned short>(8));
+    marshal.emplace_back(fir::ReferenceType::get(ty),
+                         AT{align, /*byval=*/!isResult, /*sret=*/isResult});
+    return marshal;
   }
 
   // AArch64 procedure call ABI:
   // https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#parameter-passing
   CodeGenSpecifics::Marshalling
-  structReturnType(mlir::Location loc, fir::RecordType ty) const override {
+  structType(mlir::Location loc, fir::RecordType type, bool isResult) const {
+    NRegs nregs = usedRegsForRecordType(loc, type);
+
+    // If the type needs no registers it must need to be passed on the stack
+    if (nregs.n == 0)
+      return passOnTheStack(loc, type, isResult);
+
     CodeGenSpecifics::Marshalling marshal;
 
-    if (isHFA(ty)) {
-      // Just return the existing record type
-      marshal.emplace_back(ty, AT{});
-      return marshal;
+    mlir::Type pcsType;
+    if (nregs.isSimd) {
+      pcsType = type;
+    } else {
+      pcsType = fir::SequenceType::get(
+          nregs.n, mlir::IntegerType::get(type.getContext(), 64));
     }
 
-    auto [size, align] =
-        fir::getTypeSizeAndAlignmentOrCrash(loc, ty, getDataLayout(), kindMap);
+    marshal.emplace_back(pcsType, AT{});
+    return marshal;
+  }
 
-    // return in registers if size <= 16 bytes
-    if (size <= 16) {
-      std::size_t dwordSize = (size + 7) / 8;
-      auto newTy = fir::SequenceType::get(
-          dwordSize, mlir::IntegerType::get(ty.getContext(), 64));
-      marshal.emplace_back(newTy, AT{});
-      return marshal;
+  CodeGenSpecifics::Marshalling
+  structArgumentType(mlir::Location loc, fir::RecordType ty,
+                     const Marshalling &previousArguments) const override {
+    if (!hasEnoughRegisters(loc, ty, previousArguments)) {
+      return passOnTheStack(loc, ty, /*isResult=*/false);
     }
 
-    unsigned short stackAlign = std::max<unsigned short>(align, 8u);
-    marshal.emplace_back(fir::ReferenceType::get(ty),
-                         AT{stackAlign, false, true});
-    return marshal;
+    return structType(loc, ty, /*isResult=*/false);
+  }
+
+  CodeGenSpecifics::Marshalling
+  structReturnType(mlir::Location loc, fir::RecordType ty) const override {
+    return structType(loc, ty, /*isResult=*/true);
   }
 };
 } // namespace
diff --git a/flang/test/Fir/struct-passing-aarch64-byval.fir b/flang/test/Fir/struct-passing-aarch64-byval.fir
new file mode 100644
index 00000000000000..27143459dde2f2
--- /dev/null
+++ b/flang/test/Fir/struct-passing-aarch64-byval.fir
@@ -0,0 +1,73 @@
+// Test AArch64 ABI rewrite of struct passed by value (BIND(C), VALUE derived types).
+// RUN: fir-opt --target-rewrite="target=aarch64-unknown-linux-gnu" %s | FileCheck %s
+
+// CHECK-LABEL: func.func private @small_i32(!fir.array<2xi64>)
+func.func private @small_i32(!fir.type<small_i32{i:i32,j:i32,k:i32}>)
+// CHECK-LABEL: func.func private @small_i64(!fir.array<2xi64>)
+func.func private @small_i64(!fir.type<small_i64{i:i64,j:i64}>)
+// CHECK-LABEL: func.func private @small_mixed(!fir.array<2xi64>)
+func.func private @small_mixed(!fir.type<small_mixed{i:i64,j:f32,k:i32}>)
+// CHECK-LABEL: func.func private @small_non_hfa(!fir.array<2xi64>)
+func.func private @small_non_hfa(!fir.type<small_non_hfa{i:f64,j:f32,k:f16}>)
+
+// CHECK-LABEL: func.func private @hfa_f16(!fir.type<hfa_f16{i:f16,j:f16}>)
+func.func private @hfa_f16(!fir.type<hfa_f16{i:f16,j:f16}>)
+// CHECK-LABEL: func.func private @hfa_bf16(!fir.type<hfa_bf16{i:bf16,j:bf16,k:bf16,l:bf16}>)
+func.func private @hfa_bf16(!fir.type<hfa_bf16{i:bf16,j:bf16,k:bf16,l:bf16}>)
+// CHECK-LABEL: func.func private @hfa_f32(!fir.type<hfa_f32{i:f32,j:f32}>)
+func.func private @hfa_f32(!fir.type<hfa_f32{i:f32,j:f32}>)
+// CHECK-LABEL: func.func private @hfa_f64(!fir.type<hfa_f64{i:f64,j:f64,k:f64}>)
+func.func private @hfa_f64(!fir.type<hfa_f64{i:f64,j:f64,k:f64}>)
+// CHECK-LABEL: func.func private @hfa_f128(!fir.type<hfa_f128{i:f128,j:f128,k:f128,l:f128}>)
+func.func private @hfa_f128(!fir.type<hfa_f128{i:f128,j:f128,k:f128,l:f128}>)
+
+// CHECK-LABEL: func.func private @multi_small_integer(!fir.array<2xi64>, !fir.array<2xi64>)
+func.func private @multi_small_integer(!fir.type<small_i32{i:i32,j:i32,k:i32}>, !fir.type<small_i64{i:i64,j:i64}>)
+// CHECK-LABEL: func.func private @multi_hfas(!fir.type<hfa_f16{i:f16,j:f16}>, !fir.type<hfa_f128{i:f128,j:f128,k:f128,l:f128}>)
+func.func private @multi_hfas(!fir.type<hfa_f16{i:f16,j:f16}>, !fir.type<hfa_f128{i:f128,j:f128,k:f128,l:f128}>)
+// CHECK-LABEL: func.func private @multi_mixed(!fir.type<hfa_f64{i:f64,j:f64,k:f64}>, !fir.array<2xi64>, !fir.type<hfa_f32{i:f32,j:f32}>, !fir.array<2xi64>)
+func.func private @multi_mixed(!fir.type<hfa_f64{i:f64,j:f64,k:f64}>,!fir.type<small_non_hfa{i:f64,j:f32,k:f16}>,!fir.type<hfa_f32{i:f32,j:f32}>,!fir.type<small_i64{i:i64,j:i64}>)
+
+// CHECK-LABEL: func.func private @int_max(!fir.array<2xi64>,
+// CHECK-SAME: !fir.array<2xi64>,
+// CHECK-SAME: !fir.array<2xi64>,
+// CHECK-SAME: !fir.array<2xi64>)
+func.func private @int_max(!fir.type<int_max{i:i64,j:i64}>,
+                       !fir.type<int_max{i:i64,j:i64}>,
+                       !fir.type<int_max{i:i64,j:i64}>,
+                       !fir.type<int_max{i:i64,j:i64}>)
+// CHECK-LABEL: func.func private @hfa_max(!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>, !fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>)
+func.func private @hfa_max(!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>, !fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>)
+// CHECK-LABEL: func.func private @max(!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
+// CHECK-SAME: !fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
+// CHECK-SAME: !fir.array<2xi64>,
+// CHECK-SAME: !fir.array<2xi64>,
+// CHECK-SAME: !fir.array<2xi64>,
+// CHECK-SAME: !fir.array<2xi64>)
+func.func private @max(!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
+                       !fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
+                       !fir.type<int_max{i:i64,j:i64}>,
+                       !fir.type<int_max{i:i64,j:i64}>,
+                       !fir.type<int_max{i:i64,j:i64}>,
+                       !fir.type<int_max{i:i64,j:i64}>)
+
+
+// CHECK-LABEL: func.func private @too_many_int(!fir.array<2xi64>,
+// CHECK-SAME: !fir.array<2xi64>,
+// CHECK-SAME: !fir.array<2xi64>,
+// CHECK-SAME: !fir.array<2xi64>,
+// CHECK-SAME: !fir.ref<!fir.type<int_max{i:i64,j:i64}>> {{{.*}}, llvm.byval = !fir.type<int_max{i:i64,j:i64}>})
+func.func private @too_many_int(!fir.type<int_max{i:i64,j:i64}>,
+                       !fir.type<int_max{i:i64,j:i64}>,
+                       !fir.type<int_max{i:i64,j:i64}>,
+                       !fir.type<int_max{i:i64,j:i64}>,
+                       !fir.type<int_max{i:i64,j:i64}>)
+// CHECK-LABEL: func.func private @too_many_hfa(!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
+// CHECK-SAME: !fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
+// CHECK-SAME: !fir.ref<!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>> {{{.*}}, llvm.byval = !fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>})
+func.func private @too_many_hfa(!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
+                           !fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
+                           !fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>)
+
+// CHECK-LABEL: func.func private @too_big(!fir.ref<!fir.type<too_big{i:!fir.array<5xi32>}>> {{{.*}}, llvm.byval = !fir.type<too_big{i:!fir.array<5xi32>}>})
+func.func private @too_big(!fir.type<too_big{i:!fir.array<5xi32>}>)

``````````

</details>


https://github.com/llvm/llvm-project/pull/118305


More information about the flang-commits mailing list