[llvm] [ConstraintElim] Try to use info from loop latch (PR #126118)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 6 11:34:11 PST 2025


https://github.com/fhahn created https://github.com/llvm/llvm-project/pull/126118

None

>From 87eeb375a25b5ccf1fc09eb0adb30f78f54b0efb Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Thu, 6 Feb 2025 19:21:33 +0000
Subject: [PATCH 1/2] [ConstraintElimination] Prepare.

---
 .../Scalar/ConstraintElimination.cpp          | 26 ++++++++++++++++---
 1 file changed, 22 insertions(+), 4 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp b/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp
index 6dd26910f684694..684df6d1003c679 100644
--- a/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp
@@ -1083,6 +1083,7 @@ void State::addInfoForInductions(BasicBlock &BB) {
   }
 }
 
+
 void State::addInfoFor(BasicBlock &BB) {
   addInfoForInductions(BB);
 
@@ -1096,7 +1097,15 @@ void State::addInfoFor(BasicBlock &BB) {
         auto *DTN = DT.getNode(UserI->getParent());
         if (!DTN)
           continue;
+        auto *L = LI.getLoopFor(cast<Instruction>(U.getUser())->getParent());
         WorkList.push_back(FactOrCheck::getCheck(DTN, &U));
+        if (L && L->getLoopLatch() && L->isLoopExiting(L->getLoopLatch())) {
+          auto *DTNLatch = DT.getNode(L->getLoopLatch());
+          if (DTNLatch->getDFSNumIn() >= DTN->getDFSNumIn() &&
+              DTNLatch->getDFSNumOut() <= DTN->getDFSNumOut())
+            WorkList.back().NumOut =
+                std::min(WorkList.back().NumOut, DTNLatch->getDFSNumIn());
+        }
       }
       continue;
     }
