[llvm] r318823 - [X86] Add an X86ISD::MSCATTER node for consistency with the X86ISD::MGATHER.

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 22 00:10:54 PST 2017


Author: ctopper
Date: Wed Nov 22 00:10:54 2017
New Revision: 318823

URL: http://llvm.org/viewvc/llvm-project?rev=318823&view=rev
Log:
[X86] Add an X86ISD::MSCATTER node for consistency with the X86ISD::MGATHER.

This makes the fact that X86 needs an explicit mask output not part of the type constraint for the ISD::MSCATTER.

This also gives the X86ISD::MGATHER/MSCATTER nodes a common base class simplifying the address selection code in X86ISelDAGToDAG.cpp

Modified:
    llvm/trunk/lib/Target/X86/X86ISelDAGToDAG.cpp
    llvm/trunk/lib/Target/X86/X86ISelLowering.cpp
    llvm/trunk/lib/Target/X86/X86ISelLowering.h
    llvm/trunk/lib/Target/X86/X86InstrFragmentsSIMD.td

Modified: llvm/trunk/lib/Target/X86/X86ISelDAGToDAG.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/X86/X86ISelDAGToDAG.cpp?rev=318823&r1=318822&r2=318823&view=diff
==============================================================================
--- llvm/trunk/lib/Target/X86/X86ISelDAGToDAG.cpp (original)
+++ llvm/trunk/lib/Target/X86/X86ISelDAGToDAG.cpp Wed Nov 22 00:10:54 2017
@@ -1522,14 +1522,9 @@ bool X86DAGToDAGISel::selectVectorAddr(S
                                        SDValue &Scale, SDValue &Index,
                                        SDValue &Disp, SDValue &Segment) {
   X86ISelAddressMode AM;
-  if (auto Mgs = dyn_cast<MaskedGatherScatterSDNode>(Parent)) {
-    AM.IndexReg = Mgs->getIndex();
-    AM.Scale = Mgs->getValue().getScalarValueSizeInBits() / 8;
-  } else {
-    auto X86Gather = cast<X86MaskedGatherSDNode>(Parent);
-    AM.IndexReg = X86Gather->getIndex();
-    AM.Scale = X86Gather->getValue().getScalarValueSizeInBits() / 8;
-  }
+  auto *Mgs = cast<X86MaskedGatherScatterSDNode>(Parent);
+  AM.IndexReg = Mgs->getIndex();
+  AM.Scale = Mgs->getValue().getScalarValueSizeInBits() / 8;
 
   unsigned AddrSpace = cast<MemSDNode>(Parent)->getPointerInfo().getAddrSpace();
   // AddrSpace 256 -> GS, 257 -> FS, 258 -> SS.

Modified: llvm/trunk/lib/Target/X86/X86ISelLowering.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/X86/X86ISelLowering.cpp?rev=318823&r1=318822&r2=318823&view=diff
==============================================================================
--- llvm/trunk/lib/Target/X86/X86ISelLowering.cpp (original)
+++ llvm/trunk/lib/Target/X86/X86ISelLowering.cpp Wed Nov 22 00:10:54 2017
@@ -24112,19 +24112,12 @@ static SDValue LowerMSCATTER(SDValue Op,
   assert(Subtarget.hasAVX512() &&
          "MGATHER/MSCATTER are supported on AVX-512 arch only");
 
-  // X86 scatter kills mask register, so its type should be added to
-  // the list of return values.
-  // If the "scatter" has 2 return values, it is already handled.
-  if (Op.getNode()->getNumValues() == 2)
-    return Op;
-
   MaskedScatterSDNode *N = cast<MaskedScatterSDNode>(Op.getNode());
   SDValue Src = N->getValue();
   MVT VT = Src.getSimpleValueType();
   assert(VT.getScalarSizeInBits() >= 32 && "Unsupported scatter op");
   SDLoc dl(Op);
 
-  SDValue NewScatter;
   SDValue Index = N->getIndex();
   SDValue Mask = N->getMask();
   SDValue Chain = N->getChain();
@@ -24195,8 +24188,8 @@ static SDValue LowerMSCATTER(SDValue Op,
   // The mask is killed by scatter, add it to the values
   SDVTList VTs = DAG.getVTList(BitMaskVT, MVT::Other);
   SDValue Ops[] = {Chain, Src, Mask, BasePtr, Index};
-  NewScatter = DAG.getMaskedScatter(VTs, N->getMemoryVT(), dl, Ops,
-                                    N->getMemOperand());
+  SDValue NewScatter = DAG.getTargetMemSDNode<X86MaskedScatterSDNode>(
+      VTs, Ops, dl, N->getMemoryVT(), N->getMemOperand());
   DAG.ReplaceAllUsesWith(Op, SDValue(NewScatter.getNode(), 1));
   return SDValue(NewScatter.getNode(), 1);
 }
@@ -25261,6 +25254,7 @@ const char *X86TargetLowering::getTarget
   case X86ISD::CVTS2UI_RND:        return "X86ISD::CVTS2UI_RND";
   case X86ISD::LWPINS:             return "X86ISD::LWPINS";
   case X86ISD::MGATHER:            return "X86ISD::MGATHER";
+  case X86ISD::MSCATTER:           return "X86ISD::MSCATTER";
   case X86ISD::VPDPBUSD:           return "X86ISD::VPDPBUSD";
   case X86ISD::VPDPBUSDS:          return "X86ISD::VPDPBUSDS";
   case X86ISD::VPDPWSSD:           return "X86ISD::VPDPWSSD";

Modified: llvm/trunk/lib/Target/X86/X86ISelLowering.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/X86/X86ISelLowering.h?rev=318823&r1=318822&r2=318823&view=diff
==============================================================================
--- llvm/trunk/lib/Target/X86/X86ISelLowering.h (original)
+++ llvm/trunk/lib/Target/X86/X86ISelLowering.h Wed Nov 22 00:10:54 2017
@@ -637,8 +637,8 @@ namespace llvm {
       // Vector truncating masked store with unsigned/signed saturation
       VMTRUNCSTOREUS, VMTRUNCSTORES,
 
-      // X86 specific gather
-      MGATHER
+      // X86 specific gather and scatter
+      MGATHER, MSCATTER,
 
       // WARNING: Do not add anything in the end unless you want the node to
       // have memop! In fact, starting from FIRST_TARGET_MEMORY_OPCODE all
@@ -1423,16 +1423,15 @@ namespace llvm {
     }
   };
 
-  // X86 specific Gather node.
-  // The class has the same order of operands as MaskedGatherSDNode for
+  // X86 specific Gather/Scatter nodes.
+  // The class has the same order of operands as MaskedGatherScatterSDNode for
   // convenience.
-  class X86MaskedGatherSDNode : public MemSDNode {
+  class X86MaskedGatherScatterSDNode : public MemSDNode {
   public:
-    X86MaskedGatherSDNode(unsigned Order,
-                          const DebugLoc &dl, SDVTList VTs, EVT MemVT,
-                          MachineMemOperand *MMO)
-      : MemSDNode(X86ISD::MGATHER, Order, dl, VTs, MemVT, MMO)
-    {}
+    X86MaskedGatherScatterSDNode(unsigned Opc, unsigned Order,
+                                 const DebugLoc &dl, SDVTList VTs, EVT MemVT,
+                                 MachineMemOperand *MMO)
+        : MemSDNode(Opc, Order, dl, VTs, MemVT, MMO) {}
 
     const SDValue &getBasePtr() const { return getOperand(3); }
     const SDValue &getIndex()   const { return getOperand(4); }
@@ -1440,10 +1439,35 @@ namespace llvm {
     const SDValue &getValue()   const { return getOperand(1); }
 
     static bool classof(const SDNode *N) {
+      return N->getOpcode() == X86ISD::MGATHER ||
+             N->getOpcode() == X86ISD::MSCATTER;
+    }
+  };
+
+  class X86MaskedGatherSDNode : public X86MaskedGatherScatterSDNode {
+  public:
+    X86MaskedGatherSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs,
+                          EVT MemVT, MachineMemOperand *MMO)
+        : X86MaskedGatherScatterSDNode(X86ISD::MGATHER, Order, dl, VTs, MemVT,
+                                       MMO) {}
+
+    static bool classof(const SDNode *N) {
       return N->getOpcode() == X86ISD::MGATHER;
     }
   };
 
+  class X86MaskedScatterSDNode : public X86MaskedGatherScatterSDNode {
+  public:
+    X86MaskedScatterSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs,
+                           EVT MemVT, MachineMemOperand *MMO)
+        : X86MaskedGatherScatterSDNode(X86ISD::MSCATTER, Order, dl, VTs, MemVT,
+                                       MMO) {}
+
+    static bool classof(const SDNode *N) {
+      return N->getOpcode() == X86ISD::MSCATTER;
+    }
+  };
+
   /// Generate unpacklo/unpackhi shuffle mask.
   template <typename T = int>
   void createUnpackShuffleMask(MVT VT, SmallVectorImpl<T> &Mask, bool Lo,

Modified: llvm/trunk/lib/Target/X86/X86InstrFragmentsSIMD.td
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/X86/X86InstrFragmentsSIMD.td?rev=318823&r1=318822&r2=318823&view=diff
==============================================================================
--- llvm/trunk/lib/Target/X86/X86InstrFragmentsSIMD.td (original)
+++ llvm/trunk/lib/Target/X86/X86InstrFragmentsSIMD.td Wed Nov 22 00:10:54 2017
@@ -781,6 +781,13 @@ def X86masked_gather : SDNode<"X86ISD::M
                                                    SDTCisPtrTy<4>]>,
                              [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>;
 
+def X86masked_scatter : SDNode<"X86ISD::MSCATTER",
+                              SDTypeProfile<1, 3, [SDTCisVec<0>, SDTCisVec<1>,
+                                                   SDTCisSameAs<0, 2>,
+                                                   SDTCVecEltisVT<0, i1>,
+                                                   SDTCisPtrTy<3>]>,
+                             [SDNPHasChain, SDNPMayStore, SDNPMemOperand]>;
+
 def mgatherv4i32 : PatFrag<(ops node:$src1, node:$src2, node:$src3),
   (X86masked_gather node:$src1, node:$src2, node:$src3) , [{
   X86MaskedGatherSDNode *Mgt = cast<X86MaskedGatherSDNode>(N);
@@ -815,37 +822,37 @@ def mgatherv16i32 : PatFrag<(ops node:$s
 }]>;
 
 def mscatterv2i64 : PatFrag<(ops node:$src1, node:$src2, node:$src3),
-  (masked_scatter node:$src1, node:$src2, node:$src3) , [{
-  MaskedScatterSDNode *Sc = cast<MaskedScatterSDNode>(N);
+  (X86masked_scatter node:$src1, node:$src2, node:$src3) , [{
+  X86MaskedScatterSDNode *Sc = cast<X86MaskedScatterSDNode>(N);
   return Sc->getIndex().getValueType() == MVT::v2i64;
 }]>;
 
 def mscatterv4i32 : PatFrag<(ops node:$src1, node:$src2, node:$src3),
-  (masked_scatter node:$src1, node:$src2, node:$src3) , [{
-  MaskedScatterSDNode *Sc = cast<MaskedScatterSDNode>(N);
+  (X86masked_scatter node:$src1, node:$src2, node:$src3) , [{
+  X86MaskedScatterSDNode *Sc = cast<X86MaskedScatterSDNode>(N);
   return Sc->getIndex().getValueType() == MVT::v4i32;
 }]>;
 
 def mscatterv4i64 : PatFrag<(ops node:$src1, node:$src2, node:$src3),
-  (masked_scatter node:$src1, node:$src2, node:$src3) , [{
-  MaskedScatterSDNode *Sc = cast<MaskedScatterSDNode>(N);
+  (X86masked_scatter node:$src1, node:$src2, node:$src3) , [{
+  X86MaskedScatterSDNode *Sc = cast<X86MaskedScatterSDNode>(N);
   return Sc->getIndex().getValueType() == MVT::v4i64;
 }]>;
 
 def mscatterv8i32 : PatFrag<(ops node:$src1, node:$src2, node:$src3),
-  (masked_scatter node:$src1, node:$src2, node:$src3) , [{
-  MaskedScatterSDNode *Sc = cast<MaskedScatterSDNode>(N);
+  (X86masked_scatter node:$src1, node:$src2, node:$src3) , [{
+  X86MaskedScatterSDNode *Sc = cast<X86MaskedScatterSDNode>(N);
   return Sc->getIndex().getValueType() == MVT::v8i32;
 }]>;
 
 def mscatterv8i64 : PatFrag<(ops node:$src1, node:$src2, node:$src3),
-  (masked_scatter node:$src1, node:$src2, node:$src3) , [{
-  MaskedScatterSDNode *Sc = cast<MaskedScatterSDNode>(N);
+  (X86masked_scatter node:$src1, node:$src2, node:$src3) , [{
+  X86MaskedScatterSDNode *Sc = cast<X86MaskedScatterSDNode>(N);
   return Sc->getIndex().getValueType() == MVT::v8i64;
 }]>;
 def mscatterv16i32 : PatFrag<(ops node:$src1, node:$src2, node:$src3),
-  (masked_scatter node:$src1, node:$src2, node:$src3) , [{
-  MaskedScatterSDNode *Sc = cast<MaskedScatterSDNode>(N);
+  (X86masked_scatter node:$src1, node:$src2, node:$src3) , [{
+  X86MaskedScatterSDNode *Sc = cast<X86MaskedScatterSDNode>(N);
   return Sc->getIndex().getValueType() == MVT::v16i32;
 }]>;
 




More information about the llvm-commits mailing list