[llvm] [RISCV] RISCV vector calling convention (2/2) (PR #79096)

Brandon Wu via llvm-commits llvm-commits at lists.llvm.org
Sat Feb 3 07:46:18 PST 2024


https://github.com/4vtomat updated https://github.com/llvm/llvm-project/pull/79096

>From 136a47987792938ae48896ff6695faa40213ce07 Mon Sep 17 00:00:00 2001
From: Brandon Wu <brandon.wu at sifive.com>
Date: Mon, 22 Jan 2024 21:35:46 -0800
Subject: [PATCH 1/4] [RISCV] RISCV vector calling convention (2/2)

This commit handles vector arguments/return for function definition/call,
the new class RVVArgDispatcher is added for doing all vector register
assignment including mask types, data types as well as tuple types.
It precomputes the register number for each argument as per
https://github.com/riscv-non-isa/riscv-elf-psabi-doc/blob/master/riscv-cc.adoc#standard-vector-calling-convention-variant
and it's passed to calling convention function to handle all vector arguments.

Depends on: #78550
---
 .../Target/RISCV/GISel/RISCVCallLowering.cpp  |  49 ++---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp   | 186 ++++++++++++++----
 llvm/lib/Target/RISCV/RISCVISelLowering.h     |  49 ++++-
 llvm/test/CodeGen/RISCV/rvv/calling-conv.ll   |  87 ++++++++
 .../RISCV/rvv/vector-deinterleave-load.ll     |   6 +-
 .../CodeGen/RISCV/rvv/vector-deinterleave.ll  |  17 +-
 6 files changed, 311 insertions(+), 83 deletions(-)

diff --git a/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp b/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp
index 45e19cdea300b..ff87601046dc2 100644
--- a/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp
+++ b/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp
@@ -34,14 +34,15 @@ struct RISCVOutgoingValueAssigner : public CallLowering::OutgoingValueAssigner {
   // Whether this is assigning args for a return.
   bool IsRet;
 
-  // true if assignArg has been called for a mask argument, false otherwise.
-  bool AssignedFirstMaskArg = false;
+  RVVArgDispatcher RVVDispatcher;
 
 public:
   RISCVOutgoingValueAssigner(
-      RISCVTargetLowering::RISCVCCAssignFn *RISCVAssignFn_, bool IsRet)
+      RISCVTargetLowering::RISCVCCAssignFn *RISCVAssignFn_, bool IsRet,
+      RVVArgDispatcher &RVVDispatcher)
       : CallLowering::OutgoingValueAssigner(nullptr),
-        RISCVAssignFn(RISCVAssignFn_), IsRet(IsRet) {}
+        RISCVAssignFn(RISCVAssignFn_), IsRet(IsRet),
+        RVVDispatcher(RVVDispatcher) {}
 
   bool assignArg(unsigned ValNo, EVT OrigVT, MVT ValVT, MVT LocVT,
                  CCValAssign::LocInfo LocInfo,
@@ -51,16 +52,9 @@ struct RISCVOutgoingValueAssigner : public CallLowering::OutgoingValueAssigner {
     const DataLayout &DL = MF.getDataLayout();
     const RISCVSubtarget &Subtarget = MF.getSubtarget<RISCVSubtarget>();
 
-    std::optional<unsigned> FirstMaskArgument;
-    if (Subtarget.hasVInstructions() && !AssignedFirstMaskArg &&
-        ValVT.isVector() && ValVT.getVectorElementType() == MVT::i1) {
-      FirstMaskArgument = ValNo;
-      AssignedFirstMaskArg = true;
-    }
-
     if (RISCVAssignFn(DL, Subtarget.getTargetABI(), ValNo, ValVT, LocVT,
                       LocInfo, Flags, State, Info.IsFixed, IsRet, Info.Ty,
-                      *Subtarget.getTargetLowering(), FirstMaskArgument))
+                      *Subtarget.getTargetLowering(), RVVDispatcher))
       return true;
 
     StackSize = State.getStackSize();
@@ -181,14 +175,15 @@ struct RISCVIncomingValueAssigner : public CallLowering::IncomingValueAssigner {
   // Whether this is assigning args from a return.
   bool IsRet;
 
-  // true if assignArg has been called for a mask argument, false otherwise.
-  bool AssignedFirstMaskArg = false;
+  RVVArgDispatcher &RVVDispatcher;
 
 public:
   RISCVIncomingValueAssigner(
-      RISCVTargetLowering::RISCVCCAssignFn *RISCVAssignFn_, bool IsRet)
+      RISCVTargetLowering::RISCVCCAssignFn *RISCVAssignFn_, bool IsRet,
+      RVVArgDispatcher &RVVDispatcher)
       : CallLowering::IncomingValueAssigner(nullptr),
-        RISCVAssignFn(RISCVAssignFn_), IsRet(IsRet) {}
+        RISCVAssignFn(RISCVAssignFn_), IsRet(IsRet),
+        RVVDispatcher(RVVDispatcher) {}
 
   bool assignArg(unsigned ValNo, EVT OrigVT, MVT ValVT, MVT LocVT,
                  CCValAssign::LocInfo LocInfo,
@@ -201,16 +196,9 @@ struct RISCVIncomingValueAssigner : public CallLowering::IncomingValueAssigner {
     if (LocVT.isScalableVector())
       MF.getInfo<RISCVMachineFunctionInfo>()->setIsVectorCall();
 
-    std::optional<unsigned> FirstMaskArgument;
-    if (Subtarget.hasVInstructions() && !AssignedFirstMaskArg &&
-        ValVT.isVector() && ValVT.getVectorElementType() == MVT::i1) {
-      FirstMaskArgument = ValNo;
-      AssignedFirstMaskArg = true;
-    }
-
     if (RISCVAssignFn(DL, Subtarget.getTargetABI(), ValNo, ValVT, LocVT,
                       LocInfo, Flags, State, /*IsFixed=*/true, IsRet, Info.Ty,
-                      *Subtarget.getTargetLowering(), FirstMaskArgument))
+                      *Subtarget.getTargetLowering(), RVVDispatcher))
       return true;
 
     StackSize = State.getStackSize();
@@ -420,9 +408,10 @@ bool RISCVCallLowering::lowerReturnVal(MachineIRBuilder &MIRBuilder,
   SmallVector<ArgInfo, 4> SplitRetInfos;
   splitToValueTypes(OrigRetInfo, SplitRetInfos, DL, CC);
 
+  RVVArgDispatcher Dispatcher{&MF, getTLI<RISCVTargetLowering>(), true};
   RISCVOutgoingValueAssigner Assigner(
       CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
-      /*IsRet=*/true);
+      /*IsRet=*/true, Dispatcher);
   RISCVOutgoingValueHandler Handler(MIRBuilder, MF.getRegInfo(), Ret);
   return determineAndHandleAssignments(Handler, Assigner, SplitRetInfos,
                                        MIRBuilder, CC, F.isVarArg());
@@ -545,9 +534,10 @@ bool RISCVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
     ++Index;
   }
 
