[llvm] [RISCV] Handle RVV return type in calling convention correctly (PR #87736)

Brandon Wu via llvm-commits llvm-commits at lists.llvm.org
Fri Apr 12 06:28:11 PDT 2024


https://github.com/4vtomat updated https://github.com/llvm/llvm-project/pull/87736

>From 5258b46479a2e612623072e1b8d7dd00780203e4 Mon Sep 17 00:00:00 2001
From: Brandon Wu <brandon.wu at sifive.com>
Date: Fri, 12 Apr 2024 06:19:38 -0700
Subject: [PATCH 1/2] Recommit [RISCV] RISCV vector calling convention (2/2)
 (#79096)

---
 .../Target/RISCV/GISel/RISCVCallLowering.cpp  |  55 +++---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp   | 180 ++++++++++++++----
 llvm/lib/Target/RISCV/RISCVISelLowering.h     |  56 +++++-
 llvm/test/CodeGen/RISCV/rvv/calling-conv.ll   |  87 +++++++++
 .../RISCV/rvv/vector-deinterleave-load.ll     |   6 +-
 .../CodeGen/RISCV/rvv/vector-deinterleave.ll  |  19 +-
 6 files changed, 320 insertions(+), 83 deletions(-)

diff --git a/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp b/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp
index 45e19cdea300b1..8af4bc658409d4 100644
--- a/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp
+++ b/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp
@@ -34,14 +34,15 @@ struct RISCVOutgoingValueAssigner : public CallLowering::OutgoingValueAssigner {
   // Whether this is assigning args for a return.
   bool IsRet;
 
-  // true if assignArg has been called for a mask argument, false otherwise.
-  bool AssignedFirstMaskArg = false;
+  RVVArgDispatcher &RVVDispatcher;
 
 public:
   RISCVOutgoingValueAssigner(
-      RISCVTargetLowering::RISCVCCAssignFn *RISCVAssignFn_, bool IsRet)
+      RISCVTargetLowering::RISCVCCAssignFn *RISCVAssignFn_, bool IsRet,
+      RVVArgDispatcher &RVVDispatcher)
       : CallLowering::OutgoingValueAssigner(nullptr),
-        RISCVAssignFn(RISCVAssignFn_), IsRet(IsRet) {}
+        RISCVAssignFn(RISCVAssignFn_), IsRet(IsRet),
+        RVVDispatcher(RVVDispatcher) {}
 
   bool assignArg(unsigned ValNo, EVT OrigVT, MVT ValVT, MVT LocVT,
                  CCValAssign::LocInfo LocInfo,
@@ -51,16 +52,9 @@ struct RISCVOutgoingValueAssigner : public CallLowering::OutgoingValueAssigner {
     const DataLayout &DL = MF.getDataLayout();
     const RISCVSubtarget &Subtarget = MF.getSubtarget<RISCVSubtarget>();
 
-    std::optional<unsigned> FirstMaskArgument;
-    if (Subtarget.hasVInstructions() && !AssignedFirstMaskArg &&
-        ValVT.isVector() && ValVT.getVectorElementType() == MVT::i1) {
-      FirstMaskArgument = ValNo;
-      AssignedFirstMaskArg = true;
-    }
-
     if (RISCVAssignFn(DL, Subtarget.getTargetABI(), ValNo, ValVT, LocVT,
                       LocInfo, Flags, State, Info.IsFixed, IsRet, Info.Ty,
-                      *Subtarget.getTargetLowering(), FirstMaskArgument))
+                      *Subtarget.getTargetLowering(), RVVDispatcher))
       return true;
 
     StackSize = State.getStackSize();
@@ -181,14 +175,15 @@ struct RISCVIncomingValueAssigner : public CallLowering::IncomingValueAssigner {
   // Whether this is assigning args from a return.
   bool IsRet;
 
-  // true if assignArg has been called for a mask argument, false otherwise.
-  bool AssignedFirstMaskArg = false;
+  RVVArgDispatcher &RVVDispatcher;
 
 public:
   RISCVIncomingValueAssigner(
-      RISCVTargetLowering::RISCVCCAssignFn *RISCVAssignFn_, bool IsRet)
+      RISCVTargetLowering::RISCVCCAssignFn *RISCVAssignFn_, bool IsRet,
+      RVVArgDispatcher &RVVDispatcher)
       : CallLowering::IncomingValueAssigner(nullptr),
-        RISCVAssignFn(RISCVAssignFn_), IsRet(IsRet) {}
+        RISCVAssignFn(RISCVAssignFn_), IsRet(IsRet),
+        RVVDispatcher(RVVDispatcher) {}
 
   bool assignArg(unsigned ValNo, EVT OrigVT, MVT ValVT, MVT LocVT,
                  CCValAssign::LocInfo LocInfo,
@@ -201,16 +196,9 @@ struct RISCVIncomingValueAssigner : public CallLowering::IncomingValueAssigner {
     if (LocVT.isScalableVector())
       MF.getInfo<RISCVMachineFunctionInfo>()->setIsVectorCall();
 
-    std::optional<unsigned> FirstMaskArgument;
-    if (Subtarget.hasVInstructions() && !AssignedFirstMaskArg &&
-        ValVT.isVector() && ValVT.getVectorElementType() == MVT::i1) {
-      FirstMaskArgument = ValNo;
-      AssignedFirstMaskArg = true;
-    }
-
     if (RISCVAssignFn(DL, Subtarget.getTargetABI(), ValNo, ValVT, LocVT,
                       LocInfo, Flags, State, /*IsFixed=*/true, IsRet, Info.Ty,
-                      *Subtarget.getTargetLowering(), FirstMaskArgument))
+                      *Subtarget.getTargetLowering(), RVVDispatcher))
       return true;
 
     StackSize = State.getStackSize();
@@ -420,9 +408,11 @@ bool RISCVCallLowering::lowerReturnVal(MachineIRBuilder &MIRBuilder,
   SmallVector<ArgInfo, 4> SplitRetInfos;
   splitToValueTypes(OrigRetInfo, SplitRetInfos, DL, CC);
 
+  RVVArgDispatcher Dispatcher{&MF, getTLI<RISCVTargetLowering>(),
+                              F.getReturnType()};
   RISCVOutgoingValueAssigner Assigner(
       CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
-      /*IsRet=*/true);
+      /*IsRet=*/true, Dispatcher);
   RISCVOutgoingValueHandler Handler(MIRBuilder, MF.getRegInfo(), Ret);
   return determineAndHandleAssignments(Handler, Assigner, SplitRetInfos,
                                        MIRBuilder, CC, F.isVarArg());
@@ -531,6 +521,7 @@ bool RISCVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
   CallingConv::ID CC = F.getCallingConv();
 
   SmallVector<ArgInfo, 32> SplitArgInfos;
+  SmallVector<Type *, 4> TypeList;
   unsigned Index = 0;
   for (auto &Arg : F.args()) {
     // Construct the ArgInfo object from destination register and argument type.
@@ -542,12 +533,15 @@ bool RISCVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
     // correspondingly and appended to SplitArgInfos.
     splitToValueTypes(AInfo, SplitArgInfos, DL, CC);
 
+    TypeList.push_back(Arg.getType());
+
     ++Index;
   }
 
+  RVVArgDispatcher Dispatcher{&MF, getTLI<RISCVTargetLowering>(), TypeList};
   RISCVIncomingValueAssigner Assigner(
       CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
-      /*IsRet=*/false);
+      /*IsRet=*/false, Dispatcher);
   RISCVFormalArgHandler Handler(MIRBuilder, MF.getRegInfo());
 
   SmallVector<CCValAssign, 16> ArgLocs;
@@ -585,11 +579,13 @@ bool RISCVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
 
   SmallVector<ArgInfo, 32> SplitArgInfos;
   SmallVector<ISD::OutputArg, 8> Outs;
+  SmallVector<Type *, 4> TypeList;
   for (auto &AInfo : Info.OrigArgs) {
     // Handle any required unmerging of split value types from a given VReg into
     // physical registers. ArgInfo objects are constructed correspondingly and
     // appended to SplitArgInfos.
     splitToValueTypes(AInfo, SplitArgInfos, DL, CC);
+    TypeList.push_back(AInfo.Ty);
   }
 
   // TODO: Support tail calls.
@@ -607,9 +603,10 @@ bool RISCVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
   const TargetRegisterInfo *TRI = Subtarget.getRegisterInfo();
   Call.addRegMask(TRI->getCallPreservedMask(MF, Info.CallConv));
 
+  RVVArgDispatcher ArgDispatcher{&MF, getTLI<RISCVTargetLowering>(), TypeList};
   RISCVOutgoingValueAssigner ArgAssigner(
       CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
-      /*IsRet=*/false);
+      /*IsRet=*/false, ArgDispatcher);
   RISCVOutgoingValueHandler ArgHandler(MIRBuilder, MF.getRegInfo(), Call);
   if (!determineAndHandleAssignments(ArgHandler, ArgAssigner, SplitArgInfos,
                                      MIRBuilder, CC, Info.IsVarArg))
