[llvm] [AMDGPU] Implement vop3p complex pattern optmization for gisel (PR #130234)

Diana Picus via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 27 06:18:20 PDT 2025


================
@@ -4314,44 +4314,584 @@ AMDGPUInstructionSelector::selectVOP3NoMods(MachineOperand &Root) const {
   }};
 }
 
-std::pair<Register, unsigned>
-AMDGPUInstructionSelector::selectVOP3PModsImpl(
-  Register Src, const MachineRegisterInfo &MRI, bool IsDOT) const {
+enum class SrcStatus {
+  IS_SAME,
+  IS_UPPER_HALF,
+  IS_LOWER_HALF,
+  IS_UPPER_HALF_NEG,
+  IS_LOWER_HALF_NEG,
+  IS_HI_NEG,
+  IS_LO_NEG,
+  IS_BOTH_NEG,
+  INVALID,
+  NEG_START = IS_UPPER_HALF_NEG,
+  NEG_END = IS_BOTH_NEG,
+  HALF_START = IS_UPPER_HALF,
+  HALF_END = IS_LOWER_HALF_NEG
+};
+
+static bool isTruncHalf(const MachineInstr *MI,
+                        const MachineRegisterInfo &MRI) {
+  if (MI->getOpcode() != AMDGPU::G_TRUNC)
+    return false;
+
+  unsigned DstSize = MRI.getType(MI->getOperand(0).getReg()).getSizeInBits();
+  unsigned SrcSize = MRI.getType(MI->getOperand(1).getReg()).getSizeInBits();
+  return DstSize * 2 == SrcSize;
+}
+
+static bool isLshrHalf(const MachineInstr *MI, const MachineRegisterInfo &MRI) {
+  if (MI->getOpcode() != AMDGPU::G_LSHR)
+    return false;
+
+  Register ShiftSrc;
+  std::optional<ValueAndVReg> ShiftAmt;
+  if (mi_match(MI->getOperand(0).getReg(), MRI,
+               m_GLShr(m_Reg(ShiftSrc), m_GCst(ShiftAmt)))) {
+    unsigned SrcSize = MRI.getType(MI->getOperand(1).getReg()).getSizeInBits();
+    unsigned Shift = ShiftAmt->Value.getZExtValue();
+    return Shift * 2 == SrcSize;
+  }
+  return false;
+}
+
+static bool isShlHalf(const MachineInstr *MI, const MachineRegisterInfo &MRI) {
+  if (MI->getOpcode() != AMDGPU::G_SHL)
+    return false;
+
+  Register ShiftSrc;
+  std::optional<ValueAndVReg> ShiftAmt;
+  if (mi_match(MI->getOperand(0).getReg(), MRI,
+               m_GShl(m_Reg(ShiftSrc), m_GCst(ShiftAmt)))) {
+    unsigned SrcSize = MRI.getType(MI->getOperand(1).getReg()).getSizeInBits();
+    unsigned Shift = ShiftAmt->Value.getZExtValue();
+    return Shift * 2 == SrcSize;
+  }
+  return false;
+}
+
+static std::optional<std::pair<const MachineOperand *, SrcStatus>>
+retOpStat(const MachineOperand *Op, SrcStatus Stat,
+          std::pair<const MachineOperand *, SrcStatus> &Curr) {
+  if (Stat != SrcStatus::INVALID &&
+      ((Op->isReg() && !(Op->getReg().isPhysical())) || Op->isImm() ||
+       Op->isCImm() || Op->isFPImm())) {
+    return std::optional<std::pair<const MachineOperand *, SrcStatus>>(
+        {Op, Stat});
+  }
+
+  return std::nullopt;
+}
+
+enum class TypeClass { VECTOR_OF_TWO, SCALAR, NON_OF_LISTED };
+
+static TypeClass isVectorOfTwoOrScalar(const MachineOperand *Op,
+                                       const MachineRegisterInfo &MRI) {
+  if (!Op->isReg() || Op->getReg().isPhysical())
+    return TypeClass::NON_OF_LISTED;
+  LLT OpTy = MRI.getType(Op->getReg());
+  if (OpTy.isScalar())
+    return TypeClass::SCALAR;
+  if (OpTy.isVector() && OpTy.getNumElements() == 2)
+    return TypeClass::VECTOR_OF_TWO;
+  return TypeClass::NON_OF_LISTED;
+}
+
+static SrcStatus getNegStatus(const MachineOperand *Op, SrcStatus S,
+                              const MachineRegisterInfo &MRI) {
+  TypeClass NegType = isVectorOfTwoOrScalar(Op, MRI);
+  if (NegType != TypeClass::VECTOR_OF_TWO && NegType != TypeClass::SCALAR)
+    return SrcStatus::INVALID;
+
+  switch (S) {
+  case SrcStatus::IS_SAME:
+    if (NegType == TypeClass::VECTOR_OF_TWO) {
+      // Vector of 2:
+      // [SrcHi, SrcLo]   = [CurrHi, CurrLo]
+      // [CurrHi, CurrLo] = neg [OpHi, OpLo](2 x Type)
+      // [CurrHi, CurrLo] = [-OpHi, -OpLo](2 x Type)
+      // [SrcHi, SrcLo]   = [-OpHi, -OpLo]
+      return SrcStatus::IS_BOTH_NEG;
+    } else if (NegType == TypeClass::SCALAR) {
+      // Scalar:
+      // [SrcHi, SrcLo]   = [CurrHi, CurrLo]
+      // [CurrHi, CurrLo] = neg [OpHi, OpLo](Type)
+      // [CurrHi, CurrLo] = [-OpHi, OpLo](Type)
+      // [SrcHi, SrcLo]   = [-OpHi, OpLo]
+      return SrcStatus::IS_HI_NEG;
+    }
+    break;
+  case SrcStatus::IS_HI_NEG:
+    if (NegType == TypeClass::VECTOR_OF_TWO) {
+      // Vector of 2:
+      // [SrcHi, SrcLo]   = [-CurrHi, CurrLo]
+      // [CurrHi, CurrLo] = neg [OpHi, OpLo](2 x Type)
+      // [CurrHi, CurrLo] = [-OpHi, -OpLo](2 x Type)
+      // [SrcHi, SrcLo]   = [-(-OpHi), -OpLo] = [OpHi, -OpLo]
+      return SrcStatus::IS_LO_NEG;
+    } else if (NegType == TypeClass::SCALAR) {
+      // Scalar:
+      // [SrcHi, SrcLo]   = [-CurrHi, CurrLo]
+      // [CurrHi, CurrLo] = neg [OpHi, OpLo](Type)
+      // [CurrHi, CurrLo] = [-OpHi, OpLo](Type)
+      // [SrcHi, SrcLo]   = [-(-OpHi), OpLo] = [OpHi, OpLo]
+      return SrcStatus::IS_SAME;
+    }
+    break;
+  case SrcStatus::IS_LO_NEG:
+    if (NegType == TypeClass::VECTOR_OF_TWO) {
+      // Vector of 2:
+      // [SrcHi, SrcLo]   = [CurrHi, -CurrLo]
+      // [CurrHi, CurrLo] = fneg [OpHi, OpLo](2 x Type)
+      // [CurrHi, CurrLo] = [-OpHi, -OpLo](2 x Type)
+      // [SrcHi, SrcLo]   = [-OpHi, -(-OpLo)] = [-OpHi, OpLo]
+      return SrcStatus::IS_HI_NEG;
+    } else if (NegType == TypeClass::SCALAR) {
+      // Scalar:
+      // [SrcHi, SrcLo]   = [CurrHi, -CurrLo]
+      // [CurrHi, CurrLo] = fneg [OpHi, OpLo](Type)
+      // [CurrHi, CurrLo] = [-OpHi, OpLo](Type)
+      // [SrcHi, SrcLo]   = [-OpHi, -OpLo]
+      return SrcStatus::IS_BOTH_NEG;
+    }
+    break;
+  case SrcStatus::IS_BOTH_NEG:
+    if (NegType == TypeClass::VECTOR_OF_TWO) {
+      // Vector of 2:
+      // [SrcHi, SrcLo]   = [-CurrHi, -CurrLo]
+      // [CurrHi, CurrLo] = fneg [OpHi, OpLo](2 x Type)
+      // [CurrHi, CurrLo] = [-OpHi, -OpLo](2 x Type)
+      // [SrcHi, SrcLo]   = [OpHi, OpLo]
+      return SrcStatus::IS_SAME;
+    } else if (NegType == TypeClass::SCALAR) {
+      // Scalar:
+      // [SrcHi, SrcLo]   = [-CurrHi, -CurrLo]
+      // [CurrHi, CurrLo] = fneg [OpHi, OpLo](Type)
+      // [CurrHi, CurrLo] = [-OpHi, OpLo](Type)
+      // [SrcHi, SrcLo]   = [OpHi, -OpLo]
+      return SrcStatus::IS_LO_NEG;
+    }
+    break;
+  case SrcStatus::IS_UPPER_HALF:
+    // Vector of 2:
+    // Src = CurrUpper
+    // Curr = [CurrUpper, CurrLower]
+    // [CurrUpper, CurrLower] = fneg [OpUpper, OpLower](2 x Type)
+    // [CurrUpper, CurrLower] = [-OpUpper, -OpLower](2 x Type)
+    // Src = -OpUpper
+    //
+    // Scalar:
+    // Src = CurrUpper
+    // Curr = [CurrUpper, CurrLower]
+    // [CurrUpper, CurrLower] = fneg [OpUpper, OpLower](Type)
+    // [CurrUpper, CurrLower] = [-OpUpper, OpLower](Type)
+    // Src = -OpUpper
+    return SrcStatus::IS_UPPER_HALF_NEG;
+  case SrcStatus::IS_LOWER_HALF:
+    if (NegType == TypeClass::VECTOR_OF_TWO) {
+      // Vector of 2:
+      // Src = CurrLower
+      // Curr = [CurrUpper, CurrLower]
+      // [CurrUpper, CurrLower] = fneg [OpUpper, OpLower](2 x Type)
+      // [CurrUpper, CurrLower] = [-OpUpper, -OpLower](2 x Type)
+      // Src = -OpLower
+      return SrcStatus::IS_LOWER_HALF_NEG;
+    } else if (NegType == TypeClass::SCALAR) {
+      // Scalar:
+      // Src = CurrLower
+      // Curr = [CurrUpper, CurrLower]
+      // [CurrUpper, CurrLower] = fneg [OpUpper, OpLower](Type)
+      // [CurrUpper, CurrLower] = [-OpUpper, OpLower](Type)
+      // Src = OpLower
+      return SrcStatus::IS_LOWER_HALF;
+    }
+    break;
+  case SrcStatus::IS_UPPER_HALF_NEG:
+    // Vector of 2:
+    // Src = -CurrUpper
+    // Curr = [CurrUpper, CurrLower]
+    // [CurrUpper, CurrLower] = fneg [OpUpper, OpLower](2 x Type)
+    // [CurrUpper, CurrLower] = [-OpUpper, -OpLower](2 x Type)
+    // Src = -(-OpUpper) = OpUpper
+    //
+    // Scalar:
+    // Src = -CurrUpper
+    // Curr = [CurrUpper, CurrLower]
+    // [CurrUpper, CurrLower] = fneg [OpUpper, OpLower](Type)
+    // [CurrUpper, CurrLower] = [-OpUpper, OpLower](Type)
+    // Src = -(-OpUpper) = OpUpper
+    return SrcStatus::IS_UPPER_HALF;
+  case SrcStatus::IS_LOWER_HALF_NEG:
+    if (NegType == TypeClass::VECTOR_OF_TWO) {
+      // Vector of 2:
+      // Src = -CurrLower
+      // Curr = [CurrUpper, CurrLower]
+      // [CurrUpper, CurrLower] = fneg [OpUpper, OpLower](2 x Type)
+      // [CurrUpper, CurrLower] = [-OpUpper, -OpLower](2 x Type)
+      // Src = -(-OpLower) = OpLower
+      return SrcStatus::IS_LOWER_HALF;
+    } else if (NegType == TypeClass::SCALAR) {
+      // Scalar:
+      // Src = -CurrLower
+      // Curr = [CurrUpper, CurrLower]
+      // [CurrUpper, CurrLower] = fneg [OpUpper, OpLower](Type)
+      // [CurrUpper, CurrLower] = [-OpUpper, OpLower](Type)
+      // Src = -OpLower
+      return SrcStatus::IS_LOWER_HALF_NEG;
+    }
+    break;
+  }
+  llvm_unreachable("unexpected SrcStatus");
+}
+
+static std::optional<std::pair<const MachineOperand *, SrcStatus>>
+calcNextStatus(std::pair<const MachineOperand *, SrcStatus> Curr,
+               const MachineRegisterInfo &MRI) {
+  if (!Curr.first->isReg())
+    return std::nullopt;
+
+  const MachineInstr *MI = nullptr;
+
+  if (!Curr.first->isDef())
+    MI = MRI.getVRegDef(Curr.first->getReg());
+  else
+    MI = Curr.first->getParent();
+
+  if (!MI)
+    return std::nullopt;
+
+  unsigned Opc = MI->getOpcode();
+
+  // Handle general Opc cases.
+  switch (Opc) {
+  case AMDGPU::G_BITCAST:
+  case AMDGPU::G_CONSTANT:
+  case AMDGPU::G_FCONSTANT:
+  case AMDGPU::COPY:
+    return retOpStat(&MI->getOperand(1), Curr.second, Curr);
+  case AMDGPU::G_FNEG:
+    return retOpStat(&MI->getOperand(1),
+                     getNegStatus(Curr.first, Curr.second, MRI), Curr);
+  default:
+    break;
+  }
+
+  // Calc next Stat from current Stat.
+  switch (Curr.second) {
+  case SrcStatus::IS_SAME:
+    if (isTruncHalf(MI, MRI))
+      return retOpStat(&MI->getOperand(1), SrcStatus::IS_LOWER_HALF, Curr);
+    break;
+  case SrcStatus::IS_HI_NEG:
+    if (isTruncHalf(MI, MRI)) {
+      // [SrcHi, SrcLo]   = [-CurrHi, CurrLo]
+      // [CurrHi, CurrLo] = trunc [OpUpper, OpLower] = OpLower
+      //                  = [OpLowerHi, OpLowerLo]
+      // Src = [SrcHi, SrcLo] = [-CurrHi, CurrLo]
+      //     = [-OpLowerHi, OpLowerLo]
+      //     = -OpLower
+      return retOpStat(&MI->getOperand(1), SrcStatus::IS_LOWER_HALF_NEG, Curr);
+    }
+    break;
+  case SrcStatus::IS_UPPER_HALF:
+    if (isShlHalf(MI, MRI))
+      return retOpStat(&MI->getOperand(1), SrcStatus::IS_LOWER_HALF, Curr);
+    break;
+  case SrcStatus::IS_LOWER_HALF:
+    if (isLshrHalf(MI, MRI))
+      return retOpStat(&MI->getOperand(1), SrcStatus::IS_UPPER_HALF, Curr);
+    break;
+  case SrcStatus::IS_UPPER_HALF_NEG:
+    if (isShlHalf(MI, MRI))
+      return retOpStat(&MI->getOperand(1), SrcStatus::IS_LOWER_HALF_NEG, Curr);
+    break;
+  case SrcStatus::IS_LOWER_HALF_NEG:
+    if (isLshrHalf(MI, MRI))
+      return retOpStat(&MI->getOperand(1), SrcStatus::IS_UPPER_HALF_NEG, Curr);
+    break;
+  default:
+    break;
+  }
+  return std::nullopt;
+}
+
+class statOptions {
----------------
rovka wrote:

Can you find a different name for this? `stat` suggests statistics, but that's not what this is.

https://github.com/llvm/llvm-project/pull/130234


More information about the llvm-commits mailing list