[llvm] 0242a6a - [InstCombine] Support splat vectors in some or of icmp folds

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 10 13:59:17 PST 2021


Author: Nikita Popov
Date: 2021-11-10T22:59:09+01:00
New Revision: 0242a6adf73a7cdee55e90989d1d075b607320e1

URL: https://github.com/llvm/llvm-project/commit/0242a6adf73a7cdee55e90989d1d075b607320e1
DIFF: https://github.com/llvm/llvm-project/commit/0242a6adf73a7cdee55e90989d1d075b607320e1.diff

LOG: [InstCombine] Support splat vectors in some or of icmp folds

Replace m_ConstantInt() with m_APInt() in order to support splat
constants in addition to scalar integers.

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
    llvm/test/Transforms/InstCombine/or.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index fae19549f081..230046a20a22 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -2340,8 +2340,9 @@ Value *InstCombinerImpl::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
   ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate();
   Value *LHS0 = LHS->getOperand(0), *RHS0 = RHS->getOperand(0);
   Value *LHS1 = LHS->getOperand(1), *RHS1 = RHS->getOperand(1);
-  auto *LHSC = dyn_cast<ConstantInt>(LHS1);
-  auto *RHSC = dyn_cast<ConstantInt>(RHS1);
+  const APInt *LHSC = nullptr, *RHSC = nullptr;
+  match(LHS1, m_APInt(LHSC));
+  match(RHS1, m_APInt(RHSC));
 
   // Fold (icmp ult/ule (A + C1), C3) | (icmp ult/ule (A + C2), C3)
   //                   -->  (icmp ult/ule ((A & ~(C1 ^ C2)) + max(C1, C2)), C3)
@@ -2355,40 +2356,41 @@ Value *InstCombinerImpl::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
   // This implies all values in the two ranges 
