[llvm] 99166fd - [SCEVExpander] Add option to preserve LCSSA directly.

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 29 07:08:15 PDT 2020


Author: Florian Hahn
Date: 2020-07-29T15:07:37+01:00
New Revision: 99166fd4fb422351f131fb1265cb85d5f6c5b8da

URL: https://github.com/llvm/llvm-project/commit/99166fd4fb422351f131fb1265cb85d5f6c5b8da
DIFF: https://github.com/llvm/llvm-project/commit/99166fd4fb422351f131fb1265cb85d5f6c5b8da.diff

LOG: [SCEVExpander] Add option to preserve LCSSA directly.

This patch teaches SCEVExpander to directly preserve LCSSA.

As it is currently, SCEV does not look through PHI nodes in loops,
as it might break LCSSA form. Once SCEVExpander can preserve
LCSSA form, it should be safe for SCEV to look through PHIs.

To preserve LCSSA form, this patch uses formLCSSAForInstructions
on operands of newly created instructions, if the definition is inside
a different loop than the new instruction.

The final value we return from expandCodeFor may also need LCSSA
phis, depending on the insert point. As no user for it exists there yet,
create a temporary instruction at the insert point, which can be passed
to formLCSSAForInstructions. This temporary instruction is removed
after LCSSA construction.

Reviewed By: mkazantsev

Differential Revision: https://reviews.llvm.org/D71538

Added: 
    

Modified: 
    llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h
    llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
    llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
    llvm/unittests/Transforms/Utils/ScalarEvolutionExpanderTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h b/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h
