[llvm] 82dcfe0 - [InstCombine] allow matching vector types for icmp-of-mask/cast

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Mon Jan 23 15:25:36 PST 2023


Author: Sanjay Patel
Date: 2023-01-23T18:23:43-05:00
New Revision: 82dcfe0dbbfac68056c6cb9c1ba5222ff836317c

URL: https://github.com/llvm/llvm-project/commit/82dcfe0dbbfac68056c6cb9c1ba5222ff836317c
DIFF: https://github.com/llvm/llvm-project/commit/82dcfe0dbbfac68056c6cb9c1ba5222ff836317c.diff

LOG: [InstCombine] allow matching vector types for icmp-of-mask/cast

Also use a more specific matcher to simplify the mask
compare to type size.

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
    llvm/test/Transforms/InstCombine/icmp.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index fe04b5949770..c3a53df5fb4d 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -4655,16 +4655,12 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) {
 
   // (B & (Pow2C-1)) == zext A --> A == trunc B
   // (B & (Pow2C-1)) != zext A --> A != trunc B
-  // TODO: This can be generalized for vector types.
-  ConstantInt *Cst1;
-  if (match(Op0, m_And(m_Value(B), m_ConstantInt(Cst1))) &&
+  const APInt *MaskC;
+  if (match(Op0, m_And(m_Value(B), m_LowBitMask(MaskC))) &&
       match(Op1, m_ZExt(m_Value(A))) &&
-      (Op0->hasOneUse() || Op1->hasOneUse())) {
-    APInt Pow2 = Cst1->getValue() + 1;
-    if (Pow2.isPowerOf2() && isa<IntegerType>(A->getType()) &&
-        Pow2.logBase2() == cast<IntegerType>(A->getType())->getBitWidth())
-      return new ICmpInst(Pred, A, Builder.CreateTrunc(B, A->getType()));
-  }
+      MaskC->countTrailingOnes() == A->getType()->getScalarSizeInBits() &&
+      (Op0->hasOneUse() || Op1->hasOneUse()))
+    return new ICmpInst(Pred, A, Builder.CreateTrunc(B, A->getType()));
 
   // (A >> C) == (B >> C) --> (A^B) u< (1 << C)
   // For lshr and ashr pairs.
@@ -4687,6 +4683,7 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) {
   }
 
   // (A << C) == (B << C) --> ((A^B) & (~0U >> C)) == 0
+  ConstantInt *Cst1;
   if (match(Op0, m_OneUse(m_Shl(m_Value(A), m_ConstantInt(Cst1)))) &&
       match(Op1, m_OneUse(m_Shl(m_Value(B), m_Specific(Cst1))))) {
     unsigned TypeBits = Cst1->getBitWidth();

diff  --git a/llvm/test/Transforms/InstCombine/icmp.ll b/llvm/test/Transforms/InstCombine/icmp.ll
index 89a4c049d4c7..206b763d193b 100644
--- a/llvm/test/Transforms/InstCombine/icmp.ll
+++ b/llvm/test/Transforms/InstCombine/icmp.ll
@@ -1140,6 +1140,36 @@ define i1 @low_mask_eq_zext_commute(i8 %a, i32 %b) {
   ret i1 %c
 }
 
+; negative test
+
+define i1 @wrong_low_mask_eq_zext(i8 %a, i32 %b) {
+; CHECK-LABEL: @wrong_low_mask_eq_zext(
+; CHECK-NEXT:    [[T:%.*]] = and i32 [[B:%.*]], 127
+; CHECK-NEXT:    [[Z:%.*]] = zext i8 [[A:%.*]] to i32
+; CHECK-NEXT:    [[C:%.*]] = icmp eq i32 [[T]], [[Z]]
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %t = and i32 %b, 127
+  %z = zext i8 %a to i32
+  %c = icmp eq i32 %t, %z
+  ret i1 %c
+}
+
+; negative test
+
+define i1 @wrong_low_mask_eq_zext2(i8 %a, i32 %b) {
+; CHECK-LABEL: @wrong_low_mask_eq_zext2(
+; CHECK-NEXT:    [[T:%.*]] = and i32 [[B:%.*]], 254
+; CHECK-NEXT:    [[Z:%.*]] = zext i8 [[A:%.*]] to i32
+; CHECK-NEXT:    [[C:%.*]] = icmp eq i32 [[T]], [[Z]]
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %t = and i32 %b, 254
+  %z = zext i8 %a to i32
+  %c = icmp eq i32 %t, %z
+  ret i1 %c
+}
+
 define i1 @low_mask_eq_zext_use1(i8 %a, i32 %b) {
 ; CHECK-LABEL: @low_mask_eq_zext_use1(
 ; CHECK-NEXT:    [[T:%.*]] = and i32 [[B:%.*]], 255
@@ -1189,9 +1219,8 @@ define i1 @low_mask_eq_zext_use3(i8 %a, i32 %b) {
 
 define <2 x i1> @low_mask_eq_zext_vec_splat(<2 x i8> %a, <2 x i32> %b) {
 ; CHECK-LABEL: @low_mask_eq_zext_vec_splat(
-; CHECK-NEXT:    [[T:%.*]] = and <2 x i32> [[B:%.*]], <i32 255, i32 255>
-; CHECK-NEXT:    [[Z:%.*]] = zext <2 x i8> [[A:%.*]] to <2 x i32>
-; CHECK-NEXT:    [[C:%.*]] = icmp eq <2 x i32> [[T]], [[Z]]
+; CHECK-NEXT:    [[TMP1:%.*]] = trunc <2 x i32> [[B:%.*]] to <2 x i8>
+; CHECK-NEXT:    [[C:%.*]] = icmp eq <2 x i8> [[TMP1]], [[A:%.*]]
 ; CHECK-NEXT:    ret <2 x i1> [[C]]
 ;
   %t = and <2 x i32> %b, <i32 255, i32 255>


        


More information about the llvm-commits mailing list