[llvm] [VPlan] Introduce ComputeReductionResult VPInstruction opcode. (PR #70253)
    Florian Hahn via llvm-commits 
    llvm-commits at lists.llvm.org
       
    Mon Nov  6 13:01:26 PST 2023
    
    
  
================
@@ -403,6 +404,138 @@ Value *VPInstruction::generateInstruction(VPTransformState &State,
     Builder.GetInsertBlock()->getTerminator()->eraseFromParent();
     return CondBr;
   }
+  case VPInstruction::ComputeReductionResult: {
+    if (Part != 0)
+      return State.get(
+          this, VPIteration(State.UF - 1, VPLane::getLastLaneForVF(State.VF)));
+
+    auto *PhiR = dyn_cast<VPReductionPHIRecipe>(getOperand(0));
+    PHINode *OrigPhi = cast<PHINode>(PhiR->getUnderlyingValue());
+    // Get it's reduction variable descriptor.
+    const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor();
+
+    BasicBlock *LoopMiddleBlock = State.CFG.VPBB2IRBB[getParent()];
+    Builder.SetInsertPoint(LoopMiddleBlock,
+                           LoopMiddleBlock->getFirstInsertionPt());
+
+    RecurKind RK = RdxDesc.getRecurrenceKind();
+    TrackingVH<Value> ReductionStartValue = RdxDesc.getRecurrenceStartValue();
+    Instruction *LoopExitInst = RdxDesc.getLoopExitInstr();
+    if (auto *I = dyn_cast<Instruction>(&*ReductionStartValue))
+      State.setDebugLocFrom(I->getDebugLoc());
+
+    VPValue *LoopExitInstDef = getOperand(1);
+    // This is the vector-clone of the value that leaves the loop.
+    // State.setDebugLocFrom(LoopExitInst->getDebugLoc());
+
+    Type *PhiTy = OrigPhi->getType();
+    VectorParts RdxParts(State.UF);
+    for (unsigned Part = 0; Part < State.UF; ++Part)
+      RdxParts[Part] = State.get(LoopExitInstDef, Part);
+
+    // If the vector reduction can be performed in a smaller type, we truncate
+    // then extend the loop exit value to enable InstCombine to evaluate the
+    // entire expression in the smaller type.
+    if (State.VF.isVector() && PhiTy != RdxDesc.getRecurrenceType()) {
+      Builder.SetInsertPoint(LoopMiddleBlock,
+                             LoopMiddleBlock->getFirstInsertionPt());
+      Type *RdxVecTy = VectorType::get(RdxDesc.getRecurrenceType(), State.VF);
+      for (unsigned Part = 0; Part < State.UF; ++Part)
+        RdxParts[Part] = Builder.CreateTrunc(RdxParts[Part], RdxVecTy);
+    }
+    // Reduce all of the unrolled parts into a single vector.
+    Value *ReducedPartRdx = RdxParts[0];
+    unsigned Op = RecurrenceDescriptor::getOpcode(RK);
+
+    // The middle block terminator has already been assigned a DebugLoc here
+    // (the OrigLoop's single latch terminator). We want the whole middle block
+    // to appear to execute on this line because: (a) it is all compiler
+    // generated, (b) these instructions are always executed after evaluating
+    // the latch conditional branch, and (c) other passes may add new
+    // predecessors which terminate on this line. This is the easiest way to
+    // ensure we don't accidentally cause an extra step back into the loop while
+    // debugging.
+    State.setDebugLocFrom(LoopMiddleBlock->getTerminator()->getDebugLoc());
+    if (PhiR->isOrdered())
+      ReducedPartRdx = RdxParts[State.UF - 1];
+    else {
+      // Floating-point operations should have some FMF to enable the reduction.
+      IRBuilderBase::FastMathFlagGuard FMFG(Builder);
+      Builder.setFastMathFlags(RdxDesc.getFastMathFlags());
+      for (unsigned Part = 1; Part < State.UF; ++Part) {
+        Value *RdxPart = RdxParts[Part];
+        if (Op != Instruction::ICmp && Op != Instruction::FCmp)
+          ReducedPartRdx = Builder.CreateBinOp(
+              (Instruction::BinaryOps)Op, RdxPart, ReducedPartRdx, "bin.rdx");
+        else if (RecurrenceDescriptor::isAnyOfRecurrenceKind(RK))
+          ReducedPartRdx = createAnyOfOp(Builder, ReductionStartValue, RK,
+                                         ReducedPartRdx, RdxPart);
+        else
+          ReducedPartRdx = createMinMaxOp(Builder, RK, ReducedPartRdx, RdxPart);
+      }
+    }
+
+    // Create the reduction after the loop. Note that inloop reductions create
+    // the target reduction in the loop using a Reduction recipe.
+    if (State.VF.isVector() && !PhiR->isInLoop()) {
+      ReducedPartRdx =
+          createTargetReduction(Builder, RdxDesc, ReducedPartRdx, OrigPhi);
+      // If the reduction can be performed in a smaller type, we need to extend
+      // the reduction to the wider type before we branch to the original loop.
+      if (PhiTy != RdxDesc.getRecurrenceType())
+        ReducedPartRdx = RdxDesc.isSigned()
+                             ? Builder.CreateSExt(ReducedPartRdx, PhiTy)
+                             : Builder.CreateZExt(ReducedPartRdx, PhiTy);
+    }
+
+    PHINode *ResumePhi =
+        dyn_cast<PHINode>(PhiR->getStartValue()->getUnderlyingValue());
+
+    auto *OrigLoop = State.LI->getLoopFor(OrigPhi->getParent());
+    // TODO: bc.merge.rdx should not be created here, instead it should be
+    // modeled in VPlan.
+    BasicBlock *LoopScalarPreHeader = OrigLoop->getLoopPreheader();
+    // Create a phi node that merges control-flow from the backedge-taken check
+    // block and the middle block.
+    PHINode *BCBlockPhi = PHINode::Create(PhiTy, 2, "bc.merge.rdx",
+                                          LoopScalarPreHeader->getTerminator());
+
+    // If we are fixing reductions in the epilogue loop then we should already
+    // have created a bc.merge.rdx Phi after the main vector body. Ensure that
+    // we carry over the incoming values correctly.
+    for (auto *Incoming : predecessors(LoopScalarPreHeader)) {
+      if (Incoming == LoopMiddleBlock)
+        BCBlockPhi->addIncoming(ReducedPartRdx, Incoming);
+      else if (ResumePhi && llvm::is_contained(ResumePhi->blocks(), Incoming))
+        BCBlockPhi->addIncoming(ResumePhi->getIncomingValueForBlock(Incoming),
+                                Incoming);
+      else
+        BCBlockPhi->addIncoming(ReductionStartValue, Incoming);
+    }
+
+    // If there were stores of the reduction value to a uniform memory address
+    // inside the loop, create the final store here.
+    if (StoreInst *SI = RdxDesc.IntermediateStore) {
+      StoreInst *NewSI = Builder.CreateAlignedStore(
----------------
fhahn wrote:
Done, thanks!
https://github.com/llvm/llvm-project/pull/70253
    
    
More information about the llvm-commits
mailing list