@@ -637,9 +634,11 @@ bool RISCVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
   SmallVector<ArgInfo, 4> SplitRetInfos;
   splitToValueTypes(Info.OrigRet, SplitRetInfos, DL, CC);
 
+  RVVArgDispatcher RetDispatcher{&MF, getTLI<RISCVTargetLowering>(),
+                                 F.getReturnType()};
   RISCVIncomingValueAssigner RetAssigner(
       CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
-      /*IsRet=*/true);
+      /*IsRet=*/true, RetDispatcher);
   RISCVCallReturnHandler RetHandler(MIRBuilder, MF.getRegInfo(), Call);
   if (!determineAndHandleAssignments(RetHandler, RetAssigner, SplitRetInfos,
                                      MIRBuilder, CC, Info.IsVarArg))
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 5a572002091ff3..3e7bc8c2367de6 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -22,6 +22,7 @@
 #include "llvm/ADT/Statistic.h"
 #include "llvm/Analysis/MemoryLocation.h"
 #include "llvm/Analysis/VectorUtils.h"
+#include "llvm/CodeGen/Analysis.h"
 #include "llvm/CodeGen/MachineFrameInfo.h"
 #include "llvm/CodeGen/MachineFunction.h"
 #include "llvm/CodeGen/MachineInstrBuilder.h"
@@ -18078,33 +18079,12 @@ static bool CC_RISCVAssign2XLen(unsigned XLen, CCState &State, CCValAssign VA1,
   return false;
 }
 
