[llvm] [RISCV] RISCV vector calling convention (2/2) (PR #79096)

Brandon Wu via llvm-commits llvm-commits at lists.llvm.org
Sun Feb 25 22:34:05 PST 2024


https://github.com/4vtomat updated https://github.com/llvm/llvm-project/pull/79096

>From b5b8465e925b5bcd883571ffce34cdbbc79c6cfa Mon Sep 17 00:00:00 2001
From: Brandon Wu <brandon.wu at sifive.com>
Date: Mon, 22 Jan 2024 21:35:46 -0800
Subject: [PATCH 1/5] [RISCV] RISCV vector calling convention (2/2)

This commit handles vector arguments/return for function definition/call,
the new class RVVArgDispatcher is added for doing all vector register
assignment including mask types, data types as well as tuple types.
It precomputes the register number for each argument as per
https://github.com/riscv-non-isa/riscv-elf-psabi-doc/blob/master/riscv-cc.adoc#standard-vector-calling-convention-variant
and it's passed to calling convention function to handle all vector arguments.

Depends on: #78550
---
 .../Target/RISCV/GISel/RISCVCallLowering.cpp  |  49 ++---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp   | 186 ++++++++++++++----
 llvm/lib/Target/RISCV/RISCVISelLowering.h     |  49 ++++-
 llvm/test/CodeGen/RISCV/rvv/calling-conv.ll   |  87 ++++++++
 .../RISCV/rvv/vector-deinterleave-load.ll     |   6 +-
 .../CodeGen/RISCV/rvv/vector-deinterleave.ll  |  17 +-
 6 files changed, 311 insertions(+), 83 deletions(-)

diff --git a/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp b/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp
index 45e19cdea300b1..ff87601046dc20 100644
--- a/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp
+++ b/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp
@@ -34,14 +34,15 @@ struct RISCVOutgoingValueAssigner : public CallLowering::OutgoingValueAssigner {
   // Whether this is assigning args for a return.
   bool IsRet;
 
-  // true if assignArg has been called for a mask argument, false otherwise.
-  bool AssignedFirstMaskArg = false;
+  RVVArgDispatcher RVVDispatcher;
 
 public:
   RISCVOutgoingValueAssigner(
-      RISCVTargetLowering::RISCVCCAssignFn *RISCVAssignFn_, bool IsRet)
+      RISCVTargetLowering::RISCVCCAssignFn *RISCVAssignFn_, bool IsRet,
+      RVVArgDispatcher &RVVDispatcher)
       : CallLowering::OutgoingValueAssigner(nullptr),
-        RISCVAssignFn(RISCVAssignFn_), IsRet(IsRet) {}
+        RISCVAssignFn(RISCVAssignFn_), IsRet(IsRet),
+        RVVDispatcher(RVVDispatcher) {}
 
   bool assignArg(unsigned ValNo, EVT OrigVT, MVT ValVT, MVT LocVT,
                  CCValAssign::LocInfo LocInfo,
@@ -51,16 +52,9 @@ struct RISCVOutgoingValueAssigner : public CallLowering::OutgoingValueAssigner {
     const DataLayout &DL = MF.getDataLayout();
     const RISCVSubtarget &Subtarget = MF.getSubtarget<RISCVSubtarget>();
 
-    std::optional<unsigned> FirstMaskArgument;
-    if (Subtarget.hasVInstructions() && !AssignedFirstMaskArg &&
-        ValVT.isVector() && ValVT.getVectorElementType() == MVT::i1) {
-      FirstMaskArgument = ValNo;
-      AssignedFirstMaskArg = true;
-    }
-
     if (RISCVAssignFn(DL, Subtarget.getTargetABI(), ValNo, ValVT, LocVT,
                       LocInfo, Flags, State, Info.IsFixed, IsRet, Info.Ty,
-                      *Subtarget.getTargetLowering(), FirstMaskArgument))
+                      *Subtarget.getTargetLowering(), RVVDispatcher))
       return true;
 
     StackSize = State.getStackSize();
@@ -181,14 +175,15 @@ struct RISCVIncomingValueAssigner : public CallLowering::IncomingValueAssigner {
   // Whether this is assigning args from a return.
   bool IsRet;
 
-  // true if assignArg has been called for a mask argument, false otherwise.
-  bool AssignedFirstMaskArg = false;
+  RVVArgDispatcher &RVVDispatcher;
 
 public:
   RISCVIncomingValueAssigner(
-      RISCVTargetLowering::RISCVCCAssignFn *RISCVAssignFn_, bool IsRet)
+      RISCVTargetLowering::RISCVCCAssignFn *RISCVAssignFn_, bool IsRet,
+      RVVArgDispatcher &RVVDispatcher)
       : CallLowering::IncomingValueAssigner(nullptr),
-        RISCVAssignFn(RISCVAssignFn_), IsRet(IsRet) {}
+        RISCVAssignFn(RISCVAssignFn_), IsRet(IsRet),
+        RVVDispatcher(RVVDispatcher) {}
 
   bool assignArg(unsigned ValNo, EVT OrigVT, MVT ValVT, MVT LocVT,
                  CCValAssign::LocInfo LocInfo,
@@ -201,16 +196,9 @@ struct RISCVIncomingValueAssigner : public CallLowering::IncomingValueAssigner {
     if (LocVT.isScalableVector())
       MF.getInfo<RISCVMachineFunctionInfo>()->setIsVectorCall();
 
-    std::optional<unsigned> FirstMaskArgument;
-    if (Subtarget.hasVInstructions() && !AssignedFirstMaskArg &&
-        ValVT.isVector() && ValVT.getVectorElementType() == MVT::i1) {
-      FirstMaskArgument = ValNo;
-      AssignedFirstMaskArg = true;
-    }
-
     if (RISCVAssignFn(DL, Subtarget.getTargetABI(), ValNo, ValVT, LocVT,
                       LocInfo, Flags, State, /*IsFixed=*/true, IsRet, Info.Ty,
-                      *Subtarget.getTargetLowering(), FirstMaskArgument))
+                      *Subtarget.getTargetLowering(), RVVDispatcher))
       return true;
 
     StackSize = State.getStackSize();
@@ -420,9 +408,10 @@ bool RISCVCallLowering::lowerReturnVal(MachineIRBuilder &MIRBuilder,
   SmallVector<ArgInfo, 4> SplitRetInfos;
   splitToValueTypes(OrigRetInfo, SplitRetInfos, DL, CC);
 
+  RVVArgDispatcher Dispatcher{&MF, getTLI<RISCVTargetLowering>(), true};
   RISCVOutgoingValueAssigner Assigner(
       CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
-      /*IsRet=*/true);
+      /*IsRet=*/true, Dispatcher);
   RISCVOutgoingValueHandler Handler(MIRBuilder, MF.getRegInfo(), Ret);
   return determineAndHandleAssignments(Handler, Assigner, SplitRetInfos,
                                        MIRBuilder, CC, F.isVarArg());
@@ -545,9 +534,10 @@ bool RISCVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
     ++Index;
   }
 
+  RVVArgDispatcher Dispatcher{&MF, getTLI<RISCVTargetLowering>(), false};
   RISCVIncomingValueAssigner Assigner(
       CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
-      /*IsRet=*/false);
+      /*IsRet=*/false, Dispatcher);
   RISCVFormalArgHandler Handler(MIRBuilder, MF.getRegInfo());
 
   SmallVector<CCValAssign, 16> ArgLocs;
@@ -607,9 +597,11 @@ bool RISCVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
   const TargetRegisterInfo *TRI = Subtarget.getRegisterInfo();
   Call.addRegMask(TRI->getCallPreservedMask(MF, Info.CallConv));
 
