[llvm] [VectorCombine] Fold binary op of reductions. (PR #121567)

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 12 04:39:50 PST 2025


================
@@ -1182,6 +1183,135 @@ bool VectorCombine::foldExtractedCmps(Instruction &I) {
   return true;
 }
 
+static void analyzeCostOfVecReduction(const IntrinsicInst &II,
+                                      TTI::TargetCostKind CostKind,
+                                      const TargetTransformInfo &TTI,
+                                      InstructionCost &CostBeforeReduction,
+                                      InstructionCost &CostAfterReduction) {
+  using namespace llvm::PatternMatch;
+  Instruction *Op0, *Op1;
+  Instruction *RedOp = dyn_cast<Instruction>(II.getOperand(0));
+  VectorType *VecRedTy = cast<VectorType>(II.getOperand(0)->getType());
+  unsigned ReductionOpc =
+      getArithmeticReductionInstruction(II.getIntrinsicID());
+  if (RedOp && match(RedOp, m_ZExtOrSExt(m_Value()))) {
+    bool IsUnsigned = isa<ZExtInst>(RedOp);
+    VectorType *ExtType =
+        cast<VectorType>(RedOp->getOperand(0)->getType());
+
+    CostBeforeReduction =
+        TTI.getCastInstrCost(RedOp->getOpcode(), VecRedTy, ExtType,
+                             TTI::CastContextHint::None, CostKind, RedOp);
+    CostAfterReduction =
+        TTI.getExtendedReductionCost(ReductionOpc, IsUnsigned, II.getType(),
+                                     ExtType, FastMathFlags(), CostKind);
+    return;
+  }
+  if (RedOp && II.getIntrinsicID() == Intrinsic::vector_reduce_add &&
+      match(RedOp,
+            m_ZExtOrSExt(m_Mul(m_Instruction(Op0), m_Instruction(Op1)))) &&
+      match(Op0, m_ZExtOrSExt(m_Value())) &&
+      Op0->getOpcode() == Op1->getOpcode() &&
+      Op0->getOperand(0)->getType() == Op1->getOperand(0)->getType() &&
+      (Op0->getOpcode() == RedOp->getOpcode() || Op0 == Op1)) {
+    // Matched reduce.add(ext(mul(ext(A), ext(B)))
+    bool IsUnsigned = isa<ZExtInst>(Op0);
+    VectorType *ExtType =
+        cast<VectorType>(Op0->getOperand(0)->getType());
+    VectorType *MulType = VectorType::get(Op0->getType(), VecRedTy);
+
+    InstructionCost ExtCost =
+        TTI.getCastInstrCost(Op0->getOpcode(), MulType, ExtType,
+                             TTI::CastContextHint::None, CostKind, Op0);
+    InstructionCost MulCost =
+        TTI.getArithmeticInstrCost(Instruction::Mul, MulType, CostKind);
+    InstructionCost Ext2Cost =
+        TTI.getCastInstrCost(RedOp->getOpcode(), VecRedTy, MulType,
+                             TTI::CastContextHint::None, CostKind, RedOp);
+
+    CostBeforeReduction = ExtCost * 2 + MulCost + Ext2Cost;
+    CostAfterReduction =
+        TTI.getMulAccReductionCost(IsUnsigned, II.getType(), ExtType, CostKind);
+    return;
+  }
+  CostAfterReduction = TTI.getArithmeticReductionCost(ReductionOpc, VecRedTy,
+                                                      std::nullopt, CostKind);
+  return;
+}
+
+bool VectorCombine::foldBinopOfReductions(Instruction &I) {
+  Instruction::BinaryOps BinOpOpc = cast<BinaryOperator>(&I)->getOpcode();
+  Intrinsic::ID ReductionIID = getReductionForBinop(BinOpOpc);
+  if (BinOpOpc == Instruction::Sub)
+    ReductionIID = Intrinsic::vector_reduce_add;
+  if (ReductionIID == Intrinsic::not_intrinsic)
+    return false;
+
+  auto checkIntrinsicAndGetItsArgument = [](Value *V,
+                                            Intrinsic::ID IID) -> Value * {
+    IntrinsicInst *II = dyn_cast<IntrinsicInst>(V);
+    if (!II)
+      return nullptr;
+    if (II->getIntrinsicID() == IID && II->hasOneUse())
+      return II->getArgOperand(0);
+    return nullptr;
+  };
+
+  Value *V0 = checkIntrinsicAndGetItsArgument(I.getOperand(0), ReductionIID);
+  if (!V0)
+    return false;
+  Value *V1 = checkIntrinsicAndGetItsArgument(I.getOperand(1), ReductionIID);
+  if (!V1)
+    return false;
+
+  VectorType *VTy = cast<VectorType>(V0->getType());
+  if (V1->getType() != VTy)
+    return false;
+  const IntrinsicInst &II0 = *cast<IntrinsicInst>(I.getOperand(0));
+  const IntrinsicInst &II1 = *cast<IntrinsicInst>(I.getOperand(1));
+  unsigned ReductionOpc =
+      getArithmeticReductionInstruction(II0.getIntrinsicID());
+
+  InstructionCost OldCost = 0;
+  InstructionCost NewCost = 0;
+  InstructionCost CostOfRedOperand0 = 0;
+  InstructionCost CostOfRed0 = 0;
+  InstructionCost CostOfRedOperand1 = 0;
+  InstructionCost CostOfRed1 = 0;
+  analyzeCostOfVecReduction(II0, CostKind, TTI, CostOfRedOperand0, CostOfRed0);
+  analyzeCostOfVecReduction(II1, CostKind, TTI, CostOfRedOperand1, CostOfRed1);
+  OldCost = CostOfRed0 + CostOfRed1 + TTI.getInstructionCost(&I, CostKind);
+  NewCost =
+      CostOfRedOperand0 + CostOfRedOperand1 +
+      TTI.getArithmeticInstrCost(BinOpOpc, VTy, CostKind) +
+      TTI.getArithmeticReductionCost(ReductionOpc, VTy, std::nullopt, CostKind);
+  // TODO: remove this
----------------
RKSimon wrote:

cleanup TODO and dbg, we usually just do something like:
```
  LLVM_DEBUG(dbgs() << "Found mergeable reductions: " << I
                    << "\n  OldCost: " << OldCost << " vs NewCost: " << NewCost
                    << "\n");
```

https://github.com/llvm/llvm-project/pull/121567


More information about the llvm-commits mailing list