[llvm] 3d7fa6d - [RISCV] Move allWUsers from RISCVInstrInfo to RISCVOptWInstrs.
Craig Topper via llvm-commits
llvm-commits at lists.llvm.org
Wed Mar 29 15:16:54 PDT 2023
Author: Craig Topper
Date: 2023-03-29T15:13:09-07:00
New Revision: 3d7fa6dc7cb0b4f6a52500e52faa360cb95f1406
URL: https://github.com/llvm/llvm-project/commit/3d7fa6dc7cb0b4f6a52500e52faa360cb95f1406
DIFF: https://github.com/llvm/llvm-project/commit/3d7fa6dc7cb0b4f6a52500e52faa360cb95f1406.diff
LOG: [RISCV] Move allWUsers from RISCVInstrInfo to RISCVOptWInstrs.
It was only in RISCVInstrInfo because it was used by 2 passes, but those
passes have been merged in D147173.
Reviewed By: asb
Differential Revision: https://reviews.llvm.org/D147174
Added:
Modified:
llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
llvm/lib/Target/RISCV/RISCVInstrInfo.h
llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
index 523895c69794..2ad5b814ecf1 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
@@ -2614,226 +2614,6 @@ void RISCVInstrInfo::getVLENFactoredAmount(MachineFunction &MF,
}
}
-// Checks if all users only demand the lower \p OrigBits of the original
-// instruction's result.
-// TODO: handle multiple interdependent transformations
-bool RISCVInstrInfo::hasAllNBitUsers(const MachineInstr &OrigMI,
- const MachineRegisterInfo &MRI,
- unsigned OrigBits) const {
-
- SmallSet<std::pair<const MachineInstr *, unsigned>, 4> Visited;
- SmallVector<std::pair<const MachineInstr *, unsigned>, 4> Worklist;
-
- Worklist.push_back(std::make_pair(&OrigMI, OrigBits));
-
- while (!Worklist.empty()) {
- auto P = Worklist.pop_back_val();
- const MachineInstr *MI = P.first;
- unsigned Bits = P.second;
-
- if (!Visited.insert(P).second)
- continue;
-
- // Only handle instructions with one def.
- if (MI->getNumExplicitDefs() != 1)
- return false;
-
- for (auto &UserOp : MRI.use_operands(MI->getOperand(0).getReg())) {
- const MachineInstr *UserMI = UserOp.getParent();
- unsigned OpIdx = UserOp.getOperandNo();
-
- switch (UserMI->getOpcode()) {
- default:
- return false;
-
- case RISCV::ADDIW:
- case RISCV::ADDW:
- case RISCV::DIVUW:
- case RISCV::DIVW:
- case RISCV::MULW:
- case RISCV::REMUW:
- case RISCV::REMW:
- case RISCV::SLLIW:
- case RISCV::SLLW:
- case RISCV::SRAIW:
- case RISCV::SRAW:
- case RISCV::SRLIW:
- case RISCV::SRLW:
- case RISCV::SUBW:
- case RISCV::ROLW:
- case RISCV::RORW:
- case RISCV::RORIW:
- case RISCV::CLZW:
- case RISCV::CTZW:
- case RISCV::CPOPW:
- case RISCV::SLLI_UW:
- case RISCV::FMV_W_X:
- case RISCV::FCVT_H_W:
- case RISCV::FCVT_H_WU:
- case RISCV::FCVT_S_W:
- case RISCV::FCVT_S_WU:
- case RISCV::FCVT_D_W:
- case RISCV::FCVT_D_WU:
- if (Bits >= 32)
- break;
- return false;
- case RISCV::SEXT_B:
- case RISCV::PACKH:
- if (Bits >= 8)
- break;
- return false;
- case RISCV::SEXT_H:
- case RISCV::FMV_H_X:
- case RISCV::ZEXT_H_RV32:
- case RISCV::ZEXT_H_RV64:
- case RISCV::PACKW:
- if (Bits >= 16)
- break;
- return false;
-
- case RISCV::PACK:
- if (Bits >= (STI.getXLen() / 2))
- break;
- return false;
-
- case RISCV::SRLI: {
- // If we are shifting right by less than Bits, and users don't demand
- // any bits that were shifted into [Bits-1:0], then we can consider this
- // as an N-Bit user.
- unsigned ShAmt = UserMI->getOperand(2).getImm();
- if (Bits > ShAmt) {
- Worklist.push_back(std::make_pair(UserMI, Bits - ShAmt));
- break;
- }
- return false;
- }
-
- // these overwrite higher input bits, otherwise the lower word of output
- // depends only on the lower word of input. So check their uses read W.
- case RISCV::SLLI:
- if (Bits >= (STI.getXLen() - UserMI->getOperand(2).getImm()))
- break;
- Worklist.push_back(std::make_pair(UserMI, Bits));
- break;
- case RISCV::ANDI: {
- uint64_t Imm = UserMI->getOperand(2).getImm();
- if (Bits >= (unsigned)llvm::bit_width(Imm))
- break;
- Worklist.push_back(std::make_pair(UserMI, Bits));
- break;
- }
- case RISCV::ORI: {
- uint64_t Imm = UserMI->getOperand(2).getImm();
- if (Bits >= (unsigned)llvm::bit_width<uint64_t>(~Imm))
- break;
- Worklist.push_back(std::make_pair(UserMI, Bits));
- break;
- }
-
- case RISCV::SLL:
- case RISCV::BSET:
- case RISCV::BCLR:
- case RISCV::BINV:
- // Operand 2 is the shift amount which uses log2(xlen) bits.
- if (OpIdx == 2) {
- if (Bits >= Log2_32(STI.getXLen()))
- break;
- return false;
- }
- Worklist.push_back(std::make_pair(UserMI, Bits));
- break;
-
- case RISCV::SRA:
- case RISCV::SRL:
- case RISCV::ROL:
- case RISCV::ROR:
- // Operand 2 is the shift amount which uses 6 bits.
- if (OpIdx == 2 && Bits >= Log2_32(STI.getXLen()))
- break;
- return false;
-
- case RISCV::ADD_UW:
- case RISCV::SH1ADD_UW:
- case RISCV::SH2ADD_UW:
- case RISCV::SH3ADD_UW:
- // Operand 1 is implicitly zero extended.
- if (OpIdx == 1 && Bits >= 32)
- break;
- Worklist.push_back(std::make_pair(UserMI, Bits));
- break;
-
- case RISCV::BEXTI:
- if (UserMI->getOperand(2).getImm() >= Bits)
- return false;
- break;
-
- case RISCV::SB:
- // The first argument is the value to store.
- if (OpIdx == 0 && Bits >= 8)
- break;
- return false;
- case RISCV::SH:
- // The first argument is the value to store.
- if (OpIdx == 0 && Bits >= 16)
- break;
- return false;
- case RISCV::SW:
- // The first argument is the value to store.
- if (OpIdx == 0 && Bits >= 32)
- break;
- return false;
-
- // For these, lower word of output in these operations, depends only on
- // the lower word of input. So, we check all uses only read lower word.
- case RISCV::COPY:
- case RISCV::PHI:
-
- case RISCV::ADD:
- case RISCV::ADDI:
- case RISCV::AND:
- case RISCV::MUL:
- case RISCV::OR:
- case RISCV::SUB:
- case RISCV::XOR:
- case RISCV::XORI:
-
- case RISCV::ANDN:
- case RISCV::BREV8:
- case RISCV::CLMUL:
- case RISCV::ORC_B:
- case RISCV::ORN:
- case RISCV::SH1ADD:
- case RISCV::SH2ADD:
- case RISCV::SH3ADD:
- case RISCV::XNOR:
- case RISCV::BSETI:
- case RISCV::BCLRI:
- case RISCV::BINVI:
- Worklist.push_back(std::make_pair(UserMI, Bits));
- break;
-
- case RISCV::PseudoCCMOVGPR:
- // Either operand 4 or operand 5 is returned by this instruction. If
- // only the lower word of the result is used, then only the lower word
- // of operand 4 and 5 is used.
- if (OpIdx != 4 && OpIdx != 5)
- return false;
- Worklist.push_back(std::make_pair(UserMI, Bits));
- break;
-
- case RISCV::VT_MASKC:
- case RISCV::VT_MASKCN:
- if (OpIdx != 1)
- return false;
- Worklist.push_back(std::make_pair(UserMI, Bits));
- break;
- }
- }
- }
-
- return true;
-}
-
// Returns true if this is the sext.w pattern, addiw rd, rs1, 0.
bool RISCV::isSEXT_W(const MachineInstr &MI) {
return MI.getOpcode() == RISCV::ADDIW && MI.getOperand(1).isReg() &&
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.h b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
index 64e0bc0cd550..01f112a386d0 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
@@ -227,17 +227,6 @@ class RISCVInstrInfo : public RISCVGenInstrInfo {
std::optional<unsigned> getInverseOpcode(unsigned Opcode) const override;
- // Returns true if all uses of OrigMI only depend on the lower \p NBits bits
- // of its output.
- bool hasAllNBitUsers(const MachineInstr &MI, const MachineRegisterInfo &MRI,
- unsigned NBits) const;
- // Returns true if all uses of OrigMI only depend on the lower word of its
- // output, so we can transform OrigMI to the corresponding W-version.
- bool hasAllWUsers(const MachineInstr &MI,
- const MachineRegisterInfo &MRI) const {
- return hasAllNBitUsers(MI, MRI, 32);
- }
-
protected:
const RISCVSubtarget &STI;
};
diff --git a/llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp b/llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp
index 40fe5f9987dd..7014755b6706 100644
--- a/llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp
+++ b/llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp
@@ -54,9 +54,9 @@ class RISCVOptWInstrs : public MachineFunctionPass {
bool runOnMachineFunction(MachineFunction &MF) override;
bool removeSExtWInstrs(MachineFunction &MF, const RISCVInstrInfo &TII,
- MachineRegisterInfo &MRI);
+ const RISCVSubtarget &ST, MachineRegisterInfo &MRI);
bool stripWSuffixes(MachineFunction &MF, const RISCVInstrInfo &TII,
- MachineRegisterInfo &MRI);
+ const RISCVSubtarget &ST, MachineRegisterInfo &MRI);
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.setPreservesCFG();
@@ -76,6 +76,231 @@ FunctionPass *llvm::createRISCVOptWInstrsPass() {
return new RISCVOptWInstrs();
}
+// Checks if all users only demand the lower \p OrigBits of the original
+// instruction's result.
+// TODO: handle multiple interdependent transformations
+static bool hasAllNBitUsers(const MachineInstr &OrigMI,
+ const RISCVSubtarget &ST,
+ const MachineRegisterInfo &MRI, unsigned OrigBits) {
+
+ SmallSet<std::pair<const MachineInstr *, unsigned>, 4> Visited;
+ SmallVector<std::pair<const MachineInstr *, unsigned>, 4> Worklist;
+
+ Worklist.push_back(std::make_pair(&OrigMI, OrigBits));
+
+ while (!Worklist.empty()) {
+ auto P = Worklist.pop_back_val();
+ const MachineInstr *MI = P.first;
+ unsigned Bits = P.second;
+
+ if (!Visited.insert(P).second)
+ continue;
+
+ // Only handle instructions with one def.
+ if (MI->getNumExplicitDefs() != 1)
+ return false;
+
+ for (auto &UserOp : MRI.use_operands(MI->getOperand(0).getReg())) {
+ const MachineInstr *UserMI = UserOp.getParent();
+ unsigned OpIdx = UserOp.getOperandNo();
+
+ switch (UserMI->getOpcode()) {
+ default:
+ return false;
+
+ case RISCV::ADDIW:
+ case RISCV::ADDW:
+ case RISCV::DIVUW:
+ case RISCV::DIVW:
+ case RISCV::MULW:
+ case RISCV::REMUW:
+ case RISCV::REMW:
+ case RISCV::SLLIW:
+ case RISCV::SLLW:
+ case RISCV::SRAIW:
+ case RISCV::SRAW:
+ case RISCV::SRLIW:
+ case RISCV::SRLW:
+ case RISCV::SUBW:
+ case RISCV::ROLW:
+ case RISCV::RORW:
+ case RISCV::RORIW:
+ case RISCV::CLZW:
+ case RISCV::CTZW:
+ case RISCV::CPOPW:
+ case RISCV::SLLI_UW:
+ case RISCV::FMV_W_X:
+ case RISCV::FCVT_H_W:
+ case RISCV::FCVT_H_WU:
+ case RISCV::FCVT_S_W:
+ case RISCV::FCVT_S_WU:
+ case RISCV::FCVT_D_W:
+ case RISCV::FCVT_D_WU:
+ if (Bits >= 32)
+ break;
+ return false;
+ case RISCV::SEXT_B:
+ case RISCV::PACKH:
+ if (Bits >= 8)
+ break;
+ return false;
+ case RISCV::SEXT_H:
+ case RISCV::FMV_H_X:
+ case RISCV::ZEXT_H_RV32:
+ case RISCV::ZEXT_H_RV64:
+ case RISCV::PACKW:
+ if (Bits >= 16)
+ break;
+ return false;
+
+ case RISCV::PACK:
+ if (Bits >= (ST.getXLen() / 2))
+ break;
+ return false;
+
+ case RISCV::SRLI: {
+ // If we are shifting right by less than Bits, and users don't demand
+ // any bits that were shifted into [Bits-1:0], then we can consider this
+ // as an N-Bit user.
+ unsigned ShAmt = UserMI->getOperand(2).getImm();
+ if (Bits > ShAmt) {
+ Worklist.push_back(std::make_pair(UserMI, Bits - ShAmt));
+ break;
+ }
+ return false;
+ }
+
+ // these overwrite higher input bits, otherwise the lower word of output
+ // depends only on the lower word of input. So check their uses read W.
+ case RISCV::SLLI:
+ if (Bits >= (ST.getXLen() - UserMI->getOperand(2).getImm()))
+ break;
+ Worklist.push_back(std::make_pair(UserMI, Bits));
+ break;
+ case RISCV::ANDI: {
+ uint64_t Imm = UserMI->getOperand(2).getImm();
+ if (Bits >= (unsigned)llvm::bit_width(Imm))
+ break;
+ Worklist.push_back(std::make_pair(UserMI, Bits));
+ break;
+ }
+ case RISCV::ORI: {
+ uint64_t Imm = UserMI->getOperand(2).getImm();
+ if (Bits >= (unsigned)llvm::bit_width<uint64_t>(~Imm))
+ break;
+ Worklist.push_back(std::make_pair(UserMI, Bits));
+ break;
+ }
+
+ case RISCV::SLL:
+ case RISCV::BSET:
+ case RISCV::BCLR:
+ case RISCV::BINV:
+ // Operand 2 is the shift amount which uses log2(xlen) bits.
+ if (OpIdx == 2) {
+ if (Bits >= Log2_32(ST.getXLen()))
+ break;
+ return false;
+ }
+ Worklist.push_back(std::make_pair(UserMI, Bits));
+ break;
+
+ case RISCV::SRA:
+ case RISCV::SRL:
+ case RISCV::ROL:
+ case RISCV::ROR:
+ // Operand 2 is the shift amount which uses 6 bits.
+ if (OpIdx == 2 && Bits >= Log2_32(ST.getXLen()))
+ break;
+ return false;
+
+ case RISCV::ADD_UW:
+ case RISCV::SH1ADD_UW:
+ case RISCV::SH2ADD_UW:
+ case RISCV::SH3ADD_UW:
+ // Operand 1 is implicitly zero extended.
+ if (OpIdx == 1 && Bits >= 32)
+ break;
+ Worklist.push_back(std::make_pair(UserMI, Bits));
+ break;
+
+ case RISCV::BEXTI:
+ if (UserMI->getOperand(2).getImm() >= Bits)
+ return false;
+ break;
+
+ case RISCV::SB:
+ // The first argument is the value to store.
+ if (OpIdx == 0 && Bits >= 8)
+ break;
+ return false;
+ case RISCV::SH:
+ // The first argument is the value to store.
+ if (OpIdx == 0 && Bits >= 16)
+ break;
+ return false;
+ case RISCV::SW:
+ // The first argument is the value to store.
+ if (OpIdx == 0 && Bits >= 32)
+ break;
+ return false;
+
+ // For these, lower word of output in these operations, depends only on
+ // the lower word of input. So, we check all uses only read lower word.
+ case RISCV::COPY:
+ case RISCV::PHI:
+
+ case RISCV::ADD:
+ case RISCV::ADDI:
+ case RISCV::AND:
+ case RISCV::MUL:
+ case RISCV::OR:
+ case RISCV::SUB:
+ case RISCV::XOR:
+ case RISCV::XORI:
+
+ case RISCV::ANDN:
+ case RISCV::BREV8:
+ case RISCV::CLMUL:
+ case RISCV::ORC_B:
+ case RISCV::ORN:
+ case RISCV::SH1ADD:
+ case RISCV::SH2ADD:
+ case RISCV::SH3ADD:
+ case RISCV::XNOR:
+ case RISCV::BSETI:
+ case RISCV::BCLRI:
+ case RISCV::BINVI:
+ Worklist.push_back(std::make_pair(UserMI, Bits));
+ break;
+
+ case RISCV::PseudoCCMOVGPR:
+ // Either operand 4 or operand 5 is returned by this instruction. If
+ // only the lower word of the result is used, then only the lower word
+ // of operand 4 and 5 is used.
+ if (OpIdx != 4 && OpIdx != 5)
+ return false;
+ Worklist.push_back(std::make_pair(UserMI, Bits));
+ break;
+
+ case RISCV::VT_MASKC:
+ case RISCV::VT_MASKCN:
+ if (OpIdx != 1)
+ return false;
+ Worklist.push_back(std::make_pair(UserMI, Bits));
+ break;
+ }
+ }
+ }
+
+ return true;
+}
+
+static bool hasAllWUsers(const MachineInstr &OrigMI, const RISCVSubtarget &ST,
+ const MachineRegisterInfo &MRI) {
+ return hasAllNBitUsers(OrigMI, ST, MRI, 32);
+}
+
// This function returns true if the machine instruction always outputs a value
// where bits 63:32 match bit 31.
static bool isSignExtendingOpW(const MachineInstr &MI,
@@ -110,8 +335,8 @@ static bool isSignExtendingOpW(const MachineInstr &MI,
return false;
}
-static bool isSignExtendedW(Register SrcReg, const MachineRegisterInfo &MRI,
- const RISCVInstrInfo &TII,
+static bool isSignExtendedW(Register SrcReg, const RISCVSubtarget &ST,
+ const MachineRegisterInfo &MRI,
SmallPtrSetImpl<MachineInstr *> &FixableDef) {
SmallPtrSet<const MachineInstr *, 4> Visited;
@@ -300,7 +525,7 @@ static bool isSignExtendedW(Register SrcReg, const MachineRegisterInfo &MRI,
case RISCV::LWU:
case RISCV::MUL:
case RISCV::SUB:
- if (TII.hasAllWUsers(*MI, MRI)) {
+ if (hasAllWUsers(*MI, ST, MRI)) {
FixableDef.insert(MI);
break;
}
@@ -335,6 +560,7 @@ static unsigned getWOp(unsigned Opcode) {
bool RISCVOptWInstrs::removeSExtWInstrs(MachineFunction &MF,
const RISCVInstrInfo &TII,
+ const RISCVSubtarget &ST,
MachineRegisterInfo &MRI) {
if (DisableSExtWRemoval)
return false;
@@ -355,8 +581,8 @@ bool RISCVOptWInstrs::removeSExtWInstrs(MachineFunction &MF,
// If all users only use the lower bits, this sext.w is redundant.
// Or if all definitions reaching MI sign-extend their output,
// then sext.w is redundant.
- if (!TII.hasAllWUsers(*MI, MRI) &&
- !isSignExtendedW(SrcReg, MRI, TII, FixableDefs))
+ if (!hasAllWUsers(*MI, ST, MRI) &&
+ !isSignExtendedW(SrcReg, ST, MRI, FixableDefs))
continue;
Register DstReg = MI->getOperand(0).getReg();
@@ -388,6 +614,7 @@ bool RISCVOptWInstrs::removeSExtWInstrs(MachineFunction &MF,
bool RISCVOptWInstrs::stripWSuffixes(MachineFunction &MF,
const RISCVInstrInfo &TII,
+ const RISCVSubtarget &ST,
MachineRegisterInfo &MRI) {
if (DisableStripWSuffix)
return false;
@@ -406,7 +633,7 @@ bool RISCVOptWInstrs::stripWSuffixes(MachineFunction &MF,
case RISCV::SLLIW: Opc = RISCV::SLLI; break;
}
- if (TII.hasAllWUsers(MI, MRI)) {
+ if (hasAllWUsers(MI, ST, MRI)) {
MI.setDesc(TII.get(Opc));
MadeChange = true;
}
@@ -428,8 +655,8 @@ bool RISCVOptWInstrs::runOnMachineFunction(MachineFunction &MF) {
return false;
bool MadeChange = false;
- MadeChange |= removeSExtWInstrs(MF, TII, MRI);
- MadeChange |= stripWSuffixes(MF, TII, MRI);
+ MadeChange |= removeSExtWInstrs(MF, TII, ST, MRI);
+ MadeChange |= stripWSuffixes(MF, TII, ST, MRI);
return MadeChange;
}
More information about the llvm-commits
mailing list