[llvm] IR: introduce CmpInst::is{Eq,Ne}Equivalence (PR #111979)

Ramkumar Ramachandra via llvm-commits llvm-commits at lists.llvm.org
Fri Oct 11 03:56:57 PDT 2024


https://github.com/artagnon created https://github.com/llvm/llvm-project/pull/111979

Steal impliesEquivalanceIf{True,False} (sic) from GVN, and extend it for floating-point constant vectors. Since InstCombine also performs GVN-like replacements, introduce CmpInst::is{Eq,Ne}Equivalence, and remove the corresponding code in GVN, with the intent of using it in more places.

>From 9adc5d1e995183f129fa7f6b154704312231ee0d Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Fri, 11 Oct 2024 11:38:06 +0100
Subject: [PATCH] IR: introduce CmpInst::is{Eq,Ne}Equivalence

Steal impliesEquivalanceIf{True,False} (sic) from GVN, and extend it for
floating-point constant vectors. Since InstCombine also performs
GVN-like replacements, introduce CmpInst::is{Eq,Ne}Equivalence, and
remove the corresponding code in GVN, with the intent of using it in
more places.
---
 llvm/include/llvm/IR/InstrTypes.h  | 10 +++++
 llvm/lib/IR/Instructions.cpp       | 57 +++++++++++++++++++++++++++++
 llvm/lib/Transforms/Scalar/GVN.cpp | 59 ++----------------------------
 3 files changed, 70 insertions(+), 56 deletions(-)

diff --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h
index 86d88da3d9460e..1d12dc06c365d8 100644
--- a/llvm/include/llvm/IR/InstrTypes.h
+++ b/llvm/include/llvm/IR/InstrTypes.h
@@ -912,6 +912,16 @@ class CmpInst : public Instruction {
   /// Determine if this is an equals/not equals predicate.
   bool isEquality() const { return isEquality(getPredicate()); }
 
+  /// Determine if this is an equals predicate that is also an equivalence. This
+  /// is useful in GVN-like transformations, where we can replace RHS by LHS in
+  /// the true branch of the CmpInst.
+  bool isEqEquivalence() const;
+
+  /// Determine if this is a not-equals predicate that is also an equivalence.
+  /// This is useful in GVN-like transformations, where we can replace RHS by
+  /// LHS in the false branch of the CmpInst.
+  bool isNeEquivalence() const;
+
   /// Return true if the predicate is relational (not EQ or NE).
   static bool isRelational(Predicate P) { return !isEquality(P); }
 
diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index 009e0c03957c97..1c13f925e1388d 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -3471,6 +3471,63 @@ bool CmpInst::isEquality(Predicate P) {
   llvm_unreachable("Unsupported predicate kind");
 }
 
+// Returns true if either operand of CmpInst is a provably non-zero
+// floating-point constant. If both operands are constants, simply return the
+// equality of the constants.
+static bool hasNonZeroFPOperands(const CmpInst *Cmp) {
+  auto *LHS = dyn_cast<Constant>(Cmp->getOperand(0));
+  auto *RHS = dyn_cast<Constant>(Cmp->getOperand(1));
+  if (LHS && RHS)
+    return LHS == RHS;
+  if (auto *Const = LHS ? LHS : RHS) {
+    if (auto *ConstFP = dyn_cast<ConstantFP>(Const)) {
+      if (!ConstFP->isZero())
+        return true;
+    } else if (auto *ConstVec = dyn_cast<ConstantVector>(Const)) {
+      if (auto *SplatCFP =
+              dyn_cast_or_null<ConstantFP>(ConstVec->getSplatValue())) {
+        if (!SplatCFP->isZero())
+          return true;
+      }
+    }
+  }
+  return false;
+}
+
+// Floating-point equality is not an equivalence when comparing +0.0 with
+// -0.0 or when comparing NaN with another value.
+bool CmpInst::isEqEquivalence() const {
+  switch (getPredicate()) {
+  case CmpInst::Predicate::ICMP_EQ:
+    return true;
+  case CmpInst::Predicate::FCMP_UEQ:
+    if (!hasNoNaNs())
+      return false;
+    [[fallthrough]];
+  case CmpInst::Predicate::FCMP_OEQ:
+    return hasNonZeroFPOperands(this);
+  default:
+    return false;
+  }
+}
+
+// Floating-point equality is not an equivalence when comparing +0.0 with
+// -0.0 or when comparing NaN with another value.
+bool CmpInst::isNeEquivalence() const {
+  switch (getPredicate()) {
+  case CmpInst::Predicate::ICMP_NE:
+    return true;
+  case CmpInst::Predicate::FCMP_ONE:
+    if (!hasNoNaNs())
+      return false;
+    [[fallthrough]];
+  case CmpInst::Predicate::FCMP_UNE:
+    return hasNonZeroFPOperands(this);
+  default:
+    return false;
+  }
+}
+
 CmpInst::Predicate CmpInst::getInversePredicate(Predicate pred) {
   switch (pred) {
     default: llvm_unreachable("Unknown cmp predicate!");
diff --git a/llvm/lib/Transforms/Scalar/GVN.cpp b/llvm/lib/Transforms/Scalar/GVN.cpp
index 2ba600497e00d3..cdd2a9dc06af64 100644
--- a/llvm/lib/Transforms/Scalar/GVN.cpp
+++ b/llvm/lib/Transforms/Scalar/GVN.cpp
@@ -1989,59 +1989,6 @@ bool GVNPass::processNonLocalLoad(LoadInst *Load) {
   return Changed;
 }
 
-static bool impliesEquivalanceIfTrue(CmpInst* Cmp) {
-  if (Cmp->getPredicate() == CmpInst::Predicate::ICMP_EQ)
-    return true;
-
-  // Floating point comparisons can be equal, but not equivalent.  Cases:
-  // NaNs for unordered operators
-  // +0.0 vs 0.0 for all operators
-  if (Cmp->getPredicate() == CmpInst::Predicate::FCMP_OEQ ||
-      (Cmp->getPredicate() == CmpInst::Predicate::FCMP_UEQ &&
-       Cmp->getFastMathFlags().noNaNs())) {
-      Value *LHS = Cmp->getOperand(0);
-      Value *RHS = Cmp->getOperand(1);
-      // If we can prove either side non-zero, then equality must imply
-      // equivalence.
-      // FIXME: We should do this optimization if 'no signed zeros' is
-      // applicable via an instruction-level fast-math-flag or some other
-      // indicator that relaxed FP semantics are being used.
-      if (isa<ConstantFP>(LHS) && !cast<ConstantFP>(LHS)->isZero())
-        return true;
-      if (isa<ConstantFP>(RHS) && !cast<ConstantFP>(RHS)->isZero())
-        return true;
-      // TODO: Handle vector floating point constants
-  }
-  return false;
-}
-
-static bool impliesEquivalanceIfFalse(CmpInst* Cmp) {
-  if (Cmp->getPredicate() == CmpInst::Predicate::ICMP_NE)
-    return true;
-
-  // Floating point comparisons can be equal, but not equivelent.  Cases:
-  // NaNs for unordered operators
-  // +0.0 vs 0.0 for all operators
-  if ((Cmp->getPredicate() == CmpInst::Predicate::FCMP_ONE &&
-       Cmp->getFastMathFlags().noNaNs()) ||
-      Cmp->getPredicate() == CmpInst::Predicate::FCMP_UNE) {
-      Value *LHS = Cmp->getOperand(0);
-      Value *RHS = Cmp->getOperand(1);
-      // If we can prove either side non-zero, then equality must imply
-      // equivalence.
-      // FIXME: We should do this optimization if 'no signed zeros' is
-      // applicable via an instruction-level fast-math-flag or some other
-      // indicator that relaxed FP semantics are being used.
-      if (isa<ConstantFP>(LHS) && !cast<ConstantFP>(LHS)->isZero())
-        return true;
-      if (isa<ConstantFP>(RHS) && !cast<ConstantFP>(RHS)->isZero())
-        return true;
-      // TODO: Handle vector floating point constants
-  }
-  return false;
-}
-
-
 static bool hasUsersIn(Value *V, BasicBlock *BB) {
   return llvm::any_of(V->users(), [BB](User *U) {
     auto *I = dyn_cast<Instruction>(U);
@@ -2143,7 +2090,7 @@ bool GVNPass::processAssumeIntrinsic(AssumeInst *IntrinsicI) {
   // call void @llvm.assume(i1 %cmp)
   // ret float %load ; will change it to ret float %0
   if (auto *CmpI = dyn_cast<CmpInst>(V)) {
-    if (impliesEquivalanceIfTrue(CmpI)) {
+    if (CmpI->isEqEquivalence()) {
       Value *CmpLHS = CmpI->getOperand(0);
       Value *CmpRHS = CmpI->getOperand(1);
       // Heuristically pick the better replacement -- the choice of heuristic
@@ -2567,8 +2514,8 @@ bool GVNPass::propagateEquality(Value *LHS, Value *RHS,
       // If "A == B" is known true, or "A != B" is known false, then replace
       // A with B everywhere in the scope.  For floating point operations, we
       // have to be careful since equality does not always imply equivalance.
-      if ((isKnownTrue && impliesEquivalanceIfTrue(Cmp)) ||
-          (isKnownFalse && impliesEquivalanceIfFalse(Cmp)))
+      if ((isKnownTrue && Cmp->isEqEquivalence()) ||
+          (isKnownFalse && Cmp->isNeEquivalence()))
         Worklist.push_back(std::make_pair(Op0, Op1));
 
       // If "A >= B" is known true, replace "A < B" with false everywhere.



More information about the llvm-commits mailing list