[llvm] ac51cf1 - [RISCV] Refactor RISCV::hasAllWUsers to hasAllNBitUsers similar to RISCVISelDAGToDAG's version. NFC

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Wed Dec 28 12:49:31 PST 2022


Author: Craig Topper
Date: 2022-12-28T12:49:23-08:00
New Revision: ac51cf19604630919f54131ba1d9e3d9443f715c

URL: https://github.com/llvm/llvm-project/commit/ac51cf19604630919f54131ba1d9e3d9443f715c
DIFF: https://github.com/llvm/llvm-project/commit/ac51cf19604630919f54131ba1d9e3d9443f715c.diff

LOG: [RISCV] Refactor RISCV::hasAllWUsers to hasAllNBitUsers similar to RISCVISelDAGToDAG's version. NFC

Move to RISCVInstrInfo since we need RISCVSubtarget now.

Instead of asking if only the lower 32 bits are used we can now
ask if the lower N bits are used. This will be needed by a future
patch.

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
    llvm/lib/Target/RISCV/RISCVInstrInfo.h
    llvm/lib/Target/RISCV/RISCVSExtWRemoval.cpp
    llvm/lib/Target/RISCV/RISCVStripWSuffix.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
index 9eaef3cd3739..5f705f296445 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
@@ -2490,135 +2490,24 @@ void RISCVInstrInfo::getVLENFactoredAmount(MachineFunction &MF,
   }
 }
 
