[llvm] r322210 - [SelectionDAG][X86] Explicitly store the scale in the gather/scatter ISD nodes

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 10 11:16:05 PST 2018


Author: ctopper
Date: Wed Jan 10 11:16:05 2018
New Revision: 322210

URL: http://llvm.org/viewvc/llvm-project?rev=322210&view=rev
Log:
[SelectionDAG][X86] Explicitly store the scale in the gather/scatter ISD nodes

Currently we infer the scale at isel time by analyzing whether the base is a constant 0 or not. If it is we assume scale is 1, else we take it from the element size of the pass thru or stored value. This seems a little weird and I think it makes more sense to make it explicit in the DAG rather than doing tricky things in the backend.

Most of this patch is just making sure we copy the scale around everywhere.

Differential Revision: https://reviews.llvm.org/D40055

Modified:
    llvm/trunk/include/llvm/CodeGen/SelectionDAGNodes.h
    llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/trunk/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
    llvm/trunk/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
    llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
    llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
    llvm/trunk/lib/Target/X86/X86ISelDAGToDAG.cpp
    llvm/trunk/lib/Target/X86/X86ISelLowering.cpp
    llvm/trunk/lib/Target/X86/X86ISelLowering.h
    llvm/trunk/test/CodeGen/X86/masked_gather_scatter.ll

Modified: llvm/trunk/include/llvm/CodeGen/SelectionDAGNodes.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/CodeGen/SelectionDAGNodes.h?rev=322210&r1=322209&r2=322210&view=diff
==============================================================================
--- llvm/trunk/include/llvm/CodeGen/SelectionDAGNodes.h (original)
+++ llvm/trunk/include/llvm/CodeGen/SelectionDAGNodes.h Wed Jan 10 11:16:05 2018
@@ -2120,13 +2120,14 @@ public:
       : MemSDNode(NodeTy, Order, dl, VTs, MemVT, MMO) {}
 
   // In the both nodes address is Op1, mask is Op2:
-  // MaskedGatherSDNode  (Chain, src0, mask, base, index), src0 is a passthru value
-  // MaskedScatterSDNode (Chain, value, mask, base, index)
+  // MaskedGatherSDNode  (Chain, passthru, mask, base, index, scale)
+  // MaskedScatterSDNode (Chain, value, mask, base, index, scale)
   // Mask is a vector of i1 elements
   const SDValue &getBasePtr() const { return getOperand(3); }
   const SDValue &getIndex()   const { return getOperand(4); }
   const SDValue &getMask()    const { return getOperand(2); }
   const SDValue &getValue()   const { return getOperand(1); }
+  const SDValue &getScale()   const { return getOperand(5); }
 
   static bool classof(const SDNode *N) {
     return N->getOpcode() == ISD::MGATHER ||

Modified: llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp?rev=322210&r1=322209&r2=322210&view=diff
==============================================================================
--- llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (original)
+++ llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp Wed Jan 10 11:16:05 2018
@@ -6726,6 +6726,7 @@ SDValue DAGCombiner::visitMSCATTER(SDNod
   SDValue DataLo, DataHi;
   std::tie(DataLo, DataHi) = DAG.SplitVector(Data, DL);
 
+  SDValue Scale = MSC->getScale();
   SDValue BasePtr = MSC->getBasePtr();
   SDValue IndexLo, IndexHi;
   std::tie(IndexLo, IndexHi) = DAG.SplitVector(MSC->getIndex(), DL);
@@ -6735,11 +6736,11 @@ SDValue DAGCombiner::visitMSCATTER(SDNod
                           MachineMemOperand::MOStore,  LoMemVT.getStoreSize(),
                           Alignment, MSC->getAAInfo(), MSC->getRanges());
 
-  SDValue OpsLo[] = { Chain, DataLo, MaskLo, BasePtr, IndexLo };
+  SDValue OpsLo[] = { Chain, DataLo, MaskLo, BasePtr, IndexLo, Scale };
   Lo = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), DataLo.getValueType(),
                             DL, OpsLo, MMO);
 
-  SDValue OpsHi[] = {Chain, DataHi, MaskHi, BasePtr, IndexHi};
+  SDValue OpsHi[] = { Chain, DataHi, MaskHi, BasePtr, IndexHi, Scale };
   Hi = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), DataHi.getValueType(),
                             DL, OpsHi, MMO);
 
@@ -6859,6 +6860,7 @@ SDValue DAGCombiner::visitMGATHER(SDNode
   EVT LoMemVT, HiMemVT;
   std::tie(LoMemVT, HiMemVT) = DAG.GetSplitDestVTs(MemoryVT);
 
+  SDValue Scale = MGT->getScale();
   SDValue BasePtr = MGT->getBasePtr();
   SDValue Index = MGT->getIndex();
   SDValue IndexLo, IndexHi;
@@ -6869,13 +6871,13 @@ SDValue DAGCombiner::visitMGATHER(SDNode
                           MachineMemOperand::MOLoad,  LoMemVT.getStoreSize(),
                           Alignment, MGT->getAAInfo(), MGT->getRanges());
 
-  SDValue OpsLo[] = { Chain, Src0Lo, MaskLo, BasePtr, IndexLo };
+  SDValue OpsLo[] = { Chain, Src0Lo, MaskLo, BasePtr, IndexLo, Scale };
   Lo = DAG.getMaskedGather(DAG.getVTList(LoVT, MVT::Other), LoVT, DL, OpsLo,
-                            MMO);
+                           MMO);
 
