[llvm] 8fb2160 - [RISCV] Use DenseMap to track V0 definition. NFC (#84465)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Mar 21 00:38:47 PDT 2024
Author: Luke Lau
Date: 2024-03-21T15:38:43+08:00
New Revision: 8fb2160a76b5f051f4cc8f5c8c097830bc91c22c
URL: https://github.com/llvm/llvm-project/commit/8fb2160a76b5f051f4cc8f5c8c097830bc91c22c
DIFF: https://github.com/llvm/llvm-project/commit/8fb2160a76b5f051f4cc8f5c8c097830bc91c22c.diff
LOG: [RISCV] Use DenseMap to track V0 definition. NFC (#84465)
Reviving some of the progress on #71764. To recap, we explored removing
the V0 register copies to simplify the pass, but hit a limitation with
the register allocator due to our use of the vmv0 singleton reg class
and early-clobber constraints.
So since we will have to continue to track the definition of V0
ourselves, this patch simplifies it by storing it in a map. It will
allow us to move about copies to V0 in #71764 without having to do extra
bookkeeping.
Added:
Modified:
llvm/lib/Target/RISCV/RISCVFoldMasks.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Target/RISCV/RISCVFoldMasks.cpp b/llvm/lib/Target/RISCV/RISCVFoldMasks.cpp
index fddbaa97d0638c..2089f5dda6fe52 100644
--- a/llvm/lib/Target/RISCV/RISCVFoldMasks.cpp
+++ b/llvm/lib/Target/RISCV/RISCVFoldMasks.cpp
@@ -47,10 +47,13 @@ class RISCVFoldMasks : public MachineFunctionPass {
StringRef getPassName() const override { return "RISC-V Fold Masks"; }
private:
- bool convertToUnmasked(MachineInstr &MI, MachineInstr *MaskDef) const;
- bool convertVMergeToVMv(MachineInstr &MI, MachineInstr *MaskDef) const;
+ bool convertToUnmasked(MachineInstr &MI) const;
+ bool convertVMergeToVMv(MachineInstr &MI) const;
- bool isAllOnesMask(MachineInstr *MaskDef) const;
+ bool isAllOnesMask(const MachineInstr *MaskDef) const;
+
+ /// Maps uses of V0 to the corresponding def of V0.
+ DenseMap<const MachineInstr *, const MachineInstr *> V0Defs;
};
} // namespace
@@ -59,10 +62,9 @@ char RISCVFoldMasks::ID = 0;
INITIALIZE_PASS(RISCVFoldMasks, DEBUG_TYPE, "RISC-V Fold Masks", false, false)
-bool RISCVFoldMasks::isAllOnesMask(MachineInstr *MaskDef) const {
- if (!MaskDef)
- return false;
- assert(MaskDef->isCopy() && MaskDef->getOperand(0).getReg() == RISCV::V0);
+bool RISCVFoldMasks::isAllOnesMask(const MachineInstr *MaskDef) const {
+ assert(MaskDef && MaskDef->isCopy() &&
+ MaskDef->getOperand(0).getReg() == RISCV::V0);
Register SrcReg = TRI->lookThruCopyLike(MaskDef->getOperand(1).getReg(), MRI);
if (!SrcReg.isVirtual())
return false;
@@ -89,8 +91,7 @@ bool RISCVFoldMasks::isAllOnesMask(MachineInstr *MaskDef) const {
// Transform (VMERGE_VVM_<LMUL> false, false, true, allones, vl, sew) to
// (VMV_V_V_<LMUL> false, true, vl, sew). It may decrease uses of VMSET.
-bool RISCVFoldMasks::convertVMergeToVMv(MachineInstr &MI,
- MachineInstr *V0Def) const {
+bool RISCVFoldMasks::convertVMergeToVMv(MachineInstr &MI) const {
#define CASE_VMERGE_TO_VMV(lmul) \
case RISCV::PseudoVMERGE_VVM_##lmul: \
NewOpc = RISCV::PseudoVMV_V_V_##lmul; \
@@ -116,7 +117,7 @@ bool RISCVFoldMasks::convertVMergeToVMv(MachineInstr &MI,
return false;
assert(MI.getOperand(4).isReg() && MI.getOperand(4).getReg() == RISCV::V0);
- if (!isAllOnesMask(V0Def))
+ if (!isAllOnesMask(V0Defs.lookup(&MI)))
return false;
MI.setDesc(TII->get(NewOpc));
@@ -133,14 +134,13 @@ bool RISCVFoldMasks::convertVMergeToVMv(MachineInstr &MI,
return true;
}
-bool RISCVFoldMasks::convertToUnmasked(MachineInstr &MI,
- MachineInstr *MaskDef) const {
+bool RISCVFoldMasks::convertToUnmasked(MachineInstr &MI) const {
const RISCV::RISCVMaskedPseudoInfo *I =
RISCV::getMaskedPseudoInfo(MI.getOpcode());
if (!I)
return false;
- if (!isAllOnesMask(MaskDef))
+ if (!isAllOnesMask(V0Defs.lookup(&MI)))
return false;
// There are two classes of pseudos in the table - compares and
@@ -198,20 +198,26 @@ bool RISCVFoldMasks::runOnMachineFunction(MachineFunction &MF) {
// $v0:vr = COPY %mask:vr
// %x:vr = Pseudo_MASK %a:vr, %b:br, $v0:vr
//
- // Because $v0 isn't in SSA, keep track of it so we can check the mask operand
- // on each pseudo.
- MachineInstr *CurrentV0Def;
- for (MachineBasicBlock &MBB : MF) {
- CurrentV0Def = nullptr;
- for (MachineInstr &MI : MBB) {
- Changed |= convertToUnmasked(MI, CurrentV0Def);
- Changed |= convertVMergeToVMv(MI, CurrentV0Def);
+ // Because $v0 isn't in SSA, keep track of its definition at each use so we
+ // can check mask operands.
+ for (const MachineBasicBlock &MBB : MF) {
+ const MachineInstr *CurrentV0Def = nullptr;
+ for (const MachineInstr &MI : MBB) {
+ if (MI.readsRegister(RISCV::V0, TRI))
+ V0Defs[&MI] = CurrentV0Def;
if (MI.definesRegister(RISCV::V0, TRI))
CurrentV0Def = &MI;
}
}
+ for (MachineBasicBlock &MBB : MF) {
+ for (MachineInstr &MI : MBB) {
+ Changed |= convertToUnmasked(MI);
+ Changed |= convertVMergeToVMv(MI);
+ }
+ }
+
return Changed;
}
More information about the llvm-commits
mailing list