[llvm] 905d170 - [InstCombine] allow matching vector splat constants in foldLogOpOfMaskedICmps()

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 13 07:18:04 PDT 2021


Author: Sanjay Patel
Date: 2021-10-13T10:15:26-04:00
New Revision: 905d170803b0aafd5b4872ce084ed3f84746b6c9

URL: https://github.com/llvm/llvm-project/commit/905d170803b0aafd5b4872ce084ed3f84746b6c9
DIFF: https://github.com/llvm/llvm-project/commit/905d170803b0aafd5b4872ce084ed3f84746b6c9.diff

LOG: [InstCombine] allow matching vector splat constants in foldLogOpOfMaskedICmps()

This is NFC-intended for scalar code. There are still unnecessary
m_ConstantInt restrictions in surrounding code, so this is not a
complete fix.

This prevents regressions seen with a planned follow-on to D111410.

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
    llvm/test/Transforms/InstCombine/bit-checks.ll
    llvm/test/Transforms/InstCombine/icmp-logical.ll
    llvm/test/Transforms/InstCombine/onehot_merge.ll
    llvm/test/Transforms/InstCombine/or.ll
    llvm/test/Transforms/InstCombine/sign-test-and-or.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 88ad6907b98be..3d35399b8fc68 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -185,14 +185,15 @@ enum MaskedICmpType {
 /// satisfies.
 static unsigned getMaskedICmpType(Value *A, Value *B, Value *C,
                                   ICmpInst::Predicate Pred) {
-  ConstantInt *ACst = dyn_cast<ConstantInt>(A);
-  ConstantInt *BCst = dyn_cast<ConstantInt>(B);
-  ConstantInt *CCst = dyn_cast<ConstantInt>(C);
+  const APInt *ConstA = nullptr, *ConstB = nullptr, *ConstC = nullptr;
+  match(A, m_APInt(ConstA));
+  match(B, m_APInt(ConstB));
+  match(C, m_APInt(ConstC));
   bool IsEq = (Pred == ICmpInst::ICMP_EQ);
-  bool IsAPow2 = (ACst && !ACst->isZero() && ACst->getValue().isPowerOf2());
-  bool IsBPow2 = (BCst && !BCst->isZero() && BCst->getValue().isPowerOf2());
+  bool IsAPow2 = ConstA && ConstA->isPowerOf2();
+  bool IsBPow2 = ConstB && ConstB->isPowerOf2();
   unsigned MaskVal = 0;
-  if (CCst && CCst->isZero()) {
+  if (ConstC && ConstC->isZero()) {
     // if C is zero, then both A and B qualify as mask
     MaskVal |= (IsEq ? (Mask_AllZeros | AMask_Mixed | BMask_Mixed)
                      : (Mask_NotAllZeros | AMask_NotMixed | BMask_NotMixed));
@@ -211,7 +212,7 @@ static unsigned getMaskedICmpType(Value *A, Value *B, Value *C,
     if (IsAPow2)
       MaskVal |= (IsEq ? (Mask_NotAllZeros | AMask_NotMixed)
                        : (Mask_AllZeros | AMask_Mixed));
-  } else if (ACst && CCst && ConstantExpr::getAnd(ACst, CCst) == CCst) {
+  } else if (ConstA && ConstC && ConstC->isSubsetOf(*ConstA)) {
     MaskVal |= (IsEq ? AMask_Mixed : AMask_NotMixed);
   }
 
@@ -221,7 +222,7 @@ static unsigned getMaskedICmpType(Value *A, Value *B, Value *C,
     if (IsBPow2)
       MaskVal |= (IsEq ? (Mask_NotAllZeros | BMask_NotMixed)
                        : (Mask_AllZeros | BMask_Mixed));
-  } else if (BCst && CCst && ConstantExpr::getAnd(BCst, CCst) == CCst) {
+  } else if (ConstB && ConstC && ConstC->isSubsetOf(*ConstB)) {
     MaskVal |= (IsEq ? BMask_Mixed : BMask_NotMixed);
   }
 
@@ -269,9 +270,9 @@ getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C,
                          ICmpInst *RHS,
                          ICmpInst::Predicate &PredL,
                          ICmpInst::Predicate &PredR) {
-  // vectors are not (yet?) supported. Don't support pointers either.
-  if (!LHS->getOperand(0)->getType()->isIntegerTy() ||
-      !RHS->getOperand(0)->getType()->isIntegerTy())
+  // Don't allow pointers. Splat vectors are fine.
+  if (!LHS->getOperand(0)->getType()->isIntOrIntVectorTy() ||
+      !RHS->getOperand(0)->getType()->isIntOrIntVectorTy())
     return None;
 
   // Here comes the tricky part:
@@ -619,8 +620,8 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,
   // Remaining cases assume at least that B and D are constant, and depend on
   // their actual values. This isn't strictly necessary, just a "handle the
   // easy cases for now" decision.
-  ConstantInt *BCst, *DCst;
-  if (!match(B, m_ConstantInt(BCst)) || !match(D, m_ConstantInt(DCst)))
+  const APInt *ConstB, *ConstD;
+  if (!match(B, m_APInt(ConstB)) || !match(D, m_APInt(ConstD)))
     return nullptr;
 
   if (Mask & (Mask_NotAllZeros | BMask_NotAllOnes)) {
@@ -629,11 +630,10 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,
     //     -> (icmp ne (A & B), 0) or (icmp ne (A & D), 0)
     // Only valid if one of the masks is a superset of the other (check "B&D" is
     // the same as either B or D).
-    APInt NewMask = BCst->getValue() & DCst->getValue();
-
-    if (NewMask == BCst->getValue())
+    APInt NewMask = *ConstB & *ConstD;
+    if (NewMask == *ConstB)
       return LHS;
-    else if (NewMask == DCst->getValue())
+    else if (NewMask == *ConstD)
       return RHS;
   }
 
@@ -642,11 +642,10 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,
     //     -> (icmp ne (A & B), A) or (icmp ne (A & D), A)
     // Only valid if one of the masks is a superset of the other (check "B|D" is
     // the same as either B or D).
-    APInt NewMask = BCst->getValue() | DCst->getValue();
-
-    if (NewMask == BCst->getValue())
+    APInt NewMask = *ConstB | *ConstD;
+    if (NewMask == *ConstB)
       return LHS;
-    else if (NewMask == DCst->getValue())
+    else if (NewMask == *ConstD)
       return RHS;
   }
 
@@ -661,23 +660,21 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,
     // We can't simply use C and E because we might actually handle
     //   (icmp ne (A & B), B) & (icmp eq (A & D), D)
     // with B and D, having a single bit set.
-    ConstantInt *CCst, *ECst;
-    if (!match(C, m_ConstantInt(CCst)) || !match(E, m_ConstantInt(ECst)))
+    const APInt *OldConstC, *OldConstE;
+    if (!match(C, m_APInt(OldConstC)) || !match(E, m_APInt(OldConstE)))
       return nullptr;
-    if (PredL != NewCC)
-      CCst = cast<ConstantInt>(ConstantExpr::getXor(BCst, CCst));
-    if (PredR != NewCC)
-      ECst = cast<ConstantInt>(ConstantExpr::getXor(DCst, ECst));
+
+    const APInt ConstC = PredL != NewCC ? *ConstB ^ *OldConstC : *OldConstC;
+    const APInt ConstE = PredR != NewCC ? *ConstD ^ *OldConstE : *OldConstE;
 
     // If there is a conflict, we should actually return a false for the
     // whole construct.
-    if (((BCst->getValue() & DCst->getValue()) &
-         (CCst->getValue() ^ ECst->getValue())).getBoolValue())
+    if (((*ConstB & *ConstD) & (ConstC ^ ConstE)).getBoolValue())
       return ConstantInt::get(LHS->getType(), !IsAnd);
 
     Value *NewOr1 = Builder.CreateOr(B, D);
-    Value *NewOr2 = ConstantExpr::getOr(CCst, ECst);
     Value *NewAnd = Builder.CreateAnd(A, NewOr1);
+    Constant *NewOr2 = ConstantInt::get(A->getType(), ConstC | ConstE);
     return Builder.CreateICmp(NewCC, NewAnd, NewOr2);
   }
 

diff  --git a/llvm/test/Transforms/InstCombine/bit-checks.ll b/llvm/test/Transforms/InstCombine/bit-checks.ll
index 9937a71747f79..72b82db17a2b2 100644
--- a/llvm/test/Transforms/InstCombine/bit-checks.ll
+++ b/llvm/test/Transforms/InstCombine/bit-checks.ll
@@ -290,12 +290,9 @@ define i32 @main4(i32 %argc) {
 
 define <2 x i32> @main4_splat(<2 x i32> %argc) {
 ; CHECK-LABEL: @main4_splat(
-; CHECK-NEXT:    [[AND:%.*]] = and <2 x i32> [[ARGC:%.*]], <i32 7, i32 7>
-; CHECK-NEXT:    [[TOBOOL:%.*]] = icmp ne <2 x i32> [[AND]], <i32 7, i32 7>
-; CHECK-NEXT:    [[AND2:%.*]] = and <2 x i32> [[ARGC]], <i32 48, i32 48>
-; CHECK-NEXT:    [[TOBOOL3:%.*]] = icmp ne <2 x i32> [[AND2]], <i32 48, i32 48>
-; CHECK-NEXT:    [[NOT_AND_COND:%.*]] = or <2 x i1> [[TOBOOL]], [[TOBOOL3]]
-; CHECK-NEXT:    [[STOREMERGE:%.*]] = zext <2 x i1> [[NOT_AND_COND]] to <2 x i32>
+; CHECK-NEXT:    [[TMP1:%.*]] = and <2 x i32> [[ARGC:%.*]], <i32 55, i32 55>
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp ne <2 x i32> [[TMP1]], <i32 55, i32 55>
+; CHECK-NEXT:    [[STOREMERGE:%.*]] = zext <2 x i1> [[TMP2]] to <2 x i32>
 ; CHECK-NEXT:    ret <2 x i32> [[STOREMERGE]]
 ;
   %and = and <2 x i32> %argc, <i32 7, i32 7>

diff  --git a/llvm/test/Transforms/InstCombine/icmp-logical.ll b/llvm/test/Transforms/InstCombine/icmp-logical.ll
index a9edfeb47a024..15893f5677d34 100644
--- a/llvm/test/Transforms/InstCombine/icmp-logical.ll
+++ b/llvm/test/Transforms/InstCombine/icmp-logical.ll
@@ -19,10 +19,7 @@ define <2 x i1> @masked_and_notallzeroes_splat(<2 x i32> %A) {
 ; CHECK-LABEL: @masked_and_notallzeroes_splat(
 ; CHECK-NEXT:    [[MASK1:%.*]] = and <2 x i32> [[A:%.*]], <i32 7, i32 7>
 ; CHECK-NEXT:    [[TST1:%.*]] = icmp ne <2 x i32> [[MASK1]], zeroinitializer
-; CHECK-NEXT:    [[MASK2:%.*]] = and <2 x i32> [[A]], <i32 39, i32 39>
-; CHECK-NEXT:    [[TST2:%.*]] = icmp ne <2 x i32> [[MASK2]], zeroinitializer
-; CHECK-NEXT:    [[RES:%.*]] = and <2 x i1> [[TST1]], [[TST2]]
-; CHECK-NEXT:    ret <2 x i1> [[RES]]
+; CHECK-NEXT:    ret <2 x i1> [[TST1]]
 ;
   %mask1 = and <2 x i32> %A, <i32 7, i32 7>
   %tst1 = icmp ne <2 x i32> %mask1, <i32 0, i32 0>

diff  --git a/llvm/test/Transforms/InstCombine/onehot_merge.ll b/llvm/test/Transforms/InstCombine/onehot_merge.ll
index 31183fe8a74a2..fde4dd54fa517 100644
--- a/llvm/test/Transforms/InstCombine/onehot_merge.ll
+++ b/llvm/test/Transforms/InstCombine/onehot_merge.ll
@@ -578,11 +578,10 @@ define i1 @foo1_or_signbit_lshr_without_shifting_signbit_both_sides(i32 %k, i32
 define <2 x i1> @foo1_or_signbit_lshr_without_shifting_signbit_both_sides_splat(<2 x i32> %k, <2 x i32> %c1, <2 x i32> %c2) {
 ; CHECK-LABEL: @foo1_or_signbit_lshr_without_shifting_signbit_both_sides_splat(
 ; CHECK-NEXT:    [[T0:%.*]] = shl <2 x i32> [[K:%.*]], [[C1:%.*]]
-; CHECK-NEXT:    [[T1:%.*]] = icmp slt <2 x i32> [[T0]], zeroinitializer
 ; CHECK-NEXT:    [[T2:%.*]] = shl <2 x i32> [[K]], [[C2:%.*]]
-; CHECK-NEXT:    [[T3:%.*]] = icmp slt <2 x i32> [[T2]], zeroinitializer
-; CHECK-NEXT:    [[OR:%.*]] = and <2 x i1> [[T1]], [[T3]]
-; CHECK-NEXT:    ret <2 x i1> [[OR]]
+; CHECK-NEXT:    [[TMP1:%.*]] = and <2 x i32> [[T0]], [[T2]]
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp slt <2 x i32> [[TMP1]], zeroinitializer
+; CHECK-NEXT:    ret <2 x i1> [[TMP2]]
 ;
   %t0 = shl <2 x i32> %k, %c1
   %t1 = icmp slt <2 x i32> %t0, zeroinitializer

diff  --git a/llvm/test/Transforms/InstCombine/or.ll b/llvm/test/Transforms/InstCombine/or.ll
index 00af70be740a6..fcc6a0fbac6fd 100644
--- a/llvm/test/Transforms/InstCombine/or.ll
+++ b/llvm/test/Transforms/InstCombine/or.ll
@@ -1459,11 +1459,8 @@ define i1 @cmp_overlap(i32 %x) {
 
 define <2 x i1> @cmp_overlap_splat(<2 x i5> %x) {
 ; CHECK-LABEL: @cmp_overlap_splat(
-; CHECK-NEXT:    [[ISNEG:%.*]] = icmp slt <2 x i5> [[X:%.*]], zeroinitializer
-; CHECK-NEXT:    [[NOTSUB:%.*]] = add <2 x i5> [[X]], <i5 -1, i5 -1>
-; CHECK-NEXT:    [[ISNOTNEG:%.*]] = icmp slt <2 x i5> [[NOTSUB]], zeroinitializer
-; CHECK-NEXT:    [[R:%.*]] = or <2 x i1> [[ISNEG]], [[ISNOTNEG]]
-; CHECK-NEXT:    ret <2 x i1> [[R]]
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp slt <2 x i5> [[X:%.*]], <i5 1, i5 1>
+; CHECK-NEXT:    ret <2 x i1> [[TMP1]]
 ;
   %isneg = icmp slt <2 x i5> %x, zeroinitializer
   %negx = sub <2 x i5> zeroinitializer, %x

diff  --git a/llvm/test/Transforms/InstCombine/sign-test-and-or.ll b/llvm/test/Transforms/InstCombine/sign-test-and-or.ll
index c968e76cb7ac9..7a88fc74500ae 100644
--- a/llvm/test/Transforms/InstCombine/sign-test-and-or.ll
+++ b/llvm/test/Transforms/InstCombine/sign-test-and-or.ll
@@ -17,10 +17,9 @@ define i1 @test1(i32 %a, i32 %b) {
 
 define <2 x i1> @test1_splat(<2 x i32> %a, <2 x i32> %b) {
 ; CHECK-LABEL: @test1_splat(
-; CHECK-NEXT:    [[TMP1:%.*]] = icmp slt <2 x i32> [[A:%.*]], zeroinitializer
-; CHECK-NEXT:    [[TMP2:%.*]] = icmp slt <2 x i32> [[B:%.*]], zeroinitializer
-; CHECK-NEXT:    [[OR_COND:%.*]] = or <2 x i1> [[TMP1]], [[TMP2]]
-; CHECK-NEXT:    ret <2 x i1> [[OR_COND]]
+; CHECK-NEXT:    [[TMP1:%.*]] = or <2 x i32> [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp slt <2 x i32> [[TMP1]], zeroinitializer
+; CHECK-NEXT:    ret <2 x i1> [[TMP2]]
 ;
   %1 = icmp slt <2 x i32> %a, zeroinitializer
   %2 = icmp slt <2 x i32> %b, zeroinitializer


        


More information about the llvm-commits mailing list