[flang-commits] [flang] [flang] Add struct passing target rewrite hooks and partial X86-64 impl (PR #74829)

Slava Zakharin via flang-commits flang-commits at lists.llvm.org
Fri Dec 8 16:05:27 PST 2023


================
@@ -318,6 +382,251 @@ struct TargetX86_64 : public GenericTarget<TargetX86_64> {
     }
     return marshal;
   }
+
+  /// X86-64 argument classes from System V ABI version 1.0 section 3.2.3.
+  enum ArgClass {
+    Integer = 0,
+    SSE,
+    SSEUp,
+    X87,
+    X87Up,
+    ComplexX87,
+    NoClass,
+    Memory
+  };
+
+  /// Classify an argument type or a field of an aggregate type argument.
+  /// See ystem V ABI version 1.0 section 3.2.3.
+  /// The Lo and Hi class are set to the class of the lower eight eightbytes
+  /// and upper eight eightbytes on return.
+  /// If this is called for an aggregate field, the caller is responsible to
+  /// do the post-merge.
+  void classify(mlir::Location loc, mlir::Type type, std::uint64_t byteOffset,
+                ArgClass &Lo, ArgClass &Hi) const {
+    Hi = Lo = ArgClass::NoClass;
+    ArgClass &current = byteOffset < 8 ? Lo : Hi;
+    // System V AMD64 ABI 3.2.3. version 1.0
+    llvm::TypeSwitch<mlir::Type>(type)
+        .template Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
+          if (intTy.getWidth() == 128)
+            Hi = Lo = ArgClass::Integer;
+          else
+            current = ArgClass::Integer;
+        })
+        .template Case<mlir::FloatType, fir::RealType>([&](mlir::Type floatTy) {
+          const auto *sem = &floatToSemantics(kindMap, floatTy);
+          if (sem == &llvm::APFloat::x87DoubleExtended()) {
+            Lo = ArgClass::X87;
+            Hi = ArgClass::X87Up;
+          } else if (sem == &llvm::APFloat::IEEEquad()) {
+            Lo = ArgClass::SSE;
+            Hi = ArgClass::SSEUp;
+          } else {
+            current = ArgClass::SSE;
+          }
+        })
+        .template Case<fir::ComplexType>([&](fir::ComplexType cmplx) {
+          const auto *sem = &floatToSemantics(kindMap, cmplx.getElementType());
+          if (sem == &llvm::APFloat::x87DoubleExtended()) {
+            current = ArgClass::ComplexX87;
+          } else {
+            fir::SequenceType::Shape shape{2};
+            classifyArray(loc,
+                          fir::SequenceType::get(shape, cmplx.getElementType()),
+                          byteOffset, Lo, Hi);
+          }
+        })
+        .template Case<fir::LogicalType>([&](fir::LogicalType logical) {
+          if (kindMap.getLogicalBitsize(logical.getFKind()) == 128)
+            Hi = Lo = ArgClass::Integer;
+          else
+            current = ArgClass::Integer;
+        })
+        .template Case<fir::CharacterType>(
+            [&](fir::CharacterType character) { current = ArgClass::Integer; })
+        .template Case<fir::SequenceType>([&](fir::SequenceType seqTy) {
+          // Array component.
+          classifyArray(loc, seqTy, byteOffset, Lo, Hi);
+        })
+        .template Case<fir::RecordType>([&](fir::RecordType recTy) {
+          // Component that is a derived type.
+          classifyStruct(loc, recTy, byteOffset, Lo, Hi);
+        })
+        .template Case<fir::VectorType>([&](fir::VectorType vecTy) {
+          // Previously marshalled SSE eight byte for a previous struct
+          // argument.
+          auto *sem = fir::isa_real(vecTy.getEleTy())
+                          ? &floatToSemantics(kindMap, vecTy.getEleTy())
+                          : nullptr;
+          // Note expecting to hit this todo in standard code (it would
+          // require some vector type extension).
+          if (!(sem == &llvm::APFloat::IEEEsingle() && vecTy.getLen() <= 2) &&
+              !(sem == &llvm::APFloat::IEEEhalf() && vecTy.getLen() <= 4))
+            TODO(loc, "passing vector argument to C by value");
+          current = SSE;
+        })
+        .Default([&](mlir::Type ty) {
+          if (fir::conformsWithPassByRef(ty))
+            current = ArgClass::Integer; // Pointers.
+          else
+            TODO(loc, "unsupported component type for BIND(C), VALUE derived "
+                      "type argument");
+        });
+  }
+
+  // Classify fields of a derived type starting at \p offset. Returns the new
+  // offset. Post-merge is left to the caller.
+  std::uint64_t classifyStruct(mlir::Location loc, fir::RecordType recTy,
+                               std::uint64_t byteOffset, ArgClass &Lo,
+                               ArgClass &Hi) const {
+    for (auto component : recTy.getTypeList()) {
+      if (byteOffset > 16) {
+        Lo = Hi = ArgClass::Memory;
+        return byteOffset;
+      }
+      mlir::Type compType = component.second;
+      auto [compSize, compAlign] =
+          getSizeAndAlignment(loc, compType, getDataLayout(), kindMap);
+      byteOffset = llvm::alignTo(byteOffset, compAlign);
+      ArgClass LoComp, HiComp;
+      classify(loc, compType, byteOffset, LoComp, HiComp);
+      Lo = mergeClass(Lo, LoComp);
+      Hi = mergeClass(Hi, HiComp);
+      byteOffset = byteOffset + llvm::alignTo(compSize, compAlign);
+      if (Lo == ArgClass::Memory || Hi == ArgClass::Memory)
+        return byteOffset;
+    }
+    return byteOffset;
+  }
+
+  // Classify fields of a constant size array type starting at \p offset.
+  // Returns the new offset. Post-merge is left to the caller.
+  void classifyArray(mlir::Location loc, fir::SequenceType seqTy,
+                     std::uint64_t byteOffset, ArgClass &Lo,
+                     ArgClass &Hi) const {
+    mlir::Type eleTy = seqTy.getEleTy();
+    const std::uint64_t arraySize = seqTy.getConstantArraySize();
+    auto [eleSize, eleAlign] =
+        getSizeAndAlignment(loc, eleTy, getDataLayout(), kindMap);
+    std::uint64_t eleStorageSize = llvm::alignTo(eleSize, eleAlign);
+    for (std::uint64_t i = 0; i < arraySize; ++i) {
+      byteOffset = llvm::alignTo(byteOffset, eleAlign);
+      if (byteOffset > 16) {
+        Lo = Hi = ArgClass::Memory;
+        return;
+      }
+      ArgClass LoComp, HiComp;
+      classify(loc, eleTy, byteOffset, LoComp, HiComp);
+      Lo = mergeClass(Lo, LoComp);
+      Hi = mergeClass(Hi, HiComp);
+      byteOffset = byteOffset + eleStorageSize;
+      if (Lo == ArgClass::Memory || Hi == ArgClass::Memory)
+        return;
+    }
+  }
+
+  // Goes through the previously marshalled arguments and count the
+  // register occupancy to check if there are enough registers left.
+  bool hasEnoughRegisters(mlir::Location loc, int neededIntRegisters,
+                          int neededSSERegisters,
+                          const Marshalling &previousArguments) const {
+    int availIntRegisters = 6;
+    int availSSERegisters = 8;
+    for (auto typeAndAttr : previousArguments) {
+      const auto &attr = std::get<Attributes>(typeAndAttr);
+      if (attr.isByVal() || attr.isSRet())
----------------
vzakhari wrote:

I am not sure about the `sret` case: isn't it passed as an address in a register?

https://github.com/llvm/llvm-project/pull/74829


More information about the flang-commits mailing list