[llvm] [AArch64][CostModel] Reduce the cost of fadd reduction with fast flag (PR #108791)
David Green via llvm-commits
llvm-commits at lists.llvm.org
Tue Sep 24 00:19:22 PDT 2024
================
@@ -4153,6 +4153,47 @@ AArch64TTIImpl::getArithmeticReductionCost(unsigned Opcode, VectorType *ValTy,
switch (ISD) {
default:
break;
+ case ISD::FADD:
+ if (Type *EltTy = ValTy->getScalarType();
+ // FIXME: We would be restricting the input scalar type to following
+ // types since for some of the types, codegen might be different e.g.
+ // fp128. Also, for half types without fullfp16 support, the cost maybe
+ // still be higher than what is expected from codegen.
+ MTy.isVector() && (EltTy->isFloatTy() || EltTy->isDoubleTy() ||
+ (EltTy->isHalfTy() && ST->hasFullFP16()))) {
+ const unsigned NElts = MTy.getVectorNumElements();
+ if (ValTy->getElementCount().getFixedValue() >= 2 && NElts >= 2 &&
+ isPowerOf2_32(NElts))
+ // Reduction corresponding to series of fadd instructions is lowered to
+ // series of faddp instructions. faddp has latency/throughput that
+ // matches fadd instruction and hence, every faddp instruction can be
+ // considered to have a relative cost = 1 with
+ // CostKind = TCK_RecipThroughput.
+ //
+ // Semantics of faddp is it concatenates first vector after second
+ // vector and then does pairwise addition. So, every time pairwise
+ // addition is performed, size of input vector to reduction reduces to
+ // half. e.g. 1st step of reducing v0 can be depicted as
+ // v0 v0
+ // ---------------- -----------------
+ // |3 | 2 | 1 | 0 | | 3 | 2 | 1 | 0 |
+ // ---------------- -----------------
+ //
+ // Step 1: faddp v1.4h, v0.4h, v0.4h
+ //
+ // ---------------------------------
+ // concatenated vector | 3 | 2 | 1 | 0 | 3 | 2 | 1 | 0 |
+ // ---------------------------------
+ //
+ // -------------------------
+ // pairwise addition | 2+3 | 0+1 | 2+3 | 0+1 |
+ // -------------------------
+ // Next faddp would give us the result of 0+1+2+3.
+ // Since, the size of input vector reduces by half every time,
+ // #(faddp instructions) = log2_32(NElts)
----------------
davemgreen wrote:
I'm not sure if explaining how faddp works here is super useful, and the comment is quite verbose with all the ascii art. Perhaps simplify it to something like `An faddp will pairwise add vcetor elements, so the size of input vector reduces by half every time, requiring #(faddp instructions) = log2_32(NElts).`
https://github.com/llvm/llvm-project/pull/108791
More information about the llvm-commits
mailing list