[llvm] r296363 - Fix a bug when unswitching on partial LIV for SwitchInst

Xin Tong via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 27 10:00:13 PST 2017


Author: trentxintong
Date: Mon Feb 27 12:00:13 2017
New Revision: 296363

URL: http://llvm.org/viewvc/llvm-project?rev=296363&view=rev
Log:
Fix a bug when unswitching on partial LIV for SwitchInst

Summary: Fix a bug when unswitching on partial LIV for SwitchInst.

Reviewers: hfinkel, efriedma, sanjoy

Reviewed By: sanjoy

Subscribers: david2050, mzolotukhin, llvm-commits

Differential Revision: https://reviews.llvm.org/D29107

Modified:
    llvm/trunk/lib/Transforms/Scalar/LoopUnswitch.cpp
    llvm/trunk/test/Transforms/LoopUnswitch/basictest.ll

Modified: llvm/trunk/lib/Transforms/Scalar/LoopUnswitch.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/Scalar/LoopUnswitch.cpp?rev=296363&r1=296362&r2=296363&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/Scalar/LoopUnswitch.cpp (original)
+++ llvm/trunk/lib/Transforms/Scalar/LoopUnswitch.cpp Mon Feb 27 12:00:13 2017
@@ -374,9 +374,27 @@ Pass *llvm::createLoopUnswitchPass(bool
   return new LoopUnswitch(Os);
 }
 
+/// Operator chain lattice.
+enum OperatorChain {
+  OC_OpChainNone,    ///< There is no operator.
+  OC_OpChainOr,      ///< There are only ORs.
+  OC_OpChainAnd,     ///< There are only ANDs.
+  OC_OpChainMixed    ///< There are ANDs and ORs.
+};
+
 /// Cond is a condition that occurs in L. If it is invariant in the loop, or has
 /// an invariant piece, return the invariant. Otherwise, return null.
+//
+/// NOTE: FindLIVLoopCondition will not return a partial LIV by walking up a
+/// mixed operator chain, as we can not reliably find a value which will simplify
+/// the operator chain. If the chain is AND-only or OR-only, we can use 0 or ~0
+/// to simplify the chain.
+///
+/// NOTE: In case a partial LIV and a mixed operator chain, we may be able to
+/// simplify the condition itself to a loop variant condition, but at the
+/// cost of creating an entirely new loop.
 static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed,
