[flang-commits] [flang] [flang] AArch64 ABI for BIND(C) VALUE parameters (PR #118305)

via flang-commits flang-commits at lists.llvm.org
Thu Dec 5 02:05:23 PST 2024


================
@@ -870,51 +870,143 @@ struct TargetAArch64 : public GenericTarget<TargetAArch64> {
 
   // Determine if the type is a Homogenous Floating-point Aggregate (HFA). An
   // HFA is a record type with up to 4 floating-point members of the same type.
-  static bool isHFA(fir::RecordType ty) {
+  static std::optional<int> usedRegsForHFA(fir::RecordType ty) {
     RecordType::TypeList types = ty.getTypeList();
     if (types.empty() || types.size() > 4)
-      return false;
+      return std::nullopt;
 
     std::optional<std::vector<mlir::Type>> flatTypes = flattenTypeList(types);
     if (!flatTypes || flatTypes->size() > 4) {
-      return false;
+      return std::nullopt;
     }
 
     if (!isa_real(flatTypes->front())) {
-      return false;
+      return std::nullopt;
+    }
+
+    return llvm::all_equal(*flatTypes) ? std::optional<int>{flatTypes->size()}
+                                       : std::nullopt;
+  }
+
+  struct NRegs {
+    int n{0};
+    bool isSimd{false};
+  };
+
+  NRegs usedRegsForRecordType(mlir::Location loc, fir::RecordType type) const {
+    if (std::optional<int> size = usedRegsForHFA(type))
+      return {*size, true};
+
+    auto [size, align] = fir::getTypeSizeAndAlignmentOrCrash(
+        loc, type, getDataLayout(), kindMap);
+
+    if (size <= 16)
+      return {static_cast<int>((size + 7) / 8), false};
+
+    // Pass on the stack, i.e. no registers used
+    return {};
+  }
+
+  NRegs usedRegsForType(mlir::Location loc, mlir::Type type) const {
+    return llvm::TypeSwitch<mlir::Type, NRegs>(type)
+        .Case<mlir::IntegerType>([&](auto intTy) {
+          return intTy.getWidth() == 128 ? NRegs{2, false} : NRegs{1, false};
+        })
+        .Case<mlir::FloatType>([&](auto) { return NRegs{1, true}; })
+        .Case<mlir::ComplexType>([&](auto) { return NRegs{2, true}; })
+        .Case<fir::LogicalType>([&](auto) { return NRegs{1, false}; })
+        .Case<fir::CharacterType>([&](auto) { return NRegs{1, false}; })
+        .Case<fir::SequenceType>([&](auto ty) {
+          NRegs nregs = usedRegsForType(loc, ty.getEleTy());
+          nregs.n *= ty.getShape()[0];
+          return nregs;
+        })
+        .Case<fir::RecordType>(
+            [&](auto ty) { return usedRegsForRecordType(loc, ty); })
+        .Case<fir::VectorType>([&](auto) {
+          TODO(loc, "passing vector argument to C by value is not supported");
+          return NRegs{};
+        });
+  }
+
+  bool hasEnoughRegisters(mlir::Location loc, fir::RecordType type,
+                          const Marshalling &previousArguments) const {
+    int availIntRegisters = 8;
+    int availSIMDRegisters = 8;
+
+    // Check previous arguments to see how many registers are used already
+    for (auto [type, attr] : previousArguments) {
----------------
jeanPerier wrote:

Right, it is a bit dum, but I do not expect the BIND(C) VALUE struct argument usage to be high enough so I did not modify the logic/interface too much when I added for X86-64 impl.

The main "issue" is that we do not call the target lowering for all arguments, so the target lowering cannot maintain some register state properly.

Adding callbacks for "normal" arguments could arguably increase the cost more in general because of the virtual aspects of the callback than doing the computation "again" for the few BIND(C) VALUE struct arguments.

https://github.com/llvm/llvm-project/pull/118305


More information about the flang-commits mailing list