[llvm] 89ab5c6 - [X86] Add a helper function to pull some repeated code out of combineGatherScatter. NFC

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 18 11:21:47 PST 2020


Author: Craig Topper
Date: 2020-02-18T11:10:40-08:00
New Revision: 89ab5c69c8514bd1768beb4f8c058192770aa05d

URL: https://github.com/llvm/llvm-project/commit/89ab5c69c8514bd1768beb4f8c058192770aa05d
DIFF: https://github.com/llvm/llvm-project/commit/89ab5c69c8514bd1768beb4f8c058192770aa05d.diff

LOG: [X86] Add a helper function to pull some repeated code out of combineGatherScatter. NFC

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index fee2b19794dd..385cb754731b 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -44661,13 +44661,33 @@ static SDValue combineX86GatherScatter(SDNode *N, SelectionDAG &DAG,
   return SDValue();
 }
 
+static SDValue rebuildGatherScatter(MaskedGatherScatterSDNode *GorS,
+                                    SDValue Index, SDValue Base, SDValue Scale,
+                                    SelectionDAG &DAG) {
+  SDLoc DL(GorS);
+
+  if (auto *Gather = dyn_cast<MaskedGatherSDNode>(GorS)) {
+    SDValue Ops[] = { Gather->getChain(), Gather->getPassThru(),
+                      Gather->getMask(), Base, Index, Scale } ;
+    return DAG.getMaskedGather(Gather->getVTList(),
+                               Gather->getMemoryVT(), DL, Ops,
+                               Gather->getMemOperand(),
+                               Gather->getIndexType());
+  }
+  auto *Scatter = cast<MaskedScatterSDNode>(GorS);
+  SDValue Ops[] = { Scatter->getChain(), Scatter->getValue(),
+                    Scatter->getMask(), Base, Index, Scale };
+  return DAG.getMaskedScatter(Scatter->getVTList(),
+                              Scatter->getMemoryVT(), DL,
+                              Ops, Scatter->getMemOperand(),
+                              Scatter->getIndexType());
+}
+
 static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
                                     TargetLowering::DAGCombinerInfo &DCI) {
   SDLoc DL(N);
   auto *GorS = cast<MaskedGatherScatterSDNode>(N);
-  SDValue Chain = GorS->getChain();
   SDValue Index = GorS->getIndex();
-  SDValue Mask = GorS->getMask();
   SDValue Base = GorS->getBasePtr();
   SDValue Scale = GorS->getScale();
 
@@ -44687,21 +44707,7 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
         unsigned NumElts = Index.getValueType().getVectorNumElements();
         EVT NewVT = EVT::getVectorVT(*DAG.getContext(), MVT::i32, NumElts);
         Index = DAG.getNode(ISD::TRUNCATE, DL, NewVT, Index);
-        if (auto *Gather = dyn_cast<MaskedGatherSDNode>(GorS)) {
-          SDValue Ops[] = { Chain, Gather->getPassThru(),
-                            Mask, Base, Index, Scale } ;
-          return DAG.getMaskedGather(Gather->getVTList(),
-                                     Gather->getMemoryVT(), DL, Ops,
-                                     Gather->getMemOperand(),
-                                     Gather->getIndexType());
-        }
-        auto *Scatter = cast<MaskedScatterSDNode>(GorS);
-        SDValue Ops[] = { Chain, Scatter->getValue(),
-                          Mask, Base, Index, Scale };
-        return DAG.getMaskedScatter(Scatter->getVTList(),
-                                    Scatter->getMemoryVT(), DL,
-                                    Ops, Scatter->getMemOperand(),
-                                    Scatter->getIndexType());
+        return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
       }
     }
 
@@ -44716,21 +44722,7 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
       unsigned NumElts = Index.getValueType().getVectorNumElements();
       EVT NewVT = EVT::getVectorVT(*DAG.getContext(), MVT::i32, NumElts);
       Index = DAG.getNode(ISD::TRUNCATE, DL, NewVT, Index);
-      if (auto *Gather = dyn_cast<MaskedGatherSDNode>(GorS)) {
-        SDValue Ops[] = { Chain, Gather->getPassThru(),
-                          Mask, Base, Index, Scale } ;
-        return DAG.getMaskedGather(Gather->getVTList(),
-                                   Gather->getMemoryVT(), DL, Ops,
-                                   Gather->getMemOperand(),
-                                   Gather->getIndexType());
-      }
-      auto *Scatter = cast<MaskedScatterSDNode>(GorS);
-      SDValue Ops[] = { Chain, Scatter->getValue(),
-                        Mask, Base, Index, Scale };
-      return DAG.getMaskedScatter(Scatter->getVTList(),
-                                  Scatter->getMemoryVT(), DL,
-                                  Ops, Scatter->getMemOperand(),
-                                  Scatter->getIndexType());
+      return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
     }
   }
 
@@ -44743,25 +44735,12 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
       EVT IndexVT = EVT::getVectorVT(*DAG.getContext(), EltVT,
                                    Index.getValueType().getVectorNumElements());
       Index = DAG.getSExtOrTrunc(Index, DL, IndexVT);
-      if (auto *Gather = dyn_cast<MaskedGatherSDNode>(GorS)) {
-        SDValue Ops[] = { Chain, Gather->getPassThru(),
-                          Mask, Base, Index, Scale } ;
-        return DAG.getMaskedGather(Gather->getVTList(),
-                                   Gather->getMemoryVT(), DL, Ops,
-                                   Gather->getMemOperand(),
-                                   Gather->getIndexType());
-      }
-      auto *Scatter = cast<MaskedScatterSDNode>(GorS);
-      SDValue Ops[] = { Chain, Scatter->getValue(),
-                        Mask, Base, Index, Scale };
-      return DAG.getMaskedScatter(Scatter->getVTList(),
-                                  Scatter->getMemoryVT(), DL,
-                                  Ops, Scatter->getMemOperand(),
-                                  Scatter->getIndexType());
+      return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
     }
   }
 
   // With vector masks we only demand the upper bit of the mask.
+  SDValue Mask = GorS->getMask();
   if (Mask.getScalarValueSizeInBits() != 1) {
     const TargetLowering &TLI = DAG.getTargetLoweringInfo();
     APInt DemandedMask(APInt::getSignMask(Mask.getScalarValueSizeInBits()));


        


More information about the llvm-commits mailing list