[llvm] [RISCV] RISCV vector calling convention (2/2) (PR #79096)

Brandon Wu via llvm-commits llvm-commits at lists.llvm.org
Sat Mar 30 04:51:25 PDT 2024


https://github.com/4vtomat updated https://github.com/llvm/llvm-project/pull/79096

>From 1f2cde3985078e83aa8b1d1509216ca5d905bea0 Mon Sep 17 00:00:00 2001
From: Brandon Wu <brandon.wu at sifive.com>
Date: Mon, 22 Jan 2024 21:35:46 -0800
Subject: [PATCH] [RISCV] RISCV vector calling convention (2/2)

This commit handles vector arguments/return for function definition/call,
the new class RVVArgDispatcher is added for doing all vector register
assignment including mask types, data types as well as tuple types.
It precomputes the register number for each argument as per
https://github.com/riscv-non-isa/riscv-elf-psabi-doc/blob/master/riscv-cc.adoc#standard-vector-calling-convention-variant
and it's passed to calling convention function to handle all vector arguments.

Depends on: #78550
---
 .../Target/RISCV/GISel/RISCVCallLowering.cpp  |  55 +++---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp   | 182 ++++++++++++++----
 llvm/lib/Target/RISCV/RISCVISelLowering.h     |  56 +++++-
 llvm/test/CodeGen/RISCV/rvv/calling-conv.ll   |  87 +++++++++
 .../RISCV/rvv/vector-deinterleave-load.ll     |   6 +-
 .../CodeGen/RISCV/rvv/vector-deinterleave.ll  |  19 +-
 6 files changed, 322 insertions(+), 83 deletions(-)

diff --git a/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp b/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp
index 45e19cdea300b1..8af4bc658409d4 100644
--- a/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp
+++ b/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp
@@ -34,14 +34,15 @@ struct RISCVOutgoingValueAssigner : public CallLowering::OutgoingValueAssigner {
   // Whether this is assigning args for a return.
   bool IsRet;
 
-  // true if assignArg has been called for a mask argument, false otherwise.
-  bool AssignedFirstMaskArg = false;
+  RVVArgDispatcher &RVVDispatcher;
 
 public:
   RISCVOutgoingValueAssigner(
-      RISCVTargetLowering::RISCVCCAssignFn *RISCVAssignFn_, bool IsRet)
+      RISCVTargetLowering::RISCVCCAssignFn *RISCVAssignFn_, bool IsRet,
+      RVVArgDispatcher &RVVDispatcher)
       : CallLowering::OutgoingValueAssigner(nullptr),
-        RISCVAssignFn(RISCVAssignFn_), IsRet(IsRet) {}
+        RISCVAssignFn(RISCVAssignFn_), IsRet(IsRet),
+        RVVDispatcher(RVVDispatcher) {}
 
   bool assignArg(unsigned ValNo, EVT OrigVT, MVT ValVT, MVT LocVT,
                  CCValAssign::LocInfo LocInfo,
@@ -51,16 +52,9 @@ struct RISCVOutgoingValueAssigner : public CallLowering::OutgoingValueAssigner {
     const DataLayout &DL = MF.getDataLayout();
     const RISCVSubtarget &Subtarget = MF.getSubtarget<RISCVSubtarget>();
 
-    std::optional<unsigned> FirstMaskArgument;
-    if (Subtarget.hasVInstructions() && !AssignedFirstMaskArg &&
-        ValVT.isVector() && ValVT.getVectorElementType() == MVT::i1) {
-      FirstMaskArgument = ValNo;
-      AssignedFirstMaskArg = true;
-    }
-
     if (RISCVAssignFn(DL, Subtarget.getTargetABI(), ValNo, ValVT, LocVT,
                       LocInfo, Flags, State, Info.IsFixed, IsRet, Info.Ty,
-                      *Subtarget.getTargetLowering(), FirstMaskArgument))
+                      *Subtarget.getTargetLowering(), RVVDispatcher))
       return true;
 
     StackSize = State.getStackSize();
@@ -181,14 +175,15 @@ struct RISCVIncomingValueAssigner : public CallLowering::IncomingValueAssigner {
   // Whether this is assigning args from a return.
   bool IsRet;
 
-  // true if assignArg has been called for a mask argument, false otherwise.
-  bool AssignedFirstMaskArg = false;
+  RVVArgDispatcher &RVVDispatcher;
 
 public:
   RISCVIncomingValueAssigner(
-      RISCVTargetLowering::RISCVCCAssignFn *RISCVAssignFn_, bool IsRet)
+      RISCVTargetLowering::RISCVCCAssignFn *RISCVAssignFn_, bool IsRet,
+      RVVArgDispatcher &RVVDispatcher)
       : CallLowering::IncomingValueAssigner(nullptr),
-        RISCVAssignFn(RISCVAssignFn_), IsRet(IsRet) {}
+        RISCVAssignFn(RISCVAssignFn_), IsRet(IsRet),
+        RVVDispatcher(RVVDispatcher) {}
 
   bool assignArg(unsigned ValNo, EVT OrigVT, MVT ValVT, MVT LocVT,
                  CCValAssign::LocInfo LocInfo,
@@ -201,16 +196,9 @@ struct RISCVIncomingValueAssigner : public CallLowering::IncomingValueAssigner {
     if (LocVT.isScalableVector())
       MF.getInfo<RISCVMachineFunctionInfo>()->setIsVectorCall();
 
-    std::optional<unsigned> FirstMaskArgument;
-    if (Subtarget.hasVInstructions() && !AssignedFirstMaskArg &&
-        ValVT.isVector() && ValVT.getVectorElementType() == MVT::i1) {
-      FirstMaskArgument = ValNo;
-      AssignedFirstMaskArg = true;
-    }
-
     if (RISCVAssignFn(DL, Subtarget.getTargetABI(), ValNo, ValVT, LocVT,
                       LocInfo, Flags, State, /*IsFixed=*/true, IsRet, Info.Ty,
-                      *Subtarget.getTargetLowering(), FirstMaskArgument))
+                      *Subtarget.getTargetLowering(), RVVDispatcher))
       return true;
 
     StackSize = State.getStackSize();
@@ -420,9 +408,11 @@ bool RISCVCallLowering::lowerReturnVal(MachineIRBuilder &MIRBuilder,
   SmallVector<ArgInfo, 4> SplitRetInfos;
   splitToValueTypes(OrigRetInfo, SplitRetInfos, DL, CC);
 
+  RVVArgDispatcher Dispatcher{&MF, getTLI<RISCVTargetLowering>(),
+                              F.getReturnType()};
   RISCVOutgoingValueAssigner Assigner(
       CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
-      /*IsRet=*/true);
+      /*IsRet=*/true, Dispatcher);
   RISCVOutgoingValueHandler Handler(MIRBuilder, MF.getRegInfo(), Ret);
   return determineAndHandleAssignments(Handler, Assigner, SplitRetInfos,
                                        MIRBuilder, CC, F.isVarArg());
@@ -531,6 +521,7 @@ bool RISCVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
   CallingConv::ID CC = F.getCallingConv();
 
   SmallVector<ArgInfo, 32> SplitArgInfos;
+  SmallVector<Type *, 4> TypeList;
   unsigned Index = 0;
   for (auto &Arg : F.args()) {
     // Construct the ArgInfo object from destination register and argument type.
@@ -542,12 +533,15 @@ bool RISCVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
     // correspondingly and appended to SplitArgInfos.
     splitToValueTypes(AInfo, SplitArgInfos, DL, CC);
 
+    TypeList.push_back(Arg.getType());
+
     ++Index;
   }
 
+  RVVArgDispatcher Dispatcher{&MF, getTLI<RISCVTargetLowering>(), TypeList};
   RISCVIncomingValueAssigner Assigner(
       CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
-      /*IsRet=*/false);
+      /*IsRet=*/false, Dispatcher);
   RISCVFormalArgHandler Handler(MIRBuilder, MF.getRegInfo());
 
   SmallVector<CCValAssign, 16> ArgLocs;
