[llvm] [VPlan] Impl VPlan-based pattern match for ExtendedRed and MulAccRed (NFCI) (PR #113903)

Elvis Wang via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 28 05:56:49 PDT 2024


https://github.com/ElvisWang123 created https://github.com/llvm/llvm-project/pull/113903

This patch implement the VPlan-based pattern match for extendedReduction and MulAccReduction. In above reduction patterns, extened instructions and mul instruction can fold into reduction instruction and the cost is free.

We add `FoldedRecipes` in the `VPCostContext` to put recipes that can be folded into other recipes.

ExtendedReductionPatterns:
    reduce(ext(...))
MulAccReductionPatterns:
    reduce.add(mul(...))
    reduce.add(mul(ext(...), ext(...)))
    reduce.add(ext(mul(...)))
    reduce.add(ext(mul(ext(...), ext(...))))

Ref: Original instruction based implementation:
https://reviews.llvm.org/D93476

This patch is based on #113902 .

>From 35bdec7786d3c881c8989e0a8c459fe8183d6d59 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Mon, 28 Oct 2024 01:40:45 -0700
Subject: [PATCH 1/2] [LV] Reverse recipes cost calculation.

This patch reverse the instruction cost calculation for VPRecipes to
reduce similar function invocation in the future.
---
 llvm/lib/Transforms/Vectorize/VPlan.cpp       |  2 +-
 .../LoopVectorize/RISCV/interleaved-cost.ll   | 72 +++++++++----------
 2 files changed, 37 insertions(+), 37 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp
index 6ab8fb45c351b4..49e93e1e7b5501 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp
@@ -785,7 +785,7 @@ void VPRegionBlock::execute(VPTransformState *State) {
 
 InstructionCost VPBasicBlock::cost(ElementCount VF, VPCostContext &Ctx) {
   InstructionCost Cost = 0;
-  for (VPRecipeBase &R : Recipes)
+  for (VPRecipeBase &R : reverse(Recipes))
     Cost += R.cost(VF, Ctx);
   return Cost;
 }
diff --git a/llvm/test/Transforms/LoopVectorize/RISCV/interleaved-cost.ll b/llvm/test/Transforms/LoopVectorize/RISCV/interleaved-cost.ll
index fa346b4eac02d4..f2e36399c85f5d 100644
--- a/llvm/test/Transforms/LoopVectorize/RISCV/interleaved-cost.ll
+++ b/llvm/test/Transforms/LoopVectorize/RISCV/interleaved-cost.ll
@@ -6,26 +6,26 @@ define void @i8_factor_2(ptr %data, i64 %n) {
 entry:
   br label %for.body
 ; CHECK-LABEL: Checking a loop in 'i8_factor_2'
-; CHECK: Cost of 2 for VF 2: INTERLEAVE-GROUP with factor 2 at %l0, ir<%p0>
 ; CHECK: Cost of 2 for VF 2: INTERLEAVE-GROUP with factor 2 at <badref>, ir<%p0>
-; CHECK: Cost of 2 for VF 4: INTERLEAVE-GROUP with factor 2 at %l0, ir<%p0>
+; CHECK: Cost of 2 for VF 2: INTERLEAVE-GROUP with factor 2 at %l0, ir<%p0>
 ; CHECK: Cost of 2 for VF 4: INTERLEAVE-GROUP with factor 2 at <badref>, ir<%p0>
-; CHECK: Cost of 2 for VF 8: INTERLEAVE-GROUP with factor 2 at %l0, ir<%p0>
+; CHECK: Cost of 2 for VF 4: INTERLEAVE-GROUP with factor 2 at %l0, ir<%p0>
 ; CHECK: Cost of 2 for VF 8: INTERLEAVE-GROUP with factor 2 at <badref>, ir<%p0>
-; CHECK: Cost of 3 for VF 16: INTERLEAVE-GROUP with factor 2 at %l0, ir<%p0>
+; CHECK: Cost of 2 for VF 8: INTERLEAVE-GROUP with factor 2 at %l0, ir<%p0>
 ; CHECK: Cost of 3 for VF 16: INTERLEAVE-GROUP with factor 2 at <badref>, ir<%p0>
-; CHECK: Cost of 5 for VF 32: INTERLEAVE-GROUP with factor 2 at %l0, ir<%p0>
+; CHECK: Cost of 3 for VF 16: INTERLEAVE-GROUP with factor 2 at %l0, ir<%p0>
 ; CHECK: Cost of 5 for VF 32: INTERLEAVE-GROUP with factor 2 at <badref>, ir<%p0>
-; CHECK: Cost of 2 for VF vscale x 1: INTERLEAVE-GROUP with factor 2 at %l0, ir<%p0>
+; CHECK: Cost of 5 for VF 32: INTERLEAVE-GROUP with factor 2 at %l0, ir<%p0>
 ; CHECK: Cost of 2 for VF vscale x 1: INTERLEAVE-GROUP with factor 2 at <badref>, ir<%p0>
-; CHECK: Cost of 2 for VF vscale x 2: INTERLEAVE-GROUP with factor 2 at %l0, ir<%p0>
+; CHECK: Cost of 2 for VF vscale x 1: INTERLEAVE-GROUP with factor 2 at %l0, ir<%p0>
 ; CHECK: Cost of 2 for VF vscale x 2: INTERLEAVE-GROUP with factor 2 at <badref>, ir<%p0>
-; CHECK: Cost of 2 for VF vscale x 4: INTERLEAVE-GROUP with factor 2 at %l0, ir<%p0>
+; CHECK: Cost of 2 for VF vscale x 2: INTERLEAVE-GROUP with factor 2 at %l0, ir<%p0>
 ; CHECK: Cost of 2 for VF vscale x 4: INTERLEAVE-GROUP with factor 2 at <badref>, ir<%p0>
-; CHECK: Cost of 3 for VF vscale x 8: INTERLEAVE-GROUP with factor 2 at %l0, ir<%p0>
+; CHECK: Cost of 2 for VF vscale x 4: INTERLEAVE-GROUP with factor 2 at %l0, ir<%p0>
 ; CHECK: Cost of 3 for VF vscale x 8: INTERLEAVE-GROUP with factor 2 at <badref>, ir<%p0>
-; CHECK: Cost of 5 for VF vscale x 16: INTERLEAVE-GROUP with factor 2 at %l0, ir<%p0>
+; CHECK: Cost of 3 for VF vscale x 8: INTERLEAVE-GROUP with factor 2 at %l0, ir<%p0>
 ; CHECK: Cost of 5 for VF vscale x 16: INTERLEAVE-GROUP with factor 2 at <badref>, ir<%p0>
+; CHECK: Cost of 5 for VF vscale x 16: INTERLEAVE-GROUP with factor 2 at %l0, ir<%p0>
 for.body:
   %i = phi i64 [ 0, %entry ], [ %i.next, %for.body ]
   %p0 = getelementptr inbounds %i8.2, ptr %data, i64 %i, i32 0
@@ -49,16 +49,16 @@ define void @i8_factor_3(ptr %data, i64 %n) {
 entry:
   br label %for.body
 ; CHECK-LABEL: Checking a loop in 'i8_factor_3'
-; CHECK: Cost of 2 for VF 2: INTERLEAVE-GROUP with factor 3 at %l0, ir<%p0>
 ; CHECK: Cost of 2 for VF 2: INTERLEAVE-GROUP with factor 3 at <badref>, ir<%p0>
-; CHECK: Cost of 2 for VF 4: INTERLEAVE-GROUP with factor 3 at %l0, ir<%p0>
+; CHECK: Cost of 2 for VF 2: INTERLEAVE-GROUP with factor 3 at %l0, ir<%p0>
 ; CHECK: Cost of 2 for VF 4: INTERLEAVE-GROUP with factor 3 at <badref>, ir<%p0>
-; CHECK: Cost of 3 for VF 8: INTERLEAVE-GROUP with factor 3 at %l0, ir<%p0>
+; CHECK: Cost of 2 for VF 4: INTERLEAVE-GROUP with factor 3 at %l0, ir<%p0>
 ; CHECK: Cost of 3 for VF 8: INTERLEAVE-GROUP with factor 3 at <badref>, ir<%p0>
-; CHECK: Cost of 5 for VF 16: INTERLEAVE-GROUP with factor 3 at %l0, ir<%p0>
+; CHECK: Cost of 3 for VF 8: INTERLEAVE-GROUP with factor 3 at %l0, ir<%p0>
 ; CHECK: Cost of 5 for VF 16: INTERLEAVE-GROUP with factor 3 at <badref>, ir<%p0>
-; CHECK: Cost of 9 for VF 32: INTERLEAVE-GROUP with factor 3 at %l0, ir<%p0>
+; CHECK: Cost of 5 for VF 16: INTERLEAVE-GROUP with factor 3 at %l0, ir<%p0>
 ; CHECK: Cost of 9 for VF 32: INTERLEAVE-GROUP with factor 3 at <badref>, ir<%p0>
+; CHECK: Cost of 9 for VF 32: INTERLEAVE-GROUP with factor 3 at %l0, ir<%p0>
 for.body:
   %i = phi i64 [ 0, %entry ], [ %i.next, %for.body ]
   %p0 = getelementptr inbounds %i8.3, ptr %data, i64 %i, i32 0
@@ -86,16 +86,16 @@ define void @i8_factor_4(ptr %data, i64 %n) {
 entry:
   br label %for.body
 ; CHECK-LABEL: Checking a loop in 'i8_factor_4'
-; CHECK: Cost of 2 for VF 2: INTERLEAVE-GROUP with factor 4 at %l0, ir<%p0>
 ; CHECK: Cost of 2 for VF 2: INTERLEAVE-GROUP with factor 4 at <badref>, ir<%p0>
-; CHECK: Cost of 2 for VF 4: INTERLEAVE-GROUP with factor 4 at %l0, ir<%p0>
+; CHECK: Cost of 2 for VF 2: INTERLEAVE-GROUP with factor 4 at %l0, ir<%p0>
 ; CHECK: Cost of 2 for VF 4: INTERLEAVE-GROUP with factor 4 at <badref>, ir<%p0>
-; CHECK: Cost of 3 for VF 8: INTERLEAVE-GROUP with factor 4 at %l0, ir<%p0>
+; CHECK: Cost of 2 for VF 4: INTERLEAVE-GROUP with factor 4 at %l0, ir<%p0>
 ; CHECK: Cost of 3 for VF 8: INTERLEAVE-GROUP with factor 4 at <badref>, ir<%p0>
-; CHECK: Cost of 5 for VF 16: INTERLEAVE-GROUP with factor 4 at %l0, ir<%p0>
+; CHECK: Cost of 3 for VF 8: INTERLEAVE-GROUP with factor 4 at %l0, ir<%p0>
 ; CHECK: Cost of 5 for VF 16: INTERLEAVE-GROUP with factor 4 at <badref>, ir<%p0>
-; CHECK: Cost of 9 for VF 32: INTERLEAVE-GROUP with factor 4 at %l0, ir<%p0>
+; CHECK: Cost of 5 for VF 16: INTERLEAVE-GROUP with factor 4 at %l0, ir<%p0>
 ; CHECK: Cost of 9 for VF 32: INTERLEAVE-GROUP with factor 4 at <badref>, ir<%p0>
+; CHECK: Cost of 9 for VF 32: INTERLEAVE-GROUP with factor 4 at %l0, ir<%p0>
 for.body:
   %i = phi i64 [ 0, %entry ], [ %i.next, %for.body ]
   %p0 = getelementptr inbounds %i8.4, ptr %data, i64 %i, i32 0
@@ -127,14 +127,14 @@ define void @i8_factor_5(ptr %data, i64 %n) {
 entry:
   br label %for.body
 ; CHECK-LABEL: Checking a loop in 'i8_factor_5'
-; CHECK: Cost of 2 for VF 2: INTERLEAVE-GROUP with factor 5 at %l0, ir<%p0>
 ; CHECK: Cost of 2 for VF 2: INTERLEAVE-GROUP with factor 5 at <badref>, ir<%p0>
-; CHECK: Cost of 3 for VF 4: INTERLEAVE-GROUP with factor 5 at %l0, ir<%p0>
+; CHECK: Cost of 2 for VF 2: INTERLEAVE-GROUP with factor 5 at %l0, ir<%p0>
 ; CHECK: Cost of 3 for VF 4: INTERLEAVE-GROUP with factor 5 at <badref>, ir<%p0>
-; CHECK: Cost of 5 for VF 8: INTERLEAVE-GROUP with factor 5 at %l0, ir<%p0>
+; CHECK: Cost of 3 for VF 4: INTERLEAVE-GROUP with factor 5 at %l0, ir<%p0>
 ; CHECK: Cost of 5 for VF 8: INTERLEAVE-GROUP with factor 5 at <badref>, ir<%p0>
-; CHECK: Cost of 9 for VF 16: INTERLEAVE-GROUP with factor 5 at %l0, ir<%p0>
+; CHECK: Cost of 5 for VF 8: INTERLEAVE-GROUP with factor 5 at %l0, ir<%p0>
 ; CHECK: Cost of 9 for VF 16: INTERLEAVE-GROUP with factor 5 at <badref>, ir<%p0>
+; CHECK: Cost of 9 for VF 16: INTERLEAVE-GROUP with factor 5 at %l0, ir<%p0>
 for.body:
   %i = phi i64 [ 0, %entry ], [ %i.next, %for.body ]
   %p0 = getelementptr inbounds %i8.5, ptr %data, i64 %i, i32 0
@@ -170,14 +170,14 @@ define void @i8_factor_6(ptr %data, i64 %n) {
 entry:
   br label %for.body
 ; CHECK-LABEL: Checking a loop in 'i8_factor_6'
-; CHECK: Cost of 2 for VF 2: INTERLEAVE-GROUP with factor 6 at %l0, ir<%p0>
 ; CHECK: Cost of 2 for VF 2: INTERLEAVE-GROUP with factor 6 at <badref>, ir<%p0>
-; CHECK: Cost of 3 for VF 4: INTERLEAVE-GROUP with factor 6 at %l0, ir<%p0>
+; CHECK: Cost of 2 for VF 2: INTERLEAVE-GROUP with factor 6 at %l0, ir<%p0>
 ; CHECK: Cost of 3 for VF 4: INTERLEAVE-GROUP with factor 6 at <badref>, ir<%p0>
-; CHECK: Cost of 5 for VF 8: INTERLEAVE-GROUP with factor 6 at %l0, ir<%p0>
+; CHECK: Cost of 3 for VF 4: INTERLEAVE-GROUP with factor 6 at %l0, ir<%p0>
 ; CHECK: Cost of 5 for VF 8: INTERLEAVE-GROUP with factor 6 at <badref>, ir<%p0>
-; CHECK: Cost of 9 for VF 16: INTERLEAVE-GROUP with factor 6 at %l0, ir<%p0>
+; CHECK: Cost of 5 for VF 8: INTERLEAVE-GROUP with factor 6 at %l0, ir<%p0>
 ; CHECK: Cost of 9 for VF 16: INTERLEAVE-GROUP with factor 6 at <badref>, ir<%p0>
+; CHECK: Cost of 9 for VF 16: INTERLEAVE-GROUP with factor 6 at %l0, ir<%p0>
 for.body:
   %i = phi i64 [ 0, %entry ], [ %i.next, %for.body ]
   %p0 = getelementptr inbounds %i8.6, ptr %data, i64 %i, i32 0
@@ -217,14 +217,14 @@ define void @i8_factor_7(ptr %data, i64 %n) {
 entry:
   br label %for.body
 ; CHECK-LABEL: Checking a loop in 'i8_factor_7'
-; CHECK: Cost of 2 for VF 2: INTERLEAVE-GROUP with factor 7 at %l0, ir<%p0>
 ; CHECK: Cost of 2 for VF 2: INTERLEAVE-GROUP with factor 7 at <badref>, ir<%p0>
-; CHECK: Cost of 3 for VF 4: INTERLEAVE-GROUP with factor 7 at %l0, ir<%p0>
+; CHECK: Cost of 2 for VF 2: INTERLEAVE-GROUP with factor 7 at %l0, ir<%p0>
 ; CHECK: Cost of 3 for VF 4: INTERLEAVE-GROUP with factor 7 at <badref>, ir<%p0>
-; CHECK: Cost of 5 for VF 8: INTERLEAVE-GROUP with factor 7 at %l0, ir<%p0>
+; CHECK: Cost of 3 for VF 4: INTERLEAVE-GROUP with factor 7 at %l0, ir<%p0>
 ; CHECK: Cost of 5 for VF 8: INTERLEAVE-GROUP with factor 7 at <badref>, ir<%p0>
-; CHECK: Cost of 9 for VF 16: INTERLEAVE-GROUP with factor 7 at %l0, ir<%p0>
+; CHECK: Cost of 5 for VF 8: INTERLEAVE-GROUP with factor 7 at %l0, ir<%p0>
 ; CHECK: Cost of 9 for VF 16: INTERLEAVE-GROUP with factor 7 at <badref>, ir<%p0>
+; CHECK: Cost of 9 for VF 16: INTERLEAVE-GROUP with factor 7 at %l0, ir<%p0>
 for.body:
   %i = phi i64 [ 0, %entry ], [ %i.next, %for.body ]
   %p0 = getelementptr inbounds %i8.7, ptr %data, i64 %i, i32 0
@@ -268,14 +268,14 @@ define void @i8_factor_8(ptr %data, i64 %n) {
 entry:
   br label %for.body
 ; CHECK-LABEL: Checking a loop in 'i8_factor_8'
-; CHECK: Cost of 2 for VF 2: INTERLEAVE-GROUP with factor 8 at %l0, ir<%p0>
 ; CHECK: Cost of 2 for VF 2: INTERLEAVE-GROUP with factor 8 at <badref>, ir<%p0>
-; CHECK: Cost of 3 for VF 4: INTERLEAVE-GROUP with factor 8 at %l0, ir<%p0>
+; CHECK: Cost of 2 for VF 2: INTERLEAVE-GROUP with factor 8 at %l0, ir<%p0>
 ; CHECK: Cost of 3 for VF 4: INTERLEAVE-GROUP with factor 8 at <badref>, ir<%p0>
-; CHECK: Cost of 5 for VF 8: INTERLEAVE-GROUP with factor 8 at %l0, ir<%p0>
+; CHECK: Cost of 3 for VF 4: INTERLEAVE-GROUP with factor 8 at %l0, ir<%p0>
 ; CHECK: Cost of 5 for VF 8: INTERLEAVE-GROUP with factor 8 at <badref>, ir<%p0>
-; CHECK: Cost of 9 for VF 16: INTERLEAVE-GROUP with factor 8 at %l0, ir<%p0>
+; CHECK: Cost of 5 for VF 8: INTERLEAVE-GROUP with factor 8 at %l0, ir<%p0>
 ; CHECK: Cost of 9 for VF 16: INTERLEAVE-GROUP with factor 8 at <badref>, ir<%p0>
+; CHECK: Cost of 9 for VF 16: INTERLEAVE-GROUP with factor 8 at %l0, ir<%p0>
 for.body:
   %i = phi i64 [ 0, %entry ], [ %i.next, %for.body ]
   %p0 = getelementptr inbounds %i8.8, ptr %data, i64 %i, i32 0

>From 442d1dd7269c6820daf73cf7fcc7a8de54029363 Mon Sep 17 00:00:00 2001
From: Elvis Wang <elvis.wang at sifive.com>
Date: Mon, 28 Oct 2024 05:39:35 -0700
Subject: [PATCH 2/2] [VPlan] Impl VPlan-based pattern match for ExtendedRed
 and MulAccRed. NFCI

This patch implement the VPlan-based pattern match for extendedReduction
and MulAccReduction. In above reduction patterns, extened instructions
and mul instruction can fold into reduction instruction and the cost is
free.

We add `FoldedRecipes` in the `VPCostContext` to put recipes that can be
folded into other recipes.

ExtendedReductionPatterns:
    reduce(ext(...))
MulAccReductionPatterns:
    reduce.add(mul(...))
    reduce.add(mul(ext(...), ext(...)))
    reduce.add(ext(mul(...)))
    reduce.add(ext(mul(ext(...), ext(...))))

Ref: Original instruction based implementation:
https://reviews.llvm.org/D93476
---
 .../Transforms/Vectorize/LoopVectorize.cpp    |  45 ------
 llvm/lib/Transforms/Vectorize/VPlan.h         |   2 +
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 139 ++++++++++++++++--
 3 files changed, 129 insertions(+), 57 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 60a94ca1f86e42..483e039fe133d6 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7303,51 +7303,6 @@ LoopVectorizationPlanner::precomputeCosts(VPlan &Plan, ElementCount VF,
       Cost += ReductionCost;
       continue;
     }
-
-    const auto &ChainOps = RdxDesc.getReductionOpChain(RedPhi, OrigLoop);
-    SetVector<Instruction *> ChainOpsAndOperands(ChainOps.begin(),
-                                                 ChainOps.end());
-    auto IsZExtOrSExt = [](const unsigned Opcode) -> bool {
-      return Opcode == Instruction::ZExt || Opcode == Instruction::SExt;
-    };
-    // Also include the operands of instructions in the chain, as the cost-model
-    // may mark extends as free.
-    //
-    // For ARM, some of the instruction can folded into the reducion
-    // instruction. So we need to mark all folded instructions free.
-    // For example: We can fold reduce(mul(ext(A), ext(B))) into one
-    // instruction.
-    for (auto *ChainOp : ChainOps) {
-      for (Value *Op : ChainOp->operands()) {
-        if (auto *I = dyn_cast<Instruction>(Op)) {
-          ChainOpsAndOperands.insert(I);
-          if (I->getOpcode() == Instruction::Mul) {
-            auto *Ext0 = dyn_cast<Instruction>(I->getOperand(0));
-            auto *Ext1 = dyn_cast<Instruction>(I->getOperand(1));
-            if (Ext0 && IsZExtOrSExt(Ext0->getOpcode()) && Ext1 &&
-                Ext0->getOpcode() == Ext1->getOpcode()) {
-              ChainOpsAndOperands.insert(Ext0);
-              ChainOpsAndOperands.insert(Ext1);
-            }
-          }
-        }
-      }
-    }
-
-    // Pre-compute the cost for I, if it has a reduction pattern cost.
-    for (Instruction *I : ChainOpsAndOperands) {
-      auto ReductionCost = CM.getReductionPatternCost(
-          I, VF, ToVectorTy(I->getType(), VF), TTI::TCK_RecipThroughput);
-      if (!ReductionCost)
-        continue;
-
-      assert(!CostCtx.SkipCostComputation.contains(I) &&
-             "reduction op visited multiple times");
-      CostCtx.SkipCostComputation.insert(I);
-      LLVM_DEBUG(dbgs() << "Cost of " << ReductionCost << " for VF " << VF
-                        << ":\n in-loop reduction " << *I << "\n");
-      Cost += *ReductionCost;
-    }
   }
 
   // Pre-compute the costs for branches except for the backedge, as the number
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 6a192bdf01c4ff..b26fd460a278f5 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -725,6 +725,8 @@ struct VPCostContext {
   LLVMContext &LLVMCtx;
   LoopVectorizationCostModel &CM;
   SmallPtrSet<Instruction *, 8> SkipCostComputation;
+  /// Contains recipes that are folded into other recipes.
+  SmallDenseMap<ElementCount, SmallPtrSet<VPRecipeBase *, 4>, 4> FoldedRecipes;
 
   VPCostContext(const TargetTransformInfo &TTI, const TargetLibraryInfo &TLI,
                 Type *CanIVTy, LoopVectorizationCostModel &CM)
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 0eb4f7c7c88cee..5f59a1e96df9f8 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -299,7 +299,9 @@ InstructionCost VPRecipeBase::cost(ElementCount VF, VPCostContext &Ctx) {
     UI = &WidenMem->getIngredient();
 
   InstructionCost RecipeCost;
-  if (UI && Ctx.skipCostComputation(UI, VF.isVector())) {
+  if ((UI && Ctx.skipCostComputation(UI, VF.isVector())) ||
+      (Ctx.FoldedRecipes.contains(VF) &&
+       Ctx.FoldedRecipes.at(VF).contains(this))) {
     RecipeCost = 0;
   } else {
     RecipeCost = computeCost(VF, Ctx);
@@ -2188,30 +2190,143 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
   unsigned Opcode = RdxDesc.getOpcode();
 
-  // TODO: Support any-of and in-loop reductions.
+  // TODO: Support any-of reductions.
   assert(
       (!RecurrenceDescriptor::isAnyOfRecurrenceKind(RdxKind) ||
        ForceTargetInstructionCost.getNumOccurrences() > 0) &&
       "Any-of reduction not implemented in VPlan-based cost model currently.");
-  assert(
-      (!cast<VPReductionPHIRecipe>(getOperand(0))->isInLoop() ||
-       ForceTargetInstructionCost.getNumOccurrences() > 0) &&
-      "In-loop reduction not implemented in VPlan-based cost model currently.");
 
   assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID() &&
          "Inferred type and recurrence type mismatch.");
 
-  // Cost = Reduction cost + BinOp cost
-  InstructionCost Cost =
+  // BaseCost = Reduction cost + BinOp cost
+  InstructionCost BaseCost =
       Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
   if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
     Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
-    return Cost + Ctx.TTI.getMinMaxReductionCost(
-                      Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+    BaseCost += Ctx.TTI.getMinMaxReductionCost(
+        Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+  } else {
+    BaseCost += Ctx.TTI.getArithmeticReductionCost(
+        Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
   }
 
-  return Cost + Ctx.TTI.getArithmeticReductionCost(
-                    Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+  using namespace llvm::VPlanPatternMatch;
+  auto GetMulAccReductionCost =
+      [&](const VPReductionRecipe *Red) -> InstructionCost {
+    VPValue *A, *B;
+    InstructionCost InnerExt0Cost = 0;
+    InstructionCost InnerExt1Cost = 0;
+    InstructionCost ExtCost = 0;
+    InstructionCost MulCost = 0;
+
+    VectorType *SrcVecTy = VectorTy;
+    Type *InnerExt0Ty;
+    Type *InnerExt1Ty;
+    Type *MaxInnerExtTy;
+    bool IsUnsigned = true;
+    bool HasOuterExt = false;
+
+    auto *Ext = dyn_cast_if_present<VPWidenCastRecipe>(
+        Red->getVecOp()->getDefiningRecipe());
+    VPRecipeBase *Mul;
+    // Try to match outer extend reduce.add(ext(...))
+    if (Ext && match(Ext, m_ZExtOrSExt(m_VPValue())) &&
+        cast<VPWidenCastRecipe>(Ext)->getNumUsers() == 1) {
+      IsUnsigned =
+          Ext->getOpcode() == Instruction::CastOps::ZExt ? true : false;
+      ExtCost = Ext->computeCost(VF, Ctx);
+      Mul = Ext->getOperand(0)->getDefiningRecipe();
+      HasOuterExt = true;
+    } else {
+      Mul = Red->getVecOp()->getDefiningRecipe();
+    }
+
+    // Match reduce.add(mul())
+    if (Mul && match(Mul, m_Mul(m_VPValue(A), m_VPValue(B))) &&
+        cast<VPWidenRecipe>(Mul)->getNumUsers() == 1) {
+      MulCost = cast<VPWidenRecipe>(Mul)->computeCost(VF, Ctx);
+      auto *InnerExt0 =
+          dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
+      auto *InnerExt1 =
+          dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
+      bool HasInnerExt = false;
+      // Try to match inner extends.
+      if (InnerExt0 && InnerExt1 &&
+          match(InnerExt0, m_ZExtOrSExt(m_VPValue())) &&
+          match(InnerExt1, m_ZExtOrSExt(m_VPValue())) &&
+          InnerExt0->getOpcode() == InnerExt1->getOpcode() &&
+          (InnerExt0->getNumUsers() > 0 &&
+           !InnerExt0->hasMoreThanOneUniqueUser()) &&
+          (InnerExt1->getNumUsers() > 0 &&
+           !InnerExt1->hasMoreThanOneUniqueUser())) {
+        InnerExt0Cost = InnerExt0->computeCost(VF, Ctx);
+        InnerExt1Cost = InnerExt1->computeCost(VF, Ctx);
+        Type *InnerExt0Ty = Ctx.Types.inferScalarType(InnerExt0->getOperand(0));
+        Type *InnerExt1Ty = Ctx.Types.inferScalarType(InnerExt1->getOperand(0));
+        Type *MaxInnerExtTy = InnerExt0Ty->getIntegerBitWidth() >
+                                      InnerExt1Ty->getIntegerBitWidth()
+                                  ? InnerExt0Ty
+                                  : InnerExt1Ty;
+        SrcVecTy = cast<VectorType>(ToVectorTy(MaxInnerExtTy, VF));
+        IsUnsigned = true;
+        HasInnerExt = true;
+      }
+      InstructionCost MulAccRedCost = Ctx.TTI.getMulAccReductionCost(
+          IsUnsigned, ElementTy, SrcVecTy, CostKind);
+      // Check if folding ext/mul into MulAccReduction is profitable.
+      if (MulAccRedCost.isValid() &&
+          MulAccRedCost <
+              ExtCost + MulCost + InnerExt0Cost + InnerExt1Cost + BaseCost) {
+        if (HasInnerExt) {
+          Ctx.FoldedRecipes[VF].insert(InnerExt0);
+          Ctx.FoldedRecipes[VF].insert(InnerExt1);
+        }
+        Ctx.FoldedRecipes[VF].insert(Mul);
+        if (HasOuterExt)
+          Ctx.FoldedRecipes[VF].insert(Ext);
+        return MulAccRedCost;
+      }
+    }
+    return InstructionCost::getInvalid();
+  };
+
+  // Match reduce(ext(...))
+  auto GetExtendedReductionCost =
+      [&](const VPReductionRecipe *Red) -> InstructionCost {
+    VPValue *VecOp = Red->getVecOp();
+    VPValue *A;
+    if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) && VecOp->getNumUsers() == 1) {
+      VPWidenCastRecipe *Ext =
+          cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
+      bool IsUnsigned = Ext->getOpcode() == Instruction::CastOps::ZExt;
+      InstructionCost ExtCost = Ext->computeCost(VF, Ctx);
+      auto *ExtVecTy =
+          cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(A), VF));
+      InstructionCost ExtendedRedCost = Ctx.TTI.getExtendedReductionCost(
+          Opcode, IsUnsigned, ElementTy, ExtVecTy, RdxDesc.getFastMathFlags(),
+          CostKind);
+      // Check if folding ext into ExtendedReduction is profitable.
+      if (ExtendedRedCost.isValid() && ExtendedRedCost < ExtCost + BaseCost) {
+        Ctx.FoldedRecipes[VF].insert(Ext);
+        return ExtendedRedCost;
+      }
+    }
+    return InstructionCost::getInvalid();
+  };
+
+  // Match MulAccReduction patterns.
+  InstructionCost MulAccCost = GetMulAccReductionCost(this);
+  if (MulAccCost.isValid())
+    return MulAccCost;
+
+  // Match ExtendedReduction patterns.
+  InstructionCost ExtendedCost = GetExtendedReductionCost(this);
+  if (ExtendedCost.isValid())
+    return ExtendedCost;
+
+  // Default cost.
+  return BaseCost;
 }
 
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)



More information about the llvm-commits mailing list