[llvm] 17b071d - [RISCV] Rework gather/scatter DAG combine structure [NFC]
Philip Reames via llvm-commits
llvm-commits at lists.llvm.org
Tue Sep 12 10:57:22 PDT 2023
Author: Philip Reames
Date: 2023-09-12T10:57:12-07:00
New Revision: 17b071db6a39e98c4cc5d06521096e91f72237c4
URL: https://github.com/llvm/llvm-project/commit/17b071db6a39e98c4cc5d06521096e91f72237c4
DIFF: https://github.com/llvm/llvm-project/commit/17b071db6a39e98c4cc5d06521096e91f72237c4.diff
LOG: [RISCV] Rework gather/scatter DAG combine structure [NFC]
Instead of switching on type before and after common code, use a helper function. This matches the style of DAGCombine.cpp more closely, and makes porting candidate changes from one place to the other much easier.
Added:
Modified:
llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 0dd03076cc05b36..a470ceae90ce591 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13482,6 +13482,34 @@ static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG,
return DAG.getNode(Opc, DL, VT, Ops);
}
+static bool legalizeScatterGatherIndexType(SDLoc DL, SDValue &Index,
+ ISD::MemIndexType &IndexType,
+ RISCVTargetLowering::DAGCombinerInfo &DCI) {
+ if (!DCI.isBeforeLegalize())
+ return false;
+
+ SelectionDAG &DAG = DCI.DAG;
+ const MVT XLenVT =
+ DAG.getMachineFunction().getSubtarget<RISCVSubtarget>().getXLenVT();
+
+ const EVT IndexVT = Index.getValueType();
+ const bool IsIndexSigned = isIndexTypeSigned(IndexType);
+
+ // RISC-V indexed loads only support the "unsigned unscaled" addressing
+ // mode, so anything else must be manually legalized.
+ if (!IsIndexSigned || !IndexVT.getVectorElementType().bitsLT(XLenVT))
+ return false;
+
+ // Any index legalization should first promote to XLenVT, so we don't lose
+ // bits when scaling. This may create an illegal index type so we let
+ // LLVM's legalization take care of the splitting.
+ // FIXME: LLVM can't split VP_GATHER or VP_SCATTER yet.
+ Index = DAG.getNode(ISD::SIGN_EXTEND, DL,
+ IndexVT.changeVectorElementType(XLenVT), Index);
+ IndexType = ISD::UNSIGNED_SCALED;
+ return true;
+}
+
SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
SelectionDAG &DAG = DCI.DAG;
@@ -13827,74 +13855,73 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
return DAG.getNode(ISD::FCOPYSIGN, DL, VT, N->getOperand(0),
DAG.getNode(ISD::FNEG, DL, VT, NewFPExtRound));
}
- case ISD::MGATHER:
- case ISD::MSCATTER:
- case ISD::VP_GATHER:
- case ISD::VP_SCATTER: {
- if (!DCI.isBeforeLegalize())
- break;
- SDValue Index, ScaleOp;
- bool IsIndexSigned = false;
- if (const auto *VPGSN = dyn_cast<VPGatherScatterSDNode>(N)) {
- Index = VPGSN->getIndex();
- ScaleOp = VPGSN->getScale();
- IsIndexSigned = VPGSN->isIndexSigned();
- assert(!VPGSN->isIndexScaled() &&
- "Scaled gather/scatter should not be formed");
- } else {
- const auto *MGSN = cast<MaskedGatherScatterSDNode>(N);
- Index = MGSN->getIndex();
- ScaleOp = MGSN->getScale();
- IsIndexSigned = MGSN->isIndexSigned();
- assert(!MGSN->isIndexScaled() &&
- "Scaled gather/scatter should not be formed");
-
- }
- EVT IndexVT = Index.getValueType();
- // RISC-V indexed loads only support the "unsigned unscaled" addressing
- // mode, so anything else must be manually legalized.
- bool NeedsIdxLegalization =
- (IsIndexSigned && IndexVT.getVectorElementType().bitsLT(XLenVT));
- if (!NeedsIdxLegalization)
- break;
+ case ISD::MGATHER: {
+ const auto *MGN = dyn_cast<MaskedGatherSDNode>(N);
+ SDValue Index = MGN->getIndex();
+ SDValue ScaleOp = MGN->getScale();
+ ISD::MemIndexType IndexType = MGN->getIndexType();
+ assert(!MGN->isIndexScaled() &&
+ "Scaled gather/scatter should not be formed");
SDLoc DL(N);
+ if (legalizeScatterGatherIndexType(DL, Index, IndexType, DCI))
+ return DAG.getMaskedGather(
+ N->getVTList(), MGN->getMemoryVT(), DL,
+ {MGN->getChain(), MGN->getPassThru(), MGN->getMask(),
+ MGN->getBasePtr(), Index, ScaleOp},
+ MGN->getMemOperand(), IndexType, MGN->getExtensionType());
+ break;
+ }
+ case ISD::MSCATTER:{
+ const auto *MSN = dyn_cast<MaskedScatterSDNode>(N);
+ SDValue Index = MSN->getIndex();
+ SDValue ScaleOp = MSN->getScale();
+ ISD::MemIndexType IndexType = MSN->getIndexType();
+ assert(!MSN->isIndexScaled() &&
+ "Scaled gather/scatter should not be formed");
- // Any index legalization should first promote to XLenVT, so we don't lose
- // bits when scaling. This may create an illegal index type so we let
- // LLVM's legalization take care of the splitting.
- // FIXME: LLVM can't split VP_GATHER or VP_SCATTER yet.
- if (IndexVT.getVectorElementType().bitsLT(XLenVT)) {
- IndexVT = IndexVT.changeVectorElementType(XLenVT);
- Index = DAG.getNode(IsIndexSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND,
- DL, IndexVT, Index);
- }
+ SDLoc DL(N);
+ if (legalizeScatterGatherIndexType(DL, Index, IndexType, DCI))
+ return DAG.getMaskedScatter(
+ N->getVTList(), MSN->getMemoryVT(), DL,
+ {MSN->getChain(), MSN->getValue(), MSN->getMask(), MSN->getBasePtr(),
+ Index, ScaleOp},
+ MSN->getMemOperand(), IndexType, MSN->isTruncatingStore());
+ break;
+ }
+ case ISD::VP_GATHER: {
+ const auto *VPGN = dyn_cast<VPGatherSDNode>(N);
+ SDValue Index = VPGN->getIndex();
+ SDValue ScaleOp = VPGN->getScale();
+ ISD::MemIndexType IndexType = VPGN->getIndexType();
+ assert(!VPGN->isIndexScaled() &&
+ "Scaled gather/scatter should not be formed");
- ISD::MemIndexType NewIndexTy = ISD::UNSIGNED_SCALED;
- if (const auto *VPGN = dyn_cast<VPGatherSDNode>(N))
+ SDLoc DL(N);
+ if (legalizeScatterGatherIndexType(DL, Index, IndexType, DCI))
return DAG.getGatherVP(N->getVTList(), VPGN->getMemoryVT(), DL,
{VPGN->getChain(), VPGN->getBasePtr(), Index,
ScaleOp, VPGN->getMask(),
VPGN->getVectorLength()},
- VPGN->getMemOperand(), NewIndexTy);
- if (const auto *VPSN = dyn_cast<VPScatterSDNode>(N))
+ VPGN->getMemOperand(), IndexType);
+ break;
+ }
+ case ISD::VP_SCATTER: {
+ const auto *VPSN = dyn_cast<VPScatterSDNode>(N);
+ SDValue Index = VPSN->getIndex();
+ SDValue ScaleOp = VPSN->getScale();
+ ISD::MemIndexType IndexType = VPSN->getIndexType();
+ assert(!VPSN->isIndexScaled() &&
+ "Scaled gather/scatter should not be formed");
+
+ SDLoc DL(N);
+ if (legalizeScatterGatherIndexType(DL, Index, IndexType, DCI))
return DAG.getScatterVP(N->getVTList(), VPSN->getMemoryVT(), DL,
{VPSN->getChain(), VPSN->getValue(),
VPSN->getBasePtr(), Index, ScaleOp,
VPSN->getMask(), VPSN->getVectorLength()},
- VPSN->getMemOperand(), NewIndexTy);
- if (const auto *MGN = dyn_cast<MaskedGatherSDNode>(N))
- return DAG.getMaskedGather(
- N->getVTList(), MGN->getMemoryVT(), DL,
- {MGN->getChain(), MGN->getPassThru(), MGN->getMask(),
- MGN->getBasePtr(), Index, ScaleOp},
- MGN->getMemOperand(), NewIndexTy, MGN->getExtensionType());
- const auto *MSN = cast<MaskedScatterSDNode>(N);
- return DAG.getMaskedScatter(
- N->getVTList(), MSN->getMemoryVT(), DL,
- {MSN->getChain(), MSN->getValue(), MSN->getMask(), MSN->getBasePtr(),
- Index, ScaleOp},
- MSN->getMemOperand(), NewIndexTy, MSN->isTruncatingStore());
+ VPSN->getMemOperand(), IndexType);
+ break;
}
case RISCVISD::SRA_VL:
case RISCVISD::SRL_VL:
More information about the llvm-commits
mailing list