[llvm] [SimplifyCFG] Fold `select` of equality comparison into switch predecessor (PR #79177)

via llvm-commits llvm-commits at lists.llvm.org
Tue Jan 23 09:44:10 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Antonio Frighetto (antoniofrighetto)

<details>
<summary>Changes</summary>

When a conditional basic block is speculatively executed, the phi operands of the fall-through block are rewritten with a `select`. Revisit `FoldValueComparisonIntoPredecessors` to check whether the condition of a branch may encompass a `select` of equality comparison, whose true value is always constant true.

While extending `FoldValueComparisonIntoPredecessors` seems to fit here (as it allows to leverage a good amount of the existing code), I’m not sure whether it is in contradiction with the original semantic of the function (code seems to be there since a long time).

After the proposed transformation, the following:


```
bb9:                                              ; preds = %entry, %entry, %entry, %switch.edge, %bb8
  %_3.0 = phi i1 [ false, %bb8 ], [ true, %entry ], [ true, %entry ], [ true, %switch.edge ], [ true, %entry ]
  br i1 %_3.0, label %bb3, label %bb2

bb2:                                              ; preds = %bb9
  %_12 = icmp eq i32 %c, 119
  br label %bb3

```

is folded into:
```
bb9:                                              ; preds = %entry, %entry, %entry, %switch.edge, %bb8
  %_3.0 = phi i1 [ false, %bb8 ], [ true, %entry ], [ true, %entry ], [ true, %switch.edge ], [ true, %entry ]
  %_12 = icmp eq i32 %c, 119
  %spec.select = select i1 %_3.0, i1 true, i1 %_12
  br i1 %_3.0, label %bb3, label %bb2

bb2:                                              ; preds = %bb9
  br label %bb3

```

As the `icmp` gets speculatively executed. What’s left to be applied may be to speculate the branch condition to the select too, so as to further simplify the last case.

Partially fixes #<!-- -->63470.

---
Full diff: https://github.com/llvm/llvm-project/pull/79177.diff


2 Files Affected:

- (modified) llvm/lib/Transforms/Utils/SimplifyCFG.cpp (+92-31) 
- (added) llvm/test/Transforms/SimplifyCFG/switch-multiple-comparisons-consolidation.ll (+57) 


``````````diff
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 13eae549b2ce41b..3facf8f84d26ccd 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -769,8 +769,9 @@ static void EraseTerminatorAndDCECond(Instruction *TI,
     RecursivelyDeleteTriviallyDeadInstructions(Cond, nullptr, MSSAU);
 }
 
-/// Return true if the specified terminator checks
-/// to see if a value is equal to constant integer value.
+/// Return true if the specified terminator checks to see if a value is equal to
+/// a constant integer value, or is equal to a select of a constant integer
+/// value, appearing in the false arm.
 Value *SimplifyCFGOpt::isValueEqualityComparison(Instruction *TI) {
   Value *CV = nullptr;
   if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) {
@@ -778,12 +779,28 @@ Value *SimplifyCFGOpt::isValueEqualityComparison(Instruction *TI) {
     // predecessors unless there is only one predecessor.
     if (!SI->getParent()->hasNPredecessorsOrMore(128 / SI->getNumSuccessors()))
       CV = SI->getCondition();
-  } else if (BranchInst *BI = dyn_cast<BranchInst>(TI))
-    if (BI->isConditional() && BI->getCondition()->hasOneUse())
-      if (ICmpInst *ICI = dyn_cast<ICmpInst>(BI->getCondition())) {
-        if (ICI->isEquality() && GetConstantInt(ICI->getOperand(1), DL))
-          CV = ICI->getOperand(0);
+  } else if (BranchInst *BI = dyn_cast<BranchInst>(TI)) {
+    auto HandleICmpI = [&](ICmpInst *ICI, auto &&IsEquality) -> Value * {
+      if (ICI->hasOneUse() && IsEquality(ICI) &&
+          GetConstantInt(ICI->getOperand(1), DL))
+        return ICI->getOperand(0);
+      return nullptr;
+    };
+
+    if (BI->isConditional()) {
+      if (auto *ICI = dyn_cast<ICmpInst>(BI->getCondition())) {
+        CV = HandleICmpI(ICI, [](ICmpInst *ICI) { return ICI->isEquality(); });
+      } else if (auto *SI = dyn_cast<SelectInst>(BI->getCondition())) {
+        // Chain of comparisons are already handled.
+        if (isa<ICmpInst>(SI->getCondition()))
+          return nullptr;
+        if (auto *ICI = dyn_cast<ICmpInst>(SI->getFalseValue()))
+          CV = HandleICmpI(ICI, [](ICmpInst *ICI) {
+            return ICI->getPredicate() == ICmpInst::ICMP_EQ;
+          });
       }
+    }
+  }
 
   // Unwrap any lossless ptrtoint cast.
   if (CV) {
@@ -809,7 +826,14 @@ BasicBlock *SimplifyCFGOpt::GetValueEqualityComparisonCases(
   }
 
   BranchInst *BI = cast<BranchInst>(TI);
-  ICmpInst *ICI = cast<ICmpInst>(BI->getCondition());
+  ICmpInst *ICI = nullptr;
+  if (auto *SI = dyn_cast<SelectInst>(BI->getCondition())) {
+    // We have already checked in `isValueEqualityComparison` that Succ and
+    // 'default' block depend on the icmp, and the latter is on the false arm.
+    ICI = cast<ICmpInst>(SI->getFalseValue());
+  } else {
+    ICI = cast<ICmpInst>(BI->getCondition());
+  }
   BasicBlock *Succ = BI->getSuccessor(ICI->getPredicate() == ICmpInst::ICMP_NE);
   Cases.push_back(ValueEqualityComparisonCase(
       GetConstantInt(ICI->getOperand(1), DL), Succ));
@@ -1204,7 +1228,19 @@ bool SimplifyCFGOpt::PerformValueComparisonIntoPredecessorFolding(
   } else if (PredHasWeights)
     SuccWeights.assign(1 + BBCases.size(), 1);
 
-  if (PredDefault == BB) {
+  auto I = BB->instructionsWithoutDebug(true).begin();
+  if (auto *BI = dyn_cast<BranchInst>(TI);
+      BI && isa<SelectInst>(BI->getCondition())) {
+    // TODO: Preserve branch weight metadata
+    // We handle a phi node by the time we fold the select of a comparison.
+    PHINode &PHI = cast<PHINode>(*I);
+    auto *OldCond = BI->getCondition();
+    BI->setCondition(&PHI);
+    // We have harvested only one comparison.
+    PredCases.push_back(ValueEqualityComparisonCase(BBCases[0].Value, BB));
+    ++NewSuccessors[BB];
+    RecursivelyDeleteTriviallyDeadInstructions(OldCond);
+  } else if (PredDefault == BB) {
     // If this is the default destination from PTI, only the edges in TI
     // that don't occur in PTI, or that branch to BB will be activated.
     std::set<ConstantInt *, ConstantIntOrdering> PTIHandled;
@@ -1315,7 +1351,10 @@ bool SimplifyCFGOpt::PerformValueComparisonIntoPredecessorFolding(
        NewSuccessors) {
     for (auto I : seq(NewSuccessor.second)) {
       (void)I;
-      AddPredecessorToBlock(NewSuccessor.first, Pred, BB);
+      // If the new successor happens to be `BB` itself, we are dealing with the
+      // case of the select of a comparison.
+      AddPredecessorToBlock(NewSuccessor.first, Pred,
+                            NewSuccessor.first != BB ? BB : Pred);
     }
     if (DTU && !SuccsOfPred.contains(NewSuccessor.first))
       Updates.push_back({DominatorTree::Insert, Pred, NewSuccessor.first});
@@ -1345,30 +1384,36 @@ bool SimplifyCFGOpt::PerformValueComparisonIntoPredecessorFolding(
 
   EraseTerminatorAndDCECond(PTI);
 
-  // Okay, last check.  If BB is still a successor of PSI, then we must
-  // have an infinite loop case.  If so, add an infinitely looping block
-  // to handle the case to preserve the behavior of the code.
+  // Okay, last check. If we are not handling a select of comparison, and BB is
+  // still a successor of PSI, then we must have an infinite loop case.  If so,
+  // add an infinitely looping block to handle the case to preserve the behavior
+  // of the code.
   BasicBlock *InfLoopBlock = nullptr;
-  for (unsigned i = 0, e = NewSI->getNumSuccessors(); i != e; ++i)
-    if (NewSI->getSuccessor(i) == BB) {
-      if (!InfLoopBlock) {
-        // Insert it at the end of the function, because it's either code,
-        // or it won't matter if it's hot. :)
-        InfLoopBlock =
-            BasicBlock::Create(BB->getContext(), "infloop", BB->getParent());
-        BranchInst::Create(InfLoopBlock, InfLoopBlock);
-        if (DTU)
-          Updates.push_back(
-              {DominatorTree::Insert, InfLoopBlock, InfLoopBlock});
+  if (!isa<PHINode>(*I)) {
+    for (unsigned i = 0, e = NewSI->getNumSuccessors(); i != e; ++i)
+      if (NewSI->getSuccessor(i) == BB) {
+        if (!InfLoopBlock) {
+          // Insert it at the end of the function, because it's either code,
+          // or it won't matter if it's hot. :)
+          InfLoopBlock =
+              BasicBlock::Create(BB->getContext(), "infloop", BB->getParent());
+          BranchInst::Create(InfLoopBlock, InfLoopBlock);
+          if (DTU)
+            Updates.push_back(
+                {DominatorTree::Insert, InfLoopBlock, InfLoopBlock});
+        }
+        NewSI->setSuccessor(i, InfLoopBlock);
       }
-      NewSI->setSuccessor(i, InfLoopBlock);
-    }
+  }
 
   if (DTU) {
     if (InfLoopBlock)
       Updates.push_back({DominatorTree::Insert, Pred, InfLoopBlock});
 
-    Updates.push_back({DominatorTree::Delete, Pred, BB});
+    if (!isa<PHINode>(*I))
+      Updates.push_back({DominatorTree::Delete, Pred, BB});
+    else
+      Updates.push_back({DominatorTree::Insert, Pred, BB});
 
     DTU->applyUpdates(Updates);
   }
@@ -1377,10 +1422,20 @@ bool SimplifyCFGOpt::PerformValueComparisonIntoPredecessorFolding(
   return true;
 }
 
-/// The specified terminator is a value equality comparison instruction
-/// (either a switch or a branch on "X == c").
-/// See if any of the predecessors of the terminator block are value comparisons
-/// on the same value.  If so, and if safe to do so, fold them together.
+/// The specified terminator is a value equality comparison instruction, either
+/// a switch or a branch on "X == c", or a branch on select whose false arm is
+/// "X == c". If we happen to have a select, previously generated after
+/// speculatively executing its fall-through basic block, of the following kind:
+/// \code
+///   BB:
+///     %phi = phi i1 [ false, %edge ], [ true, %switch ], [ true, %switch ]
+///     %icmp = icmp eq i32 %c, X
+///     %spec.select = select i1 %phi, i1 true, i1 %icmp
+///     br i1 %spec.select1, label %EndBB, label %ThenBB
+/// \endcode
+/// We attempt folding them into its predecessor. To do so, see if any of the
+/// predecessors of the terminator block are value comparisons on the same
+/// value.
 bool SimplifyCFGOpt::FoldValueComparisonIntoPredecessors(Instruction *TI,
                                                          IRBuilder<> &Builder) {
   BasicBlock *BB = TI->getParent();
@@ -7285,6 +7340,12 @@ bool SimplifyCFGOpt::simplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) {
       ++I;
       if (&*I == BI && FoldValueComparisonIntoPredecessors(BI, Builder))
         return requestResimplify();
+    } else if (isa<PHINode>(*I)) {
+      if (auto *SI = dyn_cast<SelectInst>(BI->getCondition()))
+        if (auto *CI = dyn_cast<ConstantInt>(SI->getTrueValue());
+            CI && CI->isOne())
+          if (FoldValueComparisonIntoPredecessors(BI, Builder))
+            return requestResimplify();
     }
   }
 
diff --git a/llvm/test/Transforms/SimplifyCFG/switch-multiple-comparisons-consolidation.ll b/llvm/test/Transforms/SimplifyCFG/switch-multiple-comparisons-consolidation.ll
new file mode 100644
index 000000000000000..fb39e50c15d0ab0
--- /dev/null
+++ b/llvm/test/Transforms/SimplifyCFG/switch-multiple-comparisons-consolidation.ll
@@ -0,0 +1,57 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt < %s -passes=simplifycfg -simplifycfg-require-and-preserve-domtree=1 -S | FileCheck %s
+
+define i1 @test(i32 %c) {
+; CHECK-LABEL: define i1 @test(
+; CHECK-SAME: i32 [[C:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    switch i32 [[C]], label [[BB8:%.*]] [
+; CHECK-NEXT:      i32 115, label [[BB9:%.*]]
+; CHECK-NEXT:      i32 109, label [[BB9]]
+; CHECK-NEXT:      i32 104, label [[BB9]]
+; CHECK-NEXT:      i32 100, label [[BB9]]
+; CHECK-NEXT:    ]
+; CHECK:       bb8:
+; CHECK-NEXT:    br label [[BB9]]
+; CHECK:       bb9:
+; CHECK-NEXT:    [[_3_0:%.*]] = phi i1 [ false, [[BB8]] ], [ true, [[ENTRY:%.*]] ], [ true, [[ENTRY]] ], [ true, [[ENTRY]] ], [ true, [[ENTRY]] ]
+; CHECK-NEXT:    [[_12:%.*]] = icmp eq i32 [[C]], 119
+; CHECK-NEXT:    [[SPEC_SELECT:%.*]] = select i1 [[_3_0]], i1 true, i1 [[_12]]
+; CHECK-NEXT:    ret i1 [[SPEC_SELECT]]
+;
+entry:
+  %i1 = icmp eq i32 %c, 115
+  br i1 %i1, label %bb12, label %bb11
+
+bb11:                                             ; preds = %entry
+  %_6 = icmp eq i32 %c, 109
+  br label %bb12
+
+bb12:                                             ; preds = %entry, %bb11
+  %_4.0 = phi i1 [ %_6, %bb11 ], [ true, %entry ]
+  br i1 %_4.0, label %bb9, label %bb8
+
+bb8:                                              ; preds = %bb12
+  %_8 = icmp eq i32 %c, 104
+  br label %bb9
+
+bb9:                                              ; preds = %bb12, %bb8
+  %_3.0 = phi i1 [ %_8, %bb8 ], [ true, %bb12 ]
+  br i1 %_3.0, label %bb6, label %bb5
+
+bb5:                                              ; preds = %bb9
+  %_10 = icmp eq i32 %c, 100
+  br label %bb6
+
+bb6:                                              ; preds = %bb9, %bb5
+  %_2.0 = phi i1 [ %_10, %bb5 ], [ true, %bb9 ]
+  br i1 %_2.0, label %bb3, label %bb2
+
+bb2:                                              ; preds = %bb6
+  %_12 = icmp eq i32 %c, 119
+  br label %bb3
+
+bb3:                                              ; preds = %bb6, %bb2
+  %i.0 = phi i1 [ %_12, %bb2 ], [ true, %bb6 ]
+  ret i1 %i.0
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/79177


More information about the llvm-commits mailing list