[llvm] eaf48dd - [VPlan] Replace BranchOnCount with BranchOnCond if TC <= UF * VF.

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 6 01:39:13 PDT 2022

Author: Florian Hahn
Date: 2022-06-06T09:38:53+01:00
New Revision: eaf48dd9b079c005d92ed9ef858f12bc452e71ef

URL: https://github.com/llvm/llvm-project/commit/eaf48dd9b079c005d92ed9ef858f12bc452e71ef
DIFF: https://github.com/llvm/llvm-project/commit/eaf48dd9b079c005d92ed9ef858f12bc452e71ef.diff

LOG: [VPlan] Replace BranchOnCount with BranchOnCond if TC <= UF * VF.

Try to simplify BranchOnCount to `BranchOnCond true` if TC <= UF * VF.

This is an alternative to D121899 which simplifies the VPlan directly
instead of doing so late in code-gen.

The potential benefit of doing this in VPlan is that this may help
cost-modeling in the future. The reason this is done in prepareToExecute
at the moment is that a single plan may be used for multiple VFs/UFs.

There are further simplifications that can be applied as follow ups:

1. Replace inductions with constants
2. Replace vector region with regular block.

Fixes #55354.

Depends on D126679.

Reviewed By: Ayal

Differential Revision: https://reviews.llvm.org/D126680