@@ -585,11 +579,13 @@ bool RISCVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
 
   SmallVector<ArgInfo, 32> SplitArgInfos;
   SmallVector<ISD::OutputArg, 8> Outs;
+  SmallVector<Type *, 4> TypeList;
   for (auto &AInfo : Info.OrigArgs) {
     // Handle any required unmerging of split value types from a given VReg into
     // physical registers. ArgInfo objects are constructed correspondingly and
     // appended to SplitArgInfos.
     splitToValueTypes(AInfo, SplitArgInfos, DL, CC);
+    TypeList.push_back(AInfo.Ty);
   }
 
   // TODO: Support tail calls.
@@ -607,9 +603,10 @@ bool RISCVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
   const TargetRegisterInfo *TRI = Subtarget.getRegisterInfo();
   Call.addRegMask(TRI->getCallPreservedMask(MF, Info.CallConv));
 
+  RVVArgDispatcher ArgDispatcher{&MF, getTLI<RISCVTargetLowering>(), TypeList};
   RISCVOutgoingValueAssigner ArgAssigner(
       CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
-      /*IsRet=*/false);
+      /*IsRet=*/false, ArgDispatcher);
   RISCVOutgoingValueHandler ArgHandler(MIRBuilder, MF.getRegInfo(), Call);
   if (!determineAndHandleAssignments(ArgHandler, ArgAssigner, SplitArgInfos,
                                      MIRBuilder, CC, Info.IsVarArg))
@@ -637,9 +634,11 @@ bool RISCVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
   SmallVector<ArgInfo, 4> SplitRetInfos;
   splitToValueTypes(Info.OrigRet, SplitRetInfos, DL, CC);
 
