[llvm] [AArch64][CostModel] Reduce the cost of fadd reduction with fast flag (PR #108791)

Sushant Gokhale via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 24 01:13:12 PDT 2024


================
@@ -4153,6 +4153,47 @@ AArch64TTIImpl::getArithmeticReductionCost(unsigned Opcode, VectorType *ValTy,
   switch (ISD) {
   default:
     break;
+  case ISD::FADD:
+    if (Type *EltTy = ValTy->getScalarType();
+        // FIXME: We would be restricting the input scalar type to following
+        // types since for some of the types, codegen might be different e.g.
+        // fp128. Also, for half types without fullfp16 support, the cost maybe
+        // still be higher than what is expected from codegen.
+        MTy.isVector() && (EltTy->isFloatTy() || EltTy->isDoubleTy() ||
+                           (EltTy->isHalfTy() && ST->hasFullFP16()))) {
+      const unsigned NElts = MTy.getVectorNumElements();
+      if (ValTy->getElementCount().getFixedValue() >= 2 && NElts >= 2 &&
+          isPowerOf2_32(NElts))
+        // Reduction corresponding to series of fadd instructions is lowered to
+        // series of faddp instructions. faddp has latency/throughput that
+        // matches fadd instruction and hence, every faddp instruction can be
+        // considered to have a relative cost = 1 with
+        // CostKind = TCK_RecipThroughput.
+        //
+        // Semantics of faddp is it concatenates first vector after second
+        // vector and then does pairwise addition. So, every time pairwise
+        // addition is performed, size of input vector to reduction reduces to
+        // half. e.g. 1st step of reducing v0 can be depicted as
+        //        v0                  v0
+        // ----------------   -----------------
+        // |3 | 2 | 1 | 0 |   | 3 | 2 | 1 | 0 |
+        // ----------------   -----------------
+        //
+        // Step 1: faddp v1.4h, v0.4h, v0.4h
+        //
+        //                       ---------------------------------
+        // concatenated vector   | 3 | 2 | 1 | 0 | 3 | 2 | 1 | 0 |
+        //                       ---------------------------------
+        //
+        //                       -------------------------
+        // pairwise addition     | 2+3 | 0+1 | 2+3 | 0+1 |
+        //                       -------------------------
+        // Next faddp would give us the result of 0+1+2+3.
+        // Since, the size of input vector reduces by half every time,
+        // #(faddp instructions) = log2_32(NElts)
----------------
sushgokh wrote:

sure, thanks.

https://github.com/llvm/llvm-project/pull/108791


More information about the llvm-commits mailing list