[llvm] b2df961 - [IndVarSimplify][LoopUtils] Avoid TOCTOU/ordering issues (PR45835)

Roman Lebedev via llvm-commits llvm-commits at lists.llvm.org
Thu May 21 03:08:33 PDT 2020


Author: Roman Lebedev
Date: 2020-05-21T13:05:55+03:00
New Revision: b2df96123198deadad74634c978e84912314da26

URL: https://github.com/llvm/llvm-project/commit/b2df96123198deadad74634c978e84912314da26
DIFF: https://github.com/llvm/llvm-project/commit/b2df96123198deadad74634c978e84912314da26.diff

LOG: [IndVarSimplify][LoopUtils] Avoid TOCTOU/ordering issues (PR45835)

Summary:
Currently, `rewriteLoopExitValues()`'s logic is roughly as following:
> Loop over each incoming value in each PHI node.
> Query whether the SCEV for that incoming value is high-cost.
> Expand the SCEV.
> Perform sanity check (`isValidRewrite()`, D51582)
> Record the info
> Afterwards, see if we can drop the loop given replacements.
> Maybe perform replacements.

The problem is that we interleave SCEV cost checking and expansion.
This is A Problem, because `isHighCostExpansion()` takes special care
to not bill for the expansions that were already expanded, and we can reuse.

While it makes sense in general - if we know that we will expand some SCEV,
all the other SCEV's costs should account for that, which might cause
some of them to become non-high-cost too, and cause chain reaction.

But that isn't what we are doing here. We expand *all* SCEV's, unconditionally.
So every next SCEV's cost will be affected by the already-performed expansions
for previous SCEV's. Even if we are not planning on keeping
some of the expansions we performed.

Worse yet, this current "bonus" depends on the exact PHI node
incoming value processing order. This is completely wrong.

As an example of an issue, see @dmajor's `pr45835.ll` - if we happen to have
a PHI node with two(!) identical high-cost incoming values for the same basic blocks,
we would decide first time around that it is high-cost, expand it,
and immediately decide that it is not high-cost because we have an expansion
that we could reuse (because we expanded it right before, temporarily),
and replace the second incoming value but not the first one;
thus resulting in a broken PHI.

What we instead should do for now, is not perform any expansions
until after we've queried all the costs.

Later, in particular after `isValidRewrite()` is an assertion (D51582)
we could improve upon that, but in a more coherent fashion.