-  SDValue OpsHi[] = {Chain, Src0Hi, MaskHi, BasePtr, IndexHi};
+  SDValue OpsHi[] = { Chain, Src0Hi, MaskHi, BasePtr, IndexHi, Scale };
   Hi = DAG.getMaskedGather(DAG.getVTList(HiVT, MVT::Other), HiVT, DL, OpsHi,
-                            MMO);
+                           MMO);
 
   AddToWorklist(Lo.getNode());
   AddToWorklist(Hi.getNode());

Modified: llvm/trunk/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp?rev=322210&r1=322209&r2=322210&view=diff
==============================================================================
--- llvm/trunk/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp (original)
+++ llvm/trunk/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp Wed Jan 10 11:16:05 2018
@@ -501,7 +501,7 @@ SDValue DAGTypeLegalizer::PromoteIntRes_
 
   SDLoc dl(N);
   SDValue Ops[] = {N->getChain(), ExtSrc0, N->getMask(), N->getBasePtr(),
-                   N->getIndex()};
+                   N->getIndex(), N->getScale() };
   SDValue Res = DAG.getMaskedGather(DAG.getVTList(NVT, MVT::Other),
                                     N->getMemoryVT(), dl, Ops,
                                     N->getMemOperand());

Modified: llvm/trunk/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp?rev=322210&r1=322209&r2=322210&view=diff
==============================================================================
--- llvm/trunk/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp (original)
+++ llvm/trunk/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp Wed Jan 10 11:16:05 2018
@@ -1238,6 +1238,7 @@ void DAGTypeLegalizer::SplitVecRes_MGATH
   SDValue Mask = MGT->getMask();
   SDValue Src0 = MGT->getValue();
   SDValue Index = MGT->getIndex();
+  SDValue Scale = MGT->getScale();
   unsigned Alignment = MGT->getOriginalAlignment();
 
   // Split Mask operand
@@ -1269,11 +1270,11 @@ void DAGTypeLegalizer::SplitVecRes_MGATH
                          MachineMemOperand::MOLoad,  LoMemVT.getStoreSize(),
                          Alignment, MGT->getAAInfo(), MGT->getRanges());
 
-  SDValue OpsLo[] = {Ch, Src0Lo, MaskLo, Ptr, IndexLo};
+  SDValue OpsLo[] = {Ch, Src0Lo, MaskLo, Ptr, IndexLo, Scale};
   Lo = DAG.getMaskedGather(DAG.getVTList(LoVT, MVT::Other), LoVT, dl, OpsLo,
                            MMO);
 
-  SDValue OpsHi[] = {Ch, Src0Hi, MaskHi, Ptr, IndexHi};
+  SDValue OpsHi[] = {Ch, Src0Hi, MaskHi, Ptr, IndexHi, Scale};
   Hi = DAG.getMaskedGather(DAG.getVTList(HiVT, MVT::Other), HiVT, dl, OpsHi,
                            MMO);
 
@@ -1816,6 +1817,7 @@ SDValue DAGTypeLegalizer::SplitVecOp_MGA
   SDValue Ch = MGT->getChain();
   SDValue Ptr = MGT->getBasePtr();
   SDValue Index = MGT->getIndex();
+  SDValue Scale = MGT->getScale();
   SDValue Mask = MGT->getMask();
   SDValue Src0 = MGT->getValue();
   unsigned Alignment = MGT->getOriginalAlignment();
@@ -1848,7 +1850,7 @@ SDValue DAGTypeLegalizer::SplitVecOp_MGA
                          MachineMemOperand::MOLoad,  LoMemVT.getStoreSize(),
                          Alignment, MGT->getAAInfo(), MGT->getRanges());
 
-  SDValue OpsLo[] = {Ch, Src0Lo, MaskLo, Ptr, IndexLo};
+  SDValue OpsLo[] = {Ch, Src0Lo, MaskLo, Ptr, IndexLo, Scale};
   SDValue Lo = DAG.getMaskedGather(DAG.getVTList(LoVT, MVT::Other), LoVT, dl,
                                    OpsLo, MMO);
 
@@ -1858,7 +1860,7 @@ SDValue DAGTypeLegalizer::SplitVecOp_MGA
                          Alignment, MGT->getAAInfo(),
                          MGT->getRanges());
 
-  SDValue OpsHi[] = {Ch, Src0Hi, MaskHi, Ptr, IndexHi};
+  SDValue OpsHi[] = {Ch, Src0Hi, MaskHi, Ptr, IndexHi, Scale};
   SDValue Hi = DAG.getMaskedGather(DAG.getVTList(HiVT, MVT::Other), HiVT, dl,
                                    OpsHi, MMO);
 
@@ -1941,6 +1943,7 @@ SDValue DAGTypeLegalizer::SplitVecOp_MSC
   SDValue Ptr = N->getBasePtr();
   SDValue Mask = N->getMask();
   SDValue Index = N->getIndex();
