[llvm] [AMDGPU] Implement vop3p complex pattern optmization for gisel (PR #130234)
Juan Manuel Martinez CaamaƱo via llvm-commits
llvm-commits at lists.llvm.org
Tue Mar 25 09:22:05 PDT 2025
================
@@ -4314,44 +4314,352 @@ AMDGPUInstructionSelector::selectVOP3NoMods(MachineOperand &Root) const {
}};
}
-std::pair<Register, unsigned>
-AMDGPUInstructionSelector::selectVOP3PModsImpl(
- Register Src, const MachineRegisterInfo &MRI, bool IsDOT) const {
- unsigned Mods = 0;
- MachineInstr *MI = MRI.getVRegDef(Src);
+enum class SrcStatus {
+ IS_SAME,
+ IS_UPPER_HALF,
+ IS_LOWER_HALF,
+ IS_NEG,
+ IS_UPPER_HALF_NEG,
+ IS_LOWER_HALF_NEG
+};
+
+static bool isTruncHalf(const MachineInstr *MI,
+ const MachineRegisterInfo &MRI) {
+ if (MI->getOpcode() != AMDGPU::G_TRUNC)
+ return false;
+
+ unsigned DstSize = MRI.getType(MI->getOperand(0).getReg()).getSizeInBits();
+ unsigned SrcSize = MRI.getType(MI->getOperand(1).getReg()).getSizeInBits();
+ return DstSize * 2 == SrcSize;
+}
+
+static bool isLshrHalf(const MachineInstr *MI, const MachineRegisterInfo &MRI) {
+ if (MI->getOpcode() != AMDGPU::G_LSHR)
+ return false;
+
+ Register ShiftSrc;
+ std::optional<ValueAndVReg> ShiftAmt;
+ if (mi_match(MI->getOperand(0).getReg(), MRI,
+ m_GLShr(m_Reg(ShiftSrc), m_GCst(ShiftAmt)))) {
+ unsigned SrcSize = MRI.getType(MI->getOperand(1).getReg()).getSizeInBits();
+ unsigned Shift = ShiftAmt->Value.getZExtValue();
+ return Shift * 2 == SrcSize;
+ }
+ return false;
+}
+
+static bool isShlHalf(const MachineInstr *MI, const MachineRegisterInfo &MRI) {
+ if (MI->getOpcode() != AMDGPU::G_SHL)
+ return false;
+
+ Register ShiftSrc;
+ std::optional<ValueAndVReg> ShiftAmt;
+ if (mi_match(MI->getOperand(0).getReg(), MRI,
+ m_GShl(m_Reg(ShiftSrc), m_GCst(ShiftAmt)))) {
+ unsigned SrcSize = MRI.getType(MI->getOperand(1).getReg()).getSizeInBits();
+ unsigned Shift = ShiftAmt->Value.getZExtValue();
+ return Shift * 2 == SrcSize;
+ }
+ return false;
+}
+
+static bool retOpStat(const MachineOperand *Op, SrcStatus Stat,
+ std::pair<const MachineOperand *, SrcStatus> &Curr) {
+ if ((Op->isReg() && !(Op->getReg().isPhysical())) || Op->isImm() ||
+ Op->isCImm() || Op->isFPImm()) {
+ Curr = {Op, Stat};
+ return true;
+ }
+
+ return false;
+}
+
+SrcStatus getNegStatus(SrcStatus S) {
+ switch (S) {
+ case SrcStatus::IS_SAME:
+ return SrcStatus::IS_NEG;
+ case SrcStatus::IS_UPPER_HALF:
+ return SrcStatus::IS_UPPER_HALF_NEG;
+ case SrcStatus::IS_LOWER_HALF:
+ return SrcStatus::IS_LOWER_HALF_NEG;
+ case SrcStatus::IS_NEG:
+ return SrcStatus::IS_SAME;
+ case SrcStatus::IS_UPPER_HALF_NEG:
+ return SrcStatus::IS_UPPER_HALF;
+ case SrcStatus::IS_LOWER_HALF_NEG:
+ return SrcStatus::IS_LOWER_HALF;
+ }
+ llvm_unreachable("unexpected SrcStatus");
+}
- if (MI->getOpcode() == AMDGPU::G_FNEG &&
- // It's possible to see an f32 fneg here, but unlikely.
- // TODO: Treat f32 fneg as only high bit.
- MRI.getType(Src) == LLT::fixed_vector(2, 16)) {
+static bool calcNextStatus(std::pair<const MachineOperand *, SrcStatus> &Curr,
+ const MachineRegisterInfo &MRI) {
+ if (!Curr.first->isReg())
+ return false;
+
+ const MachineInstr *MI = nullptr;
+
+ if (!Curr.first->isDef())
+ MI = MRI.getVRegDef(Curr.first->getReg());
+ else
+ MI = Curr.first->getParent();
+
+ if (!MI)
+ return false;
+
+ unsigned Opc = MI->getOpcode();
+
+ // Handle general Opc cases.
+ switch (Opc) {
+ case AMDGPU::G_BITCAST:
+ case AMDGPU::G_CONSTANT:
+ case AMDGPU::G_FCONSTANT:
+ case AMDGPU::COPY:
+ return retOpStat(&MI->getOperand(1), Curr.second, Curr);
+ case AMDGPU::G_FNEG:
+ return retOpStat(&MI->getOperand(1), getNegStatus(Curr.second), Curr);
+ default:
+ break;
+ }
+
+ // Calc next Stat from current Stat.
+ switch (Curr.second) {
+ case SrcStatus::IS_SAME:
+ if (isTruncHalf(MI, MRI))
+ return retOpStat(&MI->getOperand(1), SrcStatus::IS_LOWER_HALF, Curr);
+ break;
+ case SrcStatus::IS_NEG:
+ if (isTruncHalf(MI, MRI))
+ return retOpStat(&MI->getOperand(1), SrcStatus::IS_LOWER_HALF_NEG, Curr);
+ break;
+ case SrcStatus::IS_UPPER_HALF:
+ if (isShlHalf(MI, MRI))
+ return retOpStat(&MI->getOperand(1), SrcStatus::IS_LOWER_HALF, Curr);
+ break;
+ case SrcStatus::IS_LOWER_HALF:
+ if (isLshrHalf(MI, MRI))
+ return retOpStat(&MI->getOperand(1), SrcStatus::IS_UPPER_HALF, Curr);
+ break;
+ case SrcStatus::IS_UPPER_HALF_NEG:
+ if (isShlHalf(MI, MRI))
+ return retOpStat(&MI->getOperand(1), SrcStatus::IS_LOWER_HALF_NEG, Curr);
+ break;
+ case SrcStatus::IS_LOWER_HALF_NEG:
+ if (isLshrHalf(MI, MRI))
+ return retOpStat(&MI->getOperand(1), SrcStatus::IS_UPPER_HALF_NEG, Curr);
+ break;
+ default:
+ break;
+ }
+ return false;
+}
+
+SmallVector<std::pair<const MachineOperand *, SrcStatus>>
+getSrcStats(const MachineOperand *Op, const MachineRegisterInfo &MRI,
+ bool OnlyLastSameOrNeg = false, int MaxDepth = 6) {
+ int Depth = 0;
+ std::pair<const MachineOperand *, SrcStatus> Curr = {Op, SrcStatus::IS_SAME};
+ SmallVector<std::pair<const MachineOperand *, SrcStatus>> Statlist;
+
+ while (Depth <= MaxDepth && calcNextStatus(Curr, MRI)) {
+ Depth++;
+ if ((OnlyLastSameOrNeg && (Curr.second != SrcStatus::IS_SAME &&
+ Curr.second != SrcStatus::IS_NEG)))
+ break;
+
+ if (!OnlyLastSameOrNeg)
+ Statlist.push_back(Curr);
+ }
+ if (OnlyLastSameOrNeg)
+ Statlist.push_back(Curr);
+ return Statlist;
+}
+
+static bool isInlinableConstant(const MachineOperand &Op,
+ const SIInstrInfo &TII) {
+ return Op.isFPImm() && TII.isInlineConstant(Op.getFPImm()->getValueAPF());
+}
+
+static bool isSameBitWidth(const MachineOperand *Op1, const MachineOperand *Op2,
+ const MachineRegisterInfo &MRI) {
+ unsigned Width1 = MRI.getType(Op1->getReg()).getSizeInBits();
+ unsigned Width2 = MRI.getType(Op2->getReg()).getSizeInBits();
+ return Width1 == Width2;
+}
+
+static bool isSameOperand(const MachineOperand *Op1,
+ const MachineOperand *Op2) {
+ if (Op1->isReg()) {
+ if (Op2->isReg())
+ return Op1->getReg() == Op2->getReg();
+ return false;
+ }
+ return Op1->isIdenticalTo(*Op2);
+}
+
+static bool isValidToPack(SrcStatus HiStat, SrcStatus LoStat,
+ unsigned int &Mods, const MachineOperand *NewOp,
+ const MachineOperand *RootOp, const SIInstrInfo &TII,
+ const MachineRegisterInfo &MRI) {
+ if (NewOp->isReg()) {
+ if (isSameBitWidth(NewOp, RootOp, MRI)) {
+ // SrcStatus::IS_LOWER_HALF remain 0.
+ if (HiStat == SrcStatus::IS_UPPER_HALF_NEG) {
+ Mods ^= SISrcMods::NEG_HI;
+ Mods |= SISrcMods::OP_SEL_1;
+ } else if (HiStat == SrcStatus::IS_UPPER_HALF) {
+ Mods |= SISrcMods::OP_SEL_1;
+ } else if (HiStat == SrcStatus::IS_LOWER_HALF_NEG) {
+ Mods ^= SISrcMods::NEG_HI;
+ }
+ if (LoStat == SrcStatus::IS_UPPER_HALF_NEG) {
+ Mods ^= SISrcMods::NEG;
+ Mods |= SISrcMods::OP_SEL_0;
+ } else if (LoStat == SrcStatus::IS_UPPER_HALF) {
+ Mods |= SISrcMods::OP_SEL_0;
+ } else if (LoStat == SrcStatus::IS_UPPER_HALF_NEG) {
+ Mods |= SISrcMods::NEG;
+ }
+ return true;
+ }
+ } else {
+ if ((HiStat == SrcStatus::IS_SAME || HiStat == SrcStatus::IS_NEG) &&
+ (LoStat == SrcStatus::IS_SAME || LoStat == SrcStatus::IS_NEG) &&
+ isInlinableConstant(*NewOp, TII)) {
+ if (HiStat == SrcStatus::IS_NEG)
+ Mods ^= SISrcMods::NEG_HI;
+ if (LoStat == SrcStatus::IS_NEG)
+ Mods ^= SISrcMods::NEG;
+ // opsel = opsel_hi = 0, since the upper half and lower half both
+ // the same as the target inlinable constant.
+ return true;
+ }
+ }
+ return false;
+}
+
+std::pair<const MachineOperand *, unsigned>
+AMDGPUInstructionSelector::selectVOP3PModsImpl(const MachineOperand *Op,
+ const MachineRegisterInfo &MRI,
+ bool IsDOT) const {
+ unsigned Mods = 0;
+ const MachineOperand *RootOp = Op;
+ std::pair<const MachineOperand *, SrcStatus> Stat =
+ getSrcStats(Op, MRI, true)[0];
+ if (!Stat.first->isReg()) {
+ Mods |= SISrcMods::OP_SEL_1;
+ return {Op, Mods};
+ }
+ if (Stat.second == SrcStatus::IS_NEG)
Mods ^= (SISrcMods::NEG | SISrcMods::NEG_HI);
- Src = MI->getOperand(1).getReg();
- MI = MRI.getVRegDef(Src);
+
+ Op = Stat.first;
+ MachineInstr *MI = MRI.getVRegDef(Op->getReg());
+
+ if (MI->getOpcode() != AMDGPU::G_BUILD_VECTOR || MI->getNumOperands() != 3 ||
+ (IsDOT && Subtarget->hasDOTOpSelHazard())) {
+ Mods |= SISrcMods::OP_SEL_1;
+ return {Op, Mods};
}
- // TODO: Handle G_FSUB 0 as fneg
+ SmallVector<std::pair<const MachineOperand *, SrcStatus>> StatlistHi;
+ StatlistHi = getSrcStats(&MI->getOperand(2), MRI);
----------------
jmmartinez wrote:
```suggestion
SmallVector<std::pair<const MachineOperand *, SrcStatus>> StatlistHi = getSrcStats(&MI->getOperand(2), MRI);
```
https://github.com/llvm/llvm-project/pull/130234
More information about the llvm-commits
mailing list