[llvm] [InstCombine] Generalize fold of `fcmp + copysign` (PR #86387)

Yingwei Zheng via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 11 07:03:33 PDT 2024


https://github.com/dtcxzyw updated https://github.com/llvm/llvm-project/pull/86387

>From 17224b39db809a50de5818b1cf27478456c02457 Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Thu, 11 Apr 2024 21:55:20 +0800
Subject: [PATCH 1/2] [InstCombine] Add pre-commit tests. NFC.

---
 llvm/test/Transforms/InstCombine/fcmp.ll | 115 +++++++++++++++++++++++
 1 file changed, 115 insertions(+)

diff --git a/llvm/test/Transforms/InstCombine/fcmp.ll b/llvm/test/Transforms/InstCombine/fcmp.ll
index f2701d16d0f3d1..3e7c6c229f56cc 100644
--- a/llvm/test/Transforms/InstCombine/fcmp.ll
+++ b/llvm/test/Transforms/InstCombine/fcmp.ll
@@ -736,6 +736,121 @@ define i1 @is_signbit_set_simplify_nan(double %x) {
   ret i1 %r
 }
 
+define i1 @test_oeq(float %a) {
+; CHECK-LABEL: @test_oeq(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[RES:%.*]] = call float @llvm.copysign.f32(float 1.000000e+00, float [[A:%.*]])
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp oeq float [[RES]], 1.000000e+00
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+entry:
+  %res = call float @llvm.copysign.f32(float 1.0, float %a)
+  %cmp = fcmp oeq float %res, 1.0
+  ret i1 %cmp
+}
+
+define <2 x i1> @test_oeq_vec_splat(<2 x float> %a) {
+; CHECK-LABEL: @test_oeq_vec_splat(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[RES:%.*]] = call <2 x float> @llvm.copysign.v2f32(<2 x float> <float 1.000000e+00, float 1.000000e+00>, <2 x float> [[A:%.*]])
+; CHECK-NEXT:    [[TMP1:%.*]] = fcmp oeq <2 x float> [[RES]], <float 1.000000e+00, float 1.000000e+00>
+; CHECK-NEXT:    ret <2 x i1> [[TMP1]]
+;
+entry:
+  %res = call <2 x float> @llvm.copysign.v2f32(<2 x float> splat(float 1.0), <2 x float> %a)
+  %cmp = fcmp oeq <2 x float> %res, splat(float 1.0)
+  ret <2 x i1> %cmp
+}
+
+define <2 x i1> @test_oeq_vec_nonsplat(<2 x float> %a) {
+; CHECK-LABEL: @test_oeq_vec_nonsplat(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[RES:%.*]] = call <2 x float> @llvm.copysign.v2f32(<2 x float> <float 1.000000e+00, float 1.000000e+00>, <2 x float> [[A:%.*]])
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp oeq <2 x float> [[RES]], <float 0.000000e+00, float 2.000000e+00>
+; CHECK-NEXT:    ret <2 x i1> [[CMP]]
+;
+entry:
+  %res = call <2 x float> @llvm.copysign.v2f32(<2 x float> splat(float 1.0), <2 x float> %a)
+  %cmp = fcmp oeq <2 x float> %res, <float 0.0, float 2.0>
+  ret <2 x i1> %cmp
+}
+
+define i1 @test_one(float %a) {
+; CHECK-LABEL: @test_one(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[RES:%.*]] = call float @llvm.copysign.f32(float 1.000000e+00, float [[A:%.*]])
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp one float [[RES]], 1.000000e+00
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+entry:
+  %res = call float @llvm.copysign.f32(float 1.0, float %a)
+  %cmp = fcmp one float %res, 1.0
+  ret i1 %cmp
+}
+
+define i1 @test_ogt_false(float %a) {
+; CHECK-LABEL: @test_ogt_false(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[RES:%.*]] = call float @llvm.copysign.f32(float 1.000000e+00, float [[A:%.*]])
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp ogt float [[RES]], 2.000000e+00
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+entry:
+  %res = call float @llvm.copysign.f32(float 1.0, float %a)
+  %cmp = fcmp ogt float %res, 2.0
+  ret i1 %cmp
+}
+
+define i1 @test_olt_true(float %a) {
+; CHECK-LABEL: @test_olt_true(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[RES:%.*]] = call float @llvm.copysign.f32(float 1.000000e+00, float [[A:%.*]])
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp olt float [[RES]], 2.000000e+00
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+entry:
+  %res = call float @llvm.copysign.f32(float 1.0, float %a)
+  %cmp = fcmp olt float %res, 2.0
+  ret i1 %cmp
+}
+
+define i1 @test_oeq_nan_false(float %a) {
+; CHECK-LABEL: @test_oeq_nan_false(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[RES:%.*]] = call float @llvm.copysign.f32(float 0x7FF8000000000000, float [[A:%.*]])
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp oeq float [[RES]], 1.000000e+00
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+entry:
+  %res = call float @llvm.copysign.f32(float 0x7FF8000000000000, float %a)
+  %cmp = fcmp oeq float %res, 1.0
+  ret i1 %cmp
+}
+
+define i1 @test_uno_nan_true(float %a) {
+; CHECK-LABEL: @test_uno_nan_true(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    ret i1 true
+;
+entry:
+  %res = call float @llvm.copysign.f32(float 0x7FF8000000000000, float %a)
+  %cmp = fcmp uno float %res, 1.0
+  ret i1 %cmp
+}
+
+define i1 @test_oge_zero_false(float %a) {
+; CHECK-LABEL: @test_oge_zero_false(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[RES:%.*]] = call float @llvm.copysign.f32(float 0.000000e+00, float [[A:%.*]])
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp oge float [[RES]], 1.000000e+00
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+entry:
+  %res = call float @llvm.copysign.f32(float 0.0, float %a)
+  %cmp = fcmp oge float %res, 1.0
+  ret i1 %cmp
+}
+
 define <2 x i1> @lossy_oeq(<2 x float> %x) {
 ; CHECK-LABEL: @lossy_oeq(
 ; CHECK-NEXT:    ret <2 x i1> zeroinitializer

>From 608580289dc78d23de8f19a8558d9e6edf2f2b2b Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Thu, 11 Apr 2024 21:57:06 +0800
Subject: [PATCH 2/2] [InstCombine] Generalize fold of `fcmp + copysign`

---
 .../InstCombine/InstCombineCompares.cpp       | 22 +++---
 llvm/test/Transforms/InstCombine/fcmp.ll      | 72 +++++++------------
 2 files changed, 39 insertions(+), 55 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 9ff1e3aa5502e6..018c6b699b4222 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -8117,22 +8117,24 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) {
   }
 
   // Convert a sign-bit test of an FP value into a cast and integer compare.
-  // TODO: Simplify if the copysign constant is 0.0 or NaN.
-  // TODO: Handle non-zero compare constants.
-  // TODO: Handle other predicates.
+  // fcmp pred copysign(C1, X), C2 ->
+  // select !signbit(X), (fcmp pred abs(C1), C2), (fcmp pred nabs(C1), C2)
   if (match(Op0, m_OneUse(m_Intrinsic<Intrinsic::copysign>(m_APFloat(C),
                                                            m_Value(X)))) &&
-      match(Op1, m_AnyZeroFP()) && !C->isZero() && !C->isNaN()) {
+      match(Op1, m_ImmConstant(RHSC))) {
     Type *IntType = Builder.getIntNTy(X->getType()->getScalarSizeInBits());
     if (auto *VecTy = dyn_cast<VectorType>(OpType))
       IntType = VectorType::get(IntType, VecTy->getElementCount());
 
-    // copysign(non-zero constant, X) < 0.0 --> (bitcast X) < 0
-    if (Pred == FCmpInst::FCMP_OLT) {
-      Value *IntX = Builder.CreateBitCast(X, IntType);
-      return new ICmpInst(ICmpInst::ICMP_SLT, IntX,
-                          ConstantInt::getNullValue(IntType));
-    }
+    APFloat PosC = abs(*C);
+    if (Value *CmpPos = ConstantFoldCompareInstOperands(
+            Pred, ConstantFP::get(X->getType(), PosC), RHSC, DL, &TLI, &I))
+      if (Value *CmpNeg = ConstantFoldCompareInstOperands(
+              Pred, ConstantFP::get(X->getType(), -PosC), RHSC, DL, &TLI, &I)) {
+        Value *IntX = Builder.CreateBitCast(X, IntType);
+        Value *NotNeg = Builder.CreateIsNotNeg(IntX);
+        return SelectInst::Create(NotNeg, CmpPos, CmpNeg);
+      }
   }
 
   {
diff --git a/llvm/test/Transforms/InstCombine/fcmp.ll b/llvm/test/Transforms/InstCombine/fcmp.ll
index 3e7c6c229f56cc..65de93e0f15c3f 100644
--- a/llvm/test/Transforms/InstCombine/fcmp.ll
+++ b/llvm/test/Transforms/InstCombine/fcmp.ll
@@ -596,8 +596,8 @@ define i1 @is_signbit_set(double %x) {
 
 define i1 @is_signbit_set_1(double %x) {
 ; CHECK-LABEL: @is_signbit_set_1(
-; CHECK-NEXT:    [[S:%.*]] = call double @llvm.copysign.f64(double 1.000000e+00, double [[X:%.*]])
-; CHECK-NEXT:    [[R:%.*]] = fcmp ult double [[S]], 0.000000e+00
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast double [[X:%.*]] to i64
+; CHECK-NEXT:    [[R:%.*]] = icmp slt i64 [[TMP1]], 0
 ; CHECK-NEXT:    ret i1 [[R]]
 ;
   %s = call double @llvm.copysign.f64(double 1.0, double %x)
@@ -607,8 +607,8 @@ define i1 @is_signbit_set_1(double %x) {
 
 define i1 @is_signbit_set_2(double %x) {
 ; CHECK-LABEL: @is_signbit_set_2(
-; CHECK-NEXT:    [[S:%.*]] = call double @llvm.copysign.f64(double 1.000000e+00, double [[X:%.*]])
-; CHECK-NEXT:    [[R:%.*]] = fcmp ole double [[S]], 0.000000e+00
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast double [[X:%.*]] to i64
+; CHECK-NEXT:    [[R:%.*]] = icmp slt i64 [[TMP1]], 0
 ; CHECK-NEXT:    ret i1 [[R]]
 ;
   %s = call double @llvm.copysign.f64(double 1.0, double %x)
@@ -618,8 +618,8 @@ define i1 @is_signbit_set_2(double %x) {
 
 define i1 @is_signbit_set_3(double %x) {
 ; CHECK-LABEL: @is_signbit_set_3(
-; CHECK-NEXT:    [[S:%.*]] = call double @llvm.copysign.f64(double 1.000000e+00, double [[X:%.*]])
-; CHECK-NEXT:    [[R:%.*]] = fcmp ule double [[S]], 0.000000e+00
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast double [[X:%.*]] to i64
+; CHECK-NEXT:    [[R:%.*]] = icmp slt i64 [[TMP1]], 0
 ; CHECK-NEXT:    ret i1 [[R]]
 ;
   %s = call double @llvm.copysign.f64(double 1.0, double %x)
@@ -640,12 +640,10 @@ define <2 x i1> @is_signbit_set_anyzero(<2 x double> %x) {
   ret <2 x i1> %r
 }
 
-; TODO: Handle different predicates.
-
 define i1 @is_signbit_clear(double %x) {
 ; CHECK-LABEL: @is_signbit_clear(
-; CHECK-NEXT:    [[S:%.*]] = call double @llvm.copysign.f64(double 4.200000e+01, double [[X:%.*]])
-; CHECK-NEXT:    [[R:%.*]] = fcmp ogt double [[S]], 0.000000e+00
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast double [[X:%.*]] to i64
+; CHECK-NEXT:    [[R:%.*]] = icmp sgt i64 [[TMP1]], -1
 ; CHECK-NEXT:    ret i1 [[R]]
 ;
   %s = call double @llvm.copysign.f64(double -42.0, double %x)
@@ -655,8 +653,8 @@ define i1 @is_signbit_clear(double %x) {
 
 define i1 @is_signbit_clear_1(double %x) {
 ; CHECK-LABEL: @is_signbit_clear_1(
-; CHECK-NEXT:    [[S:%.*]] = call double @llvm.copysign.f64(double 4.200000e+01, double [[X:%.*]])
-; CHECK-NEXT:    [[R:%.*]] = fcmp ugt double [[S]], 0.000000e+00
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast double [[X:%.*]] to i64
+; CHECK-NEXT:    [[R:%.*]] = icmp sgt i64 [[TMP1]], -1
 ; CHECK-NEXT:    ret i1 [[R]]
 ;
   %s = call double @llvm.copysign.f64(double -42.0, double %x)
@@ -666,8 +664,8 @@ define i1 @is_signbit_clear_1(double %x) {
 
 define i1 @is_signbit_clear_2(double %x) {
 ; CHECK-LABEL: @is_signbit_clear_2(
-; CHECK-NEXT:    [[S:%.*]] = call double @llvm.copysign.f64(double 4.200000e+01, double [[X:%.*]])
-; CHECK-NEXT:    [[R:%.*]] = fcmp oge double [[S]], 0.000000e+00
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast double [[X:%.*]] to i64
+; CHECK-NEXT:    [[R:%.*]] = icmp sgt i64 [[TMP1]], -1
 ; CHECK-NEXT:    ret i1 [[R]]
 ;
   %s = call double @llvm.copysign.f64(double -42.0, double %x)
@@ -677,8 +675,8 @@ define i1 @is_signbit_clear_2(double %x) {
 
 define i1 @is_signbit_clear_3(double %x) {
 ; CHECK-LABEL: @is_signbit_clear_3(
-; CHECK-NEXT:    [[S:%.*]] = call double @llvm.copysign.f64(double 4.200000e+01, double [[X:%.*]])
-; CHECK-NEXT:    [[R:%.*]] = fcmp uge double [[S]], 0.000000e+00
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast double [[X:%.*]] to i64
+; CHECK-NEXT:    [[R:%.*]] = icmp sgt i64 [[TMP1]], -1
 ; CHECK-NEXT:    ret i1 [[R]]
 ;
   %s = call double @llvm.copysign.f64(double -42.0, double %x)
@@ -701,12 +699,10 @@ define i1 @is_signbit_set_extra_use(double %x, ptr %p) {
   ret i1 %r
 }
 
-; TODO: Handle non-zero compare constant.
-
 define i1 @is_signbit_clear_nonzero(double %x) {
 ; CHECK-LABEL: @is_signbit_clear_nonzero(
-; CHECK-NEXT:    [[S:%.*]] = call double @llvm.copysign.f64(double 4.200000e+01, double [[X:%.*]])
-; CHECK-NEXT:    [[R:%.*]] = fcmp ogt double [[S]], 1.000000e+00
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast double [[X:%.*]] to i64
+; CHECK-NEXT:    [[R:%.*]] = icmp sgt i64 [[TMP1]], -1
 ; CHECK-NEXT:    ret i1 [[R]]
 ;
   %s = call double @llvm.copysign.f64(double -42.0, double %x)
@@ -714,8 +710,6 @@ define i1 @is_signbit_clear_nonzero(double %x) {
   ret i1 %r
 }
 
-; TODO: Handle zero copysign constant.
-
 define i1 @is_signbit_set_simplify_zero(double %x) {
 ; CHECK-LABEL: @is_signbit_set_simplify_zero(
 ; CHECK-NEXT:    ret i1 false
@@ -725,8 +719,6 @@ define i1 @is_signbit_set_simplify_zero(double %x) {
   ret i1 %r
 }
 
-; TODO: Handle NaN copysign constant.
-
 define i1 @is_signbit_set_simplify_nan(double %x) {
 ; CHECK-LABEL: @is_signbit_set_simplify_nan(
 ; CHECK-NEXT:    ret i1 false
@@ -739,8 +731,8 @@ define i1 @is_signbit_set_simplify_nan(double %x) {
 define i1 @test_oeq(float %a) {
 ; CHECK-LABEL: @test_oeq(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[RES:%.*]] = call float @llvm.copysign.f32(float 1.000000e+00, float [[A:%.*]])
-; CHECK-NEXT:    [[CMP:%.*]] = fcmp oeq float [[RES]], 1.000000e+00
+; CHECK-NEXT:    [[TMP0:%.*]] = bitcast float [[A:%.*]] to i32
+; CHECK-NEXT:    [[CMP:%.*]] = icmp sgt i32 [[TMP0]], -1
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
 entry:
@@ -752,8 +744,8 @@ entry:
 define <2 x i1> @test_oeq_vec_splat(<2 x float> %a) {
 ; CHECK-LABEL: @test_oeq_vec_splat(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[RES:%.*]] = call <2 x float> @llvm.copysign.v2f32(<2 x float> <float 1.000000e+00, float 1.000000e+00>, <2 x float> [[A:%.*]])
-; CHECK-NEXT:    [[TMP1:%.*]] = fcmp oeq <2 x float> [[RES]], <float 1.000000e+00, float 1.000000e+00>
+; CHECK-NEXT:    [[TMP0:%.*]] = bitcast <2 x float> [[A:%.*]] to <2 x i32>
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp sgt <2 x i32> [[TMP0]], <i32 -1, i32 -1>
 ; CHECK-NEXT:    ret <2 x i1> [[TMP1]]
 ;
 entry:
@@ -765,9 +757,7 @@ entry:
 define <2 x i1> @test_oeq_vec_nonsplat(<2 x float> %a) {
 ; CHECK-LABEL: @test_oeq_vec_nonsplat(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[RES:%.*]] = call <2 x float> @llvm.copysign.v2f32(<2 x float> <float 1.000000e+00, float 1.000000e+00>, <2 x float> [[A:%.*]])
-; CHECK-NEXT:    [[CMP:%.*]] = fcmp oeq <2 x float> [[RES]], <float 0.000000e+00, float 2.000000e+00>
-; CHECK-NEXT:    ret <2 x i1> [[CMP]]
+; CHECK-NEXT:    ret <2 x i1> zeroinitializer
 ;
 entry:
   %res = call <2 x float> @llvm.copysign.v2f32(<2 x float> splat(float 1.0), <2 x float> %a)
@@ -778,8 +768,8 @@ entry:
 define i1 @test_one(float %a) {
 ; CHECK-LABEL: @test_one(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[RES:%.*]] = call float @llvm.copysign.f32(float 1.000000e+00, float [[A:%.*]])
-; CHECK-NEXT:    [[CMP:%.*]] = fcmp one float [[RES]], 1.000000e+00
+; CHECK-NEXT:    [[TMP0:%.*]] = bitcast float [[A:%.*]] to i32
+; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i32 [[TMP0]], 0
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
 entry:
@@ -791,9 +781,7 @@ entry:
 define i1 @test_ogt_false(float %a) {
 ; CHECK-LABEL: @test_ogt_false(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[RES:%.*]] = call float @llvm.copysign.f32(float 1.000000e+00, float [[A:%.*]])
-; CHECK-NEXT:    [[CMP:%.*]] = fcmp ogt float [[RES]], 2.000000e+00
-; CHECK-NEXT:    ret i1 [[CMP]]
+; CHECK-NEXT:    ret i1 false
 ;
 entry:
   %res = call float @llvm.copysign.f32(float 1.0, float %a)
@@ -804,9 +792,7 @@ entry:
 define i1 @test_olt_true(float %a) {
 ; CHECK-LABEL: @test_olt_true(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[RES:%.*]] = call float @llvm.copysign.f32(float 1.000000e+00, float [[A:%.*]])
-; CHECK-NEXT:    [[CMP:%.*]] = fcmp olt float [[RES]], 2.000000e+00
-; CHECK-NEXT:    ret i1 [[CMP]]
+; CHECK-NEXT:    ret i1 true
 ;
 entry:
   %res = call float @llvm.copysign.f32(float 1.0, float %a)
@@ -817,9 +803,7 @@ entry:
 define i1 @test_oeq_nan_false(float %a) {
 ; CHECK-LABEL: @test_oeq_nan_false(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[RES:%.*]] = call float @llvm.copysign.f32(float 0x7FF8000000000000, float [[A:%.*]])
-; CHECK-NEXT:    [[CMP:%.*]] = fcmp oeq float [[RES]], 1.000000e+00
-; CHECK-NEXT:    ret i1 [[CMP]]
+; CHECK-NEXT:    ret i1 false
 ;
 entry:
   %res = call float @llvm.copysign.f32(float 0x7FF8000000000000, float %a)
@@ -841,9 +825,7 @@ entry:
 define i1 @test_oge_zero_false(float %a) {
 ; CHECK-LABEL: @test_oge_zero_false(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[RES:%.*]] = call float @llvm.copysign.f32(float 0.000000e+00, float [[A:%.*]])
-; CHECK-NEXT:    [[CMP:%.*]] = fcmp oge float [[RES]], 1.000000e+00
-; CHECK-NEXT:    ret i1 [[CMP]]
+; CHECK-NEXT:    ret i1 false
 ;
 entry:
   %res = call float @llvm.copysign.f32(float 0.0, float %a)



More information about the llvm-commits mailing list