[llvm] [X86] matchAddressRecursively - move ZERO_EXTEND patterns into matchIndexRecursively (PR #85081)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Mar 13 20:21:14 PDT 2024
================
@@ -2346,54 +2352,137 @@ SDValue X86DAGToDAGISel::matchIndexRecursively(SDValue N,
}
}
- // index: zext(add_nuw(x,c)) -> index: zext(x), disp + zext(c)
- // index: zext(addlike(x,c)) -> index: zext(x), disp + zext(c)
- // TODO: call matchIndexRecursively(AddSrc) if we won't corrupt sext?
- if (Opc == ISD::ZERO_EXTEND && !VT.isVector() && N.hasOneUse()) {
+ if (Opc == ISD::ZERO_EXTEND) {
+ // index: zext(add_nuw(x,c)) -> index: zext(x), disp + zext(c)
+ // index: zext(addlike(x,c)) -> index: zext(x), disp + zext(c)
SDValue Src = N.getOperand(0);
- unsigned SrcOpc = Src.getOpcode();
- if (((SrcOpc == ISD::ADD && Src->getFlags().hasNoUnsignedWrap()) ||
- CurDAG->isADDLike(Src)) &&
- Src.hasOneUse()) {
- if (CurDAG->isBaseWithConstantOffset(Src)) {
- SDValue AddSrc = Src.getOperand(0);
- uint64_t Offset = Src.getConstantOperandVal(1);
- if (!foldOffsetIntoAddress(Offset * AM.Scale, AM)) {
- SDLoc DL(N);
- SDValue Res;
- // If we're also scaling, see if we can use that as well.
- if (AddSrc.getOpcode() == ISD::SHL &&
- isa<ConstantSDNode>(AddSrc.getOperand(1))) {
- SDValue ShVal = AddSrc.getOperand(0);
- uint64_t ShAmt = AddSrc.getConstantOperandVal(1);
- APInt HiBits =
- APInt::getHighBitsSet(AddSrc.getScalarValueSizeInBits(), ShAmt);
- uint64_t ScaleAmt = 1ULL << ShAmt;
- if ((AM.Scale * ScaleAmt) <= 8 &&
- (AddSrc->getFlags().hasNoUnsignedWrap() ||
- CurDAG->MaskedValueIsZero(ShVal, HiBits))) {
- AM.Scale *= ScaleAmt;
- SDValue ExtShVal = CurDAG->getNode(Opc, DL, VT, ShVal);
- SDValue ExtShift = CurDAG->getNode(ISD::SHL, DL, VT, ExtShVal,
- AddSrc.getOperand(1));
- insertDAGNode(*CurDAG, N, ExtShVal);
- insertDAGNode(*CurDAG, N, ExtShift);
- AddSrc = ExtShift;
- Res = ExtShVal;
+ if (!VT.isVector() && N.hasOneUse()) {
+ unsigned SrcOpc = Src.getOpcode();
+ if (((SrcOpc == ISD::ADD && Src->getFlags().hasNoUnsignedWrap()) ||
+ CurDAG->isADDLike(Src)) &&
+ Src.hasOneUse()) {
+ if (CurDAG->isBaseWithConstantOffset(Src)) {
+ SDValue AddSrc = Src.getOperand(0);
+ uint64_t Offset = Src.getConstantOperandVal(1);
+ if (!foldOffsetIntoAddress(Offset * AM.Scale, AM)) {
+ SDLoc DL(N);
+ SDValue Res;
+ // If we're also scaling, see if we can use that as well.
+ if (AddSrc.getOpcode() == ISD::SHL &&
+ isa<ConstantSDNode>(AddSrc.getOperand(1))) {
+ SDValue ShVal = AddSrc.getOperand(0);
+ uint64_t ShAmt = AddSrc.getConstantOperandVal(1);
+ APInt HiBits = APInt::getHighBitsSet(
+ AddSrc.getScalarValueSizeInBits(), ShAmt);
+ uint64_t ScaleAmt = 1ULL << ShAmt;
+ if ((AM.Scale * ScaleAmt) <= 8 &&
+ (AddSrc->getFlags().hasNoUnsignedWrap() ||
+ CurDAG->MaskedValueIsZero(ShVal, HiBits))) {
+ AM.Scale *= ScaleAmt;
+ SDValue ExtShVal = CurDAG->getNode(Opc, DL, VT, ShVal);
+ SDValue ExtShift = CurDAG->getNode(ISD::SHL, DL, VT, ExtShVal,
+ AddSrc.getOperand(1));
+ insertDAGNode(*CurDAG, N, ExtShVal);
+ insertDAGNode(*CurDAG, N, ExtShift);
+ AddSrc = ExtShift;
+ Res = adapter(ExtShVal);
+ }
}
+ SDValue ExtSrc = CurDAG->getNode(Opc, DL, VT, AddSrc);
+ SDValue ExtVal = CurDAG->getConstant(Offset, DL, VT);
+ SDValue ExtAdd = CurDAG->getNode(SrcOpc, DL, VT, ExtSrc, ExtVal);
+ insertDAGNode(*CurDAG, N, ExtSrc);
+ insertDAGNode(*CurDAG, N, ExtVal);
+ insertDAGNode(*CurDAG, N, ExtAdd);
+ CurDAG->ReplaceAllUsesWith(N, ExtAdd);
+ CurDAG->RemoveDeadNode(N.getNode());
+ // AM.IndexReg can be further picked
+ SDValue Zext = adapter(ExtSrc);
+ return Res ? Res : Zext;
}
- SDValue ExtSrc = CurDAG->getNode(Opc, DL, VT, AddSrc);
- SDValue ExtVal = CurDAG->getConstant(Offset, DL, VT);
- SDValue ExtAdd = CurDAG->getNode(SrcOpc, DL, VT, ExtSrc, ExtVal);
- insertDAGNode(*CurDAG, N, ExtSrc);
- insertDAGNode(*CurDAG, N, ExtVal);
- insertDAGNode(*CurDAG, N, ExtAdd);
- CurDAG->ReplaceAllUsesWith(N, ExtAdd);
- CurDAG->RemoveDeadNode(N.getNode());
- return Res ? Res : ExtSrc;
}
}
}
+
+ // Peek through mask: zext(and(shl(x,c1),c2))
+ APInt Mask = APInt::getAllOnes(Src.getScalarValueSizeInBits());
+ if (Src.getOpcode() == ISD::AND && Src.hasOneUse())
+ if (auto *MaskC = dyn_cast<ConstantSDNode>(Src.getOperand(1))) {
+ Mask = MaskC->getAPIntValue();
+ Src = Src.getOperand(0);
+ }
+
+ if (Src.getOpcode() == ISD::SHL && Src.hasOneUse()) {
+ // Give up if the shift is not a valid scale factor [1,2,3].
+ SDValue ShlSrc = Src.getOperand(0);
+ SDValue ShlAmt = Src.getOperand(1);
+ auto *ShAmtC = dyn_cast<ConstantSDNode>(ShlAmt);
+ if (!ShAmtC)
+ return std::nullopt;
+ unsigned ShAmtV = ShAmtC->getZExtValue();
+ if (ShAmtV > 3 || (1 << ShAmtV) * AM.Scale > 8)
----------------
RicoAfoat wrote:
Yes, I think we can remove `AM.Scale!=1` in `matchAddressRecursively`. I relax the restriction here but not for the last case where we try to deal with mask and shift.
https://github.com/llvm/llvm-project/pull/85081
More information about the llvm-commits
mailing list