[llvm] [InstCombine] Fold copysign of selects from sign comparison to sign operand (PR #85627)

Krishna Narayanan via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 25 09:23:43 PDT 2024


https://github.com/Krishna-13-cyber updated https://github.com/llvm/llvm-project/pull/85627

>From 96403f5361e5331a35fe8d158859e01bf208a22f Mon Sep 17 00:00:00 2001
From: Krishna-13-cyber <krishnanarayanan132002 at gmail.com>
Date: Mon, 18 Mar 2024 14:12:15 +0530
Subject: [PATCH 1/2] Add support for folding sign comparison to sign operand

---
 .../InstCombine/InstCombineSelect.cpp         | 45 +++++++++++++++++++
 llvm/test/Transforms/InstCombine/fcmp.ll      | 16 +++++++
 2 files changed, 61 insertions(+)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index ee76a6294428b3..49e10646eb5d15 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -2790,6 +2790,47 @@ static Instruction *foldSelectWithFCmpToFabs(SelectInst &SI,
   return ChangedFMF ? &SI : nullptr;
 }
 
+// Canonicalize select with fcmp -> select
+static Instruction *foldSelectWithFCmp(SelectInst &SI, InstCombinerImpl &IC) {
+  /* From
+    %4 = fcmp olt float %1, 0.000000e+00
+    %5 = and i1 %4, %0
+    %6 = select i1 %5, float -1.000000e+00, float 1.000000e+00
+  */
+  /* To
+   %4 = select i1 %0, float %1, float 1.000000e+00
+  */
+  Value *CondVal = SI.getCondition();
+  Value *TrueVal = SI.getTrueValue();
+  Value *FalseVal = SI.getFalseValue();
+  Value *One = Constant::getAllOnesValue(FalseVal->getType());
+  Value *X, *C, *Op;
+  const APFloat *A, *E;
+  CmpInst::Predicate Pred;
+  for (bool Swap : {false, true}) {
+    if (Swap)
+      std::swap(TrueVal, FalseVal);
+    if (match(&SI, (m_Value(CondVal), m_APFloat(A), m_APFloat(E)))) {
+      if (!match(TrueVal, m_APFloatAllowUndef(A)) &&
+          !match(FalseVal, m_APFloatAllowUndef(E)))
+        return nullptr;
+      if (!match(CondVal, m_And(m_FCmp(Pred, m_Specific(X), m_PosZeroFP()),
+                                m_Value(C))) &&
+          (X->hasOneUse() && C->hasOneUse()))
+        return nullptr;
+      if (!A->isNegative() && E->isNegative())
+        return nullptr;
+      if (!Swap && (Pred == FCmpInst::FCMP_OLT)) {
+        return SelectInst::Create(C, X, One);
+      }
+      if (Swap && (Pred == FCmpInst::FCMP_OGT)) {
+        return SelectInst::Create(C, X, One);
+      }
+    }
+  }
+  return nullptr;
+}
+
 // Match the following IR pattern:
 //   %x.lowbits = and i8 %x, %lowbitmask
 //   %x.lowbits.are.zero = icmp eq i8 %x.lowbits, 0
@@ -3508,6 +3549,10 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
   if (Instruction *Fabs = foldSelectWithFCmpToFabs(SI, *this))
     return Fabs;
 
+  // Fold selecting to ffold.
+  if (Instruction *Ffold = foldSelectWithFCmp(SI, *this))
+    return Ffold;
+
   // See if we are selecting two values based on a comparison of the two values.
   if (ICmpInst *ICI = dyn_cast<ICmpInst>(CondVal))
     if (Instruction *Result = foldSelectInstWithICmp(SI, ICI))
diff --git a/llvm/test/Transforms/InstCombine/fcmp.ll b/llvm/test/Transforms/InstCombine/fcmp.ll
index 159c84d0dd8aa9..574fb1f9417742 100644
--- a/llvm/test/Transforms/InstCombine/fcmp.ll
+++ b/llvm/test/Transforms/InstCombine/fcmp.ll
@@ -1284,3 +1284,19 @@ define <1 x i1> @bitcast_1vec_eq0(i32 %x) {
   %cmp = fcmp oeq <1 x float> %f, zeroinitializer
   ret <1 x i1> %cmp
 }
+
+define float @copysign_conditional(i1 noundef zeroext %0, float %1, float %2) {
+; CHECK-LABEL: define float @copysign_conditional(
+; CHECK-SAME: i1 noundef zeroext [[TMP0:%.*]], float [[TMP1:%.*]], float [[TMP2:%.*]]) {
+; CHECK-NEXT:    [[TMP4:%.*]] = fcmp olt float [[TMP1]], 0.000000e+00
+; CHECK-NEXT:    [[TMP5:%.*]] = and i1 [[TMP4]], [[TMP0]]
+; CHECK-NEXT:    [[TMP6:%.*]] = select i1 [[TMP5]], float -1.000000e+00, float 1.000000e+00
+; CHECK-NEXT:    [[TMP7:%.*]] = tail call float @llvm.copysign.f32(float [[TMP2]], float [[TMP6]])
+; CHECK-NEXT:    ret float [[TMP7]]
+;
+  %4 = fcmp olt float %1, 0.000000e+00
+  %5 = and i1 %4, %0
+  %6 = select i1 %5, float -1.000000e+00, float 1.000000e+00
+  %7 = tail call float @llvm.copysign.f32(float %2, float %6)
+  ret float %7
+}
\ No newline at end of file

>From 4626d38314ea4883aaa3564e1ce167dd3a06c0d6 Mon Sep 17 00:00:00 2001
From: Krishna-13-cyber <krishnanarayanan132002 at gmail.com>
Date: Mon, 25 Mar 2024 21:38:56 +0530
Subject: [PATCH 2/2] Update with copysign as root inst

---
 .../InstCombine/InstCombineCalls.cpp          | 41 +++++++++++++++++
 .../InstCombine/InstCombineSelect.cpp         | 45 -------------------
 llvm/test/Transforms/InstCombine/copysign.ll  | 15 +++++++
 llvm/test/Transforms/InstCombine/fcmp.ll      | 16 -------
 4 files changed, 56 insertions(+), 61 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 8537dbc6fe531b..6a372ca5a34cea 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -2479,6 +2479,47 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
     if (match(Mag, m_FAbs(m_Value(X))) || match(Mag, m_FNeg(m_Value(X))))
       return replaceOperand(*II, 0, X);
 
+    Value *A, *B;
+    CmpInst::Predicate Pred;
+    const APFloat *TC, *FC;
+    if (!match(Sign, m_Select((m_And(m_Value(B),
+                                     m_FCmp(Pred, m_Value(A), m_PosZeroFP()))),
+                              m_APFloat(TC), m_APFloat(FC))))
+      return nullptr;
+    // Match select ?, TC, FC where the constants are equal but negated.
+    // Check for these 8 conditions
+    /*
+    copysign(Mag, B & (A < 0.0) ? -TC : TC) --> copysign(Mag, A) B->true, A<0.
+    copysign(Mag, B & (A < 0.0) ? -TC : TC) --> copysign(Mag, A) B->true, A>0.
+    copysign(Mag, B & (A < 0.0) ? -TC : TC) --> copysign(Mag, A) B->false, A>0.
+    copysign(Mag, B & (A < 0.0) ? -TC : TC) --> copysign(Mag, -A) B->false, A<0.
+    */
+    /*
+    copysign(Mag, B & (A > 0.0) ? -TC : TC) --> copysign(Mag, -A) B->true, A<0.
+    copysign(Mag, B & (A > 0.0) ? -TC : TC) --> copysign(Mag, -A) B->true, A>0.
+    copysign(Mag, B & (A > 0.0) ? -TC : TC) --> copysign(Mag, A) B->false, A>0.
+    copysign(Mag, B & (A > 0.0) ? -TC : TC) --> copysign(Mag, -A) B->false, A<0.
+    */
+    assert(TC != FC);
+
+    if (Pred == CmpInst::FCMP_OLT)
+      if (match(A, m_Negative())) {
+        if (!match(B, m_ZeroInt()) && TC->isNegative())
+          return replaceOperand(*II, 1, A);
+        if (match(B, m_ZeroInt()) && TC->isNegative())
+          A = Builder.CreateFNeg(A);
+        return replaceOperand(*II, 1, A);
+      } else
+        return replaceOperand(*II, 1, A);
+    if (Pred == CmpInst::FCMP_OGT)
+      if (match(A, m_Negative())) {
+        A = Builder.CreateFNeg(A);
+        return replaceOperand(*II, 1, A);
+      } else if (!match(B, m_ZeroInt()) && TC->isNegative())
+        A = Builder.CreateFNeg(A);
+    return replaceOperand(*II, 1, A);
+    if (match(B, m_ZeroInt()) && TC->isNegative())
+      return replaceOperand(*II, 1, A);
     break;
   }
   case Intrinsic::fabs: {
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 49e10646eb5d15..ee76a6294428b3 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -2790,47 +2790,6 @@ static Instruction *foldSelectWithFCmpToFabs(SelectInst &SI,
   return ChangedFMF ? &SI : nullptr;
 }
 
-// Canonicalize select with fcmp -> select
-static Instruction *foldSelectWithFCmp(SelectInst &SI, InstCombinerImpl &IC) {
-  /* From
-    %4 = fcmp olt float %1, 0.000000e+00
-    %5 = and i1 %4, %0
-    %6 = select i1 %5, float -1.000000e+00, float 1.000000e+00
-  */
-  /* To
-   %4 = select i1 %0, float %1, float 1.000000e+00
-  */
-  Value *CondVal = SI.getCondition();
-  Value *TrueVal = SI.getTrueValue();
-  Value *FalseVal = SI.getFalseValue();
-  Value *One = Constant::getAllOnesValue(FalseVal->getType());
-  Value *X, *C, *Op;
-  const APFloat *A, *E;
-  CmpInst::Predicate Pred;
-  for (bool Swap : {false, true}) {
-    if (Swap)
-      std::swap(TrueVal, FalseVal);
-    if (match(&SI, (m_Value(CondVal), m_APFloat(A), m_APFloat(E)))) {
-      if (!match(TrueVal, m_APFloatAllowUndef(A)) &&
-          !match(FalseVal, m_APFloatAllowUndef(E)))
-        return nullptr;
-      if (!match(CondVal, m_And(m_FCmp(Pred, m_Specific(X), m_PosZeroFP()),
-                                m_Value(C))) &&
-          (X->hasOneUse() && C->hasOneUse()))
-        return nullptr;
-      if (!A->isNegative() && E->isNegative())
-        return nullptr;
-      if (!Swap && (Pred == FCmpInst::FCMP_OLT)) {
-        return SelectInst::Create(C, X, One);
-      }
-      if (Swap && (Pred == FCmpInst::FCMP_OGT)) {
-        return SelectInst::Create(C, X, One);
-      }
-    }
-  }
-  return nullptr;
-}
-
 // Match the following IR pattern:
 //   %x.lowbits = and i8 %x, %lowbitmask
 //   %x.lowbits.are.zero = icmp eq i8 %x.lowbits, 0
@@ -3549,10 +3508,6 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
   if (Instruction *Fabs = foldSelectWithFCmpToFabs(SI, *this))
     return Fabs;
 
-  // Fold selecting to ffold.
-  if (Instruction *Ffold = foldSelectWithFCmp(SI, *this))
-    return Ffold;
-
   // See if we are selecting two values based on a comparison of the two values.
   if (ICmpInst *ICI = dyn_cast<ICmpInst>(CondVal))
     if (Instruction *Result = foldSelectInstWithICmp(SI, ICI))
diff --git a/llvm/test/Transforms/InstCombine/copysign.ll b/llvm/test/Transforms/InstCombine/copysign.ll
index abc707acf0cd3f..c6f3eb11683013 100644
--- a/llvm/test/Transforms/InstCombine/copysign.ll
+++ b/llvm/test/Transforms/InstCombine/copysign.ll
@@ -109,3 +109,18 @@ define float @fabs_mag(float %x, float %y) {
   %r = call float @llvm.copysign.f32(float %a, float %y)
   ret float %r
 }
+
+define float @copysign_conditional(i1 noundef zeroext %x, float %y, float %z) {
+; CHECK-LABEL: @copysign_conditional(
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp olt float [[Y:%.*]], 0.000000e+00
+; CHECK-NEXT:    [[AND:%.*]] = and i1 [[CMP]], [[X:%.*]]
+; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[AND]], float -1.000000e+00, float 1.000000e+00
+; CHECK-NEXT:    [[RES:%.*]] = tail call float @llvm.copysign.f32(float [[Z:%.*]], float [[SEL]])
+; CHECK-NEXT:    ret float [[RES]]
+;
+  %cmp = fcmp olt float %y, 0.000000e+00
+  %and = and i1 %cmp, %x
+  %sel = select i1 %and, float -1.000000e+00, float 1.000000e+00
+  %res = tail call float @llvm.copysign.f32(float %z, float %sel)
+  ret float %res
+}
diff --git a/llvm/test/Transforms/InstCombine/fcmp.ll b/llvm/test/Transforms/InstCombine/fcmp.ll
index 574fb1f9417742..159c84d0dd8aa9 100644
--- a/llvm/test/Transforms/InstCombine/fcmp.ll
+++ b/llvm/test/Transforms/InstCombine/fcmp.ll
@@ -1284,19 +1284,3 @@ define <1 x i1> @bitcast_1vec_eq0(i32 %x) {
   %cmp = fcmp oeq <1 x float> %f, zeroinitializer
   ret <1 x i1> %cmp
 }
-
-define float @copysign_conditional(i1 noundef zeroext %0, float %1, float %2) {
-; CHECK-LABEL: define float @copysign_conditional(
-; CHECK-SAME: i1 noundef zeroext [[TMP0:%.*]], float [[TMP1:%.*]], float [[TMP2:%.*]]) {
-; CHECK-NEXT:    [[TMP4:%.*]] = fcmp olt float [[TMP1]], 0.000000e+00
-; CHECK-NEXT:    [[TMP5:%.*]] = and i1 [[TMP4]], [[TMP0]]
-; CHECK-NEXT:    [[TMP6:%.*]] = select i1 [[TMP5]], float -1.000000e+00, float 1.000000e+00
-; CHECK-NEXT:    [[TMP7:%.*]] = tail call float @llvm.copysign.f32(float [[TMP2]], float [[TMP6]])
-; CHECK-NEXT:    ret float [[TMP7]]
-;
-  %4 = fcmp olt float %1, 0.000000e+00
-  %5 = and i1 %4, %0
-  %6 = select i1 %5, float -1.000000e+00, float 1.000000e+00
-  %7 = tail call float @llvm.copysign.f32(float %2, float %6)
-  ret float %7
-}
\ No newline at end of file



More information about the llvm-commits mailing list