[llvm] IR: introduce CmpInst::is{Eq,Ne}Equivalence (PR #111979)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Oct 11 03:57:28 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
Author: Ramkumar Ramachandra (artagnon)
<details>
<summary>Changes</summary>
Steal impliesEquivalanceIf{True,False} (sic) from GVN, and extend it for floating-point constant vectors. Since InstCombine also performs GVN-like replacements, introduce CmpInst::is{Eq,Ne}Equivalence, and remove the corresponding code in GVN, with the intent of using it in more places.
---
Full diff: https://github.com/llvm/llvm-project/pull/111979.diff
3 Files Affected:
- (modified) llvm/include/llvm/IR/InstrTypes.h (+10)
- (modified) llvm/lib/IR/Instructions.cpp (+57)
- (modified) llvm/lib/Transforms/Scalar/GVN.cpp (+3-56)
``````````diff
diff --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h
index 86d88da3d9460e..1d12dc06c365d8 100644
--- a/llvm/include/llvm/IR/InstrTypes.h
+++ b/llvm/include/llvm/IR/InstrTypes.h
@@ -912,6 +912,16 @@ class CmpInst : public Instruction {
/// Determine if this is an equals/not equals predicate.
bool isEquality() const { return isEquality(getPredicate()); }
+ /// Determine if this is an equals predicate that is also an equivalence. This
+ /// is useful in GVN-like transformations, where we can replace RHS by LHS in
+ /// the true branch of the CmpInst.
+ bool isEqEquivalence() const;
+
+ /// Determine if this is a not-equals predicate that is also an equivalence.
+ /// This is useful in GVN-like transformations, where we can replace RHS by
+ /// LHS in the false branch of the CmpInst.
+ bool isNeEquivalence() const;
+
/// Return true if the predicate is relational (not EQ or NE).
static bool isRelational(Predicate P) { return !isEquality(P); }
diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index 009e0c03957c97..1c13f925e1388d 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -3471,6 +3471,63 @@ bool CmpInst::isEquality(Predicate P) {
llvm_unreachable("Unsupported predicate kind");
}
+// Returns true if either operand of CmpInst is a provably non-zero
+// floating-point constant. If both operands are constants, simply return the
+// equality of the constants.
+static bool hasNonZeroFPOperands(const CmpInst *Cmp) {
+ auto *LHS = dyn_cast<Constant>(Cmp->getOperand(0));
+ auto *RHS = dyn_cast<Constant>(Cmp->getOperand(1));
+ if (LHS && RHS)
+ return LHS == RHS;
+ if (auto *Const = LHS ? LHS : RHS) {
+ if (auto *ConstFP = dyn_cast<ConstantFP>(Const)) {
+ if (!ConstFP->isZero())
+ return true;
+ } else if (auto *ConstVec = dyn_cast<ConstantVector>(Const)) {
+ if (auto *SplatCFP =
+ dyn_cast_or_null<ConstantFP>(ConstVec->getSplatValue())) {
+ if (!SplatCFP->isZero())
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
+// Floating-point equality is not an equivalence when comparing +0.0 with
+// -0.0 or when comparing NaN with another value.
+bool CmpInst::isEqEquivalence() const {
+ switch (getPredicate()) {
+ case CmpInst::Predicate::ICMP_EQ:
+ return true;
+ case CmpInst::Predicate::FCMP_UEQ:
+ if (!hasNoNaNs())
+ return false;
+ [[fallthrough]];
+ case CmpInst::Predicate::FCMP_OEQ:
+ return hasNonZeroFPOperands(this);
+ default:
+ return false;
+ }
+}
+
+// Floating-point equality is not an equivalence when comparing +0.0 with
+// -0.0 or when comparing NaN with another value.
+bool CmpInst::isNeEquivalence() const {
+ switch (getPredicate()) {
+ case CmpInst::Predicate::ICMP_NE:
+ return true;
+ case CmpInst::Predicate::FCMP_ONE:
+ if (!hasNoNaNs())
+ return false;
+ [[fallthrough]];
+ case CmpInst::Predicate::FCMP_UNE:
+ return hasNonZeroFPOperands(this);
+ default:
+ return false;
+ }
+}
+
CmpInst::Predicate CmpInst::getInversePredicate(Predicate pred) {
switch (pred) {
default: llvm_unreachable("Unknown cmp predicate!");
diff --git a/llvm/lib/Transforms/Scalar/GVN.cpp b/llvm/lib/Transforms/Scalar/GVN.cpp
index 2ba600497e00d3..cdd2a9dc06af64 100644
--- a/llvm/lib/Transforms/Scalar/GVN.cpp
+++ b/llvm/lib/Transforms/Scalar/GVN.cpp
@@ -1989,59 +1989,6 @@ bool GVNPass::processNonLocalLoad(LoadInst *Load) {
return Changed;
}
-static bool impliesEquivalanceIfTrue(CmpInst* Cmp) {
- if (Cmp->getPredicate() == CmpInst::Predicate::ICMP_EQ)
- return true;
-
- // Floating point comparisons can be equal, but not equivalent. Cases:
- // NaNs for unordered operators
- // +0.0 vs 0.0 for all operators
- if (Cmp->getPredicate() == CmpInst::Predicate::FCMP_OEQ ||
- (Cmp->getPredicate() == CmpInst::Predicate::FCMP_UEQ &&
- Cmp->getFastMathFlags().noNaNs())) {
- Value *LHS = Cmp->getOperand(0);
- Value *RHS = Cmp->getOperand(1);
- // If we can prove either side non-zero, then equality must imply
- // equivalence.
- // FIXME: We should do this optimization if 'no signed zeros' is
- // applicable via an instruction-level fast-math-flag or some other
- // indicator that relaxed FP semantics are being used.
- if (isa<ConstantFP>(LHS) && !cast<ConstantFP>(LHS)->isZero())
- return true;
- if (isa<ConstantFP>(RHS) && !cast<ConstantFP>(RHS)->isZero())
- return true;
- // TODO: Handle vector floating point constants
- }
- return false;
-}
-
-static bool impliesEquivalanceIfFalse(CmpInst* Cmp) {
- if (Cmp->getPredicate() == CmpInst::Predicate::ICMP_NE)
- return true;
-
- // Floating point comparisons can be equal, but not equivelent. Cases:
- // NaNs for unordered operators
- // +0.0 vs 0.0 for all operators
- if ((Cmp->getPredicate() == CmpInst::Predicate::FCMP_ONE &&
- Cmp->getFastMathFlags().noNaNs()) ||
- Cmp->getPredicate() == CmpInst::Predicate::FCMP_UNE) {
- Value *LHS = Cmp->getOperand(0);
- Value *RHS = Cmp->getOperand(1);
- // If we can prove either side non-zero, then equality must imply
- // equivalence.
- // FIXME: We should do this optimization if 'no signed zeros' is
- // applicable via an instruction-level fast-math-flag or some other
- // indicator that relaxed FP semantics are being used.
- if (isa<ConstantFP>(LHS) && !cast<ConstantFP>(LHS)->isZero())
- return true;
- if (isa<ConstantFP>(RHS) && !cast<ConstantFP>(RHS)->isZero())
- return true;
- // TODO: Handle vector floating point constants
- }
- return false;
-}
-
-
static bool hasUsersIn(Value *V, BasicBlock *BB) {
return llvm::any_of(V->users(), [BB](User *U) {
auto *I = dyn_cast<Instruction>(U);
@@ -2143,7 +2090,7 @@ bool GVNPass::processAssumeIntrinsic(AssumeInst *IntrinsicI) {
// call void @llvm.assume(i1 %cmp)
// ret float %load ; will change it to ret float %0
if (auto *CmpI = dyn_cast<CmpInst>(V)) {
- if (impliesEquivalanceIfTrue(CmpI)) {
+ if (CmpI->isEqEquivalence()) {
Value *CmpLHS = CmpI->getOperand(0);
Value *CmpRHS = CmpI->getOperand(1);
// Heuristically pick the better replacement -- the choice of heuristic
@@ -2567,8 +2514,8 @@ bool GVNPass::propagateEquality(Value *LHS, Value *RHS,
// If "A == B" is known true, or "A != B" is known false, then replace
// A with B everywhere in the scope. For floating point operations, we
// have to be careful since equality does not always imply equivalance.
- if ((isKnownTrue && impliesEquivalanceIfTrue(Cmp)) ||
- (isKnownFalse && impliesEquivalanceIfFalse(Cmp)))
+ if ((isKnownTrue && Cmp->isEqEquivalence()) ||
+ (isKnownFalse && Cmp->isNeEquivalence()))
Worklist.push_back(std::make_pair(Op0, Op1));
// If "A >= B" is known true, replace "A < B" with false everywhere.
``````````
</details>
https://github.com/llvm/llvm-project/pull/111979
More information about the llvm-commits
mailing list