[llvm] e8900df - [X86] Move gather/scatter index shl(x, c) -> index:x, scale:c fold into X86DAGToDAGISel::matchIndexRecursively

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Tue Aug 22 07:29:59 PDT 2023


Author: Simon Pilgrim
Date: 2023-08-22T15:29:44+01:00
New Revision: e8900df8a769c40661daeb226dfbcd7ec9939a85

URL: https://github.com/llvm/llvm-project/commit/e8900df8a769c40661daeb226dfbcd7ec9939a85
DIFF: https://github.com/llvm/llvm-project/commit/e8900df8a769c40661daeb226dfbcd7ec9939a85.diff

LOG: [X86] Move gather/scatter index shl(x,c) -> index:x, scale:c fold into X86DAGToDAGISel::matchIndexRecursively

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
    llvm/lib/Target/X86/X86ISelLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
index d5b7fe3aa6cb02..1784a8103a66bb 100644
--- a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
+++ b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
@@ -2216,6 +2216,7 @@ SDValue X86DAGToDAGISel::matchIndexRecursively(SDValue N,
   if (Depth >= SelectionDAG::MaxRecursionDepth)
     return N;
 
+  // index: add(x,c) -> index: x, disp + c
   if (CurDAG->isBaseWithConstantOffset(N)) {
     auto *AddVal = cast<ConstantSDNode>(N.getOperand(1));
     uint64_t Offset = (uint64_t)AddVal->getSExtValue() * AM.Scale;
@@ -2223,6 +2224,24 @@ SDValue X86DAGToDAGISel::matchIndexRecursively(SDValue N,
       return matchIndexRecursively(N.getOperand(0), AM, Depth + 1);
   }
 
+  // index: add(x,x) -> index: x, scale * 2
+  if (N.getOpcode() == ISD::ADD && N.getOperand(0) == N.getOperand(1)) {
+    if (AM.Scale <= 4) {
+      AM.Scale *= 2;
+      return matchIndexRecursively(N.getOperand(0), AM, Depth + 1);
+    }
+  }
+
+  // index: shl(x,i) -> index: x, scale * (1 << i)
+  if (N.getOpcode() == X86ISD::VSHLI) {
+    uint64_t ShiftAmt = N.getConstantOperandVal(1);
+    uint64_t ScaleAmt = 1ULL << ShiftAmt;
+    if ((AM.Scale * ScaleAmt) <= 8) {
+      AM.Scale *= ScaleAmt;
+      return matchIndexRecursively(N.getOperand(0), AM, Depth + 1);
+    }
+  }
+
   // TODO: Handle extensions, shifted masks etc.
   return N;
 }
@@ -2672,9 +2691,15 @@ bool X86DAGToDAGISel::selectVectorAddr(MemSDNode *Parent, SDValue BasePtr,
                                        SDValue &Index, SDValue &Disp,
                                        SDValue &Segment) {
   X86ISelAddressMode AM;
-  AM.IndexReg = IndexOp;
   AM.Scale = cast<ConstantSDNode>(ScaleOp)->getZExtValue();
 
+  // Attempt to match index patterns, as long as we're not relying on implicit
+  // sign-extension, which is performed BEFORE scale.
+  if (IndexOp.getScalarValueSizeInBits() == BasePtr.getScalarValueSizeInBits())
+    AM.IndexReg = matchIndexRecursively(IndexOp, AM, 0);
+  else
+    AM.IndexReg = IndexOp;
+
   unsigned AddrSpace = Parent->getPointerInfo().getAddrSpace();
   if (AddrSpace == X86AS::GS)
     AM.Segment = CurDAG->getRegister(X86::GS, MVT::i16);

diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 0d2f3b00313a74..9f70d6cedb7611 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -52947,43 +52947,10 @@ static SDValue combineTESTP(SDNode *N, SelectionDAG &DAG,
 }
 
 static SDValue combineX86GatherScatter(SDNode *N, SelectionDAG &DAG,
-                                       TargetLowering::DAGCombinerInfo &DCI,
-                                       const X86Subtarget &Subtarget) {
+                                       TargetLowering::DAGCombinerInfo &DCI) {
   auto *MemOp = cast<X86MaskedGatherScatterSDNode>(N);
-  SDValue BasePtr = MemOp->getBasePtr();
-  SDValue Index = MemOp->getIndex();
-  SDValue Scale = MemOp->getScale();
   SDValue Mask = MemOp->getMask();
 
-  // Attempt to fold an index scale into the scale value directly.
-  // For smaller indices, implicit sext is performed BEFORE scale, preventing
-  // this fold under most circumstances.
-  // TODO: Move this into X86DAGToDAGISel::matchVectorAddressRecursively?
-  if ((Index.getOpcode() == X86ISD::VSHLI ||
-       (Index.getOpcode() == ISD::ADD &&
-        Index.getOperand(0) == Index.getOperand(1))) &&
-      isa<ConstantSDNode>(Scale) &&
-      BasePtr.getScalarValueSizeInBits() == Index.getScalarValueSizeInBits()) {
-    unsigned ShiftAmt =
-        Index.getOpcode() == ISD::ADD ? 1 : Index.getConstantOperandVal(1);
-    uint64_t ScaleAmt = cast<ConstantSDNode>(Scale)->getZExtValue();
-    uint64_t NewScaleAmt = ScaleAmt * (1ULL << ShiftAmt);
-    if (isPowerOf2_64(NewScaleAmt) && NewScaleAmt <= 8) {
-      SDValue NewIndex = Index.getOperand(0);
-      SDValue NewScale =
-          DAG.getTargetConstant(NewScaleAmt, SDLoc(N), Scale.getValueType());
-      if (N->getOpcode() == X86ISD::MGATHER)
-        return getAVX2GatherNode(N->getOpcode(), SDValue(N, 0), DAG,
-                                 MemOp->getOperand(1), Mask,
-                                 MemOp->getBasePtr(), NewIndex, NewScale,
-                                 MemOp->getChain(), Subtarget);
-      if (N->getOpcode() == X86ISD::MSCATTER)
-        return getScatterNode(N->getOpcode(), SDValue(N, 0), DAG,
-                              MemOp->getOperand(1), Mask, MemOp->getBasePtr(),
-                              NewIndex, NewScale, MemOp->getChain(), Subtarget);
-    }
-  }
-
   // With vector masks we only demand the upper bit of the mask.
   if (Mask.getScalarValueSizeInBits() != 1) {
     const TargetLowering &TLI = DAG.getTargetLoweringInfo();
@@ -55920,8 +55887,7 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
   case X86ISD::MOVMSK:      return combineMOVMSK(N, DAG, DCI, Subtarget);
   case X86ISD::TESTP:       return combineTESTP(N, DAG, DCI, Subtarget);
   case X86ISD::MGATHER:
-  case X86ISD::MSCATTER:
-    return combineX86GatherScatter(N, DAG, DCI, Subtarget);
+  case X86ISD::MSCATTER:    return combineX86GatherScatter(N, DAG, DCI);
   case ISD::MGATHER:
   case ISD::MSCATTER:       return combineGatherScatter(N, DAG, DCI);
   case X86ISD::PCMPEQ:


        


More information about the llvm-commits mailing list