+  RVVArgDispatcher RetDispatcher{&MF, getTLI<RISCVTargetLowering>(),
+                                 F.getReturnType()};
   RISCVIncomingValueAssigner RetAssigner(
       CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
-      /*IsRet=*/true);
+      /*IsRet=*/true, RetDispatcher);
   RISCVCallReturnHandler RetHandler(MIRBuilder, MF.getRegInfo(), Call);
   if (!determineAndHandleAssignments(RetHandler, RetAssigner, SplitRetInfos,
                                      MIRBuilder, CC, Info.IsVarArg))
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index e48ca4a905ce9e..f693cbd3bea51e 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -22,6 +22,7 @@
 #include "llvm/ADT/Statistic.h"
 #include "llvm/Analysis/MemoryLocation.h"
 #include "llvm/Analysis/VectorUtils.h"
+#include "llvm/CodeGen/Analysis.h"
 #include "llvm/CodeGen/MachineFrameInfo.h"
 #include "llvm/CodeGen/MachineFunction.h"
 #include "llvm/CodeGen/MachineInstrBuilder.h"
@@ -18057,33 +18058,12 @@ static bool CC_RISCVAssign2XLen(unsigned XLen, CCState &State, CCValAssign VA1,
   return false;
 }
 
-static unsigned allocateRVVReg(MVT ValVT, unsigned ValNo,
-                               std::optional<unsigned> FirstMaskArgument,
-                               CCState &State, const RISCVTargetLowering &TLI) {
-  const TargetRegisterClass *RC = TLI.getRegClassFor(ValVT);
-  if (RC == &RISCV::VRRegClass) {
-    // Assign the first mask argument to V0.
-    // This is an interim calling convention and it may be changed in the
-    // future.
-    if (FirstMaskArgument && ValNo == *FirstMaskArgument)
-      return State.AllocateReg(RISCV::V0);
-    return State.AllocateReg(ArgVRs);
-  }
-  if (RC == &RISCV::VRM2RegClass)
-    return State.AllocateReg(ArgVRM2s);
-  if (RC == &RISCV::VRM4RegClass)
-    return State.AllocateReg(ArgVRM4s);
-  if (RC == &RISCV::VRM8RegClass)
-    return State.AllocateReg(ArgVRM8s);
-  llvm_unreachable("Unhandled register class for ValueType");
-}
-
 // Implements the RISC-V calling convention. Returns true upon failure.
 bool RISCV::CC_RISCV(const DataLayout &DL, RISCVABI::ABI ABI, unsigned ValNo,
                      MVT ValVT, MVT LocVT, CCValAssign::LocInfo LocInfo,
                      ISD::ArgFlagsTy ArgFlags, CCState &State, bool IsFixed,
                      bool IsRet, Type *OrigTy, const RISCVTargetLowering &TLI,
-                     std::optional<unsigned> FirstMaskArgument) {
+                     RVVArgDispatcher &RVVDispatcher) {
   unsigned XLen = DL.getLargestLegalIntTypeSizeInBits();
   assert(XLen == 32 || XLen == 64);
   MVT XLenVT = XLen == 32 ? MVT::i32 : MVT::i64;
@@ -18252,7 +18232,7 @@ bool RISCV::CC_RISCV(const DataLayout &DL, RISCVABI::ABI ABI, unsigned ValNo,
   else if (ValVT == MVT::f64 && !UseGPRForF64)
     Reg = State.AllocateReg(ArgFPR64s);
   else if (ValVT.isVector()) {
-    Reg = allocateRVVReg(ValVT, ValNo, FirstMaskArgument, State, TLI);
+    Reg = RVVDispatcher.getNextPhysReg();
     if (!Reg) {
       // For return values, the vector must be passed fully via registers or
       // via the stack.
@@ -18338,9 +18318,13 @@ void RISCVTargetLowering::analyzeInputArgs(
   unsigned NumArgs = Ins.size();
   FunctionType *FType = MF.getFunction().getFunctionType();
 
-  std::optional<unsigned> FirstMaskArgument;
-  if (Subtarget.hasVInstructions())
-    FirstMaskArgument = preAssignMask(Ins);
+  SmallVector<Type *, 4> TypeList;
+  if (IsRet)
+    TypeList.push_back(MF.getFunction().getReturnType());
+  else
+    for (const Argument &Arg : MF.getFunction().args())
+      TypeList.push_back(Arg.getType());
+  RVVArgDispatcher Dispatcher{&MF, this, TypeList};
 
   for (unsigned i = 0; i != NumArgs; ++i) {
     MVT ArgVT = Ins[i].VT;
@@ -18355,7 +18339,7 @@ void RISCVTargetLowering::analyzeInputArgs(
     RISCVABI::ABI ABI = MF.getSubtarget<RISCVSubtarget>().getTargetABI();
     if (Fn(MF.getDataLayout(), ABI, i, ArgVT, ArgVT, CCValAssign::Full,
            ArgFlags, CCInfo, /*IsFixed=*/true, IsRet, ArgTy, *this,
-           FirstMaskArgument)) {
+           Dispatcher)) {
       LLVM_DEBUG(dbgs() << "InputArg #" << i << " has unhandled type "
                         << ArgVT << '\n');
       llvm_unreachable(nullptr);
@@ -18369,9 +18353,13 @@ void RISCVTargetLowering::analyzeOutputArgs(
     CallLoweringInfo *CLI, RISCVCCAssignFn Fn) const {
   unsigned NumArgs = Outs.size();
 
-  std::optional<unsigned> FirstMaskArgument;
-  if (Subtarget.hasVInstructions())
-    FirstMaskArgument = preAssignMask(Outs);
+  SmallVector<Type *, 4> TypeList;
+  if (IsRet)
+    TypeList.push_back(MF.getFunction().getReturnType());
+  else if (CLI)
+    for (const TargetLowering::ArgListEntry &Arg : CLI->getArgs())
+      TypeList.push_back(Arg.Ty);
+  RVVArgDispatcher Dispatcher{&MF, this, TypeList};
 
   for (unsigned i = 0; i != NumArgs; i++) {
     MVT ArgVT = Outs[i].VT;
@@ -18381,7 +18369,7 @@ void RISCVTargetLowering::analyzeOutputArgs(
     RISCVABI::ABI ABI = MF.getSubtarget<RISCVSubtarget>().getTargetABI();
     if (Fn(MF.getDataLayout(), ABI, i, ArgVT, ArgVT, CCValAssign::Full,
            ArgFlags, CCInfo, Outs[i].IsFixed, IsRet, OrigTy, *this,
-           FirstMaskArgument)) {
+           Dispatcher)) {
       LLVM_DEBUG(dbgs() << "OutputArg #" << i << " has unhandled type "
                         << ArgVT << "\n");
       llvm_unreachable(nullptr);
@@ -18562,7 +18550,7 @@ bool RISCV::CC_RISCV_FastCC(const DataLayout &DL, RISCVABI::ABI ABI,
                             ISD::ArgFlagsTy ArgFlags, CCState &State,
                             bool IsFixed, bool IsRet, Type *OrigTy,
                             const RISCVTargetLowering &TLI,
-                            std::optional<unsigned> FirstMaskArgument) {
+                            RVVArgDispatcher &RVVDispatcher) {
   if (LocVT == MVT::i32 || LocVT == MVT::i64) {
     if (unsigned Reg = State.AllocateReg(getFastCCArgGPRs(ABI))) {
       State.addLoc(CCValAssign::getReg(ValNo, ValVT, Reg, LocVT, LocInfo));
@@ -18640,13 +18628,14 @@ bool RISCV::CC_RISCV_FastCC(const DataLayout &DL, RISCVABI::ABI ABI,
   }
 
   if (LocVT.isVector()) {
-    if (unsigned Reg =
-            allocateRVVReg(ValVT, ValNo, FirstMaskArgument, State, TLI)) {
+    MCPhysReg AllocatedVReg = RVVDispatcher.getNextPhysReg();
+    if (AllocatedVReg) {
       // Fixed-length vectors are located in the corresponding scalable-vector
       // container types.
       if (ValVT.isFixedLengthVector())
         LocVT = TLI.getContainerForFixedLengthVector(LocVT);
-      State.addLoc(CCValAssign::getReg(ValNo, ValVT, Reg, LocVT, LocInfo));
+      State.addLoc(
+          CCValAssign::getReg(ValNo, ValVT, AllocatedVReg, LocVT, LocInfo));
     } else {
       // Try and pass the address via a "fast" GPR.
       if (unsigned GPRReg = State.AllocateReg(getFastCCArgGPRs(ABI))) {
@@ -19274,17 +19263,15 @@ bool RISCVTargetLowering::CanLowerReturn(
   SmallVector<CCValAssign, 16> RVLocs;
   CCState CCInfo(CallConv, IsVarArg, MF, RVLocs, Context);
 
-  std::optional<unsigned> FirstMaskArgument;
-  if (Subtarget.hasVInstructions())
-    FirstMaskArgument = preAssignMask(Outs);
+  RVVArgDispatcher Dispatcher{&MF, this, MF.getFunction().getReturnType()};
 
   for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
     MVT VT = Outs[i].VT;
     ISD::ArgFlagsTy ArgFlags = Outs[i].Flags;
     RISCVABI::ABI ABI = MF.getSubtarget<RISCVSubtarget>().getTargetABI();
     if (RISCV::CC_RISCV(MF.getDataLayout(), ABI, i, VT, VT, CCValAssign::Full,
-                 ArgFlags, CCInfo, /*IsFixed=*/true, /*IsRet=*/true, nullptr,
-                 *this, FirstMaskArgument))
+                        ArgFlags, CCInfo, /*IsFixed=*/true, /*IsRet=*/true,
+                        nullptr, *this, Dispatcher))
       return false;
   }
   return true;
@@ -21081,6 +21068,121 @@ unsigned RISCVTargetLowering::getMinimumJumpTableEntries() const {
   return Subtarget.getMinimumJumpTableEntries();
 }
 
+void RVVArgDispatcher::constructArgInfos(ArrayRef<Type *> TypeList) {
+  const DataLayout &DL = MF->getDataLayout();
+  const Function &F = MF->getFunction();
+  LLVMContext &Context = F.getContext();
+
+  bool FirstVMaskAssigned = false;
+  for (Type *Ty : TypeList) {
+    StructType *STy = dyn_cast<StructType>(Ty);
+    if (STy && STy->containsHomogeneousScalableVectorTypes()) {
+      Type *ElemTy = STy->getTypeAtIndex(0U);
+      EVT VT = TLI->getValueType(DL, ElemTy);
+      MVT RegisterVT =
+          TLI->getRegisterTypeForCallingConv(Context, F.getCallingConv(), VT);
+
+      RVVArgInfos.push_back({STy->getNumElements(), RegisterVT, false});
+    } else {
+      SmallVector<EVT, 4> ValueVTs;
+      ComputeValueVTs(*TLI, DL, Ty, ValueVTs);
+
+      for (unsigned Value = 0, NumValues = ValueVTs.size(); Value != NumValues;
+           ++Value) {
+        EVT VT = ValueVTs[Value];
+        MVT RegisterVT =
+            TLI->getRegisterTypeForCallingConv(Context, F.getCallingConv(), VT);
+        unsigned NumRegs =
+            TLI->getNumRegistersForCallingConv(Context, F.getCallingConv(), VT);
+
+        // Skip non-RVV register type
+        if (!RegisterVT.isVector())
+          continue;
+
+        if (RegisterVT.isFixedLengthVector())
+          RegisterVT = TLI->getContainerForFixedLengthVector(RegisterVT);
+
+        if (!FirstVMaskAssigned &&
+            RegisterVT.getVectorElementType() == MVT::i1) {
+          RVVArgInfos.push_back({1, RegisterVT, true});
+          FirstVMaskAssigned = true;
+        } else {
+          RVVArgInfos.push_back({1, RegisterVT, false});
+        }
+
+        RVVArgInfos.insert(RVVArgInfos.end(), --NumRegs,
+                           {1, RegisterVT, false});
+      }
+    }
+  }
+}
+
+void RVVArgDispatcher::allocatePhysReg(unsigned NF, unsigned LMul,
+                                       unsigned StartReg) {
+  assert((StartReg % LMul) == 0 &&
+         "Start register number should be multiple of lmul");
+  const MCPhysReg *VRArrays;
+  switch (LMul) {
+  default:
+    report_fatal_error("Invalid lmul");
+  case 1:
+    VRArrays = ArgVRs;
+    break;
+  case 2:
+    VRArrays = ArgVRM2s;
+    break;
+  case 4:
+    VRArrays = ArgVRM4s;
+    break;
+  case 8:
+    VRArrays = ArgVRM8s;
+    break;
+  }
+
+  for (unsigned i = 0; i < NF; ++i)
+    if (StartReg)
+      AllocatedPhysRegs.push_back(VRArrays[(StartReg - 8) / LMul + i]);
+    else
+      AllocatedPhysRegs.push_back(MCPhysReg());
+}
+
+/// This function determines if each RVV argument is passed by register, if the
+/// argument can be assigned to a VR, then give it a specific register.
+/// Otherwise, assign the argument to 0 which is a invalid MCPhysReg.
+void RVVArgDispatcher::compute() {
+  uint32_t AssignedMap = 0;
+  auto allocate = [&](const RVVArgInfo &ArgInfo) {
+    // Allocate first vector mask argument to V0.
+    if (ArgInfo.FirstVMask) {
+      AllocatedPhysRegs.push_back(RISCV::V0);
+      return;
+    }
+
+    unsigned RegsNeeded = divideCeil(
+        ArgInfo.VT.getSizeInBits().getKnownMinValue(), RISCV::RVVBitsPerBlock);
+    unsigned TotalRegsNeeded = ArgInfo.NF * RegsNeeded;
+    for (unsigned StartReg = 0; StartReg + TotalRegsNeeded <= NumArgVRs;
+         StartReg += RegsNeeded) {
+      uint32_t Map = ((1 << TotalRegsNeeded) - 1) << StartReg;
+      if ((AssignedMap & Map) == 0) {
+        allocatePhysReg(ArgInfo.NF, RegsNeeded, StartReg + 8);
+        AssignedMap |= Map;
+        return;
+      }
+    }
+
+    allocatePhysReg(ArgInfo.NF, RegsNeeded, 0);
+  };
+
+  for (unsigned i = 0; i < RVVArgInfos.size(); ++i)
+    allocate(RVVArgInfos[i]);
+}
+
+MCPhysReg RVVArgDispatcher::getNextPhysReg() {
+  assert(CurIdx < AllocatedPhysRegs.size() && "Index out of range");
+  return AllocatedPhysRegs[CurIdx++];
+}
+
 namespace llvm::RISCVVIntrinsicsTable {
 
 #define GET_RISCVVIntrinsicsTable_IMPL
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index ace5b3fd2b95b4..c28552354bf422 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -24,6 +24,7 @@ namespace llvm {
 class InstructionCost;
 class RISCVSubtarget;
 struct RISCVRegisterInfo;
+class RVVArgDispatcher;
 
 namespace RISCVISD {
 // clang-format off
@@ -875,7 +876,7 @@ class RISCVTargetLowering : public TargetLowering {
                                ISD::ArgFlagsTy ArgFlags, CCState &State,
                                bool IsFixed, bool IsRet, Type *OrigTy,
                                const RISCVTargetLowering &TLI,
-                               std::optional<unsigned> FirstMaskArgument);
+                               RVVArgDispatcher &RVVDispatcher);
 
 private:
   void analyzeInputArgs(MachineFunction &MF, CCState &CCInfo,
@@ -1015,19 +1016,68 @@ class RISCVTargetLowering : public TargetLowering {
   unsigned getMinimumJumpTableEntries() const override;
 };
 
+/// As per the spec, the rules for passing vector arguments are as follows:
+///
+/// 1. For the first vector mask argument, use v0 to pass it.
+/// 2. For vector data arguments or rest vector mask arguments, starting from
+/// the v8 register, if a vector register group between v8-v23 that has not been
+/// allocated can be found and the first register number is a multiple of LMUL,
+/// then allocate this vector register group to the argument and mark these
+/// registers as allocated. Otherwise, pass it by reference and are replaced in
+/// the argument list with the address.
+/// 3. For tuple vector data arguments, starting from the v8 register, if
+/// NFIELDS consecutive vector register groups between v8-v23 that have not been
+/// allocated can be found and the first register number is a multiple of LMUL,
+/// then allocate these vector register groups to the argument and mark these
+/// registers as allocated. Otherwise, pass it by reference and are replaced in
+/// the argument list with the address.
+class RVVArgDispatcher {
+public:
+  static constexpr unsigned NumArgVRs = 16;
+
+  struct RVVArgInfo {
+    unsigned NF;
+    MVT VT;
+    bool FirstVMask = false;
+  };
+
+  RVVArgDispatcher(const MachineFunction *MF, const RISCVTargetLowering *TLI,
+                   ArrayRef<Type *> TypeList)
+      : MF(MF), TLI(TLI) {
+    constructArgInfos(TypeList);
+    compute();
+  }
+
+  MCPhysReg getNextPhysReg();
+
+private:
+  SmallVector<RVVArgInfo, 4> RVVArgInfos;
+  SmallVector<MCPhysReg, 4> AllocatedPhysRegs;
+
+  const MachineFunction *MF = nullptr;
+  const RISCVTargetLowering *TLI = nullptr;
+
+  unsigned CurIdx = 0;
+
+  void constructArgInfos(ArrayRef<Type *> TypeList);
+  void compute();
+  void allocatePhysReg(unsigned NF = 1, unsigned LMul = 1,
+                       unsigned StartReg = 0);
+};
+
 namespace RISCV {
 
 bool CC_RISCV(const DataLayout &DL, RISCVABI::ABI ABI, unsigned ValNo,
               MVT ValVT, MVT LocVT, CCValAssign::LocInfo LocInfo,
               ISD::ArgFlagsTy ArgFlags, CCState &State, bool IsFixed,
               bool IsRet, Type *OrigTy, const RISCVTargetLowering &TLI,
-              std::optional<unsigned> FirstMaskArgument);
+              RVVArgDispatcher &RVVDispatcher);
 
 bool CC_RISCV_FastCC(const DataLayout &DL, RISCVABI::ABI ABI, unsigned ValNo,
                      MVT ValVT, MVT LocVT, CCValAssign::LocInfo LocInfo,
                      ISD::ArgFlagsTy ArgFlags, CCState &State, bool IsFixed,
                      bool IsRet, Type *OrigTy, const RISCVTargetLowering &TLI,
-                     std::optional<unsigned> FirstMaskArgument);
+                     RVVArgDispatcher &RVVDispatcher);
 
 bool CC_RISCV_GHC(unsigned ValNo, MVT ValVT, MVT LocVT,
                   CCValAssign::LocInfo LocInfo, ISD::ArgFlagsTy ArgFlags,
diff --git a/llvm/test/CodeGen/RISCV/rvv/calling-conv.ll b/llvm/test/CodeGen/RISCV/rvv/calling-conv.ll
index 78e8700a9feff8..90edb994ce8222 100644
--- a/llvm/test/CodeGen/RISCV/rvv/calling-conv.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/calling-conv.ll
@@ -162,3 +162,90 @@ define void @caller_tuple_argument({<vscale x 4 x i32>, <vscale x 4 x i32>} %x)
 }
 
 declare void @callee_tuple_argument({<vscale x 4 x i32>, <vscale x 4 x i32>})
+
+; %0 -> v8
+; %1 -> v9
+define <vscale x 1 x i64> @case1(<vscale x 1 x i64> %0, <vscale x 1 x i64> %1) {
+; CHECK-LABEL: case1:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a0, zero, e64, m1, ta, ma
+; CHECK-NEXT:    vadd.vv v8, v8, v9
+; CHECK-NEXT:    ret
+  %a = add <vscale x 1 x i64> %0, %1
+  ret <vscale x 1 x i64> %a
+}
+
+; %0 -> v8
+; %1 -> v10-v11
+; %2 -> v9
+define <vscale x 1 x i64> @case2_1(<vscale x 1 x i64> %0, <vscale x 2 x i64> %1, <vscale x 1 x i64> %2) {
+; CHECK-LABEL: case2_1:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a0, zero, e64, m1, ta, ma
+; CHECK-NEXT:    vadd.vv v8, v8, v9
+; CHECK-NEXT:    ret
+  %a = add <vscale x 1 x i64> %0, %2
+  ret <vscale x 1 x i64> %a
+}
+define <vscale x 2 x i64> @case2_2(<vscale x 1 x i64> %0, <vscale x 2 x i64> %1, <vscale x 1 x i64> %2) {
+; CHECK-LABEL: case2_2:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a0, zero, e64, m2, ta, ma
+; CHECK-NEXT:    vadd.vv v8, v10, v10
+; CHECK-NEXT:    ret
+  %a = add <vscale x 2 x i64> %1, %1
+  ret <vscale x 2 x i64> %a
+}
+
+; %0 -> v8
+; %1 -> {v10-v11, v12-v13}
+; %2 -> v9
+define <vscale x 1 x i64> @case3_1(<vscale x 1 x i64> %0, {<vscale x 2 x i64>, <vscale x 2 x i64>} %1, <vscale x 1 x i64> %2) {
+; CHECK-LABEL: case3_1:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a0, zero, e64, m1, ta, ma
+; CHECK-NEXT:    vadd.vv v8, v8, v9
+; CHECK-NEXT:    ret
+  %add = add <vscale x 1 x i64> %0, %2
+  ret <vscale x 1 x i64> %add
+}
+define <vscale x 2 x i64> @case3_2(<vscale x 1 x i64> %0, {<vscale x 2 x i64>, <vscale x 2 x i64>} %1, <vscale x 1 x i64> %2) {
+; CHECK-LABEL: case3_2:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a0, zero, e64, m2, ta, ma
+; CHECK-NEXT:    vadd.vv v8, v10, v12
+; CHECK-NEXT:    ret
+  %a = extractvalue { <vscale x 2 x i64>, <vscale x 2 x i64> } %1, 0
+  %b = extractvalue { <vscale x 2 x i64>, <vscale x 2 x i64> } %1, 1
+  %add = add <vscale x 2 x i64> %a, %b
+  ret <vscale x 2 x i64> %add
+}
+
+; %0 -> v8
+; %1 -> {by-ref, by-ref}
+; %2 -> v9
+define <vscale x 8 x i64> @case4_1(<vscale x 1 x i64> %0, {<vscale x 8 x i64>, <vscale x 8 x i64>} %1, <vscale x 1 x i64> %2) {
+; CHECK-LABEL: case4_1:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    csrr a1, vlenb
+; CHECK-NEXT:    slli a1, a1, 3
+; CHECK-NEXT:    add a1, a0, a1
+; CHECK-NEXT:    vl8re64.v v8, (a1)
+; CHECK-NEXT:    vl8re64.v v16, (a0)
+; CHECK-NEXT:    vsetvli a0, zero, e64, m8, ta, ma
+; CHECK-NEXT:    vadd.vv v8, v16, v8
+; CHECK-NEXT:    ret
+  %a = extractvalue { <vscale x 8 x i64>, <vscale x 8 x i64> } %1, 0
+  %b = extractvalue { <vscale x 8 x i64>, <vscale x 8 x i64> } %1, 1
+  %add = add <vscale x 8 x i64> %a, %b
+  ret <vscale x 8 x i64> %add
+}
+define <vscale x 1 x i64> @case4_2(<vscale x 1 x i64> %0, {<vscale x 8 x i64>, <vscale x 8 x i64>} %1, <vscale x 1 x i64> %2) {
+; CHECK-LABEL: case4_2:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a0, zero, e64, m1, ta, ma
+; CHECK-NEXT:    vadd.vv v8, v8, v9
+; CHECK-NEXT:    ret
+  %add = add <vscale x 1 x i64> %0, %2
+  ret <vscale x 1 x i64> %add
+}
diff --git a/llvm/test/CodeGen/RISCV/rvv/vector-deinterleave-load.ll b/llvm/test/CodeGen/RISCV/rvv/vector-deinterleave-load.ll
index a320aecc6fce49..6a712080fda74a 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vector-deinterleave-load.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vector-deinterleave-load.ll
@@ -18,10 +18,10 @@ define {<vscale x 16 x i1>, <vscale x 16 x i1>} @vector_deinterleave_load_nxv16i
 ; CHECK-NEXT:    vmerge.vim v14, v10, 1, v0
 ; CHECK-NEXT:    vmv1r.v v0, v8
 ; CHECK-NEXT:    vmerge.vim v12, v10, 1, v0
-; CHECK-NEXT:    vnsrl.wi v8, v12, 0
-; CHECK-NEXT:    vmsne.vi v0, v8, 0
-; CHECK-NEXT:    vnsrl.wi v10, v12, 8
+; CHECK-NEXT:    vnsrl.wi v10, v12, 0
 ; CHECK-NEXT:    vmsne.vi v8, v10, 0
+; CHECK-NEXT:    vnsrl.wi v10, v12, 8
+; CHECK-NEXT:    vmsne.vi v9, v10, 0
 ; CHECK-NEXT:    ret
   %vec = load <vscale x 32 x i1>, ptr %p
   %retval = call {<vscale x 16 x i1>, <vscale x 16 x i1>} @llvm.experimental.vector.deinterleave2.nxv32i1(<vscale x 32 x i1> %vec)
diff --git a/llvm/test/CodeGen/RISCV/rvv/vector-deinterleave.ll b/llvm/test/CodeGen/RISCV/rvv/vector-deinterleave.ll
index ef4baf34d23f03..d98597fabcd953 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vector-deinterleave.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vector-deinterleave.ll
@@ -8,18 +8,18 @@ define {<vscale x 16 x i1>, <vscale x 16 x i1>} @vector_deinterleave_nxv16i1_nxv
 ; CHECK-LABEL: vector_deinterleave_nxv16i1_nxv32i1:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vsetvli a0, zero, e8, m2, ta, ma
-; CHECK-NEXT:    vmv.v.i v10, 0
-; CHECK-NEXT:    vmerge.vim v8, v10, 1, v0
+; CHECK-NEXT:    vmv.v.i v8, 0
+; CHECK-NEXT:    vmerge.vim v12, v8, 1, v0
 ; CHECK-NEXT:    csrr a0, vlenb
 ; CHECK-NEXT:    srli a0, a0, 2
 ; CHECK-NEXT:    vsetvli a1, zero, e8, mf2, ta, ma
 ; CHECK-NEXT:    vslidedown.vx v0, v0, a0
 ; CHECK-NEXT:    vsetvli a0, zero, e8, m2, ta, ma
-; CHECK-NEXT:    vmerge.vim v10, v10, 1, v0
-; CHECK-NEXT:    vnsrl.wi v12, v8, 0
-; CHECK-NEXT:    vmsne.vi v0, v12, 0
-; CHECK-NEXT:    vnsrl.wi v12, v8, 8
-; CHECK-NEXT:    vmsne.vi v8, v12, 0
+; CHECK-NEXT:    vmerge.vim v14, v8, 1, v0
+; CHECK-NEXT:    vnsrl.wi v10, v12, 0
+; CHECK-NEXT:    vmsne.vi v8, v10, 0
+; CHECK-NEXT:    vnsrl.wi v10, v12, 8
+; CHECK-NEXT:    vmsne.vi v9, v10, 0
 ; CHECK-NEXT:    ret
 %retval = call {<vscale x 16 x i1>, <vscale x 16 x i1>} @llvm.experimental.vector.deinterleave2.nxv32i1(<vscale x 32 x i1> %vec)
 ret {<vscale x 16 x i1>, <vscale x 16 x i1>} %retval
@@ -102,12 +102,13 @@ define {<vscale x 64 x i1>, <vscale x 64 x i1>} @vector_deinterleave_nxv64i1_nxv
 ; CHECK-NEXT:    vsetvli a0, zero, e8, m4, ta, ma
 ; CHECK-NEXT:    vnsrl.wi v28, v8, 0
 ; CHECK-NEXT:    vsetvli a0, zero, e8, m8, ta, ma
-; CHECK-NEXT:    vmsne.vi v0, v24, 0
+; CHECK-NEXT:    vmsne.vi v7, v24, 0
 ; CHECK-NEXT:    vsetvli a0, zero, e8, m4, ta, ma
 ; CHECK-NEXT:    vnsrl.wi v24, v16, 8
 ; CHECK-NEXT:    vnsrl.wi v28, v8, 8
 ; CHECK-NEXT:    vsetvli a0, zero, e8, m8, ta, ma
-; CHECK-NEXT:    vmsne.vi v8, v24, 0
+; CHECK-NEXT:    vmsne.vi v9, v24, 0
+; CHECK-NEXT:    vmv1r.v v8, v7
 ; CHECK-NEXT:    ret
 %retval = call {<vscale x 64 x i1>, <vscale x 64 x i1>} @llvm.experimental.vector.deinterleave2.nxv128i1(<vscale x 128 x i1> %vec)
 ret {<vscale x 64 x i1>, <vscale x 64 x i1>} %retval



More information about the llvm-commits mailing list