[llvm] [RISCV][CostModel] Updates reduction and shuffle cost (PR #77342)

Shih-Po Hung via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 11 09:27:00 PST 2024


================
@@ -1392,20 +1426,58 @@ RISCVTTIImpl::getArithmeticReductionCost(unsigned Opcode, VectorType *Ty,
     return BaseT::getArithmeticReductionCost(Opcode, Ty, FMF, CostKind);
 
   std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Ty);
-  if (Ty->getElementType()->isIntegerTy(1))
+  SmallVector<unsigned, 3> Opcodes;
+  Type *ElementTy = Ty->getElementType();
+  if (ElementTy->isIntegerTy(1)) {
     // vcpop sequences, see vreduction-mask.ll
-    return (LT.first - 1) + (ISD == ISD::AND ? 3 : 2);
+    if (ISD == ISD::AND) {
+      Opcodes = {RISCV::VMNAND_MM, RISCV::VCPOP_M};
+      return (LT.first - 1) +
+             getRISCVInstructionCost(Opcodes, LT.second, CostKind) +
+             getCmpSelInstrCost(Instruction::Select, ElementTy, ElementTy,
+                                CmpInst::ICMP_EQ, CostKind);
+    } else {
+      Opcodes = {RISCV::VCPOP_M};
+      return (LT.first - 1) +
+             getRISCVInstructionCost(Opcodes, LT.second, CostKind) +
+             getCmpSelInstrCost(Instruction::Select, ElementTy, ElementTy,
+                                CmpInst::ICMP_NE, CostKind);
+    }
+  }
 
   // IR Reduction is composed by two vmv and one rvv reduction instruction.
-  InstructionCost BaseCost = 2;
-
-  if (CostKind == TTI::TCK_CodeSize)
-    return (LT.first - 1) + BaseCost;
-
-  unsigned VL = getEstimatedVLFor(Ty);
-  if (TTI::requiresOrderedReduction(FMF))
-    return (LT.first - 1) + BaseCost + VL;
-  return (LT.first - 1) + BaseCost + Log2_32_Ceil(VL);
+  unsigned SplitOp;
+  switch (ISD) {
+  case ISD::ADD:
+    SplitOp = RISCV::VADD_VV;
+    Opcodes = {RISCV::VMV_S_X, RISCV::VREDSUM_VS, RISCV::VMV_X_S};
+    break;
+  case ISD::OR:
+    SplitOp = RISCV::VOR_VV;
+    Opcodes = {RISCV::VMV_S_X, RISCV::VREDOR_VS, RISCV::VMV_X_S};
+    break;
+  case ISD::XOR:
+    SplitOp = RISCV::VXOR_VV;
+    Opcodes = {RISCV::VMV_S_X, RISCV::VREDXOR_VS, RISCV::VMV_X_S};
+    break;
+  case ISD::AND:
+    SplitOp = RISCV::VAND_VV;
+    Opcodes = {RISCV::VMV_S_X, RISCV::VREDAND_VS, RISCV::VMV_X_S};
+    break;
+  case ISD::FADD:
+    SplitOp = RISCV::VFADD_VV;
+    if (TTI::requiresOrderedReduction(FMF))
+      Opcodes = {RISCV::VFMV_S_F, RISCV::VFREDOSUM_VS, RISCV::VFMV_F_S};
----------------
arcbbb wrote:

Thanks for pointing out this! Fixed.

https://github.com/llvm/llvm-project/pull/77342


More information about the llvm-commits mailing list