-static unsigned allocateRVVReg(MVT ValVT, unsigned ValNo,
-                               std::optional<unsigned> FirstMaskArgument,
-                               CCState &State, const RISCVTargetLowering &TLI) {
-  const TargetRegisterClass *RC = TLI.getRegClassFor(ValVT);
-  if (RC == &RISCV::VRRegClass) {
-    // Assign the first mask argument to V0.
-    // This is an interim calling convention and it may be changed in the
-    // future.
-    if (FirstMaskArgument && ValNo == *FirstMaskArgument)
-      return State.AllocateReg(RISCV::V0);
-    return State.AllocateReg(ArgVRs);
-  }
-  if (RC == &RISCV::VRM2RegClass)
-    return State.AllocateReg(ArgVRM2s);
-  if (RC == &RISCV::VRM4RegClass)
-    return State.AllocateReg(ArgVRM4s);
-  if (RC == &RISCV::VRM8RegClass)
-    return State.AllocateReg(ArgVRM8s);
-  llvm_unreachable("Unhandled register class for ValueType");
-}
-
 // Implements the RISC-V calling convention. Returns true upon failure.
 bool RISCV::CC_RISCV(const DataLayout &DL, RISCVABI::ABI ABI, unsigned ValNo,
                      MVT ValVT, MVT LocVT, CCValAssign::LocInfo LocInfo,
                      ISD::ArgFlagsTy ArgFlags, CCState &State, bool IsFixed,
                      bool IsRet, Type *OrigTy, const RISCVTargetLowering &TLI,
-                     std::optional<unsigned> FirstMaskArgument) {
+                     RVVArgDispatcher &RVVDispatcher) {
   unsigned XLen = DL.getLargestLegalIntTypeSizeInBits();
   assert(XLen == 32 || XLen == 64);
   MVT XLenVT = XLen == 32 ? MVT::i32 : MVT::i64;
@@ -18273,7 +18253,7 @@ bool RISCV::CC_RISCV(const DataLayout &DL, RISCVABI::ABI ABI, unsigned ValNo,
   else if (ValVT == MVT::f64 && !UseGPRForF64)
     Reg = State.AllocateReg(ArgFPR64s);
   else if (ValVT.isVector()) {
-    Reg = allocateRVVReg(ValVT, ValNo, FirstMaskArgument, State, TLI);
+    Reg = RVVDispatcher.getNextPhysReg();
     if (!Reg) {
       // For return values, the vector must be passed fully via registers or
       // via the stack.
@@ -18359,9 +18339,13 @@ void RISCVTargetLowering::analyzeInputArgs(
   unsigned NumArgs = Ins.size();
   FunctionType *FType = MF.getFunction().getFunctionType();
 
-  std::optional<unsigned> FirstMaskArgument;
-  if (Subtarget.hasVInstructions())
-    FirstMaskArgument = preAssignMask(Ins);
+  SmallVector<Type *, 4> TypeList;
+  if (IsRet)
+    TypeList.push_back(MF.getFunction().getReturnType());
+  else
+    for (const Argument &Arg : MF.getFunction().args())
+      TypeList.push_back(Arg.getType());
+  RVVArgDispatcher Dispatcher{&MF, this, TypeList};
 
   for (unsigned i = 0; i != NumArgs; ++i) {
     MVT ArgVT = Ins[i].VT;
@@ -18376,7 +18360,7 @@ void RISCVTargetLowering::analyzeInputArgs(
     RISCVABI::ABI ABI = MF.getSubtarget<RISCVSubtarget>().getTargetABI();
     if (Fn(MF.getDataLayout(), ABI, i, ArgVT, ArgVT, CCValAssign::Full,
            ArgFlags, CCInfo, /*IsFixed=*/true, IsRet, ArgTy, *this,
-           FirstMaskArgument)) {
+           Dispatcher)) {
       LLVM_DEBUG(dbgs() << "InputArg #" << i << " has unhandled type "
                         << ArgVT << '\n');
       llvm_unreachable(nullptr);
@@ -18390,9 +18374,13 @@ void RISCVTargetLowering::analyzeOutputArgs(
     CallLoweringInfo *CLI, RISCVCCAssignFn Fn) const {
   unsigned NumArgs = Outs.size();
 
-  std::optional<unsigned> FirstMaskArgument;
-  if (Subtarget.hasVInstructions())
-    FirstMaskArgument = preAssignMask(Outs);
+  SmallVector<Type *, 4> TypeList;
+  if (IsRet)
+    TypeList.push_back(MF.getFunction().getReturnType());
+  else if (CLI)
+    for (const TargetLowering::ArgListEntry &Arg : CLI->getArgs())
+      TypeList.push_back(Arg.Ty);
+  RVVArgDispatcher Dispatcher{&MF, this, TypeList};
 
   for (unsigned i = 0; i != NumArgs; i++) {
     MVT ArgVT = Outs[i].VT;
@@ -18402,7 +18390,7 @@ void RISCVTargetLowering::analyzeOutputArgs(
     RISCVABI::ABI ABI = MF.getSubtarget<RISCVSubtarget>().getTargetABI();
     if (Fn(MF.getDataLayout(), ABI, i, ArgVT, ArgVT, CCValAssign::Full,
            ArgFlags, CCInfo, Outs[i].IsFixed, IsRet, OrigTy, *this,
-           FirstMaskArgument)) {
+           Dispatcher)) {
       LLVM_DEBUG(dbgs() << "OutputArg #" << i << " has unhandled type "
                         << ArgVT << "\n");
       llvm_unreachable(nullptr);
@@ -18583,7 +18571,7 @@ bool RISCV::CC_RISCV_FastCC(const DataLayout &DL, RISCVABI::ABI ABI,
                             ISD::ArgFlagsTy ArgFlags, CCState &State,
                             bool IsFixed, bool IsRet, Type *OrigTy,
                             const RISCVTargetLowering &TLI,
-                            std::optional<unsigned> FirstMaskArgument) {
+                            RVVArgDispatcher &RVVDispatcher) {
   if (LocVT == MVT::i32 || LocVT == MVT::i64) {
     if (unsigned Reg = State.AllocateReg(getFastCCArgGPRs(ABI))) {
       State.addLoc(CCValAssign::getReg(ValNo, ValVT, Reg, LocVT, LocInfo));
@@ -18661,13 +18649,14 @@ bool RISCV::CC_RISCV_FastCC(const DataLayout &DL, RISCVABI::ABI ABI,
   }
 
   if (LocVT.isVector()) {
-    if (unsigned Reg =
-            allocateRVVReg(ValVT, ValNo, FirstMaskArgument, State, TLI)) {
+    MCPhysReg AllocatedVReg = RVVDispatcher.getNextPhysReg();
+    if (AllocatedVReg) {
       // Fixed-length vectors are located in the corresponding scalable-vector
       // container types.
       if (ValVT.isFixedLengthVector())
         LocVT = TLI.getContainerForFixedLengthVector(LocVT);
-      State.addLoc(CCValAssign::getReg(ValNo, ValVT, Reg, LocVT, LocInfo));
+      State.addLoc(
+          CCValAssign::getReg(ValNo, ValVT, AllocatedVReg, LocVT, LocInfo));
     } else {
       // Try and pass the address via a "fast" GPR.
       if (unsigned GPRReg = State.AllocateReg(getFastCCArgGPRs(ABI))) {
@@ -19295,17 +19284,15 @@ bool RISCVTargetLowering::CanLowerReturn(
   SmallVector<CCValAssign, 16> RVLocs;
   CCState CCInfo(CallConv, IsVarArg, MF, RVLocs, Context);
 
-  std::optional<unsigned> FirstMaskArgument;
-  if (Subtarget.hasVInstructions())
-    FirstMaskArgument = preAssignMask(Outs);
+  RVVArgDispatcher Dispatcher{&MF, this, MF.getFunction().getReturnType()};
 
   for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
     MVT VT = Outs[i].VT;
     ISD::ArgFlagsTy ArgFlags = Outs[i].Flags;
     RISCVABI::ABI ABI = MF.getSubtarget<RISCVSubtarget>().getTargetABI();
     if (RISCV::CC_RISCV(MF.getDataLayout(), ABI, i, VT, VT, CCValAssign::Full,
-                 ArgFlags, CCInfo, /*IsFixed=*/true, /*IsRet=*/true, nullptr,
-                 *this, FirstMaskArgument))
+                        ArgFlags, CCInfo, /*IsFixed=*/true, /*IsRet=*/true,
+                        nullptr, *this, Dispatcher))
       return false;
   }
   return true;
@@ -21102,6 +21089,119 @@ unsigned RISCVTargetLowering::getMinimumJumpTableEntries() const {
   return Subtarget.getMinimumJumpTableEntries();
 }
 
+void RVVArgDispatcher::constructArgInfos(ArrayRef<Type *> TypeList) {
+  const DataLayout &DL = MF->getDataLayout();
+  const Function &F = MF->getFunction();
+  LLVMContext &Context = F.getContext();
+
+  bool FirstVMaskAssigned = false;
+  for (Type *Ty : TypeList) {
+    StructType *STy = dyn_cast<StructType>(Ty);
+    if (STy && STy->containsHomogeneousScalableVectorTypes()) {
+      Type *ElemTy = STy->getTypeAtIndex(0U);
+      EVT VT = TLI->getValueType(DL, ElemTy);
+      MVT RegisterVT =
+          TLI->getRegisterTypeForCallingConv(Context, F.getCallingConv(), VT);
+
+      RVVArgInfos.push_back({STy->getNumElements(), RegisterVT, false});
+    } else {
+      SmallVector<EVT, 4> ValueVTs;
+      ComputeValueVTs(*TLI, DL, Ty, ValueVTs);
+
+      for (unsigned Value = 0, NumValues = ValueVTs.size(); Value != NumValues;
+           ++Value) {
+        EVT VT = ValueVTs[Value];
+        MVT RegisterVT =
+            TLI->getRegisterTypeForCallingConv(Context, F.getCallingConv(), VT);
+        unsigned NumRegs =
+            TLI->getNumRegistersForCallingConv(Context, F.getCallingConv(), VT);
+
+        // Skip non-RVV register type
+        if (!RegisterVT.isVector())
+          continue;
+
+        if (RegisterVT.isFixedLengthVector())
+          RegisterVT = TLI->getContainerForFixedLengthVector(RegisterVT);
+
+        if (!FirstVMaskAssigned &&
+            RegisterVT.getVectorElementType() == MVT::i1) {
+          RVVArgInfos.push_back({1, RegisterVT, true});
+          FirstVMaskAssigned = true;
+          --NumRegs;
+        }
+
+        RVVArgInfos.insert(RVVArgInfos.end(), NumRegs, {1, RegisterVT, false});
+      }
+    }
+  }
+}
+
+void RVVArgDispatcher::allocatePhysReg(unsigned NF, unsigned LMul,
+                                       unsigned StartReg) {
+  assert((StartReg % LMul) == 0 &&
+         "Start register number should be multiple of lmul");
+  const MCPhysReg *VRArrays;
+  switch (LMul) {
+  default:
+    report_fatal_error("Invalid lmul");
+  case 1:
+    VRArrays = ArgVRs;
+    break;
+  case 2:
+    VRArrays = ArgVRM2s;
+    break;
+  case 4:
+    VRArrays = ArgVRM4s;
+    break;
+  case 8:
+    VRArrays = ArgVRM8s;
+    break;
+  }
+
+  for (unsigned i = 0; i < NF; ++i)
+    if (StartReg)
+      AllocatedPhysRegs.push_back(VRArrays[(StartReg - 8) / LMul + i]);
+    else
+      AllocatedPhysRegs.push_back(MCPhysReg());
+}
+
+/// This function determines if each RVV argument is passed by register, if the
+/// argument can be assigned to a VR, then give it a specific register.
+/// Otherwise, assign the argument to 0 which is a invalid MCPhysReg.
+void RVVArgDispatcher::compute() {
+  uint32_t AssignedMap = 0;
+  auto allocate = [&](const RVVArgInfo &ArgInfo) {
+    // Allocate first vector mask argument to V0.
+    if (ArgInfo.FirstVMask) {
+      AllocatedPhysRegs.push_back(RISCV::V0);
+      return;
+    }
+
+    unsigned RegsNeeded = divideCeil(
+        ArgInfo.VT.getSizeInBits().getKnownMinValue(), RISCV::RVVBitsPerBlock);
+    unsigned TotalRegsNeeded = ArgInfo.NF * RegsNeeded;
+    for (unsigned StartReg = 0; StartReg + TotalRegsNeeded <= NumArgVRs;
+         StartReg += RegsNeeded) {
+      uint32_t Map = ((1 << TotalRegsNeeded) - 1) << StartReg;
+      if ((AssignedMap & Map) == 0) {
+        allocatePhysReg(ArgInfo.NF, RegsNeeded, StartReg + 8);
+        AssignedMap |= Map;
+        return;
+      }
+    }
+
+    allocatePhysReg(ArgInfo.NF, RegsNeeded, 0);
+  };
+
+  for (unsigned i = 0; i < RVVArgInfos.size(); ++i)
+    allocate(RVVArgInfos[i]);
+}
+
+MCPhysReg RVVArgDispatcher::getNextPhysReg() {
+  assert(CurIdx < AllocatedPhysRegs.size() && "Index out of range");
+  return AllocatedPhysRegs[CurIdx++];
+}
+
 namespace llvm::RISCVVIntrinsicsTable {
 
 #define GET_RISCVVIntrinsicsTable_IMPL
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index ace5b3fd2b95b4..c28552354bf422 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -24,6 +24,7 @@ namespace llvm {
 class InstructionCost;
 class RISCVSubtarget;
 struct RISCVRegisterInfo;
+class RVVArgDispatcher;
 
 namespace RISCVISD {
 // clang-format off
@@ -875,7 +876,7 @@ class RISCVTargetLowering : public TargetLowering {
                                ISD::ArgFlagsTy ArgFlags, CCState &State,
                                bool IsFixed, bool IsRet, Type *OrigTy,
                                const RISCVTargetLowering &TLI,
-                               std::optional<unsigned> FirstMaskArgument);
+                               RVVArgDispatcher &RVVDispatcher);
 
 private:
   void analyzeInputArgs(MachineFunction &MF, CCState &CCInfo,
@@ -1015,19 +1016,68 @@ class RISCVTargetLowering : public TargetLowering {
   unsigned getMinimumJumpTableEntries() const override;
 };
 
+/// As per the spec, the rules for passing vector arguments are as follows:
+///
+/// 1. For the first vector mask argument, use v0 to pass it.
+/// 2. For vector data arguments or rest vector mask arguments, starting from
+/// the v8 register, if a vector register group between v8-v23 that has not been
+/// allocated can be found and the first register number is a multiple of LMUL,
+/// then allocate this vector register group to the argument and mark these
+/// registers as allocated. Otherwise, pass it by reference and are replaced in
+/// the argument list with the address.
+/// 3. For tuple vector data arguments, starting from the v8 register, if
+/// NFIELDS consecutive vector register groups between v8-v23 that have not been
+/// allocated can be found and the first register number is a multiple of LMUL,
+/// then allocate these vector register groups to the argument and mark these
+/// registers as allocated. Otherwise, pass it by reference and are replaced in
+/// the argument list with the address.
+class RVVArgDispatcher {
+public:
+  static constexpr unsigned NumArgVRs = 16;
+
+  struct RVVArgInfo {
+    unsigned NF;
+    MVT VT;
+    bool FirstVMask = false;
+  };
+
+  RVVArgDispatcher(const MachineFunction *MF, const RISCVTargetLowering *TLI,
+                   ArrayRef<Type *> TypeList)
+      : MF(MF), TLI(TLI) {
+    constructArgInfos(TypeList);
+    compute();
+  }
+
+  MCPhysReg getNextPhysReg();
+
+private:
+  SmallVector<RVVArgInfo, 4> RVVArgInfos;
+  SmallVector<MCPhysReg, 4> AllocatedPhysRegs;
+
+  const MachineFunction *MF = nullptr;
+  const RISCVTargetLowering *TLI = nullptr;
+
+  unsigned CurIdx = 0;
+
+  void constructArgInfos(ArrayRef<Type *> TypeList);
+  void compute();
+  void allocatePhysReg(unsigned NF = 1, unsigned LMul = 1,
+                       unsigned StartReg = 0);
+};
+
 namespace RISCV {
 
 bool CC_RISCV(const DataLayout &DL, RISCVABI::ABI ABI, unsigned ValNo,
               MVT ValVT, MVT LocVT, CCValAssign::LocInfo LocInfo,
               ISD::ArgFlagsTy ArgFlags, CCState &State, bool IsFixed,
               bool IsRet, Type *OrigTy, const RISCVTargetLowering &TLI,
-              std::optional<unsigned> FirstMaskArgument);
+              RVVArgDispatcher &RVVDispatcher);
 
 bool CC_RISCV_FastCC(const DataLayout &DL, RISCVABI::ABI ABI, unsigned ValNo,
                      MVT ValVT, MVT LocVT, CCValAssign::LocInfo LocInfo,
                      ISD::ArgFlagsTy ArgFlags, CCState &State, bool IsFixed,
                      bool IsRet, Type *OrigTy, const RISCVTargetLowering &TLI,
-                     std::optional<unsigned> FirstMaskArgument);
+                     RVVArgDispatcher &RVVDispatcher);
 
 bool CC_RISCV_GHC(unsigned ValNo, MVT ValVT, MVT LocVT,
                   CCValAssign::LocInfo LocInfo, ISD::ArgFlagsTy ArgFlags,
diff --git a/llvm/test/CodeGen/RISCV/rvv/calling-conv.ll b/llvm/test/CodeGen/RISCV/rvv/calling-conv.ll
index 78e8700a9feff8..90edb994ce8222 100644
--- a/llvm/test/CodeGen/RISCV/rvv/calling-conv.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/calling-conv.ll
@@ -162,3 +162,90 @@ define void @caller_tuple_argument({<vscale x 4 x i32>, <vscale x 4 x i32>} %x)
 }
 
 declare void @callee_tuple_argument({<vscale x 4 x i32>, <vscale x 4 x i32>})
+
+; %0 -> v8
+; %1 -> v9
+define <vscale x 1 x i64> @case1(<vscale x 1 x i64> %0, <vscale x 1 x i64> %1) {
+; CHECK-LABEL: case1:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a0, zero, e64, m1, ta, ma
+; CHECK-NEXT:    vadd.vv v8, v8, v9
+; CHECK-NEXT:    ret
+  %a = add <vscale x 1 x i64> %0, %1
+  ret <vscale x 1 x i64> %a
+}
+
+; %0 -> v8
+; %1 -> v10-v11
+; %2 -> v9
+define <vscale x 1 x i64> @case2_1(<vscale x 1 x i64> %0, <vscale x 2 x i64> %1, <vscale x 1 x i64> %2) {
+; CHECK-LABEL: case2_1:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a0, zero, e64, m1, ta, ma
+; CHECK-NEXT:    vadd.vv v8, v8, v9
+; CHECK-NEXT:    ret
+  %a = add <vscale x 1 x i64> %0, %2
+  ret <vscale x 1 x i64> %a
+}
+define <vscale x 2 x i64> @case2_2(<vscale x 1 x i64> %0, <vscale x 2 x i64> %1, <vscale x 1 x i64> %2) {
+; CHECK-LABEL: case2_2:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a0, zero, e64, m2, ta, ma
+; CHECK-NEXT:    vadd.vv v8, v10, v10
+; CHECK-NEXT:    ret
+  %a = add <vscale x 2 x i64> %1, %1
+  ret <vscale x 2 x i64> %a
+}
+
+; %0 -> v8
+; %1 -> {v10-v11, v12-v13}
+; %2 -> v9
+define <vscale x 1 x i64> @case3_1(<vscale x 1 x i64> %0, {<vscale x 2 x i64>, <vscale x 2 x i64>} %1, <vscale x 1 x i64> %2) {
+; CHECK-LABEL: case3_1:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a0, zero, e64, m1, ta, ma
+; CHECK-NEXT:    vadd.vv v8, v8, v9
+; CHECK-NEXT:    ret
+  %add = add <vscale x 1 x i64> %0, %2
+  ret <vscale x 1 x i64> %add
+}
+define <vscale x 2 x i64> @case3_2(<vscale x 1 x i64> %0, {<vscale x 2 x i64>, <vscale x 2 x i64>} %1, <vscale x 1 x i64> %2) {
+; CHECK-LABEL: case3_2:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a0, zero, e64, m2, ta, ma
+; CHECK-NEXT:    vadd.vv v8, v10, v12
+; CHECK-NEXT:    ret
+  %a = extractvalue { <vscale x 2 x i64>, <vscale x 2 x i64> } %1, 0
+  %b = extractvalue { <vscale x 2 x i64>, <vscale x 2 x i64> } %1, 1
+  %add = add <vscale x 2 x i64> %a, %b
+  ret <vscale x 2 x i64> %add
+}
+
+; %0 -> v8
+; %1 -> {by-ref, by-ref}
+; %2 -> v9
+define <vscale x 8 x i64> @case4_1(<vscale x 1 x i64> %0, {<vscale x 8 x i64>, <vscale x 8 x i64>} %1, <vscale x 1 x i64> %2) {
+; CHECK-LABEL: case4_1:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    csrr a1, vlenb
+; CHECK-NEXT:    slli a1, a1, 3
+; CHECK-NEXT:    add a1, a0, a1
+; CHECK-NEXT:    vl8re64.v v8, (a1)
+; CHECK-NEXT:    vl8re64.v v16, (a0)
+; CHECK-NEXT:    vsetvli a0, zero, e64, m8, ta, ma
+; CHECK-NEXT:    vadd.vv v8, v16, v8
+; CHECK-NEXT:    ret
+  %a = extractvalue { <vscale x 8 x i64>, <vscale x 8 x i64> } %1, 0
+  %b = extractvalue { <vscale x 8 x i64>, <vscale x 8 x i64> } %1, 1
+  %add = add <vscale x 8 x i64> %a, %b
+  ret <vscale x 8 x i64> %add
+}
+define <vscale x 1 x i64> @case4_2(<vscale x 1 x i64> %0, {<vscale x 8 x i64>, <vscale x 8 x i64>} %1, <vscale x 1 x i64> %2) {
+; CHECK-LABEL: case4_2:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a0, zero, e64, m1, ta, ma
+; CHECK-NEXT:    vadd.vv v8, v8, v9
+; CHECK-NEXT:    ret
+  %add = add <vscale x 1 x i64> %0, %2
+  ret <vscale x 1 x i64> %add
+}
diff --git a/llvm/test/CodeGen/RISCV/rvv/vector-deinterleave-load.ll b/llvm/test/CodeGen/RISCV/rvv/vector-deinterleave-load.ll
index a320aecc6fce49..6a712080fda74a 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vector-deinterleave-load.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vector-deinterleave-load.ll
@@ -18,10 +18,10 @@ define {<vscale x 16 x i1>, <vscale x 16 x i1>} @vector_deinterleave_load_nxv16i
 ; CHECK-NEXT:    vmerge.vim v14, v10, 1, v0
 ; CHECK-NEXT:    vmv1r.v v0, v8
 ; CHECK-NEXT:    vmerge.vim v12, v10, 1, v0
-; CHECK-NEXT:    vnsrl.wi v8, v12, 0
-; CHECK-NEXT:    vmsne.vi v0, v8, 0
-; CHECK-NEXT:    vnsrl.wi v10, v12, 8
+; CHECK-NEXT:    vnsrl.wi v10, v12, 0
 ; CHECK-NEXT:    vmsne.vi v8, v10, 0
+; CHECK-NEXT:    vnsrl.wi v10, v12, 8
+; CHECK-NEXT:    vmsne.vi v9, v10, 0
 ; CHECK-NEXT:    ret
   %vec = load <vscale x 32 x i1>, ptr %p
   %retval = call {<vscale x 16 x i1>, <vscale x 16 x i1>} @llvm.experimental.vector.deinterleave2.nxv32i1(<vscale x 32 x i1> %vec)
diff --git a/llvm/test/CodeGen/RISCV/rvv/vector-deinterleave.ll b/llvm/test/CodeGen/RISCV/rvv/vector-deinterleave.ll
index ef4baf34d23f03..d98597fabcd953 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vector-deinterleave.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vector-deinterleave.ll
@@ -8,18 +8,18 @@ define {<vscale x 16 x i1>, <vscale x 16 x i1>} @vector_deinterleave_nxv16i1_nxv
 ; CHECK-LABEL: vector_deinterleave_nxv16i1_nxv32i1:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vsetvli a0, zero, e8, m2, ta, ma
-; CHECK-NEXT:    vmv.v.i v10, 0
-; CHECK-NEXT:    vmerge.vim v8, v10, 1, v0
+; CHECK-NEXT:    vmv.v.i v8, 0
+; CHECK-NEXT:    vmerge.vim v12, v8, 1, v0
 ; CHECK-NEXT:    csrr a0, vlenb
 ; CHECK-NEXT:    srli a0, a0, 2
 ; CHECK-NEXT:    vsetvli a1, zero, e8, mf2, ta, ma
 ; CHECK-NEXT:    vslidedown.vx v0, v0, a0
 ; CHECK-NEXT:    vsetvli a0, zero, e8, m2, ta, ma
-; CHECK-NEXT:    vmerge.vim v10, v10, 1, v0
-; CHECK-NEXT:    vnsrl.wi v12, v8, 0
-; CHECK-NEXT:    vmsne.vi v0, v12, 0
-; CHECK-NEXT:    vnsrl.wi v12, v8, 8
-; CHECK-NEXT:    vmsne.vi v8, v12, 0
+; CHECK-NEXT:    vmerge.vim v14, v8, 1, v0
+; CHECK-NEXT:    vnsrl.wi v10, v12, 0
+; CHECK-NEXT:    vmsne.vi v8, v10, 0
+; CHECK-NEXT:    vnsrl.wi v10, v12, 8
+; CHECK-NEXT:    vmsne.vi v9, v10, 0
 ; CHECK-NEXT:    ret
 %retval = call {<vscale x 16 x i1>, <vscale x 16 x i1>} @llvm.experimental.vector.deinterleave2.nxv32i1(<vscale x 32 x i1> %vec)
 ret {<vscale x 16 x i1>, <vscale x 16 x i1>} %retval
@@ -102,12 +102,13 @@ define {<vscale x 64 x i1>, <vscale x 64 x i1>} @vector_deinterleave_nxv64i1_nxv
 ; CHECK-NEXT:    vsetvli a0, zero, e8, m4, ta, ma
 ; CHECK-NEXT:    vnsrl.wi v28, v8, 0
 ; CHECK-NEXT:    vsetvli a0, zero, e8, m8, ta, ma
-; CHECK-NEXT:    vmsne.vi v0, v24, 0
+; CHECK-NEXT:    vmsne.vi v7, v24, 0
 ; CHECK-NEXT:    vsetvli a0, zero, e8, m4, ta, ma
 ; CHECK-NEXT:    vnsrl.wi v24, v16, 8
 ; CHECK-NEXT:    vnsrl.wi v28, v8, 8
 ; CHECK-NEXT:    vsetvli a0, zero, e8, m8, ta, ma
-; CHECK-NEXT:    vmsne.vi v8, v24, 0
+; CHECK-NEXT:    vmsne.vi v9, v24, 0
+; CHECK-NEXT:    vmv1r.v v8, v7
 ; CHECK-NEXT:    ret
 %retval = call {<vscale x 64 x i1>, <vscale x 64 x i1>} @llvm.experimental.vector.deinterleave2.nxv128i1(<vscale x 128 x i1> %vec)
 ret {<vscale x 64 x i1>, <vscale x 64 x i1>} %retval

>From cad2c892dad83d4f3974a662dfde9c9e5896e062 Mon Sep 17 00:00:00 2001
From: Brandon Wu <brandon.wu at sifive.com>
Date: Thu, 4 Apr 2024 10:26:49 -0700
Subject: [PATCH 2/2] [RISCV] Handle RVV return type in calling convention
 correctly

Return values are handled in a same way as function arguments.
One thing to mention is that if a type can be broken down into homogeneous
vector types, e.g. {<vscale x 4 x i32>, {<vscale x 4 x i32>, <vscale x 4 x i32>}},
it is considered as a vector tuple type and need to be handled by tuple
type rule.
---
 llvm/lib/CodeGen/TargetLoweringBase.cpp       |  12 +-
 .../Target/RISCV/GISel/RISCVCallLowering.cpp  |  10 +-
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp   | 107 ++++++++++++++--
 llvm/lib/Target/RISCV/RISCVISelLowering.h     |   9 +-
 llvm/test/CodeGen/RISCV/rvv/calling-conv.ll   | 116 ++++++++++++++++++
 5 files changed, 236 insertions(+), 18 deletions(-)

diff --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp
index f64ded4f2cf965..6e7b67ded23c84 100644
--- a/llvm/lib/CodeGen/TargetLoweringBase.cpp
+++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp
@@ -1809,8 +1809,16 @@ void llvm::GetReturnInfo(CallingConv::ID CC, Type *ReturnType,
     else if (attr.hasRetAttr(Attribute::ZExt))
       Flags.setZExt();
 
-    for (unsigned i = 0; i < NumParts; ++i)
-      Outs.push_back(ISD::OutputArg(Flags, PartVT, VT, /*isfixed=*/true, 0, 0));
+    for (unsigned i = 0; i < NumParts; ++i) {
+      ISD::ArgFlagsTy OutFlags = Flags;
+      if (NumParts > 1 && i == 0)
+        OutFlags.setSplit();
+      else if (i == NumParts - 1 && i != 0)
+        OutFlags.setSplitEnd();
+
+      Outs.push_back(
+          ISD::OutputArg(OutFlags, PartVT, VT, /*isfixed=*/true, 0, 0));
+    }
   }
 }
 
diff --git a/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp b/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp
index 8af4bc658409d4..c18892ac62f247 100644
--- a/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp
+++ b/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp
@@ -409,7 +409,7 @@ bool RISCVCallLowering::lowerReturnVal(MachineIRBuilder &MIRBuilder,
   splitToValueTypes(OrigRetInfo, SplitRetInfos, DL, CC);
 
   RVVArgDispatcher Dispatcher{&MF, getTLI<RISCVTargetLowering>(),
-                              F.getReturnType()};
+                              ArrayRef(F.getReturnType())};
   RISCVOutgoingValueAssigner Assigner(
       CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
       /*IsRet=*/true, Dispatcher);
@@ -538,7 +538,8 @@ bool RISCVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
     ++Index;
   }
 
-  RVVArgDispatcher Dispatcher{&MF, getTLI<RISCVTargetLowering>(), TypeList};
+  RVVArgDispatcher Dispatcher{&MF, getTLI<RISCVTargetLowering>(),
+                              ArrayRef(TypeList)};
   RISCVIncomingValueAssigner Assigner(
       CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
       /*IsRet=*/false, Dispatcher);
@@ -603,7 +604,8 @@ bool RISCVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
   const TargetRegisterInfo *TRI = Subtarget.getRegisterInfo();
   Call.addRegMask(TRI->getCallPreservedMask(MF, Info.CallConv));
 
-  RVVArgDispatcher ArgDispatcher{&MF, getTLI<RISCVTargetLowering>(), TypeList};
+  RVVArgDispatcher ArgDispatcher{&MF, getTLI<RISCVTargetLowering>(),
+                                 ArrayRef(TypeList)};
   RISCVOutgoingValueAssigner ArgAssigner(
       CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
       /*IsRet=*/false, ArgDispatcher);
@@ -635,7 +637,7 @@ bool RISCVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
   splitToValueTypes(Info.OrigRet, SplitRetInfos, DL, CC);
 
   RVVArgDispatcher RetDispatcher{&MF, getTLI<RISCVTargetLowering>(),
-                                 F.getReturnType()};
+                                 ArrayRef(F.getReturnType())};
   RISCVIncomingValueAssigner RetAssigner(
       CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
       /*IsRet=*/true, RetDispatcher);
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 3e7bc8c2367de6..9f3efdcce74e96 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -18339,13 +18339,15 @@ void RISCVTargetLowering::analyzeInputArgs(
   unsigned NumArgs = Ins.size();
   FunctionType *FType = MF.getFunction().getFunctionType();
 
-  SmallVector<Type *, 4> TypeList;
-  if (IsRet)
-    TypeList.push_back(MF.getFunction().getReturnType());
-  else
+  RVVArgDispatcher Dispatcher;
+  if (IsRet) {
+    Dispatcher = RVVArgDispatcher{&MF, this, ArrayRef(Ins)};
+  } else {
+    SmallVector<Type *, 4> TypeList;
     for (const Argument &Arg : MF.getFunction().args())
       TypeList.push_back(Arg.getType());
-  RVVArgDispatcher Dispatcher{&MF, this, TypeList};
+    Dispatcher = RVVArgDispatcher{&MF, this, ArrayRef(TypeList)};
+  }
 
   for (unsigned i = 0; i != NumArgs; ++i) {
     MVT ArgVT = Ins[i].VT;
@@ -18380,7 +18382,7 @@ void RISCVTargetLowering::analyzeOutputArgs(
   else if (CLI)
     for (const TargetLowering::ArgListEntry &Arg : CLI->getArgs())
       TypeList.push_back(Arg.Ty);
-  RVVArgDispatcher Dispatcher{&MF, this, TypeList};
+  RVVArgDispatcher Dispatcher{&MF, this, ArrayRef(TypeList)};
 
   for (unsigned i = 0; i != NumArgs; i++) {
     MVT ArgVT = Outs[i].VT;
@@ -19284,7 +19286,7 @@ bool RISCVTargetLowering::CanLowerReturn(
   SmallVector<CCValAssign, 16> RVLocs;
   CCState CCInfo(CallConv, IsVarArg, MF, RVLocs, Context);
 
-  RVVArgDispatcher Dispatcher{&MF, this, MF.getFunction().getReturnType()};
+  RVVArgDispatcher Dispatcher{&MF, this, ArrayRef(Outs)};
 
   for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
     MVT VT = Outs[i].VT;
@@ -21089,7 +21091,91 @@ unsigned RISCVTargetLowering::getMinimumJumpTableEntries() const {
   return Subtarget.getMinimumJumpTableEntries();
 }
 
-void RVVArgDispatcher::constructArgInfos(ArrayRef<Type *> TypeList) {
+// Handle single arg such as return value.
+template <typename Arg>
+void RVVArgDispatcher::constructArgInfos(ArrayRef<Arg> ArgList) {
+  // This lambda determines whether an array of types are constructed by
+  // homogeneous vector types.
+  auto isHomogeneousScalableVectorType = [](ArrayRef<Arg> ArgList) {
+    // First, extract the first element in the argument type.
+    MVT FirstArgRegType;
+    unsigned FirstArgElements = 0;
+    auto It = ArgList.begin();
+    bool IsPart = false;
+
+    if (It == ArgList.end())
+      return false;
+
+    for (; It != ArgList.end(); ++It) {
+      FirstArgRegType = It->VT;
+      ++FirstArgElements;
+      if ((!It->Flags.isSplit() && !IsPart) || It->Flags.isSplitEnd())
+        break;
+
+      IsPart = true;
+    }
+
+    assert(It != ArgList.end() && "It shouldn't reach the end of ArgList.");
+    ++It;
+
+    // Return if this argument type contains only 1 element, or it's not a
+    // vector type.
+    if (It == ArgList.end() || !FirstArgRegType.isScalableVector())
+      return false;
+
+    // Second, check if the following elements in this argument type are all the
+    // same.
+    MVT ArgRegType;
+    unsigned ArgElements = 0;
+    IsPart = false;
+    for (; It != ArgList.end(); ++It) {
+      ArgRegType = It->VT;
+      ++ArgElements;
+      if ((!It->Flags.isSplit() && !IsPart) || It->Flags.isSplitEnd()) {
+        if (ArgRegType != FirstArgRegType || ArgElements != FirstArgElements)
+          return false;
+
+        IsPart = false;
+        ArgElements = 0;
+        continue;
+      }
+
+      IsPart = true;
+    }
+
+    return true;
+  };
+
+  if (isHomogeneousScalableVectorType(ArgList)) {
+    // Handle as tuple type
+    RVVArgInfos.push_back({(unsigned)ArgList.size(), ArgList[0].VT, false});
+  } else {
+    // Handle as normal vector type
+    bool FirstVMaskAssigned = false;
+    for (const auto &OutArg : ArgList) {
+      MVT RegisterVT = OutArg.VT;
+
+      // Skip non-RVV register type
+      if (!RegisterVT.isVector())
+        continue;
+
+      if (RegisterVT.isFixedLengthVector())
+        RegisterVT = TLI->getContainerForFixedLengthVector(RegisterVT);
+
+      if (!FirstVMaskAssigned && RegisterVT.getVectorElementType() == MVT::i1) {
+        RVVArgInfos.push_back({1, RegisterVT, true});
+        FirstVMaskAssigned = true;
+        continue;
+      }
+
+      RVVArgInfos.push_back({1, RegisterVT, false});
+    }
+  }
+}
+
+// Handle multiple args.
+template <>
+void RVVArgDispatcher::constructArgInfos<Type *>(ArrayRef<Type *> TypeList) {
   const DataLayout &DL = MF->getDataLayout();
   const Function &F = MF->getFunction();
   LLVMContext &Context = F.getContext();
@@ -21102,8 +21188,11 @@ void RVVArgDispatcher::constructArgInfos(ArrayRef<Type *> TypeList) {
       EVT VT = TLI->getValueType(DL, ElemTy);
       MVT RegisterVT =
           TLI->getRegisterTypeForCallingConv(Context, F.getCallingConv(), VT);
+      unsigned NumRegs =
+          TLI->getNumRegistersForCallingConv(Context, F.getCallingConv(), VT);
 
-      RVVArgInfos.push_back({STy->getNumElements(), RegisterVT, false});
+      RVVArgInfos.push_back(
+          {NumRegs * STy->getNumElements(), RegisterVT, false});
     } else {
       SmallVector<EVT, 4> ValueVTs;
       ComputeValueVTs(*TLI, DL, Ty, ValueVTs);
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index c28552354bf422..a2456f2fab66b1 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -1041,13 +1041,16 @@ class RVVArgDispatcher {
     bool FirstVMask = false;
   };
 
+  template <typename Arg>
   RVVArgDispatcher(const MachineFunction *MF, const RISCVTargetLowering *TLI,
-                   ArrayRef<Type *> TypeList)
+                   ArrayRef<Arg> ArgList)
       : MF(MF), TLI(TLI) {
-    constructArgInfos(TypeList);
+    constructArgInfos(ArgList);
     compute();
   }
 
+  RVVArgDispatcher() = default;
+
   MCPhysReg getNextPhysReg();
 
 private:
@@ -1059,7 +1062,7 @@ class RVVArgDispatcher {
 
   unsigned CurIdx = 0;
 
-  void constructArgInfos(ArrayRef<Type *> TypeList);
+  template <typename Arg> void constructArgInfos(ArrayRef<Arg> Ret);
   void compute();
   void allocatePhysReg(unsigned NF = 1, unsigned LMul = 1,
                        unsigned StartReg = 0);
diff --git a/llvm/test/CodeGen/RISCV/rvv/calling-conv.ll b/llvm/test/CodeGen/RISCV/rvv/calling-conv.ll
index 90edb994ce8222..647d3158b6167f 100644
--- a/llvm/test/CodeGen/RISCV/rvv/calling-conv.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/calling-conv.ll
@@ -249,3 +249,119 @@ define <vscale x 1 x i64> @case4_2(<vscale x 1 x i64> %0, {<vscale x 8 x i64>, <
   %add = add <vscale x 1 x i64> %0, %2
   ret <vscale x 1 x i64> %add
 }
+
+declare <vscale x 1 x i64> @callee1()
+declare void @callee2(<vscale x 1 x i64>)
+declare void @callee3(<vscale x 4 x i32>)
+define void @caller() {
+; RV32-LABEL: caller:
+; RV32:       # %bb.0:
+; RV32-NEXT:    addi sp, sp, -16
+; RV32-NEXT:    .cfi_def_cfa_offset 16
+; RV32-NEXT:    sw ra, 12(sp) # 4-byte Folded Spill
+; RV32-NEXT:    .cfi_offset ra, -4
+; RV32-NEXT:    call callee1
+; RV32-NEXT:    vsetvli a0, zero, e64, m1, ta, ma
+; RV32-NEXT:    vadd.vv v8, v8, v8
+; RV32-NEXT:    call callee2
+; RV32-NEXT:    lw ra, 12(sp) # 4-byte Folded Reload
+; RV32-NEXT:    addi sp, sp, 16
+; RV32-NEXT:    ret
+;
+; RV64-LABEL: caller:
+; RV64:       # %bb.0:
+; RV64-NEXT:    addi sp, sp, -16
+; RV64-NEXT:    .cfi_def_cfa_offset 16
+; RV64-NEXT:    sd ra, 8(sp) # 8-byte Folded Spill
+; RV64-NEXT:    .cfi_offset ra, -8
+; RV64-NEXT:    call callee1
+; RV64-NEXT:    vsetvli a0, zero, e64, m1, ta, ma
+; RV64-NEXT:    vadd.vv v8, v8, v8
+; RV64-NEXT:    call callee2
+; RV64-NEXT:    ld ra, 8(sp) # 8-byte Folded Reload
+; RV64-NEXT:    addi sp, sp, 16
+; RV64-NEXT:    ret
+  %a = call <vscale x 1 x i64> @callee1()
+  %add = add <vscale x 1 x i64> %a, %a
+  call void @callee2(<vscale x 1 x i64> %add)
+  ret void
+}
+
+declare {<vscale x 4 x i32>, <vscale x 4 x i32>} @callee_tuple()
+define void @caller_tuple() {
+; RV32-LABEL: caller_tuple:
+; RV32:       # %bb.0:
+; RV32-NEXT:    addi sp, sp, -16
+; RV32-NEXT:    .cfi_def_cfa_offset 16
+; RV32-NEXT:    sw ra, 12(sp) # 4-byte Folded Spill
+; RV32-NEXT:    .cfi_offset ra, -4
+; RV32-NEXT:    call callee_tuple
+; RV32-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
+; RV32-NEXT:    vadd.vv v8, v8, v10
+; RV32-NEXT:    call callee3
+; RV32-NEXT:    lw ra, 12(sp) # 4-byte Folded Reload
+; RV32-NEXT:    addi sp, sp, 16
+; RV32-NEXT:    ret
+;
+; RV64-LABEL: caller_tuple:
+; RV64:       # %bb.0:
+; RV64-NEXT:    addi sp, sp, -16
+; RV64-NEXT:    .cfi_def_cfa_offset 16
+; RV64-NEXT:    sd ra, 8(sp) # 8-byte Folded Spill
+; RV64-NEXT:    .cfi_offset ra, -8
+; RV64-NEXT:    call callee_tuple
+; RV64-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
+; RV64-NEXT:    vadd.vv v8, v8, v10
+; RV64-NEXT:    call callee3
+; RV64-NEXT:    ld ra, 8(sp) # 8-byte Folded Reload
+; RV64-NEXT:    addi sp, sp, 16
+; RV64-NEXT:    ret
+  %a = call {<vscale x 4 x i32>, <vscale x 4 x i32>} @callee_tuple()
+  %b = extractvalue {<vscale x 4 x i32>, <vscale x 4 x i32>} %a, 0
+  %c = extractvalue {<vscale x 4 x i32>, <vscale x 4 x i32>} %a, 1
+  %add = add <vscale x 4 x i32> %b, %c
+  call void @callee3(<vscale x 4 x i32> %add)
+  ret void
+}
+
+declare {<vscale x 4 x i32>, {<vscale x 4 x i32>, <vscale x 4 x i32>}} @callee_nested()
+define void @caller_nested() {
+; RV32-LABEL: caller_nested:
+; RV32:       # %bb.0:
+; RV32-NEXT:    addi sp, sp, -16
+; RV32-NEXT:    .cfi_def_cfa_offset 16
+; RV32-NEXT:    sw ra, 12(sp) # 4-byte Folded Spill
+; RV32-NEXT:    .cfi_offset ra, -4
+; RV32-NEXT:    call callee_nested
+; RV32-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
+; RV32-NEXT:    vadd.vv v8, v8, v10
+; RV32-NEXT:    vadd.vv v8, v8, v12
+; RV32-NEXT:    call callee3
+; RV32-NEXT:    lw ra, 12(sp) # 4-byte Folded Reload
+; RV32-NEXT:    addi sp, sp, 16
+; RV32-NEXT:    ret
+;
+; RV64-LABEL: caller_nested:
+; RV64:       # %bb.0:
+; RV64-NEXT:    addi sp, sp, -16
+; RV64-NEXT:    .cfi_def_cfa_offset 16
+; RV64-NEXT:    sd ra, 8(sp) # 8-byte Folded Spill
+; RV64-NEXT:    .cfi_offset ra, -8
+; RV64-NEXT:    call callee_nested
+; RV64-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
+; RV64-NEXT:    vadd.vv v8, v8, v10
+; RV64-NEXT:    vadd.vv v8, v8, v12
+; RV64-NEXT:    call callee3
+; RV64-NEXT:    ld ra, 8(sp) # 8-byte Folded Reload
+; RV64-NEXT:    addi sp, sp, 16
+; RV64-NEXT:    ret
+  %a = call {<vscale x 4 x i32>, {<vscale x 4 x i32>, <vscale x 4 x i32>}} @callee_nested()
+  %b = extractvalue {<vscale x 4 x i32>, {<vscale x 4 x i32>, <vscale x 4 x i32>}} %a, 0
+  %c = extractvalue {<vscale x 4 x i32>, {<vscale x 4 x i32>, <vscale x 4 x i32>}} %a, 1
+  %c0 = extractvalue {<vscale x 4 x i32>, <vscale x 4 x i32>} %c, 0
+  %c1 = extractvalue {<vscale x 4 x i32>, <vscale x 4 x i32>} %c, 1
+  %add0 = add <vscale x 4 x i32> %b, %c0
+  %add1 = add <vscale x 4 x i32> %add0, %c1
+  call void @callee3(<vscale x 4 x i32> %add1)
+  ret void
+}



More information about the llvm-commits mailing list