diff  --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp
index 729931a1c5e68..0f15b9949239d 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp
@@ -767,6 +767,7 @@ void VPInstruction::generateInstruction(VPTransformState &State,
   case VPInstruction::BranchOnCond: {
     if (Part != 0)
     Value *Cond = State.get(getOperand(0), VPIteration(Part, 0));
     VPRegionBlock *ParentRegion = getParent()->getParent();
     VPBasicBlock *Header = ParentRegion->getEntryBasicBlock();
@@ -898,6 +899,28 @@ void VPInstruction::setFastMathFlags(FastMathFlags FMFNew) {
 void VPlan::prepareToExecute(Value *TripCountV, Value *VectorTripCountV,
                              Value *CanonicalIVStartValue,
                              VPTransformState &State) {
+  VPBasicBlock *ExitingVPBB = getVectorLoopRegion()->getExitingBasicBlock();
+  auto *Term = dyn_cast<VPInstruction>(&ExitingVPBB->back());
+  // Try to simplify BranchOnCount to 'BranchOnCond true' if TC <= VF * UF when
+  // preparing to execute the plan for the main vector loop.
+  if (!CanonicalIVStartValue && Term &&
+      Term->getOpcode() == VPInstruction::BranchOnCount &&
+      isa<ConstantInt>(TripCountV)) {
+    ConstantInt *C = cast<ConstantInt>(TripCountV);
+    uint64_t TCVal = C->getZExtValue();
+    if (TCVal && TCVal <= State.VF.getKnownMinValue() * State.UF) {
+      auto *BOC =
+          new VPInstruction(VPInstruction::BranchOnCond,
+                            {getOrAddExternalDef(State.Builder.getTrue())});
+      Term->eraseFromParent();
+      ExitingVPBB->appendRecipe(BOC);
+      // TODO: Further simplifications are possible
+      //      1. Replace inductions with constants.
+      //      2. Replace vector loop region with VPBasicBlock.
+    }
+  }
   // Check if the trip count is needed, and if so build it.
   if (TripCount && TripCount->getNumUsers()) {
     for (unsigned Part = 0, UF = State.UF; Part < UF; ++Part)

diff  --git a/llvm/test/Transforms/LoopVectorize/AArch64/sve-low-trip-count.ll b/llvm/test/Transforms/LoopVectorize/AArch64/sve-low-trip-count.ll
index e299d35a2dcd3..15367cbac716f 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/sve-low-trip-count.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/sve-low-trip-count.ll
@@ -47,8 +47,7 @@ define void @trip5_i8(i8* noalias nocapture noundef %dst, i8* noalias nocapture
 ; CHECK:         [[VSCALE:%.*]] = call i64 @llvm.vscale.i64()
 ; CHECK-NEXT:    [[VF:%.*]] = mul i64 [[VSCALE]], 16
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add i64 [[INDEX]], [[VF]]
-; CHECK-NEXT:    [[COND:%.*]] = icmp eq i64 [[INDEX_NEXT]], {{%.*}}
-; CHECK-NEXT:    br i1 [[COND]], label %middle.block, label %vector.body
+; CHECK-NEXT:    br i1 true, label %middle.block, label %vector.body
   br label %for.body

diff  --git a/llvm/test/Transforms/LoopVectorize/X86/constant-fold.ll b/llvm/test/Transforms/LoopVectorize/X86/constant-fold.ll
index 90377a0500a66..2a5537f283dfd 100644
--- a/llvm/test/Transforms/LoopVectorize/X86/constant-fold.ll
+++ b/llvm/test/Transforms/LoopVectorize/X86/constant-fold.ll
@@ -26,8 +26,7 @@ define void @f1() {
 ; CHECK-NEXT:    [[TMP4:%.*]] = bitcast i16** [[TMP3]] to <2 x i16*>*
 ; CHECK-NEXT:    store <2 x i16*> <i16* getelementptr inbounds ([1 x %rec8], [1 x %rec8]* @a, i32 0, i32 0, i32 0), i16* getelementptr inbounds ([1 x %rec8], [1 x %rec8]* @a, i32 0, i32 0, i32 0)>, <2 x i16*>* [[TMP4]], align 8
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 2
-; CHECK-NEXT:    [[TMP5:%.*]] = icmp eq i32 [[INDEX_NEXT]], 2
-; CHECK-NEXT:    br i1 [[TMP5]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], [[LOOP0:!llvm.loop !.*]]
+; CHECK-NEXT:    br i1 true, label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], [[LOOP0:!llvm.loop !.*]]
 ; CHECK:       middle.block:
 ; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i32 2, 2
 ; CHECK-NEXT:    br i1 [[CMP_N]], label [[BB3:%.*]], label [[SCALAR_PH]]

diff  --git a/llvm/test/Transforms/LoopVectorize/X86/outer_loop_test1_no_explicit_vect_width.ll b/llvm/test/Transforms/LoopVectorize/X86/outer_loop_test1_no_explicit_vect_width.ll
index 2623257149cc2..e85dc71f63aba 100644
--- a/llvm/test/Transforms/LoopVectorize/X86/outer_loop_test1_no_explicit_vect_width.ll
+++ b/llvm/test/Transforms/LoopVectorize/X86/outer_loop_test1_no_explicit_vect_width.ll
@@ -71,10 +71,9 @@
 ; AVX: br i1 %[[InnerCond]], label %[[ForInc]], label %[[InnerLoop]]
 ; AVX: [[ForInc]]:
-; AVX: %[[IndNext]] = add nuw i64 %[[Ind]], 8
 ; AVX: %[[VecIndNext]] = add <8 x i64> %[[VecInd]], <i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8>
-; AVX: %[[Cmp:.*]] = icmp eq i64 %[[IndNext]], 8
-; AVX: br i1 %[[Cmp]], label %middle.block, label %vector.body
+; AVX: %[[IndNext]] = add nuw i64 %[[Ind]], 8
+; AVX: br i1 true, label %middle.block, label %vector.body
 @arr2 = external global [8 x i32], align 16
 @arr = external global [8 x [8 x i32]], align 16

diff  --git a/llvm/test/Transforms/LoopVectorize/X86/pr34438.ll b/llvm/test/Transforms/LoopVectorize/X86/pr34438.ll
index 84df4aaf08af6..005ad9bc9b28d 100644
--- a/llvm/test/Transforms/LoopVectorize/X86/pr34438.ll
+++ b/llvm/test/Transforms/LoopVectorize/X86/pr34438.ll
@@ -30,8 +30,7 @@ define void @small_tc(float* noalias nocapture %A, float* noalias nocapture read
 ; CHECK-NEXT:    [[TMP8:%.*]] = bitcast float* [[TMP5]] to <8 x float>*
 ; CHECK-NEXT:    store <8 x float> [[TMP7]], <8 x float>* [[TMP8]], align 4, !llvm.access.group !0
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 8
-; CHECK-NEXT:    [[TMP9:%.*]] = icmp eq i64 [[INDEX_NEXT]], 8
-; CHECK-NEXT:    br i1 [[TMP9]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], [[LOOP1:!llvm.loop !.*]]
+; CHECK-NEXT:    br i1 true, label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], [[LOOP1:!llvm.loop !.*]]
 ; CHECK:       middle.block:
 ; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i64 8, 8
 ; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_END:%.*]], label [[SCALAR_PH]]

diff  --git a/llvm/test/Transforms/LoopVectorize/X86/pr42674.ll b/llvm/test/Transforms/LoopVectorize/X86/pr42674.ll
index 7516c055ab732..7d469e7e76e16 100644
--- a/llvm/test/Transforms/LoopVectorize/X86/pr42674.ll
+++ b/llvm/test/Transforms/LoopVectorize/X86/pr42674.ll
@@ -9,26 +9,18 @@
 define zeroext i8 @sum() {
 ; CHECK-LABEL: @sum(
 ; CHECK-NEXT:  iter.check:
-; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
-; CHECK:       vector.body:
-; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <64 x i8> [ zeroinitializer, [[ENTRY]] ], [ [[TMP4:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT:    [[VEC_PHI1:%.*]] = phi <64 x i8> [ zeroinitializer, [[ENTRY]] ], [ [[TMP5:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds [128 x i8], [128 x i8]* @bytes, i64 0, i64 [[INDEX]]
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds [128 x i8], [128 x i8]* @bytes, i64 0, i64 0
 ; CHECK-NEXT:    [[TMP1:%.*]] = bitcast i8* [[TMP0]] to <64 x i8>*
 ; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <64 x i8>, <64 x i8>* [[TMP1]], align 16
 ; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i8, i8* [[TMP0]], i64 64
 ; CHECK-NEXT:    [[TMP3:%.*]] = bitcast i8* [[TMP2]] to <64 x i8>*
 ; CHECK-NEXT:    [[WIDE_LOAD2:%.*]] = load <64 x i8>, <64 x i8>* [[TMP3]], align 16
-; CHECK-NEXT:    [[TMP4]] = add <64 x i8> [[WIDE_LOAD]], [[VEC_PHI]]
-; CHECK-NEXT:    [[TMP5]] = add <64 x i8> [[WIDE_LOAD2]], [[VEC_PHI1]]
-; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 128
-; CHECK-NEXT:    [[TMP6:%.*]] = icmp eq i64 [[INDEX]], 0
-; CHECK-NEXT:    br i1 [[TMP6]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop !0
-; CHECK:       middle.block:
+; CHECK-NEXT:    [[TMP4:%.*]] = add <64 x i8> [[WIDE_LOAD]], zeroinitializer
+; CHECK-NEXT:    [[TMP5:%.*]] = add <64 x i8> [[WIDE_LOAD2]], zeroinitializer
+; CHECK-NEXT:    [[INDEX_NEXT:%.*]] = add nuw i64 0, 128
 ; CHECK-NEXT:    [[BIN_RDX:%.*]] = add <64 x i8> [[TMP5]], [[TMP4]]
-; CHECK-NEXT:    [[TMP7:%.*]] = call i8 @llvm.vector.reduce.add.v64i8(<64 x i8> [[BIN_RDX]])
-; CHECK-NEXT:    ret i8 [[TMP7]]
+; CHECK-NEXT:    [[TMP6:%.*]] = call i8 @llvm.vector.reduce.add.v64i8(<64 x i8> [[BIN_RDX]])
+; CHECK-NEXT:    ret i8 [[TMP6]]
   br label %for.body


More information about the llvm-commits mailing list