[llvm] ffbbfc7 - [SVE][CodeGen] Add the isTruncatingStore flag to MSCATTER

Kerry McLaughlin via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 11 03:14:11 PST 2020


Author: Kerry McLaughlin
Date: 2020-11-11T10:58:24Z
New Revision: ffbbfc76ca209934321f6a8b26793119598354fe

URL: https://github.com/llvm/llvm-project/commit/ffbbfc76ca209934321f6a8b26793119598354fe
DIFF: https://github.com/llvm/llvm-project/commit/ffbbfc76ca209934321f6a8b26793119598354fe.diff

LOG: [SVE][CodeGen] Add the isTruncatingStore flag to MSCATTER

This patch adds the IsTruncatingStore flag to MaskedScatterSDNode, set by getMaskedScatter().
Updated SelectionDAGDumper::print_details for MaskedScatterSDNode to print
the details of masked scatters (is truncating, signed or scaled).

This is the first in a series of patches which adds support for scalable masked scatters

Reviewed By: sdesmalen

Differential Revision: https://reviews.llvm.org/D90939

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/SelectionDAG.h
    llvm/include/llvm/CodeGen/SelectionDAGNodes.h
    llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
    llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
    llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
    llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
    llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
    llvm/lib/Target/X86/X86ISelLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index efbb8bdb5f9b..42fc296a682c 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1361,7 +1361,8 @@ class SelectionDAG {
                           ISD::MemIndexType IndexType);
   SDValue getMaskedScatter(SDVTList VTs, EVT VT, const SDLoc &dl,
                            ArrayRef<SDValue> Ops, MachineMemOperand *MMO,
-                           ISD::MemIndexType IndexType);
+                           ISD::MemIndexType IndexType,
+                           bool IsTruncating = false);
 
   /// Construct a node to track a Value* through the backend.
   SDValue getSrcValue(const Value *v);

diff  --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index fb81718e119b..34f824366206 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -523,6 +523,7 @@ BEGIN_TWO_BYTE_PACK()
   class StoreSDNodeBitfields {
     friend class StoreSDNode;
     friend class MaskedStoreSDNode;
+    friend class MaskedScatterSDNode;
 
     uint16_t : NumLSBaseSDNodeBits;
 
@@ -2441,9 +2442,16 @@ class MaskedScatterSDNode : public MaskedGatherScatterSDNode {
 
   MaskedScatterSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs,
                       EVT MemVT, MachineMemOperand *MMO,
-                      ISD::MemIndexType IndexType)
+                      ISD::MemIndexType IndexType, bool IsTrunc)
       : MaskedGatherScatterSDNode(ISD::MSCATTER, Order, dl, VTs, MemVT, MMO,
-                                  IndexType) {}
+                                  IndexType) {
+    StoreSDNodeBits.IsTruncating = IsTrunc;
+  }
+
+  /// Return true if the op does a truncation before store.
+  /// For integers this is the same as doing a TRUNCATE and storing the result.
+  /// For floats, it is the same as doing an FP_ROUND and storing the result.
+  bool isTruncatingStore() const { return StoreSDNodeBits.IsTruncating; }
 
   const SDValue &getValue() const { return getOperand(1); }
 

diff  --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 1c5253369ef9..b4f6d21e44ea 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -1851,6 +1851,7 @@ SDValue DAGTypeLegalizer::PromoteIntOp_MGATHER(MaskedGatherSDNode *N,
 
 SDValue DAGTypeLegalizer::PromoteIntOp_MSCATTER(MaskedScatterSDNode *N,
                                                 unsigned OpNo) {
+  bool TruncateStore = N->isTruncatingStore();
   SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end());
   if (OpNo == 2) {
     // The Mask
@@ -1863,9 +1864,15 @@ SDValue DAGTypeLegalizer::PromoteIntOp_MSCATTER(MaskedScatterSDNode *N,
       NewOps[OpNo] = SExtPromotedInteger(N->getOperand(OpNo));
     else
       NewOps[OpNo] = ZExtPromotedInteger(N->getOperand(OpNo));
-  } else
+
+  } else {
     NewOps[OpNo] = GetPromotedInteger(N->getOperand(OpNo));
-  return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0);
+    TruncateStore = true;
+  }
+
+  return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), N->getMemoryVT(),
+                              SDLoc(N), NewOps, N->getMemOperand(),
+                              N->getIndexType(), TruncateStore);
 }
 
 SDValue DAGTypeLegalizer::PromoteIntOp_TRUNCATE(SDNode *N) {

diff  --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index e84d8e90e382..634a37b4c706 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -2496,11 +2496,15 @@ SDValue DAGTypeLegalizer::SplitVecOp_MSCATTER(MaskedScatterSDNode *N,
   SDValue Index = N->getIndex();
   SDValue Scale = N->getScale();
   SDValue Data = N->getValue();
+  EVT MemoryVT = N->getMemoryVT();
   Align Alignment = N->getOriginalAlign();
   SDLoc DL(N);
 
   // Split all operands
 
+  EVT LoMemVT, HiMemVT;
+  std::tie(LoMemVT, HiMemVT) = DAG.GetSplitDestVTs(MemoryVT);
+
   SDValue DataLo, DataHi;
   if (getTypeAction(Data.getValueType()) == TargetLowering::TypeSplitVector)
     // Split Data operand
@@ -2531,15 +2535,17 @@ SDValue DAGTypeLegalizer::SplitVecOp_MSCATTER(MaskedScatterSDNode *N,
       MemoryLocation::UnknownSize, Alignment, N->getAAInfo(), N->getRanges());
 
   SDValue OpsLo[] = {Ch, DataLo, MaskLo, Ptr, IndexLo, Scale};
-  Lo = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), DataLo.getValueType(),
-                            DL, OpsLo, MMO, N->getIndexType());
+  Lo = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), LoMemVT,
+                            DL, OpsLo, MMO, N->getIndexType(),
+                            N->isTruncatingStore());
 
   // The order of the Scatter operation after split is well defined. The "Hi"
   // part comes after the "Lo". So these two operations should be chained one
   // after another.
   SDValue OpsHi[] = {Lo, DataHi, MaskHi, Ptr, IndexHi, Scale};
