[llvm] [DFAJumpThreading] Add an early exit heuristic for unpredictable values (PR #85015)

Usman Nadeem via llvm-commits llvm-commits at lists.llvm.org
Tue Mar 12 21:05:00 PDT 2024


https://github.com/UsmanNadeem updated https://github.com/llvm/llvm-project/pull/85015

>From 90c031f3d7ef4ff01f1673f7a6d0cbafef9b3f66 Mon Sep 17 00:00:00 2001
From: "Nadeem, Usman" <mnadeem at quicinc.com>
Date: Tue, 12 Mar 2024 17:20:43 -0700
Subject: [PATCH] [DFAJumpThreading] Add an early exit heuristic for
 unpredictable values

Right now the algorithm does not exit on unpredictable values. It
waits until all the paths have been enumerated to see if any of
those paths have that value. Waiting this late leads to a lot of
wasteful computation and higher compile time.

In this patch I have added a heuristic that checks if the value
comes from the same inner loops as the switch, if so, then it is
likely that the value will also be seen on a threadable path and
the code in `getStateDefMap()` return an empty map.

I tested this on the llvm test suite and the only change in the
number of threaded switches was in 7zip (before 23, after 18).
In all of those cases the current algorithm was partially threading
the loop because it was hitting a limit on the number of paths to
be explored. On increasing this limit even the current algorithm
finds paths where the unpredictable value is seen.

Compile time(with pass enabled by default and this patch):
  https://llvm-compile-time-tracker.com/compare.php?from=8c5e9cf737138aba22a4a8f64ef2c5efc80dd7f9&to=42c75d888058b35c6d15901b34e36251d8f766b9&stat=instructions:u

Change-Id: Id6b61a2ce177cdb433c97b7916218a7fc2092d73
---
 .../Transforms/Scalar/DFAJumpThreading.cpp    |  53 ++++++--
 .../DFAJumpThreading/dfa-unfold-select.ll     |   2 +-
 .../unpredictable-heuristic.ll                | 124 ++++++++++++++++++
 3 files changed, 164 insertions(+), 15 deletions(-)
 create mode 100644 llvm/test/Transforms/DFAJumpThreading/unpredictable-heuristic.ll

diff --git a/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp b/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp
index 85d4065286e41f..768e018d70b79c 100644
--- a/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp
+++ b/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp
@@ -65,6 +65,7 @@
 #include "llvm/Analysis/AssumptionCache.h"
 #include "llvm/Analysis/CodeMetrics.h"
 #include "llvm/Analysis/DomTreeUpdater.h"
+#include "llvm/Analysis/LoopInfo.h"
 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/IR/CFG.h"
@@ -95,6 +96,11 @@ static cl::opt<bool>
                     cl::desc("View the CFG before DFA Jump Threading"),
                     cl::Hidden, cl::init(false));
 
