[llvm] 9b4fcfa - [SVE][CodeGen] Remove performMaskedGatherScatterCombine

Kerry McLaughlin via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 1 06:11:27 PST 2021


Author: Kerry McLaughlin
Date: 2021-02-01T14:10:00Z
New Revision: 9b4fcfaa9e8f19f250c45e92dd2e5a305156b701

URL: https://github.com/llvm/llvm-project/commit/9b4fcfaa9e8f19f250c45e92dd2e5a305156b701
DIFF: https://github.com/llvm/llvm-project/commit/9b4fcfaa9e8f19f250c45e92dd2e5a305156b701.diff

LOG: [SVE][CodeGen] Remove performMaskedGatherScatterCombine

The AArch64 DAG combine added by D90945 & D91433 extends the index
of a scalable masked gather or scatter to i32 if necessary.

This patch removes the combine and instead adds shouldExtendGSIndex, which
is used by visitMaskedGather/Scatter in SelectionDAGBuilder to query whether
the index should be extended before calling getMaskedGather/Scatter.

Reviewed By: david-arm

Differential Revision: https://reviews.llvm.org/D94525

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/TargetLowering.h
    llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.h

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 6eafc146b900..7d27bf390e65 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -1318,6 +1318,10 @@ class TargetLoweringBase {
             getIndexedMaskedStoreAction(IdxMode, VT.getSimpleVT()) == Custom);
   }
 
+  /// Returns true if the index type for a masked gather/scatter requires
+  /// extending
+  virtual bool shouldExtendGSIndex(EVT VT, EVT &EltTy) const { return false; }
+
   // Returns true if VT is a legal index type for masked gathers/scatters
   // on this target
   virtual bool shouldRemoveExtendFromGSIndex(EVT VT) const { return false; }

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 2194b73b5768..5c94a83f719c 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -4339,6 +4339,14 @@ void SelectionDAGBuilder::visitMaskedScatter(const CallInst &I) {
     IndexType = ISD::SIGNED_UNSCALED;
     Scale = DAG.getTargetConstant(1, sdl, TLI.getPointerTy(DAG.getDataLayout()));
   }
+
+  EVT IdxVT = Index.getValueType();
+  EVT EltTy = IdxVT.getVectorElementType();
+  if (TLI.shouldExtendGSIndex(IdxVT, EltTy)) {
+    EVT NewIdxVT = IdxVT.changeVectorElementType(EltTy);
+    Index = DAG.getNode(ISD::SIGN_EXTEND, sdl, NewIdxVT, Index);
+  }
+
   SDValue Ops[] = { getMemoryRoot(), Src0, Mask, Base, Index, Scale };
   SDValue Scatter = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), VT, sdl,
                                          Ops, MMO, IndexType, false);
@@ -4450,6 +4458,14 @@ void SelectionDAGBuilder::visitMaskedGather(const CallInst &I) {
     IndexType = ISD::SIGNED_UNSCALED;
     Scale = DAG.getTargetConstant(1, sdl, TLI.getPointerTy(DAG.getDataLayout()));
   }
+
+  EVT IdxVT = Index.getValueType();
+  EVT EltTy = IdxVT.getVectorElementType();
+  if (TLI.shouldExtendGSIndex(IdxVT, EltTy)) {
+    EVT NewIdxVT = IdxVT.changeVectorElementType(EltTy);
+    Index = DAG.getNode(ISD::SIGN_EXTEND, sdl, NewIdxVT, Index);
+  }
+
   SDValue Ops[] = { Root, Src0, Mask, Base, Index, Scale };
   SDValue Gather = DAG.getMaskedGather(DAG.getVTList(VT, MVT::Other), VT, sdl,
                                        Ops, MMO, IndexType, ISD::NON_EXTLOAD);

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index b4d5f5545927..36f304b1fd94 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -873,9 +873,6 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
   if (Subtarget->supportsAddressTopByteIgnored())
     setTargetDAGCombine(ISD::LOAD);
 
-  setTargetDAGCombine(ISD::MGATHER);
-  setTargetDAGCombine(ISD::MSCATTER);
-
   setTargetDAGCombine(ISD::MUL);
 
   setTargetDAGCombine(ISD::SELECT);
@@ -3825,6 +3822,15 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
   }
 }
 
