[llvm] r373137 - [X86] Stop using UpdateNodeOperands in combineGatherScatter. Create new nodes like most other DAG combines.

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Fri Sep 27 18:08:48 PDT 2019


Author: ctopper
Date: Fri Sep 27 18:08:46 2019
New Revision: 373137

URL: http://llvm.org/viewvc/llvm-project?rev=373137&view=rev
Log:
[X86] Stop using UpdateNodeOperands in combineGatherScatter. Create new nodes like most other DAG combines.

Creating new nodes is what we usually do. Have to explicitly
check that we don't update to an existing node and having
to manually manage the worklist is unusual.

We can probably add a helper function to reduce the duplication
of having to check if we should create a gather or scatter, but
I wanted to just get the simple thing done.

Modified:
    llvm/trunk/lib/Target/X86/X86ISelLowering.cpp

Modified: llvm/trunk/lib/Target/X86/X86ISelLowering.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/X86/X86ISelLowering.cpp?rev=373137&r1=373136&r2=373137&view=diff
==============================================================================
--- llvm/trunk/lib/Target/X86/X86ISelLowering.cpp (original)
+++ llvm/trunk/lib/Target/X86/X86ISelLowering.cpp Fri Sep 27 18:08:46 2019
@@ -43381,26 +43381,36 @@ static SDValue combineX86GatherScatter(S
 static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
                                     TargetLowering::DAGCombinerInfo &DCI) {
   SDLoc DL(N);
+  auto *GorS = cast<MaskedGatherScatterSDNode>(N);
+  SDValue Chain = GorS->getChain();
+  SDValue Index = GorS->getIndex();
+  SDValue Mask = GorS->getMask();
+  SDValue Base = GorS->getBasePtr();
+  SDValue Scale = GorS->getScale();
 
   if (DCI.isBeforeLegalizeOps()) {
-    SDValue Index = N->getOperand(4);
     // Remove any sign extends from 32 or smaller to larger than 32.
     // Only do this before LegalizeOps in case we need the sign extend for
     // legalization.
-    if (Index.getOpcode() == ISD::SIGN_EXTEND) {
-      if (Index.getScalarValueSizeInBits() > 32 &&
-          Index.getOperand(0).getScalarValueSizeInBits() <= 32) {
-        SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end());
-        NewOps[4] = Index.getOperand(0);
-        SDNode *Res = DAG.UpdateNodeOperands(N, NewOps);
-        if (Res == N) {
-          // The original sign extend has less users, add back to worklist in
-          // case it needs to be removed
-          DCI.AddToWorklist(Index.getNode());
-          DCI.AddToWorklist(N);
-        }
-        return SDValue(Res, 0);
+    if (Index.getOpcode() == ISD::SIGN_EXTEND &&
+        Index.getScalarValueSizeInBits() > 32 &&
+        Index.getOperand(0).getScalarValueSizeInBits() <= 32) {
+      Index = Index.getOperand(0);
+      if (auto *Gather = dyn_cast<MaskedGatherSDNode>(GorS)) {
+        SDValue Ops[] = { Chain, Gather->getPassThru(),
+                          Mask, Base, Index, Scale } ;
+        return DAG.getMaskedGather(Gather->getVTList(),
+                                   Gather->getMemoryVT(), DL, Ops,
+                                   Gather->getMemOperand(),
+                                   Gather->getIndexType());
       }
+      auto *Scatter = cast<MaskedScatterSDNode>(GorS);
+      SDValue Ops[] = { Chain, Scatter->getValue(),
+                        Mask, Base, Index, Scale };
+      return DAG.getMaskedScatter(Scatter->getVTList(),
+                                  Scatter->getMemoryVT(), DL,
+                                  Ops, Scatter->getMemOperand(),
+                                  Scatter->getIndexType());
     }
 
     // Make sure the index is either i32 or i64
@@ -43410,36 +43420,49 @@ static SDValue combineGatherScatter(SDNo
       EVT IndexVT = EVT::getVectorVT(*DAG.getContext(), EltVT,
                                    Index.getValueType().getVectorNumElements());
       Index = DAG.getSExtOrTrunc(Index, DL, IndexVT);
-      SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end());
-      NewOps[4] = Index;
-      SDNode *Res = DAG.UpdateNodeOperands(N, NewOps);
-      if (Res == N)
-        DCI.AddToWorklist(N);
-      return SDValue(Res, 0);
+      if (auto *Gather = dyn_cast<MaskedGatherSDNode>(GorS)) {
+        SDValue Ops[] = { Chain, Gather->getPassThru(),
+                          Mask, Base, Index, Scale } ;
+        return DAG.getMaskedGather(Gather->getVTList(),
+                                   Gather->getMemoryVT(), DL, Ops,
+                                   Gather->getMemOperand(),
+                                   Gather->getIndexType());
+      }
+      auto *Scatter = cast<MaskedScatterSDNode>(GorS);
+      SDValue Ops[] = { Chain, Scatter->getValue(),
+                        Mask, Base, Index, Scale };
+      return DAG.getMaskedScatter(Scatter->getVTList(),
+                                  Scatter->getMemoryVT(), DL,
+                                  Ops, Scatter->getMemOperand(),
+                                  Scatter->getIndexType());
     }
 
     // Try to remove zero extends from 32->64 if we know the sign bit of
     // the input is zero.
     if (Index.getOpcode() == ISD::ZERO_EXTEND &&
         Index.getScalarValueSizeInBits() == 64 &&
-        Index.getOperand(0).getScalarValueSizeInBits() == 32) {
-      if (DAG.SignBitIsZero(Index.getOperand(0))) {
-        SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end());
-        NewOps[4] = Index.getOperand(0);
-        SDNode *Res = DAG.UpdateNodeOperands(N, NewOps);
-        if (Res == N) {
-          // The original sign extend has less users, add back to worklist in
-          // case it needs to be removed
-          DCI.AddToWorklist(Index.getNode());
-          DCI.AddToWorklist(N);
-        }
-        return SDValue(Res, 0);
+        Index.getOperand(0).getScalarValueSizeInBits() == 32 &&
+        DAG.SignBitIsZero(Index.getOperand(0))) {
+      Index = Index.getOperand(0);
+      if (auto *Gather = dyn_cast<MaskedGatherSDNode>(GorS)) {
+        SDValue Ops[] = { Chain, Gather->getPassThru(),
+                          Mask, Base, Index, Scale } ;
+        return DAG.getMaskedGather(Gather->getVTList(),
+                                   Gather->getMemoryVT(), DL, Ops,
+                                   Gather->getMemOperand(),
+                                   Gather->getIndexType());
       }
+      auto *Scatter = cast<MaskedScatterSDNode>(GorS);
+      SDValue Ops[] = { Chain, Scatter->getValue(),
+                        Mask, Base, Index, Scale };
+      return DAG.getMaskedScatter(Scatter->getVTList(),
+                                  Scatter->getMemoryVT(), DL,
+                                  Ops, Scatter->getMemOperand(),
+                                  Scatter->getIndexType());
     }
   }
 
   // With vector masks we only demand the upper bit of the mask.
-  SDValue Mask = cast<MaskedGatherScatterSDNode>(N)->getMask();
   if (Mask.getScalarValueSizeInBits() != 1) {
     const TargetLowering &TLI = DAG.getTargetLoweringInfo();
     APInt DemandedMask(APInt::getSignMask(Mask.getScalarValueSizeInBits()));




More information about the llvm-commits mailing list