[llvm] [ValueTracking] Move `getFlippedStrictnessPredicateAndConstant` into ValueTracking. NFC. (PR #122064)

Yingwei Zheng via llvm-commits llvm-commits at lists.llvm.org
Tue Jan 7 23:23:44 PST 2025


https://github.com/dtcxzyw created https://github.com/llvm/llvm-project/pull/122064

Needed by https://github.com/llvm/llvm-project/pull/121958.


>From daeb94d7246474ac3e2364fa4741671f00a3ee2b Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Wed, 8 Jan 2025 15:22:05 +0800
Subject: [PATCH] [ValueTracking] Move
 `getFlippedStrictnessPredicateAndConstant` into ValueTracking. NFC.

---
 llvm/include/llvm/Analysis/ValueTracking.h    |  7 ++
 .../Transforms/InstCombine/InstCombiner.h     |  6 --
 llvm/lib/Analysis/ValueTracking.cpp           | 74 ++++++++++++++++
 .../InstCombine/InstCombineCompares.cpp       | 84 +------------------
 .../InstCombine/InstCombineSelect.cpp         |  6 +-
 5 files changed, 87 insertions(+), 90 deletions(-)

diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h
index 8aa024a72afc88..b4918c2d1e8a18 100644
--- a/llvm/include/llvm/Analysis/ValueTracking.h
+++ b/llvm/include/llvm/Analysis/ValueTracking.h
@@ -1102,6 +1102,13 @@ bool mustExecuteUBIfPoisonOnPathTo(Instruction *Root,
                                    Instruction *OnPathTo,
                                    DominatorTree *DT);
 
+/// Convert an integer comparison with a constant RHS into an equivalent
+/// form with the strictness flipped predicate. Return the new predicate and
+/// corresponding constant RHS if possible. Otherwise return std::nullopt.
+/// E.g., (icmp sgt X, 0) -> (icmp sle X, 1).
+std::optional<std::pair<CmpPredicate, Constant *>>
+getFlippedStrictnessPredicateAndConstant(CmpPredicate Pred, Constant *C);
+
 /// Specific patterns of select instructions we can match.
 enum SelectPatternFlavor {
   SPF_UNKNOWN = 0,
diff --git a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
index 71592058e34563..fa6b60cba15aaf 100644
--- a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
+++ b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
@@ -184,12 +184,6 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner {
     return ConstantExpr::getSub(C, ConstantInt::get(C->getType(), 1));
   }
 
-  std::optional<std::pair<
-      CmpPredicate,
-      Constant *>> static getFlippedStrictnessPredicateAndConstant(CmpPredicate
-                                                                       Pred,
-                                                                   Constant *C);
-
   static bool shouldAvoidAbsorbingNotIntoSelect(const SelectInst &SI) {
     // a ? b : false and a ? true : b are the canonical form of logical and/or.
     // This includes !a ? b : false and !a ? true : b. Absorbing the not into
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 2f6e869ae7b735..0eb43dd581acc6 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -8641,6 +8641,80 @@ SelectPatternResult llvm::getSelectPattern(CmpInst::Predicate Pred,
   }
 }
 
+std::optional<std::pair<CmpPredicate, Constant *>>
+llvm::getFlippedStrictnessPredicateAndConstant(CmpPredicate Pred, Constant *C) {
+  assert(ICmpInst::isRelational(Pred) && ICmpInst::isIntPredicate(Pred) &&
+         "Only for relational integer predicates.");
+  if (isa<UndefValue>(C))
+    return std::nullopt;
+
+  Type *Type = C->getType();
+  bool IsSigned = ICmpInst::isSigned(Pred);
+
+  CmpInst::Predicate UnsignedPred = ICmpInst::getUnsignedPredicate(Pred);
+  bool WillIncrement =
+      UnsignedPred == ICmpInst::ICMP_ULE || UnsignedPred == ICmpInst::ICMP_UGT;
+
+  // Check if the constant operand can be safely incremented/decremented
+  // without overflowing/underflowing.
+  auto ConstantIsOk = [WillIncrement, IsSigned](ConstantInt *C) {
+    return WillIncrement ? !C->isMaxValue(IsSigned) : !C->isMinValue(IsSigned);
+  };
+
+  Constant *SafeReplacementConstant = nullptr;
+  if (auto *CI = dyn_cast<ConstantInt>(C)) {
+    // Bail out if the constant can't be safely incremented/decremented.
+    if (!ConstantIsOk(CI))
+      return std::nullopt;
+  } else if (auto *FVTy = dyn_cast<FixedVectorType>(Type)) {
+    unsigned NumElts = FVTy->getNumElements();
+    for (unsigned i = 0; i != NumElts; ++i) {
+      Constant *Elt = C->getAggregateElement(i);
+      if (!Elt)
+        return std::nullopt;
+
+      if (isa<UndefValue>(Elt))
+        continue;
+
+      // Bail out if we can't determine if this constant is min/max or if we
+      // know that this constant is min/max.
+      auto *CI = dyn_cast<ConstantInt>(Elt);
+      if (!CI || !ConstantIsOk(CI))
+        return std::nullopt;
+
+      if (!SafeReplacementConstant)
+        SafeReplacementConstant = CI;
+    }
+  } else if (isa<VectorType>(C->getType())) {
+    // Handle scalable splat
+    Value *SplatC = C->getSplatValue();
+    auto *CI = dyn_cast_or_null<ConstantInt>(SplatC);
+    // Bail out if the constant can't be safely incremented/decremented.
+    if (!CI || !ConstantIsOk(CI))
+      return std::nullopt;
+  } else {
+    // ConstantExpr?
+    return std::nullopt;
+  }
+
+  // It may not be safe to change a compare predicate in the presence of
+  // undefined elements, so replace those elements with the first safe constant
+  // that we found.
+  // TODO: in case of poison, it is safe; let's replace undefs only.
+  if (C->containsUndefOrPoisonElement()) {
+    assert(SafeReplacementConstant && "Replacement constant not set");
+    C = Constant::replaceUndefsWith(C, SafeReplacementConstant);
+  }
+
+  CmpInst::Predicate NewPred = CmpInst::getFlippedStrictnessPredicate(Pred);
+
+  // Increment or decrement the constant.
+  Constant *OneOrNegOne = ConstantInt::get(Type, WillIncrement ? 1 : -1, true);
+  Constant *NewC = ConstantExpr::getAdd(C, OneOrNegOne);
+
+  return std::make_pair(NewPred, NewC);
+}
+
 static SelectPatternResult matchSelectPattern(CmpInst::Predicate Pred,
                                               FastMathFlags FMF,
                                               Value *CmpLHS, Value *CmpRHS,
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 8b23583c510637..c2d659035877ed 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -2485,9 +2485,8 @@ Instruction *InstCombinerImpl::foldICmpShlConstant(ICmpInst &Cmp,
       // icmp ule i64 (shl X, 32), 8589934592 ->
       // icmp ule i32 (trunc X, i32), 2 ->
       // icmp ult i32 (trunc X, i32), 3
-      if (auto FlippedStrictness =
-              InstCombiner::getFlippedStrictnessPredicateAndConstant(
-                  Pred, ConstantInt::get(ShType->getContext(), C))) {
+      if (auto FlippedStrictness = getFlippedStrictnessPredicateAndConstant(
+              Pred, ConstantInt::get(ShType->getContext(), C))) {
         CmpPred = FlippedStrictness->first;
         RHSC = cast<ConstantInt>(FlippedStrictness->second)->getValue();
       }
@@ -3280,8 +3279,7 @@ bool InstCombinerImpl::matchThreeWayIntCompare(SelectInst *SI, Value *&LHS,
   if (PredB == ICmpInst::ICMP_SGT && isa<Constant>(RHS2)) {
     // x sgt C-1  <-->  x sge C  <-->  not(x slt C)
     auto FlippedStrictness =
-        InstCombiner::getFlippedStrictnessPredicateAndConstant(
-            PredB, cast<Constant>(RHS2));
+        getFlippedStrictnessPredicateAndConstant(PredB, cast<Constant>(RHS2));
     if (!FlippedStrictness)
       return false;
     assert(FlippedStrictness->first == ICmpInst::ICMP_SGE &&
@@ -6908,79 +6906,6 @@ Instruction *InstCombinerImpl::foldICmpUsingBoolRange(ICmpInst &I) {
   return nullptr;
 }
 
-std::optional<std::pair<CmpPredicate, Constant *>>
-InstCombiner::getFlippedStrictnessPredicateAndConstant(CmpPredicate Pred,
-                                                       Constant *C) {
-  assert(ICmpInst::isRelational(Pred) && ICmpInst::isIntPredicate(Pred) &&
-         "Only for relational integer predicates.");
-
-  Type *Type = C->getType();
-  bool IsSigned = ICmpInst::isSigned(Pred);
-
-  CmpInst::Predicate UnsignedPred = ICmpInst::getUnsignedPredicate(Pred);
-  bool WillIncrement =
-      UnsignedPred == ICmpInst::ICMP_ULE || UnsignedPred == ICmpInst::ICMP_UGT;
-
-  // Check if the constant operand can be safely incremented/decremented
-  // without overflowing/underflowing.
-  auto ConstantIsOk = [WillIncrement, IsSigned](ConstantInt *C) {
-    return WillIncrement ? !C->isMaxValue(IsSigned) : !C->isMinValue(IsSigned);
-  };
-
-  Constant *SafeReplacementConstant = nullptr;
-  if (auto *CI = dyn_cast<ConstantInt>(C)) {
-    // Bail out if the constant can't be safely incremented/decremented.
-    if (!ConstantIsOk(CI))
-      return std::nullopt;
-  } else if (auto *FVTy = dyn_cast<FixedVectorType>(Type)) {
-    unsigned NumElts = FVTy->getNumElements();
-    for (unsigned i = 0; i != NumElts; ++i) {
-      Constant *Elt = C->getAggregateElement(i);
-      if (!Elt)
-        return std::nullopt;
-
-      if (isa<UndefValue>(Elt))
-        continue;
-
-      // Bail out if we can't determine if this constant is min/max or if we
-      // know that this constant is min/max.
-      auto *CI = dyn_cast<ConstantInt>(Elt);
-      if (!CI || !ConstantIsOk(CI))
-        return std::nullopt;
-
-      if (!SafeReplacementConstant)
-        SafeReplacementConstant = CI;
-    }
-  } else if (isa<VectorType>(C->getType())) {
-    // Handle scalable splat
-    Value *SplatC = C->getSplatValue();
-    auto *CI = dyn_cast_or_null<ConstantInt>(SplatC);
-    // Bail out if the constant can't be safely incremented/decremented.
-    if (!CI || !ConstantIsOk(CI))
-      return std::nullopt;
-  } else {
-    // ConstantExpr?
-    return std::nullopt;
-  }
-
-  // It may not be safe to change a compare predicate in the presence of
-  // undefined elements, so replace those elements with the first safe constant
-  // that we found.
-  // TODO: in case of poison, it is safe; let's replace undefs only.
-  if (C->containsUndefOrPoisonElement()) {
-    assert(SafeReplacementConstant && "Replacement constant not set");
-    C = Constant::replaceUndefsWith(C, SafeReplacementConstant);
-  }
-
-  CmpInst::Predicate NewPred = CmpInst::getFlippedStrictnessPredicate(Pred);
-
-  // Increment or decrement the constant.
-  Constant *OneOrNegOne = ConstantInt::get(Type, WillIncrement ? 1 : -1, true);
-  Constant *NewC = ConstantExpr::getAdd(C, OneOrNegOne);
-
-  return std::make_pair(NewPred, NewC);
-}
-
 /// If we have an icmp le or icmp ge instruction with a constant operand, turn
 /// it into the appropriate icmp lt or icmp gt instruction. This transform
 /// allows them to be folded in visitICmpInst.
@@ -6996,8 +6921,7 @@ static ICmpInst *canonicalizeCmpWithConstant(ICmpInst &I) {
   if (!Op1C)
     return nullptr;
 
-  auto FlippedStrictness =
-      InstCombiner::getFlippedStrictnessPredicateAndConstant(Pred, Op1C);
+  auto FlippedStrictness = getFlippedStrictnessPredicateAndConstant(Pred, Op1C);
   if (!FlippedStrictness)
     return nullptr;
 
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 7fd91c72a2fb0e..eca518aa640700 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -1689,8 +1689,7 @@ tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp,
     return nullptr;
 
   // Check the constant we'd have with flipped-strictness predicate.
-  auto FlippedStrictness =
-      InstCombiner::getFlippedStrictnessPredicateAndConstant(Pred, C0);
+  auto FlippedStrictness = getFlippedStrictnessPredicateAndConstant(Pred, C0);
   if (!FlippedStrictness)
     return nullptr;
 
@@ -1970,8 +1969,7 @@ static Value *foldSelectWithConstOpToBinOp(ICmpInst *Cmp, Value *TrueVal,
   Value *RHS;
   SelectPatternFlavor SPF;
   const DataLayout &DL = BOp->getDataLayout();
-  auto Flipped =
-      InstCombiner::getFlippedStrictnessPredicateAndConstant(Predicate, C1);
+  auto Flipped = getFlippedStrictnessPredicateAndConstant(Predicate, C1);
 
   if (C3 == ConstantFoldBinaryOpOperands(Opcode, C1, C2, DL)) {
     SPF = getSelectPattern(Predicate).Flavor;



More information about the llvm-commits mailing list