[llvm] 9b4fcfa - [SVE][CodeGen] Remove performMaskedGatherScatterCombine
Kerry McLaughlin via llvm-commits
llvm-commits at lists.llvm.org
Mon Feb 1 06:11:27 PST 2021
Author: Kerry McLaughlin
Date: 2021-02-01T14:10:00Z
New Revision: 9b4fcfaa9e8f19f250c45e92dd2e5a305156b701
URL: https://github.com/llvm/llvm-project/commit/9b4fcfaa9e8f19f250c45e92dd2e5a305156b701
DIFF: https://github.com/llvm/llvm-project/commit/9b4fcfaa9e8f19f250c45e92dd2e5a305156b701.diff
LOG: [SVE][CodeGen] Remove performMaskedGatherScatterCombine
The AArch64 DAG combine added by D90945 & D91433 extends the index
of a scalable masked gather or scatter to i32 if necessary.
This patch removes the combine and instead adds shouldExtendGSIndex, which
is used by visitMaskedGather/Scatter in SelectionDAGBuilder to query whether
the index should be extended before calling getMaskedGather/Scatter.
Reviewed By: david-arm
Differential Revision: https://reviews.llvm.org/D94525
Added:
Modified:
llvm/include/llvm/CodeGen/TargetLowering.h
llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/lib/Target/AArch64/AArch64ISelLowering.h
Removed:
################################################################################
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 6eafc146b900..7d27bf390e65 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -1318,6 +1318,10 @@ class TargetLoweringBase {
getIndexedMaskedStoreAction(IdxMode, VT.getSimpleVT()) == Custom);
}
+ /// Returns true if the index type for a masked gather/scatter requires
+ /// extending
+ virtual bool shouldExtendGSIndex(EVT VT, EVT &EltTy) const { return false; }
+
// Returns true if VT is a legal index type for masked gathers/scatters
// on this target
virtual bool shouldRemoveExtendFromGSIndex(EVT VT) const { return false; }
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 2194b73b5768..5c94a83f719c 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -4339,6 +4339,14 @@ void SelectionDAGBuilder::visitMaskedScatter(const CallInst &I) {
IndexType = ISD::SIGNED_UNSCALED;
Scale = DAG.getTargetConstant(1, sdl, TLI.getPointerTy(DAG.getDataLayout()));
}
+
+ EVT IdxVT = Index.getValueType();
+ EVT EltTy = IdxVT.getVectorElementType();
+ if (TLI.shouldExtendGSIndex(IdxVT, EltTy)) {
+ EVT NewIdxVT = IdxVT.changeVectorElementType(EltTy);
+ Index = DAG.getNode(ISD::SIGN_EXTEND, sdl, NewIdxVT, Index);
+ }
+
SDValue Ops[] = { getMemoryRoot(), Src0, Mask, Base, Index, Scale };
SDValue Scatter = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), VT, sdl,
Ops, MMO, IndexType, false);
@@ -4450,6 +4458,14 @@ void SelectionDAGBuilder::visitMaskedGather(const CallInst &I) {
IndexType = ISD::SIGNED_UNSCALED;
Scale = DAG.getTargetConstant(1, sdl, TLI.getPointerTy(DAG.getDataLayout()));
}
+
+ EVT IdxVT = Index.getValueType();
+ EVT EltTy = IdxVT.getVectorElementType();
+ if (TLI.shouldExtendGSIndex(IdxVT, EltTy)) {
+ EVT NewIdxVT = IdxVT.changeVectorElementType(EltTy);
+ Index = DAG.getNode(ISD::SIGN_EXTEND, sdl, NewIdxVT, Index);
+ }
+
SDValue Ops[] = { Root, Src0, Mask, Base, Index, Scale };
SDValue Gather = DAG.getMaskedGather(DAG.getVTList(VT, MVT::Other), VT, sdl,
Ops, MMO, IndexType, ISD::NON_EXTLOAD);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index b4d5f5545927..36f304b1fd94 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -873,9 +873,6 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
if (Subtarget->supportsAddressTopByteIgnored())
setTargetDAGCombine(ISD::LOAD);
- setTargetDAGCombine(ISD::MGATHER);
- setTargetDAGCombine(ISD::MSCATTER);
-
setTargetDAGCombine(ISD::MUL);
setTargetDAGCombine(ISD::SELECT);
@@ -3825,6 +3822,15 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
}
}
+bool AArch64TargetLowering::shouldExtendGSIndex(EVT VT, EVT &EltTy) const {
+ if (VT.getVectorElementType() == MVT::i8 ||
+ VT.getVectorElementType() == MVT::i16) {
+ EltTy = MVT::i32;
+ return true;
+ }
+ return false;
+}
+
bool AArch64TargetLowering::shouldRemoveExtendFromGSIndex(EVT VT) const {
if (VT.getVectorElementType() == MVT::i32 &&
VT.getVectorElementCount().getKnownMinValue() >= 4)
@@ -14395,55 +14401,6 @@ static SDValue performSTORECombine(SDNode *N,
return SDValue();
}
-static SDValue performMaskedGatherScatterCombine(SDNode *N,
- TargetLowering::DAGCombinerInfo &DCI,
- SelectionDAG &DAG) {
- MaskedGatherScatterSDNode *MGS = cast<MaskedGatherScatterSDNode>(N);
- assert(MGS && "Can only combine gather load or scatter store nodes");
-
- SDLoc DL(MGS);
- SDValue Chain = MGS->getChain();
- SDValue Scale = MGS->getScale();
- SDValue Index = MGS->getIndex();
- SDValue Mask = MGS->getMask();
- SDValue BasePtr = MGS->getBasePtr();
- ISD::MemIndexType IndexType = MGS->getIndexType();
-
- EVT IdxVT = Index.getValueType();
-
- if (DCI.isBeforeLegalize()) {
- // SVE gather/scatter requires indices of i32/i64. Promote anything smaller
- // prior to legalisation so the result can be split if required.
- if ((IdxVT.getVectorElementType() == MVT::i8) ||
- (IdxVT.getVectorElementType() == MVT::i16)) {
- EVT NewIdxVT = IdxVT.changeVectorElementType(MVT::i32);
- if (MGS->isIndexSigned())
- Index = DAG.getNode(ISD::SIGN_EXTEND, DL, NewIdxVT, Index);
- else
- Index = DAG.getNode(ISD::ZERO_EXTEND, DL, NewIdxVT, Index);
-
- if (auto *MGT = dyn_cast<MaskedGatherSDNode>(MGS)) {
- SDValue PassThru = MGT->getPassThru();
- SDValue Ops[] = { Chain, PassThru, Mask, BasePtr, Index, Scale };
- return DAG.getMaskedGather(DAG.getVTList(N->getValueType(0), MVT::Other),
- PassThru.getValueType(), DL, Ops,
- MGT->getMemOperand(),
- MGT->getIndexType(), MGT->getExtensionType());
- } else {
- auto *MSC = cast<MaskedScatterSDNode>(MGS);
- SDValue Data = MSC->getValue();
- SDValue Ops[] = { Chain, Data, Mask, BasePtr, Index, Scale };
- return DAG.getMaskedScatter(DAG.getVTList(MVT::Other),
- MSC->getMemoryVT(), DL, Ops,
- MSC->getMemOperand(), IndexType,
- MSC->isTruncatingStore());
- }
- }
- }
-
- return SDValue();
-}
-
/// Target-specific DAG combine function for NEON load/store intrinsics
/// to merge base address updates.
static SDValue performNEONPostLDSTCombine(SDNode *N,
@@ -15638,9 +15595,6 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
break;
case ISD::STORE:
return performSTORECombine(N, DCI, DAG, Subtarget);
- case ISD::MGATHER:
- case ISD::MSCATTER:
- return performMaskedGatherScatterCombine(N, DCI, DAG);
case AArch64ISD::BRCOND:
return performBRCONDCombine(N, DCI, DAG);
case AArch64ISD::TBNZ:
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 9550197159e6..8aec29478b72 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -996,6 +996,7 @@ class AArch64TargetLowering : public TargetLowering {
return TargetLowering::getInlineAsmMemConstraint(ConstraintCode);
}
+ bool shouldExtendGSIndex(EVT VT, EVT &EltTy) const override;
bool shouldRemoveExtendFromGSIndex(EVT VT) const override;
bool isVectorLoadExtDesirable(SDValue ExtVal) const override;
bool isUsedByReturnOnly(SDNode *N, SDValue &Chain) const override;
More information about the llvm-commits
mailing list