[llvm] [LV] Fix the cost of min/max reductions. (PR #98453)

Mel Chen via llvm-commits llvm-commits at lists.llvm.org
Thu Jul 11 02:30:59 PDT 2024


https://github.com/Mel-Chen created https://github.com/llvm/llvm-project/pull/98453

This patch updates the function `getReductionPatternCost` to handle the cost of min/max reductions by `TTI.getMinMaxReductionCost`.

>From d266f479ee1e1f3c9021b227fac3c568d16ee901 Mon Sep 17 00:00:00 2001
From: Mel Chen <mel.chen at sifive.com>
Date: Thu, 11 Jul 2024 01:47:26 -0700
Subject: [PATCH] [LV] Fix the cost of in-loop minmax reduction.

---
 llvm/lib/Transforms/Vectorize/LoopVectorize.cpp | 14 +++++++++++---
 1 file changed, 11 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index ed63453f1e357..c955e192e3a46 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -6138,12 +6138,20 @@ LoopVectorizationCostModel::getReductionPatternCost(
   const RecurrenceDescriptor &RdxDesc =
       Legal->getReductionVars().find(cast<PHINode>(ReductionPhi))->second;
 
-  InstructionCost BaseCost = TTI.getArithmeticReductionCost(
-      RdxDesc.getOpcode(), VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+  InstructionCost BaseCost;
+  RecurKind RK = RdxDesc.getRecurrenceKind();
+  if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RK)) {
+    Intrinsic::ID MinMaxID = getMinMaxReductionIntrinsicOp(RK);
+    BaseCost = TTI.getMinMaxReductionCost(MinMaxID, VectorTy,
+                                          RdxDesc.getFastMathFlags(), CostKind);
+  } else {
+    BaseCost = TTI.getArithmeticReductionCost(
+        RdxDesc.getOpcode(), VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+  }
 
   // For a call to the llvm.fmuladd intrinsic we need to add the cost of a
   // normal fmul instruction to the cost of the fadd reduction.
-  if (RdxDesc.getRecurrenceKind() == RecurKind::FMulAdd)
+  if (RK == RecurKind::FMulAdd)
     BaseCost +=
         TTI.getArithmeticInstrCost(Instruction::FMul, VectorTy, CostKind);
 



More information about the llvm-commits mailing list