index 6e53d15d25fc7..cb212b2705eed 100644
--- a/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h
+++ b/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h
@@ -52,6 +52,9 @@ class SCEVExpander : public SCEVVisitor<SCEVExpander, Value *> {
   // New instructions receive a name to identify them with the current pass.
   const char *IVName;
 
+  /// Indicates whether LCSSA phis should be created for inserted values.
+  bool PreserveLCSSA;
+
   // InsertedExpressions caches Values for reuse, so must track RAUW.
   DenseMap<std::pair<const SCEV *, Instruction *>, TrackingVH<Value>>
       InsertedExpressions;
@@ -146,9 +149,10 @@ class SCEVExpander : public SCEVVisitor<SCEVExpander, Value *> {
 public:
   /// Construct a SCEVExpander in "canonical" mode.
   explicit SCEVExpander(ScalarEvolution &se, const DataLayout &DL,
-                        const char *name)
-      : SE(se), DL(DL), IVName(name), IVIncInsertLoop(nullptr),
-        IVIncInsertPos(nullptr), CanonicalMode(true), LSRMode(false),
+                        const char *name, bool PreserveLCSSA = true)
+      : SE(se), DL(DL), IVName(name), PreserveLCSSA(PreserveLCSSA),
+        IVIncInsertLoop(nullptr), IVIncInsertPos(nullptr), CanonicalMode(true),
+        LSRMode(false),
         Builder(se.getContext(), TargetFolder(DL),
                 IRBuilderCallbackInserter(
                     [this](Instruction *I) { rememberInstruction(I); })) {
@@ -223,14 +227,18 @@ class SCEVExpander : public SCEVVisitor<SCEVExpander, Value *> {
                                const TargetTransformInfo *TTI = nullptr);
 
   /// Insert code to directly compute the specified SCEV expression into the
-  /// program.  The inserted code is inserted into the specified block.
-  Value *expandCodeFor(const SCEV *SH, Type *Ty, Instruction *I);
+  /// program.  The code is inserted into the specified block.
+  Value *expandCodeFor(const SCEV *SH, Type *Ty, Instruction *I) {
+    return expandCodeForImpl(SH, Ty, I, true);
+  }
 
   /// Insert code to directly compute the specified SCEV expression into the
-  /// program.  The inserted code is inserted into the SCEVExpander's current
+  /// program.  The code is inserted into the SCEVExpander's current
   /// insertion point. If a type is specified, the result will be expanded to
   /// have that type, with a cast if necessary.
-  Value *expandCodeFor(const SCEV *SH, Type *Ty = nullptr);
+  Value *expandCodeFor(const SCEV *SH, Type *Ty = nullptr) {
+    return expandCodeForImpl(SH, Ty, true);
+  }
 
   /// Generates a code sequence that evaluates this predicate.  The inserted
   /// instructions will be at position \p Loc.  The result will be of type i1
@@ -338,6 +346,20 @@ class SCEVExpander : public SCEVVisitor<SCEVExpander, Value *> {
 private:
   LLVMContext &getContext() const { return SE.getContext(); }
 
+  /// Insert code to directly compute the specified SCEV expression into the
+  /// program. The code is inserted into the SCEVExpander's current
+  /// insertion point. If a type is specified, the result will be expanded to
+  /// have that type, with a cast if necessary. If \p Root is true, this
+  /// indicates that \p SH is the top-level expression to expand passed from
+  /// an external client call.
+  Value *expandCodeForImpl(const SCEV *SH, Type *Ty, bool Root);
+
+  /// Insert code to directly compute the specified SCEV expression into the
+  /// program. The code is inserted into the specified block. If \p
+  /// Root is true, this indicates that \p SH is the top-level expression to
+  /// expand passed from an external client call.
+  Value *expandCodeForImpl(const SCEV *SH, Type *Ty, Instruction *I, bool Root);
+
   /// Recursive helper function for isHighCostExpansion.
   bool isHighCostExpansionHelper(const SCEV *S, Loop *L, const Instruction &At,
                                  int &BudgetRemaining,
@@ -419,6 +441,11 @@ class SCEVExpander : public SCEVVisitor<SCEVExpander, Value *> {
                       Instruction *Pos, PHINode *LoopPhi);
 
   void fixupInsertPoints(Instruction *I);
+
+  /// If required, create LCSSA PHIs for \p Users' operand \p OpIdx. If new
+  /// LCSSA PHIs have been created, return the LCSSA PHI available at \p User.
+  /// If no PHIs have been created, return the unchanged operand \p OpIdx.
+  Value *fixupLCSSAFormFor(Instruction *User, unsigned OpIdx);
 };
 } // namespace llvm
 

diff  --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
index cf02ef1e83f3f..c3e46c1fadef3 100644
--- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
@@ -5514,8 +5514,8 @@ void LSRInstance::ImplementSolution(
   // we can remove them after we are done working.
   SmallVector<WeakTrackingVH, 16> DeadInsts;
 
-  SCEVExpander Rewriter(SE, L->getHeader()->getModule()->getDataLayout(),
-                        "lsr");
+  SCEVExpander Rewriter(SE, L->getHeader()->getModule()->getDataLayout(), "lsr",
+                        false);
 #ifndef NDEBUG
   Rewriter.setDebugType(DEBUG_TYPE);
 #endif
@@ -5780,7 +5780,7 @@ static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE,
   if (EnablePhiElim && L->isLoopSimplifyForm()) {
     SmallVector<WeakTrackingVH, 16> DeadInsts;
     const DataLayout &DL = L->getHeader()->getModule()->getDataLayout();
-    SCEVExpander Rewriter(SE, DL, "lsr");
+    SCEVExpander Rewriter(SE, DL, "lsr", false);
 #ifndef NDEBUG
     Rewriter.setDebugType(DEBUG_TYPE);
 #endif

diff  --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
index c99d612a7768f..1a10e580c68c0 100644
--- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
+++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
@@ -27,6 +27,7 @@
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
+#include "llvm/Transforms/Utils/LoopUtils.h"
 
 using namespace llvm;
 
@@ -461,9 +462,10 @@ Value *SCEVExpander::expandAddToGEP(const SCEV *const *op_begin,
     // we didn't find any operands that could be factored, tentatively
     // assume that element zero was selected (since the zero offset
     // would obviously be folded away).
-    Value *Scaled = ScaledOps.empty() ?
-                    Constant::getNullValue(Ty) :
-                    expandCodeFor(SE.getAddExpr(ScaledOps), Ty);
+    Value *Scaled =
+        ScaledOps.empty()
+            ? Constant::getNullValue(Ty)
+            : expandCodeForImpl(SE.getAddExpr(ScaledOps), Ty, false);
     GepIndices.push_back(Scaled);
 
     // Collect struct field index operands.
@@ -522,7 +524,7 @@ Value *SCEVExpander::expandAddToGEP(const SCEV *const *op_begin,
            SE.DT.dominates(cast<Instruction>(V), &*Builder.GetInsertPoint()));
 
     // Expand the operands for a plain byte offset.
-    Value *Idx = expandCodeFor(SE.getAddExpr(Ops), Ty);
+    Value *Idx = expandCodeForImpl(SE.getAddExpr(Ops), Ty, false);
 
     // Fold a GEP with constant operands.
     if (Constant *CLHS = dyn_cast<Constant>(V))
@@ -743,14 +745,14 @@ Value *SCEVExpander::visitAddExpr(const SCEVAddExpr *S) {
       Sum = expandAddToGEP(NewOps.begin(), NewOps.end(), PTy, Ty, expand(Op));
     } else if (Op->isNonConstantNegative()) {
       // Instead of doing a negate and add, just do a subtract.
-      Value *W = expandCodeFor(SE.getNegativeSCEV(Op), Ty);
+      Value *W = expandCodeForImpl(SE.getNegativeSCEV(Op), Ty, false);
       Sum = InsertNoopCastOfTo(Sum, Ty);
       Sum = InsertBinop(Instruction::Sub, Sum, W, SCEV::FlagAnyWrap,
                         /*IsSafeToHoist*/ true);
       ++I;
     } else {
       // A simple add.
-      Value *W = expandCodeFor(Op, Ty);
+      Value *W = expandCodeForImpl(Op, Ty, false);
       Sum = InsertNoopCastOfTo(Sum, Ty);
       // Canonicalize a constant to the RHS.
       if (isa<Constant>(Sum)) std::swap(Sum, W);
@@ -802,7 +804,7 @@ Value *SCEVExpander::visitMulExpr(const SCEVMulExpr *S) {
 
     // Calculate powers with exponents 1, 2, 4, 8 etc. and include those of them
     // that are needed into the result.
-    Value *P = expandCodeFor(I->second, Ty);
+    Value *P = expandCodeForImpl(I->second, Ty, false);
     Value *Result = nullptr;
     if (Exponent & 1)
       Result = P;
@@ -861,7 +863,7 @@ Value *SCEVExpander::visitMulExpr(const SCEVMulExpr *S) {
 Value *SCEVExpander::visitUDivExpr(const SCEVUDivExpr *S) {
   Type *Ty = SE.getEffectiveSCEVType(S->getType());
 
-  Value *LHS = expandCodeFor(S->getLHS(), Ty);
+  Value *LHS = expandCodeForImpl(S->getLHS(), Ty, false);
   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(S->getRHS())) {
     const APInt &RHS = SC->getAPInt();
     if (RHS.isPowerOf2())
@@ -870,7 +872,7 @@ Value *SCEVExpander::visitUDivExpr(const SCEVUDivExpr *S) {
                          SCEV::FlagAnyWrap, /*IsSafeToHoist*/ true);
   }
 
-  Value *RHS = expandCodeFor(S->getRHS(), Ty);
+  Value *RHS = expandCodeForImpl(S->getRHS(), Ty, false);
   return InsertBinop(Instruction::UDiv, LHS, RHS, SCEV::FlagAnyWrap,
                      /*IsSafeToHoist*/ SE.isKnownNonZero(S->getRHS()));
 }
@@ -1265,8 +1267,9 @@ SCEVExpander::getAddRecExprPHILiterally(const SCEVAddRecExpr *Normalized,
   // Expand code for the start value into the loop preheader.
   assert(L->getLoopPreheader() &&
          "Can't expand add recurrences without a loop preheader!");
-  Value *StartV = expandCodeFor(Normalized->getStart(), ExpandTy,
-                                L->getLoopPreheader()->getTerminator());
+  Value *StartV =
+      expandCodeForImpl(Normalized->getStart(), ExpandTy,
+                        L->getLoopPreheader()->getTerminator(), false);
 
   // StartV must have been be inserted into L's preheader to dominate the new
   // phi.
@@ -1284,8 +1287,8 @@ SCEVExpander::getAddRecExprPHILiterally(const SCEVAddRecExpr *Normalized,
   if (useSubtract)
     Step = SE.getNegativeSCEV(Step);
   // Expand the step somewhere that dominates the loop header.
-  Value *StepV = expandCodeFor(Step, IntTy,
-                               &*L->getHeader()->getFirstInsertionPt());
+  Value *StepV = expandCodeForImpl(
+      Step, IntTy, &*L->getHeader()->getFirstInsertionPt(), false);
 
   // The no-wrap behavior proved by IsIncrement(NUW|NSW) is only applicable if
   // we actually do emit an addition.  It does not apply if we emit a
@@ -1430,8 +1433,8 @@ Value *SCEVExpander::expandAddRecExprLiterally(const SCEVAddRecExpr *S) {
       {
         // Expand the step somewhere that dominates the loop header.
         SCEVInsertPointGuard Guard(Builder, this);
-        StepV = expandCodeFor(Step, IntTy,
-                              &*L->getHeader()->getFirstInsertionPt());
+        StepV = expandCodeForImpl(
+            Step, IntTy, &*L->getHeader()->getFirstInsertionPt(), false);
       }
       Result = expandIVInc(PN, StepV, L, ExpandTy, IntTy, useSubtract);
     }
@@ -1450,8 +1453,8 @@ Value *SCEVExpander::expandAddRecExprLiterally(const SCEVAddRecExpr *S) {
 
     // Invert the result.
     if (InvertStep)
-      Result = Builder.CreateSub(expandCodeFor(Normalized->getStart(), TruncTy),
-                                 Result);
+      Result = Builder.CreateSub(
+          expandCodeForImpl(Normalized->getStart(), TruncTy, false), Result);
   }
 
   // Re-apply any non-loop-dominating scale.
@@ -1459,22 +1462,22 @@ Value *SCEVExpander::expandAddRecExprLiterally(const SCEVAddRecExpr *S) {
     assert(S->isAffine() && "Can't linearly scale non-affine recurrences.");
     Result = InsertNoopCastOfTo(Result, IntTy);
     Result = Builder.CreateMul(Result,
-                               expandCodeFor(PostLoopScale, IntTy));
+                               expandCodeForImpl(PostLoopScale, IntTy, false));
   }
 
   // Re-apply any non-loop-dominating offset.
   if (PostLoopOffset) {
     if (PointerType *PTy = dyn_cast<PointerType>(ExpandTy)) {
       if (Result->getType()->isIntegerTy()) {
-        Value *Base = expandCodeFor(PostLoopOffset, ExpandTy);
+        Value *Base = expandCodeForImpl(PostLoopOffset, ExpandTy, false);
         Result = expandAddToGEP(SE.getUnknown(Result), PTy, IntTy, Base);
       } else {
         Result = expandAddToGEP(PostLoopOffset, PTy, IntTy, Result);
       }
     } else {
       Result = InsertNoopCastOfTo(Result, IntTy);
-      Result = Builder.CreateAdd(Result,
-                                 expandCodeFor(PostLoopOffset, IntTy));
+      Result = Builder.CreateAdd(
+          Result, expandCodeForImpl(PostLoopOffset, IntTy, false));
     }
   }
 
@@ -1516,8 +1519,8 @@ Value *SCEVExpander::visitAddRecExpr(const SCEVAddRecExpr *S) {
                                        S->getNoWrapFlags(SCEV::FlagNW)));
     BasicBlock::iterator NewInsertPt =
         findInsertPointAfter(cast<Instruction>(V), Builder.GetInsertBlock());
-    V = expandCodeFor(SE.getTruncateExpr(SE.getUnknown(V), Ty), nullptr,
-                      &*NewInsertPt);
+    V = expandCodeForImpl(SE.getTruncateExpr(SE.getUnknown(V), Ty), nullptr,
+                          &*NewInsertPt, false);
     return V;
   }
 
@@ -1632,22 +1635,25 @@ Value *SCEVExpander::visitAddRecExpr(const SCEVAddRecExpr *S) {
 
 Value *SCEVExpander::visitTruncateExpr(const SCEVTruncateExpr *S) {
   Type *Ty = SE.getEffectiveSCEVType(S->getType());
-  Value *V = expandCodeFor(S->getOperand(),
-                           SE.getEffectiveSCEVType(S->getOperand()->getType()));
+  Value *V = expandCodeForImpl(
+      S->getOperand(), SE.getEffectiveSCEVType(S->getOperand()->getType()),
+      false);
   return Builder.CreateTrunc(V, Ty);
 }
 
 Value *SCEVExpander::visitZeroExtendExpr(const SCEVZeroExtendExpr *S) {
   Type *Ty = SE.getEffectiveSCEVType(S->getType());
-  Value *V = expandCodeFor(S->getOperand(),
-                           SE.getEffectiveSCEVType(S->getOperand()->getType()));
+  Value *V = expandCodeForImpl(
+      S->getOperand(), SE.getEffectiveSCEVType(S->getOperand()->getType()),
+      false);
   return Builder.CreateZExt(V, Ty);
 }
 
 Value *SCEVExpander::visitSignExtendExpr(const SCEVSignExtendExpr *S) {
   Type *Ty = SE.getEffectiveSCEVType(S->getType());
-  Value *V = expandCodeFor(S->getOperand(),
-                           SE.getEffectiveSCEVType(S->getOperand()->getType()));
+  Value *V = expandCodeForImpl(
+      S->getOperand(), SE.getEffectiveSCEVType(S->getOperand()->getType()),
+      false);
   return Builder.CreateSExt(V, Ty);
 }
 
@@ -1662,7 +1668,7 @@ Value *SCEVExpander::visitSMaxExpr(const SCEVSMaxExpr *S) {
       Ty = SE.getEffectiveSCEVType(Ty);
       LHS = InsertNoopCastOfTo(LHS, Ty);
     }
-    Value *RHS = expandCodeFor(S->getOperand(i), Ty);
+    Value *RHS = expandCodeForImpl(S->getOperand(i), Ty, false);
     Value *ICmp = Builder.CreateICmpSGT(LHS, RHS);
     Value *Sel = Builder.CreateSelect(ICmp, LHS, RHS, "smax");
     LHS = Sel;
@@ -1685,7 +1691,7 @@ Value *SCEVExpander::visitUMaxExpr(const SCEVUMaxExpr *S) {
       Ty = SE.getEffectiveSCEVType(Ty);
       LHS = InsertNoopCastOfTo(LHS, Ty);
     }
-    Value *RHS = expandCodeFor(S->getOperand(i), Ty);
+    Value *RHS = expandCodeForImpl(S->getOperand(i), Ty, false);
     Value *ICmp = Builder.CreateICmpUGT(LHS, RHS);
     Value *Sel = Builder.CreateSelect(ICmp, LHS, RHS, "umax");
     LHS = Sel;
@@ -1708,7 +1714,7 @@ Value *SCEVExpander::visitSMinExpr(const SCEVSMinExpr *S) {
       Ty = SE.getEffectiveSCEVType(Ty);
       LHS = InsertNoopCastOfTo(LHS, Ty);
     }
-    Value *RHS = expandCodeFor(S->getOperand(i), Ty);
+    Value *RHS = expandCodeForImpl(S->getOperand(i), Ty, false);
     Value *ICmp = Builder.CreateICmpSLT(LHS, RHS);
     Value *Sel = Builder.CreateSelect(ICmp, LHS, RHS, "smin");
     LHS = Sel;
@@ -1731,7 +1737,7 @@ Value *SCEVExpander::visitUMinExpr(const SCEVUMinExpr *S) {
       Ty = SE.getEffectiveSCEVType(Ty);
       LHS = InsertNoopCastOfTo(LHS, Ty);
     }
-    Value *RHS = expandCodeFor(S->getOperand(i), Ty);
+    Value *RHS = expandCodeForImpl(S->getOperand(i), Ty, false);
     Value *ICmp = Builder.CreateICmpULT(LHS, RHS);
     Value *Sel = Builder.CreateSelect(ICmp, LHS, RHS, "umin");
     LHS = Sel;
@@ -1743,15 +1749,43 @@ Value *SCEVExpander::visitUMinExpr(const SCEVUMinExpr *S) {
   return LHS;
 }
 
-Value *SCEVExpander::expandCodeFor(const SCEV *SH, Type *Ty,
-                                   Instruction *IP) {
+Value *SCEVExpander::expandCodeForImpl(const SCEV *SH, Type *Ty,
+                                       Instruction *IP, bool Root) {
   setInsertPoint(IP);
-  return expandCodeFor(SH, Ty);
+  Value *V = expandCodeForImpl(SH, Ty, Root);
+  return V;
 }
 
-Value *SCEVExpander::expandCodeFor(const SCEV *SH, Type *Ty) {
+Value *SCEVExpander::expandCodeForImpl(const SCEV *SH, Type *Ty, bool Root) {
   // Expand the code for this SCEV.
   Value *V = expand(SH);
+
+  if (PreserveLCSSA) {
+    if (auto *Inst = dyn_cast<Instruction>(V)) {
+      // Create a temporary instruction to at the current insertion point, so we
+      // can hand it off to the helper to create LCSSA PHIs if required for the
+      // new use.
+      // FIXME: Ideally formLCSSAForInstructions (used in fixupLCSSAFormFor)
+      // would accept a insertion point and return an LCSSA phi for that
+      // insertion point, so there is no need to insert & remove the temporary
+      // instruction.
+      Instruction *Tmp;
+      if (Inst->getType()->isIntegerTy())
+        Tmp = cast<Instruction>(Builder.CreateAdd(Inst, Inst));
+      else {
+        assert(Inst->getType()->isPointerTy());
+        Tmp = cast<Instruction>(Builder.CreateGEP(Inst, Builder.getInt32(1)));
+      }
+      V = fixupLCSSAFormFor(Tmp, 0);
+
+      // Clean up temporary instruction.
+      InsertedValues.erase(Tmp);
+      InsertedPostIncValues.erase(Tmp);
+      Tmp->eraseFromParent();
+    }
+  }
+
+  InsertedExpressions[std::make_pair(SH, &*Builder.GetInsertPoint())] = V;
   if (Ty) {
     assert(SE.getTypeSizeInBits(Ty) == SE.getTypeSizeInBits(SH->getType()) &&
            "non-trivial casts should be done with the SCEVs directly!");
@@ -1891,10 +1925,28 @@ Value *SCEVExpander::expand(const SCEV *S) {
 }
 
 void SCEVExpander::rememberInstruction(Value *I) {
-  if (!PostIncLoops.empty())
-    InsertedPostIncValues.insert(I);
-  else
-    InsertedValues.insert(I);
+  auto DoInsert = [this](Value *V) {
+    if (!PostIncLoops.empty())
+      InsertedPostIncValues.insert(V);
+    else
+      InsertedValues.insert(V);
+  };
+  DoInsert(I);
+
+  if (!PreserveLCSSA)
+    return;
+
+  if (auto *Inst = dyn_cast<Instruction>(I)) {
+    // A new instruction has been added, which might introduce new uses outside
+    // a defining loop. Fix LCSSA from for each operand of the new instruction,
+    // if required.
+    for (unsigned OpIdx = 0, OpEnd = Inst->getNumOperands(); OpIdx != OpEnd;
+         OpIdx++) {
+      auto *V = fixupLCSSAFormFor(Inst, OpIdx);
+      if (V != I)
+        DoInsert(V);
+    }
+  }
 }
 
 /// getOrInsertCanonicalInductionVariable - This method returns the
@@ -1913,9 +1965,8 @@ SCEVExpander::getOrInsertCanonicalInductionVariable(const Loop *L,
 
   // Emit code for it.
   SCEVInsertPointGuard Guard(Builder, this);
-  PHINode *V =
-      cast<PHINode>(expandCodeFor(H, nullptr,
-                                  &*L->getHeader()->getFirstInsertionPt()));
+  PHINode *V = cast<PHINode>(expandCodeForImpl(
+      H, nullptr, &*L->getHeader()->getFirstInsertionPt(), false));
 
   return V;
 }
@@ -2315,8 +2366,10 @@ Value *SCEVExpander::expandCodeForPredicate(const SCEVPredicate *Pred,
 
 Value *SCEVExpander::expandEqualPredicate(const SCEVEqualPredicate *Pred,
                                           Instruction *IP) {
-  Value *Expr0 = expandCodeFor(Pred->getLHS(), Pred->getLHS()->getType(), IP);
-  Value *Expr1 = expandCodeFor(Pred->getRHS(), Pred->getRHS()->getType(), IP);
+  Value *Expr0 =
+      expandCodeForImpl(Pred->getLHS(), Pred->getLHS()->getType(), IP, false);
+  Value *Expr1 =
+      expandCodeForImpl(Pred->getRHS(), Pred->getRHS()->getType(), IP, false);
 
   Builder.SetInsertPoint(IP);
   auto *I = Builder.CreateICmpNE(Expr0, Expr1, "ident.check");
@@ -2348,15 +2401,16 @@ Value *SCEVExpander::generateOverflowCheck(const SCEVAddRecExpr *AR,
 
   IntegerType *CountTy = IntegerType::get(Loc->getContext(), SrcBits);
   Builder.SetInsertPoint(Loc);
-  Value *TripCountVal = expandCodeFor(ExitCount, CountTy, Loc);
+  Value *TripCountVal = expandCodeForImpl(ExitCount, CountTy, Loc, false);
 
   IntegerType *Ty =
       IntegerType::get(Loc->getContext(), SE.getTypeSizeInBits(ARTy));
   Type *ARExpandTy = DL.isNonIntegralPointerType(ARTy) ? ARTy : Ty;
 
-  Value *StepValue = expandCodeFor(Step, Ty, Loc);
-  Value *NegStepValue = expandCodeFor(SE.getNegativeSCEV(Step), Ty, Loc);
-  Value *StartValue = expandCodeFor(Start, ARExpandTy, Loc);
+  Value *StepValue = expandCodeForImpl(Step, Ty, Loc, false);
+  Value *NegStepValue =
+      expandCodeForImpl(SE.getNegativeSCEV(Step), Ty, Loc, false);
+  Value *StartValue = expandCodeForImpl(Start, ARExpandTy, Loc, false);
 
   ConstantInt *Zero =
       ConstantInt::get(Loc->getContext(), APInt::getNullValue(DstBits));
@@ -2459,6 +2513,25 @@ Value *SCEVExpander::expandUnionPredicate(const SCEVUnionPredicate *Union,
   return Check;
 }
 
+Value *SCEVExpander::fixupLCSSAFormFor(Instruction *User, unsigned OpIdx) {
+  assert(PreserveLCSSA);
+  SmallVector<Instruction *, 1> ToUpdate;
+
+  auto *OpV = User->getOperand(OpIdx);
+  auto *OpI = dyn_cast<Instruction>(OpV);
+  if (!OpI)
+    return OpV;
+
+  Loop *DefLoop = SE.LI.getLoopFor(OpI->getParent());
+  Loop *UseLoop = SE.LI.getLoopFor(User->getParent());
+  if (!DefLoop || UseLoop == DefLoop || DefLoop->contains(UseLoop))
+    return OpV;
+
+  ToUpdate.push_back(OpI);
+  formLCSSAForInstructions(ToUpdate, SE.DT, SE.LI, &SE);
+  return User->getOperand(OpIdx);
+}
+
 namespace {
 // Search for a SCEV subexpression that is not safe to expand.  Any expression
 // that may expand to a !isSafeToSpeculativelyExecute value is unsafe, namely

diff  --git a/llvm/unittests/Transforms/Utils/ScalarEvolutionExpanderTest.cpp b/llvm/unittests/Transforms/Utils/ScalarEvolutionExpanderTest.cpp
index c146d155ffbea..6e83370dfeb95 100644
--- a/llvm/unittests/Transforms/Utils/ScalarEvolutionExpanderTest.cpp
+++ b/llvm/unittests/Transforms/Utils/ScalarEvolutionExpanderTest.cpp
@@ -265,7 +265,7 @@ TEST_F(ScalarEvolutionExpanderTest, SCEVExpanderIsSafeToExpandAt) {
   Phi->addIncoming(Add, L);
 
   Builder.SetInsertPoint(Post);
-  Builder.CreateRetVoid();
+  Instruction *Ret = Builder.CreateRetVoid();
 
   ScalarEvolution SE = buildSE(*F);
   const SCEV *S = SE.getSCEV(Phi);
@@ -276,6 +276,11 @@ TEST_F(ScalarEvolutionExpanderTest, SCEVExpanderIsSafeToExpandAt) {
   EXPECT_FALSE(isSafeToExpandAt(AR, LPh->getTerminator(), SE));
   EXPECT_TRUE(isSafeToExpandAt(AR, L->getTerminator(), SE));
   EXPECT_TRUE(isSafeToExpandAt(AR, Post->getTerminator(), SE));
+
+  EXPECT_TRUE(LI->getLoopFor(L)->isLCSSAForm(*DT));
+  SCEVExpander Exp(SE, M.getDataLayout(), "expander");
+  Exp.expandCodeFor(SE.getSCEV(Add), nullptr, Ret);
+  EXPECT_TRUE(LI->getLoopFor(L)->isLCSSAForm(*DT));
 }
 
 // Check that SCEV expander does not use the nuw instruction


        


More information about the llvm-commits mailing list