+  RVVArgDispatcher ArgDispatcher{&MF, getTLI<RISCVTargetLowering>(), false,
+                                 nullptr, &Info};
   RISCVOutgoingValueAssigner ArgAssigner(
       CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
-      /*IsRet=*/false);
+      /*IsRet=*/false, ArgDispatcher);
   RISCVOutgoingValueHandler ArgHandler(MIRBuilder, MF.getRegInfo(), Call);
   if (!determineAndHandleAssignments(ArgHandler, ArgAssigner, SplitArgInfos,
                                      MIRBuilder, CC, Info.IsVarArg))
@@ -637,9 +629,10 @@ bool RISCVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
   SmallVector<ArgInfo, 4> SplitRetInfos;
   splitToValueTypes(Info.OrigRet, SplitRetInfos, DL, CC);
 
+  RVVArgDispatcher RetDispatcher{&MF, getTLI<RISCVTargetLowering>(), true};
   RISCVIncomingValueAssigner RetAssigner(
       CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
-      /*IsRet=*/true);
+      /*IsRet=*/true, RetDispatcher);
   RISCVCallReturnHandler RetHandler(MIRBuilder, MF.getRegInfo(), Call);
   if (!determineAndHandleAssignments(RetHandler, RetAssigner, SplitRetInfos,
                                      MIRBuilder, CC, Info.IsVarArg))
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 540c2e7476dc18..e999a8b910576a 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -22,6 +22,7 @@
 #include "llvm/ADT/Statistic.h"
 #include "llvm/Analysis/MemoryLocation.h"
 #include "llvm/Analysis/VectorUtils.h"
+#include "llvm/CodeGen/Analysis.h"
 #include "llvm/CodeGen/MachineFrameInfo.h"
 #include "llvm/CodeGen/MachineFunction.h"
 #include "llvm/CodeGen/MachineInstrBuilder.h"
@@ -17868,33 +17869,12 @@ static bool CC_RISCVAssign2XLen(unsigned XLen, CCState &State, CCValAssign VA1,
   return false;
 }
 
