[clang] 53f7f8e - [Clang][AArch64] Fix Pure Scalables Types argument passing and return (#112747)

via cfe-commits cfe-commits at lists.llvm.org
Mon Oct 28 08:43:19 PDT 2024


Author: Momchil Velikov
Date: 2024-10-28T15:43:14Z
New Revision: 53f7f8eccabd6e3383edfeec312bf8671a89bc66

URL: https://github.com/llvm/llvm-project/commit/53f7f8eccabd6e3383edfeec312bf8671a89bc66
DIFF: https://github.com/llvm/llvm-project/commit/53f7f8eccabd6e3383edfeec312bf8671a89bc66.diff

LOG: [Clang][AArch64] Fix Pure Scalables Types argument passing and return (#112747)

Pure Scalable Types are defined in AAPCS64 here:

https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#pure-scalable-types-psts

And should be passed according to Rule C.7 here:

https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#682parameter-passing-rules

This part of the ABI is completely unimplemented in Clang, instead it
treats PSTs sometimes as HFAs/HVAs, sometime as general composite types.

This patch implements the rules for passing PSTs by employing the
`CoerceAndExpand` method and extending it to:
* allow array types in the `coerceToType`; Now only `[N x i8]` are
considered padding.
* allow mismatch between the elements of the `coerceToType` and the
elements of the `unpaddedCoerceToType`; AArch64 uses this to map
fixed-length vector types to SVE vector types.

Corectly passing a PST argument needs a decision in Clang about whether
to pass it in memory or registers or, equivalently, whether to use the
`Indirect` or `Expand/CoerceAndExpand` method. It was considered
relatively harder (or not practically possible) to make that decision in
the AArch64 backend.
Hence this patch implements the register counting from AAPCS64 (cf.
`NSRN`, `NPRN`) to guide the Clang's decision.

Added: 
    clang/test/CodeGen/aarch64-pure-scalable-args-empty-union.c
    clang/test/CodeGen/aarch64-pure-scalable-args.c

Modified: 
    clang/include/clang/CodeGen/CGFunctionInfo.h
    clang/lib/CodeGen/CGCall.cpp
    clang/lib/CodeGen/Targets/AArch64.cpp

Removed: 
    


################################################################################
diff  --git a/clang/include/clang/CodeGen/CGFunctionInfo.h b/clang/include/clang/CodeGen/CGFunctionInfo.h
index d19f84d198876f..9d785d878b61dc 100644
--- a/clang/include/clang/CodeGen/CGFunctionInfo.h
+++ b/clang/include/clang/CodeGen/CGFunctionInfo.h
@@ -271,12 +271,8 @@ class ABIArgInfo {
     // in the unpadded type.
     unsigned unpaddedIndex = 0;
     for (auto eltType : coerceToType->elements()) {
-      if (isPaddingForCoerceAndExpand(eltType)) continue;
-      if (unpaddedStruct) {
-        assert(unpaddedStruct->getElementType(unpaddedIndex) == eltType);
-      } else {
-        assert(unpaddedIndex == 0 && unpaddedCoerceToType == eltType);
-      }
+      if (isPaddingForCoerceAndExpand(eltType))
+        continue;
       unpaddedIndex++;
     }
 
@@ -295,12 +291,8 @@ class ABIArgInfo {
   }
 
   static bool isPaddingForCoerceAndExpand(llvm::Type *eltType) {
-    if (eltType->isArrayTy()) {
-      assert(eltType->getArrayElementType()->isIntegerTy(8));
-      return true;
-    } else {
-      return false;
-    }
+    return eltType->isArrayTy() &&
+           eltType->getArrayElementType()->isIntegerTy(8);
   }
 
   Kind getKind() const { return TheKind; }

diff  --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp
index 1949b4ceb7f204..64e60f0616d77b 100644
--- a/clang/lib/CodeGen/CGCall.cpp
+++ b/clang/lib/CodeGen/CGCall.cpp
@@ -1410,6 +1410,30 @@ static Address emitAddressAtOffset(CodeGenFunction &CGF, Address addr,
   return addr;
 }
 
+static std::pair<llvm::Value *, bool>
+CoerceScalableToFixed(CodeGenFunction &CGF, llvm::FixedVectorType *ToTy,
+                      llvm::ScalableVectorType *FromTy, llvm::Value *V,
+                      StringRef Name = "") {
+  // If we are casting a scalable i1 predicate vector to a fixed i8
+  // vector, first bitcast the source.
+  if (FromTy->getElementType()->isIntegerTy(1) &&
+      FromTy->getElementCount().isKnownMultipleOf(8) &&
+      ToTy->getElementType() == CGF.Builder.getInt8Ty()) {
+    FromTy = llvm::ScalableVectorType::get(
+        ToTy->getElementType(),
+        FromTy->getElementCount().getKnownMinValue() / 8);
+    V = CGF.Builder.CreateBitCast(V, FromTy);
+  }
+  if (FromTy->getElementType() == ToTy->getElementType()) {
+    llvm::Value *Zero = llvm::Constant::getNullValue(CGF.CGM.Int64Ty);
+
+    V->setName(Name + ".coerce");
+    V = CGF.Builder.CreateExtractVector(ToTy, V, Zero, "cast.fixed");
+    return {V, true};
+  }
+  return {V, false};
+}
+
 namespace {
 
 /// Encapsulates information about the way function arguments from
@@ -3196,26 +3220,14 @@ void CodeGenFunction::EmitFunctionProlog(const CGFunctionInfo &FI,
       // a VLAT at the function boundary and the types match up, use
       // llvm.vector.extract to convert back to the original VLST.
       if (auto *VecTyTo = dyn_cast<llvm::FixedVectorType>(ConvertType(Ty))) {
-        llvm::Value *Coerced = Fn->getArg(FirstIRArg);
+        llvm::Value *ArgVal = Fn->getArg(FirstIRArg);
         if (auto *VecTyFrom =
-                dyn_cast<llvm::ScalableVectorType>(Coerced->getType())) {
-          // If we are casting a scalable i1 predicate vector to a fixed i8
-          // vector, bitcast the source and use a vector extract.
-          if (VecTyFrom->getElementType()->isIntegerTy(1) &&
-              VecTyFrom->getElementCount().isKnownMultipleOf(8) &&
-              VecTyTo->getElementType() == Builder.getInt8Ty()) {
-            VecTyFrom = llvm::ScalableVectorType::get(
-                VecTyTo->getElementType(),
-                VecTyFrom->getElementCount().getKnownMinValue() / 8);
-            Coerced = Builder.CreateBitCast(Coerced, VecTyFrom);
-          }
-          if (VecTyFrom->getElementType() == VecTyTo->getElementType()) {
-            llvm::Value *Zero = llvm::Constant::getNullValue(CGM.Int64Ty);
-
+                dyn_cast<llvm::ScalableVectorType>(ArgVal->getType())) {
+          auto [Coerced, Extracted] = CoerceScalableToFixed(
+              *this, VecTyTo, VecTyFrom, ArgVal, Arg->getName());
+          if (Extracted) {
             assert(NumIRArgs == 1);
-            Coerced->setName(Arg->getName() + ".coerce");
-            ArgVals.push_back(ParamValue::forDirect(Builder.CreateExtractVector(
-                VecTyTo, Coerced, Zero, "cast.fixed")));
+            ArgVals.push_back(ParamValue::forDirect(Coerced));
             break;
           }
         }
@@ -3326,16 +3338,33 @@ void CodeGenFunction::EmitFunctionProlog(const CGFunctionInfo &FI,
       ArgVals.push_back(ParamValue::forIndirect(alloca));
 
       auto coercionType = ArgI.getCoerceAndExpandType();
+      auto unpaddedCoercionType = ArgI.getUnpaddedCoerceAndExpandType();
+      auto *unpaddedStruct = dyn_cast<llvm::StructType>(unpaddedCoercionType);
+
       alloca = alloca.withElementType(coercionType);
 
       unsigned argIndex = FirstIRArg;
+      unsigned unpaddedIndex = 0;
       for (unsigned i = 0, e = coercionType->getNumElements(); i != e; ++i) {
         llvm::Type *eltType = coercionType->getElementType(i);
         if (ABIArgInfo::isPaddingForCoerceAndExpand(eltType))
           continue;
 
         auto eltAddr = Builder.CreateStructGEP(alloca, i);
-        auto elt = Fn->getArg(argIndex++);
+        llvm::Value *elt = Fn->getArg(argIndex++);
+
+        auto paramType = unpaddedStruct
+                             ? unpaddedStruct->getElementType(unpaddedIndex++)
+                             : unpaddedCoercionType;
+
+        if (auto *VecTyTo = dyn_cast<llvm::FixedVectorType>(eltType)) {
+          if (auto *VecTyFrom = dyn_cast<llvm::ScalableVectorType>(paramType)) {
+            bool Extracted;
+            std::tie(elt, Extracted) = CoerceScalableToFixed(
+                *this, VecTyTo, VecTyFrom, elt, elt->getName());
+            assert(Extracted && "Unexpected scalable to fixed vector coercion");
+          }
+        }
         Builder.CreateStore(elt, eltAddr);
       }
       assert(argIndex == FirstIRArg + NumIRArgs);
@@ -3930,17 +3959,24 @@ void CodeGenFunction::EmitFunctionEpilog(const CGFunctionInfo &FI,
 
   case ABIArgInfo::CoerceAndExpand: {
     auto coercionType = RetAI.getCoerceAndExpandType();
+    auto unpaddedCoercionType = RetAI.getUnpaddedCoerceAndExpandType();
+    auto *unpaddedStruct = dyn_cast<llvm::StructType>(unpaddedCoercionType);
 
     // Load all of the coerced elements out into results.
     llvm::SmallVector<llvm::Value*, 4> results;
     Address addr = ReturnValue.withElementType(coercionType);
+    unsigned unpaddedIndex = 0;
     for (unsigned i = 0, e = coercionType->getNumElements(); i != e; ++i) {
       auto coercedEltType = coercionType->getElementType(i);
       if (ABIArgInfo::isPaddingForCoerceAndExpand(coercedEltType))
         continue;
 
       auto eltAddr = Builder.CreateStructGEP(addr, i);
-      auto elt = Builder.CreateLoad(eltAddr);
+      llvm::Value *elt = CreateCoercedLoad(
+          eltAddr,
+          unpaddedStruct ? unpaddedStruct->getElementType(unpaddedIndex++)
+                         : unpaddedCoercionType,
+          *this);
       results.push_back(elt);
     }
 
@@ -5472,6 +5508,8 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
     case ABIArgInfo::CoerceAndExpand: {
       auto coercionType = ArgInfo.getCoerceAndExpandType();
       auto layout = CGM.getDataLayout().getStructLayout(coercionType);
+      auto unpaddedCoercionType = ArgInfo.getUnpaddedCoerceAndExpandType();
+      auto *unpaddedStruct = dyn_cast<llvm::StructType>(unpaddedCoercionType);
 
       llvm::Value *tempSize = nullptr;
       Address addr = Address::invalid();
@@ -5502,11 +5540,16 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
       addr = addr.withElementType(coercionType);
 
       unsigned IRArgPos = FirstIRArg;
+      unsigned unpaddedIndex = 0;
       for (unsigned i = 0, e = coercionType->getNumElements(); i != e; ++i) {
         llvm::Type *eltType = coercionType->getElementType(i);
         if (ABIArgInfo::isPaddingForCoerceAndExpand(eltType)) continue;
         Address eltAddr = Builder.CreateStructGEP(addr, i);
-        llvm::Value *elt = Builder.CreateLoad(eltAddr);
+        llvm::Value *elt = CreateCoercedLoad(
+            eltAddr,
+            unpaddedStruct ? unpaddedStruct->getElementType(unpaddedIndex++)
+                           : unpaddedCoercionType,
+            *this);
         if (ArgHasMaybeUndefAttr)
           elt = Builder.CreateFreeze(elt);
         IRCallArgs[IRArgPos++] = elt;

diff  --git a/clang/lib/CodeGen/Targets/AArch64.cpp b/clang/lib/CodeGen/Targets/AArch64.cpp
index ec617eec67192c..a80411971b60c3 100644
--- a/clang/lib/CodeGen/Targets/AArch64.cpp
+++ b/clang/lib/CodeGen/Targets/AArch64.cpp
@@ -34,10 +34,17 @@ class AArch64ABIInfo : public ABIInfo {
   AArch64ABIKind getABIKind() const { return Kind; }
   bool isDarwinPCS() const { return Kind == AArch64ABIKind::DarwinPCS; }
 
-  ABIArgInfo classifyReturnType(QualType RetTy, bool IsVariadic) const;
-  ABIArgInfo classifyArgumentType(QualType RetTy, bool IsVariadic,
-                                  unsigned CallingConvention) const;
-  ABIArgInfo coerceIllegalVector(QualType Ty) const;
+  ABIArgInfo classifyReturnType(QualType RetTy, bool IsVariadicFn) const;
+  ABIArgInfo classifyArgumentType(QualType RetTy, bool IsVariadicFn,
+                                  bool IsNamedArg, unsigned CallingConvention,
+                                  unsigned &NSRN, unsigned &NPRN) const;
+  llvm::Type *convertFixedToScalableVectorType(const VectorType *VT) const;
+  ABIArgInfo coerceIllegalVector(QualType Ty, unsigned &NSRN,
+                                 unsigned &NPRN) const;
+  ABIArgInfo coerceAndExpandPureScalableAggregate(
+      QualType Ty, bool IsNamedArg, unsigned NVec, unsigned NPred,
+      const SmallVectorImpl<llvm::Type *> &UnpaddedCoerceToSeq, unsigned &NSRN,
+      unsigned &NPRN) const;
   bool isHomogeneousAggregateBaseType(QualType Ty) const override;
   bool isHomogeneousAggregateSmallEnough(const Type *Ty,
                                          uint64_t Members) const override;
@@ -45,14 +52,26 @@ class AArch64ABIInfo : public ABIInfo {
 
   bool isIllegalVectorType(QualType Ty) const;
 
+  bool passAsPureScalableType(QualType Ty, unsigned &NV, unsigned &NP,
+                              SmallVectorImpl<llvm::Type *> &CoerceToSeq) const;
+
+  void flattenType(llvm::Type *Ty,
+                   SmallVectorImpl<llvm::Type *> &Flattened) const;
+
   void computeInfo(CGFunctionInfo &FI) const override {
     if (!::classifyReturnType(getCXXABI(), FI, *this))
       FI.getReturnInfo() =
           classifyReturnType(FI.getReturnType(), FI.isVariadic());
 
-    for (auto &it : FI.arguments())
-      it.info = classifyArgumentType(it.type, FI.isVariadic(),
-                                     FI.getCallingConvention());
+    unsigned ArgNo = 0;
+    unsigned NSRN = 0, NPRN = 0;
+    for (auto &it : FI.arguments()) {
+      const bool IsNamedArg =
+          !FI.isVariadic() || ArgNo < FI.getRequiredArgs().getNumRequiredArgs();
+      ++ArgNo;
+      it.info = classifyArgumentType(it.type, FI.isVariadic(), IsNamedArg,
+                                     FI.getCallingConvention(), NSRN, NPRN);
+    }
   }
 
   RValue EmitDarwinVAArg(Address VAListAddr, QualType Ty, CodeGenFunction &CGF,
@@ -201,65 +220,83 @@ void WindowsAArch64TargetCodeGenInfo::setTargetAttributes(
 }
 }
 
-ABIArgInfo AArch64ABIInfo::coerceIllegalVector(QualType Ty) const {
-  assert(Ty->isVectorType() && "expected vector type!");
+llvm::Type *
+AArch64ABIInfo::convertFixedToScalableVectorType(const VectorType *VT) const {
+  assert(VT->getElementType()->isBuiltinType() && "expected builtin type!");
 
-  const auto *VT = Ty->castAs<VectorType>();
   if (VT->getVectorKind() == VectorKind::SveFixedLengthPredicate) {
-    assert(VT->getElementType()->isBuiltinType() && "expected builtin type!");
     assert(VT->getElementType()->castAs<BuiltinType>()->getKind() ==
                BuiltinType::UChar &&
            "unexpected builtin type for SVE predicate!");
-    return ABIArgInfo::getDirect(llvm::ScalableVectorType::get(
-        llvm::Type::getInt1Ty(getVMContext()), 16));
+    return llvm::ScalableVectorType::get(llvm::Type::getInt1Ty(getVMContext()),
+                                         16);
   }
 
   if (VT->getVectorKind() == VectorKind::SveFixedLengthData) {
-    assert(VT->getElementType()->isBuiltinType() && "expected builtin type!");
-
     const auto *BT = VT->getElementType()->castAs<BuiltinType>();
-    llvm::ScalableVectorType *ResType = nullptr;
     switch (BT->getKind()) {
     default:
       llvm_unreachable("unexpected builtin type for SVE vector!");
+
     case BuiltinType::SChar:
     case BuiltinType::UChar:
-      ResType = llvm::ScalableVectorType::get(
+      return llvm::ScalableVectorType::get(
           llvm::Type::getInt8Ty(getVMContext()), 16);
-      break;
+
     case BuiltinType::Short:
     case BuiltinType::UShort:
-      ResType = llvm::ScalableVectorType::get(
+      return llvm::ScalableVectorType::get(
           llvm::Type::getInt16Ty(getVMContext()), 8);
-      break;
+
     case BuiltinType::Int:
     case BuiltinType::UInt:
-      ResType = llvm::ScalableVectorType::get(
+      return llvm::ScalableVectorType::get(
           llvm::Type::getInt32Ty(getVMContext()), 4);
-      break;
+
     case BuiltinType::Long:
     case BuiltinType::ULong:
-      ResType = llvm::ScalableVectorType::get(
+      return llvm::ScalableVectorType::get(
           llvm::Type::getInt64Ty(getVMContext()), 2);
-      break;
+
     case BuiltinType::Half:
-      ResType = llvm::ScalableVectorType::get(
+      return llvm::ScalableVectorType::get(
           llvm::Type::getHalfTy(getVMContext()), 8);
-      break;
+
     case BuiltinType::Float:
-      ResType = llvm::ScalableVectorType::get(
+      return llvm::ScalableVectorType::get(
           llvm::Type::getFloatTy(getVMContext()), 4);
-      break;
+
     case BuiltinType::Double:
-      ResType = llvm::ScalableVectorType::get(
+      return llvm::ScalableVectorType::get(
           llvm::Type::getDoubleTy(getVMContext()), 2);
-      break;
+
     case BuiltinType::BFloat16:
-      ResType = llvm::ScalableVectorType::get(
+      return llvm::ScalableVectorType::get(
           llvm::Type::getBFloatTy(getVMContext()), 8);
-      break;
     }
-    return ABIArgInfo::getDirect(ResType);
+  }
+
+  llvm_unreachable("expected fixed-length SVE vector");
+}
+
+ABIArgInfo AArch64ABIInfo::coerceIllegalVector(QualType Ty, unsigned &NSRN,
+                                               unsigned &NPRN) const {
+  assert(Ty->isVectorType() && "expected vector type!");
+
+  const auto *VT = Ty->castAs<VectorType>();
+  if (VT->getVectorKind() == VectorKind::SveFixedLengthPredicate) {
+    assert(VT->getElementType()->isBuiltinType() && "expected builtin type!");
+    assert(VT->getElementType()->castAs<BuiltinType>()->getKind() ==
+               BuiltinType::UChar &&
+           "unexpected builtin type for SVE predicate!");
+    NPRN = std::min(NPRN + 1, 4u);
+    return ABIArgInfo::getDirect(llvm::ScalableVectorType::get(
+        llvm::Type::getInt1Ty(getVMContext()), 16));
+  }
+
+  if (VT->getVectorKind() == VectorKind::SveFixedLengthData) {
+    NSRN = std::min(NSRN + 1, 8u);
+    return ABIArgInfo::getDirect(convertFixedToScalableVectorType(VT));
   }
 
   uint64_t Size = getContext().getTypeSize(Ty);
@@ -273,26 +310,54 @@ ABIArgInfo AArch64ABIInfo::coerceIllegalVector(QualType Ty) const {
     return ABIArgInfo::getDirect(ResType);
   }
   if (Size == 64) {
+    NSRN = std::min(NSRN + 1, 8u);
     auto *ResType =
         llvm::FixedVectorType::get(llvm::Type::getInt32Ty(getVMContext()), 2);
     return ABIArgInfo::getDirect(ResType);
   }
   if (Size == 128) {
+    NSRN = std::min(NSRN + 1, 8u);
     auto *ResType =
         llvm::FixedVectorType::get(llvm::Type::getInt32Ty(getVMContext()), 4);
     return ABIArgInfo::getDirect(ResType);
   }
+
   return getNaturalAlignIndirect(Ty, /*ByVal=*/false);
 }
 
-ABIArgInfo
-AArch64ABIInfo::classifyArgumentType(QualType Ty, bool IsVariadic,
-                                     unsigned CallingConvention) const {
+ABIArgInfo AArch64ABIInfo::coerceAndExpandPureScalableAggregate(
+    QualType Ty, bool IsNamedArg, unsigned NVec, unsigned NPred,
+    const SmallVectorImpl<llvm::Type *> &UnpaddedCoerceToSeq, unsigned &NSRN,
+    unsigned &NPRN) const {
+  if (!IsNamedArg || NSRN + NVec > 8 || NPRN + NPred > 4)
+    return getNaturalAlignIndirect(Ty, /*ByVal=*/false);
+  NSRN += NVec;
+  NPRN += NPred;
+
+  llvm::Type *UnpaddedCoerceToType =
+      UnpaddedCoerceToSeq.size() == 1
+          ? UnpaddedCoerceToSeq[0]
+          : llvm::StructType::get(CGT.getLLVMContext(), UnpaddedCoerceToSeq,
+                                  true);
+
+  SmallVector<llvm::Type *> CoerceToSeq;
+  flattenType(CGT.ConvertType(Ty), CoerceToSeq);
+  auto *CoerceToType =
+      llvm::StructType::get(CGT.getLLVMContext(), CoerceToSeq, false);
+
+  return ABIArgInfo::getCoerceAndExpand(CoerceToType, UnpaddedCoerceToType);
+}
+
+ABIArgInfo AArch64ABIInfo::classifyArgumentType(QualType Ty, bool IsVariadicFn,
+                                                bool IsNamedArg,
+                                                unsigned CallingConvention,
+                                                unsigned &NSRN,
+                                                unsigned &NPRN) const {
   Ty = useFirstFieldIfTransparentUnion(Ty);
 
   // Handle illegal vector types here.
   if (isIllegalVectorType(Ty))
-    return coerceIllegalVector(Ty);
+    return coerceIllegalVector(Ty, NSRN, NPRN);
 
   if (!isAggregateTypeForABI(Ty)) {
     // Treat an enum type as its underlying type.
@@ -303,6 +368,36 @@ AArch64ABIInfo::classifyArgumentType(QualType Ty, bool IsVariadic,
       if (EIT->getNumBits() > 128)
         return getNaturalAlignIndirect(Ty, false);
 
+    if (Ty->isVectorType())
+      NSRN = std::min(NSRN + 1, 8u);
+    else if (const auto *BT = Ty->getAs<BuiltinType>()) {
+      if (BT->isFloatingPoint())
+        NSRN = std::min(NSRN + 1, 8u);
+      else {
+        switch (BT->getKind()) {
+        case BuiltinType::MFloat8x8:
+        case BuiltinType::MFloat8x16:
+          NSRN = std::min(NSRN + 1, 8u);
+          break;
+        case BuiltinType::SveBool:
+        case BuiltinType::SveCount:
+          NPRN = std::min(NPRN + 1, 4u);
+          break;
+        case BuiltinType::SveBoolx2:
+          NPRN = std::min(NPRN + 2, 4u);
+          break;
+        case BuiltinType::SveBoolx4:
+          NPRN = std::min(NPRN + 4, 4u);
+          break;
+        default:
+          if (BT->isSVESizelessBuiltinType())
+            NSRN = std::min(
+                NSRN + getContext().getBuiltinVectorTypeInfo(BT).NumVectors,
+                8u);
+        }
+      }
+    }
+
     return (isPromotableIntegerTypeForABI(Ty) && isDarwinPCS()
                 ? ABIArgInfo::getExtend(Ty, CGT.ConvertType(Ty))
                 : ABIArgInfo::getDirect());
@@ -335,10 +430,11 @@ AArch64ABIInfo::classifyArgumentType(QualType Ty, bool IsVariadic,
   uint64_t Members = 0;
   bool IsWin64 = Kind == AArch64ABIKind::Win64 ||
                  CallingConvention == llvm::CallingConv::Win64;
-  bool IsWinVariadic = IsWin64 && IsVariadic;
+  bool IsWinVariadic = IsWin64 && IsVariadicFn;
   // In variadic functions on Windows, all composite types are treated alike,
   // no special handling of HFAs/HVAs.
   if (!IsWinVariadic && isHomogeneousAggregate(Ty, Base, Members)) {
+    NSRN = std::min(NSRN + Members, uint64_t(8));
     if (Kind != AArch64ABIKind::AAPCS)
       return ABIArgInfo::getDirect(
           llvm::ArrayType::get(CGT.ConvertType(QualType(Base, 0)), Members));
@@ -353,6 +449,17 @@ AArch64ABIInfo::classifyArgumentType(QualType Ty, bool IsVariadic,
         nullptr, true, Align);
   }
 
+  // In AAPCS named arguments of a Pure Scalable Type are passed expanded in
+  // registers, or indirectly if there are not enough registers.
+  if (Kind == AArch64ABIKind::AAPCS) {
+    unsigned NVec = 0, NPred = 0;
+    SmallVector<llvm::Type *> UnpaddedCoerceToSeq;
+    if (passAsPureScalableType(Ty, NVec, NPred, UnpaddedCoerceToSeq) &&
+        (NVec + NPred) > 0)
+      return coerceAndExpandPureScalableAggregate(
+          Ty, IsNamedArg, NVec, NPred, UnpaddedCoerceToSeq, NSRN, NPRN);
+  }
+
   // Aggregates <= 16 bytes are passed directly in registers or on the stack.
   if (Size <= 128) {
     // On RenderScript, coerce Aggregates <= 16 bytes to an integer array of
@@ -383,14 +490,16 @@ AArch64ABIInfo::classifyArgumentType(QualType Ty, bool IsVariadic,
 }
 
 ABIArgInfo AArch64ABIInfo::classifyReturnType(QualType RetTy,
-                                              bool IsVariadic) const {
+                                              bool IsVariadicFn) const {
   if (RetTy->isVoidType())
     return ABIArgInfo::getIgnore();
 
   if (const auto *VT = RetTy->getAs<VectorType>()) {
     if (VT->getVectorKind() == VectorKind::SveFixedLengthData ||
-        VT->getVectorKind() == VectorKind::SveFixedLengthPredicate)
-      return coerceIllegalVector(RetTy);
+        VT->getVectorKind() == VectorKind::SveFixedLengthPredicate) {
+      unsigned NSRN = 0, NPRN = 0;
+      return coerceIllegalVector(RetTy, NSRN, NPRN);
+    }
   }
 
   // Large vector types should be returned via memory.
@@ -419,10 +528,24 @@ ABIArgInfo AArch64ABIInfo::classifyReturnType(QualType RetTy,
   uint64_t Members = 0;
   if (isHomogeneousAggregate(RetTy, Base, Members) &&
       !(getTarget().getTriple().getArch() == llvm::Triple::aarch64_32 &&
-        IsVariadic))
+        IsVariadicFn))
     // Homogeneous Floating-point Aggregates (HFAs) are returned directly.
     return ABIArgInfo::getDirect();
 
+  // In AAPCS return values of a Pure Scalable type are treated as a single
+  // named argument and passed expanded in registers, or indirectly if there are
+  // not enough registers.
+  if (Kind == AArch64ABIKind::AAPCS) {
+    unsigned NSRN = 0, NPRN = 0;
+    unsigned NVec = 0, NPred = 0;
+    SmallVector<llvm::Type *> UnpaddedCoerceToSeq;
+    if (passAsPureScalableType(RetTy, NVec, NPred, UnpaddedCoerceToSeq) &&
+        (NVec + NPred) > 0)
+      return coerceAndExpandPureScalableAggregate(
+          RetTy, /* IsNamedArg */ true, NVec, NPred, UnpaddedCoerceToSeq, NSRN,
+          NPRN);
+  }
+
   // Aggregates <= 16 bytes are returned directly in registers or on the stack.
   if (Size <= 128) {
     // On RenderScript, coerce Aggregates <= 16 bytes to an integer array of
@@ -508,9 +631,15 @@ bool AArch64ABIInfo::isHomogeneousAggregateBaseType(QualType Ty) const {
   // but with the 
diff erence that any floating-point type is allowed,
   // including __fp16.
   if (const BuiltinType *BT = Ty->getAs<BuiltinType>()) {
-    if (BT->isFloatingPoint())
+    if (BT->isFloatingPoint() || BT->getKind() == BuiltinType::MFloat8x16 ||
+        BT->getKind() == BuiltinType::MFloat8x8)
       return true;
   } else if (const VectorType *VT = Ty->getAs<VectorType>()) {
+    if (auto Kind = VT->getVectorKind();
+        Kind == VectorKind::SveFixedLengthData ||
+        Kind == VectorKind::SveFixedLengthPredicate)
+      return false;
+
     unsigned VecSize = getContext().getTypeSize(VT);
     if (VecSize == 64 || VecSize == 128)
       return true;
@@ -533,11 +662,166 @@ bool AArch64ABIInfo::isZeroLengthBitfieldPermittedInHomogeneousAggregate()
   return true;
 }
 
+// Check if a type needs to be passed in registers as a Pure Scalable Type (as
+// defined by AAPCS64). Return the number of data vectors and the number of
+// predicate vectors in the type, into `NVec` and `NPred`, respectively. Upon
+// return `CoerceToSeq` contains an expanded sequence of LLVM IR types, one
+// element for each non-composite member. For practical purposes, limit the
+// length of `CoerceToSeq` to about 12 (the maximum that could possibly fit
+// in registers) and return false, the effect of which will be to  pass the
+// argument under the rules for a large (> 128 bytes) composite.
+bool AArch64ABIInfo::passAsPureScalableType(
+    QualType Ty, unsigned &NVec, unsigned &NPred,
+    SmallVectorImpl<llvm::Type *> &CoerceToSeq) const {
+  if (const ConstantArrayType *AT = getContext().getAsConstantArrayType(Ty)) {
+    uint64_t NElt = AT->getZExtSize();
+    if (NElt == 0)
+      return false;
+
+    unsigned NV = 0, NP = 0;
+    SmallVector<llvm::Type *> EltCoerceToSeq;
+    if (!passAsPureScalableType(AT->getElementType(), NV, NP, EltCoerceToSeq))
+      return false;
+
+    if (CoerceToSeq.size() + NElt * EltCoerceToSeq.size() > 12)
+      return false;
+
+    for (uint64_t I = 0; I < NElt; ++I)
+      llvm::copy(EltCoerceToSeq, std::back_inserter(CoerceToSeq));
+
+    NVec += NElt * NV;
+    NPred += NElt * NP;
+    return true;
+  }
+
+  if (const RecordType *RT = Ty->getAs<RecordType>()) {
+    // If the record cannot be passed in registers, then it's not a PST.
+    if (CGCXXABI::RecordArgABI RAA = getRecordArgABI(RT, getCXXABI());
+        RAA != CGCXXABI::RAA_Default)
+      return false;
+
+    // Pure scalable types are never unions and never contain unions.
+    const RecordDecl *RD = RT->getDecl();
+    if (RD->isUnion())
+      return false;
+
+    // If this is a C++ record, check the bases.
+    if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+      for (const auto &I : CXXRD->bases()) {
+        if (isEmptyRecord(getContext(), I.getType(), true))
+          continue;
+        if (!passAsPureScalableType(I.getType(), NVec, NPred, CoerceToSeq))
+          return false;
+      }
+    }
+
+    // Check members.
+    for (const auto *FD : RD->fields()) {
+      QualType FT = FD->getType();
+      if (isEmptyField(getContext(), FD, /* AllowArrays */ true))
+        continue;
+      if (!passAsPureScalableType(FT, NVec, NPred, CoerceToSeq))
+        return false;
+    }
+
+    return true;
+  }
+
+  const auto *VT = Ty->getAs<VectorType>();
+  if (!VT)
+    return false;
+
+  if (VT->getVectorKind() == VectorKind::SveFixedLengthPredicate) {
+    ++NPred;
+    if (CoerceToSeq.size() + 1 > 12)
+      return false;
+    CoerceToSeq.push_back(convertFixedToScalableVectorType(VT));
+    return true;
+  }
+
+  if (VT->getVectorKind() == VectorKind::SveFixedLengthData) {
+    ++NVec;
+    if (CoerceToSeq.size() + 1 > 12)
+      return false;
+    CoerceToSeq.push_back(convertFixedToScalableVectorType(VT));
+    return true;
+  }
+
+  if (!VT->isBuiltinType())
+    return false;
+
+  switch (cast<BuiltinType>(VT)->getKind()) {
+#define SVE_VECTOR_TYPE(Name, MangledName, Id, SingletonId)                    \
+  case BuiltinType::Id:                                                        \
+    ++NVec;                                                                    \
+    break;
+#define SVE_PREDICATE_TYPE(Name, MangledName, Id, SingletonId)                 \
+  case BuiltinType::Id:                                                        \
+    ++NPred;                                                                   \
+    break;
+#define SVE_TYPE(Name, Id, SingletonId)
+#include "clang/Basic/AArch64SVEACLETypes.def"
+  default:
+    return false;
+  }
+
+  ASTContext::BuiltinVectorTypeInfo Info =
+      getContext().getBuiltinVectorTypeInfo(cast<BuiltinType>(Ty));
+  assert(Info.NumVectors > 0 && Info.NumVectors <= 4 &&
+         "Expected 1, 2, 3 or 4 vectors!");
+  auto VTy = llvm::ScalableVectorType::get(CGT.ConvertType(Info.ElementType),
+                                           Info.EC.getKnownMinValue());
+
+  if (CoerceToSeq.size() + Info.NumVectors > 12)
+    return false;
+  std::fill_n(std::back_inserter(CoerceToSeq), Info.NumVectors, VTy);
+
+  return true;
+}
+
+// Expand an LLVM IR type into a sequence with a element for each non-struct,
+// non-array member of the type, with the exception of the padding types, which
+// are retained.
+void AArch64ABIInfo::flattenType(
+    llvm::Type *Ty, SmallVectorImpl<llvm::Type *> &Flattened) const {
+
+  if (ABIArgInfo::isPaddingForCoerceAndExpand(Ty)) {
+    Flattened.push_back(Ty);
+    return;
+  }
+
+  if (const auto *AT = dyn_cast<llvm::ArrayType>(Ty)) {
+    uint64_t NElt = AT->getNumElements();
+    if (NElt == 0)
+      return;
+
+    SmallVector<llvm::Type *> EltFlattened;
+    flattenType(AT->getElementType(), EltFlattened);
+
+    for (uint64_t I = 0; I < NElt; ++I)
+      llvm::copy(EltFlattened, std::back_inserter(Flattened));
+    return;
+  }
+
+  if (const auto *ST = dyn_cast<llvm::StructType>(Ty)) {
+    for (auto *ET : ST->elements())
+      flattenType(ET, Flattened);
+    return;
+  }
+
+  Flattened.push_back(Ty);
+}
+
 RValue AArch64ABIInfo::EmitAAPCSVAArg(Address VAListAddr, QualType Ty,
                                       CodeGenFunction &CGF, AArch64ABIKind Kind,
                                       AggValueSlot Slot) const {
-  ABIArgInfo AI = classifyArgumentType(Ty, /*IsVariadic=*/true,
-                                       CGF.CurFnInfo->getCallingConvention());
+  // These numbers are not used for variadic arguments, hence it doesn't matter
+  // they don't retain their values across multiple calls to
+  // `classifyArgumentType` here.
+  unsigned NSRN = 0, NPRN = 0;
+  ABIArgInfo AI =
+      classifyArgumentType(Ty, /*IsVariadicFn=*/true, /* IsNamedArg */ false,
+                           CGF.CurFnInfo->getCallingConvention(), NSRN, NPRN);
   // Empty records are ignored for parameter passing purposes.
   if (AI.isIgnore())
     return Slot.asRValue();

diff  --git a/clang/test/CodeGen/aarch64-pure-scalable-args-empty-union.c b/clang/test/CodeGen/aarch64-pure-scalable-args-empty-union.c
new file mode 100644
index 00000000000000..546910068c78a2
--- /dev/null
+++ b/clang/test/CodeGen/aarch64-pure-scalable-args-empty-union.c
@@ -0,0 +1,39 @@
+// RUN: %clang_cc1        -O3 -triple aarch64 -target-feature +sve -mvscale-min=1 -mvscale-max=1 -emit-llvm -o - %s | FileCheck %s --check-prefixes=CHECK-C
+// RUN: %clang_cc1 -x c++ -O3 -triple aarch64 -target-feature +sve -mvscale-min=1 -mvscale-max=1 -emit-llvm -o - %s | FileCheck %s --check-prefixes=CHECK-CXX
+
+typedef __SVFloat32_t fvec32 __attribute__((arm_sve_vector_bits(128)));
+
+// PST containing an empty union: when compiled as C pass it in registers,
+// when compiled as C++ - in memory.
+typedef struct {
+  fvec32 x[4];
+  union {} u;
+} S0;
+
+#ifdef __cplusplus
+extern "C"
+#endif
+void use0(S0);
+
+void f0(S0 *p) {
+  use0(*p);
+}
+// CHECK-C:   declare void @use0(<vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>)
+// CHECK-CXX: declare void @use0(ptr noundef)
+
+#ifdef __cplusplus
+
+// PST containing an empty union with `[[no_unique_address]]`` - pass in registers.
+typedef struct {
+   fvec32 x[4];
+   [[no_unique_address]]
+   union {} u;
+} S1;
+
+extern "C" void use1(S1);
+void f1(S1 *p) {
+  use1(*p);
+}
+// CHECK-CXX: declare void @use1(<vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>)
+
+#endif // __cplusplus

diff  --git a/clang/test/CodeGen/aarch64-pure-scalable-args.c b/clang/test/CodeGen/aarch64-pure-scalable-args.c
new file mode 100644
index 00000000000000..851159ada76749
--- /dev/null
+++ b/clang/test/CodeGen/aarch64-pure-scalable-args.c
@@ -0,0 +1,461 @@
+// RUN: %clang_cc1 -O3 -triple aarch64                                  -target-feature +sve -target-feature +sve2p1 -mvscale-min=1 -mvscale-max=1 -emit-llvm -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-AAPCS
+// RUN: %clang_cc1 -O3 -triple arm64-apple-ios7.0 -target-abi darwinpcs -target-feature +sve -target-feature +sve2p1 -mvscale-min=1 -mvscale-max=1 -emit-llvm -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-DARWIN
+// RUN: %clang_cc1 -O3 -triple aarch64-linux-gnu                        -target-feature +sve -target-feature +sve2p1 -mvscale-min=1 -mvscale-max=1 -emit-llvm -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-AAPCS
+
+// REQUIRES: aarch64-registered-target
+
+#include <arm_neon.h>
+#include <arm_sve.h>
+#include <stdarg.h>
+
+typedef svfloat32_t fvec32 __attribute__((arm_sve_vector_bits(128)));
+typedef svfloat64_t fvec64 __attribute__((arm_sve_vector_bits(128)));
+typedef svbool_t bvec __attribute__((arm_sve_vector_bits(128)));
+typedef svmfloat8_t mfvec8 __attribute__((arm_sve_vector_bits(128)));
+
+typedef struct {
+    float f[4];
+} HFA;
+
+typedef struct {
+    mfloat8x16_t f[4];
+} HVA;
+
+// Pure Scalable Type, needs 4 Z-regs, 2 P-regs
+typedef struct {
+     bvec a;
+     fvec64 x;
+     fvec32 y[2];
+     mfvec8 z;
+     bvec b;
+} PST;
+
+// Pure Scalable Type, 1 Z-reg
+typedef struct {
+    fvec32 x;
+} SmallPST;
+
+// Big PST, does not fit in registers.
+typedef struct {
+    struct {
+        bvec a;
+        fvec32 x[4];
+    } u[2];
+    fvec64 v;
+} BigPST;
+
+// A small aggregate type
+typedef struct  {
+    char data[16];
+} SmallAgg;
+
+// CHECK: %struct.PST = type { <2 x i8>, <2 x double>, [2 x <4 x float>], <16 x i8>, <2 x i8> }
+
+// Test argument passing of Pure Scalable Types by examining the generated
+// LLVM IR function declarations. A PST argument in C/C++ should map to:
+//   a) an `ptr` argument, if passed indirectly through memory
+//   b) a series of scalable vector arguments, if passed via registers
+
+// Simple argument passing, PST expanded into registers.
+//   a    -> p0
+//   b    -> p1
+//   x    -> q0
+//   y[0] -> q1
+//   y[1] -> q2
+//   z    -> q3
+void test_argpass_simple(PST *p) {
+    void argpass_simple_callee(PST);
+    argpass_simple_callee(*p);
+}
+// CHECK-AAPCS:      define dso_local void @test_argpass_simple(ptr nocapture noundef readonly %p)
+// CHECK-AAPCS-NEXT: entry:
+// CHECK-AAPCS-NEXT: %0 = load <2 x i8>, ptr %p, align 16
+// CHECK-AAPCS-NEXT: %cast.scalable = tail call <vscale x 2 x i8> @llvm.vector.insert.nxv2i8.v2i8(<vscale x 2 x i8> undef, <2 x i8> %0, i64 0)
+// CHECK-AAPCS-NEXT: %1 = bitcast <vscale x 2 x i8> %cast.scalable to <vscale x 16 x i1>
+// CHECK-AAPCS-NEXT: %2 = getelementptr inbounds nuw i8, ptr %p, i64 16
+// CHECK-AAPCS-NEXT: %3 = load <2 x double>, ptr %2, align 16
+// CHECK-AAPCS-NEXT: %cast.scalable1 = tail call <vscale x 2 x double> @llvm.vector.insert.nxv2f64.v2f64(<vscale x 2 x double> undef, <2 x double> %3, i64 0)
+// CHECK-AAPCS-NEXT: %4 = getelementptr inbounds nuw i8, ptr %p, i64 32
+// CHECK-AAPCS-NEXT: %5 = load <4 x float>, ptr %4, align 16
+// CHECK-AAPCS-NEXT: %cast.scalable2 = tail call <vscale x 4 x float> @llvm.vector.insert.nxv4f32.v4f32(<vscale x 4 x float> undef, <4 x float> %5, i64 0)
+// CHECK-AAPCS-NEXT: %6 = getelementptr inbounds nuw i8, ptr %p, i64 48
+// CHECK-AAPCS-NEXT: %7 = load <4 x float>, ptr %6, align 16
+// CHECK-AAPCS-NEXT: %cast.scalable3 = tail call <vscale x 4 x float> @llvm.vector.insert.nxv4f32.v4f32(<vscale x 4 x float> undef, <4 x float> %7, i64 0)
+// CHECK-AAPCS-NEXT: %8 = getelementptr inbounds nuw i8, ptr %p, i64 64
+// CHECK-AAPCS-NEXT: %9 = load <16 x i8>, ptr %8, align 16
+// CHECK-AAPCS-NEXT: %cast.scalable4 = tail call <vscale x 16 x i8> @llvm.vector.insert.nxv16i8.v16i8(<vscale x 16 x i8> undef, <16 x i8> %9, i64 0)
+// CHECK-AAPCS-NEXT: %10 = getelementptr inbounds nuw i8, ptr %p, i64 80
+// CHECK-AAPCS-NEXT: %11 = load <2 x i8>, ptr %10, align 16
+// CHECK-AAPCS-NEXT: %cast.scalable5 = tail call <vscale x 2 x i8> @llvm.vector.insert.nxv2i8.v2i8(<vscale x 2 x i8> undef, <2 x i8> %11, i64 0)
+// CHECK-AAPCS-NEXT: %12 = bitcast <vscale x 2 x i8> %cast.scalable5 to <vscale x 16 x i1>
+// CHECK-AAPCS-NEXT: tail call void @argpass_simple_callee(<vscale x 16 x i1> %1, <vscale x 2 x double> %cast.scalable1, <vscale x 4 x float> %cast.scalable2, <vscale x 4 x float> %cast.scalable3, <vscale x 16 x i8> %cast.scalable4, <vscale x 16 x i1> %12)
+// CHECK-AAPCS-NEXT: ret void
+
+// CHECK-AAPCS:  declare void @argpass_simple_callee(<vscale x 16 x i1>, <vscale x 2 x double>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 16 x i8>, <vscale x 16 x i1>)
+// CHECK-DARWIN: declare void @argpass_simple_callee(ptr noundef)
+
+// Boundary case of using the last available Z-reg, PST expanded.
+//   0.0  -> d0-d3
+//   a    -> p0
+//   b    -> p1
+//   x    -> q4
+//   y[0] -> q5
+//   y[1] -> q6
+//   z    -> q7
+void test_argpass_last_z(PST *p) {
+    void argpass_last_z_callee(double, double, double, double, PST);
+    argpass_last_z_callee(.0, .0, .0, .0, *p);
+}
+// CHECK-AAPCS:  declare void @argpass_last_z_callee(double noundef, double noundef, double noundef, double noundef, <vscale x 16 x i1>, <vscale x 2 x double>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 16 x i8>, <vscale x 16 x i1>)
+// CHECK-DARWIN: declare void @argpass_last_z_callee(double noundef, double noundef, double noundef, double noundef, ptr noundef)
+
+
+// Like the above, but using a tuple type to occupy some registers.
+//   x    -> z0.d-z3.d
+//   a    -> p0
+//   b    -> p1
+//   x    -> q4
+//   y[0] -> q5
+//   y[1] -> q6
+//   z    -> q7
+void test_argpass_last_z_tuple(PST *p, svfloat64x4_t x) {
+  void argpass_last_z_tuple_callee(svfloat64x4_t, PST);
+  argpass_last_z_tuple_callee(x, *p);
+}
+// CHECK-AAPCS:  declare void @argpass_last_z_tuple_callee(<vscale x 2 x double>, <vscale x 2 x double>, <vscale x 2 x double>, <vscale x 2 x double>, <vscale x 16 x i1>, <vscale x 2 x double>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 16 x i8>, <vscale x 16 x i1>)
+// CHECK-DARWIN: declare void @argpass_last_z_tuple_callee(<vscale x 2 x double>, <vscale x 2 x double>, <vscale x 2 x double>, <vscale x 2 x double>, ptr noundef)
+
+
+// Boundary case of using the last available P-reg, PST expanded.
+//   false -> p0-p1
+//   a     -> p2
+//   b     -> p3
+//   x     -> q0
+//   y[0]  -> q1
+//   y[1]  -> q2
+//   z     -> q3
+void test_argpass_last_p(PST *p) {
+    void argpass_last_p_callee(svbool_t, svcount_t, PST);
+    argpass_last_p_callee(svpfalse(), svpfalse_c(), *p);
+}
+// CHECK-AAPCS:  declare void @argpass_last_p_callee(<vscale x 16 x i1>, target("aarch64.svcount"), <vscale x 16 x i1>, <vscale x 2 x double>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 16 x i8>, <vscale x 16 x i1>)
+// CHECK-DARWIN: declare void @argpass_last_p_callee(<vscale x 16 x i1>, target("aarch64.svcount"), ptr noundef)
+
+
+// Not enough Z-regs, push PST to memory and pass a pointer, Z-regs and
+// P-regs still available for other arguments
+//   u     -> z0
+//   v     -> q1
+//   w     -> q2
+//   0.0   -> d3-d4
+//   1     -> w0
+//   *p    -> memory, address -> x1
+//   2     -> w2
+//   3.0   -> d5
+//   true  -> p0
+void test_argpass_no_z(PST *p, double dummy, svmfloat8_t u, int8x16_t v, mfloat8x16_t w) {
+    void argpass_no_z_callee(svmfloat8_t, int8x16_t, mfloat8x16_t, double, double, int, PST, int, double, svbool_t);
+    argpass_no_z_callee(u, v, w, .0, .0, 1, *p, 2, 3.0, svptrue_b64());
+}
+// CHECK: declare void @argpass_no_z_callee(<vscale x 16 x i8>, <16 x i8> noundef, <16 x i8>, double noundef, double noundef, i32 noundef, ptr noundef, i32 noundef, double noundef, <vscale x 16 x i1>)
+
+
+// Like the above, using a tuple to occupy some registers.
+//   x     -> z0.d-z3.d
+//   0.0   -> d4
+//   1     -> w0
+//   *p    -> memory, address -> x1
+//   2     -> w2
+//   3.0   -> d5
+//   true  -> p0
+void test_argpass_no_z_tuple_f64(PST *p, float dummy, svfloat64x4_t x) {
+  void argpass_no_z_tuple_f64_callee(svfloat64x4_t, double, int, PST, int,
+                                     double, svbool_t);
+  argpass_no_z_tuple_f64_callee(x, .0, 1, *p, 2, 3.0, svptrue_b64());
+}
+// CHECK: declare void @argpass_no_z_tuple_f64_callee(<vscale x 2 x double>, <vscale x 2 x double>, <vscale x 2 x double>, <vscale x 2 x double>, double noundef, i32 noundef, ptr noundef, i32 noundef, double noundef, <vscale x 16 x i1>)
+
+
+// Likewise, using a 
diff erent tuple.
+//   x     -> z0.d-z3.d
+//   0.0   -> d4
+//   1     -> w0
+//   *p    -> memory, address -> x1
+//   2     -> w2
+//   3.0   -> d5
+//   true  -> p0
+void test_argpass_no_z_tuple_mfp8(PST *p, float dummy, svmfloat8x4_t x) {
+  void argpass_no_z_tuple_mfp8_callee(svmfloat8x4_t, double, int, PST, int,
+                                      double, svbool_t);
+  argpass_no_z_tuple_mfp8_callee(x, .0, 1, *p, 2, 3.0, svptrue_b64());
+}
+// CHECK: declare void @argpass_no_z_tuple_mfp8_callee(<vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8>, double noundef, i32 noundef, ptr noundef, i32 noundef, double noundef, <vscale x 16 x i1>)
+
+
+// Not enough Z-regs (consumed by a HFA), PST passed indirectly
+//   0.0  -> d0
+//   *h   -> s1-s4
+//   1    -> w0
+//   *p   -> memory, address -> x1
+//   p    -> x1
+//   2    -> w2
+//   true -> p0
+void test_argpass_no_z_hfa(HFA *h, PST *p) {
+    void argpass_no_z_hfa_callee(double, HFA, int, PST, int, svbool_t);
+    argpass_no_z_hfa_callee(.0, *h, 1, *p, 2, svptrue_b64());
+}
+// CHECK-AAPCS:  declare void @argpass_no_z_hfa_callee(double noundef, [4 x float] alignstack(8), i32 noundef, ptr noundef, i32 noundef, <vscale x 16 x i1>)
+// CHECK-DARWIN: declare void @argpass_no_z_hfa_callee(double noundef, [4 x float], i32 noundef, ptr noundef, i32 noundef, <vscale x 16 x i1>)
+
+// Not enough Z-regs (consumed by a HVA), PST passed indirectly
+//   0.0  -> d0
+//   *h   -> s1-s4
+//   1    -> w0
+//   *p   -> memory, address -> x1
+//   p    -> x1
+//   2    -> w2
+//   true -> p0
+void test_argpass_no_z_hva(HVA *h, PST *p) {
+    void argpass_no_z_hva_callee(double, HVA, int, PST, int, svbool_t);
+    argpass_no_z_hva_callee(.0, *h, 1, *p, 2, svptrue_b64());
+}
+// CHECK-AAPCS:  declare void @argpass_no_z_hva_callee(double noundef, [4 x <16 x i8>] alignstack(16), i32 noundef, ptr noundef, i32 noundef, <vscale x 16 x i1>)
+// CHECK-DARWIN: declare void @argpass_no_z_hva_callee(double noundef, [4 x <16 x i8>], i32 noundef, ptr noundef, i32 noundef, <vscale x 16 x i1>)
+
+// Not enough P-regs, PST passed indirectly, Z-regs and P-regs still available.
+//   true -> p0-p2
+//   1    -> w0
+//   *p   -> memory, address -> x1
+//   2    -> w2
+//   3.0  -> d0
+//   true -> p3
+void test_argpass_no_p(PST *p) {
+    void argpass_no_p_callee(svbool_t, svbool_t, svbool_t, int, PST, int, double, svbool_t);
+    argpass_no_p_callee(svptrue_b8(), svptrue_b16(), svptrue_b32(), 1, *p, 2, 3.0, svptrue_b64());
+}
+// CHECK: declare void @argpass_no_p_callee(<vscale x 16 x i1>, <vscale x 16 x i1>, <vscale x 16 x i1>, i32 noundef, ptr noundef, i32 noundef, double noundef, <vscale x 16 x i1>)
+
+
+// Like above, using a tuple to occupy some registers.
+// P-regs still available.
+//   v    -> p0-p1
+//   u    -> p2
+//   1    -> w0
+//   *p   -> memory, address -> x1
+//   2    -> w2
+//   3.0  -> d0
+//   true -> p3
+void test_argpass_no_p_tuple(PST *p, svbool_t u, svboolx2_t v) {
+  void argpass_no_p_tuple_callee(svboolx2_t, svbool_t, int, PST, int, double,
+                                 svbool_t);
+  argpass_no_p_tuple_callee(v, u, 1, *p, 2, 3.0, svptrue_b64());
+}
+// CHECK: declare void @argpass_no_p_tuple_callee(<vscale x 16 x i1>, <vscale x 16 x i1>, <vscale x 16 x i1>, i32 noundef, ptr noundef, i32 noundef, double noundef, <vscale x 16 x i1>)
+
+
+// HFAs go back-to-back to memory, afterwards Z-regs not available, PST passed indirectly.
+//   0.0   -> d0-d3
+//   *h    -> memory
+//   *p    -> memory, address -> x0
+//   *h    -> memory
+//   false -> p0
+void test_after_hfa(HFA *h, PST *p) {
+    void after_hfa_callee(double, double, double, double, double, HFA, PST, HFA, svbool_t);
+    after_hfa_callee(.0, .0, .0, .0, .0, *h, *p, *h, svpfalse());
+}
+// CHECK-AAPCS:  declare void @after_hfa_callee(double noundef, double noundef, double noundef, double noundef, double noundef, [4 x float] alignstack(8), ptr noundef, [4 x float] alignstack(8), <vscale x 16 x i1>)
+// CHECK-DARWIN: declare void @after_hfa_callee(double noundef, double noundef, double noundef, double noundef, double noundef, [4 x float], ptr noundef, [4 x float], <vscale x 16 x i1>)
+
+// Small PST, not enough registers, passed indirectly, unlike other small
+// aggregates.
+//   *s  -> x0-x1
+//   0.0 -> d0-d7
+//   *p  -> memory, address -> x2
+//   1.0 -> memory
+//   2.0 -> memory (next to the above)
+void test_small_pst(SmallPST *p, SmallAgg *s) {
+    void small_pst_callee(SmallAgg, double, double, double, double, double, double, double, double, double, SmallPST, double);
+    small_pst_callee(*s, .0, .0, .0, .0, .0, .0, .0, .0, 1.0, *p, 2.0);
+}
+// CHECK-AAPCS:  declare void @small_pst_callee([2 x i64], double noundef, double noundef, double noundef, double noundef, double noundef, double noundef, double noundef, double noundef, double noundef, ptr noundef, double noundef)
+// CHECK-DARWIN: declare void @small_pst_callee([2 x i64], double noundef, double noundef, double noundef, double noundef, double noundef, double noundef, double noundef, double noundef, double noundef, i128, double noundef)
+
+
+// Simple return, PST expanded to registers
+//   p->a    -> p0
+//   p->x    -> q0
+//   p->y[0] -> q1
+//   p->y[1] -> q2
+//   p->z    -> q3
+//   p->b    -> p1
+PST test_return(PST *p) {
+    return *p;
+}
+// CHECK-AAPCS:  define dso_local <{ <vscale x 16 x i1>, <vscale x 2 x double>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 16 x i8>, <vscale x 16 x i1> }> @test_return(ptr
+// CHECK-DARWIN: define void @test_return(ptr dead_on_unwind noalias nocapture writable writeonly sret(%struct.PST) align 16 %agg.result, ptr nocapture noundef readonly %p)
+
+// Corner case of 1-element aggregate
+//   p->x -> q0
+SmallPST test_return_small_pst(SmallPST *p) {
+    return *p;
+}
+// CHECK-AAPCS:  define dso_local <vscale x 4 x float> @test_return_small_pst(ptr
+// CHECK-DARWIN: define i128 @test_return_small_pst(ptr nocapture noundef readonly %p)
+
+
+// Big PST, returned indirectly
+//   *p -> *x8
+BigPST test_return_big_pst(BigPST *p) {
+    return *p;
+}
+// CHECK-AAPCS:  define dso_local void @test_return_big_pst(ptr dead_on_unwind noalias nocapture writable writeonly sret(%struct.BigPST) align 16 %agg.result, ptr nocapture noundef readonly %p)
+// CHECK-DARWIN: define void @test_return_big_pst(ptr dead_on_unwind noalias nocapture writable writeonly sret(%struct.BigPST) align 16 %agg.result, ptr nocapture noundef readonly %p)
+
+// Variadic arguments are unnamed, PST passed indirectly.
+// (Passing SVE types to a variadic function currently unsupported by
+// the AArch64 backend)
+//   p->a    -> p0
+//   p->x    -> q0
+//   p->y[0] -> q1
+//   p->y[1] -> q2
+//   p->z    -> q3
+//   p->b    -> p1
+//   *q -> memory, address -> x1
+void test_pass_variadic(PST *p, PST *q) {
+    void pass_variadic_callee(PST, ...);
+    pass_variadic_callee(*p, *q);
+}
+// CHECK-AAPCS: call void @llvm.memcpy.p0.p0.i64(ptr noundef nonnull align 16 dereferenceable(96) %byval-temp, ptr noundef nonnull align 16 dereferenceable(96) %q, i64 96, i1 false)
+// CHECK-AAPCS: call void (<vscale x 16 x i1>, <vscale x 2 x double>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 16 x i8>, <vscale x 16 x i1>, ...) @pass_variadic_callee(<vscale x 16 x i1> %1, <vscale x 2 x double> %cast.scalable1, <vscale x 4 x float> %cast.scalable2, <vscale x 4 x float> %cast.scalable3, <vscale x 16 x i8> %cast.scalable4, <vscale x 16 x i1> %12, ptr noundef nonnull %byval-temp)
+
+// CHECK-DARWIN: call void @llvm.memcpy.p0.p0.i64(ptr noundef nonnull align 16 dereferenceable(96) %byval-temp, ptr noundef nonnull align 16 dereferenceable(96) %p, i64 96, i1 false)
+// CHECK-DARWIN: call void @llvm.lifetime.start.p0(i64 96, ptr nonnull %byval-temp1)
+// CHECK-DARWIN: call void @llvm.memcpy.p0.p0.i64(ptr noundef nonnull align 16 dereferenceable(96) %byval-temp1, ptr noundef nonnull align 16 dereferenceable(96) %q, i64 96, i1 false)
+// CHECK-DARWIN: call void (ptr, ...) @pass_variadic_callee(ptr noundef nonnull %byval-temp, ptr noundef nonnull %byval-temp1)
+
+
+// Test passing a small PST, still passed indirectly, despite being <= 128 bits
+void test_small_pst_variadic(SmallPST *p) {
+    void small_pst_variadic_callee(int, ...);
+    small_pst_variadic_callee(0, *p);
+}
+// CHECK-AAPCS: call void @llvm.memcpy.p0.p0.i64(ptr noundef nonnull align 16 dereferenceable(16) %byval-temp, ptr noundef nonnull align 16 dereferenceable(16) %p, i64 16, i1 false)
+// CHECK-AAPCS: call void (i32, ...) @small_pst_variadic_callee(i32 noundef 0, ptr noundef nonnull %byval-temp)
+
+// CHECK-DARWIN: %0 = load i128, ptr %p, align 16
+// CHECK-DARWIN: tail call void (i32, ...) @small_pst_variadic_callee(i32 noundef 0, i128 %0)
+
+// Test handling of a PST argument when passed in registers, from the callee side.
+void test_argpass_callee_side(PST v) {
+    void use(PST *p);
+    use(&v);
+}
+// CHECK-AAPCS:      define dso_local void @test_argpass_callee_side(<vscale x 16 x i1> %0, <vscale x 2 x double> %.coerce1, <vscale x 4 x float> %.coerce3, <vscale x 4 x float> %.coerce5, <vscale x 16 x i8> %.coerce7, <vscale x 16 x i1> %1)
+// CHECK-AAPCS-NEXT: entry:
+// CHECK-AAPCS-NEXT:   %v = alloca %struct.PST, align 16
+// CHECK-AAPCS-NEXT:   %.coerce = bitcast <vscale x 16 x i1> %0 to <vscale x 2 x i8>
+// CHECK-AAPCS-NEXT:   %cast.fixed = tail call <2 x i8> @llvm.vector.extract.v2i8.nxv2i8(<vscale x 2 x i8> %.coerce, i64 0)
+// CHECK-AAPCS-NEXT:   store <2 x i8> %cast.fixed, ptr %v, align 16
+// CHECK-AAPCS-NEXT:   %2 = getelementptr inbounds nuw i8, ptr %v, i64 16
+// CHECK-AAPCS-NEXT:   %cast.fixed2 = tail call <2 x double> @llvm.vector.extract.v2f64.nxv2f64(<vscale x 2 x double> %.coerce1, i64 0)
+// CHECK-AAPCS-NEXT:   store <2 x double> %cast.fixed2, ptr %2, align 16
+// CHECK-AAPCS-NEXT:   %3 = getelementptr inbounds nuw i8, ptr %v, i64 32
+// CHECK-AAPCS-NEXT:   %cast.fixed4 = tail call <4 x float> @llvm.vector.extract.v4f32.nxv4f32(<vscale x 4 x float> %.coerce3, i64 0)
+// CHECK-AAPCS-NEXT:   store <4 x float> %cast.fixed4, ptr %3, align 16
+// CHECK-AAPCS-NEXT:   %4 = getelementptr inbounds nuw i8, ptr %v, i64 48
+// CHECK-AAPCS-NEXT:   %cast.fixed6 = tail call <4 x float> @llvm.vector.extract.v4f32.nxv4f32(<vscale x 4 x float> %.coerce5, i64 0)
+// CHECK-AAPCS-NEXT:   store <4 x float> %cast.fixed6, ptr %4, align 16
+// CHECK-AAPCS-NEXT:   %5 = getelementptr inbounds nuw i8, ptr %v, i64 64
+// CHECK-AAPCS-NEXT:   %cast.fixed8 = tail call <16 x i8> @llvm.vector.extract.v16i8.nxv16i8(<vscale x 16 x i8> %.coerce7, i64 0)
+// CHECK-AAPCS-NEXT:   store <16 x i8> %cast.fixed8, ptr %5, align 16
+// CHECK-AAPCS-NEXT:   %6 = getelementptr inbounds nuw i8, ptr %v, i64 80
+// CHECK-AAPCS-NEXT:   %.coerce9 = bitcast <vscale x 16 x i1> %1 to <vscale x 2 x i8>
+// CHECK-AAPCS-NEXT:   %cast.fixed10 = tail call <2 x i8> @llvm.vector.extract.v2i8.nxv2i8(<vscale x 2 x i8> %.coerce9, i64 0)
+// CHECK-AAPCS-NEXT:   store <2 x i8> %cast.fixed10, ptr %6, align 16
+// CHECK-AAPCS-NEXT:   call void @use(ptr noundef nonnull %v)
+// CHECK-AAPCS-NEXT:   ret void
+// CHECK-AAPCS-NEXT: }
+
+// Test va_arg operation
+#ifdef __cplusplus
+ extern "C"
+#endif
+void test_va_arg(int n, ...) {
+     va_list ap;
+     va_start(ap, n);  
+     PST v = va_arg(ap, PST);
+     va_end(ap);
+
+     void use1(bvec, fvec32);
+     use1(v.a, v.y[1]);
+}
+// CHECK-AAPCS: define dso_local void @test_va_arg(i32 noundef %n, ...)
+// CHECK-AAPCS-NEXT: entry:
+// CHECK-AAPCS-NEXT:   %ap = alloca %struct.__va_list, align 8
+// CHECK-AAPCS-NEXT:   call void @llvm.lifetime.start.p0(i64 32, ptr nonnull %ap)
+// CHECK-AAPCS-NEXT:   call void @llvm.va_start.p0(ptr nonnull %ap)
+// CHECK-AAPCS-NEXT:   %gr_offs_p = getelementptr inbounds nuw i8, ptr %ap, i64 24
+// CHECK-AAPCS-NEXT:   %gr_offs = load i32, ptr %gr_offs_p, align 8
+// CHECK-AAPCS-NEXT:   %0 = icmp sgt i32 %gr_offs, -1
+// CHECK-AAPCS-NEXT:   br i1 %0, label %vaarg.on_stack, label %vaarg.maybe_reg
+// CHECK-AAPCS-EMPTY:
+// CHECK-AAPCS-NEXT: vaarg.maybe_reg:                                  ; preds = %entry
+
+// Increment by 8, size of the pointer to the argument value, not size of the argument value itself.
+
+// CHECK-AAPCS-NEXT:   %new_reg_offs = add nsw i32 %gr_offs, 8
+// CHECK-AAPCS-NEXT:   store i32 %new_reg_offs, ptr %gr_offs_p, align 8
+// CHECK-AAPCS-NEXT:   %inreg = icmp ult i32 %gr_offs, -7
+// CHECK-AAPCS-NEXT:   br i1 %inreg, label %vaarg.in_reg, label %vaarg.on_stack
+// CHECK-AAPCS-EMPTY:
+// CHECK-AAPCS-NEXT: vaarg.in_reg:                                     ; preds = %vaarg.maybe_reg
+// CHECK-AAPCS-NEXT:   %reg_top_p = getelementptr inbounds nuw i8, ptr %ap, i64 8
+// CHECK-AAPCS-NEXT:   %reg_top = load ptr, ptr %reg_top_p, align 8
+// CHECK-AAPCS-NEXT:   %1 = sext i32 %gr_offs to i64
+// CHECK-AAPCS-NEXT:   %2 = getelementptr inbounds i8, ptr %reg_top, i64 %1
+// CHECK-AAPCS-NEXT:   br label %vaarg.end
+// CHECK-AAPCS-EMPTY:
+// CHECK-AAPCS-NEXT: vaarg.on_stack:                                   ; preds = %vaarg.maybe_reg, %entry
+// CHECK-AAPCS-NEXT:   %stack = load ptr, ptr %ap, align 8
+// CHECK-AAPCS-NEXT:   %new_stack = getelementptr inbounds i8, ptr %stack, i64 8
+// CHECK-AAPCS-NEXT:   store ptr %new_stack, ptr %ap, align 8
+// CHECK-AAPCS-NEXT:   br label %vaarg.end
+// CHECK-AAPCS-EMPTY:
+// CHECK-AAPCS-NEXT: vaarg.end:                                        ; preds = %vaarg.on_stack, %vaarg.in_reg
+// CHECK-AAPCS-NEXT:   %vaargs.addr = phi ptr [ %2, %vaarg.in_reg ], [ %stack, %vaarg.on_stack ]
+
+// Extra indirection, for a composite passed indirectly.
+// CHECK-AAPCS-NEXT:   %vaarg.addr = load ptr, ptr %vaargs.addr, align 8
+
+// CHECK-AAPCS-NEXT:   %v.sroa.0.0.copyload = load <2 x i8>, ptr %vaarg.addr, align 16
+// CHECK-AAPCS-NEXT:   %v.sroa.43.0.vaarg.addr.sroa_idx = getelementptr inbounds i8, ptr %vaarg.addr, i64 48
+// CHECK-AAPCS-NEXT:   %v.sroa.43.0.copyload = load <4 x float>, ptr %v.sroa.43.0.vaarg.addr.sroa_idx, align 16
+// CHECK-AAPCS-NEXT:   call void @llvm.va_end.p0(ptr nonnull %ap)
+// CHECK-AAPCS-NEXT:   %cast.scalable = call <vscale x 2 x i8> @llvm.vector.insert.nxv2i8.v2i8(<vscale x 2 x i8> undef, <2 x i8> %v.sroa.0.0.copyload, i64 0)
+// CHECK-AAPCS-NEXT:   %3 = bitcast <vscale x 2 x i8> %cast.scalable to <vscale x 16 x i1>
+// CHECK-AAPCS-NEXT:   %cast.scalable2 = call <vscale x 4 x float> @llvm.vector.insert.nxv4f32.v4f32(<vscale x 4 x float> undef, <4 x float> %v.sroa.43.0.copyload, i64 0)
+// CHECK-AAPCS-NEXT:   call void @use1(<vscale x 16 x i1> noundef %3, <vscale x 4 x float> noundef %cast.scalable2)
+// CHECK-AAPCS-NEXT:   call void @llvm.lifetime.end.p0(i64 32, ptr nonnull %ap)
+// CHECK-AAPCS-NEXT:   ret void
+// CHECK-AAPCS-NEXT: }
+
+// CHECK-DARWIN: define void @test_va_arg(i32 noundef %n, ...)
+// CHECK-DARWIN-NEXT: entry:
+// CHECK-DARWIN-NEXT:   %ap = alloca ptr, align 8
+// CHECK-DARWIN-NEXT:   call void @llvm.lifetime.start.p0(i64 8, ptr nonnull %ap)
+// CHECK-DARWIN-NEXT:   call void @llvm.va_start.p0(ptr nonnull %ap)
+// CHECK-DARWIN-NEXT:   %argp.cur = load ptr, ptr %ap, align 8
+// CHECK-DARWIN-NEXT:   %argp.next = getelementptr inbounds i8, ptr %argp.cur, i64 8
+// CHECK-DARWIN-NEXT:   store ptr %argp.next, ptr %ap, align 8
+// CHECK-DARWIN-NEXT:   %0 = load ptr, ptr %argp.cur, align 8
+// CHECK-DARWIN-NEXT:   %v.sroa.0.0.copyload = load <2 x i8>, ptr %0, align 16
+// CHECK-DARWIN-NEXT:   %v.sroa.43.0..sroa_idx = getelementptr inbounds i8, ptr %0, i64 48
+// CHECK-DARWIN-NEXT:   %v.sroa.43.0.copyload = load <4 x float>, ptr %v.sroa.43.0..sroa_idx, align 16
+// CHECK-DARWIN-NEXT:   call void @llvm.va_end.p0(ptr nonnull %ap)
+// CHECK-DARWIN-NEXT:   %cast.scalable = call <vscale x 2 x i8> @llvm.vector.insert.nxv2i8.v2i8(<vscale x 2 x i8> undef, <2 x i8> %v.sroa.0.0.copyload, i64 0)
+// CHECK-DARWIN-NEXT:   %1 = bitcast <vscale x 2 x i8> %cast.scalable to <vscale x 16 x i1>
+// CHECK-DARWIN-NEXT:   %cast.scalable2 = call <vscale x 4 x float> @llvm.vector.insert.nxv4f32.v4f32(<vscale x 4 x float> undef, <4 x float> %v.sroa.43.0.copyload, i64 0)
+// CHECK-DARWIN-NEXT:   call void @use1(<vscale x 16 x i1> noundef %1, <vscale x 4 x float> noundef %cast.scalable2)
+// CHECK-DARWIN-NEXT:   call void @llvm.lifetime.end.p0(i64 8, ptr nonnull %ap)
+// CHECK-DARWIN-NEXT:   ret void
+// CHECK-DARWIN-NEXT: }


        


More information about the cfe-commits mailing list