[flang-commits] [flang] [flang] AArch64 ABI for BIND(C) VALUE parameters (PR #118305)

Tom Eccles via flang-commits flang-commits at lists.llvm.org
Wed Dec 4 10:09:01 PST 2024


================
@@ -870,51 +870,143 @@ struct TargetAArch64 : public GenericTarget<TargetAArch64> {
 
   // Determine if the type is a Homogenous Floating-point Aggregate (HFA). An
   // HFA is a record type with up to 4 floating-point members of the same type.
-  static bool isHFA(fir::RecordType ty) {
+  static std::optional<int> usedRegsForHFA(fir::RecordType ty) {
     RecordType::TypeList types = ty.getTypeList();
     if (types.empty() || types.size() > 4)
-      return false;
+      return std::nullopt;
 
     std::optional<std::vector<mlir::Type>> flatTypes = flattenTypeList(types);
     if (!flatTypes || flatTypes->size() > 4) {
-      return false;
+      return std::nullopt;
     }
 
     if (!isa_real(flatTypes->front())) {
-      return false;
+      return std::nullopt;
+    }
+
+    return llvm::all_equal(*flatTypes) ? std::optional<int>{flatTypes->size()}
+                                       : std::nullopt;
+  }
+
+  struct NRegs {
+    int n{0};
+    bool isSimd{false};
+  };
+
+  NRegs usedRegsForRecordType(mlir::Location loc, fir::RecordType type) const {
+    if (std::optional<int> size = usedRegsForHFA(type))
+      return {*size, true};
+
+    auto [size, align] = fir::getTypeSizeAndAlignmentOrCrash(
+        loc, type, getDataLayout(), kindMap);
+
+    if (size <= 16)
+      return {static_cast<int>((size + 7) / 8), false};
+
+    // Pass on the stack, i.e. no registers used
+    return {};
+  }
+
+  NRegs usedRegsForType(mlir::Location loc, mlir::Type type) const {
+    return llvm::TypeSwitch<mlir::Type, NRegs>(type)
+        .Case<mlir::IntegerType>([&](auto intTy) {
+          return intTy.getWidth() == 128 ? NRegs{2, false} : NRegs{1, false};
+        })
+        .Case<mlir::FloatType>([&](auto) { return NRegs{1, true}; })
+        .Case<mlir::ComplexType>([&](auto) { return NRegs{2, true}; })
+        .Case<fir::LogicalType>([&](auto) { return NRegs{1, false}; })
+        .Case<fir::CharacterType>([&](auto) { return NRegs{1, false}; })
+        .Case<fir::SequenceType>([&](auto ty) {
+          NRegs nregs = usedRegsForType(loc, ty.getEleTy());
+          nregs.n *= ty.getShape()[0];
+          return nregs;
+        })
+        .Case<fir::RecordType>(
+            [&](auto ty) { return usedRegsForRecordType(loc, ty); })
+        .Case<fir::VectorType>([&](auto) {
+          TODO(loc, "passing vector argument to C by value is not supported");
+          return NRegs{};
+        });
+  }
+
+  bool hasEnoughRegisters(mlir::Location loc, fir::RecordType type,
+                          const Marshalling &previousArguments) const {
+    int availIntRegisters = 8;
+    int availSIMDRegisters = 8;
----------------
tblah wrote:

It would be good to have a standard reference to make it clear where these numbers come from.

https://github.com/llvm/llvm-project/pull/118305


More information about the flang-commits mailing list