[llvm] InstSimplify: support floating-point equivalences (PR #115152)

via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 6 04:10:46 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-analysis

Author: Ramkumar Ramachandra (artagnon)

<details>
<summary>Changes</summary>

Since cd16b07 (IR: introduce CmpInst::isEquivalence), there is now an isEquivalence routine in CmpInst that we can use to determine equivalence in simplifySelectWithICmpEq. Implement this, extending the code from integer-equalities to integer and floating-point equivalences.

---
Full diff: https://github.com/llvm/llvm-project/pull/115152.diff


2 Files Affected:

- (modified) llvm/lib/Analysis/InstructionSimplify.cpp (+28-32) 
- (added) llvm/test/Transforms/InstSimplify/select-equivalence-fp.ll (+156) 


``````````diff
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 2cb2612bf611e3..198707c5667c8c 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -4617,10 +4617,10 @@ static Value *simplifySelectWithFakeICmpEq(Value *CmpLHS, Value *CmpRHS,
 
 /// Try to simplify a select instruction when its condition operand is an
 /// integer equality comparison.
-static Value *simplifySelectWithICmpEq(Value *CmpLHS, Value *CmpRHS,
-                                       Value *TrueVal, Value *FalseVal,
-                                       const SimplifyQuery &Q,
-                                       unsigned MaxRecurse) {
+static Value *simplifySelectWithEquivalence(Value *CmpLHS, Value *CmpRHS,
+                                            Value *TrueVal, Value *FalseVal,
+                                            const SimplifyQuery &Q,
+                                            unsigned MaxRecurse) {
   if (simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q.getWithoutUndef(),
                              /* AllowRefinement */ false,
                              /* DropFlags */ nullptr, MaxRecurse) == TrueVal)
@@ -4635,23 +4635,21 @@ static Value *simplifySelectWithICmpEq(Value *CmpLHS, Value *CmpRHS,
 
 /// Try to simplify a select instruction when its condition operand is an
 /// integer comparison.
-static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,
-                                         Value *FalseVal,
-                                         const SimplifyQuery &Q,
-                                         unsigned MaxRecurse) {
+static Value *simplifySelectWithCmpCond(Value *CondVal, Value *TrueVal,
+                                        Value *FalseVal, const SimplifyQuery &Q,
+                                        unsigned MaxRecurse) {
   ICmpInst::Predicate Pred;
   Value *CmpLHS, *CmpRHS;
-  if (!match(CondVal, m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS))))
+  if (!match(CondVal, m_Cmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS))))
     return nullptr;
+  auto *CI = cast<CmpInst>(CondVal);
 
   if (Value *V = simplifyCmpSelOfMaxMin(CmpLHS, CmpRHS, Pred, TrueVal, FalseVal))
     return V;
 
-  // Canonicalize ne to eq predicate.
-  if (Pred == ICmpInst::ICMP_NE) {
-    Pred = ICmpInst::ICMP_EQ;
+  // Canonicalize the equivalence, of which equality is a subset.
+  if (CI->isEquivalence(/*Invert=*/true))
     std::swap(TrueVal, FalseVal);
-  }
 
   // Check for integer min/max with a limit constant:
   // X > MIN_INT ? X : MIN_INT --> X
@@ -4659,9 +4657,7 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,
   if (TrueVal->getType()->isIntOrIntVectorTy()) {
     Value *X, *Y;
     SelectPatternFlavor SPF =
-        matchDecomposedSelectPattern(cast<ICmpInst>(CondVal), TrueVal, FalseVal,
-                                     X, Y)
-            .Flavor;
+        matchDecomposedSelectPattern(CI, TrueVal, FalseVal, X, Y).Flavor;
     if (SelectPatternResult::isMinOrMax(SPF) && Pred == getMinMaxPred(SPF)) {
       APInt LimitC = getMinMaxLimit(getInverseMinMaxFlavor(SPF),
                                     X->getType()->getScalarSizeInBits());
@@ -4670,7 +4666,7 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,
     }
   }
 
-  if (Pred == ICmpInst::ICMP_EQ && match(CmpRHS, m_Zero())) {
+  if (CI->isEquality() && match(CmpRHS, m_Zero())) {
     Value *X;
     const APInt *Y;
     if (match(CmpLHS, m_And(m_Value(X), m_APInt(Y))))
@@ -4698,7 +4694,7 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,
     // (ShAmt == 0) ? X : fshl(X, X, ShAmt) --> fshl(X, X, ShAmt)
     // (ShAmt == 0) ? X : fshr(X, X, ShAmt) --> fshr(X, X, ShAmt)
     if (match(FalseVal, isRotate) && TrueVal == X && CmpLHS == ShAmt &&
-        Pred == ICmpInst::ICMP_EQ)
+        CI->isEquality())
       return FalseVal;
 
     // X == 0 ? abs(X) : -abs(X) --> -abs(X)
@@ -4720,12 +4716,12 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,
   // If we have a scalar equality comparison, then we know the value in one of
   // the arms of the select. See if substituting this value into the arm and
   // simplifying the result yields the same value as the other arm.
-  if (Pred == ICmpInst::ICMP_EQ) {
-    if (Value *V = simplifySelectWithICmpEq(CmpLHS, CmpRHS, TrueVal, FalseVal,
-                                            Q, MaxRecurse))
+  if (CI->isEquivalence() || CI->isEquivalence(/*Invert=*/true)) {
+    if (Value *V = simplifySelectWithEquivalence(CmpLHS, CmpRHS, TrueVal,
+                                                 FalseVal, Q, MaxRecurse))
       return V;
-    if (Value *V = simplifySelectWithICmpEq(CmpRHS, CmpLHS, TrueVal, FalseVal,
-                                            Q, MaxRecurse))
+    if (Value *V = simplifySelectWithEquivalence(CmpRHS, CmpLHS, TrueVal,
+                                                 FalseVal, Q, MaxRecurse))
       return V;
 
     Value *X;
@@ -4734,11 +4730,11 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,
     if (match(CmpLHS, m_Or(m_Value(X), m_Value(Y))) &&
         match(CmpRHS, m_Zero())) {
       // (X | Y) == 0 implies X == 0 and Y == 0.
-      if (Value *V = simplifySelectWithICmpEq(X, CmpRHS, TrueVal, FalseVal, Q,
-                                              MaxRecurse))
+      if (Value *V = simplifySelectWithEquivalence(X, CmpRHS, TrueVal, FalseVal,
+                                                   Q, MaxRecurse))
         return V;
-      if (Value *V = simplifySelectWithICmpEq(Y, CmpRHS, TrueVal, FalseVal, Q,
-                                              MaxRecurse))
+      if (Value *V = simplifySelectWithEquivalence(Y, CmpRHS, TrueVal, FalseVal,
+                                                   Q, MaxRecurse))
         return V;
     }
 
@@ -4746,11 +4742,11 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,
     if (match(CmpLHS, m_And(m_Value(X), m_Value(Y))) &&
         match(CmpRHS, m_AllOnes())) {
       // (X & Y) == -1 implies X == -1 and Y == -1.
-      if (Value *V = simplifySelectWithICmpEq(X, CmpRHS, TrueVal, FalseVal, Q,
-                                              MaxRecurse))
+      if (Value *V = simplifySelectWithEquivalence(X, CmpRHS, TrueVal, FalseVal,
+                                                   Q, MaxRecurse))
         return V;
-      if (Value *V = simplifySelectWithICmpEq(Y, CmpRHS, TrueVal, FalseVal, Q,
-                                              MaxRecurse))
+      if (Value *V = simplifySelectWithEquivalence(Y, CmpRHS, TrueVal, FalseVal,
+                                                   Q, MaxRecurse))
         return V;
     }
   }
@@ -4952,7 +4948,7 @@ static Value *simplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal,
   }
 
   if (Value *V =
-          simplifySelectWithICmpCond(Cond, TrueVal, FalseVal, Q, MaxRecurse))
+          simplifySelectWithCmpCond(Cond, TrueVal, FalseVal, Q, MaxRecurse))
     return V;
 
   if (Value *V = simplifySelectWithFCmp(Cond, TrueVal, FalseVal, Q))
diff --git a/llvm/test/Transforms/InstSimplify/select-equivalence-fp.ll b/llvm/test/Transforms/InstSimplify/select-equivalence-fp.ll
new file mode 100644
index 00000000000000..a59139246b00a6
--- /dev/null
+++ b/llvm/test/Transforms/InstSimplify/select-equivalence-fp.ll
@@ -0,0 +1,156 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt < %s -passes=instsimplify -S | FileCheck %s
+
+define float @select_fcmp_fsub_oeq(float %x, float %y) {
+; CHECK-LABEL: @select_fcmp_fsub_oeq(
+; CHECK-NEXT:    ret float 0.000000e+00
+;
+  %fcmp = fcmp oeq float %y, 2.
+  %fadd = fsub float %y, 2.
+  %sel = select i1 %fcmp, float %fadd, float 0.
+  ret float %sel
+}
+
+define float @select_fcmp_fsub_oeq_zero(float %x, float %y) {
+; CHECK-LABEL: @select_fcmp_fsub_oeq_zero(
+; CHECK-NEXT:    [[FCMP:%.*]] = fcmp oeq float [[Y:%.*]], 0.000000e+00
+; CHECK-NEXT:    [[FADD:%.*]] = fsub float [[Y]], 2.000000e+00
+; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[FCMP]], float [[FADD]], float 2.000000e+00
+; CHECK-NEXT:    ret float [[SEL]]
+;
+  %fcmp = fcmp oeq float %y, 0.
+  %fadd = fsub float %y, 2.
+  %sel = select i1 %fcmp, float %fadd, float 2.
+  ret float %sel
+}
+
+define float @select_fcmp_fsub_ueq(float %x, float %y) {
+; CHECK-LABEL: @select_fcmp_fsub_ueq(
+; CHECK-NEXT:    [[FCMP:%.*]] = fcmp ueq float [[Y:%.*]], 2.000000e+00
+; CHECK-NEXT:    [[FADD:%.*]] = fsub float [[Y]], 2.000000e+00
+; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[FCMP]], float [[FADD]], float 0.000000e+00
+; CHECK-NEXT:    ret float [[SEL]]
+;
+  %fcmp = fcmp ueq float %y, 2.
+  %fadd = fsub float %y, 2.
+  %sel = select i1 %fcmp, float %fadd, float 0.
+  ret float %sel
+}
+
+define float @select_fcmp_fsub_ueq_nnan(float %x, float %y) {
+; CHECK-LABEL: @select_fcmp_fsub_ueq_nnan(
+; CHECK-NEXT:    ret float 0.000000e+00
+;
+  %fcmp = fcmp nnan ueq float %y, 2.
+  %fadd = fsub float %y, 2.
+  %sel = select i1 %fcmp, float %fadd, float 0.
+  ret float %sel
+}
+
+define float @select_fcmp_fsub_une(float %x, float %y) {
+; CHECK-LABEL: @select_fcmp_fsub_une(
+; CHECK-NEXT:    ret float 0.000000e+00
+;
+  %fcmp = fcmp une float %y, 2.
+  %fadd = fsub float %y, 2.
+  %sel = select i1 %fcmp, float 0., float %fadd
+  ret float %sel
+}
+
+define float @select_fcmp_fsub_une_zero(float %x, float %y) {
+; CHECK-LABEL: @select_fcmp_fsub_une_zero(
+; CHECK-NEXT:    [[FCMP:%.*]] = fcmp une float [[Y:%.*]], 0.000000e+00
+; CHECK-NEXT:    [[FADD:%.*]] = fsub float [[Y]], 2.000000e+00
+; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[FCMP]], float 2.000000e+00, float [[FADD]]
+; CHECK-NEXT:    ret float [[SEL]]
+;
+  %fcmp = fcmp une float %y, 0.
+  %fadd = fsub float %y, 2.
+  %sel = select i1 %fcmp, float 2., float %fadd
+  ret float %sel
+}
+
+define float @select_fcmp_fsub_one(float %x, float %y) {
+; CHECK-LABEL: @select_fcmp_fsub_one(
+; CHECK-NEXT:    [[FCMP:%.*]] = fcmp one float [[Y:%.*]], 2.000000e+00
+; CHECK-NEXT:    [[FADD:%.*]] = fsub float [[Y]], 2.000000e+00
+; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[FCMP]], float 0.000000e+00, float [[FADD]]
+; CHECK-NEXT:    ret float [[SEL]]
+;
+  %fcmp = fcmp one float %y, 2.
+  %fadd = fsub float %y, 2.
+  %sel = select i1 %fcmp, float 0., float %fadd
+  ret float %sel
+}
+
+define float @select_fcmp_fsub_one_nnan(float %x, float %y) {
+; CHECK-LABEL: @select_fcmp_fsub_one_nnan(
+; CHECK-NEXT:    ret float 0.000000e+00
+;
+  %fcmp = fcmp nnan one float %y, 2.
+  %fadd = fsub float %y, 2.
+  %sel = select i1 %fcmp, float 0., float %fadd
+  ret float %sel
+}
+
+define float @select_fcmp_fadd(float %x, float %y) {
+; CHECK-LABEL: @select_fcmp_fadd(
+; CHECK-NEXT:    ret float 4.000000e+00
+;
+  %fcmp = fcmp oeq float %y, 2.
+  %fadd = fadd float %y, 2.
+  %sel = select i1 %fcmp, float %fadd, float 4.
+  ret float %sel
+}
+
+define <2 x float> @select_fcmp_fadd_vec(<2 x float> %x, <2 x float> %y) {
+; CHECK-LABEL: @select_fcmp_fadd_vec(
+; CHECK-NEXT:    ret <2 x float> <float 4.000000e+00, float 4.000000e+00>
+;
+  %fcmp = fcmp oeq <2 x float> %y, <float 2., float 2.>
+  %fadd = fadd <2 x float> %y, <float 2., float 2.>
+  %sel = select <2 x i1> %fcmp, <2 x float> %fadd, <2 x float> <float 4., float 4.>
+  ret <2 x float> %sel
+}
+
+
+define float @select_fcmp_fdiv(float %x, float %y) {
+; CHECK-LABEL: @select_fcmp_fdiv(
+; CHECK-NEXT:    ret float 1.000000e+00
+;
+  %fcmp = fcmp oeq float %y, 2.
+  %fdiv = fdiv float %y, 2.
+  %sel = select i1 %fcmp, float %fdiv, float 1.
+  ret float %sel
+}
+
+define float @select_fcmp_frem(float %x, float %y) {
+; CHECK-LABEL: @select_fcmp_frem(
+; CHECK-NEXT:    ret float 1.000000e+00
+;
+  %fcmp = fcmp oeq float %y, 3.
+  %frem = frem float %y, 2.
+  %sel = select i1 %fcmp, float %frem, float 1.
+  ret float %sel
+}
+
+define <2 x float> @select_fcmp_insertelement(<2 x float> %x, <2 x float> %y) {
+; CHECK-LABEL: @select_fcmp_insertelement(
+; CHECK-NEXT:    ret <2 x float> <float 4.000000e+00, float 2.000000e+00>
+;
+  %fcmp = fcmp oeq <2 x float> %y, <float 2., float 2.>
+  %insert = insertelement <2 x float> %y, float 4., i64 0
+  %sel = select <2 x i1> %fcmp, <2 x float> %insert, <2 x float> <float 4., float 2.>
+  ret <2 x float> %sel
+}
+
+define <4 x float> @select_fcmp_shufflevector_select(<4 x float> %x, <4 x float> %y) {
+; CHECK-LABEL: @select_fcmp_shufflevector_select(
+; CHECK-NEXT:    ret <4 x float> <float poison, float 2.000000e+00, float poison, float 2.000000e+00>
+;
+  %fcmp = fcmp oeq <4 x float> %y, <float 2., float 2., float 2., float 2.>
+  %shuffle = shufflevector <4 x float> %y, <4 x float> poison, <4 x i32> <i32 4, i32 1, i32 6, i32 3>
+  %sel = select <4 x i1> %fcmp, <4 x float> %shuffle, <4 x float> <float poison, float 2., float poison, float 2.>
+  ret <4 x float> %sel
+}
+

``````````

</details>


https://github.com/llvm/llvm-project/pull/115152


More information about the llvm-commits mailing list