[llvm] 87c97d0 - [InstSimplify] Extend simplifications for `(icmp ({z|s}ext X), C)` where `C` is vector

Noah Goldstein via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 3 09:05:44 PDT 2023


Author: Noah Goldstein
Date: 2023-04-03T11:04:57-05:00
New Revision: 87c97d052cfd6dc0c03e5e36be1315f659f9f0ac

URL: https://github.com/llvm/llvm-project/commit/87c97d052cfd6dc0c03e5e36be1315f659f9f0ac
DIFF: https://github.com/llvm/llvm-project/commit/87c97d052cfd6dc0c03e5e36be1315f659f9f0ac.diff

LOG: [InstSimplify] Extend simplifications for `(icmp ({z|s}ext X), C)` where `C` is vector

Previous logic only applied for `ConstantInt` which misses all vector
cases. New code works for splat/non-splat vectors as well. No change
to the underlying simplifications.

Reviewed By: nikic

Differential Revision: https://reviews.llvm.org/D147275

Added: 
    

Modified: 
    llvm/lib/Analysis/InstructionSimplify.cpp
    llvm/test/Transforms/InstSimplify/vec-icmp-of-cast.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index eaf0af92484d7..b82b0e784e425 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -3818,22 +3818,27 @@ static Value *simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
       }
       // Turn icmp (zext X), Cst into a compare of X and Cst if Cst is extended
       // too.  If not, then try to deduce the result of the comparison.
