[clang] [Clang][AArch64] Fix Pure Scalables Types argument passing and return (PR #112747)

Jonathan Thackray via cfe-commits cfe-commits at lists.llvm.org
Thu Oct 17 13:04:12 PDT 2024


================
@@ -533,11 +638,158 @@ bool AArch64ABIInfo::isZeroLengthBitfieldPermittedInHomogeneousAggregate()
   return true;
 }
 
+// Check if a type is a Pure Scalable Type as defined by AAPCS64. Return the
+// number of data vectors and the number of predicate vectors in the types, into
+// `NVec` and `NPred`, respectively. Upon return `CoerceToSeq` contains an
+// expanded sequence of LLVM IR types, one element for each non-composite
+// member. For practical purposes, limit the length of `CoerceToSeq` to about
+// 12, the maximum size that could possibly fit in registers.
+bool AArch64ABIInfo::isPureScalableType(
+    QualType Ty, unsigned &NVec, unsigned &NPred,
+    SmallVectorImpl<llvm::Type *> &CoerceToSeq) const {
+  if (const ConstantArrayType *AT = getContext().getAsConstantArrayType(Ty)) {
+    uint64_t NElt = AT->getZExtSize();
+    if (NElt == 0)
+      return false;
+
+    unsigned NV = 0, NP = 0;
+    SmallVector<llvm::Type *> EltCoerceToSeq;
+    if (!isPureScalableType(AT->getElementType(), NV, NP, EltCoerceToSeq))
+      return false;
+
+    for (uint64_t I = 0; CoerceToSeq.size() < 12 && I < NElt; ++I)
+      llvm::copy(EltCoerceToSeq, std::back_inserter(CoerceToSeq));
+
+    NVec += NElt * NV;
+    NPred += NElt * NP;
+    return true;
+  }
+
+  if (const RecordType *RT = Ty->getAs<RecordType>()) {
+    // If the record cannot be passed in registers, then it's not a PST.
+    if (CGCXXABI::RecordArgABI RAA = getRecordArgABI(RT, getCXXABI());
+        RAA != CGCXXABI::RAA_Default)
+      return false;
+
+    // Pure scalable types are never unions and never contain unions.
+    const RecordDecl *RD = RT->getDecl();
+    if (RD->isUnion())
+      return false;
+
+    // If this is a C++ record, check the bases bases.
+    if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+      for (const auto &I : CXXRD->bases()) {
+        if (isEmptyRecord(getContext(), I.getType(), true))
+          continue;
+        if (!isPureScalableType(I.getType(), NVec, NPred, CoerceToSeq))
+          return false;
+      }
+    }
+
+    // Check members.
+    for (const auto *FD : RD->fields()) {
+      QualType FT = FD->getType();
+      if (isEmptyRecord(getContext(), FT, true))
+        continue;
+      if (!isPureScalableType(FT, NVec, NPred, CoerceToSeq))
+        return false;
+    }
+
+    return true;
+  }
+
+  const auto *VT = Ty->getAs<VectorType>();
+  if (!VT)
+    return false;
+
+  if (VT->getVectorKind() == VectorKind::SveFixedLengthPredicate) {
+    ++NPred;
+    if (CoerceToSeq.size() < 12)
+      CoerceToSeq.push_back(convertFixedToScalableVectorType(VT));
+    return true;
+  }
+
+  if (VT->getVectorKind() == VectorKind::SveFixedLengthData) {
+    ++NVec;
+    if (CoerceToSeq.size() < 12)
+      CoerceToSeq.push_back(convertFixedToScalableVectorType(VT));
+    return true;
+  }
+
+  if (!VT->isBuiltinType())
+    return false;
+
+  switch (cast<BuiltinType>(VT)->getKind()) {
+#define SVE_VECTOR_TYPE(Name, MangledName, Id, SingletonId)                    \
+  case BuiltinType::Id:                                                        \
+    ++NVec;                                                                    \
+    break;
+#define SVE_PREDICATE_TYPE(Name, MangledName, Id, SingletonId)                 \
+  case BuiltinType::Id:                                                        \
+    ++NPred;                                                                   \
+    break;
+#define SVE_TYPE(Name, Id, SingletonId)
+#include "clang/Basic/AArch64SVEACLETypes.def"
+  default:
+    return false;
+  }
+
+  ASTContext::BuiltinVectorTypeInfo Info =
+      getContext().getBuiltinVectorTypeInfo(cast<BuiltinType>(Ty));
+  assert(Info.NumVectors > 0 && Info.NumVectors <= 4 &&
+         "Expected 1, 2, 3 or 4 vectors!");
+  auto VTy = llvm::ScalableVectorType::get(CGT.ConvertType(Info.ElementType),
+                                           Info.EC.getKnownMinValue());
+
+  if (CoerceToSeq.size() < 12)
+    std::fill_n(std::back_inserter(CoerceToSeq), Info.NumVectors, VTy);
+
+  return true;
+}
+
+// Expand an LLVM IR type into a sequence with a element for each non-struct,
+// non-array member of the type, with the exception of the padding types, which
+// are retained.
+void AArch64ABIInfo::flattenType(
+    llvm::Type *Ty, SmallVectorImpl<llvm::Type *> &Flattened) const {
+
+  if (ABIArgInfo::isPaddingForCoerceAndExpand(Ty)) {
+    Flattened.push_back(Ty);
+    return;
+  }
+
+  if (const auto *AT = dyn_cast<llvm::ArrayType>(Ty)) {
+    uint64_t NElt = AT->getNumElements();
+    if (NElt == 0)
+      return;
+
+    SmallVector<llvm::Type *> EltFlattened;
+    flattenType(AT->getElementType(), EltFlattened);
+
+    for (uint64_t I = 0; I < NElt; ++I)
+      llvm::copy(EltFlattened, std::back_inserter(Flattened));
+    return;
+  }
+
+  if (const auto *ST = dyn_cast<llvm::StructType>(Ty)) {
+    for (auto *ET : ST->elements())
+      flattenType(ET, Flattened);
+    return;
+  }
+
+  Flattened.push_back(Ty);
+}
+
 RValue AArch64ABIInfo::EmitAAPCSVAArg(Address VAListAddr, QualType Ty,
                                       CodeGenFunction &CGF, AArch64ABIKind Kind,
                                       AggValueSlot Slot) const {
-  ABIArgInfo AI = classifyArgumentType(Ty, /*IsVariadic=*/true,
-                                       CGF.CurFnInfo->getCallingConvention());
+  // These numbers are not used for variadic arguments, hence it doesn't matter
+  // they don't retain their values accross multiple calls to
----------------
jthackray wrote:

nit: s/accross/across/

https://github.com/llvm/llvm-project/pull/112747


More information about the cfe-commits mailing list