-// Returns true if this is the sext.w pattern, addiw rd, rs1, 0.
-bool RISCV::isSEXT_W(const MachineInstr &MI) {
-  return MI.getOpcode() == RISCV::ADDIW && MI.getOperand(1).isReg() &&
-         MI.getOperand(2).isImm() && MI.getOperand(2).getImm() == 0;
-}
-
-// Returns true if this is the zext.w pattern, adduw rd, rs1, x0.
-bool RISCV::isZEXT_W(const MachineInstr &MI) {
-  return MI.getOpcode() == RISCV::ADD_UW && MI.getOperand(1).isReg() &&
-         MI.getOperand(2).isReg() && MI.getOperand(2).getReg() == RISCV::X0;
-}
-
-// Returns true if this is the zext.b pattern, andi rd, rs1, 255.
-bool RISCV::isZEXT_B(const MachineInstr &MI) {
-  return MI.getOpcode() == RISCV::ANDI && MI.getOperand(1).isReg() &&
-         MI.getOperand(2).isImm() && MI.getOperand(2).getImm() == 255;
-}
-
-static bool isRVVWholeLoadStore(unsigned Opcode) {
-  switch (Opcode) {
-  default:
-    return false;
-  case RISCV::VS1R_V:
-  case RISCV::VS2R_V:
-  case RISCV::VS4R_V:
-  case RISCV::VS8R_V:
-  case RISCV::VL1RE8_V:
-  case RISCV::VL2RE8_V:
-  case RISCV::VL4RE8_V:
-  case RISCV::VL8RE8_V:
-  case RISCV::VL1RE16_V:
-  case RISCV::VL2RE16_V:
-  case RISCV::VL4RE16_V:
-  case RISCV::VL8RE16_V:
-  case RISCV::VL1RE32_V:
-  case RISCV::VL2RE32_V:
-  case RISCV::VL4RE32_V:
-  case RISCV::VL8RE32_V:
-  case RISCV::VL1RE64_V:
-  case RISCV::VL2RE64_V:
-  case RISCV::VL4RE64_V:
-  case RISCV::VL8RE64_V:
-    return true;
-  }
-}
-
-bool RISCV::isRVVSpill(const MachineInstr &MI) {
-  // RVV lacks any support for immediate addressing for stack addresses, so be
-  // conservative.
-  unsigned Opcode = MI.getOpcode();
-  if (!RISCVVPseudosTable::getPseudoInfo(Opcode) &&
-      !isRVVWholeLoadStore(Opcode) && !isRVVSpillForZvlsseg(Opcode))
-    return false;
-  return true;
-}
-
-std::optional<std::pair<unsigned, unsigned>>
-RISCV::isRVVSpillForZvlsseg(unsigned Opcode) {
-  switch (Opcode) {
-  default:
-    return std::nullopt;
-  case RISCV::PseudoVSPILL2_M1:
-  case RISCV::PseudoVRELOAD2_M1:
-    return std::make_pair(2u, 1u);
-  case RISCV::PseudoVSPILL2_M2:
-  case RISCV::PseudoVRELOAD2_M2:
-    return std::make_pair(2u, 2u);
-  case RISCV::PseudoVSPILL2_M4:
-  case RISCV::PseudoVRELOAD2_M4:
-    return std::make_pair(2u, 4u);
-  case RISCV::PseudoVSPILL3_M1:
-  case RISCV::PseudoVRELOAD3_M1:
-    return std::make_pair(3u, 1u);
-  case RISCV::PseudoVSPILL3_M2:
-  case RISCV::PseudoVRELOAD3_M2:
-    return std::make_pair(3u, 2u);
-  case RISCV::PseudoVSPILL4_M1:
-  case RISCV::PseudoVRELOAD4_M1:
-    return std::make_pair(4u, 1u);
-  case RISCV::PseudoVSPILL4_M2:
-  case RISCV::PseudoVRELOAD4_M2:
-    return std::make_pair(4u, 2u);
-  case RISCV::PseudoVSPILL5_M1:
-  case RISCV::PseudoVRELOAD5_M1:
-    return std::make_pair(5u, 1u);
-  case RISCV::PseudoVSPILL6_M1:
-  case RISCV::PseudoVRELOAD6_M1:
-    return std::make_pair(6u, 1u);
-  case RISCV::PseudoVSPILL7_M1:
-  case RISCV::PseudoVRELOAD7_M1:
-    return std::make_pair(7u, 1u);
-  case RISCV::PseudoVSPILL8_M1:
-  case RISCV::PseudoVRELOAD8_M1:
-    return std::make_pair(8u, 1u);
-  }
-}
-
-bool RISCV::isFaultFirstLoad(const MachineInstr &MI) {
-  return MI.getNumExplicitDefs() == 2 && MI.modifiesRegister(RISCV::VL) &&
-         !MI.isInlineAsm();
-}
-
-bool RISCV::hasEqualFRM(const MachineInstr &MI1, const MachineInstr &MI2) {
-  int16_t MI1FrmOpIdx =
-      RISCV::getNamedOperandIdx(MI1.getOpcode(), RISCV::OpName::frm);
-  int16_t MI2FrmOpIdx =
-      RISCV::getNamedOperandIdx(MI2.getOpcode(), RISCV::OpName::frm);
-  if (MI1FrmOpIdx < 0 || MI2FrmOpIdx < 0)
-    return false;
-  MachineOperand FrmOp1 = MI1.getOperand(MI1FrmOpIdx);
-  MachineOperand FrmOp2 = MI2.getOperand(MI2FrmOpIdx);
-  return FrmOp1.getImm() == FrmOp2.getImm();
-}
-
-// Checks if all users only demand the lower word of the original instruction's
-// result.
+// Checks if all users only demand the lower \p OrigBits of the original
+// instruction's result.
 // TODO: handle multiple interdependent transformations
