[flang-commits] [flang] [flang] Add struct passing target rewrite hooks and partial X86-64 impl (PR #74829)

via flang-commits flang-commits at lists.llvm.org
Fri Dec 8 03:33:35 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-fir-hlfir

Author: None (jeanPerier)

<details>
<summary>Changes</summary>

In the context of C/Fortran interoperability (BIND(C)), it is possible to give the VALUE attribute to a BIND(C) derived type dummy, which according to Fortran 2018 18.3.6 - 2. (4) implies that it must be passed like the equivalent C structure value. The way C structure value are passed is ABI dependent.

LLVM does not implement the C struct ABI passing for LLVM aggregate type arguments. It is up to the front-end, like clang is doing, to split the struct into registers or pass the struct on the stack (llvm "byval") as required by the target ABI.
So the logic for C struct passing sits in clang. Using it from flang requires setting up a lot of clang context and to bridge FIR/MLIR representation to clang AST representation for function signatures (in both directions). It is a non trivial task.
See https://stackoverflow.com/questions/39438033/passing-structs-by-value-in-llvm-ir/75002581#<!-- -->75002581.

Since BIND(C) struct are rather limited as opposed to generic C struct (e.g. no bit fields). It is easier to provide a limited implementation of it for the case that matter to Fortran.

This patch:
- Updates the generic target rewrite pass to keep track of both the new argument type and attributes. The motivation for this is to be able to tell if a previously marshalled argument is passed in memory (it is a C pointer), or if it is being passed on the stack (has the byval llvm attributes).
- Adds an entry point in the target specific codegen to marshal struct arguments, and use it in the generic target rewrite pass.
- Implements limited support for the X86-64 case. So far, the support allows telling if a struct must be passed in register or on the stack, and to deal with the stack case. The register case is left TODO in this patch.

The X86-64 ABI implemented is the System V ABI for AMD64 version 1.0

---

Patch is 77.99 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/74829.diff


13 Files Affected:

- (modified) flang/include/flang/Optimizer/CodeGen/CGPasses.td (+3-2) 
- (modified) flang/include/flang/Optimizer/CodeGen/Target.h (+29-5) 
- (modified) flang/include/flang/Optimizer/CodeGen/TypeConverter.h (+5-1) 
- (modified) flang/include/flang/Optimizer/Dialect/FIRTypes.td (+10) 
- (modified) flang/include/flang/Optimizer/Support/DataLayout.h (+12) 
- (modified) flang/lib/Optimizer/CodeGen/CodeGen.cpp (+13-1) 
- (modified) flang/lib/Optimizer/CodeGen/Target.cpp (+322-14) 
- (modified) flang/lib/Optimizer/CodeGen/TargetRewrite.cpp (+296-163) 
- (modified) flang/lib/Optimizer/CodeGen/TypeConverter.cpp (+3-2) 
- (modified) flang/lib/Optimizer/Dialect/FIRType.cpp (+6) 
- (modified) flang/lib/Optimizer/Support/DataLayout.cpp (+13) 
- (modified) flang/test/Fir/recursive-type-tco.fir (+2-2) 
- (added) flang/test/Fir/struct-passing-x86-64-byval.fir (+103) 


``````````diff
diff --git a/flang/include/flang/Optimizer/CodeGen/CGPasses.td b/flang/include/flang/Optimizer/CodeGen/CGPasses.td
index 0014298a27a22..5e47119582776 100644
--- a/flang/include/flang/Optimizer/CodeGen/CGPasses.td
+++ b/flang/include/flang/Optimizer/CodeGen/CGPasses.td
@@ -23,7 +23,7 @@ def FIRToLLVMLowering : Pass<"fir-to-llvm-ir", "mlir::ModuleOp"> {
     will also convert ops in the standard and FIRCG dialects.
   }];
   let constructor = "::fir::createFIRToLLVMPass()";
-  let dependentDialects = ["mlir::LLVM::LLVMDialect"];
+  let dependentDialects = ["mlir::LLVM::LLVMDialect", "mlir::DLTIDialect"];
   let options = [
     Option<"forcedTargetTriple", "target", "std::string", /*default=*/"",
            "Override module's target triple.">,
@@ -53,7 +53,8 @@ def TargetRewritePass : Pass<"target-rewrite", "mlir::ModuleOp"> {
       representations that may differ based on the target machine.
   }];
   let constructor = "::fir::createFirTargetRewritePass()";
