[llvm] [RISCV] Porting hasAllNBitUsers to RISCV GISel for instruction select (PR #124678)
Michael Maitland via llvm-commits
llvm-commits at lists.llvm.org
Tue Jan 28 07:08:05 PST 2025
================
@@ -186,6 +201,169 @@ RISCVInstructionSelector::RISCVInstructionSelector(
{
}
+bool RISCVInstructionSelector::hasAllNBitUsers(const MachineInstr &MI,
+ unsigned Bits,
+ const unsigned Depth) const {
+
+ assert((MI.getOpcode() == TargetOpcode::G_ADD ||
+ MI.getOpcode() == TargetOpcode::G_SUB ||
+ MI.getOpcode() == TargetOpcode::G_MUL ||
+ MI.getOpcode() == TargetOpcode::G_SHL ||
+ MI.getOpcode() == TargetOpcode::G_LSHR ||
+ MI.getOpcode() == TargetOpcode::G_AND ||
+ MI.getOpcode() == TargetOpcode::G_OR ||
+ MI.getOpcode() == TargetOpcode::G_XOR ||
+ MI.getOpcode() == TargetOpcode::G_SEXT_INREG || Depth != 0) &&
+ "Unexpected opcode");
+
+ if (Depth >= RISCVInstructionSelector::MaxRecursionDepth)
+ return false;
+
+ // Skip Vectors
+ // if(Depth == 0 && !MI.getOperand(0).isScalar())
+ // return false;
+
+ for (MachineInstr &Use : MRI->use_instructions(MI.getOperand(0).getReg())) {
+
+ switch (Use.getOpcode()) {
+ default:
+ // if (vectorPseudoHasAllNBitUsers(User, Use.getNumOperands(), Bits, TII))
+ // break;
+ return false;
+ case RISCV::ADDW:
+ case RISCV::ADDIW:
+ case RISCV::SUBW:
+ case RISCV::MULW:
+ case RISCV::SLLW:
+ case RISCV::SLLIW:
+ case RISCV::SRAW:
+ case RISCV::SRAIW:
+ case RISCV::SRLW:
+ case RISCV::SRLIW:
+ case RISCV::DIVW:
+ case RISCV::DIVUW:
+ case RISCV::REMW:
+ case RISCV::REMUW:
+ case RISCV::ROLW:
+ case RISCV::RORW:
+ case RISCV::RORIW:
+ case RISCV::CLZW:
+ case RISCV::CTZW:
+ case RISCV::CPOPW:
+ case RISCV::SLLI_UW:
+ case RISCV::FMV_W_X:
+ case RISCV::FCVT_H_W:
+ case RISCV::FCVT_H_W_INX:
+ case RISCV::FCVT_H_WU:
+ case RISCV::FCVT_H_WU_INX:
+ case RISCV::FCVT_S_W:
+ case RISCV::FCVT_S_W_INX:
+ case RISCV::FCVT_S_WU:
+ case RISCV::FCVT_S_WU_INX:
+ case RISCV::FCVT_D_W:
+ case RISCV::FCVT_D_W_INX:
+ case RISCV::FCVT_D_WU:
+ case RISCV::FCVT_D_WU_INX:
+ case RISCV::TH_REVW:
+ case RISCV::TH_SRRIW:
+ if (Bits >= 32)
+ break;
+ return false;
+ case RISCV::SLL:
+ case RISCV::SRA:
+ case RISCV::SRL:
+ case RISCV::ROL:
+ case RISCV::ROR:
+ case RISCV::BSET:
+ case RISCV::BCLR:
+ case RISCV::BINV:
+ // Shift amount operands only use log2(Xlen) bits.
+ if (Use.getNumOperands() == 1 && Bits >= Log2_32(Subtarget->getXLen()))
+ break;
+ return false;
+ case RISCV::SLLI:
+ // SLLI only uses the lower (XLen - ShAmt) bits.
+ if (Bits >= Subtarget->getXLen() - Use.getOperand(2).getImm())
+ break;
+ return false;
+ case RISCV::ANDI:
+ if (Bits >= (unsigned)llvm::bit_width<uint64_t>(
+ (uint64_t)Use.getOperand(2).getImm()))
+ break;
+ goto RecCheck;
+ case RISCV::ORI: {
+ uint64_t Imm = Use.getOperand(2).getImm();
+ if (Bits >= (unsigned)llvm::bit_width<uint64_t>(~Imm))
+ break;
+ [[fallthrough]];
+ }
+ case RISCV::AND:
+ case RISCV::OR:
+ case RISCV::XOR:
+ case RISCV::XORI:
+ case RISCV::ANDN:
+ case RISCV::ORN:
+ case RISCV::XNOR:
+ case RISCV::SH1ADD:
+ case RISCV::SH2ADD:
+ case RISCV::SH3ADD:
+ RecCheck:
+ if (hasAllNBitUsers(Use, Bits, Depth + 1))
+ break;
+ return false;
+ case RISCV::SRLI: {
+ unsigned ShAmt = Use.getOperand(2).getImm();
+ // If we are shifting right by less than Bits, and users don't demand any
+ // bits that were shifted into [Bits-1:0], then we can consider this as an
+ // N-Bit user.
+ if (Bits > ShAmt && hasAllNBitUsers(Use, Bits - ShAmt, Depth + 1))
+ break;
+ return false;
+ }
+ case RISCV::SEXT_B:
+ case RISCV::PACKH:
+ if (Bits >= 8)
+ break;
+ return false;
+ case RISCV::SEXT_H:
+ case RISCV::FMV_H_X:
+ case RISCV::ZEXT_H_RV32:
+ case RISCV::ZEXT_H_RV64:
+ case RISCV::PACKW:
+ if (Bits >= 16)
+ break;
+ return false;
+ case RISCV::PACK:
+ if (Bits >= (Subtarget->getXLen() / 2))
+ break;
+ return false;
+ case RISCV::ADD_UW:
+ case RISCV::SH1ADD_UW:
+ case RISCV::SH2ADD_UW:
+ case RISCV::SH3ADD_UW:
+ // The first operand to add.uw/shXadd.uw is implicitly zero extended from
+ // 32 bits.
+ if (Use.getNumOperands() == 0 && Bits >= 32)
+ break;
+ return false;
+ case RISCV::SB:
+ if (Use.getNumOperands() == 0 && Bits >= 8)
+ break;
+ return false;
+ case RISCV::SH:
+ if (Use.getNumOperands() == 0 && Bits >= 16)
+ break;
+ return false;
+ case RISCV::SW:
+ if (Use.getNumOperands() == 0 && Bits >= 32)
----------------
michaelmaitland wrote:
Do we have test coverage for all of these instructions? I don't see any sw/sh/sb related diff below.
https://github.com/llvm/llvm-project/pull/124678
More information about the llvm-commits
mailing list