[llvm] [X86] matchAddressRecursively - move ZERO_EXTEND patterns into matchIndexRecursively (PR #85081)

via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 13 20:21:14 PDT 2024


================
@@ -2346,54 +2352,137 @@ SDValue X86DAGToDAGISel::matchIndexRecursively(SDValue N,
     }
   }
 
-  // index: zext(add_nuw(x,c)) -> index: zext(x), disp + zext(c)
-  // index: zext(addlike(x,c)) -> index: zext(x), disp + zext(c)
-  // TODO: call matchIndexRecursively(AddSrc) if we won't corrupt sext?
-  if (Opc == ISD::ZERO_EXTEND && !VT.isVector() && N.hasOneUse()) {
+  if (Opc == ISD::ZERO_EXTEND) {
+    // index: zext(add_nuw(x,c)) -> index: zext(x), disp + zext(c)
+    // index: zext(addlike(x,c)) -> index: zext(x), disp + zext(c)
     SDValue Src = N.getOperand(0);
-    unsigned SrcOpc = Src.getOpcode();
-    if (((SrcOpc == ISD::ADD && Src->getFlags().hasNoUnsignedWrap()) ||
-         CurDAG->isADDLike(Src)) &&
-        Src.hasOneUse()) {
-      if (CurDAG->isBaseWithConstantOffset(Src)) {
-        SDValue AddSrc = Src.getOperand(0);
-        uint64_t Offset = Src.getConstantOperandVal(1);
-        if (!foldOffsetIntoAddress(Offset * AM.Scale, AM)) {
-          SDLoc DL(N);
-          SDValue Res;
-          // If we're also scaling, see if we can use that as well.
-          if (AddSrc.getOpcode() == ISD::SHL &&
-              isa<ConstantSDNode>(AddSrc.getOperand(1))) {
-            SDValue ShVal = AddSrc.getOperand(0);
-            uint64_t ShAmt = AddSrc.getConstantOperandVal(1);
-            APInt HiBits =
-                APInt::getHighBitsSet(AddSrc.getScalarValueSizeInBits(), ShAmt);
-            uint64_t ScaleAmt = 1ULL << ShAmt;
-            if ((AM.Scale * ScaleAmt) <= 8 &&
-                (AddSrc->getFlags().hasNoUnsignedWrap() ||
-                 CurDAG->MaskedValueIsZero(ShVal, HiBits))) {
-              AM.Scale *= ScaleAmt;
-              SDValue ExtShVal = CurDAG->getNode(Opc, DL, VT, ShVal);
-              SDValue ExtShift = CurDAG->getNode(ISD::SHL, DL, VT, ExtShVal,
-                                                 AddSrc.getOperand(1));
-              insertDAGNode(*CurDAG, N, ExtShVal);
-              insertDAGNode(*CurDAG, N, ExtShift);
-              AddSrc = ExtShift;
-              Res = ExtShVal;
+    if (!VT.isVector() && N.hasOneUse()) {
+      unsigned SrcOpc = Src.getOpcode();
+      if (((SrcOpc == ISD::ADD && Src->getFlags().hasNoUnsignedWrap()) ||
+           CurDAG->isADDLike(Src)) &&
+          Src.hasOneUse()) {
+        if (CurDAG->isBaseWithConstantOffset(Src)) {
+          SDValue AddSrc = Src.getOperand(0);
+          uint64_t Offset = Src.getConstantOperandVal(1);
+          if (!foldOffsetIntoAddress(Offset * AM.Scale, AM)) {
+            SDLoc DL(N);
+            SDValue Res;
+            // If we're also scaling, see if we can use that as well.
+            if (AddSrc.getOpcode() == ISD::SHL &&
+                isa<ConstantSDNode>(AddSrc.getOperand(1))) {
+              SDValue ShVal = AddSrc.getOperand(0);
+              uint64_t ShAmt = AddSrc.getConstantOperandVal(1);
+              APInt HiBits = APInt::getHighBitsSet(
+                  AddSrc.getScalarValueSizeInBits(), ShAmt);
+              uint64_t ScaleAmt = 1ULL << ShAmt;
+              if ((AM.Scale * ScaleAmt) <= 8 &&
+                  (AddSrc->getFlags().hasNoUnsignedWrap() ||
+                   CurDAG->MaskedValueIsZero(ShVal, HiBits))) {
+                AM.Scale *= ScaleAmt;
+                SDValue ExtShVal = CurDAG->getNode(Opc, DL, VT, ShVal);
+                SDValue ExtShift = CurDAG->getNode(ISD::SHL, DL, VT, ExtShVal,
+                                                   AddSrc.getOperand(1));
+                insertDAGNode(*CurDAG, N, ExtShVal);
+                insertDAGNode(*CurDAG, N, ExtShift);
+                AddSrc = ExtShift;
+                Res = adapter(ExtShVal);
+              }
             }
+            SDValue ExtSrc = CurDAG->getNode(Opc, DL, VT, AddSrc);
+            SDValue ExtVal = CurDAG->getConstant(Offset, DL, VT);
+            SDValue ExtAdd = CurDAG->getNode(SrcOpc, DL, VT, ExtSrc, ExtVal);
+            insertDAGNode(*CurDAG, N, ExtSrc);
+            insertDAGNode(*CurDAG, N, ExtVal);
+            insertDAGNode(*CurDAG, N, ExtAdd);
+            CurDAG->ReplaceAllUsesWith(N, ExtAdd);
+            CurDAG->RemoveDeadNode(N.getNode());
+            // AM.IndexReg can be further picked
+            SDValue Zext = adapter(ExtSrc);
+            return Res ? Res : Zext;
           }
-          SDValue ExtSrc = CurDAG->getNode(Opc, DL, VT, AddSrc);
-          SDValue ExtVal = CurDAG->getConstant(Offset, DL, VT);
-          SDValue ExtAdd = CurDAG->getNode(SrcOpc, DL, VT, ExtSrc, ExtVal);
-          insertDAGNode(*CurDAG, N, ExtSrc);
-          insertDAGNode(*CurDAG, N, ExtVal);
-          insertDAGNode(*CurDAG, N, ExtAdd);
-          CurDAG->ReplaceAllUsesWith(N, ExtAdd);
-          CurDAG->RemoveDeadNode(N.getNode());
-          return Res ? Res : ExtSrc;
         }
       }
     }
+
+    // Peek through mask: zext(and(shl(x,c1),c2))
+    APInt Mask = APInt::getAllOnes(Src.getScalarValueSizeInBits());
+    if (Src.getOpcode() == ISD::AND && Src.hasOneUse())
+      if (auto *MaskC = dyn_cast<ConstantSDNode>(Src.getOperand(1))) {
+        Mask = MaskC->getAPIntValue();
+        Src = Src.getOperand(0);
+      }
+
+    if (Src.getOpcode() == ISD::SHL && Src.hasOneUse()) {
+      // Give up if the shift is not a valid scale factor [1,2,3].
+      SDValue ShlSrc = Src.getOperand(0);
+      SDValue ShlAmt = Src.getOperand(1);
+      auto *ShAmtC = dyn_cast<ConstantSDNode>(ShlAmt);
+      if (!ShAmtC)
+        return std::nullopt;
+      unsigned ShAmtV = ShAmtC->getZExtValue();
+      if (ShAmtV > 3 || (1 << ShAmtV) * AM.Scale > 8)
----------------
RicoAfoat wrote:

Yes, I think we can remove `AM.Scale!=1` in `matchAddressRecursively`. I relax the restriction here but not for the last case where we try to deal with mask and shift.

https://github.com/llvm/llvm-project/pull/85081


More information about the llvm-commits mailing list