-      else if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) {
+      else if (match(RHS, m_ImmConstant())) {
+        Constant *C = dyn_cast<Constant>(RHS);
+        assert(C != nullptr);
+
         // Compute the constant that would happen if we truncated to SrcTy then
         // reextended to DstTy.
-        Constant *Trunc = ConstantExpr::getTrunc(CI, SrcTy);
+        Constant *Trunc = ConstantExpr::getTrunc(C, SrcTy);
         Constant *RExt = ConstantExpr::getCast(CastInst::ZExt, Trunc, DstTy);
+        Constant *AnyEq = ConstantExpr::getICmp(ICmpInst::ICMP_EQ, RExt, C);
 
-        // If the re-extended constant didn't change then this is effectively
-        // also a case of comparing two zero-extended values.
-        if (RExt == CI && MaxRecurse)
+        // If the re-extended constant didn't change any of the elements then
+        // this is effectively also a case of comparing two zero-extended
+        // values.
+        if (AnyEq->isAllOnesValue() && MaxRecurse)
           if (Value *V = simplifyICmpInst(ICmpInst::getUnsignedPredicate(Pred),
                                           SrcOp, Trunc, Q, MaxRecurse - 1))
             return V;
 
         // Otherwise the upper bits of LHS are zero while RHS has a non-zero bit
         // there.  Use this to work out the result of the comparison.
-        if (RExt != CI) {
+        if (AnyEq->isNullValue()) {
           switch (Pred) {
           default:
             llvm_unreachable("Unknown ICmp predicate!");
@@ -3841,26 +3846,23 @@ static Value *simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
           case ICmpInst::ICMP_EQ:
           case ICmpInst::ICMP_UGT:
           case ICmpInst::ICMP_UGE:
-            return ConstantInt::getFalse(CI->getContext());
+            return Constant::getNullValue(ITy);
 
           case ICmpInst::ICMP_NE:
           case ICmpInst::ICMP_ULT:
           case ICmpInst::ICMP_ULE:
-            return ConstantInt::getTrue(CI->getContext());
+            return Constant::getAllOnesValue(ITy);
 
           // LHS is non-negative.  If RHS is negative then LHS >s LHS.  If RHS
           // is non-negative then LHS <s RHS.
           case ICmpInst::ICMP_SGT:
           case ICmpInst::ICMP_SGE:
-            return CI->getValue().isNegative()
-                       ? ConstantInt::getTrue(CI->getContext())
-                       : ConstantInt::getFalse(CI->getContext());
-
+            return ConstantExpr::getICmp(ICmpInst::ICMP_SLT, C,
+                                         Constant::getNullValue(C->getType()));
           case ICmpInst::ICMP_SLT:
           case ICmpInst::ICMP_SLE:
-            return CI->getValue().isNegative()
-                       ? ConstantInt::getFalse(CI->getContext())
-                       : ConstantInt::getTrue(CI->getContext());
+            return ConstantExpr::getICmp(ICmpInst::ICMP_SGE, C,
+                                         Constant::getNullValue(C->getType()));
           }
         }
       }
@@ -3887,42 +3889,44 @@ static Value *simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
       }
       // Turn icmp (sext X), Cst into a compare of X and Cst if Cst is extended
       // too.  If not, then try to deduce the result of the comparison.
-      else if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) {
+      else if (match(RHS, m_ImmConstant())) {
+        Constant *C = dyn_cast<Constant>(RHS);
+        assert(C != nullptr);
+
         // Compute the constant that would happen if we truncated to SrcTy then
         // reextended to DstTy.
-        Constant *Trunc = ConstantExpr::getTrunc(CI, SrcTy);
+        Constant *Trunc = ConstantExpr::getTrunc(C, SrcTy);
         Constant *RExt = ConstantExpr::getCast(CastInst::SExt, Trunc, DstTy);
+        Constant *AnyEq = ConstantExpr::getICmp(ICmpInst::ICMP_EQ, RExt, C);
 
         // If the re-extended constant didn't change then this is effectively
         // also a case of comparing two sign-extended values.
-        if (RExt == CI && MaxRecurse)
+        if (AnyEq->isAllOnesValue() && MaxRecurse)
           if (Value *V =
                   simplifyICmpInst(Pred, SrcOp, Trunc, Q, MaxRecurse - 1))
             return V;
 
         // Otherwise the upper bits of LHS are all equal, while RHS has varying
         // bits there.  Use this to work out the result of the comparison.
-        if (RExt != CI) {
+        if (AnyEq->isNullValue()) {
           switch (Pred) {
           default:
             llvm_unreachable("Unknown ICmp predicate!");
           case ICmpInst::ICMP_EQ:
-            return ConstantInt::getFalse(CI->getContext());
+            return Constant::getNullValue(ITy);
           case ICmpInst::ICMP_NE:
-            return ConstantInt::getTrue(CI->getContext());
+            return Constant::getAllOnesValue(ITy);
 
           // If RHS is non-negative then LHS <s RHS.  If RHS is negative then
           // LHS >s RHS.
           case ICmpInst::ICMP_SGT:
           case ICmpInst::ICMP_SGE:
-            return CI->getValue().isNegative()
-                       ? ConstantInt::getTrue(CI->getContext())
-                       : ConstantInt::getFalse(CI->getContext());
+            return ConstantExpr::getICmp(ICmpInst::ICMP_SLT, C,
+                                         Constant::getNullValue(C->getType()));
           case ICmpInst::ICMP_SLT:
           case ICmpInst::ICMP_SLE:
-            return CI->getValue().isNegative()
-                       ? ConstantInt::getFalse(CI->getContext())
-                       : ConstantInt::getTrue(CI->getContext());
+            return ConstantExpr::getICmp(ICmpInst::ICMP_SGE, C,
+                                         Constant::getNullValue(C->getType()));
 
           // If LHS is non-negative then LHS <u RHS.  If LHS is negative then
           // LHS >u RHS.

diff  --git a/llvm/test/Transforms/InstSimplify/vec-icmp-of-cast.ll b/llvm/test/Transforms/InstSimplify/vec-icmp-of-cast.ll
index d3240d6f98f6d..4acf2fba1934f 100644
--- a/llvm/test/Transforms/InstSimplify/vec-icmp-of-cast.ll
+++ b/llvm/test/Transforms/InstSimplify/vec-icmp-of-cast.ll
@@ -3,9 +3,7 @@
 
 define <2 x i1> @icmp_eq_zext_is_false(<2 x i8> %x) {
 ; CHECK-LABEL: @icmp_eq_zext_is_false(
-; CHECK-NEXT:    [[XEXT:%.*]] = zext <2 x i8> [[X:%.*]] to <2 x i32>
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq <2 x i32> [[XEXT]], <i32 511, i32 1234>
-; CHECK-NEXT:    ret <2 x i1> [[CMP]]
+; CHECK-NEXT:    ret <2 x i1> zeroinitializer
 ;
   %xext = zext <2 x i8> %x to <2 x i32>
   %cmp = icmp eq <2 x i32> %xext, <i32 511, i32 1234>
@@ -14,9 +12,7 @@ define <2 x i1> @icmp_eq_zext_is_false(<2 x i8> %x) {
 
 define <2 x i1> @icmp_ugt_zext_is_false(<2 x i8> %x) {
 ; CHECK-LABEL: @icmp_ugt_zext_is_false(
-; CHECK-NEXT:    [[XEXT:%.*]] = zext <2 x i8> [[X:%.*]] to <2 x i32>
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ugt <2 x i32> [[XEXT]], <i32 256, i32 1234>
-; CHECK-NEXT:    ret <2 x i1> [[CMP]]
+; CHECK-NEXT:    ret <2 x i1> zeroinitializer
 ;
   %xext = zext <2 x i8> %x to <2 x i32>
   %cmp = icmp ugt <2 x i32> %xext, <i32 256, i32 1234>
@@ -36,9 +32,7 @@ define <2 x i1> @icmp_ugt_zext_todo_off_by1(<2 x i8> %x) {
 
 define <2 x i1> @icmp_uge_zext_is_false(<2 x i8> %x) {
 ; CHECK-LABEL: @icmp_uge_zext_is_false(
-; CHECK-NEXT:    [[XEXT:%.*]] = zext <2 x i8> [[X:%.*]] to <2 x i32>
-; CHECK-NEXT:    [[CMP:%.*]] = icmp uge <2 x i32> [[XEXT]], <i32 256, i32 1234>
-; CHECK-NEXT:    ret <2 x i1> [[CMP]]
+; CHECK-NEXT:    ret <2 x i1> zeroinitializer
 ;
   %xext = zext <2 x i8> %x to <2 x i32>
   %cmp = icmp uge <2 x i32> %xext, <i32 256, i32 1234>
@@ -69,9 +63,7 @@ define <2 x i1> @icmp_eq_zext_unused(<2 x i8> %x) {
 
 define <2 x i1> @icmp_ne_zext_is_true(<2 x i8> %x) {
 ; CHECK-LABEL: @icmp_ne_zext_is_true(
-; CHECK-NEXT:    [[XEXT:%.*]] = zext <2 x i8> [[X:%.*]] to <2 x i32>
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ne <2 x i32> [[XEXT]], <i32 256, i32 1234>
-; CHECK-NEXT:    ret <2 x i1> [[CMP]]
+; CHECK-NEXT:    ret <2 x i1> <i1 true, i1 true>
 ;
   %xext = zext <2 x i8> %x to <2 x i32>
   %cmp = icmp ne <2 x i32> %xext, <i32 256, i32 1234>
@@ -80,9 +72,7 @@ define <2 x i1> @icmp_ne_zext_is_true(<2 x i8> %x) {
 
 define <2 x i1> @icmp_ult_zext_is_true(<2 x i8> %x) {
 ; CHECK-LABEL: @icmp_ult_zext_is_true(
-; CHECK-NEXT:    [[XEXT:%.*]] = zext <2 x i8> [[X:%.*]] to <2 x i32>
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ult <2 x i32> [[XEXT]], <i32 256, i32 1234>
-; CHECK-NEXT:    ret <2 x i1> [[CMP]]
+; CHECK-NEXT:    ret <2 x i1> <i1 true, i1 true>
 ;
   %xext = zext <2 x i8> %x to <2 x i32>
   %cmp = icmp ult <2 x i32> %xext, <i32 256, i32 1234>
@@ -91,9 +81,7 @@ define <2 x i1> @icmp_ult_zext_is_true(<2 x i8> %x) {
 
 define <2 x i1> @icmp_ule_zext_is_true(<2 x i8> %x) {
 ; CHECK-LABEL: @icmp_ule_zext_is_true(
-; CHECK-NEXT:    [[XEXT:%.*]] = zext <2 x i8> [[X:%.*]] to <2 x i32>
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ule <2 x i32> [[XEXT]], <i32 256, i32 -1>
-; CHECK-NEXT:    ret <2 x i1> [[CMP]]
+; CHECK-NEXT:    ret <2 x i1> <i1 true, i1 true>
 ;
   %xext = zext <2 x i8> %x to <2 x i32>
   %cmp = icmp ule <2 x i32> %xext, <i32 256, i32 -1>
@@ -124,9 +112,7 @@ define <2 x i1> @icmp_ne_zext_unused(<2 x i8> %x) {
 
 define <2 x i1> @icmp_sge_zext_is_false_true(<2 x i8> %x) {
 ; CHECK-LABEL: @icmp_sge_zext_is_false_true(
-; CHECK-NEXT:    [[XEXT:%.*]] = zext <2 x i8> [[X:%.*]] to <2 x i32>
-; CHECK-NEXT:    [[CMP:%.*]] = icmp sge <2 x i32> [[XEXT]], <i32 257, i32 -450>
-; CHECK-NEXT:    ret <2 x i1> [[CMP]]
+; CHECK-NEXT:    ret <2 x i1> <i1 false, i1 true>
 ;
   %xext = zext <2 x i8> %x to <2 x i32>
   %cmp = icmp sge <2 x i32> %xext, <i32 257, i32 -450>
@@ -135,9 +121,7 @@ define <2 x i1> @icmp_sge_zext_is_false_true(<2 x i8> %x) {
 
 define <2 x i1> @icmp_sle_zext_is_false_false(<2 x i8> %x) {
 ; CHECK-LABEL: @icmp_sle_zext_is_false_false(
-; CHECK-NEXT:    [[XEXT:%.*]] = zext <2 x i8> [[X:%.*]] to <2 x i32>
-; CHECK-NEXT:    [[CMP:%.*]] = icmp sle <2 x i32> [[XEXT]], <i32 -256, i32 -450>
-; CHECK-NEXT:    ret <2 x i1> [[CMP]]
+; CHECK-NEXT:    ret <2 x i1> zeroinitializer
 ;
   %xext = zext <2 x i8> %x to <2 x i32>
   %cmp = icmp sle <2 x i32> %xext, <i32 -256, i32 -450>
@@ -146,9 +130,7 @@ define <2 x i1> @icmp_sle_zext_is_false_false(<2 x i8> %x) {
 
 define <2 x i1> @icmp_eq_sext_is_false(<2 x i8> %x) {
 ; CHECK-LABEL: @icmp_eq_sext_is_false(
-; CHECK-NEXT:    [[XEXT:%.*]] = sext <2 x i8> [[X:%.*]] to <2 x i32>
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq <2 x i32> [[XEXT]], <i32 255, i32 129>
-; CHECK-NEXT:    ret <2 x i1> [[CMP]]
+; CHECK-NEXT:    ret <2 x i1> zeroinitializer
 ;
   %xext = sext <2 x i8> %x to <2 x i32>
   %cmp = icmp eq <2 x i32> %xext, <i32 255, i32 129>
@@ -168,9 +150,7 @@ define <2 x i1> @icmp_eq_sext_fail(<2 x i8> %x) {
 
 define <2 x i1> @icmp_ne_sext_is_true(<2 x i8> %x) {
 ; CHECK-LABEL: @icmp_ne_sext_is_true(
-; CHECK-NEXT:    [[XEXT:%.*]] = sext <2 x i8> [[X:%.*]] to <2 x i32>
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ne <2 x i32> [[XEXT]], <i32 199, i32 1234>
-; CHECK-NEXT:    ret <2 x i1> [[CMP]]
+; CHECK-NEXT:    ret <2 x i1> <i1 true, i1 true>
 ;
   %xext = sext <2 x i8> %x to <2 x i32>
   %cmp = icmp ne <2 x i32> %xext, <i32 199, i32 1234>
@@ -179,9 +159,7 @@ define <2 x i1> @icmp_ne_sext_is_true(<2 x i8> %x) {
 
 define <2 x i1> @icmp_sgt_sext_is_true_false(<2 x i8> %x) {
 ; CHECK-LABEL: @icmp_sgt_sext_is_true_false(
-; CHECK-NEXT:    [[XEXT:%.*]] = sext <2 x i8> [[X:%.*]] to <2 x i32>
-; CHECK-NEXT:    [[CMP:%.*]] = icmp sgt <2 x i32> [[XEXT]], <i32 -250, i32 450>
-; CHECK-NEXT:    ret <2 x i1> [[CMP]]
+; CHECK-NEXT:    ret <2 x i1> <i1 true, i1 false>
 ;
   %xext = sext <2 x i8> %x to <2 x i32>
   %cmp = icmp sgt <2 x i32> %xext, <i32 -250, i32 450>
@@ -190,9 +168,7 @@ define <2 x i1> @icmp_sgt_sext_is_true_false(<2 x i8> %x) {
 
 define <2 x i1> @icmp_slt_sext_is_true_false(<2 x i8> %x) {
 ; CHECK-LABEL: @icmp_slt_sext_is_true_false(
-; CHECK-NEXT:    [[XEXT:%.*]] = sext <2 x i8> [[X:%.*]] to <2 x i32>
-; CHECK-NEXT:    [[CMP:%.*]] = icmp slt <2 x i32> [[XEXT]], <i32 257, i32 -450>
-; CHECK-NEXT:    ret <2 x i1> [[CMP]]
+; CHECK-NEXT:    ret <2 x i1> <i1 true, i1 false>
 ;
   %xext = sext <2 x i8> %x to <2 x i32>
   %cmp = icmp slt <2 x i32> %xext, <i32 257, i32 -450>


        


More information about the llvm-commits mailing list