[llvm] 17b071d - [RISCV] Rework gather/scatter DAG combine structure [NFC]

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 12 10:57:22 PDT 2023


Author: Philip Reames
Date: 2023-09-12T10:57:12-07:00
New Revision: 17b071db6a39e98c4cc5d06521096e91f72237c4

URL: https://github.com/llvm/llvm-project/commit/17b071db6a39e98c4cc5d06521096e91f72237c4
DIFF: https://github.com/llvm/llvm-project/commit/17b071db6a39e98c4cc5d06521096e91f72237c4.diff

LOG: [RISCV] Rework gather/scatter DAG combine structure [NFC]

Instead of switching on type before and after common code, use a helper function.  This matches the style of DAGCombine.cpp more closely, and makes porting candidate changes from one place to the other much easier.

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 0dd03076cc05b36..a470ceae90ce591 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13482,6 +13482,34 @@ static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG,
   return DAG.getNode(Opc, DL, VT, Ops);
 }
 
+static bool legalizeScatterGatherIndexType(SDLoc DL, SDValue &Index,
+                                           ISD::MemIndexType &IndexType,
+                                           RISCVTargetLowering::DAGCombinerInfo &DCI) {
+  if (!DCI.isBeforeLegalize())
+    return false;
+
+  SelectionDAG &DAG = DCI.DAG;
+  const MVT XLenVT =
+    DAG.getMachineFunction().getSubtarget<RISCVSubtarget>().getXLenVT();
+
+  const EVT IndexVT = Index.getValueType();
+  const bool IsIndexSigned = isIndexTypeSigned(IndexType);
+
+  // RISC-V indexed loads only support the "unsigned unscaled" addressing
+  // mode, so anything else must be manually legalized.
+  if (!IsIndexSigned || !IndexVT.getVectorElementType().bitsLT(XLenVT))
+    return false;
+
+  // Any index legalization should first promote to XLenVT, so we don't lose
+  // bits when scaling. This may create an illegal index type so we let
+  // LLVM's legalization take care of the splitting.
+  // FIXME: LLVM can't split VP_GATHER or VP_SCATTER yet.
+  Index = DAG.getNode(ISD::SIGN_EXTEND, DL,
+                      IndexVT.changeVectorElementType(XLenVT), Index);
+  IndexType = ISD::UNSIGNED_SCALED;
+  return true;
+}
+
 SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
                                                DAGCombinerInfo &DCI) const {
   SelectionDAG &DAG = DCI.DAG;
@@ -13827,74 +13855,73 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
     return DAG.getNode(ISD::FCOPYSIGN, DL, VT, N->getOperand(0),
                        DAG.getNode(ISD::FNEG, DL, VT, NewFPExtRound));
   }
