[llvm] [ConstraintElim] Try to use info from loop latch (PR #126118)
Florian Hahn via llvm-commits
llvm-commits at lists.llvm.org
Thu Feb 6 11:34:11 PST 2025
https://github.com/fhahn created https://github.com/llvm/llvm-project/pull/126118
None
>From 87eeb375a25b5ccf1fc09eb0adb30f78f54b0efb Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Thu, 6 Feb 2025 19:21:33 +0000
Subject: [PATCH 1/2] [ConstraintElimination] Prepare.
---
.../Scalar/ConstraintElimination.cpp | 26 ++++++++++++++++---
1 file changed, 22 insertions(+), 4 deletions(-)
diff --git a/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp b/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp
index 6dd26910f684694..684df6d1003c679 100644
--- a/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp
@@ -1083,6 +1083,7 @@ void State::addInfoForInductions(BasicBlock &BB) {
}
}
+
void State::addInfoFor(BasicBlock &BB) {
addInfoForInductions(BB);
@@ -1096,7 +1097,15 @@ void State::addInfoFor(BasicBlock &BB) {
auto *DTN = DT.getNode(UserI->getParent());
if (!DTN)
continue;
+ auto *L = LI.getLoopFor(cast<Instruction>(U.getUser())->getParent());
WorkList.push_back(FactOrCheck::getCheck(DTN, &U));
+ if (L && L->getLoopLatch() && L->isLoopExiting(L->getLoopLatch())) {
+ auto *DTNLatch = DT.getNode(L->getLoopLatch());
+ if (DTNLatch->getDFSNumIn() >= DTN->getDFSNumIn() &&
+ DTNLatch->getDFSNumOut() <= DTN->getDFSNumOut())
+ WorkList.back().NumOut =
+ std::min(WorkList.back().NumOut, DTNLatch->getDFSNumIn());
+ }
}
continue;
}
@@ -1430,16 +1439,25 @@ static bool checkAndReplaceCondition(
CmpInst *Cmp, ConstraintInfo &Info, unsigned NumIn, unsigned NumOut,
Instruction *ContextInst, Module *ReproducerModule,
ArrayRef<ReproducerEntry> ReproducerCondStack, DominatorTree &DT,
- SmallVectorImpl<Instruction *> &ToRemove) {
+ SmallVectorImpl<Instruction *> &ToRemove, LoopInfo &LI) {
auto ReplaceCmpWithConstant = [&](CmpInst *Cmp, bool IsTrue) {
generateReproducer(Cmp, ReproducerModule, ReproducerCondStack, Info, DT);
Constant *ConstantC = ConstantInt::getBool(
CmpInst::makeCmpResultType(Cmp->getType()), IsTrue);
- Cmp->replaceUsesWithIf(ConstantC, [&DT, NumIn, NumOut,
+ Cmp->replaceUsesWithIf(ConstantC, [&DT, NumIn, NumOut, Cmp, &LI,
ContextInst](Use &U) {
auto *UserI = getContextInstForUse(U);
auto *DTN = DT.getNode(UserI->getParent());
- if (!DTN || DTN->getDFSNumIn() < NumIn || DTN->getDFSNumOut() > NumOut)
+ auto *L = LI.getLoopFor(Cmp->getParent());
+ bool IsSameLoop = L && LI.getLoopFor(UserI->getParent()) == L;
+ unsigned UserNumOut = DTN->getDFSNumOut();
+ if (IsSameLoop && L->getLoopLatch() && L->isLoopExiting(L->getLoopLatch())) {
+ auto *DTNLatch = DT.getNode(L->getLoopLatch());
+ if (DTNLatch->getDFSNumIn() >= DTN->getDFSNumIn() &&
+ DTNLatch->getDFSNumOut() <= DTN->getDFSNumOut())
+ UserNumOut = std::min(UserNumOut, DTNLatch->getDFSNumIn());
+ }
+ if (!DTN || DTN->getDFSNumIn() < NumIn || (UserNumOut > NumOut))
return false;
if (UserI->getParent() == ContextInst->getParent() &&
UserI->comesBefore(ContextInst))
@@ -1814,7 +1832,7 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT, LoopInfo &LI,
} else if (auto *Cmp = dyn_cast<ICmpInst>(Inst)) {
bool Simplified = checkAndReplaceCondition(
Cmp, Info, CB.NumIn, CB.NumOut, CB.getContextInst(),
- ReproducerModule.get(), ReproducerCondStack, S.DT, ToRemove);
+ ReproducerModule.get(), ReproducerCondStack, S.DT, ToRemove, LI);
if (!Simplified &&
match(CB.getContextInst(), m_LogicalOp(m_Value(), m_Value()))) {
Simplified =
>From a2292a602398968ca61731695598e7658105f5e5 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Tue, 4 Feb 2025 13:24:09 +0000
Subject: [PATCH 2/2] Add
Fix
step
---
.../Scalar/ConstraintElimination.cpp | 131 ++++++++++++++++++
.../ConstraintElimination/latch-test.ll | 94 +++++++++++++
2 files changed, 225 insertions(+)
create mode 100644 llvm/test/Transforms/ConstraintElimination/latch-test.ll
diff --git a/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp b/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp
index 684df6d1003c679..9f05b726747b4ac 100644
--- a/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp
@@ -138,12 +138,23 @@ struct FactOrCheck {
: Cond(Pred, Op0, Op1), DoesHold(Precond), NumIn(DTN->getDFSNumIn()),
NumOut(DTN->getDFSNumOut()), Ty(EntryTy::ConditionFact) {}
+ FactOrCheck(unsigned NumIn, unsigned NumOut, CmpPredicate Pred, Value *Op0,
+ Value *Op1, ConditionTy Precond = {})
+ : Cond(Pred, Op0, Op1), DoesHold(Precond), NumIn(NumIn), NumOut(NumOut),
+ Ty(EntryTy::ConditionFact) {}
+
static FactOrCheck getConditionFact(DomTreeNode *DTN, CmpPredicate Pred,
Value *Op0, Value *Op1,
ConditionTy Precond = {}) {
return FactOrCheck(DTN, Pred, Op0, Op1, Precond);
}
+ static FactOrCheck getConditionFact(unsigned NumIn, unsigned NumOut,
+ CmpPredicate Pred, Value *Op0, Value *Op1,
+ ConditionTy Precond = {}) {
+ return FactOrCheck(NumIn, NumOut, Pred, Op0, Op1, Precond);
+ }
+
static FactOrCheck getInstFact(DomTreeNode *DTN, Instruction *Inst) {
return FactOrCheck(EntryTy::InstFact, DTN, Inst);
}
@@ -194,6 +205,7 @@ struct State {
/// Try to add facts for loop inductions (AddRecs) in EQ/NE compares
/// controlling the loop header.
void addInfoForInductions(BasicBlock &BB);
+ void addInfoForInductions2(BasicBlock &BB);
/// Returns true if we can add a known condition from BB to its successor
/// block Succ.
@@ -1083,9 +1095,128 @@ void State::addInfoForInductions(BasicBlock &BB) {
}
}
+static bool hasLoopCmpUser(Instruction *I, Loop *L) {
+ SetVector<Instruction *> WorkList;
+ WorkList.insert(I);
+
+ for (unsigned I = 0; I != WorkList.size(); ++I) {
+ Instruction *Curr = WorkList[I];
+ if (isa<LoadInst, StoreInst>(Curr))
+ continue;
+ if (!L->contains(Curr))
+ continue;
+ if (WorkList.size() > 16)
+ return false;
+ if (isa<ICmpInst>(Curr) && Curr->getParent() != L->getLoopLatch())
+ return true;
+ for (User *U : Curr->users())
+ WorkList.insert(cast<Instruction>(U));
+ }
+ return false;
+}
+
+void State::addInfoForInductions2(BasicBlock &BB) {
+ auto *L = LI.getLoopFor(&BB);
+ Value *A;
+ Value *B;
+ CmpPredicate Pred;
+
+ if (!L)
+ return;
+ BasicBlock *Latch = L->getLoopLatch();
+ if (&BB != Latch || !L->isLoopExiting(Latch) || Latch == L->getHeader())
+ return;
+
+ auto *Term = Latch->getTerminator();
+ BasicBlock *TrueSucc;
+ if (!match(Term, m_Br(m_ICmp(Pred, m_Add(m_Value(A), m_Value()), m_Value(B)),
+ m_BasicBlock(TrueSucc), m_Value())) ||
+ Pred != CmpInst::ICMP_EQ || L->contains(TrueSucc))
+ return;
+
+ PHINode *PN = dyn_cast<PHINode>(A);
+ if (!PN || PN->getParent() != L->getHeader() ||
+ PN->getNumIncomingValues() != 2 || !hasLoopCmpUser(PN, L) ||
+ !SE.isSCEVable(PN->getType()))
+ return;
+
+ auto *Val =
+ cast<Instruction>(cast<Instruction>(Term)->getOperand(0))->getOperand(0);
+ auto *AR = dyn_cast_or_null<SCEVAddRecExpr>(SE.getSCEV(Val));
+ BasicBlock *LoopPred = L->getLoopPredecessor();
+ if (!AR || AR->getLoop() != L || !LoopPred || !isa<SCEVConstant>(AR->getStart())|| !isa<SCEVConstant>(AR->getStepRecurrence(SE)))
+ return;
+
+ auto *PNAR = cast<SCEVAddRecExpr>(SE.getSCEV(PN));
+ const SCEV *StartSCEV = PNAR->getStart();
+ Value *StartValue = nullptr;
+ if (auto *C = dyn_cast<SCEVConstant>(StartSCEV)) {
+ StartValue = C->getValue();
+ } else {
+ return;
+ }
+
+ DomTreeNode *DTN = DT.getNode(L->getHeader());
+ DomTreeNode *DTNLatch = DT.getNode(&BB);
+ auto IncUnsigned = SE.getMonotonicPredicateType(PNAR, CmpInst::ICMP_UGT);
+ auto IncSigned = SE.getMonotonicPredicateType(PNAR, CmpInst::ICMP_SGT);
+ bool MonotonicallyIncreasingUnsigned =
+ IncUnsigned && *IncUnsigned == ScalarEvolution::MonotonicallyIncreasing;
+ bool MonotonicallyIncreasingSigned =
+ IncSigned && *IncSigned == ScalarEvolution::MonotonicallyIncreasing;
+ if (!MonotonicallyIncreasingSigned || !MonotonicallyIncreasingUnsigned)
+ return;
+ // If SCEV guarantees that AR does not wrap, PN >= StartValue can be added
+ // unconditionally.
+ if (MonotonicallyIncreasingUnsigned)
+ WorkList.push_back(FactOrCheck::getConditionFact(
+ DTN->getDFSNumIn(), DTNLatch->getDFSNumIn(), CmpInst::ICMP_UGE, PN,
+ StartValue));
+ if (MonotonicallyIncreasingSigned)
+ WorkList.push_back(FactOrCheck::getConditionFact(
+ DTN->getDFSNumIn(), DTNLatch->getDFSNumIn(), CmpInst::ICMP_SGE, PN,
+ StartValue));
+
+ APInt StepOffset;
+ if (auto *C = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE)))
+ StepOffset = C->getAPInt();
+ else
+ return;
+
+ // Make sure the bound B is loop-invariant.
+ if (!L->isLoopInvariant(B))
+ return;
+
+ // Make sure AR either steps by 1 or that the value we compare against is a
+ // GEP based on the same start value and all offsets are a multiple of the
+ // step size, to guarantee that the induction will reach the value.
+ if (StepOffset.isZero() || StepOffset.isNegative())
+ return;
+
+ if (!StepOffset.isOne()) {
+ // Check whether B-Start is known to be a multiple of StepOffset.
+ const SCEV *BMinusStart = SE.getMinusSCEV(SE.getSCEV(B), StartSCEV);
+ if (isa<SCEVCouldNotCompute>(BMinusStart) ||
+ !SE.getConstantMultiple(BMinusStart).urem(StepOffset).isZero())
+ return;
+ }
+
+ StartValue = cast<SCEVConstant>(SE.getAddExpr(SE.getSCEV(StartValue), AR->getStepRecurrence(SE)))->getValue();
+ if (MonotonicallyIncreasingUnsigned) {
+ WorkList.push_back(FactOrCheck::getConditionFact(
+ DTN->getDFSNumIn(), DTNLatch->getDFSNumIn(), CmpInst::ICMP_ULT, PN, B,
+ ConditionTy(CmpInst::ICMP_ULE, StartValue, B)));
+ }
+ if (MonotonicallyIncreasingSigned) {
+ WorkList.push_back(FactOrCheck::getConditionFact(
+ DTN->getDFSNumIn(), DTNLatch->getDFSNumIn(), CmpInst::ICMP_SLT, PN, B,
+ ConditionTy(CmpInst::ICMP_SLE, StartValue, B)));
+ }
+}
void State::addInfoFor(BasicBlock &BB) {
addInfoForInductions(BB);
+ addInfoForInductions2(BB);
// True as long as long as the current instruction is guaranteed to execute.
bool GuaranteedToExecute = true;
diff --git a/llvm/test/Transforms/ConstraintElimination/latch-test.ll b/llvm/test/Transforms/ConstraintElimination/latch-test.ll
new file mode 100644
index 000000000000000..744c3c7a8503c62
--- /dev/null
+++ b/llvm/test/Transforms/ConstraintElimination/latch-test.ll
@@ -0,0 +1,94 @@
+target datalayout = "e-m:o-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-n32:64-S128-Fn32"
+target triple = "arm64-apple-macosx15.0.0"
+
+define fastcc i1 @test2() {
+entry:
+ br label %loop
+
+loop: ; preds = %loop, %entry
+ %iv = phi i64 [ %iv.next, %loop ], [ -1, %entry ]
+ %iv.next = add i64 %iv, 1
+ %ec = icmp eq i64 %iv.next, 0
+ br i1 %ec, label %for.cond, label %loop
+
+for.cond: ; preds = %loop
+ ret i1 false
+}
+
+
+define fastcc i1 @test1(i32 %bf.load.i.i.i.i724) {
+entry:
+ br label %for.body4.i.i.i.i
+
+for.body4.i.i.i.i: ; preds = %for.body4.i.i.i.i, %entry
+ %__n.addr.116.i.i.i.i = phi i64 [ %inc.i.i.i.i, %for.body4.i.i.i.i ], [ -1, %entry ]
+ %inc.i.i.i.i = add i64 %__n.addr.116.i.i.i.i, 1
+ %exitcond.not.i.i.i.i = icmp eq i64 %inc.i.i.i.i, 0
+ br i1 %exitcond.not.i.i.i.i, label %for.cond, label %for.body4.i.i.i.i
+
+for.cond: ; preds = %for.body4.i.i.i.i
+ %tobool.not.i.i.i.i.i.i727 = icmp eq i32 %bf.load.i.i.i.i724, 0
+ %cond.i.i.i.i.i.i = select i1 %tobool.not.i.i.i.i.i.i727, ptr null, ptr null
+ ret i1 false
+}
+
+
+define i8 @test3(i32 %0) {
+entry:
+ br label %loop.header
+
+loop.header:
+ %iv = phi i32 [ 0, %entry ], [ %iv.next, %loop.latch ]
+ %c.1 = icmp eq i32 %iv, 0
+ br i1 true, label %then, label %loop.latch
+
+loop.latch: ; preds = %loop.header
+ %iv.next = add i32 %iv, 1
+ %c.2 = icmp eq i32 %iv.next, 0
+ br i1 %c.2, label %exit.3, label %loop.header
+
+then:
+ %c.3 = icmp eq i32 %0, 31
+ br i1 %c.3, label %exit.2, label %exit
+
+
+exit:
+ ret i8 1
+
+exit.2:
+ ret i8 0
+
+
+exit.3:
+ ret i8 2
+}
+
+
+define void @test3(i32 %0) {
+entry:
+ br label %for.body496.i1357
+
+for.cond.cleanup495.i1365: ; preds = %if.else549.i
+ ret void
+
+for.body496.i1357: ; preds = %if.else549.i, %entry
+ %K.01766.i = phi i32 [ -1, %entry ], [ %inc.i1363, %if.else549.i ]
+ %cmp499.i1359 = icmp eq i32 %K.01766.i, 0
+ br i1 %cmp499.i1359, label %if.then502.i, label %if.else549.i
+
+if.then502.i: ; preds = %for.body496.i1357
+ %cmp4.not.i1467.i = icmp eq i32 %0, 31
+ br i1 %cmp4.not.i1467.i, label %if.end6.i1470.i, label %_ZN4llvm18SaturatingMultiplyIjEENSt3__19enable_ifIXsr3stdE13is_unsigned_vIT_EES3_E4typeES3_S3_Pb.exit1485.i
+
+if.end6.i1470.i: ; preds = %if.then502.i
+ ret void
+
+_ZN4llvm18SaturatingMultiplyIjEENSt3__19enable_ifIXsr3stdE13is_unsigned_vIT_EES3_E4typeES3_S3_Pb.exit1485.i: ; preds = %if.then502.i
+ ret void
+
+if.else549.i: ; preds = %for.body496.i1357
+ %inc.i1363 = add i32 %K.01766.i, 1
+ %exitcond.not.i1364 = icmp eq i32 %inc.i1363, 0
+ br i1 %exitcond.not.i1364, label %for.cond.cleanup495.i1365, label %for.body496.i1357
+}
+
More information about the llvm-commits
mailing list