[llvm] [AMDGPU] fix up vop3p gisel errors (PR #136262)
Matt Arsenault via llvm-commits
llvm-commits at lists.llvm.org
Mon Jun 16 07:01:14 PDT 2025
================
@@ -4318,60 +4318,586 @@ AMDGPUInstructionSelector::selectVOP3NoMods(MachineOperand &Root) const {
}};
}
-std::pair<Register, unsigned>
-AMDGPUInstructionSelector::selectVOP3PModsImpl(
- Register Src, const MachineRegisterInfo &MRI, bool IsDOT) const {
+enum class SrcStatus {
+ IS_SAME,
+ IS_UPPER_HALF,
+ IS_LOWER_HALF,
+ IS_UPPER_HALF_NEG,
+ // This means current op = [op_upper, op_lower] and src = -op_lower.
+ IS_LOWER_HALF_NEG,
+ IS_HI_NEG,
+ // This means current op = [op_upper, op_lower] and src = [op_upper,
+ // -op_lower].
+ IS_LO_NEG,
+ IS_BOTH_NEG,
+ INVALID,
+ NEG_START = IS_UPPER_HALF_NEG,
+ NEG_END = IS_BOTH_NEG,
+ HALF_START = IS_UPPER_HALF,
+ HALF_END = IS_LOWER_HALF_NEG
+};
+// Test if the MI is truncating to half, such as `%reg0:n = G_TRUNC %reg1:2n`
+static bool isTruncHalf(const MachineInstr *MI,
+ const MachineRegisterInfo &MRI) {
+ if (MI->getOpcode() != AMDGPU::G_TRUNC)
+ return false;
+
+ unsigned DstSize = MRI.getType(MI->getOperand(0).getReg()).getSizeInBits();
+ unsigned SrcSize = MRI.getType(MI->getOperand(1).getReg()).getSizeInBits();
+ return DstSize * 2 == SrcSize;
+}
+
+// Test if the MI is logic shift right with half bits,
+// such as `%reg0:2n =G_LSHR %reg1:2n, CONST(n)`
+static bool isLshrHalf(const MachineInstr *MI, const MachineRegisterInfo &MRI) {
+ if (MI->getOpcode() != AMDGPU::G_LSHR)
+ return false;
+
+ Register ShiftSrc;
+ std::optional<ValueAndVReg> ShiftAmt;
+ if (mi_match(MI->getOperand(0).getReg(), MRI,
+ m_GLShr(m_Reg(ShiftSrc), m_GCst(ShiftAmt)))) {
+ unsigned SrcSize = MRI.getType(MI->getOperand(1).getReg()).getSizeInBits();
+ unsigned Shift = ShiftAmt->Value.getZExtValue();
+ return Shift * 2 == SrcSize;
+ }
+ return false;
+}
+
+// Test if the MI is shift left with half bits,
+// such as `%reg0:2n =G_SHL %reg1:2n, CONST(n)`
+static bool isShlHalf(const MachineInstr *MI, const MachineRegisterInfo &MRI) {
+ if (MI->getOpcode() != AMDGPU::G_SHL)
+ return false;
+
+ Register ShiftSrc;
+ std::optional<ValueAndVReg> ShiftAmt;
+ if (mi_match(MI->getOperand(0).getReg(), MRI,
+ m_GShl(m_Reg(ShiftSrc), m_GCst(ShiftAmt)))) {
+ unsigned SrcSize = MRI.getType(MI->getOperand(1).getReg()).getSizeInBits();
+ unsigned Shift = ShiftAmt->Value.getZExtValue();
+ return Shift * 2 == SrcSize;
+ }
+ return false;
+}
+
+// Test function, if the MI is `%reg0:n, %reg1:n = G_UNMERGE_VALUES %reg2:2n`
+static bool isUnmergeHalf(const MachineInstr *MI,
+ const MachineRegisterInfo &MRI) {
+ if (MI->getOpcode() != AMDGPU::G_UNMERGE_VALUES)
+ return false;
+ return MI->getNumOperands() == 3 && MI->getOperand(0).isDef() &&
+ MI->getOperand(1).isDef() && !MI->getOperand(2).isDef();
+}
+
+enum class TypeClass { VECTOR_OF_TWO, SCALAR, NONE_OF_LISTED };
+
+static TypeClass isVectorOfTwoOrScalar(Register Reg,
+ const MachineRegisterInfo &MRI) {
+ LLT OpTy = MRI.getType(Reg);
+ if (OpTy.isScalar())
+ return TypeClass::SCALAR;
+ if (OpTy.isVector() && OpTy.getNumElements() == 2)
+ return TypeClass::VECTOR_OF_TWO;
+ return TypeClass::NONE_OF_LISTED;
+}
+
+static SrcStatus getNegStatus(Register Reg, SrcStatus S,
+ const MachineRegisterInfo &MRI) {
+ TypeClass NegType = isVectorOfTwoOrScalar(Reg, MRI);
+ if (NegType != TypeClass::VECTOR_OF_TWO && NegType != TypeClass::SCALAR)
+ return SrcStatus::INVALID;
+
+ switch (S) {
+ case SrcStatus::IS_SAME:
+ if (NegType == TypeClass::VECTOR_OF_TWO) {
+ // Vector of 2:
+ // [SrcHi, SrcLo] = [CurrHi, CurrLo]
+ // [CurrHi, CurrLo] = neg [OpHi, OpLo](2 x Type)
+ // [CurrHi, CurrLo] = [-OpHi, -OpLo](2 x Type)
+ // [SrcHi, SrcLo] = [-OpHi, -OpLo]
+ return SrcStatus::IS_BOTH_NEG;
+ }
+ if (NegType == TypeClass::SCALAR) {
+ // Scalar:
+ // [SrcHi, SrcLo] = [CurrHi, CurrLo]
+ // [CurrHi, CurrLo] = neg [OpHi, OpLo](Type)
+ // [CurrHi, CurrLo] = [-OpHi, OpLo](Type)
+ // [SrcHi, SrcLo] = [-OpHi, OpLo]
+ return SrcStatus::IS_HI_NEG;
+ }
+ break;
+ case SrcStatus::IS_HI_NEG:
+ if (NegType == TypeClass::VECTOR_OF_TWO) {
+ // Vector of 2:
+ // [SrcHi, SrcLo] = [-CurrHi, CurrLo]
+ // [CurrHi, CurrLo] = neg [OpHi, OpLo](2 x Type)
+ // [CurrHi, CurrLo] = [-OpHi, -OpLo](2 x Type)
+ // [SrcHi, SrcLo] = [-(-OpHi), -OpLo] = [OpHi, -OpLo]
+ return SrcStatus::IS_LO_NEG;
+ }
+ if (NegType == TypeClass::SCALAR) {
+ // Scalar:
+ // [SrcHi, SrcLo] = [-CurrHi, CurrLo]
+ // [CurrHi, CurrLo] = neg [OpHi, OpLo](Type)
+ // [CurrHi, CurrLo] = [-OpHi, OpLo](Type)
+ // [SrcHi, SrcLo] = [-(-OpHi), OpLo] = [OpHi, OpLo]
+ return SrcStatus::IS_SAME;
+ }
+ break;
+ case SrcStatus::IS_LO_NEG:
+ if (NegType == TypeClass::VECTOR_OF_TWO) {
+ // Vector of 2:
+ // [SrcHi, SrcLo] = [CurrHi, -CurrLo]
+ // [CurrHi, CurrLo] = fneg [OpHi, OpLo](2 x Type)
+ // [CurrHi, CurrLo] = [-OpHi, -OpLo](2 x Type)
+ // [SrcHi, SrcLo] = [-OpHi, -(-OpLo)] = [-OpHi, OpLo]
+ return SrcStatus::IS_HI_NEG;
+ }
+ if (NegType == TypeClass::SCALAR) {
+ // Scalar:
+ // [SrcHi, SrcLo] = [CurrHi, -CurrLo]
+ // [CurrHi, CurrLo] = fneg [OpHi, OpLo](Type)
+ // [CurrHi, CurrLo] = [-OpHi, OpLo](Type)
+ // [SrcHi, SrcLo] = [-OpHi, -OpLo]
+ return SrcStatus::IS_BOTH_NEG;
+ }
+ break;
+ case SrcStatus::IS_BOTH_NEG:
+ if (NegType == TypeClass::VECTOR_OF_TWO) {
+ // Vector of 2:
+ // [SrcHi, SrcLo] = [-CurrHi, -CurrLo]
+ // [CurrHi, CurrLo] = fneg [OpHi, OpLo](2 x Type)
+ // [CurrHi, CurrLo] = [-OpHi, -OpLo](2 x Type)
+ // [SrcHi, SrcLo] = [OpHi, OpLo]
+ return SrcStatus::IS_SAME;
+ }
+ if (NegType == TypeClass::SCALAR) {
+ // Scalar:
+ // [SrcHi, SrcLo] = [-CurrHi, -CurrLo]
+ // [CurrHi, CurrLo] = fneg [OpHi, OpLo](Type)
+ // [CurrHi, CurrLo] = [-OpHi, OpLo](Type)
+ // [SrcHi, SrcLo] = [OpHi, -OpLo]
+ return SrcStatus::IS_LO_NEG;
+ }
+ break;
+ case SrcStatus::IS_UPPER_HALF:
+ // Vector of 2:
+ // Src = CurrUpper
+ // Curr = [CurrUpper, CurrLower]
+ // [CurrUpper, CurrLower] = fneg [OpUpper, OpLower](2 x Type)
+ // [CurrUpper, CurrLower] = [-OpUpper, -OpLower](2 x Type)
+ // Src = -OpUpper
+ //
+ // Scalar:
+ // Src = CurrUpper
+ // Curr = [CurrUpper, CurrLower]
+ // [CurrUpper, CurrLower] = fneg [OpUpper, OpLower](Type)
+ // [CurrUpper, CurrLower] = [-OpUpper, OpLower](Type)
+ // Src = -OpUpper
+ return SrcStatus::IS_UPPER_HALF_NEG;
+ case SrcStatus::IS_LOWER_HALF:
+ if (NegType == TypeClass::VECTOR_OF_TWO) {
+ // Vector of 2:
+ // Src = CurrLower
+ // Curr = [CurrUpper, CurrLower]
+ // [CurrUpper, CurrLower] = fneg [OpUpper, OpLower](2 x Type)
+ // [CurrUpper, CurrLower] = [-OpUpper, -OpLower](2 x Type)
+ // Src = -OpLower
+ return SrcStatus::IS_LOWER_HALF_NEG;
+ }
+ if (NegType == TypeClass::SCALAR) {
+ // Scalar:
+ // Src = CurrLower
+ // Curr = [CurrUpper, CurrLower]
+ // [CurrUpper, CurrLower] = fneg [OpUpper, OpLower](Type)
+ // [CurrUpper, CurrLower] = [-OpUpper, OpLower](Type)
+ // Src = OpLower
+ return SrcStatus::IS_LOWER_HALF;
+ }
+ break;
+ case SrcStatus::IS_UPPER_HALF_NEG:
+ // Vector of 2:
+ // Src = -CurrUpper
+ // Curr = [CurrUpper, CurrLower]
+ // [CurrUpper, CurrLower] = fneg [OpUpper, OpLower](2 x Type)
+ // [CurrUpper, CurrLower] = [-OpUpper, -OpLower](2 x Type)
+ // Src = -(-OpUpper) = OpUpper
+ //
+ // Scalar:
+ // Src = -CurrUpper
+ // Curr = [CurrUpper, CurrLower]
+ // [CurrUpper, CurrLower] = fneg [OpUpper, OpLower](Type)
+ // [CurrUpper, CurrLower] = [-OpUpper, OpLower](Type)
+ // Src = -(-OpUpper) = OpUpper
+ return SrcStatus::IS_UPPER_HALF;
+ case SrcStatus::IS_LOWER_HALF_NEG:
+ if (NegType == TypeClass::VECTOR_OF_TWO) {
+ // Vector of 2:
+ // Src = -CurrLower
+ // Curr = [CurrUpper, CurrLower]
+ // [CurrUpper, CurrLower] = fneg [OpUpper, OpLower](2 x Type)
+ // [CurrUpper, CurrLower] = [-OpUpper, -OpLower](2 x Type)
+ // Src = -(-OpLower) = OpLower
+ return SrcStatus::IS_LOWER_HALF;
+ }
+ if (NegType == TypeClass::SCALAR) {
+ // Scalar:
+ // Src = -CurrLower
+ // Curr = [CurrUpper, CurrLower]
+ // [CurrUpper, CurrLower] = fneg [OpUpper, OpLower](Type)
+ // [CurrUpper, CurrLower] = [-OpUpper, OpLower](Type)
+ // Src = -OpLower
+ return SrcStatus::IS_LOWER_HALF_NEG;
+ }
+ break;
+ default:
+ llvm_unreachable("unexpected SrcStatus");
+ }
+}
+
+static std::optional<std::pair<Register, SrcStatus>>
+calcNextStatus(std::pair<Register, SrcStatus> Curr,
+ const MachineRegisterInfo &MRI) {
+ const MachineInstr *MI = MRI.getVRegDef(Curr.first);
+
+ unsigned Opc = MI->getOpcode();
+
+ // Handle general Opc cases.
+ switch (Opc) {
+ case AMDGPU::G_BITCAST:
+ case AMDGPU::COPY:
+ if (MI->getOperand(1).getReg().isPhysical())
+ return std::nullopt;
+ return std::optional<std::pair<Register, SrcStatus>>(
+ {MI->getOperand(1).getReg(), Curr.second});
+ case AMDGPU::G_FNEG: {
+ SrcStatus Stat = getNegStatus(Curr.first, Curr.second, MRI);
+ if (Stat == SrcStatus::INVALID)
+ return std::nullopt;
+ return std::optional<std::pair<Register, SrcStatus>>(
+ {MI->getOperand(1).getReg(), Stat});
+ }
+ default:
+ break;
+ }
+
+ // Calc next Stat from current Stat.
+ switch (Curr.second) {
+ case SrcStatus::IS_SAME:
+ if (isTruncHalf(MI, MRI))
+ return std::optional<std::pair<Register, SrcStatus>>(
+ {MI->getOperand(1).getReg(), SrcStatus::IS_LOWER_HALF});
+ else if (isUnmergeHalf(MI, MRI)) {
+ if (Curr.first == MI->getOperand(0).getReg())
+ return std::optional<std::pair<Register, SrcStatus>>(
+ {MI->getOperand(2).getReg(), SrcStatus::IS_LOWER_HALF});
+ return std::optional<std::pair<Register, SrcStatus>>(
+ {MI->getOperand(2).getReg(), SrcStatus::IS_UPPER_HALF});
+ }
+ break;
+ case SrcStatus::IS_HI_NEG:
+ if (isTruncHalf(MI, MRI)) {
+ // [SrcHi, SrcLo] = [-CurrHi, CurrLo]
+ // [CurrHi, CurrLo] = trunc [OpUpper, OpLower] = OpLower
+ // = [OpLowerHi, OpLowerLo]
+ // Src = [SrcHi, SrcLo] = [-CurrHi, CurrLo]
+ // = [-OpLowerHi, OpLowerLo]
+ // = -OpLower
+ return std::optional<std::pair<Register, SrcStatus>>(
+ {MI->getOperand(1).getReg(), SrcStatus::IS_LOWER_HALF_NEG});
+ }
+ if (isUnmergeHalf(MI, MRI)) {
+ if (Curr.first == MI->getOperand(0).getReg())
+ return std::optional<std::pair<Register, SrcStatus>>(
+ {MI->getOperand(2).getReg(), SrcStatus::IS_LOWER_HALF_NEG});
+ return std::optional<std::pair<Register, SrcStatus>>(
+ {MI->getOperand(2).getReg(), SrcStatus::IS_UPPER_HALF_NEG});
+ }
+ break;
+ case SrcStatus::IS_UPPER_HALF:
+ if (isShlHalf(MI, MRI))
+ return std::optional<std::pair<Register, SrcStatus>>(
+ {MI->getOperand(1).getReg(), SrcStatus::IS_LOWER_HALF});
+ break;
+ case SrcStatus::IS_LOWER_HALF:
+ if (isLshrHalf(MI, MRI))
+ return std::optional<std::pair<Register, SrcStatus>>(
+ {MI->getOperand(1).getReg(), SrcStatus::IS_UPPER_HALF});
+ break;
+ case SrcStatus::IS_UPPER_HALF_NEG:
+ if (isShlHalf(MI, MRI))
+ return std::optional<std::pair<Register, SrcStatus>>(
+ {MI->getOperand(1).getReg(), SrcStatus::IS_LOWER_HALF_NEG});
+ break;
+ case SrcStatus::IS_LOWER_HALF_NEG:
+ if (isLshrHalf(MI, MRI))
+ return std::optional<std::pair<Register, SrcStatus>>(
+ {MI->getOperand(1).getReg(), SrcStatus::IS_UPPER_HALF_NEG});
+ break;
+ default:
+ break;
+ }
+ return std::nullopt;
+}
+
+class searchOptions {
+private:
+ bool HasNeg = false;
+ // Assume all complex pattern of VOP3P has opsel.
+ bool HasOpsel = true;
+
+public:
+ searchOptions(Register Reg, const MachineRegisterInfo &MRI) {
+ const MachineInstr *MI = MRI.getVRegDef(Reg);
+ unsigned Opc = MI->getOpcode();
+
+ if (Opc < TargetOpcode::GENERIC_OP_END) {
+ // Keep same for generic op.
+ HasNeg = true;
+ } else if (Opc == TargetOpcode::G_INTRINSIC) {
+ Intrinsic::ID IntrinsicID = cast<GIntrinsic>(*MI).getIntrinsicID();
+ // Only float point intrinsic has neg & neg_hi bits.
+ if (IntrinsicID == Intrinsic::amdgcn_fdot2)
+ HasNeg = true;
+ }
+ }
+ bool checkOptions(SrcStatus Stat) const {
+ if (!HasNeg &&
+ (Stat >= SrcStatus::NEG_START && Stat <= SrcStatus::NEG_END)) {
+ return false;
+ }
+ if (!HasOpsel &&
+ (Stat >= SrcStatus::HALF_START && Stat <= SrcStatus::HALF_END)) {
+ return false;
+ }
+ return true;
+ }
+};
+
+static SmallVector<std::pair<Register, SrcStatus>>
+getSrcStats(Register Reg, const MachineRegisterInfo &MRI,
+ searchOptions SearchOptions, int MaxDepth = 6) {
+ int Depth = 0;
+ auto Curr = calcNextStatus({Reg, SrcStatus::IS_SAME}, MRI);
+ SmallVector<std::pair<Register, SrcStatus>> Statlist;
+
+ while (Depth <= MaxDepth && Curr.has_value()) {
+ Depth++;
+ if (SearchOptions.checkOptions(Curr.value().second))
+ Statlist.push_back(Curr.value());
+ Curr = calcNextStatus(Curr.value(), MRI);
+ }
+
+ return Statlist;
+}
+
+static std::pair<Register, SrcStatus>
+getLastSameOrNeg(Register Reg, const MachineRegisterInfo &MRI,
+ searchOptions SearchOptions, int MaxDepth = 6) {
----------------
arsenm wrote:
I still think this whole thing is overcomplicating the problem. You shouldn't need to do an arbitrary depth analysis. At most it's probably 3. Anything deeper is running into missed combine opportunities
https://github.com/llvm/llvm-project/pull/136262
More information about the llvm-commits
mailing list