[llvm] [ValueTracking] Move `getFlippedStrictnessPredicateAndConstant` into ValueTracking. NFC. (PR #122064)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Jan 7 23:24:15 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-analysis
Author: Yingwei Zheng (dtcxzyw)
<details>
<summary>Changes</summary>
Needed by https://github.com/llvm/llvm-project/pull/121958.
---
Full diff: https://github.com/llvm/llvm-project/pull/122064.diff
5 Files Affected:
- (modified) llvm/include/llvm/Analysis/ValueTracking.h (+7)
- (modified) llvm/include/llvm/Transforms/InstCombine/InstCombiner.h (-6)
- (modified) llvm/lib/Analysis/ValueTracking.cpp (+74)
- (modified) llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp (+4-80)
- (modified) llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp (+2-4)
``````````diff
diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h
index 8aa024a72afc88..b4918c2d1e8a18 100644
--- a/llvm/include/llvm/Analysis/ValueTracking.h
+++ b/llvm/include/llvm/Analysis/ValueTracking.h
@@ -1102,6 +1102,13 @@ bool mustExecuteUBIfPoisonOnPathTo(Instruction *Root,
Instruction *OnPathTo,
DominatorTree *DT);
+/// Convert an integer comparison with a constant RHS into an equivalent
+/// form with the strictness flipped predicate. Return the new predicate and
+/// corresponding constant RHS if possible. Otherwise return std::nullopt.
+/// E.g., (icmp sgt X, 0) -> (icmp sle X, 1).
+std::optional<std::pair<CmpPredicate, Constant *>>
+getFlippedStrictnessPredicateAndConstant(CmpPredicate Pred, Constant *C);
+
/// Specific patterns of select instructions we can match.
enum SelectPatternFlavor {
SPF_UNKNOWN = 0,
diff --git a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
index 71592058e34563..fa6b60cba15aaf 100644
--- a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
+++ b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
@@ -184,12 +184,6 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner {
return ConstantExpr::getSub(C, ConstantInt::get(C->getType(), 1));
}
- std::optional<std::pair<
- CmpPredicate,
- Constant *>> static getFlippedStrictnessPredicateAndConstant(CmpPredicate
- Pred,
- Constant *C);
-
static bool shouldAvoidAbsorbingNotIntoSelect(const SelectInst &SI) {
// a ? b : false and a ? true : b are the canonical form of logical and/or.
// This includes !a ? b : false and !a ? true : b. Absorbing the not into
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 2f6e869ae7b735..0eb43dd581acc6 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -8641,6 +8641,80 @@ SelectPatternResult llvm::getSelectPattern(CmpInst::Predicate Pred,
}
}
+std::optional<std::pair<CmpPredicate, Constant *>>
+llvm::getFlippedStrictnessPredicateAndConstant(CmpPredicate Pred, Constant *C) {
+ assert(ICmpInst::isRelational(Pred) && ICmpInst::isIntPredicate(Pred) &&
+ "Only for relational integer predicates.");
+ if (isa<UndefValue>(C))
+ return std::nullopt;
+
+ Type *Type = C->getType();
+ bool IsSigned = ICmpInst::isSigned(Pred);
+
+ CmpInst::Predicate UnsignedPred = ICmpInst::getUnsignedPredicate(Pred);
+ bool WillIncrement =
+ UnsignedPred == ICmpInst::ICMP_ULE || UnsignedPred == ICmpInst::ICMP_UGT;
+
+ // Check if the constant operand can be safely incremented/decremented
+ // without overflowing/underflowing.
+ auto ConstantIsOk = [WillIncrement, IsSigned](ConstantInt *C) {
+ return WillIncrement ? !C->isMaxValue(IsSigned) : !C->isMinValue(IsSigned);
+ };
+
+ Constant *SafeReplacementConstant = nullptr;
+ if (auto *CI = dyn_cast<ConstantInt>(C)) {
+ // Bail out if the constant can't be safely incremented/decremented.
+ if (!ConstantIsOk(CI))
+ return std::nullopt;
+ } else if (auto *FVTy = dyn_cast<FixedVectorType>(Type)) {
+ unsigned NumElts = FVTy->getNumElements();
+ for (unsigned i = 0; i != NumElts; ++i) {
+ Constant *Elt = C->getAggregateElement(i);
+ if (!Elt)
+ return std::nullopt;
+
+ if (isa<UndefValue>(Elt))
+ continue;
+
+ // Bail out if we can't determine if this constant is min/max or if we
+ // know that this constant is min/max.
+ auto *CI = dyn_cast<ConstantInt>(Elt);
+ if (!CI || !ConstantIsOk(CI))
+ return std::nullopt;
+
+ if (!SafeReplacementConstant)
+ SafeReplacementConstant = CI;
+ }
+ } else if (isa<VectorType>(C->getType())) {
+ // Handle scalable splat
+ Value *SplatC = C->getSplatValue();
+ auto *CI = dyn_cast_or_null<ConstantInt>(SplatC);
+ // Bail out if the constant can't be safely incremented/decremented.
+ if (!CI || !ConstantIsOk(CI))
+ return std::nullopt;
+ } else {
+ // ConstantExpr?
+ return std::nullopt;
+ }
+
+ // It may not be safe to change a compare predicate in the presence of
+ // undefined elements, so replace those elements with the first safe constant
+ // that we found.
+ // TODO: in case of poison, it is safe; let's replace undefs only.
+ if (C->containsUndefOrPoisonElement()) {
+ assert(SafeReplacementConstant && "Replacement constant not set");
+ C = Constant::replaceUndefsWith(C, SafeReplacementConstant);
+ }
+
+ CmpInst::Predicate NewPred = CmpInst::getFlippedStrictnessPredicate(Pred);
+
+ // Increment or decrement the constant.
+ Constant *OneOrNegOne = ConstantInt::get(Type, WillIncrement ? 1 : -1, true);
+ Constant *NewC = ConstantExpr::getAdd(C, OneOrNegOne);
+
+ return std::make_pair(NewPred, NewC);
+}
+
static SelectPatternResult matchSelectPattern(CmpInst::Predicate Pred,
FastMathFlags FMF,
Value *CmpLHS, Value *CmpRHS,
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 8b23583c510637..c2d659035877ed 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -2485,9 +2485,8 @@ Instruction *InstCombinerImpl::foldICmpShlConstant(ICmpInst &Cmp,
// icmp ule i64 (shl X, 32), 8589934592 ->
// icmp ule i32 (trunc X, i32), 2 ->
// icmp ult i32 (trunc X, i32), 3
- if (auto FlippedStrictness =
- InstCombiner::getFlippedStrictnessPredicateAndConstant(
- Pred, ConstantInt::get(ShType->getContext(), C))) {
+ if (auto FlippedStrictness = getFlippedStrictnessPredicateAndConstant(
+ Pred, ConstantInt::get(ShType->getContext(), C))) {
CmpPred = FlippedStrictness->first;
RHSC = cast<ConstantInt>(FlippedStrictness->second)->getValue();
}
@@ -3280,8 +3279,7 @@ bool InstCombinerImpl::matchThreeWayIntCompare(SelectInst *SI, Value *&LHS,
if (PredB == ICmpInst::ICMP_SGT && isa<Constant>(RHS2)) {
// x sgt C-1 <--> x sge C <--> not(x slt C)
auto FlippedStrictness =
- InstCombiner::getFlippedStrictnessPredicateAndConstant(
- PredB, cast<Constant>(RHS2));
+ getFlippedStrictnessPredicateAndConstant(PredB, cast<Constant>(RHS2));
if (!FlippedStrictness)
return false;
assert(FlippedStrictness->first == ICmpInst::ICMP_SGE &&
@@ -6908,79 +6906,6 @@ Instruction *InstCombinerImpl::foldICmpUsingBoolRange(ICmpInst &I) {
return nullptr;
}
-std::optional<std::pair<CmpPredicate, Constant *>>
-InstCombiner::getFlippedStrictnessPredicateAndConstant(CmpPredicate Pred,
- Constant *C) {
- assert(ICmpInst::isRelational(Pred) && ICmpInst::isIntPredicate(Pred) &&
- "Only for relational integer predicates.");
-
- Type *Type = C->getType();
- bool IsSigned = ICmpInst::isSigned(Pred);
-
- CmpInst::Predicate UnsignedPred = ICmpInst::getUnsignedPredicate(Pred);
- bool WillIncrement =
- UnsignedPred == ICmpInst::ICMP_ULE || UnsignedPred == ICmpInst::ICMP_UGT;
-
- // Check if the constant operand can be safely incremented/decremented
- // without overflowing/underflowing.
- auto ConstantIsOk = [WillIncrement, IsSigned](ConstantInt *C) {
- return WillIncrement ? !C->isMaxValue(IsSigned) : !C->isMinValue(IsSigned);
- };
-
- Constant *SafeReplacementConstant = nullptr;
- if (auto *CI = dyn_cast<ConstantInt>(C)) {
- // Bail out if the constant can't be safely incremented/decremented.
- if (!ConstantIsOk(CI))
- return std::nullopt;
- } else if (auto *FVTy = dyn_cast<FixedVectorType>(Type)) {
- unsigned NumElts = FVTy->getNumElements();
- for (unsigned i = 0; i != NumElts; ++i) {
- Constant *Elt = C->getAggregateElement(i);
- if (!Elt)
- return std::nullopt;
-
- if (isa<UndefValue>(Elt))
- continue;
-
- // Bail out if we can't determine if this constant is min/max or if we
- // know that this constant is min/max.
- auto *CI = dyn_cast<ConstantInt>(Elt);
- if (!CI || !ConstantIsOk(CI))
- return std::nullopt;
-
- if (!SafeReplacementConstant)
- SafeReplacementConstant = CI;
- }
- } else if (isa<VectorType>(C->getType())) {
- // Handle scalable splat
- Value *SplatC = C->getSplatValue();
- auto *CI = dyn_cast_or_null<ConstantInt>(SplatC);
- // Bail out if the constant can't be safely incremented/decremented.
- if (!CI || !ConstantIsOk(CI))
- return std::nullopt;
- } else {
- // ConstantExpr?
- return std::nullopt;
- }
-
- // It may not be safe to change a compare predicate in the presence of
- // undefined elements, so replace those elements with the first safe constant
- // that we found.
- // TODO: in case of poison, it is safe; let's replace undefs only.
- if (C->containsUndefOrPoisonElement()) {
- assert(SafeReplacementConstant && "Replacement constant not set");
- C = Constant::replaceUndefsWith(C, SafeReplacementConstant);
- }
-
- CmpInst::Predicate NewPred = CmpInst::getFlippedStrictnessPredicate(Pred);
-
- // Increment or decrement the constant.
- Constant *OneOrNegOne = ConstantInt::get(Type, WillIncrement ? 1 : -1, true);
- Constant *NewC = ConstantExpr::getAdd(C, OneOrNegOne);
-
- return std::make_pair(NewPred, NewC);
-}
-
/// If we have an icmp le or icmp ge instruction with a constant operand, turn
/// it into the appropriate icmp lt or icmp gt instruction. This transform
/// allows them to be folded in visitICmpInst.
@@ -6996,8 +6921,7 @@ static ICmpInst *canonicalizeCmpWithConstant(ICmpInst &I) {
if (!Op1C)
return nullptr;
- auto FlippedStrictness =
- InstCombiner::getFlippedStrictnessPredicateAndConstant(Pred, Op1C);
+ auto FlippedStrictness = getFlippedStrictnessPredicateAndConstant(Pred, Op1C);
if (!FlippedStrictness)
return nullptr;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 7fd91c72a2fb0e..eca518aa640700 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -1689,8 +1689,7 @@ tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp,
return nullptr;
// Check the constant we'd have with flipped-strictness predicate.
- auto FlippedStrictness =
- InstCombiner::getFlippedStrictnessPredicateAndConstant(Pred, C0);
+ auto FlippedStrictness = getFlippedStrictnessPredicateAndConstant(Pred, C0);
if (!FlippedStrictness)
return nullptr;
@@ -1970,8 +1969,7 @@ static Value *foldSelectWithConstOpToBinOp(ICmpInst *Cmp, Value *TrueVal,
Value *RHS;
SelectPatternFlavor SPF;
const DataLayout &DL = BOp->getDataLayout();
- auto Flipped =
- InstCombiner::getFlippedStrictnessPredicateAndConstant(Predicate, C1);
+ auto Flipped = getFlippedStrictnessPredicateAndConstant(Predicate, C1);
if (C3 == ConstantFoldBinaryOpOperands(Opcode, C1, C2, DL)) {
SPF = getSelectPattern(Predicate).Flavor;
``````````
</details>
https://github.com/llvm/llvm-project/pull/122064
More information about the llvm-commits
mailing list