[llvm] [VPlan] Introduce ComputeReductionResult VPInstruction opcode. (PR #70253)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 6 13:01:21 PST 2023


================
@@ -403,6 +404,138 @@ Value *VPInstruction::generateInstruction(VPTransformState &State,
     Builder.GetInsertBlock()->getTerminator()->eraseFromParent();
     return CondBr;
   }
+  case VPInstruction::ComputeReductionResult: {
+    if (Part != 0)
+      return State.get(
+          this, VPIteration(State.UF - 1, VPLane::getLastLaneForVF(State.VF)));
+
+    auto *PhiR = dyn_cast<VPReductionPHIRecipe>(getOperand(0));
+    PHINode *OrigPhi = cast<PHINode>(PhiR->getUnderlyingValue());
+    // Get it's reduction variable descriptor.
+    const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor();
+
+    BasicBlock *LoopMiddleBlock = State.CFG.VPBB2IRBB[getParent()];
+    Builder.SetInsertPoint(LoopMiddleBlock,
+                           LoopMiddleBlock->getFirstInsertionPt());
+
+    RecurKind RK = RdxDesc.getRecurrenceKind();
+    TrackingVH<Value> ReductionStartValue = RdxDesc.getRecurrenceStartValue();
+    Instruction *LoopExitInst = RdxDesc.getLoopExitInstr();
+    if (auto *I = dyn_cast<Instruction>(&*ReductionStartValue))
+      State.setDebugLocFrom(I->getDebugLoc());
+
+    VPValue *LoopExitInstDef = getOperand(1);
+    // This is the vector-clone of the value that leaves the loop.
+    // State.setDebugLocFrom(LoopExitInst->getDebugLoc());
+
+    Type *PhiTy = OrigPhi->getType();
+    VectorParts RdxParts(State.UF);
+    for (unsigned Part = 0; Part < State.UF; ++Part)
+      RdxParts[Part] = State.get(LoopExitInstDef, Part);
+
+    // If the vector reduction can be performed in a smaller type, we truncate
+    // then extend the loop exit value to enable InstCombine to evaluate the
+    // entire expression in the smaller type.
+    if (State.VF.isVector() && PhiTy != RdxDesc.getRecurrenceType()) {
+      Builder.SetInsertPoint(LoopMiddleBlock,
+                             LoopMiddleBlock->getFirstInsertionPt());
+      Type *RdxVecTy = VectorType::get(RdxDesc.getRecurrenceType(), State.VF);
+      for (unsigned Part = 0; Part < State.UF; ++Part)
+        RdxParts[Part] = Builder.CreateTrunc(RdxParts[Part], RdxVecTy);
+    }
+    // Reduce all of the unrolled parts into a single vector.
+    Value *ReducedPartRdx = RdxParts[0];
+    unsigned Op = RecurrenceDescriptor::getOpcode(RK);
+
+    // The middle block terminator has already been assigned a DebugLoc here
+    // (the OrigLoop's single latch terminator). We want the whole middle block
+    // to appear to execute on this line because: (a) it is all compiler
+    // generated, (b) these instructions are always executed after evaluating
+    // the latch conditional branch, and (c) other passes may add new
+    // predecessors which terminate on this line. This is the easiest way to
+    // ensure we don't accidentally cause an extra step back into the loop while
+    // debugging.
+    State.setDebugLocFrom(LoopMiddleBlock->getTerminator()->getDebugLoc());
+    if (PhiR->isOrdered())
+      ReducedPartRdx = RdxParts[State.UF - 1];
+    else {
----------------
fhahn wrote:

Done, thanks!

https://github.com/llvm/llvm-project/pull/70253


More information about the llvm-commits mailing list