-  let dependentDialects = [ "fir::FIROpsDialect", "mlir::func::FuncDialect" ];
+  let dependentDialects = [ "fir::FIROpsDialect", "mlir::func::FuncDialect",
+                            "mlir::DLTIDialect" ];
   let options = [
     Option<"forcedTargetTriple", "target", "std::string", /*default=*/"",
            "Override module's target triple.">,
diff --git a/flang/include/flang/Optimizer/CodeGen/Target.h b/flang/include/flang/Optimizer/CodeGen/Target.h
index acffe6c1cec9c..c3ef521ced120 100644
--- a/flang/include/flang/Optimizer/CodeGen/Target.h
+++ b/flang/include/flang/Optimizer/CodeGen/Target.h
@@ -13,6 +13,7 @@
 #ifndef FORTRAN_OPTMIZER_CODEGEN_TARGET_H
 #define FORTRAN_OPTMIZER_CODEGEN_TARGET_H
 
+#include "flang/Optimizer/Dialect/FIRType.h"
 #include "flang/Optimizer/Dialect/Support/KindMapping.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "llvm/TargetParser/Triple.h"
@@ -20,6 +21,10 @@
 #include <tuple>
 #include <vector>
 
+namespace mlir {
+class DataLayout;
+}
+
 namespace fir {
 
 namespace details {
@@ -62,14 +67,20 @@ class Attributes {
 class CodeGenSpecifics {
 public:
   using Attributes = details::Attributes;
-  using Marshalling = std::vector<std::tuple<mlir::Type, Attributes>>;
+  using TypeAndAttr = std::tuple<mlir::Type, Attributes>;
+  using Marshalling = std::vector<TypeAndAttr>;
+
+  static std::unique_ptr<CodeGenSpecifics> get(mlir::MLIRContext *ctx,
+                                               llvm::Triple &&trp,
+                                               KindMapping &&kindMap,
+                                               const mlir::DataLayout &dl);
 
-  static std::unique_ptr<CodeGenSpecifics>
-  get(mlir::MLIRContext *ctx, llvm::Triple &&trp, KindMapping &&kindMap);
+  static TypeAndAttr getTypeAndAttr(mlir::Type t) { return TypeAndAttr{t, {}}; }
 
   CodeGenSpecifics(mlir::MLIRContext *ctx, llvm::Triple &&trp,
-                   KindMapping &&kindMap)
-      : context{*ctx}, triple{std::move(trp)}, kindMap{std::move(kindMap)} {}
+                   KindMapping &&kindMap, const mlir::DataLayout &dl)
+      : context{*ctx}, triple{std::move(trp)}, kindMap{std::move(kindMap)},
+        dataLayout{&dl} {}
   CodeGenSpecifics() = delete;
   virtual ~CodeGenSpecifics() {}
 
@@ -90,6 +101,13 @@ class CodeGenSpecifics {
   /// Type presentation of a `boxchar<n>` type value in memory.
   virtual mlir::Type boxcharMemoryType(mlir::Type eleTy) const = 0;
 
+  /// Type representation of a `fir.type<T>` type argument when passed by
+  /// value. It may have to be split into several arguments, or be passed
+  /// as a byval reference argument (on the stack).
+  virtual Marshalling
+  structArgumentType(mlir::Location loc, fir::RecordType recTy,
+                     const Marshalling &previousArguments) const = 0;
+
   /// Type representation of a `boxchar<n>` type argument when passed by value.
   /// An argument value may need to be passed as a (safe) reference argument.
   ///
@@ -143,10 +161,16 @@ class CodeGenSpecifics {
   // Returns width in bits of C/C++ 'int' type size.
   virtual unsigned char getCIntTypeWidth() const = 0;
 
+  const mlir::DataLayout &getDataLayout() const {
+    assert(dataLayout && "dataLayout must be set");
+    return *dataLayout;
+  }
+
 protected:
   mlir::MLIRContext &context;
   llvm::Triple triple;
   KindMapping kindMap;
+  const mlir::DataLayout *dataLayout = nullptr;
 };
 
 } // namespace fir
diff --git a/flang/include/flang/Optimizer/CodeGen/TypeConverter.h b/flang/include/flang/Optimizer/CodeGen/TypeConverter.h
index 9ce756bdfd966..396c136392555 100644
--- a/flang/include/flang/Optimizer/CodeGen/TypeConverter.h
+++ b/flang/include/flang/Optimizer/CodeGen/TypeConverter.h
@@ -39,6 +39,10 @@ static constexpr unsigned kDimLowerBoundPos = 0;
 static constexpr unsigned kDimExtentPos = 1;
 static constexpr unsigned kDimStridePos = 2;
 
+namespace mlir {
+class DataLayout;
+}
+
 namespace fir {
 
 /// FIR type converter
@@ -46,7 +50,7 @@ namespace fir {
 class LLVMTypeConverter : public mlir::LLVMTypeConverter {
 public:
   LLVMTypeConverter(mlir::ModuleOp module, bool applyTBAA,
-                    bool forceUnifiedTBAATree);
+                    bool forceUnifiedTBAATree, const mlir::DataLayout &);
 
   // i32 is used here because LLVM wants i32 constants when indexing into struct
   // types. Indexing into other aggregate types is more flexible.
diff --git a/flang/include/flang/Optimizer/Dialect/FIRTypes.td b/flang/include/flang/Optimizer/Dialect/FIRTypes.td
index 51608e3c1d63e..2a2f50720859e 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRTypes.td
+++ b/flang/include/flang/Optimizer/Dialect/FIRTypes.td
@@ -326,6 +326,8 @@ def fir_RealType : FIR_Type<"Real", "real"> {
 
   let extraClassDeclaration = [{
     using KindTy = unsigned;
+    // Get MLIR float type with same semantics.
+    mlir::Type getFloatType(const fir::KindMapping &kindMap) const;
   }];
 
   let genVerifyDecl = 1;
@@ -495,6 +497,14 @@ def fir_SequenceType : FIR_Type<"Sequence", "array"> {
     static constexpr Extent getUnknownExtent() {
       return mlir::ShapedType::kDynamic;
     }
+
+    std::uint64_t getConstantArraySize() {
+      assert(!hasDynamicExtents() && "array type must have constant shape");
+      std::uint64_t size = 1;
+      for (Extent extent : getShape())
+        size = size * static_cast<std::uint64_t>(extent);
+      return size;
+    }
   }];
 }
 
diff --git a/flang/include/flang/Optimizer/Support/DataLayout.h b/flang/include/flang/Optimizer/Support/DataLayout.h
index 88ff575a8ff08..d21576bb95f79 100644
--- a/flang/include/flang/Optimizer/Support/DataLayout.h
+++ b/flang/include/flang/Optimizer/Support/DataLayout.h
@@ -13,6 +13,9 @@
 #ifndef FORTRAN_OPTIMIZER_SUPPORT_DATALAYOUT_H
 #define FORTRAN_OPTIMIZER_SUPPORT_DATALAYOUT_H
 
+#include "mlir/Interfaces/DataLayoutInterfaces.h"
+#include <optional>
+
 namespace mlir {
 class ModuleOp;
 }
@@ -34,6 +37,15 @@ void setMLIRDataLayout(mlir::ModuleOp mlirModule, const llvm::DataLayout &dl);
 /// nothing.
 void setMLIRDataLayoutFromAttributes(mlir::ModuleOp mlirModule,
                                      bool allowDefaultLayout);
+
+/// Create mlir::DataLayout from the data layout information on the
+/// mlir::Module. Creates the data layout information attributes with
+/// setMLIRDataLayoutFromAttributes if the DLTI attribute is not yet set. If no
+/// information is present at all and \p allowDefaultLayout is false, returns
+/// std::nullopt.
+std::optional<mlir::DataLayout>
+getOrSetDataLayout(mlir::ModuleOp mlirModule, bool allowDefaultLayout = false);
+
 } // namespace fir::support
 
 #endif // FORTRAN_OPTIMIZER_SUPPORT_DATALAYOUT_H
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index bf175c8ebadee..c0f3ea3241a77 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -16,6 +16,7 @@
 #include "flang/Optimizer/Dialect/FIRAttr.h"
 #include "flang/Optimizer/Dialect/FIROps.h"
 #include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/Support/DataLayout.h"
 #include "flang/Optimizer/Support/InternalNames.h"
 #include "flang/Optimizer/Support/TypeCode.h"
 #include "flang/Optimizer/Support/Utils.h"
@@ -34,6 +35,7 @@
 #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/DLTI/DLTI.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/Transforms/AddComdats.h"
 #include "mlir/Dialect/OpenACC/OpenACC.h"
@@ -3820,10 +3822,20 @@ class FIRToLLVMLowering
     if (mlir::failed(runPipeline(mathConvertionPM, mod)))
       return signalPassFailure();
 
+    std::optional<mlir::DataLayout> dl =
+        fir::support::getOrSetDataLayout(mod, /*allowDefaultLayout=*/true);
+    if (!dl) {
+      mlir::emitError(mod.getLoc(),
+                      "module operation must carry a data layout attribute "
+                      "to generate llvm IR from FIR");
+      signalPassFailure();
+      return;
+    }
+
     auto *context = getModule().getContext();
     fir::LLVMTypeConverter typeConverter{getModule(),
                                          options.applyTBAA || applyTBAA,
-                                         options.forceUnifiedTBAATree};
+                                         options.forceUnifiedTBAATree, *dl};
     mlir::RewritePatternSet pattern(context);
     pattern.insert<
         AbsentOpConversion, AddcOpConversion, AddrOfOpConversion,
diff --git a/flang/lib/Optimizer/CodeGen/Target.cpp b/flang/lib/Optimizer/CodeGen/Target.cpp
index 112f56e268c3b..ea10486a6b34c 100644
--- a/flang/lib/Optimizer/CodeGen/Target.cpp
+++ b/flang/lib/Optimizer/CodeGen/Target.cpp
@@ -18,6 +18,7 @@
 #include "flang/Optimizer/Support/Utils.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/TypeRange.h"
+#include "llvm/ADT/TypeSwitch.h"
 
 #define DEBUG_TYPE "flang-codegen-target"
 
@@ -58,6 +59,62 @@ static void typeTodo(const llvm::fltSemantics *sem, mlir::Location loc,
   }
 }
 
+/// Return the size of alignment of FIR types.
+/// TODO: consider moving this to a DataLayoutTypeInterface implementation
+/// for FIR types. It should first be ensured that it is OK to open the gate of
+/// target dependent type size inquiries in lowering. It would also not be
+/// straightforward given the need for a kind map that would need to be
+/// converted in terms of mlir::DataLayoutEntryKey.
+static std::pair<std::uint64_t, unsigned short>
+getSizeAndAlignment(mlir::Location loc, mlir::Type ty,
+                    const mlir::DataLayout &dl,
+                    const fir::KindMapping &kindMap) {
+  if (mlir::isa<mlir::IntegerType, mlir::FloatType, mlir::ComplexType>(ty)) {
+    llvm::TypeSize size = dl.getTypeSize(ty);
+    unsigned short alignment = dl.getTypeABIAlignment(ty);
+    return {size, alignment};
+  }
+  if (auto firCmplx = mlir::dyn_cast<fir::ComplexType>(ty)) {
+    auto [floatSize, floatAlign] =
+        getSizeAndAlignment(loc, firCmplx.getEleType(kindMap), dl, kindMap);
+    return {llvm::alignTo(floatSize, floatAlign) + floatSize, floatAlign};
+  }
+  if (auto real = mlir::dyn_cast<fir::RealType>(ty))
+    return getSizeAndAlignment(loc, real.getFloatType(kindMap), dl, kindMap);
+
+  if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty)) {
+    auto [eleSize, eleAlign] =
+        getSizeAndAlignment(loc, seqTy.getEleTy(), dl, kindMap);
+
+    std::uint64_t size =
+        llvm::alignTo(eleSize, eleAlign) * seqTy.getConstantArraySize();
+    return {size, eleAlign};
+  }
+  if (auto recTy = mlir::dyn_cast<fir::RecordType>(ty)) {
+    std::uint64_t size = 0;
+    unsigned short align = 8;
+    for (auto component : recTy.getTypeList()) {
+      auto [compSize, compAlign] =
+          getSizeAndAlignment(loc, component.second, dl, kindMap);
+      size =
+          llvm::alignTo(size, compAlign) + llvm::alignTo(compSize, compAlign);
+      align = std::max(align, compAlign);
+    }
+    return {size, align};
+  }
+  if (auto logical = mlir::dyn_cast<fir::LogicalType>(ty)) {
+    mlir::Type intTy = mlir::IntegerType::get(
+        logical.getContext(), kindMap.getLogicalBitsize(logical.getFKind()));
+    return getSizeAndAlignment(loc, intTy, dl, kindMap);
+  }
+  if (auto logical = mlir::dyn_cast<fir::CharacterType>(ty)) {
+    mlir::Type intTy = mlir::IntegerType::get(
+        logical.getContext(), kindMap.getLogicalBitsize(logical.getFKind()));
+    return getSizeAndAlignment(loc, intTy, dl, kindMap);
+  }
+  TODO(loc, "computing size of a component");
+}
+
 namespace {
 template <typename S>
 struct GenericTarget : public CodeGenSpecifics {
@@ -95,6 +152,12 @@ struct GenericTarget : public CodeGenSpecifics {
     return marshal;
   }
 
+  CodeGenSpecifics::Marshalling
+  structArgumentType(mlir::Location loc, fir::RecordType,
+                     const Marshalling &) const override {
+    TODO(loc, "passing VALUE BIND(C) derived type for this target");
+  }
+
   CodeGenSpecifics::Marshalling
   integerArgumentType(mlir::Location loc,
                       mlir::IntegerType argTy) const override {
@@ -318,6 +381,251 @@ struct TargetX86_64 : public GenericTarget<TargetX86_64> {
     }
     return marshal;
   }
+
+  /// X86-64 argument classes from System V ABI version 1.0 section 3.2.3.
+  enum ArgClass {
+    Integer = 0,
+    SSE,
+    SSEUp,
+    X87,
+    X87Up,
+    ComplexX87,
+    NoClass,
+    Memory
+  };
+
+  /// Classify an argument type or a field of an aggregate type argument.
+  /// See ystem V ABI version 1.0 section 3.2.3.
+  /// The Lo and Hi class are set to the class of the lower eight eightbytes
+  /// and upper eight eightbytes on return.
+  /// If this is called for an aggregate field, the caller is responsible to
+  /// do the post-merge.
+  void classify(mlir::Location loc, mlir::Type type, std::uint64_t byteOffset,
+                ArgClass &Lo, ArgClass &Hi) const {
+    Hi = Lo = ArgClass::NoClass;
+    ArgClass &current = byteOffset < 8 ? Lo : Hi;
+    // System V AMD64 ABI 3.2.3. version 1.0
+    llvm::TypeSwitch<mlir::Type>(type)
+        .template Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
+          if (intTy.getWidth() == 128)
+            Hi = Lo = ArgClass::Integer;
+          else
+            current = ArgClass::Integer;
+        })
+        .template Case<mlir::FloatType, fir::RealType>([&](mlir::Type floatTy) {
+          const auto *sem = &floatToSemantics(kindMap, floatTy);
+          if (sem == &llvm::APFloat::x87DoubleExtended()) {
+            Lo = ArgClass::X87;
+            Hi = ArgClass::X87Up;
+          } else if (sem == &llvm::APFloat::IEEEquad()) {
+            Lo = ArgClass::SSE;
+            Hi = ArgClass::SSEUp;
+          } else {
+            current = ArgClass::SSE;
+          }
+        })
+        .template Case<fir::ComplexType>([&](fir::ComplexType cmplx) {
+          const auto *sem = &floatToSemantics(kindMap, cmplx.getElementType());
+          if (sem == &llvm::APFloat::x87DoubleExtended()) {
+            current = ArgClass::ComplexX87;
+          } else {
+            fir::SequenceType::Shape shape{2};
+            classifyArray(loc,
+                          fir::SequenceType::get(shape, cmplx.getElementType()),
+                          byteOffset, Lo, Hi);
+          }
+        })
+        .template Case<fir::LogicalType>([&](fir::LogicalType logical) {
+          if (kindMap.getLogicalBitsize(logical.getFKind()) == 128)
+            Hi = Lo = ArgClass::Integer;
+          else
+            current = ArgClass::Integer;
+        })
+        .template Case<fir::CharacterType>(
+            [&](fir::CharacterType character) { current = ArgClass::Integer; })
+        .template Case<fir::SequenceType>([&](fir::SequenceType seqTy) {
+          // Array component.
+          classifyArray(loc, seqTy, byteOffset, Lo, Hi);
+        })
+        .template Case<fir::RecordType>([&](fir::RecordType recTy) {
+          // Component that is a derived type.
+          classifyStruct(loc, recTy, byteOffset, Lo, Hi);
+        })
+        .template Case<fir::VectorType>([&](fir::VectorType vecTy) {
+          // Previously marshalled SSE eight byte for a previous struct
+          // argument.
+          auto *sem = fir::isa_real(vecTy.getEleTy())
+                          ? &floatToSemantics(kindMap, vecTy.getEleTy())
+                          : nullptr;
+          // Note expecting to hit this todo in standard code (it would
+          // require some vector type extension).
+          if (!(sem == &llvm::APFloat::IEEEsingle() && vecTy.getLen() <= 2) &&
+              !(sem == &llvm::APFloat::IEEEhalf() && vecTy.getLen() <= 4))
+            TODO(loc, "passing vector argument to C by value");
+          current = SSE;
+        })
+        .Default([&](mlir::Type ty) {
+          if (fir::conformsWithPassByRef(ty))
+            current = ArgClass::Integer; // Pointers.
+          else
+            TODO(loc, "unsupported component type for BIND(C), VALUE derived "
+                      "type argument");
+        });
+  }
+
+  // Classify fields of a derived type starting at \p offset. Returns the new
+  // offset. Post-merge is left to the caller.
+  std::uint64_t classifyStruct(mlir::Location loc, fir::RecordType recTy,
+                               std::uint64_t byteOffset, ArgClass &Lo,
+                               ArgClass &Hi) const {
+    for (auto component : recTy.getTypeList()) {
+      if (byteOffset > 16) {
+        Lo = Hi = ArgClass::Memory;
+        return byteOffset;
+      }
+      mlir::Type compType = component.second;
+      auto [compSize, compAlign] =
+          getSizeAndAlignment(loc, compType, getDataLayout(), kindMap);
+      byteOffset = llvm::alignTo(byteOffset, compAlign);
+      ArgClass LoComp, HiComp;
+      classify(loc, compType, byteOffset, LoComp, HiComp);
+      Lo = mergeClass(Lo, LoComp);
+      Hi = mergeClass(Hi, HiComp);
+      byteOffset = byteOffset + llvm::alignTo(compSize, compAlign);
+      if (Lo == ArgClass::Memory || Hi == ArgClass::Memory)
+        return byteOffset;
+    }
+    return byteOffset;
+  }
+
+  // Classify fields of a constant size array type starting at \p offset.
+  // Returns the new offset. Post-merge is left to the caller.
+  void classifyArray(mlir::Location loc, fir::SequenceType seqTy,
+                     std::uint64_t byteOffset, ArgClass &Lo,
+                     ArgClass &Hi) const {
+    mlir::Type eleTy = seqTy.getEleTy();
+    const std::uint64_t arraySize = seqTy.getConstantArraySize();
+    auto [eleSize, eleAlign] =
+        getSizeAndAlignment(loc, eleTy, getDataLayout(), kindMap);
+    std::uint64_t eleStorageSize = llvm::alignTo(eleSize, eleAlign);
+    for (std::uint64_t i = 0; i < arraySize; ++i) {
+      byteOffset = llvm::alignTo(byteOffset, eleAlign);
+      if (byteOffset > 16) {
+        Lo = Hi = ArgClass::Memory;
+        return;
+      }
+      ArgClass LoComp, HiComp;
+      classify(loc, eleTy, byteOffset, LoComp, HiComp);
+      Lo = mergeClass(Lo, LoComp);
+      Hi = mergeClass(Hi, HiComp);
+      byteOffset = byteOffset + eleStorageSize;
+      if (Lo == ArgClass::Memory || Hi == ArgClass::Memory)
+        return;
+    }
+  }
+
+  // Goes through the previously marshalled arguments and count the
+  // register occupancy to check if there are enough registers left.
+  bool hasEnoughRegisters(mlir::Location loc, int neededIntRegisters,
+                          int neededSSERegisters,
+                          const Marshalling &previousArguments) const {
+    int availIntRegisters = 6;
+    int availSSERegisters = 8;
+    for (auto typeAndAttr : previousArguments) {
+      const auto &attr = std::get<Attributes>(typeAndAttr);
+      if (attr.isByVal() || attr.isSRet())
+        continue; // Previous argument passed on the stack.
+      ArgClass Lo, Hi;
+      Lo = Hi = ArgClass::NoClass;
+      classify(loc, std::get<mlir::Type>(typeAndAttr), 0, Lo, Hi);
+      // post merge is not needed here since previous aggregate arguments
+      // were marshalled into simpler arguments.
+      if (Lo == ArgClass::Integer)
+        --availIntRegisters;
+      else if (Lo == SSE)
+        --availSSERegisters;
+      if (Hi == ArgClass::Integer)
+        --availIntRegisters;
+      else if (Hi == ArgClass::SSE)
+        --availSSERegisters;
+    }
+    return availSSERegisters...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/74829


More information about the flang-commits mailing list