+  SDValue Scale = N->getScale();
   SDValue Data = N->getValue();
   EVT MemoryVT = N->getMemoryVT();
   unsigned Alignment = N->getOriginalAlignment();
@@ -1976,7 +1979,7 @@ SDValue DAGTypeLegalizer::SplitVecOp_MSC
                          MachineMemOperand::MOStore, LoMemVT.getStoreSize(),
                          Alignment, N->getAAInfo(), N->getRanges());
 
-  SDValue OpsLo[] = {Ch, DataLo, MaskLo, Ptr, IndexLo};
+  SDValue OpsLo[] = {Ch, DataLo, MaskLo, Ptr, IndexLo, Scale};
   Lo = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), DataLo.getValueType(),
                             DL, OpsLo, MMO);
 
@@ -1988,7 +1991,7 @@ SDValue DAGTypeLegalizer::SplitVecOp_MSC
   // The order of the Scatter operation after split is well defined. The "Hi"
   // part comes after the "Lo". So these two operations should be chained one
   // after another.
-  SDValue OpsHi[] = {Lo, DataHi, MaskHi, Ptr, IndexHi};
+  SDValue OpsHi[] = {Lo, DataHi, MaskHi, Ptr, IndexHi, Scale};
   return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), DataHi.getValueType(),
                               DL, OpsHi, MMO);
 }
@@ -2954,6 +2957,7 @@ SDValue DAGTypeLegalizer::WidenVecRes_MG
   SDValue Mask = N->getMask();
   EVT MaskVT = Mask.getValueType();
   SDValue Src0 = GetWidenedVector(N->getValue());
+  SDValue Scale = N->getScale();
   unsigned NumElts = WideVT.getVectorNumElements();
   SDLoc dl(N);
 
@@ -2969,7 +2973,7 @@ SDValue DAGTypeLegalizer::WidenVecRes_MG
                                      Index.getValueType().getScalarType(),
                                      NumElts);
   Index = ModifyToType(Index, WideIndexVT);
-  SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index };
+  SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index, Scale };
   SDValue Res = DAG.getMaskedGather(DAG.getVTList(WideVT, MVT::Other),
                                     N->getMemoryVT(), dl, Ops,
                                     N->getMemOperand());
@@ -3593,6 +3597,7 @@ SDValue DAGTypeLegalizer::WidenVecOp_MSC
   SDValue DataOp = MSC->getValue();
   SDValue Mask = MSC->getMask();
   EVT MaskVT = Mask.getValueType();
+  SDValue Scale = MSC->getScale();
 
   // Widen the value.
   SDValue WideVal = GetWidenedVector(DataOp);
@@ -3612,7 +3617,8 @@ SDValue DAGTypeLegalizer::WidenVecOp_MSC
                                      NumElts);
   Index = ModifyToType(Index, WideIndexVT);
 
-  SDValue Ops[] = {MSC->getChain(), WideVal, Mask, MSC->getBasePtr(), Index};
+  SDValue Ops[] = {MSC->getChain(), WideVal, Mask, MSC->getBasePtr(), Index,
+                   Scale};
   return DAG.getMaskedScatter(DAG.getVTList(MVT::Other),
                               MSC->getMemoryVT(), dl, Ops,
                               MSC->getMemOperand());

