[llvm] [LV] Fix the cost of min/max reductions. (PR #98453)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jul 11 02:31:34 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Mel Chen (Mel-Chen)

<details>
<summary>Changes</summary>

This patch updates the function `getReductionPatternCost` to handle the cost of min/max reductions by `TTI.getMinMaxReductionCost`.

---
Full diff: https://github.com/llvm/llvm-project/pull/98453.diff


1 Files Affected:

- (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+11-3) 


``````````diff
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index ed63453f1e357..c955e192e3a46 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -6138,12 +6138,20 @@ LoopVectorizationCostModel::getReductionPatternCost(
   const RecurrenceDescriptor &RdxDesc =
       Legal->getReductionVars().find(cast<PHINode>(ReductionPhi))->second;
 
-  InstructionCost BaseCost = TTI.getArithmeticReductionCost(
-      RdxDesc.getOpcode(), VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+  InstructionCost BaseCost;
+  RecurKind RK = RdxDesc.getRecurrenceKind();
+  if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RK)) {
+    Intrinsic::ID MinMaxID = getMinMaxReductionIntrinsicOp(RK);
+    BaseCost = TTI.getMinMaxReductionCost(MinMaxID, VectorTy,
+                                          RdxDesc.getFastMathFlags(), CostKind);
+  } else {
+    BaseCost = TTI.getArithmeticReductionCost(
+        RdxDesc.getOpcode(), VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+  }
 
   // For a call to the llvm.fmuladd intrinsic we need to add the cost of a
   // normal fmul instruction to the cost of the fadd reduction.
-  if (RdxDesc.getRecurrenceKind() == RecurKind::FMulAdd)
+  if (RK == RecurKind::FMulAdd)
     BaseCost +=
         TTI.getArithmeticInstrCost(Instruction::FMul, VectorTy, CostKind);
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/98453


More information about the llvm-commits mailing list