+                                   OperatorChain &ParentChain,
                                    DenseMap<Value *, Value *> &Cache) {
   auto CacheIt = Cache.find(Cond);
   if (CacheIt != Cache.end())
@@ -400,21 +418,53 @@ static Value *FindLIVLoopCondition(Value
     return Cond;
   }
 
+  // Walk up the operator chain to find partial invariant conditions.
   if (BinaryOperator *BO = dyn_cast<BinaryOperator>(Cond))
     if (BO->getOpcode() == Instruction::And ||
         BO->getOpcode() == Instruction::Or) {
-      // If either the left or right side is invariant, we can unswitch on this,
-      // which will cause the branch to go away in one loop and the condition to
-      // simplify in the other one.
-      if (Value *LHS =
-              FindLIVLoopCondition(BO->getOperand(0), L, Changed, Cache)) {
-        Cache[Cond] = LHS;
-        return LHS;
+      // Given the previous operator, compute the current operator chain status.
+      OperatorChain NewChain;
+      switch (ParentChain) {
+      case OC_OpChainNone:
+        NewChain = BO->getOpcode() == Instruction::And ? OC_OpChainAnd :
+                                      OC_OpChainOr;
+        break;
+      case OC_OpChainOr:
+        NewChain = BO->getOpcode() == Instruction::Or ? OC_OpChainOr :
+                                      OC_OpChainMixed;
+        break;
+      case OC_OpChainAnd:
+        NewChain = BO->getOpcode() == Instruction::And ? OC_OpChainAnd :
+                                      OC_OpChainMixed;
+        break;
+      case OC_OpChainMixed:
+        NewChain = OC_OpChainMixed;
+        break;
       }
-      if (Value *RHS =
-              FindLIVLoopCondition(BO->getOperand(1), L, Changed, Cache)) {
-        Cache[Cond] = RHS;
-        return RHS;
+
+      // If we reach a Mixed state, we do not want to keep walking up as we can not
+      // reliably find a value that will simplify the chain. With this check, we
+      // will return null on the first sight of mixed chain and the caller will
+      // either backtrack to find partial LIV in other operand or return null.
+      if (NewChain != OC_OpChainMixed) {
+        // Update the current operator chain type before we search up the chain.
+        ParentChain = NewChain;
+        // If either the left or right side is invariant, we can unswitch on this,
+        // which will cause the branch to go away in one loop and the condition to
+        // simplify in the other one.
+        if (Value *LHS = FindLIVLoopCondition(BO->getOperand(0), L, Changed,
+                                              ParentChain, Cache)) {
+          Cache[Cond] = LHS;
+          return LHS;
+        }
+        // We did not manage to find a partial LIV in operand(0). Backtrack and try
+        // operand(1).
+        ParentChain = NewChain;
+        if (Value *RHS = FindLIVLoopCondition(BO->getOperand(1), L, Changed,
+                                              ParentChain, Cache)) {
+          Cache[Cond] = RHS;
+          return RHS;
+        }
       }
     }
 
@@ -422,9 +472,21 @@ static Value *FindLIVLoopCondition(Value
   return nullptr;
 }
 
-static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed) {
+/// Cond is a condition that occurs in L. If it is invariant in the loop, or has
+/// an invariant piece, return the invariant along with the operator chain type.
+/// Otherwise, return null.
+static std::pair<Value *, OperatorChain> FindLIVLoopCondition(Value *Cond,
+                                                              Loop *L,
+                                                              bool &Changed) {
   DenseMap<Value *, Value *> Cache;
-  return FindLIVLoopCondition(Cond, L, Changed, Cache);
+  OperatorChain OpChain = OC_OpChainNone;
+  Value *FCond = FindLIVLoopCondition(Cond, L, Changed, OpChain, Cache);
+
+  // In case we do find a LIV, it can not be obtained by walking up a mixed
+  // operator chain.
+  assert((!FCond || OpChain != OC_OpChainMixed) &&
+        "Do not expect a partial LIV with mixed operator chain");
+  return {FCond, OpChain};
 }
 
 bool LoopUnswitch::runOnLoop(Loop *L, LPPassManager &LPM_Ref) {
@@ -556,7 +618,7 @@ bool LoopUnswitch::processCurrentLoop()
 
   for (IntrinsicInst *Guard : Guards) {
     Value *LoopCond =
-        FindLIVLoopCondition(Guard->getOperand(0), currentLoop, Changed);
+        FindLIVLoopCondition(Guard->getOperand(0), currentLoop, Changed).first;
     if (LoopCond &&
         UnswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context))) {
       // NB! Unswitching (if successful) could have erased some of the
@@ -597,7 +659,7 @@ bool LoopUnswitch::processCurrentLoop()
         // See if this, or some part of it, is loop invariant.  If so, we can
         // unswitch on it if we desire.
         Value *LoopCond = FindLIVLoopCondition(BI->getCondition(),
-                                               currentLoop, Changed);
+                                               currentLoop, Changed).first;
         if (LoopCond &&
             UnswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context), TI)) {
           ++NumBranches;
@@ -605,24 +667,49 @@ bool LoopUnswitch::processCurrentLoop()
         }
       }
     } else if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) {
-      Value *LoopCond = FindLIVLoopCondition(SI->getCondition(),
-                                             currentLoop, Changed);
+      Value *SC = SI->getCondition();
+      Value *LoopCond;
+      OperatorChain OpChain;
+      std::tie(LoopCond, OpChain) =
+        FindLIVLoopCondition(SC, currentLoop, Changed);
+
       unsigned NumCases = SI->getNumCases();
       if (LoopCond && NumCases) {
         // Find a value to unswitch on:
         // FIXME: this should chose the most expensive case!
         // FIXME: scan for a case with a non-critical edge?
         Constant *UnswitchVal = nullptr;
-
-        // Do not process same value again and again.
-        // At this point we have some cases already unswitched and
-        // some not yet unswitched. Let's find the first not yet unswitched one.
-        for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end();
-             i != e; ++i) {
-          Constant *UnswitchValCandidate = i.getCaseValue();
-          if (!BranchesInfo.isUnswitched(SI, UnswitchValCandidate)) {
-            UnswitchVal = UnswitchValCandidate;
-            break;
+        // Find a case value such that at least one case value is unswitched
+        // out.
+        if (OpChain == OC_OpChainAnd) {
+          // If the chain only has ANDs and the switch has a case value of 0.
+          // Dropping in a 0 to the chain will unswitch out the 0-casevalue.
+          auto *AllZero = cast<ConstantInt>(Constant::getNullValue(SC->getType()));
+          if (BranchesInfo.isUnswitched(SI, AllZero))
+            continue;
+          // We are unswitching 0 out.
+          UnswitchVal = AllZero;
+        } else if (OpChain == OC_OpChainOr) {
+          // If the chain only has ORs and the switch has a case value of ~0.
+          // Dropping in a ~0 to the chain will unswitch out the ~0-casevalue.
+          auto *AllOne = cast<ConstantInt>(Constant::getAllOnesValue(SC->getType()));
+          if (BranchesInfo.isUnswitched(SI, AllOne))
+            continue;
+          // We are unswitching ~0 out.
+          UnswitchVal = AllOne;
+        } else {
+          assert(OpChain == OC_OpChainNone && 
+                 "Expect to unswitch on trivial chain");
+          // Do not process same value again and again.
+          // At this point we have some cases already unswitched and
+          // some not yet unswitched. Let's find the first not yet unswitched one.
+          for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end();
+               i != e; ++i) {
+            Constant *UnswitchValCandidate = i.getCaseValue();
+            if (!BranchesInfo.isUnswitched(SI, UnswitchValCandidate)) {
+              UnswitchVal = UnswitchValCandidate;
+              break;
+            }
           }
         }
 
@@ -631,6 +718,11 @@ bool LoopUnswitch::processCurrentLoop()
 
         if (UnswitchIfProfitable(LoopCond, UnswitchVal)) {
           ++NumSwitches;
+          // In case of a full LIV, UnswitchVal is the value we unswitched out.
+          // In case of a partial LIV, we only unswitch when its an AND-chain
+          // or OR-chain. In both cases switch input value simplifies to
+          // UnswitchVal.
+          BranchesInfo.setUnswitched(SI, UnswitchVal);
           return true;
         }
       }
@@ -641,7 +733,7 @@ bool LoopUnswitch::processCurrentLoop()
          BBI != E; ++BBI)
       if (SelectInst *SI = dyn_cast<SelectInst>(BBI)) {
         Value *LoopCond = FindLIVLoopCondition(SI->getCondition(),
-                                               currentLoop, Changed);
+                                               currentLoop, Changed).first;
         if (LoopCond && UnswitchIfProfitable(LoopCond,
                                              ConstantInt::getTrue(Context))) {
           ++NumSelects;
@@ -900,7 +992,7 @@ bool LoopUnswitch::TryTrivialLoopUnswitc
       return false;
 
     Value *LoopCond = FindLIVLoopCondition(BI->getCondition(),
-                                           currentLoop, Changed);
+                                           currentLoop, Changed).first;
 
     // Unswitch only if the trivial condition itself is an LIV (not
     // partial LIV which could occur in and/or)
