[llvm] [ValueTracking] Augment isImpliedByDomCondition by data-relation (PR #187224)

Kunqiu Chen via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 18 08:08:03 PDT 2026


================
@@ -9414,86 +9414,568 @@ bool llvm::matchSimpleTernaryIntrinsicRecurrence(const IntrinsicInst *I,
          II == I;
 }
 
+namespace {
+enum PatternKind {
+  MATCH_RHS,  // Match Pred LHS, (op LHS, ...)
+  MATCH_LHS,  // Match Pred (op RHS, ...), RHS
+  MATCH_BOTH, // Match Pred (op1 X, ...), (op2 X, ...)
+  MATCH_NONE  // Match none
+};
+
+std::optional<APFloat> getCompareAPFloat(const APFloat &C,
+                                         const DenormalMode Mode) {
+  if (!C.isDenormal())
+    return C;
+  DenormalMode::DenormalModeKind InMode = Mode.Input;
+  if (InMode == DenormalMode::DenormalModeKind::IEEE)
+    return C;
+  if (InMode == DenormalMode::DenormalModeKind::Dynamic)
+    return std::nullopt;
+  assert(InMode != DenormalMode::DenormalModeKind::Invalid &&
+         InMode != DenormalMode::DenormalModeKind::Dynamic &&
+         "Expected a concrete denormal input mode");
+  // flush denormal input
+  return APFloat::getZero(C.getSemantics(),
+                          InMode == DenormalMode::DenormalModeKind::PreserveSign
+                              ? C.isNegative()
+                              : false);
+}
+
+// If LHS Pred RHS is alwasy true, return true.
+// This function = FCmpInst::compare + DenormalMode
+bool compareFloat(const FCmpInst::Predicate Pred, const APFloat &LHS,
+                  const APFloat &RHS, const Function *CxtF) {
+  DenormalMode Mode = CxtF ? CxtF->getDenormalMode(LHS.getSemantics())
+                           : DenormalMode::getDynamic();
+  auto L = getCompareAPFloat(LHS, Mode);
+  auto R = getCompareAPFloat(RHS, Mode);
+  return L && R && FCmpInst::compare(*L, *R, Pred);
+}
+
+/// Classify a comparison into one of the simple operand-sharing patterns used
+/// by isTrueIntPredicate()/isTrueFPPredicate(), and optionally replace a
+/// constant side with a stronger constant operand found on the opposite side.
+///
+/// \param Pred The original comparison predicate.
+/// \param X Output common operand anchor for MATCH_BOTH and direct-match cases.
+/// \param LHS In/out comparison LHS. May be rewritten to a constant operand
+///            from the other side when the original LHS is constant and that
+///            operand is already sufficient to imply the original predicate.
+/// \param RHS In/out comparison RHS. Symmetric to \p LHS.
+/// \param CLHS Parsed constant for the original LHS, or null if it is not a
+///             supported constant.
+/// \param CRHS Parsed constant for the original RHS, or null if it is not a
+///             supported constant.
+/// \param CxtF Context function to extract DenormalMode for float computing.
+///
+/// Returns The recognized PatternKind, or MATCH_NONE if this helper cannot
+///          normalize the comparison into a supported shape.
+template <typename ConstantT>
+PatternKind classifyCmpPatternAndAnchorConstants(
+    CmpInst::Predicate Pred, const Value *&X, const Value *&LHS,
+    const Value *&RHS, const ConstantT *CLHS, const ConstantT *CRHS,
+    const Function *CxtF = nullptr) {
+  static_assert(std::is_same_v<ConstantT, APInt> ||
+                    std::is_same_v<ConstantT, APFloat>,
+                "Only APInt and APFloat are supported");
+
+  // Candidate Operands: those ops with the same type
+  SmallSetVector<const Value *, 4> LHSOps, RHSOps;
+  auto CmpAndCollectOps = [&](SmallSetVector<const Value *, 4> &Ops,
+                              const Instruction *I,
+                              const Value *CmpTo) -> bool {
+    bool Contains = false;
+    for (const auto &Op : I->operands()) {
+      if (Op->getType() != LHS->getType())
+        continue;
+      if (Op.get() == CmpTo) {
+        X = CmpTo;
+        Contains = true;
+      }
+      Ops.insert(Op.get());
+    }
+    return Contains;
+  };
+  // Category 1: Match Pred LHS, (op LHS, ...)
+  if (auto *RHSInst = dyn_cast<Instruction>(RHS);
+      RHSInst && CmpAndCollectOps(RHSOps, RHSInst, LHS))
+    return MATCH_RHS;
+  // Category 2: Match Pred (op RHS, ...), RHS
+  if (auto *LHSInst = dyn_cast<Instruction>(LHS);
+      LHSInst && CmpAndCollectOps(LHSOps, LHSInst, RHS))
+    return MATCH_LHS;
+  // Category 3: Match Pred (op1 X, ...), (op2 X, ...)
+  if (const auto *It = find_if(
+          LHSOps, [&RHSOps](const Value *Op) { return RHSOps.contains(Op); });
+      It != LHSOps.end() && (X = *It))
+    return MATCH_BOTH;
+  // If one of LHS and RHS is constant, try to find a new LHS/RHS to continue.
+  // E.g., if pred is < :
+  //  CLHS < CLHS' < RHS  --> CLHS < RHS is true, hence we set LHS as CLHS'.
+  //  LHS  < CRHS' < CRHS --> RHS < CRHS is true, hence we set RHS as CRHS'.
+  if (CLHS || CRHS) {
+    const bool IsConstLHS = CLHS != nullptr;
+    const auto &Ops = IsConstLHS ? RHSOps : LHSOps;
+    const auto *It = find_if(Ops, [&](const Value *Op) {
+      const ConstantT *C;
+      if constexpr (std::is_same_v<ConstantT, APInt>) {
+        if (!match(Op, m_APInt(C)))
+          return false;
+        return IsConstLHS ? ICmpInst::compare(*CLHS, *C, Pred)
+                          : ICmpInst::compare(*C, *CRHS, Pred);
+      } else {
+        if (!match(Op, m_APFloat(C)))
+          return false;
+        return IsConstLHS ? compareFloat(Pred, *CLHS, *C, CxtF)
+                          : compareFloat(Pred, *C, *CRHS, CxtF);
+      }
+    });
+    if (It == Ops.end())
+      return MATCH_NONE;
+
+    return IsConstLHS ? (LHS = *It, MATCH_RHS) : (RHS = *It, MATCH_LHS);
+  }
+  // Fast path to quit: We do not handle other patterns for now.
+  return MATCH_NONE;
+}
+
 /// Return true if "icmp Pred LHS RHS" is always true.
-static bool isTruePredicate(CmpInst::Predicate Pred, const Value *LHS,
-                            const Value *RHS) {
-  if (ICmpInst::isTrueWhenEqual(Pred) && LHS == RHS)
+bool isTrueIntPredicate(CmpInst::Predicate Pred, const Value *LHS,
+                        const Value *RHS) {
+  if (LHS->getType() != RHS->getType())
+    return false;
+
+  switch (Pred) {
+  default:
+    break;
+  case CmpInst::ICMP_SGT:
+    return isTrueIntPredicate(CmpInst::ICMP_SLT, RHS, LHS);
+  case CmpInst::ICMP_SGE:
+    return isTrueIntPredicate(CmpInst::ICMP_SLE, RHS, LHS);
+  case CmpInst::ICMP_UGT:
+    return isTrueIntPredicate(CmpInst::ICMP_ULT, RHS, LHS);
+  case CmpInst::ICMP_UGE:
+    return isTrueIntPredicate(CmpInst::ICMP_ULE, RHS, LHS);
+  }
+
+  const APInt *CLHS = nullptr, *CRHS = nullptr;
+  match(LHS, m_APInt(CLHS));
+  match(RHS, m_APInt(CRHS));
+  // If both CLHS and CRHS are constant integers.
+  if (CLHS && CRHS)
+    return ICmpInst::compare(*CLHS, *CRHS, Pred);
+
+  // If the predicate is true when equal?
+  const bool CanEq = ICmpInst::isTrueWhenEqual(Pred);
+  if (CanEq && LHS == RHS)
     return true;
 
+  // Exclude NE/EQ
+  if (ICmpInst::isEquality(Pred))
+    return false;
+
+  // Represent the common operand between LHS and RHS
+  const Value *X;
+
+  // Derive possible match pattern
+  PatternKind PK =
+      classifyCmpPatternAndAnchorConstants(Pred, X, LHS, RHS, CLHS, CRHS);
+
+  // The pattern is too complex to analyze, quit early.
+  if (PK == MATCH_NONE)
+    return false;
+
+  const APInt *C;
+  const Value *V;
+  bool m;
+
   switch (Pred) {
   default:
     return false;
 
+  case CmpInst::ICMP_SLT:
+    // Delegate to CmpInst::ICMP_SLE to share common patterns.
   case CmpInst::ICMP_SLE: {
-    const APInt *C;
+    // TODO: handle select/phi.
+
+    // Category 1: Match Pred LHS, (op LHS, ...)
+    if (PK == MATCH_RHS) {
+      // LHS s<= LHS +_{nsw} C   if C >= 0
+      // LHS s<  LHS +_{nsw} C   if C > 0
+      if (match(RHS, m_c_NSWAdd(m_Specific(LHS), m_APInt(C))))
+        return CanEq ? C->isNonNegative() : C->isStrictlyPositive();
+      // LHS s<= LHS -_{nsw} C   if C <= 0
+      // LHS s<  LHS -_{nsw} C   if C < 0
+      if (match(RHS, m_NSWSub(m_Specific(LHS), m_APInt(C))))
+        return CanEq ? C->isNonPositive() : C->isNegative();
+      // LHS s<= LHS <<_{nsw,nuw} V  for any V (V < 0 is UB)
+      // slt: cannot exclude LHS == 0
+      if (CanEq && match(RHS, m_NSWShl(m_Specific(LHS), m_Value(V))) &&
+          cast<OverflowingBinaryOperator>(RHS)->hasNoUnsignedWrap())
+        return true;
----------------
Camsyn wrote:

Assume no UB.

https://github.com/llvm/llvm-project/pull/187224


More information about the llvm-commits mailing list