[llvm] [AArch64][SVE] Improve code quality of vector unsigned add reduction. (PR #97339)
Dinar Temirbulatov via llvm-commits
llvm-commits at lists.llvm.org
Tue Jul 2 12:52:31 PDT 2024
================
@@ -17455,6 +17455,99 @@ static SDValue performVecReduceAddCombineWithUADDLP(SDNode *N,
return DAG.getNode(ISD::VECREDUCE_ADD, DL, MVT::i32, UADDLP);
}
+static SDValue
+performVecReduceAddZextCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
+ const AArch64TargetLowering &TLI) {
+ if (N->getOperand(0).getOpcode() != ISD::ZERO_EXTEND)
+ return SDValue();
+
+ SelectionDAG &DAG = DCI.DAG;
+ auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
+ SDNode *ZEXT = N->getOperand(0).getNode();
+ EVT VecVT = ZEXT->getOperand(0).getValueType();
+ SDLoc DL(N);
+
+ SDValue VecOp = ZEXT->getOperand(0);
+ VecVT = VecOp.getValueType();
+ bool IsScalableType = VecVT.isScalableVector();
+
+ if (TLI.isTypeLegal(VecVT)) {
+ if (!IsScalableType &&
+ !TLI.useSVEForFixedLengthVectorVT(
+ VecVT,
+ /*OverrideNEON=*/Subtarget.useSVEForFixedLengthVectors(VecVT)))
+ return SDValue();
+
+ if (!IsScalableType) {
+ EVT ContainerVT = getContainerForFixedLengthVector(DAG, VecVT);
+ VecOp = convertToScalableVector(DAG, ContainerVT, VecOp);
+ }
+ VecVT = VecOp.getValueType();
+ EVT RdxVT = N->getValueType(0);
+ RdxVT = getPackedSVEVectorVT(RdxVT);
+ SDValue Pg = getPredicateForVector(DAG, DL, VecVT);
+ SDValue Res = DAG.getNode(
+ ISD::INTRINSIC_WO_CHAIN, DL, MVT::i64,
+ DAG.getConstant(Intrinsic::aarch64_sve_uaddv, DL, MVT::i64), Pg, VecOp);
+ EVT ResVT = MVT::i64;
+ if (ResVT != N->getValueType(0))
+ Res = DAG.getAnyExtOrTrunc(Res, DL, N->getValueType(0));
+ return Res;
+ }
+
+ SmallVector<SDValue, 4> SplitVals;
+ SmallVector<SDValue, 4> PrevVals;
+ PrevVals.push_back(VecOp);
+ while (true) {
+
+ if (!VecVT.isScalableVector() &&
+ !PrevVals[0].getValueType().getVectorElementCount().isKnownEven())
+ return SDValue();
+
+ for (SDValue Vec : PrevVals) {
+ SDValue Lo, Hi;
+ std::tie(Lo, Hi) = DAG.SplitVector(Vec, DL);
+ SplitVals.push_back(Lo);
+ SplitVals.push_back(Hi);
+ }
+ if (TLI.isTypeLegal(SplitVals[0].getValueType()))
+ break;
+ PrevVals.clear();
+ std::copy(SplitVals.begin(), SplitVals.end(), std::back_inserter(PrevVals));
+ SplitVals.clear();
+ }
+ SDNode *VecRed = N;
+ EVT ElemType = VecRed->getValueType(0);
+ SmallVector<SDValue, 4> Results;
+
+ if (!IsScalableType &&
+ !TLI.useSVEForFixedLengthVectorVT(
+ SplitVals[0].getValueType(),
+ /*OverrideNEON=*/Subtarget.useSVEForFixedLengthVectors(
+ SplitVals[0].getValueType())))
+ return SDValue();
+
+ for (unsigned Num = 0; Num < SplitVals.size(); ++Num) {
----------------
dtemirbulatov wrote:
Done.
https://github.com/llvm/llvm-project/pull/97339
More information about the llvm-commits
mailing list