[llvm] [InstCombine] Fold `fcmp pred sqrt(X), 0.0 -> fcmp pred2 X, 0.0` (PR #101626)

Yingwei Zheng via llvm-commits llvm-commits at lists.llvm.org
Fri Aug 2 09:54:06 PDT 2024


https://github.com/dtcxzyw updated https://github.com/llvm/llvm-project/pull/101626

>From 96ca3f3c8639141a48fa7116fe8527e702469e96 Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Fri, 2 Aug 2024 15:04:36 +0800
Subject: [PATCH 1/4] [InstCombine] Add pre-commit tests. NFC.

---
 llvm/test/Transforms/InstCombine/fcmp.ll | 185 +++++++++++++++++++++++
 1 file changed, 185 insertions(+)

diff --git a/llvm/test/Transforms/InstCombine/fcmp.ll b/llvm/test/Transforms/InstCombine/fcmp.ll
index 656b3d2c49206..aafb1e25e13a5 100644
--- a/llvm/test/Transforms/InstCombine/fcmp.ll
+++ b/llvm/test/Transforms/InstCombine/fcmp.ll
@@ -2117,3 +2117,188 @@ define <8 x i1> @fcmp_ogt_fsub_const_vec_denormal_preserve-sign(<8 x float> %x,
   %cmp = fcmp ogt <8 x float> %fs, zeroinitializer
   ret <8 x i1> %cmp
 }
+
+define i1 @fcmp_sqrt_zero_olt(half %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_olt(
+; CHECK-NEXT:    ret i1 false
+;
+  %sqrt = call half @llvm.sqrt(half %x)
+  %cmp = fcmp olt half %sqrt, 0.0
+  ret i1 %cmp
+}
+
+define i1 @fcmp_sqrt_zero_ult(half %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_ult(
+; CHECK-NEXT:    [[SQRT:%.*]] = call half @llvm.sqrt.f16(half [[X:%.*]])
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp ult half [[SQRT]], 0xH0000
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %sqrt = call half @llvm.sqrt(half %x)
+  %cmp = fcmp ult half %sqrt, 0.0
+  ret i1 %cmp
+}
+
+define <2 x i1> @fcmp_sqrt_zero_ult_vec(<2 x half> %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_ult_vec(
+; CHECK-NEXT:    [[SQRT:%.*]] = call <2 x half> @llvm.sqrt.v2f16(<2 x half> [[X:%.*]])
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp ult <2 x half> [[SQRT]], zeroinitializer
+; CHECK-NEXT:    ret <2 x i1> [[CMP]]
+;
+  %sqrt = call <2 x half> @llvm.sqrt(<2 x half> %x)
+  %cmp = fcmp ult <2 x half> %sqrt, zeroinitializer
+  ret <2 x i1> %cmp
+}
+
+define i1 @fcmp_sqrt_zero_ole(half %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_ole(
+; CHECK-NEXT:    [[SQRT:%.*]] = call half @llvm.sqrt.f16(half [[X:%.*]])
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp ole half [[SQRT]], 0xH0000
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %sqrt = call half @llvm.sqrt(half %x)
+  %cmp = fcmp ole half %sqrt, 0.0
+  ret i1 %cmp
+}
+
+define i1 @fcmp_sqrt_zero_ule(half %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_ule(
+; CHECK-NEXT:    [[SQRT:%.*]] = call half @llvm.sqrt.f16(half [[X:%.*]])
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp ule half [[SQRT]], 0xH0000
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %sqrt = call half @llvm.sqrt(half %x)
+  %cmp = fcmp ule half %sqrt, 0.0
+  ret i1 %cmp
+}
+
+define i1 @fcmp_sqrt_zero_ogt(half %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_ogt(
+; CHECK-NEXT:    [[SQRT:%.*]] = call half @llvm.sqrt.f16(half [[X:%.*]])
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp ogt half [[SQRT]], 0xH0000
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %sqrt = call half @llvm.sqrt(half %x)
+  %cmp = fcmp ogt half %sqrt, 0.0
+  ret i1 %cmp
+}
+
+define i1 @fcmp_sqrt_zero_ugt(half %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_ugt(
+; CHECK-NEXT:    [[SQRT:%.*]] = call half @llvm.sqrt.f16(half [[X:%.*]])
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp ugt half [[SQRT]], 0xH0000
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %sqrt = call half @llvm.sqrt(half %x)
+  %cmp = fcmp ugt half %sqrt, 0.0
+  ret i1 %cmp
+}
+
+define i1 @fcmp_sqrt_zero_oge(half %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_oge(
+; CHECK-NEXT:    [[SQRT:%.*]] = call half @llvm.sqrt.f16(half [[X:%.*]])
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp oge half [[SQRT]], 0xH0000
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %sqrt = call half @llvm.sqrt(half %x)
+  %cmp = fcmp oge half %sqrt, 0.0
+  ret i1 %cmp
+}
+
+define i1 @fcmp_sqrt_zero_uge(half %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_uge(
+; CHECK-NEXT:    ret i1 true
+;
+  %sqrt = call half @llvm.sqrt(half %x)
+  %cmp = fcmp uge half %sqrt, 0.0
+  ret i1 %cmp
+}
+
+define i1 @fcmp_sqrt_zero_oeq(half %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_oeq(
+; CHECK-NEXT:    [[SQRT:%.*]] = call half @llvm.sqrt.f16(half [[X:%.*]])
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp oeq half [[SQRT]], 0xH0000
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %sqrt = call half @llvm.sqrt(half %x)
+  %cmp = fcmp oeq half %sqrt, 0.0
+  ret i1 %cmp
+}
+
+define i1 @fcmp_sqrt_zero_ueq(half %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_ueq(
+; CHECK-NEXT:    [[SQRT:%.*]] = call half @llvm.sqrt.f16(half [[X:%.*]])
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp ueq half [[SQRT]], 0xH0000
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %sqrt = call half @llvm.sqrt(half %x)
+  %cmp = fcmp ueq half %sqrt, 0.0
+  ret i1 %cmp
+}
+
+define i1 @fcmp_sqrt_zero_one(half %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_one(
+; CHECK-NEXT:    [[SQRT:%.*]] = call half @llvm.sqrt.f16(half [[X:%.*]])
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp one half [[SQRT]], 0xH0000
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %sqrt = call half @llvm.sqrt(half %x)
+  %cmp = fcmp one half %sqrt, 0.0
+  ret i1 %cmp
+}
+
+define i1 @fcmp_sqrt_zero_une(half %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_une(
+; CHECK-NEXT:    [[SQRT:%.*]] = call half @llvm.sqrt.f16(half [[X:%.*]])
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp une half [[SQRT]], 0xH0000
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %sqrt = call half @llvm.sqrt(half %x)
+  %cmp = fcmp une half %sqrt, 0.0
+  ret i1 %cmp
+}
+
+define i1 @fcmp_sqrt_zero_ord(half %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_ord(
+; CHECK-NEXT:    [[SQRT:%.*]] = call half @llvm.sqrt.f16(half [[X:%.*]])
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp ord half [[SQRT]], 0xH0000
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %sqrt = call half @llvm.sqrt(half %x)
+  %cmp = fcmp ord half %sqrt, 0.0
+  ret i1 %cmp
+}
+
+define i1 @fcmp_sqrt_zero_uno(half %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_uno(
+; CHECK-NEXT:    [[SQRT:%.*]] = call half @llvm.sqrt.f16(half [[X:%.*]])
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp uno half [[SQRT]], 0xH0000
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %sqrt = call half @llvm.sqrt(half %x)
+  %cmp = fcmp uno half %sqrt, 0.0
+  ret i1 %cmp
+}
+
+; negative tests
+
+define i1 @fcmp_sqrt_zero_ult_var(half %x, half %y) {
+; CHECK-LABEL: @fcmp_sqrt_zero_ult_var(
+; CHECK-NEXT:    [[SQRT:%.*]] = call half @llvm.sqrt.f16(half [[X:%.*]])
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp ult half [[SQRT]], [[Y:%.*]]
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %sqrt = call half @llvm.sqrt(half %x)
+  %cmp = fcmp ult half %sqrt, %y
+  ret i1 %cmp
+}
+
+define i1 @fcmp_sqrt_zero_ult_nonzero(half %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_ult_nonzero(
+; CHECK-NEXT:    [[SQRT:%.*]] = call half @llvm.sqrt.f16(half [[X:%.*]])
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp ult half [[SQRT]], 0xH3C00
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %sqrt = call half @llvm.sqrt(half %x)
+  %cmp = fcmp ult half %sqrt, 1.000000e+00
+  ret i1 %cmp
+}

>From ab676809fcb52a69044534c3b914462c0fee2a86 Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Fri, 2 Aug 2024 15:12:44 +0800
Subject: [PATCH 2/4] [InstCombine] Fold `fcmp pred sqrt(X), 0.0 -> fcmp pred2
 X, 0.0`

---
 .../InstCombine/InstCombineCompares.cpp       | 60 +++++++++++++++++++
 llvm/test/Transforms/InstCombine/fcmp.ll      | 39 ++++--------
 .../Transforms/InstCombine/known-never-nan.ll |  4 +-
 3 files changed, 74 insertions(+), 29 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 3b6df2760ecc2..622e7a420dd95 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -7980,6 +7980,63 @@ static Instruction *foldFabsWithFcmpZero(FCmpInst &I, InstCombinerImpl &IC) {
   }
 }
 
+/// Optimize sqrt(X) compared with zero.
+static Instruction *foldSqrtWithFcmpZero(FCmpInst &I, InstCombinerImpl &IC) {
+  Value *X;
+  if (!match(I.getOperand(0), m_Sqrt(m_Value(X))))
+    return nullptr;
+
+  if (!match(I.getOperand(1), m_PosZeroFP()))
+    return nullptr;
+
+  auto ReplacePredAndOp0 = [&](FCmpInst::Predicate P) {
+    I.setPredicate(P);
+    return IC.replaceOperand(I, 0, X);
+  };
+
+  switch (I.getPredicate()) {
+  case FCmpInst::FCMP_OLT:
+  case FCmpInst::FCMP_UGE:
+    // sqrt(X) < 0.0 --> false
+    // sqrt(X) u>= 0.0 --> true
+    llvm_unreachable("fcmp should have simplified");
+  case FCmpInst::FCMP_ULT:
+  case FCmpInst::FCMP_ULE:
+  case FCmpInst::FCMP_OGT:
+  case FCmpInst::FCMP_OGE:
+  case FCmpInst::FCMP_OEQ:
+  case FCmpInst::FCMP_UNE:
+    // sqrt(X) u< 0.0 --> X u< 0.0
+    // sqrt(X) u<= 0.0 --> X u<= 0.0
+    // sqrt(X) > 0.0 --> X > 0.0
+    // sqrt(X) >= 0.0 --> X >= 0.0
+    // sqrt(X) == 0.0 --> X == 0.0
+    // sqrt(X) u!= 0.0 --> X u!= 0.0
+    return IC.replaceOperand(I, 0, X);
+
+  case FCmpInst::FCMP_OLE:
+    // sqrt(X) <= 0.0 --> X == 0.0
+    return ReplacePredAndOp0(FCmpInst::FCMP_OEQ);
+  case FCmpInst::FCMP_UGT:
+    // sqrt(X) u> 0.0 --> X u!= 0.0
+    return ReplacePredAndOp0(FCmpInst::FCMP_UNE);
+  case FCmpInst::FCMP_UEQ:
+    // sqrt(X) u== 0.0 --> X u<= 0.0
+    return ReplacePredAndOp0(FCmpInst::FCMP_ULE);
+  case FCmpInst::FCMP_ONE:
+    // sqrt(X) != 0.0 --> X > 0.0
+    return ReplacePredAndOp0(FCmpInst::FCMP_OGT);
+  case FCmpInst::FCMP_ORD:
+    // !isnan(sqrt(X)) --> X >= 0.0
+    return ReplacePredAndOp0(FCmpInst::FCMP_OGE);
+  case FCmpInst::FCMP_UNO:
+    // isnan(sqrt(X)) --> X u< 0.0
+    return ReplacePredAndOp0(FCmpInst::FCMP_ULT);
+  default:
+    llvm_unreachable("Unexpected predicate!");
+  }
+}
+
 static Instruction *foldFCmpFNegCommonOp(FCmpInst &I) {
   CmpInst::Predicate Pred = I.getPredicate();
   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
@@ -8247,6 +8304,9 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) {
   if (Instruction *R = foldFabsWithFcmpZero(I, *this))
     return R;
 
+  if (Instruction *R = foldSqrtWithFcmpZero(I, *this))
+    return R;
+
   if (match(Op0, m_FNeg(m_Value(X)))) {
     // fcmp pred (fneg X), C --> fcmp swap(pred) X, -C
     Constant *C;
diff --git a/llvm/test/Transforms/InstCombine/fcmp.ll b/llvm/test/Transforms/InstCombine/fcmp.ll
index aafb1e25e13a5..3ea93149094c8 100644
--- a/llvm/test/Transforms/InstCombine/fcmp.ll
+++ b/llvm/test/Transforms/InstCombine/fcmp.ll
@@ -2129,8 +2129,7 @@ define i1 @fcmp_sqrt_zero_olt(half %x) {
 
 define i1 @fcmp_sqrt_zero_ult(half %x) {
 ; CHECK-LABEL: @fcmp_sqrt_zero_ult(
-; CHECK-NEXT:    [[SQRT:%.*]] = call half @llvm.sqrt.f16(half [[X:%.*]])
-; CHECK-NEXT:    [[CMP:%.*]] = fcmp ult half [[SQRT]], 0xH0000
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp ult half [[X:%.*]], 0xH0000
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
   %sqrt = call half @llvm.sqrt(half %x)
@@ -2140,8 +2139,7 @@ define i1 @fcmp_sqrt_zero_ult(half %x) {
 
 define <2 x i1> @fcmp_sqrt_zero_ult_vec(<2 x half> %x) {
 ; CHECK-LABEL: @fcmp_sqrt_zero_ult_vec(
-; CHECK-NEXT:    [[SQRT:%.*]] = call <2 x half> @llvm.sqrt.v2f16(<2 x half> [[X:%.*]])
-; CHECK-NEXT:    [[CMP:%.*]] = fcmp ult <2 x half> [[SQRT]], zeroinitializer
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp ult <2 x half> [[X:%.*]], zeroinitializer
 ; CHECK-NEXT:    ret <2 x i1> [[CMP]]
 ;
   %sqrt = call <2 x half> @llvm.sqrt(<2 x half> %x)
@@ -2151,8 +2149,7 @@ define <2 x i1> @fcmp_sqrt_zero_ult_vec(<2 x half> %x) {
 
 define i1 @fcmp_sqrt_zero_ole(half %x) {
 ; CHECK-LABEL: @fcmp_sqrt_zero_ole(
-; CHECK-NEXT:    [[SQRT:%.*]] = call half @llvm.sqrt.f16(half [[X:%.*]])
-; CHECK-NEXT:    [[CMP:%.*]] = fcmp ole half [[SQRT]], 0xH0000
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp oeq half [[X:%.*]], 0xH0000
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
   %sqrt = call half @llvm.sqrt(half %x)
@@ -2162,8 +2159,7 @@ define i1 @fcmp_sqrt_zero_ole(half %x) {
 
 define i1 @fcmp_sqrt_zero_ule(half %x) {
 ; CHECK-LABEL: @fcmp_sqrt_zero_ule(
-; CHECK-NEXT:    [[SQRT:%.*]] = call half @llvm.sqrt.f16(half [[X:%.*]])
-; CHECK-NEXT:    [[CMP:%.*]] = fcmp ule half [[SQRT]], 0xH0000
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp ule half [[X:%.*]], 0xH0000
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
   %sqrt = call half @llvm.sqrt(half %x)
@@ -2173,8 +2169,7 @@ define i1 @fcmp_sqrt_zero_ule(half %x) {
 
 define i1 @fcmp_sqrt_zero_ogt(half %x) {
 ; CHECK-LABEL: @fcmp_sqrt_zero_ogt(
-; CHECK-NEXT:    [[SQRT:%.*]] = call half @llvm.sqrt.f16(half [[X:%.*]])
-; CHECK-NEXT:    [[CMP:%.*]] = fcmp ogt half [[SQRT]], 0xH0000
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp ogt half [[X:%.*]], 0xH0000
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
   %sqrt = call half @llvm.sqrt(half %x)
@@ -2184,8 +2179,7 @@ define i1 @fcmp_sqrt_zero_ogt(half %x) {
 
 define i1 @fcmp_sqrt_zero_ugt(half %x) {
 ; CHECK-LABEL: @fcmp_sqrt_zero_ugt(
-; CHECK-NEXT:    [[SQRT:%.*]] = call half @llvm.sqrt.f16(half [[X:%.*]])
-; CHECK-NEXT:    [[CMP:%.*]] = fcmp ugt half [[SQRT]], 0xH0000
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp une half [[X:%.*]], 0xH0000
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
   %sqrt = call half @llvm.sqrt(half %x)
@@ -2195,8 +2189,7 @@ define i1 @fcmp_sqrt_zero_ugt(half %x) {
 
 define i1 @fcmp_sqrt_zero_oge(half %x) {
 ; CHECK-LABEL: @fcmp_sqrt_zero_oge(
-; CHECK-NEXT:    [[SQRT:%.*]] = call half @llvm.sqrt.f16(half [[X:%.*]])
-; CHECK-NEXT:    [[CMP:%.*]] = fcmp oge half [[SQRT]], 0xH0000
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp oge half [[X:%.*]], 0xH0000
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
   %sqrt = call half @llvm.sqrt(half %x)
@@ -2215,8 +2208,7 @@ define i1 @fcmp_sqrt_zero_uge(half %x) {
 
 define i1 @fcmp_sqrt_zero_oeq(half %x) {
 ; CHECK-LABEL: @fcmp_sqrt_zero_oeq(
-; CHECK-NEXT:    [[SQRT:%.*]] = call half @llvm.sqrt.f16(half [[X:%.*]])
-; CHECK-NEXT:    [[CMP:%.*]] = fcmp oeq half [[SQRT]], 0xH0000
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp oeq half [[X:%.*]], 0xH0000
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
   %sqrt = call half @llvm.sqrt(half %x)
@@ -2226,8 +2218,7 @@ define i1 @fcmp_sqrt_zero_oeq(half %x) {
 
 define i1 @fcmp_sqrt_zero_ueq(half %x) {
 ; CHECK-LABEL: @fcmp_sqrt_zero_ueq(
-; CHECK-NEXT:    [[SQRT:%.*]] = call half @llvm.sqrt.f16(half [[X:%.*]])
-; CHECK-NEXT:    [[CMP:%.*]] = fcmp ueq half [[SQRT]], 0xH0000
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp ule half [[X:%.*]], 0xH0000
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
   %sqrt = call half @llvm.sqrt(half %x)
@@ -2237,8 +2228,7 @@ define i1 @fcmp_sqrt_zero_ueq(half %x) {
 
 define i1 @fcmp_sqrt_zero_one(half %x) {
 ; CHECK-LABEL: @fcmp_sqrt_zero_one(
-; CHECK-NEXT:    [[SQRT:%.*]] = call half @llvm.sqrt.f16(half [[X:%.*]])
-; CHECK-NEXT:    [[CMP:%.*]] = fcmp one half [[SQRT]], 0xH0000
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp ogt half [[X:%.*]], 0xH0000
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
   %sqrt = call half @llvm.sqrt(half %x)
@@ -2248,8 +2238,7 @@ define i1 @fcmp_sqrt_zero_one(half %x) {
 
 define i1 @fcmp_sqrt_zero_une(half %x) {
 ; CHECK-LABEL: @fcmp_sqrt_zero_une(
-; CHECK-NEXT:    [[SQRT:%.*]] = call half @llvm.sqrt.f16(half [[X:%.*]])
-; CHECK-NEXT:    [[CMP:%.*]] = fcmp une half [[SQRT]], 0xH0000
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp une half [[X:%.*]], 0xH0000
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
   %sqrt = call half @llvm.sqrt(half %x)
@@ -2259,8 +2248,7 @@ define i1 @fcmp_sqrt_zero_une(half %x) {
 
 define i1 @fcmp_sqrt_zero_ord(half %x) {
 ; CHECK-LABEL: @fcmp_sqrt_zero_ord(
-; CHECK-NEXT:    [[SQRT:%.*]] = call half @llvm.sqrt.f16(half [[X:%.*]])
-; CHECK-NEXT:    [[CMP:%.*]] = fcmp ord half [[SQRT]], 0xH0000
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp oge half [[X:%.*]], 0xH0000
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
   %sqrt = call half @llvm.sqrt(half %x)
@@ -2270,8 +2258,7 @@ define i1 @fcmp_sqrt_zero_ord(half %x) {
 
 define i1 @fcmp_sqrt_zero_uno(half %x) {
 ; CHECK-LABEL: @fcmp_sqrt_zero_uno(
-; CHECK-NEXT:    [[SQRT:%.*]] = call half @llvm.sqrt.f16(half [[X:%.*]])
-; CHECK-NEXT:    [[CMP:%.*]] = fcmp uno half [[SQRT]], 0xH0000
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp ult half [[X:%.*]], 0xH0000
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
   %sqrt = call half @llvm.sqrt(half %x)
diff --git a/llvm/test/Transforms/InstCombine/known-never-nan.ll b/llvm/test/Transforms/InstCombine/known-never-nan.ll
index a1cabc29682b4..82075b37b4361 100644
--- a/llvm/test/Transforms/InstCombine/known-never-nan.ll
+++ b/llvm/test/Transforms/InstCombine/known-never-nan.ll
@@ -9,9 +9,7 @@
 
 define i1 @fabs_sqrt_src_maybe_nan(double %arg0, double %arg1) {
 ; CHECK-LABEL: @fabs_sqrt_src_maybe_nan(
-; CHECK-NEXT:    [[FABS:%.*]] = call double @llvm.fabs.f64(double [[ARG0:%.*]])
-; CHECK-NEXT:    [[OP:%.*]] = call double @llvm.sqrt.f64(double [[FABS]])
-; CHECK-NEXT:    [[TMP:%.*]] = fcmp ord double [[OP]], 0.000000e+00
+; CHECK-NEXT:    [[TMP:%.*]] = fcmp ord double [[ARG0:%.*]], 0.000000e+00
 ; CHECK-NEXT:    ret i1 [[TMP]]
 ;
   %fabs = call double @llvm.fabs.f64(double %arg0)

>From 33ca92adfab1607d77c3af41792604c4141ac609 Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Fri, 2 Aug 2024 16:22:31 +0800
Subject: [PATCH 3/4] [InstCombine] Address review comments.

---
 .../InstCombine/InstCombineCompares.cpp       |  3 ++
 llvm/test/Transforms/InstCombine/fcmp.ll      | 41 +++++++++++++++++++
 2 files changed, 44 insertions(+)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 622e7a420dd95..b7ee18c39e141 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -7994,6 +7994,9 @@ static Instruction *foldSqrtWithFcmpZero(FCmpInst &I, InstCombinerImpl &IC) {
     return IC.replaceOperand(I, 0, X);
   };
 
+  // Clear ninf flag.
+  I.setHasNoInfs(false);
+
   switch (I.getPredicate()) {
   case FCmpInst::FCMP_OLT:
   case FCmpInst::FCMP_UGE:
diff --git a/llvm/test/Transforms/InstCombine/fcmp.ll b/llvm/test/Transforms/InstCombine/fcmp.ll
index 3ea93149094c8..f688ac2c68dbe 100644
--- a/llvm/test/Transforms/InstCombine/fcmp.ll
+++ b/llvm/test/Transforms/InstCombine/fcmp.ll
@@ -2137,6 +2137,26 @@ define i1 @fcmp_sqrt_zero_ult(half %x) {
   ret i1 %cmp
 }
 
+define i1 @fcmp_sqrt_zero_ult_fmf(half %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_ult_fmf(
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp nsz ult half [[X:%.*]], 0xH0000
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %sqrt = call half @llvm.sqrt(half %x)
+  %cmp = fcmp ninf nsz ult half %sqrt, 0.0
+  ret i1 %cmp
+}
+
+define i1 @fcmp_sqrt_zero_ult_nzero(half %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_ult_nzero(
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp ult half [[X:%.*]], 0xH0000
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %sqrt = call half @llvm.sqrt(half %x)
+  %cmp = fcmp ult half %sqrt, -0.0
+  ret i1 %cmp
+}
+
 define <2 x i1> @fcmp_sqrt_zero_ult_vec(<2 x half> %x) {
 ; CHECK-LABEL: @fcmp_sqrt_zero_ult_vec(
 ; CHECK-NEXT:    [[CMP:%.*]] = fcmp ult <2 x half> [[X:%.*]], zeroinitializer
@@ -2147,6 +2167,16 @@ define <2 x i1> @fcmp_sqrt_zero_ult_vec(<2 x half> %x) {
   ret <2 x i1> %cmp
 }
 
+define <2 x i1> @fcmp_sqrt_zero_ult_vec_mixed_zero(<2 x half> %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_ult_vec_mixed_zero(
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp ult <2 x half> [[X:%.*]], zeroinitializer
+; CHECK-NEXT:    ret <2 x i1> [[CMP]]
+;
+  %sqrt = call <2 x half> @llvm.sqrt(<2 x half> %x)
+  %cmp = fcmp ult <2 x half> %sqrt, <half 0.0, half -0.0>
+  ret <2 x i1> %cmp
+}
+
 define i1 @fcmp_sqrt_zero_ole(half %x) {
 ; CHECK-LABEL: @fcmp_sqrt_zero_ole(
 ; CHECK-NEXT:    [[CMP:%.*]] = fcmp oeq half [[X:%.*]], 0xH0000
@@ -2266,6 +2296,17 @@ define i1 @fcmp_sqrt_zero_uno(half %x) {
   ret i1 %cmp
 }
 
+; Make sure that ninf is cleared.
+define i1 @fcmp_sqrt_zero_uno_fmf(half %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_uno_fmf(
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp ult half [[X:%.*]], 0xH0000
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %sqrt = call half @llvm.sqrt(half %x)
+  %cmp = fcmp ninf uno half %sqrt, 0.0
+  ret i1 %cmp
+}
+
 ; negative tests
 
 define i1 @fcmp_sqrt_zero_ult_var(half %x, half %y) {

>From 2aa5a3e491a53b8d90d3e335182400df944c8522 Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Sat, 3 Aug 2024 00:53:24 +0800
Subject: [PATCH 4/4] [InstCombine] Address review comments.

---
 .../InstCombine/InstCombineCompares.cpp       |  5 +-
 llvm/test/Transforms/InstCombine/fcmp.ll      | 62 ++++++++++++-------
 2 files changed, 44 insertions(+), 23 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index b7ee18c39e141..94786f0b9ec54 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -7994,8 +7994,9 @@ static Instruction *foldSqrtWithFcmpZero(FCmpInst &I, InstCombinerImpl &IC) {
     return IC.replaceOperand(I, 0, X);
   };
 
-  // Clear ninf flag.
-  I.setHasNoInfs(false);
+  // Clear ninf flag if sqrt doesn't have it.
+  if (!cast<Instruction>(I.getOperand(0))->hasNoInfs())
+    I.setHasNoInfs(false);
 
   switch (I.getPredicate()) {
   case FCmpInst::FCMP_OLT:
diff --git a/llvm/test/Transforms/InstCombine/fcmp.ll b/llvm/test/Transforms/InstCombine/fcmp.ll
index f688ac2c68dbe..8afb6463b669d 100644
--- a/llvm/test/Transforms/InstCombine/fcmp.ll
+++ b/llvm/test/Transforms/InstCombine/fcmp.ll
@@ -2122,7 +2122,7 @@ define i1 @fcmp_sqrt_zero_olt(half %x) {
 ; CHECK-LABEL: @fcmp_sqrt_zero_olt(
 ; CHECK-NEXT:    ret i1 false
 ;
-  %sqrt = call half @llvm.sqrt(half %x)
+  %sqrt = call half @llvm.sqrt.f16(half %x)
   %cmp = fcmp olt half %sqrt, 0.0
   ret i1 %cmp
 }
@@ -2132,7 +2132,7 @@ define i1 @fcmp_sqrt_zero_ult(half %x) {
 ; CHECK-NEXT:    [[CMP:%.*]] = fcmp ult half [[X:%.*]], 0xH0000
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
-  %sqrt = call half @llvm.sqrt(half %x)
+  %sqrt = call half @llvm.sqrt.f16(half %x)
   %cmp = fcmp ult half %sqrt, 0.0
   ret i1 %cmp
 }
@@ -2142,7 +2142,17 @@ define i1 @fcmp_sqrt_zero_ult_fmf(half %x) {
 ; CHECK-NEXT:    [[CMP:%.*]] = fcmp nsz ult half [[X:%.*]], 0xH0000
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
-  %sqrt = call half @llvm.sqrt(half %x)
+  %sqrt = call half @llvm.sqrt.f16(half %x)
+  %cmp = fcmp ninf nsz ult half %sqrt, 0.0
+  ret i1 %cmp
+}
+
+define i1 @fcmp_sqrt_zero_ult_fmf_sqrt_ninf(half %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_ult_fmf_sqrt_ninf(
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp ninf nsz ult half [[X:%.*]], 0xH0000
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %sqrt = call ninf half @llvm.sqrt.f16(half %x)
   %cmp = fcmp ninf nsz ult half %sqrt, 0.0
   ret i1 %cmp
 }
@@ -2152,7 +2162,7 @@ define i1 @fcmp_sqrt_zero_ult_nzero(half %x) {
 ; CHECK-NEXT:    [[CMP:%.*]] = fcmp ult half [[X:%.*]], 0xH0000
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
-  %sqrt = call half @llvm.sqrt(half %x)
+  %sqrt = call half @llvm.sqrt.f16(half %x)
   %cmp = fcmp ult half %sqrt, -0.0
   ret i1 %cmp
 }
@@ -2162,7 +2172,7 @@ define <2 x i1> @fcmp_sqrt_zero_ult_vec(<2 x half> %x) {
 ; CHECK-NEXT:    [[CMP:%.*]] = fcmp ult <2 x half> [[X:%.*]], zeroinitializer
 ; CHECK-NEXT:    ret <2 x i1> [[CMP]]
 ;
-  %sqrt = call <2 x half> @llvm.sqrt(<2 x half> %x)
+  %sqrt = call <2 x half> @llvm.sqrt.v2f16(<2 x half> %x)
   %cmp = fcmp ult <2 x half> %sqrt, zeroinitializer
   ret <2 x i1> %cmp
 }
@@ -2172,7 +2182,7 @@ define <2 x i1> @fcmp_sqrt_zero_ult_vec_mixed_zero(<2 x half> %x) {
 ; CHECK-NEXT:    [[CMP:%.*]] = fcmp ult <2 x half> [[X:%.*]], zeroinitializer
 ; CHECK-NEXT:    ret <2 x i1> [[CMP]]
 ;
-  %sqrt = call <2 x half> @llvm.sqrt(<2 x half> %x)
+  %sqrt = call <2 x half> @llvm.sqrt.v2f16(<2 x half> %x)
   %cmp = fcmp ult <2 x half> %sqrt, <half 0.0, half -0.0>
   ret <2 x i1> %cmp
 }
@@ -2182,7 +2192,7 @@ define i1 @fcmp_sqrt_zero_ole(half %x) {
 ; CHECK-NEXT:    [[CMP:%.*]] = fcmp oeq half [[X:%.*]], 0xH0000
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
-  %sqrt = call half @llvm.sqrt(half %x)
+  %sqrt = call half @llvm.sqrt.f16(half %x)
   %cmp = fcmp ole half %sqrt, 0.0
   ret i1 %cmp
 }
@@ -2192,7 +2202,7 @@ define i1 @fcmp_sqrt_zero_ule(half %x) {
 ; CHECK-NEXT:    [[CMP:%.*]] = fcmp ule half [[X:%.*]], 0xH0000
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
-  %sqrt = call half @llvm.sqrt(half %x)
+  %sqrt = call half @llvm.sqrt.f16(half %x)
   %cmp = fcmp ule half %sqrt, 0.0
   ret i1 %cmp
 }
@@ -2202,7 +2212,7 @@ define i1 @fcmp_sqrt_zero_ogt(half %x) {
 ; CHECK-NEXT:    [[CMP:%.*]] = fcmp ogt half [[X:%.*]], 0xH0000
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
-  %sqrt = call half @llvm.sqrt(half %x)
+  %sqrt = call half @llvm.sqrt.f16(half %x)
   %cmp = fcmp ogt half %sqrt, 0.0
   ret i1 %cmp
 }
@@ -2212,7 +2222,7 @@ define i1 @fcmp_sqrt_zero_ugt(half %x) {
 ; CHECK-NEXT:    [[CMP:%.*]] = fcmp une half [[X:%.*]], 0xH0000
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
-  %sqrt = call half @llvm.sqrt(half %x)
+  %sqrt = call half @llvm.sqrt.f16(half %x)
   %cmp = fcmp ugt half %sqrt, 0.0
   ret i1 %cmp
 }
@@ -2222,7 +2232,7 @@ define i1 @fcmp_sqrt_zero_oge(half %x) {
 ; CHECK-NEXT:    [[CMP:%.*]] = fcmp oge half [[X:%.*]], 0xH0000
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
-  %sqrt = call half @llvm.sqrt(half %x)
+  %sqrt = call half @llvm.sqrt.f16(half %x)
   %cmp = fcmp oge half %sqrt, 0.0
   ret i1 %cmp
 }
@@ -2231,7 +2241,7 @@ define i1 @fcmp_sqrt_zero_uge(half %x) {
 ; CHECK-LABEL: @fcmp_sqrt_zero_uge(
 ; CHECK-NEXT:    ret i1 true
 ;
-  %sqrt = call half @llvm.sqrt(half %x)
+  %sqrt = call half @llvm.sqrt.f16(half %x)
   %cmp = fcmp uge half %sqrt, 0.0
   ret i1 %cmp
 }
@@ -2241,7 +2251,7 @@ define i1 @fcmp_sqrt_zero_oeq(half %x) {
 ; CHECK-NEXT:    [[CMP:%.*]] = fcmp oeq half [[X:%.*]], 0xH0000
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
-  %sqrt = call half @llvm.sqrt(half %x)
+  %sqrt = call half @llvm.sqrt.f16(half %x)
   %cmp = fcmp oeq half %sqrt, 0.0
   ret i1 %cmp
 }
@@ -2251,7 +2261,7 @@ define i1 @fcmp_sqrt_zero_ueq(half %x) {
 ; CHECK-NEXT:    [[CMP:%.*]] = fcmp ule half [[X:%.*]], 0xH0000
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
-  %sqrt = call half @llvm.sqrt(half %x)
+  %sqrt = call half @llvm.sqrt.f16(half %x)
   %cmp = fcmp ueq half %sqrt, 0.0
   ret i1 %cmp
 }
@@ -2261,7 +2271,7 @@ define i1 @fcmp_sqrt_zero_one(half %x) {
 ; CHECK-NEXT:    [[CMP:%.*]] = fcmp ogt half [[X:%.*]], 0xH0000
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
-  %sqrt = call half @llvm.sqrt(half %x)
+  %sqrt = call half @llvm.sqrt.f16(half %x)
   %cmp = fcmp one half %sqrt, 0.0
   ret i1 %cmp
 }
@@ -2271,7 +2281,7 @@ define i1 @fcmp_sqrt_zero_une(half %x) {
 ; CHECK-NEXT:    [[CMP:%.*]] = fcmp une half [[X:%.*]], 0xH0000
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
-  %sqrt = call half @llvm.sqrt(half %x)
+  %sqrt = call half @llvm.sqrt.f16(half %x)
   %cmp = fcmp une half %sqrt, 0.0
   ret i1 %cmp
 }
@@ -2281,7 +2291,7 @@ define i1 @fcmp_sqrt_zero_ord(half %x) {
 ; CHECK-NEXT:    [[CMP:%.*]] = fcmp oge half [[X:%.*]], 0xH0000
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
-  %sqrt = call half @llvm.sqrt(half %x)
+  %sqrt = call half @llvm.sqrt.f16(half %x)
   %cmp = fcmp ord half %sqrt, 0.0
   ret i1 %cmp
 }
@@ -2291,7 +2301,7 @@ define i1 @fcmp_sqrt_zero_uno(half %x) {
 ; CHECK-NEXT:    [[CMP:%.*]] = fcmp ult half [[X:%.*]], 0xH0000
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
-  %sqrt = call half @llvm.sqrt(half %x)
+  %sqrt = call half @llvm.sqrt.f16(half %x)
   %cmp = fcmp uno half %sqrt, 0.0
   ret i1 %cmp
 }
@@ -2302,7 +2312,17 @@ define i1 @fcmp_sqrt_zero_uno_fmf(half %x) {
 ; CHECK-NEXT:    [[CMP:%.*]] = fcmp ult half [[X:%.*]], 0xH0000
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
-  %sqrt = call half @llvm.sqrt(half %x)
+  %sqrt = call half @llvm.sqrt.f16(half %x)
+  %cmp = fcmp ninf uno half %sqrt, 0.0
+  ret i1 %cmp
+}
+
+define i1 @fcmp_sqrt_zero_uno_fmf_sqrt_ninf(half %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_uno_fmf_sqrt_ninf(
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp ninf ult half [[X:%.*]], 0xH0000
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %sqrt = call ninf half @llvm.sqrt.f16(half %x)
   %cmp = fcmp ninf uno half %sqrt, 0.0
   ret i1 %cmp
 }
@@ -2315,7 +2335,7 @@ define i1 @fcmp_sqrt_zero_ult_var(half %x, half %y) {
 ; CHECK-NEXT:    [[CMP:%.*]] = fcmp ult half [[SQRT]], [[Y:%.*]]
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
-  %sqrt = call half @llvm.sqrt(half %x)
+  %sqrt = call half @llvm.sqrt.f16(half %x)
   %cmp = fcmp ult half %sqrt, %y
   ret i1 %cmp
 }
@@ -2326,7 +2346,7 @@ define i1 @fcmp_sqrt_zero_ult_nonzero(half %x) {
 ; CHECK-NEXT:    [[CMP:%.*]] = fcmp ult half [[SQRT]], 0xH3C00
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
-  %sqrt = call half @llvm.sqrt(half %x)
+  %sqrt = call half @llvm.sqrt.f16(half %x)
   %cmp = fcmp ult half %sqrt, 1.000000e+00
   ret i1 %cmp
 }



More information about the llvm-commits mailing list