[llvm] [AMDGPU] Implement vop3p complex pattern optmization for gisel (PR #130234)
Matt Arsenault via llvm-commits
llvm-commits at lists.llvm.org
Sun Mar 9 20:06:11 PDT 2025
================
@@ -4282,44 +4282,364 @@ AMDGPUInstructionSelector::selectVOP3NoMods(MachineOperand &Root) const {
}};
}
-std::pair<Register, unsigned>
-AMDGPUInstructionSelector::selectVOP3PModsImpl(
- Register Src, const MachineRegisterInfo &MRI, bool IsDOT) const {
- unsigned Mods = 0;
- MachineInstr *MI = MRI.getVRegDef(Src);
+enum srcStatus {
+ IS_SAME,
+ IS_UPPER_HALF,
+ IS_LOWER_HALF,
+ IS_NEG,
+ IS_UPPER_HALF_NEG,
+ IS_LOWER_HALF_NEG,
+ LAST_STAT = IS_LOWER_HALF_NEG
+};
+
+bool isTruncHalf(MachineInstr *MI, const MachineRegisterInfo &MRI) {
+ assert(MI->getOpcode() == AMDGPU::G_TRUNC);
+ unsigned dstSize = MRI.getType(MI->getOperand(0).getReg()).getSizeInBits();
+ unsigned srcSize = MRI.getType(MI->getOperand(1).getReg()).getSizeInBits();
+ return dstSize * 2 == srcSize;
+}
+
+bool isLshrHalf(MachineInstr *MI, const MachineRegisterInfo &MRI) {
+ assert(MI->getOpcode() == AMDGPU::G_LSHR);
+ Register ShiftSrc;
+ std::optional<ValueAndVReg> ShiftAmt;
+ if (mi_match(MI->getOperand(0).getReg(), MRI,
+ m_GLShr(m_Reg(ShiftSrc), m_GCst(ShiftAmt)))) {
+ unsigned srcSize = MRI.getType(MI->getOperand(1).getReg()).getSizeInBits();
+ unsigned shift = ShiftAmt->Value.getZExtValue();
+ return shift * 2 == srcSize;
+ }
+ return false;
+}
- if (MI->getOpcode() == AMDGPU::G_FNEG &&
- // It's possible to see an f32 fneg here, but unlikely.
- // TODO: Treat f32 fneg as only high bit.
- MRI.getType(Src) == LLT::fixed_vector(2, 16)) {
- Mods ^= (SISrcMods::NEG | SISrcMods::NEG_HI);
- Src = MI->getOperand(1).getReg();
- MI = MRI.getVRegDef(Src);
+bool isShlHalf(MachineInstr *MI, const MachineRegisterInfo &MRI) {
+ assert(MI->getOpcode() == AMDGPU::G_SHL);
+ Register ShiftSrc;
+ std::optional<ValueAndVReg> ShiftAmt;
+ if (mi_match(MI->getOperand(0).getReg(), MRI,
+ m_GShl(m_Reg(ShiftSrc), m_GCst(ShiftAmt)))) {
+ unsigned srcSize = MRI.getType(MI->getOperand(1).getReg()).getSizeInBits();
+ unsigned shift = ShiftAmt->Value.getZExtValue();
+ return shift * 2 == srcSize;
+ }
+ return false;
+}
+
+bool retOpStat(MachineOperand *Op, int stat,
+ std::pair<MachineOperand *, int> &curr) {
+ if ((Op->isReg() && !(Op->getReg().isPhysical())) || Op->isImm() ||
+ Op->isCImm() || Op->isFPImm()) {
+ curr = {Op, stat};
+ return true;
+ }
+ return false;
+}
+
+bool calcNextStatus(std::pair<MachineOperand *, int> &curr,
+ const MachineRegisterInfo &MRI) {
+ if (!curr.first->isReg()) {
+ return false;
+ }
+ MachineInstr *MI = nullptr;
+
+ if (!curr.first->isDef()) {
+ // MRI.getVRegDef falls into infinite loop if use define reg
+ MI = MRI.getVRegDef(curr.first->getReg());
+ } else {
+ MI = curr.first->getParent();
+ }
+ if (!MI) {
+ return false;
+ }
+
+ unsigned Opc = MI->getOpcode();
+
+ // Handle general Opc cases
+ switch (Opc) {
+ case AMDGPU::G_BITCAST:
+ case AMDGPU::G_CONSTANT:
+ case AMDGPU::G_FCONSTANT:
+ case AMDGPU::COPY:
+ return retOpStat(&MI->getOperand(1), curr.second, curr);
+ case AMDGPU::G_FNEG:
+ // XXXX + 3 = XXXX_NEG, (XXXX_NEG + 3) mod 3 = XXXX
+ return retOpStat(&MI->getOperand(1),
+ (curr.second + ((LAST_STAT + 1) / 2)) % (LAST_STAT + 1),
+ curr);
+ }
+
+ // Calc next stat from current stat
+ switch (curr.second) {
+ case IS_SAME:
+ switch (Opc) {
+ case AMDGPU::G_TRUNC: {
+ if (isTruncHalf(MI, MRI)) {
+ return retOpStat(&MI->getOperand(1), IS_LOWER_HALF, curr);
+ }
+ break;
+ }
+ }
+ break;
+ case IS_NEG:
+ switch (Opc) {
+ case AMDGPU::G_TRUNC: {
+ if (isTruncHalf(MI, MRI)) {
+ return retOpStat(&MI->getOperand(1), IS_LOWER_HALF_NEG, curr);
+ }
+ break;
+ }
+ }
+ break;
+ case IS_UPPER_HALF:
+ switch (Opc) {
+ case AMDGPU::G_SHL: {
+ if (isShlHalf(MI, MRI)) {
+ return retOpStat(&MI->getOperand(1), IS_LOWER_HALF, curr);
+ }
+ break;
+ }
+ }
+ break;
+ case IS_LOWER_HALF:
+ switch (Opc) {
+ case AMDGPU::G_LSHR: {
+ if (isLshrHalf(MI, MRI)) {
+ return retOpStat(&MI->getOperand(1), IS_UPPER_HALF, curr);
+ }
+ break;
+ }
+ }
+ break;
+ case IS_UPPER_HALF_NEG:
+ switch (Opc) {
+ case AMDGPU::G_SHL: {
+ if (isShlHalf(MI, MRI)) {
+ return retOpStat(&MI->getOperand(1), IS_LOWER_HALF_NEG, curr);
+ }
+ break;
+ }
+ }
+ break;
+ case IS_LOWER_HALF_NEG:
+ switch (Opc) {
+ case AMDGPU::G_LSHR: {
+ if (isLshrHalf(MI, MRI)) {
+ return retOpStat(&MI->getOperand(1), IS_UPPER_HALF_NEG, curr);
+ }
+ break;
+ }
+ }
+ break;
+ }
+ return false;
+}
+
+std::vector<std::pair<MachineOperand *, int>>
+getSrcStats(MachineOperand *Op, const MachineRegisterInfo &MRI,
+ bool onlyLastSameOrNeg = false, int maxDepth = 6) {
+ int depth = 0;
+ std::pair<MachineOperand *, int> curr = {Op, IS_SAME};
+ std::vector<std::pair<MachineOperand *, int>> statList;
+
+ while (true) {
+ depth++;
+ if (depth > maxDepth) {
+ break;
+ }
+ bool ret = calcNextStatus(curr, MRI);
+ if (!ret || (onlyLastSameOrNeg &&
+ (curr.second != IS_SAME && curr.second != IS_NEG))) {
+ break;
+ } else if (!onlyLastSameOrNeg) {
+ statList.push_back(curr);
+ }
}
+ if (onlyLastSameOrNeg) {
+ statList.push_back(curr);
+ }
+ return statList;
+}
- // TODO: Handle G_FSUB 0 as fneg
+bool isInlinableConstant(MachineOperand *Op, const SIInstrInfo &TII) {
+ bool a = TII.isInlineConstant(*Op);
+ switch (Op->getType()) {
+ case MachineOperand::MachineOperandType::MO_Immediate:
+ return TII.isInlineConstant(*Op);
+ case MachineOperand::MachineOperandType::MO_CImmediate:
+ return TII.isInlineConstant(Op->getCImm()->getValue());
+ case MachineOperand::MachineOperandType::MO_FPImmediate:
+ return TII.isInlineConstant(Op->getFPImm()->getValueAPF());
+ }
+ return false;
+}
- // TODO: Match op_sel through g_build_vector_trunc and g_shuffle_vector.
- (void)IsDOT; // DOTs do not use OPSEL on gfx942+, check ST.hasDOTOpSelHazard()
+bool isSameBitWidth(MachineOperand *Op1, MachineOperand *Op2,
+ const MachineRegisterInfo &MRI) {
+ unsigned width1 = MRI.getType(Op1->getReg()).getSizeInBits();
+ unsigned width2 = MRI.getType(Op2->getReg()).getSizeInBits();
+ return width1 == width2;
+}
+bool isSameOperand(MachineOperand *Op1, MachineOperand *Op2) {
+ if (Op1->isReg()) {
+ if (Op2->isReg()) {
+ return Op1->getReg() == Op2->getReg();
+ }
----------------
arsenm wrote:
This is ignoring subregister uses, but you shouldn't encounter them either
https://github.com/llvm/llvm-project/pull/130234
More information about the llvm-commits
mailing list