+bool AArch64TargetLowering::shouldExtendGSIndex(EVT VT, EVT &EltTy) const {
+  if (VT.getVectorElementType() == MVT::i8 ||
+      VT.getVectorElementType() == MVT::i16) {
+    EltTy = MVT::i32;
+    return true;
+  }
+  return false;
+}
+
 bool AArch64TargetLowering::shouldRemoveExtendFromGSIndex(EVT VT) const {
   if (VT.getVectorElementType() == MVT::i32 &&
       VT.getVectorElementCount().getKnownMinValue() >= 4)
@@ -14395,55 +14401,6 @@ static SDValue performSTORECombine(SDNode *N,
   return SDValue();
 }
 
-static SDValue performMaskedGatherScatterCombine(SDNode *N,
-                                      TargetLowering::DAGCombinerInfo &DCI,
-                                      SelectionDAG &DAG) {
-  MaskedGatherScatterSDNode *MGS = cast<MaskedGatherScatterSDNode>(N);
-  assert(MGS && "Can only combine gather load or scatter store nodes");
-
-  SDLoc DL(MGS);
-  SDValue Chain = MGS->getChain();
-  SDValue Scale = MGS->getScale();
-  SDValue Index = MGS->getIndex();
-  SDValue Mask = MGS->getMask();
-  SDValue BasePtr = MGS->getBasePtr();
-  ISD::MemIndexType IndexType = MGS->getIndexType();
-
-  EVT IdxVT = Index.getValueType();
-
-  if (DCI.isBeforeLegalize()) {
-    // SVE gather/scatter requires indices of i32/i64. Promote anything smaller
-    // prior to legalisation so the result can be split if required.
-    if ((IdxVT.getVectorElementType() == MVT::i8) ||
-        (IdxVT.getVectorElementType() == MVT::i16)) {
-      EVT NewIdxVT = IdxVT.changeVectorElementType(MVT::i32);
-      if (MGS->isIndexSigned())
-        Index = DAG.getNode(ISD::SIGN_EXTEND, DL, NewIdxVT, Index);
-      else
-        Index = DAG.getNode(ISD::ZERO_EXTEND, DL, NewIdxVT, Index);
-
-      if (auto *MGT = dyn_cast<MaskedGatherSDNode>(MGS)) {
-        SDValue PassThru = MGT->getPassThru();
-        SDValue Ops[] = { Chain, PassThru, Mask, BasePtr, Index, Scale };
-        return DAG.getMaskedGather(DAG.getVTList(N->getValueType(0), MVT::Other),
-                                   PassThru.getValueType(), DL, Ops,
-                                   MGT->getMemOperand(),
-                                   MGT->getIndexType(), MGT->getExtensionType());
-      } else {
-        auto *MSC = cast<MaskedScatterSDNode>(MGS);
-        SDValue Data = MSC->getValue();
-        SDValue Ops[] = { Chain, Data, Mask, BasePtr, Index, Scale };
-        return DAG.getMaskedScatter(DAG.getVTList(MVT::Other),
-                                    MSC->getMemoryVT(), DL, Ops,
-                                    MSC->getMemOperand(), IndexType,
-                                    MSC->isTruncatingStore());
-      }
-    }
-  }
-
-  return SDValue();
-}
-
 /// Target-specific DAG combine function for NEON load/store intrinsics
 /// to merge base address updates.
 static SDValue performNEONPostLDSTCombine(SDNode *N,
@@ -15638,9 +15595,6 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
     break;
   case ISD::STORE:
     return performSTORECombine(N, DCI, DAG, Subtarget);
-  case ISD::MGATHER:
-  case ISD::MSCATTER:
-    return performMaskedGatherScatterCombine(N, DCI, DAG);
   case AArch64ISD::BRCOND:
     return performBRCONDCombine(N, DCI, DAG);
   case AArch64ISD::TBNZ:

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 9550197159e6..8aec29478b72 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -996,6 +996,7 @@ class AArch64TargetLowering : public TargetLowering {
     return TargetLowering::getInlineAsmMemConstraint(ConstraintCode);
   }
 
+  bool shouldExtendGSIndex(EVT VT, EVT &EltTy) const override;
   bool shouldRemoveExtendFromGSIndex(EVT VT) const override;
   bool isVectorLoadExtDesirable(SDValue ExtVal) const override;
   bool isUsedByReturnOnly(SDNode *N, SDValue &Chain) const override;


        


More information about the llvm-commits mailing list