[llvm] [RISCV][MI] Support partial spill/reload for vector registers (PR #105661)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Aug 22 07:02:16 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-risc-v
Author: Brandon Wu (4vtomat)
<details>
<summary>Changes</summary>
RFC: https://discourse.llvm.org/t/rfc-riscv-vector-register-spill-optimization-pass/80850
Current RISC-V vector register spill/reload works on full vector registers
no matter how long the defining instruction uses, in some cases, it's not
necessary to spill full vector register, for example:
```
vsetvli a1, a0, e8, mf2, ta, ma
vadd.vv v8, v8, v9
vs1r v8, (a2) <- spill
.
.
.
vl1r v8, (a2) <- reload
vmul.vv v8, v8, v9
```
Both spill and reload can be replaced to `vse8.v` and `vle8.v` respectively
as below:
```
vsetvli a1, a0, e8, mf2, ta, ma
vadd.vv v8, v8, v9
vse8.v v8, (a2) <- spill
.
.
.
vsetvli a1, x0, e8, mf2, ta, ma
vle8.v v8, (a2) <- reload
vmul.vv v8, v8, v9
```
Note that this patch doesn't support the BB if there is any inline
assembly, for example:
```
%0 = vadd.vv v8, v9 (e8, mf2)
vs1r %0, %stack.0
...
inline_asm("vsetvli 888, e8, m1")
%1 = vl1r %stack.0
inline_asm("vadd.vv %a, %b, %c", %a=v8, %b=%1, %c=%1)
```
If we rewrite the case above, %1 would become vle8 with mf2 and the
RISCVInsertVSETVLI would emit a vsetvli with mf2 for %1 which is
incompatible with original semantic which is m1.
---
Patch is 51.38 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/105661.diff
9 Files Affected:
- (modified) llvm/lib/Target/RISCV/CMakeLists.txt (+1)
- (modified) llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h (+7)
- (modified) llvm/lib/Target/RISCV/RISCV.h (+3)
- (modified) llvm/lib/Target/RISCV/RISCVInstrFormats.td (+4)
- (modified) llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td (+18)
- (modified) llvm/lib/Target/RISCV/RISCVInstrInfoZvk.td (+2-1)
- (added) llvm/lib/Target/RISCV/RISCVSpillRewrite.cpp (+407)
- (modified) llvm/lib/Target/RISCV/RISCVTargetMachine.cpp (+8)
- (added) llvm/test/CodeGen/RISCV/rvv/vector-spill-rewrite.mir (+551)
``````````diff
diff --git a/llvm/lib/Target/RISCV/CMakeLists.txt b/llvm/lib/Target/RISCV/CMakeLists.txt
index cbb4c2cedfb97e..858a82cd6f9086 100644
--- a/llvm/lib/Target/RISCV/CMakeLists.txt
+++ b/llvm/lib/Target/RISCV/CMakeLists.txt
@@ -38,6 +38,7 @@ add_llvm_target(RISCVCodeGen
RISCVGatherScatterLowering.cpp
RISCVIndirectBranchTracking.cpp
RISCVInsertVSETVLI.cpp
+ RISCVSpillRewrite.cpp
RISCVInsertReadWriteCSR.cpp
RISCVInsertWriteVXRM.cpp
RISCVInstrInfo.cpp
diff --git a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h
index c65bd5b1d33631..12b2b8c79ef47f 100644
--- a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h
+++ b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h
@@ -126,6 +126,10 @@ enum {
ActiveElementsAffectResultShift = TargetOverlapConstraintTypeShift + 2,
ActiveElementsAffectResultMask = 1ULL << ActiveElementsAffectResultShift,
+
+ // Indicates whether the instruction produces widened result.
+ IsWidenShift = ActiveElementsAffectResultShift + 1,
+ IsWidenMask = 1 << IsWidenShift,
};
// Helper functions to read TSFlags.
@@ -149,6 +153,9 @@ static inline bool isTiedPseudo(uint64_t TSFlags) {
static inline bool hasSEWOp(uint64_t TSFlags) {
return TSFlags & HasSEWOpMask;
}
+
+/// \returns true if the instruction produces widened result.
+static inline bool isWiden(uint64_t TSFlags) { return TSFlags & IsWidenMask; }
/// \returns true if there is a VL operand for the instruction.
static inline bool hasVLOp(uint64_t TSFlags) {
return TSFlags & HasVLOpMask;
diff --git a/llvm/lib/Target/RISCV/RISCV.h b/llvm/lib/Target/RISCV/RISCV.h
index 5a94ada8f8dd46..4779043dc207a0 100644
--- a/llvm/lib/Target/RISCV/RISCV.h
+++ b/llvm/lib/Target/RISCV/RISCV.h
@@ -68,6 +68,9 @@ FunctionPass *createRISCVInsertVSETVLIPass();
void initializeRISCVInsertVSETVLIPass(PassRegistry &);
extern char &RISCVInsertVSETVLIID;
+FunctionPass *createRISCVSpillRewritePass();
+void initializeRISCVSpillRewritePass(PassRegistry &);
+
FunctionPass *createRISCVPostRAExpandPseudoPass();
void initializeRISCVPostRAExpandPseudoPass(PassRegistry &);
FunctionPass *createRISCVInsertReadWriteCSRPass();
diff --git a/llvm/lib/Target/RISCV/RISCVInstrFormats.td b/llvm/lib/Target/RISCV/RISCVInstrFormats.td
index 95f157064d73e2..d33badc331ac9a 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrFormats.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrFormats.td
@@ -226,6 +226,10 @@ class RVInstCommon<dag outs, dag ins, string opcodestr, string argstr,
bit ActiveElementsAffectResult = 0;
let TSFlags{23} = ActiveElementsAffectResult;
+
+ // Indicates whether the instruction produces widened result.
+ bit IsWiden = 0;
+ let TSFlags{24} = IsWiden;
}
class RVInst<dag outs, dag ins, string opcodestr, string argstr,
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
index 1b4303fbbcf809..9f022fea3b0cd6 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
@@ -6277,6 +6277,7 @@ foreach vti = AllIntegerVectors in {
//===----------------------------------------------------------------------===//
// 11.2. Vector Widening Integer Add/Subtract
//===----------------------------------------------------------------------===//
+let IsWiden = 1 in {
defm PseudoVWADDU : VPseudoVWALU_VV_VX<Commutable=1>;
defm PseudoVWSUBU : VPseudoVWALU_VV_VX;
defm PseudoVWADD : VPseudoVWALU_VV_VX<Commutable=1>;
@@ -6285,6 +6286,7 @@ defm PseudoVWADDU : VPseudoVWALU_WV_WX;
defm PseudoVWSUBU : VPseudoVWALU_WV_WX;
defm PseudoVWADD : VPseudoVWALU_WV_WX;
defm PseudoVWSUB : VPseudoVWALU_WV_WX;
+}
//===----------------------------------------------------------------------===//
// 11.3. Vector Integer Extension
@@ -6366,9 +6368,11 @@ defm PseudoVREM : VPseudoVDIV_VV_VX;
//===----------------------------------------------------------------------===//
// 11.12. Vector Widening Integer Multiply Instructions
//===----------------------------------------------------------------------===//
+let IsWiden = 1 in {
defm PseudoVWMUL : VPseudoVWMUL_VV_VX<Commutable=1>;
defm PseudoVWMULU : VPseudoVWMUL_VV_VX<Commutable=1>;
defm PseudoVWMULSU : VPseudoVWMUL_VV_VX;
+}
//===----------------------------------------------------------------------===//
// 11.13. Vector Single-Width Integer Multiply-Add Instructions
@@ -6381,10 +6385,12 @@ defm PseudoVNMSUB : VPseudoVMAC_VV_VX_AAXA;
//===----------------------------------------------------------------------===//
// 11.14. Vector Widening Integer Multiply-Add Instructions
//===----------------------------------------------------------------------===//
+let IsWiden = 1 in {
defm PseudoVWMACCU : VPseudoVWMAC_VV_VX<Commutable=1>;
defm PseudoVWMACC : VPseudoVWMAC_VV_VX<Commutable=1>;
defm PseudoVWMACCSU : VPseudoVWMAC_VV_VX;
defm PseudoVWMACCUS : VPseudoVWMAC_VX;
+}
//===----------------------------------------------------------------------===//
// 11.15. Vector Integer Merge Instructions
@@ -6458,12 +6464,14 @@ defm PseudoVFRSUB : VPseudoVALU_VF_RM;
//===----------------------------------------------------------------------===//
// 13.3. Vector Widening Floating-Point Add/Subtract Instructions
//===----------------------------------------------------------------------===//
+let IsWiden = 1 in {
let mayRaiseFPException = true, hasSideEffects = 0, hasPostISelHook = 1 in {
defm PseudoVFWADD : VPseudoVFWALU_VV_VF_RM;
defm PseudoVFWSUB : VPseudoVFWALU_VV_VF_RM;
defm PseudoVFWADD : VPseudoVFWALU_WV_WF_RM;
defm PseudoVFWSUB : VPseudoVFWALU_WV_WF_RM;
}
+}
//===----------------------------------------------------------------------===//
// 13.4. Vector Single-Width Floating-Point Multiply/Divide Instructions
@@ -6477,9 +6485,11 @@ defm PseudoVFRDIV : VPseudoVFRDIV_VF_RM;
//===----------------------------------------------------------------------===//
// 13.5. Vector Widening Floating-Point Multiply
//===----------------------------------------------------------------------===//
+let IsWiden = 1 in {
let mayRaiseFPException = true, hasSideEffects = 0 in {
defm PseudoVFWMUL : VPseudoVWMUL_VV_VF_RM;
}
+}
//===----------------------------------------------------------------------===//
// 13.6. Vector Single-Width Floating-Point Fused Multiply-Add Instructions
@@ -6498,6 +6508,7 @@ defm PseudoVFNMSUB : VPseudoVMAC_VV_VF_AAXA_RM;
//===----------------------------------------------------------------------===//
// 13.7. Vector Widening Floating-Point Fused Multiply-Add Instructions
//===----------------------------------------------------------------------===//
+let IsWiden = 1 in {
let mayRaiseFPException = true, hasSideEffects = 0, hasPostISelHook = 1 in {
defm PseudoVFWMACC : VPseudoVWMAC_VV_VF_RM;
defm PseudoVFWNMACC : VPseudoVWMAC_VV_VF_RM;
@@ -6506,6 +6517,7 @@ defm PseudoVFWNMSAC : VPseudoVWMAC_VV_VF_RM;
let Predicates = [HasStdExtZvfbfwma] in
defm PseudoVFWMACCBF16 : VPseudoVWMAC_VV_VF_BF_RM;
}
+}
//===----------------------------------------------------------------------===//
// 13.8. Vector Floating-Point Square-Root Instruction
@@ -6594,6 +6606,7 @@ defm PseudoVFCVT_RM_F_X : VPseudoVCVTF_RM_V;
//===----------------------------------------------------------------------===//
// 13.18. Widening Floating-Point/Integer Type-Convert Instructions
//===----------------------------------------------------------------------===//
+let IsWiden = 1 in {
let mayRaiseFPException = true in {
let hasSideEffects = 0, hasPostISelHook = 1 in {
defm PseudoVFWCVT_XU_F : VPseudoVWCVTI_V_RM;
@@ -6611,6 +6624,7 @@ defm PseudoVFWCVT_F_X : VPseudoVWCVTF_V;
defm PseudoVFWCVT_F_F : VPseudoVWCVTD_V;
defm PseudoVFWCVTBF16_F_F : VPseudoVWCVTD_V;
} // mayRaiseFPException = true
+}
//===----------------------------------------------------------------------===//
// 13.19. Narrowing Floating-Point/Integer Type-Convert Instructions
@@ -6661,11 +6675,13 @@ defm PseudoVREDMAX : VPseudoVREDMINMAX_VS;
//===----------------------------------------------------------------------===//
// 14.2. Vector Widening Integer Reduction Instructions
//===----------------------------------------------------------------------===//
+let IsWiden = 1 in {
let IsRVVWideningReduction = 1 in {
defm PseudoVWREDSUMU : VPseudoVWRED_VS;
defm PseudoVWREDSUM : VPseudoVWRED_VS;
}
} // Predicates = [HasVInstructions]
+}
let Predicates = [HasVInstructionsAnyF] in {
//===----------------------------------------------------------------------===//
@@ -6684,6 +6700,7 @@ defm PseudoVFREDMAX : VPseudoVFREDMINMAX_VS;
//===----------------------------------------------------------------------===//
// 14.4. Vector Widening Floating-Point Reduction Instructions
//===----------------------------------------------------------------------===//
+let IsWiden = 1 in {
let IsRVVWideningReduction = 1,
hasSideEffects = 0,
mayRaiseFPException = true in {
@@ -6692,6 +6709,7 @@ defm PseudoVFWREDOSUM : VPseudoVFWREDO_VS_RM;
}
} // Predicates = [HasVInstructionsAnyF]
+}
//===----------------------------------------------------------------------===//
// 15. Vector Mask Instructions
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZvk.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZvk.td
index cafd259031746d..f2f69be150f4d0 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoZvk.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZvk.td
@@ -529,7 +529,8 @@ let Predicates = [HasStdExtZvbb] in {
defm PseudoVCLZ : VPseudoVCLZ;
defm PseudoVCTZ : VPseudoVCTZ;
defm PseudoVCPOP : VPseudoVCPOP;
- defm PseudoVWSLL : VPseudoVWSLL;
+ let IsWiden = 1 in
+ defm PseudoVWSLL : VPseudoVWSLL;
} // Predicates = [HasStdExtZvbb]
let Predicates = [HasStdExtZvbc] in {
diff --git a/llvm/lib/Target/RISCV/RISCVSpillRewrite.cpp b/llvm/lib/Target/RISCV/RISCVSpillRewrite.cpp
new file mode 100644
index 00000000000000..d1c70ed944da6e
--- /dev/null
+++ b/llvm/lib/Target/RISCV/RISCVSpillRewrite.cpp
@@ -0,0 +1,407 @@
+//===----------- RISCVSpillRewrite.cpp - RISC-V Spill Rewrite -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a function pass that rewrite spills and reloads to
+// reduce the instruction latency by changing full register
+// store/load(VS1R/VL1R) to fractional store/load(VSE/VLE) needed and expands.
+//
+// The algorithm finds and rewrites spills(VS1R) to VSE if the spilled vreg only
+// needs fraction of a vreg(determined by the last write instruction's LMUL),
+// note that if the spilled register comes from different BB, it will find the
+// union LMUL of each defined BB. After then, it rewrites reloads(VL1R) to VLE
+// follows the corresponding spills in the spill slots. The algorithm runs until
+// there's no any rewrite.
+//
+//===----------------------------------------------------------------------===//
+
+#include "MCTargetDesc/RISCVBaseInfo.h"
+#include "MCTargetDesc/RISCVMCTargetDesc.h"
+#include "RISCV.h"
+#include "RISCVSubtarget.h"
+#include "llvm/ADT/Statistic.h"
+#include "llvm/CodeGen/LiveDebugVariables.h"
+#include "llvm/CodeGen/LiveIntervals.h"
+#include "llvm/CodeGen/LiveStacks.h"
+#include "llvm/CodeGen/MachineBasicBlock.h"
+#include "llvm/CodeGen/MachineFrameInfo.h"
+#include "llvm/CodeGen/MachineFunctionPass.h"
+#include "llvm/CodeGen/MachineInstrBuilder.h"
+#include "llvm/CodeGen/MachineOperand.h"
+#include "llvm/TargetParser/RISCVTargetParser.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "riscv-spill-rewrite"
+#define RISCV_SPILL_REWRITE_NAME "RISC-V Spill Rewrite pass"
+
+namespace {
+static inline bool isSpillInst(const MachineInstr &MI) {
+ return MI.getOpcode() == RISCV::VS1R_V;
+}
+
+static inline bool isReloadInst(const MachineInstr &MI) {
+ return MI.getOpcode() == RISCV::VL1RE8_V ||
+ MI.getOpcode() == RISCV::VL2RE8_V ||
+ MI.getOpcode() == RISCV::VL4RE8_V ||
+ MI.getOpcode() == RISCV::VL8RE8_V ||
+ MI.getOpcode() == RISCV::VL1RE16_V ||
+ MI.getOpcode() == RISCV::VL2RE16_V ||
+ MI.getOpcode() == RISCV::VL4RE16_V ||
+ MI.getOpcode() == RISCV::VL8RE16_V ||
+ MI.getOpcode() == RISCV::VL1RE32_V ||
+ MI.getOpcode() == RISCV::VL2RE32_V ||
+ MI.getOpcode() == RISCV::VL4RE32_V ||
+ MI.getOpcode() == RISCV::VL8RE32_V ||
+ MI.getOpcode() == RISCV::VL1RE64_V ||
+ MI.getOpcode() == RISCV::VL2RE64_V ||
+ MI.getOpcode() == RISCV::VL4RE64_V ||
+ MI.getOpcode() == RISCV::VL8RE64_V;
+}
+
+static inline bool hasSpillSlotObject(const MachineFrameInfo *MFI,
+ const MachineInstr &MI,
+ bool IsReload = false) {
+ unsigned MemOpIdx = IsReload ? 2 : 1;
+ if (MI.getNumOperands() <= MemOpIdx || !MI.getOperand(MemOpIdx).isFI())
+ return false;
+
+ int FI = MI.getOperand(MemOpIdx).getIndex();
+ return MFI->isSpillSlotObjectIndex(FI);
+}
+
+static inline RISCVII::VLMUL maxLMUL(RISCVII::VLMUL LMUL1,
+ RISCVII::VLMUL LMUL2) {
+ int LMUL1Val = std::numeric_limits<int>::min();
+ int LMUL2Val = std::numeric_limits<int>::min();
+
+ if (LMUL1 != RISCVII::LMUL_RESERVED) {
+ auto DecodedLMUL1 = RISCVVType::decodeVLMUL(LMUL1);
+ LMUL1Val = DecodedLMUL1.second ? -DecodedLMUL1.first : DecodedLMUL1.first;
+ }
+ if (LMUL2 != RISCVII::LMUL_RESERVED) {
+ auto DecodedLMUL2 = RISCVVType::decodeVLMUL(LMUL2);
+ LMUL2Val = DecodedLMUL2.second ? -DecodedLMUL2.first : DecodedLMUL2.first;
+ }
+
+ return LMUL1Val > LMUL2Val ? LMUL1 : LMUL2;
+}
+
+static inline RISCVII::VLMUL getWidenedFracLMUL(RISCVII::VLMUL LMUL) {
+ if (LMUL == RISCVII::LMUL_F8)
+ return RISCVII::LMUL_F4;
+ if (LMUL == RISCVII::LMUL_F4)
+ return RISCVII::LMUL_F2;
+ if (LMUL == RISCVII::LMUL_F2)
+ return RISCVII::LMUL_1;
+
+ llvm_unreachable("The LMUL is supposed to be fractional.");
+}
+
+class RISCVSpillRewrite : public MachineFunctionPass {
+ const RISCVSubtarget *ST = nullptr;
+ const TargetInstrInfo *TII = nullptr;
+ MachineRegisterInfo *MRI = nullptr;
+ MachineFrameInfo *MFI = nullptr;
+ LiveIntervals *LIS = nullptr;
+
+public:
+ static char ID;
+ RISCVSpillRewrite() : MachineFunctionPass(ID) {}
+ StringRef getPassName() const override { return RISCV_SPILL_REWRITE_NAME; }
+ bool runOnMachineFunction(MachineFunction &MF) override;
+ void getAnalysisUsage(AnalysisUsage &AU) const override;
+
+private:
+ bool tryToRewrite(MachineFunction &MF);
+
+ // This function find the Reg's LMUL in its defining inst, if there're
+ // multiple instructions that define the Reg in different BB, recursively find
+ // them and return the maximum LMUL that are found. If it can't be found due
+ // to any reason such as the register is dead, it returns
+ // RISCVII::LMUL_RESERVE which means the Reg can't be rewritten.
+ // BegI represents the starting instruction in the beginning, this is used to
+ // determine whether it encounters a loop, if so then the defining instruction
+ // doesn't exist in this MBB.
+ RISCVII::VLMUL
+ findDefiningInstUnionLMUL(MachineBasicBlock &MBB, Register Reg,
+ DenseMap<MachineInstr *, bool> &Visited,
+ MachineBasicBlock::reverse_iterator BegI = nullptr);
+ bool tryToRewriteSpill(MachineBasicBlock &MBB, MachineBasicBlock::iterator I,
+ std::map<int, RISCVII::VLMUL> &SpillLMUL);
+ bool tryToRewriteReload(MachineBasicBlock &MBB, MachineBasicBlock::iterator I,
+ int FI,
+ const std::map<int, RISCVII::VLMUL> &SpillLMUL);
+};
+
+} // end anonymous namespace
+
+char RISCVSpillRewrite::ID = 0;
+
+INITIALIZE_PASS(RISCVSpillRewrite, DEBUG_TYPE, RISCV_SPILL_REWRITE_NAME, false,
+ false)
+
+void RISCVSpillRewrite::getAnalysisUsage(AnalysisUsage &AU) const {
+ AU.setPreservesCFG();
+
+ AU.addPreserved<LiveIntervalsWrapperPass>();
+ AU.addRequired<LiveIntervalsWrapperPass>();
+ AU.addPreserved<SlotIndexesWrapperPass>();
+ AU.addRequired<SlotIndexesWrapperPass>();
+ AU.addPreserved<LiveDebugVariables>();
+ AU.addPreserved<LiveStacks>();
+
+ MachineFunctionPass::getAnalysisUsage(AU);
+}
+
+RISCVII::VLMUL RISCVSpillRewrite::findDefiningInstUnionLMUL(
+ MachineBasicBlock &MBB, Register Reg,
+ DenseMap<MachineInstr *, bool> &Visited,
+ MachineBasicBlock::reverse_iterator BegI) {
+ for (auto I = (BegI == nullptr ? MBB.rbegin() : BegI); I != MBB.rend(); ++I) {
+ if (I->isDebugInstr())
+ continue;
+
+ // Return the minimum LMUL if this MBB is a loop body and we meet the
+ // instruction that is already visited, it means the LMUL in this MBB is
+ // dont-care.
+ if (Visited.contains(&*I))
+ return RISCVII::LMUL_F8;
+
+ Visited[&*I];
+ if (I->definesRegister(Reg, nullptr)) {
+ if (I->registerDefIsDead(Reg, nullptr))
+ return RISCVII::LMUL_RESERVED;
+
+ if (isReloadInst(*I))
+ return RISCVII::LMUL_1;
+
+ if (auto DstSrcPair = TII->isCopyInstr(*I))
+ return findDefiningInstUnionLMUL(MBB, DstSrcPair->Source->getReg(),
+ Visited, *++I);
+
+ const uint64_t TSFlags = I->getDesc().TSFlags;
+ assert(RISCVII::hasSEWOp(TSFlags));
+
+ // If the instruction is tail undisturbed, we need to preserve the full
+ // vector register since the tail data might be used somewhere.
+ if (RISCVII::hasVecPolicyOp(TSFlags)) {
+ const MachineOperand &PolicyOp =
+ I->getOperand(I->getNumExplicitOperands() - 1);
+ if ((PolicyOp.getImm() & RISCVII::TAIL_AGNOSTIC) == 0)
+ return RISCVII::VLMUL::LMUL_1;
+ }
+
+ RISCVII::VLMUL LMUL = RISCVII::getLMul(TSFlags);
+ if (RISCVII::isRVVWideningReduction(TSFlags)) {
+ // Widening reduction produces only single element result, so we just
+ // need to calculate LMUL for single element.
+ int Log2SEW =
+ I->getOperand(RISCVII::getSEWOpNum(I->getDesc())).getImm();
+ int Log2LMUL = Log2SEW - Log2_64(ST->getELen());
+ LMUL =
+ static_cast<RISCVII::VLMUL>(Log2LMUL < 0 ? Log2LMUL + 8 : Log2LMUL);
+ }
+ if (RISCVII::isWiden(TSFlags))
+ LMUL = getWidenedFracLMUL(LMUL);
+
+ return LMUL;
+ }
+ }
+
+ assert(MBB.isLiveIn(Reg));
+
+ // If Reg's defining inst is not found in this BB, find it in it's
+ // predecessors.
+ RISCVII::VLMUL LMUL = RISCVII::LMUL_RESERVED;
+ for (MachineBasicBlock *P : MBB.predecessors()) {
+ RISCVII::VLMUL PredLMUL = findDefiningInstUnionLMUL(*P, Reg, Visited);
+ if (PredLMUL == RISCVII::LMUL_RESERVED)
+ continue;
+
+ if (LMUL == RISCVII::LMUL_RESERVED) {
+ LMUL = PredLMUL;
+ continue;
+ }
+
+ LMUL = maxLMUL(LMUL, PredLMUL);
+ }
+
+ return LMUL;
+}
+
+bool RISCVSpillRewrite::tryToRewriteSpill(
+ MachineBasicBlock &MBB, MachineBasicBlock::iterator I,
+ std::map<int, RISCVII::VLMUL> &SpillLMUL) {
+ Register SrcReg = I->getOperand(0).getReg();
+ unsigned Opcode = 0;
+ DenseMap<MachineInstr *, bool> Visited;
+ // Find the nearest inst defines this spilled reg.
+ RISCVII::VLMUL LMUL = findDefiningInstUnionLMUL(MBB, SrcReg, Visited, *I);
+ // If the register's defined inst just defines partial of register, we only
+ // need to store partial register.
+ switch (LMUL) {
+ case RISCVII::LMUL_F2:
+ Opcode = RISCV::PseudoVSE8_V_MF2;
+ break;
+ case RISCVII::LMUL_F4:
+ Opcode = RISCV::PseudoVSE8_V_MF4;
+ break;
+ case RISCVII::LMUL_F8:
+ Opcode = RISCV::PseudoVSE8_V_MF8;
+ break;
+ default:
+ break;
+ }
+
+ // No need to rewrite.
+ if (!Opcode)
+ return false;
+
+ int FI = I->getOperand(1).getIndex();
+ auto updateLMUL = [&](RISCVII::VLMUL LMUL) {
+ assert(!SpillLMUL.count(FI) &&
+ "Each frame index should only be used once.");
+ SpillLMUL[FI] = LMUL;
+ };
+
+ if (Opcode == RISCV::PseudoVSE8_V_MF2)
+ updateLMUL(RISCVII::LMUL_F2);
+ else if (Opcode == RISCV::PseudoVSE8_V_MF4)
+ updateLMUL(RISCVII::LMUL_F4);
+ else if (Opcode == RISCV::PseudoVSE8_V_MF8...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/105661
More information about the llvm-commits
mailing list