@@ -931,7 +1023,7 @@ bool LoopUnswitch::TryTrivialLoopUnswitc
   } else if (SwitchInst *SI = dyn_cast<SwitchInst>(CurrentTerm)) {
     // If this isn't switching on an invariant condition, we can't unswitch it.
     Value *LoopCond = FindLIVLoopCondition(SI->getCondition(),
-                                           currentLoop, Changed);
+                                           currentLoop, Changed).first;
 
     // Unswitch only if the trivial condition itself is an LIV (not
     // partial LIV which could occur in and/or)
@@ -969,6 +1061,9 @@ bool LoopUnswitch::TryTrivialLoopUnswitc
 
     UnswitchTrivialCondition(currentLoop, LoopCond, CondVal, LoopExitBB,
                              nullptr);
+
+    // We are only unswitching full LIV.
+    BranchesInfo.setUnswitched(SI, CondVal);
     ++NumSwitches;
     return true;
   }
@@ -1250,6 +1345,9 @@ void LoopUnswitch::RewriteLoopBodyWithCo
     SwitchInst *SI = dyn_cast<SwitchInst>(UI);
     if (!SI || !isa<ConstantInt>(Val)) continue;
 
+    // NOTE: if a case value for the switch is unswitched out, we record it
+    // after the unswitch finishes. We can not record it here as the switch
+    // is not a direct user of the partial LIV.
     SwitchInst::CaseIt DeadCase = SI->findCaseValue(cast<ConstantInt>(Val));
     // Default case is live for multiple values.
     if (DeadCase == SI->case_default()) continue;