Modified: llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp?rev=322210&r1=322209&r2=322210&view=diff
==============================================================================
--- llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp (original)
+++ llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp Wed Jan 10 11:16:05 2018
@@ -6208,7 +6208,7 @@ SDValue SelectionDAG::getMaskedStore(SDV
 SDValue SelectionDAG::getMaskedGather(SDVTList VTs, EVT VT, const SDLoc &dl,
                                       ArrayRef<SDValue> Ops,
                                       MachineMemOperand *MMO) {
-  assert(Ops.size() == 5 && "Incompatible number of operands");
+  assert(Ops.size() == 6 && "Incompatible number of operands");
 
   FoldingSetNodeID ID;
   AddNodeIDNode(ID, ISD::MGATHER, VTs, Ops);
@@ -6234,6 +6234,9 @@ SDValue SelectionDAG::getMaskedGather(SD
   assert(N->getIndex().getValueType().getVectorNumElements() ==
              N->getValueType(0).getVectorNumElements() &&
          "Vector width mismatch between index and data");
+  assert(isa<ConstantSDNode>(N->getScale()) &&
+         cast<ConstantSDNode>(N->getScale())->getAPIntValue().isPowerOf2() &&
+         "Scale should be a constant power of 2");
 
   CSEMap.InsertNode(N, IP);
   InsertNode(N);
@@ -6245,7 +6248,7 @@ SDValue SelectionDAG::getMaskedGather(SD
 SDValue SelectionDAG::getMaskedScatter(SDVTList VTs, EVT VT, const SDLoc &dl,
                                        ArrayRef<SDValue> Ops,
                                        MachineMemOperand *MMO) {
-  assert(Ops.size() == 5 && "Incompatible number of operands");
+  assert(Ops.size() == 6 && "Incompatible number of operands");
 
   FoldingSetNodeID ID;
   AddNodeIDNode(ID, ISD::MSCATTER, VTs, Ops);
@@ -6268,6 +6271,9 @@ SDValue SelectionDAG::getMaskedScatter(S
   assert(N->getIndex().getValueType().getVectorNumElements() ==
              N->getValue().getValueType().getVectorNumElements() &&
          "Vector width mismatch between index and data");
+  assert(isa<ConstantSDNode>(N->getScale()) &&
+         cast<ConstantSDNode>(N->getScale())->getAPIntValue().isPowerOf2() &&
+         "Scale should be a constant power of 2");
 
   CSEMap.InsertNode(N, IP);
   InsertNode(N);

Modified: llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp?rev=322210&r1=322209&r2=322210&view=diff
==============================================================================
--- llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (original)
+++ llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp Wed Jan 10 11:16:05 2018
@@ -3867,7 +3867,7 @@ void SelectionDAGBuilder::visitMaskedSto
 // extract the splat value and use it as a uniform base.
 // In all other cases the function returns 'false'.
 static bool getUniformBase(const Value* &Ptr, SDValue& Base, SDValue& Index,
-                           SelectionDAGBuilder* SDB) {
+                           SDValue &Scale, SelectionDAGBuilder* SDB) {
   SelectionDAG& DAG = SDB->DAG;
   LLVMContext &Context = *DAG.getContext();
 
@@ -3897,6 +3897,10 @@ static bool getUniformBase(const Value*
   if (!SDB->findValue(Ptr) || !SDB->findValue(IndexVal))
     return false;
 
+  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+  const DataLayout &DL = DAG.getDataLayout();
+  Scale = DAG.getTargetConstant(DL.getTypeAllocSize(GEP->getResultElementType()),
+                                SDB->getCurSDLoc(), TLI.getPointerTy(DL));
   Base = SDB->getValue(Ptr);
   Index = SDB->getValue(IndexVal);
 
@@ -3926,8 +3930,9 @@ void SelectionDAGBuilder::visitMaskedSca
 
   SDValue Base;
   SDValue Index;
+  SDValue Scale;
   const Value *BasePtr = Ptr;
-  bool UniformBase = getUniformBase(BasePtr, Base, Index, this);
+  bool UniformBase = getUniformBase(BasePtr, Base, Index, Scale, this);
 
   const Value *MemOpBasePtr = UniformBase ? BasePtr : nullptr;
   MachineMemOperand *MMO = DAG.getMachineFunction().
@@ -3935,10 +3940,11 @@ void SelectionDAGBuilder::visitMaskedSca
                          MachineMemOperand::MOStore,  VT.getStoreSize(),
                          Alignment, AAInfo);
   if (!UniformBase) {
-    Base = DAG.getTargetConstant(0, sdl, TLI.getPointerTy(DAG.getDataLayout()));
+    Base = DAG.getConstant(0, sdl, TLI.getPointerTy(DAG.getDataLayout()));
     Index = getValue(Ptr);
+    Scale = DAG.getTargetConstant(1, sdl, TLI.getPointerTy(DAG.getDataLayout()));
   }
-  SDValue Ops[] = { getRoot(), Src0, Mask, Base, Index };
+  SDValue Ops[] = { getRoot(), Src0, Mask, Base, Index, Scale };
   SDValue Scatter = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), VT, sdl,
                                          Ops, MMO);
   DAG.setRoot(Scatter);
@@ -4025,8 +4031,9 @@ void SelectionDAGBuilder::visitMaskedGat
   SDValue Root = DAG.getRoot();
   SDValue Base;
   SDValue Index;
+  SDValue Scale;
   const Value *BasePtr = Ptr;
-  bool UniformBase = getUniformBase(BasePtr, Base, Index, this);
+  bool UniformBase = getUniformBase(BasePtr, Base, Index, Scale, this);
   bool ConstantMemory = false;
   if (UniformBase &&
       AA && AA->pointsToConstantMemory(MemoryLocation(
@@ -4044,10 +4051,11 @@ void SelectionDAGBuilder::visitMaskedGat
                          Alignment, AAInfo, Ranges);
 
   if (!UniformBase) {
-    Base = DAG.getTargetConstant(0, sdl, TLI.getPointerTy(DAG.getDataLayout()));
+    Base = DAG.getConstant(0, sdl, TLI.getPointerTy(DAG.getDataLayout()));
     Index = getValue(Ptr);
+    Scale = DAG.getTargetConstant(1, sdl, TLI.getPointerTy(DAG.getDataLayout()));
   }
-  SDValue Ops[] = { Root, Src0, Mask, Base, Index };
+  SDValue Ops[] = { Root, Src0, Mask, Base, Index, Scale };
   SDValue Gather = DAG.getMaskedGather(DAG.getVTList(VT, MVT::Other), VT, sdl,
                                        Ops, MMO);
 

Modified: llvm/trunk/lib/Target/X86/X86ISelDAGToDAG.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/X86/X86ISelDAGToDAG.cpp?rev=322210&r1=322209&r2=322210&view=diff
==============================================================================
--- llvm/trunk/lib/Target/X86/X86ISelDAGToDAG.cpp (original)
+++ llvm/trunk/lib/Target/X86/X86ISelDAGToDAG.cpp Wed Jan 10 11:16:05 2018
@@ -1508,6 +1508,12 @@ bool X86DAGToDAGISel::matchAddressBase(S
 bool X86DAGToDAGISel::matchVectorAddress(SDValue N, X86ISelAddressMode &AM) {
   // TODO: Support other operations.
   switch (N.getOpcode()) {
+  case ISD::Constant: {
+    uint64_t Val = cast<ConstantSDNode>(N)->getSExtValue();
+    if (!foldOffsetIntoAddress(Val, AM))
+      return false;
+    break;
+  }
   case X86ISD::Wrapper:
     if (!matchWrapper(N, AM))
       return false;
@@ -1523,7 +1529,7 @@ bool X86DAGToDAGISel::selectVectorAddr(S
   X86ISelAddressMode AM;
   auto *Mgs = cast<X86MaskedGatherScatterSDNode>(Parent);
   AM.IndexReg = Mgs->getIndex();
-  AM.Scale = Mgs->getValue().getScalarValueSizeInBits() / 8;
+  AM.Scale = cast<ConstantSDNode>(Mgs->getScale())->getZExtValue();
 
   unsigned AddrSpace = cast<MemSDNode>(Parent)->getPointerInfo().getAddrSpace();
   // AddrSpace 256 -> GS, 257 -> FS, 258 -> SS.
@@ -1534,14 +1540,8 @@ bool X86DAGToDAGISel::selectVectorAddr(S
   if (AddrSpace == 258)
     AM.Segment = CurDAG->getRegister(X86::SS, MVT::i16);
 
-  // If Base is 0, the whole address is in index and the Scale is 1
-  if (isa<ConstantSDNode>(N)) {
-    assert(cast<ConstantSDNode>(N)->isNullValue() &&
-           "Unexpected base in gather/scatter");
-    AM.Scale = 1;
-  }
-  // Otherwise, try to match into the base and displacement fields.
-  else if (matchVectorAddress(N, AM))
+  // Try to match into the base and displacement fields.
+  if (matchVectorAddress(N, AM))
     return false;
 
   MVT VT = N.getSimpleValueType();

Modified: llvm/trunk/lib/Target/X86/X86ISelLowering.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/X86/X86ISelLowering.cpp?rev=322210&r1=322209&r2=322210&view=diff
==============================================================================
--- llvm/trunk/lib/Target/X86/X86ISelLowering.cpp (original)
+++ llvm/trunk/lib/Target/X86/X86ISelLowering.cpp Wed Jan 10 11:16:05 2018
@@ -24317,6 +24317,7 @@ static SDValue LowerMSCATTER(SDValue Op,
   assert(VT.getScalarSizeInBits() >= 32 && "Unsupported scatter op");
   SDLoc dl(Op);
 
+  SDValue Scale = N->getScale();
   SDValue Index = N->getIndex();
   SDValue Mask = N->getMask();
   SDValue Chain = N->getChain();
@@ -24383,7 +24384,7 @@ static SDValue LowerMSCATTER(SDValue Op,
 
   // The mask is killed by scatter, add it to the values
   SDVTList VTs = DAG.getVTList(Mask.getValueType(), MVT::Other);
-  SDValue Ops[] = {Chain, Src, Mask, BasePtr, Index};
+  SDValue Ops[] = {Chain, Src, Mask, BasePtr, Index, Scale};
   SDValue NewScatter = DAG.getTargetMemSDNode<X86MaskedScatterSDNode>(
       VTs, Ops, dl, N->getMemoryVT(), N->getMemOperand());
   DAG.ReplaceAllUsesWith(Op, SDValue(NewScatter.getNode(), 1));
@@ -24489,6 +24490,7 @@ static SDValue LowerMGATHER(SDValue Op,
   MaskedGatherSDNode *N = cast<MaskedGatherSDNode>(Op.getNode());
   SDLoc dl(Op);
   MVT VT = Op.getSimpleValueType();
+  SDValue Scale = N->getScale();
   SDValue Index = N->getIndex();
   SDValue Mask = N->getMask();
   SDValue Src0 = N->getValue();
@@ -24509,7 +24511,8 @@ static SDValue LowerMGATHER(SDValue Op,
     // the vector contains 8 elements, we just sign-extend the index
     if (NumElts == 8) {
       Index = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::v8i64, Index);
-      SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index };
+      SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index,
+                        Scale };
       SDValue NewGather = DAG.getTargetMemSDNode<X86MaskedGatherSDNode>(
           DAG.getVTList(VT, MaskVT, MVT::Other), Ops, dl, N->getMemoryVT(),
           N->getMemOperand());
@@ -24533,7 +24536,7 @@ static SDValue LowerMGATHER(SDValue Op,
     MVT NewVT = MVT::getVectorVT(VT.getScalarType(), NumElts);
     Src0 = ExtendToType(Src0, NewVT, DAG);
 
-    SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index };
+    SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index, Scale };
     SDValue NewGather = DAG.getTargetMemSDNode<X86MaskedGatherSDNode>(
         DAG.getVTList(NewVT, MaskVT, MVT::Other), Ops, dl, N->getMemoryVT(),
         N->getMemOperand());
@@ -24544,7 +24547,7 @@ static SDValue LowerMGATHER(SDValue Op,
     return DAG.getMergeValues(RetOps, dl);
   }
 
-  SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index };
+  SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index, Scale };
   SDValue NewGather = DAG.getTargetMemSDNode<X86MaskedGatherSDNode>(
       DAG.getVTList(VT, MaskVT, MVT::Other), Ops, dl, N->getMemoryVT(),
       N->getMemOperand());
@@ -25080,7 +25083,7 @@ void X86TargetLowering::ReplaceNodeResul
         Mask = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::v4i32, Mask);
       }
       SDValue Ops[] = { Gather->getChain(), Src0, Mask, Gather->getBasePtr(),
-                        Index };
+                        Index, Gather->getScale() };
       SDValue Res = DAG.getTargetMemSDNode<X86MaskedGatherSDNode>(
         DAG.getVTList(MVT::v4f32, Mask.getValueType(), MVT::Other), Ops, dl,
         Gather->getMemoryVT(), Gather->getMemOperand());
@@ -25107,7 +25110,7 @@ void X86TargetLowering::ReplaceNodeResul
           Mask = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::v4i32, Mask);
         }
         SDValue Ops[] = { Gather->getChain(), Src0, Mask, Gather->getBasePtr(),
-                          Index };
+                          Index, Gather->getScale() };
         SDValue Res = DAG.getTargetMemSDNode<X86MaskedGatherSDNode>(
           DAG.getVTList(MVT::v4i32, Mask.getValueType(), MVT::Other), Ops, dl,
           Gather->getMemoryVT(), Gather->getMemOperand());
@@ -25128,7 +25131,7 @@ void X86TargetLowering::ReplaceNodeResul
       Mask = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4i1, Mask,
                          DAG.getConstant(0, dl, MVT::v2i1));
       SDValue Ops[] = { Gather->getChain(), Src0, Mask, Gather->getBasePtr(),
-                        Index };
+                        Index, Gather->getScale() };
       SDValue Res = DAG.getMaskedGather(DAG.getVTList(MVT::v4i32, MVT::Other),
                                         Gather->getMemoryVT(), dl, Ops,
                                         Gather->getMemOperand());

Modified: llvm/trunk/lib/Target/X86/X86ISelLowering.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/X86/X86ISelLowering.h?rev=322210&r1=322209&r2=322210&view=diff
==============================================================================
--- llvm/trunk/lib/Target/X86/X86ISelLowering.h (original)
+++ llvm/trunk/lib/Target/X86/X86ISelLowering.h Wed Jan 10 11:16:05 2018
@@ -1442,6 +1442,7 @@ namespace llvm {
     const SDValue &getIndex()   const { return getOperand(4); }
     const SDValue &getMask()    const { return getOperand(2); }
     const SDValue &getValue()   const { return getOperand(1); }
+    const SDValue &getScale()   const { return getOperand(5); }
 
     static bool classof(const SDNode *N) {
       return N->getOpcode() == X86ISD::MGATHER ||

Modified: llvm/trunk/test/CodeGen/X86/masked_gather_scatter.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/CodeGen/X86/masked_gather_scatter.ll?rev=322210&r1=322209&r2=322210&view=diff
==============================================================================
--- llvm/trunk/test/CodeGen/X86/masked_gather_scatter.ll (original)
+++ llvm/trunk/test/CodeGen/X86/masked_gather_scatter.ll Wed Jan 10 11:16:05 2018
@@ -2782,3 +2782,163 @@ define <16 x float> @zext_index(float* %
   %res = call <16 x float> @llvm.masked.gather.v16f32.v16p0f32(<16 x float*> %gep.random, i32 4, <16 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>, <16 x float> undef)
   ret <16 x float>%res
 }
+
+define <16 x double> @test_gather_setcc_split(double* %base, <16 x i32> %ind, <16 x i32> %cmp, <16 x double> %passthru) {
+; KNL_64-LABEL: test_gather_setcc_split:
+; KNL_64:       # %bb.0:
+; KNL_64-NEXT:    vextractf64x4 $1, %zmm0, %ymm4
+; KNL_64-NEXT:    vpxor %xmm5, %xmm5, %xmm5
+; KNL_64-NEXT:    vextracti64x4 $1, %zmm1, %ymm6
+; KNL_64-NEXT:    vpcmpeqd %zmm5, %zmm6, %k1
+; KNL_64-NEXT:    vpcmpeqd %zmm5, %zmm1, %k2
+; KNL_64-NEXT:    vgatherdpd (%rdi,%ymm0,8), %zmm2 {%k2}
+; KNL_64-NEXT:    vgatherdpd (%rdi,%ymm4,8), %zmm3 {%k1}
+; KNL_64-NEXT:    vmovapd %zmm2, %zmm0
+; KNL_64-NEXT:    vmovapd %zmm3, %zmm1
+; KNL_64-NEXT:    retq
+;
+; KNL_32-LABEL: test_gather_setcc_split:
+; KNL_32:       # %bb.0:
+; KNL_32-NEXT:    pushl %ebp
+; KNL_32-NEXT:    .cfi_def_cfa_offset 8
+; KNL_32-NEXT:    .cfi_offset %ebp, -8
+; KNL_32-NEXT:    movl %esp, %ebp
+; KNL_32-NEXT:    .cfi_def_cfa_register %ebp
+; KNL_32-NEXT:    andl $-64, %esp
+; KNL_32-NEXT:    subl $64, %esp
+; KNL_32-NEXT:    vmovapd 72(%ebp), %zmm3
+; KNL_32-NEXT:    movl 8(%ebp), %eax
+; KNL_32-NEXT:    vextractf64x4 $1, %zmm0, %ymm4
+; KNL_32-NEXT:    vpxor %xmm5, %xmm5, %xmm5
+; KNL_32-NEXT:    vextracti64x4 $1, %zmm1, %ymm6
+; KNL_32-NEXT:    vpcmpeqd %zmm5, %zmm6, %k1
+; KNL_32-NEXT:    vpcmpeqd %zmm5, %zmm1, %k2
+; KNL_32-NEXT:    vgatherdpd (%eax,%ymm0,8), %zmm2 {%k2}
+; KNL_32-NEXT:    vgatherdpd (%eax,%ymm4,8), %zmm3 {%k1}
+; KNL_32-NEXT:    vmovapd %zmm2, %zmm0
+; KNL_32-NEXT:    vmovapd %zmm3, %zmm1
+; KNL_32-NEXT:    movl %ebp, %esp
+; KNL_32-NEXT:    popl %ebp
+; KNL_32-NEXT:    retl
+;
+; SKX-LABEL: test_gather_setcc_split:
+; SKX:       # %bb.0:
+; SKX-NEXT:    vextractf64x4 $1, %zmm0, %ymm4
+; SKX-NEXT:    vextracti64x4 $1, %zmm1, %ymm5
+; SKX-NEXT:    vpxor %xmm6, %xmm6, %xmm6
+; SKX-NEXT:    vpcmpeqd %ymm6, %ymm5, %k1
+; SKX-NEXT:    vpcmpeqd %ymm6, %ymm1, %k2
+; SKX-NEXT:    vgatherdpd (%rdi,%ymm0,8), %zmm2 {%k2}
+; SKX-NEXT:    vgatherdpd (%rdi,%ymm4,8), %zmm3 {%k1}
+; SKX-NEXT:    vmovapd %zmm2, %zmm0
+; SKX-NEXT:    vmovapd %zmm3, %zmm1
+; SKX-NEXT:    retq
+;
+; SKX_32-LABEL: test_gather_setcc_split:
+; SKX_32:       # %bb.0:
+; SKX_32-NEXT:    pushl %ebp
+; SKX_32-NEXT:    .cfi_def_cfa_offset 8
+; SKX_32-NEXT:    .cfi_offset %ebp, -8
+; SKX_32-NEXT:    movl %esp, %ebp
+; SKX_32-NEXT:    .cfi_def_cfa_register %ebp
+; SKX_32-NEXT:    andl $-64, %esp
+; SKX_32-NEXT:    subl $64, %esp
+; SKX_32-NEXT:    vmovapd 72(%ebp), %zmm3
+; SKX_32-NEXT:    movl 8(%ebp), %eax
+; SKX_32-NEXT:    vextractf64x4 $1, %zmm0, %ymm4
+; SKX_32-NEXT:    vextracti64x4 $1, %zmm1, %ymm5
+; SKX_32-NEXT:    vpxor %xmm6, %xmm6, %xmm6
+; SKX_32-NEXT:    vpcmpeqd %ymm6, %ymm5, %k1
+; SKX_32-NEXT:    vpcmpeqd %ymm6, %ymm1, %k2
+; SKX_32-NEXT:    vgatherdpd (%eax,%ymm0,8), %zmm2 {%k2}
+; SKX_32-NEXT:    vgatherdpd (%eax,%ymm4,8), %zmm3 {%k1}
+; SKX_32-NEXT:    vmovapd %zmm2, %zmm0
+; SKX_32-NEXT:    vmovapd %zmm3, %zmm1
+; SKX_32-NEXT:    movl %ebp, %esp
+; SKX_32-NEXT:    popl %ebp
+; SKX_32-NEXT:    retl
+  %sext_ind = sext <16 x i32> %ind to <16 x i64>
+  %gep.random = getelementptr double, double *%base, <16 x i64> %sext_ind
+
+  %mask = icmp eq <16 x i32> %cmp, zeroinitializer
+  %res = call <16 x double> @llvm.masked.gather.v16f64.v16p0f64(<16 x double*> %gep.random, i32 4, <16 x i1> %mask, <16 x double> %passthru)
+  ret <16 x double>%res
+}
+
+define void @test_scatter_setcc_split(double* %base, <16 x i32> %ind, <16 x i32> %cmp, <16 x double> %src0)  {
+; KNL_64-LABEL: test_scatter_setcc_split:
+; KNL_64:       # %bb.0:
+; KNL_64-NEXT:    vextractf64x4 $1, %zmm0, %ymm4
+; KNL_64-NEXT:    vpxor %xmm5, %xmm5, %xmm5
+; KNL_64-NEXT:    vpcmpeqd %zmm5, %zmm1, %k1
+; KNL_64-NEXT:    vextracti64x4 $1, %zmm1, %ymm1
+; KNL_64-NEXT:    vpcmpeqd %zmm5, %zmm1, %k2
+; KNL_64-NEXT:    vscatterdpd %zmm3, (%rdi,%ymm4,8) {%k2}
+; KNL_64-NEXT:    vscatterdpd %zmm2, (%rdi,%ymm0,8) {%k1}
+; KNL_64-NEXT:    vzeroupper
+; KNL_64-NEXT:    retq
+;
+; KNL_32-LABEL: test_scatter_setcc_split:
+; KNL_32:       # %bb.0:
+; KNL_32-NEXT:    pushl %ebp
+; KNL_32-NEXT:    .cfi_def_cfa_offset 8
+; KNL_32-NEXT:    .cfi_offset %ebp, -8
+; KNL_32-NEXT:    movl %esp, %ebp
+; KNL_32-NEXT:    .cfi_def_cfa_register %ebp
+; KNL_32-NEXT:    andl $-64, %esp
+; KNL_32-NEXT:    subl $64, %esp
+; KNL_32-NEXT:    vmovapd 72(%ebp), %zmm3
+; KNL_32-NEXT:    movl 8(%ebp), %eax
+; KNL_32-NEXT:    vextractf64x4 $1, %zmm0, %ymm4
+; KNL_32-NEXT:    vpxor %xmm5, %xmm5, %xmm5
+; KNL_32-NEXT:    vpcmpeqd %zmm5, %zmm1, %k1
+; KNL_32-NEXT:    vextracti64x4 $1, %zmm1, %ymm1
+; KNL_32-NEXT:    vpcmpeqd %zmm5, %zmm1, %k2
+; KNL_32-NEXT:    vscatterdpd %zmm3, (%eax,%ymm4,8) {%k2}
+; KNL_32-NEXT:    vscatterdpd %zmm2, (%eax,%ymm0,8) {%k1}
+; KNL_32-NEXT:    movl %ebp, %esp
+; KNL_32-NEXT:    popl %ebp
+; KNL_32-NEXT:    vzeroupper
+; KNL_32-NEXT:    retl
+;
+; SKX-LABEL: test_scatter_setcc_split:
+; SKX:       # %bb.0:
+; SKX-NEXT:    vextractf64x4 $1, %zmm0, %ymm4
+; SKX-NEXT:    vpxor %xmm5, %xmm5, %xmm5
+; SKX-NEXT:    vpcmpeqd %ymm5, %ymm1, %k1
+; SKX-NEXT:    vextracti64x4 $1, %zmm1, %ymm1
+; SKX-NEXT:    vpcmpeqd %ymm5, %ymm1, %k2
+; SKX-NEXT:    vscatterdpd %zmm3, (%rdi,%ymm4,8) {%k2}
+; SKX-NEXT:    vscatterdpd %zmm2, (%rdi,%ymm0,8) {%k1}
+; SKX-NEXT:    vzeroupper
+; SKX-NEXT:    retq
+;
+; SKX_32-LABEL: test_scatter_setcc_split:
+; SKX_32:       # %bb.0:
+; SKX_32-NEXT:    pushl %ebp
+; SKX_32-NEXT:    .cfi_def_cfa_offset 8
+; SKX_32-NEXT:    .cfi_offset %ebp, -8
+; SKX_32-NEXT:    movl %esp, %ebp
+; SKX_32-NEXT:    .cfi_def_cfa_register %ebp
+; SKX_32-NEXT:    andl $-64, %esp
+; SKX_32-NEXT:    subl $64, %esp
+; SKX_32-NEXT:    vmovapd 72(%ebp), %zmm3
+; SKX_32-NEXT:    movl 8(%ebp), %eax
+; SKX_32-NEXT:    vextractf64x4 $1, %zmm0, %ymm4
+; SKX_32-NEXT:    vpxor %xmm5, %xmm5, %xmm5
+; SKX_32-NEXT:    vpcmpeqd %ymm5, %ymm1, %k1
+; SKX_32-NEXT:    vextracti64x4 $1, %zmm1, %ymm1
+; SKX_32-NEXT:    vpcmpeqd %ymm5, %ymm1, %k2
+; SKX_32-NEXT:    vscatterdpd %zmm3, (%eax,%ymm4,8) {%k2}
+; SKX_32-NEXT:    vscatterdpd %zmm2, (%eax,%ymm0,8) {%k1}
+; SKX_32-NEXT:    movl %ebp, %esp
+; SKX_32-NEXT:    popl %ebp
+; SKX_32-NEXT:    vzeroupper
+; SKX_32-NEXT:    retl
+  %sext_ind = sext <16 x i32> %ind to <16 x i64>
+  %gep.random = getelementptr double, double *%base, <16 x i64> %sext_ind
+
+  %mask = icmp eq <16 x i32> %cmp, zeroinitializer
+  call void @llvm.masked.scatter.v16f64.v16p0f64(<16 x double> %src0, <16 x double*> %gep.random, i32 4, <16 x i1> %mask)
+  ret void
+}




More information about the llvm-commits mailing list