[llvm] 9238dfb - [X86] Remove mask output from X86 gather/scatter ISD opcodes.

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 24 23:56:45 PST 2020


Author: Craig Topper
Date: 2020-02-24T23:56:28-08:00
New Revision: 9238dfb4d80dc93a3187a2d3b030c6f7867afd50

URL: https://github.com/llvm/llvm-project/commit/9238dfb4d80dc93a3187a2d3b030c6f7867afd50
DIFF: https://github.com/llvm/llvm-project/commit/9238dfb4d80dc93a3187a2d3b030c6f7867afd50.diff

LOG: [X86] Remove mask output from X86 gather/scatter ISD opcodes.

Instead add it when we make the machine nodes during instruction
selections.

This makes this ISD node closer to ISD::MGATHER. Trying to see
if we remove the X86 specific ones.

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
    llvm/lib/Target/X86/X86ISelLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
index 3c99423fa637..76a903278960 100644
--- a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
+++ b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
@@ -5474,7 +5474,8 @@ void X86DAGToDAGISel::Select(SDNode *Node) {
 
     SDValue PassThru = Mgt->getPassThru();
     SDValue Chain = Mgt->getChain();
-    SDVTList VTs = Mgt->getVTList();
+    // Gather instructions have a mask output not in the ISD node.
+    SDVTList VTs = CurDAG->getVTList(ValueVT, MaskVT, MVT::Other);
 
     MachineSDNode *NewNode;
     if (AVX512Gather) {
@@ -5487,7 +5488,9 @@ void X86DAGToDAGISel::Select(SDNode *Node) {
       NewNode = CurDAG->getMachineNode(Opc, SDLoc(dl), VTs, Ops);
     }
     CurDAG->setNodeMemRefs(NewNode, {Mgt->getMemOperand()});
-    ReplaceNode(Node, NewNode);
+    ReplaceUses(SDValue(Node, 0), SDValue(NewNode, 0));
+    ReplaceUses(SDValue(Node, 1), SDValue(NewNode, 2));
+    CurDAG->RemoveDeadNode(Node);
     return;
   }
   case X86ISD::MSCATTER: {
@@ -5544,12 +5547,14 @@ void X86DAGToDAGISel::Select(SDNode *Node) {
 
     SDValue Mask = Sc->getMask();
     SDValue Chain = Sc->getChain();
-    SDVTList VTs = Sc->getVTList();
+    // Scatter instructions have a mask output not in the ISD node.
+    SDVTList VTs = CurDAG->getVTList(Mask.getValueType(), MVT::Other);
     SDValue Ops[] = {Base, Scale, Index, Disp, Segment, Mask, Value, Chain};
 
     MachineSDNode *NewNode = CurDAG->getMachineNode(Opc, SDLoc(dl), VTs, Ops);
     CurDAG->setNodeMemRefs(NewNode, {Sc->getMemOperand()});
-    ReplaceNode(Node, NewNode);
+    ReplaceUses(SDValue(Node, 0), SDValue(NewNode, 1));
+    CurDAG->RemoveDeadNode(Node);
     return;
   }
   }

diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index bd9e7f075918..be06e8eb7b63 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -24771,7 +24771,7 @@ static SDValue getAVX2GatherNode(unsigned Opc, SDValue Op, SelectionDAG &DAG,
   SDValue Scale = DAG.getTargetConstant(C->getZExtValue(), dl,
                                         TLI.getPointerTy(DAG.getDataLayout()));
   EVT MaskVT = Mask.getValueType().changeVectorElementTypeToInteger();
-  SDVTList VTs = DAG.getVTList(Op.getValueType(), MaskVT, MVT::Other);
+  SDVTList VTs = DAG.getVTList(Op.getValueType(), MVT::Other);
   // If source is undef or we know it won't be used, use a zero vector
   // to break register dependency.
   // TODO: use undef instead and let BreakFalseDeps deal with it?
@@ -24787,7 +24787,7 @@ static SDValue getAVX2GatherNode(unsigned Opc, SDValue Op, SelectionDAG &DAG,
   SDValue Res =
       DAG.getMemIntrinsicNode(X86ISD::MGATHER, dl, VTs, Ops,
                               MemIntr->getMemoryVT(), MemIntr->getMemOperand());
-  return DAG.getMergeValues({ Res, Res.getValue(2) }, dl);
+  return DAG.getMergeValues({Res, Res.getValue(1)}, dl);
 }
 
 static SDValue getGatherNode(SDValue Op, SelectionDAG &DAG,
@@ -24812,7 +24812,7 @@ static SDValue getGatherNode(SDValue Op, SelectionDAG &DAG,
   if (Mask.getValueType() != MaskVT)
     Mask = getMaskNode(Mask, MaskVT, Subtarget, DAG, dl);
 
-  SDVTList VTs = DAG.getVTList(Op.getValueType(), MaskVT, MVT::Other);
+  SDVTList VTs = DAG.getVTList(Op.getValueType(), MVT::Other);
   // If source is undef or we know it won't be used, use a zero vector
   // to break register dependency.
   // TODO: use undef instead and let BreakFalseDeps deal with it?
@@ -24825,7 +24825,7 @@ static SDValue getGatherNode(SDValue Op, SelectionDAG &DAG,
   SDValue Res =
       DAG.getMemIntrinsicNode(X86ISD::MGATHER, dl, VTs, Ops,
                               MemIntr->getMemoryVT(), MemIntr->getMemOperand());
-  return DAG.getMergeValues({ Res, Res.getValue(2) }, dl);
+  return DAG.getMergeValues({Res, Res.getValue(1)}, dl);
 }
 
 static SDValue getScatterNode(unsigned Opc, SDValue Op, SelectionDAG &DAG,
@@ -24851,12 +24851,12 @@ static SDValue getScatterNode(unsigned Opc, SDValue Op, SelectionDAG &DAG,
 
   MemIntrinsicSDNode *MemIntr = cast<MemIntrinsicSDNode>(Op);
 
-  SDVTList VTs = DAG.getVTList(MaskVT, MVT::Other);
+  SDVTList VTs = DAG.getVTList(MVT::Other);
   SDValue Ops[] = {Chain, Src, Mask, Base, Index, Scale};
   SDValue Res =
       DAG.getMemIntrinsicNode(X86ISD::MSCATTER, dl, VTs, Ops,
                               MemIntr->getMemoryVT(), MemIntr->getMemOperand());
-  return Res.getValue(1);
+  return Res;
 }
 
 static SDValue getPrefetchNode(unsigned Opc, SDValue Op, SelectionDAG &DAG,
@@ -28523,11 +28523,10 @@ static SDValue LowerMSCATTER(SDValue Op, const X86Subtarget &Subtarget,
       const TargetLowering &TLI = DAG.getTargetLoweringInfo();
       EVT WideVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
       Src = DAG.getNode(ISD::CONCAT_VECTORS, dl, WideVT, Src, DAG.getUNDEF(VT));
-      SDVTList VTs = DAG.getVTList(MVT::v2i1, MVT::Other);
+      SDVTList VTs = DAG.getVTList(MVT::Other);
       SDValue Ops[] = {Chain, Src, Mask, BasePtr, Index, Scale};
-      SDValue NewScatter = DAG.getMemIntrinsicNode(
-          X86ISD::MSCATTER, dl, VTs, Ops, N->getMemoryVT(), N->getMemOperand());
-      return SDValue(NewScatter.getNode(), 1);
+      return DAG.getMemIntrinsicNode(X86ISD::MSCATTER, dl, VTs, Ops,
+                                     N->getMemoryVT(), N->getMemOperand());
     }
     return SDValue();
   }
@@ -28558,11 +28557,10 @@ static SDValue LowerMSCATTER(SDValue Op, const X86Subtarget &Subtarget,
     Mask = ExtendToType(Mask, MaskVT, DAG, true);
   }
 
-  SDVTList VTs = DAG.getVTList(MaskVT, MVT::Other);
+  SDVTList VTs = DAG.getVTList(MVT::Other);
   SDValue Ops[] = {Chain, Src, Mask, BasePtr, Index, Scale};
-  SDValue NewScatter = DAG.getMemIntrinsicNode(
-      X86ISD::MSCATTER, dl, VTs, Ops, N->getMemoryVT(), N->getMemOperand());
-  return SDValue(NewScatter.getNode(), 1);
+  return DAG.getMemIntrinsicNode(X86ISD::MSCATTER, dl, VTs, Ops,
+                                 N->getMemoryVT(), N->getMemOperand());
 }
 
 static SDValue LowerMLOAD(SDValue Op, const X86Subtarget &Subtarget,
@@ -28717,11 +28715,11 @@ static SDValue LowerMGATHER(SDValue Op, const X86Subtarget &Subtarget,
   SDValue Ops[] = { N->getChain(), PassThru, Mask, N->getBasePtr(), Index,
                     N->getScale() };
   SDValue NewGather = DAG.getMemIntrinsicNode(
-      X86ISD::MGATHER, dl, DAG.getVTList(VT, MaskVT, MVT::Other), Ops,
-      N->getMemoryVT(), N->getMemOperand());
+      X86ISD::MGATHER, dl, DAG.getVTList(VT, MVT::Other), Ops, N->getMemoryVT(),
+      N->getMemOperand());
   SDValue Extract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, OrigVT,
                                 NewGather, DAG.getIntPtrConstant(0, dl));
-  return DAG.getMergeValues({Extract, NewGather.getValue(2)}, dl);
+  return DAG.getMergeValues({Extract, NewGather.getValue(1)}, dl);
 }
 
 static SDValue LowerADDRSPACECAST(SDValue Op, SelectionDAG &DAG) {
@@ -29833,11 +29831,10 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N,
       SDValue Ops[] = { Gather->getChain(), PassThru, Mask,
                         Gather->getBasePtr(), Index, Gather->getScale() };
       SDValue Res = DAG.getMemIntrinsicNode(
-          X86ISD::MGATHER, dl,
-          DAG.getVTList(WideVT, Mask.getValueType(), MVT::Other), Ops,
+          X86ISD::MGATHER, dl, DAG.getVTList(WideVT, MVT::Other), Ops,
           Gather->getMemoryVT(), Gather->getMemOperand());
       Results.push_back(Res);
-      Results.push_back(Res.getValue(2));
+      Results.push_back(Res.getValue(1));
       return;
     }
     return;


        


More information about the llvm-commits mailing list