@@ -1262,8 +1360,6 @@ void LoopUnswitch::RewriteLoopBodyWithCo
     BasicBlock *SISucc = DeadCase.getCaseSuccessor();
     BasicBlock *Latch = L->getLoopLatch();
 
-    BranchesInfo.setUnswitched(SI, Val);
-
     if (!SI->findCaseDest(SISucc)) continue;  // Edge is critical.
     // If the DeadCase successor dominates the loop latch, then the
     // transformation isn't safe since it will delete the sole predecessor edge

Modified: llvm/trunk/test/Transforms/LoopUnswitch/basictest.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/LoopUnswitch/basictest.ll?rev=296363&r1=296362&r2=296363&view=diff
==============================================================================
--- llvm/trunk/test/Transforms/LoopUnswitch/basictest.ll (original)
+++ llvm/trunk/test/Transforms/LoopUnswitch/basictest.ll Mon Feb 27 12:00:13 2017
@@ -101,6 +101,217 @@ loop_exit:
 ; CHECK: }
 }
 
+; Make sure we unswitch %a == 0 out of the loop.
+;
+; CHECK: define void @and_i2_as_switch_input(i2
+; CHECK: entry:
+; This is an indication that the loop has been unswitched.
+; CHECK: icmp eq i2 %a, 0
+; CHECK: br
+; There should be no more unswitching after the 1st unswitch.
+; CHECK-NOT: icmp eq
+; CHECK: ret
+define void @and_i2_as_switch_input(i2 %a) {
+entry:
+  br label %for.body
+
+for.body:
+  %i = phi i2 [ 0, %entry ], [ %inc, %for.inc ]
+  %and = and i2 %a, %i
+  %and1 = and i2 %and, %i
+  switch i2 %and1, label %sw.default [
+    i2 0, label %sw.bb
+    i2 1, label %sw.bb1
+  ]
+
+sw.bb:
+  br label %sw.epilog
+
+sw.bb1:
+  br label %sw.epilog
+
+sw.default:
+  br label %sw.epilog
+
+sw.epilog:
+  br label %for.inc
+
+for.inc:
+  %inc = add nsw i2 %i, 1
+  %cmp = icmp slt i2 %inc, 3 
+  br i1 %cmp, label %for.body, label %for.end
+
+for.end:
+  ret void
+}
+
+; Make sure we unswitch %a == !0 out of the loop.
+;
+; CHECK: define void @or_i2_as_switch_input(i2
+; CHECK: entry:
+; This is an indication that the loop has been unswitched.
+; CHECK: icmp eq i2 %a, -1
+; CHECK: br
+; There should be no more unswitching after the 1st unswitch.
+; CHECK-NOT: icmp eq
+; CHECK: ret
+define void @or_i2_as_switch_input(i2 %a) {
+entry:
+  br label %for.body
+
+for.body:
+  %i = phi i2 [ 0, %entry ], [ %inc, %for.inc ]
+  %or = or i2 %a, %i
+  %or1 = or i2 %or, %i
+  switch i2 %or1, label %sw.default [
+    i2 2, label %sw.bb
+    i2 3, label %sw.bb1
+  ]
+
+sw.bb:
+  br label %sw.epilog
+
+sw.bb1:
+  br label %sw.epilog
+
+sw.default:
+  br label %sw.epilog
+
+sw.epilog:
+  br label %for.inc
+
+for.inc:
+  %inc = add nsw i2 %i, 1
+  %cmp = icmp slt i2 %inc, 3 
+  br i1 %cmp, label %for.body, label %for.end
+
+for.end:
+  ret void
+}
+
+; Make sure we unswitch %a == !0 out of the loop. Even we do not
+; have it as a case value. Unswitching it out allows us to simplify
+; the or operator chain.
+;
+; CHECK: define void @or_i2_as_switch_input_unswitch_default(i2
+; CHECK: entry:
+; This is an indication that the loop has been unswitched.
+; CHECK: icmp eq i2 %a, -1
+; CHECK: br
+; There should be no more unswitching after the 1st unswitch.
+; CHECK-NOT: icmp eq
+; CHECK: ret
+define void @or_i2_as_switch_input_unswitch_default(i2 %a) {
+entry:
+  br label %for.body
+
+for.body:
+  %i = phi i2 [ 0, %entry ], [ %inc, %for.inc ]
+  %or = or i2 %a, %i
+  %or1 = or i2 %or, %i
+  switch i2 %or1, label %sw.default [
+    i2 1, label %sw.bb
+    i2 2, label %sw.bb1
+  ]
+
+sw.bb:
+  br label %sw.epilog
+
+sw.bb1:
+  br label %sw.epilog
+
+sw.default:
+  br label %sw.epilog
+
+sw.epilog:
+  br label %for.inc
+
+for.inc:
+  %inc = add nsw i2 %i, 1
+  %cmp = icmp slt i2 %inc, 3 
+  br i1 %cmp, label %for.body, label %for.end
+
+for.end:
+  ret void
+}
+
+; Make sure we don't unswitch, as we can not find an input value %a
+; that will effectively unswitch 0 or 3 out of the loop.
+;
+; CHECK: define void @and_or_i2_as_switch_input(i2
+; CHECK: entry:
+; This is an indication that the loop has NOT been unswitched.
+; CHECK-NOT: icmp
+; CHECK: br
+define void @and_or_i2_as_switch_input(i2 %a) {
+entry:
+  br label %for.body
+
+for.body:
+  %i = phi i2 [ 0, %entry ], [ %inc, %for.inc ]
+  %and = and i2 %a, %i 
+  %or = or i2 %and, %i
+  switch i2 %or, label %sw.default [
+    i2 0, label %sw.bb
+    i2 3, label %sw.bb1
+  ]
+
+sw.bb:
+  br label %sw.epilog
+
+sw.bb1:
+  br label %sw.epilog
+
+sw.default:
+  br label %sw.epilog
+
+sw.epilog:
+  br label %for.inc
+
+for.inc:
+  %inc = add nsw i2 %i, 1
+  %cmp = icmp slt i2 %inc, 3 
+  br i1 %cmp, label %for.body, label %for.end
+
+for.end:
+  ret void
+}
+
+; Make sure we don't unswitch, as we can not find an input value %a
+; that will effectively unswitch true/false out of the loop.
+;
+; CHECK: define void @and_or_i1_as_branch_input(i1
+; CHECK: entry:
+; This is an indication that the loop has NOT been unswitched.
+; CHECK-NOT: icmp
+; CHECK: br
+define void @and_or_i1_as_branch_input(i1 %a) {
+entry:
+  br label %for.body
+
+for.body:
+  %i = phi i1 [ 0, %entry ], [ %inc, %for.inc ]
+  %and = and i1 %a, %i 
+  %or = or i1 %and, %i
+  br i1 %or, label %sw.bb, label %sw.bb1
+
+sw.bb:
+  br label %sw.epilog
+
+sw.bb1:
+  br label %sw.epilog
+
+sw.epilog:
+  br label %for.inc
+
+for.inc:
+  %inc = add nsw i1 %i, 1
+  %cmp = icmp slt i1 %inc, 1 
+  br i1 %cmp, label %for.body, label %for.end
+
+for.end:
+  ret void
+}
 
 declare void @incf() noreturn
 declare void @decf() noreturn




More information about the llvm-commits mailing list