diff er by exactly one bit.
   if ((PredL == ICmpInst::ICMP_ULT || PredL == ICmpInst::ICMP_ULE) &&
       PredL == PredR && LHSC && RHSC && LHS->hasOneUse() && RHS->hasOneUse() &&
-      LHSC->getType() == RHSC->getType() &&
-      LHSC->getValue() == (RHSC->getValue())) {
+      LHSC->getBitWidth() == RHSC->getBitWidth() && *LHSC == *RHSC) {
 
     Value *AddOpnd;
-    ConstantInt *LAddC, *RAddC;
-    if (match(LHS0, m_Add(m_Value(AddOpnd), m_ConstantInt(LAddC))) &&
-        match(RHS0, m_Add(m_Specific(AddOpnd), m_ConstantInt(RAddC))) &&
-        LAddC->getValue().ugt(LHSC->getValue()) &&
-        RAddC->getValue().ugt(LHSC->getValue())) {
+    const APInt *LAddC, *RAddC;
+    if (match(LHS0, m_Add(m_Value(AddOpnd), m_APInt(LAddC))) &&
+        match(RHS0, m_Add(m_Specific(AddOpnd), m_APInt(RAddC))) &&
+        LAddC->ugt(*LHSC) && RAddC->ugt(*LHSC)) {
 
-      APInt DiffC = LAddC->getValue() ^ RAddC->getValue();
+      APInt DiffC = *LAddC ^ *RAddC;
       if (DiffC.isPowerOf2()) {
-        ConstantInt *MaxAddC = nullptr;
-        if (LAddC->getValue().ult(RAddC->getValue()))
+        const APInt *MaxAddC = nullptr;
+        if (LAddC->ult(*RAddC))
           MaxAddC = RAddC;
         else
           MaxAddC = LAddC;
 
-        APInt RRangeLow = -RAddC->getValue();
-        APInt RRangeHigh = RRangeLow + LHSC->getValue();
-        APInt LRangeLow = -LAddC->getValue();
-        APInt LRangeHigh = LRangeLow + LHSC->getValue();
+        APInt RRangeLow = -*RAddC;
+        APInt RRangeHigh = RRangeLow + *LHSC;
+        APInt LRangeLow = -*LAddC;
+        APInt LRangeHigh = LRangeLow + *LHSC;
         APInt LowRangeDiff = RRangeLow ^ LRangeLow;
         APInt HighRangeDiff = RRangeHigh ^ LRangeHigh;
         APInt RangeDiff = LRangeLow.sgt(RRangeLow) ? LRangeLow - RRangeLow
                                                    : RRangeLow - LRangeLow;
 
         if (LowRangeDiff.isPowerOf2() && LowRangeDiff == HighRangeDiff &&
-            RangeDiff.ugt(LHSC->getValue())) {
-          Value *MaskC = ConstantInt::get(LAddC->getType(), ~DiffC);
+            RangeDiff.ugt(*LHSC)) {
+          Type *Ty = AddOpnd->getType();
+          Value *MaskC = ConstantInt::get(Ty, ~DiffC);
 
           Value *NewAnd = Builder.CreateAnd(AddOpnd, MaskC);
-          Value *NewAdd = Builder.CreateAdd(NewAnd, MaxAddC);
-          return Builder.CreateICmp(LHS->getPredicate(), NewAdd, LHSC);
+          Value *NewAdd = Builder.CreateAdd(NewAnd,
+                                            ConstantInt::get(Ty, *MaxAddC));
+          return Builder.CreateICmp(LHS->getPredicate(), NewAdd,
+                                    ConstantInt::get(Ty, *LHSC));
         }
       }
     }
@@ -2480,8 +2482,7 @@ Value *InstCombinerImpl::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
   if (!LHSC || !RHSC)
     return nullptr;
 
-  return foldAndOrOfICmpsUsingRanges(PredL, LHS0, LHSC->getValue(),
-                                     PredR, RHS0, RHSC->getValue(),
+  return foldAndOrOfICmpsUsingRanges(PredL, LHS0, *LHSC, PredR, RHS0, *RHSC,
                                      Builder, /* IsAnd */ false);
 }
 

diff  --git a/llvm/test/Transforms/InstCombine/or.ll b/llvm/test/Transforms/InstCombine/or.ll
index ce64662d191b..013183a390ac 100644
--- a/llvm/test/Transforms/InstCombine/or.ll
+++ b/llvm/test/Transforms/InstCombine/or.ll
@@ -121,13 +121,11 @@ define i1 @test18_logical(i32 %A) {
   ret i1 %D
 }
 
-; FIXME: Vectors should fold too.
 define <2 x i1> @test18vec(<2 x i32> %A) {
 ; CHECK-LABEL: @test18vec(
-; CHECK-NEXT:    [[B:%.*]] = icmp sgt <2 x i32> [[A:%.*]], <i32 99, i32 99>
-; CHECK-NEXT:    [[C:%.*]] = icmp slt <2 x i32> [[A]], <i32 50, i32 50>
-; CHECK-NEXT:    [[D:%.*]] = or <2 x i1> [[B]], [[C]]
-; CHECK-NEXT:    ret <2 x i1> [[D]]
+; CHECK-NEXT:    [[TMP1:%.*]] = add <2 x i32> [[A:%.*]], <i32 -100, i32 -100>
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp ult <2 x i32> [[TMP1]], <i32 -50, i32 -50>
+; CHECK-NEXT:    ret <2 x i1> [[TMP2]]
 ;
   %B = icmp sge <2 x i32> %A, <i32 100, i32 100>
   %C = icmp slt <2 x i32> %A, <i32 50, i32 50>
@@ -481,9 +479,9 @@ define i1 @test36_logical(i32 %x) {
 
 define i1 @test37(i32 %x) {
 ; CHECK-LABEL: @test37(
-; CHECK-NEXT:    [[ADD1:%.*]] = add i32 [[X:%.*]], 7
-; CHECK-NEXT:    [[TMP1:%.*]] = icmp ult i32 [[ADD1]], 31
-; CHECK-NEXT:    ret i1 [[TMP1]]
+; CHECK-NEXT:    [[TMP1:%.*]] = add i32 [[X:%.*]], 7
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp ult i32 [[TMP1]], 31
+; CHECK-NEXT:    ret i1 [[TMP2]]
 ;
   %add1 = add i32 %x, 7
   %cmp1 = icmp ult i32 %add1, 30
@@ -494,9 +492,9 @@ define i1 @test37(i32 %x) {
 
 define i1 @test37_logical(i32 %x) {
 ; CHECK-LABEL: @test37_logical(
-; CHECK-NEXT:    [[ADD1:%.*]] = add i32 [[X:%.*]], 7
-; CHECK-NEXT:    [[TMP1:%.*]] = icmp ult i32 [[ADD1]], 31
-; CHECK-NEXT:    ret i1 [[TMP1]]
+; CHECK-NEXT:    [[TMP1:%.*]] = add i32 [[X:%.*]], 7
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp ult i32 [[TMP1]], 31
+; CHECK-NEXT:    ret i1 [[TMP2]]
 ;
   %add1 = add i32 %x, 7
   %cmp1 = icmp ult i32 %add1, 30
@@ -507,11 +505,9 @@ define i1 @test37_logical(i32 %x) {
 
 define <2 x i1> @test37_uniform(<2 x i32> %x) {
 ; CHECK-LABEL: @test37_uniform(
-; CHECK-NEXT:    [[ADD1:%.*]] = add <2 x i32> [[X:%.*]], <i32 7, i32 7>
-; CHECK-NEXT:    [[CMP1:%.*]] = icmp ult <2 x i32> [[ADD1]], <i32 30, i32 30>
-; CHECK-NEXT:    [[CMP2:%.*]] = icmp eq <2 x i32> [[X]], <i32 23, i32 23>
-; CHECK-NEXT:    [[RET1:%.*]] = or <2 x i1> [[CMP1]], [[CMP2]]
-; CHECK-NEXT:    ret <2 x i1> [[RET1]]
+; CHECK-NEXT:    [[TMP1:%.*]] = add <2 x i32> [[X:%.*]], <i32 7, i32 7>
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp ult <2 x i32> [[TMP1]], <i32 31, i32 31>
+; CHECK-NEXT:    ret <2 x i1> [[TMP2]]
 ;
   %add1 = add <2 x i32> %x, <i32 7, i32 7>
   %cmp1 = icmp ult <2 x i32> %add1, <i32 30, i32 30>
@@ -792,12 +788,10 @@ define i1 @test46_logical(i8 signext %c)  {
 
 define <2 x i1> @test46_uniform(<2 x i8> %c)  {
 ; CHECK-LABEL: @test46_uniform(
-; CHECK-NEXT:    [[C_OFF:%.*]] = add <2 x i8> [[C:%.*]], <i8 -97, i8 -97>
-; CHECK-NEXT:    [[CMP1:%.*]] = icmp ult <2 x i8> [[C_OFF]], <i8 26, i8 26>
-; CHECK-NEXT:    [[C_OFF17:%.*]] = add <2 x i8> [[C]], <i8 -65, i8 -65>
-; CHECK-NEXT:    [[CMP2:%.*]] = icmp ult <2 x i8> [[C_OFF17]], <i8 26, i8 26>
-; CHECK-NEXT:    [[OR:%.*]] = or <2 x i1> [[CMP1]], [[CMP2]]
-; CHECK-NEXT:    ret <2 x i1> [[OR]]
+; CHECK-NEXT:    [[TMP1:%.*]] = and <2 x i8> [[C:%.*]], <i8 -33, i8 -33>
+; CHECK-NEXT:    [[TMP2:%.*]] = add <2 x i8> [[TMP1]], <i8 -65, i8 -65>
+; CHECK-NEXT:    [[TMP3:%.*]] = icmp ult <2 x i8> [[TMP2]], <i8 26, i8 26>
+; CHECK-NEXT:    ret <2 x i1> [[TMP3]]
 ;
   %c.off = add <2 x i8> %c, <i8 -97, i8 -97>
   %cmp1 = icmp ult <2 x i8> %c.off, <i8 26, i8 26>


        


More information about the llvm-commits mailing list