[clang] [llvm] Add support for flag output operand "=@cc" for SystemZ. (PR #125970)
Ulrich Weigand via llvm-commits
llvm-commits at lists.llvm.org
Wed Aug 20 06:34:04 PDT 2025
================
@@ -8701,95 +8734,341 @@ SDValue SystemZTargetLowering::combineSETCC(
return SDValue();
}
-static bool combineCCMask(SDValue &CCReg, int &CCValid, int &CCMask) {
+static std::pair<SDValue, int> findCCUse(const SDValue &Val) {
+ auto *N = Val.getNode();
+ if (!N)
+ return std::make_pair(SDValue(), SystemZ::CCMASK_NONE);
+ switch (N->getOpcode()) {
+ default:
+ return std::make_pair(SDValue(), SystemZ::CCMASK_NONE);
+ case SystemZISD::IPM:
+ if (N->getOperand(0).getOpcode() == SystemZISD::CLC ||
+ N->getOperand(0).getOpcode() == SystemZ::CLST ||
+ N->getOperand(0).getOpcode() == SystemZISD::STRCMP)
+ return std::make_pair(N->getOperand(0), SystemZ::CCMASK_ICMP);
+ return std::make_pair(N->getOperand(0), SystemZ::CCMASK_ANY);
+ case ISD::SHL:
+ case ISD::SRA:
+ case ISD::SRL:
+ return findCCUse(N->getOperand(0));
+ case SystemZISD::SELECT_CCMASK: {
+ SDValue Op4CCReg = N->getOperand(4);
+ auto *Op4CCNode = Op4CCReg.getNode();
+ auto *CCValid = dyn_cast<ConstantSDNode>(N->getOperand(2));
+ if (!CCValid || !Op4CCNode)
+ return std::make_pair(SDValue(), SystemZ::CCMASK_NONE);
+ int CCValidVal = CCValid->getZExtValue();
+ if (Op4CCNode->getOpcode() == SystemZISD::ICMP ||
+ Op4CCNode->getOpcode() == SystemZISD::TM) {
+ auto [OpCC, OpCCValid] = findCCUse(Op4CCNode->getOperand(0));
+ if (OpCC != SDValue())
+ return std::make_pair(OpCC, OpCCValid);
+ }
+ auto [OpCC, OpCCValid] = findCCUse(Op4CCReg);
+ return OpCC != SDValue() ? std::make_pair(OpCC, OpCCValid)
+ : std::make_pair(Op4CCReg, CCValidVal);
+ }
+ case ISD::ADD:
+ case ISD::AND:
+ case ISD::OR:
+ case ISD::XOR:
+ auto [Op0CC, Op0CCValid] = findCCUse(N->getOperand(0));
+ if (Op0CC != SDValue())
+ return std::make_pair(Op0CC, Op0CCValid);
+ return findCCUse(N->getOperand(1));
+ }
+}
+
+SmallVector<SDValue, 4>
+SystemZTargetLowering::simplifyAssumingCCVal(SDValue &Val, SDValue &CC,
+ DAGCombinerInfo &DCI) const {
+ const auto isValidBinaryOperation = [](const SDValue &Op, SDValue &Op0,
+ SDValue &Op1, unsigned &Opcode) {
+ auto *N = Op.getNode();
+ if (!N)
+ return false;
+ Opcode = N->getOpcode();
+ if (Opcode != ISD::ADD && Opcode != ISD::AND && Opcode != ISD::OR &&
+ Opcode != ISD::XOR)
+ return false;
+ Op0 = N->getOperand(0);
+ Op1 = N->getOperand(1);
+ return true;
+ };
+ if (isa<ConstantSDNode>(Val)) {
+ return {Val, Val, Val, Val};
+ }
+ auto *N = Val.getNode(), *CCNode = CC.getNode();
+ if (!N || !CCNode)
+ return {};
+ SelectionDAG &DAG = DCI.DAG;
+ SDLoc DL(N);
+ if (N->getOpcode() == SystemZISD::IPM) {
+ SmallVector<SDValue, 4> ShiftedCCVals;
+ for (auto CC : {0, 1, 2, 3}) {
+ SDValue CCVal = DAG.getConstant(CC, DL, MVT::i32);
+ ShiftedCCVals.emplace_back(
+ DAG.getNode(ISD::SHL, DL, MVT::i32, CCVal,
+ DAG.getConstant(SystemZ::IPM_CC, DL, MVT::i32)));
+ }
+ return ShiftedCCVals;
+ }
+ if (N->getOpcode() == ISD::SRL) {
+ SDValue Op0 = N->getOperand(0);
+ auto *SRLCount = dyn_cast<ConstantSDNode>(N->getOperand(1));
+ if (!SRLCount)
+ return {};
+ auto SRLCountVal = SRLCount->getZExtValue();
+ const auto &&SDVals = simplifyAssumingCCVal(Op0, CC, DCI);
+ if (SDVals.empty())
+ return SDVals;
+ SmallVector<SDValue, 4> ShiftedVals;
+ for (const auto &SDVal : SDVals)
+ ShiftedVals.emplace_back(
+ DAG.getNode(ISD::SRL, DL, MVT::i32, SDVal,
+ DAG.getConstant(SRLCountVal, DL, MVT::i32)));
+ return ShiftedVals;
+ }
+ if (N->getOpcode() == ISD::SRA) {
+ // Keep SRA and SHL opcode together and check for shift amount the same as
+ // in original code.
+ auto *SRACount = dyn_cast<ConstantSDNode>(N->getOperand(1));
+ if (!SRACount || SRACount->getZExtValue() != 30)
+ return {};
+ auto *SHL = N->getOperand(0).getNode();
+ if (SHL->getOpcode() != ISD::SHL)
+ return {};
+ auto *SHLCount = dyn_cast<ConstantSDNode>(SHL->getOperand(1));
+ if (!SHLCount || SHLCount->getZExtValue() != 30 - SystemZ::IPM_CC)
+ return {};
+ // Avoid introducing CC spills (because SRA would clobber CC).
+ if (!N->hasOneUse())
+ return {};
+ SDValue IPM = SHL->getOperand(0);
+ const auto &&SDVals = simplifyAssumingCCVal(IPM, CC, DCI);
+ if (SDVals.empty())
+ return SDVals;
+ auto SRAShift = SRACount->getZExtValue();
+ auto SHLShift = SHLCount->getZExtValue();
+ SmallVector<SDValue, 4> ShiftedVals;
+ for (const auto &SDVal : SDVals) {
+ SDValue SRAVal = DAG.getNode(ISD::SHL, DL, MVT::i32, SDVal,
+ DAG.getConstant(SHLShift, DL, MVT::i32));
+ ShiftedVals.emplace_back(
+ DAG.getNode(ISD::SRA, DL, MVT::i32, SRAVal,
+ DAG.getConstant(SRAShift, DL, MVT::i32)));
+ }
+ return ShiftedVals;
+ }
+ if (N->getOpcode() == SystemZISD::SELECT_CCMASK) {
+ SDValue TrueVal = N->getOperand(0), FalseVal = N->getOperand(1);
+ auto *TrueOp = TrueVal.getNode();
+ auto *FalseOp = FalseVal.getNode();
+ auto *CCValid = dyn_cast<ConstantSDNode>(N->getOperand(2));
+ auto *CCMask = dyn_cast<ConstantSDNode>(N->getOperand(3));
+ if (!TrueOp || !FalseOp || !CCValid || !CCMask)
+ return {};
+
+ int CCValidVal = CCValid->getZExtValue();
+ int CCMaskVal = CCMask->getZExtValue();
+ const auto &&TrueSDVals = simplifyAssumingCCVal(TrueVal, CC, DCI);
+ const auto &&FalseSDVals = simplifyAssumingCCVal(FalseVal, CC, DCI);
+ if (TrueSDVals.empty() || FalseSDVals.empty())
+ return {};
+ SDValue Op4CCReg = N->getOperand(4);
+ auto *Op4CCNode = Op4CCReg.getNode();
+ if (Op4CCNode && Op4CCNode != CCNode)
+ combineCCMask(Op4CCReg, CCValidVal, CCMaskVal, DCI);
+ Op4CCNode = Op4CCReg.getNode();
+ if (!Op4CCNode || Op4CCNode != CCNode)
+ return {};
+ SmallVector<SDValue, 4> MergedSDVals;
+ for (auto &CCVal : {0, 1, 2, 3})
+ MergedSDVals.emplace_back((((CCMaskVal & (1 << (3 - CCVal))) != 0) &&
+ ((CCValidVal & (1 << (3 - CCVal))) != 0))
+ ? TrueSDVals[CCVal]
+ : FalseSDVals[CCVal]);
+ return MergedSDVals;
+ }
+ SDValue Op0, Op1;
+ unsigned Opcode;
+ if (isValidBinaryOperation(Val, Op0, Op1, Opcode)) {
+ const auto &&Op0SDVals = simplifyAssumingCCVal(Op0, CC, DCI);
+ const auto &&Op1SDVals = simplifyAssumingCCVal(Op1, CC, DCI);
+ if (Op0SDVals.empty() || Op1SDVals.empty())
+ return {};
+ SmallVector<SDValue, 4> BinaryOpSDVals;
+ for (auto CCVal : {0, 1, 2, 3})
+ BinaryOpSDVals.emplace_back(DAG.getNode(
+ Opcode, DL, Val.getValueType(), Op0SDVals[CCVal], Op1SDVals[CCVal]));
+ return BinaryOpSDVals;
+ }
+ return {};
+}
+
+bool SystemZTargetLowering::combineCCMask(SDValue &CCReg, int &CCValid,
+ int &CCMask,
+ DAGCombinerInfo &DCI) const {
// We have a SELECT_CCMASK or BR_CCMASK comparing the condition code
// set by the CCReg instruction using the CCValid / CCMask masks,
- // If the CCReg instruction is itself a ICMP testing the condition
+ // If the CCReg instruction is itself a ICMP / TM testing the condition
// code set by some other instruction, see whether we can directly
// use that condition code.
-
- // Verify that we have an ICMP against some constant.
- if (CCValid != SystemZ::CCMASK_ICMP)
- return false;
- auto *ICmp = CCReg.getNode();
- if (ICmp->getOpcode() != SystemZISD::ICMP)
- return false;
- auto *CompareLHS = ICmp->getOperand(0).getNode();
- auto *CompareRHS = dyn_cast<ConstantSDNode>(ICmp->getOperand(1));
- if (!CompareRHS)
+ auto *CCNode = CCReg.getNode();
+ if (!CCNode)
return false;
+ const auto getConstFromConstSDVals = [](const SmallVector<SDValue, 4> &Vals) {
+ SmallVector<int, 4> CCVals;
+ for (const auto &Val : Vals)
+ if (auto *ConstNode = dyn_cast<ConstantSDNode>(Val.getNode()))
+ CCVals.emplace_back(ConstNode->getZExtValue());
+ else
+ return SmallVector<int, 4>();
+ return CCVals;
+ };
+ const auto getMSBPosSet = [](unsigned int Mask) {
+ int NumBits = std::numeric_limits<unsigned int>::digits;
+ int count = 0;
+ // Keep target search space to the left.
+ while (NumBits > 0) {
+ NumBits /= 2;
+ // Upper half zeros.
+ if (!(Mask >> NumBits)) {
+ count += NumBits;
+ // Search lower half.
+ Mask <<= NumBits;
+ }
+ }
+ return count;
+ };
- // Optimize the case where CompareLHS is a SELECT_CCMASK.
- if (CompareLHS->getOpcode() == SystemZISD::SELECT_CCMASK) {
- // Verify that we have an appropriate mask for a EQ or NE comparison.
- bool Invert = false;
- if (CCMask == SystemZ::CCMASK_CMP_NE)
- Invert = !Invert;
- else if (CCMask != SystemZ::CCMASK_CMP_EQ)
+ if (CCNode->getOpcode() == SystemZISD::TM) {
+ if (CCValid != SystemZ::CCMASK_TM)
return false;
-
- // Verify that the ICMP compares against one of select values.
- auto *TrueVal = dyn_cast<ConstantSDNode>(CompareLHS->getOperand(0));
- if (!TrueVal)
+ const auto emulateTMCCMask = [&](int CCVal, int Mask) {
+ if (!Mask)
+ return std::numeric_limits<unsigned int>::digits;
+ int Result = CCVal & Mask;
+ bool AllOnes = Result == Mask;
+ bool AllZeros = Result == 0;
+ bool MixedZerosOnes = (!AllOnes && !AllZeros);
+ int MSBPos = getMSBPosSet(static_cast<unsigned int>(Mask));
+ bool IsLeftMostBitSet = (Result & (1 << MSBPos)) != 0;
+ return AllOnes ? 3
+ : AllZeros ? 0
+ : (MixedZerosOnes && IsLeftMostBitSet) ? 2
+ : 1;
+ };
+ SDValue Op0 = CCNode->getOperand(0);
+ SDValue Op1 = CCNode->getOperand(1);
+ auto [Op0CC, Op0CCValid] = findCCUse(Op0);
+ if (Op0CC == SDValue())
return false;
- auto *FalseVal = dyn_cast<ConstantSDNode>(CompareLHS->getOperand(1));
- if (!FalseVal)
+ const auto &&Op0SDVals = simplifyAssumingCCVal(Op0, Op0CC, DCI);
+ const auto &&Op1SDVals = simplifyAssumingCCVal(Op1, Op0CC, DCI);
+ if (Op0SDVals.empty() || Op1SDVals.empty())
return false;
- if (CompareRHS->getAPIntValue() == FalseVal->getAPIntValue())
- Invert = !Invert;
- else if (CompareRHS->getAPIntValue() != TrueVal->getAPIntValue())
+ auto &&Op0CCVals = getConstFromConstSDVals(Op0SDVals);
+ const auto &&Op1CCVals = getConstFromConstSDVals(Op1SDVals);
+ if (Op0CCVals.empty() || Op1CCVals.empty())
return false;
-
- // Compute the effective CC mask for the new branch or select.
- auto *NewCCValid = dyn_cast<ConstantSDNode>(CompareLHS->getOperand(2));
- auto *NewCCMask = dyn_cast<ConstantSDNode>(CompareLHS->getOperand(3));
- if (!NewCCValid || !NewCCMask)
+ std::transform(Op0CCVals.begin(), Op0CCVals.end(), Op1CCVals.begin(),
+ Op0CCVals.begin(), emulateTMCCMask);
+ if (std::any_of(Op0CCVals.begin(), Op0CCVals.end(),
+ [](const auto &CC) { return CC < 0 || CC > 3; }))
return false;
- CCValid = NewCCValid->getZExtValue();
- CCMask = NewCCMask->getZExtValue();
- if (Invert)
- CCMask ^= CCValid;
-
- // Return the updated CCReg link.
- CCReg = CompareLHS->getOperand(4);
+ int NewCCMask = 0;
+ for (auto CC : Op0CCVals) {
+ NewCCMask <<= 1;
+ NewCCMask |= (CCMask & (1 << (3 - CC))) != 0;
+ }
+ CCReg = Op0CC;
+ CCMask = NewCCMask;
return true;
}
+ if (CCNode->getOpcode() != SystemZISD::ICMP ||
+ CCValid != SystemZ::CCMASK_ICMP)
+ return false;
- // Optimize the case where CompareRHS is (SRA (SHL (IPM))).
- if (CompareLHS->getOpcode() == ISD::SRA) {
- auto *SRACount = dyn_cast<ConstantSDNode>(CompareLHS->getOperand(1));
- if (!SRACount || SRACount->getZExtValue() != 30)
- return false;
- auto *SHL = CompareLHS->getOperand(0).getNode();
- if (SHL->getOpcode() != ISD::SHL)
- return false;
- auto *SHLCount = dyn_cast<ConstantSDNode>(SHL->getOperand(1));
- if (!SHLCount || SHLCount->getZExtValue() != 30 - SystemZ::IPM_CC)
- return false;
- auto *IPM = SHL->getOperand(0).getNode();
- if (IPM->getOpcode() != SystemZISD::IPM)
- return false;
-
- // Avoid introducing CC spills (because SRA would clobber CC).
- if (!CompareLHS->hasOneUse())
+ SDValue CmpOp0 = CCNode->getOperand(0);
+ SDValue CmpOp1 = CCNode->getOperand(1);
+ SDValue CmpOp2 = CCNode->getOperand(2);
+ auto [Op0CC, Op0CCValid] = findCCUse(CmpOp0);
+ if (Op0CC != SDValue()) {
+ const auto &&Op0SDVals = simplifyAssumingCCVal(CmpOp0, Op0CC, DCI);
+ const auto &&Op1SDVals = simplifyAssumingCCVal(CmpOp1, Op0CC, DCI);
+ if (Op0SDVals.empty() || Op1SDVals.empty())
return false;
- // Verify that the ICMP compares against zero.
- if (CompareRHS->getZExtValue() != 0)
+ auto &&Op0CCVals = getConstFromConstSDVals(Op0SDVals);
+ const auto &&Op1CCVals = getConstFromConstSDVals(Op1SDVals);
----------------
uweigand wrote:
Again with the `int` truncation ... this also should operate on the full `APInt`s.
https://github.com/llvm/llvm-project/pull/125970
More information about the llvm-commits
mailing list