+  RVVArgDispatcher Dispatcher{&MF, getTLI<RISCVTargetLowering>(), false};
   RISCVIncomingValueAssigner Assigner(
       CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
-      /*IsRet=*/false);
+      /*IsRet=*/false, Dispatcher);
   RISCVFormalArgHandler Handler(MIRBuilder, MF.getRegInfo());
 
   SmallVector<CCValAssign, 16> ArgLocs;
@@ -607,9 +597,11 @@ bool RISCVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
   const TargetRegisterInfo *TRI = Subtarget.getRegisterInfo();
   Call.addRegMask(TRI->getCallPreservedMask(MF, Info.CallConv));
 
+  RVVArgDispatcher ArgDispatcher{&MF, getTLI<RISCVTargetLowering>(), false,
+                                 nullptr, &Info};
   RISCVOutgoingValueAssigner ArgAssigner(
       CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
-      /*IsRet=*/false);
+      /*IsRet=*/false, ArgDispatcher);
   RISCVOutgoingValueHandler ArgHandler(MIRBuilder, MF.getRegInfo(), Call);
   if (!determineAndHandleAssignments(ArgHandler, ArgAssigner, SplitArgInfos,
                                      MIRBuilder, CC, Info.IsVarArg))
@@ -637,9 +629,10 @@ bool RISCVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
   SmallVector<ArgInfo, 4> SplitRetInfos;
   splitToValueTypes(Info.OrigRet, SplitRetInfos, DL, CC);
 
+  RVVArgDispatcher RetDispatcher{&MF, getTLI<RISCVTargetLowering>(), true};
   RISCVIncomingValueAssigner RetAssigner(
       CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
-      /*IsRet=*/true);
+      /*IsRet=*/true, RetDispatcher);
   RISCVCallReturnHandler RetHandler(MIRBuilder, MF.getRegInfo(), Call);
   if (!determineAndHandleAssignments(RetHandler, RetAssigner, SplitRetInfos,
                                      MIRBuilder, CC, Info.IsVarArg))
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 37d8ada1f926c..5490bab930fec 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -22,6 +22,7 @@
 #include "llvm/ADT/Statistic.h"
 #include "llvm/Analysis/MemoryLocation.h"
 #include "llvm/Analysis/VectorUtils.h"
+#include "llvm/CodeGen/Analysis.h"
 #include "llvm/CodeGen/MachineFrameInfo.h"
 #include "llvm/CodeGen/MachineFunction.h"
 #include "llvm/CodeGen/MachineInstrBuilder.h"
@@ -17527,33 +17528,12 @@ static bool CC_RISCVAssign2XLen(unsigned XLen, CCState &State, CCValAssign VA1,
   return false;
 }
 
