[llvm] [ValueTracking] Augment isImpliedByDomCondition by data-relation (PR #187224)
Kunqiu Chen via llvm-commits
llvm-commits at lists.llvm.org
Wed Mar 18 08:08:03 PDT 2026
================
@@ -9414,86 +9414,568 @@ bool llvm::matchSimpleTernaryIntrinsicRecurrence(const IntrinsicInst *I,
II == I;
}
+namespace {
+enum PatternKind {
+ MATCH_RHS, // Match Pred LHS, (op LHS, ...)
+ MATCH_LHS, // Match Pred (op RHS, ...), RHS
+ MATCH_BOTH, // Match Pred (op1 X, ...), (op2 X, ...)
+ MATCH_NONE // Match none
+};
+
+std::optional<APFloat> getCompareAPFloat(const APFloat &C,
+ const DenormalMode Mode) {
+ if (!C.isDenormal())
+ return C;
+ DenormalMode::DenormalModeKind InMode = Mode.Input;
+ if (InMode == DenormalMode::DenormalModeKind::IEEE)
+ return C;
+ if (InMode == DenormalMode::DenormalModeKind::Dynamic)
+ return std::nullopt;
+ assert(InMode != DenormalMode::DenormalModeKind::Invalid &&
+ InMode != DenormalMode::DenormalModeKind::Dynamic &&
+ "Expected a concrete denormal input mode");
+ // flush denormal input
+ return APFloat::getZero(C.getSemantics(),
+ InMode == DenormalMode::DenormalModeKind::PreserveSign
+ ? C.isNegative()
+ : false);
+}
+
+// If LHS Pred RHS is alwasy true, return true.
+// This function = FCmpInst::compare + DenormalMode
+bool compareFloat(const FCmpInst::Predicate Pred, const APFloat &LHS,
+ const APFloat &RHS, const Function *CxtF) {
+ DenormalMode Mode = CxtF ? CxtF->getDenormalMode(LHS.getSemantics())
+ : DenormalMode::getDynamic();
+ auto L = getCompareAPFloat(LHS, Mode);
+ auto R = getCompareAPFloat(RHS, Mode);
+ return L && R && FCmpInst::compare(*L, *R, Pred);
+}
+
+/// Classify a comparison into one of the simple operand-sharing patterns used
+/// by isTrueIntPredicate()/isTrueFPPredicate(), and optionally replace a
+/// constant side with a stronger constant operand found on the opposite side.
+///
+/// \param Pred The original comparison predicate.
+/// \param X Output common operand anchor for MATCH_BOTH and direct-match cases.
+/// \param LHS In/out comparison LHS. May be rewritten to a constant operand
+/// from the other side when the original LHS is constant and that
+/// operand is already sufficient to imply the original predicate.
+/// \param RHS In/out comparison RHS. Symmetric to \p LHS.
+/// \param CLHS Parsed constant for the original LHS, or null if it is not a
+/// supported constant.
+/// \param CRHS Parsed constant for the original RHS, or null if it is not a
+/// supported constant.
+/// \param CxtF Context function to extract DenormalMode for float computing.
+///
+/// Returns The recognized PatternKind, or MATCH_NONE if this helper cannot
+/// normalize the comparison into a supported shape.
+template <typename ConstantT>
+PatternKind classifyCmpPatternAndAnchorConstants(
+ CmpInst::Predicate Pred, const Value *&X, const Value *&LHS,
+ const Value *&RHS, const ConstantT *CLHS, const ConstantT *CRHS,
+ const Function *CxtF = nullptr) {
+ static_assert(std::is_same_v<ConstantT, APInt> ||
+ std::is_same_v<ConstantT, APFloat>,
+ "Only APInt and APFloat are supported");
+
+ // Candidate Operands: those ops with the same type
+ SmallSetVector<const Value *, 4> LHSOps, RHSOps;
+ auto CmpAndCollectOps = [&](SmallSetVector<const Value *, 4> &Ops,
+ const Instruction *I,
+ const Value *CmpTo) -> bool {
+ bool Contains = false;
+ for (const auto &Op : I->operands()) {
+ if (Op->getType() != LHS->getType())
+ continue;
+ if (Op.get() == CmpTo) {
+ X = CmpTo;
+ Contains = true;
+ }
+ Ops.insert(Op.get());
+ }
+ return Contains;
+ };
+ // Category 1: Match Pred LHS, (op LHS, ...)
+ if (auto *RHSInst = dyn_cast<Instruction>(RHS);
+ RHSInst && CmpAndCollectOps(RHSOps, RHSInst, LHS))
+ return MATCH_RHS;
+ // Category 2: Match Pred (op RHS, ...), RHS
+ if (auto *LHSInst = dyn_cast<Instruction>(LHS);
+ LHSInst && CmpAndCollectOps(LHSOps, LHSInst, RHS))
+ return MATCH_LHS;
+ // Category 3: Match Pred (op1 X, ...), (op2 X, ...)
+ if (const auto *It = find_if(
+ LHSOps, [&RHSOps](const Value *Op) { return RHSOps.contains(Op); });
+ It != LHSOps.end() && (X = *It))
+ return MATCH_BOTH;
+ // If one of LHS and RHS is constant, try to find a new LHS/RHS to continue.
+ // E.g., if pred is < :
+ // CLHS < CLHS' < RHS --> CLHS < RHS is true, hence we set LHS as CLHS'.
+ // LHS < CRHS' < CRHS --> RHS < CRHS is true, hence we set RHS as CRHS'.
+ if (CLHS || CRHS) {
+ const bool IsConstLHS = CLHS != nullptr;
+ const auto &Ops = IsConstLHS ? RHSOps : LHSOps;
+ const auto *It = find_if(Ops, [&](const Value *Op) {
+ const ConstantT *C;
+ if constexpr (std::is_same_v<ConstantT, APInt>) {
+ if (!match(Op, m_APInt(C)))
+ return false;
+ return IsConstLHS ? ICmpInst::compare(*CLHS, *C, Pred)
+ : ICmpInst::compare(*C, *CRHS, Pred);
+ } else {
+ if (!match(Op, m_APFloat(C)))
+ return false;
+ return IsConstLHS ? compareFloat(Pred, *CLHS, *C, CxtF)
+ : compareFloat(Pred, *C, *CRHS, CxtF);
+ }
+ });
+ if (It == Ops.end())
+ return MATCH_NONE;
+
+ return IsConstLHS ? (LHS = *It, MATCH_RHS) : (RHS = *It, MATCH_LHS);
+ }
+ // Fast path to quit: We do not handle other patterns for now.
+ return MATCH_NONE;
+}
+
/// Return true if "icmp Pred LHS RHS" is always true.
-static bool isTruePredicate(CmpInst::Predicate Pred, const Value *LHS,
- const Value *RHS) {
- if (ICmpInst::isTrueWhenEqual(Pred) && LHS == RHS)
+bool isTrueIntPredicate(CmpInst::Predicate Pred, const Value *LHS,
+ const Value *RHS) {
+ if (LHS->getType() != RHS->getType())
+ return false;
+
+ switch (Pred) {
+ default:
+ break;
+ case CmpInst::ICMP_SGT:
+ return isTrueIntPredicate(CmpInst::ICMP_SLT, RHS, LHS);
+ case CmpInst::ICMP_SGE:
+ return isTrueIntPredicate(CmpInst::ICMP_SLE, RHS, LHS);
+ case CmpInst::ICMP_UGT:
+ return isTrueIntPredicate(CmpInst::ICMP_ULT, RHS, LHS);
+ case CmpInst::ICMP_UGE:
+ return isTrueIntPredicate(CmpInst::ICMP_ULE, RHS, LHS);
+ }
+
+ const APInt *CLHS = nullptr, *CRHS = nullptr;
+ match(LHS, m_APInt(CLHS));
+ match(RHS, m_APInt(CRHS));
+ // If both CLHS and CRHS are constant integers.
+ if (CLHS && CRHS)
+ return ICmpInst::compare(*CLHS, *CRHS, Pred);
+
+ // If the predicate is true when equal?
+ const bool CanEq = ICmpInst::isTrueWhenEqual(Pred);
+ if (CanEq && LHS == RHS)
return true;
+ // Exclude NE/EQ
+ if (ICmpInst::isEquality(Pred))
+ return false;
+
+ // Represent the common operand between LHS and RHS
+ const Value *X;
+
+ // Derive possible match pattern
+ PatternKind PK =
+ classifyCmpPatternAndAnchorConstants(Pred, X, LHS, RHS, CLHS, CRHS);
+
+ // The pattern is too complex to analyze, quit early.
+ if (PK == MATCH_NONE)
+ return false;
+
+ const APInt *C;
+ const Value *V;
+ bool m;
+
switch (Pred) {
default:
return false;
+ case CmpInst::ICMP_SLT:
+ // Delegate to CmpInst::ICMP_SLE to share common patterns.
case CmpInst::ICMP_SLE: {
- const APInt *C;
+ // TODO: handle select/phi.
+
+ // Category 1: Match Pred LHS, (op LHS, ...)
+ if (PK == MATCH_RHS) {
+ // LHS s<= LHS +_{nsw} C if C >= 0
+ // LHS s< LHS +_{nsw} C if C > 0
+ if (match(RHS, m_c_NSWAdd(m_Specific(LHS), m_APInt(C))))
+ return CanEq ? C->isNonNegative() : C->isStrictlyPositive();
+ // LHS s<= LHS -_{nsw} C if C <= 0
+ // LHS s< LHS -_{nsw} C if C < 0
+ if (match(RHS, m_NSWSub(m_Specific(LHS), m_APInt(C))))
+ return CanEq ? C->isNonPositive() : C->isNegative();
+ // LHS s<= LHS <<_{nsw,nuw} V for any V (V < 0 is UB)
+ // slt: cannot exclude LHS == 0
+ if (CanEq && match(RHS, m_NSWShl(m_Specific(LHS), m_Value(V))) &&
+ cast<OverflowingBinaryOperator>(RHS)->hasNoUnsignedWrap())
+ return true;
----------------
Camsyn wrote:
Assume no UB.
https://github.com/llvm/llvm-project/pull/187224
More information about the llvm-commits
mailing list