[llvm] ba45453 - [SimplifyCFG] Skip threading if the target may have divergent branches

via llvm-commits llvm-commits at lists.llvm.org
Fri Jul 26 09:15:52 PDT 2024


Author: darkbuck
Date: 2024-07-26T12:15:49-04:00
New Revision: ba45453c0a5df3e6c3eddee647e14c97e02243fa

URL: https://github.com/llvm/llvm-project/commit/ba45453c0a5df3e6c3eddee647e14c97e02243fa
DIFF: https://github.com/llvm/llvm-project/commit/ba45453c0a5df3e6c3eddee647e14c97e02243fa.diff

LOG: [SimplifyCFG] Skip threading if the target may have divergent branches

- This patch skips the threading on known values if the target has
  divergent branch.
- So far, threading on known values is skipped when the basic block has
  covergent calls. However, even without convergent calls, if that
  condition is divergent, threading duplicates the execution of that
  block threaded and hence results in lower performance. E.g.,
  ```
  BB1:
    if (cond) BB3, BB2

  BB2:
    // work2
    br BB3

  BB3:
    // work3
    if (cond) BB5, BB4

  BB4:
    // work4
    br BB5

  BB5:
  ```

  after threading,

  ```
  BB1:
    if (cond) BB3', BB2'

  BB2':
    // work3
    br BB5

  BB3':
    // work2
    // work3
    // work4
    br BB5

  BB5:

  ```

  After threading, work3 is executed twice if 'cond' is a divergent one.

Reviewers: yxsamliu, nikic

Pull Request: https://github.com/llvm/llvm-project/pull/100185

Added: 
    llvm/test/Transforms/SimplifyCFG/AMDGPU/skip-threading.ll

Modified: 
    llvm/lib/Transforms/Utils/SimplifyCFG.cpp
    llvm/test/Transforms/SimplifyCFG/convergent.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index f23e28888931d..1a17524b826a1 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -3246,7 +3246,12 @@ bool SimplifyCFGOpt::SpeculativelyExecuteBB(BranchInst *BI,
 }
 
 /// Return true if we can thread a branch across this block.
