[llvm] [InstCombine] Fix Failure to convert vector fp comparisons that can be represented as integers #82241 (PR #83274)

via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 28 07:24:12 PST 2024


https://github.com/SahilPatidar created https://github.com/llvm/llvm-project/pull/83274

Resolve #82241

>From 07773ad4c428bc77947b9fd4c0374e059b3ca334 Mon Sep 17 00:00:00 2001
From: SahilPatidar <patidarsahil at 2001gmail.com>
Date: Wed, 28 Feb 2024 09:42:29 +0530
Subject: [PATCH] [InstCombine] Fix Failure to convert vector fp comparisons
 that can be represented as integers #82241

---
 .../InstCombine/InstCombineCompares.cpp       | 57 ++++++++++++------
 .../InstCombine/cast-int-fcmp-eq-0.ll         | 58 +++++++++++++++++++
 .../Transforms/InstCombine/clamp-to-minmax.ll | 17 ++++++
 llvm/test/Transforms/InstCombine/sitofp.ll    | 37 ++++++++++++
 4 files changed, 151 insertions(+), 18 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 49e597171b1c6f..9cbf920bc76a49 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -7286,6 +7286,10 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) {
 Instruction *InstCombinerImpl::foldFCmpIntToFPConst(FCmpInst &I,
                                                     Instruction *LHSI,
                                                     Constant *RHSC) {
+  if (RHSC->getType()->isVectorTy()) {
+    if (Constant *CV = RHSC->getSplatValue(false))
+      RHSC = CV;
+  }
   if (!isa<ConstantFP>(RHSC)) return nullptr;
   const APFloat &RHS = cast<ConstantFP>(RHSC)->getValueAPF();
 
@@ -7298,6 +7302,16 @@ Instruction *InstCombinerImpl::foldFCmpIntToFPConst(FCmpInst &I,
   unsigned IntWidth = IntTy->getScalarSizeInBits();
   bool LHSUnsigned = isa<UIToFPInst>(LHSI);
 
+  Value *False, *True;
+  if (IntTy->isVectorTy()) {
+    ElementCount EC = cast<VectorType>(IntTy)->getElementCount();
+    True = Builder.CreateVectorSplat(EC, Builder.getTrue());
+    False = Builder.CreateVectorSplat(EC, Builder.getFalse());
+  } else {
+    True = Builder.getTrue();
+    False = Builder.getFalse();
+  }
+
   if (I.isEquality()) {
     FCmpInst::Predicate P = I.getPredicate();
     bool IsExact = false;
@@ -7312,10 +7326,10 @@ Instruction *InstCombinerImpl::foldFCmpIntToFPConst(FCmpInst &I,
       RHSRoundInt.roundToIntegral(APFloat::rmNearestTiesToEven);
       if (RHS != RHSRoundInt) {
         if (P == FCmpInst::FCMP_OEQ || P == FCmpInst::FCMP_UEQ)
-          return replaceInstUsesWith(I, Builder.getFalse());
+          return replaceInstUsesWith(I, False);
 
         assert(P == FCmpInst::FCMP_ONE || P == FCmpInst::FCMP_UNE);
-        return replaceInstUsesWith(I, Builder.getTrue());
+        return replaceInstUsesWith(I, True);
       }
     }
 
@@ -7380,9 +7394,9 @@ Instruction *InstCombinerImpl::foldFCmpIntToFPConst(FCmpInst &I,
     Pred = ICmpInst::ICMP_NE;
     break;
   case FCmpInst::FCMP_ORD:
-    return replaceInstUsesWith(I, Builder.getTrue());
+    return replaceInstUsesWith(I, True);
   case FCmpInst::FCMP_UNO:
-    return replaceInstUsesWith(I, Builder.getFalse());
+    return replaceInstUsesWith(I, False);
   }
 
   // Now we know that the APFloat is a normal number, zero or inf.
@@ -7398,8 +7412,8 @@ Instruction *InstCombinerImpl::foldFCmpIntToFPConst(FCmpInst &I,
     if (SMax < RHS) { // smax < 13123.0
       if (Pred == ICmpInst::ICMP_NE  || Pred == ICmpInst::ICMP_SLT ||
           Pred == ICmpInst::ICMP_SLE)
-        return replaceInstUsesWith(I, Builder.getTrue());
-      return replaceInstUsesWith(I, Builder.getFalse());
+        return replaceInstUsesWith(I, True);
+      return replaceInstUsesWith(I, False);
     }
   } else {
     // If the RHS value is > UnsignedMax, fold the comparison. This handles
@@ -7410,8 +7424,8 @@ Instruction *InstCombinerImpl::foldFCmpIntToFPConst(FCmpInst &I,
     if (UMax < RHS) { // umax < 13123.0
       if (Pred == ICmpInst::ICMP_NE  || Pred == ICmpInst::ICMP_ULT ||
           Pred == ICmpInst::ICMP_ULE)
-        return replaceInstUsesWith(I, Builder.getTrue());
-      return replaceInstUsesWith(I, Builder.getFalse());
+        return replaceInstUsesWith(I, True);
+      return replaceInstUsesWith(I, False);
     }
   }
 
@@ -7423,8 +7437,8 @@ Instruction *InstCombinerImpl::foldFCmpIntToFPConst(FCmpInst &I,
     if (SMin > RHS) { // smin > 12312.0
       if (Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_SGT ||
           Pred == ICmpInst::ICMP_SGE)
-        return replaceInstUsesWith(I, Builder.getTrue());
-      return replaceInstUsesWith(I, Builder.getFalse());
+        return replaceInstUsesWith(I, True);
+      return replaceInstUsesWith(I, False);
     }
   } else {
     // See if the RHS value is < UnsignedMin.
@@ -7434,8 +7448,8 @@ Instruction *InstCombinerImpl::foldFCmpIntToFPConst(FCmpInst &I,
     if (UMin > RHS) { // umin > 12312.0
       if (Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_UGT ||
           Pred == ICmpInst::ICMP_UGE)
-        return replaceInstUsesWith(I, Builder.getTrue());
-      return replaceInstUsesWith(I, Builder.getFalse());
+        return replaceInstUsesWith(I, True);
+      return replaceInstUsesWith(I, False);
     }
   }
 
@@ -7454,14 +7468,14 @@ Instruction *InstCombinerImpl::foldFCmpIntToFPConst(FCmpInst &I,
       switch (Pred) {
       default: llvm_unreachable("Unexpected integer comparison!");
       case ICmpInst::ICMP_NE:  // (float)int != 4.4   --> true
-        return replaceInstUsesWith(I, Builder.getTrue());
+        return replaceInstUsesWith(I, True);
       case ICmpInst::ICMP_EQ:  // (float)int == 4.4   --> false
-        return replaceInstUsesWith(I, Builder.getFalse());
+        return replaceInstUsesWith(I, False);
       case ICmpInst::ICMP_ULE:
         // (float)int <= 4.4   --> int <= 4
         // (float)int <= -4.4  --> false
         if (RHS.isNegative())
-          return replaceInstUsesWith(I, Builder.getFalse());
+          return replaceInstUsesWith(I, False);
         break;
       case ICmpInst::ICMP_SLE:
         // (float)int <= 4.4   --> int <= 4
@@ -7473,7 +7487,7 @@ Instruction *InstCombinerImpl::foldFCmpIntToFPConst(FCmpInst &I,
         // (float)int < -4.4   --> false
         // (float)int < 4.4    --> int <= 4
         if (RHS.isNegative())
-          return replaceInstUsesWith(I, Builder.getFalse());
+          return replaceInstUsesWith(I, False);
         Pred = ICmpInst::ICMP_ULE;
         break;
       case ICmpInst::ICMP_SLT:
@@ -7486,7 +7500,7 @@ Instruction *InstCombinerImpl::foldFCmpIntToFPConst(FCmpInst &I,
         // (float)int > 4.4    --> int > 4
         // (float)int > -4.4   --> true
         if (RHS.isNegative())
-          return replaceInstUsesWith(I, Builder.getTrue());
+          return replaceInstUsesWith(I, True);
         break;
       case ICmpInst::ICMP_SGT:
         // (float)int > 4.4    --> int > 4
@@ -7498,7 +7512,7 @@ Instruction *InstCombinerImpl::foldFCmpIntToFPConst(FCmpInst &I,
         // (float)int >= -4.4   --> true
         // (float)int >= 4.4    --> int > 4
         if (RHS.isNegative())
-          return replaceInstUsesWith(I, Builder.getTrue());
+          return replaceInstUsesWith(I, True);
         Pred = ICmpInst::ICMP_UGT;
         break;
       case ICmpInst::ICMP_SGE:
@@ -7511,6 +7525,13 @@ Instruction *InstCombinerImpl::foldFCmpIntToFPConst(FCmpInst &I,
     }
   }
 
+  if (IntTy->isVectorTy()) {
+    Value *RHSV = Builder.CreateVectorSplat(
+                     cast<VectorType>(IntTy)->getElementCount(),
+                     Builder.getInt(RHSInt));
+    return new ICmpInst(Pred, LHSI->getOperand(0), RHSV);
+  }
+
   // Lower this FP comparison into an appropriate integer version of the
   // comparison.
   return new ICmpInst(Pred, LHSI->getOperand(0), Builder.getInt(RHSInt));
diff --git a/llvm/test/Transforms/InstCombine/cast-int-fcmp-eq-0.ll b/llvm/test/Transforms/InstCombine/cast-int-fcmp-eq-0.ll
index b04033d94a26f4..68a386ec004237 100644
--- a/llvm/test/Transforms/InstCombine/cast-int-fcmp-eq-0.ll
+++ b/llvm/test/Transforms/InstCombine/cast-int-fcmp-eq-0.ll
@@ -509,3 +509,61 @@ define i1 @i128_cast_cmp_oeq_int_inf_uitofp(i128 %i) {
   %cmp = fcmp oeq float %f, 0x7FF0000000000000
   ret i1 %cmp
 }
+
+define <2 x i1> @i32_vec_cast_cmp_oeq_vec_int_0_sitofp(<2 x i32> %i) {
+; CHECK-LABEL: @i32_vec_cast_cmp_oeq_vec_int_0_sitofp(
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq <2 x i32> [[I:%.*]], zeroinitializer
+; CHECK-NEXT:    ret <2 x i1> [[CMP]]
+;
+  %f = sitofp <2 x i32> %i to <2 x float>
+  %cmp = fcmp oeq <2 x float> %f, <float 0.0, float 0.0>
+  ret <2 x i1> %cmp
+}
+
+define <2 x i1> @i32_vec_cast_cmp_oeq_vec_int_n0_sitofp(<2 x i32> %i) {
+; CHECK-LABEL: @i32_vec_cast_cmp_oeq_vec_int_n0_sitofp(
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq <2 x i32> [[I:%.*]], zeroinitializer
+; CHECK-NEXT:    ret <2 x i1> [[CMP]]
+;
+  %f = sitofp <2 x i32> %i to <2 x float>
+  %cmp = fcmp oeq <2 x float> %f, <float -0.0, float -0.0>
+  ret <2 x i1> %cmp
+}
+
+define <2 x i1> @i32_vec_cast_cmp_oeq_vec_int_i32imax_sitofp(<2 x i32> %i) {
+; CHECK-LABEL: @i32_vec_cast_cmp_oeq_vec_int_i32imax_sitofp(
+; CHECK-NEXT:    [[F:%.*]] = sitofp <2 x i32> [[I:%.*]] to <2 x float>
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp oeq <2 x float> [[F]], <float 0x41E0000000000000, float 0x41E0000000000000>
+; CHECK-NEXT:    ret <2 x i1> [[CMP]]
+;
+  %f = sitofp <2 x i32> %i to <2 x float>
+  %cmp = fcmp oeq <2 x float> %f, <float 0x41E0000000000000, float 0x41E0000000000000>
+  ret <2 x i1> %cmp
+}
+
+define <2 x i1> @i32_vec_cast_cmp_oeq_vec_int_negi32umax_sitofp(<2 x i32> %i) {
+; CHECK-LABEL: @i32_vec_cast_cmp_oeq_vec_int_negi32umax_sitofp(
+; CHECK-NEXT:    ret <2 x i1> zeroinitializer
+;
+  %f = sitofp <2 x i32> %i to <2 x float>
+  %cmp = fcmp oeq <2 x float> %f, <float 0xC1F0000000000000, float 0xC1F0000000000000>
+  ret <2 x i1> %cmp
+}
+
+define <2 x i1> @i32_vec_cast_cmp_oeq_vec_half_sitofp(<2 x i32> %i) {
+; CHECK-LABEL: @i32_vec_cast_cmp_oeq_vec_half_sitofp(
+; CHECK-NEXT:    ret <2 x i1> zeroinitializer
+;
+  %f = sitofp <2 x i32> %i to <2 x float>
+  %cmp = fcmp oeq <2 x float> %f, <float 0.5, float 0.5>
+  ret <2 x i1> %cmp
+}
+
+define <2 x i1> @i32_vec_cast_cmp_oeq_vec_int_inf_sitofp(<2 x i32> %i) {
+; CHECK-LABEL: @i32_vec_cast_cmp_oeq_vec_int_inf_sitofp(
+; CHECK-NEXT:    ret <2 x i1> zeroinitializer
+;
+  %f = sitofp <2 x i32> %i to <2 x float>
+  %cmp = fcmp oeq <2 x float> %f, <float 0x7FF0000000000000, float 0x7FF0000000000000>
+  ret <2 x i1> %cmp
+}
diff --git a/llvm/test/Transforms/InstCombine/clamp-to-minmax.ll b/llvm/test/Transforms/InstCombine/clamp-to-minmax.ll
index fad1176cc18fac..9da9eb36d381f0 100644
--- a/llvm/test/Transforms/InstCombine/clamp-to-minmax.ll
+++ b/llvm/test/Transforms/InstCombine/clamp-to-minmax.ll
@@ -566,3 +566,20 @@ define i32 @mixed_clamp_to_i32_2(float %x) {
   %r = select i1 %lo_cmp, i32 1, i32 %i32_min
   ret i32 %r
 }
+
+
+define <2 x float> @mixed_clamp_to_float_vec(<2 x i32> %x) {
+; CHECK-LABEL: @mixed_clamp_to_float_vec(
+; CHECK-NEXT:    [[SI_MIN:%.*]] = call <2 x i32> @llvm.smin.v2i32(<2 x i32> [[X:%.*]], <2 x i32> <i32 255, i32 255>)
+; CHECK-NEXT:    [[R1:%.*]] = call <2 x i32> @llvm.smax.v2i32(<2 x i32> [[SI_MIN]], <2 x i32> <i32 1, i32 1>)
+; CHECK-NEXT:    [[R:%.*]] = sitofp <2 x i32> [[R1]] to <2 x float>
+; CHECK-NEXT:    ret <2 x float> [[R]]
+;
+  %si_min_cmp = icmp sgt <2 x i32> %x, <i32 255, i32 255>
+  %si_min = select <2 x i1> %si_min_cmp, <2 x i32> <i32 255, i32 255>, <2 x i32> %x
+  %f_min = sitofp <2 x i32> %si_min to <2 x float>
+  %f_x = sitofp <2 x i32> %x to <2 x float>
+  %lo_cmp = fcmp ult <2 x float> %f_x, <float 1.0, float 1.0>
+  %r = select <2 x i1> %lo_cmp, <2 x float> <float 1.0, float 1.0>, <2 x float> %f_min
+  ret <2 x float> %r
+}
diff --git a/llvm/test/Transforms/InstCombine/sitofp.ll b/llvm/test/Transforms/InstCombine/sitofp.ll
index 5e0cf944880071..cc6b6425eb03c8 100644
--- a/llvm/test/Transforms/InstCombine/sitofp.ll
+++ b/llvm/test/Transforms/InstCombine/sitofp.ll
@@ -378,3 +378,40 @@ define i12 @u32_half_u12(i32 %x) {
   %r = fptoui half %h to i12
   ret i12 %r
 }
+
+define <2 x i1> @i8_vec_sitofp_test1(<2 x i8> %A) {
+; CHECK-LABEL: @i8_vec_sitofp_test1(
+; CHECK-NEXT:    ret <2 x i1> <i1 true, i1 true>
+;
+  %B = sitofp <2 x i8> %A to <2 x double>
+  %C = fcmp ult <2 x double> %B, <double 128.0, double 128.0>
+  ret <2 x i1> %C
+}
+
+define <2 x i1> @i8_vec_sitofp_test2(<2 x i8> %A) {
+; CHECK-LABEL: @i8_vec_sitofp_test2(
+; CHECK-NEXT:    ret <2 x i1> <i1 true, i1 true>
+;
+  %B = sitofp <2 x i8> %A to <2 x double>
+  %C = fcmp ugt <2 x double> %B, <double -128.1, double -128.1>
+  ret <2 x i1> %C
+}
+
+define <2 x i1> @i8_vec_sitofp_test3(<2 x i8> %A) {
+; CHECK-LABEL: @i8_vec_sitofp_test3(
+; CHECK-NEXT:    ret <2 x i1> <i1 true, i1 true>
+;
+  %B = sitofp <2 x i8> %A to <2 x double>
+  %C = fcmp ule <2 x double> %B, <double 127.0, double 127.0>
+  ret <2 x i1> %C
+}
+
+define <2 x i1> @i8_vec_sitofp_test4(<2 x i8> %A) {
+; CHECK-LABEL: @i8_vec_sitofp_test4(
+; CHECK-NEXT:    [[C:%.*]] = icmp ne <2 x i8> [[A:%.*]], <i8 127, i8 127>
+; CHECK-NEXT:    ret <2 x i1> [[C]]
+;
+  %B = sitofp <2 x i8> %A to <2 x double>
+  %C = fcmp ult <2 x double> %B, <double 127.0, double 127.0>
+  ret <2 x i1> %C
+}



More information about the llvm-commits mailing list