-static unsigned allocateRVVReg(MVT ValVT, unsigned ValNo,
-                               std::optional<unsigned> FirstMaskArgument,
-                               CCState &State, const RISCVTargetLowering &TLI) {
-  const TargetRegisterClass *RC = TLI.getRegClassFor(ValVT);
-  if (RC == &RISCV::VRRegClass) {
-    // Assign the first mask argument to V0.
-    // This is an interim calling convention and it may be changed in the
-    // future.
-    if (FirstMaskArgument && ValNo == *FirstMaskArgument)
-      return State.AllocateReg(RISCV::V0);
-    return State.AllocateReg(ArgVRs);
-  }
-  if (RC == &RISCV::VRM2RegClass)
-    return State.AllocateReg(ArgVRM2s);
-  if (RC == &RISCV::VRM4RegClass)
-    return State.AllocateReg(ArgVRM4s);
-  if (RC == &RISCV::VRM8RegClass)
-    return State.AllocateReg(ArgVRM8s);
-  llvm_unreachable("Unhandled register class for ValueType");
-}
-
 // Implements the RISC-V calling convention. Returns true upon failure.
 bool RISCV::CC_RISCV(const DataLayout &DL, RISCVABI::ABI ABI, unsigned ValNo,
                      MVT ValVT, MVT LocVT, CCValAssign::LocInfo LocInfo,
                      ISD::ArgFlagsTy ArgFlags, CCState &State, bool IsFixed,
                      bool IsRet, Type *OrigTy, const RISCVTargetLowering &TLI,
-                     std::optional<unsigned> FirstMaskArgument) {
+                     RVVArgDispatcher &RVVDispatcher) {
   unsigned XLen = DL.getLargestLegalIntTypeSizeInBits();
   assert(XLen == 32 || XLen == 64);
   MVT XLenVT = XLen == 32 ? MVT::i32 : MVT::i64;
@@ -18052,7 +18032,7 @@ bool RISCV::CC_RISCV(const DataLayout &DL, RISCVABI::ABI ABI, unsigned ValNo,
   }
 
   // Allocate to a register if possible, or else a stack slot.
-  Register Reg;
+  Register Reg = MCRegister();
   unsigned StoreSizeBytes = XLen / 8;
   Align StackAlign = Align(XLen / 8);
 
@@ -18063,7 +18043,7 @@ bool RISCV::CC_RISCV(const DataLayout &DL, RISCVABI::ABI ABI, unsigned ValNo,
   else if (ValVT == MVT::f64 && !UseGPRForF64)
     Reg = State.AllocateReg(ArgFPR64s);
   else if (ValVT.isVector()) {
-    Reg = allocateRVVReg(ValVT, ValNo, FirstMaskArgument, State, TLI);
+    Reg = RVVDispatcher.getNextPhysReg();
     if (!Reg) {
       // For return values, the vector must be passed fully via registers or
       // via the stack.
@@ -18149,9 +18129,7 @@ void RISCVTargetLowering::analyzeInputArgs(
   unsigned NumArgs = Ins.size();
   FunctionType *FType = MF.getFunction().getFunctionType();
 
-  std::optional<unsigned> FirstMaskArgument;
-  if (Subtarget.hasVInstructions())
-    FirstMaskArgument = preAssignMask(Ins);
+  RVVArgDispatcher Dispatcher{&MF, this, IsRet};
 
   for (unsigned i = 0; i != NumArgs; ++i) {
     MVT ArgVT = Ins[i].VT;
@@ -18166,7 +18144,7 @@ void RISCVTargetLowering::analyzeInputArgs(
     RISCVABI::ABI ABI = MF.getSubtarget<RISCVSubtarget>().getTargetABI();
     if (Fn(MF.getDataLayout(), ABI, i, ArgVT, ArgVT, CCValAssign::Full,
            ArgFlags, CCInfo, /*IsFixed=*/true, IsRet, ArgTy, *this,
-           FirstMaskArgument)) {
+           Dispatcher)) {
       LLVM_DEBUG(dbgs() << "InputArg #" << i << " has unhandled type "
                         << ArgVT << '\n');
       llvm_unreachable(nullptr);
@@ -18180,9 +18158,7 @@ void RISCVTargetLowering::analyzeOutputArgs(
     CallLoweringInfo *CLI, RISCVCCAssignFn Fn) const {
   unsigned NumArgs = Outs.size();
 
-  std::optional<unsigned> FirstMaskArgument;
-  if (Subtarget.hasVInstructions())
-    FirstMaskArgument = preAssignMask(Outs);
+  RVVArgDispatcher Dispatcher{&MF, this, IsRet, CLI};
 
   for (unsigned i = 0; i != NumArgs; i++) {
     MVT ArgVT = Outs[i].VT;
@@ -18192,7 +18168,7 @@ void RISCVTargetLowering::analyzeOutputArgs(
     RISCVABI::ABI ABI = MF.getSubtarget<RISCVSubtarget>().getTargetABI();
     if (Fn(MF.getDataLayout(), ABI, i, ArgVT, ArgVT, CCValAssign::Full,
            ArgFlags, CCInfo, Outs[i].IsFixed, IsRet, OrigTy, *this,
-           FirstMaskArgument)) {
+           Dispatcher)) {
       LLVM_DEBUG(dbgs() << "OutputArg #" << i << " has unhandled type "
                         << ArgVT << "\n");
       llvm_unreachable(nullptr);
@@ -18373,7 +18349,7 @@ bool RISCV::CC_RISCV_FastCC(const DataLayout &DL, RISCVABI::ABI ABI,
                             ISD::ArgFlagsTy ArgFlags, CCState &State,
                             bool IsFixed, bool IsRet, Type *OrigTy,
                             const RISCVTargetLowering &TLI,
-                            std::optional<unsigned> FirstMaskArgument) {
+                            RVVArgDispatcher &RVVDispatcher) {
   if (LocVT == MVT::i32 || LocVT == MVT::i64) {
     if (unsigned Reg = State.AllocateReg(getFastCCArgGPRs(ABI))) {
       State.addLoc(CCValAssign::getReg(ValNo, ValVT, Reg, LocVT, LocInfo));
@@ -18451,13 +18427,14 @@ bool RISCV::CC_RISCV_FastCC(const DataLayout &DL, RISCVABI::ABI ABI,
   }
 
   if (LocVT.isVector()) {
-    if (unsigned Reg =
-            allocateRVVReg(ValVT, ValNo, FirstMaskArgument, State, TLI)) {
+    MCPhysReg AllocatedVReg = RVVDispatcher.getNextPhysReg();
+    if (AllocatedVReg) {
       // Fixed-length vectors are located in the corresponding scalable-vector
       // container types.
       if (ValVT.isFixedLengthVector())
         LocVT = TLI.getContainerForFixedLengthVector(LocVT);
-      State.addLoc(CCValAssign::getReg(ValNo, ValVT, Reg, LocVT, LocInfo));
+      State.addLoc(
+          CCValAssign::getReg(ValNo, ValVT, AllocatedVReg, LocVT, LocInfo));
     } else {
       // Try and pass the address via a "fast" GPR.
       if (unsigned GPRReg = State.AllocateReg(getFastCCArgGPRs(ABI))) {
@@ -19084,17 +19061,15 @@ bool RISCVTargetLowering::CanLowerReturn(
   SmallVector<CCValAssign, 16> RVLocs;
   CCState CCInfo(CallConv, IsVarArg, MF, RVLocs, Context);
 
-  std::optional<unsigned> FirstMaskArgument;
-  if (Subtarget.hasVInstructions())
-    FirstMaskArgument = preAssignMask(Outs);
+  RVVArgDispatcher Dispatcher{&MF, this, true};
 
   for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
     MVT VT = Outs[i].VT;
     ISD::ArgFlagsTy ArgFlags = Outs[i].Flags;
     RISCVABI::ABI ABI = MF.getSubtarget<RISCVSubtarget>().getTargetABI();
     if (RISCV::CC_RISCV(MF.getDataLayout(), ABI, i, VT, VT, CCValAssign::Full,
-                 ArgFlags, CCInfo, /*IsFixed=*/true, /*IsRet=*/true, nullptr,
-                 *this, FirstMaskArgument))
+                        ArgFlags, CCInfo, /*IsFixed=*/true, /*IsRet=*/true,
+                        nullptr, *this, Dispatcher))
       return false;
   }
   return true;
@@ -20889,6 +20864,135 @@ unsigned RISCVTargetLowering::getMinimumJumpTableEntries() const {
   return Subtarget.getMinimumJumpTableEntries();
 }
 
+void RVVArgDispatcher::constructHelper(Type *Ty) {
+  const DataLayout &DL = MF->getDataLayout();
+  const Function &F = MF->getFunction();
+  LLVMContext &Context = F.getContext();
+
+  StructType *STy = dyn_cast<StructType>(Ty);
+  if (STy && STy->containsHomogeneousScalableVectorTypes()) {
+    Type *ElemTy = STy->getTypeAtIndex(0U);
+    EVT VT = TLI->getValueType(DL, ElemTy);
+    MVT RegisterVT =
+        TLI->getRegisterTypeForCallingConv(Context, F.getCallingConv(), VT);
+
+    RVVArgInfos.push_back({STy->getNumElements(), RegisterVT, false});
+  } else {
+    SmallVector<EVT, 4> ValueVTs;
+    ComputeValueVTs(*TLI, DL, Ty, ValueVTs);
+
+    for (unsigned Value = 0, NumValues = ValueVTs.size(); Value != NumValues;
+         ++Value) {
+      EVT VT = ValueVTs[Value];
+      MVT RegisterVT =
+          TLI->getRegisterTypeForCallingConv(Context, F.getCallingConv(), VT);
+      unsigned NumRegs =
+          TLI->getNumRegistersForCallingConv(Context, F.getCallingConv(), VT);
+
+      // Skip non-RVV register type
+      if (!RegisterVT.isVector())
+        continue;
+
+      if (RegisterVT.isFixedLengthVector())
+        RegisterVT = TLI->getContainerForFixedLengthVector(RegisterVT);
+
+      RVVArgInfo Info{1, RegisterVT, false};
+
+      while (NumRegs--)
+        RVVArgInfos.push_back(Info);
+    }
+  }
+}
+
+void RVVArgDispatcher::construct(bool IsRet) {
+  const Function &F = MF->getFunction();
+
+  if (IsRet)
+    constructHelper(F.getReturnType());
+  else if (CLI)
+    for (const TargetLowering::ArgListEntry &Arg : CLI->getArgs())
+      constructHelper(Arg.Ty);
+  else if (GISelCLI)
+    for (const CallLowering::ArgInfo &Arg : GISelCLI->OrigArgs)
+      constructHelper(Arg.Ty);
+  else
+    for (const Argument &Arg : F.args())
+      constructHelper(Arg.getType());
+
+  for (auto &Info : RVVArgInfos)
+    if (Info.NF == 1 && Info.VT.getVectorElementType() == MVT::i1) {
+      Info.FirstVMask = true;
+      break;
+    }
+}
+
+void RVVArgDispatcher::allocatePhysReg(unsigned NF, unsigned LMul,
+                                       unsigned StartReg) {
+  assert((StartReg % LMul) == 0 &&
+         "Start register number should be multiple of lmul");
+  const MCPhysReg *VRArrays;
+  switch (LMul) {
+  default:
+    report_fatal_error("Invalid lmul");
+  case 1:
+    VRArrays = ArgVRs;
+    break;
+  case 2:
+    VRArrays = ArgVRM2s;
+    break;
+  case 4:
+    VRArrays = ArgVRM4s;
+    break;
+  case 8:
+    VRArrays = ArgVRM8s;
+    break;
+  }
+
+  for (unsigned i = 0; i < NF; ++i)
+    if (StartReg)
+      AllocatedPhysRegs.push_back(VRArrays[(StartReg - 8) / LMul + i]);
+    else
+      AllocatedPhysRegs.push_back(0);
+}
+
+// This function determines if each RVV argument is passed by register.
+void RVVArgDispatcher::compute() {
+  unsigned ToBeAssigned = RVVArgInfos.size();
+  uint64_t AssignedMap = 0;
+  auto tryAllocate = [&](const RVVArgInfo &ArgInfo) {
+    // Allocate first vector mask argument to V0.
+    if (ArgInfo.FirstVMask) {
+      AllocatedPhysRegs.push_back(RISCV::V0);
+      return;
+    }
+
+    unsigned RegsNeeded = std::max(
+        ArgInfo.VT.getSizeInBits().getKnownMinValue() / RISCV::RVVBitsPerBlock,
+        1UL);
+    unsigned TotalRegsNeeded = ArgInfo.NF * RegsNeeded;
+    for (unsigned StartReg = 0; StartReg + TotalRegsNeeded <= NumArgVRs;
+         StartReg += RegsNeeded) {
+      unsigned Map = ((1 << TotalRegsNeeded) - 1) << StartReg;
+      if ((AssignedMap & Map) == 0) {
+        allocatePhysReg(ArgInfo.NF, RegsNeeded, StartReg + 8);
+        AssignedMap |= Map;
+        return;
+      }
+    }
+
+    allocatePhysReg(ArgInfo.NF, RegsNeeded, 0);
+    return;
+  };
+
+  for (unsigned i = 0; i < ToBeAssigned; ++i)
+    tryAllocate(RVVArgInfos[i]);
+}
+
+MCPhysReg RVVArgDispatcher::getNextPhysReg() {
+  assert(CurIdx < AllocatedPhysRegs.size() && "Index out of range");
+  return AllocatedPhysRegs[CurIdx++];
+}
+
 namespace llvm::RISCVVIntrinsicsTable {
 
 #define GET_RISCVVIntrinsicsTable_IMPL
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index a38463f810270a..c5437a13b6446e 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -16,6 +16,7 @@
 
 #include "RISCV.h"
 #include "llvm/CodeGen/CallingConvLower.h"
+#include "llvm/CodeGen/GlobalISel/CallLowering.h"
 #include "llvm/CodeGen/SelectionDAG.h"
 #include "llvm/CodeGen/TargetLowering.h"
 #include "llvm/TargetParser/RISCVTargetParser.h"
@@ -25,6 +26,7 @@ namespace llvm {
 class InstructionCost;
 class RISCVSubtarget;
 struct RISCVRegisterInfo;
+class RVVArgDispatcher;
 
 namespace RISCVISD {
 // clang-format off
@@ -870,7 +872,7 @@ class RISCVTargetLowering : public TargetLowering {
                                ISD::ArgFlagsTy ArgFlags, CCState &State,
                                bool IsFixed, bool IsRet, Type *OrigTy,
                                const RISCVTargetLowering &TLI,
-                               std::optional<unsigned> FirstMaskArgument);
+                               RVVArgDispatcher &RVVDispatcher);
 
 private:
   void analyzeInputArgs(MachineFunction &MF, CCState &CCInfo,
@@ -1010,19 +1012,60 @@ class RISCVTargetLowering : public TargetLowering {
   unsigned getMinimumJumpTableEntries() const override;
 };
 
+class RVVArgDispatcher {
+public:
+  static constexpr unsigned NumArgVRs = 16;
+
+  struct RVVArgInfo {
+    unsigned NF;
+    MVT VT;
+    bool FirstVMask = false;
+  };
+
+  RVVArgDispatcher() = default;
+  RVVArgDispatcher(const MachineFunction *MF, const RISCVTargetLowering *TLI,
+                   bool IsRet, TargetLowering::CallLoweringInfo *CLI = nullptr,
+                   CallLowering::CallLoweringInfo *GISelCLI = nullptr)
+      : MF(MF), TLI(TLI), CLI(CLI), GISelCLI(GISelCLI) {
+    assert((!CLI || !GISelCLI) &&
+           "ISel and Global ISel can't co-exist at the same time");
+    construct(IsRet);
+    compute();
+  }
+
+  MCPhysReg getNextPhysReg();
+
+private:
+  std::vector<RVVArgInfo> RVVArgInfos;
+  std::vector<MCPhysReg> AllocatedPhysRegs;
+
+  const MachineFunction *MF = nullptr;
+  const RISCVTargetLowering *TLI = nullptr;
+  TargetLowering::CallLoweringInfo *CLI = nullptr;
+  CallLowering::CallLoweringInfo *GISelCLI = nullptr;
+
+  unsigned CurIdx = 0;
+
+  void construct(bool IsRet);
+  void constructHelper(Type *Ty);
+  void compute();
+  void allocatePhysReg(unsigned NF = 1, unsigned LMul = 1,
+                       unsigned StartReg = 0);
+};
+
 namespace RISCV {
 
 bool CC_RISCV(const DataLayout &DL, RISCVABI::ABI ABI, unsigned ValNo,
               MVT ValVT, MVT LocVT, CCValAssign::LocInfo LocInfo,
               ISD::ArgFlagsTy ArgFlags, CCState &State, bool IsFixed,
               bool IsRet, Type *OrigTy, const RISCVTargetLowering &TLI,
-              std::optional<unsigned> FirstMaskArgument);
+              RVVArgDispatcher &RVVDispatcher);
 
 bool CC_RISCV_FastCC(const DataLayout &DL, RISCVABI::ABI ABI, unsigned ValNo,
                      MVT ValVT, MVT LocVT, CCValAssign::LocInfo LocInfo,
                      ISD::ArgFlagsTy ArgFlags, CCState &State, bool IsFixed,
                      bool IsRet, Type *OrigTy, const RISCVTargetLowering &TLI,
-                     std::optional<unsigned> FirstMaskArgument);
+                     RVVArgDispatcher &RVVDispatcher);
 
 bool CC_RISCV_GHC(unsigned ValNo, MVT ValVT, MVT LocVT,
                   CCValAssign::LocInfo LocInfo, ISD::ArgFlagsTy ArgFlags,
diff --git a/llvm/test/CodeGen/RISCV/rvv/calling-conv.ll b/llvm/test/CodeGen/RISCV/rvv/calling-conv.ll
index 78385a80b47ebd..005bc5b60ca805 100644
--- a/llvm/test/CodeGen/RISCV/rvv/calling-conv.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/calling-conv.ll
@@ -86,3 +86,90 @@ define <vscale x 32 x i32> @caller_scalable_vector_split_indirect(<vscale x 32 x
   %a = call <vscale x 32 x i32> @callee_scalable_vector_split_indirect(<vscale x 32 x i32> zeroinitializer, <vscale x 32 x i32> %x)
   ret <vscale x 32 x i32> %a
 }
+
+; %0 -> v8
+; %1 -> v9
+define <vscale x 1 x i64> @case1(<vscale x 1 x i64> %0, <vscale x 1 x i64> %1) {
+; CHECK-LABEL: case1:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a0, zero, e64, m1, ta, ma
+; CHECK-NEXT:    vadd.vv v8, v8, v9
+; CHECK-NEXT:    ret
+  %a = add <vscale x 1 x i64> %0, %1
+  ret <vscale x 1 x i64> %a
+}
+
+; %0 -> v8
+; %1 -> v10-v11
+; %2 -> v9
+define <vscale x 1 x i64> @case2_1(<vscale x 1 x i64> %0, <vscale x 2 x i64> %1, <vscale x 1 x i64> %2) {
+; CHECK-LABEL: case2_1:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a0, zero, e64, m1, ta, ma
+; CHECK-NEXT:    vadd.vv v8, v8, v9
+; CHECK-NEXT:    ret
+  %a = add <vscale x 1 x i64> %0, %2
+  ret <vscale x 1 x i64> %a
+}
+define <vscale x 2 x i64> @case2_2(<vscale x 1 x i64> %0, <vscale x 2 x i64> %1, <vscale x 1 x i64> %2) {
+; CHECK-LABEL: case2_2:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a0, zero, e64, m2, ta, ma
+; CHECK-NEXT:    vadd.vv v8, v10, v10
+; CHECK-NEXT:    ret
+  %a = add <vscale x 2 x i64> %1, %1
+  ret <vscale x 2 x i64> %a
+}
+
+; %0 -> v8
+; %1 -> {v10-v11, v12-v13}
+; %2 -> v9
+define <vscale x 1 x i64> @case3_1(<vscale x 1 x i64> %0, {<vscale x 2 x i64>, <vscale x 2 x i64>} %1, <vscale x 1 x i64> %2) {
+; CHECK-LABEL: case3_1:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a0, zero, e64, m1, ta, ma
+; CHECK-NEXT:    vadd.vv v8, v8, v9
+; CHECK-NEXT:    ret
+  %add = add <vscale x 1 x i64> %0, %2
+  ret <vscale x 1 x i64> %add
+}
+define <vscale x 2 x i64> @case3_2(<vscale x 1 x i64> %0, {<vscale x 2 x i64>, <vscale x 2 x i64>} %1, <vscale x 1 x i64> %2) {
+; CHECK-LABEL: case3_2:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a0, zero, e64, m2, ta, ma
+; CHECK-NEXT:    vadd.vv v8, v10, v12
+; CHECK-NEXT:    ret
+  %a = extractvalue { <vscale x 2 x i64>, <vscale x 2 x i64> } %1, 0
+  %b = extractvalue { <vscale x 2 x i64>, <vscale x 2 x i64> } %1, 1
+  %add = add <vscale x 2 x i64> %a, %b
+  ret <vscale x 2 x i64> %add
+}
+
+; %0 -> v8
+; %1 -> {by-ref, by-ref}
+; %2 -> v9
+define <vscale x 8 x i64> @case4_1(<vscale x 1 x i64> %0, {<vscale x 8 x i64>, <vscale x 8 x i64>} %1, <vscale x 1 x i64> %2) {
+; CHECK-LABEL: case4_1:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    csrr a1, vlenb
+; CHECK-NEXT:    slli a1, a1, 3
+; CHECK-NEXT:    add a1, a0, a1
+; CHECK-NEXT:    vl8re64.v v8, (a1)
+; CHECK-NEXT:    vl8re64.v v16, (a0)
+; CHECK-NEXT:    vsetvli a0, zero, e64, m8, ta, ma
+; CHECK-NEXT:    vadd.vv v8, v16, v8
+; CHECK-NEXT:    ret
+  %a = extractvalue { <vscale x 8 x i64>, <vscale x 8 x i64> } %1, 0
+  %b = extractvalue { <vscale x 8 x i64>, <vscale x 8 x i64> } %1, 1
+  %add = add <vscale x 8 x i64> %a, %b
+  ret <vscale x 8 x i64> %add
+}
+define <vscale x 1 x i64> @case4_2(<vscale x 1 x i64> %0, {<vscale x 8 x i64>, <vscale x 8 x i64>} %1, <vscale x 1 x i64> %2) {
+; CHECK-LABEL: case4_2:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a0, zero, e64, m1, ta, ma
+; CHECK-NEXT:    vadd.vv v8, v8, v9
+; CHECK-NEXT:    ret
+  %add = add <vscale x 1 x i64> %0, %2
+  ret <vscale x 1 x i64> %add
+}
diff --git a/llvm/test/CodeGen/RISCV/rvv/vector-deinterleave-load.ll b/llvm/test/CodeGen/RISCV/rvv/vector-deinterleave-load.ll
index a320aecc6fce49..6a712080fda74a 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vector-deinterleave-load.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vector-deinterleave-load.ll
@@ -18,10 +18,10 @@ define {<vscale x 16 x i1>, <vscale x 16 x i1>} @vector_deinterleave_load_nxv16i
 ; CHECK-NEXT:    vmerge.vim v14, v10, 1, v0
 ; CHECK-NEXT:    vmv1r.v v0, v8
 ; CHECK-NEXT:    vmerge.vim v12, v10, 1, v0
-; CHECK-NEXT:    vnsrl.wi v8, v12, 0
-; CHECK-NEXT:    vmsne.vi v0, v8, 0
-; CHECK-NEXT:    vnsrl.wi v10, v12, 8
+; CHECK-NEXT:    vnsrl.wi v10, v12, 0
 ; CHECK-NEXT:    vmsne.vi v8, v10, 0
+; CHECK-NEXT:    vnsrl.wi v10, v12, 8
+; CHECK-NEXT:    vmsne.vi v9, v10, 0
 ; CHECK-NEXT:    ret
   %vec = load <vscale x 32 x i1>, ptr %p
   %retval = call {<vscale x 16 x i1>, <vscale x 16 x i1>} @llvm.experimental.vector.deinterleave2.nxv32i1(<vscale x 32 x i1> %vec)
diff --git a/llvm/test/CodeGen/RISCV/rvv/vector-deinterleave.ll b/llvm/test/CodeGen/RISCV/rvv/vector-deinterleave.ll
index ef4baf34d23f03..d5a818e0e38205 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vector-deinterleave.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vector-deinterleave.ll
@@ -8,18 +8,18 @@ define {<vscale x 16 x i1>, <vscale x 16 x i1>} @vector_deinterleave_nxv16i1_nxv
 ; CHECK-LABEL: vector_deinterleave_nxv16i1_nxv32i1:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vsetvli a0, zero, e8, m2, ta, ma
-; CHECK-NEXT:    vmv.v.i v10, 0
-; CHECK-NEXT:    vmerge.vim v8, v10, 1, v0
+; CHECK-NEXT:    vmv.v.i v8, 0
+; CHECK-NEXT:    vmerge.vim v12, v8, 1, v0
 ; CHECK-NEXT:    csrr a0, vlenb
 ; CHECK-NEXT:    srli a0, a0, 2
 ; CHECK-NEXT:    vsetvli a1, zero, e8, mf2, ta, ma
 ; CHECK-NEXT:    vslidedown.vx v0, v0, a0
 ; CHECK-NEXT:    vsetvli a0, zero, e8, m2, ta, ma
-; CHECK-NEXT:    vmerge.vim v10, v10, 1, v0
-; CHECK-NEXT:    vnsrl.wi v12, v8, 0
-; CHECK-NEXT:    vmsne.vi v0, v12, 0
-; CHECK-NEXT:    vnsrl.wi v12, v8, 8
-; CHECK-NEXT:    vmsne.vi v8, v12, 0
+; CHECK-NEXT:    vmerge.vim v14, v8, 1, v0
+; CHECK-NEXT:    vnsrl.wi v10, v12, 0
+; CHECK-NEXT:    vmsne.vi v8, v10, 0
+; CHECK-NEXT:    vnsrl.wi v10, v12, 8
+; CHECK-NEXT:    vmsne.vi v9, v10, 0
 ; CHECK-NEXT:    ret
 %retval = call {<vscale x 16 x i1>, <vscale x 16 x i1>} @llvm.experimental.vector.deinterleave2.nxv32i1(<vscale x 32 x i1> %vec)
 ret {<vscale x 16 x i1>, <vscale x 16 x i1>} %retval
@@ -107,7 +107,8 @@ define {<vscale x 64 x i1>, <vscale x 64 x i1>} @vector_deinterleave_nxv64i1_nxv
 ; CHECK-NEXT:    vnsrl.wi v24, v16, 8
 ; CHECK-NEXT:    vnsrl.wi v28, v8, 8
 ; CHECK-NEXT:    vsetvli a0, zero, e8, m8, ta, ma
-; CHECK-NEXT:    vmsne.vi v8, v24, 0
+; CHECK-NEXT:    vmsne.vi v9, v24, 0
+; CHECK-NEXT:    vmv1r.v v8, v0
 ; CHECK-NEXT:    ret
 %retval = call {<vscale x 64 x i1>, <vscale x 64 x i1>} @llvm.experimental.vector.deinterleave2.nxv128i1(<vscale x 128 x i1> %vec)
 ret {<vscale x 64 x i1>, <vscale x 64 x i1>} %retval

>From c94b8465b5b377c91961de0a3abdaff937295ba9 Mon Sep 17 00:00:00 2001
From: Brandon Wu <brandon.wu at sifive.com>
Date: Mon, 22 Jan 2024 23:39:13 -0800
Subject: [PATCH 2/5] fixup! [RISCV] RISCV vector calling convention (2/2)

---
 .../Target/RISCV/GISel/RISCVCallLowering.cpp  | 18 ++++++----
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp   | 36 ++++++++++---------
 llvm/lib/Target/RISCV/RISCVISelLowering.h     | 20 ++++++-----
 3 files changed, 43 insertions(+), 31 deletions(-)

diff --git a/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp b/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp
index ff87601046dc20..63ee73a05f74af 100644
--- a/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp
+++ b/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp
@@ -34,7 +34,7 @@ struct RISCVOutgoingValueAssigner : public CallLowering::OutgoingValueAssigner {
   // Whether this is assigning args for a return.
   bool IsRet;
 
-  RVVArgDispatcher RVVDispatcher;
+  RVVArgDispatcher &RVVDispatcher;
 
 public:
   RISCVOutgoingValueAssigner(
@@ -408,7 +408,8 @@ bool RISCVCallLowering::lowerReturnVal(MachineIRBuilder &MIRBuilder,
   SmallVector<ArgInfo, 4> SplitRetInfos;
   splitToValueTypes(OrigRetInfo, SplitRetInfos, DL, CC);
 
-  RVVArgDispatcher Dispatcher{&MF, getTLI<RISCVTargetLowering>(), true};
+  RVVArgDispatcher Dispatcher{&MF, getTLI<RISCVTargetLowering>(),
+                              F.getReturnType()};
   RISCVOutgoingValueAssigner Assigner(
       CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
       /*IsRet=*/true, Dispatcher);
@@ -520,6 +521,7 @@ bool RISCVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
   CallingConv::ID CC = F.getCallingConv();
 
   SmallVector<ArgInfo, 32> SplitArgInfos;
+  std::vector<Type *> TypeList;
   unsigned Index = 0;
   for (auto &Arg : F.args()) {
     // Construct the ArgInfo object from destination register and argument type.
@@ -531,10 +533,12 @@ bool RISCVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
     // correspondingly and appended to SplitArgInfos.
     splitToValueTypes(AInfo, SplitArgInfos, DL, CC);
 
+    TypeList.push_back(Arg.getType());
+
     ++Index;
   }
 
-  RVVArgDispatcher Dispatcher{&MF, getTLI<RISCVTargetLowering>(), false};
+  RVVArgDispatcher Dispatcher{&MF, getTLI<RISCVTargetLowering>(), TypeList};
   RISCVIncomingValueAssigner Assigner(
       CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
       /*IsRet=*/false, Dispatcher);
@@ -575,11 +579,13 @@ bool RISCVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
 
   SmallVector<ArgInfo, 32> SplitArgInfos;
   SmallVector<ISD::OutputArg, 8> Outs;
+  std::vector<Type *> TypeList;
   for (auto &AInfo : Info.OrigArgs) {
     // Handle any required unmerging of split value types from a given VReg into
     // physical registers. ArgInfo objects are constructed correspondingly and
     // appended to SplitArgInfos.
     splitToValueTypes(AInfo, SplitArgInfos, DL, CC);
+    TypeList.push_back(AInfo.Ty);
   }
 
   // TODO: Support tail calls.
@@ -597,8 +603,7 @@ bool RISCVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
   const TargetRegisterInfo *TRI = Subtarget.getRegisterInfo();
   Call.addRegMask(TRI->getCallPreservedMask(MF, Info.CallConv));
 
-  RVVArgDispatcher ArgDispatcher{&MF, getTLI<RISCVTargetLowering>(), false,
-                                 nullptr, &Info};
+  RVVArgDispatcher ArgDispatcher{&MF, getTLI<RISCVTargetLowering>(), TypeList};
   RISCVOutgoingValueAssigner ArgAssigner(
       CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
       /*IsRet=*/false, ArgDispatcher);
@@ -629,7 +634,8 @@ bool RISCVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
   SmallVector<ArgInfo, 4> SplitRetInfos;
   splitToValueTypes(Info.OrigRet, SplitRetInfos, DL, CC);
 
-  RVVArgDispatcher RetDispatcher{&MF, getTLI<RISCVTargetLowering>(), true};
+  RVVArgDispatcher RetDispatcher{&MF, getTLI<RISCVTargetLowering>(),
+                                 F.getReturnType()};
   RISCVIncomingValueAssigner RetAssigner(
       CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
       /*IsRet=*/true, RetDispatcher);
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index e999a8b910576a..23bb092c7015bc 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -18129,7 +18129,13 @@ void RISCVTargetLowering::analyzeInputArgs(
   unsigned NumArgs = Ins.size();
   FunctionType *FType = MF.getFunction().getFunctionType();
 
-  RVVArgDispatcher Dispatcher{&MF, this, IsRet};
+  std::vector<Type *> TypeList;
+  if (IsRet)
+    TypeList.push_back(MF.getFunction().getReturnType());
+  else
+    for (const Argument &Arg : MF.getFunction().args())
+      TypeList.push_back(Arg.getType());
+  RVVArgDispatcher Dispatcher{&MF, this, TypeList};
 
   for (unsigned i = 0; i != NumArgs; ++i) {
     MVT ArgVT = Ins[i].VT;
@@ -18158,7 +18164,13 @@ void RISCVTargetLowering::analyzeOutputArgs(
     CallLoweringInfo *CLI, RISCVCCAssignFn Fn) const {
   unsigned NumArgs = Outs.size();
 
-  RVVArgDispatcher Dispatcher{&MF, this, IsRet, CLI};
+  std::vector<Type *> TypeList;
+  if (IsRet)
+    TypeList.push_back(MF.getFunction().getReturnType());
+  else if (CLI)
+    for (const TargetLowering::ArgListEntry &Arg : CLI->getArgs())
+      TypeList.push_back(Arg.Ty);
+  RVVArgDispatcher Dispatcher{&MF, this, TypeList};
 
   for (unsigned i = 0; i != NumArgs; i++) {
     MVT ArgVT = Outs[i].VT;
@@ -19061,7 +19073,8 @@ bool RISCVTargetLowering::CanLowerReturn(
   SmallVector<CCValAssign, 16> RVLocs;
   CCState CCInfo(CallConv, IsVarArg, MF, RVLocs, Context);
 
-  RVVArgDispatcher Dispatcher{&MF, this, true};
+  std::vector<Type *> TypeList = {MF.getFunction().getReturnType()};
+  RVVArgDispatcher Dispatcher{&MF, this, TypeList};
 
   for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
     MVT VT = Outs[i].VT;
@@ -20904,20 +20917,9 @@ void RVVArgDispatcher::constructHelper(Type *Ty) {
   }
 }
 
-void RVVArgDispatcher::construct(bool IsRet) {
-  const Function &F = MF->getFunction();
-
-  if (IsRet)
-    constructHelper(F.getReturnType());
-  else if (CLI)
-    for (const TargetLowering::ArgListEntry &Arg : CLI->getArgs())
-      constructHelper(Arg.Ty);
-  else if (GISelCLI)
-    for (const CallLowering::ArgInfo &Arg : GISelCLI->OrigArgs)
-      constructHelper(Arg.Ty);
-  else
-    for (const Argument &Arg : F.args())
-      constructHelper(Arg.getType());
+void RVVArgDispatcher::construct(std::vector<Type *> &TypeList) {
+  for (Type *Ty : TypeList)
+    constructHelper(Ty);
 
   for (auto &Info : RVVArgInfos)
     if (Info.NF == 1 && Info.VT.getVectorElementType() == MVT::i1) {
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index c5437a13b6446e..98bef01ad619d0 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -1022,14 +1022,18 @@ class RVVArgDispatcher {
     bool FirstVMask = false;
   };
 
-  RVVArgDispatcher() = default;
   RVVArgDispatcher(const MachineFunction *MF, const RISCVTargetLowering *TLI,
-                   bool IsRet, TargetLowering::CallLoweringInfo *CLI = nullptr,
-                   CallLowering::CallLoweringInfo *GISelCLI = nullptr)
-      : MF(MF), TLI(TLI), CLI(CLI), GISelCLI(GISelCLI) {
-    assert((!CLI || !GISelCLI) &&
-           "ISel and Global ISel can't co-exist at the same time");
-    construct(IsRet);
+                   std::vector<Type *> &TypeList)
+      : MF(MF), TLI(TLI) {
+    construct(TypeList);
+    compute();
+  }
+
+  RVVArgDispatcher(const MachineFunction *MF, const RISCVTargetLowering *TLI,
+                   Type *Ty)
+      : MF(MF), TLI(TLI) {
+    std::vector<Type *> TypeList = {Ty};
+    construct(TypeList);
     compute();
   }
 
@@ -1046,7 +1050,7 @@ class RVVArgDispatcher {
 
   unsigned CurIdx = 0;
 
-  void construct(bool IsRet);
+  void construct(std::vector<Type *> &TypeList);
   void constructHelper(Type *Ty);
   void compute();
   void allocatePhysReg(unsigned NF = 1, unsigned LMul = 1,

>From e294356574b922cb5cf5d9cd734c6eb3d9f62174 Mon Sep 17 00:00:00 2001
From: Brandon Wu <brandon.wu at sifive.com>
Date: Tue, 23 Jan 2024 19:34:39 -0800
Subject: [PATCH 3/5] fixup! [RISCV] RISCV vector calling convention (2/2)

---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 23bb092c7015bc..7a4531e0cabf6e 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -20968,9 +20968,10 @@ void RVVArgDispatcher::compute() {
       return;
     }
 
-    unsigned RegsNeeded = std::max(
-        ArgInfo.VT.getSizeInBits().getKnownMinValue() / RISCV::RVVBitsPerBlock,
-        1UL);
+    unsigned RegsNeeded =
+        std::max((unsigned)ArgInfo.VT.getSizeInBits().getKnownMinValue() /
+                     RISCV::RVVBitsPerBlock,
+                 (unsigned)1);
     unsigned TotalRegsNeeded = ArgInfo.NF * RegsNeeded;
     for (unsigned StartReg = 0; StartReg + TotalRegsNeeded <= NumArgVRs;
          StartReg += RegsNeeded) {

>From 3cb56d30516e5654db135691d2018d5559da3adc Mon Sep 17 00:00:00 2001
From: Brandon Wu <brandon.wu at sifive.com>
Date: Wed, 24 Jan 2024 23:46:28 -0800
Subject: [PATCH 4/5] fixup! [RISCV] RISCV vector calling convention (2/2)

---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 34 +++++++++------------
 llvm/lib/Target/RISCV/RISCVISelLowering.h   | 21 ++++++++++---
 2 files changed, 32 insertions(+), 23 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 7a4531e0cabf6e..4956417ea4b3e8 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -20877,7 +20877,7 @@ unsigned RISCVTargetLowering::getMinimumJumpTableEntries() const {
   return Subtarget.getMinimumJumpTableEntries();
 }
 
-void RVVArgDispatcher::constructHelper(Type *Ty) {
+void RVVArgDispatcher::constructArgInfos(Type *Ty) {
   const DataLayout &DL = MF->getDataLayout();
   const Function &F = MF->getFunction();
   LLVMContext &Context = F.getContext();
@@ -20910,16 +20910,14 @@ void RVVArgDispatcher::constructHelper(Type *Ty) {
         RegisterVT = TLI->getContainerForFixedLengthVector(RegisterVT);
 
       RVVArgInfo Info{1, RegisterVT, false};
-
-      while (NumRegs--)
-        RVVArgInfos.push_back(Info);
+      RVVArgInfos.insert(RVVArgInfos.end(), NumRegs, Info);
     }
   }
 }
 
-void RVVArgDispatcher::construct(std::vector<Type *> &TypeList) {
+void RVVArgDispatcher::construct(const std::vector<Type *> &TypeList) {
   for (Type *Ty : TypeList)
-    constructHelper(Ty);
+    constructArgInfos(Ty);
 
   for (auto &Info : RVVArgInfos)
     if (Info.NF == 1 && Info.VT.getVectorElementType() == MVT::i1) {
@@ -20954,28 +20952,27 @@ void RVVArgDispatcher::allocatePhysReg(unsigned NF, unsigned LMul,
     if (StartReg)
       AllocatedPhysRegs.push_back(VRArrays[(StartReg - 8) / LMul + i]);
     else
-      AllocatedPhysRegs.push_back(0);
+      AllocatedPhysRegs.push_back(MCPhysReg());
 }
 
-// This function determines if each RVV argument is passed by register.
+/// This function determines if each RVV argument is passed by register, if the
+/// argument can be assigned to a VR, then give it a specific register.
+/// Otherwise, assign the argument to 0 which is a invalid MCPhysReg.
 void RVVArgDispatcher::compute() {
-  unsigned ToBeAssigned = RVVArgInfos.size();
-  uint64_t AssignedMap = 0;
-  auto tryAllocate = [&](const RVVArgInfo &ArgInfo) {
+  uint32_t AssignedMap = 0;
+  auto allocate = [&](const RVVArgInfo &ArgInfo) {
     // Allocate first vector mask argument to V0.
     if (ArgInfo.FirstVMask) {
       AllocatedPhysRegs.push_back(RISCV::V0);
       return;
     }
 
-    unsigned RegsNeeded =
-        std::max((unsigned)ArgInfo.VT.getSizeInBits().getKnownMinValue() /
-                     RISCV::RVVBitsPerBlock,
-                 (unsigned)1);
+    unsigned RegsNeeded = divideCeil(
+        ArgInfo.VT.getSizeInBits().getKnownMinValue(), RISCV::RVVBitsPerBlock);
     unsigned TotalRegsNeeded = ArgInfo.NF * RegsNeeded;
     for (unsigned StartReg = 0; StartReg + TotalRegsNeeded <= NumArgVRs;
          StartReg += RegsNeeded) {
-      unsigned Map = ((1 << TotalRegsNeeded) - 1) << StartReg;
+      uint32_t Map = ((1 << TotalRegsNeeded) - 1) << StartReg;
       if ((AssignedMap & Map) == 0) {
         allocatePhysReg(ArgInfo.NF, RegsNeeded, StartReg + 8);
         AssignedMap |= Map;
@@ -20984,11 +20981,10 @@ void RVVArgDispatcher::compute() {
     }
 
     allocatePhysReg(ArgInfo.NF, RegsNeeded, 0);
-    return;
   };
 
-  for (unsigned i = 0; i < ToBeAssigned; ++i)
-    tryAllocate(RVVArgInfos[i]);
+  for (unsigned i = 0; i < RVVArgInfos.size(); ++i)
+    allocate(RVVArgInfos[i]);
 }
 
 MCPhysReg RVVArgDispatcher::getNextPhysReg() {
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 98bef01ad619d0..1ef09c5b9dcda6 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -1012,6 +1012,21 @@ class RISCVTargetLowering : public TargetLowering {
   unsigned getMinimumJumpTableEntries() const override;
 };
 
+/// As per the spec, the rules for passing vector arguments are as follows:
+///
+/// 1. For the first vector mask argument, use v0 to pass it.
+/// 2. For vector data arguments or rest vector mask arguments, starting from
+/// the v8 register, if a vector register group between v8-v23 that has not been
+/// allocated can be found and the first register number is a multiple of LMUL,
+/// then allocate this vector register group to the argument and mark these
+/// registers as allocated. Otherwise, pass it by reference and are replaced in
+/// the argument list with the address.
+/// 3. For tuple vector data arguments, starting from the v8 register, if
+/// NFIELDS consecutive vector register groups between v8-v23 that have not been
+/// allocated can be found and the first register number is a multiple of LMUL,
+/// then allocate these vector register groups to the argument and mark these
+/// registers as allocated. Otherwise, pass it by reference and are replaced in
+/// the argument list with the address.
 class RVVArgDispatcher {
 public:
   static constexpr unsigned NumArgVRs = 16;
@@ -1045,13 +1060,11 @@ class RVVArgDispatcher {
 
   const MachineFunction *MF = nullptr;
   const RISCVTargetLowering *TLI = nullptr;
-  TargetLowering::CallLoweringInfo *CLI = nullptr;
-  CallLowering::CallLoweringInfo *GISelCLI = nullptr;
 
   unsigned CurIdx = 0;
 
-  void construct(std::vector<Type *> &TypeList);
-  void constructHelper(Type *Ty);
+  void construct(const std::vector<Type *> &TypeList);
+  void constructArgInfos(Type *Ty);
   void compute();
   void allocatePhysReg(unsigned NF = 1, unsigned LMul = 1,
                        unsigned StartReg = 0);

>From db17c2b6ffeb1213496eb0a795a6b720eb5f16c8 Mon Sep 17 00:00:00 2001
From: Brandon Wu <brandon.wu at sifive.com>
Date: Wed, 7 Feb 2024 00:22:08 -0800
Subject: [PATCH 5/5] fixup! [RISCV] RISCV vector calling convention (2/2)

---
 .../Target/RISCV/GISel/RISCVCallLowering.cpp  |  4 +--
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp   | 33 ++++++++++---------
 llvm/lib/Target/RISCV/RISCVISelLowering.h     | 16 ++++-----
 3 files changed, 27 insertions(+), 26 deletions(-)

diff --git a/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp b/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp
index 63ee73a05f74af..8af4bc658409d4 100644
--- a/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp
+++ b/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp
@@ -521,7 +521,7 @@ bool RISCVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
   CallingConv::ID CC = F.getCallingConv();
 
   SmallVector<ArgInfo, 32> SplitArgInfos;
-  std::vector<Type *> TypeList;
+  SmallVector<Type *, 4> TypeList;
   unsigned Index = 0;
   for (auto &Arg : F.args()) {
     // Construct the ArgInfo object from destination register and argument type.
@@ -579,7 +579,7 @@ bool RISCVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
 
   SmallVector<ArgInfo, 32> SplitArgInfos;
   SmallVector<ISD::OutputArg, 8> Outs;
-  std::vector<Type *> TypeList;
+  SmallVector<Type *, 4> TypeList;
   for (auto &AInfo : Info.OrigArgs) {
     // Handle any required unmerging of split value types from a given VReg into
     // physical registers. ArgInfo objects are constructed correspondingly and
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 4956417ea4b3e8..d300d33a5a7bea 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -18032,7 +18032,7 @@ bool RISCV::CC_RISCV(const DataLayout &DL, RISCVABI::ABI ABI, unsigned ValNo,
   }
 
   // Allocate to a register if possible, or else a stack slot.
-  Register Reg = MCRegister();
+  Register Reg;
   unsigned StoreSizeBytes = XLen / 8;
   Align StackAlign = Align(XLen / 8);
 
@@ -18129,7 +18129,7 @@ void RISCVTargetLowering::analyzeInputArgs(
   unsigned NumArgs = Ins.size();
   FunctionType *FType = MF.getFunction().getFunctionType();
 
-  std::vector<Type *> TypeList;
+  SmallVector<Type *, 4> TypeList;
   if (IsRet)
     TypeList.push_back(MF.getFunction().getReturnType());
   else
@@ -18164,7 +18164,7 @@ void RISCVTargetLowering::analyzeOutputArgs(
     CallLoweringInfo *CLI, RISCVCCAssignFn Fn) const {
   unsigned NumArgs = Outs.size();
 
-  std::vector<Type *> TypeList;
+  SmallVector<Type *, 4> TypeList;
   if (IsRet)
     TypeList.push_back(MF.getFunction().getReturnType());
   else if (CLI)
@@ -19073,8 +19073,7 @@ bool RISCVTargetLowering::CanLowerReturn(
   SmallVector<CCValAssign, 16> RVLocs;
   CCState CCInfo(CallConv, IsVarArg, MF, RVLocs, Context);
 
-  std::vector<Type *> TypeList = {MF.getFunction().getReturnType()};
-  RVVArgDispatcher Dispatcher{&MF, this, TypeList};
+  RVVArgDispatcher Dispatcher{&MF, this, MF.getFunction().getReturnType()};
 
   for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
     MVT VT = Outs[i].VT;
@@ -20877,7 +20876,7 @@ unsigned RISCVTargetLowering::getMinimumJumpTableEntries() const {
   return Subtarget.getMinimumJumpTableEntries();
 }
 
-void RVVArgDispatcher::constructArgInfos(Type *Ty) {
+void RVVArgDispatcher::constructArgInfos(Type *Ty, bool &FirstMaskAssigned) {
   const DataLayout &DL = MF->getDataLayout();
   const Function &F = MF->getFunction();
   LLVMContext &Context = F.getContext();
@@ -20909,21 +20908,23 @@ void RVVArgDispatcher::constructArgInfos(Type *Ty) {
       if (RegisterVT.isFixedLengthVector())
         RegisterVT = TLI->getContainerForFixedLengthVector(RegisterVT);
 
-      RVVArgInfo Info{1, RegisterVT, false};
-      RVVArgInfos.insert(RVVArgInfos.end(), NumRegs, Info);
+      if (!FirstMaskAssigned && RegisterVT.getVectorElementType() == MVT::i1) {
+        RVVArgInfos.push_back({1, RegisterVT, true});
+        FirstMaskAssigned = true;
+      } else {
+        RVVArgInfos.push_back({1, RegisterVT, false});
+      }
+
+      RVVArgInfos.insert(RVVArgInfos.end(), --NumRegs, {1, RegisterVT, false});
     }
   }
 }
 
-void RVVArgDispatcher::construct(const std::vector<Type *> &TypeList) {
+void RVVArgDispatcher::constructArgInfos(
+    const SmallVectorImpl<Type *> &TypeList) {
+  bool FirstVMaskAssigned = false;
   for (Type *Ty : TypeList)
-    constructArgInfos(Ty);
-
-  for (auto &Info : RVVArgInfos)
-    if (Info.NF == 1 && Info.VT.getVectorElementType() == MVT::i1) {
-      Info.FirstVMask = true;
-      break;
-    }
+    constructArgInfos(Ty, FirstVMaskAssigned);
 }
 
 void RVVArgDispatcher::allocatePhysReg(unsigned NF, unsigned LMul,
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 1ef09c5b9dcda6..ea02ff0aec93fa 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -1038,33 +1038,33 @@ class RVVArgDispatcher {
   };
 
   RVVArgDispatcher(const MachineFunction *MF, const RISCVTargetLowering *TLI,
-                   std::vector<Type *> &TypeList)
+                   SmallVectorImpl<Type *> &TypeList)
       : MF(MF), TLI(TLI) {
-    construct(TypeList);
+    constructArgInfos(TypeList);
     compute();
   }
 
   RVVArgDispatcher(const MachineFunction *MF, const RISCVTargetLowering *TLI,
                    Type *Ty)
       : MF(MF), TLI(TLI) {
-    std::vector<Type *> TypeList = {Ty};
-    construct(TypeList);
+    SmallVector<Type *, 4> TypeList = {Ty};
+    constructArgInfos(TypeList);
     compute();
   }
 
   MCPhysReg getNextPhysReg();
 
 private:
-  std::vector<RVVArgInfo> RVVArgInfos;
-  std::vector<MCPhysReg> AllocatedPhysRegs;
+  SmallVector<RVVArgInfo, 4> RVVArgInfos;
+  SmallVector<MCPhysReg, 4> AllocatedPhysRegs;
 
   const MachineFunction *MF = nullptr;
   const RISCVTargetLowering *TLI = nullptr;
 
   unsigned CurIdx = 0;
 
-  void construct(const std::vector<Type *> &TypeList);
-  void constructArgInfos(Type *Ty);
+  void constructArgInfos(const SmallVectorImpl<Type *> &TypeList);
+  void constructArgInfos(Type *Ty, bool &FirstMaskAssigned);
   void compute();
   void allocatePhysReg(unsigned NF = 1, unsigned LMul = 1,
                        unsigned StartReg = 0);



More information about the llvm-commits mailing list