-bool RISCV::hasAllWUsers(const MachineInstr &OrigMI,
-                         const MachineRegisterInfo &MRI) {
+bool RISCVInstrInfo::hasAllNBitUsers(const MachineInstr &OrigMI,
+                                     const MachineRegisterInfo &MRI,
+                                     unsigned OrigBits) const {
 
-  SmallPtrSet<const MachineInstr *, 4> Visited;
-  SmallVector<const MachineInstr *, 4> Worklist;
+  SmallSet<std::pair<const MachineInstr *, unsigned>, 4> Visited;
+  SmallVector<std::pair<const MachineInstr *, unsigned>, 4> Worklist;
 
-  Worklist.push_back(&OrigMI);
+  Worklist.push_back(std::make_pair(&OrigMI, OrigBits));
 
   while (!Worklist.empty()) {
-    const MachineInstr *MI = Worklist.pop_back_val();
+    auto P = Worklist.pop_back_val();
+    const MachineInstr *MI = P.first;
+    unsigned Bits = P.second;
 
-    if (!Visited.insert(MI).second)
+    if (!Visited.insert(P).second)
       continue;
 
     // Only handle instructions with one def.
@@ -2654,7 +2543,6 @@ bool RISCV::hasAllWUsers(const MachineInstr &OrigMI,
       case RISCV::CTZW:
       case RISCV::CPOPW:
       case RISCV::SLLI_UW:
-      case RISCV::FMV_H_X:
       case RISCV::FMV_W_X:
       case RISCV::FCVT_H_W:
       case RISCV::FCVT_H_WU:
@@ -2662,40 +2550,59 @@ bool RISCV::hasAllWUsers(const MachineInstr &OrigMI,
       case RISCV::FCVT_S_WU:
       case RISCV::FCVT_D_W:
       case RISCV::FCVT_D_WU:
+        if (Bits >= 32)
+          break;
+        return false;
       case RISCV::SEXT_B:
+      case RISCV::PACKH:
+        if (Bits >= 8)
+          break;
+        return false;
       case RISCV::SEXT_H:
+      case RISCV::FMV_H_X:
+      case RISCV::ZEXT_H_RV32:
       case RISCV::ZEXT_H_RV64:
-      case RISCV::PACK:
-      case RISCV::PACKH:
       case RISCV::PACKW:
-        break;
+        if (Bits >= 16)
+          break;
+        return false;
+
+      case RISCV::PACK:
+        if (Bits >= (STI.getXLen() / 2))
+          break;
+        return false;
 
       // these overwrite higher input bits, otherwise the lower word of output
       // depends only on the lower word of input. So check their uses read W.
       case RISCV::SLLI:
-        if (UserMI->getOperand(2).getImm() >= 32)
+        if (Bits >= (STI.getXLen() - UserMI->getOperand(2).getImm()))
           break;
-        Worklist.push_back(UserMI);
+        Worklist.push_back(std::make_pair(UserMI, Bits));
         break;
       case RISCV::ANDI:
-        if (isUInt<11>(UserMI->getOperand(2).getImm()))
+        if (Bits >=
+            (64 - countLeadingZeros((uint64_t)UserMI->getOperand(2).getImm())))
           break;
-        Worklist.push_back(UserMI);
+        Worklist.push_back(std::make_pair(UserMI, Bits));
         break;
       case RISCV::ORI:
-        if (!isUInt<11>(UserMI->getOperand(2).getImm()))
+        if (Bits >=
+            (64 - countLeadingOnes((uint64_t)UserMI->getOperand(2).getImm())))
           break;
-        Worklist.push_back(UserMI);
+        Worklist.push_back(std::make_pair(UserMI, Bits));
         break;
 
       case RISCV::SLL:
       case RISCV::BSET:
       case RISCV::BCLR:
       case RISCV::BINV:
-        // Operand 2 is the shift amount which uses 6 bits.
-        if (OpIdx == 2)
-          break;
-        Worklist.push_back(UserMI);
+        // Operand 2 is the shift amount which uses log2(xlen) bits.
+        if (OpIdx == 2) {
+          if (Bits >= Log2_32(STI.getXLen()))
+            break;
+          return false;
+        }
+        Worklist.push_back(std::make_pair(UserMI, Bits));
         break;
 
       case RISCV::SRA:
@@ -2703,7 +2610,7 @@ bool RISCV::hasAllWUsers(const MachineInstr &OrigMI,
       case RISCV::ROL:
       case RISCV::ROR:
         // Operand 2 is the shift amount which uses 6 bits.
-        if (OpIdx == 2)
+        if (OpIdx == 2 && Bits >= Log2_32(STI.getXLen()))
           break;
         return false;
 
@@ -2712,23 +2619,31 @@ bool RISCV::hasAllWUsers(const MachineInstr &OrigMI,
       case RISCV::SH2ADD_UW:
       case RISCV::SH3ADD_UW:
         // Operand 1 is implicitly zero extended.
-        if (OpIdx == 1)
+        if (OpIdx == 1 && Bits >= 32)
           break;
-        Worklist.push_back(UserMI);
+        Worklist.push_back(std::make_pair(UserMI, Bits));
         break;
 
       case RISCV::BEXTI:
-        if (UserMI->getOperand(2).getImm() >= 32)
+        if (UserMI->getOperand(2).getImm() >= Bits)
           return false;
         break;
 
       case RISCV::SB:
+        // The first argument is the value to store.
+        if (OpIdx == 0 && Bits >= 8)
+          break;
+        return false;
       case RISCV::SH:
+        // The first argument is the value to store.
+        if (OpIdx == 0 && Bits >= 16)
+          break;
+        return false;
       case RISCV::SW:
         // The first argument is the value to store.
-        if (OpIdx != 0)
-          return false;
-        break;
+        if (OpIdx == 0 && Bits >= 32)
+          break;
+        return false;
 
       // For these, lower word of output in these operations, depends only on
       // the lower word of input. So, we check all uses only read lower word.
@@ -2756,7 +2671,7 @@ bool RISCV::hasAllWUsers(const MachineInstr &OrigMI,
       case RISCV::BSETI:
       case RISCV::BCLRI:
       case RISCV::BINVI:
-        Worklist.push_back(UserMI);
+        Worklist.push_back(std::make_pair(UserMI, Bits));
         break;
 
       case RISCV::PseudoCCMOVGPR:
@@ -2765,14 +2680,14 @@ bool RISCV::hasAllWUsers(const MachineInstr &OrigMI,
         // of operand 4 and 5 is used.
         if (OpIdx != 4 && OpIdx != 5)
           return false;
-        Worklist.push_back(UserMI);
+        Worklist.push_back(std::make_pair(UserMI, Bits));
         break;
 
       case RISCV::VT_MASKC:
       case RISCV::VT_MASKCN:
         if (OpIdx != 1)
           return false;
-        Worklist.push_back(UserMI);
+        Worklist.push_back(std::make_pair(UserMI, Bits));
         break;
       }
     }
@@ -2780,3 +2695,117 @@ bool RISCV::hasAllWUsers(const MachineInstr &OrigMI,
 
   return true;
 }
+
+// Returns true if this is the sext.w pattern, addiw rd, rs1, 0.
+bool RISCV::isSEXT_W(const MachineInstr &MI) {
+  return MI.getOpcode() == RISCV::ADDIW && MI.getOperand(1).isReg() &&
+         MI.getOperand(2).isImm() && MI.getOperand(2).getImm() == 0;
+}
+
+// Returns true if this is the zext.w pattern, adduw rd, rs1, x0.
+bool RISCV::isZEXT_W(const MachineInstr &MI) {
+  return MI.getOpcode() == RISCV::ADD_UW && MI.getOperand(1).isReg() &&
+         MI.getOperand(2).isReg() && MI.getOperand(2).getReg() == RISCV::X0;
+}
+
+// Returns true if this is the zext.b pattern, andi rd, rs1, 255.
+bool RISCV::isZEXT_B(const MachineInstr &MI) {
+  return MI.getOpcode() == RISCV::ANDI && MI.getOperand(1).isReg() &&
+         MI.getOperand(2).isImm() && MI.getOperand(2).getImm() == 255;
+}
+
+static bool isRVVWholeLoadStore(unsigned Opcode) {
+  switch (Opcode) {
+  default:
+    return false;
+  case RISCV::VS1R_V:
+  case RISCV::VS2R_V:
+  case RISCV::VS4R_V:
+  case RISCV::VS8R_V:
+  case RISCV::VL1RE8_V:
+  case RISCV::VL2RE8_V:
+  case RISCV::VL4RE8_V:
+  case RISCV::VL8RE8_V:
+  case RISCV::VL1RE16_V:
+  case RISCV::VL2RE16_V:
+  case RISCV::VL4RE16_V:
+  case RISCV::VL8RE16_V:
+  case RISCV::VL1RE32_V:
+  case RISCV::VL2RE32_V:
+  case RISCV::VL4RE32_V:
+  case RISCV::VL8RE32_V:
+  case RISCV::VL1RE64_V:
+  case RISCV::VL2RE64_V:
+  case RISCV::VL4RE64_V:
+  case RISCV::VL8RE64_V:
+    return true;
+  }
+}
+
+bool RISCV::isRVVSpill(const MachineInstr &MI) {
+  // RVV lacks any support for immediate addressing for stack addresses, so be
+  // conservative.
+  unsigned Opcode = MI.getOpcode();
+  if (!RISCVVPseudosTable::getPseudoInfo(Opcode) &&
+      !isRVVWholeLoadStore(Opcode) && !isRVVSpillForZvlsseg(Opcode))
+    return false;
+  return true;
+}
+
+std::optional<std::pair<unsigned, unsigned>>
+RISCV::isRVVSpillForZvlsseg(unsigned Opcode) {
+  switch (Opcode) {
+  default:
+    return std::nullopt;
+  case RISCV::PseudoVSPILL2_M1:
+  case RISCV::PseudoVRELOAD2_M1:
+    return std::make_pair(2u, 1u);
+  case RISCV::PseudoVSPILL2_M2:
+  case RISCV::PseudoVRELOAD2_M2:
+    return std::make_pair(2u, 2u);
+  case RISCV::PseudoVSPILL2_M4:
+  case RISCV::PseudoVRELOAD2_M4:
+    return std::make_pair(2u, 4u);
+  case RISCV::PseudoVSPILL3_M1:
+  case RISCV::PseudoVRELOAD3_M1:
+    return std::make_pair(3u, 1u);
+  case RISCV::PseudoVSPILL3_M2:
+  case RISCV::PseudoVRELOAD3_M2:
+    return std::make_pair(3u, 2u);
+  case RISCV::PseudoVSPILL4_M1:
+  case RISCV::PseudoVRELOAD4_M1:
+    return std::make_pair(4u, 1u);
+  case RISCV::PseudoVSPILL4_M2:
+  case RISCV::PseudoVRELOAD4_M2:
+    return std::make_pair(4u, 2u);
+  case RISCV::PseudoVSPILL5_M1:
+  case RISCV::PseudoVRELOAD5_M1:
+    return std::make_pair(5u, 1u);
+  case RISCV::PseudoVSPILL6_M1:
+  case RISCV::PseudoVRELOAD6_M1:
+    return std::make_pair(6u, 1u);
+  case RISCV::PseudoVSPILL7_M1:
+  case RISCV::PseudoVRELOAD7_M1:
+    return std::make_pair(7u, 1u);
+  case RISCV::PseudoVSPILL8_M1:
+  case RISCV::PseudoVRELOAD8_M1:
+    return std::make_pair(8u, 1u);
+  }
+}
+
+bool RISCV::isFaultFirstLoad(const MachineInstr &MI) {
+  return MI.getNumExplicitDefs() == 2 && MI.modifiesRegister(RISCV::VL) &&
+         !MI.isInlineAsm();
+}
+
+bool RISCV::hasEqualFRM(const MachineInstr &MI1, const MachineInstr &MI2) {
+  int16_t MI1FrmOpIdx =
+      RISCV::getNamedOperandIdx(MI1.getOpcode(), RISCV::OpName::frm);
+  int16_t MI2FrmOpIdx =
+      RISCV::getNamedOperandIdx(MI2.getOpcode(), RISCV::OpName::frm);
+  if (MI1FrmOpIdx < 0 || MI2FrmOpIdx < 0)
+    return false;
+  MachineOperand FrmOp1 = MI1.getOperand(MI1FrmOpIdx);
+  MachineOperand FrmOp2 = MI2.getOperand(MI2FrmOpIdx);
+  return FrmOp1.getImm() == FrmOp2.getImm();
+}

diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.h b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
index e03582efc652..c663af75a557 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
@@ -220,6 +220,17 @@ class RISCVInstrInfo : public RISCVGenInstrInfo {
 
   std::optional<unsigned> getInverseOpcode(unsigned Opcode) const override;
 
+  // Returns true if all uses of OrigMI only depend on the lower \p NBits bits
+  // of its output.
+  bool hasAllNBitUsers(const MachineInstr &MI, const MachineRegisterInfo &MRI,
+                       unsigned NBits) const;
+  // Returns true if all uses of OrigMI only depend on the lower word of its
+  // output, so we can transform OrigMI to the corresponding W-version.
+  bool hasAllWUsers(const MachineInstr &MI,
+                    const MachineRegisterInfo &MRI) const {
+    return hasAllNBitUsers(MI, MRI, 32);
+  }
+
 protected:
   const RISCVSubtarget &STI;
 };
@@ -250,9 +261,6 @@ bool hasEqualFRM(const MachineInstr &MI1, const MachineInstr &MI2);
 // Special immediate for AVL operand of V pseudo instructions to indicate VLMax.
 static constexpr int64_t VLMaxSentinel = -1LL;
 
-// Returns true if all uses of OrigMI only depend on the lower word of its
-// output, so we can transform OrigMI to the corresponding W-version.
-bool hasAllWUsers(const MachineInstr &MI, const MachineRegisterInfo &MRI);
 } // namespace RISCV
 
 namespace RISCVVPseudosTable {

diff  --git a/llvm/lib/Target/RISCV/RISCVSExtWRemoval.cpp b/llvm/lib/Target/RISCV/RISCVSExtWRemoval.cpp
index 061b4defef1c..2ee228d72825 100644
--- a/llvm/lib/Target/RISCV/RISCVSExtWRemoval.cpp
+++ b/llvm/lib/Target/RISCV/RISCVSExtWRemoval.cpp
@@ -95,6 +95,7 @@ static bool isSignExtendingOpW(const MachineInstr &MI,
 }
 
 static bool isSignExtendedW(Register SrcReg, const MachineRegisterInfo &MRI,
+                            const RISCVInstrInfo &TII,
                             SmallPtrSetImpl<MachineInstr *> &FixableDef) {
 
   SmallPtrSet<const MachineInstr *, 4> Visited;
@@ -282,7 +283,7 @@ static bool isSignExtendedW(Register SrcReg, const MachineRegisterInfo &MRI,
     case RISCV::LWU:
     case RISCV::MUL:
     case RISCV::SUB:
-      if (RISCV::hasAllWUsers(*MI, MRI)) {
+      if (TII.hasAllWUsers(*MI, MRI)) {
         FixableDef.insert(MI);
         break;
       }
@@ -343,8 +344,8 @@ bool RISCVSExtWRemoval::runOnMachineFunction(MachineFunction &MF) {
       // If all users only use the lower bits, this sext.w is redundant.
       // Or if all definitions reaching MI sign-extend their output,
       // then sext.w is redundant.
-      if (!RISCV::hasAllWUsers(*MI, MRI) &&
-          !isSignExtendedW(SrcReg, MRI, FixableDefs))
+      if (!TII.hasAllWUsers(*MI, MRI) &&
+          !isSignExtendedW(SrcReg, MRI, TII, FixableDefs))
         continue;
 
       Register DstReg = MI->getOperand(0).getReg();

diff  --git a/llvm/lib/Target/RISCV/RISCVStripWSuffix.cpp b/llvm/lib/Target/RISCV/RISCVStripWSuffix.cpp
index e818cc5459d1..14ab9c2dd655 100644
--- a/llvm/lib/Target/RISCV/RISCVStripWSuffix.cpp
+++ b/llvm/lib/Target/RISCV/RISCVStripWSuffix.cpp
@@ -72,7 +72,7 @@ bool RISCVStripWSuffix::runOnMachineFunction(MachineFunction &MF) {
       switch (MI.getOpcode()) {
       case RISCV::ADDW:
       case RISCV::SLLIW:
-        if (RISCV::hasAllWUsers(MI, MRI)) {
+        if (TII.hasAllWUsers(MI, MRI)) {
           unsigned Opc =
               MI.getOpcode() == RISCV::ADDW ? RISCV::ADD : RISCV::SLLI;
           MI.setDesc(TII.get(Opc));


        


More information about the llvm-commits mailing list