[llvm] [RISCV][MI] Support partial spill/reload for vector registers (PR #105661)

via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 22 07:02:16 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

Author: Brandon Wu (4vtomat)

<details>
<summary>Changes</summary>

RFC: https://discourse.llvm.org/t/rfc-riscv-vector-register-spill-optimization-pass/80850

Current RISC-V vector register spill/reload works on full vector registers
no matter how long the defining instruction uses, in some cases, it's not
necessary to spill full vector register, for example:
```
vsetvli a1, a0, e8, mf2, ta, ma
vadd.vv v8, v8, v9
vs1r v8, (a2)                   <- spill
    .
    .
    .
vl1r v8, (a2)                   <- reload
vmul.vv v8, v8, v9
```

Both spill and reload can be replaced to `vse8.v` and `vle8.v` respectively
as below:
```
vsetvli a1, a0, e8, mf2, ta, ma
vadd.vv v8, v8, v9
vse8.v v8, (a2)                 <- spill
    .
    .
    .
vsetvli a1, x0, e8, mf2, ta, ma
vle8.v v8, (a2)                 <- reload
vmul.vv v8, v8, v9
```

Note that this patch doesn't support the BB if there is any inline
assembly, for example:
```
%0 = vadd.vv v8, v9 (e8, mf2)
vs1r %0, %stack.0
...
inline_asm("vsetvli 888, e8, m1")
%1 = vl1r %stack.0
inline_asm("vadd.vv %a, %b, %c", %a=v8, %b=%1, %c=%1)
```

If we rewrite the case above, %1 would become vle8 with mf2 and the
RISCVInsertVSETVLI would emit a vsetvli with mf2 for %1 which is
incompatible with original semantic which is m1.


---

Patch is 51.38 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/105661.diff


9 Files Affected:

- (modified) llvm/lib/Target/RISCV/CMakeLists.txt (+1) 
- (modified) llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h (+7) 
- (modified) llvm/lib/Target/RISCV/RISCV.h (+3) 
- (modified) llvm/lib/Target/RISCV/RISCVInstrFormats.td (+4) 
- (modified) llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td (+18) 
- (modified) llvm/lib/Target/RISCV/RISCVInstrInfoZvk.td (+2-1) 
- (added) llvm/lib/Target/RISCV/RISCVSpillRewrite.cpp (+407) 
- (modified) llvm/lib/Target/RISCV/RISCVTargetMachine.cpp (+8) 
- (added) llvm/test/CodeGen/RISCV/rvv/vector-spill-rewrite.mir (+551) 


``````````diff
diff --git a/llvm/lib/Target/RISCV/CMakeLists.txt b/llvm/lib/Target/RISCV/CMakeLists.txt
index cbb4c2cedfb97e..858a82cd6f9086 100644
--- a/llvm/lib/Target/RISCV/CMakeLists.txt
+++ b/llvm/lib/Target/RISCV/CMakeLists.txt
@@ -38,6 +38,7 @@ add_llvm_target(RISCVCodeGen
   RISCVGatherScatterLowering.cpp
   RISCVIndirectBranchTracking.cpp
   RISCVInsertVSETVLI.cpp
+  RISCVSpillRewrite.cpp
   RISCVInsertReadWriteCSR.cpp
   RISCVInsertWriteVXRM.cpp
   RISCVInstrInfo.cpp
diff --git a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h
index c65bd5b1d33631..12b2b8c79ef47f 100644
--- a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h
+++ b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h
@@ -126,6 +126,10 @@ enum {
 
   ActiveElementsAffectResultShift = TargetOverlapConstraintTypeShift + 2,
   ActiveElementsAffectResultMask = 1ULL << ActiveElementsAffectResultShift,
+
+  // Indicates whether the instruction produces widened result.
+  IsWidenShift = ActiveElementsAffectResultShift + 1,
+  IsWidenMask = 1 << IsWidenShift,
 };
 
 // Helper functions to read TSFlags.
@@ -149,6 +153,9 @@ static inline bool isTiedPseudo(uint64_t TSFlags) {
 static inline bool hasSEWOp(uint64_t TSFlags) {
   return TSFlags & HasSEWOpMask;
 }
+
+/// \returns true if the instruction produces widened result.
+static inline bool isWiden(uint64_t TSFlags) { return TSFlags & IsWidenMask; }
 /// \returns true if there is a VL operand for the instruction.
 static inline bool hasVLOp(uint64_t TSFlags) {
   return TSFlags & HasVLOpMask;
diff --git a/llvm/lib/Target/RISCV/RISCV.h b/llvm/lib/Target/RISCV/RISCV.h
index 5a94ada8f8dd46..4779043dc207a0 100644
--- a/llvm/lib/Target/RISCV/RISCV.h
+++ b/llvm/lib/Target/RISCV/RISCV.h
@@ -68,6 +68,9 @@ FunctionPass *createRISCVInsertVSETVLIPass();
 void initializeRISCVInsertVSETVLIPass(PassRegistry &);
 extern char &RISCVInsertVSETVLIID;
 
+FunctionPass *createRISCVSpillRewritePass();
+void initializeRISCVSpillRewritePass(PassRegistry &);
+
 FunctionPass *createRISCVPostRAExpandPseudoPass();
 void initializeRISCVPostRAExpandPseudoPass(PassRegistry &);
 FunctionPass *createRISCVInsertReadWriteCSRPass();
diff --git a/llvm/lib/Target/RISCV/RISCVInstrFormats.td b/llvm/lib/Target/RISCV/RISCVInstrFormats.td
index 95f157064d73e2..d33badc331ac9a 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrFormats.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrFormats.td
@@ -226,6 +226,10 @@ class RVInstCommon<dag outs, dag ins, string opcodestr, string argstr,
 
   bit ActiveElementsAffectResult = 0;
   let TSFlags{23} = ActiveElementsAffectResult;
+
+  // Indicates whether the instruction produces widened result.
+  bit IsWiden = 0;
+  let TSFlags{24} = IsWiden;
 }
 
 class RVInst<dag outs, dag ins, string opcodestr, string argstr,
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
index 1b4303fbbcf809..9f022fea3b0cd6 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
@@ -6277,6 +6277,7 @@ foreach vti = AllIntegerVectors in {
 //===----------------------------------------------------------------------===//
 // 11.2. Vector Widening Integer Add/Subtract
 //===----------------------------------------------------------------------===//
+let IsWiden = 1 in {
 defm PseudoVWADDU : VPseudoVWALU_VV_VX<Commutable=1>;
 defm PseudoVWSUBU : VPseudoVWALU_VV_VX;
 defm PseudoVWADD  : VPseudoVWALU_VV_VX<Commutable=1>;
@@ -6285,6 +6286,7 @@ defm PseudoVWADDU : VPseudoVWALU_WV_WX;
 defm PseudoVWSUBU : VPseudoVWALU_WV_WX;
 defm PseudoVWADD  : VPseudoVWALU_WV_WX;
 defm PseudoVWSUB  : VPseudoVWALU_WV_WX;
+}
 
 //===----------------------------------------------------------------------===//
 // 11.3. Vector Integer Extension
@@ -6366,9 +6368,11 @@ defm PseudoVREM  : VPseudoVDIV_VV_VX;
 //===----------------------------------------------------------------------===//
 // 11.12. Vector Widening Integer Multiply Instructions
 //===----------------------------------------------------------------------===//
+let IsWiden = 1 in {
 defm PseudoVWMUL   : VPseudoVWMUL_VV_VX<Commutable=1>;
 defm PseudoVWMULU  : VPseudoVWMUL_VV_VX<Commutable=1>;
 defm PseudoVWMULSU : VPseudoVWMUL_VV_VX;
+}
 
 //===----------------------------------------------------------------------===//
 // 11.13. Vector Single-Width Integer Multiply-Add Instructions
@@ -6381,10 +6385,12 @@ defm PseudoVNMSUB : VPseudoVMAC_VV_VX_AAXA;
 //===----------------------------------------------------------------------===//
 // 11.14. Vector Widening Integer Multiply-Add Instructions
 //===----------------------------------------------------------------------===//
+let IsWiden = 1 in {
 defm PseudoVWMACCU  : VPseudoVWMAC_VV_VX<Commutable=1>;
 defm PseudoVWMACC   : VPseudoVWMAC_VV_VX<Commutable=1>;
 defm PseudoVWMACCSU : VPseudoVWMAC_VV_VX;
 defm PseudoVWMACCUS : VPseudoVWMAC_VX;
+}
 
 //===----------------------------------------------------------------------===//
 // 11.15. Vector Integer Merge Instructions
@@ -6458,12 +6464,14 @@ defm PseudoVFRSUB : VPseudoVALU_VF_RM;
 //===----------------------------------------------------------------------===//
 // 13.3. Vector Widening Floating-Point Add/Subtract Instructions
 //===----------------------------------------------------------------------===//
+let IsWiden = 1 in {
 let mayRaiseFPException = true, hasSideEffects = 0, hasPostISelHook = 1 in {
 defm PseudoVFWADD : VPseudoVFWALU_VV_VF_RM;
 defm PseudoVFWSUB : VPseudoVFWALU_VV_VF_RM;
 defm PseudoVFWADD : VPseudoVFWALU_WV_WF_RM;
 defm PseudoVFWSUB : VPseudoVFWALU_WV_WF_RM;
 }
+}
 
 //===----------------------------------------------------------------------===//
 // 13.4. Vector Single-Width Floating-Point Multiply/Divide Instructions
@@ -6477,9 +6485,11 @@ defm PseudoVFRDIV : VPseudoVFRDIV_VF_RM;
 //===----------------------------------------------------------------------===//
 // 13.5. Vector Widening Floating-Point Multiply
 //===----------------------------------------------------------------------===//
+let IsWiden = 1 in {
 let mayRaiseFPException = true, hasSideEffects = 0 in {
 defm PseudoVFWMUL : VPseudoVWMUL_VV_VF_RM;
 }
+}
 
 //===----------------------------------------------------------------------===//
 // 13.6. Vector Single-Width Floating-Point Fused Multiply-Add Instructions
@@ -6498,6 +6508,7 @@ defm PseudoVFNMSUB : VPseudoVMAC_VV_VF_AAXA_RM;
 //===----------------------------------------------------------------------===//
 // 13.7. Vector Widening Floating-Point Fused Multiply-Add Instructions
 //===----------------------------------------------------------------------===//
+let IsWiden = 1 in {
 let mayRaiseFPException = true, hasSideEffects = 0, hasPostISelHook = 1 in {
 defm PseudoVFWMACC  : VPseudoVWMAC_VV_VF_RM;
 defm PseudoVFWNMACC : VPseudoVWMAC_VV_VF_RM;
@@ -6506,6 +6517,7 @@ defm PseudoVFWNMSAC : VPseudoVWMAC_VV_VF_RM;
 let Predicates = [HasStdExtZvfbfwma] in
 defm PseudoVFWMACCBF16  : VPseudoVWMAC_VV_VF_BF_RM;
 }
+}
 
 //===----------------------------------------------------------------------===//
 // 13.8. Vector Floating-Point Square-Root Instruction
@@ -6594,6 +6606,7 @@ defm PseudoVFCVT_RM_F_X  : VPseudoVCVTF_RM_V;
 //===----------------------------------------------------------------------===//
 // 13.18. Widening Floating-Point/Integer Type-Convert Instructions
 //===----------------------------------------------------------------------===//
+let IsWiden = 1 in {
 let mayRaiseFPException = true in {
 let hasSideEffects = 0, hasPostISelHook = 1 in {
 defm PseudoVFWCVT_XU_F     : VPseudoVWCVTI_V_RM;
@@ -6611,6 +6624,7 @@ defm PseudoVFWCVT_F_X      : VPseudoVWCVTF_V;
 defm PseudoVFWCVT_F_F      : VPseudoVWCVTD_V;
 defm PseudoVFWCVTBF16_F_F :  VPseudoVWCVTD_V;
 } // mayRaiseFPException = true
+}
 
 //===----------------------------------------------------------------------===//
 // 13.19. Narrowing Floating-Point/Integer Type-Convert Instructions
@@ -6661,11 +6675,13 @@ defm PseudoVREDMAX  : VPseudoVREDMINMAX_VS;
 //===----------------------------------------------------------------------===//
 // 14.2. Vector Widening Integer Reduction Instructions
 //===----------------------------------------------------------------------===//
+let IsWiden = 1 in {
 let IsRVVWideningReduction = 1 in {
 defm PseudoVWREDSUMU   : VPseudoVWRED_VS;
 defm PseudoVWREDSUM    : VPseudoVWRED_VS;
 }
 } // Predicates = [HasVInstructions]
+}
 
 let Predicates = [HasVInstructionsAnyF] in {
 //===----------------------------------------------------------------------===//
@@ -6684,6 +6700,7 @@ defm PseudoVFREDMAX  : VPseudoVFREDMINMAX_VS;
 //===----------------------------------------------------------------------===//
 // 14.4. Vector Widening Floating-Point Reduction Instructions
 //===----------------------------------------------------------------------===//
+let IsWiden = 1 in {
 let IsRVVWideningReduction = 1,
     hasSideEffects = 0,
     mayRaiseFPException = true in {
@@ -6692,6 +6709,7 @@ defm PseudoVFWREDOSUM  : VPseudoVFWREDO_VS_RM;
 }
 
 } // Predicates = [HasVInstructionsAnyF]
+}
 
 //===----------------------------------------------------------------------===//
 // 15. Vector Mask Instructions
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZvk.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZvk.td
index cafd259031746d..f2f69be150f4d0 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoZvk.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZvk.td
@@ -529,7 +529,8 @@ let Predicates = [HasStdExtZvbb] in {
   defm PseudoVCLZ   : VPseudoVCLZ;
   defm PseudoVCTZ   : VPseudoVCTZ;
   defm PseudoVCPOP  : VPseudoVCPOP;
-  defm PseudoVWSLL : VPseudoVWSLL;
+  let IsWiden = 1 in
+    defm PseudoVWSLL : VPseudoVWSLL;
 } // Predicates = [HasStdExtZvbb]
 
 let Predicates = [HasStdExtZvbc] in {
diff --git a/llvm/lib/Target/RISCV/RISCVSpillRewrite.cpp b/llvm/lib/Target/RISCV/RISCVSpillRewrite.cpp
new file mode 100644
index 00000000000000..d1c70ed944da6e
--- /dev/null
+++ b/llvm/lib/Target/RISCV/RISCVSpillRewrite.cpp
@@ -0,0 +1,407 @@
+//===----------- RISCVSpillRewrite.cpp - RISC-V Spill Rewrite -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a function pass that rewrite spills and reloads to
+// reduce the instruction latency by changing full register
+// store/load(VS1R/VL1R) to fractional store/load(VSE/VLE) needed and expands.
+//
+// The algorithm finds and rewrites spills(VS1R) to VSE if the spilled vreg only
+// needs fraction of a vreg(determined by the last write instruction's LMUL),
+// note that if the spilled register comes from different BB, it will find the
+// union LMUL of each defined BB. After then, it rewrites reloads(VL1R) to VLE
+// follows the corresponding spills in the spill slots. The algorithm runs until
+// there's no any rewrite.
+//
+//===----------------------------------------------------------------------===//
+
+#include "MCTargetDesc/RISCVBaseInfo.h"
+#include "MCTargetDesc/RISCVMCTargetDesc.h"
+#include "RISCV.h"
+#include "RISCVSubtarget.h"
+#include "llvm/ADT/Statistic.h"
+#include "llvm/CodeGen/LiveDebugVariables.h"
+#include "llvm/CodeGen/LiveIntervals.h"
+#include "llvm/CodeGen/LiveStacks.h"
+#include "llvm/CodeGen/MachineBasicBlock.h"
+#include "llvm/CodeGen/MachineFrameInfo.h"
+#include "llvm/CodeGen/MachineFunctionPass.h"
+#include "llvm/CodeGen/MachineInstrBuilder.h"
+#include "llvm/CodeGen/MachineOperand.h"
+#include "llvm/TargetParser/RISCVTargetParser.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "riscv-spill-rewrite"
+#define RISCV_SPILL_REWRITE_NAME "RISC-V Spill Rewrite pass"
+
+namespace {
+static inline bool isSpillInst(const MachineInstr &MI) {
+  return MI.getOpcode() == RISCV::VS1R_V;
+}
+
+static inline bool isReloadInst(const MachineInstr &MI) {
+  return MI.getOpcode() == RISCV::VL1RE8_V ||
+         MI.getOpcode() == RISCV::VL2RE8_V ||
+         MI.getOpcode() == RISCV::VL4RE8_V ||
+         MI.getOpcode() == RISCV::VL8RE8_V ||
+         MI.getOpcode() == RISCV::VL1RE16_V ||
+         MI.getOpcode() == RISCV::VL2RE16_V ||
+         MI.getOpcode() == RISCV::VL4RE16_V ||
+         MI.getOpcode() == RISCV::VL8RE16_V ||
+         MI.getOpcode() == RISCV::VL1RE32_V ||
+         MI.getOpcode() == RISCV::VL2RE32_V ||
+         MI.getOpcode() == RISCV::VL4RE32_V ||
+         MI.getOpcode() == RISCV::VL8RE32_V ||
+         MI.getOpcode() == RISCV::VL1RE64_V ||
+         MI.getOpcode() == RISCV::VL2RE64_V ||
+         MI.getOpcode() == RISCV::VL4RE64_V ||
+         MI.getOpcode() == RISCV::VL8RE64_V;
+}
+
+static inline bool hasSpillSlotObject(const MachineFrameInfo *MFI,
+                                      const MachineInstr &MI,
+                                      bool IsReload = false) {
+  unsigned MemOpIdx = IsReload ? 2 : 1;
+  if (MI.getNumOperands() <= MemOpIdx || !MI.getOperand(MemOpIdx).isFI())
+    return false;
+
+  int FI = MI.getOperand(MemOpIdx).getIndex();
+  return MFI->isSpillSlotObjectIndex(FI);
+}
+
+static inline RISCVII::VLMUL maxLMUL(RISCVII::VLMUL LMUL1,
+                                     RISCVII::VLMUL LMUL2) {
+  int LMUL1Val = std::numeric_limits<int>::min();
+  int LMUL2Val = std::numeric_limits<int>::min();
+
+  if (LMUL1 != RISCVII::LMUL_RESERVED) {
+    auto DecodedLMUL1 = RISCVVType::decodeVLMUL(LMUL1);
+    LMUL1Val = DecodedLMUL1.second ? -DecodedLMUL1.first : DecodedLMUL1.first;
+  }
+  if (LMUL2 != RISCVII::LMUL_RESERVED) {
+    auto DecodedLMUL2 = RISCVVType::decodeVLMUL(LMUL2);
+    LMUL2Val = DecodedLMUL2.second ? -DecodedLMUL2.first : DecodedLMUL2.first;
+  }
+
+  return LMUL1Val > LMUL2Val ? LMUL1 : LMUL2;
+}
+
+static inline RISCVII::VLMUL getWidenedFracLMUL(RISCVII::VLMUL LMUL) {
+  if (LMUL == RISCVII::LMUL_F8)
+    return RISCVII::LMUL_F4;
+  if (LMUL == RISCVII::LMUL_F4)
+    return RISCVII::LMUL_F2;
+  if (LMUL == RISCVII::LMUL_F2)
+    return RISCVII::LMUL_1;
+
+  llvm_unreachable("The LMUL is supposed to be fractional.");
+}
+
+class RISCVSpillRewrite : public MachineFunctionPass {
+  const RISCVSubtarget *ST = nullptr;
+  const TargetInstrInfo *TII = nullptr;
+  MachineRegisterInfo *MRI = nullptr;
+  MachineFrameInfo *MFI = nullptr;
+  LiveIntervals *LIS = nullptr;
+
+public:
+  static char ID;
+  RISCVSpillRewrite() : MachineFunctionPass(ID) {}
+  StringRef getPassName() const override { return RISCV_SPILL_REWRITE_NAME; }
+  bool runOnMachineFunction(MachineFunction &MF) override;
+  void getAnalysisUsage(AnalysisUsage &AU) const override;
+
+private:
+  bool tryToRewrite(MachineFunction &MF);
+
+  // This function find the Reg's LMUL in its defining inst, if there're
+  // multiple instructions that define the Reg in different BB, recursively find
+  // them and return the maximum LMUL that are found. If it can't be found due
+  // to any reason such as the register is dead, it returns
+  // RISCVII::LMUL_RESERVE which means the Reg can't be rewritten.
+  // BegI represents the starting instruction in the beginning, this is used to
+  // determine whether it encounters a loop, if so then the defining instruction
+  // doesn't exist in this MBB.
+  RISCVII::VLMUL
+  findDefiningInstUnionLMUL(MachineBasicBlock &MBB, Register Reg,
+                            DenseMap<MachineInstr *, bool> &Visited,
+                            MachineBasicBlock::reverse_iterator BegI = nullptr);
+  bool tryToRewriteSpill(MachineBasicBlock &MBB, MachineBasicBlock::iterator I,
+                         std::map<int, RISCVII::VLMUL> &SpillLMUL);
+  bool tryToRewriteReload(MachineBasicBlock &MBB, MachineBasicBlock::iterator I,
+                          int FI,
+                          const std::map<int, RISCVII::VLMUL> &SpillLMUL);
+};
+
+} // end anonymous namespace
+
+char RISCVSpillRewrite::ID = 0;
+
+INITIALIZE_PASS(RISCVSpillRewrite, DEBUG_TYPE, RISCV_SPILL_REWRITE_NAME, false,
+                false)
+
+void RISCVSpillRewrite::getAnalysisUsage(AnalysisUsage &AU) const {
+  AU.setPreservesCFG();
+
+  AU.addPreserved<LiveIntervalsWrapperPass>();
+  AU.addRequired<LiveIntervalsWrapperPass>();
+  AU.addPreserved<SlotIndexesWrapperPass>();
+  AU.addRequired<SlotIndexesWrapperPass>();
+  AU.addPreserved<LiveDebugVariables>();
+  AU.addPreserved<LiveStacks>();
+
+  MachineFunctionPass::getAnalysisUsage(AU);
+}
+
+RISCVII::VLMUL RISCVSpillRewrite::findDefiningInstUnionLMUL(
+    MachineBasicBlock &MBB, Register Reg,
+    DenseMap<MachineInstr *, bool> &Visited,
+    MachineBasicBlock::reverse_iterator BegI) {
+  for (auto I = (BegI == nullptr ? MBB.rbegin() : BegI); I != MBB.rend(); ++I) {
+    if (I->isDebugInstr())
+      continue;
+
+    // Return the minimum LMUL if this MBB is a loop body and we meet the
+    // instruction that is already visited, it means the LMUL in this MBB is
+    // dont-care.
+    if (Visited.contains(&*I))
+      return RISCVII::LMUL_F8;
+
+    Visited[&*I];
+    if (I->definesRegister(Reg, nullptr)) {
+      if (I->registerDefIsDead(Reg, nullptr))
+        return RISCVII::LMUL_RESERVED;
+
+      if (isReloadInst(*I))
+        return RISCVII::LMUL_1;
+
+      if (auto DstSrcPair = TII->isCopyInstr(*I))
+        return findDefiningInstUnionLMUL(MBB, DstSrcPair->Source->getReg(),
+                                         Visited, *++I);
+
+      const uint64_t TSFlags = I->getDesc().TSFlags;
+      assert(RISCVII::hasSEWOp(TSFlags));
+
+      // If the instruction is tail undisturbed, we need to preserve the full
+      // vector register since the tail data might be used somewhere.
+      if (RISCVII::hasVecPolicyOp(TSFlags)) {
+        const MachineOperand &PolicyOp =
+            I->getOperand(I->getNumExplicitOperands() - 1);
+        if ((PolicyOp.getImm() & RISCVII::TAIL_AGNOSTIC) == 0)
+          return RISCVII::VLMUL::LMUL_1;
+      }
+
+      RISCVII::VLMUL LMUL = RISCVII::getLMul(TSFlags);
+      if (RISCVII::isRVVWideningReduction(TSFlags)) {
+        // Widening reduction produces only single element result, so we just
+        // need to calculate LMUL for single element.
+        int Log2SEW =
+            I->getOperand(RISCVII::getSEWOpNum(I->getDesc())).getImm();
+        int Log2LMUL = Log2SEW - Log2_64(ST->getELen());
+        LMUL =
+            static_cast<RISCVII::VLMUL>(Log2LMUL < 0 ? Log2LMUL + 8 : Log2LMUL);
+      }
+      if (RISCVII::isWiden(TSFlags))
+        LMUL = getWidenedFracLMUL(LMUL);
+
+      return LMUL;
+    }
+  }
+
+  assert(MBB.isLiveIn(Reg));
+
+  // If Reg's defining inst is not found in this BB, find it in it's
+  // predecessors.
+  RISCVII::VLMUL LMUL = RISCVII::LMUL_RESERVED;
+  for (MachineBasicBlock *P : MBB.predecessors()) {
+    RISCVII::VLMUL PredLMUL = findDefiningInstUnionLMUL(*P, Reg, Visited);
+    if (PredLMUL == RISCVII::LMUL_RESERVED)
+      continue;
+
+    if (LMUL == RISCVII::LMUL_RESERVED) {
+      LMUL = PredLMUL;
+      continue;
+    }
+
+    LMUL = maxLMUL(LMUL, PredLMUL);
+  }
+
+  return LMUL;
+}
+
+bool RISCVSpillRewrite::tryToRewriteSpill(
+    MachineBasicBlock &MBB, MachineBasicBlock::iterator I,
+    std::map<int, RISCVII::VLMUL> &SpillLMUL) {
+  Register SrcReg = I->getOperand(0).getReg();
+  unsigned Opcode = 0;
+  DenseMap<MachineInstr *, bool> Visited;
+  // Find the nearest inst defines this spilled reg.
+  RISCVII::VLMUL LMUL = findDefiningInstUnionLMUL(MBB, SrcReg, Visited, *I);
+  // If the register's defined inst just defines partial of register, we only
+  // need to store partial register.
+  switch (LMUL) {
+  case RISCVII::LMUL_F2:
+    Opcode = RISCV::PseudoVSE8_V_MF2;
+    break;
+  case RISCVII::LMUL_F4:
+    Opcode = RISCV::PseudoVSE8_V_MF4;
+    break;
+  case RISCVII::LMUL_F8:
+    Opcode = RISCV::PseudoVSE8_V_MF8;
+    break;
+  default:
+    break;
+  }
+
+  // No need to rewrite.
+  if (!Opcode)
+    return false;
+
+  int FI = I->getOperand(1).getIndex();
+  auto updateLMUL = [&](RISCVII::VLMUL LMUL) {
+    assert(!SpillLMUL.count(FI) &&
+           "Each frame index should only be used once.");
+    SpillLMUL[FI] = LMUL;
+  };
+
+  if (Opcode == RISCV::PseudoVSE8_V_MF2)
+    updateLMUL(RISCVII::LMUL_F2);
+  else if (Opcode == RISCV::PseudoVSE8_V_MF4)
+    updateLMUL(RISCVII::LMUL_F4);
+  else if (Opcode == RISCV::PseudoVSE8_V_MF8...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/105661


More information about the llvm-commits mailing list