-static bool BlockIsSimpleEnoughToThreadThrough(BasicBlock *BB) {
+static bool BlockIsSimpleEnoughToThreadThrough(BasicBlock *BB,
+                                               const TargetTransformInfo &TTI) {
+  // Skip threading if the branch may be divergent.
+  if (TTI.hasBranchDivergence(BB->getParent()))
+    return false;
+
   int Size = 0;
   EphemeralValueTracker EphTracker;
 
@@ -3301,10 +3306,9 @@ static ConstantInt *getKnownValueOnEdge(Value *V, BasicBlock *From,
 /// If we have a conditional branch on something for which we know the constant
 /// value in predecessors (e.g. a phi node in the current block), thread edges
 /// from the predecessor to their ultimate destination.
-static std::optional<bool>
-FoldCondBranchOnValueKnownInPredecessorImpl(BranchInst *BI, DomTreeUpdater *DTU,
-                                            const DataLayout &DL,
-                                            AssumptionCache *AC) {
+static std::optional<bool> FoldCondBranchOnValueKnownInPredecessorImpl(
+    BranchInst *BI, DomTreeUpdater *DTU, const DataLayout &DL,
+    const TargetTransformInfo &TTI, AssumptionCache *AC) {
   SmallMapVector<ConstantInt *, SmallSetVector<BasicBlock *, 2>, 2> KnownValues;
   BasicBlock *BB = BI->getParent();
   Value *Cond = BI->getCondition();
@@ -3332,7 +3336,7 @@ FoldCondBranchOnValueKnownInPredecessorImpl(BranchInst *BI, DomTreeUpdater *DTU,
   // Now we know that this block has multiple preds and two succs.
   // Check that the block is small enough and values defined in the block are
   // not used outside of it.
-  if (!BlockIsSimpleEnoughToThreadThrough(BB))
+  if (!BlockIsSimpleEnoughToThreadThrough(BB, TTI))
     return false;
 
   for (const auto &Pair : KnownValues) {
@@ -3459,15 +3463,14 @@ FoldCondBranchOnValueKnownInPredecessorImpl(BranchInst *BI, DomTreeUpdater *DTU,
   return false;
 }
 
-static bool FoldCondBranchOnValueKnownInPredecessor(BranchInst *BI,
-                                                    DomTreeUpdater *DTU,
-                                                    const DataLayout &DL,
-                                                    AssumptionCache *AC) {
+static bool FoldCondBranchOnValueKnownInPredecessor(
+    BranchInst *BI, DomTreeUpdater *DTU, const DataLayout &DL,
+    const TargetTransformInfo &TTI, AssumptionCache *AC) {
   std::optional<bool> Result;
   bool EverChanged = false;
   do {
     // Note that None means "we changed things, but recurse further."
-    Result = FoldCondBranchOnValueKnownInPredecessorImpl(BI, DTU, DL, AC);
+    Result = FoldCondBranchOnValueKnownInPredecessorImpl(BI, DTU, DL, TTI, AC);
     EverChanged |= Result == std::nullopt || *Result;
   } while (Result == std::nullopt);
   return EverChanged;
@@ -7543,7 +7546,7 @@ bool SimplifyCFGOpt::simplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) {
   // If this is a branch on something for which we know the constant value in
   // predecessors (e.g. a phi node in the current block), thread control
   // through this block.
-  if (FoldCondBranchOnValueKnownInPredecessor(BI, DTU, DL, Options.AC))
+  if (FoldCondBranchOnValueKnownInPredecessor(BI, DTU, DL, TTI, Options.AC))
     return requestResimplify();
 
   // Scan predecessor blocks for conditional branches.

diff  --git a/llvm/test/Transforms/SimplifyCFG/AMDGPU/skip-threading.ll b/llvm/test/Transforms/SimplifyCFG/AMDGPU/skip-threading.ll
new file mode 100644
index 0000000000000..b1262e294c6d0
--- /dev/null
+++ b/llvm/test/Transforms/SimplifyCFG/AMDGPU/skip-threading.ll
@@ -0,0 +1,44 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -mtriple=amdgcn -S -passes=simplifycfg < %s | FileCheck %s
+
+declare void @bar1()
+declare void @bar2()
+declare void @bar3()
+
+define i32 @test_01a(i32 %a) {
+; CHECK-LABEL: define i32 @test_01a(
+; CHECK-SAME: i32 [[A:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[COND:%.*]] = icmp eq i32 [[A]], 0
+; CHECK-NEXT:    br i1 [[COND]], label %[[MERGE:.*]], label %[[IF_FALSE:.*]]
+; CHECK:       [[IF_FALSE]]:
+; CHECK-NEXT:    call void @bar1()
+; CHECK-NEXT:    br label %[[MERGE]]
+; CHECK:       [[MERGE]]:
+; CHECK-NEXT:    call void @bar2()
+; CHECK-NEXT:    br i1 [[COND]], label %[[EXIT:.*]], label %[[IF_FALSE_2:.*]]
+; CHECK:       [[IF_FALSE_2]]:
+; CHECK-NEXT:    call void @bar3()
+; CHECK-NEXT:    br label %[[EXIT]]
+; CHECK:       [[EXIT]]:
+; CHECK-NEXT:    ret i32 [[A]]
+;
+entry:
+  %cond = icmp eq i32 %a, 0
+  br i1 %cond, label %merge, label %if.false
+
+if.false:
+  call void @bar1()
+  br label %merge
+
+merge:
+  call void @bar2()
+  br i1 %cond, label %exit, label %if.false.2
+
+if.false.2:
+  call void @bar3()
+  br label %exit
+
+exit:
+  ret i32 %a
+}

diff  --git a/llvm/test/Transforms/SimplifyCFG/convergent.ll b/llvm/test/Transforms/SimplifyCFG/convergent.ll
index 6ba51e06460c2..d148063589de6 100644
--- a/llvm/test/Transforms/SimplifyCFG/convergent.ll
+++ b/llvm/test/Transforms/SimplifyCFG/convergent.ll
@@ -4,6 +4,9 @@
 ; RUN: opt -S -passes='simplifycfg<hoist-common-insts;sink-common-insts>' < %s | FileCheck -check-prefixes=CHECK,SINK %s
 
 declare void @foo() convergent
+declare void @bar1()
+declare void @bar2()
+declare void @bar3()
 declare i32 @tid()
 declare i32 @mbcnt(i32 %a, i32 %b) convergent
 declare i32 @bpermute(i32 %a, i32 %b) convergent
@@ -45,6 +48,42 @@ exit:
   ret i32 %a
 }
 
+define i32 @test_01a(i32 %a) {
+; CHECK-LABEL: @test_01a(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[COND:%.*]] = icmp eq i32 [[A:%.*]], 0
+; CHECK-NEXT:    br i1 [[COND]], label [[EXIT_CRITEDGE:%.*]], label [[IF_FALSE:%.*]]
+; CHECK:       if.false:
+; CHECK-NEXT:    call void @bar1()
+; CHECK-NEXT:    call void @bar2()
+; CHECK-NEXT:    call void @bar3()
+; CHECK-NEXT:    br label [[EXIT:%.*]]
+; CHECK:       exit.critedge:
+; CHECK-NEXT:    call void @bar2()
+; CHECK-NEXT:    br label [[EXIT]]
+; CHECK:       exit:
+; CHECK-NEXT:    ret i32 [[A]]
+;
+entry:
+  %cond = icmp eq i32 %a, 0
+  br i1 %cond, label %merge, label %if.false
+
+if.false:
+  call void @bar1()
+  br label %merge
+
+merge:
+  call void @bar2()
+  br i1 %cond, label %exit, label %if.false.2
+
+if.false.2:
+  call void @bar3()
+  br label %exit
+
+exit:
+  ret i32 %a
+}
+
 define void @test_02(ptr %y.coerce) convergent {
 ; NOSINK-LABEL: @test_02(
 ; NOSINK-NEXT:  entry:


        


More information about the llvm-commits mailing list