[llvm] [AMDGPU] Implement vop3p complex pattern optmization for gisel (PR #130234)

via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 27 20:59:39 PDT 2025


================
@@ -4314,44 +4314,584 @@ AMDGPUInstructionSelector::selectVOP3NoMods(MachineOperand &Root) const {
   }};
 }
 
-std::pair<Register, unsigned>
-AMDGPUInstructionSelector::selectVOP3PModsImpl(
-  Register Src, const MachineRegisterInfo &MRI, bool IsDOT) const {
+enum class SrcStatus {
+  IS_SAME,
+  IS_UPPER_HALF,
+  IS_LOWER_HALF,
+  IS_UPPER_HALF_NEG,
+  IS_LOWER_HALF_NEG,
+  IS_HI_NEG,
+  IS_LO_NEG,
+  IS_BOTH_NEG,
+  INVALID,
+  NEG_START = IS_UPPER_HALF_NEG,
+  NEG_END = IS_BOTH_NEG,
+  HALF_START = IS_UPPER_HALF,
+  HALF_END = IS_LOWER_HALF_NEG
+};
+
+static bool isTruncHalf(const MachineInstr *MI,
+                        const MachineRegisterInfo &MRI) {
+  if (MI->getOpcode() != AMDGPU::G_TRUNC)
+    return false;
+
+  unsigned DstSize = MRI.getType(MI->getOperand(0).getReg()).getSizeInBits();
+  unsigned SrcSize = MRI.getType(MI->getOperand(1).getReg()).getSizeInBits();
+  return DstSize * 2 == SrcSize;
+}
+
+static bool isLshrHalf(const MachineInstr *MI, const MachineRegisterInfo &MRI) {
+  if (MI->getOpcode() != AMDGPU::G_LSHR)
+    return false;
+
+  Register ShiftSrc;
+  std::optional<ValueAndVReg> ShiftAmt;
+  if (mi_match(MI->getOperand(0).getReg(), MRI,
+               m_GLShr(m_Reg(ShiftSrc), m_GCst(ShiftAmt)))) {
+    unsigned SrcSize = MRI.getType(MI->getOperand(1).getReg()).getSizeInBits();
+    unsigned Shift = ShiftAmt->Value.getZExtValue();
+    return Shift * 2 == SrcSize;
+  }
+  return false;
+}
+
+static bool isShlHalf(const MachineInstr *MI, const MachineRegisterInfo &MRI) {
+  if (MI->getOpcode() != AMDGPU::G_SHL)
+    return false;
+
+  Register ShiftSrc;
+  std::optional<ValueAndVReg> ShiftAmt;
+  if (mi_match(MI->getOperand(0).getReg(), MRI,
+               m_GShl(m_Reg(ShiftSrc), m_GCst(ShiftAmt)))) {
+    unsigned SrcSize = MRI.getType(MI->getOperand(1).getReg()).getSizeInBits();
+    unsigned Shift = ShiftAmt->Value.getZExtValue();
+    return Shift * 2 == SrcSize;
+  }
+  return false;
+}
+
+static std::optional<std::pair<const MachineOperand *, SrcStatus>>
+retOpStat(const MachineOperand *Op, SrcStatus Stat,
+          std::pair<const MachineOperand *, SrcStatus> &Curr) {
+  if (Stat != SrcStatus::INVALID &&
+      ((Op->isReg() && !(Op->getReg().isPhysical())) || Op->isImm() ||
+       Op->isCImm() || Op->isFPImm())) {
+    return std::optional<std::pair<const MachineOperand *, SrcStatus>>(
+        {Op, Stat});
+  }
+
+  return std::nullopt;
+}
+
+enum class TypeClass { VECTOR_OF_TWO, SCALAR, NON_OF_LISTED };
+
+static TypeClass isVectorOfTwoOrScalar(const MachineOperand *Op,
+                                       const MachineRegisterInfo &MRI) {
+  if (!Op->isReg() || Op->getReg().isPhysical())
+    return TypeClass::NON_OF_LISTED;
+  LLT OpTy = MRI.getType(Op->getReg());
+  if (OpTy.isScalar())
+    return TypeClass::SCALAR;
+  if (OpTy.isVector() && OpTy.getNumElements() == 2)
+    return TypeClass::VECTOR_OF_TWO;
+  return TypeClass::NON_OF_LISTED;
+}
+
+static SrcStatus getNegStatus(const MachineOperand *Op, SrcStatus S,
+                              const MachineRegisterInfo &MRI) {
+  TypeClass NegType = isVectorOfTwoOrScalar(Op, MRI);
+  if (NegType != TypeClass::VECTOR_OF_TWO && NegType != TypeClass::SCALAR)
+    return SrcStatus::INVALID;
+
+  switch (S) {
+  case SrcStatus::IS_SAME:
+    if (NegType == TypeClass::VECTOR_OF_TWO) {
+      // Vector of 2:
+      // [SrcHi, SrcLo]   = [CurrHi, CurrLo]
+      // [CurrHi, CurrLo] = neg [OpHi, OpLo](2 x Type)
+      // [CurrHi, CurrLo] = [-OpHi, -OpLo](2 x Type)
+      // [SrcHi, SrcLo]   = [-OpHi, -OpLo]
+      return SrcStatus::IS_BOTH_NEG;
+    } else if (NegType == TypeClass::SCALAR) {
+      // Scalar:
+      // [SrcHi, SrcLo]   = [CurrHi, CurrLo]
+      // [CurrHi, CurrLo] = neg [OpHi, OpLo](Type)
+      // [CurrHi, CurrLo] = [-OpHi, OpLo](Type)
+      // [SrcHi, SrcLo]   = [-OpHi, OpLo]
+      return SrcStatus::IS_HI_NEG;
+    }
+    break;
+  case SrcStatus::IS_HI_NEG:
+    if (NegType == TypeClass::VECTOR_OF_TWO) {
+      // Vector of 2:
+      // [SrcHi, SrcLo]   = [-CurrHi, CurrLo]
+      // [CurrHi, CurrLo] = neg [OpHi, OpLo](2 x Type)
+      // [CurrHi, CurrLo] = [-OpHi, -OpLo](2 x Type)
+      // [SrcHi, SrcLo]   = [-(-OpHi), -OpLo] = [OpHi, -OpLo]
+      return SrcStatus::IS_LO_NEG;
+    } else if (NegType == TypeClass::SCALAR) {
+      // Scalar:
+      // [SrcHi, SrcLo]   = [-CurrHi, CurrLo]
+      // [CurrHi, CurrLo] = neg [OpHi, OpLo](Type)
+      // [CurrHi, CurrLo] = [-OpHi, OpLo](Type)
+      // [SrcHi, SrcLo]   = [-(-OpHi), OpLo] = [OpHi, OpLo]
+      return SrcStatus::IS_SAME;
+    }
+    break;
+  case SrcStatus::IS_LO_NEG:
+    if (NegType == TypeClass::VECTOR_OF_TWO) {
+      // Vector of 2:
+      // [SrcHi, SrcLo]   = [CurrHi, -CurrLo]
+      // [CurrHi, CurrLo] = fneg [OpHi, OpLo](2 x Type)
+      // [CurrHi, CurrLo] = [-OpHi, -OpLo](2 x Type)
+      // [SrcHi, SrcLo]   = [-OpHi, -(-OpLo)] = [-OpHi, OpLo]
+      return SrcStatus::IS_HI_NEG;
+    } else if (NegType == TypeClass::SCALAR) {
+      // Scalar:
+      // [SrcHi, SrcLo]   = [CurrHi, -CurrLo]
+      // [CurrHi, CurrLo] = fneg [OpHi, OpLo](Type)
+      // [CurrHi, CurrLo] = [-OpHi, OpLo](Type)
+      // [SrcHi, SrcLo]   = [-OpHi, -OpLo]
+      return SrcStatus::IS_BOTH_NEG;
+    }
+    break;
+  case SrcStatus::IS_BOTH_NEG:
+    if (NegType == TypeClass::VECTOR_OF_TWO) {
+      // Vector of 2:
+      // [SrcHi, SrcLo]   = [-CurrHi, -CurrLo]
+      // [CurrHi, CurrLo] = fneg [OpHi, OpLo](2 x Type)
+      // [CurrHi, CurrLo] = [-OpHi, -OpLo](2 x Type)
+      // [SrcHi, SrcLo]   = [OpHi, OpLo]
+      return SrcStatus::IS_SAME;
+    } else if (NegType == TypeClass::SCALAR) {
+      // Scalar:
+      // [SrcHi, SrcLo]   = [-CurrHi, -CurrLo]
+      // [CurrHi, CurrLo] = fneg [OpHi, OpLo](Type)
+      // [CurrHi, CurrLo] = [-OpHi, OpLo](Type)
+      // [SrcHi, SrcLo]   = [OpHi, -OpLo]
+      return SrcStatus::IS_LO_NEG;
+    }
+    break;
+  case SrcStatus::IS_UPPER_HALF:
+    // Vector of 2:
+    // Src = CurrUpper
+    // Curr = [CurrUpper, CurrLower]
+    // [CurrUpper, CurrLower] = fneg [OpUpper, OpLower](2 x Type)
+    // [CurrUpper, CurrLower] = [-OpUpper, -OpLower](2 x Type)
+    // Src = -OpUpper
+    //
+    // Scalar:
+    // Src = CurrUpper
+    // Curr = [CurrUpper, CurrLower]
+    // [CurrUpper, CurrLower] = fneg [OpUpper, OpLower](Type)
+    // [CurrUpper, CurrLower] = [-OpUpper, OpLower](Type)
+    // Src = -OpUpper
+    return SrcStatus::IS_UPPER_HALF_NEG;
+  case SrcStatus::IS_LOWER_HALF:
+    if (NegType == TypeClass::VECTOR_OF_TWO) {
+      // Vector of 2:
+      // Src = CurrLower
+      // Curr = [CurrUpper, CurrLower]
+      // [CurrUpper, CurrLower] = fneg [OpUpper, OpLower](2 x Type)
+      // [CurrUpper, CurrLower] = [-OpUpper, -OpLower](2 x Type)
+      // Src = -OpLower
+      return SrcStatus::IS_LOWER_HALF_NEG;
+    } else if (NegType == TypeClass::SCALAR) {
+      // Scalar:
+      // Src = CurrLower
+      // Curr = [CurrUpper, CurrLower]
+      // [CurrUpper, CurrLower] = fneg [OpUpper, OpLower](Type)
+      // [CurrUpper, CurrLower] = [-OpUpper, OpLower](Type)
+      // Src = OpLower
+      return SrcStatus::IS_LOWER_HALF;
+    }
+    break;
+  case SrcStatus::IS_UPPER_HALF_NEG:
+    // Vector of 2:
+    // Src = -CurrUpper
+    // Curr = [CurrUpper, CurrLower]
+    // [CurrUpper, CurrLower] = fneg [OpUpper, OpLower](2 x Type)
+    // [CurrUpper, CurrLower] = [-OpUpper, -OpLower](2 x Type)
+    // Src = -(-OpUpper) = OpUpper
+    //
+    // Scalar:
+    // Src = -CurrUpper
+    // Curr = [CurrUpper, CurrLower]
+    // [CurrUpper, CurrLower] = fneg [OpUpper, OpLower](Type)
+    // [CurrUpper, CurrLower] = [-OpUpper, OpLower](Type)
+    // Src = -(-OpUpper) = OpUpper
+    return SrcStatus::IS_UPPER_HALF;
+  case SrcStatus::IS_LOWER_HALF_NEG:
+    if (NegType == TypeClass::VECTOR_OF_TWO) {
+      // Vector of 2:
+      // Src = -CurrLower
+      // Curr = [CurrUpper, CurrLower]
+      // [CurrUpper, CurrLower] = fneg [OpUpper, OpLower](2 x Type)
+      // [CurrUpper, CurrLower] = [-OpUpper, -OpLower](2 x Type)
+      // Src = -(-OpLower) = OpLower
+      return SrcStatus::IS_LOWER_HALF;
+    } else if (NegType == TypeClass::SCALAR) {
+      // Scalar:
+      // Src = -CurrLower
+      // Curr = [CurrUpper, CurrLower]
+      // [CurrUpper, CurrLower] = fneg [OpUpper, OpLower](Type)
+      // [CurrUpper, CurrLower] = [-OpUpper, OpLower](Type)
+      // Src = -OpLower
+      return SrcStatus::IS_LOWER_HALF_NEG;
+    }
+    break;
+  }
+  llvm_unreachable("unexpected SrcStatus");
+}
+
+static std::optional<std::pair<const MachineOperand *, SrcStatus>>
+calcNextStatus(std::pair<const MachineOperand *, SrcStatus> Curr,
+               const MachineRegisterInfo &MRI) {
+  if (!Curr.first->isReg())
+    return std::nullopt;
+
+  const MachineInstr *MI = nullptr;
+
+  if (!Curr.first->isDef())
+    MI = MRI.getVRegDef(Curr.first->getReg());
+  else
+    MI = Curr.first->getParent();
+
+  if (!MI)
+    return std::nullopt;
+
+  unsigned Opc = MI->getOpcode();
+
+  // Handle general Opc cases.
+  switch (Opc) {
+  case AMDGPU::G_BITCAST:
+  case AMDGPU::G_CONSTANT:
+  case AMDGPU::G_FCONSTANT:
+  case AMDGPU::COPY:
+    return retOpStat(&MI->getOperand(1), Curr.second, Curr);
+  case AMDGPU::G_FNEG:
+    return retOpStat(&MI->getOperand(1),
+                     getNegStatus(Curr.first, Curr.second, MRI), Curr);
+  default:
+    break;
+  }
+
+  // Calc next Stat from current Stat.
+  switch (Curr.second) {
+  case SrcStatus::IS_SAME:
+    if (isTruncHalf(MI, MRI))
+      return retOpStat(&MI->getOperand(1), SrcStatus::IS_LOWER_HALF, Curr);
+    break;
+  case SrcStatus::IS_HI_NEG:
+    if (isTruncHalf(MI, MRI)) {
+      // [SrcHi, SrcLo]   = [-CurrHi, CurrLo]
+      // [CurrHi, CurrLo] = trunc [OpUpper, OpLower] = OpLower
+      //                  = [OpLowerHi, OpLowerLo]
+      // Src = [SrcHi, SrcLo] = [-CurrHi, CurrLo]
+      //     = [-OpLowerHi, OpLowerLo]
+      //     = -OpLower
+      return retOpStat(&MI->getOperand(1), SrcStatus::IS_LOWER_HALF_NEG, Curr);
+    }
+    break;
+  case SrcStatus::IS_UPPER_HALF:
+    if (isShlHalf(MI, MRI))
+      return retOpStat(&MI->getOperand(1), SrcStatus::IS_LOWER_HALF, Curr);
+    break;
+  case SrcStatus::IS_LOWER_HALF:
+    if (isLshrHalf(MI, MRI))
+      return retOpStat(&MI->getOperand(1), SrcStatus::IS_UPPER_HALF, Curr);
+    break;
+  case SrcStatus::IS_UPPER_HALF_NEG:
+    if (isShlHalf(MI, MRI))
+      return retOpStat(&MI->getOperand(1), SrcStatus::IS_LOWER_HALF_NEG, Curr);
+    break;
+  case SrcStatus::IS_LOWER_HALF_NEG:
+    if (isLshrHalf(MI, MRI))
+      return retOpStat(&MI->getOperand(1), SrcStatus::IS_UPPER_HALF_NEG, Curr);
+    break;
+  default:
+    break;
+  }
+  return std::nullopt;
+}
+
+class statOptions {
+private:
+  bool HasNeg;
+  bool HasOpsel;
+
+public:
+  statOptions(const MachineOperand *RootOp, const MachineRegisterInfo &MRI) {
+    const MachineInstr *MI = RootOp->getParent();
+    unsigned Opc = MI->getOpcode();
+    HasNeg = false;
+    HasOpsel = false;
+    if (Opc < TargetOpcode::GENERIC_OP_END) {
+      // Keep same for gerneric op
+      HasNeg = true;
+    } else if (Opc == TargetOpcode::G_INTRINSIC) {
+      Intrinsic::ID IntrinsicID = cast<GIntrinsic>(*MI).getIntrinsicID();
+      // Only float point intrinsic has neg & neg_hi bits
+      if (IntrinsicID == Intrinsic::amdgcn_fdot2)
+        HasNeg = true;
+    }
+
+    // Assume all complex pattern of VOP3P has opsel
+    HasOpsel = true;
+  }
+  bool checkOptions(SrcStatus Stat) const {
+    if (!HasNeg &&
+        (Stat >= SrcStatus::NEG_START || Stat <= SrcStatus::NEG_END)) {
+      return false;
+    }
+    if (!HasOpsel &&
+        (Stat >= SrcStatus::HALF_START || Stat >= SrcStatus::HALF_END)) {
+      return false;
+    }
+    return true;
+  }
+};
+
+static SmallVector<std::pair<const MachineOperand *, SrcStatus>>
+getSrcStats(const MachineOperand *Op, const MachineRegisterInfo &MRI,
+            statOptions StatOptions, int MaxDepth = 6) {
+  int Depth = 0;
+  auto Curr = calcNextStatus({Op, SrcStatus::IS_SAME}, MRI);
+  SmallVector<std::pair<const MachineOperand *, SrcStatus>, 4> Statlist;
+
+  while (Depth <= MaxDepth && Curr.has_value()) {
+    Depth++;
+    if (StatOptions.checkOptions(Curr.value().second)) {
+      Statlist.push_back(Curr.value());
+    }
+    Curr = calcNextStatus(Curr.value(), MRI);
+  }
+
+  return Statlist;
+}
+
+static std::pair<const MachineOperand *, SrcStatus>
+getLastSameOrNeg(const MachineOperand *Op, const MachineRegisterInfo &MRI,
+                 statOptions StatOptions, int MaxDepth = 6) {
+  int Depth = 0;
+  std::pair<const MachineOperand *, SrcStatus> LastSameOrNeg = {
+      Op, SrcStatus::IS_SAME};
+  auto Curr = calcNextStatus(LastSameOrNeg, MRI);
+
+  while (Depth <= MaxDepth && Curr.has_value()) {
+    Depth++;
+    if (StatOptions.checkOptions(Curr.value().second)) {
+      if (Curr.value().second == SrcStatus::IS_SAME ||
+          Curr.value().second == SrcStatus::IS_HI_NEG ||
+          Curr.value().second == SrcStatus::IS_LO_NEG ||
+          Curr.value().second == SrcStatus::IS_BOTH_NEG)
+        LastSameOrNeg = Curr.value();
+    }
+    Curr = calcNextStatus(Curr.value(), MRI);
+  }
+
+  return LastSameOrNeg;
+}
+
+static bool isInlinableFPConstant(const MachineOperand &Op,
+                                  const SIInstrInfo &TII) {
+  return Op.isFPImm() && TII.isInlineConstant(Op.getFPImm()->getValueAPF());
+}
+
+static bool isSameBitWidth(const MachineOperand *Op1, const MachineOperand *Op2,
+                           const MachineRegisterInfo &MRI) {
+  unsigned Width1 = MRI.getType(Op1->getReg()).getSizeInBits();
+  unsigned Width2 = MRI.getType(Op2->getReg()).getSizeInBits();
+  return Width1 == Width2;
+}
+
+static bool isSameOperand(const MachineOperand *Op1,
+                          const MachineOperand *Op2) {
+  if (Op1->isReg()) {
+    return Op2->isReg() && Op1->getReg() == Op2->getReg();
+  }
+  return Op1->isIdenticalTo(*Op2);
+}
+
+static unsigned updateMods(SrcStatus HiStat, SrcStatus LoStat, unsigned Mods) {
+  // SrcStatus::IS_LOWER_HALF remain 0.
+  if (HiStat == SrcStatus::IS_UPPER_HALF_NEG) {
+    Mods ^= SISrcMods::NEG_HI;
+    Mods |= SISrcMods::OP_SEL_1;
+  } else if (HiStat == SrcStatus::IS_UPPER_HALF)
+    Mods |= SISrcMods::OP_SEL_1;
+  else if (HiStat == SrcStatus::IS_LOWER_HALF_NEG)
+    Mods ^= SISrcMods::NEG_HI;
+  else if (HiStat == SrcStatus::IS_HI_NEG)
+    Mods ^= SISrcMods::NEG_HI;
+
+  if (LoStat == SrcStatus::IS_UPPER_HALF_NEG) {
+    Mods ^= SISrcMods::NEG;
+    Mods |= SISrcMods::OP_SEL_0;
+  } else if (LoStat == SrcStatus::IS_UPPER_HALF)
+    Mods |= SISrcMods::OP_SEL_0;
+  else if (LoStat == SrcStatus::IS_LOWER_HALF_NEG)
+    Mods |= SISrcMods::NEG;
+  else if (LoStat == SrcStatus::IS_HI_NEG)
+    Mods ^= SISrcMods::NEG;
+
+  return Mods;
+}
+
+static bool isValidToPack(SrcStatus HiStat, SrcStatus LoStat,
+                          const MachineOperand *NewOp,
+                          const MachineOperand *RootOp, const SIInstrInfo &TII,
+                          const MachineRegisterInfo &MRI) {
+  if (NewOp->isReg()) {
+    auto IsHalfState = [](SrcStatus S) {
+      return S == SrcStatus::IS_UPPER_HALF ||
+             S == SrcStatus::IS_UPPER_HALF_NEG ||
+             S == SrcStatus::IS_LOWER_HALF || S == SrcStatus::IS_LOWER_HALF_NEG;
+    };
+    return isSameBitWidth(NewOp, RootOp, MRI) && IsHalfState(LoStat) &&
+           IsHalfState(HiStat);
+  } else
+    return ((HiStat == SrcStatus::IS_SAME || HiStat == SrcStatus::IS_HI_NEG) &&
+            (LoStat == SrcStatus::IS_SAME || LoStat == SrcStatus::IS_HI_NEG) &&
+            isInlinableFPConstant(*NewOp, TII));
+
+  return false;
+}
+
+std::pair<const MachineOperand *, unsigned>
+AMDGPUInstructionSelector::selectVOP3PModsImpl(const MachineOperand *RootOp,
+                                               const MachineRegisterInfo &MRI,
+                                               bool IsDOT) const {
   unsigned Mods = 0;
-  MachineInstr *MI = MRI.getVRegDef(Src);
+  const MachineOperand *Op = RootOp;
+  // No modification if Root type is not form of <2 x Type>
+  if (isVectorOfTwoOrScalar(Op, MRI) != TypeClass::VECTOR_OF_TWO) {
+    Mods |= SISrcMods::OP_SEL_1;
+    return {Op, Mods};
+  }
 
-  if (MI->getOpcode() == AMDGPU::G_FNEG &&
-      // It's possible to see an f32 fneg here, but unlikely.
-      // TODO: Treat f32 fneg as only high bit.
-      MRI.getType(Src) == LLT::fixed_vector(2, 16)) {
+  statOptions StatOptions(Op, MRI);
+
+  std::pair<const MachineOperand *, SrcStatus> Stat =
+      getLastSameOrNeg(Op, MRI, StatOptions);
+  if (!Stat.first->isReg()) {
+    Mods |= SISrcMods::OP_SEL_1;
+    return {Op, Mods};
+  }
+  if (Stat.second == SrcStatus::IS_BOTH_NEG)
     Mods ^= (SISrcMods::NEG | SISrcMods::NEG_HI);
-    Src = MI->getOperand(1).getReg();
-    MI = MRI.getVRegDef(Src);
+  else if (Stat.second == SrcStatus::IS_HI_NEG)
+    Mods ^= SISrcMods::NEG_HI;
+  else if (Stat.second == SrcStatus::IS_LO_NEG)
+    Mods ^= SISrcMods::NEG;
+
+  Op = Stat.first;
+  MachineInstr *MI = MRI.getVRegDef(Op->getReg());
+
+  if (MI->getOpcode() != AMDGPU::G_BUILD_VECTOR || MI->getNumOperands() != 3 ||
+      (IsDOT && Subtarget->hasDOTOpSelHazard())) {
+    Mods |= SISrcMods::OP_SEL_1;
+    return {Op, Mods};
+  }
+
+  SmallVector<std::pair<const MachineOperand *, SrcStatus>> StatlistHi =
+      getSrcStats(&MI->getOperand(2), MRI, StatOptions);
+
+  if (StatlistHi.size() == 0) {
+    Mods |= SISrcMods::OP_SEL_1;
+    return {Op, Mods};
   }
 
-  // TODO: Handle G_FSUB 0 as fneg
+  SmallVector<std::pair<const MachineOperand *, SrcStatus>> StatlistLo =
+      getSrcStats(&MI->getOperand(1), MRI, StatOptions);
 
-  // TODO: Match op_sel through g_build_vector_trunc and g_shuffle_vector.
-  (void)IsDOT; // DOTs do not use OPSEL on gfx942+, check ST.hasDOTOpSelHazard()
+  if (StatlistLo.size() == 0) {
+    Mods |= SISrcMods::OP_SEL_1;
+    return {Op, Mods};
+  }
 
+  for (int i = StatlistHi.size() - 1; i >= 0; i--) {
+    for (int j = StatlistLo.size() - 1; j >= 0; j--) {
+      if (isSameOperand(StatlistHi[i].first, StatlistLo[j].first) &&
+          isValidToPack(StatlistHi[i].second, StatlistLo[j].second,
+                        StatlistHi[i].first, RootOp, TII, MRI))
+        return {StatlistHi[i].first,
+                updateMods(StatlistHi[i].second, StatlistLo[j].second, Mods)};
+    }
+  }
   // Packed instructions do not have abs modifiers.
   Mods |= SISrcMods::OP_SEL_1;
 
-  return std::pair(Src, Mods);
+  return {Op, Mods};
+}
+
+int64_t getAllKindImm(const MachineOperand *Op) {
+  switch (Op->getType()) {
+  case MachineOperand::MachineOperandType::MO_Immediate:
+    return Op->getImm();
+  case MachineOperand::MachineOperandType::MO_CImmediate:
+    return Op->getCImm()->getSExtValue();
+  case MachineOperand::MachineOperandType::MO_FPImmediate:
+    return Op->getFPImm()->getValueAPF().bitcastToAPInt().getSExtValue();
+  default:
+    llvm_unreachable("not an imm type");
+  }
+}
+
+static bool checkRB(const MachineOperand *Op, unsigned int RBNo,
+                    const AMDGPURegisterBankInfo &RBI,
+                    const MachineRegisterInfo &MRI,
+                    const TargetRegisterInfo &TRI) {
+  const RegisterBank *RB = RBI.getRegBank(Op->getReg(), MRI, TRI);
+  return RB->getID() == RBNo;
+}
+
+const MachineOperand *
+getVReg(const MachineOperand *NewOp, const MachineOperand *RootOp,
+        const AMDGPURegisterBankInfo &RBI, MachineRegisterInfo &MRI,
+        const TargetRegisterInfo &TRI, const SIInstrInfo &TII) {
+  // RootOp can only be VGPR or SGPR (some hand written cases such as
+  // inst-select-ashr.v2s16.mir::ashr_v2s16_vs).
+  if (checkRB(RootOp, AMDGPU::SGPRRegBankID, RBI, MRI, TRI) ||
+      checkRB(NewOp, AMDGPU::VGPRRegBankID, RBI, MRI, TRI))
+    return NewOp;
+
+  MachineInstr *MI = MRI.getVRegDef(RootOp->getReg());
+  if (MI->getOpcode() == AMDGPU::COPY &&
+      isSameOperand(NewOp, &MI->getOperand(1))) {
+    // RootOp is VGPR, NewOp is not VGPR, but RootOp = COPY NewOp.
+    return RootOp;
+  }
+
+  MachineBasicBlock *BB = MI->getParent();
+  const TargetRegisterClass *DstRC =
+      TRI.getConstrainedRegClassForOperand(*RootOp, MRI);
+  Register DstReg = MRI.createVirtualRegister(DstRC);
+
+  MachineInstrBuilder MIB =
+      BuildMI(*BB, MI, MI->getDebugLoc(), TII.get(AMDGPU::COPY), DstReg)
+          .addReg(NewOp->getReg());
+
+  // only accept VGPR
+  return &MIB->getOperand(0);
 }
 
 InstructionSelector::ComplexRendererFns
 AMDGPUInstructionSelector::selectVOP3PMods(MachineOperand &Root) const {
   MachineRegisterInfo &MRI
     = Root.getParent()->getParent()->getParent()->getRegInfo();
 
-  Register Src;
-  unsigned Mods;
-  std::tie(Src, Mods) = selectVOP3PModsImpl(Root.getReg(), MRI);
+  auto [Op, Mods] = selectVOP3PModsImpl(&Root, MRI);
+  if (!(Op->isReg()))
+    return {{
+        [=](MachineInstrBuilder &MIB) { MIB.addImm(getAllKindImm(Op)); },
+        [=](MachineInstrBuilder &MIB) { MIB.addImm(Mods); } // src_mods
+    }};
 
+  Op = getVReg(Op, &Root, RBI, MRI, TRI, TII);
   return {{
-      [=](MachineInstrBuilder &MIB) { MIB.addReg(Src); },
-      [=](MachineInstrBuilder &MIB) { MIB.addImm(Mods); }  // src_mods
+      [=](MachineInstrBuilder &MIB) { MIB.addReg(Op->getReg()); },
----------------
Shoreshen wrote:

Hi @rovka , by using add(Op) relative test crashed or failed T T, I'll have a look into it when I have time.....

https://github.com/llvm/llvm-project/pull/130234


More information about the llvm-commits mailing list