-  case ISD::MGATHER:
-  case ISD::MSCATTER:
-  case ISD::VP_GATHER:
-  case ISD::VP_SCATTER: {
-    if (!DCI.isBeforeLegalize())
-      break;
-    SDValue Index, ScaleOp;
-    bool IsIndexSigned = false;
-    if (const auto *VPGSN = dyn_cast<VPGatherScatterSDNode>(N)) {
-      Index = VPGSN->getIndex();
-      ScaleOp = VPGSN->getScale();
-      IsIndexSigned = VPGSN->isIndexSigned();
-      assert(!VPGSN->isIndexScaled() &&
-             "Scaled gather/scatter should not be formed");
-    } else {
-      const auto *MGSN = cast<MaskedGatherScatterSDNode>(N);
-      Index = MGSN->getIndex();
-      ScaleOp = MGSN->getScale();
-      IsIndexSigned = MGSN->isIndexSigned();
-      assert(!MGSN->isIndexScaled() &&
-             "Scaled gather/scatter should not be formed");
-
-    }
-    EVT IndexVT = Index.getValueType();
-    // RISC-V indexed loads only support the "unsigned unscaled" addressing
-    // mode, so anything else must be manually legalized.
-    bool NeedsIdxLegalization =
-        (IsIndexSigned && IndexVT.getVectorElementType().bitsLT(XLenVT));
-    if (!NeedsIdxLegalization)
-      break;
+  case ISD::MGATHER: {
+    const auto *MGN = dyn_cast<MaskedGatherSDNode>(N);
+    SDValue Index = MGN->getIndex();
+    SDValue ScaleOp = MGN->getScale();
+    ISD::MemIndexType IndexType = MGN->getIndexType();
+    assert(!MGN->isIndexScaled() &&
+           "Scaled gather/scatter should not be formed");
 
     SDLoc DL(N);
+    if (legalizeScatterGatherIndexType(DL, Index, IndexType, DCI))
+      return DAG.getMaskedGather(
+          N->getVTList(), MGN->getMemoryVT(), DL,
+          {MGN->getChain(), MGN->getPassThru(), MGN->getMask(),
+           MGN->getBasePtr(), Index, ScaleOp},
+          MGN->getMemOperand(), IndexType, MGN->getExtensionType());
+    break;
+  }
+  case ISD::MSCATTER:{
+    const auto *MSN = dyn_cast<MaskedScatterSDNode>(N);
+    SDValue Index = MSN->getIndex();
+    SDValue ScaleOp = MSN->getScale();
+    ISD::MemIndexType IndexType = MSN->getIndexType();
+    assert(!MSN->isIndexScaled() &&
+           "Scaled gather/scatter should not be formed");
 
-    // Any index legalization should first promote to XLenVT, so we don't lose
-    // bits when scaling. This may create an illegal index type so we let
-    // LLVM's legalization take care of the splitting.
-    // FIXME: LLVM can't split VP_GATHER or VP_SCATTER yet.
-    if (IndexVT.getVectorElementType().bitsLT(XLenVT)) {
-      IndexVT = IndexVT.changeVectorElementType(XLenVT);
-      Index = DAG.getNode(IsIndexSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND,
-                          DL, IndexVT, Index);
-    }
+    SDLoc DL(N);
+    if (legalizeScatterGatherIndexType(DL, Index, IndexType, DCI))
+      return DAG.getMaskedScatter(
+          N->getVTList(), MSN->getMemoryVT(), DL,
+          {MSN->getChain(), MSN->getValue(), MSN->getMask(), MSN->getBasePtr(),
+           Index, ScaleOp},
+          MSN->getMemOperand(), IndexType, MSN->isTruncatingStore());
+    break;
+  }
+  case ISD::VP_GATHER: {
+    const auto *VPGN = dyn_cast<VPGatherSDNode>(N);
+    SDValue Index = VPGN->getIndex();
+    SDValue ScaleOp = VPGN->getScale();
+    ISD::MemIndexType IndexType = VPGN->getIndexType();
+    assert(!VPGN->isIndexScaled() &&
+           "Scaled gather/scatter should not be formed");
 
-    ISD::MemIndexType NewIndexTy = ISD::UNSIGNED_SCALED;
-    if (const auto *VPGN = dyn_cast<VPGatherSDNode>(N))
+    SDLoc DL(N);
+    if (legalizeScatterGatherIndexType(DL, Index, IndexType, DCI))
       return DAG.getGatherVP(N->getVTList(), VPGN->getMemoryVT(), DL,
                              {VPGN->getChain(), VPGN->getBasePtr(), Index,
                               ScaleOp, VPGN->getMask(),
                               VPGN->getVectorLength()},
-                             VPGN->getMemOperand(), NewIndexTy);
-    if (const auto *VPSN = dyn_cast<VPScatterSDNode>(N))
+                             VPGN->getMemOperand(), IndexType);
+    break;
+  }
+  case ISD::VP_SCATTER: {
+    const auto *VPSN = dyn_cast<VPScatterSDNode>(N);
+    SDValue Index = VPSN->getIndex();
+    SDValue ScaleOp = VPSN->getScale();
+    ISD::MemIndexType IndexType = VPSN->getIndexType();
+    assert(!VPSN->isIndexScaled() &&
+           "Scaled gather/scatter should not be formed");
+
+    SDLoc DL(N);
+    if (legalizeScatterGatherIndexType(DL, Index, IndexType, DCI))
       return DAG.getScatterVP(N->getVTList(), VPSN->getMemoryVT(), DL,
                               {VPSN->getChain(), VPSN->getValue(),
                                VPSN->getBasePtr(), Index, ScaleOp,
                                VPSN->getMask(), VPSN->getVectorLength()},
-                              VPSN->getMemOperand(), NewIndexTy);
-    if (const auto *MGN = dyn_cast<MaskedGatherSDNode>(N))
-      return DAG.getMaskedGather(
-          N->getVTList(), MGN->getMemoryVT(), DL,
-          {MGN->getChain(), MGN->getPassThru(), MGN->getMask(),
-           MGN->getBasePtr(), Index, ScaleOp},
-          MGN->getMemOperand(), NewIndexTy, MGN->getExtensionType());
-    const auto *MSN = cast<MaskedScatterSDNode>(N);
-    return DAG.getMaskedScatter(
-        N->getVTList(), MSN->getMemoryVT(), DL,
-        {MSN->getChain(), MSN->getValue(), MSN->getMask(), MSN->getBasePtr(),
-         Index, ScaleOp},
-        MSN->getMemOperand(), NewIndexTy, MSN->isTruncatingStore());
+                              VPSN->getMemOperand(), IndexType);
+    break;
   }
   case RISCVISD::SRA_VL:
   case RISCVISD::SRL_VL:


        


More information about the llvm-commits mailing list