[llvm] [RISCV] Porting hasAllNBitUsers to RISCV GISel for instruction select (PR #124678)

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 29 08:33:12 PST 2025


================
@@ -186,6 +201,169 @@ RISCVInstructionSelector::RISCVInstructionSelector(
 {
 }
 
+bool RISCVInstructionSelector::hasAllNBitUsers(const MachineInstr &MI,
+                                               unsigned Bits,
+                                               const unsigned Depth) const {
+
+  assert((MI.getOpcode() == TargetOpcode::G_ADD ||
+          MI.getOpcode() == TargetOpcode::G_SUB ||
+          MI.getOpcode() == TargetOpcode::G_MUL ||
+          MI.getOpcode() == TargetOpcode::G_SHL ||
+          MI.getOpcode() == TargetOpcode::G_LSHR ||
+          MI.getOpcode() == TargetOpcode::G_AND ||
+          MI.getOpcode() == TargetOpcode::G_OR ||
+          MI.getOpcode() == TargetOpcode::G_XOR ||
+          MI.getOpcode() == TargetOpcode::G_SEXT_INREG || Depth != 0) &&
+         "Unexpected opcode");
+
+  if (Depth >= RISCVInstructionSelector::MaxRecursionDepth)
+    return false;
+
+  // Skip Vectors
+  // if(Depth == 0 && !MI.getOperand(0).isScalar())
+  //    return false;
+
+  for (MachineInstr &Use : MRI->use_instructions(MI.getOperand(0).getReg())) {
+
+    switch (Use.getOpcode()) {
+    default:
+      // if (vectorPseudoHasAllNBitUsers(User, Use.getNumOperands(), Bits, TII))
+      //   break;
+      return false;
+    case RISCV::ADDW:
+    case RISCV::ADDIW:
+    case RISCV::SUBW:
+    case RISCV::MULW:
+    case RISCV::SLLW:
+    case RISCV::SLLIW:
+    case RISCV::SRAW:
+    case RISCV::SRAIW:
+    case RISCV::SRLW:
+    case RISCV::SRLIW:
+    case RISCV::DIVW:
+    case RISCV::DIVUW:
+    case RISCV::REMW:
+    case RISCV::REMUW:
+    case RISCV::ROLW:
+    case RISCV::RORW:
+    case RISCV::RORIW:
+    case RISCV::CLZW:
+    case RISCV::CTZW:
+    case RISCV::CPOPW:
+    case RISCV::SLLI_UW:
+    case RISCV::FMV_W_X:
+    case RISCV::FCVT_H_W:
+    case RISCV::FCVT_H_W_INX:
+    case RISCV::FCVT_H_WU:
+    case RISCV::FCVT_H_WU_INX:
+    case RISCV::FCVT_S_W:
+    case RISCV::FCVT_S_W_INX:
+    case RISCV::FCVT_S_WU:
+    case RISCV::FCVT_S_WU_INX:
+    case RISCV::FCVT_D_W:
+    case RISCV::FCVT_D_W_INX:
+    case RISCV::FCVT_D_WU:
+    case RISCV::FCVT_D_WU_INX:
+    case RISCV::TH_REVW:
+    case RISCV::TH_SRRIW:
+      if (Bits >= 32)
+        break;
+      return false;
+    case RISCV::SLL:
+    case RISCV::SRA:
+    case RISCV::SRL:
+    case RISCV::ROL:
+    case RISCV::ROR:
+    case RISCV::BSET:
+    case RISCV::BCLR:
+    case RISCV::BINV:
+      // Shift amount operands only use log2(Xlen) bits.
+      if (Use.getNumOperands() == 1 && Bits >= Log2_32(Subtarget->getXLen()))
+        break;
+      return false;
+    case RISCV::SLLI:
+      // SLLI only uses the lower (XLen - ShAmt) bits.
+      if (Bits >= Subtarget->getXLen() - Use.getOperand(2).getImm())
+        break;
+      return false;
+    case RISCV::ANDI:
+      if (Bits >= (unsigned)llvm::bit_width<uint64_t>(
+                      (uint64_t)Use.getOperand(2).getImm()))
+        break;
+      goto RecCheck;
+    case RISCV::ORI: {
+      uint64_t Imm = Use.getOperand(2).getImm();
+      if (Bits >= (unsigned)llvm::bit_width<uint64_t>(~Imm))
+        break;
+      [[fallthrough]];
+    }
+    case RISCV::AND:
+    case RISCV::OR:
+    case RISCV::XOR:
+    case RISCV::XORI:
+    case RISCV::ANDN:
+    case RISCV::ORN:
+    case RISCV::XNOR:
+    case RISCV::SH1ADD:
+    case RISCV::SH2ADD:
+    case RISCV::SH3ADD:
+    RecCheck:
+      if (hasAllNBitUsers(Use, Bits, Depth + 1))
+        break;
+      return false;
+    case RISCV::SRLI: {
+      unsigned ShAmt = Use.getOperand(2).getImm();
+      // If we are shifting right by less than Bits, and users don't demand any
+      // bits that were shifted into [Bits-1:0], then we can consider this as an
+      // N-Bit user.
+      if (Bits > ShAmt && hasAllNBitUsers(Use, Bits - ShAmt, Depth + 1))
+        break;
+      return false;
+    }
+    case RISCV::SEXT_B:
+    case RISCV::PACKH:
+      if (Bits >= 8)
+        break;
+      return false;
+    case RISCV::SEXT_H:
+    case RISCV::FMV_H_X:
+    case RISCV::ZEXT_H_RV32:
+    case RISCV::ZEXT_H_RV64:
+    case RISCV::PACKW:
+      if (Bits >= 16)
+        break;
+      return false;
+    case RISCV::PACK:
+      if (Bits >= (Subtarget->getXLen() / 2))
+        break;
+      return false;
+    case RISCV::ADD_UW:
+    case RISCV::SH1ADD_UW:
+    case RISCV::SH2ADD_UW:
+    case RISCV::SH3ADD_UW:
+      // The first operand to add.uw/shXadd.uw is implicitly zero extended from
+      // 32 bits.
+      if (Use.getNumOperands() == 0 && Bits >= 32)
----------------
topperc wrote:

SB/SH/SW don't write to any registers so they only have input operands. So the operand numbering doesn't change between SelectionDAG and MachineInstr.

https://github.com/llvm/llvm-project/pull/124678


More information about the llvm-commits mailing list