-static unsigned allocateRVVReg(MVT ValVT, unsigned ValNo,
-                               std::optional<unsigned> FirstMaskArgument,
-                               CCState &State, const RISCVTargetLowering &TLI) {
-  const TargetRegisterClass *RC = TLI.getRegClassFor(ValVT);
-  if (RC == &RISCV::VRRegClass) {
-    // Assign the first mask argument to V0.
-    // This is an interim calling convention and it may be changed in the
-    // future.
-    if (FirstMaskArgument && ValNo == *FirstMaskArgument)
-      return State.AllocateReg(RISCV::V0);
-    return State.AllocateReg(ArgVRs);
-  }
-  if (RC == &RISCV::VRM2RegClass)
-    return State.AllocateReg(ArgVRM2s);
-  if (RC == &RISCV::VRM4RegClass)
-    return State.AllocateReg(ArgVRM4s);
-  if (RC == &RISCV::VRM8RegClass)
-    return State.AllocateReg(ArgVRM8s);
-  llvm_unreachable("Unhandled register class for ValueType");
-}
-
 // Implements the RISC-V calling convention. Returns true upon failure.
 bool RISCV::CC_RISCV(const DataLayout &DL, RISCVABI::ABI ABI, unsigned ValNo,
                      MVT ValVT, MVT LocVT, CCValAssign::LocInfo LocInfo,
                      ISD::ArgFlagsTy ArgFlags, CCState &State, bool IsFixed,
                      bool IsRet, Type *OrigTy, const RISCVTargetLowering &TLI,
-                     std::optional<unsigned> FirstMaskArgument) {
+                     RVVArgDispatcher &RVVDispatcher) {
   unsigned XLen = DL.getLargestLegalIntTypeSizeInBits();
   assert(XLen == 32 || XLen == 64);
   MVT XLenVT = XLen == 32 ? MVT::i32 : MVT::i64;
@@ -17711,7 +17691,7 @@ bool RISCV::CC_RISCV(const DataLayout &DL, RISCVABI::ABI ABI, unsigned ValNo,
   }
 
   // Allocate to a register if possible, or else a stack slot.
-  Register Reg;
+  Register Reg = MCRegister();
   unsigned StoreSizeBytes = XLen / 8;
   Align StackAlign = Align(XLen / 8);
 
@@ -17722,7 +17702,7 @@ bool RISCV::CC_RISCV(const DataLayout &DL, RISCVABI::ABI ABI, unsigned ValNo,
   else if (ValVT == MVT::f64 && !UseGPRForF64)
     Reg = State.AllocateReg(ArgFPR64s);
   else if (ValVT.isVector()) {
-    Reg = allocateRVVReg(ValVT, ValNo, FirstMaskArgument, State, TLI);
+    Reg = RVVDispatcher.getNextPhysReg();
     if (!Reg) {
       // For return values, the vector must be passed fully via registers or
       // via the stack.
@@ -17808,9 +17788,7 @@ void RISCVTargetLowering::analyzeInputArgs(
   unsigned NumArgs = Ins.size();
   FunctionType *FType = MF.getFunction().getFunctionType();
 
-  std::optional<unsigned> FirstMaskArgument;
-  if (Subtarget.hasVInstructions())
-    FirstMaskArgument = preAssignMask(Ins);
+  RVVArgDispatcher Dispatcher{&MF, this, IsRet};
 
   for (unsigned i = 0; i != NumArgs; ++i) {
     MVT ArgVT = Ins[i].VT;
@@ -17825,7 +17803,7 @@ void RISCVTargetLowering::analyzeInputArgs(
     RISCVABI::ABI ABI = MF.getSubtarget<RISCVSubtarget>().getTargetABI();
     if (Fn(MF.getDataLayout(), ABI, i, ArgVT, ArgVT, CCValAssign::Full,
            ArgFlags, CCInfo, /*IsFixed=*/true, IsRet, ArgTy, *this,
-           FirstMaskArgument)) {
+           Dispatcher)) {
       LLVM_DEBUG(dbgs() << "InputArg #" << i << " has unhandled type "
                         << ArgVT << '\n');
       llvm_unreachable(nullptr);
@@ -17839,9 +17817,7 @@ void RISCVTargetLowering::analyzeOutputArgs(
     CallLoweringInfo *CLI, RISCVCCAssignFn Fn) const {
   unsigned NumArgs = Outs.size();
 
-  std::optional<unsigned> FirstMaskArgument;
-  if (Subtarget.hasVInstructions())
-    FirstMaskArgument = preAssignMask(Outs);
+  RVVArgDispatcher Dispatcher{&MF, this, IsRet, CLI};
 
   for (unsigned i = 0; i != NumArgs; i++) {
     MVT ArgVT = Outs[i].VT;
@@ -17851,7 +17827,7 @@ void RISCVTargetLowering::analyzeOutputArgs(
     RISCVABI::ABI ABI = MF.getSubtarget<RISCVSubtarget>().getTargetABI();
     if (Fn(MF.getDataLayout(), ABI, i, ArgVT, ArgVT, CCValAssign::Full,
            ArgFlags, CCInfo, Outs[i].IsFixed, IsRet, OrigTy, *this,
-           FirstMaskArgument)) {
+           Dispatcher)) {
       LLVM_DEBUG(dbgs() << "OutputArg #" << i << " has unhandled type "
                         << ArgVT << "\n");
       llvm_unreachable(nullptr);
@@ -18032,7 +18008,7 @@ bool RISCV::CC_RISCV_FastCC(const DataLayout &DL, RISCVABI::ABI ABI,
                             ISD::ArgFlagsTy ArgFlags, CCState &State,
                             bool IsFixed, bool IsRet, Type *OrigTy,
                             const RISCVTargetLowering &TLI,
-                            std::optional<unsigned> FirstMaskArgument) {
+                            RVVArgDispatcher &RVVDispatcher) {
   if (LocVT == MVT::i32 || LocVT == MVT::i64) {
     if (unsigned Reg = State.AllocateReg(getFastCCArgGPRs(ABI))) {
       State.addLoc(CCValAssign::getReg(ValNo, ValVT, Reg, LocVT, LocInfo));
@@ -18110,13 +18086,14 @@ bool RISCV::CC_RISCV_FastCC(const DataLayout &DL, RISCVABI::ABI ABI,
   }
 
   if (LocVT.isVector()) {
-    if (unsigned Reg =
-            allocateRVVReg(ValVT, ValNo, FirstMaskArgument, State, TLI)) {
+    MCPhysReg AllocatedVReg = RVVDispatcher.getNextPhysReg();
+    if (AllocatedVReg) {
       // Fixed-length vectors are located in the corresponding scalable-vector
       // container types.
       if (ValVT.isFixedLengthVector())
         LocVT = TLI.getContainerForFixedLengthVector(LocVT);
-      State.addLoc(CCValAssign::getReg(ValNo, ValVT, Reg, LocVT, LocInfo));
+      State.addLoc(
+          CCValAssign::getReg(ValNo, ValVT, AllocatedVReg, LocVT, LocInfo));
     } else {
       // Try and pass the address via a "fast" GPR.
       if (unsigned GPRReg = State.AllocateReg(getFastCCArgGPRs(ABI))) {
@@ -18743,17 +18720,15 @@ bool RISCVTargetLowering::CanLowerReturn(
   SmallVector<CCValAssign, 16> RVLocs;
   CCState CCInfo(CallConv, IsVarArg, MF, RVLocs, Context);
 
-  std::optional<unsigned> FirstMaskArgument;
-  if (Subtarget.hasVInstructions())
-    FirstMaskArgument = preAssignMask(Outs);
+  RVVArgDispatcher Dispatcher{&MF, this, true};
 
   for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
     MVT VT = Outs[i].VT;
     ISD::ArgFlagsTy ArgFlags = Outs[i].Flags;
     RISCVABI::ABI ABI = MF.getSubtarget<RISCVSubtarget>().getTargetABI();
     if (RISCV::CC_RISCV(MF.getDataLayout(), ABI, i, VT, VT, CCValAssign::Full,
-                 ArgFlags, CCInfo, /*IsFixed=*/true, /*IsRet=*/true, nullptr,
-                 *this, FirstMaskArgument))
+                        ArgFlags, CCInfo, /*IsFixed=*/true, /*IsRet=*/true,
+                        nullptr, *this, Dispatcher))
       return false;
   }
   return true;
@@ -20543,6 +20518,135 @@ unsigned RISCVTargetLowering::getMinimumJumpTableEntries() const {
   return Subtarget.getMinimumJumpTableEntries();
 }
 
+void RVVArgDispatcher::constructHelper(Type *Ty) {
+  const DataLayout &DL = MF->getDataLayout();
+  const Function &F = MF->getFunction();
+  LLVMContext &Context = F.getContext();
+
+  StructType *STy = dyn_cast<StructType>(Ty);
+  if (STy && STy->containsHomogeneousScalableVectorTypes()) {
+    Type *ElemTy = STy->getTypeAtIndex(0U);
+    EVT VT = TLI->getValueType(DL, ElemTy);
+    MVT RegisterVT =
+        TLI->getRegisterTypeForCallingConv(Context, F.getCallingConv(), VT);
+
+    RVVArgInfos.push_back({STy->getNumElements(), RegisterVT, false});
+  } else {
+    SmallVector<EVT, 4> ValueVTs;
+    ComputeValueVTs(*TLI, DL, Ty, ValueVTs);
+
+    for (unsigned Value = 0, NumValues = ValueVTs.size(); Value != NumValues;
+         ++Value) {
+      EVT VT = ValueVTs[Value];
+      MVT RegisterVT =
+          TLI->getRegisterTypeForCallingConv(Context, F.getCallingConv(), VT);
+      unsigned NumRegs =
+          TLI->getNumRegistersForCallingConv(Context, F.getCallingConv(), VT);
+
+      // Skip non-RVV register type
+      if (!RegisterVT.isVector())
+        continue;
+
+      if (RegisterVT.isFixedLengthVector())
+        RegisterVT = TLI->getContainerForFixedLengthVector(RegisterVT);
+
+      RVVArgInfo Info{1, RegisterVT, false};
+
+      while (NumRegs--)
+        RVVArgInfos.push_back(Info);
+    }
+  }
+}
+
+void RVVArgDispatcher::construct(bool IsRet) {
+  const Function &F = MF->getFunction();
+
+  if (IsRet)
+    constructHelper(F.getReturnType());
+  else if (CLI)
+    for (const TargetLowering::ArgListEntry &Arg : CLI->getArgs())
+      constructHelper(Arg.Ty);
+  else if (GISelCLI)
+    for (const CallLowering::ArgInfo &Arg : GISelCLI->OrigArgs)
+      constructHelper(Arg.Ty);
+  else
+    for (const Argument &Arg : F.args())
+      constructHelper(Arg.getType());
+
+  for (auto &Info : RVVArgInfos)
+    if (Info.NF == 1 && Info.VT.getVectorElementType() == MVT::i1) {
+      Info.FirstVMask = true;
+      break;
+    }
+}
+
+void RVVArgDispatcher::allocatePhysReg(unsigned NF, unsigned LMul,
+                                       unsigned StartReg) {
+  assert((StartReg % LMul) == 0 &&
+         "Start register number should be multiple of lmul");
+  const MCPhysReg *VRArrays;
+  switch (LMul) {
+  default:
+    report_fatal_error("Invalid lmul");
+  case 1:
+    VRArrays = ArgVRs;
+    break;
+  case 2:
+    VRArrays = ArgVRM2s;
+    break;
+  case 4:
+    VRArrays = ArgVRM4s;
+    break;
+  case 8:
+    VRArrays = ArgVRM8s;
+    break;
+  }
+
+  for (unsigned i = 0; i < NF; ++i)
+    if (StartReg)
+      AllocatedPhysRegs.push_back(VRArrays[(StartReg - 8) / LMul + i]);
+    else
+      AllocatedPhysRegs.push_back(0);
+}
+
+// This function determines if each RVV argument is passed by register.
+void RVVArgDispatcher::compute() {
+  unsigned ToBeAssigned = RVVArgInfos.size();
+  uint64_t AssignedMap = 0;
+  auto tryAllocate = [&](const RVVArgInfo &ArgInfo) {
+    // Allocate first vector mask argument to V0.
+    if (ArgInfo.FirstVMask) {
+      AllocatedPhysRegs.push_back(RISCV::V0);
+      return;
+    }
+
+    unsigned RegsNeeded = std::max(
+        ArgInfo.VT.getSizeInBits().getKnownMinValue() / RISCV::RVVBitsPerBlock,
+        1UL);
+    unsigned TotalRegsNeeded = ArgInfo.NF * RegsNeeded;
+    for (unsigned StartReg = 0; StartReg + TotalRegsNeeded <= NumArgVRs;
+         StartReg += RegsNeeded) {
+      unsigned Map = ((1 << TotalRegsNeeded) - 1) << StartReg;
+      if ((AssignedMap & Map) == 0) {
+        allocatePhysReg(ArgInfo.NF, RegsNeeded, StartReg + 8);
+        AssignedMap |= Map;
+        return;
+      }
+    }
+
+    allocatePhysReg(ArgInfo.NF, RegsNeeded, 0);
+    return;
+  };
+
+  for (unsigned i = 0; i < ToBeAssigned; ++i)
+    tryAllocate(RVVArgInfos[i]);
+}
+
+MCPhysReg RVVArgDispatcher::getNextPhysReg() {
+  assert(CurIdx < AllocatedPhysRegs.size() && "Index out of range");
+  return AllocatedPhysRegs[CurIdx++];
+}
+
 namespace llvm::RISCVVIntrinsicsTable {
 
 #define GET_RISCVVIntrinsicsTable_IMPL
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 255b1d0e15eed..3224faf3b34a9 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -16,6 +16,7 @@
 
 #include "RISCV.h"
 #include "llvm/CodeGen/CallingConvLower.h"
+#include "llvm/CodeGen/GlobalISel/CallLowering.h"
 #include "llvm/CodeGen/SelectionDAG.h"
 #include "llvm/CodeGen/TargetLowering.h"
 #include "llvm/TargetParser/RISCVTargetParser.h"
@@ -25,6 +26,7 @@ namespace llvm {
 class InstructionCost;
 class RISCVSubtarget;
 struct RISCVRegisterInfo;
+class RVVArgDispatcher;
 
 namespace RISCVISD {
 // clang-format off
@@ -868,7 +870,7 @@ class RISCVTargetLowering : public TargetLowering {
                                ISD::ArgFlagsTy ArgFlags, CCState &State,
                                bool IsFixed, bool IsRet, Type *OrigTy,
                                const RISCVTargetLowering &TLI,
-                               std::optional<unsigned> FirstMaskArgument);
+                               RVVArgDispatcher &RVVDispatcher);
 
 private:
   void analyzeInputArgs(MachineFunction &MF, CCState &CCInfo,
@@ -1008,19 +1010,60 @@ class RISCVTargetLowering : public TargetLowering {
   unsigned getMinimumJumpTableEntries() const override;
 };
 
+class RVVArgDispatcher {
+public:
+  static constexpr unsigned NumArgVRs = 16;
+
+  struct RVVArgInfo {
+    unsigned NF;
+    MVT VT;
+    bool FirstVMask = false;
+  };
+
+  RVVArgDispatcher() = default;
+  RVVArgDispatcher(const MachineFunction *MF, const RISCVTargetLowering *TLI,
+                   bool IsRet, TargetLowering::CallLoweringInfo *CLI = nullptr,
+                   CallLowering::CallLoweringInfo *GISelCLI = nullptr)
+      : MF(MF), TLI(TLI), CLI(CLI), GISelCLI(GISelCLI) {
+    assert((!CLI || !GISelCLI) &&
+           "ISel and Global ISel can't co-exist at the same time");
+    construct(IsRet);
+    compute();
+  }
+
+  MCPhysReg getNextPhysReg();
+
+private:
+  std::vector<RVVArgInfo> RVVArgInfos;
+  std::vector<MCPhysReg> AllocatedPhysRegs;
+
+  const MachineFunction *MF = nullptr;
+  const RISCVTargetLowering *TLI = nullptr;
+  TargetLowering::CallLoweringInfo *CLI = nullptr;
+  CallLowering::CallLoweringInfo *GISelCLI = nullptr;
+
+  unsigned CurIdx = 0;
+
+  void construct(bool IsRet);
+  void constructHelper(Type *Ty);
+  void compute();
+  void allocatePhysReg(unsigned NF = 1, unsigned LMul = 1,
+                       unsigned StartReg = 0);
+};
+
 namespace RISCV {
 
 bool CC_RISCV(const DataLayout &DL, RISCVABI::ABI ABI, unsigned ValNo,
               MVT ValVT, MVT LocVT, CCValAssign::LocInfo LocInfo,
               ISD::ArgFlagsTy ArgFlags, CCState &State, bool IsFixed,
               bool IsRet, Type *OrigTy, const RISCVTargetLowering &TLI,
-              std::optional<unsigned> FirstMaskArgument);
+              RVVArgDispatcher &RVVDispatcher);
 
 bool CC_RISCV_FastCC(const DataLayout &DL, RISCVABI::ABI ABI, unsigned ValNo,
                      MVT ValVT, MVT LocVT, CCValAssign::LocInfo LocInfo,
                      ISD::ArgFlagsTy ArgFlags, CCState &State, bool IsFixed,
                      bool IsRet, Type *OrigTy, const RISCVTargetLowering &TLI,
-                     std::optional<unsigned> FirstMaskArgument);
+                     RVVArgDispatcher &RVVDispatcher);
 
 bool CC_RISCV_GHC(unsigned ValNo, MVT ValVT, MVT LocVT,
                   CCValAssign::LocInfo LocInfo, ISD::ArgFlagsTy ArgFlags,
diff --git a/llvm/test/CodeGen/RISCV/rvv/calling-conv.ll b/llvm/test/CodeGen/RISCV/rvv/calling-conv.ll
index 78385a80b47eb..005bc5b60ca80 100644
--- a/llvm/test/CodeGen/RISCV/rvv/calling-conv.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/calling-conv.ll
@@ -86,3 +86,90 @@ define <vscale x 32 x i32> @caller_scalable_vector_split_indirect(<vscale x 32 x
   %a = call <vscale x 32 x i32> @callee_scalable_vector_split_indirect(<vscale x 32 x i32> zeroinitializer, <vscale x 32 x i32> %x)
   ret <vscale x 32 x i32> %a
 }
+
+; %0 -> v8
+; %1 -> v9
+define <vscale x 1 x i64> @case1(<vscale x 1 x i64> %0, <vscale x 1 x i64> %1) {
+; CHECK-LABEL: case1:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a0, zero, e64, m1, ta, ma
+; CHECK-NEXT:    vadd.vv v8, v8, v9
+; CHECK-NEXT:    ret
+  %a = add <vscale x 1 x i64> %0, %1
+  ret <vscale x 1 x i64> %a
+}
+
+; %0 -> v8
+; %1 -> v10-v11
+; %2 -> v9
+define <vscale x 1 x i64> @case2_1(<vscale x 1 x i64> %0, <vscale x 2 x i64> %1, <vscale x 1 x i64> %2) {
+; CHECK-LABEL: case2_1:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a0, zero, e64, m1, ta, ma
+; CHECK-NEXT:    vadd.vv v8, v8, v9
+; CHECK-NEXT:    ret
+  %a = add <vscale x 1 x i64> %0, %2
+  ret <vscale x 1 x i64> %a
+}
+define <vscale x 2 x i64> @case2_2(<vscale x 1 x i64> %0, <vscale x 2 x i64> %1, <vscale x 1 x i64> %2) {
+; CHECK-LABEL: case2_2:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a0, zero, e64, m2, ta, ma
+; CHECK-NEXT:    vadd.vv v8, v10, v10
+; CHECK-NEXT:    ret
+  %a = add <vscale x 2 x i64> %1, %1
+  ret <vscale x 2 x i64> %a
+}
+
+; %0 -> v8
+; %1 -> {v10-v11, v12-v13}
+; %2 -> v9
+define <vscale x 1 x i64> @case3_1(<vscale x 1 x i64> %0, {<vscale x 2 x i64>, <vscale x 2 x i64>} %1, <vscale x 1 x i64> %2) {
+; CHECK-LABEL: case3_1:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a0, zero, e64, m1, ta, ma
+; CHECK-NEXT:    vadd.vv v8, v8, v9
+; CHECK-NEXT:    ret
+  %add = add <vscale x 1 x i64> %0, %2
+  ret <vscale x 1 x i64> %add
+}
+define <vscale x 2 x i64> @case3_2(<vscale x 1 x i64> %0, {<vscale x 2 x i64>, <vscale x 2 x i64>} %1, <vscale x 1 x i64> %2) {
+; CHECK-LABEL: case3_2:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a0, zero, e64, m2, ta, ma
+; CHECK-NEXT:    vadd.vv v8, v10, v12
+; CHECK-NEXT:    ret
+  %a = extractvalue { <vscale x 2 x i64>, <vscale x 2 x i64> } %1, 0
+  %b = extractvalue { <vscale x 2 x i64>, <vscale x 2 x i64> } %1, 1
+  %add = add <vscale x 2 x i64> %a, %b
+  ret <vscale x 2 x i64> %add
+}
+
+; %0 -> v8
+; %1 -> {by-ref, by-ref}
+; %2 -> v9
+define <vscale x 8 x i64> @case4_1(<vscale x 1 x i64> %0, {<vscale x 8 x i64>, <vscale x 8 x i64>} %1, <vscale x 1 x i64> %2) {
+; CHECK-LABEL: case4_1:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    csrr a1, vlenb
+; CHECK-NEXT:    slli a1, a1, 3
+; CHECK-NEXT:    add a1, a0, a1
+; CHECK-NEXT:    vl8re64.v v8, (a1)
+; CHECK-NEXT:    vl8re64.v v16, (a0)
+; CHECK-NEXT:    vsetvli a0, zero, e64, m8, ta, ma
+; CHECK-NEXT:    vadd.vv v8, v16, v8
+; CHECK-NEXT:    ret
+  %a = extractvalue { <vscale x 8 x i64>, <vscale x 8 x i64> } %1, 0
+  %b = extractvalue { <vscale x 8 x i64>, <vscale x 8 x i64> } %1, 1
+  %add = add <vscale x 8 x i64> %a, %b
+  ret <vscale x 8 x i64> %add
+}
+define <vscale x 1 x i64> @case4_2(<vscale x 1 x i64> %0, {<vscale x 8 x i64>, <vscale x 8 x i64>} %1, <vscale x 1 x i64> %2) {
+; CHECK-LABEL: case4_2:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a0, zero, e64, m1, ta, ma
+; CHECK-NEXT:    vadd.vv v8, v8, v9
+; CHECK-NEXT:    ret
+  %add = add <vscale x 1 x i64> %0, %2
+  ret <vscale x 1 x i64> %add
+}
diff --git a/llvm/test/CodeGen/RISCV/rvv/vector-deinterleave-load.ll b/llvm/test/CodeGen/RISCV/rvv/vector-deinterleave-load.ll
index a320aecc6fce4..6a712080fda74 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vector-deinterleave-load.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vector-deinterleave-load.ll
@@ -18,10 +18,10 @@ define {<vscale x 16 x i1>, <vscale x 16 x i1>} @vector_deinterleave_load_nxv16i
 ; CHECK-NEXT:    vmerge.vim v14, v10, 1, v0
 ; CHECK-NEXT:    vmv1r.v v0, v8
 ; CHECK-NEXT:    vmerge.vim v12, v10, 1, v0
-; CHECK-NEXT:    vnsrl.wi v8, v12, 0
-; CHECK-NEXT:    vmsne.vi v0, v8, 0
-; CHECK-NEXT:    vnsrl.wi v10, v12, 8
+; CHECK-NEXT:    vnsrl.wi v10, v12, 0
 ; CHECK-NEXT:    vmsne.vi v8, v10, 0
+; CHECK-NEXT:    vnsrl.wi v10, v12, 8
+; CHECK-NEXT:    vmsne.vi v9, v10, 0
 ; CHECK-NEXT:    ret
   %vec = load <vscale x 32 x i1>, ptr %p
   %retval = call {<vscale x 16 x i1>, <vscale x 16 x i1>} @llvm.experimental.vector.deinterleave2.nxv32i1(<vscale x 32 x i1> %vec)
diff --git a/llvm/test/CodeGen/RISCV/rvv/vector-deinterleave.ll b/llvm/test/CodeGen/RISCV/rvv/vector-deinterleave.ll
index ef4baf34d23f0..d5a818e0e3820 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vector-deinterleave.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vector-deinterleave.ll
@@ -8,18 +8,18 @@ define {<vscale x 16 x i1>, <vscale x 16 x i1>} @vector_deinterleave_nxv16i1_nxv
 ; CHECK-LABEL: vector_deinterleave_nxv16i1_nxv32i1:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vsetvli a0, zero, e8, m2, ta, ma
-; CHECK-NEXT:    vmv.v.i v10, 0
-; CHECK-NEXT:    vmerge.vim v8, v10, 1, v0
+; CHECK-NEXT:    vmv.v.i v8, 0
+; CHECK-NEXT:    vmerge.vim v12, v8, 1, v0
 ; CHECK-NEXT:    csrr a0, vlenb
 ; CHECK-NEXT:    srli a0, a0, 2
 ; CHECK-NEXT:    vsetvli a1, zero, e8, mf2, ta, ma
 ; CHECK-NEXT:    vslidedown.vx v0, v0, a0
 ; CHECK-NEXT:    vsetvli a0, zero, e8, m2, ta, ma
-; CHECK-NEXT:    vmerge.vim v10, v10, 1, v0
-; CHECK-NEXT:    vnsrl.wi v12, v8, 0
-; CHECK-NEXT:    vmsne.vi v0, v12, 0
-; CHECK-NEXT:    vnsrl.wi v12, v8, 8
-; CHECK-NEXT:    vmsne.vi v8, v12, 0
+; CHECK-NEXT:    vmerge.vim v14, v8, 1, v0
+; CHECK-NEXT:    vnsrl.wi v10, v12, 0
+; CHECK-NEXT:    vmsne.vi v8, v10, 0
+; CHECK-NEXT:    vnsrl.wi v10, v12, 8
+; CHECK-NEXT:    vmsne.vi v9, v10, 0
 ; CHECK-NEXT:    ret
 %retval = call {<vscale x 16 x i1>, <vscale x 16 x i1>} @llvm.experimental.vector.deinterleave2.nxv32i1(<vscale x 32 x i1> %vec)
 ret {<vscale x 16 x i1>, <vscale x 16 x i1>} %retval
@@ -107,7 +107,8 @@ define {<vscale x 64 x i1>, <vscale x 64 x i1>} @vector_deinterleave_nxv64i1_nxv
 ; CHECK-NEXT:    vnsrl.wi v24, v16, 8
 ; CHECK-NEXT:    vnsrl.wi v28, v8, 8
 ; CHECK-NEXT:    vsetvli a0, zero, e8, m8, ta, ma
-; CHECK-NEXT:    vmsne.vi v8, v24, 0
+; CHECK-NEXT:    vmsne.vi v9, v24, 0
+; CHECK-NEXT:    vmv1r.v v8, v0
 ; CHECK-NEXT:    ret
 %retval = call {<vscale x 64 x i1>, <vscale x 64 x i1>} @llvm.experimental.vector.deinterleave2.nxv128i1(<vscale x 128 x i1> %vec)
 ret {<vscale x 64 x i1>, <vscale x 64 x i1>} %retval

>From c06a2ed69cd9eab9314c1d10db35ac10f49f0216 Mon Sep 17 00:00:00 2001
From: Brandon Wu <brandon.wu at sifive.com>
Date: Mon, 22 Jan 2024 23:39:13 -0800
Subject: [PATCH 2/4] fixup! [RISCV] RISCV vector calling convention (2/2)

---
 .../Target/RISCV/GISel/RISCVCallLowering.cpp  | 18 ++++++----
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp   | 36 ++++++++++---------
 llvm/lib/Target/RISCV/RISCVISelLowering.h     | 20 ++++++-----
 3 files changed, 43 insertions(+), 31 deletions(-)

diff --git a/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp b/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp
index ff87601046dc2..63ee73a05f74a 100644
--- a/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp
+++ b/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp
@@ -34,7 +34,7 @@ struct RISCVOutgoingValueAssigner : public CallLowering::OutgoingValueAssigner {
   // Whether this is assigning args for a return.
   bool IsRet;
 
-  RVVArgDispatcher RVVDispatcher;
+  RVVArgDispatcher &RVVDispatcher;
 
 public:
   RISCVOutgoingValueAssigner(
@@ -408,7 +408,8 @@ bool RISCVCallLowering::lowerReturnVal(MachineIRBuilder &MIRBuilder,
   SmallVector<ArgInfo, 4> SplitRetInfos;
   splitToValueTypes(OrigRetInfo, SplitRetInfos, DL, CC);
 
-  RVVArgDispatcher Dispatcher{&MF, getTLI<RISCVTargetLowering>(), true};
+  RVVArgDispatcher Dispatcher{&MF, getTLI<RISCVTargetLowering>(),
+                              F.getReturnType()};
   RISCVOutgoingValueAssigner Assigner(
       CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
       /*IsRet=*/true, Dispatcher);
@@ -520,6 +521,7 @@ bool RISCVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
   CallingConv::ID CC = F.getCallingConv();
 
   SmallVector<ArgInfo, 32> SplitArgInfos;
+  std::vector<Type *> TypeList;
   unsigned Index = 0;
   for (auto &Arg : F.args()) {
     // Construct the ArgInfo object from destination register and argument type.
@@ -531,10 +533,12 @@ bool RISCVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
     // correspondingly and appended to SplitArgInfos.
     splitToValueTypes(AInfo, SplitArgInfos, DL, CC);
 
+    TypeList.push_back(Arg.getType());
+
     ++Index;
   }
 
-  RVVArgDispatcher Dispatcher{&MF, getTLI<RISCVTargetLowering>(), false};
+  RVVArgDispatcher Dispatcher{&MF, getTLI<RISCVTargetLowering>(), TypeList};
   RISCVIncomingValueAssigner Assigner(
       CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
       /*IsRet=*/false, Dispatcher);
@@ -575,11 +579,13 @@ bool RISCVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
 
   SmallVector<ArgInfo, 32> SplitArgInfos;
   SmallVector<ISD::OutputArg, 8> Outs;
+  std::vector<Type *> TypeList;
   for (auto &AInfo : Info.OrigArgs) {
     // Handle any required unmerging of split value types from a given VReg into
     // physical registers. ArgInfo objects are constructed correspondingly and
     // appended to SplitArgInfos.
     splitToValueTypes(AInfo, SplitArgInfos, DL, CC);
+    TypeList.push_back(AInfo.Ty);
   }
 
   // TODO: Support tail calls.
@@ -597,8 +603,7 @@ bool RISCVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
   const TargetRegisterInfo *TRI = Subtarget.getRegisterInfo();
   Call.addRegMask(TRI->getCallPreservedMask(MF, Info.CallConv));
 
-  RVVArgDispatcher ArgDispatcher{&MF, getTLI<RISCVTargetLowering>(), false,
-                                 nullptr, &Info};
+  RVVArgDispatcher ArgDispatcher{&MF, getTLI<RISCVTargetLowering>(), TypeList};
   RISCVOutgoingValueAssigner ArgAssigner(
       CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
       /*IsRet=*/false, ArgDispatcher);
@@ -629,7 +634,8 @@ bool RISCVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
   SmallVector<ArgInfo, 4> SplitRetInfos;
   splitToValueTypes(Info.OrigRet, SplitRetInfos, DL, CC);
 
-  RVVArgDispatcher RetDispatcher{&MF, getTLI<RISCVTargetLowering>(), true};
+  RVVArgDispatcher RetDispatcher{&MF, getTLI<RISCVTargetLowering>(),
+                                 F.getReturnType()};
   RISCVIncomingValueAssigner RetAssigner(
       CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
       /*IsRet=*/true, RetDispatcher);
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 5490bab930fec..2eb643d7a71ad 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -17788,7 +17788,13 @@ void RISCVTargetLowering::analyzeInputArgs(
   unsigned NumArgs = Ins.size();
   FunctionType *FType = MF.getFunction().getFunctionType();
 
-  RVVArgDispatcher Dispatcher{&MF, this, IsRet};
+  std::vector<Type *> TypeList;
+  if (IsRet)
+    TypeList.push_back(MF.getFunction().getReturnType());
+  else
+    for (const Argument &Arg : MF.getFunction().args())
+      TypeList.push_back(Arg.getType());
+  RVVArgDispatcher Dispatcher{&MF, this, TypeList};
 
   for (unsigned i = 0; i != NumArgs; ++i) {
     MVT ArgVT = Ins[i].VT;
@@ -17817,7 +17823,13 @@ void RISCVTargetLowering::analyzeOutputArgs(
     CallLoweringInfo *CLI, RISCVCCAssignFn Fn) const {
   unsigned NumArgs = Outs.size();
 
-  RVVArgDispatcher Dispatcher{&MF, this, IsRet, CLI};
+  std::vector<Type *> TypeList;
+  if (IsRet)
+    TypeList.push_back(MF.getFunction().getReturnType());
+  else if (CLI)
+    for (const TargetLowering::ArgListEntry &Arg : CLI->getArgs())
+      TypeList.push_back(Arg.Ty);
+  RVVArgDispatcher Dispatcher{&MF, this, TypeList};
 
   for (unsigned i = 0; i != NumArgs; i++) {
     MVT ArgVT = Outs[i].VT;
@@ -18720,7 +18732,8 @@ bool RISCVTargetLowering::CanLowerReturn(
   SmallVector<CCValAssign, 16> RVLocs;
   CCState CCInfo(CallConv, IsVarArg, MF, RVLocs, Context);
 
-  RVVArgDispatcher Dispatcher{&MF, this, true};
+  std::vector<Type *> TypeList = {MF.getFunction().getReturnType()};
+  RVVArgDispatcher Dispatcher{&MF, this, TypeList};
 
   for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
     MVT VT = Outs[i].VT;
@@ -20558,20 +20571,9 @@ void RVVArgDispatcher::constructHelper(Type *Ty) {
   }
 }
 
-void RVVArgDispatcher::construct(bool IsRet) {
-  const Function &F = MF->getFunction();
-
-  if (IsRet)
-    constructHelper(F.getReturnType());
-  else if (CLI)
-    for (const TargetLowering::ArgListEntry &Arg : CLI->getArgs())
-      constructHelper(Arg.Ty);
-  else if (GISelCLI)
-    for (const CallLowering::ArgInfo &Arg : GISelCLI->OrigArgs)
-      constructHelper(Arg.Ty);
-  else
-    for (const Argument &Arg : F.args())
-      constructHelper(Arg.getType());
+void RVVArgDispatcher::construct(std::vector<Type *> &TypeList) {
+  for (Type *Ty : TypeList)
+    constructHelper(Ty);
 
   for (auto &Info : RVVArgInfos)
     if (Info.NF == 1 && Info.VT.getVectorElementType() == MVT::i1) {
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 3224faf3b34a9..1b6ebab86fd76 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -1020,14 +1020,18 @@ class RVVArgDispatcher {
     bool FirstVMask = false;
   };
 
-  RVVArgDispatcher() = default;
   RVVArgDispatcher(const MachineFunction *MF, const RISCVTargetLowering *TLI,
-                   bool IsRet, TargetLowering::CallLoweringInfo *CLI = nullptr,
-                   CallLowering::CallLoweringInfo *GISelCLI = nullptr)
-      : MF(MF), TLI(TLI), CLI(CLI), GISelCLI(GISelCLI) {
-    assert((!CLI || !GISelCLI) &&
-           "ISel and Global ISel can't co-exist at the same time");
-    construct(IsRet);
+                   std::vector<Type *> &TypeList)
+      : MF(MF), TLI(TLI) {
+    construct(TypeList);
+    compute();
+  }
+
+  RVVArgDispatcher(const MachineFunction *MF, const RISCVTargetLowering *TLI,
+                   Type *Ty)
+      : MF(MF), TLI(TLI) {
+    std::vector<Type *> TypeList = {Ty};
+    construct(TypeList);
     compute();
   }
 
@@ -1044,7 +1048,7 @@ class RVVArgDispatcher {
 
   unsigned CurIdx = 0;
 
-  void construct(bool IsRet);
+  void construct(std::vector<Type *> &TypeList);
   void constructHelper(Type *Ty);
   void compute();
   void allocatePhysReg(unsigned NF = 1, unsigned LMul = 1,

>From f8cc21956fcaf12a438c5174e4e5906c6d5d2ab4 Mon Sep 17 00:00:00 2001
From: Brandon Wu <brandon.wu at sifive.com>
Date: Tue, 23 Jan 2024 19:34:39 -0800
Subject: [PATCH 3/4] fixup! [RISCV] RISCV vector calling convention (2/2)

---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 2eb643d7a71ad..ba3d4baec1dbe 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -20622,9 +20622,10 @@ void RVVArgDispatcher::compute() {
       return;
     }
 
-    unsigned RegsNeeded = std::max(
-        ArgInfo.VT.getSizeInBits().getKnownMinValue() / RISCV::RVVBitsPerBlock,
-        1UL);
+    unsigned RegsNeeded =
+        std::max((unsigned)ArgInfo.VT.getSizeInBits().getKnownMinValue() /
+                     RISCV::RVVBitsPerBlock,
+                 (unsigned)1);
     unsigned TotalRegsNeeded = ArgInfo.NF * RegsNeeded;
     for (unsigned StartReg = 0; StartReg + TotalRegsNeeded <= NumArgVRs;
          StartReg += RegsNeeded) {

>From fd763f8d201a15a890812172f540d59c558d6209 Mon Sep 17 00:00:00 2001
From: Brandon Wu <brandon.wu at sifive.com>
Date: Wed, 24 Jan 2024 23:46:28 -0800
Subject: [PATCH 4/4] fixup! [RISCV] RISCV vector calling convention (2/2)

---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 34 +++++++++------------
 llvm/lib/Target/RISCV/RISCVISelLowering.h   | 21 ++++++++++---
 2 files changed, 32 insertions(+), 23 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index ba3d4baec1dbe..de9bd4beb6669 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -20531,7 +20531,7 @@ unsigned RISCVTargetLowering::getMinimumJumpTableEntries() const {
   return Subtarget.getMinimumJumpTableEntries();
 }
 
-void RVVArgDispatcher::constructHelper(Type *Ty) {
+void RVVArgDispatcher::constructArgInfos(Type *Ty) {
   const DataLayout &DL = MF->getDataLayout();
   const Function &F = MF->getFunction();
   LLVMContext &Context = F.getContext();
@@ -20564,16 +20564,14 @@ void RVVArgDispatcher::constructHelper(Type *Ty) {
         RegisterVT = TLI->getContainerForFixedLengthVector(RegisterVT);
 
       RVVArgInfo Info{1, RegisterVT, false};
-
-      while (NumRegs--)
-        RVVArgInfos.push_back(Info);
+      RVVArgInfos.insert(RVVArgInfos.end(), NumRegs, Info);
     }
   }
 }
 
-void RVVArgDispatcher::construct(std::vector<Type *> &TypeList) {
+void RVVArgDispatcher::construct(const std::vector<Type *> &TypeList) {
   for (Type *Ty : TypeList)
-    constructHelper(Ty);
+    constructArgInfos(Ty);
 
   for (auto &Info : RVVArgInfos)
     if (Info.NF == 1 && Info.VT.getVectorElementType() == MVT::i1) {
@@ -20608,28 +20606,27 @@ void RVVArgDispatcher::allocatePhysReg(unsigned NF, unsigned LMul,
     if (StartReg)
       AllocatedPhysRegs.push_back(VRArrays[(StartReg - 8) / LMul + i]);
     else
-      AllocatedPhysRegs.push_back(0);
+      AllocatedPhysRegs.push_back(MCPhysReg());
 }
 
-// This function determines if each RVV argument is passed by register.
+/// This function determines if each RVV argument is passed by register, if the
+/// argument can be assigned to a VR, then give it a specific register.
+/// Otherwise, assign the argument to 0 which is a invalid MCPhysReg.
 void RVVArgDispatcher::compute() {
-  unsigned ToBeAssigned = RVVArgInfos.size();
-  uint64_t AssignedMap = 0;
-  auto tryAllocate = [&](const RVVArgInfo &ArgInfo) {
+  uint32_t AssignedMap = 0;
+  auto allocate = [&](const RVVArgInfo &ArgInfo) {
     // Allocate first vector mask argument to V0.
     if (ArgInfo.FirstVMask) {
       AllocatedPhysRegs.push_back(RISCV::V0);
       return;
     }
 
-    unsigned RegsNeeded =
-        std::max((unsigned)ArgInfo.VT.getSizeInBits().getKnownMinValue() /
-                     RISCV::RVVBitsPerBlock,
-                 (unsigned)1);
+    unsigned RegsNeeded = divideCeil(
+        ArgInfo.VT.getSizeInBits().getKnownMinValue(), RISCV::RVVBitsPerBlock);
     unsigned TotalRegsNeeded = ArgInfo.NF * RegsNeeded;
     for (unsigned StartReg = 0; StartReg + TotalRegsNeeded <= NumArgVRs;
          StartReg += RegsNeeded) {
-      unsigned Map = ((1 << TotalRegsNeeded) - 1) << StartReg;
+      uint32_t Map = ((1 << TotalRegsNeeded) - 1) << StartReg;
       if ((AssignedMap & Map) == 0) {
         allocatePhysReg(ArgInfo.NF, RegsNeeded, StartReg + 8);
         AssignedMap |= Map;
@@ -20638,11 +20635,10 @@ void RVVArgDispatcher::compute() {
     }
 
     allocatePhysReg(ArgInfo.NF, RegsNeeded, 0);
-    return;
   };
 
-  for (unsigned i = 0; i < ToBeAssigned; ++i)
-    tryAllocate(RVVArgInfos[i]);
+  for (unsigned i = 0; i < RVVArgInfos.size(); ++i)
+    allocate(RVVArgInfos[i]);
 }
 
 MCPhysReg RVVArgDispatcher::getNextPhysReg() {
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 1b6ebab86fd76..7200483f0c936 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -1010,6 +1010,21 @@ class RISCVTargetLowering : public TargetLowering {
   unsigned getMinimumJumpTableEntries() const override;
 };
 
+/// As per the spec, the rules for passing vector arguments are as follows:
+///
+/// 1. For the first vector mask argument, use v0 to pass it.
+/// 2. For vector data arguments or rest vector mask arguments, starting from
+/// the v8 register, if a vector register group between v8-v23 that has not been
+/// allocated can be found and the first register number is a multiple of LMUL,
+/// then allocate this vector register group to the argument and mark these
+/// registers as allocated. Otherwise, pass it by reference and are replaced in
+/// the argument list with the address.
+/// 3. For tuple vector data arguments, starting from the v8 register, if
+/// NFIELDS consecutive vector register groups between v8-v23 that have not been
+/// allocated can be found and the first register number is a multiple of LMUL,
+/// then allocate these vector register groups to the argument and mark these
+/// registers as allocated. Otherwise, pass it by reference and are replaced in
+/// the argument list with the address.
 class RVVArgDispatcher {
 public:
   static constexpr unsigned NumArgVRs = 16;
@@ -1043,13 +1058,11 @@ class RVVArgDispatcher {
 
   const MachineFunction *MF = nullptr;
   const RISCVTargetLowering *TLI = nullptr;
-  TargetLowering::CallLoweringInfo *CLI = nullptr;
-  CallLowering::CallLoweringInfo *GISelCLI = nullptr;
 
   unsigned CurIdx = 0;
 
-  void construct(std::vector<Type *> &TypeList);
-  void constructHelper(Type *Ty);
+  void construct(const std::vector<Type *> &TypeList);
+  void constructArgInfos(Type *Ty);
   void compute();
   void allocatePhysReg(unsigned NF = 1, unsigned LMul = 1,
                        unsigned StartReg = 0);



More information about the llvm-commits mailing list