[PATCH] D157067: [RISCV] Set the vector calling convention if any of input type or return type is vector

Brandon Wu via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Tue Aug 15 04:09:33 PDT 2023


4vtomat updated this revision to Diff 550246.
4vtomat added a comment.

Add an option to deduce vector cc(default off).


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D157067/new/

https://reviews.llvm.org/D157067

Files:
  llvm/lib/Target/RISCV/RISCVISelLowering.cpp
  llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp


Index: llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp
===================================================================
--- llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp
+++ llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp
@@ -29,6 +29,8 @@
 
 using namespace llvm;
 
+extern cl::opt<bool> DeduceVectorCC;
+
 static cl::opt<bool>
     DisableRegAllocHints("riscv-disable-regalloc-hints", cl::Hidden,
                          cl::init(false),
@@ -67,7 +69,9 @@
   }
 
   bool HasVectorCSR =
-      MF->getFunction().getCallingConv() == CallingConv::RISCV_VectorCall;
+      MF->getFunction().getCallingConv() == CallingConv::RISCV_VectorCall ||
+      (MF->getInfo<RISCVMachineFunctionInfo>()->isVectorCall() &&
+       DeduceVectorCC);
 
   switch (Subtarget.getTargetABI()) {
   default:
Index: llvm/lib/Target/RISCV/RISCVISelLowering.cpp
===================================================================
--- llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -73,6 +73,13 @@
                        "use for creating a floating-point immediate value"),
               cl::init(2));
 
+cl::opt<bool>
+    DeduceVectorCC(DEBUG_TYPE "-deduce-vector-cc", cl::Hidden,
+                   cl::desc("Automatically turn on vector calling convention "
+                            "for every function that has RVV argument/return "
+                            "type."),
+                   cl::init(false));
+
 RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
                                          const RISCVSubtarget &STI)
     : TargetLowering(TM), Subtarget(STI) {
@@ -15713,7 +15720,7 @@
   SDValue Chain = CLI.Chain;
   SDValue Callee = CLI.Callee;
   bool &IsTailCall = CLI.IsTailCall;
-  CallingConv::ID CallConv = CLI.CallConv;
+  CallingConv::ID &CallConv = CLI.CallConv;
   bool IsVarArg = CLI.IsVarArg;
   EVT PtrVT = getPointerTy(DAG.getDataLayout());
   MVT XLenVT = Subtarget.getXLenVT();
@@ -15731,6 +15738,28 @@
                       CallConv == CallingConv::Fast ? RISCV::CC_RISCV_FastCC
                                                     : RISCV::CC_RISCV);
 
+  // Assign locations to each value returned by this call.
+  SmallVector<CCValAssign, 16> RVLocs;
+  CCState RetCCInfo(CallConv, IsVarArg, MF, RVLocs, *DAG.getContext());
+  analyzeInputArgs(MF, RetCCInfo, Ins, /*IsRet=*/true, RISCV::CC_RISCV);
+
+  // Check callee args/returns for RVV registers and set calling convention
+  // accordingly.
+  if (DeduceVectorCC && (CallConv == CallingConv::C || CallConv == CallingConv::Fast)) {
+    auto HasRVVRegLoc = [](CCValAssign &Loc) {
+      if (!Loc.isRegLoc())
+        return false;
+
+      const auto RegClasses = {&RISCV::VRRegClass, &RISCV::VRM2RegClass,
+                               &RISCV::VRM4RegClass, &RISCV::VRM8RegClass};
+      return any_of(RegClasses, [&](const auto *RC)
+                                    { return RC->contains(Loc.getLocReg()); });
+    };
+    if (any_of(RVLocs, HasRVVRegLoc) || any_of(ArgLocs, HasRVVRegLoc)) {
+      CallConv = CallingConv::RISCV_VectorCall;
+    }
+  }
+
   // Check if it's really possible to do a tail call.
   if (IsTailCall)
     IsTailCall = isEligibleForTailCallOptimization(ArgCCInfo, CLI, MF, ArgLocs);
@@ -15977,11 +16006,6 @@
   Chain = DAG.getCALLSEQ_END(Chain, NumBytes, 0, Glue, DL);
   Glue = Chain.getValue(1);
 
-  // Assign locations to each value returned by this call.
-  SmallVector<CCValAssign, 16> RVLocs;
-  CCState RetCCInfo(CallConv, IsVarArg, MF, RVLocs, *DAG.getContext());
-  analyzeInputArgs(MF, RetCCInfo, Ins, /*IsRet=*/true, RISCV::CC_RISCV);
-
   // Copy all of the result registers out of their specified physreg.
   for (auto &VA : RVLocs) {
     // Copy the value out


-------------- next part --------------
A non-text attachment was scrubbed...
Name: D157067.550246.patch
Type: text/x-patch
Size: 3751 bytes
Desc: not available
URL: <http://lists.llvm.org/pipermail/llvm-commits/attachments/20230815/c9b3a27c/attachment.bin>


More information about the llvm-commits mailing list