[llvm] IR: introduce CmpInst::is{Eq,Ne}Equivalence (PR #111979)

Ramkumar Ramachandra via llvm-commits llvm-commits at lists.llvm.org
Fri Oct 11 05:10:20 PDT 2024


https://github.com/artagnon updated https://github.com/llvm/llvm-project/pull/111979

>From 9adc5d1e995183f129fa7f6b154704312231ee0d Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Fri, 11 Oct 2024 11:38:06 +0100
Subject: [PATCH 1/2] IR: introduce CmpInst::is{Eq,Ne}Equivalence

Steal impliesEquivalanceIf{True,False} (sic) from GVN, and extend it for
floating-point constant vectors. Since InstCombine also performs
GVN-like replacements, introduce CmpInst::is{Eq,Ne}Equivalence, and
remove the corresponding code in GVN, with the intent of using it in
more places.
---
 llvm/include/llvm/IR/InstrTypes.h  | 10 +++++
 llvm/lib/IR/Instructions.cpp       | 57 +++++++++++++++++++++++++++++
 llvm/lib/Transforms/Scalar/GVN.cpp | 59 ++----------------------------
 3 files changed, 70 insertions(+), 56 deletions(-)

diff --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h
index 86d88da3d9460e..1d12dc06c365d8 100644
--- a/llvm/include/llvm/IR/InstrTypes.h
+++ b/llvm/include/llvm/IR/InstrTypes.h
@@ -912,6 +912,16 @@ class CmpInst : public Instruction {
   /// Determine if this is an equals/not equals predicate.
   bool isEquality() const { return isEquality(getPredicate()); }
 
+  /// Determine if this is an equals predicate that is also an equivalence. This
+  /// is useful in GVN-like transformations, where we can replace RHS by LHS in
+  /// the true branch of the CmpInst.
+  bool isEqEquivalence() const;
+
+  /// Determine if this is a not-equals predicate that is also an equivalence.
+  /// This is useful in GVN-like transformations, where we can replace RHS by
+  /// LHS in the false branch of the CmpInst.
+  bool isNeEquivalence() const;
+
   /// Return true if the predicate is relational (not EQ or NE).
   static bool isRelational(Predicate P) { return !isEquality(P); }
 
diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index 009e0c03957c97..1c13f925e1388d 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -3471,6 +3471,63 @@ bool CmpInst::isEquality(Predicate P) {
   llvm_unreachable("Unsupported predicate kind");
 }
 
+// Returns true if either operand of CmpInst is a provably non-zero
+// floating-point constant. If both operands are constants, simply return the
+// equality of the constants.
+static bool hasNonZeroFPOperands(const CmpInst *Cmp) {
+  auto *LHS = dyn_cast<Constant>(Cmp->getOperand(0));
+  auto *RHS = dyn_cast<Constant>(Cmp->getOperand(1));
+  if (LHS && RHS)
+    return LHS == RHS;
+  if (auto *Const = LHS ? LHS : RHS) {
+    if (auto *ConstFP = dyn_cast<ConstantFP>(Const)) {
+      if (!ConstFP->isZero())
+        return true;
+    } else if (auto *ConstVec = dyn_cast<ConstantVector>(Const)) {
+      if (auto *SplatCFP =
+              dyn_cast_or_null<ConstantFP>(ConstVec->getSplatValue())) {
+        if (!SplatCFP->isZero())
+          return true;
+      }
+    }
+  }
+  return false;
+}
+
+// Floating-point equality is not an equivalence when comparing +0.0 with
+// -0.0 or when comparing NaN with another value.
+bool CmpInst::isEqEquivalence() const {
+  switch (getPredicate()) {
+  case CmpInst::Predicate::ICMP_EQ:
+    return true;
+  case CmpInst::Predicate::FCMP_UEQ:
+    if (!hasNoNaNs())
+      return false;
+    [[fallthrough]];
+  case CmpInst::Predicate::FCMP_OEQ:
+    return hasNonZeroFPOperands(this);
+  default:
+    return false;
+  }
+}
+
+// Floating-point equality is not an equivalence when comparing +0.0 with
+// -0.0 or when comparing NaN with another value.
+bool CmpInst::isNeEquivalence() const {
+  switch (getPredicate()) {
+  case CmpInst::Predicate::ICMP_NE:
+    return true;
+  case CmpInst::Predicate::FCMP_ONE:
+    if (!hasNoNaNs())
+      return false;
+    [[fallthrough]];
+  case CmpInst::Predicate::FCMP_UNE:
+    return hasNonZeroFPOperands(this);
+  default:
+    return false;
+  }
+}
+
 CmpInst::Predicate CmpInst::getInversePredicate(Predicate pred) {
   switch (pred) {
     default: llvm_unreachable("Unknown cmp predicate!");
diff --git a/llvm/lib/Transforms/Scalar/GVN.cpp b/llvm/lib/Transforms/Scalar/GVN.cpp
index 2ba600497e00d3..cdd2a9dc06af64 100644
--- a/llvm/lib/Transforms/Scalar/GVN.cpp
+++ b/llvm/lib/Transforms/Scalar/GVN.cpp
@@ -1989,59 +1989,6 @@ bool GVNPass::processNonLocalLoad(LoadInst *Load) {
   return Changed;
 }
 
-static bool impliesEquivalanceIfTrue(CmpInst* Cmp) {
-  if (Cmp->getPredicate() == CmpInst::Predicate::ICMP_EQ)
-    return true;
-
-  // Floating point comparisons can be equal, but not equivalent.  Cases:
-  // NaNs for unordered operators
-  // +0.0 vs 0.0 for all operators
-  if (Cmp->getPredicate() == CmpInst::Predicate::FCMP_OEQ ||
-      (Cmp->getPredicate() == CmpInst::Predicate::FCMP_UEQ &&
-       Cmp->getFastMathFlags().noNaNs())) {
-      Value *LHS = Cmp->getOperand(0);
-      Value *RHS = Cmp->getOperand(1);
-      // If we can prove either side non-zero, then equality must imply
-      // equivalence.
-      // FIXME: We should do this optimization if 'no signed zeros' is
-      // applicable via an instruction-level fast-math-flag or some other
-      // indicator that relaxed FP semantics are being used.
-      if (isa<ConstantFP>(LHS) && !cast<ConstantFP>(LHS)->isZero())
-        return true;
-      if (isa<ConstantFP>(RHS) && !cast<ConstantFP>(RHS)->isZero())
-        return true;
-      // TODO: Handle vector floating point constants
-  }
-  return false;
-}
-
-static bool impliesEquivalanceIfFalse(CmpInst* Cmp) {
-  if (Cmp->getPredicate() == CmpInst::Predicate::ICMP_NE)
-    return true;
-
-  // Floating point comparisons can be equal, but not equivelent.  Cases:
-  // NaNs for unordered operators
-  // +0.0 vs 0.0 for all operators
-  if ((Cmp->getPredicate() == CmpInst::Predicate::FCMP_ONE &&
-       Cmp->getFastMathFlags().noNaNs()) ||
-      Cmp->getPredicate() == CmpInst::Predicate::FCMP_UNE) {
-      Value *LHS = Cmp->getOperand(0);
-      Value *RHS = Cmp->getOperand(1);
-      // If we can prove either side non-zero, then equality must imply
-      // equivalence.
-      // FIXME: We should do this optimization if 'no signed zeros' is
-      // applicable via an instruction-level fast-math-flag or some other
-      // indicator that relaxed FP semantics are being used.
-      if (isa<ConstantFP>(LHS) && !cast<ConstantFP>(LHS)->isZero())
-        return true;
-      if (isa<ConstantFP>(RHS) && !cast<ConstantFP>(RHS)->isZero())
-        return true;
-      // TODO: Handle vector floating point constants
-  }
-  return false;
-}
-
-
 static bool hasUsersIn(Value *V, BasicBlock *BB) {
   return llvm::any_of(V->users(), [BB](User *U) {
     auto *I = dyn_cast<Instruction>(U);
@@ -2143,7 +2090,7 @@ bool GVNPass::processAssumeIntrinsic(AssumeInst *IntrinsicI) {
   // call void @llvm.assume(i1 %cmp)
   // ret float %load ; will change it to ret float %0
   if (auto *CmpI = dyn_cast<CmpInst>(V)) {
-    if (impliesEquivalanceIfTrue(CmpI)) {
+    if (CmpI->isEqEquivalence()) {
       Value *CmpLHS = CmpI->getOperand(0);
       Value *CmpRHS = CmpI->getOperand(1);
       // Heuristically pick the better replacement -- the choice of heuristic
@@ -2567,8 +2514,8 @@ bool GVNPass::propagateEquality(Value *LHS, Value *RHS,
       // If "A == B" is known true, or "A != B" is known false, then replace
       // A with B everywhere in the scope.  For floating point operations, we
       // have to be careful since equality does not always imply equivalance.
-      if ((isKnownTrue && impliesEquivalanceIfTrue(Cmp)) ||
-          (isKnownFalse && impliesEquivalanceIfFalse(Cmp)))
+      if ((isKnownTrue && Cmp->isEqEquivalence()) ||
+          (isKnownFalse && Cmp->isNeEquivalence()))
         Worklist.push_back(std::make_pair(Op0, Op1));
 
       // If "A >= B" is known true, replace "A < B" with false everywhere.

>From 5605c343b64b9c4f9cb41fc10d1c5edce88bb518 Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Fri, 11 Oct 2024 12:55:21 +0100
Subject: [PATCH 2/2] Instructions: address review

---
 llvm/lib/IR/Instructions.cpp | 24 ++++++++----------------
 1 file changed, 8 insertions(+), 16 deletions(-)

diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index 1c13f925e1388d..caa8c4e16ca093 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -32,6 +32,7 @@
 #include "llvm/IR/Metadata.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Operator.h"
+#include "llvm/IR/PatternMatch.h"
 #include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Value.h"
@@ -3472,30 +3473,20 @@ bool CmpInst::isEquality(Predicate P) {
 }
 
 // Returns true if either operand of CmpInst is a provably non-zero
-// floating-point constant. If both operands are constants, simply return the
-// equality of the constants.
+// floating-point constant.
 static bool hasNonZeroFPOperands(const CmpInst *Cmp) {
   auto *LHS = dyn_cast<Constant>(Cmp->getOperand(0));
   auto *RHS = dyn_cast<Constant>(Cmp->getOperand(1));
-  if (LHS && RHS)
-    return LHS == RHS;
   if (auto *Const = LHS ? LHS : RHS) {
-    if (auto *ConstFP = dyn_cast<ConstantFP>(Const)) {
-      if (!ConstFP->isZero())
-        return true;
-    } else if (auto *ConstVec = dyn_cast<ConstantVector>(Const)) {
-      if (auto *SplatCFP =
-              dyn_cast_or_null<ConstantFP>(ConstVec->getSplatValue())) {
-        if (!SplatCFP->isZero())
-          return true;
-      }
-    }
+    using namespace llvm::PatternMatch;
+    return Const->isNormalFP() && match(Const, m_NonZeroFP());
   }
   return false;
 }
 
 // Floating-point equality is not an equivalence when comparing +0.0 with
-// -0.0 or when comparing NaN with another value.
+// -0.0, when comparing NaN with another value, or when flushing
+// denormals-to-zero.
 bool CmpInst::isEqEquivalence() const {
   switch (getPredicate()) {
   case CmpInst::Predicate::ICMP_EQ:
@@ -3512,7 +3503,8 @@ bool CmpInst::isEqEquivalence() const {
 }
 
 // Floating-point equality is not an equivalence when comparing +0.0 with
-// -0.0 or when comparing NaN with another value.
+// -0.0, when comparing NaN with another value, or when flushing
+// denormals-to-zero.
 bool CmpInst::isNeEquivalence() const {
   switch (getPredicate()) {
   case CmpInst::Predicate::ICMP_NE:



More information about the llvm-commits mailing list