See [[ https://bugs.llvm.org/show_bug.cgi?id=45835 | PR45835 ]]

Reviewers: dmajor, reames, mkazantsev, fhahn, efriedma

Reviewed By: dmajor, mkazantsev

Subscribers: smeenai, nikic, hiraditya, javed.absar, llvm-commits, dmajor

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D79787

Added: 
    llvm/test/Transforms/IndVarSimplify/pr45835.ll

Modified: 
    llvm/lib/Transforms/Utils/LoopUtils.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp
index 972a7fc8536c..8c7475eae6e3 100644
--- a/llvm/lib/Transforms/Utils/LoopUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp
@@ -1216,13 +1216,19 @@ static bool hasHardUserWithinLoop(const Loop *L, const Instruction *I) {
 // Collect information about PHI nodes which can be transformed in
 // rewriteLoopExitValues.
 struct RewritePhi {
-  PHINode *PN;
-  unsigned Ith;   // Ith incoming value.
-  Value *Val;     // Exit value after expansion.
-  bool HighCost;  // High Cost when expansion.
-
-  RewritePhi(PHINode *P, unsigned I, Value *V, bool H)
-      : PN(P), Ith(I), Val(V), HighCost(H) {}
+  PHINode *PN;               // For which PHI node is this replacement?
+  unsigned Ith;              // For which incoming value?
+  const SCEV *ExpansionSCEV; // The SCEV of the incoming value we are rewriting.
+  Instruction *ExpansionPoint; // Where we'd like to expand that SCEV?
+  bool HighCost;               // Is this expansion a high-cost?
+
+  Value *Expansion = nullptr;
+  bool ValidRewrite = false;
+
+  RewritePhi(PHINode *P, unsigned I, const SCEV *Val, Instruction *ExpansionPt,
+             bool H)
+      : PN(P), Ith(I), ExpansionSCEV(Val), ExpansionPoint(ExpansionPt),
+        HighCost(H) {}
 };
 
 // Check whether it is possible to delete the loop after rewriting exit
@@ -1255,6 +1261,8 @@ static bool canLoopBeDeleted(Loop *L, SmallVector<RewritePhi, 8> &RewritePhiSet)
     // phase later. Skip it in the loop invariant check below.
     bool found = false;
     for (const RewritePhi &Phi : RewritePhiSet) {
+      if (!Phi.ValidRewrite)
+        continue;
       unsigned i = Phi.Ith;
       if (Phi.PN == P && (Phi.PN)->getIncomingValue(i) == Incoming) {
         found = true;
@@ -1372,42 +1380,66 @@ int llvm::rewriteLoopExitValues(Loop *L, LoopInfo *LI, TargetLibraryInfo *TLI,
             !isa<SCEVUnknown>(ExitValue) && hasHardUserWithinLoop(L, Inst))
           continue;
 
+        // Check if expansions of this SCEV would count as being high cost.
         bool HighCost = Rewriter.isHighCostExpansion(
             ExitValue, L, SCEVCheapExpansionBudget, TTI, Inst);
-        Value *ExitVal = Rewriter.expandCodeFor(ExitValue, PN->getType(), Inst);
-
-        LLVM_DEBUG(dbgs() << "rewriteLoopExitValues: AfterLoopVal = "
-                          << *ExitVal << '\n' << "  LoopVal = " << *Inst
-                          << "\n");
-
-        if (!isValidRewrite(SE, Inst, ExitVal)) {
-          DeadInsts.push_back(ExitVal);
-          continue;
-        }
 
-#ifndef NDEBUG
-        // If we reuse an instruction from a loop which is neither L nor one of
-        // its containing loops, we end up breaking LCSSA form for this loop by
-        // creating a new use of its instruction.
-        if (auto *ExitInsn = dyn_cast<Instruction>(ExitVal))
-          if (auto *EVL = LI->getLoopFor(ExitInsn->getParent()))
-            if (EVL != L)
-              assert(EVL->contains(L) && "LCSSA breach detected!");
-#endif
+        // Note that we must not perform expansions until after
+        // we query *all* the costs, because if we perform temporary expansion
+        // inbetween, one that we might not intend to keep, said expansion
+        // *may* affect cost calculation of the the next SCEV's we'll query,
+        // and next SCEV may errneously get smaller cost.
 
         // Collect all the candidate PHINodes to be rewritten.
-        RewritePhiSet.emplace_back(PN, i, ExitVal, HighCost);
+        RewritePhiSet.emplace_back(PN, i, ExitValue, Inst, HighCost);
       }
     }
   }
 
+  // Now that we've done preliminary filtering and billed all the SCEV's,
+  // we can perform the last sanity check - the expansion must be valid.
+  for (RewritePhi &Phi : RewritePhiSet) {
+    Phi.Expansion = Rewriter.expandCodeFor(Phi.ExpansionSCEV, Phi.PN->getType(),
+                                           Phi.ExpansionPoint);
+
+    LLVM_DEBUG(dbgs() << "rewriteLoopExitValues: AfterLoopVal = "
+                      << *(Phi.Expansion) << '\n'
+                      << "  LoopVal = " << *(Phi.ExpansionPoint) << "\n");
+
+    // FIXME: isValidRewrite() is a hack. it should be an assert, eventually.
+    Phi.ValidRewrite = isValidRewrite(SE, Phi.ExpansionPoint, Phi.Expansion);
+    if (!Phi.ValidRewrite) {
+      DeadInsts.push_back(Phi.Expansion);
+      continue;
+    }
+
+#ifndef NDEBUG
+    // If we reuse an instruction from a loop which is neither L nor one of
+    // its containing loops, we end up breaking LCSSA form for this loop by
+    // creating a new use of its instruction.
+    if (auto *ExitInsn = dyn_cast<Instruction>(Phi.Expansion))
+      if (auto *EVL = LI->getLoopFor(ExitInsn->getParent()))
+        if (EVL != L)
+          assert(EVL->contains(L) && "LCSSA breach detected!");
+#endif
+  }
+
+  // TODO: after isValidRewrite() is an assertion, evaluate whether
+  // it is beneficial to change how we calculate high-cost:
+  // if we have SCEV 'A' which we know we will expand, should we calculate
+  // the cost of other SCEV's after expanding SCEV 'A',
+  // thus potentially giving cost bonus to those other SCEV's?
+
   bool LoopCanBeDel = canLoopBeDeleted(L, RewritePhiSet);
   int NumReplaced = 0;
 
   // Transformation.
   for (const RewritePhi &Phi : RewritePhiSet) {
+    if (!Phi.ValidRewrite)
+      continue;
+
     PHINode *PN = Phi.PN;
-    Value *ExitVal = Phi.Val;
+    Value *ExitVal = Phi.Expansion;
 
     // Only do the rewrite when the ExitValue can be expanded cheaply.
     // If LoopCanBeDel is true, rewrite exit value aggressively.

diff  --git a/llvm/test/Transforms/IndVarSimplify/pr45835.ll b/llvm/test/Transforms/IndVarSimplify/pr45835.ll
new file mode 100644
index 000000000000..d5bab5ada626
--- /dev/null
+++ b/llvm/test/Transforms/IndVarSimplify/pr45835.ll
@@ -0,0 +1,38 @@
+; RUN: opt < %s -indvars -replexitval=always -S | FileCheck %s --check-prefix=ALWAYS
+; RUN: opt < %s -indvars -replexitval=never -S | FileCheck %s --check-prefix=NEVER
+; RUN: opt < %s -indvars -replexitval=cheap -scev-cheap-expansion-budget=1 -S | FileCheck %s --check-prefix=CHEAP
+
+; rewriteLoopExitValues() must rewrite all or none of a PHI's values from a given block.
+
+target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
+
+ at a = common global i8 0, align 1
+
+define internal fastcc void @d(i8* %c) unnamed_addr #0 {
+entry:
+  %cmp = icmp ule i8* %c, getelementptr inbounds (i8, i8* @a, i64 65535)
+  %add.ptr = getelementptr inbounds i8, i8* %c, i64 -65535
+  br label %while.cond
+
+while.cond:
+  br i1 icmp ne (i8 0, i8 0), label %cont, label %while.end
+
+cont:
+  %a.mux = select i1 %cmp, i8* @a, i8* %add.ptr
+  switch i64 0, label %while.cond [
+    i64 -1, label %handler.pointer_overflow.i
+    i64 0, label %handler.pointer_overflow.i
+  ]
+
+handler.pointer_overflow.i:
+  %a.mux.lcssa4 = phi i8* [ %a.mux, %cont ], [ %a.mux, %cont ]
+; ALWAYS: [ %scevgep, %cont ], [ %scevgep, %cont ]
+; NEVER: [ %a.mux, %cont ], [ %a.mux, %cont ]
+; In cheap mode, use either one as long as it's consistent.
+; CHEAP: [ %[[VAL:.*]], %cont ], [ %[[VAL]], %cont ]
+  %x5 = ptrtoint i8* %a.mux.lcssa4 to i64
+  br label %while.end
+
+while.end:
+  ret void
+}


        


More information about the llvm-commits mailing list