[llvm] [RISCV] Remove pre-assignment of mask vectors during call lowering. (PR #107192)

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 4 00:21:24 PDT 2024


https://github.com/topperc created https://github.com/llvm/llvm-project/pull/107192

The first mask vector operand is supposed to be assigned to V0. No other vector types will be assigned to V0. We don't need to pre-assign, we can just try V0 first for any mask vectors in the normal processing.

>From 265b76a3b95e4caa977f307fbc04d290b33638b6 Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Wed, 4 Sep 2024 00:12:30 -0700
Subject: [PATCH] [RISCV] Remove pre-assignment of mask vectors during call
 lowering.

The first mask vector operand is supposed to be assigned to V0. No
other vector types will be assigned to V0. We don't need to pre-assign,
we can just try V0 first for any mask vectors in the normal processing.
---
 .../Target/RISCV/GISel/RISCVCallLowering.cpp  | 26 ++--------
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp   | 47 ++++---------------
 llvm/lib/Target/RISCV/RISCVISelLowering.h     |  9 ++--
 3 files changed, 16 insertions(+), 66 deletions(-)

diff --git a/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp b/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp
index 6e33032384eded..31a9df53a2aa1b 100644
--- a/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp
+++ b/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp
@@ -35,9 +35,6 @@ struct RISCVOutgoingValueAssigner : public CallLowering::OutgoingValueAssigner {
   // Whether this is assigning args for a return.
   bool IsRet;
 
-  // true if assignArg has been called for a mask argument, false otherwise.
-  bool AssignedFirstMaskArg = false;
-
 public:
   RISCVOutgoingValueAssigner(
       RISCVTargetLowering::RISCVCCAssignFn *RISCVAssignFn_, bool IsRet)
@@ -52,16 +49,9 @@ struct RISCVOutgoingValueAssigner : public CallLowering::OutgoingValueAssigner {
     const DataLayout &DL = MF.getDataLayout();
     const RISCVSubtarget &Subtarget = MF.getSubtarget<RISCVSubtarget>();
 
-    std::optional<unsigned> FirstMaskArgument;
-    if (Subtarget.hasVInstructions() && !AssignedFirstMaskArg &&
-        ValVT.isVector() && ValVT.getVectorElementType() == MVT::i1) {
-      FirstMaskArgument = ValNo;
-      AssignedFirstMaskArg = true;
-    }
-
     if (RISCVAssignFn(DL, Subtarget.getTargetABI(), ValNo, ValVT, LocVT,
                       LocInfo, Flags, State, Info.IsFixed, IsRet, Info.Ty,
-                      *Subtarget.getTargetLowering(), FirstMaskArgument))
+                      *Subtarget.getTargetLowering()))
       return true;
 
     StackSize = State.getStackSize();
@@ -197,9 +187,6 @@ struct RISCVIncomingValueAssigner : public CallLowering::IncomingValueAssigner {
   // Whether this is assigning args from a return.
   bool IsRet;
 
-  // true if assignArg has been called for a mask argument, false otherwise.
-  bool AssignedFirstMaskArg = false;
-
 public:
   RISCVIncomingValueAssigner(
       RISCVTargetLowering::RISCVCCAssignFn *RISCVAssignFn_, bool IsRet)
@@ -217,16 +204,9 @@ struct RISCVIncomingValueAssigner : public CallLowering::IncomingValueAssigner {
     if (LocVT.isScalableVector())
       MF.getInfo<RISCVMachineFunctionInfo>()->setIsVectorCall();
 
-    std::optional<unsigned> FirstMaskArgument;
-    if (Subtarget.hasVInstructions() && !AssignedFirstMaskArg &&
-        ValVT.isVector() && ValVT.getVectorElementType() == MVT::i1) {
-      FirstMaskArgument = ValNo;
-      AssignedFirstMaskArg = true;
-    }
-
     if (RISCVAssignFn(DL, Subtarget.getTargetABI(), ValNo, ValVT, LocVT,
                       LocInfo, Flags, State, /*IsFixed=*/true, IsRet, Info.Ty,
-                      *Subtarget.getTargetLowering(), FirstMaskArgument))
+                      *Subtarget.getTargetLowering()))
       return true;
 
     StackSize = State.getStackSize();
@@ -483,7 +463,7 @@ bool RISCVCallLowering::canLowerReturn(MachineFunction &MF,
     MVT VT = MVT::getVT(Outs[I].Ty);
     if (RISCV::CC_RISCV(MF.getDataLayout(), ABI, I, VT, VT, CCValAssign::Full,
                         Outs[I].Flags[0], CCInfo, /*IsFixed=*/true,
-                        /*isRet=*/true, nullptr, TLI, FirstMaskArgument))
+                        /*isRet=*/true, nullptr, TLI))
       return false;
   }
   return true;
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 8f95a86ade3039..00e338acec749f 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -19123,7 +19123,6 @@ static bool CC_RISCVAssign2XLen(unsigned XLen, CCState &State, CCValAssign VA1,
 }
 
 static MCRegister allocateRVVReg(MVT ValVT, unsigned ValNo,
-                                 std::optional<unsigned> FirstMaskArgument,
                                  CCState &State,
                                  const RISCVTargetLowering &TLI) {
   const TargetRegisterClass *RC = TLI.getRegClassFor(ValVT);
@@ -19131,8 +19130,9 @@ static MCRegister allocateRVVReg(MVT ValVT, unsigned ValNo,
     // Assign the first mask argument to V0.
     // This is an interim calling convention and it may be changed in the
     // future.
-    if (FirstMaskArgument && ValNo == *FirstMaskArgument)
-      return State.AllocateReg(RISCV::V0);
+    if (ValVT.getVectorElementType() == MVT::i1)
+      if (MCRegister Reg = State.AllocateReg(RISCV::V0))
+        return Reg;
     return State.AllocateReg(ArgVRs);
   }
   if (RC == &RISCV::VRM2RegClass)
@@ -19170,8 +19170,7 @@ static MCRegister allocateRVVReg(MVT ValVT, unsigned ValNo,
 bool RISCV::CC_RISCV(const DataLayout &DL, RISCVABI::ABI ABI, unsigned ValNo,
                      MVT ValVT, MVT LocVT, CCValAssign::LocInfo LocInfo,
                      ISD::ArgFlagsTy ArgFlags, CCState &State, bool IsFixed,
-                     bool IsRet, Type *OrigTy, const RISCVTargetLowering &TLI,
-                     std::optional<unsigned> FirstMaskArgument) {
+                     bool IsRet, Type *OrigTy, const RISCVTargetLowering &TLI) {
   unsigned XLen = DL.getLargestLegalIntTypeSizeInBits();
   assert(XLen == 32 || XLen == 64);
   MVT XLenVT = XLen == 32 ? MVT::i32 : MVT::i64;
@@ -19351,7 +19350,7 @@ bool RISCV::CC_RISCV(const DataLayout &DL, RISCVABI::ABI ABI, unsigned ValNo,
   else if (ValVT == MVT::f64 && !UseGPRForF64)
     Reg = State.AllocateReg(ArgFPR64s);
   else if (ValVT.isVector() || ValVT.isRISCVVectorTuple()) {
-    Reg = allocateRVVReg(ValVT, ValNo, FirstMaskArgument, State, TLI);
+    Reg = allocateRVVReg(ValVT, ValNo, State, TLI);
     if (!Reg) {
       // For return values, the vector must be passed fully via registers or
       // via the stack.
@@ -19421,16 +19420,6 @@ bool RISCV::CC_RISCV(const DataLayout &DL, RISCVABI::ABI ABI, unsigned ValNo,
   return false;
 }
 
-template <typename ArgTy>
-static std::optional<unsigned> preAssignMask(const ArgTy &Args) {
-  for (const auto &ArgIdx : enumerate(Args)) {
-    MVT ArgVT = ArgIdx.value().VT;
-    if (ArgVT.isVector() && ArgVT.getVectorElementType() == MVT::i1)
-      return ArgIdx.index();
-  }
-  return std::nullopt;
-}
-
 void RISCVTargetLowering::analyzeInputArgs(
     MachineFunction &MF, CCState &CCInfo,
     const SmallVectorImpl<ISD::InputArg> &Ins, bool IsRet,
@@ -19438,10 +19427,6 @@ void RISCVTargetLowering::analyzeInputArgs(
   unsigned NumArgs = Ins.size();
   FunctionType *FType = MF.getFunction().getFunctionType();
 
-  std::optional<unsigned> FirstMaskArgument;
-  if (Subtarget.hasVInstructions())
-    FirstMaskArgument = preAssignMask(Ins);
-
   for (unsigned i = 0; i != NumArgs; ++i) {
     MVT ArgVT = Ins[i].VT;
     ISD::ArgFlagsTy ArgFlags = Ins[i].Flags;
@@ -19454,8 +19439,7 @@ void RISCVTargetLowering::analyzeInputArgs(
 
     RISCVABI::ABI ABI = MF.getSubtarget<RISCVSubtarget>().getTargetABI();
     if (Fn(MF.getDataLayout(), ABI, i, ArgVT, ArgVT, CCValAssign::Full,
-           ArgFlags, CCInfo, /*IsFixed=*/true, IsRet, ArgTy, *this,
-           FirstMaskArgument)) {
+           ArgFlags, CCInfo, /*IsFixed=*/true, IsRet, ArgTy, *this)) {
       LLVM_DEBUG(dbgs() << "InputArg #" << i << " has unhandled type "
                         << ArgVT << '\n');
       llvm_unreachable(nullptr);
@@ -19469,10 +19453,6 @@ void RISCVTargetLowering::analyzeOutputArgs(
     CallLoweringInfo *CLI, RISCVCCAssignFn Fn) const {
   unsigned NumArgs = Outs.size();
 
-  std::optional<unsigned> FirstMaskArgument;
-  if (Subtarget.hasVInstructions())
-    FirstMaskArgument = preAssignMask(Outs);
-
   for (unsigned i = 0; i != NumArgs; i++) {
     MVT ArgVT = Outs[i].VT;
     ISD::ArgFlagsTy ArgFlags = Outs[i].Flags;
@@ -19480,8 +19460,7 @@ void RISCVTargetLowering::analyzeOutputArgs(
 
     RISCVABI::ABI ABI = MF.getSubtarget<RISCVSubtarget>().getTargetABI();
     if (Fn(MF.getDataLayout(), ABI, i, ArgVT, ArgVT, CCValAssign::Full,
-           ArgFlags, CCInfo, Outs[i].IsFixed, IsRet, OrigTy, *this,
-           FirstMaskArgument)) {
+           ArgFlags, CCInfo, Outs[i].IsFixed, IsRet, OrigTy, *this)) {
       LLVM_DEBUG(dbgs() << "OutputArg #" << i << " has unhandled type "
                         << ArgVT << "\n");
       llvm_unreachable(nullptr);
@@ -19659,8 +19638,7 @@ bool RISCV::CC_RISCV_FastCC(const DataLayout &DL, RISCVABI::ABI ABI,
                             CCValAssign::LocInfo LocInfo,
                             ISD::ArgFlagsTy ArgFlags, CCState &State,
                             bool IsFixed, bool IsRet, Type *OrigTy,
-                            const RISCVTargetLowering &TLI,
-                            std::optional<unsigned> FirstMaskArgument) {
+                            const RISCVTargetLowering &TLI) {
   if (LocVT == MVT::i32 || LocVT == MVT::i64) {
     if (MCRegister Reg = State.AllocateReg(getFastCCArgGPRs(ABI))) {
       State.addLoc(CCValAssign::getReg(ValNo, ValVT, Reg, LocVT, LocInfo));
@@ -19744,8 +19722,7 @@ bool RISCV::CC_RISCV_FastCC(const DataLayout &DL, RISCVABI::ABI ABI,
   }
 
   if (LocVT.isVector()) {
-    if (MCRegister Reg =
-            allocateRVVReg(ValVT, ValNo, FirstMaskArgument, State, TLI)) {
+    if (MCRegister Reg = allocateRVVReg(ValVT, ValNo, State, TLI)) {
       // Fixed-length vectors are located in the corresponding scalable-vector
       // container types.
       if (ValVT.isFixedLengthVector())
@@ -20377,17 +20354,13 @@ bool RISCVTargetLowering::CanLowerReturn(
   SmallVector<CCValAssign, 16> RVLocs;
   CCState CCInfo(CallConv, IsVarArg, MF, RVLocs, Context);
 
-  std::optional<unsigned> FirstMaskArgument;
-  if (Subtarget.hasVInstructions())
-    FirstMaskArgument = preAssignMask(Outs);
-
   for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
     MVT VT = Outs[i].VT;
     ISD::ArgFlagsTy ArgFlags = Outs[i].Flags;
     RISCVABI::ABI ABI = MF.getSubtarget<RISCVSubtarget>().getTargetABI();
     if (RISCV::CC_RISCV(MF.getDataLayout(), ABI, i, VT, VT, CCValAssign::Full,
                         ArgFlags, CCInfo, /*IsFixed=*/true, /*IsRet=*/true,
-                        nullptr, *this, FirstMaskArgument))
+                        nullptr, *this))
       return false;
   }
   return true;
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index f1d1cca043b358..3beee4686956ec 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -904,8 +904,7 @@ class RISCVTargetLowering : public TargetLowering {
                                CCValAssign::LocInfo LocInfo,
                                ISD::ArgFlagsTy ArgFlags, CCState &State,
                                bool IsFixed, bool IsRet, Type *OrigTy,
-                               const RISCVTargetLowering &TLI,
-                               std::optional<unsigned> FirstMaskArgument);
+                               const RISCVTargetLowering &TLI);
 
 private:
   void analyzeInputArgs(MachineFunction &MF, CCState &CCInfo,
@@ -1054,14 +1053,12 @@ namespace RISCV {
 bool CC_RISCV(const DataLayout &DL, RISCVABI::ABI ABI, unsigned ValNo,
               MVT ValVT, MVT LocVT, CCValAssign::LocInfo LocInfo,
               ISD::ArgFlagsTy ArgFlags, CCState &State, bool IsFixed,
-              bool IsRet, Type *OrigTy, const RISCVTargetLowering &TLI,
-              std::optional<unsigned> FirstMaskArgument);
+              bool IsRet, Type *OrigTy, const RISCVTargetLowering &TLI);
 
 bool CC_RISCV_FastCC(const DataLayout &DL, RISCVABI::ABI ABI, unsigned ValNo,
                      MVT ValVT, MVT LocVT, CCValAssign::LocInfo LocInfo,
                      ISD::ArgFlagsTy ArgFlags, CCState &State, bool IsFixed,
-                     bool IsRet, Type *OrigTy, const RISCVTargetLowering &TLI,
-                     std::optional<unsigned> FirstMaskArgument);
+                     bool IsRet, Type *OrigTy, const RISCVTargetLowering &TLI);
 
 bool CC_RISCV_GHC(unsigned ValNo, MVT ValVT, MVT LocVT,
                   CCValAssign::LocInfo LocInfo, ISD::ArgFlagsTy ArgFlags,



More information about the llvm-commits mailing list