[llvm] [AMDGPU] Implement vop3p complex pattern optmization for gisel (PR #130234)

via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 17 22:17:06 PDT 2025


================
@@ -4310,44 +4310,349 @@ AMDGPUInstructionSelector::selectVOP3NoMods(MachineOperand &Root) const {
   }};
 }
 
-std::pair<Register, unsigned>
-AMDGPUInstructionSelector::selectVOP3PModsImpl(
-  Register Src, const MachineRegisterInfo &MRI, bool IsDOT) const {
-  unsigned Mods = 0;
-  MachineInstr *MI = MRI.getVRegDef(Src);
+enum SrcStatus {
+  IS_SAME,
+  IS_UPPER_HALF,
+  IS_LOWER_HALF,
+  IS_NEG,
+  IS_UPPER_HALF_NEG,
+  IS_LOWER_HALF_NEG
+};
+
+static bool isTruncHalf(const MachineInstr *MI,
+                        const MachineRegisterInfo &MRI) {
+  if (MI->getOpcode() != AMDGPU::G_TRUNC)
+    return false;
+
+  unsigned DstSize = MRI.getType(MI->getOperand(0).getReg()).getSizeInBits();
+  unsigned SrcSize = MRI.getType(MI->getOperand(1).getReg()).getSizeInBits();
+  return DstSize * 2 == SrcSize;
+}
+
+static bool isLshrHalf(const MachineInstr *MI, const MachineRegisterInfo &MRI) {
+  if (MI->getOpcode() != AMDGPU::G_LSHR)
+    return false;
+
+  Register ShiftSrc;
+  std::optional<ValueAndVReg> ShiftAmt;
+  if (mi_match(MI->getOperand(0).getReg(), MRI,
+               m_GLShr(m_Reg(ShiftSrc), m_GCst(ShiftAmt)))) {
+    unsigned SrcSize = MRI.getType(MI->getOperand(1).getReg()).getSizeInBits();
+    unsigned Shift = ShiftAmt->Value.getZExtValue();
+    return Shift * 2 == SrcSize;
+  }
+  return false;
+}
+
+static bool isShlHalf(const MachineInstr *MI, const MachineRegisterInfo &MRI) {
+  if (MI->getOpcode() != AMDGPU::G_SHL)
+    return false;
+
+  Register ShiftSrc;
+  std::optional<ValueAndVReg> ShiftAmt;
+  if (mi_match(MI->getOperand(0).getReg(), MRI,
+               m_GShl(m_Reg(ShiftSrc), m_GCst(ShiftAmt)))) {
+    unsigned SrcSize = MRI.getType(MI->getOperand(1).getReg()).getSizeInBits();
+    unsigned Shift = ShiftAmt->Value.getZExtValue();
+    return Shift * 2 == SrcSize;
+  }
+  return false;
+}
+
+static bool retOpStat(const MachineOperand *Op, SrcStatus Stat,
+                      std::pair<const MachineOperand *, SrcStatus> &Curr) {
+  if ((Op->isReg() && !(Op->getReg().isPhysical())) || Op->isImm() ||
+      Op->isCImm() || Op->isFPImm())
+    Curr = {Op, Stat};
+  return true;
+
+  return false;
+}
+
+SrcStatus getNegStatus(SrcStatus S) {
+  switch (S) {
+  case IS_SAME:
+    return IS_NEG;
+  case IS_UPPER_HALF:
+    return IS_UPPER_HALF_NEG;
+  case IS_LOWER_HALF:
+    return IS_LOWER_HALF_NEG;
+  case IS_NEG:
+    return IS_SAME;
+  case IS_UPPER_HALF_NEG:
+    return IS_UPPER_HALF;
+  case IS_LOWER_HALF_NEG:
+    return IS_LOWER_HALF;
+  }
+  llvm_unreachable("unexpected SrcStatus");
+}
 
-  if (MI->getOpcode() == AMDGPU::G_FNEG &&
-      // It's possible to see an f32 fneg here, but unlikely.
-      // TODO: Treat f32 fneg as only high bit.
-      MRI.getType(Src) == LLT::fixed_vector(2, 16)) {
+static bool calcNextStatus(std::pair<const MachineOperand *, SrcStatus> &Curr,
+                           const MachineRegisterInfo &MRI) {
+  if (!Curr.first->isReg())
+    return false;
+
+  const MachineInstr *MI = nullptr;
+
+  if (!Curr.first->isDef()) {
+    MI = MRI.getVRegDef(Curr.first->getReg());
+  } else {
+    MI = Curr.first->getParent();
+  }
+  if (!MI)
+    return false;
+
+  unsigned Opc = MI->getOpcode();
+
+  // Handle general Opc cases
+  switch (Opc) {
+  case AMDGPU::G_BITCAST:
+  case AMDGPU::G_CONSTANT:
+  case AMDGPU::G_FCONSTANT:
+  case AMDGPU::COPY:
+    return retOpStat(&MI->getOperand(1), Curr.second, Curr);
+  case AMDGPU::G_FNEG:
+    // XXXX + 3 = XXXX_NEG, (XXXX_NEG + 3) mod 3 = XXXX
+    return retOpStat(&MI->getOperand(1), getNegStatus(Curr.second), Curr);
+  }
+
+  // Calc next Stat from current Stat
+  switch (Curr.second) {
+  case IS_SAME:
+    if (isTruncHalf(MI, MRI))
+      return retOpStat(&MI->getOperand(1), IS_LOWER_HALF, Curr);
+    break;
+  case IS_NEG:
+    if (isTruncHalf(MI, MRI))
+      return retOpStat(&MI->getOperand(1), IS_LOWER_HALF_NEG, Curr);
+    break;
+  case IS_UPPER_HALF:
+    if (isShlHalf(MI, MRI))
+      return retOpStat(&MI->getOperand(1), IS_LOWER_HALF, Curr);
+    break;
+  case IS_LOWER_HALF:
+    if (isLshrHalf(MI, MRI))
+      return retOpStat(&MI->getOperand(1), IS_UPPER_HALF, Curr);
+    break;
+  case IS_UPPER_HALF_NEG:
+    if (isShlHalf(MI, MRI))
+      return retOpStat(&MI->getOperand(1), IS_LOWER_HALF_NEG, Curr);
+    break;
+  case IS_LOWER_HALF_NEG:
+    if (isLshrHalf(MI, MRI))
+      return retOpStat(&MI->getOperand(1), IS_UPPER_HALF_NEG, Curr);
+    break;
+  }
+  return false;
+}
+
+SmallVector<std::pair<const MachineOperand *, SrcStatus>>
+getSrcStats(const MachineOperand *Op, const MachineRegisterInfo &MRI,
+            bool onlyLastSameOrNeg = false, int maxDepth = 6) {
+  int depth = 0;
+  std::pair<const MachineOperand *, SrcStatus> Curr = {Op, IS_SAME};
+  SmallVector<std::pair<const MachineOperand *, SrcStatus>> Statlist;
+
+  while (depth <= maxDepth && calcNextStatus(Curr, MRI)) {
+    depth++;
+    if ((onlyLastSameOrNeg &&
+         (Curr.second != IS_SAME && Curr.second != IS_NEG))) {
+      break;
+    } else if (!onlyLastSameOrNeg) {
+      Statlist.push_back(Curr);
+    }
+  }
+  if (onlyLastSameOrNeg)
+    Statlist.push_back(Curr);
+  return Statlist;
+}
+
+static bool isInlinableConstant(const MachineOperand &Op,
+                                const SIInstrInfo &TII) {
+  if (Op.isFPImm())
+    return TII.isInlineConstant(Op.getFPImm()->getValueAPF());
+
+  return false;
+}
+
+static bool isSameBitWidth(const MachineOperand *Op1, const MachineOperand *Op2,
+                           const MachineRegisterInfo &MRI) {
+  unsigned Width1 = MRI.getType(Op1->getReg()).getSizeInBits();
+  unsigned Width2 = MRI.getType(Op2->getReg()).getSizeInBits();
+  return Width1 == Width2;
+}
+
+static bool isSameOperand(const MachineOperand *Op1,
+                          const MachineOperand *Op2) {
+  if (Op1->isReg()) {
+    if (Op2->isReg())
+      return Op1->getReg() == Op2->getReg();
----------------
Shoreshen wrote:

Hi @shiltian , direct use of `isIdenticalTo` will differentiate use/def of register and subreg

https://github.com/llvm/llvm-project/pull/130234


More information about the llvm-commits mailing list