[clang] [Clang][AArch64] Fix Pure Scalables Types argument passing and return (PR #112747)

Momchil Velikov via cfe-commits cfe-commits at lists.llvm.org
Fri Oct 18 08:31:56 PDT 2024


https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/112747

>From c2f223d84c18498f3cbe1582b006b0d4c52999aa Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Thu, 17 Oct 2024 14:04:05 +0100
Subject: [PATCH 1/3] [Clang][AArch64] Fix Pure Scalables Types argument
 passing and return

Pure Scalable Types are defined in AAPCS64 here:
  https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#pure-scalable-types-psts

And should be passed according to Rule C.7 here:
  https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#682parameter-passing-rules

This part of the ABI is completely unimplemented in Clang, instead
it treats PSTs sometimes as HFAs/HVAs, sometime as general composite
types.

This patch implements the rules for passing PSTs by employing the
`CoerceAndExpand` method and extending it to:
  * allow array types in the `coerceToType`; Now only `[N x i8]` are
    considered padding.
  * allow mismatch between the elements of the `coerceToType` and the
    elements of the `unpaddedCoerceToType`; AArch64 uses this to map
    fixed-length vector types to SVE vector types.

Corectly passing a PST argument needs a decision in Clang about
whether to pass it in memory or registers or, equivalently, whether
to use the `Indirect` or `Expand/CoerceAndExpand` method.
It was considered relatively harder (or not practically possible)
to make that decision in the AArch64 backend.
Hence this patch implements the register counting
from AAPCS64 (cf. "NSRN", "NPRN") to guide the Clang's decision.
---
 clang/include/clang/CodeGen/CGFunctionInfo.h  |  13 +-
 clang/lib/CodeGen/CGCall.cpp                  |  85 +++--
 clang/lib/CodeGen/Targets/AArch64.cpp         | 326 ++++++++++++++++--
 .../test/CodeGen/aarch64-pure-scalable-args.c | 314 +++++++++++++++++
 4 files changed, 669 insertions(+), 69 deletions(-)
 create mode 100644 clang/test/CodeGen/aarch64-pure-scalable-args.c

diff --git a/clang/include/clang/CodeGen/CGFunctionInfo.h b/clang/include/clang/CodeGen/CGFunctionInfo.h
index d19f84d198876f..915f676d7d3905 100644
--- a/clang/include/clang/CodeGen/CGFunctionInfo.h
+++ b/clang/include/clang/CodeGen/CGFunctionInfo.h
@@ -272,11 +272,6 @@ class ABIArgInfo {
     unsigned unpaddedIndex = 0;
     for (auto eltType : coerceToType->elements()) {
       if (isPaddingForCoerceAndExpand(eltType)) continue;
-      if (unpaddedStruct) {
-        assert(unpaddedStruct->getElementType(unpaddedIndex) == eltType);
-      } else {
-        assert(unpaddedIndex == 0 && unpaddedCoerceToType == eltType);
-      }
       unpaddedIndex++;
     }
 
@@ -295,12 +290,8 @@ class ABIArgInfo {
   }
 
   static bool isPaddingForCoerceAndExpand(llvm::Type *eltType) {
-    if (eltType->isArrayTy()) {
-      assert(eltType->getArrayElementType()->isIntegerTy(8));
-      return true;
-    } else {
-      return false;
-    }
+    return eltType->isArrayTy() &&
+           eltType->getArrayElementType()->isIntegerTy(8);
   }
 
   Kind getKind() const { return TheKind; }
diff --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp
index 4ae981e4013e9c..3c75dae9918af9 100644
--- a/clang/lib/CodeGen/CGCall.cpp
+++ b/clang/lib/CodeGen/CGCall.cpp
@@ -1410,6 +1410,30 @@ static Address emitAddressAtOffset(CodeGenFunction &CGF, Address addr,
   return addr;
 }
 
+static std::pair<llvm::Value *, bool>
+CoerceScalableToFixed(CodeGenFunction &CGF, llvm::FixedVectorType *ToTy,
+                      llvm::ScalableVectorType *FromTy, llvm::Value *V,
+                      StringRef Name = "") {
+  // If we are casting a scalable i1 predicate vector to a fixed i8
+  // vector, first bitcast the source.
+  if (FromTy->getElementType()->isIntegerTy(1) &&
+      FromTy->getElementCount().isKnownMultipleOf(8) &&
+      ToTy->getElementType() == CGF.Builder.getInt8Ty()) {
+    FromTy = llvm::ScalableVectorType::get(
+        ToTy->getElementType(),
+        FromTy->getElementCount().getKnownMinValue() / 8);
+    V = CGF.Builder.CreateBitCast(V, FromTy);
+  }
+  if (FromTy->getElementType() == ToTy->getElementType()) {
+    llvm::Value *Zero = llvm::Constant::getNullValue(CGF.CGM.Int64Ty);
+
+    V->setName(Name + ".coerce");
+    V = CGF.Builder.CreateExtractVector(ToTy, V, Zero, "cast.fixed");
+    return {V, true};
+  }
+  return {V, false};
+}
+
 namespace {
 
 /// Encapsulates information about the way function arguments from
@@ -3196,26 +3220,14 @@ void CodeGenFunction::EmitFunctionProlog(const CGFunctionInfo &FI,
       // a VLAT at the function boundary and the types match up, use
       // llvm.vector.extract to convert back to the original VLST.
       if (auto *VecTyTo = dyn_cast<llvm::FixedVectorType>(ConvertType(Ty))) {
-        llvm::Value *Coerced = Fn->getArg(FirstIRArg);
+        llvm::Value *ArgVal = Fn->getArg(FirstIRArg);
         if (auto *VecTyFrom =
-                dyn_cast<llvm::ScalableVectorType>(Coerced->getType())) {
-          // If we are casting a scalable i1 predicate vector to a fixed i8
-          // vector, bitcast the source and use a vector extract.
-          if (VecTyFrom->getElementType()->isIntegerTy(1) &&
-              VecTyFrom->getElementCount().isKnownMultipleOf(8) &&
-              VecTyTo->getElementType() == Builder.getInt8Ty()) {
-            VecTyFrom = llvm::ScalableVectorType::get(
-                VecTyTo->getElementType(),
-                VecTyFrom->getElementCount().getKnownMinValue() / 8);
-            Coerced = Builder.CreateBitCast(Coerced, VecTyFrom);
-          }
-          if (VecTyFrom->getElementType() == VecTyTo->getElementType()) {
-            llvm::Value *Zero = llvm::Constant::getNullValue(CGM.Int64Ty);
-
+                dyn_cast<llvm::ScalableVectorType>(ArgVal->getType())) {
+          auto [Coerced, Extracted] = CoerceScalableToFixed(
+              *this, VecTyTo, VecTyFrom, ArgVal, Arg->getName());
+          if (Extracted) {
             assert(NumIRArgs == 1);
-            Coerced->setName(Arg->getName() + ".coerce");
-            ArgVals.push_back(ParamValue::forDirect(Builder.CreateExtractVector(
-                VecTyTo, Coerced, Zero, "cast.fixed")));
+            ArgVals.push_back(ParamValue::forDirect(Coerced));
             break;
           }
         }
@@ -3326,16 +3338,33 @@ void CodeGenFunction::EmitFunctionProlog(const CGFunctionInfo &FI,
       ArgVals.push_back(ParamValue::forIndirect(alloca));
 
       auto coercionType = ArgI.getCoerceAndExpandType();
+      auto unpaddedCoercionType = ArgI.getUnpaddedCoerceAndExpandType();
+      auto *unpaddedStruct = dyn_cast<llvm::StructType>(unpaddedCoercionType);
+
       alloca = alloca.withElementType(coercionType);
 
       unsigned argIndex = FirstIRArg;
+      unsigned unpaddedIndex = 0;
       for (unsigned i = 0, e = coercionType->getNumElements(); i != e; ++i) {
         llvm::Type *eltType = coercionType->getElementType(i);
         if (ABIArgInfo::isPaddingForCoerceAndExpand(eltType))
           continue;
 
         auto eltAddr = Builder.CreateStructGEP(alloca, i);
-        auto elt = Fn->getArg(argIndex++);
+        llvm::Value *elt = Fn->getArg(argIndex++);
+
+        auto paramType = unpaddedStruct
+                             ? unpaddedStruct->getElementType(unpaddedIndex++)
+                             : unpaddedCoercionType;
+
+        if (auto *VecTyTo = dyn_cast<llvm::FixedVectorType>(eltType)) {
+          if (auto *VecTyFrom = dyn_cast<llvm::ScalableVectorType>(paramType)) {
+            bool Extracted;
+            std::tie(elt, Extracted) = CoerceScalableToFixed(
+                *this, VecTyTo, VecTyFrom, elt, elt->getName());
+            assert(Extracted && "Unexpected scalable to fixed vector coercion");
+          }
+        }
         Builder.CreateStore(elt, eltAddr);
       }
       assert(argIndex == FirstIRArg + NumIRArgs);
@@ -3930,17 +3959,24 @@ void CodeGenFunction::EmitFunctionEpilog(const CGFunctionInfo &FI,
 
   case ABIArgInfo::CoerceAndExpand: {
     auto coercionType = RetAI.getCoerceAndExpandType();
+    auto unpaddedCoercionType = RetAI.getUnpaddedCoerceAndExpandType();
+    auto *unpaddedStruct = dyn_cast<llvm::StructType>(unpaddedCoercionType);
 
     // Load all of the coerced elements out into results.
     llvm::SmallVector<llvm::Value*, 4> results;
     Address addr = ReturnValue.withElementType(coercionType);
+    unsigned unpaddedIndex = 0;
     for (unsigned i = 0, e = coercionType->getNumElements(); i != e; ++i) {
       auto coercedEltType = coercionType->getElementType(i);
       if (ABIArgInfo::isPaddingForCoerceAndExpand(coercedEltType))
         continue;
 
       auto eltAddr = Builder.CreateStructGEP(addr, i);
-      auto elt = Builder.CreateLoad(eltAddr);
+      llvm::Value *elt = CreateCoercedLoad(
+          eltAddr,
+          unpaddedStruct ? unpaddedStruct->getElementType(unpaddedIndex++)
+                         : unpaddedCoercionType,
+          *this);
       results.push_back(elt);
     }
 
@@ -5468,6 +5504,8 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
     case ABIArgInfo::CoerceAndExpand: {
       auto coercionType = ArgInfo.getCoerceAndExpandType();
       auto layout = CGM.getDataLayout().getStructLayout(coercionType);
+      auto unpaddedCoercionType = ArgInfo.getUnpaddedCoerceAndExpandType();
+      auto *unpaddedStruct = dyn_cast<llvm::StructType>(unpaddedCoercionType);
 
       llvm::Value *tempSize = nullptr;
       Address addr = Address::invalid();
@@ -5498,11 +5536,16 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
       addr = addr.withElementType(coercionType);
 
       unsigned IRArgPos = FirstIRArg;
+      unsigned unpaddedIndex = 0;
       for (unsigned i = 0, e = coercionType->getNumElements(); i != e; ++i) {
         llvm::Type *eltType = coercionType->getElementType(i);
         if (ABIArgInfo::isPaddingForCoerceAndExpand(eltType)) continue;
         Address eltAddr = Builder.CreateStructGEP(addr, i);
-        llvm::Value *elt = Builder.CreateLoad(eltAddr);
+        llvm::Value *elt = CreateCoercedLoad(
+            eltAddr,
+            unpaddedStruct ? unpaddedStruct->getElementType(unpaddedIndex++)
+                           : unpaddedCoercionType,
+            *this);
         if (ArgHasMaybeUndefAttr)
           elt = Builder.CreateFreeze(elt);
         IRCallArgs[IRArgPos++] = elt;
diff --git a/clang/lib/CodeGen/Targets/AArch64.cpp b/clang/lib/CodeGen/Targets/AArch64.cpp
index ec617eec67192c..269b5b352bfd84 100644
--- a/clang/lib/CodeGen/Targets/AArch64.cpp
+++ b/clang/lib/CodeGen/Targets/AArch64.cpp
@@ -36,8 +36,15 @@ class AArch64ABIInfo : public ABIInfo {
 
   ABIArgInfo classifyReturnType(QualType RetTy, bool IsVariadic) const;
   ABIArgInfo classifyArgumentType(QualType RetTy, bool IsVariadic,
-                                  unsigned CallingConvention) const;
-  ABIArgInfo coerceIllegalVector(QualType Ty) const;
+                                  unsigned CallingConvention, unsigned &NSRN,
+                                  unsigned &NPRN) const;
+  llvm::Type *convertFixedToScalableVectorType(const VectorType *VT) const;
+  ABIArgInfo coerceIllegalVector(QualType Ty, unsigned &NSRN,
+                                 unsigned &NPRN) const;
+  ABIArgInfo coerceAndExpandPureScalableAggregate(
+      QualType Ty, unsigned NVec, unsigned NPred,
+      const SmallVectorImpl<llvm::Type *> &UnpaddedCoerceToSeq, unsigned &NSRN,
+      unsigned &NPRN) const;
   bool isHomogeneousAggregateBaseType(QualType Ty) const override;
   bool isHomogeneousAggregateSmallEnough(const Type *Ty,
                                          uint64_t Members) const override;
@@ -45,14 +52,21 @@ class AArch64ABIInfo : public ABIInfo {
 
   bool isIllegalVectorType(QualType Ty) const;
 
+  bool isPureScalableType(QualType Ty, unsigned &NV, unsigned &NP,
+                          SmallVectorImpl<llvm::Type *> &CoerceToSeq) const;
+
+  void flattenType(llvm::Type *Ty,
+                   SmallVectorImpl<llvm::Type *> &Flattened) const;
+
   void computeInfo(CGFunctionInfo &FI) const override {
     if (!::classifyReturnType(getCXXABI(), FI, *this))
       FI.getReturnInfo() =
           classifyReturnType(FI.getReturnType(), FI.isVariadic());
 
+    unsigned NSRN = 0, NPRN = 0;
     for (auto &it : FI.arguments())
       it.info = classifyArgumentType(it.type, FI.isVariadic(),
-                                     FI.getCallingConvention());
+                                     FI.getCallingConvention(), NSRN, NPRN);
   }
 
   RValue EmitDarwinVAArg(Address VAListAddr, QualType Ty, CodeGenFunction &CGF,
@@ -201,65 +215,83 @@ void WindowsAArch64TargetCodeGenInfo::setTargetAttributes(
 }
 }
 
-ABIArgInfo AArch64ABIInfo::coerceIllegalVector(QualType Ty) const {
-  assert(Ty->isVectorType() && "expected vector type!");
+llvm::Type *
+AArch64ABIInfo::convertFixedToScalableVectorType(const VectorType *VT) const {
+  assert(VT->getElementType()->isBuiltinType() && "expected builtin type!");
 
-  const auto *VT = Ty->castAs<VectorType>();
   if (VT->getVectorKind() == VectorKind::SveFixedLengthPredicate) {
-    assert(VT->getElementType()->isBuiltinType() && "expected builtin type!");
     assert(VT->getElementType()->castAs<BuiltinType>()->getKind() ==
                BuiltinType::UChar &&
            "unexpected builtin type for SVE predicate!");
-    return ABIArgInfo::getDirect(llvm::ScalableVectorType::get(
-        llvm::Type::getInt1Ty(getVMContext()), 16));
+    return llvm::ScalableVectorType::get(llvm::Type::getInt1Ty(getVMContext()),
+                                         16);
   }
 
   if (VT->getVectorKind() == VectorKind::SveFixedLengthData) {
-    assert(VT->getElementType()->isBuiltinType() && "expected builtin type!");
-
     const auto *BT = VT->getElementType()->castAs<BuiltinType>();
-    llvm::ScalableVectorType *ResType = nullptr;
     switch (BT->getKind()) {
     default:
       llvm_unreachable("unexpected builtin type for SVE vector!");
+
     case BuiltinType::SChar:
     case BuiltinType::UChar:
-      ResType = llvm::ScalableVectorType::get(
+      return llvm::ScalableVectorType::get(
           llvm::Type::getInt8Ty(getVMContext()), 16);
-      break;
+
     case BuiltinType::Short:
     case BuiltinType::UShort:
-      ResType = llvm::ScalableVectorType::get(
+      return llvm::ScalableVectorType::get(
           llvm::Type::getInt16Ty(getVMContext()), 8);
-      break;
+
     case BuiltinType::Int:
     case BuiltinType::UInt:
-      ResType = llvm::ScalableVectorType::get(
+      return llvm::ScalableVectorType::get(
           llvm::Type::getInt32Ty(getVMContext()), 4);
-      break;
+
     case BuiltinType::Long:
     case BuiltinType::ULong:
-      ResType = llvm::ScalableVectorType::get(
+      return llvm::ScalableVectorType::get(
           llvm::Type::getInt64Ty(getVMContext()), 2);
-      break;
+
     case BuiltinType::Half:
-      ResType = llvm::ScalableVectorType::get(
+      return llvm::ScalableVectorType::get(
           llvm::Type::getHalfTy(getVMContext()), 8);
-      break;
+
     case BuiltinType::Float:
-      ResType = llvm::ScalableVectorType::get(
+      return llvm::ScalableVectorType::get(
           llvm::Type::getFloatTy(getVMContext()), 4);
-      break;
+
     case BuiltinType::Double:
-      ResType = llvm::ScalableVectorType::get(
+      return llvm::ScalableVectorType::get(
           llvm::Type::getDoubleTy(getVMContext()), 2);
-      break;
+
     case BuiltinType::BFloat16:
-      ResType = llvm::ScalableVectorType::get(
+      return llvm::ScalableVectorType::get(
           llvm::Type::getBFloatTy(getVMContext()), 8);
-      break;
     }
-    return ABIArgInfo::getDirect(ResType);
+  }
+
+  llvm_unreachable("expected fixed-length SVE vector");
+}
+
+ABIArgInfo AArch64ABIInfo::coerceIllegalVector(QualType Ty, unsigned &NSRN,
+                                               unsigned &NPRN) const {
+  assert(Ty->isVectorType() && "expected vector type!");
+
+  const auto *VT = Ty->castAs<VectorType>();
+  if (VT->getVectorKind() == VectorKind::SveFixedLengthPredicate) {
+    assert(VT->getElementType()->isBuiltinType() && "expected builtin type!");
+    assert(VT->getElementType()->castAs<BuiltinType>()->getKind() ==
+               BuiltinType::UChar &&
+           "unexpected builtin type for SVE predicate!");
+    NPRN = std::min(NPRN + 1, 4u);
+    return ABIArgInfo::getDirect(llvm::ScalableVectorType::get(
+        llvm::Type::getInt1Ty(getVMContext()), 16));
+  }
+
+  if (VT->getVectorKind() == VectorKind::SveFixedLengthData) {
+    NSRN = std::min(NSRN + 1, 8u);
+    return ABIArgInfo::getDirect(convertFixedToScalableVectorType(VT));
   }
 
   uint64_t Size = getContext().getTypeSize(Ty);
@@ -273,26 +305,53 @@ ABIArgInfo AArch64ABIInfo::coerceIllegalVector(QualType Ty) const {
     return ABIArgInfo::getDirect(ResType);
   }
   if (Size == 64) {
+    NSRN = std::min(NSRN + 1, 8u);
     auto *ResType =
         llvm::FixedVectorType::get(llvm::Type::getInt32Ty(getVMContext()), 2);
     return ABIArgInfo::getDirect(ResType);
   }
   if (Size == 128) {
+    NSRN = std::min(NSRN + 1, 8u);
     auto *ResType =
         llvm::FixedVectorType::get(llvm::Type::getInt32Ty(getVMContext()), 4);
     return ABIArgInfo::getDirect(ResType);
   }
+
   return getNaturalAlignIndirect(Ty, /*ByVal=*/false);
 }
 
-ABIArgInfo
-AArch64ABIInfo::classifyArgumentType(QualType Ty, bool IsVariadic,
-                                     unsigned CallingConvention) const {
+ABIArgInfo AArch64ABIInfo::coerceAndExpandPureScalableAggregate(
+    QualType Ty, unsigned NVec, unsigned NPred,
+    const SmallVectorImpl<llvm::Type *> &UnpaddedCoerceToSeq, unsigned &NSRN,
+    unsigned &NPRN) const {
+  if (NSRN + NVec > 8 || NPRN + NPred > 4)
+    return getNaturalAlignIndirect(Ty, /*ByVal=*/false);
+  NSRN += NVec;
+  NPRN += NPred;
+
+  llvm::Type *UnpaddedCoerceToType =
+      UnpaddedCoerceToSeq.size() == 1
+          ? UnpaddedCoerceToSeq[0]
+          : llvm::StructType::get(CGT.getLLVMContext(), UnpaddedCoerceToSeq,
+                                  true);
+
+  SmallVector<llvm::Type *> CoerceToSeq;
+  flattenType(CGT.ConvertType(Ty), CoerceToSeq);
+  auto *CoerceToType =
+      llvm::StructType::get(CGT.getLLVMContext(), CoerceToSeq, false);
+
+  return ABIArgInfo::getCoerceAndExpand(CoerceToType, UnpaddedCoerceToType);
+}
+
+ABIArgInfo AArch64ABIInfo::classifyArgumentType(QualType Ty, bool IsVariadic,
+                                                unsigned CallingConvention,
+                                                unsigned &NSRN,
+                                                unsigned &NPRN) const {
   Ty = useFirstFieldIfTransparentUnion(Ty);
 
   // Handle illegal vector types here.
   if (isIllegalVectorType(Ty))
-    return coerceIllegalVector(Ty);
+    return coerceIllegalVector(Ty, NSRN, NPRN);
 
   if (!isAggregateTypeForABI(Ty)) {
     // Treat an enum type as its underlying type.
@@ -303,6 +362,20 @@ AArch64ABIInfo::classifyArgumentType(QualType Ty, bool IsVariadic,
       if (EIT->getNumBits() > 128)
         return getNaturalAlignIndirect(Ty, false);
 
+    if (const BuiltinType *BT = Ty->getAs<BuiltinType>()) {
+      if (BT->isSVEBool() || BT->isSVECount())
+        NPRN = std::min(NPRN + 1, 4u);
+      else if (BT->getKind() == BuiltinType::SveBoolx2)
+        NPRN = std::min(NPRN + 2, 4u);
+      else if (BT->getKind() == BuiltinType::SveBoolx4)
+        NPRN = std::min(NPRN + 4, 4u);
+      else if (BT->isFloatingPoint() || BT->isVectorType())
+        NSRN = std::min(NSRN + 1, 8u);
+      else if (BT->isSVESizelessBuiltinType())
+        NSRN = std::min(
+            NSRN + getContext().getBuiltinVectorTypeInfo(BT).NumVectors, 8u);
+    }
+
     return (isPromotableIntegerTypeForABI(Ty) && isDarwinPCS()
                 ? ABIArgInfo::getExtend(Ty, CGT.ConvertType(Ty))
                 : ABIArgInfo::getDirect());
@@ -339,6 +412,7 @@ AArch64ABIInfo::classifyArgumentType(QualType Ty, bool IsVariadic,
   // In variadic functions on Windows, all composite types are treated alike,
   // no special handling of HFAs/HVAs.
   if (!IsWinVariadic && isHomogeneousAggregate(Ty, Base, Members)) {
+    NSRN = std::min(NSRN + Members, uint64_t(8));
     if (Kind != AArch64ABIKind::AAPCS)
       return ABIArgInfo::getDirect(
           llvm::ArrayType::get(CGT.ConvertType(QualType(Base, 0)), Members));
@@ -353,6 +427,17 @@ AArch64ABIInfo::classifyArgumentType(QualType Ty, bool IsVariadic,
         nullptr, true, Align);
   }
 
+  // In AAPCS named arguments of a Pure Scalable Type are passed expanded in
+  // registers, or indirectly if there are not enough registers.
+  if (Kind == AArch64ABIKind::AAPCS && !IsVariadic) {
+    unsigned NVec = 0, NPred = 0;
+    SmallVector<llvm::Type *> UnpaddedCoerceToSeq;
+    if (isPureScalableType(Ty, NVec, NPred, UnpaddedCoerceToSeq) &&
+        (NVec + NPred) > 0)
+      return coerceAndExpandPureScalableAggregate(
+          Ty, NVec, NPred, UnpaddedCoerceToSeq, NSRN, NPRN);
+  }
+
   // Aggregates <= 16 bytes are passed directly in registers or on the stack.
   if (Size <= 128) {
     // On RenderScript, coerce Aggregates <= 16 bytes to an integer array of
@@ -389,8 +474,10 @@ ABIArgInfo AArch64ABIInfo::classifyReturnType(QualType RetTy,
 
   if (const auto *VT = RetTy->getAs<VectorType>()) {
     if (VT->getVectorKind() == VectorKind::SveFixedLengthData ||
-        VT->getVectorKind() == VectorKind::SveFixedLengthPredicate)
-      return coerceIllegalVector(RetTy);
+        VT->getVectorKind() == VectorKind::SveFixedLengthPredicate) {
+      unsigned NSRN = 0, NPRN = 0;
+      return coerceIllegalVector(RetTy, NSRN, NPRN);
+    }
   }
 
   // Large vector types should be returned via memory.
@@ -423,6 +510,19 @@ ABIArgInfo AArch64ABIInfo::classifyReturnType(QualType RetTy,
     // Homogeneous Floating-point Aggregates (HFAs) are returned directly.
     return ABIArgInfo::getDirect();
 
+  // In AAPCS return values of a Pure Scalable type are treated is a first named
+  // argument and passed expanded in registers, or indirectly if there are not
+  // enough registers.
+  if (Kind == AArch64ABIKind::AAPCS) {
+    unsigned NSRN = 0, NPRN = 0;
+    unsigned NVec = 0, NPred = 0;
+    SmallVector<llvm::Type *> UnpaddedCoerceToSeq;
+    if (isPureScalableType(RetTy, NVec, NPred, UnpaddedCoerceToSeq) &&
+        (NVec + NPred) > 0)
+      return coerceAndExpandPureScalableAggregate(
+          RetTy, NVec, NPred, UnpaddedCoerceToSeq, NSRN, NPRN);
+  }
+
   // Aggregates <= 16 bytes are returned directly in registers or on the stack.
   if (Size <= 128) {
     // On RenderScript, coerce Aggregates <= 16 bytes to an integer array of
@@ -511,6 +611,11 @@ bool AArch64ABIInfo::isHomogeneousAggregateBaseType(QualType Ty) const {
     if (BT->isFloatingPoint())
       return true;
   } else if (const VectorType *VT = Ty->getAs<VectorType>()) {
+    if (auto Kind = VT->getVectorKind();
+        Kind == VectorKind::SveFixedLengthData ||
+        Kind == VectorKind::SveFixedLengthPredicate)
+      return false;
+
     unsigned VecSize = getContext().getTypeSize(VT);
     if (VecSize == 64 || VecSize == 128)
       return true;
@@ -533,11 +638,158 @@ bool AArch64ABIInfo::isZeroLengthBitfieldPermittedInHomogeneousAggregate()
   return true;
 }
 
+// Check if a type is a Pure Scalable Type as defined by AAPCS64. Return the
+// number of data vectors and the number of predicate vectors in the types, into
+// `NVec` and `NPred`, respectively. Upon return `CoerceToSeq` contains an
+// expanded sequence of LLVM IR types, one element for each non-composite
+// member. For practical purposes, limit the length of `CoerceToSeq` to about
+// 12, the maximum size that could possibly fit in registers.
+bool AArch64ABIInfo::isPureScalableType(
+    QualType Ty, unsigned &NVec, unsigned &NPred,
+    SmallVectorImpl<llvm::Type *> &CoerceToSeq) const {
+  if (const ConstantArrayType *AT = getContext().getAsConstantArrayType(Ty)) {
+    uint64_t NElt = AT->getZExtSize();
+    if (NElt == 0)
+      return false;
+
+    unsigned NV = 0, NP = 0;
+    SmallVector<llvm::Type *> EltCoerceToSeq;
+    if (!isPureScalableType(AT->getElementType(), NV, NP, EltCoerceToSeq))
+      return false;
+
+    for (uint64_t I = 0; CoerceToSeq.size() < 12 && I < NElt; ++I)
+      llvm::copy(EltCoerceToSeq, std::back_inserter(CoerceToSeq));
+
+    NVec += NElt * NV;
+    NPred += NElt * NP;
+    return true;
+  }
+
+  if (const RecordType *RT = Ty->getAs<RecordType>()) {
+    // If the record cannot be passed in registers, then it's not a PST.
+    if (CGCXXABI::RecordArgABI RAA = getRecordArgABI(RT, getCXXABI());
+        RAA != CGCXXABI::RAA_Default)
+      return false;
+
+    // Pure scalable types are never unions and never contain unions.
+    const RecordDecl *RD = RT->getDecl();
+    if (RD->isUnion())
+      return false;
+
+    // If this is a C++ record, check the bases bases.
+    if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+      for (const auto &I : CXXRD->bases()) {
+        if (isEmptyRecord(getContext(), I.getType(), true))
+          continue;
+        if (!isPureScalableType(I.getType(), NVec, NPred, CoerceToSeq))
+          return false;
+      }
+    }
+
+    // Check members.
+    for (const auto *FD : RD->fields()) {
+      QualType FT = FD->getType();
+      if (isEmptyRecord(getContext(), FT, true))
+        continue;
+      if (!isPureScalableType(FT, NVec, NPred, CoerceToSeq))
+        return false;
+    }
+
+    return true;
+  }
+
+  const auto *VT = Ty->getAs<VectorType>();
+  if (!VT)
+    return false;
+
+  if (VT->getVectorKind() == VectorKind::SveFixedLengthPredicate) {
+    ++NPred;
+    if (CoerceToSeq.size() < 12)
+      CoerceToSeq.push_back(convertFixedToScalableVectorType(VT));
+    return true;
+  }
+
+  if (VT->getVectorKind() == VectorKind::SveFixedLengthData) {
+    ++NVec;
+    if (CoerceToSeq.size() < 12)
+      CoerceToSeq.push_back(convertFixedToScalableVectorType(VT));
+    return true;
+  }
+
+  if (!VT->isBuiltinType())
+    return false;
+
+  switch (cast<BuiltinType>(VT)->getKind()) {
+#define SVE_VECTOR_TYPE(Name, MangledName, Id, SingletonId)                    \
+  case BuiltinType::Id:                                                        \
+    ++NVec;                                                                    \
+    break;
+#define SVE_PREDICATE_TYPE(Name, MangledName, Id, SingletonId)                 \
+  case BuiltinType::Id:                                                        \
+    ++NPred;                                                                   \
+    break;
+#define SVE_TYPE(Name, Id, SingletonId)
+#include "clang/Basic/AArch64SVEACLETypes.def"
+  default:
+    return false;
+  }
+
+  ASTContext::BuiltinVectorTypeInfo Info =
+      getContext().getBuiltinVectorTypeInfo(cast<BuiltinType>(Ty));
+  assert(Info.NumVectors > 0 && Info.NumVectors <= 4 &&
+         "Expected 1, 2, 3 or 4 vectors!");
+  auto VTy = llvm::ScalableVectorType::get(CGT.ConvertType(Info.ElementType),
+                                           Info.EC.getKnownMinValue());
+
+  if (CoerceToSeq.size() < 12)
+    std::fill_n(std::back_inserter(CoerceToSeq), Info.NumVectors, VTy);
+
+  return true;
+}
+
+// Expand an LLVM IR type into a sequence with a element for each non-struct,
+// non-array member of the type, with the exception of the padding types, which
+// are retained.
+void AArch64ABIInfo::flattenType(
+    llvm::Type *Ty, SmallVectorImpl<llvm::Type *> &Flattened) const {
+
+  if (ABIArgInfo::isPaddingForCoerceAndExpand(Ty)) {
+    Flattened.push_back(Ty);
+    return;
+  }
+
+  if (const auto *AT = dyn_cast<llvm::ArrayType>(Ty)) {
+    uint64_t NElt = AT->getNumElements();
+    if (NElt == 0)
+      return;
+
+    SmallVector<llvm::Type *> EltFlattened;
+    flattenType(AT->getElementType(), EltFlattened);
+
+    for (uint64_t I = 0; I < NElt; ++I)
+      llvm::copy(EltFlattened, std::back_inserter(Flattened));
+    return;
+  }
+
+  if (const auto *ST = dyn_cast<llvm::StructType>(Ty)) {
+    for (auto *ET : ST->elements())
+      flattenType(ET, Flattened);
+    return;
+  }
+
+  Flattened.push_back(Ty);
+}
+
 RValue AArch64ABIInfo::EmitAAPCSVAArg(Address VAListAddr, QualType Ty,
                                       CodeGenFunction &CGF, AArch64ABIKind Kind,
                                       AggValueSlot Slot) const {
-  ABIArgInfo AI = classifyArgumentType(Ty, /*IsVariadic=*/true,
-                                       CGF.CurFnInfo->getCallingConvention());
+  // These numbers are not used for variadic arguments, hence it doesn't matter
+  // they don't retain their values accross multiple calls to
+  // `classifyArgumentType` here.
+  unsigned NSRN = 0, NPRN = 0;
+  ABIArgInfo AI =
+      classifyArgumentType(Ty, /*IsVariadic=*/true,
+                           CGF.CurFnInfo->getCallingConvention(), NSRN, NPRN);
   // Empty records are ignored for parameter passing purposes.
   if (AI.isIgnore())
     return Slot.asRValue();
diff --git a/clang/test/CodeGen/aarch64-pure-scalable-args.c b/clang/test/CodeGen/aarch64-pure-scalable-args.c
new file mode 100644
index 00000000000000..e5e50e85d9b93a
--- /dev/null
+++ b/clang/test/CodeGen/aarch64-pure-scalable-args.c
@@ -0,0 +1,314 @@
+// RUN: %clang_cc1 -O3 -triple aarch64                                  -target-feature +sve -target-feature +sve2p1 -mvscale-min=1 -mvscale-max=1 -emit-llvm -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-AAPCS
+// RUN: %clang_cc1 -O3 -triple arm64-apple-ios7.0 -target-abi darwinpcs -target-feature +sve -target-feature +sve2p1 -mvscale-min=1 -mvscale-max=1 -emit-llvm -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-DARWIN
+// RUN: %clang_cc1 -O3 -triple aarch64-linux-gnu                        -target-feature +sve -target-feature +sve2p1 -mvscale-min=1 -mvscale-max=1 -emit-llvm -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-AAPCS
+
+// REQUIRES: aarch64-registered-target
+
+#include <arm_sve.h>
+
+typedef svfloat32_t fvec32 __attribute__((arm_sve_vector_bits(128)));
+typedef svfloat64_t fvec64 __attribute__((arm_sve_vector_bits(128)));
+typedef svbool_t bvec __attribute__((arm_sve_vector_bits(128)));
+typedef svmfloat8_t mfvec8 __attribute__((arm_sve_vector_bits(128)));
+
+typedef struct {
+    float f[4];
+} HFA;
+
+// Pure Scalable Type, needs 4 Z-regs, 2 P-regs
+typedef struct {
+     bvec a;
+     fvec64 x;
+     fvec32 y[2];
+     mfvec8 z;
+     bvec b;
+} PST;
+
+// Pure Scalable Type, 1 Z-reg
+typedef struct {
+    fvec32 x;
+} SmallPST;
+
+// Big PST, does not fit in registers.
+typedef struct {
+    struct {
+        bvec a;
+        fvec32 x[4];
+    } u[2];
+    fvec64 v;
+} BigPST;
+
+// A small aggregate type
+typedef struct  {
+    char data[16];
+} SmallAgg;
+
+// CHECK: %struct.PST = type { <2 x i8>, <2 x double>, [2 x <4 x float>], <16 x i8>, <2 x i8> }
+
+// Test argument passing of Pure Scalable Types by examining the generated
+// LLVM IR function declarations. A PST argument in C/C++ should map to:
+//   a) an `ptr` argument, if passed indirectly through memory
+//   b) a series of scalable vector arguments, if passed via registers
+
+// Simple argument passing, PST expanded into registers.
+//   a    -> p0
+//   b    -> p1
+//   x    -> q0
+//   y[0] -> q1
+//   y[1] -> q2
+//   z    -> q3
+void test_argpass_simple(PST *p) {
+    void argpass_simple_callee(PST);
+    argpass_simple_callee(*p);
+}
+// CHECK-AAPCS:      define dso_local void @test_argpass_simple(ptr nocapture noundef readonly %p)
+// CHECK-AAPCS-NEXT: entry:
+// CHECK-AAPCS-NEXT: %0 = load <2 x i8>, ptr %p, align 16
+// CHECK-AAPCS-NEXT: %cast.scalable = tail call <vscale x 2 x i8> @llvm.vector.insert.nxv2i8.v2i8(<vscale x 2 x i8> undef, <2 x i8> %0, i64 0)
+// CHECK-AAPCS-NEXT: %1 = bitcast <vscale x 2 x i8> %cast.scalable to <vscale x 16 x i1>
+// CHECK-AAPCS-NEXT: %2 = getelementptr inbounds nuw i8, ptr %p, i64 16
+// CHECK-AAPCS-NEXT: %3 = load <2 x double>, ptr %2, align 16
+// CHECK-AAPCS-NEXT: %cast.scalable1 = tail call <vscale x 2 x double> @llvm.vector.insert.nxv2f64.v2f64(<vscale x 2 x double> undef, <2 x double> %3, i64 0)
+// CHECK-AAPCS-NEXT: %4 = getelementptr inbounds nuw i8, ptr %p, i64 32
+// CHECK-AAPCS-NEXT: %5 = load <4 x float>, ptr %4, align 16
+// CHECK-AAPCS-NEXT: %cast.scalable2 = tail call <vscale x 4 x float> @llvm.vector.insert.nxv4f32.v4f32(<vscale x 4 x float> undef, <4 x float> %5, i64 0)
+// CHECK-AAPCS-NEXT: %6 = getelementptr inbounds nuw i8, ptr %p, i64 48
+// CHECK-AAPCS-NEXT: %7 = load <4 x float>, ptr %6, align 16
+// CHECK-AAPCS-NEXT: %cast.scalable3 = tail call <vscale x 4 x float> @llvm.vector.insert.nxv4f32.v4f32(<vscale x 4 x float> undef, <4 x float> %7, i64 0)
+// CHECK-AAPCS-NEXT: %8 = getelementptr inbounds nuw i8, ptr %p, i64 64
+// CHECK-AAPCS-NEXT: %9 = load <16 x i8>, ptr %8, align 16
+// CHECK-AAPCS-NEXT: %cast.scalable4 = tail call <vscale x 16 x i8> @llvm.vector.insert.nxv16i8.v16i8(<vscale x 16 x i8> undef, <16 x i8> %9, i64 0)
+// CHECK-AAPCS-NEXT: %10 = getelementptr inbounds nuw i8, ptr %p, i64 80
+// CHECK-AAPCS-NEXT: %11 = load <2 x i8>, ptr %10, align 16
+// CHECK-AAPCS-NEXT: %cast.scalable5 = tail call <vscale x 2 x i8> @llvm.vector.insert.nxv2i8.v2i8(<vscale x 2 x i8> undef, <2 x i8> %11, i64 0)
+// CHECK-AAPCS-NEXT: %12 = bitcast <vscale x 2 x i8> %cast.scalable5 to <vscale x 16 x i1>
+// CHECK-AAPCS-NEXT: tail call void @argpass_simple_callee(<vscale x 16 x i1> %1, <vscale x 2 x double> %cast.scalable1, <vscale x 4 x float> %cast.scalable2, <vscale x 4 x float> %cast.scalable3, <vscale x 16 x i8> %cast.scalable4, <vscale x 16 x i1> %12)
+// CHECK-AAPCS-NEXT: ret void
+
+// CHECK-AAPCS:  declare void @argpass_simple_callee(<vscale x 16 x i1>, <vscale x 2 x double>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 16 x i8>, <vscale x 16 x i1>)
+// CHECK-DARWIN: declare void @argpass_simple_callee(ptr noundef)
+
+// Boundary case of using the last available Z-reg, PST expanded.
+//   0.0  -> d0-d3
+//   a    -> p0
+//   b    -> p1
+//   x    -> q4
+//   y[0] -> q5
+//   y[1] -> q6
+//   z    -> q7
+void test_argpass_last_z(PST *p) {
+    void argpass_last_z_callee(double, double, double, double, PST);
+    argpass_last_z_callee(.0, .0, .0, .0, *p);
+}
+// CHECK-AAPCS:  declare void @argpass_last_z_callee(double noundef, double noundef, double noundef, double noundef, <vscale x 16 x i1>, <vscale x 2 x double>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 16 x i8>, <vscale x 16 x i1>)
+// CHECK-DARWIN: declare void @argpass_last_z_callee(double noundef, double noundef, double noundef, double noundef, ptr noundef)
+
+
+// Like the above, but using a tuple type to occupy some registers.
+//   x    -> z0.d-z3.d
+//   a    -> p0
+//   b    -> p1
+//   x    -> q4
+//   y[0] -> q5
+//   y[1] -> q6
+//   z    -> q7
+void test_argpass_last_z_tuple(PST *p, svfloat64x4_t x) {
+  void argpass_last_z_tuple_callee(svfloat64x4_t, PST);
+  argpass_last_z_tuple_callee(x, *p);
+}
+// CHECK-AAPCS:  declare void @argpass_last_z_tuple_callee(<vscale x 2 x double>, <vscale x 2 x double>, <vscale x 2 x double>, <vscale x 2 x double>, <vscale x 16 x i1>, <vscale x 2 x double>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 16 x i8>, <vscale x 16 x i1>)
+// CHECK-DARWIN: declare void @argpass_last_z_tuple_callee(<vscale x 2 x double>, <vscale x 2 x double>, <vscale x 2 x double>, <vscale x 2 x double>, ptr noundef)
+
+
+// Boundary case of using the last available P-reg, PST expanded.
+//   false -> p0-p1
+//   a     -> p2
+//   b     -> p3
+//   x     -> q0
+//   y[0]  -> q1
+//   y[1]  -> q2
+//   z     -> q3
+void test_argpass_last_p(PST *p) {
+    void argpass_last_p_callee(svbool_t, svcount_t, PST);
+    argpass_last_p_callee(svpfalse(), svpfalse_c(), *p);
+}
+// CHECK-AAPCS:  declare void @argpass_last_p_callee(<vscale x 16 x i1>, target("aarch64.svcount"), <vscale x 16 x i1>, <vscale x 2 x double>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 16 x i8>, <vscale x 16 x i1>)
+// CHECK-DARWIN: declare void @argpass_last_p_callee(<vscale x 16 x i1>, target("aarch64.svcount"), ptr noundef)
+
+
+// Not enough Z-regs, push PST to memory and pass a pointer, Z-regs and
+// P-regs still available for other arguments
+//   u     -> z0
+//   0.0   -> d1-d4
+//   1     -> w0
+//   *p    -> memory, address -> x1
+//   2     -> w2
+//   3.0   -> d5
+//   true  -> p0
+void test_argpass_no_z(PST *p, double dummy, svmfloat8_t u) {
+    void argpass_no_z_callee(svmfloat8_t, double, double, double, double, int, PST, int, double, svbool_t);
+    argpass_no_z_callee(u, .0, .0, .0, .0, 1, *p, 2, 3.0, svptrue_b64());
+}
+// CHECK: declare void @argpass_no_z_callee(<vscale x 16 x i8>, double noundef, double noundef, double noundef, double noundef, i32 noundef, ptr noundef, i32 noundef, double noundef, <vscale x 16 x i1>)
+
+
+// Like the above, using a tuple to occupy some registers.
+//   x     -> z0.d-z3.d
+//   0.0   -> d4
+//   1     -> w0
+//   *p    -> memory, address -> x1
+//   2     -> w2
+//   3.0   -> d5
+//   true  -> p0
+void test_argpass_no_z_tuple(PST *p, float dummy, svfloat64x4_t x) {
+  void argpass_no_z_tuple_callee(svfloat64x4_t, double, int, PST, int,
+                                 double, svbool_t);
+  argpass_no_z_tuple_callee(x, .0, 1, *p, 2, 3.0, svptrue_b64());
+}
+// CHECK: declare void @argpass_no_z_tuple_callee(<vscale x 2 x double>, <vscale x 2 x double>, <vscale x 2 x double>, <vscale x 2 x double>, double noundef, i32 noundef, ptr noundef, i32 noundef, double noundef, <vscale x 16 x i1>)
+
+
+// Not enough Z-regs (consumed by a HFA), PST passed indirectly
+//   0.0  -> d0
+//   *h   -> s1-s4
+//   1    -> w0
+//   *p   -> memory, address -> x1
+//   p    -> x1
+//   2    -> w2
+//   true -> p0
+void test_argpass_no_z_hfa(HFA *h, PST *p) {
+    void argpass_no_z_hfa_callee(double, HFA, int, PST, int, svbool_t);
+    argpass_no_z_hfa_callee(.0, *h, 1, *p, 2, svptrue_b64());
+}
+// CHECK-AAPCS:  declare void @argpass_no_z_hfa_callee(double noundef, [4 x float] alignstack(8), i32 noundef, ptr noundef, i32 noundef, <vscale x 16 x i1>)
+// CHECK-DARWIN: declare void @argpass_no_z_hfa_callee(double noundef, [4 x float], i32 noundef, ptr noundef, i32 noundef, <vscale x 16 x i1>)
+
+
+// Not enough P-regs, PST passed indirectly, Z-regs and P-regs still available.
+//   true -> p0-p2
+//   1    -> w0
+//   *p   -> memory, address -> x1
+//   2    -> w2
+//   3.0  -> d0
+//   true -> p3
+void test_argpass_no_p(PST *p) {
+    void argpass_no_p_callee(svbool_t, svbool_t, svbool_t, int, PST, int, double, svbool_t);
+    argpass_no_p_callee(svptrue_b8(), svptrue_b16(), svptrue_b32(), 1, *p, 2, 3.0, svptrue_b64());
+}
+// CHECK: declare void @argpass_no_p_callee(<vscale x 16 x i1>, <vscale x 16 x i1>, <vscale x 16 x i1>, i32 noundef, ptr noundef, i32 noundef, double noundef, <vscale x 16 x i1>)
+
+
+// Like above, using a tuple to occupy some registers.
+// P-regs still available.
+//   v    -> p0-p1
+//   u    -> p2
+//   1    -> w0
+//   *p   -> memory, address -> x1
+//   2    -> w2
+//   3.0  -> d0
+//   true -> p3
+void test_argpass_no_p_tuple(PST *p, svbool_t u, svboolx2_t v) {
+  void argpass_no_p_tuple_callee(svboolx2_t, svbool_t, int, PST, int, double,
+                                 svbool_t);
+  argpass_no_p_tuple_callee(v, u, 1, *p, 2, 3.0, svptrue_b64());
+}
+// CHECK: declare void @argpass_no_p_tuple_callee(<vscale x 16 x i1>, <vscale x 16 x i1>, <vscale x 16 x i1>, i32 noundef, ptr noundef, i32 noundef, double noundef, <vscale x 16 x i1>)
+
+
+// HFAs go back-to-back to memory, afterwards Z-regs not available, PST passed indirectly.
+//   0.0   -> d0-d3
+//   *h    -> memory
+//   *p    -> memory, address -> x0
+//   *h    -> memory
+//   false -> p0
+void test_after_hfa(HFA *h, PST *p) {
+    void after_hfa_callee(double, double, double, double, double, HFA, PST, HFA, svbool_t);
+    after_hfa_callee(.0, .0, .0, .0, .0, *h, *p, *h, svpfalse());
+}
+// CHECK-AAPCS:  declare void @after_hfa_callee(double noundef, double noundef, double noundef, double noundef, double noundef, [4 x float] alignstack(8), ptr noundef, [4 x float] alignstack(8), <vscale x 16 x i1>)
+// CHECK-DARWIN: declare void @after_hfa_callee(double noundef, double noundef, double noundef, double noundef, double noundef, [4 x float], ptr noundef, [4 x float], <vscale x 16 x i1>)
+
+// Small PST, not enough registers, passed indirectly, unlike other small
+// aggregates.
+//   *s  -> x0-x1
+//   0.0 -> d0-d7
+//   *p  -> memory, address -> x2
+//   1.0 -> memory
+//   2.0 -> memory (next to the above)
+void test_small_pst(SmallPST *p, SmallAgg *s) {
+    void small_pst_callee(SmallAgg, double, double, double, double, double, double, double, double, double, SmallPST, double);
+    small_pst_callee(*s, .0, .0, .0, .0, .0, .0, .0, .0, 1.0, *p, 2.0);
+}
+// CHECK-AAPCS:  declare void @small_pst_callee([2 x i64], double noundef, double noundef, double noundef, double noundef, double noundef, double noundef, double noundef, double noundef, double noundef, ptr noundef, double noundef)
+// CHECK-DARWIN: declare void @small_pst_callee([2 x i64], double noundef, double noundef, double noundef, double noundef, double noundef, double noundef, double noundef, double noundef, double noundef, i128, double noundef)
+
+// Simple return, PST expanded to registers
+//   p->a    -> p0
+//   p->x    -> q0
+//   p->y[0] -> q1
+//   p->y[1] -> q2
+//   p->z    -> q3
+//   p->b    -> p1
+PST test_return(PST *p) {
+    return *p;
+}
+// CHECK-AAPCS:  define dso_local <{ <vscale x 16 x i1>, <vscale x 2 x double>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 16 x i8>, <vscale x 16 x i1> }> @test_return(ptr
+// CHECK-DARWIN: define void @test_return(ptr dead_on_unwind noalias nocapture writable writeonly sret(%struct.PST) align 16 %agg.result, ptr nocapture noundef readonly %p)
+
+// Corner case of 1-element aggregate
+//   p->x -> q0
+SmallPST test_return_small_pst(SmallPST *p) {
+    return *p;
+}
+// CHECK-AAPCS:  define dso_local <vscale x 4 x float> @test_return_small_pst(ptr
+// CHECK-DARWIN: define i128 @test_return_small_pst(ptr nocapture noundef readonly %p)
+
+
+// Big PST, returned indirectly
+//   *p -> *x8
+BigPST test_return_big_pst(BigPST *p) {
+    return *p;
+}
+// CHECK-AAPCS:  define dso_local void @test_return_big_pst(ptr dead_on_unwind noalias nocapture writable writeonly sret(%struct.BigPST) align 16 %agg.result, ptr nocapture noundef readonly %p)
+// CHECK-DARWIN: define void @test_return_big_pst(ptr dead_on_unwind noalias nocapture writable writeonly sret(%struct.BigPST) align 16 %agg.result, ptr nocapture noundef readonly %p)
+
+// Variadic arguments are unnamed, PST passed indirectly
+//   0  -> x0
+//   *p -> memory, address -> x1
+void test_pass_variadic(PST *p) {
+    void pass_variadic_callee(int n, ...);
+    pass_variadic_callee(0, *p);
+}
+// CHECK: declare void @pass_variadic_callee(i32 noundef, ...)
+
+
+// Test handling of a PST argument when passed in registers, from the callee side.
+void argpass_callee_side(PST v) {
+    void use(PST *p);
+    use(&v);
+}
+// CHECK-AAPCS:      define dso_local void @argpass_callee_side(<vscale x 16 x i1> %0, <vscale x 2 x double> %.coerce1, <vscale x 4 x float> %.coerce3, <vscale x 4 x float> %.coerce5, <vscale x 16 x i8> %.coerce7, <vscale x 16 x i1> %1) local_unnamed_addr #0 {
+// CHECK-AAPCS-NEXT: entry:
+// CHECK-AAPCS-NEXT:   %v = alloca %struct.PST, align 16
+// CHECK-AAPCS-NEXT:   %.coerce = bitcast <vscale x 16 x i1> %0 to <vscale x 2 x i8>
+// CHECK-AAPCS-NEXT:   %cast.fixed = tail call <2 x i8> @llvm.vector.extract.v2i8.nxv2i8(<vscale x 2 x i8> %.coerce, i64 0)
+// CHECK-AAPCS-NEXT:   store <2 x i8> %cast.fixed, ptr %v, align 16
+// CHECK-AAPCS-NEXT:   %2 = getelementptr inbounds nuw i8, ptr %v, i64 16
+// CHECK-AAPCS-NEXT:   %cast.fixed2 = tail call <2 x double> @llvm.vector.extract.v2f64.nxv2f64(<vscale x 2 x double> %.coerce1, i64 0)
+// CHECK-AAPCS-NEXT:   store <2 x double> %cast.fixed2, ptr %2, align 16
+// CHECK-AAPCS-NEXT:   %3 = getelementptr inbounds nuw i8, ptr %v, i64 32
+// CHECK-AAPCS-NEXT:   %cast.fixed4 = tail call <4 x float> @llvm.vector.extract.v4f32.nxv4f32(<vscale x 4 x float> %.coerce3, i64 0)
+// CHECK-AAPCS-NEXT:   store <4 x float> %cast.fixed4, ptr %3, align 16
+// CHECK-AAPCS-NEXT:   %4 = getelementptr inbounds nuw i8, ptr %v, i64 48
+// CHECK-AAPCS-NEXT:   %cast.fixed6 = tail call <4 x float> @llvm.vector.extract.v4f32.nxv4f32(<vscale x 4 x float> %.coerce5, i64 0)
+// CHECK-AAPCS-NEXT:   store <4 x float> %cast.fixed6, ptr %4, align 16
+// CHECK-AAPCS-NEXT:   %5 = getelementptr inbounds nuw i8, ptr %v, i64 64
+// CHECK-AAPCS-NEXT:   %cast.fixed8 = tail call <16 x i8> @llvm.vector.extract.v16i8.nxv16i8(<vscale x 16 x i8> %.coerce7, i64 0)
+// CHECK-AAPCS-NEXT:   store <16 x i8> %cast.fixed8, ptr %5, align 16
+// CHECK-AAPCS-NEXT:   %6 = getelementptr inbounds nuw i8, ptr %v, i64 80
+// CHECK-AAPCS-NEXT:   %.coerce9 = bitcast <vscale x 16 x i1> %1 to <vscale x 2 x i8>
+// CHECK-AAPCS-NEXT:   %cast.fixed10 = tail call <2 x i8> @llvm.vector.extract.v2i8.nxv2i8(<vscale x 2 x i8> %.coerce9, i64 0)
+// CHECK-AAPCS-NEXT:   store <2 x i8> %cast.fixed10, ptr %6, align 16
+// CHECK-AAPCS-NEXT:   call void @use(ptr noundef nonnull %v) #8
+// CHECK-AAPCS-NEXT:   ret void
+// CHECK-AAPCS-NEXT: }

>From 21854c1e163da841d9e84fec566d809c163b6419 Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Fri, 18 Oct 2024 11:12:09 +0100
Subject: [PATCH 2/3] [fixup] Add test using FP8 tuple type

---
 .../test/CodeGen/aarch64-pure-scalable-args.c | 26 +++++++++++++++----
 1 file changed, 21 insertions(+), 5 deletions(-)

diff --git a/clang/test/CodeGen/aarch64-pure-scalable-args.c b/clang/test/CodeGen/aarch64-pure-scalable-args.c
index e5e50e85d9b93a..631f6a5a1321b2 100644
--- a/clang/test/CodeGen/aarch64-pure-scalable-args.c
+++ b/clang/test/CodeGen/aarch64-pure-scalable-args.c
@@ -160,12 +160,28 @@ void test_argpass_no_z(PST *p, double dummy, svmfloat8_t u) {
 //   2     -> w2
 //   3.0   -> d5
 //   true  -> p0
-void test_argpass_no_z_tuple(PST *p, float dummy, svfloat64x4_t x) {
-  void argpass_no_z_tuple_callee(svfloat64x4_t, double, int, PST, int,
-                                 double, svbool_t);
-  argpass_no_z_tuple_callee(x, .0, 1, *p, 2, 3.0, svptrue_b64());
+void test_argpass_no_z_tuple_f64(PST *p, float dummy, svfloat64x4_t x) {
+  void argpass_no_z_tuple_f64_callee(svfloat64x4_t, double, int, PST, int,
+                                     double, svbool_t);
+  argpass_no_z_tuple_f64_callee(x, .0, 1, *p, 2, 3.0, svptrue_b64());
 }
-// CHECK: declare void @argpass_no_z_tuple_callee(<vscale x 2 x double>, <vscale x 2 x double>, <vscale x 2 x double>, <vscale x 2 x double>, double noundef, i32 noundef, ptr noundef, i32 noundef, double noundef, <vscale x 16 x i1>)
+// CHECK: declare void @argpass_no_z_tuple_f64_callee(<vscale x 2 x double>, <vscale x 2 x double>, <vscale x 2 x double>, <vscale x 2 x double>, double noundef, i32 noundef, ptr noundef, i32 noundef, double noundef, <vscale x 16 x i1>)
+
+
+// Likewise, using a different tuple.
+//   x     -> z0.d-z3.d
+//   0.0   -> d4
+//   1     -> w0
+//   *p    -> memory, address -> x1
+//   2     -> w2
+//   3.0   -> d5
+//   true  -> p0
+void test_argpass_no_z_tuple_mfp8(PST *p, float dummy, svmfloat8x4_t x) {
+  void argpass_no_z_tuple_mfp8_callee(svmfloat8x4_t, double, int, PST, int,
+                                      double, svbool_t);
+  argpass_no_z_tuple_mfp8_callee(x, .0, 1, *p, 2, 3.0, svptrue_b64());
+}
+// CHECK: declare void @argpass_no_z_tuple_mfp8_callee(<vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8>, double noundef, i32 noundef, ptr noundef, i32 noundef, double noundef, <vscale x 16 x i1>)
 
 
 // Not enough Z-regs (consumed by a HFA), PST passed indirectly

>From 00e3be6e5fe60c25b4e3a5ff699d1e2bf0f0a097 Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Fri, 18 Oct 2024 16:25:26 +0100
Subject: [PATCH 3/3] [fixup] Correcltly handle empty unions and misc other
 fixes

---
 clang/include/clang/CodeGen/CGFunctionInfo.h  |  3 +-
 clang/lib/CodeGen/Targets/AArch64.cpp         | 62 +++++++++++--------
 .../aarch64-pure-scalable-args-empty-union.c  | 39 ++++++++++++
 3 files changed, 76 insertions(+), 28 deletions(-)
 create mode 100644 clang/test/CodeGen/aarch64-pure-scalable-args-empty-union.c

diff --git a/clang/include/clang/CodeGen/CGFunctionInfo.h b/clang/include/clang/CodeGen/CGFunctionInfo.h
index 915f676d7d3905..9d785d878b61dc 100644
--- a/clang/include/clang/CodeGen/CGFunctionInfo.h
+++ b/clang/include/clang/CodeGen/CGFunctionInfo.h
@@ -271,7 +271,8 @@ class ABIArgInfo {
     // in the unpadded type.
     unsigned unpaddedIndex = 0;
     for (auto eltType : coerceToType->elements()) {
-      if (isPaddingForCoerceAndExpand(eltType)) continue;
+      if (isPaddingForCoerceAndExpand(eltType))
+        continue;
       unpaddedIndex++;
     }
 
diff --git a/clang/lib/CodeGen/Targets/AArch64.cpp b/clang/lib/CodeGen/Targets/AArch64.cpp
index 269b5b352bfd84..85b40f267bfa04 100644
--- a/clang/lib/CodeGen/Targets/AArch64.cpp
+++ b/clang/lib/CodeGen/Targets/AArch64.cpp
@@ -52,8 +52,8 @@ class AArch64ABIInfo : public ABIInfo {
 
   bool isIllegalVectorType(QualType Ty) const;
 
-  bool isPureScalableType(QualType Ty, unsigned &NV, unsigned &NP,
-                          SmallVectorImpl<llvm::Type *> &CoerceToSeq) const;
+  bool passAsPureScalableType(QualType Ty, unsigned &NV, unsigned &NP,
+                              SmallVectorImpl<llvm::Type *> &CoerceToSeq) const;
 
   void flattenType(llvm::Type *Ty,
                    SmallVectorImpl<llvm::Type *> &Flattened) const;
@@ -432,7 +432,7 @@ ABIArgInfo AArch64ABIInfo::classifyArgumentType(QualType Ty, bool IsVariadic,
   if (Kind == AArch64ABIKind::AAPCS && !IsVariadic) {
     unsigned NVec = 0, NPred = 0;
     SmallVector<llvm::Type *> UnpaddedCoerceToSeq;
-    if (isPureScalableType(Ty, NVec, NPred, UnpaddedCoerceToSeq) &&
+    if (passAsPureScalableType(Ty, NVec, NPred, UnpaddedCoerceToSeq) &&
         (NVec + NPred) > 0)
       return coerceAndExpandPureScalableAggregate(
           Ty, NVec, NPred, UnpaddedCoerceToSeq, NSRN, NPRN);
@@ -510,14 +510,14 @@ ABIArgInfo AArch64ABIInfo::classifyReturnType(QualType RetTy,
     // Homogeneous Floating-point Aggregates (HFAs) are returned directly.
     return ABIArgInfo::getDirect();
 
-  // In AAPCS return values of a Pure Scalable type are treated is a first named
-  // argument and passed expanded in registers, or indirectly if there are not
-  // enough registers.
+  // In AAPCS return values of a Pure Scalable type are treated as a single
+  // named argument and passed expanded in registers, or indirectly if there are
+  // not enough registers.
   if (Kind == AArch64ABIKind::AAPCS) {
     unsigned NSRN = 0, NPRN = 0;
     unsigned NVec = 0, NPred = 0;
     SmallVector<llvm::Type *> UnpaddedCoerceToSeq;
-    if (isPureScalableType(RetTy, NVec, NPred, UnpaddedCoerceToSeq) &&
+    if (passAsPureScalableType(RetTy, NVec, NPred, UnpaddedCoerceToSeq) &&
         (NVec + NPred) > 0)
       return coerceAndExpandPureScalableAggregate(
           RetTy, NVec, NPred, UnpaddedCoerceToSeq, NSRN, NPRN);
@@ -638,13 +638,15 @@ bool AArch64ABIInfo::isZeroLengthBitfieldPermittedInHomogeneousAggregate()
   return true;
 }
 
-// Check if a type is a Pure Scalable Type as defined by AAPCS64. Return the
-// number of data vectors and the number of predicate vectors in the types, into
-// `NVec` and `NPred`, respectively. Upon return `CoerceToSeq` contains an
-// expanded sequence of LLVM IR types, one element for each non-composite
-// member. For practical purposes, limit the length of `CoerceToSeq` to about
-// 12, the maximum size that could possibly fit in registers.
-bool AArch64ABIInfo::isPureScalableType(
+// Check if a type needs to be passed in registers as a Pure Scalable Type (as
+// defined by AAPCS64). Return the number of data vectors and the number of
+// predicate vectors in the type, into `NVec` and `NPred`, respectively. Upon
+// return `CoerceToSeq` contains an expanded sequence of LLVM IR types, one
+// element for each non-composite member. For practical purposes, limit the
+// length of `CoerceToSeq` to about 12 (the maximum that could possibly fit
+// in registers) and return false, the effect of which will be to  pass the
+// argument under the rules for a large (> 128 bytes) composite.
+bool AArch64ABIInfo::passAsPureScalableType(
     QualType Ty, unsigned &NVec, unsigned &NPred,
     SmallVectorImpl<llvm::Type *> &CoerceToSeq) const {
   if (const ConstantArrayType *AT = getContext().getAsConstantArrayType(Ty)) {
@@ -654,10 +656,13 @@ bool AArch64ABIInfo::isPureScalableType(
 
     unsigned NV = 0, NP = 0;
     SmallVector<llvm::Type *> EltCoerceToSeq;
-    if (!isPureScalableType(AT->getElementType(), NV, NP, EltCoerceToSeq))
+    if (!passAsPureScalableType(AT->getElementType(), NV, NP, EltCoerceToSeq))
       return false;
 
-    for (uint64_t I = 0; CoerceToSeq.size() < 12 && I < NElt; ++I)
+    if (CoerceToSeq.size() + NElt * EltCoerceToSeq.size() > 12)
+      return false;
+
+    for (uint64_t I = 0; I < NElt; ++I)
       llvm::copy(EltCoerceToSeq, std::back_inserter(CoerceToSeq));
 
     NVec += NElt * NV;
@@ -676,12 +681,12 @@ bool AArch64ABIInfo::isPureScalableType(
     if (RD->isUnion())
       return false;
 
-    // If this is a C++ record, check the bases bases.
+    // If this is a C++ record, check the bases.
     if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
       for (const auto &I : CXXRD->bases()) {
         if (isEmptyRecord(getContext(), I.getType(), true))
           continue;
-        if (!isPureScalableType(I.getType(), NVec, NPred, CoerceToSeq))
+        if (!passAsPureScalableType(I.getType(), NVec, NPred, CoerceToSeq))
           return false;
       }
     }
@@ -689,9 +694,9 @@ bool AArch64ABIInfo::isPureScalableType(
     // Check members.
     for (const auto *FD : RD->fields()) {
       QualType FT = FD->getType();
-      if (isEmptyRecord(getContext(), FT, true))
+      if (isEmptyField(getContext(), FD, /* AllowArrays */ true))
         continue;
-      if (!isPureScalableType(FT, NVec, NPred, CoerceToSeq))
+      if (!passAsPureScalableType(FT, NVec, NPred, CoerceToSeq))
         return false;
     }
 
@@ -704,15 +709,17 @@ bool AArch64ABIInfo::isPureScalableType(
 
   if (VT->getVectorKind() == VectorKind::SveFixedLengthPredicate) {
     ++NPred;
-    if (CoerceToSeq.size() < 12)
-      CoerceToSeq.push_back(convertFixedToScalableVectorType(VT));
+    if (CoerceToSeq.size() + 1 > 12)
+      return false;
+    CoerceToSeq.push_back(convertFixedToScalableVectorType(VT));
     return true;
   }
 
   if (VT->getVectorKind() == VectorKind::SveFixedLengthData) {
     ++NVec;
-    if (CoerceToSeq.size() < 12)
-      CoerceToSeq.push_back(convertFixedToScalableVectorType(VT));
+    if (CoerceToSeq.size() + 1 > 12)
+      return false;
+    CoerceToSeq.push_back(convertFixedToScalableVectorType(VT));
     return true;
   }
 
@@ -741,8 +748,9 @@ bool AArch64ABIInfo::isPureScalableType(
   auto VTy = llvm::ScalableVectorType::get(CGT.ConvertType(Info.ElementType),
                                            Info.EC.getKnownMinValue());
 
-  if (CoerceToSeq.size() < 12)
-    std::fill_n(std::back_inserter(CoerceToSeq), Info.NumVectors, VTy);
+  if (CoerceToSeq.size() + Info.NumVectors > 12)
+    return false;
+  std::fill_n(std::back_inserter(CoerceToSeq), Info.NumVectors, VTy);
 
   return true;
 }
@@ -784,7 +792,7 @@ RValue AArch64ABIInfo::EmitAAPCSVAArg(Address VAListAddr, QualType Ty,
                                       CodeGenFunction &CGF, AArch64ABIKind Kind,
                                       AggValueSlot Slot) const {
   // These numbers are not used for variadic arguments, hence it doesn't matter
-  // they don't retain their values accross multiple calls to
+  // they don't retain their values across multiple calls to
   // `classifyArgumentType` here.
   unsigned NSRN = 0, NPRN = 0;
   ABIArgInfo AI =
diff --git a/clang/test/CodeGen/aarch64-pure-scalable-args-empty-union.c b/clang/test/CodeGen/aarch64-pure-scalable-args-empty-union.c
new file mode 100644
index 00000000000000..546910068c78a2
--- /dev/null
+++ b/clang/test/CodeGen/aarch64-pure-scalable-args-empty-union.c
@@ -0,0 +1,39 @@
+// RUN: %clang_cc1        -O3 -triple aarch64 -target-feature +sve -mvscale-min=1 -mvscale-max=1 -emit-llvm -o - %s | FileCheck %s --check-prefixes=CHECK-C
+// RUN: %clang_cc1 -x c++ -O3 -triple aarch64 -target-feature +sve -mvscale-min=1 -mvscale-max=1 -emit-llvm -o - %s | FileCheck %s --check-prefixes=CHECK-CXX
+
+typedef __SVFloat32_t fvec32 __attribute__((arm_sve_vector_bits(128)));
+
+// PST containing an empty union: when compiled as C pass it in registers,
+// when compiled as C++ - in memory.
+typedef struct {
+  fvec32 x[4];
+  union {} u;
+} S0;
+
+#ifdef __cplusplus
+extern "C"
+#endif
+void use0(S0);
+
+void f0(S0 *p) {
+  use0(*p);
+}
+// CHECK-C:   declare void @use0(<vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>)
+// CHECK-CXX: declare void @use0(ptr noundef)
+
+#ifdef __cplusplus
+
+// PST containing an empty union with `[[no_unique_address]]`` - pass in registers.
+typedef struct {
+   fvec32 x[4];
+   [[no_unique_address]]
+   union {} u;
+} S1;
+
+extern "C" void use1(S1);
+void f1(S1 *p) {
+  use1(*p);
+}
+// CHECK-CXX: declare void @use1(<vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>)
+
+#endif // __cplusplus



More information about the cfe-commits mailing list