[llvm] IR: introduce CmpInst::isEquivalence (PR #111979)

Ramkumar Ramachandra via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 21 05:50:06 PDT 2024


https://github.com/artagnon updated https://github.com/llvm/llvm-project/pull/111979

>From 28da0cf548ef21058f12838ec0429ecb377ea5ec Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Fri, 11 Oct 2024 11:38:06 +0100
Subject: [PATCH 1/4] IR: introduce CmpInst::is{Eq,Ne}Equivalence

Steal impliesEquivalanceIf{True,False} (sic) from GVN, and extend it for
floating-point constant vectors. Since InstCombine also performs
GVN-like replacements, introduce CmpInst::is{Eq,Ne}Equivalence, and
remove the corresponding code in GVN, with the intent of using it in
more places.
---
 llvm/include/llvm/IR/InstrTypes.h   | 10 +++++
 llvm/include/llvm/IR/PatternMatch.h | 10 +++++
 llvm/lib/IR/Instructions.cpp        | 49 ++++++++++++++++++++++++
 llvm/lib/Transforms/Scalar/GVN.cpp  | 59 ++---------------------------
 4 files changed, 72 insertions(+), 56 deletions(-)

diff --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h
index 99f72792ce4024..85e84afda738c3 100644
--- a/llvm/include/llvm/IR/InstrTypes.h
+++ b/llvm/include/llvm/IR/InstrTypes.h
@@ -912,6 +912,16 @@ class CmpInst : public Instruction {
   /// Determine if this is an equals/not equals predicate.
   bool isEquality() const { return isEquality(getPredicate()); }
 
+  /// Determine if this is an equals predicate that is also an equivalence. This
+  /// is useful in GVN-like transformations, where we can replace RHS by LHS in
+  /// the true branch of the CmpInst.
+  bool isEqEquivalence() const;
+
+  /// Determine if this is a not-equals predicate that is also an equivalence.
+  /// This is useful in GVN-like transformations, where we can replace RHS by
+  /// LHS in the false branch of the CmpInst.
+  bool isNeEquivalence() const;
+
   /// Return true if the predicate is relational (not EQ or NE).
   static bool isRelational(Predicate P) { return !isEquality(P); }
 
diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index c3349c9772c7ad..0d6df727906324 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -792,6 +792,16 @@ inline cstfp_pred_ty<is_non_zero_fp> m_NonZeroFP() {
   return cstfp_pred_ty<is_non_zero_fp>();
 }
 
+struct is_non_zero_not_denormal_fp {
+  bool isValue(const APFloat &C) { return !C.isDenormal() && C.isNonZero(); }
+};
+
+/// Match a floating-point non-zero that is not a denormal.
+/// For vectors, this includes constants with undefined elements.
+inline cstfp_pred_ty<is_non_zero_not_denormal_fp> m_NonZeroNotDenormalFP() {
+  return cstfp_pred_ty<is_non_zero_not_denormal_fp>();
+}
+
 ///////////////////////////////////////////////////////////////////////////////
 
 template <typename Class> struct bind_ty {
diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index 009e0c03957c97..98b474f5bbc36c 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -32,6 +32,7 @@
 #include "llvm/IR/Metadata.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Operator.h"
+#include "llvm/IR/PatternMatch.h"
 #include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Value.h"
@@ -3471,6 +3472,54 @@ bool CmpInst::isEquality(Predicate P) {
   llvm_unreachable("Unsupported predicate kind");
 }
 
+// Returns true if either operand of CmpInst is a provably non-zero
+// floating-point constant.
+static bool hasNonZeroFPOperands(const CmpInst *Cmp) {
+  auto *LHS = dyn_cast<Constant>(Cmp->getOperand(0));
+  auto *RHS = dyn_cast<Constant>(Cmp->getOperand(1));
+  if (auto *Const = LHS ? LHS : RHS) {
+    using namespace llvm::PatternMatch;
+    return match(Const, m_NonZeroNotDenormalFP());
+  }
+  return false;
+}
+
+// Floating-point equality is not an equivalence when comparing +0.0 with
+// -0.0, when comparing NaN with another value, or when flushing
+// denormals-to-zero.
+bool CmpInst::isEqEquivalence() const {
+  switch (getPredicate()) {
+  case CmpInst::Predicate::ICMP_EQ:
+    return true;
+  case CmpInst::Predicate::FCMP_UEQ:
+    if (!hasNoNaNs())
+      return false;
+    [[fallthrough]];
+  case CmpInst::Predicate::FCMP_OEQ:
+    return hasNonZeroFPOperands(this);
+  default:
+    return false;
+  }
+}
+
+// Floating-point equality is not an equivalence when comparing +0.0 with
+// -0.0, when comparing NaN with another value, or when flushing
+// denormals-to-zero.
+bool CmpInst::isNeEquivalence() const {
+  switch (getPredicate()) {
+  case CmpInst::Predicate::ICMP_NE:
+    return true;
+  case CmpInst::Predicate::FCMP_ONE:
+    if (!hasNoNaNs())
+      return false;
+    [[fallthrough]];
+  case CmpInst::Predicate::FCMP_UNE:
+    return hasNonZeroFPOperands(this);
+  default:
+    return false;
+  }
+}
+
 CmpInst::Predicate CmpInst::getInversePredicate(Predicate pred) {
   switch (pred) {
     default: llvm_unreachable("Unknown cmp predicate!");
diff --git a/llvm/lib/Transforms/Scalar/GVN.cpp b/llvm/lib/Transforms/Scalar/GVN.cpp
index 2ba600497e00d3..cdd2a9dc06af64 100644
--- a/llvm/lib/Transforms/Scalar/GVN.cpp
+++ b/llvm/lib/Transforms/Scalar/GVN.cpp
@@ -1989,59 +1989,6 @@ bool GVNPass::processNonLocalLoad(LoadInst *Load) {
   return Changed;
 }
 
-static bool impliesEquivalanceIfTrue(CmpInst* Cmp) {
-  if (Cmp->getPredicate() == CmpInst::Predicate::ICMP_EQ)
-    return true;
-
-  // Floating point comparisons can be equal, but not equivalent.  Cases:
-  // NaNs for unordered operators
-  // +0.0 vs 0.0 for all operators
-  if (Cmp->getPredicate() == CmpInst::Predicate::FCMP_OEQ ||
-      (Cmp->getPredicate() == CmpInst::Predicate::FCMP_UEQ &&
-       Cmp->getFastMathFlags().noNaNs())) {
-      Value *LHS = Cmp->getOperand(0);
-      Value *RHS = Cmp->getOperand(1);
-      // If we can prove either side non-zero, then equality must imply
-      // equivalence.
-      // FIXME: We should do this optimization if 'no signed zeros' is
-      // applicable via an instruction-level fast-math-flag or some other
-      // indicator that relaxed FP semantics are being used.
-      if (isa<ConstantFP>(LHS) && !cast<ConstantFP>(LHS)->isZero())
-        return true;
-      if (isa<ConstantFP>(RHS) && !cast<ConstantFP>(RHS)->isZero())
-        return true;
-      // TODO: Handle vector floating point constants
-  }
-  return false;
-}
-
-static bool impliesEquivalanceIfFalse(CmpInst* Cmp) {
-  if (Cmp->getPredicate() == CmpInst::Predicate::ICMP_NE)
-    return true;
-
-  // Floating point comparisons can be equal, but not equivelent.  Cases:
-  // NaNs for unordered operators
-  // +0.0 vs 0.0 for all operators
-  if ((Cmp->getPredicate() == CmpInst::Predicate::FCMP_ONE &&
-       Cmp->getFastMathFlags().noNaNs()) ||
-      Cmp->getPredicate() == CmpInst::Predicate::FCMP_UNE) {
-      Value *LHS = Cmp->getOperand(0);
-      Value *RHS = Cmp->getOperand(1);
-      // If we can prove either side non-zero, then equality must imply
-      // equivalence.
-      // FIXME: We should do this optimization if 'no signed zeros' is
-      // applicable via an instruction-level fast-math-flag or some other
-      // indicator that relaxed FP semantics are being used.
-      if (isa<ConstantFP>(LHS) && !cast<ConstantFP>(LHS)->isZero())
-        return true;
-      if (isa<ConstantFP>(RHS) && !cast<ConstantFP>(RHS)->isZero())
-        return true;
-      // TODO: Handle vector floating point constants
-  }
-  return false;
-}
-
-
 static bool hasUsersIn(Value *V, BasicBlock *BB) {
   return llvm::any_of(V->users(), [BB](User *U) {
     auto *I = dyn_cast<Instruction>(U);
@@ -2143,7 +2090,7 @@ bool GVNPass::processAssumeIntrinsic(AssumeInst *IntrinsicI) {
   // call void @llvm.assume(i1 %cmp)
   // ret float %load ; will change it to ret float %0
   if (auto *CmpI = dyn_cast<CmpInst>(V)) {
-    if (impliesEquivalanceIfTrue(CmpI)) {
+    if (CmpI->isEqEquivalence()) {
       Value *CmpLHS = CmpI->getOperand(0);
       Value *CmpRHS = CmpI->getOperand(1);
       // Heuristically pick the better replacement -- the choice of heuristic
@@ -2567,8 +2514,8 @@ bool GVNPass::propagateEquality(Value *LHS, Value *RHS,
       // If "A == B" is known true, or "A != B" is known false, then replace
       // A with B everywhere in the scope.  For floating point operations, we
       // have to be careful since equality does not always imply equivalance.
-      if ((isKnownTrue && impliesEquivalanceIfTrue(Cmp)) ||
-          (isKnownFalse && impliesEquivalanceIfFalse(Cmp)))
+      if ((isKnownTrue && Cmp->isEqEquivalence()) ||
+          (isKnownFalse && Cmp->isNeEquivalence()))
         Worklist.push_back(std::make_pair(Op0, Op1));
 
       // If "A >= B" is known true, replace "A < B" with false everywhere.

>From 7bcf6c98c876b2c21f5e18b43e7ee158562d2572 Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Tue, 15 Oct 2024 12:29:06 +0100
Subject: [PATCH 2/4] GVN: add test for denormals

---
 llvm/test/Transforms/GVN/edge.ll | 30 +++++++++++++++++++++++++++++-
 1 file changed, 29 insertions(+), 1 deletion(-)

diff --git a/llvm/test/Transforms/GVN/edge.ll b/llvm/test/Transforms/GVN/edge.ll
index 9703195d3b642e..83c4c336f6474a 100644
--- a/llvm/test/Transforms/GVN/edge.ll
+++ b/llvm/test/Transforms/GVN/edge.ll
@@ -224,6 +224,34 @@ return:
   ret double %retval
 }
 
+; Denormals may be flushed to zero in some cases by the backend.
+; Hence, treat denormals as 0.
+define float @fcmp_oeq_denormal(float %x, float %y) {
+; CHECK-LABEL: define float @fcmp_oeq_denormal(
+; CHECK-SAME: float [[X:%.*]], float [[Y:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*]]:
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp oeq float [[Y]], 0x3800000000000000
+; CHECK-NEXT:    br i1 [[CMP]], label %[[IF:.*]], label %[[RETURN:.*]]
+; CHECK:       [[IF]]:
+; CHECK-NEXT:    [[DIV:%.*]] = fdiv float [[X]], [[Y]]
+; CHECK-NEXT:    br label %[[RETURN]]
+; CHECK:       [[RETURN]]:
+; CHECK-NEXT:    [[RETVAL:%.*]] = phi float [ [[DIV]], %[[IF]] ], [ [[X]], %[[ENTRY]] ]
+; CHECK-NEXT:    ret float [[RETVAL]]
+;
+entry:
+  %cmp = fcmp oeq float %y, 0x3800000000000000
+  br i1 %cmp, label %if, label %return
+
+if:
+  %div = fdiv float %x, %y
+  br label %return
+
+return:
+  %retval = phi float [ %div, %if ], [ %x, %entry ]
+  ret float %retval
+}
+
 define double @fcmp_une_zero(double %x, double %y) {
 ; CHECK-LABEL: define double @fcmp_une_zero(
 ; CHECK-SAME: double [[X:%.*]], double [[Y:%.*]]) {
@@ -251,7 +279,7 @@ return:
 }
 
 ; We also cannot propagate a value if it's not a constant.
-; This is because the value could be 0.0 or -0.0.
+; This is because the value could be 0.0, -0.0, or a denormal.
 
 define double @fcmp_oeq_maybe_zero(double %x, double %y, double %z1, double %z2) {
 ; CHECK-LABEL: define double @fcmp_oeq_maybe_zero(

>From e57f9eb52340f66ff564da6512a0dd16873cf9da Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Mon, 21 Oct 2024 13:01:13 +0100
Subject: [PATCH 3/4] CmpInst: merge functions; address review

---
 llvm/include/llvm/IR/InstrTypes.h  | 14 +++++---------
 llvm/lib/IR/Instructions.cpp       | 22 ++++++----------------
 llvm/lib/Transforms/Scalar/GVN.cpp |  6 +++---
 3 files changed, 14 insertions(+), 28 deletions(-)

diff --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h
index 85e84afda738c3..d93cd5958bc07f 100644
--- a/llvm/include/llvm/IR/InstrTypes.h
+++ b/llvm/include/llvm/IR/InstrTypes.h
@@ -912,15 +912,11 @@ class CmpInst : public Instruction {
   /// Determine if this is an equals/not equals predicate.
   bool isEquality() const { return isEquality(getPredicate()); }
 
-  /// Determine if this is an equals predicate that is also an equivalence. This
-  /// is useful in GVN-like transformations, where we can replace RHS by LHS in
-  /// the true branch of the CmpInst.
-  bool isEqEquivalence() const;
-
-  /// Determine if this is a not-equals predicate that is also an equivalence.
-  /// This is useful in GVN-like transformations, where we can replace RHS by
-  /// LHS in the false branch of the CmpInst.
-  bool isNeEquivalence() const;
+  /// Determine if one operand of this compare can always be replaced by the
+  /// other operand, ignoring provenance considerations. If \p Invert is false,
+  /// check for equivalence with an equals predicate; otherwise, check for
+  /// equivalence with a not-equals predicate.
+  bool isEquivalence(bool Invert = false) const;
 
   /// Return true if the predicate is relational (not EQ or NE).
   static bool isRelational(Predicate P) { return !isEquality(P); }
diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index 98b474f5bbc36c..63f3568f359b86 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -3487,34 +3487,24 @@ static bool hasNonZeroFPOperands(const CmpInst *Cmp) {
 // Floating-point equality is not an equivalence when comparing +0.0 with
 // -0.0, when comparing NaN with another value, or when flushing
 // denormals-to-zero.
-bool CmpInst::isEqEquivalence() const {
+bool CmpInst::isEquivalence(bool Invert) const {
   switch (getPredicate()) {
   case CmpInst::Predicate::ICMP_EQ:
-    return true;
+    return !Invert;
+  case CmpInst::Predicate::ICMP_NE:
+    return Invert;
   case CmpInst::Predicate::FCMP_UEQ:
     if (!hasNoNaNs())
       return false;
     [[fallthrough]];
   case CmpInst::Predicate::FCMP_OEQ:
-    return hasNonZeroFPOperands(this);
-  default:
-    return false;
-  }
-}
-
-// Floating-point equality is not an equivalence when comparing +0.0 with
-// -0.0, when comparing NaN with another value, or when flushing
-// denormals-to-zero.
-bool CmpInst::isNeEquivalence() const {
-  switch (getPredicate()) {
-  case CmpInst::Predicate::ICMP_NE:
-    return true;
+    return !Invert && hasNonZeroFPOperands(this);
   case CmpInst::Predicate::FCMP_ONE:
     if (!hasNoNaNs())
       return false;
     [[fallthrough]];
   case CmpInst::Predicate::FCMP_UNE:
-    return hasNonZeroFPOperands(this);
+    return Invert && hasNonZeroFPOperands(this);
   default:
     return false;
   }
diff --git a/llvm/lib/Transforms/Scalar/GVN.cpp b/llvm/lib/Transforms/Scalar/GVN.cpp
index cdd2a9dc06af64..adfac2b5914e8e 100644
--- a/llvm/lib/Transforms/Scalar/GVN.cpp
+++ b/llvm/lib/Transforms/Scalar/GVN.cpp
@@ -2090,7 +2090,7 @@ bool GVNPass::processAssumeIntrinsic(AssumeInst *IntrinsicI) {
   // call void @llvm.assume(i1 %cmp)
   // ret float %load ; will change it to ret float %0
   if (auto *CmpI = dyn_cast<CmpInst>(V)) {
-    if (CmpI->isEqEquivalence()) {
+    if (CmpI->isEquivalence()) {
       Value *CmpLHS = CmpI->getOperand(0);
       Value *CmpRHS = CmpI->getOperand(1);
       // Heuristically pick the better replacement -- the choice of heuristic
@@ -2514,8 +2514,8 @@ bool GVNPass::propagateEquality(Value *LHS, Value *RHS,
       // If "A == B" is known true, or "A != B" is known false, then replace
       // A with B everywhere in the scope.  For floating point operations, we
       // have to be careful since equality does not always imply equivalance.
-      if ((isKnownTrue && Cmp->isEqEquivalence()) ||
-          (isKnownFalse && Cmp->isNeEquivalence()))
+      if ((isKnownTrue && Cmp->isEquivalence()) ||
+          (isKnownFalse && Cmp->isEquivalence(/* Invert = */ true)))
         Worklist.push_back(std::make_pair(Op0, Op1));
 
       // If "A >= B" is known true, replace "A < B" with false everywhere.

>From 7fe374186cd4e52e828efce894c58afd34b7c4b4 Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Mon, 21 Oct 2024 13:49:25 +0100
Subject: [PATCH 4/4] CmpInst: de-duplicate code; address review

---
 llvm/include/llvm/IR/InstrTypes.h |  5 ++---
 llvm/lib/IR/Instructions.cpp      | 14 +++-----------
 2 files changed, 5 insertions(+), 14 deletions(-)

diff --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h
index d93cd5958bc07f..1c60eae7f2f85b 100644
--- a/llvm/include/llvm/IR/InstrTypes.h
+++ b/llvm/include/llvm/IR/InstrTypes.h
@@ -913,9 +913,8 @@ class CmpInst : public Instruction {
   bool isEquality() const { return isEquality(getPredicate()); }
 
   /// Determine if one operand of this compare can always be replaced by the
-  /// other operand, ignoring provenance considerations. If \p Invert is false,
-  /// check for equivalence with an equals predicate; otherwise, check for
-  /// equivalence with a not-equals predicate.
+  /// other operand, ignoring provenance considerations. If \p Invert, check for
+  /// equivalence with the inverse predicate.
   bool isEquivalence(bool Invert = false) const;
 
   /// Return true if the predicate is relational (not EQ or NE).
diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index 63f3568f359b86..05e340ffa20a07 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -3488,23 +3488,15 @@ static bool hasNonZeroFPOperands(const CmpInst *Cmp) {
 // -0.0, when comparing NaN with another value, or when flushing
 // denormals-to-zero.
 bool CmpInst::isEquivalence(bool Invert) const {
-  switch (getPredicate()) {
+  switch (Invert ? getInversePredicate() : getPredicate()) {
   case CmpInst::Predicate::ICMP_EQ:
-    return !Invert;
-  case CmpInst::Predicate::ICMP_NE:
-    return Invert;
+    return true;
   case CmpInst::Predicate::FCMP_UEQ:
     if (!hasNoNaNs())
       return false;
     [[fallthrough]];
   case CmpInst::Predicate::FCMP_OEQ:
-    return !Invert && hasNonZeroFPOperands(this);
-  case CmpInst::Predicate::FCMP_ONE:
-    if (!hasNoNaNs())
-      return false;
-    [[fallthrough]];
-  case CmpInst::Predicate::FCMP_UNE:
-    return Invert && hasNonZeroFPOperands(this);
+    return hasNonZeroFPOperands(this);
   default:
     return false;
   }



More information about the llvm-commits mailing list