-  return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), DataHi.getValueType(),
-                              DL, OpsHi, MMO, N->getIndexType());
+  return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), HiMemVT,
+                              DL, OpsHi, MMO, N->getIndexType(),
+                              N->isTruncatingStore());
 }
 
 SDValue DAGTypeLegalizer::SplitVecOp_STORE(StoreSDNode *N, unsigned OpNo) {
@@ -4717,7 +4723,8 @@ SDValue DAGTypeLegalizer::WidenVecOp_MSCATTER(SDNode *N, unsigned OpNo) {
                    Scale};
   return DAG.getMaskedScatter(DAG.getVTList(MVT::Other),
                               MSC->getMemoryVT(), SDLoc(N), Ops,
-                              MSC->getMemOperand(), MSC->getIndexType());
+                              MSC->getMemOperand(), MSC->getIndexType(),
+                              MSC->isTruncatingStore());
 }
 
 SDValue DAGTypeLegalizer::WidenVecOp_SETCC(SDNode *N) {

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 9ae04adf1b85..2b7da1f60bd0 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -7340,22 +7340,24 @@ SDValue SelectionDAG::getMaskedGather(SDVTList VTs, EVT VT, const SDLoc &dl,
 SDValue SelectionDAG::getMaskedScatter(SDVTList VTs, EVT VT, const SDLoc &dl,
                                        ArrayRef<SDValue> Ops,
                                        MachineMemOperand *MMO,
-                                       ISD::MemIndexType IndexType) {
+                                       ISD::MemIndexType IndexType,
+                                       bool IsTrunc) {
   assert(Ops.size() == 6 && "Incompatible number of operands");
 
   FoldingSetNodeID ID;
   AddNodeIDNode(ID, ISD::MSCATTER, VTs, Ops);
   ID.AddInteger(VT.getRawBits());
   ID.AddInteger(getSyntheticNodeSubclassData<MaskedScatterSDNode>(
-      dl.getIROrder(), VTs, VT, MMO, IndexType));
+      dl.getIROrder(), VTs, VT, MMO, IndexType, IsTrunc));
   ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
   void *IP = nullptr;
   if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) {
     cast<MaskedScatterSDNode>(E)->refineAlignment(MMO);
     return SDValue(E, 0);
   }
+
   auto *N = newSDNode<MaskedScatterSDNode>(dl.getIROrder(), dl.getDebugLoc(),
-                                           VTs, VT, MMO, IndexType);
+                                           VTs, VT, MMO, IndexType, IsTrunc);
   createOperands(N, Ops);
 
   assert(N->getMask().getValueType().getVectorNumElements() ==

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index a89cbc7e5fa7..95dddb64f31d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -4302,7 +4302,7 @@ void SelectionDAGBuilder::visitMaskedScatter(const CallInst &I) {
   }
   SDValue Ops[] = { getMemoryRoot(), Src0, Mask, Base, Index, Scale };
   SDValue Scatter = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), VT, sdl,
-                                         Ops, MMO, IndexType);
+                                         Ops, MMO, IndexType, false);
   DAG.setRoot(Scatter);
   setValue(&I, Scatter);
 }

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index 43be03e9332f..df35f057072e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -735,7 +735,19 @@ void SDNode::print_details(raw_ostream &OS, const SelectionDAG *G) const {
       OS << ", compressing";
 
     OS << ">";
-  } else if (const MemSDNode* M = dyn_cast<MemSDNode>(this)) {
+  } else if (const auto *MScatter = dyn_cast<MaskedScatterSDNode>(this)) {
+    OS << "<";
+    printMemOperand(OS, *MScatter->getMemOperand(), G);
+
+    if (MScatter->isTruncatingStore())
+      OS << ", trunc to " << MScatter->getMemoryVT().getEVTString();
+
+    auto Signed = MScatter->isIndexSigned() ? "signed" : "unsigned";
+    auto Scaled = MScatter->isIndexScaled() ? "scaled" : "unscaled";
+    OS << ", " << Signed << " " << Scaled << " offset";
+
+    OS << ">";
+  } else if (const MemSDNode *M = dyn_cast<MemSDNode>(this)) {
     OS << "<";
     printMemOperand(OS, *M->getMemOperand(), G);
     OS << ">";

diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 31fa4f447406..59dd36c9a768 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -47560,7 +47560,8 @@ static SDValue rebuildGatherScatter(MaskedGatherScatterSDNode *GorS,
   return DAG.getMaskedScatter(Scatter->getVTList(),
                               Scatter->getMemoryVT(), DL,
                               Ops, Scatter->getMemOperand(),
-                              Scatter->getIndexType());
+                              Scatter->getIndexType(),
+                              Scatter->isTruncatingStore());
 }
 
 static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,


        


More information about the llvm-commits mailing list