[llvm] [AArch64][SVE] Improve code quality of vector unsigned/signed add reductions. (PR #97339)

Dinar Temirbulatov via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 17 07:34:07 PDT 2024


================
@@ -17455,6 +17456,77 @@ static SDValue performVecReduceAddCombineWithUADDLP(SDNode *N,
   return DAG.getNode(ISD::VECREDUCE_ADD, DL, MVT::i32, UADDLP);
 }
 
+// Turn [sign|zero]_extend(vecreduce_add()) into SVE's  SADDV|UADDV
+// instructions.
+static SDValue
+performVecReduceAddExtCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
+                              const AArch64TargetLowering &TLI) {
+  if (N->getOperand(0).getOpcode() != ISD::ZERO_EXTEND &&
+      N->getOperand(0).getOpcode() != ISD::SIGN_EXTEND)
+    return SDValue();
+  bool IsSigned = N->getOperand(0).getOpcode() == ISD::SIGN_EXTEND;
+
+  SelectionDAG &DAG = DCI.DAG;
+  auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
+  SDValue VecOp = N->getOperand(0).getOperand(0);
+  SDLoc DL(N);
+
+  bool IsScalableType = VecOp.getValueType().isScalableVector();
+  std::deque<SDValue> ResultValues;
+  ResultValues.push_back(VecOp);
+
+  // Split the input vectors if not legal.
+  while (!TLI.isTypeLegal(ResultValues.front().getValueType())) {
+    if (!ResultValues.front()
+             .getValueType()
+             .getVectorElementCount()
+             .isKnownEven())
+      return SDValue();
+    EVT CurVT = ResultValues.front().getValueType();
+    while (true) {
+      SDValue Vec = ResultValues.front();
+      if (Vec.getValueType() != CurVT)
+        break;
+      ResultValues.pop_front();
+      SDValue Lo, Hi;
+      std::tie(Lo, Hi) = DAG.SplitVector(Vec, DL);
+      ResultValues.push_back(Lo);
+      ResultValues.push_back(Hi);
+    }
+  }
+
+  EVT ElemType = N->getValueType(0);
+  SmallVector<SDValue, 2> Results;
+  if (!IsScalableType &&
+      !TLI.useSVEForFixedLengthVectorVT(
+          ResultValues[0].getValueType(),
+          /*OverrideNEON=*/Subtarget.useSVEForFixedLengthVectors(
+              ResultValues[0].getValueType())))
----------------
dtemirbulatov wrote:

Done.

https://github.com/llvm/llvm-project/pull/97339


More information about the llvm-commits mailing list