[clang] [Clang][AArch64] Fix Pure Scalables Types argument passing and return (PR #112747)

via cfe-commits cfe-commits at lists.llvm.org
Thu Oct 17 09:59:50 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-clang

@llvm/pr-subscribers-clang-codegen

Author: Momchil Velikov (momchil-velikov)

<details>
<summary>Changes</summary>

Pure Scalable Types are defined in AAPCS64 here:
  https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#pure-scalable-types-psts

And should be passed according to Rule C.7 here:
  https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#<!-- -->682parameter-passing-rules

This part of the ABI is completely unimplemented in Clang, instead it treats PSTs sometimes as HFAs/HVAs, sometime as general composite types.

This patch implements the rules for passing PSTs by employing the `CoerceAndExpand` method and extending it to:
  * allow array types in the `coerceToType`; Now only `[N x i8]` are considered padding.
  * allow mismatch between the elements of the `coerceToType` and the elements of the `unpaddedCoerceToType`; AArch64 uses this to map fixed-length vector types to SVE vector types.

Corectly passing a PST argument needs a decision in Clang about whether to pass it in memory or registers or, equivalently, whether to use the `Indirect` or `Expand/CoerceAndExpand` method. It was considered relatively harder (or not practically possible) to make that decision in the AArch64 backend.
Hence this patch implements the register counting from AAPCS64 (cf. `NSRN`, `NPRN`) to guide the Clang's decision.

---

Patch is 42.50 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/112747.diff


4 Files Affected:

- (modified) clang/include/clang/CodeGen/CGFunctionInfo.h (+2-11) 
- (modified) clang/lib/CodeGen/CGCall.cpp (+64-21) 
- (modified) clang/lib/CodeGen/Targets/AArch64.cpp (+289-37) 
- (added) clang/test/CodeGen/aarch64-pure-scalable-args.c (+314) 


``````````diff
diff --git a/clang/include/clang/CodeGen/CGFunctionInfo.h b/clang/include/clang/CodeGen/CGFunctionInfo.h
index d19f84d198876f..915f676d7d3905 100644
--- a/clang/include/clang/CodeGen/CGFunctionInfo.h
+++ b/clang/include/clang/CodeGen/CGFunctionInfo.h
@@ -272,11 +272,6 @@ class ABIArgInfo {
     unsigned unpaddedIndex = 0;
     for (auto eltType : coerceToType->elements()) {
       if (isPaddingForCoerceAndExpand(eltType)) continue;
-      if (unpaddedStruct) {
-        assert(unpaddedStruct->getElementType(unpaddedIndex) == eltType);
-      } else {
-        assert(unpaddedIndex == 0 && unpaddedCoerceToType == eltType);
-      }
       unpaddedIndex++;
     }
 
@@ -295,12 +290,8 @@ class ABIArgInfo {
   }
 
   static bool isPaddingForCoerceAndExpand(llvm::Type *eltType) {
-    if (eltType->isArrayTy()) {
-      assert(eltType->getArrayElementType()->isIntegerTy(8));
-      return true;
-    } else {
-      return false;
-    }
+    return eltType->isArrayTy() &&
+           eltType->getArrayElementType()->isIntegerTy(8);
   }
 
   Kind getKind() const { return TheKind; }
diff --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp
index 4ae981e4013e9c..3c75dae9918af9 100644
--- a/clang/lib/CodeGen/CGCall.cpp
+++ b/clang/lib/CodeGen/CGCall.cpp
@@ -1410,6 +1410,30 @@ static Address emitAddressAtOffset(CodeGenFunction &CGF, Address addr,
   return addr;
 }
 
+static std::pair<llvm::Value *, bool>
+CoerceScalableToFixed(CodeGenFunction &CGF, llvm::FixedVectorType *ToTy,
+                      llvm::ScalableVectorType *FromTy, llvm::Value *V,
+                      StringRef Name = "") {
+  // If we are casting a scalable i1 predicate vector to a fixed i8
+  // vector, first bitcast the source.
+  if (FromTy->getElementType()->isIntegerTy(1) &&
+      FromTy->getElementCount().isKnownMultipleOf(8) &&
+      ToTy->getElementType() == CGF.Builder.getInt8Ty()) {
+    FromTy = llvm::ScalableVectorType::get(
+        ToTy->getElementType(),
+        FromTy->getElementCount().getKnownMinValue() / 8);
+    V = CGF.Builder.CreateBitCast(V, FromTy);
+  }
+  if (FromTy->getElementType() == ToTy->getElementType()) {
+    llvm::Value *Zero = llvm::Constant::getNullValue(CGF.CGM.Int64Ty);
+
+    V->setName(Name + ".coerce");
+    V = CGF.Builder.CreateExtractVector(ToTy, V, Zero, "cast.fixed");
+    return {V, true};
+  }
+  return {V, false};
+}
+
 namespace {
 
 /// Encapsulates information about the way function arguments from
@@ -3196,26 +3220,14 @@ void CodeGenFunction::EmitFunctionProlog(const CGFunctionInfo &FI,
       // a VLAT at the function boundary and the types match up, use
       // llvm.vector.extract to convert back to the original VLST.
       if (auto *VecTyTo = dyn_cast<llvm::FixedVectorType>(ConvertType(Ty))) {
-        llvm::Value *Coerced = Fn->getArg(FirstIRArg);
+        llvm::Value *ArgVal = Fn->getArg(FirstIRArg);
         if (auto *VecTyFrom =
-                dyn_cast<llvm::ScalableVectorType>(Coerced->getType())) {
-          // If we are casting a scalable i1 predicate vector to a fixed i8
-          // vector, bitcast the source and use a vector extract.
-          if (VecTyFrom->getElementType()->isIntegerTy(1) &&
-              VecTyFrom->getElementCount().isKnownMultipleOf(8) &&
-              VecTyTo->getElementType() == Builder.getInt8Ty()) {
-            VecTyFrom = llvm::ScalableVectorType::get(
-                VecTyTo->getElementType(),
-                VecTyFrom->getElementCount().getKnownMinValue() / 8);
-            Coerced = Builder.CreateBitCast(Coerced, VecTyFrom);
-          }
-          if (VecTyFrom->getElementType() == VecTyTo->getElementType()) {
-            llvm::Value *Zero = llvm::Constant::getNullValue(CGM.Int64Ty);
-
+                dyn_cast<llvm::ScalableVectorType>(ArgVal->getType())) {
+          auto [Coerced, Extracted] = CoerceScalableToFixed(
+              *this, VecTyTo, VecTyFrom, ArgVal, Arg->getName());
+          if (Extracted) {
             assert(NumIRArgs == 1);
-            Coerced->setName(Arg->getName() + ".coerce");
-            ArgVals.push_back(ParamValue::forDirect(Builder.CreateExtractVector(
-                VecTyTo, Coerced, Zero, "cast.fixed")));
+            ArgVals.push_back(ParamValue::forDirect(Coerced));
             break;
           }
         }
@@ -3326,16 +3338,33 @@ void CodeGenFunction::EmitFunctionProlog(const CGFunctionInfo &FI,
       ArgVals.push_back(ParamValue::forIndirect(alloca));
 
       auto coercionType = ArgI.getCoerceAndExpandType();
+      auto unpaddedCoercionType = ArgI.getUnpaddedCoerceAndExpandType();
+      auto *unpaddedStruct = dyn_cast<llvm::StructType>(unpaddedCoercionType);
+
       alloca = alloca.withElementType(coercionType);
 
       unsigned argIndex = FirstIRArg;
+      unsigned unpaddedIndex = 0;
       for (unsigned i = 0, e = coercionType->getNumElements(); i != e; ++i) {
         llvm::Type *eltType = coercionType->getElementType(i);
         if (ABIArgInfo::isPaddingForCoerceAndExpand(eltType))
           continue;
 
         auto eltAddr = Builder.CreateStructGEP(alloca, i);
-        auto elt = Fn->getArg(argIndex++);
+        llvm::Value *elt = Fn->getArg(argIndex++);
+
+        auto paramType = unpaddedStruct
+                             ? unpaddedStruct->getElementType(unpaddedIndex++)
+                             : unpaddedCoercionType;
+
+        if (auto *VecTyTo = dyn_cast<llvm::FixedVectorType>(eltType)) {
+          if (auto *VecTyFrom = dyn_cast<llvm::ScalableVectorType>(paramType)) {
+            bool Extracted;
+            std::tie(elt, Extracted) = CoerceScalableToFixed(
+                *this, VecTyTo, VecTyFrom, elt, elt->getName());
+            assert(Extracted && "Unexpected scalable to fixed vector coercion");
+          }
+        }
         Builder.CreateStore(elt, eltAddr);
       }
       assert(argIndex == FirstIRArg + NumIRArgs);
@@ -3930,17 +3959,24 @@ void CodeGenFunction::EmitFunctionEpilog(const CGFunctionInfo &FI,
 
   case ABIArgInfo::CoerceAndExpand: {
     auto coercionType = RetAI.getCoerceAndExpandType();
+    auto unpaddedCoercionType = RetAI.getUnpaddedCoerceAndExpandType();
+    auto *unpaddedStruct = dyn_cast<llvm::StructType>(unpaddedCoercionType);
 
     // Load all of the coerced elements out into results.
     llvm::SmallVector<llvm::Value*, 4> results;
     Address addr = ReturnValue.withElementType(coercionType);
+    unsigned unpaddedIndex = 0;
     for (unsigned i = 0, e = coercionType->getNumElements(); i != e; ++i) {
       auto coercedEltType = coercionType->getElementType(i);
       if (ABIArgInfo::isPaddingForCoerceAndExpand(coercedEltType))
         continue;
 
       auto eltAddr = Builder.CreateStructGEP(addr, i);
-      auto elt = Builder.CreateLoad(eltAddr);
+      llvm::Value *elt = CreateCoercedLoad(
+          eltAddr,
+          unpaddedStruct ? unpaddedStruct->getElementType(unpaddedIndex++)
+                         : unpaddedCoercionType,
+          *this);
       results.push_back(elt);
     }
 
@@ -5468,6 +5504,8 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
     case ABIArgInfo::CoerceAndExpand: {
       auto coercionType = ArgInfo.getCoerceAndExpandType();
       auto layout = CGM.getDataLayout().getStructLayout(coercionType);
+      auto unpaddedCoercionType = ArgInfo.getUnpaddedCoerceAndExpandType();
+      auto *unpaddedStruct = dyn_cast<llvm::StructType>(unpaddedCoercionType);
 
       llvm::Value *tempSize = nullptr;
       Address addr = Address::invalid();
@@ -5498,11 +5536,16 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
       addr = addr.withElementType(coercionType);
 
       unsigned IRArgPos = FirstIRArg;
+      unsigned unpaddedIndex = 0;
       for (unsigned i = 0, e = coercionType->getNumElements(); i != e; ++i) {
         llvm::Type *eltType = coercionType->getElementType(i);
         if (ABIArgInfo::isPaddingForCoerceAndExpand(eltType)) continue;
         Address eltAddr = Builder.CreateStructGEP(addr, i);
-        llvm::Value *elt = Builder.CreateLoad(eltAddr);
+        llvm::Value *elt = CreateCoercedLoad(
+            eltAddr,
+            unpaddedStruct ? unpaddedStruct->getElementType(unpaddedIndex++)
+                           : unpaddedCoercionType,
+            *this);
         if (ArgHasMaybeUndefAttr)
           elt = Builder.CreateFreeze(elt);
         IRCallArgs[IRArgPos++] = elt;
diff --git a/clang/lib/CodeGen/Targets/AArch64.cpp b/clang/lib/CodeGen/Targets/AArch64.cpp
index ec617eec67192c..269b5b352bfd84 100644
--- a/clang/lib/CodeGen/Targets/AArch64.cpp
+++ b/clang/lib/CodeGen/Targets/AArch64.cpp
@@ -36,8 +36,15 @@ class AArch64ABIInfo : public ABIInfo {
 
   ABIArgInfo classifyReturnType(QualType RetTy, bool IsVariadic) const;
   ABIArgInfo classifyArgumentType(QualType RetTy, bool IsVariadic,
-                                  unsigned CallingConvention) const;
-  ABIArgInfo coerceIllegalVector(QualType Ty) const;
+                                  unsigned CallingConvention, unsigned &NSRN,
+                                  unsigned &NPRN) const;
+  llvm::Type *convertFixedToScalableVectorType(const VectorType *VT) const;
+  ABIArgInfo coerceIllegalVector(QualType Ty, unsigned &NSRN,
+                                 unsigned &NPRN) const;
+  ABIArgInfo coerceAndExpandPureScalableAggregate(
+      QualType Ty, unsigned NVec, unsigned NPred,
+      const SmallVectorImpl<llvm::Type *> &UnpaddedCoerceToSeq, unsigned &NSRN,
+      unsigned &NPRN) const;
   bool isHomogeneousAggregateBaseType(QualType Ty) const override;
   bool isHomogeneousAggregateSmallEnough(const Type *Ty,
                                          uint64_t Members) const override;
@@ -45,14 +52,21 @@ class AArch64ABIInfo : public ABIInfo {
 
   bool isIllegalVectorType(QualType Ty) const;
 
+  bool isPureScalableType(QualType Ty, unsigned &NV, unsigned &NP,
+                          SmallVectorImpl<llvm::Type *> &CoerceToSeq) const;
+
+  void flattenType(llvm::Type *Ty,
+                   SmallVectorImpl<llvm::Type *> &Flattened) const;
+
   void computeInfo(CGFunctionInfo &FI) const override {
     if (!::classifyReturnType(getCXXABI(), FI, *this))
       FI.getReturnInfo() =
           classifyReturnType(FI.getReturnType(), FI.isVariadic());
 
+    unsigned NSRN = 0, NPRN = 0;
     for (auto &it : FI.arguments())
       it.info = classifyArgumentType(it.type, FI.isVariadic(),
-                                     FI.getCallingConvention());
+                                     FI.getCallingConvention(), NSRN, NPRN);
   }
 
   RValue EmitDarwinVAArg(Address VAListAddr, QualType Ty, CodeGenFunction &CGF,
@@ -201,65 +215,83 @@ void WindowsAArch64TargetCodeGenInfo::setTargetAttributes(
 }
 }
 
-ABIArgInfo AArch64ABIInfo::coerceIllegalVector(QualType Ty) const {
-  assert(Ty->isVectorType() && "expected vector type!");
+llvm::Type *
+AArch64ABIInfo::convertFixedToScalableVectorType(const VectorType *VT) const {
+  assert(VT->getElementType()->isBuiltinType() && "expected builtin type!");
 
-  const auto *VT = Ty->castAs<VectorType>();
   if (VT->getVectorKind() == VectorKind::SveFixedLengthPredicate) {
-    assert(VT->getElementType()->isBuiltinType() && "expected builtin type!");
     assert(VT->getElementType()->castAs<BuiltinType>()->getKind() ==
                BuiltinType::UChar &&
            "unexpected builtin type for SVE predicate!");
-    return ABIArgInfo::getDirect(llvm::ScalableVectorType::get(
-        llvm::Type::getInt1Ty(getVMContext()), 16));
+    return llvm::ScalableVectorType::get(llvm::Type::getInt1Ty(getVMContext()),
+                                         16);
   }
 
   if (VT->getVectorKind() == VectorKind::SveFixedLengthData) {
-    assert(VT->getElementType()->isBuiltinType() && "expected builtin type!");
-
     const auto *BT = VT->getElementType()->castAs<BuiltinType>();
-    llvm::ScalableVectorType *ResType = nullptr;
     switch (BT->getKind()) {
     default:
       llvm_unreachable("unexpected builtin type for SVE vector!");
+
     case BuiltinType::SChar:
     case BuiltinType::UChar:
-      ResType = llvm::ScalableVectorType::get(
+      return llvm::ScalableVectorType::get(
           llvm::Type::getInt8Ty(getVMContext()), 16);
-      break;
+
     case BuiltinType::Short:
     case BuiltinType::UShort:
-      ResType = llvm::ScalableVectorType::get(
+      return llvm::ScalableVectorType::get(
           llvm::Type::getInt16Ty(getVMContext()), 8);
-      break;
+
     case BuiltinType::Int:
     case BuiltinType::UInt:
-      ResType = llvm::ScalableVectorType::get(
+      return llvm::ScalableVectorType::get(
           llvm::Type::getInt32Ty(getVMContext()), 4);
-      break;
+
     case BuiltinType::Long:
     case BuiltinType::ULong:
-      ResType = llvm::ScalableVectorType::get(
+      return llvm::ScalableVectorType::get(
           llvm::Type::getInt64Ty(getVMContext()), 2);
-      break;
+
     case BuiltinType::Half:
-      ResType = llvm::ScalableVectorType::get(
+      return llvm::ScalableVectorType::get(
           llvm::Type::getHalfTy(getVMContext()), 8);
-      break;
+
     case BuiltinType::Float:
-      ResType = llvm::ScalableVectorType::get(
+      return llvm::ScalableVectorType::get(
           llvm::Type::getFloatTy(getVMContext()), 4);
-      break;
+
     case BuiltinType::Double:
-      ResType = llvm::ScalableVectorType::get(
+      return llvm::ScalableVectorType::get(
           llvm::Type::getDoubleTy(getVMContext()), 2);
-      break;
+
     case BuiltinType::BFloat16:
-      ResType = llvm::ScalableVectorType::get(
+      return llvm::ScalableVectorType::get(
           llvm::Type::getBFloatTy(getVMContext()), 8);
-      break;
     }
-    return ABIArgInfo::getDirect(ResType);
+  }
+
+  llvm_unreachable("expected fixed-length SVE vector");
+}
+
+ABIArgInfo AArch64ABIInfo::coerceIllegalVector(QualType Ty, unsigned &NSRN,
+                                               unsigned &NPRN) const {
+  assert(Ty->isVectorType() && "expected vector type!");
+
+  const auto *VT = Ty->castAs<VectorType>();
+  if (VT->getVectorKind() == VectorKind::SveFixedLengthPredicate) {
+    assert(VT->getElementType()->isBuiltinType() && "expected builtin type!");
+    assert(VT->getElementType()->castAs<BuiltinType>()->getKind() ==
+               BuiltinType::UChar &&
+           "unexpected builtin type for SVE predicate!");
+    NPRN = std::min(NPRN + 1, 4u);
+    return ABIArgInfo::getDirect(llvm::ScalableVectorType::get(
+        llvm::Type::getInt1Ty(getVMContext()), 16));
+  }
+
+  if (VT->getVectorKind() == VectorKind::SveFixedLengthData) {
+    NSRN = std::min(NSRN + 1, 8u);
+    return ABIArgInfo::getDirect(convertFixedToScalableVectorType(VT));
   }
 
   uint64_t Size = getContext().getTypeSize(Ty);
@@ -273,26 +305,53 @@ ABIArgInfo AArch64ABIInfo::coerceIllegalVector(QualType Ty) const {
     return ABIArgInfo::getDirect(ResType);
   }
   if (Size == 64) {
+    NSRN = std::min(NSRN + 1, 8u);
     auto *ResType =
         llvm::FixedVectorType::get(llvm::Type::getInt32Ty(getVMContext()), 2);
     return ABIArgInfo::getDirect(ResType);
   }
   if (Size == 128) {
+    NSRN = std::min(NSRN + 1, 8u);
     auto *ResType =
         llvm::FixedVectorType::get(llvm::Type::getInt32Ty(getVMContext()), 4);
     return ABIArgInfo::getDirect(ResType);
   }
+
   return getNaturalAlignIndirect(Ty, /*ByVal=*/false);
 }
 
-ABIArgInfo
-AArch64ABIInfo::classifyArgumentType(QualType Ty, bool IsVariadic,
-                                     unsigned CallingConvention) const {
+ABIArgInfo AArch64ABIInfo::coerceAndExpandPureScalableAggregate(
+    QualType Ty, unsigned NVec, unsigned NPred,
+    const SmallVectorImpl<llvm::Type *> &UnpaddedCoerceToSeq, unsigned &NSRN,
+    unsigned &NPRN) const {
+  if (NSRN + NVec > 8 || NPRN + NPred > 4)
+    return getNaturalAlignIndirect(Ty, /*ByVal=*/false);
+  NSRN += NVec;
+  NPRN += NPred;
+
+  llvm::Type *UnpaddedCoerceToType =
+      UnpaddedCoerceToSeq.size() == 1
+          ? UnpaddedCoerceToSeq[0]
+          : llvm::StructType::get(CGT.getLLVMContext(), UnpaddedCoerceToSeq,
+                                  true);
+
+  SmallVector<llvm::Type *> CoerceToSeq;
+  flattenType(CGT.ConvertType(Ty), CoerceToSeq);
+  auto *CoerceToType =
+      llvm::StructType::get(CGT.getLLVMContext(), CoerceToSeq, false);
+
+  return ABIArgInfo::getCoerceAndExpand(CoerceToType, UnpaddedCoerceToType);
+}
+
+ABIArgInfo AArch64ABIInfo::classifyArgumentType(QualType Ty, bool IsVariadic,
+                                                unsigned CallingConvention,
+                                                unsigned &NSRN,
+                                                unsigned &NPRN) const {
   Ty = useFirstFieldIfTransparentUnion(Ty);
 
   // Handle illegal vector types here.
   if (isIllegalVectorType(Ty))
-    return coerceIllegalVector(Ty);
+    return coerceIllegalVector(Ty, NSRN, NPRN);
 
   if (!isAggregateTypeForABI(Ty)) {
     // Treat an enum type as its underlying type.
@@ -303,6 +362,20 @@ AArch64ABIInfo::classifyArgumentType(QualType Ty, bool IsVariadic,
       if (EIT->getNumBits() > 128)
         return getNaturalAlignIndirect(Ty, false);
 
+    if (const BuiltinType *BT = Ty->getAs<BuiltinType>()) {
+      if (BT->isSVEBool() || BT->isSVECount())
+        NPRN = std::min(NPRN + 1, 4u);
+      else if (BT->getKind() == BuiltinType::SveBoolx2)
+        NPRN = std::min(NPRN + 2, 4u);
+      else if (BT->getKind() == BuiltinType::SveBoolx4)
+        NPRN = std::min(NPRN + 4, 4u);
+      else if (BT->isFloatingPoint() || BT->isVectorType())
+        NSRN = std::min(NSRN + 1, 8u);
+      else if (BT->isSVESizelessBuiltinType())
+        NSRN = std::min(
+            NSRN + getContext().getBuiltinVectorTypeInfo(BT).NumVectors, 8u);
+    }
+
     return (isPromotableIntegerTypeForABI(Ty) && isDarwinPCS()
                 ? ABIArgInfo::getExtend(Ty, CGT.ConvertType(Ty))
                 : ABIArgInfo::getDirect());
@@ -339,6 +412,7 @@ AArch64ABIInfo::classifyArgumentType(QualType Ty, bool IsVariadic,
   // In variadic functions on Windows, all composite types are treated alike,
   // no special handling of HFAs/HVAs.
   if (!IsWinVariadic && isHomogeneousAggregate(Ty, Base, Members)) {
+    NSRN = std::min(NSRN + Members, uint64_t(8));
     if (Kind != AArch64ABIKind::AAPCS)
       return ABIArgInfo::getDirect(
           llvm::ArrayType::get(CGT.ConvertType(QualType(Base, 0)), Members));
@@ -353,6 +427,17 @@ AArch64ABIInfo::classifyArgumentType(QualType Ty, bool IsVariadic,
         nullptr, true, Align);
   }
 
+  // In AAPCS named arguments of a Pure Scalable Type are passed expanded in
+  // registers, or indirectly if there are not enough registers.
+  if (Kind == AArch64ABIKind::AAPCS && !IsVariadic) {
+    unsigned NVec = 0, NPred = 0;
+    SmallVector<llvm::Type *> UnpaddedCoerceToSeq;
+    if (isPureScalableType(Ty, NVec, NPred, UnpaddedCoerceToSeq) &&
+        (NVec + NPred) > 0)
+      return coerceAndExpandPureScalableAggregate(
+          Ty, NVec, NPred, UnpaddedCoerceToSeq, NSRN, NPRN);
+  }
+
   // Aggregates <= 16 bytes are passed directly in registers or on the stack.
   if (Size <= 128) {
     // On RenderScript, coerce Aggregates <= 16 bytes to an integer array of
@@ -389,8 +474,10 @@ ABIArgInfo AArch64ABIInfo::classifyReturnType(QualType RetTy,
 
   if (const auto *VT = RetTy->getAs<VectorType>()) {
     if (VT->getVectorKind() == VectorKind::SveFixedLengthData ||
-        VT->getVectorKind() == VectorKind::SveFixedLengthPredicate)
-      return coerceIllegalVector(RetTy);
+        VT->getVectorKind() == VectorKind::SveFixedLengthPredicate) {
+      unsigned NSRN = 0, NPRN = 0;
+      return coerceIllegalVector(RetTy, NSRN, NPRN);
+    }
   }
 
   // Large vector types should be returned via memory.
@@ -423,6 +510,19 @@ ABIArgInfo AArch64ABIInfo::classifyReturnType(QualType RetTy,
     // Homogeneous Floating-point Aggregates (HFAs) are returned directly.
     return ABIArgInfo::getDirect();
 
+  // In AAPCS return values of a Pure Scalable type are treated is a first named
+  // argument and passed expanded in registers, or indirectly if there are not
+  // enough registers.
+  if (Kind == AArch64ABIKind::AAPCS) {
+    unsigned NSRN = 0, NPRN = 0;
+    unsigned NVec = 0, NPred = 0;
+    SmallVector<llvm::Type *> UnpaddedCoerceToSeq;
+    if (isPureScalableType(RetTy, NVec, NPred, UnpaddedCoerceToSeq) &&
+        (NVec + NPred) > 0)
+      return...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/112747


More information about the cfe-commits mailing list