@@ -1430,16 +1439,25 @@ static bool checkAndReplaceCondition(
     CmpInst *Cmp, ConstraintInfo &Info, unsigned NumIn, unsigned NumOut,
     Instruction *ContextInst, Module *ReproducerModule,
     ArrayRef<ReproducerEntry> ReproducerCondStack, DominatorTree &DT,
-    SmallVectorImpl<Instruction *> &ToRemove) {
+    SmallVectorImpl<Instruction *> &ToRemove, LoopInfo &LI) {
   auto ReplaceCmpWithConstant = [&](CmpInst *Cmp, bool IsTrue) {
     generateReproducer(Cmp, ReproducerModule, ReproducerCondStack, Info, DT);
     Constant *ConstantC = ConstantInt::getBool(
         CmpInst::makeCmpResultType(Cmp->getType()), IsTrue);
-    Cmp->replaceUsesWithIf(ConstantC, [&DT, NumIn, NumOut,
+    Cmp->replaceUsesWithIf(ConstantC, [&DT, NumIn, NumOut, Cmp, &LI,
                                        ContextInst](Use &U) {
       auto *UserI = getContextInstForUse(U);
       auto *DTN = DT.getNode(UserI->getParent());
-      if (!DTN || DTN->getDFSNumIn() < NumIn || DTN->getDFSNumOut() > NumOut)
+      auto *L = LI.getLoopFor(Cmp->getParent());
+      bool IsSameLoop = L && LI.getLoopFor(UserI->getParent()) == L;
+      unsigned UserNumOut = DTN->getDFSNumOut();
+      if (IsSameLoop && L->getLoopLatch() && L->isLoopExiting(L->getLoopLatch())) {
+        auto *DTNLatch = DT.getNode(L->getLoopLatch());
+        if (DTNLatch->getDFSNumIn() >= DTN->getDFSNumIn() &&
+            DTNLatch->getDFSNumOut() <= DTN->getDFSNumOut())
+          UserNumOut = std::min(UserNumOut, DTNLatch->getDFSNumIn());
+      }
+      if (!DTN || DTN->getDFSNumIn() < NumIn || (UserNumOut > NumOut))
         return false;
       if (UserI->getParent() == ContextInst->getParent() &&
           UserI->comesBefore(ContextInst))
@@ -1814,7 +1832,7 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT, LoopInfo &LI,
       } else if (auto *Cmp = dyn_cast<ICmpInst>(Inst)) {
         bool Simplified = checkAndReplaceCondition(
             Cmp, Info, CB.NumIn, CB.NumOut, CB.getContextInst(),
-            ReproducerModule.get(), ReproducerCondStack, S.DT, ToRemove);
+            ReproducerModule.get(), ReproducerCondStack, S.DT, ToRemove, LI);
         if (!Simplified &&
             match(CB.getContextInst(), m_LogicalOp(m_Value(), m_Value()))) {
           Simplified =

>From a2292a602398968ca61731695598e7658105f5e5 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Tue, 4 Feb 2025 13:24:09 +0000
Subject: [PATCH 2/2] Add

Fix

step
---
 .../Scalar/ConstraintElimination.cpp          | 131 ++++++++++++++++++
 .../ConstraintElimination/latch-test.ll       |  94 +++++++++++++
 2 files changed, 225 insertions(+)
 create mode 100644 llvm/test/Transforms/ConstraintElimination/latch-test.ll

diff --git a/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp b/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp
index 684df6d1003c679..9f05b726747b4ac 100644
--- a/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp
@@ -138,12 +138,23 @@ struct FactOrCheck {
       : Cond(Pred, Op0, Op1), DoesHold(Precond), NumIn(DTN->getDFSNumIn()),
         NumOut(DTN->getDFSNumOut()), Ty(EntryTy::ConditionFact) {}
 
+  FactOrCheck(unsigned NumIn, unsigned NumOut, CmpPredicate Pred, Value *Op0,
+              Value *Op1, ConditionTy Precond = {})
+      : Cond(Pred, Op0, Op1), DoesHold(Precond), NumIn(NumIn), NumOut(NumOut),
+        Ty(EntryTy::ConditionFact) {}
+
   static FactOrCheck getConditionFact(DomTreeNode *DTN, CmpPredicate Pred,
                                       Value *Op0, Value *Op1,
                                       ConditionTy Precond = {}) {
     return FactOrCheck(DTN, Pred, Op0, Op1, Precond);
   }
 
+  static FactOrCheck getConditionFact(unsigned NumIn, unsigned NumOut,
+                                      CmpPredicate Pred, Value *Op0, Value *Op1,
+                                      ConditionTy Precond = {}) {
+    return FactOrCheck(NumIn, NumOut, Pred, Op0, Op1, Precond);
+  }
+
   static FactOrCheck getInstFact(DomTreeNode *DTN, Instruction *Inst) {
     return FactOrCheck(EntryTy::InstFact, DTN, Inst);
   }
@@ -194,6 +205,7 @@ struct State {
   /// Try to add facts for loop inductions (AddRecs) in EQ/NE compares
   /// controlling the loop header.
   void addInfoForInductions(BasicBlock &BB);
+  void addInfoForInductions2(BasicBlock &BB);
 
   /// Returns true if we can add a known condition from BB to its successor
   /// block Succ.
@@ -1083,9 +1095,128 @@ void State::addInfoForInductions(BasicBlock &BB) {
   }
 }
 
+static bool hasLoopCmpUser(Instruction *I, Loop *L) {
+  SetVector<Instruction *> WorkList;
+  WorkList.insert(I);
+
+  for (unsigned I = 0; I != WorkList.size(); ++I) {
+    Instruction *Curr = WorkList[I];
+    if (isa<LoadInst, StoreInst>(Curr))
+      continue;
+    if (!L->contains(Curr))
+      continue;
+    if (WorkList.size() > 16)
+      return false;
+    if (isa<ICmpInst>(Curr) && Curr->getParent() != L->getLoopLatch())
+      return true;
+    for (User *U : Curr->users())
+      WorkList.insert(cast<Instruction>(U));
+  }
+  return false;
+}
+
+void State::addInfoForInductions2(BasicBlock &BB) {
+  auto *L = LI.getLoopFor(&BB);
+  Value *A;
+  Value *B;
+  CmpPredicate Pred;
+
+  if (!L)
+    return;
+  BasicBlock *Latch = L->getLoopLatch();
+  if (&BB != Latch || !L->isLoopExiting(Latch) || Latch == L->getHeader())
+    return;
+
+  auto *Term = Latch->getTerminator();
+  BasicBlock *TrueSucc;
+  if (!match(Term, m_Br(m_ICmp(Pred, m_Add(m_Value(A), m_Value()), m_Value(B)),
+                        m_BasicBlock(TrueSucc), m_Value())) ||
+      Pred != CmpInst::ICMP_EQ || L->contains(TrueSucc))
+    return;
+
+  PHINode *PN = dyn_cast<PHINode>(A);
+  if (!PN || PN->getParent() != L->getHeader() ||
+      PN->getNumIncomingValues() != 2 || !hasLoopCmpUser(PN, L) ||
+      !SE.isSCEVable(PN->getType()))
+    return;
+
+  auto *Val =
+      cast<Instruction>(cast<Instruction>(Term)->getOperand(0))->getOperand(0);
+  auto *AR = dyn_cast_or_null<SCEVAddRecExpr>(SE.getSCEV(Val));
+  BasicBlock *LoopPred = L->getLoopPredecessor();
+  if (!AR || AR->getLoop() != L || !LoopPred || !isa<SCEVConstant>(AR->getStart())|| !isa<SCEVConstant>(AR->getStepRecurrence(SE)))
+    return;
+
+  auto *PNAR = cast<SCEVAddRecExpr>(SE.getSCEV(PN));
+  const SCEV *StartSCEV = PNAR->getStart();
+  Value *StartValue = nullptr;
+  if (auto *C = dyn_cast<SCEVConstant>(StartSCEV)) {
+    StartValue = C->getValue();
+  } else {
+    return;
+  }
+
+  DomTreeNode *DTN = DT.getNode(L->getHeader());
+  DomTreeNode *DTNLatch = DT.getNode(&BB);
+  auto IncUnsigned = SE.getMonotonicPredicateType(PNAR, CmpInst::ICMP_UGT);
+  auto IncSigned = SE.getMonotonicPredicateType(PNAR, CmpInst::ICMP_SGT);
+  bool MonotonicallyIncreasingUnsigned =
+      IncUnsigned && *IncUnsigned == ScalarEvolution::MonotonicallyIncreasing;
+  bool MonotonicallyIncreasingSigned =
+      IncSigned && *IncSigned == ScalarEvolution::MonotonicallyIncreasing;
+  if (!MonotonicallyIncreasingSigned || !MonotonicallyIncreasingUnsigned)
+    return;
+  // If SCEV guarantees that AR does not wrap, PN >= StartValue can be added
+  // unconditionally.
+  if (MonotonicallyIncreasingUnsigned)
+    WorkList.push_back(FactOrCheck::getConditionFact(
+        DTN->getDFSNumIn(), DTNLatch->getDFSNumIn(), CmpInst::ICMP_UGE, PN,
+        StartValue));
+  if (MonotonicallyIncreasingSigned)
+    WorkList.push_back(FactOrCheck::getConditionFact(
+        DTN->getDFSNumIn(), DTNLatch->getDFSNumIn(), CmpInst::ICMP_SGE, PN,
+        StartValue));
+
+  APInt StepOffset;
+  if (auto *C = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE)))
+    StepOffset = C->getAPInt();
+  else
+    return;
+
+  // Make sure the bound B is loop-invariant.
+  if (!L->isLoopInvariant(B))
+    return;
+
+  // Make sure AR either steps by 1 or that the value we compare against is a
+  // GEP based on the same start value and all offsets are a multiple of the
+  // step size, to guarantee that the induction will reach the value.
+  if (StepOffset.isZero() || StepOffset.isNegative())
+    return;
+
+  if (!StepOffset.isOne()) {
+    // Check whether B-Start is known to be a multiple of StepOffset.
+    const SCEV *BMinusStart = SE.getMinusSCEV(SE.getSCEV(B), StartSCEV);
+    if (isa<SCEVCouldNotCompute>(BMinusStart) ||
+        !SE.getConstantMultiple(BMinusStart).urem(StepOffset).isZero())
+      return;
+  }
+
+  StartValue = cast<SCEVConstant>(SE.getAddExpr(SE.getSCEV(StartValue), AR->getStepRecurrence(SE)))->getValue();
+  if (MonotonicallyIncreasingUnsigned) {
+    WorkList.push_back(FactOrCheck::getConditionFact(
+        DTN->getDFSNumIn(), DTNLatch->getDFSNumIn(), CmpInst::ICMP_ULT, PN, B,
+        ConditionTy(CmpInst::ICMP_ULE, StartValue, B)));
+  }
+  if (MonotonicallyIncreasingSigned) {
+    WorkList.push_back(FactOrCheck::getConditionFact(
+        DTN->getDFSNumIn(), DTNLatch->getDFSNumIn(), CmpInst::ICMP_SLT, PN, B,
+        ConditionTy(CmpInst::ICMP_SLE, StartValue, B)));
+  }
+}
 
 void State::addInfoFor(BasicBlock &BB) {
   addInfoForInductions(BB);
+  addInfoForInductions2(BB);
 
   // True as long as long as the current instruction is guaranteed to execute.
   bool GuaranteedToExecute = true;
diff --git a/llvm/test/Transforms/ConstraintElimination/latch-test.ll b/llvm/test/Transforms/ConstraintElimination/latch-test.ll
new file mode 100644
index 000000000000000..744c3c7a8503c62
--- /dev/null
+++ b/llvm/test/Transforms/ConstraintElimination/latch-test.ll
@@ -0,0 +1,94 @@
+target datalayout = "e-m:o-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-n32:64-S128-Fn32"
+target triple = "arm64-apple-macosx15.0.0"
+
+define fastcc i1 @test2() {
+entry:
+  br label %loop
+
+loop:                                ; preds = %loop, %entry
+  %iv = phi i64 [ %iv.next, %loop ], [ -1, %entry ]
+  %iv.next = add i64 %iv, 1
+  %ec = icmp eq i64 %iv.next, 0
+  br i1 %ec, label %for.cond, label %loop
+
+for.cond:                                         ; preds = %loop
+  ret i1 false
+}
+
+
+define fastcc i1 @test1(i32 %bf.load.i.i.i.i724) {
+entry:
+  br label %for.body4.i.i.i.i
+
+for.body4.i.i.i.i:                                ; preds = %for.body4.i.i.i.i, %entry
+  %__n.addr.116.i.i.i.i = phi i64 [ %inc.i.i.i.i, %for.body4.i.i.i.i ], [ -1, %entry ]
+  %inc.i.i.i.i = add i64 %__n.addr.116.i.i.i.i, 1
+  %exitcond.not.i.i.i.i = icmp eq i64 %inc.i.i.i.i, 0
+  br i1 %exitcond.not.i.i.i.i, label %for.cond, label %for.body4.i.i.i.i
+
+for.cond:                                         ; preds = %for.body4.i.i.i.i
+  %tobool.not.i.i.i.i.i.i727 = icmp eq i32 %bf.load.i.i.i.i724, 0
+  %cond.i.i.i.i.i.i = select i1 %tobool.not.i.i.i.i.i.i727, ptr null, ptr null
+  ret i1 false
+}
+
+
+define i8 @test3(i32 %0) {
+entry:
+  br label %loop.header
+
+loop.header:
+  %iv = phi i32 [ 0, %entry ], [ %iv.next, %loop.latch ]
+  %c.1 = icmp eq i32 %iv, 0
+  br i1 true, label %then, label %loop.latch
+
+loop.latch:                                     ; preds = %loop.header
+  %iv.next = add i32 %iv, 1
+  %c.2 = icmp eq i32 %iv.next, 0
+  br i1 %c.2, label %exit.3, label %loop.header
+
+then:
+  %c.3 = icmp eq i32 %0, 31
+  br i1 %c.3, label %exit.2, label %exit
+
+
+exit:
+  ret i8 1
+
+exit.2:
+  ret i8 0
+
+
+exit.3:
+  ret i8 2
+}
+
+
+define void @test3(i32 %0) {
+entry:
+  br label %for.body496.i1357
+
+for.cond.cleanup495.i1365:                        ; preds = %if.else549.i
+  ret void
+
+for.body496.i1357:                                ; preds = %if.else549.i, %entry
+  %K.01766.i = phi i32 [ -1, %entry ], [ %inc.i1363, %if.else549.i ]
+  %cmp499.i1359 = icmp eq i32 %K.01766.i, 0
+  br i1 %cmp499.i1359, label %if.then502.i, label %if.else549.i
+
+if.then502.i:                                     ; preds = %for.body496.i1357
+  %cmp4.not.i1467.i = icmp eq i32 %0, 31
+  br i1 %cmp4.not.i1467.i, label %if.end6.i1470.i, label %_ZN4llvm18SaturatingMultiplyIjEENSt3__19enable_ifIXsr3stdE13is_unsigned_vIT_EES3_E4typeES3_S3_Pb.exit1485.i
+
+if.end6.i1470.i:                                  ; preds = %if.then502.i
+  ret void
+
+_ZN4llvm18SaturatingMultiplyIjEENSt3__19enable_ifIXsr3stdE13is_unsigned_vIT_EES3_E4typeES3_S3_Pb.exit1485.i: ; preds = %if.then502.i
+  ret void
+
+if.else549.i:                                     ; preds = %for.body496.i1357
+  %inc.i1363 = add i32 %K.01766.i, 1
+  %exitcond.not.i1364 = icmp eq i32 %inc.i1363, 0
+  br i1 %exitcond.not.i1364, label %for.cond.cleanup495.i1365, label %for.body496.i1357
+}
+



More information about the llvm-commits mailing list