+static cl::opt<bool> EarlyExitHeuristic(
+    "dfa-early-exit-heuristic",
+    cl::desc("Exit early if an unpredictable value come from the same loop"),
+    cl::Hidden, cl::init(true));
+
 static cl::opt<unsigned> MaxPathLength(
     "dfa-max-path-length",
     cl::desc("Max number of blocks searched to find a threading path"),
@@ -131,9 +137,9 @@ void unfold(DomTreeUpdater *DTU, SelectInstToUnfold SIToUnfold,
 
 class DFAJumpThreading {
 public:
-  DFAJumpThreading(AssumptionCache *AC, DominatorTree *DT,
+  DFAJumpThreading(AssumptionCache *AC, DominatorTree *DT, LoopInfo *LI,
                    TargetTransformInfo *TTI, OptimizationRemarkEmitter *ORE)
-      : AC(AC), DT(DT), TTI(TTI), ORE(ORE) {}
+      : AC(AC), DT(DT), LI(LI), TTI(TTI), ORE(ORE) {}
 
   bool run(Function &F);
 
@@ -161,6 +167,7 @@ class DFAJumpThreading {
 
   AssumptionCache *AC;
   DominatorTree *DT;
+  LoopInfo *LI;
   TargetTransformInfo *TTI;
   OptimizationRemarkEmitter *ORE;
 };
@@ -378,7 +385,8 @@ inline raw_ostream &operator<<(raw_ostream &OS, const ThreadingPath &TPath) {
 #endif
 
 struct MainSwitch {
-  MainSwitch(SwitchInst *SI, OptimizationRemarkEmitter *ORE) {
+  MainSwitch(SwitchInst *SI, LoopInfo *LI, OptimizationRemarkEmitter *ORE)
+      : LI(LI) {
     if (isCandidate(SI)) {
       Instr = SI;
     } else {
@@ -402,7 +410,7 @@ struct MainSwitch {
   ///
   /// Also, collect select instructions to unfold.
   bool isCandidate(const SwitchInst *SI) {
-    std::deque<Value *> Q;
+    std::deque<std::pair<Value *, BasicBlock *>> Q;
     SmallSet<Value *, 16> SeenValues;
     SelectInsts.clear();
 
@@ -411,22 +419,24 @@ struct MainSwitch {
     if (!isa<PHINode>(SICond))
       return false;
 
-    addToQueue(SICond, Q, SeenValues);
+    addToQueue(SICond, nullptr, Q, SeenValues);
 
     while (!Q.empty()) {
-      Value *Current = Q.front();
+      Value *Current = Q.front().first;
+      BasicBlock *CurrentIncomingBB = Q.front().second;
       Q.pop_front();
 
       if (auto *Phi = dyn_cast<PHINode>(Current)) {
-        for (Value *Incoming : Phi->incoming_values()) {
-          addToQueue(Incoming, Q, SeenValues);
+        for (BasicBlock *IncomingBB : Phi->blocks()) {
+          Value *Incoming = Phi->getIncomingValueForBlock(IncomingBB);
+          addToQueue(Incoming, IncomingBB, Q, SeenValues);
         }
         LLVM_DEBUG(dbgs() << "\tphi: " << *Phi << "\n");
       } else if (SelectInst *SelI = dyn_cast<SelectInst>(Current)) {
         if (!isValidSelectInst(SelI))
           return false;
-        addToQueue(SelI->getTrueValue(), Q, SeenValues);
-        addToQueue(SelI->getFalseValue(), Q, SeenValues);
+        addToQueue(SelI->getTrueValue(), CurrentIncomingBB, Q, SeenValues);
+        addToQueue(SelI->getFalseValue(), CurrentIncomingBB, Q, SeenValues);
         LLVM_DEBUG(dbgs() << "\tselect: " << *SelI << "\n");
         if (auto *SelIUse = dyn_cast<PHINode>(SelI->user_back()))
           SelectInsts.push_back(SelectInstToUnfold(SelI, SelIUse));
@@ -439,6 +449,18 @@ struct MainSwitch {
         // initial switch values that can be ignored (they will hit the
         // unthreaded switch) but this assumption will get checked later after
         // paths have been enumerated (in function getStateDefMap).
+
+        // If the unpredictable value comes from the same inner loop it is
+        // likely that it will also be on the enumerated paths, causing us to
+        // exit after we have enumerated all the paths. This heuristic save
+        // compile time because a search for all the paths can become expensive.
+        if (EarlyExitHeuristic && LI->getLoopFor(SI->getParent()) ==
+                                      LI->getLoopFor(CurrentIncomingBB)) {
+          LLVM_DEBUG(dbgs()
+                     << "\tExiting early due to unpredictability heuristic.\n");
+          return false;
+        }
+
         continue;
       }
     }
@@ -446,11 +468,12 @@ struct MainSwitch {
     return true;
   }
 
-  void addToQueue(Value *Val, std::deque<Value *> &Q,
+  void addToQueue(Value *Val, BasicBlock *BB,
+                  std::deque<std::pair<Value *, BasicBlock *>> &Q,
                   SmallSet<Value *, 16> &SeenValues) {
     if (SeenValues.contains(Val))
       return;
-    Q.push_back(Val);
+    Q.push_back({Val, BB});
     SeenValues.insert(Val);
   }
 
@@ -488,6 +511,7 @@ struct MainSwitch {
     return true;
   }
 
+  LoopInfo *LI;
   SwitchInst *Instr = nullptr;
   SmallVector<SelectInstToUnfold, 4> SelectInsts;
 };
@@ -1262,7 +1286,7 @@ bool DFAJumpThreading::run(Function &F) {
 
     LLVM_DEBUG(dbgs() << "\nCheck if SwitchInst in BB " << BB.getName()
                       << " is a candidate\n");
-    MainSwitch Switch(SI, ORE);
+    MainSwitch Switch(SI, LI, ORE);
 
     if (!Switch.getInstr())
       continue;
@@ -1315,10 +1339,11 @@ PreservedAnalyses DFAJumpThreadingPass::run(Function &F,
                                             FunctionAnalysisManager &AM) {
   AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(F);
   DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F);
+  LoopInfo &LI = AM.getResult<LoopAnalysis>(F);
   TargetTransformInfo &TTI = AM.getResult<TargetIRAnalysis>(F);
   OptimizationRemarkEmitter ORE(&F);
 
-  if (!DFAJumpThreading(&AC, &DT, &TTI, &ORE).run(F))
+  if (!DFAJumpThreading(&AC, &DT, &LI, &TTI, &ORE).run(F))
     return PreservedAnalyses::all();
 
   PreservedAnalyses PA;
diff --git a/llvm/test/Transforms/DFAJumpThreading/dfa-unfold-select.ll b/llvm/test/Transforms/DFAJumpThreading/dfa-unfold-select.ll
index df725b9a7fa47d..696bd55d2dfdd9 100644
--- a/llvm/test/Transforms/DFAJumpThreading/dfa-unfold-select.ll
+++ b/llvm/test/Transforms/DFAJumpThreading/dfa-unfold-select.ll
@@ -1,5 +1,5 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
-; RUN: opt -S -passes=dfa-jump-threading %s | FileCheck %s
+; RUN: opt -S -passes=dfa-jump-threading -dfa-early-exit-heuristic=false %s | FileCheck %s
 
 ; These tests check if selects are unfolded properly for jump threading
 ; opportunities. There are three different patterns to consider:
diff --git a/llvm/test/Transforms/DFAJumpThreading/unpredictable-heuristic.ll b/llvm/test/Transforms/DFAJumpThreading/unpredictable-heuristic.ll
new file mode 100644
index 00000000000000..9743f0acc8165e
--- /dev/null
+++ b/llvm/test/Transforms/DFAJumpThreading/unpredictable-heuristic.ll
@@ -0,0 +1,124 @@
+; REQUIRES: asserts
+; RUN: opt -S -passes=dfa-jump-threading %s -debug-only=dfa-jump-threading 2>&1 | FileCheck %s
+
+; CHECK-COUNT-3: Exiting early due to unpredictability heuristic.
+
+ at .str.1 = private unnamed_addr constant [3 x i8] c"10\00", align 1
+ at .str.2 = private unnamed_addr constant [3 x i8] c"30\00", align 1
+ at .str.3 = private unnamed_addr constant [3 x i8] c"20\00", align 1
+ at .str.4 = private unnamed_addr constant [3 x i8] c"40\00", align 1
+
+define void @test1(i32 noundef %num, i32 noundef %num2) {
+entry:
+  br label %while.body
+
+while.body:                                       ; preds = %entry, %sw.epilog
+  %num.addr.0 = phi i32 [ %num, %entry ], [ %num.addr.1, %sw.epilog ]
+  switch i32 %num.addr.0, label %sw.default [
+    i32 10, label %sw.bb
+    i32 30, label %sw.bb1
+    i32 20, label %sw.bb2
+    i32 40, label %sw.bb3
+  ]
+
+sw.bb:                                            ; preds = %while.body
+  %call.i = tail call i32 @bar(ptr noundef nonnull @.str.1)
+  br label %sw.epilog
+
+sw.bb1:                                           ; preds = %while.body
+  %call.i4 = tail call i32 @bar(ptr noundef nonnull @.str.2)
+  br label %sw.epilog
+
+sw.bb2:                                           ; preds = %while.body
+  %call.i5 = tail call i32 @bar(ptr noundef nonnull @.str.3)
+  br label %sw.epilog
+
+sw.bb3:                                           ; preds = %while.body
+  %call.i6 = tail call i32 @bar(ptr noundef nonnull @.str.4)
+  %call = tail call noundef i32 @foo()
+  %add = add nsw i32 %call, %num2
+  br label %sw.epilog
+
+sw.default:                                       ; preds = %while.body
+  ret void
+
+sw.epilog:                                        ; preds = %sw.bb3, %sw.bb2, %sw.bb1, %sw.bb
+  %num.addr.1 = phi i32 [ %add, %sw.bb3 ], [ 40, %sw.bb2 ], [ 20, %sw.bb1 ], [ 30, %sw.bb ]
+  br label %while.body
+}
+
+
+define void @test2(i32 noundef %num, i32 noundef %num2) {
+entry:
+  br label %while.body
+
+while.body:                                       ; preds = %entry, %sw.epilog
+  %num.addr.0 = phi i32 [ %num, %entry ], [ %num.addr.1, %sw.epilog ]
+  switch i32 %num.addr.0, label %sw.default [
+    i32 10, label %sw.epilog
+    i32 30, label %sw.bb1
+    i32 20, label %sw.bb2
+    i32 40, label %sw.bb3
+  ]
+
+sw.bb1:                                           ; preds = %while.body
+  br label %sw.epilog
+
+sw.bb2:                                           ; preds = %while.body
+  br label %sw.epilog
+
+sw.bb3:                                           ; preds = %while.body
+  br label %sw.epilog
+
+sw.default:                                       ; preds = %while.body
+  ret void
+
+sw.epilog:                                        ; preds = %while.body, %sw.bb3, %sw.bb2, %sw.bb1
+  %.str.4.sink = phi ptr [ @.str.4, %sw.bb3 ], [ @.str.3, %sw.bb2 ], [ @.str.2, %sw.bb1 ], [ @.str.1, %while.body ]
+  %num.addr.1 = phi i32 [ %num2, %sw.bb3 ], [ 40, %sw.bb2 ], [ 20, %sw.bb1 ], [ 30, %while.body ]
+  %call.i6 = tail call i32 @bar(ptr noundef nonnull %.str.4.sink)
+  br label %while.body
+}
+
+
+define void @test3(i32 noundef %num, i32 noundef %num2) {
+entry:
+  %add = add nsw i32 %num2, 40
+  br label %while.body
+
+while.body:                                       ; preds = %entry, %sw.epilog
+  %num.addr.0 = phi i32 [ %num, %entry ], [ %num.addr.1, %sw.epilog ]
+  switch i32 %num.addr.0, label %sw.default [
+    i32 10, label %sw.bb
+    i32 30, label %sw.bb1
+    i32 20, label %sw.bb2
+    i32 40, label %sw.bb3
+  ]
+
+sw.bb:                                            ; preds = %while.body
+  %call.i = tail call i32 @bar(ptr noundef nonnull @.str.1)
+  br label %sw.epilog
+
+sw.bb1:                                           ; preds = %while.body
+  %call.i5 = tail call i32 @bar(ptr noundef nonnull @.str.2)
+  br label %sw.epilog
+
+sw.bb2:                                           ; preds = %while.body
+  %call.i6 = tail call i32 @bar(ptr noundef nonnull @.str.3)
+  br label %sw.epilog
+
+sw.bb3:                                           ; preds = %while.body
+  %call.i7 = tail call i32 @bar(ptr noundef nonnull @.str.4)
+  br label %sw.epilog
+
+sw.default:                                       ; preds = %while.body
+  ret void
+
+sw.epilog:                                        ; preds = %sw.bb3, %sw.bb2, %sw.bb1, %sw.bb
+  %num.addr.1 = phi i32 [ %add, %sw.bb3 ], [ 40, %sw.bb2 ], [ 20, %sw.bb1 ], [ 30, %sw.bb ]
+  br label %while.body
+}
+
+
+declare noundef i32 @foo()
+declare noundef i32 @bar(ptr nocapture noundef readonly)



More information about the llvm-commits mailing list