[llvm] [AMDGPU] Implement vop3p complex pattern optmization for gisel (PR #130234)

via llvm-commits llvm-commits at lists.llvm.org
Tue Mar 11 19:46:21 PDT 2025


================
@@ -4282,44 +4282,362 @@ AMDGPUInstructionSelector::selectVOP3NoMods(MachineOperand &Root) const {
   }};
 }
 
-std::pair<Register, unsigned>
-AMDGPUInstructionSelector::selectVOP3PModsImpl(
-  Register Src, const MachineRegisterInfo &MRI, bool IsDOT) const {
-  unsigned Mods = 0;
-  MachineInstr *MI = MRI.getVRegDef(Src);
+enum srcStatus {
+  IS_SAME,
+  IS_UPPER_HALF,
+  IS_LOWER_HALF,
+  IS_NEG,
+  IS_UPPER_HALF_NEG,
+  IS_LOWER_HALF_NEG,
+  LAST_STAT = IS_LOWER_HALF_NEG
+};
+
+static bool isTruncHalf(const MachineInstr *MI,
+                        const MachineRegisterInfo &MRI) {
+  assert(MI->getOpcode() == AMDGPU::G_TRUNC);
+  unsigned dstSize = MRI.getType(MI->getOperand(0).getReg()).getSizeInBits();
+  unsigned srcSize = MRI.getType(MI->getOperand(1).getReg()).getSizeInBits();
+  return dstSize * 2 == srcSize;
+}
+
+static bool isLshrHalf(const MachineInstr *MI, const MachineRegisterInfo &MRI) {
+  assert(MI->getOpcode() == AMDGPU::G_LSHR);
+  Register ShiftSrc;
+  std::optional<ValueAndVReg> ShiftAmt;
+  if (mi_match(MI->getOperand(0).getReg(), MRI,
+               m_GLShr(m_Reg(ShiftSrc), m_GCst(ShiftAmt)))) {
+    unsigned srcSize = MRI.getType(MI->getOperand(1).getReg()).getSizeInBits();
+    unsigned shift = ShiftAmt->Value.getZExtValue();
+    return shift * 2 == srcSize;
+  }
+  return false;
+}
 
-  if (MI->getOpcode() == AMDGPU::G_FNEG &&
-      // It's possible to see an f32 fneg here, but unlikely.
-      // TODO: Treat f32 fneg as only high bit.
-      MRI.getType(Src) == LLT::fixed_vector(2, 16)) {
-    Mods ^= (SISrcMods::NEG | SISrcMods::NEG_HI);
-    Src = MI->getOperand(1).getReg();
-    MI = MRI.getVRegDef(Src);
+static bool isShlHalf(const MachineInstr *MI, const MachineRegisterInfo &MRI) {
+  assert(MI->getOpcode() == AMDGPU::G_SHL);
+  Register ShiftSrc;
+  std::optional<ValueAndVReg> ShiftAmt;
+  if (mi_match(MI->getOperand(0).getReg(), MRI,
+               m_GShl(m_Reg(ShiftSrc), m_GCst(ShiftAmt)))) {
+    unsigned srcSize = MRI.getType(MI->getOperand(1).getReg()).getSizeInBits();
+    unsigned shift = ShiftAmt->Value.getZExtValue();
+    return shift * 2 == srcSize;
+  }
+  return false;
+}
+
+static bool retOpStat(const MachineOperand *Op, int stat,
+                      std::pair<const MachineOperand *, int> &curr) {
+  if ((Op->isReg() && !(Op->getReg().isPhysical())) || Op->isImm() ||
+      Op->isCImm() || Op->isFPImm()) {
+    curr = {Op, stat};
+    return true;
+  }
+  return false;
+}
+
+static bool calcNextStatus(std::pair<const MachineOperand *, int> &curr,
+                           const MachineRegisterInfo &MRI) {
+  if (!curr.first->isReg()) {
+    return false;
+  }
+  const MachineInstr *MI = nullptr;
+
+  if (!curr.first->isDef()) {
+    MI = MRI.getVRegDef(curr.first->getReg());
+  } else {
+    MI = curr.first->getParent();
+  }
+  if (!MI) {
+    return false;
+  }
+
+  unsigned Opc = MI->getOpcode();
+
+  // Handle general Opc cases
+  switch (Opc) {
+  case AMDGPU::G_BITCAST:
+  case AMDGPU::G_CONSTANT:
+  case AMDGPU::G_FCONSTANT:
+  case AMDGPU::COPY:
+    return retOpStat(&MI->getOperand(1), curr.second, curr);
+  case AMDGPU::G_FNEG:
+    // XXXX + 3 = XXXX_NEG, (XXXX_NEG + 3) mod 3 = XXXX
+    return retOpStat(&MI->getOperand(1),
+                     (curr.second + ((LAST_STAT + 1) / 2)) % (LAST_STAT + 1),
+                     curr);
+  }
+
+  // Calc next stat from current stat
+  switch (curr.second) {
+  case IS_SAME:
+    switch (Opc) {
+    case AMDGPU::G_TRUNC: {
+      if (isTruncHalf(MI, MRI)) {
+        return retOpStat(&MI->getOperand(1), IS_LOWER_HALF, curr);
+      }
+      break;
+    }
+    }
+    break;
+  case IS_NEG:
+    switch (Opc) {
+    case AMDGPU::G_TRUNC: {
+      if (isTruncHalf(MI, MRI)) {
+        return retOpStat(&MI->getOperand(1), IS_LOWER_HALF_NEG, curr);
+      }
+      break;
+    }
+    }
+    break;
+  case IS_UPPER_HALF:
+    switch (Opc) {
+    case AMDGPU::G_SHL: {
+      if (isShlHalf(MI, MRI)) {
+        return retOpStat(&MI->getOperand(1), IS_LOWER_HALF, curr);
+      }
+      break;
+    }
+    }
+    break;
+  case IS_LOWER_HALF:
+    switch (Opc) {
+    case AMDGPU::G_LSHR: {
+      if (isLshrHalf(MI, MRI)) {
+        return retOpStat(&MI->getOperand(1), IS_UPPER_HALF, curr);
+      }
+      break;
+    }
+    }
+    break;
+  case IS_UPPER_HALF_NEG:
+    switch (Opc) {
+    case AMDGPU::G_SHL: {
+      if (isShlHalf(MI, MRI)) {
+        return retOpStat(&MI->getOperand(1), IS_LOWER_HALF_NEG, curr);
+      }
+      break;
+    }
+    }
+    break;
+  case IS_LOWER_HALF_NEG:
+    switch (Opc) {
+    case AMDGPU::G_LSHR: {
+      if (isLshrHalf(MI, MRI)) {
+        return retOpStat(&MI->getOperand(1), IS_UPPER_HALF_NEG, curr);
+      }
+      break;
+    }
+    }
+    break;
+  }
+  return false;
+}
+
+SmallVector<std::pair<const MachineOperand *, int>>
+getSrcStats(const MachineOperand *Op, const MachineRegisterInfo &MRI,
+            bool onlyLastSameOrNeg = false, int maxDepth = 6) {
+  int depth = 0;
+  std::pair<const MachineOperand *, int> curr = {Op, IS_SAME};
+  SmallVector<std::pair<const MachineOperand *, int>> statList;
+
+  while (true) {
+    depth++;
+    if (depth > maxDepth) {
+      break;
+    }
+    bool ret = calcNextStatus(curr, MRI);
----------------
Shoreshen wrote:

Hi @jmmartinez , this is mainly for onlyLastSameOrNeg. By passing it as reference, if calcNextStatus fail, the curr keeps the last curr's value.

The onlyLastSameOrNeg means to pick the last op that is same or neg to the origin op. This is used to search if the current operand of vop3p need fneg. (e.g v_pk_fadd fneg(build_vector op1, op2), op3 )

To find it out, I have to search until exceed max depth or calcNextStatus fail.  When calcNextStatus fail, then the last curr is the only term to push into the result.

But the loop can be simplified as you said, I put some conditions into while. Thanks a lot


https://github.com/llvm/llvm-project/pull/130234


More information about the llvm-commits mailing list